From 98bdaa2eac94cce066ceb901212d319fd37d17c6 Mon Sep 17 00:00:00 2001 From: humit0 Date: Wed, 4 May 2022 11:45:26 +0900 Subject: [PATCH 1/2] Rename cluster_policy to task_policy --- .../concepts/cluster-policies.rst | 15 +++++++++----- tests/cluster_policies/__init__.py | 20 +------------------ tests/models/test_dagbag.py | 6 +++--- 3 files changed, 14 insertions(+), 27 deletions(-) diff --git a/docs/apache-airflow/concepts/cluster-policies.rst b/docs/apache-airflow/concepts/cluster-policies.rst index e1c664660a4d6..96b523a610f4e 100644 --- a/docs/apache-airflow/concepts/cluster-policies.rst +++ b/docs/apache-airflow/concepts/cluster-policies.rst @@ -57,12 +57,17 @@ This policy checks if each DAG has at least one tag defined: Task policies ------------- -Here's an example of enforcing a maximum timeout policy on every task: +Here's an example of enforcing a maximum timeout policy on every task:: -.. literalinclude:: /../../tests/cluster_policies/__init__.py - :language: python - :start-after: [START example_task_cluster_policy] - :end-before: [END example_task_cluster_policy] + class TimedOperator(BaseOperator, ABC): + timeout: timedelta + + + def task_policy(task: TimedOperator): + if task.task_type == 'HivePartitionSensor': + task.queue = "sensor_queue" + if task.timeout > timedelta(hours=48): + task.timeout = timedelta(hours=48) You could also implement to protect against common errors, rather than as technical security controls. For example, don't run tasks without airflow owners: diff --git a/tests/cluster_policies/__init__.py b/tests/cluster_policies/__init__.py index bab74cc5d7a5f..b4901be2adb98 100644 --- a/tests/cluster_policies/__init__.py +++ b/tests/cluster_policies/__init__.py @@ -15,8 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from abc import ABC -from datetime import timedelta from typing import Callable, List from airflow.configuration import conf @@ -62,7 +60,7 @@ def _check_task_rules(current_task: BaseOperator): ) -def cluster_policy(task: BaseOperator): +def task_policy(task: BaseOperator): """Ensure Tasks have non-default owners.""" _check_task_rules(task) @@ -80,22 +78,6 @@ def dag_policy(dag: DAG): # [END example_dag_cluster_policy] - -class TimedOperator(BaseOperator, ABC): - timeout: timedelta - - -# [START example_task_cluster_policy] -def task_policy(task: TimedOperator): - if task.task_type == 'HivePartitionSensor': - task.queue = "sensor_queue" - if task.timeout > timedelta(hours=48): - task.timeout = timedelta(hours=48) - - -# [END example_task_cluster_policy] - - # [START example_task_mutation_hook] def task_instance_mutation_hook(task_instance: TaskInstance): if task_instance.try_number >= 1: diff --git a/tests/models/test_dagbag.py b/tests/models/test_dagbag.py index 4554389f950e0..432c55c83f075 100644 --- a/tests/models/test_dagbag.py +++ b/tests/models/test_dagbag.py @@ -955,7 +955,7 @@ def test_collect_dags_from_db(self): assert serialized_dag.dag_id == dag.dag_id assert set(serialized_dag.task_dict) == set(dag.task_dict) - @patch("airflow.settings.task_policy", cluster_policies.cluster_policy) + @patch("airflow.settings.task_policy", cluster_policies.task_policy) def test_task_cluster_policy_violation(self): """ test that file processing results in import error when task does not @@ -974,7 +974,7 @@ def test_task_cluster_policy_violation(self): } assert expected_import_errors == dagbag.import_errors - @patch("airflow.settings.task_policy", cluster_policies.cluster_policy) + @patch("airflow.settings.task_policy", cluster_policies.task_policy) def test_task_cluster_policy_nonstring_owner(self): """ test that file processing results in import error when task does not @@ -994,7 +994,7 @@ def test_task_cluster_policy_nonstring_owner(self): } assert expected_import_errors == dagbag.import_errors - @patch("airflow.settings.task_policy", cluster_policies.cluster_policy) + @patch("airflow.settings.task_policy", cluster_policies.task_policy) def test_task_cluster_policy_obeyed(self): """ test that dag successfully imported without import errors when tasks From 2b2e8be1701947d19b449c1a657f707e22993c03 Mon Sep 17 00:00:00 2001 From: humit0 Date: Wed, 11 May 2022 12:07:38 +0900 Subject: [PATCH 2/2] rename task_policy as example_task_policy. --- .../concepts/cluster-policies.rst | 15 +++++--------- tests/cluster_policies/__init__.py | 20 ++++++++++++++++++- tests/models/test_dagbag.py | 6 +++--- 3 files changed, 27 insertions(+), 14 deletions(-) diff --git a/docs/apache-airflow/concepts/cluster-policies.rst b/docs/apache-airflow/concepts/cluster-policies.rst index 96b523a610f4e..e1c664660a4d6 100644 --- a/docs/apache-airflow/concepts/cluster-policies.rst +++ b/docs/apache-airflow/concepts/cluster-policies.rst @@ -57,17 +57,12 @@ This policy checks if each DAG has at least one tag defined: Task policies ------------- -Here's an example of enforcing a maximum timeout policy on every task:: +Here's an example of enforcing a maximum timeout policy on every task: - class TimedOperator(BaseOperator, ABC): - timeout: timedelta - - - def task_policy(task: TimedOperator): - if task.task_type == 'HivePartitionSensor': - task.queue = "sensor_queue" - if task.timeout > timedelta(hours=48): - task.timeout = timedelta(hours=48) +.. literalinclude:: /../../tests/cluster_policies/__init__.py + :language: python + :start-after: [START example_task_cluster_policy] + :end-before: [END example_task_cluster_policy] You could also implement to protect against common errors, rather than as technical security controls. For example, don't run tasks without airflow owners: diff --git a/tests/cluster_policies/__init__.py b/tests/cluster_policies/__init__.py index b4901be2adb98..b54667647c5cf 100644 --- a/tests/cluster_policies/__init__.py +++ b/tests/cluster_policies/__init__.py @@ -15,6 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from abc import ABC +from datetime import timedelta from typing import Callable, List from airflow.configuration import conf @@ -60,7 +62,7 @@ def _check_task_rules(current_task: BaseOperator): ) -def task_policy(task: BaseOperator): +def example_task_policy(task: BaseOperator): """Ensure Tasks have non-default owners.""" _check_task_rules(task) @@ -78,6 +80,22 @@ def dag_policy(dag: DAG): # [END example_dag_cluster_policy] + +# [START example_task_cluster_policy] +class TimedOperator(BaseOperator, ABC): + timeout: timedelta + + +def task_policy(task: TimedOperator): + if task.task_type == 'HivePartitionSensor': + task.queue = "sensor_queue" + if task.timeout > timedelta(hours=48): + task.timeout = timedelta(hours=48) + + +# [END example_task_cluster_policy] + + # [START example_task_mutation_hook] def task_instance_mutation_hook(task_instance: TaskInstance): if task_instance.try_number >= 1: diff --git a/tests/models/test_dagbag.py b/tests/models/test_dagbag.py index 432c55c83f075..c37cc55b38655 100644 --- a/tests/models/test_dagbag.py +++ b/tests/models/test_dagbag.py @@ -955,7 +955,7 @@ def test_collect_dags_from_db(self): assert serialized_dag.dag_id == dag.dag_id assert set(serialized_dag.task_dict) == set(dag.task_dict) - @patch("airflow.settings.task_policy", cluster_policies.task_policy) + @patch("airflow.settings.task_policy", cluster_policies.example_task_policy) def test_task_cluster_policy_violation(self): """ test that file processing results in import error when task does not @@ -974,7 +974,7 @@ def test_task_cluster_policy_violation(self): } assert expected_import_errors == dagbag.import_errors - @patch("airflow.settings.task_policy", cluster_policies.task_policy) + @patch("airflow.settings.task_policy", cluster_policies.example_task_policy) def test_task_cluster_policy_nonstring_owner(self): """ test that file processing results in import error when task does not @@ -994,7 +994,7 @@ def test_task_cluster_policy_nonstring_owner(self): } assert expected_import_errors == dagbag.import_errors - @patch("airflow.settings.task_policy", cluster_policies.task_policy) + @patch("airflow.settings.task_policy", cluster_policies.example_task_policy) def test_task_cluster_policy_obeyed(self): """ test that dag successfully imported without import errors when tasks