Skip to content

Commit

Permalink
Merge pull request scikit-learn#7216 from b0noI/master
Browse files Browse the repository at this point in the history
[MRG+1] New text for the ValueError that is thrown by _check_param_grid method
  • Loading branch information
ogrisel authored Aug 31, 2016
2 parents c931bf0 + f906b9f commit db56cf4
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 8 deletions.
9 changes: 5 additions & 4 deletions sklearn/grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,17 +326,18 @@ def _check_param_grid(param_grid):
param_grid = [param_grid]

for p in param_grid:
for v in p.values():
for name, v in p.items():
if isinstance(v, np.ndarray) and v.ndim > 1:
raise ValueError("Parameter array should be one-dimensional.")

check = [isinstance(v, k) for k in (list, tuple, np.ndarray)]
if True not in check:
raise ValueError("Parameter values should be a list.")
raise ValueError("Parameter values for parameter ({0}) need "
"to be a sequence.".format(name))

if len(v) == 0:
raise ValueError("Parameter values should be a non-empty "
"list.")
raise ValueError("Parameter values for parameter ({0}) need "
"to be a non-empty sequence.".format(name))


class _CVScoreTuple (namedtuple('_CVScoreTuple',
Expand Down
9 changes: 5 additions & 4 deletions sklearn/model_selection/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,17 +328,18 @@ def _check_param_grid(param_grid):
param_grid = [param_grid]

for p in param_grid:
for v in p.values():
for name, v in p.items():
if isinstance(v, np.ndarray) and v.ndim > 1:
raise ValueError("Parameter array should be one-dimensional.")

check = [isinstance(v, k) for k in (list, tuple, np.ndarray)]
if True not in check:
raise ValueError("Parameter values should be a list.")
raise ValueError("Parameter values for parameter ({0}) need "
"to be a sequence.".format(name))

if len(v) == 0:
raise ValueError("Parameter values should be a non-empty "
"list.")
raise ValueError("Parameter values for parameter ({0}) need "
"to be a non-empty sequence.".format(name))


# XXX Remove in 0.20
Expand Down
16 changes: 16 additions & 0 deletions sklearn/model_selection/tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,22 @@ def test_grid_search():
assert_raises(ValueError, grid_search.fit, X, y)


def test_grid_search_incorrect_param_grid():
clf = MockClassifier()
assert_raise_message(
ValueError,
"Parameter values for parameter (C) need to be a sequence.",
GridSearchCV, clf, {'C': 1})


def test_grid_search_param_grid_includes_sequence_of_a_zero_length():
clf = MockClassifier()
assert_raise_message(
ValueError,
"Parameter values for parameter (C) need to be a non-empty sequence.",
GridSearchCV, clf, {'C': []})


@ignore_warnings
def test_grid_search_no_score():
# Test grid-search on classifier that has no score function.
Expand Down

0 comments on commit db56cf4

Please sign in to comment.