From 01d0df1f4dedefbad8c0d6d388e380eaedd47854 Mon Sep 17 00:00:00 2001 From: Nishant Bhakar Date: Tue, 4 Feb 2025 17:06:47 -0800 Subject: [PATCH] fix(agg): use set aggregation in first stage instead of list --- .../src/physical_planner/translate.rs | 2 +- tests/cookbook/test_aggregations.py | 23 ++-- tests/dataframe/test_aggregations.py | 102 ++++++++++++------ tests/table/test_table_aggs.py | 7 ++ 4 files changed, 93 insertions(+), 41 deletions(-) diff --git a/src/daft-physical-plan/src/physical_planner/translate.rs b/src/daft-physical-plan/src/physical_planner/translate.rs index c3994fa4bd..801efcd2dc 100644 --- a/src/daft-physical-plan/src/physical_planner/translate.rs +++ b/src/daft-physical-plan/src/physical_planner/translate.rs @@ -1100,7 +1100,7 @@ pub fn populate_aggregation_stages( } AggExpr::Set(e) => { let list_agg_id = - add_to_stage(AggExpr::List, e.clone(), schema, &mut first_stage_aggs); + add_to_stage(AggExpr::Set, e.clone(), schema, &mut first_stage_aggs); let list_concat_id = add_to_stage( AggExpr::Concat, col(list_agg_id.clone()), diff --git a/tests/cookbook/test_aggregations.py b/tests/cookbook/test_aggregations.py index a581be3bdc..086cea6739 100644 --- a/tests/cookbook/test_aggregations.py +++ b/tests/cookbook/test_aggregations.py @@ -10,6 +10,7 @@ from daft.expressions import col from daft.udf import udf from tests.conftest import assert_df_equals +from tests.dataframe.test_aggregations import _assert_all_hashable def test_sum(daft_df, service_requests_csv_pd_df, repartition_nparts, with_morsel_size): @@ -88,15 +89,19 @@ def test_list(daft_df, service_requests_csv_pd_df, repartition_nparts, with_mors assert set(result_list[0]) == set(unique_key_list) -def test_set(daft_df, service_requests_csv_pd_df, repartition_nparts, with_morsel_size): - """Set agg a column for entire table to get unique values.""" - daft_df = daft_df.repartition(repartition_nparts).agg_set(col("Unique Key").alias("unique_key_set")).collect() - unique_key_list = service_requests_csv_pd_df["Unique Key"].drop_duplicates().to_list() - - result_list = daft_df.to_pydict()["unique_key_set"] - assert len(result_list) == 1 - assert len(result_list[0]) == len(set(result_list[0])), "Result should contain no duplicates" - assert set(result_list[0]) == set(unique_key_list), "Sets should contain same elements" +@pytest.mark.parametrize("npartitions", [1, 2]) +def test_set(make_df, npartitions, with_morsel_size): + df = make_df( + { + "values": [1, 2, 2, 3, 3, 3], + }, + repartition=npartitions, + ) + df = df.agg([col("values").agg_set().alias("set")]) + df.collect() + result = df.to_pydict()["set"][0] + _assert_all_hashable(result, "test_set") + assert set(result) == {1, 2, 3} def test_global_agg(daft_df, service_requests_csv_pd_df, repartition_nparts, with_morsel_size): diff --git a/tests/dataframe/test_aggregations.py b/tests/dataframe/test_aggregations.py index d153b9c133..57f6228638 100644 --- a/tests/dataframe/test_aggregations.py +++ b/tests/dataframe/test_aggregations.py @@ -11,11 +11,37 @@ from daft.context import get_context from daft.datatype import DataType from daft.errors import ExpressionTypeError -from daft.exceptions import DaftTypeError from daft.utils import freeze from tests.utils import sort_arrow_table +def _assert_all_hashable(values, test_name=""): + """Helper function to check if all elements in an iterable are hashable. + + Args: + values: Iterable of values to check + test_name: Name of the test for better error messages + + Raises: + AssertionError: If any elements are not hashable, with a descriptive message showing all unhashable values + """ + unhashable = [] + for val in values: + if val is not None: # Skip None values as they are always hashable + try: + hash(val) + except TypeError: + unhashable.append((type(val), val)) + + if unhashable: + details = "\n".join(f" - {t.__name__}: {v}" for t, v in unhashable) + raise AssertionError( + f"{test_name}: Found {len(unhashable)} unhashable value(s) in values: {values}\n" + f"Unhashable values:\n{details}\n" + "Set aggregation requires all elements to be hashable." + ) + + @pytest.mark.parametrize("repartition_nparts", [1, 2, 4]) def test_agg_global(make_df, repartition_nparts, with_morsel_size): daft_df = make_df( @@ -62,6 +88,7 @@ def test_agg_global(make_df, repartition_nparts, with_morsel_size): # Check set agg without nulls assert len(res_set) == 1 + _assert_all_hashable(res_set[0], "test_agg_global") assert len(res_set[0]) == len(set(x for x in res_set[0] if x is not None)), "Result should contain no duplicates" assert set(x for x in res_set[0] if x is not None) == set( x for x in exp_set[0] if x is not None @@ -224,6 +251,7 @@ def test_agg_groupby(make_df, repartition_nparts, with_morsel_size): sorted_res = [res_set[i] for i in arg_sort] sorted_exp = exp_set for res, exp in zip(sorted_res, sorted_exp): + _assert_all_hashable(res, "test_agg_groupby") assert len(res) == len(set(x for x in res if x is not None)), "Result should contain no duplicates" assert set(x for x in res if x is not None) == set(x for x in exp if x is not None), "Sets should match" @@ -361,6 +389,7 @@ def test_all_null_groupby_keys(make_df, repartition_nparts, with_morsel_size): # Check set without nulls (should be same as with nulls since no nulls in values) assert len(daft_cols["set"]) == 1 + _assert_all_hashable(daft_cols["set"][0], "test_all_null_groupby_keys") assert len(daft_cols["set"][0]) == 3, "Should contain all unique non-null values" assert set(daft_cols["set"][0]) == {1, 2, 3}, "Should contain all unique values" @@ -479,6 +508,7 @@ def test_agg_groupby_with_alias(make_df, repartition_nparts, with_morsel_size): sorted_res = [res_set[i] for i in arg_sort] sorted_exp = exp_set for res, exp in zip(sorted_res, sorted_exp): + _assert_all_hashable(res, "test_agg_groupby_with_alias") assert len(res) == len(set(x for x in res if x is not None)), "Result should contain no duplicates" assert set(x for x in res if x is not None) == set( x for x in exp if x is not None @@ -511,19 +541,6 @@ def test_agg_pyobjects_list(): assert set(result["list"][0]) == set(objects) -def test_agg_pyobjects_set(): - objects = [CustomObject(val=0), None, CustomObject(val=1)] - df = daft.from_pydict({"objs": objects}) - df = df.into_partitions(2) - df = df.agg( - [ - col("objs").agg_set().alias("set"), - ] - ) - with pytest.raises(DaftTypeError, match="Expected list input, got Python"): - df.collect() - - def test_groupby_agg_pyobjects_list(): objects = [CustomObject(val=0), CustomObject(val=1), None, None, CustomObject(val=2)] df = daft.from_pydict({"objects": objects, "groups": [1, 2, 1, 2, 1]}) @@ -547,23 +564,6 @@ def test_groupby_agg_pyobjects_list(): assert set(res["list"][1]) == set([objects[1], objects[3]]) -def test_groupby_agg_pyobjects_set(): - objects = [CustomObject(val=0), CustomObject(val=1), None, None, CustomObject(val=2)] - df = daft.from_pydict({"objects": objects, "groups": [1, 2, 1, 2, 1]}) - df = df.into_partitions(2) - df = ( - df.groupby(col("groups")) - .agg( - [ - col("objects").agg_set().alias("set"), - ] - ) - .sort(col("groups")) - ) - with pytest.raises(DaftTypeError, match="Expected list input, got Python"): - df.collect() - - @pytest.mark.parametrize("shuffle_aggregation_default_partitions", [None, 20]) def test_groupby_result_partitions_smaller_than_input(shuffle_aggregation_default_partitions, with_morsel_size): if shuffle_aggregation_default_partitions is None: @@ -752,3 +752,43 @@ def test_agg_with_groupby_key_in_agg(make_df, repartition_nparts, with_morsel_si "group_plus_1": [2, 3, 4], "id_plus_group": [7, 11, 15], } + + +@pytest.mark.parametrize("repartition_nparts", [2, 3]) +def test_agg_set_duplicates_across_partitions(make_df, repartition_nparts, with_morsel_size): + """Test that set aggregation correctly maintains uniqueness across partitions. + + This test verifies that when we have duplicates across different partitions, + the set aggregation still maintains uniqueness in the final result. For example, + if partition 1 has [1, 1, 1] and partition 2 has [1, 2], the final result should + be [1, 2] and not [1, 1, 2]. + """ + # Create a DataFrame with duplicates that will be distributed across partitions + daft_df = make_df( + { + "group": [1, 1, 1, 1, 1], + "values": [1, 1, 1, 1, 2], # Multiple 1s to ensure duplicates across partitions + }, + repartition=repartition_nparts, + ) + + # Test both global and groupby aggregations + # Global aggregation + global_result = daft_df.agg([col("values").agg_set().alias("set")]) + global_result.collect() + global_set = global_result.to_pydict()["set"][0] + + # The result should be [1, 2] or [2, 1], order doesn't matter + _assert_all_hashable(global_set, "test_agg_set_duplicates_across_partitions (global)") + assert len(global_set) == 2, f"Expected 2 unique values, got {len(global_set)} values: {global_set}" + assert set(global_set) == {1, 2}, f"Expected set {{1, 2}}, got set {set(global_set)}" + + # Groupby aggregation + group_result = daft_df.groupby("group").agg([col("values").agg_set().alias("set")]) + group_result.collect() + group_set = group_result.to_pydict()["set"][0] + + # The result should be [1, 2] or [2, 1], order doesn't matter + _assert_all_hashable(group_set, "test_agg_set_duplicates_across_partitions (group)") + assert len(group_set) == 2, f"Expected 2 unique values, got {len(group_set)} values: {group_set}" + assert set(group_set) == {1, 2}, f"Expected set {{1, 2}}, got set {set(group_set)}" diff --git a/tests/table/test_table_aggs.py b/tests/table/test_table_aggs.py index da8d00e098..43c234b182 100644 --- a/tests/table/test_table_aggs.py +++ b/tests/table/test_table_aggs.py @@ -11,6 +11,7 @@ from daft.logical.schema import Schema from daft.series import Series from daft.table import MicroPartition +from tests.dataframe.test_aggregations import _assert_all_hashable from tests.table import ( daft_comparable_types, daft_floating_types, @@ -411,6 +412,7 @@ def test_table_agg_global(case) -> None: # Check set without nulls assert len(res_set) == 1 + _assert_all_hashable(res_set[0], "test_table_agg_global") assert len(res_set[0]) == len(set(x for x in res_set[0] if x is not None)), "Result should contain no duplicates" assert set(x for x in res_set[0] if x is not None) == set( x for x in exp_set[0] if x is not None @@ -515,6 +517,7 @@ def test_table_agg_groupby(case) -> None: # Compare set columns by converting to sets assert len(result[key]) == len(expected[key]), f"Length mismatch in column {key}" for res, exp in zip(result[key], expected[key]): + _assert_all_hashable(res, "test_table_agg_groupby") assert set(res) == exp, f"Set mismatch in column {key}" else: assert result[key] == expected[key], f"Mismatch in column {key}" @@ -1087,6 +1090,7 @@ def test_global_set_aggs(dtype) -> None: assert result.get_column("set").datatype() == DataType.list(dtype) expected = [x for x in set(input) if x is not None] result_set = result.to_pydict()["set"][0] + _assert_all_hashable(result_set, "test_global_set_aggs") # Check length assert len(result_set) == len(expected) # Convert both to sets to ignore order @@ -1131,6 +1135,9 @@ def test_grouped_set_aggs(dtype) -> None: result_dict = result.to_pydict() assert sorted(result_dict["groups"]) == [1, 2, 3] + for i, group_set in enumerate(result_dict["set"]): + _assert_all_hashable(group_set, f"test_grouped_set_aggs (group {result_dict['groups'][i]})") + group1_set = set(result_dict["set"][0]) group2_set = set(result_dict["set"][1]) group3_set = set(result_dict["set"][2])