Skip to content

Commit

Permalink
add checks on migration
Browse files Browse the repository at this point in the history
  • Loading branch information
eschutho committed Sep 12, 2023
1 parent 37742c4 commit 1d2dbf2
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from sqlalchemy.ext.declarative import declarative_base

from superset import db
from superset.utils.core import generic_find_check_constraint_exists

Base = declarative_base()

Expand Down Expand Up @@ -84,6 +85,7 @@ def upgrade_slc(slc: Slice) -> None:
def upgrade():
bind = op.get_bind()
session = db.Session(bind=bind)
insp = sa.engine.reflection.Inspector.from_engine(bind)
with op.batch_alter_table("slices") as batch_op:
for slc in session.query(Slice).filter(Slice.datasource_type != "table").all():
if slc.datasource_type == "query":
Expand All @@ -100,13 +102,16 @@ def upgrade():
session.commit()

with op.batch_alter_table("slices") as batch_op:
batch_op.create_check_constraint(
"ck_chart_datasource", "datasource_type in ('table')"
)
if not generic_find_check_constraint_exists(
"slices", "ck_chart_datasource", insp
):
batch_op.create_check_constraint(
"ck_chart_datasource", "datasource_type in ('table')"
)

session.commit()
session.close()


def downgrade():
op.drop_constraint("ck_chart_datasource", "slices", type_="check")
op.execute("ALTER TABLE slices DROP CONSTRAINT IF EXISTS ck_chart_datasource")
18 changes: 17 additions & 1 deletion superset/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
from flask_babel.speaklater import LazyString
from pandas.api.types import infer_dtype
from pandas.core.dtypes.common import is_numeric_dtype
from sqlalchemy import event, exc, inspect, select, Text
from sqlalchemy import CheckConstraint, event, exc, inspect, select, Text
from sqlalchemy.dialects.mysql import MEDIUMTEXT
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.engine.reflection import Inspector
Expand Down Expand Up @@ -733,6 +733,22 @@ def generic_find_uq_constraint_name(
return None


def generic_find_check_constraint_exists(
table_name: str, constraint_name: str, insp: Inspector
) -> bool:
"""
Check if a constraint exists in a table.
"""

# Get the table object
check_constraints = insp.get_check_constraints(table_name)

# Check if the table has the specified constraint
return any(
constraint["name"] == constraint_name for constraint in check_constraints
)


def get_datasource_full_name(
database_name: str, datasource_name: str, schema: str | None = None
) -> str:
Expand Down
144 changes: 144 additions & 0 deletions tests/unit_tests/utils/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,18 @@
# under the License.
import os
from typing import Any, Optional
from unittest.mock import MagicMock, patch

import pandas as pd
import pytest
from sqlalchemy import CheckConstraint, Column, Integer, MetaData, Table

from superset.utils.core import (
cast_to_boolean,
DateColumn,
generic_find_check_constraint_exists,
generic_find_constraint_name,
generic_find_fk_constraint_name,
is_test,
normalize_dttm_col,
parse_boolean_string,
Expand Down Expand Up @@ -201,3 +206,142 @@ def test_normalize_dttm_col() -> None:
normalize_dttm_col(df, dttm_cols)

assert df["__time"].astype(str).tolist() == ["2017-07-01"]


def test_constraint_exists():
insp_mock = MagicMock()
table_name = "my_table"
constraint_name = "my_constraint"
constraint_mock = {"name": constraint_name}

# Configure the Inspector mock to return the mock table
insp_mock.get_check_constraints.return_value = [constraint_mock]

result = generic_find_check_constraint_exists(
table_name, "my_constraint", insp_mock
)
assert result is True


def test_constraint_does_not_exist():
insp_mock = MagicMock()
table_name = "my_table"
constraint_name = "my_constraint"
constraint_mock = {"name": constraint_name}

# Configure the Inspector mock to return the mock table
insp_mock.get_check_constraints.return_value = [constraint_mock]
result = generic_find_check_constraint_exists(
table_name, "non_existent_constraint", insp_mock
)
assert result is False


def test_generic_constraint_name_exists():
# Create a mock SQLAlchemy database object
database_mock = MagicMock()

# Define the table name and constraint details
table_name = "my_table"
columns = {"column1", "column2"}
referenced_table_name = "other_table"
constraint_name = "my_constraint"

# Create a mock table object with the same structure
table_mock = MagicMock()
table_mock.name = table_name
table_mock.columns = [MagicMock(name=col) for col in columns]

# Create a mock for the referred_table with a name attribute
referred_table_mock = MagicMock()
referred_table_mock.name = referenced_table_name

# Create a mock for the foreign key constraint with a name attribute
foreign_key_constraint_mock = MagicMock()
foreign_key_constraint_mock.name = constraint_name
foreign_key_constraint_mock.referred_table = referred_table_mock
foreign_key_constraint_mock.column_keys = list(columns)

# Set the foreign key constraint mock as part of the table's constraints
table_mock.foreign_key_constraints = [foreign_key_constraint_mock]

# Configure the autoload behavior for the database mock
database_mock.metadata = MagicMock()
database_mock.metadata.tables = {table_name: table_mock}

# Mock the sa.Table creation with autoload
with patch("superset.utils.core.sa.Table") as table_creation_mock:
table_creation_mock.return_value = table_mock

result = generic_find_constraint_name(
table_name, columns, referenced_table_name, database_mock
)

assert result == constraint_name


def test_generic_constraint_name_not_found():
# Create a mock SQLAlchemy database object
database_mock = MagicMock()

# Define the table name and constraint details
table_name = "my_table"
columns = {"column1", "column2"}
referenced_table_name = "other_table"
constraint_name = "my_constraint"

# Create a mock table object with the same structure but no matching constraint
table_mock = MagicMock()
table_mock.name = table_name
table_mock.columns = [MagicMock(name=col) for col in columns]
table_mock.foreign_key_constraints = []

# Configure the autoload behavior for the database mock
database_mock.metadata = MagicMock()
database_mock.metadata.tables = {table_name: table_mock}

result = generic_find_constraint_name(
table_name, columns, referenced_table_name, database_mock
)

assert result is None


def test_generic_find_fk_constraint_exists():
insp_mock = MagicMock()
table_name = "my_table"
columns = {"column1", "column2"}
referenced_table_name = "other_table"
constraint_name = "my_constraint"

# Create a mock for the foreign key constraint as a dictionary
constraint_mock = {
"name": constraint_name,
"referred_table": referenced_table_name,
"referred_columns": list(columns),
}

# Configure the Inspector mock to return the list of foreign key constraints
insp_mock.get_foreign_keys.return_value = [constraint_mock]

result = generic_find_fk_constraint_name(
table_name, columns, referenced_table_name, insp_mock
)

assert result == constraint_name


def test_generic_find_fk_constraint_none_exist():
insp_mock = MagicMock()
table_name = "my_table"
columns = {"column1", "column2"}
referenced_table_name = "other_table"

# Configure the Inspector mock to return the list of foreign key constraints
insp_mock.get_foreign_keys.return_value = []

result = generic_find_fk_constraint_name(
table_name, columns, referenced_table_name, insp_mock
)

assert result is None

0 comments on commit 1d2dbf2

Please sign in to comment.