From 003cbe35260bb93231640054aa3abf7f7b8492ba Mon Sep 17 00:00:00 2001 From: rht Date: Fri, 26 Jan 2024 11:36:21 -0500 Subject: [PATCH] refactor: Remove dependence on model.schedule, add clock to Model (#1942) * refactor: Remove dependence on model.schedule * model: Implement internal clock * time: Call self.model.advance_time() in step() This ensures that the scheduler's clock and the model's clock are updated and are in sync. * Ensure advance_time call in schedulers happen only once in a model step * Turn model steps and time to be private attribute * Rename advance_time to _advance_time * Annotate model._steps * Remove _advance_time from tests This is because schedule.step already calls _advance_time under the hood. * model: Rename _time to time_ * Rename _steps to steps_ * Revert applying _advance_time in schedulers step * feat: Automatically call _advance_time right after model step() Solution drafted by and partially attributed to ChatGPT: https://chat.openai.com/share/d9b9c6c6-17d0-4eb9-9eae-484402bed756 * fix: Make sure agent removes itself in schedule.remove * fix: Do step() wrapping in scheduler instead of model * fix: JupyterViz: replace model.steps with model.steps_ * Rename steps_ -> _steps, time_ -> _time * agent_records: Use model.agents only when model has no scheduler --------- Co-authored-by: Ewout ter Hoeven --- mesa/datacollection.py | 17 +++++++++-------- mesa/experimental/jupyter_viz.py | 2 +- mesa/model.py | 12 +++++++++++- mesa/time.py | 7 +++++++ 4 files changed, 28 insertions(+), 10 deletions(-) diff --git a/mesa/datacollection.py b/mesa/datacollection.py index 88f1863f457..3fcd205c9c5 100644 --- a/mesa/datacollection.py +++ b/mesa/datacollection.py @@ -14,8 +14,7 @@ When the collect() method is called, each model-level function is called, with the model as the argument, and the results associated with the relevant -variable. Then the agent-level functions are called on each agent in the model -scheduler. +variable. Then the agent-level functions are called on each agent. Additionally, other objects can write directly to tables by passing in an appropriate dictionary object for a table row. @@ -30,8 +29,7 @@ Finally, DataCollector can create a pandas DataFrame from each collection. The default DataCollector here makes several assumptions: - * The model has a schedule object called 'schedule' - * The schedule has an agent list called agents + * The model has an agent list called agents * For collecting agent-level variables, agents must have a unique_id """ import contextlib @@ -67,7 +65,7 @@ def __init__( Model reporters can take four types of arguments: 1. Lambda function: - {"agent_count": lambda m: m.schedule.get_agent_count()} + {"agent_count": lambda m: len(m.agents)} 2. Method of a class/instance: {"agent_count": self.get_agent_count} # self here is a class instance {"agent_count": Model.get_agent_count} # Model here is a class @@ -180,11 +178,14 @@ def _record_agents(self, model): rep_funcs = self.agent_reporters.values() def get_reports(agent): - _prefix = (agent.model.schedule.steps, agent.unique_id) + _prefix = (agent.model._steps, agent.unique_id) reports = tuple(rep(agent) for rep in rep_funcs) return _prefix + reports - agent_records = map(get_reports, model.schedule.agents) + agent_records = map( + get_reports, + model.schedule.agents if hasattr(model, "schedule") else model.agents, + ) return agent_records def collect(self, model): @@ -207,7 +208,7 @@ def collect(self, model): if self.agent_reporters: agent_records = self._record_agents(model) - self._agent_records[model.schedule.steps] = list(agent_records) + self._agent_records[model._steps] = list(agent_records) def add_table_row(self, table_name, row, ignore_missing=False): """Add a row dictionary to a specific table. diff --git a/mesa/experimental/jupyter_viz.py b/mesa/experimental/jupyter_viz.py index 7a586239058..eb20ca5f3bd 100644 --- a/mesa/experimental/jupyter_viz.py +++ b/mesa/experimental/jupyter_viz.py @@ -181,7 +181,7 @@ def on_value_play(change): def do_step(): model.step() previous_step.value = current_step.value - current_step.value += 1 + current_step.value = model._steps def do_play(): model.running = True diff --git a/mesa/model.py b/mesa/model.py index 1bd9b3a7548..b2b12340ad9 100644 --- a/mesa/model.py +++ b/mesa/model.py @@ -13,11 +13,13 @@ from collections import defaultdict # mypy -from typing import Any +from typing import Any, Union from mesa.agent import Agent, AgentSet from mesa.datacollection import DataCollector +TimeT = Union[float, int] + class Model: """Base class for models in the Mesa ABM library. @@ -68,6 +70,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self.current_id = 0 self.agents_: defaultdict[type, dict] = defaultdict(dict) + self._steps: int = 0 + self._time: TimeT = 0 # the model's clock + # Warning flags for current experimental features. These make sure a warning is only printed once per model. self.agentset_experimental_warning_given = False @@ -112,6 +117,11 @@ def run_model(self) -> None: def step(self) -> None: """A single step. Fill in here.""" + def _advance_time(self, deltat: TimeT = 1): + """Increment the model's steps counter and clock.""" + self._steps += 1 + self._time += deltat + def next_id(self) -> int: """Return the next unique ID for agents, increment current_id""" self.current_id += 1 diff --git a/mesa/time.py b/mesa/time.py index 87a68949cdd..3dcd1708f84 100644 --- a/mesa/time.py +++ b/mesa/time.py @@ -73,6 +73,8 @@ def __init__(self, model: Model, agents: Iterable[Agent] | None = None) -> None: self.model = model self.steps = 0 self.time: TimeT = 0 + self._original_step = self.step + self.step = self._wrapped_step if agents is None: agents = [] @@ -115,6 +117,11 @@ def step(self) -> None: self.steps += 1 self.time += 1 + def _wrapped_step(self): + """Wrapper for the step method to include time and step updating.""" + self._original_step() + self.model._advance_time() + def get_agent_count(self) -> int: """Returns the current number of agents in the queue.""" return len(self._agents)