From 7771f60f3c82981a9c30b834f68c5fef19f4817a Mon Sep 17 00:00:00 2001 From: ji-huazhong Date: Mon, 10 Feb 2025 16:55:07 +0800 Subject: [PATCH 1/6] Add GRPO Trainer support for Ascend NPU --- trl/trainer/grpo_trainer.py | 34 +++++++++++++++++++++++++++------- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 1211a453fc..5c5836f93b 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib +import functools import os import textwrap import warnings @@ -22,6 +24,7 @@ import torch import torch.utils.data import transformers +from accelerate import PartialState from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed from accelerate.utils.other import is_compiled_module from datasets import Dataset, IterableDataset @@ -372,21 +375,23 @@ def data_collator(features): # No data collation is needed in GRPO if self.accelerator.is_main_process: vllm_device = self.args.vllm_device + device_type = PartialState().default_device.type + device_module = getattr(torch, device_type) if vllm_device == "auto": - if torch.cuda.device_count() == 1: - vllm_device = "cuda:0" # particular case when training with onyl 1 GPU: share it + if device_module.device_count() == 1: + vllm_device = "{device_type}:0" # particular case when training with onyl 1 GPU: share it else: - vllm_device = f"cuda:{self.accelerator.num_processes}" # take the next GPU idx + vllm_device = f"{device_type}:{self.accelerator.num_processes}" # take the next GPU idx # Check that the requested device is available - if vllm_device.split(":")[0] == "cuda" and int(vllm_device.split(":")[1]) >= torch.cuda.device_count(): + if vllm_device.split(":")[0] == f"{device_type}" and int(vllm_device.split(":")[1]) >= device_module.device_count(): raise ValueError( f"The requested device for vllm ({vllm_device}) is not available. You are likely using vLLM " "without restricting the number of GPUs for training. Set the `--num_processes` argument to a " "value lower than the number of GPUs available on your machine—typically, reducing it by one " - f"is sufficient. In your case: `--num_processes {torch.cuda.device_count() - 1}`." + f"is sufficient. In your case: `--num_processes {device_module.device_count() - 1}`." ) # Check that the requested device is not also used for training - if vllm_device in {f"cuda:{idx}" for idx in range(self.accelerator.num_processes)}: + if vllm_device in {f"{device_type}:{idx}" for idx in range(self.accelerator.num_processes)}: warnings.warn( f"The requested device {vllm_device} is also being used for training. For higher throughput " "and to avoid out-of-memory errors, it is recommended to use a dedicated device for vLLM. " @@ -400,7 +405,22 @@ def data_collator(features): # No data collation is needed in GRPO profiling_patch = patch( "vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling", return_value=None ) - with world_size_patch, profiling_patch: + # For Ascend NPU (torch-npu), collective communication requires the establishment of a communication group, + # and different processes must hold the same group number. However, multiple process groups will be created + # internally within vLLM. This will cause the group id of the communication group on rank 0 to be different from + # that of other ranks, causing backward to hang on because the communication domain cannot be established. So + # we need to patch it to make sure the group id of different ranks in the training phase are the same. + @contextlib.contextmanager + def new_group_context(): + original_new_group = torch.distributed.new_group + try: + torch.distributed.new_group = functools.partial(original_new_group, use_local_synchronization=True) + yield + finally: + torch.distributed.new_group = original_new_group + + new_group_patch = new_group_context() if device_type == "npu" else contextlib.nullcontext() + with world_size_patch, profiling_patch, new_group_patch: self.llm = LLM( model=model.name_or_path, device=vllm_device, From df304ea34a2bc88c71d1f8805bffb827a559f16d Mon Sep 17 00:00:00 2001 From: Huazhong Ji Date: Fri, 14 Feb 2025 20:54:55 +0800 Subject: [PATCH 2/6] =?UTF-8?q?=E6=9B=B4=E6=96=B0=20grpo=5Ftrainer.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --- trl/trainer/grpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 5c5836f93b..7431b8be36 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -406,7 +406,7 @@ def data_collator(features): # No data collation is needed in GRPO "vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling", return_value=None ) # For Ascend NPU (torch-npu), collective communication requires the establishment of a communication group, - # and different processes must hold the same group number. However, multiple process groups will be created + # and different processes must hold the same group number. However, multiple process groups will be created # internally within vLLM. This will cause the group id of the communication group on rank 0 to be different from # that of other ranks, causing backward to hang on because the communication domain cannot be established. So # we need to patch it to make sure the group id of different ranks in the training phase are the same. From 622301f73db5ff9f5a9259dba989101bcbd9641e Mon Sep 17 00:00:00 2001 From: "hzji210@gmail.com" Date: Sat, 15 Feb 2025 01:04:36 +0800 Subject: [PATCH 3/6] code format --- trl/trainer/grpo_trainer.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 7431b8be36..2ea782b144 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -383,7 +383,10 @@ def data_collator(features): # No data collation is needed in GRPO else: vllm_device = f"{device_type}:{self.accelerator.num_processes}" # take the next GPU idx # Check that the requested device is available - if vllm_device.split(":")[0] == f"{device_type}" and int(vllm_device.split(":")[1]) >= device_module.device_count(): + if ( + vllm_device.split(":")[0] == f"{device_type}" + and int(vllm_device.split(":")[1]) >= device_module.device_count() + ): raise ValueError( f"The requested device for vllm ({vllm_device}) is not available. You are likely using vLLM " "without restricting the number of GPUs for training. Set the `--num_processes` argument to a " @@ -405,8 +408,9 @@ def data_collator(features): # No data collation is needed in GRPO profiling_patch = patch( "vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling", return_value=None ) - # For Ascend NPU (torch-npu), collective communication requires the establishment of a communication group, - # and different processes must hold the same group number. However, multiple process groups will be created + + # For Ascend NPU (torch-npu), collective communication requires the establishment of a communication group, + # and different processes must hold the same group number. However, multiple process groups will be created # internally within vLLM. This will cause the group id of the communication group on rank 0 to be different from # that of other ranks, causing backward to hang on because the communication domain cannot be established. So # we need to patch it to make sure the group id of different ranks in the training phase are the same. @@ -414,7 +418,9 @@ def data_collator(features): # No data collation is needed in GRPO def new_group_context(): original_new_group = torch.distributed.new_group try: - torch.distributed.new_group = functools.partial(original_new_group, use_local_synchronization=True) + torch.distributed.new_group = functools.partial( + original_new_group, use_local_synchronization=True + ) yield finally: torch.distributed.new_group = original_new_group From 44ca5ccab4d409766b91d276329bb9c2f8bbdc16 Mon Sep 17 00:00:00 2001 From: Huazhong Ji Date: Wed, 19 Feb 2025 00:18:18 +0800 Subject: [PATCH 4/6] =?UTF-8?q?=E6=9B=B4=E6=96=B0=20grpo=5Ftrainer.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --- trl/trainer/grpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index db68af0f4c..57a14575df 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -379,7 +379,7 @@ def data_collator(features): # No data collation is needed in GRPO device_module = getattr(torch, device_type) if vllm_device == "auto": if device_module.device_count() == 1: - vllm_device = "{device_type}:0" # particular case when training with onyl 1 GPU: share it + vllm_device = f"{device_type}:0" # particular case when training with onyl 1 GPU: share it else: vllm_device = f"{device_type}:{self.accelerator.num_processes}" # take the next GPU idx # Check that the requested device is available From 82a36b49d0bc073a34e5b1a25d61b20cfffd517d Mon Sep 17 00:00:00 2001 From: "hzji210@gmail.com" Date: Fri, 21 Feb 2025 09:00:34 +0800 Subject: [PATCH 5/6] patch mem_get_info --- trl/trainer/grpo_trainer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 57a14575df..dc4cc9ee98 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -421,6 +421,9 @@ def new_group_context(): torch.distributed.new_group = functools.partial( original_new_group, use_local_synchronization=True ) + torch.npu.mem_get_info = functools.partial( + torch.npu.mem_get_info, device=vllm_device + ) yield finally: torch.distributed.new_group = original_new_group From dde817f65f50db6916b1c755aeee984d347fcee6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 27 Feb 2025 11:21:03 +0000 Subject: [PATCH 6/6] stylre --- trl/trainer/grpo_trainer.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 02039419c6..6173030343 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -481,24 +481,21 @@ def data_collator(features): # No data collation is needed in GRPO "vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling", return_value=None ) - # For Ascend NPU (torch-npu), collective communication requires the establishment of a communication group, - # and different processes must hold the same group number. However, multiple process groups will be created - # internally within vLLM. This will cause the group id of the communication group on rank 0 to be different from - # that of other ranks, causing backward to hang on because the communication domain cannot be established. So - # we need to patch it to make sure the group id of different ranks in the training phase are the same. + # For Ascend NPU (torch-npu), collective communication requires the establishment of a communication + # group, and different processes must hold the same group number. However, multiple process groups will + # be created internally within vLLM. This will cause the group id of the communication group on rank 0 + # to be different from that of other ranks, causing backward to hang on because the communication + # domain cannot be established. So we need to patch it to make sure the group id of different ranks in + # the training phase are the same. @contextlib.contextmanager def new_group_context(): - original_new_group = torch.distributed.new_group + new_group = torch.distributed.new_group try: - torch.distributed.new_group = functools.partial( - original_new_group, use_local_synchronization=True - ) - torch.npu.mem_get_info = functools.partial( - torch.npu.mem_get_info, device=vllm_device - ) + torch.distributed.new_group = functools.partial(new_group, use_local_synchronization=True) + torch.npu.mem_get_info = functools.partial(torch.npu.mem_get_info, device=vllm_device) yield finally: - torch.distributed.new_group = original_new_group + torch.distributed.new_group = new_group new_group_patch = new_group_context() if device_type == "npu" else contextlib.nullcontext() with world_size_patch, profiling_patch, new_group_patch: