diff --git a/.github/workflows/e2e_gsm8k_megatron.yml b/.github/workflows/e2e_gsm8k_megatron.yml index 305d1724..5bb69bc3 100644 --- a/.github/workflows/e2e_gsm8k_megatron.yml +++ b/.github/workflows/e2e_gsm8k_megatron.yml @@ -45,5 +45,9 @@ jobs: run: | ray stop --force [ ! -d "$HOME/Megatron-LM" ] && git clone -b core_v0.4.0_verl https://github.com/eric-haibin-lin/Megatron-LM $HOME/Megatron-LM + [ ! -d "$HOME/Megatron-LM-v0.4" ] && git clone -b core_v0.4.0 https://github.com/NVIDIA/Megatron-LM $HOME/Megatron-LM-v0.4 export PYTHONPATH=$PYTHONPATH:$HOME/Megatron-LM + bash tests/e2e/run_deepseek_megatron.sh + - name: Running gsm8k e2e training tests on 8 L20 GPUs with Megatron without patches + export PYTHONPATH=$HOME/Megatron-LM-v0.4:$PYTHONPATH bash tests/e2e/run_deepseek_megatron.sh \ No newline at end of file diff --git a/verl/models/llama/megatron/modeling_llama_megatron.py b/verl/models/llama/megatron/modeling_llama_megatron.py index c693f33c..88ac4234 100644 --- a/verl/models/llama/megatron/modeling_llama_megatron.py +++ b/verl/models/llama/megatron/modeling_llama_megatron.py @@ -84,7 +84,7 @@ def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() if megatron_config is not None: assert embedding_kwargs.get('config', False), 'must have ModelParallelConfig' - tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) + tp_utils.update_kwargs_with_config(embedding_kwargs, self.config) self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs) @@ -162,7 +162,7 @@ def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() if megatron_config is not None: assert column_kwargs.get('config', False), 'must have ModelParallelConfig' - tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) + tp_utils.update_kwargs_with_config(column_kwargs, self.config) self.lm_head = tensor_parallel.ColumnParallelLinear(input_size=config.hidden_size, output_size=config.vocab_size, @@ -225,10 +225,10 @@ def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() - self.megatron_config = megatron_config + self.config = megatron_config if megatron_config is not None: assert embedding_kwargs.get('config', False), 'must have ModelParallelConfig' - tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) + tp_utils.update_kwargs_with_config(embedding_kwargs, self.config) self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs) @@ -257,7 +257,7 @@ def forward(self, # (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size) inputs_embeds = inputs_embeds.transpose(0, 1) - if self.megatron_config.sequence_parallel: + if self.config.sequence_parallel: inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds) hidden_states = inputs_embeds @@ -278,21 +278,22 @@ def forward(self, class ParallelLlamaForCausalLMRmPad(nn.Module): - def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): + def __init__(self, model_config: LlamaConfig, megatron_config: ModelParallelConfig): super().__init__() - self.config = config - self.megatron_config = megatron_config + # Note(haibin.lin): to be compatible with Megatron APIs, model.config refers to megatron configs + self.config = megatron_config + self.model_config = model_config self.model = ParallelLlamaModelRmPad(config, megatron_config=megatron_config) - self.vocab_size = config.vocab_size + self.vocab_size = model_config.vocab_size self._init_head() def _init_head(self): column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() - if self.megatron_config is not None: + if self.config is not None: assert column_kwargs.get('config', False), 'must have ModelParallelConfig' - tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) - self.lm_head = tensor_parallel.ColumnParallelLinear(input_size=self.config.hidden_size, - output_size=self.config.vocab_size, + tp_utils.update_kwargs_with_config(column_kwargs, self.config) + self.lm_head = tensor_parallel.ColumnParallelLinear(input_size=self.model_config.hidden_size, + output_size=self.model_config.vocab_size, bias=False, gather_output=False, skip_bias_add=False, @@ -328,7 +329,7 @@ def forward( # pad input_ids to multiple of tp for all tp ranks # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap - if self.megatron_config.sequence_parallel: + if self.config.sequence_parallel: input_ids = sp_utils.pad_to_sequence_parallel(input_ids) input_ids = input_ids.transpose(0, 1) # (1, total_nnz+pad) @@ -345,7 +346,7 @@ def forward( logits = self._forward_head(hidden_states) # remove padding from sequence parallel - if self.megatron_config.sequence_parallel: + if self.config.sequence_parallel: totol_nnz = cu_seqlens[-1] logits = logits[:totol_nnz] # (total_nnz_padded) @@ -367,17 +368,17 @@ class ParallelLlamaForValueRmPad(ParallelLlamaForCausalLMRmPad): def _init_head(self): column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() - if self.megatron_config is not None: + if self.config is not None: assert column_kwargs.get('config', False), 'must have ModelParallelConfig' - tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) - self.lm_head = nn.Linear(in_features=self.config.hidden_size, out_features=1, bias=False) + tp_utils.update_kwargs_with_config(column_kwargs, self.config) + self.lm_head = nn.Linear(in_features=self.model_config.hidden_size, out_features=1, bias=False) # lm_head is effectively the same as sequence parallel sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight) def _forward_head(self, hidden_states): logits = self.lm_head(hidden_states) # (total_nnz_padded // tp, 1, 1) logits = logits.float() - if self.megatron_config.sequence_parallel: + if self.config.sequence_parallel: logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False) return logits @@ -413,11 +414,11 @@ def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, pr self.vocab_size = config.vocab_size self.pre_process = pre_process self.post_process = post_process - self.megatron_config = megatron_config + self.config = megatron_config embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() if megatron_config is not None: assert embedding_kwargs.get('config', False), 'must have ModelParallelConfig' - tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) + tp_utils.update_kwargs_with_config(embedding_kwargs, self.config) if pre_process: self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, @@ -487,7 +488,7 @@ def forward(self, # so need to deal with it by handle here: # (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size) inputs_embeds = inputs_embeds.transpose(0, 1) - if self.megatron_config.sequence_parallel: + if self.config.sequence_parallel: inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds) hidden_states = inputs_embeds @@ -513,16 +514,17 @@ def forward(self, class ParallelLlamaForCausalLMRmPadPP(nn.Module): - def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, pre_process, post_process): + def __init__(self, model_config: LlamaConfig, megatron_config: ModelParallelConfig, pre_process, post_process): super().__init__() - self.config = config - self.megatron_config = megatron_config - self.model = ParallelLlamaModelRmPadPP(config, + # Note(haibin.lin): to be compatible with Megatron APIs, model.config refers to megatron configs + self.config = megatron_config + self.model_config = model_config + self.model = ParallelLlamaModelRmPadPP(model_config, megatron_config=megatron_config, pre_process=pre_process, post_process=post_process) self.share_embeddings_and_output_weights = None # workaround, megatron requires this attr - self.vocab_size = config.vocab_size + self.vocab_size = model_config.vocab_size self.pre_process = pre_process self.post_process = post_process if post_process: @@ -541,11 +543,11 @@ def set_input_tensor(self, input_tensor): def _init_head(self): column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() - if self.megatron_config is not None: + if self.config is not None: assert column_kwargs.get('config', False), 'must have ModelParallelConfig' - tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) - self.lm_head = tensor_parallel.ColumnParallelLinear(input_size=self.config.hidden_size, - output_size=self.config.vocab_size, + tp_utils.update_kwargs_with_config(column_kwargs, self.config) + self.lm_head = tensor_parallel.ColumnParallelLinear(input_size=self.model_config.hidden_size, + output_size=self.model_config.vocab_size, bias=False, gather_output=False, skip_bias_add=False, @@ -586,7 +588,7 @@ def forward( # pad input_ids to multiple of tp for all tp ranks # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap - if self.megatron_config.sequence_parallel: + if self.config.sequence_parallel: input_ids_rmpad = sp_utils.pad_to_sequence_parallel(input_ids_rmpad) input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz+pad) @@ -605,7 +607,7 @@ def forward( logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension # torch.Size([8, 32, 16]) # remove padding from sequence parallel - if self.megatron_config.sequence_parallel: + if self.config.sequence_parallel: totol_nnz = cu_seqlens[-1] logits = logits[:totol_nnz] # (total_nnz_padded) # add removed padding back. If input is already rmpad, we let the caller pad_input @@ -627,17 +629,17 @@ class ParallelLlamaForValueRmPadPP(ParallelLlamaForCausalLMRmPadPP): def _init_head(self): column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() - if self.megatron_config is not None: + if self.config is not None: assert column_kwargs.get('config', False), 'must have ModelParallelConfig' - tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) - self.lm_head = nn.Linear(in_features=self.config.hidden_size, out_features=1, bias=False) + tp_utils.update_kwargs_with_config(column_kwargs, self.config) + self.lm_head = nn.Linear(in_features=self.model_config.hidden_size, out_features=1, bias=False) # lm_head is effectively the same as sequence parallel sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight) def _forward_head(self, hidden_states): logits = self.lm_head(hidden_states) # (total_nnz_padded // tp, 1, 1) logits = logits.float() - if self.megatron_config.sequence_parallel: + if self.config.sequence_parallel: logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False) return logits diff --git a/verl/single_controller/ray/megatron.py b/verl/single_controller/ray/megatron.py index 3ccb23a1..c24f47b9 100644 --- a/verl/single_controller/ray/megatron.py +++ b/verl/single_controller/ray/megatron.py @@ -37,6 +37,8 @@ def __init__(self, resource_pool: RayResourcePool, ray_cls_with_init: RayClassWi class MegatronRayWorkerGroup(RayWorkerGroup, MegatronWorkerGroup): """ + Note(haibin.lin): this class is not used in the open source version of verl. Kept for internal reference only. + MegatronWorkerGroup will query each worker of its megatron rank info and store it inside the WorkerGroup so that the dispatcher can use it to dispatch data. """ diff --git a/verl/utils/megatron/optimizer.py b/verl/utils/megatron/optimizer.py index 9ae70b08..07aede38 100644 --- a/verl/utils/megatron/optimizer.py +++ b/verl/utils/megatron/optimizer.py @@ -90,3 +90,138 @@ def get_megatron_optimizer( # FP32. return FP32Optimizer(optimizer, config.clip_grad, config.log_num_zeros_in_grad, check_for_nan_in_loss_and_grad, params_have_main_grad, model) + + +def _init_distributed_optimizer(self, optimizer, clip_grad, log_num_zeros_in_grad, check_for_nan_in_grad, + params_have_main_grad, fp16, bf16, params_dtype, grad_scaler, models, + overlap_param_gather: bool): + """Megatron optimizer initialized WITHOUT the dependency of **get_args()** APIs. + + See top of class definition for argument descriptions. + + The steps in this method create the core mapping between DDP grad + buffers, parameters, and parameter shard ranges, that is needed for + converting between model param indexes and main parameter shard + indexes. This method also updates the optimizer parameter groups + with the newly created shards. + """ + import torch + from megatron import get_args + from megatron import get_timers + from megatron import print_rank_0 + from megatron.core import mpu, tensor_parallel + + from megatron.optimizer.optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper + from megatron.optimizer.utils import shard_buffer + + super(DistributedOptimizer, self).__init__(optimizer, clip_grad, log_num_zeros_in_grad, check_for_nan_in_grad, + params_have_main_grad, fp16, bf16, params_dtype, grad_scaler, models) + + assert isinstance(optimizer, Adam), \ + "Only Adam currently supported, due to checkpointing requirements." + + # Model grad buffer ranges. + self.model_gbuf_ranges = [] + self.per_bucket_numel = [] + for _, model_chunk in enumerate(self.models): + self.per_bucket_numel.append({ + dtype: [bucket.data.numel() for bucket in model_chunk.grad_buffers[dtype].buckets] + for dtype in model_chunk.grad_buffers + }) + self.model_gbuf_ranges.append(self.build_model_gbuf_range_map(model_chunk)) + self.model_param_gbuf_map = \ + self.build_model_param_gbuf_map(self.model_gbuf_ranges) + + # Optimizer ranges. + self.model_param_group_index_map, self.opt_group_ranges = \ + self.build_optimizer_group_ranges(self.optimizer.param_groups, + self.model_gbuf_ranges) + + # Allocate main param shards. + ( + self.model_float16_groups, + self.model_fp32_groups, + self.shard_float16_groups, + self.shard_fp32_groups, + self.shard_fp32_from_float16_groups, + ) = self.build_model_and_main_param_groups(self.model_gbuf_ranges, self.model_param_gbuf_map, self.opt_group_ranges) + + # Initialize param buffers. + # - These are views on the DDP model's grad buffers, that share + # storage & have their own dtype. This is safe because the param + # dtype size is always <= grad dtype size. + self.param_buffers = [] + for model_index, model in enumerate(self.models): + current_param_buffers = {} + for dtype, grad_buffer in model.grad_buffers.items(): + size_ratio = torch.finfo(dtype).bits // torch.finfo(params_dtype).bits + current_param_buffers[dtype] = [] + for bucket in grad_buffer.buckets: + + # Handle older/newer method for getting untyped storage. + try: + storage = bucket.data.storage()._untyped() + except: + storage = bucket.data.storage().untyped() + + # Typed param buffer. + param_buffer = torch.tensor(storage, dtype=params_dtype, device=bucket.data.device) + + # .storage() ignores views / slices, so param_buffer now points to the start + # of the grad_buffer instead of to the start of each bucket. As a result, + # add bucket.offset to make sure param_buffers point to the right region of + # memory. + # Since we want the start of each bucket's param_buffer to coincide with the + # start of the same bucket's grad_buffer (this ensures that zeroing the grad + # buffer does not zero out params in the param_buffer before they are copied + # into the model_params), multiply the offset by the size ratio of grads and + # params. + offset = bucket.offset * size_ratio + param_buffer = param_buffer[offset:offset + bucket.data.numel()] + assert param_buffer.data_ptr() == bucket.data.data_ptr(), \ + "param_buffer and grad_buffer for same bucket should start at the same byte address" + assert param_buffer.numel() == bucket.data.numel(), \ + "param_buffer and grad_buffer for same bucket should have the same number of elements" + current_param_buffers[dtype].append(param_buffer) + self.param_buffers.append(current_param_buffers) + + # Now construct data structures to manage all-gather handles. + self.all_gather_handles = [] + self.all_gather_handle_index_to_bucket_index_map = [] + self.model_index_to_all_gather_handle_index_map = {} + self.param_to_all_gather_handle_index_map = {} + self.param_buffer_copied = [] + + self.pbuf_view_items = self.get_model_param_buffer_dp_views() + for (model_index, dtype, bucket_index, _, _) in self.pbuf_view_items: + self.all_gather_handle_index_to_bucket_index_map.append((model_index, dtype, bucket_index)) + all_gather_handle_index = len(self.all_gather_handle_index_to_bucket_index_map) - 1 + + # Store all all_gather_handle_indices relevant to a particular model chunk. + if model_index not in self.model_index_to_all_gather_handle_index_map: + self.model_index_to_all_gather_handle_index_map[model_index] = [] + self.model_index_to_all_gather_handle_index_map[model_index].append(all_gather_handle_index) + + for param in self.models[model_index].grad_buffers[dtype].buckets[bucket_index].params_list: + self.param_to_all_gather_handle_index_map[param] = all_gather_handle_index + self.param_buffer_copied.append(False) + self.num_all_gather_handles = len(self.all_gather_handle_index_to_bucket_index_map) + + self.overlap_param_gather = overlap_param_gather + if self.overlap_param_gather: + self.remove_pre_hook_handle = torch.nn.modules.module.register_module_forward_pre_hook( + self._make_forward_pre_hook()) + else: + self.remove_pre_hook_handle = None + + self.update_successful = False + + # Update optimizer groups. + # - Also, leverage state_dict() and load_state_dict() to + # recast preexisting per-param state tensors. + self.optimizer.param_groups = \ + [ g["orig_group"] for g in self.opt_group_ranges ] + self.optimizer.load_state_dict(self.optimizer.state_dict()) + + +DistributedOptimizer.__init__ = _init_distributed_optimizer diff --git a/verl/utils/megatron/pipeline_parallel.py b/verl/utils/megatron/pipeline_parallel.py index 3a3790bb..27d197e2 100644 --- a/verl/utils/megatron/pipeline_parallel.py +++ b/verl/utils/megatron/pipeline_parallel.py @@ -49,3 +49,18 @@ def make_batch_generator(batches, vpp_size): # no vpp batch_generator = iter(batches) return batch_generator + + +def require_extra_schedule_kwargs(): + """Used to work around megatron get_args() issues. To be dropped after mcore v0.7""" + from megatron.core.pipeline_parallel.schedules import forward_backward_no_pipelining + import inspect + num_args = len(inspect.signature(forward_backward_no_pipelining).parameters) + # mcore v0.4 + if num_args == 9: + return False + elif num_args == 11: + # mcore v0.4 patched version + return True + else: + raise NotImplementedError("Unknown megatron version") diff --git a/verl/workers/actor/megatron_actor.py b/verl/workers/actor/megatron_actor.py index 9ada0230..d27b95f9 100644 --- a/verl/workers/actor/megatron_actor.py +++ b/verl/workers/actor/megatron_actor.py @@ -302,20 +302,24 @@ def forward_step(batch_iter, model): # batch should be a list of batches inside micro-batches batch_generator = make_batch_generator(batches, vpp_size=len(self.actor_module)) + from verl.utils.megatron.pipeline_parallel import require_extra_schedule_kwargs + schedule_kwargs = {} + if require_extra_schedule_kwargs(): + schedule_kwargs = {'hidden_size': self.model_config.hidden_size} # TODO: we may use the new schedule instead # for flash-attn: (seq_len, batch_size, hidden_size) = (mbs*seq_len, 1, hidden_size) if mpu.get_pipeline_model_parallel_world_size() > 1: + schedule_kwargs['input_shapes'] = input_shapes losses_reduced = forward_backward_func( forward_step_func=forward_step, data_iterator=batch_generator, model=self.actor_module, num_microbatches=n_micro_batch, - input_shapes=input_shapes, # must set for flash-attn sequence packing seq_length=batch_size * seq_len, # no use when input_shapes was set - hidden_size=self.model_config.hidden_size, # no use when input_shapes was set micro_batch_size=1, # no use when input_shapes was set forward_only=forward_only, + **schedule_kwargs, ) else: losses_reduced = forward_backward_func( @@ -324,9 +328,9 @@ def forward_step(batch_iter, model): model=self.actor_module, num_microbatches=n_micro_batch, seq_length=batch_size * seq_len, # in use for pp = 1 - hidden_size=self.model_config.hidden_size, # in use for pp = 1 micro_batch_size=1, # in use for pp = 1 forward_only=forward_only, + **schedule_kwargs, ) # loss_reduces contains the stats returned from loss_func return losses_reduced @@ -370,4 +374,4 @@ def update_policy(self, dataloader: Iterable[DataProto]) -> Dict: # add empty cache after each compute torch.cuda.empty_cache() - return metrics + return metrics \ No newline at end of file diff --git a/verl/workers/critic/megatron_critic.py b/verl/workers/critic/megatron_critic.py index f0b044ad..e2134c76 100644 --- a/verl/workers/critic/megatron_critic.py +++ b/verl/workers/critic/megatron_critic.py @@ -176,10 +176,15 @@ def forward_step(batch_iter, model): # batch should be a list of batches inside micro-batches batch_generator = make_batch_generator(batches, vpp_size=len(self.critic_module)) + from verl.utils.megatron.pipeline_parallel import require_extra_schedule_kwargs + schedule_kwargs = {} + if require_extra_schedule_kwargs(): + schedule_kwargs = {'hidden_size': self.model_config.hidden_size} # TODO: we may use the new schedule instead # for flash-attn: (seq_len, batch_size, hidden_size) = (mbs*seq_len, 1, hidden_size) if mpu.get_pipeline_model_parallel_world_size() > 1: + schedule_kwargs['input_shapes'] = input_shapes losses_reduced = forward_backward_func( forward_step_func=forward_step, data_iterator=batch_generator, @@ -187,9 +192,9 @@ def forward_step(batch_iter, model): num_microbatches=n_micro_batch, input_shapes=input_shapes, # must set for flash-attn sequence packing seq_length=self.config.ppo_micro_batch_size_per_gpu * seq_len, # no use when input_shapes was set - hidden_size=self.model_config.hidden_size, # no use when input_shapes was set micro_batch_size=1, # no use when input_shapes was set forward_only=forward_only, + **schedule_kwargs, ) else: losses_reduced = forward_backward_func( @@ -198,9 +203,9 @@ def forward_step(batch_iter, model): model=self.critic_module, num_microbatches=n_micro_batch, seq_length=self.config.ppo_micro_batch_size_per_gpu * seq_len, # in use for pp = 1 - hidden_size=self.model_config.hidden_size, # in use for pp = 1 micro_batch_size=1, # in use for pp = 1 forward_only=forward_only, + **schedule_kwargs, ) # loss_reduces contains the stats returned from loss_func return losses_reduced @@ -230,4 +235,4 @@ def update_critic(self, dataloader: Iterable[DataProto]): # add empty cache after each compute torch.cuda.empty_cache() - return metrics + return metrics \ No newline at end of file diff --git a/verl/workers/reward_model/megatron/reward_model.py b/verl/workers/reward_model/megatron/reward_model.py index 1b58f42c..0cfc1d49 100644 --- a/verl/workers/reward_model/megatron/reward_model.py +++ b/verl/workers/reward_model/megatron/reward_model.py @@ -229,20 +229,23 @@ def forward_step(batch_iter, model): # batch should be a list of batches inside micro-batches batch_generator = make_batch_generator(batches, vpp_size=len(self.reward_model_module)) - + from verl.utils.megatron.pipeline_parallel import require_extra_schedule_kwargs + schedule_kwargs = {} + if require_extra_schedule_kwargs(): + schedule_kwargs = {'hidden_size': self.model_config.hidden_size} # TODO: we may use the new schedule instead # for flash-attn: (seq_len, batch_size, hidden_size) = (mbs*seq_len, 1, hidden_size) if mpu.get_pipeline_model_parallel_world_size() > 1: + schedule_kwargs['input_shapes'] = input_shapes # must set for flash-attn sequence packing losses_reduced = forward_backward_func( forward_step_func=forward_step, data_iterator=batch_generator, model=self.reward_model_module, num_microbatches=n_micro_batch, - input_shapes=input_shapes, # must set for flash-attn sequence packing seq_length=infer_batch_size * seq_len, # no use when input_shapes was set - hidden_size=self.model_config.hidden_size, # no use when input_shapes was set micro_batch_size=1, # no use when input_shapes was set forward_only=True, + **schedule_kwargs, # hidden size is of no use when input_shapes was set ) else: losses_reduced = forward_backward_func( @@ -251,9 +254,9 @@ def forward_step(batch_iter, model): model=self.reward_model_module, num_microbatches=n_micro_batch, seq_length=infer_batch_size * seq_len, # in use for pp = 1 - hidden_size=self.model_config.hidden_size, # in use for pp = 1 micro_batch_size=1, # in use for pp = 1 forward_only=True, + **schedule_kwargs, # hidden_size in use for pp = 1 ) # loss_reduces contains the stats returned from loss_func