diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index 1460a7ab..078080d9 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -34,6 +34,7 @@ from pydantic.utils import Representation from typing_extensions import Protocol, get_args, get_origin from ulid import ULID +from more_itertools import ichunked from ..checks import has_redis_json, has_redisearch from ..connections import get_redis_connection @@ -1114,14 +1115,16 @@ def key(self): pk = getattr(self, self._meta.primary_key.field.name) return self.make_primary_key(pk) + @classmethod + async def _delete(cls, db, *pks): + return await db.delete(*pks) + @classmethod async def delete(cls, pk: Any, pipeline: Optional[Pipeline] = None) -> int: """Delete data at this key.""" - if pipeline is None: - db = cls.db() - else: - db = pipeline - return await db.delete(cls.make_primary_key(pk)) + db = cls._get_db(pipeline) + + return await cls._delete(db, cls.make_primary_key(pk)) @classmethod async def get(cls, pk: Any) -> "RedisModel": @@ -1135,10 +1138,7 @@ async def save(self, pipeline: Optional[Pipeline] = None) -> "RedisModel": raise NotImplementedError async def expire(self, num_seconds: int, pipeline: Optional[Pipeline] = None): - if pipeline is None: - db = self.db() - else: - db = pipeline + db = self._get_db(pipeline) # TODO: Wrap any Redis response errors in a custom exception? await db.expire(self.make_primary_key(self.pk), num_seconds) @@ -1248,16 +1248,7 @@ async def add( pipeline: Optional[Pipeline] = None, pipeline_verifier: Callable[..., Any] = verify_pipeline_response, ) -> Sequence["RedisModel"]: - if pipeline is None: - # By default, send commands in a pipeline. Saving each model will - # be atomic, but Redis may process other commands in between - # these saves. - db = cls.db().pipeline(transaction=False) - else: - # If the user gave us a pipeline, add our commands to that. The user - # will be responsible for executing the pipeline after they've accumulated - # the commands they want to send. - db = pipeline + db = cls._get_db(pipeline, bulk=True) for model in models: # save() just returns the model, we don't need that here. @@ -1272,25 +1263,25 @@ async def add( return models @classmethod - async def delete_all( + def _get_db(self, pipeline: Optional[Pipeline]=None, bulk: bool=False): + if pipeline is not None: + return pipeline + elif bulk: + return self.db().pipeline(transaction=False) + else: + return self.db() + + @classmethod + async def delete_many( cls, models: Sequence["RedisModel"], pipeline: Optional[Pipeline] = None, - pipeline_verifier: Callable[..., Any] = verify_pipeline_response, ) -> int: - if pipeline is None: - db = cls.db().pipeline(transaction=False) - else: - db = pipeline - - for model in models: - await model.delete(model.pk, pipeline=db) + db = cls._get_db(pipeline) - # If the user didn't give us a pipeline, then we need to execute - # the one we just created. - if pipeline is None: - result = await db.execute() - pipeline_verifier(result, expected_responses=len(models)) + for chunk in ichunked(models, 100): + pks = [cls.make_primary_key(model.pk) for model in chunk] + await cls._delete(db, *pks) return len(models) @@ -1330,10 +1321,8 @@ def __init_subclass__(cls, **kwargs): async def save(self, pipeline: Optional[Pipeline] = None) -> "HashModel": self.check() - if pipeline is None: - db = self.db() - else: - db = pipeline + db = self._get_db(pipeline) + document = jsonable_encoder(self.dict()) # TODO: Wrap any Redis response errors in a custom exception? await db.hset(self.key(), mapping=document) @@ -1502,10 +1491,8 @@ def __init__(self, *args, **kwargs): async def save(self, pipeline: Optional[Pipeline] = None) -> "JsonModel": self.check() - if pipeline is None: - db = self.db() - else: - db = pipeline + db = self._get_db(pipeline) + # TODO: Wrap response errors in a custom exception? await db.execute_command("JSON.SET", self.key(), ".", self.json()) return self diff --git a/pyproject.toml b/pyproject.toml index 3a6ff9d2..b79ad7bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ python-ulid = "^1.0.3" cleo = "1.0.0a4" typing-extensions = "^4.0.0" hiredis = "^2.0.0" +more-itertools = "^8.13.0" [tool.poetry.dev-dependencies] mypy = "^0.950" diff --git a/tests/test_hash_model.py b/tests/test_hash_model.py index bc33f82d..29aeda28 100644 --- a/tests/test_hash_model.py +++ b/tests/test_hash_model.py @@ -574,7 +574,7 @@ async def test_delete_many(m): members = [member1, member2] result = await m.Member.add(members) assert result == [member1, member2] - result = await m.Member.delete_all(members) + result = await m.Member.delete_many(members) assert result == 2 with pytest.raises(NotFoundError): await m.Member.get(pk=member1.pk) diff --git a/tests/test_json_model.py b/tests/test_json_model.py index 8f84234d..6d424eaa 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -318,38 +318,12 @@ async def test_delete_many_implicit_pipeline(address, m): members = [member1, member2] result = await m.Member.add(members) assert result == [member1, member2] - result = await m.Member.delete_all(members) + result = await m.Member.delete_many(members) assert result == 2 with pytest.raises(NotFoundError): await m.Member.get(pk=member2.pk) -@py_test_mark_asyncio -async def test_delete_many_explicit_transaction(address, m): - member1 = m.Member( - first_name="Andrew", - last_name="Brookins", - email="a@example.com", - join_date=today, - address=address, - age=38, - ) - member2 = m.Member( - first_name="Kim", - last_name="Brookins", - email="k@example.com", - join_date=today, - address=address, - age=34, - ) - members = [member1, member2] - result = await m.Member.add(members) - assert result == [member1, member2] - async with m.Member.db().pipeline(transaction=True) as pipeline: - await m.Member.delete_all(members, pipeline=pipeline) - assert await pipeline.execute() == [1, 1] - - async def save(members): for m in members: await m.save()