Skip to content

Commit

Permalink
🧗 Add GRPO Trainer support for third-party accelerators (#2836)
Browse files Browse the repository at this point in the history
* Add GRPO Trainer support for Ascend NPU

* 更新 grpo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* code format

* 更新 grpo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* patch mem_get_info

* stylre

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
  • Loading branch information
3 people authored Feb 27, 2025
1 parent f074dcd commit 27a6f22
Showing 1 changed file with 33 additions and 7 deletions.
40 changes: 33 additions & 7 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import functools
import os
import textwrap
import warnings
Expand All @@ -22,6 +24,7 @@
import torch
import torch.utils.data
import transformers
from accelerate import PartialState
from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed
from accelerate.utils.other import is_compiled_module
from datasets import Dataset, IterableDataset
Expand Down Expand Up @@ -444,21 +447,26 @@ def data_collator(features): # No data collation is needed in GRPO

if self.accelerator.is_main_process:
vllm_device = self.args.vllm_device
device_type = PartialState().default_device.type
device_module = getattr(torch, device_type)
if vllm_device == "auto":
if torch.cuda.device_count() == 1:
vllm_device = "cuda:0" # particular case when training with only 1 GPU: share it
if device_module.device_count() == 1:
vllm_device = f"{device_type}:0" # particular case when training with onyl 1 device: share it
else:
vllm_device = f"cuda:{self.accelerator.num_processes}" # take the next GPU idx
vllm_device = f"{device_type}:{self.accelerator.num_processes}" # take the next GPU idx
# Check that the requested device is available
if vllm_device.split(":")[0] == "cuda" and int(vllm_device.split(":")[1]) >= torch.cuda.device_count():
if (
vllm_device.split(":")[0] == f"{device_type}"
and int(vllm_device.split(":")[1]) >= device_module.device_count()
):
raise ValueError(
f"The requested device for vllm ({vllm_device}) is not available. You are likely using vLLM "
"without restricting the number of GPUs for training. Set the `--num_processes` argument to a "
"value lower than the number of GPUs available on your machine—typically, reducing it by one "
f"is sufficient. In your case: `--num_processes {torch.cuda.device_count() - 1}`."
f"is sufficient. In your case: `--num_processes {device_module.device_count() - 1}`."
)
# Check that the requested device is not also used for training
if vllm_device in {f"cuda:{idx}" for idx in range(self.accelerator.num_processes)}:
if vllm_device in {f"{device_type}:{idx}" for idx in range(self.accelerator.num_processes)}:
warnings.warn(
f"The requested device {vllm_device} is also being used for training. For higher throughput "
"and to avoid out-of-memory errors, it is recommended to use a dedicated device for vLLM. "
Expand All @@ -472,7 +480,25 @@ def data_collator(features): # No data collation is needed in GRPO
profiling_patch = patch(
"vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling", return_value=None
)
with world_size_patch, profiling_patch:

# For Ascend NPU (torch-npu), collective communication requires the establishment of a communication
# group, and different processes must hold the same group number. However, multiple process groups will
# be created internally within vLLM. This will cause the group id of the communication group on rank 0
# to be different from that of other ranks, causing backward to hang on because the communication
# domain cannot be established. So we need to patch it to make sure the group id of different ranks in
# the training phase are the same.
@contextlib.contextmanager
def new_group_context():
new_group = torch.distributed.new_group
try:
torch.distributed.new_group = functools.partial(new_group, use_local_synchronization=True)
torch.npu.mem_get_info = functools.partial(torch.npu.mem_get_info, device=vllm_device)
yield
finally:
torch.distributed.new_group = new_group

new_group_patch = new_group_context() if device_type == "npu" else contextlib.nullcontext()
with world_size_patch, profiling_patch, new_group_patch:
self.llm = LLM(
model=model.name_or_path,
device=vllm_device,
Expand Down

0 comments on commit 27a6f22

Please sign in to comment.