Skip to content

Commit

Permalink
Optimize Simd handling for g++-12
Browse files Browse the repository at this point in the history
  • Loading branch information
bernhardmgruber committed Nov 3, 2022
1 parent c19e566 commit 165d8c3
Showing 1 changed file with 15 additions and 9 deletions.
24 changes: 15 additions & 9 deletions include/llama/Simd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,19 +206,23 @@ namespace llama
//}
else if constexpr(mapping::isAoS<Mapping>)
{
constexpr static auto stride
constexpr static auto srcStride
= flatSizeOf<typename Mapping::Flattener::FlatRecordDim, Mapping::alignAndPad>;
auto* base = reinterpret_cast<const std::byte*>(&srcRef(rc));
const auto* srcBaseAddr = reinterpret_cast<const std::byte*>(&srcRef(rc));
ElementSimd elemSimd; // g++-12 really needs the intermediate elemSimd and memcpy
for(auto i = 0; i < Traits::lanes; i++)
reinterpret_cast<FieldType*>(&dstSimd(rc))[i]
= *reinterpret_cast<const FieldType*>(base + i * stride);
reinterpret_cast<FieldType*>(&elemSimd)[i]
= *reinterpret_cast<const FieldType*>(srcBaseAddr + i * srcStride);
std::memcpy(&dstSimd(rc), &elemSimd, sizeof(elemSimd));
}
else
{
auto b = ArrayIndexIterator{srcRef.view.mapping().extents(), srcRef.arrayIndex()};
ElementSimd elemSimd; // g++-12 really needs the intermediate elemSimd and memcpy
for(auto i = 0; i < Traits::lanes; i++)
reinterpret_cast<FieldType*>(&dstSimd(rc))[i]
reinterpret_cast<FieldType*>(&elemSimd)[i]
= srcRef.view(*b++)(cat(typename T::BoundRecordCoord{}, rc)); // scalar loads
std::memcpy(&dstSimd(rc), &elemSimd, sizeof(elemSimd));
}
});
}
Expand Down Expand Up @@ -266,19 +270,21 @@ namespace llama
{
constexpr static auto stride
= flatSizeOf<typename Mapping::Flattener::FlatRecordDim, Mapping::alignAndPad>;
auto* base = reinterpret_cast<std::byte*>(&dstRef(rc));
auto* dstBaseAddr = reinterpret_cast<std::byte*>(&dstRef(rc));
const ElementSimd elemSimd = srcSimd(rc);
for(auto i = 0; i < Traits::lanes; i++)
*reinterpret_cast<FieldType*>(base + i * stride)
= reinterpret_cast<const FieldType*>(&srcSimd(rc))[i];
*reinterpret_cast<FieldType*>(dstBaseAddr + i * stride)
= reinterpret_cast<const FieldType*>(&elemSimd)[i];
}
else
{
// TODO(bgruber): how does this generalize conceptually to 2D and higher dimensions? in which
// direction should we collect SIMD values?
const ElementSimd elemSimd = srcSimd(rc);
auto b = ArrayIndexIterator{dstRef.view.mapping().extents(), dstRef.arrayIndex()};
for(auto i = 0; i < Traits::lanes; i++)
dstRef.view (*b++)(cat(typename T::BoundRecordCoord{}, rc))
= reinterpret_cast<const FieldType*>(&srcSimd(rc))[i]; // scalar store
= reinterpret_cast<const FieldType*>(&elemSimd)[i]; // scalar store
}
});
}
Expand Down

0 comments on commit 165d8c3

Please sign in to comment.