Skip to content

Commit

Permalink
Add get_data_column(), refactor filtering by the time domain (#562)
Browse files Browse the repository at this point in the history
* Add a utility function `get_data_column()`

* Refactor the `filter()` function to avoid casting to `data`

* Add to release notes

* Improve the docstring

* Use new function in `unit_mapping`

* Implement changes per discussion with @gidden, improve the test
  • Loading branch information
danielhuppmann authored Jul 30, 2021
1 parent 666da23 commit a8c60d9
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 16 deletions.
1 change: 1 addition & 0 deletions RELEASE_NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ the attribute `_LONG_IDX` as deprecated. Please use `dimensions` instead.

- [#564](https://github.com/IAMconsortium/pyam/pull/564) Add an example with a secondary axis to the plotting gallery
- [#563](https://github.com/IAMconsortium/pyam/pull/563) Enable `colors` keyword argument as list in `plot.pie()`
- [#562](https://github.com/IAMconsortium/pyam/pull/562) Add `get_data_column()`, refactor filtering by the time domain
- [#560](https://github.com/IAMconsortium/pyam/pull/560) Add a feature to `swap_year_for_time()`
- [#559](https://github.com/IAMconsortium/pyam/pull/559) Add attribute `dimensions`, fix compatibility with pandas v1.3
- [#557](https://github.com/IAMconsortium/pyam/pull/557) Swap time for year keeping subannual resolution
Expand Down
44 changes: 29 additions & 15 deletions pyam/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def __getitem__(self, key):
if set(_key_check).issubset(self.meta.columns):
return self.meta.__getitem__(key)
else:
return self.data.__getitem__(key)
return self.get_data_column(key)

def __len__(self):
return len(self._data)
Expand Down Expand Up @@ -365,10 +365,7 @@ def list_or_str(x):

return (
pd.DataFrame(
zip(
self._data.index.get_level_values("variable"),
self._data.index.get_level_values("unit"),
),
zip(self.get_data_column("variable"), self.get_data_column("unit")),
columns=["variable", "unit"],
)
.groupby("variable")
Expand All @@ -383,6 +380,22 @@ def data(self):
return pd.DataFrame([], columns=self.dimensions + ["value"])
return self._data.reset_index()

def get_data_column(self, column):
"""Return a `column` from the timeseries data in long format
Equivalent to :meth:`IamDataFrame.data[column] <IamDataFrame.data>`.
Parameters
----------
column : str
The column name.
Returns
-------
pd.Series
"""
return pd.Series(self._data.index.get_level_values(column), name=column)

@property
def dimensions(self):
"""Return the list of `data` columns (index names & data coordinates)"""
Expand Down Expand Up @@ -1721,16 +1734,15 @@ def _apply_filters(self, level=None, **filters):
cat_idx = self.meta[matches].index
keep_col = _make_index(self._data, unique=False).isin(cat_idx)
elif col == "year":
_data = (
self.data[col]
if self.time_col != "time"
else self.data["time"].apply(lambda x: x.year)
)
if self.time_col == "year":
_data = self.get_data_column(col)
else:
_data = self.get_data_column("time").apply(lambda x: x.year)
keep_col = years_match(_data, values)

elif col == "month" and self.time_col == "time":
keep_col = month_match(
self.data["time"].apply(lambda x: x.month), values
self.get_data_column("time").apply(lambda x: x.month), values
)

elif col == "day" and self.time_col == "time":
Expand All @@ -1742,17 +1754,19 @@ def _apply_filters(self, level=None, **filters):
wday = False

if wday:
days = self.data["time"].apply(lambda x: x.weekday())
days = self.get_data_column("time").apply(lambda x: x.weekday())
else: # ints or list of ints
days = self.data["time"].apply(lambda x: x.day)
days = self.get_data_column("time").apply(lambda x: x.day)

keep_col = day_match(days, values)

elif col == "hour" and self.time_col == "time":
keep_col = hour_match(self.data["time"].apply(lambda x: x.hour), values)
keep_col = hour_match(
self.get_data_column("time").apply(lambda x: x.hour), values
)

elif col == "time" and self.time_col == "time":
keep_col = datetime_match(self.data[col], values)
keep_col = datetime_match(self.get_data_column("time"), values)

elif col in self.dimensions:
lvl_index, lvl_codes = get_index_levels_codes(self._data, col)
Expand Down
14 changes: 13 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,9 @@ def test_equals_raises(test_pd_df):


def test_get_item(test_df):
assert test_df["model"].unique() == ["model_a"]
"""Assert that getting a column from `data` via the direct getter works"""
pdt.assert_series_equal(test_df["model"], test_df.data["model"])
pdt.assert_series_equal(test_df["variable"], test_df.data["variable"])


def test_index(test_df_year):
Expand Down Expand Up @@ -330,6 +332,16 @@ def test_dimensions(test_df):
assert test_df._LONG_IDX == IAMC_IDX + [test_df.time_col]


def test_get_data_column(test_df):
"""Assert that getting a column from the `data` dataframe works"""

obs = test_df.get_data_column("model")
pdt.assert_series_equal(obs, pd.Series(["model_a"] * 6, name="model"))

obs = test_df.get_data_column(test_df.time_col)
pdt.assert_series_equal(obs, test_df.data[test_df.time_col])


def test_filter_empty_df():
# test for issue seen in #254
df = IamDataFrame(data=df_empty)
Expand Down

0 comments on commit a8c60d9

Please sign in to comment.