diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index 14884f3d6..c70167082 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -1,4 +1,6 @@ # Next Release + +- [#548](https://github.com/IAMconsortium/pyam/pull/548) Add a `unit_mapping` attribute to show a variable-unit dictionary - [#546](https://github.com/IAMconsortium/pyam/pull/546) Fixed logging for recursive aggregation # Release v0.12.0 diff --git a/pyam/core.py b/pyam/core.py index 0775b7aa4..6b03d6314 100755 --- a/pyam/core.py +++ b/pyam/core.py @@ -366,6 +366,27 @@ def unit(self): """Return the list of (unique) units""" return get_index_levels(self._data, "unit") + @property + def unit_mapping(self): + """Return a dictionary of variables to (list of) correspoding units""" + + def list_or_str(x): + x = list(x.drop_duplicates()) + return x if len(x) > 1 else x[0] + + return ( + pd.DataFrame( + zip( + self._data.index.get_level_values("variable"), + self._data.index.get_level_values("unit"), + ), + columns=["variable", "unit"], + ) + .groupby("variable") + .apply(lambda u: list_or_str(u.unit)) + .to_dict() + ) + @property def data(self): """Return the timeseries data as a long :class:`pandas.DataFrame`""" @@ -456,6 +477,7 @@ def variables(self, include_units=False): return pd.Series(get_index_levels(self._data, _var), name=_var) # else construct dataframe from variable and unit levels + deprecation_warning("Use the attribute `unit_mapping` instead.") return ( pd.DataFrame( zip( diff --git a/tests/test_core.py b/tests/test_core.py index 9ee6106ca..38b4ab2ad 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -316,6 +316,14 @@ def test_index_attributes_extra_col(test_pd_df): assert df.subannual == ["summer", "winter"] +def test_unit_mapping(test_pd_df): + """assert that the `unit_mapping` returns the expected dictionary""" + test_pd_df.loc[2, "unit"] = "foo" # replace unit of one row of Primary Energy data + obs = IamDataFrame(test_pd_df).unit_mapping + + assert obs == {"Primary Energy": ["EJ/yr", "foo"], "Primary Energy|Coal": "EJ/yr"} + + def test_model(test_df): exp = pd.Series(data=["model_a"], name="model") pd.testing.assert_series_equal(test_df.models(), exp) diff --git a/tests/test_unfccc.py b/tests/test_unfccc.py index 7b8ef5754..6f7b82ab2 100644 --- a/tests/test_unfccc.py +++ b/tests/test_unfccc.py @@ -4,7 +4,7 @@ UNFCCC_DF = pd.DataFrame( - [[1990, 1738.137558], [1991, 1537.282312], [1992, 1499.067572]], + [[1990, 1609.25345], [1991, 1434.21149], [1992, 1398.38269]], columns=["year", "value"], )