From be26611223f7546946b30e160d83fd34d61b1882 Mon Sep 17 00:00:00 2001 From: masui Date: Fri, 7 Feb 2020 15:03:32 +0900 Subject: [PATCH] #84 fix dct implementation --- spmimage/decomposition/dct.py | 97 +++++++++++++++++++++++++++++------ 1 file changed, 80 insertions(+), 17 deletions(-) diff --git a/spmimage/decomposition/dct.py b/spmimage/decomposition/dct.py index e8b8ab7..9ddd19f 100644 --- a/spmimage/decomposition/dct.py +++ b/spmimage/decomposition/dct.py @@ -1,32 +1,95 @@ +import itertools import numpy as np -def generate_dct_dictionary(patch_size: int, sqrt_dict_size: int) -> np.ndarray: +def zig_zag_index(k: int, n: int): + """ + get k-th index i and j on (n, n)-matrix according to zig-zag scan. + + Parameters: + ----------- + k : int + a ranking of element, which we want to know the index i and j + + n : int + a size of square matrix + + Returns: + ----------- + (i, j) : Tuple[int, int] + the tuple which represents the height and width index of k-th elements + + Reference + ---------- + https://medium.com/100-days-of-algorithms/day-63-zig-zag-51a41127f31 + """ + # upper side of interval + if k >= n * (n + 1) // 2: + i, j = zig_zag_index(n * n - 1 - k, n) + return n - 1 - i, n - 1 - j + + # lower side of interval + i = int((np.sqrt(1 + 8 * k) - 1) / 2) + j = k - i * (i + 1) // 2 + return (j, i - j) if i & 1 else (i - j, j) + +def generate_dct_atom(u, v, n): + """ + generate an (u, v)-th atom of DCT dictionary with size n by n. + + Parameters: + ----------- + u : int + an index for height + + v : int + an index for width + + n : int + a size of DCT + + Returns: + ----------- + atom : np.ndarray + (n, n) matrix which represents (u,v)-th atom of DCT dictionary + """ + atom = np.empty((n, n)) + for i, j in itertools.product(range(n), range(n)): + atom[i, j] = np.cos(((i+0.5)*u*np.pi)/n) * np.cos(((j+0.5)*v*np.pi)/n) + return atom + +def generate_dct_dictionary(n_components: int, patch_size: int) -> np.ndarray: """generate_dct_dictionary Generate a DCT dictionary. - An atom is a (patch_size, patch_size) image, and total number of atoms is - sqrt_dict_size * sqrt_dict_size. - The result D is a matrix whose shape is (sqrt_dict_size^2, patch_size^2). + An atom is a (patch_size, patch_size) image, and total number of atoms is + n_components. + + The result D is a matrix whose shape is (n_components, patch_size ** 2). Note that, a row of the result D shows an atom (flatten). Parameters: ------------ - patch_size : int - height and width of an atom of DCT dictionary + n_components: int + a number of atom, where n_components <= patch_size ** 2. - sqrt_dict_size : int - Total number of DCT atoms is a square number. - This parameter fix the number of atoms in the Dictionary. + patch_size : int + size of atom of DCT dictionary Returns: ------------ - D : np.ndarray, shape (sqrt_dict_size^2, patch_size^2) + D : np.ndarray, shape (n_components, patch_size ** 2) DCT dictionary """ - D1 = np.zeros((sqrt_dict_size, patch_size)) - for k in np.arange(sqrt_dict_size): - for i in np.arange(patch_size): - D1[k, i] = np.cos(i * k * np.pi / float(sqrt_dict_size)) - if k != 0: - D1[k, :] -= D1[k, :].mean() - return np.kron(D1, D1) + D = np.empty((n_components, patch_size ** 2)) + + if n_components > patch_size ** 2: + raise ValueError("n_components must be smaller than patch_size ** 2") + + elif n_components == patch_size ** 2: + for i, j in itertools.product(range(patch_size), range(patch_size)): + D[i*patch_size + j] = generate_dct_atom(i, j, patch_size).flatten() + else: + for k in range(n_components): + i, j = zig_zag_index(k, patch_size) + D[k, :] = generate_dct_atom(i, j, patch_size).flatten() + return D