Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-level dynamic decompositions #6881

Draft
wants to merge 10 commits into
base: cond_dynamic_decomp
Choose a base branch
from
130 changes: 36 additions & 94 deletions pennylane/transforms/decompose.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import pennylane as qml
from pennylane.transforms.core import transform

from pennylane.capture.autograph.transformer import run_autograph


def null_postprocessing(results):
"""A postprocessing function returned by a transform that only converts the batch of results
Expand Down Expand Up @@ -65,9 +67,13 @@ def _operator_decomposition_gen(
def _get_plxpr_decompose(): # pylint: disable=missing-docstring
try:
# pylint: disable=import-outside-toplevel
from jax import make_jaxpr
import jax

from pennylane.capture.primitives import ctrl_transform_prim
from pennylane.capture.primitives import (
ctrl_transform_prim,
AbstractMeasurement,
AbstractOperator,
)
except ImportError: # pragma: no cover
return None, None

Expand All @@ -78,8 +84,9 @@ class DecomposeInterpreter(qml.capture.PlxprInterpreter):
when program capture is enabled.
"""

def __init__(self, gate_set=None, max_expansion=None):
def __init__(self, gate_set=None, max_expansion=None, dynamic_decomposition=False):
self.max_expansion = max_expansion
self.dynamic_decomposition = dynamic_decomposition

if gate_set is None:
gate_set = set(qml.ops.__all__)
Expand Down Expand Up @@ -136,10 +143,13 @@ def decompose_operation(self, op: qml.operation.Operator):
return self.interpret_operation(op)

qml.capture.disable()

try:
decomposition = list(
_operator_decomposition_gen(
op, self.stopping_condition, max_expansion=self.max_expansion
op,
self.stopping_condition,
max_expansion=self.max_expansion,
)
)
finally:
Expand All @@ -159,8 +169,27 @@ def interpret_operation_eqn(self, eqn):
invals = (self.read(invar) for invar in eqn.invars)
with qml.QueuingManager.stop_recording():
op = eqn.primitive.impl(*invals, **eqn.params)
if eqn.outvars[0].__class__.__name__ == "DropVar":
return self.decompose_operation(op)

if isinstance(eqn.outvars[0], jax.core.DropVar):

if not self.dynamic_decomposition:

return self.decompose_operation(op)

else:

if self.gate_set(op):
return self.interpret_operation(op)

args = (*op.parameters, *op.wires)
return_ops_decomposition = run_autograph(op.compute_decomposition)(
*args, **op.hyperparameters
)

if return_ops_decomposition is not None:
for sub_op in return_ops_decomposition:
self.interpret_operation(sub_op)

return op

# pylint: disable=unused-variable,missing-function-docstring
Expand All @@ -179,101 +208,14 @@ def decompose_plxpr_to_plxpr(
def wrapper(*inner_args):
return decomposer.eval(jaxpr, consts, *inner_args)

return make_jaxpr(wrapper)(*args)
return jax.make_jaxpr(wrapper)(*args)

return DecomposeInterpreter, decompose_plxpr_to_plxpr


DecomposeInterpreter, decompose_plxpr_to_plxpr = _get_plxpr_decompose()


@lru_cache
def _get_plxpr_dynamic_decompose(): # pylint: disable=missing-docstring
try:
# pylint: disable=import-outside-toplevel
# pylint: disable=unused-import
import jax

from pennylane.capture.primitives import AbstractMeasurement, AbstractOperator
except ImportError: # pragma: no cover
return None, None

# pylint: disable=redefined-outer-name

class DynamicDecomposeInterpreter(qml.capture.PlxprInterpreter):
"""
Experimental Plxpr Interpreter for applying a dynamic decomposition to operations when program capture is enabled.
"""

def eval_dynamic_decomposition(self, jaxpr_decomp: "jax.core.Jaxpr", consts, *args):
"""
Evaluate a dynamic decomposition of a Jaxpr.

Args:
jaxpr_decomp (jax.core.Jaxpr): the Jaxpr to evaluate
*args: the arguments to use in the evaluation
"""

for arg, invar in zip(args, jaxpr_decomp.invars, strict=True):
self._env[invar] = arg

for const, constvar in zip(consts, jaxpr_decomp.constvars, strict=True):
self._env[constvar] = const

for inner_eqn in jaxpr_decomp.eqns:

custom_handler = self._primitive_registrations.get(inner_eqn.primitive, None)

if custom_handler:
invals = [self.read(invar) for invar in inner_eqn.invars]
outvals = custom_handler(self, *invals, **inner_eqn.params)

elif isinstance(inner_eqn.outvars[0].aval, AbstractOperator):
# This does not currently support nested decompositions
outvals = super().interpret_operation_eqn(inner_eqn)
elif isinstance(inner_eqn.outvars[0].aval, AbstractMeasurement):
outvals = super().interpret_measurement_eqn(inner_eqn)
else:
invals = [self.read(invar) for invar in inner_eqn.invars]
outvals = inner_eqn.primitive.bind(*invals, **inner_eqn.params)

if not inner_eqn.primitive.multiple_results:
outvals = [outvals]

for inner_outvar, inner_outval in zip(inner_eqn.outvars, outvals, strict=True):
self._env[inner_outvar] = inner_outval

def interpret_operation_eqn(self, eqn: "jax.core.JaxprEqn"):
"""
Interpret an equation corresponding to an operator.

Args:
eqn (jax.core.JaxprEqn): a jax equation for an operator.
"""

invals = (self.read(invar) for invar in eqn.invars)
with qml.QueuingManager.stop_recording():
op = eqn.primitive.impl(*invals, **eqn.params)

if isinstance(eqn.outvars[0], jax.core.DropVar):

if op._has_plxpr_decomposition:
jaxpr_decomp = op._plxpr_decomposition()
args = (*op.parameters, *op.wires)
return self.eval_dynamic_decomposition(
jaxpr_decomp.jaxpr, jaxpr_decomp.consts, *args
)

return super().interpret_operation(op)

return op

return DynamicDecomposeInterpreter


DynamicDecomposeInterpreter = _get_plxpr_dynamic_decompose()


@partial(transform, plxpr_transform=decompose_plxpr_to_plxpr)
def decompose(tape, gate_set=None, max_expansion=None):
"""Decomposes a quantum circuit into a user-specified gate set.
Expand Down
Loading