From 3bb8dde35b19c27c8b4e9ee47607cbce48357244 Mon Sep 17 00:00:00 2001 From: Ewout ter Hoeven Date: Tue, 19 Dec 2023 10:57:55 +0100 Subject: [PATCH 01/25] Introduce AgentSet class What's really nice it that you can now do things like if agent in model.agents This commit introduces the `AgentSet` class in the Mesa agent-based modeling framework, along with significant changes to the agent management process. The `AgentSet` class is designed to encapsulate and manage a collection of agents, providing methods for efficient selection, sorting, shuffling, and applying actions to groups of agents. This addition aims to enhance the framework's scalability and flexibility in handling agent operations. Key changes include: - **Agent Class Modifications**: Updated the `Agent` class to directly manage agent registration and removal within the model's `_agents` attribute. This simplification removes the need for separate registration and removal methods, maintaining the encapsulation of agent management logic within the `Agent` class itself. - **Model Class Enhancements**: Refactored the `Model` class to utilize the `AgentSet` class. The `agents` property now returns an `AgentSet` instance, representing all agents in the model. This change streamlines agent access and manipulation, aligning with the object-oriented design of the framework. - **AgentSet Functionality**: The new `AgentSet` class includes methods like `select`, `shuffle`, `sort`, and `do_each` to enable more intuitive and powerful operations on agent collections. These methods support a range of common tasks in agent-based modeling, such as filtering agents based on criteria, randomizing their order, or applying actions to each agent. This implementation significantly refactors agent management in the Mesa framework, aiming to provide a more robust and user-friendly interface for modeling complex systems with diverse agent interactions. The addition of `AgentSet` aligns with the framework's goal of facilitating efficient and effective agent-based modeling. --- mesa/agent.py | 49 +++++++++++++++++++++++++++++++++++++++++---- mesa/model.py | 15 ++++++++++++-- tests/test_agent.py | 4 ++-- tests/test_model.py | 2 +- 4 files changed, 61 insertions(+), 9 deletions(-) diff --git a/mesa/agent.py b/mesa/agent.py index e29b233cffc..238e701ba98 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -10,7 +10,7 @@ from random import Random # mypy -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Callable, Iterator if TYPE_CHECKING: # We ensure that these are not imported during runtime to prevent cyclic @@ -41,12 +41,15 @@ def __init__(self, unique_id: int, model: Model) -> None: self.model = model self.pos: Position | None = None - # Register the agent with the model using defaultdict - self.model.agents[type(self)][self] = None + # Directly register the agent with the model + if type(self) not in self.model._agents: + self.model._agents[type(self)] = set() + self.model._agents[type(self)].add(self) def remove(self) -> None: """Remove and delete the agent from the model.""" - self.model.agents[type(self)].pop(self) + if type(self) in self.model._agents: + self.model._agents[type(self)].discard(self) def step(self) -> None: """A single step of the agent.""" @@ -57,3 +60,41 @@ def advance(self) -> None: @property def random(self) -> Random: return self.model.random + + +class AgentSet: + def __init__(self, agents: set[Agent], model: Model): + self._agents = agents + self.model = model + + def __len__(self): + return len(self._agents) + + def __iter__(self) -> Iterator[Agent]: + return iter(self._agents) + + def __contains__(self, agent: Agent) -> bool: + """Check if an agent is in the AgentSet.""" + return agent in self._agents + + def select(self, filter_func: Callable[[Agent], bool] | None = None) -> AgentSet: + if filter_func is None: + return AgentSet(set(self._agents), self.model) + return AgentSet( + {agent for agent in self._agents if filter_func(agent)}, self.model + ) + + def shuffle(self) -> AgentSet: + shuffled_agents = list(self._agents) + self.model.random.shuffle(shuffled_agents) + return AgentSet(set(shuffled_agents), self.model) + + def sort(self, key: Callable[[Agent], Any], reverse: bool = False) -> AgentSet: + sorted_agents = sorted(self._agents, key=key, reverse=reverse) + return AgentSet(set(sorted_agents), self.model) + + def do_each(self, method_name: str): + for agent in self._agents: + getattr(agent, method_name)() + + # Additional methods like union, intersection, difference, etc., can be added as needed. diff --git a/mesa/model.py b/mesa/model.py index 96398b446f8..ddaa6aa22fb 100644 --- a/mesa/model.py +++ b/mesa/model.py @@ -13,6 +13,7 @@ # mypy from typing import Any +from mesa.agent import AgentSet from mesa.datacollection import DataCollector @@ -51,12 +52,22 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self.running = True self.schedule = None self.current_id = 0 - self.agents: defaultdict[type, dict] = defaultdict(dict) + self._agents: defaultdict[type, dict] = defaultdict(dict) + + @property + def agents(self) -> AgentSet: + all_agents = set() + for agent_type in self._agents: + all_agents.update(self._agents[agent_type]) + return AgentSet(all_agents, self) @property def agent_types(self) -> list: """Return a list of different agent types.""" - return list(self.agents.keys()) + return list(self._agents.keys()) + + def select_agents(self, *args, **kwargs) -> AgentSet: + return self.agents.select(*args, **kwargs) def run_model(self) -> None: """Run the model until the end condition is reached. Overload as diff --git a/tests/test_agent.py b/tests/test_agent.py index 561208f25af..fe4c9bb6d7a 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -9,8 +9,8 @@ class TestAgent(Agent): model = Model() agent = TestAgent(model.next_id(), model) # Check if the agent is added - assert agent in model.agents[type(agent)] + assert agent in model.agents agent.remove() # Check if the agent is removed - assert agent not in model.agents[type(agent)] + assert agent not in model.agents diff --git a/tests/test_model.py b/tests/test_model.py index c54634d8352..874d45f935f 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -49,5 +49,5 @@ class TestAgent(Agent): model = Model() test_agent = TestAgent(model.next_id(), model) - assert test_agent in model.agents[type(test_agent)] + assert test_agent in model.agents assert type(test_agent) in model.agent_types From 51d293b1fc6c168a320ee5349baac7f9e3d84d1c Mon Sep 17 00:00:00 2001 From: Jan Kwakkel Date: Tue, 19 Dec 2023 16:31:18 +0100 Subject: [PATCH 02/25] move to WeakKeyDictionary, add inplace boolean, do_each can return results _agent is now a WeakKeyDictionary, and for arguments sake, I added the inplace keyword argument --- mesa/agent.py | 73 +++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 53 insertions(+), 20 deletions(-) diff --git a/mesa/agent.py b/mesa/agent.py index 238e701ba98..36e9154525e 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -7,10 +7,11 @@ # Remove this __future__ import once the oldest supported Python is 3.10 from __future__ import annotations +import weakref from random import Random # mypy -from typing import TYPE_CHECKING, Any, Callable, Iterator +from typing import TYPE_CHECKING, Any, Callable, Iterator, Iterable if TYPE_CHECKING: # We ensure that these are not imported during runtime to prevent cyclic @@ -63,38 +64,70 @@ def random(self) -> Random: class AgentSet: - def __init__(self, agents: set[Agent], model: Model): - self._agents = agents + def __init__(self, agents: Iterable[Agent], model: Model): + self._agents = weakref.WeakKeyDictionary() + + for agent in agents: + self._agents[agent] = None + self.model = model - def __len__(self): + def __len__(self) -> int: return len(self._agents) def __iter__(self) -> Iterator[Agent]: - return iter(self._agents) + return iter(self._agents.keys()) def __contains__(self, agent: Agent) -> bool: """Check if an agent is in the AgentSet.""" return agent in self._agents - def select(self, filter_func: Callable[[Agent], bool] | None = None) -> AgentSet: + def select(self, filter_func: Callable[[Agent], bool] | None = None, inplace: bool = False) -> AgentSet: if filter_func is None: - return AgentSet(set(self._agents), self.model) - return AgentSet( - {agent for agent in self._agents if filter_func(agent)}, self.model - ) - - def shuffle(self) -> AgentSet: - shuffled_agents = list(self._agents) + if inplace: + return self + else: + return AgentSet(list(self._agents.keys()), self.model) + else: + agents = [agent for agent in self._agents.keys() if filter_func(agent)] + + if inplace: + self._reorder(agents) + return self + else: + return AgentSet( + agents, + self.model + ) + + def shuffle(self, inplace: bool = False) -> AgentSet: + shuffled_agents = list(self._agents.keys()) self.model.random.shuffle(shuffled_agents) - return AgentSet(set(shuffled_agents), self.model) - def sort(self, key: Callable[[Agent], Any], reverse: bool = False) -> AgentSet: - sorted_agents = sorted(self._agents, key=key, reverse=reverse) - return AgentSet(set(sorted_agents), self.model) + if inplace: + self._reorder(shuffled_agents) + return self + else: + return AgentSet(shuffled_agents, self.model) + + def sort(self, key: Callable[[Agent], Any], reverse: bool = False, inplace: bool = False) -> AgentSet: + sorted_agents = sorted(list(self._agents.keys()), key=key, reverse=reverse) + + if inplace: + self._reorder(sorted_agents) + return self + else: + return AgentSet(sorted_agents, self.model) + + def _reorder(self, agents: Iterable[Agent]): + _agents = weakref.WeakKeyDictionary() + for agent in agents: + _agents[agent] = None + self._agents = _agents + + def do_each(self, method_name: str, *args, **kwargs) -> list[Any]: + return [getattr(agent, method_name)(*args, **kwargs) for agent in self._agents] + - def do_each(self, method_name: str): - for agent in self._agents: - getattr(agent, method_name)() # Additional methods like union, intersection, difference, etc., can be added as needed. From e397f9b59c2ab7c8d318cb28d1a4a46cc542d3c6 Mon Sep 17 00:00:00 2001 From: Jan Kwakkel Date: Tue, 19 Dec 2023 21:44:48 +0100 Subject: [PATCH 03/25] adds __getitem__, get_each, and code cleanup --- mesa/agent.py | 63 ++++++++++++++++++++++++++++++++------------------- 1 file changed, 40 insertions(+), 23 deletions(-) diff --git a/mesa/agent.py b/mesa/agent.py index 36e9154525e..cc197efe102 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -64,11 +64,14 @@ def random(self) -> Random: class AgentSet: + """Ordered set of agents""" + def __init__(self, agents: Iterable[Agent], model: Model): self._agents = weakref.WeakKeyDictionary() - + self._indices = [] for agent in agents: self._agents[agent] = None + self._indices.append(weakref.ref(agent)) self.model = model @@ -76,31 +79,28 @@ def __len__(self) -> int: return len(self._agents) def __iter__(self) -> Iterator[Agent]: - return iter(self._agents.keys()) + return self._agents.keys() def __contains__(self, agent: Agent) -> bool: """Check if an agent is in the AgentSet.""" return agent in self._agents - def select(self, filter_func: Callable[[Agent], bool] | None = None, inplace: bool = False) -> AgentSet: + def select(self, filter_func: Callable[[Agent], bool] | None = None, inplace: bool = False): if filter_func is None: - if inplace: - return self - else: - return AgentSet(list(self._agents.keys()), self.model) + return self if inplace else AgentSet(list(self._agents.keys()), self.model) + + agents = [agent for agent in self._agents.keys() if filter_func(agent)] + + if inplace: + self._reorder(agents) + return self else: - agents = [agent for agent in self._agents.keys() if filter_func(agent)] - - if inplace: - self._reorder(agents) - return self - else: - return AgentSet( - agents, - self.model - ) - - def shuffle(self, inplace: bool = False) -> AgentSet: + return AgentSet( + agents, + self.model + ) + + def shuffle(self, inplace: bool = False): shuffled_agents = list(self._agents.keys()) self.model.random.shuffle(shuffled_agents) @@ -110,8 +110,8 @@ def shuffle(self, inplace: bool = False) -> AgentSet: else: return AgentSet(shuffled_agents, self.model) - def sort(self, key: Callable[[Agent], Any], reverse: bool = False, inplace: bool = False) -> AgentSet: - sorted_agents = sorted(list(self._agents.keys()), key=key, reverse=reverse) + def sort(self, key: Callable[[Agent], Any], reverse: bool = False, inplace: bool = False): + sorted_agents = sorted(self._agents.keys(), key=key, reverse=reverse) if inplace: self._reorder(sorted_agents) @@ -121,13 +121,30 @@ def sort(self, key: Callable[[Agent], Any], reverse: bool = False, inplace: bool def _reorder(self, agents: Iterable[Agent]): _agents = weakref.WeakKeyDictionary() + _indices = [] for agent in agents: _agents[agent] = None + _indices.append(weakref.ref(agent)) self._agents = _agents + self._indices = _indices def do_each(self, method_name: str, *args, **kwargs) -> list[Any]: + """invoke method on each agent""" return [getattr(agent, method_name)(*args, **kwargs) for agent in self._agents] + def get_each(self, attr_name: str) -> list[Any]: + """get attribute value on each agent""" + return [getattr(agent, attr_name) for agent in self._agents] + + def __getitem__(self, item: int) -> Agent: + # TODO:: + # TBD:: it is a bit tricky to make this work + # part of the problem is that there is no weakreflist + agent = self._indices[item]() + if agent is None: + # the agent has been garbage collected + return None + else: + return self._agents[agent] - - # Additional methods like union, intersection, difference, etc., can be added as needed. + # Additional methods like union, intersection, difference, etc., can be added as needed. \ No newline at end of file From 8d31c3b13502bc100b4f803cf6719fba1a3e199f Mon Sep 17 00:00:00 2001 From: Jan Kwakkel Date: Wed, 20 Dec 2023 11:32:58 +0100 Subject: [PATCH 04/25] removal of set updates Agent and Model to not rely on sets when creating AgentSet. also removes unnecessary check in agent if key exists. Because of defaultdict this is not needed. Likewise, deregistring is done, and error is caught (It is better to ask forgiveness than permission). some comments in indexing --- mesa/agent.py | 16 ++++++++++------ mesa/model.py | 12 +++++++----- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/mesa/agent.py b/mesa/agent.py index cc197efe102..7374a3cf457 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -42,15 +42,16 @@ def __init__(self, unique_id: int, model: Model) -> None: self.model = model self.pos: Position | None = None - # Directly register the agent with the model - if type(self) not in self.model._agents: - self.model._agents[type(self)] = set() - self.model._agents[type(self)].add(self) + # register agent + self.model._agents[type(self)][self] = None def remove(self) -> None: """Remove and delete the agent from the model.""" - if type(self) in self.model._agents: - self.model._agents[type(self)].discard(self) + try: + # remove agent + self.model._agents[type(self)].pop(self) + except KeyError: + pass def step(self) -> None: """A single step of the agent.""" @@ -140,6 +141,9 @@ def __getitem__(self, item: int) -> Agent: # TODO:: # TBD:: it is a bit tricky to make this work # part of the problem is that there is no weakreflist + # might also be fixable through weakref.finalize + # TODO:: make slice also work + # item can be int or slice agent = self._indices[item]() if agent is None: # the agent has been garbage collected diff --git a/mesa/model.py b/mesa/model.py index ddaa6aa22fb..9d1f7bafde2 100644 --- a/mesa/model.py +++ b/mesa/model.py @@ -8,6 +8,7 @@ from __future__ import annotations import random +import itertools from collections import defaultdict # mypy @@ -56,9 +57,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: @property def agents(self) -> AgentSet: - all_agents = set() - for agent_type in self._agents: - all_agents.update(self._agents[agent_type]) + all_agents = itertools.chain(*[agents_by_type.values() for agents_by_type in self._agents.values()]) return AgentSet(all_agents, self) @property @@ -66,8 +65,11 @@ def agent_types(self) -> list: """Return a list of different agent types.""" return list(self._agents.keys()) - def select_agents(self, *args, **kwargs) -> AgentSet: - return self.agents.select(*args, **kwargs) + def get_agents_of_type(self, agenttype:type) -> AgentSet: + return AgentSet(self._agents[agenttype].values(), self) + + # def select_agents(self, *args, **kwargs) -> AgentSet: + # return self.agents.select(*args, **kwargs) def run_model(self) -> None: """Run the model until the end condition is reached. Overload as From ef763097b4fbf5d04f9fdd45b90358581de4e337 Mon Sep 17 00:00:00 2001 From: Jan Kwakkel Date: Wed, 20 Dec 2023 17:18:44 +0100 Subject: [PATCH 05/25] AgentSet subclasses abc.MutableSet and supports indexing and slicing --- mesa/agent.py | 48 +++++++++++++++++++++++++----------------------- 1 file changed, 25 insertions(+), 23 deletions(-) diff --git a/mesa/agent.py b/mesa/agent.py index 7374a3cf457..39c5cd1388e 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -8,6 +8,7 @@ from __future__ import annotations import weakref +from collections.abc import MutableSet from random import Random # mypy @@ -64,15 +65,13 @@ def random(self) -> Random: return self.model.random -class AgentSet: +class AgentSet(MutableSet): """Ordered set of agents""" def __init__(self, agents: Iterable[Agent], model: Model): self._agents = weakref.WeakKeyDictionary() - self._indices = [] for agent in agents: self._agents[agent] = None - self._indices.append(weakref.ref(agent)) self.model = model @@ -86,7 +85,7 @@ def __contains__(self, agent: Agent) -> bool: """Check if an agent is in the AgentSet.""" return agent in self._agents - def select(self, filter_func: Callable[[Agent], bool] | None = None, inplace: bool = False): + def select(self, filter_func: Callable[[Agent], bool] | None = None, inplace: bool = False) -> AgentSet: if filter_func is None: return self if inplace else AgentSet(list(self._agents.keys()), self.model) @@ -101,7 +100,7 @@ def select(self, filter_func: Callable[[Agent], bool] | None = None, inplace: bo self.model ) - def shuffle(self, inplace: bool = False): + def shuffle(self, inplace: bool = False)-> AgentSet: shuffled_agents = list(self._agents.keys()) self.model.random.shuffle(shuffled_agents) @@ -111,7 +110,7 @@ def shuffle(self, inplace: bool = False): else: return AgentSet(shuffled_agents, self.model) - def sort(self, key: Callable[[Agent], Any], reverse: bool = False, inplace: bool = False): + def sort(self, key: Callable[[Agent], Any], reverse: bool = False, inplace: bool = False)-> AgentSet: sorted_agents = sorted(self._agents.keys(), key=key, reverse=reverse) if inplace: @@ -122,12 +121,10 @@ def sort(self, key: Callable[[Agent], Any], reverse: bool = False, inplace: bool def _reorder(self, agents: Iterable[Agent]): _agents = weakref.WeakKeyDictionary() - _indices = [] for agent in agents: _agents[agent] = None - _indices.append(weakref.ref(agent)) + self._agents = _agents - self._indices = _indices def do_each(self, method_name: str, *args, **kwargs) -> list[Any]: """invoke method on each agent""" @@ -137,18 +134,23 @@ def get_each(self, attr_name: str) -> list[Any]: """get attribute value on each agent""" return [getattr(agent, attr_name) for agent in self._agents] - def __getitem__(self, item: int) -> Agent: - # TODO:: - # TBD:: it is a bit tricky to make this work - # part of the problem is that there is no weakreflist - # might also be fixable through weakref.finalize - # TODO:: make slice also work - # item can be int or slice - agent = self._indices[item]() - if agent is None: - # the agent has been garbage collected - return None - else: - return self._agents[agent] + def __getitem__(self, item: int | slice) -> Agent: + return list(self._agents.keys())[item] + + def add(self, agent: Agent): + # abstract method from MutableSet + self._agents[agent] = None + + def discard(self, agent: Agent): + # abstract method from MutableSet + # discard should not raise an error when + # item is not in set + try: + del self._agents[agent] + except KeyError: + pass - # Additional methods like union, intersection, difference, etc., can be added as needed. \ No newline at end of file + def remove(self, agent: Agent): + # remove should raise an error when + # item is not in set + del self._agents[agent] \ No newline at end of file From 7f7ca589b9ac9659117f6aa04a4740fa101730ef Mon Sep 17 00:00:00 2001 From: Jan Kwakkel Date: Wed, 20 Dec 2023 20:53:44 +0100 Subject: [PATCH 06/25] unittests for AgentSet --- mesa/model.py | 2 +- tests/test_agent.py | 54 +++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 51 insertions(+), 5 deletions(-) diff --git a/mesa/model.py b/mesa/model.py index 9d1f7bafde2..b8449fb86d5 100644 --- a/mesa/model.py +++ b/mesa/model.py @@ -57,7 +57,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: @property def agents(self) -> AgentSet: - all_agents = itertools.chain(*[agents_by_type.values() for agents_by_type in self._agents.values()]) + all_agents = itertools.chain(*[agents_by_type.keys() for agents_by_type in self._agents.values()]) return AgentSet(all_agents, self) @property diff --git a/tests/test_agent.py b/tests/test_agent.py index fe4c9bb6d7a..c9112107a0d 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -1,11 +1,16 @@ -from mesa.agent import Agent +import pytest + +from mesa.agent import Agent, AgentSet from mesa.model import Model -def test_agent_removal(): - class TestAgent(Agent): - pass +class TestAgent(Agent): + + def get_unique_identifier(self): + return self.unique_id + +def test_agent_removal(): model = Model() agent = TestAgent(model.next_id(), model) # Check if the agent is added @@ -14,3 +19,44 @@ class TestAgent(Agent): agent.remove() # Check if the agent is removed assert agent not in model.agents + + +def test_agentset(): + # create agentset + model = Model() + agents = [TestAgent(model.next_id(), model) for _ in range(10)] + + agentset = AgentSet(agents, model) + + assert agents[0] in agentset + assert len(agentset) == len(agents) + assert all(a1 == a2 for a1, a2 in zip(agentset[0:5], agents[0:5])) + + for a1, a2 in zip(agentset, agents): + assert a1 == a2 + + def test_function(agent): + return agent.unique_id > 5 + + assert len(agentset.select(test_function)) == 5 + assert len(agentset.select(test_function, inplace=True)) == 5 + assert agentset.select(inplace=True) == agentset + assert all(a1 == a2 for a1, a2 in zip(agentset.select(), agentset)) + + def test_function(agent): + return agent.unique_id + + assert all(a1 == a2 for a1, a2 in zip(agentset.sort(test_function, reverse=True), agentset[::-1])) + + assert all(a1 == a2.unique_id for a1, a2 in zip(agentset.get_each("unique_id"), agentset)) + assert all(a1 == a2.unique_id for a1, a2 in zip(agentset.do_each("get_unique_identifier"), agentset)) + + agentset.discard(agents[0]) + assert agents[0] not in agentset + agentset.discard(agents[0]) # check if no error is raised on discard + + with pytest.raises(KeyError): + agentset.remove(agents[0]) + + agentset.add(agents[0]) + assert agents[0] in agentset From ec36b9f5194b35bafd0774999020fc5367468200 Mon Sep 17 00:00:00 2001 From: Jan Kwakkel Date: Wed, 20 Dec 2023 21:07:28 +0100 Subject: [PATCH 07/25] Update mesa/agent.py Co-authored-by: Corvince --- mesa/agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mesa/agent.py b/mesa/agent.py index 39c5cd1388e..575f8a4e611 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -101,7 +101,7 @@ def select(self, filter_func: Callable[[Agent], bool] | None = None, inplace: bo ) def shuffle(self, inplace: bool = False)-> AgentSet: - shuffled_agents = list(self._agents.keys()) + shuffled_agents = list(self) self.model.random.shuffle(shuffled_agents) if inplace: From 1531ec59e9eb81532e79521f49418fa5df54dac8 Mon Sep 17 00:00:00 2001 From: Jan Kwakkel Date: Wed, 20 Dec 2023 21:08:52 +0100 Subject: [PATCH 08/25] Update mesa/agent.py Co-authored-by: Corvince --- mesa/agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mesa/agent.py b/mesa/agent.py index 575f8a4e611..11c4a48246c 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -87,7 +87,7 @@ def __contains__(self, agent: Agent) -> bool: def select(self, filter_func: Callable[[Agent], bool] | None = None, inplace: bool = False) -> AgentSet: if filter_func is None: - return self if inplace else AgentSet(list(self._agents.keys()), self.model) + return self if inplace else AgentSet(iter(self), self.model) agents = [agent for agent in self._agents.keys() if filter_func(agent)] From d34eef979cf5dc5eee83205778e52033d8984c41 Mon Sep 17 00:00:00 2001 From: Jan Kwakkel Date: Wed, 20 Dec 2023 21:45:41 +0100 Subject: [PATCH 09/25] make AgentSet pickle-able --- mesa/agent.py | 9 ++++++++- tests/test_agent.py | 4 ++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/mesa/agent.py b/mesa/agent.py index 11c4a48246c..719d3ec2b75 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -153,4 +153,11 @@ def discard(self, agent: Agent): def remove(self, agent: Agent): # remove should raise an error when # item is not in set - del self._agents[agent] \ No newline at end of file + del self._agents[agent] + + def __getstate__(self): + return dict(agents=list(self._agents.keys()), model=self.model) + + def __setstate__(self, state): + self.model = model + self._reorder(agents) \ No newline at end of file diff --git a/tests/test_agent.py b/tests/test_agent.py index c9112107a0d..a9beee2f3cd 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -1,4 +1,5 @@ import pytest +import pickle from mesa.agent import Agent, AgentSet from mesa.model import Model @@ -60,3 +61,6 @@ def test_function(agent): agentset.add(agents[0]) assert agents[0] in agentset + + anotherset = pickle.loads(pickle.dumps(agents)) + assert all(a1.unique_id==a2.unique_id for a1, a2 in zip(anotherset, agents)) \ No newline at end of file From 1e3b4d971b71e52d797ddf3879711f955d8a6096 Mon Sep 17 00:00:00 2001 From: Jan Kwakkel Date: Wed, 20 Dec 2023 22:27:57 +0100 Subject: [PATCH 10/25] minor change to __setstate__ for pickleability --- mesa/agent.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mesa/agent.py b/mesa/agent.py index 719d3ec2b75..c9c7dea56fc 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -159,5 +159,5 @@ def __getstate__(self): return dict(agents=list(self._agents.keys()), model=self.model) def __setstate__(self, state): - self.model = model - self._reorder(agents) \ No newline at end of file + self.model = state['model'] + self._reorder(state['agents']) From 5c279b4eb62a4cc944e50b02112cb6e4d26a19a8 Mon Sep 17 00:00:00 2001 From: Jan Kwakkel Date: Wed, 20 Dec 2023 22:46:57 +0100 Subject: [PATCH 11/25] fix for pickle test --- tests/test_agent.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_agent.py b/tests/test_agent.py index a9beee2f3cd..33c52456e08 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -62,5 +62,7 @@ def test_function(agent): agentset.add(agents[0]) assert agents[0] in agentset - anotherset = pickle.loads(pickle.dumps(agents)) - assert all(a1.unique_id==a2.unique_id for a1, a2 in zip(anotherset, agents)) \ No newline at end of file + # because AgentSet uses weakrefs, we need hard refs as well.... + other_agents, another_set = pickle.loads(pickle.dumps([agents, AgentSet(agents, model)])) + assert all(a1.unique_id==a2.unique_id for a1, a2 in zip(another_set, other_agents)) + assert len(another_set) == len(other_agents) \ No newline at end of file From c1e64f68cd1b9c8f468eb6d0f105429d61af90f2 Mon Sep 17 00:00:00 2001 From: Jan Kwakkel Date: Thu, 21 Dec 2023 09:57:13 +0100 Subject: [PATCH 12/25] additional keyword arguments for sort, select, renaming of some other method --- mesa/agent.py | 89 +++++++++++++++++++++++++++++---------------- tests/test_agent.py | 11 +++++- 2 files changed, 66 insertions(+), 34 deletions(-) diff --git a/mesa/agent.py b/mesa/agent.py index c9c7dea56fc..86e73d71076 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -7,8 +7,9 @@ # Remove this __future__ import once the oldest supported Python is 3.10 from __future__ import annotations +import operator import weakref -from collections.abc import MutableSet +from collections.abc import MutableSet, Sequence from random import Random # mypy @@ -65,7 +66,7 @@ def random(self) -> Random: return self.model.random -class AgentSet(MutableSet): +class AgentSet(MutableSet, Sequence): """Ordered set of agents""" def __init__(self, agents: Iterable[Agent], model: Model): @@ -85,53 +86,72 @@ def __contains__(self, agent: Agent) -> bool: """Check if an agent is in the AgentSet.""" return agent in self._agents - def select(self, filter_func: Callable[[Agent], bool] | None = None, inplace: bool = False) -> AgentSet: - if filter_func is None: - return self if inplace else AgentSet(iter(self), self.model) + def select(self, filter_func: Callable[[Agent], bool] | None = None, n: int = 0, inplace: bool = False) -> AgentSet: + """select agents from AgentSet + + Args: + filter_func (Callable[[Agent]]): function to filter agents. Function should return True if agent is to be + included, false otherwise + n (int, optional): number of agents to return. Defaults to 0, meaning all agents are returned + inplace (bool, optional): updates agentset inplace if True, else return new Agentset. Defaults to False. - agents = [agent for agent in self._agents.keys() if filter_func(agent)] - if inplace: - self._reorder(agents) - return self + """ + if filter_func is not None: + agents = [agent for agent in self._agents.keys() if filter_func(agent)] else: - return AgentSet( - agents, - self.model - ) + agents = list(self._agents.keys()) + + if n: + agents = agents[:n] + + return AgentSet(agents, self.model) if not inplace else self._update(agents) - def shuffle(self, inplace: bool = False)-> AgentSet: + def shuffle(self, inplace: bool = False) -> AgentSet: shuffled_agents = list(self) self.model.random.shuffle(shuffled_agents) - if inplace: - self._reorder(shuffled_agents) - return self - else: - return AgentSet(shuffled_agents, self.model) + return AgentSet(shuffled_agents, self.model) if not inplace else self._update(shuffled_agents) + + def sort(self, key: Callable[[Agent], Any]|str, reverse: bool = False, inplace: bool = False) -> AgentSet: + if isinstance(key, str): + key = operator.attrgetter(key) - def sort(self, key: Callable[[Agent], Any], reverse: bool = False, inplace: bool = False)-> AgentSet: sorted_agents = sorted(self._agents.keys(), key=key, reverse=reverse) - if inplace: - self._reorder(sorted_agents) - return self - else: - return AgentSet(sorted_agents, self.model) + return AgentSet(sorted_agents, self.model) if not inplace else self._update(sorted_agents) - def _reorder(self, agents: Iterable[Agent]): + def _update(self, agents: Iterable[Agent]): _agents = weakref.WeakKeyDictionary() for agent in agents: _agents[agent] = None self._agents = _agents + return self + + def do(self, method_name: str, *args, return_results: bool = False, **kwargs) -> AgentSet | list[Any]: + """invoke method on each agent + + Args: + method_name (str): name of the method to call on each agent + return_results (bool): whether to return the result from the method call or + return the AgentSet itself. Defaults to False, so you can + continue method chaining + + Additional arguments and keyword arguments will be passed to the method being called + + """ + res = [getattr(agent, method_name)(*args, **kwargs) for agent in self._agents] + + return res if return_results else self - def do_each(self, method_name: str, *args, **kwargs) -> list[Any]: - """invoke method on each agent""" - return [getattr(agent, method_name)(*args, **kwargs) for agent in self._agents] + def get(self, attr_name: str) -> list[Any]: + """get attribute value on each agent - def get_each(self, attr_name: str) -> list[Any]: - """get attribute value on each agent""" + Args: + attr_name (str): name of the attribute to get from eahc agent in the set + + """ return [getattr(agent, attr_name) for agent in self._agents] def __getitem__(self, item: int | slice) -> Agent: @@ -160,4 +180,9 @@ def __getstate__(self): def __setstate__(self, state): self.model = state['model'] - self._reorder(state['agents']) + self._update(state['agents']) + + +# consider adding for performance reasons +# for Sequence: __reversed__, index, and count +# for MutableSet clear, pop, remove, __ior__, __iand__, __ixor__, and __isub__ \ No newline at end of file diff --git a/tests/test_agent.py b/tests/test_agent.py index 33c52456e08..3f2cab29bca 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -40,17 +40,24 @@ def test_function(agent): return agent.unique_id > 5 assert len(agentset.select(test_function)) == 5 + assert len(agentset.select(test_function, n=2)) == 2 assert len(agentset.select(test_function, inplace=True)) == 5 assert agentset.select(inplace=True) == agentset assert all(a1 == a2 for a1, a2 in zip(agentset.select(), agentset)) + assert all(a1 == a2 for a1, a2 in zip(agentset.select(n=5), agentset[:5])) + + assert len(agentset.shuffle().select(n=5)) == 5 def test_function(agent): return agent.unique_id assert all(a1 == a2 for a1, a2 in zip(agentset.sort(test_function, reverse=True), agentset[::-1])) + assert all(a1 == a2 for a1, a2 in zip(agentset.sort("unique_id", reverse=True), agentset[::-1])) + + assert all(a1 == a2.unique_id for a1, a2 in zip(agentset.get("unique_id"), agentset)) + assert all(a1 == a2.unique_id for a1, a2 in zip(agentset.do("get_unique_identifier", return_results=True), agentset)) + assert agentset == agentset.do("get_unique_identifier") - assert all(a1 == a2.unique_id for a1, a2 in zip(agentset.get_each("unique_id"), agentset)) - assert all(a1 == a2.unique_id for a1, a2 in zip(agentset.do_each("get_unique_identifier"), agentset)) agentset.discard(agents[0]) assert agents[0] not in agentset From d26c483f64eabb828a66e3b778fd895e5b3c8294 Mon Sep 17 00:00:00 2001 From: Jan Kwakkel Date: Thu, 21 Dec 2023 14:18:02 +0100 Subject: [PATCH 13/25] mimic how Agents handles random --- mesa/agent.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mesa/agent.py b/mesa/agent.py index 86e73d71076..7d10937a050 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -109,7 +109,7 @@ def select(self, filter_func: Callable[[Agent], bool] | None = None, n: int = 0, def shuffle(self, inplace: bool = False) -> AgentSet: shuffled_agents = list(self) - self.model.random.shuffle(shuffled_agents) + self.random.shuffle(shuffled_agents) return AgentSet(shuffled_agents, self.model) if not inplace else self._update(shuffled_agents) @@ -182,6 +182,10 @@ def __setstate__(self, state): self.model = state['model'] self._update(state['agents']) + @property + def random(self) -> Random: + return self.model.random + # consider adding for performance reasons # for Sequence: __reversed__, index, and count From 2436e5b2cdc803484153beaa399dca2d08174187 Mon Sep 17 00:00:00 2001 From: Jan Kwakkel Date: Thu, 21 Dec 2023 14:27:10 +0100 Subject: [PATCH 14/25] fix for typo in docstring --- mesa/agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mesa/agent.py b/mesa/agent.py index 7d10937a050..297dd9e071d 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -149,7 +149,7 @@ def get(self, attr_name: str) -> list[Any]: """get attribute value on each agent Args: - attr_name (str): name of the attribute to get from eahc agent in the set + attr_name (str): name of the attribute to get from each agent in the set """ return [getattr(agent, attr_name) for agent in self._agents] From 503e05acc5d8cc4f656d4296f3e4e1696fafadb3 Mon Sep 17 00:00:00 2001 From: Ewout ter Hoeven Date: Thu, 21 Dec 2023 15:30:39 +0100 Subject: [PATCH 15/25] Black formatting --- mesa/agent.py | 38 +++++++++++++++++++++++++++++--------- mesa/model.py | 6 ++++-- tests/test_agent.py | 35 +++++++++++++++++++++++++---------- 3 files changed, 58 insertions(+), 21 deletions(-) diff --git a/mesa/agent.py b/mesa/agent.py index 297dd9e071d..ad3dc919bbf 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -86,7 +86,12 @@ def __contains__(self, agent: Agent) -> bool: """Check if an agent is in the AgentSet.""" return agent in self._agents - def select(self, filter_func: Callable[[Agent], bool] | None = None, n: int = 0, inplace: bool = False) -> AgentSet: + def select( + self, + filter_func: Callable[[Agent], bool] | None = None, + n: int = 0, + inplace: bool = False, + ) -> AgentSet: """select agents from AgentSet Args: @@ -111,15 +116,28 @@ def shuffle(self, inplace: bool = False) -> AgentSet: shuffled_agents = list(self) self.random.shuffle(shuffled_agents) - return AgentSet(shuffled_agents, self.model) if not inplace else self._update(shuffled_agents) - - def sort(self, key: Callable[[Agent], Any]|str, reverse: bool = False, inplace: bool = False) -> AgentSet: + return ( + AgentSet(shuffled_agents, self.model) + if not inplace + else self._update(shuffled_agents) + ) + + def sort( + self, + key: Callable[[Agent], Any] | str, + reverse: bool = False, + inplace: bool = False, + ) -> AgentSet: if isinstance(key, str): key = operator.attrgetter(key) sorted_agents = sorted(self._agents.keys(), key=key, reverse=reverse) - return AgentSet(sorted_agents, self.model) if not inplace else self._update(sorted_agents) + return ( + AgentSet(sorted_agents, self.model) + if not inplace + else self._update(sorted_agents) + ) def _update(self, agents: Iterable[Agent]): _agents = weakref.WeakKeyDictionary() @@ -129,7 +147,9 @@ def _update(self, agents: Iterable[Agent]): self._agents = _agents return self - def do(self, method_name: str, *args, return_results: bool = False, **kwargs) -> AgentSet | list[Any]: + def do( + self, method_name: str, *args, return_results: bool = False, **kwargs + ) -> AgentSet | list[Any]: """invoke method on each agent Args: @@ -179,8 +199,8 @@ def __getstate__(self): return dict(agents=list(self._agents.keys()), model=self.model) def __setstate__(self, state): - self.model = state['model'] - self._update(state['agents']) + self.model = state["model"] + self._update(state["agents"]) @property def random(self) -> Random: @@ -189,4 +209,4 @@ def random(self) -> Random: # consider adding for performance reasons # for Sequence: __reversed__, index, and count -# for MutableSet clear, pop, remove, __ior__, __iand__, __ixor__, and __isub__ \ No newline at end of file +# for MutableSet clear, pop, remove, __ior__, __iand__, __ixor__, and __isub__ diff --git a/mesa/model.py b/mesa/model.py index b8449fb86d5..3356643c05a 100644 --- a/mesa/model.py +++ b/mesa/model.py @@ -57,7 +57,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: @property def agents(self) -> AgentSet: - all_agents = itertools.chain(*[agents_by_type.keys() for agents_by_type in self._agents.values()]) + all_agents = itertools.chain( + *[agents_by_type.keys() for agents_by_type in self._agents.values()] + ) return AgentSet(all_agents, self) @property @@ -65,7 +67,7 @@ def agent_types(self) -> list: """Return a list of different agent types.""" return list(self._agents.keys()) - def get_agents_of_type(self, agenttype:type) -> AgentSet: + def get_agents_of_type(self, agenttype: type) -> AgentSet: return AgentSet(self._agents[agenttype].values(), self) # def select_agents(self, *args, **kwargs) -> AgentSet: diff --git a/tests/test_agent.py b/tests/test_agent.py index 3f2cab29bca..97efb04fc85 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -6,7 +6,6 @@ class TestAgent(Agent): - def get_unique_identifier(self): return self.unique_id @@ -51,14 +50,26 @@ def test_function(agent): def test_function(agent): return agent.unique_id - assert all(a1 == a2 for a1, a2 in zip(agentset.sort(test_function, reverse=True), agentset[::-1])) - assert all(a1 == a2 for a1, a2 in zip(agentset.sort("unique_id", reverse=True), agentset[::-1])) - - assert all(a1 == a2.unique_id for a1, a2 in zip(agentset.get("unique_id"), agentset)) - assert all(a1 == a2.unique_id for a1, a2 in zip(agentset.do("get_unique_identifier", return_results=True), agentset)) + assert all( + a1 == a2 + for a1, a2 in zip(agentset.sort(test_function, reverse=True), agentset[::-1]) + ) + assert all( + a1 == a2 + for a1, a2 in zip(agentset.sort("unique_id", reverse=True), agentset[::-1]) + ) + + assert all( + a1 == a2.unique_id for a1, a2 in zip(agentset.get("unique_id"), agentset) + ) + assert all( + a1 == a2.unique_id + for a1, a2 in zip( + agentset.do("get_unique_identifier", return_results=True), agentset + ) + ) assert agentset == agentset.do("get_unique_identifier") - agentset.discard(agents[0]) assert agents[0] not in agentset agentset.discard(agents[0]) # check if no error is raised on discard @@ -70,6 +81,10 @@ def test_function(agent): assert agents[0] in agentset # because AgentSet uses weakrefs, we need hard refs as well.... - other_agents, another_set = pickle.loads(pickle.dumps([agents, AgentSet(agents, model)])) - assert all(a1.unique_id==a2.unique_id for a1, a2 in zip(another_set, other_agents)) - assert len(another_set) == len(other_agents) \ No newline at end of file + other_agents, another_set = pickle.loads( + pickle.dumps([agents, AgentSet(agents, model)]) + ) + assert all( + a1.unique_id == a2.unique_id for a1, a2 in zip(another_set, other_agents) + ) + assert len(another_set) == len(other_agents) From 3c42e50f05e03adf7711225e816eac0cd631e9a7 Mon Sep 17 00:00:00 2001 From: Ewout ter Hoeven Date: Thu, 21 Dec 2023 15:31:43 +0100 Subject: [PATCH 16/25] Ruff fixes --- mesa/agent.py | 11 +++++------ mesa/model.py | 2 +- tests/test_agent.py | 3 ++- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/mesa/agent.py b/mesa/agent.py index ad3dc919bbf..23eddaf01bd 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -7,13 +7,14 @@ # Remove this __future__ import once the oldest supported Python is 3.10 from __future__ import annotations +import contextlib import operator import weakref from collections.abc import MutableSet, Sequence from random import Random # mypy -from typing import TYPE_CHECKING, Any, Callable, Iterator, Iterable +from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator if TYPE_CHECKING: # We ensure that these are not imported during runtime to prevent cyclic @@ -103,7 +104,7 @@ def select( """ if filter_func is not None: - agents = [agent for agent in self._agents.keys() if filter_func(agent)] + agents = [agent for agent in self._agents if filter_func(agent)] else: agents = list(self._agents.keys()) @@ -185,10 +186,8 @@ def discard(self, agent: Agent): # abstract method from MutableSet # discard should not raise an error when # item is not in set - try: + with contextlib.suppress(KeyError): del self._agents[agent] - except KeyError: - pass def remove(self, agent: Agent): # remove should raise an error when @@ -196,7 +195,7 @@ def remove(self, agent: Agent): del self._agents[agent] def __getstate__(self): - return dict(agents=list(self._agents.keys()), model=self.model) + return {"agents": list(self._agents.keys()), "model": self.model} def __setstate__(self, state): self.model = state["model"] diff --git a/mesa/model.py b/mesa/model.py index 3356643c05a..5bd7c44a2a8 100644 --- a/mesa/model.py +++ b/mesa/model.py @@ -7,8 +7,8 @@ # Remove this __future__ import once the oldest supported Python is 3.10 from __future__ import annotations -import random import itertools +import random from collections import defaultdict # mypy diff --git a/tests/test_agent.py b/tests/test_agent.py index 97efb04fc85..88edf26f829 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -1,6 +1,7 @@ -import pytest import pickle +import pytest + from mesa.agent import Agent, AgentSet from mesa.model import Model From a513d74daab00047810edde7122b804f28ffe525 Mon Sep 17 00:00:00 2001 From: Ewout ter Hoeven Date: Thu, 21 Dec 2023 15:40:37 +0100 Subject: [PATCH 17/25] Fix last ruff errors -Use with contextlib.suppress(KeyError) - Supress pickling warning in tests --- mesa/agent.py | 5 +---- tests/test_agent.py | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/mesa/agent.py b/mesa/agent.py index 23eddaf01bd..8354b6bcea0 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -50,11 +50,8 @@ def __init__(self, unique_id: int, model: Model) -> None: def remove(self) -> None: """Remove and delete the agent from the model.""" - try: - # remove agent + with contextlib.suppress(KeyError): self.model._agents[type(self)].pop(self) - except KeyError: - pass def step(self) -> None: """A single step of the agent.""" diff --git a/tests/test_agent.py b/tests/test_agent.py index 88edf26f829..58c8edc6210 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -82,7 +82,7 @@ def test_function(agent): assert agents[0] in agentset # because AgentSet uses weakrefs, we need hard refs as well.... - other_agents, another_set = pickle.loads( + other_agents, another_set = pickle.loads( # noqa: S301 pickle.dumps([agents, AgentSet(agents, model)]) ) assert all( From f298b7d86f75ca87e6cda0c7bd83f3c478df9ad8 Mon Sep 17 00:00:00 2001 From: Ewout ter Hoeven Date: Thu, 21 Dec 2023 15:50:28 +0100 Subject: [PATCH 18/25] Model: Update docstring --- mesa/model.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/mesa/model.py b/mesa/model.py index 5bd7c44a2a8..d1e89aeac4e 100644 --- a/mesa/model.py +++ b/mesa/model.py @@ -29,8 +29,20 @@ class Model: running: A boolean indicating if the model should continue running. schedule: An object to manage the order and execution of agent steps. current_id: A counter for assigning unique IDs to agents. - agents: A defaultdict mapping each agent type to a dict of its instances. - Agent instances are saved in the nested dict keys, with the values being None. + _agents: A defaultdict mapping each agent type to a dict of its instances. + This private attribute is used internally to manage agents. + + Properties: + agents: An AgentSet containing all agents in the model, generated from the _agents attribute. + agent_types: A list of different agent types present in the model. + + Methods: + get_agents_of_type: Returns an AgentSet of agents of the specified type. + run_model: Runs the model's simulation until a defined end condition is reached. + step: Executes a single step of the model's simulation process. + next_id: Generates and returns the next unique identifier for an agent. + reset_randomizer: Resets the model's random number generator with a new or existing seed. + initialize_data_collector: Sets up the data collector for the model, requiring an initialized scheduler and agents. """ def __new__(cls, *args: Any, **kwargs: Any) -> Any: @@ -57,6 +69,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: @property def agents(self) -> AgentSet: + """Provides an AgentSet of all agents in the model, combining agents from all types.""" all_agents = itertools.chain( *[agents_by_type.keys() for agents_by_type in self._agents.values()] ) @@ -68,11 +81,9 @@ def agent_types(self) -> list: return list(self._agents.keys()) def get_agents_of_type(self, agenttype: type) -> AgentSet: + """Retrieves an AgentSet containing all agents of the specified type.""" return AgentSet(self._agents[agenttype].values(), self) - # def select_agents(self, *args, **kwargs) -> AgentSet: - # return self.agents.select(*args, **kwargs) - def run_model(self) -> None: """Run the model until the end condition is reached. Overload as needed. From 5c770a6481805b39de0f46859fa8e7fdfe2f5ce2 Mon Sep 17 00:00:00 2001 From: Ewout ter Hoeven Date: Thu, 21 Dec 2023 16:01:28 +0100 Subject: [PATCH 19/25] AgentSet: Add docstring --- mesa/agent.py | 153 ++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 130 insertions(+), 23 deletions(-) diff --git a/mesa/agent.py b/mesa/agent.py index 8354b6bcea0..f3a68c63d0b 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -65,9 +65,33 @@ def random(self) -> Random: class AgentSet(MutableSet, Sequence): - """Ordered set of agents""" + """ + A collection class that represents an ordered set of agents within an agent-based model (ABM). This class + extends both MutableSet and Sequence, providing set-like functionality with order preservation and + sequence operations. + + Attributes: + model (Model): The ABM model instance to which this AgentSet belongs. + + Methods: + __len__, __iter__, __contains__, select, shuffle, sort, _update, do, get, __getitem__, + add, discard, remove, __getstate__, __setstate__, random + + Note: + The AgentSet maintains weak references to agents, allowing for efficient management of agent lifecycles + without preventing garbage collection. It is associated with a specific model instance, enabling + interactions with the model's environment and other agents.The implementation uses a WeakKeyDictionary to store agents, + which means that agents not referenced elsewhere in the program may be automatically removed from the AgentSet. + """ def __init__(self, agents: Iterable[Agent], model: Model): + """ + Initializes the AgentSet with a collection of agents and a reference to the model. + + Args: + agents (Iterable[Agent]): An iterable of Agent objects to be included in the set. + model (Model): The ABM model instance to which this AgentSet belongs. + """ self._agents = weakref.WeakKeyDictionary() for agent in agents: self._agents[agent] = None @@ -75,13 +99,15 @@ def __init__(self, agents: Iterable[Agent], model: Model): self.model = model def __len__(self) -> int: + """Return the number of agents in the AgentSet.""" return len(self._agents) def __iter__(self) -> Iterator[Agent]: + """Provide an iterator over the agents in the AgentSet.""" return self._agents.keys() def __contains__(self, agent: Agent) -> bool: - """Check if an agent is in the AgentSet.""" + """Check if an agent is in the AgentSet. Can be used like `agent in agentset`.""" return agent in self._agents def select( @@ -90,15 +116,17 @@ def select( n: int = 0, inplace: bool = False, ) -> AgentSet: - """select agents from AgentSet + """ + Select a subset of agents from the AgentSet based on a filter function and/or quantity limit. Args: - filter_func (Callable[[Agent]]): function to filter agents. Function should return True if agent is to be - included, false otherwise - n (int, optional): number of agents to return. Defaults to 0, meaning all agents are returned - inplace (bool, optional): updates agentset inplace if True, else return new Agentset. Defaults to False. - + filter_func (Callable[[Agent], bool], optional): A function that takes an Agent and returns True if the + agent should be included in the result. Defaults to None, meaning no filtering is applied. + n (int, optional): The number of agents to select. If 0, all matching agents are selected. Defaults to 0. + inplace (bool, optional): If True, modifies the current AgentSet; otherwise, returns a new AgentSet. Defaults to False. + Returns: + AgentSet: A new AgentSet containing the selected agents, unless inplace is True, in which case the current AgentSet is updated. """ if filter_func is not None: agents = [agent for agent in self._agents if filter_func(agent)] @@ -111,6 +139,15 @@ def select( return AgentSet(agents, self.model) if not inplace else self._update(agents) def shuffle(self, inplace: bool = False) -> AgentSet: + """ + Randomly shuffle the order of agents in the AgentSet. + + Args: + inplace (bool, optional): If True, shuffles the agents in the current AgentSet; otherwise, returns a new shuffled AgentSet. Defaults to False. + + Returns: + AgentSet: A shuffled AgentSet. Returns the current AgentSet if inplace is True. + """ shuffled_agents = list(self) self.random.shuffle(shuffled_agents) @@ -126,6 +163,17 @@ def sort( reverse: bool = False, inplace: bool = False, ) -> AgentSet: + """ + Sort the agents in the AgentSet based on a specified attribute or custom function. + + Args: + key (Callable[[Agent], Any] | str): A function or attribute name based on which the agents are sorted. + reverse (bool, optional): If True, the agents are sorted in descending order. Defaults to False. + inplace (bool, optional): If True, sorts the agents in the current AgentSet; otherwise, returns a new sorted AgentSet. Defaults to False. + + Returns: + AgentSet: A sorted AgentSet. Returns the current AgentSet if inplace is True. + """ if isinstance(key, str): key = operator.attrgetter(key) @@ -138,6 +186,9 @@ def sort( ) def _update(self, agents: Iterable[Agent]): + """Update the AgentSet with a new set of agents. + This is a private method primarily used internally by other methods like select, shuffle, and sort. + """ _agents = weakref.WeakKeyDictionary() for agent in agents: _agents[agent] = None @@ -148,58 +199,114 @@ def _update(self, agents: Iterable[Agent]): def do( self, method_name: str, *args, return_results: bool = False, **kwargs ) -> AgentSet | list[Any]: - """invoke method on each agent + """ + Invoke a method on each agent in the AgentSet. Args: - method_name (str): name of the method to call on each agent - return_results (bool): whether to return the result from the method call or - return the AgentSet itself. Defaults to False, so you can - continue method chaining - - Additional arguments and keyword arguments will be passed to the method being called + method_name (str): The name of the method to call on each agent. + return_results (bool, optional): If True, returns the results of the method calls; otherwise, returns the AgentSet itself. Defaults to False, so you can chain method calls. + *args: Variable length argument list passed to the method being called. + **kwargs: Arbitrary keyword arguments passed to the method being called. + Returns: + AgentSet | list[Any]: The results of the method calls if return_results is True, otherwise the AgentSet itself. """ res = [getattr(agent, method_name)(*args, **kwargs) for agent in self._agents] return res if return_results else self def get(self, attr_name: str) -> list[Any]: - """get attribute value on each agent + """ + Retrieve a specified attribute from each agent in the AgentSet. Args: - attr_name (str): name of the attribute to get from each agent in the set + attr_name (str): The name of the attribute to retrieve from each agent. + Returns: + list[Any]: A list of attribute values from each agent in the set. """ return [getattr(agent, attr_name) for agent in self._agents] def __getitem__(self, item: int | slice) -> Agent: + """ + Retrieve an agent or a slice of agents from the AgentSet. + + Args: + item (int | slice): The index or slice for selecting agents. + + Returns: + Agent | list[Agent]: The selected agent or list of agents based on the index or slice provided. + """ return list(self._agents.keys())[item] def add(self, agent: Agent): - # abstract method from MutableSet + """ + Add an agent to the AgentSet. + + Args: + agent (Agent): The agent to add to the set. + + Note: + This method is an implementation of the abstract method from MutableSet. + """ self._agents[agent] = None def discard(self, agent: Agent): - # abstract method from MutableSet - # discard should not raise an error when - # item is not in set + """ + Remove an agent from the AgentSet if it exists. + + This method does not raise an error if the agent is not present. + + Args: + agent (Agent): The agent to remove from the set. + + Note: + This method is an implementation of the abstract method from MutableSet. + """ with contextlib.suppress(KeyError): del self._agents[agent] def remove(self, agent: Agent): - # remove should raise an error when - # item is not in set + """ + Remove an agent from the AgentSet. + + This method raises an error if the agent is not present. + + Args: + agent (Agent): The agent to remove from the set. + + Note: + This method is an implementation of the abstract method from MutableSet. + """ del self._agents[agent] def __getstate__(self): + """ + Retrieve the state of the AgentSet for serialization. + + Returns: + dict: A dictionary representing the state of the AgentSet. + """ return {"agents": list(self._agents.keys()), "model": self.model} def __setstate__(self, state): + """ + Set the state of the AgentSet during deserialization. + + Args: + state (dict): A dictionary representing the state to restore. + """ self.model = state["model"] self._update(state["agents"]) @property def random(self) -> Random: + """ + Provide access to the model's random number generator. + + Returns: + Random: The random number generator associated with the model. + """ return self.model.random From 175e2c45cbbcbe12098fe942f696a05939311bf2 Mon Sep 17 00:00:00 2001 From: Ewout ter Hoeven Date: Thu, 21 Dec 2023 16:22:55 +0100 Subject: [PATCH 20/25] tests: Add more tests for AgentSet Probably have to refactor the AgentSet tests to have a proper setup class etc. --- tests/test_agent.py | 86 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/tests/test_agent.py b/tests/test_agent.py index 58c8edc6210..d32ae93ba54 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -89,3 +89,89 @@ def test_function(agent): a1.unique_id == a2.unique_id for a1, a2 in zip(another_set, other_agents) ) assert len(another_set) == len(other_agents) + + +def test_agentset_initialization(): + model = Model() + empty_agentset = AgentSet([], model) + assert len(empty_agentset) == 0 + + agents = [TestAgent(model.next_id(), model) for _ in range(10)] + agentset = AgentSet(agents, model) + assert len(agentset) == 10 + + +def test_agentset_serialization(): + model = Model() + agents = [TestAgent(model.next_id(), model) for _ in range(5)] + agentset = AgentSet(agents, model) + + serialized = pickle.dumps(agentset) + deserialized = pickle.loads(serialized) # noqa: S301 + + original_ids = [agent.unique_id for agent in agents] + deserialized_ids = [agent.unique_id for agent in deserialized] + + assert deserialized_ids == original_ids + + +def test_agent_membership(): + model = Model() + agents = [TestAgent(model.next_id(), model) for _ in range(5)] + agentset = AgentSet(agents, model) + + assert agents[0] in agentset + assert TestAgent(model.next_id(), model) not in agentset + + +def test_agent_add_remove_discard(): + model = Model() + agent = TestAgent(model.next_id(), model) + agentset = AgentSet([], model) + + agentset.add(agent) + assert agent in agentset + + agentset.remove(agent) + assert agent not in agentset + + agentset.add(agent) + agentset.discard(agent) + assert agent not in agentset + + with pytest.raises(KeyError): + agentset.remove(agent) + + +def test_agentset_get_item(): + model = Model() + agents = [TestAgent(model.next_id(), model) for _ in range(10)] + agentset = AgentSet(agents, model) + + assert agentset[0] == agents[0] + assert agentset[-1] == agents[-1] + assert agentset[1:3] == agents[1:3] + + with pytest.raises(IndexError): + _ = agentset[20] + + +def test_agentset_do_method(): + model = Model() + agents = [TestAgent(model.next_id(), model) for _ in range(10)] + agentset = AgentSet(agents, model) + + with pytest.raises(AttributeError): + agentset.do("non_existing_method") + + +def test_agentset_get_attribute(): + model = Model() + agents = [TestAgent(model.next_id(), model) for _ in range(10)] + agentset = AgentSet(agents, model) + + unique_ids = agentset.get("unique_id") + assert unique_ids == [agent.unique_id for agent in agents] + + with pytest.raises(AttributeError): + agentset.get("non_existing_attribute") From daf8a23abb01da74e30aac7e31539edc3176a075 Mon Sep 17 00:00:00 2001 From: Ewout ter Hoeven Date: Thu, 21 Dec 2023 17:10:02 +0100 Subject: [PATCH 21/25] AgentSet: Add agent_type argument to select() method Allow selection by type. Can be a superclass of multiple types. --- mesa/agent.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/mesa/agent.py b/mesa/agent.py index f3a68c63d0b..cd1142f45e2 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -115,6 +115,7 @@ def select( filter_func: Callable[[Agent], bool] | None = None, n: int = 0, inplace: bool = False, + agent_type: type[Agent] | None = None, ) -> AgentSet: """ Select a subset of agents from the AgentSet based on a filter function and/or quantity limit. @@ -124,14 +125,18 @@ def select( agent should be included in the result. Defaults to None, meaning no filtering is applied. n (int, optional): The number of agents to select. If 0, all matching agents are selected. Defaults to 0. inplace (bool, optional): If True, modifies the current AgentSet; otherwise, returns a new AgentSet. Defaults to False. + agent_type (type[Agent], optional): The class type of the agents to select. Defaults to None, meaning no type filtering is applied. Returns: AgentSet: A new AgentSet containing the selected agents, unless inplace is True, in which case the current AgentSet is updated. """ + agents = list(self._agents.keys()) + if filter_func is not None: - agents = [agent for agent in self._agents if filter_func(agent)] - else: - agents = list(self._agents.keys()) + agents = [agent for agent in agents if filter_func(agent)] + + if agent_type is not None: + agents = [agent for agent in agents if isinstance(agent, agent_type)] if n: agents = agents[:n] From a26e7c6c8e6e8159fc71fb0f0da0b9607abddaa4 Mon Sep 17 00:00:00 2001 From: Ewout ter Hoeven Date: Thu, 21 Dec 2023 17:14:19 +0100 Subject: [PATCH 22/25] AgentSet: Rename reverse to ascending in sort --- mesa/agent.py | 6 +++--- tests/test_agent.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mesa/agent.py b/mesa/agent.py index cd1142f45e2..def80510db8 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -165,7 +165,7 @@ def shuffle(self, inplace: bool = False) -> AgentSet: def sort( self, key: Callable[[Agent], Any] | str, - reverse: bool = False, + ascending: bool = False, inplace: bool = False, ) -> AgentSet: """ @@ -173,7 +173,7 @@ def sort( Args: key (Callable[[Agent], Any] | str): A function or attribute name based on which the agents are sorted. - reverse (bool, optional): If True, the agents are sorted in descending order. Defaults to False. + ascending (bool, optional): If True, the agents are sorted in ascending order. Defaults to False. inplace (bool, optional): If True, sorts the agents in the current AgentSet; otherwise, returns a new sorted AgentSet. Defaults to False. Returns: @@ -182,7 +182,7 @@ def sort( if isinstance(key, str): key = operator.attrgetter(key) - sorted_agents = sorted(self._agents.keys(), key=key, reverse=reverse) + sorted_agents = sorted(self._agents.keys(), key=key, reverse=not ascending) return ( AgentSet(sorted_agents, self.model) diff --git a/tests/test_agent.py b/tests/test_agent.py index d32ae93ba54..7f083decf8a 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -53,11 +53,11 @@ def test_function(agent): assert all( a1 == a2 - for a1, a2 in zip(agentset.sort(test_function, reverse=True), agentset[::-1]) + for a1, a2 in zip(agentset.sort(test_function, ascending=False), agentset[::-1]) ) assert all( a1 == a2 - for a1, a2 in zip(agentset.sort("unique_id", reverse=True), agentset[::-1]) + for a1, a2 in zip(agentset.sort("unique_id", ascending=False), agentset[::-1]) ) assert all( From 8d2a9d2ff2d88a19b52bd210a520c3b520a5f2c0 Mon Sep 17 00:00:00 2001 From: Ewout ter Hoeven Date: Thu, 21 Dec 2023 17:22:19 +0100 Subject: [PATCH 23/25] Add tests for selecting by type --- tests/test_agent.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/test_agent.py b/tests/test_agent.py index 7f083decf8a..3cc8f9adc4e 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -175,3 +175,34 @@ def test_agentset_get_attribute(): with pytest.raises(AttributeError): agentset.get("non_existing_attribute") + + +class OtherAgentType(Agent): + def get_unique_identifier(self): + return self.unique_id + + +def test_agentset_select_by_type(): + model = Model() + # Create a mix of agents of two different types + test_agents = [TestAgent(model.next_id(), model) for _ in range(4)] + other_agents = [OtherAgentType(model.next_id(), model) for _ in range(6)] + + # Combine the two types of agents + mixed_agents = test_agents + other_agents + agentset = AgentSet(mixed_agents, model) + + # Test selection by type + selected_test_agents = agentset.select(agent_type=TestAgent) + assert len(selected_test_agents) == len(test_agents) + assert all(isinstance(agent, TestAgent) for agent in selected_test_agents) + assert len(selected_test_agents) == 4 + + selected_other_agents = agentset.select(agent_type=OtherAgentType) + assert len(selected_other_agents) == len(other_agents) + assert all(isinstance(agent, OtherAgentType) for agent in selected_other_agents) + assert len(selected_other_agents) == 6 + + # Test with no type specified (should select all agents) + all_agents = agentset.select() + assert len(all_agents) == len(mixed_agents) From 78880077e08864847e08680c7d433ebe75101b5b Mon Sep 17 00:00:00 2001 From: Ewout ter Hoeven Date: Thu, 21 Dec 2023 17:38:25 +0100 Subject: [PATCH 24/25] Add experimental warning for AgentSet --- mesa/agent.py | 20 ++++++++++++++++++-- mesa/model.py | 3 +++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/mesa/agent.py b/mesa/agent.py index def80510db8..dbe4132f9ac 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -9,6 +9,7 @@ import contextlib import operator +import warnings import weakref from collections.abc import MutableSet, Sequence from random import Random @@ -66,6 +67,12 @@ def random(self) -> Random: class AgentSet(MutableSet, Sequence): """ + .. warning:: + The AgentSet is experimental. It may be changed or removed in any and all future releases, including + patch releases. + We would love to hear what you think about this new feature. If you have any thoughts, share them with + us here: https://github.com/projectmesa/mesa/discussions/1919 + A collection class that represents an ordered set of agents within an agent-based model (ABM). This class extends both MutableSet and Sequence, providing set-like functionality with order preservation and sequence operations. @@ -92,12 +99,21 @@ def __init__(self, agents: Iterable[Agent], model: Model): agents (Iterable[Agent]): An iterable of Agent objects to be included in the set. model (Model): The ABM model instance to which this AgentSet belongs. """ + self.model = model + + if not self.model.agentset_experimental_warning_given: + self.model.agentset_experimental_warning_given = True + warnings.warn( + "The AgentSet is experimental. It may be changed or removed in any and all future releases, including patch releases.\n" + "We would love to hear what you think about this new feature. If you have any thoughts, share them with us here: https://github.com/projectmesa/mesa/discussions/1919", + FutureWarning, + stacklevel=2, + ) + self._agents = weakref.WeakKeyDictionary() for agent in agents: self._agents[agent] = None - self.model = model - def __len__(self) -> int: """Return the number of agents in the AgentSet.""" return len(self._agents) diff --git a/mesa/model.py b/mesa/model.py index d1e89aeac4e..121732d2bbe 100644 --- a/mesa/model.py +++ b/mesa/model.py @@ -67,6 +67,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self.current_id = 0 self._agents: defaultdict[type, dict] = defaultdict(dict) + # Warning flags for current experimental features. These make sure a warning is only printed once per model. + self.agentset_experimental_warning_given = False + @property def agents(self) -> AgentSet: """Provides an AgentSet of all agents in the model, combining agents from all types.""" From 22b99cdaeb26d9d1382a38c3e88462ff83ef4e89 Mon Sep 17 00:00:00 2001 From: Jan Kwakkel Date: Fri, 22 Dec 2023 09:02:07 +0100 Subject: [PATCH 25/25] requested fixes --- mesa/agent.py | 23 ++++++++++++++--------- mesa/model.py | 4 ++-- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/mesa/agent.py b/mesa/agent.py index dbe4132f9ac..0bb2f21875a 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -146,16 +146,21 @@ def select( Returns: AgentSet: A new AgentSet containing the selected agents, unless inplace is True, in which case the current AgentSet is updated. """ - agents = list(self._agents.keys()) - if filter_func is not None: - agents = [agent for agent in agents if filter_func(agent)] - - if agent_type is not None: - agents = [agent for agent in agents if isinstance(agent, agent_type)] - - if n: - agents = agents[:n] + def agent_generator(): + count = 0 + for agent in self: + if filter_func and not filter_func(agent): + continue + if agent_type and not isinstance(agent, agent_type): + continue + yield agent + count += 1 + # default of n is zero, zo evaluates to False + if n and count >= n: + break + + agents = agent_generator() return AgentSet(agents, self.model) if not inplace else self._update(agents) diff --git a/mesa/model.py b/mesa/model.py index 121732d2bbe..22e654483dd 100644 --- a/mesa/model.py +++ b/mesa/model.py @@ -74,12 +74,12 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: def agents(self) -> AgentSet: """Provides an AgentSet of all agents in the model, combining agents from all types.""" all_agents = itertools.chain( - *[agents_by_type.keys() for agents_by_type in self._agents.values()] + *(agents_by_type.keys() for agents_by_type in self._agents.values()) ) return AgentSet(all_agents, self) @property - def agent_types(self) -> list: + def agent_types(self) -> list[type]: """Return a list of different agent types.""" return list(self._agents.keys())