Skip to content

Commit

Permalink
Simplify solara code (#1786)
Browse files Browse the repository at this point in the history
* Simplify solara code

* re-introduced backend switch and removed self
  • Loading branch information
Corvince authored Sep 1, 2023
1 parent ea4b213 commit fb81c1a
Showing 1 changed file with 166 additions and 182 deletions.
348 changes: 166 additions & 182 deletions mesa/experimental/jupyter_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,59 +13,162 @@
plt.switch_backend("agg")


class JupyterContainer:
def __init__(
self,
model_class,
model_params,
measures=None,
name="Mesa Model",
agent_portrayal=None,
):
self.model_class = model_class
self.split_model_params(model_params)
self.measures = measures
self.name = name
self.agent_portrayal = agent_portrayal
self.thread = None

def split_model_params(self, model_params):
self.model_params_input = {}
self.model_params_fixed = {}
for k, v in model_params.items():
if self.check_param_is_fixed(v):
self.model_params_fixed[k] = v
@solara.component
def JupyterViz(
model_class,
model_params,
measures=None,
name="Mesa Model",
agent_portrayal=None,
space_drawer=None,
play_interval=400,
):
current_step, set_current_step = solara.use_state(0)

solara.Markdown(name)

# 0. Split model params
model_params_input, model_params_fixed = split_model_params(model_params)

# 1. User inputs
user_inputs = {}
for k, v in model_params_input.items():
user_input = solara.use_reactive(v["value"])
user_inputs[k] = user_input.value
make_user_input(user_input, k, v)

# 2. Model
def make_model():
return model_class(**user_inputs, **model_params_fixed)

model = solara.use_memo(make_model, dependencies=list(user_inputs.values()))

# 3. Buttons
ModelController(model, play_interval, current_step, set_current_step)

with solara.GridFixed(columns=2):
# 4. Space
if space_drawer is None:
make_space(model, agent_portrayal)
else:
space_drawer(model, agent_portrayal)
# 5. Plots
for measure in measures:
if callable(measure):
# Is a custom object
measure(model)
else:
self.model_params_input[k] = v
make_plot(model, measure)


def check_param_is_fixed(self, param):
if not isinstance(param, dict):
return True
if "type" not in param:
return True
@solara.component
def ModelController(model, play_interval, current_step, set_current_step):
playing = solara.use_reactive(False)
thread = solara.use_reactive(None)

def on_value_play(change):
if model.running:
do_step()
else:
playing.value = False

def do_step(self):
self.model.step()
self.set_df(self.model.datacollector.get_model_vars_dataframe())
def do_step():
model.step()
set_current_step(model.schedule.steps)

def do_play(self):
self.model.running = True
while self.model.running:
self.do_step()
def do_play():
model.running = True
while model.running:
do_step()

def threaded_do_play(self):
if self.thread is not None and self.thread.is_alive():
def threaded_do_play():
if thread is not None and thread.is_alive():
return
self.thread = threading.Thread(target=self.do_play)
self.thread.start()
thread.value = threading.Thread(target=do_play)
thread.start()

def do_pause(self):
if (self.thread is None) or (not self.thread.is_alive()):
def do_pause():
if (thread is None) or (not thread.is_alive()):
return
self.model.running = False
self.thread.join()
model.running = False
thread.join()

def portray(self, g):
with solara.Row():
solara.Button(label="Step", color="primary", on_click=do_step)
# This style is necessary so that the play widget has almost the same
# height as typical Solara buttons.
solara.Style(
"""
.widget-play {
height: 30px;
}
"""
)
widgets.Play(
value=0,
interval=play_interval,
repeat=True,
show_repeat=False,
on_value=on_value_play,
playing=playing.value,
on_playing=playing.set,
)
solara.Markdown(md_text=f"**Step:** {current_step}")
# threaded_do_play is not used for now because it
# doesn't work in Google colab. We use
# ipywidgets.Play until it is fixed. The threading
# version is definite a much better implementation,
# if it works.
# solara.Button(label="▶", color="primary", on_click=viz.threaded_do_play)
# solara.Button(label="⏸︎", color="primary", on_click=viz.do_pause)
# solara.Button(label="Reset", color="primary", on_click=do_reset)


def split_model_params(model_params):
model_params_input = {}
model_params_fixed = {}
for k, v in model_params.items():
if check_param_is_fixed(v):
model_params_fixed[k] = v
else:
model_params_input[k] = v
return model_params_input, model_params_fixed


def check_param_is_fixed(param):
if not isinstance(param, dict):
return True
if "type" not in param:
return True


def make_user_input(user_input, k, v):
if v["type"] == "SliderInt":
solara.SliderInt(
v.get("label", "label"),
value=user_input,
min=v.get("min"),
max=v.get("max"),
step=v.get("step"),
)
elif v["type"] == "SliderFloat":
solara.SliderFloat(
v.get("label", "label"),
value=user_input,
min=v.get("min"),
max=v.get("max"),
step=v.get("step"),
)
elif v["type"] == "Select":
solara.Select(
v.get("label", "label"),
value=v.get("value"),
values=v.get("values"),
)


def make_space(model, agent_portrayal):
def portray(g):
x = []
y = []
s = [] # size
Expand All @@ -79,7 +182,7 @@ def portray(self, g):
# Is a single grid
content = [content]
for agent in content:
data = self.agent_portrayal(agent)
data = agent_portrayal(agent)
x.append(i)
y.append(j)
if "size" in data:
Expand All @@ -93,159 +196,40 @@ def portray(self, g):
out["c"] = c
return out

space_fig = Figure()
space_ax = space_fig.subplots()
if isinstance(model.grid, mesa.space.NetworkGrid):
_draw_network_grid(model, space_ax, agent_portrayal)
else:
space_ax.scatter(**portray(model.grid))
space_ax.set_axis_off()
solara.FigureMatplotlib(space_fig)


def _draw_network_grid(viz, space_ax):
graph = viz.model.grid.G
def _draw_network_grid(model, space_ax, agent_portrayal):
graph = model.grid.G
pos = nx.spring_layout(graph, seed=0)
nx.draw(
graph,
ax=space_ax,
pos=pos,
**viz.agent_portrayal(graph),
**agent_portrayal(graph),
)


def make_space(viz):
space_fig = Figure()
space_ax = space_fig.subplots()
if isinstance(viz.model.grid, mesa.space.NetworkGrid):
_draw_network_grid(viz, space_ax)
else:
space_ax.scatter(**viz.portray(viz.model.grid))
space_ax.set_axis_off()
solara.FigureMatplotlib(space_fig, dependencies=[viz.model, viz.df])


def make_plot(viz, measure):
def make_plot(model, measure):
fig = Figure()
ax = fig.subplots()
ax.plot(viz.df.loc[:, measure])
df = model.datacollector.get_model_vars_dataframe()
ax.plot(df.loc[:, measure])
ax.set_ylabel(measure)
# Set integer x axis
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
solara.FigureMatplotlib(fig, dependencies=[viz.model, viz.df])
solara.FigureMatplotlib(fig)


def make_text(renderer):
def function(viz):
solara.Markdown(renderer(viz.model))
def function(model):
solara.Markdown(renderer(model))

return function


def make_user_input(user_input, k, v):
if v["type"] == "SliderInt":
solara.SliderInt(
v.get("label", "label"),
value=user_input,
min=v.get("min"),
max=v.get("max"),
step=v.get("step"),
)
elif v["type"] == "SliderFloat":
solara.SliderFloat(
v.get("label", "label"),
value=user_input,
min=v.get("min"),
max=v.get("max"),
step=v.get("step"),
)
elif v["type"] == "Select":
solara.Select(
v.get("label", "label"),
value=v.get("value"),
values=v.get("values"),
)


@solara.component
def MesaComponent(viz, space_drawer=None, play_interval=400):
solara.Markdown(viz.name)

# 1. User inputs
user_inputs = {}
for k, v in viz.model_params_input.items():
user_input = solara.use_reactive(v["value"])
user_inputs[k] = user_input.value
make_user_input(user_input, k, v)

# 2. Model
def make_model():
return viz.model_class(**user_inputs, **viz.model_params_fixed)

viz.model = solara.use_memo(make_model, dependencies=list(user_inputs.values()))
viz.df, viz.set_df = solara.use_state(
viz.model.datacollector.get_model_vars_dataframe()
)

# 3. Buttons
playing = solara.use_reactive(False)

def on_value_play(change):
if viz.model.running:
viz.do_step()
else:
playing.value = False

with solara.Row():
solara.Button(label="Step", color="primary", on_click=viz.do_step)
# This style is necessary so that the play widget has almost the same
# height as typical Solara buttons.
solara.Style(
"""
.widget-play {
height: 30px;
}
"""
)
widgets.Play(
value=0,
interval=play_interval,
repeat=True,
show_repeat=False,
on_value=on_value_play,
playing=playing.value,
on_playing=playing.set,
)
solara.Markdown(md_text=f"**Step:** {viz.model.schedule.steps}")
# threaded_do_play is not used for now because it
# doesn't work in Google colab. We use
# ipywidgets.Play until it is fixed. The threading
# version is definite a much better implementation,
# if it works.
# solara.Button(label="▶", color="primary", on_click=viz.threaded_do_play)
# solara.Button(label="⏸︎", color="primary", on_click=viz.do_pause)
# solara.Button(label="Reset", color="primary", on_click=do_reset)

with solara.GridFixed(columns=2):
# 4. Space
if space_drawer is None:
make_space(viz)
else:
space_drawer(viz)
# 5. Plots
for measure in viz.measures:
if callable(measure):
# Is a custom object
measure(viz)
else:
make_plot(viz, measure)


# JupyterViz has to be a Solara component, so that each browser tabs runs in
# their own, separate simulation thread. See https://github.com/projectmesa/mesa/issues/856.
@solara.component
def JupyterViz(
model_class,
model_params,
measures=None,
name="Mesa Model",
agent_portrayal=None,
space_drawer=None,
play_interval=400,
):
return MesaComponent(
JupyterContainer(model_class, model_params, measures, name, agent_portrayal),
space_drawer=space_drawer,
play_interval=play_interval,
)

0 comments on commit fb81c1a

Please sign in to comment.