From 649f409e39b033494ea2c1877e9111399db960b5 Mon Sep 17 00:00:00 2001 From: Pankaj Koti Date: Fri, 26 Jul 2024 18:46:59 +0530 Subject: [PATCH] Revert "Fix named parameters templating in Databricks operators (#40864)" This reverts commit cfe1d53ed041ea903292e3789e1a5238db5b5031. --- .../databricks/operators/databricks.py | 173 +++++------------- .../databricks/operators/test_databricks.py | 156 ---------------- 2 files changed, 45 insertions(+), 284 deletions(-) diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index b1299b9d85040..a263fa9106a11 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -263,23 +263,7 @@ class DatabricksCreateJobsOperator(BaseOperator): """ # Used in airflow.models.BaseOperator - template_fields: Sequence[str] = ( - "json", - "databricks_conn_id", - "name", - "description", - "tags", - "tasks", - "job_clusters", - "email_notifications", - "webhook_notifications", - "notification_settings", - "timeout_seconds", - "schedule", - "max_concurrent_runs", - "git_source", - "access_control_list", - ) + template_fields: Sequence[str] = ("json", "databricks_conn_id") # Databricks brand color (blue) under white text ui_color = "#1CB1C2" ui_fgcolor = "#fff" @@ -316,19 +300,21 @@ def __init__( self.databricks_retry_limit = databricks_retry_limit self.databricks_retry_delay = databricks_retry_delay self.databricks_retry_args = databricks_retry_args - self.name = name - self.description = description - self.tags = tags - self.tasks = tasks - self.job_clusters = job_clusters - self.email_notifications = email_notifications - self.webhook_notifications = webhook_notifications - self.notification_settings = notification_settings - self.timeout_seconds = timeout_seconds - self.schedule = schedule - self.max_concurrent_runs = max_concurrent_runs - self.git_source = git_source - self.access_control_list = access_control_list + self.overridden_json_params = { + "name": name, + "description": description, + "tags": tags, + "tasks": tasks, + "job_clusters": job_clusters, + "email_notifications": email_notifications, + "webhook_notifications": webhook_notifications, + "notification_settings": notification_settings, + "timeout_seconds": timeout_seconds, + "schedule": schedule, + "max_concurrent_runs": max_concurrent_runs, + "git_source": git_source, + "access_control_list": access_control_list, + } @cached_property def _hook(self): @@ -341,22 +327,6 @@ def _hook(self): ) def _setup_and_validate_json(self): - self.overridden_json_params = { - "name": self.name, - "description": self.description, - "tags": self.tags, - "tasks": self.tasks, - "job_clusters": self.job_clusters, - "email_notifications": self.email_notifications, - "webhook_notifications": self.webhook_notifications, - "notification_settings": self.notification_settings, - "timeout_seconds": self.timeout_seconds, - "schedule": self.schedule, - "max_concurrent_runs": self.max_concurrent_runs, - "git_source": self.git_source, - "access_control_list": self.access_control_list, - } - _handle_overridden_json_params(self) if "name" not in self.json: @@ -500,25 +470,7 @@ class DatabricksSubmitRunOperator(BaseOperator): """ # Used in airflow.models.BaseOperator - template_fields: Sequence[str] = ( - "json", - "databricks_conn_id", - "tasks", - "spark_jar_task", - "notebook_task", - "spark_python_task", - "spark_submit_task", - "pipeline_task", - "dbt_task", - "new_cluster", - "existing_cluster_id", - "libraries", - "run_name", - "timeout_seconds", - "idempotency_token", - "access_control_list", - "git_source", - ) + template_fields: Sequence[str] = ("json", "databricks_conn_id") template_ext: Sequence[str] = (".json-tpl",) # Databricks brand color (blue) under white text ui_color = "#1CB1C2" @@ -564,21 +516,23 @@ def __init__( self.databricks_retry_args = databricks_retry_args self.wait_for_termination = wait_for_termination self.deferrable = deferrable - self.tasks = tasks - self.spark_jar_task = spark_jar_task - self.notebook_task = notebook_task - self.spark_python_task = spark_python_task - self.spark_submit_task = spark_submit_task - self.pipeline_task = pipeline_task - self.dbt_task = dbt_task - self.new_cluster = new_cluster - self.existing_cluster_id = existing_cluster_id - self.libraries = libraries - self.run_name = run_name - self.timeout_seconds = timeout_seconds - self.idempotency_token = idempotency_token - self.access_control_list = access_control_list - self.git_source = git_source + self.overridden_json_params = { + "tasks": tasks, + "spark_jar_task": spark_jar_task, + "notebook_task": notebook_task, + "spark_python_task": spark_python_task, + "spark_submit_task": spark_submit_task, + "pipeline_task": pipeline_task, + "dbt_task": dbt_task, + "new_cluster": new_cluster, + "existing_cluster_id": existing_cluster_id, + "libraries": libraries, + "run_name": run_name, + "timeout_seconds": timeout_seconds, + "idempotency_token": idempotency_token, + "access_control_list": access_control_list, + "git_source": git_source, + } # This variable will be used in case our task gets killed. self.run_id: int | None = None @@ -598,24 +552,6 @@ def _get_hook(self, caller: str) -> DatabricksHook: ) def _setup_and_validate_json(self): - self.overridden_json_params = { - "tasks": self.tasks, - "spark_jar_task": self.spark_jar_task, - "notebook_task": self.notebook_task, - "spark_python_task": self.spark_python_task, - "spark_submit_task": self.spark_submit_task, - "pipeline_task": self.pipeline_task, - "dbt_task": self.dbt_task, - "new_cluster": self.new_cluster, - "existing_cluster_id": self.existing_cluster_id, - "libraries": self.libraries, - "run_name": self.run_name, - "timeout_seconds": self.timeout_seconds, - "idempotency_token": self.idempotency_token, - "access_control_list": self.access_control_list, - "git_source": self.git_source, - } - _handle_overridden_json_params(self) if "run_name" not in self.json or self.json["run_name"] is None: @@ -836,18 +772,7 @@ class DatabricksRunNowOperator(BaseOperator): """ # Used in airflow.models.BaseOperator - template_fields: Sequence[str] = ( - "json", - "databricks_conn_id", - "job_id", - "job_name", - "notebook_params", - "python_params", - "python_named_params", - "jar_params", - "spark_submit_params", - "idempotency_token", - ) + template_fields: Sequence[str] = ("json", "databricks_conn_id") template_ext: Sequence[str] = (".json-tpl",) # Databricks brand color (blue) under white text ui_color = "#1CB1C2" @@ -890,14 +815,16 @@ def __init__( self.deferrable = deferrable self.repair_run = repair_run self.cancel_previous_runs = cancel_previous_runs - self.job_id = job_id - self.job_name = job_name - self.notebook_params = notebook_params - self.python_params = python_params - self.python_named_params = python_named_params - self.jar_params = jar_params - self.spark_submit_params = spark_submit_params - self.idempotency_token = idempotency_token + self.overridden_json_params = { + "job_id": job_id, + "job_name": job_name, + "notebook_params": notebook_params, + "python_params": python_params, + "python_named_params": python_named_params, + "jar_params": jar_params, + "spark_submit_params": spark_submit_params, + "idempotency_token": idempotency_token, + } # This variable will be used in case our task gets killed. self.run_id: int | None = None self.do_xcom_push = do_xcom_push @@ -916,16 +843,6 @@ def _get_hook(self, caller: str) -> DatabricksHook: ) def _setup_and_validate_json(self): - self.overridden_json_params = { - "job_id": self.job_id, - "job_name": self.job_name, - "notebook_params": self.notebook_params, - "python_params": self.python_params, - "python_named_params": self.python_named_params, - "jar_params": self.jar_params, - "spark_submit_params": self.spark_submit_params, - "idempotency_token": self.idempotency_token, - } _handle_overridden_json_params(self) if "job_id" in self.json and "job_name" in self.json: diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index a7337669047cb..ae2bb4976669c 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -66,11 +66,7 @@ RUN_ID = 1 RUN_PAGE_URL = "run-page-url" JOB_ID = "42" -TEMPLATED_JOB_ID = "job-id-{{ ds }}" -RENDERED_TEMPLATED_JOB_ID = f"job-id-{DATE}" JOB_NAME = "job-name" -TEMPLATED_JOB_NAME = "job-name-{{ ds }}" -RENDERED_TEMPLATED_JOB_NAME = f"job-name-{DATE}" JOB_DESCRIPTION = "job-description" NOTEBOOK_PARAMS = {"dry-run": "true", "oldest-time-to-consider": "1457570074236"} JAR_PARAMS = ["param1", "param2"] @@ -487,68 +483,6 @@ def test_validate_json_with_templated_json(self, db_mock_class, dag_maker): db_mock.create_job.assert_called_once_with(expected) assert JOB_ID == tis[TASK_ID].xcom_pull() - @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") - def test_validate_json_with_templated_named_param(self, db_mock_class, dag_maker): - json = "{{ ti.xcom_pull(task_ids='push_json') }}" - with dag_maker("test_templated", render_template_as_native_obj=True): - push_json = PythonOperator( - task_id="push_json", - python_callable=lambda: { - "description": JOB_DESCRIPTION, - "tags": TAGS, - "tasks": TASKS, - "job_clusters": JOB_CLUSTERS, - "email_notifications": EMAIL_NOTIFICATIONS, - "webhook_notifications": WEBHOOK_NOTIFICATIONS, - "notification_settings": NOTIFICATION_SETTINGS, - "timeout_seconds": TIMEOUT_SECONDS, - "schedule": SCHEDULE, - "max_concurrent_runs": MAX_CONCURRENT_RUNS, - "git_source": GIT_SOURCE, - "access_control_list": ACCESS_CONTROL_LIST, - }, - ) - op = DatabricksCreateJobsOperator(task_id=TASK_ID, json=json, name=TEMPLATED_JOB_NAME) - push_json >> op - - db_mock = db_mock_class.return_value - db_mock.create_job.return_value = JOB_ID - - db_mock.find_job_id_by_name.return_value = None - - dagrun = dag_maker.create_dagrun(execution_date=datetime.strptime(DATE, "%Y-%m-%d")) - tis = {ti.task_id: ti for ti in dagrun.task_instances} - tis["push_json"].run() - tis[TASK_ID].run() - - expected = utils._normalise_json_content( - { - "name": RENDERED_TEMPLATED_JOB_NAME, - "description": JOB_DESCRIPTION, - "tags": TAGS, - "tasks": TASKS, - "job_clusters": JOB_CLUSTERS, - "email_notifications": EMAIL_NOTIFICATIONS, - "webhook_notifications": WEBHOOK_NOTIFICATIONS, - "notification_settings": NOTIFICATION_SETTINGS, - "timeout_seconds": TIMEOUT_SECONDS, - "schedule": SCHEDULE, - "max_concurrent_runs": MAX_CONCURRENT_RUNS, - "git_source": GIT_SOURCE, - "access_control_list": ACCESS_CONTROL_LIST, - } - ) - db_mock_class.assert_called_once_with( - DEFAULT_CONN_ID, - retry_limit=op.databricks_retry_limit, - retry_delay=op.databricks_retry_delay, - retry_args=None, - caller="DatabricksCreateJobsOperator", - ) - - db_mock.create_job.assert_called_once_with(expected) - assert JOB_ID == tis[TASK_ID].xcom_pull() - @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") def test_validate_json_with_xcomarg_json(self, db_mock_class, dag_maker): with dag_maker("test_xcomarg", render_template_as_native_obj=True): @@ -1081,50 +1015,6 @@ def test_validate_json_with_templated_json(self, db_mock_class, dag_maker): db_mock.get_run_page_url.assert_called_once_with(RUN_ID) db_mock.get_run.assert_called_once_with(RUN_ID) - @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") - def test_validate_json_with_templated_named_params(self, db_mock_class, dag_maker): - json = "{{ ti.xcom_pull(task_ids='push_json') }}" - with dag_maker("test_templated", render_template_as_native_obj=True): - push_json = PythonOperator( - task_id="push_json", - python_callable=lambda: { - "new_cluster": NEW_CLUSTER, - }, - ) - op = DatabricksSubmitRunOperator( - task_id=TASK_ID, json=json, notebook_task=TEMPLATED_NOTEBOOK_TASK - ) - push_json >> op - - db_mock = db_mock_class.return_value - db_mock.submit_run.return_value = RUN_ID - db_mock.get_run_page_url.return_value = RUN_PAGE_URL - db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS") - - dagrun = dag_maker.create_dagrun(execution_date=datetime.strptime(DATE, "%Y-%m-%d")) - tis = {ti.task_id: ti for ti in dagrun.task_instances} - tis["push_json"].run() - tis[TASK_ID].run() - - expected = utils._normalise_json_content( - { - "new_cluster": NEW_CLUSTER, - "notebook_task": RENDERED_TEMPLATED_NOTEBOOK_TASK, - "run_name": TASK_ID, - } - ) - db_mock_class.assert_called_once_with( - DEFAULT_CONN_ID, - retry_limit=op.databricks_retry_limit, - retry_delay=op.databricks_retry_delay, - retry_args=None, - caller="DatabricksSubmitRunOperator", - ) - - db_mock.submit_run.assert_called_once_with(expected) - db_mock.get_run_page_url.assert_called_once_with(RUN_ID) - db_mock.get_run.assert_called_once_with(RUN_ID) - @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") def test_validate_json_with_xcomarg_json(self, db_mock_class, dag_maker): with dag_maker("test_xcomarg", render_template_as_native_obj=True): @@ -1644,52 +1534,6 @@ def test_validate_json_with_templated_json(self, db_mock_class, dag_maker): db_mock.get_run_page_url.assert_called_once_with(RUN_ID) db_mock.get_run.assert_called_once_with(RUN_ID) - @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") - def test_validate_json_with_templated_named_params(self, db_mock_class, dag_maker): - json = "{{ ti.xcom_pull(task_ids='push_json') }}" - with dag_maker("test_templated", render_template_as_native_obj=True): - push_json = PythonOperator( - task_id="push_json", - python_callable=lambda: { - "notebook_params": NOTEBOOK_PARAMS, - "notebook_task": NOTEBOOK_TASK, - }, - ) - op = DatabricksRunNowOperator( - task_id=TASK_ID, job_id=TEMPLATED_JOB_ID, jar_params=TEMPLATED_JAR_PARAMS, json=json - ) - push_json >> op - - db_mock = db_mock_class.return_value - db_mock.run_now.return_value = RUN_ID - db_mock.get_run_page_url.return_value = RUN_PAGE_URL - db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS") - - dagrun = dag_maker.create_dagrun(execution_date=datetime.strptime(DATE, "%Y-%m-%d")) - tis = {ti.task_id: ti for ti in dagrun.task_instances} - tis["push_json"].run() - tis[TASK_ID].run() - - expected = utils._normalise_json_content( - { - "notebook_params": NOTEBOOK_PARAMS, - "notebook_task": NOTEBOOK_TASK, - "jar_params": RENDERED_TEMPLATED_JAR_PARAMS, - "job_id": RENDERED_TEMPLATED_JOB_ID, - } - ) - - db_mock_class.assert_called_once_with( - DEFAULT_CONN_ID, - retry_limit=op.databricks_retry_limit, - retry_delay=op.databricks_retry_delay, - retry_args=None, - caller="DatabricksRunNowOperator", - ) - db_mock.run_now.assert_called_once_with(expected) - db_mock.get_run_page_url.assert_called_once_with(RUN_ID) - db_mock.get_run.assert_called_once_with(RUN_ID) - @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") def test_validate_json_with_xcomarg_json(self, db_mock_class, dag_maker): with dag_maker("test_xcomarg", render_template_as_native_obj=True):