Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove duplicate code #38

Merged
merged 23 commits into from
Aug 26, 2024
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ repos:
types: [python]
args:
- "--max-line-length=88"
- "--disable=C0116,R0801,R0912,R0913,R0914,R0915,R1710,W0511,W0719"
- "--disable=C0116,R0912,R0913,R0914,R0915,R1710,W0511,W0719"
- repo: local
hooks:
- id: flake8
Expand Down
11 changes: 6 additions & 5 deletions engine/perturb.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from util.click_util import CommaSeperatedInts, CommaSeperatedStrings, cli_help
from util.log_handler import logger
from util.netcdf_io import nc4_get_copy
from util.utils import get_seed_from_member_num
from util.utils import get_seed_from_member_num, process_member_num


def create_perturb_files(in_path, in_files, out_path, copy_all_files=False):
Expand Down Expand Up @@ -101,10 +101,11 @@ def perturb(
perturb_amplitude,
copy_all_files,
): # pylint: disable=unused-argument
if len(member_num) == 1:
member_num = list(range(1, member_num[0] + 1))
for m_num in member_num:
m_id = str(m_num)

processed_member_num = process_member_num(member_num)

for m_num, m_id in processed_member_num:

if member_type:
m_id = member_type + "_" + m_id
perturbed_model_input_dir_member_id = perturbed_model_input_dir.format(
Expand Down
10 changes: 5 additions & 5 deletions engine/run_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from util.click_util import CommaSeperatedInts, CommaSeperatedStrings, cli_help
from util.log_handler import logger
from util.utils import get_seed_from_member_num
from util.utils import get_seed_from_member_num, process_member_num


def is_float(string):
Expand Down Expand Up @@ -231,10 +231,10 @@ def run_ensemble(
append_job(job, job_list, parallel)

# run the ensemble
if len(member_num) == 1:
member_num = list(range(1, member_num[0] + 1))
for m_num in member_num:
m_id = str(m_num)
processed_member_num = process_member_num(member_num)

for m_num, m_id in processed_member_num:

Path(perturbed_run_dir.format(member_id=m_id)).mkdir(
exist_ok=True, parents=True
)
Expand Down
10 changes: 2 additions & 8 deletions tests/engine/test_perturb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@
and amplitude of the perturbations.
"""

import os
import shutil
import unittest

import numpy as np
from matplotlib import pyplot as plt
from netCDF4 import Dataset # pylint: disable=no-name-in-module

from engine.perturb import perturb_array
from tests.helpers import setup_test_directory

atype = np.float32
AMPLITUDE = atype(1e-14)
Expand All @@ -37,12 +36,7 @@ def __init__(self, methodName="runTest"):

@classmethod
def setUpClass(cls):
test_path = os.path.realpath("tests/tmp")
cls.test_path = test_path
# create test directory (remake if it exists)
if os.path.exists(test_path):
shutil.rmtree(test_path)
os.mkdir(test_path)
cls.test_path = setup_test_directory("tests/tmp")

def test_perturb_array(self):
# create two arrays, perturb one.
Expand Down
20 changes: 20 additions & 0 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,26 @@
from engine.tolerance import tolerance


def setup_test_directory(directory_path):
"""
Creates a test directory, deleting any existing directory at the same path.

Args:
directory_path (str): The path to the directory to be created or cleaned.

Returns:
str: The absolute path to the created directory.
"""
test_path = os.path.realpath(directory_path)

if os.path.exists(test_path):
shutil.rmtree(test_path)

os.mkdir(test_path)

return test_path


def load_netcdf(path):
return xr.load_dataset(path)

Expand Down
80 changes: 80 additions & 0 deletions tests/test_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""
This module contains unit tests for the `helpers.py` module.
"""

import os
import shutil

from tests.helpers import setup_test_directory


def test_setup_test_directory_creates_directory():
test_dir = "test_dir"

# Call the function to create the directory
created_dir = setup_test_directory(test_dir)

try:
# Assert the directory was created
assert os.path.exists(created_dir), f"Directory {created_dir} was not created."
assert os.path.isdir(created_dir), f"{created_dir} is not a directory."

# Test cleanup behavior
setup_test_directory(test_dir)

assert os.path.exists(
created_dir
), f"Directory {created_dir} was not recreated."
assert os.path.isdir(created_dir), f"{created_dir} is not a directory."
finally:
# Clean up
if os.path.exists(test_dir):
shutil.rmtree(test_dir)


def test_setup_test_directory_removes_existing_directory():
test_dir = "test_dir"

try:
# Create an existing directory with a file in it
os.mkdir(test_dir)
with open(os.path.join(test_dir, "test_file.txt"), "w", encoding="utf-8") as f:
f.write("This is a test file.")

# Assert the file exists
assert os.path.exists(
os.path.join(test_dir, "test_file.txt")
), "Test file was not created."

# Call the function to recreate the directory
setup_test_directory(test_dir)

# Assert the directory was recreated and the file was removed
assert not os.path.exists(
os.path.join(test_dir, "test_file.txt")
), "Test file was not removed."
assert os.path.exists(test_dir), f"Directory {test_dir} was not recreated."
assert os.path.isdir(test_dir), f"{test_dir} is not a directory."

finally:
# Clean up
if os.path.exists(test_dir):
shutil.rmtree(test_dir)


def test_setup_test_directory_returns_absolute_path():
test_dir = "test_dir"

try:
# Call the function to create the directory
created_dir = setup_test_directory(test_dir)

# Assert the returned path is absolute
assert os.path.isabs(
created_dir
), f"Returned path {created_dir} is not absolute."

finally:
# Clean up
if os.path.exists(test_dir):
shutil.rmtree(test_dir)
21 changes: 13 additions & 8 deletions tests/util/icon/test_timing_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
growing trees with new nodes, and adding trees together.
"""

import os
import shutil
import unittest
from datetime import datetime

import numpy as np
import pandas as pd

from tests.helpers import setup_test_directory
from util.icon.extract_timings import read_logfile
from util.tree import TimingTree

Expand Down Expand Up @@ -43,12 +43,7 @@ class TestTimingTree(unittest.TestCase):

@classmethod
def setUpClass(cls):
test_path = os.path.realpath("tests/tmp")
cls.test_path = test_path
# create test directory (remake if it exists)
if os.path.exists(test_path):
shutil.rmtree(test_path)
os.mkdir(test_path)
cls.test_path = setup_test_directory("tests/tmp")

def setUp(self):
return
Expand Down Expand Up @@ -190,6 +185,16 @@ def test_add(self):

self.assert_trees_equal(tt1, tt_added)

def test_get_sorted_finish_times(self):
tt_json = TimingTree.from_json(JSON_REFERENCE)

dates = tt_json.get_sorted_finish_times()
print(dates)
self.assertTrue(
dates == [datetime(2022, 6, 26, 20, 11, 23)],
msg="sorted finish time does not match reference",
)


if __name__ == "__main__":
unittest.main()
102 changes: 102 additions & 0 deletions tests/util/test_dataframe_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""
This module contains unit tests for the `dataframe_ops.py` module.
"""

from unittest.mock import patch

import pandas as pd
import pytest

from util.dataframe_ops import parse_check


@pytest.fixture(name="setup_csv_files")
def fixture_setup_csv_files(tmp_path):
# Create sample CSV files for testing
tolerance_data = pd.DataFrame(
{"A": [0.1, 0.2], "B": [0.3, 0.4]},
index=pd.MultiIndex.from_tuples(
[("a", "b"), ("c", "d")], names=["col1", "col2"]
),
)
ref_data = pd.DataFrame(
{"A": [1, 2], "B": [3, 4]},
index=pd.MultiIndex.from_tuples(
[("a", "b", "c"), ("d", "e", "f")], names=["col1", "col2", "col3"]
),
)
cur_data = pd.DataFrame(
{"A": [2, 3], "B": [4, 5]},
index=pd.MultiIndex.from_tuples(
[("a", "b", "c"), ("d", "e", "f")], names=["col1", "col2", "col3"]
),
)

tolerance_file = tmp_path / "tolerance_test.csv"
ref_file = tmp_path / "input_ref_test.csv"
cur_file = tmp_path / "input_cur_test.csv"

tolerance_data.to_csv(tolerance_file)
ref_data.to_csv(ref_file)
cur_data.to_csv(cur_file)

return {
"tolerance_file": tolerance_file,
"ref_file": ref_file,
"cur_file": cur_file,
}


@patch("util.dataframe_ops.parse_probtest_csv")
@patch("util.dataframe_ops.logger")
def test_parse_check(mock_logger, mock_parse_probtest_csv, setup_csv_files):
# Mock the return value of parse_probtest_csv
mock_parse_probtest_csv.side_effect = lambda file, index_col: pd.read_csv(
file, index_col=index_col
)

factor = 2.0

df_tol, df_ref, df_cur = parse_check(
setup_csv_files["tolerance_file"],
setup_csv_files["ref_file"],
setup_csv_files["cur_file"],
factor,
)

# Check that the tolerance DataFrame has been scaled
expected_tol = pd.DataFrame(
{"A": [0.2, 0.4], "B": [0.6, 0.8]},
index=pd.MultiIndex.from_tuples(
[("a", "b"), ("c", "d")], names=["col1", "col2"]
),
)

pd.testing.assert_frame_equal(df_tol, expected_tol)

# Check that the reference and current DataFrames are read correctly
expected_ref = pd.DataFrame(
{"A": [1, 2], "B": [3, 4]},
index=pd.MultiIndex.from_tuples(
[("a", "b", "c"), ("d", "e", "f")], names=["col1", "col2", "col3"]
),
)

expected_cur = pd.DataFrame(
{"A": [2, 3], "B": [4, 5]},
index=pd.MultiIndex.from_tuples(
[("a", "b", "c"), ("d", "e", "f")], names=["col1", "col2", "col3"]
),
)

pd.testing.assert_frame_equal(df_ref, expected_ref)
pd.testing.assert_frame_equal(df_cur, expected_cur)

# Check logging
mock_logger.info.assert_any_call("applying a factor of %s to the spread", factor)
mock_logger.info.assert_any_call(
"checking %s against %s using tolerances from %s",
setup_csv_files["cur_file"],
setup_csv_files["ref_file"],
setup_csv_files["tolerance_file"],
)
45 changes: 45 additions & 0 deletions tests/util/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""
This module contains unit tests for the `utils.py` module.
"""

from util.utils import process_member_num


def test_process_member_num_single_element():
"""
Test case for a single element in the input list.
It should generate a list from 1 to that number and convert each to a string.
"""
input_data = [5]
expected_output = [(1, "1"), (2, "2"), (3, "3"), (4, "4"), (5, "5")]
assert process_member_num(input_data) == expected_output


def test_process_member_num_multiple_elements():
"""
Test case for multiple elements in the input list.
It should convert each number to a string.
"""
input_data = [2, 3, 4]
expected_output = [(2, "2"), (3, "3"), (4, "4")]
assert process_member_num(input_data) == expected_output


def test_process_member_num_empty_list():
"""
Test case for an empty input list.
It should return an empty list.
"""
input_data = []
expected_output = []
assert process_member_num(input_data) == expected_output


def test_process_member_num_single_element_zero():
"""
Test case for a single element in the input list being zero.
It should return an empty list.
"""
input_data = [0]
expected_output = []
assert process_member_num(input_data) == expected_output
Loading
Loading