Skip to content

Commit

Permalink
Support IntervalDtype(subtype=None) (#18017)
Browse files Browse the repository at this point in the history
closes #17997

Will help unblock #17978 where we will need to interpret `dtype="interval"` as "interval without a subtype" instead of "interval with float64 subtype"

Authors:
  - Matthew Roeschke (https://github.com/mroeschke)

Approvers:
  - Vyas Ramasubramani (https://github.com/vyasr)

URL: #18017
  • Loading branch information
mroeschke authored Feb 25, 2025
1 parent 8c7eecf commit 5f71f76
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 59 deletions.
2 changes: 2 additions & 0 deletions docs/cudf/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,8 @@ def on_missing_reference(app, env, node, contnode):
("py:class", "pyarrow.lib.ChunkedArray"),
("py:class", "pyarrow.lib.Array"),
("py:class", "ColumnLike"),
("py:class", "DtypeObj"),
("py:class", "pa.StructType"),
# TODO: Remove this when we figure out why typing_extensions doesn't seem
# to map types correctly for intersphinx
("py:class", "typing_extensions.Self"),
Expand Down
8 changes: 3 additions & 5 deletions python/cudf/cudf/core/column/interval.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2018-2024, NVIDIA CORPORATION.
# Copyright (c) 2018-2025, NVIDIA CORPORATION.
from __future__ import annotations

from typing import TYPE_CHECKING, Literal
Expand Down Expand Up @@ -105,9 +105,7 @@ def copy(self, deep: bool = True) -> Self:
return IntervalColumn( # type: ignore[return-value]
data=None,
size=struct_copy.size,
dtype=IntervalDtype(
struct_copy.dtype.fields["left"], self.dtype.closed
),
dtype=IntervalDtype(self.dtype.subtype, self.dtype.closed),
mask=struct_copy.base_mask,
offset=struct_copy.offset,
null_count=struct_copy.null_count,
Expand Down Expand Up @@ -163,7 +161,7 @@ def set_closed(
return IntervalColumn( # type: ignore[return-value]
data=None,
size=self.size,
dtype=IntervalDtype(self.dtype.fields["left"], closed),
dtype=IntervalDtype(self.dtype.subtype, closed),
mask=self.base_mask,
offset=self.offset,
null_count=self.null_count,
Expand Down
131 changes: 79 additions & 52 deletions python/cudf/cudf/core/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import textwrap
import warnings
from functools import cached_property
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal

import numpy as np
import pandas as pd
Expand All @@ -19,7 +19,11 @@
from cudf.core._compat import PANDAS_GE_210, PANDAS_LT_300
from cudf.core.abc import Serializable
from cudf.utils.docutils import doc_apply
from cudf.utils.dtypes import CUDF_STRING_DTYPE, cudf_dtype_from_pa_type
from cudf.utils.dtypes import (
CUDF_STRING_DTYPE,
cudf_dtype_from_pa_type,
cudf_dtype_to_pa_type,
)

if PANDAS_GE_210:
PANDAS_NUMPY_DTYPE = pd.core.dtypes.dtypes.NumpyEADtype
Expand All @@ -29,7 +33,9 @@
if TYPE_CHECKING:
from collections.abc import Callable

from cudf._typing import Dtype
from typing_extension import Self

from cudf._typing import Dtype, DtypeObj
from cudf.core.buffer import Buffer


Expand Down Expand Up @@ -573,15 +579,11 @@ class StructDtype(_BaseDtype):

name = "struct"

def __init__(self, fields):
pa_fields = {
k: cudf.utils.dtypes.cudf_dtype_to_pa_type(cudf.dtype(v))
for k, v in fields.items()
}
self._typ = pa.struct(pa_fields)
def __init__(self, fields: dict[str, Dtype]) -> None:
self._fields = {k: cudf.dtype(v) for k, v in fields.items()}

@property
def fields(self):
def fields(self) -> dict[str, DtypeObj]:
"""
Returns an ordered dict of column name and dtype key-value.
Expand All @@ -594,10 +596,7 @@ def fields(self):
>>> struct_dtype.fields
{'a': dtype('int64'), 'b': dtype('O')}
"""
return {
field.name: cudf.utils.dtypes.cudf_dtype_from_pa_type(field.type)
for field in self._typ
}
return self._fields

@property
def type(self):
Expand All @@ -606,7 +605,7 @@ def type(self):
return dict

@classmethod
def from_arrow(cls, typ):
def from_arrow(cls, typ: pa.StructType) -> Self:
"""
Convert a ``pyarrow.StructType`` to ``StructDtype``.
Expand All @@ -620,11 +619,19 @@ def from_arrow(cls, typ):
>>> cudf.StructDtype.from_arrow(pa_struct_type)
StructDtype({'x': dtype('int32'), 'y': dtype('O')})
"""
obj = object.__new__(cls)
obj._typ = typ
return obj
return cls(
{
typ.field(i).name: cudf_dtype_from_pa_type(typ.field(i).type)
for i in range(typ.num_fields)
}
# Once pyarrow 18 is the min version, replace with this version
# {
# field.name: cudf_dtype_from_pa_type(field.type)
# for field in typ.fields
# }
)

def to_arrow(self):
def to_arrow(self) -> pa.StructType:
"""
Convert a ``StructDtype`` to a ``pyarrow.StructType``.
Expand All @@ -637,20 +644,25 @@ def to_arrow(self):
>>> struct_type.to_arrow()
StructType(struct<x: int32, y: string>)
"""
return self._typ
return pa.struct(
{
k: cudf_dtype_to_pa_type(dtype)
for k, dtype in self.fields.items()
}
)

def __eq__(self, other):
def __eq__(self, other) -> bool:
if isinstance(other, str):
return other == self.name
if not isinstance(other, StructDtype):
return False
return self._typ.equals(other._typ)
return self.to_arrow().equals(other.to_arrow())

def __repr__(self):
def __repr__(self) -> str:
return f"{type(self).__name__}({self.fields})"

def __hash__(self):
return hash(self._typ)
def __hash__(self) -> int:
return hash(self.to_arrow())

def serialize(self) -> tuple[dict, list]:
header: dict[str, Any] = {}
Expand All @@ -674,7 +686,7 @@ def serialize(self) -> tuple[dict, list]:
return header, frames

@classmethod
def deserialize(cls, header: dict, frames: list):
def deserialize(cls, header: dict, frames: list) -> Self:
_check_type(cls, header, frames)
fields = {}
for k, dtype in header["fields"].items():
Expand All @@ -689,11 +701,8 @@ def deserialize(cls, header: dict, frames: list):
return cls(fields)

@cached_property
def itemsize(self):
return sum(
cudf.utils.dtypes.cudf_dtype_from_pa_type(field.type).itemsize
for field in self._typ
)
def itemsize(self) -> int:
return sum(field.itemsize for field in self.fields.values())

def _recursively_replace_fields(self, result: dict) -> dict:
"""
Expand Down Expand Up @@ -926,6 +935,10 @@ class Decimal128Dtype(DecimalDtype):

class IntervalDtype(StructDtype):
"""
A data type for Interval data.
Parameters
----------
subtype: str, np.dtype
The dtype of the Interval bounds.
closed: {'right', 'left', 'both', 'neither'}, default 'right'
Expand All @@ -935,43 +948,55 @@ class IntervalDtype(StructDtype):

name = "interval"

def __init__(self, subtype, closed="right"):
super().__init__(fields={"left": subtype, "right": subtype})

if closed is None:
closed = "right"
if closed in ["left", "right", "neither", "both"]:
def __init__(
self,
subtype: None | Dtype = None,
closed: Literal["left", "right", "neither", "both"] = "right",
) -> None:
if closed in {"left", "right", "neither", "both"}:
self.closed = closed
else:
raise ValueError("closed value is not valid")
raise ValueError(f"{closed=} is not valid")
if subtype is None:
self._subtype = None
dtypes = {}
else:
self._subtype = cudf.dtype(subtype)
dtypes = {"left": self._subtype, "right": self._subtype}
super().__init__(dtypes)

@property
def subtype(self):
return self.fields["left"]
def subtype(self) -> DtypeObj | None:
return self._subtype

def __repr__(self) -> str:
if self.subtype is None:
return "interval"
return f"interval[{self.subtype}, {self.closed}]"

def __str__(self) -> str:
return self.__repr__()
return repr(self)

@classmethod
def from_arrow(cls, typ):
return IntervalDtype(typ.subtype.to_pandas_dtype(), typ.closed)
def from_arrow(cls, typ: ArrowIntervalType) -> Self:
return cls(typ.subtype.to_pandas_dtype(), typ.closed)

def to_arrow(self):
def to_arrow(self) -> ArrowIntervalType:
return ArrowIntervalType(
pa.from_numpy_dtype(self.subtype), self.closed
cudf_dtype_to_pa_type(self.subtype), self.closed
)

@classmethod
def from_pandas(cls, pd_dtype: pd.IntervalDtype) -> "IntervalDtype":
return cls(subtype=pd_dtype.subtype, closed=pd_dtype.closed)
def from_pandas(cls, pd_dtype: pd.IntervalDtype) -> Self:
return cls(
subtype=pd_dtype.subtype,
closed="right" if pd_dtype.closed is None else pd_dtype.closed,
)

def to_pandas(self) -> pd.IntervalDtype:
return pd.IntervalDtype(subtype=self.subtype, closed=self.closed)

def __eq__(self, other):
def __eq__(self, other) -> bool:
if isinstance(other, str):
# This means equality isn't transitive but mimics pandas
return other in (self.name, str(self))
Expand All @@ -981,21 +1006,23 @@ def __eq__(self, other):
and self.closed == other.closed
)

def __hash__(self):
def __hash__(self) -> int:
return hash((self.subtype, self.closed))

def serialize(self) -> tuple[dict, list]:
header = {
"fields": (self.subtype.str, self.closed),
"fields": (
self.subtype.str if self.subtype is not None else self.subtype,
self.closed,
),
"frame_count": 0,
}
return header, []

@classmethod
def deserialize(cls, header: dict, frames: list):
def deserialize(cls, header: dict, frames: list) -> Self:
_check_type(cls, header, frames)
subtype, closed = header["fields"]
subtype = np.dtype(subtype)
return cls(subtype, closed=closed)


Expand Down
2 changes: 1 addition & 1 deletion python/cudf/cudf/core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -3517,7 +3517,7 @@ def _from_column(
def from_breaks(
cls,
breaks,
closed: Literal["left", "right", "neither", "both"] | None = "right",
closed: Literal["left", "right", "neither", "both"] = "right",
name=None,
copy: bool = False,
dtype=None,
Expand Down
11 changes: 10 additions & 1 deletion python/cudf/cudf/tests/test_interval.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
# Copyright (c) 2020-2025, NVIDIA CORPORATION.


import numpy as np
Expand Down Expand Up @@ -210,3 +210,12 @@ def test_reduction_return_interval_pandas_compatible():
result = cudf_ii.min()
expected = ii.min()
assert result == expected


def test_empty_intervaldtype():
# "older pandas" supported closed=None, cudf chooses not to support that
pd_id = pd.IntervalDtype(closed="right")
cudf_id = cudf.IntervalDtype()

assert str(pd_id) == str(cudf_id)
assert pd_id.subtype == cudf_id.subtype

0 comments on commit 5f71f76

Please sign in to comment.