Skip to content

Commit

Permalink
Add default values and missing value handling to agentset.get (proj…
Browse files Browse the repository at this point in the history
  • Loading branch information
quaquel authored and EwoutH committed Sep 24, 2024
1 parent 00acfc5 commit 7353b2b
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 13 deletions.
55 changes: 42 additions & 13 deletions mesa/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from random import Random

# mypy
from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any, Callable, Literal

if TYPE_CHECKING:
# We ensure that these are not imported during runtime to prevent cyclic
Expand Down Expand Up @@ -333,29 +333,58 @@ def agg(self, attribute: str, func: Callable) -> Any:
values = self.get(attribute)
return func(values)

def get(self, attr_names: str | list[str]) -> list[Any]:
def get(
self,
attr_names: str | list[str],
handle_missing: Literal["error", "default"] = "error",
default_value: Any = None,
) -> list[Any] | list[list[Any]]:
"""
Retrieve the specified attribute(s) from each agent in the AgentSet.
Args:
attr_names (str | list[str]): The name(s) of the attribute(s) to retrieve from each agent.
handle_missing (str, optional): How to handle missing attributes. Can be:
- 'error' (default): raises an AttributeError if attribute is missing.
- 'default': returns the specified default_value.
default_value (Any, optional): The default value to return if 'handle_missing' is set to 'default'
and the agent does not have the attribute.
Returns:
list[Any]: A list with the attribute value for each agent in the set if attr_names is a str
list[list[Any]]: A list with a list of attribute values for each agent in the set if attr_names is a list of str
list[Any]: A list with the attribute value for each agent if attr_names is a str.
list[list[Any]]: A list with a lists of attribute values for each agent if attr_names is a list of str.
Raises:
AttributeError if an agent does not have the specified attribute(s)
"""
AttributeError: If 'handle_missing' is 'error' and the agent does not have the specified attribute(s).
ValueError: If an unknown 'handle_missing' option is provided.
"""
is_single_attr = isinstance(attr_names, str)

if handle_missing == "error":
if is_single_attr:
return [getattr(agent, attr_names) for agent in self._agents]
else:
return [
[getattr(agent, attr) for attr in attr_names]
for agent in self._agents
]

elif handle_missing == "default":
if is_single_attr:
return [
getattr(agent, attr_names, default_value) for agent in self._agents
]
else:
return [
[getattr(agent, attr, default_value) for attr in attr_names]
for agent in self._agents
]

if isinstance(attr_names, str):
return [getattr(agent, attr_names) for agent in self._agents]
else:
return [
[getattr(agent, attr_name) for attr_name in attr_names]
for agent in self._agents
]
raise ValueError(
f"Unknown handle_missing option: {handle_missing}, "
"should be one of 'error' or 'default'"
)

def set(self, attr_name: str, value: Any) -> AgentSet:
"""
Expand Down
45 changes: 45 additions & 0 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,51 @@ def remove_function(agent):
assert len(agentset) == 0


def test_agentset_get():
model = Model()
_ = [TestAgent(i, model) for i in range(10)]

agentset = model.agents

agentset.set("a", 5)
agentset.set("b", 6)

# Case 1: Normal retrieval of existing attributes
values = agentset.get(["a", "b"])
assert all((a == 5) & (b == 6) for a, b in values)

# Case 2: Raise AttributeError when attribute doesn't exist
with pytest.raises(AttributeError):
agentset.get("unknown_attribute")

# Case 3: Use default value when attribute is missing
results = agentset.get(
"unknown_attribute", handle_missing="default", default_value=True
)
assert all(results) is True

# Case 4: Retrieve mixed attributes with default value for missing ones
values = agentset.get(
["a", "unknown_attribute"], handle_missing="default", default_value=True
)
assert all((a == 5) & (unknown is True) for a, unknown in values)

# Case 5: Invalid handle_missing value raises ValueError
with pytest.raises(ValueError):
agentset.get("unknown_attribute", handle_missing="some nonsense value")

# Case 6: Retrieve multiple attributes with mixed existence and 'default' handling
values = agentset.get(
["a", "b", "unknown_attribute"], handle_missing="default", default_value=0
)
assert all((a == 5) & (b == 6) & (unknown == 0) for a, b, unknown in values)

# Case 7: 'default' handling when one attribute is completely missing from some agents
agentset.select(at_most=0.5).set("c", 8) # Only some agents have attribute 'c'
values = agentset.get(["a", "c"], handle_missing="default", default_value=-1)
assert all((a == 5) & (c in [8, -1]) for a, c in values)


def test_agentset_agg():
model = Model()
agents = [TestAgent(i, model) for i in range(10)]
Expand Down

0 comments on commit 7353b2b

Please sign in to comment.