Skip to content

Commit

Permalink
Remove dead TF loading code (#28926)
Browse files Browse the repository at this point in the history
Remove dead code
  • Loading branch information
Rocketknight1 authored Feb 8, 2024
1 parent 115ac94 commit 693667b
Showing 1 changed file with 0 additions and 50 deletions.
50 changes: 0 additions & 50 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import h5py
import numpy as np
import tensorflow as tf
from huggingface_hub import Repository, list_repo_files
from packaging.version import parse

from . import DataCollatorWithPadding, DefaultDataCollator
Expand Down Expand Up @@ -1356,55 +1355,6 @@ def _save_checkpoint(self, checkpoint_dir, epoch):
with open(extra_data_path, "wb") as f:
pickle.dump(extra_data, f)

def load_repo_checkpoint(self, repo_path_or_name):
"""
Loads a saved checkpoint (model weights and optimizer state) from a repo. Returns the current epoch count when
the checkpoint was made.
Args:
repo_path_or_name (`str`):
Can either be a repository name for your {object} in the Hub or a path to a local folder (in which case
the repository will have the name of that local folder).
Returns:
`dict`: A dictionary of extra metadata from the checkpoint, most commonly an "epoch" count.
"""
if getattr(self, "optimizer", None) is None:
raise RuntimeError(
"Checkpoint loading failed as no optimizer is attached to the model. "
"This is most likely caused by the model not being compiled."
)
if os.path.isdir(repo_path_or_name):
local_dir = repo_path_or_name
else:
# If this isn't a local path, check that the remote repo exists and has a checkpoint in it
repo_files = list_repo_files(repo_path_or_name)
for file in ("checkpoint/weights.h5", "checkpoint/extra_data.pickle"):
if file not in repo_files:
raise FileNotFoundError(f"Repo {repo_path_or_name} does not contain checkpoint file {file}!")
repo = Repository(repo_path_or_name.split("/")[-1], clone_from=repo_path_or_name)
local_dir = repo.local_dir

# Now make sure the repo actually has a checkpoint in it.
checkpoint_dir = os.path.join(local_dir, "checkpoint")
weights_file = os.path.join(checkpoint_dir, "weights.h5")
if not os.path.isfile(weights_file):
raise FileNotFoundError(f"Could not find checkpoint file weights.h5 in repo {repo_path_or_name}!")
extra_data_file = os.path.join(checkpoint_dir, "extra_data.pickle")
if not os.path.isfile(extra_data_file):
raise FileNotFoundError(f"Could not find checkpoint file extra_data.pickle in repo {repo_path_or_name}!")

# Assuming the repo is real and we got a checkpoint, load the weights and the optimizer state into the model.
# The optimizer state includes the iteration count, so learning rate schedules should resume as normal too.
self.load_weights(weights_file)
with open(extra_data_file, "rb") as f:
extra_data = pickle.load(f)
self.optimizer.set_weights(extra_data["optimizer_state"])

# Finally, return the epoch number from the checkpoint. This isn't a property of the model, so we can't
# set it directly, but the user can pass it to fit().
return {"epoch": extra_data["epoch"]}

def prepare_tf_dataset(
self,
dataset: "datasets.Dataset", # noqa:F821
Expand Down

0 comments on commit 693667b

Please sign in to comment.