Skip to content

Commit

Permalink
fix expand_matrix for qudit matrices (#6398)
Browse files Browse the repository at this point in the history
**Context:**

We were incorrectly manipulating qutrit matrices.

**Description of the Change:**

Allow `expand_matrix` to work on higher qudit matrices.

**Benefits:**

**Possible Drawbacks:**

**Related GitHub Issues:**

[sc-75501] Fixes #6368
  • Loading branch information
albi3ro authored and mudit2812 committed Nov 11, 2024
1 parent b48a8be commit 4803e11
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 8 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,9 @@

<h3>Bug fixes 🐛</h3>

* Fixes `qml.math.expand_matrix` for qutrit and arbitrary qudit operators.
[(#6398)](https://github.com/PennyLaneAI/pennylane/pull/6398/)

* `MeasurementValue` now raises an error when it is used as a boolean.
[(#6386)](https://github.com/PennyLaneAI/pennylane/pull/6386)

Expand Down
27 changes: 19 additions & 8 deletions pennylane/math/matrix_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def expand_matrix(mat, wires, wire_order=None, sparse_format="csr"):
Args:
mat (tensor_like): matrix to expand
wires (Iterable): wires determining the subspace that ``mat`` acts on; a matrix of
dimension :math:`2^n` acts on a subspace of :math:`n` wires
dimension :math:`D^n` acts on a subspace of :math:`n` wires, where :math:`D` is the qudit dimension (2).
wire_order (Iterable): global wire order, which has to contain all wire labels in ``wires``, but can also
contain additional labels
sparse_format (str): if ``mat`` is a SciPy sparse matrix then this is the string representing the
Expand Down Expand Up @@ -101,11 +101,18 @@ def expand_matrix(mat, wires, wire_order=None, sparse_format="csr"):
[0., 0., 1., 0.]])
"""
wires = Wires(wires)

if wires:
float_dim = qml.math.shape(mat)[-1] ** (1 / (len(wires)))
qudit_dim = int(qml.math.round(float_dim))
else:
qudit_dim = 2 # if no wires, just assume qubit

if (wire_order is None) or (wire_order == wires):
return mat

if not wires and qml.math.shape(mat) == (2, 2):
if not wires and qml.math.shape(mat) == (qudit_dim, qudit_dim):
# global phase
wires = wire_order[0:1]

Expand All @@ -118,8 +125,8 @@ def expand_matrix(mat, wires, wire_order=None, sparse_format="csr"):

def eye_interface(dim):
if interface == "scipy":
return eye(2**dim, format="coo")
return qml.math.cast_like(qml.math.eye(2**dim, like=interface), mat)
return eye(qudit_dim**dim, format="coo")
return qml.math.cast_like(qml.math.eye(qudit_dim**dim, like=interface), mat)

def kron_interface(mat1, mat2):
if interface == "scipy":
Expand Down Expand Up @@ -154,7 +161,9 @@ def kron_interface(mat1, mat2):
if interface == "scipy":
mat = _permute_sparse_matrix(mat, expanded_wires, subset_wire_order)
else:
mat = _permute_dense_matrix(mat, expanded_wires, subset_wire_order, batch_dim)
mat = _permute_dense_matrix(
mat, expanded_wires, subset_wire_order, batch_dim, qudit_dim=qudit_dim
)

# expand the matrix even further if needed
if len(expanded_wires) < len(wire_order):
Expand Down Expand Up @@ -201,7 +210,7 @@ def _permute_sparse_matrix(matrix, wires, wire_order):
return matrix


def _permute_dense_matrix(matrix, wires, wire_order, batch_dim):
def _permute_dense_matrix(matrix, wires, wire_order, batch_dim, qudit_dim: int = 2):
"""Permute the matrix to match the wires given in `wire_order`.
Args:
Expand All @@ -228,12 +237,14 @@ def _permute_dense_matrix(matrix, wires, wire_order, batch_dim):

# reshape matrix to match wire values e.g. mat[0, 0, 0, 0] = <00|mat|00>
# with this reshape we can easily swap wires
shape = [batch_dim] + [2] * (num_wires * 2) if batch_dim else [2] * (num_wires * 2)
shape = (
[batch_dim] + [qudit_dim] * (num_wires * 2) if batch_dim else [qudit_dim] * (num_wires * 2)
)
matrix = qml.math.reshape(matrix, shape)
# transpose matrix
matrix = qml.math.transpose(matrix, axes=perm)
# reshape back
shape = [batch_dim] + [2**num_wires] * 2 if batch_dim else [2**num_wires] * 2
shape = [batch_dim] + [qudit_dim**num_wires] * 2 if batch_dim else [qudit_dim**num_wires] * 2
return qml.math.reshape(matrix, shape)


Expand Down
65 changes: 65 additions & 0 deletions tests/math/test_matrix_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import pennylane as qml
from pennylane import numpy as pnp
from pennylane.math import expand_matrix

# Define a list of dtypes to test
dtypes = ["complex64", "complex128"]
Expand Down Expand Up @@ -545,6 +546,70 @@ def compute_matrix():
assert np.allclose(op.matrix(wire_order=[0, 1, 2]), expanded_matrix, atol=tol)


class TestExpandMatrixQutrit:

def test_add_wire_at_end(self):
"""Test that expand_matrix works on qutrit matrices when an additional wire is added at the end."""

mat = np.reshape(np.arange(9), (3, 3))

new_mat = expand_matrix(mat, 0, (0, 1))
assert qml.math.allclose(new_mat, np.kron(mat, np.eye(3)))

def test_add_wire_at_start(self):
"""Test that expand_matrix works on qutrit matrices when an additional wire is added at the start."""

mat = np.reshape(np.arange(9), (3, 3))
new_mat = expand_matrix(mat, 0, (1, 0))
assert qml.math.allclose(new_mat, np.kron(np.eye(3), mat))

def test_wire_permutation(self):
"""Test that wires can be permuted."""
m1 = np.reshape(np.arange(81), (9, 9))
m2 = expand_matrix(m1, (0, 1), (1, 0))

# states across row are 00, 01, 02, 10, 11, 12, 20, 21, 22
# extract out right qubit state with mod
m1_wire_zero = m1 % 3
m2_wire_zero = m2 % 3

# extract out left qubit state with floor then mod
m1_wire_one = np.floor(m1 / 3) % 3
m2_wire_one = np.floor(m2 / 3) % 3

assert qml.math.allclose(m1_wire_zero, m2_wire_one)
assert qml.math.allclose(m1_wire_one, m2_wire_zero)

# check columns also switched
# now matrix numbers indicate row number
m1p = np.floor(m1 / 9)
m2p = np.floor(m2 / 9)

# states across column are 00, 01, 02, 10, 11, 12, 20, 21, 22
# extract out right qubit state with mod
m1_wire_zerop = m1p % 3
m2_wire_zerop = m2p % 3

# extract out left qubit state with floor then mod
m1_wire_onep = np.floor(m1p / 3) % 3
m2_wire_onep = np.floor(m2p / 3) % 3

assert qml.math.allclose(m1_wire_zerop, m2_wire_onep)
assert qml.math.allclose(m1_wire_onep, m2_wire_zerop)

def test_adding_wire_in_middle(self):
"""Test that expand_matrix can add an identity wire in the middle of a two qutrit matrix."""

m1 = np.reshape(np.arange(9), (3, 3))
m2 = np.reshape(np.arange(9, 18), (3, 3))
m3 = np.kron(m1, m2)

m3_added_wire = expand_matrix(m3, (0, 1), (1, 2, 0))
m3_kron = np.kron(np.kron(m2, np.eye(3)), m1)

assert qml.math.allclose(m3_added_wire, m3_kron)


class TestExpandMatrixSparse:
"""Tests for the _sparse_expand_matrix function."""

Expand Down

0 comments on commit 4803e11

Please sign in to comment.