From 9fe76c042933e7d015d578179facdcdbf22fa5b8 Mon Sep 17 00:00:00 2001 From: jankinf Date: Wed, 26 Feb 2025 11:57:07 +0800 Subject: [PATCH] incorporate fsdp ckpt loading func into main_generation (#298) --- scripts/main_generation.sh | 19 ++++++++++ verl/trainer/config/generation.yaml | 5 +++ verl/trainer/main_generation.py | 17 +++------ verl/utils/fsdp_utils.py | 54 +++++++++++++++++++++++++++++ verl/workers/fsdp_workers.py | 14 ++++++-- 5 files changed, 93 insertions(+), 16 deletions(-) create mode 100644 scripts/main_generation.sh diff --git a/scripts/main_generation.sh b/scripts/main_generation.sh new file mode 100644 index 00000000..43ddccfe --- /dev/null +++ b/scripts/main_generation.sh @@ -0,0 +1,19 @@ +export HYDRA_FULL_ERROR=1 +export NCCL_DEBUG='WARN' +export TOKENIZERS_PARALLELISM='true' + +export HF_HUB_OFFLINE=1 +export HF_DATASETS_OFFLINE=1 + +export CUDA_VISIBLE_DEVICES=0 + +python3 -m verl.trainer.main_generation \ + trainer.n_gpus_per_node=1 \ + data.path='/data/countdown/test.parquet' \ + data.output_path='out_test.parquet' \ + data.batch_size=8 \ + data.n_samples=1 \ + model.path='Qwen/Qwen2.5-3B' \ + model.fsdp_model_path='/checkpoints/grpo-countdown-qwen2.5-3b/global_step_200/actor' \ + rollout.do_sample=False \ + rollout.response_length=1024 diff --git a/verl/trainer/config/generation.yaml b/verl/trainer/config/generation.yaml index 14cd2e5d..13805608 100644 --- a/verl/trainer/config/generation.yaml +++ b/verl/trainer/config/generation.yaml @@ -11,7 +11,9 @@ data: model: path: ~/models/Qwen2-7B-Instruct + fsdp_model_path: null external_lib: null + rollout: name: vllm temperature: 1.0 @@ -19,6 +21,7 @@ rollout: top_p: 0.7 prompt_length: 1536 response_length: 512 + # for vllm rollout dtype: bfloat16 # should align with FSDP gpu_memory_utilization: 0.5 @@ -31,11 +34,13 @@ rollout: max_num_seqs: 1024 log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu log_prob_micro_batch_size_per_gpu: 8 + # for hf rollout do_sample: True disable_log_stats: True enable_chunked_prefill: True n: 1 + actor: strategy: fsdp # This is for backward-compatibility ppo_mini_batch_size: 256 diff --git a/verl/trainer/main_generation.py b/verl/trainer/main_generation.py index 044c6e4d..c52e4ed9 100644 --- a/verl/trainer/main_generation.py +++ b/verl/trainer/main_generation.py @@ -14,25 +14,16 @@ """ Generate responses given a dataset of prompts """ +import os import ray -import numpy as np import hydra -import os - -os.environ['NCCL_DEBUG'] = 'WARN' -os.environ['TOKENIZERS_PARALLELISM'] = 'true' -# os.environ['TORCH_COMPILE_DISABLE'] = '1' - -from verl.utils.model import compute_position_id_with_mask - +import numpy as np import pandas as pd - -from transformers import AutoTokenizer - from verl import DataProto +from verl.utils.hdfs_io import makedirs from verl.utils.fs import copy_local_path_from_hdfs +from verl.utils.model import compute_position_id_with_mask from verl.workers.fsdp_workers import ActorRolloutRefWorker -from verl.utils.hdfs_io import makedirs from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup diff --git a/verl/utils/fsdp_utils.py b/verl/utils/fsdp_utils.py index 26b7dbd5..c82c69f2 100644 --- a/verl/utils/fsdp_utils.py +++ b/verl/utils/fsdp_utils.py @@ -18,6 +18,9 @@ import math import itertools import os +import re +from pathlib import Path +from collections import defaultdict from contextlib import contextmanager from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy from torch.distributed.fsdp import FullyShardedDataParallel as FSDP @@ -336,3 +339,54 @@ def init_fn(sub_mod: torch.nn.Module, recurse: bool = True): return sub_mod return init_fn + + +def load_sharded_model(fsdp_checkpoint_path): + state_dict = defaultdict(list) + checkpoint_dir = Path(fsdp_checkpoint_path) + + shard_files = list(checkpoint_dir.glob("model_world_size_*_rank_*.pt")) + print("fsdp_checkpoint_path: ", fsdp_checkpoint_path) + print("shard_files: ", shard_files) + if not shard_files: + raise ValueError(f"No checkpoint files found in {fsdp_checkpoint_path}") + + pattern = re.compile(r"model_world_size_(\d+)_rank_(\d+)\.pt") + world_sizes = set() + for file in shard_files: + match = pattern.match(file.name) + if match: + world_sizes.add(int(match.group(1))) + + if len(world_sizes) != 1: + raise ValueError( + f"Inconsistent world_size found in checkpoint files: {world_sizes}" + ) + + world_size = world_sizes.pop() + print(f"Found checkpoints with world_size = {world_size}") + + for rank in range(world_size): + filepath = checkpoint_dir / f"model_world_size_{world_size}_rank_{rank}.pt" + if not filepath.exists(): + raise ValueError(f"Missing shard file: {filepath}") + + print(f"Loading shard: {filepath}") + shard_dict = torch.load(filepath) + + for key, value in shard_dict.items(): + if hasattr(value, "to_local"): + value = value.to_local() + state_dict[key].append(value) + + consolidated_state_dict = {} + for key in state_dict: + try: + consolidated_state_dict[key] = torch.cat(state_dict[key], dim=0) + except (RuntimeError, TypeError): + consolidated_state_dict[key] = state_dict[key][0] + print( + f"Parameter '{key}' does not need concatenation, using first shard value" + ) + + return consolidated_state_dict diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 1660bd06..cdba8cd9 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -15,15 +15,15 @@ The main entry point to run the PPO algorithm """ -import logging import os +import logging import warnings import torch import torch.distributed -from torch.distributed.device_mesh import init_device_mesh import verl.utils.torch_functional as verl_F from omegaconf import DictConfig, open_dict +from torch.distributed.device_mesh import init_device_mesh from verl import DataProto from verl.single_controller.base import Worker from verl.single_controller.base.decorator import register, Dispatch @@ -32,7 +32,7 @@ from verl.utils.fs import copy_local_path_from_hdfs from verl.utils.fsdp_utils import get_fsdp_wrap_policy, init_fn, get_init_weight_context_manager from verl.utils.fsdp_utils import offload_fsdp_optimizer, offload_fsdp_model_to_cpu, load_fsdp_optimizer, \ - load_fsdp_model_to_gpu + load_fsdp_model_to_gpu, load_sharded_model from verl.utils.import_utils import import_external_libs from verl.utils.model import compute_position_id_with_mask from verl.utils.flops_counter import FlopsCounter @@ -137,6 +137,7 @@ def __init__(self, config: DictConfig, role: str): def _build_model_optimizer(self, model_path, + fsdp_model_path, fsdp_config, optim_config, override_model_config, @@ -244,6 +245,12 @@ def _build_model_optimizer(self, # We force reference policy to use CPUOffload to save memory. # We force turn off CPUOffload for actor because it causes incorrect results when using grad accumulation cpu_offload = None if role == 'actor' else CPUOffload(offload_params=True) + + if fsdp_model_path: + print("loading fsdp_model_path") + consolidated_state_dict = load_sharded_model(fsdp_model_path) + actor_module.load_state_dict(consolidated_state_dict) + actor_module_fsdp = FSDP( actor_module, cpu_offload=cpu_offload, @@ -348,6 +355,7 @@ def init_model(self): fsdp_config = OmegaConf.create() self.actor_module_fsdp, self.actor_optimizer, self.actor_lr_scheduler, self.actor_model_config = self._build_model_optimizer( model_path=self.config.model.path, + fsdp_model_path=self.config.model.get("fsdp_model_path", None), fsdp_config=fsdp_config, optim_config=optim_config, override_model_config=override_model_config,