diff --git a/pyproject.toml b/pyproject.toml index de970e9..a81e96d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "sqlalchemy-declarative-extensions" -version = "0.15.2" +version = "0.15.3" authors = ["Dan Cardin "] description = "Library to declare additional kinds of objects not natively supported by SQLAlchemy/Alembic." diff --git a/src/sqlalchemy_declarative_extensions/alembic/row.py b/src/sqlalchemy_declarative_extensions/alembic/row.py index 2f5c7eb..1e35a64 100644 --- a/src/sqlalchemy_declarative_extensions/alembic/row.py +++ b/src/sqlalchemy_declarative_extensions/alembic/row.py @@ -3,6 +3,7 @@ from alembic.autogenerate.api import AutogenContext from alembic.operations import Operations from alembic.operations.ops import UpgradeOps +from sqlalchemy import MetaData from sqlalchemy_declarative_extensions import row from sqlalchemy_declarative_extensions.alembic.base import ( @@ -20,18 +21,14 @@ def compare_rows(autogen_context: AutogenContext, upgrade_ops: UpgradeOps, _): - if ( - autogen_context.metadata is None or autogen_context.connection is None - ): # pragma: no cover + optional_rows: tuple[Rows, MetaData] | None = Rows.extract(autogen_context.metadata) + if not optional_rows: return - rows: Rows | None = autogen_context.metadata.info.get("rows") - if not rows: - return + rows, metadata = optional_rows - result = row.compare.compare_rows( - autogen_context.connection, autogen_context.metadata, rows - ) + assert autogen_context.connection + result = row.compare.compare_rows(autogen_context.connection, metadata, rows) upgrade_ops.ops.extend(result) # type: ignore diff --git a/src/sqlalchemy_declarative_extensions/alembic/view.py b/src/sqlalchemy_declarative_extensions/alembic/view.py index 79aae9f..9b6512c 100644 --- a/src/sqlalchemy_declarative_extensions/alembic/view.py +++ b/src/sqlalchemy_declarative_extensions/alembic/view.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from alembic.autogenerate.api import AutogenContext from sqlalchemy_declarative_extensions.alembic.base import ( @@ -5,6 +7,7 @@ register_renderer_dispatcher, register_rewriter_dispatcher, ) +from sqlalchemy_declarative_extensions.view.base import Views from sqlalchemy_declarative_extensions.view.compare import ( CreateViewOp, DropViewOp, @@ -14,13 +17,13 @@ ) -def _compare_views(autogen_context, upgrade_ops, _): - metadata = autogen_context.metadata - views = metadata.info.get("views") +def _compare_views(autogen_context: AutogenContext, upgrade_ops, _): + views: Views | None = Views.extract(autogen_context.metadata) if not views: return - result = compare_views(autogen_context.connection, views, metadata) + assert autogen_context.connection + result = compare_views(autogen_context.connection, views) upgrade_ops.ops.extend(result) diff --git a/src/sqlalchemy_declarative_extensions/dialects/postgresql/view.py b/src/sqlalchemy_declarative_extensions/dialects/postgresql/view.py index 059e8b0..31b1a67 100644 --- a/src/sqlalchemy_declarative_extensions/dialects/postgresql/view.py +++ b/src/sqlalchemy_declarative_extensions/dialects/postgresql/view.py @@ -3,7 +3,6 @@ from dataclasses import dataclass, field, replace from typing import Any, Literal -from sqlalchemy import MetaData from sqlalchemy.engine import Connection, Dialect from typing_extensions import override @@ -121,9 +120,12 @@ def to_sql_create(self, dialect: Dialect) -> list[str]: return result def normalize( - self, conn: Connection, metadata: MetaData, using_connection: bool = True + self, + conn: Connection, + naming_convention: base.NamingConvention | None, + using_connection: bool = True, ) -> View: - instance = super().normalize(conn, metadata, using_connection) + instance = super().normalize(conn, naming_convention, using_connection) return replace( instance, materialized=MaterializedOptions.from_value(self.materialized), diff --git a/src/sqlalchemy_declarative_extensions/dialects/snowflake/view.py b/src/sqlalchemy_declarative_extensions/dialects/snowflake/view.py index fc8c25f..be81b6a 100644 --- a/src/sqlalchemy_declarative_extensions/dialects/snowflake/view.py +++ b/src/sqlalchemy_declarative_extensions/dialects/snowflake/view.py @@ -3,7 +3,6 @@ from dataclasses import dataclass, replace from typing import Any -from sqlalchemy import MetaData from sqlalchemy.engine import Connection, Dialect from typing_extensions import override @@ -52,9 +51,12 @@ def to_sql_create(self, dialect: Dialect) -> list[str]: return result def normalize( - self, conn: Connection, metadata: MetaData, using_connection: bool = True + self, + conn: Connection, + naming_convention: base.NamingConvention | None, + using_connection: bool = True, ) -> View: - result = super().normalize(conn, metadata, using_connection) + result = super().normalize(conn, naming_convention, using_connection) return replace( result, schema=self.schema.upper() if self.schema else None, diff --git a/src/sqlalchemy_declarative_extensions/row/base.py b/src/sqlalchemy_declarative_extensions/row/base.py index e3d90ba..6ef4677 100644 --- a/src/sqlalchemy_declarative_extensions/row/base.py +++ b/src/sqlalchemy_declarative_extensions/row/base.py @@ -1,7 +1,10 @@ from __future__ import annotations from dataclasses import dataclass, field, replace -from typing import Any, Iterable +from typing import Any, Iterable, Sequence + +from sqlalchemy import MetaData +from typing_extensions import Self from sqlalchemy_declarative_extensions.sql import split_schema @@ -22,6 +25,30 @@ def coerce_from_unknown(cls, unknown: None | Iterable[Row] | Rows) -> Rows | Non return None + @classmethod + def extract( + cls, metadata: MetaData | list[MetaData | None] | None + ) -> tuple[Self, MetaData] | None: + if not isinstance(metadata, Sequence): + metadata = [metadata] + + instances: list[Self] = [ + m.info["rows"] for m in metadata if m and m.info.get("rows") + ] + + instance_count = len(instances) + if instance_count == 0: + return None + + if instance_count == 1: + metadata = metadata[0] + assert metadata + return instances[0], metadata + + raise NotImplementedError( + "Rows is currently only supported on a single instance of MetaData. File an issue if this affects you!" + ) + def __iter__(self): yield from self.rows diff --git a/src/sqlalchemy_declarative_extensions/view/base.py b/src/sqlalchemy_declarative_extensions/view/base.py index 7873653..69aac15 100644 --- a/src/sqlalchemy_declarative_extensions/view/base.py +++ b/src/sqlalchemy_declarative_extensions/view/base.py @@ -4,7 +4,18 @@ import uuid import warnings from dataclasses import dataclass, field, replace -from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Optional, TypeVar, cast +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Sequence, + TypeVar, + cast, +) from sqlalchemy import Index, MetaData, UniqueConstraint, text from sqlalchemy.engine import Connection, Dialect @@ -29,6 +40,7 @@ T = TypeVar("T") ViewType = TypeVar("ViewType", "View", "DeclarativeView") +NamingConvention = Dict[str, Any] def view( @@ -288,12 +300,15 @@ def render_constraints(self, *, create): return result def normalize( - self, conn: Connection, metadata: MetaData, using_connection: bool = True + self, + conn: Connection, + naming_convention: NamingConvention | None, + using_connection: bool = True, ) -> Self: constraints = None if self.constraints: constraints = [ - ViewIndex.from_unknown(c, self, conn.dialect, metadata) + ViewIndex.from_unknown(c, self, conn.dialect, naming_convention) for c in self.constraints ] @@ -365,7 +380,7 @@ def from_unknown( index: ViewIndex | Index | UniqueConstraint, source_view: View, dialect: Dialect, - metadata: MetaData, + naming_convention: NamingConvention | None, ): if isinstance(index, ViewIndex): convention = "uq" if index.unique else "ix" @@ -390,13 +405,13 @@ def from_unknown( if instance.name: return instance - naming_convention = metadata.naming_convention or DEFAULT_NAMING_CONVENTION + naming_convention = naming_convention or DEFAULT_NAMING_CONVENTION # type: ignore + assert naming_convention + assert "ix" in naming_convention template = cast( str, naming_convention.get(convention) or naming_convention["ix"] ) - cd = ConventionDict( - _ViewIndexAdapter(instance), source_view, metadata.naming_convention - ) + cd = ConventionDict(_ViewIndexAdapter(instance), source_view, naming_convention) conventionalized_name = conv(template % cd) try: @@ -504,6 +519,7 @@ class Views: ignore: Iterable[str] = field(default_factory=set) ignore_views: Iterable[str] = field(default_factory=set) + naming_convention: NamingConvention | None = None @classmethod def coerce_from_unknown( @@ -517,6 +533,53 @@ def coerce_from_unknown( return None + @classmethod + def extract(cls, metadata: MetaData | list[MetaData | None] | None) -> Self | None: + if not isinstance(metadata, Sequence): + metadata = [metadata] + + naming_conventions = [m.naming_convention for m in metadata if m] + instances: list[Self] = [ + m.info["views"] for m in metadata if m and m.info.get("views") + ] + + instance_count = len(instances) + if instance_count == 0: + return None + + if instance_count == 1: + return instances[0] + + if not all( + x.ignore_unspecified == instances[0].ignore_unspecified + and x.naming_convention == instances[0].naming_convention + for x in instances + ): + raise ValueError( + "All combined `Views` instances must agree on the set of settings: ignore_unspecified, naming_convention" + ) + + views = [s for instance in instances for s in instance.views] + ignore = [s for instance in instances for s in instance.ignore] + ignore_views = [s for instance in instances for s in instance.ignore_views] + + ignore_unspecified = instances[0].ignore_unspecified + naming_convention: NamingConvention = instances[0].naming_convention # type: ignore + + if not naming_convention: + if not all(n == naming_conventions[0] for n in naming_conventions): + raise ValueError("All MetaData `naming_convention`s must agree") + + naming_convention = naming_conventions[0] # type: ignore + + return cls( + views=views, + ignore_unspecified=ignore_unspecified, + ignore=ignore, + ignore_views=ignore_views, + naming_convention=naming_convention, + ) + def append(self, view: View | DeclarativeView): self.views.append(view) diff --git a/src/sqlalchemy_declarative_extensions/view/compare.py b/src/sqlalchemy_declarative_extensions/view/compare.py index 4385d4a..1232c98 100644 --- a/src/sqlalchemy_declarative_extensions/view/compare.py +++ b/src/sqlalchemy_declarative_extensions/view/compare.py @@ -5,7 +5,6 @@ from fnmatch import fnmatch from typing import Union -from sqlalchemy import MetaData from sqlalchemy.engine import Connection, Dialect from sqlalchemy_declarative_extensions.dialects import get_view_cls, get_views @@ -52,7 +51,6 @@ def to_sql(self, dialect: Dialect) -> list[str]: def compare_views( connection: Connection, views: Views, - metadata: MetaData, normalize_with_connection: bool = True, ) -> list[Operation]: if views.ignore_views: @@ -79,7 +77,9 @@ def compare_views( removed_view_names = existing_view_names - expected_view_names for view in concrete_defined_views: - normalized_view = view.normalize(connection, metadata, using_connection=False) + normalized_view = view.normalize( + connection, views.naming_convention, using_connection=False + ) view_name = normalized_view.qualified_name @@ -95,12 +95,16 @@ def compare_views( result.append(CreateViewOp(normalized_view)) else: normalized_view = normalized_view.normalize( - connection, metadata, using_connection=normalize_with_connection + connection, + views.naming_convention, + using_connection=normalize_with_connection, ) existing_view = existing_views_by_name[view_name] normalized_existing_view = existing_view.normalize( - connection, metadata, using_connection=normalize_with_connection + connection, + views.naming_convention, + using_connection=normalize_with_connection, ) if normalized_existing_view != normalized_view: diff --git a/src/sqlalchemy_declarative_extensions/view/ddl.py b/src/sqlalchemy_declarative_extensions/view/ddl.py index 005cfc7..dc0f43c 100644 --- a/src/sqlalchemy_declarative_extensions/view/ddl.py +++ b/src/sqlalchemy_declarative_extensions/view/ddl.py @@ -10,9 +10,7 @@ def view_ddl(views: Views, view_filter: list[str] | None = None): def after_create(metadata: MetaData, connection: Connection, **_): - result = compare_views( - connection, views, metadata, normalize_with_connection=False - ) + result = compare_views(connection, views, normalize_with_connection=False) for op in result: if not match_name(op.view.qualified_name, view_filter): continue diff --git a/tests/examples/test_view_complex_comparison_pg/test_migrations.py b/tests/examples/test_view_complex_comparison_pg/test_migrations.py index ec4f8ad..2d1bb78 100644 --- a/tests/examples/test_view_complex_comparison_pg/test_migrations.py +++ b/tests/examples/test_view_complex_comparison_pg/test_migrations.py @@ -22,5 +22,5 @@ def test_apply_autogenerated_revision(alembic_runner: MigrationContext, alembic_ # Now a comparison should yield no results, because the view def has not changed. with alembic_engine.connect() as conn: - result = compare_views(conn, views=Base.views, metadata=Base.metadata) + result = compare_views(conn, views=Base.views) assert result == [] diff --git a/tests/row/test_metadata_sequence.py b/tests/row/test_metadata_sequence.py new file mode 100644 index 0000000..92f3fd5 --- /dev/null +++ b/tests/row/test_metadata_sequence.py @@ -0,0 +1,32 @@ +import pytest +import sqlalchemy + +from sqlalchemy_declarative_extensions import ( + Row, + Rows, + declare_database, +) + +metadata1 = sqlalchemy.MetaData() +metadata2 = sqlalchemy.MetaData() +metadata3 = sqlalchemy.MetaData() + +declare_database(metadata1, rows=Rows().are(Row("foo"))) +declare_database(metadata2, rows=Rows().are(Row("bar"))) + + +def test_invalid_combination(): + with pytest.raises(NotImplementedError): + Rows.extract([metadata1, metadata1]) + + +def test_single(): + rows = Rows.extract(metadata1) + assert rows + assert rows[0] is metadata1.info["rows"] + assert rows[1] is metadata1 + + rows = Rows.extract([metadata1, metadata3]) + assert rows + assert rows[0] is metadata1.info["rows"] + assert rows[1] is metadata1 diff --git a/tests/view/test_constraint_drops_first.py b/tests/view/test_constraint_drops_first.py index 8671dcc..3c22f37 100644 --- a/tests/view/test_constraint_drops_first.py +++ b/tests/view/test_constraint_drops_first.py @@ -50,7 +50,7 @@ def test_constraint_changes(pg): views = Base.metadata.info["views"] connection = pg.connection() - result = compare_views(connection, views, Base.metadata) + result = compare_views(connection, views) sql_statements = result[0].to_sql(connection.dialect) assert len(sql_statements) == 2 diff --git a/tests/view/test_ignore_unspecified.py b/tests/view/test_ignore_unspecified.py index 46e480d..766c56e 100644 --- a/tests/view/test_ignore_unspecified.py +++ b/tests/view/test_ignore_unspecified.py @@ -42,5 +42,5 @@ def test_ignore_views(pg): # Verify this no longer sees changes to make! Failing here would imply the autogenerate # is not fully normalizing the difference. - result = compare_views(pg.connection(), views=Base.views, metadata=Base.metadata) + result = compare_views(pg.connection(), views=Base.views) assert result == [] diff --git a/tests/view/test_metadata_sequence.py b/tests/view/test_metadata_sequence.py new file mode 100644 index 0000000..3e7f682 --- /dev/null +++ b/tests/view/test_metadata_sequence.py @@ -0,0 +1,63 @@ +import pytest +import sqlalchemy +from sqlalchemy.sql.schema import DEFAULT_NAMING_CONVENTION + +from sqlalchemy_declarative_extensions import ( + View, + Views, + declare_database, +) + +metadata1 = sqlalchemy.MetaData() +metadata2 = sqlalchemy.MetaData() +metadata3 = sqlalchemy.MetaData() +metadata4 = sqlalchemy.MetaData() +metadata5 = sqlalchemy.MetaData(naming_convention={"ix": "asdf"}) + +declare_database(metadata1, views=Views(ignore=["one"]).are(View("foo", ""))) +declare_database(metadata2, views=Views(ignore_views=["two"]).are(View("bar", ""))) +declare_database(metadata3, views=Views(ignore_unspecified=True).are(View("baz", ""))) +declare_database( + metadata4, + views=Views(naming_convention={"ix": "asdf"}).are(View("baz", "")), +) +declare_database(metadata5, views=Views().are(View("bax", ""))) + + +def test_invalid_combination(): + with pytest.raises(ValueError): + Views.extract([metadata1, metadata3]) + + with pytest.raises(ValueError): + Views.extract([metadata1, metadata4]) + + with pytest.raises(ValueError): + Views.extract([metadata1, metadata5]) + + +def test_valid_combination(): + views = Views.extract([metadata1, metadata2]) + assert views == Views( + views=[View("foo", ""), View("bar", "")], + ignore=["one"], + ignore_views=["two"], + naming_convention=DEFAULT_NAMING_CONVENTION, + ) + + +def test_single(): + views = Views.extract(metadata5) + assert views + assert views is metadata5.info["views"] + assert views.naming_convention is None + + +def test_naming_convention_fallback(): + metadatat1 = sqlalchemy.MetaData(naming_convention={"ix": "asdf"}) + metadatat2 = sqlalchemy.MetaData(naming_convention={"ix": "asdf"}) + declare_database(metadatat1, views=Views()) + declare_database(metadatat2, views=Views()) + + views = Views.extract([metadatat1, metadatat2]) + assert views + assert views.naming_convention == {"ix": "asdf"} diff --git a/tests/view/test_only_constraint_changes.py b/tests/view/test_only_constraint_changes.py index 46c7700..145b83e 100644 --- a/tests/view/test_only_constraint_changes.py +++ b/tests/view/test_only_constraint_changes.py @@ -66,7 +66,7 @@ def test_constraint_changes(pg): views = Base.metadata.info["views"] connection = pg.connection() - result = compare_views(connection, views, Base.metadata) + result = compare_views(connection, views) assert len(result) == 1 sql_statements = result[0].to_sql(connection.dialect) diff --git a/tests/view/test_only_different_constraints_drop.py b/tests/view/test_only_different_constraints_drop.py index 350f5c5..90ad403 100644 --- a/tests/view/test_only_different_constraints_drop.py +++ b/tests/view/test_only_different_constraints_drop.py @@ -70,7 +70,7 @@ def test_constraint_changes(pg): views = Base.metadata.info["views"] connection = pg.connection() - result = compare_views(connection, views, Base.metadata) + result = compare_views(connection, views) assert len(result) == 1 sql_statements = result[0].to_sql(connection.dialect)