From d676717d6a66a529bf849c2b0e11b9ac7da79a64 Mon Sep 17 00:00:00 2001 From: DanCardin Date: Tue, 13 Feb 2024 17:15:14 -0500 Subject: [PATCH] fix: Handle new columns during Row comparison evaluation. --- .../row/compare.py | 112 ++++++++++++------ tests/row/test_column_added.py | 60 ++++++++++ tests/view/test_add_constraint_to_existing.py | 3 - 3 files changed, 138 insertions(+), 37 deletions(-) create mode 100644 tests/row/test_column_added.py diff --git a/src/sqlalchemy_declarative_extensions/row/compare.py b/src/sqlalchemy_declarative_extensions/row/compare.py index 431a67f..32af712 100644 --- a/src/sqlalchemy_declarative_extensions/row/compare.py +++ b/src/sqlalchemy_declarative_extensions/row/compare.py @@ -3,12 +3,12 @@ from dataclasses import dataclass from typing import Any, Union -from sqlalchemy import MetaData, Table, and_, not_, or_ from sqlalchemy.engine.base import Connection +from sqlalchemy.sql.expression import and_, not_, or_, select +from sqlalchemy.sql.schema import MetaData, Table from sqlalchemy_declarative_extensions.dialects import ( check_table_exists, - get_current_schema, ) from sqlalchemy_declarative_extensions.row.base import Row, Rows from sqlalchemy_declarative_extensions.sql import split_schema @@ -27,6 +27,7 @@ def insert_table_row(cls, operations, table, values): def execute(self, conn: Connection): table = get_table(conn, self.table) + query = table.insert().values(self.values) conn.execute(query) @@ -123,21 +124,20 @@ def get_table(conn: Connection, tablename: str): def compare_rows(connection: Connection, metadata: MetaData, rows: Rows) -> list[RowOp]: - result: list[RowOp] = [] + assert metadata.tables - current_schema = get_current_schema(connection) + result: list[RowOp] = [] - existing_tables = resolve_existing_tables(connection, rows, current_schema) + existing_metadata, existing_tables = resolve_existing_tables(connection, rows) + assert existing_metadata.tables # Collects table-specific primary keys so that we can efficiently compare rows # further down by the pk - pk_to_row: dict[Table, dict[tuple[Any, ...], Row]] = {} + pk_to_row: dict[str, dict[tuple[Any, ...], Row]] = {} # Collects the ongoing filter required to select all referenced records (by their pk) filters_by_table: dict[Table, list] = {} for row in rows: - row = row.qualify(current_schema) - table = metadata.tables.get(row.qualified_name) if table is None: raise ValueError(f"Unknown table: {row.qualified_name}") @@ -150,13 +150,24 @@ def compare_rows(connection: Connection, metadata: MetaData, rows: Rows) -> list ) pk = tuple([row.column_values[c.name] for c in table.primary_key.columns]) - pk_to_row.setdefault(table, {})[pk] = row - filters_by_table.setdefault(table, []).append( - and_(*[c == row.column_values[c.name] for c in table.primary_key.columns]) - ) + pk_to_row.setdefault(table.fullname, {})[pk] = row + + existing_table = existing_metadata.tables.get(row.qualified_name) + if existing_table is None: + continue + + if existing_table.primary_key.columns: + filters_by_table.setdefault(existing_table, []).append( + and_( + *[ + c == row.column_values[c.name] + for c in existing_table.primary_key.columns + ] + ) + ) existing_rows_by_table = collect_existing_record_data( - connection, filters_by_table, existing_tables + connection, existing_metadata, filters_by_table, existing_tables ) table_row_inserts: dict[Table, list[dict[str, Any]]] = {} @@ -164,10 +175,12 @@ def compare_rows(connection: Connection, metadata: MetaData, rows: Rows) -> list Table, tuple[list[dict[str, Any]], list[dict[str, Any]]] ] = {} for table, pks in pk_to_row.items(): - row_inserts = table_row_inserts.setdefault(table, []) - row_updates = table_row_updates.setdefault(table, ([], [])) + current_table: Table | None = existing_metadata.tables.get(table) + dest_table = metadata.tables[table] + + row_inserts = table_row_inserts.setdefault(dest_table, []) - existing_rows = existing_rows_by_table[table] + existing_rows = existing_rows_by_table.get(table, {}) stub_keys = { key: None for row in pks.values() for key in row.column_values.keys() @@ -175,16 +188,26 @@ def compare_rows(connection: Connection, metadata: MetaData, rows: Rows) -> list for pk, row in pks.items(): if pk in existing_rows: + assert current_table is not None + existing_row = existing_rows[pk] row_keys = row.column_values.keys() record_dict = {k: v for k, v in existing_row.items() if k in row_keys} - if row.column_values == record_dict: + + column_values = filter_column_data(current_table, row.column_values) + if column_values == record_dict: continue + row_updates = table_row_updates.setdefault(current_table, ([], [])) row_updates[0].append(record_dict) - row_updates[1].append(row.column_values) + row_updates[1].append(column_values) else: - row_inserts.append({**stub_keys, **row.column_values}) + insert_values = {**stub_keys, **row.column_values} + + insert_table: Table = ( + current_table if current_table is not None else dest_table + ) + row_inserts.append(filter_column_data(insert_table, insert_values)) # Deletes should get inserted first, so as to avoid foreign key constraint errors. if not rows.ignore_unspecified: @@ -197,11 +220,11 @@ def compare_rows(connection: Connection, metadata: MetaData, rows: Rows) -> list if not table_exists: continue - select = table.select() + statement = select(*table.primary_key) if filter: - select = select.where(not_(or_(*filter))) + statement = statement.where(not_(or_(*filter))) - to_delete = connection.execute(select).fetchall() + to_delete = connection.execute(statement).fetchall() if not to_delete: continue @@ -235,22 +258,30 @@ def compare_rows(connection: Connection, metadata: MetaData, rows: Rows) -> list def resolve_existing_tables( - connection: Connection, rows: Rows, current_schema: str | None = None -) -> dict[str, bool]: + connection: Connection, rows: Rows +) -> tuple[MetaData, dict[str, bool]]: """Collect a map of referenced tables, to whether or not they exist.""" + existing_metadata = MetaData() + assert existing_metadata.tables + result = {} for row in rows: - row = row.qualify(current_schema) if row.qualified_name in result: continue # If the table doesn't exist yet, we can likely assume it's being autogenerated # in the current revision and as such, will just emit insert statements. - result[row.qualified_name] = check_table_exists( + table_exists = check_table_exists( connection, row.tablename, schema=row.schema, ) + result[row.qualified_name] = table_exists + + if table_exists and row.qualified_name not in existing_metadata.tables: + existing_metadata.reflect( + bind=connection, schema=row.schema, only=[row.tablename] + ) for fq_tablename in rows.included_tables: schema, tablename = split_schema(fq_tablename) @@ -260,25 +291,38 @@ def resolve_existing_tables( schema=schema, ) - return result + return existing_metadata, result def collect_existing_record_data( - connection: Connection, filters_by_table, existing_tables: dict[str, bool] -) -> dict[Table, dict[tuple[Any, ...], dict[str, Any]]]: - result = {} + connection: Connection, + metadata: MetaData, + filters_by_table, + existing_tables: dict[str, bool], +) -> dict[str, dict[tuple[Any, ...], dict[str, Any]]]: + assert metadata.tables + + result: dict[str, dict[tuple[Any, ...], dict[str, Any]]] = {} for table, filters in filters_by_table.items(): table_exists = existing_tables[table.fullname] if not table_exists: - result[table] = {} + result[table.fullname] = {} continue primary_key_columns = [c.name for c in table.primary_key.columns] - records = connection.execute(table.select().where(or_(*filters))).fetchall() - existing_rows = result.setdefault(table, {}) - for record in records: + current_table = metadata.tables[table.fullname] + records = connection.execute(select(current_table).where(or_(*filters))) + assert records + + existing_rows = result.setdefault(table.fullname, {}) + for record in records.fetchall(): record_dict = row_to_dict(record) pk = tuple([record_dict[c] for c in primary_key_columns]) existing_rows[pk] = record_dict + return result + + +def filter_column_data(table: Table, row: dict): + return {c: v for c, v in row.items() if c in table.columns} diff --git a/tests/row/test_column_added.py b/tests/row/test_column_added.py new file mode 100644 index 0000000..f0b1a79 --- /dev/null +++ b/tests/row/test_column_added.py @@ -0,0 +1,60 @@ +from pytest_mock_resources import create_postgres_fixture +from sqlalchemy import Column, text, types + +from sqlalchemy_declarative_extensions import ( + Row, + Rows, + declarative_database, + register_sqlalchemy_events, +) +from sqlalchemy_declarative_extensions.row.compare import ( + compare_rows, +) +from sqlalchemy_declarative_extensions.sqlalchemy import declarative_base + +_Base = declarative_base() + + +@declarative_database +class Base(_Base): # type: ignore + __abstract__ = True + + rows = Rows(ignore_unspecified=True).are( + Row("foo", id=1, name="qwer"), + ) + + +class Foo(Base): + __tablename__ = "foo" + + id = Column(types.Integer(), primary_key=True) + name = Column(types.Unicode(), nullable=True) + + +register_sqlalchemy_events(Base.metadata, rows=True) + +pg = create_postgres_fixture( + scope="function", engine_kwargs={"echo": True}, session=True +) + + +def test_insert_missing(pg): + pg.execute(text("CREATE TABLE foo (id SERIAL PRIMARY KEY)")) + pg.commit() + + Base.metadata.create_all(bind=pg.connection()) + pg.commit() + + result = pg.execute(text("SELECT * FROM foo")).fetchall() + assert result == [(1,)] + + pg.execute(text("ALTER TABLE foo ADD name VARCHAR")) + pg.commit() + + result = compare_rows(pg.connection(), Base.metadata, Base.rows) + for op in result: + op.execute(pg.connection()) + pg.commit() + + result = pg.execute(text("SELECT * FROM foo")).fetchall() + assert result == [(1, "qwer")] diff --git a/tests/view/test_add_constraint_to_existing.py b/tests/view/test_add_constraint_to_existing.py index 7aa0b6b..5f53874 100644 --- a/tests/view/test_add_constraint_to_existing.py +++ b/tests/view/test_add_constraint_to_existing.py @@ -11,9 +11,6 @@ ) from sqlalchemy_declarative_extensions.sqlalchemy import declarative_base -() - - _Base = declarative_base()