Skip to content

Commit

Permalink
local salt marker
Browse files Browse the repository at this point in the history
  • Loading branch information
astralcai committed Oct 24, 2024
1 parent 21b4629 commit 4fd9637
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 7 deletions.
37 changes: 31 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,6 @@ class DummyDevice(DefaultGaussian):
_operation_map["Kerr"] = lambda *x, **y: np.identity(2)


@pytest.fixture(autouse=True)
def set_numpy_seed():
np.random.seed(9872653)
yield


@pytest.fixture(scope="session")
def tol():
"""Numerical tolerance for equality tests."""
Expand Down Expand Up @@ -184,6 +178,37 @@ def legacy_opmath_only():
pytest.skip("This test exclusively tests legacy opmath")


#######################################################################


@pytest.fixture
def seed(request):
"""An integer random number generator seed
This fixture overrides the ``seed`` fixture provided by pytest-rng, adding the flexibility
of locally getting a new seed for a test case by applying the ``local_salt`` marker. This is
useful when the seed from pytest-rng happens to be a bad seed that causes your test to fail.
.. code_block:: python
@pytest.mark.local_salt(42)
def test_something(seed):
...
The value passed to ``local_salt`` needs to be an integer.
"""

fixture_manager = request._fixturemanager # pylint:disable=protected-access
fixture_defs = fixture_manager.getfixturedefs("seed", request.node)
original_fixture_def = fixture_defs[0] # the original seed fixture provided by pytest-rng
original_seed = original_fixture_def.func(request)
marker = request.node.get_closest_marker("local_salt")
if marker and marker.args:
return original_seed + marker.args[0]
return original_seed


#######################################################################

try:
Expand Down
2 changes: 1 addition & 1 deletion tests/devices/qubit/test_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1559,14 +1559,14 @@ def test_tree_traversal_combine_measurements(self, measurements, expected):
else:
assert qml.math.allclose(combined_measurement, expected)

@pytest.mark.local_salt(2)
@pytest.mark.parametrize("ml_framework", ml_frameworks_list)
@pytest.mark.parametrize(
"postselect_mode", [None, "hw-like", "pad-invalid-samples", "fill-shots"]
)
def test_simulate_one_shot_native_mcm(self, ml_framework, postselect_mode, seed):
"""Unit tests for simulate_one_shot_native_mcm"""

seed = seed + 2
with qml.queuing.AnnotatedQueue() as q:
qml.RX(np.pi / 4, wires=0)
m = qml.measure(wires=0, postselect=0)
Expand Down
1 change: 1 addition & 0 deletions tests/pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ markers =
logging: marks tests for pennylane logging
external: marks tests that require external packages such as matplotlib and PyZX
catalyst: marks tests for catalyst testing (select with '-m "catalyst"')
local_salt(salt): adds a salt to the seed provided by the pytest-rng fixture
filterwarnings =
ignore::DeprecationWarning:autograd.numpy.numpy_wrapper
ignore:Casting complex values to real::autograd.numpy.numpy_wrapper
Expand Down

0 comments on commit 4fd9637

Please sign in to comment.