Skip to content

Commit

Permalink
ENH: Enable pytables to round-trip with StringDtype (#60663)
Browse files Browse the repository at this point in the history
Co-authored-by: William Ayd <william.ayd@icloud.com>
  • Loading branch information
rhshadrach and WillAyd authored Jan 23, 2025
1 parent 4c3b968 commit 60325b8
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 20 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.3.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Other enhancements
- The semantics for the ``copy`` keyword in ``__array__`` methods (i.e. called
when using ``np.array()`` or ``np.asarray()`` on pandas objects) has been
updated to work correctly with NumPy >= 2 (:issue:`57739`)
- :meth:`~Series.to_hdf` and :meth:`~DataFrame.to_hdf` now round-trip with ``StringDtype`` (:issue:`60663`)
- The :meth:`~Series.cumsum`, :meth:`~Series.cummin`, and :meth:`~Series.cummax` reductions are now implemented for ``StringDtype`` columns when backed by PyArrow (:issue:`60633`)
- The :meth:`~Series.sum` reduction is now implemented for ``StringDtype`` columns (:issue:`59853`)

Expand Down
36 changes: 30 additions & 6 deletions pandas/io/pytables.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,16 @@
PeriodArray,
)
from pandas.core.arrays.datetimes import tz_to_dtype
from pandas.core.arrays.string_ import BaseStringArray
import pandas.core.common as com
from pandas.core.computation.pytables import (
PyTablesExpr,
maybe_expression,
)
from pandas.core.construction import extract_array
from pandas.core.construction import (
array as pd_array,
extract_array,
)
from pandas.core.indexes.api import ensure_index

from pandas.io.common import stringify_path
Expand Down Expand Up @@ -3023,6 +3027,9 @@ def read_array(self, key: str, start: int | None = None, stop: int | None = None

if isinstance(node, tables.VLArray):
ret = node[0][start:stop]
dtype = getattr(attrs, "value_type", None)
if dtype is not None:
ret = pd_array(ret, dtype=dtype)
else:
dtype = getattr(attrs, "value_type", None)
shape = getattr(attrs, "shape", None)
Expand Down Expand Up @@ -3262,6 +3269,11 @@ def write_array(
elif lib.is_np_dtype(value.dtype, "m"):
self._handle.create_array(self.group, key, value.view("i8"))
getattr(self.group, key)._v_attrs.value_type = "timedelta64"
elif isinstance(value, BaseStringArray):
vlarr = self._handle.create_vlarray(self.group, key, _tables().ObjectAtom())
vlarr.append(value.to_numpy())
node = getattr(self.group, key)
node._v_attrs.value_type = str(value.dtype)
elif empty_array:
self.write_array_empty(key, value)
else:
Expand Down Expand Up @@ -3294,7 +3306,11 @@ def read(
index = self.read_index("index", start=start, stop=stop)
values = self.read_array("values", start=start, stop=stop)
result = Series(values, index=index, name=self.name, copy=False)
if using_string_dtype() and is_string_array(values, skipna=True):
if (
using_string_dtype()
and isinstance(values, np.ndarray)
and is_string_array(values, skipna=True)
):
result = result.astype(StringDtype(na_value=np.nan))
return result

Expand Down Expand Up @@ -3363,7 +3379,11 @@ def read(

columns = items[items.get_indexer(blk_items)]
df = DataFrame(values.T, columns=columns, index=axes[1], copy=False)
if using_string_dtype() and is_string_array(values, skipna=True):
if (
using_string_dtype()
and isinstance(values, np.ndarray)
and is_string_array(values, skipna=True)
):
df = df.astype(StringDtype(na_value=np.nan))
dfs.append(df)

Expand Down Expand Up @@ -4737,9 +4757,13 @@ def read(
df = DataFrame._from_arrays([values], columns=cols_, index=index_)
if not (using_string_dtype() and values.dtype.kind == "O"):
assert (df.dtypes == values.dtype).all(), (df.dtypes, values.dtype)
if using_string_dtype() and is_string_array(
values, # type: ignore[arg-type]
skipna=True,
if (
using_string_dtype()
and isinstance(values, np.ndarray)
and is_string_array(
values,
skipna=True,
)
):
df = df.astype(StringDtype(na_value=np.nan))
frames.append(df)
Expand Down
70 changes: 56 additions & 14 deletions pandas/tests/io/pytables/test_put.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import numpy as np
import pytest

from pandas._config import using_string_dtype

from pandas._libs.tslibs import Timestamp

import pandas as pd
Expand All @@ -26,7 +24,6 @@

pytestmark = [
pytest.mark.single_cpu,
pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False),
]


Expand Down Expand Up @@ -54,8 +51,8 @@ def test_api_default_format(tmp_path, setup_path):
with ensure_clean_store(setup_path) as store:
df = DataFrame(
1.1 * np.arange(120).reshape((30, 4)),
columns=Index(list("ABCD"), dtype=object),
index=Index([f"i-{i}" for i in range(30)], dtype=object),
columns=Index(list("ABCD")),
index=Index([f"i-{i}" for i in range(30)]),
)

with pd.option_context("io.hdf.default_format", "fixed"):
Expand All @@ -79,8 +76,8 @@ def test_api_default_format(tmp_path, setup_path):
path = tmp_path / setup_path
df = DataFrame(
1.1 * np.arange(120).reshape((30, 4)),
columns=Index(list("ABCD"), dtype=object),
index=Index([f"i-{i}" for i in range(30)], dtype=object),
columns=Index(list("ABCD")),
index=Index([f"i-{i}" for i in range(30)]),
)

with pd.option_context("io.hdf.default_format", "fixed"):
Expand All @@ -106,7 +103,7 @@ def test_put(setup_path):
)
df = DataFrame(
np.random.default_rng(2).standard_normal((20, 4)),
columns=Index(list("ABCD"), dtype=object),
columns=Index(list("ABCD")),
index=date_range("2000-01-01", periods=20, freq="B"),
)
store["a"] = ts
Expand Down Expand Up @@ -166,7 +163,7 @@ def test_put_compression(setup_path):
with ensure_clean_store(setup_path) as store:
df = DataFrame(
np.random.default_rng(2).standard_normal((10, 4)),
columns=Index(list("ABCD"), dtype=object),
columns=Index(list("ABCD")),
index=date_range("2000-01-01", periods=10, freq="B"),
)

Expand All @@ -183,7 +180,7 @@ def test_put_compression(setup_path):
def test_put_compression_blosc(setup_path):
df = DataFrame(
np.random.default_rng(2).standard_normal((10, 4)),
columns=Index(list("ABCD"), dtype=object),
columns=Index(list("ABCD")),
index=date_range("2000-01-01", periods=10, freq="B"),
)

Expand All @@ -197,10 +194,20 @@ def test_put_compression_blosc(setup_path):
tm.assert_frame_equal(store["c"], df)


def test_put_mixed_type(setup_path, performance_warning):
def test_put_datetime_ser(setup_path, performance_warning, using_infer_string):
# https://github.com/pandas-dev/pandas/pull/60663
ser = Series(3 * [Timestamp("20010102").as_unit("ns")])
with ensure_clean_store(setup_path) as store:
store.put("ser", ser)
expected = ser.copy()
result = store.get("ser")
tm.assert_series_equal(result, expected)


def test_put_mixed_type(setup_path, performance_warning, using_infer_string):
df = DataFrame(
np.random.default_rng(2).standard_normal((10, 4)),
columns=Index(list("ABCD"), dtype=object),
columns=Index(list("ABCD")),
index=date_range("2000-01-01", periods=10, freq="B"),
)
df["obj1"] = "foo"
Expand All @@ -220,13 +227,42 @@ def test_put_mixed_type(setup_path, performance_warning):
with ensure_clean_store(setup_path) as store:
_maybe_remove(store, "df")

with tm.assert_produces_warning(performance_warning):
warning = None if using_infer_string else performance_warning
with tm.assert_produces_warning(warning):
store.put("df", df)

expected = store.get("df")
tm.assert_frame_equal(expected, df)


def test_put_str_frame(setup_path, performance_warning, string_dtype_arguments):
# https://github.com/pandas-dev/pandas/pull/60663
dtype = pd.StringDtype(*string_dtype_arguments)
df = DataFrame({"a": pd.array(["x", pd.NA, "y"], dtype=dtype)})
with ensure_clean_store(setup_path) as store:
_maybe_remove(store, "df")

store.put("df", df)
expected_dtype = "str" if dtype.na_value is np.nan else "string"
expected = df.astype(expected_dtype)
result = store.get("df")
tm.assert_frame_equal(result, expected)


def test_put_str_series(setup_path, performance_warning, string_dtype_arguments):
# https://github.com/pandas-dev/pandas/pull/60663
dtype = pd.StringDtype(*string_dtype_arguments)
ser = Series(["x", pd.NA, "y"], dtype=dtype)
with ensure_clean_store(setup_path) as store:
_maybe_remove(store, "df")

store.put("ser", ser)
expected_dtype = "str" if dtype.na_value is np.nan else "string"
expected = ser.astype(expected_dtype)
result = store.get("ser")
tm.assert_series_equal(result, expected)


@pytest.mark.parametrize("format", ["table", "fixed"])
@pytest.mark.parametrize(
"index",
Expand All @@ -253,7 +289,7 @@ def test_store_index_types(setup_path, format, index):
tm.assert_frame_equal(df, store["df"])


def test_column_multiindex(setup_path):
def test_column_multiindex(setup_path, using_infer_string):
# GH 4710
# recreate multi-indexes properly

Expand All @@ -264,6 +300,12 @@ def test_column_multiindex(setup_path):
expected = df.set_axis(df.index.to_numpy())

with ensure_clean_store(setup_path) as store:
if using_infer_string:
# TODO(infer_string) make this work for string dtype
msg = "Saving a MultiIndex with an extension dtype is not supported."
with pytest.raises(NotImplementedError, match=msg):
store.put("df", df)
return
store.put("df", df)
tm.assert_frame_equal(
store["df"], expected, check_index_type=True, check_column_type=True
Expand Down

0 comments on commit 60325b8

Please sign in to comment.