Skip to content

Commit

Permalink
Add GRPO Trainer support for Ascend NPU
Browse files Browse the repository at this point in the history
  • Loading branch information
ji-huazhong authored and 白超 committed Feb 12, 2025
1 parent 7347c29 commit 86c5569
Showing 1 changed file with 28 additions and 7 deletions.
35 changes: 28 additions & 7 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import functools
import os
import textwrap
import warnings
Expand All @@ -20,8 +22,10 @@
from unittest.mock import patch

import torch
import torch.distributed
import torch.utils.data
import transformers
from accelerate import PartialState
from accelerate.utils import broadcast_object_list, gather, gather_object, set_seed
from accelerate.utils.other import is_compiled_module
from datasets import Dataset, IterableDataset
Expand Down Expand Up @@ -372,21 +376,23 @@ def data_collator(features): # No data collation is needed in GRPO

if self.accelerator.is_main_process:
vllm_device = self.args.vllm_device
device_type = PartialState().default_device.type
device_module = getattr(torch, device_type)
if vllm_device == "auto":
if torch.cuda.device_count() == 1:
vllm_device = "cuda:0" # particular case when training with onyl 1 GPU: share it
if device_module.device_count() == 1:
vllm_device = "{device_type}:0" # particular case when training with onyl 1 GPU: share it
else:
vllm_device = f"cuda:{self.accelerator.num_processes}" # take the next GPU idx
vllm_device = f"{device_type}:{self.accelerator.num_processes}" # take the next GPU idx
# Check that the requested device is available
if vllm_device.split(":")[0] == "cuda" and int(vllm_device.split(":")[1]) >= torch.cuda.device_count():
if vllm_device.split(":")[0] == f"{device_type}" and int(vllm_device.split(":")[1]) >= device_module.device_count():
raise ValueError(
f"The requested device for vllm ({vllm_device}) is not available. You are likely using vLLM "
"without restricting the number of GPUs for training. Set the `--num_processes` argument to a "
"value lower than the number of GPUs available on your machine—typically, reducing it by one "
f"is sufficient. In your case: `--num_processes {torch.cuda.device_count() - 1}`."
f"is sufficient. In your case: `--num_processes {device_module.device_count() - 1}`."
)
# Check that the requested device is not also used for training
if vllm_device in {f"cuda:{idx}" for idx in range(self.accelerator.num_processes)}:
if vllm_device in {f"{device_type}:{idx}" for idx in range(self.accelerator.num_processes)}:
warnings.warn(
f"The requested device {vllm_device} is also being used for training. For higher throughput "
"and to avoid out-of-memory errors, it is recommended to use a dedicated device for vLLM. "
Expand All @@ -400,7 +406,22 @@ def data_collator(features): # No data collation is needed in GRPO
profiling_patch = patch(
"vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling", return_value=None
)
with world_size_patch, profiling_patch:
# For Ascend NPU (torch-npu), collective communication requires the establishment of a communication group,
# and different processes must hold the same group number. However, multiple process groups will be created
# internally within vLLM. This will cause the group id of the communication group on rank 0 to be different from
# that of other ranks, causing backward to hang on because the communication domain cannot be established. So
# we need to patch it to make sure the group id of different ranks in the training phase are the same.
@contextlib.contextmanager
def new_group_context():
original_new_group = torch.distributed.new_group
try:
torch.distributed.new_group = functools.partial(original_new_group, use_local_synchronization=True)
yield
finally:
torch.distributed.new_group = original_new_group

new_group_patch = new_group_context() if device_type == "npu" else contextlib.nullcontext()
with world_size_patch, profiling_patch, new_group_patch:
self.llm = LLM(
model=model.name_or_path,
device=vllm_device,
Expand Down

0 comments on commit 86c5569

Please sign in to comment.