diff --git a/proto/__init__.py b/proto/__init__.py index 577eac40..0319199d 100644 --- a/proto/__init__.py +++ b/proto/__init__.py @@ -17,6 +17,7 @@ from .fields import RepeatedField from .marshal import Marshal from .message import Message +from .modules import define_module as module from .primitives import ProtoType @@ -45,6 +46,7 @@ 'RepeatedField', 'Marshal', 'Message', + 'module', # Expose the types directly. 'DOUBLE', diff --git a/proto/message.py b/proto/message.py index 815bf88c..5ac26ae3 100644 --- a/proto/message.py +++ b/proto/message.py @@ -17,6 +17,7 @@ import copy import inspect import re +import sys import uuid from typing import List, Type @@ -42,12 +43,16 @@ def __new__(mcls, name, bases, attrs): if not bases: return super().__new__(mcls, name, bases, attrs) + # Pull a reference to the module where this class is being + # declared. + module = sys.modules.get(attrs.get('__module__')) + # Pop metadata off the attrs. - Meta = attrs.pop('Meta', object()) + proto_module = getattr(module, '__protobuf__', object()) # A package and full name should be present. - package = getattr(Meta, 'package', '') - marshal = Marshal(name=getattr(Meta, 'marshal', package)) + package = getattr(proto_module, 'package', '') + marshal = Marshal(name=getattr(proto_module, 'marshal', package)) local_path = tuple(attrs.get('__qualname__', name).split('.')) # Sanity check: We get the wrong full name if a class is declared @@ -57,9 +62,7 @@ def __new__(mcls, name, bases, attrs): local_path = local_path[:ix - 1] + local_path[ix + 1:] # Determine the full name in protocol buffers. - full_name = getattr(Meta, 'full_name', - '.'.join((package,) + local_path).lstrip('.'), - ) + full_name = '.'.join((package,) + local_path).lstrip('.') # Special case: Maps. Map fields are special; they are essentially # shorthand for a nested message and a repeated field of that message. @@ -92,10 +95,7 @@ def __new__(mcls, name, bases, attrs): prefix=attrs.get('__qualname__', name), name=msg_name, ), - 'Meta': type('Meta', (object,), { - 'options': descriptor_pb2.MessageOptions(map_entry=True), - 'package': package, - }), + '_pb_options': {'map_entry': True}, }) entry_attrs['key'] = Field(field.map_key_type, number=1) entry_attrs['value'] = Field(field.proto_type, number=2, @@ -196,7 +196,7 @@ def __new__(mcls, name, bases, attrs): file_info.descriptor.dependency.append(proto_import) # Retrieve any message options. - opts = getattr(Meta, 'options', descriptor_pb2.MessageOptions()) + opts = descriptor_pb2.MessageOptions(**attrs.pop('_pb_options', {})) # Create the underlying proto descriptor. desc = descriptor_pb2.DescriptorProto( @@ -626,13 +626,17 @@ def ready(self, new_class): if field.message not in self.messages: return False - # If the module in which this class is defined provides an __all__, - # do not generate the file descriptor until every member of __all__ - # has been populated. + # If the module in which this class is defined provides a + # __protobuf__ property, it may have a manifest. + # + # Do not generate the file descriptor until every member of the + # manifest has been populated. module = inspect.getmodule(new_class) - manifest = set(getattr(module, '__all__', ())).difference( - {new_class.__name__}, - ) + manifest = frozenset() + if hasattr(module, '__protobuf__'): + manifest = module.__protobuf__.manifest.difference( + {new_class.__name__}, + ) if not all([hasattr(module, i) for i in manifest]): return False diff --git a/proto/modules.py b/proto/modules.py new file mode 100644 index 00000000..a5ba94ff --- /dev/null +++ b/proto/modules.py @@ -0,0 +1,50 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Set +import collections + + +_ProtoModule = collections.namedtuple('ProtoModule', + ['package', 'marshal', 'manifest'], +) + + +def define_module(*, package: str, marshal: str = None, + manifest: Set[str] = frozenset()) -> _ProtoModule: + """Define a protocol buffers module. + + The settings defined here are used for all protobuf messages + declared in the module of the given name. + + Args: + package (str): The proto package name. + marshal (str): The name of the marshal to use. It is recommended + to use one marshal per Python library (e.g. package on PyPI). + manifest (Tuple[str]): A tuple of classes to be created. Setting + this adds a slight efficiency in piecing together proto + descriptors under the hood. + """ + if not marshal: + marshal = package + return _ProtoModule( + package=package, + marshal=marshal, + manifest=frozenset(manifest), + ) + + +__all__ = ( + 'define_module', +) diff --git a/tests/test_fields_composite_string_ref.py b/tests/test_fields_composite_string_ref.py index 865bc0da..f8df8344 100644 --- a/tests/test_fields_composite_string_ref.py +++ b/tests/test_fields_composite_string_ref.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import sys + import proto @@ -31,24 +33,19 @@ class Foo(proto.Message): def test_composite_forward_ref_with_package(): - class Spam(proto.Message): - foo = proto.Field('Foo', number=1) - - class Meta: - package = 'abc.def' - - class Eggs(proto.Message): - foo = proto.Field('abc.def.Foo', number=1) - - class Meta: - package = 'abc.def' - - class Foo(proto.Message): - bar = proto.Field(proto.STRING, number=1) - baz = proto.Field(proto.INT64, number=2) - - class Meta: - package = 'abc.def' + sys.modules[__name__].__protobuf__ = proto.module(package='abc.def') + try: + class Spam(proto.Message): + foo = proto.Field('Foo', number=1) + + class Eggs(proto.Message): + foo = proto.Field('abc.def.Foo', number=1) + + class Foo(proto.Message): + bar = proto.Field(proto.STRING, number=1) + baz = proto.Field(proto.INT64, number=2) + finally: + del sys.modules[__name__].__protobuf__ spam = Spam(foo=Foo(bar='str', baz=42)) eggs = Eggs(foo=Foo(bar='rts', baz=24)) diff --git a/tests/test_message_all.py b/tests/test_modules.py similarity index 65% rename from tests/test_message_all.py rename to tests/test_modules.py index a2ee1798..0c78eab8 100644 --- a/tests/test_message_all.py +++ b/tests/test_modules.py @@ -12,16 +12,51 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect from unittest import mock +import inspect +import sys from google.protobuf import wrappers_pb2 import proto -def test_message_creation_all(): - __all__ = ('Foo', 'Bar', 'Baz') # noqa: F841 +def test_module_package(): + sys.modules[__name__].__protobuf__ = proto.module(package='spam.eggs.v1') + try: + class Foo(proto.Message): + bar = proto.Field(proto.INT32, number=1) + + marshal = proto.Marshal(name='spam.eggs.v1') + + assert Foo.meta.package == 'spam.eggs.v1' + assert Foo.pb() in marshal._rules + finally: + del sys.modules[__name__].__protobuf__ + + +def test_module_package_explicit_marshal(): + sys.modules[__name__].__protobuf__ = proto.module( + package='spam.eggs.v1', + marshal='foo', + ) + try: + class Foo(proto.Message): + bar = proto.Field(proto.INT32, number=1) + + marshal = proto.Marshal(name='foo') + + assert Foo.meta.package == 'spam.eggs.v1' + assert Foo.pb() in marshal._rules + finally: + del sys.modules[__name__].__protobuf__ + + +def test_module_manifest(): + __protobuf__ = proto.module( + manifest={'Foo', 'Bar', 'Baz'}, + package='spam.eggs.v1', + ) # We want to fake a module, but modules have attribute access, and # `frame.f_locals` is a dictionary. Since we only actually care about