Skip to content

Commit

Permalink
fixed typing for python >=3.8, optimized with inmutables, added unit …
Browse files Browse the repository at this point in the history
…tests for index configs
  • Loading branch information
axellpadilla committed Oct 30, 2024
1 parent 7006ae6 commit 5f85ce2
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 15 deletions.
36 changes: 23 additions & 13 deletions dbt/adapters/sqlserver/relation_configs/index.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Optional
from typing import FrozenSet, Optional, Set, Tuple

import agate
from dbt.adapters.exceptions import IndexConfigError, IndexConfigNotDictError
Expand All @@ -16,6 +16,11 @@
from dbt_common.utils import encoding as dbt_encoding


# Handle datetime now for testing.
def datetime_now(tz: Optional[timezone] = timezone.utc) -> datetime:
return datetime.now(tz)


# ALTERED FROM:
# github.com/dbt-labs/dbt-postgres/blob/main/dbt/adapters/postgres/relation_configs/index.py
class SQLServerIndexType(StrEnum):
Expand All @@ -32,7 +37,7 @@ def default(cls) -> "SQLServerIndexType":

@classmethod
def valid_types(cls):
return list(cls)
return tuple(cls)


@dataclass(frozen=True, eq=True, unsafe_hash=True)
Expand All @@ -52,17 +57,19 @@ class SQLServerIndexConfig(RelationConfigBase, RelationConfigValidationMixin, db
"""

name: str = field(default="", hash=False, compare=False)
columns: list[str] = field(default_factory=list, hash=True) # Keeping order is important
columns: Tuple[str, ...] = field(
default_factory=tuple, hash=True
) # Keeping order is important
unique: bool = field(
default=False, hash=True
) # Uniqueness can be a property of both clustered and nonclustered indexes.
type: SQLServerIndexType = field(default=SQLServerIndexType.default(), hash=True)
included_columns: frozenset[str] = field(
included_columns: FrozenSet[str] = field(
default_factory=frozenset, hash=True
) # Keeping order is not important

@property
def validation_rules(self) -> set[RelationConfigValidationRule]:
def validation_rules(self) -> Set[RelationConfigValidationRule]:
return {
RelationConfigValidationRule(
validation_check=True if self.columns else False,
Expand Down Expand Up @@ -102,7 +109,7 @@ def validation_rules(self) -> set[RelationConfigValidationRule]:
def from_dict(cls, config_dict) -> "SQLServerIndexConfig":
kwargs_dict = {
"name": config_dict.get("name"),
"columns": list(column for column in config_dict.get("columns", list())),
"columns": tuple(column for column in config_dict.get("columns", tuple())),
"unique": config_dict.get("unique"),
"type": config_dict.get("type"),
"included_columns": frozenset(
Expand All @@ -115,7 +122,7 @@ def from_dict(cls, config_dict) -> "SQLServerIndexConfig":
@classmethod
def parse_model_node(cls, model_node_entry: dict) -> dict:
config_dict = {
"columns": list(model_node_entry.get("columns", list())),
"columns": tuple(model_node_entry.get("columns", tuple())),
"unique": model_node_entry.get("unique"),
"type": model_node_entry.get("type"),
"included_columns": frozenset(model_node_entry.get("included_columns", set())),
Expand All @@ -126,7 +133,7 @@ def parse_model_node(cls, model_node_entry: dict) -> dict:
def parse_relation_results(cls, relation_results_entry: agate.Row) -> dict:
config_dict = {
"name": relation_results_entry.get("name"),
"columns": list(relation_results_entry.get("columns", "").split(",")),
"columns": tuple(relation_results_entry.get("columns", "").split(",")),
"unique": relation_results_entry.get("unique"),
"type": relation_results_entry.get("type"),
"included_columns": set(relation_results_entry.get("included_columns", "").split(",")),
Expand All @@ -139,10 +146,10 @@ def as_node_config(self) -> dict:
Returns: a dictionary that can be passed into `get_create_index_sql()`
"""
node_config = {
"columns": list(self.columns),
"columns": tuple(self.columns),
"unique": self.unique,
"type": self.type.value,
"included_columns": list(self.included_columns),
"included_columns": frozenset(self.included_columns),
}
return node_config

Expand All @@ -152,16 +159,19 @@ def render(self, relation):
# https://github.com/dbt-labs/dbt-core/issues/1945#issuecomment-576714925
# for an explanation.

now = datetime.now(timezone.utc).isoformat()
inputs = self.columns + [relation.render(), str(self.unique), str(self.type), now]
now = datetime_now(tz=timezone.utc).isoformat()
inputs = self.columns + tuple((relation.render(), str(self.unique), str(self.type), now))
string = "_".join(inputs)
print(f"Actual string before MD5: {string}")
return dbt_encoding.md5(string)

@classmethod
def parse(cls, raw_index) -> Optional["SQLServerIndexConfig"]:
if raw_index is None:
return None
try:
if not isinstance(raw_index, dict):
raise IndexConfigNotDictError(raw_index)
cls.validate(raw_index)
return cls.from_dict(raw_index)
except ValidationError as exc:
Expand Down Expand Up @@ -202,7 +212,7 @@ def requires_full_refresh(self) -> bool:
return False

@property
def validation_rules(self) -> set[RelationConfigValidationRule]:
def validation_rules(self) -> Set[RelationConfigValidationRule]:
return {
RelationConfigValidationRule(
validation_check=self.action
Expand Down
4 changes: 2 additions & 2 deletions dbt/adapters/sqlserver/sqlserver_configs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Optional
from typing import Optional, Tuple

from dbt.adapters.fabric import FabricConfigs

Expand All @@ -8,4 +8,4 @@

@dataclass
class SQLServerConfigs(FabricConfigs):
indexes: Optional[list[SQLServerIndexConfig]] = None
indexes: Optional[Tuple[SQLServerIndexConfig]] = None
174 changes: 174 additions & 0 deletions tests/unit/adapters/mssql/test_index_configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch

import pytest
from dbt.adapters.exceptions import IndexConfigError, IndexConfigNotDictError
from dbt.exceptions import DbtRuntimeError
from dbt_common.utils import encoding as dbt_encoding

from dbt.adapters.sqlserver.relation_configs.index import SQLServerIndexConfig, SQLServerIndexType


def test_sqlserver_index_type_default():
assert SQLServerIndexType.default() == SQLServerIndexType.nonclustered


def test_sqlserver_index_type_valid_types():
valid_types = SQLServerIndexType.valid_types()
assert isinstance(valid_types, tuple)
assert len(valid_types) > 0


def test_sqlserver_index_config_creation():
config = SQLServerIndexConfig(
columns=("col1", "col2"),
unique=True,
type=SQLServerIndexType.nonclustered,
included_columns=frozenset(["col3", "col4"]),
)
assert config.columns == ("col1", "col2")
assert config.unique is True
assert config.type == SQLServerIndexType.nonclustered
assert config.included_columns == frozenset(["col3", "col4"])


def test_sqlserver_index_config_from_dict():
config_dict = {
"columns": ["col1", "col2"],
"unique": True,
"type": "nonclustered",
"included_columns": ["col3", "col4"],
}
config = SQLServerIndexConfig.from_dict(config_dict)
assert config.columns == ("col1", "col2")
assert config.unique is True
assert config.type == SQLServerIndexType.nonclustered
assert config.included_columns == frozenset(["col3", "col4"])


def test_sqlserver_index_config_validation_rules():
# Test valid configuration
valid_config = SQLServerIndexConfig(
columns=("col1", "col2"),
unique=True,
type=SQLServerIndexType.nonclustered,
included_columns=frozenset(["col3", "col4"]),
)
assert len(valid_config.validation_rules) == 4
for rule in valid_config.validation_rules:
assert rule.validation_check is True

# Test invalid configurations
with pytest.raises(DbtRuntimeError, match="'columns' is a required property"):
SQLServerIndexConfig(columns=())

with pytest.raises(
DbtRuntimeError,
match="Non-clustered indexes are the only index types that can include extra columns",
):
SQLServerIndexConfig(
columns=("col1",),
type=SQLServerIndexType.clustered,
included_columns=frozenset(["col2"]),
)

with pytest.raises(
DbtRuntimeError,
match="Clustered and nonclustered indexes are the only types that can be unique",
):
SQLServerIndexConfig(columns=("col1",), unique=True, type=SQLServerIndexType.columnstore)


def test_sqlserver_index_config_parse_model_node():
model_node_entry = {
"columns": ["col1", "col2"],
"unique": True,
"type": "nonclustered",
"included_columns": ["col3", "col4"],
}
parsed_dict = SQLServerIndexConfig.parse_model_node(model_node_entry)
assert parsed_dict == {
"columns": ("col1", "col2"),
"unique": True,
"type": "nonclustered",
"included_columns": frozenset(["col3", "col4"]),
}


def test_sqlserver_index_config_parse_relation_results():
relation_results_entry = {
"name": "index_name",
"columns": "col1,col2",
"unique": True,
"type": "nonclustered",
"included_columns": "col3,col4",
}
parsed_dict = SQLServerIndexConfig.parse_relation_results(relation_results_entry)
assert parsed_dict == {
"name": "index_name",
"columns": ("col1", "col2"),
"unique": True,
"type": "nonclustered",
"included_columns": {"col3", "col4"},
}


def test_sqlserver_index_config_as_node_config():
config = SQLServerIndexConfig(
columns=("col1", "col2"),
unique=True,
type=SQLServerIndexType.nonclustered,
included_columns=frozenset(["col3", "col4"]),
)
node_config = config.as_node_config
assert node_config == {
"columns": ("col1", "col2"),
"unique": True,
"type": "nonclustered",
"included_columns": frozenset(["col3", "col4"]),
}


FAKE_NOW = datetime(2023, 1, 1, 0, 0, 0, tzinfo=timezone.utc)


@pytest.fixture(autouse=True)
def patch_datetime_now():
with patch("dbt.adapters.sqlserver.relation_configs.index.datetime_now") as mocked_datetime:
mocked_datetime.return_value = FAKE_NOW
yield mocked_datetime


def test_sqlserver_index_config_render():
config = SQLServerIndexConfig(
columns=("col1", "col2"), unique=True, type=SQLServerIndexType.nonclustered
)
relation = MagicMock()
relation.render.return_value = "test_relation"

result = config.render(relation)

expected_string = "col1_col2_test_relation_True_nonclustered_2023-01-01T00:00:00+00:00"

print(f"Expected string: {expected_string}")
print(f"Actual result (MD5): {result}")
print(f"Expected result (MD5): {dbt_encoding.md5(expected_string)}")

assert result == dbt_encoding.md5(expected_string)


def test_sqlserver_index_config_parse():
valid_raw_index = {"columns": ["col1", "col2"], "unique": True, "type": "nonclustered"}
result = SQLServerIndexConfig.parse(valid_raw_index)
assert isinstance(result, SQLServerIndexConfig)
assert result.columns == ("col1", "col2")
assert result.unique is True
assert result.type == SQLServerIndexType.nonclustered

assert SQLServerIndexConfig.parse(None) is None

with pytest.raises(IndexConfigError):
SQLServerIndexConfig.parse({"invalid": "config"})

with pytest.raises(IndexConfigNotDictError):
SQLServerIndexConfig.parse("not a dict")

0 comments on commit 5f85ce2

Please sign in to comment.