From 7353b2b48e471d0143a7ed9f251f30d74b6e02d9 Mon Sep 17 00:00:00 2001 From: Jan Kwakkel Date: Fri, 6 Sep 2024 17:33:08 +0200 Subject: [PATCH] Add default values and missing value handling to `agentset.get` (#2279) --- mesa/agent.py | 55 ++++++++++++++++++++++++++++++++++----------- tests/test_agent.py | 45 +++++++++++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+), 13 deletions(-) diff --git a/mesa/agent.py b/mesa/agent.py index f0bbcaf2430..396d3f9eb37 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -18,7 +18,7 @@ from random import Random # mypy -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Callable, Literal if TYPE_CHECKING: # We ensure that these are not imported during runtime to prevent cyclic @@ -333,29 +333,58 @@ def agg(self, attribute: str, func: Callable) -> Any: values = self.get(attribute) return func(values) - def get(self, attr_names: str | list[str]) -> list[Any]: + def get( + self, + attr_names: str | list[str], + handle_missing: Literal["error", "default"] = "error", + default_value: Any = None, + ) -> list[Any] | list[list[Any]]: """ Retrieve the specified attribute(s) from each agent in the AgentSet. Args: attr_names (str | list[str]): The name(s) of the attribute(s) to retrieve from each agent. + handle_missing (str, optional): How to handle missing attributes. Can be: + - 'error' (default): raises an AttributeError if attribute is missing. + - 'default': returns the specified default_value. + default_value (Any, optional): The default value to return if 'handle_missing' is set to 'default' + and the agent does not have the attribute. Returns: - list[Any]: A list with the attribute value for each agent in the set if attr_names is a str - list[list[Any]]: A list with a list of attribute values for each agent in the set if attr_names is a list of str + list[Any]: A list with the attribute value for each agent if attr_names is a str. + list[list[Any]]: A list with a lists of attribute values for each agent if attr_names is a list of str. Raises: - AttributeError if an agent does not have the specified attribute(s) - - """ + AttributeError: If 'handle_missing' is 'error' and the agent does not have the specified attribute(s). + ValueError: If an unknown 'handle_missing' option is provided. + """ + is_single_attr = isinstance(attr_names, str) + + if handle_missing == "error": + if is_single_attr: + return [getattr(agent, attr_names) for agent in self._agents] + else: + return [ + [getattr(agent, attr) for attr in attr_names] + for agent in self._agents + ] + + elif handle_missing == "default": + if is_single_attr: + return [ + getattr(agent, attr_names, default_value) for agent in self._agents + ] + else: + return [ + [getattr(agent, attr, default_value) for attr in attr_names] + for agent in self._agents + ] - if isinstance(attr_names, str): - return [getattr(agent, attr_names) for agent in self._agents] else: - return [ - [getattr(agent, attr_name) for attr_name in attr_names] - for agent in self._agents - ] + raise ValueError( + f"Unknown handle_missing option: {handle_missing}, " + "should be one of 'error' or 'default'" + ) def set(self, attr_name: str, value: Any) -> AgentSet: """ diff --git a/tests/test_agent.py b/tests/test_agent.py index 67ceab1a9f9..aa6a7b88cc4 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -277,6 +277,51 @@ def remove_function(agent): assert len(agentset) == 0 +def test_agentset_get(): + model = Model() + _ = [TestAgent(i, model) for i in range(10)] + + agentset = model.agents + + agentset.set("a", 5) + agentset.set("b", 6) + + # Case 1: Normal retrieval of existing attributes + values = agentset.get(["a", "b"]) + assert all((a == 5) & (b == 6) for a, b in values) + + # Case 2: Raise AttributeError when attribute doesn't exist + with pytest.raises(AttributeError): + agentset.get("unknown_attribute") + + # Case 3: Use default value when attribute is missing + results = agentset.get( + "unknown_attribute", handle_missing="default", default_value=True + ) + assert all(results) is True + + # Case 4: Retrieve mixed attributes with default value for missing ones + values = agentset.get( + ["a", "unknown_attribute"], handle_missing="default", default_value=True + ) + assert all((a == 5) & (unknown is True) for a, unknown in values) + + # Case 5: Invalid handle_missing value raises ValueError + with pytest.raises(ValueError): + agentset.get("unknown_attribute", handle_missing="some nonsense value") + + # Case 6: Retrieve multiple attributes with mixed existence and 'default' handling + values = agentset.get( + ["a", "b", "unknown_attribute"], handle_missing="default", default_value=0 + ) + assert all((a == 5) & (b == 6) & (unknown == 0) for a, b, unknown in values) + + # Case 7: 'default' handling when one attribute is completely missing from some agents + agentset.select(at_most=0.5).set("c", 8) # Only some agents have attribute 'c' + values = agentset.get(["a", "c"], handle_missing="default", default_value=-1) + assert all((a == 5) & (c in [8, -1]) for a, c in values) + + def test_agentset_agg(): model = Model() agents = [TestAgent(i, model) for i in range(10)]