From 7e50d72a5d31f667a249816a5ed6aeb3c0e53d81 Mon Sep 17 00:00:00 2001 From: Will Date: Wed, 11 Sep 2024 14:33:06 -0400 Subject: [PATCH 1/8] first draft of JIT implementation --- pennylane/math/__init__.py | 2 + pennylane/math/multi_dispatch.py | 30 +++++++++++++ pennylane/ops/qubit/observables.py | 3 ++ pennylane/ops/qutrit/observables.py | 17 ++++--- .../state_preparations/basis_qutrit.py | 16 ++++--- .../test_qutrit_basis_state_prep.py | 44 ++++++++++--------- 6 files changed, 80 insertions(+), 32 deletions(-) diff --git a/pennylane/math/__init__.py b/pennylane/math/__init__.py index 731979daac9..55c7c0d69e0 100644 --- a/pennylane/math/__init__.py +++ b/pennylane/math/__init__.py @@ -64,6 +64,8 @@ tensordot, unwrap, where, + matrix_power, + eigh, ) from .quantum import ( cov_matrix, diff --git a/pennylane/math/multi_dispatch.py b/pennylane/math/multi_dispatch.py index 2c4684571f7..703ea678973 100644 --- a/pennylane/math/multi_dispatch.py +++ b/pennylane/math/multi_dispatch.py @@ -314,6 +314,36 @@ def matmul(tensor1, tensor2, like=None): tensor2 = cast_like(tensor2, tensor1) # pylint: disable=arguments-out-of-order return ar.numpy.matmul(tensor1, tensor2, like=like) +@multi_dispatch(argnum=[0, 1]) +def matrix_power(tensor1, tensor2, like=None): + """Raise a tensor to the power of a tensor.""" + if like == "jax": + import jax + + def matrix_power_while_inner(val, M): + k, cur_val = val + return k - 1, M @ cur_val + + def matrix_power_while(M, k): + cond_fun = lambda val: val[0] >= 0 + init_val = (k - 1, jax.numpy.eye(M.shape[0])) + body_fun = lambda val: matrix_power_while_inner(val, M) + + result = jax.lax.while_loop(cond_fun, body_fun, init_val) + return result[1] + return matrix_power_while(tensor1, tensor2) + + return np.linalg.matrix_power(tensor1, tensor2) + +@multi_dispatch(argnum=[0]) +def eigh(tensor, like=None): + """Retruns the eigenvalues of a Hermitian matrix.""" + if like == "jax": + import jax + return jax.numpy.linalg.eigh(tensor) + + return np.linalg.eigh(tensor) + @multi_dispatch(argnum=[0, 1]) def dot(tensor1, tensor2, like=None): diff --git a/pennylane/ops/qubit/observables.py b/pennylane/ops/qubit/observables.py index 8f992c81bc2..77c0b5ae98e 100644 --- a/pennylane/ops/qubit/observables.py +++ b/pennylane/ops/qubit/observables.py @@ -91,6 +91,9 @@ def __init__(self, A: TensorLike, wires: WiresLike, id: Optional[str] = None): @staticmethod def _validate_input(A: TensorLike, expected_mx_shape: Optional[int] = None): """Validate the input matrix.""" + if qml.math.is_abstract(A): + return + if len(A.shape) != 2 or A.shape[0] != A.shape[1]: raise ValueError("Observable must be a square matrix.") diff --git a/pennylane/ops/qutrit/observables.py b/pennylane/ops/qutrit/observables.py index e8ed7cf02b2..2e781e95e4b 100644 --- a/pennylane/ops/qutrit/observables.py +++ b/pennylane/ops/qutrit/observables.py @@ -110,13 +110,18 @@ def eigendecomposition(self): Hermitian observable """ Hmat = self.matrix() - Hmat = qml.math.to_numpy(Hmat) - Hkey = tuple(Hmat.flatten().tolist()) - if Hkey not in THermitian._eigs: - w, U = np.linalg.eigh(Hmat) - THermitian._eigs[Hkey] = {"eigvec": U, "eigval": w} - return THermitian._eigs[Hkey] + if not qml.math.is_abstract(Hmat): + Hmat = qml.math.to_numpy(Hmat) + Hkey = tuple(Hmat.flatten().tolist()) + if Hkey not in THermitian._eigs: + w, U = qml.math.eigh(Hmat) + THermitian._eigs[Hkey] = {"eigvec": U, "eigval": w} + + return THermitian._eigs[Hkey] + + w, U = qml.math.eigh(Hmat) + return {"eigvec": U, "eigval": w} @staticmethod def compute_diagonalizing_gates(eigenvectors, wires): # pylint: disable=arguments-differ diff --git a/pennylane/templates/state_preparations/basis_qutrit.py b/pennylane/templates/state_preparations/basis_qutrit.py index 8568d0141d4..5d79a9db02f 100644 --- a/pennylane/templates/state_preparations/basis_qutrit.py +++ b/pennylane/templates/state_preparations/basis_qutrit.py @@ -77,10 +77,11 @@ def __init__(self, basis_state, wires, id=None): f"Basis states must be of length {len(wires)}; state {i} has length {n_bits}." ) - if any(bit not in [0, 1, 2] for bit in state): - raise ValueError( - f"Basis states must only consist of 0s, 1s, and 2s; state {i} is {state}" - ) + if not qml.math.is_abstract(basis_state): + if any(bit not in [0, 1, 2] for bit in state): + raise ValueError( + f"Basis states must only consist of 0s, 1s, and 2s; state {i} is {state}" + ) # TODO: basis_state should be a hyperparameter, not a trainable parameter. # However, this breaks a test that ensures compatibility with batch_transform. @@ -112,7 +113,10 @@ def compute_decomposition(basis_state, wires): # pylint: disable=arguments-diff """ op_list = [] + tshift = qml.math.array([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) for wire, state in zip(wires, basis_state): - for _ in range(0, state): - op_list.append(qml.TShift(wire)) + mat = qml.math.matrix_power(tshift, state) + op = qml.ops.QutritUnitary(mat, wires=wire) + op_list.append(op) + return op_list diff --git a/tests/templates/test_state_preparations/test_qutrit_basis_state_prep.py b/tests/templates/test_state_preparations/test_qutrit_basis_state_prep.py index 98755079512..3c8c15ccf22 100644 --- a/tests/templates/test_state_preparations/test_qutrit_basis_state_prep.py +++ b/tests/templates/test_state_preparations/test_qutrit_basis_state_prep.py @@ -36,33 +36,37 @@ class TestDecomposition: """Tests that the template defines the correct decomposition.""" # fmt: off - @pytest.mark.parametrize("basis_state,wires,target_wires", [ - ([0], [0], []), - ([0], [1], []), - ([1], [0], [0]), - ([2], [1], [1, 1]), - ([0, 1], [0, 1], [1]), - ([2, 0], [1, 4], [1, 1]), - ([1, 0], [4, 5], [4]), - ([0, 2], [4, 5], [5, 5]), - ([1, 2], [0, 2], [0, 2, 2]), - ([0, 0, 1, 0], [1, 2, 3, 4], [3]), - ([2, 0, 0, 0], [1, 2, 3, 4], [1, 1]), - ([1, 1, 1, 0], [1, 2, 6, 8], [1, 2, 6]), - ([0, 2, 1, 2], [1, 2, 6, 8], [2, 2, 6, 8, 8]), - ([1, 0, 1, 1], [1, 2, 6, 8], [1, 6, 8]), - ([2, 1, 0, 2], [1, 2, 6, 8], [1, 1, 2, 8, 8]), + tshift0 = np.eye(3, dtype=int) + tshift1 = np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) + tshift2 = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]]) + @pytest.mark.parametrize("basis_state,wires,target_unitary", [ + ([0], [0], [tshift0]), + ([0], [1], [tshift0]), + ([1], [0], [tshift1]), + ([2], [1], [tshift2]), + ([0, 1], [0, 1], [tshift0, tshift1]), + ([2, 0], [1, 4], [tshift2, tshift0]), + ([1, 0], [4, 5], [tshift1, tshift0]), + ([0, 2], [4, 5], [tshift0, tshift2]), + ([1, 2], [0, 2], [tshift1, tshift2]), + ([0, 0, 1, 0], [1, 2, 3, 4], [tshift0, tshift0, tshift1, tshift0]), + ([2, 0, 0, 0], [1, 2, 3, 4], [tshift2, tshift0, tshift0, tshift0]), + ([1, 1, 1, 0], [1, 2, 6, 8], [tshift1, tshift1, tshift1, tshift0]), + ([0, 2, 1, 2], [1, 2, 6, 8], [tshift0, tshift2, tshift1, tshift2]), + ([1, 0, 1, 1], [1, 2, 6, 8], [tshift1, tshift0, tshift1, tshift1]), + ([2, 1, 0, 2], [1, 2, 6, 8], [tshift2, tshift1, tshift0, tshift2]), ]) # fmt: on - def test_correct_pl_gates(self, basis_state, wires, target_wires): + def test_correct_pl_gates(self, basis_state, wires, target_unitary): """Tests queue for simple cases.""" op = qml.QutritBasisStatePreparation(basis_state, wires) queue = op.decomposition() for id, gate in enumerate(queue): - assert gate.name == "TShift" - assert gate.wires.tolist() == [target_wires[id]] + assert gate.name == "QutritUnitary" + assert gate.wires.tolist() == [wires[id]] + assert np.array_equal(gate.matrix(), target_unitary[id]) # fmt: off @pytest.mark.parametrize("basis_state,wires,target_state", [ @@ -110,7 +114,7 @@ def circuit(obs): ([1, 0, 1], [2, 0, 1], [0, 1, 1]), ], ) - @pytest.mark.xfail(reason="JIT comptability not yet implemented") + def test_state_preparation_jax_jit( self, tol, qutrit_device_3_wires, basis_state, wires, target_state ): From 2711b8b34e5a05a337dede6e255110b7a6a3ede4 Mon Sep 17 00:00:00 2001 From: Will Date: Thu, 12 Sep 2024 11:36:18 -0400 Subject: [PATCH 2/8] fixing fable template and adding test --- pennylane/templates/subroutines/fable.py | 6 ++++- .../templates/test_subroutines/test_fable.py | 23 +++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/pennylane/templates/subroutines/fable.py b/pennylane/templates/subroutines/fable.py index 16f1160ddd0..31f59105146 100644 --- a/pennylane/templates/subroutines/fable.py +++ b/pennylane/templates/subroutines/fable.py @@ -166,7 +166,11 @@ def compute_decomposition(input_matrix, wires, tol=0): # pylint:disable=argumen for c_wire in nots: op_list.append(qml.CNOT(wires=[c_wire] + ancilla)) op_list.append(qml.RY(2 * theta, wires=ancilla)) - nots[wire_map[control_index]] = 1 + nots = {} + if wire_map[control_index] in nots: + del nots[wire_map[control_index]] + else: + nots[wire_map[control_index]] = 1 else: if abs(2 * theta) > tol: for c_wire in nots: diff --git a/tests/templates/test_subroutines/test_fable.py b/tests/templates/test_subroutines/test_fable.py index 8649fe71748..3b8600fd400 100644 --- a/tests/templates/test_subroutines/test_fable.py +++ b/tests/templates/test_subroutines/test_fable.py @@ -285,6 +285,29 @@ def circuit_jax(input_matrix): gradient_jax = grad_fn(input_matrix_jax) assert np.allclose(gradient_numeric, gradient_jax[0, 0], rtol=0.001) + @pytest.mark.jax + def test_jit_result(self): + """Test that the value returned in JIT mode equals the value returned without JIT.""" + import jax + + def fable(input_matrix): + qml.FABLE(input_matrix, wires=range(5), tol=0) + return qml.expval(qml.PauliZ(wires=0)) + + input_matrix = np.array( + [ + [-0.5, -0.4, 0.6, 0.7], + [0.9, 0.9, 0.8, 0.9], + [0.8, 0.7, 0.9, 0.8], + [0.9, 0.7, 0.8, 0.3], + ] + ) + + device = qml.device("default.qubit", wires=5) + interpreted_fn = qml.QNode(fable, device) + jitted_fn = jax.jit(interpreted_fn) + assert np.allclose(jitted_fn(input_matrix), interpreted_fn(input_matrix)) + @pytest.mark.jax def test_fable_grad_jax_jit_error(self, input_matrix): """Test that FABLE is differentiable when using jax.""" From f8cb805e0ca51bdc19653eb97d11494e51676c44 Mon Sep 17 00:00:00 2001 From: Will Date: Thu, 12 Sep 2024 11:44:34 -0400 Subject: [PATCH 3/8] fixing PR --- pennylane/math/__init__.py | 2 - pennylane/math/multi_dispatch.py | 30 ------------- pennylane/ops/qubit/observables.py | 3 -- pennylane/ops/qutrit/observables.py | 17 +++---- .../state_preparations/basis_qutrit.py | 16 +++---- .../test_qutrit_basis_state_prep.py | 44 +++++++++---------- 6 files changed, 32 insertions(+), 80 deletions(-) diff --git a/pennylane/math/__init__.py b/pennylane/math/__init__.py index 55c7c0d69e0..731979daac9 100644 --- a/pennylane/math/__init__.py +++ b/pennylane/math/__init__.py @@ -64,8 +64,6 @@ tensordot, unwrap, where, - matrix_power, - eigh, ) from .quantum import ( cov_matrix, diff --git a/pennylane/math/multi_dispatch.py b/pennylane/math/multi_dispatch.py index 703ea678973..2c4684571f7 100644 --- a/pennylane/math/multi_dispatch.py +++ b/pennylane/math/multi_dispatch.py @@ -314,36 +314,6 @@ def matmul(tensor1, tensor2, like=None): tensor2 = cast_like(tensor2, tensor1) # pylint: disable=arguments-out-of-order return ar.numpy.matmul(tensor1, tensor2, like=like) -@multi_dispatch(argnum=[0, 1]) -def matrix_power(tensor1, tensor2, like=None): - """Raise a tensor to the power of a tensor.""" - if like == "jax": - import jax - - def matrix_power_while_inner(val, M): - k, cur_val = val - return k - 1, M @ cur_val - - def matrix_power_while(M, k): - cond_fun = lambda val: val[0] >= 0 - init_val = (k - 1, jax.numpy.eye(M.shape[0])) - body_fun = lambda val: matrix_power_while_inner(val, M) - - result = jax.lax.while_loop(cond_fun, body_fun, init_val) - return result[1] - return matrix_power_while(tensor1, tensor2) - - return np.linalg.matrix_power(tensor1, tensor2) - -@multi_dispatch(argnum=[0]) -def eigh(tensor, like=None): - """Retruns the eigenvalues of a Hermitian matrix.""" - if like == "jax": - import jax - return jax.numpy.linalg.eigh(tensor) - - return np.linalg.eigh(tensor) - @multi_dispatch(argnum=[0, 1]) def dot(tensor1, tensor2, like=None): diff --git a/pennylane/ops/qubit/observables.py b/pennylane/ops/qubit/observables.py index 77c0b5ae98e..8f992c81bc2 100644 --- a/pennylane/ops/qubit/observables.py +++ b/pennylane/ops/qubit/observables.py @@ -91,9 +91,6 @@ def __init__(self, A: TensorLike, wires: WiresLike, id: Optional[str] = None): @staticmethod def _validate_input(A: TensorLike, expected_mx_shape: Optional[int] = None): """Validate the input matrix.""" - if qml.math.is_abstract(A): - return - if len(A.shape) != 2 or A.shape[0] != A.shape[1]: raise ValueError("Observable must be a square matrix.") diff --git a/pennylane/ops/qutrit/observables.py b/pennylane/ops/qutrit/observables.py index 2e781e95e4b..e8ed7cf02b2 100644 --- a/pennylane/ops/qutrit/observables.py +++ b/pennylane/ops/qutrit/observables.py @@ -110,18 +110,13 @@ def eigendecomposition(self): Hermitian observable """ Hmat = self.matrix() + Hmat = qml.math.to_numpy(Hmat) + Hkey = tuple(Hmat.flatten().tolist()) + if Hkey not in THermitian._eigs: + w, U = np.linalg.eigh(Hmat) + THermitian._eigs[Hkey] = {"eigvec": U, "eigval": w} - if not qml.math.is_abstract(Hmat): - Hmat = qml.math.to_numpy(Hmat) - Hkey = tuple(Hmat.flatten().tolist()) - if Hkey not in THermitian._eigs: - w, U = qml.math.eigh(Hmat) - THermitian._eigs[Hkey] = {"eigvec": U, "eigval": w} - - return THermitian._eigs[Hkey] - - w, U = qml.math.eigh(Hmat) - return {"eigvec": U, "eigval": w} + return THermitian._eigs[Hkey] @staticmethod def compute_diagonalizing_gates(eigenvectors, wires): # pylint: disable=arguments-differ diff --git a/pennylane/templates/state_preparations/basis_qutrit.py b/pennylane/templates/state_preparations/basis_qutrit.py index 5d79a9db02f..8568d0141d4 100644 --- a/pennylane/templates/state_preparations/basis_qutrit.py +++ b/pennylane/templates/state_preparations/basis_qutrit.py @@ -77,11 +77,10 @@ def __init__(self, basis_state, wires, id=None): f"Basis states must be of length {len(wires)}; state {i} has length {n_bits}." ) - if not qml.math.is_abstract(basis_state): - if any(bit not in [0, 1, 2] for bit in state): - raise ValueError( - f"Basis states must only consist of 0s, 1s, and 2s; state {i} is {state}" - ) + if any(bit not in [0, 1, 2] for bit in state): + raise ValueError( + f"Basis states must only consist of 0s, 1s, and 2s; state {i} is {state}" + ) # TODO: basis_state should be a hyperparameter, not a trainable parameter. # However, this breaks a test that ensures compatibility with batch_transform. @@ -113,10 +112,7 @@ def compute_decomposition(basis_state, wires): # pylint: disable=arguments-diff """ op_list = [] - tshift = qml.math.array([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) for wire, state in zip(wires, basis_state): - mat = qml.math.matrix_power(tshift, state) - op = qml.ops.QutritUnitary(mat, wires=wire) - op_list.append(op) - + for _ in range(0, state): + op_list.append(qml.TShift(wire)) return op_list diff --git a/tests/templates/test_state_preparations/test_qutrit_basis_state_prep.py b/tests/templates/test_state_preparations/test_qutrit_basis_state_prep.py index 3c8c15ccf22..98755079512 100644 --- a/tests/templates/test_state_preparations/test_qutrit_basis_state_prep.py +++ b/tests/templates/test_state_preparations/test_qutrit_basis_state_prep.py @@ -36,37 +36,33 @@ class TestDecomposition: """Tests that the template defines the correct decomposition.""" # fmt: off - tshift0 = np.eye(3, dtype=int) - tshift1 = np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) - tshift2 = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]]) - @pytest.mark.parametrize("basis_state,wires,target_unitary", [ - ([0], [0], [tshift0]), - ([0], [1], [tshift0]), - ([1], [0], [tshift1]), - ([2], [1], [tshift2]), - ([0, 1], [0, 1], [tshift0, tshift1]), - ([2, 0], [1, 4], [tshift2, tshift0]), - ([1, 0], [4, 5], [tshift1, tshift0]), - ([0, 2], [4, 5], [tshift0, tshift2]), - ([1, 2], [0, 2], [tshift1, tshift2]), - ([0, 0, 1, 0], [1, 2, 3, 4], [tshift0, tshift0, tshift1, tshift0]), - ([2, 0, 0, 0], [1, 2, 3, 4], [tshift2, tshift0, tshift0, tshift0]), - ([1, 1, 1, 0], [1, 2, 6, 8], [tshift1, tshift1, tshift1, tshift0]), - ([0, 2, 1, 2], [1, 2, 6, 8], [tshift0, tshift2, tshift1, tshift2]), - ([1, 0, 1, 1], [1, 2, 6, 8], [tshift1, tshift0, tshift1, tshift1]), - ([2, 1, 0, 2], [1, 2, 6, 8], [tshift2, tshift1, tshift0, tshift2]), + @pytest.mark.parametrize("basis_state,wires,target_wires", [ + ([0], [0], []), + ([0], [1], []), + ([1], [0], [0]), + ([2], [1], [1, 1]), + ([0, 1], [0, 1], [1]), + ([2, 0], [1, 4], [1, 1]), + ([1, 0], [4, 5], [4]), + ([0, 2], [4, 5], [5, 5]), + ([1, 2], [0, 2], [0, 2, 2]), + ([0, 0, 1, 0], [1, 2, 3, 4], [3]), + ([2, 0, 0, 0], [1, 2, 3, 4], [1, 1]), + ([1, 1, 1, 0], [1, 2, 6, 8], [1, 2, 6]), + ([0, 2, 1, 2], [1, 2, 6, 8], [2, 2, 6, 8, 8]), + ([1, 0, 1, 1], [1, 2, 6, 8], [1, 6, 8]), + ([2, 1, 0, 2], [1, 2, 6, 8], [1, 1, 2, 8, 8]), ]) # fmt: on - def test_correct_pl_gates(self, basis_state, wires, target_unitary): + def test_correct_pl_gates(self, basis_state, wires, target_wires): """Tests queue for simple cases.""" op = qml.QutritBasisStatePreparation(basis_state, wires) queue = op.decomposition() for id, gate in enumerate(queue): - assert gate.name == "QutritUnitary" - assert gate.wires.tolist() == [wires[id]] - assert np.array_equal(gate.matrix(), target_unitary[id]) + assert gate.name == "TShift" + assert gate.wires.tolist() == [target_wires[id]] # fmt: off @pytest.mark.parametrize("basis_state,wires,target_state", [ @@ -114,7 +110,7 @@ def circuit(obs): ([1, 0, 1], [2, 0, 1], [0, 1, 1]), ], ) - + @pytest.mark.xfail(reason="JIT comptability not yet implemented") def test_state_preparation_jax_jit( self, tol, qutrit_device_3_wires, basis_state, wires, target_state ): From db85b5d12be30d372ee694e295828f8c0e7a0488 Mon Sep 17 00:00:00 2001 From: Will Date: Thu, 12 Sep 2024 11:47:25 -0400 Subject: [PATCH 4/8] update changelog --- doc/releases/changelog-dev.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 7ac1ce219bb..9d9c063f020 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -98,6 +98,9 @@ * The ``qml.Qubitization`` template now orders the ``control`` wires first and the ``hamiltonian`` wires second, which is the expected according to other templates. [(#6229)](https://github.com/PennyLaneAI/pennylane/pull/6229) +* The ``qml.FABLE`` template now returns the correct value when JIT is enabled. + [(#6263)](https://github.com/PennyLaneAI/pennylane/pull/6263) + *

Contributors ✍️

This release contains contributions from (in alphabetical order): From 254ed814d884c073d1ab9218780d670da11a08a5 Mon Sep 17 00:00:00 2001 From: Will Date: Thu, 12 Sep 2024 12:03:04 -0400 Subject: [PATCH 5/8] removing unnecessary branch --- pennylane/templates/subroutines/fable.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/pennylane/templates/subroutines/fable.py b/pennylane/templates/subroutines/fable.py index 31f59105146..290201200bc 100644 --- a/pennylane/templates/subroutines/fable.py +++ b/pennylane/templates/subroutines/fable.py @@ -167,20 +167,18 @@ def compute_decomposition(input_matrix, wires, tol=0): # pylint:disable=argumen op_list.append(qml.CNOT(wires=[c_wire] + ancilla)) op_list.append(qml.RY(2 * theta, wires=ancilla)) nots = {} - if wire_map[control_index] in nots: - del nots[wire_map[control_index]] - else: - nots[wire_map[control_index]] = 1 + nots[wire_map[control_index]] = 1 + continue + + if abs(2 * theta) > tol: + for c_wire in nots: + op_list.append(qml.CNOT(wires=[c_wire] + ancilla)) + op_list.append(qml.RY(2 * theta, wires=ancilla)) + nots = {} + if wire_map[control_index] in nots: + del nots[wire_map[control_index]] else: - if abs(2 * theta) > tol: - for c_wire in nots: - op_list.append(qml.CNOT(wires=[c_wire] + ancilla)) - op_list.append(qml.RY(2 * theta, wires=ancilla)) - nots = {} - if wire_map[control_index] in nots: - del nots[wire_map[control_index]] - else: - nots[wire_map[control_index]] = 1 + nots[wire_map[control_index]] = 1 for c_wire in nots: op_list.append(qml.CNOT([c_wire] + ancilla)) From 9c8f536242bb837912648e12971971e422fca93d Mon Sep 17 00:00:00 2001 From: Will Date: Thu, 12 Sep 2024 13:58:49 -0400 Subject: [PATCH 6/8] fixing gradient test --- tests/templates/test_subroutines/test_fable.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/templates/test_subroutines/test_fable.py b/tests/templates/test_subroutines/test_fable.py index 3b8600fd400..22c1db71d2b 100644 --- a/tests/templates/test_subroutines/test_fable.py +++ b/tests/templates/test_subroutines/test_fable.py @@ -272,13 +272,14 @@ def test_fable_grad_jax_jit(self, input_matrix): input_jax_negative_delta = jnp.array(input_negative_delta) input_matrix_jax = jnp.array(input_matrix) - @jax.jit @qml.qnode(dev, diff_method="backprop") def circuit_jax(input_matrix): qml.FABLE(input_matrix, wires=range(5), tol=0) return qml.expval(qml.PauliZ(wires=0)) - grad_fn = jax.grad(circuit_jax) + jitted_fn = jax.jit(circuit_jax) + + grad_fn = jax.grad(jitted_fn) gradient_numeric = ( circuit_jax(input_jax_positive_delta) - circuit_jax(input_jax_negative_delta) ) / (2 * delta) @@ -286,7 +287,7 @@ def circuit_jax(input_matrix): assert np.allclose(gradient_numeric, gradient_jax[0, 0], rtol=0.001) @pytest.mark.jax - def test_jit_result(self): + def test_fable_jax_jit(self): """Test that the value returned in JIT mode equals the value returned without JIT.""" import jax From 9fe1828072ec600c69456ab6eae0d15fe6597467 Mon Sep 17 00:00:00 2001 From: Will Date: Thu, 12 Sep 2024 16:12:28 -0400 Subject: [PATCH 7/8] minor changes --- pennylane/templates/subroutines/fable.py | 2 +- .../templates/test_subroutines/test_fable.py | 27 +++---------------- 2 files changed, 4 insertions(+), 25 deletions(-) diff --git a/pennylane/templates/subroutines/fable.py b/pennylane/templates/subroutines/fable.py index 290201200bc..d9637676738 100644 --- a/pennylane/templates/subroutines/fable.py +++ b/pennylane/templates/subroutines/fable.py @@ -170,7 +170,7 @@ def compute_decomposition(input_matrix, wires, tol=0): # pylint:disable=argumen nots[wire_map[control_index]] = 1 continue - if abs(2 * theta) > tol: + if qml.math.abs(2 * theta) > tol: for c_wire in nots: op_list.append(qml.CNOT(wires=[c_wire] + ancilla)) op_list.append(qml.RY(2 * theta, wires=ancilla)) diff --git a/tests/templates/test_subroutines/test_fable.py b/tests/templates/test_subroutines/test_fable.py index 22c1db71d2b..ceebfc8735c 100644 --- a/tests/templates/test_subroutines/test_fable.py +++ b/tests/templates/test_subroutines/test_fable.py @@ -235,7 +235,7 @@ def circuit_jax(input_matrix): assert np.allclose(gradient_numeric, gradient_jax[0, 0], rtol=0.001) @pytest.mark.jax - def test_fable_grad_jax_jit(self, input_matrix): + def test_fable_jax_jit(self, input_matrix): """Test that FABLE is differentiable when using jax.""" import jax import jax.numpy as jnp @@ -284,30 +284,9 @@ def circuit_jax(input_matrix): circuit_jax(input_jax_positive_delta) - circuit_jax(input_jax_negative_delta) ) / (2 * delta) gradient_jax = grad_fn(input_matrix_jax) - assert np.allclose(gradient_numeric, gradient_jax[0, 0], rtol=0.001) - @pytest.mark.jax - def test_fable_jax_jit(self): - """Test that the value returned in JIT mode equals the value returned without JIT.""" - import jax - - def fable(input_matrix): - qml.FABLE(input_matrix, wires=range(5), tol=0) - return qml.expval(qml.PauliZ(wires=0)) - - input_matrix = np.array( - [ - [-0.5, -0.4, 0.6, 0.7], - [0.9, 0.9, 0.8, 0.9], - [0.8, 0.7, 0.9, 0.8], - [0.9, 0.7, 0.8, 0.3], - ] - ) - - device = qml.device("default.qubit", wires=5) - interpreted_fn = qml.QNode(fable, device) - jitted_fn = jax.jit(interpreted_fn) - assert np.allclose(jitted_fn(input_matrix), interpreted_fn(input_matrix)) + assert np.allclose(gradient_numeric, gradient_jax[0, 0], rtol=0.001) + assert np.allclose(jitted_fn(input_matrix), circuit_jax(input_matrix)) @pytest.mark.jax def test_fable_grad_jax_jit_error(self, input_matrix): From 403c7d4fa8fd6b54035f1fa5ef42a495ca1e69ae Mon Sep 17 00:00:00 2001 From: Will Date: Fri, 13 Sep 2024 11:27:55 -0400 Subject: [PATCH 8/8] using qml allclose --- tests/templates/test_subroutines/test_fable.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/templates/test_subroutines/test_fable.py b/tests/templates/test_subroutines/test_fable.py index ceebfc8735c..d2ba5f2496a 100644 --- a/tests/templates/test_subroutines/test_fable.py +++ b/tests/templates/test_subroutines/test_fable.py @@ -285,8 +285,8 @@ def circuit_jax(input_matrix): ) / (2 * delta) gradient_jax = grad_fn(input_matrix_jax) - assert np.allclose(gradient_numeric, gradient_jax[0, 0], rtol=0.001) - assert np.allclose(jitted_fn(input_matrix), circuit_jax(input_matrix)) + assert qml.math.allclose(gradient_numeric, gradient_jax[0, 0], rtol=0.001) + assert qml.math.allclose(jitted_fn(input_matrix), circuit_jax(input_matrix)) @pytest.mark.jax def test_fable_grad_jax_jit_error(self, input_matrix):