Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

qml.wire.Wires accepts JAX arrays #6312

Merged
merged 26 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
ab9c274
E.C. [ci skip]
PietropaoloFrisoni Sep 30, 2024
63caa8e
First test with CI
PietropaoloFrisoni Sep 30, 2024
4d311cb
Intercepting tracers
PietropaoloFrisoni Sep 30, 2024
29348f8
Disregarding tracers for now
PietropaoloFrisoni Sep 30, 2024
7f7dde0
FIxing case without jax installed
PietropaoloFrisoni Oct 1, 2024
449a128
Hopefully this works
PietropaoloFrisoni Oct 1, 2024
573fb86
Adding pragma no cover. TODO: add tests
PietropaoloFrisoni Oct 1, 2024
bd4d54a
Test with tests
PietropaoloFrisoni Oct 1, 2024
99ac118
Solving errors for tests with JAX not installed (?)
PietropaoloFrisoni Oct 1, 2024
a70ea4b
Moving tests in a separate file
PietropaoloFrisoni Oct 1, 2024
ecfb9dc
Updating changelog and adding tests
PietropaoloFrisoni Oct 1, 2024
352d988
Move check outside the _process function
PietropaoloFrisoni Oct 1, 2024
2ab99a2
Using `qml.math` to detect JAX interface
PietropaoloFrisoni Oct 2, 2024
e18af4b
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Oct 2, 2024
48d92fb
Importorskip inside the class
PietropaoloFrisoni Oct 3, 2024
8e5cce6
Merge branch 'master' into Wires_accept_JAX_array
PietropaoloFrisoni Oct 3, 2024
d25eed8
Other solution
PietropaoloFrisoni Oct 3, 2024
cdc9b80
Merge branch 'Wires_accept_JAX_array' of https://github.com/PennyLane…
PietropaoloFrisoni Oct 3, 2024
ff61eab
Other solution
PietropaoloFrisoni Oct 3, 2024
e6674eb
Other solution
PietropaoloFrisoni Oct 3, 2024
f08baa6
Final solution
PietropaoloFrisoni Oct 3, 2024
2fb858f
Repetition in test
PietropaoloFrisoni Oct 3, 2024
45f7cf0
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Oct 3, 2024
4c61871
Creating one unique PR
PietropaoloFrisoni Oct 7, 2024
94a776f
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Oct 7, 2024
0b6346e
Merge branch 'master' into Wires_accept_JAX_array
PietropaoloFrisoni Oct 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@

<h4>Capturing and representing hybrid programs</h4>

* `qml.wires.Wires` now accepts JAX arrays as input.
[(#6312)](https://github.com/PennyLaneAI/pennylane/pull/6312)

* Differentiation of hybrid programs via `qml.grad` and `qml.jacobian` can now be captured
into plxpr. When evaluating a captured `qml.grad` (`qml.jacobian`) instruction, it will
dispatch to `jax.grad` (`jax.jacobian`), which differs from the Autograd implementation
Expand Down
12 changes: 12 additions & 0 deletions pennylane/wires.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import functools
import itertools
from collections.abc import Hashable, Iterable, Sequence
from importlib import import_module, util
from typing import Union

import numpy as np
Expand Down Expand Up @@ -50,6 +51,17 @@ def _process(wires):
# of considering the elements of iterables as wire labels.
wires = [wires]

if util.find_spec("jax") is not None:
jax = import_module("jax")

if isinstance(wires, jax.numpy.ndarray) and not isinstance(wires, jax.core.Tracer):
wires = tuple(wires.tolist() if wires.ndim > 0 else (wires.item(),))
PietropaoloFrisoni marked this conversation as resolved.
Show resolved Hide resolved
else:
if "jax" in str(type(wires)):
raise ImportError( # pragma: no cover
"JAX is required to process wires that are JAX arrays. "
"You can install it using: pip install jax jaxlib"
)
PietropaoloFrisoni marked this conversation as resolved.
Show resolved Hide resolved
try:
# Use tuple conversion as a check for whether `wires` can be iterated over.
# Note, this is not the same as `isinstance(wires, Iterable)` which would
Expand Down
85 changes: 85 additions & 0 deletions tests/capture/test_wires_jax.py
PietropaoloFrisoni marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright 2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tests the support for JAX arrays in the ``Wires`` class.
"""
import pytest

from pennylane.wires import WireError, Wires

jax = pytest.importorskip("jax")

pytestmark = pytest.mark.jax


class TestWiresJax:
"""Tests the support for JAX arrays in the ``Wires`` class."""

@pytest.mark.parametrize(
"iterable, expected",
[
(jax.numpy.array([0, 1, 2]), (0, 1, 2)),
(jax.numpy.array([0]), (0,)),
(jax.numpy.array(0), (0,)),
(jax.numpy.array([]), ()),
],
)
def test_creation_from_jax_array(self, iterable, expected):
"""Tests that a Wires object can be created from a JAX array."""

wires = Wires(iterable)
assert wires.labels == expected

@pytest.mark.parametrize(
"input",
[
[jax.numpy.array([0, 1, 2]), jax.numpy.array([3, 4])],
[jax.numpy.array([0, 1, 2]), 3],
jax.numpy.array([[0, 1, 2]]),
],
)
def test_error_for_incorrect_jax_arrays(self, input):
"""Tests that a Wires object cannot be created from incorrect JAX arrays."""

with pytest.raises(WireError, match="Wires must be hashable"):
Wires(input)

@pytest.mark.parametrize("iterable", [jax.numpy.array([4, 1, 1, 3]), jax.numpy.array([0, 0])])
def test_error_for_repeated_wires_jax(self, iterable):
"""Tests that a Wires object cannot be created from a JAX array with repeated indices."""

with pytest.raises(WireError, match="Wires must be unique"):
Wires(iterable)

def test_array_representation_jax(self):
"""Tests that Wires object has an array representation with JAX."""

wires = Wires([4, 0, 1])
array = jax.numpy.array(wires.labels)
assert isinstance(array, jax.numpy.ndarray)
assert array.shape == (3,)
for w1, w2 in zip(array, jax.numpy.array([4, 0, 1])):
assert w1 == w2

@pytest.mark.parametrize(
"source", [jax.numpy.array([0, 1, 2]), jax.numpy.array([0]), jax.numpy.array(0)]
)
def test_jax_wires_pytree(self, source):
"""Test that Wires class supports the PyTree flattening interface with JAX arrays."""

wires = Wires(source)
wires_flat, tree = jax.tree_util.tree_flatten(wires)
wires2 = jax.tree_util.tree_unflatten(tree, wires_flat)
assert isinstance(wires2, Wires), f"{wires2} is not Wires"
assert wires == wires2, f"{wires} != {wires2}"
Loading