Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Warnings for two_qubit_unitary #6437

Merged
merged 10 commits into from
Oct 25, 2024
31 changes: 23 additions & 8 deletions pennylane/ops/op_math/decompositions/two_qubit_unitary.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
"""Contains transforms and helpers functions for decomposing arbitrary two-qubit
unitary operations into elementary gates.
"""
import warnings

import numpy as np

import pennylane as qml
Expand Down Expand Up @@ -43,6 +45,22 @@
###################################################################################


def _check_differentiability_warning(U):
"""Check conditions that may lead to non-differentiability and raise appropriate warnings.

Args:
U (tensor): Input unitary matrix to check
JerryChen97 marked this conversation as resolved.
Show resolved Hide resolved
"""

if qml.math.requires_grad(U):
warnings.warn(
"The two-qubit decomposition may not be differentiable when the input "
"unitary depends on trainable parameters.",
RuntimeWarning,
stacklevel=2,
)


# This gate E is called the "magic basis". It can be used to convert between
# SO(4) and SU(2) x SU(2). For A in SO(4), E A E^\dag is in SU(2) x SU(2).
E = np.array([[1, 1j, 0, 0], [0, 0, 1j, 1], [0, 0, 1j, -1], [1, -1j, 0, 0]]) / np.sqrt(2)
Expand Down Expand Up @@ -114,22 +132,18 @@ def _compute_num_cnots(U):
u = math.dot(Edag, math.dot(U, E))
gammaU = math.dot(u, math.T(u))
trace = math.trace(gammaU)
gU2 = math.dot(gammaU, gammaU)
id4 = math.eye(4)

# Case: 0 CNOTs (tensor product), the trace is +/- 4
# We need a tolerance of around 1e-7 here in order to work with the case where U
# is specified with 8 decimal places.
if math.allclose(trace, 4, atol=1e-7) or math.allclose(trace, -4, atol=1e-7):
return 0

# To distinguish between 1/2 CNOT cases, we need to look at the eigenvalues
evs = math.linalg.eigvals(gammaU)

sorted_evs = math.sort(math.imag(evs))

# Case: 1 CNOT, the trace is 0, and the eigenvalues of gammaU are [-1j, -1j, 1j, 1j]
# Checking the eigenvalues is needed because of some special 2-CNOT cases that yield
# a trace 0.
if math.allclose(trace, 0j, atol=1e-7) and math.allclose(sorted_evs, [-1, -1, 1, 1]):
# Try gammaU^2 + I = 0 along with zero trace
if math.allclose(trace, 0j, atol=1e-7) and math.allclose(gU2 + id4, 0):
return 1

# Case: 2 CNOTs, the trace has only a real part (or is 0)
Expand Down Expand Up @@ -604,6 +618,7 @@ def two_qubit_decomposition(U, wires):
Rot(tensor(-3.78673588, requires_grad=True), tensor(2.03936812, requires_grad=True), tensor(-2.46956972, requires_grad=True), wires=[0])]

"""
_check_differentiability_warning(U)
# First, we note that this method works only for SU(4) gates, meaning that
# we need to rescale the matrix by its determinant.
U = _convert_to_su4(U)
Expand Down
148 changes: 148 additions & 0 deletions tests/ops/op_math/test_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def test_zyz_decomposition(self, U, expected_params):
def test_zyz_decomposition_torch(self, U, expected_params):
"""Test that a one-qubit operation in Torch is correctly decomposed."""

# pylint: disable=import-outside-toplevel
import torch

U = torch.tensor(U, dtype=torch.complex128)
Expand All @@ -151,6 +152,7 @@ def test_zyz_decomposition_torch(self, U, expected_params):
def test_zyz_decomposition_tf(self, U, expected_params):
"""Test that a one-qubit operation in Tensorflow is correctly decomposed."""

# pylint: disable=import-outside-toplevel
import tensorflow as tf

U = tf.Variable(U, dtype=tf.complex128)
Expand All @@ -161,6 +163,7 @@ def test_zyz_decomposition_tf(self, U, expected_params):
def test_zyz_decomposition_jax(self, U, expected_params):
"""Test that a one-qubit operation in JAX is correctly decomposed."""

# pylint: disable=import-outside-toplevel
import jax

U = jax.numpy.array(U, dtype=jax.numpy.complex128)
Expand Down Expand Up @@ -216,6 +219,7 @@ def test_xyx_decomposition(self, U, expected_params):
def test_xyx_decomposition_torch(self, U, expected_params):
"""Test that a one-qubit operation in Torch is correctly decomposed."""

# pylint: disable=import-outside-toplevel
import torch

U = torch.tensor(U, dtype=torch.complex128)
Expand All @@ -226,6 +230,7 @@ def test_xyx_decomposition_torch(self, U, expected_params):
def test_xyx_decomposition_tf(self, U, expected_params):
"""Test that a one-qubit operation in Tensorflow is correctly decomposed."""

# pylint: disable=import-outside-toplevel
import tensorflow as tf

U = tf.Variable(U, dtype=tf.complex128)
Expand All @@ -236,6 +241,7 @@ def test_xyx_decomposition_tf(self, U, expected_params):
def test_xyx_decomposition_jax(self, U, expected_params):
"""Test that a one-qubit operation in JAX is correctly decomposed."""

# pylint: disable=import-outside-toplevel
import jax

U = jax.numpy.array(U, dtype=jax.numpy.complex128)
Expand Down Expand Up @@ -288,6 +294,7 @@ def test_xzx_decomposition(self, U, expected_params):
def test_xzx_decomposition_torch(self, U, expected_params):
"""Test that a one-qubit operation in Torch is correctly decomposed."""

# pylint: disable=import-outside-toplevel
import torch

U = torch.tensor(U, dtype=torch.complex128)
Expand All @@ -298,6 +305,7 @@ def test_xzx_decomposition_torch(self, U, expected_params):
def test_xzx_decomposition_tf(self, U, expected_params):
"""Test that a one-qubit operation in Tensorflow is correctly decomposed."""

# pylint: disable=import-outside-toplevel
import tensorflow as tf

U = tf.Variable(U, dtype=tf.complex128)
Expand All @@ -308,6 +316,7 @@ def test_xzx_decomposition_tf(self, U, expected_params):
def test_xzx_decomposition_jax(self, U, expected_params):
"""Test that a one-qubit operation in JAX is correctly decomposed."""

# pylint: disable=import-outside-toplevel
import jax

U = jax.numpy.array(U, dtype=jax.numpy.complex128)
Expand Down Expand Up @@ -374,6 +383,7 @@ def test_zxz_decomposition(self, U, expected_params):
def test_zxz_decomposition_torch(self, U, expected_params):
"""Test that a one-qubit operation in Torch is correctly decomposed."""

# pylint: disable=import-outside-toplevel
import torch

U = torch.tensor(U, dtype=torch.complex128)
Expand All @@ -384,6 +394,7 @@ def test_zxz_decomposition_torch(self, U, expected_params):
def test_zxz_decomposition_tf(self, U, expected_params):
"""Test that a one-qubit operation in Tensorflow is correctly decomposed."""

# pylint: disable=import-outside-toplevel
import tensorflow as tf

U = tf.Variable(U, dtype=tf.complex128)
Expand All @@ -394,6 +405,7 @@ def test_zxz_decomposition_tf(self, U, expected_params):
def test_zxz_decomposition_jax(self, U, expected_params):
"""Test that a one-qubit operation in JAX is correctly decomposed."""

# pylint: disable=import-outside-toplevel
import jax

U = jax.numpy.array(U, dtype=jax.numpy.complex128)
Expand Down Expand Up @@ -438,6 +450,7 @@ def test_one_qubit_decomposition_rot(self, U, expected_gates, expected_params):
def test_rot_decomposition_torch(self, U, expected_gates, expected_params):
"""Test that a one-qubit operation in Torch is correctly decomposed."""

# pylint: disable=import-outside-toplevel
import torch

U = torch.tensor(U, dtype=torch.complex128)
Expand All @@ -448,6 +461,7 @@ def test_rot_decomposition_torch(self, U, expected_gates, expected_params):
def test_rot_decomposition_tf(self, U, expected_gates, expected_params):
"""Test that a one-qubit operation in Tensorflow is correctly decomposed."""

# pylint: disable=import-outside-toplevel
import tensorflow as tf

U = tf.Variable(U, dtype=tf.complex128)
Expand All @@ -458,6 +472,7 @@ def test_rot_decomposition_tf(self, U, expected_gates, expected_params):
def test_rot_decomposition_jax(self, U, expected_gates, expected_params):
"""Test that a one-qubit operation in JAX is correctly decomposed."""

# pylint: disable=import-outside-toplevel
import jax

U = jax.numpy.array(U, dtype=jax.numpy.complex128)
Expand Down Expand Up @@ -1046,6 +1061,7 @@ class TestTwoQubitUnitaryDecompositionInterfaces:
def test_two_qubit_decomposition_torch(self, U, wires):
"""Test that a two-qubit operation in Torch is correctly decomposed."""

# pylint: disable=import-outside-toplevel
import torch

U = torch.tensor(U, dtype=torch.complex128)
Expand All @@ -1067,6 +1083,7 @@ def test_two_qubit_decomposition_torch(self, U, wires):
def test_two_qubit_decomposition_tensor_products_torch(self, U_pair, wires):
"""Test that a two-qubit tensor product in Torch is correctly decomposed."""

# pylint: disable=import-outside-toplevel
import torch

U1 = torch.tensor(U_pair[0], dtype=torch.complex128)
Expand All @@ -1090,6 +1107,7 @@ def test_two_qubit_decomposition_tensor_products_torch(self, U_pair, wires):
def test_two_qubit_decomposition_tf(self, U, wires):
"""Test that a two-qubit operation in Tensorflow is correctly decomposed."""

# pylint: disable=import-outside-toplevel
import tensorflow as tf

U = tf.Variable(U, dtype=tf.complex128)
Expand All @@ -1111,6 +1129,7 @@ def test_two_qubit_decomposition_tf(self, U, wires):
def test_two_qubit_decomposition_tensor_products_tf(self, U_pair, wires):
"""Test that a two-qubit tensor product in Tensorflow is correctly decomposed."""

# pylint: disable=import-outside-toplevel
import tensorflow as tf

U1 = tf.Variable(U_pair[0], dtype=tf.complex128)
Expand All @@ -1134,6 +1153,7 @@ def test_two_qubit_decomposition_tensor_products_tf(self, U_pair, wires):
def test_two_qubit_decomposition_jax(self, U, wires):
"""Test that a two-qubit operation in JAX is correctly decomposed."""

# pylint: disable=import-outside-toplevel
import jax

U = jax.numpy.array(U, dtype=jax.numpy.complex128)
Expand All @@ -1155,6 +1175,7 @@ def test_two_qubit_decomposition_jax(self, U, wires):
def test_two_qubit_decomposition_tensor_products_jax(self, U_pair, wires):
"""Test that a two-qubit tensor product in JAX is correctly decomposed."""

# pylint: disable=import-outside-toplevel
import jax

U1 = jax.numpy.array(U_pair[0], dtype=jax.numpy.complex128)
Expand All @@ -1178,6 +1199,7 @@ def test_two_qubit_decomposition_tensor_products_jax(self, U_pair, wires):
def test_two_qubit_decomposition_jax_jit(self, U, wires):
"""Test that a two-qubit operation is correctly decomposed with JAX-JIT ."""

# pylint: disable=import-outside-toplevel
import jax

U = jax.numpy.array(U, dtype=jax.numpy.complex128)
Expand Down Expand Up @@ -1206,6 +1228,7 @@ def wrapped_decomposition(U):
def test_two_qubit_decomposition_tensor_products_jax_jit(self, U_pair, wires):
"""Test that a two-qubit tensor product is correctly decomposed with JAX-JIT."""

# pylint: disable=import-outside-toplevel
import jax

U1 = jax.numpy.array(U_pair[0], dtype=jax.numpy.complex128)
Expand Down Expand Up @@ -1263,3 +1286,128 @@ def expm(val):
mat = make_unitary(np.pi / 2)
decomp_mat = qml.matrix(two_qubit_decomposition, wire_order=(0, 1))(mat, wires=(0, 1))
assert qml.math.allclose(mat, decomp_mat)


class TestTwoQubitDecompositionWarnings:
"""Test suite for warning generation in two_qubit_decomposition"""

def test_warning_parameterized_autograd(self):
"""Test warning is raised for parameterized matrix with autograd"""
dev = qml.device("default.qubit", wires=2)

def my_qfunc(params):
U = qml.numpy.array(np.eye(4, dtype=np.complex128), requires_grad=True) * params
ops = qml.ops.two_qubit_decomposition(U, wires=[0, 1])
for op in ops:
qml.apply(op)
return qml.expval(qml.PauliZ(0))

qnode = qml.QNode(my_qfunc, dev, interface="autograd")

with pytest.warns(
RuntimeWarning, match="The two-qubit decomposition may not be differentiable"
):
qnode(1.0)

@pytest.mark.torch
def test_warning_parameterized_torch(self):
"""Test warning is raised for parameterized matrix with PyTorch"""
try:
# pylint: disable=import-outside-toplevel
import torch
except ImportError:
pytest.skip("PyTorch not installed")

dev = qml.device("default.qubit", wires=2)

def my_qfunc(params):
U = torch.eye(4, dtype=torch.complex128) * params
U.requires_grad_(True)
ops = qml.ops.two_qubit_decomposition(U, wires=[0, 1])
for op in ops:
qml.apply(op)
return qml.expval(qml.PauliZ(0))

qnode = qml.QNode(my_qfunc, dev, interface="torch")

with pytest.warns(
RuntimeWarning, match="The two-qubit decomposition may not be differentiable"
):
qnode(torch.tensor(1.0, dtype=torch.complex128, requires_grad=True))

@pytest.mark.tf
def test_warning_parameterized_tf(self):
"""Test warning is raised for parameterized matrix with TensorFlow"""
try:
# pylint: disable=import-outside-toplevel
import tensorflow as tf
except ImportError:
pytest.skip("TensorFlow not installed")

dev = qml.device("default.qubit", wires=2)

def my_qfunc(params):
params = tf.cast(params, tf.complex128)
U = tf.eye(4, dtype=tf.complex128) * params # Create tensor without Variable
with tf.GradientTape() as tape:
tape.watch(U) # Explicitly watch U
ops = qml.ops.two_qubit_decomposition(U, wires=[0, 1])
for op in ops:
qml.apply(op)
return qml.expval(qml.PauliZ(0))

qnode = qml.QNode(my_qfunc, dev, interface="tf")

with pytest.warns(
RuntimeWarning, match="The two-qubit decomposition may not be differentiable"
):
qnode(tf.constant(1.0))

@pytest.mark.jax
def test_warning_parameterized_jax(self):
"""Test warning is raised for parameterized matrix with JAX"""
try:
# pylint: disable=import-outside-toplevel
import jax
import jax.numpy as jnp
except ImportError:
pytest.skip("JAX not installed")

dev = qml.device("default.qubit", wires=2)

def my_qfunc(params):
U = jnp.array(np.eye(4, dtype=np.complex128)) * params
ops = qml.ops.two_qubit_decomposition(U, wires=[0, 1])
for op in ops:
qml.apply(op)
return qml.expval(qml.PauliZ(0))

qnode = qml.QNode(my_qfunc, dev, interface="jax")

# Convert function to one that JAX can differentiate
def cost(x):
return jnp.real(qnode(x))

with pytest.warns(
RuntimeWarning, match="The two-qubit decomposition may not be differentiable"
):
# Use JAX's grad to create a Tracer
jax.grad(cost)(1.0)

def test_warning_complex_input(self):
"""Test warning is raised with complex input parameters"""
dev = qml.device("default.qubit", wires=2)

def my_qfunc(params):
U = qml.numpy.array(np.eye(4, dtype=np.complex128), requires_grad=True) * params
ops = qml.ops.two_qubit_decomposition(U, wires=[0, 1])
for op in ops:
qml.apply(op)
return qml.expval(qml.PauliZ(0))

qnode = qml.QNode(my_qfunc, dev, interface="autograd")

with pytest.warns(
RuntimeWarning, match="The two-qubit decomposition may not be differentiable"
):
qnode(1.0 + 0.5j)