From 5f85ce26af2b88a77751d8c9eefd90bc089780ec Mon Sep 17 00:00:00 2001 From: Axell Padilla <68310020+axellpadilla@users.noreply.github.com> Date: Wed, 30 Oct 2024 04:16:42 +0000 Subject: [PATCH] fixed typing for python >=3.8, optimized with inmutables, added unit tests for index configs --- .../sqlserver/relation_configs/index.py | 36 ++-- dbt/adapters/sqlserver/sqlserver_configs.py | 4 +- .../unit/adapters/mssql/test_index_configs.py | 174 ++++++++++++++++++ 3 files changed, 199 insertions(+), 15 deletions(-) create mode 100644 tests/unit/adapters/mssql/test_index_configs.py diff --git a/dbt/adapters/sqlserver/relation_configs/index.py b/dbt/adapters/sqlserver/relation_configs/index.py index 9213247b..029448c0 100644 --- a/dbt/adapters/sqlserver/relation_configs/index.py +++ b/dbt/adapters/sqlserver/relation_configs/index.py @@ -1,6 +1,6 @@ from dataclasses import dataclass, field from datetime import datetime, timezone -from typing import Optional +from typing import FrozenSet, Optional, Set, Tuple import agate from dbt.adapters.exceptions import IndexConfigError, IndexConfigNotDictError @@ -16,6 +16,11 @@ from dbt_common.utils import encoding as dbt_encoding +# Handle datetime now for testing. +def datetime_now(tz: Optional[timezone] = timezone.utc) -> datetime: + return datetime.now(tz) + + # ALTERED FROM: # github.com/dbt-labs/dbt-postgres/blob/main/dbt/adapters/postgres/relation_configs/index.py class SQLServerIndexType(StrEnum): @@ -32,7 +37,7 @@ def default(cls) -> "SQLServerIndexType": @classmethod def valid_types(cls): - return list(cls) + return tuple(cls) @dataclass(frozen=True, eq=True, unsafe_hash=True) @@ -52,17 +57,19 @@ class SQLServerIndexConfig(RelationConfigBase, RelationConfigValidationMixin, db """ name: str = field(default="", hash=False, compare=False) - columns: list[str] = field(default_factory=list, hash=True) # Keeping order is important + columns: Tuple[str, ...] = field( + default_factory=tuple, hash=True + ) # Keeping order is important unique: bool = field( default=False, hash=True ) # Uniqueness can be a property of both clustered and nonclustered indexes. type: SQLServerIndexType = field(default=SQLServerIndexType.default(), hash=True) - included_columns: frozenset[str] = field( + included_columns: FrozenSet[str] = field( default_factory=frozenset, hash=True ) # Keeping order is not important @property - def validation_rules(self) -> set[RelationConfigValidationRule]: + def validation_rules(self) -> Set[RelationConfigValidationRule]: return { RelationConfigValidationRule( validation_check=True if self.columns else False, @@ -102,7 +109,7 @@ def validation_rules(self) -> set[RelationConfigValidationRule]: def from_dict(cls, config_dict) -> "SQLServerIndexConfig": kwargs_dict = { "name": config_dict.get("name"), - "columns": list(column for column in config_dict.get("columns", list())), + "columns": tuple(column for column in config_dict.get("columns", tuple())), "unique": config_dict.get("unique"), "type": config_dict.get("type"), "included_columns": frozenset( @@ -115,7 +122,7 @@ def from_dict(cls, config_dict) -> "SQLServerIndexConfig": @classmethod def parse_model_node(cls, model_node_entry: dict) -> dict: config_dict = { - "columns": list(model_node_entry.get("columns", list())), + "columns": tuple(model_node_entry.get("columns", tuple())), "unique": model_node_entry.get("unique"), "type": model_node_entry.get("type"), "included_columns": frozenset(model_node_entry.get("included_columns", set())), @@ -126,7 +133,7 @@ def parse_model_node(cls, model_node_entry: dict) -> dict: def parse_relation_results(cls, relation_results_entry: agate.Row) -> dict: config_dict = { "name": relation_results_entry.get("name"), - "columns": list(relation_results_entry.get("columns", "").split(",")), + "columns": tuple(relation_results_entry.get("columns", "").split(",")), "unique": relation_results_entry.get("unique"), "type": relation_results_entry.get("type"), "included_columns": set(relation_results_entry.get("included_columns", "").split(",")), @@ -139,10 +146,10 @@ def as_node_config(self) -> dict: Returns: a dictionary that can be passed into `get_create_index_sql()` """ node_config = { - "columns": list(self.columns), + "columns": tuple(self.columns), "unique": self.unique, "type": self.type.value, - "included_columns": list(self.included_columns), + "included_columns": frozenset(self.included_columns), } return node_config @@ -152,9 +159,10 @@ def render(self, relation): # https://github.com/dbt-labs/dbt-core/issues/1945#issuecomment-576714925 # for an explanation. - now = datetime.now(timezone.utc).isoformat() - inputs = self.columns + [relation.render(), str(self.unique), str(self.type), now] + now = datetime_now(tz=timezone.utc).isoformat() + inputs = self.columns + tuple((relation.render(), str(self.unique), str(self.type), now)) string = "_".join(inputs) + print(f"Actual string before MD5: {string}") return dbt_encoding.md5(string) @classmethod @@ -162,6 +170,8 @@ def parse(cls, raw_index) -> Optional["SQLServerIndexConfig"]: if raw_index is None: return None try: + if not isinstance(raw_index, dict): + raise IndexConfigNotDictError(raw_index) cls.validate(raw_index) return cls.from_dict(raw_index) except ValidationError as exc: @@ -202,7 +212,7 @@ def requires_full_refresh(self) -> bool: return False @property - def validation_rules(self) -> set[RelationConfigValidationRule]: + def validation_rules(self) -> Set[RelationConfigValidationRule]: return { RelationConfigValidationRule( validation_check=self.action diff --git a/dbt/adapters/sqlserver/sqlserver_configs.py b/dbt/adapters/sqlserver/sqlserver_configs.py index 32347c31..1cdbc8cd 100644 --- a/dbt/adapters/sqlserver/sqlserver_configs.py +++ b/dbt/adapters/sqlserver/sqlserver_configs.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Optional +from typing import Optional, Tuple from dbt.adapters.fabric import FabricConfigs @@ -8,4 +8,4 @@ @dataclass class SQLServerConfigs(FabricConfigs): - indexes: Optional[list[SQLServerIndexConfig]] = None + indexes: Optional[Tuple[SQLServerIndexConfig]] = None diff --git a/tests/unit/adapters/mssql/test_index_configs.py b/tests/unit/adapters/mssql/test_index_configs.py new file mode 100644 index 00000000..278ba8e7 --- /dev/null +++ b/tests/unit/adapters/mssql/test_index_configs.py @@ -0,0 +1,174 @@ +from datetime import datetime, timezone +from unittest.mock import MagicMock, patch + +import pytest +from dbt.adapters.exceptions import IndexConfigError, IndexConfigNotDictError +from dbt.exceptions import DbtRuntimeError +from dbt_common.utils import encoding as dbt_encoding + +from dbt.adapters.sqlserver.relation_configs.index import SQLServerIndexConfig, SQLServerIndexType + + +def test_sqlserver_index_type_default(): + assert SQLServerIndexType.default() == SQLServerIndexType.nonclustered + + +def test_sqlserver_index_type_valid_types(): + valid_types = SQLServerIndexType.valid_types() + assert isinstance(valid_types, tuple) + assert len(valid_types) > 0 + + +def test_sqlserver_index_config_creation(): + config = SQLServerIndexConfig( + columns=("col1", "col2"), + unique=True, + type=SQLServerIndexType.nonclustered, + included_columns=frozenset(["col3", "col4"]), + ) + assert config.columns == ("col1", "col2") + assert config.unique is True + assert config.type == SQLServerIndexType.nonclustered + assert config.included_columns == frozenset(["col3", "col4"]) + + +def test_sqlserver_index_config_from_dict(): + config_dict = { + "columns": ["col1", "col2"], + "unique": True, + "type": "nonclustered", + "included_columns": ["col3", "col4"], + } + config = SQLServerIndexConfig.from_dict(config_dict) + assert config.columns == ("col1", "col2") + assert config.unique is True + assert config.type == SQLServerIndexType.nonclustered + assert config.included_columns == frozenset(["col3", "col4"]) + + +def test_sqlserver_index_config_validation_rules(): + # Test valid configuration + valid_config = SQLServerIndexConfig( + columns=("col1", "col2"), + unique=True, + type=SQLServerIndexType.nonclustered, + included_columns=frozenset(["col3", "col4"]), + ) + assert len(valid_config.validation_rules) == 4 + for rule in valid_config.validation_rules: + assert rule.validation_check is True + + # Test invalid configurations + with pytest.raises(DbtRuntimeError, match="'columns' is a required property"): + SQLServerIndexConfig(columns=()) + + with pytest.raises( + DbtRuntimeError, + match="Non-clustered indexes are the only index types that can include extra columns", + ): + SQLServerIndexConfig( + columns=("col1",), + type=SQLServerIndexType.clustered, + included_columns=frozenset(["col2"]), + ) + + with pytest.raises( + DbtRuntimeError, + match="Clustered and nonclustered indexes are the only types that can be unique", + ): + SQLServerIndexConfig(columns=("col1",), unique=True, type=SQLServerIndexType.columnstore) + + +def test_sqlserver_index_config_parse_model_node(): + model_node_entry = { + "columns": ["col1", "col2"], + "unique": True, + "type": "nonclustered", + "included_columns": ["col3", "col4"], + } + parsed_dict = SQLServerIndexConfig.parse_model_node(model_node_entry) + assert parsed_dict == { + "columns": ("col1", "col2"), + "unique": True, + "type": "nonclustered", + "included_columns": frozenset(["col3", "col4"]), + } + + +def test_sqlserver_index_config_parse_relation_results(): + relation_results_entry = { + "name": "index_name", + "columns": "col1,col2", + "unique": True, + "type": "nonclustered", + "included_columns": "col3,col4", + } + parsed_dict = SQLServerIndexConfig.parse_relation_results(relation_results_entry) + assert parsed_dict == { + "name": "index_name", + "columns": ("col1", "col2"), + "unique": True, + "type": "nonclustered", + "included_columns": {"col3", "col4"}, + } + + +def test_sqlserver_index_config_as_node_config(): + config = SQLServerIndexConfig( + columns=("col1", "col2"), + unique=True, + type=SQLServerIndexType.nonclustered, + included_columns=frozenset(["col3", "col4"]), + ) + node_config = config.as_node_config + assert node_config == { + "columns": ("col1", "col2"), + "unique": True, + "type": "nonclustered", + "included_columns": frozenset(["col3", "col4"]), + } + + +FAKE_NOW = datetime(2023, 1, 1, 0, 0, 0, tzinfo=timezone.utc) + + +@pytest.fixture(autouse=True) +def patch_datetime_now(): + with patch("dbt.adapters.sqlserver.relation_configs.index.datetime_now") as mocked_datetime: + mocked_datetime.return_value = FAKE_NOW + yield mocked_datetime + + +def test_sqlserver_index_config_render(): + config = SQLServerIndexConfig( + columns=("col1", "col2"), unique=True, type=SQLServerIndexType.nonclustered + ) + relation = MagicMock() + relation.render.return_value = "test_relation" + + result = config.render(relation) + + expected_string = "col1_col2_test_relation_True_nonclustered_2023-01-01T00:00:00+00:00" + + print(f"Expected string: {expected_string}") + print(f"Actual result (MD5): {result}") + print(f"Expected result (MD5): {dbt_encoding.md5(expected_string)}") + + assert result == dbt_encoding.md5(expected_string) + + +def test_sqlserver_index_config_parse(): + valid_raw_index = {"columns": ["col1", "col2"], "unique": True, "type": "nonclustered"} + result = SQLServerIndexConfig.parse(valid_raw_index) + assert isinstance(result, SQLServerIndexConfig) + assert result.columns == ("col1", "col2") + assert result.unique is True + assert result.type == SQLServerIndexType.nonclustered + + assert SQLServerIndexConfig.parse(None) is None + + with pytest.raises(IndexConfigError): + SQLServerIndexConfig.parse({"invalid": "config"}) + + with pytest.raises(IndexConfigNotDictError): + SQLServerIndexConfig.parse("not a dict")