Skip to content

Commit

Permalink
Fix weights_only (#28725)
Browse files Browse the repository at this point in the history
fix

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
  • Loading branch information
2 people authored and amyeroberts committed Jan 26, 2024
1 parent 56ee444 commit 711bed1
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 15 deletions.
3 changes: 2 additions & 1 deletion src/transformers/convert_pytorch_checkpoint_to_tf2.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,10 +330,11 @@ def convert_pt_checkpoint_to_tf(
if compare_with_pt_model:
tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network

weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
state_dict = torch.load(
pytorch_checkpoint_path,
map_location="cpu",
weights_only=is_torch_greater_or_equal_than_1_13,
**weights_only_kwarg,
)
pt_model = pt_model_class.from_pretrained(
pretrained_model_name_or_path=None, config=config, state_dict=state_dict
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/modeling_flax_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def load_pytorch_checkpoint_in_flax_state_dict(
)
raise

pt_state_dict = torch.load(pt_path, map_location="cpu", weights_only=is_torch_greater_or_equal_than_1_13)
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
pt_state_dict = torch.load(pt_path, map_location="cpu", **weights_only_kwarg)
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")

flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
Expand Down Expand Up @@ -252,7 +253,8 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
flax_state_dict = {}
for shard_file in shard_filenames:
# load using msgpack utils
pt_state_dict = torch.load(shard_file, weights_only=is_torch_greater_or_equal_than_1_13)
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
pt_state_dict = torch.load(shard_file, **weights_only_kwarg)
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}

model_prefix = flax_model.base_model_prefix
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/modeling_tf_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ def load_pytorch_checkpoint_in_tf2_model(
if pt_path.endswith(".safetensors"):
state_dict = safe_load_file(pt_path)
else:
state_dict = torch.load(pt_path, map_location="cpu", weights_only=is_torch_greater_or_equal_than_1_13)
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
state_dict = torch.load(pt_path, map_location="cpu", **weights_only_kwarg)

pt_state_dict.update(state_dict)

Expand Down
10 changes: 4 additions & 6 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,11 +482,8 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
error_message += f"\nMissing key(s): {str_unexpected_keys}."
raise RuntimeError(error_message)

loader = (
safe_load_file
if load_safe
else partial(torch.load, map_location="cpu", weights_only=is_torch_greater_or_equal_than_1_13)
)
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu", **weights_only_kwarg)

for shard_file in shard_files:
state_dict = loader(os.path.join(folder, shard_file))
Expand Down Expand Up @@ -530,10 +527,11 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
and is_zipfile(checkpoint_file)
):
extra_args = {"mmap": True}
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
return torch.load(
checkpoint_file,
map_location=map_location,
weights_only=is_torch_greater_or_equal_than_1_13,
**weights_only_kwarg,
**extra_args,
)
except Exception as e:
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/wav2vec2/modeling_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,10 +1334,11 @@ def load_adapter(self, target_lang: str, force_load=True, **kwargs):
cache_dir=cache_dir,
)

weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
state_dict = torch.load(
weight_path,
map_location="cpu",
weights_only=is_torch_greater_or_equal_than_1_13,
**weights_only_kwarg,
)

except EnvironmentError:
Expand Down
10 changes: 6 additions & 4 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2088,6 +2088,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
)

if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file) or is_fsdp_ckpt:
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
# If the model is on the GPU, it still works!
if is_sagemaker_mp_enabled():
if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")):
Expand All @@ -2106,7 +2107,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
state_dict = torch.load(
weights_file,
map_location="cpu",
weights_only=is_torch_greater_or_equal_than_1_13,
**weights_only_kwarg,
)
# Required for smp to not auto-translate state_dict from hf to smp (is already smp).
state_dict["_smp_is_partial"] = False
Expand All @@ -2123,7 +2124,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
state_dict = torch.load(
weights_file,
map_location="cpu",
weights_only=is_torch_greater_or_equal_than_1_13,
**weights_only_kwarg,
)

# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
Expand Down Expand Up @@ -2176,6 +2177,7 @@ def _load_best_model(self):
or os.path.exists(best_safe_adapter_model_path)
):
has_been_loaded = True
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
if is_sagemaker_mp_enabled():
if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")):
# If the 'user_content.pt' file exists, load with the new smp api.
Expand All @@ -2195,7 +2197,7 @@ def _load_best_model(self):
state_dict = torch.load(
best_model_path,
map_location="cpu",
weights_only=is_torch_greater_or_equal_than_1_13,
**weights_only_kwarg,
)

state_dict["_smp_is_partial"] = False
Expand Down Expand Up @@ -2228,7 +2230,7 @@ def _load_best_model(self):
state_dict = torch.load(
best_model_path,
map_location="cpu",
weights_only=is_torch_greater_or_equal_than_1_13,
**weights_only_kwarg,
)

# If the model is on the GPU, it still works!
Expand Down

0 comments on commit 711bed1

Please sign in to comment.