diff --git a/README.md b/README.md index e1a6f2d..59cbf03 100644 --- a/README.md +++ b/README.md @@ -222,4 +222,36 @@ print(OrderUserSchema.from_orm(user).json(ident=4)) "id": 1, "email": "" } -``` \ No newline at end of file +``` + + +### Include from annotations + +By default, a Schema without Config.include or Config.exclude defined will include all fields of the Config.model class. + +If you want to limit included fields to the annotations of the Schema without defining Config.include, use `Config.include = "__annotations__"`. + + +```python +class ProfileSchema(ModelSchema): + website: str + + class Config: + model = Profile + include = "__annotations__" + + assert ProfileSchema.schema() == { + "title": "ProfileSchema", + "description": "A user's profile.", + "type": "object", + "properties": { + "website": { + "title": "Website", + "type": "string" + } + }, + "required": [ + "website" + ] + } +``` diff --git a/djantic/main.py b/djantic/main.py index 94b57ed..aaeeb60 100644 --- a/djantic/main.py +++ b/djantic/main.py @@ -27,6 +27,13 @@ def default(self, obj): # pragma: nocover return super().default(obj) +def get_field_name(field) -> str: + if issubclass(field.__class__, ForeignObjectRel) and not issubclass(field.__class__, OneToOneRel): + return getattr(field, "related_name", None) or f"{field.name}_set" + else: + return getattr(field, "name", field) + + class ModelSchemaMetaclass(ModelMetaclass): @no_type_check def __new__( @@ -49,7 +56,7 @@ def __new__( raise ConfigError( f"{exc} (Is `Config` class defined?)" ) - + include = getattr(config, "include", None) exclude = getattr(config, "exclude", None) @@ -68,17 +75,18 @@ def __new__( f"{exc} (Is `Config.model` a valid Django model class?)" ) - if include is None and exclude is None: - cls.__config__.include = [f.name for f in fields] + if include == '__annotations__': + include = list(annotations.keys()) + cls.__config__.include = include + elif include is None and exclude is None: + include = list(annotations.keys()) + [get_field_name(f) for f in fields] + cls.__config__.include = include field_values = {} _seen = set() for field in chain(fields, annotations.copy()): - if issubclass(field.__class__, ForeignObjectRel) and not issubclass(field.__class__, OneToOneRel): - field_name = getattr(field, "related_name", None) or f"{field.name}_set" - else: - field_name = getattr(field, "name", field) + field_name = get_field_name(field) if ( field_name in _seen @@ -114,6 +122,8 @@ def __new__( cls.__doc__ = namespace.get("__doc__", config.model.__doc__) cls.__fields__ = {} + cls.__alias_map__ = {getattr(model_field[1], 'alias', None) or field_name: field_name + for field_name, model_field in field_values.items()} model_schema = create_model( name, __base__=cls, __module__=cls.__module__, **field_values ) @@ -129,14 +139,14 @@ def __init__(self, obj: Any, schema_class): self.schema_class = schema_class def get(self, key: Any, default: Any = None) -> Any: + alias = self.schema_class.__alias_map__[key] + outer_type_ = self.schema_class.__fields__[alias].outer_type_ if "__" in key: # Allow double underscores aliases: `first_name: str = Field(alias="user__first_name")` keys_map = key.split("__") attr = reduce(lambda a, b: getattr(a, b, default), keys_map, self._obj) - outer_type_ = self.schema_class.__fields__["user"].outer_type_ else: - attr = getattr(self._obj, key) - outer_type_ = self.schema_class.__fields__[key].outer_type_ + attr = getattr(self._obj, key, None) is_manager = issubclass(attr.__class__, Manager) diff --git a/tests/test_multiple_level_relations.py b/tests/test_multiple_level_relations.py index d6c24cb..7b4240a 100644 --- a/tests/test_multiple_level_relations.py +++ b/tests/test_multiple_level_relations.py @@ -1,8 +1,9 @@ from decimal import Decimal -from typing import List +from typing import List, Optional import pytest +from pydantic import validator from testapp.order import Order, OrderItem, OrderItemDetail, OrderUser, OrderUserFactory, OrderUserProfile from djantic import ModelSchema @@ -34,9 +35,23 @@ class Config: class OrderUserSchema(ModelSchema): orders: List[OrderSchema] profile: OrderUserProfileSchema + user_cache: Optional[dict] class Config: model = OrderUser + include = ('id', + 'first_name', + 'last_name', + 'email', + 'profile', + 'orders', + 'user_cache') + + @validator('user_cache', pre=True, always=True) + def get_user_cache(cls, _): + return { + 'has_order': True + } user = OrderUserFactory.create() @@ -45,6 +60,7 @@ class Config: 'first_name': '', 'last_name': None, 'email': '', + 'user_cache': {'has_order': True}, 'profile': { 'id': 1, 'address': '', @@ -195,6 +211,10 @@ class Config: "description": "email", "maxLength": 254, "type": "string" + }, + "user_cache": { + "title": "User Cache", + "type": "object" } }, "required": [ diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 7218b52..eb8cd68 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -380,3 +380,32 @@ class Config: ] }""" assert ConfigurationSchema.schema_json(indent=2) == expected + + +@pytest.mark.django_db +def test_include_from_annotations(): + """ + Test include="__annotations__" config. + """ + + class ProfileSchema(ModelSchema): + website: str + + class Config: + model = Profile + include = "__annotations__" + + assert ProfileSchema.schema() == { + "title": "ProfileSchema", + "description": "A user's profile.", + "type": "object", + "properties": { + "website": { + "title": "Website", + "type": "string" + } + }, + "required": [ + "website" + ] + }