-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
masui
committed
Feb 7, 2020
1 parent
dad0665
commit be26611
Showing
1 changed file
with
80 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,32 +1,95 @@ | ||
import itertools | ||
import numpy as np | ||
|
||
|
||
def generate_dct_dictionary(patch_size: int, sqrt_dict_size: int) -> np.ndarray: | ||
def zig_zag_index(k: int, n: int): | ||
""" | ||
get k-th index i and j on (n, n)-matrix according to zig-zag scan. | ||
Parameters: | ||
----------- | ||
k : int | ||
a ranking of element, which we want to know the index i and j | ||
n : int | ||
a size of square matrix | ||
Returns: | ||
----------- | ||
(i, j) : Tuple[int, int] | ||
the tuple which represents the height and width index of k-th elements | ||
Reference | ||
---------- | ||
https://medium.com/100-days-of-algorithms/day-63-zig-zag-51a41127f31 | ||
""" | ||
# upper side of interval | ||
if k >= n * (n + 1) // 2: | ||
i, j = zig_zag_index(n * n - 1 - k, n) | ||
return n - 1 - i, n - 1 - j | ||
|
||
# lower side of interval | ||
i = int((np.sqrt(1 + 8 * k) - 1) / 2) | ||
j = k - i * (i + 1) // 2 | ||
return (j, i - j) if i & 1 else (i - j, j) | ||
|
||
def generate_dct_atom(u, v, n): | ||
""" | ||
generate an (u, v)-th atom of DCT dictionary with size n by n. | ||
Parameters: | ||
----------- | ||
u : int | ||
an index for height | ||
v : int | ||
an index for width | ||
n : int | ||
a size of DCT | ||
Returns: | ||
----------- | ||
atom : np.ndarray | ||
(n, n) matrix which represents (u,v)-th atom of DCT dictionary | ||
""" | ||
atom = np.empty((n, n)) | ||
for i, j in itertools.product(range(n), range(n)): | ||
atom[i, j] = np.cos(((i+0.5)*u*np.pi)/n) * np.cos(((j+0.5)*v*np.pi)/n) | ||
return atom | ||
|
||
def generate_dct_dictionary(n_components: int, patch_size: int) -> np.ndarray: | ||
"""generate_dct_dictionary | ||
Generate a DCT dictionary. | ||
An atom is a (patch_size, patch_size) image, and total number of atoms is | ||
sqrt_dict_size * sqrt_dict_size. | ||
The result D is a matrix whose shape is (sqrt_dict_size^2, patch_size^2). | ||
An atom is a (patch_size, patch_size) image, and total number of atoms is | ||
n_components. | ||
The result D is a matrix whose shape is (n_components, patch_size ** 2). | ||
Note that, a row of the result D shows an atom (flatten). | ||
Parameters: | ||
------------ | ||
patch_size : int | ||
height and width of an atom of DCT dictionary | ||
n_components: int | ||
a number of atom, where n_components <= patch_size ** 2. | ||
sqrt_dict_size : int | ||
Total number of DCT atoms is a square number. | ||
This parameter fix the number of atoms in the Dictionary. | ||
patch_size : int | ||
size of atom of DCT dictionary | ||
Returns: | ||
------------ | ||
D : np.ndarray, shape (sqrt_dict_size^2, patch_size^2) | ||
D : np.ndarray, shape (n_components, patch_size ** 2) | ||
DCT dictionary | ||
""" | ||
D1 = np.zeros((sqrt_dict_size, patch_size)) | ||
for k in np.arange(sqrt_dict_size): | ||
for i in np.arange(patch_size): | ||
D1[k, i] = np.cos(i * k * np.pi / float(sqrt_dict_size)) | ||
if k != 0: | ||
D1[k, :] -= D1[k, :].mean() | ||
return np.kron(D1, D1) | ||
D = np.empty((n_components, patch_size ** 2)) | ||
|
||
if n_components > patch_size ** 2: | ||
raise ValueError("n_components must be smaller than patch_size ** 2") | ||
|
||
elif n_components == patch_size ** 2: | ||
for i, j in itertools.product(range(patch_size), range(patch_size)): | ||
D[i*patch_size + j] = generate_dct_atom(i, j, patch_size).flatten() | ||
else: | ||
for k in range(n_components): | ||
i, j = zig_zag_index(k, patch_size) | ||
D[k, :] = generate_dct_atom(i, j, patch_size).flatten() | ||
return D |