diff --git a/superset/migrations/versions/2023-03-27_12-30_7e67aecbf3f1_chart_ds_constraint.py b/superset/migrations/versions/2023-03-27_12-30_7e67aecbf3f1_chart_ds_constraint.py index 4c6941fe8416d..f00fa9c357214 100644 --- a/superset/migrations/versions/2023-03-27_12-30_7e67aecbf3f1_chart_ds_constraint.py +++ b/superset/migrations/versions/2023-03-27_12-30_7e67aecbf3f1_chart_ds_constraint.py @@ -34,6 +34,7 @@ from sqlalchemy.ext.declarative import declarative_base from superset import db +from superset.utils.core import generic_find_check_constraint_exists Base = declarative_base() @@ -84,6 +85,7 @@ def upgrade_slc(slc: Slice) -> None: def upgrade(): bind = op.get_bind() session = db.Session(bind=bind) + insp = sa.engine.reflection.Inspector.from_engine(bind) with op.batch_alter_table("slices") as batch_op: for slc in session.query(Slice).filter(Slice.datasource_type != "table").all(): if slc.datasource_type == "query": @@ -100,13 +102,16 @@ def upgrade(): session.commit() with op.batch_alter_table("slices") as batch_op: - batch_op.create_check_constraint( - "ck_chart_datasource", "datasource_type in ('table')" - ) + if not generic_find_check_constraint_exists( + "slices", "ck_chart_datasource", insp + ): + batch_op.create_check_constraint( + "ck_chart_datasource", "datasource_type in ('table')" + ) session.commit() session.close() def downgrade(): - op.drop_constraint("ck_chart_datasource", "slices", type_="check") + op.execute("ALTER TABLE slices DROP CONSTRAINT IF EXISTS ck_chart_datasource") diff --git a/superset/utils/core.py b/superset/utils/core.py index 7cce8795ec6c9..097ca4dd698f2 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -67,7 +67,7 @@ from flask_babel.speaklater import LazyString from pandas.api.types import infer_dtype from pandas.core.dtypes.common import is_numeric_dtype -from sqlalchemy import event, exc, inspect, select, Text +from sqlalchemy import CheckConstraint, event, exc, inspect, select, Text from sqlalchemy.dialects.mysql import MEDIUMTEXT from sqlalchemy.engine import Connection, Engine from sqlalchemy.engine.reflection import Inspector @@ -733,6 +733,22 @@ def generic_find_uq_constraint_name( return None +def generic_find_check_constraint_exists( + table_name: str, constraint_name: str, insp: Inspector +) -> bool: + """ + Check if a constraint exists in a table. + """ + + # Get the table object + check_constraints = insp.get_check_constraints(table_name) + + # Check if the table has the specified constraint + return any( + constraint["name"] == constraint_name for constraint in check_constraints + ) + + def get_datasource_full_name( database_name: str, datasource_name: str, schema: str | None = None ) -> str: diff --git a/tests/unit_tests/utils/test_core.py b/tests/unit_tests/utils/test_core.py index 562ebe582e57a..b1f2fa2114a83 100644 --- a/tests/unit_tests/utils/test_core.py +++ b/tests/unit_tests/utils/test_core.py @@ -16,13 +16,18 @@ # under the License. import os from typing import Any, Optional +from unittest.mock import MagicMock, patch import pandas as pd import pytest +from sqlalchemy import CheckConstraint, Column, Integer, MetaData, Table from superset.utils.core import ( cast_to_boolean, DateColumn, + generic_find_check_constraint_exists, + generic_find_constraint_name, + generic_find_fk_constraint_name, is_test, normalize_dttm_col, parse_boolean_string, @@ -201,3 +206,142 @@ def test_normalize_dttm_col() -> None: normalize_dttm_col(df, dttm_cols) assert df["__time"].astype(str).tolist() == ["2017-07-01"] + + +def test_constraint_exists(): + insp_mock = MagicMock() + table_name = "my_table" + constraint_name = "my_constraint" + constraint_mock = {"name": constraint_name} + + # Configure the Inspector mock to return the mock table + insp_mock.get_check_constraints.return_value = [constraint_mock] + + result = generic_find_check_constraint_exists( + table_name, "my_constraint", insp_mock + ) + assert result is True + + +def test_constraint_does_not_exist(): + insp_mock = MagicMock() + table_name = "my_table" + constraint_name = "my_constraint" + constraint_mock = {"name": constraint_name} + + # Configure the Inspector mock to return the mock table + insp_mock.get_check_constraints.return_value = [constraint_mock] + result = generic_find_check_constraint_exists( + table_name, "non_existent_constraint", insp_mock + ) + assert result is False + + +def test_generic_constraint_name_exists(): + # Create a mock SQLAlchemy database object + database_mock = MagicMock() + + # Define the table name and constraint details + table_name = "my_table" + columns = {"column1", "column2"} + referenced_table_name = "other_table" + constraint_name = "my_constraint" + + # Create a mock table object with the same structure + table_mock = MagicMock() + table_mock.name = table_name + table_mock.columns = [MagicMock(name=col) for col in columns] + + # Create a mock for the referred_table with a name attribute + referred_table_mock = MagicMock() + referred_table_mock.name = referenced_table_name + + # Create a mock for the foreign key constraint with a name attribute + foreign_key_constraint_mock = MagicMock() + foreign_key_constraint_mock.name = constraint_name + foreign_key_constraint_mock.referred_table = referred_table_mock + foreign_key_constraint_mock.column_keys = list(columns) + + # Set the foreign key constraint mock as part of the table's constraints + table_mock.foreign_key_constraints = [foreign_key_constraint_mock] + + # Configure the autoload behavior for the database mock + database_mock.metadata = MagicMock() + database_mock.metadata.tables = {table_name: table_mock} + + # Mock the sa.Table creation with autoload + with patch("superset.utils.core.sa.Table") as table_creation_mock: + table_creation_mock.return_value = table_mock + + result = generic_find_constraint_name( + table_name, columns, referenced_table_name, database_mock + ) + + assert result == constraint_name + + +def test_generic_constraint_name_not_found(): + # Create a mock SQLAlchemy database object + database_mock = MagicMock() + + # Define the table name and constraint details + table_name = "my_table" + columns = {"column1", "column2"} + referenced_table_name = "other_table" + constraint_name = "my_constraint" + + # Create a mock table object with the same structure but no matching constraint + table_mock = MagicMock() + table_mock.name = table_name + table_mock.columns = [MagicMock(name=col) for col in columns] + table_mock.foreign_key_constraints = [] + + # Configure the autoload behavior for the database mock + database_mock.metadata = MagicMock() + database_mock.metadata.tables = {table_name: table_mock} + + result = generic_find_constraint_name( + table_name, columns, referenced_table_name, database_mock + ) + + assert result is None + + +def test_generic_find_fk_constraint_exists(): + insp_mock = MagicMock() + table_name = "my_table" + columns = {"column1", "column2"} + referenced_table_name = "other_table" + constraint_name = "my_constraint" + + # Create a mock for the foreign key constraint as a dictionary + constraint_mock = { + "name": constraint_name, + "referred_table": referenced_table_name, + "referred_columns": list(columns), + } + + # Configure the Inspector mock to return the list of foreign key constraints + insp_mock.get_foreign_keys.return_value = [constraint_mock] + + result = generic_find_fk_constraint_name( + table_name, columns, referenced_table_name, insp_mock + ) + + assert result == constraint_name + + +def test_generic_find_fk_constraint_none_exist(): + insp_mock = MagicMock() + table_name = "my_table" + columns = {"column1", "column2"} + referenced_table_name = "other_table" + + # Configure the Inspector mock to return the list of foreign key constraints + insp_mock.get_foreign_keys.return_value = [] + + result = generic_find_fk_constraint_name( + table_name, columns, referenced_table_name, insp_mock + ) + + assert result is None