Skip to content

Commit

Permalink
cuda : better rope implementation
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov committed May 30, 2024
1 parent fb97b9e commit 4739018
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 88 deletions.
151 changes: 63 additions & 88 deletions ggml-cuda/rope.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,67 +28,38 @@ static __device__ void rope_yarn(
*sin_theta = sinf(theta) * mscale;
}

//// rope == RoPE == rotary positional embedding
//template<typename T, bool has_ff>
//static __global__ void rope_norm(
// const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
// float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
// const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
//
// if (col >= ncols) {
// return;
// }
//
// const int row = blockDim.x*blockIdx.x + threadIdx.x;
// const int i = row*ncols + col;
// const int i2 = row/p_delta_rows;
//
// const float theta_base = pos[i2]*powf(freq_base, -float(col)/ncols);
//
// const float freq_factor = has_ff ? freq_factors[col/2] : 1.0f;
//
// float cos_theta, sin_theta;
// rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta);
//
// const float x0 = x[i + 0];
// const float x1 = x[i + 1];
//
// dst[i + 0] = x0*cos_theta - x1*sin_theta;
// dst[i + 1] = x0*sin_theta + x1*cos_theta;
//}

template<typename T, bool has_ff>
static __global__ void rope_norm(
const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);

if (col >= ncols) {
if (i0 >= ne0) {
return;
}

const int row = blockDim.x*blockIdx.x + threadIdx.x;
const int ib = col / n_dims;
const int ic = col % n_dims;

if (ib > 0) {
const int i = row*ncols + ib*n_dims + ic;
if (i0 >= n_dims) {
const int i = row*ne0 + i0;

dst[i + 0] = x[i + 0];
dst[i + 1] = x[i + 1];

return;
}

const int i = row*ncols + ib*n_dims + ic;
const int i = row*ne0 + i0;
const int i2 = row/p_delta_rows;

const float theta_base = pos[i2]*powf(theta_scale, col/2.0f);
const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);

const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;

const float freq_factor = has_ff ? freq_factors[ic/2] : 1.0f;
float cos_theta;
float sin_theta;

float cos_theta, sin_theta;
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);

const float x0 = x[i + 0];
const float x1 = x[i + 1];
Expand All @@ -99,36 +70,36 @@ static __global__ void rope_norm(

template<typename T, bool has_ff>
static __global__ void rope_neox(
const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);

if (col >= ncols) {
if (i0 >= ne0) {
return;
}

const int row = blockDim.x*blockIdx.x + threadIdx.x;
const int ib = col / n_dims;
const int ic = col % n_dims;

if (ib > 0) {
const int i = row*ncols + ib*n_dims + ic;
if (i0 >= n_dims) {
const int i = row*ne0 + i0;

dst[i + 0] = x[i + 0];
dst[i + 1] = x[i + 1];

return;
}

const int i = row*ncols + ib*n_dims + ic/2;
const int i = row*ne0 + i0/2;
const int i2 = row/p_delta_rows;

const float theta_base = pos[i2]*powf(theta_scale, col/2.0f);
const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);

const float freq_factor = has_ff ? freq_factors[ic/2] : 1.0f;
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;

float cos_theta, sin_theta;
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
float cos_theta;
float sin_theta;

rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);

const float x0 = x[i + 0];
const float x1 = x[i + n_dims/2];
Expand All @@ -139,79 +110,79 @@ static __global__ void rope_neox(

template<typename T>
static void rope_norm_cuda(
const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
GGML_ASSERT(ncols % 2 == 0);
GGML_ASSERT(ne0 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nrows, num_blocks_x, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nr, n_blocks_x, 1);

const float theta_scale = powf(freq_base, -2.0f/n_dims);

if (freq_factors == nullptr) {
rope_norm<T, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors
);
} else {
rope_norm<T, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors
);
}
}

template<typename T>
static void rope_neox_cuda(
const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
GGML_ASSERT(ncols % 2 == 0);
GGML_ASSERT(ne0 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nrows, num_blocks_x, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nr, n_blocks_x, 1);

const float theta_scale = powf(freq_base, -2.0f/n_dims);

if (freq_factors == nullptr) {
rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors
);
} else {
rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors
);
}
}

static void rope_norm_cuda_f16(
const half * x, half * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {

rope_norm_cuda<half>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
rope_norm_cuda<half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
}

static void rope_norm_cuda_f32(
const float * x, float * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {

rope_norm_cuda<float>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
rope_norm_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
}

static void rope_neox_cuda_f16(
const half * x, half * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {

rope_neox_cuda<half>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
rope_neox_cuda<half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
}

static void rope_neox_cuda_f32(
const float * x, float * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
) {

rope_neox_cuda<float>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
rope_neox_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
}

void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
Expand All @@ -232,30 +203,34 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t nrows = ggml_nrows(src0);
const int64_t nr = ggml_nrows(src0);

//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
//const int n_ctx = ((int32_t *) dst->op_params)[3];
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
//const int n_ctx = ((int32_t *) dst->op_params)[3];
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];

// RoPE alteration for extended context
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
float freq_base;
float freq_scale;
float ext_factor;
float attn_factor;
float beta_fast;
float beta_slow;

memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));

const float * freq_factors = nullptr;
const int32_t * pos = nullptr;

const bool is_neox = mode & 2;

pos = (const int32_t *) src1_d;
const int32_t * pos = (const int32_t *) src1_d;

const float * freq_factors = nullptr;
if (src2 != nullptr) {
freq_factors = (const float *) src2->data;
}
Expand All @@ -267,12 +242,12 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
if (is_neox) {
if (src0->type == GGML_TYPE_F32) {
rope_neox_cuda_f32(
(const float *)src0_d, (float *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
(const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, stream
);
} else if (src0->type == GGML_TYPE_F16) {
rope_neox_cuda_f16(
(const half *)src0_d, (half *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
(const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, stream
);
} else {
Expand All @@ -281,12 +256,12 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
} else {
if (src0->type == GGML_TYPE_F32) {
rope_norm_cuda_f32(
(const float *)src0_d, (float *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
(const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, stream
);
} else if (src0->type == GGML_TYPE_F16) {
rope_norm_cuda_f16(
(const half *)src0_d, (half *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
(const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, stream
);
} else {
Expand Down
1 change: 1 addition & 0 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2252,6 +2252,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 40B)
}
}

all = false;
}
}
Expand Down

0 comments on commit 4739018

Please sign in to comment.