diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 9ae6acd473..6173030343 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib +import functools import os import textwrap import warnings @@ -22,6 +24,7 @@ import torch import torch.utils.data import transformers +from accelerate import PartialState from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed from accelerate.utils.other import is_compiled_module from datasets import Dataset, IterableDataset @@ -444,21 +447,26 @@ def data_collator(features): # No data collation is needed in GRPO if self.accelerator.is_main_process: vllm_device = self.args.vllm_device + device_type = PartialState().default_device.type + device_module = getattr(torch, device_type) if vllm_device == "auto": - if torch.cuda.device_count() == 1: - vllm_device = "cuda:0" # particular case when training with only 1 GPU: share it + if device_module.device_count() == 1: + vllm_device = f"{device_type}:0" # particular case when training with onyl 1 device: share it else: - vllm_device = f"cuda:{self.accelerator.num_processes}" # take the next GPU idx + vllm_device = f"{device_type}:{self.accelerator.num_processes}" # take the next GPU idx # Check that the requested device is available - if vllm_device.split(":")[0] == "cuda" and int(vllm_device.split(":")[1]) >= torch.cuda.device_count(): + if ( + vllm_device.split(":")[0] == f"{device_type}" + and int(vllm_device.split(":")[1]) >= device_module.device_count() + ): raise ValueError( f"The requested device for vllm ({vllm_device}) is not available. You are likely using vLLM " "without restricting the number of GPUs for training. Set the `--num_processes` argument to a " "value lower than the number of GPUs available on your machine—typically, reducing it by one " - f"is sufficient. In your case: `--num_processes {torch.cuda.device_count() - 1}`." + f"is sufficient. In your case: `--num_processes {device_module.device_count() - 1}`." ) # Check that the requested device is not also used for training - if vllm_device in {f"cuda:{idx}" for idx in range(self.accelerator.num_processes)}: + if vllm_device in {f"{device_type}:{idx}" for idx in range(self.accelerator.num_processes)}: warnings.warn( f"The requested device {vllm_device} is also being used for training. For higher throughput " "and to avoid out-of-memory errors, it is recommended to use a dedicated device for vLLM. " @@ -472,7 +480,25 @@ def data_collator(features): # No data collation is needed in GRPO profiling_patch = patch( "vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling", return_value=None ) - with world_size_patch, profiling_patch: + + # For Ascend NPU (torch-npu), collective communication requires the establishment of a communication + # group, and different processes must hold the same group number. However, multiple process groups will + # be created internally within vLLM. This will cause the group id of the communication group on rank 0 + # to be different from that of other ranks, causing backward to hang on because the communication + # domain cannot be established. So we need to patch it to make sure the group id of different ranks in + # the training phase are the same. + @contextlib.contextmanager + def new_group_context(): + new_group = torch.distributed.new_group + try: + torch.distributed.new_group = functools.partial(new_group, use_local_synchronization=True) + torch.npu.mem_get_info = functools.partial(torch.npu.mem_get_info, device=vllm_device) + yield + finally: + torch.distributed.new_group = new_group + + new_group_patch = new_group_context() if device_type == "npu" else contextlib.nullcontext() + with world_size_patch, profiling_patch, new_group_patch: self.llm = LLM( model=model.name_or_path, device=vllm_device,