Skip to content

Commit

Permalink
Merge pull request #82 from hacarus/issue/80/init_dict_ksvd
Browse files Browse the repository at this point in the history
Issue/80/init dict ksvd
  • Loading branch information
y-iwao authored Feb 6, 2020
2 parents 2dd6a4a + 79bb584 commit dad0665
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 83 deletions.
139 changes: 62 additions & 77 deletions examples/ksvd_inpainting.ipynb

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion spmimage/decomposition/dict_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,17 @@ def sparse_encode_with_mask(X, dictionary, mask, **kwargs):
X : array-like, shape (n_samples, n_features)
Training vector, where n_samples in the number of samples
and n_features is the number of features.
dictionary : array of shape (n_components, n_features),
The dictionary factor
mask : array-like, shape (n_samples, n_features),
value at (i,j) in mask is not 1 indicates value at (i,j) in X is missing
verbose : bool
Degree of output the procedure will print.
**kwargs :
**kwargs :
algorithm : {'lasso_lars', 'lasso_cd', 'lars', 'omp', 'threshold'}
lars: uses the least angle regression method (linear_model.lars_path)
lasso_lars: uses Lars to compute the Lasso solution
Expand Down
45 changes: 40 additions & 5 deletions spmimage/decomposition/ksvd.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,28 +27,38 @@ def _ksvd(Y: np.ndarray, n_components: int, n_nonzero_coefs: int, max_iter: int,
Y : array-like, shape (n_samples, n_features)
Training vector, where n_samples in the number of samples
and n_features is the number of features.
n_components : int,
number of dictionary elements to extract
n_nonzero_coefs : int,
number of non-zero elements of sparse coding
max_iter : int,
maximum number of iterations to perform
tol : float,
tolerance for numerical error
dict_init : array of shape (n_components, n_features),
initial values for the dictionary, for warm restart
mask : array-like, shape (n_samples, n_features),
value at (i,j) in mask is not 1 indicates value at (i,j) in Y is missing
n_jobs : int, optional
Number of parallel jobs to run.
Returns:
---------
code : array of shape (n_samples, n_components)
The sparse code factor in the matrix factorization.
dictionary : array of shape (n_components, n_features),
The dictionary factor in the matrix factorization.
errors : array
Vector of errors at each iteration.
n_iter : int
Number of iterations run. Returned only if `return_n_iter` is
set to True.
Expand Down Expand Up @@ -110,12 +120,16 @@ class KSVD(BaseEstimator, SparseCodingMixin):
----------
n_components : int,
number of dictionary elements to extract
max_iter : int,
maximum number of iterations to perform
tol : float,
tolerance for numerical error
missing_value : float,
missing value in the data
transform_algorithm : {'lasso_lars', 'lasso_cd', 'lars', 'omp', 'threshold'}
Algorithm used to transform the data
lars: uses the least angle regression method (linear_model.lars_path)
Expand All @@ -128,10 +142,12 @@ class KSVD(BaseEstimator, SparseCodingMixin):
the projection ``dictionary * X'``
.. versionadded:: 0.17
*lasso_cd* coordinate descent method to improve speed.
transform_n_nonzero_coefs : int, ``0.1 * n_features`` by default
Number of nonzero coefficients to target in each column of the
solution. This is only used by `algorithm='lars'` and `algorithm='omp'`
and is overridden by `alpha` in the `omp` case.
transform_alpha : float, 1. by default
If `algorithm='lasso_lars'` or `algorithm='lasso_cd'`, `alpha` is the
penalty applied to the L1 norm.
Expand All @@ -140,27 +156,34 @@ class KSVD(BaseEstimator, SparseCodingMixin):
If `algorithm='omp'`, `alpha` is the tolerance parameter: the value of
the reconstruction error targeted. In this case, it overrides
`n_nonzero_coefs`.
n_jobs : int,
number of parallel jobs to run
split_sign : bool, False by default
Whether to split the sparse feature vector into the concatenation of
its negative part and its positive part. This can improve the
performance of downstream classifiers.
random_state : int, RandomState instance or None, optional (default=None)
If int, random_state is the seed used by the random number generator;
If RandomState instance, random_state is the random number generator;
If None, the random number generator is the RandomState instance used
by `np.random`.
method : {'approximate': Approximate KSVD, 'normal': normal KSVD}, 'approximate' by default
Attributes
----------
components_ : array, [n_components, n_features]
dictionary atoms extracted from the data
error_ : array
vector of errors at each iteration
n_iter_ : int
Number of iterations run.
**References:**
Elad, Michael, and Michal Aharon.
"Image denoising via sparse and redundant representations over learned dictionaries."
Expand All @@ -173,7 +196,7 @@ def __init__(self, n_components=None, max_iter=1000, tol=1e-8,
missing_value=None, transform_algorithm='omp',
transform_n_nonzero_coefs=None,
transform_alpha=None, n_jobs=1,
split_sign=False, random_state=None, method='approximate'):
split_sign=False, random_state=None, method='approximate', dict_init=None):
self._set_sparse_coding_params(n_components, transform_algorithm,
transform_n_nonzero_coefs,
transform_alpha, split_sign, n_jobs)
Expand All @@ -182,7 +205,7 @@ def __init__(self, n_components=None, max_iter=1000, tol=1e-8,
self.missing_value = missing_value
self.random_state = random_state
self.method = method
self.components_ = None
self.components_ = dict_init

def fit(self, X, y=None):
"""Fit the model from data in X.
Expand All @@ -191,7 +214,9 @@ def fit(self, X, y=None):
X : array-like, shape (n_samples, n_features)
Training vector, where n_samples in the number of samples
and n_features is the number of features.
y : Ignored
Returns
-------
self : object
Expand All @@ -213,9 +238,17 @@ def fit(self, X, y=None):

# initialize dictionary
dict_init = None
if self.components_ is not None and self.components_.shape == (n_components, n_features):
# Warm Start
dict_init = self.components_
if self.components_ is not None:
if self.components_.shape[1] != n_features:
raise ValueError("Found input variables with inconsistent numbers of n_features")
elif self.components_.shape[0] != n_components:
raise ValueError("Found input variables with inconsistent numbers of n_components")
else:
# Warm Start
logger.info("KSVD fit - warm start")
dict_init = self.components_
else:
logger.info("KSVD fit - cold start")

code, self.components_, self.error_, self.n_iter_ = _ksvd(
X, n_components, self.transform_n_nonzero_coefs,
Expand All @@ -228,11 +261,13 @@ def transform(self, X):
"""Encode the data as a sparse combination of the dictionary atoms.
Coding method is determined by the object parameter
`transform_algorithm`.
Parameters
----------
X : array of shape (n_samples, n_features)
Test data to be transformed, must have the same number of
features as the data used to train the model.
Returns
-------
code : array, shape (n_samples, n_components)
Expand Down
17 changes: 17 additions & 0 deletions tests/test_decomposition_ksvd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from spmimage.decomposition import KSVD

import numpy as np
import numpy.testing as npt

from tests.utils import generate_dictionary_and_samples

Expand Down Expand Up @@ -73,6 +74,22 @@ def test_ksvd_warm_start(self):
self.assertTrue(model.error_[-1] <= prev_error)
prev_error = model.error_[-1]

def test_ksvd_dict_init(self):
D = np.random.rand(10, 100)
model = KSVD(n_components=10, transform_n_nonzero_coefs=5, max_iter=1, method='normal', dict_init=D)
npt.assert_array_equal(model.components_, D)

# shape of X is invalid against initial dictionary
X = np.random.rand(20, 200)
with self.assertRaises(ValueError):
model.fit(X)

# n_components is invalid against initial dictionary
X = np.random.rand(20, 100)
model = KSVD(n_components=20, transform_n_nonzero_coefs=5, max_iter=1, method='normal', dict_init=D)
with self.assertRaises(ValueError):
model.fit(X)

def test_approximate_ksvd(self):
n_nonzero_coefs = 5
n_samples = 128
Expand Down

0 comments on commit dad0665

Please sign in to comment.