Skip to content

Commit

Permalink
Fixup remaining upstream failures (#1111)
Browse files Browse the repository at this point in the history
  • Loading branch information
phofl authored Jul 25, 2024
1 parent 184a22c commit 1c64671
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
8 changes: 8 additions & 0 deletions dask_expr/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1571,6 +1571,8 @@ def std(
)
else:
needs_time_conversion = True
if axis == 1:
numeric_dd = numeric_dd.astype(f"datetime64[{meta.array.unit}]")
for col in time_cols:
numeric_dd[col] = _convert_to_numeric(numeric_dd[col], skipna)
else:
Expand All @@ -1583,6 +1585,11 @@ def std(
units = [getattr(self._meta[c].array, "unit", None) for c in time_cols]

if axis == 1:
_kwargs = (
{}
if not needs_time_conversion
else {"unit": meta.array.unit, "dtype": meta.dtype}
)
return numeric_dd.map_partitions(
M.std if not needs_time_conversion else _sqrt_and_convert_to_timedelta,
meta=meta,
Expand All @@ -1591,6 +1598,7 @@ def std(
ddof=ddof,
enforce_metadata=False,
numeric_only=numeric_only,
**_kwargs,
)

result = numeric_dd.var(
Expand Down
12 changes: 8 additions & 4 deletions dask_expr/_describe.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,21 +52,25 @@ def _lower(self):
frame.max(split_every=self.split_every),
]
return DescribeNumericAggregate(
self.frame._meta.name, is_td_col, is_dt_col, *stats
self.frame._meta.name,
is_td_col,
is_dt_col,
getattr(self.frame._meta.array, "unit", None),
*stats,
)


class DescribeNumericAggregate(Blockwise):
_parameters = ["name", "is_timedelta_col", "is_datetime_col"]
_parameters = ["name", "is_timedelta_col", "is_datetime_col", "unit"]
_defaults = {"is_timedelta_col": False, "is_datetime_col": False}

def _broadcast_dep(self, dep):
return dep.npartitions == 1

@staticmethod
def operation(name, is_timedelta_col, is_datetime_col, *stats):
def operation(name, is_timedelta_col, is_datetime_col, unit, *stats):
return describe_numeric_aggregate(
stats, name, is_timedelta_col, is_datetime_col
stats, name, is_timedelta_col, is_datetime_col, unit
)


Expand Down

0 comments on commit 1c64671

Please sign in to comment.