Skip to content

Commit

Permalink
Update tests/unit/utils.py per dbt-core#4212
Browse files Browse the repository at this point in the history
  • Loading branch information
jtcohen6 committed Nov 8, 2021
1 parent 448d8d7 commit af5f1b4
Showing 1 changed file with 212 additions and 12 deletions.
224 changes: 212 additions & 12 deletions tests/unit/utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
"""Unit test utility functions.
Note that all imports should be inside the functions to avoid import/mocking
issues.
"""
import string
import os
from unittest import mock
from unittest import TestCase

from hologram import ValidationError
import agate
import pytest
from dbt.dataclass_schema import ValidationError
from dbt.config.project import PartialProject


def normalize(path):
"""On windows, neither is enough on its own:
>>> normcase('C:\\documents/ALL CAPS/subdir\\..')
'c:\\documents\\all caps\\subdir\\..'
>>> normpath('C:\\documents/ALL CAPS/subdir\\..')
Expand All @@ -28,9 +30,10 @@ class Obj:
single_threaded = False


def mock_connection(name):
def mock_connection(name, state='open'):
conn = mock.MagicMock()
conn.name = name
conn.state = state
return conn


Expand All @@ -42,7 +45,7 @@ def profile_from_dict(profile, profile_name, cli_vars='{}'):
if not isinstance(cli_vars, dict):
cli_vars = parse_cli_vars(cli_vars)

renderer = ProfileRenderer(generate_base_context(cli_vars))
renderer = ProfileRenderer(cli_vars)
return Profile.from_raw_profile_info(
profile,
profile_name,
Expand All @@ -58,13 +61,18 @@ def project_from_dict(project, profile, packages=None, selectors=None, cli_vars=
if not isinstance(cli_vars, dict):
cli_vars = parse_cli_vars(cli_vars)

renderer = DbtProjectYamlRenderer(generate_target_context(profile, cli_vars))
renderer = DbtProjectYamlRenderer(profile, cli_vars)

project_root = project.pop('project-root', os.getcwd())

return Project.render_from_dict(
project_root, project, packages, selectors, renderer
)
partial = PartialProject.from_dicts(
project_root=project_root,
project_dict=project,
packages_dict=packages,
selectors_dict=selectors,
)
return partial.render(renderer)



def config_from_parts_or_dicts(project, profile, packages=None, selectors=None, cli_vars='{}'):
Expand Down Expand Up @@ -108,6 +116,14 @@ def inject_plugin(plugin):
FACTORY.plugins[key] = plugin


def inject_plugin_for(config):
# from dbt.adapters.postgres import Plugin, PostgresAdapter
from dbt.adapters.factory import FACTORY
FACTORY.load_plugin(config.credentials.type)
adapter = FACTORY.get_adapter(config)
return adapter


def inject_adapter(value, plugin):
"""Inject the given adapter into the adapter factory, so your hand-crafted
artisanal adapter will be available from get_adapter() as if dbt loaded it.
Expand All @@ -118,6 +134,13 @@ def inject_adapter(value, plugin):
FACTORY.adapters[key] = value


def clear_plugin(plugin):
from dbt.adapters.factory import FACTORY
key = plugin.adapter.type()
FACTORY.plugins.pop(key, None)
FACTORY.adapters.pop(key, None)


class ContractTestCase(TestCase):
ContractType = None

Expand All @@ -126,11 +149,12 @@ def setUp(self):
super().setUp()

def assert_to_dict(self, obj, dct):
self.assertEqual(obj.to_dict(), dct)
self.assertEqual(obj.to_dict(omit_none=True), dct)

def assert_from_dict(self, obj, dct, cls=None):
if cls is None:
cls = self.ContractType
cls.validate(dct)
self.assertEqual(cls.from_dict(dct), obj)

def assert_symmetric(self, obj, dct, cls=None):
Expand All @@ -142,9 +166,59 @@ def assert_fails_validation(self, dct, cls=None):
cls = self.ContractType

with self.assertRaises(ValidationError):
cls.validate(dct)
cls.from_dict(dct)


def compare_dicts(dict1, dict2):
first_set = set(dict1.keys())
second_set = set(dict2.keys())
print(f"--- Difference between first and second keys: {first_set.difference(second_set)}")
print(f"--- Difference between second and first keys: {second_set.difference(first_set)}")
common_keys = set(first_set).intersection(set(second_set))
found_differences = False
for key in common_keys:
if dict1[key] != dict2[key] :
print(f"--- --- first dict: {key}: {str(dict1[key])}")
print(f"--- --- second dict: {key}: {str(dict2[key])}")
found_differences = True
if found_differences:
print("--- Found differences in dictionaries")
else:
print("--- Found no differences in dictionaries")


def assert_from_dict(obj, dct, cls=None):
if cls is None:
cls = obj.__class__
cls.validate(dct)
obj_from_dict = cls.from_dict(dct)
if hasattr(obj, 'created_at'):
obj_from_dict.created_at = 1
obj.created_at = 1
assert obj_from_dict == obj


def assert_to_dict(obj, dct):
obj_to_dict = obj.to_dict(omit_none=True)
if 'created_at' in obj_to_dict:
obj_to_dict['created_at'] = 1
if 'created_at' in dct:
dct['created_at'] = 1
assert obj_to_dict == dct


def assert_symmetric(obj, dct, cls=None):
assert_to_dict(obj, dct)
assert_from_dict(obj, dct, cls)


def assert_fails_validation(dct, cls):
with pytest.raises(ValidationError):
cls.validate(dct)
cls.from_dict(dct)


def generate_name_macros(package):
from dbt.contracts.graph.parsed import ParsedMacro
from dbt.node_types import NodeType
Expand All @@ -158,7 +232,6 @@ def generate_name_macros(package):
sql = f'{{% macro {name}(value, node) %}} {{% if value %}} {{{{ value }}}} {{% else %}} {{{{ {source} }}}} {{% endif %}} {{% endmacro %}}'
name_sql[name] = sql

all_sql = '\n'.join(name_sql.values())
for name, sql in name_sql.items():
pm = ParsedMacro(
name=name,
Expand All @@ -168,7 +241,134 @@ def generate_name_macros(package):
original_file_path=normalize('macros/macro.sql'),
root_path='./dbt_packages/root',
path=normalize('macros/macro.sql'),
raw_sql=all_sql,
macro_sql=sql,
)
yield pm


class TestAdapterConversions(TestCase):
def _get_tester_for(self, column_type):
from dbt.clients import agate_helper
if column_type is agate.TimeDelta: # dbt never makes this!
return agate.TimeDelta()

for instance in agate_helper.DEFAULT_TYPE_TESTER._possible_types:
if type(instance) is column_type:
return instance

raise ValueError(f'no tester for {column_type}')

def _make_table_of(self, rows, column_types):
column_names = list(string.ascii_letters[:len(rows[0])])
if isinstance(column_types, type):
column_types = [self._get_tester_for(column_types) for _ in column_names]
else:
column_types = [self._get_tester_for(typ) for typ in column_types]
table = agate.Table(rows, column_names=column_names, column_types=column_types)
return table


def MockMacro(package, name='my_macro', **kwargs):
from dbt.contracts.graph.parsed import ParsedMacro
from dbt.node_types import NodeType

mock_kwargs = dict(
resource_type=NodeType.Macro,
package_name=package,
unique_id=f'macro.{package}.{name}',
original_file_path='/dev/null',
)

mock_kwargs.update(kwargs)

macro = mock.MagicMock(
spec=ParsedMacro,
**mock_kwargs
)
macro.name = name
return macro


def MockMaterialization(package, name='my_materialization', adapter_type=None, **kwargs):
if adapter_type is None:
adapter_type = 'default'
kwargs['adapter_type'] = adapter_type
return MockMacro(package, f'materialization_{name}_{adapter_type}', **kwargs)


def MockGenerateMacro(package, component='some_component', **kwargs):
name = f'generate_{component}_name'
return MockMacro(package, name=name, **kwargs)


def MockSource(package, source_name, name, **kwargs):
from dbt.node_types import NodeType
from dbt.contracts.graph.parsed import ParsedSourceDefinition
src = mock.MagicMock(
__class__=ParsedSourceDefinition,
resource_type=NodeType.Source,
source_name=source_name,
package_name=package,
unique_id=f'source.{package}.{source_name}.{name}',
search_name=f'{source_name}.{name}',
**kwargs
)
src.name = name
return src


def MockNode(package, name, resource_type=None, **kwargs):
from dbt.node_types import NodeType
from dbt.contracts.graph.parsed import ParsedModelNode, ParsedSeedNode
if resource_type is None:
resource_type = NodeType.Model
if resource_type == NodeType.Model:
cls = ParsedModelNode
elif resource_type == NodeType.Seed:
cls = ParsedSeedNode
else:
raise ValueError(f'I do not know how to handle {resource_type}')
node = mock.MagicMock(
__class__=cls,
resource_type=resource_type,
package_name=package,
unique_id=f'{str(resource_type)}.{package}.{name}',
search_name=name,
**kwargs
)
node.name = name
return node


def MockDocumentation(package, name, **kwargs):
from dbt.node_types import NodeType
from dbt.contracts.graph.parsed import ParsedDocumentation
doc = mock.MagicMock(
__class__=ParsedDocumentation,
resource_type=NodeType.Documentation,
package_name=package,
search_name=name,
unique_id=f'{package}.{name}',
**kwargs
)
doc.name = name
return doc


def load_internal_manifest_macros(config, macro_hook = lambda m: None):
from dbt.parser.manifest import ManifestLoader
return ManifestLoader.load_macros(config, macro_hook)



def dict_replace(dct, **kwargs):
dct = dct.copy()
dct.update(kwargs)
return dct


def replace_config(n, **kwargs):
return n.replace(
config=n.config.replace(**kwargs),
unrendered_config=dict_replace(n.unrendered_config, **kwargs),
)

0 comments on commit af5f1b4

Please sign in to comment.