diff --git a/changes.d/709.feature b/changes.d/709.feature new file mode 100644 index 00000000..9bdf45b1 --- /dev/null +++ b/changes.d/709.feature @@ -0,0 +1,3 @@ +Add support for time dependent expressions for arithmetics with atomic pulse templates i.e. ParallelChannelPT and +ArithmeticPT support time dependent expressions if used with atomic pulse templates. +Rename `ParallelConstantChannelPT` to `ParallelChannelPT` to reflect this change. diff --git a/qupulse/_program/transformation.py b/qupulse/_program/transformation.py index 66ccfc04..658263dd 100644 --- a/qupulse/_program/transformation.py +++ b/qupulse/_program/transformation.py @@ -1,4 +1,4 @@ -from typing import Any, Mapping, Set, Tuple, Sequence, AbstractSet, Union, TYPE_CHECKING +from typing import Any, Mapping, Set, Tuple, Sequence, AbstractSet, Union, TYPE_CHECKING, Hashable from abc import abstractmethod from numbers import Real @@ -6,7 +6,11 @@ from qupulse import ChannelID from qupulse.comparable import Comparable -from qupulse.utils.types import SingletonABCMeta +from qupulse.utils.types import SingletonABCMeta, frozendict +from qupulse.expressions import ExpressionScalar + + +_TrafoValue = Union[Real, ExpressionScalar] class Transformation(Comparable): @@ -44,6 +48,9 @@ def is_constant_invariant(self): """Signals if the transformation always maps constants to constants.""" return False + def get_constant_output_channels(self, input_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]: + return frozenset() + class IdentityTransformation(Transformation, metaclass=SingletonABCMeta): def __call__(self, time: Union[np.ndarray, float], @@ -70,6 +77,9 @@ def is_constant_invariant(self): """Signals if the transformation always maps constants to constants.""" return True + def get_constant_output_channels(self, input_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]: + return input_channels + class ChainedTransformation(Transformation): def __init__(self, *transformations: Transformation): @@ -103,12 +113,17 @@ def chain(self, next_transformation) -> Transformation: return chain_transformations(*self.transformations, next_transformation) def __repr__(self): - return 'ChainedTransformation%r' % (self._transformations,) + return f'{type(self).__name__}{self._transformations!r}' def is_constant_invariant(self): """Signals if the transformation always maps constants to constants.""" return all(trafo.is_constant_invariant() for trafo in self._transformations) + def get_constant_output_channels(self, input_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]: + for trafo in self._transformations: + input_channels = trafo.get_constant_output_channels(input_channels) + return input_channels + class LinearTransformation(Transformation): def __init__(self, @@ -192,9 +207,12 @@ def is_constant_invariant(self): """Signals if the transformation always maps constants to constants.""" return True + def get_constant_output_channels(self, input_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]: + return input_channels + class OffsetTransformation(Transformation): - def __init__(self, offsets: Mapping[ChannelID, Real]): + def __init__(self, offsets: Mapping[ChannelID, _TrafoValue]): """Adds an offset to each channel specified in offsets. Channels not in offsets are forewarded @@ -202,11 +220,13 @@ def __init__(self, offsets: Mapping[ChannelID, Real]): Args: offsets: Channel -> offset mapping """ - self._offsets = dict(offsets.items()) + self._offsets = frozendict(offsets) + assert _are_valid_transformation_expressions(self._offsets) def __call__(self, time: Union[np.ndarray, float], data: Mapping[ChannelID, Union[np.ndarray, float]]) -> Mapping[ChannelID, Union[np.ndarray, float]]: - return {channel: channel_values + self._offsets[channel] if channel in self._offsets else channel_values + offsets = _instantiate_expression_dict(time, self._offsets) + return {channel: channel_values + offsets[channel] if channel in offsets else channel_values for channel, channel_values in data.items()} def get_input_channels(self, output_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]: @@ -216,24 +236,29 @@ def get_output_channels(self, input_channels: AbstractSet[ChannelID]) -> Abstrac return input_channels @property - def compare_key(self) -> frozenset: - return frozenset(self._offsets.items()) + def compare_key(self) -> Hashable: + return self._offsets def __repr__(self): - return 'OffsetTransformation(%r)' % self._offsets + return f'{type(self).__name__}({dict(self._offsets)!r})' def is_constant_invariant(self): """Signals if the transformation always maps constants to constants.""" - return True + return not _has_time_dependent_values(self._offsets) + + def get_constant_output_channels(self, input_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]: + return _get_constant_output_channels(self._offsets, input_channels) class ScalingTransformation(Transformation): - def __init__(self, factors: Mapping[ChannelID, Real]): - self._factors = dict(factors.items()) + def __init__(self, factors: Mapping[ChannelID, _TrafoValue]): + self._factors = frozendict(factors) + assert _are_valid_transformation_expressions(self._factors) def __call__(self, time: Union[np.ndarray, float], data: Mapping[ChannelID, Union[np.ndarray, float]]) -> Mapping[ChannelID, Union[np.ndarray, float]]: - return {channel: channel_values * self._factors[channel] if channel in self._factors else channel_values + factors = _instantiate_expression_dict(time, self._factors) + return {channel: channel_values * factors[channel] if channel in factors else channel_values for channel, channel_values in data.items()} def get_input_channels(self, output_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]: @@ -243,15 +268,18 @@ def get_output_channels(self, input_channels: AbstractSet[ChannelID]) -> Abstrac return input_channels @property - def compare_key(self) -> frozenset: - return frozenset(self._factors.items()) + def compare_key(self) -> Hashable: + return self._factors def __repr__(self): - return 'ScalingTransformation(%r)' % self._factors + return f'{type(self).__name__}({dict(self._factors)!r})' def is_constant_invariant(self): """Signals if the transformation always maps constants to constants.""" - return True + return not _has_time_dependent_values(self._factors) + + def get_constant_output_channels(self, input_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]: + return _get_constant_output_channels(self._factors, input_channels) try: @@ -277,25 +305,30 @@ def linear_transformation_from_pandas(transformation: PandasDataFrameType) -> Li pass -class ParallelConstantChannelTransformation(Transformation): - def __init__(self, channels: Mapping[ChannelID, Real]): - """Set channel values to given values regardless their former existence +class ParallelChannelTransformation(Transformation): + def __init__(self, channels: Mapping[ChannelID, _TrafoValue]): + """Set channel values to given values regardless their former existence. The values can be time dependent + expressions. Args: channels: Channels present in this map are set to the given value. """ - self._channels = {channel: float(value) - for channel, value in channels.items()} + self._channels: Mapping[ChannelID, _TrafoValue] = frozendict(channels.items()) + assert _are_valid_transformation_expressions(self._channels) def __call__(self, time: Union[np.ndarray, float], data: Mapping[ChannelID, Union[np.ndarray, float]]) -> Mapping[ChannelID, Union[np.ndarray, float]]: - overwritten = {channel: np.full_like(time, fill_value=value, dtype=float) - for channel, value in self._channels.items()} + overwritten = self._instantiated_values(time) return {**data, **overwritten} + def _instantiated_values(self, time): + scope = {'t': time} + return {channel: value.evaluate_in_scope(scope) if hasattr(value, 'evaluate_in_scope') else np.full_like(time, fill_value=value, dtype=float) + for channel, value in self._channels.items()} + @property - def compare_key(self) -> Tuple[Tuple[ChannelID, float], ...]: - return tuple(sorted(self._channels.items())) + def compare_key(self) -> Hashable: + return self._channels def get_input_channels(self, output_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]: return output_channels - self._channels.keys() @@ -304,11 +337,21 @@ def get_output_channels(self, input_channels: AbstractSet[ChannelID]) -> Abstrac return input_channels | self._channels.keys() def __repr__(self): - return 'ParallelConstantChannelTransformation(%r)' % self._channels + return f'{type(self).__name__}({dict(self._channels)!r})' def is_constant_invariant(self): """Signals if the transformation always maps constants to constants.""" - return True + return not _has_time_dependent_values(self._channels) + + def get_constant_output_channels(self, input_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]: + output_channels = set(input_channels) + for ch, value in self._channels.items(): + if hasattr(value, 'variables'): + output_channels.discard(ch) + else: + output_channels.add(ch) + + return output_channels def chain_transformations(*transformations: Transformation) -> Transformation: @@ -325,4 +368,33 @@ def chain_transformations(*transformations: Transformation) -> Transformation: elif len(parsed_transformations) == 1: return parsed_transformations[0] else: - return ChainedTransformation(*parsed_transformations) \ No newline at end of file + return ChainedTransformation(*parsed_transformations) + + +def _instantiate_expression_dict(time, expressions: Mapping[str, _TrafoValue]) -> Mapping[str, Union[Real, np.ndarray]]: + scope = {'t': time} + modified_expressions = {} + for name, value in expressions.items(): + if hasattr(value, 'evaluate_in_scope'): + modified_expressions[name] = value.evaluate_in_scope(scope) + if modified_expressions: + return {**expressions, **modified_expressions} + else: + return expressions + + +def _has_time_dependent_values(expressions: Mapping[ChannelID, _TrafoValue]) -> bool: + return any(hasattr(value, 'variables') + for value in expressions.values()) + + +def _get_constant_output_channels(expressions: Mapping[ChannelID, _TrafoValue], + constant_input_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]: + return {ch + for ch in constant_input_channels + if not hasattr(expressions.get(ch, None), 'variables')} + +def _are_valid_transformation_expressions(expressions: Mapping[ChannelID, _TrafoValue]) -> bool: + return all(expr.variables == ('t',) + for expr in expressions.values() + if hasattr(expr, 'variables')) diff --git a/qupulse/pulses/__init__.py b/qupulse/pulses/__init__.py index 8c8af596..4a8e1016 100644 --- a/qupulse/pulses/__init__.py +++ b/qupulse/pulses/__init__.py @@ -7,7 +7,8 @@ from qupulse.pulses.function_pulse_template import FunctionPulseTemplate as FunctionPT from qupulse.pulses.loop_pulse_template import ForLoopPulseTemplate as ForLoopPT from qupulse.pulses.multi_channel_pulse_template import AtomicMultiChannelPulseTemplate as AtomicMultiChannelPT,\ - ParallelConstantChannelPulseTemplate as ParallelConstantChannelPT + ParallelConstantChannelPulseTemplate as ParallelConstantChannelPT,\ + ParallelChannelPulseTemplate as ParallelChannelPT from qupulse.pulses.mapping_pulse_template import MappingPulseTemplate as MappingPT from qupulse.pulses.repetition_pulse_template import RepetitionPulseTemplate as RepetitionPT from qupulse.pulses.sequence_pulse_template import SequencePulseTemplate as SequencePT @@ -30,5 +31,4 @@ __all__ = ["FunctionPT", "ForLoopPT", "AtomicMultiChannelPT", "MappingPT", "RepetitionPT", "SequencePT", "TablePT", "PointPT", "ConstantPT", "AbstractPT", "ParallelConstantChannelPT", "ArithmeticPT", "ArithmeticAtomicPT", - "TimeReversalPT"] - + "TimeReversalPT", "ParallelChannelPT"] diff --git a/qupulse/pulses/arithmetic_pulse_template.py b/qupulse/pulses/arithmetic_pulse_template.py index 2e997472..a10777b8 100644 --- a/qupulse/pulses/arithmetic_pulse_template.py +++ b/qupulse/pulses/arithmetic_pulse_template.py @@ -198,24 +198,31 @@ def __init__(self, rhs: Union[PulseTemplate, ExpressionLike, Mapping[ChannelID, ExpressionLike]], *, identifier: Optional[str] = None): - """Allowed operations + """Implements the arithmetics between an aribrary pulse template and scalar values. The values can be the same + for all channels, channel specific or only for a subset of the inner pulse templates defined channels. + The expression may be time dependent if the pulse template is atomic. - scalar + pulse_template - scalar - pulse_template - scalar * pulse_template - pulse_template + scalar - pulse_template - scalar - pulse_template * scalar - pulse_template / scalar + A channel dependent scalar is represented by a mapping of ChannelID -> Expression. + + The allowed operations are: + scalar + pulse_template + scalar - pulse_template + scalar * pulse_template + pulse_template + scalar + pulse_template - scalar + pulse_template * scalar + pulse_template / scalar Args: lhs: Left hand side operand arithmetic_operator: String representation of the operator rhs: Right hand side operand - identifier: + identifier: Identifier used for serialization Raises: - TypeError if both or none of the operands are pulse templates + TypeError: If both or none of the operands are pulse templates or if there is a time dependent expression + and a composite pulse template. + ValueError: If the scalar is a mapping and contains channels that are not defined on the pulse template. """ PulseTemplate.__init__(self, identifier=identifier) @@ -243,11 +250,15 @@ def __init__(self, self._lhs = lhs self._rhs = rhs - self._pulse_template = pulse_template + self._pulse_template: PulseTemplate = pulse_template self._scalar = scalar self._arithmetic_operator = arithmetic_operator + if not self._pulse_template._is_atomic() and _is_time_dependent(self._scalar): + raise TypeError("A time dependent ArithmeticPulseTemplate scalar operand currently requires an atomic " + "pulse template as the other operand.", self) + @staticmethod def _parse_operand(operand: Union[ExpressionLike, Mapping[ChannelID, ExpressionLike]], channels: Set[ChannelID]) -> Union[ExpressionScalar, Mapping[ChannelID, ExpressionScalar]]: @@ -298,7 +309,7 @@ def _get_scalar_value(self, if channel_mapping[channel]} else: - return {channel_mapping[channel]: value.evaluate_in_scope(parameters) + return {channel_mapping[channel]: value.evaluate_symbolic(parameters) if 't' in value.variables else value.evaluate_in_scope(parameters) for channel, value in self._scalar.items() if channel_mapping[channel]} @@ -479,9 +490,11 @@ def measurement_names(self) -> Set[str]: @cached_property def _scalar_operand_parameters(self) -> FrozenSet[str]: if isinstance(self._scalar, dict): - return frozenset(*(value.variables for value in self._scalar.values())) + return frozenset(variable + for value in self._scalar.values() + for variable in value.variables) - {'t'} else: - return frozenset(self._scalar.variables) + return frozenset(self._scalar.variables) - {'t'} @property def parameter_names(self) -> Set[str]: @@ -499,6 +512,9 @@ def get_measurement_windows(self, measurement_mapping=measurement_mapping)) return measurements + def _is_atomic(self): + return self._pulse_template._is_atomic() + def try_operation(lhs: Union[PulseTemplate, ExpressionLike, Mapping[ChannelID, ExpressionLike]], op: str, @@ -531,6 +547,13 @@ def try_operation(lhs: Union[PulseTemplate, ExpressionLike, Mapping[ChannelID, E return NotImplemented +def _is_time_dependent(scalar: Union[ExpressionScalar, Dict[str, ExpressionScalar]]) -> bool: + if isinstance(scalar, dict): + return any('t' in value.variables for value in scalar.values()) + else: + return 't' in scalar.variables + + class UnequalDurationWarningInArithmeticPT(RuntimeWarning): """Signals that an ArithmeticAtomicPulseTemplate was constructed from operands with unequal duration. This is a separate class to allow easy silencing.""" diff --git a/qupulse/pulses/mapping_pulse_template.py b/qupulse/pulses/mapping_pulse_template.py index f235b28e..c5998da3 100644 --- a/qupulse/pulses/mapping_pulse_template.py +++ b/qupulse/pulses/mapping_pulse_template.py @@ -112,7 +112,7 @@ def __init__(self, template: PulseTemplate, *, for k, v in template.channel_mapping.items()} template = template.template - self.__template = template + self.__template: PulseTemplate = template self.__parameter_mapping = FrozenDict(parameter_mapping) self.__external_parameters = set(itertools.chain(*(expr.variables for expr in self.__parameter_mapping.values()))) self.__external_parameters |= self.constrained_parameters @@ -230,6 +230,9 @@ def get_serialization_data(self, serializer: Optional[Serializer]=None) -> Dict[ return data + def _is_atomic(self): + return self.__template._is_atomic() + @classmethod def deserialize(cls, serializer: Optional[Serializer]=None, # compatibility to old serialization routines, deprecated diff --git a/qupulse/pulses/multi_channel_pulse_template.py b/qupulse/pulses/multi_channel_pulse_template.py index 1249f476..757dc84c 100644 --- a/qupulse/pulses/multi_channel_pulse_template.py +++ b/qupulse/pulses/multi_channel_pulse_template.py @@ -4,7 +4,7 @@ Classes: - MultiChannelPulseTemplate: A pulse template defined for several channels by combining pulse templates - - MultiChannelWaveform: A waveform defined for several channels by combining waveforms + - ParallelChannelPulseTemplate: A pulse template to add channels to an existing pulse template. """ from typing import Dict, List, Optional, Any, AbstractSet, Union, Set, Sequence, Mapping @@ -18,7 +18,7 @@ from qupulse.utils.sympy import almost_equal, Sympifyable from qupulse.utils.types import ChannelID, TimeType from qupulse._program.waveforms import MultiChannelWaveform, Waveform, TransformingWaveform -from qupulse._program.transformation import ParallelConstantChannelTransformation, Transformation, chain_transformations +from qupulse._program.transformation import ParallelChannelTransformation, Transformation, chain_transformations from qupulse.pulses.pulse_template import PulseTemplate, AtomicPulseTemplate from qupulse.pulses.mapping_pulse_template import MappingPulseTemplate, MappingTuple from qupulse.pulses.parameters import Parameter, ParameterConstrainer @@ -29,7 +29,6 @@ class AtomicMultiChannelPulseTemplate(AtomicPulseTemplate, ParameterConstrainer): - """Combines multiple PulseTemplates that are defined on different channels into an AtomicPulseTemplate.""" def __init__(self, *subtemplates: Union[AtomicPulseTemplate, MappingTuple, MappingPulseTemplate], identifier: Optional[str] = None, @@ -37,9 +36,11 @@ def __init__(self, measurements: Optional[List[MeasurementDeclaration]] = None, registry: PulseRegistryType = None, duration: Optional[ExpressionLike] = None) -> None: - """Parallels multiple AtomicPulseTemplates of the same duration. If the duration keyword argument is given - it is enforced that the instantiated pulse template has this duration. If duration is None the duration of the - PT is the duration of the first subtemplate. There are probably changes to this behaviour in the future. + """Combines multiple AtomicPulseTemplates of the same duration that are defined on different channels into an + AtomicPulseTemplate. + If the duration keyword argument is given it is enforced that the instantiated pulse template has this duration. + If duration is None the duration of the PT is the duration of the first subtemplate. + There are probably changes to this behaviour in the future. Args: *subtemplates: Positional arguments are subtemplates to combine. @@ -209,18 +210,34 @@ def final_values(self) -> Dict[ChannelID, ExpressionScalar]: return values -class ParallelConstantChannelPulseTemplate(PulseTemplate): +class ParallelChannelPulseTemplate(PulseTemplate): def __init__(self, template: PulseTemplate, overwritten_channels: Mapping[ChannelID, Union[ExpressionScalar, Sympifyable]], *, identifier: Optional[str]=None, - registry: Optional[PulseRegistryType]=None): + registry: Optional[PulseRegistryType] = None): + """Pulse template to add new or overwrite existing channels of a contained pulse template. The channel values + may be time dependent if the contained pulse template is atomic. + + Args: + template: Inner pulse template where all channels that are not overwritten will stay the same. + overwritten_channels: Mapping of channels to values that this channel will have. This can overwrite existing + channels or add new ones. May be time dependent if template is atomic. + identifier: Name of the pulse template for serialization + registry: Pulse template gets registered here if not None. + """ super().__init__(identifier=identifier) self._template = template self._overwritten_channels = {channel: ExpressionScalar(value) for channel, value in overwritten_channels.items()} + if not template._is_atomic(): + for expr in self._overwritten_channels.values(): + if 't' in expr.variables: + raise TypeError(f"{type(self).__name__} currently only supports time dependent expressions if the " + f"pulse template is atomic.", self) + self._register(registry=registry) @property @@ -234,8 +251,10 @@ def overwritten_channels(self) -> Mapping[str, ExpressionScalar]: def _get_overwritten_channels_values(self, parameters: Mapping[str, Union[numbers.Real]], channel_mapping: Dict[ChannelID, Optional[ChannelID]] - ) -> Dict[str, numbers.Real]: - return {channel_mapping[name]: value.evaluate_in_scope(parameters) + ) -> Dict[str, Union[numbers.Real, ExpressionScalar]]: + """Return a dictionary of ChannelID to channel value mappings. The channel values can bei either numbers or time + dependent expressions.""" + return {channel_mapping[name]: value.evaluate_symbolic(parameters) if 't' in value.variables else value.evaluate_in_scope(parameters) for name, value in self.overwritten_channels.items() if channel_mapping[name] is not None} @@ -245,7 +264,7 @@ def _internal_create_program(self, *, channel_mapping: Dict[ChannelID, Optional[ChannelID]], **kwargs): overwritten_channels = self._get_overwritten_channels_values(parameters=scope, channel_mapping=channel_mapping) - transformation = ParallelConstantChannelTransformation(overwritten_channels) + transformation = ParallelChannelTransformation(overwritten_channels) if global_transformation is not None: transformation = chain_transformations(global_transformation, transformation) @@ -262,7 +281,7 @@ def build_waveform(self, parameters: Dict[str, numbers.Real], if inner_waveform: overwritten_channels = self._get_overwritten_channels_values(parameters=parameters, channel_mapping=channel_mapping) - transformation = ParallelConstantChannelTransformation(overwritten_channels) + transformation = ParallelChannelTransformation(overwritten_channels) return TransformingWaveform.from_transformation(inner_waveform, transformation) @property @@ -275,7 +294,7 @@ def measurement_names(self) -> AbstractSet[str]: @property def transformation_parameters(self) -> AbstractSet[str]: - return set().union(*(value.variables for value in self.overwritten_channels.values())) + return set().union(*(value.variables for value in self.overwritten_channels.values())) - {'t'} @property def parameter_names(self): @@ -315,6 +334,12 @@ def get_serialization_data(self, serializer: Optional[Serializer]=None) -> Dict[ data['overwritten_channels'] = self._overwritten_channels return data + def _is_atomic(self) -> bool: + return self._template._is_atomic() + + +ParallelConstantChannelPulseTemplate = ParallelChannelPulseTemplate + class ChannelMappingException(Exception): def __init__(self, obj1, obj2, intersect_set): diff --git a/qupulse/pulses/pulse_template.py b/qupulse/pulses/pulse_template.py index 1ed0786e..580d701a 100644 --- a/qupulse/pulses/pulse_template.py +++ b/qupulse/pulses/pulse_template.py @@ -6,6 +6,7 @@ - AtomicPulseTemplate: PulseTemplate that does imply any control flow disruptions and can be directly translated into a waveform. """ +import warnings from abc import abstractmethod from typing import Dict, Tuple, Set, Optional, Union, List, Callable, Any, Generic, TypeVar, Mapping import itertools @@ -79,6 +80,10 @@ def num_channels(self) -> int: """The number of channels this PulseTemplate defines""" return len(self.defined_channels) + def _is_atomic(self) -> bool: + """This is (currently a private) a check if this pulse template always is translated into a single waveform.""" + return False + def __matmul__(self, other: Union['PulseTemplate', MappingTuple]) -> 'SequencePulseTemplate': """This method enables using the @-operator (intended for matrix multiplication) for concatenating pulses. If one of the pulses is a SequencePulseTemplate the other pulse gets merged into it""" @@ -319,6 +324,10 @@ def __init__(self, *, @property def atomicity(self) -> bool: + warnings.warn("Deprecated since neither maintained nor properly designed.", category=DeprecationWarning) + return True + + def _is_atomic(self) -> bool: return True measurement_names = MeasurementDefiner.measurement_names diff --git a/qupulse/pulses/time_reversal_pulse_template.py b/qupulse/pulses/time_reversal_pulse_template.py index 83997477..a5415f01 100644 --- a/qupulse/pulses/time_reversal_pulse_template.py +++ b/qupulse/pulses/time_reversal_pulse_template.py @@ -58,3 +58,6 @@ def get_serialization_data(self, serializer=None): **super().get_serialization_data(), 'inner': self._inner } + + def _is_atomic(self) -> bool: + return self._inner._is_atomic() diff --git a/qupulse/utils/types.py b/qupulse/utils/types.py index 7b467b0d..7ec9d8c7 100644 --- a/qupulse/utils/types.py +++ b/qupulse/utils/types.py @@ -21,7 +21,7 @@ import qupulse.utils.numeric as qupulse_numeric __all__ = ["MeasurementWindow", "ChannelID", "HashableNumpyArray", "TimeType", "time_from_float", "DocStringABCMeta", - "SingletonABCMeta", "SequenceProxy"] + "SingletonABCMeta", "SequenceProxy", "frozendict"] MeasurementWindow = typing.Tuple[str, numbers.Real, numbers.Real] ChannelID = typing.Union[str, int] diff --git a/tests/_program/transformation_tests.py b/tests/_program/transformation_tests.py index e75e17dc..3f366482 100644 --- a/tests/_program/transformation_tests.py +++ b/tests/_program/transformation_tests.py @@ -3,8 +3,10 @@ import numpy as np + +from qupulse.expressions import ExpressionScalar from qupulse._program.transformation import LinearTransformation, Transformation, IdentityTransformation,\ - ChainedTransformation, ParallelConstantChannelTransformation, chain_transformations, OffsetTransformation,\ + ChainedTransformation, ParallelChannelTransformation, chain_transformations, OffsetTransformation,\ ScalingTransformation @@ -18,6 +20,9 @@ def get_output_channels(self, input_channels): def get_input_channels(self, output_channels): raise NotImplementedError() + def get_constant_output_channels(self, input_channels): + raise NotImplementedError() + @property def compare_key(self): return id(self) @@ -201,6 +206,8 @@ def test_scalar_trafo_works(self): def test_constant_propagation(self): self.assertTrue(IdentityTransformation().is_constant_invariant()) + chans = {'a', 'b'} + self.assertIs(chans, IdentityTransformation().get_constant_output_channels(chans)) class ChainedTransformationTests(unittest.TestCase): @@ -282,68 +289,89 @@ def test_repr(self): def test_constant_propagation(self): trafo = ChainedTransformation(ScalingTransformation({'a': 1.1}), OffsetTransformation({'b': 6.6})) self.assertTrue(trafo.is_constant_invariant()) + self.assertEqual({'a', 'b', 'c'}, trafo.get_constant_output_channels({'a', 'b', 'c'})) + trafo = ChainedTransformation(ScalingTransformation({'a': 1.1}), TransformationStub()) self.assertFalse(trafo.is_constant_invariant()) -class ParallelConstantChannelTransformationTests(unittest.TestCase): +class ParallelChannelTransformationTests(unittest.TestCase): def test_init(self): - channels = {'X': 2, 'Y': 4.4} + channels = {'X': 2, 'Y': 4.4, 'Z': ExpressionScalar('t')} - trafo = ParallelConstantChannelTransformation(channels) + trafo = ParallelChannelTransformation(channels) self.assertEqual(trafo._channels, channels) - self.assertTrue(all(isinstance(v, float) for v in trafo._channels.values())) - - self.assertEqual(trafo.compare_key, (('X', 2.), ('Y', 4.4))) self.assertEqual(trafo.get_input_channels(set()), set()) self.assertEqual(trafo.get_input_channels({'X'}), set()) - self.assertEqual(trafo.get_input_channels({'Z'}), {'Z'}) - self.assertEqual(trafo.get_input_channels({'X', 'Z'}), {'Z'}) + self.assertEqual(trafo.get_input_channels({'K'}), {'K'}) + self.assertEqual(trafo.get_input_channels({'X', 'Z', 'K'}), {'K'}) + + self.assertEqual(trafo.get_output_channels(set()), {'X', 'Y', 'Z'}) + self.assertEqual(trafo.get_output_channels({'X'}), {'X', 'Y', 'Z'}) + self.assertEqual(trafo.get_output_channels({'X', 'Z', 'K'}), {'X', 'Y', 'Z', 'K'}) - self.assertEqual(trafo.get_output_channels(set()), {'X', 'Y'}) - self.assertEqual(trafo.get_output_channels({'X'}), {'X', 'Y'}) - self.assertEqual(trafo.get_output_channels({'X', 'Z'}), {'X', 'Y', 'Z'}) + self.assertEqual(trafo.get_constant_output_channels({'X', 'Y', 'Z', 'K'}), {'X', 'Y', 'K'}) def test_trafo(self): - channels = {'X': 2, 'Y': 4.4} - trafo = ParallelConstantChannelTransformation(channels) + channels = {'X': 2, 'Y': 4.4, 'Z': ExpressionScalar('t')} + trafo = ParallelChannelTransformation(channels) n_points = 17 time = np.arange(17, dtype=float) expected_overwrites = {'X': np.full((n_points,), 2.), - 'Y': np.full((n_points,), 4.4)} + 'Y': np.full((n_points,), 4.4), + 'Z': time} empty_input_result = trafo(time, {}) np.testing.assert_equal(empty_input_result, expected_overwrites) - z_input_result = trafo(time, {'Z': np.sin(time)}) - np.testing.assert_equal(z_input_result, {'Z': np.sin(time), **expected_overwrites}) + k_input = {'K': np.sin(time)} + k_input_result = trafo(time, k_input) + np.testing.assert_equal(k_input_result, {**k_input, **expected_overwrites}) x_input_result = trafo(time, {'X': np.cos(time)}) - np.testing.assert_equal(empty_input_result, expected_overwrites) + np.testing.assert_equal(x_input_result, expected_overwrites) - x_z_input_result = trafo(time, {'X': np.cos(time), 'Z': np.sin(time)}) - np.testing.assert_equal(z_input_result, {'Z': np.sin(time), **expected_overwrites}) + x_k_input_result = trafo(time, {'X': np.cos(time), 'K': np.sin(time)}) + np.testing.assert_equal(x_k_input_result, {'K': np.sin(time), **expected_overwrites}) def test_repr(self): channels = {'X': 2, 'Y': 4.4} - trafo = ParallelConstantChannelTransformation(channels) + trafo = ParallelChannelTransformation(channels) self.assertEqual(trafo, eval(repr(trafo))) def test_scalar_trafo_works(self): channels = {'X': 2, 'Y': 4.4} - trafo = ParallelConstantChannelTransformation(channels) + trafo = ParallelChannelTransformation(channels) assert_scalar_trafo_works(self, trafo, {'a': 0., 'b': 0.3, 'c': 0.6}) def test_constant_propagation(self): channels = {'X': 2, 'Y': 4.4} - trafo = ParallelConstantChannelTransformation(channels) + trafo = ParallelChannelTransformation(channels) self.assertTrue(trafo.is_constant_invariant()) + def test_time_dependence(self): + channels = {'X': 2, 'Y': ExpressionScalar('sin(t)')} + trafo = ParallelChannelTransformation(channels) + self.assertEqual({'X', 'K'}, trafo.get_constant_output_channels({'X', 'Y', 'K'})) + + t = np.linspace(0., 1., num=50) + values = { + 'X': np.cos(t), + 'Y': 4. * np.ones_like(t), + 'K': 5. * np.ones_like(t) + } + transformed = trafo(t, values) + np.testing.assert_equal({ + 'X': np.ones_like(t) * 2, + 'Y': np.sin(t), + 'K': values['K'] + }, transformed) + class TestChaining(unittest.TestCase): def test_identity_result(self): @@ -376,10 +404,19 @@ def test_chaining(self): self.assertEqual(result, expected) + def test_constant_propagation(self): + chained = ChainedTransformation( + ScalingTransformation({'K': 1.1, 'X': ExpressionScalar('sin(t)')}), + OffsetTransformation({'K': 2.2, 'Y': ExpressionScalar('cos(t)')}), + ParallelChannelTransformation({'Z': ExpressionScalar('exp(t)')}) + ) + + self.assertEqual({'K', 'other'}, chained.get_constant_output_channels({'K', 'X', 'Y', 'Z', 'other'})) + class TestOffsetTransformation(unittest.TestCase): def setUp(self) -> None: - self.offsets = {'A': 1., 'B': 1.2} + self.offsets = {'A': 1., 'B': 1.2, 'C': ExpressionScalar('t')} def test_init(self): trafo = OffsetTransformation(self.offsets) @@ -390,26 +427,29 @@ def test_init(self): def test_get_input_channels(self): trafo = OffsetTransformation(self.offsets) - channels = {'A', 'C'} + channels = {'A', 'K'} self.assertIs(channels, trafo.get_input_channels(channels)) self.assertIs(channels, trafo.get_output_channels(channels)) def test_compare_key(self): trafo = OffsetTransformation(self.offsets) _ = hash(trafo) - self.assertEqual(frozenset([('A', 1.), ('B', 1.2)]), trafo.compare_key) + self.assertEqual(trafo, OffsetTransformation(self.offsets)) + self.assertEqual({trafo}, {OffsetTransformation(self.offsets), trafo}) def test_trafo(self): trafo = OffsetTransformation(self.offsets) time = np.asarray([.5, .6]) - in_data = {'A': np.asarray([.1, .2]), 'C': np.asarray([3., 4.])} + in_data = {'A': np.asarray([.1, .2]), + 'C': np.asarray([.5, .6]), + 'K': np.asarray([3., 4.])} - expected = {'A': np.asarray([1.1, 1.2]), 'C': in_data['C']} + expected = {'A': np.asarray([1.1, 1.2]), 'C': in_data['C'] + time, 'K': in_data['K']} out_data = trafo(time, in_data) - self.assertIs(expected['C'], out_data['C']) + self.assertIs(expected['K'], out_data['K']) np.testing.assert_equal(expected, out_data) def test_repr(self): @@ -422,12 +462,33 @@ def test_scalar_trafo_works(self): def test_constant_propagation(self): trafo = OffsetTransformation(self.offsets) - self.assertTrue(trafo.is_constant_invariant()) + self.assertFalse(trafo.is_constant_invariant()) + constant_trafo = OffsetTransformation({'a': 7, 'b': 8.}) + self.assertTrue(constant_trafo.is_constant_invariant()) + + def test_time_dependence(self): + channels = {'X': 2, 'Y': ExpressionScalar('sin(t)')} + trafo = OffsetTransformation(channels) + self.assertEqual({'X', 'K'}, trafo.get_constant_output_channels({'X', 'Y', 'K'})) + + t = np.linspace(0., 1., num=50) + values = { + 'X': np.cos(t), + 'Y': 4. * np.ones_like(t), + 'K': 5. * np.ones_like(t) + } + transformed = trafo(t, values) + np.testing.assert_equal({ + 'X': np.cos(t) + 2, + 'Y': np.sin(t) + 4., + 'K': values['K'] + }, transformed) class TestScalingTransformation(unittest.TestCase): def setUp(self) -> None: - self.scales = {'A': 1.5, 'B': 1.2} + self.constant_scales = {'A': 1.5, 'B': 1.2} + self.scales = {'A': 1.5, 'B': 1.2, 'C': ExpressionScalar('t')} def test_init(self): trafo = ScalingTransformation(self.scales) @@ -443,19 +504,23 @@ def test_get_input_channels(self): def test_compare_key(self): trafo = OffsetTransformation(self.scales) + const_trafo = OffsetTransformation(self.constant_scales) _ = hash(trafo) - self.assertEqual(frozenset([('A', 1.5), ('B', 1.2)]), trafo.compare_key) + self.assertEqual(trafo, trafo) + self.assertNotEqual(trafo, const_trafo) + self.assertEqual({trafo}, {trafo, OffsetTransformation(self.scales)}) + self.assertEqual({trafo, const_trafo}, {trafo, OffsetTransformation(self.constant_scales)}) def test_trafo(self): trafo = ScalingTransformation(self.scales) time = np.asarray([.5, .6]) - in_data = {'A': np.asarray([.1, .2]), 'C': np.asarray([3., 4.])} - expected = {'A': np.asarray([.1 * 1.5, .2 * 1.5]), 'C': in_data['C']} + in_data = {'A': np.asarray([.1, .2]), 'C': np.asarray([3., 4.]), 'K': np.asarray([5., 6.])} + expected = {'A': in_data['A'] * 1.5, 'C': in_data['C'] * time, 'K': in_data['K']} out_data = trafo(time, in_data) - self.assertIs(expected['C'], out_data['C']) + self.assertIs(expected['K'], out_data['K']) np.testing.assert_equal(expected, out_data) def test_repr(self): @@ -468,4 +533,27 @@ def test_scalar_trafo_works(self): def test_constant_propagation(self): trafo = ScalingTransformation(self.scales) - self.assertTrue(trafo.is_constant_invariant()) + const_trafo = ScalingTransformation(self.constant_scales) + self.assertFalse(trafo.is_constant_invariant()) + self.assertTrue(const_trafo.is_constant_invariant()) + + def test_time_dependence(self): + channels = {'X': 2, 'Y': ExpressionScalar('sin(t)'), 'Z': ExpressionScalar('exp(t)')} + trafo = ScalingTransformation(channels) + self.assertEqual({'X', 'K'}, + trafo.get_constant_output_channels({'X', 'Y', 'K'})) + + t = np.linspace(0., 1., num=50) + values = { + 'X': np.cos(t), + 'Y': 4. * np.ones_like(t), + 'Z': np.tan(t), + 'K': 5. * np.ones_like(t) + } + transformed = trafo(t, values) + np.testing.assert_equal({ + 'X': np.cos(t) * 2, + 'Y': np.sin(t) * 4., + 'Z': np.tan(t) * np.exp(t), + 'K': values['K'] + }, transformed) diff --git a/tests/pulses/arithmetic_pulse_template_tests.py b/tests/pulses/arithmetic_pulse_template_tests.py index 87c1f65b..9369a16b 100644 --- a/tests/pulses/arithmetic_pulse_template_tests.py +++ b/tests/pulses/arithmetic_pulse_template_tests.py @@ -7,7 +7,8 @@ from qupulse.parameter_scope import DictScope from qupulse.expressions import ExpressionScalar -from qupulse.pulses.parameters import ConstantParameter +from qupulse.pulses import MappingPT, ConstantPT, RepetitionPT +from qupulse.pulses.plotting import render from qupulse.pulses.arithmetic_pulse_template import ArithmeticAtomicPulseTemplate, ArithmeticPulseTemplate,\ ImplicitAtomicityInArithmeticPT, UnequalDurationWarningInArithmeticPT, try_operation from qupulse._program.waveforms import TransformingWaveform @@ -240,23 +241,31 @@ def test_init(self): with mock.patch.object(ArithmeticPulseTemplate, '_parse_operand', return_value=scalar) as parse_operand: - arith = ArithmeticPulseTemplate(lhs, '/', non_pt) - parse_operand.assert_called_once_with(non_pt, lhs.defined_channels) - self.assertEqual(lhs, arith.lhs) - self.assertEqual(scalar, arith.rhs) - self.assertEqual(lhs, arith._pulse_template) - self.assertEqual(scalar, arith._scalar) - self.assertEqual('/', arith._arithmetic_operator) + with mock.patch('qupulse.pulses.arithmetic_pulse_template._is_time_dependent', return_value=False): + arith = ArithmeticPulseTemplate(lhs, '/', non_pt) + parse_operand.assert_called_once_with(non_pt, lhs.defined_channels) + self.assertEqual(lhs, arith.lhs) + self.assertEqual(scalar, arith.rhs) + self.assertEqual(lhs, arith._pulse_template) + self.assertEqual(scalar, arith._scalar) + self.assertEqual('/', arith._arithmetic_operator) with mock.patch.object(ArithmeticPulseTemplate, '_parse_operand', return_value=scalar) as parse_operand: - arith = ArithmeticPulseTemplate(non_pt, '-', rhs) - parse_operand.assert_called_once_with(non_pt, rhs.defined_channels) - self.assertEqual(scalar, arith.lhs) - self.assertEqual(rhs, arith.rhs) - self.assertEqual(rhs, arith._pulse_template) - self.assertEqual(scalar, arith._scalar) - self.assertEqual('-', arith._arithmetic_operator) + with mock.patch('qupulse.pulses.arithmetic_pulse_template._is_time_dependent', return_value=False): + arith = ArithmeticPulseTemplate(non_pt, '-', rhs) + parse_operand.assert_called_once_with(non_pt, rhs.defined_channels) + self.assertEqual(scalar, arith.lhs) + self.assertEqual(rhs, arith.rhs) + self.assertEqual(rhs, arith._pulse_template) + self.assertEqual(scalar, arith._scalar) + self.assertEqual('-', arith._arithmetic_operator) + + with mock.patch.object(ArithmeticPulseTemplate, '_parse_operand', + return_value=scalar) as parse_operand: + with mock.patch('qupulse.pulses.arithmetic_pulse_template._is_time_dependent', return_value=True): + with self.assertRaises(TypeError): + ArithmeticPulseTemplate(non_pt, '-', RepetitionPT(rhs, 3)) def test_parse_operand(self): operand = {'a': 3, 'b': 'x'} @@ -362,6 +371,23 @@ def test_get_transformation(self): expected_trafo = ScalingTransformation(inv_scalar) self.assertEqual(expected_trafo, trafo) + def test_time_dependent_expression(self): + inner = FunctionPT('exp(-(t - t_duration/2)**2)', duration_expression='t_duration') + inner_iq = AtomicMultiChannelPT((inner, {'default': 'I'}), (inner, {'default': 'Q'})) + modulated = ArithmeticPulseTemplate(inner_iq, '*', {'I': 'sin(2*pi*f*t)', 'Q': 'cos(2*pi*f*t)'}) + program = modulated.create_program(parameters={'t_duration': 10, 'f': 1.}) + wf = program[0].waveform + self.assertEqual(1, len(program)) + time = np.linspace(0, 10) + + sampled_i = wf.get_sampled('I', time) + sampled_q = wf.get_sampled('Q', time) + + expected_sampled_i = np.sin(2*np.pi*time) * np.exp(-(time - 5)**2) + expected_sampled_q = np.cos(2*np.pi*time) * np.exp(-(time - 5)**2) + np.testing.assert_allclose(expected_sampled_i, sampled_i) + np.testing.assert_allclose(expected_sampled_q, sampled_q) + def test_internal_create_program(self): lhs = 'x + y' rhs = DummyPulseTemplate(defined_channels={'u', 'v', 'w'}) @@ -510,10 +536,12 @@ def test_parameter_names(self): self.assertEqual({'x', 'y', 'foo', 'bar'}, arith.parameter_names) pt = DummyPulseTemplate(defined_channels={'a'}, parameter_names={'foo', 'bar'}) - mapping = {'a': 'x', 'b': 'y'} self.assertEqual(frozenset({'x', 'y'}), arith._scalar_operand_parameters) self.assertEqual({'x', 'y', 'foo', 'bar'}, arith.parameter_names) + arith = ArithmeticPulseTemplate(pt, '+', scalar + '+t') + self.assertEqual({'x', 'y', 'foo', 'bar'}, arith.parameter_names) + def test_try_operation(self): apt = DummyPulseTemplate(duration=1, defined_channels={'a'}) npt = PulseTemplateStub(defined_channels={'a'}) @@ -560,6 +588,20 @@ def test_repr(self): arith = ArithmeticPulseTemplate(pt, '-', scalar, identifier='id') self.assertEqual(super(ArithmeticPulseTemplate, arith).__repr__(), repr(arith)) + def test_time_dependence(self): + inner = ConstantPT(1.4, {'a': ExpressionScalar('x'), 'b': 1.1}) + with self.assertRaises(TypeError): + ArithmeticPulseTemplate(RepetitionPT(inner, 3), '*', {'a': 'sin(t)', 'b': 'cos(t)'}) + + pc = ArithmeticPulseTemplate(inner, '*', {'a': 'sin(t)', 'b': 'cos(t)'}) + prog = pc.create_program(parameters={'x': -1}) + t, vals, _ = render(prog, sample_rate=10) + expected_values = { + 'a': -np.sin(t), + 'b': 1.1 * np.cos(t) + } + np.testing.assert_equal(expected_values, vals) + class ArithmeticUsageTests(unittest.TestCase): def setUp(self) -> None: diff --git a/tests/pulses/multi_channel_pulse_template_tests.py b/tests/pulses/multi_channel_pulse_template_tests.py index 84904a8d..76b4d141 100644 --- a/tests/pulses/multi_channel_pulse_template_tests.py +++ b/tests/pulses/multi_channel_pulse_template_tests.py @@ -2,11 +2,14 @@ from unittest import mock import numpy +import numpy as np from qupulse.parameter_scope import DictScope +from qupulse.pulses import RepetitionPT, ConstantPT +from qupulse.pulses.plotting import render from qupulse.pulses.multi_channel_pulse_template import MultiChannelWaveform, MappingPulseTemplate,\ - ChannelMappingException, AtomicMultiChannelPulseTemplate, ParallelConstantChannelPulseTemplate,\ - TransformingWaveform, ParallelConstantChannelTransformation + ChannelMappingException, AtomicMultiChannelPulseTemplate, ParallelChannelPulseTemplate,\ + TransformingWaveform, ParallelChannelTransformation from qupulse.pulses.parameters import ParameterConstraint, ParameterConstraintViolation, ConstantParameter from qupulse.expressions import ExpressionScalar, Expression from qupulse._program.transformation import LinearTransformation, chain_transformations @@ -342,14 +345,14 @@ def serialize_callback(obj) -> str: self.assertEqual(expected_data, data) -class ParallelConstantChannelPulseTemplateTests(unittest.TestCase): +class ParallelChannelPulseTemplateTests(unittest.TestCase): def test_init(self): template = DummyPulseTemplate(duration='t1', defined_channels={'X', 'Y'}, parameter_names={'a', 'b'}, measurement_names={'M'}) overwritten_channels = {'Y': 'c', 'Z': 'a'} expected_overwritten_channels = {'Y': ExpressionScalar('c'), 'Z': ExpressionScalar('a')} - pccpt = ParallelConstantChannelPulseTemplate(template, overwritten_channels) + pccpt = ParallelChannelPulseTemplate(template, overwritten_channels) self.assertIs(template, pccpt.template) self.assertEqual(expected_overwritten_channels, pccpt.overwritten_channels) @@ -359,8 +362,16 @@ def test_init(self): self.assertEqual({'a', 'c'}, pccpt.transformation_parameters) self.assertIs(template.duration, pccpt.duration) + non_atomic_pt = RepetitionPT(template, 5) + ParallelChannelPulseTemplate(non_atomic_pt, overwritten_channels) + with self.assertRaises(TypeError): + overwritten_channels['T'] = 'a * t' + ParallelChannelPulseTemplate(non_atomic_pt, overwritten_channels) + + ParallelChannelPulseTemplate(template, overwritten_channels) + def test_missing_implementations(self): - pccpt = ParallelConstantChannelPulseTemplate(DummyPulseTemplate(), {}) + pccpt = ParallelChannelPulseTemplate(DummyPulseTemplate(), {}) with self.assertRaises(NotImplementedError): pccpt.get_serialization_data(object()) @@ -369,7 +380,7 @@ def test_integral(self): measurement_names={'M'}, integrals={'X': ExpressionScalar('a'), 'Y': ExpressionScalar(4)}) overwritten_channels = {'Y': 'c', 'Z': 'a'} - pccpt = ParallelConstantChannelPulseTemplate(template, overwritten_channels) + pccpt = ParallelChannelPulseTemplate(template, overwritten_channels) expected_integral = {'X': ExpressionScalar('a'), 'Y': ExpressionScalar('c*t1'), @@ -378,12 +389,12 @@ def test_integral(self): def test_initial_values(self): dpt = DummyPulseTemplate(initial_values={'A': 'a', 'B': 'b'}) - par = ParallelConstantChannelPulseTemplate(dpt, {'B': 'b2', 'C': 'c'}) + par = ParallelChannelPulseTemplate(dpt, {'B': 'b2', 'C': 'c'}) self.assertEqual({'A': 'a', 'B': 'b2', 'C': 'c'}, par.initial_values) def test_final_values(self): dpt = DummyPulseTemplate(final_values={'A': 'a', 'B': 'b'}) - par = ParallelConstantChannelPulseTemplate(dpt, {'B': 'b2', 'C': 'c'}) + par = ParallelChannelPulseTemplate(dpt, {'B': 'b2', 'C': 'c'}) self.assertEqual({'A': 'a', 'B': 'b2', 'C': 'c'}, par.final_values) def test_get_overwritten_channels_values(self): @@ -393,7 +404,7 @@ def test_get_overwritten_channels_values(self): channel_mapping = {'X': 'X', 'Y': 'K', 'Z': 'Z', 'ToNone': None} expected_overwritten_channel_values = {'K': 1.2, 'Z': 3.4} - pccpt = ParallelConstantChannelPulseTemplate(template, overwritten_channels) + pccpt = ParallelChannelPulseTemplate(template, overwritten_channels) real_parameters = {'c': 1.2, 'a': 3.4} self.assertEqual(expected_overwritten_channel_values, pccpt._get_overwritten_channels_values(real_parameters, @@ -413,13 +424,13 @@ def test_internal_create_program(self): channel_mapping=channel_mapping, to_single_waveform=to_single_waveform, parent_loop=parent_loop) - pccpt = ParallelConstantChannelPulseTemplate(template, overwritten_channels) + pccpt = ParallelChannelPulseTemplate(template, overwritten_channels) scope = DictScope.from_kwargs(c=1.2, a=3.4) kwargs = {**other_kwargs, 'scope': scope, 'global_transformation': None} expected_overwritten_channels = {'O': 1.2, 'Z': 3.4} - expected_transformation = ParallelConstantChannelTransformation(expected_overwritten_channels) + expected_transformation = ParallelChannelTransformation(expected_overwritten_channels) expected_kwargs = {**kwargs, 'global_transformation': expected_transformation} with mock.patch.object(template, '_create_program', spec=template._create_program) as cp_mock: @@ -440,11 +451,11 @@ def test_build_waveform(self): measurement_names={'M'}, waveform=DummyWaveform()) overwritten_channels = {'Y': 'c', 'Z': 'a'} channel_mapping = {'X': 'X', 'Y': 'K', 'Z': 'Z'} - pccpt = ParallelConstantChannelPulseTemplate(template, overwritten_channels) + pccpt = ParallelChannelPulseTemplate(template, overwritten_channels) parameters = {'c': 1.2, 'a': 3.4} expected_overwritten_channels = {'K': 1.2, 'Z': 3.4} - expected_transformation = ParallelConstantChannelTransformation(expected_overwritten_channels) + expected_transformation = ParallelChannelTransformation(expected_overwritten_channels) expected_waveform = TransformingWaveform(template.waveform, expected_transformation) resulting_waveform = pccpt.build_waveform(parameters.copy(), channel_mapping.copy()) @@ -456,12 +467,32 @@ def test_build_waveform(self): resulting_waveform = pccpt.build_waveform(parameters.copy(), channel_mapping.copy()) self.assertEqual(None, resulting_waveform) self.assertEqual([(parameters, channel_mapping), (parameters, channel_mapping)], template.build_waveform_calls) + + def test_time_dependence(self): + inner = ConstantPT(1.4, {'a': ExpressionScalar('x'), 'b': 1.}) + with self.assertRaises(TypeError): + ParallelChannelPulseTemplate(RepetitionPT(inner, 3), {'c': 'sin(t)'}) + + pc = ParallelChannelPulseTemplate(inner, {'c': 'sin(t)'}) + prog = pc.create_program(parameters={'x': -1}) + t, vals, _ = render(prog, sample_rate=10) + expected_values = { + 'a': np.broadcast_to(-1, t.shape), + 'b': np.broadcast_to(1., t.shape), + 'c': np.sin(t) + } + np.testing.assert_equal(expected_values, vals) + + def test_parameter_names(self): + inner = ConstantPT(1.4, {'a': ExpressionScalar('x'), 'b': 1.}) + pc = ParallelChannelPulseTemplate(inner, {'c': 'sin(2*pi*f*t)', 'd': 'k'}) + self.assertEqual({'x', 'f', 'k'}, pc.parameter_names) -class ParallelConstantChannelPulseTemplateSerializationTests(SerializableTests, unittest.TestCase): +class ParallelChannelPulseTemplateSerializationTests(SerializableTests, unittest.TestCase): @property def class_to_test(self): - return ParallelConstantChannelPulseTemplate + return ParallelChannelPulseTemplate @staticmethod def make_kwargs(*args, **kwargs): @@ -470,9 +501,9 @@ def make_kwargs(*args, **kwargs): 'overwritten_channels': {'Y': 'c', 'Z': 'a'} } - def assert_equal_instance_except_id(self, lhs: ParallelConstantChannelPulseTemplate, rhs: ParallelConstantChannelPulseTemplate): - self.assertIsInstance(lhs, ParallelConstantChannelPulseTemplate) - self.assertIsInstance(rhs, ParallelConstantChannelPulseTemplate) + def assert_equal_instance_except_id(self, lhs: ParallelChannelPulseTemplate, rhs: ParallelChannelPulseTemplate): + self.assertIsInstance(lhs, ParallelChannelPulseTemplate) + self.assertIsInstance(rhs, ParallelChannelPulseTemplate) self.assertEqual(lhs.template, rhs.template) self.assertEqual(lhs.overwritten_channels, rhs.overwritten_channels)