-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RuntimeError: Expected to mark a variable ready only once. #389
Comments
Hi there. I'm not able to reproduce your error on my side. Also you are using distributed multi-GPU but with only one GPU ( |
Hi, actually I tried with different accelerate configurations, I can confirm that specify multi-gpus lead the same error on my side. to reproduce the error, the script must be launched by |
- `Accelerate` version: 0.9.0
- Platform: Linux-4.14.105-1-tlinux3-0013-x86_64-with-glibc2.17
- Python version: 3.8.13
- Numpy version: 1.21.6
- PyTorch version (GPU?): 1.11.0+cu113 (True)
- `Accelerate` default config:
- compute_environment: LOCAL_MACHINE
- distributed_type: MULTI_GPU
- mixed_precision: bf16
- use_cpu: False
- num_processes: 8
- machine_rank: 0
- num_machines: 1
- main_process_ip: None
- main_process_port: None
- main_training_function: main
- deepspeed_config: {}
- fsdp_config: {} I tested the PoC code in a different GPU machine with multi-gpus configuration, and it has the same error. |
Looking more, I can reproduce when I launch on two processes. It's not linked to Accelerate as when I use vanilla PyTorch DDP, I get the same error: import torch
import transformers
from transformers import AutoModel, AutoTokenizer
from torch.nn.parallel import DistributedDataParallel
torch.distributed.init_process_group(backend="nccl")
process_index = torch.distributed.get_rank()
# seems not related to a specific model, longformer/gpt2 can also lead to the crash
model_path = "facebook/opt-125m"
x = ['whoami', 'hello world']
y = ['i am who', 'world hello']
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
x = tokenizer(x, padding=True, return_tensors='pt').to(process_index)
y = tokenizer(y, padding=True, return_tensors='pt').to(process_index)
model = AutoModel.from_pretrained("facebook/opt-125m")
model.gradient_checkpointing_enable()
model = model.to(process_index)
model = DistributedDataParallel(model, device_ids=[process_index], output_device=process_index, find_unused_parameters=True)
# when gradient_checkpointing is enabled, this line below will lead to the error
model.train()
x_embed = model(**x, return_dict=True).last_hidden_state[:, -1, :]
y_embed = model(**y, return_dict=True).last_hidden_state[:, -1, :]
logits = torch.matmul(x_embed, y_embed.t())
loss = transformers.models.clip.modeling_clip.clip_loss(logits)
loss.backward() It just looks like this model does not support gradient checkpointing with distributed training. |
Thanks for your finding, and I searched from transformers's repo: However, after I tried with find_unused_parameters=False and different models(gpt2/longformer), it still has the same error:( |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Hi, I have met the same issuse, did you solved it now? |
Get rid of this line of code: |
Hi, I have met the same issuse, did you solved it now? |
did you solved it now? |
Throw away the gradient_checkpointing, set it to False. |
Hi, I have met the same issuse in the vanilla PyTorch DDP model, did you solved it now? |
Hello, I encountered this problem in the vanilla PyTorch DDP model. It seems that there is no setting method for gradient_checkpointing is enabled. Is there any other way? |
Hello, I encountered this problem in the vanilla PyTorch DDP model. It seems that there is no setting method for gradient_checkpointing is enabled. Is there any other way? |
You can try |
Hello!
I found that when I try to enable gradient_checking with model.train(), a runtime error will be raised.
Hope someone will give me some feedback to solve it.
Thanks!
command line to launch the script:
TORCH_DISTRIBUTED_DEBUG=DETAIL accelerate launch grad_checking.py
The text was updated successfully, but these errors were encountered: