Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…o 107-generic-function-operation
  • Loading branch information
Abhishek-N committed Apr 11, 2024
2 parents 1a16a20 + ef626df commit ff57b62
Show file tree
Hide file tree
Showing 9 changed files with 544 additions and 12 deletions.
74 changes: 73 additions & 1 deletion dbt_automation/assets/operations.template.yml
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,50 @@ operations:
- ...
dest_schema: <destination schema>
output_name: <name of the output model>

- type: pivot
config:
input:
input_type: <"source" or "model" of table1>
input_name: <name of source table or ref model table1>
source_name: <name of the source defined in source.yml; will be null for type "model" table1>
source_columns:
- <column name>
- <column name>
- <column name>
pivot_column_name: <column name>
pivot_column_values:
- <pivot col value1>
- <pivot col value2>
- <pivot col value3>
dest_schema: <destination schema>
output_name: <name of the output model>

- type: unpivot
config:
input:
input_type: <"source" or "model" of table1>
input_name: <name of source table or ref model table1>
source_name: <name of the source defined in source.yml; will be null for type "model" table1>
source_columns:
- <column name>
- <column name>
- <column name>
exclude_columns:
- <column name>
- <column name>
- <column name>
unpivot_columns:
- <column name>
- <column name>
- <column name>
unpivot_field_name: <by default - "field_name">
unpivot_value_name: <by default - "value">
cast_to: <data type to cast values to - "varchar" for postgres & "STRING" for bigquery>
dest_schema: <destination schema>
output_name: <name of the output model>


- type: mergeoperations
config:
dest_schema: <destination_schema>
Expand Down Expand Up @@ -644,4 +688,32 @@ operations:
- <column name>
- <column name>
- <column name>
- ...
- ...
- type: pivot
config:
source_columns:
- <column name>
- <column name>
- <column name>
pivot_column_name: <column name>
pivot_column_values:
- <pivot col value1>
- <pivot col value2>
- <pivot col value3>
- type: unpivot
config:
source_columns:
- <column name>
- <column name>
- <column name>
exclude_columns:
- <column name>
- <column name>
- <column name>
unpivot_columns:
- <column name>
- <column name>
- <column name>
cast_to: <data type to cast values to - "varchar" for postgres & "STRING" for bigquery>
unpivot_field_name: <by default - "field_name">
unpivot_value_name: <by default - "value">
67 changes: 67 additions & 0 deletions dbt_automation/assets/unpivot.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
{#
Pivot values from columns to rows. Similar to pandas DataFrame melt() function.

Example Usage: {{ unpivot(relation=ref('users'), cast_to='integer', exclude=['id','created_at']) }}

Arguments:
relation: Relation object, required.
cast_to: The datatype to cast all unpivoted columns to. Default is varchar.
exclude: A list of columns to keep but exclude from the unpivot operation. Default is none.
remove: A list of columns to remove from the resulting table. Default is none.
field_name: Destination table column name for the source table column names.
value_name: Destination table column name for the pivoted values
#}

{% macro unpivot(relation=none, cast_to='varchar', exclude=none, remove=none, field_name='field_name', value_name='value', quote_identifiers=True) -%}
{{ return(adapter.dispatch('unpivot', 'dbt_utils')(relation, cast_to, exclude, remove, field_name, value_name, quote_identifiers)) }}
{% endmacro %}

{% macro default__unpivot(relation=none, cast_to='varchar', exclude=none, remove=none, field_name='field_name', value_name='value', quote_identifiers=True) -%}

{% if not relation %}
{{ exceptions.raise_compiler_error("Error: argument `relation` is required for `unpivot` macro.") }}
{% endif %}

{%- set exclude = exclude if exclude is not none else [] %}
{%- set remove = remove if remove is not none else [] %}

{%- set include_cols = [] %}

{%- set table_columns = {} %}

{%- do table_columns.update({relation: []}) %}

{%- do dbt_utils._is_relation(relation, 'unpivot') -%}
{%- do dbt_utils._is_ephemeral(relation, 'unpivot') -%}
{%- set cols = adapter.get_columns_in_relation(relation) %}

{%- for col in cols -%}
{%- if col.column.lower() not in remove|map('lower') and col.column.lower() not in exclude|map('lower') -%}
{% do include_cols.append(col) %}
{%- endif %}
{%- endfor %}


{%- for col in include_cols -%}
{%- set current_col_name = adapter.quote(col.column) if quote_identifiers else col.column -%}
select
{%- for exclude_col in exclude %}
{{ adapter.quote(exclude_col) if quote_identifiers else exclude_col }},
{%- endfor %}

cast('{{ col.column }}' as {{ dbt.type_string() }}) as {{ adapter.quote(field_name) if quote_identifiers else field_name }},
cast( {% if col.data_type == 'boolean' %}
{{ dbt.cast_bool_to_text(current_col_name) }}
{% else %}
{{ current_col_name }}
{% endif %}
as {{ cast_to }}) as {{ adapter.quote(value_name) if quote_identifiers else value_name }}

from {{ relation }}

{% if not loop.last -%}
union all
{% endif -%}
{%- endfor -%}

{%- endmacro %}
12 changes: 11 additions & 1 deletion dbt_automation/operations/mergeoperations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
from dbt_automation.operations.aggregate import aggregate_dbt_sql
from dbt_automation.operations.casewhen import casewhen_dbt_sql
from dbt_automation.operations.flattenjson import flattenjson_dbt_sql
from dbt_automation.operations.mergetables import union_tables, union_tables_sql
from dbt_automation.operations.mergetables import union_tables_sql
from dbt_automation.operations.pivot import pivot_dbt_sql
from dbt_automation.operations.unpivot import unpivot_dbt_sql


def merge_operations_sql(
Expand Down Expand Up @@ -118,6 +120,14 @@ def merge_operations_sql(
op_select_statement, out_cols = union_tables_sql(
operation["config"], warehouse
)
elif operation["type"] == "pivot":
op_select_statement, out_cols = pivot_dbt_sql(
operation["config"], warehouse
)
elif operation["type"] == "unpivot":
op_select_statement, out_cols = unpivot_dbt_sql(
operation["config"], warehouse
)
elif operation["type"] == "generic":
op_select_statement, out_cols = generic_function_dbt_sql(
operation["config"], warehouse
Expand Down
96 changes: 96 additions & 0 deletions dbt_automation/operations/pivot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""
Generates a dbt model for pivot
"""

from logging import basicConfig, getLogger, INFO

from dbt_automation.utils.dbtproject import dbtProject
from dbt_automation.utils.columnutils import quote_columnname
from dbt_automation.utils.interfaces.warehouse_interface import WarehouseInterface
from dbt_automation.utils.columnutils import quote_columnname, quote_constvalue
from dbt_automation.utils.tableutils import source_or_ref

basicConfig(level=INFO)
logger = getLogger()


def select_from(input_table: dict):
"""generates the correct FROM clause for the input table"""
selectfrom = source_or_ref(**input_table)
if input_table["input_type"] == "cte":
return f"FROM {selectfrom}\n"
return f"FROM {{{{{selectfrom}}}}}\n"


# pylint:disable=unused-argument,logging-fstring-interpolation
def pivot_dbt_sql(
config: dict,
warehouse: WarehouseInterface,
):
"""
Generate SQL code for the coalesce_columns operation.
"""
source_columns = config.get("source_columns", [])
pivot_column_values = config.get("pivot_column_values", [])
pivot_column_name = config.get("pivot_column_name", None)
input_table = config["input"]

if not pivot_column_name:
raise ValueError("Pivot column name not provided")

dbt_code = "SELECT\n"

if len(source_columns) > 0:
dbt_code += ",\n".join(
[quote_columnname(col_name, warehouse.name) for col_name in source_columns]
)
dbt_code += ",\n"

dbt_code += "{{ dbt_utils.pivot("
dbt_code += quote_constvalue(
quote_columnname(pivot_column_name, warehouse.name), warehouse.name
)
dbt_code += ", "
dbt_code += (
"["
+ ",".join(
[
quote_constvalue(pivot_val, warehouse.name)
for pivot_val in pivot_column_values
]
)
+ "]"
)
dbt_code += ")}}\n"

dbt_code += select_from(input_table)
if len(source_columns) > 0:
dbt_code += "GROUP BY "
dbt_code += ",".join(
[quote_columnname(col_name, warehouse.name) for col_name in source_columns]
)

return dbt_code, source_columns + pivot_column_values


def pivot(config: dict, warehouse: WarehouseInterface, project_dir: str):
"""
Perform coalescing of columns and generate a DBT model.
"""
dbt_sql = ""
if config["input"]["input_type"] != "cte":
dbt_sql = (
"{{ config(materialized='table', schema='" + config["dest_schema"] + "') }}"
)

select_statement, output_cols = pivot_dbt_sql(config, warehouse)
dbt_sql += "\n" + select_statement

dbt_project = dbtProject(project_dir)
dbt_project.ensure_models_dir(config["dest_schema"])

output_name = config["output_name"]
dest_schema = config["dest_schema"]
model_sql_path = dbt_project.write_model(dest_schema, output_name, dbt_sql)

return model_sql_path, output_cols
23 changes: 15 additions & 8 deletions dbt_automation/operations/scaffold.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""setup the dbt project"""

import glob
import os, shutil, yaml
from pathlib import Path
from string import Template
Expand Down Expand Up @@ -44,14 +45,20 @@ def scaffold(config: dict, warehouse: WarehouseInterface, project_dir: str):
(Path(project_dir) / "models" / "staging").mkdir()
(Path(project_dir) / "models" / "intermediate").mkdir()

flatten_json_target = Path(project_dir) / "macros" / "flatten_json.sql"
custom_schema_target = Path(project_dir) / "macros" / "generate_schema_name.sql"
logger.info("created %s", flatten_json_target)
source_schema_name_macro_path = os.path.abspath(
os.path.join(os.path.abspath(assets.__file__), "..", "generate_schema_name.sql")
)
shutil.copy(source_schema_name_macro_path, custom_schema_target)
logger.info("created %s", custom_schema_target)
# copy all .sql files from assets/ to project_dir/macros
# create if the file is not present in project_dir/macros
assets_dir = assets.__path__[0]

# loop over all sql macros with .sql extension
for sql_file_path in glob.glob(os.path.join(assets_dir, "*.sql")):
# Get the target path in the project_dir/macros directory
target_path = Path(project_dir) / "macros" / Path(sql_file_path).name

# Copy the .sql file to the target path
shutil.copy(sql_file_path, target_path)

# Log the creation of the file
logger.info("created %s", target_path)

dbtproject_filename = Path(project_dir) / "dbt_project.yml"
PROJECT_TEMPLATE = Template(
Expand Down
Loading

0 comments on commit ff57b62

Please sign in to comment.