From f1ed5b2e3e41f09245d1991b5d479555750c7b2b Mon Sep 17 00:00:00 2001 From: Steve Lorello <42971704+slorello89@users.noreply.github.com> Date: Thu, 2 May 2024 10:34:17 -0400 Subject: [PATCH] Refactoring Redis OM Python to use pydantic 2.0 types and validators (#603) * refactoring model to use pydantic 2.0 types and validators * adding tests for #591 * adding tests with uuid * readme fixes * fixing typo in NOT_IN --- README.md | 18 +- aredis_om/_compat.py | 90 +++++++++- aredis_om/model/encoders.py | 2 +- aredis_om/model/model.py | 293 ++++++++++++++++++++++++------- pyproject.toml | 6 +- tests/_compat.py | 7 +- tests/test_hash_model.py | 43 ++++- tests/test_json_model.py | 102 +++++++++-- tests/test_oss_redis_features.py | 4 +- 9 files changed, 468 insertions(+), 97 deletions(-) diff --git a/README.md b/README.md index 568b937e..c8456a83 100644 --- a/README.md +++ b/README.md @@ -93,7 +93,7 @@ Check out this example of modeling customer data with Redis OM. First, we create import datetime from typing import Optional -from pydantic.v1 import EmailStr +from pydantic import EmailStr from redis_om import HashModel @@ -104,7 +104,7 @@ class Customer(HashModel): email: EmailStr join_date: datetime.date age: int - bio: Optional[str] + bio: Optional[str] = None ``` Now that we have a `Customer` model, let's use it to save customer data to Redis. @@ -113,7 +113,7 @@ Now that we have a `Customer` model, let's use it to save customer data to Redis import datetime from typing import Optional -from pydantic.v1 import EmailStr +from pydantic import EmailStr from redis_om import HashModel @@ -124,7 +124,7 @@ class Customer(HashModel): email: EmailStr join_date: datetime.date age: int - bio: Optional[str] + bio: Optional[str] = None # First, we create a new `Customer` object: @@ -168,7 +168,7 @@ For example, because we used the `EmailStr` type for the `email` field, we'll ge import datetime from typing import Optional -from pydantic.v1 import EmailStr, ValidationError +from pydantic import EmailStr, ValidationError from redis_om import HashModel @@ -179,7 +179,7 @@ class Customer(HashModel): email: EmailStr join_date: datetime.date age: int - bio: Optional[str] + bio: Optional[str] = None try: @@ -222,7 +222,7 @@ To show how this works, we'll make a small change to the `Customer` model we def import datetime from typing import Optional -from pydantic.v1 import EmailStr +from pydantic import EmailStr from redis_om import ( Field, @@ -237,7 +237,7 @@ class Customer(HashModel): email: EmailStr join_date: datetime.date age: int = Field(index=True) - bio: Optional[str] + bio: Optional[str] = None # Now, if we use this model with a Redis deployment that has the @@ -287,7 +287,7 @@ from redis_om import ( class Address(EmbeddedJsonModel): address_line_1: str - address_line_2: Optional[str] + address_line_2: Optional[str] = None city: str = Field(index=True) state: str = Field(index=True) country: str diff --git a/aredis_om/_compat.py b/aredis_om/_compat.py index 0246e4f8..07dc2824 100644 --- a/aredis_om/_compat.py +++ b/aredis_om/_compat.py @@ -1,15 +1,92 @@ +from dataclasses import dataclass, is_dataclass +from typing import ( + Any, + Callable, + Deque, + Dict, + FrozenSet, + List, + Mapping, + Sequence, + Set, + Tuple, + Type, + Union, +) + from pydantic.version import VERSION as PYDANTIC_VERSION +from typing_extensions import Annotated, Literal, get_args, get_origin PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") if PYDANTIC_V2: - from pydantic.v1 import BaseModel, validator - from pydantic.v1.fields import FieldInfo, ModelField, Undefined, UndefinedType - from pydantic.v1.json import ENCODERS_BY_TYPE - from pydantic.v1.main import ModelMetaclass, validate_model + + def use_pydantic_2_plus(): + return True + + from pydantic import BaseModel, TypeAdapter + from pydantic import ValidationError as ValidationError + from pydantic import validator + from pydantic._internal._model_construction import ModelMetaclass + from pydantic._internal._repr import Representation + from pydantic.deprecated.json import ENCODERS_BY_TYPE + from pydantic.fields import FieldInfo + from pydantic.v1.main import validate_model from pydantic.v1.typing import NoArgAnyCallable - from pydantic.v1.utils import Representation + from pydantic_core import PydanticUndefined as Undefined + from pydantic_core import PydanticUndefinedType as UndefinedType + + @dataclass + class ModelField: + field_info: FieldInfo + name: str + mode: Literal["validation", "serialization"] = "validation" + + @property + def alias(self) -> str: + a = self.field_info.alias + return a if a is not None else self.name + + @property + def required(self) -> bool: + return self.field_info.is_required() + + @property + def default(self) -> Any: + return self.get_default() + + @property + def type_(self) -> Any: + return self.field_info.annotation + + def __post_init__(self) -> None: + self._type_adapter: TypeAdapter[Any] = TypeAdapter( + Annotated[self.field_info.annotation, self.field_info] + ) + + def get_default(self) -> Any: + if self.field_info.is_required(): + return Undefined + return self.field_info.get_default(call_default_factory=True) + + def validate( + self, + value: Any, + values: Dict[str, Any] = {}, # noqa: B006 + *, + loc: Tuple[Union[int, str], ...] = (), + ) -> Tuple[Any, Union[List[Dict[str, Any]], None]]: + return ( + self._type_adapter.validate_python(value, from_attributes=True), + None, + ) + + def __hash__(self) -> int: + # Each ModelField is unique for our purposes, to allow making a dict from + # ModelField to its JSON Schema. + return id(self) + else: from pydantic import BaseModel, validator from pydantic.fields import FieldInfo, ModelField, Undefined, UndefinedType @@ -17,3 +94,6 @@ from pydantic.main import ModelMetaclass, validate_model from pydantic.typing import NoArgAnyCallable from pydantic.utils import Representation + + def use_pydantic_2_plus(): + return False diff --git a/aredis_om/model/encoders.py b/aredis_om/model/encoders.py index 2f90e481..f097a35d 100644 --- a/aredis_om/model/encoders.py +++ b/aredis_om/model/encoders.py @@ -68,7 +68,7 @@ def jsonable_encoder( if exclude is not None and not isinstance(exclude, (set, dict)): exclude = set(exclude) - if isinstance(obj, BaseModel): + if isinstance(obj, BaseModel) and hasattr(obj, "__config__"): encoder = getattr(obj.__config__, "json_encoders", {}) if custom_encoder: encoder.update(custom_encoder) diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index a90b3971..31c42bdb 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -21,8 +21,8 @@ Type, TypeVar, Union, - no_type_check, ) +from typing import get_args as typing_get_args, no_type_check from more_itertools import ichunked from redis.commands.json.path import Path @@ -75,6 +75,17 @@ ERRORS_URL = "https://github.com/redis/redis-om-python/blob/main/docs/errors.md" +def get_outer_type(field): + if hasattr(field, "outer_type_"): + return field.outer_type_ + elif isinstance(field.annotation, type) or is_supported_container_type( + field.annotation + ): + return field.annotation + else: + return field.annotation.__args__[0] + + class RedisModelError(Exception): """Raised when a problem exists in the definition of a RedisModel.""" @@ -106,7 +117,9 @@ def __str__(self): return str(self.name) -ExpressionOrModelField = Union["Expression", "NegatedExpression", ModelField] +ExpressionOrModelField = Union[ + "Expression", "NegatedExpression", ModelField, PydanticFieldInfo +] def embedded(cls): @@ -130,6 +143,9 @@ def validate_model_fields(model: Type["RedisModel"], field_values: Dict[str, Any if "__" in field_name: obj = model for sub_field in field_name.split("__"): + if not isinstance(obj, ModelMeta) and hasattr(obj, "field"): + obj = getattr(obj, "field").annotation + if not hasattr(obj, sub_field): raise QuerySyntaxError( f"The update path {field_name} contains a field that does not " @@ -138,7 +154,7 @@ def validate_model_fields(model: Type["RedisModel"], field_values: Dict[str, Any obj = getattr(obj, sub_field) return - if field_name not in model.__fields__: + if field_name not in model.__fields__: # type: ignore raise QuerySyntaxError( f"The field {field_name} does not exist on the model {model.__name__}" ) @@ -331,8 +347,11 @@ def __rshift__(self, other: Any) -> Expression: ) def __getattr__(self, item): - if is_supported_container_type(self.field.outer_type_): - embedded_cls = get_args(self.field.outer_type_) + if item.startswith("__"): + raise AttributeError("cannot invoke __getattr__ with reserved field") + outer_type = outer_type_or_annotation(self.field) + if is_supported_container_type(outer_type): + embedded_cls = get_args(outer_type) if not embedded_cls: raise QuerySyntaxError( "In order to query on a list field, you must define " @@ -342,9 +361,9 @@ def __getattr__(self, item): embedded_cls = embedded_cls[0] attr = getattr(embedded_cls, item) else: - attr = getattr(self.field.outer_type_, item) + attr = getattr(outer_type, item) if isinstance(attr, self.__class__): - new_parent = (self.field.name, self.field.outer_type_) + new_parent = (self.field.alias, outer_type) if new_parent not in attr.parents: attr.parents.append(new_parent) new_parents = list(set(self.parents) - set(attr.parents)) @@ -474,13 +493,20 @@ def validate_sort_fields(self, sort_fields: List[str]): field_name = sort_field.lstrip("-") if self.knn and field_name == self.knn.score_field: continue - if field_name not in self.model.__fields__: + if field_name not in self.model.__fields__: # type: ignore raise QueryNotSupportedError( f"You tried sort by {field_name}, but that field " f"does not exist on the model {self.model}" ) field_proxy = getattr(self.model, field_name) - if not getattr(field_proxy.field.field_info, "sortable", False): + if isinstance(field_proxy.field, FieldInfo) or isinstance( + field_proxy.field, PydanticFieldInfo + ): + field_info = field_proxy.field + else: + field_info = field_proxy.field.field_info + + if not getattr(field_info, "sortable", False): raise QueryNotSupportedError( f"You tried sort by {field_name}, but {self.model} does " f"not define that field as sortable. Docs: {ERRORS_URL}#E2" @@ -488,20 +514,31 @@ def validate_sort_fields(self, sort_fields: List[str]): return sort_fields @staticmethod - def resolve_field_type(field: ModelField, op: Operators) -> RediSearchFieldTypes: - if getattr(field.field_info, "primary_key", None) is True: + def resolve_field_type( + field: Union[ModelField, PydanticFieldInfo], op: Operators + ) -> RediSearchFieldTypes: + field_info: Union[FieldInfo, ModelField, PydanticFieldInfo] + + if not hasattr(field, "field_info"): + field_info = field + else: + field_info = field.field_info + if getattr(field_info, "primary_key", None) is True: return RediSearchFieldTypes.TAG elif op is Operators.LIKE: - fts = getattr(field.field_info, "full_text_search", None) + fts = getattr(field_info, "full_text_search", None) if fts is not True: # Could be PydanticUndefined raise QuerySyntaxError( - f"You tried to do a full-text search on the field '{field.name}', " + f"You tried to do a full-text search on the field '{field.alias}', " f"but the field is not indexed for full-text search. Use the " f"full_text_search=True option. Docs: {ERRORS_URL}#E3" ) return RediSearchFieldTypes.TEXT - field_type = field.outer_type_ + field_type = outer_type_or_annotation(field) + + if not isinstance(field_type, type): + field_type = field_type.__origin__ # TODO: GEO fields container_type = get_origin(field_type) @@ -651,7 +688,7 @@ def resolve_value( elif op is Operators.NOT_IN: # TODO: Implement NOT_IN, test this... expanded_value = cls.expand_tag_value(value) - result += "-(@{field_name}):{{{expanded_value}}}".format( + result += "-(@{field_name}:{{{expanded_value}}})".format( field_name=field_name, expanded_value=expanded_value ) @@ -729,6 +766,15 @@ def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str: f"You tried to query by a field ({field_name}) " f"that isn't indexed. Docs: {ERRORS_URL}#E6" ) + elif isinstance(expression.left, FieldInfo): + field_type = cls.resolve_field_type(expression.left, expression.op) + field_name = expression.left.alias + field_info = expression.left + if not field_info or not getattr(field_info, "index", None): + raise QueryNotSupportedError( + f"You tried to query by a field ({field_name}) " + f"that isn't indexed. Docs: {ERRORS_URL}#E6" + ) else: raise QueryNotSupportedError( "A query expression should start with either a field " @@ -1100,27 +1146,27 @@ def Field( default: Any = Undefined, *, default_factory: Optional[NoArgAnyCallable] = None, - alias: str = None, - title: str = None, - description: str = None, + alias: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, exclude: Union[ AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any ] = None, include: Union[ AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any ] = None, - const: bool = None, - gt: float = None, - ge: float = None, - lt: float = None, - le: float = None, - multiple_of: float = None, - min_items: int = None, - max_items: int = None, - min_length: int = None, - max_length: int = None, + const: Optional[bool] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + multiple_of: Optional[float] = None, + min_items: Optional[int] = None, + max_items: Optional[int] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, allow_mutation: bool = True, - regex: str = None, + regex: Optional[str] = None, primary_key: bool = False, sortable: Union[bool, UndefinedType] = Undefined, index: Union[bool, UndefinedType] = Undefined, @@ -1156,7 +1202,6 @@ def Field( vector_options=vector_options, **current_schema_extra, ) - field_info._validate() return field_info @@ -1232,6 +1277,18 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 # Create proxies for each model field so that we can use the field # in queries, like Model.get(Model.field_name == 1) for field_name, field in new_class.__fields__.items(): + if not isinstance(field, FieldInfo): + for base_candidate in bases: + if hasattr(base_candidate, field_name): + inner_field = getattr(base_candidate, field_name) + if hasattr(inner_field, "field") and isinstance( + getattr(inner_field, "field"), FieldInfo + ): + field.metadata.append(getattr(inner_field, "field")) + field = getattr(inner_field, "field") + + if not field.alias: + field.alias = field_name setattr(new_class, field_name, ExpressionProxy(field, [])) annotation = new_class.get_annotations().get(field_name) if annotation: @@ -1241,12 +1298,21 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 else: new_class.__annotations__[field_name] = ExpressionProxy # Check if this is our FieldInfo version with extended ORM metadata. - if isinstance(field.field_info, FieldInfo): - if field.field_info.primary_key: + field_info = None + if hasattr(field, "field_info") and isinstance(field.field_info, FieldInfo): + field_info = field.field_info + elif field_name in attrs and isinstance( + attrs.__getitem__(field_name), FieldInfo + ): + field_info = attrs.__getitem__(field_name) + field.field_info = field_info + + if field_info is not None: + if field_info.primary_key: new_class._meta.primary_key = PrimaryKey( name=field_name, field=field ) - if field.field_info.vector_options: + if field_info.vector_options: score_attr = f"_{field_name}_score" setattr(new_class, score_attr, None) new_class.__annotations__[score_attr] = Union[float, None] @@ -1290,6 +1356,17 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 return new_class +def outer_type_or_annotation(field): + if hasattr(field, "outer_type_"): + return field.outer_type_ + elif not hasattr(field.annotation, "__args__"): + if not isinstance(field.annotation, type): + raise AttributeError(f"could not extract outer type from field {field}") + return field.annotation + else: + return field.annotation.__args__[0] + + class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): pk: Optional[str] = Field(default=None, primary_key=True) @@ -1310,7 +1387,10 @@ def __lt__(self, other): def key(self): """Return the Redis key for this model.""" - pk = getattr(self, self._meta.primary_key.field.name) + if hasattr(self._meta.primary_key.field, "name"): + pk = getattr(self, self._meta.primary_key.field.name) + else: + pk = getattr(self, self._meta.primary_key.name) return self.make_primary_key(pk) @classmethod @@ -1349,7 +1429,7 @@ async def expire( @validator("pk", always=True, allow_reuse=True) def validate_pk(cls, v): - if not v: + if not v or isinstance(v, ExpressionProxy): v = cls._meta.primary_key_creator_cls().create_pk() return v @@ -1358,7 +1438,20 @@ def validate_primary_key(cls): """Check for a primary key. We need one (and only one).""" primary_keys = 0 for name, field in cls.__fields__.items(): - if getattr(field.field_info, "primary_key", None): + if not hasattr(field, "field_info"): + if ( + not isinstance(field, FieldInfo) + and hasattr(field, "metadata") + and len(field.metadata) > 0 + and isinstance(field.metadata[0], FieldInfo) + ): + field_info = field.metadata[0] + else: + field_info = field + else: + field_info = field.field_info + + if getattr(field_info, "primary_key", None): primary_keys += 1 if primary_keys == 0: raise RedisModelError("You must define a primary key for the model") @@ -1490,17 +1583,42 @@ def redisearch_schema(cls): def check(self): """Run all validations.""" - *_, validation_error = validate_model(self.__class__, self.__dict__) - if validation_error: - raise validation_error + from pydantic.version import VERSION as PYDANTIC_VERSION + + PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") + if not PYDANTIC_V2: + *_, validation_error = validate_model(self.__class__, self.__dict__) + if validation_error: + raise validation_error class HashModel(RedisModel, abc.ABC): def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) + if hasattr(cls, "__annotations__"): + for name, field_type in cls.__annotations__.items(): + origin = get_origin(field_type) + for typ in (Set, Mapping, List): + if isinstance(origin, type) and issubclass(origin, typ): + raise RedisModelError( + f"HashModels cannot index set, list," + f" or mapping fields. Field: {name}" + ) + if isinstance(field_type, type) and issubclass(field_type, RedisModel): + raise RedisModelError( + f"HashModels cannot index embedded model fields. Field: {name}" + ) + elif isinstance(field_type, type) and dataclasses.is_dataclass( + field_type + ): + raise RedisModelError( + f"HashModels cannot index dataclass fields. Field: {name}" + ) + for name, field in cls.__fields__.items(): - origin = get_origin(field.outer_type_) + outer_type = outer_type_or_annotation(field) + origin = get_origin(outer_type) if origin: for typ in (Set, Mapping, List): if issubclass(origin, typ): @@ -1509,11 +1627,11 @@ def __init_subclass__(cls, **kwargs): f" or mapping fields. Field: {name}" ) - if issubclass(field.outer_type_, RedisModel): + if issubclass(outer_type, RedisModel): raise RedisModelError( f"HashModels cannot index embedded model fields. Field: {name}" ) - elif dataclasses.is_dataclass(field.outer_type_): + elif dataclasses.is_dataclass(outer_type): raise RedisModelError( f"HashModels cannot index dataclass fields. Field: {name}" ) @@ -1523,7 +1641,6 @@ async def save( ) -> "Model": self.check() db = self._get_db(pipeline) - document = jsonable_encoder(self.dict()) # TODO: Wrap any Redis response errors in a custom exception? await db.hset(self.key(), mapping=document) @@ -1594,21 +1711,32 @@ def schema_for_fields(cls): for name, field in cls.__fields__.items(): # TODO: Merge this code with schema_for_type()? - _type = field.outer_type_ + _type = outer_type_or_annotation(field) is_subscripted_type = get_origin(_type) - if getattr(field.field_info, "primary_key", None): + if ( + not isinstance(field, FieldInfo) + and hasattr(field, "metadata") + and len(field.metadata) > 0 + and isinstance(field.metadata[0], FieldInfo) + ): + field = field.metadata[0] + + if not hasattr(field, "field_info"): + field_info = field + else: + field_info = field.field_info + + if getattr(field_info, "primary_key", None): if issubclass(_type, str): redisearch_field = ( f"{name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}" ) else: - redisearch_field = cls.schema_for_type( - name, _type, field.field_info - ) + redisearch_field = cls.schema_for_type(name, _type, field_info) schema_parts.append(redisearch_field) - elif getattr(field.field_info, "index", None) is True: - schema_parts.append(cls.schema_for_type(name, _type, field.field_info)) + elif getattr(field_info, "index", None) is True: + schema_parts.append(cls.schema_for_type(name, _type, field_info)) elif is_subscripted_type: # Ignore subscripted types (usually containers!) that we don't # support, for the purposes of indexing. @@ -1621,11 +1749,9 @@ def schema_for_fields(cls): log.warning("Model %s defined an empty list field: %s", cls, name) continue embedded_cls = embedded_cls[0] - schema_parts.append( - cls.schema_for_type(name, embedded_cls, field.field_info) - ) + schema_parts.append(cls.schema_for_type(name, embedded_cls, field_info)) elif issubclass(_type, RedisModel): - schema_parts.append(cls.schema_for_type(name, _type, field.field_info)) + schema_parts.append(cls.schema_for_type(name, _type, field_info)) return schema_parts @classmethod @@ -1760,11 +1886,42 @@ def redisearch_schema(cls): def schema_for_fields(cls): schema_parts = [] json_path = "$" - + fields = dict() for name, field in cls.__fields__.items(): - _type = field.outer_type_ + fields[name] = field + for name, field in cls.__dict__.items(): + if isinstance(field, FieldInfo): + if not field.annotation: + field.annotation = cls.__annotations__.get(name) + fields[name] = field + for name, field in cls.__annotations__.items(): + if name in fields: + continue + fields[name] = PydanticFieldInfo.from_annotation(field) + + for name, field in fields.items(): + _type = get_outer_type(field) + if ( + not isinstance(field, FieldInfo) + and hasattr(field, "metadata") + and len(field.metadata) > 0 + and isinstance(field.metadata[0], FieldInfo) + ): + field = field.metadata[0] + + if hasattr(field, "field_info"): + field_info = field.field_info + else: + field_info = field + if getattr(field_info, "primary_key", None): + if issubclass(_type, str): + redisearch_field = f"$.{name} AS {name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}" + else: + redisearch_field = cls.schema_for_type(name, _type, field_info) + schema_parts.append(redisearch_field) + continue schema_parts.append( - cls.schema_for_type(json_path, name, "", _type, field.field_info) + cls.schema_for_type(json_path, name, "", _type, field_info) ) return schema_parts @@ -1843,6 +2000,17 @@ def schema_for_type( name_prefix = f"{name_prefix}_{name}" if name_prefix else name sub_fields = [] for embedded_name, field in typ.__fields__.items(): + if hasattr(field, "field_info"): + field_info = field.field_info + elif ( + hasattr(field, "metadata") + and len(field.metadata) > 0 + and isinstance(field.metadata[0], FieldInfo) + ): + field_info = field.metadata[0] + else: + field_info = field + if parent_is_container_type: # We'll store this value either as a JavaScript array, so # the correct JSONPath expression is to refer directly to @@ -1859,8 +2027,9 @@ def schema_for_type( path, embedded_name, name_prefix, - field.outer_type_, - field.field_info, + # field.annotation, + get_outer_type(field), + field_info, parent_type=typ, ) ) @@ -1884,6 +2053,12 @@ def schema_for_type( "See docs: TODO" ) + # For more complicated compound validators (e.g. PositiveInt), we might get a _GenericAlias rather than + # a proper type, we can pull the type information from the origin of the first argument. + if not isinstance(typ, type): + type_args = typing_get_args(field_info.annotation) + typ = type_args[0].__origin__ + # TODO: GEO field if is_vector and vector_options: schema = f"{path} AS {index_field_name} {vector_options.schema}" diff --git a/pyproject.toml b/pyproject.toml index c92dae15..b9ed08ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "redis-om" -version = "0.2.2" +version = "0.3.0" description = "Object mappings, and more, for Redis." authors = ["Redis OSS "] maintainers = ["Redis OSS "] @@ -37,7 +37,7 @@ include=[ [tool.poetry.dependencies] python = ">=3.8,<4.0" redis = ">=3.5.3,<6.0.0" -pydantic = ">=1.10.2,<2.5.0" +pydantic = ">=1.10.2,<3.0.0" click = "^8.0.1" types-redis = ">=3.5.9,<5.0.0" python-ulid = "^1.0.3" @@ -47,7 +47,7 @@ more-itertools = ">=8.14,<11.0" setuptools = {version = "^69.2.0", markers = "python_version >= '3.12'"} [tool.poetry.dev-dependencies] -mypy = "^0.982" +mypy = "^1.9.0" pytest = "^8.0.2" ipdb = "^0.13.9" black = "^24.2" diff --git a/tests/_compat.py b/tests/_compat.py index c21b47d2..1cd55bf2 100644 --- a/tests/_compat.py +++ b/tests/_compat.py @@ -1,7 +1,10 @@ -from aredis_om._compat import PYDANTIC_V2 +from aredis_om._compat import PYDANTIC_V2, use_pydantic_2_plus -if PYDANTIC_V2: +if not use_pydantic_2_plus() and PYDANTIC_V2: from pydantic.v1 import EmailStr, ValidationError +elif PYDANTIC_V2: + from pydantic import EmailStr, PositiveInt, ValidationError + else: from pydantic import EmailStr, ValidationError diff --git a/tests/test_hash_model.py b/tests/test_hash_model.py index 38ca18e2..f7aee626 100644 --- a/tests/test_hash_model.py +++ b/tests/test_hash_model.py @@ -4,8 +4,9 @@ import dataclasses import datetime import decimal +import uuid from collections import namedtuple -from typing import Dict, List, Optional, Set +from typing import Dict, List, Optional, Set, Union from unittest import mock import pytest @@ -388,7 +389,10 @@ def test_validates_required_fields(m): # Raises ValidationError: last_name is required # TODO: Test the error value with pytest.raises(ValidationError): - m.Member(id=0, first_name="Andrew", zipcode="97086", join_date=today) + try: + m.Member(id=0, first_name="Andrew", zipcode="97086", join_date=today) + except Exception as e: + raise e def test_validates_field(m): @@ -581,6 +585,7 @@ class Address(m.BaseHashModel): with pytest.raises(RedisModelError): class InvalidMember(m.BaseHashModel): + name: str = Field(index=True) address: Address @@ -728,7 +733,6 @@ class Address(m.BaseHashModel): # We need to build the key prefix because it will differ based on whether # these tests were copied into the tests_sync folder and unasynce'd. key_prefix = Address.make_key(Address._meta.primary_key_pattern.format(pk="")) - assert ( Address.redisearch_schema() == f"ON HASH PREFIX 1 {key_prefix} SCHEMA pk TAG SEPARATOR | a_string TAG SEPARATOR | a_full_text_string TAG SEPARATOR | a_full_text_string AS a_full_text_string_fts TEXT an_integer NUMERIC SORTABLE a_float NUMERIC" @@ -804,3 +808,36 @@ async def test_count(members, m): m.Member.first_name == "Kim", m.Member.last_name == "Brookins" ).count() assert actual_count == 1 + + +@py_test_mark_asyncio +async def test_type_with_union(members, m): + class TypeWithUnion(m.BaseHashModel): + field: Union[str, int] + + twu_str = TypeWithUnion(field="hello world") + res = await twu_str.save() + assert res.pk == twu_str.pk + twu_str_rematerialized = await TypeWithUnion.get(twu_str.pk) + assert ( + isinstance(twu_str_rematerialized.field, str) + and twu_str_rematerialized.pk == twu_str.pk + ) + + twu_int = TypeWithUnion(field=42) + await twu_int.save() + twu_int_rematerialized = await TypeWithUnion.get(twu_int.pk) + + # Note - we will not be able to automatically serialize an int back to this union type, + # since as far as we know from Redis this item is a string + assert twu_int_rematerialized.pk == twu_int.pk + + +@py_test_mark_asyncio +async def test_type_with_uuid(): + class TypeWithUuid(HashModel): + uuid: uuid.UUID + + item = TypeWithUuid(uuid=uuid.uuid4()) + + await item.save() diff --git a/tests/test_json_model.py b/tests/test_json_model.py index 8fec6c0a..55e8b0fa 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -4,8 +4,9 @@ import dataclasses import datetime import decimal +import uuid from collections import namedtuple -from typing import Dict, List, Optional, Set +from typing import Dict, List, Optional, Set, Union from unittest import mock import pytest @@ -24,7 +25,7 @@ # We need to run this check as sync code (during tests) even in async mode # because we call it in the top-level module scope. from redis_om import has_redis_json -from tests._compat import ValidationError +from tests._compat import EmailStr, PositiveInt, ValidationError from .conftest import py_test_mark_asyncio @@ -50,12 +51,12 @@ class Note(EmbeddedJsonModel): class Address(EmbeddedJsonModel): address_line_1: str - address_line_2: Optional[str] + address_line_2: Optional[str] = None city: str = Field(index=True) state: str country: str postal_code: str = Field(index=True) - note: Optional[Note] + note: Optional[Note] = None class Item(EmbeddedJsonModel): price: decimal.Decimal @@ -68,16 +69,16 @@ class Order(EmbeddedJsonModel): class Member(BaseJsonModel): first_name: str = Field(index=True) last_name: str = Field(index=True) - email: str = Field(index=True) + email: Optional[EmailStr] = Field(index=True, default=None) join_date: datetime.date - age: int = Field(index=True) + age: Optional[PositiveInt] = Field(index=True, default=None) bio: Optional[str] = Field(index=True, full_text_search=True, default="") # Creates an embedded model. address: Address # Creates an embedded list of models. - orders: Optional[List[Order]] + orders: Optional[List[Order]] = None await Migrator().run() @@ -88,13 +89,16 @@ class Member(BaseJsonModel): @pytest.fixture() def address(m): - yield m.Address( - address_line_1="1 Main St.", - city="Portland", - state="OR", - country="USA", - postal_code=11111, - ) + try: + yield m.Address( + address_line_1="1 Main St.", + city="Portland", + state="OR", + country="USA", + postal_code="11111", + ) + except Exception as e: + raise e @pytest_asyncio.fixture() @@ -133,6 +137,34 @@ async def members(address, m): yield member1, member2, member3 +@py_test_mark_asyncio +async def test_validate_bad_email(address, m): + # Raises ValidationError as email is malformed + with pytest.raises(ValidationError): + m.Member( + first_name="Andrew", + last_name="Brookins", + zipcode="97086", + join_date=today, + email="foobarbaz", + ) + + +@py_test_mark_asyncio +async def test_validate_bad_age(address, m): + # Raises ValidationError as email is malformed + with pytest.raises(ValidationError): + m.Member( + first_name="Andrew", + last_name="Brookins", + zipcode="97086", + join_date=today, + email="foo@bar.com", + address=address, + age=-5, + ) + + @py_test_mark_asyncio async def test_validates_required_fields(address, m): # Raises ValidationError address is required @@ -422,6 +454,15 @@ async def test_in_query(members, m): ) assert actual == [member2, member1, member3] +@py_test_mark_asyncio +async def test_not_in_query(members, m): + member1, member2, member3 = members + actual = await ( + m.Member.find(m.Member.pk >> [member2.pk, member3.pk]) + .sort_by("age") + .all() + ) + assert actual == [ member1] @py_test_mark_asyncio async def test_update_query(members, m): @@ -849,3 +890,36 @@ async def test_count(members, m): m.Member.first_name == "Kim", m.Member.last_name == "Brookins" ).count() assert actual_count == 1 + + +@py_test_mark_asyncio +async def test_type_with_union(members, m): + class TypeWithUnion(m.BaseJsonModel): + field: Union[str, int] + + twu_str = TypeWithUnion(field="hello world") + res = await twu_str.save() + assert res.pk == twu_str.pk + twu_str_rematerialized = await TypeWithUnion.get(twu_str.pk) + assert ( + isinstance(twu_str_rematerialized.field, str) + and twu_str_rematerialized.pk == twu_str.pk + ) + + twu_int = TypeWithUnion(field=42) + await twu_int.save() + twu_int_rematerialized = await TypeWithUnion.get(twu_int.pk) + assert ( + isinstance(twu_int_rematerialized.field, int) + and twu_int_rematerialized.pk == twu_int.pk + ) + + +@py_test_mark_asyncio +async def test_type_with_uuid(): + class TypeWithUuid(JsonModel): + uuid: uuid.UUID + + item = TypeWithUuid(uuid=uuid.uuid4()) + + await item.save() diff --git a/tests/test_oss_redis_features.py b/tests/test_oss_redis_features.py index 4d5b0913..47ebe47f 100644 --- a/tests/test_oss_redis_features.py +++ b/tests/test_oss_redis_features.py @@ -166,7 +166,9 @@ async def test_saves_many(m): result = await m.Member.add(members) assert result == [member1, member2] - assert await m.Member.get(pk=member1.pk) == member1 + m1_rematerialized = await m.Member.get(pk=member1.pk) + + assert m1_rematerialized == member1 assert await m.Member.get(pk=member2.pk) == member2