-
Notifications
You must be signed in to change notification settings - Fork 615
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
…6399) **Context:** Right now, `QNode.get_gradient_fn` and `QNode.get_best_method`, and `QNode.best_method_str` all exist as public or static methods and aren't used internally (aside from inside the `QNode` class) and they don't make much sense externally. To address this, a new user-friendly function should be added. **Description of the Change:** Implemented functions `get_best_diff_method` and `_get_gradient_fn` to `qml.workflow`. They behave the same as the existing methods in `QNode` but are designed with a simpler interface in mind 😄. The former is intended to be more user-facing and so has a simplifed interface. The latter is more for internal development and will be used to deprecate the internal `QNode` methods mentioned earlier. For `get_best_diff_method` the user can provide the `QNode` they are trying to execute and it will return a string with the "best" differentiation method. For `_get_gradient_fn`, the intention was to extract this from `QNode` and use it later on once we begin the deprecation and removal of those methods outlined earlier.4 Example _`qml.workflow.get_best_diff_method`_ ```python >>> dev = qml.device("default.qubit") >>> qn = qml.QNode(lambda: None, dev) >>> qml.workflow.get_best_diff_method(qn)() 'backprop' >>> dev_shots = qml.device("default.qubit", shots=45) >>> qn_shots = qml.QNode(lambda: None, dev_shots) >>> qml.workflow.get_best_diff_method(qn_shots)() 'parameter-shift' ``` _`qml.workflow._get_gradient_fn`_ ```python >>> qml.workflow._get_gradient_fn(dev_shots, diff_method='parameter-shift') <transform: param_shift> ``` **Benefits:** Improves `QNode` organization and structure. **Possible Drawbacks:** None [sc-72157] --------- Co-authored-by: Christina Lee <christina@xanadu.ai>
- Loading branch information
Showing
6 changed files
with
473 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# Copyright 2018-2024 Xanadu Quantum Technologies Inc. | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Contains a function for getting the best differentiation method for a given QNode. | ||
""" | ||
|
||
from functools import wraps | ||
|
||
import pennylane as qml | ||
from pennylane.workflow.qnode import QNode, _make_execution_config | ||
|
||
|
||
def get_best_diff_method(qnode: QNode): | ||
"""Returns a function that computes the 'best' differentiation method | ||
for a particular QNode. | ||
This method prioritizes differentiation methods in the following order (SPSA-based and Hadamard-based gradients | ||
are not included here): | ||
* ``"device"`` | ||
* ``"backprop"`` | ||
* ``"parameter-shift"`` | ||
.. note:: | ||
The first differentiation method that is supported (from top to bottom) | ||
will be returned. The order is designed to maximize efficiency, generality, | ||
and stability. | ||
.. seealso:: | ||
For a detailed comparison of the backpropagation and parameter-shift methods, | ||
refer to the :doc:`quantum gradients with backpropagation example <demo:demos/tutorial_backprop>`. | ||
Args: | ||
qnode (.QNode): the qnode to get the 'best' differentiation method for. | ||
Returns: | ||
str: the gradient transform. | ||
""" | ||
|
||
def handle_return(transform): | ||
"""Helper function to manage the return""" | ||
if transform in (qml.gradients.param_shift, qml.gradients.param_shift_cv): | ||
return "parameter-shift" | ||
return transform | ||
|
||
@wraps(qnode) | ||
def wrapper(*args, **kwargs): | ||
device = qnode.device | ||
tape = qml.workflow.construct_tape(qnode)(*args, **kwargs) | ||
|
||
config = _make_execution_config(None, "best") | ||
|
||
if device.supports_derivatives(config, circuit=tape): | ||
new_config = device.preprocess(config)[1] | ||
transform = new_config.gradient_method | ||
return handle_return(transform) | ||
|
||
if tape and any(isinstance(o, qml.operation.CV) for o in tape): | ||
transform = qml.gradients.param_shift_cv | ||
return handle_return(transform) | ||
|
||
transform = qml.gradients.param_shift | ||
return handle_return(transform) | ||
|
||
return wrapper |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# Copyright 2018-2024 Xanadu Quantum Technologies Inc. | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Contains a function for retrieving the gradient function for a given device or tape. | ||
""" | ||
|
||
from typing import Optional, get_args | ||
|
||
import pennylane as qml | ||
from pennylane.transforms.core import TransformDispatcher | ||
from pennylane.workflow.qnode import ( | ||
SupportedDeviceAPIs, | ||
SupportedDiffMethods, | ||
_make_execution_config, | ||
) | ||
|
||
|
||
# pylint: disable=too-many-return-statements, unsupported-binary-operation | ||
def _get_gradient_fn( | ||
device: SupportedDeviceAPIs, | ||
diff_method: "TransformDispatcher | SupportedDiffMethods" = "best", | ||
tape: Optional["qml.tape.QuantumTape"] = None, | ||
): | ||
"""Determines the differentiation method for a given device and diff method. | ||
Args: | ||
device (:class:`~.devices.Device`): PennyLane device | ||
diff_method (str or :class:`~.TransformDispatcher`): The requested method of differentiation. Defaults to ``"best"``. | ||
If a string, allowed options are ``"best"``, ``"backprop"``, ``"adjoint"``, | ||
``"device"``, ``"parameter-shift"``, ``"hadamard"``, ``"finite-diff"``, or ``"spsa"``. | ||
Alternatively, a gradient transform can be provided. | ||
tape (Optional[.QuantumTape]): the circuit that will be differentiated. Should include shots information. | ||
Returns: | ||
str or :class:`~.TransformDispatcher` (the ``gradient_fn``) | ||
""" | ||
|
||
if diff_method is None: | ||
return None | ||
|
||
config = _make_execution_config(None, diff_method) | ||
|
||
if device.supports_derivatives(config, circuit=tape): | ||
new_config = device.preprocess(config)[1] | ||
return new_config.gradient_method | ||
|
||
if diff_method in {"backprop", "adjoint", "device"}: # device-only derivatives | ||
raise qml.QuantumFunctionError( | ||
f"Device {device} does not support {diff_method} with requested circuit." | ||
) | ||
|
||
if diff_method == "best": | ||
qn = qml.QNode(lambda: None, device, diff_method=None) | ||
return qml.workflow.get_best_diff_method(qn)() | ||
|
||
if diff_method == "parameter-shift": | ||
if tape and any(isinstance(o, qml.operation.CV) and o.name != "Identity" for o in tape): | ||
return qml.gradients.param_shift_cv | ||
return qml.gradients.param_shift | ||
|
||
gradient_transform_map = { | ||
"finite-diff": qml.gradients.finite_diff, | ||
"spsa": qml.gradients.spsa_grad, | ||
"hadamard": qml.gradients.hadamard_grad, | ||
} | ||
|
||
if diff_method in gradient_transform_map: | ||
return gradient_transform_map[diff_method] | ||
|
||
if isinstance(diff_method, str): | ||
raise qml.QuantumFunctionError( | ||
f"Differentiation method {diff_method} not recognized. Allowed " | ||
f"options are {tuple(get_args(SupportedDiffMethods))}." | ||
) | ||
|
||
if isinstance(diff_method, TransformDispatcher): | ||
return diff_method | ||
|
||
raise qml.QuantumFunctionError( | ||
f"Differentiation method {diff_method} must be a gradient transform or a string." | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
# Copyright 2018-2024 Xanadu Quantum Technologies Inc. | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Unit tests for the `get_best_diff_method` function""" | ||
|
||
import pytest | ||
|
||
import pennylane as qml | ||
from pennylane.workflow import get_best_diff_method | ||
|
||
|
||
def dummy_cv_func(x): | ||
"""A dummy CV function with continuous-variable operations.""" | ||
qml.Displacement(x, 0.1, wires=0) | ||
return qml.expval(qml.X(0)) | ||
|
||
|
||
def dummyfunc(): | ||
"""dummy func.""" | ||
return None | ||
|
||
|
||
# pylint: disable=unused-argument | ||
class CustomDevice(qml.devices.Device): | ||
"""A null device that just returns 0.""" | ||
|
||
def __repr__(self): | ||
return "CustomDevice" | ||
|
||
def execute(self, circuits, execution_config=None): | ||
return (0,) | ||
|
||
|
||
class CustomDeviceWithDiffMethod(qml.devices.Device): | ||
"""A device that defines a derivative.""" | ||
|
||
def execute(self, circuits, execution_config=None): | ||
return 0 | ||
|
||
def compute_derivatives(self, circuits, execution_config=None): | ||
"""Device defines its own method to compute derivatives""" | ||
return 0 | ||
|
||
|
||
class TestValidation: | ||
"""Tests for QNode creation and validation""" | ||
|
||
@pytest.mark.autograd | ||
def test_best_method_is_device(self): | ||
"""Test that the method for determining the best diff method | ||
for a device that is a child of qml.devices.Device and has a | ||
compute_derivatives method defined returns 'device'""" | ||
|
||
dev = CustomDeviceWithDiffMethod() | ||
qn_jax = qml.QNode(dummyfunc, dev, "jax") | ||
qn_none = qml.QNode(dummyfunc, dev, None) | ||
|
||
res = get_best_diff_method(qn_jax)() | ||
assert res == "device" | ||
|
||
res = get_best_diff_method(qn_none)() | ||
assert res == "device" | ||
|
||
@pytest.mark.parametrize("interface", ["jax", "tensorflow", "torch", "autograd"]) | ||
def test_best_method_is_backprop(self, interface): | ||
"""Test that the method for determining the best diff method | ||
for the default.qubit device and a valid interface returns back-propagation""" | ||
|
||
dev = qml.device("default.qubit", wires=1) | ||
qn = qml.QNode(dummyfunc, dev, interface) | ||
|
||
# backprop is returned when the interface is an allowed interface for the device and Jacobian is not provided | ||
res = get_best_diff_method(qn)() | ||
assert res == "backprop" | ||
|
||
def test_best_method_is_param_shift(self): | ||
"""Test that the method for determining the best diff method | ||
for a given device and interface returns the parameter shift rule if | ||
'device' and 'backprop' don't work""" | ||
|
||
# null device has no info - fall back on parameter-shift | ||
dev = CustomDevice() | ||
qn = qml.QNode(dummyfunc, dev) | ||
|
||
res = get_best_diff_method(qn)() | ||
assert res == "parameter-shift" | ||
|
||
# no interface - fall back on parameter-shift | ||
dev2 = qml.device("default.qubit", wires=1) | ||
qn = qml.QNode(dummyfunc, dev2) | ||
res2 = get_best_diff_method(qn)(shots=50) | ||
assert res2 == "parameter-shift" | ||
|
||
def test_best_method_is_param_shift_cv(self): | ||
"""Tests that the method returns 'parameter-shift' when CV operations are in the QNode.""" | ||
|
||
dev = qml.device("default.gaussian", wires=1) | ||
qn = qml.QNode(dummy_cv_func, dev, interface=None) | ||
|
||
res = get_best_diff_method(qn)(0.5) | ||
assert res == "parameter-shift" |
Oops, something went wrong.