Skip to content

Commit

Permalink
Add epsilon argument check
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jul 26, 2021
1 parent 62bd386 commit 572cab6
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
7 changes: 7 additions & 0 deletions pymc3/distributions/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,13 @@ def check_simulator_args(cls, simulator, *args, **kwargs):
f"`pm.SimulatorRV` but got {type(simulator)}"
)

n_params = len(args) + len(kwargs.get("params", [])) + ("epsilon" in kwargs)
if n_params != len(simulator.ndims_params):
raise ValueError(
f"`Simulator` expected {len(simulator.ndims_params)} parameters"
f"but got {n_params}. Did you forget to specify `epsilon`?"
)

if "distance" in kwargs:
raise ValueError(
"distance is no longer defined when calling `pm.Simulator`. It"
Expand Down
5 changes: 5 additions & 0 deletions pymc3/tests/test_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,11 @@ def test_simulator_error_msg(self):
with pytest.raises(ValueError, match=msg):
pm.Simulator("sim", lambda: None, 0, 1, epsilon=1.0)

msg = "Did you forget to specify `epsilon`?"
with pm.Model() as m:
with pytest.raises(ValueError, match=msg):
pm.Simulator("sim", NormalSimRV1(), 0, 1)

msg = "distance is no longer defined when calling `pm.Simulator`"
with pm.Model() as m:
with pytest.raises(ValueError, match=msg):
Expand Down

0 comments on commit 572cab6

Please sign in to comment.