diff --git a/.github/workflows/e2e_sft.yml b/.github/workflows/e2e_sft.yml index 7e80fe5a..4d42430a 100644 --- a/.github/workflows/e2e_sft.yml +++ b/.github/workflows/e2e_sft.yml @@ -43,4 +43,12 @@ jobs: - name: Running gsm8k e2e training tests on 8 L20 GPUs with rmpad using function rm run: | ray stop --force - bash tests/sft/run_sft.sh \ No newline at end of file + bash tests/sft/run_sft.sh + - name: Running gsm8k e2e training tests on 8 L20 GPUs with sequence parallism + run: | + ray stop --force + bash examples/sft/gsm8k/run_qwen_05_sp2.sh 8 $HOME/ckpts/ + - name: Check loss difference between sequence parallel vs. default implementation + run: | + ray stop --force + bash tests/sft/run_sft_sp_loss_match.sh diff --git a/examples/sft/gsm8k/run_qwen_05_sp2.sh b/examples/sft/gsm8k/run_qwen_05_sp2.sh new file mode 100755 index 00000000..a27cef1d --- /dev/null +++ b/examples/sft/gsm8k/run_qwen_05_sp2.sh @@ -0,0 +1,32 @@ +set -x + +if [ "$#" -lt 2 ]; then + echo "Usage: run_qwen_05_sp2.sh [other_configs...]" + exit 1 +fi + +nproc_per_node=$1 +save_path=$2 + +# Shift the arguments so $@ refers to the rest +shift 2 + +torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ + -m verl.trainer.fsdp_sft_trainer \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.prompt_key=extra_info \ + data.response_key=extra_info \ + optim.lr=1e-4 \ + +data.prompt_dict_keys=['question'] \ + +data.response_dict_keys=['answer'] \ + data.micro_batch_size=4 \ + model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \ + trainer.default_local_dir=$save_path \ + trainer.project_name=gsm8k-sft \ + trainer.experiment_name=gsm8k-sft-qwen-2.5-0.5b-instruct-sp2 \ + trainer.logger=['console'] \ + trainer.total_training_steps=1 \ + trainer.default_hdfs_dir=null $@ \ + ulysses_sequence_parallel_size=2 \ + use_remove_padding=true diff --git a/tests/sft/run_sft_sp_loss_match.sh b/tests/sft/run_sft_sp_loss_match.sh new file mode 100644 index 00000000..a63328ec --- /dev/null +++ b/tests/sft/run_sft_sp_loss_match.sh @@ -0,0 +1,24 @@ +# Tested with 2 & 4 GPUs + +set -x + +torchrun --standalone --nnodes=1 --nproc_per_node=8 \ + tests/sft/test_sp_loss_match.py \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.prompt_key=extra_info \ + data.response_key=extra_info \ + +data.prompt_dict_keys=['question'] \ + +data.response_dict_keys=['answer'] \ + data.micro_batch_size=32 \ + model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \ + ulysses_sequence_parallel_size=2 \ + use_remove_padding=True \ + trainer.default_local_dir=$HOME/ckpts/ \ + trainer.project_name=qwen2.5-sft \ + trainer.experiment_name=gsm8k-sft-gemma-2b-it \ + trainer.total_training_steps=1 \ + trainer.logger=['console'] \ + trainer.default_hdfs_dir=null $@ + +rm -rf $HOME/ckpts/ diff --git a/tests/sft/test_sp_loss_match.py b/tests/sft/test_sp_loss_match.py new file mode 100644 index 00000000..69223d3d --- /dev/null +++ b/tests/sft/test_sp_loss_match.py @@ -0,0 +1,128 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.distributed +from tensordict import TensorDict +from verl.trainer.fsdp_sft_trainer import FSDPSFTTrainer +from torch.distributed.device_mesh import init_device_mesh +from verl.utils.distributed import initialize_global_process_group + + +def test_trainer_forward_consistency(trainer: FSDPSFTTrainer, total_steps: int = 4): + """Test consistency between original forward pass and SP+rmpad forward passes. + + Args: + trainer: The FSDPSFTTrainer instance to test + total_steps: Number of steps to test (default: 4) + """ + if trainer.device_mesh.get_rank() == 0: + print("\nStarting debug comparison between original and SP+rmpad forward passes...") + print(f"Sequence parallel size: {trainer.config.ulysses_sequence_parallel_size}") + print(f"Remove padding: {trainer.use_remove_padding}\n") + + steps_remaining = total_steps + + for epoch in range(1): # Just one epoch for testing + trainer.train_sampler.set_epoch(epoch=epoch) + for data in trainer.train_dataloader: + data = TensorDict(data, batch_size=trainer.config.data.train_batch_size).cuda() + trainer.fsdp_model.train() + micro_batches = data.split(trainer.config.data.micro_batch_size) + + for idx, micro_batch in enumerate(micro_batches): + if trainer.device_mesh.get_rank() == 0: + print(f"\nProcessing micro batch {idx + 1}/{len(micro_batches)}") + + # Compute losses using both methods + # Disable SP and rmpad + trainer.use_remove_padding = False + old_sp = trainer.config.ulysses_sequence_parallel_size + trainer.config.ulysses_sequence_parallel_size = 1 + loss_ref = trainer._compute_loss_and_backward(micro_batch.copy(), do_backward=False) + + # Do SP and rmpad + trainer.config.ulysses_sequence_parallel_size = old_sp + trainer.use_remove_padding = True + loss_sp = trainer._compute_loss_and_backward(micro_batch.copy(), do_backward=False) + + # Collect losses across all ranks + loss_ref_all = loss_ref.clone() + loss_sp_all = loss_sp.clone() + torch.distributed.all_reduce(loss_ref_all, op=torch.distributed.ReduceOp.AVG) + torch.distributed.all_reduce(loss_sp_all, op=torch.distributed.ReduceOp.AVG) + + # Calculate relative difference of averaged losses + rel_diff = torch.abs(loss_ref_all - loss_sp_all) / (torch.abs(loss_ref_all) + 1e-8) + + if trainer.device_mesh.get_rank() == 0: + print("\nComparison Results (Averaged across ranks):") + print(f"Reference Loss: {loss_ref_all.item():.6f}") + print(f"SP+rmpad Loss: {loss_sp_all.item():.6f}") + print(f"Relative Difference: {rel_diff.item():.6f}") + + assert rel_diff.item() < 1e-2, "Significant difference detected between averaged losses!" + print("Loss difference is within the acceptable range.") + + steps_remaining -= 1 + if steps_remaining == 0: + break + if steps_remaining == 0: + break + break + + if trainer.device_mesh.get_rank() == 0: + print("\nDebug comparison completed successfully.") + + +def create_trainer(config): + """Create and initialize a trainer instance with the given config. + + Args: + config: Configuration object with training parameters + + Returns: + FSDPSFTTrainer: Initialized trainer instance + """ + local_rank, rank, world_size = initialize_global_process_group() + + device_mesh = init_device_mesh(device_type='cuda', mesh_shape=(world_size,), mesh_dim_names=('fsdp',)) + + dp_size = world_size // config.ulysses_sequence_parallel_size + ulysses_device_mesh = init_device_mesh(device_type='cuda', + mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), + mesh_dim_names=('dp', 'sp')) + + return FSDPSFTTrainer(config=config, device_mesh=device_mesh, ulysses_device_mesh=ulysses_device_mesh) + + +def main(config): + """Main function to run trainer tests. + + Args: + config: Configuration object with training parameters + """ + trainer = create_trainer(config) + test_trainer_forward_consistency(trainer) + + +if __name__ == '__main__': + import hydra + from omegaconf import DictConfig + + @hydra.main(config_path="../../verl/trainer/config", config_name="sft_trainer") + def hydra_entry(cfg: DictConfig) -> None: + main(cfg) + + hydra_entry() diff --git a/verl/trainer/config/sft_trainer.yaml b/verl/trainer/config/sft_trainer.yaml index d68e6dcc..9ac707ad 100644 --- a/verl/trainer/config/sft_trainer.yaml +++ b/verl/trainer/config/sft_trainer.yaml @@ -22,14 +22,15 @@ model: trust_remote_code: False lora_rank: 0 # Set to positive value to enable LoRA (e.g., 32) lora_alpha: 16 # LoRA scaling factor - target_modules: [q_proj, v_proj] # Target modules for LoRA adaptation + target_modules: all-linear # Target modules for LoRA adaptation optim: lr: 1e-5 betas: [0.9, 0.95] weight_decay: 0.01 warmup_steps_ratio: 0.1 clip_grad: 1.0 - +ulysses_sequence_parallel_size: 1 +use_remove_padding: False trainer: default_local_dir: /tmp/sft_model default_hdfs_dir: hdfs://tmp/experiments/gsm8k/gemma-1.1-7b-it/ # change the hdfs path here diff --git a/verl/trainer/fsdp_sft_trainer.py b/verl/trainer/fsdp_sft_trainer.py index dc7e677a..51efb4b9 100644 --- a/verl/trainer/fsdp_sft_trainer.py +++ b/verl/trainer/fsdp_sft_trainer.py @@ -25,26 +25,33 @@ import logging import re +from contextlib import nullcontext import torch import torch.distributed from torch import nn, optim from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, MixedPrecision, ShardingStrategy, CPUOffload +from tqdm import tqdm from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, AutoConfig from verl.utils.torch_functional import get_cosine_schedule_with_warmup from tensordict import TensorDict from torch.utils.data import DataLoader, DistributedSampler +from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis from verl.utils.fsdp_utils import get_fsdp_wrap_policy, init_fn, get_init_weight_context_manager from verl.utils.dataset import SFTDataset from verl.utils.fs import copy_local_path_from_hdfs from verl.utils.tracking import Tracking - +from verl.utils.ulysses import get_ulysses_sequence_parallel_world_size, set_ulysses_sequence_parallel_group from torch.distributed.device_mesh import DeviceMesh import verl.utils.hdfs_io as hdfs_io from verl.utils.debug import log_gpu_memory_usage from peft import LoraConfig, TaskType, get_peft_model +from verl.workers.sharding_manager import FSDPUlyssesShardingManager +from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad +from verl import DataProto + logger = logging.getLogger(__file__) logger.setLevel(os.getenv('VERL_SFT_LOGGING_LEVEL', 'WARN')) @@ -70,9 +77,11 @@ def convert_to_regular_types(obj): class FSDPSFTTrainer(object): - def __init__(self, config, device_mesh: DeviceMesh): + def __init__(self, config, device_mesh: DeviceMesh, ulysses_device_mesh: DeviceMesh): self.config = config self.device_mesh = device_mesh + self.ulysses_device_mesh = ulysses_device_mesh + self.sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) # build tokenizer first local_model_path = copy_local_path_from_hdfs(src=self.config.model.partial_pretrain, verbose=True) from verl.utils import hf_tokenizer @@ -83,6 +92,13 @@ def __init__(self, config, device_mesh: DeviceMesh): # normalize dp size self._normalize_config_bsz() + # Set sequence parallel size + self.config.ulysses_sequence_parallel_size = getattr(self.config, 'ulysses_sequence_parallel_size', 1) + self.use_remove_padding = getattr(self.config, 'use_remove_padding', False) + if self.device_mesh.get_rank() == 0: + print(f'Using sequence parallel size: {self.config.ulysses_sequence_parallel_size}') + print(f'Using remove padding: {self.use_remove_padding}') + self._build_dataloader() # build model self._build_model_optimizer() @@ -92,11 +108,11 @@ def __init__(self, config, device_mesh: DeviceMesh): print(self.config) def _normalize_config_bsz(self): - dp_size = self.device_mesh.size() + dp_size = self.device_mesh.size(0) if not self.ulysses_device_mesh else self.ulysses_device_mesh.size(0) if self.device_mesh.get_rank() == 0: print(f'Normalize batch size by dp {dp_size}') - assert self.config.data.train_batch_size % dp_size == 0 + assert self.config.data.train_batch_size % dp_size == 0, f"Global batch size {self.config.data.train_batch_size} is not divisible by dp size {dp_size}" self.config.data.train_batch_size //= dp_size @@ -123,8 +139,21 @@ def _build_dataloader(self): truncation=config.data.truncation) # build dataloader - rank = self.device_mesh.get_rank() - world_size = self.device_mesh.size() + # Use data parallel rank and size instead of global rank and world size + + # If doing SP, we need to use the local rank and size + if self.config.ulysses_sequence_parallel_size > 1: + rank = self.ulysses_device_mesh.get_local_rank('dp') + world_size = self.ulysses_device_mesh.size(0) + if self.ulysses_device_mesh.get_rank() == 0: + print(f'Using SP rank {rank} and size {world_size} for data distribution') + print(f'Each SP rank gets different data, but the same data WITHIN the same rank') + else: + rank = self.device_mesh.get_rank() + world_size = self.device_mesh.size() + if self.device_mesh.get_rank() == 0: + print(f'Using FSDP rank {rank} and size {world_size} for data distribution') + self.train_sampler = DistributedSampler(self.train_dataset, shuffle=True, num_replicas=world_size, @@ -165,6 +194,14 @@ def _build_model_optimizer(self): trust_remote_code = self.config.model.trust_remote_code # load config first config = AutoConfig.from_pretrained(local_model_path, trust_remote_code=trust_remote_code) + if self.config.ulysses_sequence_parallel_size > 1: + assert self.use_remove_padding, "Sequence parallel is only supported when remove_padding is enabled" + from verl.models.registry import check_model_support_rmpad + check_model_support_rmpad(config.model_type) + + if self.use_remove_padding and self.config.ulysses_sequence_parallel_size > 1: + from verl.models.transformers.monkey_patch import apply_monkey_patch + apply_monkey_patch(config, verbose=True) # This may be very large init_context = get_init_weight_context_manager(use_meta_tensor=not config.tie_word_embeddings) @@ -228,53 +265,116 @@ def _build_model_optimizer(self): log_gpu_memory_usage('After initialize optimizer', logger=logger) - steps_per_epoch = len(self.train_dataloader) - total_steps = steps_per_epoch * self.config.trainer.total_epochs + self.steps_per_epoch = len(self.train_dataloader) + self.total_steps = self.steps_per_epoch * self.config.trainer.total_epochs if self.device_mesh.get_rank() == 0: print( - f'Number of steps/epoch {steps_per_epoch}, number of epochs {self.config.trainer.total_epochs}, total number of steps {total_steps}' + f'Number of steps/epoch {self.steps_per_epoch}, number of epochs {self.config.trainer.total_epochs}, total number of steps {self.total_steps}' ) - num_warmup_steps = int(total_steps * self.config.optim.warmup_steps_ratio) + num_warmup_steps = int(self.total_steps * self.config.optim.warmup_steps_ratio) self.lr_scheduler = get_cosine_schedule_with_warmup(optimizer=self.optimizer, num_warmup_steps=num_warmup_steps, - num_training_steps=total_steps) - - def _compute_loss(self, batch): - loss_mask = batch.pop('loss_mask')[:, :-1].reshape(-1).cuda() - labels = batch['input_ids'][:, 1:].cuda() - - with torch.autocast(device_type='cuda', dtype=torch.bfloat16): - output = self.fsdp_model(input_ids=batch['input_ids'], - attention_mask=batch['attention_mask'], - position_ids=batch['position_ids'], - use_cache=False) # prevent model thinks it it generating + num_training_steps=self.total_steps) - logits = output.logits + def _compute_loss_and_backward(self, batch, do_backward=True): + """Compute loss with optional sequence parallelism and remove padding features""" + use_sp = self.use_remove_padding and self.config.ulysses_sequence_parallel_size > 1 - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels.contiguous() - # Flatten the tokens + # Move inputs to GPU and prepare loss mask + input_ids = batch['input_ids'].cuda() + attention_mask = batch['attention_mask'].cuda() + position_ids = batch['position_ids'].cuda() + loss_mask = batch.pop('loss_mask')[:, :-1].reshape(-1).cuda() loss_fct = nn.CrossEntropyLoss(reduction='none') - shift_logits = shift_logits.view(-1, self.model.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - loss = loss * loss_mask - - valid_token_this_rank = torch.sum(loss_mask) - - if self.config.data.balance_dp_token: - torch.distributed.all_reduce(valid_token_this_rank) # becomes total valid tokens in all ranks - dp_size = torch.distributed.get_world_size() - else: - dp_size = 1 - loss = torch.sum(loss) / valid_token_this_rank * dp_size # possible bugs here for dp - return loss + # Context manager for sequence parallel if needed + context = self.sharding_manager if use_sp else nullcontext() + with context: + with torch.autocast(device_type='cuda', dtype=torch.bfloat16): + if not use_sp: + # Standard forward pass without sequence parallel + labels = input_ids[:, 1:].contiguous() + output = self.fsdp_model(input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=False) + logits = output.logits + + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels.contiguous() + # Flatten the tokens + shift_logits = shift_logits.view(-1, self.model.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + loss = loss * loss_mask.to(loss.device) + else: + # IMPORTANT: We have a big assumption here, so we can shard the SAME sequence across SP ranks + # i.e., each GPU has <1 sequence, and each SP group has 1 sequence + # 1. All SP ranks will receive the *SAME* batch + # 2. Different SP groups will receive *DIFFERENT* batches + # This is implemented by the DistributedSampler + + batch_size, seqlen = input_ids.shape + # Remove padding + input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), + attention_mask) # input_ids_rmpad (total_nnz, ...) + input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) + + # Unpad position_ids to align rotary + position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), + indices).transpose(0, 1) + + # Pad and slice inputs for sequence parallelism + input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs( + input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size()) + # For computing loss + input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz) + input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs( + input_ids_rmpad_rolled, None, get_ulysses_sequence_parallel_world_size()) + input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) # ((total_nnz / sp) + pad) + + # Forward pass + output = self.fsdp_model( + input_ids=input_ids_rmpad_sliced, + attention_mask=None, # Not needed with flash attention varlen + position_ids=position_ids_rmpad_padded, + use_cache=False) + + # Compute loss locally then aggregate + logits_rmpad = output.logits.squeeze(0) + input_ids_rmpad_rolled = input_ids_rmpad_rolled.to(logits_rmpad.device) + loss = loss_fct(logits_rmpad, input_ids_rmpad_rolled) + # Gather and unpad for sequence parallelism + loss = gather_outpus_and_unpad(loss, gather_dim=0, unpad_dim=0, padding_size=pad_size) + + # This is the loss collected from all ulysses ranks + full_loss = pad_input(hidden_states=loss.unsqueeze(-1), + indices=indices, + batch=batch_size, + seqlen=seqlen) + full_loss = full_loss.squeeze(-1)[:, :-1] # Remove last token's loss + full_loss = full_loss.reshape(-1) + loss_mask = loss_mask.to(full_loss.device) + loss = full_loss * loss_mask + + valid_token_this_rank = torch.sum(loss_mask) + + if self.config.data.balance_dp_token: + torch.distributed.all_reduce(valid_token_this_rank) + dp_size = self.ulysses_device_mesh.size('dp') if use_sp else torch.distributed.get_world_size() + else: + dp_size = 1 + + loss = torch.sum(loss) / valid_token_this_rank * dp_size + + if do_backward: + loss.backward() + return loss def training_step(self, batch: TensorDict): self.fsdp_model.train() @@ -289,8 +389,7 @@ def training_step(self, batch: TensorDict): n_micro_batches = len(micro_batches) step_loss = 0 for micro_batch in micro_batches: - loss = self._compute_loss(batch=micro_batch) / n_micro_batches - loss.backward() + loss = self._compute_loss_and_backward(batch=micro_batch) / n_micro_batches step_loss += loss.item() self.fsdp_model.clip_grad_norm_(max_norm=self.config.optim.clip_grad) @@ -315,7 +414,7 @@ def training_step(self, batch: TensorDict): def validation_step(self, batch: TensorDict): self.fsdp_model.eval() with torch.no_grad(): - loss = self._compute_loss(batch) + loss = self._compute_loss_and_backward(batch, do_backward=False) torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG) return loss @@ -361,7 +460,9 @@ def fit(self): for epoch in range(self.config.trainer.total_epochs): self.train_sampler.set_epoch(epoch=epoch) - for data in self.train_dataloader: + for data in tqdm(self.train_dataloader, + total=self.steps_per_epoch, + desc=f"Epoch {epoch+1}/{self.config.trainer.total_epochs}"): data = TensorDict(data, batch_size=self.config.data.train_batch_size).cuda() metric = self.training_step(data) if rank == 0: @@ -414,8 +515,12 @@ def fit(self): def main(config): local_rank, rank, world_size = initialize_global_process_group() - device_mesh = init_device_mesh(device_type='cuda', mesh_shape=(world_size,), mesh_dim_names=('dp',)) - trainer = FSDPSFTTrainer(config=config, device_mesh=device_mesh) + device_mesh = init_device_mesh(device_type='cuda', mesh_shape=(world_size,), mesh_dim_names=('fsdp',)) + dp_size = world_size // config.ulysses_sequence_parallel_size + ulysses_device_mesh = init_device_mesh(device_type='cuda', + mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), + mesh_dim_names=('dp', 'sp')) + trainer = FSDPSFTTrainer(config=config, device_mesh=device_mesh, ulysses_device_mesh=ulysses_device_mesh) trainer.fit()