Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix weights_only #28725

Merged
merged 1 commit into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/transformers/convert_pytorch_checkpoint_to_tf2.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,10 +330,11 @@ def convert_pt_checkpoint_to_tf(
if compare_with_pt_model:
tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network

weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
state_dict = torch.load(
pytorch_checkpoint_path,
map_location="cpu",
weights_only=is_torch_greater_or_equal_than_1_13,
**weights_only_kwarg,
)
pt_model = pt_model_class.from_pretrained(
pretrained_model_name_or_path=None, config=config, state_dict=state_dict
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/modeling_flax_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def load_pytorch_checkpoint_in_flax_state_dict(
)
raise

pt_state_dict = torch.load(pt_path, map_location="cpu", weights_only=is_torch_greater_or_equal_than_1_13)
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
pt_state_dict = torch.load(pt_path, map_location="cpu", **weights_only_kwarg)
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")

flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
Expand Down Expand Up @@ -252,7 +253,8 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
flax_state_dict = {}
for shard_file in shard_filenames:
# load using msgpack utils
pt_state_dict = torch.load(shard_file, weights_only=is_torch_greater_or_equal_than_1_13)
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
pt_state_dict = torch.load(shard_file, **weights_only_kwarg)
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}

model_prefix = flax_model.base_model_prefix
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/modeling_tf_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ def load_pytorch_checkpoint_in_tf2_model(
if pt_path.endswith(".safetensors"):
state_dict = safe_load_file(pt_path)
else:
state_dict = torch.load(pt_path, map_location="cpu", weights_only=is_torch_greater_or_equal_than_1_13)
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
state_dict = torch.load(pt_path, map_location="cpu", **weights_only_kwarg)

pt_state_dict.update(state_dict)

Expand Down
10 changes: 4 additions & 6 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,11 +482,8 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
error_message += f"\nMissing key(s): {str_unexpected_keys}."
raise RuntimeError(error_message)

loader = (
safe_load_file
if load_safe
else partial(torch.load, map_location="cpu", weights_only=is_torch_greater_or_equal_than_1_13)
)
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu", **weights_only_kwarg)

for shard_file in shard_files:
state_dict = loader(os.path.join(folder, shard_file))
Expand Down Expand Up @@ -530,10 +527,11 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
and is_zipfile(checkpoint_file)
):
extra_args = {"mmap": True}
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
return torch.load(
checkpoint_file,
map_location=map_location,
weights_only=is_torch_greater_or_equal_than_1_13,
**weights_only_kwarg,
**extra_args,
)
except Exception as e:
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/wav2vec2/modeling_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,10 +1334,11 @@ def load_adapter(self, target_lang: str, force_load=True, **kwargs):
cache_dir=cache_dir,
)

weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
state_dict = torch.load(
weight_path,
map_location="cpu",
weights_only=is_torch_greater_or_equal_than_1_13,
**weights_only_kwarg,
)

except EnvironmentError:
Expand Down
10 changes: 6 additions & 4 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2091,6 +2091,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
)

if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file) or is_fsdp_ckpt:
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
# If the model is on the GPU, it still works!
if is_sagemaker_mp_enabled():
if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")):
Expand All @@ -2109,7 +2110,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
state_dict = torch.load(
weights_file,
map_location="cpu",
weights_only=is_torch_greater_or_equal_than_1_13,
**weights_only_kwarg,
)
# Required for smp to not auto-translate state_dict from hf to smp (is already smp).
state_dict["_smp_is_partial"] = False
Expand All @@ -2126,7 +2127,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
state_dict = torch.load(
weights_file,
map_location="cpu",
weights_only=is_torch_greater_or_equal_than_1_13,
**weights_only_kwarg,
)

# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
Expand Down Expand Up @@ -2179,6 +2180,7 @@ def _load_best_model(self):
or os.path.exists(best_safe_adapter_model_path)
):
has_been_loaded = True
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
if is_sagemaker_mp_enabled():
if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")):
# If the 'user_content.pt' file exists, load with the new smp api.
Expand All @@ -2198,7 +2200,7 @@ def _load_best_model(self):
state_dict = torch.load(
best_model_path,
map_location="cpu",
weights_only=is_torch_greater_or_equal_than_1_13,
**weights_only_kwarg,
)

state_dict["_smp_is_partial"] = False
Expand Down Expand Up @@ -2231,7 +2233,7 @@ def _load_best_model(self):
state_dict = torch.load(
best_model_path,
map_location="cpu",
weights_only=is_torch_greater_or_equal_than_1_13,
**weights_only_kwarg,
)

# If the model is on the GPU, it still works!
Expand Down
Loading