diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 97a052093652..e0a49ee5795e 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -4461,6 +4461,7 @@ def push_to_hub( commit_message: Optional[str] = "End of training", blocking: bool = True, token: Optional[str] = None, + revision: Optional[str] = None, **kwargs, ) -> str: """ @@ -4473,6 +4474,8 @@ def push_to_hub( Whether the function should return only when the `git push` has finished. token (`str`, *optional*, defaults to `None`): Token with write permission to overwrite Trainer's original args. + revision (`str`, *optional*): + The git revision to commit from. Defaults to the head of the "main" branch. kwargs (`Dict[str, Any]`, *optional*): Additional keyword arguments passed along to [`~Trainer.create_model_card`]. @@ -4526,6 +4529,7 @@ def push_to_hub( token=token, run_as_future=not blocking, ignore_patterns=["_*", f"{PREFIX_CHECKPOINT_DIR}-*"], + revision=revision, ) # diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 791486ec8374..c5f8b6169fcf 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -31,7 +31,7 @@ from unittest.mock import Mock, patch import numpy as np -from huggingface_hub import HfFolder, ModelCard, delete_repo, list_repo_commits, list_repo_files +from huggingface_hub import HfFolder, ModelCard, create_branch, delete_repo, list_repo_commits, list_repo_files from parameterized import parameterized from requests.exceptions import HTTPError @@ -3933,6 +3933,25 @@ def test_push_to_hub_tags(self): model_card = ModelCard.load(repo_name) self.assertTrue("test-trainer-tags" in model_card.data.tags) + def test_push_to_hub_with_revision(self): + # Checks if `trainer.push_to_hub()` works correctly by adding revision + with tempfile.TemporaryDirectory() as tmp_dir: + trainer = get_regression_trainer( + output_dir=os.path.join(tmp_dir, "test-trainer-revision"), + push_to_hub=True, + hub_token=self._token, + ) + branch = "v1.0" + create_branch(repo_id=trainer.hub_model_id, branch=branch, token=self._token, exist_ok=True) + url = trainer.push_to_hub(revision=branch) + + # Extract branch from the url + re_search = re.search(r"tree/([^/]+)/", url) + self.assertIsNotNone(re_search) + + branch_name = re_search.groups()[0] + self.assertEqual(branch_name, branch) + @require_torch @require_optuna