diff --git a/tests/cluster_policies/__init__.py b/tests/cluster_policies/__init__.py index bab74cc5d7a5f..b54667647c5cf 100644 --- a/tests/cluster_policies/__init__.py +++ b/tests/cluster_policies/__init__.py @@ -62,7 +62,7 @@ def _check_task_rules(current_task: BaseOperator): ) -def cluster_policy(task: BaseOperator): +def example_task_policy(task: BaseOperator): """Ensure Tasks have non-default owners.""" _check_task_rules(task) @@ -81,11 +81,11 @@ def dag_policy(dag: DAG): # [END example_dag_cluster_policy] +# [START example_task_cluster_policy] class TimedOperator(BaseOperator, ABC): timeout: timedelta -# [START example_task_cluster_policy] def task_policy(task: TimedOperator): if task.task_type == 'HivePartitionSensor': task.queue = "sensor_queue" diff --git a/tests/models/test_dagbag.py b/tests/models/test_dagbag.py index 4554389f950e0..c37cc55b38655 100644 --- a/tests/models/test_dagbag.py +++ b/tests/models/test_dagbag.py @@ -955,7 +955,7 @@ def test_collect_dags_from_db(self): assert serialized_dag.dag_id == dag.dag_id assert set(serialized_dag.task_dict) == set(dag.task_dict) - @patch("airflow.settings.task_policy", cluster_policies.cluster_policy) + @patch("airflow.settings.task_policy", cluster_policies.example_task_policy) def test_task_cluster_policy_violation(self): """ test that file processing results in import error when task does not @@ -974,7 +974,7 @@ def test_task_cluster_policy_violation(self): } assert expected_import_errors == dagbag.import_errors - @patch("airflow.settings.task_policy", cluster_policies.cluster_policy) + @patch("airflow.settings.task_policy", cluster_policies.example_task_policy) def test_task_cluster_policy_nonstring_owner(self): """ test that file processing results in import error when task does not @@ -994,7 +994,7 @@ def test_task_cluster_policy_nonstring_owner(self): } assert expected_import_errors == dagbag.import_errors - @patch("airflow.settings.task_policy", cluster_policies.cluster_policy) + @patch("airflow.settings.task_policy", cluster_policies.example_task_policy) def test_task_cluster_policy_obeyed(self): """ test that dag successfully imported without import errors when tasks