diff --git a/include/llama/RecordRef.hpp b/include/llama/RecordRef.hpp index 596127ded1..5f52140dce 100644 --- a/include/llama/RecordRef.hpp +++ b/include/llama/RecordRef.hpp @@ -331,6 +331,12 @@ namespace llama template typename Tuple, typename... Args> constexpr inline auto isDirectListInitializableFromTuple> = isDirectListInitializable; + + template + LLAMA_FN_HOST_ACC_INLINE void loadSimdRecord(const T& srcRef, Simd& dstSimd, RecordCoord rc); + + template + LLAMA_FN_HOST_ACC_INLINE void storeSimdRecord(const Simd& srcSimd, T&& dstRef, RecordCoord rc); } // namespace internal /// Record reference type returned by \ref View after resolving an array dimensions coordinate or partially @@ -745,10 +751,10 @@ namespace llama // FIXME(bgruber): the SIMD load/store functions need to navigate back from a record ref to the contained view // to find subsequent elements. This is not a great design for now and the SIMD load/store functions should // probably take iterators to records. - template - friend void loadSimd(const T& srcRef, Simd& dstSimd); - template - friend void storeSimd(const Simd& srcSimd, T&& dstRef); + template + friend void internal::loadSimdRecord(const T& srcRef, Simd& dstSimd, RecordCoord rc); + template + friend void internal::storeSimdRecord(const Simd& srcSimd, T&& dstRef, RecordCoord rc); }; // swap for heterogeneous RecordRef diff --git a/include/llama/Simd.hpp b/include/llama/Simd.hpp index f590d57e1f..b8f7faf4d5 100644 --- a/include/llama/Simd.hpp +++ b/include/llama/Simd.hpp @@ -3,6 +3,7 @@ #include "Core.hpp" #include "RecordRef.hpp" #include "macros.hpp" +#include "mapping/AoS.hpp" #include "mapping/SoA.hpp" #include @@ -166,6 +167,98 @@ namespace llama return simdLanes; }(); + namespace internal + { + template + LLAMA_FN_HOST_ACC_INLINE void loadSimdRecord(const T& srcRef, Simd& dstSimd, RecordCoord rc) + { + using RecordDim = typename T::AccessibleRecordDim; + using FieldType = GetType; + using ElementSimd = std::decay_t; + using Traits = SimdTraits; + + // TODO(bgruber): can we generalize the logic whether we can load a dstSimd from that mapping? + using Mapping = typename T::View::Mapping; + if constexpr(mapping::isSoA) + { + LLAMA_BEGIN_SUPPRESS_HOST_DEVICE_WARNING + dstSimd(rc) = Traits::loadUnaligned(&srcRef(rc)); // SIMD load + LLAMA_END_SUPPRESS_HOST_DEVICE_WARNING + } + // else if constexpr(mapping::isAoSoA) + //{ + // // it turns out we do not need the specialization, because clang already fuses the scalar + // loads + // // into a vector load :D + // assert(srcRef.arrayDimsCoord()[0] % SIMD_WIDTH == 0); + // // if(srcRef.arrayDimsCoord()[0] % SIMD_WIDTH != 0) + // // __builtin_unreachable(); // this also helps nothing + // //__builtin_assume(srcRef.arrayDimsCoord()[0] % SIMD_WIDTH == 0); // this also helps nothing + // dstSimd(rc) = Traits::load_from(&srcRef(rc)); // SIMD load + //} + else if constexpr(mapping::isAoS) + { + static_assert(mapping::isAoS); + constexpr static auto srcStride + = flatSizeOf; + const auto* srcBaseAddr = reinterpret_cast(&srcRef(rc)); + ElementSimd elemSimd; // g++-12 really needs the intermediate elemSimd and memcpy + for(auto i = 0; i < Traits::lanes; i++) + reinterpret_cast(&elemSimd)[i] + = *reinterpret_cast(srcBaseAddr + i * srcStride); + std::memcpy(&dstSimd(rc), &elemSimd, sizeof(elemSimd)); + } + else + { + auto b = ArrayIndexIterator{srcRef.view.mapping().extents(), srcRef.arrayIndex()}; + ElementSimd elemSimd; // g++-12 really needs the intermediate elemSimd and memcpy + for(auto i = 0; i < Traits::lanes; i++) + reinterpret_cast(&elemSimd)[i] + = srcRef.view(*b++)(cat(typename T::BoundRecordCoord{}, rc)); // scalar loads + std::memcpy(&dstSimd(rc), &elemSimd, sizeof(elemSimd)); + } + } + + template + LLAMA_FN_HOST_ACC_INLINE void storeSimdRecord(const Simd& srcSimd, TFwd&& dstRef, RecordCoord rc) + { + using T = std::remove_reference_t; + using RecordDim = typename T::AccessibleRecordDim; + using FieldType = GetType; + using ElementSimd = std::decay_t; + using Traits = SimdTraits; + + // TODO(bgruber): can we generalize the logic whether we can store a srcSimd to that mapping? + using Mapping = typename std::remove_reference_t::View::Mapping; + if constexpr(mapping::isSoA) + { + LLAMA_BEGIN_SUPPRESS_HOST_DEVICE_WARNING + Traits::storeUnaligned(srcSimd(rc), &dstRef(rc)); // SIMD store + LLAMA_END_SUPPRESS_HOST_DEVICE_WARNING + } + else if constexpr(mapping::isAoS) + { + constexpr static auto stride + = flatSizeOf; + auto* dstBaseAddr = reinterpret_cast(&dstRef(rc)); + const ElementSimd elemSimd = srcSimd(rc); + for(auto i = 0; i < Traits::lanes; i++) + *reinterpret_cast(dstBaseAddr + i * stride) + = reinterpret_cast(&elemSimd)[i]; + } + else + { + // TODO(bgruber): how does this generalize conceptually to 2D and higher dimensions? in which + // direction should we collect SIMD values? + const ElementSimd elemSimd = srcSimd(rc); + auto b = ArrayIndexIterator{dstRef.view.mapping().extents(), dstRef.arrayIndex()}; + for(auto i = 0; i < Traits::lanes; i++) + dstRef.view (*b++)(cat(typename T::BoundRecordCoord{}, rc)) + = reinterpret_cast(&elemSimd)[i]; // scalar store + } + } + } // namespace internal + /// Loads SIMD vectors of data starting from the given record reference to dstSimd. Only field tags occurring in /// RecordRef are loaded. If Simd contains multiple fields of SIMD types, a SIMD vector will be fetched for each of /// the fields. The number of elements fetched per SIMD vector depends on the SIMD width of the vector. Simd is @@ -176,40 +269,8 @@ namespace llama // structured dstSimd type and record reference if constexpr(isRecordRef && isRecordRef) { - using RecordDim = typename T::AccessibleRecordDim; - forEachLeafCoord( - [&](auto rc) LLAMA_LAMBDA_INLINE - { - using FieldType = GetType; - using ElementSimd = std::decay_t; - using Traits = SimdTraits; - - // TODO(bgruber): can we generalize the logic whether we can load a dstSimd from that mapping? - if constexpr(mapping::isSoA) - { - LLAMA_BEGIN_SUPPRESS_HOST_DEVICE_WARNING - dstSimd(rc) = Traits::loadUnaligned(&srcRef(rc)); // SIMD load - LLAMA_END_SUPPRESS_HOST_DEVICE_WARNING - } - // else if constexpr(mapping::isAoSoA) - //{ - // // it turns out we do not need the specialization, because clang already fuses the scalar - // loads - // // into a vector load :D - // assert(srcRef.arrayDimsCoord()[0] % SIMD_WIDTH == 0); - // // if(srcRef.arrayDimsCoord()[0] % SIMD_WIDTH != 0) - // // __builtin_unreachable(); // this also helps nothing - // //__builtin_assume(srcRef.arrayDimsCoord()[0] % SIMD_WIDTH == 0); // this also helps nothing - // dstSimd(rc) = Traits::load_from(&srcRef(rc)); // SIMD load - //} - else - { - auto b = ArrayIndexIterator{srcRef.view.mapping().extents(), srcRef.arrayIndex()}; - for(auto i = 0; i < Traits::lanes; i++) - reinterpret_cast(&dstSimd(rc))[i] - = srcRef.view(*b++)(cat(typename T::BoundRecordCoord{}, rc)); // scalar loads - } - }); + forEachLeafCoord([&](auto rc) LLAMA_LAMBDA_INLINE + { internal::loadSimdRecord(srcRef, dstSimd, rc); }); } // unstructured dstSimd and reference type else if constexpr(!isRecordRef && !isRecordRef) @@ -235,31 +296,8 @@ namespace llama // structured Simd type and record reference if constexpr(isRecordRef && isRecordRef) { - using RecordDim = typename T::AccessibleRecordDim; - forEachLeafCoord( - [&](auto rc) LLAMA_LAMBDA_INLINE - { - using FieldType = GetType; - using ElementSimd = std::decay_t; - using Traits = SimdTraits; - - // TODO(bgruber): can we generalize the logic whether we can store a srcSimd to that mapping? - if constexpr(mapping::isSoA) - { - LLAMA_BEGIN_SUPPRESS_HOST_DEVICE_WARNING - Traits::storeUnaligned(srcSimd(rc), &dstRef(rc)); // SIMD store - LLAMA_END_SUPPRESS_HOST_DEVICE_WARNING - } - else - { - // TODO(bgruber): how does this generalize conceptually to 2D and higher dimensions? in which - // direction should we collect SIMD values? - auto b = ArrayIndexIterator{dstRef.view.mapping().extents(), dstRef.arrayIndex()}; - for(auto i = 0; i < Traits::lanes; i++) - dstRef.view (*b++)(cat(typename T::BoundRecordCoord{}, rc)) - = reinterpret_cast(&srcSimd(rc))[i]; // scalar store - } - }); + forEachLeafCoord([&](auto rc) LLAMA_LAMBDA_INLINE + { internal::storeSimdRecord(srcSimd, dstRef, rc); }); } // unstructured srcSimd and reference type else if constexpr(!isRecordRef && !isRecordRef)