diff --git a/ggml.c b/ggml.c index 0b2622ab0e4ee..6cea937c8d735 100644 --- a/ggml.c +++ b/ggml.c @@ -1310,6 +1310,29 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r } } +#ifdef __AVX2__ +// There is no better way of doing this? +// I guess not, AVX is not very good at horizontal sums. +// The commented solution for a hotrizontal sum was suggested by @pubby as being slightly +// faster than the solution below. As I don't have an AVX2 system handt right now to test, +// keeping the original. +// TODO: Please try and if it does make a differece, uncomment and remove the implementation below. +//static inline float horizontal_sum(__m256i a) { +// __m256i b = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(a))); +// __m256i sum = _mm256_add_epi32(a, b); +// __m256i hi = _mm256_unpackhi_epi64(sum, sum); +// sum = _mm256_add_epi32(sum, hi); +// return _mm256_cvtsi256_si32(sum) + _mm256_extract_epi32(sum, 4); +//} +static inline float horizontal_sum(__m256i a) { + __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extracti128_si256(a, 1)); + __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); + __m128i sum64 = _mm_add_epi32(hi64, sum128); + __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); +} +#endif + static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) { assert(k % QK8_0 == 0); const int nb = k / QK8_0; @@ -1399,14 +1422,8 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int #if defined(__AVX2__) - // Compute the sum of the quants - // There is not better way of doing this??? - __m256i acc = _mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)); - __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(acc), _mm256_extracti128_si256(acc, 1)); - __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); - __m128i sum64 = _mm_add_epi32(hi64, sum128); - __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); - y[i].s = d * _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); + // Compute the sum of the quants and set y[i].s + y[i].s = d * horizontal_sum(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))); // Convert int32 to int16 i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 @@ -2411,7 +2428,6 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * sum8 += x0->d * y0->s + x1->d * y1->s; const uint8x16_t m4b = vdupq_n_u8(0xf); - //const int8x16_t s8b = vdupq_n_s8(0x8); const uint8x16_t v0_0 = vld1q_u8(x0->qs); const uint8x16_t v0_1 = vld1q_u8(x1->qs); @@ -2422,12 +2438,6 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); - // sub 8 - //const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b); - //const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b); - //const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b); - //const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); - // load y const int8x16_t v1_0l = vld1q_s8(y0->qs); const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); @@ -2442,27 +2452,17 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * #if defined(__ARM_FEATURE_DOTPROD) // dot product into int32x4_t - //const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs); - //const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs); const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs); const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs); sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d); sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d); #else - //const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls)); - //const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls)); - //const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs)); - //const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs)); const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0ls)); const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0ls)); const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0hs)); const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0hs)); - //const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls)); - //const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls)); - //const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs)); - //const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs)); const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1ls)); const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1ls)); const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1hs)); @@ -2644,19 +2644,6 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h); const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h); - // We no longer need this. We have computed the sum of the y quants during quantization, - // so we get the same as these via the scalar instruction above (summs += x0->m * y0->s + x1->m * y1->s) - //const int16x8_t s0i = vaddq_s16( - // vaddq_s16(vmovl_s8(vget_low_s8(v1_0ls)), vmovl_s8(vget_high_s8(v1_0ls))), - // vaddq_s16(vmovl_s8(vget_low_s8(v1_0hs)), vmovl_s8(vget_high_s8(v1_0hs)))); - - //const int16x8_t s1i = vaddq_s16( - // vaddq_s16(vmovl_s8(vget_low_s8(v1_1ls)), vmovl_s8(vget_high_s8(v1_1ls))), - // vaddq_s16(vmovl_s8(vget_low_s8(v1_1hs)), vmovl_s8(vget_high_s8(v1_1hs)))); - - //sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s0i), vget_high_s16(s0i))), x0->m*y0->d); - //sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s1i), vget_high_s16(s1i))), x1->m*y1->d); - #if defined(__ARM_FEATURE_DOTPROD) // dot product into int32x4_t const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs); @@ -2702,11 +2689,9 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * const __m256 d0v = _mm256_broadcast_ss( d0 ); const __m256 d1v = _mm256_broadcast_ss( d1 ); - //const __m256 m0v = _mm256_broadcast_ss( m0 ); // Compute combined scales const __m256 d0d1 = _mm256_mul_ps( d0v, d1v ); - //const __m256 d1m0 = _mm256_mul_ps( d1v, m0v ); // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes const __m256i bx = bytes_from_nibbles_32(x[i].qs); @@ -2728,17 +2713,6 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * // Accumulate d0*d1*x*y acc = _mm256_fmadd_ps( d0d1, xy, acc ); - - // We no longer need this. We have computed the sum of the y quants during quantization, - // so we get the same as these via the single scalar instruction above (summs += x[i].m * y[i].s) - //// Compute sum of y values - //const __m256i y16_l = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) ); - //const __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) ); - //const __m256i ysumi = _mm256_madd_epi16( _mm256_add_epi16(y16_l, y16_h), ones ); - //const __m256 ysum = _mm256_cvtepi32_ps( ysumi ); - - //// Accumulate d1*m0*y - //acc = _mm256_fmadd_ps( d1m0, ysum, acc ); } // Return horizontal sum of the acc vector