diff --git a/airflow/providers/google/cloud/example_dags/example_dataproc.py b/airflow/providers/google/cloud/example_dags/example_dataproc.py index c5cf524e82e10..30bb0db60a231 100644 --- a/airflow/providers/google/cloud/example_dags/example_dataproc.py +++ b/airflow/providers/google/cloud/example_dags/example_dataproc.py @@ -26,10 +26,14 @@ from airflow import models from airflow.providers.google.cloud.operators.dataproc import ( ClusterGenerator, + DataprocCreateBatchOperator, DataprocCreateClusterOperator, DataprocCreateWorkflowTemplateOperator, + DataprocDeleteBatchOperator, DataprocDeleteClusterOperator, + DataprocGetBatchOperator, DataprocInstantiateWorkflowTemplateOperator, + DataprocListBatchesOperator, DataprocSubmitJobOperator, DataprocUpdateClusterOperator, ) @@ -174,6 +178,13 @@ }, "jobs": [{"step_id": "pig_job_1", "pig_job": PIG_JOB["pig_job"]}], } +BATCH_ID = "test-batch-id" +BATCH_CONFIG = { + "spark_batch": { + "jar_file_uris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"], + "main_class": "org.apache.spark.examples.SparkPi", + }, +} with models.DAG( @@ -282,3 +293,41 @@ # Task dependency created via `XComArgs`: # spark_task_async >> spark_task_async_sensor + +with models.DAG( + "example_gcp_batch_dataproc", + schedule_interval='@once', + start_date=datetime(2021, 1, 1), + catchup=False, +) as dag_batch: + # [START how_to_cloud_dataproc_create_batch_operator] + create_batch = DataprocCreateBatchOperator( + task_id="create_batch", + project_id=PROJECT_ID, + region=REGION, + batch=BATCH_CONFIG, + batch_id=BATCH_ID, + ) + # [END how_to_cloud_dataproc_create_batch_operator] + + # [START how_to_cloud_dataproc_get_batch_operator] + get_batch = DataprocGetBatchOperator( + task_id="get_batch", project_id=PROJECT_ID, region=REGION, batch_id=BATCH_ID + ) + # [END how_to_cloud_dataproc_get_batch_operator] + + # [START how_to_cloud_dataproc_list_batches_operator] + list_batches = DataprocListBatchesOperator( + task_id="list_batches", + project_id=PROJECT_ID, + region=REGION, + ) + # [END how_to_cloud_dataproc_list_batches_operator] + + # [START how_to_cloud_dataproc_delete_batch_operator] + delete_batch = DataprocDeleteBatchOperator( + task_id="delete_batch", project_id=PROJECT_ID, region=REGION, batch_id=BATCH_ID + ) + # [END how_to_cloud_dataproc_delete_batch_operator] + + create_batch >> get_batch >> list_batches >> delete_batch diff --git a/airflow/providers/google/cloud/hooks/dataproc.py b/airflow/providers/google/cloud/hooks/dataproc.py index b63ef6d4a4567..1ac35c05a9605 100644 --- a/airflow/providers/google/cloud/hooks/dataproc.py +++ b/airflow/providers/google/cloud/hooks/dataproc.py @@ -24,8 +24,11 @@ from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union from google.api_core.exceptions import ServerError +from google.api_core.operation import Operation from google.api_core.retry import Retry from google.cloud.dataproc_v1 import ( + Batch, + BatchControllerClient, Cluster, ClusterControllerClient, Job, @@ -267,6 +270,34 @@ def get_job_client( credentials=self._get_credentials(), client_info=self.client_info, client_options=client_options ) + def get_batch_client( + self, region: Optional[str] = None, location: Optional[str] = None + ) -> BatchControllerClient: + """Returns BatchControllerClient""" + if location is not None: + warnings.warn( + "Parameter `location` will be deprecated. " + "Please provide value through `region` parameter instead.", + DeprecationWarning, + stacklevel=2, + ) + region = location + client_options = None + if region and region != 'global': + client_options = {'api_endpoint': f'{region}-dataproc.googleapis.com:443'} + + return BatchControllerClient( + credentials=self._get_credentials(), client_info=self.client_info, client_options=client_options + ) + + def wait_for_operation(self, timeout: float, operation: Operation): + """Waits for long-lasting operation to complete.""" + try: + return operation.result(timeout=timeout) + except Exception: + error = operation.exception(timeout=timeout) + raise AirflowException(error) + @GoogleBaseHook.fallback_to_default_project_id def create_cluster( self, @@ -1030,3 +1061,191 @@ def cancel_job( metadata=metadata, ) return job + + @GoogleBaseHook.fallback_to_default_project_id + def create_batch( + self, + region: str, + project_id: str, + batch: Union[Dict, Batch], + batch_id: Optional[str] = None, + request_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = "", + ): + """ + Creates a batch workload. + + :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to. + :type project_id: str + :param region: Required. The Cloud Dataproc region in which to handle the request. + :type region: str + :param batch: Required. The batch to create. + :type batch: google.cloud.dataproc_v1.types.Batch + :param batch_id: Optional. The ID to use for the batch, which will become the final component + of the batch's resource name. + This value must be 4-63 characters. Valid characters are /[a-z][0-9]-/. + :type batch_id: str + :param request_id: Optional. A unique id used to identify the request. If the server receives two + ``CreateBatchRequest`` requests with the same id, then the second request will be ignored and + the first ``google.longrunning.Operation`` created and stored in the backend is returned. + :type request_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_batch_client(region) + parent = f'projects/{project_id}/regions/{region}' + + result = client.create_batch( + request={ + 'parent': parent, + 'batch': batch, + 'batch_id': batch_id, + 'request_id': request_id, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def delete_batch( + self, + batch_id: str, + region: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ): + """ + Deletes the batch workload resource. + + :param batch_id: Required. The ID to use for the batch, which will become the final component + of the batch's resource name. + This value must be 4-63 characters. Valid characters are /[a-z][0-9]-/. + :type batch_id: str + :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to. + :type project_id: str + :param region: Required. The Cloud Dataproc region in which to handle the request. + :type region: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_batch_client(region) + name = f"projects/{project_id}/regions/{region}/batches/{batch_id}" + + result = client.delete_batch( + request={ + 'name': name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def get_batch( + self, + batch_id: str, + region: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ): + """ + Gets the batch workload resource representation. + + :param batch_id: Required. The ID to use for the batch, which will become the final component + of the batch's resource name. + This value must be 4-63 characters. Valid characters are /[a-z][0-9]-/. + :type batch_id: str + :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to. + :type project_id: str + :param region: Required. The Cloud Dataproc region in which to handle the request. + :type region: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_batch_client(region) + name = f"projects/{project_id}/regions/{region}/batches/{batch_id}" + + result = client.get_batch( + request={ + 'name': name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def list_batches( + self, + region: str, + project_id: str, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ): + """ + Lists batch workloads. + + :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to. + :type project_id: str + :param region: Required. The Cloud Dataproc region in which to handle the request. + :type region: str + :param page_size: Optional. The maximum number of batches to return in each response. The service may + return fewer than this value. The default page size is 20; the maximum page size is 1000. + :type page_size: int + :param page_token: Optional. A page token received from a previous ``ListBatches`` call. + Provide this token to retrieve the subsequent page. + :type page_token: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_batch_client(region) + parent = f'projects/{project_id}/regions/{region}' + + result = client.list_batches( + request={ + 'parent': parent, + 'page_size': page_size, + 'page_token': page_token, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py index 53be7332a6343..a0b2453fdcf92 100644 --- a/airflow/providers/google/cloud/operators/dataproc.py +++ b/airflow/providers/google/cloud/operators/dataproc.py @@ -28,9 +28,10 @@ from datetime import datetime, timedelta from typing import Dict, List, Optional, Sequence, Set, Tuple, Union +from google.api_core import operation # type: ignore from google.api_core.exceptions import AlreadyExists, NotFound from google.api_core.retry import Retry, exponential_sleep_generator -from google.cloud.dataproc_v1 import Cluster +from google.cloud.dataproc_v1 import Batch, Cluster from google.protobuf.duration_pb2 import Duration from google.protobuf.field_mask_pb2 import FieldMask @@ -2159,3 +2160,332 @@ def execute(self, context: Dict): ) operation.result() self.log.info("Updated %s cluster.", self.cluster_name) + + +class DataprocCreateBatchOperator(BaseOperator): + """ + Creates a batch workload. + + :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to. + :type project_id: str + :param region: Required. The Cloud Dataproc region in which to handle the request. + :type region: str + :param batch: Required. The batch to create. + :type batch: google.cloud.dataproc_v1.types.Batch + :param batch_id: Optional. The ID to use for the batch, which will become the final component + of the batch's resource name. + This value must be 4-63 characters. Valid characters are /[a-z][0-9]-/. + :type batch_id: str + :param request_id: Optional. A unique id used to identify the request. If the server receives two + ``CreateBatchRequest`` requests with the same id, then the second request will be ignored and + the first ``google.longrunning.Operation`` created and stored in the backend is returned. + :type request_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + 'project_id', + 'batch_id', + 'region', + 'impersonation_chain', + ) + + def __init__( + self, + *, + region: str = None, + project_id: str, + batch: Union[Dict, Batch], + batch_id: Optional[str] = None, + request_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = "", + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.region = region + self.project_id = project_id + self.batch = batch + self.batch_id = batch_id + self.request_id = request_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + self.operation: Optional[operation.Operation] = None + + def execute(self, context): + hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) + self.log.info("Creating batch") + try: + self.operation = hook.create_batch( + region=self.region, + project_id=self.project_id, + batch=self.batch, + batch_id=self.batch_id, + request_id=self.request_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + result = hook.wait_for_operation(self.timeout, self.operation) + self.log.info("Batch %s created", self.batch_id) + except AlreadyExists: + self.log.info("Batch with given id already exists") + result = hook.get_batch( + batch_id=self.batch_id, + region=self.region, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return Batch.to_dict(result) + + def on_kill(self): + if self.operation: + self.operation.cancel() + + +class DataprocDeleteBatchOperator(BaseOperator): + """ + Deletes the batch workload resource. + + :param batch_id: Required. The ID to use for the batch, which will become the final component + of the batch's resource name. + This value must be 4-63 characters. Valid characters are /[a-z][0-9]-/. + :type batch_id: str + :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to. + :type project_id: str + :param region: Required. The Cloud Dataproc region in which to handle the request. + :type region: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ("batch_id", "region", "project_id", "impersonation_chain") + + def __init__( + self, + *, + batch_id: str, + region: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = "", + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.batch_id = batch_id + self.region = region + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) + self.log.info("Deleting batch: %s", self.batch_id) + hook.delete_batch( + batch_id=self.batch_id, + region=self.region, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + self.log.info("Batch deleted.") + + +class DataprocGetBatchOperator(BaseOperator): + """ + Gets the batch workload resource representation. + + :param batch_id: Required. The ID to use for the batch, which will become the final component + of the batch's resource name. + This value must be 4-63 characters. Valid characters are /[a-z][0-9]-/. + :type batch_id: str + :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to. + :type project_id: str + :param region: Required. The Cloud Dataproc region in which to handle the request. + :type region: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ("batch_id", "region", "project_id", "impersonation_chain") + + def __init__( + self, + *, + batch_id: str, + region: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = "", + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.batch_id = batch_id + self.region = region + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) + self.log.info("Getting batch: %s", self.batch_id) + batch = hook.get_batch( + batch_id=self.batch_id, + region=self.region, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return Batch.to_dict(batch) + + +class DataprocListBatchesOperator(BaseOperator): + """ + Lists batch workloads. + + :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to. + :type project_id: str + :param region: Required. The Cloud Dataproc region in which to handle the request. + :type region: str + :param page_size: Optional. The maximum number of batches to return in each response. The service may + return fewer than this value. The default page size is 20; the maximum page size is 1000. + :type page_size: int + :param page_token: Optional. A page token received from a previous ``ListBatches`` call. + Provide this token to retrieve the subsequent page. + :type page_token: str + :param retry: Optional, a retry object used to retry requests. If `None` is specified, requests + will not be retried. + :type retry: Optional[Retry] + :param timeout: Optional, the amount of time, in seconds, to wait for the request to complete. + Note that if `retry` is specified, the timeout applies to each individual attempt. + :type timeout: Optional[float] + :param metadata: Optional, additional metadata that is provided to the method. + :type metadata: Optional[Sequence[Tuple[str, str]]] + :param gcp_conn_id: Optional, the connection ID used to connect to Google Cloud Platform. + :type gcp_conn_id: Optional[str] + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :rtype: List[dict] + """ + + template_fields = ("region", "project_id", "impersonation_chain") + + def __init__( + self, + *, + region: str, + project_id: Optional[str] = None, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = "", + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.region = region + self.project_id = project_id + self.page_size = page_size + self.page_token = page_token + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) + results = hook.list_batches( + region=self.region, + project_id=self.project_id, + page_size=self.page_size, + page_token=self.page_token, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return [Batch.to_dict(result) for result in results] diff --git a/docs/apache-airflow-providers-google/index.rst b/docs/apache-airflow-providers-google/index.rst index df67decebeb6d..8ae865e019499 100644 --- a/docs/apache-airflow-providers-google/index.rst +++ b/docs/apache-airflow-providers-google/index.rst @@ -102,7 +102,7 @@ PIP package Version required ``google-cloud-build`` ``>=3.0.0,<4.0.0`` ``google-cloud-container`` ``>=0.1.1,<2.0.0`` ``google-cloud-datacatalog`` ``>=3.0.0,<4.0.0`` -``google-cloud-dataproc`` ``>=2.2.0,<4.0.0`` +``google-cloud-dataproc`` ``>=3.1.0,<4.0.0`` ``google-cloud-dlp`` ``>=0.11.0,<2.0.0`` ``google-cloud-kms`` ``>=2.0.0,<3.0.0`` ``google-cloud-language`` ``>=1.1.1,<2.0.0`` diff --git a/docs/apache-airflow-providers-google/operators/cloud/dataproc.rst b/docs/apache-airflow-providers-google/operators/cloud/dataproc.rst index 3d506d0f5aeb2..f93643a42531f 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/dataproc.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/dataproc.rst @@ -212,6 +212,55 @@ Once a workflow is created users can trigger it using :start-after: [START how_to_cloud_dataproc_trigger_workflow_template] :end-before: [END how_to_cloud_dataproc_trigger_workflow_template] +Create a Batch +-------------- + +Dataproc supports creating a batch workload. + +A batch can be created using: +:class: ``~airflow.providers.google.cloud.operators.dataproc.DataprocCreateBatchOperator``. + +.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_dataproc.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_dataproc_create_batch_operator] + :end-before: [END how_to_cloud_dataproc_create_batch_operator] + +Get a Batch +----------- + +To get a batch you can use: +:class: ``~airflow.providers.google.cloud.operators.dataproc.DataprocGetBatchOperator``. + +.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_dataproc.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_dataproc_get_batch_operator] + :end-before: [END how_to_cloud_dataproc_get_batch_operator] + +List a Batch +------------ + +To get a list of exists batches you can use: +:class: ``~airflow.providers.google.cloud.operators.dataproc.DataprocListBatchesOperator``. + +.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_dataproc.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_dataproc_list_batches_operator] + :end-before: [END how_to_cloud_dataproc_list_batches_operator] + +Delete a Batch +-------------- + +To delete a batch you can use: +:class: ``~airflow.providers.google.cloud.operators.dataproc.DataprocDeleteBatchOperator``. + +.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_dataproc.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_dataproc_delete_batch_operator] + :end-before: [END how_to_cloud_dataproc_delete_batch_operator] References ^^^^^^^^^^ diff --git a/setup.py b/setup.py index 2fee3e10be409..0490e5d6e3aa1 100644 --- a/setup.py +++ b/setup.py @@ -306,7 +306,7 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version 'google-cloud-build>=3.0.0,<4.0.0', 'google-cloud-container>=0.1.1,<2.0.0', 'google-cloud-datacatalog>=3.0.0,<4.0.0', - 'google-cloud-dataproc>=2.2.0,<4.0.0', + 'google-cloud-dataproc>=3.1.0,<4.0.0', 'google-cloud-dataproc-metastore>=1.2.0,<2.0.0', 'google-cloud-dlp>=0.11.0,<2.0.0', 'google-cloud-kms>=2.0.0,<3.0.0', diff --git a/tests/providers/google/cloud/hooks/test_dataproc.py b/tests/providers/google/cloud/hooks/test_dataproc.py index c81c39cc6eb49..598bb910785a0 100644 --- a/tests/providers/google/cloud/hooks/test_dataproc.py +++ b/tests/providers/google/cloud/hooks/test_dataproc.py @@ -42,6 +42,10 @@ "labels": LABELS, "project_id": GCP_PROJECT, } +BATCH = {"batch": "test-batch"} +BATCH_ID = "batch-id" +BATCH_NAME = "projects/{}/regions/{}/batches/{}" +PARENT = "projects/{}/regions/{}" BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}" DATAPROC_STRING = "airflow.providers.google.cloud.hooks.dataproc.{}" @@ -179,6 +183,47 @@ def test_get_job_client_region_deprecation_warning( ) assert warning_message == str(warnings[0].message) + @mock.patch(DATAPROC_STRING.format("DataprocHook._get_credentials")) + @mock.patch(DATAPROC_STRING.format("DataprocHook.client_info"), new_callable=mock.PropertyMock) + @mock.patch(DATAPROC_STRING.format("BatchControllerClient")) + def test_get_batch_client(self, mock_client, mock_client_info, mock_get_credentials): + self.hook.get_batch_client(region=GCP_LOCATION) + mock_client.assert_called_once_with( + credentials=mock_get_credentials.return_value, + client_info=mock_client_info.return_value, + client_options=None, + ) + + @mock.patch(DATAPROC_STRING.format("DataprocHook._get_credentials")) + @mock.patch(DATAPROC_STRING.format("DataprocHook.client_info"), new_callable=mock.PropertyMock) + @mock.patch(DATAPROC_STRING.format("BatchControllerClient")) + def test_get_batch_client_region(self, mock_client, mock_client_info, mock_get_credentials): + self.hook.get_batch_client(region='region1') + mock_client.assert_called_once_with( + credentials=mock_get_credentials.return_value, + client_info=mock_client_info.return_value, + client_options={'api_endpoint': 'region1-dataproc.googleapis.com:443'}, + ) + + @mock.patch(DATAPROC_STRING.format("DataprocHook._get_credentials")) + @mock.patch(DATAPROC_STRING.format("DataprocHook.client_info"), new_callable=mock.PropertyMock) + @mock.patch(DATAPROC_STRING.format("BatchControllerClient")) + def test_get_batch_client_region_deprecation_warning( + self, mock_client, mock_client_info, mock_get_credentials + ): + warning_message = ( + "Parameter `location` will be deprecated. " + "Please provide value through `region` parameter instead." + ) + with pytest.warns(DeprecationWarning) as warnings: + self.hook.get_batch_client(location='region1') + mock_client.assert_called_once_with( + credentials=mock_get_credentials.return_value, + client_info=mock_client_info.return_value, + client_options={'api_endpoint': 'region1-dataproc.googleapis.com:443'}, + ) + assert warning_message == str(warnings[0].message) + @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client")) def test_create_cluster(self, mock_client): self.hook.create_cluster( @@ -615,6 +660,79 @@ def test_cancel_job_deprecation_warning_param_rename(self, mock_client): ) assert warning_message == str(warnings[0].message) + @mock.patch(DATAPROC_STRING.format("DataprocHook.get_batch_client")) + def test_create_batch(self, mock_client): + self.hook.create_batch( + project_id=GCP_PROJECT, + region=GCP_LOCATION, + batch=BATCH, + batch_id=BATCH_ID, + ) + mock_client.assert_called_once_with(GCP_LOCATION) + mock_client.return_value.create_batch.assert_called_once_with( + request=dict( + parent=PARENT.format(GCP_PROJECT, GCP_LOCATION), + batch=BATCH, + batch_id=BATCH_ID, + request_id=None, + ), + metadata="", + retry=None, + timeout=None, + ) + + @mock.patch(DATAPROC_STRING.format("DataprocHook.get_batch_client")) + def test_delete_batch(self, mock_client): + self.hook.delete_batch( + batch_id=BATCH_ID, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + ) + mock_client.assert_called_once_with(GCP_LOCATION) + mock_client.return_value.delete_batch.assert_called_once_with( + request=dict( + name=BATCH_NAME.format(GCP_PROJECT, GCP_LOCATION, BATCH_ID), + ), + metadata=None, + retry=None, + timeout=None, + ) + + @mock.patch(DATAPROC_STRING.format("DataprocHook.get_batch_client")) + def test_get_batch(self, mock_client): + self.hook.get_batch( + batch_id=BATCH_ID, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + ) + mock_client.assert_called_once_with(GCP_LOCATION) + mock_client.return_value.get_batch.assert_called_once_with( + request=dict( + name=BATCH_NAME.format(GCP_PROJECT, GCP_LOCATION, BATCH_ID), + ), + metadata=None, + retry=None, + timeout=None, + ) + + @mock.patch(DATAPROC_STRING.format("DataprocHook.get_batch_client")) + def test_list_batches(self, mock_client): + self.hook.list_batches( + project_id=GCP_PROJECT, + region=GCP_LOCATION, + ) + mock_client.assert_called_once_with(GCP_LOCATION) + mock_client.return_value.list_batches.assert_called_once_with( + request=dict( + parent=PARENT.format(GCP_PROJECT, GCP_LOCATION), + page_size=None, + page_token=None, + ), + metadata=None, + retry=None, + timeout=None, + ) + class TestDataProcJobBuilder(unittest.TestCase): def setUp(self) -> None: diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py index f8500aa9b0080..34e63537f0c55 100644 --- a/tests/providers/google/cloud/operators/test_dataproc.py +++ b/tests/providers/google/cloud/operators/test_dataproc.py @@ -29,12 +29,16 @@ from airflow.providers.google.cloud.operators.dataproc import ( ClusterGenerator, DataprocClusterLink, + DataprocCreateBatchOperator, DataprocCreateClusterOperator, DataprocCreateWorkflowTemplateOperator, + DataprocDeleteBatchOperator, DataprocDeleteClusterOperator, + DataprocGetBatchOperator, DataprocInstantiateInlineWorkflowTemplateOperator, DataprocInstantiateWorkflowTemplateOperator, DataprocJobLink, + DataprocListBatchesOperator, DataprocScaleClusterOperator, DataprocSubmitHadoopJobOperator, DataprocSubmitHiveJobOperator, @@ -199,6 +203,13 @@ "region": GCP_LOCATION, "project_id": GCP_PROJECT, } +BATCH_ID = "test-batch-id" +BATCH = { + "spark_batch": { + "jar_file_uris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"], + "main_class": "org.apache.spark.examples.SparkPi", + }, +} def assert_warning(msg: str, warnings): @@ -1661,3 +1672,118 @@ def test_location_deprecation_warning(self, mock_hook): template=WORKFLOW_TEMPLATE, ) op.execute(context={}) + + +class TestDataprocCreateBatchOperator: + @mock.patch(DATAPROC_PATH.format("Batch.to_dict")) + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + def test_execute(self, mock_hook, to_dict_mock): + op = DataprocCreateBatchOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + batch=BATCH, + batch_id=BATCH_ID, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + op.execute(context={}) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.create_batch.assert_called_once_with( + region=GCP_LOCATION, + project_id=GCP_PROJECT, + batch=BATCH, + batch_id=BATCH_ID, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + +class TestDataprocDeleteBatchOperator: + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + def test_execute(self, mock_hook): + op = DataprocDeleteBatchOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + project_id=GCP_PROJECT, + region=GCP_LOCATION, + batch_id=BATCH_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + op.execute(context={}) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.delete_batch.assert_called_once_with( + project_id=GCP_PROJECT, + region=GCP_LOCATION, + batch_id=BATCH_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + +class TestDataprocGetBatchOperator: + @mock.patch(DATAPROC_PATH.format("Batch.to_dict")) + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + def test_execute(self, mock_hook, to_dict_mock): + op = DataprocGetBatchOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + project_id=GCP_PROJECT, + region=GCP_LOCATION, + batch_id=BATCH_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + op.execute(context={}) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.get_batch.assert_called_once_with( + project_id=GCP_PROJECT, + region=GCP_LOCATION, + batch_id=BATCH_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + +class TestDataprocListBatchesOperator: + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + def test_execute(self, mock_hook): + page_token = "page_token" + page_size = 42 + + op = DataprocListBatchesOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + page_size=page_size, + page_token=page_token, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + op.execute(context={}) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.list_batches.assert_called_once_with( + region=GCP_LOCATION, + project_id=GCP_PROJECT, + page_size=page_size, + page_token=page_token, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) diff --git a/tests/providers/google/cloud/operators/test_dataproc_system.py b/tests/providers/google/cloud/operators/test_dataproc_system.py index 568af28f53fa0..30f9a35d9a4c4 100644 --- a/tests/providers/google/cloud/operators/test_dataproc_system.py +++ b/tests/providers/google/cloud/operators/test_dataproc_system.py @@ -63,3 +63,7 @@ def tearDown(self): @provide_gcp_context(GCP_DATAPROC_KEY) def test_run_example_dag(self): self.run_dag(dag_id="example_gcp_dataproc", dag_folder=CLOUD_DAG_FOLDER) + + @provide_gcp_context(GCP_DATAPROC_KEY) + def test_run_batch_example_dag(self): + self.run_dag(dag_id="example_gcp_batch_dataproc", dag_folder=CLOUD_DAG_FOLDER)