diff --git a/changes.d/710.feature b/changes.d/710.feature new file mode 100644 index 00000000..650482cc --- /dev/null +++ b/changes.d/710.feature @@ -0,0 +1,2 @@ +Add `with_` family of helper methods to `PulseTemplate` to allow convinient and easily discoverable pulse template +combination. diff --git a/qupulse/pulses/multi_channel_pulse_template.py b/qupulse/pulses/multi_channel_pulse_template.py index 757dc84c..73ae4c0b 100644 --- a/qupulse/pulses/multi_channel_pulse_template.py +++ b/qupulse/pulses/multi_channel_pulse_template.py @@ -92,6 +92,20 @@ def __init__(self, self._register(registry=registry) + def with_parallel_atomic(self, *parallel: 'AtomicPulseTemplate') -> 'AtomicPulseTemplate': + from qupulse.pulses import AtomicMultiChannelPT + if parallel: + if self.identifier: + return AtomicMultiChannelPT(self, *parallel) + else: + return AtomicMultiChannelPT( + *self._subtemplates, *parallel, + measurements=self.measurement_declarations, + parameter_constraints=self.parameter_constraints, + ) + else: + return self + @property def duration(self) -> ExpressionScalar: if self._duration is None: @@ -334,6 +348,15 @@ def get_serialization_data(self, serializer: Optional[Serializer]=None) -> Dict[ data['overwritten_channels'] = self._overwritten_channels return data + def with_parallel_channels(self, values: Mapping[ChannelID, ExpressionLike]) -> 'PulseTemplate': + if self.identifier: + return super().with_parallel_channels(values) + else: + return ParallelConstantChannelPulseTemplate( + self._template, + {**self._overwritten_channels, **values}, + ) + def _is_atomic(self) -> bool: return self._template._is_atomic() diff --git a/qupulse/pulses/parameters.py b/qupulse/pulses/parameters.py index df873041..ea5de0e1 100644 --- a/qupulse/pulses/parameters.py +++ b/qupulse/pulses/parameters.py @@ -239,10 +239,13 @@ def get_serialization_data(self) -> str: return str(self) +ConstraintLike = Union[sympy.Expr, str, ParameterConstraint] + + class ParameterConstrainer: """A class that implements the testing of parameter constraints. It is used by the subclassing pulse templates.""" def __init__(self, *, - parameter_constraints: Optional[Iterable[Union[str, ParameterConstraint]]]) -> None: + parameter_constraints: Optional[Iterable[ConstraintLike]]) -> None: if parameter_constraints is None: self._parameter_constraints = [] else: diff --git a/qupulse/pulses/pulse_template.py b/qupulse/pulses/pulse_template.py index 580d701a..9afc8164 100644 --- a/qupulse/pulses/pulse_template.py +++ b/qupulse/pulses/pulse_template.py @@ -251,6 +251,116 @@ def _create_program(self, *, global_transformation=global_transformation, parent_loop=parent_loop) + def with_parallel_channels(self, values: Mapping[ChannelID, ExpressionLike]) -> 'PulseTemplate': + """Create a new pulse template that sets the given channels to the corresponding values. + + See :class:`~qupulse.pulses.ParallelChannelPulseTemplate` for implementation details and restictions. + + Examples: + >>> from qupulse.pulses import FunctionPT + ... fpt = FunctionPT('sin(0.1 * t)', duration_expression=10) + ... fpt_and_marker = fpt.with_parallel_channels({'marker': 1}) + + Args: + values: Values to be set for each channel. + + Returns: + A newly created pulse template. + """ + from qupulse.pulses.multi_channel_pulse_template import ParallelChannelPulseTemplate + return ParallelChannelPulseTemplate( + self, + values + ) + + def with_repetition(self, repetition_count: ExpressionLike) -> 'PulseTemplate': + """Repeat this pulse template `repetition_count` times via a :class:`~qupulse.pulses.RepetitionPulseTemplate`. + + Examples: + >>> from qupulse.pulses import FunctionPT + ... fpt = FunctionPT('sin(0.1 * t)', duration_expression=10) + ... repeated = fpt.with_repetition('n_periods') + + Args: + repetition_count: Amount of times this pulse template is repeated in the return value. + + Returns: + A newly created pulse template. + """ + from qupulse.pulses.repetition_pulse_template import RepetitionPulseTemplate + return RepetitionPulseTemplate(self, repetition_count) + + def with_mapping(self, *mapping_tuple_args: Mapping, **mapping_kwargs: Mapping) -> 'PulseTemplate': + """Map parameters / channel names / measurement names. You may either specify the mappings as positional + arguments XOR as keyword arguments. Positional arguments are forwarded to + :func:`~qupulse.pulses.MappingPT.from_tuple` which automatically determines the "type" of the mappings. + Keyword arguments must be one of the keyword arguments of :class:`~qupulse.pulses.MappingPT`. + + Args: + *mapping_tuple_args: Mappings for parameters / channel names / measurement names + **mapping_kwargs: Mappings for parameters / channel names / measurement names + + Examples: + Equivalent ways to rename a channel and map a parameter value + >>> from qupulse.pulses import FunctionPT + ... fpt = FunctionPT('sin(f * t)', duration_expression=10, channel='A') + ... mapped = fpt.with_mapping({'f': 0.1}, {'A': 'B'}) + ... mapped.defined_channels + {'B'} + + >>> from qupulse.pulses import FunctionPT + ... fpt = FunctionPT('sin(f * t)', duration_expression=10, channel='A') + ... mapped = fpt.with_mapping(parameter_mapping={'f': 0.1}, channel_mapping={'A': 'B'}) + ... mapped.defined_channels + {'B'} + + Returns: + A newly created mapping pulse template + """ + from qupulse.pulses import MappingPT + + if mapping_kwargs and mapping_tuple_args: + raise ValueError("Only positional argument (auto detection of mapping type) " + "xor keyword arguments are allowed.") + if mapping_tuple_args: + return MappingPT.from_tuple((self, *mapping_tuple_args)) + else: + return MappingPT(self, **mapping_kwargs) + + def with_iteration(self, loop_idx: str, loop_range) -> 'PulseTemplate': + """Create a :class:`~qupulse.pulses.ForLoopPT` with the given index and range. + + Examples: + >>> from qupulse.pulses import ConstantPT + ... const = ConstantPT('t_hold', {'x': 'start_x + i_x * step_x', 'y': 'start_y + i_y * step_y'}) + ... scan_2d = const.with_iteration('i_x', 'n_x').with_iteration('i_y', 'n_y') + """ + from qupulse.pulses import ForLoopPT + return ForLoopPT(self, loop_idx, loop_range) + + def with_time_reversal(self) -> 'PulseTemplate': + """Reverse this pulse template by creating a :class:`~qupulse.pulses.TimeReversalPT`. + + Examples: + >>> from qupulse.pulses import FunctionPT + ... forward = FunctionPT('sin(f * t)', duration_expression=10, channel='A') + ... backward = fpt.with_time_reversal() + ... forward_and_backward = forward @ backward + """ + from qupulse.pulses import TimeReversalPT + return TimeReversalPT(self) + + def with_appended(self, *appended: 'PulseTemplate'): + """Create a :class:`~qupulse.pulses.SequencePT` that represents a sequence of this pulse template and `appended` + + You can also use the `@` operator to do this or call :func:`qupulse.pulses.SequencePT.concatenate` directly. + """ + from qupulse.pulses import SequencePT + if appended: + return SequencePT.concatenate(self, *appended) + else: + return self + def __format__(self, format_spec: str): if format_spec == '': format_spec = self._DEFAULT_FORMAT_SPEC @@ -322,6 +432,13 @@ def __init__(self, *, PulseTemplate.__init__(self, identifier=identifier) MeasurementDefiner.__init__(self, measurements=measurements) + def with_parallel_atomic(self, *parallel: 'AtomicPulseTemplate') -> 'AtomicPulseTemplate': + from qupulse.pulses import AtomicMultiChannelPT + if parallel: + return AtomicMultiChannelPT(self, *parallel) + else: + return self + @property def atomicity(self) -> bool: warnings.warn("Deprecated since neither maintained nor properly designed.", category=DeprecationWarning) diff --git a/qupulse/pulses/repetition_pulse_template.py b/qupulse/pulses/repetition_pulse_template.py index f367b6f6..8df278a5 100644 --- a/qupulse/pulses/repetition_pulse_template.py +++ b/qupulse/pulses/repetition_pulse_template.py @@ -70,6 +70,17 @@ def __init__(self, self._register(registry=registry) + def with_repetition(self, repetition_count: Union[int, str, ExpressionScalar]) -> 'PulseTemplate': + if self.identifier: + return RepetitionPulseTemplate(self, repetition_count) + else: + return RepetitionPulseTemplate( + self.body, + self.repetition_count * repetition_count, + parameter_constraints=self.parameter_constraints, + measurements=self.measurement_declarations + ) + @property def repetition_count(self) -> ExpressionScalar: """The amount of repetitions. Either a constant integer or a ParameterDeclaration object.""" diff --git a/qupulse/pulses/sequence_pulse_template.py b/qupulse/pulses/sequence_pulse_template.py index 9b1f64d1..1ef1bed0 100644 --- a/qupulse/pulses/sequence_pulse_template.py +++ b/qupulse/pulses/sequence_pulse_template.py @@ -2,7 +2,7 @@ combines several other PulseTemplate objects for sequential execution.""" import numpy as np -from typing import Dict, List, Set, Optional, Any, AbstractSet, Union, Callable, cast +from typing import Dict, List, Set, Optional, Any, AbstractSet, Union, Callable, cast, Iterable from numbers import Real import functools import warnings @@ -13,7 +13,7 @@ from qupulse.utils import cached_property from qupulse.utils.types import MeasurementWindow, ChannelID, TimeType from qupulse.pulses.pulse_template import PulseTemplate, AtomicPulseTemplate -from qupulse.pulses.parameters import Parameter, ParameterConstrainer, ParameterNotProvidedException +from qupulse.pulses.parameters import ConstraintLike, ParameterConstrainer from qupulse.pulses.mapping_pulse_template import MappingPulseTemplate, MappingTuple from qupulse._program.waveforms import SequenceWaveform from qupulse.pulses.measurement import MeasurementDeclaration, MeasurementDefiner @@ -38,7 +38,7 @@ class SequencePulseTemplate(PulseTemplate, ParameterConstrainer, MeasurementDefi def __init__(self, *subtemplates: Union[PulseTemplate, MappingTuple], identifier: Optional[str]=None, - parameter_constraints: Optional[List[Union[str, Expression]]]=None, + parameter_constraints: Optional[Iterable[ConstraintLike]]=None, measurements: Optional[List[MeasurementDeclaration]]=None, registry: PulseRegistryType=None) -> None: """Create a new SequencePulseTemplate instance. diff --git a/qupulse/pulses/time_reversal_pulse_template.py b/qupulse/pulses/time_reversal_pulse_template.py index a5415f01..a4758d1a 100644 --- a/qupulse/pulses/time_reversal_pulse_template.py +++ b/qupulse/pulses/time_reversal_pulse_template.py @@ -19,6 +19,13 @@ def __init__(self, inner: PulseTemplate, self._inner = inner self._register(registry=registry) + def with_time_reversal(self) -> 'PulseTemplate': + from qupulse.pulses import TimeReversalPT + if self.identifier: + return TimeReversalPT(self) + else: + return self._inner + @property def parameter_names(self) -> Set[str]: return self._inner.parameter_names diff --git a/tests/pulses/pulse_template_tests.py b/tests/pulses/pulse_template_tests.py index 0e8210bf..0d840865 100644 --- a/tests/pulses/pulse_template_tests.py +++ b/tests/pulses/pulse_template_tests.py @@ -8,6 +8,8 @@ from qupulse.parameter_scope import Scope, DictScope from qupulse.utils.types import ChannelID from qupulse.expressions import Expression, ExpressionScalar +from qupulse.pulses import ConstantPT, FunctionPT, RepetitionPT, ForLoopPT, ParallelChannelPT, MappingPT,\ + TimeReversalPT, AtomicMultiChannelPT from qupulse.pulses.pulse_template import AtomicPulseTemplate, PulseTemplate from qupulse.pulses.parameters import Parameter, ConstantParameter, ParameterNotProvidedException from qupulse.pulses.multi_channel_pulse_template import MultiChannelWaveform @@ -363,6 +365,52 @@ def test_format(self): "{:identifier;duration}".format(a)) +class WithMethodTests(unittest.TestCase): + def setUp(self) -> None: + self.fpt = FunctionPT(1.4, 'sin(f*t)', 'X') + self.cpt = ConstantPT(1.4, {'Y': 'start + idx * step'}) + def test_parallel_channels(self): + expected = ParallelChannelPT(self.fpt, {'K': 'k'}) + actual = self.fpt.with_parallel_channels({'K': 'k'}) + self.assertEqual(expected, actual) + + def test_parallel_channels_optimization(self): + expected = ParallelChannelPT(self.fpt, {'K': 'k', 'C': 'c'}) + actual = self.fpt.with_parallel_channels({'K': 'k'}).with_parallel_channels({'C': 'c'}) + self.assertEqual(expected, actual) + + def test_iteration(self): + expected = ForLoopPT(self.cpt, 'idx', 'n_steps') + actual = self.cpt.with_iteration('idx', 'n_steps') + self.assertEqual(expected, actual) + + def test_appended(self): + expected = self.fpt @ self.fpt.with_time_reversal() + actual = self.fpt.with_appended(self.fpt.with_time_reversal()) + self.assertEqual(expected, actual) + + def test_repetition(self): + expected = RepetitionPT(self.fpt, 6) + actual = self.fpt.with_repetition(6) + self.assertEqual(expected, actual) + + def test_repetition_optimization(self): + # unstable test due to flimsy expression equality :( + expected = RepetitionPT(self.fpt, ExpressionScalar(6) * 2) + actual = self.fpt.with_repetition(6).with_repetition(2) + self.assertEqual(expected, actual) + + def test_time_reversal(self): + expected = TimeReversalPT(self.fpt) + actual = self.fpt.with_time_reversal() + self.assertEqual(expected, actual) + + def test_parallel_atomic(self): + expected = AtomicMultiChannelPT(self.fpt, self.cpt) + actual = self.fpt.with_parallel_atomic(self.cpt) + self.assertEqual(expected, actual) + + class AtomicPulseTemplateTests(unittest.TestCase): def test_internal_create_program(self) -> None: