diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 47d0e961d30..32c7a43b25f 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -273,6 +273,9 @@
Bug fixes 🐛
+* Fixes `qml.math.expand_matrix` for qutrit and arbitrary qudit operators.
+ [(#6398)](https://github.com/PennyLaneAI/pennylane/pull/6398/)
+
* `MeasurementValue` now raises an error when it is used as a boolean.
[(#6386)](https://github.com/PennyLaneAI/pennylane/pull/6386)
diff --git a/pennylane/math/matrix_manipulation.py b/pennylane/math/matrix_manipulation.py
index 6e2d487ac19..98bea694eae 100644
--- a/pennylane/math/matrix_manipulation.py
+++ b/pennylane/math/matrix_manipulation.py
@@ -32,7 +32,7 @@ def expand_matrix(mat, wires, wire_order=None, sparse_format="csr"):
Args:
mat (tensor_like): matrix to expand
wires (Iterable): wires determining the subspace that ``mat`` acts on; a matrix of
- dimension :math:`2^n` acts on a subspace of :math:`n` wires
+ dimension :math:`D^n` acts on a subspace of :math:`n` wires, where :math:`D` is the qudit dimension (2).
wire_order (Iterable): global wire order, which has to contain all wire labels in ``wires``, but can also
contain additional labels
sparse_format (str): if ``mat`` is a SciPy sparse matrix then this is the string representing the
@@ -101,11 +101,18 @@ def expand_matrix(mat, wires, wire_order=None, sparse_format="csr"):
[0., 0., 1., 0.]])
"""
+ wires = Wires(wires)
+
+ if wires:
+ float_dim = qml.math.shape(mat)[-1] ** (1 / (len(wires)))
+ qudit_dim = int(qml.math.round(float_dim))
+ else:
+ qudit_dim = 2 # if no wires, just assume qubit
if (wire_order is None) or (wire_order == wires):
return mat
- if not wires and qml.math.shape(mat) == (2, 2):
+ if not wires and qml.math.shape(mat) == (qudit_dim, qudit_dim):
# global phase
wires = wire_order[0:1]
@@ -118,8 +125,8 @@ def expand_matrix(mat, wires, wire_order=None, sparse_format="csr"):
def eye_interface(dim):
if interface == "scipy":
- return eye(2**dim, format="coo")
- return qml.math.cast_like(qml.math.eye(2**dim, like=interface), mat)
+ return eye(qudit_dim**dim, format="coo")
+ return qml.math.cast_like(qml.math.eye(qudit_dim**dim, like=interface), mat)
def kron_interface(mat1, mat2):
if interface == "scipy":
@@ -154,7 +161,9 @@ def kron_interface(mat1, mat2):
if interface == "scipy":
mat = _permute_sparse_matrix(mat, expanded_wires, subset_wire_order)
else:
- mat = _permute_dense_matrix(mat, expanded_wires, subset_wire_order, batch_dim)
+ mat = _permute_dense_matrix(
+ mat, expanded_wires, subset_wire_order, batch_dim, qudit_dim=qudit_dim
+ )
# expand the matrix even further if needed
if len(expanded_wires) < len(wire_order):
@@ -201,7 +210,7 @@ def _permute_sparse_matrix(matrix, wires, wire_order):
return matrix
-def _permute_dense_matrix(matrix, wires, wire_order, batch_dim):
+def _permute_dense_matrix(matrix, wires, wire_order, batch_dim, qudit_dim: int = 2):
"""Permute the matrix to match the wires given in `wire_order`.
Args:
@@ -228,12 +237,14 @@ def _permute_dense_matrix(matrix, wires, wire_order, batch_dim):
# reshape matrix to match wire values e.g. mat[0, 0, 0, 0] = <00|mat|00>
# with this reshape we can easily swap wires
- shape = [batch_dim] + [2] * (num_wires * 2) if batch_dim else [2] * (num_wires * 2)
+ shape = (
+ [batch_dim] + [qudit_dim] * (num_wires * 2) if batch_dim else [qudit_dim] * (num_wires * 2)
+ )
matrix = qml.math.reshape(matrix, shape)
# transpose matrix
matrix = qml.math.transpose(matrix, axes=perm)
# reshape back
- shape = [batch_dim] + [2**num_wires] * 2 if batch_dim else [2**num_wires] * 2
+ shape = [batch_dim] + [qudit_dim**num_wires] * 2 if batch_dim else [qudit_dim**num_wires] * 2
return qml.math.reshape(matrix, shape)
diff --git a/tests/math/test_matrix_manipulation.py b/tests/math/test_matrix_manipulation.py
index 5c8be485909..d6b4029a0ca 100644
--- a/tests/math/test_matrix_manipulation.py
+++ b/tests/math/test_matrix_manipulation.py
@@ -22,6 +22,7 @@
import pennylane as qml
from pennylane import numpy as pnp
+from pennylane.math import expand_matrix
# Define a list of dtypes to test
dtypes = ["complex64", "complex128"]
@@ -545,6 +546,70 @@ def compute_matrix():
assert np.allclose(op.matrix(wire_order=[0, 1, 2]), expanded_matrix, atol=tol)
+class TestExpandMatrixQutrit:
+
+ def test_add_wire_at_end(self):
+ """Test that expand_matrix works on qutrit matrices when an additional wire is added at the end."""
+
+ mat = np.reshape(np.arange(9), (3, 3))
+
+ new_mat = expand_matrix(mat, 0, (0, 1))
+ assert qml.math.allclose(new_mat, np.kron(mat, np.eye(3)))
+
+ def test_add_wire_at_start(self):
+ """Test that expand_matrix works on qutrit matrices when an additional wire is added at the start."""
+
+ mat = np.reshape(np.arange(9), (3, 3))
+ new_mat = expand_matrix(mat, 0, (1, 0))
+ assert qml.math.allclose(new_mat, np.kron(np.eye(3), mat))
+
+ def test_wire_permutation(self):
+ """Test that wires can be permuted."""
+ m1 = np.reshape(np.arange(81), (9, 9))
+ m2 = expand_matrix(m1, (0, 1), (1, 0))
+
+ # states across row are 00, 01, 02, 10, 11, 12, 20, 21, 22
+ # extract out right qubit state with mod
+ m1_wire_zero = m1 % 3
+ m2_wire_zero = m2 % 3
+
+ # extract out left qubit state with floor then mod
+ m1_wire_one = np.floor(m1 / 3) % 3
+ m2_wire_one = np.floor(m2 / 3) % 3
+
+ assert qml.math.allclose(m1_wire_zero, m2_wire_one)
+ assert qml.math.allclose(m1_wire_one, m2_wire_zero)
+
+ # check columns also switched
+ # now matrix numbers indicate row number
+ m1p = np.floor(m1 / 9)
+ m2p = np.floor(m2 / 9)
+
+ # states across column are 00, 01, 02, 10, 11, 12, 20, 21, 22
+ # extract out right qubit state with mod
+ m1_wire_zerop = m1p % 3
+ m2_wire_zerop = m2p % 3
+
+ # extract out left qubit state with floor then mod
+ m1_wire_onep = np.floor(m1p / 3) % 3
+ m2_wire_onep = np.floor(m2p / 3) % 3
+
+ assert qml.math.allclose(m1_wire_zerop, m2_wire_onep)
+ assert qml.math.allclose(m1_wire_onep, m2_wire_zerop)
+
+ def test_adding_wire_in_middle(self):
+ """Test that expand_matrix can add an identity wire in the middle of a two qutrit matrix."""
+
+ m1 = np.reshape(np.arange(9), (3, 3))
+ m2 = np.reshape(np.arange(9, 18), (3, 3))
+ m3 = np.kron(m1, m2)
+
+ m3_added_wire = expand_matrix(m3, (0, 1), (1, 2, 0))
+ m3_kron = np.kron(np.kron(m2, np.eye(3)), m1)
+
+ assert qml.math.allclose(m3_added_wire, m3_kron)
+
+
class TestExpandMatrixSparse:
"""Tests for the _sparse_expand_matrix function."""