From 53a493f0023f56501aa82cc0576eca3204e16cf2 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Thu, 9 Jan 2025 16:12:56 +0400 Subject: [PATCH] feat(python): Raise `DuplicateError` if given a pyarrow Table object with duplicate column names (#20624) --- .../polars/_utils/construction/dataframe.py | 12 +++++++++--- .../unit/constructors/test_constructors.py | 17 ++++++++++++++++- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/py-polars/polars/_utils/construction/dataframe.py b/py-polars/polars/_utils/construction/dataframe.py index b7e1bfc9ad84..0aa27457a76b 100644 --- a/py-polars/polars/_utils/construction/dataframe.py +++ b/py-polars/polars/_utils/construction/dataframe.py @@ -1,6 +1,7 @@ from __future__ import annotations import contextlib +from collections import Counter from collections.abc import Generator, Mapping, Sequence from datetime import date, datetime, time, timedelta from functools import singledispatch @@ -52,7 +53,7 @@ from polars.dependencies import numpy as np from polars.dependencies import pandas as pd from polars.dependencies import pyarrow as pa -from polars.exceptions import DataOrientationWarning, ShapeError +from polars.exceptions import DataOrientationWarning, DuplicateError, ShapeError from polars.meta import thread_pool_size with contextlib.suppress(ImportError): # Module not available when building docs @@ -209,7 +210,7 @@ def _parse_schema_overrides( schema_overrides = _parse_schema_overrides(schema_overrides) - # Fast path for empty schema + # fast path for empty schema if not schema: columns = ( [f"column_{i}" for i in range(n_expected)] if n_expected is not None else [] @@ -1163,7 +1164,6 @@ def arrow_to_pydf( column_names, schema_overrides = _unpack_schema( (schema or data.schema.names), schema_overrides=schema_overrides ) - try: if column_names != data.schema.names: data = data.rename_columns(column_names) @@ -1180,6 +1180,12 @@ def arrow_to_pydf( # supply the arrow schema so the metadata is intact pydf = PyDataFrame.from_arrow_record_batches(batches, data.schema) + # arrow tables allow duplicate names; we don't + if len(data.columns) != pydf.width(): + col_name, _ = Counter(column_names).most_common(1)[0] + msg = f"column {col_name!r} appears more than once; names must be unique" + raise DuplicateError(msg) + if rechunk: pydf = pydf.rechunk() diff --git a/py-polars/tests/unit/constructors/test_constructors.py b/py-polars/tests/unit/constructors/test_constructors.py index dcff7219d8a4..0e169fae4fde 100644 --- a/py-polars/tests/unit/constructors/test_constructors.py +++ b/py-polars/tests/unit/constructors/test_constructors.py @@ -18,7 +18,7 @@ from polars._utils.construction.utils import try_get_type_hints from polars.datatypes import numpy_char_code_to_dtype from polars.dependencies import dataclasses, pydantic -from polars.exceptions import ShapeError +from polars.exceptions import DuplicateError, ShapeError from polars.testing import assert_frame_equal, assert_series_equal if TYPE_CHECKING: @@ -723,6 +723,21 @@ def test_init_arrow() -> None: pl.DataFrame(pa.table({"a": [1, 2, 3], "b": [4, 5, 6]}), schema=["c", "d", "e"]) +def test_init_arrow_dupes() -> None: + tbl = pa.Table.from_arrays( + arrays=[ + pa.array([1, 2, 3], type=pa.int32()), + pa.array([4, 5, 6], type=pa.int32()), + ], + schema=pa.schema([("col", pa.int32()), ("col", pa.int32())]), + ) + with pytest.raises( + DuplicateError, + match="column 'col' appears more than once; names must be unique", + ): + pl.DataFrame(tbl) + + def test_init_from_frame() -> None: df1 = pl.DataFrame({"id": [0, 1], "misc": ["a", "b"], "val": [-10, 10]}) assert_frame_equal(df1, pl.DataFrame(df1))