Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🧗 Add GRPO Trainer support for third-party accelerators #2836

Merged
merged 10 commits into from
Feb 27, 2025
40 changes: 33 additions & 7 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import functools
import os
import textwrap
import warnings
Expand All @@ -22,6 +24,7 @@
import torch
import torch.utils.data
import transformers
from accelerate import PartialState
from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed
from accelerate.utils.other import is_compiled_module
from datasets import Dataset, IterableDataset
Expand Down Expand Up @@ -444,21 +447,26 @@ def data_collator(features): # No data collation is needed in GRPO

if self.accelerator.is_main_process:
vllm_device = self.args.vllm_device
device_type = PartialState().default_device.type
device_module = getattr(torch, device_type)
if vllm_device == "auto":
if torch.cuda.device_count() == 1:
vllm_device = "cuda:0" # particular case when training with only 1 GPU: share it
if device_module.device_count() == 1:
vllm_device = f"{device_type}:0" # particular case when training with onyl 1 device: share it
else:
vllm_device = f"cuda:{self.accelerator.num_processes}" # take the next GPU idx
vllm_device = f"{device_type}:{self.accelerator.num_processes}" # take the next GPU idx
# Check that the requested device is available
if vllm_device.split(":")[0] == "cuda" and int(vllm_device.split(":")[1]) >= torch.cuda.device_count():
if (
vllm_device.split(":")[0] == f"{device_type}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should always be the case, no?

Copy link
Contributor Author

@ji-huazhong ji-huazhong Feb 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @qgallouedec

Thanks for your review. In line 387,I maintained the same logic as orignal conditional statement,only repalcing the 'cuda' type with more general type.

I believe the check for device availability here is necessary. However, perhaps we could split this conditional statement into two parts.

First, we check if the device type matches, and only after this condition is met do we check if the device index is within the range of available devices. wdyt?

and int(vllm_device.split(":")[1]) >= device_module.device_count()
):
raise ValueError(
f"The requested device for vllm ({vllm_device}) is not available. You are likely using vLLM "
"without restricting the number of GPUs for training. Set the `--num_processes` argument to a "
"value lower than the number of GPUs available on your machine—typically, reducing it by one "
f"is sufficient. In your case: `--num_processes {torch.cuda.device_count() - 1}`."
f"is sufficient. In your case: `--num_processes {device_module.device_count() - 1}`."
)
# Check that the requested device is not also used for training
if vllm_device in {f"cuda:{idx}" for idx in range(self.accelerator.num_processes)}:
if vllm_device in {f"{device_type}:{idx}" for idx in range(self.accelerator.num_processes)}:
warnings.warn(
f"The requested device {vllm_device} is also being used for training. For higher throughput "
"and to avoid out-of-memory errors, it is recommended to use a dedicated device for vLLM. "
Expand All @@ -472,7 +480,25 @@ def data_collator(features): # No data collation is needed in GRPO
profiling_patch = patch(
"vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling", return_value=None
)
with world_size_patch, profiling_patch:

# For Ascend NPU (torch-npu), collective communication requires the establishment of a communication
# group, and different processes must hold the same group number. However, multiple process groups will
# be created internally within vLLM. This will cause the group id of the communication group on rank 0
# to be different from that of other ranks, causing backward to hang on because the communication
# domain cannot be established. So we need to patch it to make sure the group id of different ranks in
# the training phase are the same.
@contextlib.contextmanager
def new_group_context():
new_group = torch.distributed.new_group
try:
torch.distributed.new_group = functools.partial(new_group, use_local_synchronization=True)
torch.npu.mem_get_info = functools.partial(torch.npu.mem_get_info, device=vllm_device)
yield
finally:
torch.distributed.new_group = new_group

new_group_patch = new_group_context() if device_type == "npu" else contextlib.nullcontext()
with world_size_patch, profiling_patch, new_group_patch:
self.llm = LLM(
model=model.name_or_path,
device=vllm_device,
Expand Down
Loading