Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RoPE: fix back, CUDA support for back + noncont. #11240

Merged
merged 2 commits into from
Jan 15, 2025

Conversation

JohannesGaessler
Copy link
Collaborator

This PR fixes the backward pass for RoPE. On master test-backend-ops grad is failing on a related assert. The backwards pass can be constructed relatively simply by just creating a tensor for the forward pass and then changing the op from GGML_ROPE to GGML_ROPE_BACK. One could maybe set ggml_tensor.op_params instead of the op but I don't think that would reduce the overall complexity.

This PR also adds CUDA support for the RoPE backwards pass and for non-contiguous inputs. The latter is needed for the backwards pass of the KV cache. I also added __restrict__ and const where applicable and simplified the templating a bit.

Implicitly test-backend-ops is already testing GGML_OP_ROPE_BACK via gradients, I also added an explicit test.

@github-actions github-actions bot added testing Everything test related Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Jan 14, 2025
Copy link
Owner

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that non-contiguous rope is supported, the extra ggml_cont in build_minicpm3 and build_deepseek2 could likely be removed:

llama.cpp/src/llama.cpp

Lines 6499 to 6514 in 39509fb

q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
q_pe = ggml_rope_ext(
ctx0, q_pe, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor_scaled, beta_fast, beta_slow
);
cb(q_pe, "q_pe", il);
// shared RoPE key
k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
k_pe = ggml_rope_ext(
ctx0, k_pe, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor_scaled, beta_fast, beta_slow
);
cb(k_pe, "k_pe", il);

But it would be good to test these models just in case.
If you don't have the models handy, you can simply update the TODO comments with one that asks to perform the test, so that somebody else can do it in the future.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The build_minicpm3 has the same comments that should be updated

@JohannesGaessler JohannesGaessler merged commit 432df2d into ggerganov:master Jan 15, 2025
48 checks passed
Nexesenex added a commit to Nexesenex/croco.cpp that referenced this pull request Jan 15, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs testing Everything test related
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants