diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 620cffd145..9ec05f4dd6 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib +import functools import os import textwrap import warnings @@ -20,8 +22,10 @@ from unittest.mock import patch import torch +import torch.distributed import torch.utils.data import transformers +from accelerate import PartialState from accelerate.utils import broadcast_object_list, gather, gather_object, set_seed from accelerate.utils.other import is_compiled_module from datasets import Dataset, IterableDataset @@ -372,21 +376,23 @@ def data_collator(features): # No data collation is needed in GRPO if self.accelerator.is_main_process: vllm_device = self.args.vllm_device + device_type = PartialState().default_device.type + device_module = getattr(torch, device_type) if vllm_device == "auto": - if torch.cuda.device_count() == 1: - vllm_device = "cuda:0" # particular case when training with onyl 1 GPU: share it + if device_module.device_count() == 1: + vllm_device = "{device_type}:0" # particular case when training with onyl 1 GPU: share it else: - vllm_device = f"cuda:{self.accelerator.num_processes}" # take the next GPU idx + vllm_device = f"{device_type}:{self.accelerator.num_processes}" # take the next GPU idx # Check that the requested device is available - if vllm_device.split(":")[0] == "cuda" and int(vllm_device.split(":")[1]) >= torch.cuda.device_count(): + if vllm_device.split(":")[0] == f"{device_type}" and int(vllm_device.split(":")[1]) >= device_module.device_count(): raise ValueError( f"The requested device for vllm ({vllm_device}) is not available. You are likely using vLLM " "without restricting the number of GPUs for training. Set the `--num_processes` argument to a " "value lower than the number of GPUs available on your machine—typically, reducing it by one " - f"is sufficient. In your case: `--num_processes {torch.cuda.device_count() - 1}`." + f"is sufficient. In your case: `--num_processes {device_module.device_count() - 1}`." ) # Check that the requested device is not also used for training - if vllm_device in {f"cuda:{idx}" for idx in range(self.accelerator.num_processes)}: + if vllm_device in {f"{device_type}:{idx}" for idx in range(self.accelerator.num_processes)}: warnings.warn( f"The requested device {vllm_device} is also being used for training. For higher throughput " "and to avoid out-of-memory errors, it is recommended to use a dedicated device for vLLM. " @@ -400,7 +406,22 @@ def data_collator(features): # No data collation is needed in GRPO profiling_patch = patch( "vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling", return_value=None ) - with world_size_patch, profiling_patch: + # For Ascend NPU (torch-npu), collective communication requires the establishment of a communication group, + # and different processes must hold the same group number. However, multiple process groups will be created + # internally within vLLM. This will cause the group id of the communication group on rank 0 to be different from + # that of other ranks, causing backward to hang on because the communication domain cannot be established. So + # we need to patch it to make sure the group id of different ranks in the training phase are the same. + @contextlib.contextmanager + def new_group_context(): + original_new_group = torch.distributed.new_group + try: + torch.distributed.new_group = functools.partial(original_new_group, use_local_synchronization=True) + yield + finally: + torch.distributed.new_group = original_new_group + + new_group_patch = new_group_context() if device_type == "npu" else contextlib.nullcontext() + with world_size_patch, profiling_patch, new_group_patch: self.llm = LLM( model=model.name_or_path, device=vllm_device,