Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mamba slow_forward uses inplace operations causing errors in the backward pass #29526

Closed
2 of 4 tasks
vasqu opened this issue Mar 7, 2024 · 4 comments
Closed
2 of 4 tasks

Comments

@vasqu
Copy link
Contributor

vasqu commented Mar 7, 2024

System Info

  • transformers version: 4.39.0.dev0 (commit hash 923733c)
  • Platform: Linux-6.5.0-10022-tuxedo-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.20.3
  • Safetensors version: 0.4.2
  • Accelerate version: not installed
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.2.1+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: no

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I'm using google colab for demonstration purposes (originally occurred in a private project). Important to note that it uses the slow_forward since on my private OS I haven't upgraded my CUDA yet and hence, tried out the slow version instead for now.

  1. !pip install git+https://github.com/huggingface/transformers.git@923733c22bf4d3cc6661c8cd3b730b275e9a938e (a mamba compatible version)
  2. The example script itself:
import torch
from transformers import MambaModel, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
mamba = MambaModel.from_pretrained("state-spaces/mamba-130m-hf").to('cuda')
optimizer = torch.optim.SGD(mamba.parameters(), lr=0.01, momentum=0.9)

inputs = tokenizer("Hello, this is a test sample to demonstrate the backward pass issue in mamba", return_tensors="pt").to('cuda')
ouputs = mamba(**inputs).last_hidden_state

mamba.zero_grad()
optimizer.zero_grad()

loss = torch.log(1 + torch.abs(outputs.sum())).to('cuda')
loss.backward()
  1. Following error will ocurr: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1, 1536, 16]], which is output 0 of torch::autograd::CopyBackwards, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

Expected behavior

Backward passes can be performed even under the slow variant of the forward pass.

Using detect anomaly in autograd (i.e autograd.detect_anomaly) reveals the issue to be in L256@modeling_mamba.

A possible "dirty" solution would be to copy the ssm_state as in ssm_state = discrete_A[:, :, i, :] * ssm_state.clone() + deltaB_u[:, :, i, :]. Maybe there is something better.

@amyeroberts
Copy link
Collaborator

Hi @vasqu, thanks for raising and for taking the time to dig into this. This is a duplicate of #29514

Very happy to review any PR to fix this. I'm surprised the issue is coming from ssm_state being used here - a = a * b wouldn't normally be considered an in place operation (as opposed to a *= b)

@ArthurZucker
Copy link
Collaborator

I am also surprised because we have 3 tests that make sure gradient propagates, and have trained / seen successful training.
would recommend you to set « use_cache » to False, that should fix it

@vasqu
Copy link
Contributor Author

vasqu commented Mar 8, 2024

Yea, I'm also quite confused as to what causes it to be seen as an in place operation.

With the comment of @ArthurZucker, I looked into the loop again and tried just cloning the initial ssm_state. More concretely L214 seems to be the cause of the issue. Changing it to ssm_state = cache_params.ssm_states[self.layer_idx].clone(), has similar effects as to turning the cache off (on first glance). My first suggestion does work, but cloning in a loop explodes the memory requirements even more..

If I understand Mamba correctly, then it is to be expected that the memory explodes in the slow variant? If I have time later today or on the weekend I'll submit a PR. Additionally, I'd suggest rewriting the initial warning to additionally advise not using the cache.

@ArthurZucker
Copy link
Collaborator

Will have a look at the PR but yes, the slow version needs a custom forward and bacward that saves explicitly for backward, instead of relying on the engine

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants