Skip to content

Commit

Permalink
Cleaning up
Browse files Browse the repository at this point in the history
  • Loading branch information
Kawrakow committed Apr 21, 2023
1 parent 66a865b commit c542d5a
Showing 1 changed file with 25 additions and 51 deletions.
76 changes: 25 additions & 51 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -1310,6 +1310,29 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
}
}

#ifdef __AVX2__
// There is no better way of doing this?
// I guess not, AVX is not very good at horizontal sums.
// The commented solution for a hotrizontal sum was suggested by @pubby as being slightly
// faster than the solution below. As I don't have an AVX2 system handt right now to test,
// keeping the original.
// TODO: Please try and if it does make a differece, uncomment and remove the implementation below.
//static inline float horizontal_sum(__m256i a) {
// __m256i b = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(a)));
// __m256i sum = _mm256_add_epi32(a, b);
// __m256i hi = _mm256_unpackhi_epi64(sum, sum);
// sum = _mm256_add_epi32(sum, hi);
// return _mm256_cvtsi256_si32(sum) + _mm256_extract_epi32(sum, 4);
//}
static inline float horizontal_sum(__m256i a) {
__m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extracti128_si256(a, 1));
__m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
__m128i sum64 = _mm_add_epi32(hi64, sum128);
__m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
}
#endif

static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
assert(k % QK8_0 == 0);
const int nb = k / QK8_0;
Expand Down Expand Up @@ -1399,14 +1422,8 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int

#if defined(__AVX2__)

// Compute the sum of the quants
// There is not better way of doing this???
__m256i acc = _mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3));
__m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(acc), _mm256_extracti128_si256(acc, 1));
__m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
__m128i sum64 = _mm_add_epi32(hi64, sum128);
__m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
y[i].s = d * _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
// Compute the sum of the quants and set y[i].s
y[i].s = d * horizontal_sum(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));

// Convert int32 to int16
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
Expand Down Expand Up @@ -2411,7 +2428,6 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
sum8 += x0->d * y0->s + x1->d * y1->s;

const uint8x16_t m4b = vdupq_n_u8(0xf);
//const int8x16_t s8b = vdupq_n_s8(0x8);

const uint8x16_t v0_0 = vld1q_u8(x0->qs);
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
Expand All @@ -2422,12 +2438,6 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));

// sub 8
//const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
//const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
//const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
//const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);

// load y
const int8x16_t v1_0l = vld1q_s8(y0->qs);
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
Expand All @@ -2442,27 +2452,17 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *

#if defined(__ARM_FEATURE_DOTPROD)
// dot product into int32x4_t
//const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs);
//const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs);
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs);
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs);

sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
#else
//const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
//const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
//const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
//const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0ls));
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0ls));
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0hs));
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0hs));

//const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
//const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
//const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
//const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1ls));
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1ls));
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1hs));
Expand Down Expand Up @@ -2644,19 +2644,6 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);

// We no longer need this. We have computed the sum of the y quants during quantization,
// so we get the same as these via the scalar instruction above (summs += x0->m * y0->s + x1->m * y1->s)
//const int16x8_t s0i = vaddq_s16(
// vaddq_s16(vmovl_s8(vget_low_s8(v1_0ls)), vmovl_s8(vget_high_s8(v1_0ls))),
// vaddq_s16(vmovl_s8(vget_low_s8(v1_0hs)), vmovl_s8(vget_high_s8(v1_0hs))));

//const int16x8_t s1i = vaddq_s16(
// vaddq_s16(vmovl_s8(vget_low_s8(v1_1ls)), vmovl_s8(vget_high_s8(v1_1ls))),
// vaddq_s16(vmovl_s8(vget_low_s8(v1_1hs)), vmovl_s8(vget_high_s8(v1_1hs))));

//sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s0i), vget_high_s16(s0i))), x0->m*y0->d);
//sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s1i), vget_high_s16(s1i))), x1->m*y1->d);

#if defined(__ARM_FEATURE_DOTPROD)
// dot product into int32x4_t
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs);
Expand Down Expand Up @@ -2702,11 +2689,9 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *

const __m256 d0v = _mm256_broadcast_ss( d0 );
const __m256 d1v = _mm256_broadcast_ss( d1 );
//const __m256 m0v = _mm256_broadcast_ss( m0 );

// Compute combined scales
const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
//const __m256 d1m0 = _mm256_mul_ps( d1v, m0v );

// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
const __m256i bx = bytes_from_nibbles_32(x[i].qs);
Expand All @@ -2728,17 +2713,6 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *

// Accumulate d0*d1*x*y
acc = _mm256_fmadd_ps( d0d1, xy, acc );

// We no longer need this. We have computed the sum of the y quants during quantization,
// so we get the same as these via the single scalar instruction above (summs += x[i].m * y[i].s)
//// Compute sum of y values
//const __m256i y16_l = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
//const __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
//const __m256i ysumi = _mm256_madd_epi16( _mm256_add_epi16(y16_l, y16_h), ones );
//const __m256 ysum = _mm256_cvtepi32_ps( ysumi );

//// Accumulate d1*m0*y
//acc = _mm256_fmadd_ps( d1m0, ysum, acc );
}

// Return horizontal sum of the acc vector
Expand Down

0 comments on commit c542d5a

Please sign in to comment.