Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AgentSet: Refactor shuffle (overhaul) #2284

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 76 additions & 41 deletions mesa/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from random import Random

# mypy
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal

if TYPE_CHECKING:
# We ensure that these are not imported during runtime to prevent cyclic
Expand Down Expand Up @@ -127,7 +127,8 @@ def __init__(self, agents: Iterable[Agent], model: Model):
"""

self.model = model
self._agents = weakref.WeakKeyDictionary({agent: None for agent in agents})
self._agents_list: list[weakref.ref] = [weakref.ref(agent) for agent in agents]
self._agents_set: set[int] = {id(agent) for agent in agents}

def __len__(self) -> int:
"""Return the number of agents in the AgentSet."""
Expand Down Expand Up @@ -213,16 +214,16 @@ def shuffle(self, inplace: bool = False) -> AgentSet:
Using inplace = True is more performant

"""
weakrefs = list(self._agents.keyrefs())
self.random.shuffle(weakrefs)

if inplace:
self._agents.data = {entry: None for entry in weakrefs}
self.random.shuffle(self._agents_list)
return self
else:
return AgentSet(
(agent for ref in weakrefs if (agent := ref()) is not None), self.model
new_agentset = AgentSet([], self.model)
new_agentset._agents_list = self.random.sample(
self._agents_list, len(self._agents_list)
)
new_agentset._agents_set = self._agents_set.copy()
return new_agentset

def sort(
self,
Expand Down Expand Up @@ -276,30 +277,14 @@ def do(self, method: str | Callable, *args, **kwargs) -> AgentSet:
Returns:
AgentSet | list[Any]: The results of the callable calls if return_results is True, otherwise the AgentSet itself.
"""
try:
return_results = kwargs.pop("return_results")
except KeyError:
return_results = False
else:
warnings.warn(
"Using return_results is deprecated. Use AgenSet.do in case of return_results=False, and "
"AgentSet.map in case of return_results=True",
stacklevel=2,
)

if return_results:
return self.map(method, *args, **kwargs)

# we iterate over the actual weakref keys and check if weakref is alive before calling the method
if isinstance(method, str):
for agentref in self._agents.keyrefs():
if (agent := agentref()) is not None:
for ref in self._agents_list:
if (agent := ref()) is not None:
getattr(agent, method)(*args, **kwargs)
else:
for agentref in self._agents.keyrefs():
if (agent := agentref()) is not None:
for ref in self._agents_list:
if (agent := ref()) is not None:
method(agent, *args, **kwargs)

return self

def map(self, method: str | Callable, *args, **kwargs) -> list[Any]:
Expand Down Expand Up @@ -348,29 +333,77 @@ def agg(self, attribute: str, func: Callable) -> Any:
values = self.get(attribute)
return func(values)

def get(self, attr_names: str | list[str]) -> list[Any]:
def get(
self,
attr_names: str | list[str],
handle_missing: Literal["error", "skip", "default"] = "error",
default_value: Any = None,
) -> list[Any] | list[list[Any]]:
"""
Retrieve the specified attribute(s) from each agent in the AgentSet.

Args:
attr_names (str | list[str]): The name(s) of the attribute(s) to retrieve from each agent.
handle_missing (str, optional): How to handle missing attributes. Can be:
- 'error' (default): raises an AttributeError if attribute is missing.
- 'skip': skips the agents missing the attribute.
- 'default': returns the specified default_value.
default_value (Any, optional): The default value to return if 'handle_missing' is set to 'default'
and the agent does not have the attribute.

Returns:
list[Any]: A list with the attribute value for each agent in the set if attr_names is a str
list[list[Any]]: A list with a list of attribute values for each agent in the set if attr_names is a list of str
list[Any]: A list of attribute values for each agent if attr_names is a str.
list[list[Any]]: A list of lists of attribute values for each agent if attr_names is a list of str.

Raises:
AttributeError if an agent does not have the specified attribute(s)
AttributeError: If 'handle_missing' is 'error' and the agent does not have the specified attribute(s).
ValueError: If an unknown 'handle_missing' option is provided.
"""
is_single_attr = isinstance(attr_names, str)

if handle_missing == "error":
if is_single_attr:
return [getattr(agent, attr_names) for agent in self._agents]
else:
return [
[getattr(agent, attr) for attr in attr_names]
for agent in self._agents
]

elif handle_missing == "default":
if is_single_attr:
return [
getattr(agent, attr_names, default_value) for agent in self._agents
]
else:
return [
[getattr(agent, attr, default_value) for attr in attr_names]
for agent in self._agents
]

elif handle_missing == "skip":
if is_single_attr:
return [
getattr(agent, attr_names)
for agent in self._agents
if hasattr(agent, attr_names)
]
else:
return [
[
getattr(agent, attr)
for attr in attr_names
if hasattr(agent, attr)
]
for agent in self._agents
if any(hasattr(agent, attr) for attr in attr_names)
]

"""

if isinstance(attr_names, str):
return [getattr(agent, attr_names) for agent in self._agents]
else:
return [
[getattr(agent, attr_name) for attr_name in attr_names]
for agent in self._agents
]
raise ValueError(
f"Unknown handle_missing option: {handle_missing}, "
"should be one of 'error', 'skip', or 'default'"
)

def set(self, attr_name: str, value: Any) -> AgentSet:
"""
Expand Down Expand Up @@ -409,7 +442,9 @@ def add(self, agent: Agent):
Note:
This method is an implementation of the abstract method from MutableSet.
"""
self._agents[agent] = None
if id(agent) not in self._agents_set:
self._agents_list.append(weakref.ref(agent))
self._agents_set.add(id(agent))

def discard(self, agent: Agent):
"""
Expand Down
Loading