Skip to content

Commit

Permalink
avx implementation of transpose for [u]int[8|16]
Browse files Browse the repository at this point in the history
  • Loading branch information
serge-sans-paille committed Dec 27, 2024
1 parent 1427dbf commit bbe95c1
Showing 1 changed file with 77 additions and 0 deletions.
77 changes: 77 additions & 0 deletions include/xsimd/arch/xsimd_avx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ namespace xsimd
template <class A, class T, size_t I>
XSIMD_INLINE batch<T, A> insert(batch<T, A> const& self, T val, index<I>, requires_arch<generic>) noexcept;

template <class A>
XSIMD_INLINE void transpose(batch<uint16_t, A>* matrix_begin, batch<uint16_t, A>* matrix_end, requires_arch<generic>) noexcept;
template <class A>
XSIMD_INLINE void transpose(batch<uint8_t, A>* matrix_begin, batch<uint8_t, A>* matrix_end, requires_arch<generic>) noexcept;

namespace detail
{
XSIMD_INLINE void split_avx(__m256i val, __m128i& low, __m128i& high) noexcept
Expand Down Expand Up @@ -1676,6 +1681,78 @@ namespace xsimd
return transpose(reinterpret_cast<batch<double, A>*>(matrix_begin), reinterpret_cast<batch<double, A>*>(matrix_end), A {});
}

template <class A>
XSIMD_INLINE void transpose(batch<uint16_t, A>* matrix_begin, batch<uint16_t, A>* matrix_end, requires_arch<avx>) noexcept
{
assert((matrix_end - matrix_begin == batch<uint16_t, A>::size) && "correctly sized matrix");
(void)matrix_end;
batch<uint16_t, sse4_2> tmp_lo0[8];
for (int i = 0; i < 8; ++i)
tmp_lo0[i] = _mm256_castsi256_si128(matrix_begin[i]);
transpose(tmp_lo0 + 0, tmp_lo0 + 8, sse4_2 {});

batch<uint16_t, sse4_2> tmp_hi0[8];
for (int i = 0; i < 8; ++i)
tmp_hi0[i] = _mm256_castsi256_si128(matrix_begin[8 + i]);
transpose(tmp_hi0 + 0, tmp_hi0 + 8, sse4_2 {});

batch<uint16_t, sse4_2> tmp_lo1[8];
for (int i = 0; i < 8; ++i)
tmp_lo1[i] = _mm256_extractf128_si256(matrix_begin[i], 1);
transpose(tmp_lo1 + 0, tmp_lo1 + 8, sse4_2 {});

batch<uint16_t, sse4_2> tmp_hi1[8];
for (int i = 0; i < 8; ++i)
tmp_hi1[i] = _mm256_extractf128_si256(matrix_begin[8 + i], 1);
transpose(tmp_hi1 + 0, tmp_hi1 + 8, sse4_2 {});

for (int i = 0; i < 8; ++i)
matrix_begin[i] = detail::merge_sse(tmp_lo0[i], tmp_hi0[i]);
for (int i = 0; i < 8; ++i)
matrix_begin[i + 8] = detail::merge_sse(tmp_lo1[i], tmp_hi1[i]);
}
template <class A>
XSIMD_INLINE void transpose(batch<int16_t, A>* matrix_begin, batch<int16_t, A>* matrix_end, requires_arch<avx>) noexcept
{
return transpose(reinterpret_cast<batch<uint16_t, A>*>(matrix_begin), reinterpret_cast<batch<uint16_t, A>*>(matrix_end), A {});
}

template <class A>
XSIMD_INLINE void transpose(batch<uint8_t, A>* matrix_begin, batch<uint8_t, A>* matrix_end, requires_arch<avx>) noexcept
{
assert((matrix_end - matrix_begin == batch<uint8_t, A>::size) && "correctly sized matrix");
(void)matrix_end;
batch<uint8_t, sse4_2> tmp_lo0[16];
for (int i = 0; i < 16; ++i)
tmp_lo0[i] = _mm256_castsi256_si128(matrix_begin[i]);
transpose(tmp_lo0 + 0, tmp_lo0 + 16, sse4_2 {});

batch<uint8_t, sse4_2> tmp_hi0[16];
for (int i = 0; i < 16; ++i)
tmp_hi0[i] = _mm256_castsi256_si128(matrix_begin[16 + i]);
transpose(tmp_hi0 + 0, tmp_hi0 + 16, sse4_2 {});

batch<uint8_t, sse4_2> tmp_lo1[16];
for (int i = 0; i < 16; ++i)
tmp_lo1[i] = _mm256_extractf128_si256(matrix_begin[i], 1);
transpose(tmp_lo1 + 0, tmp_lo1 + 16, sse4_2 {});

batch<uint8_t, sse4_2> tmp_hi1[16];
for (int i = 0; i < 16; ++i)
tmp_hi1[i] = _mm256_extractf128_si256(matrix_begin[16 + i], 1);
transpose(tmp_hi1 + 0, tmp_hi1 + 16, sse4_2 {});

for (int i = 0; i < 16; ++i)
matrix_begin[i] = detail::merge_sse(tmp_lo0[i], tmp_hi0[i]);
for (int i = 0; i < 16; ++i)
matrix_begin[i + 16] = detail::merge_sse(tmp_lo1[i], tmp_hi1[i]);
}
template <class A>
XSIMD_INLINE void transpose(batch<int8_t, A>* matrix_begin, batch<int8_t, A>* matrix_end, requires_arch<avx>) noexcept
{
return transpose(reinterpret_cast<batch<uint8_t, A>*>(matrix_begin), reinterpret_cast<batch<uint8_t, A>*>(matrix_end), A {});
}

// trunc
template <class A>
XSIMD_INLINE batch<float, A> trunc(batch<float, A> const& self, requires_arch<avx>) noexcept
Expand Down

0 comments on commit bbe95c1

Please sign in to comment.