Skip to content

Commit

Permalink
support ASCEND NPU
Browse files Browse the repository at this point in the history
  • Loading branch information
as12138 committed Feb 25, 2025
1 parent ef8b6e7 commit d36c1c7
Show file tree
Hide file tree
Showing 10 changed files with 123 additions and 112 deletions.
16 changes: 8 additions & 8 deletions docs/ascend/ascend.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# veRL x Ascend

我们在 verRL 上增加对华为昇腾设备的支持,在华为昇腾设备上使用 veRL 与在英伟达 GPU 上使用几乎相同。
我们在 veRL 上增加对华为昇腾设备的支持,在华为昇腾设备上使用 veRL 与在英伟达 GPU 上使用几乎相同。

## 硬件支持

Expand All @@ -22,7 +22,7 @@
### 源码安装

```shell
git clone -b vllm-0.7-npu https://github.com/as12138/verl.git
git clone https://github.com/volcengine/verl.git
cd verl
pip install -r requirements-npu.txt
pip install -e .
Expand Down Expand Up @@ -50,17 +50,17 @@ pip install -e .

根据经验,我们期望在相同配置下,在华为昇腾设备上的 Loss 与英伟达 GPU 的 Loss 平均误差小于 2%,具体计算方式如下:

![loss_comparison](./images/loss_comparison.png)
![loss_comparison](https://github.com/eric-haibin-lin/verl-community/tree/main/docs/loss_comparison.png)

其中,N 表示训练的步数。更多信息请参考[精度计算说明](https://www.hiascend.com/document/detail/zh/Pytorch/600/ptmoddevg/trainingmigrguide/LMaccuracy_0001.html)

### 进展

| 算法 | 进展 |
|------|---------------------------------------------------------------|
| SFT | 已支持 |
| PPO | 已支持 |
| GRPO | 已支持 |
| 算法 | 进展 |
|:------|:----|
| SFT | 已支持 |
| PPO | 已支持 |
| GRPO | 已支持 |


> 补充说明:
Expand Down
Binary file removed docs/ascend/images/loss_comparison.png
Binary file not shown.
41 changes: 0 additions & 41 deletions examples/grpo_trainer/run_qwen2-7b_npu.sh

This file was deleted.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ dependencies = [
"ray>=2.10",
"tensordict<0.6",
"transformers",
"vllm<=0.7.3",
'wandb',
]

Expand Down
4 changes: 2 additions & 2 deletions requirements-npu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@ ray
tensordict<0.6
transformers
wandb
vllm
vllm-ascend
vllm==0.7.1
vllm-ascend==0.7.1rc1
46 changes: 36 additions & 10 deletions verl/bert_padding.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,39 @@
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
# Copyright (c) 2023, Tri Dao.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import torch
import torch.nn.functional as F
from einops import rearrange, repeat


class IndexFirstAxis(torch.autograd.Function):

@staticmethod
def forward(ctx, input, indices):
ctx.save_for_backward(indices)
Expand All @@ -14,9 +42,8 @@ def forward(ctx, input, indices):
second_dim = other_shape.numel()
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
# return input[indices]
return torch.gather(
rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
).reshape(-1, *other_shape)
return torch.gather(rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d",
d=second_dim)).reshape(-1, *other_shape)

@staticmethod
def backward(ctx, grad_output):
Expand All @@ -39,14 +66,13 @@ def backward(ctx, grad_output):


class IndexPutFirstAxis(torch.autograd.Function):

@staticmethod
def forward(ctx, values, indices, first_axis_dim):
ctx.save_for_backward(indices)
assert indices.ndim == 1
assert values.ndim >= 2
output = torch.zeros(
first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype
)
output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
output[indices] = values
# output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
Expand All @@ -65,6 +91,7 @@ def backward(ctx, grad_output):


class IndexFirstAxisResidual(torch.autograd.Function):

@staticmethod
def forward(ctx, input, indices):
ctx.save_for_backward(indices)
Expand Down Expand Up @@ -182,9 +209,8 @@ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_leng
"""
length = attention_mask_in_length.sum(dim=-1)
seqlen = attention_mask_in_length.size(-1)
attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length),
seqlen) < length.unsqueeze(
1)
attention_mask_2d = torch.arange(seqlen, device=length.device,
dtype=length.dtype).expand(len(length), seqlen) < length.unsqueeze(1)
real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten()
seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
Expand Down Expand Up @@ -217,4 +243,4 @@ def pad_input(hidden_states, indices, batch, seqlen):
# output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
# output[indices] = hidden_states
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
return rearrange(output, "(b s) ... -> b s ...", b=batch)
return rearrange(output, "(b s) ... -> b s ...", b=batch)
37 changes: 19 additions & 18 deletions verl/trainer/fsdp_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,12 +209,12 @@ def _build_model_optimizer(self):
init_context = get_init_weight_context_manager(use_meta_tensor=not config.tie_word_embeddings)

with init_context():
self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(local_model_path,
config=config,
torch_dtype=torch.float32,
attn_implementation='flash_attention_2'
if is_cuda_available else 'sdpa',
trust_remote_code=trust_remote_code)
self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
local_model_path,
config=config,
torch_dtype=torch.float32,
attn_implementation='flash_attention_2' if is_cuda_available else 'sdpa',
trust_remote_code=trust_remote_code)

# Apply Liger kernel if use_liger is enabled
if self.config.model.get('use_liger', False):
Expand Down Expand Up @@ -253,17 +253,17 @@ def _build_model_optimizer(self):
else:
cpu_offload = CPUOffload(offload_params=self.config.model.fsdp_config.offload_params)

self.fsdp_model = FSDP(module=self.model,
auto_wrap_policy=auto_wrap_policy,
param_init_fn=init_fn,
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=mixed_precision,
device_mesh=self.device_mesh,
sync_module_states=True,
device_id=torch.cuda.current_device() if is_cuda_available else
torch.npu.current_device(),
cpu_offload=cpu_offload,
use_orig_params=False)
self.fsdp_model = FSDP(
module=self.model,
auto_wrap_policy=auto_wrap_policy,
param_init_fn=init_fn,
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=mixed_precision,
device_mesh=self.device_mesh,
sync_module_states=True,
device_id=torch.cuda.current_device() if is_cuda_available else torch.npu.current_device(),
cpu_offload=cpu_offload,
use_orig_params=False)

log_gpu_memory_usage('After FSDP wrapping', logger=logger)

Expand Down Expand Up @@ -489,7 +489,8 @@ def fit(self):
# Perform final validation
val_losses = []
for val_data in self.val_dataloader:
val_data = TensorDict(val_data, batch_size=self.config.data.micro_batch_size_per_gpu).to(self.device)
val_data = TensorDict(val_data,
batch_size=self.config.data.micro_batch_size_per_gpu).to(self.device)
val_loss = self.validation_step(val_data)
val_losses.append(val_loss)
if rank == 0:
Expand Down
35 changes: 29 additions & 6 deletions verl/utils/device.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,34 @@
# This code is inspired by the torchtune.
# https://github.com/pytorch/torchtune/blob/main/torchtune/utils/_device.py
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice,this list
# of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice, this
# list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors may
# be used to endorse or promote products derived from this software without specific
# prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
# OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT
# SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
# TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
# BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
# DAMAGE.

import os
import logging
from enum import Enum
from typing import Optional

import torch
Expand Down Expand Up @@ -69,7 +94,5 @@ def get_torch_device() -> any:
try:
return getattr(torch, device_name)
except AttributeError:
logger.warning(
f"Device namespace '{device_name}' not found in torch, try to load torch.cuda."
)
return torch.cuda
logger.warning(f"Device namespace '{device_name}' not found in torch, try to load torch.cuda.")
return torch.cuda
6 changes: 3 additions & 3 deletions verl/utils/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@

def init_fn(x: torch.nn.Module):
if not torch.distributed.get_rank() == 0:
x = x.to_empty(device=torch.cuda.current_device() if is_cuda_available else torch.npu.current_device() if is_cuda_available else torch.npu.current_device(),
x = x.to_empty(device=torch.cuda.current_device() if is_cuda_available else torch.npu.current_device(),
recurse=False)
torch.cuda.empty_cache() if is_cuda_available else torch.npu.empty_cache() if is_cuda_available else torch.npu.empty_cache() if is_cuda_available else torch.npu.empty_cache()
torch.cuda.empty_cache() if is_cuda_available else torch.npu.empty_cache()
return x


Expand Down Expand Up @@ -127,7 +127,7 @@ def offload_fsdp_model_to_cpu(model: FSDP, empty_cache: bool = True):
flat_param._local_shard = flat_param.data
assert id(flat_param._local_shard) != id(flat_param.data)
if empty_cache:
torch.cuda.empty_cache() if is_cuda_available else torch.npu.empty_cache() if is_cuda_available else torch.npu.empty_cache()
torch.cuda.empty_cache() if is_cuda_available else torch.npu.empty_cache()


@torch.no_grad()
Expand Down
Loading

0 comments on commit d36c1c7

Please sign in to comment.