Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pulse template creation helpers #711

Merged
merged 9 commits into from
Dec 5, 2022
2 changes: 2 additions & 0 deletions changes.d/710.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Add `with_` family of helper methods to `PulseTemplate` to allow convinient and easily discoverable pulse template
combination.
23 changes: 23 additions & 0 deletions qupulse/pulses/multi_channel_pulse_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,20 @@ def __init__(self,

self._register(registry=registry)

def with_parallel_atomic(self, *parallel: 'AtomicPulseTemplate') -> 'AtomicPulseTemplate':
from qupulse.pulses import AtomicMultiChannelPT
if parallel:
if self.identifier:
return AtomicMultiChannelPT(self, *parallel)
else:
return AtomicMultiChannelPT(
*self._subtemplates, *parallel,
measurements=self.measurement_declarations,
parameter_constraints=self.parameter_constraints,
)
else:
return self

@property
def duration(self) -> ExpressionScalar:
if self._duration is None:
Expand Down Expand Up @@ -334,6 +348,15 @@ def get_serialization_data(self, serializer: Optional[Serializer]=None) -> Dict[
data['overwritten_channels'] = self._overwritten_channels
return data

def with_parallel_channels(self, values: Mapping[ChannelID, ExpressionLike]) -> 'PulseTemplate':
if self.identifier:
return super().with_parallel_channels(values)
else:
return ParallelConstantChannelPulseTemplate(
self._template,
{**self._overwritten_channels, **values},
)

def _is_atomic(self) -> bool:
return self._template._is_atomic()

Expand Down
5 changes: 4 additions & 1 deletion qupulse/pulses/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,13 @@ def get_serialization_data(self) -> str:
return str(self)


ConstraintLike = Union[sympy.Expr, str, ParameterConstraint]


class ParameterConstrainer:
"""A class that implements the testing of parameter constraints. It is used by the subclassing pulse templates."""
def __init__(self, *,
parameter_constraints: Optional[Iterable[Union[str, ParameterConstraint]]]) -> None:
parameter_constraints: Optional[Iterable[ConstraintLike]]) -> None:
if parameter_constraints is None:
self._parameter_constraints = []
else:
Expand Down
117 changes: 117 additions & 0 deletions qupulse/pulses/pulse_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,116 @@ def _create_program(self, *,
global_transformation=global_transformation,
parent_loop=parent_loop)

def with_parallel_channels(self, values: Mapping[ChannelID, ExpressionLike]) -> 'PulseTemplate':
"""Create a new pulse template that sets the given channels to the corresponding values.

See :class:`~qupulse.pulses.ParallelChannelPulseTemplate` for implementation details and restictions.

Examples:
>>> from qupulse.pulses import FunctionPT
... fpt = FunctionPT('sin(0.1 * t)', duration_expression=10)
... fpt_and_marker = fpt.with_parallel_channels({'marker': 1})

Args:
values: Values to be set for each channel.

Returns:
A newly created pulse template.
"""
from qupulse.pulses.multi_channel_pulse_template import ParallelChannelPulseTemplate
return ParallelChannelPulseTemplate(
self,
values
)

def with_repetition(self, repetition_count: ExpressionLike) -> 'PulseTemplate':
"""Repeat this pulse template `repetition_count` times via a :class:`~qupulse.pulses.RepetitionPulseTemplate`.

Examples:
>>> from qupulse.pulses import FunctionPT
... fpt = FunctionPT('sin(0.1 * t)', duration_expression=10)
... repeated = fpt.with_repetition('n_periods')

Args:
repetition_count: Amount of times this pulse template is repeated in the return value.

Returns:
A newly created pulse template.
"""
from qupulse.pulses.repetition_pulse_template import RepetitionPulseTemplate
return RepetitionPulseTemplate(self, repetition_count)

def with_mapping(self, *mapping_tuple_args: Mapping, **mapping_kwargs: Mapping) -> 'PulseTemplate':
"""Map parameters / channel names / measurement names. You may either specify the mappings as positional
arguments XOR as keyword arguments. Positional arguments are forwarded to
:func:`~qupulse.pulses.MappingPT.from_tuple` which automatically determines the "type" of the mappings.
Keyword arguments must be one of the keyword arguments of :class:`~qupulse.pulses.MappingPT`.

Args:
*mapping_tuple_args: Mappings for parameters / channel names / measurement names
**mapping_kwargs: Mappings for parameters / channel names / measurement names

Examples:
Equivalent ways to rename a channel and map a parameter value
>>> from qupulse.pulses import FunctionPT
... fpt = FunctionPT('sin(f * t)', duration_expression=10, channel='A')
... mapped = fpt.with_mapping({'f': 0.1}, {'A': 'B'})
... mapped.defined_channels
{'B'}

>>> from qupulse.pulses import FunctionPT
... fpt = FunctionPT('sin(f * t)', duration_expression=10, channel='A')
... mapped = fpt.with_mapping(parameter_mapping={'f': 0.1}, channel_mapping={'A': 'B'})
... mapped.defined_channels
{'B'}

Returns:
A newly created mapping pulse template
"""
from qupulse.pulses import MappingPT

if mapping_kwargs and mapping_tuple_args:
raise ValueError("Only positional argument (auto detection of mapping type) "
"xor keyword arguments are allowed.")
if mapping_tuple_args:
return MappingPT.from_tuple((self, *mapping_tuple_args))
else:
return MappingPT(self, **mapping_kwargs)

def with_iteration(self, loop_idx: str, loop_range) -> 'PulseTemplate':
"""Create a :class:`~qupulse.pulses.ForLoopPT` with the given index and range.

Examples:
>>> from qupulse.pulses import ConstantPT
... const = ConstantPT('t_hold', {'x': 'start_x + i_x * step_x', 'y': 'start_y + i_y * step_y'})
... scan_2d = const.with_iteration('i_x', 'n_x').with_iteration('i_y', 'n_y')
"""
from qupulse.pulses import ForLoopPT
return ForLoopPT(self, loop_idx, loop_range)

def with_time_reversal(self) -> 'PulseTemplate':
"""Reverse this pulse template by creating a :class:`~qupulse.pulses.TimeReversalPT`.

Examples:
>>> from qupulse.pulses import FunctionPT
... forward = FunctionPT('sin(f * t)', duration_expression=10, channel='A')
... backward = fpt.with_time_reversal()
... forward_and_backward = forward @ backward
"""
from qupulse.pulses import TimeReversalPT
return TimeReversalPT(self)

def with_appended(self, *appended: 'PulseTemplate'):
"""Create a :class:`~qupulse.pulses.SequencePT` that represents a sequence of this pulse template and `appended`

You can also use the `@` operator to do this or call :func:`qupulse.pulses.SequencePT.concatenate` directly.
"""
from qupulse.pulses import SequencePT
if appended:
return SequencePT.concatenate(self, *appended)
else:
return self

def __format__(self, format_spec: str):
if format_spec == '':
format_spec = self._DEFAULT_FORMAT_SPEC
Expand Down Expand Up @@ -322,6 +432,13 @@ def __init__(self, *,
PulseTemplate.__init__(self, identifier=identifier)
MeasurementDefiner.__init__(self, measurements=measurements)

def with_parallel_atomic(self, *parallel: 'AtomicPulseTemplate') -> 'AtomicPulseTemplate':
from qupulse.pulses import AtomicMultiChannelPT
if parallel:
return AtomicMultiChannelPT(self, *parallel)
else:
return self

@property
def atomicity(self) -> bool:
warnings.warn("Deprecated since neither maintained nor properly designed.", category=DeprecationWarning)
Expand Down
11 changes: 11 additions & 0 deletions qupulse/pulses/repetition_pulse_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,17 @@ def __init__(self,

self._register(registry=registry)

def with_repetition(self, repetition_count: Union[int, str, ExpressionScalar]) -> 'PulseTemplate':
if self.identifier:
return RepetitionPulseTemplate(self, repetition_count)
else:
return RepetitionPulseTemplate(
self.body,
self.repetition_count * repetition_count,
parameter_constraints=self.parameter_constraints,
measurements=self.measurement_declarations
)

@property
def repetition_count(self) -> ExpressionScalar:
"""The amount of repetitions. Either a constant integer or a ParameterDeclaration object."""
Expand Down
6 changes: 3 additions & 3 deletions qupulse/pulses/sequence_pulse_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
combines several other PulseTemplate objects for sequential execution."""

import numpy as np
from typing import Dict, List, Set, Optional, Any, AbstractSet, Union, Callable, cast
from typing import Dict, List, Set, Optional, Any, AbstractSet, Union, Callable, cast, Iterable
from numbers import Real
import functools
import warnings
Expand All @@ -13,7 +13,7 @@
from qupulse.utils import cached_property
from qupulse.utils.types import MeasurementWindow, ChannelID, TimeType
from qupulse.pulses.pulse_template import PulseTemplate, AtomicPulseTemplate
from qupulse.pulses.parameters import Parameter, ParameterConstrainer, ParameterNotProvidedException
from qupulse.pulses.parameters import ConstraintLike, ParameterConstrainer
from qupulse.pulses.mapping_pulse_template import MappingPulseTemplate, MappingTuple
from qupulse._program.waveforms import SequenceWaveform
from qupulse.pulses.measurement import MeasurementDeclaration, MeasurementDefiner
Expand All @@ -38,7 +38,7 @@ class SequencePulseTemplate(PulseTemplate, ParameterConstrainer, MeasurementDefi
def __init__(self,
*subtemplates: Union[PulseTemplate, MappingTuple],
identifier: Optional[str]=None,
parameter_constraints: Optional[List[Union[str, Expression]]]=None,
parameter_constraints: Optional[Iterable[ConstraintLike]]=None,
measurements: Optional[List[MeasurementDeclaration]]=None,
registry: PulseRegistryType=None) -> None:
"""Create a new SequencePulseTemplate instance.
Expand Down
7 changes: 7 additions & 0 deletions qupulse/pulses/time_reversal_pulse_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ def __init__(self, inner: PulseTemplate,
self._inner = inner
self._register(registry=registry)

def with_time_reversal(self) -> 'PulseTemplate':
from qupulse.pulses import TimeReversalPT
if self.identifier:
return TimeReversalPT(self)
else:
return self._inner

@property
def parameter_names(self) -> Set[str]:
return self._inner.parameter_names
Expand Down
48 changes: 48 additions & 0 deletions tests/pulses/pulse_template_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from qupulse.parameter_scope import Scope, DictScope
from qupulse.utils.types import ChannelID
from qupulse.expressions import Expression, ExpressionScalar
from qupulse.pulses import ConstantPT, FunctionPT, RepetitionPT, ForLoopPT, ParallelChannelPT, MappingPT,\
TimeReversalPT, AtomicMultiChannelPT
from qupulse.pulses.pulse_template import AtomicPulseTemplate, PulseTemplate
from qupulse.pulses.parameters import Parameter, ConstantParameter, ParameterNotProvidedException
from qupulse.pulses.multi_channel_pulse_template import MultiChannelWaveform
Expand Down Expand Up @@ -363,6 +365,52 @@ def test_format(self):
"{:identifier;duration}".format(a))


class WithMethodTests(unittest.TestCase):
def setUp(self) -> None:
self.fpt = FunctionPT(1.4, 'sin(f*t)', 'X')
self.cpt = ConstantPT(1.4, {'Y': 'start + idx * step'})
def test_parallel_channels(self):
expected = ParallelChannelPT(self.fpt, {'K': 'k'})
actual = self.fpt.with_parallel_channels({'K': 'k'})
self.assertEqual(expected, actual)

def test_parallel_channels_optimization(self):
expected = ParallelChannelPT(self.fpt, {'K': 'k', 'C': 'c'})
actual = self.fpt.with_parallel_channels({'K': 'k'}).with_parallel_channels({'C': 'c'})
self.assertEqual(expected, actual)

def test_iteration(self):
expected = ForLoopPT(self.cpt, 'idx', 'n_steps')
actual = self.cpt.with_iteration('idx', 'n_steps')
self.assertEqual(expected, actual)

def test_appended(self):
expected = self.fpt @ self.fpt.with_time_reversal()
actual = self.fpt.with_appended(self.fpt.with_time_reversal())
self.assertEqual(expected, actual)

def test_repetition(self):
expected = RepetitionPT(self.fpt, 6)
actual = self.fpt.with_repetition(6)
self.assertEqual(expected, actual)

def test_repetition_optimization(self):
# unstable test due to flimsy expression equality :(
expected = RepetitionPT(self.fpt, ExpressionScalar(6) * 2)
actual = self.fpt.with_repetition(6).with_repetition(2)
self.assertEqual(expected, actual)

def test_time_reversal(self):
expected = TimeReversalPT(self.fpt)
actual = self.fpt.with_time_reversal()
self.assertEqual(expected, actual)

def test_parallel_atomic(self):
expected = AtomicMultiChannelPT(self.fpt, self.cpt)
actual = self.fpt.with_parallel_atomic(self.cpt)
self.assertEqual(expected, actual)


class AtomicPulseTemplateTests(unittest.TestCase):

def test_internal_create_program(self) -> None:
Expand Down