From d4d4ca38cc3a65f3d61bc6baa6a5eef4de740284 Mon Sep 17 00:00:00 2001 From: Shahar Epstein Date: Mon, 11 Mar 2024 21:13:02 +0200 Subject: [PATCH] Update tests/providers/google/cloud/operators/test_vertex_ai.py Co-authored-by: Andrey Anshin Rename `DeleteCustomTrainingJobOperator`'s fields' name to comply with templated fields validation --- .pre-commit-config.yaml | 1 - .../cloud/operators/vertex_ai/custom_job.py | 28 ++++++++++++++-- .../google/cloud/operators/test_vertex_ai.py | 32 ++++++++++++++++++- 3 files changed, 56 insertions(+), 5 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ba2707de50fdb..68eea393f1221 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -335,7 +335,6 @@ repos: exclude: | (?x)^( ^airflow\/providers\/google\/cloud\/operators\/mlengine.py$| - ^airflow\/providers\/google\/cloud\/operators\/vertex_ai\/custom_job.py$| ^airflow\/providers\/google\/cloud\/operators\/cloud_storage_transfer_service.py$| ^airflow\/providers\/apache\/spark\/operators\/spark_submit.py\.py$| ^airflow\/providers\/google\/cloud\/operators\/vertex_ai\/auto_ml\.py$| diff --git a/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py b/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py index 71c4526e9249b..dcd5acbcad271 100644 --- a/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +++ b/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py @@ -20,12 +20,14 @@ from typing import TYPE_CHECKING, Sequence +from deprecated import deprecated from google.api_core.exceptions import NotFound from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.cloud.aiplatform.models import Model from google.cloud.aiplatform_v1.types.dataset import Dataset from google.cloud.aiplatform_v1.types.training_pipeline import TrainingPipeline +from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.google.cloud.hooks.vertex_ai.custom_job import CustomJobHook from airflow.providers.google.cloud.links.vertex_ai import ( VertexAIModelLink, @@ -1328,7 +1330,7 @@ class DeleteCustomTrainingJobOperator(GoogleCloudBaseOperator): account from the list granting this role to the originating account (templated). """ - template_fields = ("training_pipeline", "custom_job", "region", "project_id", "impersonation_chain") + template_fields = ("training_pipeline_id", "custom_job_id", "region", "project_id", "impersonation_chain") def __init__( self, @@ -1345,8 +1347,8 @@ def __init__( **kwargs, ) -> None: super().__init__(**kwargs) - self.training_pipeline = training_pipeline_id - self.custom_job = custom_job_id + self.training_pipeline_id = training_pipeline_id + self.custom_job_id = custom_job_id self.region = region self.project_id = project_id self.retry = retry @@ -1355,6 +1357,26 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + @deprecated( + reason="`training_pipeline` is deprecated and will be removed in the future. " + "Please use `training_pipeline_id` instead.", + category=AirflowProviderDeprecationWarning, + ) + def training_pipeline(self): + """Alias for ``training_pipeline_id``, used for compatibility (deprecated).""" + return self.training_pipeline_id + + @property + @deprecated( + reason="`custom_job` is deprecated and will be removed in the future. " + "Please use `custom_job_id` instead.", + category=AirflowProviderDeprecationWarning, + ) + def custom_job(self): + """Alias for ``custom_job_id``, used for compatibility (deprecated).""" + return self.custom_job_id + def execute(self, context: Context): hook = CustomJobHook( gcp_conn_id=self.gcp_conn_id, diff --git a/tests/providers/google/cloud/operators/test_vertex_ai.py b/tests/providers/google/cloud/operators/test_vertex_ai.py index 57864a0a0241f..a092cebafb80a 100644 --- a/tests/providers/google/cloud/operators/test_vertex_ai.py +++ b/tests/providers/google/cloud/operators/test_vertex_ai.py @@ -23,7 +23,7 @@ from google.api_core.gapic_v1.method import DEFAULT from google.api_core.retry import Retry -from airflow.exceptions import AirflowException, TaskDeferred +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, TaskDeferred from airflow.providers.google.cloud.operators.vertex_ai.auto_ml import ( CreateAutoMLForecastingTrainingJobOperator, CreateAutoMLImageTrainingJobOperator, @@ -84,6 +84,7 @@ ListPipelineJobOperator, RunPipelineJobOperator, ) +from airflow.utils import timezone VERTEX_AI_PATH = "airflow.providers.google.cloud.operators.vertex_ai.{}" VERTEX_AI_LINKS_PATH = "airflow.providers.google.cloud.links.vertex_ai.{}" @@ -477,6 +478,35 @@ def test_execute(self, mock_hook): metadata=METADATA, ) + @pytest.mark.db_test + def test_templating(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + DeleteCustomTrainingJobOperator, + # Templated fields + training_pipeline_id="{{ 'training-pipeline-id' }}", + custom_job_id="{{ 'custom_job_id' }}", + region="{{ 'region' }}", + project_id="{{ 'project_id' }}", + impersonation_chain="{{ 'impersonation-chain' }}", + # Other parameters + dag_id="test_template_body_templating_dag", + task_id="test_template_body_templating_task", + execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc), + ) + ti.render_templates() + task: DeleteCustomTrainingJobOperator = ti.task + assert task.training_pipeline_id == "training-pipeline-id" + assert task.custom_job_id == "custom_job_id" + assert task.region == "region" + assert task.project_id == "project_id" + assert task.impersonation_chain == "impersonation-chain" + + # Deprecated aliases + with pytest.warns(AirflowProviderDeprecationWarning): + assert task.training_pipeline == "training-pipeline-id" + with pytest.warns(AirflowProviderDeprecationWarning): + assert task.custom_job == "custom_job_id" + class TestVertexAIListCustomTrainingJobOperator: @mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook"))