Skip to content

Commit

Permalink
fix: Handle new columns during Row comparison evaluation.
Browse files Browse the repository at this point in the history
  • Loading branch information
DanCardin committed Feb 13, 2024
1 parent f9b4ea8 commit d676717
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 37 deletions.
112 changes: 78 additions & 34 deletions src/sqlalchemy_declarative_extensions/row/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from dataclasses import dataclass
from typing import Any, Union

from sqlalchemy import MetaData, Table, and_, not_, or_
from sqlalchemy.engine.base import Connection
from sqlalchemy.sql.expression import and_, not_, or_, select
from sqlalchemy.sql.schema import MetaData, Table

from sqlalchemy_declarative_extensions.dialects import (
check_table_exists,
get_current_schema,
)
from sqlalchemy_declarative_extensions.row.base import Row, Rows
from sqlalchemy_declarative_extensions.sql import split_schema
Expand All @@ -27,6 +27,7 @@ def insert_table_row(cls, operations, table, values):

def execute(self, conn: Connection):
table = get_table(conn, self.table)

query = table.insert().values(self.values)
conn.execute(query)

Expand Down Expand Up @@ -123,21 +124,20 @@ def get_table(conn: Connection, tablename: str):


def compare_rows(connection: Connection, metadata: MetaData, rows: Rows) -> list[RowOp]:
result: list[RowOp] = []
assert metadata.tables

current_schema = get_current_schema(connection)
result: list[RowOp] = []

existing_tables = resolve_existing_tables(connection, rows, current_schema)
existing_metadata, existing_tables = resolve_existing_tables(connection, rows)
assert existing_metadata.tables

# Collects table-specific primary keys so that we can efficiently compare rows
# further down by the pk
pk_to_row: dict[Table, dict[tuple[Any, ...], Row]] = {}
pk_to_row: dict[str, dict[tuple[Any, ...], Row]] = {}

# Collects the ongoing filter required to select all referenced records (by their pk)
filters_by_table: dict[Table, list] = {}
for row in rows:
row = row.qualify(current_schema)

table = metadata.tables.get(row.qualified_name)
if table is None:
raise ValueError(f"Unknown table: {row.qualified_name}")
Expand All @@ -150,41 +150,64 @@ def compare_rows(connection: Connection, metadata: MetaData, rows: Rows) -> list
)

pk = tuple([row.column_values[c.name] for c in table.primary_key.columns])
pk_to_row.setdefault(table, {})[pk] = row
filters_by_table.setdefault(table, []).append(
and_(*[c == row.column_values[c.name] for c in table.primary_key.columns])
)
pk_to_row.setdefault(table.fullname, {})[pk] = row

existing_table = existing_metadata.tables.get(row.qualified_name)
if existing_table is None:
continue

if existing_table.primary_key.columns:
filters_by_table.setdefault(existing_table, []).append(
and_(
*[
c == row.column_values[c.name]
for c in existing_table.primary_key.columns
]
)
)

existing_rows_by_table = collect_existing_record_data(
connection, filters_by_table, existing_tables
connection, existing_metadata, filters_by_table, existing_tables
)

table_row_inserts: dict[Table, list[dict[str, Any]]] = {}
table_row_updates: dict[
Table, tuple[list[dict[str, Any]], list[dict[str, Any]]]
] = {}
for table, pks in pk_to_row.items():
row_inserts = table_row_inserts.setdefault(table, [])
row_updates = table_row_updates.setdefault(table, ([], []))
current_table: Table | None = existing_metadata.tables.get(table)
dest_table = metadata.tables[table]

row_inserts = table_row_inserts.setdefault(dest_table, [])

existing_rows = existing_rows_by_table[table]
existing_rows = existing_rows_by_table.get(table, {})

stub_keys = {
key: None for row in pks.values() for key in row.column_values.keys()
}

for pk, row in pks.items():
if pk in existing_rows:
assert current_table is not None

existing_row = existing_rows[pk]
row_keys = row.column_values.keys()
record_dict = {k: v for k, v in existing_row.items() if k in row_keys}
if row.column_values == record_dict:

column_values = filter_column_data(current_table, row.column_values)
if column_values == record_dict:
continue

row_updates = table_row_updates.setdefault(current_table, ([], []))
row_updates[0].append(record_dict)
row_updates[1].append(row.column_values)
row_updates[1].append(column_values)
else:
row_inserts.append({**stub_keys, **row.column_values})
insert_values = {**stub_keys, **row.column_values}

insert_table: Table = (
current_table if current_table is not None else dest_table
)
row_inserts.append(filter_column_data(insert_table, insert_values))

# Deletes should get inserted first, so as to avoid foreign key constraint errors.
if not rows.ignore_unspecified:
Expand All @@ -197,11 +220,11 @@ def compare_rows(connection: Connection, metadata: MetaData, rows: Rows) -> list
if not table_exists:
continue

select = table.select()
statement = select(*table.primary_key)
if filter:
select = select.where(not_(or_(*filter)))
statement = statement.where(not_(or_(*filter)))

to_delete = connection.execute(select).fetchall()
to_delete = connection.execute(statement).fetchall()

if not to_delete:
continue
Expand Down Expand Up @@ -235,22 +258,30 @@ def compare_rows(connection: Connection, metadata: MetaData, rows: Rows) -> list


def resolve_existing_tables(
connection: Connection, rows: Rows, current_schema: str | None = None
) -> dict[str, bool]:
connection: Connection, rows: Rows
) -> tuple[MetaData, dict[str, bool]]:
"""Collect a map of referenced tables, to whether or not they exist."""
existing_metadata = MetaData()
assert existing_metadata.tables

result = {}
for row in rows:
row = row.qualify(current_schema)
if row.qualified_name in result:
continue

# If the table doesn't exist yet, we can likely assume it's being autogenerated
# in the current revision and as such, will just emit insert statements.
result[row.qualified_name] = check_table_exists(
table_exists = check_table_exists(
connection,
row.tablename,
schema=row.schema,
)
result[row.qualified_name] = table_exists

if table_exists and row.qualified_name not in existing_metadata.tables:
existing_metadata.reflect(
bind=connection, schema=row.schema, only=[row.tablename]
)

for fq_tablename in rows.included_tables:
schema, tablename = split_schema(fq_tablename)
Expand All @@ -260,25 +291,38 @@ def resolve_existing_tables(
schema=schema,
)

return result
return existing_metadata, result


def collect_existing_record_data(
connection: Connection, filters_by_table, existing_tables: dict[str, bool]
) -> dict[Table, dict[tuple[Any, ...], dict[str, Any]]]:
result = {}
connection: Connection,
metadata: MetaData,
filters_by_table,
existing_tables: dict[str, bool],
) -> dict[str, dict[tuple[Any, ...], dict[str, Any]]]:
assert metadata.tables

result: dict[str, dict[tuple[Any, ...], dict[str, Any]]] = {}
for table, filters in filters_by_table.items():
table_exists = existing_tables[table.fullname]
if not table_exists:
result[table] = {}
result[table.fullname] = {}
continue

primary_key_columns = [c.name for c in table.primary_key.columns]

records = connection.execute(table.select().where(or_(*filters))).fetchall()
existing_rows = result.setdefault(table, {})
for record in records:
current_table = metadata.tables[table.fullname]
records = connection.execute(select(current_table).where(or_(*filters)))
assert records

existing_rows = result.setdefault(table.fullname, {})
for record in records.fetchall():
record_dict = row_to_dict(record)
pk = tuple([record_dict[c] for c in primary_key_columns])
existing_rows[pk] = record_dict

return result


def filter_column_data(table: Table, row: dict):
return {c: v for c, v in row.items() if c in table.columns}
60 changes: 60 additions & 0 deletions tests/row/test_column_added.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from pytest_mock_resources import create_postgres_fixture
from sqlalchemy import Column, text, types

from sqlalchemy_declarative_extensions import (
Row,
Rows,
declarative_database,
register_sqlalchemy_events,
)
from sqlalchemy_declarative_extensions.row.compare import (
compare_rows,
)
from sqlalchemy_declarative_extensions.sqlalchemy import declarative_base

_Base = declarative_base()


@declarative_database
class Base(_Base): # type: ignore
__abstract__ = True

rows = Rows(ignore_unspecified=True).are(
Row("foo", id=1, name="qwer"),
)


class Foo(Base):
__tablename__ = "foo"

id = Column(types.Integer(), primary_key=True)
name = Column(types.Unicode(), nullable=True)


register_sqlalchemy_events(Base.metadata, rows=True)

pg = create_postgres_fixture(
scope="function", engine_kwargs={"echo": True}, session=True
)


def test_insert_missing(pg):
pg.execute(text("CREATE TABLE foo (id SERIAL PRIMARY KEY)"))
pg.commit()

Base.metadata.create_all(bind=pg.connection())
pg.commit()

result = pg.execute(text("SELECT * FROM foo")).fetchall()
assert result == [(1,)]

pg.execute(text("ALTER TABLE foo ADD name VARCHAR"))
pg.commit()

result = compare_rows(pg.connection(), Base.metadata, Base.rows)
for op in result:
op.execute(pg.connection())
pg.commit()

result = pg.execute(text("SELECT * FROM foo")).fetchall()
assert result == [(1, "qwer")]
3 changes: 0 additions & 3 deletions tests/view/test_add_constraint_to_existing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@
)
from sqlalchemy_declarative_extensions.sqlalchemy import declarative_base

()


_Base = declarative_base()


Expand Down

0 comments on commit d676717

Please sign in to comment.