diff --git a/dask_expr/_collection.py b/dask_expr/_collection.py index b317efb9a..002aff4da 100644 --- a/dask_expr/_collection.py +++ b/dask_expr/_collection.py @@ -1571,6 +1571,8 @@ def std( ) else: needs_time_conversion = True + if axis == 1: + numeric_dd = numeric_dd.astype(f"datetime64[{meta.array.unit}]") for col in time_cols: numeric_dd[col] = _convert_to_numeric(numeric_dd[col], skipna) else: @@ -1583,6 +1585,11 @@ def std( units = [getattr(self._meta[c].array, "unit", None) for c in time_cols] if axis == 1: + _kwargs = ( + {} + if not needs_time_conversion + else {"unit": meta.array.unit, "dtype": meta.dtype} + ) return numeric_dd.map_partitions( M.std if not needs_time_conversion else _sqrt_and_convert_to_timedelta, meta=meta, @@ -1591,6 +1598,7 @@ def std( ddof=ddof, enforce_metadata=False, numeric_only=numeric_only, + **_kwargs, ) result = numeric_dd.var( diff --git a/dask_expr/_describe.py b/dask_expr/_describe.py index 5df15a3a7..a1878b4aa 100644 --- a/dask_expr/_describe.py +++ b/dask_expr/_describe.py @@ -52,21 +52,25 @@ def _lower(self): frame.max(split_every=self.split_every), ] return DescribeNumericAggregate( - self.frame._meta.name, is_td_col, is_dt_col, *stats + self.frame._meta.name, + is_td_col, + is_dt_col, + getattr(self.frame._meta.array, "unit", None), + *stats, ) class DescribeNumericAggregate(Blockwise): - _parameters = ["name", "is_timedelta_col", "is_datetime_col"] + _parameters = ["name", "is_timedelta_col", "is_datetime_col", "unit"] _defaults = {"is_timedelta_col": False, "is_datetime_col": False} def _broadcast_dep(self, dep): return dep.npartitions == 1 @staticmethod - def operation(name, is_timedelta_col, is_datetime_col, *stats): + def operation(name, is_timedelta_col, is_datetime_col, unit, *stats): return describe_numeric_aggregate( - stats, name, is_timedelta_col, is_datetime_col + stats, name, is_timedelta_col, is_datetime_col, unit )