Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GRPOTrainer with Deepspeed: Getting device mismatch error #2745

Closed
5 tasks done
3rdAT opened this issue Feb 3, 2025 · 2 comments · Fixed by #2766
Closed
5 tasks done

GRPOTrainer with Deepspeed: Getting device mismatch error #2745

3rdAT opened this issue Feb 3, 2025 · 2 comments · Fixed by #2766
Labels
🐛 bug Something isn't working 🚀 deepspeed Related to deepspeed 🏋 GRPO Related to GRPO

Comments

@3rdAT
Copy link

3rdAT commented Feb 3, 2025

Reproduction

During training, I gave 4 GPUs for training the model and 1 GPU for vLLM, "cuda:4", but I am getting a device mismatch error when vllm constructs the CUDA_GRAPHS. How to overcome this issue?

export CUDA_VISIBLE_DEVICES=0,1,2,3,4

ACCELERATE_LOG_LEVEL=info accelerate launch \
  --config_file \
  --main_process_port 29501 \
  grpo.py \
  --model_name_or_path  \
  --dataset_name  \
  --output_dir  \
  --bf16 True \
  --bf16_full_eval True \
  --per_device_train_batch_size=1 \

Deepspeed Configuration File

compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
  deepspeed_multinode_launcher: standard
  offload_optimizer_device: none
  offload_param_device: none
  zero3_init_flag: true
  zero3_save_16bit_model: true
  zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
   training_args = GRPOConfig(
        output_dir="/data/data/arrv/models/GRPO/model1",
        run_name="GRPO_OrcaMath",
        learning_rate=5e-6,
        adam_beta1 = 0.9,
        adam_beta2 = 0.99,
        weight_decay = 0.1,
        warmup_ratio = 0.1,
        lr_scheduler_type='cosine',
        logging_steps=1,
        bf16=True,
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        num_generations=4,
        max_prompt_length=256,
        max_completion_length=512,
        num_train_epochs=1,
        save_steps=100,
        max_grad_norm=0.1,
        log_on_each_node=False,
        use_vllm=True,
        vllm_gpu_memory_utilization=0.3,
        vllm_device="auto",
        report_to="wandb" #I'm disabling Wandb.
    )

    # Initialize the GRPO trainer
    trainer = GRPOTrainer(
        model=model,
        smol_model=None,
        smol_model_tokenizer=None,
        reward_funcs=[format_reward_function, correctness_reward_function],
        args=training_args,
        train_dataset=raw_datasets,
        # eval_dataset=eval_raw_datasets,
        processing_class=tokenizer,
    )

outputs:

    ...
[rank0]: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:4 and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)
    ...

System Info

  • Platform: Linux-5.15.0-119-generic-x86_64-with-glibc2.35
  • Python version: 3.11.7
  • PyTorch version: 2.5.1
  • CUDA device(s): NVIDIA H100 NVL, NVIDIA H100 NVL, NVIDIA H100 NVL, NVIDIA H100 NVL, NVIDIA H100 NVL, NVIDIA H100 NVL, NVIDIA H100 NVL, NVIDIA H100 NVL
  • Transformers version: 4.48.2
  • Accelerate version: 1.1.1
  • Accelerate config: not found
  • Datasets version: 3.1.0
  • HF Hub version: 0.26.2
  • TRL version: 0.13.0
  • bitsandbytes version: 0.45.0
  • DeepSpeed version: 0.16.2
  • Diffusers version: not installed
  • Liger-Kernel version: not installed
  • LLM-Blender version: not installed
  • OpenAI version: 1.60.2
  • PEFT version: 0.13.2

Checklist

  • I have checked that my issue isn't already filed (see open issues)
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible (more on MREs)
  • Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
  • Any traceback provided is complete
@github-actions github-actions bot added 🏋 GRPO Related to GRPO 🚀 deepspeed Related to deepspeed 🐛 bug Something isn't working labels Feb 3, 2025
@tchang1997
Copy link
Contributor

What version of vLLM are you using? I had a similar issue when using vLLM for inference — for me, the issue was that I was using vLLM v0.6, and upgrading to 0.7.1 resolved this error.

Basically, vllm/worker/model_runner.py was using .cuda() to change tensor devices instead of setting device=[correct device name] until a bugfix on 1/4/2025, which is included in the 0.7 release.

Perhaps a brief line on vLLM version requirements could be added to the docs, if it isn't present already?

@3rdAT
Copy link
Author

3rdAT commented Feb 3, 2025

Hi @tchang1997 , I was actually using vLLM==0.6.6.post1. The updated version works! Thanks!

Also, I would like to say that the model needs to be loaded with flash-attention to work flawlessly with vllm, Adding this to the documentation for GRPO would also be beneficial for people who are new to this.

Thank for the help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🐛 bug Something isn't working 🚀 deepspeed Related to deepspeed 🏋 GRPO Related to GRPO
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants