diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 47d0e961d30..32c7a43b25f 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -273,6 +273,9 @@

Bug fixes 🐛

+* Fixes `qml.math.expand_matrix` for qutrit and arbitrary qudit operators. + [(#6398)](https://github.com/PennyLaneAI/pennylane/pull/6398/) + * `MeasurementValue` now raises an error when it is used as a boolean. [(#6386)](https://github.com/PennyLaneAI/pennylane/pull/6386) diff --git a/pennylane/math/matrix_manipulation.py b/pennylane/math/matrix_manipulation.py index 6e2d487ac19..98bea694eae 100644 --- a/pennylane/math/matrix_manipulation.py +++ b/pennylane/math/matrix_manipulation.py @@ -32,7 +32,7 @@ def expand_matrix(mat, wires, wire_order=None, sparse_format="csr"): Args: mat (tensor_like): matrix to expand wires (Iterable): wires determining the subspace that ``mat`` acts on; a matrix of - dimension :math:`2^n` acts on a subspace of :math:`n` wires + dimension :math:`D^n` acts on a subspace of :math:`n` wires, where :math:`D` is the qudit dimension (2). wire_order (Iterable): global wire order, which has to contain all wire labels in ``wires``, but can also contain additional labels sparse_format (str): if ``mat`` is a SciPy sparse matrix then this is the string representing the @@ -101,11 +101,18 @@ def expand_matrix(mat, wires, wire_order=None, sparse_format="csr"): [0., 0., 1., 0.]]) """ + wires = Wires(wires) + + if wires: + float_dim = qml.math.shape(mat)[-1] ** (1 / (len(wires))) + qudit_dim = int(qml.math.round(float_dim)) + else: + qudit_dim = 2 # if no wires, just assume qubit if (wire_order is None) or (wire_order == wires): return mat - if not wires and qml.math.shape(mat) == (2, 2): + if not wires and qml.math.shape(mat) == (qudit_dim, qudit_dim): # global phase wires = wire_order[0:1] @@ -118,8 +125,8 @@ def expand_matrix(mat, wires, wire_order=None, sparse_format="csr"): def eye_interface(dim): if interface == "scipy": - return eye(2**dim, format="coo") - return qml.math.cast_like(qml.math.eye(2**dim, like=interface), mat) + return eye(qudit_dim**dim, format="coo") + return qml.math.cast_like(qml.math.eye(qudit_dim**dim, like=interface), mat) def kron_interface(mat1, mat2): if interface == "scipy": @@ -154,7 +161,9 @@ def kron_interface(mat1, mat2): if interface == "scipy": mat = _permute_sparse_matrix(mat, expanded_wires, subset_wire_order) else: - mat = _permute_dense_matrix(mat, expanded_wires, subset_wire_order, batch_dim) + mat = _permute_dense_matrix( + mat, expanded_wires, subset_wire_order, batch_dim, qudit_dim=qudit_dim + ) # expand the matrix even further if needed if len(expanded_wires) < len(wire_order): @@ -201,7 +210,7 @@ def _permute_sparse_matrix(matrix, wires, wire_order): return matrix -def _permute_dense_matrix(matrix, wires, wire_order, batch_dim): +def _permute_dense_matrix(matrix, wires, wire_order, batch_dim, qudit_dim: int = 2): """Permute the matrix to match the wires given in `wire_order`. Args: @@ -228,12 +237,14 @@ def _permute_dense_matrix(matrix, wires, wire_order, batch_dim): # reshape matrix to match wire values e.g. mat[0, 0, 0, 0] = <00|mat|00> # with this reshape we can easily swap wires - shape = [batch_dim] + [2] * (num_wires * 2) if batch_dim else [2] * (num_wires * 2) + shape = ( + [batch_dim] + [qudit_dim] * (num_wires * 2) if batch_dim else [qudit_dim] * (num_wires * 2) + ) matrix = qml.math.reshape(matrix, shape) # transpose matrix matrix = qml.math.transpose(matrix, axes=perm) # reshape back - shape = [batch_dim] + [2**num_wires] * 2 if batch_dim else [2**num_wires] * 2 + shape = [batch_dim] + [qudit_dim**num_wires] * 2 if batch_dim else [qudit_dim**num_wires] * 2 return qml.math.reshape(matrix, shape) diff --git a/tests/math/test_matrix_manipulation.py b/tests/math/test_matrix_manipulation.py index 5c8be485909..d6b4029a0ca 100644 --- a/tests/math/test_matrix_manipulation.py +++ b/tests/math/test_matrix_manipulation.py @@ -22,6 +22,7 @@ import pennylane as qml from pennylane import numpy as pnp +from pennylane.math import expand_matrix # Define a list of dtypes to test dtypes = ["complex64", "complex128"] @@ -545,6 +546,70 @@ def compute_matrix(): assert np.allclose(op.matrix(wire_order=[0, 1, 2]), expanded_matrix, atol=tol) +class TestExpandMatrixQutrit: + + def test_add_wire_at_end(self): + """Test that expand_matrix works on qutrit matrices when an additional wire is added at the end.""" + + mat = np.reshape(np.arange(9), (3, 3)) + + new_mat = expand_matrix(mat, 0, (0, 1)) + assert qml.math.allclose(new_mat, np.kron(mat, np.eye(3))) + + def test_add_wire_at_start(self): + """Test that expand_matrix works on qutrit matrices when an additional wire is added at the start.""" + + mat = np.reshape(np.arange(9), (3, 3)) + new_mat = expand_matrix(mat, 0, (1, 0)) + assert qml.math.allclose(new_mat, np.kron(np.eye(3), mat)) + + def test_wire_permutation(self): + """Test that wires can be permuted.""" + m1 = np.reshape(np.arange(81), (9, 9)) + m2 = expand_matrix(m1, (0, 1), (1, 0)) + + # states across row are 00, 01, 02, 10, 11, 12, 20, 21, 22 + # extract out right qubit state with mod + m1_wire_zero = m1 % 3 + m2_wire_zero = m2 % 3 + + # extract out left qubit state with floor then mod + m1_wire_one = np.floor(m1 / 3) % 3 + m2_wire_one = np.floor(m2 / 3) % 3 + + assert qml.math.allclose(m1_wire_zero, m2_wire_one) + assert qml.math.allclose(m1_wire_one, m2_wire_zero) + + # check columns also switched + # now matrix numbers indicate row number + m1p = np.floor(m1 / 9) + m2p = np.floor(m2 / 9) + + # states across column are 00, 01, 02, 10, 11, 12, 20, 21, 22 + # extract out right qubit state with mod + m1_wire_zerop = m1p % 3 + m2_wire_zerop = m2p % 3 + + # extract out left qubit state with floor then mod + m1_wire_onep = np.floor(m1p / 3) % 3 + m2_wire_onep = np.floor(m2p / 3) % 3 + + assert qml.math.allclose(m1_wire_zerop, m2_wire_onep) + assert qml.math.allclose(m1_wire_onep, m2_wire_zerop) + + def test_adding_wire_in_middle(self): + """Test that expand_matrix can add an identity wire in the middle of a two qutrit matrix.""" + + m1 = np.reshape(np.arange(9), (3, 3)) + m2 = np.reshape(np.arange(9, 18), (3, 3)) + m3 = np.kron(m1, m2) + + m3_added_wire = expand_matrix(m3, (0, 1), (1, 2, 0)) + m3_kron = np.kron(np.kron(m2, np.eye(3)), m1) + + assert qml.math.allclose(m3_added_wire, m3_kron) + + class TestExpandMatrixSparse: """Tests for the _sparse_expand_matrix function."""