Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch 2.6 argument weights_only breaks mace_mp API #809

Open
PythonFZ opened this issue Jan 30, 2025 · 4 comments
Open

PyTorch 2.6 argument weights_only breaks mace_mp API #809

PythonFZ opened this issue Jan 30, 2025 · 4 comments

Comments

@PythonFZ
Copy link

Describe the bug
PyTorch has changed the default value for torch.loads argument weights_only to True.
See https://github.com/pytorch/pytorch/releases

This causes the following error:

zndraw-utils-1  | ╭───────────────────── Traceback (most recent call last) ──────────────────────╮
zndraw-utils-1  | │ /usr/local/lib/python3.11/site-packages/zndraw_utils/cli.py:32 in            │
zndraw-utils-1  | │ zndraw_register                                                              │
zndraw-utils-1  | │                                                                              │
zndraw-utils-1  | │   29 │   ),                                                                  │
zndraw-utils-1  | │   30 │   public: bool = typer.Option(True),                                  │
zndraw-utils-1  | │   31 ):                                                                      │
zndraw-utils-1  | │ ❱ 32 │   from mace.calculators import mace_mp                                │
zndraw-utils-1  | │   33 │   if names == ["all"]:                                                │
zndraw-utils-1  | │   34 │   │   names = [Methods.md, Methods.relax, Methods.smiles, Methods.sol │
zndraw-utils-1  | │   35                                                                         │
zndraw-utils-1  | │                                                                              │
zndraw-utils-1  | │ ╭────────────────────── locals ──────────────────────╮                       │
zndraw-utils-1  | │ │ auth_token = 'xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx'    │                       │
zndraw-utils-1  | │ │      names = [<Methods.all: 'all'>]                │                       │
zndraw-utils-1  | │ │     public = True                                  │                       │
zndraw-utils-1  | │ │      token = None                                  │                       │
zndraw-utils-1  | │ │        url = 'https://zndraw.icp.uni-stuttgart.de' │                       │
zndraw-utils-1  | │ ╰────────────────────────────────────────────────────╯                       │
zndraw-utils-1  | │                                                                              │
zndraw-utils-1  | │ /usr/local/lib/python3.11/site-packages/mace/calculators/__init__.py:1 in    │
zndraw-utils-1  | │ <module>                                                                     │
zndraw-utils-1  | │                                                                              │
zndraw-utils-1  | │ ❱  1 from .foundations_models import mace_anicc, mace_mp, mace_off           │
zndraw-utils-1  | │    2 from .lammps_mace import LAMMPS_MACE                                    │
zndraw-utils-1  | │    3 from .mace import MACECalculator                                        │
zndraw-utils-1  | │    4                                                                         │
zndraw-utils-1  | │                                                                              │
zndraw-utils-1  | │ /usr/local/lib/python3.11/site-packages/mace/calculators/foundations_models. │
zndraw-utils-1  | │ py:10 in <module>                                                            │
zndraw-utils-1  | │                                                                              │
zndraw-utils-1  | │     7 from ase import units                                                  │
zndraw-utils-1  | │     8 from ase.calculators.mixing import SumCalculator                       │
zndraw-utils-1  | │     9                                                                        │
zndraw-utils-1  | │ ❱  10 from .mace import MACECalculator                                       │
zndraw-utils-1  | │    11                                                                        │
zndraw-utils-1  | │    12 module_dir = os.path.dirname(__file__)                                 │
zndraw-utils-1  | │    13 local_model_path = os.path.join(                                       │
zndraw-utils-1  | │                                                                              │
zndraw-utils-1  | │ ╭───────────────────────────────── locals ─────────────────────────────────╮ │
zndraw-utils-1  | │ │     os = <module 'os' (frozen)>                                          │ │
zndraw-utils-1  | │ │  torch = <module 'torch' from                                            │ │
zndraw-utils-1  | │ │          '/usr/local/lib/python3.11/site-packages/torch/__init__.py'>    │ │
zndraw-utils-1  | │ │  Union = typing.Union                                                    │ │
zndraw-utils-1  | │ │  units = <module 'ase.units' from                                        │ │
zndraw-utils-1  | │ │          '/usr/local/lib/python3.11/site-packages/ase/units.py'>         │ │
zndraw-utils-1  | │ │ urllib = <module 'urllib' from                                           │ │
zndraw-utils-1  | │ │          '/usr/local/lib/python3.11/urllib/__init__.py'>                 │ │
zndraw-utils-1  | │ ╰──────────────────────────────────────────────────────────────────────────╯ │
zndraw-utils-1  | │                                                                              │
zndraw-utils-1  | │ /usr/local/lib/python3.11/site-packages/mace/calculators/mace.py:17 in       │
zndraw-utils-1  | │ <module>                                                                     │
zndraw-utils-1  | │                                                                              │
zndraw-utils-1  | │    14 import torch                                                           │
zndraw-utils-1  | │    15 from ase.calculators.calculator import Calculator, all_changes         │
zndraw-utils-1  | │    16 from ase.stress import full_3x3_to_voigt_6_stress                      │
zndraw-utils-1  | │ ❱  17 from e3nn import o3                                                    │
zndraw-utils-1  | │    18                                                                        │
zndraw-utils-1  | │    19 from mace import data                                                  │
zndraw-utils-1  | │    20 from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq         │
zndraw-utils-1  | │                                                                              │
zndraw-utils-1  | │ ╭───────────────────────────────── locals ─────────────────────────────────╮ │
zndraw-utils-1  | │ │ all_changes = [                                                          │ │
zndraw-utils-1  | │ │               │   'positions',                                           │ │
zndraw-utils-1  | │ │               │   'numbers',                                             │ │
zndraw-utils-1  | │ │               │   'cell',                                                │ │
zndraw-utils-1  | │ │               │   'pbc',                                                 │ │
zndraw-utils-1  | │ │               │   'initial_charges',                                     │ │
zndraw-utils-1  | │ │               │   'initial_magmoms'                                      │ │
zndraw-utils-1  | │ │               ]                                                          │ │
zndraw-utils-1  | │ │     logging = <module 'logging' from                                     │ │
zndraw-utils-1  | │ │               '/usr/local/lib/python3.11/logging/__init__.py'>           │ │
zndraw-utils-1  | │ │          np = <module 'numpy' from                                       │ │
zndraw-utils-1  | │ │               '/usr/local/lib/python3.11/site-packages/numpy/__init__.p… │ │
zndraw-utils-1  | │ │       torch = <module 'torch' from                                       │ │
zndraw-utils-1  | │ │               '/usr/local/lib/python3.11/site-packages/torch/__init__.p… │ │
zndraw-utils-1  | │ │       Union = typing.Union                                               │ │
zndraw-utils-1  | │ ╰──────────────────────────────────────────────────────────────────────────╯ │
zndraw-utils-1  | │                                                                              │
zndraw-utils-1  | │ /usr/local/lib/python3.11/site-packages/e3nn/o3/__init__.py:31 in <module>   │
zndraw-utils-1  | │                                                                              │
zndraw-utils-1  | │    28 │   angles_to_xyz,                                                     │
zndraw-utils-1  | │    29 │   xyz_to_angles,                                                     │
zndraw-utils-1  | │    30 )                                                                      │
zndraw-utils-1  | │ ❱  31 from ._wigner import wigner_D, wigner_3j                               │
zndraw-utils-1  | │    32 from ._irreps import Irrep, Irreps                                     │
zndraw-utils-1  | │    33 from ._tensor_product import (                                         │
zndraw-utils-1  | │    34 │   Instruction,                                                       │
zndraw-utils-1  | │                                                                              │
zndraw-utils-1  | │ ╭───────────────────────────────── locals ─────────────────────────────────╮ │
zndraw-utils-1  | │ │ _rotation = <module 'e3nn.o3._rotation' from                             │ │
zndraw-utils-1  | │ │             '/usr/local/lib/python3.11/site-packages/e3nn/o3/_rotation.… │ │
zndraw-utils-1  | │ ╰──────────────────────────────────────────────────────────────────────────╯ │
zndraw-utils-1  | │                                                                              │
zndraw-utils-1  | │ /usr/local/lib/python3.11/site-packages/e3nn/o3/_wigner.py:10 in <module>    │
zndraw-utils-1  | │                                                                              │
zndraw-utils-1  | │     7 from e3nn import o3                                                    │
zndraw-utils-1  | │     8 from e3nn.util import explicit_default_types                           │
zndraw-utils-1  | │     9                                                                        │
zndraw-utils-1  | │ ❱  10 _Jd, _W3j_flat, _W3j_indices = torch.load(os.path.join(os.path.dirname │
zndraw-utils-1  | │    11 # _Jd is a list of tensors of shape (2l+1, 2l+1)                       │
zndraw-utils-1  | │    12 # _W3j_flat is a flatten version of W3j symbols                        │
zndraw-utils-1  | │    13 # _W3j_indices is a dict from (l1, l2, l3) -> slice(i, j) to index the │
zndraw-utils-1  | │                                                                              │
zndraw-utils-1  | │ ╭───────────────────────────────── locals ─────────────────────────────────╮ │
zndraw-utils-1  | │ │    o3 = <module 'e3nn.o3' from                                           │ │
zndraw-utils-1  | │ │         '/usr/local/lib/python3.11/site-packages/e3nn/o3/__init__.py'>   │ │
zndraw-utils-1  | │ │    os = <module 'os' (frozen)>                                           │ │
zndraw-utils-1  | │ │ torch = <module 'torch' from                                             │ │
zndraw-utils-1  | │ │         '/usr/local/lib/python3.11/site-packages/torch/__init__.py'>     │ │
zndraw-utils-1  | │ ╰──────────────────────────────────────────────────────────────────────────╯ │
zndraw-utils-1  | │                                                                              │
zndraw-utils-1  | │ /usr/local/lib/python3.11/site-packages/torch/serialization.py:1470 in load  │
zndraw-utils-1  | │                                                                              │
zndraw-utils-1  | │   1467 │   │   │   │   │   │   │   **pickle_load_args,                       │
zndraw-utils-1  | │   1468 │   │   │   │   │   │   )                                             │
zndraw-utils-1  | │   1469 │   │   │   │   │   except pickle.UnpicklingError as e:               │
zndraw-utils-1  | │ ❱ 1470 │   │   │   │   │   │   raise pickle.UnpicklingError(_get_wo_message( │
zndraw-utils-1  | │   1471 │   │   │   │   return _load(                                         │
zndraw-utils-1  | │   1472 │   │   │   │   │   opened_zipfile,                                   │
zndraw-utils-1  | │   1473 │   │   │   │   │   map_location,                                     │
zndraw-utils-1  | │                                                                              │
zndraw-utils-1  | │ ╭───────────────────────────────── locals ─────────────────────────────────╮ │
zndraw-utils-1  | │ │               DOCS_MESSAGE = '\n\nCheck the documentation of torch.load  │ │
zndraw-utils-1  | │ │                              to learn more about types accepted by       │ │
zndraw-utils-1  | │ │                              de'+82                                      │ │
zndraw-utils-1  | │ │                          f = '/usr/local/lib/python3.11/site-packages/e… │ │
zndraw-utils-1  | │ │ force_no_weights_only_load = False                                       │ │
zndraw-utils-1  | │ │    force_weights_only_load = False                                       │ │
zndraw-utils-1  | │ │               map_location = None                                        │ │
zndraw-utils-1  | │ │                       mmap = False                                       │ │
zndraw-utils-1  | │ │                opened_file = <_io.BufferedReader                         │ │
zndraw-utils-1  | │ │                              name='/usr/local/lib/python3.11/site-packa… │ │
zndraw-utils-1  | │ │             opened_zipfile = <torch.PyTorchFileReader object at          │ │
zndraw-utils-1  | │ │                              0x77af0c531170>                             │ │
zndraw-utils-1  | │ │              orig_position = 0                                           │ │
zndraw-utils-1  | │ │            overall_storage = None                                        │ │
zndraw-utils-1  | │ │           pickle_load_args = {'encoding': 'utf-8'}                       │ │
zndraw-utils-1  | │ │              pickle_module = None                                        │ │
zndraw-utils-1  | │ │                  skip_data = False                                       │ │
zndraw-utils-1  | │ │                true_values = ['1', 'y', 'yes', 'true']                   │ │
zndraw-utils-1  | │ │               weights_only = True                                        │ │
zndraw-utils-1  | │ │       weights_only_not_set = True                                        │ │
zndraw-utils-1  | │ ╰──────────────────────────────────────────────────────────────────────────╯ │
zndraw-utils-1  | ╰──────────────────────────────────────────────────────────────────────────────╯
zndraw-utils-1  | UnpicklingError: Weights only load failed. This file can still be loaded, to do 
zndraw-utils-1  | so you have two options, do those steps only if you trust the source of the 
zndraw-utils-1  | checkpoint. 
zndraw-utils-1  |         (1) In PyTorch 2.6, we changed the default value of the `weights_only` 
zndraw-utils-1  | argument in `torch.load` from `False` to `True`. Re-running `torch.load` with 
zndraw-utils-1  | `weights_only` set to `False` will likely succeed, but it can result in 
zndraw-utils-1  | arbitrary code execution. Do it only if you got the file from a trusted source.
zndraw-utils-1  |         (2) Alternatively, to load with `weights_only=True` please check the 
zndraw-utils-1  | recommended steps in the following error message.
zndraw-utils-1  |         WeightsUnpickler error: Unsupported global: GLOBAL builtins.slice was 
zndraw-utils-1  | not an allowed global by default. Please use 
zndraw-utils-1  | `torch.serialization.add_safe_globals([slice])` or the 
zndraw-utils-1  | `torch.serialization.safe_globals([slice])` context manager to allowlist this 
zndraw-utils-1  | global if you trust this class/function.
zndraw-utils-1  | 
zndraw-utils-1  | Check the documentation of torch.load to learn more about types accepted by 
zndraw-utils-1  | default with weights_only 
zndraw-utils-1  | https://pytorch.org/docs/stable/generated/torch.load.html.
@ilyes319 ilyes319 changed the title PyTorch 6.0 argument weights_only breaks mace_mp API PyTorch 2.6 argument weights_only breaks mace_mp API Jan 30, 2025
@ilyes319
Copy link
Contributor

I think we need to add these in some init TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1

@janosh
Copy link
Contributor

janosh commented Feb 3, 2025

not sure what the best long-term solution that plays well with future e3nn versions looks like but since MACE already pins e3nn==0.4.4, as a stop-gap measure, it'd be enough for torch==2.6 support to insert weights_only=False here:

self.models = [
torch.load(f=model_path, map_location=device)
for model_path in model_paths
]

@ilyes319
Copy link
Contributor

ilyes319 commented Feb 4, 2025

that should be fixed now!

@PythonFZ
Copy link
Author

PythonFZ commented Feb 4, 2025

Do you want to keep this issue open and we look for a more secure solution than:

os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1"

or can we close this now?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants