Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create dataproc serverless spark batches operator #19248

Merged
49 changes: 49 additions & 0 deletions airflow/providers/google/cloud/example_dags/example_dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,14 @@
from airflow import models
from airflow.providers.google.cloud.operators.dataproc import (
ClusterGenerator,
DataprocCreateBatchOperator,
DataprocCreateClusterOperator,
DataprocCreateWorkflowTemplateOperator,
DataprocDeleteBatchOperator,
DataprocDeleteClusterOperator,
DataprocGetBatchOperator,
DataprocInstantiateWorkflowTemplateOperator,
DataprocListBatchesOperator,
DataprocSubmitJobOperator,
DataprocUpdateClusterOperator,
)
Expand Down Expand Up @@ -174,6 +178,13 @@
},
"jobs": [{"step_id": "pig_job_1", "pig_job": PIG_JOB["pig_job"]}],
}
BATCH_ID = "test-batch-id"
BATCH_CONFIG = {
"spark_batch": {
"jar_file_uris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"],
"main_class": "org.apache.spark.examples.SparkPi",
},
}


with models.DAG(
Expand Down Expand Up @@ -282,3 +293,41 @@

# Task dependency created via `XComArgs`:
# spark_task_async >> spark_task_async_sensor

with models.DAG(
"example_gcp_batch_dataproc",
schedule_interval='@once',
start_date=datetime(2021, 1, 1),
catchup=False,
) as dag_batch:
# [START how_to_cloud_dataproc_create_batch_operator]
create_batch = DataprocCreateBatchOperator(
task_id="create_batch",
project_id=PROJECT_ID,
region=REGION,
batch=BATCH_CONFIG,
batch_id=BATCH_ID,
)
# [END how_to_cloud_dataproc_create_batch_operator]

# [START how_to_cloud_dataproc_get_batch_operator]
get_batch = DataprocGetBatchOperator(
task_id="get_batch", project_id=PROJECT_ID, region=REGION, batch_id=BATCH_ID
)
# [END how_to_cloud_dataproc_get_batch_operator]

# [START how_to_cloud_dataproc_list_batches_operator]
list_batches = DataprocListBatchesOperator(
task_id="list_batches",
project_id=PROJECT_ID,
region=REGION,
)
# [END how_to_cloud_dataproc_list_batches_operator]

# [START how_to_cloud_dataproc_delete_batch_operator]
delete_batch = DataprocDeleteBatchOperator(
task_id="delete_batch", project_id=PROJECT_ID, region=REGION, batch_id=BATCH_ID
)
# [END how_to_cloud_dataproc_delete_batch_operator]

create_batch >> get_batch >> list_batches >> delete_batch
219 changes: 219 additions & 0 deletions airflow/providers/google/cloud/hooks/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union

from google.api_core.exceptions import ServerError
from google.api_core.operation import Operation
from google.api_core.retry import Retry
from google.cloud.dataproc_v1 import (
Batch,
BatchControllerClient,
Cluster,
ClusterControllerClient,
Job,
Expand Down Expand Up @@ -267,6 +270,34 @@ def get_job_client(
credentials=self._get_credentials(), client_info=self.client_info, client_options=client_options
)

def get_batch_client(
self, region: Optional[str] = None, location: Optional[str] = None
) -> BatchControllerClient:
"""Returns BatchControllerClient"""
if location is not None:
warnings.warn(
"Parameter `location` will be deprecated. "
"Please provide value through `region` parameter instead.",
DeprecationWarning,
stacklevel=2,
)
region = location
client_options = None
if region and region != 'global':
client_options = {'api_endpoint': f'{region}-dataproc.googleapis.com:443'}

return BatchControllerClient(
credentials=self._get_credentials(), client_info=self.client_info, client_options=client_options
)

def wait_for_operation(self, timeout: float, operation: Operation):
"""Waits for long-lasting operation to complete."""
try:
return operation.result(timeout=timeout)
except Exception:
error = operation.exception(timeout=timeout)
raise AirflowException(error)

@GoogleBaseHook.fallback_to_default_project_id
def create_cluster(
self,
Expand Down Expand Up @@ -1030,3 +1061,191 @@ def cancel_job(
metadata=metadata,
)
return job

@GoogleBaseHook.fallback_to_default_project_id
def create_batch(
self,
region: str,
project_id: str,
batch: Union[Dict, Batch],
batch_id: Optional[str] = None,
request_id: Optional[str] = None,
retry: Optional[Retry] = None,
timeout: Optional[float] = None,
metadata: Optional[Sequence[Tuple[str, str]]] = "",
):
"""
Creates a batch workload.

:param project_id: Required. The ID of the Google Cloud project that the cluster belongs to.
:type project_id: str
:param region: Required. The Cloud Dataproc region in which to handle the request.
:type region: str
:param batch: Required. The batch to create.
:type batch: google.cloud.dataproc_v1.types.Batch
:param batch_id: Optional. The ID to use for the batch, which will become the final component
of the batch's resource name.
This value must be 4-63 characters. Valid characters are /[a-z][0-9]-/.
:type batch_id: str
:param request_id: Optional. A unique id used to identify the request. If the server receives two
``CreateBatchRequest`` requests with the same id, then the second request will be ignored and
the first ``google.longrunning.Operation`` created and stored in the backend is returned.
:type request_id: str
:param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
retried.
:type retry: google.api_core.retry.Retry
:param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
``retry`` is specified, the timeout applies to each individual attempt.
:type timeout: float
:param metadata: Additional metadata that is provided to the method.
:type metadata: Sequence[Tuple[str, str]]
"""
client = self.get_batch_client(region)
parent = f'projects/{project_id}/regions/{region}'

result = client.create_batch(
request={
'parent': parent,
'batch': batch,
'batch_id': batch_id,
'request_id': request_id,
},
retry=retry,
timeout=timeout,
metadata=metadata,
)
return result

@GoogleBaseHook.fallback_to_default_project_id
def delete_batch(
self,
batch_id: str,
region: str,
project_id: str,
retry: Optional[Retry] = None,
timeout: Optional[float] = None,
metadata: Optional[Sequence[Tuple[str, str]]] = None,
):
"""
Deletes the batch workload resource.

:param batch_id: Required. The ID to use for the batch, which will become the final component
of the batch's resource name.
This value must be 4-63 characters. Valid characters are /[a-z][0-9]-/.
:type batch_id: str
:param project_id: Required. The ID of the Google Cloud project that the cluster belongs to.
:type project_id: str
:param region: Required. The Cloud Dataproc region in which to handle the request.
:type region: str
:param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
retried.
:type retry: google.api_core.retry.Retry
:param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
``retry`` is specified, the timeout applies to each individual attempt.
:type timeout: float
:param metadata: Additional metadata that is provided to the method.
:type metadata: Sequence[Tuple[str, str]]
"""
client = self.get_batch_client(region)
name = f"projects/{project_id}/regions/{region}/batches/{batch_id}"

result = client.delete_batch(
request={
'name': name,
},
retry=retry,
timeout=timeout,
metadata=metadata,
)
return result

@GoogleBaseHook.fallback_to_default_project_id
def get_batch(
self,
batch_id: str,
region: str,
project_id: str,
retry: Optional[Retry] = None,
timeout: Optional[float] = None,
metadata: Optional[Sequence[Tuple[str, str]]] = None,
):
"""
Gets the batch workload resource representation.

:param batch_id: Required. The ID to use for the batch, which will become the final component
of the batch's resource name.
This value must be 4-63 characters. Valid characters are /[a-z][0-9]-/.
:type batch_id: str
:param project_id: Required. The ID of the Google Cloud project that the cluster belongs to.
:type project_id: str
:param region: Required. The Cloud Dataproc region in which to handle the request.
:type region: str
:param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
retried.
:type retry: google.api_core.retry.Retry
:param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
``retry`` is specified, the timeout applies to each individual attempt.
:type timeout: float
:param metadata: Additional metadata that is provided to the method.
:type metadata: Sequence[Tuple[str, str]]
"""
client = self.get_batch_client(region)
name = f"projects/{project_id}/regions/{region}/batches/{batch_id}"

result = client.get_batch(
request={
'name': name,
},
retry=retry,
timeout=timeout,
metadata=metadata,
)
return result

@GoogleBaseHook.fallback_to_default_project_id
def list_batches(
self,
region: str,
project_id: str,
page_size: Optional[int] = None,
page_token: Optional[str] = None,
retry: Optional[Retry] = None,
timeout: Optional[float] = None,
metadata: Optional[Sequence[Tuple[str, str]]] = None,
):
"""
Lists batch workloads.

:param project_id: Required. The ID of the Google Cloud project that the cluster belongs to.
:type project_id: str
:param region: Required. The Cloud Dataproc region in which to handle the request.
:type region: str
:param page_size: Optional. The maximum number of batches to return in each response. The service may
return fewer than this value. The default page size is 20; the maximum page size is 1000.
:type page_size: int
:param page_token: Optional. A page token received from a previous ``ListBatches`` call.
Provide this token to retrieve the subsequent page.
:type page_token: str
:param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
retried.
:type retry: google.api_core.retry.Retry
:param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
``retry`` is specified, the timeout applies to each individual attempt.
:type timeout: float
:param metadata: Additional metadata that is provided to the method.
:type metadata: Sequence[Tuple[str, str]]
"""
client = self.get_batch_client(region)
parent = f'projects/{project_id}/regions/{region}'

result = client.list_batches(
request={
'parent': parent,
'page_size': page_size,
'page_token': page_token,
},
retry=retry,
timeout=timeout,
metadata=metadata,
)
return result
Loading