From ab9c2747487c12270a5b051537b440113a27dd57 Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Mon, 30 Sep 2024 14:58:11 -0400 Subject: [PATCH 01/20] E.C. [ci skip] From 63caa8e46d41fad37aa0e048984c9d8e630ab67a Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Mon, 30 Sep 2024 18:49:02 -0400 Subject: [PATCH 02/20] First test with CI --- pennylane/wires.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/pennylane/wires.py b/pennylane/wires.py index 75e3354e89e..e9ccc5e4bf1 100644 --- a/pennylane/wires.py +++ b/pennylane/wires.py @@ -50,6 +50,20 @@ def _process(wires): # of considering the elements of iterables as wire labels. wires = [wires] + if "jax" in str(type(wires)): + try: + # pylint: disable=import-outside-toplevel + import jax + + if isinstance(wires, jax.numpy.ndarray): + wires = tuple(wires.tolist()) + + except ImportError as exc: + raise ImportError( + "JAX is required to process this input. " + "You can install jax via: pip install jax jaxlib" + ) from exc + try: # Use tuple conversion as a check for whether `wires` can be iterated over. # Note, this is not the same as `isinstance(wires, Iterable)` which would From 4d311cb478727e6d6e7fa086561365596094b466 Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Mon, 30 Sep 2024 19:14:21 -0400 Subject: [PATCH 03/20] Intercepting tracers --- pennylane/wires.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pennylane/wires.py b/pennylane/wires.py index e9ccc5e4bf1..6598edeca3e 100644 --- a/pennylane/wires.py +++ b/pennylane/wires.py @@ -56,7 +56,7 @@ def _process(wires): import jax if isinstance(wires, jax.numpy.ndarray): - wires = tuple(wires.tolist()) + wires = wires if isinstance(wires, jax.core.Tracer) else tuple(wires.tolist()) except ImportError as exc: raise ImportError( From 29348f8d37fc1cb7c737e36775ceeba02797a1f1 Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Mon, 30 Sep 2024 19:54:22 -0400 Subject: [PATCH 04/20] Disregarding tracers for now --- pennylane/wires.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/pennylane/wires.py b/pennylane/wires.py index 6598edeca3e..1a40a3bcda8 100644 --- a/pennylane/wires.py +++ b/pennylane/wires.py @@ -50,19 +50,18 @@ def _process(wires): # of considering the elements of iterables as wire labels. wires = [wires] - if "jax" in str(type(wires)): - try: - # pylint: disable=import-outside-toplevel - import jax + try: + # pylint: disable=import-outside-toplevel + import jax - if isinstance(wires, jax.numpy.ndarray): - wires = wires if isinstance(wires, jax.core.Tracer) else tuple(wires.tolist()) + if isinstance(wires, jax.numpy.ndarray) and not isinstance(wires, jax.core.Tracer): + wires = tuple(wires.tolist() if wires.ndim > 0 else (wires.item(),)) - except ImportError as exc: - raise ImportError( - "JAX is required to process this input. " - "You can install jax via: pip install jax jaxlib" - ) from exc + except ImportError as exc: + raise ImportError( + "JAX is required to process this input. " + "Please install it via: pip install jax jaxlib" + ) from exc try: # Use tuple conversion as a check for whether `wires` can be iterated over. From 7f7dde0a54b6572ada849474186969b6c119677f Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Mon, 30 Sep 2024 20:12:09 -0400 Subject: [PATCH 05/20] FIxing case without jax installed --- pennylane/wires.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/pennylane/wires.py b/pennylane/wires.py index 1a40a3bcda8..cd58e854ed0 100644 --- a/pennylane/wires.py +++ b/pennylane/wires.py @@ -18,6 +18,7 @@ import itertools from collections.abc import Hashable, Iterable, Sequence from typing import Union +from importlib import import_module, util import numpy as np @@ -50,19 +51,18 @@ def _process(wires): # of considering the elements of iterables as wire labels. wires = [wires] - try: - # pylint: disable=import-outside-toplevel - import jax + jax_spec = util.find_spec("jax") + if jax_spec is not None: + jax = import_module("jax") if isinstance(wires, jax.numpy.ndarray) and not isinstance(wires, jax.core.Tracer): wires = tuple(wires.tolist() if wires.ndim > 0 else (wires.item(),)) - - except ImportError as exc: - raise ImportError( - "JAX is required to process this input. " - "Please install it via: pip install jax jaxlib" - ) from exc - + # TODO: something like qml.wires.Wires(jax.numpy.array(2)) should not work since it is not hashable + else: + if isinstance(wires, jax.numpy.ndarray): + raise ImportError( + "JAX is required to process this input. Please install it via: pip install jax jaxlib" + ) try: # Use tuple conversion as a check for whether `wires` can be iterated over. # Note, this is not the same as `isinstance(wires, Iterable)` which would From 449a12837a78c2520247b6a747c205f20417e286 Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Mon, 30 Sep 2024 20:26:41 -0400 Subject: [PATCH 06/20] Hopefully this works --- pennylane/wires.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/pennylane/wires.py b/pennylane/wires.py index cd58e854ed0..9e39939cf7e 100644 --- a/pennylane/wires.py +++ b/pennylane/wires.py @@ -17,8 +17,8 @@ import functools import itertools from collections.abc import Hashable, Iterable, Sequence -from typing import Union from importlib import import_module, util +from typing import Union import numpy as np @@ -51,17 +51,16 @@ def _process(wires): # of considering the elements of iterables as wire labels. wires = [wires] - jax_spec = util.find_spec("jax") - if jax_spec is not None: + if util.find_spec("jax") is not None: jax = import_module("jax") if isinstance(wires, jax.numpy.ndarray) and not isinstance(wires, jax.core.Tracer): wires = tuple(wires.tolist() if wires.ndim > 0 else (wires.item(),)) - # TODO: something like qml.wires.Wires(jax.numpy.array(2)) should not work since it is not hashable else: - if isinstance(wires, jax.numpy.ndarray): + if "jax" in str(type(wires)): raise ImportError( - "JAX is required to process this input. Please install it via: pip install jax jaxlib" + "JAX is required to process wires that are JAX arrays. " + "You can install it using: pip install jax jaxlib" ) try: # Use tuple conversion as a check for whether `wires` can be iterated over. From 573fb86810418667f258d154c68f2c7b03d994fe Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Tue, 1 Oct 2024 11:08:08 -0400 Subject: [PATCH 07/20] Adding pragma no cover. TODO: add tests [ci skip] --- pennylane/wires.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pennylane/wires.py b/pennylane/wires.py index 9e39939cf7e..349c3e976af 100644 --- a/pennylane/wires.py +++ b/pennylane/wires.py @@ -58,7 +58,7 @@ def _process(wires): wires = tuple(wires.tolist() if wires.ndim > 0 else (wires.item(),)) else: if "jax" in str(type(wires)): - raise ImportError( + raise ImportError( # pragma: no cover "JAX is required to process wires that are JAX arrays. " "You can install it using: pip install jax jaxlib" ) From bd4d54adaf8f7323122f0459630475a6f441e76c Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Tue, 1 Oct 2024 12:41:42 -0400 Subject: [PATCH 08/20] Test with tests --- tests/test_wires.py | 46 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/tests/test_wires.py b/tests/test_wires.py index 5ceb6475bb9..c4f84094d27 100644 --- a/tests/test_wires.py +++ b/tests/test_wires.py @@ -487,3 +487,49 @@ def test_complex_operation(self): expected = Wires([0, 1, 2, 3, 4, 5, 6, 7]) assert result == expected + + +@pytest.mark.jax +class TestWiresJax: + """Tests the support for JAX arrays in the ``Wires`` class.""" + + import jax + + @pytest.mark.parametrize( + "iterable, expected", + [ + (jax.numpy.array([0, 1, 2]), (0, 1, 2)), + (jax.numpy.array([0]), (0,)), + (jax.numpy.array([]), ()), + ], + ) + def test_creation_from_jax_array(self, iterable, expected): + """Tests that a Wires object can be created from a JAX array.""" + + wires = Wires(iterable) + assert wires.labels == expected + + @pytest.mark.parametrize("input", [[jax.numpy.array([0, 1, 2]), jax.numpy.array([3, 4])]]) + def test_error_for_incorrect_wire_types(self, input): + """Tests that a Wires object cannot be created from a list of JAX arrays.""" + + with pytest.raises(WireError, match="Wires must be hashable"): + Wires(input) + + @pytest.mark.parametrize("iterable", [jax.numpy.array([4, 1, 1, 3]), jax.numpy.array([0, 0])]) + def test_error_for_repeated_wires_jax(self, iterable): + """Tests that a Wires object cannot be created from a JAX array with repeated indices.""" + + with pytest.raises(WireError, match="Wires must be unique"): + Wires(iterable) + + def test_array_representation_jax(self): + """Tests that Wires object has an array representation with JAX.""" + import jax + + wires = Wires([4, 0, 1]) + array = jax.numpy.array(wires.labels) + assert isinstance(array, jax.numpy.ndarray) + assert array.shape == (3,) + for w1, w2 in zip(array, jax.numpy.array([4, 0, 1])): + assert w1 == w2 From 99ac118b714806630a905fa5ef34ec32e59f5773 Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Tue, 1 Oct 2024 14:54:53 -0400 Subject: [PATCH 09/20] Solving errors for tests with JAX not installed (?) --- tests/test_wires.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/test_wires.py b/tests/test_wires.py index c4f84094d27..1f7ff99ac58 100644 --- a/tests/test_wires.py +++ b/tests/test_wires.py @@ -493,8 +493,12 @@ def test_complex_operation(self): class TestWiresJax: """Tests the support for JAX arrays in the ``Wires`` class.""" - import jax + try: + import jax + except ImportError: + pytest.skip(reason="JAX not installed") + @pytest.mark.jax @pytest.mark.parametrize( "iterable, expected", [ @@ -509,6 +513,7 @@ def test_creation_from_jax_array(self, iterable, expected): wires = Wires(iterable) assert wires.labels == expected + @pytest.mark.jax @pytest.mark.parametrize("input", [[jax.numpy.array([0, 1, 2]), jax.numpy.array([3, 4])]]) def test_error_for_incorrect_wire_types(self, input): """Tests that a Wires object cannot be created from a list of JAX arrays.""" @@ -516,6 +521,7 @@ def test_error_for_incorrect_wire_types(self, input): with pytest.raises(WireError, match="Wires must be hashable"): Wires(input) + @pytest.mark.jax @pytest.mark.parametrize("iterable", [jax.numpy.array([4, 1, 1, 3]), jax.numpy.array([0, 0])]) def test_error_for_repeated_wires_jax(self, iterable): """Tests that a Wires object cannot be created from a JAX array with repeated indices.""" @@ -523,6 +529,7 @@ def test_error_for_repeated_wires_jax(self, iterable): with pytest.raises(WireError, match="Wires must be unique"): Wires(iterable) + @pytest.mark.jax def test_array_representation_jax(self): """Tests that Wires object has an array representation with JAX.""" import jax From a70ea4b4f546ccdc46a37e108b0f1e47e0266d07 Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Tue, 1 Oct 2024 15:10:51 -0400 Subject: [PATCH 10/20] Moving tests in a separate file --- tests/capture/test_wires_jax.py | 65 +++++++++++++++++++++++++++++++++ tests/test_wires.py | 53 --------------------------- 2 files changed, 65 insertions(+), 53 deletions(-) create mode 100644 tests/capture/test_wires_jax.py diff --git a/tests/capture/test_wires_jax.py b/tests/capture/test_wires_jax.py new file mode 100644 index 00000000000..4216da249d3 --- /dev/null +++ b/tests/capture/test_wires_jax.py @@ -0,0 +1,65 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests the support for JAX arrays in the ``Wires`` class. +""" +import pytest + +from pennylane.wires import WireError, Wires + +jax = pytest.importorskip("jax") + +pytestmark = pytest.mark.jax + + +class TestWiresJax: + """Tests the support for JAX arrays in the ``Wires`` class.""" + + @pytest.mark.parametrize( + "iterable, expected", + [ + (jax.numpy.array([0, 1, 2]), (0, 1, 2)), + (jax.numpy.array([0]), (0,)), + (jax.numpy.array([]), ()), + ], + ) + def test_creation_from_jax_array(self, iterable, expected): + """Tests that a Wires object can be created from a JAX array.""" + + wires = Wires(iterable) + assert wires.labels == expected + + @pytest.mark.parametrize("input", [[jax.numpy.array([0, 1, 2]), jax.numpy.array([3, 4])]]) + def test_error_for_incorrect_wire_types(self, input): + """Tests that a Wires object cannot be created from a list of JAX arrays.""" + + with pytest.raises(WireError, match="Wires must be hashable"): + Wires(input) + + @pytest.mark.parametrize("iterable", [jax.numpy.array([4, 1, 1, 3]), jax.numpy.array([0, 0])]) + def test_error_for_repeated_wires_jax(self, iterable): + """Tests that a Wires object cannot be created from a JAX array with repeated indices.""" + + with pytest.raises(WireError, match="Wires must be unique"): + Wires(iterable) + + def test_array_representation_jax(self): + """Tests that Wires object has an array representation with JAX.""" + + wires = Wires([4, 0, 1]) + array = jax.numpy.array(wires.labels) + assert isinstance(array, jax.numpy.ndarray) + assert array.shape == (3,) + for w1, w2 in zip(array, jax.numpy.array([4, 0, 1])): + assert w1 == w2 diff --git a/tests/test_wires.py b/tests/test_wires.py index 1f7ff99ac58..5ceb6475bb9 100644 --- a/tests/test_wires.py +++ b/tests/test_wires.py @@ -487,56 +487,3 @@ def test_complex_operation(self): expected = Wires([0, 1, 2, 3, 4, 5, 6, 7]) assert result == expected - - -@pytest.mark.jax -class TestWiresJax: - """Tests the support for JAX arrays in the ``Wires`` class.""" - - try: - import jax - except ImportError: - pytest.skip(reason="JAX not installed") - - @pytest.mark.jax - @pytest.mark.parametrize( - "iterable, expected", - [ - (jax.numpy.array([0, 1, 2]), (0, 1, 2)), - (jax.numpy.array([0]), (0,)), - (jax.numpy.array([]), ()), - ], - ) - def test_creation_from_jax_array(self, iterable, expected): - """Tests that a Wires object can be created from a JAX array.""" - - wires = Wires(iterable) - assert wires.labels == expected - - @pytest.mark.jax - @pytest.mark.parametrize("input", [[jax.numpy.array([0, 1, 2]), jax.numpy.array([3, 4])]]) - def test_error_for_incorrect_wire_types(self, input): - """Tests that a Wires object cannot be created from a list of JAX arrays.""" - - with pytest.raises(WireError, match="Wires must be hashable"): - Wires(input) - - @pytest.mark.jax - @pytest.mark.parametrize("iterable", [jax.numpy.array([4, 1, 1, 3]), jax.numpy.array([0, 0])]) - def test_error_for_repeated_wires_jax(self, iterable): - """Tests that a Wires object cannot be created from a JAX array with repeated indices.""" - - with pytest.raises(WireError, match="Wires must be unique"): - Wires(iterable) - - @pytest.mark.jax - def test_array_representation_jax(self): - """Tests that Wires object has an array representation with JAX.""" - import jax - - wires = Wires([4, 0, 1]) - array = jax.numpy.array(wires.labels) - assert isinstance(array, jax.numpy.ndarray) - assert array.shape == (3,) - for w1, w2 in zip(array, jax.numpy.array([4, 0, 1])): - assert w1 == w2 From ecfb9dc017e23d87da472869f4040918aa9caf87 Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Tue, 1 Oct 2024 16:09:30 -0400 Subject: [PATCH 11/20] Updating changelog and adding tests --- doc/releases/changelog-dev.md | 3 +++ tests/capture/test_wires_jax.py | 26 +++++++++++++++++++++++--- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index b9229c68dee..d401d025950 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -45,6 +45,9 @@

Capturing and representing hybrid programs

+* `qml.wires.Wires` now accepts JAX arrays as input. + [(#6312)](https://github.com/PennyLaneAI/pennylane/pull/6312) + * Differentiation of hybrid programs via `qml.grad` and `qml.jacobian` can now be captured into plxpr. When evaluating a captured `qml.grad` (`qml.jacobian`) instruction, it will dispatch to `jax.grad` (`jax.jacobian`), which differs from the Autograd implementation diff --git a/tests/capture/test_wires_jax.py b/tests/capture/test_wires_jax.py index 4216da249d3..40f08093c48 100644 --- a/tests/capture/test_wires_jax.py +++ b/tests/capture/test_wires_jax.py @@ -31,6 +31,7 @@ class TestWiresJax: [ (jax.numpy.array([0, 1, 2]), (0, 1, 2)), (jax.numpy.array([0]), (0,)), + (jax.numpy.array(0), (0,)), (jax.numpy.array([]), ()), ], ) @@ -40,9 +41,16 @@ def test_creation_from_jax_array(self, iterable, expected): wires = Wires(iterable) assert wires.labels == expected - @pytest.mark.parametrize("input", [[jax.numpy.array([0, 1, 2]), jax.numpy.array([3, 4])]]) - def test_error_for_incorrect_wire_types(self, input): - """Tests that a Wires object cannot be created from a list of JAX arrays.""" + @pytest.mark.parametrize( + "input", + [ + [jax.numpy.array([0, 1, 2]), jax.numpy.array([3, 4])], + [jax.numpy.array([0, 1, 2]), 3], + jax.numpy.array([[0, 1, 2]]), + ], + ) + def test_error_for_incorrect_jax_arrays(self, input): + """Tests that a Wires object cannot be created from incorrect JAX arrays.""" with pytest.raises(WireError, match="Wires must be hashable"): Wires(input) @@ -63,3 +71,15 @@ def test_array_representation_jax(self): assert array.shape == (3,) for w1, w2 in zip(array, jax.numpy.array([4, 0, 1])): assert w1 == w2 + + @pytest.mark.parametrize( + "source", [jax.numpy.array([0, 1, 2]), jax.numpy.array([0]), jax.numpy.array(0)] + ) + def test_jax_wires_pytree(self, source): + """Test that Wires class supports the PyTree flattening interface with JAX arrays.""" + + wires = Wires(source) + wires_flat, tree = jax.tree_util.tree_flatten(wires) + wires2 = jax.tree_util.tree_unflatten(tree, wires_flat) + assert isinstance(wires2, Wires), f"{wires2} is not Wires" + assert wires == wires2, f"{wires} != {wires2}" From 352d988ee5e1c598e40bc1611916d23891cfd96e Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Tue, 1 Oct 2024 18:37:11 -0400 Subject: [PATCH 12/20] Move check outside the _process function --- pennylane/wires.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/pennylane/wires.py b/pennylane/wires.py index 349c3e976af..127d252aca4 100644 --- a/pennylane/wires.py +++ b/pennylane/wires.py @@ -29,6 +29,14 @@ class WireError(Exception): """Exception raised by a :class:`~.pennylane.wires.Wire` object when it is unable to process wires.""" +if util.find_spec("jax") is not None: + jax = import_module("jax") + jax_available = True +else: + jax_available = False + jax = None + + def _process(wires): """Converts the input to a tuple of wire labels. @@ -51,9 +59,7 @@ def _process(wires): # of considering the elements of iterables as wire labels. wires = [wires] - if util.find_spec("jax") is not None: - jax = import_module("jax") - + if jax_available: if isinstance(wires, jax.numpy.ndarray) and not isinstance(wires, jax.core.Tracer): wires = tuple(wires.tolist() if wires.ndim > 0 else (wires.item(),)) else: From 2ab99a2a4d818f02597dba43935fb430ffc9fd7c Mon Sep 17 00:00:00 2001 From: Pietropaolo Frisoni Date: Wed, 2 Oct 2024 10:11:56 -0400 Subject: [PATCH 13/20] Using `qml.math` to detect JAX interface --- pennylane/wires.py | 22 ++++------------------ tests/capture/test_wires_jax.py | 2 ++ 2 files changed, 6 insertions(+), 18 deletions(-) diff --git a/pennylane/wires.py b/pennylane/wires.py index 127d252aca4..7fbccb0c8fe 100644 --- a/pennylane/wires.py +++ b/pennylane/wires.py @@ -17,11 +17,11 @@ import functools import itertools from collections.abc import Hashable, Iterable, Sequence -from importlib import import_module, util from typing import Union import numpy as np +import pennylane as qml from pennylane.pytrees import register_pytree @@ -29,14 +29,6 @@ class WireError(Exception): """Exception raised by a :class:`~.pennylane.wires.Wire` object when it is unable to process wires.""" -if util.find_spec("jax") is not None: - jax = import_module("jax") - jax_available = True -else: - jax_available = False - jax = None - - def _process(wires): """Converts the input to a tuple of wire labels. @@ -59,15 +51,9 @@ def _process(wires): # of considering the elements of iterables as wire labels. wires = [wires] - if jax_available: - if isinstance(wires, jax.numpy.ndarray) and not isinstance(wires, jax.core.Tracer): - wires = tuple(wires.tolist() if wires.ndim > 0 else (wires.item(),)) - else: - if "jax" in str(type(wires)): - raise ImportError( # pragma: no cover - "JAX is required to process wires that are JAX arrays. " - "You can install it using: pip install jax jaxlib" - ) + if qml.math.get_interface(wires) == "jax" and not qml.math.is_abstract(wires): + wires = tuple(wires.tolist() if wires.ndim > 0 else (wires.item(),)) + try: # Use tuple conversion as a check for whether `wires` can be iterated over. # Note, this is not the same as `isinstance(wires, Iterable)` which would diff --git a/tests/capture/test_wires_jax.py b/tests/capture/test_wires_jax.py index 40f08093c48..ab46b1ec843 100644 --- a/tests/capture/test_wires_jax.py +++ b/tests/capture/test_wires_jax.py @@ -47,6 +47,8 @@ def test_creation_from_jax_array(self, iterable, expected): [jax.numpy.array([0, 1, 2]), jax.numpy.array([3, 4])], [jax.numpy.array([0, 1, 2]), 3], jax.numpy.array([[0, 1, 2]]), + jax.numpy.array([[[0, 1], [2, 3]]]), + jax.numpy.array([[[[0]]]]), ], ) def test_error_for_incorrect_jax_arrays(self, input): From 48d92fb8342d368f93042f2003c86b74b5139b78 Mon Sep 17 00:00:00 2001 From: Pietropaolo Frisoni Date: Wed, 2 Oct 2024 22:26:01 -0400 Subject: [PATCH 14/20] Importorskip inside the class --- tests/capture/test_wires_jax.py | 87 --------------------------------- tests/test_wires.py | 74 ++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 87 deletions(-) delete mode 100644 tests/capture/test_wires_jax.py diff --git a/tests/capture/test_wires_jax.py b/tests/capture/test_wires_jax.py deleted file mode 100644 index ab46b1ec843..00000000000 --- a/tests/capture/test_wires_jax.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright 2024 Xanadu Quantum Technologies Inc. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Tests the support for JAX arrays in the ``Wires`` class. -""" -import pytest - -from pennylane.wires import WireError, Wires - -jax = pytest.importorskip("jax") - -pytestmark = pytest.mark.jax - - -class TestWiresJax: - """Tests the support for JAX arrays in the ``Wires`` class.""" - - @pytest.mark.parametrize( - "iterable, expected", - [ - (jax.numpy.array([0, 1, 2]), (0, 1, 2)), - (jax.numpy.array([0]), (0,)), - (jax.numpy.array(0), (0,)), - (jax.numpy.array([]), ()), - ], - ) - def test_creation_from_jax_array(self, iterable, expected): - """Tests that a Wires object can be created from a JAX array.""" - - wires = Wires(iterable) - assert wires.labels == expected - - @pytest.mark.parametrize( - "input", - [ - [jax.numpy.array([0, 1, 2]), jax.numpy.array([3, 4])], - [jax.numpy.array([0, 1, 2]), 3], - jax.numpy.array([[0, 1, 2]]), - jax.numpy.array([[[0, 1], [2, 3]]]), - jax.numpy.array([[[[0]]]]), - ], - ) - def test_error_for_incorrect_jax_arrays(self, input): - """Tests that a Wires object cannot be created from incorrect JAX arrays.""" - - with pytest.raises(WireError, match="Wires must be hashable"): - Wires(input) - - @pytest.mark.parametrize("iterable", [jax.numpy.array([4, 1, 1, 3]), jax.numpy.array([0, 0])]) - def test_error_for_repeated_wires_jax(self, iterable): - """Tests that a Wires object cannot be created from a JAX array with repeated indices.""" - - with pytest.raises(WireError, match="Wires must be unique"): - Wires(iterable) - - def test_array_representation_jax(self): - """Tests that Wires object has an array representation with JAX.""" - - wires = Wires([4, 0, 1]) - array = jax.numpy.array(wires.labels) - assert isinstance(array, jax.numpy.ndarray) - assert array.shape == (3,) - for w1, w2 in zip(array, jax.numpy.array([4, 0, 1])): - assert w1 == w2 - - @pytest.mark.parametrize( - "source", [jax.numpy.array([0, 1, 2]), jax.numpy.array([0]), jax.numpy.array(0)] - ) - def test_jax_wires_pytree(self, source): - """Test that Wires class supports the PyTree flattening interface with JAX arrays.""" - - wires = Wires(source) - wires_flat, tree = jax.tree_util.tree_flatten(wires) - wires2 = jax.tree_util.tree_unflatten(tree, wires_flat) - assert isinstance(wires2, Wires), f"{wires2} is not Wires" - assert wires == wires2, f"{wires} != {wires2}" diff --git a/tests/test_wires.py b/tests/test_wires.py index 5ceb6475bb9..af82b9b4076 100644 --- a/tests/test_wires.py +++ b/tests/test_wires.py @@ -487,3 +487,77 @@ def test_complex_operation(self): expected = Wires([0, 1, 2, 3, 4, 5, 6, 7]) assert result == expected + + +class TestWiresJax: + """Tests the support for JAX arrays in the ``Wires`` class.""" + + jax = pytest.importorskip("jax") + + @pytest.mark.jax + @pytest.mark.parametrize( + "iterable, expected", + [ + (jax.numpy.array([0, 1, 2]), (0, 1, 2)), + (jax.numpy.array([0]), (0,)), + (jax.numpy.array(0), (0,)), + (jax.numpy.array([]), ()), + ], + ) + def test_creation_from_jax_array(self, iterable, expected): + """Tests that a Wires object can be created from a JAX array.""" + wires = Wires(iterable) + assert wires.labels == expected + + @pytest.mark.jax + @pytest.mark.parametrize( + "input", + [ + [jax.numpy.array([0, 1, 2]), jax.numpy.array([3, 4])], + [jax.numpy.array([0, 1, 2]), 3], + jax.numpy.array([[0, 1, 2]]), + jax.numpy.array([[[0, 1], [2, 3]]]), + jax.numpy.array([[[[0]]]]), + ], + ) + def test_error_for_incorrect_jax_arrays(self, input): + """Tests that a Wires object cannot be created from incorrect JAX arrays.""" + with pytest.raises(WireError, match="Wires must be hashable"): + Wires(input) + + @pytest.mark.jax + @pytest.mark.parametrize("iterable", [jax.numpy.array([4, 1, 1, 3]), jax.numpy.array([0, 0])]) + def test_error_for_repeated_wires_jax(self, iterable): + """Tests that a Wires object cannot be created from a JAX array with repeated indices.""" + with pytest.raises(WireError, match="Wires must be unique"): + Wires(iterable) + + @pytest.mark.jax + def test_array_representation_jax(self): + """Tests that Wires object has an array representation with JAX.""" + + # pylint: disable=import-outside-toplevel + import jax + + wires = Wires([4, 0, 1]) + array = jax.numpy.array(wires.labels) + assert isinstance(array, jax.numpy.ndarray) + assert array.shape == (3,) + for w1, w2 in zip(array, jax.numpy.array([4, 0, 1])): + assert w1 == w2 + + @pytest.mark.jax + @pytest.mark.parametrize( + "source", [jax.numpy.array([0, 1, 2]), jax.numpy.array([0]), jax.numpy.array(0)] + ) + def test_jax_wires_pytree(self, source): + """Test that Wires class supports the PyTree flattening interface with JAX arrays.""" + + # pylint: disable=import-outside-toplevel + import jax + + wires = Wires(source) + wires_flat, tree = jax.tree_util.tree_flatten(wires) + wires2 = jax.tree_util.tree_unflatten(tree, wires_flat) + assert isinstance(wires2, Wires), f"{wires2} is not Wires" + assert wires == wires2, f"{wires} != {wires2}" From d25eed884a108c6f13c160073ddec6571a5f81ba Mon Sep 17 00:00:00 2001 From: Pietropaolo Frisoni Date: Wed, 2 Oct 2024 23:13:46 -0400 Subject: [PATCH 15/20] Other solution --- tests/test_wires.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/test_wires.py b/tests/test_wires.py index af82b9b4076..e45108006c4 100644 --- a/tests/test_wires.py +++ b/tests/test_wires.py @@ -21,6 +21,14 @@ from pennylane.wires import WireError, Wires +try: + import jax + + has_jax = True +except ImportError: + has_jax = False + + # pylint: disable=too-many-public-methods class TestWires: """Tests for the ``Wires`` class.""" @@ -489,11 +497,10 @@ def test_complex_operation(self): assert result == expected +@pytest.mark.skipif(not has_jax, reason="JAX is not installed") class TestWiresJax: """Tests the support for JAX arrays in the ``Wires`` class.""" - jax = pytest.importorskip("jax") - @pytest.mark.jax @pytest.mark.parametrize( "iterable, expected", @@ -536,9 +543,6 @@ def test_error_for_repeated_wires_jax(self, iterable): def test_array_representation_jax(self): """Tests that Wires object has an array representation with JAX.""" - # pylint: disable=import-outside-toplevel - import jax - wires = Wires([4, 0, 1]) array = jax.numpy.array(wires.labels) assert isinstance(array, jax.numpy.ndarray) @@ -553,9 +557,6 @@ def test_array_representation_jax(self): def test_jax_wires_pytree(self, source): """Test that Wires class supports the PyTree flattening interface with JAX arrays.""" - # pylint: disable=import-outside-toplevel - import jax - wires = Wires(source) wires_flat, tree = jax.tree_util.tree_flatten(wires) wires2 = jax.tree_util.tree_unflatten(tree, wires_flat) From ff61eabcad3bf71bd0dd16602c69d1777265d3c9 Mon Sep 17 00:00:00 2001 From: Pietropaolo Frisoni Date: Wed, 2 Oct 2024 23:26:21 -0400 Subject: [PATCH 16/20] Other solution --- tests/test_wires.py | 54 +++++++++++++++++++++++++++++++-------------- 1 file changed, 37 insertions(+), 17 deletions(-) diff --git a/tests/test_wires.py b/tests/test_wires.py index e45108006c4..e9a39ca4595 100644 --- a/tests/test_wires.py +++ b/tests/test_wires.py @@ -497,52 +497,68 @@ def test_complex_operation(self): assert result == expected -@pytest.mark.skipif(not has_jax, reason="JAX is not installed") class TestWiresJax: """Tests the support for JAX arrays in the ``Wires`` class.""" - @pytest.mark.jax + @pytest.mark.skipif(not has_jax, reason="JAX is not installed") @pytest.mark.parametrize( "iterable, expected", [ - (jax.numpy.array([0, 1, 2]), (0, 1, 2)), - (jax.numpy.array([0]), (0,)), - (jax.numpy.array(0), (0,)), - (jax.numpy.array([]), ()), + (jax.numpy.array([0, 1, 2]) if has_jax else None, (0, 1, 2)), + (jax.numpy.array([0]) if has_jax else None, (0,)), + (jax.numpy.array(0) if has_jax else None, (0,)), + (jax.numpy.array([]) if has_jax else None, ()), ], ) def test_creation_from_jax_array(self, iterable, expected): """Tests that a Wires object can be created from a JAX array.""" + + if not has_jax: + pytest.skip("Skipping test since JAX is not installed.") + wires = Wires(iterable) assert wires.labels == expected - @pytest.mark.jax + @pytest.mark.skipif(not has_jax, reason="JAX is not installed") @pytest.mark.parametrize( "input", [ - [jax.numpy.array([0, 1, 2]), jax.numpy.array([3, 4])], - [jax.numpy.array([0, 1, 2]), 3], - jax.numpy.array([[0, 1, 2]]), - jax.numpy.array([[[0, 1], [2, 3]]]), - jax.numpy.array([[[[0]]]]), + [jax.numpy.array([0, 1, 2]), jax.numpy.array([3, 4])] if has_jax else None, + [jax.numpy.array([0, 1, 2]), 3] if has_jax else None, + jax.numpy.array([[0, 1, 2]]) if has_jax else None, + jax.numpy.array([[[0, 1], [2, 3]]]) if has_jax else None, + jax.numpy.array([[[[0]]]]) if has_jax else None, ], ) def test_error_for_incorrect_jax_arrays(self, input): """Tests that a Wires object cannot be created from incorrect JAX arrays.""" + + if not has_jax: + pytest.skip("Skipping test since JAX is not installed.") + with pytest.raises(WireError, match="Wires must be hashable"): Wires(input) - @pytest.mark.jax - @pytest.mark.parametrize("iterable", [jax.numpy.array([4, 1, 1, 3]), jax.numpy.array([0, 0])]) + @pytest.mark.skipif(not has_jax, reason="JAX is not installed") + @pytest.mark.parametrize( + "iterable", [jax.numpy.array([4, 1, 1, 3]), jax.numpy.array([0, 0])] if has_jax else None + ) def test_error_for_repeated_wires_jax(self, iterable): """Tests that a Wires object cannot be created from a JAX array with repeated indices.""" + + if not has_jax: + pytest.skip("Skipping test since JAX is not installed.") + with pytest.raises(WireError, match="Wires must be unique"): Wires(iterable) - @pytest.mark.jax + @pytest.mark.skipif(not has_jax, reason="JAX is not installed") def test_array_representation_jax(self): """Tests that Wires object has an array representation with JAX.""" + if not has_jax: + pytest.skip("Skipping test since JAX is not installed.") + wires = Wires([4, 0, 1]) array = jax.numpy.array(wires.labels) assert isinstance(array, jax.numpy.ndarray) @@ -550,13 +566,17 @@ def test_array_representation_jax(self): for w1, w2 in zip(array, jax.numpy.array([4, 0, 1])): assert w1 == w2 - @pytest.mark.jax + @pytest.mark.skipif(not has_jax, reason="JAX is not installed") @pytest.mark.parametrize( - "source", [jax.numpy.array([0, 1, 2]), jax.numpy.array([0]), jax.numpy.array(0)] + "source", + [jax.numpy.array([0, 1, 2]), jax.numpy.array([0]), jax.numpy.array(0)] if has_jax else None, ) def test_jax_wires_pytree(self, source): """Test that Wires class supports the PyTree flattening interface with JAX arrays.""" + if not has_jax: + pytest.skip("Skipping test since JAX is not installed.") + wires = Wires(source) wires_flat, tree = jax.tree_util.tree_flatten(wires) wires2 = jax.tree_util.tree_unflatten(tree, wires_flat) From e6674eb67a8d0e205ff78fe86992a37709a9d57c Mon Sep 17 00:00:00 2001 From: Pietropaolo Frisoni Date: Wed, 2 Oct 2024 23:55:13 -0400 Subject: [PATCH 17/20] Other solution --- tests/test_wires.py | 84 +++++++++++++++++++++------------------------ 1 file changed, 40 insertions(+), 44 deletions(-) diff --git a/tests/test_wires.py b/tests/test_wires.py index e9a39ca4595..16e96a93d90 100644 --- a/tests/test_wires.py +++ b/tests/test_wires.py @@ -14,6 +14,7 @@ """ Unit tests for :mod:`pennylane.wires`. """ +from importlib import import_module, util import numpy as np import pytest @@ -21,12 +22,12 @@ from pennylane.wires import WireError, Wires -try: - import jax - - has_jax = True -except ImportError: - has_jax = False +if util.find_spec("jax") is not None: + jax = import_module("jax") + jax_available = True +else: + jax_available = False + jax = None # pylint: disable=too-many-public-methods @@ -500,65 +501,59 @@ def test_complex_operation(self): class TestWiresJax: """Tests the support for JAX arrays in the ``Wires`` class.""" - @pytest.mark.skipif(not has_jax, reason="JAX is not installed") + @pytest.mark.jax @pytest.mark.parametrize( "iterable, expected", - [ - (jax.numpy.array([0, 1, 2]) if has_jax else None, (0, 1, 2)), - (jax.numpy.array([0]) if has_jax else None, (0,)), - (jax.numpy.array(0) if has_jax else None, (0,)), - (jax.numpy.array([]) if has_jax else None, ()), - ], + ( + [ + (jax.numpy.array([0, 1, 2]), (0, 1, 2)), + (jax.numpy.array([0]), (0,)), + (jax.numpy.array(0), (0,)), + (jax.numpy.array([]), ()), + ] + if jax_available + else [] + ), ) def test_creation_from_jax_array(self, iterable, expected): """Tests that a Wires object can be created from a JAX array.""" - - if not has_jax: - pytest.skip("Skipping test since JAX is not installed.") - wires = Wires(iterable) assert wires.labels == expected - @pytest.mark.skipif(not has_jax, reason="JAX is not installed") + @pytest.mark.jax @pytest.mark.parametrize( "input", - [ - [jax.numpy.array([0, 1, 2]), jax.numpy.array([3, 4])] if has_jax else None, - [jax.numpy.array([0, 1, 2]), 3] if has_jax else None, - jax.numpy.array([[0, 1, 2]]) if has_jax else None, - jax.numpy.array([[[0, 1], [2, 3]]]) if has_jax else None, - jax.numpy.array([[[[0]]]]) if has_jax else None, - ], + ( + [ + [jax.numpy.array([0, 1, 2]), jax.numpy.array([3, 4])], + [jax.numpy.array([0, 1, 2]), 3], + jax.numpy.array([[0, 1, 2]]), + jax.numpy.array([[[0, 1], [2, 3]]]), + jax.numpy.array([[[[0]]]]) if jax_available else [], + ] + if jax_available + else [] + ), ) def test_error_for_incorrect_jax_arrays(self, input): """Tests that a Wires object cannot be created from incorrect JAX arrays.""" - - if not has_jax: - pytest.skip("Skipping test since JAX is not installed.") - with pytest.raises(WireError, match="Wires must be hashable"): Wires(input) - @pytest.mark.skipif(not has_jax, reason="JAX is not installed") + @pytest.mark.jax @pytest.mark.parametrize( - "iterable", [jax.numpy.array([4, 1, 1, 3]), jax.numpy.array([0, 0])] if has_jax else None + "iterable", + [jax.numpy.array([4, 1, 1, 3]), jax.numpy.array([0, 0])] if jax_available else [], ) def test_error_for_repeated_wires_jax(self, iterable): """Tests that a Wires object cannot be created from a JAX array with repeated indices.""" - - if not has_jax: - pytest.skip("Skipping test since JAX is not installed.") - with pytest.raises(WireError, match="Wires must be unique"): Wires(iterable) - @pytest.mark.skipif(not has_jax, reason="JAX is not installed") + @pytest.mark.jax def test_array_representation_jax(self): """Tests that Wires object has an array representation with JAX.""" - if not has_jax: - pytest.skip("Skipping test since JAX is not installed.") - wires = Wires([4, 0, 1]) array = jax.numpy.array(wires.labels) assert isinstance(array, jax.numpy.ndarray) @@ -566,17 +561,18 @@ def test_array_representation_jax(self): for w1, w2 in zip(array, jax.numpy.array([4, 0, 1])): assert w1 == w2 - @pytest.mark.skipif(not has_jax, reason="JAX is not installed") + @pytest.mark.jax @pytest.mark.parametrize( "source", - [jax.numpy.array([0, 1, 2]), jax.numpy.array([0]), jax.numpy.array(0)] if has_jax else None, + ( + [jax.numpy.array([0, 1, 2]), jax.numpy.array([0]), jax.numpy.array(0)] + if jax_available + else [] + ), ) def test_jax_wires_pytree(self, source): """Test that Wires class supports the PyTree flattening interface with JAX arrays.""" - if not has_jax: - pytest.skip("Skipping test since JAX is not installed.") - wires = Wires(source) wires_flat, tree = jax.tree_util.tree_flatten(wires) wires2 = jax.tree_util.tree_unflatten(tree, wires_flat) From f08baa62fc270cc347cf2efde02d19c187e622c3 Mon Sep 17 00:00:00 2001 From: Pietropaolo Frisoni Date: Thu, 3 Oct 2024 00:12:36 -0400 Subject: [PATCH 18/20] Final solution --- tests/test_wires.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/test_wires.py b/tests/test_wires.py index 16e96a93d90..ec4f0eb361a 100644 --- a/tests/test_wires.py +++ b/tests/test_wires.py @@ -15,13 +15,13 @@ Unit tests for :mod:`pennylane.wires`. """ from importlib import import_module, util + import numpy as np import pytest import pennylane as qml from pennylane.wires import WireError, Wires - if util.find_spec("jax") is not None: jax = import_module("jax") jax_available = True @@ -498,10 +498,10 @@ def test_complex_operation(self): assert result == expected +@pytest.mark.jax class TestWiresJax: """Tests the support for JAX arrays in the ``Wires`` class.""" - @pytest.mark.jax @pytest.mark.parametrize( "iterable, expected", ( @@ -520,7 +520,6 @@ def test_creation_from_jax_array(self, iterable, expected): wires = Wires(iterable) assert wires.labels == expected - @pytest.mark.jax @pytest.mark.parametrize( "input", ( @@ -540,7 +539,6 @@ def test_error_for_incorrect_jax_arrays(self, input): with pytest.raises(WireError, match="Wires must be hashable"): Wires(input) - @pytest.mark.jax @pytest.mark.parametrize( "iterable", [jax.numpy.array([4, 1, 1, 3]), jax.numpy.array([0, 0])] if jax_available else [], @@ -550,7 +548,6 @@ def test_error_for_repeated_wires_jax(self, iterable): with pytest.raises(WireError, match="Wires must be unique"): Wires(iterable) - @pytest.mark.jax def test_array_representation_jax(self): """Tests that Wires object has an array representation with JAX.""" @@ -561,7 +558,6 @@ def test_array_representation_jax(self): for w1, w2 in zip(array, jax.numpy.array([4, 0, 1])): assert w1 == w2 - @pytest.mark.jax @pytest.mark.parametrize( "source", ( From 2fb858f3341ceca8d6af64b43a3946fb07e6628d Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Thu, 3 Oct 2024 13:24:07 -0400 Subject: [PATCH 19/20] Repetition in test --- tests/test_wires.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_wires.py b/tests/test_wires.py index ec4f0eb361a..0f4f4c17eb7 100644 --- a/tests/test_wires.py +++ b/tests/test_wires.py @@ -528,7 +528,7 @@ def test_creation_from_jax_array(self, iterable, expected): [jax.numpy.array([0, 1, 2]), 3], jax.numpy.array([[0, 1, 2]]), jax.numpy.array([[[0, 1], [2, 3]]]), - jax.numpy.array([[[[0]]]]) if jax_available else [], + jax.numpy.array([[[[0]]]]), ] if jax_available else [] From 4c61871cc9525edc6fc8a4c890f9688edc9ea75e Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Mon, 7 Oct 2024 12:02:28 -0400 Subject: [PATCH 20/20] Creating one unique PR --- doc/releases/changelog-dev.md | 3 ++- pennylane/wires.py | 12 ++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 25f59a5a0a0..89b1c5a6c3d 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -58,7 +58,8 @@

Capturing and representing hybrid programs

-* `qml.wires.Wires` now accepts JAX arrays as input. +* `qml.wires.Wires` now accepts JAX arrays as input. Furthermore, a `FutureWarning` is no longer raised in `JAX 0.4.30+` + when providing JAX tracers as input to `qml.wires.Wires`. [(#6312)](https://github.com/PennyLaneAI/pennylane/pull/6312) * Differentiation of hybrid programs via `qml.grad` and `qml.jacobian` can now be captured diff --git a/pennylane/wires.py b/pennylane/wires.py index 7fbccb0c8fe..4e1af20f419 100644 --- a/pennylane/wires.py +++ b/pennylane/wires.py @@ -17,6 +17,7 @@ import functools import itertools from collections.abc import Hashable, Iterable, Sequence +from importlib import import_module, util from typing import Union import numpy as np @@ -24,6 +25,17 @@ import pennylane as qml from pennylane.pytrees import register_pytree +if util.find_spec("jax") is not None: + jax = import_module("jax") + jax_available = True +else: + jax_available = False + jax = None + +if jax_available: + # pylint: disable=unnecessary-lambda + setattr(jax.interpreters.partial_eval.DynamicJaxprTracer, "__hash__", lambda x: id(x)) + class WireError(Exception): """Exception raised by a :class:`~.pennylane.wires.Wire` object when it is unable to process wires."""