Skip to content

Commit

Permalink
REF: simplify ohlc (pandas-dev#41091)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored and JulianWgs committed Jul 3, 2021
1 parent 77e90ac commit 77feba1
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 30 deletions.
31 changes: 5 additions & 26 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,20 +363,10 @@ def _cython_agg_general(
result = self.grouper._cython_operation(
"aggregate", obj._values, how, axis=0, min_count=min_count
)

if how == "ohlc":
# e.g. ohlc
agg_names = ["open", "high", "low", "close"]
assert len(agg_names) == result.shape[1]
for result_column, result_name in zip(result.T, agg_names):
key = base.OutputKey(label=result_name, position=idx)
output[key] = result_column
idx += 1
else:
assert result.ndim == 1
key = base.OutputKey(label=name, position=idx)
output[key] = result
idx += 1
assert result.ndim == 1
key = base.OutputKey(label=name, position=idx)
output[key] = result
idx += 1

if not output:
raise DataError("No numeric types to aggregate")
Expand Down Expand Up @@ -942,10 +932,6 @@ def count(self) -> Series:
)
return self._reindex_output(result, fill_value=0)

def _apply_to_column_groupbys(self, func):
""" return a pass thru """
return func(self)

def pct_change(self, periods=1, fill_method="pad", limit=None, freq=None):
"""Calculate pct_change of each value to previous entry in group"""
# TODO: Remove this conditional when #23918 is fixed
Expand Down Expand Up @@ -1137,6 +1123,7 @@ def _cython_agg_general(
def _cython_agg_manager(
self, how: str, alt=None, numeric_only: bool = True, min_count: int = -1
) -> Manager2D:
# Note: we never get here with how="ohlc"; that goes through SeriesGroupBy

data: Manager2D = self._get_data_to_aggregate()

Expand Down Expand Up @@ -1227,21 +1214,13 @@ def array_func(values: ArrayLike) -> ArrayLike:
# generally if we have numeric_only=False
# and non-applicable functions
# try to python agg

if alt is None:
# we cannot perform the operation
# in an alternate way, exclude the block
assert how == "ohlc"
raise

result = py_fallback(values)

return cast_agg_result(result, values, how)
return result

# TypeError -> we may have an exception in trying to aggregate
# continue and exclude the block
# NotImplementedError -> "ohlc" with wrong dtype
new_mgr = data.grouped_reduce(array_func, ignore_failures=True)

if not len(new_mgr):
Expand Down
20 changes: 19 additions & 1 deletion pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1796,7 +1796,25 @@ def ohlc(self) -> DataFrame:
DataFrame
Open, high, low and close values within each group.
"""
return self._apply_to_column_groupbys(lambda x: x._cython_agg_general("ohlc"))
if self.obj.ndim == 1:
# self._iterate_slices() yields only self._selected_obj
obj = self._selected_obj

is_numeric = is_numeric_dtype(obj.dtype)
if not is_numeric:
raise DataError("No numeric types to aggregate")

res_values = self.grouper._cython_operation(
"aggregate", obj._values, "ohlc", axis=0, min_count=-1
)

agg_names = ["open", "high", "low", "close"]
result = self.obj._constructor_expanddim(
res_values, index=self.grouper.result_index, columns=agg_names
)
return self._reindex_output(result)

return self._apply_to_column_groupbys(lambda x: x.ohlc())

@final
@doc(DataFrame.describe)
Expand Down
8 changes: 5 additions & 3 deletions pandas/tests/resample/test_datetime_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,16 @@ def test_custom_grouper(index):
g = s.groupby(b)

# check all cython functions work
funcs = ["add", "mean", "prod", "ohlc", "min", "max", "var"]
g.ohlc() # doesn't use _cython_agg_general
funcs = ["add", "mean", "prod", "min", "max", "var"]
for f in funcs:
g._cython_agg_general(f)

b = Grouper(freq=Minute(5), closed="right", label="right")
g = s.groupby(b)
# check all cython functions work
funcs = ["add", "mean", "prod", "ohlc", "min", "max", "var"]
g.ohlc() # doesn't use _cython_agg_general
funcs = ["add", "mean", "prod", "min", "max", "var"]
for f in funcs:
g._cython_agg_general(f)

Expand All @@ -79,7 +81,7 @@ def test_custom_grouper(index):
idx = DatetimeIndex(idx, freq="5T")
expect = Series(arr, index=idx)

# GH2763 - return in put dtype if we can
# GH2763 - return input dtype if we can
result = g.agg(np.sum)
tm.assert_series_equal(result, expect)

Expand Down

0 comments on commit 77feba1

Please sign in to comment.