From 8995e81c8b8a479b084161b2fb0fccbef4e3a20e Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Wed, 8 Jan 2025 17:07:21 +0400 Subject: [PATCH 1/2] feat: Raise `DuplicateError` if given a pyarrow Table object with duplicate column names --- .../polars/_utils/construction/dataframe.py | 12 +++++++++--- .../unit/constructors/test_constructors.py | 17 ++++++++++++++++- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/py-polars/polars/_utils/construction/dataframe.py b/py-polars/polars/_utils/construction/dataframe.py index b7e1bfc9ad84..9efefb4f69b4 100644 --- a/py-polars/polars/_utils/construction/dataframe.py +++ b/py-polars/polars/_utils/construction/dataframe.py @@ -1,6 +1,7 @@ from __future__ import annotations import contextlib +from collections import Counter from collections.abc import Generator, Mapping, Sequence from datetime import date, datetime, time, timedelta from functools import singledispatch @@ -52,7 +53,7 @@ from polars.dependencies import numpy as np from polars.dependencies import pandas as pd from polars.dependencies import pyarrow as pa -from polars.exceptions import DataOrientationWarning, ShapeError +from polars.exceptions import DataOrientationWarning, DuplicateError, ShapeError from polars.meta import thread_pool_size with contextlib.suppress(ImportError): # Module not available when building docs @@ -209,7 +210,7 @@ def _parse_schema_overrides( schema_overrides = _parse_schema_overrides(schema_overrides) - # Fast path for empty schema + # fast path for empty schema if not schema: columns = ( [f"column_{i}" for i in range(n_expected)] if n_expected is not None else [] @@ -1163,7 +1164,6 @@ def arrow_to_pydf( column_names, schema_overrides = _unpack_schema( (schema or data.schema.names), schema_overrides=schema_overrides ) - try: if column_names != data.schema.names: data = data.rename_columns(column_names) @@ -1171,6 +1171,12 @@ def arrow_to_pydf( msg = "dimensions of columns arg must match data dimensions" raise ValueError(msg) from e + # arrow tables allow duplicate names; we don't + if len(column_names) != len(set(column_names)): + col_name, col_count = Counter(column_names).most_common(1)[0] + msg = f"column {col_name!r} appears {col_count} times; names must be unique" + raise DuplicateError(msg) + batches: list[pa.RecordBatch] if isinstance(data, pa.RecordBatch): batches = [data] diff --git a/py-polars/tests/unit/constructors/test_constructors.py b/py-polars/tests/unit/constructors/test_constructors.py index dcff7219d8a4..30caf2930acd 100644 --- a/py-polars/tests/unit/constructors/test_constructors.py +++ b/py-polars/tests/unit/constructors/test_constructors.py @@ -18,7 +18,7 @@ from polars._utils.construction.utils import try_get_type_hints from polars.datatypes import numpy_char_code_to_dtype from polars.dependencies import dataclasses, pydantic -from polars.exceptions import ShapeError +from polars.exceptions import DuplicateError, ShapeError from polars.testing import assert_frame_equal, assert_series_equal if TYPE_CHECKING: @@ -723,6 +723,21 @@ def test_init_arrow() -> None: pl.DataFrame(pa.table({"a": [1, 2, 3], "b": [4, 5, 6]}), schema=["c", "d", "e"]) +def test_init_arrow_dupes() -> None: + tbl = pa.Table.from_arrays( + arrays=[ + pa.array([1, 2, 3], type=pa.int32()), + pa.array([4, 5, 6], type=pa.int32()), + ], + schema=pa.schema([("col", pa.int32()), ("col", pa.int32())]), + ) + with pytest.raises( + DuplicateError, + match="column 'col' appears 2 times; names must be unique", + ): + pl.DataFrame(tbl) + + def test_init_from_frame() -> None: df1 = pl.DataFrame({"id": [0, 1], "misc": ["a", "b"], "val": [-10, 10]}) assert_frame_equal(df1, pl.DataFrame(df1)) From 9d1cba4b81d9fbb87fc2b401c2658b4cb8b05c07 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Thu, 9 Jan 2025 12:34:09 +0400 Subject: [PATCH 2/2] make duplicate column check optimistic --- py-polars/polars/_utils/construction/dataframe.py | 12 ++++++------ .../tests/unit/constructors/test_constructors.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/py-polars/polars/_utils/construction/dataframe.py b/py-polars/polars/_utils/construction/dataframe.py index 9efefb4f69b4..0aa27457a76b 100644 --- a/py-polars/polars/_utils/construction/dataframe.py +++ b/py-polars/polars/_utils/construction/dataframe.py @@ -1171,12 +1171,6 @@ def arrow_to_pydf( msg = "dimensions of columns arg must match data dimensions" raise ValueError(msg) from e - # arrow tables allow duplicate names; we don't - if len(column_names) != len(set(column_names)): - col_name, col_count = Counter(column_names).most_common(1)[0] - msg = f"column {col_name!r} appears {col_count} times; names must be unique" - raise DuplicateError(msg) - batches: list[pa.RecordBatch] if isinstance(data, pa.RecordBatch): batches = [data] @@ -1186,6 +1180,12 @@ def arrow_to_pydf( # supply the arrow schema so the metadata is intact pydf = PyDataFrame.from_arrow_record_batches(batches, data.schema) + # arrow tables allow duplicate names; we don't + if len(data.columns) != pydf.width(): + col_name, _ = Counter(column_names).most_common(1)[0] + msg = f"column {col_name!r} appears more than once; names must be unique" + raise DuplicateError(msg) + if rechunk: pydf = pydf.rechunk() diff --git a/py-polars/tests/unit/constructors/test_constructors.py b/py-polars/tests/unit/constructors/test_constructors.py index 30caf2930acd..0e169fae4fde 100644 --- a/py-polars/tests/unit/constructors/test_constructors.py +++ b/py-polars/tests/unit/constructors/test_constructors.py @@ -733,7 +733,7 @@ def test_init_arrow_dupes() -> None: ) with pytest.raises( DuplicateError, - match="column 'col' appears 2 times; names must be unique", + match="column 'col' appears more than once; names must be unique", ): pl.DataFrame(tbl)