diff --git a/.changes/unreleased/Features-20230222-130632.yaml b/.changes/unreleased/Features-20230222-130632.yaml new file mode 100644 index 00000000000..008052b284a --- /dev/null +++ b/.changes/unreleased/Features-20230222-130632.yaml @@ -0,0 +1,6 @@ +kind: Features +body: get_column_schema_from_query_macro +time: 2023-02-22T13:06:32.583743-05:00 +custom: + Author: jtcohen6 michelleark + Issue: "6751" diff --git a/core/dbt/adapters/base/impl.py b/core/dbt/adapters/base/impl.py index 8234f90910c..97e8ac13369 100644 --- a/core/dbt/adapters/base/impl.py +++ b/core/dbt/adapters/base/impl.py @@ -17,6 +17,7 @@ Iterator, Set, ) + import agate import pytz @@ -37,10 +38,7 @@ UnexpectedNonTimestampError, ) -from dbt.adapters.protocol import ( - AdapterConfig, - ConnectionManagerProtocol, -) +from dbt.adapters.protocol import AdapterConfig, ConnectionManagerProtocol from dbt.clients.agate_helper import empty_table, merge_tables, table_from_rows from dbt.clients.jinja import MacroGenerator from dbt.contracts.graph.manifest import Manifest, MacroManifest @@ -176,6 +174,7 @@ class BaseAdapter(metaclass=AdapterMeta): - truncate_relation - rename_relation - get_columns_in_relation + - get_column_schema_from_query - expand_column_types - list_relations_without_caching - is_cancelable @@ -268,6 +267,19 @@ def execute( """ return self.connections.execute(sql=sql, auto_begin=auto_begin, fetch=fetch) + @available.parse(lambda *a, **k: []) + def get_column_schema_from_query(self, sql: str) -> List[BaseColumn]: + """Get a list of the Columns with names and data types from the given sql.""" + _, cursor = self.connections.add_select_query(sql) + columns = [ + self.Column.create( + column_name, self.connections.data_type_code_to_name(column_type_code) + ) + # https://peps.python.org/pep-0249/#description + for column_name, column_type_code, *_ in cursor.description + ] + return columns + @available.parse(lambda *a, **k: ("", empty_table())) def get_partitions_metadata(self, table: str) -> Tuple[agate.Table]: """Obtain partitions metadata for a BigQuery partitioned table. diff --git a/core/dbt/adapters/sql/connections.py b/core/dbt/adapters/sql/connections.py index e13cf12e319..88e4a30d0b6 100644 --- a/core/dbt/adapters/sql/connections.py +++ b/core/dbt/adapters/sql/connections.py @@ -1,6 +1,6 @@ import abc import time -from typing import List, Optional, Tuple, Any, Iterable, Dict +from typing import List, Optional, Tuple, Any, Iterable, Dict, Union import agate @@ -52,6 +52,7 @@ def add_query( bindings: Optional[Any] = None, abridge_sql_log: bool = False, ) -> Tuple[Connection, Any]: + connection = self.get_thread_connection() if auto_begin and connection.transaction_open is False: self.begin() @@ -128,6 +129,14 @@ def get_result_from_cursor(cls, cursor: Any) -> agate.Table: return dbt.clients.agate_helper.table_from_data_flat(data, column_names) + @classmethod + def data_type_code_to_name(cls, type_code: Union[int, str]) -> str: + """Get the string representation of the data type from the type_code.""" + # https://peps.python.org/pep-0249/#type-objects + raise dbt.exceptions.NotImplementedError( + "`data_type_code_to_name` is not implemented for this adapter!" + ) + def execute( self, sql: str, auto_begin: bool = False, fetch: bool = False ) -> Tuple[AdapterResponse, agate.Table]: @@ -146,6 +155,10 @@ def add_begin_query(self): def add_commit_query(self): return self.add_query("COMMIT", auto_begin=False) + def add_select_query(self, sql: str) -> Tuple[Connection, Any]: + sql = self._add_query_comment(sql) + return self.add_query(sql, auto_begin=False) + def begin(self): connection = self.get_thread_connection() if connection.transaction_open is True: diff --git a/core/dbt/docs/build/doctrees/environment.pickle b/core/dbt/docs/build/doctrees/environment.pickle index e74fd454bfb..89511fea3ba 100644 Binary files a/core/dbt/docs/build/doctrees/environment.pickle and b/core/dbt/docs/build/doctrees/environment.pickle differ diff --git a/core/dbt/include/global_project/macros/adapters/columns.sql b/core/dbt/include/global_project/macros/adapters/columns.sql index 7f8302a2bc1..99cbf27529b 100644 --- a/core/dbt/include/global_project/macros/adapters/columns.sql +++ b/core/dbt/include/global_project/macros/adapters/columns.sql @@ -17,23 +17,55 @@ {% endmacro %} +{% macro get_empty_subquery_sql(select_sql) -%} + {{ return(adapter.dispatch('get_empty_subquery_sql', 'dbt')(select_sql)) }} +{% endmacro %} + +{# + Builds a query that results in the same schema as the given select_sql statement, without necessitating a data scan. + Useful for running a query in a 'pre-flight' context, such as model contract enforcement (assert_columns_equivalent macro). +#} +{% macro default__get_empty_subquery_sql(select_sql) %} + select * from ( + {{ select_sql }} + ) as __dbt_sbq + where false + limit 0 +{% endmacro %} + + +{% macro get_empty_schema_sql(columns) -%} + {{ return(adapter.dispatch('get_empty_schema_sql', 'dbt')(columns)) }} +{% endmacro %} + +{% macro default__get_empty_schema_sql(columns) %} + select + {% for i in columns %} + {%- set col = columns[i] -%} + cast(null as {{ col['data_type'] }}) as {{ col['name'] }}{{ ", " if not loop.last }} + {%- endfor -%} +{% endmacro %} + +{% macro get_column_schema_from_query(select_sql) -%} + {% set columns = [] %} + {# -- Using an 'empty subquery' here to get the same schema as the given select_sql statement, without necessitating a data scan.#} + {% set sql = get_empty_subquery_sql(select_sql) %} + {% set column_schema = adapter.get_column_schema_from_query(sql) %} + {{ return(column_schema) }} +{% endmacro %} + +-- here for back compat {% macro get_columns_in_query(select_sql) -%} {{ return(adapter.dispatch('get_columns_in_query', 'dbt')(select_sql)) }} {% endmacro %} {% macro default__get_columns_in_query(select_sql) %} {% call statement('get_columns_in_query', fetch_result=True, auto_begin=False) -%} - select * from ( - {{ select_sql }} - ) as __dbt_sbq - where false - limit 0 + {{ get_empty_subquery_sql(select_sql) }} {% endcall %} - {{ return(load_result('get_columns_in_query').table.columns | map(attribute='name') | list) }} {% endmacro %} - {% macro alter_column_type(relation, column_name, new_column_type) -%} {{ return(adapter.dispatch('alter_column_type', 'dbt')(relation, column_name, new_column_type)) }} {% endmacro %} diff --git a/core/dbt/include/global_project/macros/materializations/models/table/columns_spec_ddl.sql b/core/dbt/include/global_project/macros/materializations/models/table/columns_spec_ddl.sql index 7eea90a5fd9..7d929f5ff9c 100644 --- a/core/dbt/include/global_project/macros/materializations/models/table/columns_spec_ddl.sql +++ b/core/dbt/include/global_project/macros/materializations/models/table/columns_spec_ddl.sql @@ -27,25 +27,34 @@ {{ return(assert_columns_equivalent(sql)) }} {%- endmacro %} +{# + Compares the column schema provided by a model's sql file to the column schema provided by a model's schema file. + If any differences in name, data_type or order of columns exist between the two schemas, raises a compiler error +#} {% macro assert_columns_equivalent(sql) %} - {#- loop through user_provided_columns to get column names -#} - {%- set user_provided_columns = model['columns'] -%} - {%- set column_names_config_only = [] -%} - {%- for i in user_provided_columns -%} - {%- set col = user_provided_columns[i] -%} - {%- set col_name = col['name'] -%} - {%- set column_names_config_only = column_names_config_only.append(col_name) -%} - {%- endfor -%} - {%- set sql_file_provided_columns = get_columns_in_query(sql) -%} - - {#- uppercase both schema and sql file columns -#} - {%- set column_names_config_upper= column_names_config_only|map('upper')|join(',') -%} - {%- set column_names_config_formatted = column_names_config_upper.split(',') -%} - {%- set sql_file_provided_columns_upper = sql_file_provided_columns|map('upper')|join(',') -%} - {%- set sql_file_provided_columns_formatted = sql_file_provided_columns_upper.split(',') -%} - - {%- if column_names_config_formatted != sql_file_provided_columns_formatted -%} - {%- do exceptions.raise_compiler_error('Please ensure the name, order, and number of columns in your `yml` file match the columns in your SQL file.\nSchema File Columns: ' ~ column_names_config_formatted ~ '\nSQL File Columns: ' ~ sql_file_provided_columns_formatted ~ ' ' ) %} - {%- endif -%} + {#-- Obtain the column schema provided by sql file. #} + {%- set sql_file_provided_columns = get_column_schema_from_query(sql) -%} + {#--Obtain the column schema provided by the schema file by generating an 'empty schema' query from the model's columns. #} + {%- set schema_file_provided_columns = get_column_schema_from_query(get_empty_schema_sql(model['columns'])) -%} + + {%- set sql_file_provided_columns_formatted = format_columns(sql_file_provided_columns) -%} + {%- set schema_file_provided_columns_formatted = format_columns(schema_file_provided_columns) -%} + + {%- if sql_file_provided_columns_formatted != schema_file_provided_columns_formatted -%} + {%- do exceptions.raise_compiler_error('Please ensure the name, data_type, order, and number of columns in your `yml` file match the columns in your SQL file.\nSchema File Columns: ' ~ (schema_file_provided_columns_formatted|trim) ~ '\n\nSQL File Columns: ' ~ (sql_file_provided_columns_formatted|trim) ~ ' ' ) %} + {%- endif -%} {% endmacro %} + +{% macro format_columns(columns) %} + {% set formatted_columns = [] %} + {% for column in columns %} + {%- set formatted_column = adapter.dispatch('format_column', 'dbt')(column) -%} + {%- do formatted_columns.append(formatted_column) -%} + {% endfor %} + {{ return(formatted_columns|join(', ')) }} +{%- endmacro -%} + +{% macro default__format_column(column) -%} + {{ return(column.column.lower() ~ " " ~ column.dtype) }} +{%- endmacro -%} diff --git a/plugins/postgres/dbt/adapters/postgres/connections.py b/plugins/postgres/dbt/adapters/postgres/connections.py index afa74a46339..cbbdd33fb38 100644 --- a/plugins/postgres/dbt/adapters/postgres/connections.py +++ b/plugins/postgres/dbt/adapters/postgres/connections.py @@ -1,6 +1,7 @@ from contextlib import contextmanager import psycopg2 +from psycopg2.extensions import string_types import dbt.exceptions from dbt.adapters.base import Credentials @@ -190,3 +191,7 @@ def get_response(cls, cursor) -> AdapterResponse: status_messsage_strings = [part for part in status_message_parts if not part.isdigit()] code = " ".join(status_messsage_strings) return AdapterResponse(_message=message, code=code, rows_affected=rows) + + @classmethod + def data_type_code_to_name(cls, type_code: int) -> str: + return string_types[type_code].name diff --git a/tests/adapter/dbt/tests/adapter/constraints/fixtures.py b/tests/adapter/dbt/tests/adapter/constraints/fixtures.py index 68744a24ef3..acf2c8b4bcd 100644 --- a/tests/adapter/dbt/tests/adapter/constraints/fixtures.py +++ b/tests/adapter/dbt/tests/adapter/constraints/fixtures.py @@ -19,8 +19,8 @@ }} select - 1 as color, - 'blue' as id, + 'blue' as color, + 1 as id, cast('2019-01-01' as date) as date_day """ @@ -37,6 +37,17 @@ cast('2019-01-01' as date) as date_day """ +my_model_data_type_sql = """ +{{{{ + config( + materialized = "table" + ) +}}}} + +select + {sql_value} as wrong_data_type_column_name +""" + my_model_with_nulls_sql = """ {{ config( @@ -117,3 +128,14 @@ - name: date_day data_type: date """ + +model_data_type_schema_yml = """ +version: 2 +models: + - name: my_model_data_type + config: + contract: true + columns: + - name: wrong_data_type_column_name + data_type: {data_type} +""" diff --git a/tests/adapter/dbt/tests/adapter/constraints/test_constraints.py b/tests/adapter/dbt/tests/adapter/constraints/test_constraints.py index bfb856de156..5f08fdc845d 100644 --- a/tests/adapter/dbt/tests/adapter/constraints/test_constraints.py +++ b/tests/adapter/dbt/tests/adapter/constraints/test_constraints.py @@ -14,6 +14,8 @@ my_model_sql, my_model_wrong_order_sql, my_model_wrong_name_sql, + my_model_data_type_sql, + model_data_type_schema_yml, my_model_with_nulls_sql, model_schema_yml, ) @@ -32,7 +34,35 @@ def models(self): "constraints_schema.yml": model_schema_yml, } - def test__constraints_wrong_column_order(self, project): + @pytest.fixture + def string_type(self): + return "TEXT" + + @pytest.fixture + def int_type(self): + return "INT" + + @pytest.fixture + def schema_int_type(self, int_type): + return int_type + + @pytest.fixture + def data_types(self, schema_int_type, int_type, string_type): + # sql_column_value, schema_data_type, error_data_type + return [ + ["1", schema_int_type, int_type], + ["'1'", string_type, string_type], + ["cast('2019-01-01' as date)", "date", "DATE"], + ["true", "bool", "BOOL"], + ["'2013-11-03 00:00:00-07'::timestamptz", "timestamptz", "DATETIMETZ"], + ["'2013-11-03 00:00:00-07'::timestamp", "timestamp", "DATETIME"], + ["ARRAY['a','b','c']", "text[]", "STRINGARRAY"], + ["ARRAY[1,2,3]", "int[]", "INTEGERARRAY"], + ["'1'::numeric", "numeric", "DECIMAL"], + ["""'{"bar": "baz", "balance": 7.77, "active": false}'::json""", "json", "JSON"], + ] + + def test__constraints_wrong_column_order(self, project, string_type, int_type): results, log_output = run_dbt_and_capture( ["run", "-s", "my_model_wrong_order"], expect_pass=False ) @@ -43,15 +73,20 @@ def test__constraints_wrong_column_order(self, project): assert contract_actual_config is True - expected_compile_error = "Please ensure the name, order, and number of columns in your `yml` file match the columns in your SQL file." - expected_schema_file_columns = "Schema File Columns: ['ID', 'COLOR', 'DATE_DAY']" - expected_sql_file_columns = "SQL File Columns: ['COLOR', 'ID', 'DATE_DAY']" + expected_compile_error = "Please ensure the name, data_type, order, and number of columns in your `yml` file match the columns in your SQL file." + + expected_schema_file_columns = ( + f"Schema File Columns: id {int_type}, color {string_type}, date_day DATE" + ) + expected_sql_file_columns = ( + f"SQL File Columns: color {string_type}, id {int_type}, date_day DATE" + ) assert expected_compile_error in log_output assert expected_schema_file_columns in log_output assert expected_sql_file_columns in log_output - def test__constraints_wrong_column_names(self, project): + def test__constraints_wrong_column_names(self, project, string_type, int_type): results, log_output = run_dbt_and_capture( ["run", "-s", "my_model_wrong_name"], expect_pass=False ) @@ -62,14 +97,91 @@ def test__constraints_wrong_column_names(self, project): assert contract_actual_config is True - expected_compile_error = "Please ensure the name, order, and number of columns in your `yml` file match the columns in your SQL file." - expected_schema_file_columns = "Schema File Columns: ['ID', 'COLOR', 'DATE_DAY']" - expected_sql_file_columns = "SQL File Columns: ['ERROR', 'COLOR', 'DATE_DAY']" + expected_compile_error = "Please ensure the name, data_type, order, and number of columns in your `yml` file match the columns in your SQL file." + expected_schema_file_columns = ( + f"Schema File Columns: id {int_type}, color {string_type}, date_day DATE" + ) + expected_sql_file_columns = ( + f"SQL File Columns: error {int_type}, color {string_type}, date_day DATE" + ) assert expected_compile_error in log_output assert expected_schema_file_columns in log_output assert expected_sql_file_columns in log_output + def test__constraints_wrong_column_data_types( + self, project, string_type, int_type, schema_int_type, data_types + ): + for (sql_column_value, schema_data_type, error_data_type) in data_types: + # Write parametrized data_type to sql file + write_file( + my_model_data_type_sql.format(sql_value=sql_column_value), + "models", + "my_model_data_type.sql", + ) + + # Write wrong data_type to corresponding schema file + # Write integer type for all schema yaml values except when testing integer type itself + wrong_schema_data_type = ( + schema_int_type + if schema_data_type.upper() != schema_int_type.upper() + else string_type + ) + wrong_schema_error_data_type = ( + int_type if schema_data_type.upper() != schema_int_type.upper() else string_type + ) + write_file( + model_data_type_schema_yml.format(data_type=wrong_schema_data_type), + "models", + "constraints_schema.yml", + ) + + results, log_output = run_dbt_and_capture( + ["run", "-s", "my_model_data_type"], expect_pass=False + ) + manifest = get_manifest(project.project_root) + model_id = "model.test.my_model_data_type" + my_model_config = manifest.nodes[model_id].config + contract_actual_config = my_model_config.contract + + assert contract_actual_config is True + + expected_compile_error = "Please ensure the name, data_type, order, and number of columns in your `yml` file match the columns in your SQL file." + expected_sql_file_columns = ( + f"SQL File Columns: wrong_data_type_column_name {error_data_type}" + ) + expected_schema_file_columns = ( + f"Schema File Columns: wrong_data_type_column_name {wrong_schema_error_data_type}" + ) + + assert expected_compile_error in log_output + assert expected_schema_file_columns in log_output + assert expected_sql_file_columns in log_output + + def test__constraints_correct_column_data_types(self, project, data_types): + for (sql_column_value, schema_data_type, _) in data_types: + # Write parametrized data_type to sql file + write_file( + my_model_data_type_sql.format(sql_value=sql_column_value), + "models", + "my_model_data_type.sql", + ) + # Write correct data_type to corresponding schema file + write_file( + model_data_type_schema_yml.format(data_type=schema_data_type), + "models", + "constraints_schema.yml", + ) + + run_dbt(["run", "-s", "my_model_data_type"]) + + manifest = get_manifest(project.project_root) + model_id = "model.test.my_model_data_type" + my_model_config = manifest.nodes[model_id].config + contract_actual_config = my_model_config.contract + + assert contract_actual_config is True + # This is SUPER specific to Postgres, and will need replacing on other adapters # TODO: make more generic