diff --git a/doc/development/deprecations.rst b/doc/development/deprecations.rst
index 10487b67eeb..9e2d701c582 100644
--- a/doc/development/deprecations.rst
+++ b/doc/development/deprecations.rst
@@ -9,6 +9,12 @@ deprecations are listed below.
Pending deprecations
--------------------
+* The ``QNode.get_best_method`` and ``QNode.best_method_str`` methods have been deprecated.
+ Instead, use the ``qml.workflow.get_best_diff_method``.
+
+ - Deprecated in v0.40
+ - Will be removed in v0.41
+
* The ``gradient_fn`` keyword argument to ``qml.execute`` has been renamed ``diff_method``.
- Deprecated in v0.40
diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 81105b333c4..7c3aabdfa2a 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -89,6 +89,10 @@
Deprecations 👋
+* The `QNode.get_best_method` and `QNode.best_method_str` methods have been deprecated.
+ Instead, use the `qml.workflow.get_best_diff_method` function.
+ [(#6418)](https://github.com/PennyLaneAI/pennylane/pull/6418)
+
* The `qml.execute` `gradient_fn` keyword argument has been renamed `diff_method`,
to better align with the termionology used by the `QNode`.
`gradient_fn` will be removed in v0.41.
diff --git a/pennylane/workflow/get_gradient_fn.py b/pennylane/workflow/get_gradient_fn.py
index 2359ead6424..15a448c8a1e 100644
--- a/pennylane/workflow/get_gradient_fn.py
+++ b/pennylane/workflow/get_gradient_fn.py
@@ -18,6 +18,7 @@
from typing import Optional, get_args
import pennylane as qml
+from pennylane.logging import debug_logger
from pennylane.transforms.core import TransformDispatcher
from pennylane.workflow.qnode import (
SupportedDeviceAPIs,
@@ -27,6 +28,7 @@
# pylint: disable=too-many-return-statements, unsupported-binary-operation
+@debug_logger
def _get_gradient_fn(
device: SupportedDeviceAPIs,
diff_method: "TransformDispatcher | SupportedDiffMethods" = "best",
@@ -61,8 +63,10 @@ def _get_gradient_fn(
)
if diff_method == "best":
- qn = qml.QNode(lambda: None, device, diff_method=None)
- return qml.workflow.get_best_diff_method(qn)()
+ if tape and any(isinstance(o, qml.operation.CV) for o in tape):
+ return qml.gradients.param_shift_cv
+
+ return qml.gradients.param_shift
if diff_method == "parameter-shift":
if tape and any(isinstance(o, qml.operation.CV) and o.name != "Identity" for o in tape):
diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py
index 69e65200a40..1e6a28e08db 100644
--- a/pennylane/workflow/qnode.py
+++ b/pennylane/workflow/qnode.py
@@ -680,7 +680,10 @@ def get_gradient_fn(
)
if diff_method == "best":
- return QNode.get_best_method(device, interface, tape=tape)
+ if tape and any(isinstance(o, qml.operation.CV) for o in tape):
+ return qml.gradients.param_shift_cv, {"dev": device}, device
+
+ return qml.gradients.param_shift, {}, device
if diff_method == "parameter-shift":
if tape and any(isinstance(o, qml.operation.CV) and o.name != "Identity" for o in tape):
@@ -720,7 +723,13 @@ def get_best_method(
dict[str, Any],
SupportedDeviceAPIs,
]:
- """Returns the 'best' differentiation method
+ """
+ .. warning::
+
+ This method is deprecated in v0.40 and will be removed in v0.41.
+ Instead, use the :func:`qml.workflow.get_best_diff_method <.workflow.get_best_diff_method>` function.
+
+ Returns the 'best' differentiation method
for a particular device and interface combination.
This method attempts to determine support for differentiation
@@ -744,6 +753,13 @@ def get_best_method(
tuple[str or .TransformDispatcher, dict, .device.Device: Tuple containing the ``gradient_fn``,
``gradient_kwargs``, and the device to use when calling the execute function.
"""
+
+ warnings.warn(
+ "QNode.get_best_method is deprecated and will be removed in v0.41. "
+ "Instead, use the qml.workflow.get_best_diff_method function.",
+ qml.PennyLaneDeprecationWarning,
+ )
+
if not isinstance(device, qml.devices.Device):
device = qml.devices.LegacyDeviceFacade(device)
@@ -761,7 +777,14 @@ def get_best_method(
@staticmethod
@debug_logger
def best_method_str(device: SupportedDeviceAPIs, interface: SupportedInterfaceUserInput) -> str:
- """Similar to :meth:`~.get_best_method`, except return the
+ """
+ .. warning::
+
+ This method is deprecated in v0.40 and will be removed in v0.41.
+ Instead, use the :func:`qml.workflow.get_best_diff_method <.workflow.get_best_diff_method>` function.
+
+
+ Similar to :meth:`~.get_best_method`, except return the
'best' differentiation method in human-readable format.
This method attempts to determine support for differentiation
@@ -786,9 +809,19 @@ def best_method_str(device: SupportedDeviceAPIs, interface: SupportedInterfaceUs
Returns:
str: The gradient function to use in human-readable format.
"""
+
+ warnings.warn(
+ "QNode.best_method_str is deprecated and will be removed in v0.41. "
+ "Instead, use the qml.workflow.get_best_diff_method function.",
+ qml.PennyLaneDeprecationWarning,
+ )
+
if not isinstance(device, qml.devices.Device):
device = qml.devices.LegacyDeviceFacade(device)
+ warnings.filterwarnings(
+ "ignore", "QNode.get_best_method is deprecated", qml.PennyLaneDeprecationWarning
+ )
transform = QNode.get_best_method(device, interface)[0]
if transform is qml.gradients.finite_diff:
diff --git a/tests/test_qnode.py b/tests/test_qnode.py
index f56cdfd036c..2a7825611cb 100644
--- a/tests/test_qnode.py
+++ b/tests/test_qnode.py
@@ -37,6 +37,20 @@ def dummyfunc():
return None
+def test_get_best_method_is_deprecated():
+ """Test that is deprecated."""
+ with pytest.warns(qml.PennyLaneDeprecationWarning, match="QNode.get_best_method is deprecated"):
+ dev = qml.device("default.qubit", wires=2)
+ _ = QNode.get_best_method(dev, "jax")
+
+
+def test_best_method_str_is_deprecated():
+ """Test that is deprecated."""
+ with pytest.warns(qml.PennyLaneDeprecationWarning, match="QNode.best_method_str is deprecated"):
+ dev = qml.device("default.qubit", wires=2)
+ _ = QNode.best_method_str(dev, "jax")
+
+
# pylint: disable=unused-argument
class CustomDevice(qml.devices.Device):
"""A null device that just returns 0."""
@@ -177,11 +191,17 @@ def test_best_method_is_device(self):
dev = CustomDeviceWithDiffMethod()
- res = QNode.get_best_method(dev, "jax")
- assert res == ("device", {}, dev)
+ with pytest.warns(
+ qml.PennyLaneDeprecationWarning, match="QNode.get_best_method is deprecated"
+ ):
+ res = QNode.get_best_method(dev, "jax")
+ assert res == ("device", {}, dev)
- res = QNode.get_best_method(dev, None)
- assert res == ("device", {}, dev)
+ with pytest.warns(
+ qml.PennyLaneDeprecationWarning, match="QNode.get_best_method is deprecated"
+ ):
+ res = QNode.get_best_method(dev, None)
+ assert res == ("device", {}, dev)
# pylint: disable=protected-access
@pytest.mark.parametrize("interface", ["jax", "tensorflow", "torch", "autograd"])
@@ -192,8 +212,11 @@ def test_best_method_is_backprop(self, interface):
dev = qml.device("default.qubit", wires=1)
# backprop is returned when the interface is an allowed interface for the device and Jacobian is not provided
- res = QNode.get_best_method(dev, interface)
- assert res == ("backprop", {}, dev)
+ with pytest.warns(
+ qml.PennyLaneDeprecationWarning, match="QNode.get_best_method is deprecated"
+ ):
+ res = QNode.get_best_method(dev, interface)
+ assert res == ("backprop", {}, dev)
# pylint: disable=protected-access
def test_best_method_is_param_shift(self):
@@ -203,14 +226,20 @@ def test_best_method_is_param_shift(self):
# null device has no info - fall back on parameter-shift
dev = CustomDevice()
- res = QNode.get_best_method(dev, None)
- assert res == (qml.gradients.param_shift, {}, dev)
+ with pytest.warns(
+ qml.PennyLaneDeprecationWarning, match="QNode.get_best_method is deprecated"
+ ):
+ res = QNode.get_best_method(dev, None)
+ assert res == (qml.gradients.param_shift, {}, dev)
# no interface - fall back on parameter-shift
dev2 = qml.device("default.qubit", wires=1)
tape = qml.tape.QuantumScript([], [], shots=50)
- res2 = QNode.get_best_method(dev2, None, tape=tape)
- assert res2 == (qml.gradients.param_shift, {}, dev2)
+ with pytest.warns(
+ qml.PennyLaneDeprecationWarning, match="QNode.get_best_method is deprecated"
+ ):
+ res2 = QNode.get_best_method(dev2, None, tape=tape)
+ assert res2 == (qml.gradients.param_shift, {}, dev2)
# pylint: disable=protected-access
@pytest.mark.xfail(
@@ -230,8 +259,11 @@ def capabilities(cls):
# finite differences is the fallback when we know nothing about the device
monkeypatch.setattr(qml.devices.DefaultMixed, "capabilities", capabilities)
- res = QNode.get_best_method(dev, "another_interface")
- assert res == (qml.gradients.finite_diff, {}, dev)
+ with pytest.warns(
+ qml.PennyLaneDeprecationWarning, match=r"QNode.get_best_method is deprecated"
+ ):
+ res = QNode.get_best_method(dev, "another_interface")
+ assert res == (qml.gradients.finite_diff, {}, dev)
# pylint: disable=protected-access, too-many-statements
def test_diff_method(self):
@@ -1462,6 +1494,13 @@ def test_get_gradient_fn_custom_device(self):
assert not kwargs
assert new_dev is self.dev
+ def test_get_gradient_fn_with_best_method_and_cv_ops(self):
+ """Test that get_gradient_fn returns 'parameter-shift-cv' when CV operations are present on tape"""
+ dev = qml.device("default.gaussian", wires=1)
+ tape = qml.tape.QuantumScript([qml.Displacement(0.5, 0.0, wires=0)])
+ res = qml.QNode.get_gradient_fn(dev, interface="autograd", diff_method="best", tape=tape)
+ assert res == (qml.gradients.param_shift_cv, {"dev": dev}, dev)
+
def test_get_gradient_fn_default_qubit(self):
"""Tests the get_gradient_fn is backprop for best for default qubit2."""
dev = qml.devices.DefaultQubit()
diff --git a/tests/test_qnode_legacy.py b/tests/test_qnode_legacy.py
index aa9c8e75f72..f620f51b9e6 100644
--- a/tests/test_qnode_legacy.py
+++ b/tests/test_qnode_legacy.py
@@ -117,7 +117,10 @@ def test_best_method_wraps_legacy_device_correctly(self, mocker):
spy = mocker.spy(qml.devices.LegacyDeviceFacade, "__init__")
- QNode.get_best_method(dev_legacy, "some_interface")
+ with pytest.warns(
+ qml.PennyLaneDeprecationWarning, match="QNode.get_best_method is deprecated"
+ ):
+ QNode.get_best_method(dev_legacy, "some_interface")
spy.assert_called_once()
@@ -133,11 +136,17 @@ def test_best_method_is_device(self, monkeypatch):
monkeypatch.setitem(dev._capabilities, "provides_jacobian", True)
# basic check if the device provides a Jacobian
- res = QNode.get_best_method(dev, "another_interface")
+ with pytest.warns(
+ qml.PennyLaneDeprecationWarning, match="QNode.get_best_method is deprecated"
+ ):
+ res = QNode.get_best_method(dev, "another_interface")
assert res == ("device", {}, dev)
# device is returned even if backpropagation is possible
- res = QNode.get_best_method(dev, "some_interface")
+ with pytest.warns(
+ qml.PennyLaneDeprecationWarning, match="QNode.get_best_method is deprecated"
+ ):
+ res = QNode.get_best_method(dev, "some_interface")
assert res == ("device", {}, dev)
# pylint: disable=protected-access
@@ -148,7 +157,10 @@ def test_best_method_is_backprop(self, interface):
dev = qml.device("default.mixed", wires=1)
# backprop is returned when the interface is an allowed interface for the device and Jacobian is not provided
- res = QNode.get_best_method(dev, interface)
+ with pytest.warns(
+ qml.PennyLaneDeprecationWarning, match="QNode.get_best_method is deprecated"
+ ):
+ res = QNode.get_best_method(dev, interface)
assert res == ("backprop", {}, dev)
# pylint: disable=protected-access
@@ -158,7 +170,10 @@ def test_best_method_is_param_shift(self):
dev = qml.device("default.mixed", wires=1)
tape = qml.tape.QuantumScript([], [], shots=50)
- res = QNode.get_best_method(dev, None, tape=tape)
+ with pytest.warns(
+ qml.PennyLaneDeprecationWarning, match="QNode.get_best_method is deprecated"
+ ):
+ res = QNode.get_best_method(dev, None, tape=tape)
assert res == (qml.gradients.param_shift, {}, dev)
# pylint: disable=protected-access
@@ -178,8 +193,10 @@ def capabilities(cls):
dev = qml.device("default.mixed", wires=1)
monkeypatch.setitem(dev._capabilities, "passthru_interface", "some_interface")
monkeypatch.setitem(dev._capabilities, "provides_jacobian", False)
-
- res = QNode.get_best_method(dev, "another_interface")
+ with pytest.warns(
+ qml.PennyLaneDeprecationWarning, match="QNode.get_best_method is deprecated"
+ ):
+ res = QNode.get_best_method(dev, "another_interface")
assert res == (qml.gradients.finite_diff, {}, dev)
# pylint: disable=protected-access
@@ -191,11 +208,17 @@ def test_best_method_str_is_device(self, monkeypatch):
monkeypatch.setitem(dev._capabilities, "provides_jacobian", True)
# basic check if the device provides a Jacobian
- res = QNode.best_method_str(dev, "another_interface")
+ with pytest.warns(
+ qml.PennyLaneDeprecationWarning, match="QNode.best_method_str is deprecated"
+ ):
+ res = QNode.best_method_str(dev, "another_interface")
assert res == "device"
# device is returned even if backpropagation is possible
- res = QNode.best_method_str(dev, "some_interface")
+ with pytest.warns(
+ qml.PennyLaneDeprecationWarning, match="QNode.best_method_str is deprecated"
+ ):
+ res = QNode.best_method_str(dev, "some_interface")
assert res == "device"
# pylint: disable=protected-access
@@ -207,7 +230,10 @@ def test_best_method_str_is_backprop(self, monkeypatch):
monkeypatch.setitem(dev._capabilities, "provides_jacobian", False)
# backprop is returned when the interfaces match and Jacobian is not provided
- res = QNode.best_method_str(dev, "some_interface")
+ with pytest.warns(
+ qml.PennyLaneDeprecationWarning, match="QNode.best_method_str is deprecated"
+ ):
+ res = QNode.best_method_str(dev, "some_interface")
assert res == "backprop"
def test_best_method_str_wraps_legacy_device_correctly(self, mocker):
@@ -215,7 +241,10 @@ def test_best_method_str_wraps_legacy_device_correctly(self, mocker):
spy = mocker.spy(qml.devices.LegacyDeviceFacade, "__init__")
- QNode.best_method_str(dev_legacy, "some_interface")
+ with pytest.warns(
+ qml.PennyLaneDeprecationWarning, match="QNode.best_method_str is deprecated"
+ ):
+ QNode.best_method_str(dev_legacy, "some_interface")
spy.assert_called_once()
@@ -227,7 +256,10 @@ def test_best_method_str_is_param_shift(self):
# parameter shift is returned when Jacobian is not provided and
# the backprop interfaces do not match
- res = QNode.best_method_str(dev, "another_interface")
+ with pytest.warns(
+ qml.PennyLaneDeprecationWarning, match="QNode.best_method_str is deprecated"
+ ):
+ res = QNode.best_method_str(dev, "another_interface")
assert res == "parameter-shift"
# pylint: disable=protected-access
@@ -238,7 +270,10 @@ def test_best_method_str_is_finite_diff(self, mocker):
mocker.patch.object(QNode, "get_best_method", return_value=[qml.gradients.finite_diff])
- res = QNode.best_method_str(dev, "another_interface")
+ with pytest.warns(
+ qml.PennyLaneDeprecationWarning, match="QNode.best_method_str is deprecated"
+ ):
+ res = QNode.best_method_str(dev, "another_interface")
assert res == "finite-diff"
diff --git a/tests/workflow/test_get_gradient_fn.py b/tests/workflow/test_get_gradient_fn.py
index 914020efd1e..c595b97a7ad 100644
--- a/tests/workflow/test_get_gradient_fn.py
+++ b/tests/workflow/test_get_gradient_fn.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Unit tests for the `qml.workflow.get_gradient_fn`"""
+"""Unit tests for the `qml.workflow._get_gradient_fn`"""
import pytest
@@ -79,7 +79,7 @@ class TestCustomDeviceIntegration:
def test_get_gradient_fn_custom_device(self):
"""Test get_gradient_fn is parameter for best for null device."""
gradient_fn = _get_gradient_fn(self.dev, "best")
- assert gradient_fn == "parameter-shift"
+ assert gradient_fn is qml.gradients.param_shift
def test_get_gradient_fn_custom_dev_adjoint(self):
"""Test that an error is raised if adjoint is requested for a device that does not support it."""
@@ -150,32 +150,39 @@ def test_finite_diff_method(self):
"""Test that get_gradient_fn returns 'finite-diff' for the 'finite-diff' method"""
dev = qml.device("default.qubit", wires=1)
gradient_fn = _get_gradient_fn(dev, diff_method="finite-diff")
- assert gradient_fn == qml.gradients.finite_diff
+ assert gradient_fn is qml.gradients.finite_diff
def test_spsa_method(self):
"""Test that get_gradient_fn returns 'spsa' for the 'spsa' method"""
dev = qml.device("default.qubit", wires=1)
gradient_fn = _get_gradient_fn(dev, diff_method="spsa")
- assert gradient_fn == qml.gradients.spsa_grad
+ assert gradient_fn is qml.gradients.spsa_grad
def test_hadamard_method(self):
"""Test that get_gradient_fn returns 'hadamard' for the 'hadamard' method"""
dev = qml.device("default.qubit", wires=1)
gradient_fn = _get_gradient_fn(dev, diff_method="hadamard")
- assert gradient_fn == qml.gradients.hadamard_grad
+ assert gradient_fn is qml.gradients.hadamard_grad
def test_param_shift_method(self):
"""Test that get_gradient_fn returns 'parameter-shift' for the 'parameter-shift' method"""
dev = qml.device("default.qubit", wires=1)
gradient_fn = _get_gradient_fn(dev, diff_method="parameter-shift")
- assert gradient_fn == qml.gradients.param_shift
+ assert gradient_fn is qml.gradients.param_shift
def test_param_shift_method_with_cv_ops(self):
"""Test that get_gradient_fn returns 'parameter-shift-cv' when CV operations are present on tape"""
dev = qml.device("default.gaussian", wires=1)
tape = qml.tape.QuantumScript([qml.Displacement(0.5, 0.0, wires=0)])
gradient_fn = _get_gradient_fn(dev, diff_method="parameter-shift", tape=tape)
- assert gradient_fn == qml.gradients.param_shift_cv
+ assert gradient_fn is qml.gradients.param_shift_cv
+
+ def test_best_method_with_cv_ops(self):
+ """Test that get_gradient_fn returns 'parameter-shift-cv' when CV operations are present on tape"""
+ dev = qml.device("default.gaussian", wires=1)
+ tape = qml.tape.QuantumScript([qml.Displacement(0.5, 0.0, wires=0)])
+ gradient_fn = _get_gradient_fn(dev, diff_method="best", tape=tape)
+ assert gradient_fn is qml.gradients.param_shift_cv
def test_invalid_diff_method(self):
"""Test that get_gradient_fn raises an error for invalid diff method"""