Skip to content

Commit

Permalink
chore: simplify testing logic
Browse files Browse the repository at this point in the history
  • Loading branch information
cako committed Aug 5, 2024
1 parent 45b5fbc commit b9618f5
Showing 1 changed file with 47 additions and 39 deletions.
86 changes: 47 additions & 39 deletions pytests/test_dtcwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,67 +20,75 @@ def sequential_array(shape):
@pytest.mark.parametrize("par", [(par1), (par2)])
def test_dtcwt1D_input1D(par):
"""Test for DTCWT with 1D input"""
if int(np_version[0]) < 2:
t = sequential_array((par["ny"],))
if int(np_version[0]) >= 2:
return

for level in range(1, 10):
Dtcwt = DTCWT(dims=t.shape, level=level, dtype=par["dtype"])
x = Dtcwt @ t
y = Dtcwt.H @ x
t = sequential_array((par["ny"],))

np.testing.assert_allclose(t, y)
for level in range(1, 10):
Dtcwt = DTCWT(dims=t.shape, level=level, dtype=par["dtype"])
x = Dtcwt @ t
y = Dtcwt.H @ x

np.testing.assert_allclose(t, y)


@pytest.mark.parametrize("par", [(par1), (par2)])
def test_dtcwt1D_input2D(par):
"""Test for DTCWT with 2D input (forward-inverse pair)"""
if int(np_version[0]) < 2:
t = sequential_array(
(
par["ny"],
par["ny"],
)
if int(np_version[0]) >= 2:
return

t = sequential_array(
(
par["ny"],
par["ny"],
)
)

for level in range(1, 10):
Dtcwt = DTCWT(dims=t.shape, level=level, dtype=par["dtype"])
x = Dtcwt @ t
y = Dtcwt.H @ x
for level in range(1, 10):
Dtcwt = DTCWT(dims=t.shape, level=level, dtype=par["dtype"])
x = Dtcwt @ t
y = Dtcwt.H @ x

np.testing.assert_allclose(t, y)
np.testing.assert_allclose(t, y)


@pytest.mark.parametrize("par", [(par1), (par2)])
def test_dtcwt1D_input3D(par):
"""Test for DTCWT with 3D input (forward-inverse pair)"""
if int(np_version[0]) < 2:
t = sequential_array((par["ny"], par["ny"], par["ny"]))
if int(np_version[0]) >= 2:
return

t = sequential_array((par["ny"], par["ny"], par["ny"]))

for level in range(1, 10):
Dtcwt = DTCWT(dims=t.shape, level=level, dtype=par["dtype"])
x = Dtcwt @ t
y = Dtcwt.H @ x
for level in range(1, 10):
Dtcwt = DTCWT(dims=t.shape, level=level, dtype=par["dtype"])
x = Dtcwt @ t
y = Dtcwt.H @ x

np.testing.assert_allclose(t, y)
np.testing.assert_allclose(t, y)


@pytest.mark.parametrize("par", [(par1), (par2)])
def test_dtcwt1D_birot(par):
"""Test for DTCWT birot (forward-inverse pair)"""
if int(np_version[0]) < 2:
birots = ["antonini", "legall", "near_sym_a", "near_sym_b"]

t = sequential_array(
(
par["ny"],
par["ny"],
)
if int(np_version[0]) >= 2:
return

birots = ["antonini", "legall", "near_sym_a", "near_sym_b"]

t = sequential_array(
(
par["ny"],
par["ny"],
)
)

for _b in birots:
print(f"birot {_b}")
Dtcwt = DTCWT(dims=t.shape, biort=_b, dtype=par["dtype"])
x = Dtcwt @ t
y = Dtcwt.H @ x
for _b in birots:
print(f"birot {_b}")
Dtcwt = DTCWT(dims=t.shape, biort=_b, dtype=par["dtype"])
x = Dtcwt @ t
y = Dtcwt.H @ x

np.testing.assert_allclose(t, y)
np.testing.assert_allclose(t, y)

0 comments on commit b9618f5

Please sign in to comment.