Skip to content

Commit

Permalink
Reuse lambda function (MESMER-group#475)
Browse files Browse the repository at this point in the history
* reuse lambda function in get_lambdas_from_covariates_xr and refactor for that

* adjust tests

* add docstring

* [pre-commit.ci] auto fixes from pre-commit.com hooks

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
1 parent 6a30c27 commit c7b736e
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 8 deletions.
47 changes: 41 additions & 6 deletions mesmer/mesmer_m/power_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,33 @@
from sklearn.preprocessing import PowerTransformer, StandardScaler


def lambda_function(coeffs, local_yearly_T):
return 2 / (1 + coeffs[0] * np.exp(local_yearly_T * coeffs[1]))
def lambda_function(xi_0, xi_1, local_yearly_T):
r"""Use logistic function to calculate lambda depending on the local yearly
temperature. The function is defined as
.. math::
\lambda = \frac{2}{\xi_0 + e^{\xi_1 \cdot T_y}}
It ranges between 0 and 2.
Parameters
----------
xi_0 : float
First coefficient of the logistic function (controlling the intercept).
xi_1 : float
Second coefficient of the logistic function (controlling the slope).
local_yearly_T : ndarray of shape (n_years,)
yearly temperature values of one gridcell and month used as predictor
for lambda.
Returns
-------
lambdas : ndarray of float of shape (n_years,)
The parameters of the power transformation for each gridcell and month
"""
return 2 / (1 + xi_0 * np.exp(local_yearly_T * xi_1))


class PowerTransformerVariableLambda(PowerTransformer):
Expand Down Expand Up @@ -113,7 +138,7 @@ def _yeo_johnson_optimize_lambda(self, local_monthly_residuals, local_yearly_T):
def _neg_log_likelihood(coeffs):
"""Return the negative log likelihood of the observed local monthly
residual temperatures as a function of lambda."""
lambdas = lambda_function(coeffs, local_yearly_T)
lambdas = lambda_function(coeffs[0], coeffs[1], local_yearly_T)
# version with sklearn yeo johnson transform
# x_trans = np.zeros_like(x)
# for i, lmbda in enumerate(lambdas):
Expand Down Expand Up @@ -224,7 +249,7 @@ def _get_yeo_johnson_lambdas(self, yearly_T):
gridcell = 0
# TODO: sure yearly_T.T gives local yearly T?
for coeffs, local_yearly_T in zip(self.coeffs_, yearly_T.T):
lambdas[:, gridcell] = lambda_function(coeffs, local_yearly_T)
lambdas[:, gridcell] = lambda_function(coeffs[0], coeffs[1], local_yearly_T)
gridcell += 1

lambdas = np.where(lambdas < 0, 0, lambdas)
Expand Down Expand Up @@ -374,7 +399,7 @@ def _neg_log_likelihood(coeffs):
"""Return the negative log likelihood of the observed local monthly residual
temperatures as a function of lambda.
"""
lambdas = lambda_function(coeffs, yearly_pred)
lambdas = lambda_function(coeffs[0], coeffs[1], yearly_pred)

# version with own power transform
transformed_resids = _yeo_johnson_transform_np(monthly_residuals, lambdas)
Expand Down Expand Up @@ -427,7 +452,17 @@ def get_lambdas_from_covariates_xr(coeffs, yearly_pred):
if not isinstance(yearly_pred, xr.DataArray):
raise TypeError(f"Expected a `xr.DataArray`, got {type(yearly_pred)}")

lambdas = 2 / (1 + coeffs.xi_0 * np.exp(yearly_pred * coeffs.xi_1))
lambdas = xr.apply_ufunc(
lambda_function,
coeffs.xi_0,
coeffs.xi_1,
yearly_pred,
input_core_dims=[[], [], []],
output_core_dims=[[]],
vectorize=True,
dask="parallelized",
output_dtypes=[float],
)

return lambdas.rename("lambdas")

Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_power_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
)
def test_lambda_function(coeffs, t, expected):

result = lambda_function(coeffs, t)
result = lambda_function(coeffs[0], coeffs[1], t)
np.testing.assert_allclose(result, expected)


Expand Down Expand Up @@ -75,7 +75,7 @@ def test_yeo_johnson_optimize_lambda(skew, bounds):

pt = PowerTransformerVariableLambda(standardize=False)
pt.coeffs_ = pt._yeo_johnson_optimize_lambda(local_monthly_residuals, yearly_T)
lmbda = lambda_function(pt.coeffs_, yearly_T)
lmbda = lambda_function(pt.coeffs_[0], pt.coeffs_[1], yearly_T)
transformed = pt._yeo_johnson_transform(local_monthly_residuals, lmbda)

assert (lmbda >= bounds[0]).all() & (lmbda <= bounds[1]).all()
Expand Down

0 comments on commit c7b736e

Please sign in to comment.