From 9064b1ca051c6ebff04a3ad9b77d9f8d309396a6 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 9 Dec 2023 14:04:54 +0200 Subject: [PATCH] ggml : fix ggml_get_rows to take into account ne02 / ne11 --- ggml.c | 63 ++++++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 41 insertions(+), 22 deletions(-) diff --git a/ggml.c b/ggml.c index 9982c2dade94e..4bdb702480bc9 100644 --- a/ggml.c +++ b/ggml.c @@ -10342,20 +10342,27 @@ static void ggml_compute_forward_get_rows_q( return; } - const int nc = src0->ne[0]; - const int nr = ggml_nelements(src1); + GGML_TENSOR_BINARY_OP_LOCALS + + const int64_t nc = ne00; + const int64_t nr = ggml_nelements(src1); + const enum ggml_type type = src0->type; ggml_to_float_t const dequantize_row_q = type_traits[type].to_float; - assert( dst->ne[0] == nc); + assert(ne0 == nc); + assert(ne02 == ne11); + assert(nb00 == ggml_type_size(type)); assert(ggml_nrows(dst) == nr); - assert(src0->nb[0] == ggml_type_size(type)); - for (int i = 0; i < nr; ++i) { - const int r = ((int32_t *) src1->data)[i]; + // TODO: multi-thread + for (int64_t i = 0; i < nr; ++i) { + const int64_t r = ((int32_t *) src1->data)[i]; + + const int64_t i02 = i/ne10; dequantize_row_q( - (const void *) ((char *) src0->data + r*src0->nb[1]), + (const void *) ((char *) src0->data + i02*nb02 + r*nb01), (float *) ((char *) dst->data + i*dst->nb[1]), nc); } } @@ -10371,19 +10378,25 @@ static void ggml_compute_forward_get_rows_f16( return; } - const int nc = src0->ne[0]; - const int nr = ggml_nelements(src1); + GGML_TENSOR_BINARY_OP_LOCALS + + const int64_t nc = ne00; + const int64_t nr = ggml_nelements(src1); - assert( dst->ne[0] == nc); + assert(ne0 == nc); + assert(ne02 == ne11); + assert(nb00 == sizeof(ggml_fp16_t)); assert(ggml_nrows(dst) == nr); - assert(src0->nb[0] == sizeof(ggml_fp16_t)); - for (int i = 0; i < nr; ++i) { - const int r = ((int32_t *) src1->data)[i]; + // TODO: multi-thread + for (int64_t i = 0; i < nr; ++i) { + const int64_t r = ((int32_t *) src1->data)[i]; + + const int64_t i02 = i/ne10; for (int j = 0; j < nc; ++j) { - ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + r*src0->nb[1]))[j]; - ((float *) ((char *) dst->data + i*dst->nb[1]))[j] = GGML_FP16_TO_FP32(v); + ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + i02*nb02 + r*nb01))[j]; + ((float *) ((char *) dst->data + i*dst->nb[1]))[j] = GGML_FP16_TO_FP32(v); } } } @@ -10399,19 +10412,25 @@ static void ggml_compute_forward_get_rows_f32( return; } - const int nc = src0->ne[0]; - const int nr = ggml_nelements(src1); + GGML_TENSOR_BINARY_OP_LOCALS + + const int64_t nc = ne00; + const int64_t nr = ggml_nelements(src1); - assert( dst->ne[0] == nc); + assert(ne0 == nc); + assert(ne02 == ne11); + assert(nb00 == sizeof(float)); assert(ggml_nrows(dst) == nr); - assert(src0->nb[0] == sizeof(float)); - for (int i = 0; i < nr; ++i) { - const int r = ((int32_t *) src1->data)[i]; + // TODO: multi-thread + for (int64_t i = 0; i < nr; ++i) { + const int64_t r = ((int32_t *) src1->data)[i]; + + const int64_t i02 = i/ne10; ggml_vec_cpy_f32(nc, (float *) ((char *) dst->data + i*dst->nb[1]), - (float *) ((char *) src0->data + r*src0->nb[1])); + (float *) ((char *) src0->data + i02*nb02 + r*nb01)); } }