Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement Slider class for JupyterViz #1972

Merged
merged 1 commit into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions mesa/experimental/UserParam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
class UserParam:
_ERROR_MESSAGE = "Missing or malformed inputs for '{}' Option '{}'"

def maybe_raise_error(self, param_type, valid):
if valid:
return
msg = self._ERROR_MESSAGE.format(param_type, self.label)
raise ValueError(msg)


class Slider(UserParam):
"""
A number-based slider input with settable increment.

Example:

slider_option = Slider("My Slider", value=123, min=10, max=200, step=0.1)

Args:
label: The displayed label in the UI
value: The initial value of the slider
min: The minimum possible value of the slider
max: The maximum possible value of the slider
step: The step between min and max for a range of possible values
dtype: either int or float
"""

def __init__(
self,
label="",
value=None,
min=None,
max=None,
step=1,
dtype=None,
):
self.label = label
self.value = value
self.min = min
self.max = max
self.step = step

# Validate option type to make sure values are supplied properly
valid = not (self.value is None or self.min is None or self.max is None)
self.maybe_raise_error("slider", valid)

if dtype is None:
self.is_float_slider = self._check_values_are_float(value, min, max, step)
else:
self.is_float_slider = dtype == float

def _check_values_are_float(self, value, min, max, step):
return any(isinstance(n, float) for n in (value, min, max, step))

def get(self, attr):
return getattr(self, attr)
2 changes: 1 addition & 1 deletion mesa/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .jupyter_viz import JupyterViz, make_text # noqa
from .jupyter_viz import JupyterViz, make_text, Slider # noqa
24 changes: 20 additions & 4 deletions mesa/experimental/jupyter_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from solara.alias import rv

import mesa.experimental.components.matplotlib as components_matplotlib
from mesa.experimental.UserParam import Slider

# Avoid interactive backend
plt.switch_backend("agg")
Expand Down Expand Up @@ -40,7 +41,7 @@ def JupyterViz(
# 1. Set up model parameters
user_params, fixed_params = split_model_params(model_params)
model_parameters, set_model_parameters = solara.use_state(
{**fixed_params, **{k: v["value"] for k, v in user_params.items()}}
{**fixed_params, **{k: v.get("value") for k, v in user_params.items()}}
)

# 2. Set up Model
Expand Down Expand Up @@ -251,6 +252,8 @@ def split_model_params(model_params):


def check_param_is_fixed(param):
if isinstance(param, Slider):
return False
if not isinstance(param, dict):
return True
if "type" not in param:
Expand All @@ -270,13 +273,26 @@ def UserInputs(user_params, on_change=None):
"""

for name, options in user_params.items():
# label for the input is "label" from options or name
label = options.get("label", name)
input_type = options.get("type")

def change_handler(value, name=name):
on_change(name, value)

if isinstance(options, Slider):
slider_class = (
solara.SliderFloat if options.is_float_slider else solara.SliderInt
)
slider_class(
options.label,
on_value=change_handler,
min=options.min,
max=options.max,
step=options.step,
)
continue

# label for the input is "label" from options or name
label = options.get("label", name)
input_type = options.get("type")
if input_type == "SliderInt":
solara.SliderInt(
label,
Expand Down
16 changes: 15 additions & 1 deletion tests/test_jupyter_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import ipyvuetify as vw
import solara

from mesa.experimental.jupyter_viz import JupyterViz, UserInputs
from mesa.experimental.jupyter_viz import JupyterViz, Slider, UserInputs


class TestMakeUserInput(unittest.TestCase):
Expand Down Expand Up @@ -132,3 +132,17 @@ def test_call_space_drawer(self, mock_space_matplotlib):
altspace_drawer.assert_called_with(
mock_model_class.return_value, agent_portrayal
)


def test_slider():
slider_float = Slider("Agent density", 0.8, 0.1, 1.0, 0.1)
assert slider_float.is_float_slider
assert slider_float.value == 0.8
assert slider_float.get("value") == 0.8
assert slider_float.min == 0.1
assert slider_float.max == 1.0
assert slider_float.step == 0.1
slider_int = Slider("Homophily", 3, 0, 8, 1)
assert not slider_int.is_float_slider
slider_dtype_float = Slider("Homophily", 3, 0, 8, 1, dtype=float)
assert slider_dtype_float.is_float_slider