Skip to content

Commit

Permalink
Implement cumulative aggregation (dask#433)
Browse files Browse the repository at this point in the history
Co-authored-by: crusaderky <crusaderky@gmail.com>
  • Loading branch information
phofl and crusaderky authored Dec 12, 2023
1 parent 9f76576 commit fa515d2
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 0 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ API Coverage
- `combine_first`
- `copy`
- `count`
- `cummax`
- `cummin`
- `cumprod`
- `cumsum`
- `dask`
- `div`
- `divide`
Expand Down Expand Up @@ -169,6 +173,10 @@ API Coverage
- `combine_first`
- `copy`
- `count`
- `cummax`
- `cummin`
- `cumprod`
- `cumsum`
- `dask`
- `div`
- `divide`
Expand Down Expand Up @@ -204,6 +212,7 @@ API Coverage
- `partitions`
- `pow`
- `prod`
- `product`
- `radd`
- `rdiv`
- `rename`
Expand Down Expand Up @@ -295,6 +304,7 @@ API Coverage
- `mean`
- `median`
- `min`
- `nunique`
- `prod`
- `shift`
- `size`
Expand Down
12 changes: 12 additions & 0 deletions dask_expr/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,18 @@ def align(self, other, join="outer", fill_value=None):
def nunique_approx(self):
return new_collection(self.expr.nunique_approx())

def cumsum(self, skipna=True):
return new_collection(self.expr.cumsum(skipna=skipna))

def cumprod(self, skipna=True):
return new_collection(self.expr.cumprod(skipna=skipna))

def cummax(self, skipna=True):
return new_collection(self.expr.cummax(skipna=skipna))

def cummin(self, skipna=True):
return new_collection(self.expr.cummin(skipna=skipna))

def memory_usage_per_partition(self, index=True, deep=False):
return new_collection(self.expr.memory_usage_per_partition(index, deep))

Expand Down
116 changes: 116 additions & 0 deletions dask_expr/_cumulative.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import functools

from dask.dataframe import methods
from dask.utils import M

from dask_expr._expr import Blockwise, Expr, Projection


class CumulativeAggregations(Expr):
_parameters = ["frame", "axis", "skipna"]
_defaults = {"axis": None}

chunk_operation = None
aggregate_operation = None

def _divisions(self):
return self.frame._divisions()

@functools.cached_property
def _meta(self):
return self.frame._meta

def _lower(self):
chunks = CumulativeBlockwise(
self.frame, self.axis, self.skipna, self.chunk_operation
)
chunks_last = TakeLast(chunks, self.skipna)
return CumulativeFinalize(chunks, chunks_last, self.aggregate_operation)

def _simplify_up(self, parent):
if isinstance(parent, Projection):
return type(self)(self.frame[parent.operand("columns")], *self.operands[1:])


class CumulativeBlockwise(Blockwise):
_parameters = ["frame", "axis", "skipna", "operation"]
_defaults = {"skipna": True, "axis": None}
_projection_passthrough = True

@functools.cached_property
def _meta(self):
return self.frame._meta

@functools.cached_property
def operation(self):
return self.operand("operation")

@functools.cached_property
def _args(self) -> list:
return self.operands[:-1]


class TakeLast(Blockwise):
_parameters = ["frame", "skipna"]
_projection_passthrough = True

@staticmethod
def operation(a, skipna=True):
if skipna:
a = a.bfill()
return a.tail(n=1).squeeze()


class CumulativeFinalize(Expr):
_parameters = ["frame", "previous_partitions", "aggregator"]

def _divisions(self):
return self.frame._divisions()

@functools.cached_property
def _meta(self):
return self.frame._meta

def _layer(self) -> dict:
dsk = {}
frame, previous_partitions = self.frame, self.previous_partitions
dsk[(self._name, 0)] = (frame._name, 0)

intermediate_name = self._name + "-intermediate"
for i in range(1, self.frame.npartitions):
if i == 1:
dsk[(intermediate_name, i)] = (previous_partitions._name, i - 1)
else:
# aggregate with previous cumulation results
dsk[(intermediate_name, i)] = (
methods._cum_aggregate_apply,
self.aggregator,
(intermediate_name, i - 1),
(previous_partitions._name, i - 1),
)
dsk[(self._name, i)] = (
self.aggregator,
(self.frame._name, i),
(intermediate_name, i),
)
return dsk


class CumSum(CumulativeAggregations):
chunk_operation = M.cumsum
aggregate_operation = staticmethod(methods.cumsum_aggregate)


class CumProd(CumulativeAggregations):
chunk_operation = M.cumprod
aggregate_operation = staticmethod(methods.cumprod_aggregate)


class CumMax(CumulativeAggregations):
chunk_operation = M.cummax
aggregate_operation = staticmethod(methods.cummax_aggregate)


class CumMin(CumulativeAggregations):
chunk_operation = M.cummin
aggregate_operation = staticmethod(methods.cummin_aggregate)
20 changes: 20 additions & 0 deletions dask_expr/_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,26 @@ def min(self, skipna=True, numeric_only=False, min_count=0):
def count(self, numeric_only=False):
return Count(self, numeric_only)

def cumsum(self, skipna=True):
from dask_expr._cumulative import CumSum

return CumSum(self, skipna=skipna)

def cumprod(self, skipna=True):
from dask_expr._cumulative import CumProd

return CumProd(self, skipna=skipna)

def cummax(self, skipna=True):
from dask_expr._cumulative import CumMax

return CumMax(self, skipna=skipna)

def cummin(self, skipna=True):
from dask_expr._cumulative import CumMin

return CumMin(self, skipna=skipna)

def abs(self):
return Abs(self)

Expand Down
18 changes: 18 additions & 0 deletions dask_expr/tests/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,24 @@ def test_std_kwargs(axis, skipna, ddof):
)


@pytest.mark.parametrize("func", ["cumsum", "cumprod", "cummin", "cummax"])
def test_cumulative_methods(df, pdf, func):
assert_eq(getattr(df, func)(), getattr(pdf, func)(), check_dtype=False)
assert_eq(getattr(df.x, func)(), getattr(pdf.x, func)())

q = getattr(df, func)()["x"]
assert q.simplify()._name == getattr(df.x, func)()

pdf.loc[slice(None, None, 2), "x"] = np.nan
df = from_pandas(pdf, npartitions=10)
assert_eq(
getattr(df, func)(skipna=False),
getattr(pdf, func)(skipna=False),
check_dtype=False,
)
assert_eq(getattr(df.x, func)(skipna=False), getattr(pdf.x, func)(skipna=False))


@xfail_gpu("nbytes not supported by cudf")
def test_nbytes(pdf, df):
with pytest.raises(NotImplementedError, match="nbytes is not implemented"):
Expand Down

0 comments on commit fa515d2

Please sign in to comment.