Skip to content

Commit

Permalink
feat: Detect and perform bulk queries for insert/delete ops.
Browse files Browse the repository at this point in the history
This is helpful for databases (like snowflake...) who have really slow
individual query performance.
  • Loading branch information
DanCardin committed Feb 13, 2024
1 parent 176e42c commit f9b4ea8
Show file tree
Hide file tree
Showing 18 changed files with 386 additions and 96 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "sqlalchemy-declarative-extensions"
version = "0.6.9"
version = "0.6.10"
authors = ["Dan Cardin <ddcardin@gmail.com>"]

description = "Library to declare additional kinds of objects not natively supported by SQLAlchemy/Alembic."
Expand Down
3 changes: 1 addition & 2 deletions src/sqlalchemy_declarative_extensions/alembic/row.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,4 @@ def execute_row(
operation: Union[InsertRowOp, UpdateRowOp, DeleteRowOp],
):
conn = operations.get_bind()
query = operation.render(conn)
conn.execute(query)
operation.execute(conn)
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
GrantOptions,
GrantTypes,
)
from sqlalchemy_declarative_extensions.sql import split_schema
from sqlalchemy_declarative_extensions.typing import Protocol, runtime_checkable


Expand Down Expand Up @@ -345,10 +346,7 @@ def _render_to_or_from(grant: Grant) -> str:


def _quote_table_name(name: str):
if "." in name:
schema, name = name.split(".")
else:
schema = None
schema, name = split_schema(name)

if schema:
return f'"{schema}"."{name}"'
Expand Down
8 changes: 3 additions & 5 deletions src/sqlalchemy_declarative_extensions/row/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from dataclasses import dataclass, field, replace
from typing import Any, Iterable

from sqlalchemy_declarative_extensions.sql import split_schema


@dataclass
class Rows:
Expand Down Expand Up @@ -34,11 +36,7 @@ class Row:
column_values: dict[str, Any]

def __init__(self, tablename, *, schema: str | None = None, **column_values):
schema = schema
try:
schema, table = tablename.split(".", 1)
except ValueError:
table = tablename
schema, table = split_schema(tablename, schema=schema)

self.schema = schema
self.tablename = table
Expand Down
248 changes: 171 additions & 77 deletions src/sqlalchemy_declarative_extensions/row/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,32 @@
from dataclasses import dataclass
from typing import Any, Union

from sqlalchemy import MetaData, Table, tuple_
from sqlalchemy import MetaData, Table, and_, not_, or_
from sqlalchemy.engine.base import Connection

from sqlalchemy_declarative_extensions.dialects import (
check_table_exists,
get_current_schema,
)
from sqlalchemy_declarative_extensions.row.base import Row, Rows
from sqlalchemy_declarative_extensions.sql import split_schema
from sqlalchemy_declarative_extensions.sqlalchemy import row_to_dict


@dataclass
class InsertRowOp:
table: str
values: dict[str, Any]
values: dict[str, Any] | list[dict[str, Any]]

@classmethod
def insert_table_row(cls, operations, table, values):
op = cls(table, values)
return operations.invoke(op)

def render(self, conn: Connection):
def execute(self, conn: Connection):
table = get_table(conn, self.table)
return table.insert().values(self.values)
query = table.insert().values(self.values)
conn.execute(query)

def reverse(self):
return DeleteRowOp(self.table, self.values)
Expand All @@ -35,27 +37,31 @@ def reverse(self):
@dataclass
class UpdateRowOp:
table: str
from_values: dict[str, Any]
to_values: dict[str, Any]
from_values: dict[str, Any] | list[dict[str, Any]]
to_values: dict[str, Any] | list[dict[str, Any]]

@classmethod
def update_table_row(cls, operations, table, from_values, to_values):
op = cls(table, from_values, to_values)
return operations.invoke(op)

def render(self, conn: Connection):
def execute(self, conn: Connection):
table = get_table(conn, self.table)

primary_key_columns = [c.name for c in table.primary_key.columns]
where = [
table.c[c] == v
for c, v in self.to_values.items()
if c in primary_key_columns
]
values = {
c: v for c, v in self.to_values.items() if c not in primary_key_columns
}
return table.update().where(*where).values(**values)

if isinstance(self.to_values, dict):
to_values: list[dict[str, Any]] = [self.to_values]
else:
to_values = self.to_values

for to_value in to_values:
where = [
table.c[c] == v for c, v in to_value.items() if c in primary_key_columns
]
values = {c: v for c, v in to_value.items() if c not in primary_key_columns}
query = table.update().where(*where).values(**values)
conn.execute(query)

def reverse(self):
return UpdateRowOp(self.table, self.to_values, self.from_values)
Expand All @@ -64,21 +70,37 @@ def reverse(self):
@dataclass
class DeleteRowOp:
table: str
values: dict[str, Any]
values: dict[str, Any] | list[dict[str, Any]]

@classmethod
def delete_table_row(cls, operations, table, values):
op = cls(table, values)
return operations.invoke(op)

def render(self, conn: Connection):
def execute(self, conn: Connection):
table = get_table(conn, self.table)

primary_key_columns = [c.name for c in table.primary_key.columns]
where = [
table.c[c] == v for c, v in self.values.items() if c in primary_key_columns
]
return table.delete().where(*where)
if isinstance(self.values, dict):
rows_values: list[dict[str, Any]] = [self.values]
else:
rows_values = self.values

primary_key_columns = {c.name for c in table.primary_key.columns}

where = or_(
*[
and_(
*[
table.c[c] == v
for c, v in row_values.items()
if c in primary_key_columns
]
)
for row_values in rows_values
]
)
query = table.delete().where(where)
conn.execute(query)

def reverse(self):
return InsertRowOp(self.table, self.values)
Expand All @@ -105,86 +127,158 @@ def compare_rows(connection: Connection, metadata: MetaData, rows: Rows) -> list

current_schema = get_current_schema(connection)

rows_by_table: dict[Table, list[Row]] = {}
existing_tables = resolve_existing_tables(connection, rows, current_schema)

# Collects table-specific primary keys so that we can efficiently compare rows
# further down by the pk
pk_to_row: dict[Table, dict[tuple[Any, ...], Row]] = {}

# Collects the ongoing filter required to select all referenced records (by their pk)
filters_by_table: dict[Table, list] = {}
for row in rows:
row = row.qualify(current_schema)

table = metadata.tables.get(row.qualified_name)
if table is None:
raise ValueError(f"Unknown table: {row.qualified_name}")

rows_by_table.setdefault(table, []).append(row)

primary_key_columns = [c.name for c in table.primary_key.columns]

if set(primary_key_columns) - row.column_values.keys():
raise ValueError(
f"Row is missing primary key values required to declaratively specify: {row}"
)

column_filters = [
c == row.column_values[c.name] for c in table.primary_key.columns
]

# If the table doesn't exist yet, we can likely assume it's being autogenerated
# in the current revision and as such, will just emit insert statements.
table_exists = check_table_exists(
connection,
table.name,
schema=table.schema or current_schema,
pk = tuple([row.column_values[c.name] for c in table.primary_key.columns])
pk_to_row.setdefault(table, {})[pk] = row
filters_by_table.setdefault(table, []).append(
and_(*[c == row.column_values[c.name] for c in table.primary_key.columns])
)

record = None
if table_exists:
record = connection.execute(
table.select().where(*column_filters).limit(1)
).first()

if record:
row_keys = row.column_values.keys()
record_dict = {
k: v for k, v in row_to_dict(record).items() if k in row_keys
}
if row.column_values == record_dict:
continue
existing_rows_by_table = collect_existing_record_data(
connection, filters_by_table, existing_tables
)

result.append(
UpdateRowOp(
row.qualified_name,
from_values=record_dict,
to_values=row.column_values,
)
)
else:
result.append(InsertRowOp(row.qualified_name, values=row.column_values))
table_row_inserts: dict[Table, list[dict[str, Any]]] = {}
table_row_updates: dict[
Table, tuple[list[dict[str, Any]], list[dict[str, Any]]]
] = {}
for table, pks in pk_to_row.items():
row_inserts = table_row_inserts.setdefault(table, [])
row_updates = table_row_updates.setdefault(table, ([], []))

existing_rows = existing_rows_by_table[table]

stub_keys = {
key: None for row in pks.values() for key in row.column_values.keys()
}

for pk, row in pks.items():
if pk in existing_rows:
existing_row = existing_rows[pk]
row_keys = row.column_values.keys()
record_dict = {k: v for k, v in existing_row.items() if k in row_keys}
if row.column_values == record_dict:
continue

row_updates[0].append(record_dict)
row_updates[1].append(row.column_values)
else:
row_inserts.append({**stub_keys, **row.column_values})

# Deletes should get inserted first, so as to avoid foreign key constraint errors.
if not rows.ignore_unspecified:
for table_name in rows.included_tables:
table = metadata.tables[table_name]
rows_by_table.setdefault(table, [])
filters_by_table.setdefault(table, [])

for table, row_list in rows_by_table.items():
table_exists = check_table_exists(
connection,
table.name,
schema=table.schema or current_schema,
)
for table, filter in filters_by_table.items():
table_exists = existing_tables[table.fullname]
if not table_exists:
continue

primary_key_columns = [c.name for c in table.primary_key.columns]
primary_key_values = [
tuple(row.column_values[c] for c in primary_key_columns)
for row in row_list
]
to_delete = connection.execute(
table.select().where(
tuple_(*table.primary_key.columns).notin_(primary_key_values)
select = table.select()
if filter:
select = select.where(not_(or_(*filter)))

to_delete = connection.execute(select).fetchall()

if not to_delete:
continue

result.append(
DeleteRowOp(
table.fullname, [row_to_dict(record) for record in to_delete]
)
).fetchall()
)

for table, row_updates in table_row_updates.items():
old_rows, new_rows = row_updates
if not new_rows:
continue

result.append(
UpdateRowOp(
table.fullname,
from_values=old_rows,
to_values=new_rows,
)
)

for table, row_inserts in table_row_inserts.items():
if not row_inserts:
continue

result.append(InsertRowOp(table.fullname, values=row_inserts))

return result


def resolve_existing_tables(
connection: Connection, rows: Rows, current_schema: str | None = None
) -> dict[str, bool]:
"""Collect a map of referenced tables, to whether or not they exist."""
result = {}
for row in rows:
row = row.qualify(current_schema)
if row.qualified_name in result:
continue

# If the table doesn't exist yet, we can likely assume it's being autogenerated
# in the current revision and as such, will just emit insert statements.
result[row.qualified_name] = check_table_exists(
connection,
row.tablename,
schema=row.schema,
)

for record in to_delete:
op = DeleteRowOp(table.fullname, row_to_dict(record))
result.append(op)
for fq_tablename in rows.included_tables:
schema, tablename = split_schema(fq_tablename)
result[fq_tablename] = check_table_exists(
connection,
tablename,
schema=schema,
)

return result


def collect_existing_record_data(
connection: Connection, filters_by_table, existing_tables: dict[str, bool]
) -> dict[Table, dict[tuple[Any, ...], dict[str, Any]]]:
result = {}
for table, filters in filters_by_table.items():
table_exists = existing_tables[table.fullname]
if not table_exists:
result[table] = {}
continue

primary_key_columns = [c.name for c in table.primary_key.columns]

records = connection.execute(table.select().where(or_(*filters))).fetchall()
existing_rows = result.setdefault(table, {})
for record in records:
record_dict = row_to_dict(record)
pk = tuple([record_dict[c] for c in primary_key_columns])
existing_rows[pk] = record_dict
return result
Loading

0 comments on commit f9b4ea8

Please sign in to comment.