From 60325b86e28edf40cb02444367efbc8deb2b5231 Mon Sep 17 00:00:00 2001 From: Richard Shadrach <45562402+rhshadrach@users.noreply.github.com> Date: Thu, 23 Jan 2025 02:38:26 -0500 Subject: [PATCH] ENH: Enable pytables to round-trip with StringDtype (#60663) Co-authored-by: William Ayd --- doc/source/whatsnew/v2.3.0.rst | 1 + pandas/io/pytables.py | 36 +++++++++++--- pandas/tests/io/pytables/test_put.py | 70 ++++++++++++++++++++++------ 3 files changed, 87 insertions(+), 20 deletions(-) diff --git a/doc/source/whatsnew/v2.3.0.rst b/doc/source/whatsnew/v2.3.0.rst index de1118b56dc81..108ee62d88409 100644 --- a/doc/source/whatsnew/v2.3.0.rst +++ b/doc/source/whatsnew/v2.3.0.rst @@ -35,6 +35,7 @@ Other enhancements - The semantics for the ``copy`` keyword in ``__array__`` methods (i.e. called when using ``np.array()`` or ``np.asarray()`` on pandas objects) has been updated to work correctly with NumPy >= 2 (:issue:`57739`) +- :meth:`~Series.to_hdf` and :meth:`~DataFrame.to_hdf` now round-trip with ``StringDtype`` (:issue:`60663`) - The :meth:`~Series.cumsum`, :meth:`~Series.cummin`, and :meth:`~Series.cummax` reductions are now implemented for ``StringDtype`` columns when backed by PyArrow (:issue:`60633`) - The :meth:`~Series.sum` reduction is now implemented for ``StringDtype`` columns (:issue:`59853`) diff --git a/pandas/io/pytables.py b/pandas/io/pytables.py index b75dc6c3a43b4..2f8096746318b 100644 --- a/pandas/io/pytables.py +++ b/pandas/io/pytables.py @@ -86,12 +86,16 @@ PeriodArray, ) from pandas.core.arrays.datetimes import tz_to_dtype +from pandas.core.arrays.string_ import BaseStringArray import pandas.core.common as com from pandas.core.computation.pytables import ( PyTablesExpr, maybe_expression, ) -from pandas.core.construction import extract_array +from pandas.core.construction import ( + array as pd_array, + extract_array, +) from pandas.core.indexes.api import ensure_index from pandas.io.common import stringify_path @@ -3023,6 +3027,9 @@ def read_array(self, key: str, start: int | None = None, stop: int | None = None if isinstance(node, tables.VLArray): ret = node[0][start:stop] + dtype = getattr(attrs, "value_type", None) + if dtype is not None: + ret = pd_array(ret, dtype=dtype) else: dtype = getattr(attrs, "value_type", None) shape = getattr(attrs, "shape", None) @@ -3262,6 +3269,11 @@ def write_array( elif lib.is_np_dtype(value.dtype, "m"): self._handle.create_array(self.group, key, value.view("i8")) getattr(self.group, key)._v_attrs.value_type = "timedelta64" + elif isinstance(value, BaseStringArray): + vlarr = self._handle.create_vlarray(self.group, key, _tables().ObjectAtom()) + vlarr.append(value.to_numpy()) + node = getattr(self.group, key) + node._v_attrs.value_type = str(value.dtype) elif empty_array: self.write_array_empty(key, value) else: @@ -3294,7 +3306,11 @@ def read( index = self.read_index("index", start=start, stop=stop) values = self.read_array("values", start=start, stop=stop) result = Series(values, index=index, name=self.name, copy=False) - if using_string_dtype() and is_string_array(values, skipna=True): + if ( + using_string_dtype() + and isinstance(values, np.ndarray) + and is_string_array(values, skipna=True) + ): result = result.astype(StringDtype(na_value=np.nan)) return result @@ -3363,7 +3379,11 @@ def read( columns = items[items.get_indexer(blk_items)] df = DataFrame(values.T, columns=columns, index=axes[1], copy=False) - if using_string_dtype() and is_string_array(values, skipna=True): + if ( + using_string_dtype() + and isinstance(values, np.ndarray) + and is_string_array(values, skipna=True) + ): df = df.astype(StringDtype(na_value=np.nan)) dfs.append(df) @@ -4737,9 +4757,13 @@ def read( df = DataFrame._from_arrays([values], columns=cols_, index=index_) if not (using_string_dtype() and values.dtype.kind == "O"): assert (df.dtypes == values.dtype).all(), (df.dtypes, values.dtype) - if using_string_dtype() and is_string_array( - values, # type: ignore[arg-type] - skipna=True, + if ( + using_string_dtype() + and isinstance(values, np.ndarray) + and is_string_array( + values, + skipna=True, + ) ): df = df.astype(StringDtype(na_value=np.nan)) frames.append(df) diff --git a/pandas/tests/io/pytables/test_put.py b/pandas/tests/io/pytables/test_put.py index a4257b54dd6db..66596f1138b96 100644 --- a/pandas/tests/io/pytables/test_put.py +++ b/pandas/tests/io/pytables/test_put.py @@ -3,8 +3,6 @@ import numpy as np import pytest -from pandas._config import using_string_dtype - from pandas._libs.tslibs import Timestamp import pandas as pd @@ -26,7 +24,6 @@ pytestmark = [ pytest.mark.single_cpu, - pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False), ] @@ -54,8 +51,8 @@ def test_api_default_format(tmp_path, setup_path): with ensure_clean_store(setup_path) as store: df = DataFrame( 1.1 * np.arange(120).reshape((30, 4)), - columns=Index(list("ABCD"), dtype=object), - index=Index([f"i-{i}" for i in range(30)], dtype=object), + columns=Index(list("ABCD")), + index=Index([f"i-{i}" for i in range(30)]), ) with pd.option_context("io.hdf.default_format", "fixed"): @@ -79,8 +76,8 @@ def test_api_default_format(tmp_path, setup_path): path = tmp_path / setup_path df = DataFrame( 1.1 * np.arange(120).reshape((30, 4)), - columns=Index(list("ABCD"), dtype=object), - index=Index([f"i-{i}" for i in range(30)], dtype=object), + columns=Index(list("ABCD")), + index=Index([f"i-{i}" for i in range(30)]), ) with pd.option_context("io.hdf.default_format", "fixed"): @@ -106,7 +103,7 @@ def test_put(setup_path): ) df = DataFrame( np.random.default_rng(2).standard_normal((20, 4)), - columns=Index(list("ABCD"), dtype=object), + columns=Index(list("ABCD")), index=date_range("2000-01-01", periods=20, freq="B"), ) store["a"] = ts @@ -166,7 +163,7 @@ def test_put_compression(setup_path): with ensure_clean_store(setup_path) as store: df = DataFrame( np.random.default_rng(2).standard_normal((10, 4)), - columns=Index(list("ABCD"), dtype=object), + columns=Index(list("ABCD")), index=date_range("2000-01-01", periods=10, freq="B"), ) @@ -183,7 +180,7 @@ def test_put_compression(setup_path): def test_put_compression_blosc(setup_path): df = DataFrame( np.random.default_rng(2).standard_normal((10, 4)), - columns=Index(list("ABCD"), dtype=object), + columns=Index(list("ABCD")), index=date_range("2000-01-01", periods=10, freq="B"), ) @@ -197,10 +194,20 @@ def test_put_compression_blosc(setup_path): tm.assert_frame_equal(store["c"], df) -def test_put_mixed_type(setup_path, performance_warning): +def test_put_datetime_ser(setup_path, performance_warning, using_infer_string): + # https://github.com/pandas-dev/pandas/pull/60663 + ser = Series(3 * [Timestamp("20010102").as_unit("ns")]) + with ensure_clean_store(setup_path) as store: + store.put("ser", ser) + expected = ser.copy() + result = store.get("ser") + tm.assert_series_equal(result, expected) + + +def test_put_mixed_type(setup_path, performance_warning, using_infer_string): df = DataFrame( np.random.default_rng(2).standard_normal((10, 4)), - columns=Index(list("ABCD"), dtype=object), + columns=Index(list("ABCD")), index=date_range("2000-01-01", periods=10, freq="B"), ) df["obj1"] = "foo" @@ -220,13 +227,42 @@ def test_put_mixed_type(setup_path, performance_warning): with ensure_clean_store(setup_path) as store: _maybe_remove(store, "df") - with tm.assert_produces_warning(performance_warning): + warning = None if using_infer_string else performance_warning + with tm.assert_produces_warning(warning): store.put("df", df) expected = store.get("df") tm.assert_frame_equal(expected, df) +def test_put_str_frame(setup_path, performance_warning, string_dtype_arguments): + # https://github.com/pandas-dev/pandas/pull/60663 + dtype = pd.StringDtype(*string_dtype_arguments) + df = DataFrame({"a": pd.array(["x", pd.NA, "y"], dtype=dtype)}) + with ensure_clean_store(setup_path) as store: + _maybe_remove(store, "df") + + store.put("df", df) + expected_dtype = "str" if dtype.na_value is np.nan else "string" + expected = df.astype(expected_dtype) + result = store.get("df") + tm.assert_frame_equal(result, expected) + + +def test_put_str_series(setup_path, performance_warning, string_dtype_arguments): + # https://github.com/pandas-dev/pandas/pull/60663 + dtype = pd.StringDtype(*string_dtype_arguments) + ser = Series(["x", pd.NA, "y"], dtype=dtype) + with ensure_clean_store(setup_path) as store: + _maybe_remove(store, "df") + + store.put("ser", ser) + expected_dtype = "str" if dtype.na_value is np.nan else "string" + expected = ser.astype(expected_dtype) + result = store.get("ser") + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("format", ["table", "fixed"]) @pytest.mark.parametrize( "index", @@ -253,7 +289,7 @@ def test_store_index_types(setup_path, format, index): tm.assert_frame_equal(df, store["df"]) -def test_column_multiindex(setup_path): +def test_column_multiindex(setup_path, using_infer_string): # GH 4710 # recreate multi-indexes properly @@ -264,6 +300,12 @@ def test_column_multiindex(setup_path): expected = df.set_axis(df.index.to_numpy()) with ensure_clean_store(setup_path) as store: + if using_infer_string: + # TODO(infer_string) make this work for string dtype + msg = "Saving a MultiIndex with an extension dtype is not supported." + with pytest.raises(NotImplementedError, match=msg): + store.put("df", df) + return store.put("df", df) tm.assert_frame_equal( store["df"], expected, check_index_type=True, check_column_type=True