diff --git a/mesa/datacollection.py b/mesa/datacollection.py index 88f1863f457..3fcd205c9c5 100644 --- a/mesa/datacollection.py +++ b/mesa/datacollection.py @@ -14,8 +14,7 @@ When the collect() method is called, each model-level function is called, with the model as the argument, and the results associated with the relevant -variable. Then the agent-level functions are called on each agent in the model -scheduler. +variable. Then the agent-level functions are called on each agent. Additionally, other objects can write directly to tables by passing in an appropriate dictionary object for a table row. @@ -30,8 +29,7 @@ Finally, DataCollector can create a pandas DataFrame from each collection. The default DataCollector here makes several assumptions: - * The model has a schedule object called 'schedule' - * The schedule has an agent list called agents + * The model has an agent list called agents * For collecting agent-level variables, agents must have a unique_id """ import contextlib @@ -67,7 +65,7 @@ def __init__( Model reporters can take four types of arguments: 1. Lambda function: - {"agent_count": lambda m: m.schedule.get_agent_count()} + {"agent_count": lambda m: len(m.agents)} 2. Method of a class/instance: {"agent_count": self.get_agent_count} # self here is a class instance {"agent_count": Model.get_agent_count} # Model here is a class @@ -180,11 +178,14 @@ def _record_agents(self, model): rep_funcs = self.agent_reporters.values() def get_reports(agent): - _prefix = (agent.model.schedule.steps, agent.unique_id) + _prefix = (agent.model._steps, agent.unique_id) reports = tuple(rep(agent) for rep in rep_funcs) return _prefix + reports - agent_records = map(get_reports, model.schedule.agents) + agent_records = map( + get_reports, + model.schedule.agents if hasattr(model, "schedule") else model.agents, + ) return agent_records def collect(self, model): @@ -207,7 +208,7 @@ def collect(self, model): if self.agent_reporters: agent_records = self._record_agents(model) - self._agent_records[model.schedule.steps] = list(agent_records) + self._agent_records[model._steps] = list(agent_records) def add_table_row(self, table_name, row, ignore_missing=False): """Add a row dictionary to a specific table. diff --git a/mesa/experimental/jupyter_viz.py b/mesa/experimental/jupyter_viz.py index 7a586239058..eb20ca5f3bd 100644 --- a/mesa/experimental/jupyter_viz.py +++ b/mesa/experimental/jupyter_viz.py @@ -181,7 +181,7 @@ def on_value_play(change): def do_step(): model.step() previous_step.value = current_step.value - current_step.value += 1 + current_step.value = model._steps def do_play(): model.running = True diff --git a/mesa/model.py b/mesa/model.py index 1bd9b3a7548..b2b12340ad9 100644 --- a/mesa/model.py +++ b/mesa/model.py @@ -13,11 +13,13 @@ from collections import defaultdict # mypy -from typing import Any +from typing import Any, Union from mesa.agent import Agent, AgentSet from mesa.datacollection import DataCollector +TimeT = Union[float, int] + class Model: """Base class for models in the Mesa ABM library. @@ -68,6 +70,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self.current_id = 0 self.agents_: defaultdict[type, dict] = defaultdict(dict) + self._steps: int = 0 + self._time: TimeT = 0 # the model's clock + # Warning flags for current experimental features. These make sure a warning is only printed once per model. self.agentset_experimental_warning_given = False @@ -112,6 +117,11 @@ def run_model(self) -> None: def step(self) -> None: """A single step. Fill in here.""" + def _advance_time(self, deltat: TimeT = 1): + """Increment the model's steps counter and clock.""" + self._steps += 1 + self._time += deltat + def next_id(self) -> int: """Return the next unique ID for agents, increment current_id""" self.current_id += 1 diff --git a/mesa/time.py b/mesa/time.py index 87a68949cdd..3dcd1708f84 100644 --- a/mesa/time.py +++ b/mesa/time.py @@ -73,6 +73,8 @@ def __init__(self, model: Model, agents: Iterable[Agent] | None = None) -> None: self.model = model self.steps = 0 self.time: TimeT = 0 + self._original_step = self.step + self.step = self._wrapped_step if agents is None: agents = [] @@ -115,6 +117,11 @@ def step(self) -> None: self.steps += 1 self.time += 1 + def _wrapped_step(self): + """Wrapper for the step method to include time and step updating.""" + self._original_step() + self.model._advance_time() + def get_agent_count(self) -> int: """Returns the current number of agents in the queue.""" return len(self._agents)