-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug] transformers TPU
support broken on v4.45.0
#34176
Comments
@steveepreston you confirmed it works on 4.44.0? |
Hey @muellerzr Tested below versions from 4.43.1 to 4.45.2 one by one. for each test, full restarted session/kernel.
|
same error on dev build:
|
Test History for trainer.py on v4.45.0
|
Test History for training_args.py on v4.45.0
|
Test History for
|
Problem Found: Commit Caused Error is:
|
|
Hey @steveepreston, we probably need to revert this commit as I just checked that the fsdp integration in accelerate do not support xla yet. We only have this integration in Trainer as you can see here. Another solution would be to add the integration in accelerate. Would you like to open a PR to revert this PR first ? |
Hey @SunMarc. Thank for attention! I'm not deeply familiar with btw, i created a PL to revert the Error Throwing by |
@SunMarc Thank you for your support. The error gone now and Trainer works again ✅ But I confused after your explain. was that past commit correct in fact and was so now we are bypassing i wonder if |
and what about the official blog post for Fine-Tuning Gemma Models in Huggingface website? |
hi, actually accelerate support xla fsdp in this pr: huggingface/accelerate#2176. But we only integrate it in transformers: #29334. transformers/src/transformers/training_args.py Line 1939 in 3f06f95
|
@hanwen-sun Hey, Thank for explain. Am i understanding correct:
But |
Correct me if i'm wrong @hanwen-sun but XLA FSDP requires to use |
@steveepreston @SunMarc I will take some time to check this and give you a reply tomorrow. |
@SunMarc @hanwen-sun Thank you both! |
I'm agree with @SunMarc Once again see error trace:
|
@steveepreston @SunMarc sorry I made a mistake. The Accelerator does not support XLA FSDP; instead, it wraps FSDP within transformers/trainer.py. The Accelerator checks the device in FullyShardedDataParallelPlugin.post_init(). Previously, we used GPU as the backend for XLA, which allowed us to run the code successfully. However, this approach will not work correctly for TPU. |
@hanwen-sun Thank you for checking. Can you please check that I'm correct? Then I can deep into this issue and debug it:
without XLA, torch operations run on cpu0 and ignores [xla0, xla1, xla2, xla3, xla4, xla5, xla6, xla7]. |
@steveepreston FSDP is a type of distributed training strategy which aims to fully utilize the computation resource of hardware. You can refer to https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html. I'm not family with the use_spmd(). But you are right in general. |
@hanwen-sun Thanks for the note |
… On Sun, Oct 20, 2024, 22:33 steveepreston ***@***.***> wrote:
@hanwen-sun <https://github.com/hanwen-sun> Thanks for the note
—
Reply to this email directly, view it on GitHub
<#34176 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/A45VJTKBNNQWM6UYBIBBLT3Z4SG27AVCNFSM6AAAAABP7LPHXSVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDIMRVGYZDMOJZHE>
.
You are receiving this because you are subscribed to this thread.Message
ID: ***@***.***>
|
System Info
transformers: v4.45.0 and up (any of v4.45.0 / v4.45.1 / v4.45.2)
accelerate: v1.0.1 (same result on v0.34.2)
Who can help?
trainer experts: @muellerzr @SunMarc
accelerate expert: @muellerzr
text models expert: @ArthurZucker
Thank you guys!
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Minimal working code is Here. Code follows GoogleCloudPlatform example
on TPU VM, train done like a charm on transformers from v4.43.1 to v4.44.2, but when upgrading to any of v4.45.0 / v4.45.1 / v4.45.2 it throws this Error:
RuntimeError: There are currently no available devices found, must be one of 'XPU', 'CUDA', or 'NPU'.
Error Traceback:
General traceback is: callling
SFTTrainer()
>self.accelerator = Accelerator(**args)
(transformers/trainer.py)Click here to Show Full Error Traceback
My observation and guess
I tested multiple times, and can confirm that this error is Directly Caused by only changing version of
transformers
. Thereforeaccelerate
version was fixed during all runs, my guess is something changed onv4.45.0
(maybe ontrainer.py
) that affectsargs
in theself.accelerator = Accelerator(**args)
, so that error will raised byaccelerate
.Expected behavior
my guess:
args
corrected andself.accelerator = Accelerator(**args)
called correctly. soaccelerate
can work onTPU
.The text was updated successfully, but these errors were encountered: