From fe28bc41a563837804efefe78bc5b99fbb0f4cf6 Mon Sep 17 00:00:00 2001 From: Ewout ter Hoeven Date: Fri, 6 Sep 2024 12:47:17 +0200 Subject: [PATCH 1/3] Add AgentSet.shuffle_do() --- mesa/agent.py | 87 +++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 74 insertions(+), 13 deletions(-) diff --git a/mesa/agent.py b/mesa/agent.py index 225e89cda79..c386c0706f2 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -20,7 +20,7 @@ from random import Random # mypy -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal if TYPE_CHECKING: # We ensure that these are not imported during runtime to prevent cyclic @@ -302,6 +302,19 @@ def do(self, method: str | Callable, *args, **kwargs) -> AgentSet: return self + def shuffle_do(self, method: str | Callable, *args, **kwargs) -> AgentSet: + agents = list(self._agents.keys()) + self.random.shuffle(agents) + + if isinstance(method, str): + for agent in agents: + getattr(agent, method)(*args, **kwargs) + else: + for agent in agents: + method(agent, *args, **kwargs) + + return self + def map(self, method: str | Callable, *args, **kwargs) -> list[Any]: """ Invoke a method or function on each agent in the AgentSet and return the results. @@ -348,29 +361,77 @@ def agg(self, attribute: str, func: Callable) -> Any: values = self.get(attribute) return func(values) - def get(self, attr_names: str | list[str]) -> list[Any]: + def get( + self, + attr_names: str | list[str], + handle_missing: Literal["error", "skip", "default"] = "error", + default_value: Any = None, + ) -> list[Any] | list[list[Any]]: """ Retrieve the specified attribute(s) from each agent in the AgentSet. Args: attr_names (str | list[str]): The name(s) of the attribute(s) to retrieve from each agent. + handle_missing (str, optional): How to handle missing attributes. Can be: + - 'error' (default): raises an AttributeError if attribute is missing. + - 'skip': skips the agents missing the attribute. + - 'default': returns the specified default_value. + default_value (Any, optional): The default value to return if 'handle_missing' is set to 'default' + and the agent does not have the attribute. Returns: - list[Any]: A list with the attribute value for each agent in the set if attr_names is a str - list[list[Any]]: A list with a list of attribute values for each agent in the set if attr_names is a list of str + list[Any]: A list of attribute values for each agent if attr_names is a str. + list[list[Any]]: A list of lists of attribute values for each agent if attr_names is a list of str. Raises: - AttributeError if an agent does not have the specified attribute(s) - - """ + AttributeError: If 'handle_missing' is 'error' and the agent does not have the specified attribute(s). + ValueError: If an unknown 'handle_missing' option is provided. + """ + is_single_attr = isinstance(attr_names, str) + + if handle_missing == "error": + if is_single_attr: + return [getattr(agent, attr_names) for agent in self._agents] + else: + return [ + [getattr(agent, attr) for attr in attr_names] + for agent in self._agents + ] + + elif handle_missing == "default": + if is_single_attr: + return [ + getattr(agent, attr_names, default_value) for agent in self._agents + ] + else: + return [ + [getattr(agent, attr, default_value) for attr in attr_names] + for agent in self._agents + ] + + elif handle_missing == "skip": + if is_single_attr: + return [ + getattr(agent, attr_names) + for agent in self._agents + if hasattr(agent, attr_names) + ] + else: + return [ + [ + getattr(agent, attr) + for attr in attr_names + if hasattr(agent, attr) + ] + for agent in self._agents + if any(hasattr(agent, attr) for attr in attr_names) + ] - if isinstance(attr_names, str): - return [getattr(agent, attr_names) for agent in self._agents] else: - return [ - [getattr(agent, attr_name) for attr_name in attr_names] - for agent in self._agents - ] + raise ValueError( + f"Unknown handle_missing option: {handle_missing}, " + "should be one of 'error', 'skip', or 'default'" + ) def set(self, attr_name: str, value: Any) -> AgentSet: """ From 168fd454089d26eb6919f574fe43962eb6fe4fa9 Mon Sep 17 00:00:00 2001 From: Ewout ter Hoeven Date: Fri, 6 Sep 2024 13:18:20 +0200 Subject: [PATCH 2/3] AgentSet: Refactor shuffle (overhaul) --- mesa/agent.py | 56 +++++++++++++-------------------------------------- 1 file changed, 14 insertions(+), 42 deletions(-) diff --git a/mesa/agent.py b/mesa/agent.py index c386c0706f2..1bd1f4ffeea 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -127,7 +127,8 @@ def __init__(self, agents: Iterable[Agent], model: Model): """ self.model = model - self._agents = weakref.WeakKeyDictionary({agent: None for agent in agents}) + self._agents_list: list[weakref.ref] = [weakref.ref(agent) for agent in agents] + self._agents_set: set[int] = set(id(agent) for agent in agents) def __len__(self) -> int: """Return the number of agents in the AgentSet.""" @@ -213,16 +214,14 @@ def shuffle(self, inplace: bool = False) -> AgentSet: Using inplace = True is more performant """ - weakrefs = list(self._agents.keyrefs()) - self.random.shuffle(weakrefs) - if inplace: - self._agents.data = {entry: None for entry in weakrefs} + self.random.shuffle(self._agents_list) return self else: - return AgentSet( - (agent for ref in weakrefs if (agent := ref()) is not None), self.model - ) + new_agentset = AgentSet([], self.model) + new_agentset._agents_list = self.random.sample(self._agents_list, len(self._agents_list)) + new_agentset._agents_set = self._agents_set.copy() + return new_agentset def sort( self, @@ -276,43 +275,14 @@ def do(self, method: str | Callable, *args, **kwargs) -> AgentSet: Returns: AgentSet | list[Any]: The results of the callable calls if return_results is True, otherwise the AgentSet itself. """ - try: - return_results = kwargs.pop("return_results") - except KeyError: - return_results = False - else: - warnings.warn( - "Using return_results is deprecated. Use AgenSet.do in case of return_results=False, and " - "AgentSet.map in case of return_results=True", - stacklevel=2, - ) - - if return_results: - return self.map(method, *args, **kwargs) - - # we iterate over the actual weakref keys and check if weakref is alive before calling the method if isinstance(method, str): - for agentref in self._agents.keyrefs(): - if (agent := agentref()) is not None: + for ref in self._agents_list: + if (agent := ref()) is not None: getattr(agent, method)(*args, **kwargs) else: - for agentref in self._agents.keyrefs(): - if (agent := agentref()) is not None: + for ref in self._agents_list: + if (agent := ref()) is not None: method(agent, *args, **kwargs) - - return self - - def shuffle_do(self, method: str | Callable, *args, **kwargs) -> AgentSet: - agents = list(self._agents.keys()) - self.random.shuffle(agents) - - if isinstance(method, str): - for agent in agents: - getattr(agent, method)(*args, **kwargs) - else: - for agent in agents: - method(agent, *args, **kwargs) - return self def map(self, method: str | Callable, *args, **kwargs) -> list[Any]: @@ -470,7 +440,9 @@ def add(self, agent: Agent): Note: This method is an implementation of the abstract method from MutableSet. """ - self._agents[agent] = None + if id(agent) not in self._agents_set: + self._agents_list.append(weakref.ref(agent)) + self._agents_set.add(id(agent)) def discard(self, agent: Agent): """ From 68c50b5dbaad674a976b0aa7181d8b10e763a430 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Sep 2024 11:29:17 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mesa/agent.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mesa/agent.py b/mesa/agent.py index 1bd1f4ffeea..9cc57864962 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -128,7 +128,7 @@ def __init__(self, agents: Iterable[Agent], model: Model): self.model = model self._agents_list: list[weakref.ref] = [weakref.ref(agent) for agent in agents] - self._agents_set: set[int] = set(id(agent) for agent in agents) + self._agents_set: set[int] = {id(agent) for agent in agents} def __len__(self) -> int: """Return the number of agents in the AgentSet.""" @@ -219,7 +219,9 @@ def shuffle(self, inplace: bool = False) -> AgentSet: return self else: new_agentset = AgentSet([], self.model) - new_agentset._agents_list = self.random.sample(self._agents_list, len(self._agents_list)) + new_agentset._agents_list = self.random.sample( + self._agents_list, len(self._agents_list) + ) new_agentset._agents_set = self._agents_set.copy() return new_agentset