Skip to content

Commit

Permalink
Use SoA implementation in Copy specialization directly
Browse files Browse the repository at this point in the history
  • Loading branch information
bernhardmgruber committed Oct 2, 2023
1 parent 8edac59 commit 9f46000
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 25 deletions.
27 changes: 4 additions & 23 deletions include/llama/Copy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,34 +177,15 @@ namespace llama
throw std::runtime_error{"Destination SoA mapping's total array elements must be evenly divisible by the "
"source AoSoA Lane count."};

// the same as SoA::blobNrAndOffset but takes a flat array index
auto mapSoA = [&](std::size_t flatArrayIndex, auto rc, bool mb) LLAMA_LAMBDA_INLINE
{
const auto blob = mb * flatRecordCoord<RecordDim, decltype(rc)>;
const auto offset = !mb * offsetOf<RecordDim, decltype(rc)> * flatSize
+ sizeof(GetType<RecordDim, decltype(rc)>) * flatArrayIndex;
return NrAndOffset{blob, offset};
};

auto mapSrc = [&](std::size_t flatArrayIndex, auto rc) LLAMA_LAMBDA_INLINE
{
if constexpr(srcIsAoSoA)
return &srcView.blobs()[0][0] + srcView.mapping().blobNrAndOffset(flatArrayIndex, rc).offset;
else
{
const auto [blob, off] = mapSoA(flatArrayIndex, rc, isSrcMB);
return &srcView.blobs()[blob][off];
}
const auto [blob, off] = srcView.mapping().blobNrAndOffset(flatArrayIndex, rc);
return &srcView.blobs()[blob][off];
};
auto mapDst = [&](std::size_t flatArrayIndex, auto rc) LLAMA_LAMBDA_INLINE
{
if constexpr(dstIsAoSoA)
return &dstView.blobs()[0][0] + dstView.mapping().blobNrAndOffset(flatArrayIndex, rc).offset;
else
{
const auto [blob, off] = mapSoA(flatArrayIndex, rc, isDstMB);
return &dstView.blobs()[blob][off];
}
const auto [blob, off] = dstView.mapping().blobNrAndOffset(flatArrayIndex, rc);
return &dstView.blobs()[blob][off];
};

static constexpr auto l = []
Expand Down
13 changes: 11 additions & 2 deletions include/llama/mapping/SoA.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,19 @@ namespace llama::mapping
template<std::size_t... RecordCoords>
LLAMA_FN_HOST_ACC_INLINE constexpr auto blobNrAndOffset(
typename Base::ArrayIndex ai,
RecordCoord<RecordCoords...> rc = {}) const -> NrAndOffset<size_type>
{
return blobNrAndOffset(LinearizeArrayIndexFunctor{}(ai, Base::extents()), rc);
}

// Exposed for aosoaCommonBlockCopy. Should be private ...
template<std::size_t... RecordCoords>
LLAMA_FN_HOST_ACC_INLINE constexpr auto blobNrAndOffset(
size_type flatArrayIndex,
RecordCoord<RecordCoords...> = {}) const -> NrAndOffset<size_type>
{
const size_type elementOffset = LinearizeArrayIndexFunctor{}(ai, Base::extents())
* static_cast<size_type>(sizeof(GetType<TRecordDim, RecordCoord<RecordCoords...>>));
const size_type elementOffset
= flatArrayIndex * static_cast<size_type>(sizeof(GetType<TRecordDim, RecordCoord<RecordCoords...>>));
if constexpr(blobs == Blobs::OnePerField)
{
constexpr auto blob = flatRecordCoord<TRecordDim, RecordCoord<RecordCoords...>>;
Expand Down

0 comments on commit 9f46000

Please sign in to comment.