diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 4b5d8799594..237b52b1942 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -335,6 +335,9 @@

Bug fixes 🐛

+* Fixes unnecessary call of `eigvals` in `qml.ops.op_math.decompositions.two_qubit_unitary.py` that was causing an error in VJP. Raises warnings to users if this essentially nondifferentiable module is used. + [(#6437)](https://github.com/PennyLaneAI/pennylane/pull/6437) + * Patches the `math` module to function with autoray 0.7.0. [(#6429)](https://github.com/PennyLaneAI/pennylane/pull/6429) @@ -417,6 +420,7 @@ Guillermo Alonso, Utkarsh Azad, Oleksandr Borysenko, Astral Cai, +Yushao Chen, Isaac De Vlugt, Diksha Dhawan, Lillian M. A. Frederiksen, diff --git a/pennylane/ops/op_math/decompositions/two_qubit_unitary.py b/pennylane/ops/op_math/decompositions/two_qubit_unitary.py index 46e43dd2168..de223d5f3ef 100644 --- a/pennylane/ops/op_math/decompositions/two_qubit_unitary.py +++ b/pennylane/ops/op_math/decompositions/two_qubit_unitary.py @@ -14,6 +14,8 @@ """Contains transforms and helpers functions for decomposing arbitrary two-qubit unitary operations into elementary gates. """ +import warnings + import numpy as np import pennylane as qml @@ -43,6 +45,22 @@ ################################################################################### +def _check_differentiability_warning(U): + """Check conditions that may lead to non-differentiability and raise appropriate warnings. + + Args: + U (tensor_like): Input unitary matrix to check. + """ + + if qml.math.requires_grad(U): + warnings.warn( + "The two-qubit decomposition may not be differentiable when the input " + "unitary depends on trainable parameters.", + RuntimeWarning, + stacklevel=2, + ) + + # This gate E is called the "magic basis". It can be used to convert between # SO(4) and SU(2) x SU(2). For A in SO(4), E A E^\dag is in SU(2) x SU(2). E = np.array([[1, 1j, 0, 0], [0, 0, 1j, 1], [0, 0, 1j, -1], [1, -1j, 0, 0]]) / np.sqrt(2) @@ -114,6 +132,8 @@ def _compute_num_cnots(U): u = math.dot(Edag, math.dot(U, E)) gammaU = math.dot(u, math.T(u)) trace = math.trace(gammaU) + gU2 = math.dot(gammaU, gammaU) + id4 = math.eye(4) # Case: 0 CNOTs (tensor product), the trace is +/- 4 # We need a tolerance of around 1e-7 here in order to work with the case where U @@ -121,15 +141,9 @@ def _compute_num_cnots(U): if math.allclose(trace, 4, atol=1e-7) or math.allclose(trace, -4, atol=1e-7): return 0 - # To distinguish between 1/2 CNOT cases, we need to look at the eigenvalues - evs = math.linalg.eigvals(gammaU) - - sorted_evs = math.sort(math.imag(evs)) - # Case: 1 CNOT, the trace is 0, and the eigenvalues of gammaU are [-1j, -1j, 1j, 1j] - # Checking the eigenvalues is needed because of some special 2-CNOT cases that yield - # a trace 0. - if math.allclose(trace, 0j, atol=1e-7) and math.allclose(sorted_evs, [-1, -1, 1, 1]): + # Try gammaU^2 + I = 0 along with zero trace + if math.allclose(trace, 0j, atol=1e-7) and math.allclose(gU2 + id4, 0): return 1 # Case: 2 CNOTs, the trace has only a real part (or is 0) @@ -604,6 +618,7 @@ def two_qubit_decomposition(U, wires): Rot(tensor(-3.78673588, requires_grad=True), tensor(2.03936812, requires_grad=True), tensor(-2.46956972, requires_grad=True), wires=[0])] """ + _check_differentiability_warning(U) # First, we note that this method works only for SU(4) gates, meaning that # we need to rescale the matrix by its determinant. U = _convert_to_su4(U) diff --git a/tests/ops/op_math/test_decompositions.py b/tests/ops/op_math/test_decompositions.py index acf0454bb5b..42573654899 100644 --- a/tests/ops/op_math/test_decompositions.py +++ b/tests/ops/op_math/test_decompositions.py @@ -141,6 +141,7 @@ def test_zyz_decomposition(self, U, expected_params): def test_zyz_decomposition_torch(self, U, expected_params): """Test that a one-qubit operation in Torch is correctly decomposed.""" + # pylint: disable=import-outside-toplevel import torch U = torch.tensor(U, dtype=torch.complex128) @@ -151,6 +152,7 @@ def test_zyz_decomposition_torch(self, U, expected_params): def test_zyz_decomposition_tf(self, U, expected_params): """Test that a one-qubit operation in Tensorflow is correctly decomposed.""" + # pylint: disable=import-outside-toplevel import tensorflow as tf U = tf.Variable(U, dtype=tf.complex128) @@ -161,6 +163,7 @@ def test_zyz_decomposition_tf(self, U, expected_params): def test_zyz_decomposition_jax(self, U, expected_params): """Test that a one-qubit operation in JAX is correctly decomposed.""" + # pylint: disable=import-outside-toplevel import jax U = jax.numpy.array(U, dtype=jax.numpy.complex128) @@ -216,6 +219,7 @@ def test_xyx_decomposition(self, U, expected_params): def test_xyx_decomposition_torch(self, U, expected_params): """Test that a one-qubit operation in Torch is correctly decomposed.""" + # pylint: disable=import-outside-toplevel import torch U = torch.tensor(U, dtype=torch.complex128) @@ -226,6 +230,7 @@ def test_xyx_decomposition_torch(self, U, expected_params): def test_xyx_decomposition_tf(self, U, expected_params): """Test that a one-qubit operation in Tensorflow is correctly decomposed.""" + # pylint: disable=import-outside-toplevel import tensorflow as tf U = tf.Variable(U, dtype=tf.complex128) @@ -236,6 +241,7 @@ def test_xyx_decomposition_tf(self, U, expected_params): def test_xyx_decomposition_jax(self, U, expected_params): """Test that a one-qubit operation in JAX is correctly decomposed.""" + # pylint: disable=import-outside-toplevel import jax U = jax.numpy.array(U, dtype=jax.numpy.complex128) @@ -288,6 +294,7 @@ def test_xzx_decomposition(self, U, expected_params): def test_xzx_decomposition_torch(self, U, expected_params): """Test that a one-qubit operation in Torch is correctly decomposed.""" + # pylint: disable=import-outside-toplevel import torch U = torch.tensor(U, dtype=torch.complex128) @@ -298,6 +305,7 @@ def test_xzx_decomposition_torch(self, U, expected_params): def test_xzx_decomposition_tf(self, U, expected_params): """Test that a one-qubit operation in Tensorflow is correctly decomposed.""" + # pylint: disable=import-outside-toplevel import tensorflow as tf U = tf.Variable(U, dtype=tf.complex128) @@ -308,6 +316,7 @@ def test_xzx_decomposition_tf(self, U, expected_params): def test_xzx_decomposition_jax(self, U, expected_params): """Test that a one-qubit operation in JAX is correctly decomposed.""" + # pylint: disable=import-outside-toplevel import jax U = jax.numpy.array(U, dtype=jax.numpy.complex128) @@ -374,6 +383,7 @@ def test_zxz_decomposition(self, U, expected_params): def test_zxz_decomposition_torch(self, U, expected_params): """Test that a one-qubit operation in Torch is correctly decomposed.""" + # pylint: disable=import-outside-toplevel import torch U = torch.tensor(U, dtype=torch.complex128) @@ -384,6 +394,7 @@ def test_zxz_decomposition_torch(self, U, expected_params): def test_zxz_decomposition_tf(self, U, expected_params): """Test that a one-qubit operation in Tensorflow is correctly decomposed.""" + # pylint: disable=import-outside-toplevel import tensorflow as tf U = tf.Variable(U, dtype=tf.complex128) @@ -394,6 +405,7 @@ def test_zxz_decomposition_tf(self, U, expected_params): def test_zxz_decomposition_jax(self, U, expected_params): """Test that a one-qubit operation in JAX is correctly decomposed.""" + # pylint: disable=import-outside-toplevel import jax U = jax.numpy.array(U, dtype=jax.numpy.complex128) @@ -438,6 +450,7 @@ def test_one_qubit_decomposition_rot(self, U, expected_gates, expected_params): def test_rot_decomposition_torch(self, U, expected_gates, expected_params): """Test that a one-qubit operation in Torch is correctly decomposed.""" + # pylint: disable=import-outside-toplevel import torch U = torch.tensor(U, dtype=torch.complex128) @@ -448,6 +461,7 @@ def test_rot_decomposition_torch(self, U, expected_gates, expected_params): def test_rot_decomposition_tf(self, U, expected_gates, expected_params): """Test that a one-qubit operation in Tensorflow is correctly decomposed.""" + # pylint: disable=import-outside-toplevel import tensorflow as tf U = tf.Variable(U, dtype=tf.complex128) @@ -458,6 +472,7 @@ def test_rot_decomposition_tf(self, U, expected_gates, expected_params): def test_rot_decomposition_jax(self, U, expected_gates, expected_params): """Test that a one-qubit operation in JAX is correctly decomposed.""" + # pylint: disable=import-outside-toplevel import jax U = jax.numpy.array(U, dtype=jax.numpy.complex128) @@ -1046,6 +1061,7 @@ class TestTwoQubitUnitaryDecompositionInterfaces: def test_two_qubit_decomposition_torch(self, U, wires): """Test that a two-qubit operation in Torch is correctly decomposed.""" + # pylint: disable=import-outside-toplevel import torch U = torch.tensor(U, dtype=torch.complex128) @@ -1067,6 +1083,7 @@ def test_two_qubit_decomposition_torch(self, U, wires): def test_two_qubit_decomposition_tensor_products_torch(self, U_pair, wires): """Test that a two-qubit tensor product in Torch is correctly decomposed.""" + # pylint: disable=import-outside-toplevel import torch U1 = torch.tensor(U_pair[0], dtype=torch.complex128) @@ -1090,6 +1107,7 @@ def test_two_qubit_decomposition_tensor_products_torch(self, U_pair, wires): def test_two_qubit_decomposition_tf(self, U, wires): """Test that a two-qubit operation in Tensorflow is correctly decomposed.""" + # pylint: disable=import-outside-toplevel import tensorflow as tf U = tf.Variable(U, dtype=tf.complex128) @@ -1111,6 +1129,7 @@ def test_two_qubit_decomposition_tf(self, U, wires): def test_two_qubit_decomposition_tensor_products_tf(self, U_pair, wires): """Test that a two-qubit tensor product in Tensorflow is correctly decomposed.""" + # pylint: disable=import-outside-toplevel import tensorflow as tf U1 = tf.Variable(U_pair[0], dtype=tf.complex128) @@ -1134,6 +1153,7 @@ def test_two_qubit_decomposition_tensor_products_tf(self, U_pair, wires): def test_two_qubit_decomposition_jax(self, U, wires): """Test that a two-qubit operation in JAX is correctly decomposed.""" + # pylint: disable=import-outside-toplevel import jax U = jax.numpy.array(U, dtype=jax.numpy.complex128) @@ -1155,6 +1175,7 @@ def test_two_qubit_decomposition_jax(self, U, wires): def test_two_qubit_decomposition_tensor_products_jax(self, U_pair, wires): """Test that a two-qubit tensor product in JAX is correctly decomposed.""" + # pylint: disable=import-outside-toplevel import jax U1 = jax.numpy.array(U_pair[0], dtype=jax.numpy.complex128) @@ -1178,6 +1199,7 @@ def test_two_qubit_decomposition_tensor_products_jax(self, U_pair, wires): def test_two_qubit_decomposition_jax_jit(self, U, wires): """Test that a two-qubit operation is correctly decomposed with JAX-JIT .""" + # pylint: disable=import-outside-toplevel import jax U = jax.numpy.array(U, dtype=jax.numpy.complex128) @@ -1206,6 +1228,7 @@ def wrapped_decomposition(U): def test_two_qubit_decomposition_tensor_products_jax_jit(self, U_pair, wires): """Test that a two-qubit tensor product is correctly decomposed with JAX-JIT.""" + # pylint: disable=import-outside-toplevel import jax U1 = jax.numpy.array(U_pair[0], dtype=jax.numpy.complex128) @@ -1263,3 +1286,128 @@ def expm(val): mat = make_unitary(np.pi / 2) decomp_mat = qml.matrix(two_qubit_decomposition, wire_order=(0, 1))(mat, wires=(0, 1)) assert qml.math.allclose(mat, decomp_mat) + + +class TestTwoQubitDecompositionWarnings: + """Test suite for warning generation in two_qubit_decomposition""" + + def test_warning_parameterized_autograd(self): + """Test warning is raised for parameterized matrix with autograd""" + dev = qml.device("default.qubit", wires=2) + + def my_qfunc(params): + U = qml.numpy.array(np.eye(4, dtype=np.complex128), requires_grad=True) * params + ops = qml.ops.two_qubit_decomposition(U, wires=[0, 1]) + for op in ops: + qml.apply(op) + return qml.expval(qml.PauliZ(0)) + + qnode = qml.QNode(my_qfunc, dev, interface="autograd") + + with pytest.warns( + RuntimeWarning, match="The two-qubit decomposition may not be differentiable" + ): + qnode(1.0) + + @pytest.mark.torch + def test_warning_parameterized_torch(self): + """Test warning is raised for parameterized matrix with PyTorch""" + try: + # pylint: disable=import-outside-toplevel + import torch + except ImportError: + pytest.skip("PyTorch not installed") + + dev = qml.device("default.qubit", wires=2) + + def my_qfunc(params): + U = torch.eye(4, dtype=torch.complex128) * params + U.requires_grad_(True) + ops = qml.ops.two_qubit_decomposition(U, wires=[0, 1]) + for op in ops: + qml.apply(op) + return qml.expval(qml.PauliZ(0)) + + qnode = qml.QNode(my_qfunc, dev, interface="torch") + + with pytest.warns( + RuntimeWarning, match="The two-qubit decomposition may not be differentiable" + ): + qnode(torch.tensor(1.0, dtype=torch.complex128, requires_grad=True)) + + @pytest.mark.tf + def test_warning_parameterized_tf(self): + """Test warning is raised for parameterized matrix with TensorFlow""" + try: + # pylint: disable=import-outside-toplevel + import tensorflow as tf + except ImportError: + pytest.skip("TensorFlow not installed") + + dev = qml.device("default.qubit", wires=2) + + def my_qfunc(params): + params = tf.cast(params, tf.complex128) + U = tf.eye(4, dtype=tf.complex128) * params # Create tensor without Variable + with tf.GradientTape() as tape: + tape.watch(U) # Explicitly watch U + ops = qml.ops.two_qubit_decomposition(U, wires=[0, 1]) + for op in ops: + qml.apply(op) + return qml.expval(qml.PauliZ(0)) + + qnode = qml.QNode(my_qfunc, dev, interface="tf") + + with pytest.warns( + RuntimeWarning, match="The two-qubit decomposition may not be differentiable" + ): + qnode(tf.constant(1.0)) + + @pytest.mark.jax + def test_warning_parameterized_jax(self): + """Test warning is raised for parameterized matrix with JAX""" + try: + # pylint: disable=import-outside-toplevel + import jax + import jax.numpy as jnp + except ImportError: + pytest.skip("JAX not installed") + + dev = qml.device("default.qubit", wires=2) + + def my_qfunc(params): + U = jnp.array(np.eye(4, dtype=np.complex128)) * params + ops = qml.ops.two_qubit_decomposition(U, wires=[0, 1]) + for op in ops: + qml.apply(op) + return qml.expval(qml.PauliZ(0)) + + qnode = qml.QNode(my_qfunc, dev, interface="jax") + + # Convert function to one that JAX can differentiate + def cost(x): + return jnp.real(qnode(x)) + + with pytest.warns( + RuntimeWarning, match="The two-qubit decomposition may not be differentiable" + ): + # Use JAX's grad to create a Tracer + jax.grad(cost)(1.0) + + def test_warning_complex_input(self): + """Test warning is raised with complex input parameters""" + dev = qml.device("default.qubit", wires=2) + + def my_qfunc(params): + U = qml.numpy.array(np.eye(4, dtype=np.complex128), requires_grad=True) * params + ops = qml.ops.two_qubit_decomposition(U, wires=[0, 1]) + for op in ops: + qml.apply(op) + return qml.expval(qml.PauliZ(0)) + + qnode = qml.QNode(my_qfunc, dev, interface="autograd") + + with pytest.warns( + RuntimeWarning, match="The two-qubit decomposition may not be differentiable" + ): + qnode(1.0 + 0.5j)