-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mamba slow_forward
uses inplace operations causing errors in the backward pass
#29526
Comments
I am also surprised because we have 3 tests that make sure gradient propagates, and have trained / seen successful training. |
Yea, I'm also quite confused as to what causes it to be seen as an in place operation. With the comment of @ArthurZucker, I looked into the loop again and tried just cloning the initial ssm_state. More concretely L214 seems to be the cause of the issue. Changing it to If I understand Mamba correctly, then it is to be expected that the memory explodes in the slow variant? If I have time later today or on the weekend I'll submit a PR. Additionally, I'd suggest rewriting the initial warning to additionally advise not using the cache. |
Will have a look at the PR but yes, the slow version needs a custom forward and bacward that saves explicitly for backward, instead of relying on the engine |
System Info
transformers
version: 4.39.0.dev0 (commit hash 923733c)Who can help?
@ArthurZucker
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
I'm using google colab for demonstration purposes (originally occurred in a private project). Important to note that it uses the
slow_forward
since on my private OS I haven't upgraded my CUDA yet and hence, tried out the slow version instead for now.!pip install git+https://github.com/huggingface/transformers.git@923733c22bf4d3cc6661c8cd3b730b275e9a938e
(a mamba compatible version)RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1, 1536, 16]], which is output 0 of torch::autograd::CopyBackwards, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
Expected behavior
Backward passes can be performed even under the slow variant of the forward pass.
Using detect anomaly in autograd (i.e autograd.detect_anomaly) reveals the issue to be in L256@modeling_mamba.
A possible "dirty" solution would be to copy the ssm_state as in
ssm_state = discrete_A[:, :, i, :] * ssm_state.clone() + deltaB_u[:, :, i, :]
. Maybe there is something better.The text was updated successfully, but these errors were encountered: