Skip to content

Commit

Permalink
feat: allow preloading of extensions
Browse files Browse the repository at this point in the history
  • Loading branch information
Mause committed Aug 19, 2022
1 parent 2e200bb commit 13a92e1
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 6 deletions.
36 changes: 31 additions & 5 deletions duckdb_engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type

import duckdb
from sqlalchemy import pool
from sqlalchemy import String, pool
from sqlalchemy import types as sqltypes
from sqlalchemy import util
from sqlalchemy.dialects.postgresql.base import PGInspector, PGTypeCompiler
Expand All @@ -10,6 +10,7 @@
from sqlalchemy.ext.compiler import compiles

from . import datatypes
from .config import get_core_config

__version__ = "0.5.0"

Expand All @@ -33,6 +34,9 @@ class DBAPI:
# this is being fixed upstream to add a proper exception hierarchy
Error = getattr(duckdb, "Error", RuntimeError)

IOException = getattr(duckdb, "IOException", RuntimeError)
CatalogException = getattr(duckdb, "CatalogException", RuntimeError)

@staticmethod
def Binary(x: Any) -> Any:
return x
Expand Down Expand Up @@ -144,12 +148,34 @@ class Dialect(PGDialect_psycopg2):
},
)

def __init__(self, *args: Any, **kwargs: Any) -> None:
def __init__(self, **kwargs: Any) -> None:
kwargs["use_native_hstore"] = False
super().__init__(*args, **kwargs)
super().__init__(**kwargs)

def connect(self, *cargs: Any, **cparams: Any) -> "Connection":
return ConnectionWrapper(duckdb.connect(*cargs, **cparams))

core_keys = get_core_config()
preload_extensions = cparams.pop("preload_extensions", [])
config = cparams.get("config", {})

ext = {k: config.pop(k) for k in list(config) if k not in core_keys}

conn = duckdb.connect(*cargs, **cparams)

for extension in preload_extensions:
try:
conn.execute(f"LOAD {extension}")
except self.dbapi().IOException:
pass

for k, v in ext.items():
v = String().literal_processor(dialect=self)(v)
try:
conn.execute(f"SET {k} = {v}")
except self.dbapi().CatalogException:
pass

return ConnectionWrapper(conn)

def on_connect(self) -> None:
pass
Expand Down Expand Up @@ -189,7 +215,7 @@ def get_view_names(
connection: Any,
schema: Optional[Any] = ...,
include: Any = ...,
**kw: Any
**kw: Any,
) -> Any:
s = "SELECT name FROM sqlite_master WHERE type='view' ORDER BY name"
rs = connection.exec_driver_sql(s)
Expand Down
14 changes: 14 additions & 0 deletions duckdb_engine/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from functools import lru_cache
from typing import Set

import duckdb


@lru_cache()
def get_core_config() -> Set[str]:
rows = (
duckdb.connect(":memory:")
.execute("SELECT name FROM duckdb_settings()")
.fetchall()
)
return {name for name, in rows}
17 changes: 16 additions & 1 deletion duckdb_engine/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import duckdb
from hypothesis import assume, given, settings
from hypothesis.strategies import text as text_strat
from pytest import LogCaptureFixture, fixture, importorskip, mark, raises
from pytest import LogCaptureFixture, fixture, importorskip, mark, raises, skip
from sqlalchemy import (
Column,
ForeignKey,
Expand Down Expand Up @@ -151,6 +151,21 @@ def test_get_views(engine: Engine) -> None:
assert views == ["test"]


def test_preload_extension() -> None:
try:
duckdb.default_connection.execute("INSTALL https")
except Exception as e:
skip(str(e))
engine = create_engine(
"duckdb:///",
connect_args={
"preload_extensions": ["httpfs"],
"config": {"s3_region": "ap-southeast-2"},
},
)
engine.connect()


@fixture
def inspector(engine: Engine, session: Session) -> Inspector:
session.execute(text("create table test (id int);"))
Expand Down

0 comments on commit 13a92e1

Please sign in to comment.