diff --git a/scripts/eval_on_data.sh b/scripts/eval_on_data.sh new file mode 100644 index 00000000..6bd8373b --- /dev/null +++ b/scripts/eval_on_data.sh @@ -0,0 +1,15 @@ +export HYDRA_FULL_ERROR=1 +export HF_HUB_OFFLINE=1 +export HF_DATASETS_OFFLINE=1 +export CUDA_VISIBLE_DEVICES=1 + +python3 -m verl.trainer.eval_on_data \ + data.val_files='/data/countdown/test.parquet' \ + actor_rollout_ref.rollout.micro_batch_size=1 \ + actor_rollout_ref.rollout.do_sample=False \ + actor_rollout_ref.rollout.response_length=1024 \ + actor_rollout_ref.rollout.top_p=1 \ + actor_rollout_ref.rollout.top_k=0 \ + actor_rollout_ref.rollout.temperature=0 \ + actor_rollout_ref.model.hf_model_path='Qwen/Qwen2.5-3B' \ + actor_rollout_ref.model.fsdp_model_path='/checkpoints/grpo-countdown-qwen2.5-3b/global_step_200/actor' diff --git a/scripts/main_generation.sh b/scripts/main_generation.sh new file mode 100644 index 00000000..43ddccfe --- /dev/null +++ b/scripts/main_generation.sh @@ -0,0 +1,19 @@ +export HYDRA_FULL_ERROR=1 +export NCCL_DEBUG='WARN' +export TOKENIZERS_PARALLELISM='true' + +export HF_HUB_OFFLINE=1 +export HF_DATASETS_OFFLINE=1 + +export CUDA_VISIBLE_DEVICES=0 + +python3 -m verl.trainer.main_generation \ + trainer.n_gpus_per_node=1 \ + data.path='/data/countdown/test.parquet' \ + data.output_path='out_test.parquet' \ + data.batch_size=8 \ + data.n_samples=1 \ + model.path='Qwen/Qwen2.5-3B' \ + model.fsdp_model_path='/checkpoints/grpo-countdown-qwen2.5-3b/global_step_200/actor' \ + rollout.do_sample=False \ + rollout.response_length=1024 diff --git a/verl/trainer/config/eval_on_data.yaml b/verl/trainer/config/eval_on_data.yaml new file mode 100644 index 00000000..d7f0ef35 --- /dev/null +++ b/verl/trainer/config/eval_on_data.yaml @@ -0,0 +1,19 @@ +data: + val_files: "/data/countdown/test.parquet" + max_prompt_length: 512 + batch_size: 8 + shuffle: True + drop_last: True + +actor_rollout_ref: + model: + hf_model_path: "Qwen/Qwen2.5-3B" + fsdp_model_path: None + rollout: + micro_batch_size: 1 + do_sample: False + response_length: 1024 + top_p: 1 + top_k: 0 + temperature: 0 + diff --git a/verl/trainer/config/generation.yaml b/verl/trainer/config/generation.yaml index 14cd2e5d..13805608 100644 --- a/verl/trainer/config/generation.yaml +++ b/verl/trainer/config/generation.yaml @@ -11,7 +11,9 @@ data: model: path: ~/models/Qwen2-7B-Instruct + fsdp_model_path: null external_lib: null + rollout: name: vllm temperature: 1.0 @@ -19,6 +21,7 @@ rollout: top_p: 0.7 prompt_length: 1536 response_length: 512 + # for vllm rollout dtype: bfloat16 # should align with FSDP gpu_memory_utilization: 0.5 @@ -31,11 +34,13 @@ rollout: max_num_seqs: 1024 log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu log_prob_micro_batch_size_per_gpu: 8 + # for hf rollout do_sample: True disable_log_stats: True enable_chunked_prefill: True n: 1 + actor: strategy: fsdp # This is for backward-compatibility ppo_mini_batch_size: 256 diff --git a/verl/trainer/eval_on_data.py b/verl/trainer/eval_on_data.py new file mode 100644 index 00000000..234ad200 --- /dev/null +++ b/verl/trainer/eval_on_data.py @@ -0,0 +1,198 @@ +import hydra +import numpy as np +import re +import torch +import torch.distributed +from pathlib import Path +from torch.utils.data import DataLoader +from transformers import AutoModelForCausalLM, AutoConfig +from collections import defaultdict + +from verl import DataProto +from verl.utils import hf_tokenizer +from verl.utils.fs import copy_local_path_from_hdfs +from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto +from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn +from verl.workers.reward_manager import NaiveRewardManager +from verl.workers.rollout.hf_rollout import HFRollout + + +def load_sharded_model(fsdp_checkpoint_path): + state_dict = defaultdict(list) + checkpoint_dir = Path(fsdp_checkpoint_path) + + shard_files = list(checkpoint_dir.glob("model_world_size_*_rank_*.pt")) + if not shard_files: + raise ValueError(f"No checkpoint files found in {fsdp_checkpoint_path}") + + pattern = re.compile(r"model_world_size_(\d+)_rank_(\d+)\.pt") + world_sizes = set() + for file in shard_files: + match = pattern.match(file.name) + if match: + world_sizes.add(int(match.group(1))) + + if len(world_sizes) != 1: + raise ValueError( + f"Inconsistent world_size found in checkpoint files: {world_sizes}" + ) + + world_size = world_sizes.pop() + print(f"Found checkpoints with world_size = {world_size}") + + for rank in range(world_size): + filepath = checkpoint_dir / f"model_world_size_{world_size}_rank_{rank}.pt" + if not filepath.exists(): + raise ValueError(f"Missing shard file: {filepath}") + + print(f"Loading shard: {filepath}") + shard_dict = torch.load(filepath) + + for key, value in shard_dict.items(): + if hasattr(value, "to_local"): + value = value.to_local() + state_dict[key].append(value) + + consolidated_state_dict = {} + for key in state_dict: + try: + consolidated_state_dict[key] = torch.cat(state_dict[key], dim=0) + except (RuntimeError, TypeError): + consolidated_state_dict[key] = state_dict[key][0] + print( + f"Parameter '{key}' does not need concatenation, using first shard value" + ) + + return consolidated_state_dict + + +def initialize_model_and_tokenizer( + model_path, trust_remote_code=True, torch_dtype=torch.bfloat16 +): + local_path = copy_local_path_from_hdfs(model_path) + tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) + + actor_model_config = AutoConfig.from_pretrained( + local_path, trust_remote_code=trust_remote_code + ) + actor_module = AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path=local_path, + torch_dtype=torch_dtype, + config=actor_model_config, + attn_implementation="flash_attention_2", + trust_remote_code=trust_remote_code, + ) + + return tokenizer, actor_module + + +@hydra.main(config_path='config', config_name='eval_on_data', version_base=None) +def main(config): + # Loading huggingface-style checkpoint, for example "Qwen/Qwen2.5-3B" or local_ckpt_path + model_path = config.actor_rollout_ref.model.hf_model_path + tokenizer, actor_module = initialize_model_and_tokenizer(model_path) + + # Loading FSDP checkpoint (optional: these three lines can be skipped. Prerequisite: actor_module must be preloaded) + fsdp_checkpoint_path = config.actor_rollout_ref.model.get("fsdp_checkpoint_path", None) + if fsdp_checkpoint_path is not None: + state_dict = load_sharded_model(fsdp_checkpoint_path) + actor_module.load_state_dict(state_dict) + + actor_module.to(torch.bfloat16) + actor_module.to("cuda:0") + + val_files = config.data.val_files + val_dataset = RLHFDataset( + parquet_files=val_files, + tokenizer=tokenizer, + prompt_key="prompt", + max_prompt_length=config.data.max_prompt_length, + filter_prompts=True, + return_raw_chat=False, + truncation="error", + ) + val_dataloader = DataLoader( + dataset=val_dataset, + batch_size=config.data.batch_size, + shuffle=config.data.shuffle, + drop_last=config.data.drop_last, + collate_fn=collate_fn, + ) + + assert len(val_dataloader) >= 1 + + val_reward_fn = NaiveRewardManager( + tokenizer=tokenizer, num_examine=1, compute_score=None + ) + + hfrollout = HFRollout(module=actor_module, config=config) + + sample_inputs = [] + sample_outputs = [] + sample_scores = [] + reward_tensor_lst = [] + data_source_lst = [] + + for data in val_dataloader: + test_batch = DataProto.from_single_dict(data) + test_batch = test_batch.to("cuda") + input_ids = test_batch.batch["input_ids"] + input_texts = [ + tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids + ] + sample_inputs.extend(input_texts) + + test_gen_batch = test_batch.pop(["input_ids", "attention_mask", "position_ids"]) + test_gen_batch.meta_info = { + "eos_token_id": tokenizer.eos_token_id, + "pad_token_id": tokenizer.pad_token_id, + "recompute_log_prob": False, + "do_sample": False, + "validate": True, + } + + test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, 1) + test_output_gen_batch_padded = hfrollout.generate_sequences( + test_gen_batch_padded + ) + test_output_gen_batch = unpad_dataproto( + test_output_gen_batch_padded, pad_size=pad_size + ) + print("validation generation end") + + output_ids = test_output_gen_batch.batch["responses"] + output_texts = [ + tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids + ] + sample_outputs.extend(output_texts) + test_batch = test_batch.union(test_output_gen_batch) + + reward_tensor = val_reward_fn(test_batch) + scores = reward_tensor.sum(-1).cpu().tolist() + sample_scores.extend(scores) + reward_tensor_lst.append(reward_tensor) + data_source_lst.append( + test_batch.non_tensor_batch.get( + "data_source", ["unknown"] * reward_tensor.shape[0] + ) + ) + + reward_tensor = torch.cat(reward_tensor_lst, dim=0).sum(-1).cpu() + data_sources = np.concatenate(data_source_lst, axis=0) + + data_source_reward = {} + for i in range(reward_tensor.shape[0]): + data_source = data_sources[i] + if data_source not in data_source_reward: + data_source_reward[data_source] = [] + data_source_reward[data_source].append(reward_tensor[i].item()) + + metric_dict = {} + for data_source, rewards in data_source_reward.items(): + metric_dict[f"val/test_score/{data_source}"] = np.mean(rewards) + + print(metric_dict) + + +if __name__ == "__main__": + main() diff --git a/verl/trainer/main_generation.py b/verl/trainer/main_generation.py index 044c6e4d..c52e4ed9 100644 --- a/verl/trainer/main_generation.py +++ b/verl/trainer/main_generation.py @@ -14,25 +14,16 @@ """ Generate responses given a dataset of prompts """ +import os import ray -import numpy as np import hydra -import os - -os.environ['NCCL_DEBUG'] = 'WARN' -os.environ['TOKENIZERS_PARALLELISM'] = 'true' -# os.environ['TORCH_COMPILE_DISABLE'] = '1' - -from verl.utils.model import compute_position_id_with_mask - +import numpy as np import pandas as pd - -from transformers import AutoTokenizer - from verl import DataProto +from verl.utils.hdfs_io import makedirs from verl.utils.fs import copy_local_path_from_hdfs +from verl.utils.model import compute_position_id_with_mask from verl.workers.fsdp_workers import ActorRolloutRefWorker -from verl.utils.hdfs_io import makedirs from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup diff --git a/verl/utils/fsdp_utils.py b/verl/utils/fsdp_utils.py index 26b7dbd5..c82c69f2 100644 --- a/verl/utils/fsdp_utils.py +++ b/verl/utils/fsdp_utils.py @@ -18,6 +18,9 @@ import math import itertools import os +import re +from pathlib import Path +from collections import defaultdict from contextlib import contextmanager from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy from torch.distributed.fsdp import FullyShardedDataParallel as FSDP @@ -336,3 +339,54 @@ def init_fn(sub_mod: torch.nn.Module, recurse: bool = True): return sub_mod return init_fn + + +def load_sharded_model(fsdp_checkpoint_path): + state_dict = defaultdict(list) + checkpoint_dir = Path(fsdp_checkpoint_path) + + shard_files = list(checkpoint_dir.glob("model_world_size_*_rank_*.pt")) + print("fsdp_checkpoint_path: ", fsdp_checkpoint_path) + print("shard_files: ", shard_files) + if not shard_files: + raise ValueError(f"No checkpoint files found in {fsdp_checkpoint_path}") + + pattern = re.compile(r"model_world_size_(\d+)_rank_(\d+)\.pt") + world_sizes = set() + for file in shard_files: + match = pattern.match(file.name) + if match: + world_sizes.add(int(match.group(1))) + + if len(world_sizes) != 1: + raise ValueError( + f"Inconsistent world_size found in checkpoint files: {world_sizes}" + ) + + world_size = world_sizes.pop() + print(f"Found checkpoints with world_size = {world_size}") + + for rank in range(world_size): + filepath = checkpoint_dir / f"model_world_size_{world_size}_rank_{rank}.pt" + if not filepath.exists(): + raise ValueError(f"Missing shard file: {filepath}") + + print(f"Loading shard: {filepath}") + shard_dict = torch.load(filepath) + + for key, value in shard_dict.items(): + if hasattr(value, "to_local"): + value = value.to_local() + state_dict[key].append(value) + + consolidated_state_dict = {} + for key in state_dict: + try: + consolidated_state_dict[key] = torch.cat(state_dict[key], dim=0) + except (RuntimeError, TypeError): + consolidated_state_dict[key] = state_dict[key][0] + print( + f"Parameter '{key}' does not need concatenation, using first shard value" + ) + + return consolidated_state_dict diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 1660bd06..cdba8cd9 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -15,15 +15,15 @@ The main entry point to run the PPO algorithm """ -import logging import os +import logging import warnings import torch import torch.distributed -from torch.distributed.device_mesh import init_device_mesh import verl.utils.torch_functional as verl_F from omegaconf import DictConfig, open_dict +from torch.distributed.device_mesh import init_device_mesh from verl import DataProto from verl.single_controller.base import Worker from verl.single_controller.base.decorator import register, Dispatch @@ -32,7 +32,7 @@ from verl.utils.fs import copy_local_path_from_hdfs from verl.utils.fsdp_utils import get_fsdp_wrap_policy, init_fn, get_init_weight_context_manager from verl.utils.fsdp_utils import offload_fsdp_optimizer, offload_fsdp_model_to_cpu, load_fsdp_optimizer, \ - load_fsdp_model_to_gpu + load_fsdp_model_to_gpu, load_sharded_model from verl.utils.import_utils import import_external_libs from verl.utils.model import compute_position_id_with_mask from verl.utils.flops_counter import FlopsCounter @@ -137,6 +137,7 @@ def __init__(self, config: DictConfig, role: str): def _build_model_optimizer(self, model_path, + fsdp_model_path, fsdp_config, optim_config, override_model_config, @@ -244,6 +245,12 @@ def _build_model_optimizer(self, # We force reference policy to use CPUOffload to save memory. # We force turn off CPUOffload for actor because it causes incorrect results when using grad accumulation cpu_offload = None if role == 'actor' else CPUOffload(offload_params=True) + + if fsdp_model_path: + print("loading fsdp_model_path") + consolidated_state_dict = load_sharded_model(fsdp_model_path) + actor_module.load_state_dict(consolidated_state_dict) + actor_module_fsdp = FSDP( actor_module, cpu_offload=cpu_offload, @@ -348,6 +355,7 @@ def init_model(self): fsdp_config = OmegaConf.create() self.actor_module_fsdp, self.actor_optimizer, self.actor_lr_scheduler, self.actor_model_config = self._build_model_optimizer( model_path=self.config.model.path, + fsdp_model_path=self.config.model.get("fsdp_model_path", None), fsdp_config=fsdp_config, optim_config=optim_config, override_model_config=override_model_config,