Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): Raise DuplicateError if given a pyarrow Table object with duplicate column names #20624

Merged
merged 2 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions py-polars/polars/_utils/construction/dataframe.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import contextlib
from collections import Counter
from collections.abc import Generator, Mapping, Sequence
from datetime import date, datetime, time, timedelta
from functools import singledispatch
Expand Down Expand Up @@ -52,7 +53,7 @@
from polars.dependencies import numpy as np
from polars.dependencies import pandas as pd
from polars.dependencies import pyarrow as pa
from polars.exceptions import DataOrientationWarning, ShapeError
from polars.exceptions import DataOrientationWarning, DuplicateError, ShapeError
from polars.meta import thread_pool_size

with contextlib.suppress(ImportError): # Module not available when building docs
Expand Down Expand Up @@ -209,7 +210,7 @@ def _parse_schema_overrides(

schema_overrides = _parse_schema_overrides(schema_overrides)

# Fast path for empty schema
# fast path for empty schema
if not schema:
columns = (
[f"column_{i}" for i in range(n_expected)] if n_expected is not None else []
Expand Down Expand Up @@ -1163,14 +1164,19 @@ def arrow_to_pydf(
column_names, schema_overrides = _unpack_schema(
(schema or data.schema.names), schema_overrides=schema_overrides
)

try:
if column_names != data.schema.names:
data = data.rename_columns(column_names)
except pa.lib.ArrowInvalid as e:
msg = "dimensions of columns arg must match data dimensions"
raise ValueError(msg) from e

# arrow tables allow duplicate names; we don't
if len(column_names) != len(set(column_names)):
alexander-beedie marked this conversation as resolved.
Show resolved Hide resolved
col_name, col_count = Counter(column_names).most_common(1)[0]
msg = f"column {col_name!r} appears {col_count} times; names must be unique"
raise DuplicateError(msg)

batches: list[pa.RecordBatch]
if isinstance(data, pa.RecordBatch):
batches = [data]
Expand Down
17 changes: 16 additions & 1 deletion py-polars/tests/unit/constructors/test_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from polars._utils.construction.utils import try_get_type_hints
from polars.datatypes import numpy_char_code_to_dtype
from polars.dependencies import dataclasses, pydantic
from polars.exceptions import ShapeError
from polars.exceptions import DuplicateError, ShapeError
from polars.testing import assert_frame_equal, assert_series_equal

if TYPE_CHECKING:
Expand Down Expand Up @@ -723,6 +723,21 @@ def test_init_arrow() -> None:
pl.DataFrame(pa.table({"a": [1, 2, 3], "b": [4, 5, 6]}), schema=["c", "d", "e"])


def test_init_arrow_dupes() -> None:
tbl = pa.Table.from_arrays(
arrays=[
pa.array([1, 2, 3], type=pa.int32()),
pa.array([4, 5, 6], type=pa.int32()),
],
schema=pa.schema([("col", pa.int32()), ("col", pa.int32())]),
)
with pytest.raises(
DuplicateError,
match="column 'col' appears 2 times; names must be unique",
):
pl.DataFrame(tbl)


def test_init_from_frame() -> None:
df1 = pl.DataFrame({"id": [0, 1], "misc": ["a", "b"], "val": [-10, 10]})
assert_frame_equal(df1, pl.DataFrame(df1))
Expand Down
Loading