From e255a2d552a1ce4f6fe57f7bf908a6acb9c078e1 Mon Sep 17 00:00:00 2001 From: Jan Kwakkel Date: Sun, 13 Oct 2024 15:32:52 +0200 Subject: [PATCH] replace model with random in AgentSet init (#2350) * replace model with random in AgentSet `__init__` also closing #2323 --- mesa/agent.py | 31 +++++++++++--------------- mesa/model.py | 11 +++++---- tests/test_agent.py | 54 ++++++++++++++++++++++----------------------- 3 files changed, 47 insertions(+), 49 deletions(-) diff --git a/mesa/agent.py b/mesa/agent.py index 4d28f4742b1..2b6d2f99343 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -99,14 +99,18 @@ class AgentSet(MutableSet, Sequence): which means that agents not referenced elsewhere in the program may be automatically removed from the AgentSet. """ - def __init__(self, agents: Iterable[Agent], model: Model): + def __init__(self, agents: Iterable[Agent], random: Random | None = None): """Initializes the AgentSet with a collection of agents and a reference to the model. Args: agents (Iterable[Agent]): An iterable of Agent objects to be included in the set. - model (Model): The ABM model instance to which this AgentSet belongs. + random (Random): the random number generator """ - self.model = model + if random is None: + random = ( + Random() + ) # FIXME see issue 1981, how to get the central rng from model + self.random = random self._agents = weakref.WeakKeyDictionary({agent: None for agent in agents}) def __len__(self) -> int: @@ -177,7 +181,7 @@ def agent_generator(filter_func, agent_type, at_most): agents = agent_generator(filter_func, agent_type, at_most) - return AgentSet(agents, self.model) if not inplace else self._update(agents) + return AgentSet(agents, self.random) if not inplace else self._update(agents) def shuffle(self, inplace: bool = False) -> AgentSet: """Randomly shuffle the order of agents in the AgentSet. @@ -200,7 +204,7 @@ def shuffle(self, inplace: bool = False) -> AgentSet: return self else: return AgentSet( - (agent for ref in weakrefs if (agent := ref()) is not None), self.model + (agent for ref in weakrefs if (agent := ref()) is not None), self.random ) def sort( @@ -225,7 +229,7 @@ def sort( sorted_agents = sorted(self._agents.keys(), key=key, reverse=not ascending) return ( - AgentSet(sorted_agents, self.model) + AgentSet(sorted_agents, self.random) if not inplace else self._update(sorted_agents) ) @@ -477,7 +481,7 @@ def __getstate__(self): Returns: dict: A dictionary representing the state of the AgentSet. """ - return {"agents": list(self._agents.keys()), "model": self.model} + return {"agents": list(self._agents.keys()), "random": self.random} def __setstate__(self, state): """Set the state of the AgentSet during deserialization. @@ -485,18 +489,9 @@ def __setstate__(self, state): Args: state (dict): A dictionary representing the state to restore. """ - self.model = state["model"] + self.random = state["random"] self._update(state["agents"]) - @property - def random(self) -> Random: - """Provide access to the model's random number generator. - - Returns: - Random: The random number generator associated with the model. - """ - return self.model.random - def groupby(self, by: Callable | str, result_type: str = "agentset") -> GroupBy: """Group agents by the specified attribute or return from the callable. @@ -529,7 +524,7 @@ def groupby(self, by: Callable | str, result_type: str = "agentset") -> GroupBy: if result_type == "agentset": return GroupBy( - {k: AgentSet(v, model=self.model) for k, v in groups.items()} + {k: AgentSet(v, random=self.random) for k, v in groups.items()} ) else: return GroupBy(groups) diff --git a/mesa/model.py b/mesa/model.py index e4f70793c0e..b6500eb4132 100644 --- a/mesa/model.py +++ b/mesa/model.py @@ -52,8 +52,6 @@ def __init__(self, *args: Any, seed: float | None = None, **kwargs: Any) -> None self.running = True self.steps: int = 0 - self._setup_agent_registration() - self._seed = seed if self._seed is None: # We explicitly specify the seed here so that we know its value in @@ -65,6 +63,9 @@ def __init__(self, *args: Any, seed: float | None = None, **kwargs: Any) -> None self._user_step = self.step self.step = self._wrapped_step + # setup agent registration data structures + self._setup_agent_registration() + def _wrapped_step(self, *args: Any, **kwargs: Any) -> None: """Automatically increments time and steps after calling the user's step method.""" # Automatically increment time and step counters @@ -119,7 +120,9 @@ def _setup_agent_registration(self): self._agents_by_type: dict[ type[Agent], AgentSet ] = {} # a dict with an agentset for each class of agents - self._all_agents = AgentSet([], self) # an agenset with all agents + self._all_agents = AgentSet( + [], random=self.random + ) # an agenset with all agents def register_agent(self, agent): """Register the agent with the model. @@ -153,7 +156,7 @@ def register_agent(self, agent): [ agent, ], - self, + random=self.random, ) self._all_agents.add(agent) diff --git a/tests/test_agent.py b/tests/test_agent.py index 73aab897799..aa5ff9350ff 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -62,7 +62,7 @@ def test_agentset(): model = Model() agents = [AgentTest(model) for _ in range(10)] - agentset = AgentSet(agents, model) + agentset = AgentSet(agents, random=model.random) assert agents[0] in agentset assert len(agentset) == len(agents) @@ -118,7 +118,7 @@ def test_function(agent): # because AgentSet uses weakrefs, we need hard refs as well.... other_agents, another_set = pickle.loads( # noqa: S301 - pickle.dumps([agents, AgentSet(agents, model)]) + pickle.dumps([agents, AgentSet(agents, random=model.random)]) ) assert all( a1.unique_id == a2.unique_id for a1, a2 in zip(another_set, other_agents) @@ -129,11 +129,11 @@ def test_function(agent): def test_agentset_initialization(): """Test agentset initialization.""" model = Model() - empty_agentset = AgentSet([], model) + empty_agentset = AgentSet([], random=model.random) assert len(empty_agentset) == 0 agents = [AgentTest(model) for _ in range(10)] - agentset = AgentSet(agents, model) + agentset = AgentSet(agents, random=model.random) assert len(agentset) == 10 @@ -141,7 +141,7 @@ def test_agentset_serialization(): """Test pickleability of agentset.""" model = Model() agents = [AgentTest(model) for _ in range(5)] - agentset = AgentSet(agents, model) + agentset = AgentSet(agents, random=model.random) serialized = pickle.dumps(agentset) deserialized = pickle.loads(serialized) # noqa: S301 @@ -156,7 +156,7 @@ def test_agent_membership(): """Test agent membership in AgentSet.""" model = Model() agents = [AgentTest(model) for _ in range(5)] - agentset = AgentSet(agents, model) + agentset = AgentSet(agents, random=model.random) assert agents[0] in agentset assert AgentTest(model) not in agentset @@ -166,7 +166,7 @@ def test_agent_add_remove_discard(): """Test adding, removing and discarding agents from AgentSet.""" model = Model() agent = AgentTest(model) - agentset = AgentSet([], model) + agentset = AgentSet([], random=model.random) agentset.add(agent) assert agent in agentset @@ -186,7 +186,7 @@ def test_agentset_get_item(): """Test integer based access to AgentSet.""" model = Model() agents = [AgentTest(model) for _ in range(10)] - agentset = AgentSet(agents, model) + agentset = AgentSet(agents, random=model.random) assert agentset[0] == agents[0] assert agentset[-1] == agents[-1] @@ -200,7 +200,7 @@ def test_agentset_do_str(): """Test AgentSet.do with str.""" model = Model() agents = [AgentTest(model) for _ in range(10)] - agentset = AgentSet(agents, model) + agentset = AgentSet(agents, random=model.random) with pytest.raises(AttributeError): agentset.do("non_existing_method") @@ -213,7 +213,7 @@ def test_agentset_do_str(): n = 10 model = Model() agents = [AgentDoTest(model) for _ in range(n)] - agentset = AgentSet(agents, model) + agentset = AgentSet(agents, random=model.random) for agent in agents: agent.agent_set = agentset @@ -223,7 +223,7 @@ def test_agentset_do_str(): # setup model = Model() agents = [AgentDoTest(model) for _ in range(10)] - agentset = AgentSet(agents, model) + agentset = AgentSet(agents, random=model.random) for agent in agents: agent.agent_set = agentset @@ -235,7 +235,7 @@ def test_agentset_do_callable(): """Test AgentSet.do with callable.""" model = Model() agents = [AgentTest(model) for _ in range(10)] - agentset = AgentSet(agents, model) + agentset = AgentSet(agents, random=model.random) # Test callable with non-existent function with pytest.raises(AttributeError): @@ -249,7 +249,7 @@ def test_agentset_do_callable(): n = 10 model = Model() agents = [AgentDoTest(model) for _ in range(n)] - agentset = AgentSet(agents, model) + agentset = AgentSet(agents, random=model.random) for agent in agents: agent.agent_set = agentset @@ -260,7 +260,7 @@ def test_agentset_do_callable(): # setup again for lambda function tests model = Model() agents = [AgentDoTest(model) for _ in range(10)] - agentset = AgentSet(agents, model) + agentset = AgentSet(agents, random=model.random) for agent in agents: agent.agent_set = agentset @@ -278,7 +278,7 @@ def remove_function(agent): # setup again for actual function tests model = Model() agents = [AgentDoTest(model) for _ in range(n)] - agentset = AgentSet(agents, model) + agentset = AgentSet(agents, random=model.random) for agent in agents: agent.agent_set = agentset @@ -289,7 +289,7 @@ def remove_function(agent): # setup again for actual function tests model = Model() agents = [AgentDoTest(model) for _ in range(10)] - agentset = AgentSet(agents, model) + agentset = AgentSet(agents, random=model.random) for agent in agents: agent.agent_set = agentset @@ -354,7 +354,7 @@ def test_agentset_agg(): agent.energy = i + 1 agent.wealth = 10 * (i + 1) - agentset = AgentSet(agents, model) + agentset = AgentSet(agents, random=model.random) # Test min aggregation min_energy = agentset.agg("energy", min) @@ -391,7 +391,7 @@ def __init__(self, model, age=None): model = Model() agents = [TestAgentWithAttribute(model, age=i) for i in range(5)] - agentset = AgentSet(agents, model) + agentset = AgentSet(agents, random=model.random) # Set a new attribute "health" and an existing attribute "age" for all agents agentset.set("health", 100).set("age", 50).set("status", "active") @@ -410,7 +410,7 @@ def test_agentset_map_str(): """Test AgentSet.map with strings.""" model = Model() agents = [AgentTest(model) for _ in range(10)] - agentset = AgentSet(agents, model) + agentset = AgentSet(agents, random=model.random) with pytest.raises(AttributeError): agentset.do("non_existing_method") @@ -423,7 +423,7 @@ def test_agentset_map_callable(): """Test AgentSet.map with callable.""" model = Model() agents = [AgentTest(model) for _ in range(10)] - agentset = AgentSet(agents, model) + agentset = AgentSet(agents, random=model.random) # Test callable with non-existent function with pytest.raises(AttributeError): @@ -450,7 +450,7 @@ def test_method(self): self.called = True agents = [TestAgentShuffleDo(model) for _ in range(100)] - agentset = AgentSet(agents, model) + agentset = AgentSet(agents, random=model.random) # Test shuffle_do with a string method name agentset.shuffle_do("test_method") @@ -477,7 +477,7 @@ def test_agentset_get_attribute(): """Test AgentSet.get for attributes.""" model = Model() agents = [AgentTest(model) for _ in range(10)] - agentset = AgentSet(agents, model) + agentset = AgentSet(agents, random=model.random) unique_ids = agentset.get("unique_id") assert unique_ids == [agent.unique_id for agent in agents] @@ -491,7 +491,7 @@ def test_agentset_get_attribute(): agent = AgentTest(model) agent.i = i**2 agents.append(agent) - agentset = AgentSet(agents, model) + agentset = AgentSet(agents, random=model.random) values = agentset.get(["unique_id", "i"]) @@ -521,7 +521,7 @@ def test_agentset_select_by_type(): # Combine the two types of agents mixed_agents = test_agents + other_agents - agentset = AgentSet(mixed_agents, model) + agentset = AgentSet(mixed_agents, random=model.random) # Test selection by type selected_test_agents = agentset.select(agent_type=AgentTest) @@ -544,11 +544,11 @@ def test_agentset_shuffle(): model = Model() test_agents = [AgentTest(model) for _ in range(12)] - agentset = AgentSet(test_agents, model=model) + agentset = AgentSet(test_agents, random=model.random) agentset = agentset.shuffle() assert not all(a1 == a2 for a1, a2 in zip(test_agents, agentset)) - agentset = AgentSet(test_agents, model=model) + agentset = AgentSet(test_agents, random=model.random) agentset.shuffle(inplace=True) assert not all(a1 == a2 for a1, a2 in zip(test_agents, agentset)) @@ -567,7 +567,7 @@ def get_unique_identifier(self): model = Model() agents = [TestAgent(model) for _ in range(10)] - agentset = AgentSet(agents, model) + agentset = AgentSet(agents, random=model.random) groups = agentset.groupby("even") assert len(groups.groups[True]) == 5