From 9d6b21bb5e2bb90a6283af4294dbc55158f086cb Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Wed, 28 Feb 2024 19:30:43 +0000 Subject: [PATCH 01/11] fix pyarrow interchange --- doc/source/whatsnew/v2.2.2.rst | 4 +- pandas/core/interchange/buffer.py | 76 ++++++++++++ pandas/core/interchange/column.py | 75 ++++++++++-- pandas/core/interchange/from_dataframe.py | 29 +++-- pandas/tests/interchange/test_impl.py | 141 +++++++++++++++++++--- 5 files changed, 279 insertions(+), 46 deletions(-) diff --git a/doc/source/whatsnew/v2.2.2.rst b/doc/source/whatsnew/v2.2.2.rst index 96f210ce6b7b9..54084abab7817 100644 --- a/doc/source/whatsnew/v2.2.2.rst +++ b/doc/source/whatsnew/v2.2.2.rst @@ -14,6 +14,7 @@ including other versions of pandas. Fixed regressions ~~~~~~~~~~~~~~~~~ - :meth:`DataFrame.__dataframe__` was producing incorrect data buffers when the a column's type was a pandas nullable on with missing values (:issue:`56702`) +- :meth:`DataFrame.__dataframe__` was producing incorrect data buffers when the a column's type was a pyarrow nullable on with missing values (:issue:`57664`) - .. --------------------------------------------------------------------------- @@ -21,7 +22,8 @@ Fixed regressions Bug fixes ~~~~~~~~~ -- +- :meth:`DataFrame.__dataframe__` was showing bytemask instead of bitmask for ``'string[pyarrow]'`` validity buffer (:issue:`57762`) +- :meth:`DataFrame.__dataframe__` was showing non-null validity buffer (instead of ``None``) ``'string[pyarrow]'`` without missing values (:issue:`57761`) .. --------------------------------------------------------------------------- .. _whatsnew_222.other: diff --git a/pandas/core/interchange/buffer.py b/pandas/core/interchange/buffer.py index 5c97fc17d7070..aba404b2c15c9 100644 --- a/pandas/core/interchange/buffer.py +++ b/pandas/core/interchange/buffer.py @@ -12,6 +12,7 @@ if TYPE_CHECKING: import numpy as np + import pyarrow as pa class PandasBuffer(Buffer): @@ -76,3 +77,78 @@ def __repr__(self) -> str: ) + ")" ) + + +class PandasBufferPyarrow(Buffer): + """ + Data in the buffer is guaranteed to be contiguous in memory. + """ + + def __init__( + self, + chunked_array: pa.ChunkedArray, + *, + is_validity: bool, + allow_copy: bool = True, + ) -> None: + """ + Handle pyarrow chunked arrays. + """ + if len(chunked_array.chunks) == 1: + arr = chunked_array.chunks[0] + else: + if not allow_copy: + raise RuntimeError( + "Found multi-chunk pyarrow array, but `allow_copy` is False" + ) + arr = chunked_array.combine_chunks() + if is_validity: + self._buffer = arr.buffers()[0] + else: + self._buffer = arr.buffers()[1] + self._length = len(arr) + self._dlpack = getattr(arr, "__dlpack__", None) + self._is_validity = is_validity + + @property + def bufsize(self) -> int: + """ + Buffer size in bytes. + """ + return self._buffer.size + + @property + def ptr(self) -> int: + """ + Pointer to start of the buffer as an integer. + """ + return self._buffer.address + + def __dlpack__(self) -> Any: + """ + Represent this structure as DLPack interface. + """ + if self._dlpack is not None: + return self._dlpack() + raise NotImplementedError( + "pyarrow>=15.0.0 is required for DLPack support for pyarrow-backed buffers" + ) + + def __dlpack_device__(self) -> tuple[DlpackDeviceType, int | None]: + """ + Device type and device ID for where the data in the buffer resides. + """ + return (DlpackDeviceType.CPU, None) + + def __repr__(self) -> str: + return ( + "PandasBuffer[pyarrow](" + + str( + { + "bufsize": self.bufsize, + "ptr": self.ptr, + "device": "CPU", + } + ) + + ")" + ) diff --git a/pandas/core/interchange/column.py b/pandas/core/interchange/column.py index bf20f0b5433cd..507efe1c74b19 100644 --- a/pandas/core/interchange/column.py +++ b/pandas/core/interchange/column.py @@ -1,6 +1,9 @@ from __future__ import annotations -from typing import Any +from typing import ( + TYPE_CHECKING, + Any, +) import numpy as np @@ -9,15 +12,18 @@ from pandas.errors import NoBufferPresent from pandas.util._decorators import cache_readonly -from pandas.core.dtypes.dtypes import ( +from pandas.core.dtypes.dtypes import BaseMaskedDtype + +import pandas as pd +from pandas import ( ArrowDtype, - BaseMaskedDtype, DatetimeTZDtype, ) - -import pandas as pd from pandas.api.types import is_string_dtype -from pandas.core.interchange.buffer import PandasBuffer +from pandas.core.interchange.buffer import ( + PandasBuffer, + PandasBufferPyarrow, +) from pandas.core.interchange.dataframe_protocol import ( Column, ColumnBuffers, @@ -30,6 +36,9 @@ dtype_to_arrow_c_fmt, ) +if TYPE_CHECKING: + from pandas.core.interchange.dataframe_protocol import Buffer + _NP_KINDS = { "i": DtypeKind.INT, "u": DtypeKind.UINT, @@ -157,6 +166,14 @@ def _dtype_from_pandasdtype(self, dtype) -> tuple[DtypeKind, int, str, str]: else: byteorder = dtype.byteorder + if dtype == "bool[pyarrow]": + return ( + kind, + dtype.itemsize, # pyright: ignore[reportAttributeAccessIssue] + ArrowCTypes.BOOL, + byteorder, + ) + return kind, dtype.itemsize * 8, dtype_to_arrow_c_fmt(dtype), byteorder @property @@ -194,6 +211,13 @@ def describe_null(self): column_null_dtype = ColumnNullType.USE_BYTEMASK null_value = 1 return column_null_dtype, null_value + if isinstance(self._col.dtype, ArrowDtype): + if all( + chunk.buffers()[0] is None + for chunk in self._col.array._pa_array.chunks # type: ignore[attr-defined] + ): + return ColumnNullType.NON_NULLABLE, None + return ColumnNullType.USE_BITMASK, 0 kind = self.dtype[0] try: null, value = _NULL_DESCRIPTION[kind] @@ -278,7 +302,7 @@ def get_buffers(self) -> ColumnBuffers: def _get_data_buffer( self, - ) -> tuple[PandasBuffer, Any]: # Any is for self.dtype tuple + ) -> tuple[Buffer, tuple[DtypeKind, int, str, str]]: """ Return the buffer containing the data and the buffer's associated dtype. """ @@ -289,7 +313,7 @@ def _get_data_buffer( np_arr = self._col.dt.tz_convert(None).to_numpy() else: np_arr = self._col.to_numpy() - buffer = PandasBuffer(np_arr, allow_copy=self._allow_copy) + buffer: Buffer = PandasBuffer(np_arr, allow_copy=self._allow_copy) dtype = ( DtypeKind.INT, 64, @@ -302,15 +326,27 @@ def _get_data_buffer( DtypeKind.FLOAT, DtypeKind.BOOL, ): + dtype = self.dtype arr = self._col.array + if isinstance(self._col.dtype, ArrowDtype): + buffer = PandasBufferPyarrow( + arr._pa_array, # type: ignore[attr-defined] + is_validity=False, + allow_copy=self._allow_copy, + ) + if self.dtype[0] == DtypeKind.BOOL: + dtype = ( + DtypeKind.BOOL, + 1, + ArrowCTypes.BOOL, + Endianness.NATIVE, + ) + return buffer, dtype if isinstance(self._col.dtype, BaseMaskedDtype): np_arr = arr._data # type: ignore[attr-defined] - elif isinstance(self._col.dtype, ArrowDtype): - raise NotImplementedError("ArrowDtype not handled yet") else: np_arr = arr._ndarray # type: ignore[attr-defined] buffer = PandasBuffer(np_arr, allow_copy=self._allow_copy) - dtype = self.dtype elif self.dtype[0] == DtypeKind.CATEGORICAL: codes = self._col.values._codes buffer = PandasBuffer(codes, allow_copy=self._allow_copy) @@ -343,7 +379,7 @@ def _get_data_buffer( return buffer, dtype - def _get_validity_buffer(self) -> tuple[PandasBuffer, Any]: + def _get_validity_buffer(self) -> tuple[Buffer, Any] | None: """ Return the buffer containing the mask values indicating missing data and the buffer's associated dtype. @@ -351,6 +387,21 @@ def _get_validity_buffer(self) -> tuple[PandasBuffer, Any]: """ null, invalid = self.describe_null + if isinstance(self._col.dtype, ArrowDtype): + arr = self._col.array + dtype = (DtypeKind.BOOL, 1, ArrowCTypes.BOOL, Endianness.NATIVE) + if all( + chunk.buffers()[0] is None + for chunk in arr._pa_array.chunks # type: ignore[attr-defined] + ): + return None + buffer: Buffer = PandasBufferPyarrow( + arr._pa_array, # type: ignore[attr-defined] + is_validity=True, + allow_copy=self._allow_copy, + ) + return buffer, dtype + if isinstance(self._col.dtype, BaseMaskedDtype): mask = self._col.array._mask # type: ignore[attr-defined] buffer = PandasBuffer(mask) diff --git a/pandas/core/interchange/from_dataframe.py b/pandas/core/interchange/from_dataframe.py index a952887d7eed2..22ddb12b2a4bb 100644 --- a/pandas/core/interchange/from_dataframe.py +++ b/pandas/core/interchange/from_dataframe.py @@ -298,13 +298,14 @@ def string_column_to_ndarray(col: Column) -> tuple[np.ndarray, Any]: null_pos = None if null_kind in (ColumnNullType.USE_BITMASK, ColumnNullType.USE_BYTEMASK): - assert buffers["validity"], "Validity buffers cannot be empty for masks" - valid_buff, valid_dtype = buffers["validity"] - null_pos = buffer_to_ndarray( - valid_buff, valid_dtype, offset=col.offset, length=col.size() - ) - if sentinel_val == 0: - null_pos = ~null_pos + validity = buffers["validity"] + if validity is not None: + valid_buff, valid_dtype = validity + null_pos = buffer_to_ndarray( + valid_buff, valid_dtype, offset=col.offset, length=col.size() + ) + if sentinel_val == 0: + null_pos = ~null_pos # Assemble the strings from the code units str_list: list[None | float | str] = [None] * col.size() @@ -516,19 +517,21 @@ def set_nulls( np.ndarray or pd.Series Data with the nulls being set. """ + if validity is None: + return data null_kind, sentinel_val = col.describe_null null_pos = None if null_kind == ColumnNullType.USE_SENTINEL: null_pos = pd.Series(data) == sentinel_val elif null_kind in (ColumnNullType.USE_BITMASK, ColumnNullType.USE_BYTEMASK): - assert validity, "Expected to have a validity buffer for the mask" valid_buff, valid_dtype = validity - null_pos = buffer_to_ndarray( - valid_buff, valid_dtype, offset=col.offset, length=col.size() - ) - if sentinel_val == 0: - null_pos = ~null_pos + if valid_buff is not None: + null_pos = buffer_to_ndarray( + valid_buff, valid_dtype, offset=col.offset, length=col.size() + ) + if sentinel_val == 0: + null_pos = ~null_pos elif null_kind in (ColumnNullType.NON_NULLABLE, ColumnNullType.USE_NAN): pass else: diff --git a/pandas/tests/interchange/test_impl.py b/pandas/tests/interchange/test_impl.py index 94b2da894ad0f..ee1efd13072c1 100644 --- a/pandas/tests/interchange/test_impl.py +++ b/pandas/tests/interchange/test_impl.py @@ -1,4 +1,7 @@ -from datetime import datetime +from datetime import ( + datetime, + timezone, +) import numpy as np import pytest @@ -416,42 +419,60 @@ def test_non_str_names_w_duplicates(): pd.api.interchange.from_dataframe(dfi, allow_copy=False) -def test_nullable_integers() -> None: - # https://github.com/pandas-dev/pandas/issues/55069 - df = pd.DataFrame({"a": [1]}, dtype="Int8") - expected = pd.DataFrame({"a": [1]}, dtype="int8") - result = pd.api.interchange.from_dataframe(df.__dataframe__()) - tm.assert_frame_equal(result, expected) - - -@pytest.mark.xfail(reason="https://github.com/pandas-dev/pandas/issues/57664") -def test_nullable_integers_pyarrow() -> None: - # https://github.com/pandas-dev/pandas/issues/55069 - df = pd.DataFrame({"a": [1]}, dtype="Int8[pyarrow]") - expected = pd.DataFrame({"a": [1]}, dtype="int8") - result = pd.api.interchange.from_dataframe(df.__dataframe__()) - tm.assert_frame_equal(result, expected) - - @pytest.mark.parametrize( ("data", "dtype", "expected_dtype"), [ ([1, 2, None], "Int64", "int64"), + ([1, 2, None], "Int64[pyarrow]", "int64"), + ([1, 2, None], "Int8", "int8"), + ([1, 2, None], "Int8[pyarrow]", "int8"), ( [1, 2, None], "UInt64", "uint64", ), + ( + [1, 2, None], + "UInt64[pyarrow]", + "uint64", + ), ([1.0, 2.25, None], "Float32", "float32"), + ([1.0, 2.25, None], "Float32[pyarrow]", "float32"), + ([True, False, None], "boolean[pyarrow]", "bool"), + (["much ado", "about", None], "string[pyarrow_numpy]", "large_string"), + (["much ado", "about", None], "string[pyarrow]", "large_string"), + ( + [datetime(2020, 1, 1), datetime(2020, 1, 2), None], + "timestamp[ns][pyarrow]", + "timestamp[ns]", + ), + ( + [datetime(2020, 1, 1), datetime(2020, 1, 2), None], + "timestamp[us][pyarrow]", + "timestamp[us]", + ), + ( + [ + datetime(2020, 1, 1, tzinfo=timezone.utc), + datetime(2020, 1, 2, tzinfo=timezone.utc), + None, + ], + "timestamp[us, Asia/Kathmandu][pyarrow]", + "timestamp[us, tz=Asia/Kathmandu]", + ), ], ) -def test_pandas_nullable_w_missing_values( +def test_pandas_nullable_with_missing_values( data: list, dtype: str, expected_dtype: str ) -> None: # https://github.com/pandas-dev/pandas/issues/57643 - pytest.importorskip("pyarrow", "11.0.0") + # https://github.com/pandas-dev/pandas/issues/57664 + pa = pytest.importorskip("pyarrow", "11.0.0") import pyarrow.interchange as pai + if expected_dtype == "timestamp[us, tz=Asia/Kathmandu]": + expected_dtype = pa.timestamp("us", "Asia/Kathmandu") + df = pd.DataFrame({"a": data}, dtype=dtype) result = pai.from_dataframe(df.__dataframe__())["a"] assert result.type == expected_dtype @@ -460,6 +481,86 @@ def test_pandas_nullable_w_missing_values( assert result[2].as_py() is None +@pytest.mark.parametrize( + ("data", "dtype", "expected_dtype"), + [ + ([1, 2, 3], "Int64", "int64"), + ([1, 2, 3], "Int64[pyarrow]", "int64"), + ([1, 2, 3], "Int8", "int8"), + ([1, 2, 3], "Int8[pyarrow]", "int8"), + ( + [1, 2, 3], + "UInt64", + "uint64", + ), + ( + [1, 2, 3], + "UInt64[pyarrow]", + "uint64", + ), + ([1.0, 2.25, 5.0], "Float32", "float32"), + ([1.0, 2.25, 5.0], "Float32[pyarrow]", "float32"), + ([True, False, False], "boolean[pyarrow]", "bool"), + (["much ado", "about", "nothing"], "string[pyarrow_numpy]", "large_string"), + (["much ado", "about", "nothing"], "string[pyarrow]", "large_string"), + ( + [datetime(2020, 1, 1), datetime(2020, 1, 2), datetime(2020, 1, 3)], + "timestamp[ns][pyarrow]", + "timestamp[ns]", + ), + ( + [datetime(2020, 1, 1), datetime(2020, 1, 2), datetime(2020, 1, 3)], + "timestamp[us][pyarrow]", + "timestamp[us]", + ), + ( + [ + datetime(2020, 1, 1, tzinfo=timezone.utc), + datetime(2020, 1, 2, tzinfo=timezone.utc), + datetime(2020, 1, 3, tzinfo=timezone.utc), + ], + "timestamp[us, Asia/Kathmandu][pyarrow]", + "timestamp[us, tz=Asia/Kathmandu]", + ), + ], +) +def test_pandas_nullable_without_missing_values( + data: list, dtype: str, expected_dtype: str +) -> None: + # https://github.com/pandas-dev/pandas/issues/57643 + pa = pytest.importorskip("pyarrow", "11.0.0") + import pyarrow.interchange as pai + + if expected_dtype == "timestamp[us, tz=Asia/Kathmandu]": + expected_dtype = pa.timestamp("us", "Asia/Kathmandu") + + df = pd.DataFrame({"a": data}, dtype=dtype) + result = pai.from_dataframe(df.__dataframe__())["a"] + assert result.type == expected_dtype + assert result[0].as_py() == data[0] + assert result[1].as_py() == data[1] + assert result[2].as_py() == data[2] + + +def test_string_validity_buffer() -> None: + # https://github.com/pandas-dev/pandas/issues/57761 + pytest.importorskip("pyarrow", "11.0.0") + df = pd.DataFrame({"a": ["x"]}, dtype="large_string[pyarrow]") + result = df.__dataframe__().get_column_by_name("a").get_buffers()["validity"] + assert result is None + + +def test_string_validity_buffer_no_missing() -> None: + # https://github.com/pandas-dev/pandas/issues/57762 + pytest.importorskip("pyarrow", "11.0.0") + df = pd.DataFrame({"a": ["x", None]}, dtype="large_string[pyarrow]") + validity = df.__dataframe__().get_column_by_name("a").get_buffers()["validity"] + assert validity is not None + result = validity[1] + expected = (DtypeKind.BOOL, 1, ArrowCTypes.BOOL, "=") + assert result == expected + + def test_empty_dataframe(): # https://github.com/pandas-dev/pandas/issues/56700 df = pd.DataFrame({"a": []}, dtype="int8") From 031d9aa1784a9c4d97a760c988ff1dd14364153b Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Wed, 13 Mar 2024 19:47:31 +0000 Subject: [PATCH 02/11] reduce diff --- pandas/core/interchange/from_dataframe.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pandas/core/interchange/from_dataframe.py b/pandas/core/interchange/from_dataframe.py index 22ddb12b2a4bb..4575837fb12fc 100644 --- a/pandas/core/interchange/from_dataframe.py +++ b/pandas/core/interchange/from_dataframe.py @@ -525,13 +525,13 @@ def set_nulls( if null_kind == ColumnNullType.USE_SENTINEL: null_pos = pd.Series(data) == sentinel_val elif null_kind in (ColumnNullType.USE_BITMASK, ColumnNullType.USE_BYTEMASK): + assert validity, "Expected to have a validity buffer for the mask" valid_buff, valid_dtype = validity - if valid_buff is not None: - null_pos = buffer_to_ndarray( - valid_buff, valid_dtype, offset=col.offset, length=col.size() - ) - if sentinel_val == 0: - null_pos = ~null_pos + null_pos = buffer_to_ndarray( + valid_buff, valid_dtype, offset=col.offset, length=col.size() + ) + if sentinel_val == 0: + null_pos = ~null_pos elif null_kind in (ColumnNullType.NON_NULLABLE, ColumnNullType.USE_NAN): pass else: From cec4b4d7e11a38750a687d7c4a7609588f59b0f0 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Wed, 13 Mar 2024 19:48:03 +0000 Subject: [PATCH 03/11] reduce diff --- pandas/core/interchange/buffer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pandas/core/interchange/buffer.py b/pandas/core/interchange/buffer.py index aba404b2c15c9..f122d35dc4921 100644 --- a/pandas/core/interchange/buffer.py +++ b/pandas/core/interchange/buffer.py @@ -128,11 +128,11 @@ def __dlpack__(self) -> Any: """ Represent this structure as DLPack interface. """ - if self._dlpack is not None: - return self._dlpack() - raise NotImplementedError( - "pyarrow>=15.0.0 is required for DLPack support for pyarrow-backed buffers" - ) + if self._dlpack is None: + raise NotImplementedError( + "pyarrow>=15.0.0 required for DLPack support for pyarrow-backed buffers" + ) + return self._dlpack() def __dlpack_device__(self) -> tuple[DlpackDeviceType, int | None]: """ From 9adf45fd3ebb2a5acfdf1ea6d040a7c001b6260b Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 15 Mar 2024 15:47:31 +0000 Subject: [PATCH 04/11] start simplifying --- pandas/core/interchange/buffer.py | 28 +++++----------------------- pandas/core/interchange/column.py | 18 ++++++++++++------ pandas/core/interchange/utils.py | 14 ++++++++++++++ 3 files changed, 31 insertions(+), 29 deletions(-) diff --git a/pandas/core/interchange/buffer.py b/pandas/core/interchange/buffer.py index f122d35dc4921..5d24325e67f62 100644 --- a/pandas/core/interchange/buffer.py +++ b/pandas/core/interchange/buffer.py @@ -86,29 +86,15 @@ class PandasBufferPyarrow(Buffer): def __init__( self, - chunked_array: pa.ChunkedArray, + buffer: pa.Buffer, *, - is_validity: bool, - allow_copy: bool = True, + length: int, ) -> None: """ Handle pyarrow chunked arrays. """ - if len(chunked_array.chunks) == 1: - arr = chunked_array.chunks[0] - else: - if not allow_copy: - raise RuntimeError( - "Found multi-chunk pyarrow array, but `allow_copy` is False" - ) - arr = chunked_array.combine_chunks() - if is_validity: - self._buffer = arr.buffers()[0] - else: - self._buffer = arr.buffers()[1] - self._length = len(arr) - self._dlpack = getattr(arr, "__dlpack__", None) - self._is_validity = is_validity + self._buffer = buffer + self._length = length @property def bufsize(self) -> int: @@ -128,11 +114,7 @@ def __dlpack__(self) -> Any: """ Represent this structure as DLPack interface. """ - if self._dlpack is None: - raise NotImplementedError( - "pyarrow>=15.0.0 required for DLPack support for pyarrow-backed buffers" - ) - return self._dlpack() + raise NotImplementedError() def __dlpack_device__(self) -> tuple[DlpackDeviceType, int | None]: """ diff --git a/pandas/core/interchange/column.py b/pandas/core/interchange/column.py index 507efe1c74b19..da4b67054e487 100644 --- a/pandas/core/interchange/column.py +++ b/pandas/core/interchange/column.py @@ -34,6 +34,7 @@ ArrowCTypes, Endianness, dtype_to_arrow_c_fmt, + maybe_rechunk, ) if TYPE_CHECKING: @@ -167,6 +168,8 @@ def _dtype_from_pandasdtype(self, dtype) -> tuple[DtypeKind, int, str, str]: byteorder = dtype.byteorder if dtype == "bool[pyarrow]": + # return early to avoid the `* 8` below, as this is a bitmask + # rather than a bytemask return ( kind, dtype.itemsize, # pyright: ignore[reportAttributeAccessIssue] @@ -329,10 +332,10 @@ def _get_data_buffer( dtype = self.dtype arr = self._col.array if isinstance(self._col.dtype, ArrowDtype): + arr = maybe_rechunk(arr._pa_array, allow_copy=self._allow_copy) buffer = PandasBufferPyarrow( - arr._pa_array, # type: ignore[attr-defined] - is_validity=False, - allow_copy=self._allow_copy, + arr.buffers()[1], # type: ignore[attr-defined] + length=len(arr), ) if self.dtype[0] == DtypeKind.BOOL: dtype = ( @@ -395,10 +398,13 @@ def _get_validity_buffer(self) -> tuple[Buffer, Any] | None: for chunk in arr._pa_array.chunks # type: ignore[attr-defined] ): return None + chunked_array = arr._pa_array + arr = maybe_rechunk(chunked_array, allow_copy=self._allow_copy) + if arr.buffers()[0] is None: + return None buffer: Buffer = PandasBufferPyarrow( - arr._pa_array, # type: ignore[attr-defined] - is_validity=True, - allow_copy=self._allow_copy, + arr.buffers()[0], # type: ignore[attr-defined] + length=len(arr), ) return buffer, dtype diff --git a/pandas/core/interchange/utils.py b/pandas/core/interchange/utils.py index 2e73e560e5740..0680c8d8cd525 100644 --- a/pandas/core/interchange/utils.py +++ b/pandas/core/interchange/utils.py @@ -17,6 +17,8 @@ ) if typing.TYPE_CHECKING: + import pyarrow as pa + from pandas._typing import DtypeObj @@ -145,3 +147,15 @@ def dtype_to_arrow_c_fmt(dtype: DtypeObj) -> str: raise NotImplementedError( f"Conversion of {dtype} to Arrow C format string is not implemented." ) + + +def maybe_rechunk(chunked_array: pa.ChunkedArray, *, allow_copy: bool) -> pa.Array: + if len(chunked_array.chunks) == 1: + arr = chunked_array.chunks[0] + else: + if not allow_copy: + raise RuntimeError( + "Found multi-chunk pyarrow array, but `allow_copy` is False" + ) + arr = chunked_array.combine_chunks() + return arr From 080e54f6a0a024c6f9f555390073c0695c5d86a0 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 15 Mar 2024 16:05:03 +0000 Subject: [PATCH 05/11] simplify, remove is_validity arg --- pandas/core/interchange/column.py | 23 +++++++++-------------- pandas/core/interchange/dataframe.py | 5 +++++ pandas/core/interchange/utils.py | 22 +++++++++++----------- 3 files changed, 25 insertions(+), 25 deletions(-) diff --git a/pandas/core/interchange/column.py b/pandas/core/interchange/column.py index da4b67054e487..69a0c07ddb132 100644 --- a/pandas/core/interchange/column.py +++ b/pandas/core/interchange/column.py @@ -34,7 +34,6 @@ ArrowCTypes, Endianness, dtype_to_arrow_c_fmt, - maybe_rechunk, ) if TYPE_CHECKING: @@ -215,10 +214,9 @@ def describe_null(self): null_value = 1 return column_null_dtype, null_value if isinstance(self._col.dtype, ArrowDtype): - if all( - chunk.buffers()[0] is None - for chunk in self._col.array._pa_array.chunks # type: ignore[attr-defined] - ): + # We already rechunk (if necessary / allowed) upon initialization, so this + # is already single-chunk by the time we get here. + if self._col.array._pa_array.chunks[0].buffers()[0] is None: # type: ignore[attr-defined] return ColumnNullType.NON_NULLABLE, None return ColumnNullType.USE_BITMASK, 0 kind = self.dtype[0] @@ -332,7 +330,9 @@ def _get_data_buffer( dtype = self.dtype arr = self._col.array if isinstance(self._col.dtype, ArrowDtype): - arr = maybe_rechunk(arr._pa_array, allow_copy=self._allow_copy) + # We already rechunk (if necessary / allowed) upon initialization, so + # this is already single-chunk by the time we get here. + arr = arr._pa_array.chunks[0] buffer = PandasBufferPyarrow( arr.buffers()[1], # type: ignore[attr-defined] length=len(arr), @@ -391,15 +391,10 @@ def _get_validity_buffer(self) -> tuple[Buffer, Any] | None: null, invalid = self.describe_null if isinstance(self._col.dtype, ArrowDtype): - arr = self._col.array + # We already rechunk (if necessary / allowed) upon initialization, so this + # is already single-chunk by the time we get here. + arr = self._col.array._pa_array.chunks[0] dtype = (DtypeKind.BOOL, 1, ArrowCTypes.BOOL, Endianness.NATIVE) - if all( - chunk.buffers()[0] is None - for chunk in arr._pa_array.chunks # type: ignore[attr-defined] - ): - return None - chunked_array = arr._pa_array - arr = maybe_rechunk(chunked_array, allow_copy=self._allow_copy) if arr.buffers()[0] is None: return None buffer: Buffer = PandasBufferPyarrow( diff --git a/pandas/core/interchange/dataframe.py b/pandas/core/interchange/dataframe.py index 1ffe0e8e8dbb0..93e608049e0b8 100644 --- a/pandas/core/interchange/dataframe.py +++ b/pandas/core/interchange/dataframe.py @@ -5,6 +5,7 @@ from pandas.core.interchange.column import PandasColumn from pandas.core.interchange.dataframe_protocol import DataFrame as DataFrameXchg +from pandas.core.interchange.utils import maybe_rechunk if TYPE_CHECKING: from collections.abc import ( @@ -34,6 +35,10 @@ def __init__(self, df: DataFrame, allow_copy: bool = True) -> None: """ self._df = df.rename(columns=str, copy=False) self._allow_copy = allow_copy + for i, col in enumerate(self._df.columns): + rechunked = maybe_rechunk(self._df.iloc[:, i], allow_copy=allow_copy) + if rechunked is not None: + self._df.isetitem(i, rechunked) def __dataframe__( self, nan_as_null: bool = False, allow_copy: bool = True diff --git a/pandas/core/interchange/utils.py b/pandas/core/interchange/utils.py index 0680c8d8cd525..2d353240c61f9 100644 --- a/pandas/core/interchange/utils.py +++ b/pandas/core/interchange/utils.py @@ -16,9 +16,9 @@ DatetimeTZDtype, ) -if typing.TYPE_CHECKING: - import pyarrow as pa +import pandas as pd +if typing.TYPE_CHECKING: from pandas._typing import DtypeObj @@ -149,13 +149,13 @@ def dtype_to_arrow_c_fmt(dtype: DtypeObj) -> str: ) -def maybe_rechunk(chunked_array: pa.ChunkedArray, *, allow_copy: bool) -> pa.Array: +def maybe_rechunk(series: pd.Series, *, allow_copy: bool) -> pd.Series | None: + if not isinstance(series.dtype, pd.ArrowDtype): + return None + chunked_array = series.array._pa_array if len(chunked_array.chunks) == 1: - arr = chunked_array.chunks[0] - else: - if not allow_copy: - raise RuntimeError( - "Found multi-chunk pyarrow array, but `allow_copy` is False" - ) - arr = chunked_array.combine_chunks() - return arr + return None + if not allow_copy: + raise RuntimeError("Found multi-chunk pyarrow array, but `allow_copy` is False") + arr = chunked_array.combine_chunks() + return pd.Series(arr, dtype=series.dtype, name=series.name, index=series.index) From 9344458a656a944084fb3b5cbe7e6e8a44da37b9 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 15 Mar 2024 16:11:10 +0000 Subject: [PATCH 06/11] remove unnecessary branch --- pandas/core/interchange/column.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/pandas/core/interchange/column.py b/pandas/core/interchange/column.py index 69a0c07ddb132..003228f65464f 100644 --- a/pandas/core/interchange/column.py +++ b/pandas/core/interchange/column.py @@ -337,13 +337,6 @@ def _get_data_buffer( arr.buffers()[1], # type: ignore[attr-defined] length=len(arr), ) - if self.dtype[0] == DtypeKind.BOOL: - dtype = ( - DtypeKind.BOOL, - 1, - ArrowCTypes.BOOL, - Endianness.NATIVE, - ) return buffer, dtype if isinstance(self._col.dtype, BaseMaskedDtype): np_arr = arr._data # type: ignore[attr-defined] From c2f5bfa923a11b0674b8bb8be266a7ceb0a49245 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 15 Mar 2024 16:15:10 +0000 Subject: [PATCH 07/11] doc maybe_rechunk --- pandas/core/interchange/utils.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/pandas/core/interchange/utils.py b/pandas/core/interchange/utils.py index 2d353240c61f9..f77a0829659cb 100644 --- a/pandas/core/interchange/utils.py +++ b/pandas/core/interchange/utils.py @@ -150,6 +150,16 @@ def dtype_to_arrow_c_fmt(dtype: DtypeObj) -> str: def maybe_rechunk(series: pd.Series, *, allow_copy: bool) -> pd.Series | None: + """ + Rechunk a multi-chunk pyarrow array into a single-chunk array, if necessary. + + - Returns `None` if the input series is not backed by a multi-chunk pyarrow array + (and so doesn't need rechunking) + - Returns a single-chunk-backed-Series if the input is backed by a multi-chunk + pyarrow array and `allow_copy` is `True`. + - Raises a `RuntimeError` if `allow_copy` is `False` and input is a + based by a multi-chunk pyarrow array. + """ if not isinstance(series.dtype, pd.ArrowDtype): return None chunked_array = series.array._pa_array From e4531a0632fa9bb115d722a4e6665dde98436ce1 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 15 Mar 2024 17:57:58 +0000 Subject: [PATCH 08/11] mypy --- pandas/core/interchange/column.py | 6 +++--- pandas/core/interchange/utils.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pandas/core/interchange/column.py b/pandas/core/interchange/column.py index 003228f65464f..8772b6eea62f3 100644 --- a/pandas/core/interchange/column.py +++ b/pandas/core/interchange/column.py @@ -332,7 +332,7 @@ def _get_data_buffer( if isinstance(self._col.dtype, ArrowDtype): # We already rechunk (if necessary / allowed) upon initialization, so # this is already single-chunk by the time we get here. - arr = arr._pa_array.chunks[0] + arr = arr._pa_array.chunks[0] # type: ignore[attr-defined] buffer = PandasBufferPyarrow( arr.buffers()[1], # type: ignore[attr-defined] length=len(arr), @@ -386,12 +386,12 @@ def _get_validity_buffer(self) -> tuple[Buffer, Any] | None: if isinstance(self._col.dtype, ArrowDtype): # We already rechunk (if necessary / allowed) upon initialization, so this # is already single-chunk by the time we get here. - arr = self._col.array._pa_array.chunks[0] + arr = self._col.array._pa_array.chunks[0] # type: ignore[attr-defined] dtype = (DtypeKind.BOOL, 1, ArrowCTypes.BOOL, Endianness.NATIVE) if arr.buffers()[0] is None: return None buffer: Buffer = PandasBufferPyarrow( - arr.buffers()[0], # type: ignore[attr-defined] + arr.buffers()[0], length=len(arr), ) return buffer, dtype diff --git a/pandas/core/interchange/utils.py b/pandas/core/interchange/utils.py index f77a0829659cb..77701df35fba3 100644 --- a/pandas/core/interchange/utils.py +++ b/pandas/core/interchange/utils.py @@ -162,7 +162,7 @@ def maybe_rechunk(series: pd.Series, *, allow_copy: bool) -> pd.Series | None: """ if not isinstance(series.dtype, pd.ArrowDtype): return None - chunked_array = series.array._pa_array + chunked_array = series.array._pa_array # type: ignore[attr-defined] if len(chunked_array.chunks) == 1: return None if not allow_copy: From 0d89d9710aefba7fc8a5ace3b356c158f204d5cd Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 15 Mar 2024 18:23:36 +0000 Subject: [PATCH 09/11] extra test --- pandas/core/interchange/utils.py | 6 +++++- pandas/tests/interchange/test_impl.py | 15 +++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/pandas/core/interchange/utils.py b/pandas/core/interchange/utils.py index 77701df35fba3..2a19dd5046aa3 100644 --- a/pandas/core/interchange/utils.py +++ b/pandas/core/interchange/utils.py @@ -166,6 +166,10 @@ def maybe_rechunk(series: pd.Series, *, allow_copy: bool) -> pd.Series | None: if len(chunked_array.chunks) == 1: return None if not allow_copy: - raise RuntimeError("Found multi-chunk pyarrow array, but `allow_copy` is False") + raise RuntimeError( + "Found multi-chunk pyarrow array, but `allow_copy` is False. " + "Please rechunk the array before calling this function, or set " + "`allow_copy=True`." + ) arr = chunked_array.combine_chunks() return pd.Series(arr, dtype=series.dtype, name=series.name, index=series.index) diff --git a/pandas/tests/interchange/test_impl.py b/pandas/tests/interchange/test_impl.py index ee1efd13072c1..622a1eeee96ff 100644 --- a/pandas/tests/interchange/test_impl.py +++ b/pandas/tests/interchange/test_impl.py @@ -294,6 +294,21 @@ def test_multi_chunk_pyarrow() -> None: pd.api.interchange.from_dataframe(table, allow_copy=False) +def test_multi_chunk_column() -> None: + pytest.importorskip("pyarrow", "11.0.0") + ser = pd.Series([1, 2, None], dtype="Int64[pyarrow]") + df = pd.concat([ser, ser], ignore_index=True).to_frame("a") + with pytest.raises( + RuntimeError, match="Found multi-chunk pyarrow array, but `allow_copy` is False" + ): + pd.api.interchange.from_dataframe(df.__dataframe__(allow_copy=False)) + result = pd.api.interchange.from_dataframe(df.__dataframe__(allow_copy=True)) + # Interchange protocol defaults to creating numpy-backed columns, so currently this + # is 'float64'. + expected = pd.DataFrame({"a": [1.0, 2.0, None, 1.0, 2.0, None]}, dtype="float64") + tm.assert_frame_equal(result, expected) + + def test_timestamp_ns_pyarrow(): # GH 56712 pytest.importorskip("pyarrow", "11.0.0") From d85c904a0e1120ba645006889a2b281a2ef9c36c Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 16 Mar 2024 07:46:13 +0000 Subject: [PATCH 10/11] mark _col unused, assert rechunking did not modify original df --- pandas/core/interchange/dataframe.py | 2 +- pandas/tests/interchange/test_impl.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/pandas/core/interchange/dataframe.py b/pandas/core/interchange/dataframe.py index 93e608049e0b8..1abacddfc7e3b 100644 --- a/pandas/core/interchange/dataframe.py +++ b/pandas/core/interchange/dataframe.py @@ -35,7 +35,7 @@ def __init__(self, df: DataFrame, allow_copy: bool = True) -> None: """ self._df = df.rename(columns=str, copy=False) self._allow_copy = allow_copy - for i, col in enumerate(self._df.columns): + for i, _col in enumerate(self._df.columns): rechunked = maybe_rechunk(self._df.iloc[:, i], allow_copy=allow_copy) if rechunked is not None: self._df.isetitem(i, rechunked) diff --git a/pandas/tests/interchange/test_impl.py b/pandas/tests/interchange/test_impl.py index 622a1eeee96ff..83574e8630d6f 100644 --- a/pandas/tests/interchange/test_impl.py +++ b/pandas/tests/interchange/test_impl.py @@ -298,6 +298,7 @@ def test_multi_chunk_column() -> None: pytest.importorskip("pyarrow", "11.0.0") ser = pd.Series([1, 2, None], dtype="Int64[pyarrow]") df = pd.concat([ser, ser], ignore_index=True).to_frame("a") + df_orig = df.copy() with pytest.raises( RuntimeError, match="Found multi-chunk pyarrow array, but `allow_copy` is False" ): @@ -308,6 +309,11 @@ def test_multi_chunk_column() -> None: expected = pd.DataFrame({"a": [1.0, 2.0, None, 1.0, 2.0, None]}, dtype="float64") tm.assert_frame_equal(result, expected) + # Check that the rechunking we did didn't modify the original DataFrame. + tm.assert_frame_equal(df, df_orig) + assert len(df["a"].array._pa_array.chunks) == 2 + assert len(df_orig["a"].array._pa_array.chunks) == 2 + def test_timestamp_ns_pyarrow(): # GH 56712 From db0f40260b5741ec61dd9e01fd40ab7f1b7f9247 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Wed, 20 Mar 2024 18:17:12 +0000 Subject: [PATCH 11/11] declare buffer: Buffer outside of if/else branch --- pandas/core/interchange/column.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pandas/core/interchange/column.py b/pandas/core/interchange/column.py index 8772b6eea62f3..c27a9d8141712 100644 --- a/pandas/core/interchange/column.py +++ b/pandas/core/interchange/column.py @@ -307,6 +307,7 @@ def _get_data_buffer( """ Return the buffer containing the data and the buffer's associated dtype. """ + buffer: Buffer if self.dtype[0] == DtypeKind.DATETIME: # self.dtype[2] is an ArrowCTypes.TIMESTAMP where the tz will make # it longer than 4 characters @@ -314,7 +315,7 @@ def _get_data_buffer( np_arr = self._col.dt.tz_convert(None).to_numpy() else: np_arr = self._col.to_numpy() - buffer: Buffer = PandasBuffer(np_arr, allow_copy=self._allow_copy) + buffer = PandasBuffer(np_arr, allow_copy=self._allow_copy) dtype = ( DtypeKind.INT, 64, @@ -382,7 +383,7 @@ def _get_validity_buffer(self) -> tuple[Buffer, Any] | None: Raises NoBufferPresent if null representation is not a bit or byte mask. """ null, invalid = self.describe_null - + buffer: Buffer if isinstance(self._col.dtype, ArrowDtype): # We already rechunk (if necessary / allowed) upon initialization, so this # is already single-chunk by the time we get here. @@ -390,7 +391,7 @@ def _get_validity_buffer(self) -> tuple[Buffer, Any] | None: dtype = (DtypeKind.BOOL, 1, ArrowCTypes.BOOL, Endianness.NATIVE) if arr.buffers()[0] is None: return None - buffer: Buffer = PandasBufferPyarrow( + buffer = PandasBufferPyarrow( arr.buffers()[0], length=len(arr), )