diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index 01519cee3..85fb31b33 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -1,7 +1,7 @@ # Next Release ## Individual Updates - +- [#370](https://github.com/IAMconsortium/pyam/pull/370) Allowed filter to work with np.int64 years and np.datetime64 dates. - [#361](https://github.com/IAMconsortium/pyam/pull/361) iam-units refactored from a Git submodule to a Python dependency of pyam. # Release v0.5.0 diff --git a/pyam/core.py b/pyam/core.py index d0597430b..a3371d028 100755 --- a/pyam/core.py +++ b/pyam/core.py @@ -1178,10 +1178,10 @@ def filter(self, keep=True, inplace=False, **kwargs): string or list of strings, where `*` can be used as a wildcard - 'level': the maximum "depth" of IAM variables (number of '|') (excluding the strings given in the 'variable' argument) - - 'year': takes an integer, a list of integers or a range - note that the last year of a range is not included, + - 'year': takes an integer (int/np.int64), a list of integers or + a range. Note that the last year of a range is not included, so `range(2010, 2015)` is interpreted as `[2010, ..., 2014]` - - arguments for filtering by `datetime.datetime` + - arguments for filtering by `datetime.datetime` or np.datetime64 ('month', 'hour', 'time') - 'regexp=True' disables pseudo-regexp syntax in `pattern_match()` """ diff --git a/pyam/utils.py b/pyam/utils.py index 6fc01405c..74dbbfd28 100644 --- a/pyam/utils.py +++ b/pyam/utils.py @@ -353,8 +353,10 @@ def _escape_regexp(s): def years_match(data, years): """Return rows where data matches year""" - years = [years] if isinstance(years, int) else years - dt = datetime.datetime + years = [years] if ( + isinstance(years, (int, np.int64)) + ) else years + dt = (datetime.datetime, np.datetime64) if isinstance(years, dt) or isinstance(years[0], dt): error_msg = "`year` can only be filtered with ints or lists of ints" raise TypeError(error_msg) @@ -423,9 +425,11 @@ def conv_strs(strs_to_convert, conv_codes, name): def datetime_match(data, dts): """Matching of datetimes in time columns for data filtering""" dts = dts if islistable(dts) else [dts] - if any([not isinstance(i, datetime.datetime) for i in dts]): + if any([not ( + isinstance(i, (datetime.datetime, np.datetime64)) + ) for i in dts]): error_msg = ( - "`time` can only be filtered by datetimes" + "`time` can only be filtered by datetimes and datetime64s" ) raise TypeError(error_msg) return data.isin(dts) diff --git a/tests/test_core.py b/tests/test_core.py index 6f0508208..4711b0eb4 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -317,6 +317,15 @@ def test_filter_day(test_df, test_day): assert unique_time[0] == expected +def test_filter_with_numpy_64_date_vals(test_df): + dates = test_df[test_df.time_col].unique() + key = 'year' if test_df.time_col == "year" else 'time' + res_0 = test_df.filter(**{key: dates[0]}) + res = test_df.filter(**{key: dates}) + assert np.equal(res_0.data[res_0.time_col].values, dates[0]).all() + assert res.equals(test_df) + + @pytest.mark.parametrize("test_hour", [0, 12, [12, 13]]) def test_filter_hour(test_df, test_hour): if "year" in test_df.data.columns: