Skip to content

Commit

Permalink
πŸͺ› Moving to SQLAlchemy 2.0 (#540)
Browse files Browse the repository at this point in the history
* πŸͺ› Added support for SQLAlchemy 2.0
* Added common and dialects packages to handle
the new SQLAlchemy 2.0+
* πŸͺ² Fix specific asyncpg oriented test
---------

Co-authored-by: ansipunk <ansipunk@use.startmail.com>
  • Loading branch information
tarsil and ansipunk authored Feb 21, 2024
1 parent c2e4c5b commit 1e40ad1
Show file tree
Hide file tree
Showing 19 changed files with 394 additions and 275 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test-suite.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:

strategy:
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10"]
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]

services:
mysql:
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ values = [
]
await database.execute_many(query=query, values=values)

#Β Run a database query.
# Run a database query.
query = "SELECT * FROM HighScores"
rows = await database.fetch_all(query=query)
print('High Scores:', rows)
Expand Down Expand Up @@ -115,4 +115,4 @@ for examples of how to start using databases together with SQLAlchemy core expre
[quart]: https://gitlab.com/pgjones/quart
[aiohttp]: https://github.com/aio-libs/aiohttp
[tornado]: https://github.com/tornadoweb/tornado
[fastapi]: https://github.com/tiangolo/fastapi
[fastapi]: https://github.com/tiangolo/fastapi
69 changes: 44 additions & 25 deletions databases/backends/aiopg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,20 @@
import uuid

import aiopg
from aiopg.sa.engine import APGCompiler_psycopg2
from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2
from sqlalchemy.engine.cursor import CursorResultMetaData
from sqlalchemy.engine.interfaces import Dialect, ExecutionContext
from sqlalchemy.engine.row import Row
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql.ddl import DDLElement

from databases.core import DatabaseURL
from databases.backends.common.records import Record, Row, create_column_maps
from databases.backends.compilers.psycopg import PGCompiler_psycopg
from databases.backends.dialects.psycopg import PGDialect_psycopg
from databases.core import LOG_EXTRA, DatabaseURL
from databases.interfaces import (
ConnectionBackend,
DatabaseBackend,
Record,
Record as RecordInterface,
TransactionBackend,
)

Expand All @@ -34,10 +35,10 @@ def __init__(
self._pool: typing.Union[aiopg.Pool, None] = None

def _get_dialect(self) -> Dialect:
dialect = PGDialect_psycopg2(
dialect = PGDialect_psycopg(
json_serializer=json.dumps, json_deserializer=lambda x: x
)
dialect.statement_compiler = APGCompiler_psycopg2
dialect.statement_compiler = PGCompiler_psycopg
dialect.implicit_returning = True
dialect.supports_native_enum = True
dialect.supports_smallserial = True # 9.2+
Expand Down Expand Up @@ -117,50 +118,55 @@ async def release(self) -> None:
await self._database._pool.release(self._connection)
self._connection = None

async def fetch_all(self, query: ClauseElement) -> typing.List[Record]:
async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, context = self._compile(query)
query_str, args, result_columns, context = self._compile(query)
column_maps = create_column_maps(result_columns)
dialect = self._dialect

cursor = await self._connection.cursor()
try:
await cursor.execute(query_str, args)
rows = await cursor.fetchall()
metadata = CursorResultMetaData(context, cursor.description)
return [
rows = [
Row(
metadata,
metadata._processors,
metadata._keymap,
Row._default_key_style,
row,
)
for row in rows
]
return [Record(row, result_columns, dialect, column_maps) for row in rows]
finally:
cursor.close()

async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]:
async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, context = self._compile(query)
query_str, args, result_columns, context = self._compile(query)
column_maps = create_column_maps(result_columns)
dialect = self._dialect
cursor = await self._connection.cursor()
try:
await cursor.execute(query_str, args)
row = await cursor.fetchone()
if row is None:
return None
metadata = CursorResultMetaData(context, cursor.description)
return Row(
row = Row(
metadata,
metadata._processors,
metadata._keymap,
Row._default_key_style,
row,
)
return Record(row, result_columns, dialect, column_maps)
finally:
cursor.close()

async def execute(self, query: ClauseElement) -> typing.Any:
assert self._connection is not None, "Connection is not acquired"
query_str, args, context = self._compile(query)
query_str, args, _, _ = self._compile(query)
cursor = await self._connection.cursor()
try:
await cursor.execute(query_str, args)
Expand All @@ -173,7 +179,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None:
cursor = await self._connection.cursor()
try:
for single_query in queries:
single_query, args, context = self._compile(single_query)
single_query, args, _, _ = self._compile(single_query)
await cursor.execute(single_query, args)
finally:
cursor.close()
Expand All @@ -182,36 +188,37 @@ async def iterate(
self, query: ClauseElement
) -> typing.AsyncGenerator[typing.Any, None]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, context = self._compile(query)
query_str, args, result_columns, context = self._compile(query)
column_maps = create_column_maps(result_columns)
dialect = self._dialect
cursor = await self._connection.cursor()
try:
await cursor.execute(query_str, args)
metadata = CursorResultMetaData(context, cursor.description)
async for row in cursor:
yield Row(
record = Row(
metadata,
metadata._processors,
metadata._keymap,
Row._default_key_style,
row,
)
yield Record(record, result_columns, dialect, column_maps)
finally:
cursor.close()

def transaction(self) -> TransactionBackend:
return AiopgTransaction(self)

def _compile(
self, query: ClauseElement
) -> typing.Tuple[str, dict, CompilationContext]:
def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]:
compiled = query.compile(
dialect=self._dialect, compile_kwargs={"render_postcompile": True}
)

execution_context = self._dialect.execution_ctx_cls()
execution_context.dialect = self._dialect

if not isinstance(query, DDLElement):
compiled_params = sorted(compiled.params.items())

args = compiled.construct_params()
for key, val in args.items():
if key in compiled._bind_processors:
Expand All @@ -224,11 +231,23 @@ def _compile(
compiled._ad_hoc_textual,
compiled._loose_column_name_matching,
)

mapping = {
key: "$" + str(i) for i, (key, _) in enumerate(compiled_params, start=1)
}
compiled_query = compiled.string % mapping
result_map = compiled._result_columns

else:
args = {}
result_map = None
compiled_query = compiled.string

logger.debug("Query: %s\nArgs: %s", compiled.string, args)
return compiled.string, args, CompilationContext(execution_context)
query_message = compiled_query.replace(" \n", " ").replace("\n", " ")
logger.debug(
"Query: %s Args: %s", query_message, repr(tuple(args)), extra=LOG_EXTRA
)
return compiled.string, args, result_map, CompilationContext(execution_context)

@property
def raw_connection(self) -> aiopg.connection.Connection:
Expand Down
63 changes: 41 additions & 22 deletions databases/backends/asyncmy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
from sqlalchemy.dialects.mysql import pymysql
from sqlalchemy.engine.cursor import CursorResultMetaData
from sqlalchemy.engine.interfaces import Dialect, ExecutionContext
from sqlalchemy.engine.row import Row
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql.ddl import DDLElement

from databases.backends.common.records import Record, Row, create_column_maps
from databases.core import LOG_EXTRA, DatabaseURL
from databases.interfaces import (
ConnectionBackend,
DatabaseBackend,
Record,
Record as RecordInterface,
TransactionBackend,
)

Expand Down Expand Up @@ -108,50 +108,57 @@ async def release(self) -> None:
await self._database._pool.release(self._connection)
self._connection = None

async def fetch_all(self, query: ClauseElement) -> typing.List[Record]:
async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, context = self._compile(query)
query_str, args, result_columns, context = self._compile(query)
column_maps = create_column_maps(result_columns)
dialect = self._dialect

async with self._connection.cursor() as cursor:
try:
await cursor.execute(query_str, args)
rows = await cursor.fetchall()
metadata = CursorResultMetaData(context, cursor.description)
return [
rows = [
Row(
metadata,
metadata._processors,
metadata._keymap,
Row._default_key_style,
row,
)
for row in rows
]
return [
Record(row, result_columns, dialect, column_maps) for row in rows
]
finally:
await cursor.close()

async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]:
async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, context = self._compile(query)
query_str, args, result_columns, context = self._compile(query)
column_maps = create_column_maps(result_columns)
dialect = self._dialect
async with self._connection.cursor() as cursor:
try:
await cursor.execute(query_str, args)
row = await cursor.fetchone()
if row is None:
return None
metadata = CursorResultMetaData(context, cursor.description)
return Row(
row = Row(
metadata,
metadata._processors,
metadata._keymap,
Row._default_key_style,
row,
)
return Record(row, result_columns, dialect, column_maps)
finally:
await cursor.close()

async def execute(self, query: ClauseElement) -> typing.Any:
assert self._connection is not None, "Connection is not acquired"
query_str, args, context = self._compile(query)
query_str, args, _, _ = self._compile(query)
async with self._connection.cursor() as cursor:
try:
await cursor.execute(query_str, args)
Expand All @@ -166,7 +173,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None:
async with self._connection.cursor() as cursor:
try:
for single_query in queries:
single_query, args, context = self._compile(single_query)
single_query, args, _, _ = self._compile(single_query)
await cursor.execute(single_query, args)
finally:
await cursor.close()
Expand All @@ -175,36 +182,37 @@ async def iterate(
self, query: ClauseElement
) -> typing.AsyncGenerator[typing.Any, None]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, context = self._compile(query)
query_str, args, result_columns, context = self._compile(query)
column_maps = create_column_maps(result_columns)
dialect = self._dialect
async with self._connection.cursor() as cursor:
try:
await cursor.execute(query_str, args)
metadata = CursorResultMetaData(context, cursor.description)
async for row in cursor:
yield Row(
record = Row(
metadata,
metadata._processors,
metadata._keymap,
Row._default_key_style,
row,
)
yield Record(record, result_columns, dialect, column_maps)
finally:
await cursor.close()

def transaction(self) -> TransactionBackend:
return AsyncMyTransaction(self)

def _compile(
self, query: ClauseElement
) -> typing.Tuple[str, dict, CompilationContext]:
def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]:
compiled = query.compile(
dialect=self._dialect, compile_kwargs={"render_postcompile": True}
)

execution_context = self._dialect.execution_ctx_cls()
execution_context.dialect = self._dialect

if not isinstance(query, DDLElement):
compiled_params = sorted(compiled.params.items())

args = compiled.construct_params()
for key, val in args.items():
if key in compiled._bind_processors:
Expand All @@ -217,12 +225,23 @@ def _compile(
compiled._ad_hoc_textual,
compiled._loose_column_name_matching,
)

mapping = {
key: "$" + str(i) for i, (key, _) in enumerate(compiled_params, start=1)
}
compiled_query = compiled.string % mapping
result_map = compiled._result_columns

else:
args = {}
result_map = None
compiled_query = compiled.string

query_message = compiled.string.replace(" \n", " ").replace("\n", " ")
logger.debug("Query: %s Args: %s", query_message, repr(args), extra=LOG_EXTRA)
return compiled.string, args, CompilationContext(execution_context)
query_message = compiled_query.replace(" \n", " ").replace("\n", " ")
logger.debug(
"Query: %s Args: %s", query_message, repr(tuple(args)), extra=LOG_EXTRA
)
return compiled.string, args, result_map, CompilationContext(execution_context)

@property
def raw_connection(self) -> asyncmy.connection.Connection:
Expand Down
Empty file.
Loading

0 comments on commit 1e40ad1

Please sign in to comment.