Skip to content

Commit

Permalink
remove duplicate code (#38)
Browse files Browse the repository at this point in the history
The PR refactors code duplication and activates pylint to check it.
  • Loading branch information
huppd authored Aug 26, 2024
1 parent 45ef591 commit 3b7d668
Show file tree
Hide file tree
Showing 14 changed files with 418 additions and 310 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ repos:
types: [python]
args:
- "--max-line-length=88"
- "--disable=C0116,R0801,R0912,R0913,R0914,R0915,R1710,W0511,W0719"
- "--disable=C0116,R0912,R0913,R0914,R0915,R1710,W0511,W0719"
- repo: local
hooks:
- id: flake8
Expand Down
11 changes: 6 additions & 5 deletions engine/perturb.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from util.click_util import CommaSeperatedInts, CommaSeperatedStrings, cli_help
from util.log_handler import logger
from util.netcdf_io import nc4_get_copy
from util.utils import get_seed_from_member_num
from util.utils import get_seed_from_member_num, process_member_num


def create_perturb_files(in_path, in_files, out_path, copy_all_files=False):
Expand Down Expand Up @@ -101,10 +101,11 @@ def perturb(
perturb_amplitude,
copy_all_files,
): # pylint: disable=unused-argument
if len(member_num) == 1:
member_num = list(range(1, member_num[0] + 1))
for m_num in member_num:
m_id = str(m_num)

processed_member_num = process_member_num(member_num)

for m_num, m_id in processed_member_num:

if member_type:
m_id = member_type + "_" + m_id
perturbed_model_input_dir_member_id = perturbed_model_input_dir.format(
Expand Down
10 changes: 5 additions & 5 deletions engine/run_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from util.click_util import CommaSeperatedInts, CommaSeperatedStrings, cli_help
from util.log_handler import logger
from util.utils import get_seed_from_member_num
from util.utils import get_seed_from_member_num, process_member_num


def is_float(string):
Expand Down Expand Up @@ -231,10 +231,10 @@ def run_ensemble(
append_job(job, job_list, parallel)

# run the ensemble
if len(member_num) == 1:
member_num = list(range(1, member_num[0] + 1))
for m_num in member_num:
m_id = str(m_num)
processed_member_num = process_member_num(member_num)

for m_num, m_id in processed_member_num:

Path(perturbed_run_dir.format(member_id=m_id)).mkdir(
exist_ok=True, parents=True
)
Expand Down
37 changes: 37 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,40 @@ def stats_file_set(tmp_dir):
yield files
if os.path.exists(files["tol"]):
os.remove(files["tol"])


@pytest.fixture(name="setup_csv_files")
def fixture_setup_csv_files(tmp_path):
# Create sample CSV files for testing
tolerance_data = pd.DataFrame(
{"A": [0.1, 0.2], "B": [0.3, 0.4]},
index=pd.MultiIndex.from_tuples(
[("a", "b"), ("c", "d")], names=["col1", "col2"]
),
)
ref_data = pd.DataFrame(
{"A": [1, 2], "B": [3, 4]},
index=pd.MultiIndex.from_tuples(
[("a", "b", "c"), ("d", "e", "f")], names=["col1", "col2", "col3"]
),
)
cur_data = pd.DataFrame(
{"A": [2, 3], "B": [4, 5]},
index=pd.MultiIndex.from_tuples(
[("a", "b", "c"), ("d", "e", "f")], names=["col1", "col2", "col3"]
),
)

tolerance_file = tmp_path / "tolerance_test.csv"
ref_file = tmp_path / "input_ref_test.csv"
cur_file = tmp_path / "input_cur_test.csv"

tolerance_data.to_csv(tolerance_file)
ref_data.to_csv(ref_file)
cur_data.to_csv(cur_file)

return {
"tolerance_file": tolerance_file,
"ref_file": ref_file,
"cur_file": cur_file,
}
171 changes: 64 additions & 107 deletions tests/engine/test_perturb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,8 @@
and amplitude of the perturbations.
"""

import os
import shutil
import unittest

import numpy as np
import pytest
from matplotlib import pyplot as plt
from netCDF4 import Dataset # pylint: disable=no-name-in-module

Expand All @@ -20,106 +17,66 @@
ARRAY_DIM = 100


class TestPerturb(unittest.TestCase):
"""
Unit tests for verifying the functionality of perturbation methods applied
to arrays and NetCDF datasets.
This class uses the `unittest` framework to ensure the correctness of
perturbation operations on both in-memory arrays and NetCDF4 datasets.
It checks that perturbation functions produce the expected results and
handle different data types correctly.
"""

def __init__(self, methodName="runTest"):
super().__init__(methodName=methodName)
self.data = [None] * 2

@classmethod
def setUpClass(cls):
test_path = os.path.realpath("tests/tmp")
cls.test_path = test_path
# create test directory (remake if it exists)
if os.path.exists(test_path):
shutil.rmtree(test_path)
os.mkdir(test_path)

def test_perturb_array(self):
# create two arrays, perturb one.
# This is to make sure that we get a copy of the input from perturb
x1 = np.ones((ARRAY_DIM, ARRAY_DIM), dtype=atype)
x2 = np.ones((ARRAY_DIM, ARRAY_DIM), dtype=atype)
x_perturbed = perturb_array(x2, 10, AMPLITUDE)

# compute some stats and do assertions
diff1 = np.abs(x1 - x_perturbed) / x1
diff2 = np.abs(x2 - x_perturbed) / x2
mean_diff1 = np.mean(diff1)
mean_diff2 = np.mean(diff2)
self.assertLess(
np.max(diff2), AMPLITUDE, msg="perturbation is larger than amplitude!"
)
self.assertAlmostEqual(
mean_diff1,
AMPLITUDE * 0.5,
msg="perturbation is most likely too small!",
delta=AMPLITUDE * 1e-2,
)
self.assertAlmostEqual(
mean_diff2,
AMPLITUDE * 0.5,
msg="perturbation did not return a copy of input!",
delta=AMPLITUDE * 1e-2,
)

def test_perturb_nc(self):

# create two dummy netcdf4 files, one with single, one with double precision
for i, dt in enumerate([np.float32, np.float64]):
self.data[i] = Dataset(f"dummy{i}.nc", "w")
self.data[i].createDimension("x", size=ARRAY_DIM)
self.data[i].createDimension("y", size=ARRAY_DIM)
self.data[i].createVariable("z", dt, dimensions=("x", "y"))
self.data[i].variables["z"][:] = np.ones((ARRAY_DIM, ARRAY_DIM))

# perturb the double precision file
zd = self.data[1].variables["z"][:]
zd_perturb = perturb_array(zd, 10, AMPLITUDE)
self.data[1].variables["z"][:] = zd_perturb

# close the data
for i in range(2):
self.data[i].close()

# reopen the files to make sure we get the values form disk
for i in range(2):
self.data[i] = Dataset(f"dummy{i}.nc", "r")
zf = self.data[0].variables["z"][:]
zd_perturb = self.data[1].variables["z"][:]

# compute some stats and do assertions
diff = zf - zd_perturb
mean_diff = np.mean(np.abs(diff))
self.assertTrue(zf.dtype == "float32", msg="zf is not float32!")
self.assertTrue(zd.dtype == "float64", msg="zf is not float64!")
self.assertLess(
np.max(diff), AMPLITUDE, msg="perturbation is larger than amplitude!"
)
self.assertAlmostEqual(
mean_diff,
AMPLITUDE * 0.5,
msg="perturbation is most likely too small!",
delta=AMPLITUDE * 1e-2,
)

# to really make sure:
# create a plot of the difference to be read out manually if you are scared
fig, ax = plt.subplots(1, 1)
xx, yy = np.meshgrid(np.linspace(0, 1, ARRAY_DIM), np.linspace(0, 1, ARRAY_DIM))
cs = ax.contourf(xx, yy, diff)
fig.colorbar(cs, ax=ax)
fig.savefig(f"{self.test_path}/diff_figure.pdf")


if __name__ == "__main__":
unittest.main()
@pytest.fixture(name="create_nc_files")
def fixture_create_nc_files(tmp_dir):
data = [None] * 2
for i, dt in enumerate([np.float32, np.float64]):
data[i] = Dataset(f"{tmp_dir}/dummy{i}.nc", "w")
data[i].createDimension("x", size=ARRAY_DIM)
data[i].createDimension("y", size=ARRAY_DIM)
data[i].createVariable("z", dt, dimensions=("x", "y"))
data[i].variables["z"][:] = np.ones((ARRAY_DIM, ARRAY_DIM))
yield data
for d in data:
d.close()


def test_perturb_array():
# create two arrays, perturb one.
x1 = np.ones((ARRAY_DIM, ARRAY_DIM), dtype=atype)
x2 = np.ones((ARRAY_DIM, ARRAY_DIM), dtype=atype)
x_perturbed = perturb_array(x2, 10, AMPLITUDE)

# compute some stats and do assertions
diff1 = np.abs(x1 - x_perturbed) / x1
diff2 = np.abs(x2 - x_perturbed) / x2
mean_diff1 = np.mean(diff1)
mean_diff2 = np.mean(diff2)
assert np.max(diff2) < AMPLITUDE, "perturbation is larger than amplitude!"
assert np.isclose(
mean_diff1, AMPLITUDE * 0.5, atol=AMPLITUDE * 1e-2
), "perturbation is most likely too small!"
assert np.isclose(
mean_diff2, AMPLITUDE * 0.5, atol=AMPLITUDE * 1e-2
), "perturbation did not return a copy of input!"


def test_perturb_nc(tmp_dir, create_nc_files):

data = create_nc_files
zd = data[1].variables["z"][:]
zd_perturb = perturb_array(zd, 10, AMPLITUDE)
data[1].variables["z"][:] = zd_perturb

# reopen the files to make sure we get the values from disk
for i in range(2):
data[i] = Dataset(f"{tmp_dir}/dummy{i}.nc", "r")
zf = data[0].variables["z"][:]
zd_perturb = data[1].variables["z"][:]

# compute some stats and do assertions
diff = zf - zd_perturb
mean_diff = np.mean(np.abs(diff))
assert zf.dtype == "float32", "zf is not float32!"
assert zd.dtype == "float64", "zf is not float64!"
assert np.max(diff) < AMPLITUDE, "perturbation is larger than amplitude!"
assert np.isclose(
mean_diff, AMPLITUDE * 0.5, atol=AMPLITUDE * 1e-2
), "perturbation is most likely too small!"

# create a plot of the difference to be read out manually if you are scared
fig, ax = plt.subplots(1, 1)
xx, yy = np.meshgrid(np.linspace(0, 1, ARRAY_DIM), np.linspace(0, 1, ARRAY_DIM))
cs = ax.contourf(xx, yy, diff)
fig.colorbar(cs, ax=ax)
fig.savefig(f"{tmp_dir}/diff_figure.pdf")
Loading

0 comments on commit 3b7d668

Please sign in to comment.