From b43fc704c2ccfd84a9365fd9dc26f28f9862a0c3 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Tue, 24 Sep 2024 16:55:44 -0400 Subject: [PATCH] feat: multicorn2 (Postgres FDW) backend (#397) * feat: multicorn2 (Postgres FDW) backend * Adding tests * Adding tests * Optimizing SELECT * Fix tests * Write API * Query cost * Add docs * Add integration test * Different strategy * Another approach * Rebase * Fix docker * Remove entrypoint * Fix tests --- .github/workflows/python-integration.yml | 11 + .gitignore | 2 + CHANGELOG.rst | 6 +- README.rst | 19 ++ docs/index.rst | 1 + docs/postgres.rst | 23 ++ examples/postgres.py | 31 ++ postgres/Dockerfile | 38 +++ postgres/docker-compose.yml | 19 ++ postgres/init.sql | 1 + requirements/base.txt | 7 + requirements/test.txt | 4 + setup.cfg | 10 + src/shillelagh/backends/apsw/db.py | 5 +- src/shillelagh/backends/apsw/dialects/base.py | 15 +- src/shillelagh/backends/apsw/vt.py | 23 +- src/shillelagh/backends/multicorn/__init__.py | 0 src/shillelagh/backends/multicorn/db.py | 279 ++++++++++++++++++ .../backends/multicorn/dialects/__init__.py | 0 .../backends/multicorn/dialects/base.py | 108 +++++++ src/shillelagh/backends/multicorn/fdw.py | 161 ++++++++++ src/shillelagh/lib.py | 20 ++ .../adapters/api/gsheets/integration_test.py | 57 ++++ tests/backends/apsw/db_test.py | 3 +- tests/backends/apsw/dialects/base_test.py | 8 + tests/backends/multicorn/__init__.py | 0 tests/backends/multicorn/db_test.py | 236 +++++++++++++++ tests/backends/multicorn/dialects/__init__.py | 0 .../backends/multicorn/dialects/base_test.py | 82 +++++ tests/backends/multicorn/fdw_test.py | 212 +++++++++++++ 30 files changed, 1352 insertions(+), 29 deletions(-) create mode 100644 docs/postgres.rst create mode 100644 examples/postgres.py create mode 100644 postgres/Dockerfile create mode 100644 postgres/docker-compose.yml create mode 100644 postgres/init.sql create mode 100644 src/shillelagh/backends/multicorn/__init__.py create mode 100644 src/shillelagh/backends/multicorn/db.py create mode 100644 src/shillelagh/backends/multicorn/dialects/__init__.py create mode 100644 src/shillelagh/backends/multicorn/dialects/base.py create mode 100644 src/shillelagh/backends/multicorn/fdw.py create mode 100644 tests/backends/multicorn/__init__.py create mode 100644 tests/backends/multicorn/db_test.py create mode 100644 tests/backends/multicorn/dialects/__init__.py create mode 100644 tests/backends/multicorn/dialects/base_test.py create mode 100644 tests/backends/multicorn/fdw_test.py diff --git a/.github/workflows/python-integration.yml b/.github/workflows/python-integration.yml index 8f048852..37a68d8a 100644 --- a/.github/workflows/python-integration.yml +++ b/.github/workflows/python-integration.yml @@ -31,8 +31,19 @@ jobs: run: | python -m pip install --upgrade pip pip install -r requirements/test.txt + - name: Start the Postgres service + run: | + docker compose -f postgres/docker-compose.yml up --build -d + - name: Wait for Postgres to become available + run: | + until docker run --network container:postgres-postgres-1 postgres-postgres pg_isready -h postgres -p 5432 -U shillelagh --timeout=90; do sleep 10; done - name: Test with pytest env: SHILLELAGH_ADAPTER_KWARGS: ${{ secrets.SHILLELAGH_ADAPTER_KWARGS }} run: | pytest --cov-fail-under=100 --cov=src/shillelagh -vv tests/ --doctest-modules src/shillelagh --with-integration --with-slow-integration + - name: Stop the Postgres service + if: always() + run: | + docker logs postgres-postgres-1 + docker compose -f postgres/docker-compose.yml down diff --git a/.gitignore b/.gitignore index 22e91ac2..64569892 100644 --- a/.gitignore +++ b/.gitignore @@ -105,3 +105,5 @@ ENV/ *.sqlite *.db *.swp + +multicorn2 diff --git a/CHANGELOG.rst b/CHANGELOG.rst index d8362541..f0bdabbb 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -2,8 +2,10 @@ Changelog ========= -Next -==== +Next (1.3.0) +============ + +- New Postgres backend based on multicorn2 (#397) Version 1.2.28 - 2024-09-11 =========================== diff --git a/README.rst b/README.rst index d1a6893c..68883d98 100644 --- a/README.rst +++ b/README.rst @@ -52,6 +52,25 @@ And a command-line utility: $ shillelagh sql> SELECT * FROM a_table +There is also an [experimental backend](https://shillelagh.readthedocs.io/en/latest/postgres.html) that uses Postgres with the [Multicorn2](http://multicorn2.org/) extension: + +.. code-block:: python + + from shillelagh.backends.multicorn.db import connect + + connection = connect( + user="username", + password="password", + host="localhost", + port=5432, + database="examples", + ) + +.. code-block:: python + + from sqlalchemy import create_engine + engine = create_engine("shillelagh+multicorn2://username:password@localhost:5432/examples") + Why SQL? ======== diff --git a/docs/index.rst b/docs/index.rst index a46ebe53..3466f4cb 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -10,6 +10,7 @@ Contents Usage Adapters Creating a new adapter + Postgres backend License Authors Changelog diff --git a/docs/postgres.rst b/docs/postgres.rst new file mode 100644 index 00000000..d76d57eb --- /dev/null +++ b/docs/postgres.rst @@ -0,0 +1,23 @@ +.. _postgres: + +================ +Postgres backend +================ + +Since version 1.3 Shillelagh ships with an experimental backend that uses Postgres instead of SQLite. The backend implements a custom [pyscopg2](https://pypi.org/project/psycopg2/) cursor that automatically registers a foreign data wrapper (FDW) whenever a supported table is accessed. It's based on the [multicorn2](http://multicorn2.org/) extension and Python package. + +To use the backend you need to: + +1. Install the [Multicorn2](http://multicorn2.org/) extension. +2. Install the multicorn2 Python package in the machine running Postgres. Note that this is not the "multicorn" package available on PyPI. You need to download the source and install it manually. +3. Install Shillelagh in the machine running Postgres. + +Note that you need to install Python packages in a way that they are available to the process running Postgres. You can either install them globally, or install them in a virtual environment and have it activated in the process that starts Postgres. + +The ``postgres/`` directory has a Docker configuration that can be used to test the backend, or as a basis for installation. To run it, execute: + +.. code-block:: bash + + docker compose -f postgres/docker-compose.yml up + +You should then be able to run the example script in `examples/postgres.py`_ to test that everything works. diff --git a/examples/postgres.py b/examples/postgres.py new file mode 100644 index 00000000..534b5484 --- /dev/null +++ b/examples/postgres.py @@ -0,0 +1,31 @@ +""" +Simple multicorn2 test. + +Multicorn2 is an extension for PostgreSQL that allows you to create foreign data wrappers +in Python. To use it, you need to install on the machine running Postgres the extension, +the multicorn2 package (not on (PyPI), and the shillelagh package. + +If you want to play with it Shillelagh has a `docker-compose.yml` file that will run +Postgres with the extension and the Python packages. Just run: + + $ cd postgres/ + $ docker compose up --build -d + +Then you can run this script. +""" + +from sqlalchemy import create_engine + +# the backend uses psycopg2 under the hood, so any valid connection string for it will +# work; just replace the scheme with `shillelagh+multicorn2` +engine = create_engine( + "shillelagh+multicorn2://shillelagh:shillelagh123@localhost:5432/shillelagh", +) +connection = engine.connect() + +SQL = ( + 'SELECT * FROM "https://docs.google.com/spreadsheets/d/' + '1LcWZMsdCl92g7nA-D6qGRqg1T5TiHyuKJUY1u9XAnsk/edit#gid=0"' +) +for row in connection.execute(SQL): + print(row) diff --git a/postgres/Dockerfile b/postgres/Dockerfile new file mode 100644 index 00000000..267c7786 --- /dev/null +++ b/postgres/Dockerfile @@ -0,0 +1,38 @@ +# Use the official Postgres image as a base +FROM postgres:13 + +WORKDIR /code +COPY . /code + +# Use root for package installation +USER root + +# Install system dependencies +RUN apt-get update && apt-get install -y \ + build-essential \ + git \ + postgresql-server-dev-13 \ + python3 \ + python3-dev \ + python3-pip \ + python3-venv \ + wget + +# Download, build, and install multicorn2 +RUN wget https://github.com/pgsql-io/multicorn2/archive/refs/tags/v2.5.tar.gz && \ + tar -xvf v2.5.tar.gz && \ + cd multicorn2-2.5 && \ + make && \ + make install + + +# Create a virtual environment and install dependencies +RUN python3 -m venv /code/venv && \ + /code/venv/bin/pip install --upgrade pip && \ + /code/venv/bin/pip install -e '.[all]' + +# Set environment variable for PostgreSQL to use the virtual environment +ENV PATH="/code/venv/bin:$PATH" + +# Switch back to the default postgres user +USER postgres diff --git a/postgres/docker-compose.yml b/postgres/docker-compose.yml new file mode 100644 index 00000000..5d905388 --- /dev/null +++ b/postgres/docker-compose.yml @@ -0,0 +1,19 @@ +version: '3.8' + +services: + postgres: + build: + context: .. + dockerfile: postgres/Dockerfile + environment: + POSTGRES_PASSWORD: shillelagh123 + POSTGRES_USER: shillelagh + POSTGRES_DB: shillelagh + volumes: + - db_data:/var/lib/postgresql/data + - ./init.sql:/docker-entrypoint-initdb.d/init.sql:ro + ports: + - "5432:5432" + +volumes: + db_data: diff --git a/postgres/init.sql b/postgres/init.sql new file mode 100644 index 00000000..6844895f --- /dev/null +++ b/postgres/init.sql @@ -0,0 +1 @@ +CREATE EXTENSION IF NOT EXISTS multicorn; diff --git a/requirements/base.txt b/requirements/base.txt index d537e320..943dff0b 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -19,6 +19,8 @@ certifi==2022.6.15 # via requests charset-normalizer==2.1.0 # via requests +exceptiongroup==1.1.3 + # via cattrs greenlet==2.0.2 # via # shillelagh @@ -45,9 +47,14 @@ sqlalchemy==1.4.39 # via shillelagh typing-extensions==4.3.0 # via shillelagh + # via + # cattrs + # shillelagh url-normalize==1.4.3 # via requests-cache urllib3==1.26.10 # via # requests # requests-cache +zipp==3.15.0 + # via importlib-metadata diff --git a/requirements/test.txt b/requirements/test.txt index 5a71df88..9e136f6e 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -95,6 +95,8 @@ lazy-object-proxy==1.7.1 # via astroid mccabe==0.7.0 # via pylint +multicorn @ git+https://github.com/pgsql-io/multicorn2.git@v2.5 + # via shillelagh multidict==6.0.2 # via # aiohttp @@ -137,6 +139,8 @@ psutil==5.9.1 # via shillelagh pyarrow==16.0.0 # via shillelagh +psycopg2-binary==2.9.9 + # via shillelagh pyasn1==0.4.8 # via # pyasn1-modules diff --git a/setup.cfg b/setup.cfg index e276358f..00329003 100644 --- a/setup.cfg +++ b/setup.cfg @@ -80,6 +80,8 @@ testing = google-auth>=1.23.0 holidays>=0.23 html5lib>=1.1 + jsonpath-python>=1.0.5 + multicorn @ git+https://github.com/pgsql-io/multicorn2.git@v2.5 pandas>=1.2.2 pip-tools>=6.4.0 pre-commit>=2.13.0 @@ -87,6 +89,7 @@ testing = prison>=0.2.1 prompt_toolkit>=3 psutil>=5.8.0 + psycopg2-binary>=2.9.9 pyarrow>=14.0.1 pyfakefs>=4.3.3 pygments>=2.8 @@ -111,10 +114,13 @@ all = google-auth>=1.23.0 holidays>=0.23 html5lib>=1.1 + jsonpath-python>=1.0.5 + multicorn @ git+https://github.com/pgsql-io/multicorn2.git@v2.5 pandas>=1.2.2 prison>=0.2.1 prompt_toolkit>=3 psutil>=5.8.0 + psycopg2-binary>=2.9.9 pyarrow>=14.0.1 pygments>=2.8 python-graphql-client>=0.4.3 @@ -153,6 +159,9 @@ htmltableapi = beautifulsoup4>=4.11.1 html5lib>=1.1 pandas>=1.2.2 +multicorn = + multicorn @ git+https://github.com/pgsql-io/multicorn2.git@v2.5 + psycopg2-binary>=2.9.9 pandasmemory = pandas>=1.2.2 s3selectapi = @@ -184,6 +193,7 @@ sqlalchemy.dialects = shillelagh.apsw = shillelagh.backends.apsw.dialects.base:APSWDialect shillelagh.safe = shillelagh.backends.apsw.dialects.safe:APSWSafeDialect gsheets = shillelagh.backends.apsw.dialects.gsheets:APSWGSheetsDialect + shillelagh.multicorn2 = shillelagh.backends.multicorn.dialects.base:Multicorn2Dialect console_scripts = shillelagh = shillelagh.console:main # For example: diff --git a/src/shillelagh/backends/apsw/db.py b/src/shillelagh/backends/apsw/db.py index ab7b8901..be9cd007 100644 --- a/src/shillelagh/backends/apsw/db.py +++ b/src/shillelagh/backends/apsw/db.py @@ -286,9 +286,10 @@ def _drop_table_uri(self, operation: str) -> Optional[str]: operation = "\n".join( line for line in operation.split("\n") if not line.strip().startswith("--") ) + schema = re.escape(self.schema) regexp = re.compile( - rf"^\s*DROP\s+TABLE\s+(IF\s+EXISTS\s+)?" - rf'({self.schema}\.)?(?P(.*?)|(".*?"))\s*;?\s*$', + r"^\s*DROP\s+TABLE\s+(IF\s+EXISTS\s+)?" + rf'({schema}\.)?(?P(.*?)|(".*?"))\s*;?\s*$', re.IGNORECASE, ) if match := regexp.match(operation): diff --git a/src/shillelagh/backends/apsw/dialects/base.py b/src/shillelagh/backends/apsw/dialects/base.py index 73d15b60..56302020 100644 --- a/src/shillelagh/backends/apsw/dialects/base.py +++ b/src/shillelagh/backends/apsw/dialects/base.py @@ -1,5 +1,8 @@ +""" +A SQLALchemy dialect. +""" + # pylint: disable=protected-access, abstract-method -"""A SQLALchemy dialect.""" from typing import Any, Dict, List, Optional, Tuple, cast @@ -102,7 +105,6 @@ def has_table( # pylint: disable=unused-argument connection: _ConnectionFairy, table_name: str, schema: Optional[str] = None, - info_cache: Optional[Dict[Any, Any]] = None, **kwargs: Any, ) -> bool: """ @@ -111,7 +113,14 @@ def has_table( # pylint: disable=unused-argument try: get_adapter_for_table_name(connection, table_name) except ProgrammingError: - return False + return bool( + super().has_table( + connection, + table_name, + schema, + **kwargs, # pylint: disable=unused-argument + ), + ) return True # needed for SQLAlchemy diff --git a/src/shillelagh/backends/apsw/vt.py b/src/shillelagh/backends/apsw/vt.py index 771f1c39..217b9def 100644 --- a/src/shillelagh/backends/apsw/vt.py +++ b/src/shillelagh/backends/apsw/vt.py @@ -42,8 +42,8 @@ StringDuration, StringInteger, ) -from shillelagh.filters import Filter, Operator -from shillelagh.lib import best_index_object_available, deserialize +from shillelagh.filters import Operator +from shillelagh.lib import best_index_object_available, deserialize, get_bounds from shillelagh.typing import ( Constraint, Index, @@ -245,25 +245,6 @@ def get_order( ] -def get_bounds( - columns: Dict[str, Field], - all_bounds: DefaultDict[str, Set[Tuple[Operator, Any]]], -) -> Dict[str, Filter]: - """ - Combine all filters that apply to each column. - """ - bounds: Dict[str, Filter] = {} - for column_name, operations in all_bounds.items(): - column_type = columns[column_name] - operators = {operation[0] for operation in operations} - for class_ in column_type.filters: - if all(operator in class_.operators for operator in operators): - bounds[column_name] = class_.build(operations) - break - - return bounds - - class VTModule: # pylint: disable=too-few-public-methods """ A module used to create SQLite virtual tables. diff --git a/src/shillelagh/backends/multicorn/__init__.py b/src/shillelagh/backends/multicorn/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/shillelagh/backends/multicorn/db.py b/src/shillelagh/backends/multicorn/db.py new file mode 100644 index 00000000..df826292 --- /dev/null +++ b/src/shillelagh/backends/multicorn/db.py @@ -0,0 +1,279 @@ +# pylint: disable=invalid-name, c-extension-no-member, unused-import +""" +A DB API 2.0 wrapper. +""" + +import logging +import re +from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast +from uuid import uuid4 + +import psycopg2 +from psycopg2 import extensions + +from shillelagh.adapters.base import Adapter +from shillelagh.adapters.registry import registry +from shillelagh.exceptions import ( # nopycln: import; pylint: disable=redefined-builtin + DatabaseError, + DataError, + Error, + IntegrityError, + InterfaceError, + InternalError, + OperationalError, + ProgrammingError, + Warning, +) +from shillelagh.lib import ( + combine_args_kwargs, + escape_identifier, + find_adapter, + serialize, +) +from shillelagh.types import ( + BINARY, + DATETIME, + NUMBER, + ROWID, + STRING, + Binary, + Date, + DateFromTicks, + Time, + TimeFromTicks, + Timestamp, + TimestampFromTicks, +) + +__all__ = [ + "DatabaseError", + "DataError", + "Error", + "IntegrityError", + "InterfaceError", + "InternalError", + "OperationalError", + "BINARY", + "DATETIME", + "NUMBER", + "ROWID", + "STRING", + "Binary", + "Date", + "DateFromTicks", + "Time", + "TimeFromTicks", + "Timestamp", + "TimestampFromTicks", + "Warning", +] + +apilevel = "2.0" +threadsafety = 2 +paramstyle = "pyformat" + +NO_SUCH_TABLE = re.compile('relation "(.*?)" does not exist') +DEFAULT_SCHEMA = "main" + +_logger = logging.getLogger(__name__) + + +class Cursor(extensions.cursor): # pylint: disable=too-few-public-methods + """ + A cursor that registers FDWs. + """ + + def __init__( + self, + *args: Any, + adapters: Dict[str, Type[Adapter]], + adapter_kwargs: Dict[str, Dict[str, Any]], + schema: str, + **kwargs: Any, + ): + super().__init__(*args, **kwargs) + + self._adapters = list(adapters.values()) + self._adapter_map = {v: k for k, v in adapters.items()} + self._adapter_kwargs = adapter_kwargs + self.schema = schema + + def execute( + self, + operation: str, + parameters: Optional[Tuple[Any, ...]] = None, + ) -> Union["Cursor", extensions.cursor]: + """ + Execute a query, automatically registering FDWs if necessary. + """ + # which cursor should be returned + cursor: Union["Cursor", extensions.cursor] = self + + while True: + savepoint = uuid4() + super().execute(f'SAVEPOINT "{savepoint}"') + + try: + cursor = cast(extensions.cursor, super().execute(operation, parameters)) + break + except psycopg2.errors.UndefinedTable as ex: # pylint: disable=no-member + message = ex.args[0] + match = NO_SUCH_TABLE.match(message) + if not match: + raise ProgrammingError(message) from ex + + # Postgres truncates the table name in the error message, so we need to + # find it in the original query + fragment = match.group(1) + uri = self._get_table_uri(fragment, operation) + if not uri: + raise ProgrammingError("Could not determine table name") from ex + + super().execute(f'ROLLBACK TO SAVEPOINT "{savepoint}"') + self._create_table(uri) + + if uri := self._drop_table_uri(operation): + adapter, args, kwargs = find_adapter( + uri, + self._adapter_kwargs, + self._adapters, + ) + instance = adapter(*args, **kwargs) + instance.drop_table() + + return cursor + + def _get_table_uri(self, fragment: str, operation: str) -> Optional[str]: + """ + Extract the table name from a query. + """ + schema = re.escape(self.schema) + fragment = re.escape(fragment) + regexp = re.compile( + rf'\b(FROM|INTO)\s+({schema}\.)?(?P"{fragment}.*?")', + re.IGNORECASE, + ) + if match := regexp.search(operation): + return match.groupdict()["uri"].strip('"') + + return None + + def _drop_table_uri(self, operation: str) -> Optional[str]: + """ + Build a ``DROP TABLE`` regexp. + """ + schema = re.escape(self.schema) + regexp = re.compile( + r"^\s*DROP\s+TABLE\s+(IF\s+EXISTS\s+)?" + rf'({schema}\.)?(?P(.*?)|(".*?"))\s*;?\s*$', + re.IGNORECASE, + ) + if match := regexp.match(operation): + return match.groupdict()["uri"].strip('"') + + return None + + def _create_table(self, uri: str) -> None: + """ + Register a FDW. + """ + adapter, args, kwargs = find_adapter(uri, self._adapter_kwargs, self._adapters) + formatted_args = serialize(combine_args_kwargs(adapter, *args, **kwargs)) + entrypoint = self._adapter_map[adapter] + + table_name = escape_identifier(uri) + + columns = adapter(*args, **kwargs).get_columns() + if not columns: + raise ProgrammingError(f"Virtual table {table_name} has no columns") + + quoted_columns = {k.replace('"', '""'): v for k, v in columns.items()} + formatted_columns = ", ".join( + f'"{k}" {v.type}' for (k, v) in quoted_columns.items() + ) + + super().execute( + """ +CREATE SERVER shillelagh foreign data wrapper multicorn options ( + wrapper 'shillelagh.backends.multicorn.fdw.MulticornForeignDataWrapper' +); + """, + ) + super().execute( + f""" +CREATE FOREIGN TABLE "{table_name}" ( + {formatted_columns} +) server shillelagh options ( + adapter '{entrypoint}', + args '{formatted_args}' +); + """, + ) + + +class CursorFactory: # pylint: disable=too-few-public-methods + """ + Custom cursor factory. + + This returns a custom cursor that will auto register FDWs for the user. + """ + + def __init__( + self, + adapters: Dict[str, Type[Adapter]], + adapter_kwargs: Dict[str, Dict[str, Any]], + schema: str, + ): + self.schema = schema + self._adapters = adapters + self._adapter_kwargs = adapter_kwargs + + def __call__(self, *args, **kwargs) -> Cursor: + """ + Create a new cursor. + """ + return Cursor( + *args, + adapters=self._adapters, + adapter_kwargs=self._adapter_kwargs, + schema=self.schema, + **kwargs, + ) + + +def connect( # pylint: disable=too-many-arguments + dsn: Optional[str] = None, + adapters: Optional[List[str]] = None, + adapter_kwargs: Optional[Dict[str, Dict[str, Any]]] = None, + schema: str = DEFAULT_SCHEMA, + **psycopg2_connection_kwargs: Any, +) -> extensions.connection: + """ + Constructor for creating a connection to the database. + + Only safe adapters can be loaded. If no adapters are specified, all safe adapters are + loaded. + """ + adapter_kwargs = adapter_kwargs or {} + enabled_adapters = { + name: adapter + for name, adapter in registry.load_all(adapters, safe=False).items() + if adapter.safe + } + + # replace entry point names with class names + mapping = { + name: adapter.__name__.lower() for name, adapter in enabled_adapters.items() + } + adapter_kwargs = {mapping[k]: v for k, v in adapter_kwargs.items() if k in mapping} + + cursor_factory = CursorFactory( + enabled_adapters, + adapter_kwargs, + schema, + ) + return psycopg2.connect( + dsn, + cursor_factory=cursor_factory, + **psycopg2_connection_kwargs, + ) diff --git a/src/shillelagh/backends/multicorn/dialects/__init__.py b/src/shillelagh/backends/multicorn/dialects/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/shillelagh/backends/multicorn/dialects/base.py b/src/shillelagh/backends/multicorn/dialects/base.py new file mode 100644 index 00000000..e3ea9872 --- /dev/null +++ b/src/shillelagh/backends/multicorn/dialects/base.py @@ -0,0 +1,108 @@ +""" +A SQLAlchemy dialect based on psycopg2 and multicorn2. +""" + +# pylint: disable=protected-access, abstract-method + +from typing import Any, Dict, List, Optional, Tuple, cast + +from psycopg2 import extensions +from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2 +from sqlalchemy.engine.url import URL +from sqlalchemy.pool.base import _ConnectionFairy + +from shillelagh.adapters.base import Adapter +from shillelagh.backends.multicorn import db +from shillelagh.exceptions import ProgrammingError +from shillelagh.lib import find_adapter + + +class Multicorn2Dialect(PGDialect_psycopg2): + """ + A SQLAlchemy dialect for Shillelagh based on psycopg2 and multicorn2. + """ + + name = "shillelagh" + driver = "multicorn2" + + supports_statement_cache = True + + @classmethod + def dbapi(cls): + """ + Return the DB API module. + """ + return db + + @classmethod + def import_dbapi(cls): + """ + New version of the ``dbapi`` method. + """ + return db + + def __init__( + self, + adapters: Optional[List[str]] = None, + adapter_kwargs: Optional[Dict[str, Dict[str, Any]]] = None, + **kwargs: Any, + ): + super().__init__(**kwargs) + self._adapters = adapters + self._adapter_kwargs = adapter_kwargs or {} + + def create_connect_args( + self, + url: URL, + ) -> Tuple[List[Any], Dict[str, Any]]: + args, kwargs = super().create_connect_args(url) + kwargs.update( + { + "adapters": self._adapters, + "adapter_kwargs": self._adapter_kwargs, + }, + ) + return args, kwargs + + def has_table( + self, + connection: _ConnectionFairy, + table_name: str, + schema: Optional[str] = None, + **kwargs: Any, + ) -> bool: + """ + Return true if a given table exists. + """ + try: + get_adapter_for_table_name(connection, table_name) + except ProgrammingError: + return bool( + super().has_table( + connection, + table_name, + schema, + **kwargs, + ), + ) + return True + + +def get_adapter_for_table_name( + connection: _ConnectionFairy, + table_name: str, +) -> Adapter: + """ + Return an adapter associated with a connection. + + This function instantiates the adapter responsible for a given table name, + using the connection to properly pass any adapter kwargs. + """ + raw_connection = cast(extensions.connection, connection.engine.raw_connection()) + cursor = raw_connection.cursor() + adapter, args, kwargs = find_adapter( + table_name, + cursor._adapter_kwargs, + cursor._adapters, + ) + return adapter(*args, **kwargs) diff --git a/src/shillelagh/backends/multicorn/fdw.py b/src/shillelagh/backends/multicorn/fdw.py new file mode 100644 index 00000000..def86730 --- /dev/null +++ b/src/shillelagh/backends/multicorn/fdw.py @@ -0,0 +1,161 @@ +""" +An FDW. +""" + +from collections import defaultdict +from typing import ( + Any, + DefaultDict, + Dict, + Iterator, + List, + Optional, + Set, + Tuple, + TypedDict, +) + +from multicorn import ForeignDataWrapper, Qual, SortKey + +from shillelagh.adapters.registry import registry +from shillelagh.fields import Order +from shillelagh.filters import Operator +from shillelagh.lib import deserialize, get_bounds +from shillelagh.typing import RequestedOrder, Row + +operator_map = { + "=": Operator.EQ, + ">": Operator.GT, + "<": Operator.LT, + ">=": Operator.GE, + "<=": Operator.LE, +} + + +def get_all_bounds(quals: List[Qual]) -> DefaultDict[str, Set[Tuple[Operator, Any]]]: + """ + Convert list of ``Qual`` into a set of operators for each column. + """ + all_bounds: DefaultDict[str, Set[Tuple[Operator, Any]]] = defaultdict(set) + for qual in quals: + if operator := operator_map.get(qual.operator): + all_bounds[qual.field_name].add((operator, qual.value)) + + return all_bounds + + +class OptionsType(TypedDict): + """ + Type for OPTIONS. + """ + + adapter: str + args: str + + +class MulticornForeignDataWrapper(ForeignDataWrapper): + """ + A FDW that dispatches queries to adapters. + """ + + def __init__(self, options: OptionsType, columns: Dict[str, str]): + super().__init__(options, columns) + + deserialized_args = deserialize(options["args"]) + self.adapter = registry.load(options["adapter"])(*deserialized_args) + self.columns = self.adapter.get_columns() + + def execute( + self, + quals: List[Qual], + columns: List[str], + sortkeys: Optional[List[SortKey]] = None, + ) -> Iterator[Row]: + """ + Execute a query. + """ + all_bounds = get_all_bounds(quals) + bounds = get_bounds(self.columns, all_bounds) + + order: List[Tuple[str, RequestedOrder]] = [ + (key.attname, Order.DESCENDING if key.is_reversed else Order.ASCENDING) + for key in sortkeys or [] + ] + + kwargs = ( + {"requested_columns": columns} + if self.adapter.supports_requested_columns + else {} + ) + + return self.adapter.get_rows(bounds, order, **kwargs) + + def can_sort(self, sortkeys: List[SortKey]) -> List[SortKey]: + """ + Return a list of sorts the adapter can perform. + """ + + def is_sortable(key: SortKey) -> bool: + """ + Return if a given sort key can be enforced by the adapter. + """ + if key.attname not in self.columns: + return False + + order = self.columns[key.attname].order + return ( + order == Order.ANY + or (order == Order.ASCENDING and not key.is_reversed) + or (order == Order.DESCENDING and key.is_reversed) + ) + + return [key for key in sortkeys if is_sortable(key)] + + def insert(self, values: Row) -> Row: + rowid = self.adapter.insert_row(values) + values["rowid"] = rowid + return values + + def delete(self, oldvalues: Row) -> None: + rowid = oldvalues["rowid"] + self.adapter.delete_row(rowid) + + def update(self, oldvalues: Row, newvalues: Row) -> Row: + rowid = newvalues["rowid"] + self.adapter.update_row(rowid, newvalues) + return newvalues + + @property + def rowid_column(self): + return "rowid" + + def get_rel_size(self, quals: List[Qual], columns: List[str]) -> Tuple[int, int]: + """ + Estimate query cost. + """ + all_bounds = get_all_bounds(quals) + filtered_columns = [ + (column, operator[0]) + for column, operators in all_bounds.items() + for operator in operators + ] + + # the adapter returns an arbitrary cost that takes in consideration filtering and + # sorting; let's use that as an approximation for rows + rows = int(self.adapter.get_cost(filtered_columns, [])) + + # same assumption as the parent class + row_width = len(columns) * 100 + + return (rows, row_width) + + @classmethod + def import_schema( # pylint: disable=too-many-arguments + cls, + schema: str, + srv_options: Dict[str, str], + options: Dict[str, str], + restriction_type: Optional[str], + restricts: List[str], + ): + return [] diff --git a/src/shillelagh/lib.py b/src/shillelagh/lib.py index 409efc37..d35a0dbf 100644 --- a/src/shillelagh/lib.py +++ b/src/shillelagh/lib.py @@ -11,6 +11,7 @@ from typing import ( Any, Callable, + DefaultDict, Dict, Iterator, List, @@ -641,3 +642,22 @@ def get_session( session.headers.update(request_headers) return session + + +def get_bounds( + columns: Dict[str, Field], + all_bounds: DefaultDict[str, Set[Tuple[Operator, Any]]], +) -> Dict[str, Filter]: + """ + Combine all filters that apply to each column. + """ + bounds: Dict[str, Filter] = {} + for column_name, operations in all_bounds.items(): + column_type = columns[column_name] + operators = {operation[0] for operation in operations} + for class_ in column_type.filters: + if all(operator in class_.operators for operator in operators): + bounds[column_name] = class_.build(operations) + break + + return bounds diff --git a/tests/adapters/api/gsheets/integration_test.py b/tests/adapters/api/gsheets/integration_test.py index aa20de0c..4653293d 100644 --- a/tests/adapters/api/gsheets/integration_test.py +++ b/tests/adapters/api/gsheets/integration_test.py @@ -15,6 +15,7 @@ from shillelagh.adapters.api.gsheets.types import SyncMode from shillelagh.backends.apsw.db import connect +from shillelagh.backends.multicorn.db import connect as connect_multicorn @pytest.mark.skip("Credentials no longer valid") @@ -727,3 +728,59 @@ def test_weird_symbols(adapter_kwargs: Dict[str, Any]) -> None: assert cursor.fetchall() == [(1.0, "a", 45.0), (2.0, "b", 1999.0)] assert cursor.description is not None assert [column[0] for column in cursor.description] == ['foo"', '"bar', 'a"b'] + + +@pytest.mark.slow_integration_test +def test_public_sheet_apsw() -> None: + """ + Test reading values from a public sheet with APSW. + """ + table = ( + '"https://docs.google.com/spreadsheets/d/' + '1LcWZMsdCl92g7nA-D6qGRqg1T5TiHyuKJUY1u9XAnsk/edit#gid=0"' + ) + + connection = connect(":memory:") + cursor = connection.cursor() + sql = f"SELECT * FROM {table}" + cursor.execute(sql) + assert cursor.fetchall() == [ + ("BR", 2), + ("BR", 4), + ("ZA", 7), + ("CR", 11), + ("CR", 11), + ("FR", 100), + ("AR", 42), + ] + + +@pytest.mark.slow_integration_test +def test_public_sheet_multicorn() -> None: + """ + Test reading values from a public sheet with Multicorn2. + """ + table = ( + '"https://docs.google.com/spreadsheets/d/' + '1LcWZMsdCl92g7nA-D6qGRqg1T5TiHyuKJUY1u9XAnsk/edit#gid=0"' + ) + + connection = connect_multicorn( + user="shillelagh", + password="shillelagh123", + host="localhost", + port=5432, + database="shillelagh", + ) + cursor = connection.cursor() + sql = f"SELECT * FROM {table}" + cursor.execute(sql) + assert cursor.fetchall() == [ + ("BR", 2), + ("BR", 4), + ("ZA", 7), + ("CR", 11), + ("CR", 11), + ("FR", 100), + ("AR", 42), + ] diff --git a/tests/backends/apsw/db_test.py b/tests/backends/apsw/db_test.py index 662bc14c..206c0bad 100644 --- a/tests/backends/apsw/db_test.py +++ b/tests/backends/apsw/db_test.py @@ -1,8 +1,9 @@ -# pylint: disable=protected-access, c-extension-no-member, too-few-public-methods """ Tests for shillelagh.backends.apsw.db. """ +# pylint: disable=protected-access, c-extension-no-member, too-few-public-methods + import datetime from typing import Any, List, Tuple from unittest import mock diff --git a/tests/backends/apsw/dialects/base_test.py b/tests/backends/apsw/dialects/base_test.py index a578c2fd..37586d0a 100644 --- a/tests/backends/apsw/dialects/base_test.py +++ b/tests/backends/apsw/dialects/base_test.py @@ -8,12 +8,20 @@ from sqlalchemy import MetaData, Table, create_engine, func, inspect, select from shillelagh.adapters.registry import AdapterLoader +from shillelagh.backends.apsw import db from shillelagh.backends.apsw.dialects.base import APSWDialect from shillelagh.exceptions import ProgrammingError from ....fakes import FakeAdapter +def test_dbapi() -> None: + """ + Test the ``dbapi`` and ``import_dbapi`` methods. + """ + assert APSWDialect.dbapi() == APSWDialect.import_dbapi() == db + + def test_create_engine(registry: AdapterLoader) -> None: """ Test ``create_engine``. diff --git a/tests/backends/multicorn/__init__.py b/tests/backends/multicorn/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/backends/multicorn/db_test.py b/tests/backends/multicorn/db_test.py new file mode 100644 index 00000000..41f989cc --- /dev/null +++ b/tests/backends/multicorn/db_test.py @@ -0,0 +1,236 @@ +""" +Tests for the Multicorn2 DB API 2.0 wrapper. +""" + +# pylint: disable=invalid-name, redefined-outer-name, no-member, redefined-builtin + +import psycopg2 +import pytest +from pytest_mock import MockerFixture + +from shillelagh.adapters.registry import AdapterLoader +from shillelagh.backends.multicorn.db import Cursor, CursorFactory, connect +from shillelagh.exceptions import ProgrammingError + +from ...fakes import FakeAdapter + + +def test_connect(mocker: MockerFixture, registry: AdapterLoader) -> None: + """ + Test the ``connect`` function. + """ + psycopg2 = mocker.patch("shillelagh.backends.multicorn.db.psycopg2") + CursorFactory = mocker.patch("shillelagh.backends.multicorn.db.CursorFactory") + + registry.add("dummy", FakeAdapter) + + connect( + None, + ["dummy"], + user="username", + password="password", + host="host", + port=1234, + database="database", + ) + psycopg2.connect.assert_called_with( + None, + cursor_factory=CursorFactory( + {"dummy": FakeAdapter}, + {}, + "main", + ), + user="username", + password="password", + host="host", + port=1234, + database="database", + ) + + +def test_cursor_factory(mocker: MockerFixture) -> None: + """ + Test the ``CursorFactory`` class. + """ + Cursor = mocker.patch("shillelagh.backends.multicorn.db.Cursor") + + cursor_factory = CursorFactory( + {"dummy": FakeAdapter}, + {}, + "main", + ) + assert cursor_factory( + user="username", + password="password", + host="host", + port=1234, + database="database", + ) == Cursor( + adapters=["dummy"], + adapter_kwargs={}, + schema="main", + user="username", + password="password", + host="host", + port=1234, + database="database", + ) + + +def test_cursor(mocker: MockerFixture) -> None: + """ + Test the ``Cursor`` class. + """ + mocker.patch("shillelagh.backends.multicorn.db.uuid4", return_value="uuid") + super = mocker.patch("shillelagh.backends.multicorn.db.super", create=True) + execute = mocker.MagicMock(name="execute") + super.return_value.execute = execute + + cursor = Cursor( + adapters={"dummy": FakeAdapter}, + adapter_kwargs={}, + schema="main", + ) + + cursor.execute("SELECT 1") + execute.assert_has_calls( + [ + mocker.call('SAVEPOINT "uuid"'), + mocker.call("SELECT 1", None), + ], + ) + + execute.reset_mock() + execute.side_effect = [ + True, # SAVEPOINT + psycopg2.errors.UndefinedTable('relation "dummy://" does not exist'), + True, # ROLLBACK + True, # CREATE SERVER + True, # CREATE FOREIGN TABLE + True, # SAVEPOINT + True, # successful query + ] + + cursor.execute('SELECT * FROM "dummy://"') + execute.assert_has_calls( + [ + mocker.call('SAVEPOINT "uuid"'), + mocker.call('SELECT * FROM "dummy://"', None), + mocker.call('ROLLBACK TO SAVEPOINT "uuid"'), + mocker.call( + """ +CREATE SERVER shillelagh foreign data wrapper multicorn options ( + wrapper 'shillelagh.backends.multicorn.fdw.MulticornForeignDataWrapper' +); + """, + ), + mocker.call( + """ +CREATE FOREIGN TABLE "dummy://" ( + "age" REAL, "name" TEXT, "pets" INTEGER +) server shillelagh options ( + adapter \'dummy\', + args \'qQA=\' +); + """, + ), + mocker.call('SAVEPOINT "uuid"'), + mocker.call('SELECT * FROM "dummy://"', None), + ], + ) + + +def test_cursor_no_table_match(mocker: MockerFixture) -> None: + """ + Test an edge case where ``UndefinedTable`` is raised with a different message. + """ + super = mocker.patch("shillelagh.backends.multicorn.db.super", create=True) + execute = mocker.MagicMock(name="execute") + super.return_value.execute = execute + + execute.side_effect = [ + True, # SAVEPOINT + psycopg2.errors.UndefinedTable("An unexpected error occurred"), + ] + + cursor = Cursor( + adapters={"dummy": FakeAdapter}, + adapter_kwargs={}, + schema="main", + ) + + with pytest.raises(ProgrammingError) as excinfo: + cursor.execute('SELECT * FROM "dummy://"') + assert str(excinfo.value) == "An unexpected error occurred" + + +def test_cursor_no_table_name(mocker: MockerFixture) -> None: + """ + Test an edge case where we can't determine the table name from the exception. + """ + super = mocker.patch("shillelagh.backends.multicorn.db.super", create=True) + execute = mocker.MagicMock(name="execute") + super.return_value.execute = execute + + execute.side_effect = [ + True, # SAVEPOINT + psycopg2.errors.UndefinedTable('relation "invalid://" does not exist'), + ] + + cursor = Cursor( + adapters={"dummy": FakeAdapter}, + adapter_kwargs={}, + schema="main", + ) + + with pytest.raises(ProgrammingError) as excinfo: + cursor.execute('SELECT * FROM "dummy://"') + assert str(excinfo.value) == "Could not determine table name" + + +def test_drop_table(mocker: MockerFixture) -> None: + """ + Test the ``drop_table`` method. + """ + super = mocker.patch("shillelagh.backends.multicorn.db.super", create=True) + execute = mocker.MagicMock(name="execute") + super.return_value.execute = execute + adapter = mocker.MagicMock(name="adapter") + mocker.patch( + "shillelagh.backends.multicorn.db.find_adapter", + return_value=(adapter, ["one"], {"two": 2}), + ) + + cursor = Cursor( + adapters={"dummy": FakeAdapter}, + adapter_kwargs={}, + schema="main", + ) + + cursor.execute('DROP TABLE "dummy://"') + adapter.assert_called_with("one", two=2) + adapter().drop_table.assert_called() + + +def test_table_without_columns(mocker: MockerFixture) -> None: + """ + Test an edge case where a virtual table has no columns. + """ + super = mocker.patch("shillelagh.backends.multicorn.db.super", create=True) + execute = mocker.MagicMock(name="execute") + super.return_value.execute = execute + adapter = mocker.MagicMock(name="adapter") + adapter().get_columns.return_value = [] + mocker.patch( + "shillelagh.backends.multicorn.db.find_adapter", + return_value=(adapter, ["one"], {"two": 2}), + ) + + cursor = Cursor( + adapters={"dummy": adapter}, + adapter_kwargs={}, + schema="main", + ) + with pytest.raises(ProgrammingError) as excinfo: + cursor._create_table("dummy://") # pylint: disable=protected-access + assert str(excinfo.value) == "Virtual table dummy:// has no columns" diff --git a/tests/backends/multicorn/dialects/__init__.py b/tests/backends/multicorn/dialects/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/backends/multicorn/dialects/base_test.py b/tests/backends/multicorn/dialects/base_test.py new file mode 100644 index 00000000..1d5610fe --- /dev/null +++ b/tests/backends/multicorn/dialects/base_test.py @@ -0,0 +1,82 @@ +""" +Tests for the multicorn dialect. +""" + +from pytest_mock import MockerFixture +from sqlalchemy.engine.url import make_url + +from shillelagh.backends.multicorn import db +from shillelagh.backends.multicorn.db import Cursor +from shillelagh.backends.multicorn.dialects.base import ( + Multicorn2Dialect, + get_adapter_for_table_name, +) +from shillelagh.exceptions import ProgrammingError + +from ....fakes import FakeAdapter + + +def test_dbapi() -> None: + """ + Test the ``dbapi`` and ``import_dbapi`` methods. + """ + assert Multicorn2Dialect.dbapi() == Multicorn2Dialect.import_dbapi() == db + + +def test_create_connect_args() -> None: + """ + Test ``create_connect_args``. + """ + dialect = Multicorn2Dialect(["dummy"], {}) + assert dialect.create_connect_args( + make_url( + "shillelagh+multicorn2://shillelagh:shillelagh123@localhost:12345/shillelagh", + ), + ) == ( + [], + { + "adapter_kwargs": {}, + "adapters": ["dummy"], + "user": "shillelagh", + "password": "shillelagh123", + "host": "localhost", + "port": 12345, + "database": "shillelagh", + }, + ) + + +def test_has_table(mocker: MockerFixture) -> None: + """ + Test ``has_table``. + """ + super = mocker.patch( # pylint: disable=redefined-builtin + "shillelagh.backends.multicorn.dialects.base.super", + create=True, + ) + has_table = mocker.MagicMock(name="has_table", return_value=False) + super.return_value.has_table = has_table + mocker.patch( + "shillelagh.backends.multicorn.dialects.base.get_adapter_for_table_name", + side_effect=[True, ProgrammingError('No adapter for table "dummy://".')], + ) + connection = mocker.MagicMock() + + dialect = Multicorn2Dialect(["dummy"], {}) + assert dialect.has_table(connection, "dummy://") is True + assert dialect.has_table(connection, "my_table") is False + + +def test_get_adapter_for_table_name(mocker: MockerFixture) -> None: + """ + Test the ``get_adapter_for_table_name`` function. + """ + mocker.patch("shillelagh.backends.multicorn.db.super", create=True) + connection = mocker.MagicMock() + connection.engine.raw_connection().cursor.return_value = Cursor( + adapters={"dummy": FakeAdapter}, + adapter_kwargs={}, + schema="main", + ) + + assert isinstance(get_adapter_for_table_name(connection, "dummy://"), FakeAdapter) diff --git a/tests/backends/multicorn/fdw_test.py b/tests/backends/multicorn/fdw_test.py new file mode 100644 index 00000000..f4518664 --- /dev/null +++ b/tests/backends/multicorn/fdw_test.py @@ -0,0 +1,212 @@ +""" +Tests for the Multicorn2 FDW. +""" + +# pylint: disable=invalid-name, redefined-outer-name, no-member, redefined-builtin + +from collections import defaultdict + +from multicorn import Qual, SortKey +from pytest_mock import MockerFixture + +from shillelagh.adapters.registry import AdapterLoader +from shillelagh.backends.multicorn.fdw import ( + MulticornForeignDataWrapper, + get_all_bounds, +) +from shillelagh.filters import Operator + +from ...fakes import FakeAdapter + + +def test_fdw(mocker: MockerFixture, registry: AdapterLoader) -> None: + """ + Test the ``MulticornForeignDataWrapper`` class. + """ + mocker.patch("shillelagh.backends.multicorn.fdw.registry", registry) + + registry.add("dummy", FakeAdapter) + + assert ( + MulticornForeignDataWrapper.import_schema("schema", {}, {}, "limit", []) == [] + ) + wrapper = MulticornForeignDataWrapper( + {"adapter": "dummy", "args": "qQA="}, + {}, + ) + assert wrapper.rowid_column == "rowid" + + assert list(wrapper.execute([], ["rowid", "name", "age", "pets"])) == [ + {"rowid": 0, "name": "Alice", "age": 20, "pets": 0}, + {"rowid": 1, "name": "Bob", "age": 23, "pets": 3}, + ] + + assert list( + wrapper.execute( + [Qual("age", ">", 21)], + ["rowid", "name", "age", "pets"], + [], + ), + ) == [ + {"rowid": 1, "name": "Bob", "age": 23, "pets": 3}, + ] + + assert list( + wrapper.execute( + [], + ["rowid", "name", "age", "pets"], + [ + SortKey( + attname="age", + attnum=2, + is_reversed=True, + nulls_first=True, + collate=None, + ), + ], + ), + ) == [ + {"rowid": 1, "name": "Bob", "age": 23, "pets": 3}, + {"rowid": 0, "name": "Alice", "age": 20, "pets": 0}, + ] + + +def test_get_all_bounds() -> None: + """ + Test ``get_all_bounds``. + """ + quals = [ + Qual("column1", "=", 3), + Qual("column2", "LIKE", "test%"), + Qual("column3", ">", 10), + ] + + assert get_all_bounds([]) == defaultdict(set) + assert get_all_bounds([quals[0]]) == {"column1": {(Operator.EQ, 3)}} + assert get_all_bounds(quals) == { + "column1": {(Operator.EQ, 3)}, + "column3": {(Operator.GT, 10)}, + } + assert get_all_bounds([Qual("column4", "unsupported_operator", 1)]) == defaultdict( + set, + ) + + +def test_can_sort(mocker: MockerFixture, registry: AdapterLoader) -> None: + """ + Test the ``can_sort`` method. + """ + mocker.patch("shillelagh.backends.multicorn.fdw.registry", registry) + + registry.add("dummy", FakeAdapter) + + wrapper = MulticornForeignDataWrapper( + {"adapter": "dummy", "args": "qQA="}, + {}, + ) + assert wrapper.can_sort([]) == [] + assert wrapper.can_sort( + [ + SortKey( + attname="age", + attnum=2, + is_reversed=True, + nulls_first=True, + collate=None, + ), + SortKey( + attname="foobar", + attnum=1, + is_reversed=True, + nulls_first=True, + collate=None, + ), + ], + ) == [ + SortKey( + attname="age", + attnum=2, + is_reversed=True, + nulls_first=True, + collate=None, + ), + ] + + +def test_insert(mocker: MockerFixture, registry: AdapterLoader) -> None: + """ + Test the ``insert`` method. + """ + mocker.patch("shillelagh.backends.multicorn.fdw.registry", registry) + + registry.add("dummy", FakeAdapter) + + wrapper = MulticornForeignDataWrapper( + {"adapter": "dummy", "args": "qQA="}, + {}, + ) + + wrapper.insert({"rowid": 2, "name": "Charlie", "age": 6, "pets": 1}) + assert list(wrapper.execute([], ["rowid", "name", "age", "pets"])) == [ + {"rowid": 0, "name": "Alice", "age": 20, "pets": 0}, + {"rowid": 1, "name": "Bob", "age": 23, "pets": 3}, + {"rowid": 2, "name": "Charlie", "age": 6, "pets": 1}, + ] + + +def test_delete(mocker: MockerFixture, registry: AdapterLoader) -> None: + """ + Test the ``delete`` method. + """ + mocker.patch("shillelagh.backends.multicorn.fdw.registry", registry) + + registry.add("dummy", FakeAdapter) + + wrapper = MulticornForeignDataWrapper( + {"adapter": "dummy", "args": "qQA="}, + {}, + ) + + wrapper.delete({"rowid": 1, "name": "Bob", "age": 23, "pets": 3}) + assert list(wrapper.execute([], ["rowid", "name", "age", "pets"])) == [ + {"rowid": 0, "name": "Alice", "age": 20, "pets": 0}, + ] + + +def test_update(mocker: MockerFixture, registry: AdapterLoader) -> None: + """ + Test the ``update`` method. + """ + mocker.patch("shillelagh.backends.multicorn.fdw.registry", registry) + + registry.add("dummy", FakeAdapter) + + wrapper = MulticornForeignDataWrapper( + {"adapter": "dummy", "args": "qQA="}, + {}, + ) + + wrapper.update( + {"rowid": 0, "name": "Alice", "age": 20, "pets": 0}, + {"rowid": 0, "name": "Alice", "age": 20, "pets": 1}, + ) + assert list(wrapper.execute([], ["rowid", "name", "age", "pets"])) == [ + {"rowid": 1, "name": "Bob", "age": 23, "pets": 3}, + {"rowid": 0, "name": "Alice", "age": 20, "pets": 1}, + ] + + +def test_get_rel_Size(mocker: MockerFixture, registry: AdapterLoader) -> None: + """ + Test the ``get_rel_size`` method. + """ + mocker.patch("shillelagh.backends.multicorn.fdw.registry", registry) + + registry.add("dummy", FakeAdapter) + + wrapper = MulticornForeignDataWrapper( + {"adapter": "dummy", "args": "qQA="}, + {}, + ) + + assert wrapper.get_rel_size([Qual("age", ">", 21)], ["name", "age"]) == (666, 200)