Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wavelet transform fixes #259

Merged
merged 11 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 24 additions & 10 deletions clouddrift/wavelet.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,15 @@ def morse_wavelet_transform(
If the input signal is real as specificied by ``complex=False``:

wtx : np.ndarray
Time-domain wavelet transform of input ``x``. The axes of ``wtx`` will be organized as (x axes), orders, frequencies, time
Time-domain wavelet transform of input ``x``. The axes of ``wtx`` will be organized as (x axes), time, frequencies, orders
unless ``time_axis`` is different from last (-1) in which case it will be moved back to its original position within the axes of ``x``.

If the input signal is complex as specificied by ``complex=True``:
If the input signal is complex as specificied by ``complex=True``, a tuple is returned:

wtx_p: np.array
Time-domain positive wavelet transform of input ``x``.
Time-domain positive wavelet transform of input ``x``, with axes organized as in the ``complex=False`` case.
wtx_n: np.array
Time-domain negative wavelet transform of input ``x``.
Time-domain negative wavelet transform of input ``x``, with axes organized as in the ``complex=False`` case.

Examples
--------
Expand Down Expand Up @@ -181,10 +181,15 @@ def morse_wavelet_transform(
)
wtx = wtx_p, wtx_n

else:
elif ~complex:
# real case
wtx = wavelet_transform(x, wavelet, boundary=boundary, time_axis=time_axis)

else:
raise ValueError(
"`complex` optional argument must be boolean 'True' or 'False'"
)

return wtx


Expand Down Expand Up @@ -224,7 +229,7 @@ def wavelet_transform(
Returns
-------
wtx : np.ndarray
Time-domain wavelet transform of input ``x``. The axes of ``wtx`` will be organized as (x axes), orders, frequencies, time
Time-domain wavelet transform of input ``x``. The axes of ``wtx`` will be organized as (x axes), time, frequencies, orders
unless ``time_axis`` is different from last (-1) in which case it will be moved back to its original position within the axes of ``x``.

Examples
Expand Down Expand Up @@ -261,7 +266,7 @@ def wavelet_transform(
)
# Positions and time arrays must have the same shape.
if x.shape[time_axis] != wavelet.shape[-1]:
raise ValueError("x and wave time axes must have the same length.")
raise ValueError("x and wavelet time axes must have the same length.")

wavelet_ = np.moveaxis(wavelet, [freq_axis, order_axis], [-2, -3])

Expand Down Expand Up @@ -319,10 +324,19 @@ def wavelet_transform(
complex_dtype = np.cdouble if x.dtype == np.single else np.csingle
wtx = np.fft.ifft(X_ * np.conj(_wavelet_fft)).astype(complex_dtype)
wtx = wtx[..., index]
# remove extra dimensions

# reorder as ((shape of x),length, freq_axis, order_axis) = ((shape of x),-3,-2,-1)
wtx = np.moveaxis(wtx, [-1], [-3]) # move length to -3
wtx = np.moveaxis(wtx, [-2], [-1]) # move order to -1

# reposition the time axis if needed from axis -3
if time_axis != -1:
wtx = np.moveaxis(wtx, -3, time_axis)
else:
pass
selipot marked this conversation as resolved.
Show resolved Hide resolved

# remove extra dimensions if needed
wtx = np.squeeze(wtx)
# reposition the time axis: should I add a condition to do so only if time_axis!=-1? works anyway
wtx = np.moveaxis(wtx, -1, time_axis)

return wtx

Expand Down
24 changes: 19 additions & 5 deletions tests/wavelet_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,23 @@ def test_wavelet_transform_size(self):
gamma = 3
beta = 4
x = np.random.random((m, m * 2, length))
wave, _ = morse_wavelet(length, gamma, beta, radian_frequency, order=order)
w = wavelet_transform(x, wave)
self.assertTrue(np.shape(w) == (m, m * 2, order, len(radian_frequency), length))
wavelet, _ = morse_wavelet(length, gamma, beta, radian_frequency, order=order)
wtx = wavelet_transform(x, wavelet)
self.assertTrue(
np.shape(wtx) == (m, m * 2, length, len(radian_frequency), order)
)
x = np.random.random((length, m, m * 2))
wavelet, _ = morse_wavelet(length, gamma, beta, radian_frequency, order=order)
wtx = wavelet_transform(x, wavelet, time_axis=0)
self.assertTrue(
np.shape(wtx) == (length, m, m * 2, len(radian_frequency), order)
)
x = np.random.random((m, length, m * 2))
wavelet, _ = morse_wavelet(length, gamma, beta, radian_frequency, order=order)
wtx = wavelet_transform(x, wavelet, time_axis=1)
self.assertTrue(
np.shape(wtx) == (m, length, m * 2, len(radian_frequency), order)
)

def test_wavelet_transform_size_axis(self):
length = 1024
Expand All @@ -157,7 +171,7 @@ def test_wavelet_transform_size_axis(self):
x = np.random.random((length, m, int(m / 2)))
wave, _ = morse_wavelet(length, gamma, beta, radian_frequency, order=order)
w = wavelet_transform(x, wave, time_axis=0)
self.assertTrue(np.shape(w) == (length, m, m / 2, order, len(radian_frequency)))
self.assertTrue(np.shape(w) == (length, m, m / 2, len(radian_frequency), order))

def test_wavelet_transform_centered(self):
J = 10
Expand All @@ -166,7 +180,7 @@ def test_wavelet_transform_centered(self):
wave, _ = morse_wavelet(len(x), 2, 4, ao, order=1)
x[2**9] = 1
y = wavelet_transform(x, wave)
m = np.argmax(np.abs(y), axis=-1)
m = np.argmax(np.abs(y), axis=-2)
self.assertTrue(np.allclose(m, 2**9))

def test_wavelet_transform_data_real(self):
Expand Down