Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GroupBy: Add count and agg methods #2290

Merged
merged 4 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion mesa/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import warnings
import weakref
from collections import defaultdict
from collections.abc import Callable, Iterable, Iterator, MutableSet, Sequence
from collections.abc import Callable, Hashable, Iterable, Iterator, MutableSet, Sequence
from random import Random

# mypy
Expand Down Expand Up @@ -611,6 +611,29 @@ def do(self, method: Callable | str, *args, **kwargs) -> GroupBy:

return self

def count(self) -> dict[Any, int]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although it can be expressed with agg, I am in favour of having a separate method for count. It just makes code more readable and easier to write.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be dict[Hashable, int] as well?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, missed this one.

Aren't all dicts by design required to have Hashable as a key? Seems a bit redundant to define that every time.

Like, the int part adds information. The Hashable doesn't (over Any).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please "Unresolve conversation" if you have additional comments by the way.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Hashable adds information to the type checker (in an IDE) to constraint user code. Arguably, any dict[Any, *] should be dict[Hashable, *]. I don't have a citation on why this is better. ChatGPT (the fallback version after my quota of GPT-4o has run out) seems to agree.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that sounds about right.

A smart type checker should know you can’t use non-hashable types as dict keys.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried Mypy:

from typing import Hashable, Any

a: dict[Hashable, int] = {[1]: 1}
a: dict[Any, int] = {[1]: 1}

It caught the first one, but doesn't consider the 2nd one to be a type error.

"""Return the count of agents in each group.

Returns:
dict: A dictionary mapping group names to the number of agents in each group.
EwoutH marked this conversation as resolved.
Show resolved Hide resolved
"""
return {k: len(v) for k, v in self.groups.items()}

def agg(self, attr_name: str, func: Callable) -> dict[Hashable, Any]:
"""Aggregate the values of a specific attribute across each group using the provided function.

Args:
attr_name (str): The name of the attribute to aggregate.
func (Callable): The function to apply (e.g., sum, min, max, mean).

Returns:
dict[Hashable, Any]: A dictionary mapping group names to the result of applying the aggregation function.
"""
return {
group_name: func([getattr(agent, attr_name) for agent in group])
for group_name, group in self.groups.items()
}

def __iter__(self): # noqa: D105
return iter(self.groups.items())

Expand Down
32 changes: 32 additions & 0 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,7 @@ class TestAgent(Agent):
def __init__(self, model):
super().__init__(model)
self.even = self.unique_id % 2 == 0
self.value = self.unique_id * 10

def get_unique_identifier(self):
return self.unique_id
Expand Down Expand Up @@ -560,6 +561,37 @@ def get_unique_identifier(self):
another_ref_to_groups = groups.do(lambda x: x.do("step"))
assert groups == another_ref_to_groups

# New tests for count() method
groups = agentset.groupby("even")
count_result = groups.count()
assert count_result == {True: 5, False: 5}

# New tests for agg() method
groups = agentset.groupby("even")
sum_result = groups.agg("value", sum)
assert sum_result[True] == sum(agent.value for agent in agents if agent.even)
assert sum_result[False] == sum(agent.value for agent in agents if not agent.even)

max_result = groups.agg("value", max)
assert max_result[True] == max(agent.value for agent in agents if agent.even)
assert max_result[False] == max(agent.value for agent in agents if not agent.even)

min_result = groups.agg("value", min)
assert min_result[True] == min(agent.value for agent in agents if agent.even)
assert min_result[False] == min(agent.value for agent in agents if not agent.even)

# Test with a custom aggregation function
def custom_agg(values):
return sum(values) / len(values) if values else 0

custom_result = groups.agg("value", custom_agg)
assert custom_result[True] == custom_agg(
[agent.value for agent in agents if agent.even]
)
assert custom_result[False] == custom_agg(
[agent.value for agent in agents if not agent.even]
)


def test_oldstyle_agent_instantiation():
"""Old behavior of Agent creation with unique_id and model as positional arguments.
Expand Down
Loading