diff --git a/docs/overview.md b/docs/overview.md index 92f7be06907..38f0ef9d2d0 100644 --- a/docs/overview.md +++ b/docs/overview.md @@ -177,20 +177,20 @@ def agent_portrayal(agent): model_params = { "N": { - "type": "SliderInt", - "value": 50, - "label": "Number of agents:", - "min": 10, - "max": 100, - "step": 1, + "type": "SliderInt", + "value": 50, + "label": "Number of agents:", + "min": 10, + "max": 100, + "step": 1, } } page = SolaraViz( MyModel, [ - make_space_component(agent_portrayal), - make_plot_component("mean_age") + make_space_component(agent_portrayal), + make_plot_component("mean_age") ], model_params=model_params ) diff --git a/mesa/examples/basic/boltzmann_wealth_model/app.py b/mesa/examples/basic/boltzmann_wealth_model/app.py index ddb8933049f..ff329ab9667 100644 --- a/mesa/examples/basic/boltzmann_wealth_model/app.py +++ b/mesa/examples/basic/boltzmann_wealth_model/app.py @@ -1,18 +1,10 @@ from mesa.examples.basic.boltzmann_wealth_model.model import BoltzmannWealthModel -from mesa.visualization import ( - SolaraViz, - make_plot_component, - make_space_component, -) +from mesa.visualization import SolaraViz, make_plot_component, make_space_component def agent_portrayal(agent): - size = 10 - color = "tab:red" - if agent.wealth > 0: - size = 50 - color = "tab:blue" - return {"size": size, "color": color} + color = agent.wealth # we are using a colormap to translate wealth to color + return {"color": color} model_params = { @@ -28,6 +20,11 @@ def agent_portrayal(agent): "height": 10, } + +def post_process(ax): + ax.get_figure().colorbar(ax.collections[0], label="wealth", ax=ax) + + # Create initial model instance model1 = BoltzmannWealthModel(50, 10, 10) @@ -36,7 +33,10 @@ def agent_portrayal(agent): # Under the hood these are just classes that receive the model instance. # You can also author your own visualization elements, which can also be functions # that receive the model instance and return a valid solara component. -SpaceGraph = make_space_component(agent_portrayal) + +SpaceGraph = make_space_component( + agent_portrayal, cmap="viridis", vmin=0, vmax=10, post_process=post_process +) GiniPlot = make_plot_component("Gini") # Create the SolaraViz page. This will automatically create a server and display the diff --git a/mesa/visualization/components/matplotlib.py b/mesa/visualization/components/matplotlib.py index 09b281a3e17..2bda984f775 100644 --- a/mesa/visualization/components/matplotlib.py +++ b/mesa/visualization/components/matplotlib.py @@ -309,6 +309,7 @@ def draw_orthogonal_grid( agent_portrayal: Callable, ax: Axes | None = None, draw_grid: bool = True, + **kwargs, ): """Visualize a orthogonal grid. @@ -317,6 +318,7 @@ def draw_orthogonal_grid( agent_portrayal: a callable that is called with the agent and returns a dict ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots draw_grid: whether to draw the grid + kwargs: additional keyword arguments passed to ax.scatter Returns: Returns the Axes object with the plot drawn onto it. @@ -333,7 +335,7 @@ def draw_orthogonal_grid( arguments = collect_agent_data(space, agent_portrayal, size=s_default) # plot the agents - _scatter(ax, arguments) + _scatter(ax, arguments, **kwargs) # further styling ax.set_xlim(-0.5, space.width - 0.5) @@ -354,6 +356,7 @@ def draw_hex_grid( agent_portrayal: Callable, ax: Axes | None = None, draw_grid: bool = True, + **kwargs, ): """Visualize a hex grid. @@ -362,6 +365,7 @@ def draw_hex_grid( agent_portrayal: a callable that is called with the agent and returns a dict ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots draw_grid: whether to draw the grid + kwargs: additional keyword arguments passed to ax.scatter Returns: Returns the Axes object with the plot drawn onto it. @@ -394,7 +398,7 @@ def draw_hex_grid( arguments["loc"] = loc # plot the agents - _scatter(ax, arguments) + _scatter(ax, arguments, **kwargs) # further styling and adding of grid ax.set_xlim(-1, space.width + 0.5) @@ -443,6 +447,7 @@ def draw_network( draw_grid: bool = True, layout_alg=nx.spring_layout, layout_kwargs=None, + **kwargs, ): """Visualize a network space. @@ -453,6 +458,7 @@ def draw_network( draw_grid: whether to draw the grid layout_alg: a networkx layout algorithm or other callable with the same behavior layout_kwargs: a dictionary of keyword arguments for the layout algorithm + kwargs: additional keyword arguments passed to ax.scatter Returns: Returns the Axes object with the plot drawn onto it. @@ -488,7 +494,7 @@ def draw_network( arguments["loc"] = pos[arguments["loc"]] # plot the agents - _scatter(ax, arguments) + _scatter(ax, arguments, **kwargs) # further styling ax.set_axis_off() @@ -506,7 +512,7 @@ def draw_network( def draw_continuous_space( - space: ContinuousSpace, agent_portrayal: Callable, ax: Axes | None = None + space: ContinuousSpace, agent_portrayal: Callable, ax: Axes | None = None, **kwargs ): """Visualize a continuous space. @@ -514,6 +520,7 @@ def draw_continuous_space( space: the space to visualize agent_portrayal: a callable that is called with the agent and returns a dict ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots + kwargs: additional keyword arguments passed to ax.scatter Returns: Returns the Axes object with the plot drawn onto it. @@ -536,7 +543,7 @@ def draw_continuous_space( arguments = collect_agent_data(space, agent_portrayal, size=s_default) # plot the agents - _scatter(ax, arguments) + _scatter(ax, arguments, **kwargs) # further visual styling border_style = "solid" if not space.torus else (0, (5, 10)) @@ -552,7 +559,7 @@ def draw_continuous_space( def draw_voroinoi_grid( - space: VoronoiGrid, agent_portrayal: Callable, ax: Axes | None = None + space: VoronoiGrid, agent_portrayal: Callable, ax: Axes | None = None, **kwargs ): """Visualize a voronoi grid. @@ -560,6 +567,7 @@ def draw_voroinoi_grid( space: the space to visualize agent_portrayal: a callable that is called with the agent and returns a dict ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots + kwargs: additional keyword arguments passed to ax.scatter Returns: Returns the Axes object with the plot drawn onto it. @@ -589,7 +597,7 @@ def draw_voroinoi_grid( ax.set_xlim(x_min - x_padding, x_max + x_padding) ax.set_ylim(y_min - y_padding, y_max + y_padding) - _scatter(ax, arguments) + _scatter(ax, arguments, **kwargs) for cell in space.all_cells: polygon = cell.properties["polygon"] @@ -604,8 +612,15 @@ def draw_voroinoi_grid( return ax -def _scatter(ax: Axes, arguments): - """Helper function for plotting the agents.""" +def _scatter(ax: Axes, arguments, **kwargs): + """Helper function for plotting the agents. + + Args: + ax: a Matplotlib Axes instance + arguments: the agents specific arguments for platting + kwargs: additional keyword arguments for ax.scatter + + """ loc = arguments.pop("loc") x = loc[:, 0] @@ -624,6 +639,7 @@ def _scatter(ax: Axes, arguments): marker=mark, zorder=z_order, **{k: v[logical] for k, v in arguments.items()}, + **kwargs, )