Skip to content

Commit

Permalink
feat: Implement Slider class for JupyterViz
Browse files Browse the repository at this point in the history
This provides API parity with the previous Tornado viz.
  • Loading branch information
rht committed Jan 19, 2024
1 parent f9798eb commit a1ddac1
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 6 deletions.
48 changes: 48 additions & 0 deletions mesa/experimental/UserParam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
class UserParam:
_ERROR_MESSAGE = "Missing or malformed inputs for '{}' Option '{}'"

def maybe_raise_error(self, param_type, valid):
if valid:
return
msg = self._ERROR_MESSAGE.format(param_type, self.label)
raise ValueError(msg)


class Slider(UserParam):
"""
A number-based slider input with settable increment.
Example:
slider_option = Slider("My Slider", value=123, min=10, max=200, step=0.1)
"""

def __init__(
self,
label="",
value=None,
min=None,
max=None,
step=1,
dtype=None,
):
self.label = label
self.value = value
self.min = min
self.max = max
self.step = step

# Validate option type to make sure values are supplied properly
valid = not (self.value is None or self.min is None or self.max is None)
self.maybe_raise_error("slider", valid)

if dtype is None:
self.is_float_slider = self._check_values_are_float(value, min, max, step)
else:
self.is_float_slider = dtype == float

def _check_values_are_float(self, value, min, max, step):
return any(isinstance(n, float) for n in (value, min, max, step))

def get(self, attr):
return getattr(self, attr)
2 changes: 1 addition & 1 deletion mesa/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .jupyter_viz import JupyterViz, make_text # noqa
from .jupyter_viz import JupyterViz, make_text, Slider # noqa
24 changes: 20 additions & 4 deletions mesa/experimental/jupyter_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from solara.alias import rv

import mesa.experimental.components.matplotlib as components_matplotlib
from mesa.experimental.UserParam import Slider

# Avoid interactive backend
plt.switch_backend("agg")
Expand Down Expand Up @@ -40,7 +41,7 @@ def JupyterViz(
# 1. Set up model parameters
user_params, fixed_params = split_model_params(model_params)
model_parameters, set_model_parameters = solara.use_state(
{**fixed_params, **{k: v["value"] for k, v in user_params.items()}}
{**fixed_params, **{k: v.get("value") for k, v in user_params.items()}}
)

# 2. Set up Model
Expand Down Expand Up @@ -251,6 +252,8 @@ def split_model_params(model_params):


def check_param_is_fixed(param):
if isinstance(param, Slider):
return False
if not isinstance(param, dict):
return True
if "type" not in param:
Expand All @@ -270,13 +273,26 @@ def UserInputs(user_params, on_change=None):
"""

for name, options in user_params.items():
# label for the input is "label" from options or name
label = options.get("label", name)
input_type = options.get("type")

def change_handler(value, name=name):
on_change(name, value)

if isinstance(options, Slider):
slider_class = (
solara.SliderFloat if options.is_float_slider else solara.SliderInt
)
slider_class(
options.label,
on_value=change_handler,
min=options.min,
max=options.max,
step=options.step,
)
continue

# label for the input is "label" from options or name
label = options.get("label", name)
input_type = options.get("type")
if input_type == "SliderInt":
solara.SliderInt(
label,
Expand Down
16 changes: 15 additions & 1 deletion tests/test_jupyter_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import ipyvuetify as vw
import solara

from mesa.experimental.jupyter_viz import JupyterViz, UserInputs
from mesa.experimental.jupyter_viz import JupyterViz, Slider, UserInputs


class TestMakeUserInput(unittest.TestCase):
Expand Down Expand Up @@ -132,3 +132,17 @@ def test_call_space_drawer(self, mock_space_matplotlib):
altspace_drawer.assert_called_with(
mock_model_class.return_value, agent_portrayal
)


def test_slider():
slider_float = Slider("Agent density", 0.8, 0.1, 1.0, 0.1)
assert slider_float.is_float_slider
assert slider_float.value == 0.8
assert slider_float.get("value") == 0.8
assert slider_float.min == 0.1
assert slider_float.max == 1.0
assert slider_float.step == 0.1
slider_int = Slider("Homophily", 3, 0, 8, 1)
assert not slider_int.is_float_slider
slider_dtype_float = Slider("Homophily", 3, 0, 8, 1, dtype=float)
assert slider_dtype_float.is_float_slider

0 comments on commit a1ddac1

Please sign in to comment.