diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 4b5d8799594..237b52b1942 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -335,6 +335,9 @@
Bug fixes 🐛
+* Fixes unnecessary call of `eigvals` in `qml.ops.op_math.decompositions.two_qubit_unitary.py` that was causing an error in VJP. Raises warnings to users if this essentially nondifferentiable module is used.
+ [(#6437)](https://github.com/PennyLaneAI/pennylane/pull/6437)
+
* Patches the `math` module to function with autoray 0.7.0.
[(#6429)](https://github.com/PennyLaneAI/pennylane/pull/6429)
@@ -417,6 +420,7 @@ Guillermo Alonso,
Utkarsh Azad,
Oleksandr Borysenko,
Astral Cai,
+Yushao Chen,
Isaac De Vlugt,
Diksha Dhawan,
Lillian M. A. Frederiksen,
diff --git a/pennylane/ops/op_math/decompositions/two_qubit_unitary.py b/pennylane/ops/op_math/decompositions/two_qubit_unitary.py
index 46e43dd2168..de223d5f3ef 100644
--- a/pennylane/ops/op_math/decompositions/two_qubit_unitary.py
+++ b/pennylane/ops/op_math/decompositions/two_qubit_unitary.py
@@ -14,6 +14,8 @@
"""Contains transforms and helpers functions for decomposing arbitrary two-qubit
unitary operations into elementary gates.
"""
+import warnings
+
import numpy as np
import pennylane as qml
@@ -43,6 +45,22 @@
###################################################################################
+def _check_differentiability_warning(U):
+ """Check conditions that may lead to non-differentiability and raise appropriate warnings.
+
+ Args:
+ U (tensor_like): Input unitary matrix to check.
+ """
+
+ if qml.math.requires_grad(U):
+ warnings.warn(
+ "The two-qubit decomposition may not be differentiable when the input "
+ "unitary depends on trainable parameters.",
+ RuntimeWarning,
+ stacklevel=2,
+ )
+
+
# This gate E is called the "magic basis". It can be used to convert between
# SO(4) and SU(2) x SU(2). For A in SO(4), E A E^\dag is in SU(2) x SU(2).
E = np.array([[1, 1j, 0, 0], [0, 0, 1j, 1], [0, 0, 1j, -1], [1, -1j, 0, 0]]) / np.sqrt(2)
@@ -114,6 +132,8 @@ def _compute_num_cnots(U):
u = math.dot(Edag, math.dot(U, E))
gammaU = math.dot(u, math.T(u))
trace = math.trace(gammaU)
+ gU2 = math.dot(gammaU, gammaU)
+ id4 = math.eye(4)
# Case: 0 CNOTs (tensor product), the trace is +/- 4
# We need a tolerance of around 1e-7 here in order to work with the case where U
@@ -121,15 +141,9 @@ def _compute_num_cnots(U):
if math.allclose(trace, 4, atol=1e-7) or math.allclose(trace, -4, atol=1e-7):
return 0
- # To distinguish between 1/2 CNOT cases, we need to look at the eigenvalues
- evs = math.linalg.eigvals(gammaU)
-
- sorted_evs = math.sort(math.imag(evs))
-
# Case: 1 CNOT, the trace is 0, and the eigenvalues of gammaU are [-1j, -1j, 1j, 1j]
- # Checking the eigenvalues is needed because of some special 2-CNOT cases that yield
- # a trace 0.
- if math.allclose(trace, 0j, atol=1e-7) and math.allclose(sorted_evs, [-1, -1, 1, 1]):
+ # Try gammaU^2 + I = 0 along with zero trace
+ if math.allclose(trace, 0j, atol=1e-7) and math.allclose(gU2 + id4, 0):
return 1
# Case: 2 CNOTs, the trace has only a real part (or is 0)
@@ -604,6 +618,7 @@ def two_qubit_decomposition(U, wires):
Rot(tensor(-3.78673588, requires_grad=True), tensor(2.03936812, requires_grad=True), tensor(-2.46956972, requires_grad=True), wires=[0])]
"""
+ _check_differentiability_warning(U)
# First, we note that this method works only for SU(4) gates, meaning that
# we need to rescale the matrix by its determinant.
U = _convert_to_su4(U)
diff --git a/tests/ops/op_math/test_decompositions.py b/tests/ops/op_math/test_decompositions.py
index acf0454bb5b..42573654899 100644
--- a/tests/ops/op_math/test_decompositions.py
+++ b/tests/ops/op_math/test_decompositions.py
@@ -141,6 +141,7 @@ def test_zyz_decomposition(self, U, expected_params):
def test_zyz_decomposition_torch(self, U, expected_params):
"""Test that a one-qubit operation in Torch is correctly decomposed."""
+ # pylint: disable=import-outside-toplevel
import torch
U = torch.tensor(U, dtype=torch.complex128)
@@ -151,6 +152,7 @@ def test_zyz_decomposition_torch(self, U, expected_params):
def test_zyz_decomposition_tf(self, U, expected_params):
"""Test that a one-qubit operation in Tensorflow is correctly decomposed."""
+ # pylint: disable=import-outside-toplevel
import tensorflow as tf
U = tf.Variable(U, dtype=tf.complex128)
@@ -161,6 +163,7 @@ def test_zyz_decomposition_tf(self, U, expected_params):
def test_zyz_decomposition_jax(self, U, expected_params):
"""Test that a one-qubit operation in JAX is correctly decomposed."""
+ # pylint: disable=import-outside-toplevel
import jax
U = jax.numpy.array(U, dtype=jax.numpy.complex128)
@@ -216,6 +219,7 @@ def test_xyx_decomposition(self, U, expected_params):
def test_xyx_decomposition_torch(self, U, expected_params):
"""Test that a one-qubit operation in Torch is correctly decomposed."""
+ # pylint: disable=import-outside-toplevel
import torch
U = torch.tensor(U, dtype=torch.complex128)
@@ -226,6 +230,7 @@ def test_xyx_decomposition_torch(self, U, expected_params):
def test_xyx_decomposition_tf(self, U, expected_params):
"""Test that a one-qubit operation in Tensorflow is correctly decomposed."""
+ # pylint: disable=import-outside-toplevel
import tensorflow as tf
U = tf.Variable(U, dtype=tf.complex128)
@@ -236,6 +241,7 @@ def test_xyx_decomposition_tf(self, U, expected_params):
def test_xyx_decomposition_jax(self, U, expected_params):
"""Test that a one-qubit operation in JAX is correctly decomposed."""
+ # pylint: disable=import-outside-toplevel
import jax
U = jax.numpy.array(U, dtype=jax.numpy.complex128)
@@ -288,6 +294,7 @@ def test_xzx_decomposition(self, U, expected_params):
def test_xzx_decomposition_torch(self, U, expected_params):
"""Test that a one-qubit operation in Torch is correctly decomposed."""
+ # pylint: disable=import-outside-toplevel
import torch
U = torch.tensor(U, dtype=torch.complex128)
@@ -298,6 +305,7 @@ def test_xzx_decomposition_torch(self, U, expected_params):
def test_xzx_decomposition_tf(self, U, expected_params):
"""Test that a one-qubit operation in Tensorflow is correctly decomposed."""
+ # pylint: disable=import-outside-toplevel
import tensorflow as tf
U = tf.Variable(U, dtype=tf.complex128)
@@ -308,6 +316,7 @@ def test_xzx_decomposition_tf(self, U, expected_params):
def test_xzx_decomposition_jax(self, U, expected_params):
"""Test that a one-qubit operation in JAX is correctly decomposed."""
+ # pylint: disable=import-outside-toplevel
import jax
U = jax.numpy.array(U, dtype=jax.numpy.complex128)
@@ -374,6 +383,7 @@ def test_zxz_decomposition(self, U, expected_params):
def test_zxz_decomposition_torch(self, U, expected_params):
"""Test that a one-qubit operation in Torch is correctly decomposed."""
+ # pylint: disable=import-outside-toplevel
import torch
U = torch.tensor(U, dtype=torch.complex128)
@@ -384,6 +394,7 @@ def test_zxz_decomposition_torch(self, U, expected_params):
def test_zxz_decomposition_tf(self, U, expected_params):
"""Test that a one-qubit operation in Tensorflow is correctly decomposed."""
+ # pylint: disable=import-outside-toplevel
import tensorflow as tf
U = tf.Variable(U, dtype=tf.complex128)
@@ -394,6 +405,7 @@ def test_zxz_decomposition_tf(self, U, expected_params):
def test_zxz_decomposition_jax(self, U, expected_params):
"""Test that a one-qubit operation in JAX is correctly decomposed."""
+ # pylint: disable=import-outside-toplevel
import jax
U = jax.numpy.array(U, dtype=jax.numpy.complex128)
@@ -438,6 +450,7 @@ def test_one_qubit_decomposition_rot(self, U, expected_gates, expected_params):
def test_rot_decomposition_torch(self, U, expected_gates, expected_params):
"""Test that a one-qubit operation in Torch is correctly decomposed."""
+ # pylint: disable=import-outside-toplevel
import torch
U = torch.tensor(U, dtype=torch.complex128)
@@ -448,6 +461,7 @@ def test_rot_decomposition_torch(self, U, expected_gates, expected_params):
def test_rot_decomposition_tf(self, U, expected_gates, expected_params):
"""Test that a one-qubit operation in Tensorflow is correctly decomposed."""
+ # pylint: disable=import-outside-toplevel
import tensorflow as tf
U = tf.Variable(U, dtype=tf.complex128)
@@ -458,6 +472,7 @@ def test_rot_decomposition_tf(self, U, expected_gates, expected_params):
def test_rot_decomposition_jax(self, U, expected_gates, expected_params):
"""Test that a one-qubit operation in JAX is correctly decomposed."""
+ # pylint: disable=import-outside-toplevel
import jax
U = jax.numpy.array(U, dtype=jax.numpy.complex128)
@@ -1046,6 +1061,7 @@ class TestTwoQubitUnitaryDecompositionInterfaces:
def test_two_qubit_decomposition_torch(self, U, wires):
"""Test that a two-qubit operation in Torch is correctly decomposed."""
+ # pylint: disable=import-outside-toplevel
import torch
U = torch.tensor(U, dtype=torch.complex128)
@@ -1067,6 +1083,7 @@ def test_two_qubit_decomposition_torch(self, U, wires):
def test_two_qubit_decomposition_tensor_products_torch(self, U_pair, wires):
"""Test that a two-qubit tensor product in Torch is correctly decomposed."""
+ # pylint: disable=import-outside-toplevel
import torch
U1 = torch.tensor(U_pair[0], dtype=torch.complex128)
@@ -1090,6 +1107,7 @@ def test_two_qubit_decomposition_tensor_products_torch(self, U_pair, wires):
def test_two_qubit_decomposition_tf(self, U, wires):
"""Test that a two-qubit operation in Tensorflow is correctly decomposed."""
+ # pylint: disable=import-outside-toplevel
import tensorflow as tf
U = tf.Variable(U, dtype=tf.complex128)
@@ -1111,6 +1129,7 @@ def test_two_qubit_decomposition_tf(self, U, wires):
def test_two_qubit_decomposition_tensor_products_tf(self, U_pair, wires):
"""Test that a two-qubit tensor product in Tensorflow is correctly decomposed."""
+ # pylint: disable=import-outside-toplevel
import tensorflow as tf
U1 = tf.Variable(U_pair[0], dtype=tf.complex128)
@@ -1134,6 +1153,7 @@ def test_two_qubit_decomposition_tensor_products_tf(self, U_pair, wires):
def test_two_qubit_decomposition_jax(self, U, wires):
"""Test that a two-qubit operation in JAX is correctly decomposed."""
+ # pylint: disable=import-outside-toplevel
import jax
U = jax.numpy.array(U, dtype=jax.numpy.complex128)
@@ -1155,6 +1175,7 @@ def test_two_qubit_decomposition_jax(self, U, wires):
def test_two_qubit_decomposition_tensor_products_jax(self, U_pair, wires):
"""Test that a two-qubit tensor product in JAX is correctly decomposed."""
+ # pylint: disable=import-outside-toplevel
import jax
U1 = jax.numpy.array(U_pair[0], dtype=jax.numpy.complex128)
@@ -1178,6 +1199,7 @@ def test_two_qubit_decomposition_tensor_products_jax(self, U_pair, wires):
def test_two_qubit_decomposition_jax_jit(self, U, wires):
"""Test that a two-qubit operation is correctly decomposed with JAX-JIT ."""
+ # pylint: disable=import-outside-toplevel
import jax
U = jax.numpy.array(U, dtype=jax.numpy.complex128)
@@ -1206,6 +1228,7 @@ def wrapped_decomposition(U):
def test_two_qubit_decomposition_tensor_products_jax_jit(self, U_pair, wires):
"""Test that a two-qubit tensor product is correctly decomposed with JAX-JIT."""
+ # pylint: disable=import-outside-toplevel
import jax
U1 = jax.numpy.array(U_pair[0], dtype=jax.numpy.complex128)
@@ -1263,3 +1286,128 @@ def expm(val):
mat = make_unitary(np.pi / 2)
decomp_mat = qml.matrix(two_qubit_decomposition, wire_order=(0, 1))(mat, wires=(0, 1))
assert qml.math.allclose(mat, decomp_mat)
+
+
+class TestTwoQubitDecompositionWarnings:
+ """Test suite for warning generation in two_qubit_decomposition"""
+
+ def test_warning_parameterized_autograd(self):
+ """Test warning is raised for parameterized matrix with autograd"""
+ dev = qml.device("default.qubit", wires=2)
+
+ def my_qfunc(params):
+ U = qml.numpy.array(np.eye(4, dtype=np.complex128), requires_grad=True) * params
+ ops = qml.ops.two_qubit_decomposition(U, wires=[0, 1])
+ for op in ops:
+ qml.apply(op)
+ return qml.expval(qml.PauliZ(0))
+
+ qnode = qml.QNode(my_qfunc, dev, interface="autograd")
+
+ with pytest.warns(
+ RuntimeWarning, match="The two-qubit decomposition may not be differentiable"
+ ):
+ qnode(1.0)
+
+ @pytest.mark.torch
+ def test_warning_parameterized_torch(self):
+ """Test warning is raised for parameterized matrix with PyTorch"""
+ try:
+ # pylint: disable=import-outside-toplevel
+ import torch
+ except ImportError:
+ pytest.skip("PyTorch not installed")
+
+ dev = qml.device("default.qubit", wires=2)
+
+ def my_qfunc(params):
+ U = torch.eye(4, dtype=torch.complex128) * params
+ U.requires_grad_(True)
+ ops = qml.ops.two_qubit_decomposition(U, wires=[0, 1])
+ for op in ops:
+ qml.apply(op)
+ return qml.expval(qml.PauliZ(0))
+
+ qnode = qml.QNode(my_qfunc, dev, interface="torch")
+
+ with pytest.warns(
+ RuntimeWarning, match="The two-qubit decomposition may not be differentiable"
+ ):
+ qnode(torch.tensor(1.0, dtype=torch.complex128, requires_grad=True))
+
+ @pytest.mark.tf
+ def test_warning_parameterized_tf(self):
+ """Test warning is raised for parameterized matrix with TensorFlow"""
+ try:
+ # pylint: disable=import-outside-toplevel
+ import tensorflow as tf
+ except ImportError:
+ pytest.skip("TensorFlow not installed")
+
+ dev = qml.device("default.qubit", wires=2)
+
+ def my_qfunc(params):
+ params = tf.cast(params, tf.complex128)
+ U = tf.eye(4, dtype=tf.complex128) * params # Create tensor without Variable
+ with tf.GradientTape() as tape:
+ tape.watch(U) # Explicitly watch U
+ ops = qml.ops.two_qubit_decomposition(U, wires=[0, 1])
+ for op in ops:
+ qml.apply(op)
+ return qml.expval(qml.PauliZ(0))
+
+ qnode = qml.QNode(my_qfunc, dev, interface="tf")
+
+ with pytest.warns(
+ RuntimeWarning, match="The two-qubit decomposition may not be differentiable"
+ ):
+ qnode(tf.constant(1.0))
+
+ @pytest.mark.jax
+ def test_warning_parameterized_jax(self):
+ """Test warning is raised for parameterized matrix with JAX"""
+ try:
+ # pylint: disable=import-outside-toplevel
+ import jax
+ import jax.numpy as jnp
+ except ImportError:
+ pytest.skip("JAX not installed")
+
+ dev = qml.device("default.qubit", wires=2)
+
+ def my_qfunc(params):
+ U = jnp.array(np.eye(4, dtype=np.complex128)) * params
+ ops = qml.ops.two_qubit_decomposition(U, wires=[0, 1])
+ for op in ops:
+ qml.apply(op)
+ return qml.expval(qml.PauliZ(0))
+
+ qnode = qml.QNode(my_qfunc, dev, interface="jax")
+
+ # Convert function to one that JAX can differentiate
+ def cost(x):
+ return jnp.real(qnode(x))
+
+ with pytest.warns(
+ RuntimeWarning, match="The two-qubit decomposition may not be differentiable"
+ ):
+ # Use JAX's grad to create a Tracer
+ jax.grad(cost)(1.0)
+
+ def test_warning_complex_input(self):
+ """Test warning is raised with complex input parameters"""
+ dev = qml.device("default.qubit", wires=2)
+
+ def my_qfunc(params):
+ U = qml.numpy.array(np.eye(4, dtype=np.complex128), requires_grad=True) * params
+ ops = qml.ops.two_qubit_decomposition(U, wires=[0, 1])
+ for op in ops:
+ qml.apply(op)
+ return qml.expval(qml.PauliZ(0))
+
+ qnode = qml.QNode(my_qfunc, dev, interface="autograd")
+
+ with pytest.warns(
+ RuntimeWarning, match="The two-qubit decomposition may not be differentiable"
+ ):
+ qnode(1.0 + 0.5j)