Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MNT: Update CI, add Python 3.11, fix sklearn issues #990

Merged
merged 11 commits into from
Jul 6, 2023
20 changes: 16 additions & 4 deletions .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,40 @@ on:
pull_request:
branches: [ master ]

# print more columns
env:
COLUMNS: 120

jobs:
test:

runs-on: ${{ matrix.os }}
strategy:
fail-fast: false # don't cancel all jobs when one fails
matrix:
python_version: ['3.8', '3.9', '3.10']
torch_version: ['1.11.0+cpu', '1.12.1+cpu', '1.13.1+cpu', '2.0.0+cpu']
python_version: ['3.8', '3.9', '3.10', '3.11']
torch_version: ['1.11.0+cpu', '1.12.1+cpu', '1.13.1+cpu', '2.0.1+cpu']
os: [ubuntu-latest]
exclude:
- python_version: '3.11'
torch_version: '1.11.0+cpu'
- python_version: '3.11'
torch_version: '1.12.1+cpu'
- python_version: '3.11'
torch_version: '2.0.1+cpu'

steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python_version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install -r requirements-dev.txt
python -m pip install -r requirements.txt
python -m pip install pytest-pretty
python -m pip install torch==${{ matrix.torch_version }} -f https://download.pytorch.org/whl/torch_stable.html
python -m pip list
- name: Install skorch
Expand Down
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ minor PyTorch versions, which currently are:
- 1.11.0
- 1.12.1
- 1.13.1
- 2.0.0
- 2.0.1

However, that doesn't mean that older versions don't work, just that
they aren't tested. Since skorch mostly relies on the stable part of
Expand Down
2 changes: 1 addition & 1 deletion docs/user/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ minor PyTorch versions, which currently are:
- 1.11.0
- 1.12.1
- 1.13.1
- 2.0.0
- 2.0.1

However, that doesn't mean that older versions don't work, just that
they aren't tested. Since skorch mostly relies on the stable part of
Expand Down
12 changes: 10 additions & 2 deletions skorch/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
from skorch.callbacks import EpochScoring
from skorch.callbacks import PassthroughScoring
from skorch.dataset import ValidSplit
from skorch.utils import get_dim, to_numpy
from skorch.utils import is_dataset
from skorch.utils import data_from_dataset, is_dataset, get_dim, to_numpy

neural_net_clf_doc_start = """NeuralNet for classification tasks

Expand Down Expand Up @@ -114,6 +113,15 @@ def check_data(self, X, y):
"``iterator_train`` and ``iterator_valid`` parameters "
"respectively.")
raise ValueError(msg)

if (y is None) and is_dataset(X):
try:
_, y_ds = data_from_dataset(X)
self.classes_inferred_ = np.unique(to_numpy(y_ds))
except AttributeError:
# If this fails, we might still be good to go, so don't raise
pass

if y is not None:
# pylint: disable=attribute-defined-outside-init
self.classes_inferred_ = np.unique(to_numpy(y))
Expand Down
14 changes: 8 additions & 6 deletions skorch/tests/callbacks/test_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,11 @@ def test_with_score_nonexisting_string(
max_epochs=2,
train_split=train_split,
)
with pytest.raises(ValueError) as exc:
# don't check precise error message, because it stems from sklearn and
# may change in the future
msg = "'does-not-exist'"
with pytest.raises(ValueError, match=msg):
net.fit(*data)
msg = "'does-not-exist' is not a valid scoring value."
assert exc.value.args[0].startswith(msg)

def test_with_score_as_custom_func(
self, net_cls, module_cls, scoring_cls, train_split, data, score55,
Expand Down Expand Up @@ -788,10 +789,11 @@ def test_with_score_nonexisting_string(
max_epochs=2,
train_split=train_split,
)
with pytest.raises(ValueError) as exc:
# don't check precise error message, because it stems from sklearn and
# may change in the future
msg = "'does-not-exist'"
with pytest.raises(ValueError, match=msg):
net.fit(*data)
msg = "'does-not-exist' is not a valid scoring value."
assert exc.value.args[0].startswith(msg)

def test_with_score_as_custom_func(
self, net_cls, module_cls, scoring_cls, train_split, data, score55,
Expand Down
24 changes: 24 additions & 0 deletions skorch/tests/test_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,30 @@ def test_pass_classes_explicitly_overrides(self, net_cls, module_cls, data):
net = net_cls(module_cls, max_epochs=0, classes=['foo', 'bar']).fit(*data)
assert net.classes_ == ['foo', 'bar']

def test_classes_are_set_with_tensordataset_explicit_y(
self, net_cls, module_cls, data
):
# see 990
X = torch.from_numpy(data[0])
y = torch.arange(len(X)) % 10
dataset = torch.utils.data.TensorDataset(X, y)
net = net_cls(module_cls, max_epochs=0).fit(dataset, y)
assert (net.classes_ == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]).all()

def test_classes_are_set_with_tensordataset_implicit_y(
self, net_cls, module_cls, data
):
# see 990
from skorch.dataset import ValidSplit

X = torch.from_numpy(data[0])
y = torch.arange(len(X)) % 10
dataset = torch.utils.data.TensorDataset(X, y)
net = net_cls(
module_cls, max_epochs=0, train_split=ValidSplit(3, stratified=False)
).fit(dataset, None)
assert (net.classes_ == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]).all()

@pytest.mark.parametrize('classes', [[], np.array([])])
def test_pass_empty_classes_raises(
self, net_cls, module_cls, data, classes):
Expand Down
23 changes: 23 additions & 0 deletions skorch/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,11 @@ def data(self):
y = np.array([1, 3, 0, 2])
return X, y

@pytest.fixture
def tensors(self, data):
X, y = data
return torch.from_numpy(X), torch.from_numpy(y)

@pytest.fixture
def skorch_ds(self, data):
from skorch.dataset import Dataset
Expand Down Expand Up @@ -501,6 +506,24 @@ def test_subset_with_y_none(self, data_from_dataset, data, subset):
assert (X == data[0][[1, 3]]).all()
assert y is None

def test_with_tensordataset_2_vals(self, data_from_dataset, tensors):
dataset = torch.utils.data.dataset.TensorDataset(*tensors)
X, y = data_from_dataset(dataset)
assert (X == tensors[0]).all()
assert (y == tensors[1]).all()

def test_with_tensordataset_1_val_raises(self, data_from_dataset, tensors):
dataset = torch.utils.data.dataset.TensorDataset(tensors[0])
msg = "Could not access X and y from dataset."
with pytest.raises(AttributeError, match=msg):
data_from_dataset(dataset)

def test_with_tensordataset_3_vals_raises(self, data_from_dataset, tensors):
dataset = torch.utils.data.dataset.TensorDataset(*tensors, tensors[0])
msg = "Could not access X and y from dataset."
with pytest.raises(AttributeError, match=msg):
data_from_dataset(dataset)


class TestMultiIndexing:
@pytest.fixture
Expand Down
8 changes: 8 additions & 0 deletions skorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,11 @@ def data_from_dataset(dataset, X_indexing=None, y_indexing=None):
If not None, use this function for indexing into the y data. If
None, try to automatically determine how to index data.

Raises
------
AttributeError
If X and y could not be accessed from the dataset.

"""
X, y = _none, _none

Expand All @@ -450,6 +455,9 @@ def data_from_dataset(dataset, X_indexing=None, y_indexing=None):
y = multi_indexing(y, dataset.indices, indexing=y_indexing)
elif hasattr(dataset, 'X') and hasattr(dataset, 'y'):
X, y = dataset.X, dataset.y
elif isinstance(dataset, torch.utils.data.dataset.TensorDataset):
if len(items := dataset.tensors) == 2:
X, y = items

if (X is _none) or (y is _none):
raise AttributeError("Could not access X and y from dataset.")
Expand Down