Skip to content

Commit

Permalink
Merge pull request #94 from DanCardin/dc/view-metadata-sequence
Browse files Browse the repository at this point in the history
fix: Handle view metadata sequence.
  • Loading branch information
DanCardin authored Oct 3, 2024
2 parents 2c329d8 + aec9ee4 commit ca32253
Show file tree
Hide file tree
Showing 16 changed files with 233 additions and 42 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "sqlalchemy-declarative-extensions"
version = "0.15.2"
version = "0.15.3"
authors = ["Dan Cardin <ddcardin@gmail.com>"]

description = "Library to declare additional kinds of objects not natively supported by SQLAlchemy/Alembic."
Expand Down
15 changes: 6 additions & 9 deletions src/sqlalchemy_declarative_extensions/alembic/row.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from alembic.autogenerate.api import AutogenContext
from alembic.operations import Operations
from alembic.operations.ops import UpgradeOps
from sqlalchemy import MetaData

from sqlalchemy_declarative_extensions import row
from sqlalchemy_declarative_extensions.alembic.base import (
Expand All @@ -20,18 +21,14 @@


def compare_rows(autogen_context: AutogenContext, upgrade_ops: UpgradeOps, _):
if (
autogen_context.metadata is None or autogen_context.connection is None
): # pragma: no cover
optional_rows: tuple[Rows, MetaData] | None = Rows.extract(autogen_context.metadata)
if not optional_rows:
return

rows: Rows | None = autogen_context.metadata.info.get("rows")
if not rows:
return
rows, metadata = optional_rows

result = row.compare.compare_rows(
autogen_context.connection, autogen_context.metadata, rows
)
assert autogen_context.connection
result = row.compare.compare_rows(autogen_context.connection, metadata, rows)
upgrade_ops.ops.extend(result) # type: ignore


Expand Down
11 changes: 7 additions & 4 deletions src/sqlalchemy_declarative_extensions/alembic/view.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations

from alembic.autogenerate.api import AutogenContext

from sqlalchemy_declarative_extensions.alembic.base import (
register_comparator_dispatcher,
register_renderer_dispatcher,
register_rewriter_dispatcher,
)
from sqlalchemy_declarative_extensions.view.base import Views
from sqlalchemy_declarative_extensions.view.compare import (
CreateViewOp,
DropViewOp,
Expand All @@ -14,13 +17,13 @@
)


def _compare_views(autogen_context, upgrade_ops, _):
metadata = autogen_context.metadata
views = metadata.info.get("views")
def _compare_views(autogen_context: AutogenContext, upgrade_ops, _):
views: Views | None = Views.extract(autogen_context.metadata)
if not views:
return

result = compare_views(autogen_context.connection, views, metadata)
assert autogen_context.connection
result = compare_views(autogen_context.connection, views)
upgrade_ops.ops.extend(result)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from dataclasses import dataclass, field, replace
from typing import Any, Literal

from sqlalchemy import MetaData
from sqlalchemy.engine import Connection, Dialect
from typing_extensions import override

Expand Down Expand Up @@ -121,9 +120,12 @@ def to_sql_create(self, dialect: Dialect) -> list[str]:
return result

def normalize(
self, conn: Connection, metadata: MetaData, using_connection: bool = True
self,
conn: Connection,
naming_convention: base.NamingConvention | None,
using_connection: bool = True,
) -> View:
instance = super().normalize(conn, metadata, using_connection)
instance = super().normalize(conn, naming_convention, using_connection)
return replace(
instance,
materialized=MaterializedOptions.from_value(self.materialized),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from dataclasses import dataclass, replace
from typing import Any

from sqlalchemy import MetaData
from sqlalchemy.engine import Connection, Dialect
from typing_extensions import override

Expand Down Expand Up @@ -52,9 +51,12 @@ def to_sql_create(self, dialect: Dialect) -> list[str]:
return result

def normalize(
self, conn: Connection, metadata: MetaData, using_connection: bool = True
self,
conn: Connection,
naming_convention: base.NamingConvention | None,
using_connection: bool = True,
) -> View:
result = super().normalize(conn, metadata, using_connection)
result = super().normalize(conn, naming_convention, using_connection)
return replace(
result,
schema=self.schema.upper() if self.schema else None,
Expand Down
29 changes: 28 additions & 1 deletion src/sqlalchemy_declarative_extensions/row/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

from dataclasses import dataclass, field, replace
from typing import Any, Iterable
from typing import Any, Iterable, Sequence

from sqlalchemy import MetaData
from typing_extensions import Self

from sqlalchemy_declarative_extensions.sql import split_schema

Expand All @@ -22,6 +25,30 @@ def coerce_from_unknown(cls, unknown: None | Iterable[Row] | Rows) -> Rows | Non

return None

@classmethod
def extract(
cls, metadata: MetaData | list[MetaData | None] | None
) -> tuple[Self, MetaData] | None:
if not isinstance(metadata, Sequence):
metadata = [metadata]

instances: list[Self] = [
m.info["rows"] for m in metadata if m and m.info.get("rows")
]

instance_count = len(instances)
if instance_count == 0:
return None

if instance_count == 1:
metadata = metadata[0]
assert metadata
return instances[0], metadata

raise NotImplementedError(
"Rows is currently only supported on a single instance of MetaData. File an issue if this affects you!"
)

def __iter__(self):
yield from self.rows

Expand Down
79 changes: 71 additions & 8 deletions src/sqlalchemy_declarative_extensions/view/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,18 @@
import uuid
import warnings
from dataclasses import dataclass, field, replace
from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Optional, TypeVar, cast
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
TypeVar,
cast,
)

from sqlalchemy import Index, MetaData, UniqueConstraint, text
from sqlalchemy.engine import Connection, Dialect
Expand All @@ -29,6 +40,7 @@

T = TypeVar("T")
ViewType = TypeVar("ViewType", "View", "DeclarativeView")
NamingConvention = Dict[str, Any]


def view(
Expand Down Expand Up @@ -288,12 +300,15 @@ def render_constraints(self, *, create):
return result

def normalize(
self, conn: Connection, metadata: MetaData, using_connection: bool = True
self,
conn: Connection,
naming_convention: NamingConvention | None,
using_connection: bool = True,
) -> Self:
constraints = None
if self.constraints:
constraints = [
ViewIndex.from_unknown(c, self, conn.dialect, metadata)
ViewIndex.from_unknown(c, self, conn.dialect, naming_convention)
for c in self.constraints
]

Expand Down Expand Up @@ -365,7 +380,7 @@ def from_unknown(
index: ViewIndex | Index | UniqueConstraint,
source_view: View,
dialect: Dialect,
metadata: MetaData,
naming_convention: NamingConvention | None,
):
if isinstance(index, ViewIndex):
convention = "uq" if index.unique else "ix"
Expand All @@ -390,13 +405,13 @@ def from_unknown(
if instance.name:
return instance

naming_convention = metadata.naming_convention or DEFAULT_NAMING_CONVENTION
naming_convention = naming_convention or DEFAULT_NAMING_CONVENTION # type: ignore
assert naming_convention
assert "ix" in naming_convention
template = cast(
str, naming_convention.get(convention) or naming_convention["ix"]
)
cd = ConventionDict(
_ViewIndexAdapter(instance), source_view, metadata.naming_convention
)
cd = ConventionDict(_ViewIndexAdapter(instance), source_view, naming_convention)
conventionalized_name = conv(template % cd)

try:
Expand Down Expand Up @@ -504,6 +519,7 @@ class Views:

ignore: Iterable[str] = field(default_factory=set)
ignore_views: Iterable[str] = field(default_factory=set)
naming_convention: NamingConvention | None = None

@classmethod
def coerce_from_unknown(
Expand All @@ -517,6 +533,53 @@ def coerce_from_unknown(

return None

@classmethod
def extract(cls, metadata: MetaData | list[MetaData | None] | None) -> Self | None:
if not isinstance(metadata, Sequence):
metadata = [metadata]

naming_conventions = [m.naming_convention for m in metadata if m]
instances: list[Self] = [
m.info["views"] for m in metadata if m and m.info.get("views")
]

instance_count = len(instances)
if instance_count == 0:
return None

if instance_count == 1:
return instances[0]

if not all(
x.ignore_unspecified == instances[0].ignore_unspecified
and x.naming_convention == instances[0].naming_convention
for x in instances
):
raise ValueError(
"All combined `Views` instances must agree on the set of settings: ignore_unspecified, naming_convention"
)

views = [s for instance in instances for s in instance.views]
ignore = [s for instance in instances for s in instance.ignore]
ignore_views = [s for instance in instances for s in instance.ignore_views]

ignore_unspecified = instances[0].ignore_unspecified
naming_convention: NamingConvention = instances[0].naming_convention # type: ignore

if not naming_convention:
if not all(n == naming_conventions[0] for n in naming_conventions):
raise ValueError("All MetaData `naming_convention`s must agree")

naming_convention = naming_conventions[0] # type: ignore

return cls(
views=views,
ignore_unspecified=ignore_unspecified,
ignore=ignore,
ignore_views=ignore_views,
naming_convention=naming_convention,
)

def append(self, view: View | DeclarativeView):
self.views.append(view)

Expand Down
14 changes: 9 additions & 5 deletions src/sqlalchemy_declarative_extensions/view/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from fnmatch import fnmatch
from typing import Union

from sqlalchemy import MetaData
from sqlalchemy.engine import Connection, Dialect

from sqlalchemy_declarative_extensions.dialects import get_view_cls, get_views
Expand Down Expand Up @@ -52,7 +51,6 @@ def to_sql(self, dialect: Dialect) -> list[str]:
def compare_views(
connection: Connection,
views: Views,
metadata: MetaData,
normalize_with_connection: bool = True,
) -> list[Operation]:
if views.ignore_views:
Expand All @@ -79,7 +77,9 @@ def compare_views(
removed_view_names = existing_view_names - expected_view_names

for view in concrete_defined_views:
normalized_view = view.normalize(connection, metadata, using_connection=False)
normalized_view = view.normalize(
connection, views.naming_convention, using_connection=False
)

view_name = normalized_view.qualified_name

Expand All @@ -95,12 +95,16 @@ def compare_views(
result.append(CreateViewOp(normalized_view))
else:
normalized_view = normalized_view.normalize(
connection, metadata, using_connection=normalize_with_connection
connection,
views.naming_convention,
using_connection=normalize_with_connection,
)

existing_view = existing_views_by_name[view_name]
normalized_existing_view = existing_view.normalize(
connection, metadata, using_connection=normalize_with_connection
connection,
views.naming_convention,
using_connection=normalize_with_connection,
)

if normalized_existing_view != normalized_view:
Expand Down
4 changes: 1 addition & 3 deletions src/sqlalchemy_declarative_extensions/view/ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@

def view_ddl(views: Views, view_filter: list[str] | None = None):
def after_create(metadata: MetaData, connection: Connection, **_):
result = compare_views(
connection, views, metadata, normalize_with_connection=False
)
result = compare_views(connection, views, normalize_with_connection=False)
for op in result:
if not match_name(op.view.qualified_name, view_filter):
continue
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@ def test_apply_autogenerated_revision(alembic_runner: MigrationContext, alembic_

# Now a comparison should yield no results, because the view def has not changed.
with alembic_engine.connect() as conn:
result = compare_views(conn, views=Base.views, metadata=Base.metadata)
result = compare_views(conn, views=Base.views)
assert result == []
32 changes: 32 additions & 0 deletions tests/row/test_metadata_sequence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pytest
import sqlalchemy

from sqlalchemy_declarative_extensions import (
Row,
Rows,
declare_database,
)

metadata1 = sqlalchemy.MetaData()
metadata2 = sqlalchemy.MetaData()
metadata3 = sqlalchemy.MetaData()

declare_database(metadata1, rows=Rows().are(Row("foo")))
declare_database(metadata2, rows=Rows().are(Row("bar")))


def test_invalid_combination():
with pytest.raises(NotImplementedError):
Rows.extract([metadata1, metadata1])


def test_single():
rows = Rows.extract(metadata1)
assert rows
assert rows[0] is metadata1.info["rows"]
assert rows[1] is metadata1

rows = Rows.extract([metadata1, metadata3])
assert rows
assert rows[0] is metadata1.info["rows"]
assert rows[1] is metadata1
Loading

0 comments on commit ca32253

Please sign in to comment.