Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-38325: [Python] Implement PyCapsule interface for Device data in PyArrow #40717

Merged
91 changes: 90 additions & 1 deletion python/pyarrow/array.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ def array(object obj, type=None, mask=None, size=None, from_pandas=None,
If both type and size are specified may be a single use iterable. If
not strongly-typed, Arrow type will be inferred for resulting array.
Any Arrow-compatible array that implements the Arrow PyCapsule Protocol
(has an ``__arrow_c_array__`` method) can be passed as well.
(has an ``__arrow_c_array__`` or ``__arrow_c_device_array__`` method)
can be passed as well.
type : pyarrow.DataType
Explicit type to attempt to coerce to, otherwise will be inferred from
the data.
Expand Down Expand Up @@ -245,6 +246,18 @@ def array(object obj, type=None, mask=None, size=None, from_pandas=None,

if hasattr(obj, '__arrow_array__'):
return _handle_arrow_array_protocol(obj, type, mask, size)
elif hasattr(obj, '__arrow_c_device_array__'):
if type is not None:
requested_type = type.__arrow_c_schema__()
else:
requested_type = None
schema_capsule, array_capsule = obj.__arrow_c_device_array__(requested_type)
out_array = Array._import_from_c_device_capsule(schema_capsule, array_capsule)
if type is not None and out_array.type != type:
# PyCapsule interface type coercion is best effort, so we need to
# check the type of the returned array and cast if necessary
out_array = array.cast(type, safe=safe, memory_pool=memory_pool)
return out_array
elif hasattr(obj, '__arrow_c_array__'):
if type is not None:
requested_type = type.__arrow_c_schema__()
Expand Down Expand Up @@ -1847,6 +1860,82 @@ cdef class Array(_PandasConvertible):
)
return pyarrow_wrap_array(c_array)

def __arrow_c_device_array__(self, requested_schema=None, **kwargs):
"""
Get a pair of PyCapsules containing a C ArrowDeviceArray representation
of the object.

Parameters
----------
requested_schema : PyCapsule | None
A PyCapsule containing a C ArrowSchema representation of a requested
schema. PyArrow will attempt to cast the array to this data type.
If None, the array will be returned as-is, with a type matching the
one returned by :meth:`__arrow_c_schema__()`.

jorisvandenbossche marked this conversation as resolved.
Show resolved Hide resolved
Returns
-------
Tuple[PyCapsule, PyCapsule]
A pair of PyCapsules containing a C ArrowSchema and ArrowDeviceArray,
respectively.
"""
cdef:
ArrowDeviceArray* c_array
ArrowSchema* c_schema
shared_ptr[CArray] inner_array

non_default_kwargs = [
name for name, value in kwargs.items() if value is not None
]
if non_default_kwargs:
raise NotImplementedError(
f"Received unsupported keyword argument(s): {non_default_kwargs}"
)

if requested_schema is not None:
target_type = DataType._import_from_c_capsule(requested_schema)

if target_type != self.type:
# TODO should protect from trying to cast non-CPU data
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this check easy to do? (If the failure mode is a crash maybe this would be good to do?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we actually expose a device_type on Array and RecordBatch, so we can easily validate this and raise an informative error when trying to cast to requested_schema for non-CPU data.

try:
casted_array = _pc().cast(self, target_type, safe=True)
inner_array = pyarrow_unwrap_array(casted_array)
except ArrowInvalid as e:
raise ValueError(
f"Could not cast {self.type} to requested type {target_type}: {e}"
)
else:
inner_array = self.sp_array
else:
inner_array = self.sp_array
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is a bit repetitive with the non-device version. I could factor that out into a shared helper function


schema_capsule = alloc_c_schema(&c_schema)
array_capsule = alloc_c_device_array(&c_array)

with nogil:
check_status(ExportDeviceArray(
deref(inner_array), <shared_ptr[CSyncEvent]>NULL,
c_array, c_schema))

return schema_capsule, array_capsule

@staticmethod
def _import_from_c_device_capsule(schema_capsule, array_capsule):
cdef:
ArrowSchema* c_schema
ArrowDeviceArray* c_array
shared_ptr[CArray] array

c_schema = <ArrowSchema*> PyCapsule_GetPointer(schema_capsule, 'arrow_schema')
c_array = <ArrowDeviceArray*> PyCapsule_GetPointer(
array_capsule, 'arrow_device_array'
)

with nogil:
array = GetResultValue(ImportDeviceArray(c_array, c_schema))

return pyarrow_wrap_array(array)

def __dlpack__(self, stream=None):
"""Export a primitive array as a DLPack capsule.

Expand Down
121 changes: 111 additions & 10 deletions python/pyarrow/table.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -3619,7 +3619,7 @@ cdef class RecordBatch(_Tabular):
requested_schema : PyCapsule | None
A PyCapsule containing a C ArrowSchema representation of a requested
schema. PyArrow will attempt to cast the batch to this schema.
If None, the schema will be returned as-is, with a schema matching the
If None, the batch will be returned as-is, with a schema matching the
one returned by :meth:`__arrow_c_schema__()`.

Returns
Expand All @@ -3637,9 +3637,7 @@ cdef class RecordBatch(_Tabular):

if target_schema != self.schema:
try:
# We don't expose .cast() on RecordBatch, only on Table.
casted_batch = Table.from_batches([self]).cast(
target_schema, safe=True).to_batches()[0]
casted_batch = self.cast(target_schema, safe=True)
inner_batch = pyarrow_unwrap_batch(casted_batch)
except ArrowInvalid as e:
raise ValueError(
Expand Down Expand Up @@ -3680,8 +3678,8 @@ cdef class RecordBatch(_Tabular):
@staticmethod
def _import_from_c_capsule(schema_capsule, array_capsule):
"""
Import RecordBatch from a pair of PyCapsules containing a C ArrowArray
and ArrowSchema, respectively.
Import RecordBatch from a pair of PyCapsules containing a C ArrowSchema
and ArrowArray, respectively.

Parameters
----------
Expand Down Expand Up @@ -3772,6 +3770,96 @@ cdef class RecordBatch(_Tabular):
c_device_array, c_schema))
return pyarrow_wrap_batch(c_batch)

def __arrow_c_device_array__(self, requested_schema=None, **kwargs):
"""
Get a pair of PyCapsules containing a C ArrowDeviceArray representation
of the object.

Parameters
----------
requested_schema : PyCapsule | None
A PyCapsule containing a C ArrowSchema representation of a requested
schema. PyArrow will attempt to cast the batch to this data type.
If None, the batch will be returned as-is, with a type matching the
one returned by :meth:`__arrow_c_schema__()`.

jorisvandenbossche marked this conversation as resolved.
Show resolved Hide resolved
Returns
-------
Tuple[PyCapsule, PyCapsule]
A pair of PyCapsules containing a C ArrowSchema and ArrowDeviceArray,
respectively.
"""
cdef:
ArrowDeviceArray* c_array
ArrowSchema* c_schema
shared_ptr[CRecordBatch] inner_batch

non_default_kwargs = [
name for name, value in kwargs.items() if value is not None
]
if non_default_kwargs:
raise NotImplementedError(
f"Received unsupported keyword argument(s): {non_default_kwargs}"
)

if requested_schema is not None:
target_schema = Schema._import_from_c_capsule(requested_schema)

if target_schema != self.schema:
# TODO should protect from trying to cast non-CPU data
try:
casted_batch = self.cast(target_schema, safe=True)
inner_batch = pyarrow_unwrap_batch(casted_batch)
except ArrowInvalid as e:
raise ValueError(
f"Could not cast {self.schema} to requested schema {target_schema}: {e}"
)
else:
inner_batch = self.sp_batch
else:
inner_batch = self.sp_batch

schema_capsule = alloc_c_schema(&c_schema)
array_capsule = alloc_c_device_array(&c_array)

with nogil:
check_status(ExportDeviceRecordBatch(
deref(inner_batch), <shared_ptr[CSyncEvent]>NULL, c_array, c_schema))

return schema_capsule, array_capsule

@staticmethod
def _import_from_c_device_capsule(schema_capsule, array_capsule):
"""
Import RecordBatch from a pair of PyCapsules containing a
C ArrowSchema and ArrowDeviceArray, respectively.

Parameters
----------
schema_capsule : PyCapsule
A PyCapsule containing a C ArrowSchema representation of the schema.
array_capsule : PyCapsule
A PyCapsule containing a C ArrowDeviceArray representation of the array.

Returns
-------
pyarrow.RecordBatch
"""
cdef:
ArrowSchema* c_schema
ArrowDeviceArray* c_array
shared_ptr[CRecordBatch] batch

c_schema = <ArrowSchema*> PyCapsule_GetPointer(schema_capsule, 'arrow_schema')
c_array = <ArrowDeviceArray*> PyCapsule_GetPointer(
array_capsule, 'arrow_device_array'
)

with nogil:
batch = GetResultValue(ImportDeviceRecordBatch(c_array, c_schema))

return pyarrow_wrap_batch(batch)


def _reconstruct_record_batch(columns, schema):
"""
Expand Down Expand Up @@ -5636,7 +5724,8 @@ def record_batch(data, names=None, schema=None, metadata=None):
data : dict, list, pandas.DataFrame, Arrow-compatible table
A mapping of strings to Arrays or Python lists, a list of Arrays,
a pandas DataFame, or any tabular object implementing the
Arrow PyCapsule Protocol (has an ``__arrow_c_array__`` method).
Arrow PyCapsule Protocol (has an ``__arrow_c_array__`` or
``__arrow_c_device_array__`` method).
names : list, default None
Column names if list of arrays passed as data. Mutually exclusive with
'schema' argument.
Expand Down Expand Up @@ -5770,6 +5859,18 @@ def record_batch(data, names=None, schema=None, metadata=None):
raise ValueError(
"The 'names' argument is not valid when passing a dictionary")
return RecordBatch.from_pydict(data, schema=schema, metadata=metadata)
elif hasattr(data, "__arrow_c_device_array__"):
if schema is not None:
requested_schema = schema.__arrow_c_schema__()
else:
requested_schema = None
schema_capsule, array_capsule = data.__arrow_c_device_array__(requested_schema)
batch = RecordBatch._import_from_c_device_capsule(schema_capsule, array_capsule)
if schema is not None and batch.schema != schema:
# __arrow_c_device_array__ coerces schema with best effort, so we might
# need to cast it if the producer wasn't able to cast to exact schema.
batch = batch.cast(schema)
return batch
elif hasattr(data, "__arrow_c_array__"):
if schema is not None:
requested_schema = schema.__arrow_c_schema__()
Expand Down Expand Up @@ -5799,8 +5900,8 @@ def table(data, names=None, schema=None, metadata=None, nthreads=None):
data : dict, list, pandas.DataFrame, Arrow-compatible table
A mapping of strings to Arrays or Python lists, a list of arrays or
chunked arrays, a pandas DataFame, or any tabular object implementing
the Arrow PyCapsule Protocol (has an ``__arrow_c_array__`` or
``__arrow_c_stream__`` method).
the Arrow PyCapsule Protocol (has an ``__arrow_c_array__``,
``__arrow_c_device_array__`` or ``__arrow_c_stream__`` method).
names : list, default None
Column names if list of arrays passed as data. Mutually exclusive with
'schema' argument.
Expand Down Expand Up @@ -5940,7 +6041,7 @@ def table(data, names=None, schema=None, metadata=None, nthreads=None):
# need to cast it if the producer wasn't able to cast to exact schema.
table = table.cast(schema)
return table
elif hasattr(data, "__arrow_c_array__"):
elif hasattr(data, "__arrow_c_array__") or hasattr(data, "__arrow_c_device_array__"):
if names is not None or metadata is not None:
raise ValueError(
"The 'names' and 'metadata' arguments are not valid when "
Expand Down
39 changes: 32 additions & 7 deletions python/pyarrow/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -3505,16 +3505,27 @@ def __arrow_array__(self, type=None):
assert result.equals(expected)


def test_c_array_protocol():
class ArrayWrapper:
def __init__(self, data):
self.data = data
class ArrayWrapper:
def __init__(self, data):
self.data = data

def __arrow_c_array__(self, requested_schema=None):
return self.data.__arrow_c_array__(requested_schema)


class ArrayDeviceWrapper:
def __init__(self, data):
self.data = data

def __arrow_c_device_array__(self, requested_schema=None, **kwargs):
return self.data.__arrow_c_device_array__(requested_schema, **kwargs)

def __arrow_c_array__(self, requested_schema=None):
return self.data.__arrow_c_array__(requested_schema)

@pytest.mark.parametrize("wrapper_class", [ArrayWrapper, ArrayDeviceWrapper])
def test_c_array_protocol(wrapper_class):

# Can roundtrip through the C array protocol
arr = ArrayWrapper(pa.array([1, 2, 3], type=pa.int64()))
arr = wrapper_class(pa.array([1, 2, 3], type=pa.int64()))
result = pa.array(arr)
assert result == arr.data

Expand All @@ -3523,6 +3534,20 @@ def __arrow_c_array__(self, requested_schema=None):
assert result == pa.array([1, 2, 3], type=pa.int32())


def test_c_array_protocol_device_unsupported_keyword():
# For the device-aware version, we raise a specific error for unsupported keywords
arr = pa.array([1, 2, 3], type=pa.int64())

with pytest.raises(
NotImplementedError,
match=r"Received unsupported keyword argument\(s\): \['other'\]"
):
arr.__arrow_c_device_array__(other="not-none")

# but with None value it is ignored
_ = arr.__arrow_c_device_array__(other=None)


def test_concat_array():
concatenated = pa.concat_arrays(
[pa.array([1, 2]), pa.array([3, 4])])
Expand Down
42 changes: 42 additions & 0 deletions python/pyarrow/tests/test_cffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,48 @@ def test_roundtrip_array_capsule(arr, schema_accessor, bad_type, good_type):
assert schema_accessor(arr_out) == good_type


@pytest.mark.parametrize('arr,schema_accessor,bad_type,good_type', [
(pa.array(['a', 'b', 'c']), lambda x: x.type, pa.int32(), pa.string()),
(
pa.record_batch([pa.array(['a', 'b', 'c'])], names=['x']),
lambda x: x.schema,
pa.schema({'x': pa.int32()}),
pa.schema({'x': pa.string()})
),
], ids=['array', 'record_batch'])
def test_roundtrip_device_array_capsule(arr, schema_accessor, bad_type, good_type):
gc.collect() # Make sure no Arrow data dangles in a ref cycle
old_allocated = pa.total_allocated_bytes()

import_array = type(arr)._import_from_c_device_capsule

schema_capsule, capsule = arr.__arrow_c_device_array__()
assert PyCapsule_IsValid(schema_capsule, b"arrow_schema") == 1
assert PyCapsule_IsValid(capsule, b"arrow_device_array") == 1
arr_out = import_array(schema_capsule, capsule)
assert arr_out.equals(arr)

assert pa.total_allocated_bytes() > old_allocated
del arr_out

assert pa.total_allocated_bytes() == old_allocated

capsule = arr.__arrow_c_array__()

assert pa.total_allocated_bytes() > old_allocated
del capsule
assert pa.total_allocated_bytes() == old_allocated

with pytest.raises(ValueError,
match=r"Could not cast.* string to requested .* int32"):
arr.__arrow_c_device_array__(bad_type.__arrow_c_schema__())

schema_capsule, array_capsule = arr.__arrow_c_device_array__(
good_type.__arrow_c_schema__())
arr_out = import_array(schema_capsule, array_capsule)
assert schema_accessor(arr_out) == good_type


# TODO: implement requested_schema for stream
@pytest.mark.parametrize('constructor', [
pa.RecordBatchReader.from_batches,
Expand Down
Loading
Loading