Skip to content

Commit

Permalink
BUG: Interchange object data buffer has the wrong dtype / from_datafr…
Browse files Browse the repository at this point in the history
…ame incorrect (pandas-dev#57570)

string
  • Loading branch information
MarcoGorelli authored and pmhatre1 committed May 7, 2024
1 parent bf60b12 commit dd9fda3
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 11 deletions.
2 changes: 0 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,11 @@ repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: check-ast
- id: check-case-conflict
- id: check-toml
- id: check-xml
- id: check-yaml
exclude: ^ci/meta.yaml$
- id: debug-statements
- id: end-of-file-fixer
exclude: \.txt$
- id: mixed-line-ending
Expand Down
1 change: 1 addition & 0 deletions doc/source/whatsnew/v3.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ Other
- Bug in :func:`tseries.api.guess_datetime_format` would fail to infer time format when "%Y" == "%H%M" (:issue:`57452`)
- Bug in :meth:`DataFrame.sort_index` when passing ``axis="columns"`` and ``ignore_index=True`` and ``ascending=False`` not returning a :class:`RangeIndex` columns (:issue:`57293`)
- Bug in :meth:`DataFrame.where` where using a non-bool type array in the function would return a ``ValueError`` instead of a ``TypeError`` (:issue:`56330`)
- Bug in Dataframe Interchange Protocol implementation was returning incorrect results for data buffers' associated dtype, for string and datetime columns (:issue:`54781`)

.. ***DO NOT USE THIS SECTION***
Expand Down
31 changes: 22 additions & 9 deletions pandas/core/interchange/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,20 +278,28 @@ def _get_data_buffer(
"""
Return the buffer containing the data and the buffer's associated dtype.
"""
if self.dtype[0] in (
DtypeKind.INT,
DtypeKind.UINT,
DtypeKind.FLOAT,
DtypeKind.BOOL,
DtypeKind.DATETIME,
):
if self.dtype[0] == DtypeKind.DATETIME:
# self.dtype[2] is an ArrowCTypes.TIMESTAMP where the tz will make
# it longer than 4 characters
if self.dtype[0] == DtypeKind.DATETIME and len(self.dtype[2]) > 4:
if len(self.dtype[2]) > 4:
np_arr = self._col.dt.tz_convert(None).to_numpy()
else:
np_arr = self._col.to_numpy()
buffer = PandasBuffer(np_arr, allow_copy=self._allow_copy)
dtype = (
DtypeKind.INT,
64,
ArrowCTypes.INT64,
Endianness.NATIVE,
)
elif self.dtype[0] in (
DtypeKind.INT,
DtypeKind.UINT,
DtypeKind.FLOAT,
DtypeKind.BOOL,
):
np_arr = self._col.to_numpy()
buffer = PandasBuffer(np_arr, allow_copy=self._allow_copy)
dtype = self.dtype
elif self.dtype[0] == DtypeKind.CATEGORICAL:
codes = self._col.values._codes
Expand All @@ -314,7 +322,12 @@ def _get_data_buffer(
# Define the dtype for the returned buffer
# TODO: this will need correcting
# https://github.com/pandas-dev/pandas/issues/54781
dtype = self.dtype
dtype = (
DtypeKind.UINT,
8,
ArrowCTypes.UINT8,
Endianness.NATIVE,
) # note: currently only support native endianness
else:
raise NotImplementedError(f"Data type {self._col.dtype} not handled yet")

Expand Down
45 changes: 45 additions & 0 deletions pandas/tests/interchange/test_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,3 +435,48 @@ def test_empty_dataframe():
result = pd.api.interchange.from_dataframe(dfi, allow_copy=False)
expected = pd.DataFrame({"a": []}, dtype="int8")
tm.assert_frame_equal(result, expected)


@pytest.mark.parametrize(
("data", "expected_dtype", "expected_buffer_dtype"),
[
(
pd.Series(["a", "b", "a"], dtype="category"),
(DtypeKind.CATEGORICAL, 8, "c", "="),
(DtypeKind.INT, 8, "c", "|"),
),
(
pd.Series(
[datetime(2022, 1, 1), datetime(2022, 1, 2), datetime(2022, 1, 3)]
),
(DtypeKind.DATETIME, 64, "tsn:", "="),
(DtypeKind.INT, 64, ArrowCTypes.INT64, "="),
),
(
pd.Series(["a", "bc", None]),
(DtypeKind.STRING, 8, ArrowCTypes.STRING, "="),
(DtypeKind.UINT, 8, ArrowCTypes.UINT8, "="),
),
(
pd.Series([1, 2, 3]),
(DtypeKind.INT, 64, ArrowCTypes.INT64, "="),
(DtypeKind.INT, 64, ArrowCTypes.INT64, "="),
),
(
pd.Series([1.5, 2, 3]),
(DtypeKind.FLOAT, 64, ArrowCTypes.FLOAT64, "="),
(DtypeKind.FLOAT, 64, ArrowCTypes.FLOAT64, "="),
),
],
)
def test_buffer_dtype_categorical(
data: pd.Series,
expected_dtype: tuple[DtypeKind, int, str, str],
expected_buffer_dtype: tuple[DtypeKind, int, str, str],
) -> None:
# https://github.com/pandas-dev/pandas/issues/54781
df = pd.DataFrame({"data": data})
dfi = df.__dataframe__()
col = dfi.get_column_by_name("data")
assert col.dtype == expected_dtype
assert col.get_buffers()["data"][1] == expected_buffer_dtype

0 comments on commit dd9fda3

Please sign in to comment.