Skip to content

Commit

Permalink
fix preds in (MESMER-group#510)
Browse files Browse the repository at this point in the history
  • Loading branch information
mathause authored Aug 28, 2024
1 parent df670b4 commit 0e15d4e
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 22 deletions.
11 changes: 6 additions & 5 deletions mesmer/stats/_harmonic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,8 @@ def func(coeffs, yearly_predictor, mon_target):

coeffs = minimize_result.x
mse = np.mean(minimize_result.fun**2)
preds = _generate_fourier_series_np(
yearly_predictor=yearly_predictor, coeffs=coeffs
)

return coeffs, preds, mse
return coeffs, mse


def _calculate_bic(n_samples, order, mse):
Expand Down Expand Up @@ -206,7 +203,7 @@ def _fit_fourier_order_np(yearly_predictor, monthly_target, max_order):

for i_order in range(1, max_order + 1):

coeffs, predictions, mse = _fit_fourier_coeffs_np(
coeffs, mse = _fit_fourier_coeffs_np(
yearly_predictor,
monthly_target,
# use coeffs from last iteration as first guess
Expand All @@ -221,6 +218,10 @@ def _fit_fourier_order_np(yearly_predictor, monthly_target, max_order):
else:
break

predictions = _generate_fourier_series_np(
yearly_predictor=yearly_predictor, coeffs=last_coeffs
)

# need the coeff array to be the same size for all orders
coeffs = np.full(max_order * 4, fill_value=np.nan)
coeffs[: selected_order * 4] = last_coeffs
Expand Down
37 changes: 20 additions & 17 deletions tests/unit/test_harmonic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def test_fit_harmonic_model():
# test if the model can recover the underlying cycle with noise on top of monthly target
rng = np.random.default_rng(0)
noisy_monthly_target = monthly_target + rng.normal(
loc=0, scale=0.1, size=monthly_target.values.shape
loc=0, scale=0.1, size=monthly_target.shape
)

result = mesmer.stats.fit_harmonic_model(yearly_predictor, noisy_monthly_target)
Expand All @@ -171,28 +171,31 @@ def test_fit_harmonic_model():
# compare numerically one cell of one year
expected = np.array(
[
9.975936,
9.968497,
7.32234,
2.750445,
-2.520796,
-7.081546,
-9.713699,
-9.71333,
-7.077949,
-2.509761,
2.76855,
7.340076,
9.970548,
9.966644,
7.325875,
2.755833,
-2.518943,
-7.085081,
-9.719088,
-9.715184,
-7.074415,
-2.504373,
2.770403,
7.336541,
]
)

result_comp = result.predictions.isel(cells=0, time=slice(0, 12)).values
np.testing.assert_allclose(
result_comp,
expected,
atol=1e-6,
np.testing.assert_allclose(result_comp, expected, atol=1e-6)

# ensure coeffs and predictions are consistent
expected = mesmer.stats.predict_harmonic_model(
yearly_predictor, result.coeffs, result.time
)

xr.testing.assert_equal(expected, result.predictions)


def test_fit_harmonic_model_checks():
yearly_predictor = trend_data_2D(n_timesteps=10, n_lat=3, n_lon=2)
Expand Down

0 comments on commit 0e15d4e

Please sign in to comment.