diff --git a/.coveragerc b/.coveragerc index 89353bb9..55fda8f0 100644 --- a/.coveragerc +++ b/.coveragerc @@ -5,7 +5,7 @@ branch = True fail_under = 100 show_missing = True omit = - proto/marshal/containers.py + proto/marshal/compat.py exclude_lines = # Re-enable the standard pragma pragma: NO COVER diff --git a/proto/__init__.py b/proto/__init__.py index 8f69ef55..577eac40 100644 --- a/proto/__init__.py +++ b/proto/__init__.py @@ -15,7 +15,7 @@ from .fields import Field from .fields import MapField from .fields import RepeatedField -from .marshal.marshal import marshal +from .marshal import Marshal from .message import Message from .primitives import ProtoType @@ -43,7 +43,7 @@ 'Field', 'MapField', 'RepeatedField', - 'marshal', + 'Marshal', 'Message', # Expose the types directly. diff --git a/proto/marshal/__init__.py b/proto/marshal/__init__.py index b0c7da3d..71f1c721 100644 --- a/proto/marshal/__init__.py +++ b/proto/marshal/__init__.py @@ -11,3 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .marshal import Marshal + + +__all__ = ( + 'Marshal', +) diff --git a/proto/marshal/containers.py b/proto/marshal/compat.py similarity index 100% rename from proto/marshal/containers.py rename to proto/marshal/compat.py diff --git a/proto/marshal/marshal.py b/proto/marshal/marshal.py index d2ed7a14..4bd6cb56 100644 --- a/proto/marshal/marshal.py +++ b/proto/marshal/marshal.py @@ -19,12 +19,12 @@ from google.protobuf import timestamp_pb2 from google.protobuf import wrappers_pb2 -from proto.marshal import containers +from proto.marshal import compat from proto.marshal.collections import MapComposite from proto.marshal.collections import Repeated from proto.marshal.collections import RepeatedComposite -from proto.marshal.types import dates -from proto.marshal.types import wrappers +from proto.marshal.rules import dates +from proto.marshal.rules import wrappers class Rule(abc.ABC): @@ -37,8 +37,8 @@ def __subclasshook__(cls, C): return NotImplemented -class MarshalRegistry: - """A class to translate between protocol buffers and Python classes. +class BaseMarshal: + """The base class to translate between protobuf and Python classes. Protocol buffers defines many common types (e.g. Timestamp, Duration) which also exist in the Python standard library. The marshal essentially @@ -52,14 +52,12 @@ class MarshalRegistry: the declared field type is still used. This means that, if appropriate, multiple protocol buffer types may use the same Python type. - The marshal is intended to be a singleton; this module instantiates - and exports one marshal, which is imported throughout the rest of this - library. This allows for an advanced case where user code registers - additional types to be marshaled. + The primary implementation of this is :class:`Marshal`, which should + usually be used instead of this class directly. """ def __init__(self): - self._registry = {} - self._noop = NoopMarshal() + self._rules = {} + self._noop = NoopRule() self.reset() def register(self, proto_type: type, rule: Rule = None): @@ -73,7 +71,7 @@ def register(self, proto_type: type, rule: Rule = None): This function can also be used as a decorator:: @marshal.register(timestamp_pb2.Timestamp) - class TimestampMarshal: + class TimestampRule: ... In this case, the class will be initialized for you with zero @@ -97,7 +95,7 @@ class TimestampMarshal: '`to_proto` and `to_python` methods.') # Register the rule. - self._registry[proto_type] = rule + self._rules[proto_type] = rule return # Create an inner function that will register an instance of the @@ -109,43 +107,43 @@ def register_rule_class(rule_class: type): '`to_proto` and `to_python` methods.') # Register the rule class. - self._registry[proto_type] = rule_class() + self._rules[proto_type] = rule_class() return rule_class return register_rule_class def reset(self): """Reset the registry to its initial state.""" - self._registry.clear() + self._rules.clear() # Register date and time wrappers. - self.register(timestamp_pb2.Timestamp, dates.TimestampMarshal()) - self.register(duration_pb2.Duration, dates.DurationMarshal()) + self.register(timestamp_pb2.Timestamp, dates.TimestampRule()) + self.register(duration_pb2.Duration, dates.DurationRule()) # Register nullable primitive wrappers. - self.register(wrappers_pb2.BoolValue, wrappers.BoolValueMarshal()) - self.register(wrappers_pb2.BytesValue, wrappers.BytesValueMarshal()) - self.register(wrappers_pb2.DoubleValue, wrappers.DoubleValueMarshal()) - self.register(wrappers_pb2.FloatValue, wrappers.FloatValueMarshal()) - self.register(wrappers_pb2.Int32Value, wrappers.Int32ValueMarshal()) - self.register(wrappers_pb2.Int64Value, wrappers.Int64ValueMarshal()) - self.register(wrappers_pb2.StringValue, wrappers.StringValueMarshal()) - self.register(wrappers_pb2.UInt32Value, wrappers.UInt32ValueMarshal()) - self.register(wrappers_pb2.UInt64Value, wrappers.UInt64ValueMarshal()) + self.register(wrappers_pb2.BoolValue, wrappers.BoolValueRule()) + self.register(wrappers_pb2.BytesValue, wrappers.BytesValueRule()) + self.register(wrappers_pb2.DoubleValue, wrappers.DoubleValueRule()) + self.register(wrappers_pb2.FloatValue, wrappers.FloatValueRule()) + self.register(wrappers_pb2.Int32Value, wrappers.Int32ValueRule()) + self.register(wrappers_pb2.Int64Value, wrappers.Int64ValueRule()) + self.register(wrappers_pb2.StringValue, wrappers.StringValueRule()) + self.register(wrappers_pb2.UInt32Value, wrappers.UInt32ValueRule()) + self.register(wrappers_pb2.UInt64Value, wrappers.UInt64ValueRule()) def to_python(self, proto_type, value, *, absent: bool = None): # Internal protobuf has its own special type for lists of values. # Return a view around it that implements MutableSequence. - if isinstance(value, containers.repeated_composite_types): + if isinstance(value, compat.repeated_composite_types): return RepeatedComposite(value, marshal=self) - if isinstance(value, containers.repeated_scalar_types): + if isinstance(value, compat.repeated_scalar_types): return Repeated(value, marshal=self) # Same thing for maps of messages. - if isinstance(value, containers.map_composite_types): + if isinstance(value, compat.map_composite_types): return MapComposite(value, marshal=self) # Convert ordinary values. - rule = self._registry.get(proto_type, self._noop) + rule = self._rules.get(proto_type, self._noop) return rule.to_python(value, absent=absent) def to_proto(self, proto_type, value, *, strict: bool = False): @@ -172,7 +170,7 @@ def to_proto(self, proto_type, value, *, strict: bool = False): for k, v in value.items()} # Convert ordinary values. - rule = self._registry.get(proto_type, self._noop) + rule = self._rules.get(proto_type, self._noop) pb_value = rule.to_proto(value) # Sanity check: If we are in strict mode, did we get the value we want? @@ -189,8 +187,42 @@ def to_proto(self, proto_type, value, *, strict: bool = False): return pb_value -class NoopMarshal: - """A catch-all marshal that does nothing.""" +class Marshal(BaseMarshal): + """The translator between protocol buffer and Python instances. + + The bulk of the implementation is in :class:`BaseMarshal`. This class + adds identity tracking: multiple instantiations of :class:`Marshal` with + the same name will provide the same instance. + """ + _instances = {} + + def __new__(cls, *, name: str): + """Create a marshal instance. + + Args: + name (str): The name of the marshal. Instantiating multiple + marshals with the same ``name`` argument will provide the + same marshal each time. + """ + if name not in cls._instances: + cls._instances[name] = super().__new__(cls) + return cls._instances[name] + + def __init__(self, *, name: str): + """Instantiate a marshal. + + Args: + name (str): The name of the marshal. Instantiating multiple + marshals with the same ``name`` argument will provide the + same marshal each time. + """ + self._name = name + if not hasattr(self, '_rules'): + super().__init__() + + +class NoopRule: + """A catch-all rule that does nothing.""" def to_python(self, pb_value, *, absent: bool = None): return pb_value @@ -199,8 +231,6 @@ def to_proto(self, value): return value -marshal = MarshalRegistry() - __all__ = ( - 'marshal', + 'Marshal', ) diff --git a/proto/marshal/types/__init__.py b/proto/marshal/rules/__init__.py similarity index 100% rename from proto/marshal/types/__init__.py rename to proto/marshal/rules/__init__.py diff --git a/proto/marshal/types/dates.py b/proto/marshal/rules/dates.py similarity index 98% rename from proto/marshal/types/dates.py rename to proto/marshal/rules/dates.py index 2f54d274..b365fdf2 100644 --- a/proto/marshal/types/dates.py +++ b/proto/marshal/rules/dates.py @@ -20,7 +20,7 @@ from google.protobuf import timestamp_pb2 -class TimestampMarshal: +class TimestampRule: """A marshal between Python datetimes and protobuf timestamps. Note: Python datetimes are less precise than protobuf datetimes @@ -47,7 +47,7 @@ def to_proto(self, value) -> timestamp_pb2.Timestamp: return value -class DurationMarshal: +class DurationRule: """A marshal between Python timedeltas and protobuf durations. Note: Python timedeltas are less precise than protobuf durations diff --git a/proto/marshal/types/message.py b/proto/marshal/rules/message.py similarity index 98% rename from proto/marshal/types/message.py rename to proto/marshal/rules/message.py index c74d496d..e5ecf17b 100644 --- a/proto/marshal/types/message.py +++ b/proto/marshal/rules/message.py @@ -13,7 +13,7 @@ # limitations under the License. -class MessageMarshal: +class MessageRule: """A marshal for converting between a descriptor and proto.Message.""" def __init__(self, descriptor: type, wrapper: type): diff --git a/proto/marshal/types/wrappers.py b/proto/marshal/rules/wrappers.py similarity index 83% rename from proto/marshal/types/wrappers.py rename to proto/marshal/rules/wrappers.py index 5d2d3658..bfd4b78e 100644 --- a/proto/marshal/types/wrappers.py +++ b/proto/marshal/rules/wrappers.py @@ -15,7 +15,7 @@ from google.protobuf import wrappers_pb2 -class WrapperMarshal: +class WrapperRule: """A marshal for converting the protobuf wrapper classes to Python. This class converts between ``google.protobuf.BoolValue``, @@ -38,46 +38,46 @@ def to_proto(self, value): return value -class DoubleValueMarshal(WrapperMarshal): +class DoubleValueRule(WrapperRule): _proto_type = wrappers_pb2.DoubleValue _python_type = float -class FloatValueMarshal(WrapperMarshal): +class FloatValueRule(WrapperRule): _proto_type = wrappers_pb2.FloatValue _python_type = float -class Int64ValueMarshal(WrapperMarshal): +class Int64ValueRule(WrapperRule): _proto_type = wrappers_pb2.Int64Value _python_type = int -class UInt64ValueMarshal(WrapperMarshal): +class UInt64ValueRule(WrapperRule): _proto_type = wrappers_pb2.UInt64Value _python_type = int -class Int32ValueMarshal(WrapperMarshal): +class Int32ValueRule(WrapperRule): _proto_type = wrappers_pb2.Int32Value _python_type = int -class UInt32ValueMarshal(WrapperMarshal): +class UInt32ValueRule(WrapperRule): _proto_type = wrappers_pb2.UInt32Value _python_type = int -class BoolValueMarshal(WrapperMarshal): +class BoolValueRule(WrapperRule): _proto_type = wrappers_pb2.BoolValue _python_type = bool -class StringValueMarshal(WrapperMarshal): +class StringValueRule(WrapperRule): _proto_type = wrappers_pb2.StringValue _python_type = str -class BytesValueMarshal(WrapperMarshal): +class BytesValueRule(WrapperRule): _proto_type = wrappers_pb2.BytesValue _python_type = bytes diff --git a/proto/message.py b/proto/message.py index 03e49d39..e755b585 100644 --- a/proto/message.py +++ b/proto/message.py @@ -29,8 +29,8 @@ from proto.fields import Field from proto.fields import MapField from proto.fields import RepeatedField -from proto.marshal.marshal import marshal -from proto.marshal.types.message import MessageMarshal +from proto.marshal import Marshal +from proto.marshal.rules.message import MessageRule from proto.primitives import ProtoType @@ -47,6 +47,7 @@ def __new__(mcls, name, bases, attrs): # A package and full name should be present. package = getattr(Meta, 'package', '') + marshal = Marshal(name=getattr(Meta, 'marshal', package)) local_path = tuple(attrs.get('__qualname__', name).split('.')) # Sanity check: We get the wrong full name if a class is declared @@ -225,6 +226,7 @@ def __new__(mcls, name, bases, attrs): attrs['_meta'] = _MessageInfo( fields=fields, full_name=full_name, + marshal=marshal, options=opts, package=package, ) @@ -345,7 +347,7 @@ def __init__(self, mapping=None, **kwargs): # passed in. # # The `__wrap_original` argument is private API to override - # this behavior, because `MessageMarshal` actually does want to + # this behavior, because `MessageRule` actually does want to # wrap the original argument it was given. The `wrap` method # on the metaclass is the public API for this behavior. if not kwargs.pop('__wrap_original', False): @@ -364,6 +366,7 @@ def __init__(self, mapping=None, **kwargs): # Update the mapping to address any values that need to be # coerced. + marshal = self._meta.marshal for key, value in copy.copy(mapping).items(): pb_type = self._meta.fields[key].pb_type pb_value = marshal.to_proto(pb_type, value) @@ -464,6 +467,7 @@ def __getattr__(self, key): """ pb_type = self._meta.fields[key].pb_type pb_value = getattr(self._pb, key) + marshal = self._meta.marshal return marshal.to_python(pb_type, pb_value, absent=key not in self) def __ne__(self, other): @@ -481,6 +485,7 @@ def __setattr__(self, key, value): """ if key.startswith('_'): return super().__setattr__(key, value) + marshal = self._meta.marshal pb_type = self._meta.fields[key].pb_type pb_value = marshal.to_proto(pb_type, value) @@ -503,11 +508,15 @@ class _MessageInfo: full_name (str): The full name of the message. file_info (~._FileInfo): The file descriptor and messages for the file containing this message. + marshal (~.Marshal): The marshal instance to which this message was + automatically registered. options (~.descriptor_pb2.MessageOptions): Any options that were set on the message. """ - def __init__(self, *, fields: List[Field], package: str, full_name: str, - options: descriptor_pb2.MessageOptions) -> None: + def __init__(self, *, + fields: List[Field], package: str, full_name: str, + marshal: Marshal, options: descriptor_pb2.MessageOptions + ) -> None: self.package = package self.full_name = full_name self.options = options @@ -517,6 +526,7 @@ def __init__(self, *, fields: List[Field], package: str, full_name: str, self.fields_by_number = collections.OrderedDict([ (i.number, i) for i in fields ]) + self.marshal = marshal self._pb = None @property @@ -573,9 +583,9 @@ def generate_file_pb(self): # Register the message with the marshal so it is wrapped # appropriately. proto_plus_message._meta._pb = pb_message - marshal.register( + proto_plus_message._meta.marshal.register( pb_message, - MessageMarshal(pb_message, proto_plus_message) + MessageRule(pb_message, proto_plus_message) ) # Iterate over any fields on the message and, if their type diff --git a/tests/conftest.py b/tests/conftest.py index f4bf1084..9e8a59c3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,8 +20,8 @@ from google.protobuf import reflection from google.protobuf import symbol_database -import proto -from proto.marshal import types +from proto.marshal import Marshal +from proto.marshal import rules from proto.message import _FileInfo @@ -73,11 +73,11 @@ def pytest_runtest_setup(item): # If the marshal had previously registered the old message classes, # then reload the appropriate modules so the marshal is using the new ones. if 'wrappers_pb2' in reloaded: - imp.reload(types.wrappers) + imp.reload(rules.wrappers) if reloaded.intersection({'timestamp_pb2', 'duration_pb2'}): - imp.reload(types.dates) - proto.marshal.reset() + imp.reload(rules.dates) def pytest_runtest_teardown(item): + Marshal._instances.clear() [i.stop() for i in item._mocks] diff --git a/tests/test_marshal_register.py b/tests/test_marshal_register.py index cb36f3ee..673355e1 100644 --- a/tests/test_marshal_register.py +++ b/tests/test_marshal_register.py @@ -16,24 +16,27 @@ from google.protobuf import empty_pb2 -import proto +from proto.marshal.marshal import BaseMarshal def test_registration(): - @proto.marshal.register(empty_pb2.Empty) - class Marshal: + marshal = BaseMarshal() + + @marshal.register(empty_pb2.Empty) + class Rule: def to_proto(self, value): return value def to_python(self, value, *, absent=None): return value - assert isinstance(proto.marshal._registry[empty_pb2.Empty], Marshal) + assert isinstance(marshal._rules[empty_pb2.Empty], Rule) def test_invalid_target_registration(): + marshal = BaseMarshal() with pytest.raises(TypeError): - @proto.marshal.register(object) - class Marshal: + @marshal.register(object) + class Rule: def to_proto(self, value): return value @@ -42,12 +45,14 @@ def to_python(self, value, *, absent=None): def test_invalid_marshal_class(): + marshal = BaseMarshal() with pytest.raises(TypeError): - @proto.marshal.register(empty_pb2.Empty) + @marshal.register(empty_pb2.Empty) class Marshal: pass def test_invalid_marshal_rule(): + marshal = BaseMarshal() with pytest.raises(TypeError): - proto.marshal.register(empty_pb2.Empty, rule=object()) + marshal.register(empty_pb2.Empty, rule=object()) diff --git a/tests/test_marshal_types_dates.py b/tests/test_marshal_types_dates.py index 0a35ebf9..22264c54 100644 --- a/tests/test_marshal_types_dates.py +++ b/tests/test_marshal_types_dates.py @@ -20,6 +20,7 @@ from google.protobuf import timestamp_pb2 import proto +from proto.marshal.marshal import BaseMarshal def test_timestamp_read(): @@ -212,7 +213,7 @@ def test_timestamp_to_python_idempotent(): # # However, we test idempotency for consistency with `to_proto` and # general resiliency. - marshal = proto.marshal + marshal = BaseMarshal() py_value = datetime(2012, 4, 21, 15, tzinfo=timezone.utc) assert marshal.to_python(timestamp_pb2.Timestamp, py_value) is py_value @@ -223,6 +224,6 @@ def test_duration_to_python_idempotent(): # # However, we test idempotency for consistency with `to_proto` and # general resiliency. - marshal = proto.marshal + marshal = BaseMarshal() py_value = timedelta(seconds=240) assert marshal.to_python(duration_pb2.Duration, py_value) is py_value diff --git a/tests/test_marshal_types_message.py b/tests/test_marshal_types_message.py index d1a252aa..49d48e4d 100644 --- a/tests/test_marshal_types_message.py +++ b/tests/test_marshal_types_message.py @@ -13,17 +13,17 @@ # limitations under the License. import proto -from proto.marshal.types.message import MessageMarshal +from proto.marshal.rules.message import MessageRule def test_to_proto(): class Foo(proto.Message): bar = proto.Field(proto.INT32, number=1) - message_marshal = MessageMarshal(Foo.pb(), Foo) - foo_pb2_a = message_marshal.to_proto(Foo(bar=42)) - foo_pb2_b = message_marshal.to_proto(Foo.pb()(bar=42)) - foo_pb2_c = message_marshal.to_proto({'bar': 42}) + message_rule = MessageRule(Foo.pb(), Foo) + foo_pb2_a = message_rule.to_proto(Foo(bar=42)) + foo_pb2_b = message_rule.to_proto(Foo.pb()(bar=42)) + foo_pb2_c = message_rule.to_proto({'bar': 42}) assert foo_pb2_a == foo_pb2_b == foo_pb2_c @@ -31,7 +31,7 @@ def test_to_python(): class Foo(proto.Message): bar = proto.Field(proto.INT32, number=1) - message_marshal = MessageMarshal(Foo.pb(), Foo) - foo_a = message_marshal.to_python(Foo(bar=42)) - foo_b = message_marshal.to_python(Foo.pb()(bar=42)) + message_rule = MessageRule(Foo.pb(), Foo) + foo_a = message_rule.to_python(Foo(bar=42)) + foo_b = message_rule.to_python(Foo.pb()(bar=42)) assert foo_a == foo_b diff --git a/tests/test_marshal_types_wrappers_bool.py b/tests/test_marshal_types_wrappers_bool.py index ad3d1cf2..0a25a48c 100644 --- a/tests/test_marshal_types_wrappers_bool.py +++ b/tests/test_marshal_types_wrappers_bool.py @@ -15,6 +15,7 @@ from google.protobuf import wrappers_pb2 import proto +from proto.marshal.marshal import BaseMarshal def test_bool_value_init(): @@ -103,7 +104,7 @@ def test_bool_value_to_python(): # # However, we test idempotency for consistency with `to_proto` and # general resiliency. - marshal = proto.marshal + marshal = BaseMarshal() assert marshal.to_python(wrappers_pb2.BoolValue, True) is True assert marshal.to_python(wrappers_pb2.BoolValue, False) is False assert marshal.to_python(wrappers_pb2.BoolValue, None) is None