Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-38325: [Python] Implement PyCapsule interface for Device data in PyArrow #40717

Merged
98 changes: 97 additions & 1 deletion python/pyarrow/array.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ def array(object obj, type=None, mask=None, size=None, from_pandas=None,
If both type and size are specified may be a single use iterable. If
not strongly-typed, Arrow type will be inferred for resulting array.
Any Arrow-compatible array that implements the Arrow PyCapsule Protocol
(has an ``__arrow_c_array__`` method) can be passed as well.
(has an ``__arrow_c_array__`` or ``__arrow_c_device_array__`` method)
can be passed as well.
type : pyarrow.DataType
Explicit type to attempt to coerce to, otherwise will be inferred from
the data.
Expand Down Expand Up @@ -245,6 +246,18 @@ def array(object obj, type=None, mask=None, size=None, from_pandas=None,

if hasattr(obj, '__arrow_array__'):
return _handle_arrow_array_protocol(obj, type, mask, size)
elif hasattr(obj, '__arrow_c_device_array__'):
if type is not None:
requested_type = type.__arrow_c_schema__()
else:
requested_type = None
schema_capsule, array_capsule = obj.__arrow_c_device_array__(requested_type)
out_array = Array._import_from_c_device_capsule(schema_capsule, array_capsule)
if type is not None and out_array.type != type:
# PyCapsule interface type coercion is best effort, so we need to
# check the type of the returned array and cast if necessary
out_array = array.cast(type, safe=safe, memory_pool=memory_pool)
return out_array
elif hasattr(obj, '__arrow_c_array__'):
if type is not None:
requested_type = type.__arrow_c_schema__()
Expand Down Expand Up @@ -1879,6 +1892,89 @@ cdef class Array(_PandasConvertible):
)
return pyarrow_wrap_array(c_array)

def __arrow_c_device_array__(self, requested_schema=None, **kwargs):
"""
Get a pair of PyCapsules containing a C ArrowDeviceArray representation
of the object.

Parameters
----------
requested_schema : PyCapsule | None
A PyCapsule containing a C ArrowSchema representation of a requested
schema. PyArrow will attempt to cast the array to this data type.
If None, the array will be returned as-is, with a type matching the
one returned by :meth:`__arrow_c_schema__()`.
kwargs
Currently no additional keyword arguments are supported, but
this method will accept any keyword with a value of ``None``
for compatibility with future keywords.

jorisvandenbossche marked this conversation as resolved.
Show resolved Hide resolved
Returns
-------
Tuple[PyCapsule, PyCapsule]
A pair of PyCapsules containing a C ArrowSchema and ArrowDeviceArray,
respectively.
"""
cdef:
ArrowDeviceArray* c_array
ArrowSchema* c_schema
shared_ptr[CArray] inner_array

non_default_kwargs = [
name for name, value in kwargs.items() if value is not None
]
if non_default_kwargs:
raise NotImplementedError(
f"Received unsupported keyword argument(s): {non_default_kwargs}"
)

if requested_schema is not None:
target_type = DataType._import_from_c_capsule(requested_schema)

if target_type != self.type:
if not self.is_cpu:
raise NotImplementedError(
"Casting to a requested schema is only supported for CPU data"
)
try:
casted_array = _pc().cast(self, target_type, safe=True)
inner_array = pyarrow_unwrap_array(casted_array)
except ArrowInvalid as e:
raise ValueError(
f"Could not cast {self.type} to requested type {target_type}: {e}"
)
else:
inner_array = self.sp_array
else:
inner_array = self.sp_array

schema_capsule = alloc_c_schema(&c_schema)
array_capsule = alloc_c_device_array(&c_array)

with nogil:
check_status(ExportDeviceArray(
deref(inner_array), <shared_ptr[CSyncEvent]>NULL,
c_array, c_schema))

return schema_capsule, array_capsule

@staticmethod
def _import_from_c_device_capsule(schema_capsule, array_capsule):
cdef:
ArrowSchema* c_schema
ArrowDeviceArray* c_array
shared_ptr[CArray] array

c_schema = <ArrowSchema*> PyCapsule_GetPointer(schema_capsule, 'arrow_schema')
c_array = <ArrowDeviceArray*> PyCapsule_GetPointer(
array_capsule, 'arrow_device_array'
)

with nogil:
array = GetResultValue(ImportDeviceArray(c_array, c_schema))

return pyarrow_wrap_array(array)

def __dlpack__(self, stream=None):
"""Export a primitive array as a DLPack capsule.

Expand Down
2 changes: 2 additions & 0 deletions python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -1016,6 +1016,8 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:
int num_columns()
int64_t num_rows()

CDeviceAllocationType device_type()

CStatus Validate() const
CStatus ValidateFull() const

Expand Down
146 changes: 136 additions & 10 deletions python/pyarrow/table.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -3619,7 +3619,7 @@ cdef class RecordBatch(_Tabular):
requested_schema : PyCapsule | None
A PyCapsule containing a C ArrowSchema representation of a requested
schema. PyArrow will attempt to cast the batch to this schema.
If None, the schema will be returned as-is, with a schema matching the
If None, the batch will be returned as-is, with a schema matching the
one returned by :meth:`__arrow_c_schema__()`.

Returns
Expand All @@ -3637,9 +3637,7 @@ cdef class RecordBatch(_Tabular):

if target_schema != self.schema:
try:
# We don't expose .cast() on RecordBatch, only on Table.
casted_batch = Table.from_batches([self]).cast(
target_schema, safe=True).to_batches()[0]
casted_batch = self.cast(target_schema, safe=True)
inner_batch = pyarrow_unwrap_batch(casted_batch)
except ArrowInvalid as e:
raise ValueError(
Expand Down Expand Up @@ -3680,8 +3678,8 @@ cdef class RecordBatch(_Tabular):
@staticmethod
def _import_from_c_capsule(schema_capsule, array_capsule):
"""
Import RecordBatch from a pair of PyCapsules containing a C ArrowArray
and ArrowSchema, respectively.
Import RecordBatch from a pair of PyCapsules containing a C ArrowSchema
and ArrowArray, respectively.

Parameters
----------
Expand Down Expand Up @@ -3772,6 +3770,121 @@ cdef class RecordBatch(_Tabular):
c_device_array, c_schema))
return pyarrow_wrap_batch(c_batch)

def __arrow_c_device_array__(self, requested_schema=None, **kwargs):
"""
Get a pair of PyCapsules containing a C ArrowDeviceArray representation
of the object.

Parameters
----------
requested_schema : PyCapsule | None
A PyCapsule containing a C ArrowSchema representation of a requested
schema. PyArrow will attempt to cast the batch to this data type.
If None, the batch will be returned as-is, with a type matching the
one returned by :meth:`__arrow_c_schema__()`.
kwargs
Currently no additional keyword arguments are supported, but
this method will accept any keyword with a value of ``None``
for compatibility with future keywords.

jorisvandenbossche marked this conversation as resolved.
Show resolved Hide resolved
Returns
-------
Tuple[PyCapsule, PyCapsule]
A pair of PyCapsules containing a C ArrowSchema and ArrowDeviceArray,
respectively.
"""
cdef:
ArrowDeviceArray* c_array
ArrowSchema* c_schema
shared_ptr[CRecordBatch] inner_batch

non_default_kwargs = [
name for name, value in kwargs.items() if value is not None
]
if non_default_kwargs:
raise NotImplementedError(
f"Received unsupported keyword argument(s): {non_default_kwargs}"
)

if requested_schema is not None:
target_schema = Schema._import_from_c_capsule(requested_schema)

if target_schema != self.schema:
if not self.is_cpu:
raise NotImplementedError(
"Casting to a requested schema is only supported for CPU data"
)
try:
casted_batch = self.cast(target_schema, safe=True)
inner_batch = pyarrow_unwrap_batch(casted_batch)
except ArrowInvalid as e:
raise ValueError(
f"Could not cast {self.schema} to requested schema {target_schema}: {e}"
)
else:
inner_batch = self.sp_batch
else:
inner_batch = self.sp_batch

schema_capsule = alloc_c_schema(&c_schema)
array_capsule = alloc_c_device_array(&c_array)

with nogil:
check_status(ExportDeviceRecordBatch(
deref(inner_batch), <shared_ptr[CSyncEvent]>NULL, c_array, c_schema))

return schema_capsule, array_capsule

@staticmethod
def _import_from_c_device_capsule(schema_capsule, array_capsule):
"""
Import RecordBatch from a pair of PyCapsules containing a
C ArrowSchema and ArrowDeviceArray, respectively.

Parameters
----------
schema_capsule : PyCapsule
A PyCapsule containing a C ArrowSchema representation of the schema.
array_capsule : PyCapsule
A PyCapsule containing a C ArrowDeviceArray representation of the array.

Returns
-------
pyarrow.RecordBatch
"""
cdef:
ArrowSchema* c_schema
ArrowDeviceArray* c_array
shared_ptr[CRecordBatch] batch

c_schema = <ArrowSchema*> PyCapsule_GetPointer(schema_capsule, 'arrow_schema')
c_array = <ArrowDeviceArray*> PyCapsule_GetPointer(
array_capsule, 'arrow_device_array'
)

with nogil:
batch = GetResultValue(ImportDeviceRecordBatch(c_array, c_schema))

return pyarrow_wrap_batch(batch)

@property
def device_type(self):
"""
The device type where the arrays in the RecordBatch reside.

Returns
-------
DeviceAllocationType
"""
return _wrap_device_allocation_type(self.sp_batch.get().device_type())

@property
def is_cpu(self):
"""
Whether the RecordBatch's arrays are CPU-accessible.
"""
return self.device_type == DeviceAllocationType.CPU


def _reconstruct_record_batch(columns, schema):
"""
Expand Down Expand Up @@ -5636,7 +5749,8 @@ def record_batch(data, names=None, schema=None, metadata=None):
data : dict, list, pandas.DataFrame, Arrow-compatible table
A mapping of strings to Arrays or Python lists, a list of Arrays,
a pandas DataFame, or any tabular object implementing the
Arrow PyCapsule Protocol (has an ``__arrow_c_array__`` method).
Arrow PyCapsule Protocol (has an ``__arrow_c_array__`` or
``__arrow_c_device_array__`` method).
names : list, default None
Column names if list of arrays passed as data. Mutually exclusive with
'schema' argument.
Expand Down Expand Up @@ -5770,6 +5884,18 @@ def record_batch(data, names=None, schema=None, metadata=None):
raise ValueError(
"The 'names' argument is not valid when passing a dictionary")
return RecordBatch.from_pydict(data, schema=schema, metadata=metadata)
elif hasattr(data, "__arrow_c_device_array__"):
if schema is not None:
requested_schema = schema.__arrow_c_schema__()
else:
requested_schema = None
schema_capsule, array_capsule = data.__arrow_c_device_array__(requested_schema)
batch = RecordBatch._import_from_c_device_capsule(schema_capsule, array_capsule)
if schema is not None and batch.schema != schema:
# __arrow_c_device_array__ coerces schema with best effort, so we might
# need to cast it if the producer wasn't able to cast to exact schema.
batch = batch.cast(schema)
return batch
elif hasattr(data, "__arrow_c_array__"):
if schema is not None:
requested_schema = schema.__arrow_c_schema__()
Expand Down Expand Up @@ -5799,8 +5925,8 @@ def table(data, names=None, schema=None, metadata=None, nthreads=None):
data : dict, list, pandas.DataFrame, Arrow-compatible table
A mapping of strings to Arrays or Python lists, a list of arrays or
chunked arrays, a pandas DataFame, or any tabular object implementing
the Arrow PyCapsule Protocol (has an ``__arrow_c_array__`` or
``__arrow_c_stream__`` method).
the Arrow PyCapsule Protocol (has an ``__arrow_c_array__``,
``__arrow_c_device_array__`` or ``__arrow_c_stream__`` method).
names : list, default None
Column names if list of arrays passed as data. Mutually exclusive with
'schema' argument.
Expand Down Expand Up @@ -5940,7 +6066,7 @@ def table(data, names=None, schema=None, metadata=None, nthreads=None):
# need to cast it if the producer wasn't able to cast to exact schema.
table = table.cast(schema)
return table
elif hasattr(data, "__arrow_c_array__"):
elif hasattr(data, "__arrow_c_array__") or hasattr(data, "__arrow_c_device_array__"):
if names is not None or metadata is not None:
raise ValueError(
"The 'names' and 'metadata' arguments are not valid when "
Expand Down
39 changes: 32 additions & 7 deletions python/pyarrow/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -3505,16 +3505,27 @@ def __arrow_array__(self, type=None):
assert result.equals(expected)


def test_c_array_protocol():
class ArrayWrapper:
def __init__(self, data):
self.data = data
class ArrayWrapper:
def __init__(self, data):
self.data = data

def __arrow_c_array__(self, requested_schema=None):
return self.data.__arrow_c_array__(requested_schema)


class ArrayDeviceWrapper:
def __init__(self, data):
self.data = data

def __arrow_c_device_array__(self, requested_schema=None, **kwargs):
return self.data.__arrow_c_device_array__(requested_schema, **kwargs)

def __arrow_c_array__(self, requested_schema=None):
return self.data.__arrow_c_array__(requested_schema)

@pytest.mark.parametrize("wrapper_class", [ArrayWrapper, ArrayDeviceWrapper])
def test_c_array_protocol(wrapper_class):

# Can roundtrip through the C array protocol
arr = ArrayWrapper(pa.array([1, 2, 3], type=pa.int64()))
arr = wrapper_class(pa.array([1, 2, 3], type=pa.int64()))
result = pa.array(arr)
assert result == arr.data

Expand All @@ -3523,6 +3534,20 @@ def __arrow_c_array__(self, requested_schema=None):
assert result == pa.array([1, 2, 3], type=pa.int32())


def test_c_array_protocol_device_unsupported_keyword():
# For the device-aware version, we raise a specific error for unsupported keywords
arr = pa.array([1, 2, 3], type=pa.int64())

with pytest.raises(
NotImplementedError,
match=r"Received unsupported keyword argument\(s\): \['other'\]"
):
arr.__arrow_c_device_array__(other="not-none")

# but with None value it is ignored
_ = arr.__arrow_c_device_array__(other=None)


def test_concat_array():
concatenated = pa.concat_arrays(
[pa.array([1, 2]), pa.array([3, 4])])
Expand Down
Loading
Loading