diff --git a/airflow/providers/apache/beam/provider.yaml b/airflow/providers/apache/beam/provider.yaml index 4000776d102bd..bf23c51d00deb 100644 --- a/airflow/providers/apache/beam/provider.yaml +++ b/airflow/providers/apache/beam/provider.yaml @@ -43,7 +43,7 @@ versions: dependencies: - apache-airflow>=2.4.0 - - apache-beam>=2.33.0 + - apache-beam>=2.47.0 integrations: - integration-name: Apache Beam diff --git a/airflow/providers/google/ads/hooks/ads.py b/airflow/providers/google/ads/hooks/ads.py index eadc9cacbab3c..f55f01a7a81cf 100644 --- a/airflow/providers/google/ads/hooks/ads.py +++ b/airflow/providers/google/ads/hooks/ads.py @@ -21,26 +21,18 @@ from tempfile import NamedTemporaryFile from typing import IO, Any +from google.ads.googleads.client import GoogleAdsClient +from google.ads.googleads.errors import GoogleAdsException +from google.ads.googleads.v12.services.services.customer_service import CustomerServiceClient +from google.ads.googleads.v12.services.services.google_ads_service import GoogleAdsServiceClient +from google.ads.googleads.v12.services.types.google_ads_service import GoogleAdsRow +from google.api_core.page_iterator import GRPCIterator from google.auth.exceptions import GoogleAuthError from airflow import AirflowException from airflow.compat.functools import cached_property from airflow.hooks.base import BaseHook from airflow.providers.google.common.hooks.base_google import get_field -from airflow.providers.google_vendor.googleads.client import GoogleAdsClient -from airflow.providers.google_vendor.googleads.errors import GoogleAdsException -from airflow.providers.google_vendor.googleads.v12.services.services.customer_service import ( - CustomerServiceClient, -) -from airflow.providers.google_vendor.googleads.v12.services.services.google_ads_service import ( - GoogleAdsServiceClient, -) -from airflow.providers.google_vendor.googleads.v12.services.services.google_ads_service.pagers import ( - SearchPager, -) -from airflow.providers.google_vendor.googleads.v12.services.types.google_ads_service import ( - GoogleAdsRow, -) class GoogleAdsHook(BaseHook): @@ -238,7 +230,7 @@ def _search( return self._extract_rows(iterators) - def _extract_rows(self, iterators: list[SearchPager]) -> list[GoogleAdsRow]: + def _extract_rows(self, iterators: list[GRPCIterator]) -> list[GoogleAdsRow]: """ Convert Google Page Iterator (GRPCIterator) objects to Google Ads Rows diff --git a/airflow/providers/google/cloud/_internal_client/secret_manager_client.py b/airflow/providers/google/cloud/_internal_client/secret_manager_client.py index 9ea72e63bbba9..0de1abfdd90ac 100644 --- a/airflow/providers/google/cloud/_internal_client/secret_manager_client.py +++ b/airflow/providers/google/cloud/_internal_client/secret_manager_client.py @@ -71,7 +71,7 @@ def get_secret(self, secret_id: str, project_id: str, secret_version: str = "lat """ name = self.client.secret_version_path(project_id, secret_id, secret_version) try: - response = self.client.access_secret_version(name) + response = self.client.access_secret_version(request={"name": name}) value = response.payload.data.decode("UTF-8") return value except NotFound: diff --git a/airflow/providers/google/cloud/example_dags/example_compute.py b/airflow/providers/google/cloud/example_dags/example_compute.py deleted file mode 100644 index e42cb6d962ed2..0000000000000 --- a/airflow/providers/google/cloud/example_dags/example_compute.py +++ /dev/null @@ -1,107 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Example Airflow DAG that starts, stops and sets the machine type of a Google Compute -Engine instance. - -This DAG relies on the following OS environment variables - -* GCP_PROJECT_ID - Google Cloud project where the Compute Engine instance exists. -* GCE_ZONE - Google Cloud zone where the instance exists. -* GCE_INSTANCE - Name of the Compute Engine instance. -* GCE_SHORT_MACHINE_TYPE_NAME - Machine type resource name to set, e.g. 'n1-standard-1'. - See https://cloud.google.com/compute/docs/machine-types -""" -from __future__ import annotations - -import os -from datetime import datetime - -from airflow import models -from airflow.models.baseoperator import chain -from airflow.providers.google.cloud.operators.compute import ( - ComputeEngineSetMachineTypeOperator, - ComputeEngineStartInstanceOperator, - ComputeEngineStopInstanceOperator, -) - -# [START howto_operator_gce_args_common] -GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") -GCE_ZONE = os.environ.get("GCE_ZONE", "europe-west1-b") -GCE_INSTANCE = os.environ.get("GCE_INSTANCE", "testinstance") -# [END howto_operator_gce_args_common] - - -GCE_SHORT_MACHINE_TYPE_NAME = os.environ.get("GCE_SHORT_MACHINE_TYPE_NAME", "n1-standard-1") - - -with models.DAG( - "example_gcp_compute", - start_date=datetime(2021, 1, 1), - catchup=False, - tags=["example"], -) as dag: - # [START howto_operator_gce_start] - gce_instance_start = ComputeEngineStartInstanceOperator( - project_id=GCP_PROJECT_ID, zone=GCE_ZONE, resource_id=GCE_INSTANCE, task_id="gcp_compute_start_task" - ) - # [END howto_operator_gce_start] - # Duplicate start for idempotence testing - # [START howto_operator_gce_start_no_project_id] - gce_instance_start2 = ComputeEngineStartInstanceOperator( - zone=GCE_ZONE, resource_id=GCE_INSTANCE, task_id="gcp_compute_start_task2" - ) - # [END howto_operator_gce_start_no_project_id] - # [START howto_operator_gce_stop] - gce_instance_stop = ComputeEngineStopInstanceOperator( - project_id=GCP_PROJECT_ID, zone=GCE_ZONE, resource_id=GCE_INSTANCE, task_id="gcp_compute_stop_task" - ) - # [END howto_operator_gce_stop] - # Duplicate stop for idempotence testing - # [START howto_operator_gce_stop_no_project_id] - gce_instance_stop2 = ComputeEngineStopInstanceOperator( - zone=GCE_ZONE, resource_id=GCE_INSTANCE, task_id="gcp_compute_stop_task2" - ) - # [END howto_operator_gce_stop_no_project_id] - # [START howto_operator_gce_set_machine_type] - gce_set_machine_type = ComputeEngineSetMachineTypeOperator( - project_id=GCP_PROJECT_ID, - zone=GCE_ZONE, - resource_id=GCE_INSTANCE, - body={"machineType": f"zones/{GCE_ZONE}/machineTypes/{GCE_SHORT_MACHINE_TYPE_NAME}"}, - task_id="gcp_compute_set_machine_type", - ) - # [END howto_operator_gce_set_machine_type] - # Duplicate set machine type for idempotence testing - # [START howto_operator_gce_set_machine_type_no_project_id] - gce_set_machine_type2 = ComputeEngineSetMachineTypeOperator( - zone=GCE_ZONE, - resource_id=GCE_INSTANCE, - body={"machineType": f"zones/{GCE_ZONE}/machineTypes/{GCE_SHORT_MACHINE_TYPE_NAME}"}, - task_id="gcp_compute_set_machine_type2", - ) - # [END howto_operator_gce_set_machine_type_no_project_id] - - chain( - gce_instance_start, - gce_instance_start2, - gce_instance_stop, - gce_instance_stop2, - gce_set_machine_type, - gce_set_machine_type2, - ) diff --git a/airflow/providers/google/cloud/example_dags/example_compute_ssh.py b/airflow/providers/google/cloud/example_dags/example_compute_ssh.py deleted file mode 100644 index 044789aec2a7e..0000000000000 --- a/airflow/providers/google/cloud/example_dags/example_compute_ssh.py +++ /dev/null @@ -1,90 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import os -from datetime import datetime - -from airflow import models -from airflow.providers.google.cloud.hooks.compute_ssh import ComputeEngineSSHHook -from airflow.providers.ssh.operators.ssh import SSHOperator - -# [START howto_operator_gce_args_common] -GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") -GCE_ZONE = os.environ.get("GCE_ZONE", "europe-west2-a") -GCE_INSTANCE = os.environ.get("GCE_INSTANCE", "target-instance") -# [END howto_operator_gce_args_common] - -with models.DAG( - "example_compute_ssh", - start_date=datetime(2021, 1, 1), - catchup=False, - tags=["example"], -) as dag: - # # [START howto_execute_command_on_remote1] - os_login_without_iap_tunnel = SSHOperator( - task_id="os_login_without_iap_tunnel", - ssh_hook=ComputeEngineSSHHook( - instance_name=GCE_INSTANCE, - zone=GCE_ZONE, - project_id=GCP_PROJECT_ID, - use_oslogin=True, - use_iap_tunnel=False, - ), - command="echo os_login_without_iap_tunnel", - ) - # # [END howto_execute_command_on_remote1] - - # # [START howto_execute_command_on_remote2] - metadata_without_iap_tunnel = SSHOperator( - task_id="metadata_without_iap_tunnel", - ssh_hook=ComputeEngineSSHHook( - instance_name=GCE_INSTANCE, - zone=GCE_ZONE, - use_oslogin=False, - use_iap_tunnel=False, - ), - command="echo metadata_without_iap_tunnel", - ) - # # [END howto_execute_command_on_remote2] - - os_login_with_iap_tunnel = SSHOperator( - task_id="os_login_with_iap_tunnel", - ssh_hook=ComputeEngineSSHHook( - instance_name=GCE_INSTANCE, - zone=GCE_ZONE, - use_oslogin=True, - use_iap_tunnel=True, - ), - command="echo os_login_with_iap_tunnel", - ) - - metadata_with_iap_tunnel = SSHOperator( - task_id="metadata_with_iap_tunnel", - ssh_hook=ComputeEngineSSHHook( - instance_name=GCE_INSTANCE, - zone=GCE_ZONE, - use_oslogin=False, - use_iap_tunnel=True, - ), - command="echo metadata_with_iap_tunnel", - ) - - os_login_with_iap_tunnel >> os_login_without_iap_tunnel - metadata_with_iap_tunnel >> metadata_without_iap_tunnel - - os_login_without_iap_tunnel >> metadata_with_iap_tunnel diff --git a/airflow/providers/google/cloud/hooks/automl.py b/airflow/providers/google/cloud/hooks/automl.py index 6e7e5aa005746..e14bb9044fcbb 100644 --- a/airflow/providers/google/cloud/hooks/automl.py +++ b/airflow/providers/google/cloud/hooks/automl.py @@ -48,6 +48,7 @@ ) from google.protobuf.field_mask_pb2 import FieldMask +from airflow import AirflowException from airflow.compat.functools import cached_property from airflow.providers.google.common.consts import CLIENT_INFO from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook @@ -93,6 +94,14 @@ def get_conn(self) -> AutoMlClient: self._client = AutoMlClient(credentials=self.get_credentials(), client_info=CLIENT_INFO) return self._client + def wait_for_operation(self, operation: Operation, timeout: float | None = None): + """Waits for long-lasting operation to complete.""" + try: + return operation.result(timeout=timeout) + except Exception: + error = operation.exception(timeout=timeout) + raise AirflowException(error) + @cached_property def prediction_client(self) -> PredictionServiceClient: """ diff --git a/airflow/providers/google/cloud/hooks/bigquery.py b/airflow/providers/google/cloud/hooks/bigquery.py index a091dd73fe200..1d6a460831fea 100644 --- a/airflow/providers/google/cloud/hooks/bigquery.py +++ b/airflow/providers/google/cloud/hooks/bigquery.py @@ -45,6 +45,7 @@ LoadJob, QueryJob, SchemaField, + UnknownJob, ) from google.cloud.bigquery.dataset import AccessEntry, Dataset, DatasetListItem, DatasetReference from google.cloud.bigquery.table import EncryptionConfiguration, Row, RowIterator, Table, TableReference @@ -319,7 +320,7 @@ def create_empty_table( view: dict | None = None, materialized_view: dict | None = None, encryption_configuration: dict | None = None, - retry: Retry | None = DEFAULT_RETRY, + retry: Retry = DEFAULT_RETRY, location: str | None = None, exists_ok: bool = True, ) -> Table: @@ -1062,7 +1063,9 @@ def get_datasets_list( # If iterator is requested, we cannot perform a list() on it to log the number # of datasets because we will have started iteration if return_iterator: - return iterator + # The iterator returned by list_datasets() is a HTTPIterator but annotated + # as Iterator + return iterator # type: ignore datasets_list = list(iterator) self.log.info("Datasets List: %s", len(datasets_list)) @@ -1294,9 +1297,9 @@ def list_rows( selected_fields = selected_fields.split(",") if selected_fields: - selected_fields = [SchemaField(n, "") for n in selected_fields] + selected_fields_sequence = [SchemaField(n, "") for n in selected_fields] else: - selected_fields = None + selected_fields_sequence = None table = self._resolve_table_reference( table_resource={}, @@ -1307,7 +1310,7 @@ def list_rows( iterator = self.get_client(project_id=project_id, location=location).list_rows( table=Table.from_api_repr(table), - selected_fields=selected_fields, + selected_fields=selected_fields_sequence, max_results=max_results, page_token=page_token, start_index=start_index, @@ -1503,17 +1506,17 @@ def cancel_job( @GoogleBaseHook.fallback_to_default_project_id def get_job( self, - job_id: str | None = None, + job_id: str, project_id: str | None = None, location: str | None = None, - ) -> CopyJob | QueryJob | LoadJob | ExtractJob: + ) -> CopyJob | QueryJob | LoadJob | ExtractJob | UnknownJob: """ Retrieves a BigQuery job. For more information see: https://cloud.google.com/bigquery/docs/reference/v2/jobs :param job_id: The ID of the job. The ID must contain only letters (a-z, A-Z), numbers (0-9), underscores (_), or dashes (-). The maximum length is 1,024 - characters. If not provided then uuid will be generated. + characters. :param project_id: Google Cloud Project where the job is running :param location: location the job is running """ @@ -1570,14 +1573,14 @@ def insert_job( "jobReference": {"jobId": job_id, "projectId": project_id, "location": location}, } - supported_jobs = { + supported_jobs: dict[str, type[CopyJob] | type[QueryJob] | type[LoadJob] | type[ExtractJob]] = { LoadJob._JOB_TYPE: LoadJob, CopyJob._JOB_TYPE: CopyJob, ExtractJob._JOB_TYPE: ExtractJob, QueryJob._JOB_TYPE: QueryJob, } - job = None + job: type[CopyJob] | type[QueryJob] | type[LoadJob] | type[ExtractJob] | None = None for job_type, job_object in supported_jobs.items(): if job_type in configuration: job = job_object @@ -1585,15 +1588,15 @@ def insert_job( if not job: raise AirflowException(f"Unknown job type. Supported types: {supported_jobs.keys()}") - job = job.from_api_repr(job_data, client) - self.log.info("Inserting job %s", job.job_id) + job_api_repr = job.from_api_repr(job_data, client) + self.log.info("Inserting job %s", job_api_repr.job_id) if nowait: # Initiate the job and don't wait for it to complete. - job._begin() + job_api_repr._begin() else: # Start the job and wait for it to complete and get the result. - job.result(timeout=timeout, retry=retry) - return job + job_api_repr.result(timeout=timeout, retry=retry) + return job_api_repr def run_with_configuration(self, configuration: dict) -> str: """ @@ -2527,7 +2530,7 @@ def get_datasets_list(self, *args, **kwargs) -> list | HTTPIterator: ) return self.hook.get_datasets_list(*args, **kwargs) - def get_dataset(self, *args, **kwargs) -> dict: + def get_dataset(self, *args, **kwargs) -> Dataset: """ This method is deprecated. Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_dataset` @@ -2671,7 +2674,7 @@ def run_copy(self, *args, **kwargs) -> str: ) return self.hook.run_copy(*args, **kwargs) - def run_extract(self, *args, **kwargs) -> str: + def run_extract(self, *args, **kwargs) -> str | BigQueryJob: """ This method is deprecated. Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.run_extract` diff --git a/airflow/providers/google/cloud/hooks/bigtable.py b/airflow/providers/google/cloud/hooks/bigtable.py index 999d141547555..c1e573a85a738 100644 --- a/airflow/providers/google/cloud/hooks/bigtable.py +++ b/airflow/providers/google/cloud/hooks/bigtable.py @@ -263,6 +263,8 @@ def update_cluster(instance: Instance, cluster_id: str, nodes: int) -> None: :param nodes: The desired number of nodes. """ cluster = Cluster(cluster_id, instance) + # "reload" is required to set location_id attribute on cluster. + cluster.reload() cluster.serve_nodes = nodes cluster.update() diff --git a/airflow/providers/google/cloud/hooks/cloud_sql.py b/airflow/providers/google/cloud/hooks/cloud_sql.py index 4b430fa02ed11..69c35dc0f1333 100644 --- a/airflow/providers/google/cloud/hooks/cloud_sql.py +++ b/airflow/providers/google/cloud/hooks/cloud_sql.py @@ -199,7 +199,11 @@ def delete_instance(self, instance: str, project_id: str) -> None: .execute(num_retries=self.num_retries) ) operation_name = response["name"] - self._wait_for_operation_to_complete(project_id=project_id, operation_name=operation_name) + # For some delete instance operations, the operation stops being available ~9 seconds after + # completion, so we need a shorter sleep time to make sure we don't miss the DONE status. + self._wait_for_operation_to_complete( + project_id=project_id, operation_name=operation_name, time_to_sleep=5 + ) @GoogleBaseHook.fallback_to_default_project_id def get_database(self, instance: str, database: str, project_id: str) -> dict: @@ -355,7 +359,7 @@ def clone_instance(self, instance: str, body: dict, project_id: str) -> None: :param instance: Database instance ID to be used for the clone. This does not include the project ID. :param body: The request body, as described in - https://cloud.google.com/sql/docs/mysql/admin-api/rest/v1/instances/clone + https://cloud.google.com/sql/docs/mysql/admin-api/rest/v1beta4/instances/clone :param project_id: Project ID of the project that contains the instance. If set to None or missing, the default project_id from the Google Cloud connection is used. :return: None @@ -372,13 +376,16 @@ def clone_instance(self, instance: str, body: dict, project_id: str) -> None: except HttpError as ex: raise AirflowException(f"Cloning of instance {instance} failed: {ex.content}") - def _wait_for_operation_to_complete(self, project_id: str, operation_name: str) -> None: + def _wait_for_operation_to_complete( + self, project_id: str, operation_name: str, time_to_sleep: int = TIME_TO_SLEEP_IN_SECONDS + ) -> None: """ Waits for the named operation to complete - checks status of the asynchronous call. :param project_id: Project ID of the project that contains the instance. :param operation_name: Name of the operation. + :param time_to_sleep: Time to sleep between active checks of the operation results. :return: None """ service = self.get_conn() @@ -396,7 +403,7 @@ def _wait_for_operation_to_complete(self, project_id: str, operation_name: str) raise AirflowException(error_msg) # No meaningful info to return from the response in case of success return - time.sleep(TIME_TO_SLEEP_IN_SECONDS) + time.sleep(time_to_sleep) CLOUD_SQL_PROXY_DOWNLOAD_URL = "https://dl.google.com/cloudsql/cloud_sql_proxy.{}.{}" diff --git a/airflow/providers/google/cloud/hooks/compute.py b/airflow/providers/google/cloud/hooks/compute.py index 8b839c595a49f..a6cde83425226 100644 --- a/airflow/providers/google/cloud/hooks/compute.py +++ b/airflow/providers/google/cloud/hooks/compute.py @@ -122,7 +122,7 @@ def insert_instance_template( :param metadata: Additional metadata that is provided to the method. """ client = self.get_compute_instance_template_client() - client.insert( + operation = client.insert( # Calling method insert() on client to create Instance Template. # This method accepts request object as an argument and should be of type # Union[google.cloud.compute_v1.types.InsertInstanceTemplateRequest, dict] to construct a request @@ -143,6 +143,7 @@ def insert_instance_template( timeout=timeout, metadata=metadata, ) + self._wait_for_operation_to_complete(operation_name=operation.name, project_id=project_id) @GoogleBaseHook.fallback_to_default_project_id def delete_instance_template( @@ -174,7 +175,7 @@ def delete_instance_template( :param metadata: Additional metadata that is provided to the method. """ client = self.get_compute_instance_template_client() - client.delete( + operation = client.delete( # Calling method delete() on client to delete Instance Template. # This method accepts request object as an argument and should be of type # Union[google.cloud.compute_v1.types.DeleteInstanceTemplateRequest, dict] to @@ -195,6 +196,7 @@ def delete_instance_template( timeout=timeout, metadata=metadata, ) + self._wait_for_operation_to_complete(operation_name=operation.name, project_id=project_id) @GoogleBaseHook.fallback_to_default_project_id def get_instance_template( @@ -222,7 +224,7 @@ def get_instance_template( :rtype: object """ client = self.get_compute_instance_template_client() - instance_template_obj = client.get( + instance_template = client.get( # Calling method get() on client to get the specified Instance Template. # This method accepts request object as an argument and should be of type # Union[google.cloud.compute_v1.types.GetInstanceTemplateRequest, dict] to construct a request @@ -240,7 +242,7 @@ def get_instance_template( timeout=timeout, metadata=metadata, ) - return instance_template_obj + return instance_template @GoogleBaseHook.fallback_to_default_project_id def insert_instance( @@ -271,9 +273,8 @@ def insert_instance( :param source_instance_template: Existing Instance Template that will be used as a base while creating new Instance. When specified, only name of new Instance should be provided as input arguments in 'body' - parameter when creating new Instance. All other parameters, such as machine_type, disks - and network_interfaces and etc will be passed to Instance as they are specified - in the Instance Template. + parameter when creating new Instance. All other parameters, will be passed to Instance as they + are specified in the Instance Template. Full or partial URL and can be represented as examples below: 1. "https://www.googleapis.com/compute/v1/projects/your-project/global/instanceTemplates/temp" 2. "projects/your-project/global/instanceTemplates/temp" @@ -289,7 +290,7 @@ def insert_instance( :param metadata: Additional metadata that is provided to the method. """ client = self.get_compute_instance_client() - client.insert( + operation = client.insert( # Calling method insert() on client to create Instance. # This method accepts request object as an argument and should be of type # Union[google.cloud.compute_v1.types.InsertInstanceRequest, dict] to construct a request @@ -316,6 +317,7 @@ def insert_instance( timeout=timeout, metadata=metadata, ) + self._wait_for_operation_to_complete(project_id=project_id, operation_name=operation.name, zone=zone) @GoogleBaseHook.fallback_to_default_project_id def get_instance( @@ -345,7 +347,7 @@ def get_instance( :rtype: object """ client = self.get_compute_instance_client() - instance_obj = client.get( + instance = client.get( # Calling method get() on client to get the specified Instance. # This method accepts request object as an argument and should be of type # Union[google.cloud.compute_v1.types.GetInstanceRequest, dict] to construct a request @@ -366,7 +368,7 @@ def get_instance( timeout=timeout, metadata=metadata, ) - return instance_obj + return instance @GoogleBaseHook.fallback_to_default_project_id def delete_instance( @@ -400,7 +402,7 @@ def delete_instance( :param metadata: Additional metadata that is provided to the method. """ client = self.get_compute_instance_client() - client.delete( + operation = client.delete( # Calling method delete() on client to delete Instance. # This method accepts request object as an argument and should be of type # Union[google.cloud.compute_v1.types.DeleteInstanceRequest, dict] to construct a request @@ -424,6 +426,7 @@ def delete_instance( timeout=timeout, metadata=metadata, ) + self._wait_for_operation_to_complete(project_id=project_id, operation_name=operation.name, zone=zone) @GoogleBaseHook.fallback_to_default_project_id def start_instance(self, zone: str, resource_id: str, project_id: str) -> None: @@ -538,7 +541,7 @@ def insert_instance_group_manager( :param metadata: Additional metadata that is provided to the method. """ client = self.get_compute_instance_group_managers_client() - client.insert( + operation = client.insert( # Calling method insert() on client to create the specified Instance Group Managers. # This method accepts request object as an argument and should be of type # Union[google.cloud.compute_v1.types.InsertInstanceGroupManagerRequest, dict] to construct @@ -562,6 +565,7 @@ def insert_instance_group_manager( timeout=timeout, metadata=metadata, ) + self._wait_for_operation_to_complete(project_id=project_id, operation_name=operation.name, zone=zone) @GoogleBaseHook.fallback_to_default_project_id def get_instance_group_manager( @@ -591,7 +595,7 @@ def get_instance_group_manager( :rtype: object """ client = self.get_compute_instance_group_managers_client() - instance_group_manager_obj = client.get( + instance_group_manager = client.get( # Calling method get() on client to get the specified Instance Group Manager. # This method accepts request object as an argument and should be of type # Union[google.cloud.compute_v1.types.GetInstanceGroupManagerRequest, dict] to construct a @@ -612,7 +616,7 @@ def get_instance_group_manager( timeout=timeout, metadata=metadata, ) - return instance_group_manager_obj + return instance_group_manager @GoogleBaseHook.fallback_to_default_project_id def delete_instance_group_manager( @@ -645,7 +649,7 @@ def delete_instance_group_manager( :param metadata: Additional metadata that is provided to the method. """ client = self.get_compute_instance_group_managers_client() - client.delete( + operation = client.delete( # Calling method delete() on client to delete Instance Group Managers. # This method accepts request object as an argument and should be of type # Union[google.cloud.compute_v1.types.DeleteInstanceGroupManagerRequest, dict] to construct a @@ -669,6 +673,7 @@ def delete_instance_group_manager( timeout=timeout, metadata=metadata, ) + self._wait_for_operation_to_complete(project_id=project_id, operation_name=operation.name, zone=zone) @GoogleBaseHook.fallback_to_default_project_id def patch_instance_group_manager( @@ -723,10 +728,12 @@ def _wait_for_operation_to_complete( :param operation_name: name of the operation :param zone: optional region of the request (might be None for global operations) + :param project_id: Google Cloud project ID where the Compute Engine Instance exists. :return: None """ service = self.get_conn() while True: + self.log.info("Waiting for Operation to complete...") if zone is None: operation_response = self._check_global_operation_status( service=service, @@ -745,6 +752,7 @@ def _wait_for_operation_to_complete( msg = operation_response.get("httpErrorMessage") # Extracting the errors list as string and trimming square braces error_msg = str(error.get("errors"))[1:-1] + raise AirflowException(f"{code} {msg}: " + error_msg) break time.sleep(TIME_TO_SLEEP_IN_SECONDS) diff --git a/airflow/providers/google/cloud/hooks/compute_ssh.py b/airflow/providers/google/cloud/hooks/compute_ssh.py index 7b474070aee37..d2b2971c6546e 100644 --- a/airflow/providers/google/cloud/hooks/compute_ssh.py +++ b/airflow/providers/google/cloud/hooks/compute_ssh.py @@ -28,12 +28,15 @@ from airflow.providers.google.cloud.hooks.compute import ComputeEngineHook from airflow.providers.google.cloud.hooks.os_login import OSLoginHook from airflow.providers.ssh.hooks.ssh import SSHHook +from airflow.utils.types import NOTSET, ArgNotSet # Paramiko should be imported after airflow.providers.ssh. Then the import will fail with # cannot import "airflow.providers.ssh" and will be correctly discovered as optional feature # TODO:(potiuk) We should add test harness detecting such cases shortly import paramiko # isort:skip +CMD_TIMEOUT = 10 + class _GCloudAuthorizedSSHClient(paramiko.SSHClient): """SSH Client that maintains the context for gcloud authorization during the connection""" @@ -105,6 +108,7 @@ def __init__( use_iap_tunnel: bool = False, use_oslogin: bool = True, expire_time: int = 300, + cmd_timeout: int | ArgNotSet = NOTSET, **kwargs, ) -> None: if kwargs.get("delegate_to") is not None: @@ -124,6 +128,7 @@ def __init__( self.use_oslogin = use_oslogin self.expire_time = expire_time self.gcp_conn_id = gcp_conn_id + self.cmd_timeout = cmd_timeout self._conn: Any | None = None @cached_property @@ -175,6 +180,17 @@ def intify(key, value, default): self.expire_time, ) + if conn.extra is not None: + extra_options = conn.extra_dejson + if "cmd_timeout" in extra_options and self.cmd_timeout is NOTSET: + if extra_options["cmd_timeout"]: + self.cmd_timeout = int(extra_options["cmd_timeout"]) + else: + self.cmd_timeout = None + + if self.cmd_timeout is NOTSET: + self.cmd_timeout = CMD_TIMEOUT + def get_conn(self) -> paramiko.SSHClient: """Return SSH connection.""" self._load_connection_config() diff --git a/airflow/providers/google/cloud/hooks/mlengine.py b/airflow/providers/google/cloud/hooks/mlengine.py index 583c1fe0a4e63..83af18ecf4607 100644 --- a/airflow/providers/google/cloud/hooks/mlengine.py +++ b/airflow/providers/google/cloud/hooks/mlengine.py @@ -21,7 +21,7 @@ import logging import random import time -from typing import Callable, cast +from typing import Callable from aiohttp import ClientSession from gcloud.aio.auth import AioSession, Token @@ -187,7 +187,6 @@ def create_job_without_waiting_result( hook = self.get_conn() self._append_label(body) - request = hook.projects().jobs().create(parent=f"projects/{project_id}", body=body) job_id = body["jobId"] request.execute(num_retries=self.num_retries) @@ -391,6 +390,7 @@ def delete_version( belongs to. (templated) :param project_id: The Google Cloud project name to which MLEngine model belongs. + :param version_name: A name to use for the version being operated upon. (templated) :return: If the version was deleted successfully, returns the operation. Otherwise raises an error. """ @@ -538,9 +538,10 @@ def _append_label(self, model: dict) -> None: class MLEngineAsyncHook(GoogleBaseAsyncHook): - """Uses gcloud-aio library to retrieve Job details""" + """Class to get asynchronous hook for MLEngine""" sync_hook_class = MLEngineHook + scopes = ["https://www.googleapis.com/auth/cloud-platform"] def _check_fileds( self, @@ -553,16 +554,17 @@ def _check_fileds( raise AirflowException("An unique job id is required for Google MLEngine training job.") async def _get_link(self, url: str, session: Session): - s = AioSession(session) - t = Token(scopes=["https://www.googleapis.com/auth/cloud-platform"]) - headers = { - "Authorization": f"Bearer {t.get()}", - "accept": "application/json", - "accept-encoding": "gzip, deflate", - "user-agent": "(gzip)", - "x-goog-api-client": "gdcl/1.12.11 gl-python/3.8.15", - } - return await s.get(url=url, headers=headers) + async with Token(scopes=self.scopes) as token: + session_aio = AioSession(session) + headers = { + "Authorization": f"Bearer {await token.get()}", + } + try: + job = await session_aio.get(url=url, headers=headers) + except AirflowException: + pass # Because the job may not be visible in system yet + + return job async def get_job(self, job_id: str, session: Session, project_id: str | None = None): """Get the specified job resource by job ID and project ID.""" @@ -583,18 +585,17 @@ async def get_job_status( Exception means that Job finished with errors """ self._check_fileds(project_id=project_id, job_id=job_id) - - async with ClientSession() as s: + async with ClientSession() as session: try: - job_response = await self.get_job( - project_id=project_id, job_id=job_id, session=cast(Session, s) + job = await self.get_job( + project_id=project_id, job_id=job_id, session=session # type: ignore ) - json_response = await job_response.json() - self.log.info("Retrieving json_response: %s", json_response) + job = await job.json(content_type=None) + self.log.info("Retrieving json_response: %s", job) - if json_response["state"] in ["SUCCEEDED", "FAILED", "CANCELLED"]: + if job["state"] in ["SUCCEEDED", "FAILED", "CANCELLED"]: job_status = "success" - elif json_response["state"] in ["PREPARING", "RUNNING"]: + elif job["state"] in ["PREPARING", "RUNNING"]: job_status = "pending" except OSError: job_status = "pending" diff --git a/airflow/providers/google/cloud/hooks/natural_language.py b/airflow/providers/google/cloud/hooks/natural_language.py index db0e0528ec58e..ee6f25965cf6c 100644 --- a/airflow/providers/google/cloud/hooks/natural_language.py +++ b/airflow/providers/google/cloud/hooks/natural_language.py @@ -22,7 +22,7 @@ from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.api_core.retry import Retry -from google.cloud.language_v1 import LanguageServiceClient, enums +from google.cloud.language_v1 import EncodingType, LanguageServiceClient from google.cloud.language_v1.types import ( AnalyzeEntitiesResponse, AnalyzeEntitySentimentResponse, @@ -68,7 +68,7 @@ def __init__( gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain, ) - self._conn = None + self._conn: LanguageServiceClient | None = None def get_conn(self) -> LanguageServiceClient: """ @@ -84,7 +84,7 @@ def get_conn(self) -> LanguageServiceClient: def analyze_entities( self, document: dict | Document, - encoding_type: enums.EncodingType | None = None, + encoding_type: EncodingType | None = None, retry: Retry | _MethodDefault = DEFAULT, timeout: float | None = None, metadata: Sequence[tuple[str, str]] = (), @@ -104,6 +104,8 @@ def analyze_entities( """ client = self.get_conn() + if isinstance(document, dict): + document = Document(document) return client.analyze_entities( document=document, encoding_type=encoding_type, retry=retry, timeout=timeout, metadata=metadata ) @@ -112,7 +114,7 @@ def analyze_entities( def analyze_entity_sentiment( self, document: dict | Document, - encoding_type: enums.EncodingType | None = None, + encoding_type: EncodingType | None = None, retry: Retry | _MethodDefault = DEFAULT, timeout: float | None = None, metadata: Sequence[tuple[str, str]] = (), @@ -132,6 +134,8 @@ def analyze_entity_sentiment( """ client = self.get_conn() + if isinstance(document, dict): + document = Document(document) return client.analyze_entity_sentiment( document=document, encoding_type=encoding_type, retry=retry, timeout=timeout, metadata=metadata ) @@ -140,7 +144,7 @@ def analyze_entity_sentiment( def analyze_sentiment( self, document: dict | Document, - encoding_type: enums.EncodingType | None = None, + encoding_type: EncodingType | None = None, retry: Retry | _MethodDefault = DEFAULT, timeout: float | None = None, metadata: Sequence[tuple[str, str]] = (), @@ -159,6 +163,8 @@ def analyze_sentiment( """ client = self.get_conn() + if isinstance(document, dict): + document = Document(document) return client.analyze_sentiment( document=document, encoding_type=encoding_type, retry=retry, timeout=timeout, metadata=metadata ) @@ -167,7 +173,7 @@ def analyze_sentiment( def analyze_syntax( self, document: dict | Document, - encoding_type: enums.EncodingType | None = None, + encoding_type: EncodingType | None = None, retry: Retry | _MethodDefault = DEFAULT, timeout: float | None = None, metadata: Sequence[tuple[str, str]] = (), @@ -187,6 +193,8 @@ def analyze_syntax( """ client = self.get_conn() + if isinstance(document, dict): + document = Document(document) return client.analyze_syntax( document=document, encoding_type=encoding_type, retry=retry, timeout=timeout, metadata=metadata ) @@ -196,7 +204,7 @@ def annotate_text( self, document: dict | Document, features: dict | AnnotateTextRequest.Features, - encoding_type: enums.EncodingType = None, + encoding_type: EncodingType | None = None, retry: Retry | _MethodDefault = DEFAULT, timeout: float | None = None, metadata: Sequence[tuple[str, str]] = (), @@ -218,6 +226,11 @@ def annotate_text( """ client = self.get_conn() + if isinstance(document, dict): + document = Document(document) + if isinstance(features, dict): + features = AnnotateTextRequest.Features(features) + return client.annotate_text( document=document, features=features, @@ -248,4 +261,6 @@ def classify_text( """ client = self.get_conn() + if isinstance(document, dict): + document = Document(document) return client.classify_text(document=document, retry=retry, timeout=timeout, metadata=metadata) diff --git a/airflow/providers/google/cloud/hooks/spanner.py b/airflow/providers/google/cloud/hooks/spanner.py index 60da28e9bc323..d83cfc598ad8d 100644 --- a/airflow/providers/google/cloud/hooks/spanner.py +++ b/airflow/providers/google/cloud/hooks/spanner.py @@ -55,7 +55,7 @@ def __init__( gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain, ) - self._client = None + self._client: Client | None = None def _get_client(self, project_id: str) -> Client: """ @@ -75,7 +75,7 @@ def get_instance( self, instance_id: str, project_id: str, - ) -> Instance: + ) -> Instance | None: """ Gets information about a particular instance. diff --git a/airflow/providers/google/cloud/hooks/speech_to_text.py b/airflow/providers/google/cloud/hooks/speech_to_text.py index e21bb9677c390..4dc0568b2e7d0 100644 --- a/airflow/providers/google/cloud/hooks/speech_to_text.py +++ b/airflow/providers/google/cloud/hooks/speech_to_text.py @@ -59,7 +59,7 @@ def __init__( gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain, ) - self._client = None + self._client: SpeechClient | None = None def get_conn(self) -> SpeechClient: """ @@ -92,6 +92,11 @@ def recognize_speech( Note that if retry is specified, the timeout applies to each individual attempt. """ client = self.get_conn() + if isinstance(config, dict): + config = RecognitionConfig(config) + if isinstance(audio, dict): + audio = RecognitionAudio(audio) + response = client.recognize(config=config, audio=audio, retry=retry, timeout=timeout) self.log.info("Recognised speech: %s", response) return response diff --git a/airflow/providers/google/cloud/hooks/text_to_speech.py b/airflow/providers/google/cloud/hooks/text_to_speech.py index 8cd44fc64e9b9..4e530f18c2e62 100644 --- a/airflow/providers/google/cloud/hooks/text_to_speech.py +++ b/airflow/providers/google/cloud/hooks/text_to_speech.py @@ -107,8 +107,15 @@ def synthesize_speech( https://googleapis.github.io/google-cloud-python/latest/texttospeech/gapic/v1/types.html#google.cloud.texttospeech_v1.types.SynthesizeSpeechResponse """ client = self.get_conn() + + if isinstance(input_data, dict): + input_data = SynthesisInput(input_data) + if isinstance(voice, dict): + voice = VoiceSelectionParams(voice) + if isinstance(audio_config, dict): + audio_config = AudioConfig(audio_config) self.log.info("Synthesizing input: %s", input_data) return client.synthesize_speech( - input_=input_data, voice=voice, audio_config=audio_config, retry=retry, timeout=timeout + input=input_data, voice=voice, audio_config=audio_config, retry=retry, timeout=timeout ) diff --git a/airflow/providers/google/cloud/hooks/video_intelligence.py b/airflow/providers/google/cloud/hooks/video_intelligence.py index 103025c77f96b..fa1d2f56e5410 100644 --- a/airflow/providers/google/cloud/hooks/video_intelligence.py +++ b/airflow/providers/google/cloud/hooks/video_intelligence.py @@ -23,8 +23,11 @@ from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.api_core.operation import Operation from google.api_core.retry import Retry -from google.cloud.videointelligence_v1 import VideoIntelligenceServiceClient -from google.cloud.videointelligence_v1.types import VideoContext +from google.cloud.videointelligence_v1 import ( + Feature, + VideoContext, + VideoIntelligenceServiceClient, +) from airflow.providers.google.common.consts import CLIENT_INFO from airflow.providers.google.common.hooks.base_google import GoogleBaseHook @@ -63,7 +66,7 @@ def __init__( gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain, ) - self._conn = None + self._conn: VideoIntelligenceServiceClient | None = None def get_conn(self) -> VideoIntelligenceServiceClient: """Returns Gcp Video Intelligence Service client""" @@ -78,8 +81,8 @@ def annotate_video( self, input_uri: str | None = None, input_content: bytes | None = None, - features: list[VideoIntelligenceServiceClient.enums.Feature] | None = None, - video_context: dict | VideoContext = None, + features: Sequence[Feature] | None = None, + video_context: dict | VideoContext | None = None, output_uri: str | None = None, location: str | None = None, retry: Retry | _MethodDefault = DEFAULT, @@ -109,13 +112,16 @@ def annotate_video( :param metadata: Optional, Additional metadata that is provided to the method. """ client = self.get_conn() + return client.annotate_video( - input_uri=input_uri, - input_content=input_content, - features=features, - video_context=video_context, - output_uri=output_uri, - location_id=location, + request={ + "input_uri": input_uri, + "features": features, + "input_content": input_content, + "video_context": video_context, + "output_uri": output_uri, + "location_id": location, + }, retry=retry, timeout=timeout, metadata=metadata, diff --git a/airflow/providers/google/cloud/hooks/vision.py b/airflow/providers/google/cloud/hooks/vision.py index dace3619cff4f..20c0c81f75431 100644 --- a/airflow/providers/google/cloud/hooks/vision.py +++ b/airflow/providers/google/cloud/hooks/vision.py @@ -23,15 +23,16 @@ from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.api_core.retry import Retry -from google.cloud.vision_v1 import ImageAnnotatorClient, ProductSearchClient -from google.cloud.vision_v1.types import ( +from google.cloud.vision_v1 import ( AnnotateImageRequest, - FieldMask, Image, + ImageAnnotatorClient, Product, + ProductSearchClient, ProductSet, ReferenceImage, ) +from google.protobuf import field_mask_pb2 from google.protobuf.json_format import MessageToDict from airflow.compat.functools import cached_property @@ -118,6 +119,7 @@ class CloudVisionHook(GoogleBaseHook): keyword arguments rather than positional. """ + _client: ProductSearchClient | None product_name_determiner = NameDeterminer("Product", "product_id", ProductSearchClient.product_path) product_set_name_determiner = NameDeterminer( "ProductSet", "productset_id", ProductSearchClient.product_set_path @@ -168,7 +170,7 @@ def _check_for_error(response: dict) -> None: def create_product_set( self, location: str, - product_set: dict | ProductSet, + product_set: ProductSet | None, project_id: str = PROVIDE_PROJECT_ID, product_set_id: str | None = None, retry: Retry | _MethodDefault = DEFAULT, @@ -180,7 +182,7 @@ def create_product_set( :class:`~airflow.providers.google.cloud.operators.vision.CloudVisionCreateProductSetOperator` """ client = self.get_conn() - parent = ProductSearchClient.location_path(project_id, location) + parent = f"projects/{project_id}/locations/{location}" self.log.info("Creating a new ProductSet under the parent: %s", parent) response = client.create_product_set( parent=parent, @@ -220,7 +222,7 @@ def get_product_set( response = client.get_product_set(name=name, retry=retry, timeout=timeout, metadata=metadata) self.log.info("ProductSet retrieved.") self.log.debug("ProductSet retrieved:\n%s", response) - return MessageToDict(response) + return MessageToDict(response._pb) @GoogleBaseHook.fallback_to_default_project_id def update_product_set( @@ -229,7 +231,7 @@ def update_product_set( project_id: str = PROVIDE_PROJECT_ID, location: str | None = None, product_set_id: str | None = None, - update_mask: dict | FieldMask = None, + update_mask: dict | field_mask_pb2.FieldMask | None = None, retry: Retry | _MethodDefault = DEFAULT, timeout: float | None = None, metadata: Sequence[tuple[str, str]] = (), @@ -239,16 +241,23 @@ def update_product_set( :class:`~airflow.providers.google.cloud.operators.vision.CloudVisionUpdateProductSetOperator` """ client = self.get_conn() + product_set = self.product_set_name_determiner.get_entity_with_name( product_set, product_set_id, location, project_id ) + if isinstance(product_set, dict): + product_set = ProductSet(product_set) self.log.info("Updating ProductSet: %s", product_set.name) response = client.update_product_set( - product_set=product_set, update_mask=update_mask, retry=retry, timeout=timeout, metadata=metadata + product_set=product_set, + update_mask=update_mask, # type: ignore + retry=retry, + timeout=timeout, + metadata=metadata, ) self.log.info("ProductSet updated: %s", response.name if response else "") self.log.debug("ProductSet updated:\n%s", response) - return MessageToDict(response) + return MessageToDict(response._pb) @GoogleBaseHook.fallback_to_default_project_id def delete_product_set( @@ -286,8 +295,11 @@ def create_product( :class:`~airflow.providers.google.cloud.operators.vision.CloudVisionCreateProductOperator` """ client = self.get_conn() - parent = ProductSearchClient.location_path(project_id, location) + parent = f"projects/{project_id}/locations/{location}" self.log.info("Creating a new Product under the parent: %s", parent) + + if isinstance(product, dict): + product = Product(product) response = client.create_product( parent=parent, product=product, @@ -326,7 +338,7 @@ def get_product( response = client.get_product(name=name, retry=retry, timeout=timeout, metadata=metadata) self.log.info("Product retrieved.") self.log.debug("Product retrieved:\n%s", response) - return MessageToDict(response) + return MessageToDict(response._pb) @GoogleBaseHook.fallback_to_default_project_id def update_product( @@ -335,7 +347,7 @@ def update_product( project_id: str = PROVIDE_PROJECT_ID, location: str | None = None, product_id: str | None = None, - update_mask: dict[str, FieldMask] | None = None, + update_mask: dict | field_mask_pb2.FieldMask | None = None, retry: Retry | _MethodDefault = DEFAULT, timeout: float | None = None, metadata: Sequence[tuple[str, str]] = (), @@ -345,14 +357,21 @@ def update_product( :class:`~airflow.providers.google.cloud.operators.vision.CloudVisionUpdateProductOperator` """ client = self.get_conn() + product = self.product_name_determiner.get_entity_with_name(product, product_id, location, project_id) + if isinstance(product, dict): + product = Product(product) self.log.info("Updating ProductSet: %s", product.name) response = client.update_product( - product=product, update_mask=update_mask, retry=retry, timeout=timeout, metadata=metadata + product=product, + update_mask=update_mask, # type: ignore + retry=retry, + timeout=timeout, + metadata=metadata, ) self.log.info("Product updated: %s", response.name if response else "") self.log.debug("Product updated:\n%s", response) - return MessageToDict(response) + return MessageToDict(response._pb) @GoogleBaseHook.fallback_to_default_project_id def delete_product( @@ -394,6 +413,8 @@ def create_reference_image( self.log.info("Creating ReferenceImage") parent = ProductSearchClient.product_path(project=project_id, location=location, product=product_id) + if isinstance(reference_image, dict): + reference_image = ReferenceImage(reference_image) response = client.create_reference_image( parent=parent, reference_image=reference_image, @@ -451,7 +472,7 @@ def add_product_to_product_set( product_set_id: str, product_id: str, project_id: str, - location: str | None = None, + location: str, retry: Retry | _MethodDefault = DEFAULT, timeout: float | None = None, metadata: Sequence[tuple[str, str]] = (), @@ -479,7 +500,7 @@ def remove_product_from_product_set( product_set_id: str, product_id: str, project_id: str, - location: str | None = None, + location: str, retry: Retry | _MethodDefault = DEFAULT, timeout: float | None = None, metadata: Sequence[tuple[str, str]] = (), @@ -519,7 +540,7 @@ def annotate_image( self.log.info("Image annotated") - return MessageToDict(response) + return MessageToDict(response._pb) @GoogleBaseHook.quota_retry() def batch_annotate_images( @@ -536,11 +557,12 @@ def batch_annotate_images( self.log.info("Annotating images") + requests = list(map(AnnotateImageRequest, requests)) response = client.batch_annotate_images(requests=requests, retry=retry, timeout=timeout) self.log.info("Images annotated") - return MessageToDict(response) + return MessageToDict(response._pb) @GoogleBaseHook.quota_retry() def text_detection( @@ -565,7 +587,7 @@ def text_detection( response = client.text_detection( image=image, max_results=max_results, retry=retry, timeout=timeout, **additional_properties ) - response = MessageToDict(response) + response = MessageToDict(response._pb) self._check_for_error(response) self.log.info("Text detection finished") @@ -595,7 +617,7 @@ def document_text_detection( response = client.document_text_detection( image=image, max_results=max_results, retry=retry, timeout=timeout, **additional_properties ) - response = MessageToDict(response) + response = MessageToDict(response._pb) self._check_for_error(response) self.log.info("Document text detection finished") @@ -625,7 +647,7 @@ def label_detection( response = client.label_detection( image=image, max_results=max_results, retry=retry, timeout=timeout, **additional_properties ) - response = MessageToDict(response) + response = MessageToDict(response._pb) self._check_for_error(response) self.log.info("Labels detection finished") @@ -655,7 +677,7 @@ def safe_search_detection( response = client.safe_search_detection( image=image, max_results=max_results, retry=retry, timeout=timeout, **additional_properties ) - response = MessageToDict(response) + response = MessageToDict(response._pb) self._check_for_error(response) self.log.info("Safe search detection finished") diff --git a/airflow/providers/google/cloud/log/stackdriver_task_handler.py b/airflow/providers/google/cloud/log/stackdriver_task_handler.py index 5190fbad760b5..ca3ad9bad35e1 100644 --- a/airflow/providers/google/cloud/log/stackdriver_task_handler.py +++ b/airflow/providers/google/cloud/log/stackdriver_task_handler.py @@ -315,10 +315,10 @@ def _read_single_logs_page(self, log_filter: str, page_token: str | None = None) ) response = self._logging_service_client.list_log_entries(request=request) page: ListLogEntriesResponse = next(response.pages) - messages = [] + messages: list[str] = [] for entry in page.entries: if "message" in (entry.json_payload or {}): - messages.append(entry.json_payload["message"]) + messages.append(entry.json_payload["message"]) # type: ignore elif entry.text_payload: messages.append(entry.text_payload) return "\n".join(messages), page.next_page_token diff --git a/airflow/providers/google/cloud/operators/automl.py b/airflow/providers/google/cloud/operators/automl.py index 4c53702ea5ac1..0503e6bffefef 100644 --- a/airflow/providers/google/cloud/operators/automl.py +++ b/airflow/providers/google/cloud/operators/automl.py @@ -116,7 +116,7 @@ def execute(self, context: Context): gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) - self.log.info("Creating model.") + self.log.info("Creating model %s...", self.model["display_name"]) operation = hook.create_model( model=self.model, location=self.location, @@ -128,9 +128,10 @@ def execute(self, context: Context): project_id = self.project_id or hook.project_id if project_id: AutoMLModelTrainLink.persist(context=context, task_instance=self, project_id=project_id) - result = Model.to_dict(operation.result()) + operation_result = hook.wait_for_operation(timeout=self.timeout, operation=operation) + result = Model.to_dict(operation_result) model_id = hook.extract_object_id(result) - self.log.info("Model created: %s", model_id) + self.log.info("Model is created, model_id: %s", model_id) self.xcom_push(context, key="model_id", value=model_id) if project_id: @@ -332,8 +333,9 @@ def execute(self, context: Context): timeout=self.timeout, metadata=self.metadata, ) - result = BatchPredictResult.to_dict(operation.result()) - self.log.info("Batch prediction ready.") + operation_result = hook.wait_for_operation(timeout=self.timeout, operation=operation) + result = BatchPredictResult.to_dict(operation_result) + self.log.info("Batch prediction is ready.") project_id = self.project_id or hook.project_id if project_id: AutoMLModelPredictLink.persist( @@ -412,7 +414,7 @@ def execute(self, context: Context): gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) - self.log.info("Creating dataset") + self.log.info("Creating dataset %s...", self.dataset) result = hook.create_dataset( dataset=self.dataset, location=self.location, @@ -508,7 +510,7 @@ def execute(self, context: Context): gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) - self.log.info("Importing dataset") + self.log.info("Importing data to dataset...") operation = hook.import_data( dataset_id=self.dataset_id, input_config=self.input_config, @@ -518,8 +520,8 @@ def execute(self, context: Context): timeout=self.timeout, metadata=self.metadata, ) - operation.result() - self.log.info("Import completed") + hook.wait_for_operation(timeout=self.timeout, operation=operation) + self.log.info("Import is completed") project_id = self.project_id or hook.project_id if project_id: AutoMLDatasetLink.persist( @@ -887,7 +889,8 @@ def execute(self, context: Context): timeout=self.timeout, metadata=self.metadata, ) - operation.result() + hook.wait_for_operation(timeout=self.timeout, operation=operation) + self.log.info("Deletion is completed") class AutoMLDeployModelOperator(GoogleCloudBaseOperator): @@ -976,8 +979,8 @@ def execute(self, context: Context): timeout=self.timeout, metadata=self.metadata, ) - operation.result() - self.log.info("Model deployed.") + hook.wait_for_operation(timeout=self.timeout, operation=operation) + self.log.info("Model was deployed successfully.") class AutoMLTablesListTableSpecsOperator(GoogleCloudBaseOperator): diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py index 67798ee77f2c8..8cf3489ccf2ae 100644 --- a/airflow/providers/google/cloud/operators/bigquery.py +++ b/airflow/providers/google/cloud/operators/bigquery.py @@ -27,6 +27,7 @@ from google.api_core.exceptions import Conflict from google.api_core.retry import Retry from google.cloud.bigquery import DEFAULT_RETRY, CopyJob, ExtractJob, LoadJob, QueryJob +from google.cloud.bigquery.table import RowIterator from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException from airflow.models import BaseOperator, BaseOperatorLink @@ -52,6 +53,8 @@ ) if TYPE_CHECKING: + from google.cloud.bigquery import UnknownJob + from airflow.models.taskinstancekey import TaskInstanceKey from airflow.utils.context import Context @@ -926,6 +929,10 @@ def execute(self, context: Context): project_id=self.project_id, ) + if isinstance(rows, RowIterator): + raise TypeError( + "BigQueryHook.list_rows() returns iterator when return_iterator is False (default)" + ) self.log.info("Total extracted rows: %s", len(rows)) if self.as_dict: @@ -1952,12 +1959,12 @@ def execute(self, context: Context): self.log.info("Start getting dataset: %s:%s", self.project_id, self.dataset_id) dataset = bq_hook.get_dataset(dataset_id=self.dataset_id, project_id=self.project_id) - dataset = dataset.to_api_repr() + dataset_api_repr = dataset.to_api_repr() BigQueryDatasetLink.persist( context=context, task_instance=self, - dataset_id=dataset["datasetReference"]["datasetId"], - project_id=dataset["datasetReference"]["projectId"], + dataset_id=dataset_api_repr["datasetReference"]["datasetId"], + project_id=dataset_api_repr["datasetReference"]["projectId"], ) return dataset @@ -2249,12 +2256,12 @@ def execute(self, context: Context): fields=fields, ) - dataset = dataset.to_api_repr() + dataset_api_repr = dataset.to_api_repr() BigQueryDatasetLink.persist( context=context, task_instance=self, - dataset_id=dataset["datasetReference"]["datasetId"], - project_id=dataset["datasetReference"]["projectId"], + dataset_id=dataset_api_repr["datasetReference"]["datasetId"], + project_id=dataset_api_repr["datasetReference"]["projectId"], ) return dataset @@ -2622,7 +2629,7 @@ def _submit_job( ) @staticmethod - def _handle_job_error(job: BigQueryJob) -> None: + def _handle_job_error(job: BigQueryJob | UnknownJob) -> None: if job.error_result: raise AirflowException(f"BigQuery job {job.job_id} failed: {job.error_result}") @@ -2644,7 +2651,7 @@ def execute(self, context: Any): try: self.log.info("Executing: %s'", self.configuration) - job = self._submit_job(hook, job_id) + job: BigQueryJob | UnknownJob = self._submit_job(hook, job_id) except Conflict: # If the job already exists retrieve it job = hook.get_job( diff --git a/airflow/providers/google/cloud/operators/compute.py b/airflow/providers/google/cloud/operators/compute.py index 5abc72b8d7577..db8a98e443484 100644 --- a/airflow/providers/google/cloud/operators/compute.py +++ b/airflow/providers/google/cloud/operators/compute.py @@ -944,6 +944,7 @@ def execute(self, context: Context) -> dict: ) self._validate_all_body_fields() self.check_body_fields() + self._field_sanitizer.sanitize(self.body) try: # Idempotence check (sort of) - we want to check if the new Template # is already created and if is, then we assume it was created by previous run @@ -1093,7 +1094,7 @@ def execute(self, context: Context) -> None: project_id=self.project_id, request_id=self.request_id, ) - self.log.info("Successfully deleted Instance template") + self.log.info("Successfully deleted Instance template %s", self.resource_id) except exceptions.NotFound as e: # Expecting 404 Error in case if Instance template doesn't exist. if e.code == 404: @@ -1246,7 +1247,7 @@ def execute(self, context: Context) -> dict: new_body = merge(new_body, self.body_patch) self.log.info("Calling insert instance template with updated body: %s", new_body) hook.insert_instance_template(body=new_body, request_id=self.request_id, project_id=self.project_id) - instance_template = hook.get_instance_template( + new_instance_tmp = hook.get_instance_template( resource_id=self.body_patch["name"], project_id=self.project_id ) ComputeInstanceTemplateDetailsLink.persist( @@ -1255,7 +1256,7 @@ def execute(self, context: Context) -> dict: resource_id=self.body_patch["name"], project_id=self.project_id or hook.project_id, ) - return InstanceTemplate.to_dict(instance_template) + return InstanceTemplate.to_dict(new_instance_tmp) class ComputeEngineInstanceGroupUpdateManagerTemplateOperator(ComputeEngineBaseOperator): diff --git a/airflow/providers/google/cloud/operators/natural_language.py b/airflow/providers/google/cloud/operators/natural_language.py index 257c7ea905aa2..2665f9ced60ab 100644 --- a/airflow/providers/google/cloud/operators/natural_language.py +++ b/airflow/providers/google/cloud/operators/natural_language.py @@ -22,8 +22,7 @@ from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.api_core.retry import Retry -from google.cloud.language_v1 import enums -from google.cloud.language_v1.types import Document +from google.cloud.language_v1.types import Document, EncodingType from google.protobuf.json_format import MessageToDict from airflow.providers.google.cloud.hooks.natural_language import CloudNaturalLanguageHook @@ -76,7 +75,7 @@ def __init__( self, *, document: dict | Document, - encoding_type: enums.EncodingType | None = None, + encoding_type: EncodingType | None = None, retry: Retry | _MethodDefault = DEFAULT, timeout: float | None = None, metadata: MetaData = (), @@ -105,7 +104,7 @@ def execute(self, context: Context): ) self.log.info("Finished analyzing entities") - return MessageToDict(response) + return MessageToDict(response._pb) class CloudNaturalLanguageAnalyzeEntitySentimentOperator(GoogleCloudBaseOperator): @@ -149,7 +148,7 @@ def __init__( self, *, document: dict | Document, - encoding_type: enums.EncodingType | None = None, + encoding_type: EncodingType | None = None, retry: Retry | _MethodDefault = DEFAULT, timeout: float | None = None, metadata: MetaData = (), @@ -182,7 +181,7 @@ def execute(self, context: Context): ) self.log.info("Finished entity sentiment analyze") - return MessageToDict(response) + return MessageToDict(response._pb) class CloudNaturalLanguageAnalyzeSentimentOperator(GoogleCloudBaseOperator): @@ -225,7 +224,7 @@ def __init__( self, *, document: dict | Document, - encoding_type: enums.EncodingType | None = None, + encoding_type: EncodingType | None = None, retry: Retry | _MethodDefault = DEFAULT, timeout: float | None = None, metadata: MetaData = (), @@ -254,7 +253,7 @@ def execute(self, context: Context): ) self.log.info("Finished sentiment analyze") - return MessageToDict(response) + return MessageToDict(response._pb) class CloudNaturalLanguageClassifyTextOperator(GoogleCloudBaseOperator): @@ -322,4 +321,4 @@ def execute(self, context: Context): ) self.log.info("Finished text classify") - return MessageToDict(response) + return MessageToDict(response._pb) diff --git a/airflow/providers/google/cloud/operators/speech_to_text.py b/airflow/providers/google/cloud/operators/speech_to_text.py index 6d5a97343b537..8d4d92b30b50b 100644 --- a/airflow/providers/google/cloud/operators/speech_to_text.py +++ b/airflow/providers/google/cloud/operators/speech_to_text.py @@ -122,4 +122,4 @@ def execute(self, context: Context): response = hook.recognize_speech( config=self.config, audio=self.audio, retry=self.retry, timeout=self.timeout ) - return MessageToDict(response) + return MessageToDict(response._pb) diff --git a/airflow/providers/google/cloud/operators/translate_speech.py b/airflow/providers/google/cloud/operators/translate_speech.py index 0e9881eac292a..9295d77f1ddc0 100644 --- a/airflow/providers/google/cloud/operators/translate_speech.py +++ b/airflow/providers/google/cloud/operators/translate_speech.py @@ -151,7 +151,7 @@ def execute(self, context: Context) -> dict: ) recognize_result = speech_to_text_hook.recognize_speech(config=self.config, audio=self.audio) - recognize_dict = MessageToDict(recognize_result) + recognize_dict = MessageToDict(recognize_result._pb) self.log.info("Recognition operation finished") diff --git a/airflow/providers/google/cloud/operators/video_intelligence.py b/airflow/providers/google/cloud/operators/video_intelligence.py index 7b8d57baec8f5..2069513d063da 100644 --- a/airflow/providers/google/cloud/operators/video_intelligence.py +++ b/airflow/providers/google/cloud/operators/video_intelligence.py @@ -22,8 +22,7 @@ from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.api_core.retry import Retry -from google.cloud.videointelligence_v1 import enums -from google.cloud.videointelligence_v1.types import VideoContext +from google.cloud.videointelligence_v1 import Feature, VideoContext from google.protobuf.json_format import MessageToDict from airflow.providers.google.cloud.hooks.video_intelligence import CloudVideoIntelligenceHook @@ -84,7 +83,7 @@ def __init__( input_uri: str, input_content: bytes | None = None, output_uri: str | None = None, - video_context: dict | VideoContext = None, + video_context: dict | VideoContext | None = None, location: str | None = None, retry: Retry | _MethodDefault = DEFAULT, timeout: float | None = None, @@ -114,11 +113,11 @@ def execute(self, context: Context): video_context=self.video_context, location=self.location, retry=self.retry, - features=[enums.Feature.LABEL_DETECTION], + features=[Feature.LABEL_DETECTION], timeout=self.timeout, ) self.log.info("Processing video for label annotations") - result = MessageToDict(operation.result()) + result = MessageToDict(operation.result()._pb) self.log.info("Finished processing.") return result @@ -174,7 +173,7 @@ def __init__( input_uri: str, output_uri: str | None = None, input_content: bytes | None = None, - video_context: dict | VideoContext = None, + video_context: dict | VideoContext | None = None, location: str | None = None, retry: Retry | _MethodDefault = DEFAULT, timeout: float | None = None, @@ -204,11 +203,11 @@ def execute(self, context: Context): video_context=self.video_context, location=self.location, retry=self.retry, - features=[enums.Feature.EXPLICIT_CONTENT_DETECTION], + features=[Feature.EXPLICIT_CONTENT_DETECTION], timeout=self.timeout, ) self.log.info("Processing video for explicit content annotations") - result = MessageToDict(operation.result()) + result = MessageToDict(operation.result()._pb) self.log.info("Finished processing.") return result @@ -264,7 +263,7 @@ def __init__( input_uri: str, output_uri: str | None = None, input_content: bytes | None = None, - video_context: dict | VideoContext = None, + video_context: dict | VideoContext | None = None, location: str | None = None, retry: Retry | _MethodDefault = DEFAULT, timeout: float | None = None, @@ -294,10 +293,10 @@ def execute(self, context: Context): video_context=self.video_context, location=self.location, retry=self.retry, - features=[enums.Feature.SHOT_CHANGE_DETECTION], + features=[Feature.SHOT_CHANGE_DETECTION], timeout=self.timeout, ) self.log.info("Processing video for video shots annotations") - result = MessageToDict(operation.result()) + result = MessageToDict(operation.result()._pb) self.log.info("Finished processing.") return result diff --git a/airflow/providers/google/cloud/operators/vision.py b/airflow/providers/google/cloud/operators/vision.py index f2c4443a12c76..76eabe242060e 100644 --- a/airflow/providers/google/cloud/operators/vision.py +++ b/airflow/providers/google/cloud/operators/vision.py @@ -24,14 +24,14 @@ from google.api_core.exceptions import AlreadyExists from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.api_core.retry import Retry -from google.cloud.vision_v1.types import ( +from google.cloud.vision_v1 import ( AnnotateImageRequest, - FieldMask, Image, Product, ProductSet, ReferenceImage, ) +from google.protobuf.field_mask_pb2 import FieldMask # type: ignore from airflow.providers.google.cloud.hooks.vision import CloudVisionHook from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator @@ -278,7 +278,7 @@ def __init__( location: str | None = None, product_set_id: str | None = None, project_id: str | None = None, - update_mask: dict | FieldMask = None, + update_mask: dict | FieldMask | None = None, retry: Retry | _MethodDefault = DEFAULT, timeout: float | None = None, metadata: MetaData = (), @@ -303,6 +303,9 @@ def execute(self, context: Context): gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) + + if isinstance(self.product_set, dict): + self.product_set = ProductSet(self.product_set) return hook.update_product_set( location=self.location, product_set_id=self.product_set_id, @@ -650,7 +653,7 @@ def __init__( location: str | None = None, product_id: str | None = None, project_id: str | None = None, - update_mask: dict | FieldMask = None, + update_mask: dict | FieldMask | None = None, retry: Retry | _MethodDefault = DEFAULT, timeout: float | None = None, metadata: MetaData = (), @@ -675,12 +678,13 @@ def execute(self, context: Context): gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) + return hook.update_product( product=self.product, location=self.location, product_id=self.product_id, project_id=self.project_id, - update_mask=self.update_mask, + update_mask=self.update_mask, # type: ignore retry=self.retry, timeout=self.timeout, metadata=self.metadata, @@ -923,6 +927,9 @@ def execute(self, context: Context): gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) + + if isinstance(self.reference_image, dict): + self.reference_image = ReferenceImage(self.reference_image) return hook.create_reference_image( location=self.location, product_id=self.product_id, diff --git a/airflow/providers/google/cloud/operators/workflows.py b/airflow/providers/google/cloud/operators/workflows.py index 065072967c9f9..ad7aaf6a0416d 100644 --- a/airflow/providers/google/cloud/operators/workflows.py +++ b/airflow/providers/google/cloud/operators/workflows.py @@ -664,7 +664,11 @@ def execute(self, context: Context): project_id=self.project_id or hook.project_id, ) - return [Execution.to_dict(e) for e in execution_iter if e.start_time > self.start_date_filter] + return [ + Execution.to_dict(e) + for e in execution_iter + if e.start_time.ToDatetime(tzinfo=pytz.UTC) > self.start_date_filter + ] class WorkflowsGetExecutionOperator(GoogleCloudBaseOperator): diff --git a/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py b/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py index 68be7e215b8bb..7ec62db9bfbd0 100644 --- a/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py @@ -22,7 +22,7 @@ from google.api_core.exceptions import Conflict from google.api_core.retry import Retry -from google.cloud.bigquery import DEFAULT_RETRY, ExtractJob +from google.cloud.bigquery import DEFAULT_RETRY, UnknownJob from airflow import AirflowException from airflow.models import BaseOperator @@ -138,7 +138,7 @@ def __init__( self.deferrable = deferrable @staticmethod - def _handle_job_error(job: ExtractJob) -> None: + def _handle_job_error(job: BigQueryJob | UnknownJob) -> None: if job.error_result: raise AirflowException(f"BigQuery job {job.job_id} failed: {job.error_result}") @@ -216,7 +216,9 @@ def execute(self, context: Context): try: self.log.info("Executing: %s", configuration) - job: ExtractJob = self._submit_job(hook=hook, job_id=job_id, configuration=configuration) + job: BigQueryJob | UnknownJob = self._submit_job( + hook=hook, job_id=job_id, configuration=configuration + ) except Conflict: # If the job already exists retrieve it job = hook.get_job( diff --git a/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py b/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py index f009325facfbf..8a66c1aac890c 100644 --- a/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +++ b/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py @@ -31,6 +31,7 @@ LoadJob, QueryJob, SchemaField, + UnknownJob, ) from google.cloud.bigquery.table import EncryptionConfiguration, Table, TableReference @@ -307,7 +308,7 @@ def _submit_job( ) @staticmethod - def _handle_job_error(job: BigQueryJob) -> None: + def _handle_job_error(job: BigQueryJob | UnknownJob) -> None: if job.error_result: raise AirflowException(f"BigQuery job {job.job_id} failed: {job.error_result}") @@ -374,7 +375,7 @@ def execute(self, context: Context): try: self.log.info("Executing: %s", self.configuration) - job = self._submit_job(self.hook, job_id) + job: BigQueryJob | UnknownJob = self._submit_job(self.hook, job_id) except Conflict: # If the job already exists retrieve it job = self.hook.get_job( diff --git a/airflow/providers/google/cloud/utils/bigquery_get_data.py b/airflow/providers/google/cloud/utils/bigquery_get_data.py index 39ab1ef35b989..8c5c38add7544 100644 --- a/airflow/providers/google/cloud/utils/bigquery_get_data.py +++ b/airflow/providers/google/cloud/utils/bigquery_get_data.py @@ -19,7 +19,7 @@ from collections.abc import Iterator from logging import Logger -from google.cloud.bigquery.table import Row +from google.cloud.bigquery.table import Row, RowIterator from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook @@ -37,7 +37,7 @@ def bigquery_get_data( i = 0 while True: - rows: list[Row] = big_query_hook.list_rows( + rows: list[Row] | RowIterator = big_query_hook.list_rows( dataset_id=dataset_id, table_id=table_id, max_results=batch_size, @@ -45,6 +45,9 @@ def bigquery_get_data( start_index=i * batch_size, ) + if isinstance(rows, RowIterator): + raise TypeError("BigQueryHook.list_rows() returns iterator when return_iterator=False (default)") + if len(rows) == 0: logger.info("Job Finished") return diff --git a/airflow/providers/google/cloud/utils/mlengine_operator_utils.py b/airflow/providers/google/cloud/utils/mlengine_operator_utils.py index 1d6dc5d437ca9..e4845bd7f6ba5 100644 --- a/airflow/providers/google/cloud/utils/mlengine_operator_utils.py +++ b/airflow/providers/google/cloud/utils/mlengine_operator_utils.py @@ -29,6 +29,7 @@ from airflow import DAG from airflow.exceptions import AirflowException from airflow.operators.python import PythonOperator +from airflow.providers.apache.beam.hooks.beam import BeamRunnerType from airflow.providers.apache.beam.operators.beam import BeamRunPythonPipelineOperator from airflow.providers.google.cloud.hooks.gcs import GCSHook from airflow.providers.google.cloud.operators.mlengine import MLEngineStartBatchPredictionJobOperator @@ -227,6 +228,7 @@ def validate_err_and_count(summary): metric_fn_encoded = base64.b64encode(dill.dumps(metric_fn, recurse=True)).decode() evaluate_summary = BeamRunPythonPipelineOperator( task_id=(task_prefix + "-summary"), + runner=BeamRunnerType.DataflowRunner, py_file=os.path.join(os.path.dirname(__file__), "mlengine_prediction_summary.py"), default_pipeline_options=dataflow_options, pipeline_options={ @@ -235,7 +237,7 @@ def validate_err_and_count(summary): "metric_keys": ",".join(metric_keys), }, py_interpreter=py_interpreter, - py_requirements=["apache-beam[gcp]>=2.14.0"], + py_requirements=["apache-beam[gcp]>=2.46.0"], dag=dag, ) evaluate_summary.set_upstream(evaluate_prediction) diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml index 35b1a15e33f20..6a81e61b09662 100644 --- a/airflow/providers/google/provider.yaml +++ b/airflow/providers/google/provider.yaml @@ -68,73 +68,58 @@ versions: dependencies: - apache-airflow>=2.4.0 - apache-airflow-providers-common-sql>=1.3.1 - # Google has very clear rules on what dependencies should be used. All the limits below - # follow strict guidelines of Google Libraries as quoted here: - # While this issue is open, dependents of google-api-core, google-cloud-core. and google-auth - # should preserve >1, <3 pins on these packages. - # https://github.com/googleapis/google-cloud-python/issues/10566 - # Some of Google Packages are limited to <2.0.0 because 2.0.0 releases of the libraries - # Introduced breaking changes across the board. Those libraries should be upgraded soon - # TODO: Upgrade all Google libraries that are limited to <2.0.0 - - PyOpenSSL - asgiref>=3.5.2 - gcloud-aio-auth>=4.0.0,<5.0.0 - gcloud-aio-bigquery>=6.1.2 - gcloud-aio-storage - # needed by vendored-in google-ads - - googleapis-common-protos<2.0.0,>=1.5.8 - - google-api-core==2.8.2 - - google-auth-oauthlib<1.0.0,>=0.3.0 - - grpcio<2.0.0,>=1.38.1 - - grpcio-status<2.0.0,>=1.38.1 - - PyYAML<7.0,>=5.1 - - proto-plus==1.19.6 - - protobuf!=3.18.*,!=3.19.*,<=3.20.0,>=3.12.0 - # Temporary commented out until we have google-ads vendored-in - # - google-ads>=15.1.1 - # - google-api-core>=2.7.0,<3.0.0 - - google-api-python-client>=1.6.0,<2.0.0 + - google-ads>=20.0.0 + - google-api-core>=2.11.0 + - google-api-python-client>=1.6.0 - google-auth>=1.0.0 - google-auth-httplib2>=0.0.1 - - google-cloud-aiplatform>=1.13.1,<2.0.0 - - google-cloud-automl>=2.1.0 - - google-cloud-bigquery-datatransfer>=3.0.0 - - google-cloud-bigtable>=2.0.0,<3.0.0 - - google-cloud-build>=3.0.0 - - google-cloud-compute>=0.1.0,<2.0.0 - - google-cloud-container>=2.2.0,<3.0.0 - - google-cloud-dataflow-client>=0.5.2 - - google-cloud-dataform>=0.2.0 - - google-cloud-datacatalog>=3.0.0 - - google-cloud-dataplex>=0.1.0 - - google-cloud-dataproc>=3.1.0 - - google-cloud-dataproc-metastore>=1.2.0,<2.0.0 - - google-cloud-dlp>=3.0.0 - - google-cloud-kms>=2.0.0 - - google-cloud-language>=1.1.1,<2.0.0 - - google-cloud-logging>=2.1.1 - - google-cloud-memcache>=0.2.0 - - google-cloud-monitoring>=2.0.0 - - google-cloud-os-login>=2.0.0 - - google-cloud-orchestration-airflow>=1.0.0,<2.0.0 - - google-cloud-pubsub>=2.0.0 - - google-cloud-redis>=2.0.0 - - google-cloud-secret-manager>=0.2.0,<2.0.0 - - google-cloud-spanner>=1.10.0,<2.0.0 - - google-cloud-speech>=0.36.3,<2.0.0 - - google-cloud-storage>=1.30,<3.0.0 - - google-cloud-tasks>=2.0.0 - - google-cloud-texttospeech>=0.4.0,<2.0.0 - - google-cloud-translate>=1.5.0,<2.0.0 - - google-cloud-videointelligence>=1.7.0,<2.0.0 - - google-cloud-vision>=0.35.2,<2.0.0 - - google-cloud-workflows>=0.1.0,<2.0.0 + - google-cloud-aiplatform>=1.22.1 + - google-cloud-automl>=2.11.0 + - google-cloud-bigquery-datatransfer>=3.11.0 + - google-cloud-bigtable>=2.17.0 + - google-cloud-build>=3.13.0 + - google-cloud-compute>=1.10.0 + - google-cloud-container>=2.17.4 + - google-cloud-datacatalog>=3.11.1 + - google-cloud-dataflow-client>=0.8.2 + - google-cloud-dataform>=0.5.0 + - google-cloud-dataplex>=1.4.2 + - google-cloud-dataproc>=5.4.0 + - google-cloud-dataproc-metastore>=1.10.0 + - google-cloud-dlp>=3.12.0 + - google-cloud-kms>=2.15.0 + - google-cloud-language>=2.9.0 + - google-cloud-logging>=3.5.0 + - google-cloud-memcache>=1.7.0 + - google-cloud-monitoring>=2.14.1 + - google-cloud-orchestration-airflow>=1.7.0 + - google-cloud-os-login>=2.9.1 + - google-cloud-pubsub>=2.15.0 + - google-cloud-redis>=2.12.0 + - google-cloud-secret-manager>=2.16.0 + - google-cloud-spanner>=3.11.1 + - google-cloud-speech>=2.18.0 + - google-cloud-storage>=2.7.0 + - google-cloud-tasks>=2.13.0 + - google-cloud-texttospeech>=2.14.1 + - google-cloud-translate>=3.11.0 + - google-cloud-videointelligence>=2.11.0 + - google-cloud-vision>=3.4.0 + - google-cloud-workflows>=1.10.0 - grpcio-gcp>=0.2.2 - httpx - json-merge-patch>=0.2 - looker-sdk>=22.2.0 - pandas-gbq - pandas>=0.17.1 + # A transient dependency of google-cloud-bigquery-datatransfer, but we + # further constrain it since older versions are buggy. + - proto-plus>=1.19.6 + - PyOpenSSL - sqlalchemy-bigquery>=1.2.1 integrations: diff --git a/docs/apache-airflow-providers-google/operators/cloud/compute_ssh.rst b/docs/apache-airflow-providers-google/operators/cloud/compute_ssh.rst index 7ccf3cb23a543..8e50f01566aff 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/compute_ssh.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/compute_ssh.rst @@ -45,20 +45,20 @@ Please note that the target instance must allow tcp traffic on port 22. Below is the code to create the operator: -.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_compute_ssh.py +.. exampleinclude:: /../../tests/system/providers/google/cloud/compute/example_compute_ssh.py :language: python :dedent: 4 - :start-after: [START howto_execute_command_on_remote1] - :end-before: [END howto_execute_command_on_remote1] + :start-after: [START howto_execute_command_on_remote_1] + :end-before: [END howto_execute_command_on_remote_1] You can also create the hook without project id - project id will be retrieved from the Google credentials used: -.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_compute_ssh.py +.. exampleinclude:: /../../tests/system/providers/google/cloud/compute/example_compute_ssh.py :language: python :dedent: 4 - :start-after: [START howto_execute_command_on_remote2] - :end-before: [END howto_execute_command_on_remote2] + :start-after: [START howto_execute_command_on_remote_2] + :end-before: [END howto_execute_command_on_remote_2] More information """""""""""""""" diff --git a/docs/apache-airflow-providers-google/operators/cloud/dataproc.rst b/docs/apache-airflow-providers-google/operators/cloud/dataproc.rst index f7775c6bfe870..5cc75b7f90ccf 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/dataproc.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/dataproc.rst @@ -104,7 +104,7 @@ For more information on updateMask and other parameters take a look at `Dataproc An example of a new cluster config and the updateMask: -.. exampleinclude:: /../../tests/system/providers/google/cloud/dataproc/example_dataproc_update.py +.. exampleinclude:: /../../tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_update.py :language: python :dedent: 0 :start-after: [START how_to_cloud_dataproc_updatemask_cluster_operator] @@ -113,7 +113,7 @@ An example of a new cluster config and the updateMask: To update a cluster you can use: :class:`~airflow.providers.google.cloud.operators.dataproc.DataprocUpdateClusterOperator` -.. exampleinclude:: /../../tests/system/providers/google/cloud/dataproc/example_dataproc_update.py +.. exampleinclude:: /../../tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_update.py :language: python :dedent: 4 :start-after: [START how_to_cloud_dataproc_update_cluster_operator] diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index e7f8f73027109..0d836769f4d61 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -46,7 +46,7 @@ "apache.beam": { "deps": [ "apache-airflow>=2.4.0", - "apache-beam>=2.33.0" + "apache-beam>=2.47.0" ], "cross-providers-deps": [ "google" @@ -333,62 +333,57 @@ "google": { "deps": [ "PyOpenSSL", - "PyYAML<7.0,>=5.1", "apache-airflow-providers-common-sql>=1.3.1", "apache-airflow>=2.4.0", "asgiref>=3.5.2", "gcloud-aio-auth>=4.0.0,<5.0.0", "gcloud-aio-bigquery>=6.1.2", "gcloud-aio-storage", - "google-api-core==2.8.2", - "google-api-python-client>=1.6.0,<2.0.0", + "google-ads>=20.0.0", + "google-api-core>=2.11.0", + "google-api-python-client>=1.6.0", "google-auth-httplib2>=0.0.1", - "google-auth-oauthlib<1.0.0,>=0.3.0", "google-auth>=1.0.0", - "google-cloud-aiplatform>=1.13.1,<2.0.0", - "google-cloud-automl>=2.1.0", - "google-cloud-bigquery-datatransfer>=3.0.0", - "google-cloud-bigtable>=2.0.0,<3.0.0", - "google-cloud-build>=3.0.0", - "google-cloud-compute>=0.1.0,<2.0.0", - "google-cloud-container>=2.2.0,<3.0.0", - "google-cloud-datacatalog>=3.0.0", - "google-cloud-dataflow-client>=0.5.2", - "google-cloud-dataform>=0.2.0", - "google-cloud-dataplex>=0.1.0", - "google-cloud-dataproc-metastore>=1.2.0,<2.0.0", - "google-cloud-dataproc>=3.1.0", - "google-cloud-dlp>=3.0.0", - "google-cloud-kms>=2.0.0", - "google-cloud-language>=1.1.1,<2.0.0", - "google-cloud-logging>=2.1.1", - "google-cloud-memcache>=0.2.0", - "google-cloud-monitoring>=2.0.0", - "google-cloud-orchestration-airflow>=1.0.0,<2.0.0", - "google-cloud-os-login>=2.0.0", - "google-cloud-pubsub>=2.0.0", - "google-cloud-redis>=2.0.0", - "google-cloud-secret-manager>=0.2.0,<2.0.0", - "google-cloud-spanner>=1.10.0,<2.0.0", - "google-cloud-speech>=0.36.3,<2.0.0", - "google-cloud-storage>=1.30,<3.0.0", - "google-cloud-tasks>=2.0.0", - "google-cloud-texttospeech>=0.4.0,<2.0.0", - "google-cloud-translate>=1.5.0,<2.0.0", - "google-cloud-videointelligence>=1.7.0,<2.0.0", - "google-cloud-vision>=0.35.2,<2.0.0", - "google-cloud-workflows>=0.1.0,<2.0.0", - "googleapis-common-protos<2.0.0,>=1.5.8", + "google-cloud-aiplatform>=1.22.1", + "google-cloud-automl>=2.11.0", + "google-cloud-bigquery-datatransfer>=3.11.0", + "google-cloud-bigtable>=2.17.0", + "google-cloud-build>=3.13.0", + "google-cloud-compute>=1.10.0", + "google-cloud-container>=2.17.4", + "google-cloud-datacatalog>=3.11.1", + "google-cloud-dataflow-client>=0.8.2", + "google-cloud-dataform>=0.5.0", + "google-cloud-dataplex>=1.4.2", + "google-cloud-dataproc-metastore>=1.10.0", + "google-cloud-dataproc>=5.4.0", + "google-cloud-dlp>=3.12.0", + "google-cloud-kms>=2.15.0", + "google-cloud-language>=2.9.0", + "google-cloud-logging>=3.5.0", + "google-cloud-memcache>=1.7.0", + "google-cloud-monitoring>=2.14.1", + "google-cloud-orchestration-airflow>=1.7.0", + "google-cloud-os-login>=2.9.1", + "google-cloud-pubsub>=2.15.0", + "google-cloud-redis>=2.12.0", + "google-cloud-secret-manager>=2.16.0", + "google-cloud-spanner>=3.11.1", + "google-cloud-speech>=2.18.0", + "google-cloud-storage>=2.7.0", + "google-cloud-tasks>=2.13.0", + "google-cloud-texttospeech>=2.14.1", + "google-cloud-translate>=3.11.0", + "google-cloud-videointelligence>=2.11.0", + "google-cloud-vision>=3.4.0", + "google-cloud-workflows>=1.10.0", "grpcio-gcp>=0.2.2", - "grpcio-status<2.0.0,>=1.38.1", - "grpcio<2.0.0,>=1.38.1", "httpx", "json-merge-patch>=0.2", "looker-sdk>=22.2.0", "pandas-gbq", "pandas>=0.17.1", - "proto-plus==1.19.6", - "protobuf!=3.18.*,!=3.19.*,<=3.20.0,>=3.12.0", + "proto-plus>=1.19.6", "sqlalchemy-bigquery>=1.2.1" ], "cross-providers-deps": [ diff --git a/tests/providers/google/cloud/_internal_client/test_secret_manager_client.py b/tests/providers/google/cloud/_internal_client/test_secret_manager_client.py index 2094766cb7d4d..86748dc9a0fcb 100644 --- a/tests/providers/google/cloud/_internal_client/test_secret_manager_client.py +++ b/tests/providers/google/cloud/_internal_client/test_secret_manager_client.py @@ -47,7 +47,7 @@ def test_get_non_existing_key(self, mock_secrets_client): secret = secrets_client.get_secret(secret_id="missing", project_id="project_id") mock_client.secret_version_path.assert_called_once_with("project_id", "missing", "latest") assert secret is None - mock_client.access_secret_version.assert_called_once_with("full-path") + mock_client.access_secret_version.assert_called_once_with(request={"name": "full-path"}) @mock.patch(INTERNAL_CLIENT_MODULE + ".SecretManagerServiceClient") def test_get_no_permissions(self, mock_secrets_client): @@ -60,7 +60,7 @@ def test_get_no_permissions(self, mock_secrets_client): secret = secrets_client.get_secret(secret_id="missing", project_id="project_id") mock_client.secret_version_path.assert_called_once_with("project_id", "missing", "latest") assert secret is None - mock_client.access_secret_version.assert_called_once_with("full-path") + mock_client.access_secret_version.assert_called_once_with(request={"name": "full-path"}) @mock.patch(INTERNAL_CLIENT_MODULE + ".SecretManagerServiceClient") def test_get_invalid_id(self, mock_secrets_client): @@ -73,7 +73,7 @@ def test_get_invalid_id(self, mock_secrets_client): secret = secrets_client.get_secret(secret_id="not.allow", project_id="project_id") mock_client.secret_version_path.assert_called_once_with("project_id", "not.allow", "latest") assert secret is None - mock_client.access_secret_version.assert_called_once_with("full-path") + mock_client.access_secret_version.assert_called_once_with(request={"name": "full-path"}) @mock.patch(INTERNAL_CLIENT_MODULE + ".SecretManagerServiceClient") def test_get_existing_key(self, mock_secrets_client): @@ -87,7 +87,7 @@ def test_get_existing_key(self, mock_secrets_client): secret = secrets_client.get_secret(secret_id="existing", project_id="project_id") mock_client.secret_version_path.assert_called_once_with("project_id", "existing", "latest") assert "result" == secret - mock_client.access_secret_version.assert_called_once_with("full-path") + mock_client.access_secret_version.assert_called_once_with(request={"name": "full-path"}) @mock.patch(INTERNAL_CLIENT_MODULE + ".SecretManagerServiceClient") def test_get_existing_key_with_version(self, mock_secrets_client): @@ -103,4 +103,4 @@ def test_get_existing_key_with_version(self, mock_secrets_client): ) mock_client.secret_version_path.assert_called_once_with("project_id", "existing", "test-version") assert "result" == secret - mock_client.access_secret_version.assert_called_once_with("full-path") + mock_client.access_secret_version.assert_called_once_with(request={"name": "full-path"}) diff --git a/tests/providers/google/cloud/hooks/test_bigtable.py b/tests/providers/google/cloud/hooks/test_bigtable.py index 231198e5c7cfe..4dda2fb009e76 100644 --- a/tests/providers/google/cloud/hooks/test_bigtable.py +++ b/tests/providers/google/cloud/hooks/test_bigtable.py @@ -503,8 +503,9 @@ def test_create_table(self, get_client, create): create.assert_called_once_with([], {}) @mock.patch("google.cloud.bigtable.cluster.Cluster.update") + @mock.patch("google.cloud.bigtable.cluster.Cluster.reload") @mock.patch("airflow.providers.google.cloud.hooks.bigtable.BigtableHook._get_client") - def test_update_cluster(self, get_client, update): + def test_update_cluster(self, get_client, reload, update): instance_method = get_client.return_value.instance instance_exists_method = instance_method.return_value.exists instance_exists_method.return_value = True @@ -514,6 +515,7 @@ def test_update_cluster(self, get_client, update): instance=instance, cluster_id=CBT_CLUSTER, nodes=4 ) get_client.assert_not_called() + reload.assert_called_once_with() update.assert_called_once_with() @mock.patch("google.cloud.bigtable.table.Table.list_column_families") diff --git a/tests/providers/google/cloud/hooks/test_cloud_sql.py b/tests/providers/google/cloud/hooks/test_cloud_sql.py index 64bbf8dfe5223..27eb176da4018 100644 --- a/tests/providers/google/cloud/hooks/test_cloud_sql.py +++ b/tests/providers/google/cloud/hooks/test_cloud_sql.py @@ -128,13 +128,7 @@ def test_instance_export_with_in_progress_retry(self, wait_for_operation_to_comp execute_method = export_method.return_value.execute execute_method.side_effect = [ HttpError( - resp=type( - "", - (object,), - { - "status": 429, - }, - )(), + resp=httplib2.Response({"status": 429}), content=b"Internal Server Error", ), {"name": "operation_id"}, @@ -200,13 +194,7 @@ def test_create_instance_with_in_progress_retry( execute_method = insert_method.return_value.execute execute_method.side_effect = [ HttpError( - resp=type( - "", - (object,), - { - "status": 429, - }, - )(), + resp=httplib2.Response({"status": 429}), content=b"Internal Server Error", ), {"name": "operation_id"}, @@ -234,13 +222,7 @@ def test_patch_instance_with_in_progress_retry( execute_method = patch_method.return_value.execute execute_method.side_effect = [ HttpError( - resp=type( - "", - (object,), - { - "status": 429, - }, - )(), + resp=httplib2.Response({"status": 429}), content=b"Internal Server Error", ), {"name": "operation_id"}, @@ -291,7 +273,7 @@ def test_delete_instance(self, wait_for_operation_to_complete, get_conn, mock_ge delete_method.assert_called_once_with(instance="instance", project="example-project") execute_method.assert_called_once_with(num_retries=5) wait_for_operation_to_complete.assert_called_once_with( - operation_name="operation_id", project_id="example-project" + operation_name="operation_id", project_id="example-project", time_to_sleep=5 ) assert 1 == mock_get_credentials.call_count @@ -308,13 +290,7 @@ def test_delete_instance_with_in_progress_retry( execute_method = delete_method.return_value.execute execute_method.side_effect = [ HttpError( - resp=type( - "", - (object,), - { - "status": 429, - }, - )(), + resp=httplib2.Response({"status": 429}), content=b"Internal Server Error", ), {"name": "operation_id"}, @@ -326,7 +302,7 @@ def test_delete_instance_with_in_progress_retry( assert 2 == delete_method.call_count assert 2 == execute_method.call_count wait_for_operation_to_complete.assert_called_once_with( - operation_name="operation_id", project_id="example-project" + operation_name="operation_id", project_id="example-project", time_to_sleep=5 ) @mock.patch( @@ -409,13 +385,7 @@ def test_create_database_with_in_progress_retry( execute_method = insert_method.return_value.execute execute_method.side_effect = [ HttpError( - resp=type( - "", - (object,), - { - "status": 429, - }, - )(), + resp=httplib2.Response({"status": 429}), content=b"Internal Server Error", ), {"name": "operation_id"}, @@ -465,13 +435,7 @@ def test_patch_database_with_in_progress_retry( execute_method = patch_method.return_value.execute execute_method.side_effect = [ HttpError( - resp=type( - "", - (object,), - { - "status": 429, - }, - )(), + resp=httplib2.Response({"status": 429}), content=b"Internal Server Error", ), {"name": "operation_id"}, @@ -521,13 +485,7 @@ def test_delete_database_with_in_progress_retry( execute_method = delete_method.return_value.execute execute_method.side_effect = [ HttpError( - resp=type( - "", - (object,), - { - "status": 429, - }, - )(), + resp=httplib2.Response({"status": 429}), content=b"Internal Server Error", ), {"name": "operation_id"}, @@ -684,7 +642,7 @@ def test_delete_instance_overridden_project_id( delete_method.assert_called_once_with(instance="instance", project="example-project") execute_method.assert_called_once_with(num_retries=5) wait_for_operation_to_complete.assert_called_once_with( - operation_name="operation_id", project_id="example-project" + operation_name="operation_id", project_id="example-project", time_to_sleep=5 ) @mock.patch( diff --git a/tests/providers/google/cloud/hooks/test_cloud_storage_transfer_service.py b/tests/providers/google/cloud/hooks/test_cloud_storage_transfer_service.py index d8cd3015989c0..15683186ae054 100644 --- a/tests/providers/google/cloud/hooks/test_cloud_storage_transfer_service.py +++ b/tests/providers/google/cloud/hooks/test_cloud_storage_transfer_service.py @@ -23,6 +23,7 @@ from unittest import mock from unittest.mock import MagicMock, PropertyMock +import httplib2 import pytest from googleapiclient.errors import HttpError @@ -101,11 +102,6 @@ def _with_name(body, job_name): return obj -class GCPRequestMock: - - status = TEST_HTTP_ERR_CODE - - class TestGCPTransferServiceHookWithPassedName: def test_delegate_to_runtime_error(self): with pytest.raises(RuntimeError): @@ -143,7 +139,9 @@ def test_pass_name_on_create_job( enable_transfer_job: MagicMock, ): body = _with_name(TEST_BODY, TEST_CLEAR_JOB_NAME) - get_conn.side_effect = HttpError(GCPRequestMock(), TEST_HTTP_ERR_CONTENT) + get_conn.side_effect = HttpError( + httplib2.Response({"status": TEST_HTTP_ERR_CODE}), TEST_HTTP_ERR_CONTENT + ) with pytest.raises(HttpError): diff --git a/tests/providers/google/cloud/hooks/test_compute.py b/tests/providers/google/cloud/hooks/test_compute.py index 0bba949a372c7..f4d5da5414ac0 100644 --- a/tests/providers/google/cloud/hooks/test_compute.py +++ b/tests/providers/google/cloud/hooks/test_compute.py @@ -70,8 +70,17 @@ def setup_method(self): impersonation_chain=IMPERSONATION_CHAIN, ) + @mock.patch(COMPUTE_ENGINE_HOOK_PATH.format("ComputeEngineHook._wait_for_operation_to_complete")) + @mock.patch( + "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id", + new_callable=PropertyMock, + return_value="mocked-google", + ) @mock.patch(COMPUTE_ENGINE_HOOK_PATH.format("ComputeEngineHook.get_compute_instance_template_client")) - def test_insert_template_should_execute_successfully(self, mock_client): + def test_insert_template_should_execute_successfully( + self, mock_client, mocked_project_id, wait_for_operation_to_complete + ): + wait_for_operation_to_complete.return_value = None self.hook.insert_instance_template( project_id=PROJECT_ID, body=BODY, @@ -90,13 +99,17 @@ def test_insert_template_should_execute_successfully(self, mock_client): metadata=METADATA, ) + @mock.patch(COMPUTE_ENGINE_HOOK_PATH.format("ComputeEngineHook._wait_for_operation_to_complete")) @mock.patch( "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id", new_callable=PropertyMock, return_value="mocked-google", ) @mock.patch(COMPUTE_ENGINE_HOOK_PATH.format("ComputeEngineHook.get_compute_instance_template_client")) - def test_insert_template_should_not_throw_ex_when_project_id_none(self, mock_client, mocked_project_id): + def test_insert_template_should_not_throw_ex_when_project_id_none( + self, mock_client, mocked_project_id, wait_for_operation_to_complete + ): + wait_for_operation_to_complete.return_value = None self.hook.insert_instance_template( body=BODY, retry=RETRY, @@ -114,8 +127,17 @@ def test_insert_template_should_not_throw_ex_when_project_id_none(self, mock_cli metadata=METADATA, ) + @mock.patch(COMPUTE_ENGINE_HOOK_PATH.format("ComputeEngineHook._wait_for_operation_to_complete")) + @mock.patch( + "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id", + new_callable=PropertyMock, + return_value="mocked-google", + ) @mock.patch(COMPUTE_ENGINE_HOOK_PATH.format("ComputeEngineHook.get_compute_instance_template_client")) - def test_delete_template_should_execute_successfully(self, mock_client): + def test_delete_template_should_execute_successfully( + self, mock_client, mocked_project_id, wait_for_operation_to_complete + ): + wait_for_operation_to_complete.return_value = None self.hook.delete_instance_template( project_id=PROJECT_ID, resource_id=RESOURCE_ID, @@ -134,13 +156,17 @@ def test_delete_template_should_execute_successfully(self, mock_client): metadata=METADATA, ) + @mock.patch(COMPUTE_ENGINE_HOOK_PATH.format("ComputeEngineHook._wait_for_operation_to_complete")) @mock.patch( "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id", new_callable=PropertyMock, return_value="mocked-google", ) @mock.patch(COMPUTE_ENGINE_HOOK_PATH.format("ComputeEngineHook.get_compute_instance_template_client")) - def test_delete_template_should_not_throw_ex_when_project_id_none(self, mock_client, mocked_project_id): + def test_delete_template_should_not_throw_ex_when_project_id_none( + self, mock_client, mocked_project_id, wait_for_operation_to_complete + ): + wait_for_operation_to_complete.return_value = None self.hook.delete_instance_template( resource_id=RESOURCE_ID, retry=RETRY, @@ -200,8 +226,10 @@ def test_get_template_should_not_throw_ex_when_project_id_none(self, mock_client metadata=METADATA, ) + @mock.patch(COMPUTE_ENGINE_HOOK_PATH.format("ComputeEngineHook._wait_for_operation_to_complete")) @mock.patch(COMPUTE_ENGINE_HOOK_PATH.format("ComputeEngineHook.get_compute_instance_client")) - def test_insert_instance_should_execute_successfully(self, mock_client): + def test_insert_instance_should_execute_successfully(self, mock_client, wait_for_operation_to_complete): + wait_for_operation_to_complete.return_value = None self.hook.insert_instance( project_id=PROJECT_ID, body=BODY, @@ -224,13 +252,17 @@ def test_insert_instance_should_execute_successfully(self, mock_client): metadata=METADATA, ) + @mock.patch(COMPUTE_ENGINE_HOOK_PATH.format("ComputeEngineHook._wait_for_operation_to_complete")) @mock.patch( "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id", new_callable=PropertyMock, return_value="mocked-google", ) @mock.patch(COMPUTE_ENGINE_HOOK_PATH.format("ComputeEngineHook.get_compute_instance_client")) - def test_insert_instance_should_not_throw_ex_when_project_id_none(self, mock_client, mocked_project_id): + def test_insert_instance_should_not_throw_ex_when_project_id_none( + self, mock_client, mocked_project_id, wait_for_operation_to_complete + ): + wait_for_operation_to_complete.return_value = None self.hook.insert_instance( body=BODY, zone=ZONE, @@ -298,8 +330,10 @@ def test_get_instance_should_not_throw_ex_when_project_id_none(self, mock_client metadata=METADATA, ) + @mock.patch(COMPUTE_ENGINE_HOOK_PATH.format("ComputeEngineHook._wait_for_operation_to_complete")) @mock.patch(COMPUTE_ENGINE_HOOK_PATH.format("ComputeEngineHook.get_compute_instance_client")) - def test_delete_instance_should_execute_successfully(self, mock_client): + def test_delete_instance_should_execute_successfully(self, mock_client, wait_for_operation_to_complete): + wait_for_operation_to_complete.return_value = None self.hook.delete_instance( resource_id=RESOURCE_ID, zone=ZONE, @@ -320,13 +354,17 @@ def test_delete_instance_should_execute_successfully(self, mock_client): metadata=METADATA, ) + @mock.patch(COMPUTE_ENGINE_HOOK_PATH.format("ComputeEngineHook._wait_for_operation_to_complete")) @mock.patch( "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id", new_callable=PropertyMock, return_value="mocked-google", ) @mock.patch(COMPUTE_ENGINE_HOOK_PATH.format("ComputeEngineHook.get_compute_instance_client")) - def test_delete_instance_should_not_throw_ex_when_project_id_none(self, mock_client, mocked_project_id): + def test_delete_instance_should_not_throw_ex_when_project_id_none( + self, mock_client, mocked_project_id, wait_for_operation_to_complete + ): + wait_for_operation_to_complete.return_value = None self.hook.delete_instance( resource_id=RESOURCE_ID, zone=ZONE, @@ -346,10 +384,14 @@ def test_delete_instance_should_not_throw_ex_when_project_id_none(self, mock_cli metadata=METADATA, ) + @mock.patch(COMPUTE_ENGINE_HOOK_PATH.format("ComputeEngineHook._wait_for_operation_to_complete")) @mock.patch( COMPUTE_ENGINE_HOOK_PATH.format("ComputeEngineHook.get_compute_instance_group_managers_client") ) - def test_insert_instance_group_manager_should_execute_successfully(self, mock_client): + def test_insert_instance_group_manager_should_execute_successfully( + self, mock_client, wait_for_operation_to_complete + ): + wait_for_operation_to_complete.return_value = None self.hook.insert_instance_group_manager( body=BODY, zone=ZONE, @@ -370,6 +412,7 @@ def test_insert_instance_group_manager_should_execute_successfully(self, mock_cl metadata=METADATA, ) + @mock.patch(COMPUTE_ENGINE_HOOK_PATH.format("ComputeEngineHook._wait_for_operation_to_complete")) @mock.patch( "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id", new_callable=PropertyMock, @@ -379,8 +422,9 @@ def test_insert_instance_group_manager_should_execute_successfully(self, mock_cl COMPUTE_ENGINE_HOOK_PATH.format("ComputeEngineHook.get_compute_instance_group_managers_client") ) def test_insert_instance_group_manager_should_not_throw_ex_when_project_id_none( - self, mock_client, mocked_project_id + self, mock_client, mocked_project_id, wait_for_operation_to_complete ): + wait_for_operation_to_complete.return_value = None self.hook.insert_instance_group_manager( body=BODY, zone=ZONE, @@ -452,10 +496,14 @@ def test_get_instance_group_manager_should_not_throw_ex_when_project_id_none( metadata=METADATA, ) + @mock.patch(COMPUTE_ENGINE_HOOK_PATH.format("ComputeEngineHook._wait_for_operation_to_complete")) @mock.patch( COMPUTE_ENGINE_HOOK_PATH.format("ComputeEngineHook.get_compute_instance_group_managers_client") ) - def test_delete_instance_group_manager_should_execute_successfully(self, mock_client): + def test_delete_instance_group_manager_should_execute_successfully( + self, mock_client, wait_for_operation_to_complete + ): + wait_for_operation_to_complete.return_value = None self.hook.delete_instance_group_manager( resource_id=RESOURCE_ID, zone=ZONE, @@ -476,6 +524,7 @@ def test_delete_instance_group_manager_should_execute_successfully(self, mock_cl metadata=METADATA, ) + @mock.patch(COMPUTE_ENGINE_HOOK_PATH.format("ComputeEngineHook._wait_for_operation_to_complete")) @mock.patch( "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id", new_callable=PropertyMock, @@ -485,8 +534,9 @@ def test_delete_instance_group_manager_should_execute_successfully(self, mock_cl COMPUTE_ENGINE_HOOK_PATH.format("ComputeEngineHook.get_compute_instance_group_managers_client") ) def test_delete_instance_group_manager_should_not_throw_ex_when_project_id_none( - self, mock_client, mocked_project_id + self, mock_client, mocked_project_id, wait_for_operation_to_complete ): + wait_for_operation_to_complete.return_value = None self.hook.delete_instance_group_manager( resource_id=RESOURCE_ID, zone=ZONE, diff --git a/tests/providers/google/cloud/hooks/test_natural_language.py b/tests/providers/google/cloud/hooks/test_natural_language.py index ea61a8e75b21d..22ba3d87c10f5 100644 --- a/tests/providers/google/cloud/hooks/test_natural_language.py +++ b/tests/providers/google/cloud/hooks/test_natural_language.py @@ -22,7 +22,7 @@ import pytest from google.api_core.gapic_v1.method import DEFAULT -from google.cloud.language_v1.proto.language_service_pb2 import Document +from google.cloud.language_v1 import Document from airflow.providers.google.cloud.hooks.natural_language import CloudNaturalLanguageHook from airflow.providers.google.common.consts import CLIENT_INFO diff --git a/tests/providers/google/cloud/hooks/test_pubsub.py b/tests/providers/google/cloud/hooks/test_pubsub.py index 72e08163a6e07..cad0451bc68e0 100644 --- a/tests/providers/google/cloud/hooks/test_pubsub.py +++ b/tests/providers/google/cloud/hooks/test_pubsub.py @@ -20,6 +20,7 @@ from unittest import mock from uuid import UUID +import httplib2 import pytest from google.api_core.exceptions import AlreadyExists, GoogleAPICallError from google.api_core.gapic_v1.method import DEFAULT @@ -432,7 +433,9 @@ def test_pull_no_messages(self, mock_service): @pytest.mark.parametrize( "exception", [ - pytest.param(HttpError(resp={"status": "404"}, content=EMPTY_CONTENT), id="http-error-404"), + pytest.param( + HttpError(resp=httplib2.Response({"status": 404}), content=EMPTY_CONTENT), id="http-error-404" + ), pytest.param(GoogleAPICallError("API Call Error"), id="google-api-call-error"), ], ) @@ -514,7 +517,9 @@ def test_acknowledge_fails_on_method_args_validation(self, mock_service, ack_ids @pytest.mark.parametrize( "exception", [ - pytest.param(HttpError(resp={"status": "404"}, content=EMPTY_CONTENT), id="http-error-404"), + pytest.param( + HttpError(resp=httplib2.Response({"status": 404}), content=EMPTY_CONTENT), id="http-error-404" + ), pytest.param(GoogleAPICallError("API Call Error"), id="google-api-call-error"), ], ) diff --git a/tests/providers/google/cloud/hooks/test_secret_manager.py b/tests/providers/google/cloud/hooks/test_secret_manager.py index 84e4b4334ef87..6f0f1a5339aca 100644 --- a/tests/providers/google/cloud/hooks/test_secret_manager.py +++ b/tests/providers/google/cloud/hooks/test_secret_manager.py @@ -21,7 +21,7 @@ import pytest from google.api_core.exceptions import NotFound -from google.cloud.secretmanager_v1.proto.service_pb2 import AccessSecretVersionResponse +from google.cloud.secretmanager_v1.types.service import AccessSecretVersionResponse from airflow.providers.google.cloud.hooks.secret_manager import SecretsManagerHook from tests.providers.google.cloud.utils.base_gcp_mock import ( @@ -52,7 +52,7 @@ def test_get_missing_key(self, mock_get_credentials, mock_client): mock_get_credentials.assert_called_once_with() secret = secrets_manager_hook.get_secret(secret_id="secret") mock_client.secret_version_path.assert_called_once_with("example-project", "secret", "latest") - mock_client.access_secret_version.assert_called_once_with("full-path") + mock_client.access_secret_version.assert_called_once_with(request={"name": "full-path"}) assert secret is None @patch(INTERNAL_CLIENT_PACKAGE + "._SecretManagerClient.client", return_value=MagicMock()) @@ -70,5 +70,5 @@ def test_get_existing_key(self, mock_get_credentials, mock_client): mock_get_credentials.assert_called_once_with() secret = secrets_manager_hook.get_secret(secret_id="secret") mock_client.secret_version_path.assert_called_once_with("example-project", "secret", "latest") - mock_client.access_secret_version.assert_called_once_with("full-path") + mock_client.access_secret_version.assert_called_once_with(request={"name": "full-path"}) assert "result" == secret diff --git a/tests/providers/google/cloud/hooks/test_speech_to_text.py b/tests/providers/google/cloud/hooks/test_speech_to_text.py index 3f66a6662181b..8cda071976245 100644 --- a/tests/providers/google/cloud/hooks/test_speech_to_text.py +++ b/tests/providers/google/cloud/hooks/test_speech_to_text.py @@ -21,17 +21,18 @@ import pytest from google.api_core.gapic_v1.method import DEFAULT +from google.cloud.speech_v1.types import RecognitionAudio, RecognitionConfig from airflow.providers.google.cloud.hooks.speech_to_text import CloudSpeechToTextHook from airflow.providers.google.common.consts import CLIENT_INFO from tests.providers.google.cloud.utils.base_gcp_mock import mock_base_gcp_hook_default_project_id PROJECT_ID = "project-id" -CONFIG = {"encryption": "LINEAR16"} +CONFIG = {"encoding": "LINEAR16"} AUDIO = {"uri": "gs://bucket/object"} -class TestTextToSpeechOperator: +class TestCloudSpeechToTextHook: def test_delegate_to_runtime_error(self): with pytest.raises(RuntimeError): CloudSpeechToTextHook(gcp_conn_id="GCP_CONN_ID", delegate_to="delegate_to") @@ -56,4 +57,6 @@ def test_synthesize_speech(self, get_conn): recognize_method = get_conn.return_value.recognize recognize_method.return_value = None self.gcp_speech_to_text_hook.recognize_speech(config=CONFIG, audio=AUDIO) - recognize_method.assert_called_once_with(config=CONFIG, audio=AUDIO, retry=DEFAULT, timeout=None) + recognize_method.assert_called_once_with( + config=RecognitionConfig(CONFIG), audio=RecognitionAudio(AUDIO), retry=DEFAULT, timeout=None + ) diff --git a/tests/providers/google/cloud/hooks/test_text_to_speech.py b/tests/providers/google/cloud/hooks/test_text_to_speech.py index 784a9c8c8fd98..573c56e098f58 100644 --- a/tests/providers/google/cloud/hooks/test_text_to_speech.py +++ b/tests/providers/google/cloud/hooks/test_text_to_speech.py @@ -21,6 +21,11 @@ import pytest from google.api_core.gapic_v1.method import DEFAULT +from google.cloud.texttospeech_v1.types import ( + AudioConfig, + SynthesisInput, + VoiceSelectionParams, +) from airflow.providers.google.cloud.hooks.text_to_speech import CloudTextToSpeechHook from airflow.providers.google.common.consts import CLIENT_INFO @@ -59,5 +64,9 @@ def test_synthesize_speech(self, get_conn): input_data=INPUT, voice=VOICE, audio_config=AUDIO_CONFIG ) synthesize_method.assert_called_once_with( - input_=INPUT, voice=VOICE, audio_config=AUDIO_CONFIG, retry=DEFAULT, timeout=None + input=SynthesisInput(INPUT), + voice=VoiceSelectionParams(VOICE), + audio_config=AudioConfig(AUDIO_CONFIG), + retry=DEFAULT, + timeout=None, ) diff --git a/tests/providers/google/cloud/hooks/test_video_intelligence.py b/tests/providers/google/cloud/hooks/test_video_intelligence.py index 46017de585d01..75bb9b92f71c8 100644 --- a/tests/providers/google/cloud/hooks/test_video_intelligence.py +++ b/tests/providers/google/cloud/hooks/test_video_intelligence.py @@ -21,7 +21,7 @@ import pytest from google.api_core.gapic_v1.method import DEFAULT -from google.cloud.videointelligence_v1 import enums +from google.cloud.videointelligence_v1 import Feature from airflow.providers.google.cloud.hooks.video_intelligence import CloudVideoIntelligenceHook from airflow.providers.google.common.consts import CLIENT_INFO @@ -30,7 +30,7 @@ INPUT_URI = "gs://bucket-name/input-file" OUTPUT_URI = "gs://bucket-name/output-file" -FEATURES = [enums.Feature.LABEL_DETECTION] +FEATURES = [Feature.LABEL_DETECTION] ANNOTATE_VIDEO_RESPONSE = {"test": "test"} @@ -69,12 +69,14 @@ def test_annotate_video(self, get_conn): # Then assert result is ANNOTATE_VIDEO_RESPONSE annotate_video_method.assert_called_once_with( - input_uri=INPUT_URI, - input_content=None, - features=FEATURES, - video_context=None, - output_uri=None, - location_id=None, + request={ + "input_uri": INPUT_URI, + "input_content": None, + "features": FEATURES, + "video_context": None, + "output_uri": None, + "location_id": None, + }, retry=DEFAULT, timeout=None, metadata=(), @@ -92,12 +94,14 @@ def test_annotate_video_with_output_uri(self, get_conn): # Then assert result is ANNOTATE_VIDEO_RESPONSE annotate_video_method.assert_called_once_with( - input_uri=INPUT_URI, - output_uri=OUTPUT_URI, - input_content=None, - features=FEATURES, - video_context=None, - location_id=None, + request={ + "input_uri": INPUT_URI, + "output_uri": OUTPUT_URI, + "input_content": None, + "features": FEATURES, + "video_context": None, + "location_id": None, + }, retry=DEFAULT, timeout=None, metadata=(), diff --git a/tests/providers/google/cloud/hooks/test_vision.py b/tests/providers/google/cloud/hooks/test_vision.py index dbdc4b2a66519..52bbd906801fc 100644 --- a/tests/providers/google/cloud/hooks/test_vision.py +++ b/tests/providers/google/cloud/hooks/test_vision.py @@ -21,14 +21,17 @@ import pytest from google.api_core.gapic_v1.method import DEFAULT -from google.cloud.vision import enums -from google.cloud.vision_v1 import ProductSearchClient -from google.cloud.vision_v1.proto.image_annotator_pb2 import ( +from google.cloud.vision_v1 import ( + AnnotateImageRequest, AnnotateImageResponse, EntityAnnotation, + Feature, + Product, + ProductSearchClient, + ProductSet, + ReferenceImage, SafeSearchAnnotation, ) -from google.cloud.vision_v1.proto.product_search_service_pb2 import Product, ProductSet, ReferenceImage from google.protobuf.json_format import MessageToDict from airflow.exceptions import AirflowException @@ -51,16 +54,16 @@ REFERENCE_IMAGE_GEN_ID_TEST = "ri-id" ANNOTATE_IMAGE_REQUEST = { "image": {"source": {"image_uri": "gs://bucket-name/object-name"}}, - "features": [{"type": enums.Feature.Type.LOGO_DETECTION}], + "features": [{"type": Feature.Type.LOGO_DETECTION}], } BATCH_ANNOTATE_IMAGE_REQUEST = [ { "image": {"source": {"image_uri": "gs://bucket-name/object-name"}}, - "features": [{"type": enums.Feature.Type.LOGO_DETECTION}], + "features": [{"type_": Feature.Type.LOGO_DETECTION}], }, { "image": {"source": {"image_uri": "gs://bucket-name/object-name"}}, - "features": [{"type": enums.Feature.Type.LOGO_DETECTION}], + "features": [{"type_": Feature.Type.LOGO_DETECTION}], }, ] REFERENCE_IMAGE_NAME_TEST = ( @@ -108,7 +111,7 @@ def test_create_productset_explicit_id(self, get_conn): # Given create_product_set_method = get_conn.return_value.create_product_set create_product_set_method.return_value = None - parent = ProductSearchClient.location_path(PROJECT_ID_TEST, LOC_ID_TEST) + parent = f"projects/{PROJECT_ID_TEST}/locations/{LOC_ID_TEST}" product_set = ProductSet() # When result = self.hook.create_product_set( @@ -142,7 +145,7 @@ def test_create_productset_autogenerated_id(self, get_conn): ) create_product_set_method = get_conn.return_value.create_product_set create_product_set_method.return_value = response_product_set - parent = ProductSearchClient.location_path(PROJECT_ID_TEST, LOC_ID_TEST) + parent = f"projects/{PROJECT_ID_TEST}/locations/{LOC_ID_TEST}" product_set = ProductSet() # When result = self.hook.create_product_set( @@ -167,7 +170,7 @@ def test_create_productset_autogenerated_id_wrong_api_response(self, get_conn): response_product_set = None create_product_set_method = get_conn.return_value.create_product_set create_product_set_method.return_value = response_product_set - parent = ProductSearchClient.location_path(PROJECT_ID_TEST, LOC_ID_TEST) + parent = f"projects/{PROJECT_ID_TEST}/locations/{LOC_ID_TEST}" product_set = ProductSet() # When with pytest.raises(AirflowException) as ctx: @@ -206,7 +209,7 @@ def test_get_productset(self, get_conn): ) # Then assert response - assert response == MessageToDict(response_product_set) + assert response == MessageToDict(response_product_set._pb) get_product_set_method.assert_called_once_with(name=name, retry=DEFAULT, timeout=None, metadata=()) @mock.patch("airflow.providers.google.cloud.hooks.vision.CloudVisionHook.get_conn") @@ -230,7 +233,7 @@ def test_update_productset_no_explicit_name(self, get_conn): metadata=(), ) # Then - assert result == MessageToDict(product_set) + assert result == MessageToDict(product_set._pb) update_product_set_method.assert_called_once_with( product_set=ProductSet(name=productset_name), metadata=(), @@ -292,7 +295,7 @@ def test_update_productset_explicit_name_missing_params_for_constructed_name( metadata=(), ) # Then - assert result == MessageToDict(product_set) + assert result == MessageToDict(product_set._pb) update_product_set_method.assert_called_once_with( product_set=ProductSet(name=explicit_ps_name), metadata=(), @@ -469,7 +472,9 @@ def test_batch_annotate_images(self, annotator_client_mock): # Then # Product ID was provided explicitly in the method call above, should be returned from the method batch_annotate_images_method.assert_called_once_with( - requests=BATCH_ANNOTATE_IMAGE_REQUEST, retry=DEFAULT, timeout=None + requests=list(map(AnnotateImageRequest, BATCH_ANNOTATE_IMAGE_REQUEST)), + retry=DEFAULT, + timeout=None, ) @mock.patch("airflow.providers.google.cloud.hooks.vision.CloudVisionHook.get_conn") @@ -477,7 +482,7 @@ def test_create_product_explicit_id(self, get_conn): # Given create_product_method = get_conn.return_value.create_product create_product_method.return_value = None - parent = ProductSearchClient.location_path(PROJECT_ID_TEST, LOC_ID_TEST) + parent = f"projects/{PROJECT_ID_TEST}/locations/{LOC_ID_TEST}" product = Product() # When result = self.hook.create_product( @@ -504,7 +509,7 @@ def test_create_product_autogenerated_id(self, get_conn): ) create_product_method = get_conn.return_value.create_product create_product_method.return_value = response_product - parent = ProductSearchClient.location_path(PROJECT_ID_TEST, LOC_ID_TEST) + parent = f"projects/{PROJECT_ID_TEST}/locations/{LOC_ID_TEST}" product = Product() # When result = self.hook.create_product( @@ -525,7 +530,7 @@ def test_create_product_autogenerated_id_wrong_name_in_response(self, get_conn): response_product = Product(name=wrong_name) create_product_method = get_conn.return_value.create_product create_product_method.return_value = response_product - parent = ProductSearchClient.location_path(PROJECT_ID_TEST, LOC_ID_TEST) + parent = f"projects/{PROJECT_ID_TEST}/locations/{LOC_ID_TEST}" product = Product() # When with pytest.raises(AirflowException) as ctx: @@ -546,7 +551,7 @@ def test_create_product_autogenerated_id_wrong_api_response(self, get_conn): response_product = None create_product_method = get_conn.return_value.create_product create_product_method.return_value = response_product - parent = ProductSearchClient.location_path(PROJECT_ID_TEST, LOC_ID_TEST) + parent = f"projects/{PROJECT_ID_TEST}/locations/{LOC_ID_TEST}" product = Product() # When with pytest.raises(AirflowException) as ctx: @@ -580,7 +585,7 @@ def test_update_product_no_explicit_name(self, get_conn): metadata=(), ) # Then - assert result == MessageToDict(product) + assert result == MessageToDict(product._pb) update_product_method.assert_called_once_with( product=Product(name=product_name), metadata=(), retry=DEFAULT, timeout=None, update_mask=None ) @@ -635,7 +640,7 @@ def test_update_product_explicit_name_missing_params_for_constructed_name( metadata=(), ) # Then - assert result == MessageToDict(product) + assert result == MessageToDict(product._pb) update_product_method.assert_called_once_with( product=Product(name=explicit_p_name), metadata=(), retry=DEFAULT, timeout=None, update_mask=None ) diff --git a/tests/providers/google/cloud/operators/test_automl.py b/tests/providers/google/cloud/operators/test_automl.py index ed99a0a472c43..2c9872f4450bb 100644 --- a/tests/providers/google/cloud/operators/test_automl.py +++ b/tests/providers/google/cloud/operators/test_automl.py @@ -71,6 +71,7 @@ class TestAutoMLTrainModelOperator: def test_execute(self, mock_hook): mock_hook.return_value.create_model.return_value.result.return_value = Model(name=MODEL_PATH) mock_hook.return_value.extract_object_id = extract_object_id + mock_hook.return_value.wait_for_operation.return_value = Model() op = AutoMLTrainModelOperator( model=MODEL, location=GCP_LOCATION, @@ -93,6 +94,7 @@ class TestAutoMLBatchPredictOperator: def test_execute(self, mock_hook): mock_hook.return_value.batch_predict.return_value.result.return_value = BatchPredictResult() mock_hook.return_value.extract_object_id = extract_object_id + mock_hook.return_value.wait_for_operation.return_value = BatchPredictResult() op = AutoMLBatchPredictOperator( model_id=MODEL_ID, diff --git a/tests/providers/google/cloud/operators/test_functions.py b/tests/providers/google/cloud/operators/test_functions.py index 5d520fa2b211c..18f2364c918d0 100644 --- a/tests/providers/google/cloud/operators/test_functions.py +++ b/tests/providers/google/cloud/operators/test_functions.py @@ -20,6 +20,7 @@ from copy import deepcopy from unittest import mock +import httplib2 import pytest from googleapiclient.errors import HttpError @@ -33,7 +34,7 @@ from airflow.version import version EMPTY_CONTENT = b"" -MOCK_RESP_404 = type("", (object,), {"status": 404})() +MOCK_RESP_404 = httplib2.Response({"status": 404}) GCP_PROJECT_ID = "test_project_id" GCP_LOCATION = "test_region" @@ -657,7 +658,7 @@ def test_gcf_error_silenced_when_function_doesnt_exist(self, mock_hook): @mock.patch("airflow.providers.google.cloud.operators.functions.CloudFunctionsHook") def test_non_404_gcf_error_bubbled_up(self, mock_hook): op = CloudFunctionDeleteFunctionOperator(name=self._FUNCTION_NAME, task_id="id") - resp = type("", (object,), {"status": 500})() + resp = httplib2.Response({"status": 500}) mock_hook.return_value.delete_function.side_effect = mock.Mock( side_effect=HttpError(resp=resp, content=b"error") ) diff --git a/tests/providers/google/cloud/operators/test_natural_language.py b/tests/providers/google/cloud/operators/test_natural_language.py index 132cb2443f10e..e77f77d80deea 100644 --- a/tests/providers/google/cloud/operators/test_natural_language.py +++ b/tests/providers/google/cloud/operators/test_natural_language.py @@ -19,7 +19,7 @@ from unittest.mock import patch -from google.cloud.language_v1.proto.language_service_pb2 import ( +from google.cloud.language_v1 import ( AnalyzeEntitiesResponse, AnalyzeEntitySentimentResponse, AnalyzeSentimentResponse, diff --git a/tests/providers/google/cloud/operators/test_speech_to_text.py b/tests/providers/google/cloud/operators/test_speech_to_text.py index a693b35f06b6c..51dd6dd8db7c0 100644 --- a/tests/providers/google/cloud/operators/test_speech_to_text.py +++ b/tests/providers/google/cloud/operators/test_speech_to_text.py @@ -21,6 +21,7 @@ import pytest from google.api_core.gapic_v1.method import DEFAULT +from google.cloud.speech_v1 import RecognizeResponse from airflow.exceptions import AirflowException from airflow.providers.google.cloud.operators.speech_to_text import CloudSpeechToTextRecognizeSpeechOperator @@ -32,10 +33,10 @@ AUDIO = {"uri": "gs://bucket/object"} -class TestCloudSql: +class TestCloudSpeechToTextRecognizeSpeechOperator: @patch("airflow.providers.google.cloud.operators.speech_to_text.CloudSpeechToTextHook") def test_recognize_speech_green_path(self, mock_hook): - mock_hook.return_value.recognize_speech.return_value = MagicMock() + mock_hook.return_value.recognize_speech.return_value = RecognizeResponse() CloudSpeechToTextRecognizeSpeechOperator( project_id=PROJECT_ID, diff --git a/tests/providers/google/cloud/operators/test_translate_speech.py b/tests/providers/google/cloud/operators/test_translate_speech.py index f46d411ede2c3..6dd000504cef5 100644 --- a/tests/providers/google/cloud/operators/test_translate_speech.py +++ b/tests/providers/google/cloud/operators/test_translate_speech.py @@ -20,7 +20,7 @@ from unittest import mock import pytest -from google.cloud.speech_v1.proto.cloud_speech_pb2 import ( +from google.cloud.speech_v1 import ( RecognizeResponse, SpeechRecognitionAlternative, SpeechRecognitionResult, diff --git a/tests/providers/google/cloud/operators/test_video_intelligence.py b/tests/providers/google/cloud/operators/test_video_intelligence.py index fc97c5ecdad3d..1688f7a4149d9 100644 --- a/tests/providers/google/cloud/operators/test_video_intelligence.py +++ b/tests/providers/google/cloud/operators/test_video_intelligence.py @@ -20,8 +20,7 @@ from unittest import mock from google.api_core.gapic_v1.method import DEFAULT -from google.cloud.videointelligence_v1 import enums -from google.cloud.videointelligence_v1.proto.video_intelligence_pb2 import AnnotateVideoResponse +from google.cloud.videointelligence_v1 import AnnotateVideoResponse, Feature from airflow.providers.google.cloud.operators.video_intelligence import ( CloudVideoIntelligenceDetectVideoExplicitContentOperator, @@ -59,7 +58,7 @@ def test_detect_video_labels_green_path(self, mock_hook): ) mock_hook.return_value.annotate_video.assert_called_once_with( input_uri=INPUT_URI, - features=[enums.Feature.LABEL_DETECTION], + features=[Feature.LABEL_DETECTION], input_content=None, video_context=None, location=None, @@ -86,7 +85,7 @@ def test_detect_video_explicit_content_green_path(self, mock_hook): ) mock_hook.return_value.annotate_video.assert_called_once_with( input_uri=INPUT_URI, - features=[enums.Feature.EXPLICIT_CONTENT_DETECTION], + features=[Feature.EXPLICIT_CONTENT_DETECTION], input_content=None, video_context=None, location=None, @@ -113,7 +112,7 @@ def test_detect_video_shots_green_path(self, mock_hook): ) mock_hook.return_value.annotate_video.assert_called_once_with( input_uri=INPUT_URI, - features=[enums.Feature.SHOT_CHANGE_DETECTION], + features=[Feature.SHOT_CHANGE_DETECTION], input_content=None, video_context=None, location=None, diff --git a/tests/providers/google/cloud/operators/test_vision.py b/tests/providers/google/cloud/operators/test_vision.py index 7a9ee186a7d81..38eb74facad0e 100644 --- a/tests/providers/google/cloud/operators/test_vision.py +++ b/tests/providers/google/cloud/operators/test_vision.py @@ -21,7 +21,7 @@ from google.api_core.exceptions import AlreadyExists from google.api_core.gapic_v1.method import DEFAULT -from google.cloud.vision_v1.types import Product, ProductSet, ReferenceImage +from google.cloud.vision_v1 import Product, ProductSet, ReferenceImage from airflow.providers.google.cloud.operators.vision import ( CloudVisionAddProductToProductSetOperator, diff --git a/tests/providers/google/cloud/operators/test_workflows.py b/tests/providers/google/cloud/operators/test_workflows.py index 33f16ea0a6a33..ad202fa5929d4 100644 --- a/tests/providers/google/cloud/operators/test_workflows.py +++ b/tests/providers/google/cloud/operators/test_workflows.py @@ -20,6 +20,7 @@ from unittest import mock import pytz +from google.protobuf.timestamp_pb2 import Timestamp from airflow.providers.google.cloud.operators.workflows import ( WorkflowsCancelExecutionOperator, @@ -169,8 +170,10 @@ class TestWorkflowsListWorkflowsOperator: @mock.patch(BASE_PATH.format("Workflow")) @mock.patch(BASE_PATH.format("WorkflowsHook")) def test_execute(self, mock_hook, mock_object): + timestamp = Timestamp() + timestamp.FromDatetime(datetime.datetime.now(tz=pytz.UTC) + datetime.timedelta(minutes=5)) workflow_mock = mock.MagicMock() - workflow_mock.start_time = datetime.datetime.now(tz=pytz.UTC) + datetime.timedelta(minutes=5) + workflow_mock.start_time = timestamp mock_hook.return_value.list_workflows.return_value = [workflow_mock] op = WorkflowsListWorkflowsOperator( @@ -330,8 +333,10 @@ class TestWorkflowExecutionsListExecutionsOperator: @mock.patch(BASE_PATH.format("Execution")) @mock.patch(BASE_PATH.format("WorkflowsHook")) def test_execute(self, mock_hook, mock_object): + timestamp = Timestamp() + timestamp.FromDatetime(datetime.datetime.now(tz=pytz.UTC) + datetime.timedelta(minutes=5)) execution_mock = mock.MagicMock() - execution_mock.start_time = datetime.datetime.now(tz=pytz.UTC) + datetime.timedelta(minutes=5) + execution_mock.start_time = timestamp mock_hook.return_value.list_executions.return_value = [execution_mock] op = WorkflowsListExecutionsOperator( diff --git a/tests/providers/google/cloud/triggers/test_cloud_build.py b/tests/providers/google/cloud/triggers/test_cloud_build.py index 8687263f8e65e..f8191d032345a 100644 --- a/tests/providers/google/cloud/triggers/test_cloud_build.py +++ b/tests/providers/google/cloud/triggers/test_cloud_build.py @@ -61,6 +61,9 @@ "volumes": [], "status": 0, "script": "", + "allow_failure": False, + "exit_code": 0, + "allow_exit_codes": [], } ], name="", diff --git a/tests/system/providers/google/cloud/automl/example_automl_dataset.py b/tests/system/providers/google/cloud/automl/example_automl_dataset.py index 909da1593a404..e3c5abf67122a 100644 --- a/tests/system/providers/google/cloud/automl/example_automl_dataset.py +++ b/tests/system/providers/google/cloud/automl/example_automl_dataset.py @@ -52,7 +52,7 @@ DATA_SAMPLE_GCS_BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}" RESOURCE_DATA_BUCKET = "system-tests-resources" -DATASET_NAME = f"ds_{DAG_ID}_{ENV_ID}" +DATASET_NAME = "test_dataset_tabular" DATASET = { "display_name": DATASET_NAME, "tables_dataset_metadata": {"target_column_spec_id": ""}, @@ -164,7 +164,7 @@ def get_target_column_spec(columns_specs: list[dict], column_name: str) -> str: # [START howto_operator_delete_dataset] delete_dataset = AutoMLDeleteDatasetOperator( task_id="delete_dataset", - dataset_id="{{ task_instance.xcom_pull('list_datasets_task', key='dataset_id_list') | list }}", + dataset_id=dataset_id, location=GCP_AUTOML_LOCATION, project_id=GCP_PROJECT_ID, ) diff --git a/tests/system/providers/google/cloud/automl/example_automl_model.py b/tests/system/providers/google/cloud/automl/example_automl_model.py index f08b46b0249e6..5e7560a022698 100644 --- a/tests/system/providers/google/cloud/automl/example_automl_model.py +++ b/tests/system/providers/google/cloud/automl/example_automl_model.py @@ -51,7 +51,7 @@ from airflow.utils.trigger_rule import TriggerRule ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") -DAG_ID = "automl_model" +DAG_ID = "example_automl_model" GCP_PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default") GCP_AUTOML_LOCATION = "us-central1" @@ -59,18 +59,18 @@ DATA_SAMPLE_GCS_BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}" RESOURCE_DATA_BUCKET = "system-tests-resources" -DATASET_NAME = f"dataset_{DAG_ID}_{ENV_ID}" +DATASET_NAME = "test_dataset_model" DATASET = { "display_name": DATASET_NAME, "tables_dataset_metadata": {"target_column_spec_id": ""}, } -AUTOML_DATASET_BUCKET = f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/automl/bank-marketing.csv" +AUTOML_DATASET_BUCKET = f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/automl-model/bank-marketing.csv" IMPORT_INPUT_CONFIG = {"gcs_source": {"input_uris": [AUTOML_DATASET_BUCKET]}} IMPORT_OUTPUT_CONFIG = { - "gcs_destination": {"output_uri_prefix": f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/automl"} + "gcs_destination": {"output_uri_prefix": f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/automl-model"} } -MODEL_NAME = f"model_{DAG_ID}_{ENV_ID}" +MODEL_NAME = "test_model" MODEL = { "display_name": MODEL_NAME, "tables_model_metadata": {"train_budget_milli_node_hours": 1000}, @@ -130,9 +130,9 @@ def get_target_column_spec(columns_specs: list[dict], column_name: str) -> str: move_dataset_file = GCSSynchronizeBucketsOperator( task_id="move_data_to_bucket", source_bucket=RESOURCE_DATA_BUCKET, - source_object="automl", + source_object="automl-model", destination_bucket=DATA_SAMPLE_GCS_BUCKET_NAME, - destination_object="automl", + destination_object="automl-model", recursive=True, ) @@ -255,7 +255,9 @@ def get_target_column_spec(columns_specs: list[dict], column_name: str) -> str: ( # TEST SETUP - [create_bucket >> move_dataset_file, create_dataset] + create_bucket + >> move_dataset_file + >> create_dataset >> import_dataset >> list_tables_spec >> list_columns_spec diff --git a/tests/system/providers/google/cloud/automl/example_automl_nl_text_extraction.py b/tests/system/providers/google/cloud/automl/example_automl_nl_text_extraction.py index 95dd549d8cef4..fd6a4fd6a76b2 100644 --- a/tests/system/providers/google/cloud/automl/example_automl_nl_text_extraction.py +++ b/tests/system/providers/google/cloud/automl/example_automl_nl_text_extraction.py @@ -34,32 +34,33 @@ AutoMLImportDataOperator, AutoMLTrainModelOperator, ) +from airflow.providers.google.cloud.operators.gcs import ( + GCSCreateBucketOperator, + GCSDeleteBucketOperator, + GCSSynchronizeBucketsOperator, +) from airflow.utils.trigger_rule import TriggerRule ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") DAG_ID = "example_automl_text" +GCP_PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default") -GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "your-project-id") -GCP_AUTOML_LOCATION = os.environ.get("GCP_AUTOML_LOCATION", "us-central1") -GCP_AUTOML_TEXT_BUCKET = os.environ.get( - "GCP_AUTOML_TEXT_BUCKET", "gs://INVALID BUCKET NAME/NL-entity/dataset.csv" -) +GCP_AUTOML_LOCATION = "us-central1" -# Example values -DATASET_ID = "" +DATA_SAMPLE_GCS_BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}" +RESOURCE_DATA_BUCKET = "system-tests-resources" -# Example model +DATASET_NAME = "test_entity_extr" +DATASET = {"display_name": DATASET_NAME, "text_extraction_dataset_metadata": {}} +AUTOML_DATASET_BUCKET = f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/automl-text/dataset.csv" +IMPORT_INPUT_CONFIG = {"gcs_source": {"input_uris": [AUTOML_DATASET_BUCKET]}} + +MODEL_NAME = "entity_extr_test_model" MODEL = { - "display_name": "auto_model_1", - "dataset_id": DATASET_ID, + "display_name": MODEL_NAME, "text_extraction_model_metadata": {}, } -# Example dataset -DATASET = {"display_name": "test_text_dataset", "text_extraction_dataset_metadata": {}} - -IMPORT_INPUT_CONFIG = {"gcs_source": {"input_uris": [GCP_AUTOML_TEXT_BUCKET]}} - extract_object_id = CloudAutoMLHook.extract_object_id # Example DAG for AutoML Natural Language Entities Extraction @@ -71,23 +72,46 @@ user_defined_macros={"extract_object_id": extract_object_id}, tags=["example", "automl"], ) as dag: + create_bucket = GCSCreateBucketOperator( + task_id="create_bucket", + bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME, + storage_class="REGIONAL", + location=GCP_AUTOML_LOCATION, + ) + + move_dataset_file = GCSSynchronizeBucketsOperator( + task_id="move_data_to_bucket", + source_bucket=RESOURCE_DATA_BUCKET, + source_object="automl-text", + destination_bucket=DATA_SAMPLE_GCS_BUCKET_NAME, + destination_object="automl-text", + recursive=True, + ) + create_dataset_task = AutoMLCreateDatasetOperator( - task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION + task_id="create_dataset_task", + dataset=DATASET, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, ) dataset_id = cast(str, XComArg(create_dataset_task, key="dataset_id")) - + MODEL["dataset_id"] = dataset_id import_dataset_task = AutoMLImportDataOperator( task_id="import_dataset_task", dataset_id=dataset_id, location=GCP_AUTOML_LOCATION, input_config=IMPORT_INPUT_CONFIG, + project_id=GCP_PROJECT_ID, ) - MODEL["dataset_id"] = dataset_id - create_model = AutoMLTrainModelOperator(task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION) - + create_model = AutoMLTrainModelOperator( + task_id="create_model", + model=MODEL, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) model_id = cast(str, XComArg(create_model, key="model_id")) delete_model_task = AutoMLDeleteModelOperator( @@ -105,15 +129,24 @@ trigger_rule=TriggerRule.ALL_DONE, ) + delete_bucket = GCSDeleteBucketOperator( + task_id="delete_bucket", + bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME, + trigger_rule=TriggerRule.ALL_DONE, + ) + ( # TEST SETUP - create_dataset_task - # TEST BODY + create_bucket + >> move_dataset_file + >> create_dataset_task >> import_dataset_task + # TEST BODY >> create_model - >> delete_model_task # TEST TEARDOWN + >> delete_model_task >> delete_datasets_task + >> delete_bucket ) from tests.system.utils.watcher import watcher diff --git a/tests/system/providers/google/cloud/automl/example_automl_vision_classification.py b/tests/system/providers/google/cloud/automl/example_automl_vision_classification.py index cd27eed40b699..5d6ea7b932c52 100644 --- a/tests/system/providers/google/cloud/automl/example_automl_vision_classification.py +++ b/tests/system/providers/google/cloud/automl/example_automl_vision_classification.py @@ -34,32 +34,35 @@ AutoMLImportDataOperator, AutoMLTrainModelOperator, ) +from airflow.providers.google.cloud.operators.gcs import ( + GCSCreateBucketOperator, + GCSDeleteBucketOperator, + GCSSynchronizeBucketsOperator, +) from airflow.utils.trigger_rule import TriggerRule ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") DAG_ID = "example_automl_vision" +GCP_PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default") -GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "your-project-id") -GCP_AUTOML_LOCATION = os.environ.get("GCP_AUTOML_LOCATION", "us-central1") -GCP_AUTOML_VISION_BUCKET = os.environ.get("GCP_AUTOML_VISION_BUCKET", "gs://INVALID BUCKET NAME") - -# Example values -DATASET_ID = "ICN123455678" +GCP_AUTOML_LOCATION = "us-central1" -# Example model -MODEL = { - "display_name": "auto_model_2", - "dataset_id": DATASET_ID, - "image_classification_model_metadata": {"train_budget": 1}, -} +DATA_SAMPLE_GCS_BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}" +RESOURCE_DATA_BUCKET = "system-tests-resources" -# Example dataset +DATASET_NAME = "test_dataset_vision" DATASET = { - "display_name": "test_vision_dataset", + "display_name": DATASET_NAME, "image_classification_dataset_metadata": {"classification_type": "MULTILABEL"}, } +AUTOML_DATASET_BUCKET = f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/automl-vision/data.csv" +IMPORT_INPUT_CONFIG = {"gcs_source": {"input_uris": [AUTOML_DATASET_BUCKET]}} -IMPORT_INPUT_CONFIG = {"gcs_source": {"input_uris": [GCP_AUTOML_VISION_BUCKET]}} +MODEL_NAME = "test_model" +MODEL = { + "display_name": MODEL_NAME, + "image_classification_model_metadata": {"train_budget": 1}, +} extract_object_id = CloudAutoMLHook.extract_object_id @@ -72,12 +75,29 @@ user_defined_macros={"extract_object_id": extract_object_id}, tags=["example", "automl"], ) as dag: + create_bucket = GCSCreateBucketOperator( + task_id="create_bucket", + bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME, + storage_class="REGIONAL", + location=GCP_AUTOML_LOCATION, + ) + + move_dataset_file = GCSSynchronizeBucketsOperator( + task_id="move_data_to_bucket", + source_bucket=RESOURCE_DATA_BUCKET, + source_object="automl-vision", + destination_bucket=DATA_SAMPLE_GCS_BUCKET_NAME, + destination_object="automl-vision", + recursive=True, + ) + create_dataset_task = AutoMLCreateDatasetOperator( - task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION + task_id="create_dataset_task", + dataset=DATASET, + location=GCP_AUTOML_LOCATION, ) dataset_id = cast(str, XComArg(create_dataset_task, key="dataset_id")) - import_dataset_task = AutoMLImportDataOperator( task_id="import_dataset_task", dataset_id=dataset_id, @@ -88,7 +108,6 @@ MODEL["dataset_id"] = dataset_id create_model = AutoMLTrainModelOperator(task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION) - model_id = cast(str, XComArg(create_model, key="model_id")) delete_model_task = AutoMLDeleteModelOperator( @@ -107,15 +126,24 @@ trigger_rule=TriggerRule.ALL_DONE, ) + delete_bucket = GCSDeleteBucketOperator( + task_id="delete_bucket", + bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME, + trigger_rule=TriggerRule.ALL_DONE, + ) + ( # TEST SETUP - create_dataset_task + create_bucket + >> move_dataset_file + >> create_dataset_task >> import_dataset_task # TEST BODY >> create_model # TEST TEARDOWN >> delete_model_task >> delete_datasets_task + >> delete_bucket ) from tests.system.utils.watcher import watcher diff --git a/tests/system/providers/google/cloud/bigquery/example_bigquery_queries.py b/tests/system/providers/google/cloud/bigquery/example_bigquery_queries.py index d5c3f449fd4aa..3ce1bc2801a77 100644 --- a/tests/system/providers/google/cloud/bigquery/example_bigquery_queries.py +++ b/tests/system/providers/google/cloud/bigquery/example_bigquery_queries.py @@ -242,6 +242,7 @@ execute_insert_query >> get_data >> get_data_result >> delete_dataset execute_insert_query >> execute_query_save >> bigquery_execute_multi_query >> delete_dataset execute_insert_query >> [check_count, check_value, check_interval] >> delete_dataset + execute_insert_query >> [column_check, table_check] >> delete_dataset from tests.system.utils.watcher import watcher diff --git a/tests/system/providers/google/cloud/bigquery/example_bigquery_queries_async.py b/tests/system/providers/google/cloud/bigquery/example_bigquery_queries_async.py index 334c2b4026dd3..38828dcb04b7f 100644 --- a/tests/system/providers/google/cloud/bigquery/example_bigquery_queries_async.py +++ b/tests/system/providers/google/cloud/bigquery/example_bigquery_queries_async.py @@ -147,7 +147,7 @@ task_id="select_query_job", configuration={ "query": { - "query": "{% include 'example_bigquery_query.sql' %}", + "query": "{% include 'resources/example_bigquery_query.sql' %}", "useLegacySql": False, } }, diff --git a/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql.py b/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql.py index 33a999b823415..c0d4e277d4ba4 100644 --- a/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql.py +++ b/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql.py @@ -19,11 +19,6 @@ Example Airflow DAG that creates, patches and deletes a Cloud SQL instance, and also creates, patches and deletes a database inside the instance, in Google Cloud. -This DAG relies on the following OS environment variables -https://airflow.apache.org/concepts.html#variables -* GCP_PROJECT_ID - Google Cloud project for the Cloud SQL instance. -* INSTANCE_NAME - Name of the Cloud SQL instance. -* DB_NAME - Name of the database inside a Cloud SQL instance. """ from __future__ import annotations diff --git a/tests/system/providers/google/cloud/composer/example_cloud_composer.py b/tests/system/providers/google/cloud/composer/example_cloud_composer.py index 8ffa8774d42c6..a6de4bc26be55 100644 --- a/tests/system/providers/google/cloud/composer/example_cloud_composer.py +++ b/tests/system/providers/google/cloud/composer/example_cloud_composer.py @@ -41,7 +41,7 @@ # [START howto_operator_composer_simple_environment] -ENVIRONMENT_ID = f"test-{DAG_ID}-{ENV_ID}" +ENVIRONMENT_ID = f"test-{DAG_ID}-{ENV_ID}".replace("_", "-") ENVIRONMENT = { "config": { diff --git a/tests/system/providers/google/cloud/composer/example_cloud_composer_deferrable.py b/tests/system/providers/google/cloud/composer/example_cloud_composer_deferrable.py index c7cfd2a090ad2..7ccafe98e997e 100644 --- a/tests/system/providers/google/cloud/composer/example_cloud_composer_deferrable.py +++ b/tests/system/providers/google/cloud/composer/example_cloud_composer_deferrable.py @@ -37,7 +37,7 @@ REGION = "us-central1" -ENVIRONMENT_ID = f"test-{DAG_ID}-{ENV_ID}" +ENVIRONMENT_ID = f"test-{DAG_ID}-{ENV_ID}".replace("_", "-") # [START howto_operator_composer_simple_environment] ENVIRONMENT = { "config": { diff --git a/tests/system/providers/google/cloud/compute/example_compute.py b/tests/system/providers/google/cloud/compute/example_compute.py index 59c7e63a2f90d..d5fbf347c09be 100644 --- a/tests/system/providers/google/cloud/compute/example_compute.py +++ b/tests/system/providers/google/cloud/compute/example_compute.py @@ -29,7 +29,6 @@ from airflow import models from airflow.models.baseoperator import chain -from airflow.operators.bash import BashOperator from airflow.providers.google.cloud.operators.compute import ( ComputeEngineDeleteInstanceOperator, ComputeEngineDeleteInstanceTemplateOperator, @@ -49,7 +48,7 @@ LOCATION = "europe-west1-b" REGION = "europe-west1" -GCE_INSTANCE_NAME = "instance-1" +GCE_INSTANCE_NAME = "instance-compute-test" SHORT_MACHINE_TYPE_NAME = "n1-standard-1" TEMPLATE_NAME = "instance-template" @@ -244,20 +243,12 @@ # [END howto_operator_gce_delete_new_template_no_project_id] gce_instance_template_delete.trigger_rule = TriggerRule.ALL_DONE - bash_wait_operator = BashOperator(task_id="delay_bash_task", bash_command="sleep 3m") - - bash_wait_operator2 = BashOperator(task_id="delay_bash_task2", bash_command="sleep 3m") - - bash_wait_operator3 = BashOperator(task_id="delay_bash_task3", bash_command="sleep 3m") - chain( gce_instance_insert, gce_instance_insert2, - bash_wait_operator, gce_instance_delete, gce_instance_template_insert, gce_instance_template_insert2, - bash_wait_operator2, gce_instance_insert_from_template, gce_instance_insert_from_template2, gce_instance_start, @@ -268,7 +259,6 @@ gce_set_machine_type2, gce_instance_delete2, gce_instance_template_delete, - bash_wait_operator3, ) # ### Everything below this line is not part of example ### diff --git a/tests/system/providers/google/cloud/compute/example_compute_igm.py b/tests/system/providers/google/cloud/compute/example_compute_igm.py index 54dc4e3711884..13ada15044b9f 100644 --- a/tests/system/providers/google/cloud/compute/example_compute_igm.py +++ b/tests/system/providers/google/cloud/compute/example_compute_igm.py @@ -29,7 +29,6 @@ from airflow import models from airflow.models.baseoperator import chain -from airflow.operators.bash import BashOperator from airflow.providers.google.cloud.operators.compute import ( ComputeEngineCopyInstanceTemplateOperator, ComputeEngineDeleteInstanceGroupManagerOperator, @@ -168,8 +167,6 @@ ) # [END howto_operator_gce_insert_igm_no_project_id] - bash_wait_operator = BashOperator(task_id="delay_bash_task", bash_command="sleep 3m") - # [START howto_operator_gce_igm_update_template] gce_instance_group_manager_update_template = ComputeEngineInstanceGroupUpdateManagerTemplateOperator( task_id="gcp_compute_igm_group_manager_update_template", @@ -193,8 +190,6 @@ ) # [END howto_operator_gce_igm_update_template_no_project_id] - bash_wait_operator1 = BashOperator(task_id="delay_bash_task_1", bash_command="sleep 3m") - # [START howto_operator_gce_delete_old_template_no_project_id] gce_instance_template_old_delete = ComputeEngineDeleteInstanceTemplateOperator( task_id="gcp_compute_delete_old_template_task", @@ -220,8 +215,6 @@ # [END howto_operator_gce_delete_igm_no_project_id] gce_igm_delete.trigger_rule = TriggerRule.ALL_DONE - bash_wait_operator2 = BashOperator(task_id="delay_bash_task_2", bash_command="sleep 3m") - chain( gce_instance_template_insert, gce_instance_template_insert2, @@ -229,14 +222,11 @@ gce_instance_template_copy2, gce_igm_insert, gce_igm_insert2, - bash_wait_operator, gce_instance_group_manager_update_template, gce_instance_group_manager_update_template2, - bash_wait_operator1, gce_igm_delete, gce_instance_template_old_delete, gce_instance_template_new_delete, - bash_wait_operator2, ) # ### Everything below this line is not part of example ### diff --git a/tests/system/providers/google/cloud/compute/example_compute_ssh.py b/tests/system/providers/google/cloud/compute/example_compute_ssh.py index 48107d8ae0f42..40129972f54b0 100644 --- a/tests/system/providers/google/cloud/compute/example_compute_ssh.py +++ b/tests/system/providers/google/cloud/compute/example_compute_ssh.py @@ -43,7 +43,7 @@ DAG_ID = "cloud_compute_ssh" LOCATION = "europe-west1-b" REGION = "europe-west1" -GCE_INSTANCE_NAME = "instance-1" +GCE_INSTANCE_NAME = "instance-ssh-test" SHORT_MACHINE_TYPE_NAME = "n1-standard-1" GCE_INSTANCE_BODY = { "name": GCE_INSTANCE_NAME, @@ -85,19 +85,36 @@ ) # [END howto_operator_gce_insert] - # [START howto_execute_command_on_remote] - metadata_without_iap_tunnel = SSHOperator( - task_id="metadata_without_iap_tunnel", + # [START howto_execute_command_on_remote_1] + metadata_without_iap_tunnel1 = SSHOperator( + task_id="metadata_without_iap_tunnel1", ssh_hook=ComputeEngineSSHHook( user="username", instance_name=GCE_INSTANCE_NAME, zone=LOCATION, + project_id=PROJECT_ID, use_oslogin=False, use_iap_tunnel=False, + cmd_timeout=100, ), - command="echo metadata_without_iap_tunnel", + command="echo metadata_without_iap_tunnel1", ) - # [END howto_execute_command_on_remote] + # [END howto_execute_command_on_remote_1] + + # [START howto_execute_command_on_remote_2] + metadata_without_iap_tunnel2 = SSHOperator( + task_id="metadata_without_iap_tunnel2", + ssh_hook=ComputeEngineSSHHook( + user="username", + instance_name=GCE_INSTANCE_NAME, + zone=LOCATION, + use_oslogin=False, + use_iap_tunnel=False, + cmd_timeout=100, + ), + command="echo metadata_without_iap_tunnel2", + ) + # [END howto_execute_command_on_remote_2] # [START howto_operator_gce_delete_no_project_id] gce_instance_delete = ComputeEngineDeleteInstanceOperator( @@ -110,7 +127,8 @@ chain( gce_instance_insert, - metadata_without_iap_tunnel, + metadata_without_iap_tunnel1, + metadata_without_iap_tunnel2, gce_instance_delete, ) diff --git a/tests/system/providers/google/cloud/dataflow/example_dataflow_go.py b/tests/system/providers/google/cloud/dataflow/example_dataflow_go.py new file mode 100644 index 0000000000000..548a6902b7197 --- /dev/null +++ b/tests/system/providers/google/cloud/dataflow/example_dataflow_go.py @@ -0,0 +1,154 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Airflow DAG for Apache Beam operators + +Requirements: + This test requires the gcloud and go commands to run. +""" +from __future__ import annotations + +import os +from datetime import datetime +from pathlib import Path + +from airflow import models +from airflow.providers.apache.beam.hooks.beam import BeamRunnerType +from airflow.providers.apache.beam.operators.beam import BeamRunGoPipelineOperator +from airflow.providers.google.cloud.hooks.dataflow import DataflowJobStatus +from airflow.providers.google.cloud.operators.dataflow import DataflowConfiguration +from airflow.providers.google.cloud.operators.gcs import GCSCreateBucketOperator, GCSDeleteBucketOperator +from airflow.providers.google.cloud.sensors.dataflow import ( + DataflowJobAutoScalingEventsSensor, + DataflowJobMessagesSensor, + DataflowJobStatusSensor, +) +from airflow.providers.google.cloud.transfers.local_to_gcs import LocalFilesystemToGCSOperator +from airflow.utils.trigger_rule import TriggerRule + +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") +DAG_ID = "dataflow_native_go_async" +BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}" + +GCS_TMP = f"gs://{BUCKET_NAME}/temp/" +GCS_STAGING = f"gs://{BUCKET_NAME}/staging/" +GCS_OUTPUT = f"gs://{BUCKET_NAME}/output" + +GO_FILE_NAME = "wordcount.go" +GO_FILE_LOCAL_PATH = str(Path(__file__).parent / "resources" / GO_FILE_NAME) +GCS_GO = f"gs://{BUCKET_NAME}/{GO_FILE_NAME}" +LOCATION = "europe-west3" + +default_args = { + "dataflow_default_options": { + "tempLocation": GCS_TMP, + "stagingLocation": GCS_STAGING, + } +} + +with models.DAG( + "example_beam_native_go", + start_date=datetime(2021, 1, 1), + schedule="@once", + catchup=False, + default_args=default_args, + tags=["example"], +) as dag: + create_bucket = GCSCreateBucketOperator(task_id="create_bucket", bucket_name=BUCKET_NAME) + + upload_file = LocalFilesystemToGCSOperator( + task_id="upload_file_to_bucket", + src=GO_FILE_LOCAL_PATH, + dst=GO_FILE_NAME, + bucket=BUCKET_NAME, + ) + + start_go_pipeline_dataflow_runner = BeamRunGoPipelineOperator( + task_id="start_go_pipeline_dataflow_runner", + runner=BeamRunnerType.DataflowRunner, + go_file=GCS_GO, + pipeline_options={ + "tempLocation": GCS_TMP, + "stagingLocation": GCS_STAGING, + "output": GCS_OUTPUT, + "WorkerHarnessContainerImage": "apache/beam_go_sdk:2.46.0", + }, + dataflow_config=DataflowConfiguration(job_name="start_go_job", location=LOCATION), + ) + + wait_for_go_job_async_done = DataflowJobStatusSensor( + task_id="wait_for_go_job_async_done", + job_id="{{task_instance.xcom_pull('start_go_pipeline_dataflow_runner')['dataflow_job_id']}}", + expected_statuses={DataflowJobStatus.JOB_STATE_DONE}, + location=LOCATION, + ) + + def check_message(messages: list[dict]) -> bool: + """Check message""" + for message in messages: + if "Adding workflow start and stop steps." in message.get("messageText", ""): + return True + return False + + wait_for_go_job_async_message = DataflowJobMessagesSensor( + task_id="wait_for_go_job_async_message", + job_id="{{task_instance.xcom_pull('start_go_pipeline_dataflow_runner')['dataflow_job_id']}}", + location=LOCATION, + callback=check_message, + fail_on_terminal_state=False, + ) + + def check_autoscaling_event(autoscaling_events: list[dict]) -> bool: + """Check autoscaling event""" + for autoscaling_event in autoscaling_events: + if "Worker pool started." in autoscaling_event.get("description", {}).get("messageText", ""): + return True + return False + + wait_for_go_job_async_autoscaling_event = DataflowJobAutoScalingEventsSensor( + task_id="wait_for_go_job_async_autoscaling_event", + job_id="{{task_instance.xcom_pull('start_go_pipeline_dataflow_runner')['dataflow_job_id']}}", + location=LOCATION, + callback=check_autoscaling_event, + fail_on_terminal_state=False, + ) + + delete_bucket = GCSDeleteBucketOperator( + task_id="delete_bucket", bucket_name=BUCKET_NAME, trigger_rule=TriggerRule.ALL_DONE + ) + + ( + # TEST SETUP + create_bucket + >> upload_file + # TEST BODY + >> start_go_pipeline_dataflow_runner + >> [ + wait_for_go_job_async_done, + wait_for_go_job_async_message, + wait_for_go_job_async_autoscaling_event, + ] + # TEST TEARDOWN + >> delete_bucket + ) + + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag) diff --git a/tests/system/providers/google/cloud/dataflow/example_dataflow_native_java.py b/tests/system/providers/google/cloud/dataflow/example_dataflow_native_java.py index 41065519146e9..e3dbe69217090 100644 --- a/tests/system/providers/google/cloud/dataflow/example_dataflow_native_java.py +++ b/tests/system/providers/google/cloud/dataflow/example_dataflow_native_java.py @@ -24,6 +24,12 @@ or is not compatible with the Java version used in the test, the source code for this test can be downloaded from here (https://beam.apache.org/get-started/wordcount-example) and needs to be compiled manually in order to work. + + You can follow the instructions on how to pack a self-executing jar here: + https://beam.apache.org/documentation/runners/dataflow/ + +Requirements: + These operators require the gcloud command and Java's JRE to run. """ from __future__ import annotations @@ -31,6 +37,7 @@ from datetime import datetime from airflow import models +from airflow.providers.apache.beam.hooks.beam import BeamRunnerType from airflow.providers.apache.beam.operators.beam import BeamRunJavaPipelineOperator from airflow.providers.google.cloud.operators.dataflow import CheckJobRunning from airflow.providers.google.cloud.operators.gcs import GCSCreateBucketOperator, GCSDeleteBucketOperator @@ -83,6 +90,7 @@ # [START howto_operator_start_java_job_jar_on_gcs] start_java_job = BeamRunJavaPipelineOperator( + runner=BeamRunnerType.DataflowRunner, task_id="start-java-job", jar=GCS_JAR, pipeline_options={ diff --git a/tests/system/providers/google/cloud/dataflow/example_dataflow_native_python.py b/tests/system/providers/google/cloud/dataflow/example_dataflow_native_python.py index c53fde5505a1c..6e80f8664c4a2 100644 --- a/tests/system/providers/google/cloud/dataflow/example_dataflow_native_python.py +++ b/tests/system/providers/google/cloud/dataflow/example_dataflow_native_python.py @@ -26,6 +26,7 @@ from pathlib import Path from airflow import models +from airflow.providers.apache.beam.hooks.beam import BeamRunnerType from airflow.providers.apache.beam.operators.beam import BeamRunPythonPipelineOperator from airflow.providers.google.cloud.operators.gcs import GCSCreateBucketOperator, GCSDeleteBucketOperator from airflow.providers.google.cloud.transfers.local_to_gcs import LocalFilesystemToGCSOperator @@ -70,13 +71,14 @@ # [START howto_operator_start_python_job] start_python_job = BeamRunPythonPipelineOperator( + runner=BeamRunnerType.DataflowRunner, task_id="start_python_job", py_file=GCS_PYTHON_SCRIPT, py_options=[], pipeline_options={ "output": GCS_OUTPUT, }, - py_requirements=["apache-beam[gcp]==2.36.0"], + py_requirements=["apache-beam[gcp]==2.46.0"], py_interpreter="python3", py_system_site_packages=False, dataflow_config={"location": LOCATION}, @@ -90,7 +92,7 @@ pipeline_options={ "output": GCS_OUTPUT, }, - py_requirements=["apache-beam[gcp]==2.36.0"], + py_requirements=["apache-beam[gcp]==2.46.0"], py_interpreter="python3", py_system_site_packages=False, ) diff --git a/tests/system/providers/google/cloud/dataflow/example_dataflow_native_python_async.py b/tests/system/providers/google/cloud/dataflow/example_dataflow_native_python_async.py index 8b8ce5b6477d4..9d03e851a3cdc 100644 --- a/tests/system/providers/google/cloud/dataflow/example_dataflow_native_python_async.py +++ b/tests/system/providers/google/cloud/dataflow/example_dataflow_native_python_async.py @@ -87,7 +87,7 @@ pipeline_options={ "output": GCS_OUTPUT, }, - py_requirements=["apache-beam[gcp]==2.36.0"], + py_requirements=["apache-beam[gcp]==2.46.0"], py_interpreter="python3", py_system_site_packages=False, dataflow_config={ diff --git a/tests/system/providers/google/cloud/dataflow/resources/wordcount.go b/tests/system/providers/google/cloud/dataflow/resources/wordcount.go new file mode 100644 index 0000000000000..4ad9dcd4f098e --- /dev/null +++ b/tests/system/providers/google/cloud/dataflow/resources/wordcount.go @@ -0,0 +1,230 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// wordcount is an example that counts words in Shakespeare and demonstrates +// Beam best practices. +// +// This example is the second in a series of four successively more detailed +// 'word count' examples. You may first want to take a look at +// minimal_wordcount. After you've looked at this example, see the +// debugging_wordcount pipeline for introduction of additional concepts. +// +// For a detailed walkthrough of this example, see +// +// https://beam.apache.org/get-started/wordcount-example/ +// +// Basic concepts, also in the minimal_wordcount example: reading text files; +// counting a PCollection; writing to text files. +// +// New concepts: +// +// 1. Executing a pipeline both locally and using the selected runner +// 2. Defining your own pipeline options +// 3. Using ParDo with static DoFns defined out-of-line +// 4. Building a composite transform +// +// Concept #1: You can execute this pipeline either locally or by +// selecting another runner. These are now command-line options added by +// the 'beamx' package and not hard-coded as they were in the minimal_wordcount +// example. The 'beamx' package also registers all included runners and +// filesystems as a convenience. +// +// To change the runner, specify: +// +// --runner=YOUR_SELECTED_RUNNER +// +// To execute this pipeline, specify a local output file (if using the +// 'direct' runner) or a remote file on a supported distributed file system. +// +// --output=[YOUR_LOCAL_FILE | YOUR_REMOTE_FILE] +// +// The input file defaults to a public data set containing the text of King +// Lear by William Shakespeare. You can override it and choose your own input +// with --input. +package main + +// beam-playground: +// name: WordCount +// description: An example that counts words in Shakespeare's works. +// multifile: false +// pipeline_options: --output output.txt +// context_line: 120 +// categories: +// - Combiners +// - Options +// - Quickstart +// complexity: MEDIUM +// tags: +// - count +// - io +// - strings + +import ( + "context" + "flag" + "fmt" + "log" + "regexp" + "strings" + + "github.com/apache/beam/sdks/v2/go/pkg/beam" + "github.com/apache/beam/sdks/v2/go/pkg/beam/io/textio" + "github.com/apache/beam/sdks/v2/go/pkg/beam/register" + "github.com/apache/beam/sdks/v2/go/pkg/beam/transforms/stats" + "github.com/apache/beam/sdks/v2/go/pkg/beam/x/beamx" +) + +// Concept #2: Defining your own configuration options. Pipeline options can +// be standard Go flags, or they can be obtained any other way. Defining and +// configuring the pipeline is normal Go code. +var ( + // By default, this example reads from a public dataset containing the text of + // King Lear. Set this option to choose a different input file or glob. + input = flag.String("input", "gs://apache-beam-samples/shakespeare/kinglear.txt", "File(s) to read.") + + // Set this required option to specify where to write the output. + output = flag.String("output", "", "Output file (required).") +) + +// Concept #3: You can make your pipeline assembly code less verbose by +// defining your DoFns statically out-of-line. A DoFn can be defined as a Go +// function and is conventionally suffixed "Fn". Using named function +// transforms allows for easy reuse, modular testing, and an improved monitoring +// experience. The argument and return types of a function dictate the pipeline +// shape when used in a ParDo. For example, +// +// func formatFn(w string, c int) string +// +// indicates that the function operates on a PCollection of type KV, +// representing key value pairs of strings and ints, and outputs a PCollection +// of type string. Beam typechecks the pipeline before running it. +// +// DoFns that potentially output zero or multiple elements can also be Go +// functions, but have a different signature. For example, +// +// func extractFn(w string, emit func(string)) +// +// uses an "emit" function argument instead of a string return type to allow it +// to output any number of elements. It operates on a PCollection of type string +// and returns a PCollection of type string. +// +// DoFns must be registered with Beam in order to be executed in ParDos. This is +// done automatically by the starcgen code generator, or it can be done manually +// by calling beam.RegisterFunction in an init() call. +func init() { + // register.DoFnXxY registers a struct DoFn so that it can be correctly + // serialized and does some optimization to avoid runtime reflection. Since + // extractFn has 3 inputs and 0 outputs, we use register.DoFn3x0 and provide + // its input types as its constraints (if it had any outputs, we would add + // those as constraints as well). Struct DoFns must be registered for a + // pipeline to run. + register.DoFn3x0[context.Context, string, func(string)](&extractFn{}) + // register.FunctionXxY registers a functional DoFn to optimize execution at + // runtime. formatFn has 2 inputs and 1 output, so we use + // register.Function2x1. + register.Function2x1(formatFn) + // register.EmitterX is optional and will provide some optimization to make + // things run faster. Any emitters (functions that produce output for the next + // step) should be registered. Here we register all emitters with the + // signature func(string). + register.Emitter1[string]() +} + +var ( + wordRE = regexp.MustCompile(`[a-zA-Z]+('[a-z])?`) + empty = beam.NewCounter("extract", "emptyLines") + smallWordLength = flag.Int("small_word_length", 9, "length of small words (default: 9)") + smallWords = beam.NewCounter("extract", "smallWords") + lineLen = beam.NewDistribution("extract", "lineLenDistro") +) + +// extractFn is a structural DoFn that emits the words in a given line and keeps +// a count for small words. Its ProcessElement function will be invoked on each +// element in the input PCollection. +type extractFn struct { + SmallWordLength int `json:"smallWordLength"` +} + +func (f *extractFn) ProcessElement(ctx context.Context, line string, emit func(string)) { + lineLen.Update(ctx, int64(len(line))) + if len(strings.TrimSpace(line)) == 0 { + empty.Inc(ctx, 1) + } + for _, word := range wordRE.FindAllString(line, -1) { + // increment the counter for small words if length of words is + // less than small_word_length + if len(word) < f.SmallWordLength { + smallWords.Inc(ctx, 1) + } + emit(word) + } +} + +// formatFn is a functional DoFn that formats a word and its count as a string. +func formatFn(w string, c int) string { + return fmt.Sprintf("%s: %v", w, c) +} + +// Concept #4: A composite PTransform is a Go function that adds +// transformations to a given pipeline. It is run at construction time and +// works on PCollections as values. For monitoring purposes, the pipeline +// allows scoped naming for composite transforms. The difference between a +// composite transform and a construction helper function is solely in whether +// a scoped name is used. +// +// For example, the CountWords function is a custom composite transform that +// bundles two transforms (ParDo and Count) as a reusable function. + +// CountWords is a composite transform that counts the words of a PCollection +// of lines. It expects a PCollection of type string and returns a PCollection +// of type KV. The Beam type checker enforces these constraints +// during pipeline construction. +func CountWords(s beam.Scope, lines beam.PCollection) beam.PCollection { + s = s.Scope("CountWords") + + // Convert lines of text into individual words. + col := beam.ParDo(s, &extractFn{SmallWordLength: *smallWordLength}, lines) + + // Count the number of times each word occurs. + return stats.Count(s, col) +} + +func main() { + // If beamx or Go flags are used, flags must be parsed first. + flag.Parse() + // beam.Init() is an initialization hook that must be called on startup. On + // distributed runners, it is used to intercept control. + beam.Init() + + // Input validation is done as usual. Note that it must be after Init(). + if *output == "" { + log.Fatal("No output provided") + } + + // Concepts #3 and #4: The pipeline uses the named transform and DoFn. + p := beam.NewPipeline() + s := p.Root() + + lines := textio.Read(s, *input) + counted := CountWords(s, lines) + formatted := beam.ParDo(s, formatFn, counted) + textio.Write(s, *output, formatted) + + // Concept #1: The beamx.Run convenience wrapper allows a number of + // pre-defined runners to be used via the --runner flag. + if err := beamx.Run(context.Background(), p); err != nil { + log.Fatalf("Failed to execute job: %v", err) + } +} diff --git a/tests/system/providers/google/cloud/dataprep/example_dataprep.py b/tests/system/providers/google/cloud/dataprep/example_dataprep.py index 9f478a5f0be5b..a192f7fec5149 100644 --- a/tests/system/providers/google/cloud/dataprep/example_dataprep.py +++ b/tests/system/providers/google/cloud/dataprep/example_dataprep.py @@ -39,10 +39,10 @@ DAG_ID = "example_dataprep" GCP_PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") -GCS_BUCKET_NAME = f"dataprep-bucket-heorhi-{DAG_ID}-{ENV_ID}" +GCS_BUCKET_NAME = f"dataprep-bucket-{DAG_ID}-{ENV_ID}" GCS_BUCKET_PATH = f"gs://{GCS_BUCKET_NAME}/task_results/" -FLOW_ID = os.environ.get("FLOW_ID", "") +FLOW_ID = os.environ.get("FLOW_ID") RECIPE_ID = os.environ.get("RECIPE_ID") RECIPE_NAME = os.environ.get("RECIPE_NAME") WRITE_SETTINGS = ( diff --git a/tests/system/providers/google/cloud/dataproc/example_dataproc_batch.py b/tests/system/providers/google/cloud/dataproc/example_dataproc_batch.py index 8ca87c59fa41d..160422d197cd7 100644 --- a/tests/system/providers/google/cloud/dataproc/example_dataproc_batch.py +++ b/tests/system/providers/google/cloud/dataproc/example_dataproc_batch.py @@ -37,7 +37,7 @@ ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") DAG_ID = "dataproc_batch" -PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "") +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") REGION = "europe-west1" BATCH_ID = f"test-batch-id-{ENV_ID}" @@ -130,7 +130,7 @@ task_id="cancel_operation", project_id=PROJECT_ID, region=REGION, - operation_name="{{ task_instance.xcom_pull('create_batch') }}", + operation_name="{{ task_instance.xcom_pull('create_batch_4') }}", ) # [END how_to_cloud_dataproc_cancel_operation_operator] diff --git a/tests/system/providers/google/cloud/dataproc/example_dataproc_batch_persistent.py b/tests/system/providers/google/cloud/dataproc/example_dataproc_batch_persistent.py index 55c6b2fafaed1..011b91d24594d 100644 --- a/tests/system/providers/google/cloud/dataproc/example_dataproc_batch_persistent.py +++ b/tests/system/providers/google/cloud/dataproc/example_dataproc_batch_persistent.py @@ -34,7 +34,7 @@ ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") DAG_ID = "dataproc_batch_ps" -PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "") +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}" REGION = "europe-west1" CLUSTER_NAME = f"dataproc-cluster-ps-{ENV_ID}" diff --git a/tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_deferrable.py b/tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_deferrable.py index 8f35404fc5266..1128c4cce7e61 100644 --- a/tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_deferrable.py +++ b/tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_deferrable.py @@ -32,10 +32,10 @@ from airflow.utils.trigger_rule import TriggerRule ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") -DAG_ID = "dataproc_update" -PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "") +DAG_ID = "dataproc_cluster_def" +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") -CLUSTER_NAME = f"cluster-dataproc-update-{ENV_ID}" +CLUSTER_NAME = f"cluster-dataproc-def-{ENV_ID}" REGION = "europe-west1" ZONE = "europe-west1-b" diff --git a/tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_generator.py b/tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_generator.py index 54afd48aa2935..be25251e16954 100644 --- a/tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_generator.py +++ b/tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_generator.py @@ -37,7 +37,7 @@ ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") DAG_ID = "dataproc_cluster_generation" -PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "") +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}" CLUSTER_NAME = f"dataproc-cluster-gen-{ENV_ID}" diff --git a/tests/system/providers/google/cloud/dataproc/example_dataproc_update.py b/tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_update.py similarity index 98% rename from tests/system/providers/google/cloud/dataproc/example_dataproc_update.py rename to tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_update.py index 8cfe7dae36c13..1607e714f5e76 100644 --- a/tests/system/providers/google/cloud/dataproc/example_dataproc_update.py +++ b/tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_update.py @@ -33,7 +33,7 @@ ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") DAG_ID = "dataproc_update" -PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "") +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") CLUSTER_NAME = f"cluster-dataproc-update-{ENV_ID}" REGION = "europe-west1" diff --git a/tests/system/providers/google/cloud/dataproc/example_dataproc_gke.py b/tests/system/providers/google/cloud/dataproc/example_dataproc_gke.py index 499e76bf8cb5d..63e58d2583694 100644 --- a/tests/system/providers/google/cloud/dataproc/example_dataproc_gke.py +++ b/tests/system/providers/google/cloud/dataproc/example_dataproc_gke.py @@ -17,6 +17,13 @@ # under the License. """ Example Airflow DAG that show how to create a Dataproc cluster in Google Kubernetes Engine. + +Required environment variables: +GKE_NAMESPACE = os.environ.get("GKE_NAMESPACE", f"{CLUSTER_NAME}") +A GKE cluster can support multiple DP clusters running in different namespaces. +Define a namespace or assign a default one. +Notice: optional kubernetes_namespace parameter in VIRTUAL_CLUSTER_CONFIG should be the same as GKE_NAMESPACE + """ from __future__ import annotations @@ -24,6 +31,7 @@ from datetime import datetime from airflow import models +from airflow.operators.bash import BashOperator from airflow.providers.google.cloud.operators.dataproc import ( DataprocCreateClusterOperator, DataprocDeleteClusterOperator, @@ -41,14 +49,15 @@ REGION = "us-central1" CLUSTER_NAME = f"cluster-test-build-in-gke{ENV_ID}" GKE_CLUSTER_NAME = f"test-dataproc-gke-cluster-{ENV_ID}" +WORKLOAD_POOL = f"{PROJECT_ID}.svc.id.goog" GKE_CLUSTER_CONFIG = { "name": GKE_CLUSTER_NAME, "workload_identity_config": { - "workload_pool": f"{PROJECT_ID}.svc.id.goog", + "workload_pool": WORKLOAD_POOL, }, "initial_node_count": 1, } - +GKE_NAMESPACE = os.environ.get("GKE_NAMESPACE", f"{CLUSTER_NAME}") # [START how_to_cloud_dataproc_create_cluster_in_gke_config] VIRTUAL_CLUSTER_CONFIG = { @@ -84,6 +93,13 @@ body=GKE_CLUSTER_CONFIG, ) + add_iam_policy_binding = BashOperator( + task_id="add_iam_policy_binding", + bash_command=f"gcloud projects add-iam-policy-binding {PROJECT_ID} " + f"--member=serviceAccount:{WORKLOAD_POOL}[{GKE_NAMESPACE}/agent] " + "--role=roles/iam.workloadIdentityUser", + ) + # [START how_to_cloud_dataproc_create_cluster_operator_in_gke] create_cluster_in_gke = DataprocCreateClusterOperator( task_id="create_cluster_in_gke", @@ -110,7 +126,13 @@ trigger_rule=TriggerRule.ALL_DONE, ) - create_gke_cluster >> create_cluster_in_gke >> [delete_dataproc_cluster, delete_gke_cluster] + ( + create_gke_cluster + >> add_iam_policy_binding + >> create_cluster_in_gke + >> delete_gke_cluster + >> delete_dataproc_cluster + ) from tests.system.utils.watcher import watcher diff --git a/tests/system/providers/google/cloud/dataproc/example_dataproc_hadoop.py b/tests/system/providers/google/cloud/dataproc/example_dataproc_hadoop.py index a937316ebe05a..1eb0307178845 100644 --- a/tests/system/providers/google/cloud/dataproc/example_dataproc_hadoop.py +++ b/tests/system/providers/google/cloud/dataproc/example_dataproc_hadoop.py @@ -35,7 +35,7 @@ ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") DAG_ID = "dataproc_hadoop" -PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "") +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}" CLUSTER_NAME = f"dataproc-hadoop-{ENV_ID}" diff --git a/tests/system/providers/google/cloud/dataproc/example_dataproc_hive.py b/tests/system/providers/google/cloud/dataproc/example_dataproc_hive.py index cf9c99ed3a1fc..ea98725b29931 100644 --- a/tests/system/providers/google/cloud/dataproc/example_dataproc_hive.py +++ b/tests/system/providers/google/cloud/dataproc/example_dataproc_hive.py @@ -33,7 +33,7 @@ ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") DAG_ID = "dataproc_hive" -PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "") +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") CLUSTER_NAME = f"cluster-dataproc-hive-{ENV_ID}" REGION = "europe-west1" diff --git a/tests/system/providers/google/cloud/dataproc/example_dataproc_pig.py b/tests/system/providers/google/cloud/dataproc/example_dataproc_pig.py index e3562a753f3c0..4510ab09b5007 100644 --- a/tests/system/providers/google/cloud/dataproc/example_dataproc_pig.py +++ b/tests/system/providers/google/cloud/dataproc/example_dataproc_pig.py @@ -33,7 +33,7 @@ ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") DAG_ID = "dataproc_pig" -PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "") +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") CLUSTER_NAME = f"cluster-dataproc-pig-{ENV_ID}" REGION = "europe-west1" diff --git a/tests/system/providers/google/cloud/dataproc/example_dataproc_pyspark.py b/tests/system/providers/google/cloud/dataproc/example_dataproc_pyspark.py index c8543ac2becbb..b23482f5ac8a4 100644 --- a/tests/system/providers/google/cloud/dataproc/example_dataproc_pyspark.py +++ b/tests/system/providers/google/cloud/dataproc/example_dataproc_pyspark.py @@ -36,7 +36,7 @@ ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") DAG_ID = "dataproc_pyspark" -PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "") +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}" CLUSTER_NAME = f"cluster-dataproc-pyspark-{ENV_ID}" diff --git a/tests/system/providers/google/cloud/dataproc/example_dataproc_spark.py b/tests/system/providers/google/cloud/dataproc/example_dataproc_spark.py index f3af5a7503a77..761d018166239 100644 --- a/tests/system/providers/google/cloud/dataproc/example_dataproc_spark.py +++ b/tests/system/providers/google/cloud/dataproc/example_dataproc_spark.py @@ -33,7 +33,7 @@ ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") DAG_ID = "dataproc_spark" -PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "") +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") CLUSTER_NAME = f"cluster-dataproc-spark-{ENV_ID}" REGION = "europe-west1" diff --git a/tests/system/providers/google/cloud/dataproc/example_dataproc_spark_async.py b/tests/system/providers/google/cloud/dataproc/example_dataproc_spark_async.py index 201158571e8c3..edb704f6bcc95 100644 --- a/tests/system/providers/google/cloud/dataproc/example_dataproc_spark_async.py +++ b/tests/system/providers/google/cloud/dataproc/example_dataproc_spark_async.py @@ -34,7 +34,7 @@ ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") DAG_ID = "dataproc_spark_async" -PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "") +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") CLUSTER_NAME = f"dataproc-spark-async-{ENV_ID}" REGION = "europe-west1" diff --git a/tests/system/providers/google/cloud/dataproc/example_dataproc_spark_deferrable.py b/tests/system/providers/google/cloud/dataproc/example_dataproc_spark_deferrable.py index d3a9a0d6ef67e..36e333c2ebb46 100644 --- a/tests/system/providers/google/cloud/dataproc/example_dataproc_spark_deferrable.py +++ b/tests/system/providers/google/cloud/dataproc/example_dataproc_spark_deferrable.py @@ -34,7 +34,7 @@ ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") DAG_ID = "dataproc_spark_deferrable" -PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "") +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") CLUSTER_NAME = f"cluster-dataproc-spark-{ENV_ID}" REGION = "europe-west1" diff --git a/tests/system/providers/google/cloud/dataproc/example_dataproc_spark_sql.py b/tests/system/providers/google/cloud/dataproc/example_dataproc_spark_sql.py index 5f31381c2e5fe..ff742aef33bec 100644 --- a/tests/system/providers/google/cloud/dataproc/example_dataproc_spark_sql.py +++ b/tests/system/providers/google/cloud/dataproc/example_dataproc_spark_sql.py @@ -33,7 +33,7 @@ ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") DAG_ID = "dataproc_spark_sql" -PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "") +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") CLUSTER_NAME = f"dataproc-spark-sql-{ENV_ID}" REGION = "europe-west1" diff --git a/tests/system/providers/google/cloud/dataproc/example_dataproc_sparkr.py b/tests/system/providers/google/cloud/dataproc/example_dataproc_sparkr.py index ce68eb574f4d9..80d4bca96f5cf 100644 --- a/tests/system/providers/google/cloud/dataproc/example_dataproc_sparkr.py +++ b/tests/system/providers/google/cloud/dataproc/example_dataproc_sparkr.py @@ -36,7 +36,7 @@ ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") DAG_ID = "dataproc_sparkr" -PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "") +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}" CLUSTER_NAME = f"dataproc-sparkr-{ENV_ID}" diff --git a/tests/system/providers/google/cloud/dataproc_metastore/example_dataproc_metastore.py b/tests/system/providers/google/cloud/dataproc_metastore/example_dataproc_metastore.py index 1b6341a73a7c9..e3c6d2a3b5a43 100644 --- a/tests/system/providers/google/cloud/dataproc_metastore/example_dataproc_metastore.py +++ b/tests/system/providers/google/cloud/dataproc_metastore/example_dataproc_metastore.py @@ -116,7 +116,7 @@ # [END how_to_cloud_dataproc_metastore_create_service_operator] # [START how_to_cloud_dataproc_metastore_get_service_operator] - get_service_details = DataprocMetastoreGetServiceOperator( + get_service = DataprocMetastoreGetServiceOperator( task_id="get_service", region=REGION, project_id=PROJECT_ID, @@ -138,7 +138,7 @@ # [START how_to_cloud_dataproc_metastore_create_metadata_import_operator] import_metadata = DataprocMetastoreCreateMetadataImportOperator( - task_id="create_metadata_import", + task_id="import_metadata", project_id=PROJECT_ID, region=REGION, service_id=SERVICE_ID, @@ -176,8 +176,9 @@ ( create_bucket + >> upload_file >> create_service - >> get_service_details + >> get_service >> update_service >> import_metadata >> export_metadata diff --git a/tests/system/providers/google/cloud/gcs/example_firestore.py b/tests/system/providers/google/cloud/gcs/example_firestore.py index 9be3b8dd8dcf3..12953e168dfec 100644 --- a/tests/system/providers/google/cloud/gcs/example_firestore.py +++ b/tests/system/providers/google/cloud/gcs/example_firestore.py @@ -45,7 +45,6 @@ import os from datetime import datetime -from urllib.parse import urlsplit from airflow import models from airflow.providers.google.cloud.operators.bigquery import ( @@ -59,15 +58,18 @@ from airflow.utils.trigger_rule import TriggerRule ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") -DAG_ID = "example_google_firestore" +DAG_ID = "example_gcp_firestore" GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-gcp-project") FIRESTORE_PROJECT_ID = os.environ.get("G_FIRESTORE_PROJECT_ID", "example-firebase-project") BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}" DATASET_NAME = f"dataset_{DAG_ID}_{ENV_ID}" -EXPORT_DESTINATION_URL = os.environ.get("GCP_FIRESTORE_ARCHIVE_URL", "gs://INVALID BUCKET NAME/namespace/") -EXPORT_PREFIX = urlsplit(EXPORT_DESTINATION_URL).path +EXPORT_DESTINATION_URL = os.environ.get("GCP_FIRESTORE_ARCHIVE_URL", f"gs://{BUCKET_NAME}/namespace/") EXPORT_COLLECTION_ID = os.environ.get("GCP_FIRESTORE_COLLECTION_ID", "firestore_collection_id") +EXTERNAL_TABLE_SOURCE_URI = ( + f"{EXPORT_DESTINATION_URL}/all_namespaces/kind_{EXPORT_COLLECTION_ID}" + f"/all_namespaces_kind_{EXPORT_COLLECTION_ID}.export_metadata" +) DATASET_LOCATION = os.environ.get("GCP_FIRESTORE_DATASET_LOCATION", "EU") if BUCKET_NAME is None: @@ -80,7 +82,9 @@ catchup=False, tags=["example", "firestore"], ) as dag: - create_bucket = GCSCreateBucketOperator(task_id="create_bucket", bucket_name=BUCKET_NAME) + create_bucket = GCSCreateBucketOperator( + task_id="create_bucket", bucket_name=BUCKET_NAME, location=DATASET_LOCATION + ) create_dataset = BigQueryCreateEmptyDatasetOperator( task_id="create_dataset", @@ -107,16 +111,10 @@ "datasetId": DATASET_NAME, "tableId": "firestore_data", }, - "schema": { - "fields": [ - {"name": "name", "type": "STRING"}, - {"name": "post_abbr", "type": "STRING"}, - ] - }, "externalDataConfiguration": { "sourceFormat": "DATASTORE_BACKUP", "compression": "NONE", - "csvOptions": {"skipLeadingRows": 1}, + "sourceUris": [EXTERNAL_TABLE_SOURCE_URI], }, }, ) @@ -146,15 +144,13 @@ ( # TEST SETUP - create_bucket - >> create_dataset + [create_bucket, create_dataset] # TEST BODY >> export_database_to_gcs >> create_external_table_multiple_types >> read_data_from_gcs_multiple_types # TEST TEARDOWN - >> delete_dataset - >> delete_bucket + >> [delete_dataset, delete_bucket] ) from tests.system.utils.watcher import watcher diff --git a/tests/system/providers/google/cloud/gcs/example_gcs_sensor.py b/tests/system/providers/google/cloud/gcs/example_gcs_sensor.py index 988d989767d6b..5523f2ac2c2c6 100644 --- a/tests/system/providers/google/cloud/gcs/example_gcs_sensor.py +++ b/tests/system/providers/google/cloud/gcs/example_gcs_sensor.py @@ -26,7 +26,6 @@ from airflow import models from airflow.models.baseoperator import chain -from airflow.operators.bash import BashOperator from airflow.providers.google.cloud.operators.gcs import GCSCreateBucketOperator, GCSDeleteBucketOperator from airflow.providers.google.cloud.sensors.gcs import ( GCSObjectExistenceAsyncSensor, @@ -157,12 +156,9 @@ def mode_setter(self, value): task_id="delete_bucket", bucket_name=BUCKET_NAME, trigger_rule=TriggerRule.ALL_DONE ) - sleep = BashOperator(task_id="sleep", bash_command="sleep 5") - chain( # TEST SETUP create_bucket, - sleep, upload_file, # TEST BODY [ diff --git a/tests/system/providers/google/cloud/gcs/example_gcs_to_bigquery_async.py b/tests/system/providers/google/cloud/gcs/example_gcs_to_bigquery_async.py index 5888056456bc3..0bda9be4c3f3e 100644 --- a/tests/system/providers/google/cloud/gcs/example_gcs_to_bigquery_async.py +++ b/tests/system/providers/google/cloud/gcs/example_gcs_to_bigquery_async.py @@ -80,7 +80,7 @@ write_disposition="WRITE_TRUNCATE", external_table=False, autodetect=True, - max_id_key=MAX_ID_STR, + max_id_key="string_field_0", deferrable=True, ) diff --git a/tests/system/providers/google/cloud/gcs/example_gcs_to_gcs.py b/tests/system/providers/google/cloud/gcs/example_gcs_to_gcs.py index b377be2e55add..7931295d23e79 100644 --- a/tests/system/providers/google/cloud/gcs/example_gcs_to_gcs.py +++ b/tests/system/providers/google/cloud/gcs/example_gcs_to_gcs.py @@ -43,7 +43,7 @@ BUCKET_NAME_SRC = f"bucket_{DAG_ID}_{ENV_ID}" BUCKET_NAME_DST = f"bucket_dst_{DAG_ID}_{ENV_ID}" -RANDOM_FILE_NAME = OBJECT_1 = OBJECT_2 = "random.bin" +RANDOM_FILE_NAME = OBJECT_1 = OBJECT_2 = "/tmp/random.bin" with models.DAG( diff --git a/tests/system/providers/google/cloud/gcs/example_s3_to_gcs.py b/tests/system/providers/google/cloud/gcs/example_s3_to_gcs.py index 063d5a67430b8..f735b3ea55a3f 100644 --- a/tests/system/providers/google/cloud/gcs/example_s3_to_gcs.py +++ b/tests/system/providers/google/cloud/gcs/example_s3_to_gcs.py @@ -18,6 +18,7 @@ import os from datetime import datetime +from pathlib import Path from airflow import models from airflow.decorators import task @@ -31,9 +32,10 @@ GCP_PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") DAG_ID = "example_s3_to_gcs" -BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}" +BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}".replace("_", "-") GCS_BUCKET_URL = f"gs://{BUCKET_NAME}/" -UPLOAD_FILE = "/tmp/example-file.txt" +FILE_NAME = "example_upload.txt" +UPLOAD_FILE = str(Path(__file__).parent / "resources" / FILE_NAME) PREFIX = "TESTS" diff --git a/tests/system/providers/google/cloud/gcs/example_trino_to_gcs.py b/tests/system/providers/google/cloud/gcs/example_trino_to_gcs.py index d5dc4f72c670e..96690f71471a7 100644 --- a/tests/system/providers/google/cloud/gcs/example_trino_to_gcs.py +++ b/tests/system/providers/google/cloud/gcs/example_trino_to_gcs.py @@ -31,6 +31,7 @@ BigQueryDeleteDatasetOperator, BigQueryInsertJobOperator, ) +from airflow.providers.google.cloud.operators.gcs import GCSCreateBucketOperator, GCSDeleteBucketOperator from airflow.providers.google.cloud.transfers.trino_to_gcs import TrinoToGCSOperator from airflow.utils.trigger_rule import TriggerRule @@ -41,7 +42,7 @@ GCS_BUCKET = f"bucket_{DAG_ID}_{ENV_ID}" DATASET_NAME = f"dataset_{DAG_ID}_{ENV_ID}" -SOURCE_MULTIPLE_TYPES = "memory.default.test_multiple_types" +SOURCE_SCHEMA_COLUMNS = "memory.information_schema.columns" SOURCE_CUSTOMER_TABLE = "tpch.sf1.customer" @@ -68,22 +69,26 @@ def safe_name(s: str) -> str: trigger_rule=TriggerRule.ALL_DONE, ) + create_bucket = GCSCreateBucketOperator(task_id="create_bucket", bucket_name=GCS_BUCKET) + + delete_bucket = GCSDeleteBucketOperator(task_id="delete_bucket", bucket_name=GCS_BUCKET) + # [START howto_operator_trino_to_gcs_basic] trino_to_gcs_basic = TrinoToGCSOperator( task_id="trino_to_gcs_basic", - sql=f"select * from {SOURCE_MULTIPLE_TYPES}", + sql=f"select * from {SOURCE_SCHEMA_COLUMNS}", bucket=GCS_BUCKET, - filename=f"{safe_name(SOURCE_MULTIPLE_TYPES)}.{{}}.json", + filename=f"{safe_name(SOURCE_SCHEMA_COLUMNS)}.{{}}.json", ) # [END howto_operator_trino_to_gcs_basic] # [START howto_operator_trino_to_gcs_multiple_types] trino_to_gcs_multiple_types = TrinoToGCSOperator( task_id="trino_to_gcs_multiple_types", - sql=f"select * from {SOURCE_MULTIPLE_TYPES}", + sql=f"select * from {SOURCE_SCHEMA_COLUMNS}", bucket=GCS_BUCKET, - filename=f"{safe_name(SOURCE_MULTIPLE_TYPES)}.{{}}.json", - schema_filename=f"{safe_name(SOURCE_MULTIPLE_TYPES)}-schema.json", + filename=f"{safe_name(SOURCE_SCHEMA_COLUMNS)}.{{}}.json", + schema_filename=f"{safe_name(SOURCE_SCHEMA_COLUMNS)}-schema.json", gzip=False, ) # [END howto_operator_trino_to_gcs_multiple_types] @@ -96,22 +101,28 @@ def safe_name(s: str) -> str: "tableReference": { "projectId": GCP_PROJECT_ID, "datasetId": DATASET_NAME, - "tableId": f"{safe_name(SOURCE_MULTIPLE_TYPES)}", + "tableId": f"{safe_name(SOURCE_SCHEMA_COLUMNS)}", }, "schema": { "fields": [ - {"name": "name", "type": "STRING"}, - {"name": "post_abbr", "type": "STRING"}, - ] + {"name": "table_catalog", "type": "STRING"}, + {"name": "table_schema", "type": "STRING"}, + {"name": "table_name", "type": "STRING"}, + {"name": "column_name", "type": "STRING"}, + {"name": "ordinal_position", "type": "INT64"}, + {"name": "column_default", "type": "STRING"}, + {"name": "is_nullable", "type": "STRING"}, + {"name": "data_type", "type": "STRING"}, + ], }, "externalDataConfiguration": { "sourceFormat": "NEWLINE_DELIMITED_JSON", "compression": "NONE", - "csvOptions": {"skipLeadingRows": 1}, + "sourceUris": [f"gs://{GCS_BUCKET}/{safe_name(SOURCE_SCHEMA_COLUMNS)}.*.json"], }, }, - source_objects=[f"{safe_name(SOURCE_MULTIPLE_TYPES)}.*.json"], - schema_object=f"{safe_name(SOURCE_MULTIPLE_TYPES)}-schema.json", + source_objects=[f"{safe_name(SOURCE_SCHEMA_COLUMNS)}.*.json"], + schema_object=f"{safe_name(SOURCE_SCHEMA_COLUMNS)}-schema.json", ) # [END howto_operator_create_external_table_multiple_types] @@ -120,7 +131,7 @@ def safe_name(s: str) -> str: configuration={ "query": { "query": f"SELECT COUNT(*) FROM `{GCP_PROJECT_ID}.{DATASET_NAME}." - f"{safe_name(SOURCE_MULTIPLE_TYPES)}`", + f"{safe_name(SOURCE_SCHEMA_COLUMNS)}`", "useLegacySql": False, } }, @@ -149,14 +160,20 @@ def safe_name(s: str) -> str: }, "schema": { "fields": [ + {"name": "custkey", "type": "INT64"}, {"name": "name", "type": "STRING"}, - {"name": "post_abbr", "type": "STRING"}, + {"name": "address", "type": "STRING"}, + {"name": "nationkey", "type": "INT64"}, + {"name": "phone", "type": "STRING"}, + {"name": "acctbal", "type": "FLOAT64"}, + {"name": "mktsegment", "type": "STRING"}, + {"name": "comment", "type": "STRING"}, ] }, "externalDataConfiguration": { "sourceFormat": "NEWLINE_DELIMITED_JSON", "compression": "NONE", - "csvOptions": {"skipLeadingRows": 1}, + "sourceUris": [f"gs://{GCS_BUCKET}/{safe_name(SOURCE_CUSTOMER_TABLE)}.*.json"], }, }, source_objects=[f"{safe_name(SOURCE_CUSTOMER_TABLE)}.*.json"], @@ -179,17 +196,17 @@ def safe_name(s: str) -> str: # [START howto_operator_trino_to_gcs_csv] trino_to_gcs_csv = TrinoToGCSOperator( task_id="trino_to_gcs_csv", - sql=f"select * from {SOURCE_MULTIPLE_TYPES}", + sql=f"select * from {SOURCE_SCHEMA_COLUMNS}", bucket=GCS_BUCKET, - filename=f"{safe_name(SOURCE_MULTIPLE_TYPES)}.{{}}.csv", - schema_filename=f"{safe_name(SOURCE_MULTIPLE_TYPES)}-schema.json", + filename=f"{safe_name(SOURCE_SCHEMA_COLUMNS)}.{{}}.csv", + schema_filename=f"{safe_name(SOURCE_SCHEMA_COLUMNS)}-schema.json", export_format="csv", ) # [END howto_operator_trino_to_gcs_csv] ( # TEST SETUP - create_dataset + [create_dataset, create_bucket] # TEST BODY >> trino_to_gcs_basic >> trino_to_gcs_multiple_types @@ -200,7 +217,7 @@ def safe_name(s: str) -> str: >> read_data_from_gcs_multiple_types >> read_data_from_gcs_many_chunks # TEST TEARDOWN - >> delete_dataset + >> [delete_dataset, delete_bucket] ) from tests.system.utils.watcher import watcher diff --git a/tests/system/providers/google/cloud/gcs/resources/us-states.csv b/tests/system/providers/google/cloud/gcs/resources/us-states.csv new file mode 100644 index 0000000000000..52cc001884b5a --- /dev/null +++ b/tests/system/providers/google/cloud/gcs/resources/us-states.csv @@ -0,0 +1,51 @@ +name,post_abbr +Alabama,AL +Alaska,AK +Arizona,AZ +Arkansas,AR +California,CA +Colorado,CO +Connecticut,CT +Delaware,DE +Florida,FL +Georgia,GA +Hawaii,HI +Idaho,ID +Illinois,IL +Indiana,IN +Iowa,IA +Kansas,KS +Kentucky,KY +Louisiana,LA +Maine,ME +Maryland,MD +Massachusetts,MA +Michigan,MI +Minnesota,MN +Mississippi,MS +Missouri,MO +Montana,MT +Nebraska,NE +Nevada,NV +New Hampshire,NH +New Jersey,NJ +New Mexico,NM +New York,NY +North Carolina,NC +North Dakota,ND +Ohio,OH +Oklahoma,OK +Oregon,OR +Pennsylvania,PA +Rhode Island,RI +South Carolina,SC +South Dakota,SD +Tennessee,TN +Texas,TX +Utah,UT +Vermont,VT +Virginia,VA +Washington,WA +West Virginia,WV +Wisconsin,WI +Wyoming,WY diff --git a/tests/system/providers/google/cloud/ml_engine/example_mlengine.py b/tests/system/providers/google/cloud/ml_engine/example_mlengine.py index f61e35298443d..491ce66520d28 100644 --- a/tests/system/providers/google/cloud/ml_engine/example_mlengine.py +++ b/tests/system/providers/google/cloud/ml_engine/example_mlengine.py @@ -49,14 +49,14 @@ DAG_ID = "example_gcp_mlengine" PREDICT_FILE_NAME = "predict.json" -MODEL_NAME = f"example_mlengine_model_{ENV_ID}" -BUCKET_NAME = f"example_mlengine_bucket_{ENV_ID}" +MODEL_NAME = f"example_ml_model_{ENV_ID}" +BUCKET_NAME = f"example_ml_bucket_{ENV_ID}" BUCKET_PATH = f"gs://{BUCKET_NAME}" JOB_DIR = f"{BUCKET_PATH}/job-dir" SAVED_MODEL_PATH = f"{JOB_DIR}/" PREDICTION_INPUT = f"{BUCKET_PATH}/{PREDICT_FILE_NAME}" PREDICTION_OUTPUT = f"{BUCKET_PATH}/prediction_output/" -TRAINER_URI = "gs://system-tests-resources/example_gcp_mlengine/trainer-0.1.tar.gz" +TRAINER_URI = "gs://system-tests-resources/example_gcp_mlengine/trainer-0.2.tar.gz" TRAINER_PY_MODULE = "trainer.task" SUMMARY_TMP = f"{BUCKET_PATH}/tmp/" SUMMARY_STAGING = f"{BUCKET_PATH}/staging/" @@ -66,7 +66,7 @@ def generate_model_predict_input_data() -> list[int]: - return [i for i in range(0, 201, 10)] + return [1, 4, 9, 16, 25, 36] with models.DAG( @@ -104,7 +104,7 @@ def write_predict_file(path_to_file: str): project_id=PROJECT_ID, region="us-central1", job_id="training-job-{{ ts_nodash }}-{{ params.model_name }}", - runtime_version="1.15", + runtime_version="2.1", python_version="3.7", job_dir=JOB_DIR, package_uris=[TRAINER_URI], @@ -148,7 +148,7 @@ def write_predict_file(path_to_file: str): "name": "v1", "description": "First-version", "deployment_uri": JOB_DIR, - "runtime_version": "1.15", + "runtime_version": "2.1", "machineType": "mls1-c1-m2", "framework": "TENSORFLOW", "pythonVersion": "3.7", @@ -165,7 +165,7 @@ def write_predict_file(path_to_file: str): "name": "v2", "description": "Second version", "deployment_uri": JOB_DIR, - "runtime_version": "1.15", + "runtime_version": "2.1", "machineType": "mls1-c1-m2", "framework": "TENSORFLOW", "pythonVersion": "3.7", @@ -252,7 +252,7 @@ def get_metric_fn_and_keys(): """ def normalize_value(inst: dict): - val = float(inst["output_layer"][0]) + val = int(inst["output_layer"][0]) return tuple([val]) # returns a tuple. return normalize_value, ["val"] # key order must match. diff --git a/tests/system/providers/google/cloud/ml_engine/example_mlengine_async.py b/tests/system/providers/google/cloud/ml_engine/example_mlengine_async.py index b870754b0c16a..ef5dbee2cc35e 100644 --- a/tests/system/providers/google/cloud/ml_engine/example_mlengine_async.py +++ b/tests/system/providers/google/cloud/ml_engine/example_mlengine_async.py @@ -48,15 +48,15 @@ ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") DAG_ID = "async_example_gcp_mlengine" -PREDICT_FILE_NAME = "predict.json" -MODEL_NAME = f"example_mlengine_model_{ENV_ID}" -BUCKET_NAME = f"example_mlengine_bucket_{ENV_ID}" +PREDICT_FILE_NAME = "async_predict.json" +MODEL_NAME = f"example_async_ml_model_{ENV_ID}" +BUCKET_NAME = f"example_async_ml_bucket_{ENV_ID}" BUCKET_PATH = f"gs://{BUCKET_NAME}" JOB_DIR = f"{BUCKET_PATH}/job-dir" SAVED_MODEL_PATH = f"{JOB_DIR}/" PREDICTION_INPUT = f"{BUCKET_PATH}/{PREDICT_FILE_NAME}" PREDICTION_OUTPUT = f"{BUCKET_PATH}/prediction_output/" -TRAINER_URI = "gs://system-tests-resources/example_gcp_mlengine/trainer-0.1.tar.gz" +TRAINER_URI = "gs://system-tests-resources/example_gcp_mlengine/async-trainer-0.2.tar.gz" TRAINER_PY_MODULE = "trainer.task" SUMMARY_TMP = f"{BUCKET_PATH}/tmp/" SUMMARY_STAGING = f"{BUCKET_PATH}/staging/" @@ -66,7 +66,7 @@ def generate_model_predict_input_data() -> list[int]: - return [i for i in range(0, 201, 10)] + return [1, 4, 9, 16, 25, 36] with models.DAG( @@ -104,7 +104,7 @@ def write_predict_file(path_to_file: str): project_id=PROJECT_ID, region="us-central1", job_id="async_training-job-{{ ts_nodash }}-{{ params.model_name }}", - runtime_version="1.15", + runtime_version="2.1", python_version="3.7", job_dir=JOB_DIR, package_uris=[TRAINER_URI], @@ -149,7 +149,7 @@ def write_predict_file(path_to_file: str): "name": "v1", "description": "First-version", "deployment_uri": JOB_DIR, - "runtime_version": "1.15", + "runtime_version": "2.1", "machineType": "mls1-c1-m2", "framework": "TENSORFLOW", "pythonVersion": "3.7", @@ -166,7 +166,7 @@ def write_predict_file(path_to_file: str): "name": "v2", "description": "Second version", "deployment_uri": JOB_DIR, - "runtime_version": "1.15", + "runtime_version": "2.1", "machineType": "mls1-c1-m2", "framework": "TENSORFLOW", "pythonVersion": "3.7", @@ -202,7 +202,7 @@ def write_predict_file(path_to_file: str): prediction = MLEngineStartBatchPredictionJobOperator( task_id="prediction", project_id=PROJECT_ID, - job_id="prediction-{{ ts_nodash }}-{{ params.model_name }}", + job_id="async-prediction-{{ ts_nodash }}-{{ params.model_name }}", region="us-central1", model_name=MODEL_NAME, data_format="TEXT", @@ -253,7 +253,7 @@ def get_metric_fn_and_keys(): """ def normalize_value(inst: dict): - val = float(inst["output_layer"][0]) + val = int(inst["output_layer"][0]) return tuple([val]) # returns a tuple. return normalize_value, ["val"] # key order must match. @@ -284,7 +284,7 @@ def validate_err_and_count(summary: dict) -> dict: prediction_path=PREDICTION_OUTPUT, metric_fn_and_keys=get_metric_fn_and_keys(), validate_fn=validate_err_and_count, - batch_prediction_job_id="evaluate-ops-{{ ts_nodash }}-{{ params.model_name }}", + batch_prediction_job_id="async-evaluate-ops-{{ ts_nodash }}-{{ params.model_name }}", project_id=PROJECT_ID, region="us-central1", dataflow_options={ diff --git a/tests/system/providers/google/cloud/natural_language/example_natural_language.py b/tests/system/providers/google/cloud/natural_language/example_natural_language.py index 34591dde40a3d..9c9cae26b86d6 100644 --- a/tests/system/providers/google/cloud/natural_language/example_natural_language.py +++ b/tests/system/providers/google/cloud/natural_language/example_natural_language.py @@ -23,7 +23,7 @@ import os from datetime import datetime -from google.cloud.language_v1.proto.language_service_pb2 import Document +from google.cloud.language_v1 import Document from airflow import models from airflow.operators.bash import BashOperator diff --git a/tests/system/providers/google/cloud/sql_to_sheets/example_sql_to_sheets.py b/tests/system/providers/google/cloud/sql_to_sheets/example_sql_to_sheets.py index 6a11b27b7b5a0..4ff49ec9aab74 100644 --- a/tests/system/providers/google/cloud/sql_to_sheets/example_sql_to_sheets.py +++ b/tests/system/providers/google/cloud/sql_to_sheets/example_sql_to_sheets.py @@ -15,6 +15,55 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +""" +Required environment variables: +``` +DB_CONNECTION = os.environ.get("DB_CONNECTION") +SPREADSHEET_ID = os.environ.get("SPREADSHEET_ID", "test-id") +``` + +First, you need a db instance that is accessible from the Airflow environment. +You can, for example, create a Cloud SQL instance and connect to it from +within breeze with Cloud SQL proxy: +https://cloud.google.com/sql/docs/postgres/connect-instance-auth-proxy + +# DB setup +Create db: +``` +CREATE DATABASE test_db; +``` + +Switch to db: +``` +\c test_db +``` + +Create table and insert some rows +``` +CREATE TABLE test_table (col1 INT, col2 INT); +INSERT INTO test_table VALUES (1,2), (3,4), (5,6), (7,8); +``` + +# Setup connections +db connection: +In airflow UI, set one db connection, for example "postgres_default" +and make sure the "Test" at the bottom succeeds + +google cloud connection: +We need additional scopes for this test +scopes: https://www.googleapis.com/auth/spreadsheets, https://www.googleapis.com/auth/cloud-platform + +# Sheet +Finally, you need a Google Sheet you have access to, for testing you can +create a public sheet and get it's ID. + +# Tear Down +You can delete the db with +``` +DROP DATABASE test_db; +``` +""" from __future__ import annotations import os @@ -24,10 +73,11 @@ from airflow.providers.google.suite.transfers.sql_to_sheets import SQLToGoogleSheetsOperator ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") -DAG_ID = "example_sql_to_sheets" +DB_CONNECTION = os.environ.get("DB_CONNECTION") +SPREADSHEET_ID = os.environ.get("SPREADSHEET_ID", "test-id") -SQL = "select 1 as my_col" -NEW_SPREADSHEET_ID = os.environ.get("NEW_SPREADSHEET_ID", "123") +DAG_ID = "example_sql_to_sheets" +SQL = "select col2 from test_table" with models.DAG( DAG_ID, @@ -40,8 +90,9 @@ upload_gcs_to_sheet = SQLToGoogleSheetsOperator( task_id="upload_sql_to_sheet", sql=SQL, - sql_conn_id="database_conn_id", - spreadsheet_id=NEW_SPREADSHEET_ID, + sql_conn_id=DB_CONNECTION, + database="test_db", + spreadsheet_id=SPREADSHEET_ID, ) # [END upload_sql_to_sheets] diff --git a/tests/system/providers/google/cloud/stackdriver/example_stackdriver.py b/tests/system/providers/google/cloud/stackdriver/example_stackdriver.py index db8e0054c7cc0..12c68a4afd392 100644 --- a/tests/system/providers/google/cloud/stackdriver/example_stackdriver.py +++ b/tests/system/providers/google/cloud/stackdriver/example_stackdriver.py @@ -112,15 +112,15 @@ TEST_NOTIFICATION_CHANNEL_1 = { "display_name": CHANNEL_1_NAME, "enabled": True, - "labels": {"auth_token": "top-secret", "channel_name": "#channel"}, - "type_": "slack", + "labels": {"topic": f"projects/{PROJECT_ID}/topics/notificationTopic"}, + "type": "pubsub", } TEST_NOTIFICATION_CHANNEL_2 = { "display_name": CHANNEL_2_NAME, "enabled": False, - "labels": {"auth_token": "top-secret", "channel_name": "#channel"}, - "type_": "slack", + "labels": {"topic": f"projects/{PROJECT_ID}/topics/notificationTopic2"}, + "type": "pubsub", } with models.DAG( @@ -139,7 +139,7 @@ # [START howto_operator_gcp_stackdriver_enable_notification_channel] enable_notification_channel = StackdriverEnableNotificationChannelsOperator( - task_id="enable-notification-channel", filter_='type="slack"' + task_id="enable-notification-channel", filter_='type="pubsub"' ) # [END howto_operator_gcp_stackdriver_enable_notification_channel] @@ -151,7 +151,7 @@ # [START howto_operator_gcp_stackdriver_list_notification_channel] list_notification_channel = StackdriverListNotificationChannelsOperator( - task_id="list-notification-channel", filter_='type="slack"' + task_id="list-notification-channel", filter_='type="pubsub"' ) # [END howto_operator_gcp_stackdriver_list_notification_channel] diff --git a/tests/system/providers/google/cloud/tasks/example_queue.py b/tests/system/providers/google/cloud/tasks/example_queue.py index b54a217be9142..6fd2647b02b4c 100644 --- a/tests/system/providers/google/cloud/tasks/example_queue.py +++ b/tests/system/providers/google/cloud/tasks/example_queue.py @@ -18,6 +18,9 @@ """ Example Airflow DAG that creates, gets, lists, updates, purges, pauses, resumes and deletes Queues in the Google Cloud Tasks service in the Google Cloud. + +Required setup: +- GCP_APP_ENGINE_LOCATION: GCP Project's App Engine location `gcloud app describe | grep locationId`. """ from __future__ import annotations @@ -47,7 +50,7 @@ ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") DAG_ID = "cloud_tasks_queue" -LOCATION = "europe-west2" +LOCATION = os.environ.get("GCP_APP_ENGINE_LOCATION", "europe-west2") QUEUE_ID = f"queue-{ENV_ID}-{DAG_ID.replace('_', '-')}" diff --git a/tests/system/providers/google/cloud/tasks/example_tasks.py b/tests/system/providers/google/cloud/tasks/example_tasks.py index 6916852a4b999..22eb6561a9ea0 100644 --- a/tests/system/providers/google/cloud/tasks/example_tasks.py +++ b/tests/system/providers/google/cloud/tasks/example_tasks.py @@ -18,6 +18,9 @@ """ Example Airflow DAG that creates and deletes Queues and creates, gets, lists, runs and deletes Tasks in the Google Cloud Tasks service in the Google Cloud. + +Required setup: +- GCP_APP_ENGINE_LOCATION: GCP Project's App Engine location `gcloud app describe | grep locationId`. """ from __future__ import annotations @@ -48,7 +51,7 @@ timestamp = timestamp_pb2.Timestamp() timestamp.FromDatetime(datetime.now() + timedelta(hours=12)) -LOCATION = "europe-west2" +LOCATION = os.environ.get("GCP_APP_ENGINE_LOCATION", "europe-west2") # queue cannot use recent names even if queue was removed QUEUE_ID = f"queue-{ENV_ID}-{DAG_ID.replace('_', '-')}" TASK_NAME = "task-to-run" @@ -127,6 +130,7 @@ def generate_random_string(): location=LOCATION, queue_name=QUEUE_ID + "{{ task_instance.xcom_pull(task_ids='random_string') }}", task_name=TASK_NAME + "{{ task_instance.xcom_pull(task_ids='random_string') }}", + retry=Retry(maximum=10.0), task_id="run_task", ) # [END run_task] diff --git a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_container.py b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_container.py index 42b915f20f307..4ca7c42897584 100644 --- a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_container.py +++ b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_container.py @@ -56,9 +56,9 @@ CUSTOM_CONTAINER_GCS_BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}" DATA_SAMPLE_GCS_OBJECT_NAME = "vertex-ai/california_housing_train.csv" -CSV_FILE_LOCAL_PATH = "/custom-job/california_housing_train.csv" +CSV_FILE_LOCAL_PATH = "/custom-job-container/california_housing_train.csv" RESOURCES_PATH = Path(__file__).parent / "resources" -CSV_ZIP_FILE_LOCAL_PATH = str(RESOURCES_PATH / "California-housing.zip") +CSV_ZIP_FILE_LOCAL_PATH = str(RESOURCES_PATH / "California-housing-custom-container.zip") TABULAR_DATASET = lambda bucket_name: { "display_name": f"tabular-dataset-{ENV_ID}", @@ -96,7 +96,8 @@ ) unzip_file = BashOperator( task_id="unzip_csv_data_file", - bash_command=f"mkdir -p /custom-job/ && unzip {CSV_ZIP_FILE_LOCAL_PATH} -d /custom-job/", + bash_command=f"mkdir -p /custom-job-container/ && " + f"unzip {CSV_ZIP_FILE_LOCAL_PATH} -d /custom-job-container/", ) upload_files = LocalFilesystemToGCSOperator( task_id="upload_file_to_bucket", @@ -158,7 +159,7 @@ ) clear_folder = BashOperator( task_id="clear_folder", - bash_command="rm -r /custom-job/*", + bash_command="rm -r /custom-job-container/*", ) ( diff --git a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_job.py b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_job.py index aaa289d9a0c49..2f1fb2ae6df86 100644 --- a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_job.py +++ b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_job.py @@ -56,7 +56,8 @@ CUSTOM_GCS_BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}" DATA_SAMPLE_GCS_OBJECT_NAME = "vertex-ai/california_housing_train.csv" -CSV_ZIP_FILE_LOCAL_PATH = str(Path(__file__).parent / "resources" / "California-housing.zip") +RESOURCES_PATH = Path(__file__).parent / "resources" +CSV_ZIP_FILE_LOCAL_PATH = str(RESOURCES_PATH / "California-housing-custom-job.zip") CSV_FILE_LOCAL_PATH = "/custom-job/california_housing_train.csv" TABULAR_DATASET = lambda bucket_name: { diff --git a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_job_python_package.py b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_job_python_package.py index c3a2a27c73ba1..4529f6af155ad 100644 --- a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_job_python_package.py +++ b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_job_python_package.py @@ -57,9 +57,9 @@ DATA_SAMPLE_GCS_OBJECT_NAME = "vertex-ai/california_housing_train.csv" RESOURCES_PATH = Path(__file__).parent / "resources" -CSV_ZIP_FILE_LOCAL_PATH = str(RESOURCES_PATH / "California-housing.zip") -CSV_FILE_LOCAL_PATH = "/custom-job/california_housing_train.csv" -TAR_FILE_LOCAL_PATH = "/custom-job/custom_trainer_script-0.1.tar" +CSV_ZIP_FILE_LOCAL_PATH = str(RESOURCES_PATH / "California-housing-python-package.zip") +CSV_FILE_LOCAL_PATH = "/custom-job-python/california_housing_train.csv" +TAR_FILE_LOCAL_PATH = "/custom-job-python/custom_trainer_script-0.1.tar" FILES_TO_UPLOAD = [ CSV_FILE_LOCAL_PATH, TAR_FILE_LOCAL_PATH, @@ -103,7 +103,7 @@ ) unzip_file = BashOperator( task_id="unzip_csv_data_file", - bash_command=f"mkdir -p /custom-job && unzip {CSV_ZIP_FILE_LOCAL_PATH} -d /custom-job/", + bash_command=f"mkdir -p /custom-job-python && unzip {CSV_ZIP_FILE_LOCAL_PATH} -d /custom-job-python/", ) upload_files = LocalFilesystemToGCSOperator( task_id="upload_file_to_bucket", @@ -166,7 +166,7 @@ ) clear_folder = BashOperator( task_id="clear_folder", - bash_command="rm -r /custom-job/*", + bash_command="rm -r /custom-job-python/*", ) ( diff --git a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_dataset.py b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_dataset.py index 270103b7f1e23..71992673a8898 100644 --- a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_dataset.py +++ b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_dataset.py @@ -114,7 +114,9 @@ "gs://google-cloud-aiplatform/schema/dataset/ioformat/image_bounding_box_io_format_1.0.0.yaml" ), "gcs_source": { - "uris": ["gs://ucaip-test-us-central1/dataset/salads_oid_ml_use_public_unassigned.jsonl"] + "uris": [ + "gs://system-tests-resources/vertex-ai/dataset/salads_oid_ml_use_public_unassigned.jsonl" + ] }, }, ] diff --git a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_endpoint.py b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_endpoint.py index da18746cf407a..5d40393fd0175 100644 --- a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_endpoint.py +++ b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_endpoint.py @@ -28,7 +28,6 @@ from datetime import datetime from pathlib import Path -from google.cloud import aiplatform from google.cloud.aiplatform import schema from google.protobuf.struct_pb2 import Value @@ -63,7 +62,8 @@ DATA_SAMPLE_GCS_BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}" DATA_SAMPLE_GCS_OBJECT_NAME = "vertex-ai/image-dataset.csv" -IMAGE_ZIP_CSV_FILE_LOCAL_PATH = str(Path(__file__).parent / "resources" / "image-dataset.csv.zip") +RESOURCES_PATH = Path(__file__).parent / "resources" +IMAGE_ZIP_CSV_FILE_LOCAL_PATH = str(RESOURCES_PATH / "image-dataset.csv.zip") IMAGE_CSV_FILE_LOCAL_PATH = "/endpoint/image-dataset.csv" IMAGE_DATASET = { @@ -143,12 +143,7 @@ # format: 'projects/{project}/locations/{location}/models/{model}' "model": "{{ti.xcom_pull('auto_ml_image_task')['name']}}", "display_name": f"temp_endpoint_test_{ENV_ID}", - "dedicated_resources": { - "machine_spec": { - "machine_type": "n1-standard-2", - "accelerator_type": aiplatform.gapic.AcceleratorType.NVIDIA_TESLA_K80, - "accelerator_count": 1, - }, + "automatic_resources": { "min_replica_count": 1, "max_replica_count": 1, }, diff --git a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_model_service.py b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_model_service.py index 98e0c8d20de4d..fc690d36b7bc0 100644 --- a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_model_service.py +++ b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_model_service.py @@ -65,7 +65,7 @@ DATA_SAMPLE_GCS_OBJECT_NAME = "vertex-ai/california_housing_train.csv" CSV_FILE_LOCAL_PATH = "/model_service/california_housing_train.csv" RESOURCES_PATH = Path(__file__).parent / "resources" -CSV_ZIP_FILE_LOCAL_PATH = str(RESOURCES_PATH / "California-housing.zip") +CSV_ZIP_FILE_LOCAL_PATH = str(RESOURCES_PATH / "California-housing-ai-model.zip") TABULAR_DATASET = { "display_name": f"tabular-dataset-{ENV_ID}", diff --git a/tests/system/providers/google/cloud/vertex_ai/resources/California-housing.zip b/tests/system/providers/google/cloud/vertex_ai/resources/California-housing-ai-model.zip similarity index 100% rename from tests/system/providers/google/cloud/vertex_ai/resources/California-housing.zip rename to tests/system/providers/google/cloud/vertex_ai/resources/California-housing-ai-model.zip diff --git a/tests/system/providers/google/cloud/vertex_ai/resources/California-housing-custom-container.zip b/tests/system/providers/google/cloud/vertex_ai/resources/California-housing-custom-container.zip new file mode 100644 index 0000000000000..1ac6fc83a4c08 Binary files /dev/null and b/tests/system/providers/google/cloud/vertex_ai/resources/California-housing-custom-container.zip differ diff --git a/tests/system/providers/google/cloud/vertex_ai/resources/California-housing-custom-job.zip b/tests/system/providers/google/cloud/vertex_ai/resources/California-housing-custom-job.zip new file mode 100644 index 0000000000000..1ac6fc83a4c08 Binary files /dev/null and b/tests/system/providers/google/cloud/vertex_ai/resources/California-housing-custom-job.zip differ diff --git a/tests/system/providers/google/cloud/vertex_ai/resources/California-housing-python-package.zip b/tests/system/providers/google/cloud/vertex_ai/resources/California-housing-python-package.zip new file mode 100644 index 0000000000000..1ac6fc83a4c08 Binary files /dev/null and b/tests/system/providers/google/cloud/vertex_ai/resources/California-housing-python-package.zip differ diff --git a/tests/system/providers/google/cloud/vision/example_vision_annotate_image.py b/tests/system/providers/google/cloud/vision/example_vision_annotate_image.py index b65199098bb2b..61d9fb2eca49c 100644 --- a/tests/system/providers/google/cloud/vision/example_vision_annotate_image.py +++ b/tests/system/providers/google/cloud/vision/example_vision_annotate_image.py @@ -40,7 +40,7 @@ # [END howto_operator_vision_retry_import] # [START howto_operator_vision_enums_import] -from google.cloud.vision import enums # isort:skip +from google.cloud.vision_v1 import Feature # isort:skip # [END howto_operator_vision_enums_import] @@ -59,7 +59,7 @@ # [START howto_operator_vision_annotate_image_request] annotate_image_request = { "image": {"source": {"image_uri": GCP_VISION_ANNOTATE_IMAGE_URL}}, - "features": [{"type": enums.Feature.Type.LOGO_DETECTION}], + "features": [{"type_": Feature.Type.LOGO_DETECTION}], } # [END howto_operator_vision_annotate_image_request] @@ -89,7 +89,7 @@ copy_single_file = GCSToGCSOperator( task_id="copy_single_gcs_file", source_bucket=BUCKET_NAME_SRC, - source_object=PATH_SRC, + source_object=[PATH_SRC], destination_bucket=BUCKET_NAME, destination_object=FILE_NAME, ) diff --git a/tests/system/providers/google/cloud/vision/example_vision_autogenerated.py b/tests/system/providers/google/cloud/vision/example_vision_autogenerated.py index 12820cfeb218b..e16f8dacbed37 100644 --- a/tests/system/providers/google/cloud/vision/example_vision_autogenerated.py +++ b/tests/system/providers/google/cloud/vision/example_vision_autogenerated.py @@ -58,7 +58,7 @@ # [END howto_operator_vision_reference_image_import] # [START howto_operator_vision_enums_import] -from google.cloud.vision import enums # isort:skip +from google.cloud.vision_v1 import Feature # isort:skip # [END howto_operator_vision_enums_import] @@ -93,7 +93,7 @@ # [START howto_operator_vision_annotate_image_request] annotate_image_request = { "image": {"source": {"image_uri": VISION_IMAGE_URL}}, - "features": [{"type": enums.Feature.Type.LOGO_DETECTION}], + "features": [{"type_": Feature.Type.LOGO_DETECTION}], } # [END howto_operator_vision_annotate_image_request] @@ -121,7 +121,7 @@ copy_single_file = GCSToGCSOperator( task_id="copy_single_gcs_file", source_bucket=BUCKET_NAME_SRC, - source_object=PATH_SRC, + source_object=[PATH_SRC], destination_bucket=BUCKET_NAME, destination_object=FILE_NAME, ) diff --git a/tests/system/providers/google/cloud/vision/example_vision_explicit.py b/tests/system/providers/google/cloud/vision/example_vision_explicit.py index d386b7933d42f..3ac729457fa7e 100644 --- a/tests/system/providers/google/cloud/vision/example_vision_explicit.py +++ b/tests/system/providers/google/cloud/vision/example_vision_explicit.py @@ -107,7 +107,7 @@ copy_single_file = GCSToGCSOperator( task_id="copy_single_gcs_file", source_bucket=BUCKET_NAME_SRC, - source_object=PATH_SRC, + source_object=[PATH_SRC], destination_bucket=BUCKET_NAME, destination_object=FILE_NAME, ) @@ -257,6 +257,8 @@ ) chain( + create_bucket, + copy_single_file, product_set_create_2, product_set_get_2, product_set_update_2,