Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AVX BF16 and single scale quant optimizations #10212

Merged
merged 22 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
ad01d31
use 128 bit loads (i've tried 256->128 to death and its slower)
netrunnereve Nov 2, 2024
34b9f0d
double accumulator
netrunnereve Nov 2, 2024
dca0deb
avx bf16 vec dot
netrunnereve Nov 2, 2024
e069375
+3% q4_0 inference
netrunnereve Nov 2, 2024
fffe7e6
+7% tg +5% pp compared to master
netrunnereve Nov 2, 2024
f8dd133
slower f16c version, kep for reference
netrunnereve Nov 2, 2024
1335c78
256b version, also slow. i tried :)
netrunnereve Nov 2, 2024
629befc
revert f16
netrunnereve Nov 2, 2024
7de0bdc
faster with madd
netrunnereve Nov 2, 2024
b8d592f
split to functions
netrunnereve Nov 2, 2024
6667ede
Q8_0 and IQ4_NL, 5-7% faster
netrunnereve Nov 3, 2024
6a4c080
fix potential overflow (performance reduced)
netrunnereve Nov 3, 2024
a83ac00
Merge branch 'ggerganov:master' into avx_opt
netrunnereve Nov 3, 2024
b0e9b96
rebase to master
netrunnereve Nov 4, 2024
ec6366f
Merge https://github.com/ggerganov/llama.cpp into avx_opt
netrunnereve Nov 4, 2024
8c29230
Merge branch 'ggerganov:master' into avx_opt
netrunnereve Nov 5, 2024
54e6c88
Merge branch 'avx_opt' of https://github.com/netrunnereve/llama.cpp i…
netrunnereve Nov 8, 2024
13dfe63
Merge https://github.com/ggerganov/llama.cpp into avx_opt
netrunnereve Nov 8, 2024
a847973
16 bit add for q4_0 only
netrunnereve Nov 12, 2024
c54b67c
Merge branch 'ggerganov:master' into avx_opt
netrunnereve Nov 12, 2024
9352321
pull in master (same changes removed)
netrunnereve Nov 15, 2024
f281ca3
merge
netrunnereve Nov 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 77 additions & 51 deletions ggml/src/ggml-cpu/ggml-cpu-quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,28 @@ static inline __m128i packNibbles( __m256i bytes )
#endif
}
#elif defined(__AVX__)
static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
{
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
const __m128i lowByte = _mm_set1_epi16( 0xFF );
__m128i high = _mm_andnot_si128( lowByte, bytes1 );
__m128i low = _mm_and_si128( lowByte, bytes1 );
high = _mm_srli_epi16( high, 4 );
bytes1 = _mm_or_si128( low, high );
high = _mm_andnot_si128( lowByte, bytes2 );
low = _mm_and_si128( lowByte, bytes2 );
high = _mm_srli_epi16( high, 4 );
bytes2 = _mm_or_si128( low, high );

return _mm_packus_epi16( bytes1, bytes2);
}

static inline __m128i mul_add_epi8_sse(const __m128i x, const __m128i y) {
const __m128i ax = _mm_sign_epi8(x, x);
const __m128i sy = _mm_sign_epi8(y, x);
return _mm_maddubs_epi16(ax, sy);
}

// spread 32 bits to 32 bytes { 0x00, 0xFF }
static inline __m256i bytes_from_bits_32(const uint8_t * x) {
uint32_t x32;
Expand Down Expand Up @@ -217,26 +239,29 @@ static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
return sum_i16_pairs_float(doth, dotl);
}

static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
{
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
const __m128i lowByte = _mm_set1_epi16( 0xFF );
__m128i high = _mm_andnot_si128( lowByte, bytes1 );
__m128i low = _mm_and_si128( lowByte, bytes1 );
high = _mm_srli_epi16( high, 4 );
bytes1 = _mm_or_si128( low, high );
high = _mm_andnot_si128( lowByte, bytes2 );
low = _mm_and_si128( lowByte, bytes2 );
high = _mm_srli_epi16( high, 4 );
bytes2 = _mm_or_si128( low, high );
// larger version of mul_sum_i8_pairs_float where x and y are each represented by four 128-bit vectors
static inline __m256 mul_sum_i8_quad_float(const __m128i x_1_0, const __m128i x_1_1, const __m128i x_2_0, const __m128i x_2_1,
const __m128i y_1_0, const __m128i y_1_1, const __m128i y_2_0, const __m128i y_2_1) {
const __m128i mone = _mm_set1_epi16(1);

return _mm_packus_epi16( bytes1, bytes2);
const __m128i p16_1_0 = mul_add_epi8_sse(x_1_0, y_1_0);
const __m128i p16_1_1 = mul_add_epi8_sse(x_1_1, y_1_1);
const __m128i p16_2_0 = mul_add_epi8_sse(x_2_0, y_2_0);
const __m128i p16_2_1 = mul_add_epi8_sse(x_2_1, y_2_1);
const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, mone);
const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, mone);
const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, mone);
const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, mone);
const __m128i p_1 = _mm_add_epi32(p_1_0, p_1_1);
const __m128i p_2 = _mm_add_epi32(p_2_0, p_2_1);
return _mm256_cvtepi32_ps(MM256_SET_M128I(p_2, p_1));
}

static inline __m128i mul_add_epi8_sse(const __m128i x, const __m128i y) {
const __m128i ax = _mm_sign_epi8(x, x);
const __m128i sy = _mm_sign_epi8(y, x);
return _mm_maddubs_epi16(ax, sy);
// quad fp16 delta calculation
static inline __m256 quad_fp16_delta_float(const float x0, const float y0, const float x1, const float y1) {
// GGML_FP16_TO_FP32 is faster than Intel F16C
return _mm256_set_m128(_mm_set1_ps(GGML_FP16_TO_FP32(x1) * GGML_FP16_TO_FP32(y1)),
_mm_set1_ps(GGML_FP16_TO_FP32(x0) * GGML_FP16_TO_FP32(y0)));
}
#endif
#elif defined(__SSSE3__)
Expand Down Expand Up @@ -2004,10 +2029,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r

sumf = hsum_float_8(acc);
#elif defined(__AVX__)
const __m128i mone = _mm_set1_epi16(1);

__m256 accum1 = _mm256_setzero_ps();
__m256 accum2 = _mm256_setzero_ps();
__m256 accum = _mm256_setzero_ps();
for (; ib + 1 < nb; ib += 2) {
const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs);
const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
Expand All @@ -2020,21 +2042,20 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
const __m128i q4b_1_1 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(q4bits_1, 4)), _mm_set1_epi8(8));
const __m128i q4b_2_0 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), q4bits_2), _mm_set1_epi8(8));
const __m128i q4b_2_1 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(q4bits_2, 4)), _mm_set1_epi8(8));

const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0);
const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1);
const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0);
const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1);
const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, mone);
const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, mone);
const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, mone);
const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, mone);
accum1 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 0].d)*GGML_FP16_TO_FP32(x[ib + 0].d)),
_mm256_cvtepi32_ps(MM256_SET_M128I(p_1_1, p_1_0))), accum1);
accum2 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 1].d)*GGML_FP16_TO_FP32(x[ib + 1].d)),
_mm256_cvtepi32_ps(MM256_SET_M128I(p_2_1, p_2_0))), accum2);
const __m128i p_1 = _mm_add_epi16(p16_1_0, p16_1_1);
const __m128i p_2 = _mm_add_epi16(p16_2_0, p16_2_1);
const __m256 p = sum_i16_pairs_float(p_2, p_1);

const __m256 deltas = quad_fp16_delta_float(x[ib].d, y[ib].d, x[ib + 1].d, y[ib + 1].d);
accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum);
}

sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
sumf = hsum_float_8(accum);
#elif defined(__SSSE3__)
// set constants
const __m128i lowMask = _mm_set1_epi8(0xF);
Expand Down Expand Up @@ -3535,7 +3556,7 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
}

sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
#elif defined(__AVX2__) || defined(__AVX__)
#elif defined(__AVX2__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();

Expand All @@ -3549,14 +3570,29 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
const __m256 q = mul_sum_i8_pairs_float(qx, qy);

// Multiply q with scale and accumulate
#if defined(__AVX2__)
acc = _mm256_fmadd_ps( d, q, acc );
#else
acc = _mm256_add_ps( _mm256_mul_ps( d, q ), acc );
#endif
}

sumf = hsum_float_8(acc);
#elif defined(__AVX__)
__m256 accum = _mm256_setzero_ps();

for (; ib + 1 < nb; ib += 2) {
const __m128i qx_1_0 = _mm_loadu_si128((const __m128i *)x[ib].qs);
const __m128i qx_1_1 = _mm_loadu_si128((const __m128i *)x[ib].qs + 1);
const __m128i qx_2_0 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
const __m128i qx_2_1 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs + 1);
const __m128i qy_1_0 = _mm_loadu_si128((const __m128i *)y[ib].qs);
const __m128i qy_1_1 = _mm_loadu_si128((const __m128i *)y[ib].qs + 1);
const __m128i qy_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
const __m128i qy_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);

const __m256 p = mul_sum_i8_quad_float(qx_1_0, qx_1_1, qx_2_0, qx_2_1, qy_1_0, qy_1_1, qy_2_0, qy_2_1);
const __m256 deltas = quad_fp16_delta_float(x[ib].d, y[ib].d, x[ib + 1].d, y[ib + 1].d);
accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum);
}

sumf = hsum_float_8(accum);
#elif defined(__riscv_v_intrinsic)
size_t vl = __riscv_vsetvl_e8m1(qk);

Expand Down Expand Up @@ -10322,10 +10358,8 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
#elif defined __AVX__
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
const __m128i m4b = _mm_set1_epi8(0x0f);
const __m128i mone = _mm_set1_epi16(1);

__m256 accum1 = _mm256_setzero_ps();
__m256 accum2 = _mm256_setzero_ps();
__m256 accum = _mm256_setzero_ps();
for (; ib + 1 < nb; ib += 2) {
const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs);
const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
Expand All @@ -10338,21 +10372,13 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b));
const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b));
const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b));
const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0);
const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1);
const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0);
const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1);
const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, mone);
const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, mone);
const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, mone);
const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, mone);
accum1 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 0].d)*GGML_FP16_TO_FP32(x[ib + 0].d)),
_mm256_cvtepi32_ps(MM256_SET_M128I(p_1_1, p_1_0))), accum1);
accum2 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 1].d)*GGML_FP16_TO_FP32(x[ib + 1].d)),
_mm256_cvtepi32_ps(MM256_SET_M128I(p_2_1, p_2_0))), accum2);

const __m256 p = mul_sum_i8_quad_float(q4b_1_0, q4b_1_1, q4b_2_0, q4b_2_1, q8b_1_0, q8b_1_1, q8b_2_0, q8b_2_1);
const __m256 deltas = quad_fp16_delta_float(x[ib].d, y[ib].d, x[ib + 1].d, y[ib + 1].d);
accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum);
}

sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
sumf = hsum_float_8(accum);

#elif defined(__POWER9_VECTOR__)
const vector signed char lowMask = vec_splats((signed char)0xF);
Expand Down
6 changes: 5 additions & 1 deletion ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -1469,8 +1469,12 @@ static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t
sumf += (ggml_float)_mm512_reduce_add_ps(c2);

#undef LOAD
#elif defined(__AVX2__)
#elif defined(__AVX2__) || defined(__AVX__)
#if defined(__AVX2__)
#define LOAD(p) _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16))
#else
#define LOAD(p) _mm256_castsi256_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16)), (_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_bsrli_si128(_mm_loadu_si128((const __m128i *)(p)), 8)), 16)), 1))
#endif
__m256 c1 = _mm256_setzero_ps();
__m256 c2 = _mm256_setzero_ps();
__m256 c3 = _mm256_setzero_ps();
Expand Down
Loading