diff --git a/CHANGELOG.md b/CHANGELOG.md index 70274f7..d00dd6e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,19 @@ ## 0.15 +### 0.15.2 + +fix: Handle trigger metadata sequence. +fix: Handle procedure metadata sequence. +fix: Handle function metadata sequence. +fix: Handle grant metadata sequence. +fix: Handle role metadata sequence. +fix: Handle schema metadata sequence. + +### 0.15.1 + +- fix: Accept more generic sequence to roles. + ### 0.15.0 - fix: Add role name coercion to postgres default grant `to` argument. diff --git a/pyproject.toml b/pyproject.toml index f7c7e3b..de970e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "sqlalchemy-declarative-extensions" -version = "0.15.1" +version = "0.15.2" authors = ["Dan Cardin "] description = "Library to declare additional kinds of objects not natively supported by SQLAlchemy/Alembic." diff --git a/src/sqlalchemy_declarative_extensions/alembic/function.py b/src/sqlalchemy_declarative_extensions/alembic/function.py index 5f3adaf..8b0e25e 100644 --- a/src/sqlalchemy_declarative_extensions/alembic/function.py +++ b/src/sqlalchemy_declarative_extensions/alembic/function.py @@ -1,7 +1,10 @@ +from __future__ import annotations + from alembic.autogenerate.api import AutogenContext from alembic.autogenerate.compare import comparators from alembic.autogenerate.render import renderers +from sqlalchemy_declarative_extensions.function.base import Functions from sqlalchemy_declarative_extensions.function.compare import ( CreateFunctionOp, DropFunctionOp, @@ -12,13 +15,13 @@ @comparators.dispatch_for("schema") -def _compare_functions(autogen_context, upgrade_ops, _): - metadata = autogen_context.metadata - functions = metadata.info.get("functions") +def _compare_functions(autogen_context: AutogenContext, upgrade_ops, _): + functions: Functions | None = Functions.extract(autogen_context.metadata) if not functions: return - result = compare_functions(autogen_context.connection, functions, metadata) + assert autogen_context.connection + result = compare_functions(autogen_context.connection, functions) upgrade_ops.ops.extend(result) diff --git a/src/sqlalchemy_declarative_extensions/alembic/grant.py b/src/sqlalchemy_declarative_extensions/alembic/grant.py index 306f0f4..fa36958 100644 --- a/src/sqlalchemy_declarative_extensions/alembic/grant.py +++ b/src/sqlalchemy_declarative_extensions/alembic/grant.py @@ -1,4 +1,4 @@ -from typing import Optional +from __future__ import annotations from alembic.autogenerate.api import AutogenContext from alembic.autogenerate.compare import comparators @@ -9,6 +9,7 @@ from sqlalchemy_declarative_extensions.grant.base import Grants from sqlalchemy_declarative_extensions.grant.compare import ( GrantPrivilegesOp, + Operation, RevokePrivilegesOp, ) from sqlalchemy_declarative_extensions.role.base import Roles @@ -18,14 +19,14 @@ @comparators.dispatch_for("schema") def compare_grants(autogen_context: AutogenContext, upgrade_ops: UpgradeOps, _): - if autogen_context.metadata is None or autogen_context.connection is None: + if autogen_context.connection is None: return # pragma: no cover - grants: Optional[Grants] = autogen_context.metadata.info.get("grants") + grants: Grants | None = Grants.extract(autogen_context.metadata) if not grants: return - roles: Optional[Roles] = autogen_context.metadata.info.get("roles") + roles: Roles | None = Roles.extract(autogen_context.metadata) result = compare.compare_grants(autogen_context.connection, grants, roles=roles) if not result: @@ -43,10 +44,6 @@ def compare_grants(autogen_context: AutogenContext, upgrade_ops: UpgradeOps, _): @renderers.dispatch_for(GrantPrivilegesOp) -def render_grant(_, op: GrantPrivilegesOp): - return f'op.execute(sa.text("""{op.to_sql()}"""))' - - @renderers.dispatch_for(RevokePrivilegesOp) -def render_revoke(_, op: RevokePrivilegesOp): +def render_grant(_, op: Operation): return f'op.execute(sa.text("""{op.to_sql()}"""))' diff --git a/src/sqlalchemy_declarative_extensions/alembic/procedure.py b/src/sqlalchemy_declarative_extensions/alembic/procedure.py index 8e42cf3..500024c 100644 --- a/src/sqlalchemy_declarative_extensions/alembic/procedure.py +++ b/src/sqlalchemy_declarative_extensions/alembic/procedure.py @@ -1,7 +1,10 @@ +from __future__ import annotations + from alembic.autogenerate.api import AutogenContext from alembic.autogenerate.compare import comparators from alembic.autogenerate.render import renderers +from sqlalchemy_declarative_extensions.procedure.base import Procedures from sqlalchemy_declarative_extensions.procedure.compare import ( CreateProcedureOp, DropProcedureOp, @@ -12,13 +15,13 @@ @comparators.dispatch_for("schema") -def _compare_procedures(autogen_context, upgrade_ops, _): - metadata = autogen_context.metadata - procedures = metadata.info.get("procedures") +def _compare_procedures(autogen_context: AutogenContext, upgrade_ops, _): + procedures: Procedures | None = Procedures.extract(autogen_context.metadata) if not procedures: return - result = compare_procedures(autogen_context.connection, procedures, metadata) + assert autogen_context.connection + result = compare_procedures(autogen_context.connection, procedures) upgrade_ops.ops.extend(result) diff --git a/src/sqlalchemy_declarative_extensions/alembic/role.py b/src/sqlalchemy_declarative_extensions/alembic/role.py index 38717ae..310dc46 100644 --- a/src/sqlalchemy_declarative_extensions/alembic/role.py +++ b/src/sqlalchemy_declarative_extensions/alembic/role.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from alembic.autogenerate.api import AutogenContext from alembic.autogenerate.compare import comparators from alembic.autogenerate.render import renderers @@ -6,6 +8,7 @@ from sqlalchemy_declarative_extensions.role.compare import ( CreateRoleOp, DropRoleOp, + Roles, UpdateRoleOp, compare_roles, ) @@ -16,11 +19,12 @@ @comparators.dispatch_for("schema") -def _compare_roles(autogen_context, upgrade_ops, _): - roles = autogen_context.metadata.info.get("roles") +def _compare_roles(autogen_context: AutogenContext, upgrade_ops, _): + roles: Roles | None = Roles.extract(autogen_context.metadata) if not roles: return + assert autogen_context.connection result = compare_roles(autogen_context.connection, roles) upgrade_ops.ops[0:0] = result diff --git a/src/sqlalchemy_declarative_extensions/alembic/schema.py b/src/sqlalchemy_declarative_extensions/alembic/schema.py index 55b9f15..f77947d 100644 --- a/src/sqlalchemy_declarative_extensions/alembic/schema.py +++ b/src/sqlalchemy_declarative_extensions/alembic/schema.py @@ -19,8 +19,7 @@ @comparators.dispatch_for("schema") def compare_schemas(autogen_context: AutogenContext, upgrade_ops, _): - assert autogen_context.metadata - schemas: Schemas | None = autogen_context.metadata.info.get("schemas") + schemas: Schemas | None = Schemas.extract(autogen_context.metadata) if not schemas: return diff --git a/src/sqlalchemy_declarative_extensions/alembic/trigger.py b/src/sqlalchemy_declarative_extensions/alembic/trigger.py index 306941f..aa8b830 100644 --- a/src/sqlalchemy_declarative_extensions/alembic/trigger.py +++ b/src/sqlalchemy_declarative_extensions/alembic/trigger.py @@ -1,7 +1,10 @@ +from __future__ import annotations + from alembic.autogenerate.api import AutogenContext from alembic.autogenerate.compare import comparators from alembic.autogenerate.render import renderers +from sqlalchemy_declarative_extensions.trigger.base import Triggers from sqlalchemy_declarative_extensions.trigger.compare import ( CreateTriggerOp, DropTriggerOp, @@ -11,13 +14,13 @@ @comparators.dispatch_for("schema") -def _compare_triggers(autogen_context, upgrade_ops, _): - metadata = autogen_context.metadata - triggers = metadata.info.get("triggers") +def _compare_triggers(autogen_context: AutogenContext, upgrade_ops, _): + triggers: Triggers | None = Triggers.extract(autogen_context.metadata) if not triggers: return - result = compare_triggers(autogen_context.connection, triggers, metadata) + assert autogen_context.connection + result = compare_triggers(autogen_context.connection, triggers) upgrade_ops.ops.extend(result) diff --git a/src/sqlalchemy_declarative_extensions/function/base.py b/src/sqlalchemy_declarative_extensions/function/base.py index 0843d09..d797f4d 100644 --- a/src/sqlalchemy_declarative_extensions/function/base.py +++ b/src/sqlalchemy_declarative_extensions/function/base.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass, field, replace -from typing import Iterable +from typing import Iterable, Sequence from sqlalchemy import MetaData from typing_extensions import Self @@ -92,6 +92,36 @@ def coerce_from_unknown( return None + @classmethod + def extract(cls, metadata: MetaData | list[MetaData | None] | None) -> Self | None: + if not isinstance(metadata, Sequence): + metadata = [metadata] + + instances: list[Self] = [ + m.info["functions"] for m in metadata if m and m.info.get("functions") + ] + + instance_count = len(instances) + if instance_count == 0: + return None + + if instance_count == 1: + return instances[0] + + if not all( + x.ignore_unspecified == instances[0].ignore_unspecified for x in instances + ): + raise ValueError( + "All combined `Functions` instances must agree on the set of settings: ignore_unspecified" + ) + + functions = [s for instance in instances for s in instance.functions] + ignore = [s for instance in instances for s in instance.ignore] + ignore_unspecified = instances[0].ignore_unspecified + return cls( + functions=functions, ignore_unspecified=ignore_unspecified, ignore=ignore + ) + def append(self, function: Function): self.functions.append(function) diff --git a/src/sqlalchemy_declarative_extensions/function/compare.py b/src/sqlalchemy_declarative_extensions/function/compare.py index 888219e..385f187 100644 --- a/src/sqlalchemy_declarative_extensions/function/compare.py +++ b/src/sqlalchemy_declarative_extensions/function/compare.py @@ -4,7 +4,6 @@ from dataclasses import dataclass from typing import Sequence, Union -from sqlalchemy import MetaData from sqlalchemy.engine import Connection from sqlalchemy_declarative_extensions.dialects import get_function_cls, get_functions @@ -48,11 +47,7 @@ def to_sql(self) -> list[str]: Operation = Union[CreateFunctionOp, UpdateFunctionOp, DropFunctionOp] -def compare_functions( - connection: Connection, - functions: Functions, - metadata: MetaData, -) -> list[Operation]: +def compare_functions(connection: Connection, functions: Functions) -> list[Operation]: result: list[Operation] = [] functions_by_name = {f.qualified_name: f for f in functions.functions} diff --git a/src/sqlalchemy_declarative_extensions/function/ddl.py b/src/sqlalchemy_declarative_extensions/function/ddl.py index a2775cb..d17cbe2 100644 --- a/src/sqlalchemy_declarative_extensions/function/ddl.py +++ b/src/sqlalchemy_declarative_extensions/function/ddl.py @@ -10,7 +10,7 @@ def function_ddl(functions: Functions, function_filter: list[str] | None = None): def after_create(metadata: MetaData, connection: Connection, **_): - result = compare_functions(connection, functions, metadata) + result = compare_functions(connection, functions) for op in result: if not match_name(op.function.qualified_name, function_filter): continue diff --git a/src/sqlalchemy_declarative_extensions/grant/base.py b/src/sqlalchemy_declarative_extensions/grant/base.py index 53abd83..ce5de65 100644 --- a/src/sqlalchemy_declarative_extensions/grant/base.py +++ b/src/sqlalchemy_declarative_extensions/grant/base.py @@ -1,7 +1,10 @@ from __future__ import annotations from dataclasses import dataclass, field, replace -from typing import Iterable, Union +from typing import Iterable, Sequence, Union + +from sqlalchemy import MetaData +from typing_extensions import Self from sqlalchemy_declarative_extensions.dialects import postgresql @@ -60,6 +63,49 @@ def coerce_from_unknown(cls, unknown: None | Iterable[G] | Grants) -> Grants | N return None + @classmethod + def extract(cls, metadata: MetaData | list[MetaData | None] | None) -> Self | None: + if not isinstance(metadata, Sequence): + metadata = [metadata] + + instances: list[Self] = [ + m.info["grants"] for m in metadata if m and m.info.get("grants") + ] + + instance_count = len(instances) + if instance_count == 0: + return None + + if instance_count == 1: + return instances[0] + + if not all( + x.ignore_unspecified == instances[0].ignore_unspecified + and x.ignore_self_grants == instances[0].ignore_self_grants + and x.only_defined_roles == instances[0].only_defined_roles + and x.default_grants_imply_grants + == instances[0].default_grants_imply_grants + for x in instances + ): + raise ValueError( + "All combined `Grants` instances must agree on the set of settings: " + "ignore_unspecified, ignore_self_grants, only_defined_roles, default_grants_imply_grants" + ) + + grants = [s for instance in instances for s in instance.grants] + + ignore_unspecified = instances[0].ignore_unspecified + ignore_self_grants = instances[0].ignore_self_grants + only_defined_roles = instances[0].only_defined_roles + default_grants_imply_grants = instances[0].default_grants_imply_grants + return cls( + grants=grants, + ignore_unspecified=ignore_unspecified, + ignore_self_grants=ignore_self_grants, + only_defined_roles=only_defined_roles, + default_grants_imply_grants=default_grants_imply_grants, + ) + def __iter__(self): yield from self.grants diff --git a/src/sqlalchemy_declarative_extensions/procedure/base.py b/src/sqlalchemy_declarative_extensions/procedure/base.py index 620fad6..1961142 100644 --- a/src/sqlalchemy_declarative_extensions/procedure/base.py +++ b/src/sqlalchemy_declarative_extensions/procedure/base.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass, field, replace -from typing import Iterable +from typing import Iterable, Sequence from sqlalchemy import MetaData from typing_extensions import Self @@ -87,6 +87,36 @@ def coerce_from_unknown( return None + @classmethod + def extract(cls, metadata: MetaData | list[MetaData | None] | None) -> Self | None: + if not isinstance(metadata, Sequence): + metadata = [metadata] + + instances: list[Self] = [ + m.info["procedures"] for m in metadata if m and m.info.get("procedures") + ] + + instance_count = len(instances) + if instance_count == 0: + return None + + if instance_count == 1: + return instances[0] + + if not all( + x.ignore_unspecified == instances[0].ignore_unspecified for x in instances + ): + raise ValueError( + "All combined `Procedures` instances must agree on the set of settings: ignore_unspecified" + ) + + procedures = [s for instance in instances for s in instance.procedures] + ignore = [s for instance in instances for s in instance.ignore] + ignore_unspecified = instances[0].ignore_unspecified + return cls( + procedures=procedures, ignore_unspecified=ignore_unspecified, ignore=ignore + ) + def append(self, procedure: Procedure): self.procedures.append(procedure) diff --git a/src/sqlalchemy_declarative_extensions/procedure/compare.py b/src/sqlalchemy_declarative_extensions/procedure/compare.py index eb53341..1c4579b 100644 --- a/src/sqlalchemy_declarative_extensions/procedure/compare.py +++ b/src/sqlalchemy_declarative_extensions/procedure/compare.py @@ -4,7 +4,6 @@ from dataclasses import dataclass from typing import Sequence, Union -from sqlalchemy import MetaData from sqlalchemy.engine import Connection from sqlalchemy_declarative_extensions.dialects import get_procedure_cls, get_procedures @@ -49,9 +48,7 @@ def to_sql(self) -> list[str]: def compare_procedures( - connection: Connection, - procedures: Procedures, - metadata: MetaData, + connection: Connection, procedures: Procedures ) -> list[Operation]: result: list[Operation] = [] diff --git a/src/sqlalchemy_declarative_extensions/procedure/ddl.py b/src/sqlalchemy_declarative_extensions/procedure/ddl.py index ec4bb9b..88e05f3 100644 --- a/src/sqlalchemy_declarative_extensions/procedure/ddl.py +++ b/src/sqlalchemy_declarative_extensions/procedure/ddl.py @@ -10,7 +10,7 @@ def procedure_ddl(procedures: Procedures, procedure_filter: list[str] | None = None): def after_create(metadata: MetaData, connection: Connection, **_): - result = compare_procedures(connection, procedures, metadata) + result = compare_procedures(connection, procedures) for op in result: if not match_name(op.procedure.qualified_name, procedure_filter): continue diff --git a/src/sqlalchemy_declarative_extensions/role/base.py b/src/sqlalchemy_declarative_extensions/role/base.py index 5136f61..cd56ba4 100644 --- a/src/sqlalchemy_declarative_extensions/role/base.py +++ b/src/sqlalchemy_declarative_extensions/role/base.py @@ -3,6 +3,9 @@ from dataclasses import dataclass, field, replace from typing import Generator, Iterable, Sequence +from sqlalchemy import MetaData +from typing_extensions import Self + from sqlalchemy_declarative_extensions.dialects import postgresql from sqlalchemy_declarative_extensions.role import generic @@ -26,6 +29,39 @@ def coerce_from_unknown( return None + @classmethod + def extract(cls, metadata: MetaData | list[MetaData | None] | None) -> Self | None: + if not isinstance(metadata, Sequence): + metadata = [metadata] + + instances: list[Self] = [ + m.info["roles"] for m in metadata if m and m.info.get("roles") + ] + + instance_count = len(instances) + if instance_count == 0: + return None + + if instance_count == 1: + return instances[0] + + if not all( + x.ignore_unspecified == instances[0].ignore_unspecified for x in instances + ): + raise ValueError( + "All combined `Roles` instances must agree on the set of settings: ignore_unspecified" + ) + + roles = tuple(s for instance in instances for s in instance.roles) + ignore_roles = [s for instance in instances for s in instance.ignore_roles] + + ignore_unspecified = instances[0].ignore_unspecified + return cls( + roles=roles, + ignore_unspecified=ignore_unspecified, + ignore_roles=ignore_roles, + ) + def __iter__(self) -> Generator[postgresql.Role | generic.Role, None, None]: yield from self.roles diff --git a/src/sqlalchemy_declarative_extensions/schema/base.py b/src/sqlalchemy_declarative_extensions/schema/base.py index c94f0c6..beca450 100644 --- a/src/sqlalchemy_declarative_extensions/schema/base.py +++ b/src/sqlalchemy_declarative_extensions/schema/base.py @@ -3,6 +3,7 @@ from dataclasses import dataclass, replace from typing import TYPE_CHECKING, Iterable, Sequence +from sqlalchemy import MetaData from sqlalchemy.sql.base import Executable from sqlalchemy.sql.ddl import CreateSchema, DropSchema from typing_extensions import Self @@ -52,6 +53,33 @@ def coerce_from_unknown( return None + @classmethod + def extract(cls, metadata: MetaData | list[MetaData | None] | None) -> Self | None: + if not isinstance(metadata, Sequence): + metadata = [metadata] + + instances: list[Self] = [ + m.info["schemas"] for m in metadata if m and m.info.get("schemas") + ] + + instance_count = len(instances) + if instance_count == 0: + return None + + if instance_count == 1: + return instances[0] + + if not all( + x.ignore_unspecified == instances[0].ignore_unspecified for x in instances + ): + raise ValueError( + "All combined `Schemas` instances must agree on the set of settings: ignore_unspecified" + ) + + schemas = tuple(s for instance in instances for s in instance.schemas) + ignore_unspecified = instances[0].ignore_unspecified + return cls(schemas=schemas, ignore_unspecified=ignore_unspecified) + def __iter__(self): yield from self.schemas diff --git a/src/sqlalchemy_declarative_extensions/trigger/base.py b/src/sqlalchemy_declarative_extensions/trigger/base.py index 65359ea..b557f88 100644 --- a/src/sqlalchemy_declarative_extensions/trigger/base.py +++ b/src/sqlalchemy_declarative_extensions/trigger/base.py @@ -1,10 +1,11 @@ from __future__ import annotations from dataclasses import dataclass, field, replace -from typing import Iterable +from typing import Iterable, Sequence from sqlalchemy import MetaData from sqlalchemy.engine import Connection +from typing_extensions import Self from sqlalchemy_declarative_extensions.sqlalchemy import HasMetaData @@ -51,6 +52,33 @@ def coerce_from_unknown( return None + @classmethod + def extract(cls, metadata: MetaData | list[MetaData | None] | None) -> Self | None: + if not isinstance(metadata, Sequence): + metadata = [metadata] + + instances: list[Self] = [ + m.info["triggers"] for m in metadata if m and m.info.get("triggers") + ] + + instance_count = len(instances) + if instance_count == 0: + return None + + if instance_count == 1: + return instances[0] + + if not all( + x.ignore_unspecified == instances[0].ignore_unspecified for x in instances + ): + raise ValueError( + "All combined `Triggers` instances must agree on the set of settings: ignore_unspecified" + ) + + triggers = [s for instance in instances for s in instance.triggers] + ignore_unspecified = instances[0].ignore_unspecified + return cls(triggers=triggers, ignore_unspecified=ignore_unspecified) + def append(self, trigger: Trigger): self.triggers.append(trigger) diff --git a/src/sqlalchemy_declarative_extensions/trigger/compare.py b/src/sqlalchemy_declarative_extensions/trigger/compare.py index 5c691c0..3b846e0 100644 --- a/src/sqlalchemy_declarative_extensions/trigger/compare.py +++ b/src/sqlalchemy_declarative_extensions/trigger/compare.py @@ -3,7 +3,6 @@ from dataclasses import dataclass from typing import Union -from sqlalchemy import MetaData from sqlalchemy.engine import Connection from sqlalchemy_declarative_extensions.dialects import get_triggers @@ -47,9 +46,7 @@ def to_sql(self, _) -> list[str]: Operation = Union[CreateTriggerOp, UpdateTriggerOp, DropTriggerOp] -def compare_triggers( - connection: Connection, triggers: Triggers, metadata: MetaData -) -> list[Operation]: +def compare_triggers(connection: Connection, triggers: Triggers) -> list[Operation]: result: list[Operation] = [] triggers_by_name = {r.name: r for r in triggers.triggers} diff --git a/src/sqlalchemy_declarative_extensions/trigger/ddl.py b/src/sqlalchemy_declarative_extensions/trigger/ddl.py index 75a7194..0a5d3bb 100644 --- a/src/sqlalchemy_declarative_extensions/trigger/ddl.py +++ b/src/sqlalchemy_declarative_extensions/trigger/ddl.py @@ -10,7 +10,7 @@ def trigger_ddl(triggers: Triggers, trigger_filter: list[str] | None = None): def after_create(metadata: MetaData, connection: Connection, **_): - result = compare_triggers(connection, triggers, metadata) + result = compare_triggers(connection, triggers) for op in result: if not match_name(op.trigger.name, trigger_filter): continue diff --git a/tests/examples/test_function_leading_whitespace/test_migrations.py b/tests/examples/test_function_leading_whitespace/test_migrations.py index 1f4d272..98eae28 100644 --- a/tests/examples/test_function_leading_whitespace/test_migrations.py +++ b/tests/examples/test_function_leading_whitespace/test_migrations.py @@ -20,9 +20,5 @@ def test_apply_autogenerated_revision(alembic_runner: MigrationContext, alembic_ alembic_runner.migrate_up_one() with alembic_engine.connect() as conn: - result = compare_functions( - conn, - Base.metadata.info["functions"], - Base.metadata, - ) + result = compare_functions(conn, Base.metadata.info["functions"]) assert result == [] diff --git a/tests/function/test_create.py b/tests/function/test_create.py index df34d6b..8268d41 100644 --- a/tests/function/test_create.py +++ b/tests/function/test_create.py @@ -48,5 +48,5 @@ def test_create(pg): assert result == 2 connection = pg.connection() - diff = compare_functions(connection, Base.metadata.info["functions"], Base.metadata) + diff = compare_functions(connection, Base.metadata.info["functions"]) assert diff == [] diff --git a/tests/function/test_create_mysql.py b/tests/function/test_create_mysql.py index 39ec86d..bf4c4cd 100644 --- a/tests/function/test_create_mysql.py +++ b/tests/function/test_create_mysql.py @@ -55,5 +55,5 @@ def test_create(db): assert result == 2 connection = db.connection() - diff = compare_functions(connection, Base.metadata.info["functions"], Base.metadata) + diff = compare_functions(connection, Base.metadata.info["functions"]) assert diff == [] diff --git a/tests/function/test_create_requires_normalization.py b/tests/function/test_create_requires_normalization.py index 9998b5d..dd0867c 100644 --- a/tests/function/test_create_requires_normalization.py +++ b/tests/function/test_create_requires_normalization.py @@ -55,5 +55,5 @@ def test_create_with_complex_function_requiring_normalization(pg): assert result == 1 connection = pg.connection() - diff = compare_functions(connection, Base.metadata.info["functions"], Base.metadata) + diff = compare_functions(connection, Base.metadata.info["functions"]) assert diff == [] diff --git a/tests/function/test_metadata_sequence.py b/tests/function/test_metadata_sequence.py new file mode 100644 index 0000000..0f07e5f --- /dev/null +++ b/tests/function/test_metadata_sequence.py @@ -0,0 +1,37 @@ +import pytest +import sqlalchemy + +from sqlalchemy_declarative_extensions import ( + Function, + Functions, + declare_database, +) + +metadata1 = sqlalchemy.MetaData() +metadata2 = sqlalchemy.MetaData() +metadata3 = sqlalchemy.MetaData() + +declare_database( + metadata1, functions=Functions(ignore=["one"]).are(Function("foo", "")) +) +declare_database( + metadata2, functions=Functions(ignore=["two"]).are(Function("bar", "")) +) +declare_database(metadata3, functions=Functions(ignore_unspecified=True)) + + +def test_invalid_combination(): + with pytest.raises(ValueError): + Functions.extract([metadata1, metadata3]) + + +def test_valid_combination(): + functions = Functions.extract([metadata1, metadata2]) + assert functions == Functions( + functions=[Function("foo", ""), Function("bar", "")], ignore=["one", "two"] + ) + + +def test_single(): + functions = Functions.extract(metadata1) + assert functions is metadata1.info["functions"] diff --git a/tests/function/test_pg_security_definer.py b/tests/function/test_pg_security_definer.py index d381673..c849d88 100644 --- a/tests/function/test_pg_security_definer.py +++ b/tests/function/test_pg_security_definer.py @@ -101,5 +101,5 @@ def test_function_security(pg): assert result == 3 connection = pg.connection() - diff = compare_functions(connection, Base.metadata.info["functions"], Base.metadata) + diff = compare_functions(connection, Base.metadata.info["functions"]) assert diff == [] diff --git a/tests/grant/test_metadata_sequence.py b/tests/grant/test_metadata_sequence.py new file mode 100644 index 0000000..0e80b31 --- /dev/null +++ b/tests/grant/test_metadata_sequence.py @@ -0,0 +1,75 @@ +import pytest +import sqlalchemy + +from sqlalchemy_declarative_extensions import ( + Grants, + declare_database, +) +from sqlalchemy_declarative_extensions.dialects.postgresql import Grant + +metadata1 = sqlalchemy.MetaData() +metadata2 = sqlalchemy.MetaData() +metadata3 = sqlalchemy.MetaData() +metadata4 = sqlalchemy.MetaData() +metadata5 = sqlalchemy.MetaData() +metadata6 = sqlalchemy.MetaData() + +declare_database( + metadata1, grants=Grants().are(Grant.new("select", to="foo").on_tables()) +) +declare_database( + metadata2, grants=Grants().are(Grant.new("select", to="bar").on_tables()) +) +declare_database( + metadata3, + grants=Grants(ignore_unspecified=True).are( + Grant.new("select", to="baz").on_tables() + ), +) +declare_database( + metadata4, + grants=Grants(ignore_self_grants=False).are( + Grant.new("select", to="baz").on_tables() + ), +) +declare_database( + metadata5, + grants=Grants(only_defined_roles=False).are( + Grant.new("select", to="baz").on_tables() + ), +) +declare_database( + metadata6, + grants=Grants(default_grants_imply_grants=False).are( + Grant.new("select", to="baz").on_tables() + ), +) + + +def test_invalid_combination(): + with pytest.raises(ValueError): + Grants.extract([metadata1, metadata3]) + + with pytest.raises(ValueError): + Grants.extract([metadata1, metadata4]) + + with pytest.raises(ValueError): + Grants.extract([metadata1, metadata5]) + + with pytest.raises(ValueError): + Grants.extract([metadata1, metadata6]) + + +def test_valid_combination(): + schemas = Grants.extract([metadata1, metadata2]) + assert schemas == Grants( + grants=[ + Grant.new("select", to="foo").on_tables(), + Grant.new("select", to="bar").on_tables(), + ] + ) + + +def test_single(): + grants = Grants.extract(metadata1) + assert grants is metadata1.info["grants"] diff --git a/tests/procedure/test_create.py b/tests/procedure/test_create.py index c8ae4e1..d17a2df 100644 --- a/tests/procedure/test_create.py +++ b/tests/procedure/test_create.py @@ -47,7 +47,5 @@ def test_create(pg): assert result == 2 connection = pg.connection() - diff = compare_procedures( - connection, Base.metadata.info["procedures"], Base.metadata - ) + diff = compare_procedures(connection, Base.metadata.info["procedures"]) assert diff == [] diff --git a/tests/procedure/test_create_mysql.py b/tests/procedure/test_create_mysql.py index 5ef980c..d1cd70a 100644 --- a/tests/procedure/test_create_mysql.py +++ b/tests/procedure/test_create_mysql.py @@ -70,7 +70,5 @@ def test_create(db): assert result == 6 connection = db.connection() - diff = compare_procedures( - connection, Base.metadata.info["procedures"], Base.metadata - ) + diff = compare_procedures(connection, Base.metadata.info["procedures"]) assert diff == [] diff --git a/tests/procedure/test_create_requires_normalization.py b/tests/procedure/test_create_requires_normalization.py index 59a3ebf..549a952 100644 --- a/tests/procedure/test_create_requires_normalization.py +++ b/tests/procedure/test_create_requires_normalization.py @@ -50,7 +50,5 @@ def test_create_with_complex_procedure_requiring_normalization(pg): assert result == 1 connection = pg.connection() - diff = compare_procedures( - connection, Base.metadata.info["procedures"], Base.metadata - ) + diff = compare_procedures(connection, Base.metadata.info["procedures"]) assert diff == [] diff --git a/tests/procedure/test_metadata_sequence.py b/tests/procedure/test_metadata_sequence.py new file mode 100644 index 0000000..cbfc1fe --- /dev/null +++ b/tests/procedure/test_metadata_sequence.py @@ -0,0 +1,37 @@ +import pytest +import sqlalchemy + +from sqlalchemy_declarative_extensions import ( + Procedure, + Procedures, + declare_database, +) + +metadata1 = sqlalchemy.MetaData() +metadata2 = sqlalchemy.MetaData() +metadata3 = sqlalchemy.MetaData() + +declare_database( + metadata1, procedures=Procedures(ignore=["one"]).are(Procedure("foo", "")) +) +declare_database( + metadata2, procedures=Procedures(ignore=["two"]).are(Procedure("bar", "")) +) +declare_database(metadata3, procedures=Procedures(ignore_unspecified=True)) + + +def test_invalid_combination(): + with pytest.raises(ValueError): + Procedures.extract([metadata1, metadata3]) + + +def test_valid_combination(): + procedures = Procedures.extract([metadata1, metadata2]) + assert procedures == Procedures( + procedures=[Procedure("foo", ""), Procedure("bar", "")], ignore=["one", "two"] + ) + + +def test_single(): + procedures = Procedures.extract(metadata1) + assert procedures is metadata1.info["procedures"] diff --git a/tests/procedure/test_pg_security_definer.py b/tests/procedure/test_pg_security_definer.py index 5975f55..d1c619e 100644 --- a/tests/procedure/test_pg_security_definer.py +++ b/tests/procedure/test_pg_security_definer.py @@ -101,7 +101,5 @@ def test_procedure_security(pg): assert result == 3 connection = pg.connection() - diff = compare_procedures( - connection, Base.metadata.info["procedures"], Base.metadata - ) + diff = compare_procedures(connection, Base.metadata.info["procedures"]) assert diff == [] diff --git a/tests/role/test_metadata_sequence.py b/tests/role/test_metadata_sequence.py new file mode 100644 index 0000000..aad6af1 --- /dev/null +++ b/tests/role/test_metadata_sequence.py @@ -0,0 +1,33 @@ +import pytest +import sqlalchemy + +from sqlalchemy_declarative_extensions import ( + Role, + Roles, + declare_database, +) + +metadata1 = sqlalchemy.MetaData() +metadata2 = sqlalchemy.MetaData() +metadata3 = sqlalchemy.MetaData() + +declare_database(metadata1, roles=Roles(ignore_roles=["one"]).are("foo")) +declare_database(metadata2, roles=Roles(ignore_roles=["two"]).are("bar")) +declare_database(metadata3, roles=Roles(ignore_unspecified=True).are("baz")) + + +def test_invalid_combination(): + with pytest.raises(ValueError): + Roles.extract([metadata1, metadata3]) + + +def test_valid_combination(): + schemas = Roles.extract([metadata1, metadata2]) + assert schemas == Roles( + roles=(Role("foo"), Role("bar")), ignore_roles=["one", "two"] + ) + + +def test_single(): + roles = Roles.extract(metadata1) + assert roles is metadata1.info["roles"] diff --git a/tests/schema/test_metadata_sequence.py b/tests/schema/test_metadata_sequence.py new file mode 100644 index 0000000..ace29a4 --- /dev/null +++ b/tests/schema/test_metadata_sequence.py @@ -0,0 +1,31 @@ +import pytest +import sqlalchemy + +from sqlalchemy_declarative_extensions import ( + Schema, + Schemas, + declare_database, +) + +metadata1 = sqlalchemy.MetaData() +metadata2 = sqlalchemy.MetaData() +metadata3 = sqlalchemy.MetaData() + +declare_database(metadata1, schemas=Schemas().are("foo")) +declare_database(metadata2, schemas=Schemas().are("bar")) +declare_database(metadata3, schemas=Schemas(ignore_unspecified=True).are("baz")) + + +def test_invalid_combination(): + with pytest.raises(ValueError): + Schemas.extract([metadata1, metadata3]) + + +def test_valid_combination(): + schemas = Schemas.extract([metadata1, metadata2]) + assert schemas == Schemas(schemas=(Schema("foo"), Schema("bar"))) + + +def test_single(): + schemas = Schemas.extract(metadata1) + assert schemas is metadata1.info["schemas"] diff --git a/tests/trigger/test_arguments.py b/tests/trigger/test_arguments.py index b1cc32e..95eb370 100644 --- a/tests/trigger/test_arguments.py +++ b/tests/trigger/test_arguments.py @@ -73,5 +73,5 @@ def test_create(pg): assert result == [(5, 3), (7, 3)] connection = pg.connection() - diff = compare_triggers(connection, Base.metadata.info["triggers"], Base.metadata) + diff = compare_triggers(connection, Base.metadata.info["triggers"]) assert diff == [] diff --git a/tests/trigger/test_create_mysql.py b/tests/trigger/test_create_mysql.py index e8309eb..c1e36fa 100644 --- a/tests/trigger/test_create_mysql.py +++ b/tests/trigger/test_create_mysql.py @@ -48,5 +48,5 @@ def test_create(pg): assert result == [10, 12, 14] connection = pg.connection() - diff = compare_triggers(connection, Base.metadata.info["triggers"], Base.metadata) + diff = compare_triggers(connection, Base.metadata.info["triggers"]) assert diff == [] diff --git a/tests/trigger/test_create_postgres.py b/tests/trigger/test_create_postgres.py index 9d42ec9..9dab1fe 100644 --- a/tests/trigger/test_create_postgres.py +++ b/tests/trigger/test_create_postgres.py @@ -62,5 +62,5 @@ def test_create(pg): assert result == [5, 6] connection = pg.connection() - diff = compare_triggers(connection, Base.metadata.info["triggers"], Base.metadata) + diff = compare_triggers(connection, Base.metadata.info["triggers"]) assert diff == [] diff --git a/tests/trigger/test_drop_mysql.py b/tests/trigger/test_drop_mysql.py index 541b35a..de295e2 100644 --- a/tests/trigger/test_drop_mysql.py +++ b/tests/trigger/test_drop_mysql.py @@ -58,5 +58,5 @@ def test_drop(pg): assert result == [5, 6] connection = pg.connection() - diff = compare_triggers(connection, Base.metadata.info["triggers"], Base.metadata) + diff = compare_triggers(connection, Base.metadata.info["triggers"]) assert diff == [] diff --git a/tests/trigger/test_drop_postgres.py b/tests/trigger/test_drop_postgres.py index 729e7ed..2f3e9fa 100644 --- a/tests/trigger/test_drop_postgres.py +++ b/tests/trigger/test_drop_postgres.py @@ -61,5 +61,5 @@ def test_drop(pg): assert result == [5] connection = pg.connection() - diff = compare_triggers(connection, Base.metadata.info["triggers"], Base.metadata) + diff = compare_triggers(connection, Base.metadata.info["triggers"]) assert diff == [] diff --git a/tests/trigger/test_ignore_unspecified_postgres.py b/tests/trigger/test_ignore_unspecified_postgres.py index ff77dfa..f3d2697 100644 --- a/tests/trigger/test_ignore_unspecified_postgres.py +++ b/tests/trigger/test_ignore_unspecified_postgres.py @@ -62,5 +62,5 @@ def test_drop(pg): assert result == [5, 6] connection = pg.connection() - diff = compare_triggers(connection, Base.metadata.info["triggers"], Base.metadata) + diff = compare_triggers(connection, Base.metadata.info["triggers"]) assert diff == [] diff --git a/tests/trigger/test_metadata_sequence.py b/tests/trigger/test_metadata_sequence.py new file mode 100644 index 0000000..5809a1e --- /dev/null +++ b/tests/trigger/test_metadata_sequence.py @@ -0,0 +1,33 @@ +import pytest +import sqlalchemy + +from sqlalchemy_declarative_extensions import ( + Trigger, + Triggers, + declare_database, +) + +metadata1 = sqlalchemy.MetaData() +metadata2 = sqlalchemy.MetaData() +metadata3 = sqlalchemy.MetaData() + +declare_database(metadata1, triggers=Triggers().are(Trigger("foo", "", ""))) +declare_database(metadata2, triggers=Triggers().are(Trigger("bar", "", ""))) +declare_database(metadata3, triggers=Triggers(ignore_unspecified=True)) + + +def test_invalid_combination(): + with pytest.raises(ValueError): + Triggers.extract([metadata1, metadata3]) + + +def test_valid_combination(): + procedures = Triggers.extract([metadata1, metadata2]) + assert procedures == Triggers( + triggers=[Trigger("foo", "", ""), Trigger("bar", "", "")] + ) + + +def test_single(): + triggers = Triggers.extract(metadata1) + assert triggers is metadata1.info["triggers"] diff --git a/tests/trigger/test_missing_when_postgres.py b/tests/trigger/test_missing_when_postgres.py index 71967c6..ee9c328 100644 --- a/tests/trigger/test_missing_when_postgres.py +++ b/tests/trigger/test_missing_when_postgres.py @@ -61,5 +61,5 @@ def test_create(pg): assert result == [6] connection = pg.connection() - diff = compare_triggers(connection, Base.metadata.info["triggers"], Base.metadata) + diff = compare_triggers(connection, Base.metadata.info["triggers"]) assert diff == [] diff --git a/tests/trigger/test_update_mysql.py b/tests/trigger/test_update_mysql.py index 876564a..0856e6d 100644 --- a/tests/trigger/test_update_mysql.py +++ b/tests/trigger/test_update_mysql.py @@ -62,5 +62,5 @@ def test_update(pg): assert result == [6, 15] connection = pg.connection() - diff = compare_triggers(connection, Base.metadata.info["triggers"], Base.metadata) + diff = compare_triggers(connection, Base.metadata.info["triggers"]) assert diff == [] diff --git a/tests/trigger/test_update_postgres.py b/tests/trigger/test_update_postgres.py index d7cdd92..b1bc294 100644 --- a/tests/trigger/test_update_postgres.py +++ b/tests/trigger/test_update_postgres.py @@ -80,5 +80,5 @@ def test_update(pg): assert result == [5, 6, 10, 11] connection = pg.connection() - diff = compare_triggers(connection, Base.metadata.info["triggers"], Base.metadata) + diff = compare_triggers(connection, Base.metadata.info["triggers"]) assert diff == [] diff --git a/tests/trigger/test_when_references_new_postgres.py b/tests/trigger/test_when_references_new_postgres.py index 41f3b5a..14fe26b 100644 --- a/tests/trigger/test_when_references_new_postgres.py +++ b/tests/trigger/test_when_references_new_postgres.py @@ -69,5 +69,5 @@ def test_create(pg): assert result == [6] connection = pg.connection() - diff = compare_triggers(connection, Base.metadata.info["triggers"], Base.metadata) + diff = compare_triggers(connection, Base.metadata.info["triggers"]) assert diff == [] diff --git a/tests/trigger/test_wrapped_parens_postgres.py b/tests/trigger/test_wrapped_parens_postgres.py index 218bf90..47d191f 100644 --- a/tests/trigger/test_wrapped_parens_postgres.py +++ b/tests/trigger/test_wrapped_parens_postgres.py @@ -62,5 +62,5 @@ def test_create(pg): assert result == [6] connection = pg.connection() - diff = compare_triggers(connection, Base.metadata.info["triggers"], Base.metadata) + diff = compare_triggers(connection, Base.metadata.info["triggers"]) assert diff == []