Skip to content

Commit

Permalink
Merge branch 'master' into pytorch-2.4
Browse files Browse the repository at this point in the history
  • Loading branch information
BenjaminBossan committed Sep 19, 2024
2 parents eb2963d + e724424 commit b115637
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 1 deletion.
2 changes: 2 additions & 0 deletions .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@ jobs:
with:
python-version: ${{ matrix.python_version }}
- name: Install dependencies
# TODO remove numpy version constraint once we no longer support PyTorch < 2.3
run: |
python -m pip install --upgrade pip
python -m pip install -r requirements-dev.txt
python -m pip install -r requirements.txt
python -m pip install --force-reinstall -U "numpy<2.0.0"
python -m pip install pytest-pretty
python -m pip install torch==${{ matrix.torch_version }} -f https://download.pytorch.org/whl/torch
python -m pip list
Expand Down
36 changes: 35 additions & 1 deletion skorch/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from skorch.utils import to_device
from skorch.utils import to_numpy
from skorch.utils import to_tensor
from skorch.utils import get_default_torch_load_kwargs


# pylint: disable=too-many-instance-attributes
Expand Down Expand Up @@ -235,6 +236,33 @@ class NeuralNet:
callbacks.
Implementation note: It is the job of the callbacks to honor this setting.
torch_load_kwargs : dict or None (default=None)
Additional arguments that will be passed to torch.load when load pickled
parameters.
In particular, this is important to because PyTorch will switch (probably
in version 2.6.0) to only allow weights to be loaded for security reasons
(i.e weights_only switches from False to True). As a consequence, loading
pickled parameters may raise an error after upgrading torch because some
types are used that are considered insecure. In skorch, we will also make
that switch at the same time. To resolve the error, follow the
instructions in the torch error message to designate the offending types
as secure. Only do this if you trust the source of the file.
If you want to keep loading non-weight types the same way as before,
please pass:
torch_load_kwargs={'weights_only': False}
You should be aware that this is considered insecure and should only be
used if you trust the source of the file. However, this does not introduce
new insecurities, it rather corresponds to the status quo from before
torch made the switch.
Another way to avoid this issue is to pass use_safetensors=True when
calling save_params and load_params. This avoid using pickle in favor of
the safetensors format, which is secure by design.
Attributes
----------
prefixes_ : list of str
Expand Down Expand Up @@ -311,6 +339,7 @@ def __init__(
device='cpu',
compile=False,
use_caching='auto',
torch_load_kwargs=None,
**kwargs
):
self.module = module
Expand All @@ -330,6 +359,7 @@ def __init__(
self.device = device
self.compile = compile
self.use_caching = use_caching
self.torch_load_kwargs = torch_load_kwargs

self._check_deprecated_params(**kwargs)
history = kwargs.pop('history', None)
Expand Down Expand Up @@ -2620,10 +2650,14 @@ def _get_state_dict(f_name):

return state_dict
else:
torch_load_kwargs = self.torch_load_kwargs
if torch_load_kwargs is None:
torch_load_kwargs = get_default_torch_load_kwargs()

def _get_state_dict(f_name):
map_location = get_map_location(self.device)
self.device = self._check_device(self.device, map_location)
return torch.load(f_name, map_location=map_location)
return torch.load(f_name, map_location=map_location, **torch_load_kwargs)

kwargs_full = {}
if checkpoint is not None:
Expand Down
108 changes: 108 additions & 0 deletions skorch/tests/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from unittest.mock import patch
import sys
import time
import warnings
from contextlib import ExitStack

from flaky import flaky
Expand All @@ -30,6 +31,7 @@
import torch
from torch import nn

import skorch
from skorch.tests.conftest import INFERENCE_METHODS
from skorch.utils import flatten
from skorch.utils import to_numpy
Expand Down Expand Up @@ -561,6 +563,17 @@ def test_load_params_unknown_attribute_raises(self, net_fit):
with pytest.raises(AttributeError, match=msg):
net_fit.load_params(f_unknown='some-file.pt')

def test_load_params_no_warning(self, net_fit, tmp_path, recwarn):
# See discussion in 1063
# Ensure that there is no FutureWarning (and DeprecationWarning for good
# measure) caused by torch.load.
net_fit.save_params(f_params=tmp_path / 'weights.pt')
net_fit.load_params(f_params=tmp_path / 'weights.pt')
assert not any(
isinstance(warning.message, (DeprecationWarning, FutureWarning))
for warning in recwarn.list
)

@pytest.mark.parametrize('use_safetensors', [False, True])
def test_save_load_state_dict_file(
self, net_cls, module_cls, net_fit, data, tmpdir, use_safetensors):
Expand Down Expand Up @@ -2983,6 +2996,101 @@ def test_save_load_state_dict_custom_module(
weights_loaded = net_new.custom_.state_dict()['sequential.3.weight']
assert (weights_before == weights_loaded).all()

def test_torch_load_kwargs_auto_weights_only_false_when_load_params(
self, net_cls, module_cls, monkeypatch, tmp_path
):
# Here we assume that the torch version is low enough that weights_only
# defaults to False. Check that when no argument is set in skorch, the
# right default is used.
# See discussion in 1063
net = net_cls(module_cls).initialize()
net.save_params(f_params=tmp_path / 'params.pkl')
state_dict = net.module_.state_dict()
expected_kwargs = {"weights_only": False}

mock_torch_load = Mock(return_value=state_dict)
monkeypatch.setattr(torch, "load", mock_torch_load)
monkeypatch.setattr(
skorch.net, "get_default_torch_load_kwargs", lambda: expected_kwargs
)

net.load_params(f_params=tmp_path / 'params.pkl')

call_kwargs = mock_torch_load.call_args_list[0].kwargs
del call_kwargs['map_location'] # we're not interested in that
assert call_kwargs == expected_kwargs

def test_torch_load_kwargs_auto_weights_only_true_when_load_params(
self, net_cls, module_cls, monkeypatch, tmp_path
):
# Here we assume that the torch version is high enough that weights_only
# defaults to True. Check that when no argument is set in skorch, the
# right default is used.
# See discussion in 1063
net = net_cls(module_cls).initialize()
net.save_params(f_params=tmp_path / 'params.pkl')
state_dict = net.module_.state_dict()
expected_kwargs = {"weights_only": True}

mock_torch_load = Mock(return_value=state_dict)
monkeypatch.setattr(torch, "load", mock_torch_load)
monkeypatch.setattr(
skorch.net, "get_default_torch_load_kwargs", lambda: expected_kwargs
)

net.load_params(f_params=tmp_path / 'params.pkl')

call_kwargs = mock_torch_load.call_args_list[0].kwargs
del call_kwargs['map_location'] # we're not interested in that
assert call_kwargs == expected_kwargs

def test_torch_load_kwargs_forwarded_to_torch_load(
self, net_cls, module_cls, monkeypatch, tmp_path
):
# Here we check that custom set torch load args are forwarded to
# torch.load.
# See discussion in 1063
expected_kwargs = {'weights_only': 123, 'foo': 'bar'}
net = net_cls(module_cls, torch_load_kwargs=expected_kwargs).initialize()
net.save_params(f_params=tmp_path / 'params.pkl')
state_dict = net.module_.state_dict()

mock_torch_load = Mock(return_value=state_dict)
monkeypatch.setattr(torch, "load", mock_torch_load)

net.load_params(f_params=tmp_path / 'params.pkl')

call_kwargs = mock_torch_load.call_args_list[0].kwargs
del call_kwargs['map_location'] # we're not interested in that
assert call_kwargs == expected_kwargs

def test_torch_load_kwargs_auto_weights_false_pytorch_lt_2_6(
self, net_cls, module_cls, monkeypatch, tmp_path
):
# Same test as
# test_torch_load_kwargs_auto_weights_only_false_when_load_params but
# without monkeypatching get_default_torch_load_kwargs. There is no
# corresponding test for >= 2.6.0 since it's not clear yet if the switch
# will be made in that version.
# See discussion in 1063.
from skorch._version import Version

if Version(torch.__version__) >= Version('2.6.0'):
pytest.skip("Test only for torch < v2.6.0")

net = net_cls(module_cls).initialize()
net.save_params(f_params=tmp_path / 'params.pkl')
state_dict = net.module_.state_dict()
expected_kwargs = {"weights_only": False}

mock_torch_load = Mock(return_value=state_dict)
monkeypatch.setattr(torch, "load", mock_torch_load)
net.load_params(f_params=tmp_path / 'params.pkl')

call_kwargs = mock_torch_load.call_args_list[0].kwargs
del call_kwargs['map_location'] # we're not interested in that
assert call_kwargs == expected_kwargs

def test_custom_module_params_passed_to_optimizer(
self, net_custom_module_cls, module_cls):
# custom module parameters should automatically be passed to the optimizer
Expand Down
16 changes: 16 additions & 0 deletions skorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from skorch.exceptions import DeviceWarning
from skorch.exceptions import NotInitializedError
from ._version import Version

try:
import torch_geometric
Expand Down Expand Up @@ -768,3 +769,18 @@ def _check_f_arguments(caller_name, **kwargs):
key = 'module_' if key == 'f_params' else key[2:] + '_'
kwargs_module[key] = val
return kwargs_module, kwargs_other


def get_default_torch_load_kwargs():
"""Returns the kwargs passed to torch.load that correspond to the current
torch version.
The plan is to switch from weights_only=False to True in PyTorch version
2.6.0, but depending on what happens, this may require updating.
"""
version_torch = Version(torch.__version__)
version_default_switch = Version('2.6.0')
if version_torch >= version_default_switch:
return {"weights_only": True}
return {"weights_only": False}

0 comments on commit b115637

Please sign in to comment.