Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

megatron: drop the need for megatron patches #219

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/e2e_gsm8k_megatron.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,9 @@ jobs:
run: |
ray stop --force
[ ! -d "$HOME/Megatron-LM" ] && git clone -b core_v0.4.0_verl https://github.com/eric-haibin-lin/Megatron-LM $HOME/Megatron-LM
[ ! -d "$HOME/Megatron-LM-v0.4" ] && git clone -b core_v0.4.0 https://github.com/NVIDIA/Megatron-LM $HOME/Megatron-LM-v0.4
export PYTHONPATH=$PYTHONPATH:$HOME/Megatron-LM
bash tests/e2e/run_deepseek_megatron.sh
- name: Running gsm8k e2e training tests on 8 L20 GPUs with Megatron without patches
export PYTHONPATH=$HOME/Megatron-LM-v0.4:$PYTHONPATH
bash tests/e2e/run_deepseek_megatron.sh
76 changes: 39 additions & 37 deletions verl/models/llama/megatron/modeling_llama_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()
if megatron_config is not None:
assert embedding_kwargs.get('config', False), 'must have ModelParallelConfig'
tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)
tp_utils.update_kwargs_with_config(embedding_kwargs, self.config)
self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size,
embedding_dim=config.hidden_size,
**embedding_kwargs)
Expand Down Expand Up @@ -162,7 +162,7 @@ def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
if megatron_config is not None:
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
tp_utils.update_kwargs_with_config(column_kwargs, self.config)

self.lm_head = tensor_parallel.ColumnParallelLinear(input_size=config.hidden_size,
output_size=config.vocab_size,
Expand Down Expand Up @@ -225,10 +225,10 @@ def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()
self.megatron_config = megatron_config
self.config = megatron_config
if megatron_config is not None:
assert embedding_kwargs.get('config', False), 'must have ModelParallelConfig'
tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)
tp_utils.update_kwargs_with_config(embedding_kwargs, self.config)
self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size,
embedding_dim=config.hidden_size,
**embedding_kwargs)
Expand Down Expand Up @@ -257,7 +257,7 @@ def forward(self,

# (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size)
inputs_embeds = inputs_embeds.transpose(0, 1)
if self.megatron_config.sequence_parallel:
if self.config.sequence_parallel:
inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds)

hidden_states = inputs_embeds
Expand All @@ -278,21 +278,22 @@ def forward(self,

class ParallelLlamaForCausalLMRmPad(nn.Module):

def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
def __init__(self, model_config: LlamaConfig, megatron_config: ModelParallelConfig):
super().__init__()
self.config = config
self.megatron_config = megatron_config
# Note(haibin.lin): to be compatible with Megatron APIs, model.config refers to megatron configs
self.config = megatron_config
self.model_config = model_config
self.model = ParallelLlamaModelRmPad(config, megatron_config=megatron_config)
self.vocab_size = config.vocab_size
self.vocab_size = model_config.vocab_size
self._init_head()

def _init_head(self):
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
if self.megatron_config is not None:
if self.config is not None:
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
self.lm_head = tensor_parallel.ColumnParallelLinear(input_size=self.config.hidden_size,
output_size=self.config.vocab_size,
tp_utils.update_kwargs_with_config(column_kwargs, self.config)
self.lm_head = tensor_parallel.ColumnParallelLinear(input_size=self.model_config.hidden_size,
output_size=self.model_config.vocab_size,
bias=False,
gather_output=False,
skip_bias_add=False,
Expand Down Expand Up @@ -328,7 +329,7 @@ def forward(

# pad input_ids to multiple of tp for all tp ranks
# TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap
if self.megatron_config.sequence_parallel:
if self.config.sequence_parallel:
input_ids = sp_utils.pad_to_sequence_parallel(input_ids)

input_ids = input_ids.transpose(0, 1) # (1, total_nnz+pad)
Expand All @@ -345,7 +346,7 @@ def forward(
logits = self._forward_head(hidden_states)

# remove padding from sequence parallel
if self.megatron_config.sequence_parallel:
if self.config.sequence_parallel:
totol_nnz = cu_seqlens[-1]
logits = logits[:totol_nnz] # (total_nnz_padded)

Expand All @@ -367,17 +368,17 @@ class ParallelLlamaForValueRmPad(ParallelLlamaForCausalLMRmPad):

def _init_head(self):
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
if self.megatron_config is not None:
if self.config is not None:
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
self.lm_head = nn.Linear(in_features=self.config.hidden_size, out_features=1, bias=False)
tp_utils.update_kwargs_with_config(column_kwargs, self.config)
self.lm_head = nn.Linear(in_features=self.model_config.hidden_size, out_features=1, bias=False)
# lm_head is effectively the same as sequence parallel
sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight)

def _forward_head(self, hidden_states):
logits = self.lm_head(hidden_states) # (total_nnz_padded // tp, 1, 1)
logits = logits.float()
if self.megatron_config.sequence_parallel:
if self.config.sequence_parallel:
logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False)
return logits

Expand Down Expand Up @@ -413,11 +414,11 @@ def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, pr
self.vocab_size = config.vocab_size
self.pre_process = pre_process
self.post_process = post_process
self.megatron_config = megatron_config
self.config = megatron_config
embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()
if megatron_config is not None:
assert embedding_kwargs.get('config', False), 'must have ModelParallelConfig'
tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)
tp_utils.update_kwargs_with_config(embedding_kwargs, self.config)
if pre_process:
self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size,
embedding_dim=config.hidden_size,
Expand Down Expand Up @@ -487,7 +488,7 @@ def forward(self,
# so need to deal with it by handle here:
# (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size)
inputs_embeds = inputs_embeds.transpose(0, 1)
if self.megatron_config.sequence_parallel:
if self.config.sequence_parallel:
inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds)

hidden_states = inputs_embeds
Expand All @@ -513,16 +514,17 @@ def forward(self,

class ParallelLlamaForCausalLMRmPadPP(nn.Module):

def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, pre_process, post_process):
def __init__(self, model_config: LlamaConfig, megatron_config: ModelParallelConfig, pre_process, post_process):
super().__init__()
self.config = config
self.megatron_config = megatron_config
self.model = ParallelLlamaModelRmPadPP(config,
# Note(haibin.lin): to be compatible with Megatron APIs, model.config refers to megatron configs
self.config = megatron_config
self.model_config = model_config
self.model = ParallelLlamaModelRmPadPP(model_config,
megatron_config=megatron_config,
pre_process=pre_process,
post_process=post_process)
self.share_embeddings_and_output_weights = None # workaround, megatron requires this attr
self.vocab_size = config.vocab_size
self.vocab_size = model_config.vocab_size
self.pre_process = pre_process
self.post_process = post_process
if post_process:
Expand All @@ -541,11 +543,11 @@ def set_input_tensor(self, input_tensor):

def _init_head(self):
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
if self.megatron_config is not None:
if self.config is not None:
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
self.lm_head = tensor_parallel.ColumnParallelLinear(input_size=self.config.hidden_size,
output_size=self.config.vocab_size,
tp_utils.update_kwargs_with_config(column_kwargs, self.config)
self.lm_head = tensor_parallel.ColumnParallelLinear(input_size=self.model_config.hidden_size,
output_size=self.model_config.vocab_size,
bias=False,
gather_output=False,
skip_bias_add=False,
Expand Down Expand Up @@ -586,7 +588,7 @@ def forward(

# pad input_ids to multiple of tp for all tp ranks
# TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap
if self.megatron_config.sequence_parallel:
if self.config.sequence_parallel:
input_ids_rmpad = sp_utils.pad_to_sequence_parallel(input_ids_rmpad)

input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz+pad)
Expand All @@ -605,7 +607,7 @@ def forward(
logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension # torch.Size([8, 32, 16])

# remove padding from sequence parallel
if self.megatron_config.sequence_parallel:
if self.config.sequence_parallel:
totol_nnz = cu_seqlens[-1]
logits = logits[:totol_nnz] # (total_nnz_padded)
# add removed padding back. If input is already rmpad, we let the caller pad_input
Expand All @@ -627,17 +629,17 @@ class ParallelLlamaForValueRmPadPP(ParallelLlamaForCausalLMRmPadPP):

def _init_head(self):
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
if self.megatron_config is not None:
if self.config is not None:
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
self.lm_head = nn.Linear(in_features=self.config.hidden_size, out_features=1, bias=False)
tp_utils.update_kwargs_with_config(column_kwargs, self.config)
self.lm_head = nn.Linear(in_features=self.model_config.hidden_size, out_features=1, bias=False)
# lm_head is effectively the same as sequence parallel
sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight)

def _forward_head(self, hidden_states):
logits = self.lm_head(hidden_states) # (total_nnz_padded // tp, 1, 1)
logits = logits.float()
if self.megatron_config.sequence_parallel:
if self.config.sequence_parallel:
logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False)
return logits

Expand Down
2 changes: 2 additions & 0 deletions verl/single_controller/ray/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def __init__(self, resource_pool: RayResourcePool, ray_cls_with_init: RayClassWi

class MegatronRayWorkerGroup(RayWorkerGroup, MegatronWorkerGroup):
"""
Note(haibin.lin): this class is not used in the open source version of verl. Kept for internal reference only.

MegatronWorkerGroup will query each worker of its megatron rank info and store it inside the WorkerGroup
so that the dispatcher can use it to dispatch data.
"""
Expand Down
Loading
Loading