Skip to content

Commit

Permalink
FIX added assertion for ValueError when cv iterator is empty (scikit-…
Browse files Browse the repository at this point in the history
…learn#12961)

* added assertion for ValueError when cv iterator is empty

* test value error message

* change to use match= instead of message= for pytest.raises

* fix regex escape char for flake8

* Added test case for unmached cv result length

* fix travis-ci issue

* use global X, y and n_split=3

* use true global X, y
  • Loading branch information
esvhd authored and adrinjalali committed Feb 5, 2019
1 parent c2d56a6 commit 851a4b8
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 0 deletions.
11 changes: 11 additions & 0 deletions sklearn/model_selection/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,17 @@ def evaluate_candidates(candidate_params):
in product(candidate_params,
cv.split(X, y, groups)))

if len(out) < 1:
raise ValueError('No fits were performed. '
'Was the CV iterator empty? '
'Were there no candidates?')
elif len(out) != n_candidates * n_splits:
raise ValueError('cv.split and cv.get_n_splits returned '
'inconsistent results. Expected {} '
'splits, got {}'
.format(n_splits,
len(out) // n_candidates))

all_candidate_params.extend(candidate_params)
all_out.extend(out)

Expand Down
44 changes: 44 additions & 0 deletions sklearn/model_selection/tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -1725,3 +1725,47 @@ def test_deprecated_grid_search_iid():
grid = GridSearchCV(SVC(gamma='scale'), param_grid={'C': [1]}, cv=KFold(2))
# no warning because no stratification and 54 % 2 == 0
assert_no_warnings(grid.fit, X, y)


def test_empty_cv_iterator_error():
# Use global X, y

# create cv
cv = KFold(n_splits=3).split(X)

# pop all of it, this should cause the expected ValueError
[u for u in cv]
# cv is empty now

train_size = 100
ridge = RandomizedSearchCV(Ridge(), {'alpha': [1e-3, 1e-2, 1e-1]},
cv=cv, n_jobs=-1)

# assert that this raises an error
with pytest.raises(ValueError,
match='No fits were performed. '
'Was the CV iterator empty\\? '
'Were there no candidates\\?'):
ridge.fit(X[:train_size], y[:train_size])


def test_random_search_bad_cv():
# Use global X, y

class BrokenKFold(KFold):
def get_n_splits(self, *args, **kw):
return 1

# create bad cv
cv = BrokenKFold(n_splits=3)

train_size = 100
ridge = RandomizedSearchCV(Ridge(), {'alpha': [1e-3, 1e-2, 1e-1]},
cv=cv, n_jobs=-1)

# assert that this raises an error
with pytest.raises(ValueError,
match='cv.split and cv.get_n_splits returned '
'inconsistent results. Expected \\d+ '
'splits, got \\d+'):
ridge.fit(X[:train_size], y[:train_size])

0 comments on commit 851a4b8

Please sign in to comment.