Skip to content

Commit

Permalink
Merge pull request #41 from simonsobs/dev
Browse files Browse the repository at this point in the history
Refactor pagination logic in compose_statement()
  • Loading branch information
TaiSakuma authored Feb 21, 2024
2 parents 41b5cab + 0bfe1bc commit 063e6ec
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 42 deletions.
136 changes: 96 additions & 40 deletions src/nextline_rdb/pagination.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import NamedTuple, Optional, Type, TypeVar, cast
from collections.abc import Sequence
from typing import Any, NamedTuple, Optional, Type, TypeVar

from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase, aliased, selectinload
from sqlalchemy.sql.expression import literal
from sqlalchemy.sql.selectable import Select

# import sqlparse
Expand All @@ -22,10 +22,12 @@ class SortField(NamedTuple):

_Id = TypeVar('_Id')

T = TypeVar('T', bound=DeclarativeBase)


async def load_models(
session: AsyncSession,
Model: Type[DeclarativeBase],
Model: Type[T],
id_field: str,
*,
sort: Optional[Sort] = None,
Expand All @@ -34,10 +36,24 @@ async def load_models(
first: Optional[int] = None,
last: Optional[int] = None,
):
# TODO: Make this an argument so that the caller add `where` clause
select_model = select(Model)

sort = sort or []

if id_field not in {s.field for s in sort}:
sort.append(SortField(id_field))

order_by = [
f.desc() if d else f
for f, d in [(getattr(Model, s.field), s.desc) for s in sort]
]

stmt = compose_statement(
Model,
id_field,
sort=sort,
select_model=select_model,
order_by=order_by,
before=before,
after=after,
first=first,
Expand All @@ -52,65 +68,105 @@ async def load_models(


def compose_statement(
Model: Type[DeclarativeBase],
Model: Type[T],
id_field: str,
*,
sort: Optional[Sort] = None,
select_model: Optional[Select[tuple[T]]] = None,
order_by: Optional[Sequence[Any]] = None,
before: Optional[_Id] = None,
after: Optional[_Id] = None,
first: Optional[int] = None,
last: Optional[int] = None,
) -> Select:
'''Return a SELECT statement object to be given to session.scalars'''
) -> Select[tuple[T]]:
'''Return a SQL select statement for pagination.
Parameters
----------
Model :
The class of the ORM model to query.
id_field :
The name of the primary key field, e.g., 'id'.
select_model : optional
E.g., `select(Model).where(...)`. If not provided, `select(Model)` is used.
order_by : optional
The arguments to `row_number().over(order_by=...)` and `order_by()`. If
not provided, the primary key field is used, e.g., `[Model.id]`.
before : optional
As in the GraphQL Cursor Connections Specification [1].
after : optional
As in the GraphQL Cursor Connections Specification [1].
first : optional
As in the GraphQL Cursor Connections Specification [1].
last : optional
As in the GraphQL Cursor Connections Specification [1].
Returns
-------
stmt
The composed select statement for pagination.
Raises
------
ValueError
If both before/last and after/first parameters are provided.
References
----------
.. [1] https://relay.dev/graphql/connections.htm
'''

forward = (after is not None) or (first is not None)
backward = (before is not None) or (last is not None)

if forward and backward:
raise ValueError('Only either after/first or before/last is allowed')

sort = sort or []

if id_field not in [s.field for s in sort]:
sort.append(SortField(id_field))
if select_model is None:
select_model = select(Model)

def sorting_fields(Model, reverse=False):
return [
f.desc() if reverse ^ d else f
for f, d in [(getattr(Model, s.field), s.desc) for s in sort]
]
if not order_by:
# E.g., [T.id]
order_by = [getattr(Model, id_field)]

if not (forward or backward):
return select(Model).order_by(*sorting_fields(Model))
return select_model.order_by(*order_by)

cursor = after if forward else before
limit = first if forward else last

if cursor is None:
stmt = select(Model).order_by(*sorting_fields(Model, reverse=backward))
else:
cte = select(
Model,
func.row_number()
.over(order_by=sorting_fields(Model, reverse=backward))
.label('row_number'),
).cte()

subq = select(cte.c.row_number.label('cursor'))
subq = subq.where(getattr(cte.c, id_field) == cursor)
subq = cast(Select[tuple], subq.subquery())

Alias = aliased(Model, cte) # type: ignore
stmt = select(Alias).select_from(cte)
stmt = stmt.join(subq, literal(True)) # type: ignore # cartesian product
stmt = stmt.order_by(*sorting_fields(Alias, reverse=backward))
stmt = stmt.where(cte.c.row_number > subq.c.cursor)
# A CTE (Common Table Expression) with a row_number column
cte = select_model.add_columns(
func.row_number().over(order_by=order_by).label('row_number')
).cte()

Alias = aliased(Model, cte)
stmt = select(Alias, cte.c.row_number).select_from(cte)

if cursor is not None:
# A subquery to find the row_number at the cursor
stmt_subq = select(cte.c.row_number.label('at_cursor'))
stmt_subq = stmt_subq.where(getattr(cte.c, id_field) == cursor)
subq = stmt_subq.subquery()

# Select rows after or before (if backward) the cursor
stmt = stmt.select_from(subq)
if backward:
stmt = stmt.where(cte.c.row_number < subq.c.at_cursor)
else:
stmt = stmt.where(cte.c.row_number > subq.c.at_cursor)

if limit is not None:
# Specify the maximum number of rows to return
if backward:
stmt = stmt.order_by(cte.c.row_number.desc())
else:
stmt = stmt.order_by(cte.c.row_number)
stmt = stmt.limit(limit)

if backward:
Alias = aliased(Model, stmt.subquery())
stmt = select(Alias).order_by(*sorting_fields(Alias))
# Select only the model (not the row_number) and ensure the order
cte = stmt.cte()
Alias = aliased(Model, cte)
stmt = select(Alias).select_from(cte)
stmt = stmt.order_by(cte.c.row_number)

return stmt
4 changes: 2 additions & 2 deletions tests/pagination/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
def st_entity() -> st.SearchStrategy[Entity]:
return st.builds(
Entity,
num=st_none_or(st_graphql_ints()),
txt=st_none_or(st.text()),
num=st_none_or(st_graphql_ints(min_value=0, max_value=5)),
txt=st_none_or(st.text(alphabet='ABCDE', min_size=1, max_size=1)),
)


Expand Down

0 comments on commit 063e6ec

Please sign in to comment.