Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix trainer saving safetensors: metadata is None #28219

Merged
merged 2 commits into from
Jan 2, 2024

Conversation

hiyouga
Copy link
Contributor

@hiyouga hiyouga commented Dec 23, 2023

What does this PR do?

Fixes hiyouga/LLaMA-Factory#1959

If we use Trainer to train a model that does not belong to the PreTrainedModel class, such as the PreTrainedModelwithValuehead from the TRL library, the trainer will not save the metadata. This leads to errors in reading the metadata when using AutoModelforCausalLM.from_pretrained to load the model.

safetensors.torch.save_file(state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME))

with safe_open(resolved_archive_file, framework="pt") as f:
metadata = f.metadata()
if metadata.get("format") == "pt":
pass

Although it may sound strange to load a model that does not belong to the PreTrainedModel class using AutoModelForCausalLM.from_pretrained, this approach benefits model loading by utilizing features such as low_cpu_mem_usage if the model checkpoints share the same structure.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@muellerzr @pacman100

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @hiyouga for these changes which are inline with what the save_pretrained does at

safe_save_file(shard, os.path.join(save_directory, shard_file), metadata={"format": "pt"})

LGTM!

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!

@amyeroberts amyeroberts merged commit 502a10a into huggingface:main Jan 2, 2024
21 checks passed
Saibo-creator pushed a commit to epfl-dlab/transformers-GCD-PR that referenced this pull request Jan 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Baichuan2-13B模型全参训练Reward Model后预测报错
3 participants