From d3a14dec5563796f4b88c95f89e99bbefd433853 Mon Sep 17 00:00:00 2001 From: barak1412 Date: Mon, 9 Sep 2024 09:34:56 +0300 Subject: [PATCH] fix(rust): Indicative error in `list.gather` when wrong indices type is supplied (#18611) --- .../src/chunked_array/list/namespace.rs | 2 +- .../operations/namespaces/list/test_list.py | 71 +++++++++++++++++++ 2 files changed, 72 insertions(+), 1 deletion(-) diff --git a/crates/polars-ops/src/chunked_array/list/namespace.rs b/crates/polars-ops/src/chunked_array/list/namespace.rs index 02dc0fe3e68c..0c7a0975488c 100644 --- a/crates/polars-ops/src/chunked_array/list/namespace.rs +++ b/crates/polars-ops/src/chunked_array/list/namespace.rs @@ -447,7 +447,7 @@ pub trait ListNameSpaceImpl: AsList { use DataType::*; match idx.dtype() { - List(_) => { + List(boxed_dt) if boxed_dt.is_integer() => { let idx_ca = idx.list().unwrap(); let mut out = { list_ca diff --git a/py-polars/tests/unit/operations/namespaces/list/test_list.py b/py-polars/tests/unit/operations/namespaces/list/test_list.py index 77ed41f5bba3..f306bbff5d7b 100644 --- a/py-polars/tests/unit/operations/namespaces/list/test_list.py +++ b/py-polars/tests/unit/operations/namespaces/list/test_list.py @@ -1,5 +1,6 @@ from __future__ import annotations +import re from datetime import date, datetime import numpy as np @@ -159,6 +160,76 @@ def test_list_categorical_get() -> None: ) +def test_list_gather_wrong_indices_list_type() -> None: + a = pl.Series("a", [[1, 2, 3], [4, 5], [6, 7, 8, 9]]) + expected = pl.Series("a", [[1, 2], [4], [6, 9]]) + + # int8 + indices_series = pl.Series("indices", [[0, 1], [0], [0, 3]], dtype=pl.List(pl.Int8)) + result = a.list.gather(indices=indices_series) + assert_series_equal(result, expected) + + # int16 + indices_series = pl.Series( + "indices", [[0, 1], [0], [0, 3]], dtype=pl.List(pl.Int16) + ) + result = a.list.gather(indices=indices_series) + assert_series_equal(result, expected) + + # int32 + indices_series = pl.Series( + "indices", [[0, 1], [0], [0, 3]], dtype=pl.List(pl.Int32) + ) + result = a.list.gather(indices=indices_series) + assert_series_equal(result, expected) + + # int64 + indices_series = pl.Series( + "indices", [[0, 1], [0], [0, 3]], dtype=pl.List(pl.Int64) + ) + result = a.list.gather(indices=indices_series) + assert_series_equal(result, expected) + + # uint8 + indices_series = pl.Series( + "indices", [[0, 1], [0], [0, 3]], dtype=pl.List(pl.UInt8) + ) + result = a.list.gather(indices=indices_series) + assert_series_equal(result, expected) + + # uint16 + indices_series = pl.Series( + "indices", [[0, 1], [0], [0, 3]], dtype=pl.List(pl.UInt16) + ) + result = a.list.gather(indices=indices_series) + assert_series_equal(result, expected) + + # uint32 + indices_series = pl.Series( + "indices", [[0, 1], [0], [0, 3]], dtype=pl.List(pl.UInt32) + ) + result = a.list.gather(indices=indices_series) + assert_series_equal(result, expected) + + # uint64 + indices_series = pl.Series( + "indices", [[0, 1], [0], [0, 3]], dtype=pl.List(pl.UInt64) + ) + result = a.list.gather(indices=indices_series) + assert_series_equal(result, expected) + + df = pl.DataFrame( + { + "index": [["2"], ["2"], ["2"]], + "lists": [[3, 4, 5], [4, 5, 6], [7, 8, 9, 4]], + } + ) + with pytest.raises( + ComputeError, match=re.escape("cannot use dtype `list[str]` as an index") + ): + df.select(pl.col("lists").list.gather(pl.col("index"))) + + def test_contains() -> None: a = pl.Series("a", [[1, 2, 3], [2, 5], [6, 7, 8, 9]]) out = a.list.contains(2)