Skip to content

Commit

Permalink
TST add tests for nd grid-search can train_test_split
Browse files Browse the repository at this point in the history
  • Loading branch information
amueller authored and jnothman committed Dec 25, 2014
1 parent 56ca8f9 commit 3d1ede9
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 8 deletions.
31 changes: 23 additions & 8 deletions sklearn/tests/test_cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,22 +652,37 @@ def test_train_test_split_errors():
def test_train_test_split():
X = np.arange(100).reshape((10, 10))
X_s = coo_matrix(X)
y = range(10)
split = cval.train_test_split(X, X_s, y, allow_lists=False)
X_train, X_test, X_s_train, X_s_test, y_train, y_test = split
assert_array_equal(X_train, X_s_train.toarray())
assert_array_equal(X_test, X_s_test.toarray())
assert_array_equal(X_train[:, 0], y_train * 10)
assert_array_equal(X_test[:, 0], y_test * 10)
y = np.arange(10)

# simple test
split = cval.train_test_split(X, y, test_size=None, train_size=.5)
X_train, X_test, y_train, y_test = split
assert_equal(len(y_test), len(y_train))
# test correspondence of X and y
assert_array_equal(X_train[:, 0], y_train * 10)
assert_array_equal(X_test[:, 0], y_test * 10)

# conversion of lists to arrays (deprecated?)
split = cval.train_test_split(X, X_s, y.tolist(), force_arrays=True)
X_train, X_test, X_s_train, X_s_test, y_train, y_test = split
assert_array_equal(X_train, X_s_train.toarray())
assert_array_equal(X_test, X_s_test.toarray())

split = cval.train_test_split(X, X_s, y)
# don't convert lists to anything else by default
split = cval.train_test_split(X, X_s, y.tolist())
X_train, X_test, X_s_train, X_s_test, y_train, y_test = split
assert_true(isinstance(y_train, list))
assert_true(isinstance(y_test, list))

# allow nd-arrays
X_4d = np.arange(10 * 5 * 3 * 2).reshape(10, 5, 3, 2)
y_3d = np.arange(10 * 7 * 11).reshape(10, 7, 11)
split = cval.train_test_split(X_4d, y_3d)
assert_equal(split[0].shape, (7, 5, 3, 2))
assert_equal(split[1].shape, (3, 5, 3, 2))
assert_equal(split[2].shape, (7, 7, 11))
assert_equal(split[3].shape, (3, 7, 11))


def train_test_split_pandas():
# check cross_val_score doesn't destroy pandas dataframe
Expand Down
12 changes: 12 additions & 0 deletions sklearn/tests/test_grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,18 @@ def test_refit():
clf.fit(X, y)


def test_gridsearch_nd():
"""Pass X as list in GridSearchCV"""
X_4d = np.arange(10 * 5 * 3 * 2).reshape(10, 5, 3, 2)
y_3d = np.arange(10 * 7 * 11).reshape(10, 7, 11)
check_X = lambda x: x.shape[1:] == (5, 3, 2)
check_y = lambda x: x.shape[1:] == (7, 11)
clf = CheckingClassifier(check_X=check_X, check_y=check_y)
grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]})
grid_search.fit(X_4d, y_3d).score(X, y)
assert_true(hasattr(grid_search, "grid_scores_"))


def test_X_as_list():
"""Pass X as list in GridSearchCV"""
X = np.arange(100).reshape(10, 10)
Expand Down

0 comments on commit 3d1ede9

Please sign in to comment.