Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename SparkSubmitOperator/SparkJDBCOperator fields' names to comply with templated fields validation #38051

Merged
merged 1 commit into from
Mar 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -334,9 +334,7 @@ repos:
# https://github.com/apache/airflow/issues/36484
exclude: |
(?x)^(
^airflow\/providers\/google\/cloud\/operators\/cloud_storage_transfer_service.py$|
^airflow\/providers\/apache\/spark\/operators\/spark_submit.py\.py$|
^airflow\/providers\/apache\/spark\/operators\/spark_submit\.py$|
^airflow\/providers\/google\/cloud\/operators\/cloud_storage_transfer_service.py$
)$
- id: ruff
name: Run 'ruff' for extremely fast Python linting
Expand Down
34 changes: 4 additions & 30 deletions airflow/providers/apache/spark/operators/spark_jdbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,6 @@ class SparkJDBCOperator(SparkSubmitOperator):
:param spark_files: Additional files to upload to the container running the job
:param spark_jars: Additional jars to upload and add to the driver and
executor classpath
:param num_executors: number of executor to run. This should be set so as to manage
the number of connections made with the JDBC database
:param executor_cores: Number of cores per executor
:param executor_memory: Memory per executor (e.g. 1000M, 2G)
:param driver_memory: Memory allocated to the driver (e.g. 1000M, 2G)
:param verbose: Whether to pass the verbose flag to spark-submit for debugging
:param keytab: Full path to the file that contains the keytab
:param principal: The name of the kerberos principal used for keytab
:param cmd_type: Which way the data should flow. 2 possible values:
spark_to_jdbc: data written by spark from metastore to jdbc
jdbc_to_spark: data written by spark from jdbc to metastore
Expand All @@ -60,7 +52,7 @@ class SparkJDBCOperator(SparkSubmitOperator):
:param jdbc_driver: Name of the JDBC driver to use for the JDBC connection. This
driver (usually a jar) should be passed in the 'jars' parameter
:param metastore_table: The name of the metastore table,
:param jdbc_truncate: (spark_to_jdbc only) Whether or not Spark should truncate or
:param jdbc_truncate: (spark_to_jdbc only) Whether Spark should truncate or
drop and recreate the JDBC table. This only takes effect if
'save_mode' is set to Overwrite. Also, if the schema is
different, Spark cannot truncate, and will drop and recreate
Expand Down Expand Up @@ -91,9 +83,7 @@ class SparkJDBCOperator(SparkSubmitOperator):
(e.g: "name CHAR(64), comments VARCHAR(1024)").
The specified types should be valid spark sql data
types.
:param use_krb5ccache: if True, configure spark to use ticket cache instead of relying
on keytab for Kerberos login

:param kwargs: kwargs passed to SparkSubmitOperator.
"""

def __init__(
Expand All @@ -105,13 +95,6 @@ def __init__(
spark_py_files: str | None = None,
spark_files: str | None = None,
spark_jars: str | None = None,
num_executors: int | None = None,
executor_cores: int | None = None,
executor_memory: str | None = None,
driver_memory: str | None = None,
verbose: bool = False,
principal: str | None = None,
keytab: str | None = None,
cmd_type: str = "spark_to_jdbc",
jdbc_table: str | None = None,
jdbc_conn_id: str = "jdbc-default",
Expand All @@ -127,7 +110,6 @@ def __init__(
lower_bound: str | None = None,
upper_bound: str | None = None,
create_table_column_types: str | None = None,
use_krb5ccache: bool = False,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand All @@ -137,13 +119,6 @@ def __init__(
self._spark_py_files = spark_py_files
self._spark_files = spark_files
self._spark_jars = spark_jars
self._num_executors = num_executors
self._executor_cores = executor_cores
self._executor_memory = executor_memory
self._driver_memory = driver_memory
self._verbose = verbose
self._keytab = keytab
self._principal = principal
self._cmd_type = cmd_type
self._jdbc_table = jdbc_table
self._jdbc_conn_id = jdbc_conn_id
Expand All @@ -160,7 +135,6 @@ def __init__(
self._upper_bound = upper_bound
self._create_table_column_types = create_table_column_types
self._hook: SparkJDBCHook | None = None
self._use_krb5ccache = use_krb5ccache

def execute(self, context: Context) -> None:
"""Call the SparkSubmitHook to run the provided spark job."""
Expand All @@ -186,8 +160,8 @@ def _get_hook(self) -> SparkJDBCHook:
executor_memory=self._executor_memory,
driver_memory=self._driver_memory,
verbose=self._verbose,
keytab=self._keytab,
principal=self._principal,
keytab=self.keytab,
principal=self.principal,
cmd_type=self._cmd_type,
jdbc_table=self._jdbc_table,
jdbc_conn_id=self._jdbc_conn_id,
Expand Down
90 changes: 45 additions & 45 deletions airflow/providers/apache/spark/operators/spark_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,21 +81,21 @@ class SparkSubmitOperator(BaseOperator):
"""

template_fields: Sequence[str] = (
"_application",
"_conf",
"_files",
"_py_files",
"_jars",
"_driver_class_path",
"_packages",
"_exclude_packages",
"_keytab",
"_principal",
"_proxy_user",
"_name",
"_application_args",
"_env_vars",
"_properties_file",
"application",
"conf",
"files",
"py_files",
"jars",
"driver_class_path",
"packages",
"exclude_packages",
"keytab",
"principal",
"proxy_user",
"name",
"application_args",
"env_vars",
"properties_file",
)
ui_color = WEB_COLORS["LIGHTORANGE"]

Expand Down Expand Up @@ -135,32 +135,32 @@ def __init__(
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self._application = application
self._conf = conf
self._files = files
self._py_files = py_files
self.application = application
self.conf = conf
self.files = files
self.py_files = py_files
self._archives = archives
self._driver_class_path = driver_class_path
self._jars = jars
self.driver_class_path = driver_class_path
self.jars = jars
self._java_class = java_class
self._packages = packages
self._exclude_packages = exclude_packages
self.packages = packages
self.exclude_packages = exclude_packages
self._repositories = repositories
self._total_executor_cores = total_executor_cores
self._executor_cores = executor_cores
self._executor_memory = executor_memory
self._driver_memory = driver_memory
self._keytab = keytab
self._principal = principal
self._proxy_user = proxy_user
self._name = name
self.keytab = keytab
self.principal = principal
self.proxy_user = proxy_user
self.name = name
self._num_executors = num_executors
self._status_poll_interval = status_poll_interval
self._application_args = application_args
self._env_vars = env_vars
self.application_args = application_args
self.env_vars = env_vars
self._verbose = verbose
self._spark_binary = spark_binary
self._properties_file = properties_file
self.properties_file = properties_file
self._queue = queue
self._deploy_mode = deploy_mode
self._hook: SparkSubmitHook | None = None
Expand All @@ -171,7 +171,7 @@ def execute(self, context: Context) -> None:
"""Call the SparkSubmitHook to run the provided spark job."""
if self._hook is None:
self._hook = self._get_hook()
self._hook.submit(self._application)
self._hook.submit(self.application)

def on_kill(self) -> None:
if self._hook is None:
Expand All @@ -180,32 +180,32 @@ def on_kill(self) -> None:

def _get_hook(self) -> SparkSubmitHook:
return SparkSubmitHook(
conf=self._conf,
conf=self.conf,
conn_id=self._conn_id,
files=self._files,
py_files=self._py_files,
files=self.files,
py_files=self.py_files,
archives=self._archives,
driver_class_path=self._driver_class_path,
jars=self._jars,
driver_class_path=self.driver_class_path,
jars=self.jars,
java_class=self._java_class,
packages=self._packages,
exclude_packages=self._exclude_packages,
packages=self.packages,
exclude_packages=self.exclude_packages,
repositories=self._repositories,
total_executor_cores=self._total_executor_cores,
executor_cores=self._executor_cores,
executor_memory=self._executor_memory,
driver_memory=self._driver_memory,
keytab=self._keytab,
principal=self._principal,
proxy_user=self._proxy_user,
name=self._name,
keytab=self.keytab,
principal=self.principal,
proxy_user=self.proxy_user,
name=self.name,
num_executors=self._num_executors,
status_poll_interval=self._status_poll_interval,
application_args=self._application_args,
env_vars=self._env_vars,
application_args=self.application_args,
env_vars=self.env_vars,
verbose=self._verbose,
spark_binary=self._spark_binary,
properties_file=self._properties_file,
properties_file=self.properties_file,
queue=self._queue,
deploy_mode=self._deploy_mode,
use_krb5ccache=self._use_krb5ccache,
Expand Down
49 changes: 47 additions & 2 deletions tests/providers/apache/spark/operators/test_spark_jdbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
# under the License.
from __future__ import annotations

import pytest

from airflow.models.dag import DAG
from airflow.providers.apache.spark.operators.spark_jdbc import SparkJDBCOperator
from airflow.utils import timezone
Expand Down Expand Up @@ -111,8 +113,8 @@ def test_execute(self):
assert expected_dict["executor_memory"] == operator._executor_memory
assert expected_dict["driver_memory"] == operator._driver_memory
assert expected_dict["verbose"] == operator._verbose
assert expected_dict["keytab"] == operator._keytab
assert expected_dict["principal"] == operator._principal
assert expected_dict["keytab"] == operator.keytab
assert expected_dict["principal"] == operator.principal
assert expected_dict["cmd_type"] == operator._cmd_type
assert expected_dict["jdbc_table"] == operator._jdbc_table
assert expected_dict["jdbc_driver"] == operator._jdbc_driver
Expand All @@ -128,3 +130,46 @@ def test_execute(self):
assert expected_dict["upper_bound"] == operator._upper_bound
assert expected_dict["create_table_column_types"] == operator._create_table_column_types
assert expected_dict["use_krb5ccache"] == operator._use_krb5ccache

@pytest.mark.db_test
def test_templating_with_create_task_instance_of_operator(self, create_task_instance_of_operator):
ti = create_task_instance_of_operator(
SparkJDBCOperator,
# Templated fields
application="{{ 'application' }}",
conf="{{ 'conf' }}",
files="{{ 'files' }}",
py_files="{{ 'py-files' }}",
jars="{{ 'jars' }}",
driver_class_path="{{ 'driver_class_path' }}",
packages="{{ 'packages' }}",
exclude_packages="{{ 'exclude_packages' }}",
keytab="{{ 'keytab' }}",
principal="{{ 'principal' }}",
proxy_user="{{ 'proxy_user' }}",
name="{{ 'name' }}",
application_args="{{ 'application_args' }}",
env_vars="{{ 'env_vars' }}",
properties_file="{{ 'properties_file' }}",
# Other parameters
dag_id="test_template_body_templating_dag",
task_id="test_template_body_templating_task",
execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
)
ti.render_templates()
task: SparkJDBCOperator = ti.task
assert task.application == "application"
assert task.conf == "conf"
assert task.files == "files"
assert task.py_files == "py-files"
assert task.jars == "jars"
assert task.driver_class_path == "driver_class_path"
assert task.packages == "packages"
assert task.exclude_packages == "exclude_packages"
assert task.keytab == "keytab"
assert task.principal == "principal"
assert task.proxy_user == "proxy_user"
assert task.name == "name"
assert task.application_args == "application_args"
assert task.env_vars == "env_vars"
assert task.properties_file == "properties_file"
Loading