Skip to content

Commit

Permalink
Fix mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
dbernstein committed Dec 9, 2024
1 parent 34c2189 commit 1cc625a
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
16 changes: 9 additions & 7 deletions src/palace/manager/celery/tasks/opds_odl.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
def remove_expired_holds_for_collection(
db: Session,
collection_id: int,
) -> tuple[int, dict[str, Any]]:
) -> tuple[int, list[dict[str, Any]]]:
"""
Remove expired holds from the database for this collection.
"""
Expand Down Expand Up @@ -52,7 +52,7 @@ def remove_expired_holds_for_collection(
)

expired_holds = db.scalars(select_query).all()
expired_hold_events: [dict[str, Any]] = []
expired_hold_events: list[dict[str, Any]] = []
for hold in expired_holds:
expired_hold_events.append(
dict(
Expand Down Expand Up @@ -80,7 +80,7 @@ def remove_expired_holds_for_collection(
# a rowcount, but the sqlalchemy docs swear it will in the case of
# a delete statement.
# https://docs.sqlalchemy.org/en/20/tutorial/data_update.html#getting-affected-row-count-from-update-delete
return result.rowcount, expired_hold_events # type: ignore[attr-defined,no-any-return]
return result.rowcount, expired_hold_events # type: ignore


def licensepool_ids_with_holds(
Expand Down Expand Up @@ -119,7 +119,7 @@ def lock_licenses(license_pool: LicensePool) -> None:
def recalculate_holds_for_licensepool(
license_pool: LicensePool,
reservation_period: datetime.timedelta,
) -> tuple[int, dict[str, Any]]:
) -> tuple[int, list[dict[str, Any]]]:
# We take out row level locks on all the licenses and holds for this license pool, so that
# everything is in a consistent state while we update the hold queue. This means we should be
# quickly committing the transaction, to avoid contention or deadlocks.
Expand All @@ -133,7 +133,7 @@ def recalculate_holds_for_licensepool(
waiting = holds[reserved:]
updated = 0

events: [dict[str, Any]] = []
events: list[dict[str, Any]] = []

# These holds have a copy reserved for them.
for hold in ready:
Expand Down Expand Up @@ -175,8 +175,10 @@ def remove_expired_holds_for_collection_task(task: Task, collection_id: int) ->
session,
collection_id,
)

collection_name = None if not collection else collection.name
task.log.info(
f"Removed {removed} expired holds for collection {collection.name} ({collection_id})."
f"Removed {removed} expired holds for collection {collection_name} ({collection_id})."
)

# publish events only after successful commit
Expand All @@ -198,7 +200,7 @@ def remove_expired_holds(task: Task) -> None:
if collection.id is not None
]
for collection_id, collection_name in collections:
remove_expired_holds_for_collection.delay(collection_id)
remove_expired_holds_for_collection_task.delay(collection_id)


@shared_task(queue=QueueNames.default, bind=True)
Expand Down
6 changes: 4 additions & 2 deletions tests/manager/celery/tasks/test_opds_odl.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def test_recalculate_holds_for_licensepool(
assert event["event_type"] == CirculationEvent.CM_HOLD_READY_FOR_CHECKOUT


def test_remove_expired_holds_for_collection(
def test_remove_expired_holds_for_collection_task(
celery_fixture: CeleryFixture,
db: DatabaseTransactionFixture,
opds_task_fixture: OpdsTaskFixture,
Expand Down Expand Up @@ -305,7 +305,9 @@ def test_remove_expired_holds(
collection2 = db.collection(protocol=OPDS2WithODLApi)
decoy_collection = db.collection(protocol=OverdriveAPI)

with patch.object(opds_odl, "remove_expired_holds_for_collection") as mock_remove:
with patch.object(
opds_odl, "remove_expired_holds_for_collection_task"
) as mock_remove:
remove_expired_holds.delay().wait()

assert mock_remove.delay.call_count == 2
Expand Down

0 comments on commit 1cc625a

Please sign in to comment.