Skip to content

Commit

Permalink
#84 fix dct implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
masui committed Feb 7, 2020
1 parent dad0665 commit be26611
Showing 1 changed file with 80 additions and 17 deletions.
97 changes: 80 additions & 17 deletions spmimage/decomposition/dct.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,95 @@
import itertools
import numpy as np


def generate_dct_dictionary(patch_size: int, sqrt_dict_size: int) -> np.ndarray:
def zig_zag_index(k: int, n: int):
"""
get k-th index i and j on (n, n)-matrix according to zig-zag scan.
Parameters:
-----------
k : int
a ranking of element, which we want to know the index i and j
n : int
a size of square matrix
Returns:
-----------
(i, j) : Tuple[int, int]
the tuple which represents the height and width index of k-th elements
Reference
----------
https://medium.com/100-days-of-algorithms/day-63-zig-zag-51a41127f31
"""
# upper side of interval
if k >= n * (n + 1) // 2:
i, j = zig_zag_index(n * n - 1 - k, n)
return n - 1 - i, n - 1 - j

# lower side of interval
i = int((np.sqrt(1 + 8 * k) - 1) / 2)
j = k - i * (i + 1) // 2
return (j, i - j) if i & 1 else (i - j, j)

def generate_dct_atom(u, v, n):
"""
generate an (u, v)-th atom of DCT dictionary with size n by n.
Parameters:
-----------
u : int
an index for height
v : int
an index for width
n : int
a size of DCT
Returns:
-----------
atom : np.ndarray
(n, n) matrix which represents (u,v)-th atom of DCT dictionary
"""
atom = np.empty((n, n))
for i, j in itertools.product(range(n), range(n)):
atom[i, j] = np.cos(((i+0.5)*u*np.pi)/n) * np.cos(((j+0.5)*v*np.pi)/n)
return atom

def generate_dct_dictionary(n_components: int, patch_size: int) -> np.ndarray:
"""generate_dct_dictionary
Generate a DCT dictionary.
An atom is a (patch_size, patch_size) image, and total number of atoms is
sqrt_dict_size * sqrt_dict_size.
The result D is a matrix whose shape is (sqrt_dict_size^2, patch_size^2).
An atom is a (patch_size, patch_size) image, and total number of atoms is
n_components.
The result D is a matrix whose shape is (n_components, patch_size ** 2).
Note that, a row of the result D shows an atom (flatten).
Parameters:
------------
patch_size : int
height and width of an atom of DCT dictionary
n_components: int
a number of atom, where n_components <= patch_size ** 2.
sqrt_dict_size : int
Total number of DCT atoms is a square number.
This parameter fix the number of atoms in the Dictionary.
patch_size : int
size of atom of DCT dictionary
Returns:
------------
D : np.ndarray, shape (sqrt_dict_size^2, patch_size^2)
D : np.ndarray, shape (n_components, patch_size ** 2)
DCT dictionary
"""
D1 = np.zeros((sqrt_dict_size, patch_size))
for k in np.arange(sqrt_dict_size):
for i in np.arange(patch_size):
D1[k, i] = np.cos(i * k * np.pi / float(sqrt_dict_size))
if k != 0:
D1[k, :] -= D1[k, :].mean()
return np.kron(D1, D1)
D = np.empty((n_components, patch_size ** 2))

if n_components > patch_size ** 2:
raise ValueError("n_components must be smaller than patch_size ** 2")

elif n_components == patch_size ** 2:
for i, j in itertools.product(range(patch_size), range(patch_size)):
D[i*patch_size + j] = generate_dct_atom(i, j, patch_size).flatten()
else:
for k in range(n_components):
i, j = zig_zag_index(k, patch_size)
D[k, :] = generate_dct_atom(i, j, patch_size).flatten()
return D

0 comments on commit be26611

Please sign in to comment.