diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index b45f93776..bf607c379 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -60,6 +60,7 @@ jobs: python -m pip install pre-commit pre-commit --version python -m pip install mypy==0.942 + python -m pip install types-requests mypy --version python -m pip install -r requirements.txt python -m pip install -r dev-requirements.txt diff --git a/CHANGELOG.md b/CHANGELOG.md index 28f7e138b..d015a26c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,8 @@ ## dbt-spark 1.3.0b1 (Release TBD) +### Features +- support python model through notebook, currently supported materializations are table and incremental. ([#377](https://github.com/dbt-labs/dbt-spark/pull/377)) + ### Fixes - Pin `pyodbc` to version 4.0.32 to prevent overwriting `libodbc.so` and `libltdl.so` on Linux ([#397](https://github.com/dbt-labs/dbt-spark/issues/397/), [#398](https://github.com/dbt-labs/dbt-spark/pull/398/)) diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py index 3fb9978d8..12c42ab98 100644 --- a/dbt/adapters/spark/impl.py +++ b/dbt/adapters/spark/impl.py @@ -1,4 +1,7 @@ import re +import requests +import time +import base64 from concurrent.futures import Future from dataclasses import dataclass from typing import Any, Dict, Iterable, List, Optional, Union @@ -11,7 +14,8 @@ import dbt.exceptions from dbt.adapters.base import AdapterConfig -from dbt.adapters.base.impl import catch_as_completed +from dbt.adapters.base.impl import catch_as_completed, log_code_execution +from dbt.adapters.base.meta import available from dbt.adapters.sql import SQLAdapter from dbt.adapters.spark import SparkConnectionManager from dbt.adapters.spark import SparkRelation @@ -159,11 +163,9 @@ def list_relations_without_caching( return relations - def get_relation( - self, database: Optional[str], schema: str, identifier: str - ) -> Optional[BaseRelation]: + def get_relation(self, database: str, schema: str, identifier: str) -> Optional[BaseRelation]: if not self.Relation.include_policy.database: - database = None + database = None # type: ignore return super().get_relation(database, schema, identifier) @@ -296,7 +298,12 @@ def get_catalog(self, manifest): for schema in schemas: futures.append( tpe.submit_connected( - self, schema, self._get_one_catalog, info, [schema], manifest + self, + schema, + self._get_one_catalog, + info, + [schema], + manifest, ) ) catalogs, exceptions = catch_as_completed(futures) @@ -380,6 +387,114 @@ def run_sql_for_tests(self, sql, fetch, conn): finally: conn.transaction_open = False + @available.parse_none + @log_code_execution + def submit_python_job(self, parsed_model: dict, compiled_code: str, timeout=None): + # TODO improve the typing here. N.B. Jinja returns a `jinja2.runtime.Undefined` instead + # of `None` which evaluates to True! + + # TODO limit this function to run only when doing the materialization of python nodes + + # assuming that for python job running over 1 day user would mannually overwrite this + schema = getattr(parsed_model, "schema", self.config.credentials.schema) + identifier = parsed_model["alias"] + if not timeout: + timeout = 60 * 60 * 24 + if timeout <= 0: + raise ValueError("Timeout must larger than 0") + + auth_header = {"Authorization": f"Bearer {self.connections.profile.credentials.token}"} + + # create new dir + if not self.connections.profile.credentials.user: + raise ValueError("Need to supply user in profile to submit python job") + # it is safe to call mkdirs even if dir already exists and have content inside + work_dir = f"/Users/{self.connections.profile.credentials.user}/{schema}" + response = requests.post( + f"https://{self.connections.profile.credentials.host}/api/2.0/workspace/mkdirs", + headers=auth_header, + json={ + "path": work_dir, + }, + ) + if response.status_code != 200: + raise dbt.exceptions.RuntimeException( + f"Error creating work_dir for python notebooks\n {response.content!r}" + ) + + # add notebook + b64_encoded_content = base64.b64encode(compiled_code.encode()).decode() + response = requests.post( + f"https://{self.connections.profile.credentials.host}/api/2.0/workspace/import", + headers=auth_header, + json={ + "path": f"{work_dir}/{identifier}", + "content": b64_encoded_content, + "language": "PYTHON", + "overwrite": True, + "format": "SOURCE", + }, + ) + if response.status_code != 200: + raise dbt.exceptions.RuntimeException( + f"Error creating python notebook.\n {response.content!r}" + ) + + # submit job + submit_response = requests.post( + f"https://{self.connections.profile.credentials.host}/api/2.1/jobs/runs/submit", + headers=auth_header, + json={ + "run_name": "debug task", + "existing_cluster_id": self.connections.profile.credentials.cluster, + "notebook_task": { + "notebook_path": f"{work_dir}/{identifier}", + }, + }, + ) + if submit_response.status_code != 200: + raise dbt.exceptions.RuntimeException( + f"Error creating python run.\n {response.content!r}" + ) + + # poll until job finish + state = None + start = time.time() + run_id = submit_response.json()["run_id"] + terminal_states = ["TERMINATED", "SKIPPED", "INTERNAL_ERROR"] + while state not in terminal_states and time.time() - start < timeout: + time.sleep(1) + resp = requests.get( + f"https://{self.connections.profile.credentials.host}" + f"/api/2.1/jobs/runs/get?run_id={run_id}", + headers=auth_header, + ) + json_resp = resp.json() + state = json_resp["state"]["life_cycle_state"] + # logger.debug(f"Polling.... in state: {state}") + if state != "TERMINATED": + raise dbt.exceptions.RuntimeException( + "python model run ended in state" + f"{state} with state_message\n{json_resp['state']['state_message']}" + ) + + # get end state to return to user + run_output = requests.get( + f"https://{self.connections.profile.credentials.host}" + f"/api/2.1/jobs/runs/get-output?run_id={run_id}", + headers=auth_header, + ) + json_run_output = run_output.json() + result_state = json_run_output["metadata"]["state"]["result_state"] + if result_state != "SUCCESS": + raise dbt.exceptions.RuntimeException( + "Python model failed with traceback as:\n" + "(Note that the line number here does not " + "match the line number in your code due to dbt templating)\n" + f"{json_run_output['error_trace']}" + ) + return self.connections.get_response(None) + def standardize_grants_dict(self, grants_table: agate.Table) -> dict: grants_dict: Dict[str, List[str]] = {} for row in grants_table: diff --git a/dbt/include/spark/macros/adapters.sql b/dbt/include/spark/macros/adapters.sql index abdeacb7f..05630ede5 100644 --- a/dbt/include/spark/macros/adapters.sql +++ b/dbt/include/spark/macros/adapters.sql @@ -117,35 +117,46 @@ {%- endmacro %} -{% macro create_temporary_view(relation, sql) -%} - {{ return(adapter.dispatch('create_temporary_view', 'dbt')(relation, sql)) }} +{% macro create_temporary_view(relation, compiled_code) -%} + {{ return(adapter.dispatch('create_temporary_view', 'dbt')(relation, compiled_code)) }} {%- endmacro -%} -{#-- We can't use temporary tables with `create ... as ()` syntax #} -{% macro spark__create_temporary_view(relation, sql) -%} - create temporary view {{ relation.include(schema=false) }} as - {{ sql }} -{% endmacro %} +{#-- We can't use temporary tables with `create ... as ()` syntax --#} +{% macro spark__create_temporary_view(relation, compiled_code) -%} + create temporary view {{ relation.include(schema=false) }} as + {{ compiled_code }} +{%- endmacro -%} -{% macro spark__create_table_as(temporary, relation, sql) -%} - {% if temporary -%} - {{ create_temporary_view(relation, sql) }} - {%- else -%} - {% if config.get('file_format', validator=validation.any[basestring]) == 'delta' %} - create or replace table {{ relation }} - {% else %} - create table {{ relation }} - {% endif %} - {{ file_format_clause() }} - {{ options_clause() }} - {{ partition_cols(label="partitioned by") }} - {{ clustered_cols(label="clustered by") }} - {{ location_clause() }} - {{ comment_clause() }} - as - {{ sql }} - {%- endif %} +{%- macro spark__create_table_as(temporary, relation, compiled_code, language='sql') -%} + {%- if language == 'sql' -%} + {%- if temporary -%} + {{ create_temporary_view(relation, compiled_code) }} + {%- else -%} + {% if config.get('file_format', validator=validation.any[basestring]) == 'delta' %} + create or replace table {{ relation }} + {% else %} + create table {{ relation }} + {% endif %} + {{ file_format_clause() }} + {{ options_clause() }} + {{ partition_cols(label="partitioned by") }} + {{ clustered_cols(label="clustered by") }} + {{ location_clause() }} + {{ comment_clause() }} + as + {{ compiled_code }} + {%- endif -%} + {%- elif language == 'python' -%} + {#-- + N.B. Python models _can_ write to temp views HOWEVER they use a different session + and have already expired by the time they need to be used (I.E. in merges for incremental models) + + TODO: Deep dive into spark sessions to see if we can reuse a single session for an entire + dbt invocation. + --#} + {{ py_write_table(compiled_code=compiled_code, target_relation=relation) }} + {%- endif -%} {%- endmacro -%} diff --git a/dbt/include/spark/macros/materializations/incremental/incremental.sql b/dbt/include/spark/macros/materializations/incremental/incremental.sql index 1ca2c149a..91cba9e5f 100644 --- a/dbt/include/spark/macros/materializations/incremental/incremental.sql +++ b/dbt/include/spark/macros/materializations/incremental/incremental.sql @@ -1,5 +1,4 @@ {% materialization incremental, adapter='spark' -%} - {#-- Validate early so we don't run SQL if the file_format + strategy combo is invalid --#} {%- set raw_file_format = config.get('file_format', default='parquet') -%} {%- set raw_strategy = config.get('incremental_strategy') or 'append' -%} @@ -8,43 +7,63 @@ {%- set file_format = dbt_spark_validate_get_file_format(raw_file_format) -%} {%- set strategy = dbt_spark_validate_get_incremental_strategy(raw_strategy, file_format) -%} + {#-- Set vars --#} + {%- set unique_key = config.get('unique_key', none) -%} {%- set partition_by = config.get('partition_by', none) -%} - - {%- set full_refresh_mode = (should_full_refresh()) -%} - - {% set on_schema_change = incremental_validate_on_schema_change(config.get('on_schema_change'), default='ignore') %} - - {% set target_relation = this %} - {% set existing_relation = load_relation(this) %} - {% set tmp_relation = make_temp_relation(this) %} - - {% if strategy == 'insert_overwrite' and partition_by %} - {% call statement() %} + {%- set language = model['language'] -%} + {%- set on_schema_change = incremental_validate_on_schema_change(config.get('on_schema_change'), default='ignore') -%} + {%- set target_relation = this -%} + {%- set existing_relation = load_relation(this) -%} + {%- set tmp_relation = make_temp_relation(this) -%} + + {#-- Set Overwrite Mode --#} + {%- if strategy == 'insert_overwrite' and partition_by -%} + {%- call statement() -%} set spark.sql.sources.partitionOverwriteMode = DYNAMIC - {% endcall %} - {% endif %} + {%- endcall -%} + {%- endif -%} + {#-- Run pre-hooks --#} {{ run_hooks(pre_hooks) }} - {% set is_delta = (file_format == 'delta' and existing_relation.is_delta) %} - - {% if existing_relation is none %} - {% set build_sql = create_table_as(False, target_relation, sql) %} - {% elif existing_relation.is_view or full_refresh_mode %} + {#-- Incremental run logic --#} + {%- if existing_relation is none -%} + {#-- Relation must be created --#} + {%- call statement('main', language=language) -%} + {{ create_table_as(False, target_relation, compiled_code, language) }} + {%- endcall -%} + {%- elif existing_relation.is_view or should_full_refresh() -%} + {#-- Relation must be dropped & recreated --#} + {% set is_delta = (file_format == 'delta' and existing_relation.is_delta) %} {% if not is_delta %} {#-- If Delta, we will `create or replace` below, so no need to drop --#} {% do adapter.drop_relation(existing_relation) %} {% endif %} - {% set build_sql = create_table_as(False, target_relation, sql) %} - {% else %} - {% do run_query(create_table_as(True, tmp_relation, sql)) %} - {% do process_schema_changes(on_schema_change, tmp_relation, existing_relation) %} - {% set build_sql = dbt_spark_get_incremental_sql(strategy, tmp_relation, target_relation, unique_key) %} - {% endif %} - - {%- call statement('main') -%} - {{ build_sql }} - {%- endcall -%} + {%- call statement('main', language=language) -%} + {{ create_table_as(False, target_relation, compiled_code, language) }} + {%- endcall -%} + {%- else -%} + {#-- Relation must be merged --#} + {%- call statement('create_tmp_relation', language=language) -%} + {{ create_table_as(True, tmp_relation, compiled_code, language) }} + {%- endcall -%} + {%- do process_schema_changes(on_schema_change, tmp_relation, existing_relation) -%} + {%- call statement('main') -%} + {{ dbt_spark_get_incremental_sql(strategy, tmp_relation, target_relation, unique_key) }} + {%- endcall -%} + {%- if language == 'python' -%} + {#-- + This is yucky. + See note in dbt-spark/dbt/include/spark/macros/adapters.sql + re: python models and temporary views. + + Also, why doesn't either drop_relation or adapter.drop_relation work here?! + --#} + {% call statement('drop_relation') -%} + drop table if exists {{ tmp_relation }} + {%- endcall %} + {%- endif -%} + {%- endif -%} {% set should_revoke = should_revoke(existing_relation, full_refresh_mode) %} {% do apply_grants(target_relation, grant_config, should_revoke) %} diff --git a/dbt/include/spark/macros/materializations/snapshot.sql b/dbt/include/spark/macros/materializations/snapshot.sql index a5304682e..6cf2358fe 100644 --- a/dbt/include/spark/macros/materializations/snapshot.sql +++ b/dbt/include/spark/macros/materializations/snapshot.sql @@ -117,7 +117,7 @@ {% if not target_relation_exists %} - {% set build_sql = build_snapshot_table(strategy, model['compiled_sql']) %} + {% set build_sql = build_snapshot_table(strategy, model['compiled_code']) %} {% set final_sql = create_table_as(False, target_relation, build_sql) %} {% else %} diff --git a/dbt/include/spark/macros/materializations/table.sql b/dbt/include/spark/macros/materializations/table.sql index 3462d3332..6a02ea164 100644 --- a/dbt/include/spark/macros/materializations/table.sql +++ b/dbt/include/spark/macros/materializations/table.sql @@ -1,5 +1,5 @@ {% materialization table, adapter = 'spark' %} - + {%- set language = model['language'] -%} {%- set identifier = model['alias'] -%} {%- set grant_config = config.get('grants') -%} @@ -19,9 +19,10 @@ {%- endif %} -- build model - {% call statement('main') -%} - {{ create_table_as(False, target_relation, sql) }} - {%- endcall %} + + {%- call statement('main', language=language) -%} + {{ create_table_as(False, target_relation, compiled_code, language) }} + {%- endcall -%} {% set should_revoke = should_revoke(old_relation, full_refresh_mode=True) %} {% do apply_grants(target_relation, grant_config, should_revoke) %} @@ -33,3 +34,18 @@ {{ return({'relations': [target_relation]})}} {% endmaterialization %} + + +{% macro py_write_table(compiled_code, target_relation) %} +{{ compiled_code }} +# --- Autogenerated dbt materialization code. --- # +dbt = dbtObj(spark.table) +df = model(dbt, spark) +df.write.mode("overwrite").format("delta").saveAsTable("{{ target_relation }}") +{%- endmacro -%} + +{%macro py_script_comment()%} +# how to execute python model in notebook +# dbt = dbtObj(spark.table) +# df = model(dbt, spark) +{%endmacro%} diff --git a/dev-requirements.txt b/dev-requirements.txt index b94cb8b6b..5b29e5e9d 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -3,6 +3,8 @@ git+https://github.com/dbt-labs/dbt-core.git#egg=dbt-core&subdirectory=core git+https://github.com/dbt-labs/dbt-core.git#egg=dbt-tests-adapter&subdirectory=tests/adapter + + black==22.3.0 bumpversion click~=8.0.4 diff --git a/requirements.txt b/requirements.txt index c64512aeb..5d774e4f7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,6 @@ PyHive[hive]>=0.6.0,<0.7.0 +requests[python]>=2.28.1 + pyodbc==4.0.32 sqlparams>=3.0.0 thrift>=0.13.0 diff --git a/tests/conftest.py b/tests/conftest.py index 0771566b7..2fa50d6c7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -60,6 +60,7 @@ def databricks_cluster_target(): "connect_retries": 3, "connect_timeout": 5, "retry_all": True, + "user": os.getenv('DBT_DATABRICKS_USER'), } @@ -91,6 +92,7 @@ def databricks_http_cluster_target(): "connect_retries": 5, "connect_timeout": 60, "retry_all": bool(os.getenv('DBT_DATABRICKS_RETRY_ALL', False)), + "user": os.getenv('DBT_DATABRICKS_USER'), } diff --git a/tests/functional/adapter/test_basic.py b/tests/functional/adapter/test_basic.py index e1a57fd3f..bdccf169d 100644 --- a/tests/functional/adapter/test_basic.py +++ b/tests/functional/adapter/test_basic.py @@ -79,7 +79,6 @@ def project_config_update(self): } } - @pytest.mark.skip_profile('spark_session') class TestBaseAdapterMethod(BaseAdapterMethod): pass diff --git a/tests/functional/adapter/test_python_model.py b/tests/functional/adapter/test_python_model.py new file mode 100644 index 000000000..059412f10 --- /dev/null +++ b/tests/functional/adapter/test_python_model.py @@ -0,0 +1,59 @@ +import os +import pytest +from dbt.tests.util import run_dbt, write_file, run_dbt_and_capture +from dbt.tests.adapter.python_model.test_python_model import BasePythonModelTests, BasePythonIncrementalTests + +@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint") +class TestPythonModelSpark(BasePythonModelTests): + pass + +@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint") +class TestPythonIncrementalModelSpark(BasePythonIncrementalTests): + @pytest.fixture(scope="class") + def project_config_update(self): + return {} + + +models__simple_python_model = """ +import pandas + +def model(dbt, spark): + dbt.config( + materialized='table', + ) + data = [[1,2]] * 10 + return spark.createDataFrame(data, schema=['test', 'test2']) +""" +models__simple_python_model_v2 = """ +import pandas + +def model(dbt, spark): + dbt.config( + materialized='table', + ) + data = [[1,2]] * 10 + return spark.createDataFrame(data, schema=['test1', 'test3']) +""" + + +@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint") +class TestChangingSchemaSpark: + @pytest.fixture(scope="class") + def models(self): + return {"simple_python_model.py": models__simple_python_model} + + def test_changing_schema_with_log_validation(self, project, logs_dir): + run_dbt(["run"]) + write_file( + models__simple_python_model_v2, + project.project_root + "/models", + "simple_python_model.py", + ) + run_dbt(["run"]) + log_file = os.path.join(logs_dir, "dbt.log") + with open(log_file, "r") as f: + log = f.read() + # validate #5510 log_code_execution works + assert "On model.test.simple_python_model:" in log + assert "spark.createDataFrame(data, schema=['test1', 'test3'])" in log + assert "Execution status: OK in" in log diff --git a/tests/integration/incremental_strategies/test_incremental_strategies.py b/tests/integration/incremental_strategies/test_incremental_strategies.py index 839f167e6..3848d11ae 100644 --- a/tests/integration/incremental_strategies/test_incremental_strategies.py +++ b/tests/integration/incremental_strategies/test_incremental_strategies.py @@ -60,6 +60,8 @@ def run_and_test(self): def test_insert_overwrite_apache_spark(self): self.run_and_test() + # This test requires settings on the test cluster + # more info at https://docs.getdbt.com/reference/resource-configs/spark-configs#the-insert_overwrite-strategy @use_profile("databricks_cluster") def test_insert_overwrite_databricks_cluster(self): self.run_and_test()