From d26bda46200051f514500644049bf6b040111258 Mon Sep 17 00:00:00 2001 From: Shahar Epstein Date: Mon, 11 Mar 2024 21:47:22 +0200 Subject: [PATCH] Rename `SparkSubmitOperator`'s fields' names to comply with templated fields validation --- .pre-commit-config.yaml | 4 +- .../apache/spark/operators/spark_jdbc.py | 34 +------ .../apache/spark/operators/spark_submit.py | 90 +++++++++---------- .../apache/spark/operators/test_spark_jdbc.py | 49 +++++++++- .../spark/operators/test_spark_submit.py | 75 ++++++++++++---- 5 files changed, 156 insertions(+), 96 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0ffa1f47a5e4e..6eebeacbed4e0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -334,9 +334,7 @@ repos: # https://github.com/apache/airflow/issues/36484 exclude: | (?x)^( - ^airflow\/providers\/google\/cloud\/operators\/cloud_storage_transfer_service.py$| - ^airflow\/providers\/apache\/spark\/operators\/spark_submit.py\.py$| - ^airflow\/providers\/apache\/spark\/operators\/spark_submit\.py$| + ^airflow\/providers\/google\/cloud\/operators\/cloud_storage_transfer_service.py$ )$ - id: ruff name: Run 'ruff' for extremely fast Python linting diff --git a/airflow/providers/apache/spark/operators/spark_jdbc.py b/airflow/providers/apache/spark/operators/spark_jdbc.py index e5ff5f9c65a44..465e23ab53c4d 100644 --- a/airflow/providers/apache/spark/operators/spark_jdbc.py +++ b/airflow/providers/apache/spark/operators/spark_jdbc.py @@ -44,14 +44,6 @@ class SparkJDBCOperator(SparkSubmitOperator): :param spark_files: Additional files to upload to the container running the job :param spark_jars: Additional jars to upload and add to the driver and executor classpath - :param num_executors: number of executor to run. This should be set so as to manage - the number of connections made with the JDBC database - :param executor_cores: Number of cores per executor - :param executor_memory: Memory per executor (e.g. 1000M, 2G) - :param driver_memory: Memory allocated to the driver (e.g. 1000M, 2G) - :param verbose: Whether to pass the verbose flag to spark-submit for debugging - :param keytab: Full path to the file that contains the keytab - :param principal: The name of the kerberos principal used for keytab :param cmd_type: Which way the data should flow. 2 possible values: spark_to_jdbc: data written by spark from metastore to jdbc jdbc_to_spark: data written by spark from jdbc to metastore @@ -60,7 +52,7 @@ class SparkJDBCOperator(SparkSubmitOperator): :param jdbc_driver: Name of the JDBC driver to use for the JDBC connection. This driver (usually a jar) should be passed in the 'jars' parameter :param metastore_table: The name of the metastore table, - :param jdbc_truncate: (spark_to_jdbc only) Whether or not Spark should truncate or + :param jdbc_truncate: (spark_to_jdbc only) Whether Spark should truncate or drop and recreate the JDBC table. This only takes effect if 'save_mode' is set to Overwrite. Also, if the schema is different, Spark cannot truncate, and will drop and recreate @@ -91,9 +83,7 @@ class SparkJDBCOperator(SparkSubmitOperator): (e.g: "name CHAR(64), comments VARCHAR(1024)"). The specified types should be valid spark sql data types. - :param use_krb5ccache: if True, configure spark to use ticket cache instead of relying - on keytab for Kerberos login - + :param kwargs: kwargs passed to SparkSubmitOperator. """ def __init__( @@ -105,13 +95,6 @@ def __init__( spark_py_files: str | None = None, spark_files: str | None = None, spark_jars: str | None = None, - num_executors: int | None = None, - executor_cores: int | None = None, - executor_memory: str | None = None, - driver_memory: str | None = None, - verbose: bool = False, - principal: str | None = None, - keytab: str | None = None, cmd_type: str = "spark_to_jdbc", jdbc_table: str | None = None, jdbc_conn_id: str = "jdbc-default", @@ -127,7 +110,6 @@ def __init__( lower_bound: str | None = None, upper_bound: str | None = None, create_table_column_types: str | None = None, - use_krb5ccache: bool = False, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -137,13 +119,6 @@ def __init__( self._spark_py_files = spark_py_files self._spark_files = spark_files self._spark_jars = spark_jars - self._num_executors = num_executors - self._executor_cores = executor_cores - self._executor_memory = executor_memory - self._driver_memory = driver_memory - self._verbose = verbose - self._keytab = keytab - self._principal = principal self._cmd_type = cmd_type self._jdbc_table = jdbc_table self._jdbc_conn_id = jdbc_conn_id @@ -160,7 +135,6 @@ def __init__( self._upper_bound = upper_bound self._create_table_column_types = create_table_column_types self._hook: SparkJDBCHook | None = None - self._use_krb5ccache = use_krb5ccache def execute(self, context: Context) -> None: """Call the SparkSubmitHook to run the provided spark job.""" @@ -186,8 +160,8 @@ def _get_hook(self) -> SparkJDBCHook: executor_memory=self._executor_memory, driver_memory=self._driver_memory, verbose=self._verbose, - keytab=self._keytab, - principal=self._principal, + keytab=self.keytab, + principal=self.principal, cmd_type=self._cmd_type, jdbc_table=self._jdbc_table, jdbc_conn_id=self._jdbc_conn_id, diff --git a/airflow/providers/apache/spark/operators/spark_submit.py b/airflow/providers/apache/spark/operators/spark_submit.py index bd8480b8151ff..62f7918fcf993 100644 --- a/airflow/providers/apache/spark/operators/spark_submit.py +++ b/airflow/providers/apache/spark/operators/spark_submit.py @@ -81,21 +81,21 @@ class SparkSubmitOperator(BaseOperator): """ template_fields: Sequence[str] = ( - "_application", - "_conf", - "_files", - "_py_files", - "_jars", - "_driver_class_path", - "_packages", - "_exclude_packages", - "_keytab", - "_principal", - "_proxy_user", - "_name", - "_application_args", - "_env_vars", - "_properties_file", + "application", + "conf", + "files", + "py_files", + "jars", + "driver_class_path", + "packages", + "exclude_packages", + "keytab", + "principal", + "proxy_user", + "name", + "application_args", + "env_vars", + "properties_file", ) ui_color = WEB_COLORS["LIGHTORANGE"] @@ -135,32 +135,32 @@ def __init__( **kwargs: Any, ) -> None: super().__init__(**kwargs) - self._application = application - self._conf = conf - self._files = files - self._py_files = py_files + self.application = application + self.conf = conf + self.files = files + self.py_files = py_files self._archives = archives - self._driver_class_path = driver_class_path - self._jars = jars + self.driver_class_path = driver_class_path + self.jars = jars self._java_class = java_class - self._packages = packages - self._exclude_packages = exclude_packages + self.packages = packages + self.exclude_packages = exclude_packages self._repositories = repositories self._total_executor_cores = total_executor_cores self._executor_cores = executor_cores self._executor_memory = executor_memory self._driver_memory = driver_memory - self._keytab = keytab - self._principal = principal - self._proxy_user = proxy_user - self._name = name + self.keytab = keytab + self.principal = principal + self.proxy_user = proxy_user + self.name = name self._num_executors = num_executors self._status_poll_interval = status_poll_interval - self._application_args = application_args - self._env_vars = env_vars + self.application_args = application_args + self.env_vars = env_vars self._verbose = verbose self._spark_binary = spark_binary - self._properties_file = properties_file + self.properties_file = properties_file self._queue = queue self._deploy_mode = deploy_mode self._hook: SparkSubmitHook | None = None @@ -171,7 +171,7 @@ def execute(self, context: Context) -> None: """Call the SparkSubmitHook to run the provided spark job.""" if self._hook is None: self._hook = self._get_hook() - self._hook.submit(self._application) + self._hook.submit(self.application) def on_kill(self) -> None: if self._hook is None: @@ -180,32 +180,32 @@ def on_kill(self) -> None: def _get_hook(self) -> SparkSubmitHook: return SparkSubmitHook( - conf=self._conf, + conf=self.conf, conn_id=self._conn_id, - files=self._files, - py_files=self._py_files, + files=self.files, + py_files=self.py_files, archives=self._archives, - driver_class_path=self._driver_class_path, - jars=self._jars, + driver_class_path=self.driver_class_path, + jars=self.jars, java_class=self._java_class, - packages=self._packages, - exclude_packages=self._exclude_packages, + packages=self.packages, + exclude_packages=self.exclude_packages, repositories=self._repositories, total_executor_cores=self._total_executor_cores, executor_cores=self._executor_cores, executor_memory=self._executor_memory, driver_memory=self._driver_memory, - keytab=self._keytab, - principal=self._principal, - proxy_user=self._proxy_user, - name=self._name, + keytab=self.keytab, + principal=self.principal, + proxy_user=self.proxy_user, + name=self.name, num_executors=self._num_executors, status_poll_interval=self._status_poll_interval, - application_args=self._application_args, - env_vars=self._env_vars, + application_args=self.application_args, + env_vars=self.env_vars, verbose=self._verbose, spark_binary=self._spark_binary, - properties_file=self._properties_file, + properties_file=self.properties_file, queue=self._queue, deploy_mode=self._deploy_mode, use_krb5ccache=self._use_krb5ccache, diff --git a/tests/providers/apache/spark/operators/test_spark_jdbc.py b/tests/providers/apache/spark/operators/test_spark_jdbc.py index bccc9e536753d..7ca8c22ab98fe 100644 --- a/tests/providers/apache/spark/operators/test_spark_jdbc.py +++ b/tests/providers/apache/spark/operators/test_spark_jdbc.py @@ -17,6 +17,8 @@ # under the License. from __future__ import annotations +import pytest + from airflow.models.dag import DAG from airflow.providers.apache.spark.operators.spark_jdbc import SparkJDBCOperator from airflow.utils import timezone @@ -111,8 +113,8 @@ def test_execute(self): assert expected_dict["executor_memory"] == operator._executor_memory assert expected_dict["driver_memory"] == operator._driver_memory assert expected_dict["verbose"] == operator._verbose - assert expected_dict["keytab"] == operator._keytab - assert expected_dict["principal"] == operator._principal + assert expected_dict["keytab"] == operator.keytab + assert expected_dict["principal"] == operator.principal assert expected_dict["cmd_type"] == operator._cmd_type assert expected_dict["jdbc_table"] == operator._jdbc_table assert expected_dict["jdbc_driver"] == operator._jdbc_driver @@ -128,3 +130,46 @@ def test_execute(self): assert expected_dict["upper_bound"] == operator._upper_bound assert expected_dict["create_table_column_types"] == operator._create_table_column_types assert expected_dict["use_krb5ccache"] == operator._use_krb5ccache + + @pytest.mark.db_test + def test_templating_with_create_task_instance_of_operator(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + SparkJDBCOperator, + # Templated fields + application="{{ 'application' }}", + conf="{{ 'conf' }}", + files="{{ 'files' }}", + py_files="{{ 'py-files' }}", + jars="{{ 'jars' }}", + driver_class_path="{{ 'driver_class_path' }}", + packages="{{ 'packages' }}", + exclude_packages="{{ 'exclude_packages' }}", + keytab="{{ 'keytab' }}", + principal="{{ 'principal' }}", + proxy_user="{{ 'proxy_user' }}", + name="{{ 'name' }}", + application_args="{{ 'application_args' }}", + env_vars="{{ 'env_vars' }}", + properties_file="{{ 'properties_file' }}", + # Other parameters + dag_id="test_template_body_templating_dag", + task_id="test_template_body_templating_task", + execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc), + ) + ti.render_templates() + task: SparkJDBCOperator = ti.task + assert task.application == "application" + assert task.conf == "conf" + assert task.files == "files" + assert task.py_files == "py-files" + assert task.jars == "jars" + assert task.driver_class_path == "driver_class_path" + assert task.packages == "packages" + assert task.exclude_packages == "exclude_packages" + assert task.keytab == "keytab" + assert task.principal == "principal" + assert task.proxy_user == "proxy_user" + assert task.name == "name" + assert task.application_args == "application_args" + assert task.env_vars == "env_vars" + assert task.properties_file == "properties_file" diff --git a/tests/providers/apache/spark/operators/test_spark_submit.py b/tests/providers/apache/spark/operators/test_spark_submit.py index 3330ca88ddf64..4f8cb7d5486de 100644 --- a/tests/providers/apache/spark/operators/test_spark_submit.py +++ b/tests/providers/apache/spark/operators/test_spark_submit.py @@ -129,33 +129,33 @@ def test_execute(self): } assert conn_id == operator._conn_id - assert expected_dict["application"] == operator._application - assert expected_dict["conf"] == operator._conf - assert expected_dict["files"] == operator._files - assert expected_dict["py_files"] == operator._py_files + assert expected_dict["application"] == operator.application + assert expected_dict["conf"] == operator.conf + assert expected_dict["files"] == operator.files + assert expected_dict["py_files"] == operator.py_files assert expected_dict["archives"] == operator._archives - assert expected_dict["driver_class_path"] == operator._driver_class_path - assert expected_dict["jars"] == operator._jars - assert expected_dict["packages"] == operator._packages - assert expected_dict["exclude_packages"] == operator._exclude_packages + assert expected_dict["driver_class_path"] == operator.driver_class_path + assert expected_dict["jars"] == operator.jars + assert expected_dict["packages"] == operator.packages + assert expected_dict["exclude_packages"] == operator.exclude_packages assert expected_dict["repositories"] == operator._repositories assert expected_dict["total_executor_cores"] == operator._total_executor_cores assert expected_dict["executor_cores"] == operator._executor_cores assert expected_dict["executor_memory"] == operator._executor_memory - assert expected_dict["keytab"] == operator._keytab - assert expected_dict["principal"] == operator._principal - assert expected_dict["proxy_user"] == operator._proxy_user - assert expected_dict["name"] == operator._name + assert expected_dict["keytab"] == operator.keytab + assert expected_dict["principal"] == operator.principal + assert expected_dict["proxy_user"] == operator.proxy_user + assert expected_dict["name"] == operator.name assert expected_dict["num_executors"] == operator._num_executors assert expected_dict["status_poll_interval"] == operator._status_poll_interval assert expected_dict["verbose"] == operator._verbose assert expected_dict["java_class"] == operator._java_class assert expected_dict["driver_memory"] == operator._driver_memory - assert expected_dict["application_args"] == operator._application_args + assert expected_dict["application_args"] == operator.application_args assert expected_dict["spark_binary"] == operator._spark_binary assert expected_dict["queue"] == operator._queue assert expected_dict["deploy_mode"] == operator._deploy_mode - assert expected_dict["properties_file"] == operator._properties_file + assert expected_dict["properties_file"] == operator.properties_file assert expected_dict["use_krb5ccache"] == operator._use_krb5ccache @pytest.mark.db_test @@ -205,5 +205,48 @@ def test_render_template(self): "args should keep embedded spaces", ] expected_name = "spark_submit_job" - assert expected_application_args == getattr(operator, "_application_args") - assert expected_name == getattr(operator, "_name") + assert expected_application_args == getattr(operator, "application_args") + assert expected_name == getattr(operator, "name") + + @pytest.mark.db_test + def test_templating_with_create_task_instance_of_operator(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + SparkSubmitOperator, + # Templated fields + application="{{ 'application' }}", + conf="{{ 'conf' }}", + files="{{ 'files' }}", + py_files="{{ 'py-files' }}", + jars="{{ 'jars' }}", + driver_class_path="{{ 'driver_class_path' }}", + packages="{{ 'packages' }}", + exclude_packages="{{ 'exclude_packages' }}", + keytab="{{ 'keytab' }}", + principal="{{ 'principal' }}", + proxy_user="{{ 'proxy_user' }}", + name="{{ 'name' }}", + application_args="{{ 'application_args' }}", + env_vars="{{ 'env_vars' }}", + properties_file="{{ 'properties_file' }}", + # Other parameters + dag_id="test_template_body_templating_dag", + task_id="test_template_body_templating_task", + execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc), + ) + ti.render_templates() + task: SparkSubmitOperator = ti.task + assert task.application == "application" + assert task.conf == "conf" + assert task.files == "files" + assert task.py_files == "py-files" + assert task.jars == "jars" + assert task.driver_class_path == "driver_class_path" + assert task.packages == "packages" + assert task.exclude_packages == "exclude_packages" + assert task.keytab == "keytab" + assert task.principal == "principal" + assert task.proxy_user == "proxy_user" + assert task.name == "name" + assert task.application_args == "application_args" + assert task.env_vars == "env_vars" + assert task.properties_file == "properties_file"