-
Notifications
You must be signed in to change notification settings - Fork 341
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to load trained ckpt and eval with it on specific dataset? #298
Comments
Additional Information:
However, the save_checkpoint used in this repository generates distributed rank replicas (https://github.com/volcengine/verl/blob/main/verl/utils/checkpoint/fsdp_checkpoint_manager.py#L134), which can be loaded using provided
The content of
the format is
I provide my code as follows for easy reproduction of the results: import hydra
import numpy as np
import torch
import torch.distributed
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoConfig
from verl import DataProto
from verl.utils import hf_tokenizer
from verl.utils.model import get_generation_config
from verl.utils.fs import copy_local_path_from_hdfs
from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn
from verl.workers.reward_manager import NaiveRewardManager
from verl.workers.rollout.hf_rollout import HFRollout
model_path = "checkpoints/TinyZero/countdown-qwen2.5-3b-grpo/actor/global_step_400"
# model_path = "Qwen/Qwen2.5-3B"
# model_path = "Qwen/Qwen2.5-3B-Instruct"
local_path = copy_local_path_from_hdfs(model_path)
trust_remote_code = True
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
torch_dtype = torch.bfloat16
actor_model_config = AutoConfig.from_pretrained(
local_path, trust_remote_code=trust_remote_code
)
actor_module = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=local_path,
torch_dtype=torch_dtype,
config=actor_model_config,
attn_implementation="flash_attention_2",
trust_remote_code=trust_remote_code,
)
# ckpt_new_version = "checkpoints/TinyZero/grpo-countdown-qwen2.5-3b-v2/global_step_100/actor/model_world_size_2_rank_0.pt"
# model_state = torch.load(ckpt_new_version, map_location="cpu")
# actor_module.load_state_dict(model_state)
actor_module.to(torch_dtype)
actor_module.to("cuda:0")
generation_config = get_generation_config(
local_path, trust_remote_code=trust_remote_code
)
val_reward_fn = NaiveRewardManager(
tokenizer=tokenizer, num_examine=1, compute_score=None
)
@hydra.main()
def main(config):
val_dataset = RLHFDataset(
parquet_files="data/countdown/test.parquet",
tokenizer=tokenizer,
prompt_key="prompt",
max_prompt_length=512,
filter_prompts=True,
return_raw_chat=False,
truncation="error",
)
val_dataloader = DataLoader(
dataset=val_dataset,
batch_size=3,
shuffle=True,
drop_last=True,
collate_fn=collate_fn,
)
assert len(val_dataloader) >= 1
sample_inputs = []
sample_outputs = []
sample_scores = []
reward_tensor_lst = []
data_source_lst = []
hfrollout = HFRollout(module=actor_module, config=config)
for data in val_dataloader:
test_batch = DataProto.from_single_dict(data)
test_batch = test_batch.to("cuda")
input_ids = test_batch.batch["input_ids"]
input_texts = [
tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids
]
sample_inputs.extend(input_texts)
test_gen_batch = test_batch.pop(["input_ids", "attention_mask", "position_ids"])
test_gen_batch.meta_info = {
"eos_token_id": tokenizer.eos_token_id,
"pad_token_id": tokenizer.pad_token_id,
"recompute_log_prob": False,
"do_sample": False,
"validate": True,
}
# pad to be divisible by dp_size
# test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, hfrollout.world_size)
test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, 1)
test_output_gen_batch_padded = hfrollout.generate_sequences(
test_gen_batch_padded
)
# unpad
test_output_gen_batch = unpad_dataproto(
test_output_gen_batch_padded, pad_size=pad_size
)
print("validation generation end")
# Store generated outputs
output_ids = test_output_gen_batch.batch["responses"]
output_texts = [
tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids
]
sample_outputs.extend(output_texts)
test_batch = test_batch.union(test_output_gen_batch)
# evaluate using reward_function
reward_tensor = val_reward_fn(test_batch)
# Store scores
scores = reward_tensor.sum(-1).cpu().tolist()
sample_scores.extend(scores)
reward_tensor_lst.append(reward_tensor)
data_source_lst.append(
test_batch.non_tensor_batch.get(
"data_source", ["unknown"] * reward_tensor.shape[0]
)
)
reward_tensor = torch.cat(reward_tensor_lst, dim=0).sum(-1).cpu() # (batch_size,)
data_sources = np.concatenate(data_source_lst, axis=0)
# evaluate test_score based on data source
data_source_reward = {}
for i in range(reward_tensor.shape[0]):
data_source = data_sources[i]
if data_source not in data_source_reward:
data_source_reward[data_source] = []
data_source_reward[data_source].append(reward_tensor[i].item())
metric_dict = {}
for data_source, rewards in data_source_reward.items():
metric_dict[f"val/test_score/{data_source}"] = np.mean(rewards)
print(metric_dict)
if __name__ == "__main__":
main()
To summarize my questions:
Your guidance would be appreciated! Thank you! |
I have the same problem, and I solve this by using the following script to convert the FSDP checkpoint to huggingface checkpoint. The evaluation results seem correct, but I don't know if this is an elegant solution. #!/usr/bin/env python
# encoding: utf-8
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
import torch
import fire
from glob import glob
from collections import defaultdict
def main(fsdp_checkpoint_path, huggingface_model_path, output_path):
state_dict = defaultdict(list)
world_size = 32
for rank in range(world_size):
filepath = f"{fsdp_checkpoint_path}/model_world_size_{world_size}_rank_{rank}.pt"
print('loading', filepath)
this_state_dict = torch.load(filepath)
for key, value in this_state_dict.items():
state_dict[key].append(value.to_local())
for key in state_dict:
state_dict[key] = torch.cat(state_dict[key], dim=0)
config = AutoConfig.from_pretrained(huggingface_model_path)
model = AutoModelForCausalLM.from_config(config)
model.load_state_dict(state_dict)
#for filepath in glob(f'{fsdp_checkpoint_path}/model_*.pt'):
# part_state_dict = torch.load(filepath)
# model.load_state_dict(part_state_dict)
model.save_pretrained(output_path, max_shard_size="10GB")
tokenizer = AutoTokenizer.from_pretrained(huggingface_model_path)
tokenizer.save_pretrained(output_path)
if __name__ == "__main__":
fire.Fire(main) python convert_fsdp_to_hf.py checkpoints/global_step_20/actor/ checkpoints/global_step_20/actor/huggingface/ qwen2.5_7b_rl_step20 |
@Cppowboy Thank you for your suggestion. I've successfully implemented your solution and would like to share the improvements: The code has been updated to support both:
This enhancement provides more flexibility for users working with different checkpoint formats. import hydra
import numpy as np
import re
import torch
import torch.distributed
from pathlib import Path
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoConfig
from collections import defaultdict
from verl import DataProto
from verl.utils import hf_tokenizer
from verl.utils.fs import copy_local_path_from_hdfs
from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn
from verl.workers.reward_manager import NaiveRewardManager
from verl.workers.rollout.hf_rollout import HFRollout
def load_sharded_model(fsdp_checkpoint_path):
state_dict = defaultdict(list)
checkpoint_dir = Path(fsdp_checkpoint_path)
shard_files = list(checkpoint_dir.glob("model_world_size_*_rank_*.pt"))
if not shard_files:
raise ValueError(f"No checkpoint files found in {fsdp_checkpoint_path}")
pattern = re.compile(r"model_world_size_(\d+)_rank_(\d+)\.pt")
world_sizes = set()
for file in shard_files:
match = pattern.match(file.name)
if match:
world_sizes.add(int(match.group(1)))
if len(world_sizes) != 1:
raise ValueError(
f"Inconsistent world_size found in checkpoint files: {world_sizes}"
)
world_size = world_sizes.pop()
print(f"Found checkpoints with world_size = {world_size}")
for rank in range(world_size):
filepath = checkpoint_dir / f"model_world_size_{world_size}_rank_{rank}.pt"
if not filepath.exists():
raise ValueError(f"Missing shard file: {filepath}")
print(f"Loading shard: {filepath}")
shard_dict = torch.load(filepath)
for key, value in shard_dict.items():
if hasattr(value, "to_local"):
value = value.to_local()
state_dict[key].append(value)
consolidated_state_dict = {}
for key in state_dict:
try:
consolidated_state_dict[key] = torch.cat(state_dict[key], dim=0)
except (RuntimeError, TypeError):
consolidated_state_dict[key] = state_dict[key][0]
print(
f"Parameter '{key}' does not need concatenation, using first shard value"
)
return consolidated_state_dict
def initialize_model_and_tokenizer(
model_path, trust_remote_code=True, torch_dtype=torch.bfloat16
):
local_path = copy_local_path_from_hdfs(model_path)
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
actor_model_config = AutoConfig.from_pretrained(
local_path, trust_remote_code=trust_remote_code
)
actor_module = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=local_path,
torch_dtype=torch_dtype,
config=actor_model_config,
attn_implementation="flash_attention_2",
trust_remote_code=trust_remote_code,
)
return tokenizer, actor_module
@hydra.main()
def main(config):
# Loading huggingface-style checkpoint
model_path = "/data/projects/TinyZero/checkpoints/TinyZero/countdown-qwen2.5-3b-grpo/actor/global_step_400"
# model_path = "Qwen/Qwen2.5-3B"
# model_path = "Qwen/Qwen2.5-3B-Instruct"
tokenizer, actor_module = initialize_model_and_tokenizer(model_path)
# Loading FSDP checkpoint (optional: these three lines can be skipped. Prerequisite: actor_module must be preloaded)
fsdp_checkpoint_path = "/data/projects/TinyZero/checkpoints/TinyZero/grpo-countdown-qwen2.5-3b-v2/global_step_200/actor"
state_dict = load_sharded_model(fsdp_checkpoint_path)
actor_module.load_state_dict(state_dict)
actor_module.to(torch.bfloat16)
actor_module.to("cuda:0")
val_dataset = RLHFDataset(
parquet_files="/home/projects/TinyZero/data/countdown/test.parquet",
tokenizer=tokenizer,
prompt_key="prompt",
max_prompt_length=512,
filter_prompts=True,
return_raw_chat=False,
truncation="error",
)
val_dataloader = DataLoader(
dataset=val_dataset,
batch_size=3,
shuffle=True,
drop_last=True,
collate_fn=collate_fn,
)
assert len(val_dataloader) >= 1
val_reward_fn = NaiveRewardManager(
tokenizer=tokenizer, num_examine=1, compute_score=None
)
hfrollout = HFRollout(module=actor_module, config=config)
sample_inputs = []
sample_outputs = []
sample_scores = []
reward_tensor_lst = []
data_source_lst = []
for data in val_dataloader:
test_batch = DataProto.from_single_dict(data)
test_batch = test_batch.to("cuda")
input_ids = test_batch.batch["input_ids"]
input_texts = [
tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids
]
sample_inputs.extend(input_texts)
test_gen_batch = test_batch.pop(["input_ids", "attention_mask", "position_ids"])
test_gen_batch.meta_info = {
"eos_token_id": tokenizer.eos_token_id,
"pad_token_id": tokenizer.pad_token_id,
"recompute_log_prob": False,
"do_sample": False,
"validate": True,
}
test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, 1)
test_output_gen_batch_padded = hfrollout.generate_sequences(
test_gen_batch_padded
)
test_output_gen_batch = unpad_dataproto(
test_output_gen_batch_padded, pad_size=pad_size
)
print("validation generation end")
output_ids = test_output_gen_batch.batch["responses"]
output_texts = [
tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids
]
sample_outputs.extend(output_texts)
test_batch = test_batch.union(test_output_gen_batch)
reward_tensor = val_reward_fn(test_batch)
scores = reward_tensor.sum(-1).cpu().tolist()
sample_scores.extend(scores)
reward_tensor_lst.append(reward_tensor)
data_source_lst.append(
test_batch.non_tensor_batch.get(
"data_source", ["unknown"] * reward_tensor.shape[0]
)
)
reward_tensor = torch.cat(reward_tensor_lst, dim=0).sum(-1).cpu()
data_sources = np.concatenate(data_source_lst, axis=0)
data_source_reward = {}
for i in range(reward_tensor.shape[0]):
data_source = data_sources[i]
if data_source not in data_source_reward:
data_source_reward[data_source] = []
data_source_reward[data_source].append(reward_tensor[i].item())
metric_dict = {}
for data_source, rewards in data_source_reward.items():
metric_dict[f"val/test_score/{data_source}"] = np.mean(rewards)
print(metric_dict)
if __name__ == "__main__":
main() |
@Cppowboy @jankinf Thanks for your great work! @jankinf Would you like to make a PR to upload your script for evaluation in the Thanks in advance! |
Great repo! Could you please provide some guidances on how to load a trained checkpoint and perform evaluation on a specific dataset. The documentation seems unclear about this workflow.
The text was updated successfully, but these errors were encountered: