diff --git a/dbt_automation/assets/operations.template.yml b/dbt_automation/assets/operations.template.yml index b9d9463..39f49b6 100644 --- a/dbt_automation/assets/operations.template.yml +++ b/dbt_automation/assets/operations.template.yml @@ -383,6 +383,50 @@ operations: - ... dest_schema: output_name: + + - type: pivot + config: + input: + input_type: <"source" or "model" of table1> + input_name: + source_name: + source_columns: + - + - + - + pivot_column_name: + pivot_column_values: + - + - + - + dest_schema: + output_name: + + - type: unpivot + config: + input: + input_type: <"source" or "model" of table1> + input_name: + source_name: + source_columns: + - + - + - + exclude_columns: + - + - + - + unpivot_columns: + - + - + - + unpivot_field_name: + unpivot_value_name: + cast_to: + dest_schema: + output_name: + + - type: mergeoperations config: dest_schema: @@ -644,4 +688,32 @@ operations: - - - - - ... \ No newline at end of file + - ... + - type: pivot + config: + source_columns: + - + - + - + pivot_column_name: + pivot_column_values: + - + - + - + - type: unpivot + config: + source_columns: + - + - + - + exclude_columns: + - + - + - + unpivot_columns: + - + - + - + cast_to: + unpivot_field_name: + unpivot_value_name: \ No newline at end of file diff --git a/dbt_automation/assets/unpivot.sql b/dbt_automation/assets/unpivot.sql new file mode 100644 index 0000000..d7bcd4f --- /dev/null +++ b/dbt_automation/assets/unpivot.sql @@ -0,0 +1,67 @@ +{# +Pivot values from columns to rows. Similar to pandas DataFrame melt() function. + +Example Usage: {{ unpivot(relation=ref('users'), cast_to='integer', exclude=['id','created_at']) }} + +Arguments: + relation: Relation object, required. + cast_to: The datatype to cast all unpivoted columns to. Default is varchar. + exclude: A list of columns to keep but exclude from the unpivot operation. Default is none. + remove: A list of columns to remove from the resulting table. Default is none. + field_name: Destination table column name for the source table column names. + value_name: Destination table column name for the pivoted values +#} + +{% macro unpivot(relation=none, cast_to='varchar', exclude=none, remove=none, field_name='field_name', value_name='value', quote_identifiers=True) -%} + {{ return(adapter.dispatch('unpivot', 'dbt_utils')(relation, cast_to, exclude, remove, field_name, value_name, quote_identifiers)) }} +{% endmacro %} + +{% macro default__unpivot(relation=none, cast_to='varchar', exclude=none, remove=none, field_name='field_name', value_name='value', quote_identifiers=True) -%} + + {% if not relation %} + {{ exceptions.raise_compiler_error("Error: argument `relation` is required for `unpivot` macro.") }} + {% endif %} + + {%- set exclude = exclude if exclude is not none else [] %} + {%- set remove = remove if remove is not none else [] %} + + {%- set include_cols = [] %} + + {%- set table_columns = {} %} + + {%- do table_columns.update({relation: []}) %} + + {%- do dbt_utils._is_relation(relation, 'unpivot') -%} + {%- do dbt_utils._is_ephemeral(relation, 'unpivot') -%} + {%- set cols = adapter.get_columns_in_relation(relation) %} + + {%- for col in cols -%} + {%- if col.column.lower() not in remove|map('lower') and col.column.lower() not in exclude|map('lower') -%} + {% do include_cols.append(col) %} + {%- endif %} + {%- endfor %} + + + {%- for col in include_cols -%} + {%- set current_col_name = adapter.quote(col.column) if quote_identifiers else col.column -%} + select + {%- for exclude_col in exclude %} + {{ adapter.quote(exclude_col) if quote_identifiers else exclude_col }}, + {%- endfor %} + + cast('{{ col.column }}' as {{ dbt.type_string() }}) as {{ adapter.quote(field_name) if quote_identifiers else field_name }}, + cast( {% if col.data_type == 'boolean' %} + {{ dbt.cast_bool_to_text(current_col_name) }} + {% else %} + {{ current_col_name }} + {% endif %} + as {{ cast_to }}) as {{ adapter.quote(value_name) if quote_identifiers else value_name }} + + from {{ relation }} + + {% if not loop.last -%} + union all + {% endif -%} + {%- endfor -%} + +{%- endmacro %} \ No newline at end of file diff --git a/dbt_automation/operations/mergeoperations.py b/dbt_automation/operations/mergeoperations.py index 76939b3..77260cc 100644 --- a/dbt_automation/operations/mergeoperations.py +++ b/dbt_automation/operations/mergeoperations.py @@ -22,7 +22,9 @@ from dbt_automation.operations.aggregate import aggregate_dbt_sql from dbt_automation.operations.casewhen import casewhen_dbt_sql from dbt_automation.operations.flattenjson import flattenjson_dbt_sql -from dbt_automation.operations.mergetables import union_tables, union_tables_sql +from dbt_automation.operations.mergetables import union_tables_sql +from dbt_automation.operations.pivot import pivot_dbt_sql +from dbt_automation.operations.unpivot import unpivot_dbt_sql def merge_operations_sql( @@ -118,6 +120,14 @@ def merge_operations_sql( op_select_statement, out_cols = union_tables_sql( operation["config"], warehouse ) + elif operation["type"] == "pivot": + op_select_statement, out_cols = pivot_dbt_sql( + operation["config"], warehouse + ) + elif operation["type"] == "unpivot": + op_select_statement, out_cols = unpivot_dbt_sql( + operation["config"], warehouse + ) elif operation["type"] == "generic": op_select_statement, out_cols = generic_function_dbt_sql( operation["config"], warehouse diff --git a/dbt_automation/operations/pivot.py b/dbt_automation/operations/pivot.py new file mode 100644 index 0000000..1a8a093 --- /dev/null +++ b/dbt_automation/operations/pivot.py @@ -0,0 +1,96 @@ +""" +Generates a dbt model for pivot +""" + +from logging import basicConfig, getLogger, INFO + +from dbt_automation.utils.dbtproject import dbtProject +from dbt_automation.utils.columnutils import quote_columnname +from dbt_automation.utils.interfaces.warehouse_interface import WarehouseInterface +from dbt_automation.utils.columnutils import quote_columnname, quote_constvalue +from dbt_automation.utils.tableutils import source_or_ref + +basicConfig(level=INFO) +logger = getLogger() + + +def select_from(input_table: dict): + """generates the correct FROM clause for the input table""" + selectfrom = source_or_ref(**input_table) + if input_table["input_type"] == "cte": + return f"FROM {selectfrom}\n" + return f"FROM {{{{{selectfrom}}}}}\n" + + +# pylint:disable=unused-argument,logging-fstring-interpolation +def pivot_dbt_sql( + config: dict, + warehouse: WarehouseInterface, +): + """ + Generate SQL code for the coalesce_columns operation. + """ + source_columns = config.get("source_columns", []) + pivot_column_values = config.get("pivot_column_values", []) + pivot_column_name = config.get("pivot_column_name", None) + input_table = config["input"] + + if not pivot_column_name: + raise ValueError("Pivot column name not provided") + + dbt_code = "SELECT\n" + + if len(source_columns) > 0: + dbt_code += ",\n".join( + [quote_columnname(col_name, warehouse.name) for col_name in source_columns] + ) + dbt_code += ",\n" + + dbt_code += "{{ dbt_utils.pivot(" + dbt_code += quote_constvalue( + quote_columnname(pivot_column_name, warehouse.name), warehouse.name + ) + dbt_code += ", " + dbt_code += ( + "[" + + ",".join( + [ + quote_constvalue(pivot_val, warehouse.name) + for pivot_val in pivot_column_values + ] + ) + + "]" + ) + dbt_code += ")}}\n" + + dbt_code += select_from(input_table) + if len(source_columns) > 0: + dbt_code += "GROUP BY " + dbt_code += ",".join( + [quote_columnname(col_name, warehouse.name) for col_name in source_columns] + ) + + return dbt_code, source_columns + pivot_column_values + + +def pivot(config: dict, warehouse: WarehouseInterface, project_dir: str): + """ + Perform coalescing of columns and generate a DBT model. + """ + dbt_sql = "" + if config["input"]["input_type"] != "cte": + dbt_sql = ( + "{{ config(materialized='table', schema='" + config["dest_schema"] + "') }}" + ) + + select_statement, output_cols = pivot_dbt_sql(config, warehouse) + dbt_sql += "\n" + select_statement + + dbt_project = dbtProject(project_dir) + dbt_project.ensure_models_dir(config["dest_schema"]) + + output_name = config["output_name"] + dest_schema = config["dest_schema"] + model_sql_path = dbt_project.write_model(dest_schema, output_name, dbt_sql) + + return model_sql_path, output_cols diff --git a/dbt_automation/operations/scaffold.py b/dbt_automation/operations/scaffold.py index 3d8dcdc..42c42ef 100644 --- a/dbt_automation/operations/scaffold.py +++ b/dbt_automation/operations/scaffold.py @@ -1,5 +1,6 @@ """setup the dbt project""" +import glob import os, shutil, yaml from pathlib import Path from string import Template @@ -44,14 +45,20 @@ def scaffold(config: dict, warehouse: WarehouseInterface, project_dir: str): (Path(project_dir) / "models" / "staging").mkdir() (Path(project_dir) / "models" / "intermediate").mkdir() - flatten_json_target = Path(project_dir) / "macros" / "flatten_json.sql" - custom_schema_target = Path(project_dir) / "macros" / "generate_schema_name.sql" - logger.info("created %s", flatten_json_target) - source_schema_name_macro_path = os.path.abspath( - os.path.join(os.path.abspath(assets.__file__), "..", "generate_schema_name.sql") - ) - shutil.copy(source_schema_name_macro_path, custom_schema_target) - logger.info("created %s", custom_schema_target) + # copy all .sql files from assets/ to project_dir/macros + # create if the file is not present in project_dir/macros + assets_dir = assets.__path__[0] + + # loop over all sql macros with .sql extension + for sql_file_path in glob.glob(os.path.join(assets_dir, "*.sql")): + # Get the target path in the project_dir/macros directory + target_path = Path(project_dir) / "macros" / Path(sql_file_path).name + + # Copy the .sql file to the target path + shutil.copy(sql_file_path, target_path) + + # Log the creation of the file + logger.info("created %s", target_path) dbtproject_filename = Path(project_dir) / "dbt_project.yml" PROJECT_TEMPLATE = Template( diff --git a/dbt_automation/operations/unpivot.py b/dbt_automation/operations/unpivot.py new file mode 100644 index 0000000..f99e717 --- /dev/null +++ b/dbt_automation/operations/unpivot.py @@ -0,0 +1,89 @@ +""" +Generates a dbt model for unpivot +This operation will only work in the chain of mergeoperations if its at the first step +""" + +from logging import basicConfig, getLogger, INFO + +from dbt_automation.utils.dbtproject import dbtProject +from dbt_automation.utils.columnutils import quote_columnname +from dbt_automation.utils.interfaces.warehouse_interface import WarehouseInterface +from dbt_automation.utils.columnutils import quote_columnname, quote_constvalue +from dbt_automation.utils.tableutils import source_or_ref + +basicConfig(level=INFO) +logger = getLogger() + + +# pylint:disable=unused-argument,logging-fstring-interpolation +def unpivot_dbt_sql( + config: dict, + warehouse: WarehouseInterface, +): + """ + Generate SQL code for the coalesce_columns operation. + """ + source_columns = config.get("source_columns", []) # all columns + exclude_columns = config.get( + "exclude_columns", [] + ) # exclude from unpivot but keep in the resulting table + unpivot_on_columns = config.get("unpivot_columns", []) # columns to unpivot + input_table = config["input"] + field_name = config.get("unpivot_field_name", "field_name") + value_name = config.get("unpivot_value_name", "value") + cast_datatype_to = config.get("cast_to", "varchar") + if not cast_datatype_to and warehouse.name == "bigquery": + cast_datatype_to = "STRING" + + if len(unpivot_on_columns) == 0: + raise ValueError("No columns specified for unpivot") + + output_columns = list(set(exclude_columns) | set(unpivot_on_columns)) # union + remove_columns = list(set(source_columns) - set(output_columns)) + + dbt_code = "{{ unpivot(" + dbt_code += source_or_ref(**input_table) + dbt_code += ", exclude=" + dbt_code += ( + "[" + + ",".join( + [quote_constvalue(col_name, warehouse.name) for col_name in exclude_columns] + ) + + "] ," + ) + dbt_code += f"cast_to={quote_constvalue(cast_datatype_to, warehouse.name)}, " + dbt_code += "remove=" + dbt_code += ( + "[" + + ",".join( + [quote_constvalue(col_name, warehouse.name) for col_name in remove_columns] + ) + + "] ," + ) + dbt_code += f"field_name={quote_constvalue(field_name, warehouse.name)}, value_name={quote_constvalue(value_name, warehouse.name)}" + dbt_code += ")}}\n" + + return dbt_code, output_columns + + +def unpivot(config: dict, warehouse: WarehouseInterface, project_dir: str): + """ + Perform coalescing of columns and generate a DBT model. + """ + dbt_sql = "" + if config["input"]["input_type"] != "cte": + dbt_sql = ( + "{{ config(materialized='table', schema='" + config["dest_schema"] + "') }}" + ) + + select_statement, output_cols = unpivot_dbt_sql(config, warehouse) + dbt_sql += "\n" + select_statement + + dbt_project = dbtProject(project_dir) + dbt_project.ensure_models_dir(config["dest_schema"]) + + output_name = config["output_name"] + dest_schema = config["dest_schema"] + model_sql_path = dbt_project.write_model(dest_schema, output_name, dbt_sql) + + return model_sql_path, output_cols diff --git a/scripts/main.py b/scripts/main.py index 5356164..bc7ee79 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -28,6 +28,8 @@ from dbt_automation.operations.groupby import groupby from dbt_automation.operations.aggregate import aggregate from dbt_automation.operations.casewhen import casewhen +from dbt_automation.operations.pivot import pivot +from dbt_automation.operations.unpivot import unpivot OPERATIONS_DICT = { "flatten": flatten_operation, @@ -49,6 +51,8 @@ "groupby": groupby, "aggregate": aggregate, "casewhen": casewhen, + "pivot": pivot, + "unpivot": unpivot, "generic": generic_function } diff --git a/tests/warehouse/test_bigquery_ops.py b/tests/warehouse/test_bigquery_ops.py index 3e357e8..1fb38f1 100644 --- a/tests/warehouse/test_bigquery_ops.py +++ b/tests/warehouse/test_bigquery_ops.py @@ -25,6 +25,8 @@ from dbt_automation.operations.mergetables import union_tables from dbt_automation.operations.aggregate import aggregate from dbt_automation.operations.casewhen import casewhen +from dbt_automation.operations.pivot import pivot +from dbt_automation.operations.unpivot import unpivot basicConfig(level=INFO) @@ -673,6 +675,44 @@ def test_casewhen(self): assert "SPOC B" not in spoc_values assert "SPOC C" not in spoc_values + def test_pivot(self): + """test casewhen operation""" + wc_client = TestBigqueryOperations.wc_client + output_name = "pivot_op" + + config = { + "input": { + "input_type": "model", + "input_name": "_airbyte_raw_Sheet2", + "source_name": None, + }, + "dest_schema": "pytest_intermediate", + "output_name": output_name, + "source_columns": ["SPOC"], + "pivot_column_name": "NGO", + "pivot_column_values": ["IMAGE", "FDSR", "CRC", "BAMANEH", "JTS"], + } + + pivot( + config, + wc_client, + TestBigqueryOperations.test_project_dir, + ) + + TestBigqueryOperations.execute_dbt("run", output_name) + + cols = [ + col_dict["name"] + for col_dict in wc_client.get_table_columns( + "pytest_intermediate", output_name + ) + ] + assert sorted(cols) == sorted( + config["pivot_column_values"] + config["source_columns"] + ) + table_data = wc_client.get_table_data("pytest_intermediate", output_name, 10) + assert len(table_data) == 3 + def test_mergetables(self): """test merge tables""" wc_client = TestBigqueryOperations.wc_client @@ -725,6 +765,53 @@ def test_mergetables(self): assert len(table_data1) + len(table_data2) == len(table_data_union) + def test_unpivot(self): + """test unpivot operation""" + wc_client = TestBigqueryOperations.wc_client + output_name = "unpivot_op" + + config = { + "input": { + "input_type": "model", + "input_name": "_airbyte_raw_Sheet2", + "source_name": None, + }, + "dest_schema": "pytest_intermediate", + "output_name": output_name, + "source_columns": [ + "NGO", + "SPOC", + "Month", + "measure1", + "_airbyte_ab_id", + "measure2", + "Indicator", + ], + "exclude_columns": [], + "unpivot_columns": ["NGO", "SPOC"], + "unpivot_field_name": "col_field", + "unpivot_value_name": "col_val", + } + + unpivot( + config, + wc_client, + TestBigqueryOperations.test_project_dir, + ) + + TestBigqueryOperations.execute_dbt("run", output_name) + + cols = [ + col_dict["name"] + for col_dict in wc_client.get_table_columns( + "pytest_intermediate", output_name + ) + ] + assert len(cols) == 2 + assert sorted(cols) == sorted( + [config["unpivot_field_name"], config["unpivot_value_name"]] + ) + def test_merge_operation(self): """test merge_operation""" wc_client = TestBigqueryOperations.wc_client @@ -968,7 +1055,14 @@ def test_flattenjson(self): "_airbyte_emitted_at", ], "json_column": "_airbyte_data", - "json_columns_to_copy": ["NGO", "SPOC", "Month", "measure1", "measure2", "Indicator"], + "json_columns_to_copy": [ + "NGO", + "SPOC", + "Month", + "measure1", + "measure2", + "Indicator", + ], } flattenjson(config, wc_client, TestBigqueryOperations.test_project_dir) diff --git a/tests/warehouse/test_postgres_ops.py b/tests/warehouse/test_postgres_ops.py index eb0d0b6..0eb01e0 100644 --- a/tests/warehouse/test_postgres_ops.py +++ b/tests/warehouse/test_postgres_ops.py @@ -23,6 +23,8 @@ from dbt_automation.operations.mergetables import union_tables from dbt_automation.operations.aggregate import aggregate from dbt_automation.operations.casewhen import casewhen +from dbt_automation.operations.pivot import pivot +from dbt_automation.operations.unpivot import unpivot basicConfig(level=INFO) @@ -684,6 +686,91 @@ def test_casewhen(self): assert "SPOC B" not in spoc_values assert "SPOC C" not in spoc_values + def test_pivot(self): + """test pivot operation""" + wc_client = TestPostgresOperations.wc_client + output_name = "pivot_op" + + config = { + "input": { + "input_type": "model", + "input_name": "_airbyte_raw_Sheet2", + "source_name": None, + }, + "dest_schema": "pytest_intermediate", + "output_name": output_name, + "source_columns": ["SPOC"], + "pivot_column_name": "NGO", + "pivot_column_values": ["IMAGE", "FDSR", "CRC", "BAMANEH", "JTS"], + } + + pivot( + config, + wc_client, + TestPostgresOperations.test_project_dir, + ) + + TestPostgresOperations.execute_dbt("run", output_name) + + cols = [ + col_dict["name"] + for col_dict in wc_client.get_table_columns( + "pytest_intermediate", output_name + ) + ] + assert sorted(cols) == sorted( + config["pivot_column_values"] + config["source_columns"] + ) + table_data = wc_client.get_table_data("pytest_intermediate", output_name, 10) + assert len(table_data) == 3 + + def test_unpivot(self): + """test unpivot operation""" + wc_client = TestPostgresOperations.wc_client + output_name = "unpivot_op" + + config = { + "input": { + "input_type": "model", + "input_name": "_airbyte_raw_Sheet2", + "source_name": None, + }, + "dest_schema": "pytest_intermediate", + "output_name": output_name, + "source_columns": [ + "NGO", + "SPOC", + "Month", + "measure1", + "_airbyte_ab_id", + "measure2", + "Indicator", + ], + "exclude_columns": [], + "unpivot_columns": ["NGO", "SPOC"], + "unpivot_field_name": "col_field", + "unpivot_value_name": "col_val", + } + + unpivot( + config, + wc_client, + TestPostgresOperations.test_project_dir, + ) + + TestPostgresOperations.execute_dbt("run", output_name) + + cols = [ + col_dict["name"] + for col_dict in wc_client.get_table_columns( + "pytest_intermediate", output_name + ) + ] + assert len(cols) == 2 + assert sorted(cols) == sorted( + [config["unpivot_field_name"], config["unpivot_value_name"]] + ) + def test_mergetables(self): """test merge tables""" wc_client = TestPostgresOperations.wc_client @@ -760,7 +847,13 @@ def test_flattenjson(self): "_airbyte_emitted_at", ], "json_column": "_airbyte_data", - "json_columns_to_copy": ["NGO", "Month", "measure1", "measure2", "Indicator"], + "json_columns_to_copy": [ + "NGO", + "Month", + "measure1", + "measure2", + "Indicator", + ], } flattenjson(config, wc_client, TestPostgresOperations.test_project_dir)