From afdb2ded6c6e03e1a4f1673625c1b2159b1ec5c5 Mon Sep 17 00:00:00 2001 From: Tomasz Szumski Date: Wed, 24 Jul 2024 10:06:36 +0200 Subject: [PATCH 1/4] Add AVX2 optimization --- src/lib/openjp2/dwt.c | 62 +++++++++++++++++++++++++++++++++++++++++ src/lib/openjp2/t1.c | 65 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 127 insertions(+) diff --git a/src/lib/openjp2/dwt.c b/src/lib/openjp2/dwt.c index 6b18c5dd6..2d0e77fa1 100644 --- a/src/lib/openjp2/dwt.c +++ b/src/lib/openjp2/dwt.c @@ -363,6 +363,67 @@ static void opj_idwt53_h_cas0(OPJ_INT32* tmp, if (!(len & 1)) { /* if len is even */ tmp[len - 1] = in_odd[(len - 1) / 2] + tmp[len - 2]; } +#else +#ifdef __AVX2__ + OPJ_INT32* out_ptr = tmp; + int32_t prev_even = in_even[0] - ((in_odd[0] + 1) >> 1); + + const __m256i reg_permutevar_mask_move_right = _mm256_setr_epi32(0x00, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06); + const __m256i two = _mm256_set1_epi32(2); + + int32_t simd_batch = (len - 2) / 16; + int32_t next_even; + __m256i even_m1, odd, unpack1_avx2, unpack2_avx2; + + for (i = 0; i < simd_batch; i++) { + const __m256i lf_avx2 = _mm256_loadu_si256((__m256i*)(in_even + 1)); + const __m256i hf1_avx2 = _mm256_loadu_si256((__m256i*)(in_odd)); + const __m256i hf2_avx2 = _mm256_loadu_si256((__m256i*)(in_odd + 1)); + + __m256i even = _mm256_add_epi32(hf1_avx2, hf2_avx2); + even = _mm256_add_epi32(even, two); + even = _mm256_srai_epi32(even, 2); + even = _mm256_sub_epi32(lf_avx2, even); + + next_even = _mm256_extract_epi32(even, 7); + even_m1 = _mm256_permutevar8x32_epi32(even, reg_permutevar_mask_move_right); + even_m1 = _mm256_insert_epi32(even_m1, prev_even, 0); + + //out[0] + out[2] + odd = _mm256_add_epi32(even_m1, even); + odd = _mm256_srai_epi32(odd, 1); + odd = _mm256_add_epi32(odd, hf1_avx2); + + unpack1_avx2 = _mm256_unpacklo_epi32(even_m1, odd); + unpack2_avx2 = _mm256_unpackhi_epi32(even_m1, odd); + + _mm_storeu_si128((__m128i*)(out_ptr + 0), _mm256_castsi256_si128(unpack1_avx2)); + _mm_storeu_si128((__m128i*)(out_ptr + 4), _mm256_castsi256_si128(unpack2_avx2)); + _mm_storeu_si128((__m128i*)(out_ptr + 8), _mm256_extracti128_si256(unpack1_avx2, 0x1)); + _mm_storeu_si128((__m128i*)(out_ptr + 12), _mm256_extracti128_si256(unpack2_avx2, 0x1)); + + prev_even = next_even; + + out_ptr += 16; + in_even += 8; + in_odd += 8; + } + out_ptr[0] = prev_even; + for (j = simd_batch * 16 + 1; j < (len - 2); j += 2) { + out_ptr[2] = in_even[1] - ((in_odd[0] + in_odd[1] + 2) >> 2); + out_ptr[1] = in_odd[0] + ((out_ptr[0] + out_ptr[2]) >> 1); + in_even++; + in_odd++; + out_ptr += 2; + } + + if (len & 1) { + out_ptr[2] = in_even[1] - ((in_odd[0] + 1) >> 1); + out_ptr[1] = in_odd[0] + ((out_ptr[0] + out_ptr[2]) >> 1); + } + else { //!(len & 1) + out_ptr[1] = in_odd[0] + out_ptr[0]; + } #else OPJ_INT32 d1c, d1n, s1n, s0c, s0n; @@ -397,6 +458,7 @@ static void opj_idwt53_h_cas0(OPJ_INT32* tmp, } else { tmp[len - 1] = d1n + s0n; } +#endif /*__AVX2__*/ #endif memcpy(tiledp, tmp, (OPJ_UINT32)len * sizeof(OPJ_INT32)); } diff --git a/src/lib/openjp2/t1.c b/src/lib/openjp2/t1.c index b5adbf2fb..a08b973b2 100644 --- a/src/lib/openjp2/t1.c +++ b/src/lib/openjp2/t1.c @@ -47,6 +47,9 @@ #ifdef __SSE2__ #include #endif +#ifdef __AVX2__ +#include +#endif #if defined(__GNUC__) #pragma GCC poison malloc calloc realloc free @@ -1796,6 +1799,25 @@ static void opj_t1_clbl_decode_processor(void* user_data, opj_tls_t* tls) OPJ_INT32* OPJ_RESTRICT tiledp = &tilec->data[(OPJ_SIZE_T)y * tile_w + (OPJ_SIZE_T)x]; for (j = 0; j < cblk_h; ++j) { +#ifdef __AVX2__ + //positive -> round down aka. (83)/2 = 41.5 -> 41 + //negative -> round up aka. (-83)/2 = -41.5 -> -41 + + OPJ_INT32* ptr_in = datap + (j * cblk_w); + OPJ_INT32* ptr_out = tiledp + (j * (OPJ_SIZE_T)tile_w); + for (i = 0; i < cblk_w / 8; ++i) { + __m256i in_avx = _mm256_loadu_si256((__m256i*)(ptr_in)); + const __m256i add_avx = _mm256_srli_epi32(in_avx, 31); + in_avx = _mm256_add_epi32(in_avx, add_avx); + _mm256_storeu_si256((__m256i*)(ptr_out), _mm256_srai_epi32(in_avx, 1)); + ptr_in += 8; + ptr_out += 8; + } + + for (i = 0; i < cblk_w % 8; ++i) { + ptr_out[i] = ptr_in[i] / 2; + } +#else i = 0; for (; i < (cblk_w & ~(OPJ_UINT32)3U); i += 4U) { OPJ_INT32 tmp0 = datap[(j * cblk_w) + i + 0U]; @@ -1811,6 +1833,7 @@ static void opj_t1_clbl_decode_processor(void* user_data, opj_tls_t* tls) OPJ_INT32 tmp = datap[(j * cblk_w) + i]; ((OPJ_INT32*)tiledp)[(j * (OPJ_SIZE_T)tile_w) + i] = tmp / 2; } +#endif } } else { /* if (tccp->qmfbid == 0) */ const float stepsize = 0.5f * band->stepsize; @@ -2233,6 +2256,47 @@ static void opj_t1_cblk_encode_processor(void* user_data, opj_tls_t* tls) OPJ_UINT32* OPJ_RESTRICT t1data = (OPJ_UINT32*) t1->data; /* Change from "natural" order to "zigzag" order of T1 passes */ for (j = 0; j < (cblk_h & ~3U); j += 4) { +#ifdef __AVX2__ + OPJ_UINT32* ptr = tiledp_u; + for (i = 0; i < cblk_w / 8; ++i) { + // INPUT OUTPUT + // 00 01 02 03 04 05 06 07 00 10 20 30 01 11 21 31 + // 10 11 12 13 14 15 16 17 02 12 22 32 03 13 23 33 + // 20 21 22 23 24 25 26 27 04 14 24 34 05 15 25 35 + // 30 31 32 33 34 35 36 37 06 16 26 36 07 17 27 37 + __m256i in1 = _mm256_slli_epi32(_mm256_loadu_si256((__m256i*)(ptr + (j + 0) * tile_w)), T1_NMSEDEC_FRACBITS); + __m256i in2 = _mm256_slli_epi32(_mm256_loadu_si256((__m256i*)(ptr + (j + 1) * tile_w)), T1_NMSEDEC_FRACBITS); + __m256i in3 = _mm256_slli_epi32(_mm256_loadu_si256((__m256i*)(ptr + (j + 2) * tile_w)), T1_NMSEDEC_FRACBITS); + __m256i in4 = _mm256_slli_epi32(_mm256_loadu_si256((__m256i*)(ptr + (j + 3) * tile_w)), T1_NMSEDEC_FRACBITS); + + __m256i tmp1 = _mm256_unpacklo_epi32(in1, in2); + __m256i tmp2 = _mm256_unpacklo_epi32(in3, in4); + __m256i tmp3 = _mm256_unpackhi_epi32(in1, in2); + __m256i tmp4 = _mm256_unpackhi_epi32(in3, in4); + + in1 = _mm256_unpacklo_epi64(tmp1, tmp2); + in2 = _mm256_unpacklo_epi64(tmp3, tmp4); + in3 = _mm256_unpackhi_epi64(tmp1, tmp2); + in4 = _mm256_unpackhi_epi64(tmp3, tmp4); + + _mm_storeu_si128((__m128i*)(t1data + 0), _mm256_castsi256_si128(in1)); + _mm_storeu_si128((__m128i*)(t1data + 4), _mm256_castsi256_si128(in3)); + _mm_storeu_si128((__m128i*)(t1data + 8), _mm256_castsi256_si128(in2)); + _mm_storeu_si128((__m128i*)(t1data + 12), _mm256_castsi256_si128(in4)); + _mm256_storeu_si256((__m256i*)(t1data + 16), _mm256_permute2x128_si256(in1, in3, 0x31)); + _mm256_storeu_si256((__m256i*)(t1data + 24), _mm256_permute2x128_si256(in2, in4, 0x31)); + t1data += 32; + ptr += 8; + } + for (i = 0; i < cblk_w % 8; ++i) { + t1data[0] = ptr[(j + 0) * tile_w] << T1_NMSEDEC_FRACBITS; + t1data[1] = ptr[(j + 1) * tile_w] << T1_NMSEDEC_FRACBITS; + t1data[2] = ptr[(j + 2) * tile_w] << T1_NMSEDEC_FRACBITS; + t1data[3] = ptr[(j + 3) * tile_w] << T1_NMSEDEC_FRACBITS; + t1data += 4; + ptr += 1; + } +#else for (i = 0; i < cblk_w; ++i) { t1data[0] = tiledp_u[(j + 0) * tile_w + i] << T1_NMSEDEC_FRACBITS; t1data[1] = tiledp_u[(j + 1) * tile_w + i] << T1_NMSEDEC_FRACBITS; @@ -2240,6 +2304,7 @@ static void opj_t1_cblk_encode_processor(void* user_data, opj_tls_t* tls) t1data[3] = tiledp_u[(j + 3) * tile_w + i] << T1_NMSEDEC_FRACBITS; t1data += 4; } +#endif } if (j < cblk_h) { for (i = 0; i < cblk_w; ++i) { From 846538ccabc8a7896efb588b986cefdce3e44b81 Mon Sep 17 00:00:00 2001 From: Tomasz Szumski Date: Thu, 25 Jul 2024 07:31:41 +0200 Subject: [PATCH 2/4] AVX512 code --- src/lib/openjp2/dwt.c | 160 +++++++++++++++++++++++++++++++++++++++--- src/lib/openjp2/t1.c | 70 +++++++++++++++++- 2 files changed, 218 insertions(+), 12 deletions(-) diff --git a/src/lib/openjp2/dwt.c b/src/lib/openjp2/dwt.c index 2d0e77fa1..dcac03cc0 100644 --- a/src/lib/openjp2/dwt.c +++ b/src/lib/openjp2/dwt.c @@ -52,7 +52,7 @@ #ifdef __SSSE3__ #include #endif -#ifdef __AVX2__ +#if (defined(__AVX2__) || defined(__AVX512F__)) #include #endif @@ -66,7 +66,10 @@ #define OPJ_WS(i) v->mem[(i)*2] #define OPJ_WD(i) v->mem[(1+(i)*2)] -#ifdef __AVX2__ +#if defined(__AVX512F__) +/** Number of int32 values in a AVX512 register */ +#define VREG_INT_COUNT 16 +#elif defined(__AVX2__) /** Number of int32 values in a AVX2 register */ #define VREG_INT_COUNT 8 #else @@ -331,6 +334,50 @@ static void opj_dwt_decode_1(const opj_dwt_t *v) #endif /* STANDARD_SLOW_VERSION */ +#if defined(__AVX512F__) +static int32_t loop_short_sse(int32_t len, const int32_t** lf_ptr, + const int32_t** hf_ptr, int32_t** out_ptr, + int32_t* prev_even) { + int32_t next_even; + __m128i odd, even_m1, unpack1, unpack2; + const int32_t batch = (len - 2) / 8; + const __m128i two = _mm_set1_epi32(2); + + for (int32_t i = 0; i < batch; i++) { + const __m128i lf_ = _mm_loadu_si128((__m128i*)(*lf_ptr + 1)); + const __m128i hf1_ = _mm_loadu_si128((__m128i*)(*hf_ptr)); + const __m128i hf2_ = _mm_loadu_si128((__m128i*)(*hf_ptr + 1)); + + __m128i even = _mm_add_epi32(hf1_, hf2_); + even = _mm_add_epi32(even, two); + even = _mm_srai_epi32(even, 2); + even = _mm_sub_epi32(lf_, even); + + next_even = _mm_extract_epi32(even, 3); + even_m1 = _mm_bslli_si128(even, 4); + even_m1 = _mm_insert_epi32(even_m1, *prev_even, 0); + + //out[0] + out[2] + odd = _mm_add_epi32(even_m1, even); + odd = _mm_srai_epi32(odd, 1); + odd = _mm_add_epi32(odd, hf1_); + + unpack1 = _mm_unpacklo_epi32(even_m1, odd); + unpack2 = _mm_unpackhi_epi32(even_m1, odd); + + _mm_storeu_si128((__m128i*)(*out_ptr + 0), unpack1); + _mm_storeu_si128((__m128i*)(*out_ptr + 4), unpack2); + + *prev_even = next_even; + + *out_ptr += 8; + *lf_ptr += 4; + *hf_ptr += 4; + } + return batch; +} +#endif + #if !defined(STANDARD_SLOW_VERSION) static void opj_idwt53_h_cas0(OPJ_INT32* tmp, const OPJ_INT32 sn, @@ -364,7 +411,80 @@ static void opj_idwt53_h_cas0(OPJ_INT32* tmp, tmp[len - 1] = in_odd[(len - 1) / 2] + tmp[len - 2]; } #else -#ifdef __AVX2__ +#if defined(__AVX512F__) + OPJ_INT32* out_ptr = tmp; + int32_t prev_even = in_even[0] - ((in_odd[0] + 1) >> 1); + + const __m512i permutevar_mask = _mm512_setr_epi32( + 0x10, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e); + const __m512i store1_perm = _mm512_setr_epi64(0x00, 0x01, 0x08, 0x09, 0x02, 0x03, 0x0a, 0x0b); + const __m512i store2_perm = _mm512_setr_epi64(0x04, 0x05, 0x0c, 0x0d, 0x06, 0x07, 0x0e, 0x0f); + + const __m512i two = _mm512_set1_epi32(2); + + int32_t simd_batch_512 = (len - 2) / 32; + int32_t leftover; + + for (i = 0; i < simd_batch_512; i++) { + const __m512i lf_avx2 = _mm512_loadu_si512((__m512i*)(in_even + 1)); + const __m512i hf1_avx2 = _mm512_loadu_si512((__m512i*)(in_odd)); + const __m512i hf2_avx2 = _mm512_loadu_si512((__m512i*)(in_odd + 1)); + int32_t next_even; + __m512i duplicate, even_m1, odd, unpack1, unpack2, store1, store2; + + __m512i even = _mm512_add_epi32(hf1_avx2, hf2_avx2); + even = _mm512_add_epi32(even, two); + even = _mm512_srai_epi32(even, 2); + even = _mm512_sub_epi32(lf_avx2, even); + + next_even = _mm_extract_epi32(_mm512_extracti32x4_epi32(even, 3), 3); + + duplicate = _mm512_set1_epi32(prev_even); + even_m1 = _mm512_permutex2var_epi32(even, permutevar_mask, duplicate); + + //out[0] + out[2] + odd = _mm512_add_epi32(even_m1, even); + odd = _mm512_srai_epi32(odd, 1); + odd = _mm512_add_epi32(odd, hf1_avx2); + + unpack1 = _mm512_unpacklo_epi32(even_m1, odd); + unpack2 = _mm512_unpackhi_epi32(even_m1, odd); + + store1 = _mm512_permutex2var_epi64(unpack1, store1_perm, unpack2); + store2 = _mm512_permutex2var_epi64(unpack1, store2_perm, unpack2); + + _mm512_storeu_si512(out_ptr, store1); + _mm512_storeu_si512(out_ptr + 16, store2); + + prev_even = next_even; + + out_ptr += 32; + in_even += 16; + in_odd += 16; + } + + leftover = len - simd_batch_512 * 32; + if (leftover > 8) { + leftover -= 8 * loop_short_sse(leftover, &in_even, &in_odd, &out_ptr, &prev_even); + } + out_ptr[0] = prev_even; + + for (j = 1; j < (leftover - 2); j += 2) { + out_ptr[2] = in_even[1] - ((in_odd[0] + (in_odd[1]) + 2) >> 2); + out_ptr[1] = in_odd[0] + ((out_ptr[0] + out_ptr[2]) >> 1); + in_even++; + in_odd++; + out_ptr += 2; +} + + if (len & 1) { + out_ptr[2] = in_even[1] - ((in_odd[0] + 1) >> 1); + out_ptr[1] = in_odd[0] + ((out_ptr[0] + out_ptr[2]) >> 1); + } + else { //!(len & 1) + out_ptr[1] = in_odd[0] + out_ptr[0]; + } +#elif defined(__AVX2__) OPJ_INT32* out_ptr = tmp; int32_t prev_even = in_even[0] - ((in_odd[0] + 1) >> 1); @@ -458,8 +578,8 @@ static void opj_idwt53_h_cas0(OPJ_INT32* tmp, } else { tmp[len - 1] = d1n + s0n; } -#endif /*__AVX2__*/ -#endif +#endif /*(__AVX512F__ || __AVX2__)*/ +#endif /*TWO_PASS_VERSION*/ memcpy(tiledp, tmp, (OPJ_UINT32)len * sizeof(OPJ_INT32)); } @@ -573,10 +693,20 @@ static void opj_idwt53_h(const opj_dwt_t *dwt, #endif } -#if (defined(__SSE2__) || defined(__AVX2__)) && !defined(STANDARD_SLOW_VERSION) +#if (defined(__SSE2__) || defined(__AVX2__) || defined(__AVX512F__)) && !defined(STANDARD_SLOW_VERSION) /* Conveniency macros to improve the readability of the formulas */ -#if __AVX2__ +#if defined(__AVX512F__) +#define VREG __m512i +#define LOAD_CST(x) _mm512_set1_epi32(x) +#define LOAD(x) _mm512_loadu_si512((const VREG*)(x)) +#define LOADU(x) _mm512_loadu_si512((const VREG*)(x)) +#define STORE(x,y) _mm512_storeu_si512((VREG*)(x),(y)) +#define STOREU(x,y) _mm512_storeu_si512((VREG*)(x),(y)) +#define ADD(x,y) _mm512_add_epi32((x),(y)) +#define SUB(x,y) _mm512_sub_epi32((x),(y)) +#define SAR(x,y) _mm512_srai_epi32((x),(y)) +#elif defined(__AVX2__) #define VREG __m256i #define LOAD_CST(x) _mm256_set1_epi32(x) #define LOAD(x) _mm256_load_si256((const VREG*)(x)) @@ -638,7 +768,10 @@ static void opj_idwt53_v_cas0_mcols_SSE2_OR_AVX2( const VREG two = LOAD_CST(2); assert(len > 1); -#if __AVX2__ +#if defined(__AVX512F__) + assert(PARALLEL_COLS_53 == 32); + assert(VREG_INT_COUNT == 16); +#elif defined(__AVX2__) assert(PARALLEL_COLS_53 == 16); assert(VREG_INT_COUNT == 8); #else @@ -646,10 +779,13 @@ static void opj_idwt53_v_cas0_mcols_SSE2_OR_AVX2( assert(VREG_INT_COUNT == 4); #endif +//For AVX512 code aligned load/store is set to it's unaligned equivalents +#if !defined(__AVX512F__) /* Note: loads of input even/odd values must be done in a unaligned */ /* fashion. But stores in tmp can be done with aligned store, since */ /* the temporary buffer is properly aligned */ assert((OPJ_SIZE_T)tmp % (sizeof(OPJ_INT32) * VREG_INT_COUNT) == 0); +#endif s1n_0 = LOADU(in_even + 0); s1n_1 = LOADU(in_even + VREG_INT_COUNT); @@ -740,7 +876,10 @@ static void opj_idwt53_v_cas1_mcols_SSE2_OR_AVX2( const OPJ_INT32* in_odd = &tiledp_col[0]; assert(len > 2); -#if __AVX2__ +#if defined(__AVX512F__) + assert(PARALLEL_COLS_53 == 32); + assert(VREG_INT_COUNT == 16); +#elif defined(__AVX2__) assert(PARALLEL_COLS_53 == 16); assert(VREG_INT_COUNT == 8); #else @@ -748,10 +887,13 @@ static void opj_idwt53_v_cas1_mcols_SSE2_OR_AVX2( assert(VREG_INT_COUNT == 4); #endif +//For AVX512 code aligned load/store is set to it's unaligned equivalents +#if !defined(__AVX512F__) /* Note: loads of input even/odd values must be done in a unaligned */ /* fashion. But stores in tmp can be done with aligned store, since */ /* the temporary buffer is properly aligned */ assert((OPJ_SIZE_T)tmp % (sizeof(OPJ_INT32) * VREG_INT_COUNT) == 0); +#endif s1_0 = LOADU(in_even + stride); /* in_odd[0] - ((in_even[0] + s1 + 2) >> 2); */ diff --git a/src/lib/openjp2/t1.c b/src/lib/openjp2/t1.c index a08b973b2..856c2401b 100644 --- a/src/lib/openjp2/t1.c +++ b/src/lib/openjp2/t1.c @@ -47,7 +47,7 @@ #ifdef __SSE2__ #include #endif -#ifdef __AVX2__ +#if (defined(__AVX2__) || defined(__AVX512F__)) #include #endif @@ -1799,10 +1799,24 @@ static void opj_t1_clbl_decode_processor(void* user_data, opj_tls_t* tls) OPJ_INT32* OPJ_RESTRICT tiledp = &tilec->data[(OPJ_SIZE_T)y * tile_w + (OPJ_SIZE_T)x]; for (j = 0; j < cblk_h; ++j) { -#ifdef __AVX2__ //positive -> round down aka. (83)/2 = 41.5 -> 41 //negative -> round up aka. (-83)/2 = -41.5 -> -41 +#if defined(__AVX512F__) + OPJ_INT32* ptr_in = datap + (j * cblk_w); + OPJ_INT32* ptr_out = tiledp + (j * (OPJ_SIZE_T)tile_w); + for (i = 0; i < cblk_w / 16; ++i) { + __m512i in_avx = _mm512_loadu_si512((__m512i*)(ptr_in)); + const __m512i add_avx = _mm512_srli_epi32(in_avx, 31); + in_avx = _mm512_add_epi32(in_avx, add_avx); + _mm512_storeu_si512((__m512i*)(ptr_out), _mm512_srai_epi32(in_avx, 1)); + ptr_in += 16; + ptr_out += 16; + } + for (i = 0; i < cblk_w % 16; ++i) { + ptr_out[i] = ptr_in[i] / 2; + } +#elif defined(__AVX2__) OPJ_INT32* ptr_in = datap + (j * cblk_w); OPJ_INT32* ptr_out = tiledp + (j * (OPJ_SIZE_T)tile_w); for (i = 0; i < cblk_w / 8; ++i) { @@ -2256,7 +2270,57 @@ static void opj_t1_cblk_encode_processor(void* user_data, opj_tls_t* tls) OPJ_UINT32* OPJ_RESTRICT t1data = (OPJ_UINT32*) t1->data; /* Change from "natural" order to "zigzag" order of T1 passes */ for (j = 0; j < (cblk_h & ~3U); j += 4) { -#ifdef __AVX2__ +#if defined(__AVX512F__) + const __m512i perm1 = _mm512_setr_epi64(2, 3, 10, 11, 4, 5, 12, 13); + const __m512i perm2 = _mm512_setr_epi64(6, 7, 14, 15, 0, 0, 0, 0); + OPJ_UINT32* ptr = tiledp_u; + for (i = 0; i < cblk_w / 16; ++i) { + // INPUT OUTPUT + // 00 01 02 03 04 05 06 07 08 09 0A 0B 0C 0D 0E 0F 00 10 20 30 01 11 21 31 02 12 22 32 03 13 23 33 + // 10 11 12 13 14 15 16 17 18 19 1A 1B 1C 1D 1E 1F 04 14 24 34 05 15 25 35 06 16 26 36 07 17 27 37 + // 20 21 22 23 24 25 26 27 28 29 2A 2B 2C 2D 2E 2F 08 18 28 38 09 19 29 39 0A 1A 2A 3A 0B 1B 2B 3B + // 30 31 32 33 34 35 36 37 38 39 3A 3B 3C 3D 3E 3F 0C 1C 2C 3C 0D 1D 2D 3D 0E 1E 2E 3E 0F 1F 2F 3F + __m512i in1 = _mm512_slli_epi32(_mm512_loadu_si512((__m512i*)(ptr + (j + 0) * tile_w)), T1_NMSEDEC_FRACBITS); + __m512i in2 = _mm512_slli_epi32(_mm512_loadu_si512((__m512i*)(ptr + (j + 1) * tile_w)), T1_NMSEDEC_FRACBITS); + __m512i in3 = _mm512_slli_epi32(_mm512_loadu_si512((__m512i*)(ptr + (j + 2) * tile_w)), T1_NMSEDEC_FRACBITS); + __m512i in4 = _mm512_slli_epi32(_mm512_loadu_si512((__m512i*)(ptr + (j + 3) * tile_w)), T1_NMSEDEC_FRACBITS); + + __m512i tmp1 = _mm512_unpacklo_epi32(in1, in2); + __m512i tmp2 = _mm512_unpacklo_epi32(in3, in4); + __m512i tmp3 = _mm512_unpackhi_epi32(in1, in2); + __m512i tmp4 = _mm512_unpackhi_epi32(in3, in4); + + in1 = _mm512_unpacklo_epi64(tmp1, tmp2); + in2 = _mm512_unpacklo_epi64(tmp3, tmp4); + in3 = _mm512_unpackhi_epi64(tmp1, tmp2); + in4 = _mm512_unpackhi_epi64(tmp3, tmp4); + + _mm_storeu_si128((__m128i*)(t1data + 0), _mm512_castsi512_si128(in1)); + _mm_storeu_si128((__m128i*)(t1data + 4), _mm512_castsi512_si128(in3)); + _mm_storeu_si128((__m128i*)(t1data + 8), _mm512_castsi512_si128(in2)); + _mm_storeu_si128((__m128i*)(t1data + 12), _mm512_castsi512_si128(in4)); + + tmp1 = _mm512_permutex2var_epi64(in1, perm1, in3); + tmp2 = _mm512_permutex2var_epi64(in2, perm1, in4); + + _mm256_storeu_si256((__m256i*)(t1data + 16), _mm512_castsi512_si256(tmp1)); + _mm256_storeu_si256((__m256i*)(t1data + 24), _mm512_castsi512_si256(tmp2)); + _mm256_storeu_si256((__m256i*)(t1data + 32), _mm512_extracti64x4_epi64(tmp1, 0x1)); + _mm256_storeu_si256((__m256i*)(t1data + 40), _mm512_extracti64x4_epi64(tmp2, 0x1)); + _mm256_storeu_si256((__m256i*)(t1data + 48), _mm512_castsi512_si256(_mm512_permutex2var_epi64(in1, perm2, in3))); + _mm256_storeu_si256((__m256i*)(t1data + 56), _mm512_castsi512_si256(_mm512_permutex2var_epi64(in2, perm2, in4))); + t1data += 64; + ptr += 16; + } + for (i = 0; i < cblk_w % 16; ++i) { + t1data[0] = ptr[(j + 0) * tile_w] << T1_NMSEDEC_FRACBITS; + t1data[1] = ptr[(j + 1) * tile_w] << T1_NMSEDEC_FRACBITS; + t1data[2] = ptr[(j + 2) * tile_w] << T1_NMSEDEC_FRACBITS; + t1data[3] = ptr[(j + 3) * tile_w] << T1_NMSEDEC_FRACBITS; + t1data += 4; + ptr += 1; + } +#elif defined(__AVX2__) OPJ_UINT32* ptr = tiledp_u; for (i = 0; i < cblk_w / 8; ++i) { // INPUT OUTPUT From 342c6872cab8f778e4699c13e2f03cd4d102c03d Mon Sep 17 00:00:00 2001 From: Tomasz Szumski Date: Thu, 5 Sep 2024 12:27:55 +0200 Subject: [PATCH 3/4] Fix CI, _mm256_extract_epi32 and _mm256_insert_epi32 does not exist in MSVC 2015 toolset --- src/lib/openjp2/dwt.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lib/openjp2/dwt.c b/src/lib/openjp2/dwt.c index dcac03cc0..546a91a60 100644 --- a/src/lib/openjp2/dwt.c +++ b/src/lib/openjp2/dwt.c @@ -505,9 +505,9 @@ static void opj_idwt53_h_cas0(OPJ_INT32* tmp, even = _mm256_srai_epi32(even, 2); even = _mm256_sub_epi32(lf_avx2, even); - next_even = _mm256_extract_epi32(even, 7); + next_even = _mm_extract_epi32(_mm256_extracti128_si256(even, 1), 3); even_m1 = _mm256_permutevar8x32_epi32(even, reg_permutevar_mask_move_right); - even_m1 = _mm256_insert_epi32(even_m1, prev_even, 0); + even_m1 = _mm256_blend_epi32(even_m1, _mm256_set1_epi32(prev_even), (1 << 0)); //out[0] + out[2] odd = _mm256_add_epi32(even_m1, even); From ececcd9210ad8113920ad48629f0176cfbf18e27 Mon Sep 17 00:00:00 2001 From: Tomasz Szumski Date: Thu, 5 Sep 2024 18:18:47 +0200 Subject: [PATCH 4/4] Fix Code Style --- src/lib/openjp2/dwt.c | 34 +++++++----- src/lib/openjp2/t1.c | 124 +++++++++++++++++++++++------------------- 2 files changed, 89 insertions(+), 69 deletions(-) diff --git a/src/lib/openjp2/dwt.c b/src/lib/openjp2/dwt.c index 546a91a60..11aae472d 100644 --- a/src/lib/openjp2/dwt.c +++ b/src/lib/openjp2/dwt.c @@ -336,8 +336,9 @@ static void opj_dwt_decode_1(const opj_dwt_t *v) #if defined(__AVX512F__) static int32_t loop_short_sse(int32_t len, const int32_t** lf_ptr, - const int32_t** hf_ptr, int32_t** out_ptr, - int32_t* prev_even) { + const int32_t** hf_ptr, int32_t** out_ptr, + int32_t* prev_even) +{ int32_t next_even; __m128i odd, even_m1, unpack1, unpack2; const int32_t batch = (len - 2) / 8; @@ -416,9 +417,12 @@ static void opj_idwt53_h_cas0(OPJ_INT32* tmp, int32_t prev_even = in_even[0] - ((in_odd[0] + 1) >> 1); const __m512i permutevar_mask = _mm512_setr_epi32( - 0x10, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e); - const __m512i store1_perm = _mm512_setr_epi64(0x00, 0x01, 0x08, 0x09, 0x02, 0x03, 0x0a, 0x0b); - const __m512i store2_perm = _mm512_setr_epi64(0x04, 0x05, 0x0c, 0x0d, 0x06, 0x07, 0x0e, 0x0f); + 0x10, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, + 0x0c, 0x0d, 0x0e); + const __m512i store1_perm = _mm512_setr_epi64(0x00, 0x01, 0x08, 0x09, 0x02, + 0x03, 0x0a, 0x0b); + const __m512i store2_perm = _mm512_setr_epi64(0x04, 0x05, 0x0c, 0x0d, 0x06, + 0x07, 0x0e, 0x0f); const __m512i two = _mm512_set1_epi32(2); @@ -465,7 +469,8 @@ static void opj_idwt53_h_cas0(OPJ_INT32* tmp, leftover = len - simd_batch_512 * 32; if (leftover > 8) { - leftover -= 8 * loop_short_sse(leftover, &in_even, &in_odd, &out_ptr, &prev_even); + leftover -= 8 * loop_short_sse(leftover, &in_even, &in_odd, &out_ptr, + &prev_even); } out_ptr[0] = prev_even; @@ -475,20 +480,20 @@ static void opj_idwt53_h_cas0(OPJ_INT32* tmp, in_even++; in_odd++; out_ptr += 2; -} + } if (len & 1) { out_ptr[2] = in_even[1] - ((in_odd[0] + 1) >> 1); out_ptr[1] = in_odd[0] + ((out_ptr[0] + out_ptr[2]) >> 1); - } - else { //!(len & 1) + } else { //!(len & 1) out_ptr[1] = in_odd[0] + out_ptr[0]; } #elif defined(__AVX2__) OPJ_INT32* out_ptr = tmp; int32_t prev_even = in_even[0] - ((in_odd[0] + 1) >> 1); - const __m256i reg_permutevar_mask_move_right = _mm256_setr_epi32(0x00, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06); + const __m256i reg_permutevar_mask_move_right = _mm256_setr_epi32(0x00, 0x00, + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06); const __m256i two = _mm256_set1_epi32(2); int32_t simd_batch = (len - 2) / 16; @@ -519,8 +524,10 @@ static void opj_idwt53_h_cas0(OPJ_INT32* tmp, _mm_storeu_si128((__m128i*)(out_ptr + 0), _mm256_castsi256_si128(unpack1_avx2)); _mm_storeu_si128((__m128i*)(out_ptr + 4), _mm256_castsi256_si128(unpack2_avx2)); - _mm_storeu_si128((__m128i*)(out_ptr + 8), _mm256_extracti128_si256(unpack1_avx2, 0x1)); - _mm_storeu_si128((__m128i*)(out_ptr + 12), _mm256_extracti128_si256(unpack2_avx2, 0x1)); + _mm_storeu_si128((__m128i*)(out_ptr + 8), _mm256_extracti128_si256(unpack1_avx2, + 0x1)); + _mm_storeu_si128((__m128i*)(out_ptr + 12), + _mm256_extracti128_si256(unpack2_avx2, 0x1)); prev_even = next_even; @@ -540,8 +547,7 @@ static void opj_idwt53_h_cas0(OPJ_INT32* tmp, if (len & 1) { out_ptr[2] = in_even[1] - ((in_odd[0] + 1) >> 1); out_ptr[1] = in_odd[0] + ((out_ptr[0] + out_ptr[2]) >> 1); - } - else { //!(len & 1) + } else { //!(len & 1) out_ptr[1] = in_odd[0] + out_ptr[0]; } #else diff --git a/src/lib/openjp2/t1.c b/src/lib/openjp2/t1.c index 856c2401b..98dce47f5 100644 --- a/src/lib/openjp2/t1.c +++ b/src/lib/openjp2/t1.c @@ -2271,55 +2271,63 @@ static void opj_t1_cblk_encode_processor(void* user_data, opj_tls_t* tls) /* Change from "natural" order to "zigzag" order of T1 passes */ for (j = 0; j < (cblk_h & ~3U); j += 4) { #if defined(__AVX512F__) - const __m512i perm1 = _mm512_setr_epi64(2, 3, 10, 11, 4, 5, 12, 13); - const __m512i perm2 = _mm512_setr_epi64(6, 7, 14, 15, 0, 0, 0, 0); - OPJ_UINT32* ptr = tiledp_u; - for (i = 0; i < cblk_w / 16; ++i) { - // INPUT OUTPUT - // 00 01 02 03 04 05 06 07 08 09 0A 0B 0C 0D 0E 0F 00 10 20 30 01 11 21 31 02 12 22 32 03 13 23 33 - // 10 11 12 13 14 15 16 17 18 19 1A 1B 1C 1D 1E 1F 04 14 24 34 05 15 25 35 06 16 26 36 07 17 27 37 - // 20 21 22 23 24 25 26 27 28 29 2A 2B 2C 2D 2E 2F 08 18 28 38 09 19 29 39 0A 1A 2A 3A 0B 1B 2B 3B - // 30 31 32 33 34 35 36 37 38 39 3A 3B 3C 3D 3E 3F 0C 1C 2C 3C 0D 1D 2D 3D 0E 1E 2E 3E 0F 1F 2F 3F - __m512i in1 = _mm512_slli_epi32(_mm512_loadu_si512((__m512i*)(ptr + (j + 0) * tile_w)), T1_NMSEDEC_FRACBITS); - __m512i in2 = _mm512_slli_epi32(_mm512_loadu_si512((__m512i*)(ptr + (j + 1) * tile_w)), T1_NMSEDEC_FRACBITS); - __m512i in3 = _mm512_slli_epi32(_mm512_loadu_si512((__m512i*)(ptr + (j + 2) * tile_w)), T1_NMSEDEC_FRACBITS); - __m512i in4 = _mm512_slli_epi32(_mm512_loadu_si512((__m512i*)(ptr + (j + 3) * tile_w)), T1_NMSEDEC_FRACBITS); - - __m512i tmp1 = _mm512_unpacklo_epi32(in1, in2); - __m512i tmp2 = _mm512_unpacklo_epi32(in3, in4); - __m512i tmp3 = _mm512_unpackhi_epi32(in1, in2); - __m512i tmp4 = _mm512_unpackhi_epi32(in3, in4); - - in1 = _mm512_unpacklo_epi64(tmp1, tmp2); - in2 = _mm512_unpacklo_epi64(tmp3, tmp4); - in3 = _mm512_unpackhi_epi64(tmp1, tmp2); - in4 = _mm512_unpackhi_epi64(tmp3, tmp4); - - _mm_storeu_si128((__m128i*)(t1data + 0), _mm512_castsi512_si128(in1)); - _mm_storeu_si128((__m128i*)(t1data + 4), _mm512_castsi512_si128(in3)); - _mm_storeu_si128((__m128i*)(t1data + 8), _mm512_castsi512_si128(in2)); - _mm_storeu_si128((__m128i*)(t1data + 12), _mm512_castsi512_si128(in4)); - - tmp1 = _mm512_permutex2var_epi64(in1, perm1, in3); - tmp2 = _mm512_permutex2var_epi64(in2, perm1, in4); - - _mm256_storeu_si256((__m256i*)(t1data + 16), _mm512_castsi512_si256(tmp1)); - _mm256_storeu_si256((__m256i*)(t1data + 24), _mm512_castsi512_si256(tmp2)); - _mm256_storeu_si256((__m256i*)(t1data + 32), _mm512_extracti64x4_epi64(tmp1, 0x1)); - _mm256_storeu_si256((__m256i*)(t1data + 40), _mm512_extracti64x4_epi64(tmp2, 0x1)); - _mm256_storeu_si256((__m256i*)(t1data + 48), _mm512_castsi512_si256(_mm512_permutex2var_epi64(in1, perm2, in3))); - _mm256_storeu_si256((__m256i*)(t1data + 56), _mm512_castsi512_si256(_mm512_permutex2var_epi64(in2, perm2, in4))); - t1data += 64; - ptr += 16; - } - for (i = 0; i < cblk_w % 16; ++i) { - t1data[0] = ptr[(j + 0) * tile_w] << T1_NMSEDEC_FRACBITS; - t1data[1] = ptr[(j + 1) * tile_w] << T1_NMSEDEC_FRACBITS; - t1data[2] = ptr[(j + 2) * tile_w] << T1_NMSEDEC_FRACBITS; - t1data[3] = ptr[(j + 3) * tile_w] << T1_NMSEDEC_FRACBITS; - t1data += 4; - ptr += 1; - } + const __m512i perm1 = _mm512_setr_epi64(2, 3, 10, 11, 4, 5, 12, 13); + const __m512i perm2 = _mm512_setr_epi64(6, 7, 14, 15, 0, 0, 0, 0); + OPJ_UINT32* ptr = tiledp_u; + for (i = 0; i < cblk_w / 16; ++i) { + // INPUT OUTPUT + // 00 01 02 03 04 05 06 07 08 09 0A 0B 0C 0D 0E 0F 00 10 20 30 01 11 21 31 02 12 22 32 03 13 23 33 + // 10 11 12 13 14 15 16 17 18 19 1A 1B 1C 1D 1E 1F 04 14 24 34 05 15 25 35 06 16 26 36 07 17 27 37 + // 20 21 22 23 24 25 26 27 28 29 2A 2B 2C 2D 2E 2F 08 18 28 38 09 19 29 39 0A 1A 2A 3A 0B 1B 2B 3B + // 30 31 32 33 34 35 36 37 38 39 3A 3B 3C 3D 3E 3F 0C 1C 2C 3C 0D 1D 2D 3D 0E 1E 2E 3E 0F 1F 2F 3F + __m512i in1 = _mm512_slli_epi32(_mm512_loadu_si512((__m512i*)(ptr + + (j + 0) * tile_w)), T1_NMSEDEC_FRACBITS); + __m512i in2 = _mm512_slli_epi32(_mm512_loadu_si512((__m512i*)(ptr + + (j + 1) * tile_w)), T1_NMSEDEC_FRACBITS); + __m512i in3 = _mm512_slli_epi32(_mm512_loadu_si512((__m512i*)(ptr + + (j + 2) * tile_w)), T1_NMSEDEC_FRACBITS); + __m512i in4 = _mm512_slli_epi32(_mm512_loadu_si512((__m512i*)(ptr + + (j + 3) * tile_w)), T1_NMSEDEC_FRACBITS); + + __m512i tmp1 = _mm512_unpacklo_epi32(in1, in2); + __m512i tmp2 = _mm512_unpacklo_epi32(in3, in4); + __m512i tmp3 = _mm512_unpackhi_epi32(in1, in2); + __m512i tmp4 = _mm512_unpackhi_epi32(in3, in4); + + in1 = _mm512_unpacklo_epi64(tmp1, tmp2); + in2 = _mm512_unpacklo_epi64(tmp3, tmp4); + in3 = _mm512_unpackhi_epi64(tmp1, tmp2); + in4 = _mm512_unpackhi_epi64(tmp3, tmp4); + + _mm_storeu_si128((__m128i*)(t1data + 0), _mm512_castsi512_si128(in1)); + _mm_storeu_si128((__m128i*)(t1data + 4), _mm512_castsi512_si128(in3)); + _mm_storeu_si128((__m128i*)(t1data + 8), _mm512_castsi512_si128(in2)); + _mm_storeu_si128((__m128i*)(t1data + 12), _mm512_castsi512_si128(in4)); + + tmp1 = _mm512_permutex2var_epi64(in1, perm1, in3); + tmp2 = _mm512_permutex2var_epi64(in2, perm1, in4); + + _mm256_storeu_si256((__m256i*)(t1data + 16), _mm512_castsi512_si256(tmp1)); + _mm256_storeu_si256((__m256i*)(t1data + 24), _mm512_castsi512_si256(tmp2)); + _mm256_storeu_si256((__m256i*)(t1data + 32), _mm512_extracti64x4_epi64(tmp1, + 0x1)); + _mm256_storeu_si256((__m256i*)(t1data + 40), _mm512_extracti64x4_epi64(tmp2, + 0x1)); + _mm256_storeu_si256((__m256i*)(t1data + 48), + _mm512_castsi512_si256(_mm512_permutex2var_epi64(in1, perm2, in3))); + _mm256_storeu_si256((__m256i*)(t1data + 56), + _mm512_castsi512_si256(_mm512_permutex2var_epi64(in2, perm2, in4))); + t1data += 64; + ptr += 16; + } + for (i = 0; i < cblk_w % 16; ++i) { + t1data[0] = ptr[(j + 0) * tile_w] << T1_NMSEDEC_FRACBITS; + t1data[1] = ptr[(j + 1) * tile_w] << T1_NMSEDEC_FRACBITS; + t1data[2] = ptr[(j + 2) * tile_w] << T1_NMSEDEC_FRACBITS; + t1data[3] = ptr[(j + 3) * tile_w] << T1_NMSEDEC_FRACBITS; + t1data += 4; + ptr += 1; + } #elif defined(__AVX2__) OPJ_UINT32* ptr = tiledp_u; for (i = 0; i < cblk_w / 8; ++i) { @@ -2328,10 +2336,14 @@ static void opj_t1_cblk_encode_processor(void* user_data, opj_tls_t* tls) // 10 11 12 13 14 15 16 17 02 12 22 32 03 13 23 33 // 20 21 22 23 24 25 26 27 04 14 24 34 05 15 25 35 // 30 31 32 33 34 35 36 37 06 16 26 36 07 17 27 37 - __m256i in1 = _mm256_slli_epi32(_mm256_loadu_si256((__m256i*)(ptr + (j + 0) * tile_w)), T1_NMSEDEC_FRACBITS); - __m256i in2 = _mm256_slli_epi32(_mm256_loadu_si256((__m256i*)(ptr + (j + 1) * tile_w)), T1_NMSEDEC_FRACBITS); - __m256i in3 = _mm256_slli_epi32(_mm256_loadu_si256((__m256i*)(ptr + (j + 2) * tile_w)), T1_NMSEDEC_FRACBITS); - __m256i in4 = _mm256_slli_epi32(_mm256_loadu_si256((__m256i*)(ptr + (j + 3) * tile_w)), T1_NMSEDEC_FRACBITS); + __m256i in1 = _mm256_slli_epi32(_mm256_loadu_si256((__m256i*)(ptr + + (j + 0) * tile_w)), T1_NMSEDEC_FRACBITS); + __m256i in2 = _mm256_slli_epi32(_mm256_loadu_si256((__m256i*)(ptr + + (j + 1) * tile_w)), T1_NMSEDEC_FRACBITS); + __m256i in3 = _mm256_slli_epi32(_mm256_loadu_si256((__m256i*)(ptr + + (j + 2) * tile_w)), T1_NMSEDEC_FRACBITS); + __m256i in4 = _mm256_slli_epi32(_mm256_loadu_si256((__m256i*)(ptr + + (j + 3) * tile_w)), T1_NMSEDEC_FRACBITS); __m256i tmp1 = _mm256_unpacklo_epi32(in1, in2); __m256i tmp2 = _mm256_unpacklo_epi32(in3, in4); @@ -2347,8 +2359,10 @@ static void opj_t1_cblk_encode_processor(void* user_data, opj_tls_t* tls) _mm_storeu_si128((__m128i*)(t1data + 4), _mm256_castsi256_si128(in3)); _mm_storeu_si128((__m128i*)(t1data + 8), _mm256_castsi256_si128(in2)); _mm_storeu_si128((__m128i*)(t1data + 12), _mm256_castsi256_si128(in4)); - _mm256_storeu_si256((__m256i*)(t1data + 16), _mm256_permute2x128_si256(in1, in3, 0x31)); - _mm256_storeu_si256((__m256i*)(t1data + 24), _mm256_permute2x128_si256(in2, in4, 0x31)); + _mm256_storeu_si256((__m256i*)(t1data + 16), _mm256_permute2x128_si256(in1, in3, + 0x31)); + _mm256_storeu_si256((__m256i*)(t1data + 24), _mm256_permute2x128_si256(in2, in4, + 0x31)); t1data += 32; ptr += 8; }