Skip to content

Commit

Permalink
Add get_best_diff_method and _get_gradient_fn to qml.workflow (#…
Browse files Browse the repository at this point in the history
…6399)

**Context:**

Right now, `QNode.get_gradient_fn` and `QNode.get_best_method`, and
`QNode.best_method_str` all exist as public or static methods and aren't
used internally (aside from inside the `QNode` class) and they don't
make much sense externally. To address this, a new user-friendly
function should be added.

**Description of the Change:**

Implemented functions `get_best_diff_method` and `_get_gradient_fn` to
`qml.workflow`. They behave the same as the existing methods in `QNode`
but are designed with a simpler interface in mind 😄. The former is
intended to be more user-facing and so has a simplifed interface. The
latter is more for internal development and will be used to deprecate
the internal `QNode` methods mentioned earlier.

For `get_best_diff_method` the user can provide the `QNode` they are
trying to execute and it will return a string with the "best"
differentiation method.

For `_get_gradient_fn`, the intention was to extract this from `QNode`
and use it later on once we begin the deprecation and removal of those
methods outlined earlier.4

Example 

_`qml.workflow.get_best_diff_method`_

```python
>>> dev = qml.device("default.qubit")
>>> qn = qml.QNode(lambda: None, dev)
>>> qml.workflow.get_best_diff_method(qn)()
'backprop'
>>> dev_shots = qml.device("default.qubit", shots=45)
>>> qn_shots = qml.QNode(lambda: None, dev_shots)
>>> qml.workflow.get_best_diff_method(qn_shots)()
'parameter-shift'
```

_`qml.workflow._get_gradient_fn`_

```python
>>> qml.workflow._get_gradient_fn(dev_shots, diff_method='parameter-shift')
<transform: param_shift>
```

**Benefits:** Improves `QNode` organization and structure. 

**Possible Drawbacks:** None

[sc-72157]

---------

Co-authored-by: Christina Lee <christina@xanadu.ai>
  • Loading branch information
2 people authored and mudit2812 committed Nov 11, 2024
1 parent 964e3d5 commit 6cc3429
Show file tree
Hide file tree
Showing 6 changed files with 473 additions and 0 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

<h3>New features since last release</h3>

* Added functions `get_best_diff_method` to `qml.workflow`.
[(#6399)](https://github.com/PennyLaneAI/pennylane/pull/6399)

* Add `qml.workflow.construct_tape` as a method for users to construct single tapes from a `QNode`.
[(#6419)](https://github.com/PennyLaneAI/pennylane/pull/6419)

Expand Down
3 changes: 3 additions & 0 deletions pennylane/workflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
~workflow.construct_tape
~workflow.construct_batch
~workflow.get_transform_program
~workflow.get_best_diff_method
Supported interfaces
~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -55,6 +56,8 @@
.. include:: ../../pennylane/workflow/return_types_spec.rst
"""
from .get_best_diff_method import get_best_diff_method
from .get_gradient_fn import _get_gradient_fn
from .construct_batch import construct_batch, get_transform_program
from .construct_tape import construct_tape
from .execution import INTERFACE_MAP, SUPPORTED_INTERFACE_NAMES, execute
Expand Down
78 changes: 78 additions & 0 deletions pennylane/workflow/get_best_diff_method.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright 2018-2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains a function for getting the best differentiation method for a given QNode.
"""

from functools import wraps

import pennylane as qml
from pennylane.workflow.qnode import QNode, _make_execution_config


def get_best_diff_method(qnode: QNode):
"""Returns a function that computes the 'best' differentiation method
for a particular QNode.
This method prioritizes differentiation methods in the following order (SPSA-based and Hadamard-based gradients
are not included here):
* ``"device"``
* ``"backprop"``
* ``"parameter-shift"``
.. note::
The first differentiation method that is supported (from top to bottom)
will be returned. The order is designed to maximize efficiency, generality,
and stability.
.. seealso::
For a detailed comparison of the backpropagation and parameter-shift methods,
refer to the :doc:`quantum gradients with backpropagation example <demo:demos/tutorial_backprop>`.
Args:
qnode (.QNode): the qnode to get the 'best' differentiation method for.
Returns:
str: the gradient transform.
"""

def handle_return(transform):
"""Helper function to manage the return"""
if transform in (qml.gradients.param_shift, qml.gradients.param_shift_cv):
return "parameter-shift"
return transform

@wraps(qnode)
def wrapper(*args, **kwargs):
device = qnode.device
tape = qml.workflow.construct_tape(qnode)(*args, **kwargs)

config = _make_execution_config(None, "best")

if device.supports_derivatives(config, circuit=tape):
new_config = device.preprocess(config)[1]
transform = new_config.gradient_method
return handle_return(transform)

if tape and any(isinstance(o, qml.operation.CV) for o in tape):
transform = qml.gradients.param_shift_cv
return handle_return(transform)

transform = qml.gradients.param_shift
return handle_return(transform)

return wrapper
92 changes: 92 additions & 0 deletions pennylane/workflow/get_gradient_fn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright 2018-2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains a function for retrieving the gradient function for a given device or tape.
"""

from typing import Optional, get_args

import pennylane as qml
from pennylane.transforms.core import TransformDispatcher
from pennylane.workflow.qnode import (
SupportedDeviceAPIs,
SupportedDiffMethods,
_make_execution_config,
)


# pylint: disable=too-many-return-statements, unsupported-binary-operation
def _get_gradient_fn(
device: SupportedDeviceAPIs,
diff_method: "TransformDispatcher | SupportedDiffMethods" = "best",
tape: Optional["qml.tape.QuantumTape"] = None,
):
"""Determines the differentiation method for a given device and diff method.
Args:
device (:class:`~.devices.Device`): PennyLane device
diff_method (str or :class:`~.TransformDispatcher`): The requested method of differentiation. Defaults to ``"best"``.
If a string, allowed options are ``"best"``, ``"backprop"``, ``"adjoint"``,
``"device"``, ``"parameter-shift"``, ``"hadamard"``, ``"finite-diff"``, or ``"spsa"``.
Alternatively, a gradient transform can be provided.
tape (Optional[.QuantumTape]): the circuit that will be differentiated. Should include shots information.
Returns:
str or :class:`~.TransformDispatcher` (the ``gradient_fn``)
"""

if diff_method is None:
return None

config = _make_execution_config(None, diff_method)

if device.supports_derivatives(config, circuit=tape):
new_config = device.preprocess(config)[1]
return new_config.gradient_method

if diff_method in {"backprop", "adjoint", "device"}: # device-only derivatives
raise qml.QuantumFunctionError(
f"Device {device} does not support {diff_method} with requested circuit."
)

if diff_method == "best":
qn = qml.QNode(lambda: None, device, diff_method=None)
return qml.workflow.get_best_diff_method(qn)()

if diff_method == "parameter-shift":
if tape and any(isinstance(o, qml.operation.CV) and o.name != "Identity" for o in tape):
return qml.gradients.param_shift_cv
return qml.gradients.param_shift

gradient_transform_map = {
"finite-diff": qml.gradients.finite_diff,
"spsa": qml.gradients.spsa_grad,
"hadamard": qml.gradients.hadamard_grad,
}

if diff_method in gradient_transform_map:
return gradient_transform_map[diff_method]

if isinstance(diff_method, str):
raise qml.QuantumFunctionError(
f"Differentiation method {diff_method} not recognized. Allowed "
f"options are {tuple(get_args(SupportedDiffMethods))}."
)

if isinstance(diff_method, TransformDispatcher):
return diff_method

raise qml.QuantumFunctionError(
f"Differentiation method {diff_method} must be a gradient transform or a string."
)
111 changes: 111 additions & 0 deletions tests/workflow/test_get_best_diff_method.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright 2018-2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for the `get_best_diff_method` function"""

import pytest

import pennylane as qml
from pennylane.workflow import get_best_diff_method


def dummy_cv_func(x):
"""A dummy CV function with continuous-variable operations."""
qml.Displacement(x, 0.1, wires=0)
return qml.expval(qml.X(0))


def dummyfunc():
"""dummy func."""
return None


# pylint: disable=unused-argument
class CustomDevice(qml.devices.Device):
"""A null device that just returns 0."""

def __repr__(self):
return "CustomDevice"

def execute(self, circuits, execution_config=None):
return (0,)


class CustomDeviceWithDiffMethod(qml.devices.Device):
"""A device that defines a derivative."""

def execute(self, circuits, execution_config=None):
return 0

def compute_derivatives(self, circuits, execution_config=None):
"""Device defines its own method to compute derivatives"""
return 0


class TestValidation:
"""Tests for QNode creation and validation"""

@pytest.mark.autograd
def test_best_method_is_device(self):
"""Test that the method for determining the best diff method
for a device that is a child of qml.devices.Device and has a
compute_derivatives method defined returns 'device'"""

dev = CustomDeviceWithDiffMethod()
qn_jax = qml.QNode(dummyfunc, dev, "jax")
qn_none = qml.QNode(dummyfunc, dev, None)

res = get_best_diff_method(qn_jax)()
assert res == "device"

res = get_best_diff_method(qn_none)()
assert res == "device"

@pytest.mark.parametrize("interface", ["jax", "tensorflow", "torch", "autograd"])
def test_best_method_is_backprop(self, interface):
"""Test that the method for determining the best diff method
for the default.qubit device and a valid interface returns back-propagation"""

dev = qml.device("default.qubit", wires=1)
qn = qml.QNode(dummyfunc, dev, interface)

# backprop is returned when the interface is an allowed interface for the device and Jacobian is not provided
res = get_best_diff_method(qn)()
assert res == "backprop"

def test_best_method_is_param_shift(self):
"""Test that the method for determining the best diff method
for a given device and interface returns the parameter shift rule if
'device' and 'backprop' don't work"""

# null device has no info - fall back on parameter-shift
dev = CustomDevice()
qn = qml.QNode(dummyfunc, dev)

res = get_best_diff_method(qn)()
assert res == "parameter-shift"

# no interface - fall back on parameter-shift
dev2 = qml.device("default.qubit", wires=1)
qn = qml.QNode(dummyfunc, dev2)
res2 = get_best_diff_method(qn)(shots=50)
assert res2 == "parameter-shift"

def test_best_method_is_param_shift_cv(self):
"""Tests that the method returns 'parameter-shift' when CV operations are in the QNode."""

dev = qml.device("default.gaussian", wires=1)
qn = qml.QNode(dummy_cv_func, dev, interface=None)

res = get_best_diff_method(qn)(0.5)
assert res == "parameter-shift"
Loading

0 comments on commit 6cc3429

Please sign in to comment.