Skip to content

Commit

Permalink
fix(agg): use set aggregation in first stage instead of list
Browse files Browse the repository at this point in the history
  • Loading branch information
f4t4nt committed Feb 5, 2025
1 parent 5771eb7 commit 01d0df1
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 41 deletions.
2 changes: 1 addition & 1 deletion src/daft-physical-plan/src/physical_planner/translate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1100,7 +1100,7 @@ pub fn populate_aggregation_stages(
}
AggExpr::Set(e) => {
let list_agg_id =
add_to_stage(AggExpr::List, e.clone(), schema, &mut first_stage_aggs);
add_to_stage(AggExpr::Set, e.clone(), schema, &mut first_stage_aggs);
let list_concat_id = add_to_stage(
AggExpr::Concat,
col(list_agg_id.clone()),
Expand Down
23 changes: 14 additions & 9 deletions tests/cookbook/test_aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from daft.expressions import col
from daft.udf import udf
from tests.conftest import assert_df_equals
from tests.dataframe.test_aggregations import _assert_all_hashable


def test_sum(daft_df, service_requests_csv_pd_df, repartition_nparts, with_morsel_size):
Expand Down Expand Up @@ -88,15 +89,19 @@ def test_list(daft_df, service_requests_csv_pd_df, repartition_nparts, with_mors
assert set(result_list[0]) == set(unique_key_list)


def test_set(daft_df, service_requests_csv_pd_df, repartition_nparts, with_morsel_size):
"""Set agg a column for entire table to get unique values."""
daft_df = daft_df.repartition(repartition_nparts).agg_set(col("Unique Key").alias("unique_key_set")).collect()
unique_key_list = service_requests_csv_pd_df["Unique Key"].drop_duplicates().to_list()

result_list = daft_df.to_pydict()["unique_key_set"]
assert len(result_list) == 1
assert len(result_list[0]) == len(set(result_list[0])), "Result should contain no duplicates"
assert set(result_list[0]) == set(unique_key_list), "Sets should contain same elements"
@pytest.mark.parametrize("npartitions", [1, 2])
def test_set(make_df, npartitions, with_morsel_size):
df = make_df(
{
"values": [1, 2, 2, 3, 3, 3],
},
repartition=npartitions,
)
df = df.agg([col("values").agg_set().alias("set")])
df.collect()
result = df.to_pydict()["set"][0]
_assert_all_hashable(result, "test_set")
assert set(result) == {1, 2, 3}


def test_global_agg(daft_df, service_requests_csv_pd_df, repartition_nparts, with_morsel_size):
Expand Down
102 changes: 71 additions & 31 deletions tests/dataframe/test_aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,37 @@
from daft.context import get_context
from daft.datatype import DataType
from daft.errors import ExpressionTypeError
from daft.exceptions import DaftTypeError
from daft.utils import freeze
from tests.utils import sort_arrow_table


def _assert_all_hashable(values, test_name=""):
"""Helper function to check if all elements in an iterable are hashable.
Args:
values: Iterable of values to check
test_name: Name of the test for better error messages
Raises:
AssertionError: If any elements are not hashable, with a descriptive message showing all unhashable values
"""
unhashable = []
for val in values:
if val is not None: # Skip None values as they are always hashable
try:
hash(val)
except TypeError:
unhashable.append((type(val), val))

if unhashable:
details = "\n".join(f" - {t.__name__}: {v}" for t, v in unhashable)
raise AssertionError(
f"{test_name}: Found {len(unhashable)} unhashable value(s) in values: {values}\n"
f"Unhashable values:\n{details}\n"
"Set aggregation requires all elements to be hashable."
)


@pytest.mark.parametrize("repartition_nparts", [1, 2, 4])
def test_agg_global(make_df, repartition_nparts, with_morsel_size):
daft_df = make_df(
Expand Down Expand Up @@ -62,6 +88,7 @@ def test_agg_global(make_df, repartition_nparts, with_morsel_size):

# Check set agg without nulls
assert len(res_set) == 1
_assert_all_hashable(res_set[0], "test_agg_global")
assert len(res_set[0]) == len(set(x for x in res_set[0] if x is not None)), "Result should contain no duplicates"
assert set(x for x in res_set[0] if x is not None) == set(
x for x in exp_set[0] if x is not None
Expand Down Expand Up @@ -224,6 +251,7 @@ def test_agg_groupby(make_df, repartition_nparts, with_morsel_size):
sorted_res = [res_set[i] for i in arg_sort]
sorted_exp = exp_set
for res, exp in zip(sorted_res, sorted_exp):
_assert_all_hashable(res, "test_agg_groupby")
assert len(res) == len(set(x for x in res if x is not None)), "Result should contain no duplicates"
assert set(x for x in res if x is not None) == set(x for x in exp if x is not None), "Sets should match"

Expand Down Expand Up @@ -361,6 +389,7 @@ def test_all_null_groupby_keys(make_df, repartition_nparts, with_morsel_size):

# Check set without nulls (should be same as with nulls since no nulls in values)
assert len(daft_cols["set"]) == 1
_assert_all_hashable(daft_cols["set"][0], "test_all_null_groupby_keys")
assert len(daft_cols["set"][0]) == 3, "Should contain all unique non-null values"
assert set(daft_cols["set"][0]) == {1, 2, 3}, "Should contain all unique values"

Expand Down Expand Up @@ -479,6 +508,7 @@ def test_agg_groupby_with_alias(make_df, repartition_nparts, with_morsel_size):
sorted_res = [res_set[i] for i in arg_sort]
sorted_exp = exp_set
for res, exp in zip(sorted_res, sorted_exp):
_assert_all_hashable(res, "test_agg_groupby_with_alias")
assert len(res) == len(set(x for x in res if x is not None)), "Result should contain no duplicates"
assert set(x for x in res if x is not None) == set(
x for x in exp if x is not None
Expand Down Expand Up @@ -511,19 +541,6 @@ def test_agg_pyobjects_list():
assert set(result["list"][0]) == set(objects)


def test_agg_pyobjects_set():
objects = [CustomObject(val=0), None, CustomObject(val=1)]
df = daft.from_pydict({"objs": objects})
df = df.into_partitions(2)
df = df.agg(
[
col("objs").agg_set().alias("set"),
]
)
with pytest.raises(DaftTypeError, match="Expected list input, got Python"):
df.collect()


def test_groupby_agg_pyobjects_list():
objects = [CustomObject(val=0), CustomObject(val=1), None, None, CustomObject(val=2)]
df = daft.from_pydict({"objects": objects, "groups": [1, 2, 1, 2, 1]})
Expand All @@ -547,23 +564,6 @@ def test_groupby_agg_pyobjects_list():
assert set(res["list"][1]) == set([objects[1], objects[3]])


def test_groupby_agg_pyobjects_set():
objects = [CustomObject(val=0), CustomObject(val=1), None, None, CustomObject(val=2)]
df = daft.from_pydict({"objects": objects, "groups": [1, 2, 1, 2, 1]})
df = df.into_partitions(2)
df = (
df.groupby(col("groups"))
.agg(
[
col("objects").agg_set().alias("set"),
]
)
.sort(col("groups"))
)
with pytest.raises(DaftTypeError, match="Expected list input, got Python"):
df.collect()


@pytest.mark.parametrize("shuffle_aggregation_default_partitions", [None, 20])
def test_groupby_result_partitions_smaller_than_input(shuffle_aggregation_default_partitions, with_morsel_size):
if shuffle_aggregation_default_partitions is None:
Expand Down Expand Up @@ -752,3 +752,43 @@ def test_agg_with_groupby_key_in_agg(make_df, repartition_nparts, with_morsel_si
"group_plus_1": [2, 3, 4],
"id_plus_group": [7, 11, 15],
}


@pytest.mark.parametrize("repartition_nparts", [2, 3])
def test_agg_set_duplicates_across_partitions(make_df, repartition_nparts, with_morsel_size):
"""Test that set aggregation correctly maintains uniqueness across partitions.
This test verifies that when we have duplicates across different partitions,
the set aggregation still maintains uniqueness in the final result. For example,
if partition 1 has [1, 1, 1] and partition 2 has [1, 2], the final result should
be [1, 2] and not [1, 1, 2].
"""
# Create a DataFrame with duplicates that will be distributed across partitions
daft_df = make_df(
{
"group": [1, 1, 1, 1, 1],
"values": [1, 1, 1, 1, 2], # Multiple 1s to ensure duplicates across partitions
},
repartition=repartition_nparts,
)

# Test both global and groupby aggregations
# Global aggregation
global_result = daft_df.agg([col("values").agg_set().alias("set")])
global_result.collect()
global_set = global_result.to_pydict()["set"][0]

# The result should be [1, 2] or [2, 1], order doesn't matter
_assert_all_hashable(global_set, "test_agg_set_duplicates_across_partitions (global)")
assert len(global_set) == 2, f"Expected 2 unique values, got {len(global_set)} values: {global_set}"
assert set(global_set) == {1, 2}, f"Expected set {{1, 2}}, got set {set(global_set)}"

# Groupby aggregation
group_result = daft_df.groupby("group").agg([col("values").agg_set().alias("set")])
group_result.collect()
group_set = group_result.to_pydict()["set"][0]

# The result should be [1, 2] or [2, 1], order doesn't matter
_assert_all_hashable(group_set, "test_agg_set_duplicates_across_partitions (group)")
assert len(group_set) == 2, f"Expected 2 unique values, got {len(group_set)} values: {group_set}"
assert set(group_set) == {1, 2}, f"Expected set {{1, 2}}, got set {set(group_set)}"
7 changes: 7 additions & 0 deletions tests/table/test_table_aggs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from daft.logical.schema import Schema
from daft.series import Series
from daft.table import MicroPartition
from tests.dataframe.test_aggregations import _assert_all_hashable
from tests.table import (
daft_comparable_types,
daft_floating_types,
Expand Down Expand Up @@ -411,6 +412,7 @@ def test_table_agg_global(case) -> None:

# Check set without nulls
assert len(res_set) == 1
_assert_all_hashable(res_set[0], "test_table_agg_global")
assert len(res_set[0]) == len(set(x for x in res_set[0] if x is not None)), "Result should contain no duplicates"
assert set(x for x in res_set[0] if x is not None) == set(
x for x in exp_set[0] if x is not None
Expand Down Expand Up @@ -515,6 +517,7 @@ def test_table_agg_groupby(case) -> None:
# Compare set columns by converting to sets
assert len(result[key]) == len(expected[key]), f"Length mismatch in column {key}"
for res, exp in zip(result[key], expected[key]):
_assert_all_hashable(res, "test_table_agg_groupby")
assert set(res) == exp, f"Set mismatch in column {key}"
else:
assert result[key] == expected[key], f"Mismatch in column {key}"
Expand Down Expand Up @@ -1087,6 +1090,7 @@ def test_global_set_aggs(dtype) -> None:
assert result.get_column("set").datatype() == DataType.list(dtype)
expected = [x for x in set(input) if x is not None]
result_set = result.to_pydict()["set"][0]
_assert_all_hashable(result_set, "test_global_set_aggs")
# Check length
assert len(result_set) == len(expected)
# Convert both to sets to ignore order
Expand Down Expand Up @@ -1131,6 +1135,9 @@ def test_grouped_set_aggs(dtype) -> None:
result_dict = result.to_pydict()
assert sorted(result_dict["groups"]) == [1, 2, 3]

for i, group_set in enumerate(result_dict["set"]):
_assert_all_hashable(group_set, f"test_grouped_set_aggs (group {result_dict['groups'][i]})")

group1_set = set(result_dict["set"][0])
group2_set = set(result_dict["set"][1])
group3_set = set(result_dict["set"][2])
Expand Down

0 comments on commit 01d0df1

Please sign in to comment.