Skip to content

Commit

Permalink
feat: Emit row migration ops as raw sql statements.
Browse files Browse the repository at this point in the history
  • Loading branch information
DanCardin committed Feb 14, 2024
1 parent 46e9aa3 commit 3bef997
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 66 deletions.
27 changes: 15 additions & 12 deletions src/sqlalchemy_declarative_extensions/alembic/row.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,23 @@ def compare_rows(autogen_context: AutogenContext, upgrade_ops: UpgradeOps, _):


@renderers.dispatch_for(InsertRowOp)
def render_insert_table_row(_, op: InsertRowOp):
return f"op.insert_table_row('{op.table}', {op.values})"


@renderers.dispatch_for(UpdateRowOp)
def render_update_table_row(_, op: UpdateRowOp):
return "op.update_table_row('{}', from_values={}, to_values={})".format(
op.table, op.from_values, op.to_values
)


@renderers.dispatch_for(DeleteRowOp)
def render_delete_table_row(_, op: DeleteRowOp):
return f"op.delete_table_row('{op.table}', {op.values})"
def render_insert_table_row(
autogen_context: AutogenContext, op: InsertRowOp | UpdateRowOp | DeleteRowOp
):
metadata = autogen_context.metadata
conn = autogen_context.connection
assert metadata
assert conn

result = []
for query in op.render(metadata):
query_str = query.compile(
dialect=conn.dialect, compile_kwargs={"literal_binds": True}
)
result.append(f'op.execute("{query_str}")')
return result


@Operations.implementation_for(InsertRowOp)
Expand Down
107 changes: 60 additions & 47 deletions src/sqlalchemy_declarative_extensions/row/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any, Union

from sqlalchemy.engine.base import Connection
from sqlalchemy.sql.expression import and_, not_, or_
from sqlalchemy.sql.expression import and_, not_, or_, text
from sqlalchemy.sql.schema import MetaData, Table

from sqlalchemy_declarative_extensions.dialects import (
Expand All @@ -25,11 +25,16 @@ def insert_table_row(cls, operations, table, values):
op = cls(table, values)
return operations.invoke(op)

def render(self, metadata: MetaData):
assert metadata.tables is not None
table = metadata.tables[self.table]
return [table.insert().values(self.values)]

def execute(self, conn: Connection):
table = get_table(conn, self.table)
metadata = get_metadata(conn, self.table)

query = table.insert().values(self.values)
conn.execute(query)
for query in self.render(metadata):
conn.execute(query)

def reverse(self):
return DeleteRowOp(self.table, self.values)
Expand All @@ -46,8 +51,9 @@ def update_table_row(cls, operations, table, from_values, to_values):
op = cls(table, from_values, to_values)
return operations.invoke(op)

def execute(self, conn: Connection):
table = get_table(conn, self.table)
def render(self, metadata: MetaData):
assert metadata.tables is not None
table = metadata.tables[self.table]

primary_key_columns = [c.name for c in table.primary_key.columns]

Expand All @@ -56,12 +62,20 @@ def execute(self, conn: Connection):
else:
to_values = self.to_values

result = []
for to_value in to_values:
where = [
table.c[c] == v for c, v in to_value.items() if c in primary_key_columns
]
values = {c: v for c, v in to_value.items() if c not in primary_key_columns}
query = table.update().where(*where).values(**values)
result.append(query)
return result

def execute(self, conn: Connection):
metadata = get_metadata(conn, self.table)

for query in self.render(metadata):
conn.execute(query)

def reverse(self):
Expand All @@ -78,8 +92,9 @@ def delete_table_row(cls, operations, table, values):
op = cls(table, values)
return operations.invoke(op)

def execute(self, conn: Connection):
table = get_table(conn, self.table)
def render(self, metadata: MetaData):
assert metadata.tables is not None
table = metadata.tables[self.table]

if isinstance(self.values, dict):
rows_values: list[dict[str, Any]] = [self.values]
Expand All @@ -100,14 +115,19 @@ def execute(self, conn: Connection):
for row_values in rows_values
]
)
query = table.delete().where(where)
conn.execute(query)
return [table.delete().where(where)]

def execute(self, conn: Connection):
metadata = get_metadata(conn, self.table)

for query in self.render(metadata):
conn.execute(query)

def reverse(self):
return InsertRowOp(self.table, self.values)


def get_table(conn: Connection, tablename: str):
def get_metadata(conn: Connection, tablename: str):
m = MetaData()

try:
Expand All @@ -117,16 +137,18 @@ def get_table(conn: Connection, tablename: str):
schema = None

m.reflect(conn, schema=schema, only=[table])
return m.tables[tablename]
return m


RowOp = Union[InsertRowOp, UpdateRowOp, DeleteRowOp]


def compare_rows(connection: Connection, metadata: MetaData, rows: Rows) -> list[RowOp]:
assert metadata.tables is not None

result: list[RowOp] = []

existing_metadata, existing_tables = resolve_existing_tables(connection, rows)
existing_tables = resolve_existing_tables(connection, rows)

# Collects table-specific primary keys so that we can efficiently compare rows
# further down by the pk
Expand All @@ -146,33 +168,32 @@ def compare_rows(connection: Connection, metadata: MetaData, rows: Rows) -> list
f"Row is missing primary key values required to declaratively specify: {row}"
)

table = metadata.tables.get(row.qualified_name)
if table is None:
continue

pk = tuple([row.column_values[c.name] for c in table.primary_key.columns])
pk_to_row.setdefault(table.fullname, {})[pk] = row

existing_table = existing_metadata.tables.get(row.qualified_name)
if existing_table is None:
continue

if existing_table.primary_key.columns:
filters_by_table.setdefault(existing_table, []).append(
if table.primary_key.columns:
filters_by_table.setdefault(table, []).append(
and_(
*[
c == row.column_values[c.name]
for c in existing_table.primary_key.columns
]
*[c == row.column_values[c.name] for c in table.primary_key.columns]
)
)

existing_rows_by_table = collect_existing_record_data(
connection, existing_metadata, filters_by_table, existing_tables
connection, filters_by_table, existing_tables
)

existing_metadata = MetaData()
assert existing_metadata.tables is not None

table_row_inserts: dict[Table, list[dict[str, Any]]] = {}
table_row_updates: dict[
Table, tuple[list[dict[str, Any]], list[dict[str, Any]]]
] = {}
for table, pks in pk_to_row.items():
current_table: Table | None = existing_metadata.tables.get(table)
dest_table = metadata.tables[table]

row_inserts = table_row_inserts.setdefault(dest_table, [])
Expand All @@ -185,7 +206,11 @@ def compare_rows(connection: Connection, metadata: MetaData, rows: Rows) -> list

for pk, row in pks.items():
if pk in existing_rows:
assert current_table is not None
if row.qualified_name not in existing_metadata.tables:
existing_metadata.reflect(
bind=connection, schema=row.schema, only=[row.tablename]
)
current_table: Table = existing_metadata.tables[table]

existing_row = existing_rows[pk]
row_keys = row.column_values.keys()
Expand All @@ -200,11 +225,7 @@ def compare_rows(connection: Connection, metadata: MetaData, rows: Rows) -> list
row_updates[1].append(column_values)
else:
insert_values = {**stub_keys, **row.column_values}

insert_table: Table = (
current_table if current_table is not None else dest_table
)
row_inserts.append(filter_column_data(insert_table, insert_values))
row_inserts.append(filter_column_data(dest_table, insert_values))

# Deletes should get inserted first, so as to avoid foreign key constraint errors.
if not rows.ignore_unspecified:
Expand Down Expand Up @@ -254,31 +275,20 @@ def compare_rows(connection: Connection, metadata: MetaData, rows: Rows) -> list
return result


def resolve_existing_tables(
connection: Connection, rows: Rows
) -> tuple[MetaData, dict[str, bool]]:
def resolve_existing_tables(connection: Connection, rows: Rows) -> dict[str, bool]:
"""Collect a map of referenced tables, to whether or not they exist."""
existing_metadata = MetaData()

result = {}
for row in rows:
if row.qualified_name in result:
continue

# If the table doesn't exist yet, we can likely assume it's being autogenerated
# in the current revision and as such, will just emit insert statements.
table_exists = check_table_exists(
connection,
row.tablename,
schema=row.schema,
)
result[row.qualified_name] = table_exists

if table_exists and row.qualified_name not in existing_metadata.tables:
existing_metadata.reflect(
bind=connection, schema=row.schema, only=[row.tablename]
)

for fq_tablename in rows.included_tables:
schema, tablename = split_schema(fq_tablename)
result[fq_tablename] = check_table_exists(
Expand All @@ -287,12 +297,11 @@ def resolve_existing_tables(
schema=schema,
)

return existing_metadata, result
return result


def collect_existing_record_data(
connection: Connection,
metadata: MetaData,
filters_by_table,
existing_tables: dict[str, bool],
) -> dict[str, dict[tuple[Any, ...], dict[str, Any]]]:
Expand All @@ -305,8 +314,12 @@ def collect_existing_record_data(

primary_key_columns = [c.name for c in table.primary_key.columns]

current_table = metadata.tables[table.fullname]
records = connection.execute(select(current_table).where(or_(*filters)))
filter_str = or_(*filters).compile(
dialect=connection.dialect, compile_kwargs={"literal_binds": True}
)
records = connection.execute(
text(f"SELECT * FROM {table.fullname} WHERE {filter_str}") # noqa: S608
)
assert records

existing_rows = result.setdefault(table.fullname, {})
Expand Down
12 changes: 8 additions & 4 deletions tests/row/test_column_added.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class Base(_Base): # type: ignore
__abstract__ = True

rows = Rows(ignore_unspecified=True).are(
Row("foo", id=1, name="qwer"),
Row("foo", id=1),
)


Expand All @@ -38,7 +38,7 @@ class Foo(Base):
)


def test_insert_missing(pg):
def test_column_added(pg):
pg.execute(text("CREATE TABLE foo (id SERIAL PRIMARY KEY)"))
pg.commit()

Expand All @@ -51,9 +51,13 @@ def test_insert_missing(pg):
pg.execute(text("ALTER TABLE foo ADD name VARCHAR"))
pg.commit()

result = compare_rows(pg.connection(), Base.metadata, Base.rows)
rows = Rows(ignore_unspecified=True).are(
Row("foo", id=1, name="qwer"),
)
result = compare_rows(pg.connection(), Base.metadata, rows)
for op in result:
op.execute(pg.connection())
for query in op.render(Base.metadata):
pg.execute(query)
pg.commit()

result = pg.execute(text("SELECT * FROM foo")).fetchall()
Expand Down
3 changes: 0 additions & 3 deletions tests/row/test_missing_primary_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@
)
from sqlalchemy_declarative_extensions.sqlalchemy import declarative_base

()


_Base = declarative_base()


Expand Down
52 changes: 52 additions & 0 deletions tests/row/test_schema_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from pytest_mock_resources import create_postgres_fixture
from sqlalchemy import Column, text, types

from sqlalchemy_declarative_extensions import (
Row,
Rows,
declarative_database,
register_sqlalchemy_events,
)
from sqlalchemy_declarative_extensions.row.compare import compare_rows
from sqlalchemy_declarative_extensions.sqlalchemy import declarative_base

_Base = declarative_base()


@declarative_database
class Base(_Base): # type: ignore
__abstract__ = True


class Foo(Base):
__tablename__ = "foo"

id = Column(types.Integer(), primary_key=True)
name = Column(types.Unicode(), nullable=True)


register_sqlalchemy_events(Base.metadata, rows=True)

pg = create_postgres_fixture(
scope="function", engine_kwargs={"echo": True}, session=True
)


def test_insert_missing(pg):
pg.execute(text("CREATE SCHEMA foo"))
pg.execute(text("SET SEARCH_PATH=foo"))

Base.metadata.create_all(bind=pg.connection())
pg.commit()

rows = Rows(ignore_unspecified=True).are(
Row("foo", id=1, name="qwer"),
)
result = compare_rows(pg.connection(), Base.metadata, rows)
for op in result:
for query in op.render(Base.metadata):
pg.execute(query)
pg.commit()

result = pg.execute(text("SELECT * FROM foo.foo")).fetchall()
assert result == [(1, "qwer")]

0 comments on commit 3bef997

Please sign in to comment.