From bc5875ed586055d28744780b2f4168d7553765da Mon Sep 17 00:00:00 2001 From: Maksim Date: Fri, 7 Jun 2024 06:51:46 -0700 Subject: [PATCH] Implement CloudComposerDAGRunSensor (#40088) --- .../google/cloud/sensors/cloud_composer.py | 173 +++++++++++++++++- .../google/cloud/triggers/cloud_composer.py | 115 ++++++++++++ .../operators/cloud/cloud_composer.rst | 20 ++ .../cloud/sensors/test_cloud_composer.py | 63 ++++++- .../cloud/triggers/test_cloud_composer.py | 61 +++++- .../cloud/composer/example_cloud_composer.py | 25 +++ 6 files changed, 447 insertions(+), 10 deletions(-) diff --git a/airflow/providers/google/cloud/sensors/cloud_composer.py b/airflow/providers/google/cloud/sensors/cloud_composer.py index 22d16e8f33ad2..0301466eac0ae 100644 --- a/airflow/providers/google/cloud/sensors/cloud_composer.py +++ b/airflow/providers/google/cloud/sensors/cloud_composer.py @@ -19,13 +19,24 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Sequence +import json +from datetime import datetime, timedelta +from typing import TYPE_CHECKING, Any, Iterable, Sequence +from dateutil import parser from deprecated import deprecated +from google.cloud.orchestration.airflow.service_v1.types import ExecuteAirflowCommandResponse +from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException -from airflow.providers.google.cloud.triggers.cloud_composer import CloudComposerExecutionTrigger +from airflow.providers.google.cloud.hooks.cloud_composer import CloudComposerHook +from airflow.providers.google.cloud.triggers.cloud_composer import ( + CloudComposerDAGRunTrigger, + CloudComposerExecutionTrigger, +) +from airflow.providers.google.common.consts import GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME from airflow.sensors.base import BaseSensorOperator +from airflow.utils.state import TaskInstanceState if TYPE_CHECKING: from airflow.utils.context import Context @@ -117,3 +128,161 @@ def execute_complete(self, context: dict[str, Any], event: dict[str, str] | None if self.soft_fail: raise AirflowSkipException(message) raise AirflowException(message) + + +class CloudComposerDAGRunSensor(BaseSensorOperator): + """ + Check if a DAG run has completed. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param environment_id: The name of the Composer environment. + :param composer_dag_id: The ID of executable DAG. + :param allowed_states: Iterable of allowed states, default is ``['success']``. + :param execution_range: execution DAGs time range. Sensor checks DAGs states only for DAGs which were + started in this time range. For yesterday, use [positive!] datetime.timedelta(days=1). + For future, use [negative!] datetime.timedelta(days=-1). For specific time, use list of + datetimes [datetime(2024,3,22,11,0,0), datetime(2024,3,22,12,0,0)]. + Or [datetime(2024,3,22,0,0,0)] in this case sensor will check for states from specific time in the + past till current time execution. + Default value datetime.timedelta(days=1). + :param gcp_conn_id: The connection ID to use when fetching connection info. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :param poll_interval: Optional: Control the rate of the poll for the result of deferrable run. + :param deferrable: Run sensor in deferrable mode. + """ + + template_fields = ( + "project_id", + "region", + "environment_id", + "composer_dag_id", + "impersonation_chain", + ) + + def __init__( + self, + *, + project_id: str, + region: str, + environment_id: str, + composer_dag_id: str, + allowed_states: Iterable[str] | None = None, + execution_range: timedelta | list[datetime] | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + poll_interval: int = 10, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.region = region + self.environment_id = environment_id + self.composer_dag_id = composer_dag_id + self.allowed_states = list(allowed_states) if allowed_states else [TaskInstanceState.SUCCESS.value] + self.execution_range = execution_range + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + self.deferrable = deferrable + self.poll_interval = poll_interval + + def _get_execution_dates(self, context) -> tuple[datetime, datetime]: + if isinstance(self.execution_range, timedelta): + if self.execution_range < timedelta(0): + return context["logical_date"], context["logical_date"] - self.execution_range + else: + return context["logical_date"] - self.execution_range, context["logical_date"] + elif isinstance(self.execution_range, list) and len(self.execution_range) > 0: + return self.execution_range[0], self.execution_range[1] if len( + self.execution_range + ) > 1 else context["logical_date"] + else: + return context["logical_date"] - timedelta(1), context["logical_date"] + + def poke(self, context: Context) -> bool: + start_date, end_date = self._get_execution_dates(context) + + if datetime.now(end_date.tzinfo) < end_date: + return False + + dag_runs = self._pull_dag_runs() + + self.log.info("Sensor waits for allowed states: %s", self.allowed_states) + allowed_states_status = self._check_dag_runs_states( + dag_runs=dag_runs, + start_date=start_date, + end_date=end_date, + ) + + return allowed_states_status + + def _pull_dag_runs(self) -> list[dict]: + """Pull the list of dag runs.""" + hook = CloudComposerHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + dag_runs_cmd = hook.execute_airflow_command( + project_id=self.project_id, + region=self.region, + environment_id=self.environment_id, + command="dags", + subcommand="list-runs", + parameters=["-d", self.composer_dag_id, "-o", "json"], + ) + cmd_result = hook.wait_command_execution_result( + project_id=self.project_id, + region=self.region, + environment_id=self.environment_id, + execution_cmd_info=ExecuteAirflowCommandResponse.to_dict(dag_runs_cmd), + ) + dag_runs = json.loads(cmd_result["output"][0]["content"]) + return dag_runs + + def _check_dag_runs_states( + self, + dag_runs: list[dict], + start_date: datetime, + end_date: datetime, + ) -> bool: + for dag_run in dag_runs: + if ( + start_date.timestamp() + < parser.parse(dag_run["execution_date"]).timestamp() + < end_date.timestamp() + ) and dag_run["state"] not in self.allowed_states: + return False + return True + + def execute(self, context: Context) -> None: + if self.deferrable: + start_date, end_date = self._get_execution_dates(context) + self.defer( + trigger=CloudComposerDAGRunTrigger( + project_id=self.project_id, + region=self.region, + environment_id=self.environment_id, + composer_dag_id=self.composer_dag_id, + start_date=start_date, + end_date=end_date, + allowed_states=self.allowed_states, + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + poll_interval=self.poll_interval, + ), + method_name=GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME, + ) + super().execute(context) + + def execute_complete(self, context: Context, event: dict): + if event and event["status"] == "error": + raise AirflowException(event["message"]) + self.log.info("DAG %s has executed successfully.", self.composer_dag_id) diff --git a/airflow/providers/google/cloud/triggers/cloud_composer.py b/airflow/providers/google/cloud/triggers/cloud_composer.py index ac5a00c60f4a1..2334d038e62f7 100644 --- a/airflow/providers/google/cloud/triggers/cloud_composer.py +++ b/airflow/providers/google/cloud/triggers/cloud_composer.py @@ -19,8 +19,13 @@ from __future__ import annotations import asyncio +import json +from datetime import datetime from typing import Any, Sequence +from dateutil import parser +from google.cloud.orchestration.airflow.service_v1.types import ExecuteAirflowCommandResponse + from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.cloud_composer import CloudComposerAsyncHook from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -146,3 +151,113 @@ async def run(self): } ) return + + +class CloudComposerDAGRunTrigger(BaseTrigger): + """The trigger wait for the DAG run completion.""" + + def __init__( + self, + project_id: str, + region: str, + environment_id: str, + composer_dag_id: str, + start_date: datetime, + end_date: datetime, + allowed_states: list[str], + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + poll_interval: int = 10, + ): + super().__init__() + self.project_id = project_id + self.region = region + self.environment_id = environment_id + self.composer_dag_id = composer_dag_id + self.start_date = start_date + self.end_date = end_date + self.allowed_states = allowed_states + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + self.poll_interval = poll_interval + + self.gcp_hook = CloudComposerAsyncHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + "airflow.providers.google.cloud.triggers.cloud_composer.CloudComposerDAGRunTrigger", + { + "project_id": self.project_id, + "region": self.region, + "environment_id": self.environment_id, + "composer_dag_id": self.composer_dag_id, + "start_date": self.start_date, + "end_date": self.end_date, + "allowed_states": self.allowed_states, + "gcp_conn_id": self.gcp_conn_id, + "impersonation_chain": self.impersonation_chain, + "poll_interval": self.poll_interval, + }, + ) + + async def _pull_dag_runs(self) -> list[dict]: + """Pull the list of dag runs.""" + dag_runs_cmd = await self.gcp_hook.execute_airflow_command( + project_id=self.project_id, + region=self.region, + environment_id=self.environment_id, + command="dags", + subcommand="list-runs", + parameters=["-d", self.composer_dag_id, "-o", "json"], + ) + cmd_result = await self.gcp_hook.wait_command_execution_result( + project_id=self.project_id, + region=self.region, + environment_id=self.environment_id, + execution_cmd_info=ExecuteAirflowCommandResponse.to_dict(dag_runs_cmd), + ) + dag_runs = json.loads(cmd_result["output"][0]["content"]) + return dag_runs + + def _check_dag_runs_states( + self, + dag_runs: list[dict], + start_date: datetime, + end_date: datetime, + ) -> bool: + for dag_run in dag_runs: + if ( + start_date.timestamp() + < parser.parse(dag_run["execution_date"]).timestamp() + < end_date.timestamp() + ) and dag_run["state"] not in self.allowed_states: + return False + return True + + async def run(self): + try: + while True: + if datetime.now(self.end_date.tzinfo).timestamp() > self.end_date.timestamp(): + dag_runs = await self._pull_dag_runs() + + self.log.info("Sensor waits for allowed states: %s", self.allowed_states) + if self._check_dag_runs_states( + dag_runs=dag_runs, + start_date=self.start_date, + end_date=self.end_date, + ): + yield TriggerEvent({"status": "success"}) + return + self.log.info("Sleeping for %s seconds.", self.poll_interval) + await asyncio.sleep(self.poll_interval) + except AirflowException as ex: + yield TriggerEvent( + { + "status": "error", + "message": str(ex), + } + ) + return diff --git a/docs/apache-airflow-providers-google/operators/cloud/cloud_composer.rst b/docs/apache-airflow-providers-google/operators/cloud/cloud_composer.rst index cdb9cb2931325..f8f00fbe6c54a 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/cloud_composer.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/cloud_composer.rst @@ -177,3 +177,23 @@ or you can define the same operator in the deferrable mode: :dedent: 4 :start-after: [START howto_operator_run_airflow_cli_command_deferrable_mode] :end-before: [END howto_operator_run_airflow_cli_command_deferrable_mode] + +Check if a DAG run has completed +-------------------------------- + +You can use sensor that checks if a DAG run has completed in your environments, use: +:class:`~airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerDAGRunSensor` + +.. exampleinclude:: /../../tests/system/providers/google/cloud/composer/example_cloud_composer.py + :language: python + :dedent: 4 + :start-after: [START howto_sensor_dag_run] + :end-before: [END howto_sensor_dag_run] + +or you can define the same sensor in the deferrable mode: + +.. exampleinclude:: /../../tests/system/providers/google/cloud/composer/example_cloud_composer.py + :language: python + :dedent: 4 + :start-after: [START howto_sensor_dag_run_deferrable_mode] + :end-before: [END howto_sensor_dag_run_deferrable_mode] diff --git a/tests/providers/google/cloud/sensors/test_cloud_composer.py b/tests/providers/google/cloud/sensors/test_cloud_composer.py index 5241ff551e634..c22eb90fdeaa4 100644 --- a/tests/providers/google/cloud/sensors/test_cloud_composer.py +++ b/tests/providers/google/cloud/sensors/test_cloud_composer.py @@ -17,17 +17,42 @@ from __future__ import annotations +import json +from datetime import datetime from unittest import mock import pytest from airflow.exceptions import AirflowException, AirflowSkipException, TaskDeferred -from airflow.providers.google.cloud.sensors.cloud_composer import CloudComposerEnvironmentSensor -from airflow.providers.google.cloud.triggers.cloud_composer import CloudComposerExecutionTrigger +from airflow.providers.google.cloud.sensors.cloud_composer import ( + CloudComposerDAGRunSensor, + CloudComposerEnvironmentSensor, +) +from airflow.providers.google.cloud.triggers.cloud_composer import ( + CloudComposerExecutionTrigger, +) TEST_PROJECT_ID = "test_project_id" TEST_OPERATION_NAME = "test_operation_name" TEST_REGION = "region" +TEST_ENVIRONMENT_ID = "test_env_id" +TEST_JSON_RESULT = lambda state: json.dumps( + [ + { + "dag_id": "test_dag_id", + "run_id": "scheduled__2024-05-22T11:10:00+00:00", + "state": state, + "execution_date": "2024-05-22T11:10:00+00:00", + "start_date": "2024-05-22T11:20:01.531988+00:00", + "end_date": "2024-05-22T11:20:11.997479+00:00", + } + ] +) +TEST_EXEC_RESULT = lambda state: { + "output": [{"line_number": 1, "content": TEST_JSON_RESULT(state)}], + "output_end": True, + "exit_info": {"exit_code": 0, "error": ""}, +} class TestCloudComposerEnvironmentSensor: @@ -76,3 +101,37 @@ def test_cloud_composer_existence_sensor_async_execute_complete(self): task.execute_complete( context={}, event={"operation_done": True, "operation_name": TEST_OPERATION_NAME} ) + + +class TestCloudComposerDAGRunSensor: + @mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.ExecuteAirflowCommandResponse.to_dict") + @mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook") + def test_wait_ready(self, mock_hook, to_dict_mode): + mock_hook.return_value.wait_command_execution_result.return_value = TEST_EXEC_RESULT("success") + + task = CloudComposerDAGRunSensor( + task_id="task-id", + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + environment_id=TEST_ENVIRONMENT_ID, + composer_dag_id="test_dag_id", + allowed_states=["success"], + ) + + assert task.poke(context={"logical_date": datetime(2024, 5, 23, 0, 0, 0)}) + + @mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.ExecuteAirflowCommandResponse.to_dict") + @mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook") + def test_wait_not_ready(self, mock_hook, to_dict_mode): + mock_hook.return_value.wait_command_execution_result.return_value = TEST_EXEC_RESULT("running") + + task = CloudComposerDAGRunSensor( + task_id="task-id", + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + environment_id=TEST_ENVIRONMENT_ID, + composer_dag_id="test_dag_id", + allowed_states=["success"], + ) + + assert not task.poke(context={"logical_date": datetime(2024, 5, 23, 0, 0, 0)}) diff --git a/tests/providers/google/cloud/triggers/test_cloud_composer.py b/tests/providers/google/cloud/triggers/test_cloud_composer.py index 99daaf83bdc26..00d109ed975a1 100644 --- a/tests/providers/google/cloud/triggers/test_cloud_composer.py +++ b/tests/providers/google/cloud/triggers/test_cloud_composer.py @@ -17,12 +17,16 @@ from __future__ import annotations +from datetime import datetime from unittest import mock import pytest from airflow.models import Connection -from airflow.providers.google.cloud.triggers.cloud_composer import CloudComposerAirflowCLICommandTrigger +from airflow.providers.google.cloud.triggers.cloud_composer import ( + CloudComposerAirflowCLICommandTrigger, + CloudComposerDAGRunTrigger, +) from airflow.triggers.base import TriggerEvent TEST_PROJECT_ID = "test-project-id" @@ -34,6 +38,10 @@ "pod_namespace": "test_namespace", "error": "test_error", } +TEST_COMPOSER_DAG_ID = "test_dag_id" +TEST_START_DATE = datetime(2024, 3, 22, 11, 0, 0) +TEST_END_DATE = datetime(2024, 3, 22, 12, 0, 0) +TEST_STATES = ["success"] TEST_GCP_CONN_ID = "test_gcp_conn_id" TEST_POLL_INTERVAL = 10 TEST_IMPERSONATION_CHAIN = "test_impersonation_chain" @@ -49,7 +57,7 @@ "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_connection", return_value=Connection(conn_id="test_conn"), ) -def trigger(mock_conn): +def cli_command_trigger(mock_conn): return CloudComposerAirflowCLICommandTrigger( project_id=TEST_PROJECT_ID, region=TEST_LOCATION, @@ -61,9 +69,29 @@ def trigger(mock_conn): ) +@pytest.fixture +@mock.patch( + "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_connection", + return_value=Connection(conn_id="test_conn"), +) +def dag_run_trigger(mock_conn): + return CloudComposerDAGRunTrigger( + project_id=TEST_PROJECT_ID, + region=TEST_LOCATION, + environment_id=TEST_ENVIRONMENT_ID, + composer_dag_id=TEST_COMPOSER_DAG_ID, + start_date=TEST_START_DATE, + end_date=TEST_END_DATE, + allowed_states=TEST_STATES, + gcp_conn_id=TEST_GCP_CONN_ID, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + poll_interval=TEST_POLL_INTERVAL, + ) + + class TestCloudComposerAirflowCLICommandTrigger: - def test_serialize(self, trigger): - actual_data = trigger.serialize() + def test_serialize(self, cli_command_trigger): + actual_data = cli_command_trigger.serialize() expected_data = ( "airflow.providers.google.cloud.triggers.cloud_composer.CloudComposerAirflowCLICommandTrigger", { @@ -82,7 +110,7 @@ def test_serialize(self, trigger): @mock.patch( "airflow.providers.google.cloud.hooks.cloud_composer.CloudComposerAsyncHook.wait_command_execution_result" ) - async def test_run(self, mock_exec_result, trigger): + async def test_run(self, mock_exec_result, cli_command_trigger): mock_exec_result.return_value = TEST_EXEC_RESULT expected_event = TriggerEvent( @@ -91,6 +119,27 @@ async def test_run(self, mock_exec_result, trigger): "result": TEST_EXEC_RESULT, } ) - actual_event = await trigger.run().asend(None) + actual_event = await cli_command_trigger.run().asend(None) assert actual_event == expected_event + + +class TestCloudComposerDAGRunTrigger: + def test_serialize(self, dag_run_trigger): + actual_data = dag_run_trigger.serialize() + expected_data = ( + "airflow.providers.google.cloud.triggers.cloud_composer.CloudComposerDAGRunTrigger", + { + "project_id": TEST_PROJECT_ID, + "region": TEST_LOCATION, + "environment_id": TEST_ENVIRONMENT_ID, + "composer_dag_id": TEST_COMPOSER_DAG_ID, + "start_date": TEST_START_DATE, + "end_date": TEST_END_DATE, + "allowed_states": TEST_STATES, + "gcp_conn_id": TEST_GCP_CONN_ID, + "impersonation_chain": TEST_IMPERSONATION_CHAIN, + "poll_interval": TEST_POLL_INTERVAL, + }, + ) + assert actual_data == expected_data diff --git a/tests/system/providers/google/cloud/composer/example_cloud_composer.py b/tests/system/providers/google/cloud/composer/example_cloud_composer.py index fe60c56ddf812..52404fa375394 100644 --- a/tests/system/providers/google/cloud/composer/example_cloud_composer.py +++ b/tests/system/providers/google/cloud/composer/example_cloud_composer.py @@ -31,6 +31,7 @@ CloudComposerRunAirflowCLICommandOperator, CloudComposerUpdateEnvironmentOperator, ) +from airflow.providers.google.cloud.sensors.cloud_composer import CloudComposerDAGRunSensor from airflow.utils.trigger_rule import TriggerRule ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default") @@ -158,6 +159,29 @@ ) # [END howto_operator_run_airflow_cli_command_deferrable_mode] + # [START howto_sensor_dag_run] + dag_run_sensor = CloudComposerDAGRunSensor( + task_id="dag_run_sensor", + project_id=PROJECT_ID, + region=REGION, + environment_id=ENVIRONMENT_ID, + composer_dag_id="airflow_monitoring", + allowed_states=["success"], + ) + # [END howto_sensor_dag_run] + + # [START howto_sensor_dag_run_deferrable_mode] + defer_dag_run_sensor = CloudComposerDAGRunSensor( + task_id="defer_dag_run_sensor", + project_id=PROJECT_ID, + region=REGION, + environment_id=ENVIRONMENT_ID_ASYNC, + composer_dag_id="airflow_monitoring", + allowed_states=["success"], + deferrable=True, + ) + # [END howto_sensor_dag_run_deferrable_mode] + # [START howto_operator_delete_composer_environment] delete_env = CloudComposerDeleteEnvironmentOperator( task_id="delete_env", @@ -186,6 +210,7 @@ get_env, [update_env, defer_update_env], [run_airflow_cli_cmd, defer_run_airflow_cli_cmd], + [dag_run_sensor, defer_dag_run_sensor], [delete_env, defer_delete_env], )