Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ggml : fix iq4_nl dot product with odd number of blocks #8549

Merged
merged 2 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 42 additions & 54 deletions ggml/src/ggml-quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -11745,6 +11745,9 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *

const int nb = n / QK4_NL;

int ib = 0;
float sumf = 0;

#if defined __ARM_NEON
const int8x16_t values = vld1q_s8(kvalues_iq4nl);
const uint8x16_t m4b = vdupq_n_u8(0x0f);
Expand All @@ -11753,16 +11756,14 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
int8x16x4_t q8b;
int32x4_t prod_1, prod_2;

float sumf = 0;

for (int ib = 0; ib < nb; ib += 2) {
for (; ib + 1 < nb; ib += 2) {

q4bits.val[0] = vld1q_u8(x[ib+0].qs);
q4bits.val[1] = vld1q_u8(x[ib+1].qs);
q8b.val[0] = vld1q_s8(y[ib+0].qs);
q8b.val[1] = vld1q_s8(y[ib+0].qs + 16);
q8b.val[2] = vld1q_s8(y[ib+1].qs);
q8b.val[3] = vld1q_s8(y[ib+1].qs + 16);
q4bits.val[0] = vld1q_u8(x[ib + 0].qs);
q4bits.val[1] = vld1q_u8(x[ib + 1].qs);
q8b.val[0] = vld1q_s8(y[ib + 0].qs);
q8b.val[1] = vld1q_s8(y[ib + 0].qs + 16);
q8b.val[2] = vld1q_s8(y[ib + 1].qs);
q8b.val[3] = vld1q_s8(y[ib + 1].qs + 16);

q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b));
q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4));
Expand All @@ -11773,12 +11774,10 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]);

sumf +=
GGML_FP16_TO_FP32(x[ib+0].d) * GGML_FP16_TO_FP32(y[ib+0].d) * vaddvq_s32(prod_1) +
GGML_FP16_TO_FP32(x[ib+1].d) * GGML_FP16_TO_FP32(y[ib+1].d) * vaddvq_s32(prod_2);
GGML_FP16_TO_FP32(x[ib+0].d) * GGML_FP16_TO_FP32(y[ib + 0].d) * vaddvq_s32(prod_1) +
GGML_FP16_TO_FP32(x[ib+1].d) * GGML_FP16_TO_FP32(y[ib + 1].d) * vaddvq_s32(prod_2);
}

*s = sumf;

#elif defined __AVX2__

const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
Expand All @@ -11787,11 +11786,11 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *

__m256 accum1 = _mm256_setzero_ps();
__m256 accum2 = _mm256_setzero_ps();
for (int ib = 0; ib < nb; ib += 2) {
const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[0].qs);
const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[1].qs);
const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[0].qs);
const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[1].qs);
for (; ib + 1 < nb; ib += 2) {
const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib + 0].qs);
const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib + 1].qs);
const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[ib + 0].qs);
const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[ib + 1].qs);
const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
_mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
Expand All @@ -11800,16 +11799,13 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
const __m256i p_1 = _mm256_madd_epi16(p16_1, mone);
const __m256i p_2 = _mm256_madd_epi16(p16_2, mone);
accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[0].d)*GGML_FP16_TO_FP32(x[0].d)),
accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 0].d)*GGML_FP16_TO_FP32(x[ib + 0].d)),
_mm256_cvtepi32_ps(p_1), accum1);
accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[1].d)*GGML_FP16_TO_FP32(x[1].d)),
accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 1].d)*GGML_FP16_TO_FP32(x[ib + 1].d)),
_mm256_cvtepi32_ps(p_2), accum2);

y += 2;
x += 2;
}

*s = hsum_float_8(_mm256_add_ps(accum1, accum2));
sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));

#elif defined __AVX__
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
Expand All @@ -11818,13 +11814,13 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *

__m256 accum1 = _mm256_setzero_ps();
__m256 accum2 = _mm256_setzero_ps();
for (int ib = 0; ib < nb; ib += 2) {
const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[0].qs);
const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[1].qs);
const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[0].qs);
const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[0].qs + 1);
const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[1].qs);
const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[1].qs + 1);
for (; ib + 1 < nb; ib += 2) {
const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs);
const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs);
const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1);
const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);

const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b));
const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b));
Expand All @@ -11838,16 +11834,13 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, mone);
const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, mone);
const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, mone);
accum1 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[0].d)*GGML_FP16_TO_FP32(x[0].d)),
accum1 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 0].d)*GGML_FP16_TO_FP32(x[ib + 0].d)),
_mm256_cvtepi32_ps(MM256_SET_M128I(p_1_1, p_1_0))), accum1);
accum2 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[1].d)*GGML_FP16_TO_FP32(x[1].d)),
accum2 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 1].d)*GGML_FP16_TO_FP32(x[ib + 1].d)),
_mm256_cvtepi32_ps(MM256_SET_M128I(p_2_1, p_2_0))), accum2);

y += 2;
x += 2;
}

*s = hsum_float_8(_mm256_add_ps(accum1, accum2));
sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));

#elif defined(__POWER9_VECTOR__)
const vector signed char lowMask = vec_splats((signed char)0xF);
Expand All @@ -11860,7 +11853,7 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
const vector signed char values = vec_xl( 0, kvalues_iq4nl);

#pragma GCC unroll 4
for (int ib = 0; ib < nb; ++ib) {
for (; ib < nb; ++ib) {
__builtin_prefetch(x[ib].qs, 0, 1);
__builtin_prefetch(y[ib].qs, 0, 1);

Expand Down Expand Up @@ -11897,7 +11890,7 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));

*s = vec_extract(vsumf0, 0);
sumf = vec_extract(vsumf0, 0);

#elif defined (__loongarch_asx)

Expand All @@ -11907,11 +11900,11 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *

__m256 accum1 = (__m256)__lasx_xvldi(0);
__m256 accum2 = (__m256)__lasx_xvldi(0);
for (int ib = 0; ib < nb; ib += 2) {
const __m128i q4bits_1 = __lsx_vld((const __m128i*)x[0].qs, 0);
const __m128i q4bits_2 = __lsx_vld((const __m128i*)x[1].qs, 0);
const __m256i q8b_1 = __lasx_xvld((const __m256i *)y[0].qs, 0);
const __m256i q8b_2 = __lasx_xvld((const __m256i *)y[1].qs, 0);
for (; ib + 1 < nb; ib += 2) {
const __m128i q4bits_1 = __lsx_vld((const __m128i*)x[ib + 0].qs, 0);
const __m128i q4bits_2 = __lsx_vld((const __m128i*)x[ib + 1].qs, 0);
const __m256i q8b_1 = __lasx_xvld((const __m256i *)y[ib + 0].qs, 0);
const __m256i q8b_2 = __lasx_xvld((const __m256i *)y[ib + 1].qs, 0);
const __m256i q4b_1 = lasx_insertf128(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_1, 4), m4b)),
lsx_shuffle_b(values128, __lsx_vand_v(q4bits_1, m4b)));
const __m256i q4b_2 = lasx_insertf128(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b)),
Expand All @@ -11920,20 +11913,16 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
const __m256i p_1 = lasx_madd_h(p16_1, mone);
const __m256i p_2 = lasx_madd_h(p16_2, mone);
accum1 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[0].d)*GGML_FP16_TO_FP32(x[0].d)),
accum1 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[ib + 0].d)*GGML_FP16_TO_FP32(x[ib + 0].d)),
__lasx_xvffint_s_w(p_1), accum1);
accum2 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[1].d)*GGML_FP16_TO_FP32(x[1].d)),
accum2 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[ib + 1].d)*GGML_FP16_TO_FP32(x[ib + 1].d)),
__lasx_xvffint_s_w(p_2), accum2);

y += 2;
x += 2;
}

*s = hsum_float_8(__lasx_xvfadd_s(accum1, accum2));
sumf = hsum_float_8(__lasx_xvfadd_s(accum1, accum2));

#else
float sumf = 0;
for (int ib = 0; ib < nb; ++ib) {
#endif
for (; ib < nb; ++ib) {
const float d = GGML_FP16_TO_FP32(y[ib].d)*GGML_FP16_TO_FP32(x[ib].d);
int sumi1 = 0, sumi2 = 0;
for (int j = 0; j < QK4_NL/2; ++j) {
Expand All @@ -11943,7 +11932,6 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
sumf += d * (sumi1 + sumi2);
}
*s = sumf;
#endif
}

void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
Expand Down
27 changes: 27 additions & 0 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,16 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
im = nullptr;
}
}

ggml_quantize_chunk(tensor->type, data.data(), dataq.data(), 0, size/tensor->ne[0], tensor->ne[0], im);
GGML_ASSERT(ggml_validate_row_data(tensor->type, dataq.data(), dataq.size()));
// TODO: other cases
//#pragma omp parallel for
//for (int i = 0; i < tensor->ne[1]; i++) {
// ggml_quantize_chunk(tensor->type, data.data(), dataq.data(),
// i * tensor->ne[0], 1, tensor->ne[0], im);
//}

ggml_backend_tensor_set(tensor, dataq.data(), 0, dataq.size());
} else if (tensor->type == GGML_TYPE_I8 || tensor->type == GGML_TYPE_I16 || tensor->type == GGML_TYPE_I32) {
// This is going to create some weird integers though.
Expand Down Expand Up @@ -2220,6 +2228,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 10, 10, 10}, eps));
}

#if 1
for (ggml_type type_a : base_types) {
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1}));
Expand All @@ -2239,6 +2248,24 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 2}));
}
}
#else
// m = a rows
// n = b rows
// k = cols
std::uniform_int_distribution<> dist_m(1, 128);
std::uniform_int_distribution<> dist_n(16, 128);
std::uniform_int_distribution<> dist_k(1, 16);
for (int i = 0; i < 1000; i++) {
for (ggml_type type_a : all_types) {
for (ggml_type type_b : {GGML_TYPE_F32}) {
int m = dist_m(rng);
int n = dist_n(rng);
int k = dist_k(rng) * ggml_blck_size(type_a);
test_cases.emplace_back(new test_mul_mat(type_a, type_b, m, n, k, { 1, 1}, {1, 1}));
}
}
}
#endif

for (ggml_type type_a : other_types) {
for (ggml_type type_b : {GGML_TYPE_F32}) {
Expand Down
Loading