Skip to content

Commit

Permalink
incorporate fsdp ckpt loading func into main_generation (volcengine#298)
Browse files Browse the repository at this point in the history
  • Loading branch information
jankinf committed Feb 26, 2025
1 parent b0514e2 commit 9fe76c0
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 16 deletions.
19 changes: 19 additions & 0 deletions scripts/main_generation.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
export HYDRA_FULL_ERROR=1
export NCCL_DEBUG='WARN'
export TOKENIZERS_PARALLELISM='true'

export HF_HUB_OFFLINE=1
export HF_DATASETS_OFFLINE=1

export CUDA_VISIBLE_DEVICES=0

python3 -m verl.trainer.main_generation \
trainer.n_gpus_per_node=1 \
data.path='/data/countdown/test.parquet' \
data.output_path='out_test.parquet' \
data.batch_size=8 \
data.n_samples=1 \
model.path='Qwen/Qwen2.5-3B' \
model.fsdp_model_path='/checkpoints/grpo-countdown-qwen2.5-3b/global_step_200/actor' \
rollout.do_sample=False \
rollout.response_length=1024
5 changes: 5 additions & 0 deletions verl/trainer/config/generation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,17 @@ data:

model:
path: ~/models/Qwen2-7B-Instruct
fsdp_model_path: null
external_lib: null

rollout:
name: vllm
temperature: 1.0
top_k: 50 # 0 for hf rollout, -1 for vllm rollout
top_p: 0.7
prompt_length: 1536
response_length: 512

# for vllm rollout
dtype: bfloat16 # should align with FSDP
gpu_memory_utilization: 0.5
Expand All @@ -31,11 +34,13 @@ rollout:
max_num_seqs: 1024
log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
log_prob_micro_batch_size_per_gpu: 8

# for hf rollout
do_sample: True
disable_log_stats: True
enable_chunked_prefill: True
n: 1

actor:
strategy: fsdp # This is for backward-compatibility
ppo_mini_batch_size: 256
Expand Down
17 changes: 4 additions & 13 deletions verl/trainer/main_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,16 @@
"""
Generate responses given a dataset of prompts
"""
import os
import ray
import numpy as np
import hydra
import os

os.environ['NCCL_DEBUG'] = 'WARN'
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
# os.environ['TORCH_COMPILE_DISABLE'] = '1'

from verl.utils.model import compute_position_id_with_mask

import numpy as np
import pandas as pd

from transformers import AutoTokenizer

from verl import DataProto
from verl.utils.hdfs_io import makedirs
from verl.utils.fs import copy_local_path_from_hdfs
from verl.utils.model import compute_position_id_with_mask
from verl.workers.fsdp_workers import ActorRolloutRefWorker
from verl.utils.hdfs_io import makedirs
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup


Expand Down
54 changes: 54 additions & 0 deletions verl/utils/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
import math
import itertools
import os
import re
from pathlib import Path
from collections import defaultdict
from contextlib import contextmanager
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
Expand Down Expand Up @@ -336,3 +339,54 @@ def init_fn(sub_mod: torch.nn.Module, recurse: bool = True):
return sub_mod

return init_fn


def load_sharded_model(fsdp_checkpoint_path):
state_dict = defaultdict(list)
checkpoint_dir = Path(fsdp_checkpoint_path)

shard_files = list(checkpoint_dir.glob("model_world_size_*_rank_*.pt"))
print("fsdp_checkpoint_path: ", fsdp_checkpoint_path)
print("shard_files: ", shard_files)
if not shard_files:
raise ValueError(f"No checkpoint files found in {fsdp_checkpoint_path}")

pattern = re.compile(r"model_world_size_(\d+)_rank_(\d+)\.pt")
world_sizes = set()
for file in shard_files:
match = pattern.match(file.name)
if match:
world_sizes.add(int(match.group(1)))

if len(world_sizes) != 1:
raise ValueError(
f"Inconsistent world_size found in checkpoint files: {world_sizes}"
)

world_size = world_sizes.pop()
print(f"Found checkpoints with world_size = {world_size}")

for rank in range(world_size):
filepath = checkpoint_dir / f"model_world_size_{world_size}_rank_{rank}.pt"
if not filepath.exists():
raise ValueError(f"Missing shard file: {filepath}")

print(f"Loading shard: {filepath}")
shard_dict = torch.load(filepath)

for key, value in shard_dict.items():
if hasattr(value, "to_local"):
value = value.to_local()
state_dict[key].append(value)

consolidated_state_dict = {}
for key in state_dict:
try:
consolidated_state_dict[key] = torch.cat(state_dict[key], dim=0)
except (RuntimeError, TypeError):
consolidated_state_dict[key] = state_dict[key][0]
print(
f"Parameter '{key}' does not need concatenation, using first shard value"
)

return consolidated_state_dict
14 changes: 11 additions & 3 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
The main entry point to run the PPO algorithm
"""

import logging
import os
import logging
import warnings

import torch
import torch.distributed
from torch.distributed.device_mesh import init_device_mesh
import verl.utils.torch_functional as verl_F
from omegaconf import DictConfig, open_dict
from torch.distributed.device_mesh import init_device_mesh
from verl import DataProto
from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import register, Dispatch
Expand All @@ -32,7 +32,7 @@
from verl.utils.fs import copy_local_path_from_hdfs
from verl.utils.fsdp_utils import get_fsdp_wrap_policy, init_fn, get_init_weight_context_manager
from verl.utils.fsdp_utils import offload_fsdp_optimizer, offload_fsdp_model_to_cpu, load_fsdp_optimizer, \
load_fsdp_model_to_gpu
load_fsdp_model_to_gpu, load_sharded_model
from verl.utils.import_utils import import_external_libs
from verl.utils.model import compute_position_id_with_mask
from verl.utils.flops_counter import FlopsCounter
Expand Down Expand Up @@ -137,6 +137,7 @@ def __init__(self, config: DictConfig, role: str):

def _build_model_optimizer(self,
model_path,
fsdp_model_path,
fsdp_config,
optim_config,
override_model_config,
Expand Down Expand Up @@ -244,6 +245,12 @@ def _build_model_optimizer(self,
# We force reference policy to use CPUOffload to save memory.
# We force turn off CPUOffload for actor because it causes incorrect results when using grad accumulation
cpu_offload = None if role == 'actor' else CPUOffload(offload_params=True)

if fsdp_model_path:
print("loading fsdp_model_path")
consolidated_state_dict = load_sharded_model(fsdp_model_path)
actor_module.load_state_dict(consolidated_state_dict)

actor_module_fsdp = FSDP(
actor_module,
cpu_offload=cpu_offload,
Expand Down Expand Up @@ -348,6 +355,7 @@ def init_model(self):
fsdp_config = OmegaConf.create()
self.actor_module_fsdp, self.actor_optimizer, self.actor_lr_scheduler, self.actor_model_config = self._build_model_optimizer(
model_path=self.config.model.path,
fsdp_model_path=self.config.model.get("fsdp_model_path", None),
fsdp_config=fsdp_config,
optim_config=optim_config,
override_model_config=override_model_config,
Expand Down

0 comments on commit 9fe76c0

Please sign in to comment.