From 1fd62547fdd5bb728730216a6ca44650f5843373 Mon Sep 17 00:00:00 2001 From: Ewout ter Hoeven Date: Tue, 10 Sep 2024 09:51:50 +0200 Subject: [PATCH 1/4] GroupBy: Add count and agg methods Added two new methods to the `GroupBy` class to enhance aggregation and group operations: - `count`: Returns the count of agents in each group. - `agg`: Performs aggregation on a specific attribute across groups, applying a function like `sum`, `min`, `max`, or `mean`. These methods improve flexibility in applying both group-level and attribute-specific operations. --- mesa/agent.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/mesa/agent.py b/mesa/agent.py index 5badc825b14..93008e8a35c 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -611,6 +611,31 @@ def do(self, method: Callable | str, *args, **kwargs) -> GroupBy: return self + def count(self) -> dict[Any, int]: + """ + Return the count of agents in each group. + + Returns: + dict: A dictionary mapping group names to the number of agents in each group. + """ + return {k: len(v) for k, v in self.groups.items()} + + def agg(self, attr_name: str, func: Callable) -> dict[Any, Any]: + """ + Aggregate the values of a specific attribute across each group using the provided function. + + Args: + attr_name (str): The name of the attribute to aggregate. + func (Callable): The function to apply (e.g., sum, min, max, mean). + + Returns: + dict: A dictionary mapping group names to the result of applying the aggregation function. + """ + return { + group_name: func([getattr(agent, attr_name) for agent in group]) + for group_name, group in self.groups.items() + } + def __iter__(self): # noqa: D105 return iter(self.groups.items()) From d77c5b62d2c4e820330d1f421211c19e75e2bc9f Mon Sep 17 00:00:00 2001 From: Ewout ter Hoeven Date: Thu, 19 Sep 2024 19:26:08 +0200 Subject: [PATCH 2/4] Add tests for new GroupBy methods: count() and agg() - Extend test_agentset_groupby() function - Include tests for count() method to verify group sizes - Add tests for agg() method with sum, max, min, and custom functions - Ensure proper handling of grouped data and attribute aggregation --- tests/test_agent.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/test_agent.py b/tests/test_agent.py index 769be4ec5fc..f43a80c84ec 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -524,6 +524,7 @@ class TestAgent(Agent): def __init__(self, model): super().__init__(model) self.even = self.unique_id % 2 == 0 + self.value = self.unique_id * 10 def get_unique_identifier(self): return self.unique_id @@ -560,6 +561,37 @@ def get_unique_identifier(self): another_ref_to_groups = groups.do(lambda x: x.do("step")) assert groups == another_ref_to_groups + # New tests for count() method + groups = agentset.groupby("even") + count_result = groups.count() + assert count_result == {True: 5, False: 5} + + # New tests for agg() method + groups = agentset.groupby("even") + sum_result = groups.agg("value", sum) + assert sum_result[True] == sum(agent.value for agent in agents if agent.even) + assert sum_result[False] == sum(agent.value for agent in agents if not agent.even) + + max_result = groups.agg("value", max) + assert max_result[True] == max(agent.value for agent in agents if agent.even) + assert max_result[False] == max(agent.value for agent in agents if not agent.even) + + min_result = groups.agg("value", min) + assert min_result[True] == min(agent.value for agent in agents if agent.even) + assert min_result[False] == min(agent.value for agent in agents if not agent.even) + + # Test with a custom aggregation function + def custom_agg(values): + return sum(values) / len(values) if values else 0 + + custom_result = groups.agg("value", custom_agg) + assert custom_result[True] == custom_agg( + [agent.value for agent in agents if agent.even] + ) + assert custom_result[False] == custom_agg( + [agent.value for agent in agents if not agent.even] + ) + def test_oldstyle_agent_instantiation(): """Old behavior of Agent creation with unique_id and model as positional arguments. From 817b532f8fc997ee3d06ac4d86e1291afacbf08d Mon Sep 17 00:00:00 2001 From: Ewout ter Hoeven Date: Thu, 19 Sep 2024 19:30:23 +0200 Subject: [PATCH 3/4] GroupBy.agg: Return dict[Hashable, Any] Change return type annotation of GroupBy.agg() from dict[Any, Any] to dict[Hashable, Any] --- mesa/agent.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mesa/agent.py b/mesa/agent.py index 93008e8a35c..fa34f9dd901 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -15,7 +15,7 @@ import warnings import weakref from collections import defaultdict -from collections.abc import Callable, Iterable, Iterator, MutableSet, Sequence +from collections.abc import Callable, Hashable, Iterable, Iterator, MutableSet, Sequence from random import Random # mypy @@ -620,7 +620,7 @@ def count(self) -> dict[Any, int]: """ return {k: len(v) for k, v in self.groups.items()} - def agg(self, attr_name: str, func: Callable) -> dict[Any, Any]: + def agg(self, attr_name: str, func: Callable) -> dict[Hashable, Any]: """ Aggregate the values of a specific attribute across each group using the provided function. @@ -629,7 +629,7 @@ def agg(self, attr_name: str, func: Callable) -> dict[Any, Any]: func (Callable): The function to apply (e.g., sum, min, max, mean). Returns: - dict: A dictionary mapping group names to the result of applying the aggregation function. + dict[Hashable, Any]: A dictionary mapping group names to the result of applying the aggregation function. """ return { group_name: func([getattr(agent, attr_name) for agent in group]) From 2caa41c51e1ebc876bc45fe70705a937c4a99de2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 19 Sep 2024 17:35:01 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mesa/agent.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/mesa/agent.py b/mesa/agent.py index fa34f9dd901..2d098c3549f 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -612,8 +612,7 @@ def do(self, method: Callable | str, *args, **kwargs) -> GroupBy: return self def count(self) -> dict[Any, int]: - """ - Return the count of agents in each group. + """Return the count of agents in each group. Returns: dict: A dictionary mapping group names to the number of agents in each group. @@ -621,8 +620,7 @@ def count(self) -> dict[Any, int]: return {k: len(v) for k, v in self.groups.items()} def agg(self, attr_name: str, func: Callable) -> dict[Hashable, Any]: - """ - Aggregate the values of a specific attribute across each group using the provided function. + """Aggregate the values of a specific attribute across each group using the provided function. Args: attr_name (str): The name of the attribute to aggregate.