diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index acfd956d..56c5c90d 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -22,7 +22,8 @@ TypeVar, Union, ) -from typing import get_args as typing_get_args, no_type_check +from typing import get_args as typing_get_args +from typing import no_type_check from more_itertools import ichunked from redis.commands.json.path import Path @@ -112,6 +113,9 @@ class Operators(Enum): NOT_IN = 11 LIKE = 12 ALL = 13 + STARTSWITH = 14 + ENDSWITH = 15 + CONTAINS = 16 def __str__(self): return str(self.name) @@ -346,6 +350,21 @@ def __rshift__(self, other: Any) -> Expression: left=self.field, op=Operators.NOT_IN, right=other, parents=self.parents ) + def startswith(self, other: Any) -> Expression: + return Expression( + left=self.field, op=Operators.STARTSWITH, right=other, parents=self.parents + ) + + def endswith(self, other: Any) -> Expression: + return Expression( + left=self.field, op=Operators.ENDSWITH, right=other, parents=self.parents + ) + + def contains(self, other: Any) -> Expression: + return Expression( + left=self.field, op=Operators.CONTAINS, right=other, parents=self.parents + ) + def __getattr__(self, item): if item.startswith("__"): raise AttributeError("cannot invoke __getattr__ with reserved field") @@ -691,6 +710,21 @@ def resolve_value( result += "-(@{field_name}:{{{expanded_value}}})".format( field_name=field_name, expanded_value=expanded_value ) + elif op is Operators.STARTSWITH: + expanded_value = cls.expand_tag_value(value) + result += "(@{field_name}:{{{expanded_value}*}})".format( + field_name=field_name, expanded_value=expanded_value + ) + elif op is Operators.ENDSWITH: + expanded_value = cls.expand_tag_value(value) + result += "(@{field_name}:{{*{expanded_value}}})".format( + field_name=field_name, expanded_value=expanded_value + ) + elif op is Operators.CONTAINS: + expanded_value = cls.expand_tag_value(value) + result += "(@{field_name}:{{*{expanded_value}*}})".format( + field_name=field_name, expanded_value=expanded_value + ) return result diff --git a/tests/test_hash_model.py b/tests/test_hash_model.py index 05d98724..185005f6 100644 --- a/tests/test_hash_model.py +++ b/tests/test_hash_model.py @@ -852,3 +852,26 @@ class TypeWithUuid(HashModel): item = TypeWithUuid(uuid=uuid.uuid4()) await item.save() + + +@py_test_mark_asyncio +async def test_xfix_queries(members, m): + member1, member2, member3 = members + + result = await m.Member.find(m.Member.first_name.startswith("And")).first() + assert result.first_name == "Andrew" + + result = await m.Member.find(m.Member.last_name.endswith("ins")).first() + assert result.first_name == "Andrew" + + result = await m.Member.find(m.Member.last_name.contains("ook")).first() + assert result.first_name == "Andrew" + + result = await m.Member.find(m.Member.bio % "great*").first() + assert result.first_name == "Andrew" + + result = await m.Member.find(m.Member.bio % "*rty").first() + assert result.first_name == "Andrew" + + result = await m.Member.find(m.Member.bio % "*eat*").first() + assert result.first_name == "Andrew" diff --git a/tests/test_json_model.py b/tests/test_json_model.py index e7c6ca61..d5744858 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -934,3 +934,40 @@ class TypeWithUuid(JsonModel): item = TypeWithUuid(uuid=uuid.uuid4()) await item.save() + + +@py_test_mark_asyncio +async def test_xfix_queries(m): + await m.Member( + first_name="Steve", + last_name="Lorello", + email="s@example.com", + join_date=today, + bio="Steve is a two-bit hacker who loves Redis.", + address=m.Address( + address_line_1="42 foo bar lane", + city="Satellite Beach", + state="FL", + country="USA", + postal_code="32999", + ), + age=34, + ).save() + + result = await m.Member.find(m.Member.first_name.startswith("Ste")).first() + assert result.first_name == "Steve" + + result = await m.Member.find(m.Member.last_name.endswith("llo")).first() + assert result.first_name == "Steve" + + result = await m.Member.find(m.Member.address.city.contains("llite")).first() + assert result.first_name == "Steve" + + result = await m.Member.find(m.Member.bio % "tw*").first() + assert result.first_name == "Steve" + + result = await m.Member.find(m.Member.bio % "*cker").first() + assert result.first_name == "Steve" + + result = await m.Member.find(m.Member.bio % "*ack*").first() + assert result.first_name == "Steve"