Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] transformers TPU support broken on v4.45.0 #34176

Closed
3 of 4 tasks
steveepreston opened this issue Oct 15, 2024 · 23 comments · Fixed by #34197
Closed
3 of 4 tasks

[Bug] transformers TPU support broken on v4.45.0 #34176

steveepreston opened this issue Oct 15, 2024 · 23 comments · Fixed by #34197
Labels

Comments

@steveepreston
Copy link
Contributor

steveepreston commented Oct 15, 2024

System Info

transformers: v4.45.0 and up (any of v4.45.0 / v4.45.1 / v4.45.2)
accelerate: v1.0.1 (same result on v0.34.2)

Who can help?

trainer experts: @muellerzr @SunMarc
accelerate expert: @muellerzr
text models expert: @ArthurZucker
Thank you guys!

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Minimal working code is Here. Code follows GoogleCloudPlatform example

on TPU VM, train done like a charm on transformers from v4.43.1 to v4.44.2, but when upgrading to any of v4.45.0 / v4.45.1 / v4.45.2 it throws this Error: RuntimeError: There are currently no available devices found, must be one of 'XPU', 'CUDA', or 'NPU'.

 

Error Traceback:

General traceback is: callling SFTTrainer() > self.accelerator = Accelerator(**args) (transformers/trainer.py)

Click here to Show Full Error Traceback
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[48], line 4
      1 from trl import SFTTrainer
      2 from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
----> 4 trainer = SFTTrainer(
      5     model=base_model,
      6     train_dataset=data,
      7     args=TrainingArguments(
      8         per_device_train_batch_size=BATCH_SIZE,  # This is actually the global batch size for SPMD.
      9         num_train_epochs=1,
     10         max_steps=-1,
     11         output_dir="/output_dir",
     12         optim="adafactor",
     13         logging_steps=1,
     14         dataloader_drop_last = True,  # Required for SPMD.
     15         fsdp="full_shard",
     16         fsdp_config=fsdp_config,
     17     ),
     18     peft_config=lora_config,
     19     dataset_text_field="quote",
     20     max_seq_length=max_seq_length,
     21     packing=True,
     22 )

File /usr/local/lib/python3.10/site-packages/huggingface_hub/utils/_deprecation.py:101, in _deprecate_arguments.<locals>._inner_deprecate_positional_args.<locals>.inner_f(*args, **kwargs)
     99         message += "\n\n" + custom_message
    100     warnings.warn(message, FutureWarning)
--> 101 return f(*args, **kwargs)

File /usr/local/lib/python3.10/site-packages/trl/trainer/sft_trainer.py:401, in SFTTrainer.__init__(self, model, args, data_collator, train_dataset, eval_dataset, tokenizer, model_init, compute_metrics, callbacks, optimizers, preprocess_logits_for_metrics, peft_config, dataset_text_field, packing, formatting_func, max_seq_length, infinite, num_of_sequences, chars_per_token, dataset_num_proc, dataset_batch_size, neftune_noise_alpha, model_init_kwargs, dataset_kwargs, eval_packing)
    395 if tokenizer.padding_side is not None and tokenizer.padding_side != "right":
    396     warnings.warn(
    397         "You passed a tokenizer with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to "
    398         "overflow issues when training a model in half-precision. You might consider adding `tokenizer.padding_side = 'right'` to your code."
    399     )
--> 401 super().__init__(
    402     model=model,
    403     args=args,
    404     data_collator=data_collator,
    405     train_dataset=train_dataset,
    406     eval_dataset=eval_dataset,
    407     tokenizer=tokenizer,
    408     model_init=model_init,
    409     compute_metrics=compute_metrics,
    410     callbacks=callbacks,
    411     optimizers=optimizers,
    412     preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    413 )
    415 # Add tags for models that have been loaded with the correct transformers version
    416 if hasattr(self.model, "add_model_tags"):

File /usr/local/lib/python3.10/site-packages/transformers/trainer.py:411, in Trainer.__init__(self, model, args, data_collator, train_dataset, eval_dataset, tokenizer, model_init, compute_metrics, callbacks, optimizers, preprocess_logits_for_metrics)
    408 self.deepspeed = None
    409 self.is_in_train = False
--> 411 self.create_accelerator_and_postprocess()
    413 # memory metrics - must set up as early as possible
    414 self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)

File /usr/local/lib/python3.10/site-packages/transformers/trainer.py:4858, in Trainer.create_accelerator_and_postprocess(self)
   4855     args.update(accelerator_config)
   4857 # create accelerator object
-> 4858 self.accelerator = Accelerator(**args)
   4859 # some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag
   4860 self.gather_function = self.accelerator.gather_for_metrics

File /usr/local/lib/python3.10/site-packages/accelerate/accelerator.py:349, in Accelerator.__init__(self, device_placement, split_batches, mixed_precision, gradient_accumulation_steps, cpu, dataloader_config, deepspeed_plugin, fsdp_plugin, megatron_lm_plugin, rng_types, log_with, project_dir, project_config, gradient_accumulation_plugin, step_scheduler_with_optimizer, kwargs_handlers, dynamo_backend, deepspeed_plugins)
    345         raise ValueError(f"FSDP requires PyTorch >= {FSDP_PYTORCH_VERSION}")
    347 if fsdp_plugin is None:  # init from env variables
    348     fsdp_plugin = (
--> 349         FullyShardedDataParallelPlugin() if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" else None
    350     )
    351 else:
    352     if not isinstance(fsdp_plugin, FullyShardedDataParallelPlugin):

File <string>:21, in __init__(self, sharding_strategy, backward_prefetch, mixed_precision_policy, auto_wrap_policy, cpu_offload, ignored_modules, state_dict_type, state_dict_config, optim_state_dict_config, limit_all_gathers, use_orig_params, param_init_fn, sync_module_states, forward_prefetch, activation_checkpointing, cpu_ram_efficient_loading, transformer_cls_names_to_wrap, min_num_params)

File /usr/local/lib/python3.10/site-packages/accelerate/utils/dataclasses.py:1684, in FullyShardedDataParallelPlugin.__post_init__(self)
   1682     device = torch.xpu.current_device()
   1683 else:
-> 1684     raise RuntimeError(
   1685         "There are currently no available devices found, must be one of 'XPU', 'CUDA', or 'NPU'."
   1686     )
   1687 # Create a function that will be used to initialize the parameters of the model
   1688 # when using `sync_module_states`
   1689 self.param_init_fn = lambda x: x.to_empty(device=device, recurse=False)

RuntimeError: There are currently no available devices found, must be one of 'XPU', 'CUDA', or 'NPU'.

 

My observation and guess

I tested multiple times, and can confirm that this error is Directly Caused by only changing version of transformers. Therefore accelerate version was fixed during all runs, my guess is something changed on v4.45.0 (maybe on trainer.py) that affects args in the self.accelerator = Accelerator(**args), so that error will raised by accelerate .

Expected behavior

my guess: args corrected and self.accelerator = Accelerator(**args) called correctly. so accelerate can work on TPU.

@muellerzr
Copy link
Contributor

@steveepreston you confirmed it works on 4.44.0?

@steveepreston
Copy link
Contributor Author

steveepreston commented Oct 15, 2024

Hey @muellerzr
Yes.

Tested below versions from 4.43.1 to 4.45.2 one by one. for each test, full restarted session/kernel.

  • 4.43.1 > Success
  • 4.43.4 > Success
  • 4.44.0 > Success
  • 4.44.1 > Success
  • 4.44.2 > Success
  • 4.45.0 > RuntimeError: There are currently no available devices found, must be one of 'XPU', 'CUDA', or 'NPU'
  • 4.45.1 > RuntimeError: There are currently no available devices found, must be one of 'XPU', 'CUDA', or 'NPU'
  • 4.45.2 > RuntimeError: There are currently no available devices found, must be one of 'XPU', 'CUDA', or 'NPU'

@steveepreston
Copy link
Contributor Author

steveepreston commented Oct 15, 2024

same error on dev build:

  • 4.46.0.dev0 > RuntimeError: There are currently no available devices found, must be one of 'XPU', 'CUDA', or 'NPU'

@steveepreston
Copy link
Contributor Author

steveepreston commented Oct 16, 2024

!pip install -qq git+https://github.com/huggingface/transformers.git@###

Test History for trainer.py on v4.45.0

@steveepreston
Copy link
Contributor Author

steveepreston commented Oct 16, 2024

!pip install -qq git+https://github.com/huggingface/transformers.git@###

Test History for training_args.py on v4.45.0

@steveepreston
Copy link
Contributor Author

Test History for transformers on Aug 6 - Aug 7

@steveepreston
Copy link
Contributor Author

steveepreston commented Oct 16, 2024

Problem Found:

Commit Caused Error is:

  • (Aug 7) 46d09af enable xla fsdp > RuntimeError: There are currently no available devices

@steveepreston
Copy link
Contributor Author

⚠️ @hanwen-sun Your commit caused Error for Trainer on TPU VM.
Fix it please.

@SunMarc
Copy link
Member

SunMarc commented Oct 16, 2024

Hey @steveepreston, we probably need to revert this commit as I just checked that the fsdp integration in accelerate do not support xla yet. We only have this integration in Trainer as you can see here. Another solution would be to add the integration in accelerate. Would you like to open a PR to revert this PR first ?

@steveepreston
Copy link
Contributor Author

Hey @SunMarc. Thank for attention!

I'm not deeply familiar with fsdp. i just tested and saw that SFTTrainer worked like a charm on TPU VM on 4.44.2 and due to super fast train speed i though it's using the power of TPU.

btw, i created a PL to revert the Error Throwing by 46d09af commit.

@steveepreston
Copy link
Contributor Author

steveepreston commented Oct 16, 2024

@SunMarc Thank you for your support. The error gone now and Trainer works again ✅

But I confused after your explain. was that past commit correct in fact and was RuntimeError: There are currently no available devices found, must be one of 'XPU', 'CUDA', or 'NPU' an expected correct Error for running this example on TPU VM v3-8?

so now we are bypassing accelerate? if yes, this means we are not using TPU power? no parallel/distributed training?
so how model is trained now? does it trained on cpu0 and ignores [xla0, xla1, xla2, xla3, xla4, xla5, xla6, xla7]?
Sorry for this newbie question. Please explain a little i'm really confused. Thanks.

i wonder if accelerate supports npu and xpu but not tpu

@steveepreston
Copy link
Contributor Author

and what about the official blog post for Fine-Tuning Gemma Models in Huggingface website?

@hanwen-sun
Copy link
Contributor

hi, actually accelerate support xla fsdp in this pr: huggingface/accelerate#2176. But we only integrate it in transformers: #29334.
Trl init from fsdp_plugin in accelerate to decide whether to use fsdp: https://github.com/huggingface/trl/blob/02f4e750c07c5a470f2d82a3a59e011401b5c63e/trl/trainer/ppo_trainer.py#L204 and accelerate fsdp_plugin init from the env variable ACCELERATE_USE_FSDP: https://github.com/huggingface/accelerate/blob/a84327e59652b79b1f6e3be58be634fbd35184f3/src/accelerate/accelerator.py#L348, which is set here:

if len(self.fsdp) > 0 and not self.fsdp_config["xla"]:
.

@steveepreston
Copy link
Contributor Author

steveepreston commented Oct 17, 2024

@hanwen-sun Hey, Thank for explain.

Am i understanding correct:

  • Accelerate supports xla fsdp
  • Transformers Trainer supports xla fsdp
  • because TRL SFTTrainer inherits Transformers Trainer, it supports xla fsdp too

But not self.fsdp_config["xla"] in the if code seems is bypassing to set os.environ["ACCELERATE_USE_FSDP"] = "true"

@SunMarc
Copy link
Member

SunMarc commented Oct 17, 2024

Correct me if i'm wrong @hanwen-sun but XLA FSDP requires to use torch_xla modules such as from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP. We have that in transformers but not in accelerate. Trl is a wrapper around Trainer so it will use the same code path.

@hanwen-sun
Copy link
Contributor

@steveepreston @SunMarc I will take some time to check this and give you a reply tomorrow.

@steveepreston
Copy link
Contributor Author

@SunMarc @hanwen-sun Thank you both!

@steveepreston
Copy link
Contributor Author

steveepreston commented Oct 17, 2024

I'm agree with @SunMarc
I think problem is on accelerate side

Once again see error trace:

  1. we call SFTTrainer()
  2. TRL / SFTTrainer.__init__()
  3. Transformers / Trainer.__init__()
  4. Accelerate / Accelerator.__init__()
  5. Accelerate / FullyShardedDataParallelPlugin.__post_init__() > ⛔ must be one of 'XPU', 'CUDA', or 'NPU'

@hanwen-sun
Copy link
Contributor

hanwen-sun commented Oct 18, 2024

@steveepreston @SunMarc sorry I made a mistake. The Accelerator does not support XLA FSDP; instead, it wraps FSDP within transformers/trainer.py. The Accelerator checks the device in FullyShardedDataParallelPlugin.post_init(). Previously, we used GPU as the backend for XLA, which allowed us to run the code successfully. However, this approach will not work correctly for TPU.
@steveepreston, I am not sure if SFTTrainer correctly wraps XLA FSDP. You might want to perform some checks. Set accelerator_use_fsdp=Flase could potentially cause issues with the accelerator.clip_grad_norm method for xla fsdp. I will open a new issue later and keep you informed.

@steveepreston
Copy link
Contributor Author

steveepreston commented Oct 18, 2024

@hanwen-sun Thank you for checking.

Can you please check that I'm correct? Then I can deep into this issue and debug it:

  • TPU = a hardware accelerator, contains a matrix of processing cores (acts like Multi-GPU)
  • XLA = a compiler for optimizing and executing PyTorch operations across multiple devices (like TPU/Multi-GPU)
  • FSDP = a technique to distribute model training across multiple devices by sharding model and data.

without XLA, torch operations run on cpu0 and ignores [xla0, xla1, xla2, xla3, xla4, xla5, xla6, xla7].
xla xr.use_spmd() enables us to distributes training process on all cores of TPU.
so what is the point of FSDP? just to optimize and speedup this distribution?
sorry for newbie question

@hanwen-sun
Copy link
Contributor

@steveepreston FSDP is a type of distributed training strategy which aims to fully utilize the computation resource of hardware. You can refer to https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html. I'm not family with the use_spmd(). But you are right in general.

@steveepreston
Copy link
Contributor Author

@hanwen-sun Thanks for the note

@code30x58
Copy link

code30x58 commented Dec 17, 2024 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants