Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add revision to trainer push_to_hub #33482

Merged
merged 7 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4461,6 +4461,7 @@ def push_to_hub(
commit_message: Optional[str] = "End of training",
blocking: bool = True,
token: Optional[str] = None,
revision: Optional[str] = None,
**kwargs,
) -> str:
"""
Expand All @@ -4473,6 +4474,8 @@ def push_to_hub(
Whether the function should return only when the `git push` has finished.
token (`str`, *optional*, defaults to `None`):
Token with write permission to overwrite Trainer's original args.
revision (`str`, *optional*):
The git revision to commit from. Defaults to the head of the "main" branch.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it the branch to commit to - rather than the commit to commit from?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this description comes from the hub upload_folder api. I was only going to use it for branches and tested with that, so not sure if it works for other commit hashes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK!

kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to [`~Trainer.create_model_card`].

Expand Down Expand Up @@ -4526,6 +4529,7 @@ def push_to_hub(
token=token,
run_as_future=not blocking,
ignore_patterns=["_*", f"{PREFIX_CHECKPOINT_DIR}-*"],
revision=revision,
)

#
Expand Down
21 changes: 20 additions & 1 deletion tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from unittest.mock import Mock, patch

import numpy as np
from huggingface_hub import HfFolder, ModelCard, delete_repo, list_repo_commits, list_repo_files
from huggingface_hub import HfFolder, ModelCard, create_branch, delete_repo, list_repo_commits, list_repo_files
from parameterized import parameterized
from requests.exceptions import HTTPError

Expand Down Expand Up @@ -3933,6 +3933,25 @@ def test_push_to_hub_tags(self):
model_card = ModelCard.load(repo_name)
self.assertTrue("test-trainer-tags" in model_card.data.tags)

def test_push_to_hub_with_revision(self):
# Checks if `trainer.push_to_hub()` works correctly by adding revision
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = get_regression_trainer(
output_dir=os.path.join(tmp_dir, "test-trainer-revision"),
push_to_hub=True,
hub_token=self._token,
)
branch = "v1.0"
create_branch(repo_id=trainer.hub_model_id, branch=branch, token=self._token, exist_ok=True)
url = trainer.push_to_hub(revision=branch)

# Extract branch from the url
re_search = re.search(r"tree/([^/]+)/", url)
self.assertIsNotNone(re_search)

branch_name = re_search.groups()[0]
self.assertEqual(branch_name, branch)


@require_torch
@require_optuna
Expand Down
Loading