Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for time dependent expressions in transformations. #704

Merged
merged 17 commits into from
Dec 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions changes.d/709.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Add support for time dependent expressions for arithmetics with atomic pulse templates i.e. ParallelChannelPT and
ArithmeticPT support time dependent expressions if used with atomic pulse templates.
Rename `ParallelConstantChannelPT` to `ParallelChannelPT` to reflect this change.
130 changes: 101 additions & 29 deletions qupulse/_program/transformation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from typing import Any, Mapping, Set, Tuple, Sequence, AbstractSet, Union, TYPE_CHECKING
from typing import Any, Mapping, Set, Tuple, Sequence, AbstractSet, Union, TYPE_CHECKING, Hashable
from abc import abstractmethod
from numbers import Real

import numpy as np

from qupulse import ChannelID
from qupulse.comparable import Comparable
from qupulse.utils.types import SingletonABCMeta
from qupulse.utils.types import SingletonABCMeta, frozendict
from qupulse.expressions import ExpressionScalar


_TrafoValue = Union[Real, ExpressionScalar]


class Transformation(Comparable):
Expand Down Expand Up @@ -44,6 +48,9 @@ def is_constant_invariant(self):
"""Signals if the transformation always maps constants to constants."""
return False

def get_constant_output_channels(self, input_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]:
return frozenset()


class IdentityTransformation(Transformation, metaclass=SingletonABCMeta):
def __call__(self, time: Union[np.ndarray, float],
Expand All @@ -70,6 +77,9 @@ def is_constant_invariant(self):
"""Signals if the transformation always maps constants to constants."""
return True

def get_constant_output_channels(self, input_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]:
return input_channels


class ChainedTransformation(Transformation):
def __init__(self, *transformations: Transformation):
Expand Down Expand Up @@ -103,12 +113,17 @@ def chain(self, next_transformation) -> Transformation:
return chain_transformations(*self.transformations, next_transformation)

def __repr__(self):
return 'ChainedTransformation%r' % (self._transformations,)
return f'{type(self).__name__}{self._transformations!r}'

def is_constant_invariant(self):
"""Signals if the transformation always maps constants to constants."""
return all(trafo.is_constant_invariant() for trafo in self._transformations)

def get_constant_output_channels(self, input_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]:
for trafo in self._transformations:
input_channels = trafo.get_constant_output_channels(input_channels)
return input_channels


class LinearTransformation(Transformation):
def __init__(self,
Expand Down Expand Up @@ -192,21 +207,26 @@ def is_constant_invariant(self):
"""Signals if the transformation always maps constants to constants."""
return True

def get_constant_output_channels(self, input_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]:
return input_channels


class OffsetTransformation(Transformation):
def __init__(self, offsets: Mapping[ChannelID, Real]):
def __init__(self, offsets: Mapping[ChannelID, _TrafoValue]):
"""Adds an offset to each channel specified in offsets.

Channels not in offsets are forewarded

Args:
offsets: Channel -> offset mapping
"""
self._offsets = dict(offsets.items())
self._offsets = frozendict(offsets)
assert _are_valid_transformation_expressions(self._offsets)

def __call__(self, time: Union[np.ndarray, float],
data: Mapping[ChannelID, Union[np.ndarray, float]]) -> Mapping[ChannelID, Union[np.ndarray, float]]:
return {channel: channel_values + self._offsets[channel] if channel in self._offsets else channel_values
offsets = _instantiate_expression_dict(time, self._offsets)
return {channel: channel_values + offsets[channel] if channel in offsets else channel_values
for channel, channel_values in data.items()}

def get_input_channels(self, output_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]:
Expand All @@ -216,24 +236,29 @@ def get_output_channels(self, input_channels: AbstractSet[ChannelID]) -> Abstrac
return input_channels

@property
def compare_key(self) -> frozenset:
return frozenset(self._offsets.items())
def compare_key(self) -> Hashable:
return self._offsets

def __repr__(self):
return 'OffsetTransformation(%r)' % self._offsets
return f'{type(self).__name__}({dict(self._offsets)!r})'

def is_constant_invariant(self):
"""Signals if the transformation always maps constants to constants."""
return True
return not _has_time_dependent_values(self._offsets)

def get_constant_output_channels(self, input_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]:
return _get_constant_output_channels(self._offsets, input_channels)


class ScalingTransformation(Transformation):
def __init__(self, factors: Mapping[ChannelID, Real]):
self._factors = dict(factors.items())
def __init__(self, factors: Mapping[ChannelID, _TrafoValue]):
self._factors = frozendict(factors)
assert _are_valid_transformation_expressions(self._factors)

def __call__(self, time: Union[np.ndarray, float],
data: Mapping[ChannelID, Union[np.ndarray, float]]) -> Mapping[ChannelID, Union[np.ndarray, float]]:
return {channel: channel_values * self._factors[channel] if channel in self._factors else channel_values
factors = _instantiate_expression_dict(time, self._factors)
return {channel: channel_values * factors[channel] if channel in factors else channel_values
for channel, channel_values in data.items()}

def get_input_channels(self, output_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]:
Expand All @@ -243,15 +268,18 @@ def get_output_channels(self, input_channels: AbstractSet[ChannelID]) -> Abstrac
return input_channels

@property
def compare_key(self) -> frozenset:
return frozenset(self._factors.items())
def compare_key(self) -> Hashable:
return self._factors

def __repr__(self):
return 'ScalingTransformation(%r)' % self._factors
return f'{type(self).__name__}({dict(self._factors)!r})'

def is_constant_invariant(self):
"""Signals if the transformation always maps constants to constants."""
return True
return not _has_time_dependent_values(self._factors)

def get_constant_output_channels(self, input_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]:
return _get_constant_output_channels(self._factors, input_channels)


try:
Expand All @@ -277,25 +305,30 @@ def linear_transformation_from_pandas(transformation: PandasDataFrameType) -> Li
pass


class ParallelConstantChannelTransformation(Transformation):
def __init__(self, channels: Mapping[ChannelID, Real]):
"""Set channel values to given values regardless their former existence
class ParallelChannelTransformation(Transformation):
def __init__(self, channels: Mapping[ChannelID, _TrafoValue]):
"""Set channel values to given values regardless their former existence. The values can be time dependent
expressions.

Args:
channels: Channels present in this map are set to the given value.
"""
self._channels = {channel: float(value)
for channel, value in channels.items()}
self._channels: Mapping[ChannelID, _TrafoValue] = frozendict(channels.items())
assert _are_valid_transformation_expressions(self._channels)

def __call__(self, time: Union[np.ndarray, float],
data: Mapping[ChannelID, Union[np.ndarray, float]]) -> Mapping[ChannelID, Union[np.ndarray, float]]:
overwritten = {channel: np.full_like(time, fill_value=value, dtype=float)
for channel, value in self._channels.items()}
overwritten = self._instantiated_values(time)
return {**data, **overwritten}

def _instantiated_values(self, time):
scope = {'t': time}
return {channel: value.evaluate_in_scope(scope) if hasattr(value, 'evaluate_in_scope') else np.full_like(time, fill_value=value, dtype=float)
for channel, value in self._channels.items()}

@property
def compare_key(self) -> Tuple[Tuple[ChannelID, float], ...]:
return tuple(sorted(self._channels.items()))
def compare_key(self) -> Hashable:
return self._channels

def get_input_channels(self, output_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]:
return output_channels - self._channels.keys()
Expand All @@ -304,11 +337,21 @@ def get_output_channels(self, input_channels: AbstractSet[ChannelID]) -> Abstrac
return input_channels | self._channels.keys()

def __repr__(self):
return 'ParallelConstantChannelTransformation(%r)' % self._channels
return f'{type(self).__name__}({dict(self._channels)!r})'

def is_constant_invariant(self):
"""Signals if the transformation always maps constants to constants."""
return True
return not _has_time_dependent_values(self._channels)

def get_constant_output_channels(self, input_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]:
output_channels = set(input_channels)
for ch, value in self._channels.items():
if hasattr(value, 'variables'):
output_channels.discard(ch)
else:
output_channels.add(ch)

return output_channels


def chain_transformations(*transformations: Transformation) -> Transformation:
Expand All @@ -325,4 +368,33 @@ def chain_transformations(*transformations: Transformation) -> Transformation:
elif len(parsed_transformations) == 1:
return parsed_transformations[0]
else:
return ChainedTransformation(*parsed_transformations)
return ChainedTransformation(*parsed_transformations)


def _instantiate_expression_dict(time, expressions: Mapping[str, _TrafoValue]) -> Mapping[str, Union[Real, np.ndarray]]:
scope = {'t': time}
modified_expressions = {}
for name, value in expressions.items():
if hasattr(value, 'evaluate_in_scope'):
modified_expressions[name] = value.evaluate_in_scope(scope)
if modified_expressions:
return {**expressions, **modified_expressions}
else:
return expressions


def _has_time_dependent_values(expressions: Mapping[ChannelID, _TrafoValue]) -> bool:
return any(hasattr(value, 'variables')
for value in expressions.values())


def _get_constant_output_channels(expressions: Mapping[ChannelID, _TrafoValue],
constant_input_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]:
return {ch
for ch in constant_input_channels
if not hasattr(expressions.get(ch, None), 'variables')}

def _are_valid_transformation_expressions(expressions: Mapping[ChannelID, _TrafoValue]) -> bool:
return all(expr.variables == ('t',)
for expr in expressions.values()
if hasattr(expr, 'variables'))
6 changes: 3 additions & 3 deletions qupulse/pulses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from qupulse.pulses.function_pulse_template import FunctionPulseTemplate as FunctionPT
from qupulse.pulses.loop_pulse_template import ForLoopPulseTemplate as ForLoopPT
from qupulse.pulses.multi_channel_pulse_template import AtomicMultiChannelPulseTemplate as AtomicMultiChannelPT,\
ParallelConstantChannelPulseTemplate as ParallelConstantChannelPT
ParallelConstantChannelPulseTemplate as ParallelConstantChannelPT,\
ParallelChannelPulseTemplate as ParallelChannelPT
from qupulse.pulses.mapping_pulse_template import MappingPulseTemplate as MappingPT
from qupulse.pulses.repetition_pulse_template import RepetitionPulseTemplate as RepetitionPT
from qupulse.pulses.sequence_pulse_template import SequencePulseTemplate as SequencePT
Expand All @@ -30,5 +31,4 @@

__all__ = ["FunctionPT", "ForLoopPT", "AtomicMultiChannelPT", "MappingPT", "RepetitionPT", "SequencePT", "TablePT",
"PointPT", "ConstantPT", "AbstractPT", "ParallelConstantChannelPT", "ArithmeticPT", "ArithmeticAtomicPT",
"TimeReversalPT"]

"TimeReversalPT", "ParallelChannelPT"]
51 changes: 37 additions & 14 deletions qupulse/pulses/arithmetic_pulse_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,24 +198,31 @@ def __init__(self,
rhs: Union[PulseTemplate, ExpressionLike, Mapping[ChannelID, ExpressionLike]],
*,
identifier: Optional[str] = None):
"""Allowed operations
"""Implements the arithmetics between an aribrary pulse template and scalar values. The values can be the same
for all channels, channel specific or only for a subset of the inner pulse templates defined channels.
The expression may be time dependent if the pulse template is atomic.

scalar + pulse_template
scalar - pulse_template
scalar * pulse_template
pulse_template + scalar
pulse_template - scalar
pulse_template * scalar
pulse_template / scalar
A channel dependent scalar is represented by a mapping of ChannelID -> Expression.

The allowed operations are:
scalar + pulse_template
scalar - pulse_template
scalar * pulse_template
pulse_template + scalar
pulse_template - scalar
pulse_template * scalar
pulse_template / scalar

Args:
lhs: Left hand side operand
arithmetic_operator: String representation of the operator
rhs: Right hand side operand
identifier:
identifier: Identifier used for serialization

Raises:
TypeError if both or none of the operands are pulse templates
TypeError: If both or none of the operands are pulse templates or if there is a time dependent expression
and a composite pulse template.
ValueError: If the scalar is a mapping and contains channels that are not defined on the pulse template.
"""
PulseTemplate.__init__(self, identifier=identifier)

Expand Down Expand Up @@ -243,11 +250,15 @@ def __init__(self,
self._lhs = lhs
self._rhs = rhs

self._pulse_template = pulse_template
self._pulse_template: PulseTemplate = pulse_template
self._scalar = scalar

self._arithmetic_operator = arithmetic_operator

if not self._pulse_template._is_atomic() and _is_time_dependent(self._scalar):
raise TypeError("A time dependent ArithmeticPulseTemplate scalar operand currently requires an atomic "
"pulse template as the other operand.", self)

@staticmethod
def _parse_operand(operand: Union[ExpressionLike, Mapping[ChannelID, ExpressionLike]],
channels: Set[ChannelID]) -> Union[ExpressionScalar, Mapping[ChannelID, ExpressionScalar]]:
Expand Down Expand Up @@ -298,7 +309,7 @@ def _get_scalar_value(self,
if channel_mapping[channel]}

else:
return {channel_mapping[channel]: value.evaluate_in_scope(parameters)
return {channel_mapping[channel]: value.evaluate_symbolic(parameters) if 't' in value.variables else value.evaluate_in_scope(parameters)
for channel, value in self._scalar.items()
if channel_mapping[channel]}

Expand Down Expand Up @@ -479,9 +490,11 @@ def measurement_names(self) -> Set[str]:
@cached_property
def _scalar_operand_parameters(self) -> FrozenSet[str]:
if isinstance(self._scalar, dict):
return frozenset(*(value.variables for value in self._scalar.values()))
return frozenset(variable
for value in self._scalar.values()
for variable in value.variables) - {'t'}
else:
return frozenset(self._scalar.variables)
return frozenset(self._scalar.variables) - {'t'}

@property
def parameter_names(self) -> Set[str]:
Expand All @@ -499,6 +512,9 @@ def get_measurement_windows(self,
measurement_mapping=measurement_mapping))
return measurements

def _is_atomic(self):
return self._pulse_template._is_atomic()


def try_operation(lhs: Union[PulseTemplate, ExpressionLike, Mapping[ChannelID, ExpressionLike]],
op: str,
Expand Down Expand Up @@ -531,6 +547,13 @@ def try_operation(lhs: Union[PulseTemplate, ExpressionLike, Mapping[ChannelID, E
return NotImplemented


def _is_time_dependent(scalar: Union[ExpressionScalar, Dict[str, ExpressionScalar]]) -> bool:
if isinstance(scalar, dict):
return any('t' in value.variables for value in scalar.values())
else:
return 't' in scalar.variables


class UnequalDurationWarningInArithmeticPT(RuntimeWarning):
"""Signals that an ArithmeticAtomicPulseTemplate was constructed from operands with unequal duration. This is a
separate class to allow easy silencing."""
Expand Down
5 changes: 4 additions & 1 deletion qupulse/pulses/mapping_pulse_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __init__(self, template: PulseTemplate, *,
for k, v in template.channel_mapping.items()}
template = template.template

self.__template = template
self.__template: PulseTemplate = template
self.__parameter_mapping = FrozenDict(parameter_mapping)
self.__external_parameters = set(itertools.chain(*(expr.variables for expr in self.__parameter_mapping.values())))
self.__external_parameters |= self.constrained_parameters
Expand Down Expand Up @@ -230,6 +230,9 @@ def get_serialization_data(self, serializer: Optional[Serializer]=None) -> Dict[

return data

def _is_atomic(self):
return self.__template._is_atomic()

@classmethod
def deserialize(cls,
serializer: Optional[Serializer]=None, # compatibility to old serialization routines, deprecated
Expand Down
Loading