diff --git a/.github/workflows/test-suite.yml b/.github/workflows/test-suite.yml index bc271a65..eb083dbf 100644 --- a/.github/workflows/test-suite.yml +++ b/.github/workflows/test-suite.yml @@ -11,10 +11,9 @@ jobs: tests: name: "Python ${{ matrix.python-version }}" runs-on: "ubuntu-latest" - strategy: matrix: - python-version: ["3.7", "3.8", "3.9", "3.10"] + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] services: mysql: @@ -38,11 +37,31 @@ jobs: - 5432:5432 options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 + mssql: + image: mcr.microsoft.com/mssql/server:2019-GA-ubuntu-16.04 + env: + MSSQL_SA_PASSWORD: Mssql123mssql- + ACCEPT_EULA: "Y" + MSSQL_PID: Developer + ports: + - 1433:1433 + options: >- + --health-cmd "/opt/mssql-tools/bin/sqlcmd -U sa -P Mssql123mssql- -Q 'select 1' -b -o /dev/null" + --health-interval 60s + --health-timeout 30s + --health-start-period 20s + --health-retries 3 + steps: - uses: "actions/checkout@v3" - uses: "actions/setup-python@v4" with: python-version: "${{ matrix.python-version }}" + - name: "Install drivers" + run: | + sudo ACCEPT_EULA=Y apt-get install -y msodbcsql17 + sudo apt-get install -y unixodbc-dev + - name: "Install dependencies" run: "scripts/install" - name: "Run linting checks" @@ -59,5 +78,8 @@ jobs: mysql+asyncmy://username:password@localhost:3306/testsuite, postgresql://username:password@localhost:5432/testsuite, postgresql+aiopg://username:password@127.0.0.1:5432/testsuite, - postgresql+asyncpg://username:password@localhost:5432/testsuite + postgresql+asyncpg://username:password@localhost:5432/testsuite, + mssql://sa:Mssql123mssql-@localhost:1433/master?driver=ODBC+Driver+17+for+SQL+Server, + mssql+pyodbc://sa:Mssql123mssql-@localhost:1433/master?driver=ODBC+Driver+17+for+SQL+Server, + mssql+aioodbc://sa:Mssql123mssql-@localhost:1433/master?driver=ODBC+Driver+17+for+SQL+Server run: "scripts/test" diff --git a/CHANGELOG.md b/CHANGELOG.md index 4816bc16..36d59393 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). +### Added + +- Support for SQLAlchemy 2.0+ +- Added internal support for the new psycopg dialect. + ## 0.7.0 (Dec 18th, 2022) ### Fixed diff --git a/README.md b/README.md index ba16a104..93eb3b52 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,7 @@ Database drivers supported are: * [aiomysql][aiomysql] * [asyncmy][asyncmy] * [aiosqlite][aiosqlite] +* [aioodbc][aioodbc] You can install the required database drivers with: @@ -45,9 +46,10 @@ $ pip install databases[aiopg] $ pip install databases[aiomysql] $ pip install databases[asyncmy] $ pip install databases[aiosqlite] +$ pip install databases[aioodbc] ``` -Note that if you are using any synchronous SQLAlchemy functions such as `engine.create_all()` or [alembic][alembic] migrations then you still have to install a synchronous DB driver: [psycopg2][psycopg2] for PostgreSQL and [pymysql][pymysql] for MySQL. +Note that if you are using any synchronous SQLAlchemy functions such as `engine.create_all()` or [alembic][alembic] migrations then you still have to install a synchronous DB driver: [psycopg2][psycopg2] for PostgreSQL, [pymysql][pymysql] for MySQL and [pyodbc][pyodbc] for SQL Server. --- @@ -85,7 +87,7 @@ values = [ ] await database.execute_many(query=query, values=values) -# Run a database query. +# Run a database query. query = "SELECT * FROM HighScores" rows = await database.fetch_all(query=query) print('High Scores:', rows) @@ -103,11 +105,13 @@ for examples of how to start using databases together with SQLAlchemy core expre [alembic]: https://alembic.sqlalchemy.org/en/latest/ [psycopg2]: https://www.psycopg.org/ [pymysql]: https://github.com/PyMySQL/PyMySQL +[pyodbc]: https://github.com/mkleehammer/pyodbc [asyncpg]: https://github.com/MagicStack/asyncpg [aiopg]: https://github.com/aio-libs/aiopg [aiomysql]: https://github.com/aio-libs/aiomysql [asyncmy]: https://github.com/long2ice/asyncmy [aiosqlite]: https://github.com/omnilib/aiosqlite +[aioodbc]: https://aioodbc.readthedocs.io/en/latest/ [starlette]: https://github.com/encode/starlette [sanic]: https://github.com/huge-success/sanic @@ -115,4 +119,4 @@ for examples of how to start using databases together with SQLAlchemy core expre [quart]: https://gitlab.com/pgjones/quart [aiohttp]: https://github.com/aio-libs/aiohttp [tornado]: https://github.com/tornadoweb/tornado -[fastapi]: https://github.com/tiangolo/fastapi +[fastapi]: https://github.com/tiangolo/fastapi \ No newline at end of file diff --git a/databases/backends/aiopg.py b/databases/backends/aiopg.py index 8668b2b9..0b2a6e89 100644 --- a/databases/backends/aiopg.py +++ b/databases/backends/aiopg.py @@ -5,19 +5,20 @@ import uuid import aiopg -from aiopg.sa.engine import APGCompiler_psycopg2 -from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2 from sqlalchemy.engine.cursor import CursorResultMetaData from sqlalchemy.engine.interfaces import Dialect, ExecutionContext from sqlalchemy.engine.row import Row from sqlalchemy.sql import ClauseElement from sqlalchemy.sql.ddl import DDLElement -from databases.core import DatabaseURL +from databases.backends.common.records import Record, Row, create_column_maps +from databases.backends.compilers.psycopg import PGCompiler_psycopg +from databases.backends.dialects.psycopg import PGDialect_psycopg +from databases.core import LOG_EXTRA, DatabaseURL from databases.interfaces import ( ConnectionBackend, DatabaseBackend, - Record, + Record as RecordInterface, TransactionBackend, ) @@ -34,10 +35,10 @@ def __init__( self._pool: typing.Union[aiopg.Pool, None] = None def _get_dialect(self) -> Dialect: - dialect = PGDialect_psycopg2( + dialect = PGDialect_psycopg( json_serializer=json.dumps, json_deserializer=lambda x: x ) - dialect.statement_compiler = APGCompiler_psycopg2 + dialect.statement_compiler = PGCompiler_psycopg dialect.implicit_returning = True dialect.supports_native_enum = True dialect.supports_smallserial = True # 9.2+ @@ -117,15 +118,18 @@ async def release(self) -> None: await self._database._pool.release(self._connection) self._connection = None - async def fetch_all(self, query: ClauseElement) -> typing.List[Record]: + async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, result_columns, context = self._compile(query) + column_maps = create_column_maps(result_columns) + dialect = self._dialect + cursor = await self._connection.cursor() try: await cursor.execute(query_str, args) rows = await cursor.fetchall() metadata = CursorResultMetaData(context, cursor.description) - return [ + rows = [ Row( metadata, metadata._processors, @@ -135,12 +139,15 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[Record]: ) for row in rows ] + return [Record(row, result_columns, dialect, column_maps) for row in rows] finally: cursor.close() - async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]: + async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, result_columns, context = self._compile(query) + column_maps = create_column_maps(result_columns) + dialect = self._dialect cursor = await self._connection.cursor() try: await cursor.execute(query_str, args) @@ -148,19 +155,20 @@ async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]: if row is None: return None metadata = CursorResultMetaData(context, cursor.description) - return Row( + row = Row( metadata, metadata._processors, metadata._keymap, Row._default_key_style, row, ) + return Record(row, result_columns, dialect, column_maps) finally: cursor.close() async def execute(self, query: ClauseElement) -> typing.Any: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, _, _ = self._compile(query) cursor = await self._connection.cursor() try: await cursor.execute(query_str, args) @@ -173,7 +181,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None: cursor = await self._connection.cursor() try: for single_query in queries: - single_query, args, context = self._compile(single_query) + single_query, args, _, _ = self._compile(single_query) await cursor.execute(single_query, args) finally: cursor.close() @@ -182,36 +190,38 @@ async def iterate( self, query: ClauseElement ) -> typing.AsyncGenerator[typing.Any, None]: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, result_columns, context = self._compile(query) + column_maps = create_column_maps(result_columns) + dialect = self._dialect cursor = await self._connection.cursor() try: await cursor.execute(query_str, args) metadata = CursorResultMetaData(context, cursor.description) async for row in cursor: - yield Row( + record = Row( metadata, metadata._processors, metadata._keymap, Row._default_key_style, row, ) + yield Record(record, result_columns, dialect, column_maps) finally: cursor.close() def transaction(self) -> TransactionBackend: return AiopgTransaction(self) - def _compile( - self, query: ClauseElement - ) -> typing.Tuple[str, dict, CompilationContext]: + def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]: compiled = query.compile( dialect=self._dialect, compile_kwargs={"render_postcompile": True} ) - execution_context = self._dialect.execution_ctx_cls() execution_context.dialect = self._dialect if not isinstance(query, DDLElement): + compiled_params = sorted(compiled.params.items()) + args = compiled.construct_params() for key, val in args.items(): if key in compiled._bind_processors: @@ -224,11 +234,23 @@ def _compile( compiled._ad_hoc_textual, compiled._loose_column_name_matching, ) + + mapping = { + key: "$" + str(i) for i, (key, _) in enumerate(compiled_params, start=1) + } + compiled_query = compiled.string % mapping + result_map = compiled._result_columns + else: args = {} + result_map = None + compiled_query = compiled.string - logger.debug("Query: %s\nArgs: %s", compiled.string, args) - return compiled.string, args, CompilationContext(execution_context) + query_message = compiled_query.replace(" \n", " ").replace("\n", " ") + logger.debug( + "Query: %s Args: %s", query_message, repr(tuple(args)), extra=LOG_EXTRA + ) + return compiled.string, args, result_map, CompilationContext(execution_context) @property def raw_connection(self) -> aiopg.connection.Connection: diff --git a/databases/backends/asyncmy.py b/databases/backends/asyncmy.py index 749e5afe..f224c7cc 100644 --- a/databases/backends/asyncmy.py +++ b/databases/backends/asyncmy.py @@ -7,15 +7,15 @@ from sqlalchemy.dialects.mysql import pymysql from sqlalchemy.engine.cursor import CursorResultMetaData from sqlalchemy.engine.interfaces import Dialect, ExecutionContext -from sqlalchemy.engine.row import Row from sqlalchemy.sql import ClauseElement from sqlalchemy.sql.ddl import DDLElement +from databases.backends.common.records import Record, Row, create_column_maps from databases.core import LOG_EXTRA, DatabaseURL from databases.interfaces import ( ConnectionBackend, DatabaseBackend, - Record, + Record as RecordInterface, TransactionBackend, ) @@ -105,15 +105,18 @@ async def release(self) -> None: await self._database._pool.release(self._connection) self._connection = None - async def fetch_all(self, query: ClauseElement) -> typing.List[Record]: + async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, result_columns, context = self._compile(query) + column_maps = create_column_maps(result_columns) + dialect = self._dialect + async with self._connection.cursor() as cursor: try: await cursor.execute(query_str, args) rows = await cursor.fetchall() metadata = CursorResultMetaData(context, cursor.description) - return [ + rows = [ Row( metadata, metadata._processors, @@ -123,12 +126,17 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[Record]: ) for row in rows ] + return [ + Record(row, result_columns, dialect, column_maps) for row in rows + ] finally: await cursor.close() - async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]: + async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, result_columns, context = self._compile(query) + column_maps = create_column_maps(result_columns) + dialect = self._dialect async with self._connection.cursor() as cursor: try: await cursor.execute(query_str, args) @@ -136,19 +144,20 @@ async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]: if row is None: return None metadata = CursorResultMetaData(context, cursor.description) - return Row( + row = Row( metadata, metadata._processors, metadata._keymap, Row._default_key_style, row, ) + return Record(row, result_columns, dialect, column_maps) finally: await cursor.close() async def execute(self, query: ClauseElement) -> typing.Any: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, _, _ = self._compile(query) async with self._connection.cursor() as cursor: try: await cursor.execute(query_str, args) @@ -163,7 +172,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None: async with self._connection.cursor() as cursor: try: for single_query in queries: - single_query, args, context = self._compile(single_query) + single_query, args, _, _ = self._compile(single_query) await cursor.execute(single_query, args) finally: await cursor.close() @@ -172,36 +181,38 @@ async def iterate( self, query: ClauseElement ) -> typing.AsyncGenerator[typing.Any, None]: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, result_columns, context = self._compile(query) + column_maps = create_column_maps(result_columns) + dialect = self._dialect async with self._connection.cursor() as cursor: try: await cursor.execute(query_str, args) metadata = CursorResultMetaData(context, cursor.description) async for row in cursor: - yield Row( + record = Row( metadata, metadata._processors, metadata._keymap, Row._default_key_style, row, ) + yield Record(record, result_columns, dialect, column_maps) finally: await cursor.close() def transaction(self) -> TransactionBackend: return AsyncMyTransaction(self) - def _compile( - self, query: ClauseElement - ) -> typing.Tuple[str, dict, CompilationContext]: + def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]: compiled = query.compile( dialect=self._dialect, compile_kwargs={"render_postcompile": True} ) - execution_context = self._dialect.execution_ctx_cls() execution_context.dialect = self._dialect if not isinstance(query, DDLElement): + compiled_params = sorted(compiled.params.items()) + args = compiled.construct_params() for key, val in args.items(): if key in compiled._bind_processors: @@ -214,12 +225,23 @@ def _compile( compiled._ad_hoc_textual, compiled._loose_column_name_matching, ) + + mapping = { + key: "$" + str(i) for i, (key, _) in enumerate(compiled_params, start=1) + } + compiled_query = compiled.string % mapping + result_map = compiled._result_columns + else: args = {} + result_map = None + compiled_query = compiled.string - query_message = compiled.string.replace(" \n", " ").replace("\n", " ") - logger.debug("Query: %s Args: %s", query_message, repr(args), extra=LOG_EXTRA) - return compiled.string, args, CompilationContext(execution_context) + query_message = compiled_query.replace(" \n", " ").replace("\n", " ") + logger.debug( + "Query: %s Args: %s", query_message, repr(tuple(args)), extra=LOG_EXTRA + ) + return compiled.string, args, result_map, CompilationContext(execution_context) @property def raw_connection(self) -> asyncmy.connection.Connection: diff --git a/databases/backends/common/__init__.py b/databases/backends/common/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/databases/backends/common/records.py b/databases/backends/common/records.py new file mode 100644 index 00000000..77a4d8fa --- /dev/null +++ b/databases/backends/common/records.py @@ -0,0 +1,142 @@ +import json +import typing +from datetime import date, datetime + +from sqlalchemy.engine.interfaces import Dialect +from sqlalchemy.engine.row import Row as SQLRow +from sqlalchemy.sql.compiler import _CompileLabel +from sqlalchemy.sql.schema import Column +from sqlalchemy.types import TypeEngine + +from databases.interfaces import Record as RecordInterface + +DIALECT_EXCLUDE = {"postgresql"} + + +class Record(RecordInterface): + __slots__ = ( + "_row", + "_result_columns", + "_dialect", + "_column_map", + "_column_map_int", + "_column_map_full", + ) + + def __init__( + self, + row: typing.Any, + result_columns: tuple, + dialect: Dialect, + column_maps: typing.Tuple[ + typing.Mapping[typing.Any, typing.Tuple[int, TypeEngine]], + typing.Mapping[int, typing.Tuple[int, TypeEngine]], + typing.Mapping[str, typing.Tuple[int, TypeEngine]], + ], + ) -> None: + self._row = row + self._result_columns = result_columns + self._dialect = dialect + self._column_map, self._column_map_int, self._column_map_full = column_maps + + @property + def _mapping(self) -> typing.Mapping: + return self._row + + def keys(self) -> typing.KeysView: + return self._mapping.keys() + + def values(self) -> typing.ValuesView: + return self._mapping.values() + + def __getitem__(self, key: typing.Any) -> typing.Any: + if len(self._column_map) == 0: + return self._row[key] + elif isinstance(key, Column): + idx, datatype = self._column_map_full[str(key)] + elif isinstance(key, int): + idx, datatype = self._column_map_int[key] + else: + idx, datatype = self._column_map[key] + + raw = self._row[idx] + processor = datatype._cached_result_processor(self._dialect, None) + + if self._dialect.name not in DIALECT_EXCLUDE: + if isinstance(raw, dict): + raw = json.dumps(raw) + + if processor is not None and (not isinstance(raw, (datetime, date))): + return processor(raw) + return raw + + def __iter__(self) -> typing.Iterator: + return iter(self._row.keys()) + + def __len__(self) -> int: + return len(self._row) + + def __getattr__(self, name: str) -> typing.Any: + try: + return self.__getitem__(name) + except KeyError as e: + raise AttributeError(e.args[0]) from e + + +class Row(SQLRow): + def __getitem__(self, key: typing.Any) -> typing.Any: + """ + An instance of a Row in SQLAlchemy allows the access + to the Row._fields as tuple and the Row._mapping for + the values. + """ + if isinstance(key, int): + field = self._fields[key] + return self._mapping[field] + return self._mapping[key] + + def keys(self): + return self._mapping.keys() + + def values(self): + return self._mapping.values() + + def __getattr__(self, name: str) -> typing.Any: + try: + return self.__getitem__(name) + except KeyError as e: + raise AttributeError(e.args[0]) from e + + +def create_column_maps( + result_columns: typing.Any, +) -> typing.Tuple[ + typing.Mapping[typing.Any, typing.Tuple[int, TypeEngine]], + typing.Mapping[int, typing.Tuple[int, TypeEngine]], + typing.Mapping[str, typing.Tuple[int, TypeEngine]], +]: + """ + Generate column -> datatype mappings from the column definitions. + + These mappings are used throughout PostgresConnection methods + to initialize Record-s. The underlying DB driver does not do type + conversion for us so we have wrap the returned asyncpg.Record-s. + + :return: Three mappings from different ways to address a column to \ + corresponding column indexes and datatypes: \ + 1. by column identifier; \ + 2. by column index; \ + 3. by column name in Column sqlalchemy objects. + """ + column_map, column_map_int, column_map_full = {}, {}, {} + for idx, (column_name, _, column, datatype) in enumerate(result_columns): + column_map[column_name] = (idx, datatype) + column_map_int[idx] = (idx, datatype) + + # Added in SQLA 2.0 and _CompileLabels do not have _annotations + # When this happens, the mapping is on the second position + if isinstance(column[0], _CompileLabel): + column_map_full[str(column[2])] = (idx, datatype) + else: + column_map_full[str(column[0])] = (idx, datatype) + return column_map, column_map_int, column_map_full diff --git a/databases/backends/compilers/__init__.py b/databases/backends/compilers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/databases/backends/compilers/psycopg.py b/databases/backends/compilers/psycopg.py new file mode 100644 index 00000000..654c22a1 --- /dev/null +++ b/databases/backends/compilers/psycopg.py @@ -0,0 +1,17 @@ +from sqlalchemy.dialects.postgresql.psycopg import PGCompiler_psycopg + + +class APGCompiler_psycopg2(PGCompiler_psycopg): + def construct_params(self, *args, **kwargs): + pd = super().construct_params(*args, **kwargs) + + for column in self.prefetch: + pd[column.key] = self._exec_default(column.default) + + return pd + + def _exec_default(self, default): + if default.is_callable: + return default.arg(self.dialect) + else: + return default.arg diff --git a/databases/backends/dialects/__init__.py b/databases/backends/dialects/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/databases/backends/dialects/psycopg.py b/databases/backends/dialects/psycopg.py new file mode 100644 index 00000000..07bd1880 --- /dev/null +++ b/databases/backends/dialects/psycopg.py @@ -0,0 +1,46 @@ +""" +All the unique changes for the databases package +with the custom Numeric as the deprecated pypostgresql +for backwards compatibility and to make sure the +package can go to SQLAlchemy 2.0+. +""" + +import typing + +from sqlalchemy import types, util +from sqlalchemy.dialects.postgresql.base import PGDialect, PGExecutionContext +from sqlalchemy.engine import processors +from sqlalchemy.types import Float, Numeric + + +class PGExecutionContext_psycopg(PGExecutionContext): + ... + + +class PGNumeric(Numeric): + def bind_processor( + self, dialect: typing.Any + ) -> typing.Union[str, None]: # pragma: no cover + return processors.to_str + + def result_processor( + self, dialect: typing.Any, coltype: typing.Any + ) -> typing.Union[float, None]: # pragma: no cover + if self.asdecimal: + return None + else: + return processors.to_float + + +class PGDialect_psycopg(PGDialect): + colspecs = util.update_copy( + PGDialect.colspecs, + { + types.Numeric: PGNumeric, + types.Float: Float, + }, + ) + execution_ctx_cls = PGExecutionContext_psycopg + + +dialect = PGDialect_psycopg diff --git a/databases/backends/mssql.py b/databases/backends/mssql.py new file mode 100644 index 00000000..0764e774 --- /dev/null +++ b/databases/backends/mssql.py @@ -0,0 +1,313 @@ +import getpass +import logging +import typing +import uuid + +import aioodbc +from sqlalchemy.dialects.mssql import pyodbc +from sqlalchemy.engine.cursor import CursorResultMetaData +from sqlalchemy.engine.interfaces import Dialect, ExecutionContext +from sqlalchemy.sql import ClauseElement +from sqlalchemy.sql.ddl import DDLElement + +from databases.backends.common.records import Record, Row, create_column_maps +from databases.core import LOG_EXTRA, DatabaseURL +from databases.interfaces import ( + ConnectionBackend, + DatabaseBackend, + Record as RecordInterface, + TransactionBackend, +) + +logger = logging.getLogger("databases") + + +class MSSQLBackend(DatabaseBackend): + def __init__( + self, database_url: typing.Union[DatabaseURL, str], **options: typing.Any + ) -> None: + self._database_url = DatabaseURL(database_url) + self._options = options + self._dialect = pyodbc.dialect(paramstyle="pyformat") + self._dialect.supports_native_decimal = True + self._pool: aioodbc.Pool = None + + def _get_connection_kwargs(self) -> dict: + url_options = self._database_url.options + + kwargs = {} + min_size = url_options.get("min_size") + max_size = url_options.get("max_size") + pool_recycle = url_options.get("pool_recycle") + ssl = url_options.get("ssl") + driver = url_options.get("driver") + timeout = url_options.get("connection_timeout", 30) + trusted_connection = url_options.get("trusted_connection", "no") + + assert driver is not None, "The driver must be specified" + + if min_size is not None: + kwargs["minsize"] = int(min_size) + if max_size is not None: + kwargs["maxsize"] = int(max_size) + if pool_recycle is not None: + kwargs["pool_recycle"] = int(pool_recycle) + if ssl is not None: + kwargs["ssl"] = {"true": True, "false": False}[ssl.lower()] + + kwargs["trusted_connection"] = trusted_connection.lower() + kwargs["driver"] = driver + kwargs["timeout"] = timeout + + for key, value in self._options.items(): + # Coerce 'min_size' and 'max_size' for consistency. + if key == "min_size": + key = "minsize" + elif key == "max_size": + key = "maxsize" + kwargs[key] = value + + return kwargs + + async def connect(self) -> None: + assert self._pool is None, "DatabaseBackend is already running" + kwargs = self._get_connection_kwargs() + + driver = kwargs["driver"] + database = self._database_url.database + hostname = self._database_url.hostname + port = self._database_url.port or 1433 + user = self._database_url.username or getpass.getuser() + password = self._database_url.password + timeout = kwargs.pop("timeout") + + if port: + dsn = f"Driver={driver};Database={database};Server={hostname},{port};UID={user};PWD={password};Connection+Timeout={timeout}" + else: + dsn = f"Driver={driver};Database={database};Server={hostname},{port};UID={user};PWD={password};Connection+Timeout={timeout}" + + self._pool = await aioodbc.create_pool( + dsn=dsn, + autocommit=True, + **kwargs, + ) + + async def disconnect(self) -> None: + assert self._pool is not None, "DatabaseBackend is not running" + self._pool.close() + await self._pool.wait_closed() + self._pool = None + + def connection(self) -> "MSSQLConnection": + return MSSQLConnection(self, self._dialect) + + +class CompilationContext: + def __init__(self, context: ExecutionContext): + self.context = context + + +class MSSQLConnection(ConnectionBackend): + def __init__(self, database: MSSQLBackend, dialect: Dialect) -> None: + self._database = database + self._dialect = dialect + self._connection: typing.Optional[aioodbc.Connection] = None + + async def acquire(self) -> None: + assert self._connection is None, "Connection is already acquired" + assert self._database._pool is not None, "DatabaseBackend is not running" + self._connection = await self._database._pool.acquire() + + async def release(self) -> None: + assert self._connection is not None, "Connection is not acquired" + assert self._database._pool is not None, "DatabaseBackend is not running" + await self._database._pool.release(self._connection) + self._connection = None + + async def fetch_all(self, query: ClauseElement) -> typing.List["RecordInterface"]: + assert self._connection is not None, "Connection is not acquired" + query_str, args, result_columns, context = self._compile(query) + column_maps = create_column_maps(result_columns) + dialect = self._dialect + cursor = await self._connection.cursor() + try: + await cursor.execute(query_str, args) + rows = await cursor.fetchall() + metadata = CursorResultMetaData(context, cursor.description) + rows = [ + Row( + metadata, + metadata._processors, + metadata._keymap, + Row._default_key_style, + row, + ) + for row in rows + ] + return [Record(row, result_columns, dialect, column_maps) for row in rows] + finally: + await cursor.close() + + async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]: + assert self._connection is not None, "Connection is not acquired" + query_str, args, result_columns, context = self._compile(query) + column_maps = create_column_maps(result_columns) + dialect = self._dialect + cursor = await self._connection.cursor() + try: + await cursor.execute(query_str, args) + row = await cursor.fetchone() + if row is None: + return None + metadata = CursorResultMetaData(context, cursor.description) + row = Row( + metadata, + metadata._processors, + metadata._keymap, + Row._default_key_style, + row, + ) + return Record(row, result_columns, dialect, column_maps) + finally: + await cursor.close() + + async def execute(self, query: ClauseElement) -> typing.Any: + assert self._connection is not None, "Connection is not acquired" + query_str, args, _, _ = self._compile(query) + cursor = await self._connection.cursor() + try: + values = await cursor.execute(query_str, args) + try: + values = await values.fetchone() + return values[0] + except Exception: + ... + finally: + await cursor.close() + + async def execute_many(self, queries: typing.List[ClauseElement]) -> None: + assert self._connection is not None, "Connection is not acquired" + cursor = await self._connection.cursor() + try: + for single_query in queries: + single_query, args, _, _ = self._compile(single_query) + await cursor.execute(single_query, args) + finally: + await cursor.close() + + async def iterate( + self, query: ClauseElement + ) -> typing.AsyncGenerator[typing.Any, None]: + assert self._connection is not None, "Connection is not acquired" + query_str, args, result_columns, context = self._compile(query) + column_maps = create_column_maps(result_columns) + dialect = self._dialect + cursor = await self._connection.cursor() + try: + await cursor.execute(query_str, args) + metadata = CursorResultMetaData(context, cursor.description) + async for row in cursor: + record = Row( + metadata, + metadata._processors, + metadata._keymap, + Row._default_key_style, + row, + ) + yield Record(record, result_columns, dialect, column_maps) + finally: + await cursor.close() + + def transaction(self) -> TransactionBackend: + return MSSQLTransaction(self) + + def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]: + compiled = query.compile( + dialect=self._dialect, compile_kwargs={"render_postcompile": True} + ) + + execution_context = self._dialect.execution_ctx_cls() + execution_context.dialect = self._dialect + + if not isinstance(query, DDLElement): + compiled_params = compiled.params.items() + + mapping = {key: "?" for _, (key, _) in enumerate(compiled_params, start=1)} + compiled_query = compiled.string % mapping + + processors = compiled._bind_processors + args = [ + processors[key](val) if key in processors else val + for key, val in compiled_params + ] + + execution_context.result_column_struct = ( + compiled._result_columns, + compiled._ordered_columns, + compiled._textual_ordered_columns, + compiled._ad_hoc_textual, + compiled._loose_column_name_matching, + ) + + result_map = compiled._result_columns + else: + compiled_query = compiled.string + args = [] + result_map = None + + query_message = compiled_query.replace(" \n", " ").replace("\n", " ") + logger.debug( + "Query: %s Args: %s", query_message, repr(tuple(args)), extra=LOG_EXTRA + ) + return compiled_query, args, result_map, CompilationContext(execution_context) + + @property + def raw_connection(self) -> aioodbc.connection.Connection: + assert self._connection is not None, "Connection is not acquired" + return self._connection + + +class MSSQLTransaction(TransactionBackend): + def __init__(self, connection: MSSQLConnection): + self._connection = connection + self._is_root = False + self._savepoint_name = "" + + async def start( + self, is_root: bool, extra_options: typing.Dict[typing.Any, typing.Any] + ) -> None: + assert self._connection._connection is not None, "Connection is not acquired" + self._is_root = is_root + cursor = await self._connection._connection.cursor() + if self._is_root: + await cursor.execute("BEGIN TRANSACTION") + else: + id = str(uuid.uuid4()).replace("-", "_")[:12] + self._savepoint_name = f"STARLETTE_SAVEPOINT_{id}" + try: + await cursor.execute(f"SAVE TRANSACTION {self._savepoint_name}") + finally: + cursor.close() + + async def commit(self) -> None: + assert self._connection._connection is not None, "Connection is not acquired" + cursor = await self._connection._connection.cursor() + if self._is_root: + await cursor.execute("COMMIT TRANSACTION") + else: + try: + await cursor.execute(f"COMMIT TRANSACTION {self._savepoint_name}") + finally: + cursor.close() + + async def rollback(self) -> None: + assert self._connection._connection is not None, "Connection is not acquired" + cursor = await self._connection._connection.cursor() + if self._is_root: + await cursor.execute("BEGIN TRANSACTION") + await cursor.execute("ROLLBACK TRANSACTION") + else: + try: + await cursor.execute(f"ROLLBACK TRANSACTION {self._savepoint_name}") + finally: + cursor.close() diff --git a/databases/backends/mysql.py b/databases/backends/mysql.py index 6b86042f..d93b4a7f 100644 --- a/databases/backends/mysql.py +++ b/databases/backends/mysql.py @@ -7,15 +7,15 @@ from sqlalchemy.dialects.mysql import pymysql from sqlalchemy.engine.cursor import CursorResultMetaData from sqlalchemy.engine.interfaces import Dialect, ExecutionContext -from sqlalchemy.engine.row import Row from sqlalchemy.sql import ClauseElement from sqlalchemy.sql.ddl import DDLElement +from databases.backends.common.records import Record, Row, create_column_maps from databases.core import LOG_EXTRA, DatabaseURL from databases.interfaces import ( ConnectionBackend, DatabaseBackend, - Record, + Record as RecordInterface, TransactionBackend, ) @@ -105,15 +105,17 @@ async def release(self) -> None: await self._database._pool.release(self._connection) self._connection = None - async def fetch_all(self, query: ClauseElement) -> typing.List[Record]: + async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, result_columns, context = self._compile(query) + column_maps = create_column_maps(result_columns) + dialect = self._dialect cursor = await self._connection.cursor() try: await cursor.execute(query_str, args) rows = await cursor.fetchall() metadata = CursorResultMetaData(context, cursor.description) - return [ + rows = [ Row( metadata, metadata._processors, @@ -123,12 +125,15 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[Record]: ) for row in rows ] + return [Record(row, result_columns, dialect, column_maps) for row in rows] finally: await cursor.close() - async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]: + async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, result_columns, context = self._compile(query) + column_maps = create_column_maps(result_columns) + dialect = self._dialect cursor = await self._connection.cursor() try: await cursor.execute(query_str, args) @@ -136,19 +141,20 @@ async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]: if row is None: return None metadata = CursorResultMetaData(context, cursor.description) - return Row( + row = Row( metadata, metadata._processors, metadata._keymap, Row._default_key_style, row, ) + return Record(row, result_columns, dialect, column_maps) finally: await cursor.close() async def execute(self, query: ClauseElement) -> typing.Any: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, _, _ = self._compile(query) cursor = await self._connection.cursor() try: await cursor.execute(query_str, args) @@ -163,7 +169,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None: cursor = await self._connection.cursor() try: for single_query in queries: - single_query, args, context = self._compile(single_query) + single_query, args, _, _ = self._compile(single_query) await cursor.execute(single_query, args) finally: await cursor.close() @@ -172,36 +178,38 @@ async def iterate( self, query: ClauseElement ) -> typing.AsyncGenerator[typing.Any, None]: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, result_columns, context = self._compile(query) + column_maps = create_column_maps(result_columns) + dialect = self._dialect cursor = await self._connection.cursor() try: await cursor.execute(query_str, args) metadata = CursorResultMetaData(context, cursor.description) async for row in cursor: - yield Row( + record = Row( metadata, metadata._processors, metadata._keymap, Row._default_key_style, row, ) + yield Record(record, result_columns, dialect, column_maps) finally: await cursor.close() def transaction(self) -> TransactionBackend: return MySQLTransaction(self) - def _compile( - self, query: ClauseElement - ) -> typing.Tuple[str, dict, CompilationContext]: + def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]: compiled = query.compile( dialect=self._dialect, compile_kwargs={"render_postcompile": True} ) - execution_context = self._dialect.execution_ctx_cls() execution_context.dialect = self._dialect if not isinstance(query, DDLElement): + compiled_params = sorted(compiled.params.items()) + args = compiled.construct_params() for key, val in args.items(): if key in compiled._bind_processors: @@ -214,12 +222,23 @@ def _compile( compiled._ad_hoc_textual, compiled._loose_column_name_matching, ) + + mapping = { + key: "$" + str(i) for i, (key, _) in enumerate(compiled_params, start=1) + } + compiled_query = compiled.string % mapping + result_map = compiled._result_columns + else: args = {} + result_map = None + compiled_query = compiled.string - query_message = compiled.string.replace(" \n", " ").replace("\n", " ") - logger.debug("Query: %s Args: %s", query_message, repr(args), extra=LOG_EXTRA) - return compiled.string, args, CompilationContext(execution_context) + query_message = compiled_query.replace(" \n", " ").replace("\n", " ") + logger.debug( + "Query: %s Args: %s", query_message, repr(tuple(args)), extra=LOG_EXTRA + ) + return compiled.string, args, result_map, CompilationContext(execution_context) @property def raw_connection(self) -> aiomysql.connection.Connection: diff --git a/databases/backends/postgres.py b/databases/backends/postgres.py index e30c12d7..917d4e1f 100644 --- a/databases/backends/postgres.py +++ b/databases/backends/postgres.py @@ -2,13 +2,12 @@ import typing import asyncpg -from sqlalchemy.dialects.postgresql import pypostgresql from sqlalchemy.engine.interfaces import Dialect from sqlalchemy.sql import ClauseElement from sqlalchemy.sql.ddl import DDLElement -from sqlalchemy.sql.schema import Column -from sqlalchemy.types import TypeEngine +from databases.backends.common.records import Record, create_column_maps +from databases.backends.dialects.psycopg import dialect as psycopg_dialect from databases.core import LOG_EXTRA, DatabaseURL from databases.interfaces import ( ConnectionBackend, @@ -30,7 +29,7 @@ def __init__( self._pool = None def _get_dialect(self) -> Dialect: - dialect = pypostgresql.dialect(paramstyle="pyformat") + dialect = psycopg_dialect(paramstyle="pyformat") dialect.implicit_returning = True dialect.supports_native_enum = True @@ -82,82 +81,6 @@ def connection(self) -> "PostgresConnection": return PostgresConnection(self, self._dialect) -class Record(RecordInterface): - __slots__ = ( - "_row", - "_result_columns", - "_dialect", - "_column_map", - "_column_map_int", - "_column_map_full", - ) - - def __init__( - self, - row: asyncpg.Record, - result_columns: tuple, - dialect: Dialect, - column_maps: typing.Tuple[ - typing.Mapping[typing.Any, typing.Tuple[int, TypeEngine]], - typing.Mapping[int, typing.Tuple[int, TypeEngine]], - typing.Mapping[str, typing.Tuple[int, TypeEngine]], - ], - ) -> None: - self._row = row - self._result_columns = result_columns - self._dialect = dialect - self._column_map, self._column_map_int, self._column_map_full = column_maps - - @property - def _mapping(self) -> typing.Mapping: - return self._row - - def keys(self) -> typing.KeysView: - import warnings - - warnings.warn( - "The `Row.keys()` method is deprecated to mimic SQLAlchemy behaviour, " - "use `Row._mapping.keys()` instead.", - DeprecationWarning, - ) - return self._mapping.keys() - - def values(self) -> typing.ValuesView: - import warnings - - warnings.warn( - "The `Row.values()` method is deprecated to mimic SQLAlchemy behaviour, " - "use `Row._mapping.values()` instead.", - DeprecationWarning, - ) - return self._mapping.values() - - def __getitem__(self, key: typing.Any) -> typing.Any: - if len(self._column_map) == 0: # raw query - return self._row[key] - elif isinstance(key, Column): - idx, datatype = self._column_map_full[str(key)] - elif isinstance(key, int): - idx, datatype = self._column_map_int[key] - else: - idx, datatype = self._column_map[key] - raw = self._row[idx] - processor = datatype._cached_result_processor(self._dialect, None) - - if processor is not None: - return processor(raw) - return raw - - def __iter__(self) -> typing.Iterator: - return iter(self._row.keys()) - - def __len__(self) -> int: - return len(self._row) - - def __getattr__(self, name: str) -> typing.Any: - return self._mapping.get(name) - - class PostgresConnection(ConnectionBackend): def __init__(self, database: PostgresBackend, dialect: Dialect): self._database = database @@ -180,7 +103,7 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]: query_str, args, result_columns = self._compile(query) rows = await self._connection.fetch(query_str, *args) dialect = self._dialect - column_maps = self._create_column_maps(result_columns) + column_maps = create_column_maps(result_columns) return [Record(row, result_columns, dialect, column_maps) for row in rows] async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]: @@ -193,7 +116,7 @@ async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterfa row, result_columns, self._dialect, - self._create_column_maps(result_columns), + create_column_maps(result_columns), ) async def fetch_val( @@ -213,7 +136,7 @@ async def fetch_val( async def execute(self, query: ClauseElement) -> typing.Any: assert self._connection is not None, "Connection is not acquired" - query_str, args, result_columns = self._compile(query) + query_str, args, _ = self._compile(query) return await self._connection.fetchval(query_str, *args) async def execute_many(self, queries: typing.List[ClauseElement]) -> None: @@ -222,7 +145,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None: # loop through multiple executes here, which should all end up # using the same prepared statement. for single_query in queries: - single_query, args, result_columns = self._compile(single_query) + single_query, args, _ = self._compile(single_query) await self._connection.execute(single_query, *args) async def iterate( @@ -230,7 +153,7 @@ async def iterate( ) -> typing.AsyncGenerator[typing.Any, None]: assert self._connection is not None, "Connection is not acquired" query_str, args, result_columns = self._compile(query) - column_maps = self._create_column_maps(result_columns) + column_maps = create_column_maps(result_columns) async for row in self._connection.cursor(query_str, *args): yield Record(row, result_columns, self._dialect, column_maps) @@ -255,7 +178,6 @@ def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]: processors[key](val) if key in processors else val for key, val in compiled_params ] - result_map = compiled._result_columns else: compiled_query = compiled.string @@ -268,34 +190,6 @@ def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]: ) return compiled_query, args, result_map - @staticmethod - def _create_column_maps( - result_columns: tuple, - ) -> typing.Tuple[ - typing.Mapping[typing.Any, typing.Tuple[int, TypeEngine]], - typing.Mapping[int, typing.Tuple[int, TypeEngine]], - typing.Mapping[str, typing.Tuple[int, TypeEngine]], - ]: - """ - Generate column -> datatype mappings from the column definitions. - - These mappings are used throughout PostgresConnection methods - to initialize Record-s. The underlying DB driver does not do type - conversion for us so we have wrap the returned asyncpg.Record-s. - - :return: Three mappings from different ways to address a column to \ - corresponding column indexes and datatypes: \ - 1. by column identifier; \ - 2. by column index; \ - 3. by column name in Column sqlalchemy objects. - """ - column_map, column_map_int, column_map_full = {}, {}, {} - for idx, (column_name, _, column, datatype) in enumerate(result_columns): - column_map[column_name] = (idx, datatype) - column_map_int[idx] = (idx, datatype) - column_map_full[str(column[0])] = (idx, datatype) - return column_map, column_map_int, column_map_full - @property def raw_connection(self) -> asyncpg.connection.Connection: assert self._connection is not None, "Connection is not acquired" diff --git a/databases/backends/sqlite.py b/databases/backends/sqlite.py index 19464627..f0732d6f 100644 --- a/databases/backends/sqlite.py +++ b/databases/backends/sqlite.py @@ -6,17 +6,12 @@ from sqlalchemy.dialects.sqlite import pysqlite from sqlalchemy.engine.cursor import CursorResultMetaData from sqlalchemy.engine.interfaces import Dialect, ExecutionContext -from sqlalchemy.engine.row import Row from sqlalchemy.sql import ClauseElement from sqlalchemy.sql.ddl import DDLElement +from databases.backends.common.records import Record, Row, create_column_maps from databases.core import LOG_EXTRA, DatabaseURL -from databases.interfaces import ( - ConnectionBackend, - DatabaseBackend, - Record, - TransactionBackend, -) +from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend logger = logging.getLogger("databases") @@ -33,23 +28,10 @@ def __init__( self._pool = SQLitePool(self._database_url, **self._options) async def connect(self) -> None: - pass - # assert self._pool is None, "DatabaseBackend is already running" - # self._pool = await aiomysql.create_pool( - # host=self._database_url.hostname, - # port=self._database_url.port or 3306, - # user=self._database_url.username or getpass.getuser(), - # password=self._database_url.password, - # db=self._database_url.database, - # autocommit=True, - # ) + ... async def disconnect(self) -> None: - pass - # assert self._pool is not None, "DatabaseBackend is not running" - # self._pool.close() - # await self._pool.wait_closed() - # self._pool = None + ... def connection(self) -> "SQLiteConnection": return SQLiteConnection(self._pool, self._dialect) @@ -93,12 +75,14 @@ async def release(self) -> None: async def fetch_all(self, query: ClauseElement) -> typing.List[Record]: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, result_columns, context = self._compile(query) + column_maps = create_column_maps(result_columns) + dialect = self._dialect async with self._connection.execute(query_str, args) as cursor: rows = await cursor.fetchall() metadata = CursorResultMetaData(context, cursor.description) - return [ + rows = [ Row( metadata, metadata._processors, @@ -108,27 +92,31 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[Record]: ) for row in rows ] + return [Record(row, result_columns, dialect, column_maps) for row in rows] async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, result_columns, context = self._compile(query) + column_maps = create_column_maps(result_columns) + dialect = self._dialect async with self._connection.execute(query_str, args) as cursor: row = await cursor.fetchone() if row is None: return None metadata = CursorResultMetaData(context, cursor.description) - return Row( + row = Row( metadata, metadata._processors, metadata._keymap, Row._default_key_style, row, ) + return Record(row, result_columns, dialect, column_maps) async def execute(self, query: ClauseElement) -> typing.Any: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, result_columns, context = self._compile(query) async with self._connection.cursor() as cursor: await cursor.execute(query_str, args) if cursor.lastrowid == 0: @@ -144,34 +132,38 @@ async def iterate( self, query: ClauseElement ) -> typing.AsyncGenerator[typing.Any, None]: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, result_columns, context = self._compile(query) + column_maps = create_column_maps(result_columns) + dialect = self._dialect + async with self._connection.execute(query_str, args) as cursor: metadata = CursorResultMetaData(context, cursor.description) async for row in cursor: - yield Row( + record = Row( metadata, metadata._processors, metadata._keymap, Row._default_key_style, row, ) + yield Record(record, result_columns, dialect, column_maps) def transaction(self) -> TransactionBackend: return SQLiteTransaction(self) - def _compile( - self, query: ClauseElement - ) -> typing.Tuple[str, list, CompilationContext]: + def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]: compiled = query.compile( dialect=self._dialect, compile_kwargs={"render_postcompile": True} ) - execution_context = self._dialect.execution_ctx_cls() execution_context.dialect = self._dialect args = [] + result_map = None if not isinstance(query, DDLElement): + compiled_params = sorted(compiled.params.items()) + params = compiled.construct_params() for key in compiled.positiontup: raw_val = params[key] @@ -189,11 +181,20 @@ def _compile( compiled._loose_column_name_matching, ) - query_message = compiled.string.replace(" \n", " ").replace("\n", " ") + mapping = { + key: "$" + str(i) for i, (key, _) in enumerate(compiled_params, start=1) + } + compiled_query = compiled.string % mapping + result_map = compiled._result_columns + + else: + compiled_query = compiled.string + + query_message = compiled_query.replace(" \n", " ").replace("\n", " ") logger.debug( "Query: %s Args: %s", query_message, repr(tuple(args)), extra=LOG_EXTRA ) - return compiled.string, args, CompilationContext(execution_context) + return compiled.string, args, result_map, CompilationContext(execution_context) @property def raw_connection(self) -> aiosqlite.core.Connection: diff --git a/databases/core.py b/databases/core.py index 8394ab5c..5931121b 100644 --- a/databases/core.py +++ b/databases/core.py @@ -42,6 +42,9 @@ class Database: "postgres": "databases.backends.postgres:PostgresBackend", "mysql": "databases.backends.mysql:MySQLBackend", "mysql+asyncmy": "databases.backends.asyncmy:AsyncMyBackend", + "mssql": "databases.backends.mssql:MSSQLBackend", + "mssql+pyodbc": "databases.backends.mssql:MSSQLBackend", + "mssql+aioodbc": "databases.backends.mssql:MSSQLBackend", "sqlite": "databases.backends.sqlite:SQLiteBackend", } @@ -326,7 +329,7 @@ def _build_query( return query.bindparams(**values) if values is not None else query elif values: - return query.values(**values) + return query.values(**values) # type: ignore return query diff --git a/docs/index.md b/docs/index.md index b18de817..08d8779b 100644 --- a/docs/index.md +++ b/docs/index.md @@ -34,6 +34,7 @@ Database drivers supported are: * [aiomysql][aiomysql] * [asyncmy][asyncmy] * [aiosqlite][aiosqlite] +* [aioodbc][aioodbc] You can install the required database drivers with: @@ -43,9 +44,10 @@ $ pip install databases[aiopg] $ pip install databases[aiomysql] $ pip install databases[asyncmy] $ pip install databases[aiosqlite] +$ pip install databases[aioodbc] ``` -Note that if you are using any synchronous SQLAlchemy functions such as `engine.create_all()` or [alembic][alembic] migrations then you still have to install a synchronous DB driver: [psycopg2][psycopg2] for PostgreSQL and [pymysql][pymysql] for MySQL. +Note that if you are using any synchronous SQLAlchemy functions such as `engine.create_all()` or [alembic][alembic] migrations then you still have to install a synchronous DB driver: [psycopg2][psycopg2] for PostgreSQL, [pymysql][pymysql] for MySQL and [pyodbc][pyodbc] for SQL Server. --- @@ -83,7 +85,7 @@ values = [ ] await database.execute_many(query=query, values=values) -# Run a database query. +# Run a database query. query = "SELECT * FROM HighScores" rows = await database.fetch_all(query=query) print('High Scores:', rows) @@ -101,11 +103,13 @@ for examples of how to start using databases together with SQLAlchemy core expre [alembic]: https://alembic.sqlalchemy.org/en/latest/ [psycopg2]: https://www.psycopg.org/ [pymysql]: https://github.com/PyMySQL/PyMySQL +[pyodbc]: https://github.com/mkleehammer/pyodbc [asyncpg]: https://github.com/MagicStack/asyncpg [aiopg]: https://github.com/aio-libs/aiopg [aiomysql]: https://github.com/aio-libs/aiomysql [asyncmy]: https://github.com/long2ice/asyncmy [aiosqlite]: https://github.com/omnilib/aiosqlite +[aioodbc]: https://aioodbc.readthedocs.io/en/latest/ [starlette]: https://github.com/encode/starlette [sanic]: https://github.com/huge-success/sanic @@ -113,4 +117,4 @@ for examples of how to start using databases together with SQLAlchemy core expre [quart]: https://gitlab.com/pgjones/quart [aiohttp]: https://github.com/aio-libs/aiohttp [tornado]: https://github.com/tornadoweb/tornado -[fastapi]: https://github.com/tiangolo/fastapi +[fastapi]: https://github.com/tiangolo/fastapi \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 0699d3cc..34ff1638 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,15 +1,17 @@ -e . # Async database drivers -asyncmy==0.2.5 +asyncmy==0.2.7 +aiopg==1.4.0 aiomysql==0.1.1 -aiopg==1.3.4 aiosqlite==0.17.0 asyncpg==0.26.0 +aioodbc==0.4.0 # Sync database drivers for standard tooling around setup/teardown/migrations. -psycopg2-binary==2.9.3 +psycopg2-binary==2.9.5 pymysql==1.0.2 +pyodbc==4.0.35 # Testing autoflake==1.4 diff --git a/scripts/clean b/scripts/clean index f01cc831..d7388629 100755 --- a/scripts/clean +++ b/scripts/clean @@ -9,6 +9,12 @@ fi if [ -d 'databases.egg-info' ] ; then rm -r databases.egg-info fi +if [ -d '.mypy_cache' ] ; then + rm -r .mypy_cache +fi +if [ -d '.pytest_cache' ] ; then + rm -r .pytest_cache +fi find databases -type f -name "*.py[co]" -delete find databases -type d -name __pycache__ -delete diff --git a/setup.cfg b/setup.cfg index da1831fd..b4182c83 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,6 +2,11 @@ disallow_untyped_defs = True ignore_missing_imports = True no_implicit_optional = True +disallow_any_generics = false +disallow_untyped_decorators = true +implicit_reexport = true +disallow_incomplete_defs = true +exclude = databases/backends [tool:isort] profile = black diff --git a/setup.py b/setup.py index 3725cab9..cb76540b 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,7 @@ def get_packages(package): author_email="tom@tomchristie.com", packages=get_packages("databases"), package_data={"databases": ["py.typed"]}, - install_requires=["sqlalchemy>=1.4.42,<1.5"], + install_requires=["sqlalchemy>=2.0.7"], extras_require={ "postgresql": ["asyncpg"], "asyncpg": ["asyncpg"], @@ -57,6 +57,7 @@ def get_packages(package): "asyncmy": ["asyncmy"], "sqlite": ["aiosqlite"], "aiosqlite": ["aiosqlite"], + "aioodbc": ["aioodbc"], }, classifiers=[ "Development Status :: 3 - Alpha", @@ -70,6 +71,7 @@ def get_packages(package): "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3 :: Only", ], zip_safe=False, diff --git a/tests/test_connection_options.py b/tests/test_connection_options.py index e6fe6849..8fdc0d24 100644 --- a/tests/test_connection_options.py +++ b/tests/test_connection_options.py @@ -6,6 +6,8 @@ import pytest from databases.backends.aiopg import AiopgBackend +from databases.backends.mssql import MSSQLBackend +from databases.backends.mysql import MySQLBackend from databases.backends.postgres import PostgresBackend from databases.core import DatabaseURL from tests.test_databases import DATABASE_URLS, async_adapter @@ -168,3 +170,38 @@ def test_aiopg_explicit_ssl(): backend = AiopgBackend("postgresql+aiopg://localhost/database", ssl=True) kwargs = backend._get_connection_kwargs() assert kwargs == {"ssl": True} + + +@pytest.mark.skipif(sys.version_info >= (3, 10), reason="requires python3.9 or lower") +def test_mssql_pool_size(): + backend = MySQLBackend("mssql+pyodbc://localhost/database?min_size=1&max_size=20") + kwargs = backend._get_connection_kwargs() + assert kwargs == {"minsize": 1, "maxsize": 20} + + +@pytest.mark.skipif(sys.version_info >= (3, 10), reason="requires python3.9 or lower") +def test_mssql_explicit_pool_size(): + backend = MySQLBackend("mssql+pyodbc://localhost/database", min_size=1, max_size=20) + kwargs = backend._get_connection_kwargs() + assert kwargs == {"minsize": 1, "maxsize": 20} + + +@pytest.mark.skipif(sys.version_info >= (3, 10), reason="requires python3.9 or lower") +def test_mssql_ssl(): + backend = MySQLBackend("mssql+pyodbc://localhost/database?ssl=true") + kwargs = backend._get_connection_kwargs() + assert kwargs == {"ssl": True} + + +@pytest.mark.skipif(sys.version_info >= (3, 10), reason="requires python3.9 or lower") +def test_mssql_explicit_ssl(): + backend = MySQLBackend("mssql+pyodbc://localhost/database", ssl=True) + kwargs = backend._get_connection_kwargs() + assert kwargs == {"ssl": True} + + +@pytest.mark.skipif(sys.version_info >= (3, 10), reason="requires python3.9 or lower") +def test_mssql_pool_recycle(): + backend = MySQLBackend("mssql+pyodbc://localhost/database?pool_recycle=20") + kwargs = backend._get_connection_kwargs() + assert kwargs == {"pool_recycle": 20} diff --git a/tests/test_databases.py b/tests/test_databases.py index a7545e31..fb7d0ead 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -3,7 +3,6 @@ import decimal import functools import os -import re from unittest.mock import MagicMock, patch import pytest @@ -89,6 +88,8 @@ def create_test_database(): "postgresql+aiopg", "sqlite+aiosqlite", "postgresql+asyncpg", + "mssql+pyodbc", + "mssql+aioodbc", ]: url = str(database_url.replace(driver=None)) engine = sqlalchemy.create_engine(url) @@ -106,6 +107,8 @@ def create_test_database(): "postgresql+aiopg", "sqlite+aiosqlite", "postgresql+asyncpg", + "mssql+pyodbc", + "mssql+aioodbc", ]: url = str(database_url.replace(driver=None)) engine = sqlalchemy.create_engine(url) @@ -167,24 +170,24 @@ async def test_queries(database_url): assert result["completed"] == True # fetch_val() - query = sqlalchemy.sql.select([notes.c.text]) + query = sqlalchemy.sql.select(*[notes.c.text]) result = await database.fetch_val(query=query) assert result == "example1" # fetch_val() with no rows - query = sqlalchemy.sql.select([notes.c.text]).where( + query = sqlalchemy.sql.select(*[notes.c.text]).where( notes.c.text == "impossible" ) result = await database.fetch_val(query=query) assert result is None # fetch_val() with a different column - query = sqlalchemy.sql.select([notes.c.id, notes.c.text]) + query = sqlalchemy.sql.select(*[notes.c.id, notes.c.text]) result = await database.fetch_val(query=query, column=1) assert result == "example1" # row access (needed to maintain test coverage for Record.__getitem__ in postgres backend) - query = sqlalchemy.sql.select([notes.c.text]) + query = sqlalchemy.sql.select(*[notes.c.text]) result = await database.fetch_one(query=query) assert result["text"] == "example1" assert result[0] == "example1" @@ -244,6 +247,7 @@ async def test_queries_raw(database_url): query = "SELECT completed FROM notes WHERE text = :text" result = await database.fetch_val(query=query, values={"text": "example1"}) assert result == True + query = "SELECT * FROM notes WHERE text = :text" result = await database.fetch_val( query=query, values={"text": "example1"}, column="completed" @@ -354,7 +358,7 @@ async def test_results_support_column_reference(database_url): await database.execute(query, values) # fetch_all() - query = sqlalchemy.select([articles, custom_date]) + query = sqlalchemy.select(*[articles, custom_date]) results = await database.fetch_all(query=query) assert len(results) == 1 assert results[0][articles.c.title] == "Hello, world Article" @@ -498,6 +502,7 @@ def insert_independently(): query = notes.insert().values(text="example1", completed=True) conn.execute(query) + conn.close() def delete_independently(): engine = sqlalchemy.create_engine(str(database_url)) @@ -505,6 +510,7 @@ def delete_independently(): query = notes.delete() conn.execute(query) + conn.close() async with Database(database_url) as database: async with database.transaction(force_rollback=True, isolation="serializable"): @@ -689,6 +695,7 @@ async def test_json_field(database_url): # fetch_all() query = session.select() results = await database.fetch_all(query=query) + assert len(results) == 1 assert results[0]["data"] == {"text": "hello", "boolean": True, "int": 1} @@ -853,7 +860,6 @@ async def test_queries_with_expose_backend_connection(database_url): async with connection.transaction(force_rollback=True): # Get the raw connection raw_connection = connection.raw_connection - # Insert query if database.url.scheme in [ "mysql", @@ -862,6 +868,12 @@ async def test_queries_with_expose_backend_connection(database_url): "postgresql+aiopg", ]: insert_query = "INSERT INTO notes (text, completed) VALUES (%s, %s)" + elif database.url.scheme in [ + "mssql", + "mssql+pyodbc", + "mssql+aioodbc", + ]: + insert_query = "INSERT INTO notes (text, completed) VALUES (?, ?)" else: insert_query = "INSERT INTO notes (text, completed) VALUES ($1, $2)" @@ -872,6 +884,9 @@ async def test_queries_with_expose_backend_connection(database_url): "mysql", "mysql+aiomysql", "postgresql+aiopg", + "mssql", + "mssql+pyodbc", + "mssql+aioodbc", ]: cursor = await raw_connection.cursor() await cursor.execute(insert_query, values) @@ -897,6 +912,10 @@ async def test_queries_with_expose_backend_connection(database_url): # No async support for `executemany` for value in values: await cursor.execute(insert_query, value) + elif database.url.scheme in ["mssql", "mssql+aioodbc", "mssql+pyodbc"]: + cursor = await raw_connection.cursor() + for value in values: + await cursor.execute(insert_query, value) else: await raw_connection.executemany(insert_query, values) @@ -908,6 +927,9 @@ async def test_queries_with_expose_backend_connection(database_url): "mysql", "mysql+aiomysql", "postgresql+aiopg", + "mssql", + "mssql+pyodbc", + "mssql+aioodbc", ]: cursor = await raw_connection.cursor() await cursor.execute(select_query) @@ -937,6 +959,13 @@ async def test_queries_with_expose_backend_connection(database_url): async with raw_connection.cursor() as cursor: await cursor.execute(select_query) result = await cursor.fetchone() + elif database.url.scheme in ["mssql", "mssql+pyodbc", "mssql+aioodbc"]: + cursor = await raw_connection.cursor() + try: + await cursor.execute(select_query) + result = await cursor.fetchone() + finally: + await cursor.close() else: cursor = await raw_connection.cursor() await cursor.execute(select_query) @@ -1016,7 +1045,10 @@ async def test_iterate_outside_transaction_with_values(database_url): pytest.skip("MySQL does not support `FROM (VALUES ...)` (F641)") async with Database(database_url) as database: - query = "SELECT * FROM (VALUES (1), (2), (3), (4), (5)) as t" + if database_url.dialect == "mssql": + query = "SELECT * FROM (VALUES (1), (2), (3), (4), (5)) as X(t)" + else: + query = "SELECT * FROM (VALUES (1), (2), (3), (4), (5)) as t" iterate_results = [] async for result in database.iterate(query=query): @@ -1038,13 +1070,24 @@ async def test_iterate_outside_transaction_with_temp_table(database_url): pytest.skip("SQLite interface does not work with temporary tables.") async with Database(database_url) as database: - query = "CREATE TEMPORARY TABLE no_transac(num INTEGER)" - await database.execute(query) + if database_url.dialect == "mssql": + query = "CREATE TABLE ##no_transac(num INTEGER)" + await database.execute(query) - query = "INSERT INTO no_transac(num) VALUES (1), (2), (3), (4), (5)" - await database.execute(query) + query = "INSERT INTO ##no_transac VALUES (1), (2), (3), (4), (5)" + await database.execute(query) + + query = "SELECT * FROM ##no_transac" + + else: + query = "CREATE TEMPORARY TABLE no_transac(num INTEGER)" + await database.execute(query) + + query = "INSERT INTO no_transac(num) VALUES (1), (2), (3), (4), (5)" + await database.execute(query) + + query = "SELECT * FROM no_transac" - query = "SELECT * FROM no_transac" iterate_results = [] async for result in database.iterate(query=query): @@ -1075,52 +1118,6 @@ async def test_column_names(database_url, select_query): assert results[0]["completed"] == True -@pytest.mark.parametrize("database_url", DATABASE_URLS) -@async_adapter -async def test_posgres_interface(database_url): - """ - Since SQLAlchemy 1.4, `Row.values()` is removed and `Row.keys()` is deprecated. - Custom postgres interface mimics more or less this behaviour by deprecating those - two methods - """ - database_url = DatabaseURL(database_url) - - if database_url.scheme not in ["postgresql", "postgresql+asyncpg"]: - pytest.skip("Test is only for asyncpg") - - async with Database(database_url) as database: - async with database.transaction(force_rollback=True): - query = notes.insert() - values = {"text": "example1", "completed": True} - await database.execute(query, values) - - query = notes.select() - result = await database.fetch_one(query=query) - - with pytest.warns( - DeprecationWarning, - match=re.escape( - "The `Row.keys()` method is deprecated to mimic SQLAlchemy behaviour, " - "use `Row._mapping.keys()` instead." - ), - ): - assert ( - list(result.keys()) - == [k for k in result] - == ["id", "text", "completed"] - ) - - with pytest.warns( - DeprecationWarning, - match=re.escape( - "The `Row.values()` method is deprecated to mimic SQLAlchemy behaviour, " - "use `Row._mapping.values()` instead." - ), - ): - # avoid checking `id` at index 0 since it may change depending on the launched tests - assert list(result.values())[1:] == ["example1", True] - - @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter async def test_postcompile_queries(database_url): diff --git a/tests/test_integration.py b/tests/test_integration.py index 139f8ffe..25a47287 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -31,6 +31,13 @@ def create_test_database(): "postgresql+asyncpg", ]: url = str(database_url.replace(driver=None)) + elif database_url.scheme in [ + "mssql", + "mssql+pyodbc", + "mssql+aioodbc", + "mssql+pymssql", + ]: + url = str(database_url.replace(driver="pyodbc")) engine = sqlalchemy.create_engine(url) metadata.create_all(engine) @@ -47,6 +54,13 @@ def create_test_database(): "postgresql+asyncpg", ]: url = str(database_url.replace(driver=None)) + elif database_url.scheme in [ + "mssql", + "mssql+pyodbc", + "mssql+aioodbc", + "mssql+pymssql", + ]: + url = str(database_url.replace(driver="pyodbc")) engine = sqlalchemy.create_engine(url) metadata.drop_all(engine)