Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use LinearRegression in train_gt_ic_OLSVOLC #145

Merged
merged 3 commits into from
May 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ Internal Changes
- Renamed the ``interpolation`` keyword of ``np.quantile`` to ``method`` changed in
numpy v1.22.0 (`#137 <https://github.com/MESMER-group/mesmer/pull/137>`_).
By `Mathias Hauser <https://github.com/mathause>`_.
- Make use of :py:class:`mesmer.core.linear_regression.LinearRegression` in
:py:func:`mesmer.calibrate_mesmer.train_gt_ic_OLSVOLC` (`#145 <https://github.com/MESMER-group/mesmer/pull/145>`_).
By `Mathias Hauser <https://github.com/mathause>`_.

v0.8.3 - 2021-12-23
-------------------
Expand Down
46 changes: 26 additions & 20 deletions mesmer/calibrate_mesmer/train_gt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@

import joblib
import numpy as np
from sklearn.linear_model import LinearRegression
import xarray as xr
from statsmodels.nonparametric.smoothers_lowess import lowess

from mesmer.core.linear_regression import LinearRegression
from mesmer.io import load_strat_aod


Expand Down Expand Up @@ -237,41 +238,46 @@ def train_gt_ic_OLSVOLC(var, gt_lowess, time, cfg):

# account for volcanic eruptions in historical time period
# load in observed stratospheric aerosol optical depth
aod_obs = load_strat_aod(time, dir_obs).reshape(
-1, 1
) # bring in correct format for sklearn linear regression
aod_obs_all = np.tile(
aod_obs, (nr_runs, 1)
) # repeat aod time series as many times as runs available
nr_aod_obs = len(aod_obs)
aod_obs = load_strat_aod(time, dir_obs)
# drop "year" coords - aod_obs does not have coords (currently)
aod_obs = aod_obs.drop_vars("year")

# repeat aod time series as many times as runs available
aod_obs_all = xr.concat([aod_obs] * nr_runs, dim="year")

nr_aod_obs = aod_obs.shape[0]
if nr_ts != nr_aod_obs:
raise ValueError(
f"The number of time steps of the variable ({nr_ts}) and the saod "
"({nr_aod_obs}) do not match."
f"({nr_aod_obs}) do not match."
)

# extract global variability (which still includes volc eruptions) by removing
# smooth trend from Tglob in historic period
gv_all_for_aod = np.zeros(nr_runs * nr_aod_obs)
i = 0
for run in np.arange(nr_runs):
gv_all_for_aod[i : i + nr_aod_obs] = var[run] - gt_lowess
i += nr_aod_obs
# fit linear regression of gv to aod (because some ESMs react very strongly to
# (should broadcast, and flatten the correct way - hopefully)
gv_all_for_aod = (var - gt_lowess).ravel()

gv_all_for_aod = xr.DataArray(gv_all_for_aod, dims="year").expand_dims("x")

lr = LinearRegression()

# fit linear regression of gt to aod (because some ESMs react very strongly to
# volcanoes)
# no intercept to not artifically move the ts
linreg_gv_volc = LinearRegression(fit_intercept=False).fit(
aod_obs_all, gv_all_for_aod
lr.fit(
predictors={"aod_obs": aod_obs_all},
target=gv_all_for_aod,
dim="year",
fit_intercept=False,
)

# extract the saod coefficient
coef_saod = linreg_gv_volc.coef_[0]
coef_saod = lr.params["aod_obs"].values

# apply linear regression model to obtain volcanic spikes
contrib_volc = linreg_gv_volc.predict(aod_obs)
contrib_volc = lr.predict(predictors={"aod_obs": aod_obs_all})

# merge the lowess trend wit the volc contribution
gt = gt_lowess + contrib_volc
gt = gt_lowess + contrib_volc.values.squeeze()

return coef_saod, gt
17 changes: 8 additions & 9 deletions mesmer/io/load_obs.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,20 +182,19 @@ def load_strat_aod(time, dir_obs):
"""

path_file = dir_obs + "aerosols/isaod_gl.dat"
ts = pd.read_csv(
path_file, delim_whitespace=True, skiprows=11, names=("year", "month", "AOD")
df = pd.read_csv(
path_file,
delim_whitespace=True,
skiprows=11,
names=("year", "month", "AOD"),
parse_dates=[["year", "month"]],
)

beg = str(ts["year"].iloc[0]) + "-" + str(ts["month"].iloc[0])
end = str(ts["year"].iloc[-1]) + "-" + str(ts["month"].iloc[-1])
range = pd.to_datetime([beg, end]) + pd.offsets.MonthEnd()
date_range = pd.date_range(*range, freq="m")

aod_obs = xr.DataArray(
ts["AOD"].values, dims=("time",), coords=dict(time=("time", date_range))
df["AOD"], dims=("time",), coords=dict(time=("time", df["year_month"]))
)

aod_obs = aod_obs.groupby("time.year").mean("time")
aod_obs = aod_obs.sel(year=slice(str(time[0]), str(time[-1]))).values
aod_obs = aod_obs.sel(year=slice(time[0], time[-1]))

return aod_obs