Skip to content

Commit

Permalink
Accept token in trainer.push_to_hub() (#30093)
Browse files Browse the repository at this point in the history
* pass token to trainer.push_to_hub

* fmt

* Update src/transformers/trainer.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* pass token to create_repo, update_folder

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
  • Loading branch information
mapmeld and amyeroberts authored Apr 8, 2024
1 parent 0201f64 commit 08c8443
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3909,7 +3909,7 @@ def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
else:
return 0

def init_hf_repo(self):
def init_hf_repo(self, token: Optional[str] = None):
"""
Initializes a git repo in `self.args.hub_model_id`.
"""
Expand All @@ -3922,7 +3922,8 @@ def init_hf_repo(self):
else:
repo_name = self.args.hub_model_id

repo_url = create_repo(repo_name, token=self.args.hub_token, private=self.args.hub_private_repo, exist_ok=True)
token = token if token is not None else self.args.hub_token
repo_url = create_repo(repo_name, token=token, private=self.args.hub_private_repo, exist_ok=True)
self.hub_model_id = repo_url.repo_id
self.push_in_progress = None

Expand Down Expand Up @@ -4067,7 +4068,13 @@ def _finish_current_push(self):
logger.info("Waiting for the current checkpoint push to be finished, this might take a couple of minutes.")
self.push_in_progress.wait_until_done()

def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
def push_to_hub(
self,
commit_message: Optional[str] = "End of training",
blocking: bool = True,
token: Optional[str] = None,
**kwargs,
) -> str:
"""
Upload `self.model` and `self.tokenizer` or `self.image_processor` to the 🤗 model hub on the repo `self.args.hub_model_id`.
Expand All @@ -4076,6 +4083,8 @@ def push_to_hub(self, commit_message: Optional[str] = "End of training", blockin
Message to commit while pushing.
blocking (`bool`, *optional*, defaults to `True`):
Whether the function should return only when the `git push` has finished.
token (`str`, *optional*, defaults to `None`):
Token with write permission to overwrite Trainer's original args.
kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to [`~Trainer.create_model_card`].
Expand All @@ -4089,10 +4098,11 @@ def push_to_hub(self, commit_message: Optional[str] = "End of training", blockin
model_name = Path(self.args.output_dir).name
else:
model_name = self.args.hub_model_id.split("/")[-1]
token = token if token is not None else self.args.hub_token

# In case the user calls this method with args.push_to_hub = False
if self.hub_model_id is None:
self.init_hf_repo()
self.init_hf_repo(token=token)

# Needs to be executed on all processes for TPU training, but will only save on the processed determined by
# self.args.should_save.
Expand Down Expand Up @@ -4125,7 +4135,7 @@ def push_to_hub(self, commit_message: Optional[str] = "End of training", blockin
repo_id=self.hub_model_id,
folder_path=self.args.output_dir,
commit_message=commit_message,
token=self.args.hub_token,
token=token,
run_as_future=not blocking,
ignore_patterns=["_*", f"{PREFIX_CHECKPOINT_DIR}-*"],
)
Expand Down

0 comments on commit 08c8443

Please sign in to comment.