Skip to content

Commit

Permalink
feat: Add support for Snowflake schemas.
Browse files Browse the repository at this point in the history
  • Loading branch information
DanCardin committed Feb 16, 2024
1 parent 42b861a commit 235eb57
Show file tree
Hide file tree
Showing 18 changed files with 999 additions and 274 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ jobs:
- name: Install dependencies
run: poetry install -E parse

- name: Install snowflake
- if: ${{ matrix.sqlalchemy-version == '1.4.0' }}
run: |
poetry run pip install 'snowflake-sqlalchemy'
- name: Install specific sqlalchemy version
run: |
poetry run pip install 'sqlalchemy~=${{ matrix.sqlalchemy-version }}'
Expand Down
225 changes: 126 additions & 99 deletions README.md

Large diffs are not rendered by default.

842 changes: 706 additions & 136 deletions poetry.lock

Large diffs are not rendered by default.

10 changes: 6 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "sqlalchemy-declarative-extensions"
version = "0.6.10"
version = "0.7.0"
authors = ["Dan Cardin <ddcardin@gmail.com>"]

description = "Library to declare additional kinds of objects not natively supported by SQLAlchemy/Alembic."
Expand Down Expand Up @@ -48,14 +48,15 @@ psycopg = "*"
alembic-utils = "0.8.1"
black = ">=22.3.0"
coverage = ">=5"
ruff = "0.1.15"
fakesnow = {version = "*", python = ">=3.9"}
mypy = "1.4.1"
pymysql = {version = "*", extras = ["rsa"]}
pytest = ">=7"
pytest-xdist = "*"
pytest-alembic = "*"
pytest-mock-resources = { version = ">=2.6.13", extras = ["docker"] }
pytest-xdist = "*"
ruff = "0.1.15"
sqlalchemy = {version = ">=1.4", extras = ["mypy"]}
pymysql = {version = "*", extras = ["rsa"]}

[tool.poetry.extras]
alembic = ["alembic"]
Expand Down Expand Up @@ -92,6 +93,7 @@ filterwarnings = [
'error',
'ignore:invalid escape sequence.*',
'ignore:distutils Version classes are deprecated. Use packaging.version instead.:DeprecationWarning',
'ignore:_SixMetaPathImporter.find_spec.*:ImportWarning'
]
pytester_example_dir = "tests/examples"
markers = [
Expand Down
25 changes: 11 additions & 14 deletions src/sqlalchemy_declarative_extensions/alembic/schema.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from alembic.autogenerate.api import AutogenContext
from alembic.autogenerate.compare import comparators
from alembic.autogenerate.render import renderers
from alembic.operations import Operations
from sqlalchemy.schema import CreateSchema, DropSchema

from sqlalchemy_declarative_extensions import schema
from sqlalchemy_declarative_extensions.schema.base import Schemas
Expand All @@ -15,30 +15,27 @@


@comparators.dispatch_for("schema")
def compare_schemas(autogen_context, upgrade_ops, schemas: Schemas):
def compare_schemas(autogen_context: AutogenContext, upgrade_ops, schemas: Schemas):
assert autogen_context.metadata
schemas = autogen_context.metadata.info.get("schemas")
if not schemas:
return

assert autogen_context.connection
result = schema.compare.compare_schemas(autogen_context.connection, schemas)
upgrade_ops.ops[0:0] = result


@renderers.dispatch_for(CreateSchemaOp)
def render_create_schema(_, op: CreateSchemaOp):
return f"op.create_schema('{op.schema.name}')"


@renderers.dispatch_for(DropSchemaOp)
def render_drop_schema(_, op: DropSchemaOp):
return f"op.drop_schema('{op.schema.name}')"
def render_create_schema(autogen_context: AutogenContext, op: CreateSchemaOp):
statement = op.to_sql()
cls_name = statement.__class__.__name__
autogen_context.imports.add(f"from sqlalchemy.sql.ddl import {cls_name}")
return f'op.execute({cls_name}("{statement.element}"))'


@Operations.implementation_for(CreateSchemaOp)
def create_schema(operations, operation: CreateSchemaOp):
operations.execute(CreateSchema(operation.schema.name))


@Operations.implementation_for(DropSchemaOp)
def drop_schema(operations, operation: DropSchemaOp):
operations.execute(DropSchema(operation.schema.name))
def create_schema(operations, operation: CreateSchemaOp):
operations.execute(operation.to_sql())
11 changes: 5 additions & 6 deletions src/sqlalchemy_declarative_extensions/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,11 @@ def register_sqlalchemy_events(
concrete_rows = metadata.info.get("rows")

if concrete_schemas and schemas:
for schema in concrete_schemas:
event.listen(
metadata,
"before_create",
schema_ddl(schema),
)
event.listen(
metadata,
"before_create",
schema_ddl,
)

if concrete_roles and roles:
event.listen(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_schemas_postgresql(connection: Connection):

def check_schema_exists_postgresql(connection: Connection, name: str) -> bool:
row = connection.execute(schema_exists_query, {"schema": name}).scalar()
return not bool(row)
return bool(row)


def get_objects_postgresql(connection: Connection):
Expand Down
8 changes: 8 additions & 0 deletions src/sqlalchemy_declarative_extensions/dialects/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,28 @@
get_view_postgresql,
get_views_postgresql,
)
from sqlalchemy_declarative_extensions.dialects.snowflake.query import (
check_schema_exists_snowflake,
get_schemas_snowflake,
)
from sqlalchemy_declarative_extensions.dialects.sqlite.query import (
check_schema_exists_sqlite,
get_schemas_sqlite,
get_views_sqlite,
)
from sqlalchemy_declarative_extensions.sqlalchemy import dialect_dispatch, select

get_schemas = dialect_dispatch(
postgresql=get_schemas_postgresql,
sqlite=get_schemas_sqlite,
snowflake=get_schemas_snowflake,
)

check_schema_exists = dialect_dispatch(
postgresql=check_schema_exists_postgresql,
sqlite=check_schema_exists_sqlite,
mysql=check_schema_exists_mysql,
snowflake=check_schema_exists_snowflake,
)

get_objects = dialect_dispatch(
Expand Down
Empty file.
28 changes: 28 additions & 0 deletions src/sqlalchemy_declarative_extensions/dialects/snowflake/query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from __future__ import annotations

from sqlalchemy.engine.base import Connection
from sqlalchemy.sql.expression import text


def get_schemas_snowflake(connection: Connection):
from sqlalchemy_declarative_extensions.schema.base import Schema

schemas_query = text(
"SELECT schema_name"
" FROM information_schema.schemata"
" WHERE schema_name NOT IN ('information_schema', 'pg_catalog', 'main')"
)

return {
Schema(schema) for schema, *_ in connection.execute(schemas_query).fetchall()
}


def check_schema_exists_snowflake(connection: Connection, name: str) -> bool:
schema_exists_query = text(
"SELECT schema_name"
" FROM information_schema.schemata"
" WHERE schema_name = :schema"
)
row = connection.execute(schema_exists_query, {"schema": name}).scalar()
return bool(row)
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@
from sqlalchemy_declarative_extensions.view.base import View


def get_schemas_sqlite(connection: Connection):
from sqlalchemy_declarative_extensions.schema.base import Schema

schemas = connection.execute(text("PRAGMA database_list")).fetchall()
return {Schema(schema) for _, schema, *_ in schemas if schema not in {"main"}}


def check_schema_exists_sqlite(connection: Connection, name: str) -> bool:
"""Check whether the given schema exists.
Expand Down
2 changes: 2 additions & 0 deletions src/sqlalchemy_declarative_extensions/schema/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from sqlalchemy_declarative_extensions.schema import compare
from sqlalchemy_declarative_extensions.schema.base import Schema, Schemas

__all__ = [
"Schema",
"Schemas",
"compare",
]
7 changes: 7 additions & 0 deletions src/sqlalchemy_declarative_extensions/schema/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Union

from sqlalchemy.engine.base import Connection
from sqlalchemy.sql.ddl import CreateSchema, DropSchema

from sqlalchemy_declarative_extensions.dialects import get_schemas
from sqlalchemy_declarative_extensions.schema.base import Schema, Schemas
Expand All @@ -21,6 +22,9 @@ def create_schema(cls, operations, schema, **kwargs):
def reverse(self):
return DropSchemaOp(self.schema)

def to_sql(self):
return CreateSchema(self.schema.name)


@dataclass
class DropSchemaOp:
Expand All @@ -34,6 +38,9 @@ def drop_schema(cls, operations, schema, **kwargs):
def reverse(self):
return CreateSchemaOp(self.schema)

def to_sql(self):
return DropSchema(self.schema.name)


SchemaOp = Union[CreateSchemaOp, DropSchemaOp]

Expand Down
22 changes: 13 additions & 9 deletions src/sqlalchemy_declarative_extensions/schema/ddl.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from sqlalchemy.schema import CreateSchema
from __future__ import annotations

from sqlalchemy_declarative_extensions.dialects import check_schema_exists
from sqlalchemy_declarative_extensions.schema import Schema
from sqlalchemy import MetaData
from sqlalchemy.engine import Connection

from sqlalchemy_declarative_extensions.schema import Schemas
from sqlalchemy_declarative_extensions.schema.compare import compare_schemas

def schema_ddl(schema: Schema):
ddl = CreateSchema(schema.name)
return ddl.execute_if(callable_=check_schema)

def schema_ddl(metadata: MetaData, connection: Connection, **_):
roles: Schemas | None = metadata.info.get("schemas")
if not roles: # pragma: no cover
return

def check_schema(ddl, target, connection, **_):
schema = ddl.element
return check_schema_exists(connection, name=schema)
result = compare_schemas(connection, roles)
for op in result:
statements = op.to_sql()
connection.execute(statements)
3 changes: 2 additions & 1 deletion src/sqlalchemy_declarative_extensions/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@ class HasMetaData(Protocol):
metadata: MetaData


def dialect_dispatch(postgresql=None, sqlite=None, mysql=None):
def dialect_dispatch(postgresql=None, sqlite=None, mysql=None, snowflake=None):
dispatchers = {
"postgresql": postgresql,
"sqlite": sqlite,
"mysql": mysql,
"snowflake": snowflake,
}

def dispatch(connection: Connection, *args, **kwargs):
Expand Down
17 changes: 17 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,23 @@ def pmr_mysql_config():
return MysqlConfig(image="mysql:8", port=None, ci_port=None)


@pytest.fixture
def snowflake():
try:
import snowflake.sqlalchemy
except ImportError:
pytest.skip("Snowflake not installed")

import fakesnow
from sqlalchemy.engine.create import create_engine

with fakesnow.patch(
create_database_on_connect=True,
create_schema_on_connect=False,
):
yield create_engine("snowflake://test/test/information_schema")


@pytest.fixture(autouse=True)
def clear_registry():
"""Clear out state accumulated by importing alembic modules in env.pys.
Expand Down
45 changes: 45 additions & 0 deletions tests/schema/test_drop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from pytest_mock_resources import create_postgres_fixture, create_sqlite_fixture
from sqlalchemy.sql.ddl import CreateSchema

from sqlalchemy_declarative_extensions import (
Schemas,
declarative_database,
register_sqlalchemy_events,
)
from sqlalchemy_declarative_extensions.dialects import check_schema_exists
from sqlalchemy_declarative_extensions.sqlalchemy import declarative_base

_Base = declarative_base()


@declarative_database
class Base(_Base): # type: ignore
__abstract__ = True

schemas = Schemas()


register_sqlalchemy_events(Base.metadata, schemas=True)

pg = create_postgres_fixture(scope="function", engine_kwargs={"echo": True})
sqlite = create_sqlite_fixture()


def test_createall_schema_pg(pg):
with pg.begin() as conn:
conn.execute(CreateSchema("foo"))

Base.metadata.create_all(bind=pg)

with pg.connect() as conn:
assert check_schema_exists(conn, "foo") is False


def test_createall_schema_snowflake(snowflake):
with snowflake.begin() as conn:
conn.execute(CreateSchema("foo"))

Base.metadata.create_all(bind=snowflake)

with snowflake.connect() as conn:
assert check_schema_exists(conn, "foo") is False
14 changes: 10 additions & 4 deletions tests/schema/test_sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
)
from sqlalchemy_declarative_extensions.sqlalchemy import declarative_base

()


_Base = declarative_base()


Expand All @@ -25,7 +22,9 @@ class Foo(Base):
__tablename__ = "foo"
__table_args__ = {"schema": "fooschema"}

id = sqlalchemy.Column(sqlalchemy.types.Integer(), primary_key=True)
id = sqlalchemy.Column(
sqlalchemy.types.Integer(), primary_key=True, autoincrement=False
)


register_sqlalchemy_events(Base.metadata, schemas=True)
Expand All @@ -46,3 +45,10 @@ def test_createall_schema_sqlite(sqlite):
with sqlite.connect() as conn:
result = conn.execute(Foo.__table__.select()).fetchall()
assert result == []


def test_createall_schema_snowflake(snowflake):
Base.metadata.create_all(bind=snowflake, checkfirst=False)
with snowflake.connect() as conn:
result = conn.execute(Foo.__table__.select()).fetchall()
assert result == []

0 comments on commit 235eb57

Please sign in to comment.