Skip to content

Commit

Permalink
Deprecate get_best_method and best_method_str static methods in `…
Browse files Browse the repository at this point in the history
…QNode` (#6418)

**Context:**

#6399 added helper functions `get_best_diff_method` to `qml.workflow`.
This means we are free to deprecate the old static methods out of
`QNode`.

**Description of the Change:** 

Deprecation of the methods.

Tweaked the new `_get_gradient_fn` a bit.

**Benefits:** Cleaning up `QNode` and making it more user-friendly.

**Possible Drawbacks:** None

[sc-76083]

---------

Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
Co-authored-by: Christina Lee <christina@xanadu.ai>
  • Loading branch information
3 people authored Nov 13, 2024
1 parent fd2097e commit eefe6ef
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 37 deletions.
6 changes: 6 additions & 0 deletions doc/development/deprecations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ deprecations are listed below.
Pending deprecations
--------------------

* The ``QNode.get_best_method`` and ``QNode.best_method_str`` methods have been deprecated.
Instead, use the ``qml.workflow.get_best_diff_method``.

- Deprecated in v0.40
- Will be removed in v0.41

* The ``gradient_fn`` keyword argument to ``qml.execute`` has been renamed ``diff_method``.

- Deprecated in v0.40
Expand Down
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@

<h3>Deprecations 👋</h3>

* The `QNode.get_best_method` and `QNode.best_method_str` methods have been deprecated.
Instead, use the `qml.workflow.get_best_diff_method` function.
[(#6418)](https://github.com/PennyLaneAI/pennylane/pull/6418)

* The `qml.execute` `gradient_fn` keyword argument has been renamed `diff_method`,
to better align with the termionology used by the `QNode`.
`gradient_fn` will be removed in v0.41.
Expand Down
8 changes: 6 additions & 2 deletions pennylane/workflow/get_gradient_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Optional, get_args

import pennylane as qml
from pennylane.logging import debug_logger
from pennylane.transforms.core import TransformDispatcher
from pennylane.workflow.qnode import (
SupportedDeviceAPIs,
Expand All @@ -27,6 +28,7 @@


# pylint: disable=too-many-return-statements, unsupported-binary-operation
@debug_logger
def _get_gradient_fn(
device: SupportedDeviceAPIs,
diff_method: "TransformDispatcher | SupportedDiffMethods" = "best",
Expand Down Expand Up @@ -61,8 +63,10 @@ def _get_gradient_fn(
)

if diff_method == "best":
qn = qml.QNode(lambda: None, device, diff_method=None)
return qml.workflow.get_best_diff_method(qn)()
if tape and any(isinstance(o, qml.operation.CV) for o in tape):
return qml.gradients.param_shift_cv

return qml.gradients.param_shift

if diff_method == "parameter-shift":
if tape and any(isinstance(o, qml.operation.CV) and o.name != "Identity" for o in tape):
Expand Down
39 changes: 36 additions & 3 deletions pennylane/workflow/qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,10 @@ def get_gradient_fn(
)

if diff_method == "best":
return QNode.get_best_method(device, interface, tape=tape)
if tape and any(isinstance(o, qml.operation.CV) for o in tape):
return qml.gradients.param_shift_cv, {"dev": device}, device

return qml.gradients.param_shift, {}, device

if diff_method == "parameter-shift":
if tape and any(isinstance(o, qml.operation.CV) and o.name != "Identity" for o in tape):
Expand Down Expand Up @@ -720,7 +723,13 @@ def get_best_method(
dict[str, Any],
SupportedDeviceAPIs,
]:
"""Returns the 'best' differentiation method
"""
.. warning::
This method is deprecated in v0.40 and will be removed in v0.41.
Instead, use the :func:`qml.workflow.get_best_diff_method <.workflow.get_best_diff_method>` function.
Returns the 'best' differentiation method
for a particular device and interface combination.
This method attempts to determine support for differentiation
Expand All @@ -744,6 +753,13 @@ def get_best_method(
tuple[str or .TransformDispatcher, dict, .device.Device: Tuple containing the ``gradient_fn``,
``gradient_kwargs``, and the device to use when calling the execute function.
"""

warnings.warn(
"QNode.get_best_method is deprecated and will be removed in v0.41. "
"Instead, use the qml.workflow.get_best_diff_method function.",
qml.PennyLaneDeprecationWarning,
)

if not isinstance(device, qml.devices.Device):
device = qml.devices.LegacyDeviceFacade(device)

Expand All @@ -761,7 +777,14 @@ def get_best_method(
@staticmethod
@debug_logger
def best_method_str(device: SupportedDeviceAPIs, interface: SupportedInterfaceUserInput) -> str:
"""Similar to :meth:`~.get_best_method`, except return the
"""
.. warning::
This method is deprecated in v0.40 and will be removed in v0.41.
Instead, use the :func:`qml.workflow.get_best_diff_method <.workflow.get_best_diff_method>` function.
Similar to :meth:`~.get_best_method`, except return the
'best' differentiation method in human-readable format.
This method attempts to determine support for differentiation
Expand All @@ -786,9 +809,19 @@ def best_method_str(device: SupportedDeviceAPIs, interface: SupportedInterfaceUs
Returns:
str: The gradient function to use in human-readable format.
"""

warnings.warn(
"QNode.best_method_str is deprecated and will be removed in v0.41. "
"Instead, use the qml.workflow.get_best_diff_method function.",
qml.PennyLaneDeprecationWarning,
)

if not isinstance(device, qml.devices.Device):
device = qml.devices.LegacyDeviceFacade(device)

warnings.filterwarnings(
"ignore", "QNode.get_best_method is deprecated", qml.PennyLaneDeprecationWarning
)
transform = QNode.get_best_method(device, interface)[0]

if transform is qml.gradients.finite_diff:
Expand Down
63 changes: 51 additions & 12 deletions tests/test_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,20 @@ def dummyfunc():
return None


def test_get_best_method_is_deprecated():
"""Test that is deprecated."""
with pytest.warns(qml.PennyLaneDeprecationWarning, match="QNode.get_best_method is deprecated"):
dev = qml.device("default.qubit", wires=2)
_ = QNode.get_best_method(dev, "jax")


def test_best_method_str_is_deprecated():
"""Test that is deprecated."""
with pytest.warns(qml.PennyLaneDeprecationWarning, match="QNode.best_method_str is deprecated"):
dev = qml.device("default.qubit", wires=2)
_ = QNode.best_method_str(dev, "jax")


# pylint: disable=unused-argument
class CustomDevice(qml.devices.Device):
"""A null device that just returns 0."""
Expand Down Expand Up @@ -177,11 +191,17 @@ def test_best_method_is_device(self):

dev = CustomDeviceWithDiffMethod()

res = QNode.get_best_method(dev, "jax")
assert res == ("device", {}, dev)
with pytest.warns(
qml.PennyLaneDeprecationWarning, match="QNode.get_best_method is deprecated"
):
res = QNode.get_best_method(dev, "jax")
assert res == ("device", {}, dev)

res = QNode.get_best_method(dev, None)
assert res == ("device", {}, dev)
with pytest.warns(
qml.PennyLaneDeprecationWarning, match="QNode.get_best_method is deprecated"
):
res = QNode.get_best_method(dev, None)
assert res == ("device", {}, dev)

# pylint: disable=protected-access
@pytest.mark.parametrize("interface", ["jax", "tensorflow", "torch", "autograd"])
Expand All @@ -192,8 +212,11 @@ def test_best_method_is_backprop(self, interface):
dev = qml.device("default.qubit", wires=1)

# backprop is returned when the interface is an allowed interface for the device and Jacobian is not provided
res = QNode.get_best_method(dev, interface)
assert res == ("backprop", {}, dev)
with pytest.warns(
qml.PennyLaneDeprecationWarning, match="QNode.get_best_method is deprecated"
):
res = QNode.get_best_method(dev, interface)
assert res == ("backprop", {}, dev)

# pylint: disable=protected-access
def test_best_method_is_param_shift(self):
Expand All @@ -203,14 +226,20 @@ def test_best_method_is_param_shift(self):

# null device has no info - fall back on parameter-shift
dev = CustomDevice()
res = QNode.get_best_method(dev, None)
assert res == (qml.gradients.param_shift, {}, dev)
with pytest.warns(
qml.PennyLaneDeprecationWarning, match="QNode.get_best_method is deprecated"
):
res = QNode.get_best_method(dev, None)
assert res == (qml.gradients.param_shift, {}, dev)

# no interface - fall back on parameter-shift
dev2 = qml.device("default.qubit", wires=1)
tape = qml.tape.QuantumScript([], [], shots=50)
res2 = QNode.get_best_method(dev2, None, tape=tape)
assert res2 == (qml.gradients.param_shift, {}, dev2)
with pytest.warns(
qml.PennyLaneDeprecationWarning, match="QNode.get_best_method is deprecated"
):
res2 = QNode.get_best_method(dev2, None, tape=tape)
assert res2 == (qml.gradients.param_shift, {}, dev2)

# pylint: disable=protected-access
@pytest.mark.xfail(
Expand All @@ -230,8 +259,11 @@ def capabilities(cls):

# finite differences is the fallback when we know nothing about the device
monkeypatch.setattr(qml.devices.DefaultMixed, "capabilities", capabilities)
res = QNode.get_best_method(dev, "another_interface")
assert res == (qml.gradients.finite_diff, {}, dev)
with pytest.warns(
qml.PennyLaneDeprecationWarning, match=r"QNode.get_best_method is deprecated"
):
res = QNode.get_best_method(dev, "another_interface")
assert res == (qml.gradients.finite_diff, {}, dev)

# pylint: disable=protected-access, too-many-statements
def test_diff_method(self):
Expand Down Expand Up @@ -1462,6 +1494,13 @@ def test_get_gradient_fn_custom_device(self):
assert not kwargs
assert new_dev is self.dev

def test_get_gradient_fn_with_best_method_and_cv_ops(self):
"""Test that get_gradient_fn returns 'parameter-shift-cv' when CV operations are present on tape"""
dev = qml.device("default.gaussian", wires=1)
tape = qml.tape.QuantumScript([qml.Displacement(0.5, 0.0, wires=0)])
res = qml.QNode.get_gradient_fn(dev, interface="autograd", diff_method="best", tape=tape)
assert res == (qml.gradients.param_shift_cv, {"dev": dev}, dev)

def test_get_gradient_fn_default_qubit(self):
"""Tests the get_gradient_fn is backprop for best for default qubit2."""
dev = qml.devices.DefaultQubit()
Expand Down
61 changes: 48 additions & 13 deletions tests/test_qnode_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,10 @@ def test_best_method_wraps_legacy_device_correctly(self, mocker):

spy = mocker.spy(qml.devices.LegacyDeviceFacade, "__init__")

QNode.get_best_method(dev_legacy, "some_interface")
with pytest.warns(
qml.PennyLaneDeprecationWarning, match="QNode.get_best_method is deprecated"
):
QNode.get_best_method(dev_legacy, "some_interface")

spy.assert_called_once()

Expand All @@ -133,11 +136,17 @@ def test_best_method_is_device(self, monkeypatch):
monkeypatch.setitem(dev._capabilities, "provides_jacobian", True)

# basic check if the device provides a Jacobian
res = QNode.get_best_method(dev, "another_interface")
with pytest.warns(
qml.PennyLaneDeprecationWarning, match="QNode.get_best_method is deprecated"
):
res = QNode.get_best_method(dev, "another_interface")
assert res == ("device", {}, dev)

# device is returned even if backpropagation is possible
res = QNode.get_best_method(dev, "some_interface")
with pytest.warns(
qml.PennyLaneDeprecationWarning, match="QNode.get_best_method is deprecated"
):
res = QNode.get_best_method(dev, "some_interface")
assert res == ("device", {}, dev)

# pylint: disable=protected-access
Expand All @@ -148,7 +157,10 @@ def test_best_method_is_backprop(self, interface):
dev = qml.device("default.mixed", wires=1)

# backprop is returned when the interface is an allowed interface for the device and Jacobian is not provided
res = QNode.get_best_method(dev, interface)
with pytest.warns(
qml.PennyLaneDeprecationWarning, match="QNode.get_best_method is deprecated"
):
res = QNode.get_best_method(dev, interface)
assert res == ("backprop", {}, dev)

# pylint: disable=protected-access
Expand All @@ -158,7 +170,10 @@ def test_best_method_is_param_shift(self):
dev = qml.device("default.mixed", wires=1)

tape = qml.tape.QuantumScript([], [], shots=50)
res = QNode.get_best_method(dev, None, tape=tape)
with pytest.warns(
qml.PennyLaneDeprecationWarning, match="QNode.get_best_method is deprecated"
):
res = QNode.get_best_method(dev, None, tape=tape)
assert res == (qml.gradients.param_shift, {}, dev)

# pylint: disable=protected-access
Expand All @@ -178,8 +193,10 @@ def capabilities(cls):
dev = qml.device("default.mixed", wires=1)
monkeypatch.setitem(dev._capabilities, "passthru_interface", "some_interface")
monkeypatch.setitem(dev._capabilities, "provides_jacobian", False)

res = QNode.get_best_method(dev, "another_interface")
with pytest.warns(
qml.PennyLaneDeprecationWarning, match="QNode.get_best_method is deprecated"
):
res = QNode.get_best_method(dev, "another_interface")
assert res == (qml.gradients.finite_diff, {}, dev)

# pylint: disable=protected-access
Expand All @@ -191,11 +208,17 @@ def test_best_method_str_is_device(self, monkeypatch):
monkeypatch.setitem(dev._capabilities, "provides_jacobian", True)

# basic check if the device provides a Jacobian
res = QNode.best_method_str(dev, "another_interface")
with pytest.warns(
qml.PennyLaneDeprecationWarning, match="QNode.best_method_str is deprecated"
):
res = QNode.best_method_str(dev, "another_interface")
assert res == "device"

# device is returned even if backpropagation is possible
res = QNode.best_method_str(dev, "some_interface")
with pytest.warns(
qml.PennyLaneDeprecationWarning, match="QNode.best_method_str is deprecated"
):
res = QNode.best_method_str(dev, "some_interface")
assert res == "device"

# pylint: disable=protected-access
Expand All @@ -207,15 +230,21 @@ def test_best_method_str_is_backprop(self, monkeypatch):
monkeypatch.setitem(dev._capabilities, "provides_jacobian", False)

# backprop is returned when the interfaces match and Jacobian is not provided
res = QNode.best_method_str(dev, "some_interface")
with pytest.warns(
qml.PennyLaneDeprecationWarning, match="QNode.best_method_str is deprecated"
):
res = QNode.best_method_str(dev, "some_interface")
assert res == "backprop"

def test_best_method_str_wraps_legacy_device_correctly(self, mocker):
dev_legacy = DefaultQubitLegacy(wires=2)

spy = mocker.spy(qml.devices.LegacyDeviceFacade, "__init__")

QNode.best_method_str(dev_legacy, "some_interface")
with pytest.warns(
qml.PennyLaneDeprecationWarning, match="QNode.best_method_str is deprecated"
):
QNode.best_method_str(dev_legacy, "some_interface")

spy.assert_called_once()

Expand All @@ -227,7 +256,10 @@ def test_best_method_str_is_param_shift(self):

# parameter shift is returned when Jacobian is not provided and
# the backprop interfaces do not match
res = QNode.best_method_str(dev, "another_interface")
with pytest.warns(
qml.PennyLaneDeprecationWarning, match="QNode.best_method_str is deprecated"
):
res = QNode.best_method_str(dev, "another_interface")
assert res == "parameter-shift"

# pylint: disable=protected-access
Expand All @@ -238,7 +270,10 @@ def test_best_method_str_is_finite_diff(self, mocker):

mocker.patch.object(QNode, "get_best_method", return_value=[qml.gradients.finite_diff])

res = QNode.best_method_str(dev, "another_interface")
with pytest.warns(
qml.PennyLaneDeprecationWarning, match="QNode.best_method_str is deprecated"
):
res = QNode.best_method_str(dev, "another_interface")

assert res == "finite-diff"

Expand Down
Loading

0 comments on commit eefe6ef

Please sign in to comment.