diff --git a/src/transformers/convert_pytorch_checkpoint_to_tf2.py b/src/transformers/convert_pytorch_checkpoint_to_tf2.py index c10dd44ed853..26b19a4e81f4 100755 --- a/src/transformers/convert_pytorch_checkpoint_to_tf2.py +++ b/src/transformers/convert_pytorch_checkpoint_to_tf2.py @@ -330,10 +330,11 @@ def convert_pt_checkpoint_to_tf( if compare_with_pt_model: tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network + weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {} state_dict = torch.load( pytorch_checkpoint_path, map_location="cpu", - weights_only=is_torch_greater_or_equal_than_1_13, + **weights_only_kwarg, ) pt_model = pt_model_class.from_pretrained( pretrained_model_name_or_path=None, config=config, state_dict=state_dict diff --git a/src/transformers/modeling_flax_pytorch_utils.py b/src/transformers/modeling_flax_pytorch_utils.py index 87701c50f071..6c13ba9619f4 100644 --- a/src/transformers/modeling_flax_pytorch_utils.py +++ b/src/transformers/modeling_flax_pytorch_utils.py @@ -74,7 +74,8 @@ def load_pytorch_checkpoint_in_flax_state_dict( ) raise - pt_state_dict = torch.load(pt_path, map_location="cpu", weights_only=is_torch_greater_or_equal_than_1_13) + weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {} + pt_state_dict = torch.load(pt_path, map_location="cpu", **weights_only_kwarg) logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.") flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model) @@ -252,7 +253,8 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model): flax_state_dict = {} for shard_file in shard_filenames: # load using msgpack utils - pt_state_dict = torch.load(shard_file, weights_only=is_torch_greater_or_equal_than_1_13) + weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {} + pt_state_dict = torch.load(shard_file, **weights_only_kwarg) pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} model_prefix = flax_model.base_model_prefix diff --git a/src/transformers/modeling_tf_pytorch_utils.py b/src/transformers/modeling_tf_pytorch_utils.py index e68b02bc7ab4..a96481e06283 100644 --- a/src/transformers/modeling_tf_pytorch_utils.py +++ b/src/transformers/modeling_tf_pytorch_utils.py @@ -188,7 +188,8 @@ def load_pytorch_checkpoint_in_tf2_model( if pt_path.endswith(".safetensors"): state_dict = safe_load_file(pt_path) else: - state_dict = torch.load(pt_path, map_location="cpu", weights_only=is_torch_greater_or_equal_than_1_13) + weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {} + state_dict = torch.load(pt_path, map_location="cpu", **weights_only_kwarg) pt_state_dict.update(state_dict) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 585b1ac00523..3b4887fb5cef 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -482,11 +482,8 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True): error_message += f"\nMissing key(s): {str_unexpected_keys}." raise RuntimeError(error_message) - loader = ( - safe_load_file - if load_safe - else partial(torch.load, map_location="cpu", weights_only=is_torch_greater_or_equal_than_1_13) - ) + weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {} + loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu", **weights_only_kwarg) for shard_file in shard_files: state_dict = loader(os.path.join(folder, shard_file)) @@ -530,10 +527,11 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]): and is_zipfile(checkpoint_file) ): extra_args = {"mmap": True} + weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {} return torch.load( checkpoint_file, map_location=map_location, - weights_only=is_torch_greater_or_equal_than_1_13, + **weights_only_kwarg, **extra_args, ) except Exception as e: diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 6d4501ce97f5..4062fadbc8cc 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -1334,10 +1334,11 @@ def load_adapter(self, target_lang: str, force_load=True, **kwargs): cache_dir=cache_dir, ) + weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {} state_dict = torch.load( weight_path, map_location="cpu", - weights_only=is_torch_greater_or_equal_than_1_13, + **weights_only_kwarg, ) except EnvironmentError: diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 65c9c2fdda3e..15ef9b989f09 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2091,6 +2091,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None): ) if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file) or is_fsdp_ckpt: + weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {} # If the model is on the GPU, it still works! if is_sagemaker_mp_enabled(): if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")): @@ -2109,7 +2110,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None): state_dict = torch.load( weights_file, map_location="cpu", - weights_only=is_torch_greater_or_equal_than_1_13, + **weights_only_kwarg, ) # Required for smp to not auto-translate state_dict from hf to smp (is already smp). state_dict["_smp_is_partial"] = False @@ -2126,7 +2127,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None): state_dict = torch.load( weights_file, map_location="cpu", - weights_only=is_torch_greater_or_equal_than_1_13, + **weights_only_kwarg, ) # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963 @@ -2179,6 +2180,7 @@ def _load_best_model(self): or os.path.exists(best_safe_adapter_model_path) ): has_been_loaded = True + weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {} if is_sagemaker_mp_enabled(): if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")): # If the 'user_content.pt' file exists, load with the new smp api. @@ -2198,7 +2200,7 @@ def _load_best_model(self): state_dict = torch.load( best_model_path, map_location="cpu", - weights_only=is_torch_greater_or_equal_than_1_13, + **weights_only_kwarg, ) state_dict["_smp_is_partial"] = False @@ -2231,7 +2233,7 @@ def _load_best_model(self): state_dict = torch.load( best_model_path, map_location="cpu", - weights_only=is_torch_greater_or_equal_than_1_13, + **weights_only_kwarg, ) # If the model is on the GPU, it still works!