diff --git a/include/llama/Copy.hpp b/include/llama/Copy.hpp index 57a93e7ea1..b02f188aa6 100644 --- a/include/llama/Copy.hpp +++ b/include/llama/Copy.hpp @@ -261,42 +261,24 @@ namespace llama const auto start = threadId * elementsPerThread; const auto stop = threadId == threadCount - 1 ? flatSize : (threadId + 1) * elementsPerThread; - auto copyLBlock = [&](const std::byte*& threadSrc, std::size_t dstIndex, auto rc) LLAMA_LAMBDA_INLINE - { - constexpr auto bytes = l * sizeof(GetType); - std::memcpy(mapDst(dstIndex, rc), threadSrc, bytes); - threadSrc += bytes; - }; - - // if the AoSoA is packed we can move the src pointer along - if constexpr(SrcMapping::fieldAlignment == mapping::FieldAlignment::Pack) - { - auto* threadSrc = mapSrc(start, RecordCoord<>{}); - for(std::size_t i = start; i < stop; i += lanesSrc) - forEachLeafCoord( - [&](auto rc) LLAMA_LAMBDA_INLINE - { - for(std::size_t j = 0; j < lanesSrc; j += l) - { - assert(threadSrc == mapSrc(i + j, rc)); - copyLBlock(threadSrc, i + j, rc); - } - }); - } - else - { - for(std::size_t i = start; i < stop; i += lanesSrc) - forEachLeafCoord( - [&](auto rc) LLAMA_LAMBDA_INLINE + static constexpr auto packed = SrcMapping::fieldAlignment == mapping::FieldAlignment::Pack; + decltype(mapSrc(start, RecordCoord<>{})) src; + if constexpr(packed) + src = mapSrc(start, RecordCoord<>{}); + for(std::size_t i = start; i < stop; i += lanesSrc) + forEachLeafCoord( + [&](auto rc) LLAMA_LAMBDA_INLINE + { + if constexpr(!packed) + src = mapSrc(i, rc); + for(std::size_t j = 0; j < lanesSrc; j += l) { - auto* threadSrc = mapSrc(i, rc); - for(std::size_t j = 0; j < lanesSrc; j += l) - { - assert(threadSrc == mapSrc(i + j, rc)); - copyLBlock(threadSrc, i + j, rc); - } - }); - } + assert(src == mapSrc(i + j, rc)); + static constexpr auto bytes = l * sizeof(GetType); + std::memcpy(mapDst(i + j, rc), src, bytes); + src += bytes; + } + }); } else { @@ -307,42 +289,24 @@ namespace llama const auto start = threadId * elementsPerThread; const auto stop = threadId == threadCount - 1 ? flatSize : (threadId + 1) * elementsPerThread; - auto copyLBlock = [&](std::byte*& threadDst, std::size_t srcIndex, auto rc) LLAMA_LAMBDA_INLINE - { - constexpr auto bytes = l * sizeof(GetType); - std::memcpy(threadDst, mapSrc(srcIndex, rc), bytes); - threadDst += bytes; - }; - - // if the AoSoA is packed we can move the dst pointer along - if constexpr(DstMapping::fieldAlignment == mapping::FieldAlignment::Pack) - { - auto* threadDst = mapDst(start, RecordCoord<>{}); - for(std::size_t i = start; i < stop; i += lanesDst) - forEachLeafCoord( - [&](auto rc) LLAMA_LAMBDA_INLINE - { - for(std::size_t j = 0; j < lanesDst; j += l) - { - assert(threadDst == mapDst(i + j, rc)); - copyLBlock(threadDst, i + j, rc); - } - }); - } - else - { - for(std::size_t i = start; i < stop; i += lanesDst) - forEachLeafCoord( - [&](auto rc) LLAMA_LAMBDA_INLINE + static constexpr auto packed = DstMapping::fieldAlignment == mapping::FieldAlignment::Pack; + decltype(mapDst(start, RecordCoord<>{})) dst; + if constexpr(packed) + dst = mapDst(start, RecordCoord<>{}); + for(std::size_t i = start; i < stop; i += lanesDst) + forEachLeafCoord( + [&](auto rc) LLAMA_LAMBDA_INLINE + { + if constexpr(!packed) + dst = mapDst(i, rc); + for(std::size_t j = 0; j < lanesDst; j += l) { - auto* threadDst = mapDst(i, rc); - for(std::size_t j = 0; j < lanesDst; j += l) - { - assert(threadDst == mapDst(i + j, rc)); - copyLBlock(threadDst, i + j, rc); - } - }); - } + assert(dst == mapDst(i + j, rc)); + constexpr auto bytes = l * sizeof(GetType); + std::memcpy(dst, mapSrc(i + j, rc), bytes); + dst += bytes; + } + }); } }