Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

get_column_schema_from_query macro #6986

Merged
merged 19 commits into from
Mar 3, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20230222-130632.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: get_column_schema_from_query_macro
time: 2023-02-22T13:06:32.583743-05:00
custom:
Author: jtcohen6 michelleark
Issue: "6751"
11 changes: 11 additions & 0 deletions core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ class BaseAdapter(metaclass=AdapterMeta):
- truncate_relation
- rename_relation
- get_columns_in_relation
- get_column_schema_from_query
- expand_column_types
- list_relations_without_caching
- is_cancelable
Expand Down Expand Up @@ -268,6 +269,16 @@ def execute(
"""
return self.connections.execute(sql=sql, auto_begin=auto_begin, fetch=fetch)

@available.parse(lambda *a, **k: [])
def get_column_schema_from_query(self, sql: str) -> List[Tuple[str, Any]]:
"""Get a list of the column names and data types from the given sql.

:param str sql: The sql to execute.
:return: A tuple of column schema attributes (column_name: str, data_type: Any), which can be used to construct a Column object.
:rtype: List[(column_name: str, data_type: Any)]
"""
return self.connections.get_column_schema_from_query(sql=sql)

@available.parse(lambda *a, **k: ("", empty_table()))
def get_partitions_metadata(self, table: str) -> Tuple[agate.Table]:
"""Obtain partitions metadata for a BigQuery partitioned table.
Expand Down
32 changes: 31 additions & 1 deletion core/dbt/adapters/sql/connections.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc
import time
from typing import List, Optional, Tuple, Any, Iterable, Dict
from typing import List, Optional, Tuple, Any, Iterable, Dict, Union

import agate

Expand Down Expand Up @@ -128,6 +128,31 @@ def get_result_from_cursor(cls, cursor: Any) -> agate.Table:

return dbt.clients.agate_helper.table_from_data_flat(data, column_names)

@classmethod
def data_type_code_to_name(cls, type_code: Union[int, str]) -> str:
"""Get the string representation of the data type from the type_code."""
# https://peps.python.org/pep-0249/#type-objects
raise dbt.exceptions.NotImplementedError(
"`data_type_code_to_name` is not implemented for this adapter!"
)

Comment on lines +132 to +139
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was interested in seeing a tangible example of how data type codes map to a string representation for a database connector.

This table from the Snowflake docs was useful to me:
https://docs.snowflake.com/en/user-guide/python-connector-api#label-python-connector-type-codes

type_code String Representation Data Type
0 FIXED NUMBER/INT
1 REAL REAL
2 TEXT VARCHAR/STRING
3 DATE DATE
4 TIMESTAMP TIMESTAMP
5 VARIANT VARIANT
6 TIMESTAMP_LTZ TIMESTAMP_LTZ
7 TIMESTAMP_TZ TIMESTAMP_TZ
8 TIMESTAMP_NTZ TIMESTAMP_TZ
9 OBJECT OBJECT
10 ARRAY ARRAY
11 BINARY BINARY
12 TIME TIME
13 BOOLEAN BOOLEAN

(Side note: I suspect there is a typo for code 8 and TIMESTAMP_TZ there should be TIMESTAMP_NTZ instead.)

@classmethod
def get_column_schema_from_cursor(cls, cursor: Any) -> List[Tuple[str, str]]:
MichelleArk marked this conversation as resolved.
Show resolved Hide resolved
# (column_name, data_type)
columns: List[Tuple[str, str]] = []

if cursor.description is not None:
# https://peps.python.org/pep-0249/#description
columns = [
# TODO: ignoring size, precision, scale for now
# (though it is part of DB-API standard, and our Column class does have these attributes)
# IMO user-defined contracts shouldn't have to match an exact size/precision/scale
(col[0], cls.data_type_code_to_name(col[1]))
for col in cursor.description
]

return columns

def execute(
self, sql: str, auto_begin: bool = False, fetch: bool = False
) -> Tuple[AdapterResponse, agate.Table]:
Expand All @@ -140,6 +165,11 @@ def execute(
table = dbt.clients.agate_helper.empty_table()
return response, table

def get_column_schema_from_query(self, sql: str) -> List[Tuple[str, Any]]:
sql = self._add_query_comment(sql)
_, cursor = self.add_query(sql)
return self.get_column_schema_from_cursor(cursor)

def add_begin_query(self):
return self.add_query("BEGIN", auto_begin=False)

Expand Down
51 changes: 44 additions & 7 deletions core/dbt/include/global_project/macros/adapters/columns.sql
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,60 @@
{% endmacro %}


{% macro get_empty_subquery_sql(select_sql) -%}
{{ return(adapter.dispatch('get_empty_subquery_sql', 'dbt')(select_sql)) }}
{% endmacro %}

{% macro default__get_empty_subquery_sql(select_sql) %}
MichelleArk marked this conversation as resolved.
Show resolved Hide resolved
select * from (
{{ select_sql }}
) as __dbt_sbq
where false
limit 0
{% endmacro %}


{% macro get_empty_schema_sql(columns) -%}
{{ return(adapter.dispatch('get_empty_schema_sql', 'dbt')(columns)) }}
{% endmacro %}

{% macro default__get_empty_schema_sql(columns) %}
select
{% for i in columns %}
{%- set col = columns[i] -%}
cast(null as {{ col['data_type'] }}) as {{ col['name'] }}{{ ", " if not loop.last }}
Copy link
Contributor

@VersusFacit VersusFacit Feb 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We talked about this. This can possibly lead to some weird type resolution outcomes...we think. This is so far the "best" option and so far looks promising. I just know SQLs typing mechanisms can get a mind of their own for the worse.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@VersusFacit I had the same concern, and discussed it with Michelle synchronously. On the plus side, this approach will automatically account for any new types which appear, so if it works in practice it will be a lot easier than trying to maintain our own list of type aliases. I like that. I was also reassured that this code path will only affect people using contracts, so there isn't much regression risk and we should hear pretty quickly if there are databses/drivers that this approach doesn't work for.

Copy link
Contributor Author

@MichelleArk MichelleArk Feb 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more thing this approach has going for it is that there's a non-opaque definition of what data_type values should be - it's 'the value you'd write in SQL when casting to the desired type', as opposed to 'the value returned by mapping the connection cursor's type code to a string'.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this discussion should be seen as a context dump for posterity and is not blocking.

{%- endfor -%}
{% endmacro %}

{% macro get_column_schema_from_query(select_sql) -%}
{{ return(adapter.dispatch('get_column_schema_from_query', 'dbt')(select_sql)) }}
{% endmacro %}

{% macro default__get_column_schema_from_query(select_sql) %}
MichelleArk marked this conversation as resolved.
Show resolved Hide resolved
{% set columns = [] %}
{% set sql = get_empty_subquery_sql(select_sql) %}
{% set column_schema = adapter.get_column_schema_from_query(sql) %}
{% for col in column_schema %}
-- api.Column.create includes a step for translating data type
-- TODO: could include size, precision, scale here
{% set column = api.Column.create(col[0], col[1]) %}
{% do columns.append(column) %}
{% endfor %}
{{ return(columns) }}
{% endmacro %}

-- here for back compat
{% macro get_columns_in_query(select_sql) -%}
{{ return(adapter.dispatch('get_columns_in_query', 'dbt')(select_sql)) }}
{% endmacro %}

{% macro default__get_columns_in_query(select_sql) %}
{% call statement('get_columns_in_query', fetch_result=True, auto_begin=False) -%}
select * from (
{{ select_sql }}
) as __dbt_sbq
where false
limit 0
{{ get_empty_subquery_sql(select_sql) }}
{% endcall %}

{{ return(load_result('get_columns_in_query').table.columns | map(attribute='name') | list) }}
{% endmacro %}


{% macro alter_column_type(relation, column_name, new_column_type) -%}
{{ return(adapter.dispatch('alter_column_type', 'dbt')(relation, column_name, new_column_type)) }}
{% endmacro %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,27 @@
{%- endmacro %}

{% macro assert_columns_equivalent(sql) %}
{#- loop through user_provided_columns to get column names -#}
{%- set user_provided_columns = model['columns'] -%}
{%- set column_names_config_only = [] -%}
{%- for i in user_provided_columns -%}
{%- set col = user_provided_columns[i] -%}
{%- set col_name = col['name'] -%}
{%- set column_names_config_only = column_names_config_only.append(col_name) -%}
{%- endfor -%}
{%- set sql_file_provided_columns = get_columns_in_query(sql) -%}

{#- uppercase both schema and sql file columns -#}
{%- set column_names_config_upper= column_names_config_only|map('upper')|join(',') -%}
{%- set column_names_config_formatted = column_names_config_upper.split(',') -%}
{%- set sql_file_provided_columns_upper = sql_file_provided_columns|map('upper')|join(',') -%}
{%- set sql_file_provided_columns_formatted = sql_file_provided_columns_upper.split(',') -%}

{%- if column_names_config_formatted != sql_file_provided_columns_formatted -%}
{%- do exceptions.raise_compiler_error('Please ensure the name, order, and number of columns in your `yml` file match the columns in your SQL file.\nSchema File Columns: ' ~ column_names_config_formatted ~ '\nSQL File Columns: ' ~ sql_file_provided_columns_formatted ~ ' ' ) %}
{%- endif -%}
{%- set sql_file_provided_columns = get_column_schema_from_query(sql) -%}
{%- set schema_file_provided_columns = get_column_schema_from_query(get_empty_schema_sql(model['columns'])) -%}

{%- set sql_file_provided_columns_formatted = format_columns(sql_file_provided_columns) -%}
{%- set schema_file_provided_columns_formatted = format_columns(schema_file_provided_columns) -%}

{%- if sql_file_provided_columns_formatted != schema_file_provided_columns_formatted -%}
{%- do exceptions.raise_compiler_error('Please ensure the name, data_type, order, and number of columns in your `yml` file match the columns in your SQL file.\nSchema File Columns: ' ~ (schema_file_provided_columns_formatted|trim) ~ '\n\nSQL File Columns: ' ~ (sql_file_provided_columns_formatted|trim) ~ ' ' ) %}
{%- endif -%}

{% endmacro %}

{% macro format_columns(columns) %}
{% set formatted_columns = [] %}
{% for column in columns %}
{%- set formatted_column = adapter.dispatch('format_column', 'dbt')(column) -%}
{%- do formatted_columns.append(formatted_column) -%}
{% endfor %}
{{ return(formatted_columns|join(', ')) }}
{%- endmacro -%}

{% macro default__format_column(column) -%}
{{ return(column.column.lower() ~ " " ~ column.dtype) }}
{%- endmacro -%}
5 changes: 5 additions & 0 deletions plugins/postgres/dbt/adapters/postgres/connections.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from contextlib import contextmanager

import psycopg2
from psycopg2.extensions import string_types

import dbt.exceptions
from dbt.adapters.base import Credentials
Expand Down Expand Up @@ -190,3 +191,7 @@ def get_response(cls, cursor) -> AdapterResponse:
status_messsage_strings = [part for part in status_message_parts if not part.isdigit()]
code = " ".join(status_messsage_strings)
return AdapterResponse(_message=message, code=code, rows_affected=rows)

@classmethod
def data_type_code_to_name(cls, type_code: int) -> str:
return string_types[type_code].name
35 changes: 33 additions & 2 deletions tests/adapter/dbt/tests/adapter/constraints/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
}}

select
1 as color,
'blue' as id,
'blue' as color,
1 as id,
cast('2019-01-01' as date) as date_day
"""

Expand All @@ -37,6 +37,20 @@
cast('2019-01-01' as date) as date_day
"""

my_model_wrong_data_type_sql = """
{{
config(
materialized = "table"
)
}}

select
'1' as id,
'blue' as color,
cast('2019-01-01' as date) as date_day,
ARRAY['a', 'b', 'c'] as num_array
"""

my_model_with_nulls_sql = """
{{
config(
Expand Down Expand Up @@ -116,4 +130,21 @@
data_type: text
- name: date_day
data_type: date
- name: my_model_wrong_data_type
config:
contract: true
columns:
- name: id
data_type: integer
description: hello
constraints: ['not null','primary key']
constraints_check: (id > 0)
tests:
- unique
- name: color
data_type: text
- name: date_day
data_type: date
- name: num_array
data_type: int[]
MichelleArk marked this conversation as resolved.
Show resolved Hide resolved
"""
64 changes: 56 additions & 8 deletions tests/adapter/dbt/tests/adapter/constraints/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
my_model_sql,
my_model_wrong_order_sql,
my_model_wrong_name_sql,
my_model_wrong_data_type_sql,
my_model_with_nulls_sql,
model_schema_yml,
)
Expand All @@ -29,10 +30,27 @@ def models(self):
return {
"my_model_wrong_order.sql": my_model_wrong_order_sql,
"my_model_wrong_name.sql": my_model_wrong_name_sql,
"my_model_wrong_data_type.sql": my_model_wrong_data_type_sql,
"constraints_schema.yml": model_schema_yml,
}

def test__constraints_wrong_column_order(self, project):
@pytest.fixture
def string_type(self):
return "TEXT"

@pytest.fixture
def int_type(self):
return "INT"

@pytest.fixture
def int_array_type(self):
return "INTEGERARRAY"

@pytest.fixture
def string_array_type(self):
return "STRINGARRAY"

def test__constraints_wrong_column_order(self, project, string_type, int_type):
results, log_output = run_dbt_and_capture(
["run", "-s", "my_model_wrong_order"], expect_pass=False
)
Expand All @@ -43,15 +61,20 @@ def test__constraints_wrong_column_order(self, project):

assert contract_actual_config is True

expected_compile_error = "Please ensure the name, order, and number of columns in your `yml` file match the columns in your SQL file."
expected_schema_file_columns = "Schema File Columns: ['ID', 'COLOR', 'DATE_DAY']"
expected_sql_file_columns = "SQL File Columns: ['COLOR', 'ID', 'DATE_DAY']"
expected_compile_error = "Please ensure the name, data_type, order, and number of columns in your `yml` file match the columns in your SQL file."

expected_schema_file_columns = (
f"Schema File Columns: id {int_type}, color {string_type}, date_day DATE"
)
expected_sql_file_columns = (
f"SQL File Columns: color {string_type}, id {int_type}, date_day DATE"
)

assert expected_compile_error in log_output
assert expected_schema_file_columns in log_output
assert expected_sql_file_columns in log_output

def test__constraints_wrong_column_names(self, project):
def test__constraints_wrong_column_names(self, project, string_type, int_type):
results, log_output = run_dbt_and_capture(
["run", "-s", "my_model_wrong_name"], expect_pass=False
)
Expand All @@ -62,9 +85,34 @@ def test__constraints_wrong_column_names(self, project):

assert contract_actual_config is True

expected_compile_error = "Please ensure the name, order, and number of columns in your `yml` file match the columns in your SQL file."
expected_schema_file_columns = "Schema File Columns: ['ID', 'COLOR', 'DATE_DAY']"
expected_sql_file_columns = "SQL File Columns: ['ERROR', 'COLOR', 'DATE_DAY']"
expected_compile_error = "Please ensure the name, data_type, order, and number of columns in your `yml` file match the columns in your SQL file."
expected_schema_file_columns = (
f"Schema File Columns: id {int_type}, color {string_type}, date_day DATE"
)
expected_sql_file_columns = (
f"SQL File Columns: error {int_type}, color {string_type}, date_day DATE"
)

assert expected_compile_error in log_output
assert expected_schema_file_columns in log_output
assert expected_sql_file_columns in log_output

def test__constraints_wrong_column_data_types(
self, project, string_type, int_type, int_array_type, string_array_type
):
results, log_output = run_dbt_and_capture(
["run", "-s", "my_model_wrong_data_type"], expect_pass=False
)
manifest = get_manifest(project.project_root)
model_id = "model.test.my_model_wrong_data_type"
my_model_config = manifest.nodes[model_id].config
contract_actual_config = my_model_config.contract

assert contract_actual_config is True

expected_compile_error = "Please ensure the name, data_type, order, and number of columns in your `yml` file match the columns in your SQL file."
expected_schema_file_columns = f"Schema File Columns: id {int_type}, color {string_type}, date_day DATE, num_array {int_array_type}"
expected_sql_file_columns = f"SQL File Columns: id {string_type}, color {string_type}, date_day DATE, num_array {string_array_type}"

assert expected_compile_error in log_output
assert expected_schema_file_columns in log_output
Expand Down