Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove duplicate code #38

Merged
merged 23 commits into from
Aug 26, 2024
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ repos:
types: [python]
args:
- "--max-line-length=88"
- "--disable=C0116,R0801,R0912,R0913,R0914,R0915,R1710,W0511,W0719"
- "--disable=C0116,R0912,R0913,R0914,R0915,R1710,W0511,W0719"
- repo: local
hooks:
- id: flake8
Expand Down
11 changes: 6 additions & 5 deletions engine/perturb.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from util.click_util import CommaSeperatedInts, CommaSeperatedStrings, cli_help
from util.log_handler import logger
from util.netcdf_io import nc4_get_copy
from util.utils import get_seed_from_member_num
from util.utils import get_seed_from_member_num, process_member_num


def create_perturb_files(in_path, in_files, out_path, copy_all_files=False):
Expand Down Expand Up @@ -101,10 +101,11 @@ def perturb(
perturb_amplitude,
copy_all_files,
): # pylint: disable=unused-argument
if len(member_num) == 1:
member_num = list(range(1, member_num[0] + 1))
for m_num in member_num:
m_id = str(m_num)

processed_member_num = process_member_num(member_num)

for m_num, m_id in processed_member_num:

if member_type:
m_id = member_type + "_" + m_id
perturbed_model_input_dir_member_id = perturbed_model_input_dir.format(
Expand Down
10 changes: 5 additions & 5 deletions engine/run_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from util.click_util import CommaSeperatedInts, CommaSeperatedStrings, cli_help
from util.log_handler import logger
from util.utils import get_seed_from_member_num
from util.utils import get_seed_from_member_num, process_member_num


def is_float(string):
Expand Down Expand Up @@ -231,10 +231,10 @@ def run_ensemble(
append_job(job, job_list, parallel)

# run the ensemble
if len(member_num) == 1:
member_num = list(range(1, member_num[0] + 1))
for m_num in member_num:
m_id = str(m_num)
processed_member_num = process_member_num(member_num)

for m_num, m_id in processed_member_num:

Path(perturbed_run_dir.format(member_id=m_id)).mkdir(
exist_ok=True, parents=True
)
Expand Down
37 changes: 37 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,40 @@ def stats_file_set(tmp_dir):
yield files
if os.path.exists(files["tol"]):
os.remove(files["tol"])


@pytest.fixture(name="setup_csv_files")
def fixture_setup_csv_files(tmp_path):
# Create sample CSV files for testing
tolerance_data = pd.DataFrame(
{"A": [0.1, 0.2], "B": [0.3, 0.4]},
index=pd.MultiIndex.from_tuples(
[("a", "b"), ("c", "d")], names=["col1", "col2"]
),
)
ref_data = pd.DataFrame(
{"A": [1, 2], "B": [3, 4]},
index=pd.MultiIndex.from_tuples(
[("a", "b", "c"), ("d", "e", "f")], names=["col1", "col2", "col3"]
),
)
cur_data = pd.DataFrame(
{"A": [2, 3], "B": [4, 5]},
index=pd.MultiIndex.from_tuples(
[("a", "b", "c"), ("d", "e", "f")], names=["col1", "col2", "col3"]
),
)

tolerance_file = tmp_path / "tolerance_test.csv"
ref_file = tmp_path / "input_ref_test.csv"
cur_file = tmp_path / "input_cur_test.csv"

tolerance_data.to_csv(tolerance_file)
ref_data.to_csv(ref_file)
cur_data.to_csv(cur_file)

return {
"tolerance_file": tolerance_file,
"ref_file": ref_file,
"cur_file": cur_file,
}
171 changes: 64 additions & 107 deletions tests/engine/test_perturb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,8 @@
and amplitude of the perturbations.
"""

import os
import shutil
import unittest

import numpy as np
import pytest
from matplotlib import pyplot as plt
from netCDF4 import Dataset # pylint: disable=no-name-in-module

Expand All @@ -20,106 +17,66 @@
ARRAY_DIM = 100


class TestPerturb(unittest.TestCase):
"""
Unit tests for verifying the functionality of perturbation methods applied
to arrays and NetCDF datasets.
This class uses the `unittest` framework to ensure the correctness of
perturbation operations on both in-memory arrays and NetCDF4 datasets.
It checks that perturbation functions produce the expected results and
handle different data types correctly.
"""

def __init__(self, methodName="runTest"):
super().__init__(methodName=methodName)
self.data = [None] * 2

@classmethod
def setUpClass(cls):
test_path = os.path.realpath("tests/tmp")
cls.test_path = test_path
# create test directory (remake if it exists)
if os.path.exists(test_path):
shutil.rmtree(test_path)
os.mkdir(test_path)

def test_perturb_array(self):
# create two arrays, perturb one.
# This is to make sure that we get a copy of the input from perturb
x1 = np.ones((ARRAY_DIM, ARRAY_DIM), dtype=atype)
x2 = np.ones((ARRAY_DIM, ARRAY_DIM), dtype=atype)
x_perturbed = perturb_array(x2, 10, AMPLITUDE)

# compute some stats and do assertions
diff1 = np.abs(x1 - x_perturbed) / x1
diff2 = np.abs(x2 - x_perturbed) / x2
mean_diff1 = np.mean(diff1)
mean_diff2 = np.mean(diff2)
self.assertLess(
np.max(diff2), AMPLITUDE, msg="perturbation is larger than amplitude!"
)
self.assertAlmostEqual(
mean_diff1,
AMPLITUDE * 0.5,
msg="perturbation is most likely too small!",
delta=AMPLITUDE * 1e-2,
)
self.assertAlmostEqual(
mean_diff2,
AMPLITUDE * 0.5,
msg="perturbation did not return a copy of input!",
delta=AMPLITUDE * 1e-2,
)

def test_perturb_nc(self):

# create two dummy netcdf4 files, one with single, one with double precision
for i, dt in enumerate([np.float32, np.float64]):
self.data[i] = Dataset(f"dummy{i}.nc", "w")
self.data[i].createDimension("x", size=ARRAY_DIM)
self.data[i].createDimension("y", size=ARRAY_DIM)
self.data[i].createVariable("z", dt, dimensions=("x", "y"))
self.data[i].variables["z"][:] = np.ones((ARRAY_DIM, ARRAY_DIM))

# perturb the double precision file
zd = self.data[1].variables["z"][:]
zd_perturb = perturb_array(zd, 10, AMPLITUDE)
self.data[1].variables["z"][:] = zd_perturb

# close the data
for i in range(2):
self.data[i].close()

# reopen the files to make sure we get the values form disk
for i in range(2):
self.data[i] = Dataset(f"dummy{i}.nc", "r")
zf = self.data[0].variables["z"][:]
zd_perturb = self.data[1].variables["z"][:]

# compute some stats and do assertions
diff = zf - zd_perturb
mean_diff = np.mean(np.abs(diff))
self.assertTrue(zf.dtype == "float32", msg="zf is not float32!")
self.assertTrue(zd.dtype == "float64", msg="zf is not float64!")
self.assertLess(
np.max(diff), AMPLITUDE, msg="perturbation is larger than amplitude!"
)
self.assertAlmostEqual(
mean_diff,
AMPLITUDE * 0.5,
msg="perturbation is most likely too small!",
delta=AMPLITUDE * 1e-2,
)

# to really make sure:
# create a plot of the difference to be read out manually if you are scared
fig, ax = plt.subplots(1, 1)
xx, yy = np.meshgrid(np.linspace(0, 1, ARRAY_DIM), np.linspace(0, 1, ARRAY_DIM))
cs = ax.contourf(xx, yy, diff)
fig.colorbar(cs, ax=ax)
fig.savefig(f"{self.test_path}/diff_figure.pdf")


if __name__ == "__main__":
unittest.main()
@pytest.fixture(name="create_nc_files")
def fixture_create_nc_files(tmp_dir):
data = [None] * 2
for i, dt in enumerate([np.float32, np.float64]):
data[i] = Dataset(f"{tmp_dir}/dummy{i}.nc", "w")
data[i].createDimension("x", size=ARRAY_DIM)
data[i].createDimension("y", size=ARRAY_DIM)
data[i].createVariable("z", dt, dimensions=("x", "y"))
data[i].variables["z"][:] = np.ones((ARRAY_DIM, ARRAY_DIM))
yield data
for d in data:
d.close()


def test_perturb_array():
# create two arrays, perturb one.
x1 = np.ones((ARRAY_DIM, ARRAY_DIM), dtype=atype)
x2 = np.ones((ARRAY_DIM, ARRAY_DIM), dtype=atype)
x_perturbed = perturb_array(x2, 10, AMPLITUDE)

# compute some stats and do assertions
diff1 = np.abs(x1 - x_perturbed) / x1
diff2 = np.abs(x2 - x_perturbed) / x2
mean_diff1 = np.mean(diff1)
mean_diff2 = np.mean(diff2)
assert np.max(diff2) < AMPLITUDE, "perturbation is larger than amplitude!"
assert np.isclose(
mean_diff1, AMPLITUDE * 0.5, atol=AMPLITUDE * 1e-2
), "perturbation is most likely too small!"
assert np.isclose(
mean_diff2, AMPLITUDE * 0.5, atol=AMPLITUDE * 1e-2
), "perturbation did not return a copy of input!"


def test_perturb_nc(tmp_dir, create_nc_files):

data = create_nc_files
zd = data[1].variables["z"][:]
zd_perturb = perturb_array(zd, 10, AMPLITUDE)
data[1].variables["z"][:] = zd_perturb

# reopen the files to make sure we get the values from disk
for i in range(2):
data[i] = Dataset(f"{tmp_dir}/dummy{i}.nc", "r")
zf = data[0].variables["z"][:]
zd_perturb = data[1].variables["z"][:]

# compute some stats and do assertions
diff = zf - zd_perturb
mean_diff = np.mean(np.abs(diff))
assert zf.dtype == "float32", "zf is not float32!"
assert zd.dtype == "float64", "zf is not float64!"
assert np.max(diff) < AMPLITUDE, "perturbation is larger than amplitude!"
assert np.isclose(
mean_diff, AMPLITUDE * 0.5, atol=AMPLITUDE * 1e-2
), "perturbation is most likely too small!"

# create a plot of the difference to be read out manually if you are scared
fig, ax = plt.subplots(1, 1)
xx, yy = np.meshgrid(np.linspace(0, 1, ARRAY_DIM), np.linspace(0, 1, ARRAY_DIM))
cs = ax.contourf(xx, yy, diff)
fig.colorbar(cs, ax=ax)
fig.savefig(f"{tmp_dir}/diff_figure.pdf")
Loading
Loading