Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to load trained ckpt and eval with it on specific dataset? #298

Open
jankinf opened this issue Feb 18, 2025 · 5 comments
Open

How to load trained ckpt and eval with it on specific dataset? #298

jankinf opened this issue Feb 18, 2025 · 5 comments
Labels
enhancement New feature or request good first issue Good for newcomers

Comments

@jankinf
Copy link

jankinf commented Feb 18, 2025

Great repo! Could you please provide some guidances on how to load a trained checkpoint and perform evaluation on a specific dataset. The documentation seems unclear about this workflow.

@jankinf
Copy link
Author

jankinf commented Feb 19, 2025

Additional Information:
The save_checkpoint function in https://github.com/Jiayi-Pan/TinyZero/blob/main/verl/workers/fsdp_workers.py#L478 outputs checkpoints in HuggingFace format (I'm not sure if the code in this repository corresponds to a specific historical version of verl). Therefore, it can be loaded directly using AutoModelForCausalLM.from_pretrained, which seems convenient and works well. The output model directory structure is as follows:

.
└── actor
    └── global_step_100
        ├── added_tokens.json
        ├── config.json
        ├── generation_config.json
        ├── merges.txt
        ├── model-00001-of-00003.safetensors
        ├── model-00002-of-00003.safetensors
        ├── model-00003-of-00003.safetensors
        ├── model.safetensors.index.json
        ├── special_tokens_map.json
        ├── tokenizer_config.json
        ├── tokenizer.json
        └── vocab.json

However, the save_checkpoint used in this repository generates distributed rank replicas (https://github.com/volcengine/verl/blob/main/verl/utils/checkpoint/fsdp_checkpoint_manager.py#L134), which can be loaded using provided load_checkpoint. If I try to execute AutoModelForCausalLM.from_pretrained first and then load a model from one of the ranks(for example model_world_size_2_rank_0.pt), the code will throw an error. The output model directory structure is as follows:

.
├── global_step_100
│   ├── actor
│   │   ├── extra_state_world_size_2_rank_0.pt
│   │   ├── extra_state_world_size_2_rank_1.pt
│   │   ├── huggingface
│   │   │   ├── added_tokens.json
│   │   │   ├── config.json
│   │   │   ├── merges.txt
│   │   │   ├── special_tokens_map.json
│   │   │   ├── tokenizer_config.json
│   │   │   ├── tokenizer.json
│   │   │   └── vocab.json
│   │   ├── model_world_size_2_rank_0.pt
│   │   ├── model_world_size_2_rank_1.pt
│   │   ├── optim_world_size_2_rank_0.pt
│   │   └── optim_world_size_2_rank_1.pt
│   └── data.pt
└── latest_checkpointed_iteration.txt

The content of model_world_size_2_rank_0.pt is like:

model_state['model.embed_tokens.weight']
DTensor(local_tensor=tensor([[ 0.0395,  0.0141, -0.0154,  ..., -0.0335,  0.0153,  0.0266],
        [ 0.0109,  0.0142, -0.0122,  ..., -0.0061, -0.0134,  0.0432],
        [-0.0272, -0.0251,  0.0173,  ..., -0.0285,  0.0032,  0.0245],
        ...,
        [-0.0771, -0.0359, -0.0449,  ...,  0.0003, -0.0025,  0.0280],
        [ 0.0208, -0.0186, -0.0026,  ..., -0.0045, -0.0166, -0.0067],
        [-0.0476,  0.0198, -0.0198,  ...,  0.0417, -0.0193,  0.0383]]), device_mesh=DeviceMesh([0, 1], mesh_dim_names=('fsdp',)), placements=(Shard(dim=0),))

the format is DTensor,and loading it will result in an error:

RuntimeError: Error(s) in loading state_dict for Qwen2ForCausalLM:
        While copying the parameter named "model.embed_tokens.weight", whose dimensions in the model are torch.Size([151936, 2048]) and whose dimensions in the checkpoint are torch.Size([151936, 2048]), an exception occurred : ('aten.copy_.default: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!',).
        While copying the parameter named "model.layers.0.self_attn.q_proj.weight", whose dimensions in the model are torch.Size([2048, 2048]) and whose dimensions in the checkpoint are torch.Size([2048, 2048]), an exception occurred : ('aten.copy_.default: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!',).
        While copying the parameter named "model.layers.0.self_attn.q_proj.bias", whose dimensions in the model are torch.Size([2048]) and whose dimensions in the checkpoint are torch.Size([2048]), an exception occurred : ('aten.copy_.default: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!',). 
...

I provide my code as follows for easy reproduction of the results:

import hydra
import numpy as np

import torch
import torch.distributed
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoConfig

from verl import DataProto
from verl.utils import hf_tokenizer
from verl.utils.model import get_generation_config
from verl.utils.fs import copy_local_path_from_hdfs
from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn
from verl.workers.reward_manager import NaiveRewardManager
from verl.workers.rollout.hf_rollout import HFRollout

model_path = "checkpoints/TinyZero/countdown-qwen2.5-3b-grpo/actor/global_step_400"
# model_path = "Qwen/Qwen2.5-3B"
# model_path = "Qwen/Qwen2.5-3B-Instruct"
local_path = copy_local_path_from_hdfs(model_path)

trust_remote_code = True
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)

torch_dtype = torch.bfloat16

actor_model_config = AutoConfig.from_pretrained(
    local_path, trust_remote_code=trust_remote_code
)

actor_module = AutoModelForCausalLM.from_pretrained(
    pretrained_model_name_or_path=local_path,
    torch_dtype=torch_dtype,
    config=actor_model_config,
    attn_implementation="flash_attention_2",
    trust_remote_code=trust_remote_code,
)

# ckpt_new_version = "checkpoints/TinyZero/grpo-countdown-qwen2.5-3b-v2/global_step_100/actor/model_world_size_2_rank_0.pt"
# model_state = torch.load(ckpt_new_version, map_location="cpu")
# actor_module.load_state_dict(model_state)

actor_module.to(torch_dtype)
actor_module.to("cuda:0")

generation_config = get_generation_config(
    local_path, trust_remote_code=trust_remote_code
)

val_reward_fn = NaiveRewardManager(
    tokenizer=tokenizer, num_examine=1, compute_score=None
)


@hydra.main()
def main(config):

    val_dataset = RLHFDataset(
        parquet_files="data/countdown/test.parquet",
        tokenizer=tokenizer,
        prompt_key="prompt",
        max_prompt_length=512,
        filter_prompts=True,
        return_raw_chat=False,
        truncation="error",
    )
    val_dataloader = DataLoader(
        dataset=val_dataset,
        batch_size=3,
        shuffle=True,
        drop_last=True,
        collate_fn=collate_fn,
    )

    assert len(val_dataloader) >= 1
    sample_inputs = []
    sample_outputs = []
    sample_scores = []
    reward_tensor_lst = []
    data_source_lst = []

    hfrollout = HFRollout(module=actor_module, config=config)
    for data in val_dataloader:
        test_batch = DataProto.from_single_dict(data)
        test_batch = test_batch.to("cuda")
        input_ids = test_batch.batch["input_ids"]
        input_texts = [
            tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids
        ]
        sample_inputs.extend(input_texts)

        test_gen_batch = test_batch.pop(["input_ids", "attention_mask", "position_ids"])
        test_gen_batch.meta_info = {
            "eos_token_id": tokenizer.eos_token_id,
            "pad_token_id": tokenizer.pad_token_id,
            "recompute_log_prob": False,
            "do_sample": False,
            "validate": True,
        }

        # pad to be divisible by dp_size
        # test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, hfrollout.world_size)
        test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, 1)
        test_output_gen_batch_padded = hfrollout.generate_sequences(
            test_gen_batch_padded
        )
        # unpad
        test_output_gen_batch = unpad_dataproto(
            test_output_gen_batch_padded, pad_size=pad_size
        )
        print("validation generation end")

        # Store generated outputs
        output_ids = test_output_gen_batch.batch["responses"]
        output_texts = [
            tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids
        ]
        sample_outputs.extend(output_texts)
        test_batch = test_batch.union(test_output_gen_batch)

        # evaluate using reward_function
        reward_tensor = val_reward_fn(test_batch)
        # Store scores
        scores = reward_tensor.sum(-1).cpu().tolist()
        sample_scores.extend(scores)
        reward_tensor_lst.append(reward_tensor)
        data_source_lst.append(
            test_batch.non_tensor_batch.get(
                "data_source", ["unknown"] * reward_tensor.shape[0]
            )
        )

    reward_tensor = torch.cat(reward_tensor_lst, dim=0).sum(-1).cpu()  # (batch_size,)
    data_sources = np.concatenate(data_source_lst, axis=0)

    # evaluate test_score based on data source
    data_source_reward = {}
    for i in range(reward_tensor.shape[0]):
        data_source = data_sources[i]
        if data_source not in data_source_reward:
            data_source_reward[data_source] = []
        data_source_reward[data_source].append(reward_tensor[i].item())

    metric_dict = {}
    for data_source, rewards in data_source_reward.items():
        metric_dict[f"val/test_score/{data_source}"] = np.mean(rewards)

    print(metric_dict)


if __name__ == "__main__":
    main()

In the code above, I implemented how to load model ckpts by the tinyzero-verl version and use it to evaluate countdown dataset, but I don't know how to load model ckpts by the latest version of verl (corresponding to the commented parts in the code).

To summarize my questions:

  1. How to load saved models using the latest verl code (similar to my code above) and perform evaluation?
  2. Is there room for optimization/acceleration in the evaluation code provided above?

Your guidance would be appreciated! Thank you!

@Cppowboy
Copy link
Contributor

Cppowboy commented Feb 19, 2025

I have the same problem, and I solve this by using the following script to convert the FSDP checkpoint to huggingface checkpoint. The evaluation results seem correct, but I don't know if this is an elegant solution.

#!/usr/bin/env python
# encoding: utf-8
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
import torch
import fire
from glob import glob
from collections import defaultdict


def main(fsdp_checkpoint_path, huggingface_model_path, output_path):
    state_dict = defaultdict(list)

    world_size = 32
    for rank in range(world_size):
        filepath = f"{fsdp_checkpoint_path}/model_world_size_{world_size}_rank_{rank}.pt"
        print('loading', filepath)
        this_state_dict = torch.load(filepath)
        for key, value in this_state_dict.items():
            state_dict[key].append(value.to_local())

    for key in state_dict:
        state_dict[key] = torch.cat(state_dict[key], dim=0)

    config = AutoConfig.from_pretrained(huggingface_model_path)
    model = AutoModelForCausalLM.from_config(config)
    model.load_state_dict(state_dict)

    #for filepath in glob(f'{fsdp_checkpoint_path}/model_*.pt'):
    #    part_state_dict = torch.load(filepath)
    #    model.load_state_dict(part_state_dict)

    model.save_pretrained(output_path, max_shard_size="10GB")

    tokenizer = AutoTokenizer.from_pretrained(huggingface_model_path)
    tokenizer.save_pretrained(output_path)


if __name__ == "__main__":
    fire.Fire(main)
python convert_fsdp_to_hf.py checkpoints/global_step_20/actor/ checkpoints/global_step_20/actor/huggingface/ qwen2.5_7b_rl_step20

@jankinf
Copy link
Author

jankinf commented Feb 20, 2025

@Cppowboy Thank you for your suggestion. I've successfully implemented your solution and would like to share the improvements:

The code has been updated to support both:

  • Hugging Face-style checkpoints
  • FSDP checkpoints

This enhancement provides more flexibility for users working with different checkpoint formats.

import hydra
import numpy as np
import re
import torch
import torch.distributed
from pathlib import Path
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoConfig
from collections import defaultdict

from verl import DataProto
from verl.utils import hf_tokenizer
from verl.utils.fs import copy_local_path_from_hdfs
from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn
from verl.workers.reward_manager import NaiveRewardManager
from verl.workers.rollout.hf_rollout import HFRollout


def load_sharded_model(fsdp_checkpoint_path):
    state_dict = defaultdict(list)
    checkpoint_dir = Path(fsdp_checkpoint_path)

    shard_files = list(checkpoint_dir.glob("model_world_size_*_rank_*.pt"))
    if not shard_files:
        raise ValueError(f"No checkpoint files found in {fsdp_checkpoint_path}")

    pattern = re.compile(r"model_world_size_(\d+)_rank_(\d+)\.pt")
    world_sizes = set()
    for file in shard_files:
        match = pattern.match(file.name)
        if match:
            world_sizes.add(int(match.group(1)))

    if len(world_sizes) != 1:
        raise ValueError(
            f"Inconsistent world_size found in checkpoint files: {world_sizes}"
        )

    world_size = world_sizes.pop()
    print(f"Found checkpoints with world_size = {world_size}")

    for rank in range(world_size):
        filepath = checkpoint_dir / f"model_world_size_{world_size}_rank_{rank}.pt"
        if not filepath.exists():
            raise ValueError(f"Missing shard file: {filepath}")

        print(f"Loading shard: {filepath}")
        shard_dict = torch.load(filepath)

        for key, value in shard_dict.items():
            if hasattr(value, "to_local"):
                value = value.to_local()
            state_dict[key].append(value)

    consolidated_state_dict = {}
    for key in state_dict:
        try:
            consolidated_state_dict[key] = torch.cat(state_dict[key], dim=0)
        except (RuntimeError, TypeError):
            consolidated_state_dict[key] = state_dict[key][0]
            print(
                f"Parameter '{key}' does not need concatenation, using first shard value"
            )

    return consolidated_state_dict


def initialize_model_and_tokenizer(
    model_path, trust_remote_code=True, torch_dtype=torch.bfloat16
):
    local_path = copy_local_path_from_hdfs(model_path)
    tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)

    actor_model_config = AutoConfig.from_pretrained(
        local_path, trust_remote_code=trust_remote_code
    )
    actor_module = AutoModelForCausalLM.from_pretrained(
        pretrained_model_name_or_path=local_path,
        torch_dtype=torch_dtype,
        config=actor_model_config,
        attn_implementation="flash_attention_2",
        trust_remote_code=trust_remote_code,
    )

    return tokenizer, actor_module


@hydra.main()
def main(config):
    # Loading huggingface-style checkpoint
    model_path = "/data/projects/TinyZero/checkpoints/TinyZero/countdown-qwen2.5-3b-grpo/actor/global_step_400"
    # model_path = "Qwen/Qwen2.5-3B"
    # model_path = "Qwen/Qwen2.5-3B-Instruct"

    tokenizer, actor_module = initialize_model_and_tokenizer(model_path)

    # Loading FSDP checkpoint (optional: these three lines can be skipped. Prerequisite: actor_module must be preloaded)
    fsdp_checkpoint_path = "/data/projects/TinyZero/checkpoints/TinyZero/grpo-countdown-qwen2.5-3b-v2/global_step_200/actor"
    state_dict = load_sharded_model(fsdp_checkpoint_path)
    actor_module.load_state_dict(state_dict)

    actor_module.to(torch.bfloat16)
    actor_module.to("cuda:0")

    val_dataset = RLHFDataset(
        parquet_files="/home/projects/TinyZero/data/countdown/test.parquet",
        tokenizer=tokenizer,
        prompt_key="prompt",
        max_prompt_length=512,
        filter_prompts=True,
        return_raw_chat=False,
        truncation="error",
    )
    val_dataloader = DataLoader(
        dataset=val_dataset,
        batch_size=3,
        shuffle=True,
        drop_last=True,
        collate_fn=collate_fn,
    )

    assert len(val_dataloader) >= 1

    val_reward_fn = NaiveRewardManager(
        tokenizer=tokenizer, num_examine=1, compute_score=None
    )

    hfrollout = HFRollout(module=actor_module, config=config)

    sample_inputs = []
    sample_outputs = []
    sample_scores = []
    reward_tensor_lst = []
    data_source_lst = []

    for data in val_dataloader:
        test_batch = DataProto.from_single_dict(data)
        test_batch = test_batch.to("cuda")
        input_ids = test_batch.batch["input_ids"]
        input_texts = [
            tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids
        ]
        sample_inputs.extend(input_texts)

        test_gen_batch = test_batch.pop(["input_ids", "attention_mask", "position_ids"])
        test_gen_batch.meta_info = {
            "eos_token_id": tokenizer.eos_token_id,
            "pad_token_id": tokenizer.pad_token_id,
            "recompute_log_prob": False,
            "do_sample": False,
            "validate": True,
        }

        test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, 1)
        test_output_gen_batch_padded = hfrollout.generate_sequences(
            test_gen_batch_padded
        )
        test_output_gen_batch = unpad_dataproto(
            test_output_gen_batch_padded, pad_size=pad_size
        )
        print("validation generation end")

        output_ids = test_output_gen_batch.batch["responses"]
        output_texts = [
            tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids
        ]
        sample_outputs.extend(output_texts)
        test_batch = test_batch.union(test_output_gen_batch)

        reward_tensor = val_reward_fn(test_batch)
        scores = reward_tensor.sum(-1).cpu().tolist()
        sample_scores.extend(scores)
        reward_tensor_lst.append(reward_tensor)
        data_source_lst.append(
            test_batch.non_tensor_batch.get(
                "data_source", ["unknown"] * reward_tensor.shape[0]
            )
        )

    reward_tensor = torch.cat(reward_tensor_lst, dim=0).sum(-1).cpu()
    data_sources = np.concatenate(data_source_lst, axis=0)

    data_source_reward = {}
    for i in range(reward_tensor.shape[0]):
        data_source = data_sources[i]
        if data_source not in data_source_reward:
            data_source_reward[data_source] = []
        data_source_reward[data_source].append(reward_tensor[i].item())

    metric_dict = {}
    for data_source, rewards in data_source_reward.items():
        metric_dict[f"val/test_score/{data_source}"] = np.mean(rewards)

    print(metric_dict)


if __name__ == "__main__":
    main()

@PeterSH6
Copy link
Collaborator

@Cppowboy @jankinf Thanks for your great work!

@jankinf Would you like to make a PR to upload your script for evaluation in the verl/scripts directory? It would be better if you could enhance your script by allowing users to specify the path and dataset using arguments.
@Cppowboy Would you like to make a PR for your checkpoint converter in the verl/scripts directory? Also, using arguments would be more user-friendly.

Thanks in advance!

@PeterSH6 PeterSH6 added enhancement New feature or request good first issue Good for newcomers labels Feb 21, 2025
@jankinf
Copy link
Author

jankinf commented Feb 24, 2025

As requested, I've created PR #359 to address this issue. Looking forward to your feedback. @PeterSH6

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request good first issue Good for newcomers
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants