diff --git a/pymc3/distributions/simulator.py b/pymc3/distributions/simulator.py index a6e14ba17d3..f340c4e39f0 100644 --- a/pymc3/distributions/simulator.py +++ b/pymc3/distributions/simulator.py @@ -307,6 +307,13 @@ def check_simulator_args(cls, simulator, *args, **kwargs): f"`pm.SimulatorRV` but got {type(simulator)}" ) + n_params = len(args) + len(kwargs.get("params", [])) + ("epsilon" in kwargs) + if n_params != len(simulator.ndims_params): + raise ValueError( + f"`Simulator` expected {len(simulator.ndims_params)} parameters" + f"but got {n_params}. Did you forget to specify `epsilon`?" + ) + if "distance" in kwargs: raise ValueError( "distance is no longer defined when calling `pm.Simulator`. It" diff --git a/pymc3/tests/test_smc.py b/pymc3/tests/test_smc.py index a8fc19160fd..a76828d22cf 100644 --- a/pymc3/tests/test_smc.py +++ b/pymc3/tests/test_smc.py @@ -445,6 +445,11 @@ def test_simulator_error_msg(self): with pytest.raises(ValueError, match=msg): pm.Simulator("sim", lambda: None, 0, 1, epsilon=1.0) + msg = "Did you forget to specify `epsilon`?" + with pm.Model() as m: + with pytest.raises(ValueError, match=msg): + pm.Simulator("sim", NormalSimRV1(), 0, 1) + msg = "distance is no longer defined when calling `pm.Simulator`" with pm.Model() as m: with pytest.raises(ValueError, match=msg):