Skip to content

Commit

Permalink
Implement CloudComposerDAGRunSensor (apache#40088)
Browse files Browse the repository at this point in the history
  • Loading branch information
MaksYermak authored and romsharon98 committed Jul 26, 2024
1 parent 4a651ee commit bc5875e
Show file tree
Hide file tree
Showing 6 changed files with 447 additions and 10 deletions.
173 changes: 171 additions & 2 deletions airflow/providers/google/cloud/sensors/cloud_composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,24 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Sequence
import json
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, Iterable, Sequence

from dateutil import parser
from deprecated import deprecated
from google.cloud.orchestration.airflow.service_v1.types import ExecuteAirflowCommandResponse

from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException
from airflow.providers.google.cloud.triggers.cloud_composer import CloudComposerExecutionTrigger
from airflow.providers.google.cloud.hooks.cloud_composer import CloudComposerHook
from airflow.providers.google.cloud.triggers.cloud_composer import (
CloudComposerDAGRunTrigger,
CloudComposerExecutionTrigger,
)
from airflow.providers.google.common.consts import GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
from airflow.sensors.base import BaseSensorOperator
from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down Expand Up @@ -117,3 +128,161 @@ def execute_complete(self, context: dict[str, Any], event: dict[str, str] | None
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)


class CloudComposerDAGRunSensor(BaseSensorOperator):
"""
Check if a DAG run has completed.
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
:param region: Required. The ID of the Google Cloud region that the service belongs to.
:param environment_id: The name of the Composer environment.
:param composer_dag_id: The ID of executable DAG.
:param allowed_states: Iterable of allowed states, default is ``['success']``.
:param execution_range: execution DAGs time range. Sensor checks DAGs states only for DAGs which were
started in this time range. For yesterday, use [positive!] datetime.timedelta(days=1).
For future, use [negative!] datetime.timedelta(days=-1). For specific time, use list of
datetimes [datetime(2024,3,22,11,0,0), datetime(2024,3,22,12,0,0)].
Or [datetime(2024,3,22,0,0,0)] in this case sensor will check for states from specific time in the
past till current time execution.
Default value datetime.timedelta(days=1).
:param gcp_conn_id: The connection ID to use when fetching connection info.
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param poll_interval: Optional: Control the rate of the poll for the result of deferrable run.
:param deferrable: Run sensor in deferrable mode.
"""

template_fields = (
"project_id",
"region",
"environment_id",
"composer_dag_id",
"impersonation_chain",
)

def __init__(
self,
*,
project_id: str,
region: str,
environment_id: str,
composer_dag_id: str,
allowed_states: Iterable[str] | None = None,
execution_range: timedelta | list[datetime] | None = None,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
poll_interval: int = 10,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.project_id = project_id
self.region = region
self.environment_id = environment_id
self.composer_dag_id = composer_dag_id
self.allowed_states = list(allowed_states) if allowed_states else [TaskInstanceState.SUCCESS.value]
self.execution_range = execution_range
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.deferrable = deferrable
self.poll_interval = poll_interval

def _get_execution_dates(self, context) -> tuple[datetime, datetime]:
if isinstance(self.execution_range, timedelta):
if self.execution_range < timedelta(0):
return context["logical_date"], context["logical_date"] - self.execution_range
else:
return context["logical_date"] - self.execution_range, context["logical_date"]
elif isinstance(self.execution_range, list) and len(self.execution_range) > 0:
return self.execution_range[0], self.execution_range[1] if len(
self.execution_range
) > 1 else context["logical_date"]
else:
return context["logical_date"] - timedelta(1), context["logical_date"]

def poke(self, context: Context) -> bool:
start_date, end_date = self._get_execution_dates(context)

if datetime.now(end_date.tzinfo) < end_date:
return False

dag_runs = self._pull_dag_runs()

self.log.info("Sensor waits for allowed states: %s", self.allowed_states)
allowed_states_status = self._check_dag_runs_states(
dag_runs=dag_runs,
start_date=start_date,
end_date=end_date,
)

return allowed_states_status

def _pull_dag_runs(self) -> list[dict]:
"""Pull the list of dag runs."""
hook = CloudComposerHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
dag_runs_cmd = hook.execute_airflow_command(
project_id=self.project_id,
region=self.region,
environment_id=self.environment_id,
command="dags",
subcommand="list-runs",
parameters=["-d", self.composer_dag_id, "-o", "json"],
)
cmd_result = hook.wait_command_execution_result(
project_id=self.project_id,
region=self.region,
environment_id=self.environment_id,
execution_cmd_info=ExecuteAirflowCommandResponse.to_dict(dag_runs_cmd),
)
dag_runs = json.loads(cmd_result["output"][0]["content"])
return dag_runs

def _check_dag_runs_states(
self,
dag_runs: list[dict],
start_date: datetime,
end_date: datetime,
) -> bool:
for dag_run in dag_runs:
if (
start_date.timestamp()
< parser.parse(dag_run["execution_date"]).timestamp()
< end_date.timestamp()
) and dag_run["state"] not in self.allowed_states:
return False
return True

def execute(self, context: Context) -> None:
if self.deferrable:
start_date, end_date = self._get_execution_dates(context)
self.defer(
trigger=CloudComposerDAGRunTrigger(
project_id=self.project_id,
region=self.region,
environment_id=self.environment_id,
composer_dag_id=self.composer_dag_id,
start_date=start_date,
end_date=end_date,
allowed_states=self.allowed_states,
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
poll_interval=self.poll_interval,
),
method_name=GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME,
)
super().execute(context)

def execute_complete(self, context: Context, event: dict):
if event and event["status"] == "error":
raise AirflowException(event["message"])
self.log.info("DAG %s has executed successfully.", self.composer_dag_id)
115 changes: 115 additions & 0 deletions airflow/providers/google/cloud/triggers/cloud_composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,13 @@
from __future__ import annotations

import asyncio
import json
from datetime import datetime
from typing import Any, Sequence

from dateutil import parser
from google.cloud.orchestration.airflow.service_v1.types import ExecuteAirflowCommandResponse

from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.cloud_composer import CloudComposerAsyncHook
from airflow.triggers.base import BaseTrigger, TriggerEvent
Expand Down Expand Up @@ -146,3 +151,113 @@ async def run(self):
}
)
return


class CloudComposerDAGRunTrigger(BaseTrigger):
"""The trigger wait for the DAG run completion."""

def __init__(
self,
project_id: str,
region: str,
environment_id: str,
composer_dag_id: str,
start_date: datetime,
end_date: datetime,
allowed_states: list[str],
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
poll_interval: int = 10,
):
super().__init__()
self.project_id = project_id
self.region = region
self.environment_id = environment_id
self.composer_dag_id = composer_dag_id
self.start_date = start_date
self.end_date = end_date
self.allowed_states = allowed_states
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.poll_interval = poll_interval

self.gcp_hook = CloudComposerAsyncHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)

def serialize(self) -> tuple[str, dict[str, Any]]:
return (
"airflow.providers.google.cloud.triggers.cloud_composer.CloudComposerDAGRunTrigger",
{
"project_id": self.project_id,
"region": self.region,
"environment_id": self.environment_id,
"composer_dag_id": self.composer_dag_id,
"start_date": self.start_date,
"end_date": self.end_date,
"allowed_states": self.allowed_states,
"gcp_conn_id": self.gcp_conn_id,
"impersonation_chain": self.impersonation_chain,
"poll_interval": self.poll_interval,
},
)

async def _pull_dag_runs(self) -> list[dict]:
"""Pull the list of dag runs."""
dag_runs_cmd = await self.gcp_hook.execute_airflow_command(
project_id=self.project_id,
region=self.region,
environment_id=self.environment_id,
command="dags",
subcommand="list-runs",
parameters=["-d", self.composer_dag_id, "-o", "json"],
)
cmd_result = await self.gcp_hook.wait_command_execution_result(
project_id=self.project_id,
region=self.region,
environment_id=self.environment_id,
execution_cmd_info=ExecuteAirflowCommandResponse.to_dict(dag_runs_cmd),
)
dag_runs = json.loads(cmd_result["output"][0]["content"])
return dag_runs

def _check_dag_runs_states(
self,
dag_runs: list[dict],
start_date: datetime,
end_date: datetime,
) -> bool:
for dag_run in dag_runs:
if (
start_date.timestamp()
< parser.parse(dag_run["execution_date"]).timestamp()
< end_date.timestamp()
) and dag_run["state"] not in self.allowed_states:
return False
return True

async def run(self):
try:
while True:
if datetime.now(self.end_date.tzinfo).timestamp() > self.end_date.timestamp():
dag_runs = await self._pull_dag_runs()

self.log.info("Sensor waits for allowed states: %s", self.allowed_states)
if self._check_dag_runs_states(
dag_runs=dag_runs,
start_date=self.start_date,
end_date=self.end_date,
):
yield TriggerEvent({"status": "success"})
return
self.log.info("Sleeping for %s seconds.", self.poll_interval)
await asyncio.sleep(self.poll_interval)
except AirflowException as ex:
yield TriggerEvent(
{
"status": "error",
"message": str(ex),
}
)
return
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,23 @@ or you can define the same operator in the deferrable mode:
:dedent: 4
:start-after: [START howto_operator_run_airflow_cli_command_deferrable_mode]
:end-before: [END howto_operator_run_airflow_cli_command_deferrable_mode]

Check if a DAG run has completed
--------------------------------

You can use sensor that checks if a DAG run has completed in your environments, use:
:class:`~airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerDAGRunSensor`

.. exampleinclude:: /../../tests/system/providers/google/cloud/composer/example_cloud_composer.py
:language: python
:dedent: 4
:start-after: [START howto_sensor_dag_run]
:end-before: [END howto_sensor_dag_run]

or you can define the same sensor in the deferrable mode:

.. exampleinclude:: /../../tests/system/providers/google/cloud/composer/example_cloud_composer.py
:language: python
:dedent: 4
:start-after: [START howto_sensor_dag_run_deferrable_mode]
:end-before: [END howto_sensor_dag_run_deferrable_mode]
Loading

0 comments on commit bc5875e

Please sign in to comment.