Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Handle metadata sequences for most object types. #92

Merged
merged 7 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,19 @@

## 0.15

### 0.15.2

fix: Handle trigger metadata sequence.
fix: Handle procedure metadata sequence.
fix: Handle function metadata sequence.
fix: Handle grant metadata sequence.
fix: Handle role metadata sequence.
fix: Handle schema metadata sequence.

### 0.15.1

- fix: Accept more generic sequence to roles.

### 0.15.0

- fix: Add role name coercion to postgres default grant `to` argument.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "sqlalchemy-declarative-extensions"
version = "0.15.1"
version = "0.15.2"
authors = ["Dan Cardin <ddcardin@gmail.com>"]

description = "Library to declare additional kinds of objects not natively supported by SQLAlchemy/Alembic."
Expand Down
11 changes: 7 additions & 4 deletions src/sqlalchemy_declarative_extensions/alembic/function.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

from alembic.autogenerate.api import AutogenContext
from alembic.autogenerate.compare import comparators
from alembic.autogenerate.render import renderers

from sqlalchemy_declarative_extensions.function.base import Functions
from sqlalchemy_declarative_extensions.function.compare import (
CreateFunctionOp,
DropFunctionOp,
Expand All @@ -12,13 +15,13 @@


@comparators.dispatch_for("schema")
def _compare_functions(autogen_context, upgrade_ops, _):
metadata = autogen_context.metadata
functions = metadata.info.get("functions")
def _compare_functions(autogen_context: AutogenContext, upgrade_ops, _):
functions: Functions | None = Functions.extract(autogen_context.metadata)
if not functions:
return

result = compare_functions(autogen_context.connection, functions, metadata)
assert autogen_context.connection
result = compare_functions(autogen_context.connection, functions)
upgrade_ops.ops.extend(result)


Expand Down
15 changes: 6 additions & 9 deletions src/sqlalchemy_declarative_extensions/alembic/grant.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from __future__ import annotations

from alembic.autogenerate.api import AutogenContext
from alembic.autogenerate.compare import comparators
Expand All @@ -9,6 +9,7 @@
from sqlalchemy_declarative_extensions.grant.base import Grants
from sqlalchemy_declarative_extensions.grant.compare import (
GrantPrivilegesOp,
Operation,
RevokePrivilegesOp,
)
from sqlalchemy_declarative_extensions.role.base import Roles
Expand All @@ -18,14 +19,14 @@

@comparators.dispatch_for("schema")
def compare_grants(autogen_context: AutogenContext, upgrade_ops: UpgradeOps, _):
if autogen_context.metadata is None or autogen_context.connection is None:
if autogen_context.connection is None:
return # pragma: no cover

grants: Optional[Grants] = autogen_context.metadata.info.get("grants")
grants: Grants | None = Grants.extract(autogen_context.metadata)
if not grants:
return

roles: Optional[Roles] = autogen_context.metadata.info.get("roles")
roles: Roles | None = Roles.extract(autogen_context.metadata)

result = compare.compare_grants(autogen_context.connection, grants, roles=roles)
if not result:
Expand All @@ -43,10 +44,6 @@ def compare_grants(autogen_context: AutogenContext, upgrade_ops: UpgradeOps, _):


@renderers.dispatch_for(GrantPrivilegesOp)
def render_grant(_, op: GrantPrivilegesOp):
return f'op.execute(sa.text("""{op.to_sql()}"""))'


@renderers.dispatch_for(RevokePrivilegesOp)
def render_revoke(_, op: RevokePrivilegesOp):
def render_grant(_, op: Operation):
return f'op.execute(sa.text("""{op.to_sql()}"""))'
11 changes: 7 additions & 4 deletions src/sqlalchemy_declarative_extensions/alembic/procedure.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

from alembic.autogenerate.api import AutogenContext
from alembic.autogenerate.compare import comparators
from alembic.autogenerate.render import renderers

from sqlalchemy_declarative_extensions.procedure.base import Procedures
from sqlalchemy_declarative_extensions.procedure.compare import (
CreateProcedureOp,
DropProcedureOp,
Expand All @@ -12,13 +15,13 @@


@comparators.dispatch_for("schema")
def _compare_procedures(autogen_context, upgrade_ops, _):
metadata = autogen_context.metadata
procedures = metadata.info.get("procedures")
def _compare_procedures(autogen_context: AutogenContext, upgrade_ops, _):
procedures: Procedures | None = Procedures.extract(autogen_context.metadata)
if not procedures:
return

result = compare_procedures(autogen_context.connection, procedures, metadata)
assert autogen_context.connection
result = compare_procedures(autogen_context.connection, procedures)
upgrade_ops.ops.extend(result)


Expand Down
8 changes: 6 additions & 2 deletions src/sqlalchemy_declarative_extensions/alembic/role.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from alembic.autogenerate.api import AutogenContext
from alembic.autogenerate.compare import comparators
from alembic.autogenerate.render import renderers
Expand All @@ -6,6 +8,7 @@
from sqlalchemy_declarative_extensions.role.compare import (
CreateRoleOp,
DropRoleOp,
Roles,
UpdateRoleOp,
compare_roles,
)
Expand All @@ -16,11 +19,12 @@


@comparators.dispatch_for("schema")
def _compare_roles(autogen_context, upgrade_ops, _):
roles = autogen_context.metadata.info.get("roles")
def _compare_roles(autogen_context: AutogenContext, upgrade_ops, _):
roles: Roles | None = Roles.extract(autogen_context.metadata)
if not roles:
return

assert autogen_context.connection
result = compare_roles(autogen_context.connection, roles)
upgrade_ops.ops[0:0] = result

Expand Down
3 changes: 1 addition & 2 deletions src/sqlalchemy_declarative_extensions/alembic/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@

@comparators.dispatch_for("schema")
def compare_schemas(autogen_context: AutogenContext, upgrade_ops, _):
assert autogen_context.metadata
schemas: Schemas | None = autogen_context.metadata.info.get("schemas")
schemas: Schemas | None = Schemas.extract(autogen_context.metadata)
if not schemas:
return

Expand Down
11 changes: 7 additions & 4 deletions src/sqlalchemy_declarative_extensions/alembic/trigger.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

from alembic.autogenerate.api import AutogenContext
from alembic.autogenerate.compare import comparators
from alembic.autogenerate.render import renderers

from sqlalchemy_declarative_extensions.trigger.base import Triggers
from sqlalchemy_declarative_extensions.trigger.compare import (
CreateTriggerOp,
DropTriggerOp,
Expand All @@ -11,13 +14,13 @@


@comparators.dispatch_for("schema")
def _compare_triggers(autogen_context, upgrade_ops, _):
metadata = autogen_context.metadata
triggers = metadata.info.get("triggers")
def _compare_triggers(autogen_context: AutogenContext, upgrade_ops, _):
triggers: Triggers | None = Triggers.extract(autogen_context.metadata)
if not triggers:
return

result = compare_triggers(autogen_context.connection, triggers, metadata)
assert autogen_context.connection
result = compare_triggers(autogen_context.connection, triggers)
upgrade_ops.ops.extend(result)


Expand Down
32 changes: 31 additions & 1 deletion src/sqlalchemy_declarative_extensions/function/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass, field, replace
from typing import Iterable
from typing import Iterable, Sequence

from sqlalchemy import MetaData
from typing_extensions import Self
Expand Down Expand Up @@ -92,6 +92,36 @@ def coerce_from_unknown(

return None

@classmethod
def extract(cls, metadata: MetaData | list[MetaData | None] | None) -> Self | None:
if not isinstance(metadata, Sequence):
metadata = [metadata]

instances: list[Self] = [
m.info["functions"] for m in metadata if m and m.info.get("functions")
]

instance_count = len(instances)
if instance_count == 0:
return None

if instance_count == 1:
return instances[0]

if not all(
x.ignore_unspecified == instances[0].ignore_unspecified for x in instances
):
raise ValueError(
"All combined `Functions` instances must agree on the set of settings: ignore_unspecified"
)

functions = [s for instance in instances for s in instance.functions]
ignore = [s for instance in instances for s in instance.ignore]
ignore_unspecified = instances[0].ignore_unspecified
return cls(
functions=functions, ignore_unspecified=ignore_unspecified, ignore=ignore
)

def append(self, function: Function):
self.functions.append(function)

Expand Down
7 changes: 1 addition & 6 deletions src/sqlalchemy_declarative_extensions/function/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from dataclasses import dataclass
from typing import Sequence, Union

from sqlalchemy import MetaData
from sqlalchemy.engine import Connection

from sqlalchemy_declarative_extensions.dialects import get_function_cls, get_functions
Expand Down Expand Up @@ -48,11 +47,7 @@ def to_sql(self) -> list[str]:
Operation = Union[CreateFunctionOp, UpdateFunctionOp, DropFunctionOp]


def compare_functions(
connection: Connection,
functions: Functions,
metadata: MetaData,
) -> list[Operation]:
def compare_functions(connection: Connection, functions: Functions) -> list[Operation]:
result: list[Operation] = []

functions_by_name = {f.qualified_name: f for f in functions.functions}
Expand Down
2 changes: 1 addition & 1 deletion src/sqlalchemy_declarative_extensions/function/ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

def function_ddl(functions: Functions, function_filter: list[str] | None = None):
def after_create(metadata: MetaData, connection: Connection, **_):
result = compare_functions(connection, functions, metadata)
result = compare_functions(connection, functions)
for op in result:
if not match_name(op.function.qualified_name, function_filter):
continue
Expand Down
48 changes: 47 additions & 1 deletion src/sqlalchemy_declarative_extensions/grant/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

from dataclasses import dataclass, field, replace
from typing import Iterable, Union
from typing import Iterable, Sequence, Union

from sqlalchemy import MetaData
from typing_extensions import Self

from sqlalchemy_declarative_extensions.dialects import postgresql

Expand Down Expand Up @@ -60,6 +63,49 @@ def coerce_from_unknown(cls, unknown: None | Iterable[G] | Grants) -> Grants | N

return None

@classmethod
def extract(cls, metadata: MetaData | list[MetaData | None] | None) -> Self | None:
if not isinstance(metadata, Sequence):
metadata = [metadata]

instances: list[Self] = [
m.info["grants"] for m in metadata if m and m.info.get("grants")
]

instance_count = len(instances)
if instance_count == 0:
return None

if instance_count == 1:
return instances[0]

if not all(
x.ignore_unspecified == instances[0].ignore_unspecified
and x.ignore_self_grants == instances[0].ignore_self_grants
and x.only_defined_roles == instances[0].only_defined_roles
and x.default_grants_imply_grants
== instances[0].default_grants_imply_grants
for x in instances
):
raise ValueError(
"All combined `Grants` instances must agree on the set of settings: "
"ignore_unspecified, ignore_self_grants, only_defined_roles, default_grants_imply_grants"
)

grants = [s for instance in instances for s in instance.grants]

ignore_unspecified = instances[0].ignore_unspecified
ignore_self_grants = instances[0].ignore_self_grants
only_defined_roles = instances[0].only_defined_roles
default_grants_imply_grants = instances[0].default_grants_imply_grants
return cls(
grants=grants,
ignore_unspecified=ignore_unspecified,
ignore_self_grants=ignore_self_grants,
only_defined_roles=only_defined_roles,
default_grants_imply_grants=default_grants_imply_grants,
)

def __iter__(self):
yield from self.grants

Expand Down
32 changes: 31 additions & 1 deletion src/sqlalchemy_declarative_extensions/procedure/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass, field, replace
from typing import Iterable
from typing import Iterable, Sequence

from sqlalchemy import MetaData
from typing_extensions import Self
Expand Down Expand Up @@ -87,6 +87,36 @@ def coerce_from_unknown(

return None

@classmethod
def extract(cls, metadata: MetaData | list[MetaData | None] | None) -> Self | None:
if not isinstance(metadata, Sequence):
metadata = [metadata]

instances: list[Self] = [
m.info["procedures"] for m in metadata if m and m.info.get("procedures")
]

instance_count = len(instances)
if instance_count == 0:
return None

if instance_count == 1:
return instances[0]

if not all(
x.ignore_unspecified == instances[0].ignore_unspecified for x in instances
):
raise ValueError(
"All combined `Procedures` instances must agree on the set of settings: ignore_unspecified"
)

procedures = [s for instance in instances for s in instance.procedures]
ignore = [s for instance in instances for s in instance.ignore]
ignore_unspecified = instances[0].ignore_unspecified
return cls(
procedures=procedures, ignore_unspecified=ignore_unspecified, ignore=ignore
)

def append(self, procedure: Procedure):
self.procedures.append(procedure)

Expand Down
Loading
Loading