diff --git a/pyproject.toml b/pyproject.toml index dbed4ff..3c936ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "sqlalchemy-declarative-extensions" -version = "0.8.2" +version = "0.8.3" authors = ["Dan Cardin "] description = "Library to declare additional kinds of objects not natively supported by SQLAlchemy/Alembic." diff --git a/src/sqlalchemy_declarative_extensions/dialects/sqlite/query.py b/src/sqlalchemy_declarative_extensions/dialects/sqlite/query.py index 4df14ca..bb62fca 100644 --- a/src/sqlalchemy_declarative_extensions/dialects/sqlite/query.py +++ b/src/sqlalchemy_declarative_extensions/dialects/sqlite/query.py @@ -26,7 +26,9 @@ def check_schema_exists_sqlite(connection: Connection, name: str) -> bool: def get_views_sqlite(connection: Connection): + schemas = get_schemas_sqlite(connection) return [ View(v.name, v.definition, schema=v.schema) - for v in connection.execute(views_query()).fetchall() + for schema in [*schemas, None] + for v in connection.execute(views_query(schema and schema.name)).fetchall() ] diff --git a/src/sqlalchemy_declarative_extensions/dialects/sqlite/schema.py b/src/sqlalchemy_declarative_extensions/dialects/sqlite/schema.py index a00c8a2..07273cd 100644 --- a/src/sqlalchemy_declarative_extensions/dialects/sqlite/schema.py +++ b/src/sqlalchemy_declarative_extensions/dialects/sqlite/schema.py @@ -1,27 +1,19 @@ from typing import Optional -from sqlalchemy import column, literal, table +from sqlalchemy import bindparam, text -from sqlalchemy_declarative_extensions.sqlalchemy import select - -def make_sqlite_schema(schema: Optional[str] = None): - tablename = "sqlite_schema" +def views_query(schema: Optional[str] = None): + tablename = "sqlite_master" if schema: tablename = f"{schema}.{tablename}" - return table( - tablename, - column("type"), - column("name"), - column("sql"), - ) - - -def views_query(schema: Optional[str] = None): - sqlite_schema = make_sqlite_schema(schema) - return select( - literal(None), - sqlite_schema.c.name.label("name"), - sqlite_schema.c.sql.label("definition"), - ).where(sqlite_schema.c.type == "view") + return text( + "SELECT" # noqa: S608 + " :schema AS schema," + " name AS name," + " sql AS definition," + " false as materialized" + f" FROM {tablename}" + " WHERE type == 'view'", + ).bindparams(bindparam("schema", schema)) diff --git a/src/sqlalchemy_declarative_extensions/view/base.py b/src/sqlalchemy_declarative_extensions/view/base.py index 2504720..78b70aa 100644 --- a/src/sqlalchemy_declarative_extensions/view/base.py +++ b/src/sqlalchemy_declarative_extensions/view/base.py @@ -22,9 +22,7 @@ T = TypeVar("T") -def view( - base: T, materialized: bool = False, register_as_model=False -) -> Callable[[type], T]: +def view(base, materialized: bool = False, register_as_model=False) -> Callable[[T], T]: """Decorate a class or declarative base model in order to register a View. Given some object with the attributes: `__tablename__`, (optionally for schema) `__table_args__`, @@ -212,6 +210,11 @@ def render_definition(self, conn: Connection, using_connection: bool = True): dialect_name_map = {"postgresql": "postgres"} dialect_name = dialect_name_map.get(dialect.name, dialect.name) + + # aiosqlite, pmrsqlite, etc + if "sqlite" in dialect_name: + dialect_name = "sqlite" + return ( escape_params( normalize( diff --git a/tests/view/test_sqlalchemy.py b/tests/view/test_sqlalchemy.py index 1423c05..c0cfa33 100644 --- a/tests/view/test_sqlalchemy.py +++ b/tests/view/test_sqlalchemy.py @@ -74,6 +74,11 @@ def test_create_view_mysql(mysql): run_test(mysql) +@skip_sqlalchemy13 +def test_create_view_sqlite(sqlite): + run_test(sqlite) + + def run_test(session): Base.metadata.create_all(bind=session.connection()) session.commit() diff --git a/tests/view/test_update.py b/tests/view/test_update.py new file mode 100644 index 0000000..ac7cb45 --- /dev/null +++ b/tests/view/test_update.py @@ -0,0 +1,69 @@ +from pytest_mock_resources import ( + create_postgres_fixture, + create_sqlite_fixture, +) +from sqlalchemy import Column, text, types + +from sqlalchemy_declarative_extensions import ( + Row, + Rows, + View, + declarative_database, + register_sqlalchemy_events, + register_view, +) +from sqlalchemy_declarative_extensions.sqlalchemy import declarative_base + +_Base = declarative_base() + + +@declarative_database +class Base(_Base): # type: ignore + __abstract__ = True + + rows = Rows().are( + Row("foo", id=1), + Row("foo", id=2), + Row("foo", id=12), + Row("foo", id=13), + ) + + +class Foo(Base): + __tablename__ = "foo" + + id = Column(types.Integer(), primary_key=True) + + +view = View("bar", "select id from foo where id < 10") +register_view(Base.metadata, view) + + +register_sqlalchemy_events(Base.metadata, schemas=True, views=True, rows=True) + +pg = create_postgres_fixture( + scope="function", engine_kwargs={"echo": True}, session=True +) +sqlite = create_sqlite_fixture(scope="function", session=True) + + +def test_create_view_postgresql(pg): + run_test(pg) + + +def test_create_view_sqlite(sqlite): + run_test(sqlite) + + +def run_test(session): + session.execute(text("CREATE TABLE foo (id integer)")) + session.execute(text("CREATE VIEW bar AS SELECT id FROM foo WHERE id = 1")) + session.execute(text("INSERT INTO foo (id) VALUES (1), (2), (12), (13)")) + + result = [f.id for f in session.execute(text("SELECT id from bar")).fetchall()] + assert result == [1] + + Base.metadata.create_all(bind=session.connection()) + + result = [f.id for f in session.execute(text("SELECT id from bar")).fetchall()] + assert result == [1, 2] diff --git a/tests/view/test_view_in_schema.py b/tests/view/test_view_in_schema.py new file mode 100644 index 0000000..b3c9e0e --- /dev/null +++ b/tests/view/test_view_in_schema.py @@ -0,0 +1,87 @@ +from pytest_mock_resources import ( + create_postgres_fixture, + create_sqlite_fixture, +) +from sqlalchemy import Column, text, types + +from sqlalchemy_declarative_extensions import ( + Row, + Rows, + Schemas, + View, + declarative_database, + register_sqlalchemy_events, + register_view, +) +from sqlalchemy_declarative_extensions.sqlalchemy import declarative_base + +_Base = declarative_base() + + +@declarative_database +class Base(_Base): # type: ignore + __abstract__ = True + + schemas = Schemas().are("fooschema") + rows = Rows().are( + Row("fooschema.foo", id=1), + Row("fooschema.foo", id=2), + Row("fooschema.foo", id=12), + Row("fooschema.foo", id=13), + ) + + +class Foo(Base): + __tablename__ = "foo" + __table_args__ = {"schema": "fooschema"} + + id = Column(types.Integer(), primary_key=True) + + +# Register imperitively +view = View( + "bar", + "select id from fooschema.foo where id < 10", + schema="fooschema", +) + +register_view(Base.metadata, view) + + +register_sqlalchemy_events(Base.metadata, schemas=True, views=True, rows=True) + +pg = create_postgres_fixture( + scope="function", engine_kwargs={"echo": True}, session=True +) +sqlite = create_sqlite_fixture(scope="function", session=True) + + +def test_create_view_postgresql(pg): + pg.execute(text("CREATE SCHEMA fooschema")) + run_test(pg) + + +def test_create_view_sqlite(sqlite): + sqlite.execute(text("ATTACH DATABASE ':memory:' AS fooschema")) + run_test(sqlite) + + +def run_test(session): + session.execute(text("CREATE TABLE fooschema.foo (id integer)")) + session.execute( + text("CREATE VIEW fooschema.bar AS SELECT id FROM fooschema.foo WHERE id = 1") + ) + session.execute(text("INSERT INTO fooschema.foo (id) VALUES (1), (2), (12), (13)")) + session.commit() + + result = [ + f.id for f in session.execute(text("SELECT id from fooschema.bar")).fetchall() + ] + assert result == [1] + + Base.metadata.create_all(bind=session.connection()) + + result = [ + f.id for f in session.execute(text("SELECT id from fooschema.bar")).fetchall() + ] + assert result == [1, 2]