Skip to content

Commit

Permalink
[MRG] Fixed NMF IndexError (scikit-learn#11667)
Browse files Browse the repository at this point in the history
  • Loading branch information
zjpoh authored and jnothman committed Feb 12, 2019
1 parent 5486fd5 commit 42073c2
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 14 deletions.
10 changes: 10 additions & 0 deletions doc/whats_new/v0.21.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,16 @@ Support for Python 3.4 and below has been officially dropped.
the default value is used.
:issue:`12988` by :user:`Zijie (ZJ) Poh <zjpoh>`.

:mod:`sklearn.decomposition`
............................

- |Fix| Fixed a bug in :class:`decomposition.NMF` where `init = 'nndsvd'`,
`init = 'nndsvda'`, and `init = 'nndsvdar'` are allowed when
`n_components < n_features` instead of
`n_components <= min(n_samples, n_features)`.
:issue:`11650` by :user:`Hossein Pourbozorg <hossein-pourbozorg>` and
:user:`Zijie (ZJ) Poh <zjpoh>`.

:mod:`sklearn.discriminant_analysis`
....................................

Expand Down
14 changes: 11 additions & 3 deletions sklearn/decomposition/nmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,8 @@ def _initialize_nmf(X, n_components, init=None, eps=1e-6,
Default: None.
Valid options:
- None: 'nndsvd' if n_components < n_features, otherwise 'random'.
- None: 'nndsvd' if n_components <= min(n_samples, n_features),
otherwise 'random'.
- 'random': non-negative random matrices, scaled with:
sqrt(X.mean() / n_components)
Expand Down Expand Up @@ -304,8 +305,14 @@ def _initialize_nmf(X, n_components, init=None, eps=1e-6,
check_non_negative(X, "NMF initialization")
n_samples, n_features = X.shape

if (init is not None and init != 'random'
and n_components > min(n_samples, n_features)):
raise ValueError("init = '{}' can only be used when "
"n_components <= min(n_samples, n_features)"
.format(init))

if init is None:
if n_components < n_features:
if n_components <= min(n_samples, n_features):
init = 'nndsvd'
else:
init = 'random'
Expand Down Expand Up @@ -1104,7 +1111,8 @@ class NMF(BaseEstimator, TransformerMixin):
Default: None.
Valid options:
- None: 'nndsvd' if n_components < n_features, otherwise random.
- None: 'nndsvd' if n_components <= min(n_samples, n_features),
otherwise random.
- 'random': non-negative random matrices, scaled with:
sqrt(X.mean() / n_components)
Expand Down
34 changes: 23 additions & 11 deletions sklearn/decomposition/tests/test_nmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ def test_parameter_checking():
clf = NMF(2, tol=0.1).fit(A)
assert_raise_message(ValueError, msg, clf.transform, -A)

for init in ['nndsvd', 'nndsvda', 'nndsvdar']:
msg = ("init = '{}' can only be used when "
"n_components <= min(n_samples, n_features)"
.format(init))
assert_raise_message(ValueError, msg, NMF(3, init).fit, A)
assert_raise_message(ValueError, msg, nmf._initialize_nmf, A,
3, init)


def test_initialize_close():
# Test NNDSVD error
Expand Down Expand Up @@ -197,17 +205,21 @@ def test_non_negative_factorization_consistency():
A = np.abs(rng.randn(10, 10))
A[:, 2 * np.arange(5)] = 0

for solver in ('cd', 'mu'):
W_nmf, H, _ = non_negative_factorization(
A, solver=solver, random_state=1, tol=1e-2)
W_nmf_2, _, _ = non_negative_factorization(
A, H=H, update_H=False, solver=solver, random_state=1, tol=1e-2)

model_class = NMF(solver=solver, random_state=1, tol=1e-2)
W_cls = model_class.fit_transform(A)
W_cls_2 = model_class.transform(A)
assert_array_almost_equal(W_nmf, W_cls, decimal=10)
assert_array_almost_equal(W_nmf_2, W_cls_2, decimal=10)
for init in ['random', 'nndsvd']:
for solver in ('cd', 'mu'):
W_nmf, H, _ = non_negative_factorization(
A, init=init, solver=solver, random_state=1, tol=1e-2)
W_nmf_2, _, _ = non_negative_factorization(
A, H=H, update_H=False, init=init, solver=solver,
random_state=1, tol=1e-2)

model_class = NMF(init=init, solver=solver, random_state=1,
tol=1e-2)
W_cls = model_class.fit_transform(A)
W_cls_2 = model_class.transform(A)

assert_array_almost_equal(W_nmf, W_cls, decimal=10)
assert_array_almost_equal(W_nmf_2, W_cls_2, decimal=10)


def test_non_negative_factorization_checking():
Expand Down

0 comments on commit 42073c2

Please sign in to comment.