Skip to content

Commit

Permalink
remove global state dependency from UserInputs
Browse files Browse the repository at this point in the history
  • Loading branch information
Corvince committed Sep 7, 2023
1 parent 30341cb commit 90e1243
Showing 1 changed file with 15 additions and 17 deletions.
32 changes: 15 additions & 17 deletions mesa/experimental/jupyter_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@
plt.switch_backend("agg")


model_parameters = solara.reactive(None)


@solara.component
def JupyterViz(
model_class,
Expand Down Expand Up @@ -44,24 +41,24 @@ def JupyterViz(

# 1. Set up model parameters
user_params, fixed_params = split_model_params(model_params)
if model_parameters.value is None:
model_parameters.value = fixed_params | {
k: v["value"] for k, v in user_params.items()
}
model_parameters, set_model_parameters = solara.use_state(
fixed_params | {k: v["value"] for k, v in user_params.items()}
)

# 2. Set up Model
def make_model():
model = model_class(**model_parameters.value)
model = model_class(**model_parameters)
set_current_step(0)
return model

model = solara.use_memo(
make_model, dependencies=list(model_parameters.value.values())
)
model = solara.use_memo(make_model, dependencies=list(model_parameters.values()))

def handle_change_model_params(name, change):
set_model_parameters(model_parameters | {name: change})

# 3. Set up UI
solara.Markdown(name)
UserInputs(user_params)
UserInputs(user_params, on_change=handle_change_model_params)
ModelController(model, play_interval, current_step, set_current_step)

with solara.GridFixed(columns=2):
Expand Down Expand Up @@ -165,19 +162,20 @@ def check_param_is_fixed(param):


@solara.component
def UserInputs(user_params):
def UserInputs(user_params, on_change=None):
"""Initialize user inputs for configurable model parameters.
Currently supports :class:`solara.SliderInt`, :class:`solara.SliderFloat`,
and :class:`solara.Select`.
Props:
user_params: dictionary with options for the input, including label,
min and max values, and other fields specific to the input type.
on_change: function to be called when the value of an input changes.
"""

def handle_change_value(name):
def change_value(change):
model_parameters.value = model_parameters.value | {name: change}
on_change(name, change)

return change_value

Expand All @@ -188,7 +186,7 @@ def change_value(change):
if input_type == "SliderInt":
solara.SliderInt(
label,
value=model_parameters.value[name],
value=options.get("value"),
on_value=handle_change_value(name),
min=options.get("min"),
max=options.get("max"),
Expand All @@ -197,7 +195,7 @@ def change_value(change):
elif input_type == "SliderFloat":
solara.SliderFloat(
label,
value=model_parameters.value[name],
value=options.get("value"),
on_value=handle_change_value(name),
min=options.get("min"),
max=options.get("max"),
Expand All @@ -206,7 +204,7 @@ def change_value(change):
elif input_type == "Select":
solara.Select(
label,
value=model_parameters.value[name],
value=options.get("value"),
values=options.get("values"),
)
else:
Expand Down

0 comments on commit 90e1243

Please sign in to comment.