diff --git a/mesa/experimental/UserParam.py b/mesa/experimental/UserParam.py new file mode 100644 index 00000000000..3c239599796 --- /dev/null +++ b/mesa/experimental/UserParam.py @@ -0,0 +1,48 @@ +class UserParam: + _ERROR_MESSAGE = "Missing or malformed inputs for '{}' Option '{}'" + + def maybe_raise_error(self, param_type, valid): + if valid: + return + msg = self._ERROR_MESSAGE.format(param_type, self.label) + raise ValueError(msg) + + +class Slider(UserParam): + """ + A number-based slider input with settable increment. + + Example: + + slider_option = Slider("My Slider", value=123, min=10, max=200, step=0.1) + """ + + def __init__( + self, + label="", + value=None, + min=None, + max=None, + step=1, + dtype=None, + ): + self.label = label + self.value = value + self.min = min + self.max = max + self.step = step + + # Validate option type to make sure values are supplied properly + valid = not (self.value is None or self.min is None or self.max is None) + self.maybe_raise_error("slider", valid) + + if dtype is None: + self.is_float_slider = self._check_values_are_float(value, min, max, step) + else: + self.is_float_slider = dtype == float + + def _check_values_are_float(self, value, min, max, step): + return any(isinstance(n, float) for n in (value, min, max, step)) + + def get(self, attr): + return getattr(self, attr) diff --git a/mesa/experimental/__init__.py b/mesa/experimental/__init__.py index 964dc5d19a3..c04f22589b3 100644 --- a/mesa/experimental/__init__.py +++ b/mesa/experimental/__init__.py @@ -1 +1 @@ -from .jupyter_viz import JupyterViz, make_text # noqa +from .jupyter_viz import JupyterViz, make_text, Slider # noqa diff --git a/mesa/experimental/jupyter_viz.py b/mesa/experimental/jupyter_viz.py index 2e5f066409f..4104cb64641 100644 --- a/mesa/experimental/jupyter_viz.py +++ b/mesa/experimental/jupyter_viz.py @@ -7,6 +7,7 @@ from solara.alias import rv import mesa.experimental.components.matplotlib as components_matplotlib +from mesa.experimental.UserParam import Slider # Avoid interactive backend plt.switch_backend("agg") @@ -40,7 +41,7 @@ def JupyterViz( # 1. Set up model parameters user_params, fixed_params = split_model_params(model_params) model_parameters, set_model_parameters = solara.use_state( - {**fixed_params, **{k: v["value"] for k, v in user_params.items()}} + {**fixed_params, **{k: v.get("value") for k, v in user_params.items()}} ) # 2. Set up Model @@ -251,6 +252,8 @@ def split_model_params(model_params): def check_param_is_fixed(param): + if isinstance(param, Slider): + return False if not isinstance(param, dict): return True if "type" not in param: @@ -270,13 +273,26 @@ def UserInputs(user_params, on_change=None): """ for name, options in user_params.items(): - # label for the input is "label" from options or name - label = options.get("label", name) - input_type = options.get("type") def change_handler(value, name=name): on_change(name, value) + if isinstance(options, Slider): + slider_class = ( + solara.SliderFloat if options.is_float_slider else solara.SliderInt + ) + slider_class( + options.label, + on_value=change_handler, + min=options.min, + max=options.max, + step=options.step, + ) + continue + + # label for the input is "label" from options or name + label = options.get("label", name) + input_type = options.get("type") if input_type == "SliderInt": solara.SliderInt( label, diff --git a/tests/test_jupyter_viz.py b/tests/test_jupyter_viz.py index 6f88c7d14dc..0ac3e0e2413 100644 --- a/tests/test_jupyter_viz.py +++ b/tests/test_jupyter_viz.py @@ -4,7 +4,7 @@ import ipyvuetify as vw import solara -from mesa.experimental.jupyter_viz import JupyterViz, UserInputs +from mesa.experimental.jupyter_viz import JupyterViz, Slider, UserInputs class TestMakeUserInput(unittest.TestCase): @@ -132,3 +132,17 @@ def test_call_space_drawer(self, mock_space_matplotlib): altspace_drawer.assert_called_with( mock_model_class.return_value, agent_portrayal ) + + +def test_slider(): + slider_float = Slider("Agent density", 0.8, 0.1, 1.0, 0.1) + assert slider_float.is_float_slider + assert slider_float.value == 0.8 + assert slider_float.get("value") == 0.8 + assert slider_float.min == 0.1 + assert slider_float.max == 1.0 + assert slider_float.step == 0.1 + slider_int = Slider("Homophily", 3, 0, 8, 1) + assert not slider_int.is_float_slider + slider_dtype_float = Slider("Homophily", 3, 0, 8, 1, dtype=float) + assert slider_dtype_float.is_float_slider