diff --git a/mesa/agent.py b/mesa/agent.py index 5badc825b14..2d098c3549f 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -15,7 +15,7 @@ import warnings import weakref from collections import defaultdict -from collections.abc import Callable, Iterable, Iterator, MutableSet, Sequence +from collections.abc import Callable, Hashable, Iterable, Iterator, MutableSet, Sequence from random import Random # mypy @@ -611,6 +611,29 @@ def do(self, method: Callable | str, *args, **kwargs) -> GroupBy: return self + def count(self) -> dict[Any, int]: + """Return the count of agents in each group. + + Returns: + dict: A dictionary mapping group names to the number of agents in each group. + """ + return {k: len(v) for k, v in self.groups.items()} + + def agg(self, attr_name: str, func: Callable) -> dict[Hashable, Any]: + """Aggregate the values of a specific attribute across each group using the provided function. + + Args: + attr_name (str): The name of the attribute to aggregate. + func (Callable): The function to apply (e.g., sum, min, max, mean). + + Returns: + dict[Hashable, Any]: A dictionary mapping group names to the result of applying the aggregation function. + """ + return { + group_name: func([getattr(agent, attr_name) for agent in group]) + for group_name, group in self.groups.items() + } + def __iter__(self): # noqa: D105 return iter(self.groups.items()) diff --git a/tests/test_agent.py b/tests/test_agent.py index 769be4ec5fc..f43a80c84ec 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -524,6 +524,7 @@ class TestAgent(Agent): def __init__(self, model): super().__init__(model) self.even = self.unique_id % 2 == 0 + self.value = self.unique_id * 10 def get_unique_identifier(self): return self.unique_id @@ -560,6 +561,37 @@ def get_unique_identifier(self): another_ref_to_groups = groups.do(lambda x: x.do("step")) assert groups == another_ref_to_groups + # New tests for count() method + groups = agentset.groupby("even") + count_result = groups.count() + assert count_result == {True: 5, False: 5} + + # New tests for agg() method + groups = agentset.groupby("even") + sum_result = groups.agg("value", sum) + assert sum_result[True] == sum(agent.value for agent in agents if agent.even) + assert sum_result[False] == sum(agent.value for agent in agents if not agent.even) + + max_result = groups.agg("value", max) + assert max_result[True] == max(agent.value for agent in agents if agent.even) + assert max_result[False] == max(agent.value for agent in agents if not agent.even) + + min_result = groups.agg("value", min) + assert min_result[True] == min(agent.value for agent in agents if agent.even) + assert min_result[False] == min(agent.value for agent in agents if not agent.even) + + # Test with a custom aggregation function + def custom_agg(values): + return sum(values) / len(values) if values else 0 + + custom_result = groups.agg("value", custom_agg) + assert custom_result[True] == custom_agg( + [agent.value for agent in agents if agent.even] + ) + assert custom_result[False] == custom_agg( + [agent.value for agent in agents if not agent.even] + ) + def test_oldstyle_agent_instantiation(): """Old behavior of Agent creation with unique_id and model as positional arguments.