Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: Expected to mark a variable ready only once. #389

Closed
nforest opened this issue May 24, 2022 · 15 comments
Closed

RuntimeError: Expected to mark a variable ready only once. #389

nforest opened this issue May 24, 2022 · 15 comments

Comments

@nforest
Copy link

nforest commented May 24, 2022

Hello!

I found that when I try to enable gradient_checking with model.train(), a runtime error will be raised.

Hope someone will give me some feedback to solve it.

Thanks!

  • Minimal PoC Code
#  grad_checking.py
import torch
import transformers
from accelerate import Accelerator
from transformers import AutoModel, AutoTokenizer

# seems not related to a specific model, longformer/gpt2 can also lead to the crash
model_path = "facebook/opt-125m"

accelerator = Accelerator()

x = ['whoami', 'hello world']
y = ['i am who', 'world hello']
tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
x = tokenizer(x, padding=True, return_tensors='pt')
y = tokenizer(y, padding=True, return_tensors='pt')

model = AutoModel.from_pretrained(model_path, local_files_only=True)
model.gradient_checkpointing_enable()
model = accelerator.prepare(model)

# when gradient_checkpointing is enabled, this line below will lead to the error
model.train()

x_embed = model(**x, return_dict=True).last_hidden_state[:, -1, :]
y_embed = model(**y, return_dict=True).last_hidden_state[:, -1, :]
logits = torch.matmul(x_embed, y_embed.t())
loss = transformers.models.clip.modeling_clip.clip_loss(logits)
accelerator.backward(loss)

command line to launch the script:
TORCH_DISTRIBUTED_DEBUG=DETAIL accelerate launch grad_checking.py

RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a module parameter outside the `forward` function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes. or try to use _set_static_graph() as a workaround if this module graph does not change during training loop.2) Reused parameters in multiple reentrant backward passes. For example, if you use multiple `checkpoint` functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases in default. You can try to use _set_static_graph() as a workaround if your module graph does not change over iterations.
Parameter at index 190 with name decoder.layers.11.fc2.weight has been marked as ready twice. This means that multiple autograd engine  hooks have fired for this particular parameter during this iteration.
  • Environment
- `Accelerate` version: 0.9.0
- Platform: Linux-5.4.32-1-tlinux4-0001-x86_64-with-glibc2.2.5
- Python version: 3.8.3
- Numpy version: 1.22.4
- PyTorch version (GPU?): 1.11.0+cu113 (True)
- `Accelerate` default config:
        - compute_environment: LOCAL_MACHINE
        - distributed_type: MULTI_GPU
        - mixed_precision: fp16
        - use_cpu: False
        - num_processes: 1
        - machine_rank: 0
        - num_machines: 1
        - main_process_ip: None
        - main_process_port: None
        - main_training_function: main
        - deepspeed_config: {}
        - fsdp_config: {}
@sgugger
Copy link
Collaborator

sgugger commented May 24, 2022

Hi there. I'm not able to reproduce your error on my side. Also you are using distributed multi-GPU but with only one GPU (num_processes: 1 in your config), maybe that's what is causing the bug?

@nforest
Copy link
Author

nforest commented May 24, 2022

Hi, actually I tried with different accelerate configurations, I can confirm that specify multi-gpus lead the same error on my side.

to reproduce the error, the script must be launched by accelerate launch xxx.py

@nforest
Copy link
Author

nforest commented May 24, 2022

- `Accelerate` version: 0.9.0
- Platform: Linux-4.14.105-1-tlinux3-0013-x86_64-with-glibc2.17
- Python version: 3.8.13
- Numpy version: 1.21.6
- PyTorch version (GPU?): 1.11.0+cu113 (True)
- `Accelerate` default config:
        - compute_environment: LOCAL_MACHINE
        - distributed_type: MULTI_GPU
        - mixed_precision: bf16
        - use_cpu: False
        - num_processes: 8
        - machine_rank: 0
        - num_machines: 1
        - main_process_ip: None
        - main_process_port: None
        - main_training_function: main
        - deepspeed_config: {}
        - fsdp_config: {}

I tested the PoC code in a different GPU machine with multi-gpus configuration, and it has the same error.

@sgugger
Copy link
Collaborator

sgugger commented May 24, 2022

Looking more, I can reproduce when I launch on two processes. It's not linked to Accelerate as when I use vanilla PyTorch DDP, I get the same error:

import torch
import transformers
from transformers import AutoModel, AutoTokenizer
from torch.nn.parallel import DistributedDataParallel


torch.distributed.init_process_group(backend="nccl")
process_index = torch.distributed.get_rank()

# seems not related to a specific model, longformer/gpt2 can also lead to the crash
model_path = "facebook/opt-125m"


x = ['whoami', 'hello world']
y = ['i am who', 'world hello']
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
x = tokenizer(x, padding=True, return_tensors='pt').to(process_index)
y = tokenizer(y, padding=True, return_tensors='pt').to(process_index)

model = AutoModel.from_pretrained("facebook/opt-125m")
model.gradient_checkpointing_enable()
model = model.to(process_index)
model = DistributedDataParallel(model, device_ids=[process_index], output_device=process_index, find_unused_parameters=True)

# when gradient_checkpointing is enabled, this line below will lead to the error
model.train()

x_embed = model(**x, return_dict=True).last_hidden_state[:, -1, :]
y_embed = model(**y, return_dict=True).last_hidden_state[:, -1, :]
logits = torch.matmul(x_embed, y_embed.t())
loss = transformers.models.clip.modeling_clip.clip_loss(logits)
loss.backward()

It just looks like this model does not support gradient checkpointing with distributed training.

@nforest
Copy link
Author

nforest commented May 24, 2022

Thanks for your finding, and I searched from transformers's repo:
huggingface/transformers#15191
huggingface/transformers#7446

However, after I tried with find_unused_parameters=False and different models(gpt2/longformer), it still has the same error:(

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this as completed Jul 1, 2022
@afalf
Copy link

afalf commented Apr 2, 2023

Thanks for your finding, and I searched from transformers's repo: huggingface/transformers#15191 huggingface/transformers#7446

However, after I tried with find_unused_parameters=False and different models(gpt2/longformer), it still has the same error:(

Hi, I have met the same issuse, did you solved it now?

@wei-hongbin
Copy link

Get rid of this line of code:
model.gradient_checkpointing_enable()

@skye95git
Copy link

Hi, I have met the same issuse, did you solved it now?

@zr-bee
Copy link

zr-bee commented May 5, 2023

did you solved it now?

@thusinh1969
Copy link

Throw away the gradient_checkpointing, set it to False.
Steve

@tb2-sy
Copy link

tb2-sy commented Jul 15, 2023

Looking more, I can reproduce when I launch on two processes. It's not linked to Accelerate as when I use vanilla PyTorch DDP, I get the same error:

import torch
import transformers
from transformers import AutoModel, AutoTokenizer
from torch.nn.parallel import DistributedDataParallel


torch.distributed.init_process_group(backend="nccl")
process_index = torch.distributed.get_rank()

# seems not related to a specific model, longformer/gpt2 can also lead to the crash
model_path = "facebook/opt-125m"


x = ['whoami', 'hello world']
y = ['i am who', 'world hello']
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
x = tokenizer(x, padding=True, return_tensors='pt').to(process_index)
y = tokenizer(y, padding=True, return_tensors='pt').to(process_index)

model = AutoModel.from_pretrained("facebook/opt-125m")
model.gradient_checkpointing_enable()
model = model.to(process_index)
model = DistributedDataParallel(model, device_ids=[process_index], output_device=process_index, find_unused_parameters=True)

# when gradient_checkpointing is enabled, this line below will lead to the error
model.train()

x_embed = model(**x, return_dict=True).last_hidden_state[:, -1, :]
y_embed = model(**y, return_dict=True).last_hidden_state[:, -1, :]
logits = torch.matmul(x_embed, y_embed.t())
loss = transformers.models.clip.modeling_clip.clip_loss(logits)
loss.backward()

It just looks like this model does not support gradient checkpointing with distributed training.

Hi, I have met the same issuse in the vanilla PyTorch DDP model, did you solved it now?

@tb2-sy
Copy link

tb2-sy commented Jul 15, 2023

gradient_checkpointing is enabled

Hello, I encountered this problem in the vanilla PyTorch DDP model. It seems that there is no setting method for gradient_checkpointing is enabled. Is there any other way?

@tb2-sy
Copy link

tb2-sy commented Jul 15, 2023

Get rid of this line of code: model.gradient_checkpointing_enable()

Hello, I encountered this problem in the vanilla PyTorch DDP model. It seems that there is no setting method for gradient_checkpointing is enabled. Is there any other way?

@shanguanma
Copy link

You can try model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

9 participants