diff --git a/src/nextline_rdb/utils/sa.py b/src/nextline_rdb/utils/sa.py index 26ece98..ed55e98 100644 --- a/src/nextline_rdb/utils/sa.py +++ b/src/nextline_rdb/utils/sa.py @@ -1,3 +1,4 @@ +from logging import getLogger from typing import Any, TypeVar from sqlalchemy import Select, inspect, select @@ -8,8 +9,15 @@ T = TypeVar('T', bound=DeclarativeBase) +# NOTE: Consider make this configurable. +DEFAULT_UNTIL_SCALAR_ONE_TIMEOUT = 60 # seconds -async def until_scalar_one(session: AsyncSession, stmt: Select[tuple[T]]) -> T: + +async def until_scalar_one( + session: AsyncSession, + stmt: Select[tuple[T]], + timeout: float = DEFAULT_UNTIL_SCALAR_ONE_TIMEOUT, +) -> T: '''Execute the statement until it returns exactly one row. The statement is repeatedly executed while it returns no rows. An exception @@ -19,7 +27,12 @@ async def until_scalar_one(session: AsyncSession, stmt: Select[tuple[T]]) -> T: async def _f() -> T | None: return (await session.execute(stmt)).scalar_one_or_none() - return await until_not_none(_f) + try: + return await until_not_none(_f, timeout=timeout) + except Exception: + logger = getLogger(__name__) + logger.exception('') + raise async def load_all(session: AsyncSession, model_base_class: type[T]) -> list[T]: