Skip to content

Commit

Permalink
fix review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
dvora-h committed Aug 8, 2022
1 parent ba7d552 commit ee80f7c
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 69 deletions.
69 changes: 28 additions & 41 deletions aredis_om/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from pydantic.utils import Representation
from typing_extensions import Protocol, get_args, get_origin
from ulid import ULID
from more_itertools import ichunked

from ..checks import has_redis_json, has_redisearch
from ..connections import get_redis_connection
Expand Down Expand Up @@ -1114,14 +1115,16 @@ def key(self):
pk = getattr(self, self._meta.primary_key.field.name)
return self.make_primary_key(pk)

@classmethod
async def _delete(cls, db, *pks):
return await db.delete(*pks)

@classmethod
async def delete(cls, pk: Any, pipeline: Optional[Pipeline] = None) -> int:
"""Delete data at this key."""
if pipeline is None:
db = cls.db()
else:
db = pipeline
return await db.delete(cls.make_primary_key(pk))
db = cls._get_db(pipeline)

return await cls._delete(db, cls.make_primary_key(pk))

@classmethod
async def get(cls, pk: Any) -> "RedisModel":
Expand All @@ -1135,10 +1138,7 @@ async def save(self, pipeline: Optional[Pipeline] = None) -> "RedisModel":
raise NotImplementedError

async def expire(self, num_seconds: int, pipeline: Optional[Pipeline] = None):
if pipeline is None:
db = self.db()
else:
db = pipeline
db = self._get_db(pipeline)

# TODO: Wrap any Redis response errors in a custom exception?
await db.expire(self.make_primary_key(self.pk), num_seconds)
Expand Down Expand Up @@ -1248,16 +1248,7 @@ async def add(
pipeline: Optional[Pipeline] = None,
pipeline_verifier: Callable[..., Any] = verify_pipeline_response,
) -> Sequence["RedisModel"]:
if pipeline is None:
# By default, send commands in a pipeline. Saving each model will
# be atomic, but Redis may process other commands in between
# these saves.
db = cls.db().pipeline(transaction=False)
else:
# If the user gave us a pipeline, add our commands to that. The user
# will be responsible for executing the pipeline after they've accumulated
# the commands they want to send.
db = pipeline
db = cls._get_db(pipeline, bulk=True)

for model in models:
# save() just returns the model, we don't need that here.
Expand All @@ -1272,25 +1263,25 @@ async def add(
return models

@classmethod
async def delete_all(
def _get_db(self, pipeline: Optional[Pipeline]=None, bulk: bool=False):
if pipeline is not None:
return pipeline
elif bulk:
return self.db().pipeline(transaction=False)
else:
return self.db()

@classmethod
async def delete_many(
cls,
models: Sequence["RedisModel"],
pipeline: Optional[Pipeline] = None,
pipeline_verifier: Callable[..., Any] = verify_pipeline_response,
) -> int:
if pipeline is None:
db = cls.db().pipeline(transaction=False)
else:
db = pipeline

for model in models:
await model.delete(model.pk, pipeline=db)
db = cls._get_db(pipeline)

# If the user didn't give us a pipeline, then we need to execute
# the one we just created.
if pipeline is None:
result = await db.execute()
pipeline_verifier(result, expected_responses=len(models))
for chunk in ichunked(models, 100):
pks = [cls.make_primary_key(model.pk) for model in chunk]
await cls._delete(db, *pks)

return len(models)

Expand Down Expand Up @@ -1330,10 +1321,8 @@ def __init_subclass__(cls, **kwargs):

async def save(self, pipeline: Optional[Pipeline] = None) -> "HashModel":
self.check()
if pipeline is None:
db = self.db()
else:
db = pipeline
db = self._get_db(pipeline)

document = jsonable_encoder(self.dict())
# TODO: Wrap any Redis response errors in a custom exception?
await db.hset(self.key(), mapping=document)
Expand Down Expand Up @@ -1502,10 +1491,8 @@ def __init__(self, *args, **kwargs):

async def save(self, pipeline: Optional[Pipeline] = None) -> "JsonModel":
self.check()
if pipeline is None:
db = self.db()
else:
db = pipeline
db = self._get_db(pipeline)

# TODO: Wrap response errors in a custom exception?
await db.execute_command("JSON.SET", self.key(), ".", self.json())
return self
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ python-ulid = "^1.0.3"
cleo = "1.0.0a4"
typing-extensions = "^4.0.0"
hiredis = "^2.0.0"
more-itertools = "^8.13.0"

[tool.poetry.dev-dependencies]
mypy = "^0.950"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_hash_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ async def test_delete_many(m):
members = [member1, member2]
result = await m.Member.add(members)
assert result == [member1, member2]
result = await m.Member.delete_all(members)
result = await m.Member.delete_many(members)
assert result == 2
with pytest.raises(NotFoundError):
await m.Member.get(pk=member1.pk)
Expand Down
28 changes: 1 addition & 27 deletions tests/test_json_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,38 +318,12 @@ async def test_delete_many_implicit_pipeline(address, m):
members = [member1, member2]
result = await m.Member.add(members)
assert result == [member1, member2]
result = await m.Member.delete_all(members)
result = await m.Member.delete_many(members)
assert result == 2
with pytest.raises(NotFoundError):
await m.Member.get(pk=member2.pk)


@py_test_mark_asyncio
async def test_delete_many_explicit_transaction(address, m):
member1 = m.Member(
first_name="Andrew",
last_name="Brookins",
email="a@example.com",
join_date=today,
address=address,
age=38,
)
member2 = m.Member(
first_name="Kim",
last_name="Brookins",
email="k@example.com",
join_date=today,
address=address,
age=34,
)
members = [member1, member2]
result = await m.Member.add(members)
assert result == [member1, member2]
async with m.Member.db().pipeline(transaction=True) as pipeline:
await m.Member.delete_all(members, pipeline=pipeline)
assert await pipeline.execute() == [1, 1]


async def save(members):
for m in members:
await m.save()
Expand Down

0 comments on commit ee80f7c

Please sign in to comment.