Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel RoPE on metal #3024

Merged
merged 2 commits into from
Sep 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -1141,7 +1141,7 @@ void ggml_metal_graph_compute(
[encoder setBytes:&freq_base length:sizeof(float) atIndex:21];
[encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];

[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
} break;
case GGML_OP_DUP:
case GGML_OP_CPY:
Expand Down
26 changes: 14 additions & 12 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -699,25 +699,27 @@ kernel void kernel_rope(
constant int & mode,
constant float & freq_base,
constant float & freq_scale,
uint3 tpig[[thread_position_in_grid]]) {
const int64_t i3 = tpig[2];
const int64_t i2 = tpig[1];
const int64_t i1 = tpig[0];
uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg[[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]) {
const int64_t i3 = tgpig[2];
const int64_t i2 = tgpig[1];
const int64_t i1 = tgpig[0];

const bool is_neox = mode & 2;
const float theta_scale = pow(freq_base, -2.0f/n_dims);

const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);

float theta = freq_scale * (float)p;
const float theta_0 = freq_scale * (float)p;
const float inv_ndims = -1.f/n_dims;

if (!is_neox) {
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {

const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
const float cos_theta = cos(theta);
const float sin_theta = sin(theta);

theta *= theta_scale;

device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);

Expand All @@ -729,12 +731,12 @@ kernel void kernel_rope(
}
} else {
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
for (int64_t ic = 0; ic < n_dims; ic += 2) {
for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {

const float theta = theta_0 * pow(freq_base, inv_ndims*ic - ib);
cebtenzzre marked this conversation as resolved.
Show resolved Hide resolved
const float cos_theta = cos(theta);
const float sin_theta = sin(theta);

theta *= theta_scale;

const int64_t i0 = ib*n_dims + ic/2;

device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
Expand Down