Skip to content

Commit

Permalink
expand harmonic model xarray test (MESMER-group#458)
Browse files Browse the repository at this point in the history
* add more sophisticated xarray test

* delete Shrutis test file

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Mathias Hauser <mathause@users.noreply.github.com>
  • Loading branch information
3 people authored May 28, 2024
1 parent c5df229 commit dc60db6
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 268 deletions.
253 changes: 0 additions & 253 deletions mesmer/mesmer_m/tests_harmonic_model.py

This file was deleted.

68 changes: 53 additions & 15 deletions tests/unit/test_harmonic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,23 +172,51 @@ def test_fit_to_bic_numerical_stability():
np.testing.assert_allclose(predictions, expected_predictions)


@pytest.mark.parametrize(
"coefficients",
[
np.array([0, -1, 0, -2]),
np.array([1, 2, 3, 4, 5, 6, 7, 8]),
np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]),
],
)
def test_fit_to_bic_xr(coefficients):
yearly_predictor = trend_data_2D(n_timesteps=10, n_lat=3, n_lon=2)
def get_2D_coefficients(order_per_cell, n_lat=3, n_lon=2):
n_cells = n_lat * n_lon
max_order = 6

# generate coefficients that resemble real ones
# generate rapidly decreasing coefficients for increasing orders
trend = np.repeat(np.linspace(1.2, 0.2, max_order) ** 2, 4)
# the first coefficients are rather small (scaling of seasonal variability with temperature change)
# while the second ones are large (constant distance of each month from the yearly mean)
scale = np.tile([0.01, 5.0], (n_cells, max_order * 2))
# generate some variability so not all coefficients are exactly the same
rng = np.random.default_rng(0)
variability = rng.normal(loc=0, scale=0.1, size=(n_cells, max_order * 4))
# put it together
coeffs = trend * scale + variability
coeffs = np.round(coeffs, 1)

# replace superfluous orders with nans
for cell, order in enumerate(order_per_cell):
coeffs[cell, order * 4 :] = np.nan

LON, LAT = np.meshgrid(np.arange(n_lon), np.arange(n_lat))

coords = {
"lon": ("cells", LON.flatten()),
"lat": ("cells", LAT.flatten()),
}

return xr.DataArray(coeffs, dims=("cells", "coeff"), coords=coords)


def test_fit_to_bic_xr():
n_ts = 10
orders = [1, 2, 3, 4, 5, 6]

coefficients = get_2D_coefficients(order_per_cell=orders, n_lat=3, n_lon=2)

yearly_predictor = trend_data_2D(n_timesteps=n_ts, n_lat=3, n_lon=2)

freq = "AS" if Version(pd.__version__) < Version("2.2") else "YS"
yearly_predictor["time"] = xr.cftime_range(
start="2000-01-01", periods=10, freq=freq
start="2000-01-01", periods=n_ts, freq=freq
)

time = xr.cftime_range(start="2000-01-01", periods=10 * 12, freq="MS")
time = xr.cftime_range(start="2000-01-01", periods=n_ts * 12, freq="MS")
monthly_time = xr.DataArray(
time,
dims=["time"],
Expand All @@ -200,17 +228,27 @@ def test_fit_to_bic_xr(coefficients):
monthly_target = xr.apply_ufunc(
generate_fourier_series_np,
upsampled_yearly_predictor,
input_core_dims=[["time"]],
coefficients,
input_core_dims=[["time"], ["coeff"]],
output_core_dims=[["time"]],
vectorize=True,
output_dtypes=[float],
kwargs={"coeffs": coefficients, "months": months},
kwargs={"months": months},
)

# test if the model can recover the monthly target from perfect fourier series
result = fit_to_bic_xr(yearly_predictor, monthly_target)

np.testing.assert_equal(result.n_sel.values, orders)
xr.testing.assert_allclose(result["predictions"], monthly_target, atol=0.1)

# test if the model can recover the underlying cycle with noise on top of monthly target
rng = np.random.default_rng(0)
noisy_monthly_target = monthly_target + rng.normal(
loc=0, scale=0.1, size=monthly_target.values.shape
)
result = fit_to_bic_xr(yearly_predictor, noisy_monthly_target)
xr.testing.assert_allclose(result["predictions"], monthly_target, atol=0.2)


def test_fit_to_bix_xr_instance_checks():
yearly_predictor = trend_data_2D(n_timesteps=10, n_lat=3, n_lon=2)
Expand Down

0 comments on commit dc60db6

Please sign in to comment.