Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ggml: aarch64: implement SVE kernels for q3_K_q8_K vector dot #11917

Merged
merged 5 commits into from
Feb 20, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 182 additions & 1 deletion ggml/src/ggml-cpu/ggml-cpu-quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -5088,7 +5088,188 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r

const int nb = n / QK_K;

#ifdef __ARM_NEON
#if defined(__ARM_FEATURE_SVE)

uint32_t utmp[4];

const int8_t m32 = 32;
const int vector_length = svcntb()*8;
const svuint8_t m3b_sv = svdup_n_u8(0X3);
const svint32_t vzero_sv = svdup_n_s32(0);

const svuint8_t m0_sv = svdup_n_u8(1);
const svuint8_t m1_sv = svlsl_n_u8_x(svptrue_b8(),m0_sv,1);
const svuint8_t m2_sv = svlsl_n_u8_x(svptrue_b8(),m0_sv,2);
const svuint8_t m3_sv = svlsl_n_u8_x(svptrue_b8(),m0_sv,3);
svbool_t pred_s32 = svnot_b_z (svptrue_b32(),svptrue_pat_b32(SV_VL4));

float sum = 0;

for (int i = 0; i < nb; ++i) {

const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);

const uint8_t * restrict q3_sv = x[i].qs;
const uint8_t * restrict qh_sv = x[i].hmask;
const int8_t * restrict q8_sv = y[i].qs;

// Set up scales
uint32_t *aux = &x[i].scales;
utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);

int8_t * scale = (int8_t *)utmp;

for (int j = 0; j < 16; ++j) scale[j] -= m32;

switch(vector_length){
case 128:
{
svuint8_t qhbits_sv_1 = svld1_u8(svptrue_b8(),qh_sv);
svuint8_t qhbits_sv_2 = svld1_u8(svptrue_b8(),qh_sv+16);
svuint8_t q3h_sv;

svint32_t sumi1_1 = svdup_n_s32(0);
svint8_t q3bytes_sv;

for (int j = 0; j < QK_K/128; ++j) {

const svuint8_t q3bits_sv = svld1_u8(svptrue_b8(),q3_sv); q3_sv += 16;
const svuint8_t q3bits_sv_1 = svld1_u8(svptrue_b8(),q3_sv); q3_sv += 16;
svint8_t q8bytes_1_sv_1 = svld1_s8(svptrue_b8(),q8_sv); q8_sv += 16;
svint8_t q8bytes_1_sv_2 = svld1_s8(svptrue_b8(),q8_sv); q8_sv += 16;

q3h_sv = svlsl_n_u8_x(svptrue_b8(),svbic_u8_x(svptrue_b8(),m0_sv, qhbits_sv_1),2);
q3bytes_sv = svsub_s8_x(svptrue_b8(),svreinterpret_s8_u8(svand_u8_m(svptrue_b8(),q3bits_sv,m3b_sv)), svreinterpret_s8_u8(q3h_sv));

sumi1_1 = svmla_s32_m(svptrue_b32(),sumi1_1,svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1),svdup_n_s32((int32_t)scale[0]));

q3h_sv = svlsl_n_u8_x(svptrue_b8(),svbic_u8_x(svptrue_b8(),m0_sv, qhbits_sv_2),2);
q3bytes_sv = svsub_s8_x(svptrue_b8(),svreinterpret_s8_u8(svand_u8_m(svptrue_b8(),q3bits_sv_1,m3b_sv)), svreinterpret_s8_u8(q3h_sv));

sumi1_1 = svmla_s32_m(svptrue_b32(),sumi1_1,svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2),svdup_n_s32((int32_t)scale[1]));

q8bytes_1_sv_1 = svld1_s8(svptrue_b8(),q8_sv); q8_sv += 16;
q8bytes_1_sv_2 = svld1_s8(svptrue_b8(),q8_sv); q8_sv += 16;

q3h_sv = svlsl_n_u8_x(svptrue_b8(),svbic_u8_x(svptrue_b8(),m1_sv, qhbits_sv_1),1);
q3bytes_sv = svsub_s8_x(svptrue_b8(),svreinterpret_s8_u8(svand_u8_m(svptrue_b8(),svlsr_n_u8_x(svptrue_b8(),q3bits_sv,2),m3b_sv)), svreinterpret_s8_u8(q3h_sv));

sumi1_1 = svmla_s32_m(svptrue_b32(),sumi1_1,svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1),svdup_n_s32((int32_t)scale[2]));

q3h_sv = svlsl_n_u8_x(svptrue_b8(),svbic_u8_x(svptrue_b8(),m1_sv, qhbits_sv_2),1);
q3bytes_sv = svsub_s8_x(svptrue_b8(),svreinterpret_s8_u8(svand_u8_m(svptrue_b8(),svlsr_n_u8_x(svptrue_b8(),q3bits_sv_1,2),m3b_sv)), svreinterpret_s8_u8(q3h_sv));

sumi1_1 = svmla_s32_m(svptrue_b32(),sumi1_1,svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2),svdup_n_s32((int32_t)scale[3]));


scale += 4;
q8bytes_1_sv_1 = svld1_s8(svptrue_b8(),q8_sv); q8_sv += 16;
q8bytes_1_sv_2 = svld1_s8(svptrue_b8(),q8_sv); q8_sv += 16;

q3h_sv = svbic_u8_x(svptrue_b8(),m2_sv, qhbits_sv_1);
q3bytes_sv = svsub_s8_x(svptrue_b8(),svreinterpret_s8_u8(svand_u8_m(svptrue_b8(),svlsr_n_u8_x(svptrue_b8(),q3bits_sv,4),m3b_sv)), svreinterpret_s8_u8(q3h_sv));

sumi1_1 = svmla_s32_m(svptrue_b32(),sumi1_1,svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1),svdup_n_s32((int32_t)scale[0]));

q3h_sv = svbic_u8_x(svptrue_b8(),m2_sv, qhbits_sv_2);
q3bytes_sv = svsub_s8_x(svptrue_b8(),svreinterpret_s8_u8(svand_u8_m(svptrue_b8(),svlsr_n_u8_x(svptrue_b8(),q3bits_sv_1,4),m3b_sv)), svreinterpret_s8_u8(q3h_sv));

sumi1_1 = svmla_s32_m(svptrue_b32(),sumi1_1,svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2),svdup_n_s32((int32_t)scale[1]));


q8bytes_1_sv_1 = svld1_s8(svptrue_b8(),q8_sv); q8_sv += 16;
q8bytes_1_sv_2 = svld1_s8(svptrue_b8(),q8_sv); q8_sv += 16;

q3h_sv = svlsr_n_u8_x(svptrue_b8(),svbic_u8_x(svptrue_b8(),m3_sv, qhbits_sv_1),1);
q3bytes_sv = svsub_s8_x(svptrue_b8(),svreinterpret_s8_u8(svand_u8_m(svptrue_b8(),svlsr_n_u8_x(svptrue_b8(),q3bits_sv,6),m3b_sv)), svreinterpret_s8_u8(q3h_sv));

sumi1_1 = svmla_s32_m(svptrue_b32(),sumi1_1,svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1),svdup_n_s32((int32_t)scale[2]));

q3h_sv = svlsr_n_u8_x(svptrue_b8(),svbic_u8_x(svptrue_b8(),m3_sv, qhbits_sv_2),1);
q3bytes_sv = svsub_s8_x(svptrue_b8(),svreinterpret_s8_u8(svand_u8_m(svptrue_b8(),svlsr_n_u8_x(svptrue_b8(),q3bits_sv_1,6),m3b_sv)), svreinterpret_s8_u8(q3h_sv));

sumi1_1 = svmla_s32_m(svptrue_b32(),sumi1_1,svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2),svdup_n_s32((int32_t)scale[3]));


if(j==0)
{
qhbits_sv_1 = svlsr_n_u8_x(svptrue_b8(),qhbits_sv_1,4);
qhbits_sv_2 = svlsr_n_u8_x(svptrue_b8(),qhbits_sv_2,4);
}

scale += 4;

}

sum += d * (svaddv_s32(svptrue_b32(), sumi1_1));
}break;
case 256:
case 512:
{
svuint8_t qhbits_sv = svld1_u8(svptrue_pat_b8(SV_VL32),qh_sv);
svuint8_t q3h_sv;

svint32_t sumi1_1 = svdup_n_s32(0);
svint8_t q3bytes_sv;

for (int j = 0; j < QK_K/128; ++j) {

const svuint8_t q3bits_sv = svld1_u8(svptrue_pat_b8(SV_VL32),q3_sv); q3_sv += 32;
svint8_t q8bytes_1_sv_1 = svld1_s8(svptrue_pat_b8(SV_VL32),q8_sv); q8_sv += 32;
svint8_t q8bytes_1_sv_2 = svld1_s8(svptrue_pat_b8(SV_VL32),q8_sv); q8_sv += 32;

q3h_sv = svlsl_n_u8_x(svptrue_pat_b8(SV_VL32),svbic_u8_x(svptrue_pat_b8(SV_VL32),m0_sv, qhbits_sv),2);
q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32),svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32),q3bits_sv,m3b_sv)), svreinterpret_s8_u8(q3h_sv));


svint32_t scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4),svdup_n_s32((int32_t)scale[0]),svdup_n_s32((int32_t)scale[1]));
sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8),sumi1_1,svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1),scale_1);

q3h_sv = svlsl_n_u8_x(svptrue_pat_b8(SV_VL32),svbic_u8_x(svptrue_pat_b8(SV_VL32),m1_sv, qhbits_sv),1);
q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32),svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32),svlsr_n_u8_x(svptrue_pat_b8(SV_VL32),q3bits_sv,2),m3b_sv)), svreinterpret_s8_u8(q3h_sv));

scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4),svdup_n_s32((int32_t)scale[2]),svdup_n_s32((int32_t)scale[3]));
sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8),sumi1_1,svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2),scale_1);

scale += 4;
q8bytes_1_sv_1 = svld1_s8(svptrue_pat_b8(SV_VL32),q8_sv); q8_sv += 32;
q8bytes_1_sv_2 = svld1_s8(svptrue_pat_b8(SV_VL32),q8_sv); q8_sv += 32;

q3h_sv = svbic_u8_x(svptrue_pat_b8(SV_VL32),m2_sv, qhbits_sv);
q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32),svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32),svlsr_n_u8_x(svptrue_pat_b8(SV_VL32),q3bits_sv,4),m3b_sv)), svreinterpret_s8_u8(q3h_sv));

scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4),svdup_n_s32((int32_t)scale[0]),svdup_n_s32((int32_t)scale[1]));
sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8),sumi1_1,svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1),scale_1);

q3h_sv = svlsr_n_u8_x(svptrue_pat_b8(SV_VL32),svbic_u8_x(svptrue_pat_b8(SV_VL32),m3_sv, qhbits_sv),1);
q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32),svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32),svlsr_n_u8_x(svptrue_pat_b8(SV_VL32),q3bits_sv,6),m3b_sv)), svreinterpret_s8_u8(q3h_sv));

scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4),svdup_n_s32((int32_t)scale[2]),svdup_n_s32((int32_t)scale[3]));
sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8),sumi1_1,svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2),scale_1);


if(j==0)
{
qhbits_sv = svlsr_n_u8_x(svptrue_pat_b8(SV_VL32),qhbits_sv,4);
}

scale += 4;

}

sum += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), sumi1_1));
}break;
default:
assert(false && "Unsupported vector length");
break;
}
}
*s = sum;

#elif __ARM_NEON

uint32_t aux[3];
uint32_t utmp[4];
Expand Down
Loading