From 4a5ce39ada36cead975429f3892aa532a7430df7 Mon Sep 17 00:00:00 2001
From: Kaxil Naik
Date: Wed, 7 Oct 2020 20:53:24 +0100
Subject: [PATCH 1/2] Enable Black across entire Repo
---
.pre-commit-config.yaml | 5 ++---
setup.cfg | 6 +-----
2 files changed, 3 insertions(+), 8 deletions(-)
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 43564d067e24c..be8573802024b 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -159,8 +159,7 @@ repos:
rev: 20.8b1
hooks:
- id: black
- files: api_connexion/.*\.py|.*providers.*\.py|^chart/tests/.*\.py
- exclude: .*kubernetes_pod\.py|.*google/common/hooks/base_google\.py$
+ exclude: .*kubernetes_pod\.py|.*google/common/hooks/base_google\.py$|^airflow/configuration.py$
args: [--config=./pyproject.toml]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.3.0
@@ -203,7 +202,7 @@ repos:
name: Run isort to sort imports
types: [python]
# To keep consistent with the global isort skip config defined in setup.cfg
- exclude: ^build/.*$|^.tox/.*$|^venv/.*$|.*api_connexion/.*\.py|.*providers.*\.py
+ exclude: ^build/.*$|^.tox/.*$|^venv/.*$
- repo: https://github.com/pycqa/pydocstyle
rev: 5.1.1
hooks:
diff --git a/setup.cfg b/setup.cfg
index 9cb0fa9ddb617..3d07da425bb55 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -70,11 +70,7 @@ ignore_errors = True
line_length=110
combine_as_imports = true
default_section = THIRDPARTY
-include_trailing_comma = true
known_first_party=airflow,tests
-multi_line_output=5
# Need to be consistent with the exclude config defined in pre-commit-config.yaml
skip=build,.tox,venv
-# ToDo: Enable the below before Airflow 2.0
-# profile = "black"
-skip_glob=*/api_connexion/**/*.py,*/providers/**/*.py,provider_packages/**,chart/tests/*.py
+profile = black
From 9bdaf18f7f4c222603107ee73c198a1c87169bab Mon Sep 17 00:00:00 2001
From: Kaxil Naik
Date: Wed, 7 Oct 2020 20:59:21 +0100
Subject: [PATCH 2/2] Black Formatting
---
airflow/__init__.py | 3 +
airflow/api/__init__.py | 5 +-
airflow/api/auth/backend/basic_auth.py | 5 +-
airflow/api/auth/backend/default.py | 1 +
airflow/api/auth/backend/kerberos_auth.py | 5 +-
airflow/api/client/__init__.py | 2 +-
airflow/api/client/json_client.py | 30 +-
airflow/api/client/local_client.py | 7 +-
airflow/api/common/experimental/__init__.py | 8 +-
airflow/api/common/experimental/delete_dag.py | 11 +-
.../api/common/experimental/get_dag_runs.py | 22 +-
.../api/common/experimental/get_lineage.py | 10 +-
.../common/experimental/get_task_instance.py | 3 +-
airflow/api/common/experimental/mark_tasks.py | 73 +-
.../api/common/experimental/trigger_dag.py | 19 +-
.../endpoints/dag_run_endpoint.py | 2 +-
.../endpoints/dag_source_endpoint.py | 1 -
.../endpoints/task_instance_endpoint.py | 20 +-
.../api_connexion/schemas/sla_miss_schema.py | 1 +
.../schemas/task_instance_schema.py | 6 +-
airflow/cli/cli_parser.py | 943 +++++------
airflow/cli/commands/cheat_sheet_command.py | 1 +
airflow/cli/commands/config_command.py | 4 +-
airflow/cli/commands/connection_command.py | 75 +-
airflow/cli/commands/dag_command.py | 102 +-
airflow/cli/commands/db_command.py | 10 +-
airflow/cli/commands/info_command.py | 10 +-
airflow/cli/commands/legacy_commands.py | 2 +-
airflow/cli/commands/pool_command.py | 13 +-
airflow/cli/commands/role_command.py | 4 +-
.../cli/commands/rotate_fernet_key_command.py | 3 +-
airflow/cli/commands/scheduler_command.py | 11 +-
airflow/cli/commands/sync_perm_command.py | 4 +-
airflow/cli/commands/task_command.py | 82 +-
airflow/cli/commands/user_command.py | 63 +-
airflow/cli/commands/variable_command.py | 11 +-
airflow/cli/commands/webserver_command.py | 106 +-
.../airflow_local_settings.py | 43 +-
airflow/config_templates/default_celery.py | 49 +-
airflow/configuration.py | 1 +
airflow/contrib/__init__.py | 4 +-
airflow/contrib/hooks/aws_dynamodb_hook.py | 3 +-
.../contrib/hooks/aws_glue_catalog_hook.py | 3 +-
airflow/contrib/hooks/aws_hook.py | 6 +-
airflow/contrib/hooks/aws_logs_hook.py | 3 +-
.../hooks/azure_container_instance_hook.py | 3 +-
.../hooks/azure_container_registry_hook.py | 3 +-
.../hooks/azure_container_volume_hook.py | 3 +-
airflow/contrib/hooks/azure_cosmos_hook.py | 3 +-
airflow/contrib/hooks/azure_data_lake_hook.py | 3 +-
airflow/contrib/hooks/azure_fileshare_hook.py | 3 +-
airflow/contrib/hooks/bigquery_hook.py | 9 +-
airflow/contrib/hooks/cassandra_hook.py | 3 +-
airflow/contrib/hooks/cloudant_hook.py | 3 +-
airflow/contrib/hooks/databricks_hook.py | 17 +-
airflow/contrib/hooks/datadog_hook.py | 3 +-
airflow/contrib/hooks/datastore_hook.py | 3 +-
airflow/contrib/hooks/dingding_hook.py | 3 +-
airflow/contrib/hooks/discord_webhook_hook.py | 3 +-
airflow/contrib/hooks/emr_hook.py | 3 +-
airflow/contrib/hooks/fs_hook.py | 3 +-
airflow/contrib/hooks/ftp_hook.py | 3 +-
airflow/contrib/hooks/gcp_api_base_hook.py | 6 +-
airflow/contrib/hooks/gcp_bigtable_hook.py | 3 +-
airflow/contrib/hooks/gcp_cloud_build_hook.py | 3 +-
airflow/contrib/hooks/gcp_compute_hook.py | 6 +-
airflow/contrib/hooks/gcp_container_hook.py | 6 +-
airflow/contrib/hooks/gcp_dataflow_hook.py | 6 +-
airflow/contrib/hooks/gcp_dataproc_hook.py | 6 +-
airflow/contrib/hooks/gcp_dlp_hook.py | 3 +-
airflow/contrib/hooks/gcp_function_hook.py | 6 +-
airflow/contrib/hooks/gcp_kms_hook.py | 6 +-
airflow/contrib/hooks/gcp_mlengine_hook.py | 3 +-
.../hooks/gcp_natural_language_hook.py | 3 +-
airflow/contrib/hooks/gcp_pubsub_hook.py | 3 +-
.../contrib/hooks/gcp_speech_to_text_hook.py | 6 +-
airflow/contrib/hooks/gcp_tasks_hook.py | 3 +-
.../contrib/hooks/gcp_text_to_speech_hook.py | 6 +-
airflow/contrib/hooks/gcp_transfer_hook.py | 6 +-
airflow/contrib/hooks/gcp_translate_hook.py | 3 +-
.../hooks/gcp_video_intelligence_hook.py | 3 +-
airflow/contrib/hooks/gcp_vision_hook.py | 3 +-
airflow/contrib/hooks/gcs_hook.py | 6 +-
airflow/contrib/hooks/gdrive_hook.py | 3 +-
airflow/contrib/hooks/grpc_hook.py | 3 +-
airflow/contrib/hooks/imap_hook.py | 3 +-
airflow/contrib/hooks/jenkins_hook.py | 3 +-
airflow/contrib/hooks/mongo_hook.py | 3 +-
airflow/contrib/hooks/openfaas_hook.py | 3 +-
airflow/contrib/hooks/opsgenie_alert_hook.py | 3 +-
airflow/contrib/hooks/pagerduty_hook.py | 3 +-
airflow/contrib/hooks/pinot_hook.py | 3 +-
airflow/contrib/hooks/qubole_check_hook.py | 3 +-
airflow/contrib/hooks/qubole_hook.py | 3 +-
airflow/contrib/hooks/redis_hook.py | 3 +-
airflow/contrib/hooks/sagemaker_hook.py | 9 +-
airflow/contrib/hooks/salesforce_hook.py | 3 +-
airflow/contrib/hooks/segment_hook.py | 3 +-
airflow/contrib/hooks/slack_webhook_hook.py | 3 +-
airflow/contrib/hooks/spark_jdbc_hook.py | 3 +-
airflow/contrib/hooks/spark_sql_hook.py | 3 +-
airflow/contrib/hooks/spark_submit_hook.py | 3 +-
airflow/contrib/hooks/sqoop_hook.py | 3 +-
airflow/contrib/hooks/ssh_hook.py | 3 +-
airflow/contrib/hooks/vertica_hook.py | 3 +-
airflow/contrib/hooks/wasb_hook.py | 3 +-
airflow/contrib/hooks/winrm_hook.py | 3 +-
.../contrib/operators/adls_list_operator.py | 3 +-
airflow/contrib/operators/adls_to_gcs.py | 6 +-
.../azure_container_instances_operator.py | 3 +-
.../operators/azure_cosmos_operator.py | 3 +-
.../operators/bigquery_check_operator.py | 7 +-
.../contrib/operators/bigquery_get_data.py | 3 +-
.../contrib/operators/bigquery_operator.py | 12 +-
.../contrib/operators/bigquery_to_bigquery.py | 3 +-
airflow/contrib/operators/bigquery_to_gcs.py | 6 +-
.../operators/bigquery_to_mysql_operator.py | 3 +-
airflow/contrib/operators/cassandra_to_gcs.py | 6 +-
.../contrib/operators/databricks_operator.py | 6 +-
.../contrib/operators/dataflow_operator.py | 16 +-
.../contrib/operators/dataproc_operator.py | 53 +-
.../operators/datastore_export_operator.py | 6 +-
.../operators/datastore_import_operator.py | 6 +-
.../contrib/operators/dingding_operator.py | 3 +-
.../operators/discord_webhook_operator.py | 3 +-
.../operators/docker_swarm_operator.py | 3 +-
airflow/contrib/operators/druid_operator.py | 3 +-
airflow/contrib/operators/dynamodb_to_s3.py | 3 +-
airflow/contrib/operators/ecs_operator.py | 3 +-
.../operators/emr_add_steps_operator.py | 3 +-
.../operators/emr_create_job_flow_operator.py | 3 +-
.../emr_terminate_job_flow_operator.py | 3 +-
airflow/contrib/operators/file_to_gcs.py | 6 +-
airflow/contrib/operators/file_to_wasb.py | 3 +-
.../operators/gcp_bigtable_operator.py | 28 +-
.../operators/gcp_cloud_build_operator.py | 3 +-
.../contrib/operators/gcp_compute_operator.py | 30 +-
.../operators/gcp_container_operator.py | 16 +-
airflow/contrib/operators/gcp_dlp_operator.py | 58 +-
.../operators/gcp_function_operator.py | 12 +-
.../gcp_natural_language_operator.py | 21 +-
.../contrib/operators/gcp_spanner_operator.py | 33 +-
.../operators/gcp_speech_to_text_operator.py | 6 +-
airflow/contrib/operators/gcp_sql_operator.py | 15 +-
.../contrib/operators/gcp_tasks_operator.py | 19 +-
.../operators/gcp_text_to_speech_operator.py | 6 +-
.../operators/gcp_transfer_operator.py | 48 +-
.../operators/gcp_translate_operator.py | 3 +-
.../gcp_translate_speech_operator.py | 6 +-
.../gcp_video_intelligence_operator.py | 6 +-
.../contrib/operators/gcp_vision_operator.py | 59 +-
airflow/contrib/operators/gcs_acl_operator.py | 12 +-
.../contrib/operators/gcs_delete_operator.py | 6 +-
.../operators/gcs_download_operator.py | 6 +-
.../contrib/operators/gcs_list_operator.py | 6 +-
airflow/contrib/operators/gcs_operator.py | 6 +-
airflow/contrib/operators/gcs_to_bq.py | 6 +-
airflow/contrib/operators/gcs_to_gcs.py | 6 +-
.../operators/gcs_to_gcs_transfer_operator.py | 3 +-
.../operators/gcs_to_gdrive_operator.py | 6 +-
airflow/contrib/operators/gcs_to_s3.py | 6 +-
airflow/contrib/operators/grpc_operator.py | 3 +-
airflow/contrib/operators/hive_to_dynamodb.py | 3 +-
.../imap_attachment_to_s3_operator.py | 3 +-
.../operators/jenkins_job_trigger_operator.py | 3 +-
airflow/contrib/operators/jira_operator.py | 3 +-
.../operators/kubernetes_pod_operator.py | 3 +-
.../contrib/operators/mlengine_operator.py | 19 +-
airflow/contrib/operators/mongo_to_s3.py | 3 +-
airflow/contrib/operators/mssql_to_gcs.py | 6 +-
airflow/contrib/operators/mysql_to_gcs.py | 6 +-
.../operators/opsgenie_alert_operator.py | 3 +-
.../oracle_to_azure_data_lake_transfer.py | 3 +-
.../operators/oracle_to_oracle_transfer.py | 6 +-
.../operators/postgres_to_gcs_operator.py | 6 +-
airflow/contrib/operators/pubsub_operator.py | 7 +-
.../operators/qubole_check_operator.py | 6 +-
airflow/contrib/operators/qubole_operator.py | 3 +-
.../operators/redis_publish_operator.py | 3 +-
.../operators/s3_copy_object_operator.py | 3 +-
.../operators/s3_delete_objects_operator.py | 3 +-
airflow/contrib/operators/s3_list_operator.py | 3 +-
.../contrib/operators/s3_to_gcs_operator.py | 3 +-
.../operators/s3_to_gcs_transfer_operator.py | 7 +-
.../contrib/operators/s3_to_sftp_operator.py | 3 +-
.../operators/sagemaker_base_operator.py | 3 +-
.../sagemaker_endpoint_config_operator.py | 3 +-
.../operators/sagemaker_endpoint_operator.py | 3 +-
.../operators/sagemaker_model_operator.py | 3 +-
.../operators/sagemaker_training_operator.py | 3 +-
.../operators/sagemaker_transform_operator.py | 3 +-
.../operators/sagemaker_tuning_operator.py | 3 +-
.../operators/segment_track_event_operator.py | 3 +-
.../contrib/operators/sftp_to_s3_operator.py | 3 +-
.../operators/slack_webhook_operator.py | 3 +-
.../contrib/operators/snowflake_operator.py | 3 +-
.../contrib/operators/spark_jdbc_operator.py | 3 +-
.../contrib/operators/spark_sql_operator.py | 3 +-
.../operators/spark_submit_operator.py | 3 +-
airflow/contrib/operators/sql_to_gcs.py | 6 +-
airflow/contrib/operators/sqoop_operator.py | 3 +-
airflow/contrib/operators/ssh_operator.py | 3 +-
airflow/contrib/operators/vertica_operator.py | 3 +-
airflow/contrib/operators/vertica_to_hive.py | 6 +-
airflow/contrib/operators/vertica_to_mysql.py | 6 +-
.../operators/wasb_delete_blob_operator.py | 3 +-
airflow/contrib/operators/winrm_operator.py | 3 +-
.../contrib/secrets/gcp_secrets_manager.py | 3 +-
.../aws_glue_catalog_partition_sensor.py | 3 +-
.../contrib/sensors/azure_cosmos_sensor.py | 3 +-
airflow/contrib/sensors/bash_sensor.py | 3 +-
.../contrib/sensors/celery_queue_sensor.py | 3 +-
airflow/contrib/sensors/datadog_sensor.py | 3 +-
airflow/contrib/sensors/emr_base_sensor.py | 3 +-
.../contrib/sensors/emr_job_flow_sensor.py | 3 +-
airflow/contrib/sensors/emr_step_sensor.py | 3 +-
airflow/contrib/sensors/file_sensor.py | 3 +-
airflow/contrib/sensors/ftp_sensor.py | 3 +-
.../contrib/sensors/gcp_transfer_sensor.py | 6 +-
airflow/contrib/sensors/gcs_sensor.py | 19 +-
airflow/contrib/sensors/hdfs_sensor.py | 9 +-
.../contrib/sensors/imap_attachment_sensor.py | 3 +-
airflow/contrib/sensors/mongo_sensor.py | 3 +-
airflow/contrib/sensors/pubsub_sensor.py | 3 +-
airflow/contrib/sensors/python_sensor.py | 3 +-
airflow/contrib/sensors/qubole_sensor.py | 7 +-
airflow/contrib/sensors/redis_key_sensor.py | 3 +-
.../contrib/sensors/redis_pub_sub_sensor.py | 3 +-
.../contrib/sensors/sagemaker_base_sensor.py | 3 +-
.../sensors/sagemaker_endpoint_sensor.py | 3 +-
.../sensors/sagemaker_training_sensor.py | 6 +-
.../sensors/sagemaker_transform_sensor.py | 3 +-
.../sensors/sagemaker_tuning_sensor.py | 3 +-
airflow/contrib/sensors/wasb_sensor.py | 3 +-
airflow/contrib/sensors/weekday_sensor.py | 3 +-
.../contrib/task_runner/cgroup_task_runner.py | 3 +-
airflow/contrib/utils/__init__.py | 5 +-
airflow/contrib/utils/gcp_field_sanitizer.py | 6 +-
airflow/contrib/utils/gcp_field_validator.py | 7 +-
airflow/contrib/utils/log/__init__.py | 5 +-
.../log/task_handler_with_custom_formatter.py | 3 +-
.../contrib/utils/mlengine_operator_utils.py | 3 +-
.../utils/mlengine_prediction_summary.py | 3 +-
airflow/contrib/utils/sendgrid.py | 3 +-
airflow/contrib/utils/weekday.py | 3 +-
airflow/example_dags/example_bash_operator.py | 2 +-
.../example_dags/example_branch_operator.py | 2 +-
.../example_branch_python_dop_operator_3.py | 9 +-
airflow/example_dags/example_dag_decorator.py | 11 +-
.../example_external_task_marker_dag.py | 24 +-
.../example_kubernetes_executor.py | 38 +-
.../example_kubernetes_executor_config.py | 63 +-
airflow/example_dags/example_latest_only.py | 2 +-
.../example_latest_only_with_trigger.py | 2 +-
.../example_dags/example_nested_branch_dag.py | 5 +-
...example_passing_params_via_test_command.py | 16 +-
.../example_dags/example_python_operator.py | 7 +-
.../example_dags/example_subdag_operator.py | 6 +-
.../example_trigger_controller_dag.py | 2 +-
.../example_trigger_target_dag.py | 2 +-
airflow/example_dags/example_xcom.py | 2 +-
airflow/example_dags/subdags/subdag.py | 2 +
airflow/example_dags/tutorial.py | 1 +
.../tutorial_decorated_etl_dag.py | 3 +
airflow/example_dags/tutorial_etl_dag.py | 4 +
airflow/executors/base_executor.py | 62 +-
airflow/executors/celery_executor.py | 65 +-
.../executors/celery_kubernetes_executor.py | 40 +-
airflow/executors/dask_executor.py | 12 +-
airflow/executors/debug_executor.py | 10 +-
airflow/executors/executor_loader.py | 9 +-
airflow/executors/kubernetes_executor.py | 311 ++--
airflow/executors/local_executor.py | 52 +-
airflow/executors/sequential_executor.py | 12 +-
airflow/hooks/base_hook.py | 2 +-
airflow/hooks/dbapi_hook.py | 31 +-
airflow/hooks/docker_hook.py | 3 +-
airflow/hooks/druid_hook.py | 3 +-
airflow/hooks/hdfs_hook.py | 3 +-
airflow/hooks/hive_hooks.py | 8 +-
airflow/hooks/http_hook.py | 3 +-
airflow/hooks/jdbc_hook.py | 3 +-
airflow/hooks/mssql_hook.py | 3 +-
airflow/hooks/mysql_hook.py | 3 +-
airflow/hooks/oracle_hook.py | 3 +-
airflow/hooks/pig_hook.py | 3 +-
airflow/hooks/postgres_hook.py | 3 +-
airflow/hooks/presto_hook.py | 3 +-
airflow/hooks/samba_hook.py | 3 +-
airflow/hooks/slack_hook.py | 3 +-
airflow/hooks/sqlite_hook.py | 3 +-
airflow/hooks/webhdfs_hook.py | 3 +-
airflow/hooks/zendesk_hook.py | 3 +-
airflow/jobs/backfill_job.py | 297 ++--
airflow/jobs/base_job.py | 30 +-
airflow/jobs/local_task_job.py | 77 +-
airflow/jobs/scheduler_job.py | 431 +++--
airflow/kubernetes/kube_client.py | 22 +-
airflow/kubernetes/pod_generator.py | 81 +-
.../kubernetes/pod_generator_deprecated.py | 43 +-
airflow/kubernetes/pod_launcher.py | 108 +-
airflow/kubernetes/pod_runtime_info_env.py | 6 +-
airflow/kubernetes/refresh_config.py | 6 +-
airflow/kubernetes/secret.py | 39 +-
airflow/lineage/__init__.py | 67 +-
airflow/lineage/entities.py | 2 +-
airflow/logging_config.py | 15 +-
airflow/macros/hive.py | 16 +-
airflow/migrations/env.py | 5 +-
.../versions/03bc53e68815_add_sm_dag_index.py | 4 +-
.../versions/05f30312d566_merge_heads.py | 4 +-
.../0a2a5b66e19d_add_task_reschedule_table.py | 20 +-
.../0e2a74e0fc9f_add_time_zone_awareness.py | 188 +--
...add_dag_id_state_index_on_dag_run_table.py | 4 +-
.../13eb55f81627_for_compatibility.py | 4 +-
.../1507a7289a2f_create_is_encrypted.py | 17 +-
...e3_add_is_encrypted_column_to_variable_.py | 4 +-
.../versions/1b38cef5b76e_add_dagrun.py | 30 +-
.../211e584da130_add_ti_state_index.py | 4 +-
...24_add_executor_config_to_task_instance.py | 4 +-
.../versions/2e541a1dcfed_task_duration.py | 14 +-
.../2e82aab8ef20_rename_user_table.py | 4 +-
...0f54d61_more_logging_into_task_isntance.py | 4 +-
...4_add_kubernetes_resource_checkpointing.py | 15 +-
.../versions/40e67319e3a9_dagrun_config.py | 4 +-
.../41f5f12752f8_add_superuser_field.py | 4 +-
.../versions/4446e08588_dagrun_start_end.py | 4 +-
..._add_fractional_seconds_to_mysql_tables.py | 202 +--
.../502898887f84_adding_extra_to_log.py | 4 +-
..._mssql_exec_date_rendered_task_instance.py | 4 +-
.../versions/52d714495f0_job_id_indices.py | 7 +-
...61833c1c74b_add_password_column_to_user.py | 4 +-
...de9cddf6c9_add_task_fails_journal_table.py | 4 +-
...4a4_make_taskinstance_pool_not_nullable.py | 14 +-
...hange_datetime_to_datetime2_6_on_mssql_.py | 155 +-
.../7939bcff74ba_add_dagtags_table.py | 7 +-
.../849da589634d_prefix_dag_permissions.py | 10 +-
.../8504051e801b_xcom_dag_task_indices.py | 3 +-
...add_rendered_task_instance_fields_table.py | 4 +-
.../856955da8476_fix_sqlite_foreign_key.py | 43 +-
...5c0_add_kubernetes_scheduler_uniqueness.py | 23 +-
...3f6d53_add_unique_constraint_to_conn_id.py | 22 +-
...c8_task_reschedule_fk_on_cascade_delete.py | 18 +-
.../947454bf1dff_add_ti_job_id_index.py | 4 +-
.../952da73b5eff_add_dag_code_table.py | 18 +-
.../versions/9635ae0956e7_index_faskfail.py | 10 +-
..._add_scheduling_decision_to_dagrun_and_.py | 1 +
...b_add_pool_slots_field_to_task_instance.py | 4 +-
.../a56c9515abdc_remove_dag_stat_table.py | 14 +-
...dd_precision_to_execution_date_in_mysql.py | 10 +-
.../versions/b0125267960b_merge_heads.py | 4 +-
...6_add_a_column_to_track_the_encryption_.py | 7 +-
...dd_notification_sent_column_to_sla_miss.py | 4 +-
...6_make_xcom_value_column_a_large_binary.py | 4 +-
...4f3d11e8b_drop_kuberesourceversion_and_.py | 30 +-
.../bf00311e1990_add_index_to_taskinstance.py | 11 +-
.../c8ffec048a3b_add_fields_to_dag.py | 4 +-
...7_add_max_tries_column_to_task_instance.py | 24 +-
.../cf5dc11e79ad_drop_user_and_chart.py | 17 +-
...ae31099d61_increase_text_size_for_mysql.py | 4 +-
.../d38e04c12aa2_add_serialized_dag_table.py | 25 +-
..._add_dag_hash_column_to_serialized_dag_.py | 3 +-
.../versions/dd25f486b8ea_add_idx_log_dag.py | 4 +-
...4ecb8fbee3_add_schedule_interval_to_dag.py | 4 +-
...e357a868_update_schema_for_smart_sensor.py | 15 +-
.../versions/e3a246e0dc1_current_schema.py | 67 +-
...433877c24_fix_mysql_not_null_constraint.py | 4 +-
.../f2ca10b85618_add_dag_stats_table.py | 22 +-
...increase_length_for_connection_password.py | 20 +-
airflow/models/base.py | 4 +-
airflow/models/baseoperator.py | 293 ++--
airflow/models/connection.py | 80 +-
airflow/models/crypto.py | 11 +-
airflow/models/dag.py | 546 +++---
airflow/models/dagbag.py | 115 +-
airflow/models/dagcode.py | 52 +-
airflow/models/dagrun.py | 217 +--
airflow/models/log.py | 4 +-
airflow/models/pool.py | 16 +-
airflow/models/renderedtifields.py | 80 +-
airflow/models/sensorinstance.py | 28 +-
airflow/models/serialized_dag.py | 46 +-
airflow/models/skipmixin.py | 32 +-
airflow/models/slamiss.py | 7 +-
airflow/models/taskfail.py | 5 +-
airflow/models/taskinstance.py | 452 ++---
airflow/models/taskreschedule.py | 29 +-
airflow/models/variable.py | 13 +-
airflow/models/xcom.py | 108 +-
airflow/models/xcom_arg.py | 13 +-
airflow/operators/bash.py | 28 +-
airflow/operators/bash_operator.py | 3 +-
airflow/operators/check_operator.py | 20 +-
airflow/operators/dagrun_operator.py | 7 +-
airflow/operators/docker_operator.py | 3 +-
airflow/operators/druid_check_operator.py | 3 +-
airflow/operators/email.py | 35 +-
airflow/operators/email_operator.py | 3 +-
airflow/operators/gcs_to_s3.py | 3 +-
airflow/operators/generic_transfer.py | 22 +-
.../operators/google_api_to_s3_transfer.py | 13 +-
airflow/operators/hive_operator.py | 3 +-
airflow/operators/hive_stats_operator.py | 3 +-
airflow/operators/hive_to_druid.py | 6 +-
airflow/operators/hive_to_mysql.py | 6 +-
airflow/operators/hive_to_samba_operator.py | 3 +-
airflow/operators/http_operator.py | 3 +-
airflow/operators/jdbc_operator.py | 3 +-
airflow/operators/latest_only.py | 10 +-
airflow/operators/latest_only_operator.py | 3 +-
airflow/operators/mssql_operator.py | 3 +-
airflow/operators/mssql_to_hive.py | 6 +-
airflow/operators/mysql_operator.py | 3 +-
airflow/operators/mysql_to_hive.py | 6 +-
airflow/operators/oracle_operator.py | 3 +-
airflow/operators/papermill_operator.py | 3 +-
airflow/operators/pig_operator.py | 3 +-
airflow/operators/postgres_operator.py | 3 +-
airflow/operators/presto_check_operator.py | 12 +-
airflow/operators/presto_to_mysql.py | 6 +-
airflow/operators/python.py | 116 +-
airflow/operators/python_operator.py | 8 +-
airflow/operators/redshift_to_s3_operator.py | 6 +-
.../operators/s3_file_transform_operator.py | 3 +-
airflow/operators/s3_to_hive_operator.py | 6 +-
airflow/operators/s3_to_redshift_operator.py | 6 +-
airflow/operators/slack_operator.py | 3 +-
airflow/operators/sql.py | 75 +-
airflow/operators/sql_branch_operator.py | 6 +-
airflow/operators/sqlite_operator.py | 3 +-
airflow/operators/subdag_operator.py | 56 +-
airflow/plugins_manager.py | 60 +-
.../example_dags/example_glacier_to_gcs.py | 4 +-
.../providers/amazon/aws/hooks/base_aws.py | 2 +-
.../amazon/aws/hooks/cloud_formation.py | 2 +-
.../hooks/elasticache_replication_group.py | 3 +-
airflow/providers/amazon/aws/hooks/glacier.py | 1 +
.../amazon/aws/hooks/glue_catalog.py | 2 +-
.../providers/amazon/aws/hooks/sagemaker.py | 2 +-
.../amazon/aws/hooks/secrets_manager.py | 1 +
airflow/providers/amazon/aws/hooks/sns.py | 2 +-
.../providers/amazon/aws/operators/batch.py | 2 +-
.../amazon/aws/operators/datasync.py | 2 +-
airflow/providers/amazon/aws/operators/ecs.py | 2 +-
.../aws/operators/sagemaker_transform.py | 2 +-
.../providers/amazon/aws/sensors/emr_base.py | 2 +-
.../amazon/aws/sensors/sagemaker_training.py | 3 +-
.../amazon/aws/transfers/dynamodb_to_s3.py | 2 +-
.../amazon/aws/transfers/gcs_to_s3.py | 2 +-
.../amazon/aws/transfers/glacier_to_gcs.py | 2 +-
.../amazon/aws/transfers/hive_to_dynamodb.py | 2 +-
.../amazon/aws/transfers/mongo_to_s3.py | 2 +-
.../amazon/backport_provider_setup.py | 2 +-
.../cassandra/backport_provider_setup.py | 2 +-
.../apache/druid/backport_provider_setup.py | 2 +-
.../apache/hdfs/backport_provider_setup.py | 2 +-
.../apache/hive/backport_provider_setup.py | 2 +-
.../apache/kylin/backport_provider_setup.py | 2 +-
.../apache/livy/backport_provider_setup.py | 2 +-
.../apache/pig/backport_provider_setup.py | 2 +-
.../apache/pinot/backport_provider_setup.py | 2 +-
.../apache/spark/backport_provider_setup.py | 2 +-
.../apache/sqoop/backport_provider_setup.py | 2 +-
.../celery/backport_provider_setup.py | 2 +-
.../cloudant/backport_provider_setup.py | 2 +-
.../kubernetes/backport_provider_setup.py | 2 +-
.../cncf/kubernetes/hooks/kubernetes.py | 2 +-
.../kubernetes/operators/kubernetes_pod.py | 7 +-
.../databricks/backport_provider_setup.py | 2 +-
.../providers/databricks/hooks/databricks.py | 2 +-
.../databricks/operators/databricks.py | 2 +-
.../datadog/backport_provider_setup.py | 2 +-
.../dingding/backport_provider_setup.py | 2 +-
airflow/providers/dingding/hooks/dingding.py | 2 +-
.../providers/dingding/operators/dingding.py | 3 +-
.../discord/backport_provider_setup.py | 2 +-
.../docker/backport_provider_setup.py | 2 +-
.../elasticsearch/backport_provider_setup.py | 2 +-
.../exasol/backport_provider_setup.py | 2 +-
airflow/providers/exasol/hooks/exasol.py | 2 +-
.../facebook/backport_provider_setup.py | 2 +-
.../providers/ftp/backport_provider_setup.py | 2 +-
airflow/providers/ftp/hooks/ftp.py | 2 +-
.../google/backport_provider_setup.py | 2 +-
.../example_azure_fileshare_to_gcs.py | 2 +-
.../example_dags/example_cloud_memorystore.py | 6 +-
.../example_dags/example_mysql_to_gcs.py | 1 +
.../providers/google/cloud/hooks/bigquery.py | 4 +-
.../google/cloud/hooks/cloud_memorystore.py | 2 +-
.../providers/google/cloud/hooks/cloud_sql.py | 2 +-
.../hooks/cloud_storage_transfer_service.py | 2 +-
.../providers/google/cloud/hooks/datastore.py | 2 +-
airflow/providers/google/cloud/hooks/gcs.py | 2 +-
airflow/providers/google/cloud/hooks/gdm.py | 2 +-
.../providers/google/cloud/hooks/mlengine.py | 2 +-
.../cloud_storage_transfer_service.py | 2 +-
.../providers/google/cloud/operators/dlp.py | 2 +-
.../sensors/cloud_storage_transfer_service.py | 2 +-
.../google/cloud/sensors/dataproc.py | 2 +-
.../cloud/transfers/azure_fileshare_to_gcs.py | 4 +-
.../google/cloud/transfers/presto_to_gcs.py | 2 +-
.../cloud/utils/credentials_provider.py | 2 +-
.../marketing_platform/operators/analytics.py | 2 +-
.../providers/grpc/backport_provider_setup.py | 2 +-
.../hashicorp/backport_provider_setup.py | 2 +-
.../providers/http/backport_provider_setup.py | 2 +-
.../providers/imap/backport_provider_setup.py | 2 +-
.../providers/jdbc/backport_provider_setup.py | 2 +-
.../jenkins/backport_provider_setup.py | 2 +-
.../providers/jira/backport_provider_setup.py | 2 +-
.../azure/backport_provider_setup.py | 2 +-
.../example_dags/example_azure_blob_to_gcs.py | 4 +-
.../example_dags/example_local_to_adls.py | 1 +
.../microsoft/azure/hooks/azure_batch.py | 2 +-
.../microsoft/azure/hooks/azure_fileshare.py | 4 +-
.../microsoft/azure/log/wasb_task_handler.py | 2 +-
.../operators/azure_container_instances.py | 6 +-
.../azure/transfers/azure_blob_to_gcs.py | 2 +-
.../azure/transfers/local_to_adls.py | 3 +-
.../mssql/backport_provider_setup.py | 2 +-
.../winrm/backport_provider_setup.py | 2 +-
.../mongo/backport_provider_setup.py | 2 +-
.../mysql/backport_provider_setup.py | 2 +-
.../providers/odbc/backport_provider_setup.py | 2 +-
airflow/providers/odbc/hooks/odbc.py | 2 +-
.../openfaas/backport_provider_setup.py | 2 +-
.../opsgenie/backport_provider_setup.py | 2 +-
.../opsgenie/hooks/opsgenie_alert.py | 2 +-
.../opsgenie/operators/opsgenie_alert.py | 2 +-
.../oracle/backport_provider_setup.py | 2 +-
airflow/providers/oracle/hooks/oracle.py | 2 +-
.../pagerduty/backport_provider_setup.py | 2 +-
.../plexus/backport_provider_setup.py | 2 +-
airflow/providers/plexus/operators/job.py | 2 +-
.../postgres/backport_provider_setup.py | 2 +-
airflow/providers/postgres/hooks/postgres.py | 2 +-
.../presto/backport_provider_setup.py | 2 +-
airflow/providers/presto/hooks/presto.py | 2 +-
.../qubole/backport_provider_setup.py | 2 +-
airflow/providers/qubole/hooks/qubole.py | 4 +-
.../providers/qubole/hooks/qubole_check.py | 2 +-
.../qubole/operators/qubole_check.py | 2 +-
.../redis/backport_provider_setup.py | 2 +-
.../salesforce/backport_provider_setup.py | 2 +-
.../providers/salesforce/hooks/salesforce.py | 2 +-
airflow/providers/salesforce/hooks/tableau.py | 2 +-
.../samba/backport_provider_setup.py | 2 +-
.../segment/backport_provider_setup.py | 2 +-
.../providers/sftp/backport_provider_setup.py | 2 +-
.../singularity/backport_provider_setup.py | 2 +-
.../slack/backport_provider_setup.py | 2 +-
airflow/providers/slack/operators/slack.py | 2 +-
.../slack/operators/slack_webhook.py | 2 +-
.../snowflake/backport_provider_setup.py | 2 +-
.../providers/snowflake/hooks/snowflake.py | 2 +-
.../sqlite/backport_provider_setup.py | 2 +-
.../providers/ssh/backport_provider_setup.py | 2 +-
airflow/providers/ssh/hooks/ssh.py | 2 +-
.../vertica/backport_provider_setup.py | 2 +-
.../yandex/backport_provider_setup.py | 2 +-
airflow/providers/yandex/hooks/yandex.py | 2 +-
.../zendesk/backport_provider_setup.py | 2 +-
airflow/secrets/base_secrets.py | 3 +-
airflow/secrets/local_filesystem.py | 15 +-
airflow/secrets/metastore.py | 2 +
airflow/security/kerberos.py | 49 +-
airflow/sensors/base_sensor_operator.py | 98 +-
airflow/sensors/bash.py | 16 +-
airflow/sensors/date_time_sensor.py | 8 +-
airflow/sensors/external_task_sensor.py | 103 +-
airflow/sensors/filesystem.py | 5 +-
airflow/sensors/hdfs_sensor.py | 3 +-
airflow/sensors/hive_partition_sensor.py | 3 +-
airflow/sensors/http_sensor.py | 3 +-
airflow/sensors/metastore_partition_sensor.py | 3 +-
.../sensors/named_hive_partition_sensor.py | 3 +-
airflow/sensors/python.py | 14 +-
airflow/sensors/s3_key_sensor.py | 3 +-
airflow/sensors/s3_prefix_sensor.py | 3 +-
airflow/sensors/smart_sensor_operator.py | 169 +-
airflow/sensors/sql_sensor.py | 35 +-
airflow/sensors/web_hdfs_sensor.py | 3 +-
airflow/sensors/weekday_sensor.py | 15 +-
airflow/sentry.py | 21 +-
airflow/serialization/serialized_objects.py | 113 +-
airflow/settings.py | 61 +-
airflow/stats.py | 30 +-
airflow/task/task_runner/base_task_runner.py | 16 +-
.../task/task_runner/cgroup_task_runner.py | 49 +-
airflow/ti_deps/dep_context.py | 21 +-
airflow/ti_deps/dependencies_deps.py | 5 +-
airflow/ti_deps/deps/base_ti_dep.py | 9 +-
.../deps/dag_ti_slots_available_dep.py | 4 +-
airflow/ti_deps/deps/dag_unpaused_dep.py | 3 +-
airflow/ti_deps/deps/dagrun_exists_dep.py | 20 +-
airflow/ti_deps/deps/dagrun_id_dep.py | 7 +-
.../deps/exec_date_after_start_date_dep.py | 13 +-
.../ti_deps/deps/not_in_retry_period_dep.py | 12 +-
.../deps/not_previously_skipped_dep.py | 17 +-
.../ti_deps/deps/pool_slots_available_dep.py | 16 +-
airflow/ti_deps/deps/prev_dagrun_dep.py | 31 +-
airflow/ti_deps/deps/ready_to_reschedule.py | 18 +-
.../ti_deps/deps/runnable_exec_date_dep.py | 18 +-
airflow/ti_deps/deps/task_concurrency_dep.py | 6 +-
airflow/ti_deps/deps/task_not_running_dep.py | 3 +-
airflow/ti_deps/deps/trigger_rule_dep.py | 95 +-
airflow/ti_deps/deps/valid_state_dep.py | 10 +-
airflow/typing_compat.py | 4 +-
airflow/utils/callback_requests.py | 4 +-
airflow/utils/cli.py | 37 +-
airflow/utils/compression.py | 14 +-
airflow/utils/dag_processing.py | 240 ++-
airflow/utils/dates.py | 6 +-
airflow/utils/db.py | 150 +-
airflow/utils/decorators.py | 15 +-
airflow/utils/dot_renderer.py | 30 +-
airflow/utils/email.py | 35 +-
airflow/utils/file.py | 33 +-
airflow/utils/helpers.py | 31 +-
airflow/utils/json.py | 25 +-
airflow/utils/log/cloudwatch_task_handler.py | 3 +-
airflow/utils/log/colored_log.py | 12 +-
airflow/utils/log/es_task_handler.py | 3 +-
airflow/utils/log/file_processor_handler.py | 16 +-
airflow/utils/log/file_task_handler.py | 40 +-
airflow/utils/log/gcs_task_handler.py | 3 +-
airflow/utils/log/json_formatter.py | 3 +-
airflow/utils/log/log_reader.py | 13 +-
airflow/utils/log/logging_mixin.py | 11 +-
airflow/utils/log/s3_task_handler.py | 3 +-
airflow/utils/log/stackdriver_task_handler.py | 3 +-
airflow/utils/log/wasb_task_handler.py | 3 +-
airflow/utils/module_loading.py | 4 +-
airflow/utils/operator_helpers.py | 61 +-
airflow/utils/operator_resources.py | 16 +-
airflow/utils/orm_event_handlers.py | 21 +-
airflow/utils/process_utils.py | 24 +-
airflow/utils/python_virtualenv.py | 10 +-
airflow/utils/serve_logs.py | 6 +-
airflow/utils/session.py | 4 +-
airflow/utils/sqlalchemy.py | 26 +-
airflow/utils/state.py | 41 +-
airflow/utils/task_group.py | 18 +-
airflow/utils/timeout.py | 2 +-
airflow/utils/timezone.py | 13 +-
airflow/utils/weekday.py | 4 +-
airflow/www/api/experimental/endpoints.py | 61 +-
airflow/www/app.py | 6 +-
airflow/www/auth.py | 8 +-
airflow/www/extensions/init_jinja_globals.py | 2 +-
airflow/www/extensions/init_security.py | 5 +-
airflow/www/forms.py | 170 +-
airflow/www/security.py | 10 +-
airflow/www/utils.py | 168 +-
airflow/www/validators.py | 11 +-
airflow/www/views.py | 1472 ++++++++++-------
airflow/www/widgets.py | 5 +-
.../tests/test_celery_kubernetes_executor.py | 1 +
...est_celery_kubernetes_pod_launcher_role.py | 1 +
chart/tests/test_chart_quality.py | 3 +-
.../test_dags_persistent_volume_claim.py | 1 +
chart/tests/test_flower_authorization.py | 1 +
chart/tests/test_git_sync_scheduler.py | 1 +
chart/tests/test_git_sync_webserver.py | 1 +
chart/tests/test_git_sync_worker.py | 1 +
chart/tests/test_migrate_database_job.py | 1 +
chart/tests/test_pod_template_file.py | 3 +-
chart/tests/test_scheduler.py | 1 +
chart/tests/test_worker.py | 1 +
dags/test_dag.py | 10 +-
dev/airflow-github | 55 +-
dev/airflow-license | 26 +-
dev/send_email.py | 117 +-
docs/build_docs.py | 70 +-
docs/conf.py | 57 +-
docs/exts/docroles.py | 28 +-
docs/exts/exampleinclude.py | 10 +-
kubernetes_tests/test_kubernetes_executor.py | 115 +-
.../test_kubernetes_pod_operator.py | 326 ++--
metastore_browser/hive_metastore.py | 47 +-
.../import_all_provider_classes.py | 19 +-
.../prepare_provider_packages.py | 565 ++++---
.../refactor_provider_packages.py | 158 +-
provider_packages/remove_old_releases.py | 34 +-
.../pre_commit_check_order_setup.py | 20 +-
.../pre_commit_check_setup_installation.py | 19 +-
.../ci/pre_commit/pre_commit_yaml_to_cfg.py | 28 +-
.../update_quarantined_test_status.py | 64 +-
scripts/tools/list-integrations.py | 9 +-
setup.py | 92 +-
.../disable_checks_for_tests.py | 18 +-
tests/airflow_pylint/do_not_use_asserts.py | 5 +-
tests/always/test_example_dags.py | 4 +-
tests/always/test_project_structure.py | 53 +-
tests/api/auth/backend/test_basic_auth.py | 58 +-
tests/api/auth/test_client.py | 29 +-
tests/api/client/test_local_client.py | 60 +-
.../common/experimental/test_delete_dag.py | 56 +-
.../common/experimental/test_mark_tasks.py | 299 ++--
tests/api/common/experimental/test_pool.py | 82 +-
.../common/experimental/test_trigger_dag.py | 13 +-
.../endpoints/test_config_endpoint.py | 2 -
.../endpoints/test_task_instance_endpoint.py | 8 +-
.../endpoints/test_variable_endpoint.py | 2 +-
.../endpoints/test_xcom_endpoint.py | 2 +-
.../schemas/test_task_instance_schema.py | 4 +-
tests/build_provider_packages_dependencies.py | 27 +-
tests/cli/commands/test_celery_command.py | 32 +-
.../cli/commands/test_cheat_sheet_command.py | 3 +-
tests/cli/commands/test_config_command.py | 20 +-
tests/cli/commands/test_connection_command.py | 372 +++--
tests/cli/commands/test_dag_command.py | 303 ++--
tests/cli/commands/test_db_command.py | 44 +-
tests/cli/commands/test_info_command.py | 17 +-
tests/cli/commands/test_kubernetes_command.py | 15 +-
tests/cli/commands/test_legacy_commands.py | 44 +-
tests/cli/commands/test_pool_command.py | 20 +-
tests/cli/commands/test_role_command.py | 9 +-
tests/cli/commands/test_sync_perm_command.py | 21 +-
tests/cli/commands/test_task_command.py | 320 ++--
tests/cli/commands/test_user_command.py | 226 ++-
tests/cli/commands/test_variable_command.py | 89 +-
tests/cli/commands/test_webserver_command.py | 57 +-
tests/cli/test_cli_parser.py | 87 +-
tests/cluster_policies/__init__.py | 13 +-
tests/conftest.py | 136 +-
tests/core/test_config_templates.py | 28 +-
tests/core/test_configuration.py | 211 +--
tests/core/test_core.py | 205 +--
tests/core/test_core_to_contrib.py | 12 +-
tests/core/test_example_dags_system.py | 16 +-
tests/core/test_impersonation_tests.py | 66 +-
tests/core/test_local_settings.py | 7 +
tests/core/test_logging_config.py | 35 +-
tests/core/test_sentry.py | 1 +
tests/core/test_sqlalchemy_config.py | 46 +-
tests/core/test_stats.py | 103 +-
tests/dags/subdir2/test_dont_ignore_this.py | 5 +-
tests/dags/test_backfill_pooled_tasks.py | 3 +-
tests/dags/test_clear_subdag.py | 13 +-
tests/dags/test_cli_triggered_dags.py | 19 +-
tests/dags/test_default_impersonation.py | 5 +-
tests/dags/test_default_views.py | 12 +-
tests/dags/test_double_trigger.py | 4 +-
tests/dags/test_example_bash_operator.py | 25 +-
tests/dags/test_heartbeat_failed_fast.py | 5 +-
tests/dags/test_impersonation.py | 5 +-
tests/dags/test_impersonation_subdag.py | 28 +-
tests/dags/test_invalid_cron.py | 10 +-
tests/dags/test_issue_1225.py | 29 +-
tests/dags/test_latest_runs.py | 6 +-
tests/dags/test_logging_in_dag.py | 6 +-
tests/dags/test_mark_success.py | 6 +-
tests/dags/test_missing_owner.py | 4 +-
tests/dags/test_multiple_dags.py | 8 +-
tests/dags/test_no_impersonation.py | 3 +-
tests/dags/test_on_failure_callback.py | 4 +-
tests/dags/test_on_kill.py | 9 +-
tests/dags/test_prev_dagrun_dep.py | 14 +-
tests/dags/test_retry_handling_job.py | 5 +-
tests/dags/test_scheduler_dags.py | 25 +-
tests/dags/test_task_view_type_check.py | 5 +-
tests/dags/test_with_non_default_owner.py | 5 +-
.../test_impersonation_custom.py | 17 +-
tests/dags_with_system_exit/a_system_exit.py | 4 +-
.../b_test_scheduler_dags.py | 9 +-
tests/dags_with_system_exit/c_system_exit.py | 4 +-
tests/deprecated_classes.py | 116 +-
tests/executors/test_base_executor.py | 8 +-
tests/executors/test_celery_executor.py | 134 +-
.../test_celery_kubernetes_executor.py | 14 +-
tests/executors/test_dask_executor.py | 39 +-
tests/executors/test_executor_loader.py | 33 +-
tests/executors/test_kubernetes_executor.py | 124 +-
tests/executors/test_local_executor.py | 18 +-
tests/executors/test_sequential_executor.py | 9 +-
tests/hooks/test_dbapi_hook.py | 50 +-
tests/jobs/test_backfill_job.py | 510 +++---
tests/jobs/test_base_job.py | 5 +-
tests/jobs/test_local_task_job.py | 146 +-
tests/jobs/test_scheduler_job.py | 975 ++++++-----
tests/kubernetes/models/test_secret.py | 167 +-
tests/kubernetes/test_client.py | 1 -
tests/kubernetes/test_pod_generator.py | 377 ++---
tests/kubernetes/test_pod_launcher.py | 162 +-
tests/kubernetes/test_refresh_config.py | 1 -
tests/lineage/test_lineage.py | 43 +-
tests/models/test_baseoperator.py | 77 +-
tests/models/test_cleartasks.py | 61 +-
tests/models/test_connection.py | 97 +-
tests/models/test_dag.py | 575 +++----
tests/models/test_dagbag.py | 168 +-
tests/models/test_dagcode.py | 20 +-
tests/models/test_dagparam.py | 4 +-
tests/models/test_dagrun.py | 301 ++--
tests/models/test_pool.py | 39 +-
tests/models/test_renderedtifields.py | 135 +-
tests/models/test_sensorinstance.py | 14 +-
tests/models/test_serialized_dag.py | 10 +-
tests/models/test_skipmixin.py | 18 +-
tests/models/test_taskinstance.py | 666 ++++----
tests/models/test_timestamp.py | 3 +-
tests/models/test_variable.py | 15 +-
tests/models/test_xcom.py | 146 +-
tests/models/test_xcom_arg.py | 16 +-
tests/operators/test_bash.py | 83 +-
tests/operators/test_branch_operator.py | 24 +-
tests/operators/test_dagrun_operator.py | 24 +-
tests/operators/test_email.py | 15 +-
tests/operators/test_generic_transfer.py | 24 +-
tests/operators/test_latest_only_operator.py | 162 +-
tests/operators/test_python.py | 360 ++--
tests/operators/test_sql.py | 102 +-
tests/operators/test_subdag_operator.py | 76 +-
tests/plugins/test_plugin.py | 33 +-
tests/plugins/test_plugins_manager.py | 55 +-
.../amazon/aws/hooks/test_glacier.py | 2 +-
tests/providers/amazon/aws/hooks/test_glue.py | 1 -
.../amazon/aws/hooks/test_sagemaker.py | 2 +-
.../amazon/aws/hooks/test_secrets_manager.py | 2 +-
.../amazon/aws/operators/test_athena.py | 1 -
.../amazon/aws/operators/test_batch.py | 1 -
.../amazon/aws/operators/test_ecs.py | 2 +-
.../amazon/aws/operators/test_glacier.py | 8 +-
.../aws/operators/test_glacier_system.py | 1 -
.../amazon/aws/operators/test_glue.py | 1 -
.../amazon/aws/operators/test_s3_bucket.py | 2 +-
.../amazon/aws/operators/test_s3_list.py | 1 -
.../aws/operators/test_sagemaker_endpoint.py | 2 +-
.../test_sagemaker_endpoint_config.py | 1 -
.../aws/operators/test_sagemaker_model.py | 1 -
.../operators/test_sagemaker_processing.py | 2 +-
.../aws/operators/test_sagemaker_training.py | 1 -
.../aws/operators/test_sagemaker_transform.py | 1 -
.../aws/operators/test_sagemaker_tuning.py | 1 -
.../amazon/aws/sensors/test_athena.py | 1 -
.../amazon/aws/sensors/test_glacier.py | 1 -
.../providers/amazon/aws/sensors/test_glue.py | 1 -
.../sensors/test_glue_catalog_partition.py | 1 -
.../aws/sensors/test_sagemaker_endpoint.py | 1 -
.../aws/sensors/test_sagemaker_training.py | 1 -
.../aws/sensors/test_sagemaker_transform.py | 1 -
.../aws/sensors/test_sagemaker_tuning.py | 1 -
.../amazon/aws/transfers/test_gcs_to_s3.py | 1 -
.../aws/transfers/test_glacier_to_gcs.py | 4 +-
.../amazon/aws/transfers/test_mongo_to_s3.py | 1 -
.../apache/cassandra/hooks/test_cassandra.py | 2 +-
.../druid/operators/test_druid_check.py | 1 -
.../apache/hive/transfers/test_s3_to_hive.py | 1 -
tests/providers/apache/pig/hooks/test_pig.py | 1 -
.../apache/pig/operators/test_pig.py | 1 -
.../spark/hooks/test_spark_jdbc_script.py | 1 +
.../providers/cloudant/hooks/test_cloudant.py | 1 -
.../cncf/kubernetes/hooks/test_kubernetes.py | 3 +-
.../databricks/hooks/test_databricks.py | 2 +-
.../databricks/operators/test_databricks.py | 1 -
.../providers/datadog/sensors/test_datadog.py | 1 -
tests/providers/docker/hooks/test_docker.py | 1 -
.../providers/docker/operators/test_docker.py | 1 -
.../docker/operators/test_docker_swarm.py | 2 +-
tests/providers/exasol/hooks/test_exasol.py | 1 -
.../providers/exasol/operators/test_exasol.py | 1 -
.../providers/facebook/ads/hooks/test_ads.py | 1 +
tests/providers/google/ads/hooks/test_ads.py | 1 +
.../google/cloud/hooks/test_automl.py | 2 +-
.../google/cloud/hooks/test_bigquery_dts.py | 2 +-
.../google/cloud/hooks/test_cloud_build.py | 1 -
.../test_cloud_storage_transfer_service.py | 2 +-
.../google/cloud/hooks/test_compute.py | 1 -
.../google/cloud/hooks/test_dataflow.py | 2 +-
.../google/cloud/hooks/test_dataproc.py | 2 +-
.../google/cloud/hooks/test_datastore.py | 3 +-
.../providers/google/cloud/hooks/test_dlp.py | 2 +-
.../google/cloud/hooks/test_functions.py | 1 -
.../providers/google/cloud/hooks/test_kms.py | 1 -
.../cloud/hooks/test_kubernetes_engine.py | 2 +-
.../google/cloud/hooks/test_life_sciences.py | 1 -
.../cloud/hooks/test_natural_language.py | 2 +-
.../google/cloud/hooks/test_pubsub.py | 2 +-
.../google/cloud/hooks/test_spanner.py | 1 -
.../google/cloud/hooks/test_speech_to_text.py | 1 -
.../google/cloud/hooks/test_stackdriver.py | 2 +-
.../google/cloud/hooks/test_tasks.py | 2 +-
.../google/cloud/hooks/test_text_to_speech.py | 1 -
.../google/cloud/hooks/test_translate.py | 1 -
.../cloud/hooks/test_video_intelligence.py | 2 +-
.../google/cloud/hooks/test_vision.py | 2 +-
.../google/cloud/operators/test_automl.py | 2 +-
.../google/cloud/operators/test_bigquery.py | 2 +-
.../cloud/operators/test_bigquery_dts.py | 1 -
.../cloud/operators/test_cloud_build.py | 3 +-
.../cloud/operators/test_cloud_memorystore.py | 6 +-
.../google/cloud/operators/test_cloud_sql.py | 2 +-
.../test_cloud_storage_transfer_service.py | 2 +-
.../google/cloud/operators/test_dataflow.py | 3 +-
.../cloud/operators/test_dataproc_system.py | 2 +-
.../google/cloud/operators/test_dlp.py | 1 -
.../google/cloud/operators/test_dlp_system.py | 2 +-
.../google/cloud/operators/test_functions.py | 2 +-
.../google/cloud/operators/test_gcs.py | 1 -
.../cloud/operators/test_kubernetes_engine.py | 2 +-
.../cloud/operators/test_life_sciences.py | 1 -
.../cloud/operators/test_mlengine_utils.py | 3 +-
.../google/cloud/operators/test_pubsub.py | 2 +-
.../google/cloud/operators/test_spanner.py | 2 +-
.../cloud/operators/test_speech_to_text.py | 1 -
.../cloud/operators/test_stackdriver.py | 2 +-
.../google/cloud/operators/test_tasks.py | 2 +-
.../cloud/operators/test_text_to_speech.py | 2 +-
.../google/cloud/operators/test_translate.py | 1 -
.../cloud/operators/test_translate_speech.py | 2 +-
.../operators/test_video_intelligence.py | 2 +-
.../google/cloud/operators/test_vision.py | 2 +-
.../google/cloud/sensors/test_bigquery_dts.py | 1 -
.../test_cloud_storage_transfer_service.py | 2 +-
.../google/cloud/sensors/test_dataproc.py | 2 +-
.../google/cloud/sensors/test_pubsub.py | 2 +-
.../cloud/transfers/test_adls_to_gcs.py | 1 -
.../transfers/test_azure_fileshare_to_gcs.py | 1 -
.../test_azure_fileshare_to_gcs_system.py | 9 +-
.../transfers/test_bigquery_to_bigquery.py | 1 -
.../cloud/transfers/test_bigquery_to_gcs.py | 1 -
.../cloud/transfers/test_bigquery_to_mysql.py | 1 -
.../cloud/transfers/test_cassandra_to_gcs.py | 1 -
.../cloud/transfers/test_gcs_to_bigquery.py | 1 -
.../google/cloud/transfers/test_gcs_to_gcs.py | 1 -
.../cloud/transfers/test_gcs_to_local.py | 1 -
.../transfers/test_gcs_to_local_system.py | 2 +-
.../cloud/transfers/test_gcs_to_sftp.py | 1 -
.../cloud/transfers/test_local_to_gcs.py | 2 +-
.../cloud/transfers/test_mssql_to_gcs.py | 1 -
.../cloud/transfers/test_mysql_to_gcs.py | 2 +-
.../transfers/test_mysql_to_gcs_system.py | 4 +-
.../google/cloud/transfers/test_s3_to_gcs.py | 1 -
.../cloud/transfers/test_s3_to_gcs_system.py | 3 +-
.../cloud/transfers/test_salesforce_to_gcs.py | 1 -
.../test_salesforce_to_gcs_system.py | 1 +
.../cloud/transfers/test_sftp_to_gcs.py | 1 -
.../google/cloud/transfers/test_sql_to_gcs.py | 2 +-
.../cloud/utils/test_credentials_provider.py | 2 +-
.../google/firebase/hooks/test_firestore.py | 1 -
.../google/suite/hooks/test_drive.py | 1 -
.../google/suite/hooks/test_sheets.py | 1 -
tests/providers/grpc/hooks/test_grpc.py | 1 -
tests/providers/grpc/operators/test_grpc.py | 1 -
tests/providers/http/hooks/test_http.py | 2 +-
tests/providers/http/operators/test_http.py | 2 +-
tests/providers/http/sensors/test_http.py | 2 +-
.../microsoft/azure/hooks/test_azure_batch.py | 2 +-
.../azure/hooks/test_azure_cosmos.py | 2 +-
.../azure/hooks/test_azure_data_lake.py | 1 -
.../azure/hooks/test_azure_fileshare.py | 4 +-
.../microsoft/azure/hooks/test_wasb.py | 1 -
.../azure/operators/test_adls_list.py | 1 -
.../azure/operators/test_azure_batch.py | 1 -
.../test_azure_container_instances.py | 2 +-
.../azure/operators/test_azure_cosmos.py | 1 -
.../azure/operators/test_wasb_delete_blob.py | 1 -
.../microsoft/azure/sensors/test_wasb.py | 1 -
.../azure/transfers/test_azure_blob_to_gcs.py | 5 +-
.../azure/transfers/test_file_to_wasb.py | 1 -
.../azure/transfers/test_local_to_adls.py | 1 -
.../test_oracle_to_azure_data_lake.py | 2 +-
.../microsoft/mssql/hooks/test_mssql.py | 1 -
.../microsoft/mssql/operators/test_mssql.py | 1 -
.../providers/openfaas/hooks/test_openfaas.py | 2 +-
tests/providers/oracle/hooks/test_oracle.py | 2 +-
.../providers/oracle/operators/test_oracle.py | 1 -
.../oracle/transfers/test_oracle_to_oracle.py | 1 -
.../pagerduty/hooks/test_pagerduty.py | 1 +
tests/providers/plexus/hooks/test_plexus.py | 1 +
tests/providers/plexus/operators/test_job.py | 2 +
.../qubole/operators/test_qubole_check.py | 2 +-
tests/providers/samba/hooks/test_samba.py | 2 +-
.../providers/sendgrid/utils/test_emailer.py | 1 -
.../singularity/operators/test_singularity.py | 2 +-
tests/providers/slack/hooks/test_slack.py | 2 +-
tests/providers/slack/operators/test_slack.py | 1 -
.../snowflake/operators/test_snowflake.py | 1 -
tests/providers/ssh/hooks/test_ssh.py | 6 +-
tests/secrets/test_local_filesystem.py | 55 +-
tests/secrets/test_secrets.py | 61 +-
tests/secrets/test_secrets_backends.py | 30 +-
tests/security/test_kerberos.py | 17 +-
tests/sensors/test_base_sensor.py | 162 +-
tests/sensors/test_bash.py | 9 +-
tests/sensors/test_date_time_sensor.py | 10 +-
tests/sensors/test_external_task_sensor.py | 309 ++--
tests/sensors/test_filesystem.py | 36 +-
tests/sensors/test_python.py | 54 +-
tests/sensors/test_smart_sensor_operator.py | 55 +-
tests/sensors/test_sql_sensor.py | 45 +-
tests/sensors/test_time_sensor.py | 4 +-
tests/sensors/test_timedelta_sensor.py | 8 +-
tests/sensors/test_timeout_sensor.py | 15 +-
tests/sensors/test_weekday_sensor.py | 61 +-
tests/serialization/test_dag_serialization.py | 472 +++---
.../task_runner/test_cgroup_task_runner.py | 1 -
.../task_runner/test_standard_task_runner.py | 40 +-
tests/task/task_runner/test_task_runner.py | 11 +-
tests/test_utils/amazon_system_helpers.py | 38 +-
tests/test_utils/api_connexion_utils.py | 2 +-
tests/test_utils/asserts.py | 20 +-
tests/test_utils/azure_system_helpers.py | 1 -
tests/test_utils/db.py | 16 +-
tests/test_utils/gcp_system_helpers.py | 62 +-
tests/test_utils/hdfs_utils.py | 245 +--
tests/test_utils/logging_command_executor.py | 17 +-
tests/test_utils/mock_hooks.py | 6 +-
tests/test_utils/mock_operators.py | 18 +-
tests/test_utils/mock_process.py | 12 +-
tests/test_utils/perf/dags/elastic_dag.py | 27 +-
tests/test_utils/perf/dags/perf_dag_1.py | 13 +-
tests/test_utils/perf/dags/perf_dag_2.py | 13 +-
tests/test_utils/perf/perf_kit/memory.py | 1 +
.../perf/perf_kit/repeat_and_time.py | 2 +
tests/test_utils/perf/perf_kit/sqlalchemy.py | 75 +-
.../perf/scheduler_dag_execution_timing.py | 53 +-
.../test_utils/perf/scheduler_ops_metrics.py | 101 +-
tests/test_utils/perf/sql_queries.py | 6 +-
.../remote_user_api_auth_backend.py | 6 +-
.../test_remote_user_api_auth_backend.py | 4 +-
tests/ti_deps/deps/fake_models.py | 4 -
.../deps/test_dag_ti_slots_available_dep.py | 1 -
tests/ti_deps/deps/test_dag_unpaused_dep.py | 1 -
tests/ti_deps/deps/test_dagrun_exists_dep.py | 1 -
.../deps/test_not_in_retry_period_dep.py | 10 +-
.../deps/test_not_previously_skipped_dep.py | 16 +-
tests/ti_deps/deps/test_prev_dagrun_dep.py | 81 +-
.../deps/test_ready_to_reschedule_dep.py | 18 +-
.../deps/test_runnable_exec_date_dep.py | 24 +-
tests/ti_deps/deps/test_task_concurrency.py | 1 -
.../ti_deps/deps/test_task_not_running_dep.py | 1 -
tests/ti_deps/deps/test_trigger_rule_dep.py | 656 ++++----
tests/ti_deps/deps/test_valid_state_dep.py | 1 -
.../utils/log/test_file_processor_handler.py | 14 +-
tests/utils/log/test_json_formatter.py | 6 +-
tests/utils/log/test_log_reader.py | 46 +-
tests/utils/test_cli_util.py | 26 +-
tests/utils/test_compression.py | 31 +-
tests/utils/test_dag_cycle.py | 40 +-
tests/utils/test_dag_processing.py | 89 +-
tests/utils/test_dates.py | 25 +-
tests/utils/test_db.py | 70 +-
tests/utils/test_decorators.py | 4 +-
tests/utils/test_docs.py | 14 +-
tests/utils/test_email.py | 50 +-
tests/utils/test_helpers.py | 44 +-
tests/utils/test_json.py | 46 +-
tests/utils/test_log_handlers.py | 30 +-
tests/utils/test_logging_mixin.py | 12 +-
tests/utils/test_net.py | 3 +-
tests/utils/test_operator_helpers.py | 24 +-
tests/utils/test_process_utils.py | 15 +-
tests/utils/test_python_virtualenv.py | 21 +-
tests/utils/test_sqlalchemy.py | 64 +-
...test_task_handler_with_custom_formatter.py | 5 +-
tests/utils/test_timezone.py | 9 +-
tests/utils/test_trigger_rule.py | 1 -
tests/utils/test_weight_rule.py | 1 -
.../experimental/test_dag_runs_endpoint.py | 10 +-
tests/www/api/experimental/test_endpoints.py | 166 +-
.../experimental/test_kerberos_endpoints.py | 16 +-
tests/www/test_app.py | 90 +-
tests/www/test_security.py | 33 +-
tests/www/test_utils.py | 78 +-
tests/www/test_validators.py | 5 +-
tests/www/test_views.py | 1095 ++++++------
1068 files changed, 15410 insertions(+), 15132 deletions(-)
diff --git a/airflow/__init__.py b/airflow/__init__.py
index 486bd5fbbe148..35e755ece0e04 100644
--- a/airflow/__init__.py
+++ b/airflow/__init__.py
@@ -55,15 +55,18 @@ def __getattr__(name):
# PEP-562: Lazy loaded attributes on python modules
if name == "DAG":
from airflow.models.dag import DAG # pylint: disable=redefined-outer-name
+
return DAG
if name == "AirflowException":
from airflow.exceptions import AirflowException # pylint: disable=redefined-outer-name
+
return AirflowException
raise AttributeError(f"module {__name__} has no attribute {name}")
if not settings.LAZY_LOAD_PLUGINS:
from airflow import plugins_manager
+
plugins_manager.ensure_plugins_loaded()
diff --git a/airflow/api/__init__.py b/airflow/api/__init__.py
index 63dbfcff93283..c869252858a5b 100644
--- a/airflow/api/__init__.py
+++ b/airflow/api/__init__.py
@@ -38,8 +38,5 @@ def load_auth():
log.info("Loaded API auth backend: %s", auth_backend)
return auth_backend
except ImportError as err:
- log.critical(
- "Cannot import %s for API authentication due to: %s",
- auth_backend, err
- )
+ log.critical("Cannot import %s for API authentication due to: %s", auth_backend, err)
raise AirflowException(err)
diff --git a/airflow/api/auth/backend/basic_auth.py b/airflow/api/auth/backend/basic_auth.py
index bd42708661f6e..9b0e32087b6ba 100644
--- a/airflow/api/auth/backend/basic_auth.py
+++ b/airflow/api/auth/backend/basic_auth.py
@@ -53,13 +53,12 @@ def auth_current_user() -> Optional[User]:
def requires_authentication(function: T):
"""Decorator for functions that require authentication"""
+
@wraps(function)
def decorated(*args, **kwargs):
if auth_current_user() is not None:
return function(*args, **kwargs)
else:
- return Response(
- "Unauthorized", 401, {"WWW-Authenticate": "Basic"}
- )
+ return Response("Unauthorized", 401, {"WWW-Authenticate": "Basic"})
return cast(T, decorated)
diff --git a/airflow/api/auth/backend/default.py b/airflow/api/auth/backend/default.py
index 70ae82b2d6c60..bfa96dfb0e9f7 100644
--- a/airflow/api/auth/backend/default.py
+++ b/airflow/api/auth/backend/default.py
@@ -33,6 +33,7 @@ def init_app(_):
def requires_authentication(function: T):
"""Decorator for functions that require authentication"""
+
@wraps(function)
def decorated(*args, **kwargs):
return function(*args, **kwargs)
diff --git a/airflow/api/auth/backend/kerberos_auth.py b/airflow/api/auth/backend/kerberos_auth.py
index 41042dec32c01..08a0fecace7cd 100644
--- a/airflow/api/auth/backend/kerberos_auth.py
+++ b/airflow/api/auth/backend/kerberos_auth.py
@@ -132,6 +132,7 @@ def _gssapi_authenticate(token):
def requires_authentication(function: T):
"""Decorator for functions that require authentication with Kerberos"""
+
@wraps(function)
def decorated(*args, **kwargs):
header = request.headers.get("Authorization")
@@ -144,11 +145,11 @@ def decorated(*args, **kwargs):
response = function(*args, **kwargs)
response = make_response(response)
if ctx.kerberos_token is not None:
- response.headers['WWW-Authenticate'] = ' '.join(['negotiate',
- ctx.kerberos_token])
+ response.headers['WWW-Authenticate'] = ' '.join(['negotiate', ctx.kerberos_token])
return response
if return_code != kerberos.AUTH_GSS_CONTINUE:
return _forbidden()
return _unauthorized()
+
return cast(T, decorated)
diff --git a/airflow/api/client/__init__.py b/airflow/api/client/__init__.py
index 7431dfa4cfc34..53c85062e63e9 100644
--- a/airflow/api/client/__init__.py
+++ b/airflow/api/client/__init__.py
@@ -35,6 +35,6 @@ def get_current_api_client() -> Client:
api_client = api_module.Client(
api_base_url=conf.get('cli', 'endpoint_url'),
auth=getattr(auth_backend, 'CLIENT_AUTH', None),
- session=session
+ session=session,
)
return api_client
diff --git a/airflow/api/client/json_client.py b/airflow/api/client/json_client.py
index c17307f3f1508..1ffe7fd88cd5b 100644
--- a/airflow/api/client/json_client.py
+++ b/airflow/api/client/json_client.py
@@ -45,12 +45,15 @@ def _request(self, url, method='GET', json=None):
def trigger_dag(self, dag_id, run_id=None, conf=None, execution_date=None):
endpoint = f'/api/experimental/dags/{dag_id}/dag_runs'
url = urljoin(self._api_base_url, endpoint)
- data = self._request(url, method='POST',
- json={
- "run_id": run_id,
- "conf": conf,
- "execution_date": execution_date,
- })
+ data = self._request(
+ url,
+ method='POST',
+ json={
+ "run_id": run_id,
+ "conf": conf,
+ "execution_date": execution_date,
+ },
+ )
return data['message']
def delete_dag(self, dag_id):
@@ -74,12 +77,15 @@ def get_pools(self):
def create_pool(self, name, slots, description):
endpoint = '/api/experimental/pools'
url = urljoin(self._api_base_url, endpoint)
- pool = self._request(url, method='POST',
- json={
- 'name': name,
- 'slots': slots,
- 'description': description,
- })
+ pool = self._request(
+ url,
+ method='POST',
+ json={
+ 'name': name,
+ 'slots': slots,
+ 'description': description,
+ },
+ )
return pool['pool'], pool['slots'], pool['description']
def delete_pool(self, name):
diff --git a/airflow/api/client/local_client.py b/airflow/api/client/local_client.py
index 5c08f1ab3eb13..7ce0d1655da6e 100644
--- a/airflow/api/client/local_client.py
+++ b/airflow/api/client/local_client.py
@@ -26,10 +26,9 @@ class Client(api_client.Client):
"""Local API client implementation."""
def trigger_dag(self, dag_id, run_id=None, conf=None, execution_date=None):
- dag_run = trigger_dag.trigger_dag(dag_id=dag_id,
- run_id=run_id,
- conf=conf,
- execution_date=execution_date)
+ dag_run = trigger_dag.trigger_dag(
+ dag_id=dag_id, run_id=run_id, conf=conf, execution_date=execution_date
+ )
return f"Created {dag_run}"
def delete_dag(self, dag_id):
diff --git a/airflow/api/common/experimental/__init__.py b/airflow/api/common/experimental/__init__.py
index ebaea5e658322..b161e04346358 100644
--- a/airflow/api/common/experimental/__init__.py
+++ b/airflow/api/common/experimental/__init__.py
@@ -29,10 +29,7 @@ def check_and_get_dag(dag_id: str, task_id: Optional[str] = None) -> DagModel:
if dag_model is None:
raise DagNotFound(f"Dag id {dag_id} not found in DagModel")
- dagbag = DagBag(
- dag_folder=dag_model.fileloc,
- read_dags_from_db=True
- )
+ dagbag = DagBag(dag_folder=dag_model.fileloc, read_dags_from_db=True)
dag = dagbag.get_dag(dag_id)
if not dag:
error_message = f"Dag id {dag_id} not found"
@@ -47,7 +44,6 @@ def check_and_get_dagrun(dag: DagModel, execution_date: datetime) -> DagRun:
"""Get DagRun object and check that it exists"""
dagrun = dag.get_dagrun(execution_date=execution_date)
if not dagrun:
- error_message = ('Dag Run for date {} not found in dag {}'
- .format(execution_date, dag.dag_id))
+ error_message = f'Dag Run for date {execution_date} not found in dag {dag.dag_id}'
raise DagRunNotFound(error_message)
return dagrun
diff --git a/airflow/api/common/experimental/delete_dag.py b/airflow/api/common/experimental/delete_dag.py
index ee50c7d60090e..d27c21f18b2fa 100644
--- a/airflow/api/common/experimental/delete_dag.py
+++ b/airflow/api/common/experimental/delete_dag.py
@@ -60,13 +60,14 @@ def delete_dag(dag_id: str, keep_records_in_log: bool = True, session=None) -> i
if dag.is_subdag:
parent_dag_id, task_id = dag_id.rsplit(".", 1)
for model in TaskFail, models.TaskInstance:
- count += session.query(model).filter(model.dag_id == parent_dag_id,
- model.task_id == task_id).delete()
+ count += (
+ session.query(model).filter(model.dag_id == parent_dag_id, model.task_id == task_id).delete()
+ )
# Delete entries in Import Errors table for a deleted DAG
# This handles the case when the dag_id is changed in the file
- session.query(models.ImportError).filter(
- models.ImportError.filename == dag.fileloc
- ).delete(synchronize_session='fetch')
+ session.query(models.ImportError).filter(models.ImportError.filename == dag.fileloc).delete(
+ synchronize_session='fetch'
+ )
return count
diff --git a/airflow/api/common/experimental/get_dag_runs.py b/airflow/api/common/experimental/get_dag_runs.py
index 8b85c9919ac8d..cd939676bb3ac 100644
--- a/airflow/api/common/experimental/get_dag_runs.py
+++ b/airflow/api/common/experimental/get_dag_runs.py
@@ -38,16 +38,16 @@ def get_dag_runs(dag_id: str, state: Optional[str] = None) -> List[Dict[str, Any
dag_runs = []
state = state.lower() if state else None
for run in DagRun.find(dag_id=dag_id, state=state):
- dag_runs.append({
- 'id': run.id,
- 'run_id': run.run_id,
- 'state': run.state,
- 'dag_id': run.dag_id,
- 'execution_date': run.execution_date.isoformat(),
- 'start_date': ((run.start_date or '') and
- run.start_date.isoformat()),
- 'dag_run_url': url_for('Airflow.graph', dag_id=run.dag_id,
- execution_date=run.execution_date)
- })
+ dag_runs.append(
+ {
+ 'id': run.id,
+ 'run_id': run.run_id,
+ 'state': run.state,
+ 'dag_id': run.dag_id,
+ 'execution_date': run.execution_date.isoformat(),
+ 'start_date': ((run.start_date or '') and run.start_date.isoformat()),
+ 'dag_run_url': url_for('Airflow.graph', dag_id=run.dag_id, execution_date=run.execution_date),
+ }
+ )
return dag_runs
diff --git a/airflow/api/common/experimental/get_lineage.py b/airflow/api/common/experimental/get_lineage.py
index 2d0e97dcdaf8d..c6857a65fb279 100644
--- a/airflow/api/common/experimental/get_lineage.py
+++ b/airflow/api/common/experimental/get_lineage.py
@@ -31,10 +31,12 @@ def get_lineage(dag_id: str, execution_date: datetime.datetime, session=None) ->
dag = check_and_get_dag(dag_id)
check_and_get_dagrun(dag, execution_date)
- inlets: List[XCom] = XCom.get_many(dag_ids=dag_id, execution_date=execution_date,
- key=PIPELINE_INLETS, session=session).all()
- outlets: List[XCom] = XCom.get_many(dag_ids=dag_id, execution_date=execution_date,
- key=PIPELINE_OUTLETS, session=session).all()
+ inlets: List[XCom] = XCom.get_many(
+ dag_ids=dag_id, execution_date=execution_date, key=PIPELINE_INLETS, session=session
+ ).all()
+ outlets: List[XCom] = XCom.get_many(
+ dag_ids=dag_id, execution_date=execution_date, key=PIPELINE_OUTLETS, session=session
+ ).all()
lineage: Dict[str, Dict[str, Any]] = {}
for meta in inlets:
diff --git a/airflow/api/common/experimental/get_task_instance.py b/airflow/api/common/experimental/get_task_instance.py
index 0b1dc8e18d8a0..07f41cbc400c1 100644
--- a/airflow/api/common/experimental/get_task_instance.py
+++ b/airflow/api/common/experimental/get_task_instance.py
@@ -31,8 +31,7 @@ def get_task_instance(dag_id: str, task_id: str, execution_date: datetime) -> Ta
# Get task instance object and check that it exists
task_instance = dagrun.get_task_instance(task_id)
if not task_instance:
- error_message = ('Task {} instance for date {} not found'
- .format(task_id, execution_date))
+ error_message = f'Task {task_id} instance for date {execution_date} not found'
raise TaskInstanceNotFound(error_message)
return task_instance
diff --git a/airflow/api/common/experimental/mark_tasks.py b/airflow/api/common/experimental/mark_tasks.py
index 3f1cb8363e486..852038dfe88a9 100644
--- a/airflow/api/common/experimental/mark_tasks.py
+++ b/airflow/api/common/experimental/mark_tasks.py
@@ -69,7 +69,7 @@ def set_state(
past: bool = False,
state: str = State.SUCCESS,
commit: bool = False,
- session=None
+ session=None,
): # pylint: disable=too-many-arguments,too-many-locals
"""
Set the state of a task instance and if needed its relatives. Can set state
@@ -137,33 +137,24 @@ def set_state(
# Flake and pylint disagree about correct indents here
def all_subdag_tasks_query(sub_dag_run_ids, session, state, confirmed_dates): # noqa: E123
"""Get *all* tasks of the sub dags"""
- qry_sub_dag = session.query(TaskInstance). \
- filter(
- TaskInstance.dag_id.in_(sub_dag_run_ids),
- TaskInstance.execution_date.in_(confirmed_dates)
- ). \
- filter(
- or_(
- TaskInstance.state.is_(None),
- TaskInstance.state != state
- )
+ qry_sub_dag = (
+ session.query(TaskInstance)
+ .filter(TaskInstance.dag_id.in_(sub_dag_run_ids), TaskInstance.execution_date.in_(confirmed_dates))
+ .filter(or_(TaskInstance.state.is_(None), TaskInstance.state != state))
) # noqa: E123
return qry_sub_dag
def get_all_dag_task_query(dag, session, state, task_ids, confirmed_dates):
"""Get all tasks of the main dag that will be affected by a state change"""
- qry_dag = session.query(TaskInstance). \
- filter(
- TaskInstance.dag_id == dag.dag_id,
- TaskInstance.execution_date.in_(confirmed_dates),
- TaskInstance.task_id.in_(task_ids) # noqa: E123
- ). \
- filter(
- or_(
- TaskInstance.state.is_(None),
- TaskInstance.state != state
+ qry_dag = (
+ session.query(TaskInstance)
+ .filter(
+ TaskInstance.dag_id == dag.dag_id,
+ TaskInstance.execution_date.in_(confirmed_dates),
+ TaskInstance.task_id.in_(task_ids), # noqa: E123
)
+ .filter(or_(TaskInstance.state.is_(None), TaskInstance.state != state))
)
return qry_dag
@@ -186,10 +177,12 @@ def get_subdag_runs(dag, session, state, task_ids, commit, confirmed_dates):
# this works as a kind of integrity check
# it creates missing dag runs for subdag operators,
# maybe this should be moved to dagrun.verify_integrity
- dag_runs = _create_dagruns(current_task.subdag,
- execution_dates=confirmed_dates,
- state=State.RUNNING,
- run_type=DagRunType.BACKFILL_JOB)
+ dag_runs = _create_dagruns(
+ current_task.subdag,
+ execution_dates=confirmed_dates,
+ state=State.RUNNING,
+ run_type=DagRunType.BACKFILL_JOB,
+ )
verify_dagruns(dag_runs, commit, state, session, current_task)
@@ -279,10 +272,9 @@ def _set_dag_run_state(dag_id, execution_date, state, session=None):
:param state: target state
:param session: database session
"""
- dag_run = session.query(DagRun).filter(
- DagRun.dag_id == dag_id,
- DagRun.execution_date == execution_date
- ).one()
+ dag_run = (
+ session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.execution_date == execution_date).one()
+ )
dag_run.state = state
if state == State.RUNNING:
dag_run.start_date = timezone.utcnow()
@@ -316,8 +308,9 @@ def set_dag_run_state_to_success(dag, execution_date, commit=False, session=None
# Mark all task instances of the dag run to success.
for task in dag.tasks:
task.dag = dag
- return set_state(tasks=dag.tasks, execution_date=execution_date,
- state=State.SUCCESS, commit=commit, session=session)
+ return set_state(
+ tasks=dag.tasks, execution_date=execution_date, state=State.SUCCESS, commit=commit, session=session
+ )
@provide_session
@@ -343,10 +336,15 @@ def set_dag_run_state_to_failed(dag, execution_date, commit=False, session=None)
# Mark only RUNNING task instances.
task_ids = [task.task_id for task in dag.tasks]
- tis = session.query(TaskInstance).filter(
- TaskInstance.dag_id == dag.dag_id,
- TaskInstance.execution_date == execution_date,
- TaskInstance.task_id.in_(task_ids)).filter(TaskInstance.state == State.RUNNING)
+ tis = (
+ session.query(TaskInstance)
+ .filter(
+ TaskInstance.dag_id == dag.dag_id,
+ TaskInstance.execution_date == execution_date,
+ TaskInstance.task_id.in_(task_ids),
+ )
+ .filter(TaskInstance.state == State.RUNNING)
+ )
task_ids_of_running_tis = [task_instance.task_id for task_instance in tis]
tasks = []
@@ -356,8 +354,9 @@ def set_dag_run_state_to_failed(dag, execution_date, commit=False, session=None)
task.dag = dag
tasks.append(task)
- return set_state(tasks=tasks, execution_date=execution_date,
- state=State.FAILED, commit=commit, session=session)
+ return set_state(
+ tasks=tasks, execution_date=execution_date, state=State.FAILED, commit=commit, session=session
+ )
@provide_session
diff --git a/airflow/api/common/experimental/trigger_dag.py b/airflow/api/common/experimental/trigger_dag.py
index e1d3ceebff242..519079e0d30b4 100644
--- a/airflow/api/common/experimental/trigger_dag.py
+++ b/airflow/api/common/experimental/trigger_dag.py
@@ -63,16 +63,15 @@ def _trigger_dag(
if min_dag_start_date and execution_date < min_dag_start_date:
raise ValueError(
"The execution_date [{}] should be >= start_date [{}] from DAG's default_args".format(
- execution_date.isoformat(),
- min_dag_start_date.isoformat()))
+ execution_date.isoformat(), min_dag_start_date.isoformat()
+ )
+ )
run_id = run_id or DagRun.generate_run_id(DagRunType.MANUAL, execution_date)
dag_run = DagRun.find(dag_id=dag_id, run_id=run_id)
if dag_run:
- raise DagRunAlreadyExists(
- f"Run id {run_id} already exists for dag id {dag_id}"
- )
+ raise DagRunAlreadyExists(f"Run id {run_id} already exists for dag id {dag_id}")
run_conf = None
if conf:
@@ -95,11 +94,11 @@ def _trigger_dag(
def trigger_dag(
- dag_id: str,
- run_id: Optional[str] = None,
- conf: Optional[Union[dict, str]] = None,
- execution_date: Optional[datetime] = None,
- replace_microseconds: bool = True,
+ dag_id: str,
+ run_id: Optional[str] = None,
+ conf: Optional[Union[dict, str]] = None,
+ execution_date: Optional[datetime] = None,
+ replace_microseconds: bool = True,
) -> Optional[DagRun]:
"""Triggers execution of DAG specified by dag_id
diff --git a/airflow/api_connexion/endpoints/dag_run_endpoint.py b/airflow/api_connexion/endpoints/dag_run_endpoint.py
index bac2ad215bfe4..d6a155546efd4 100644
--- a/airflow/api_connexion/endpoints/dag_run_endpoint.py
+++ b/airflow/api_connexion/endpoints/dag_run_endpoint.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from connexion import NoContent
-from flask import g, request, current_app
+from flask import current_app, g, request
from marshmallow import ValidationError
from airflow.api_connexion import security
diff --git a/airflow/api_connexion/endpoints/dag_source_endpoint.py b/airflow/api_connexion/endpoints/dag_source_endpoint.py
index b023ae60a32a6..0281b2250fcae 100644
--- a/airflow/api_connexion/endpoints/dag_source_endpoint.py
+++ b/airflow/api_connexion/endpoints/dag_source_endpoint.py
@@ -25,7 +25,6 @@
from airflow.models.dagcode import DagCode
from airflow.security import permissions
-
log = logging.getLogger(__name__)
diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py b/airflow/api_connexion/endpoints/task_instance_endpoint.py
index 37e9b02cdc4f6..b84c59ad9f896 100644
--- a/airflow/api_connexion/endpoints/task_instance_endpoint.py
+++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py
@@ -16,31 +16,31 @@
# under the License.
from typing import Any, List, Optional, Tuple
-from flask import request, current_app
+from flask import current_app, request
from marshmallow import ValidationError
from sqlalchemy import and_, func
from airflow.api.common.experimental.mark_tasks import set_state
from airflow.api_connexion import security
-from airflow.api_connexion.exceptions import NotFound, BadRequest
+from airflow.api_connexion.exceptions import BadRequest, NotFound
from airflow.api_connexion.parameters import format_datetime, format_parameters
from airflow.api_connexion.schemas.task_instance_schema import (
- clear_task_instance_form,
TaskInstanceCollection,
- task_instance_collection_schema,
- task_instance_schema,
- task_instance_batch_form,
- task_instance_reference_collection_schema,
TaskInstanceReferenceCollection,
+ clear_task_instance_form,
set_task_instance_state_form,
+ task_instance_batch_form,
+ task_instance_collection_schema,
+ task_instance_reference_collection_schema,
+ task_instance_schema,
)
from airflow.exceptions import SerializedDagNotFound
-from airflow.models.dagrun import DagRun as DR
-from airflow.models.taskinstance import clear_task_instances, TaskInstance as TI
from airflow.models import SlaMiss
+from airflow.models.dagrun import DagRun as DR
+from airflow.models.taskinstance import TaskInstance as TI, clear_task_instances
from airflow.security import permissions
-from airflow.utils.state import State
from airflow.utils.session import provide_session
+from airflow.utils.state import State
@security.requires_access(
diff --git a/airflow/api_connexion/schemas/sla_miss_schema.py b/airflow/api_connexion/schemas/sla_miss_schema.py
index 341b8aa2a7c88..9413e37cbde21 100644
--- a/airflow/api_connexion/schemas/sla_miss_schema.py
+++ b/airflow/api_connexion/schemas/sla_miss_schema.py
@@ -16,6 +16,7 @@
# under the License.
from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field
+
from airflow.models import SlaMiss
diff --git a/airflow/api_connexion/schemas/task_instance_schema.py b/airflow/api_connexion/schemas/task_instance_schema.py
index e825665f9528a..27bf4a6cd891e 100644
--- a/airflow/api_connexion/schemas/task_instance_schema.py
+++ b/airflow/api_connexion/schemas/task_instance_schema.py
@@ -15,14 +15,14 @@
# specific language governing permissions and limitations
# under the License.
-from typing import List, NamedTuple, Tuple, Optional
+from typing import List, NamedTuple, Optional, Tuple
-from marshmallow import Schema, fields, ValidationError, validates_schema, validate
+from marshmallow import Schema, ValidationError, fields, validate, validates_schema
from marshmallow.utils import get_value
from airflow.api_connexion.schemas.enum_schemas import TaskInstanceStateField
from airflow.api_connexion.schemas.sla_miss_schema import SlaMissSchema
-from airflow.models import TaskInstance, SlaMiss
+from airflow.models import SlaMiss, TaskInstance
from airflow.utils.state import State
diff --git a/airflow/cli/cli_parser.py b/airflow/cli/cli_parser.py
index 84f1349f1121d..c9dbfec837ca8 100644
--- a/airflow/cli/cli_parser.py
+++ b/airflow/cli/cli_parser.py
@@ -92,8 +92,18 @@ class Arg:
"""Class to keep information about command line argument"""
# pylint: disable=redefined-builtin,unused-argument
- def __init__(self, flags=_UNSET, help=_UNSET, action=_UNSET, default=_UNSET, nargs=_UNSET, type=_UNSET,
- choices=_UNSET, required=_UNSET, metavar=_UNSET):
+ def __init__(
+ self,
+ flags=_UNSET,
+ help=_UNSET,
+ action=_UNSET,
+ default=_UNSET,
+ nargs=_UNSET,
+ type=_UNSET,
+ choices=_UNSET,
+ required=_UNSET,
+ metavar=_UNSET,
+ ):
self.flags = flags
self.kwargs = {}
for k, v in locals().items():
@@ -103,6 +113,7 @@ def __init__(self, flags=_UNSET, help=_UNSET, action=_UNSET, default=_UNSET, nar
continue
self.kwargs[k] = v
+
# pylint: enable=redefined-builtin,unused-argument
def add_to_parser(self, parser: argparse.ArgumentParser):
@@ -122,65 +133,47 @@ def positive_int(value):
# Shared
-ARG_DAG_ID = Arg(
- ("dag_id",),
- help="The id of the dag")
-ARG_TASK_ID = Arg(
- ("task_id",),
- help="The id of the task")
-ARG_EXECUTION_DATE = Arg(
- ("execution_date",),
- help="The execution date of the DAG",
- type=parsedate)
+ARG_DAG_ID = Arg(("dag_id",), help="The id of the dag")
+ARG_TASK_ID = Arg(("task_id",), help="The id of the task")
+ARG_EXECUTION_DATE = Arg(("execution_date",), help="The execution date of the DAG", type=parsedate)
ARG_TASK_REGEX = Arg(
- ("-t", "--task-regex"),
- help="The regex to filter specific task_ids to backfill (optional)")
+ ("-t", "--task-regex"), help="The regex to filter specific task_ids to backfill (optional)"
+)
ARG_SUBDIR = Arg(
("-S", "--subdir"),
help=(
"File location or directory from which to look for the dag. "
"Defaults to '[AIRFLOW_HOME]/dags' where [AIRFLOW_HOME] is the "
- "value you set for 'AIRFLOW_HOME' config you set in 'airflow.cfg' "),
- default='[AIRFLOW_HOME]/dags' if BUILD_DOCS else settings.DAGS_FOLDER)
-ARG_START_DATE = Arg(
- ("-s", "--start-date"),
- help="Override start_date YYYY-MM-DD",
- type=parsedate)
-ARG_END_DATE = Arg(
- ("-e", "--end-date"),
- help="Override end_date YYYY-MM-DD",
- type=parsedate)
+ "value you set for 'AIRFLOW_HOME' config you set in 'airflow.cfg' "
+ ),
+ default='[AIRFLOW_HOME]/dags' if BUILD_DOCS else settings.DAGS_FOLDER,
+)
+ARG_START_DATE = Arg(("-s", "--start-date"), help="Override start_date YYYY-MM-DD", type=parsedate)
+ARG_END_DATE = Arg(("-e", "--end-date"), help="Override end_date YYYY-MM-DD", type=parsedate)
ARG_OUTPUT_PATH = Arg(
- ("-o", "--output-path",),
+ (
+ "-o",
+ "--output-path",
+ ),
help="The output for generated yaml files",
type=str,
- default="[CWD]" if BUILD_DOCS else os.getcwd())
+ default="[CWD]" if BUILD_DOCS else os.getcwd(),
+)
ARG_DRY_RUN = Arg(
("-n", "--dry-run"),
help="Perform a dry run for each task. Only renders Template Fields for each task, nothing else",
- action="store_true")
-ARG_PID = Arg(
- ("--pid",),
- help="PID file location",
- nargs='?')
+ action="store_true",
+)
+ARG_PID = Arg(("--pid",), help="PID file location", nargs='?')
ARG_DAEMON = Arg(
- ("-D", "--daemon"),
- help="Daemonize instead of running in the foreground",
- action="store_true")
-ARG_STDERR = Arg(
- ("--stderr",),
- help="Redirect stderr to this file")
-ARG_STDOUT = Arg(
- ("--stdout",),
- help="Redirect stdout to this file")
-ARG_LOG_FILE = Arg(
- ("-l", "--log-file"),
- help="Location of the log file")
+ ("-D", "--daemon"), help="Daemonize instead of running in the foreground", action="store_true"
+)
+ARG_STDERR = Arg(("--stderr",), help="Redirect stderr to this file")
+ARG_STDOUT = Arg(("--stdout",), help="Redirect stdout to this file")
+ARG_LOG_FILE = Arg(("-l", "--log-file"), help="Location of the log file")
ARG_YES = Arg(
- ("-y", "--yes"),
- help="Do not prompt to confirm reset. Use with care!",
- action="store_true",
- default=False)
+ ("-y", "--yes"), help="Do not prompt to confirm reset. Use with care!", action="store_true", default=False
+)
ARG_OUTPUT = Arg(
("--output",),
help=(
@@ -189,238 +182,181 @@ def positive_int(value):
),
metavar="FORMAT",
choices=tabulate_formats,
- default="plain")
+ default="plain",
+)
ARG_COLOR = Arg(
('--color',),
help="Do emit colored output (default: auto)",
choices={ColorMode.ON, ColorMode.OFF, ColorMode.AUTO},
- default=ColorMode.AUTO)
+ default=ColorMode.AUTO,
+)
# list_dag_runs
-ARG_DAG_ID_OPT = Arg(
- ("-d", "--dag-id"),
- help="The id of the dag"
-)
+ARG_DAG_ID_OPT = Arg(("-d", "--dag-id"), help="The id of the dag")
ARG_NO_BACKFILL = Arg(
- ("--no-backfill",),
- help="filter all the backfill dagruns given the dag id",
- action="store_true")
-ARG_STATE = Arg(
- ("--state",),
- help="Only list the dag runs corresponding to the state")
+ ("--no-backfill",), help="filter all the backfill dagruns given the dag id", action="store_true"
+)
+ARG_STATE = Arg(("--state",), help="Only list the dag runs corresponding to the state")
# list_jobs
-ARG_LIMIT = Arg(
- ("--limit",),
- help="Return a limited number of records")
+ARG_LIMIT = Arg(("--limit",), help="Return a limited number of records")
# next_execution
ARG_NUM_EXECUTIONS = Arg(
("-n", "--num-executions"),
default=1,
type=positive_int,
- help="The number of next execution datetimes to show")
+ help="The number of next execution datetimes to show",
+)
# backfill
ARG_MARK_SUCCESS = Arg(
- ("-m", "--mark-success"),
- help="Mark jobs as succeeded without running them",
- action="store_true")
-ARG_VERBOSE = Arg(
- ("-v", "--verbose"),
- help="Make logging output more verbose",
- action="store_true")
-ARG_LOCAL = Arg(
- ("-l", "--local"),
- help="Run the task using the LocalExecutor",
- action="store_true")
+ ("-m", "--mark-success"), help="Mark jobs as succeeded without running them", action="store_true"
+)
+ARG_VERBOSE = Arg(("-v", "--verbose"), help="Make logging output more verbose", action="store_true")
+ARG_LOCAL = Arg(("-l", "--local"), help="Run the task using the LocalExecutor", action="store_true")
ARG_DONOT_PICKLE = Arg(
("-x", "--donot-pickle"),
help=(
"Do not attempt to pickle the DAG object to send over "
"to the workers, just tell the workers to run their version "
- "of the code"),
- action="store_true")
+ "of the code"
+ ),
+ action="store_true",
+)
ARG_BF_IGNORE_DEPENDENCIES = Arg(
("-i", "--ignore-dependencies"),
help=(
"Skip upstream tasks, run only the tasks "
"matching the regexp. Only works in conjunction "
- "with task_regex"),
- action="store_true")
+ "with task_regex"
+ ),
+ action="store_true",
+)
ARG_BF_IGNORE_FIRST_DEPENDS_ON_PAST = Arg(
("-I", "--ignore-first-depends-on-past"),
help=(
"Ignores depends_on_past dependencies for the first "
"set of tasks only (subsequent executions in the backfill "
- "DO respect depends_on_past)"),
- action="store_true")
+ "DO respect depends_on_past)"
+ ),
+ action="store_true",
+)
ARG_POOL = Arg(("--pool",), "Resource pool to use")
ARG_DELAY_ON_LIMIT = Arg(
("--delay-on-limit",),
- help=("Amount of time in seconds to wait when the limit "
- "on maximum active dag runs (max_active_runs) has "
- "been reached before trying to execute a dag run "
- "again"),
+ help=(
+ "Amount of time in seconds to wait when the limit "
+ "on maximum active dag runs (max_active_runs) has "
+ "been reached before trying to execute a dag run "
+ "again"
+ ),
type=float,
- default=1.0)
+ default=1.0,
+)
ARG_RESET_DAG_RUN = Arg(
("--reset-dagruns",),
help=(
"if set, the backfill will delete existing "
"backfill-related DAG runs and start "
- "anew with fresh, running DAG runs"),
- action="store_true")
+ "anew with fresh, running DAG runs"
+ ),
+ action="store_true",
+)
ARG_RERUN_FAILED_TASKS = Arg(
("--rerun-failed-tasks",),
help=(
"if set, the backfill will auto-rerun "
"all the failed tasks for the backfill date range "
- "instead of throwing exceptions"),
- action="store_true")
+ "instead of throwing exceptions"
+ ),
+ action="store_true",
+)
ARG_RUN_BACKWARDS = Arg(
- ("-B", "--run-backwards",),
+ (
+ "-B",
+ "--run-backwards",
+ ),
help=(
"if set, the backfill will run tasks from the most "
"recent day first. if there are tasks that depend_on_past "
- "this option will throw an exception"),
- action="store_true")
+ "this option will throw an exception"
+ ),
+ action="store_true",
+)
# test_dag
ARG_SHOW_DAGRUN = Arg(
- ("--show-dagrun", ),
+ ("--show-dagrun",),
help=(
"After completing the backfill, shows the diagram for current DAG Run.\n"
"\n"
- "The diagram is in DOT language\n"),
- action='store_true')
+ "The diagram is in DOT language\n"
+ ),
+ action='store_true',
+)
ARG_IMGCAT_DAGRUN = Arg(
- ("--imgcat-dagrun", ),
+ ("--imgcat-dagrun",),
help=(
"After completing the dag run, prints a diagram on the screen for the "
"current DAG Run using the imgcat tool.\n"
),
- action='store_true')
+ action='store_true',
+)
ARG_SAVE_DAGRUN = Arg(
- ("--save-dagrun", ),
+ ("--save-dagrun",),
help=(
- "After completing the backfill, saves the diagram for current DAG Run to the indicated file.\n"
- "\n"
- ))
+ "After completing the backfill, saves the diagram for current DAG Run to the indicated file.\n" "\n"
+ ),
+)
# list_tasks
-ARG_TREE = Arg(
- ("-t", "--tree"),
- help="Tree view",
- action="store_true")
+ARG_TREE = Arg(("-t", "--tree"), help="Tree view", action="store_true")
# clear
-ARG_UPSTREAM = Arg(
- ("-u", "--upstream"),
- help="Include upstream tasks",
- action="store_true")
-ARG_ONLY_FAILED = Arg(
- ("-f", "--only-failed"),
- help="Only failed jobs",
- action="store_true")
-ARG_ONLY_RUNNING = Arg(
- ("-r", "--only-running"),
- help="Only running jobs",
- action="store_true")
-ARG_DOWNSTREAM = Arg(
- ("-d", "--downstream"),
- help="Include downstream tasks",
- action="store_true")
-ARG_EXCLUDE_SUBDAGS = Arg(
- ("-x", "--exclude-subdags"),
- help="Exclude subdags",
- action="store_true")
+ARG_UPSTREAM = Arg(("-u", "--upstream"), help="Include upstream tasks", action="store_true")
+ARG_ONLY_FAILED = Arg(("-f", "--only-failed"), help="Only failed jobs", action="store_true")
+ARG_ONLY_RUNNING = Arg(("-r", "--only-running"), help="Only running jobs", action="store_true")
+ARG_DOWNSTREAM = Arg(("-d", "--downstream"), help="Include downstream tasks", action="store_true")
+ARG_EXCLUDE_SUBDAGS = Arg(("-x", "--exclude-subdags"), help="Exclude subdags", action="store_true")
ARG_EXCLUDE_PARENTDAG = Arg(
("-X", "--exclude-parentdag"),
help="Exclude ParentDAGS if the task cleared is a part of a SubDAG",
- action="store_true")
+ action="store_true",
+)
ARG_DAG_REGEX = Arg(
- ("-R", "--dag-regex"),
- help="Search dag_id as regex instead of exact string",
- action="store_true")
+ ("-R", "--dag-regex"), help="Search dag_id as regex instead of exact string", action="store_true"
+)
# show_dag
-ARG_SAVE = Arg(
- ("-s", "--save"),
- help="Saves the result to the indicated file.")
+ARG_SAVE = Arg(("-s", "--save"), help="Saves the result to the indicated file.")
-ARG_IMGCAT = Arg(
- ("--imgcat", ),
- help=(
- "Displays graph using the imgcat tool."),
- action='store_true')
+ARG_IMGCAT = Arg(("--imgcat",), help=("Displays graph using the imgcat tool."), action='store_true')
# trigger_dag
-ARG_RUN_ID = Arg(
- ("-r", "--run-id"),
- help="Helps to identify this run")
-ARG_CONF = Arg(
- ('-c', '--conf'),
- help="JSON string that gets pickled into the DagRun's conf attribute")
-ARG_EXEC_DATE = Arg(
- ("-e", "--exec-date"),
- help="The execution date of the DAG",
- type=parsedate)
+ARG_RUN_ID = Arg(("-r", "--run-id"), help="Helps to identify this run")
+ARG_CONF = Arg(('-c', '--conf'), help="JSON string that gets pickled into the DagRun's conf attribute")
+ARG_EXEC_DATE = Arg(("-e", "--exec-date"), help="The execution date of the DAG", type=parsedate)
# pool
-ARG_POOL_NAME = Arg(
- ("pool",),
- metavar='NAME',
- help="Pool name")
-ARG_POOL_SLOTS = Arg(
- ("slots",),
- type=int,
- help="Pool slots")
-ARG_POOL_DESCRIPTION = Arg(
- ("description",),
- help="Pool description")
-ARG_POOL_IMPORT = Arg(
- ("file",),
- metavar="FILEPATH",
- help="Import pools from JSON file")
-ARG_POOL_EXPORT = Arg(
- ("file",),
- metavar="FILEPATH",
- help="Export all pools to JSON file")
+ARG_POOL_NAME = Arg(("pool",), metavar='NAME', help="Pool name")
+ARG_POOL_SLOTS = Arg(("slots",), type=int, help="Pool slots")
+ARG_POOL_DESCRIPTION = Arg(("description",), help="Pool description")
+ARG_POOL_IMPORT = Arg(("file",), metavar="FILEPATH", help="Import pools from JSON file")
+ARG_POOL_EXPORT = Arg(("file",), metavar="FILEPATH", help="Export all pools to JSON file")
# variables
-ARG_VAR = Arg(
- ("key",),
- help="Variable key")
-ARG_VAR_VALUE = Arg(
- ("value",),
- metavar='VALUE',
- help="Variable value")
+ARG_VAR = Arg(("key",), help="Variable key")
+ARG_VAR_VALUE = Arg(("value",), metavar='VALUE', help="Variable value")
ARG_DEFAULT = Arg(
- ("-d", "--default"),
- metavar="VAL",
- default=None,
- help="Default value returned if variable does not exist")
-ARG_JSON = Arg(
- ("-j", "--json"),
- help="Deserialize JSON variable",
- action="store_true")
-ARG_VAR_IMPORT = Arg(
- ("file",),
- help="Import variables from JSON file")
-ARG_VAR_EXPORT = Arg(
- ("file",),
- help="Export all variables to JSON file")
+ ("-d", "--default"), metavar="VAL", default=None, help="Default value returned if variable does not exist"
+)
+ARG_JSON = Arg(("-j", "--json"), help="Deserialize JSON variable", action="store_true")
+ARG_VAR_IMPORT = Arg(("file",), help="Import variables from JSON file")
+ARG_VAR_EXPORT = Arg(("file",), help="Export all variables to JSON file")
# kerberos
-ARG_PRINCIPAL = Arg(
- ("principal",),
- help="kerberos principal",
- nargs='?')
-ARG_KEYTAB = Arg(
- ("-k", "--keytab"),
- help="keytab",
- nargs='?',
- default=conf.get('kerberos', 'keytab'))
+ARG_PRINCIPAL = Arg(("principal",), help="kerberos principal", nargs='?')
+ARG_KEYTAB = Arg(("-k", "--keytab"), help="keytab", nargs='?', default=conf.get('kerberos', 'keytab'))
# run
# TODO(aoen): "force" is a poor choice of name here since it implies it overrides
# all dependencies (not just past success), e.g. the ignore_depends_on_past
@@ -429,18 +365,20 @@ def positive_int(value):
# instead.
ARG_INTERACTIVE = Arg(
('-N', '--interactive'),
- help='Do not capture standard output and error streams '
- '(useful for interactive debugging)',
- action='store_true')
+ help='Do not capture standard output and error streams ' '(useful for interactive debugging)',
+ action='store_true',
+)
ARG_FORCE = Arg(
("-f", "--force"),
help="Ignore previous task instance state, rerun regardless if task already succeeded/failed",
- action="store_true")
+ action="store_true",
+)
ARG_RAW = Arg(("-r", "--raw"), argparse.SUPPRESS, "store_true")
ARG_IGNORE_ALL_DEPENDENCIES = Arg(
("-A", "--ignore-all-dependencies"),
help="Ignores all non-critical dependencies, including ignore_ti_state and ignore_task_deps",
- action="store_true")
+ action="store_true",
+)
# TODO(aoen): ignore_dependencies is a poor choice of name here because it is too
# vague (e.g. a task being in the appropriate state to be run is also a dependency
# but is not ignored by this flag), the name 'ignore_task_dependencies' is
@@ -449,24 +387,19 @@ def positive_int(value):
ARG_IGNORE_DEPENDENCIES = Arg(
("-i", "--ignore-dependencies"),
help="Ignore task-specific dependencies, e.g. upstream, depends_on_past, and retry delay dependencies",
- action="store_true")
+ action="store_true",
+)
ARG_IGNORE_DEPENDS_ON_PAST = Arg(
("-I", "--ignore-depends-on-past"),
help="Ignore depends_on_past dependencies (but respect upstream dependencies)",
- action="store_true")
+ action="store_true",
+)
ARG_SHIP_DAG = Arg(
- ("--ship-dag",),
- help="Pickles (serializes) the DAG and ships it to the worker",
- action="store_true")
-ARG_PICKLE = Arg(
- ("-p", "--pickle"),
- help="Serialized pickle object of the entire dag (used internally)")
-ARG_JOB_ID = Arg(
- ("-j", "--job-id"),
- help=argparse.SUPPRESS)
-ARG_CFG_PATH = Arg(
- ("--cfg-path",),
- help="Path to config file to use instead of airflow.cfg")
+ ("--ship-dag",), help="Pickles (serializes) the DAG and ships it to the worker", action="store_true"
+)
+ARG_PICKLE = Arg(("-p", "--pickle"), help="Serialized pickle object of the entire dag (used internally)")
+ARG_JOB_ID = Arg(("-j", "--job-id"), help=argparse.SUPPRESS)
+ARG_CFG_PATH = Arg(("--cfg-path",), help="Path to config file to use instead of airflow.cfg")
ARG_MIGRATION_TIMEOUT = Arg(
("-t", "--migration-wait-timeout"),
help="timeout to wait for db to migrate ",
@@ -479,220 +412,194 @@ def positive_int(value):
("-p", "--port"),
default=conf.get('webserver', 'WEB_SERVER_PORT'),
type=int,
- help="The port on which to run the server")
+ help="The port on which to run the server",
+)
ARG_SSL_CERT = Arg(
("--ssl-cert",),
default=conf.get('webserver', 'WEB_SERVER_SSL_CERT'),
- help="Path to the SSL certificate for the webserver")
+ help="Path to the SSL certificate for the webserver",
+)
ARG_SSL_KEY = Arg(
("--ssl-key",),
default=conf.get('webserver', 'WEB_SERVER_SSL_KEY'),
- help="Path to the key to use with the SSL certificate")
+ help="Path to the key to use with the SSL certificate",
+)
ARG_WORKERS = Arg(
("-w", "--workers"),
default=conf.get('webserver', 'WORKERS'),
type=int,
- help="Number of workers to run the webserver on")
+ help="Number of workers to run the webserver on",
+)
ARG_WORKERCLASS = Arg(
("-k", "--workerclass"),
default=conf.get('webserver', 'WORKER_CLASS'),
choices=['sync', 'eventlet', 'gevent', 'tornado'],
- help="The worker class to use for Gunicorn")
+ help="The worker class to use for Gunicorn",
+)
ARG_WORKER_TIMEOUT = Arg(
("-t", "--worker-timeout"),
default=conf.get('webserver', 'WEB_SERVER_WORKER_TIMEOUT'),
type=int,
- help="The timeout for waiting on webserver workers")
+ help="The timeout for waiting on webserver workers",
+)
ARG_HOSTNAME = Arg(
("-H", "--hostname"),
default=conf.get('webserver', 'WEB_SERVER_HOST'),
- help="Set the hostname on which to run the web server")
+ help="Set the hostname on which to run the web server",
+)
ARG_DEBUG = Arg(
- ("-d", "--debug"),
- help="Use the server that ships with Flask in debug mode",
- action="store_true")
+ ("-d", "--debug"), help="Use the server that ships with Flask in debug mode", action="store_true"
+)
ARG_ACCESS_LOGFILE = Arg(
("-A", "--access-logfile"),
default=conf.get('webserver', 'ACCESS_LOGFILE'),
- help="The logfile to store the webserver access log. Use '-' to print to "
- "stderr")
+ help="The logfile to store the webserver access log. Use '-' to print to " "stderr",
+)
ARG_ERROR_LOGFILE = Arg(
("-E", "--error-logfile"),
default=conf.get('webserver', 'ERROR_LOGFILE'),
- help="The logfile to store the webserver error log. Use '-' to print to "
- "stderr")
+ help="The logfile to store the webserver error log. Use '-' to print to " "stderr",
+)
# scheduler
ARG_NUM_RUNS = Arg(
("-n", "--num-runs"),
default=conf.getint('scheduler', 'num_runs'),
type=int,
- help="Set the number of runs to execute before exiting")
+ help="Set the number of runs to execute before exiting",
+)
ARG_DO_PICKLE = Arg(
("-p", "--do-pickle"),
default=False,
help=(
"Attempt to pickle the DAG object to send over "
"to the workers, instead of letting workers run their version "
- "of the code"),
- action="store_true")
+ "of the code"
+ ),
+ action="store_true",
+)
# worker
ARG_QUEUES = Arg(
("-q", "--queues"),
help="Comma delimited list of queues to serve",
- default=conf.get('celery', 'DEFAULT_QUEUE'))
+ default=conf.get('celery', 'DEFAULT_QUEUE'),
+)
ARG_CONCURRENCY = Arg(
("-c", "--concurrency"),
type=int,
help="The number of worker processes",
- default=conf.get('celery', 'worker_concurrency'))
+ default=conf.get('celery', 'worker_concurrency'),
+)
ARG_CELERY_HOSTNAME = Arg(
("-H", "--celery-hostname"),
- help=("Set the hostname of celery worker "
- "if you have multiple workers on a single machine"))
+ help=("Set the hostname of celery worker " "if you have multiple workers on a single machine"),
+)
ARG_UMASK = Arg(
("-u", "--umask"),
help="Set the umask of celery worker in daemon mode",
- default=conf.get('celery', 'worker_umask'))
+ default=conf.get('celery', 'worker_umask'),
+)
# flower
ARG_BROKER_API = Arg(("-a", "--broker-api"), help="Broker API")
ARG_FLOWER_HOSTNAME = Arg(
("-H", "--hostname"),
default=conf.get('celery', 'FLOWER_HOST'),
- help="Set the hostname on which to run the server")
+ help="Set the hostname on which to run the server",
+)
ARG_FLOWER_PORT = Arg(
("-p", "--port"),
default=conf.get('celery', 'FLOWER_PORT'),
type=int,
- help="The port on which to run the server")
-ARG_FLOWER_CONF = Arg(
- ("-c", "--flower-conf"),
- help="Configuration file for flower")
+ help="The port on which to run the server",
+)
+ARG_FLOWER_CONF = Arg(("-c", "--flower-conf"), help="Configuration file for flower")
ARG_FLOWER_URL_PREFIX = Arg(
- ("-u", "--url-prefix"),
- default=conf.get('celery', 'FLOWER_URL_PREFIX'),
- help="URL prefix for Flower")
+ ("-u", "--url-prefix"), default=conf.get('celery', 'FLOWER_URL_PREFIX'), help="URL prefix for Flower"
+)
ARG_FLOWER_BASIC_AUTH = Arg(
("-A", "--basic-auth"),
default=conf.get('celery', 'FLOWER_BASIC_AUTH'),
- help=("Securing Flower with Basic Authentication. "
- "Accepts user:password pairs separated by a comma. "
- "Example: flower_basic_auth = user1:password1,user2:password2"))
-ARG_TASK_PARAMS = Arg(
- ("-t", "--task-params"),
- help="Sends a JSON params dict to the task")
+ help=(
+ "Securing Flower with Basic Authentication. "
+ "Accepts user:password pairs separated by a comma. "
+ "Example: flower_basic_auth = user1:password1,user2:password2"
+ ),
+)
+ARG_TASK_PARAMS = Arg(("-t", "--task-params"), help="Sends a JSON params dict to the task")
ARG_POST_MORTEM = Arg(
- ("-m", "--post-mortem"),
- action="store_true",
- help="Open debugger on uncaught exception")
+ ("-m", "--post-mortem"), action="store_true", help="Open debugger on uncaught exception"
+)
ARG_ENV_VARS = Arg(
- ("--env-vars", ),
+ ("--env-vars",),
help="Set env var in both parsing time and runtime for each of entry supplied in a JSON dict",
- type=json.loads)
+ type=json.loads,
+)
# connections
-ARG_CONN_ID = Arg(
- ('conn_id',),
- help='Connection id, required to get/add/delete a connection',
- type=str)
+ARG_CONN_ID = Arg(('conn_id',), help='Connection id, required to get/add/delete a connection', type=str)
ARG_CONN_ID_FILTER = Arg(
- ('--conn-id',),
- help='If passed, only items with the specified connection ID will be displayed',
- type=str)
+ ('--conn-id',), help='If passed, only items with the specified connection ID will be displayed', type=str
+)
ARG_CONN_URI = Arg(
- ('--conn-uri',),
- help='Connection URI, required to add a connection without conn_type',
- type=str)
+ ('--conn-uri',), help='Connection URI, required to add a connection without conn_type', type=str
+)
ARG_CONN_TYPE = Arg(
- ('--conn-type',),
- help='Connection type, required to add a connection without conn_uri',
- type=str)
-ARG_CONN_HOST = Arg(
- ('--conn-host',),
- help='Connection host, optional when adding a connection',
- type=str)
-ARG_CONN_LOGIN = Arg(
- ('--conn-login',),
- help='Connection login, optional when adding a connection',
- type=str)
+ ('--conn-type',), help='Connection type, required to add a connection without conn_uri', type=str
+)
+ARG_CONN_HOST = Arg(('--conn-host',), help='Connection host, optional when adding a connection', type=str)
+ARG_CONN_LOGIN = Arg(('--conn-login',), help='Connection login, optional when adding a connection', type=str)
ARG_CONN_PASSWORD = Arg(
- ('--conn-password',),
- help='Connection password, optional when adding a connection',
- type=str)
+ ('--conn-password',), help='Connection password, optional when adding a connection', type=str
+)
ARG_CONN_SCHEMA = Arg(
- ('--conn-schema',),
- help='Connection schema, optional when adding a connection',
- type=str)
-ARG_CONN_PORT = Arg(
- ('--conn-port',),
- help='Connection port, optional when adding a connection',
- type=str)
+ ('--conn-schema',), help='Connection schema, optional when adding a connection', type=str
+)
+ARG_CONN_PORT = Arg(('--conn-port',), help='Connection port, optional when adding a connection', type=str)
ARG_CONN_EXTRA = Arg(
- ('--conn-extra',),
- help='Connection `Extra` field, optional when adding a connection',
- type=str)
+ ('--conn-extra',), help='Connection `Extra` field, optional when adding a connection', type=str
+)
ARG_CONN_EXPORT = Arg(
('file',),
help='Output file path for exporting the connections',
- type=argparse.FileType('w', encoding='UTF-8'))
+ type=argparse.FileType('w', encoding='UTF-8'),
+)
ARG_CONN_EXPORT_FORMAT = Arg(
- ('--format',),
- help='Format of the connections data in file',
- type=str,
- choices=['json', 'yaml', 'env'])
+ ('--format',), help='Format of the connections data in file', type=str, choices=['json', 'yaml', 'env']
+)
# users
-ARG_USERNAME = Arg(
- ('-u', '--username'),
- help='Username of the user',
- required=True,
- type=str)
-ARG_USERNAME_OPTIONAL = Arg(
- ('-u', '--username'),
- help='Username of the user',
- type=str)
-ARG_FIRSTNAME = Arg(
- ('-f', '--firstname'),
- help='First name of the user',
- required=True,
- type=str)
-ARG_LASTNAME = Arg(
- ('-l', '--lastname'),
- help='Last name of the user',
- required=True,
- type=str)
+ARG_USERNAME = Arg(('-u', '--username'), help='Username of the user', required=True, type=str)
+ARG_USERNAME_OPTIONAL = Arg(('-u', '--username'), help='Username of the user', type=str)
+ARG_FIRSTNAME = Arg(('-f', '--firstname'), help='First name of the user', required=True, type=str)
+ARG_LASTNAME = Arg(('-l', '--lastname'), help='Last name of the user', required=True, type=str)
ARG_ROLE = Arg(
('-r', '--role'),
- help='Role of the user. Existing roles include Admin, '
- 'User, Op, Viewer, and Public',
- required=True,
- type=str,)
-ARG_EMAIL = Arg(
- ('-e', '--email'),
- help='Email of the user',
+ help='Role of the user. Existing roles include Admin, ' 'User, Op, Viewer, and Public',
required=True,
- type=str)
-ARG_EMAIL_OPTIONAL = Arg(
- ('-e', '--email'),
- help='Email of the user',
- type=str)
+ type=str,
+)
+ARG_EMAIL = Arg(('-e', '--email'), help='Email of the user', required=True, type=str)
+ARG_EMAIL_OPTIONAL = Arg(('-e', '--email'), help='Email of the user', type=str)
ARG_PASSWORD = Arg(
('-p', '--password'),
- help='Password of the user, required to create a user '
- 'without --use-random-password',
- type=str)
+ help='Password of the user, required to create a user ' 'without --use-random-password',
+ type=str,
+)
ARG_USE_RANDOM_PASSWORD = Arg(
('--use-random-password',),
help='Do not prompt for password. Use random string instead.'
- ' Required to create a user without --password ',
+ ' Required to create a user without --password ',
default=False,
- action='store_true')
+ action='store_true',
+)
ARG_USER_IMPORT = Arg(
("import",),
metavar="FILEPATH",
- help="Import users from JSON file. Example format::\n" +
- textwrap.indent(textwrap.dedent('''
+ help="Import users from JSON file. Example format::\n"
+ + textwrap.indent(
+ textwrap.dedent(
+ '''
[
{
"email": "foo@bar.org",
@@ -701,49 +608,33 @@ def positive_int(value):
"roles": ["Public"],
"username": "jondoe"
}
- ]'''), " " * 4))
-ARG_USER_EXPORT = Arg(
- ("export",),
- metavar="FILEPATH",
- help="Export all users to JSON file")
+ ]'''
+ ),
+ " " * 4,
+ ),
+)
+ARG_USER_EXPORT = Arg(("export",), metavar="FILEPATH", help="Export all users to JSON file")
# roles
-ARG_CREATE_ROLE = Arg(
- ('-c', '--create'),
- help='Create a new role',
- action='store_true')
-ARG_LIST_ROLES = Arg(
- ('-l', '--list'),
- help='List roles',
- action='store_true')
-ARG_ROLES = Arg(
- ('role',),
- help='The name of a role',
- nargs='*')
-ARG_AUTOSCALE = Arg(
- ('-a', '--autoscale'),
- help="Minimum and Maximum number of worker to autoscale")
+ARG_CREATE_ROLE = Arg(('-c', '--create'), help='Create a new role', action='store_true')
+ARG_LIST_ROLES = Arg(('-l', '--list'), help='List roles', action='store_true')
+ARG_ROLES = Arg(('role',), help='The name of a role', nargs='*')
+ARG_AUTOSCALE = Arg(('-a', '--autoscale'), help="Minimum and Maximum number of worker to autoscale")
ARG_SKIP_SERVE_LOGS = Arg(
("-s", "--skip-serve-logs"),
default=False,
help="Don't start the serve logs process along with the workers",
- action="store_true")
+ action="store_true",
+)
# info
ARG_ANONYMIZE = Arg(
('--anonymize',),
- help=(
- 'Minimize any personal identifiable information. '
- 'Use it when sharing output with others.'
- ),
- action='store_true'
+ help=('Minimize any personal identifiable information. ' 'Use it when sharing output with others.'),
+ action='store_true',
)
ARG_FILE_IO = Arg(
- ('--file-io',),
- help=(
- 'Send output to file.io service and returns link.'
- ),
- action='store_true'
+ ('--file-io',), help=('Send output to file.io service and returns link.'), action='store_true'
)
# config
@@ -764,7 +655,12 @@ def positive_int(value):
)
ALTERNATIVE_CONN_SPECS_ARGS = [
- ARG_CONN_TYPE, ARG_CONN_HOST, ARG_CONN_LOGIN, ARG_CONN_PASSWORD, ARG_CONN_SCHEMA, ARG_CONN_PORT
+ ARG_CONN_TYPE,
+ ARG_CONN_HOST,
+ ARG_CONN_LOGIN,
+ ARG_CONN_PASSWORD,
+ ARG_CONN_SCHEMA,
+ ARG_CONN_PORT,
]
@@ -867,23 +763,30 @@ class GroupCommand(NamedTuple):
ActionCommand(
name='show',
help="Displays DAG's tasks with their dependencies",
- description=("The --imgcat option only works in iTerm.\n"
- "\n"
- "For more information, see: https://www.iterm2.com/documentation-images.html\n"
- "\n"
- "The --save option saves the result to the indicated file.\n"
- "\n"
- "The file format is determined by the file extension. "
- "For more information about supported "
- "format, see: https://www.graphviz.org/doc/info/output.html\n"
- "\n"
- "If you want to create a PNG file then you should execute the following command:\n"
- "airflow dags show --save output.png\n"
- "\n"
- "If you want to create a DOT file then you should execute the following command:\n"
- "airflow dags show --save output.dot\n"),
+ description=(
+ "The --imgcat option only works in iTerm.\n"
+ "\n"
+ "For more information, see: https://www.iterm2.com/documentation-images.html\n"
+ "\n"
+ "The --save option saves the result to the indicated file.\n"
+ "\n"
+ "The file format is determined by the file extension. "
+ "For more information about supported "
+ "format, see: https://www.graphviz.org/doc/info/output.html\n"
+ "\n"
+ "If you want to create a PNG file then you should execute the following command:\n"
+ "airflow dags show --save output.png\n"
+ "\n"
+ "If you want to create a DOT file then you should execute the following command:\n"
+ "airflow dags show --save output.dot\n"
+ ),
func=lazy_load_command('airflow.cli.commands.dag_command.dag_show'),
- args=(ARG_DAG_ID, ARG_SUBDIR, ARG_SAVE, ARG_IMGCAT,),
+ args=(
+ ARG_DAG_ID,
+ ARG_SUBDIR,
+ ARG_SAVE,
+ ARG_IMGCAT,
+ ),
),
ActionCommand(
name='backfill',
@@ -896,36 +799,58 @@ class GroupCommand(NamedTuple):
),
func=lazy_load_command('airflow.cli.commands.dag_command.dag_backfill'),
args=(
- ARG_DAG_ID, ARG_TASK_REGEX, ARG_START_DATE, ARG_END_DATE, ARG_MARK_SUCCESS, ARG_LOCAL,
- ARG_DONOT_PICKLE, ARG_YES, ARG_BF_IGNORE_DEPENDENCIES, ARG_BF_IGNORE_FIRST_DEPENDS_ON_PAST,
- ARG_SUBDIR, ARG_POOL, ARG_DELAY_ON_LIMIT, ARG_DRY_RUN, ARG_VERBOSE, ARG_CONF,
- ARG_RESET_DAG_RUN, ARG_RERUN_FAILED_TASKS, ARG_RUN_BACKWARDS
+ ARG_DAG_ID,
+ ARG_TASK_REGEX,
+ ARG_START_DATE,
+ ARG_END_DATE,
+ ARG_MARK_SUCCESS,
+ ARG_LOCAL,
+ ARG_DONOT_PICKLE,
+ ARG_YES,
+ ARG_BF_IGNORE_DEPENDENCIES,
+ ARG_BF_IGNORE_FIRST_DEPENDS_ON_PAST,
+ ARG_SUBDIR,
+ ARG_POOL,
+ ARG_DELAY_ON_LIMIT,
+ ARG_DRY_RUN,
+ ARG_VERBOSE,
+ ARG_CONF,
+ ARG_RESET_DAG_RUN,
+ ARG_RERUN_FAILED_TASKS,
+ ARG_RUN_BACKWARDS,
),
),
ActionCommand(
name='test',
help="Execute one single DagRun",
- description=("Execute one single DagRun for a given DAG and execution date, "
- "using the DebugExecutor.\n"
- "\n"
- "The --imgcat-dagrun option only works in iTerm.\n"
- "\n"
- "For more information, see: https://www.iterm2.com/documentation-images.html\n"
- "\n"
- "If --save-dagrun is used, then, after completing the backfill, saves the diagram "
- "for current DAG Run to the indicated file.\n"
- "The file format is determined by the file extension. "
- "For more information about supported format, "
- "see: https://www.graphviz.org/doc/info/output.html\n"
- "\n"
- "If you want to create a PNG file then you should execute the following command:\n"
- "airflow dags test --save-dagrun output.png\n"
- "\n"
- "If you want to create a DOT file then you should execute the following command:\n"
- "airflow dags test --save-dagrun output.dot\n"),
+ description=(
+ "Execute one single DagRun for a given DAG and execution date, "
+ "using the DebugExecutor.\n"
+ "\n"
+ "The --imgcat-dagrun option only works in iTerm.\n"
+ "\n"
+ "For more information, see: https://www.iterm2.com/documentation-images.html\n"
+ "\n"
+ "If --save-dagrun is used, then, after completing the backfill, saves the diagram "
+ "for current DAG Run to the indicated file.\n"
+ "The file format is determined by the file extension. "
+ "For more information about supported format, "
+ "see: https://www.graphviz.org/doc/info/output.html\n"
+ "\n"
+ "If you want to create a PNG file then you should execute the following command:\n"
+ "airflow dags test --save-dagrun output.png\n"
+ "\n"
+ "If you want to create a DOT file then you should execute the following command:\n"
+ "airflow dags test --save-dagrun output.dot\n"
+ ),
func=lazy_load_command('airflow.cli.commands.dag_command.dag_test'),
args=(
- ARG_DAG_ID, ARG_EXECUTION_DATE, ARG_SUBDIR, ARG_SHOW_DAGRUN, ARG_IMGCAT_DAGRUN, ARG_SAVE_DAGRUN
+ ARG_DAG_ID,
+ ARG_EXECUTION_DATE,
+ ARG_SUBDIR,
+ ARG_SHOW_DAGRUN,
+ ARG_IMGCAT_DAGRUN,
+ ARG_SAVE_DAGRUN,
),
),
)
@@ -941,9 +866,19 @@ class GroupCommand(NamedTuple):
help="Clear a set of task instance, as if they never ran",
func=lazy_load_command('airflow.cli.commands.task_command.task_clear'),
args=(
- ARG_DAG_ID, ARG_TASK_REGEX, ARG_START_DATE, ARG_END_DATE, ARG_SUBDIR, ARG_UPSTREAM,
- ARG_DOWNSTREAM, ARG_YES, ARG_ONLY_FAILED, ARG_ONLY_RUNNING, ARG_EXCLUDE_SUBDAGS,
- ARG_EXCLUDE_PARENTDAG, ARG_DAG_REGEX
+ ARG_DAG_ID,
+ ARG_TASK_REGEX,
+ ARG_START_DATE,
+ ARG_END_DATE,
+ ARG_SUBDIR,
+ ARG_UPSTREAM,
+ ARG_DOWNSTREAM,
+ ARG_YES,
+ ARG_ONLY_FAILED,
+ ARG_ONLY_RUNNING,
+ ARG_EXCLUDE_SUBDAGS,
+ ARG_EXCLUDE_PARENTDAG,
+ ARG_DAG_REGEX,
),
),
ActionCommand(
@@ -974,9 +909,22 @@ class GroupCommand(NamedTuple):
help="Run a single task instance",
func=lazy_load_command('airflow.cli.commands.task_command.task_run'),
args=(
- ARG_DAG_ID, ARG_TASK_ID, ARG_EXECUTION_DATE, ARG_SUBDIR, ARG_MARK_SUCCESS, ARG_FORCE,
- ARG_POOL, ARG_CFG_PATH, ARG_LOCAL, ARG_RAW, ARG_IGNORE_ALL_DEPENDENCIES,
- ARG_IGNORE_DEPENDENCIES, ARG_IGNORE_DEPENDS_ON_PAST, ARG_SHIP_DAG, ARG_PICKLE, ARG_JOB_ID,
+ ARG_DAG_ID,
+ ARG_TASK_ID,
+ ARG_EXECUTION_DATE,
+ ARG_SUBDIR,
+ ARG_MARK_SUCCESS,
+ ARG_FORCE,
+ ARG_POOL,
+ ARG_CFG_PATH,
+ ARG_LOCAL,
+ ARG_RAW,
+ ARG_IGNORE_ALL_DEPENDENCIES,
+ ARG_IGNORE_DEPENDENCIES,
+ ARG_IGNORE_DEPENDS_ON_PAST,
+ ARG_SHIP_DAG,
+ ARG_PICKLE,
+ ARG_JOB_ID,
ARG_INTERACTIVE,
),
),
@@ -989,8 +937,14 @@ class GroupCommand(NamedTuple):
),
func=lazy_load_command('airflow.cli.commands.task_command.task_test'),
args=(
- ARG_DAG_ID, ARG_TASK_ID, ARG_EXECUTION_DATE, ARG_SUBDIR, ARG_DRY_RUN,
- ARG_TASK_PARAMS, ARG_POST_MORTEM, ARG_ENV_VARS
+ ARG_DAG_ID,
+ ARG_TASK_ID,
+ ARG_EXECUTION_DATE,
+ ARG_SUBDIR,
+ ARG_DRY_RUN,
+ ARG_TASK_PARAMS,
+ ARG_POST_MORTEM,
+ ARG_ENV_VARS,
),
),
ActionCommand(
@@ -1011,31 +965,48 @@ class GroupCommand(NamedTuple):
name='get',
help='Get pool size',
func=lazy_load_command('airflow.cli.commands.pool_command.pool_get'),
- args=(ARG_POOL_NAME, ARG_OUTPUT,),
+ args=(
+ ARG_POOL_NAME,
+ ARG_OUTPUT,
+ ),
),
ActionCommand(
name='set',
help='Configure pool',
func=lazy_load_command('airflow.cli.commands.pool_command.pool_set'),
- args=(ARG_POOL_NAME, ARG_POOL_SLOTS, ARG_POOL_DESCRIPTION, ARG_OUTPUT,),
+ args=(
+ ARG_POOL_NAME,
+ ARG_POOL_SLOTS,
+ ARG_POOL_DESCRIPTION,
+ ARG_OUTPUT,
+ ),
),
ActionCommand(
name='delete',
help='Delete pool',
func=lazy_load_command('airflow.cli.commands.pool_command.pool_delete'),
- args=(ARG_POOL_NAME, ARG_OUTPUT,),
+ args=(
+ ARG_POOL_NAME,
+ ARG_OUTPUT,
+ ),
),
ActionCommand(
name='import',
help='Import pools',
func=lazy_load_command('airflow.cli.commands.pool_command.pool_import'),
- args=(ARG_POOL_IMPORT, ARG_OUTPUT,),
+ args=(
+ ARG_POOL_IMPORT,
+ ARG_OUTPUT,
+ ),
),
ActionCommand(
name='export',
help='Export all pools',
func=lazy_load_command('airflow.cli.commands.pool_command.pool_export'),
- args=(ARG_POOL_EXPORT, ARG_OUTPUT,),
+ args=(
+ ARG_POOL_EXPORT,
+ ARG_OUTPUT,
+ ),
),
)
VARIABLES_COMMANDS = (
@@ -1086,9 +1057,7 @@ class GroupCommand(NamedTuple):
ActionCommand(
name="check-migrations",
help="Check if migration have finished",
- description=(
- "Check if migration have finished (or continually check until timeout)"
- ),
+ description=("Check if migration have finished (or continually check until timeout)"),
func=lazy_load_command('airflow.cli.commands.db_command.check_migrations'),
args=(ARG_MIGRATION_TIMEOUT,),
),
@@ -1145,18 +1114,23 @@ class GroupCommand(NamedTuple):
ActionCommand(
name='export',
help='Export all connections',
- description=("All connections can be exported in STDOUT using the following command:\n"
- "airflow connections export -\n"
- "The file format can be determined by the provided file extension. eg, The following "
- "command will export the connections in JSON format:\n"
- "airflow connections export /tmp/connections.json\n"
- "The --format parameter can be used to mention the connections format. eg, "
- "the default format is JSON in STDOUT mode, which can be overridden using: \n"
- "airflow connections export - --format yaml\n"
- "The --format parameter can also be used for the files, for example:\n"
- "airflow connections export /tmp/connections --format json\n"),
+ description=(
+ "All connections can be exported in STDOUT using the following command:\n"
+ "airflow connections export -\n"
+ "The file format can be determined by the provided file extension. eg, The following "
+ "command will export the connections in JSON format:\n"
+ "airflow connections export /tmp/connections.json\n"
+ "The --format parameter can be used to mention the connections format. eg, "
+ "the default format is JSON in STDOUT mode, which can be overridden using: \n"
+ "airflow connections export - --format yaml\n"
+ "The --format parameter can also be used for the files, for example:\n"
+ "airflow connections export /tmp/connections --format json\n"
+ ),
func=lazy_load_command('airflow.cli.commands.connection_command.connections_export'),
- args=(ARG_CONN_EXPORT, ARG_CONN_EXPORT_FORMAT,),
+ args=(
+ ARG_CONN_EXPORT,
+ ARG_CONN_EXPORT_FORMAT,
+ ),
),
)
USERS_COMMANDS = (
@@ -1171,8 +1145,13 @@ class GroupCommand(NamedTuple):
help='Create a user',
func=lazy_load_command('airflow.cli.commands.user_command.users_create'),
args=(
- ARG_ROLE, ARG_USERNAME, ARG_EMAIL, ARG_FIRSTNAME, ARG_LASTNAME, ARG_PASSWORD,
- ARG_USE_RANDOM_PASSWORD
+ ARG_ROLE,
+ ARG_USERNAME,
+ ARG_EMAIL,
+ ARG_FIRSTNAME,
+ ARG_LASTNAME,
+ ARG_PASSWORD,
+ ARG_USE_RANDOM_PASSWORD,
),
epilog=(
'examples:\n'
@@ -1184,7 +1163,7 @@ class GroupCommand(NamedTuple):
' --lastname LAST_NAME \\\n'
' --role Admin \\\n'
' --email admin@example.org'
- )
+ ),
),
ActionCommand(
name='delete',
@@ -1238,8 +1217,17 @@ class GroupCommand(NamedTuple):
help="Start a Celery worker node",
func=lazy_load_command('airflow.cli.commands.celery_command.worker'),
args=(
- ARG_QUEUES, ARG_CONCURRENCY, ARG_CELERY_HOSTNAME, ARG_PID, ARG_DAEMON,
- ARG_UMASK, ARG_STDOUT, ARG_STDERR, ARG_LOG_FILE, ARG_AUTOSCALE, ARG_SKIP_SERVE_LOGS
+ ARG_QUEUES,
+ ARG_CONCURRENCY,
+ ARG_CELERY_HOSTNAME,
+ ARG_PID,
+ ARG_DAEMON,
+ ARG_UMASK,
+ ARG_STDOUT,
+ ARG_STDERR,
+ ARG_LOG_FILE,
+ ARG_AUTOSCALE,
+ ARG_SKIP_SERVE_LOGS,
),
),
ActionCommand(
@@ -1247,9 +1235,17 @@ class GroupCommand(NamedTuple):
help="Start a Celery Flower",
func=lazy_load_command('airflow.cli.commands.celery_command.flower'),
args=(
- ARG_FLOWER_HOSTNAME, ARG_FLOWER_PORT, ARG_FLOWER_CONF, ARG_FLOWER_URL_PREFIX,
- ARG_FLOWER_BASIC_AUTH, ARG_BROKER_API, ARG_PID, ARG_DAEMON, ARG_STDOUT, ARG_STDERR,
- ARG_LOG_FILE
+ ARG_FLOWER_HOSTNAME,
+ ARG_FLOWER_PORT,
+ ARG_FLOWER_CONF,
+ ARG_FLOWER_URL_PREFIX,
+ ARG_FLOWER_BASIC_AUTH,
+ ARG_BROKER_API,
+ ARG_PID,
+ ARG_DAEMON,
+ ARG_STDOUT,
+ ARG_STDERR,
+ ARG_LOG_FILE,
),
),
ActionCommand(
@@ -1257,7 +1253,7 @@ class GroupCommand(NamedTuple):
help="Stop the Celery worker gracefully",
func=lazy_load_command('airflow.cli.commands.celery_command.stop_worker'),
args=(),
- )
+ ),
)
CONFIG_COMMANDS = (
@@ -1265,13 +1261,16 @@ class GroupCommand(NamedTuple):
name='get-value',
help='Print the value of the configuration',
func=lazy_load_command('airflow.cli.commands.config_command.get_value'),
- args=(ARG_SECTION, ARG_OPTION, ),
+ args=(
+ ARG_SECTION,
+ ARG_OPTION,
+ ),
),
ActionCommand(
name='list',
help='List options for the configuration',
func=lazy_load_command('airflow.cli.commands.config_command.show_config'),
- args=(ARG_COLOR, ),
+ args=(ARG_COLOR,),
),
)
@@ -1280,12 +1279,12 @@ class GroupCommand(NamedTuple):
name='cleanup-pods',
help="Clean up Kubernetes pods in evicted/failed/succeeded states",
func=lazy_load_command('airflow.cli.commands.kubernetes_command.cleanup_pods'),
- args=(ARG_NAMESPACE, ),
+ args=(ARG_NAMESPACE,),
),
ActionCommand(
name='generate-dag-yaml',
help="Generate YAML files for all tasks in DAG. Useful for debugging tasks without "
- "launching into a cluster",
+ "launching into a cluster",
func=lazy_load_command('airflow.cli.commands.kubernetes_command.generate_pod_yaml'),
args=(ARG_DAG_ID, ARG_EXECUTION_DATE, ARG_SUBDIR, ARG_OUTPUT_PATH),
),
@@ -1298,9 +1297,7 @@ class GroupCommand(NamedTuple):
subcommands=DAGS_COMMANDS,
),
GroupCommand(
- name="kubernetes",
- help='Tools to help run the KubernetesExecutor',
- subcommands=KUBERNETES_COMMANDS
+ name="kubernetes", help='Tools to help run the KubernetesExecutor', subcommands=KUBERNETES_COMMANDS
),
GroupCommand(
name='tasks',
@@ -1333,9 +1330,21 @@ class GroupCommand(NamedTuple):
help="Start a Airflow webserver instance",
func=lazy_load_command('airflow.cli.commands.webserver_command.webserver'),
args=(
- ARG_PORT, ARG_WORKERS, ARG_WORKERCLASS, ARG_WORKER_TIMEOUT, ARG_HOSTNAME, ARG_PID,
- ARG_DAEMON, ARG_STDOUT, ARG_STDERR, ARG_ACCESS_LOGFILE, ARG_ERROR_LOGFILE, ARG_LOG_FILE,
- ARG_SSL_CERT, ARG_SSL_KEY, ARG_DEBUG
+ ARG_PORT,
+ ARG_WORKERS,
+ ARG_WORKERCLASS,
+ ARG_WORKER_TIMEOUT,
+ ARG_HOSTNAME,
+ ARG_PID,
+ ARG_DAEMON,
+ ARG_STDOUT,
+ ARG_STDERR,
+ ARG_ACCESS_LOGFILE,
+ ARG_ERROR_LOGFILE,
+ ARG_LOG_FILE,
+ ARG_SSL_CERT,
+ ARG_SSL_KEY,
+ ARG_DEBUG,
),
),
ActionCommand(
@@ -1343,8 +1352,14 @@ class GroupCommand(NamedTuple):
help="Start a scheduler instance",
func=lazy_load_command('airflow.cli.commands.scheduler_command.scheduler'),
args=(
- ARG_SUBDIR, ARG_NUM_RUNS, ARG_DO_PICKLE, ARG_PID, ARG_DAEMON, ARG_STDOUT,
- ARG_STDERR, ARG_LOG_FILE
+ ARG_SUBDIR,
+ ARG_NUM_RUNS,
+ ARG_DO_PICKLE,
+ ARG_PID,
+ ARG_DAEMON,
+ ARG_STDOUT,
+ ARG_STDERR,
+ ARG_LOG_FILE,
),
),
ActionCommand(
@@ -1391,16 +1406,15 @@ class GroupCommand(NamedTuple):
),
args=(),
),
- GroupCommand(
- name="config",
- help='View configuration',
- subcommands=CONFIG_COMMANDS
- ),
+ GroupCommand(name="config", help='View configuration', subcommands=CONFIG_COMMANDS),
ActionCommand(
name='info',
help='Show information about current Airflow and environment',
func=lazy_load_command('airflow.cli.commands.info_command.show_info'),
- args=(ARG_ANONYMIZE, ARG_FILE_IO, ),
+ args=(
+ ARG_ANONYMIZE,
+ ARG_FILE_IO,
+ ),
),
ActionCommand(
name='plugins',
@@ -1415,13 +1429,11 @@ class GroupCommand(NamedTuple):
'Start celery components. Works only when using CeleryExecutor. For more information, see '
'https://airflow.readthedocs.io/en/stable/executor/celery.html'
),
- subcommands=CELERY_COMMANDS
- )
+ subcommands=CELERY_COMMANDS,
+ ),
]
ALL_COMMANDS_DICT: Dict[str, CLICommand] = {sp.name: sp for sp in airflow_commands}
-DAG_CLI_COMMANDS: Set[str] = {
- 'list_tasks', 'backfill', 'test', 'run', 'pause', 'unpause', 'list_dag_runs'
-}
+DAG_CLI_COMMANDS: Set[str] = {'list_tasks', 'backfill', 'test', 'run', 'pause', 'unpause', 'list_dag_runs'}
class AirflowHelpFormatter(argparse.HelpFormatter):
@@ -1483,17 +1495,18 @@ def get_parser(dag_parser: bool = False) -> argparse.ArgumentParser:
def _sort_args(args: Iterable[Arg]) -> Iterable[Arg]:
"""Sort subcommand optional args, keep positional args"""
+
def get_long_option(arg: Arg):
"""Get long option from Arg.flags"""
return arg.flags[0] if len(arg.flags) == 1 else arg.flags[1]
+
positional, optional = partition(lambda x: x.flags[0].startswith("-"), args)
yield from positional
yield from sorted(optional, key=lambda x: get_long_option(x).lower())
def _add_command(
- subparsers: argparse._SubParsersAction, # pylint: disable=protected-access
- sub: CLICommand
+ subparsers: argparse._SubParsersAction, sub: CLICommand # pylint: disable=protected-access
) -> None:
sub_proc = subparsers.add_parser(
sub.name, help=sub.help, description=sub.description or sub.help, epilog=sub.epilog
diff --git a/airflow/cli/commands/cheat_sheet_command.py b/airflow/cli/commands/cheat_sheet_command.py
index 24a86d319d99c..0e9abaa3009af 100644
--- a/airflow/cli/commands/cheat_sheet_command.py
+++ b/airflow/cli/commands/cheat_sheet_command.py
@@ -41,6 +41,7 @@ def cheat_sheet(args):
def display_commands_index():
"""Display list of all commands."""
+
def display_recursive(prefix: List[str], commands: Iterable[Union[GroupCommand, ActionCommand]]):
actions: List[ActionCommand]
groups: List[GroupCommand]
diff --git a/airflow/cli/commands/config_command.py b/airflow/cli/commands/config_command.py
index 0ec624cffb4f3..2eedfa5cbc373 100644
--- a/airflow/cli/commands/config_command.py
+++ b/airflow/cli/commands/config_command.py
@@ -32,9 +32,7 @@ def show_config(args):
conf.write(output)
code = output.getvalue()
if should_use_colors(args):
- code = pygments.highlight(
- code=code, formatter=get_terminal_formatter(), lexer=IniLexer()
- )
+ code = pygments.highlight(code=code, formatter=get_terminal_formatter(), lexer=IniLexer())
print(code)
diff --git a/airflow/cli/commands/connection_command.py b/airflow/cli/commands/connection_command.py
index 3a8d31cd5d155..6708806165692 100644
--- a/airflow/cli/commands/connection_command.py
+++ b/airflow/cli/commands/connection_command.py
@@ -47,7 +47,8 @@ def _tabulate_connection(conns: List[Connection], tablefmt: str):
'Is Encrypted': conn.is_encrypted,
'Is Extra Encrypted': conn.is_encrypted,
'Extra': conn.extra,
- } for conn in conns
+ }
+ for conn in conns
]
msg = tabulate(tabulate_data, tablefmt=tablefmt, headers='keys')
@@ -67,7 +68,7 @@ def _yamulate_connection(conn: Connection):
'Is Encrypted': conn.is_encrypted,
'Is Extra Encrypted': conn.is_encrypted,
'Extra': conn.extra_dejson,
- 'URI': conn.get_uri()
+ 'URI': conn.get_uri(),
}
return yaml.safe_dump(yaml_data, sort_keys=False)
@@ -150,8 +151,10 @@ def connections_export(args):
_, filetype = os.path.splitext(args.file.name)
filetype = filetype.lower()
if filetype not in allowed_formats:
- msg = f"Unsupported file format. " \
- f"The file must have the extension {', '.join(allowed_formats)}"
+ msg = (
+ f"Unsupported file format. "
+ f"The file must have the extension {', '.join(allowed_formats)}"
+ )
raise SystemExit(msg)
connections = session.query(Connection).order_by(Connection.conn_id).all()
@@ -164,8 +167,7 @@ def connections_export(args):
print(f"Connections successfully exported to {args.file.name}")
-alternative_conn_specs = ['conn_type', 'conn_host',
- 'conn_login', 'conn_password', 'conn_schema', 'conn_port']
+alternative_conn_specs = ['conn_type', 'conn_host', 'conn_login', 'conn_password', 'conn_schema', 'conn_port']
@cli_utils.action_logging
@@ -181,42 +183,51 @@ def connections_add(args):
elif not args.conn_type:
missing_args.append('conn-uri or conn-type')
if missing_args:
- msg = ('The following args are required to add a connection:' +
- f' {missing_args!r}')
+ msg = 'The following args are required to add a connection:' + f' {missing_args!r}'
raise SystemExit(msg)
if invalid_args:
- msg = ('The following args are not compatible with the ' +
- 'add flag and --conn-uri flag: {invalid!r}')
+ msg = 'The following args are not compatible with the ' + 'add flag and --conn-uri flag: {invalid!r}'
msg = msg.format(invalid=invalid_args)
raise SystemExit(msg)
if args.conn_uri:
new_conn = Connection(conn_id=args.conn_id, uri=args.conn_uri)
else:
- new_conn = Connection(conn_id=args.conn_id,
- conn_type=args.conn_type,
- host=args.conn_host,
- login=args.conn_login,
- password=args.conn_password,
- schema=args.conn_schema,
- port=args.conn_port)
+ new_conn = Connection(
+ conn_id=args.conn_id,
+ conn_type=args.conn_type,
+ host=args.conn_host,
+ login=args.conn_login,
+ password=args.conn_password,
+ schema=args.conn_schema,
+ port=args.conn_port,
+ )
if args.conn_extra is not None:
new_conn.set_extra(args.conn_extra)
with create_session() as session:
- if not (session.query(Connection)
- .filter(Connection.conn_id == new_conn.conn_id).first()):
+ if not session.query(Connection).filter(Connection.conn_id == new_conn.conn_id).first():
session.add(new_conn)
msg = '\n\tSuccessfully added `conn_id`={conn_id} : {uri}\n'
- msg = msg.format(conn_id=new_conn.conn_id,
- uri=args.conn_uri or
- urlunparse((args.conn_type,
- '{login}:{password}@{host}:{port}'
- .format(login=args.conn_login or '',
- password='******' if args.conn_password else '',
- host=args.conn_host or '',
- port=args.conn_port or ''),
- args.conn_schema or '', '', '', '')))
+ msg = msg.format(
+ conn_id=new_conn.conn_id,
+ uri=args.conn_uri
+ or urlunparse(
+ (
+ args.conn_type,
+ '{login}:{password}@{host}:{port}'.format(
+ login=args.conn_login or '',
+ password='******' if args.conn_password else '',
+ host=args.conn_host or '',
+ port=args.conn_port or '',
+ ),
+ args.conn_schema or '',
+ '',
+ '',
+ '',
+ )
+ ),
+ )
print(msg)
else:
msg = '\n\tA connection with `conn_id`={conn_id} already exists\n'
@@ -229,18 +240,14 @@ def connections_delete(args):
"""Deletes connection from DB"""
with create_session() as session:
try:
- to_delete = (session
- .query(Connection)
- .filter(Connection.conn_id == args.conn_id)
- .one())
+ to_delete = session.query(Connection).filter(Connection.conn_id == args.conn_id).one()
except exc.NoResultFound:
msg = '\n\tDid not find a connection with `conn_id`={conn_id}\n'
msg = msg.format(conn_id=args.conn_id)
print(msg)
return
except exc.MultipleResultsFound:
- msg = ('\n\tFound more than one connection with ' +
- '`conn_id`={conn_id}\n')
+ msg = '\n\tFound more than one connection with ' + '`conn_id`={conn_id}\n'
msg = msg.format(conn_id=args.conn_id)
print(msg)
return
diff --git a/airflow/cli/commands/dag_command.py b/airflow/cli/commands/dag_command.py
index 456b5b9838ddf..a3d1fe1bfbff5 100644
--- a/airflow/cli/commands/dag_command.py
+++ b/airflow/cli/commands/dag_command.py
@@ -52,12 +52,10 @@ def _tabulate_dag_runs(dag_runs: List[DagRun], tablefmt: str = "fancy_grid") ->
'Execution date': dag_run.execution_date.isoformat(),
'Start date': dag_run.start_date.isoformat() if dag_run.start_date else '',
'End date': dag_run.end_date.isoformat() if dag_run.end_date else '',
- } for dag_run in dag_runs
- )
- return tabulate(
- tabular_data=tabulate_data,
- tablefmt=tablefmt
+ }
+ for dag_run in dag_runs
)
+ return tabulate(tabular_data=tabulate_data, tablefmt=tablefmt)
def _tabulate_dags(dags: List[DAG], tablefmt: str = "fancy_grid") -> str:
@@ -66,27 +64,25 @@ def _tabulate_dags(dags: List[DAG], tablefmt: str = "fancy_grid") -> str:
'DAG ID': dag.dag_id,
'Filepath': dag.filepath,
'Owner': dag.owner,
- } for dag in sorted(dags, key=lambda d: d.dag_id)
- )
- return tabulate(
- tabular_data=tabulate_data,
- tablefmt=tablefmt,
- headers='keys'
+ }
+ for dag in sorted(dags, key=lambda d: d.dag_id)
)
+ return tabulate(tabular_data=tabulate_data, tablefmt=tablefmt, headers='keys')
@cli_utils.action_logging
def dag_backfill(args, dag=None):
"""Creates backfill job or dry run for a DAG"""
- logging.basicConfig(
- level=settings.LOGGING_LEVEL,
- format=settings.SIMPLE_LOG_FORMAT)
+ logging.basicConfig(level=settings.LOGGING_LEVEL, format=settings.SIMPLE_LOG_FORMAT)
signal.signal(signal.SIGTERM, sigint_handler)
import warnings
- warnings.warn('--ignore-first-depends-on-past is deprecated as the value is always set to True',
- category=PendingDeprecationWarning)
+
+ warnings.warn(
+ '--ignore-first-depends-on-past is deprecated as the value is always set to True',
+ category=PendingDeprecationWarning,
+ )
if args.ignore_first_depends_on_past is False:
args.ignore_first_depends_on_past = True
@@ -102,16 +98,15 @@ def dag_backfill(args, dag=None):
if args.task_regex:
dag = dag.partial_subset(
- task_ids_or_regex=args.task_regex,
- include_upstream=not args.ignore_dependencies)
+ task_ids_or_regex=args.task_regex, include_upstream=not args.ignore_dependencies
+ )
run_conf = None
if args.conf:
run_conf = json.loads(args.conf)
if args.dry_run:
- print("Dry run of DAG {} on {}".format(args.dag_id,
- args.start_date))
+ print(f"Dry run of DAG {args.dag_id} on {args.start_date}")
for task in dag.tasks:
print(f"Task {task.task_id}")
ti = TaskInstance(task, args.start_date)
@@ -132,8 +127,7 @@ def dag_backfill(args, dag=None):
end_date=args.end_date,
mark_success=args.mark_success,
local=args.local,
- donot_pickle=(args.donot_pickle or
- conf.getboolean('core', 'donot_pickle')),
+ donot_pickle=(args.donot_pickle or conf.getboolean('core', 'donot_pickle')),
ignore_first_depends_on_past=args.ignore_first_depends_on_past,
ignore_task_deps=args.ignore_dependencies,
pool=args.pool,
@@ -141,7 +135,7 @@ def dag_backfill(args, dag=None):
verbose=args.verbose,
conf=run_conf,
rerun_failed_tasks=args.rerun_failed_tasks,
- run_backwards=args.run_backwards
+ run_backwards=args.run_backwards,
)
@@ -150,10 +144,9 @@ def dag_trigger(args):
"""Creates a dag run for the specified dag"""
api_client = get_current_api_client()
try:
- message = api_client.trigger_dag(dag_id=args.dag_id,
- run_id=args.run_id,
- conf=args.conf,
- execution_date=args.exec_date)
+ message = api_client.trigger_dag(
+ dag_id=args.dag_id, run_id=args.run_id, conf=args.conf, execution_date=args.exec_date
+ )
print(message)
except OSError as err:
raise AirflowException(err)
@@ -163,9 +156,13 @@ def dag_trigger(args):
def dag_delete(args):
"""Deletes all DB records related to the specified dag"""
api_client = get_current_api_client()
- if args.yes or input(
- "This will drop all existing records related to the specified DAG. "
- "Proceed? (y/n)").upper() == "Y":
+ if (
+ args.yes
+ or input(
+ "This will drop all existing records related to the specified DAG. " "Proceed? (y/n)"
+ ).upper()
+ == "Y"
+ ):
try:
message = api_client.delete_dag(dag_id=args.dag_id)
print(message)
@@ -207,7 +204,7 @@ def dag_show(args):
print(
"Option --save and --imgcat are mutually exclusive. "
"Please remove one option to execute the command.",
- file=sys.stderr
+ file=sys.stderr,
)
sys.exit(1)
elif filename:
@@ -280,8 +277,11 @@ def dag_next_execution(args):
next_execution_dttm = dag.following_schedule(latest_execution_date)
if next_execution_dttm is None:
- print("[WARN] No following schedule can be found. " +
- "This DAG may have schedule interval '@once' or `None`.", file=sys.stderr)
+ print(
+ "[WARN] No following schedule can be found. "
+ + "This DAG may have schedule interval '@once' or `None`.",
+ file=sys.stderr,
+ )
print(None)
else:
print(next_execution_dttm)
@@ -327,17 +327,18 @@ def dag_list_jobs(args, dag=None):
queries.append(BaseJob.state == args.state)
with create_session() as session:
- all_jobs = (session
- .query(BaseJob)
- .filter(*queries)
- .order_by(BaseJob.start_date.desc())
- .limit(args.limit)
- .all())
+ all_jobs = (
+ session.query(BaseJob)
+ .filter(*queries)
+ .order_by(BaseJob.start_date.desc())
+ .limit(args.limit)
+ .all()
+ )
fields = ['dag_id', 'state', 'job_type', 'start_date', 'end_date']
all_jobs = [[job.__getattribute__(field) for field in fields] for job in all_jobs]
- msg = tabulate(all_jobs,
- [field.capitalize().replace('_', ' ') for field in fields],
- tablefmt=args.output)
+ msg = tabulate(
+ all_jobs, [field.capitalize().replace('_', ' ') for field in fields], tablefmt=args.output
+ )
print(msg)
@@ -367,10 +368,7 @@ def dag_list_dag_runs(args, dag=None):
return
dag_runs.sort(key=lambda x: x.execution_date, reverse=True)
- table = _tabulate_dag_runs(
- dag_runs,
- tablefmt=args.output
- )
+ table = _tabulate_dag_runs(dag_runs, tablefmt=args.output)
print(table)
@@ -389,10 +387,14 @@ def dag_test(args, session=None):
imgcat = args.imgcat_dagrun
filename = args.save_dagrun
if show_dagrun or imgcat or filename:
- tis = session.query(TaskInstance).filter(
- TaskInstance.dag_id == args.dag_id,
- TaskInstance.execution_date == args.execution_date,
- ).all()
+ tis = (
+ session.query(TaskInstance)
+ .filter(
+ TaskInstance.dag_id == args.dag_id,
+ TaskInstance.execution_date == args.execution_date,
+ )
+ .all()
+ )
dot_graph = render_dag(dag, tis=tis)
print()
diff --git a/airflow/cli/commands/db_command.py b/airflow/cli/commands/db_command.py
index 6ba5b375f3eb7..1b05bd0cce6af 100644
--- a/airflow/cli/commands/db_command.py
+++ b/airflow/cli/commands/db_command.py
@@ -35,9 +35,7 @@ def initdb(args):
def resetdb(args):
"""Resets the metadata database"""
print("DB: " + repr(settings.engine.url))
- if args.yes or input("This will drop existing tables "
- "if they exist. Proceed? "
- "(y/n)").upper() == "Y":
+ if args.yes or input("This will drop existing tables " "if they exist. Proceed? " "(y/n)").upper() == "Y":
db.resetdb()
else:
print("Cancelled")
@@ -63,14 +61,16 @@ def shell(args):
if url.get_backend_name() == 'mysql':
with NamedTemporaryFile(suffix="my.cnf") as f:
- content = textwrap.dedent(f"""
+ content = textwrap.dedent(
+ f"""
[client]
host = {url.host}
user = {url.username}
password = {url.password or ""}
port = {url.port or ""}
database = {url.database}
- """).strip()
+ """
+ ).strip()
f.write(content.encode())
f.flush()
execute_interactive(["mysql", f"--defaults-extra-file={f.name}"])
diff --git a/airflow/cli/commands/info_command.py b/airflow/cli/commands/info_command.py
index 3fcad451464a0..a60a8c2cfd48a 100644
--- a/airflow/cli/commands/info_command.py
+++ b/airflow/cli/commands/info_command.py
@@ -303,19 +303,17 @@ def __init__(self, anonymizer: Anonymizer):
@property
def task_logging_handler(self):
"""Returns task logging handler."""
+
def get_fullname(o):
module = o.__class__.__module__
if module is None or module == str.__class__.__module__:
return o.__class__.__name__ # Avoid reporting __builtin__
else:
return module + '.' + o.__class__.__name__
+
try:
- handler_names = [
- get_fullname(handler) for handler in logging.getLogger('airflow.task').handlers
- ]
- return ", ".join(
- handler_names
- )
+ handler_names = [get_fullname(handler) for handler in logging.getLogger('airflow.task').handlers]
+ return ", ".join(handler_names)
except Exception: # noqa pylint: disable=broad-except
return "NOT AVAILABLE"
diff --git a/airflow/cli/commands/legacy_commands.py b/airflow/cli/commands/legacy_commands.py
index b66cc6f575371..94f9b690327b2 100644
--- a/airflow/cli/commands/legacy_commands.py
+++ b/airflow/cli/commands/legacy_commands.py
@@ -44,7 +44,7 @@
"pool": "pools",
"list_users": "users list",
"create_user": "users create",
- "delete_user": "users delete"
+ "delete_user": "users delete",
}
diff --git a/airflow/cli/commands/pool_command.py b/airflow/cli/commands/pool_command.py
index d824bd40ee16b..4984618d8e026 100644
--- a/airflow/cli/commands/pool_command.py
+++ b/airflow/cli/commands/pool_command.py
@@ -27,8 +27,7 @@
def _tabulate_pools(pools, tablefmt="fancy_grid"):
- return "\n%s" % tabulate(pools, ['Pool', 'Slots', 'Description'],
- tablefmt=tablefmt)
+ return "\n%s" % tabulate(pools, ['Pool', 'Slots', 'Description'], tablefmt=tablefmt)
def pool_list(args):
@@ -49,9 +48,7 @@ def pool_get(args):
def pool_set(args):
"""Creates new pool with a given name and slots"""
api_client = get_current_api_client()
- pools = [api_client.create_pool(name=args.pool,
- slots=args.slots,
- description=args.description)]
+ pools = [api_client.create_pool(name=args.pool, slots=args.slots, description=args.description)]
print(_tabulate_pools(pools=pools, tablefmt=args.output))
@@ -97,9 +94,9 @@ def pool_import_helper(filepath):
counter = 0
for k, v in pools_json.items():
if isinstance(v, dict) and len(v) == 2:
- pools.append(api_client.create_pool(name=k,
- slots=v["slots"],
- description=v["description"]))
+ pools.append(
+ api_client.create_pool(name=k, slots=v["slots"], description=v["description"])
+ )
counter += 1
else:
pass
diff --git a/airflow/cli/commands/role_command.py b/airflow/cli/commands/role_command.py
index b4e6f591a1048..fb34d8c9a6003 100644
--- a/airflow/cli/commands/role_command.py
+++ b/airflow/cli/commands/role_command.py
@@ -29,9 +29,7 @@ def roles_list(args):
roles = appbuilder.sm.get_all_roles()
print("Existing roles:\n")
role_names = sorted([[r.name] for r in roles])
- msg = tabulate(role_names,
- headers=['Role'],
- tablefmt=args.output)
+ msg = tabulate(role_names, headers=['Role'], tablefmt=args.output)
print(msg)
diff --git a/airflow/cli/commands/rotate_fernet_key_command.py b/airflow/cli/commands/rotate_fernet_key_command.py
index 784eae86572c7..10c32d2a1fbaa 100644
--- a/airflow/cli/commands/rotate_fernet_key_command.py
+++ b/airflow/cli/commands/rotate_fernet_key_command.py
@@ -24,8 +24,7 @@
def rotate_fernet_key(args):
"""Rotates all encrypted connection credentials and variables"""
with create_session() as session:
- for conn in session.query(Connection).filter(
- Connection.is_encrypted | Connection.is_extra_encrypted):
+ for conn in session.query(Connection).filter(Connection.is_encrypted | Connection.is_extra_encrypted):
conn.rotate_fernet_key()
for var in session.query(Variable).filter(Variable.is_encrypted):
var.rotate_fernet_key()
diff --git a/airflow/cli/commands/scheduler_command.py b/airflow/cli/commands/scheduler_command.py
index f0f019a58ac12..100a0f1d0f775 100644
--- a/airflow/cli/commands/scheduler_command.py
+++ b/airflow/cli/commands/scheduler_command.py
@@ -34,14 +34,13 @@ def scheduler(args):
job = SchedulerJob(
subdir=process_subdir(args.subdir),
num_runs=args.num_runs,
- do_pickle=args.do_pickle)
+ do_pickle=args.do_pickle,
+ )
if args.daemon:
- pid, stdout, stderr, log_file = setup_locations("scheduler",
- args.pid,
- args.stdout,
- args.stderr,
- args.log_file)
+ pid, stdout, stderr, log_file = setup_locations(
+ "scheduler", args.pid, args.stdout, args.stderr, args.log_file
+ )
handle = setup_logging(log_file)
stdout = open(stdout, 'w+')
stderr = open(stderr, 'w+')
diff --git a/airflow/cli/commands/sync_perm_command.py b/airflow/cli/commands/sync_perm_command.py
index 3a31ea0df3ee3..b072e8de1016a 100644
--- a/airflow/cli/commands/sync_perm_command.py
+++ b/airflow/cli/commands/sync_perm_command.py
@@ -30,6 +30,4 @@ def sync_perm(args):
print('Updating permission on all DAG views')
dags = DagBag(read_dags_from_db=True).dags.values()
for dag in dags:
- appbuilder.sm.sync_perm_for_dag(
- dag.dag_id,
- dag.access_control)
+ appbuilder.sm.sync_perm_for_dag(dag.dag_id, dag.access_control)
diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py
index 46f670b5742c0..3fa1d1fb5a0fa 100644
--- a/airflow/cli/commands/task_command.py
+++ b/airflow/cli/commands/task_command.py
@@ -77,8 +77,7 @@ def _run_task_by_executor(args, dag, ti):
session.add(pickle)
pickle_id = pickle.id
# TODO: This should be written to a log
- print('Pickled dag {dag} as pickle_id: {pickle_id}'.format(
- dag=dag, pickle_id=pickle_id))
+ print(f'Pickled dag {dag} as pickle_id: {pickle_id}')
except Exception as e:
print('Could not pickle the DAG')
print(e)
@@ -94,7 +93,8 @@ def _run_task_by_executor(args, dag, ti):
ignore_depends_on_past=args.ignore_depends_on_past,
ignore_task_deps=args.ignore_dependencies,
ignore_ti_state=args.force,
- pool=args.pool)
+ pool=args.pool,
+ )
executor.heartbeat()
executor.end()
@@ -109,12 +109,16 @@ def _run_task_by_local_task_job(args, ti):
ignore_depends_on_past=args.ignore_depends_on_past,
ignore_task_deps=args.ignore_dependencies,
ignore_ti_state=args.force,
- pool=args.pool)
+ pool=args.pool,
+ )
run_job.run()
RAW_TASK_UNSUPPORTED_OPTION = [
- "ignore_all_dependencies", "ignore_depends_on_past", "ignore_dependencies", "force"
+ "ignore_all_dependencies",
+ "ignore_depends_on_past",
+ "ignore_dependencies",
+ "force",
]
@@ -182,8 +186,9 @@ def task_run(args, dag=None):
_run_task_by_selected_method(args, dag, ti)
else:
if settings.DONOT_MODIFY_HANDLERS:
- with redirect_stdout(StreamLogWriter(ti.log, logging.INFO)), \
- redirect_stderr(StreamLogWriter(ti.log, logging.WARN)):
+ with redirect_stdout(StreamLogWriter(ti.log, logging.INFO)), redirect_stderr(
+ StreamLogWriter(ti.log, logging.WARN)
+ ):
_run_task_by_selected_method(args, dag, ti)
else:
# Get all the Handlers from 'airflow.task' logger
@@ -201,8 +206,9 @@ def task_run(args, dag=None):
root_logger.addHandler(handler)
root_logger.setLevel(logging.getLogger('airflow.task').level)
- with redirect_stdout(StreamLogWriter(ti.log, logging.INFO)), \
- redirect_stderr(StreamLogWriter(ti.log, logging.WARN)):
+ with redirect_stdout(StreamLogWriter(ti.log, logging.INFO)), redirect_stderr(
+ StreamLogWriter(ti.log, logging.WARN)
+ ):
_run_task_by_selected_method(args, dag, ti)
# We need to restore the handlers to the loggers as celery worker process
@@ -300,15 +306,18 @@ def task_states_for_dag_run(args):
"""Get the status of all task instances in a DagRun"""
session = settings.Session()
- tis = session.query(
- TaskInstance.dag_id,
- TaskInstance.execution_date,
- TaskInstance.task_id,
- TaskInstance.state,
- TaskInstance.start_date,
- TaskInstance.end_date).filter(
- TaskInstance.dag_id == args.dag_id,
- TaskInstance.execution_date == args.execution_date).all()
+ tis = (
+ session.query(
+ TaskInstance.dag_id,
+ TaskInstance.execution_date,
+ TaskInstance.task_id,
+ TaskInstance.state,
+ TaskInstance.start_date,
+ TaskInstance.end_date,
+ )
+ .filter(TaskInstance.dag_id == args.dag_id, TaskInstance.execution_date == args.execution_date)
+ .all()
+ )
if len(tis) == 0:
raise AirflowException("DagRun does not exist.")
@@ -316,18 +325,18 @@ def task_states_for_dag_run(args):
formatted_rows = []
for ti in tis:
- formatted_rows.append((ti.dag_id,
- ti.execution_date,
- ti.task_id,
- ti.state,
- ti.start_date,
- ti.end_date))
+ formatted_rows.append(
+ (ti.dag_id, ti.execution_date, ti.task_id, ti.state, ti.start_date, ti.end_date)
+ )
print(
- "\n%s" %
- tabulate(
- formatted_rows, [
- 'dag', 'exec_date', 'task', 'state', 'start_date', 'end_date'], tablefmt=args.output))
+ "\n%s"
+ % tabulate(
+ formatted_rows,
+ ['dag', 'exec_date', 'task', 'state', 'start_date', 'end_date'],
+ tablefmt=args.output,
+ )
+ )
session.close()
@@ -387,20 +396,24 @@ def task_render(args):
ti = TaskInstance(task, args.execution_date)
ti.render_templates()
for attr in task.__class__.template_fields:
- print(textwrap.dedent("""\
+ print(
+ textwrap.dedent(
+ """\
# ----------------------------------------------------------
# property: {}
# ----------------------------------------------------------
{}
- """.format(attr, getattr(task, attr))))
+ """.format(
+ attr, getattr(task, attr)
+ )
+ )
+ )
@cli_utils.action_logging
def task_clear(args):
"""Clears all task instances or only those matched by regex for a DAG(s)"""
- logging.basicConfig(
- level=settings.LOGGING_LEVEL,
- format=settings.SIMPLE_LOG_FORMAT)
+ logging.basicConfig(level=settings.LOGGING_LEVEL, format=settings.SIMPLE_LOG_FORMAT)
if args.dag_id and not args.subdir and not args.dag_regex and not args.task_regex:
dags = get_dag_by_file_location(args.dag_id)
@@ -413,7 +426,8 @@ def task_clear(args):
dags[idx] = dag.partial_subset(
task_ids_or_regex=args.task_regex,
include_downstream=args.downstream,
- include_upstream=args.upstream)
+ include_upstream=args.upstream,
+ )
DAG.clear_dags(
dags,
diff --git a/airflow/cli/commands/user_command.py b/airflow/cli/commands/user_command.py
index 90a8f2a1f35eb..7536ce6928260 100644
--- a/airflow/cli/commands/user_command.py
+++ b/airflow/cli/commands/user_command.py
@@ -36,8 +36,7 @@ def users_list(args):
users = appbuilder.sm.get_all_users()
fields = ['id', 'username', 'email', 'first_name', 'last_name', 'roles']
users = [[user.__getattribute__(field) for field in fields] for user in users]
- msg = tabulate(users, [field.capitalize().replace('_', ' ') for field in fields],
- tablefmt=args.output)
+ msg = tabulate(users, [field.capitalize().replace('_', ' ') for field in fields], tablefmt=args.output)
print(msg)
@@ -63,8 +62,7 @@ def users_create(args):
if appbuilder.sm.find_user(args.username):
print(f'{args.username} already exist in the db')
return
- user = appbuilder.sm.add_user(args.username, args.firstname, args.lastname,
- args.email, role, password)
+ user = appbuilder.sm.add_user(args.username, args.firstname, args.lastname, args.email, role, password)
if user:
print(f'{args.role} user {args.username} created.')
else:
@@ -77,8 +75,7 @@ def users_delete(args):
appbuilder = cached_app().appbuilder # pylint: disable=no-member
try:
- user = next(u for u in appbuilder.sm.get_all_users()
- if u.username == args.username)
+ user = next(u for u in appbuilder.sm.get_all_users() if u.username == args.username)
except StopIteration:
raise SystemExit(f'{args.username} is not a valid user.')
@@ -95,15 +92,12 @@ def users_manage_role(args, remove=False):
raise SystemExit('Missing args: must supply one of --username or --email')
if args.username and args.email:
- raise SystemExit('Conflicting args: must supply either --username'
- ' or --email, but not both')
+ raise SystemExit('Conflicting args: must supply either --username' ' or --email, but not both')
appbuilder = cached_app().appbuilder # pylint: disable=no-member
- user = (appbuilder.sm.find_user(username=args.username) or
- appbuilder.sm.find_user(email=args.email))
+ user = appbuilder.sm.find_user(username=args.username) or appbuilder.sm.find_user(email=args.email)
if not user:
- raise SystemExit('User "{}" does not exist'.format(
- args.username or args.email))
+ raise SystemExit('User "{}" does not exist'.format(args.username or args.email))
role = appbuilder.sm.find_role(args.role)
if not role:
@@ -114,24 +108,16 @@ def users_manage_role(args, remove=False):
if role in user.roles:
user.roles = [r for r in user.roles if r != role]
appbuilder.sm.update_user(user)
- print('User "{}" removed from role "{}".'.format(
- user,
- args.role))
+ print(f'User "{user}" removed from role "{args.role}".')
else:
- raise SystemExit('User "{}" is not a member of role "{}".'.format(
- user,
- args.role))
+ raise SystemExit(f'User "{user}" is not a member of role "{args.role}".')
else:
if role in user.roles:
- raise SystemExit('User "{}" is already a member of role "{}".'.format(
- user,
- args.role))
+ raise SystemExit(f'User "{user}" is already a member of role "{args.role}".')
else:
user.roles.append(role)
appbuilder.sm.update_user(user)
- print('User "{}" added to role "{}".'.format(
- user,
- args.role))
+ print(f'User "{user}" added to role "{args.role}".')
def users_export(args):
@@ -146,9 +132,12 @@ def remove_underscores(s):
return re.sub("_", "", s)
users = [
- {remove_underscores(field): user.__getattribute__(field)
- if field != 'roles' else [r.name for r in user.roles]
- for field in fields}
+ {
+ remove_underscores(field): user.__getattribute__(field)
+ if field != 'roles'
+ else [r.name for r in user.roles]
+ for field in fields
+ }
for user in users
]
@@ -175,12 +164,10 @@ def users_import(args):
users_created, users_updated = _import_users(users_list)
if users_created:
- print("Created the following users:\n\t{}".format(
- "\n\t".join(users_created)))
+ print("Created the following users:\n\t{}".format("\n\t".join(users_created)))
if users_updated:
- print("Updated the following users:\n\t{}".format(
- "\n\t".join(users_updated)))
+ print("Updated the following users:\n\t{}".format("\n\t".join(users_updated)))
def _import_users(users_list): # pylint: disable=redefined-outer-name
@@ -199,12 +186,10 @@ def _import_users(users_list): # pylint: disable=redefined-outer-name
else:
roles.append(role)
- required_fields = ['username', 'firstname', 'lastname',
- 'email', 'roles']
+ required_fields = ['username', 'firstname', 'lastname', 'email', 'roles']
for field in required_fields:
if not user.get(field):
- print("Error: '{}' is a required field, but was not "
- "specified".format(field))
+ print("Error: '{}' is a required field, but was not " "specified".format(field))
sys.exit(1)
existing_user = appbuilder.sm.find_user(email=user['email'])
@@ -215,9 +200,11 @@ def _import_users(users_list): # pylint: disable=redefined-outer-name
existing_user.last_name = user['lastname']
if existing_user.username != user['username']:
- print("Error: Changing the username is not allowed - "
- "please delete and recreate the user with "
- "email '{}'".format(user['email']))
+ print(
+ "Error: Changing the username is not allowed - "
+ "please delete and recreate the user with "
+ "email '{}'".format(user['email'])
+ )
sys.exit(1)
appbuilder.sm.update_user(existing_user)
diff --git a/airflow/cli/commands/variable_command.py b/airflow/cli/commands/variable_command.py
index 1781b3667625a..e76dc800eac35 100644
--- a/airflow/cli/commands/variable_command.py
+++ b/airflow/cli/commands/variable_command.py
@@ -37,17 +37,10 @@ def variables_get(args):
"""Displays variable by a given name"""
try:
if args.default is None:
- var = Variable.get(
- args.key,
- deserialize_json=args.json
- )
+ var = Variable.get(args.key, deserialize_json=args.json)
print(var)
else:
- var = Variable.get(
- args.key,
- deserialize_json=args.json,
- default_var=args.default
- )
+ var = Variable.get(args.key, deserialize_json=args.json, default_var=args.default)
print(var)
except (ValueError, KeyError) as e:
print(str(e), file=sys.stderr)
diff --git a/airflow/cli/commands/webserver_command.py b/airflow/cli/commands/webserver_command.py
index 5a4f7e24f986b..1b3f91308220c 100644
--- a/airflow/cli/commands/webserver_command.py
+++ b/airflow/cli/commands/webserver_command.py
@@ -87,7 +87,7 @@ def __init__(
master_timeout: int,
worker_refresh_interval: int,
worker_refresh_batch_size: int,
- reload_on_plugin_change: bool
+ reload_on_plugin_change: bool,
):
super().__init__()
self.gunicorn_master_proc = psutil.Process(gunicorn_master_pid)
@@ -152,9 +152,7 @@ def _wait_until_true(self, fn, timeout: int = 0) -> None:
start_time = time.time()
while not fn():
if 0 < timeout <= time.time() - start_time:
- raise AirflowWebServerTimeout(
- f"No response from gunicorn master within {timeout} seconds"
- )
+ raise AirflowWebServerTimeout(f"No response from gunicorn master within {timeout} seconds")
sleep(0.1)
def _spawn_new_workers(self, count: int) -> None:
@@ -170,7 +168,7 @@ def _spawn_new_workers(self, count: int) -> None:
excess += 1
self._wait_until_true(
lambda: self.num_workers_expected + excess == self._get_num_workers_running(),
- timeout=self.master_timeout
+ timeout=self.master_timeout,
)
def _kill_old_workers(self, count: int) -> None:
@@ -185,7 +183,8 @@ def _kill_old_workers(self, count: int) -> None:
self.gunicorn_master_proc.send_signal(signal.SIGTTOU)
self._wait_until_true(
lambda: self.num_workers_expected + count == self._get_num_workers_running(),
- timeout=self.master_timeout)
+ timeout=self.master_timeout,
+ )
def _reload_gunicorn(self) -> None:
"""
@@ -197,8 +196,7 @@ def _reload_gunicorn(self) -> None:
self.gunicorn_master_proc.send_signal(signal.SIGHUP)
sleep(1)
self._wait_until_true(
- lambda: self.num_workers_expected == self._get_num_workers_running(),
- timeout=self.master_timeout
+ lambda: self.num_workers_expected == self._get_num_workers_running(), timeout=self.master_timeout
)
def start(self) -> NoReturn:
@@ -206,7 +204,7 @@ def start(self) -> NoReturn:
try: # pylint: disable=too-many-nested-blocks
self._wait_until_true(
lambda: self.num_workers_expected == self._get_num_workers_running(),
- timeout=self.master_timeout
+ timeout=self.master_timeout,
)
while True:
if not self.gunicorn_master_proc.is_running():
@@ -232,7 +230,8 @@ def _check_workers(self) -> None:
if num_ready_workers_running < num_workers_running:
self.log.debug(
'[%d / %d] Some workers are starting up, waiting...',
- num_ready_workers_running, num_workers_running
+ num_ready_workers_running,
+ num_workers_running,
)
sleep(1)
return
@@ -251,9 +250,9 @@ def _check_workers(self) -> None:
# to increase number of workers
if num_workers_running < self.num_workers_expected:
self.log.error(
- "[%d / %d] Some workers seem to have died and gunicorn did not restart "
- "them as expected",
- num_ready_workers_running, num_workers_running
+ "[%d / %d] Some workers seem to have died and gunicorn did not restart " "them as expected",
+ num_ready_workers_running,
+ num_workers_running,
)
sleep(10)
num_workers_running = self._get_num_workers_running()
@@ -263,7 +262,9 @@ def _check_workers(self) -> None:
)
self.log.debug(
'[%d / %d] Spawning %d workers',
- num_ready_workers_running, num_workers_running, new_worker_count
+ num_ready_workers_running,
+ num_workers_running,
+ new_worker_count,
)
self._spawn_new_workers(num_workers_running)
return
@@ -273,12 +274,14 @@ def _check_workers(self) -> None:
# If workers should be restarted periodically.
if self.worker_refresh_interval > 0 and self._last_refresh_time:
# and we refreshed the workers a long time ago, refresh the workers
- last_refresh_diff = (time.time() - self._last_refresh_time)
+ last_refresh_diff = time.time() - self._last_refresh_time
if self.worker_refresh_interval < last_refresh_diff:
num_new_workers = self.worker_refresh_batch_size
self.log.debug(
'[%d / %d] Starting doing a refresh. Starting %d workers.',
- num_ready_workers_running, num_workers_running, num_new_workers
+ num_ready_workers_running,
+ num_workers_running,
+ num_new_workers,
)
self._spawn_new_workers(num_new_workers)
self._last_refresh_time = time.time()
@@ -293,14 +296,16 @@ def _check_workers(self) -> None:
self.log.debug(
'[%d / %d] Plugins folder changed. The gunicorn will be restarted the next time the '
'plugin directory is checked, if there is no change in it.',
- num_ready_workers_running, num_workers_running
+ num_ready_workers_running,
+ num_workers_running,
)
self._restart_on_next_plugin_check = True
self._last_plugin_state = new_state
elif self._restart_on_next_plugin_check:
self.log.debug(
'[%d / %d] Starts reloading the gunicorn configuration.',
- num_ready_workers_running, num_workers_running
+ num_ready_workers_running,
+ num_workers_running,
)
self._restart_on_next_plugin_check = False
self._last_refresh_time = time.time()
@@ -315,25 +320,24 @@ def webserver(args):
access_logfile = args.access_logfile or conf.get('webserver', 'access_logfile')
error_logfile = args.error_logfile or conf.get('webserver', 'error_logfile')
num_workers = args.workers or conf.get('webserver', 'workers')
- worker_timeout = (args.worker_timeout or
- conf.get('webserver', 'web_server_worker_timeout'))
+ worker_timeout = args.worker_timeout or conf.get('webserver', 'web_server_worker_timeout')
ssl_cert = args.ssl_cert or conf.get('webserver', 'web_server_ssl_cert')
ssl_key = args.ssl_key or conf.get('webserver', 'web_server_ssl_key')
if not ssl_cert and ssl_key:
- raise AirflowException(
- 'An SSL certificate must also be provided for use with ' + ssl_key)
+ raise AirflowException('An SSL certificate must also be provided for use with ' + ssl_key)
if ssl_cert and not ssl_key:
- raise AirflowException(
- 'An SSL key must also be provided for use with ' + ssl_cert)
+ raise AirflowException('An SSL key must also be provided for use with ' + ssl_cert)
if args.debug:
- print(
- "Starting the web server on port {} and host {}.".format(
- args.port, args.hostname))
+ print(f"Starting the web server on port {args.port} and host {args.hostname}.")
app = create_app(testing=conf.getboolean('core', 'unit_test_mode'))
- app.run(debug=True, use_reloader=not app.config['TESTING'],
- port=args.port, host=args.hostname,
- ssl_context=(ssl_cert, ssl_key) if ssl_cert and ssl_key else None)
+ app.run(
+ debug=True,
+ use_reloader=not app.config['TESTING'],
+ port=args.port,
+ host=args.hostname,
+ ssl_context=(ssl_cert, ssl_key) if ssl_cert and ssl_key else None,
+ )
else:
# This pre-warms the cache, and makes possible errors
# get reported earlier (i.e. before demonization)
@@ -342,33 +346,49 @@ def webserver(args):
os.environ.pop('SKIP_DAGS_PARSING')
pid_file, stdout, stderr, log_file = setup_locations(
- "webserver", args.pid, args.stdout, args.stderr, args.log_file)
+ "webserver", args.pid, args.stdout, args.stderr, args.log_file
+ )
# Check if webserver is already running if not, remove old pidfile
check_if_pidfile_process_is_running(pid_file=pid_file, process_name="webserver")
print(
- textwrap.dedent('''\
+ textwrap.dedent(
+ '''\
Running the Gunicorn Server with:
Workers: {num_workers} {workerclass}
Host: {hostname}:{port}
Timeout: {worker_timeout}
Logfiles: {access_logfile} {error_logfile}
=================================================================\
- '''.format(num_workers=num_workers, workerclass=args.workerclass,
- hostname=args.hostname, port=args.port,
- worker_timeout=worker_timeout, access_logfile=access_logfile,
- error_logfile=error_logfile)))
+ '''.format(
+ num_workers=num_workers,
+ workerclass=args.workerclass,
+ hostname=args.hostname,
+ port=args.port,
+ worker_timeout=worker_timeout,
+ access_logfile=access_logfile,
+ error_logfile=error_logfile,
+ )
+ )
+ )
run_args = [
'gunicorn',
- '--workers', str(num_workers),
- '--worker-class', str(args.workerclass),
- '--timeout', str(worker_timeout),
- '--bind', args.hostname + ':' + str(args.port),
- '--name', 'airflow-webserver',
- '--pid', pid_file,
- '--config', 'python:airflow.www.gunicorn_config',
+ '--workers',
+ str(num_workers),
+ '--worker-class',
+ str(args.workerclass),
+ '--timeout',
+ str(worker_timeout),
+ '--bind',
+ args.hostname + ':' + str(args.port),
+ '--name',
+ 'airflow-webserver',
+ '--pid',
+ pid_file,
+ '--config',
+ 'python:airflow.www.gunicorn_config',
]
if args.access_logfile:
diff --git a/airflow/config_templates/airflow_local_settings.py b/airflow/config_templates/airflow_local_settings.py
index 67c54447a9ef6..16b372cc3302c 100644
--- a/airflow/config_templates/airflow_local_settings.py
+++ b/airflow/config_templates/airflow_local_settings.py
@@ -58,19 +58,17 @@
'version': 1,
'disable_existing_loggers': False,
'formatters': {
- 'airflow': {
- 'format': LOG_FORMAT
- },
+ 'airflow': {'format': LOG_FORMAT},
'airflow_coloured': {
'format': COLORED_LOG_FORMAT if COLORED_LOG else LOG_FORMAT,
- 'class': COLORED_FORMATTER_CLASS if COLORED_LOG else 'logging.Formatter'
+ 'class': COLORED_FORMATTER_CLASS if COLORED_LOG else 'logging.Formatter',
},
},
'handlers': {
'console': {
'class': 'airflow.utils.log.logging_mixin.RedirectStdHandler',
'formatter': 'airflow_coloured',
- 'stream': 'sys.stdout'
+ 'stream': 'sys.stdout',
},
'task': {
'class': 'airflow.utils.log.file_task_handler.FileTaskHandler',
@@ -83,7 +81,7 @@
'formatter': 'airflow',
'base_log_folder': os.path.expanduser(PROCESSOR_LOG_FOLDER),
'filename_template': PROCESSOR_FILENAME_TEMPLATE,
- }
+ },
},
'loggers': {
'airflow.processor': {
@@ -100,19 +98,22 @@
'handler': ['console'],
'level': FAB_LOG_LEVEL,
'propagate': True,
- }
+ },
},
'root': {
'handlers': ['console'],
'level': LOG_LEVEL,
- }
+ },
}
EXTRA_LOGGER_NAMES: str = conf.get('logging', 'EXTRA_LOGGER_NAMES', fallback=None)
if EXTRA_LOGGER_NAMES:
new_loggers = {
logger_name.strip(): {
- 'handler': ['console'], 'level': LOG_LEVEL, 'propagate': True, }
+ 'handler': ['console'],
+ 'level': LOG_LEVEL,
+ 'propagate': True,
+ }
for logger_name in EXTRA_LOGGER_NAMES.split(",")
}
DEFAULT_LOGGING_CONFIG['loggers'].update(new_loggers)
@@ -125,7 +126,7 @@
'filename': DAG_PROCESSOR_MANAGER_LOG_LOCATION,
'mode': 'a',
'maxBytes': 104857600, # 100MB
- 'backupCount': 5
+ 'backupCount': 5,
}
},
'loggers': {
@@ -134,22 +135,21 @@
'level': LOG_LEVEL,
'propagate': False,
}
- }
+ },
}
# Only update the handlers and loggers when CONFIG_PROCESSOR_MANAGER_LOGGER is set.
# This is to avoid exceptions when initializing RotatingFileHandler multiple times
# in multiple processes.
if os.environ.get('CONFIG_PROCESSOR_MANAGER_LOGGER') == 'True':
- DEFAULT_LOGGING_CONFIG['handlers'] \
- .update(DEFAULT_DAG_PARSING_LOGGING_CONFIG['handlers'])
- DEFAULT_LOGGING_CONFIG['loggers'] \
- .update(DEFAULT_DAG_PARSING_LOGGING_CONFIG['loggers'])
+ DEFAULT_LOGGING_CONFIG['handlers'].update(DEFAULT_DAG_PARSING_LOGGING_CONFIG['handlers'])
+ DEFAULT_LOGGING_CONFIG['loggers'].update(DEFAULT_DAG_PARSING_LOGGING_CONFIG['loggers'])
# Manually create log directory for processor_manager handler as RotatingFileHandler
# will only create file but not the directory.
- processor_manager_handler_config: Dict[str, Any] = \
- DEFAULT_DAG_PARSING_LOGGING_CONFIG['handlers']['processor_manager']
+ processor_manager_handler_config: Dict[str, Any] = DEFAULT_DAG_PARSING_LOGGING_CONFIG['handlers'][
+ 'processor_manager'
+ ]
directory: str = os.path.dirname(processor_manager_handler_config['filename'])
Path(directory).mkdir(parents=True, exist_ok=True, mode=0o755)
@@ -204,7 +204,7 @@
'base_log_folder': str(os.path.expanduser(BASE_LOG_FOLDER)),
'gcs_log_folder': REMOTE_BASE_LOG_FOLDER,
'filename_template': FILENAME_TEMPLATE,
- 'gcp_key_path': key_path
+ 'gcp_key_path': key_path,
},
}
@@ -232,7 +232,7 @@
'class': 'airflow.providers.google.cloud.log.stackdriver_task_handler.StackdriverTaskHandler',
'formatter': 'airflow',
'name': log_name,
- 'gcp_key_path': key_path
+ 'gcp_key_path': key_path,
}
}
@@ -257,7 +257,7 @@
'frontend': ELASTICSEARCH_FRONTEND,
'write_stdout': ELASTICSEARCH_WRITE_STDOUT,
'json_format': ELASTICSEARCH_JSON_FORMAT,
- 'json_fields': ELASTICSEARCH_JSON_FIELDS
+ 'json_fields': ELASTICSEARCH_JSON_FIELDS,
},
}
@@ -266,4 +266,5 @@
raise AirflowException(
"Incorrect remote log configuration. Please check the configuration of option 'host' in "
"section 'elasticsearch' if you are using Elasticsearch. In the other case, "
- "'remote_base_log_folder' option in 'core' section.")
+ "'remote_base_log_folder' option in 'core' section."
+ )
diff --git a/airflow/config_templates/default_celery.py b/airflow/config_templates/default_celery.py
index b12bf14e13e10..14d4c0f2fc597 100644
--- a/airflow/config_templates/default_celery.py
+++ b/airflow/config_templates/default_celery.py
@@ -59,30 +59,43 @@ def _broker_supports_visibility_timeout(url):
try:
if celery_ssl_active:
if 'amqp://' in broker_url:
- broker_use_ssl = {'keyfile': conf.get('celery', 'SSL_KEY'),
- 'certfile': conf.get('celery', 'SSL_CERT'),
- 'ca_certs': conf.get('celery', 'SSL_CACERT'),
- 'cert_reqs': ssl.CERT_REQUIRED}
+ broker_use_ssl = {
+ 'keyfile': conf.get('celery', 'SSL_KEY'),
+ 'certfile': conf.get('celery', 'SSL_CERT'),
+ 'ca_certs': conf.get('celery', 'SSL_CACERT'),
+ 'cert_reqs': ssl.CERT_REQUIRED,
+ }
elif 'redis://' in broker_url:
- broker_use_ssl = {'ssl_keyfile': conf.get('celery', 'SSL_KEY'),
- 'ssl_certfile': conf.get('celery', 'SSL_CERT'),
- 'ssl_ca_certs': conf.get('celery', 'SSL_CACERT'),
- 'ssl_cert_reqs': ssl.CERT_REQUIRED}
+ broker_use_ssl = {
+ 'ssl_keyfile': conf.get('celery', 'SSL_KEY'),
+ 'ssl_certfile': conf.get('celery', 'SSL_CERT'),
+ 'ssl_ca_certs': conf.get('celery', 'SSL_CACERT'),
+ 'ssl_cert_reqs': ssl.CERT_REQUIRED,
+ }
else:
- raise AirflowException('The broker you configured does not support SSL_ACTIVE to be True. '
- 'Please use RabbitMQ or Redis if you would like to use SSL for broker.')
+ raise AirflowException(
+ 'The broker you configured does not support SSL_ACTIVE to be True. '
+ 'Please use RabbitMQ or Redis if you would like to use SSL for broker.'
+ )
DEFAULT_CELERY_CONFIG['broker_use_ssl'] = broker_use_ssl
except AirflowConfigException:
- raise AirflowException('AirflowConfigException: SSL_ACTIVE is True, '
- 'please ensure SSL_KEY, '
- 'SSL_CERT and SSL_CACERT are set')
+ raise AirflowException(
+ 'AirflowConfigException: SSL_ACTIVE is True, '
+ 'please ensure SSL_KEY, '
+ 'SSL_CERT and SSL_CACERT are set'
+ )
except Exception as e:
- raise AirflowException('Exception: There was an unknown Celery SSL Error. '
- 'Please ensure you want to use '
- 'SSL and/or have all necessary certs and key ({}).'.format(e))
+ raise AirflowException(
+ 'Exception: There was an unknown Celery SSL Error. '
+ 'Please ensure you want to use '
+ 'SSL and/or have all necessary certs and key ({}).'.format(e)
+ )
result_backend = DEFAULT_CELERY_CONFIG['result_backend']
if 'amqp://' in result_backend or 'redis://' in result_backend or 'rpc://' in result_backend:
- log.warning("You have configured a result_backend of %s, it is highly recommended "
- "to use an alternative result_backend (i.e. a database).", result_backend)
+ log.warning(
+ "You have configured a result_backend of %s, it is highly recommended "
+ "to use an alternative result_backend (i.e. a database).",
+ result_backend,
+ )
diff --git a/airflow/configuration.py b/airflow/configuration.py
index bc4b041a3e50e..b5d1aa2ec2265 100644
--- a/airflow/configuration.py
+++ b/airflow/configuration.py
@@ -28,6 +28,7 @@
import warnings
from base64 import b64encode
from collections import OrderedDict
+
# Ignored Mypy on configparser because it thinks the configparser module has no _UNSET attribute
from configparser import _UNSET, ConfigParser, NoOptionError, NoSectionError # type: ignore
from json.decoder import JSONDecodeError
diff --git a/airflow/contrib/__init__.py b/airflow/contrib/__init__.py
index e852a5918ebb4..3a89862127ebc 100644
--- a/airflow/contrib/__init__.py
+++ b/airflow/contrib/__init__.py
@@ -19,6 +19,4 @@
import warnings
-warnings.warn(
- "This module is deprecated.", DeprecationWarning, stacklevel=2
-)
+warnings.warn("This module is deprecated.", DeprecationWarning, stacklevel=2)
diff --git a/airflow/contrib/hooks/aws_dynamodb_hook.py b/airflow/contrib/hooks/aws_dynamodb_hook.py
index 9ea6e1d29121a..fc6bd1b79c4cc 100644
--- a/airflow/contrib/hooks/aws_dynamodb_hook.py
+++ b/airflow/contrib/hooks/aws_dynamodb_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.dynamodb`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/aws_glue_catalog_hook.py b/airflow/contrib/hooks/aws_glue_catalog_hook.py
index f36820e2196e4..59733ea44b320 100644
--- a/airflow/contrib/hooks/aws_glue_catalog_hook.py
+++ b/airflow/contrib/hooks/aws_glue_catalog_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.glue_catalog`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/aws_hook.py b/airflow/contrib/hooks/aws_hook.py
index 55bd95139e70e..c18577656fcbd 100644
--- a/airflow/contrib/hooks/aws_hook.py
+++ b/airflow/contrib/hooks/aws_hook.py
@@ -24,7 +24,8 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.base_aws`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -34,6 +35,7 @@ class AwsHook(AwsBaseHook):
def __init__(self, *args, **kwargs):
warnings.warn(
"This class is deprecated. Please use `airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/hooks/aws_logs_hook.py b/airflow/contrib/hooks/aws_logs_hook.py
index d2b3446aad3a8..ab27562f55624 100644
--- a/airflow/contrib/hooks/aws_logs_hook.py
+++ b/airflow/contrib/hooks/aws_logs_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.logs`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/azure_container_instance_hook.py b/airflow/contrib/hooks/azure_container_instance_hook.py
index eaa0ff6cfd054..eaa9f091752b4 100644
--- a/airflow/contrib/hooks/azure_container_instance_hook.py
+++ b/airflow/contrib/hooks/azure_container_instance_hook.py
@@ -30,5 +30,6 @@
warnings.warn(
"This module is deprecated. "
"Please use `airflow.providers.microsoft.azure.hooks.azure_container_instance`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/azure_container_registry_hook.py b/airflow/contrib/hooks/azure_container_registry_hook.py
index 22260c49bf31a..fe9cfdc488e70 100644
--- a/airflow/contrib/hooks/azure_container_registry_hook.py
+++ b/airflow/contrib/hooks/azure_container_registry_hook.py
@@ -30,5 +30,6 @@
warnings.warn(
"This module is deprecated. "
"Please use `airflow.providers.microsoft.azure.hooks.azure_container_registry`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/azure_container_volume_hook.py b/airflow/contrib/hooks/azure_container_volume_hook.py
index 2547a5fca3fdb..4c747cadd3472 100644
--- a/airflow/contrib/hooks/azure_container_volume_hook.py
+++ b/airflow/contrib/hooks/azure_container_volume_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.azure_container_volume`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/azure_cosmos_hook.py b/airflow/contrib/hooks/azure_cosmos_hook.py
index 35ef4aea403c1..d39766935f7e1 100644
--- a/airflow/contrib/hooks/azure_cosmos_hook.py
+++ b/airflow/contrib/hooks/azure_cosmos_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.azure_cosmos`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/azure_data_lake_hook.py b/airflow/contrib/hooks/azure_data_lake_hook.py
index d11836f490178..0e21185e8eda6 100644
--- a/airflow/contrib/hooks/azure_data_lake_hook.py
+++ b/airflow/contrib/hooks/azure_data_lake_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.azure_data_lake`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/azure_fileshare_hook.py b/airflow/contrib/hooks/azure_fileshare_hook.py
index e166aa995bfe9..b4a9f99e39d34 100644
--- a/airflow/contrib/hooks/azure_fileshare_hook.py
+++ b/airflow/contrib/hooks/azure_fileshare_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.azure_fileshare`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/bigquery_hook.py b/airflow/contrib/hooks/bigquery_hook.py
index 6c9a95baa6524..6b95898c76d63 100644
--- a/airflow/contrib/hooks/bigquery_hook.py
+++ b/airflow/contrib/hooks/bigquery_hook.py
@@ -21,11 +21,16 @@
# pylint: disable=unused-import
from airflow.providers.google.cloud.hooks.bigquery import ( # noqa
- BigQueryBaseCursor, BigQueryConnection, BigQueryCursor, BigQueryHook, BigQueryPandasConnector,
+ BigQueryBaseCursor,
+ BigQueryConnection,
+ BigQueryCursor,
+ BigQueryHook,
+ BigQueryPandasConnector,
GbqConnector,
)
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.hooks.bigquery`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/cassandra_hook.py b/airflow/contrib/hooks/cassandra_hook.py
index c0002862ae264..223f15b237223 100644
--- a/airflow/contrib/hooks/cassandra_hook.py
+++ b/airflow/contrib/hooks/cassandra_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.apache.cassandra.hooks.cassandra`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/cloudant_hook.py b/airflow/contrib/hooks/cloudant_hook.py
index 453dd36a99abd..4d1da6c987ed5 100644
--- a/airflow/contrib/hooks/cloudant_hook.py
+++ b/airflow/contrib/hooks/cloudant_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.cloudant.hooks.cloudant`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/databricks_hook.py b/airflow/contrib/hooks/databricks_hook.py
index 3c610e316c21a..b2f177983294b 100644
--- a/airflow/contrib/hooks/databricks_hook.py
+++ b/airflow/contrib/hooks/databricks_hook.py
@@ -21,12 +21,21 @@
# pylint: disable=unused-import
from airflow.providers.databricks.hooks.databricks import ( # noqa
- CANCEL_RUN_ENDPOINT, GET_RUN_ENDPOINT, RESTART_CLUSTER_ENDPOINT, RUN_LIFE_CYCLE_STATES, RUN_NOW_ENDPOINT,
- START_CLUSTER_ENDPOINT, SUBMIT_RUN_ENDPOINT, TERMINATE_CLUSTER_ENDPOINT, USER_AGENT_HEADER,
- DatabricksHook, RunState,
+ CANCEL_RUN_ENDPOINT,
+ GET_RUN_ENDPOINT,
+ RESTART_CLUSTER_ENDPOINT,
+ RUN_LIFE_CYCLE_STATES,
+ RUN_NOW_ENDPOINT,
+ START_CLUSTER_ENDPOINT,
+ SUBMIT_RUN_ENDPOINT,
+ TERMINATE_CLUSTER_ENDPOINT,
+ USER_AGENT_HEADER,
+ DatabricksHook,
+ RunState,
)
warnings.warn(
"This module is deprecated. Please use `airflow.providers.databricks.hooks.databricks`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/datadog_hook.py b/airflow/contrib/hooks/datadog_hook.py
index 4746724e0d2ce..3751127db8393 100644
--- a/airflow/contrib/hooks/datadog_hook.py
+++ b/airflow/contrib/hooks/datadog_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.datadog.hooks.datadog`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/datastore_hook.py b/airflow/contrib/hooks/datastore_hook.py
index 2d395e10951fc..3a072eca543c0 100644
--- a/airflow/contrib/hooks/datastore_hook.py
+++ b/airflow/contrib/hooks/datastore_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.hooks.datastore`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/dingding_hook.py b/airflow/contrib/hooks/dingding_hook.py
index 5a4aaf78ed140..66f9c2954aa60 100644
--- a/airflow/contrib/hooks/dingding_hook.py
+++ b/airflow/contrib/hooks/dingding_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.dingding.hooks.dingding`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/discord_webhook_hook.py b/airflow/contrib/hooks/discord_webhook_hook.py
index 8de11e0016a18..e6c3cd00913f0 100644
--- a/airflow/contrib/hooks/discord_webhook_hook.py
+++ b/airflow/contrib/hooks/discord_webhook_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.discord.hooks.discord_webhook`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/emr_hook.py b/airflow/contrib/hooks/emr_hook.py
index a8cadcb6f727a..0ff71d4af9fba 100644
--- a/airflow/contrib/hooks/emr_hook.py
+++ b/airflow/contrib/hooks/emr_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.emr`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/fs_hook.py b/airflow/contrib/hooks/fs_hook.py
index 93adafda16209..73d409136a2e8 100644
--- a/airflow/contrib/hooks/fs_hook.py
+++ b/airflow/contrib/hooks/fs_hook.py
@@ -23,6 +23,5 @@
from airflow.hooks.filesystem import FSHook # noqa
warnings.warn(
- "This module is deprecated. Please use `airflow.hooks.filesystem`.",
- DeprecationWarning, stacklevel=2
+ "This module is deprecated. Please use `airflow.hooks.filesystem`.", DeprecationWarning, stacklevel=2
)
diff --git a/airflow/contrib/hooks/ftp_hook.py b/airflow/contrib/hooks/ftp_hook.py
index afe236bed01a8..bdb47b68c3c1d 100644
--- a/airflow/contrib/hooks/ftp_hook.py
+++ b/airflow/contrib/hooks/ftp_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.ftp.hooks.ftp`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/gcp_api_base_hook.py b/airflow/contrib/hooks/gcp_api_base_hook.py
index ca173f44955bb..468d7a3699b70 100644
--- a/airflow/contrib/hooks/gcp_api_base_hook.py
+++ b/airflow/contrib/hooks/gcp_api_base_hook.py
@@ -22,7 +22,8 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.common.hooks.base_google`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -36,6 +37,7 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"This class is deprecated. Please use "
"`airflow.providers.google.common.hooks.base_google.GoogleBaseHook`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/hooks/gcp_bigtable_hook.py b/airflow/contrib/hooks/gcp_bigtable_hook.py
index 2926f5f0d4f4c..372293070a9eb 100644
--- a/airflow/contrib/hooks/gcp_bigtable_hook.py
+++ b/airflow/contrib/hooks/gcp_bigtable_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.hooks.bigtable`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/gcp_cloud_build_hook.py b/airflow/contrib/hooks/gcp_cloud_build_hook.py
index 79009cd3ebc42..34535622a139e 100644
--- a/airflow/contrib/hooks/gcp_cloud_build_hook.py
+++ b/airflow/contrib/hooks/gcp_cloud_build_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.hooks.cloud_build`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/gcp_compute_hook.py b/airflow/contrib/hooks/gcp_compute_hook.py
index 84e2e4f7f59e7..f24f0bd4be6c4 100644
--- a/airflow/contrib/hooks/gcp_compute_hook.py
+++ b/airflow/contrib/hooks/gcp_compute_hook.py
@@ -23,7 +23,8 @@
warnings.warn(
"This module is deprecated. Please use airflow.providers.google.cloud.hooks.compute`",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -36,7 +37,8 @@ class GceHook(ComputeEngineHook):
def __init__(self, *args, **kwargs):
warnings.warn(
"This class is deprecated. Please use `airflow.providers.google.cloud.hooks.compute`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/hooks/gcp_container_hook.py b/airflow/contrib/hooks/gcp_container_hook.py
index 5d8d900a90d4d..d014a440eaa28 100644
--- a/airflow/contrib/hooks/gcp_container_hook.py
+++ b/airflow/contrib/hooks/gcp_container_hook.py
@@ -23,7 +23,8 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.hooks.kubernetes_engine`",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -33,6 +34,7 @@ class GKEClusterHook(GKEHook):
def __init__(self, *args, **kwargs):
warnings.warn(
"This class is deprecated. Please use `airflow.providers.google.cloud.hooks.container.GKEHook`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/hooks/gcp_dataflow_hook.py b/airflow/contrib/hooks/gcp_dataflow_hook.py
index 54080c513fafa..8e1502325a705 100644
--- a/airflow/contrib/hooks/gcp_dataflow_hook.py
+++ b/airflow/contrib/hooks/gcp_dataflow_hook.py
@@ -23,7 +23,8 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.hooks.dataflow`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -34,6 +35,7 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"This class is deprecated. "
"Please use `airflow.providers.google.cloud.hooks.dataflow.DataflowHook`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/hooks/gcp_dataproc_hook.py b/airflow/contrib/hooks/gcp_dataproc_hook.py
index 75c6df88a9d33..54f618e41a80a 100644
--- a/airflow/contrib/hooks/gcp_dataproc_hook.py
+++ b/airflow/contrib/hooks/gcp_dataproc_hook.py
@@ -23,7 +23,8 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.hooks.dataproc`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -37,7 +38,8 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.hooks.dataproc.DataprocHook`.""",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/hooks/gcp_dlp_hook.py b/airflow/contrib/hooks/gcp_dlp_hook.py
index cb1fbfc74637b..bf496e4f93a98 100644
--- a/airflow/contrib/hooks/gcp_dlp_hook.py
+++ b/airflow/contrib/hooks/gcp_dlp_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.hooks.dlp`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/gcp_function_hook.py b/airflow/contrib/hooks/gcp_function_hook.py
index d684d8d451658..d728fc98dfc81 100644
--- a/airflow/contrib/hooks/gcp_function_hook.py
+++ b/airflow/contrib/hooks/gcp_function_hook.py
@@ -23,7 +23,8 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.hooks.functions`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -37,7 +38,8 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"This class is deprecated. "
"Please use `airflow.providers.google.cloud.hooks.functions.CloudFunctionsHook`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/hooks/gcp_kms_hook.py b/airflow/contrib/hooks/gcp_kms_hook.py
index 6bd57399b4b38..7fce040de5031 100644
--- a/airflow/contrib/hooks/gcp_kms_hook.py
+++ b/airflow/contrib/hooks/gcp_kms_hook.py
@@ -23,7 +23,8 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.hooks.kms`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -33,6 +34,7 @@ class GoogleCloudKMSHook(CloudKMSHook):
def __init__(self, *args, **kwargs):
warnings.warn(
"This class is deprecated. Please use `airflow.providers.google.cloud.hooks.kms.CloudKMSHook`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/hooks/gcp_mlengine_hook.py b/airflow/contrib/hooks/gcp_mlengine_hook.py
index 07687c0b189ca..cec3f31de694b 100644
--- a/airflow/contrib/hooks/gcp_mlengine_hook.py
+++ b/airflow/contrib/hooks/gcp_mlengine_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.hooks.mlengine`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/gcp_natural_language_hook.py b/airflow/contrib/hooks/gcp_natural_language_hook.py
index 70f9405c01e2a..0243ff6c7a7ee 100644
--- a/airflow/contrib/hooks/gcp_natural_language_hook.py
+++ b/airflow/contrib/hooks/gcp_natural_language_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.hooks.natural_language`",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/gcp_pubsub_hook.py b/airflow/contrib/hooks/gcp_pubsub_hook.py
index e04f3cd6712e1..23f7ab0868329 100644
--- a/airflow/contrib/hooks/gcp_pubsub_hook.py
+++ b/airflow/contrib/hooks/gcp_pubsub_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.hooks.pubsub`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/gcp_speech_to_text_hook.py b/airflow/contrib/hooks/gcp_speech_to_text_hook.py
index dedece29b11c3..5b27c779c8b6b 100644
--- a/airflow/contrib/hooks/gcp_speech_to_text_hook.py
+++ b/airflow/contrib/hooks/gcp_speech_to_text_hook.py
@@ -23,7 +23,8 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.hooks.speech_to_text`",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -37,7 +38,8 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"This class is deprecated. "
"Please use `airflow.providers.google.cloud.hooks.speech_to_text.CloudSpeechToTextHook`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/hooks/gcp_tasks_hook.py b/airflow/contrib/hooks/gcp_tasks_hook.py
index e577de66e6130..837c210c8b1b9 100644
--- a/airflow/contrib/hooks/gcp_tasks_hook.py
+++ b/airflow/contrib/hooks/gcp_tasks_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.hooks.tasks`",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/gcp_text_to_speech_hook.py b/airflow/contrib/hooks/gcp_text_to_speech_hook.py
index 593502f9ab7a1..8a7c8fb23bf77 100644
--- a/airflow/contrib/hooks/gcp_text_to_speech_hook.py
+++ b/airflow/contrib/hooks/gcp_text_to_speech_hook.py
@@ -23,7 +23,8 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.hooks.text_to_speech`",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -37,7 +38,8 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"This class is deprecated. "
"Please use `airflow.providers.google.cloud.hooks.text_to_speech.CloudTextToSpeechHook`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/hooks/gcp_transfer_hook.py b/airflow/contrib/hooks/gcp_transfer_hook.py
index 535c24fe2cf06..57c098d52b7f9 100644
--- a/airflow/contrib/hooks/gcp_transfer_hook.py
+++ b/airflow/contrib/hooks/gcp_transfer_hook.py
@@ -27,7 +27,8 @@
warnings.warn(
"This module is deprecated. "
"Please use `airflow.providers.google.cloud.hooks.cloud_storage_transfer_service`",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -43,6 +44,7 @@ def __init__(self, *args, **kwargs):
Please use
`airflow.providers.google.cloud.hooks.cloud_storage_transfer_service
.CloudDataTransferServiceHook`.""",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/hooks/gcp_translate_hook.py b/airflow/contrib/hooks/gcp_translate_hook.py
index 016f8efc6175c..a0b1788ec8dc0 100644
--- a/airflow/contrib/hooks/gcp_translate_hook.py
+++ b/airflow/contrib/hooks/gcp_translate_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.hooks.translate`",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/gcp_video_intelligence_hook.py b/airflow/contrib/hooks/gcp_video_intelligence_hook.py
index 76ad51ef68ceb..9a99748ea9452 100644
--- a/airflow/contrib/hooks/gcp_video_intelligence_hook.py
+++ b/airflow/contrib/hooks/gcp_video_intelligence_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.hooks.video_intelligence`",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/gcp_vision_hook.py b/airflow/contrib/hooks/gcp_vision_hook.py
index 05ac86a5e86f2..98ac270ba1144 100644
--- a/airflow/contrib/hooks/gcp_vision_hook.py
+++ b/airflow/contrib/hooks/gcp_vision_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.hooks.vision`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/gcs_hook.py b/airflow/contrib/hooks/gcs_hook.py
index 625e1aa3313c7..7134f327e5f6a 100644
--- a/airflow/contrib/hooks/gcs_hook.py
+++ b/airflow/contrib/hooks/gcs_hook.py
@@ -22,7 +22,8 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.hooks.gcs`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -32,6 +33,7 @@ class GoogleCloudStorageHook(GCSHook):
def __init__(self, *args, **kwargs):
warnings.warn(
"This class is deprecated. Please use `airflow.providers.google.cloud.hooks.gcs.GCSHook`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/hooks/gdrive_hook.py b/airflow/contrib/hooks/gdrive_hook.py
index b2deb96529a83..c6d8172e7bb49 100644
--- a/airflow/contrib/hooks/gdrive_hook.py
+++ b/airflow/contrib/hooks/gdrive_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.suite.hooks.drive`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/grpc_hook.py b/airflow/contrib/hooks/grpc_hook.py
index 7f89d401169ed..631b47d418b21 100644
--- a/airflow/contrib/hooks/grpc_hook.py
+++ b/airflow/contrib/hooks/grpc_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.grpc.hooks.grpc`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/imap_hook.py b/airflow/contrib/hooks/imap_hook.py
index 12f0ab1306c98..c868dfa1ac968 100644
--- a/airflow/contrib/hooks/imap_hook.py
+++ b/airflow/contrib/hooks/imap_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.imap.hooks.imap`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/jenkins_hook.py b/airflow/contrib/hooks/jenkins_hook.py
index 72dabb6fbb38e..1e260e4595cbe 100644
--- a/airflow/contrib/hooks/jenkins_hook.py
+++ b/airflow/contrib/hooks/jenkins_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.jenkins.hooks.jenkins`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/mongo_hook.py b/airflow/contrib/hooks/mongo_hook.py
index db25b1131f580..ac5a3a0b753b0 100644
--- a/airflow/contrib/hooks/mongo_hook.py
+++ b/airflow/contrib/hooks/mongo_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.mongo.hooks.mongo`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/openfaas_hook.py b/airflow/contrib/hooks/openfaas_hook.py
index 4ea87488189a7..32ab77ced95d6 100644
--- a/airflow/contrib/hooks/openfaas_hook.py
+++ b/airflow/contrib/hooks/openfaas_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.openfaas.hooks.openfaas`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/opsgenie_alert_hook.py b/airflow/contrib/hooks/opsgenie_alert_hook.py
index 5a6b338cc996f..9bf46057a6ec9 100644
--- a/airflow/contrib/hooks/opsgenie_alert_hook.py
+++ b/airflow/contrib/hooks/opsgenie_alert_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.opsgenie.hooks.opsgenie_alert`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/pagerduty_hook.py b/airflow/contrib/hooks/pagerduty_hook.py
index cfce5e86160b6..0f1bbdd0943e2 100644
--- a/airflow/contrib/hooks/pagerduty_hook.py
+++ b/airflow/contrib/hooks/pagerduty_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.pagerduty.hooks.pagerduty`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/pinot_hook.py b/airflow/contrib/hooks/pinot_hook.py
index 5ecfbfa35dfc3..d3b9816c3ffbb 100644
--- a/airflow/contrib/hooks/pinot_hook.py
+++ b/airflow/contrib/hooks/pinot_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.apache.pinot.hooks.pinot`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/qubole_check_hook.py b/airflow/contrib/hooks/qubole_check_hook.py
index 03fed8c1e6da6..8df277d17c641 100644
--- a/airflow/contrib/hooks/qubole_check_hook.py
+++ b/airflow/contrib/hooks/qubole_check_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.qubole.hooks.qubole_check`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/qubole_hook.py b/airflow/contrib/hooks/qubole_hook.py
index ab2e6cd8b04e6..aee57ba675ada 100644
--- a/airflow/contrib/hooks/qubole_hook.py
+++ b/airflow/contrib/hooks/qubole_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.qubole.hooks.qubole`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/redis_hook.py b/airflow/contrib/hooks/redis_hook.py
index 7938f57a4ac2b..d2bd8aba1b3af 100644
--- a/airflow/contrib/hooks/redis_hook.py
+++ b/airflow/contrib/hooks/redis_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.redis.hooks.redis`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/sagemaker_hook.py b/airflow/contrib/hooks/sagemaker_hook.py
index c2f60125a815f..8d33c27359c23 100644
--- a/airflow/contrib/hooks/sagemaker_hook.py
+++ b/airflow/contrib/hooks/sagemaker_hook.py
@@ -21,11 +21,16 @@
# pylint: disable=unused-import
from airflow.providers.amazon.aws.hooks.sagemaker import ( # noqa
- LogState, Position, SageMakerHook, argmin, secondary_training_status_changed,
+ LogState,
+ Position,
+ SageMakerHook,
+ argmin,
+ secondary_training_status_changed,
secondary_training_status_message,
)
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.sagemaker`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/salesforce_hook.py b/airflow/contrib/hooks/salesforce_hook.py
index 5ea627a0b0674..d1462e1c490b8 100644
--- a/airflow/contrib/hooks/salesforce_hook.py
+++ b/airflow/contrib/hooks/salesforce_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.salesforce.hooks.salesforce`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/segment_hook.py b/airflow/contrib/hooks/segment_hook.py
index c286fbe405d6b..7e3247c1018b2 100644
--- a/airflow/contrib/hooks/segment_hook.py
+++ b/airflow/contrib/hooks/segment_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.segment.hooks.segment`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/slack_webhook_hook.py b/airflow/contrib/hooks/slack_webhook_hook.py
index 420a34bf7a90a..7e4ec92d73157 100644
--- a/airflow/contrib/hooks/slack_webhook_hook.py
+++ b/airflow/contrib/hooks/slack_webhook_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.slack.hooks.slack_webhook`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/spark_jdbc_hook.py b/airflow/contrib/hooks/spark_jdbc_hook.py
index af3ed2c0149bf..189e5bf65c6cb 100644
--- a/airflow/contrib/hooks/spark_jdbc_hook.py
+++ b/airflow/contrib/hooks/spark_jdbc_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.apache.spark.hooks.spark_jdbc`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/spark_sql_hook.py b/airflow/contrib/hooks/spark_sql_hook.py
index 760c0e7808e82..639fae1392c9f 100644
--- a/airflow/contrib/hooks/spark_sql_hook.py
+++ b/airflow/contrib/hooks/spark_sql_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.apache.spark.hooks.spark_sql`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/spark_submit_hook.py b/airflow/contrib/hooks/spark_submit_hook.py
index d067ba7f186ad..5c69e48c552fe 100644
--- a/airflow/contrib/hooks/spark_submit_hook.py
+++ b/airflow/contrib/hooks/spark_submit_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.apache.spark.hooks.spark_submit`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/sqoop_hook.py b/airflow/contrib/hooks/sqoop_hook.py
index ccf1df3948929..ce5c737371a54 100644
--- a/airflow/contrib/hooks/sqoop_hook.py
+++ b/airflow/contrib/hooks/sqoop_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.apache.sqoop.hooks.sqoop`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/ssh_hook.py b/airflow/contrib/hooks/ssh_hook.py
index 2f1b02e7c5295..d374339390f3d 100644
--- a/airflow/contrib/hooks/ssh_hook.py
+++ b/airflow/contrib/hooks/ssh_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.ssh.hooks.ssh`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/vertica_hook.py b/airflow/contrib/hooks/vertica_hook.py
index 9a38a496d2aba..592fccc2b3dd2 100644
--- a/airflow/contrib/hooks/vertica_hook.py
+++ b/airflow/contrib/hooks/vertica_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.vertica.hooks.vertica`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/wasb_hook.py b/airflow/contrib/hooks/wasb_hook.py
index 6cb1f66ad6dba..a1b62ad539bb4 100644
--- a/airflow/contrib/hooks/wasb_hook.py
+++ b/airflow/contrib/hooks/wasb_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.wasb`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/hooks/winrm_hook.py b/airflow/contrib/hooks/winrm_hook.py
index c45c825404d04..24e5abec28536 100644
--- a/airflow/contrib/hooks/winrm_hook.py
+++ b/airflow/contrib/hooks/winrm_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.microsoft.winrm.hooks.winrm`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/adls_list_operator.py b/airflow/contrib/operators/adls_list_operator.py
index f6643ecdf65c3..1572a22d72d22 100644
--- a/airflow/contrib/operators/adls_list_operator.py
+++ b/airflow/contrib/operators/adls_list_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.microsoft.azure.operators.adls_list`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/adls_to_gcs.py b/airflow/contrib/operators/adls_to_gcs.py
index 0b273ca93a275..7839b4750728b 100644
--- a/airflow/contrib/operators/adls_to_gcs.py
+++ b/airflow/contrib/operators/adls_to_gcs.py
@@ -23,7 +23,8 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.transfers.adls_to_gcs`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -37,6 +38,7 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.transfers.adls_to_gcs.ADLSToGCSOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/azure_container_instances_operator.py b/airflow/contrib/operators/azure_container_instances_operator.py
index 56ea84fb7bf27..7efd0084cccf3 100644
--- a/airflow/contrib/operators/azure_container_instances_operator.py
+++ b/airflow/contrib/operators/azure_container_instances_operator.py
@@ -29,5 +29,6 @@
warnings.warn(
"This module is deprecated. "
"Please use `airflow.providers.microsoft.azure.operators.azure_container_instances`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/azure_cosmos_operator.py b/airflow/contrib/operators/azure_cosmos_operator.py
index 48eb2c2932dba..f08b8ddf1f5e9 100644
--- a/airflow/contrib/operators/azure_cosmos_operator.py
+++ b/airflow/contrib/operators/azure_cosmos_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.microsoft.azure.operators.azure_cosmos`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/bigquery_check_operator.py b/airflow/contrib/operators/bigquery_check_operator.py
index 29069fd55aebe..c6813946c1af1 100644
--- a/airflow/contrib/operators/bigquery_check_operator.py
+++ b/airflow/contrib/operators/bigquery_check_operator.py
@@ -21,10 +21,13 @@
# pylint: disable=unused-import
from airflow.providers.google.cloud.operators.bigquery import ( # noqa
- BigQueryCheckOperator, BigQueryIntervalCheckOperator, BigQueryValueCheckOperator,
+ BigQueryCheckOperator,
+ BigQueryIntervalCheckOperator,
+ BigQueryValueCheckOperator,
)
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.operators.bigquery`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/bigquery_get_data.py b/airflow/contrib/operators/bigquery_get_data.py
index 9fe6be369907a..0a5a176f65031 100644
--- a/airflow/contrib/operators/bigquery_get_data.py
+++ b/airflow/contrib/operators/bigquery_get_data.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.operators.bigquery`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/bigquery_operator.py b/airflow/contrib/operators/bigquery_operator.py
index c01c10366f538..32746264fa8e5 100644
--- a/airflow/contrib/operators/bigquery_operator.py
+++ b/airflow/contrib/operators/bigquery_operator.py
@@ -21,9 +21,15 @@
# pylint: disable=unused-import
from airflow.providers.google.cloud.operators.bigquery import ( # noqa; noqa; noqa; noqa; noqa
- BigQueryCreateEmptyDatasetOperator, BigQueryCreateEmptyTableOperator, BigQueryCreateExternalTableOperator,
- BigQueryDeleteDatasetOperator, BigQueryExecuteQueryOperator, BigQueryGetDatasetOperator,
- BigQueryGetDatasetTablesOperator, BigQueryPatchDatasetOperator, BigQueryUpdateDatasetOperator,
+ BigQueryCreateEmptyDatasetOperator,
+ BigQueryCreateEmptyTableOperator,
+ BigQueryCreateExternalTableOperator,
+ BigQueryDeleteDatasetOperator,
+ BigQueryExecuteQueryOperator,
+ BigQueryGetDatasetOperator,
+ BigQueryGetDatasetTablesOperator,
+ BigQueryPatchDatasetOperator,
+ BigQueryUpdateDatasetOperator,
BigQueryUpsertTableOperator,
)
diff --git a/airflow/contrib/operators/bigquery_to_bigquery.py b/airflow/contrib/operators/bigquery_to_bigquery.py
index e585ccdf04a00..74df81e70ec91 100644
--- a/airflow/contrib/operators/bigquery_to_bigquery.py
+++ b/airflow/contrib/operators/bigquery_to_bigquery.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.transfers.bigquery_to_bigquery`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/bigquery_to_gcs.py b/airflow/contrib/operators/bigquery_to_gcs.py
index d51397da64ef2..243bc4343ee68 100644
--- a/airflow/contrib/operators/bigquery_to_gcs.py
+++ b/airflow/contrib/operators/bigquery_to_gcs.py
@@ -23,7 +23,8 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.transfers.bigquery_to_gcs`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -37,6 +38,7 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.transfers.bigquery_to_gcs.BigQueryToGCSOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/bigquery_to_mysql_operator.py b/airflow/contrib/operators/bigquery_to_mysql_operator.py
index 37fec73ac92b3..bebb8edc3de92 100644
--- a/airflow/contrib/operators/bigquery_to_mysql_operator.py
+++ b/airflow/contrib/operators/bigquery_to_mysql_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.transfers.bigquery_to_mysql`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/cassandra_to_gcs.py b/airflow/contrib/operators/cassandra_to_gcs.py
index 9ac70f8f1da46..3d889e686f351 100644
--- a/airflow/contrib/operators/cassandra_to_gcs.py
+++ b/airflow/contrib/operators/cassandra_to_gcs.py
@@ -23,7 +23,8 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.transfers.cassandra_to_gcs`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -37,6 +38,7 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.transfers.cassandra_to_gcs.CassandraToGCSOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/databricks_operator.py b/airflow/contrib/operators/databricks_operator.py
index a0b8e495d9d16..6f7c950b366dc 100644
--- a/airflow/contrib/operators/databricks_operator.py
+++ b/airflow/contrib/operators/databricks_operator.py
@@ -21,10 +21,12 @@
# pylint: disable=unused-import
from airflow.providers.databricks.operators.databricks import ( # noqa
- DatabricksRunNowOperator, DatabricksSubmitRunOperator,
+ DatabricksRunNowOperator,
+ DatabricksSubmitRunOperator,
)
warnings.warn(
"This module is deprecated. Please use `airflow.providers.databricks.operators.databricks`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/dataflow_operator.py b/airflow/contrib/operators/dataflow_operator.py
index 1f1f79470eeab..3d03a8bf5f779 100644
--- a/airflow/contrib/operators/dataflow_operator.py
+++ b/airflow/contrib/operators/dataflow_operator.py
@@ -20,12 +20,15 @@
import warnings
from airflow.providers.google.cloud.operators.dataflow import (
- DataflowCreateJavaJobOperator, DataflowCreatePythonJobOperator, DataflowTemplatedJobStartOperator,
+ DataflowCreateJavaJobOperator,
+ DataflowCreatePythonJobOperator,
+ DataflowTemplatedJobStartOperator,
)
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.operators.dataflow`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -39,7 +42,8 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.operators.dataflow.DataflowCreateJavaJobOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -55,7 +59,8 @@ def __init__(self, *args, **kwargs):
"""This class is deprecated.
Please use
`airflow.providers.google.cloud.operators.dataflow.DataflowCreatePythonJobOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -71,6 +76,7 @@ def __init__(self, *args, **kwargs):
"""This class is deprecated.
Please use
`airflow.providers.google.cloud.operators.dataflow.DataflowTemplatedJobStartOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/dataproc_operator.py b/airflow/contrib/operators/dataproc_operator.py
index 0337b58eeff46..feac6274a9db0 100644
--- a/airflow/contrib/operators/dataproc_operator.py
+++ b/airflow/contrib/operators/dataproc_operator.py
@@ -20,11 +20,18 @@
import warnings
from airflow.providers.google.cloud.operators.dataproc import (
- DataprocCreateClusterOperator, DataprocDeleteClusterOperator,
- DataprocInstantiateInlineWorkflowTemplateOperator, DataprocInstantiateWorkflowTemplateOperator,
- DataprocJobBaseOperator, DataprocScaleClusterOperator, DataprocSubmitHadoopJobOperator,
- DataprocSubmitHiveJobOperator, DataprocSubmitPigJobOperator, DataprocSubmitPySparkJobOperator,
- DataprocSubmitSparkJobOperator, DataprocSubmitSparkSqlJobOperator,
+ DataprocCreateClusterOperator,
+ DataprocDeleteClusterOperator,
+ DataprocInstantiateInlineWorkflowTemplateOperator,
+ DataprocInstantiateWorkflowTemplateOperator,
+ DataprocJobBaseOperator,
+ DataprocScaleClusterOperator,
+ DataprocSubmitHadoopJobOperator,
+ DataprocSubmitHiveJobOperator,
+ DataprocSubmitPigJobOperator,
+ DataprocSubmitPySparkJobOperator,
+ DataprocSubmitSparkJobOperator,
+ DataprocSubmitSparkSqlJobOperator,
)
warnings.warn(
@@ -44,7 +51,8 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.operators.dataproc.DataprocCreateClusterOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -59,7 +67,8 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.operators.dataproc.DataprocDeleteClusterOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -74,7 +83,8 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.operators.dataproc.DataprocScaleClusterOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -90,7 +100,8 @@ def __init__(self, *args, **kwargs):
"""This class is deprecated.
Please use
`airflow.providers.google.cloud.operators.dataproc.DataprocSubmitHadoopJobOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -106,7 +117,8 @@ def __init__(self, *args, **kwargs):
"""This class is deprecated.
Please use
`airflow.providers.google.cloud.operators.dataproc.DataprocSubmitHiveJobOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -121,7 +133,8 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.operators.dataproc.DataprocJobBaseOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -136,7 +149,8 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.operators.dataproc.DataprocSubmitPigJobOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -152,7 +166,8 @@ def __init__(self, *args, **kwargs):
"""This class is deprecated.
Please use
`airflow.providers.google.cloud.operators.dataproc.DataprocSubmitPySparkJobOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -168,7 +183,8 @@ def __init__(self, *args, **kwargs):
"""This class is deprecated.
Please use
`airflow.providers.google.cloud.operators.dataproc.DataprocSubmitSparkJobOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -184,7 +200,8 @@ def __init__(self, *args, **kwargs):
"""This class is deprecated.
Please use
`airflow.providers.google.cloud.operators.dataproc.DataprocSubmitSparkSqlJobOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -202,7 +219,8 @@ def __init__(self, *args, **kwargs):
Please use
`airflow.providers.google.cloud.operators.dataproc
.DataprocInstantiateInlineWorkflowTemplateOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -220,6 +238,7 @@ def __init__(self, *args, **kwargs):
Please use
`airflow.providers.google.cloud.operators.dataproc
.DataprocInstantiateWorkflowTemplateOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/datastore_export_operator.py b/airflow/contrib/operators/datastore_export_operator.py
index e2aca08c8df1f..bc88e6e47e988 100644
--- a/airflow/contrib/operators/datastore_export_operator.py
+++ b/airflow/contrib/operators/datastore_export_operator.py
@@ -23,7 +23,8 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.operators.datastore`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -38,6 +39,7 @@ def __init__(self, *args, **kwargs):
"""This class is deprecated.l
Please use
`airflow.providers.google.cloud.operators.datastore.CloudDatastoreExportEntitiesOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/datastore_import_operator.py b/airflow/contrib/operators/datastore_import_operator.py
index 2bd8b2493b0d8..4375336588e9e 100644
--- a/airflow/contrib/operators/datastore_import_operator.py
+++ b/airflow/contrib/operators/datastore_import_operator.py
@@ -23,7 +23,8 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.operators.datastore`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -38,6 +39,7 @@ def __init__(self, *args, **kwargs):
"""This class is deprecated.
Please use
`airflow.providers.google.cloud.operators.datastore.CloudDatastoreImportEntitiesOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/dingding_operator.py b/airflow/contrib/operators/dingding_operator.py
index 7f911d46ed133..0fb52aacff631 100644
--- a/airflow/contrib/operators/dingding_operator.py
+++ b/airflow/contrib/operators/dingding_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.dingding.operators.dingding`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/discord_webhook_operator.py b/airflow/contrib/operators/discord_webhook_operator.py
index 51c12b840ad8c..d6392bc9c5a69 100644
--- a/airflow/contrib/operators/discord_webhook_operator.py
+++ b/airflow/contrib/operators/discord_webhook_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.discord.operators.discord_webhook`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/docker_swarm_operator.py b/airflow/contrib/operators/docker_swarm_operator.py
index 3ac517fd80b6c..796a60d0eb8f9 100644
--- a/airflow/contrib/operators/docker_swarm_operator.py
+++ b/airflow/contrib/operators/docker_swarm_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.docker.operators.docker_swarm`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/druid_operator.py b/airflow/contrib/operators/druid_operator.py
index 912aea975c9ee..ea3d1fb705177 100644
--- a/airflow/contrib/operators/druid_operator.py
+++ b/airflow/contrib/operators/druid_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.apache.druid.operators.druid`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/dynamodb_to_s3.py b/airflow/contrib/operators/dynamodb_to_s3.py
index 748a69dc92263..28ade7edc50b3 100644
--- a/airflow/contrib/operators/dynamodb_to_s3.py
+++ b/airflow/contrib/operators/dynamodb_to_s3.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.transfers.dynamodb_to_s3`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/ecs_operator.py b/airflow/contrib/operators/ecs_operator.py
index 37fc32d8c59d9..8c8573847b7ca 100644
--- a/airflow/contrib/operators/ecs_operator.py
+++ b/airflow/contrib/operators/ecs_operator.py
@@ -25,7 +25,8 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.ecs`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/emr_add_steps_operator.py b/airflow/contrib/operators/emr_add_steps_operator.py
index 19f2db0b119b3..421d4dcd3395b 100644
--- a/airflow/contrib/operators/emr_add_steps_operator.py
+++ b/airflow/contrib/operators/emr_add_steps_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.emr_add_steps`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/emr_create_job_flow_operator.py b/airflow/contrib/operators/emr_create_job_flow_operator.py
index eafa3111c0e02..bf59df5879360 100644
--- a/airflow/contrib/operators/emr_create_job_flow_operator.py
+++ b/airflow/contrib/operators/emr_create_job_flow_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.emr_create_job_flow`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/emr_terminate_job_flow_operator.py b/airflow/contrib/operators/emr_terminate_job_flow_operator.py
index 281f64cc80ae9..0158399413427 100644
--- a/airflow/contrib/operators/emr_terminate_job_flow_operator.py
+++ b/airflow/contrib/operators/emr_terminate_job_flow_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.emr_terminate_job_flow`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/file_to_gcs.py b/airflow/contrib/operators/file_to_gcs.py
index 7b7b259d33f15..4e696a56e91c4 100644
--- a/airflow/contrib/operators/file_to_gcs.py
+++ b/airflow/contrib/operators/file_to_gcs.py
@@ -23,7 +23,8 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.transfers.local_to_gcs`,",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -38,6 +39,7 @@ def __init__(self, *args, **kwargs):
"""This class is deprecated.
Please use
`airflow.providers.google.cloud.transfers.local_to_gcs.LocalFilesystemToGCSOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/file_to_wasb.py b/airflow/contrib/operators/file_to_wasb.py
index 8ed0da71336ca..3c9866cbe664c 100644
--- a/airflow/contrib/operators/file_to_wasb.py
+++ b/airflow/contrib/operators/file_to_wasb.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.microsoft.azure.transfers.file_to_wasb`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/gcp_bigtable_operator.py b/airflow/contrib/operators/gcp_bigtable_operator.py
index 507a55f79dfce..7256d64052d14 100644
--- a/airflow/contrib/operators/gcp_bigtable_operator.py
+++ b/airflow/contrib/operators/gcp_bigtable_operator.py
@@ -23,15 +23,19 @@
import warnings
from airflow.providers.google.cloud.operators.bigtable import (
- BigtableCreateInstanceOperator, BigtableCreateTableOperator, BigtableDeleteInstanceOperator,
- BigtableDeleteTableOperator, BigtableUpdateClusterOperator,
+ BigtableCreateInstanceOperator,
+ BigtableCreateTableOperator,
+ BigtableDeleteInstanceOperator,
+ BigtableDeleteTableOperator,
+ BigtableUpdateClusterOperator,
)
from airflow.providers.google.cloud.sensors.bigtable import BigtableTableReplicationCompletedSensor
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.operators.bigtable`"
" or `airflow.providers.google.cloud.sensors.bigtable`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -45,7 +49,8 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.operators.bigtable.BigtableUpdateClusterOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -60,7 +65,8 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.operators.bigtable.BigtableCreateInstanceOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -75,7 +81,8 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.operators.bigtable.BigtableDeleteInstanceOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -90,7 +97,8 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.operators.bigtable.BigtableCreateTableOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -105,7 +113,8 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.operators.bigtable.BigtableDeleteTableOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -122,6 +131,7 @@ def __init__(self, *args, **kwargs):
"""This class is deprecated.
Please use
`airflow.providers.google.cloud.sensors.bigtable.BigtableTableReplicationCompletedSensor`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/gcp_cloud_build_operator.py b/airflow/contrib/operators/gcp_cloud_build_operator.py
index 318ca765c0aaa..ea585089b6440 100644
--- a/airflow/contrib/operators/gcp_cloud_build_operator.py
+++ b/airflow/contrib/operators/gcp_cloud_build_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.operators.cloud_build`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/gcp_compute_operator.py b/airflow/contrib/operators/gcp_compute_operator.py
index 968b4b2df1299..efd513e193ab7 100644
--- a/airflow/contrib/operators/gcp_compute_operator.py
+++ b/airflow/contrib/operators/gcp_compute_operator.py
@@ -20,14 +20,18 @@
import warnings
from airflow.providers.google.cloud.operators.compute import (
- ComputeEngineBaseOperator, ComputeEngineCopyInstanceTemplateOperator,
- ComputeEngineInstanceGroupUpdateManagerTemplateOperator, ComputeEngineSetMachineTypeOperator,
- ComputeEngineStartInstanceOperator, ComputeEngineStopInstanceOperator,
+ ComputeEngineBaseOperator,
+ ComputeEngineCopyInstanceTemplateOperator,
+ ComputeEngineInstanceGroupUpdateManagerTemplateOperator,
+ ComputeEngineSetMachineTypeOperator,
+ ComputeEngineStartInstanceOperator,
+ ComputeEngineStopInstanceOperator,
)
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.operators.compute`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -41,7 +45,8 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.operators.compute.ComputeEngineBaseOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -58,7 +63,8 @@ def __init__(self, *args, **kwargs):
"""This class is deprecated. Please use
`airflow.providers.google.cloud.operators.compute
.ComputeEngineInstanceGroupUpdateManagerTemplateOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -75,7 +81,8 @@ def __init__(self, *args, **kwargs):
"""This class is deprecated.
Please use `airflow.providers.google.cloud.operators.compute
.ComputeEngineStartInstanceOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -91,7 +98,8 @@ def __init__(self, *args, **kwargs):
"""This class is deprecated.
Please use `airflow.providers.google.cloud.operators.compute
.ComputeEngineStopInstanceOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -107,7 +115,8 @@ def __init__(self, *args, **kwargs):
""""This class is deprecated.
Please use `airflow.providers.google.cloud.operators.compute
.ComputeEngineCopyInstanceTemplateOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -123,6 +132,7 @@ def __init__(self, *args, **kwargs):
"""This class is deprecated.
Please use `airflow.providers.google.cloud.operators.compute
.ComputeEngineSetMachineTypeOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/gcp_container_operator.py b/airflow/contrib/operators/gcp_container_operator.py
index 886571d3d0dd9..449bfbef62493 100644
--- a/airflow/contrib/operators/gcp_container_operator.py
+++ b/airflow/contrib/operators/gcp_container_operator.py
@@ -20,12 +20,15 @@
import warnings
from airflow.providers.google.cloud.operators.kubernetes_engine import (
- GKECreateClusterOperator, GKEDeleteClusterOperator, GKEStartPodOperator,
+ GKECreateClusterOperator,
+ GKEDeleteClusterOperator,
+ GKEStartPodOperator,
)
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.operators.kubernetes_engine`",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -39,7 +42,8 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.operators.container.GKECreateClusterOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -54,7 +58,8 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.operators.container.GKEDeleteClusterOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -69,6 +74,7 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.operators.container.GKEStartPodOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/gcp_dlp_operator.py b/airflow/contrib/operators/gcp_dlp_operator.py
index cbf5194089ef3..99d373fd49488 100644
--- a/airflow/contrib/operators/gcp_dlp_operator.py
+++ b/airflow/contrib/operators/gcp_dlp_operator.py
@@ -21,23 +21,42 @@
# pylint: disable=unused-import
from airflow.providers.google.cloud.operators.dlp import ( # noqa
- CloudDLPCancelDLPJobOperator, CloudDLPCreateDeidentifyTemplateOperator, CloudDLPCreateDLPJobOperator,
- CloudDLPCreateInspectTemplateOperator, CloudDLPCreateJobTriggerOperator,
- CloudDLPCreateStoredInfoTypeOperator, CloudDLPDeidentifyContentOperator,
- CloudDLPDeleteDeidentifyTemplateOperator, CloudDLPDeleteDLPJobOperator,
- CloudDLPDeleteInspectTemplateOperator, CloudDLPDeleteJobTriggerOperator,
- CloudDLPDeleteStoredInfoTypeOperator, CloudDLPGetDeidentifyTemplateOperator, CloudDLPGetDLPJobOperator,
- CloudDLPGetDLPJobTriggerOperator, CloudDLPGetInspectTemplateOperator, CloudDLPGetStoredInfoTypeOperator,
- CloudDLPInspectContentOperator, CloudDLPListDeidentifyTemplatesOperator, CloudDLPListDLPJobsOperator,
- CloudDLPListInfoTypesOperator, CloudDLPListInspectTemplatesOperator, CloudDLPListJobTriggersOperator,
- CloudDLPListStoredInfoTypesOperator, CloudDLPRedactImageOperator, CloudDLPReidentifyContentOperator,
- CloudDLPUpdateDeidentifyTemplateOperator, CloudDLPUpdateInspectTemplateOperator,
- CloudDLPUpdateJobTriggerOperator, CloudDLPUpdateStoredInfoTypeOperator,
+ CloudDLPCancelDLPJobOperator,
+ CloudDLPCreateDeidentifyTemplateOperator,
+ CloudDLPCreateDLPJobOperator,
+ CloudDLPCreateInspectTemplateOperator,
+ CloudDLPCreateJobTriggerOperator,
+ CloudDLPCreateStoredInfoTypeOperator,
+ CloudDLPDeidentifyContentOperator,
+ CloudDLPDeleteDeidentifyTemplateOperator,
+ CloudDLPDeleteDLPJobOperator,
+ CloudDLPDeleteInspectTemplateOperator,
+ CloudDLPDeleteJobTriggerOperator,
+ CloudDLPDeleteStoredInfoTypeOperator,
+ CloudDLPGetDeidentifyTemplateOperator,
+ CloudDLPGetDLPJobOperator,
+ CloudDLPGetDLPJobTriggerOperator,
+ CloudDLPGetInspectTemplateOperator,
+ CloudDLPGetStoredInfoTypeOperator,
+ CloudDLPInspectContentOperator,
+ CloudDLPListDeidentifyTemplatesOperator,
+ CloudDLPListDLPJobsOperator,
+ CloudDLPListInfoTypesOperator,
+ CloudDLPListInspectTemplatesOperator,
+ CloudDLPListJobTriggersOperator,
+ CloudDLPListStoredInfoTypesOperator,
+ CloudDLPRedactImageOperator,
+ CloudDLPReidentifyContentOperator,
+ CloudDLPUpdateDeidentifyTemplateOperator,
+ CloudDLPUpdateInspectTemplateOperator,
+ CloudDLPUpdateJobTriggerOperator,
+ CloudDLPUpdateStoredInfoTypeOperator,
)
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.operators.dlp`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -51,8 +70,8 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.operators.dlp.CloudDLPDeleteDLPJobOperator`.""",
-
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -67,7 +86,8 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.operators.dlp.CloudDLPGetDLPJobOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -82,7 +102,8 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.operators.dlp.CloudDLPGetDLPJobTriggerOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -97,6 +118,7 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.operators.dlp.CloudDLPListDLPJobsOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/gcp_function_operator.py b/airflow/contrib/operators/gcp_function_operator.py
index 664f5e5ed8274..cebf4e70619a5 100644
--- a/airflow/contrib/operators/gcp_function_operator.py
+++ b/airflow/contrib/operators/gcp_function_operator.py
@@ -20,12 +20,14 @@
import warnings
from airflow.providers.google.cloud.operators.functions import (
- CloudFunctionDeleteFunctionOperator, CloudFunctionDeployFunctionOperator,
+ CloudFunctionDeleteFunctionOperator,
+ CloudFunctionDeployFunctionOperator,
)
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.operators.functions`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -40,7 +42,8 @@ def __init__(self, *args, **kwargs):
"""This class is deprecated.
Please use
`airflow.providers.google.cloud.operators.function.CloudFunctionDeleteFunctionOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -56,6 +59,7 @@ def __init__(self, *args, **kwargs):
"""This class is deprecated.
Please use
`airflow.providers.google.cloud.operators.function.CloudFunctionDeployFunctionOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/gcp_natural_language_operator.py b/airflow/contrib/operators/gcp_natural_language_operator.py
index daac37eca33eb..e9ec096b5771e 100644
--- a/airflow/contrib/operators/gcp_natural_language_operator.py
+++ b/airflow/contrib/operators/gcp_natural_language_operator.py
@@ -20,15 +20,18 @@
import warnings
from airflow.providers.google.cloud.operators.natural_language import (
- CloudNaturalLanguageAnalyzeEntitiesOperator, CloudNaturalLanguageAnalyzeEntitySentimentOperator,
- CloudNaturalLanguageAnalyzeSentimentOperator, CloudNaturalLanguageClassifyTextOperator,
+ CloudNaturalLanguageAnalyzeEntitiesOperator,
+ CloudNaturalLanguageAnalyzeEntitySentimentOperator,
+ CloudNaturalLanguageAnalyzeSentimentOperator,
+ CloudNaturalLanguageClassifyTextOperator,
)
warnings.warn(
"""This module is deprecated.
Please use `airflow.providers.google.cloud.operators.natural_language`
""",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -46,7 +49,8 @@ def __init__(self, *args, **kwargs):
`airflow.providers.google.cloud.operators.natural_language
.CloudNaturalLanguageAnalyzeEntitiesOperator`.
""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -66,7 +70,8 @@ def __init__(self, *args, **kwargs):
`airflow.providers.google.cloud.operators.natural_language
.CloudNaturalLanguageAnalyzeEntitySentimentOperator`.
""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -84,7 +89,8 @@ def __init__(self, *args, **kwargs):
Please use `airflow.providers.google.cloud.operators.natural_language
.CloudNaturalLanguageAnalyzeSentimentOperator`.
""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -102,6 +108,7 @@ def __init__(self, *args, **kwargs):
Please use `airflow.providers.google.cloud.operators.natural_language
.CloudNaturalLanguageClassifyTextOperator`.
""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/gcp_spanner_operator.py b/airflow/contrib/operators/gcp_spanner_operator.py
index bd5f4fadd3433..d15c7c0dc430c 100644
--- a/airflow/contrib/operators/gcp_spanner_operator.py
+++ b/airflow/contrib/operators/gcp_spanner_operator.py
@@ -20,9 +20,12 @@
import warnings
from airflow.providers.google.cloud.operators.spanner import (
- SpannerDeleteDatabaseInstanceOperator, SpannerDeleteInstanceOperator,
- SpannerDeployDatabaseInstanceOperator, SpannerDeployInstanceOperator,
- SpannerQueryDatabaseInstanceOperator, SpannerUpdateDatabaseInstanceOperator,
+ SpannerDeleteDatabaseInstanceOperator,
+ SpannerDeleteInstanceOperator,
+ SpannerDeployDatabaseInstanceOperator,
+ SpannerDeployInstanceOperator,
+ SpannerQueryDatabaseInstanceOperator,
+ SpannerUpdateDatabaseInstanceOperator,
)
warnings.warn(
@@ -40,7 +43,9 @@ class CloudSpannerInstanceDatabaseDeleteOperator(SpannerDeleteDatabaseInstanceOp
def __init__(self, *args, **kwargs):
warnings.warn(
- self.__doc__, DeprecationWarning, stacklevel=3,
+ self.__doc__,
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -53,7 +58,9 @@ class CloudSpannerInstanceDatabaseDeployOperator(SpannerDeployDatabaseInstanceOp
def __init__(self, *args, **kwargs):
warnings.warn(
- self.__doc__, DeprecationWarning, stacklevel=3,
+ self.__doc__,
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -66,7 +73,9 @@ class CloudSpannerInstanceDatabaseQueryOperator(SpannerQueryDatabaseInstanceOper
def __init__(self, *args, **kwargs):
warnings.warn(
- self.__doc__, DeprecationWarning, stacklevel=3,
+ self.__doc__,
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -79,7 +88,9 @@ class CloudSpannerInstanceDatabaseUpdateOperator(SpannerUpdateDatabaseInstanceOp
def __init__(self, *args, **kwargs):
warnings.warn(
- self.__doc__, DeprecationWarning, stacklevel=3,
+ self.__doc__,
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -92,7 +103,9 @@ class CloudSpannerInstanceDeleteOperator(SpannerDeleteInstanceOperator):
def __init__(self, *args, **kwargs):
warnings.warn(
- self.__doc__, DeprecationWarning, stacklevel=3,
+ self.__doc__,
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -105,6 +118,8 @@ class CloudSpannerInstanceDeployOperator(SpannerDeployInstanceOperator):
def __init__(self, *args, **kwargs):
warnings.warn(
- self.__doc__, DeprecationWarning, stacklevel=3,
+ self.__doc__,
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/gcp_speech_to_text_operator.py b/airflow/contrib/operators/gcp_speech_to_text_operator.py
index 4d2c644096a00..c24a027cdcaac 100644
--- a/airflow/contrib/operators/gcp_speech_to_text_operator.py
+++ b/airflow/contrib/operators/gcp_speech_to_text_operator.py
@@ -23,7 +23,8 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.operators.speech_to_text`",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -40,6 +41,7 @@ def __init__(self, *args, **kwargs):
Please use
`airflow.providers.google.cloud.operators.speech_to_text
.CloudSpeechToTextRecognizeSpeechOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/gcp_sql_operator.py b/airflow/contrib/operators/gcp_sql_operator.py
index 00a5b6edb399d..9ec00f8fd1048 100644
--- a/airflow/contrib/operators/gcp_sql_operator.py
+++ b/airflow/contrib/operators/gcp_sql_operator.py
@@ -20,15 +20,22 @@
import warnings
from airflow.providers.google.cloud.operators.cloud_sql import (
- CloudSQLBaseOperator, CloudSQLCreateInstanceDatabaseOperator, CloudSQLCreateInstanceOperator,
- CloudSQLDeleteInstanceDatabaseOperator, CloudSQLDeleteInstanceOperator, CloudSQLExecuteQueryOperator,
- CloudSQLExportInstanceOperator, CloudSQLImportInstanceOperator, CloudSQLInstancePatchOperator,
+ CloudSQLBaseOperator,
+ CloudSQLCreateInstanceDatabaseOperator,
+ CloudSQLCreateInstanceOperator,
+ CloudSQLDeleteInstanceDatabaseOperator,
+ CloudSQLDeleteInstanceOperator,
+ CloudSQLExecuteQueryOperator,
+ CloudSQLExportInstanceOperator,
+ CloudSQLImportInstanceOperator,
+ CloudSQLInstancePatchOperator,
CloudSQLPatchInstanceDatabaseOperator,
)
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.operators.cloud_sql`",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/gcp_tasks_operator.py b/airflow/contrib/operators/gcp_tasks_operator.py
index 1e39bdd0f7a70..e1a53242cb43d 100644
--- a/airflow/contrib/operators/gcp_tasks_operator.py
+++ b/airflow/contrib/operators/gcp_tasks_operator.py
@@ -21,14 +21,23 @@
# pylint: disable=unused-import
from airflow.providers.google.cloud.operators.tasks import ( # noqa
- CloudTasksQueueCreateOperator, CloudTasksQueueDeleteOperator, CloudTasksQueueGetOperator,
- CloudTasksQueuePauseOperator, CloudTasksQueuePurgeOperator, CloudTasksQueueResumeOperator,
- CloudTasksQueuesListOperator, CloudTasksQueueUpdateOperator, CloudTasksTaskCreateOperator,
- CloudTasksTaskDeleteOperator, CloudTasksTaskGetOperator, CloudTasksTaskRunOperator,
+ CloudTasksQueueCreateOperator,
+ CloudTasksQueueDeleteOperator,
+ CloudTasksQueueGetOperator,
+ CloudTasksQueuePauseOperator,
+ CloudTasksQueuePurgeOperator,
+ CloudTasksQueueResumeOperator,
+ CloudTasksQueuesListOperator,
+ CloudTasksQueueUpdateOperator,
+ CloudTasksTaskCreateOperator,
+ CloudTasksTaskDeleteOperator,
+ CloudTasksTaskGetOperator,
+ CloudTasksTaskRunOperator,
CloudTasksTasksListOperator,
)
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.operators.tasks`",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/gcp_text_to_speech_operator.py b/airflow/contrib/operators/gcp_text_to_speech_operator.py
index 9a43b7cee0a0f..03c2bcde9cae2 100644
--- a/airflow/contrib/operators/gcp_text_to_speech_operator.py
+++ b/airflow/contrib/operators/gcp_text_to_speech_operator.py
@@ -23,7 +23,8 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.operators.text_to_speech`",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -38,6 +39,7 @@ def __init__(self, *args, **kwargs):
"""This class is deprecated.
Please use
`airflow.providers.google.cloud.operators.text_to_speech.CloudTextToSpeechSynthesizeOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/gcp_transfer_operator.py b/airflow/contrib/operators/gcp_transfer_operator.py
index 2f5f0a1cfd7da..a939fdc9e5c82 100644
--- a/airflow/contrib/operators/gcp_transfer_operator.py
+++ b/airflow/contrib/operators/gcp_transfer_operator.py
@@ -23,17 +23,23 @@
import warnings
from airflow.providers.google.cloud.operators.cloud_storage_transfer_service import (
- CloudDataTransferServiceCancelOperationOperator, CloudDataTransferServiceCreateJobOperator,
- CloudDataTransferServiceDeleteJobOperator, CloudDataTransferServiceGCSToGCSOperator,
- CloudDataTransferServiceGetOperationOperator, CloudDataTransferServiceListOperationsOperator,
- CloudDataTransferServicePauseOperationOperator, CloudDataTransferServiceResumeOperationOperator,
- CloudDataTransferServiceS3ToGCSOperator, CloudDataTransferServiceUpdateJobOperator,
+ CloudDataTransferServiceCancelOperationOperator,
+ CloudDataTransferServiceCreateJobOperator,
+ CloudDataTransferServiceDeleteJobOperator,
+ CloudDataTransferServiceGCSToGCSOperator,
+ CloudDataTransferServiceGetOperationOperator,
+ CloudDataTransferServiceListOperationsOperator,
+ CloudDataTransferServicePauseOperationOperator,
+ CloudDataTransferServiceResumeOperationOperator,
+ CloudDataTransferServiceS3ToGCSOperator,
+ CloudDataTransferServiceUpdateJobOperator,
)
warnings.warn(
"This module is deprecated. "
"Please use `airflow.providers.google.cloud.operators.cloud_storage_transfer_service`",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -49,7 +55,8 @@ def __init__(self, *args, **kwargs):
"""This class is deprecated.
Please use `airflow.providers.google.cloud.operators.data_transfer
.CloudDataTransferServiceCreateJobOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -66,7 +73,8 @@ def __init__(self, *args, **kwargs):
"""This class is deprecated.
Please use `airflow.providers.google.cloud.operators.data_transfer
.CloudDataTransferServiceDeleteJobOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -83,7 +91,8 @@ def __init__(self, *args, **kwargs):
"""This class is deprecated.
Please use `airflow.providers.google.cloud.operators.data_transfer
.CloudDataTransferServiceUpdateJobOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -101,7 +110,8 @@ def __init__(self, *args, **kwargs):
Please use `airflow.providers.google.cloud.operators.data_transfer
.CloudDataTransferServiceCancelOperationOperator`.
""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -118,7 +128,8 @@ def __init__(self, *args, **kwargs):
"""This class is deprecated.
Please use `airflow.providers.google.cloud.operators.data_transfer
.CloudDataTransferServiceGetOperationOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -136,7 +147,8 @@ def __init__(self, *args, **kwargs):
Please use `airflow.providers.google.cloud.operators.data_transfer
.CloudDataTransferServicePauseOperationOperator`.
""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -154,7 +166,8 @@ def __init__(self, *args, **kwargs):
Please use `airflow.providers.google.cloud.operators.data_transfer
.CloudDataTransferServiceResumeOperationOperator`.
""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -172,7 +185,8 @@ def __init__(self, *args, **kwargs):
Please use `airflow.providers.google.cloud.operators.data_transfer
.CloudDataTransferServiceListOperationsOperator`.
""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -190,7 +204,8 @@ def __init__(self, *args, **kwargs):
Please use `airflow.providers.google.cloud.operators.data_transfer
.CloudDataTransferServiceGCSToGCSOperator`.
""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -208,6 +223,7 @@ def __init__(self, *args, **kwargs):
Please use `airflow.providers.google.cloud.operators.data_transfer
.CloudDataTransferServiceS3ToGCSOperator`.
""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/gcp_translate_operator.py b/airflow/contrib/operators/gcp_translate_operator.py
index 8b1c939d20c34..22f302d150515 100644
--- a/airflow/contrib/operators/gcp_translate_operator.py
+++ b/airflow/contrib/operators/gcp_translate_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.operators.translate`",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/gcp_translate_speech_operator.py b/airflow/contrib/operators/gcp_translate_speech_operator.py
index 121158d6e552e..724697852639d 100644
--- a/airflow/contrib/operators/gcp_translate_speech_operator.py
+++ b/airflow/contrib/operators/gcp_translate_speech_operator.py
@@ -24,7 +24,8 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.operators.translate_speech`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -40,6 +41,7 @@ def __init__(self, *args, **kwargs):
Please use
`airflow.providers.google.cloud.operators.translate_speech.CloudTranslateSpeechOperator`.
""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/gcp_video_intelligence_operator.py b/airflow/contrib/operators/gcp_video_intelligence_operator.py
index 401c70d577c56..3e08eb6d0e2db 100644
--- a/airflow/contrib/operators/gcp_video_intelligence_operator.py
+++ b/airflow/contrib/operators/gcp_video_intelligence_operator.py
@@ -21,11 +21,13 @@
# pylint: disable=unused-import
from airflow.providers.google.cloud.operators.video_intelligence import ( # noqa
- CloudVideoIntelligenceDetectVideoExplicitContentOperator, CloudVideoIntelligenceDetectVideoLabelsOperator,
+ CloudVideoIntelligenceDetectVideoExplicitContentOperator,
+ CloudVideoIntelligenceDetectVideoLabelsOperator,
CloudVideoIntelligenceDetectVideoShotsOperator,
)
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.operators.video_intelligence`",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/gcp_vision_operator.py b/airflow/contrib/operators/gcp_vision_operator.py
index 99f89737882fc..3975f775ea577 100644
--- a/airflow/contrib/operators/gcp_vision_operator.py
+++ b/airflow/contrib/operators/gcp_vision_operator.py
@@ -20,18 +20,28 @@
import warnings
from airflow.providers.google.cloud.operators.vision import ( # noqa # pylint: disable=unused-import
- CloudVisionAddProductToProductSetOperator, CloudVisionCreateProductOperator,
- CloudVisionCreateProductSetOperator, CloudVisionCreateReferenceImageOperator,
- CloudVisionDeleteProductOperator, CloudVisionDeleteProductSetOperator,
- CloudVisionDetectImageLabelsOperator, CloudVisionDetectImageSafeSearchOperator,
- CloudVisionDetectTextOperator, CloudVisionGetProductOperator, CloudVisionGetProductSetOperator,
- CloudVisionImageAnnotateOperator, CloudVisionRemoveProductFromProductSetOperator,
- CloudVisionTextDetectOperator, CloudVisionUpdateProductOperator, CloudVisionUpdateProductSetOperator,
+ CloudVisionAddProductToProductSetOperator,
+ CloudVisionCreateProductOperator,
+ CloudVisionCreateProductSetOperator,
+ CloudVisionCreateReferenceImageOperator,
+ CloudVisionDeleteProductOperator,
+ CloudVisionDeleteProductSetOperator,
+ CloudVisionDetectImageLabelsOperator,
+ CloudVisionDetectImageSafeSearchOperator,
+ CloudVisionDetectTextOperator,
+ CloudVisionGetProductOperator,
+ CloudVisionGetProductSetOperator,
+ CloudVisionImageAnnotateOperator,
+ CloudVisionRemoveProductFromProductSetOperator,
+ CloudVisionTextDetectOperator,
+ CloudVisionUpdateProductOperator,
+ CloudVisionUpdateProductSetOperator,
)
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.operators.vision`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -45,7 +55,8 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.operators.vision.CloudVisionImageAnnotateOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -60,7 +71,8 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.operators.vision.CloudVisionTextDetectOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -75,7 +87,8 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.operators.vision.CloudVisionCreateProductOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -90,7 +103,8 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.operators.vision.CloudVisionDeleteProductOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -105,7 +119,8 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.operators.vision.CloudVisionGetProductOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -121,7 +136,8 @@ def __init__(self, *args, **kwargs):
"""This class is deprecated.
Please use
`airflow.providers.google.cloud.operators.vision.CloudVisionCreateProductSetOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -137,7 +153,8 @@ def __init__(self, *args, **kwargs):
"""This class is deprecated.
Please use
`airflow.providers.google.cloud.operators.vision.CloudVisionDeleteProductSetOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -153,7 +170,8 @@ def __init__(self, *args, **kwargs):
"""This class is deprecated.
Please use
`airflow.providers.google.cloud.operators.vision.CloudVisionGetProductSetOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -169,7 +187,8 @@ def __init__(self, *args, **kwargs):
"""This class is deprecated.
Please use
`airflow.providers.google.cloud.operators.vision.CloudVisionUpdateProductSetOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -185,7 +204,8 @@ def __init__(self, *args, **kwargs):
"""This class is deprecated.
Please use
`airflow.providers.google.cloud.operators.vision.CloudVisionUpdateProductOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -201,6 +221,7 @@ def __init__(self, *args, **kwargs):
"""This class is deprecated.
Please use
`airflow.providers.google.cloud.operators.vision.CloudVisionCreateReferenceImageOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/gcs_acl_operator.py b/airflow/contrib/operators/gcs_acl_operator.py
index 4803c994d7dbe..efd81d9f225f9 100644
--- a/airflow/contrib/operators/gcs_acl_operator.py
+++ b/airflow/contrib/operators/gcs_acl_operator.py
@@ -20,12 +20,14 @@
import warnings
from airflow.providers.google.cloud.operators.gcs import (
- GCSBucketCreateAclEntryOperator, GCSObjectCreateAclEntryOperator,
+ GCSBucketCreateAclEntryOperator,
+ GCSObjectCreateAclEntryOperator,
)
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.operators.gcs`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -39,7 +41,8 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.operators.gcs.GCSBucketCreateAclEntryOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -54,6 +57,7 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.operators.gcs.GCSObjectCreateAclEntryOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/gcs_delete_operator.py b/airflow/contrib/operators/gcs_delete_operator.py
index e02926f7a9890..4ea63fe3e9a75 100644
--- a/airflow/contrib/operators/gcs_delete_operator.py
+++ b/airflow/contrib/operators/gcs_delete_operator.py
@@ -23,7 +23,8 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.operators.gcs`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -37,6 +38,7 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.operators.gcs.GCSDeleteObjectsOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/gcs_download_operator.py b/airflow/contrib/operators/gcs_download_operator.py
index 43f959d529f7c..bfc421ac83840 100644
--- a/airflow/contrib/operators/gcs_download_operator.py
+++ b/airflow/contrib/operators/gcs_download_operator.py
@@ -23,7 +23,8 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.operators.gcs`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -37,6 +38,7 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.operators.gcs.GCSToLocalFilesystemOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/gcs_list_operator.py b/airflow/contrib/operators/gcs_list_operator.py
index de17ce6d03daa..5ac4a9525935c 100644
--- a/airflow/contrib/operators/gcs_list_operator.py
+++ b/airflow/contrib/operators/gcs_list_operator.py
@@ -23,7 +23,8 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.operators.gcs`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -37,6 +38,7 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.operators.gcs.GCSListObjectsOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/gcs_operator.py b/airflow/contrib/operators/gcs_operator.py
index bef41c5fc4dab..72975bc0f2ad8 100644
--- a/airflow/contrib/operators/gcs_operator.py
+++ b/airflow/contrib/operators/gcs_operator.py
@@ -23,7 +23,8 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.operators.gcs`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -37,6 +38,7 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.operators.gcs.GCSCreateBucketOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/gcs_to_bq.py b/airflow/contrib/operators/gcs_to_bq.py
index 27af7b7953a74..5a9c4546fd52e 100644
--- a/airflow/contrib/operators/gcs_to_bq.py
+++ b/airflow/contrib/operators/gcs_to_bq.py
@@ -23,7 +23,8 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.transfers.gcs_to_bigquery`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -37,6 +38,7 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.transfers.gcs_to_bq.GCSToBigQueryOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/gcs_to_gcs.py b/airflow/contrib/operators/gcs_to_gcs.py
index 4737fa0bf06f3..ca02151657e8d 100644
--- a/airflow/contrib/operators/gcs_to_gcs.py
+++ b/airflow/contrib/operators/gcs_to_gcs.py
@@ -23,7 +23,8 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.transfers.gcs_to_gcs`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -37,6 +38,7 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSToGCSOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/gcs_to_gcs_transfer_operator.py b/airflow/contrib/operators/gcs_to_gcs_transfer_operator.py
index fa1ec22012fac..99d7ca2d805cd 100644
--- a/airflow/contrib/operators/gcs_to_gcs_transfer_operator.py
+++ b/airflow/contrib/operators/gcs_to_gcs_transfer_operator.py
@@ -27,5 +27,6 @@
warnings.warn(
"This module is deprecated. "
"Please use `airflow.providers.google.cloud.operators.cloud_storage_transfer_service`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/gcs_to_gdrive_operator.py b/airflow/contrib/operators/gcs_to_gdrive_operator.py
index 6ec7462d950cf..67f312d6ba39b 100644
--- a/airflow/contrib/operators/gcs_to_gdrive_operator.py
+++ b/airflow/contrib/operators/gcs_to_gdrive_operator.py
@@ -23,7 +23,7 @@
from airflow.providers.google.suite.transfers.gcs_to_gdrive import GCSToGoogleDriveOperator # noqa
warnings.warn(
- "This module is deprecated. "
- "Please use `airflow.providers.google.suite.transfers.gcs_to_gdrive.",
- DeprecationWarning, stacklevel=2
+ "This module is deprecated. " "Please use `airflow.providers.google.suite.transfers.gcs_to_gdrive.",
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/gcs_to_s3.py b/airflow/contrib/operators/gcs_to_s3.py
index dd358ee44cc87..ed7fef6fb767f 100644
--- a/airflow/contrib/operators/gcs_to_s3.py
+++ b/airflow/contrib/operators/gcs_to_s3.py
@@ -23,7 +23,8 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.transfers.gcs_to_s3`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -37,6 +38,7 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"This class is deprecated. "
"Please use `airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSToS3Operator`.",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/grpc_operator.py b/airflow/contrib/operators/grpc_operator.py
index 03a9b6e249dae..abda1207a14ad 100644
--- a/airflow/contrib/operators/grpc_operator.py
+++ b/airflow/contrib/operators/grpc_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.grpc.operators.grpc`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/hive_to_dynamodb.py b/airflow/contrib/operators/hive_to_dynamodb.py
index 6784680272273..48959601c0eea 100644
--- a/airflow/contrib/operators/hive_to_dynamodb.py
+++ b/airflow/contrib/operators/hive_to_dynamodb.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.transfers.hive_to_dynamodb`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/imap_attachment_to_s3_operator.py b/airflow/contrib/operators/imap_attachment_to_s3_operator.py
index 597d6beba0cd3..316da8a2d8e6a 100644
--- a/airflow/contrib/operators/imap_attachment_to_s3_operator.py
+++ b/airflow/contrib/operators/imap_attachment_to_s3_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.transfers.imap_attachment_to_s3`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/jenkins_job_trigger_operator.py b/airflow/contrib/operators/jenkins_job_trigger_operator.py
index 13f68ae3b686b..538a0a2a8025c 100644
--- a/airflow/contrib/operators/jenkins_job_trigger_operator.py
+++ b/airflow/contrib/operators/jenkins_job_trigger_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.jenkins.operators.jenkins_job_trigger`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/jira_operator.py b/airflow/contrib/operators/jira_operator.py
index d7e977f6138a1..d84662972adf5 100644
--- a/airflow/contrib/operators/jira_operator.py
+++ b/airflow/contrib/operators/jira_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.jira.operators.jira`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/kubernetes_pod_operator.py b/airflow/contrib/operators/kubernetes_pod_operator.py
index 2663ceddf2f63..a6587350e7101 100644
--- a/airflow/contrib/operators/kubernetes_pod_operator.py
+++ b/airflow/contrib/operators/kubernetes_pod_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.cncf.kubernetes.operators.kubernetes_pod`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/mlengine_operator.py b/airflow/contrib/operators/mlengine_operator.py
index fd6b530bb8bfd..149ff82180177 100644
--- a/airflow/contrib/operators/mlengine_operator.py
+++ b/airflow/contrib/operators/mlengine_operator.py
@@ -20,13 +20,16 @@
import warnings
from airflow.providers.google.cloud.operators.mlengine import (
- MLEngineManageModelOperator, MLEngineManageVersionOperator, MLEngineStartBatchPredictionJobOperator,
+ MLEngineManageModelOperator,
+ MLEngineManageVersionOperator,
+ MLEngineStartBatchPredictionJobOperator,
MLEngineStartTrainingJobOperator,
)
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.operators.mlengine`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -41,7 +44,8 @@ def __init__(self, *args, **kwargs):
"""This class is deprecated.
Please use
`airflow.providers.google.cloud.operators.mlengine.MLEngineStartBatchPredictionJobOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -56,7 +60,8 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.operators.mlengine.MLEngineManageModelOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -72,7 +77,8 @@ def __init__(self, *args, **kwargs):
"""This class is deprecated.
Please use
`airflow.providers.google.cloud.operators.mlengine.MLEngineStartTrainingJobOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -88,6 +94,7 @@ def __init__(self, *args, **kwargs):
"""This class is deprecated.
Please use
`airflow.providers.google.cloud.operators.mlengine.MLEngineManageVersionOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/mongo_to_s3.py b/airflow/contrib/operators/mongo_to_s3.py
index 82449ee5d7a8a..fe8bd1d3fef2a 100644
--- a/airflow/contrib/operators/mongo_to_s3.py
+++ b/airflow/contrib/operators/mongo_to_s3.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.transfers.mongo_to_s3`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/mssql_to_gcs.py b/airflow/contrib/operators/mssql_to_gcs.py
index b094c27f59f11..3527327c5c34b 100644
--- a/airflow/contrib/operators/mssql_to_gcs.py
+++ b/airflow/contrib/operators/mssql_to_gcs.py
@@ -23,7 +23,8 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.transfers.mssql_to_gcs`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -37,6 +38,7 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.transfers.mssql_to_gcs.MSSQLToGCSOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/mysql_to_gcs.py b/airflow/contrib/operators/mysql_to_gcs.py
index dbb5cdfccb804..73a137527e9a0 100644
--- a/airflow/contrib/operators/mysql_to_gcs.py
+++ b/airflow/contrib/operators/mysql_to_gcs.py
@@ -23,7 +23,8 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.transfers.mysql_to_gcs`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -37,6 +38,7 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.transfers.mysql_to_gcs.MySQLToGCSOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/opsgenie_alert_operator.py b/airflow/contrib/operators/opsgenie_alert_operator.py
index 877b87affe786..acf036aba7a34 100644
--- a/airflow/contrib/operators/opsgenie_alert_operator.py
+++ b/airflow/contrib/operators/opsgenie_alert_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.opsgenie.operators.opsgenie_alert`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/oracle_to_azure_data_lake_transfer.py b/airflow/contrib/operators/oracle_to_azure_data_lake_transfer.py
index 2cdca434975dd..c27d94826fb39 100644
--- a/airflow/contrib/operators/oracle_to_azure_data_lake_transfer.py
+++ b/airflow/contrib/operators/oracle_to_azure_data_lake_transfer.py
@@ -30,5 +30,6 @@
warnings.warn(
"This module is deprecated. "
"Please use `airflow.providers.microsoft.azure.transfers.oracle_to_azure_data_lake`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/oracle_to_oracle_transfer.py b/airflow/contrib/operators/oracle_to_oracle_transfer.py
index e522951963015..28a39c61e73f8 100644
--- a/airflow/contrib/operators/oracle_to_oracle_transfer.py
+++ b/airflow/contrib/operators/oracle_to_oracle_transfer.py
@@ -27,7 +27,8 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.oracle.transfers.oracle_to_oracle`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -43,6 +44,7 @@ def __init__(self, *args, **kwargs):
"""This class is deprecated.
Please use
`airflow.providers.oracle.transfers.oracle_to_oracle.OracleToOracleOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/postgres_to_gcs_operator.py b/airflow/contrib/operators/postgres_to_gcs_operator.py
index 0cbdca34fa140..62f2d06e202a7 100644
--- a/airflow/contrib/operators/postgres_to_gcs_operator.py
+++ b/airflow/contrib/operators/postgres_to_gcs_operator.py
@@ -23,7 +23,8 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.transfers.postgres_to_gcs`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -37,6 +38,7 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.transfers.postgres_to_gcs.PostgresToGCSOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/pubsub_operator.py b/airflow/contrib/operators/pubsub_operator.py
index d2bab4fa1e91f..6527c8f2f5a1b 100644
--- a/airflow/contrib/operators/pubsub_operator.py
+++ b/airflow/contrib/operators/pubsub_operator.py
@@ -23,8 +23,11 @@
import warnings
from airflow.providers.google.cloud.operators.pubsub import (
- PubSubCreateSubscriptionOperator, PubSubCreateTopicOperator, PubSubDeleteSubscriptionOperator,
- PubSubDeleteTopicOperator, PubSubPublishMessageOperator,
+ PubSubCreateSubscriptionOperator,
+ PubSubCreateTopicOperator,
+ PubSubDeleteSubscriptionOperator,
+ PubSubDeleteTopicOperator,
+ PubSubPublishMessageOperator,
)
warnings.warn(
diff --git a/airflow/contrib/operators/qubole_check_operator.py b/airflow/contrib/operators/qubole_check_operator.py
index 7436d67061a27..6eefe0bc8c199 100644
--- a/airflow/contrib/operators/qubole_check_operator.py
+++ b/airflow/contrib/operators/qubole_check_operator.py
@@ -21,10 +21,12 @@
# pylint: disable=unused-import
from airflow.providers.qubole.operators.qubole_check import ( # noqa
- QuboleCheckOperator, QuboleValueCheckOperator,
+ QuboleCheckOperator,
+ QuboleValueCheckOperator,
)
warnings.warn(
"This module is deprecated. Please use `airflow.providers.qubole.operators.qubole_check`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/qubole_operator.py b/airflow/contrib/operators/qubole_operator.py
index 86cc465d284d2..521c1dd0f199b 100644
--- a/airflow/contrib/operators/qubole_operator.py
+++ b/airflow/contrib/operators/qubole_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.qubole.operators.qubole`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/redis_publish_operator.py b/airflow/contrib/operators/redis_publish_operator.py
index f84aaed477813..adee8f12c749d 100644
--- a/airflow/contrib/operators/redis_publish_operator.py
+++ b/airflow/contrib/operators/redis_publish_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.redis.operators.redis_publish`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/s3_copy_object_operator.py b/airflow/contrib/operators/s3_copy_object_operator.py
index 44d34c3dbe71d..81a92154560ac 100644
--- a/airflow/contrib/operators/s3_copy_object_operator.py
+++ b/airflow/contrib/operators/s3_copy_object_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3_copy_object`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/s3_delete_objects_operator.py b/airflow/contrib/operators/s3_delete_objects_operator.py
index 211524e338499..39ab997e71f4f 100644
--- a/airflow/contrib/operators/s3_delete_objects_operator.py
+++ b/airflow/contrib/operators/s3_delete_objects_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3_delete_objects`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/s3_list_operator.py b/airflow/contrib/operators/s3_list_operator.py
index 56a1d9f840a97..e6ae59aea1a2c 100644
--- a/airflow/contrib/operators/s3_list_operator.py
+++ b/airflow/contrib/operators/s3_list_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3_list`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/s3_to_gcs_operator.py b/airflow/contrib/operators/s3_to_gcs_operator.py
index 7a43f761c5402..d6c1f04b5aaf6 100644
--- a/airflow/contrib/operators/s3_to_gcs_operator.py
+++ b/airflow/contrib/operators/s3_to_gcs_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.transfers.s3_to_gcs`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/s3_to_gcs_transfer_operator.py b/airflow/contrib/operators/s3_to_gcs_transfer_operator.py
index 84915a8d93748..d82657b774bc5 100644
--- a/airflow/contrib/operators/s3_to_gcs_transfer_operator.py
+++ b/airflow/contrib/operators/s3_to_gcs_transfer_operator.py
@@ -22,10 +22,13 @@
import warnings
# pylint: disable=unused-import,line-too-long
-from airflow.providers.google.cloud.operators.cloud_storage_transfer_service import CloudDataTransferServiceS3ToGCSOperator # noqa isort:skip
+from airflow.providers.google.cloud.operators.cloud_storage_transfer_service import ( # noqa isort:skip
+ CloudDataTransferServiceS3ToGCSOperator,
+)
warnings.warn(
"This module is deprecated. "
"Please use `airflow.providers.google.cloud.operators.cloud_storage_transfer_service`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/s3_to_sftp_operator.py b/airflow/contrib/operators/s3_to_sftp_operator.py
index b247ce53816e2..1e5a934264e2c 100644
--- a/airflow/contrib/operators/s3_to_sftp_operator.py
+++ b/airflow/contrib/operators/s3_to_sftp_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.transfers.s3_to_sftp`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/sagemaker_base_operator.py b/airflow/contrib/operators/sagemaker_base_operator.py
index c44a1cd189dc8..006424a4f492f 100644
--- a/airflow/contrib/operators/sagemaker_base_operator.py
+++ b/airflow/contrib/operators/sagemaker_base_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_base`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/sagemaker_endpoint_config_operator.py b/airflow/contrib/operators/sagemaker_endpoint_config_operator.py
index cd8b8c8c3995a..cf828d4ac7317 100644
--- a/airflow/contrib/operators/sagemaker_endpoint_config_operator.py
+++ b/airflow/contrib/operators/sagemaker_endpoint_config_operator.py
@@ -30,5 +30,6 @@
warnings.warn(
"This module is deprecated. "
"Please use `airflow.providers.amazon.aws.operators.sagemaker_endpoint_config`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/sagemaker_endpoint_operator.py b/airflow/contrib/operators/sagemaker_endpoint_operator.py
index 6cc6abb370aed..cb1de3f5d65d4 100644
--- a/airflow/contrib/operators/sagemaker_endpoint_operator.py
+++ b/airflow/contrib/operators/sagemaker_endpoint_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_endpoint`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/sagemaker_model_operator.py b/airflow/contrib/operators/sagemaker_model_operator.py
index 174e63851477b..c03799dfee397 100644
--- a/airflow/contrib/operators/sagemaker_model_operator.py
+++ b/airflow/contrib/operators/sagemaker_model_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_model`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/sagemaker_training_operator.py b/airflow/contrib/operators/sagemaker_training_operator.py
index 17f2b137768fe..f6b1a0ecf71b8 100644
--- a/airflow/contrib/operators/sagemaker_training_operator.py
+++ b/airflow/contrib/operators/sagemaker_training_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_training`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/sagemaker_transform_operator.py b/airflow/contrib/operators/sagemaker_transform_operator.py
index 490a0c2749def..e4f8460d6ee52 100644
--- a/airflow/contrib/operators/sagemaker_transform_operator.py
+++ b/airflow/contrib/operators/sagemaker_transform_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_transform`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/sagemaker_tuning_operator.py b/airflow/contrib/operators/sagemaker_tuning_operator.py
index 24cde0fe5bca5..2113eab2a1df6 100644
--- a/airflow/contrib/operators/sagemaker_tuning_operator.py
+++ b/airflow/contrib/operators/sagemaker_tuning_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_tuning`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/segment_track_event_operator.py b/airflow/contrib/operators/segment_track_event_operator.py
index e1a17f37450f7..49fc35e7efea0 100644
--- a/airflow/contrib/operators/segment_track_event_operator.py
+++ b/airflow/contrib/operators/segment_track_event_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.segment.operators.segment_track_event`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/sftp_to_s3_operator.py b/airflow/contrib/operators/sftp_to_s3_operator.py
index 817880c14c706..d52fab2f6c057 100644
--- a/airflow/contrib/operators/sftp_to_s3_operator.py
+++ b/airflow/contrib/operators/sftp_to_s3_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.transfers.sftp_to_s3`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/slack_webhook_operator.py b/airflow/contrib/operators/slack_webhook_operator.py
index bdf517424a5c4..627ca143a9cd5 100644
--- a/airflow/contrib/operators/slack_webhook_operator.py
+++ b/airflow/contrib/operators/slack_webhook_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.slack.operators.slack_webhook`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/snowflake_operator.py b/airflow/contrib/operators/snowflake_operator.py
index de14627669133..41fe2d0a44a95 100644
--- a/airflow/contrib/operators/snowflake_operator.py
+++ b/airflow/contrib/operators/snowflake_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.snowflake.operators.snowflake`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/spark_jdbc_operator.py b/airflow/contrib/operators/spark_jdbc_operator.py
index 178663e04c99e..55c0fc1c4cce7 100644
--- a/airflow/contrib/operators/spark_jdbc_operator.py
+++ b/airflow/contrib/operators/spark_jdbc_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.apache.spark.operators.spark_jdbc`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/spark_sql_operator.py b/airflow/contrib/operators/spark_sql_operator.py
index 37dfda92269a3..b6d6a5f152a77 100644
--- a/airflow/contrib/operators/spark_sql_operator.py
+++ b/airflow/contrib/operators/spark_sql_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.apache.spark.operators.spark_sql`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/spark_submit_operator.py b/airflow/contrib/operators/spark_submit_operator.py
index 4b0c469f331c6..5c3253be5b79b 100644
--- a/airflow/contrib/operators/spark_submit_operator.py
+++ b/airflow/contrib/operators/spark_submit_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.apache.spark.operators.spark_submit`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/sql_to_gcs.py b/airflow/contrib/operators/sql_to_gcs.py
index cfeaa9b5395c0..01ba7dffe961c 100644
--- a/airflow/contrib/operators/sql_to_gcs.py
+++ b/airflow/contrib/operators/sql_to_gcs.py
@@ -23,7 +23,8 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.transfers.sql_to_gcs`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -37,6 +38,7 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.transfers.sql_to_gcs.BaseSQLToGCSOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/sqoop_operator.py b/airflow/contrib/operators/sqoop_operator.py
index d0daa2b8da938..7464a8c1cb16a 100644
--- a/airflow/contrib/operators/sqoop_operator.py
+++ b/airflow/contrib/operators/sqoop_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.apache.sqoop.operators.sqoop`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/ssh_operator.py b/airflow/contrib/operators/ssh_operator.py
index cc29431b8c30d..82e88e2155356 100644
--- a/airflow/contrib/operators/ssh_operator.py
+++ b/airflow/contrib/operators/ssh_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.ssh.operators.ssh`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/vertica_operator.py b/airflow/contrib/operators/vertica_operator.py
index 2504516a98396..71ef3d15d30ed 100644
--- a/airflow/contrib/operators/vertica_operator.py
+++ b/airflow/contrib/operators/vertica_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.vertica.operators.vertica`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/vertica_to_hive.py b/airflow/contrib/operators/vertica_to_hive.py
index 84484ee69d94f..a1ed0f65ea51d 100644
--- a/airflow/contrib/operators/vertica_to_hive.py
+++ b/airflow/contrib/operators/vertica_to_hive.py
@@ -26,7 +26,8 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.apache.hive.transfers.vertica_to_hive`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -42,6 +43,7 @@ def __init__(self, *args, **kwargs):
"""This class is deprecated.
Please use
`airflow.providers.apache.hive.transfers.vertica_to_hive.VerticaToHiveOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/vertica_to_mysql.py b/airflow/contrib/operators/vertica_to_mysql.py
index 7a1d174ebe9bc..acef39e9cf3bb 100644
--- a/airflow/contrib/operators/vertica_to_mysql.py
+++ b/airflow/contrib/operators/vertica_to_mysql.py
@@ -27,7 +27,8 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.mysql.transfers.vertica_to_mysql`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -43,6 +44,7 @@ def __init__(self, *args, **kwargs):
"""This class is deprecated.
Please use
`airflow.providers.mysql.transfers.vertica_to_mysql.VerticaToMySqlOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/wasb_delete_blob_operator.py b/airflow/contrib/operators/wasb_delete_blob_operator.py
index 44fff265b724b..f204e8b174a8a 100644
--- a/airflow/contrib/operators/wasb_delete_blob_operator.py
+++ b/airflow/contrib/operators/wasb_delete_blob_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.microsoft.azure.operators.wasb_delete_blob`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/operators/winrm_operator.py b/airflow/contrib/operators/winrm_operator.py
index 9d24e774b7e34..593ae63a58534 100644
--- a/airflow/contrib/operators/winrm_operator.py
+++ b/airflow/contrib/operators/winrm_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.microsoft.winrm.operators.winrm`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/secrets/gcp_secrets_manager.py b/airflow/contrib/secrets/gcp_secrets_manager.py
index 6142a7ffc5928..0adcbbdc7bc41 100644
--- a/airflow/contrib/secrets/gcp_secrets_manager.py
+++ b/airflow/contrib/secrets/gcp_secrets_manager.py
@@ -40,6 +40,7 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.secrets.secret_manager.CloudSecretManagerBackend`.""",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/sensors/aws_glue_catalog_partition_sensor.py b/airflow/contrib/sensors/aws_glue_catalog_partition_sensor.py
index 99698a7cccbc3..97bff996e9098 100644
--- a/airflow/contrib/sensors/aws_glue_catalog_partition_sensor.py
+++ b/airflow/contrib/sensors/aws_glue_catalog_partition_sensor.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.glue_catalog_partition`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/sensors/azure_cosmos_sensor.py b/airflow/contrib/sensors/azure_cosmos_sensor.py
index beae4a9978dd8..3fdce37d55673 100644
--- a/airflow/contrib/sensors/azure_cosmos_sensor.py
+++ b/airflow/contrib/sensors/azure_cosmos_sensor.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.microsoft.azure.sensors.azure_cosmos`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/sensors/bash_sensor.py b/airflow/contrib/sensors/bash_sensor.py
index 136f78a13a132..c294fb1e8d1d4 100644
--- a/airflow/contrib/sensors/bash_sensor.py
+++ b/airflow/contrib/sensors/bash_sensor.py
@@ -23,6 +23,5 @@
from airflow.sensors.bash import STDOUT, BashSensor, Popen, TemporaryDirectory, gettempdir # noqa
warnings.warn(
- "This module is deprecated. Please use `airflow.sensors.bash`.",
- DeprecationWarning, stacklevel=2
+ "This module is deprecated. Please use `airflow.sensors.bash`.", DeprecationWarning, stacklevel=2
)
diff --git a/airflow/contrib/sensors/celery_queue_sensor.py b/airflow/contrib/sensors/celery_queue_sensor.py
index 61b2ea6959afe..5516eb468a979 100644
--- a/airflow/contrib/sensors/celery_queue_sensor.py
+++ b/airflow/contrib/sensors/celery_queue_sensor.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.celery.sensors.celery_queue`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/sensors/datadog_sensor.py b/airflow/contrib/sensors/datadog_sensor.py
index d32445068c641..d4bdaf40d9669 100644
--- a/airflow/contrib/sensors/datadog_sensor.py
+++ b/airflow/contrib/sensors/datadog_sensor.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.datadog.sensors.datadog`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/sensors/emr_base_sensor.py b/airflow/contrib/sensors/emr_base_sensor.py
index 2519440529def..595f84782a703 100644
--- a/airflow/contrib/sensors/emr_base_sensor.py
+++ b/airflow/contrib/sensors/emr_base_sensor.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.emr_base`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/sensors/emr_job_flow_sensor.py b/airflow/contrib/sensors/emr_job_flow_sensor.py
index dc2a098933250..1e8f62a669bac 100644
--- a/airflow/contrib/sensors/emr_job_flow_sensor.py
+++ b/airflow/contrib/sensors/emr_job_flow_sensor.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.emr_job_flow`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/sensors/emr_step_sensor.py b/airflow/contrib/sensors/emr_step_sensor.py
index 598919824583b..d64b20e6b0c99 100644
--- a/airflow/contrib/sensors/emr_step_sensor.py
+++ b/airflow/contrib/sensors/emr_step_sensor.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.emr_step`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/sensors/file_sensor.py b/airflow/contrib/sensors/file_sensor.py
index eced61c89302d..97d3c458a6440 100644
--- a/airflow/contrib/sensors/file_sensor.py
+++ b/airflow/contrib/sensors/file_sensor.py
@@ -23,6 +23,5 @@
from airflow.sensors.filesystem import FileSensor # noqa
warnings.warn(
- "This module is deprecated. Please use `airflow.sensors.filesystem`.",
- DeprecationWarning, stacklevel=2
+ "This module is deprecated. Please use `airflow.sensors.filesystem`.", DeprecationWarning, stacklevel=2
)
diff --git a/airflow/contrib/sensors/ftp_sensor.py b/airflow/contrib/sensors/ftp_sensor.py
index 548efe90f4b7b..b8f3f6d766f06 100644
--- a/airflow/contrib/sensors/ftp_sensor.py
+++ b/airflow/contrib/sensors/ftp_sensor.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.ftp.sensors.ftp`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/sensors/gcp_transfer_sensor.py b/airflow/contrib/sensors/gcp_transfer_sensor.py
index 7aa9ec68164cd..6adb234e8432e 100644
--- a/airflow/contrib/sensors/gcp_transfer_sensor.py
+++ b/airflow/contrib/sensors/gcp_transfer_sensor.py
@@ -29,7 +29,8 @@
warnings.warn(
"This module is deprecated. "
"Please use `airflow.providers.google.cloud.sensors.cloud_storage_transfer_service`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -44,6 +45,7 @@ def __init__(self, *args, **kwargs):
"""This class is deprecated.
Please use
`airflow.providers.google.cloud.sensors.transfer.CloudDataTransferServiceJobStatusSensor`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/sensors/gcs_sensor.py b/airflow/contrib/sensors/gcs_sensor.py
index 42182c7e6bfcc..2198886ba9ad9 100644
--- a/airflow/contrib/sensors/gcs_sensor.py
+++ b/airflow/contrib/sensors/gcs_sensor.py
@@ -20,13 +20,16 @@
import warnings
from airflow.providers.google.cloud.sensors.gcs import (
- GCSObjectExistenceSensor, GCSObjectsWtihPrefixExistenceSensor, GCSObjectUpdateSensor,
+ GCSObjectExistenceSensor,
+ GCSObjectsWtihPrefixExistenceSensor,
+ GCSObjectUpdateSensor,
GCSUploadSessionCompleteSensor,
)
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.sensors.gcs`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -40,7 +43,8 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.sensors.gcs.GCSObjectExistenceSensor`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -55,7 +59,8 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.sensors.gcs.GCSObjectUpdateSensor`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -70,7 +75,8 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.sensors.gcs.GCSObjectsWtihPrefixExistenceSensor`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -85,6 +91,7 @@ def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.google.cloud.sensors.gcs.GCSUploadSessionCompleteSensor`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/sensors/hdfs_sensor.py b/airflow/contrib/sensors/hdfs_sensor.py
index 744c1af83c880..192314684272d 100644
--- a/airflow/contrib/sensors/hdfs_sensor.py
+++ b/airflow/contrib/sensors/hdfs_sensor.py
@@ -27,7 +27,8 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.apache.hdfs.sensors.hdfs`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -44,7 +45,8 @@ def __init__(self, *args, **kwargs):
"""This class is deprecated.
Please use
`airflow.providers.apache.hdfs.sensors.hdfs.HdfsFolderSensor`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
@@ -62,6 +64,7 @@ def __init__(self, *args, **kwargs):
"""This class is deprecated.
Please use
`airflow.providers.apache.hdfs.sensors.hdfs.HdfsRegexSensor`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/sensors/imap_attachment_sensor.py b/airflow/contrib/sensors/imap_attachment_sensor.py
index 3c1fc676c598f..03d0dfe7caea4 100644
--- a/airflow/contrib/sensors/imap_attachment_sensor.py
+++ b/airflow/contrib/sensors/imap_attachment_sensor.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.imap.sensors.imap_attachment`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/sensors/mongo_sensor.py b/airflow/contrib/sensors/mongo_sensor.py
index d817dee9df49b..28e158676d3fe 100644
--- a/airflow/contrib/sensors/mongo_sensor.py
+++ b/airflow/contrib/sensors/mongo_sensor.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.mongo.sensors.mongo`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/sensors/pubsub_sensor.py b/airflow/contrib/sensors/pubsub_sensor.py
index 9fe3354f4ce5f..2a9cb520e4687 100644
--- a/airflow/contrib/sensors/pubsub_sensor.py
+++ b/airflow/contrib/sensors/pubsub_sensor.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.sensors.pubsub`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/sensors/python_sensor.py b/airflow/contrib/sensors/python_sensor.py
index 6d9f745ef83f5..7e39a2a573fea 100644
--- a/airflow/contrib/sensors/python_sensor.py
+++ b/airflow/contrib/sensors/python_sensor.py
@@ -23,6 +23,5 @@
from airflow.sensors.python import PythonSensor # noqa
warnings.warn(
- "This module is deprecated. Please use `airflow.sensors.python`.",
- DeprecationWarning, stacklevel=2
+ "This module is deprecated. Please use `airflow.sensors.python`.", DeprecationWarning, stacklevel=2
)
diff --git a/airflow/contrib/sensors/qubole_sensor.py b/airflow/contrib/sensors/qubole_sensor.py
index a503c3ea5c898..ec39a40fd9f1b 100644
--- a/airflow/contrib/sensors/qubole_sensor.py
+++ b/airflow/contrib/sensors/qubole_sensor.py
@@ -21,10 +21,13 @@
# pylint: disable=unused-import
from airflow.providers.qubole.sensors.qubole import ( # noqa
- QuboleFileSensor, QubolePartitionSensor, QuboleSensor,
+ QuboleFileSensor,
+ QubolePartitionSensor,
+ QuboleSensor,
)
warnings.warn(
"This module is deprecated. Please use `airflow.providers.qubole.sensors.qubole`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/sensors/redis_key_sensor.py b/airflow/contrib/sensors/redis_key_sensor.py
index 77dd1a431a3bd..f9762de724cf6 100644
--- a/airflow/contrib/sensors/redis_key_sensor.py
+++ b/airflow/contrib/sensors/redis_key_sensor.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.redis.sensors.redis_key`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/sensors/redis_pub_sub_sensor.py b/airflow/contrib/sensors/redis_pub_sub_sensor.py
index 121a46559354a..3b541de5dbb9c 100644
--- a/airflow/contrib/sensors/redis_pub_sub_sensor.py
+++ b/airflow/contrib/sensors/redis_pub_sub_sensor.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.redis.sensors.redis_pub_sub`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/sensors/sagemaker_base_sensor.py b/airflow/contrib/sensors/sagemaker_base_sensor.py
index 1e5e60ad7f45f..54b8323d1f798 100644
--- a/airflow/contrib/sensors/sagemaker_base_sensor.py
+++ b/airflow/contrib/sensors/sagemaker_base_sensor.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker_base`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/sensors/sagemaker_endpoint_sensor.py b/airflow/contrib/sensors/sagemaker_endpoint_sensor.py
index afda91af082e1..f6ebe815d8e28 100644
--- a/airflow/contrib/sensors/sagemaker_endpoint_sensor.py
+++ b/airflow/contrib/sensors/sagemaker_endpoint_sensor.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker_endpoint`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/sensors/sagemaker_training_sensor.py b/airflow/contrib/sensors/sagemaker_training_sensor.py
index 6c1ce75ee5483..61f8d432a4820 100644
--- a/airflow/contrib/sensors/sagemaker_training_sensor.py
+++ b/airflow/contrib/sensors/sagemaker_training_sensor.py
@@ -21,10 +21,12 @@
# pylint: disable=unused-import
from airflow.providers.amazon.aws.sensors.sagemaker_training import ( # noqa
- SageMakerHook, SageMakerTrainingSensor,
+ SageMakerHook,
+ SageMakerTrainingSensor,
)
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker_training`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/sensors/sagemaker_transform_sensor.py b/airflow/contrib/sensors/sagemaker_transform_sensor.py
index f2bdc73ed10ee..daf06add899eb 100644
--- a/airflow/contrib/sensors/sagemaker_transform_sensor.py
+++ b/airflow/contrib/sensors/sagemaker_transform_sensor.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker_transform`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/sensors/sagemaker_tuning_sensor.py b/airflow/contrib/sensors/sagemaker_tuning_sensor.py
index 3cbbf1d902e43..a1886516ad6b5 100644
--- a/airflow/contrib/sensors/sagemaker_tuning_sensor.py
+++ b/airflow/contrib/sensors/sagemaker_tuning_sensor.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker_tuning`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/sensors/wasb_sensor.py b/airflow/contrib/sensors/wasb_sensor.py
index bfe0ec115c4ff..22584473399e7 100644
--- a/airflow/contrib/sensors/wasb_sensor.py
+++ b/airflow/contrib/sensors/wasb_sensor.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.microsoft.azure.sensors.wasb`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/sensors/weekday_sensor.py b/airflow/contrib/sensors/weekday_sensor.py
index 890d07378ae7d..8f462bbb6df42 100644
--- a/airflow/contrib/sensors/weekday_sensor.py
+++ b/airflow/contrib/sensors/weekday_sensor.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.sensors.weekday_sensor`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/task_runner/cgroup_task_runner.py b/airflow/contrib/task_runner/cgroup_task_runner.py
index 0f03fed792b0a..9970da6e4674b 100644
--- a/airflow/contrib/task_runner/cgroup_task_runner.py
+++ b/airflow/contrib/task_runner/cgroup_task_runner.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.task.task_runner.cgroup_task_runner`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/utils/__init__.py b/airflow/contrib/utils/__init__.py
index 9404bc6de998f..4e6cbb7d51e23 100644
--- a/airflow/contrib/utils/__init__.py
+++ b/airflow/contrib/utils/__init__.py
@@ -19,7 +19,4 @@
import warnings
-warnings.warn(
- "This module is deprecated. Please use `airflow.utils`.",
- DeprecationWarning, stacklevel=2
-)
+warnings.warn("This module is deprecated. Please use `airflow.utils`.", DeprecationWarning, stacklevel=2)
diff --git a/airflow/contrib/utils/gcp_field_sanitizer.py b/airflow/contrib/utils/gcp_field_sanitizer.py
index 1e39bb287bf42..a7dca19c2f37d 100644
--- a/airflow/contrib/utils/gcp_field_sanitizer.py
+++ b/airflow/contrib/utils/gcp_field_sanitizer.py
@@ -21,10 +21,12 @@
# pylint: disable=unused-import
from airflow.providers.google.cloud.utils.field_sanitizer import ( # noqa
- GcpBodyFieldSanitizer, GcpFieldSanitizerException,
+ GcpBodyFieldSanitizer,
+ GcpFieldSanitizerException,
)
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.utils.field_sanitizer`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/utils/gcp_field_validator.py b/airflow/contrib/utils/gcp_field_validator.py
index 08d677c2c17e0..e623d37ccdb6c 100644
--- a/airflow/contrib/utils/gcp_field_validator.py
+++ b/airflow/contrib/utils/gcp_field_validator.py
@@ -21,10 +21,13 @@
# pylint: disable=unused-import
from airflow.providers.google.cloud.utils.field_validator import ( # noqa
- GcpBodyFieldValidator, GcpFieldValidationException, GcpValidationSpecificationException,
+ GcpBodyFieldValidator,
+ GcpFieldValidationException,
+ GcpValidationSpecificationException,
)
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.utils.field_validator`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/utils/log/__init__.py b/airflow/contrib/utils/log/__init__.py
index aecb6b83be87f..ecc47d3d85122 100644
--- a/airflow/contrib/utils/log/__init__.py
+++ b/airflow/contrib/utils/log/__init__.py
@@ -18,7 +18,4 @@
import warnings
-warnings.warn(
- "This module is deprecated. Please use `airflow.utils.log`.",
- DeprecationWarning, stacklevel=2
-)
+warnings.warn("This module is deprecated. Please use `airflow.utils.log`.", DeprecationWarning, stacklevel=2)
diff --git a/airflow/contrib/utils/log/task_handler_with_custom_formatter.py b/airflow/contrib/utils/log/task_handler_with_custom_formatter.py
index a6ebd66483c85..5045c90ad3f62 100644
--- a/airflow/contrib/utils/log/task_handler_with_custom_formatter.py
+++ b/airflow/contrib/utils/log/task_handler_with_custom_formatter.py
@@ -23,5 +23,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.utils.log.task_handler_with_custom_formatter`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/utils/mlengine_operator_utils.py b/airflow/contrib/utils/mlengine_operator_utils.py
index 80db88d62016b..0c9439874e4ad 100644
--- a/airflow/contrib/utils/mlengine_operator_utils.py
+++ b/airflow/contrib/utils/mlengine_operator_utils.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.utils.mlengine_operator_utils`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/utils/mlengine_prediction_summary.py b/airflow/contrib/utils/mlengine_prediction_summary.py
index c245d4a3de0ca..2ac126f9cca48 100644
--- a/airflow/contrib/utils/mlengine_prediction_summary.py
+++ b/airflow/contrib/utils/mlengine_prediction_summary.py
@@ -28,5 +28,6 @@
warnings.warn(
"This module is deprecated. "
"Please use `airflow.providers.google.cloud.utils.mlengine_prediction_summary`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/contrib/utils/sendgrid.py b/airflow/contrib/utils/sendgrid.py
index 8801e76cd402f..16408f92d3a5a 100644
--- a/airflow/contrib/utils/sendgrid.py
+++ b/airflow/contrib/utils/sendgrid.py
@@ -30,6 +30,7 @@ def send_email(*args, **kwargs):
"""This function is deprecated. Please use `airflow.providers.sendgrid.utils.emailer.send_email`."""
warnings.warn(
"This function is deprecated. Please use `airflow.providers.sendgrid.utils.emailer.send_email`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
return import_string('airflow.providers.sendgrid.utils.emailer.send_email')(*args, **kwargs)
diff --git a/airflow/contrib/utils/weekday.py b/airflow/contrib/utils/weekday.py
index 85e85cceed5ea..d2548aa391ad3 100644
--- a/airflow/contrib/utils/weekday.py
+++ b/airflow/contrib/utils/weekday.py
@@ -21,6 +21,5 @@
from airflow.utils.weekday import WeekDay # noqa
warnings.warn(
- "This module is deprecated. Please use `airflow.utils.weekday`.",
- DeprecationWarning, stacklevel=2
+ "This module is deprecated. Please use `airflow.utils.weekday`.", DeprecationWarning, stacklevel=2
)
diff --git a/airflow/example_dags/example_bash_operator.py b/airflow/example_dags/example_bash_operator.py
index 0bda87a9a314b..aa3004f65a380 100644
--- a/airflow/example_dags/example_bash_operator.py
+++ b/airflow/example_dags/example_bash_operator.py
@@ -36,7 +36,7 @@
start_date=days_ago(2),
dagrun_timeout=timedelta(minutes=60),
tags=['example', 'example2'],
- params={"example_key": "example_value"}
+ params={"example_key": "example_value"},
)
run_this_last = DummyOperator(
diff --git a/airflow/example_dags/example_branch_operator.py b/airflow/example_dags/example_branch_operator.py
index 6efacc8808138..97256c65dfb3c 100644
--- a/airflow/example_dags/example_branch_operator.py
+++ b/airflow/example_dags/example_branch_operator.py
@@ -34,7 +34,7 @@
default_args=args,
start_date=days_ago(2),
schedule_interval="@daily",
- tags=['example', 'example2']
+ tags=['example', 'example2'],
)
run_this_first = DummyOperator(
diff --git a/airflow/example_dags/example_branch_python_dop_operator_3.py b/airflow/example_dags/example_branch_python_dop_operator_3.py
index 0a45e6093ebe9..9d60fa3eb2b58 100644
--- a/airflow/example_dags/example_branch_python_dop_operator_3.py
+++ b/airflow/example_dags/example_branch_python_dop_operator_3.py
@@ -36,7 +36,7 @@
schedule_interval='*/1 * * * *',
start_date=days_ago(2),
default_args=args,
- tags=['example']
+ tags=['example'],
)
@@ -48,8 +48,11 @@ def should_run(**kwargs):
:return: Id of the task to run
:rtype: str
"""
- print('------------- exec dttm = {} and minute = {}'.
- format(kwargs['execution_date'], kwargs['execution_date'].minute))
+ print(
+ '------------- exec dttm = {} and minute = {}'.format(
+ kwargs['execution_date'], kwargs['execution_date'].minute
+ )
+ )
if kwargs['execution_date'].minute % 2 == 0:
return "dummy_task_1"
else:
diff --git a/airflow/example_dags/example_dag_decorator.py b/airflow/example_dags/example_dag_decorator.py
index 6f8f38229aace..04c87ac399cd5 100644
--- a/airflow/example_dags/example_dag_decorator.py
+++ b/airflow/example_dags/example_dag_decorator.py
@@ -40,25 +40,20 @@ def example_dag_decorator(email: str = 'example@example.com'):
:type email: str
"""
# Using default connection as it's set to httpbin.org by default
- get_ip = SimpleHttpOperator(
- task_id='get_ip', endpoint='get', method='GET'
- )
+ get_ip = SimpleHttpOperator(task_id='get_ip', endpoint='get', method='GET')
@task(multiple_outputs=True)
def prepare_email(raw_json: str) -> Dict[str, str]:
external_ip = json.loads(raw_json)['origin']
return {
'subject': f'Server connected from {external_ip}',
- 'body': f'Seems like today your server executing Airflow is connected from IP {external_ip}
'
+ 'body': f'Seems like today your server executing Airflow is connected from IP {external_ip}
',
}
email_info = prepare_email(get_ip.output)
EmailOperator(
- task_id='send_email',
- to=email,
- subject=email_info['subject'],
- html_content=email_info['body']
+ task_id='send_email', to=email, subject=email_info['subject'], html_content=email_info['body']
)
diff --git a/airflow/example_dags/example_external_task_marker_dag.py b/airflow/example_dags/example_external_task_marker_dag.py
index 13dab0dcd015f..b31240f9226ef 100644
--- a/airflow/example_dags/example_external_task_marker_dag.py
+++ b/airflow/example_dags/example_external_task_marker_dag.py
@@ -52,9 +52,11 @@
tags=['example2'],
) as parent_dag:
# [START howto_operator_external_task_marker]
- parent_task = ExternalTaskMarker(task_id="parent_task",
- external_dag_id="example_external_task_marker_child",
- external_task_id="child_task1")
+ parent_task = ExternalTaskMarker(
+ task_id="parent_task",
+ external_dag_id="example_external_task_marker_child",
+ external_task_id="child_task1",
+ )
# [END howto_operator_external_task_marker]
with DAG(
@@ -64,13 +66,15 @@
tags=['example2'],
) as child_dag:
# [START howto_operator_external_task_sensor]
- child_task1 = ExternalTaskSensor(task_id="child_task1",
- external_dag_id=parent_dag.dag_id,
- external_task_id=parent_task.task_id,
- timeout=600,
- allowed_states=['success'],
- failed_states=['failed', 'skipped'],
- mode="reschedule")
+ child_task1 = ExternalTaskSensor(
+ task_id="child_task1",
+ external_dag_id=parent_dag.dag_id,
+ external_task_id=parent_task.task_id,
+ timeout=600,
+ allowed_states=['success'],
+ failed_states=['failed', 'skipped'],
+ mode="reschedule",
+ )
# [END howto_operator_external_task_sensor]
child_task2 = DummyOperator(task_id="child_task2")
child_task1 >> child_task2
diff --git a/airflow/example_dags/example_kubernetes_executor.py b/airflow/example_dags/example_kubernetes_executor.py
index aa4fc32f1b2c1..a782efcf77e4a 100644
--- a/airflow/example_dags/example_kubernetes_executor.py
+++ b/airflow/example_dags/example_kubernetes_executor.py
@@ -43,24 +43,14 @@
{
'topologyKey': 'kubernetes.io/hostname',
'labelSelector': {
- 'matchExpressions': [
- {
- 'key': 'app',
- 'operator': 'In',
- 'values': ['airflow']
- }
- ]
- }
+ 'matchExpressions': [{'key': 'app', 'operator': 'In', 'values': ['airflow']}]
+ },
}
]
}
}
- tolerations = [{
- 'key': 'dedicated',
- 'operator': 'Equal',
- 'value': 'airflow'
- }]
+ tolerations = [{'key': 'dedicated', 'operator': 'Equal', 'value': 'airflow'}]
def use_zip_binary():
"""
@@ -74,23 +64,20 @@ def use_zip_binary():
raise SystemError("The zip binary is missing")
# You don't have to use any special KubernetesExecutor configuration if you don't want to
- start_task = PythonOperator(
- task_id="start_task",
- python_callable=print_stuff
- )
+ start_task = PythonOperator(task_id="start_task", python_callable=print_stuff)
# But you can if you want to
one_task = PythonOperator(
task_id="one_task",
python_callable=print_stuff,
- executor_config={"KubernetesExecutor": {"image": "airflow/ci:latest"}}
+ executor_config={"KubernetesExecutor": {"image": "airflow/ci:latest"}},
)
# Use the zip binary, which is only found in this special docker image
two_task = PythonOperator(
task_id="two_task",
python_callable=use_zip_binary,
- executor_config={"KubernetesExecutor": {"image": "airflow/ci_zip:latest"}}
+ executor_config={"KubernetesExecutor": {"image": "airflow/ci_zip:latest"}},
)
# Limit resources on this operator/task with node affinity & tolerations
@@ -98,17 +85,20 @@ def use_zip_binary():
task_id="three_task",
python_callable=print_stuff,
executor_config={
- "KubernetesExecutor": {"request_memory": "128Mi",
- "limit_memory": "128Mi",
- "tolerations": tolerations,
- "affinity": affinity}}
+ "KubernetesExecutor": {
+ "request_memory": "128Mi",
+ "limit_memory": "128Mi",
+ "tolerations": tolerations,
+ "affinity": affinity,
+ }
+ },
)
# Add arbitrary labels to worker pods
four_task = PythonOperator(
task_id="four_task",
python_callable=print_stuff,
- executor_config={"KubernetesExecutor": {"labels": {"foo": "bar"}}}
+ executor_config={"KubernetesExecutor": {"labels": {"foo": "bar"}}},
)
start_task >> [one_task, two_task, three_task, four_task]
diff --git a/airflow/example_dags/example_kubernetes_executor_config.py b/airflow/example_dags/example_kubernetes_executor_config.py
index d361227a4e7d1..57b2c4a8435e0 100644
--- a/airflow/example_dags/example_kubernetes_executor_config.py
+++ b/airflow/example_dags/example_kubernetes_executor_config.py
@@ -67,12 +67,9 @@ def test_volume_mount():
start_task = PythonOperator(
task_id="start_task",
python_callable=print_stuff,
- executor_config={"pod_override": k8s.V1Pod(
- metadata=k8s.V1ObjectMeta(
- annotations={"test": "annotation"}
- )
- )
- }
+ executor_config={
+ "pod_override": k8s.V1Pod(metadata=k8s.V1ObjectMeta(annotations={"test": "annotation"}))
+ },
)
# [START task_with_volume]
@@ -86,24 +83,19 @@ def test_volume_mount():
k8s.V1Container(
name="base",
volume_mounts=[
- k8s.V1VolumeMount(
- mount_path="/foo/",
- name="example-kubernetes-test-volume"
- )
- ]
+ k8s.V1VolumeMount(mount_path="/foo/", name="example-kubernetes-test-volume")
+ ],
)
],
volumes=[
k8s.V1Volume(
name="example-kubernetes-test-volume",
- host_path=k8s.V1HostPathVolumeSource(
- path="/tmp/"
- )
+ host_path=k8s.V1HostPathVolumeSource(path="/tmp/"),
)
- ]
+ ],
)
),
- }
+ },
)
# [END task_with_volume]
@@ -117,31 +109,22 @@ def test_volume_mount():
containers=[
k8s.V1Container(
name="base",
- volume_mounts=[k8s.V1VolumeMount(
- mount_path="/shared/",
- name="shared-empty-dir"
- )]
+ volume_mounts=[k8s.V1VolumeMount(mount_path="/shared/", name="shared-empty-dir")],
),
k8s.V1Container(
name="sidecar",
image="ubuntu",
args=["echo \"retrieved from mount\" > /shared/test.txt"],
command=["bash", "-cx"],
- volume_mounts=[k8s.V1VolumeMount(
- mount_path="/shared/",
- name="shared-empty-dir"
- )]
- )
+ volume_mounts=[k8s.V1VolumeMount(mount_path="/shared/", name="shared-empty-dir")],
+ ),
],
volumes=[
- k8s.V1Volume(
- name="shared-empty-dir",
- empty_dir=k8s.V1EmptyDirVolumeSource()
- ),
- ]
+ k8s.V1Volume(name="shared-empty-dir", empty_dir=k8s.V1EmptyDirVolumeSource()),
+ ],
)
),
- }
+ },
)
# [END task_with_sidecar]
@@ -149,27 +132,15 @@ def test_volume_mount():
third_task = PythonOperator(
task_id="non_root_task",
python_callable=print_stuff,
- executor_config={"pod_override": k8s.V1Pod(
- metadata=k8s.V1ObjectMeta(
- labels={
- "release": "stable"
- }
- )
- )
- }
+ executor_config={"pod_override": k8s.V1Pod(metadata=k8s.V1ObjectMeta(labels={"release": "stable"}))},
)
other_ns_task = PythonOperator(
task_id="other_namespace_task",
python_callable=print_stuff,
executor_config={
- "KubernetesExecutor": {
- "namespace": "test-namespace",
- "labels": {
- "release": "stable"
- }
- }
- }
+ "KubernetesExecutor": {"namespace": "test-namespace", "labels": {"release": "stable"}}
+ },
)
start_task >> volume_task >> third_task
diff --git a/airflow/example_dags/example_latest_only.py b/airflow/example_dags/example_latest_only.py
index fda514b8fd8fb..9be8a82958170 100644
--- a/airflow/example_dags/example_latest_only.py
+++ b/airflow/example_dags/example_latest_only.py
@@ -29,7 +29,7 @@
dag_id='latest_only',
schedule_interval=dt.timedelta(hours=4),
start_date=days_ago(2),
- tags=['example2', 'example3']
+ tags=['example2', 'example3'],
)
latest_only = LatestOnlyOperator(task_id='latest_only', dag=dag)
diff --git a/airflow/example_dags/example_latest_only_with_trigger.py b/airflow/example_dags/example_latest_only_with_trigger.py
index 2568139ec1bb4..e9d136063b89c 100644
--- a/airflow/example_dags/example_latest_only_with_trigger.py
+++ b/airflow/example_dags/example_latest_only_with_trigger.py
@@ -32,7 +32,7 @@
dag_id='latest_only_with_trigger',
schedule_interval=dt.timedelta(hours=4),
start_date=days_ago(2),
- tags=['example3']
+ tags=['example3'],
)
latest_only = LatestOnlyOperator(task_id='latest_only', dag=dag)
diff --git a/airflow/example_dags/example_nested_branch_dag.py b/airflow/example_dags/example_nested_branch_dag.py
index 0050d9a8df862..ffc1ad5cf70a0 100644
--- a/airflow/example_dags/example_nested_branch_dag.py
+++ b/airflow/example_dags/example_nested_branch_dag.py
@@ -28,10 +28,7 @@
from airflow.utils.dates import days_ago
with DAG(
- dag_id="example_nested_branch_dag",
- start_date=days_ago(2),
- schedule_interval="@daily",
- tags=["example"]
+ dag_id="example_nested_branch_dag", start_date=days_ago(2), schedule_interval="@daily", tags=["example"]
) as dag:
branch_1 = BranchPythonOperator(task_id="branch_1", python_callable=lambda: "true_1")
join_1 = DummyOperator(task_id="join_1", trigger_rule="none_failed_or_skipped")
diff --git a/airflow/example_dags/example_passing_params_via_test_command.py b/airflow/example_dags/example_passing_params_via_test_command.py
index bcd5318c65fe1..8eaadd7eb245f 100644
--- a/airflow/example_dags/example_passing_params_via_test_command.py
+++ b/airflow/example_dags/example_passing_params_via_test_command.py
@@ -34,7 +34,7 @@
schedule_interval='*/1 * * * *',
start_date=days_ago(1),
dagrun_timeout=timedelta(minutes=4),
- tags=['example']
+ tags=['example'],
)
@@ -45,8 +45,12 @@ def my_py_command(test_mode, params):
-t '{"foo":"bar"}'`
"""
if test_mode:
- print(" 'foo' was passed in via test={} command : kwargs[params][foo] \
- = {}".format(test_mode, params["foo"]))
+ print(
+ " 'foo' was passed in via test={} command : kwargs[params][foo] \
+ = {}".format(
+ test_mode, params["foo"]
+ )
+ )
# Print out the value of "miff", passed in below via the Python Operator
print(" 'miff' was passed in via task params = {}".format(params["miff"]))
return 1
@@ -83,10 +87,6 @@ def print_env_vars(test_mode):
print("AIRFLOW_TEST_MODE={}".format(os.environ.get('AIRFLOW_TEST_MODE')))
-env_var_test_task = PythonOperator(
- task_id='env_var_test_task',
- python_callable=print_env_vars,
- dag=dag
-)
+env_var_test_task = PythonOperator(task_id='env_var_test_task', python_callable=print_env_vars, dag=dag)
run_this >> also_run_this
diff --git a/airflow/example_dags/example_python_operator.py b/airflow/example_dags/example_python_operator.py
index 5b6d7b54d0b74..d5e16a55e95b8 100644
--- a/airflow/example_dags/example_python_operator.py
+++ b/airflow/example_dags/example_python_operator.py
@@ -33,7 +33,7 @@
default_args=args,
schedule_interval=None,
start_date=days_ago(2),
- tags=['example']
+ tags=['example'],
)
@@ -83,6 +83,7 @@ def callable_virtualenv():
from time import sleep
from colorama import Back, Fore, Style
+
print(Fore.RED + 'some red text')
print(Back.GREEN + 'and with a green background')
print(Style.DIM + 'and in dim text')
@@ -96,9 +97,7 @@ def callable_virtualenv():
virtualenv_task = PythonVirtualenvOperator(
task_id="virtualenv_python",
python_callable=callable_virtualenv,
- requirements=[
- "colorama==0.4.0"
- ],
+ requirements=["colorama==0.4.0"],
system_site_packages=False,
dag=dag,
)
diff --git a/airflow/example_dags/example_subdag_operator.py b/airflow/example_dags/example_subdag_operator.py
index f21e3d4db1577..de2853d3eb842 100644
--- a/airflow/example_dags/example_subdag_operator.py
+++ b/airflow/example_dags/example_subdag_operator.py
@@ -32,11 +32,7 @@
}
dag = DAG(
- dag_id=DAG_NAME,
- default_args=args,
- start_date=days_ago(2),
- schedule_interval="@once",
- tags=['example']
+ dag_id=DAG_NAME, default_args=args, start_date=days_ago(2), schedule_interval="@once", tags=['example']
)
start = DummyOperator(
diff --git a/airflow/example_dags/example_trigger_controller_dag.py b/airflow/example_dags/example_trigger_controller_dag.py
index f8fd5d6610b98..39bc766f90b3f 100644
--- a/airflow/example_dags/example_trigger_controller_dag.py
+++ b/airflow/example_dags/example_trigger_controller_dag.py
@@ -30,7 +30,7 @@
default_args={"owner": "airflow"},
start_date=days_ago(2),
schedule_interval="@once",
- tags=['example']
+ tags=['example'],
)
trigger = TriggerDagRunOperator(
diff --git a/airflow/example_dags/example_trigger_target_dag.py b/airflow/example_dags/example_trigger_target_dag.py
index 3f4cfd0f3c856..035527546d289 100644
--- a/airflow/example_dags/example_trigger_target_dag.py
+++ b/airflow/example_dags/example_trigger_target_dag.py
@@ -32,7 +32,7 @@
default_args={"owner": "airflow"},
start_date=days_ago(2),
schedule_interval=None,
- tags=['example']
+ tags=['example'],
)
diff --git a/airflow/example_dags/example_xcom.py b/airflow/example_dags/example_xcom.py
index b3956822a3f64..779e392c70083 100644
--- a/airflow/example_dags/example_xcom.py
+++ b/airflow/example_dags/example_xcom.py
@@ -26,7 +26,7 @@
schedule_interval="@once",
start_date=days_ago(2),
default_args={'owner': 'airflow'},
- tags=['example']
+ tags=['example'],
)
value_1 = [1, 2, 3]
diff --git a/airflow/example_dags/subdags/subdag.py b/airflow/example_dags/subdags/subdag.py
index e65a5c98eaa49..9815af33f93d5 100644
--- a/airflow/example_dags/subdags/subdag.py
+++ b/airflow/example_dags/subdags/subdag.py
@@ -49,4 +49,6 @@ def subdag(parent_dag_name, child_dag_name, args):
)
return dag_subdag
+
+
# [END subdag]
diff --git a/airflow/example_dags/tutorial.py b/airflow/example_dags/tutorial.py
index 39f779c905b67..a00051c43abe2 100644
--- a/airflow/example_dags/tutorial.py
+++ b/airflow/example_dags/tutorial.py
@@ -27,6 +27,7 @@
# The DAG object; we'll need this to instantiate a DAG
from airflow import DAG
+
# Operators; we need this to operate!
from airflow.operators.bash import BashOperator
from airflow.utils.dates import days_ago
diff --git a/airflow/example_dags/tutorial_decorated_etl_dag.py b/airflow/example_dags/tutorial_decorated_etl_dag.py
index 0f78940d824e3..b351e7318ad6f 100644
--- a/airflow/example_dags/tutorial_decorated_etl_dag.py
+++ b/airflow/example_dags/tutorial_decorated_etl_dag.py
@@ -71,6 +71,7 @@ def extract():
order_data_dict = json.loads(data_string)
return order_data_dict
+
# [END extract]
# [START transform]
@@ -87,6 +88,7 @@ def transform(order_data_dict: dict):
total_order_value += value
return {"total_order_value": total_order_value}
+
# [END transform]
# [START load]
@@ -99,6 +101,7 @@ def load(total_order_value: float):
"""
print("Total order value is: %.2f" % total_order_value)
+
# [END load]
# [START main_flow]
diff --git a/airflow/example_dags/tutorial_etl_dag.py b/airflow/example_dags/tutorial_etl_dag.py
index 4a4405e0d2c18..48b519b5e59eb 100644
--- a/airflow/example_dags/tutorial_etl_dag.py
+++ b/airflow/example_dags/tutorial_etl_dag.py
@@ -30,6 +30,7 @@
# The DAG object; we'll need this to instantiate a DAG
from airflow import DAG
+
# Operators; we need this to operate!
from airflow.operators.python import PythonOperator
from airflow.utils.dates import days_ago
@@ -63,6 +64,7 @@ def extract(**kwargs):
ti = kwargs['ti']
data_string = '{"1001": 301.27, "1002": 433.21, "1003": 502.22}'
ti.xcom_push('order_data', data_string)
+
# [END extract_function]
# [START transform_function]
@@ -78,6 +80,7 @@ def transform(**kwargs):
total_value = {"total_order_value": total_order_value}
total_value_json_string = json.dumps(total_value)
ti.xcom_push('total_order_value', total_value_json_string)
+
# [END transform_function]
# [START load_function]
@@ -87,6 +90,7 @@ def load(**kwargs):
total_order_value = json.loads(total_value_string)
print(total_order_value)
+
# [END load_function]
# [START main_flow]
diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py
index 8477019de614d..8ae4c327bd63a 100644
--- a/airflow/executors/base_executor.py
+++ b/airflow/executors/base_executor.py
@@ -60,19 +60,20 @@ class BaseExecutor(LoggingMixin):
def __init__(self, parallelism: int = PARALLELISM):
super().__init__()
self.parallelism: int = parallelism
- self.queued_tasks: OrderedDict[TaskInstanceKey, QueuedTaskInstanceType] \
- = OrderedDict()
+ self.queued_tasks: OrderedDict[TaskInstanceKey, QueuedTaskInstanceType] = OrderedDict()
self.running: Set[TaskInstanceKey] = set()
self.event_buffer: Dict[TaskInstanceKey, EventBufferValueType] = {}
def start(self): # pragma: no cover
"""Executors may need to get things started."""
- def queue_command(self,
- task_instance: TaskInstance,
- command: CommandType,
- priority: int = 1,
- queue: Optional[str] = None):
+ def queue_command(
+ self,
+ task_instance: TaskInstance,
+ command: CommandType,
+ priority: int = 1,
+ queue: Optional[str] = None,
+ ):
"""Queues command to task"""
if task_instance.key not in self.queued_tasks and task_instance.key not in self.running:
self.log.info("Adding to queue: %s", command)
@@ -81,16 +82,17 @@ def queue_command(self,
self.log.error("could not queue task %s", task_instance.key)
def queue_task_instance(
- self,
- task_instance: TaskInstance,
- mark_success: bool = False,
- pickle_id: Optional[str] = None,
- ignore_all_deps: bool = False,
- ignore_depends_on_past: bool = False,
- ignore_task_deps: bool = False,
- ignore_ti_state: bool = False,
- pool: Optional[str] = None,
- cfg_path: Optional[str] = None) -> None:
+ self,
+ task_instance: TaskInstance,
+ mark_success: bool = False,
+ pickle_id: Optional[str] = None,
+ ignore_all_deps: bool = False,
+ ignore_depends_on_past: bool = False,
+ ignore_task_deps: bool = False,
+ ignore_ti_state: bool = False,
+ pool: Optional[str] = None,
+ cfg_path: Optional[str] = None,
+ ) -> None:
"""Queues task instance."""
pool = pool or task_instance.pool
@@ -107,13 +109,15 @@ def queue_task_instance(
ignore_ti_state=ignore_ti_state,
pool=pool,
pickle_id=pickle_id,
- cfg_path=cfg_path)
+ cfg_path=cfg_path,
+ )
self.log.debug("created command %s", command_list_to_run)
self.queue_command(
task_instance,
command_list_to_run,
priority=task_instance.task.priority_weight_total,
- queue=task_instance.task.queue)
+ queue=task_instance.task.queue,
+ )
def has_task(self, task_instance: TaskInstance) -> bool:
"""
@@ -163,7 +167,8 @@ def order_queued_tasks_by_priority(self) -> List[Tuple[TaskInstanceKey, QueuedTa
return sorted(
[(k, v) for k, v in self.queued_tasks.items()], # pylint: disable=unnecessary-comprehension
key=lambda x: x[1][1],
- reverse=True)
+ reverse=True,
+ )
def trigger_tasks(self, open_slots: int) -> None:
"""
@@ -177,10 +182,7 @@ def trigger_tasks(self, open_slots: int) -> None:
key, (command, _, _, ti) = sorted_queue.pop(0)
self.queued_tasks.pop(key)
self.running.add(key)
- self.execute_async(key=key,
- command=command,
- queue=None,
- executor_config=ti.executor_config)
+ self.execute_async(key=key, command=command, queue=None, executor_config=ti.executor_config)
def change_state(self, key: TaskInstanceKey, state: str, info=None) -> None:
"""
@@ -235,11 +237,13 @@ def get_event_buffer(self, dag_ids=None) -> Dict[TaskInstanceKey, EventBufferVal
return cleared_events
- def execute_async(self,
- key: TaskInstanceKey,
- command: CommandType,
- queue: Optional[str] = None,
- executor_config: Optional[Any] = None) -> None: # pragma: no cover
+ def execute_async(
+ self,
+ key: TaskInstanceKey,
+ command: CommandType,
+ queue: Optional[str] = None,
+ executor_config: Optional[Any] = None,
+ ) -> None: # pragma: no cover
"""
This method will execute the command asynchronously.
diff --git a/airflow/executors/celery_executor.py b/airflow/executors/celery_executor.py
index 716d1ed973873..86d76d86cf0c7 100644
--- a/airflow/executors/celery_executor.py
+++ b/airflow/executors/celery_executor.py
@@ -71,9 +71,7 @@
else:
celery_configuration = DEFAULT_CELERY_CONFIG
-app = Celery(
- conf.get('celery', 'CELERY_APP_NAME'),
- config_source=celery_configuration)
+app = Celery(conf.get('celery', 'CELERY_APP_NAME'), config_source=celery_configuration)
@app.task
@@ -103,6 +101,7 @@ def _execute_in_fork(command_to_exec: CommandType) -> None:
ret = 1
try:
from airflow.cli.cli_parser import get_parser
+
parser = get_parser()
# [1:] - remove "airflow" from the start of the command
args = parser.parse_args(command_to_exec[1:])
@@ -123,8 +122,7 @@ def _execute_in_subprocess(command_to_exec: CommandType) -> None:
env = os.environ.copy()
try:
# pylint: disable=unexpected-keyword-arg
- subprocess.check_output(command_to_exec, stderr=subprocess.STDOUT,
- close_fds=True, env=env)
+ subprocess.check_output(command_to_exec, stderr=subprocess.STDOUT, close_fds=True, env=env)
# pylint: disable=unexpected-keyword-arg
except subprocess.CalledProcessError as e:
log.exception('execute_command encountered a CalledProcessError')
@@ -153,8 +151,9 @@ def __init__(self, exception: Exception, exception_traceback: str):
TaskInstanceInCelery = Tuple[TaskInstanceKey, SimpleTaskInstance, CommandType, Optional[str], Task]
-def send_task_to_executor(task_tuple: TaskInstanceInCelery) \
- -> Tuple[TaskInstanceKey, CommandType, Union[AsyncResult, ExceptionWithTraceback]]:
+def send_task_to_executor(
+ task_tuple: TaskInstanceInCelery,
+) -> Tuple[TaskInstanceKey, CommandType, Union[AsyncResult, ExceptionWithTraceback]]:
"""Sends task to executor."""
key, _, command, queue, task_to_run = task_tuple
try:
@@ -185,10 +184,13 @@ def on_celery_import_modules(*args, **kwargs):
import airflow.operators.bash
import airflow.operators.python
import airflow.operators.subdag_operator # noqa: F401
+
try:
import kubernetes.client # noqa: F401
except ImportError:
pass
+
+
# pylint: enable=unused-import
@@ -220,10 +222,7 @@ def __init__(self):
)
def start(self) -> None:
- self.log.debug(
- 'Starting Celery Executor using %s processes for syncing',
- self._sync_parallelism
- )
+ self.log.debug('Starting Celery Executor using %s processes for syncing', self._sync_parallelism)
def _num_tasks_per_send_process(self, to_send_count: int) -> int:
"""
@@ -232,8 +231,7 @@ def _num_tasks_per_send_process(self, to_send_count: int) -> int:
:return: Number of tasks that should be sent per process
:rtype: int
"""
- return max(1,
- int(math.ceil(1.0 * to_send_count / self._sync_parallelism)))
+ return max(1, int(math.ceil(1.0 * to_send_count / self._sync_parallelism)))
def trigger_tasks(self, open_slots: int) -> None:
"""
@@ -297,14 +295,14 @@ def reset_signals():
# Since we are run from inside the SchedulerJob, we don't to
# inherit the signal handlers that we registered there.
import signal
+
signal.signal(signal.SIGINT, signal.SIG_DFL)
signal.signal(signal.SIGTERM, signal.SIG_DFL)
with Pool(processes=num_processes, initializer=reset_signals) as send_pool:
key_and_async_results = send_pool.map(
- send_task_to_executor,
- task_tuples_to_send,
- chunksize=chunksize)
+ send_task_to_executor, task_tuples_to_send, chunksize=chunksize
+ )
return key_and_async_results
def sync(self) -> None:
@@ -348,7 +346,7 @@ def _check_for_stalled_adopted_tasks(self):
"Adopted tasks were still pending after %s, assuming they never made it to celery and "
"clearing:\n\t%s",
self.task_adoption_timeout,
- "\n\t".join([repr(x) for x in timedout_keys])
+ "\n\t".join([repr(x) for x in timedout_keys]),
)
for key in timedout_keys:
self.event_buffer[key] = (State.FAILED, None)
@@ -394,11 +392,13 @@ def end(self, synchronous: bool = False) -> None:
time.sleep(5)
self.sync()
- def execute_async(self,
- key: TaskInstanceKey,
- command: CommandType,
- queue: Optional[str] = None,
- executor_config: Optional[Any] = None):
+ def execute_async(
+ self,
+ key: TaskInstanceKey,
+ command: CommandType,
+ queue: Optional[str] = None,
+ executor_config: Optional[Any] = None,
+ ):
"""Do not allow async execution for Celery executor."""
raise AirflowException("No Async execution for Celery executor.")
@@ -456,14 +456,14 @@ def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance
if adopted:
task_instance_str = '\n\t'.join(adopted)
- self.log.info("Adopted the following %d tasks from a dead executor\n\t%s",
- len(adopted), task_instance_str)
+ self.log.info(
+ "Adopted the following %d tasks from a dead executor\n\t%s", len(adopted), task_instance_str
+ )
return not_adopted_tis
-def fetch_celery_task_state(async_result: AsyncResult) -> \
- Tuple[str, Union[str, ExceptionWithTraceback], Any]:
+def fetch_celery_task_state(async_result: AsyncResult) -> Tuple[str, Union[str, ExceptionWithTraceback], Any]:
"""
Fetch and return the state of the given celery task. The scope of this function is
global so that it can be called by subprocesses in the pool.
@@ -535,8 +535,9 @@ def _get_many_from_db_backend(self, async_tasks) -> Mapping[str, EventBufferValu
return self._prepare_state_and_info_by_task_dict(task_ids, task_results_by_task_id)
@staticmethod
- def _prepare_state_and_info_by_task_dict(task_ids,
- task_results_by_task_id) -> Mapping[str, EventBufferValueType]:
+ def _prepare_state_and_info_by_task_dict(
+ task_ids, task_results_by_task_id
+ ) -> Mapping[str, EventBufferValueType]:
state_info: MutableMapping[str, EventBufferValueType] = {}
for task_id in task_ids:
task_result = task_results_by_task_id.get(task_id)
@@ -556,16 +557,16 @@ def _get_many_using_multiprocessing(self, async_results) -> Mapping[str, EventBu
chunksize = max(1, math.floor(math.ceil(1.0 * len(async_results) / self._sync_parallelism)))
task_id_to_states_and_info = sync_pool.map(
- fetch_celery_task_state,
- async_results,
- chunksize=chunksize)
+ fetch_celery_task_state, async_results, chunksize=chunksize
+ )
states_and_info_by_task_id: MutableMapping[str, EventBufferValueType] = {}
for task_id, state_or_exception, info in task_id_to_states_and_info:
if isinstance(state_or_exception, ExceptionWithTraceback):
self.log.error( # pylint: disable=logging-not-lazy
CELERY_FETCH_ERR_MSG_HEADER + ":%s\n%s\n",
- state_or_exception.exception, state_or_exception.traceback
+ state_or_exception.exception,
+ state_or_exception.traceback,
)
else:
states_and_info_by_task_id[task_id] = state_or_exception, info
diff --git a/airflow/executors/celery_kubernetes_executor.py b/airflow/executors/celery_kubernetes_executor.py
index 0c400f0910079..b58d579ee2df6 100644
--- a/airflow/executors/celery_kubernetes_executor.py
+++ b/airflow/executors/celery_kubernetes_executor.py
@@ -63,31 +63,30 @@ def queue_command(
task_instance: TaskInstance,
command: CommandType,
priority: int = 1,
- queue: Optional[str] = None
+ queue: Optional[str] = None,
):
"""Queues command via celery or kubernetes executor"""
executor = self._router(task_instance)
- self.log.debug(
- "Using executor: %s for %s", executor.__class__.__name__, task_instance.key
- )
+ self.log.debug("Using executor: %s for %s", executor.__class__.__name__, task_instance.key)
executor.queue_command(task_instance, command, priority, queue)
def queue_task_instance(
- self,
- task_instance: TaskInstance,
- mark_success: bool = False,
- pickle_id: Optional[str] = None,
- ignore_all_deps: bool = False,
- ignore_depends_on_past: bool = False,
- ignore_task_deps: bool = False,
- ignore_ti_state: bool = False,
- pool: Optional[str] = None,
- cfg_path: Optional[str] = None) -> None:
+ self,
+ task_instance: TaskInstance,
+ mark_success: bool = False,
+ pickle_id: Optional[str] = None,
+ ignore_all_deps: bool = False,
+ ignore_depends_on_past: bool = False,
+ ignore_task_deps: bool = False,
+ ignore_ti_state: bool = False,
+ pool: Optional[str] = None,
+ cfg_path: Optional[str] = None,
+ ) -> None:
"""Queues task instance via celery or kubernetes executor"""
executor = self._router(SimpleTaskInstance(task_instance))
- self.log.debug("Using executor: %s to queue_task_instance for %s",
- executor.__class__.__name__, task_instance.key
- )
+ self.log.debug(
+ "Using executor: %s to queue_task_instance for %s", executor.__class__.__name__, task_instance.key
+ )
executor.queue_task_instance(
task_instance,
mark_success,
@@ -97,7 +96,7 @@ def queue_task_instance(
ignore_task_deps,
ignore_ti_state,
pool,
- cfg_path
+ cfg_path,
)
def has_task(self, task_instance: TaskInstance) -> bool:
@@ -107,8 +106,9 @@ def has_task(self, task_instance: TaskInstance) -> bool:
:param task_instance: TaskInstance
:return: True if the task is known to this executor
"""
- return self.celery_executor.has_task(task_instance) \
- or self.kubernetes_executor.has_task(task_instance)
+ return self.celery_executor.has_task(task_instance) or self.kubernetes_executor.has_task(
+ task_instance
+ )
def heartbeat(self) -> None:
"""Heartbeat sent to trigger new jobs in celery and kubernetes executor"""
diff --git a/airflow/executors/dask_executor.py b/airflow/executors/dask_executor.py
index e2cb64fda9961..f26ad0ee7d4c2 100644
--- a/airflow/executors/dask_executor.py
+++ b/airflow/executors/dask_executor.py
@@ -65,11 +65,13 @@ def start(self) -> None:
self.client = Client(self.cluster_address, security=security)
self.futures = {}
- def execute_async(self,
- key: TaskInstanceKey,
- command: CommandType,
- queue: Optional[str] = None,
- executor_config: Optional[Any] = None) -> None:
+ def execute_async(
+ self,
+ key: TaskInstanceKey,
+ command: CommandType,
+ queue: Optional[str] = None,
+ executor_config: Optional[Any] = None,
+ ) -> None:
self.validate_command(command)
diff --git a/airflow/executors/debug_executor.py b/airflow/executors/debug_executor.py
index 391b7c701e53a..580dc653957de 100644
--- a/airflow/executors/debug_executor.py
+++ b/airflow/executors/debug_executor.py
@@ -49,7 +49,7 @@ def __init__(self):
self.tasks_params: Dict[TaskInstanceKey, Dict[str, Any]] = {}
self.fail_fast = conf.getboolean("debug", "fail_fast")
- def execute_async(self, *args, **kwargs) -> None: # pylint: disable=signature-differs
+ def execute_async(self, *args, **kwargs) -> None: # pylint: disable=signature-differs
"""The method is replaced by custom trigger_task implementation."""
def sync(self) -> None:
@@ -63,9 +63,7 @@ def sync(self) -> None:
continue
if self._terminated.is_set():
- self.log.info(
- "Executor is terminated! Stopping %s to %s", ti.key, State.FAILED
- )
+ self.log.info("Executor is terminated! Stopping %s to %s", ti.key, State.FAILED)
ti.set_state(State.FAILED)
self.change_state(ti.key, State.FAILED)
continue
@@ -77,9 +75,7 @@ def _run_task(self, ti: TaskInstance) -> bool:
key = ti.key
try:
params = self.tasks_params.pop(ti.key, {})
- ti._run_raw_task( # pylint: disable=protected-access
- job_id=ti.job_id, **params
- )
+ ti._run_raw_task(job_id=ti.job_id, **params) # pylint: disable=protected-access
self.change_state(key, State.SUCCESS)
return True
except Exception as e: # pylint: disable=broad-except
diff --git a/airflow/executors/executor_loader.py b/airflow/executors/executor_loader.py
index 6039a7adcf598..5ae3e66f80804 100644
--- a/airflow/executors/executor_loader.py
+++ b/airflow/executors/executor_loader.py
@@ -45,7 +45,7 @@ class ExecutorLoader:
CELERY_KUBERNETES_EXECUTOR: 'airflow.executors.celery_kubernetes_executor.CeleryKubernetesExecutor',
DASK_EXECUTOR: 'airflow.executors.dask_executor.DaskExecutor',
KUBERNETES_EXECUTOR: 'airflow.executors.kubernetes_executor.KubernetesExecutor',
- DEBUG_EXECUTOR: 'airflow.executors.debug_executor.DebugExecutor'
+ DEBUG_EXECUTOR: 'airflow.executors.debug_executor.DebugExecutor',
}
@classmethod
@@ -55,6 +55,7 @@ def get_default_executor(cls) -> BaseExecutor:
return cls._default_executor
from airflow.configuration import conf
+
executor_name = conf.get('core', 'EXECUTOR')
cls._default_executor = cls.load_executor(executor_name)
@@ -83,12 +84,14 @@ def load_executor(cls, executor_name: str) -> BaseExecutor:
if executor_name.count(".") == 1:
log.debug(
"The executor name looks like the plugin path (executor_name=%s). Trying to load a "
- "executor from a plugin", executor_name
+ "executor from a plugin",
+ executor_name,
)
with suppress(ImportError), suppress(AttributeError):
# Load plugins here for executors as at that time the plugins might not have been
# initialized yet
from airflow import plugins_manager
+
plugins_manager.integrate_executor_plugins()
return import_string(f"airflow.executors.{executor_name}")()
@@ -118,5 +121,5 @@ def __load_celery_kubernetes_executor(cls) -> BaseExecutor:
UNPICKLEABLE_EXECUTORS = (
ExecutorLoader.LOCAL_EXECUTOR,
ExecutorLoader.SEQUENTIAL_EXECUTOR,
- ExecutorLoader.DASK_EXECUTOR
+ ExecutorLoader.DASK_EXECUTOR,
)
diff --git a/airflow/executors/kubernetes_executor.py b/airflow/executors/kubernetes_executor.py
index 79b75b38b18a2..44d42a7ad846b 100644
--- a/airflow/executors/kubernetes_executor.py
+++ b/airflow/executors/kubernetes_executor.py
@@ -85,20 +85,18 @@ def __init__(self): # pylint: disable=too-many-statements
self.airflow_home = settings.AIRFLOW_HOME
self.dags_folder = conf.get(self.core_section, 'dags_folder')
self.parallelism = conf.getint(self.core_section, 'parallelism')
- self.pod_template_file = conf.get(self.kubernetes_section, 'pod_template_file',
- fallback=None)
+ self.pod_template_file = conf.get(self.kubernetes_section, 'pod_template_file', fallback=None)
- self.delete_worker_pods = conf.getboolean(
- self.kubernetes_section, 'delete_worker_pods')
+ self.delete_worker_pods = conf.getboolean(self.kubernetes_section, 'delete_worker_pods')
self.delete_worker_pods_on_failure = conf.getboolean(
- self.kubernetes_section, 'delete_worker_pods_on_failure')
+ self.kubernetes_section, 'delete_worker_pods_on_failure'
+ )
self.worker_pods_creation_batch_size = conf.getint(
- self.kubernetes_section, 'worker_pods_creation_batch_size')
+ self.kubernetes_section, 'worker_pods_creation_batch_size'
+ )
- self.worker_container_repository = conf.get(
- self.kubernetes_section, 'worker_container_repository')
- self.worker_container_tag = conf.get(
- self.kubernetes_section, 'worker_container_tag')
+ self.worker_container_repository = conf.get(self.kubernetes_section, 'worker_container_repository')
+ self.worker_container_tag = conf.get(self.kubernetes_section, 'worker_container_tag')
self.kube_image = f'{self.worker_container_repository}:{self.worker_container_tag}'
# The Kubernetes Namespace in which the Scheduler and Webserver reside. Note
@@ -116,10 +114,12 @@ def __init__(self): # pylint: disable=too-many-statements
kube_client_request_args = conf.get(self.kubernetes_section, 'kube_client_request_args')
if kube_client_request_args:
self.kube_client_request_args = json.loads(kube_client_request_args)
- if self.kube_client_request_args['_request_timeout'] and \
- isinstance(self.kube_client_request_args['_request_timeout'], list):
- self.kube_client_request_args['_request_timeout'] = \
- tuple(self.kube_client_request_args['_request_timeout'])
+ if self.kube_client_request_args['_request_timeout'] and isinstance(
+ self.kube_client_request_args['_request_timeout'], list
+ ):
+ self.kube_client_request_args['_request_timeout'] = tuple(
+ self.kube_client_request_args['_request_timeout']
+ )
else:
self.kube_client_request_args = {}
delete_option_kwargs = conf.get(self.kubernetes_section, 'delete_option_kwargs')
@@ -141,13 +141,15 @@ def _get_security_context_val(self, scontext: str) -> Union[str, int]:
class KubernetesJobWatcher(multiprocessing.Process, LoggingMixin):
"""Watches for Kubernetes jobs"""
- def __init__(self,
- namespace: Optional[str],
- multi_namespace_mode: bool,
- watcher_queue: 'Queue[KubernetesWatchType]',
- resource_version: Optional[str],
- scheduler_job_id: Optional[str],
- kube_config: Configuration):
+ def __init__(
+ self,
+ namespace: Optional[str],
+ multi_namespace_mode: bool,
+ watcher_queue: 'Queue[KubernetesWatchType]',
+ resource_version: Optional[str],
+ scheduler_job_id: Optional[str],
+ kube_config: Configuration,
+ ):
super().__init__()
self.namespace = namespace
self.multi_namespace_mode = multi_namespace_mode
@@ -163,28 +165,31 @@ def run(self) -> None:
raise AirflowException(NOT_STARTED_MESSAGE)
while True:
try:
- self.resource_version = self._run(kube_client, self.resource_version,
- self.scheduler_job_id, self.kube_config)
+ self.resource_version = self._run(
+ kube_client, self.resource_version, self.scheduler_job_id, self.kube_config
+ )
except ReadTimeoutError:
- self.log.warning("There was a timeout error accessing the Kube API. "
- "Retrying request.", exc_info=True)
+ self.log.warning(
+ "There was a timeout error accessing the Kube API. " "Retrying request.", exc_info=True
+ )
time.sleep(1)
except Exception:
self.log.exception('Unknown error in KubernetesJobWatcher. Failing')
raise
else:
- self.log.warning('Watch died gracefully, starting back up with: '
- 'last resource_version: %s', self.resource_version)
-
- def _run(self,
- kube_client: client.CoreV1Api,
- resource_version: Optional[str],
- scheduler_job_id: str,
- kube_config: Any) -> Optional[str]:
- self.log.info(
- 'Event: and now my watch begins starting at resource_version: %s',
- resource_version
- )
+ self.log.warning(
+ 'Watch died gracefully, starting back up with: ' 'last resource_version: %s',
+ self.resource_version,
+ )
+
+ def _run(
+ self,
+ kube_client: client.CoreV1Api,
+ resource_version: Optional[str],
+ scheduler_job_id: str,
+ kube_config: Any,
+ ) -> Optional[str]:
+ self.log.info('Event: and now my watch begins starting at resource_version: %s', resource_version)
watcher = watch.Watch()
kwargs = {'label_selector': f'airflow-worker={scheduler_job_id}'}
@@ -196,20 +201,16 @@ def _run(self,
last_resource_version: Optional[str] = None
if self.multi_namespace_mode:
- list_worker_pods = functools.partial(watcher.stream,
- kube_client.list_pod_for_all_namespaces,
- **kwargs)
+ list_worker_pods = functools.partial(
+ watcher.stream, kube_client.list_pod_for_all_namespaces, **kwargs
+ )
else:
- list_worker_pods = functools.partial(watcher.stream,
- kube_client.list_namespaced_pod,
- self.namespace,
- **kwargs)
+ list_worker_pods = functools.partial(
+ watcher.stream, kube_client.list_namespaced_pod, self.namespace, **kwargs
+ )
for event in list_worker_pods():
task = event['object']
- self.log.info(
- 'Event: %s had an event of type %s',
- task.metadata.name, event['type']
- )
+ self.log.info('Event: %s had an event of type %s', task.metadata.name, event['type'])
if event['type'] == 'ERROR':
return self.process_error(event)
annotations = task.metadata.annotations
@@ -234,29 +235,28 @@ def _run(self,
def process_error(self, event: Any) -> str:
"""Process error response"""
- self.log.error(
- 'Encountered Error response from k8s list namespaced pod stream => %s',
- event
- )
+ self.log.error('Encountered Error response from k8s list namespaced pod stream => %s', event)
raw_object = event['raw_object']
if raw_object['code'] == 410:
self.log.info(
- 'Kubernetes resource version is too old, must reset to 0 => %s',
- (raw_object['message'],)
+ 'Kubernetes resource version is too old, must reset to 0 => %s', (raw_object['message'],)
)
# Return resource version 0
return '0'
raise AirflowException(
- 'Kubernetes failure for %s with code %s and message: %s' %
- (raw_object['reason'], raw_object['code'], raw_object['message'])
+ 'Kubernetes failure for %s with code %s and message: %s'
+ % (raw_object['reason'], raw_object['code'], raw_object['message'])
)
- def process_status(self, pod_id: str,
- namespace: str,
- status: str,
- annotations: Dict[str, str],
- resource_version: str,
- event: Any) -> None:
+ def process_status(
+ self,
+ pod_id: str,
+ namespace: str,
+ status: str,
+ annotations: Dict[str, str],
+ resource_version: str,
+ event: Any,
+ ) -> None:
"""Process status response"""
if status == 'Pending':
if event['type'] == 'DELETED':
@@ -277,19 +277,26 @@ def process_status(self, pod_id: str,
else:
self.log.warning(
'Event: Invalid state: %s on pod: %s in namespace %s with annotations: %s with '
- 'resource_version: %s', status, pod_id, namespace, annotations, resource_version
+ 'resource_version: %s',
+ status,
+ pod_id,
+ namespace,
+ annotations,
+ resource_version,
)
class AirflowKubernetesScheduler(LoggingMixin):
"""Airflow Scheduler for Kubernetes"""
- def __init__(self,
- kube_config: Any,
- task_queue: 'Queue[KubernetesJobType]',
- result_queue: 'Queue[KubernetesResultsType]',
- kube_client: client.CoreV1Api,
- scheduler_job_id: str):
+ def __init__(
+ self,
+ kube_config: Any,
+ task_queue: 'Queue[KubernetesJobType]',
+ result_queue: 'Queue[KubernetesResultsType]',
+ kube_client: client.CoreV1Api,
+ scheduler_job_id: str,
+ ):
super().__init__()
self.log.debug("Creating Kubernetes executor")
self.kube_config = kube_config
@@ -306,12 +313,14 @@ def __init__(self,
def _make_kube_watcher(self) -> KubernetesJobWatcher:
resource_version = ResourceVersion().resource_version
- watcher = KubernetesJobWatcher(watcher_queue=self.watcher_queue,
- namespace=self.kube_config.kube_namespace,
- multi_namespace_mode=self.kube_config.multi_namespace_mode,
- resource_version=resource_version,
- scheduler_job_id=self.scheduler_job_id,
- kube_config=self.kube_config)
+ watcher = KubernetesJobWatcher(
+ watcher_queue=self.watcher_queue,
+ namespace=self.kube_config.kube_namespace,
+ multi_namespace_mode=self.kube_config.multi_namespace_mode,
+ resource_version=resource_version,
+ scheduler_job_id=self.scheduler_job_id,
+ kube_config=self.kube_config,
+ )
watcher.start()
return watcher
@@ -320,8 +329,8 @@ def _health_check_kube_watcher(self):
self.log.debug("KubeJobWatcher alive, continuing")
else:
self.log.error(
- 'Error while health checking kube watcher process. '
- 'Process died for unknown reasons')
+ 'Error while health checking kube watcher process. ' 'Process died for unknown reasons'
+ )
self.kube_watcher = self._make_kube_watcher()
def run_next(self, next_job: KubernetesJobType) -> None:
@@ -340,8 +349,9 @@ def run_next(self, next_job: KubernetesJobType) -> None:
base_worker_pod = PodGenerator.deserialize_model_file(self.kube_config.pod_template_file)
if not base_worker_pod:
- raise AirflowException("could not find a valid worker template yaml at {}"
- .format(self.kube_config.pod_template_file))
+ raise AirflowException(
+ f"could not find a valid worker template yaml at {self.kube_config.pod_template_file}"
+ )
pod = PodGenerator.construct_pod(
namespace=self.namespace,
@@ -354,7 +364,7 @@ def run_next(self, next_job: KubernetesJobType) -> None:
date=execution_date,
command=command,
pod_override_object=kube_executor_config,
- base_worker_pod=base_worker_pod
+ base_worker_pod=base_worker_pod,
)
# Reconcile the pod generated by the Operator and the Pod
# generated by the .cfg file
@@ -370,8 +380,11 @@ def delete_pod(self, pod_id: str, namespace: str) -> None:
try:
self.log.debug("Deleting pod %s in namespace %s", pod_id, namespace)
self.kube_client.delete_namespaced_pod(
- pod_id, namespace, body=client.V1DeleteOptions(**self.kube_config.delete_option_kwargs),
- **self.kube_config.kube_client_request_args)
+ pod_id,
+ namespace,
+ body=client.V1DeleteOptions(**self.kube_config.delete_option_kwargs),
+ **self.kube_config.kube_client_request_args,
+ )
except ApiException as e:
# If the pod is already deleted
if e.status != 404:
@@ -403,8 +416,7 @@ def process_watcher_task(self, task: KubernetesWatchType) -> None:
"""Process the task by watcher."""
pod_id, namespace, state, annotations, resource_version = task
self.log.info(
- 'Attempting to finish pod; pod_id: %s; state: %s; annotations: %s',
- pod_id, state, annotations
+ 'Attempting to finish pod; pod_id: %s; state: %s; annotations: %s', pod_id, state, annotations
)
key = self._annotations_to_key(annotations=annotations)
if key:
@@ -434,7 +446,7 @@ def _make_safe_pod_id(safe_dag_id: str, safe_task_id: str, safe_uuid: str) -> st
"""
safe_key = safe_dag_id + safe_task_id
- safe_pod_id = safe_key[:MAX_POD_ID_LEN - len(safe_uuid) - 1] + "-" + safe_uuid
+ safe_pod_id = safe_key[: MAX_POD_ID_LEN - len(safe_uuid) - 1] + "-" + safe_uuid
return safe_pod_id
@@ -526,67 +538,60 @@ def clear_not_launched_queued_tasks(self, session=None) -> None:
self.log.debug("Clearing tasks that have not been launched")
if not self.kube_client:
raise AirflowException(NOT_STARTED_MESSAGE)
- queued_tasks = session \
- .query(TaskInstance) \
- .filter(TaskInstance.state == State.QUEUED).all()
- self.log.info(
- 'When executor started up, found %s queued task instances',
- len(queued_tasks)
- )
+ queued_tasks = session.query(TaskInstance).filter(TaskInstance.state == State.QUEUED).all()
+ self.log.info('When executor started up, found %s queued task instances', len(queued_tasks))
for task in queued_tasks:
# pylint: disable=protected-access
self.log.debug("Checking task %s", task)
- dict_string = (
- "dag_id={},task_id={},execution_date={},airflow-worker={}".format(
- pod_generator.make_safe_label_value(task.dag_id),
- pod_generator.make_safe_label_value(task.task_id),
- pod_generator.datetime_to_label_safe_datestring(
- task.execution_date
- ),
- self.scheduler_job_id
- )
+ dict_string = "dag_id={},task_id={},execution_date={},airflow-worker={}".format(
+ pod_generator.make_safe_label_value(task.dag_id),
+ pod_generator.make_safe_label_value(task.task_id),
+ pod_generator.datetime_to_label_safe_datestring(task.execution_date),
+ self.scheduler_job_id,
)
# pylint: enable=protected-access
kwargs = dict(label_selector=dict_string)
if self.kube_config.kube_client_request_args:
for key, value in self.kube_config.kube_client_request_args.items():
kwargs[key] = value
- pod_list = self.kube_client.list_namespaced_pod(
- self.kube_config.kube_namespace, **kwargs)
+ pod_list = self.kube_client.list_namespaced_pod(self.kube_config.kube_namespace, **kwargs)
if not pod_list.items:
self.log.info(
- 'TaskInstance: %s found in queued state but was not launched, '
- 'rescheduling', task
+ 'TaskInstance: %s found in queued state but was not launched, ' 'rescheduling', task
)
session.query(TaskInstance).filter(
TaskInstance.dag_id == task.dag_id,
TaskInstance.task_id == task.task_id,
- TaskInstance.execution_date == task.execution_date
+ TaskInstance.execution_date == task.execution_date,
).update({TaskInstance.state: State.NONE})
def _inject_secrets(self) -> None:
def _create_or_update_secret(secret_name, secret_path):
try:
return self.kube_client.create_namespaced_secret(
- self.kube_config.executor_namespace, kubernetes.client.V1Secret(
- data={
- 'key.json': base64.b64encode(open(secret_path).read())},
- metadata=kubernetes.client.V1ObjectMeta(name=secret_name)),
- **self.kube_config.kube_client_request_args)
+ self.kube_config.executor_namespace,
+ kubernetes.client.V1Secret(
+ data={'key.json': base64.b64encode(open(secret_path).read())},
+ metadata=kubernetes.client.V1ObjectMeta(name=secret_name),
+ ),
+ **self.kube_config.kube_client_request_args,
+ )
except ApiException as e:
if e.status == 409:
return self.kube_client.replace_namespaced_secret(
- secret_name, self.kube_config.executor_namespace,
+ secret_name,
+ self.kube_config.executor_namespace,
kubernetes.client.V1Secret(
- data={'key.json': base64.b64encode(
- open(secret_path).read())},
- metadata=kubernetes.client.V1ObjectMeta(name=secret_name)),
- **self.kube_config.kube_client_request_args)
+ data={'key.json': base64.b64encode(open(secret_path).read())},
+ metadata=kubernetes.client.V1ObjectMeta(name=secret_name),
+ ),
+ **self.kube_config.kube_client_request_args,
+ )
self.log.exception(
- 'Exception while trying to inject secret. '
- 'Secret name: %s, error details: %s',
- secret_name, e
+ 'Exception while trying to inject secret. ' 'Secret name: %s, error details: %s',
+ secret_name,
+ e,
)
raise
@@ -599,22 +604,20 @@ def start(self) -> None:
self.log.debug('Start with scheduler_job_id: %s', self.scheduler_job_id)
self.kube_client = get_kube_client()
self.kube_scheduler = AirflowKubernetesScheduler(
- self.kube_config, self.task_queue, self.result_queue,
- self.kube_client, self.scheduler_job_id
+ self.kube_config, self.task_queue, self.result_queue, self.kube_client, self.scheduler_job_id
)
self._inject_secrets()
self.clear_not_launched_queued_tasks()
- def execute_async(self,
- key: TaskInstanceKey,
- command: CommandType,
- queue: Optional[str] = None,
- executor_config: Optional[Any] = None) -> None:
+ def execute_async(
+ self,
+ key: TaskInstanceKey,
+ command: CommandType,
+ queue: Optional[str] = None,
+ executor_config: Optional[Any] = None,
+ ) -> None:
"""Executes task asynchronously"""
- self.log.info(
- 'Add task %s with command %s with executor_config %s',
- key, command, executor_config
- )
+ self.log.info('Add task %s with command %s with executor_config %s', key, command, executor_config)
kube_executor_config = PodGenerator.from_obj(executor_config)
if not self.task_queue:
raise AirflowException(NOT_STARTED_MESSAGE)
@@ -652,7 +655,9 @@ def sync(self) -> None:
except Exception as e: # pylint: disable=broad-except
self.log.exception(
"Exception: %s when attempting to change state of %s to %s, re-queueing.",
- e, results, state
+ e,
+ results,
+ state,
)
self.result_queue.put(results)
finally:
@@ -675,8 +680,10 @@ def sync(self) -> None:
key, _, _ = task
self.change_state(key, State.FAILED, e)
else:
- self.log.warning('ApiException when attempting to run task, re-queueing. '
- 'Message: %s', json.loads(e.body)['message'])
+ self.log.warning(
+ 'ApiException when attempting to run task, re-queueing. ' 'Message: %s',
+ json.loads(e.body)['message'],
+ )
self.task_queue.put(task)
finally:
self.task_queue.task_done()
@@ -684,11 +691,7 @@ def sync(self) -> None:
break
# pylint: enable=too-many-nested-blocks
- def _change_state(self,
- key: TaskInstanceKey,
- state: Optional[str],
- pod_id: str,
- namespace: str) -> None:
+ def _change_state(self, key: TaskInstanceKey, state: Optional[str], pod_id: str, namespace: str) -> None:
if state != State.RUNNING:
if self.kube_config.delete_worker_pods:
if not self.kube_scheduler:
@@ -706,18 +709,12 @@ def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance
tis_to_flush = [ti for ti in tis if not ti.external_executor_id]
scheduler_job_ids = [ti.external_executor_id for ti in tis]
pod_ids = {
- create_pod_id(dag_id=ti.dag_id, task_id=ti.task_id): ti
- for ti in tis if ti.external_executor_id
+ create_pod_id(dag_id=ti.dag_id, task_id=ti.task_id): ti for ti in tis if ti.external_executor_id
}
kube_client: client.CoreV1Api = self.kube_client
for scheduler_job_id in scheduler_job_ids:
- kwargs = {
- 'label_selector': f'airflow-worker={scheduler_job_id}'
- }
- pod_list = kube_client.list_namespaced_pod(
- namespace=self.kube_config.kube_namespace,
- **kwargs
- )
+ kwargs = {'label_selector': f'airflow-worker={scheduler_job_id}'}
+ pod_list = kube_client.list_namespaced_pod(namespace=self.kube_config.kube_namespace, **kwargs)
for pod in pod_list.items:
self.adopt_launched_task(kube_client, pod, pod_ids)
self._adopt_completed_pods(kube_client)
@@ -738,8 +735,11 @@ def adopt_launched_task(self, kube_client, pod, pod_ids: dict):
task_id = pod.metadata.labels['task_id']
pod_id = create_pod_id(dag_id=dag_id, task_id=task_id)
if pod_id not in pod_ids:
- self.log.error("attempting to adopt task %s in dag %s"
- " which was not specified by database", task_id, dag_id)
+ self.log.error(
+ "attempting to adopt task %s in dag %s" " which was not specified by database",
+ task_id,
+ dag_id,
+ )
else:
try:
kube_client.patch_namespaced_pod(
@@ -798,13 +798,18 @@ def _flush_result_queue(self) -> None:
self.log.warning('Executor shutting down, flushing results=%s', results)
try:
key, state, pod_id, namespace, resource_version = results
- self.log.info('Changing state of %s to %s : resource_version=%d', results, state,
- resource_version)
+ self.log.info(
+ 'Changing state of %s to %s : resource_version=%d', results, state, resource_version
+ )
try:
self._change_state(key, state, pod_id, namespace)
except Exception as e: # pylint: disable=broad-except
- self.log.exception('Ignoring exception: %s when attempting to change state of %s '
- 'to %s.', e, results, state)
+ self.log.exception(
+ 'Ignoring exception: %s when attempting to change state of %s ' 'to %s.',
+ e,
+ results,
+ state,
+ )
finally:
self.result_queue.task_done()
except Empty:
diff --git a/airflow/executors/local_executor.py b/airflow/executors/local_executor.py
index 4ad4033666276..729a45e8b1e43 100644
--- a/airflow/executors/local_executor.py
+++ b/airflow/executors/local_executor.py
@@ -36,7 +36,8 @@
from airflow.exceptions import AirflowException
from airflow.executors.base_executor import NOT_STARTED_MESSAGE, PARALLELISM, BaseExecutor, CommandType
from airflow.models.taskinstance import ( # pylint: disable=unused-import # noqa: F401
- TaskInstanceKey, TaskInstanceStateType,
+ TaskInstanceKey,
+ TaskInstanceStateType,
)
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.state import State
@@ -100,6 +101,7 @@ def _execute_work_in_fork(self, command: CommandType) -> str:
return State.SUCCESS if ret == 0 else State.FAILED
from airflow.sentry import Sentry
+
ret = 1
try:
import signal
@@ -140,10 +142,9 @@ class LocalWorker(LocalWorkerBase):
:param command: Command to execute
"""
- def __init__(self,
- result_queue: 'Queue[TaskInstanceStateType]',
- key: TaskInstanceKey,
- command: CommandType):
+ def __init__(
+ self, result_queue: 'Queue[TaskInstanceStateType]', key: TaskInstanceKey, command: CommandType
+ ):
super().__init__(result_queue)
self.key: TaskInstanceKey = key
self.command: CommandType = command
@@ -162,9 +163,7 @@ class QueuedLocalWorker(LocalWorkerBase):
:param result_queue: queue where worker puts results after finishing tasks
"""
- def __init__(self,
- task_queue: 'Queue[ExecutorWorkType]',
- result_queue: 'Queue[TaskInstanceStateType]'):
+ def __init__(self, task_queue: 'Queue[ExecutorWorkType]', result_queue: 'Queue[TaskInstanceStateType]'):
super().__init__(result_queue=result_queue)
self.task_queue = task_queue
@@ -196,8 +195,9 @@ def __init__(self, parallelism: int = PARALLELISM):
self.workers: List[QueuedLocalWorker] = []
self.workers_used: int = 0
self.workers_active: int = 0
- self.impl: Optional[Union['LocalExecutor.UnlimitedParallelism',
- 'LocalExecutor.LimitedParallelism']] = None
+ self.impl: Optional[
+ Union['LocalExecutor.UnlimitedParallelism', 'LocalExecutor.LimitedParallelism']
+ ] = None
class UnlimitedParallelism:
"""
@@ -216,11 +216,13 @@ def start(self) -> None:
self.executor.workers_active = 0
# pylint: disable=unused-argument # pragma: no cover
- def execute_async(self,
- key: TaskInstanceKey,
- command: CommandType,
- queue: Optional[str] = None,
- executor_config: Optional[Any] = None) -> None:
+ def execute_async(
+ self,
+ key: TaskInstanceKey,
+ command: CommandType,
+ queue: Optional[str] = None,
+ executor_config: Optional[Any] = None,
+ ) -> None:
"""
Executes task asynchronously.
@@ -289,7 +291,7 @@ def execute_async(
key: TaskInstanceKey,
command: CommandType,
queue: Optional[str] = None, # pylint: disable=unused-argument
- executor_config: Optional[Any] = None # pylint: disable=unused-argument
+ executor_config: Optional[Any] = None, # pylint: disable=unused-argument
) -> None:
"""
Executes task asynchronously.
@@ -331,15 +333,21 @@ def start(self) -> None:
self.workers = []
self.workers_used = 0
self.workers_active = 0
- self.impl = (LocalExecutor.UnlimitedParallelism(self) if self.parallelism == 0
- else LocalExecutor.LimitedParallelism(self))
+ self.impl = (
+ LocalExecutor.UnlimitedParallelism(self)
+ if self.parallelism == 0
+ else LocalExecutor.LimitedParallelism(self)
+ )
self.impl.start()
- def execute_async(self, key: TaskInstanceKey,
- command: CommandType,
- queue: Optional[str] = None,
- executor_config: Optional[Any] = None) -> None:
+ def execute_async(
+ self,
+ key: TaskInstanceKey,
+ command: CommandType,
+ queue: Optional[str] = None,
+ executor_config: Optional[Any] = None,
+ ) -> None:
"""Execute asynchronously."""
if not self.impl:
raise AirflowException(NOT_STARTED_MESSAGE)
diff --git a/airflow/executors/sequential_executor.py b/airflow/executors/sequential_executor.py
index 18a4747790c45..456e3e9893e8b 100644
--- a/airflow/executors/sequential_executor.py
+++ b/airflow/executors/sequential_executor.py
@@ -44,11 +44,13 @@ def __init__(self):
super().__init__()
self.commands_to_run = []
- def execute_async(self,
- key: TaskInstanceKey,
- command: CommandType,
- queue: Optional[str] = None,
- executor_config: Optional[Any] = None) -> None:
+ def execute_async(
+ self,
+ key: TaskInstanceKey,
+ command: CommandType,
+ queue: Optional[str] = None,
+ executor_config: Optional[Any] = None,
+ ) -> None:
self.validate_command(command)
self.commands_to_run.append((key, command))
diff --git a/airflow/hooks/base_hook.py b/airflow/hooks/base_hook.py
index 9a54077519ea1..cb9f6dcab4345 100644
--- a/airflow/hooks/base_hook.py
+++ b/airflow/hooks/base_hook.py
@@ -64,7 +64,7 @@ def get_connection(cls, conn_id: str) -> Connection:
conn.schema,
conn.login,
"XXXXXXXX" if conn.password else None,
- "XXXXXXXX" if conn.extra_dejson else None
+ "XXXXXXXX" if conn.extra_dejson else None,
)
return conn
diff --git a/airflow/hooks/dbapi_hook.py b/airflow/hooks/dbapi_hook.py
index fde2463409df0..17a4c81e42f8b 100644
--- a/airflow/hooks/dbapi_hook.py
+++ b/airflow/hooks/dbapi_hook.py
@@ -67,11 +67,7 @@ def __init__(self, *args, **kwargs):
def get_conn(self):
"""Returns a connection object"""
db = self.get_connection(getattr(self, self.conn_name_attr))
- return self.connector.connect(
- host=db.host,
- port=db.port,
- username=db.login,
- schema=db.schema)
+ return self.connector.connect(host=db.host, port=db.port, username=db.login, schema=db.schema)
def get_uri(self) -> str:
"""
@@ -86,8 +82,7 @@ def get_uri(self) -> str:
host = conn.host
if conn.port is not None:
host += f':{conn.port}'
- uri = '{conn.conn_type}://{login}{host}/'.format(
- conn=conn, login=login, host=host)
+ uri = f'{conn.conn_type}://{login}{host}/'
if conn.schema:
uri += conn.schema
return uri
@@ -199,7 +194,7 @@ def set_autocommit(self, conn, autocommit):
if not self.supports_autocommit and autocommit:
self.log.warning(
"%s connection doesn't support autocommit but autocommit activated.",
- getattr(self, self.conn_name_attr)
+ getattr(self, self.conn_name_attr),
)
conn.autocommit = autocommit
@@ -238,7 +233,9 @@ def _generate_insert_sql(table, values, target_fields, replace, **kwargs):
:return: The generated INSERT or REPLACE SQL statement
:rtype: str
"""
- placeholders = ["%s", ] * len(values)
+ placeholders = [
+ "%s",
+ ] * len(values)
if target_fields:
target_fields = ", ".join(target_fields)
@@ -250,14 +247,10 @@ def _generate_insert_sql(table, values, target_fields, replace, **kwargs):
sql = "INSERT INTO "
else:
sql = "REPLACE INTO "
- sql += "{} {} VALUES ({})".format(
- table,
- target_fields,
- ",".join(placeholders))
+ sql += "{} {} VALUES ({})".format(table, target_fields, ",".join(placeholders))
return sql
- def insert_rows(self, table, rows, target_fields=None, commit_every=1000,
- replace=False, **kwargs):
+ def insert_rows(self, table, rows, target_fields=None, commit_every=1000, replace=False, **kwargs):
"""
A generic way to insert a set of tuples into a table,
a new transaction is created every commit_every rows
@@ -287,15 +280,11 @@ def insert_rows(self, table, rows, target_fields=None, commit_every=1000,
for cell in row:
lst.append(self._serialize_cell(cell, conn))
values = tuple(lst)
- sql = self._generate_insert_sql(
- table, values, target_fields, replace, **kwargs
- )
+ sql = self._generate_insert_sql(table, values, target_fields, replace, **kwargs)
cur.execute(sql, values)
if commit_every and i % commit_every == 0:
conn.commit()
- self.log.info(
- "Loaded %s rows into %s so far", i, table
- )
+ self.log.info("Loaded %s rows into %s so far", i, table)
conn.commit()
self.log.info("Done loading. Loaded a total of %s rows", i)
diff --git a/airflow/hooks/docker_hook.py b/airflow/hooks/docker_hook.py
index b7da340770da6..ff02d3c1978c9 100644
--- a/airflow/hooks/docker_hook.py
+++ b/airflow/hooks/docker_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.docker.hooks.docker`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/hooks/druid_hook.py b/airflow/hooks/druid_hook.py
index 11f7d6c4429b3..9ae87b01ef352 100644
--- a/airflow/hooks/druid_hook.py
+++ b/airflow/hooks/druid_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.apache.druid.hooks.druid`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/hooks/hdfs_hook.py b/airflow/hooks/hdfs_hook.py
index 10e81c5c6d39e..9b0cdb66d311b 100644
--- a/airflow/hooks/hdfs_hook.py
+++ b/airflow/hooks/hdfs_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.apache.hdfs.hooks.hdfs`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/hooks/hive_hooks.py b/airflow/hooks/hive_hooks.py
index 990ee470430a6..5009647593e49 100644
--- a/airflow/hooks/hive_hooks.py
+++ b/airflow/hooks/hive_hooks.py
@@ -21,10 +21,14 @@
# pylint: disable=unused-import
from airflow.providers.apache.hive.hooks.hive import ( # noqa
- HIVE_QUEUE_PRIORITIES, HiveCliHook, HiveMetastoreHook, HiveServer2Hook,
+ HIVE_QUEUE_PRIORITIES,
+ HiveCliHook,
+ HiveMetastoreHook,
+ HiveServer2Hook,
)
warnings.warn(
"This module is deprecated. Please use `airflow.providers.apache.hive.hooks.hive`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/hooks/http_hook.py b/airflow/hooks/http_hook.py
index 0bc5203996b6e..80fed8f0d94e3 100644
--- a/airflow/hooks/http_hook.py
+++ b/airflow/hooks/http_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.http.hooks.http`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/hooks/jdbc_hook.py b/airflow/hooks/jdbc_hook.py
index 4ae5efca20162..1f7e127b98844 100644
--- a/airflow/hooks/jdbc_hook.py
+++ b/airflow/hooks/jdbc_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.jdbc.hooks.jdbc`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/hooks/mssql_hook.py b/airflow/hooks/mssql_hook.py
index 2584ea76ee73f..cc9bc848bcf72 100644
--- a/airflow/hooks/mssql_hook.py
+++ b/airflow/hooks/mssql_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.microsoft.mssql.hooks.mssql`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/hooks/mysql_hook.py b/airflow/hooks/mysql_hook.py
index a077a535a9aa1..97e3669613cc3 100644
--- a/airflow/hooks/mysql_hook.py
+++ b/airflow/hooks/mysql_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.mysql.hooks.mysql`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/hooks/oracle_hook.py b/airflow/hooks/oracle_hook.py
index e820ee3bc1aa6..202e9fe9f4f0f 100644
--- a/airflow/hooks/oracle_hook.py
+++ b/airflow/hooks/oracle_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.oracle.hooks.oracle`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/hooks/pig_hook.py b/airflow/hooks/pig_hook.py
index c750c87615f1f..9035a262263cc 100644
--- a/airflow/hooks/pig_hook.py
+++ b/airflow/hooks/pig_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.apache.pig.hooks.pig`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/hooks/postgres_hook.py b/airflow/hooks/postgres_hook.py
index e72d517da2c2b..1202a3f144d40 100644
--- a/airflow/hooks/postgres_hook.py
+++ b/airflow/hooks/postgres_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.postgres.hooks.postgres`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/hooks/presto_hook.py b/airflow/hooks/presto_hook.py
index 75358eecb4b02..2daca69966954 100644
--- a/airflow/hooks/presto_hook.py
+++ b/airflow/hooks/presto_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.presto.hooks.presto`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/hooks/samba_hook.py b/airflow/hooks/samba_hook.py
index 72c1af508804c..715655dca0ac6 100644
--- a/airflow/hooks/samba_hook.py
+++ b/airflow/hooks/samba_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.samba.hooks.samba`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/hooks/slack_hook.py b/airflow/hooks/slack_hook.py
index de2b6720f71da..c0c9dadbc0f8f 100644
--- a/airflow/hooks/slack_hook.py
+++ b/airflow/hooks/slack_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.slack.hooks.slack`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/hooks/sqlite_hook.py b/airflow/hooks/sqlite_hook.py
index 9b770c77aabcd..e8bdd64d47013 100644
--- a/airflow/hooks/sqlite_hook.py
+++ b/airflow/hooks/sqlite_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.sqlite.hooks.sqlite`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/hooks/webhdfs_hook.py b/airflow/hooks/webhdfs_hook.py
index 6c134215c6987..3d4291a2ab60f 100644
--- a/airflow/hooks/webhdfs_hook.py
+++ b/airflow/hooks/webhdfs_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.apache.hdfs.hooks.webhdfs`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/hooks/zendesk_hook.py b/airflow/hooks/zendesk_hook.py
index db20fe0b03711..d5a5e659f5770 100644
--- a/airflow/hooks/zendesk_hook.py
+++ b/airflow/hooks/zendesk_hook.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.zendesk.hooks.zendesk`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/jobs/backfill_job.py b/airflow/jobs/backfill_job.py
index ad7ba436e5f6b..63a711aa22e84 100644
--- a/airflow/jobs/backfill_job.py
+++ b/airflow/jobs/backfill_job.py
@@ -28,7 +28,11 @@
from airflow import models
from airflow.exceptions import (
- AirflowException, BackfillUnfinished, DagConcurrencyLimitReached, NoAvailablePoolSlot, PoolNotFound,
+ AirflowException,
+ BackfillUnfinished,
+ DagConcurrencyLimitReached,
+ NoAvailablePoolSlot,
+ PoolNotFound,
TaskConcurrencyLimitReached,
)
from airflow.executors.executor_loader import ExecutorLoader
@@ -54,9 +58,7 @@ class BackfillJob(BaseJob):
STATES_COUNT_AS_RUNNING = (State.RUNNING, State.QUEUED)
- __mapper_args__ = {
- 'polymorphic_identity': 'BackfillJob'
- }
+ __mapper_args__ = {'polymorphic_identity': 'BackfillJob'}
class _DagRunTaskStatus:
"""
@@ -93,19 +95,20 @@ class _DagRunTaskStatus:
"""
# TODO(edgarRd): AIRFLOW-1444: Add consistency check on counts
- def __init__(self, # pylint: disable=too-many-arguments
- to_run=None,
- running=None,
- skipped=None,
- succeeded=None,
- failed=None,
- not_ready=None,
- deadlocked=None,
- active_runs=None,
- executed_dag_run_dates=None,
- finished_runs=0,
- total_runs=0,
- ):
+ def __init__( # pylint: disable=too-many-arguments
+ self,
+ to_run=None,
+ running=None,
+ skipped=None,
+ succeeded=None,
+ failed=None,
+ not_ready=None,
+ deadlocked=None,
+ active_runs=None,
+ executed_dag_run_dates=None,
+ finished_runs=0,
+ total_runs=0,
+ ):
self.to_run = to_run or OrderedDict()
self.running = running or {}
self.skipped = skipped or set()
@@ -119,21 +122,23 @@ def __init__(self, # pylint: disable=too-many-arguments
self.total_runs = total_runs
def __init__( # pylint: disable=too-many-arguments
- self,
- dag,
- start_date=None,
- end_date=None,
- mark_success=False,
- donot_pickle=False,
- ignore_first_depends_on_past=False,
- ignore_task_deps=False,
- pool=None,
- delay_on_limit_secs=1.0,
- verbose=False,
- conf=None,
- rerun_failed_tasks=False,
- run_backwards=False,
- *args, **kwargs):
+ self,
+ dag,
+ start_date=None,
+ end_date=None,
+ mark_success=False,
+ donot_pickle=False,
+ ignore_first_depends_on_past=False,
+ ignore_task_deps=False,
+ pool=None,
+ delay_on_limit_secs=1.0,
+ verbose=False,
+ conf=None,
+ rerun_failed_tasks=False,
+ run_backwards=False,
+ *args,
+ **kwargs,
+ ):
"""
:param dag: DAG object.
:type dag: airflow.models.DAG
@@ -234,7 +239,7 @@ def _update_counters(self, ti_status, session=None):
self.log.warning(
"FIXME: task instance %s state was set to none externally or "
"reaching concurrency limits. Re-adding task to queue.",
- ti
+ ti,
)
tis_to_be_scheduled.append(ti)
ti_status.running.pop(key)
@@ -260,10 +265,7 @@ def _manage_executor_state(self, running):
for key, value in list(executor.get_event_buffer().items()):
state, info = value
if key not in running:
- self.log.warning(
- "%s state %s not in running=%s",
- key, state, running.values()
- )
+ self.log.warning("%s state %s not in running=%s", key, state, running.values())
continue
ti = running[key]
@@ -272,9 +274,11 @@ def _manage_executor_state(self, running):
self.log.debug("Executor state: %s task %s", state, ti)
if state in (State.FAILED, State.SUCCESS) and ti.state in self.STATES_COUNT_AS_RUNNING:
- msg = ("Executor reports task instance {} finished ({}) "
- "although the task says its {}. Was the task "
- "killed externally? Info: {}".format(ti, state, ti.state, info))
+ msg = (
+ "Executor reports task instance {} finished ({}) "
+ "although the task says its {}. Was the task "
+ "killed externally? Info: {}".format(ti, state, ti.state, info)
+ )
self.log.error(msg)
ti.handle_failure(msg)
@@ -297,11 +301,7 @@ def _get_dag_run(self, run_date: datetime, dag: DAG, session: Session = None):
# check if we are scheduling on top of a already existing dag_run
# we could find a "scheduled" run instead of a "backfill"
- runs = DagRun.find(
- dag_id=dag.dag_id,
- execution_date=run_date,
- session=session
- )
+ runs = DagRun.find(dag_id=dag.dag_id, execution_date=run_date, session=session)
run: Optional[DagRun]
if runs:
run = runs[0]
@@ -312,8 +312,7 @@ def _get_dag_run(self, run_date: datetime, dag: DAG, session: Session = None):
# enforce max_active_runs limit for dag, special cases already
# handled by respect_dag_max_active_limit
- if (respect_dag_max_active_limit and
- current_active_dag_count >= dag.max_active_runs):
+ if respect_dag_max_active_limit and current_active_dag_count >= dag.max_active_runs:
return None
run = run or dag.create_dagrun(
@@ -378,22 +377,28 @@ def _log_progress(self, ti_status):
self.log.info(
'[backfill progress] | finished run %s of %s | tasks waiting: %s | succeeded: %s | '
'running: %s | failed: %s | skipped: %s | deadlocked: %s | not ready: %s',
- ti_status.finished_runs, ti_status.total_runs, len(ti_status.to_run), len(ti_status.succeeded),
- len(ti_status.running), len(ti_status.failed), len(ti_status.skipped), len(ti_status.deadlocked),
- len(ti_status.not_ready)
+ ti_status.finished_runs,
+ ti_status.total_runs,
+ len(ti_status.to_run),
+ len(ti_status.succeeded),
+ len(ti_status.running),
+ len(ti_status.failed),
+ len(ti_status.skipped),
+ len(ti_status.deadlocked),
+ len(ti_status.not_ready),
)
- self.log.debug(
- "Finished dag run loop iteration. Remaining tasks %s",
- ti_status.to_run.values()
- )
+ self.log.debug("Finished dag run loop iteration. Remaining tasks %s", ti_status.to_run.values())
@provide_session
- def _process_backfill_task_instances(self, # pylint: disable=too-many-statements
- ti_status,
- executor,
- pickle_id,
- start_date=None, session=None):
+ def _process_backfill_task_instances( # pylint: disable=too-many-statements
+ self,
+ ti_status,
+ executor,
+ pickle_id,
+ start_date=None,
+ session=None,
+ ):
"""
Process a set of task instances from a set of dag runs. Special handling is done
to account for different task instance states that could be present when running
@@ -414,8 +419,7 @@ def _process_backfill_task_instances(self, # pylint: disable=too-many-statement
"""
executed_run_dates = []
- while ((len(ti_status.to_run) > 0 or len(ti_status.running) > 0) and
- len(ti_status.deadlocked) == 0):
+ while (len(ti_status.to_run) > 0 or len(ti_status.running) > 0) and len(ti_status.deadlocked) == 0:
self.log.debug("*** Clearing out not_ready list ***")
ti_status.not_ready.clear()
@@ -430,11 +434,10 @@ def _per_task_process(key, ti, session=None): # pylint: disable=too-many-return
task = self.dag.get_task(ti.task_id, include_subdags=True)
ti.task = task
- ignore_depends_on_past = (
- self.ignore_first_depends_on_past and
- ti.execution_date == (start_date or ti.start_date))
- self.log.debug(
- "Task instance to run %s state %s", ti, ti.state)
+ ignore_depends_on_past = self.ignore_first_depends_on_past and ti.execution_date == (
+ start_date or ti.start_date
+ )
+ self.log.debug("Task instance to run %s state %s", ti, ti.state)
# The task was already marked successful or skipped by a
# different Job. Don't rerun it.
@@ -457,8 +460,7 @@ def _per_task_process(key, ti, session=None): # pylint: disable=too-many-return
# in case max concurrency has been reached at task runtime
elif ti.state == State.NONE:
self.log.warning(
- "FIXME: task instance {} state was set to None "
- "externally. This should not happen"
+ "FIXME: task instance {} state was set to None " "externally. This should not happen"
)
ti.set_state(State.SCHEDULED, session=session)
if self.rerun_failed_tasks:
@@ -483,19 +485,17 @@ def _per_task_process(key, ti, session=None): # pylint: disable=too-many-return
deps=BACKFILL_QUEUED_DEPS,
ignore_depends_on_past=ignore_depends_on_past,
ignore_task_deps=self.ignore_task_deps,
- flag_upstream_failed=True)
+ flag_upstream_failed=True,
+ )
# Is the task runnable? -- then run it
# the dependency checker can change states of tis
if ti.are_dependencies_met(
- dep_context=backfill_context,
- session=session,
- verbose=self.verbose):
+ dep_context=backfill_context, session=session, verbose=self.verbose
+ ):
if executor.has_task(ti):
self.log.debug(
- "Task Instance %s already in executor "
- "waiting for queue to clear",
- ti
+ "Task Instance %s already in executor " "waiting for queue to clear", ti
)
else:
self.log.debug('Sending %s to executor', ti)
@@ -507,7 +507,8 @@ def _per_task_process(key, ti, session=None): # pylint: disable=too-many-return
cfg_path = None
if self.executor_class in (
- ExecutorLoader.LOCAL_EXECUTOR, ExecutorLoader.SEQUENTIAL_EXECUTOR
+ ExecutorLoader.LOCAL_EXECUTOR,
+ ExecutorLoader.SEQUENTIAL_EXECUTOR,
):
cfg_path = tmp_configuration_copy()
@@ -518,7 +519,8 @@ def _per_task_process(key, ti, session=None): # pylint: disable=too-many-return
ignore_task_deps=self.ignore_task_deps,
ignore_depends_on_past=ignore_depends_on_past,
pool=self.pool,
- cfg_path=cfg_path)
+ cfg_path=cfg_path,
+ )
ti_status.running[key] = ti
ti_status.to_run.pop(key)
session.commit()
@@ -534,9 +536,7 @@ def _per_task_process(key, ti, session=None): # pylint: disable=too-many-return
# special case
if ti.state == State.UP_FOR_RETRY:
- self.log.debug(
- "Task instance %s retry period not "
- "expired yet", ti)
+ self.log.debug("Task instance %s retry period not " "expired yet", ti)
if key in ti_status.running:
ti_status.running.pop(key)
ti_status.to_run[key] = ti
@@ -544,9 +544,7 @@ def _per_task_process(key, ti, session=None): # pylint: disable=too-many-return
# special case
if ti.state == State.UP_FOR_RESCHEDULE:
- self.log.debug(
- "Task instance %s reschedule period not "
- "expired yet", ti)
+ self.log.debug("Task instance %s reschedule period not " "expired yet", ti)
if key in ti_status.running:
ti_status.running.pop(key)
ti_status.to_run[key] = ti
@@ -562,9 +560,7 @@ def _per_task_process(key, ti, session=None): # pylint: disable=too-many-return
if task.task_id != ti.task_id:
continue
- pool = session.query(models.Pool) \
- .filter(models.Pool.pool == task.pool) \
- .first()
+ pool = session.query(models.Pool).filter(models.Pool.pool == task.pool).first()
if not pool:
raise PoolNotFound(f'Unknown pool: {task.pool}')
@@ -572,8 +568,8 @@ def _per_task_process(key, ti, session=None): # pylint: disable=too-many-return
if open_slots <= 0:
raise NoAvailablePoolSlot(
"Not scheduling since there are "
- "{} open slots in pool {}".format(
- open_slots, task.pool))
+ "{} open slots in pool {}".format(open_slots, task.pool)
+ )
num_running_task_instances_in_dag = DAG.get_num_task_instances(
self.dag_id,
@@ -582,8 +578,7 @@ def _per_task_process(key, ti, session=None): # pylint: disable=too-many-return
if num_running_task_instances_in_dag >= self.dag.concurrency:
raise DagConcurrencyLimitReached(
- "Not scheduling since DAG concurrency limit "
- "is reached."
+ "Not scheduling since DAG concurrency limit " "is reached."
)
if task.task_concurrency:
@@ -595,8 +590,7 @@ def _per_task_process(key, ti, session=None): # pylint: disable=too-many-return
if num_running_task_instances_in_task >= task.task_concurrency:
raise TaskConcurrencyLimitReached(
- "Not scheduling since Task concurrency limit "
- "is reached."
+ "Not scheduling since Task concurrency limit " "is reached."
)
_per_task_process(key, ti)
@@ -610,13 +604,12 @@ def _per_task_process(key, ti, session=None): # pylint: disable=too-many-return
# If the set of tasks that aren't ready ever equals the set of
# tasks to run and there are no running tasks then the backfill
# is deadlocked
- if (ti_status.not_ready and
- ti_status.not_ready == set(ti_status.to_run) and
- len(ti_status.running) == 0):
- self.log.warning(
- "Deadlock discovered for ti_status.to_run=%s",
- ti_status.to_run.values()
- )
+ if (
+ ti_status.not_ready
+ and ti_status.not_ready == set(ti_status.to_run)
+ and len(ti_status.running) == 0
+ ):
+ self.log.warning("Deadlock discovered for ti_status.to_run=%s", ti_status.to_run.values())
ti_status.deadlocked.update(ti_status.to_run.values())
ti_status.to_run.clear()
@@ -645,19 +638,17 @@ def _collect_errors(self, ti_status, session=None):
def tabulate_ti_keys_set(set_ti_keys: Set[TaskInstanceKey]) -> str:
# Sorting by execution date first
sorted_ti_keys = sorted(
- set_ti_keys, key=lambda ti_key:
- (ti_key.execution_date, ti_key.dag_id, ti_key.task_id, ti_key.try_number)
+ set_ti_keys,
+ key=lambda ti_key: (ti_key.execution_date, ti_key.dag_id, ti_key.task_id, ti_key.try_number),
)
return tabulate(sorted_ti_keys, headers=["DAG ID", "Task ID", "Execution date", "Try number"])
def tabulate_tis_set(set_tis: Set[TaskInstance]) -> str:
# Sorting by execution date first
sorted_tis = sorted(
- set_tis, key=lambda ti: (ti.execution_date, ti.dag_id, ti.task_id, ti.try_number))
- tis_values = (
- (ti.dag_id, ti.task_id, ti.execution_date, ti.try_number)
- for ti in sorted_tis
+ set_tis, key=lambda ti: (ti.execution_date, ti.dag_id, ti.task_id, ti.try_number)
)
+ tis_values = ((ti.dag_id, ti.task_id, ti.execution_date, ti.try_number) for ti in sorted_tis)
return tabulate(tis_values, headers=["DAG ID", "Task ID", "Execution date", "Try number"])
err = ''
@@ -670,19 +661,21 @@ def tabulate_tis_set(set_tis: Set[TaskInstance]) -> str:
t.are_dependencies_met(
dep_context=DepContext(ignore_depends_on_past=False),
session=session,
- verbose=self.verbose) !=
- t.are_dependencies_met(
- dep_context=DepContext(ignore_depends_on_past=True),
- session=session,
- verbose=self.verbose)
- for t in ti_status.deadlocked)
+ verbose=self.verbose,
+ )
+ != t.are_dependencies_met(
+ dep_context=DepContext(ignore_depends_on_past=True), session=session, verbose=self.verbose
+ )
+ for t in ti_status.deadlocked
+ )
if deadlocked_depends_on_past:
err += (
'Some of the deadlocked tasks were unable to run because '
'of "depends_on_past" relationships. Try running the '
'backfill with the option '
'"ignore_first_depends_on_past=True" or passing "-I" at '
- 'the command line.')
+ 'the command line.'
+ )
err += '\nThese tasks have succeeded:\n'
err += tabulate_ti_keys_set(ti_status.succeeded)
err += '\n\nThese tasks are running:\n'
@@ -697,8 +690,7 @@ def tabulate_tis_set(set_tis: Set[TaskInstance]) -> str:
return err
@provide_session
- def _execute_for_run_dates(self, run_dates, ti_status, executor, pickle_id,
- start_date, session=None):
+ def _execute_for_run_dates(self, run_dates, ti_status, executor, pickle_id, start_date, session=None):
"""
Computes the dag runs and their respective task instances for
the given run dates and executes the task instances.
@@ -720,8 +712,7 @@ def _execute_for_run_dates(self, run_dates, ti_status, executor, pickle_id,
for next_run_date in run_dates:
for dag in [self.dag] + self.dag.subdags:
dag_run = self._get_dag_run(next_run_date, dag, session=session)
- tis_map = self._task_instances_for_dag_run(dag_run,
- session=session)
+ tis_map = self._task_instances_for_dag_run(dag_run, session=session)
if dag_run is None:
continue
@@ -733,7 +724,8 @@ def _execute_for_run_dates(self, run_dates, ti_status, executor, pickle_id,
executor=executor,
pickle_id=pickle_id,
start_date=start_date,
- session=session)
+ session=session,
+ )
ti_status.executed_dag_run_dates.update(processed_dag_run_dates)
@@ -764,14 +756,15 @@ def _execute(self, session=None):
start_date = self.bf_start_date
# Get intervals between the start/end dates, which will turn into dag runs
- run_dates = self.dag.get_run_dates(start_date=start_date,
- end_date=self.bf_end_date)
+ run_dates = self.dag.get_run_dates(start_date=start_date, end_date=self.bf_end_date)
if self.run_backwards:
tasks_that_depend_on_past = [t.task_id for t in self.dag.task_dict.values() if t.depends_on_past]
if tasks_that_depend_on_past:
raise AirflowException(
'You cannot backfill backwards because one or more tasks depend_on_past: {}'.format(
- ",".join(tasks_that_depend_on_past)))
+ ",".join(tasks_that_depend_on_past)
+ )
+ )
run_dates = run_dates[::-1]
if len(run_dates) == 0:
@@ -782,7 +775,9 @@ def _execute(self, session=None):
pickle_id = None
if not self.donot_pickle and self.executor_class not in (
- ExecutorLoader.LOCAL_EXECUTOR, ExecutorLoader.SEQUENTIAL_EXECUTOR, ExecutorLoader.DASK_EXECUTOR,
+ ExecutorLoader.LOCAL_EXECUTOR,
+ ExecutorLoader.SEQUENTIAL_EXECUTOR,
+ ExecutorLoader.DASK_EXECUTOR,
):
pickle = DagPickle(self.dag)
session.add(pickle)
@@ -797,19 +792,20 @@ def _execute(self, session=None):
try: # pylint: disable=too-many-nested-blocks
remaining_dates = ti_status.total_runs
while remaining_dates > 0:
- dates_to_process = [run_date for run_date in run_dates if run_date not in
- ti_status.executed_dag_run_dates]
-
- self._execute_for_run_dates(run_dates=dates_to_process,
- ti_status=ti_status,
- executor=executor,
- pickle_id=pickle_id,
- start_date=start_date,
- session=session)
-
- remaining_dates = (
- ti_status.total_runs - len(ti_status.executed_dag_run_dates)
+ dates_to_process = [
+ run_date for run_date in run_dates if run_date not in ti_status.executed_dag_run_dates
+ ]
+
+ self._execute_for_run_dates(
+ run_dates=dates_to_process,
+ ti_status=ti_status,
+ executor=executor,
+ pickle_id=pickle_id,
+ start_date=start_date,
+ session=session,
)
+
+ remaining_dates = ti_status.total_runs - len(ti_status.executed_dag_run_dates)
err = self._collect_errors(ti_status=ti_status, session=session)
if err:
raise BackfillUnfinished(err, ti_status)
@@ -818,7 +814,7 @@ def _execute(self, session=None):
self.log.info(
"max_active_runs limit for dag %s has been reached "
" - waiting for other dag runs to finish",
- self.dag_id
+ self.dag_id,
)
time.sleep(self.delay_on_limit_secs)
except (KeyboardInterrupt, SystemExit):
@@ -854,21 +850,23 @@ def reset_state_for_orphaned_tasks(self, filter_by_dag_run=None, session=None):
resettable_states = [State.SCHEDULED, State.QUEUED]
if filter_by_dag_run is None:
resettable_tis = (
- session
- .query(TaskInstance)
+ session.query(TaskInstance)
.join(
DagRun,
and_(
TaskInstance.dag_id == DagRun.dag_id,
- TaskInstance.execution_date == DagRun.execution_date))
+ TaskInstance.execution_date == DagRun.execution_date,
+ ),
+ )
.filter(
# pylint: disable=comparison-with-callable
DagRun.state == State.RUNNING,
DagRun.run_type != DagRunType.BACKFILL_JOB,
- TaskInstance.state.in_(resettable_states))).all()
+ TaskInstance.state.in_(resettable_states),
+ )
+ ).all()
else:
- resettable_tis = filter_by_dag_run.get_task_instances(state=resettable_states,
- session=session)
+ resettable_tis = filter_by_dag_run.get_task_instances(state=resettable_states, session=session)
tis_to_reset = []
# Can't use an update here since it doesn't support joins
for ti in resettable_tis:
@@ -883,9 +881,12 @@ def query(result, items):
return result
filter_for_tis = TaskInstance.filter_for_tis(items)
- reset_tis = session.query(TaskInstance).filter(
- filter_for_tis, TaskInstance.state.in_(resettable_states)
- ).with_for_update().all()
+ reset_tis = (
+ session.query(TaskInstance)
+ .filter(filter_for_tis, TaskInstance.state.in_(resettable_states))
+ .with_for_update()
+ .all()
+ )
for ti in reset_tis:
ti.state = State.NONE
@@ -893,16 +894,10 @@ def query(result, items):
return result + reset_tis
- reset_tis = helpers.reduce_in_chunks(query,
- tis_to_reset,
- [],
- self.max_tis_per_query)
+ reset_tis = helpers.reduce_in_chunks(query, tis_to_reset, [], self.max_tis_per_query)
task_instance_str = '\n\t'.join([repr(x) for x in reset_tis])
session.commit()
- self.log.info(
- "Reset the following %s TaskInstances:\n\t%s",
- len(reset_tis), task_instance_str
- )
+ self.log.info("Reset the following %s TaskInstances:\n\t%s", len(reset_tis), task_instance_str)
return len(reset_tis)
diff --git a/airflow/jobs/base_job.py b/airflow/jobs/base_job.py
index caff8c52279f3..1ca0aaa19404b 100644
--- a/airflow/jobs/base_job.py
+++ b/airflow/jobs/base_job.py
@@ -53,7 +53,9 @@ class BaseJob(Base, LoggingMixin):
__tablename__ = "job"
id = Column(Integer, primary_key=True)
- dag_id = Column(String(ID_LEN),)
+ dag_id = Column(
+ String(ID_LEN),
+ )
state = Column(String(20))
job_type = Column(String(30))
start_date = Column(UtcDateTime())
@@ -63,10 +65,7 @@ class BaseJob(Base, LoggingMixin):
hostname = Column(String(500))
unixname = Column(String(1000))
- __mapper_args__ = {
- 'polymorphic_on': job_type,
- 'polymorphic_identity': 'BaseJob'
- }
+ __mapper_args__ = {'polymorphic_on': job_type, 'polymorphic_identity': 'BaseJob'}
__table_args__ = (
Index('job_type_heart', job_type, latest_heartbeat),
@@ -95,11 +94,7 @@ class BaseJob(Base, LoggingMixin):
heartrate = conf.getfloat('scheduler', 'JOB_HEARTBEAT_SEC')
- def __init__(
- self,
- executor=None,
- heartrate=None,
- *args, **kwargs):
+ def __init__(self, executor=None, heartrate=None, *args, **kwargs):
self.hostname = get_hostname()
self.executor = executor or ExecutorLoader.get_default_executor()
self.executor_class = self.executor.__class__.__name__
@@ -139,8 +134,9 @@ def is_alive(self, grace_multiplier=2.1):
:rtype: boolean
"""
return (
- self.state == State.RUNNING and
- (timezone.utcnow() - self.latest_heartbeat).total_seconds() < self.heartrate * grace_multiplier
+ self.state == State.RUNNING
+ and (timezone.utcnow() - self.latest_heartbeat).total_seconds()
+ < self.heartrate * grace_multiplier
)
@provide_session
@@ -206,9 +202,9 @@ def heartbeat(self, only_if_necessary: bool = False):
# Figure out how long to sleep for
sleep_for = 0
if self.latest_heartbeat:
- seconds_remaining = self.heartrate - \
- (timezone.utcnow() - self.latest_heartbeat)\
- .total_seconds()
+ seconds_remaining = (
+ self.heartrate - (timezone.utcnow() - self.latest_heartbeat).total_seconds()
+ )
sleep_for = max(0, seconds_remaining)
sleep(sleep_for)
@@ -224,9 +220,7 @@ def heartbeat(self, only_if_necessary: bool = False):
self.heartbeat_callback(session=session)
self.log.debug('[heartbeat]')
except OperationalError:
- Stats.incr(
- convert_camel_to_snake(self.__class__.__name__) + '_heartbeat_failure', 1,
- 1)
+ Stats.incr(convert_camel_to_snake(self.__class__.__name__) + '_heartbeat_failure', 1, 1)
self.log.exception("%s heartbeat got an exception", self.__class__.__name__)
# We didn't manage to heartbeat, so make sure that the timestamp isn't updated
self.latest_heartbeat = previous_heartbeat
diff --git a/airflow/jobs/local_task_job.py b/airflow/jobs/local_task_job.py
index 0c3e86215c44e..f4d4ef0efcf86 100644
--- a/airflow/jobs/local_task_job.py
+++ b/airflow/jobs/local_task_job.py
@@ -36,21 +36,21 @@
class LocalTaskJob(BaseJob):
"""LocalTaskJob runs a single task instance."""
- __mapper_args__ = {
- 'polymorphic_identity': 'LocalTaskJob'
- }
+ __mapper_args__ = {'polymorphic_identity': 'LocalTaskJob'}
def __init__(
- self,
- task_instance: TaskInstance,
- ignore_all_deps: bool = False,
- ignore_depends_on_past: bool = False,
- ignore_task_deps: bool = False,
- ignore_ti_state: bool = False,
- mark_success: bool = False,
- pickle_id: Optional[str] = None,
- pool: Optional[str] = None,
- *args, **kwargs):
+ self,
+ task_instance: TaskInstance,
+ ignore_all_deps: bool = False,
+ ignore_depends_on_past: bool = False,
+ ignore_task_deps: bool = False,
+ ignore_ti_state: bool = False,
+ mark_success: bool = False,
+ pickle_id: Optional[str] = None,
+ pool: Optional[str] = None,
+ *args,
+ **kwargs,
+ ):
self.task_instance = task_instance
self.dag_id = task_instance.dag_id
self.ignore_all_deps = ignore_all_deps
@@ -82,13 +82,14 @@ def signal_handler(signum, frame):
signal.signal(signal.SIGTERM, signal_handler)
if not self.task_instance.check_and_change_state_before_execution(
- mark_success=self.mark_success,
- ignore_all_deps=self.ignore_all_deps,
- ignore_depends_on_past=self.ignore_depends_on_past,
- ignore_task_deps=self.ignore_task_deps,
- ignore_ti_state=self.ignore_ti_state,
- job_id=self.id,
- pool=self.pool):
+ mark_success=self.mark_success,
+ ignore_all_deps=self.ignore_all_deps,
+ ignore_depends_on_past=self.ignore_depends_on_past,
+ ignore_task_deps=self.ignore_task_deps,
+ ignore_ti_state=self.ignore_ti_state,
+ job_id=self.id,
+ pool=self.pool,
+ ):
self.log.info("Task is not able to be run")
return
@@ -104,10 +105,12 @@ def signal_handler(signum, frame):
max_wait_time = max(
0, # Make sure this value is never negative,
min(
- (heartbeat_time_limit -
- (timezone.utcnow() - self.latest_heartbeat).total_seconds() * 0.75),
+ (
+ heartbeat_time_limit
+ - (timezone.utcnow() - self.latest_heartbeat).total_seconds() * 0.75
+ ),
self.heartrate,
- )
+ ),
)
return_code = self.task_runner.return_code(timeout=max_wait_time)
@@ -124,10 +127,10 @@ def signal_handler(signum, frame):
if time_since_last_heartbeat > heartbeat_time_limit:
Stats.incr('local_task_job_prolonged_heartbeat_failure', 1, 1)
self.log.error("Heartbeat time limit exceeded!")
- raise AirflowException("Time since last heartbeat({:.2f}s) "
- "exceeded limit ({}s)."
- .format(time_since_last_heartbeat,
- heartbeat_time_limit))
+ raise AirflowException(
+ "Time since last heartbeat({:.2f}s) "
+ "exceeded limit ({}s).".format(time_since_last_heartbeat, heartbeat_time_limit)
+ )
finally:
self.on_kill()
@@ -150,25 +153,21 @@ def heartbeat_callback(self, session=None):
fqdn = get_hostname()
same_hostname = fqdn == ti.hostname
if not same_hostname:
- self.log.warning("The recorded hostname %s "
- "does not match this instance's hostname "
- "%s", ti.hostname, fqdn)
+ self.log.warning(
+ "The recorded hostname %s " "does not match this instance's hostname " "%s",
+ ti.hostname,
+ fqdn,
+ )
raise AirflowException("Hostname of job runner does not match")
current_pid = os.getpid()
same_process = ti.pid == current_pid
if not same_process:
- self.log.warning("Recorded pid %s does not match "
- "the current pid %s", ti.pid, current_pid)
+ self.log.warning("Recorded pid %s does not match " "the current pid %s", ti.pid, current_pid)
raise AirflowException("PID of job runner does not match")
- elif (
- self.task_runner.return_code() is None and
- hasattr(self.task_runner, 'process')
- ):
+ elif self.task_runner.return_code() is None and hasattr(self.task_runner, 'process'):
self.log.warning(
- "State of this instance has been externally set to %s. "
- "Terminating instance.",
- ti.state
+ "State of this instance has been externally set to %s. " "Terminating instance.", ti.state
)
if ti.state == State.FAILED and ti.task.on_failure_callback:
context = ti.get_template_context()
diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py
index 38a79ffa51fb2..ceee033789aae 100644
--- a/airflow/jobs/scheduler_job.py
+++ b/airflow/jobs/scheduler_job.py
@@ -52,7 +52,10 @@
from airflow.ti_deps.dependencies_states import EXECUTION_STATES
from airflow.utils import timezone
from airflow.utils.callback_requests import (
- CallbackRequest, DagCallbackRequest, SlaCallbackRequest, TaskCallbackRequest,
+ CallbackRequest,
+ DagCallbackRequest,
+ SlaCallbackRequest,
+ TaskCallbackRequest,
)
from airflow.utils.dag_processing import AbstractDagFileProcessorProcess, DagFileProcessorAgent
from airflow.utils.email import get_email_address_list, send_email
@@ -184,9 +187,7 @@ def _run_file_processor(
)
result_channel.send(result)
end_time = time.time()
- log.info(
- "Processing %s took %.3f seconds", file_path, end_time - start_time
- )
+ log.info("Processing %s took %.3f seconds", file_path, end_time - start_time)
except Exception: # pylint: disable=broad-except
# Log exceptions through the logging framework.
log.exception("Got an exception! Propagating...")
@@ -213,9 +214,9 @@ def start(self) -> None:
self._pickle_dags,
self._dag_ids,
f"DagFileProcessor{self._instance_id}",
- self._callback_requests
+ self._callback_requests,
),
- name=f"DagFileProcessor{self._instance_id}-Process"
+ name=f"DagFileProcessor{self._instance_id}-Process",
)
self._process = process
self._start_time = timezone.utcnow()
@@ -400,28 +401,24 @@ def manage_slas(self, dag: DAG, session: Session = None) -> None:
return
qry = (
- session
- .query(
- TI.task_id,
- func.max(TI.execution_date).label('max_ti')
- )
+ session.query(TI.task_id, func.max(TI.execution_date).label('max_ti'))
.with_hint(TI, 'USE INDEX (PRIMARY)', dialect_name='mysql')
.filter(TI.dag_id == dag.dag_id)
- .filter(
- or_(
- TI.state == State.SUCCESS,
- TI.state == State.SKIPPED
- )
- )
+ .filter(or_(TI.state == State.SUCCESS, TI.state == State.SKIPPED))
.filter(TI.task_id.in_(dag.task_ids))
- .group_by(TI.task_id).subquery('sq')
+ .group_by(TI.task_id)
+ .subquery('sq')
)
- max_tis: List[TI] = session.query(TI).filter(
- TI.dag_id == dag.dag_id,
- TI.task_id == qry.c.task_id,
- TI.execution_date == qry.c.max_ti,
- ).all()
+ max_tis: List[TI] = (
+ session.query(TI)
+ .filter(
+ TI.dag_id == dag.dag_id,
+ TI.task_id == qry.c.task_id,
+ TI.execution_date == qry.c.max_ti,
+ )
+ .all()
+ )
ts = timezone.utcnow()
for ti in max_tis:
@@ -433,31 +430,26 @@ def manage_slas(self, dag: DAG, session: Session = None) -> None:
while dttm < timezone.utcnow():
following_schedule = dag.following_schedule(dttm)
if following_schedule + task.sla < timezone.utcnow():
- session.merge(SlaMiss(
- task_id=ti.task_id,
- dag_id=ti.dag_id,
- execution_date=dttm,
- timestamp=ts))
+ session.merge(
+ SlaMiss(task_id=ti.task_id, dag_id=ti.dag_id, execution_date=dttm, timestamp=ts)
+ )
dttm = dag.following_schedule(dttm)
session.commit()
+ # pylint: disable=singleton-comparison
slas: List[SlaMiss] = (
- session
- .query(SlaMiss)
- .filter(SlaMiss.notification_sent == False, SlaMiss.dag_id == dag.dag_id) # noqa pylint: disable=singleton-comparison
+ session.query(SlaMiss)
+ .filter(SlaMiss.notification_sent == False, SlaMiss.dag_id == dag.dag_id) # noqa
.all()
)
+ # pylint: enable=singleton-comparison
if slas: # pylint: disable=too-many-nested-blocks
sla_dates: List[datetime.datetime] = [sla.execution_date for sla in slas]
fetched_tis: List[TI] = (
- session
- .query(TI)
- .filter(
- TI.state != State.SUCCESS,
- TI.execution_date.in_(sla_dates),
- TI.dag_id == dag.dag_id
- ).all()
+ session.query(TI)
+ .filter(TI.state != State.SUCCESS, TI.execution_date.in_(sla_dates), TI.dag_id == dag.dag_id)
+ .all()
)
blocking_tis: List[TI] = []
for ti in fetched_tis:
@@ -468,12 +460,10 @@ def manage_slas(self, dag: DAG, session: Session = None) -> None:
session.delete(ti)
session.commit()
- task_list = "\n".join([
- sla.task_id + ' on ' + sla.execution_date.isoformat()
- for sla in slas])
- blocking_task_list = "\n".join([
- ti.task_id + ' on ' + ti.execution_date.isoformat()
- for ti in blocking_tis])
+ task_list = "\n".join([sla.task_id + ' on ' + sla.execution_date.isoformat() for sla in slas])
+ blocking_task_list = "\n".join(
+ [ti.task_id + ' on ' + ti.execution_date.isoformat() for ti in blocking_tis]
+ )
# Track whether email or any alert notification sent
# We consider email or the alert callback as notifications
email_sent = False
@@ -482,8 +472,7 @@ def manage_slas(self, dag: DAG, session: Session = None) -> None:
# Execute the alert callback
self.log.info('Calling SLA miss callback')
try:
- dag.sla_miss_callback(dag, task_list, blocking_task_list, slas,
- blocking_tis)
+ dag.sla_miss_callback(dag, task_list, blocking_task_list, slas, blocking_tis)
notification_sent = True
except Exception: # pylint: disable=broad-except
self.log.exception("Could not call sla_miss_callback for DAG %s", dag.dag_id)
@@ -501,8 +490,8 @@ def manage_slas(self, dag: DAG, session: Session = None) -> None:
except TaskNotFound:
# task already deleted from DAG, skip it
self.log.warning(
- "Task %s doesn't exist in DAG anymore, skipping SLA miss notification.",
- sla.task_id)
+ "Task %s doesn't exist in DAG anymore, skipping SLA miss notification.", sla.task_id
+ )
continue
tasks_missed_sla.append(task)
@@ -515,17 +504,12 @@ def manage_slas(self, dag: DAG, session: Session = None) -> None:
emails |= set(task.email)
if emails:
try:
- send_email(
- emails,
- f"[airflow] SLA miss on DAG={dag.dag_id}",
- email_content
- )
+ send_email(emails, f"[airflow] SLA miss on DAG={dag.dag_id}", email_content)
email_sent = True
notification_sent = True
except Exception: # pylint: disable=broad-except
Stats.incr('sla_email_notification_failure')
- self.log.exception("Could not send SLA Miss email notification for"
- " DAG %s", dag.dag_id)
+ self.log.exception("Could not send SLA Miss email notification for" " DAG %s", dag.dag_id)
# If we sent any notification, update the sla_miss table
if notification_sent:
for sla in slas:
@@ -548,24 +532,18 @@ def update_import_errors(session: Session, dagbag: DagBag) -> None:
"""
# Clear the errors of the processed files
for dagbag_file in dagbag.file_last_changed:
- session.query(errors.ImportError).filter(
- errors.ImportError.filename == dagbag_file
- ).delete()
+ session.query(errors.ImportError).filter(errors.ImportError.filename == dagbag_file).delete()
# Add the errors of the processed files
for filename, stacktrace in dagbag.import_errors.items():
- session.add(errors.ImportError(
- filename=filename,
- timestamp=timezone.utcnow(),
- stacktrace=stacktrace))
+ session.add(
+ errors.ImportError(filename=filename, timestamp=timezone.utcnow(), stacktrace=stacktrace)
+ )
session.commit()
@provide_session
def execute_callbacks(
- self,
- dagbag: DagBag,
- callback_requests: List[CallbackRequest],
- session: Session = None
+ self, dagbag: DagBag, callback_requests: List[CallbackRequest], session: Session = None
) -> None:
"""
Execute on failure callbacks. These objects can come from SchedulerJob or from
@@ -588,7 +566,7 @@ def execute_callbacks(
self.log.exception(
"Error executing %s callback for file: %s",
request.__class__.__name__,
- request.full_filepath
+ request.full_filepath,
)
session.commit()
@@ -598,10 +576,7 @@ def _execute_dag_callbacks(self, dagbag: DagBag, request: DagCallbackRequest, se
dag = dagbag.dags[request.dag_id]
dag_run = dag.get_dagrun(execution_date=request.execution_date, session=session)
dag.handle_callback(
- dagrun=dag_run,
- success=not request.is_failure_callback,
- reason=request.msg,
- session=session
+ dagrun=dag_run, success=not request.is_failure_callback, reason=request.msg, session=session
)
def _execute_task_callbacks(self, dagbag: DagBag, request: TaskCallbackRequest):
@@ -627,7 +602,7 @@ def process_file(
file_path: str,
callback_requests: List[CallbackRequest],
pickle_dags: bool = False,
- session: Session = None
+ session: Session = None,
) -> Tuple[int, int]:
"""
Process a Python file containing Airflow DAGs.
@@ -728,20 +703,20 @@ class SchedulerJob(BaseJob): # pylint: disable=too-many-instance-attributes
:type do_pickle: bool
"""
- __mapper_args__ = {
- 'polymorphic_identity': 'SchedulerJob'
- }
+ __mapper_args__ = {'polymorphic_identity': 'SchedulerJob'}
heartrate: int = conf.getint('scheduler', 'SCHEDULER_HEARTBEAT_SEC')
def __init__(
- self,
- subdir: str = settings.DAGS_FOLDER,
- num_runs: int = conf.getint('scheduler', 'num_runs'),
- num_times_parse_dags: int = -1,
- processor_poll_interval: float = conf.getfloat('scheduler', 'processor_poll_interval'),
- do_pickle: bool = False,
- log: Any = None,
- *args, **kwargs):
+ self,
+ subdir: str = settings.DAGS_FOLDER,
+ num_runs: int = conf.getint('scheduler', 'num_runs'),
+ num_times_parse_dags: int = -1,
+ processor_poll_interval: float = conf.getfloat('scheduler', 'processor_poll_interval'),
+ do_pickle: bool = False,
+ log: Any = None,
+ *args,
+ **kwargs,
+ ):
self.subdir = subdir
self.num_runs = num_runs
@@ -796,16 +771,13 @@ def is_alive(self, grace_multiplier: Optional[float] = None) -> bool:
return super().is_alive(grace_multiplier=grace_multiplier)
scheduler_health_check_threshold: int = conf.getint('scheduler', 'scheduler_health_check_threshold')
return (
- self.state == State.RUNNING and
- (timezone.utcnow() - self.latest_heartbeat).total_seconds() < scheduler_health_check_threshold
+ self.state == State.RUNNING
+ and (timezone.utcnow() - self.latest_heartbeat).total_seconds() < scheduler_health_check_threshold
)
@provide_session
def _change_state_for_tis_without_dagrun(
- self,
- old_states: List[str],
- new_state: str,
- session: Session = None
+ self, old_states: List[str], new_state: str, session: Session = None
) -> None:
"""
For all DAG IDs in the DagBag, look for task instances in the
@@ -820,20 +792,24 @@ def _change_state_for_tis_without_dagrun(
:type new_state: airflow.utils.state.State
"""
tis_changed = 0
- query = session \
- .query(models.TaskInstance) \
- .outerjoin(models.TaskInstance.dag_run) \
- .filter(models.TaskInstance.dag_id.in_(list(self.dagbag.dag_ids))) \
- .filter(models.TaskInstance.state.in_(old_states)) \
- .filter(or_(
- # pylint: disable=comparison-with-callable
- models.DagRun.state != State.RUNNING,
- models.DagRun.state.is_(None))) # pylint: disable=no-member
+ query = (
+ session.query(models.TaskInstance)
+ .outerjoin(models.TaskInstance.dag_run)
+ .filter(models.TaskInstance.dag_id.in_(list(self.dagbag.dag_ids)))
+ .filter(models.TaskInstance.state.in_(old_states))
+ .filter(
+ or_(
+ # pylint: disable=comparison-with-callable
+ models.DagRun.state != State.RUNNING,
+ # pylint: disable=no-member
+ models.DagRun.state.is_(None),
+ )
+ )
+ )
# We need to do this for mysql as well because it can cause deadlocks
# as discussed in https://issues.apache.org/jira/browse/AIRFLOW-2516
if self.using_sqlite or self.using_mysql:
- tis_to_change: List[TI] = with_row_locks(query, of=TI,
- **skip_locked(session=session)).all()
+ tis_to_change: List[TI] = with_row_locks(query, of=TI, **skip_locked(session=session)).all()
for ti in tis_to_change:
ti.set_state(new_state, session=session)
tis_changed += 1
@@ -847,25 +823,29 @@ def _change_state_for_tis_without_dagrun(
# Only add end_date and duration if the new_state is 'success', 'failed' or 'skipped'
if new_state in State.finished:
- ti_prop_update.update({
- models.TaskInstance.end_date: current_time,
- models.TaskInstance.duration: 0,
- })
+ ti_prop_update.update(
+ {
+ models.TaskInstance.end_date: current_time,
+ models.TaskInstance.duration: 0,
+ }
+ )
- tis_changed = session \
- .query(models.TaskInstance) \
+ tis_changed = (
+ session.query(models.TaskInstance)
.filter(
models.TaskInstance.dag_id == subq.c.dag_id,
models.TaskInstance.task_id == subq.c.task_id,
- models.TaskInstance.execution_date ==
- subq.c.execution_date) \
+ models.TaskInstance.execution_date == subq.c.execution_date,
+ )
.update(ti_prop_update, synchronize_session=False)
+ )
session.flush()
if tis_changed > 0:
self.log.warning(
"Set %s task instances to state=%s as their associated DagRun was not in RUNNING state",
- tis_changed, new_state
+ tis_changed,
+ new_state,
)
Stats.gauge('scheduler.tasks.without_dagrun', tis_changed)
@@ -883,8 +863,7 @@ def __get_concurrency_maps(
:rtype: tuple[dict[str, int], dict[tuple[str, str], int]]
"""
ti_concurrency_query: List[Tuple[str, str, int]] = (
- session
- .query(TI.task_id, TI.dag_id, func.count('*'))
+ session.query(TI.task_id, TI.dag_id, func.count('*'))
.filter(TI.state.in_(states))
.group_by(TI.task_id, TI.dag_id)
).all()
@@ -927,11 +906,9 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session =
# DagRuns which are not backfilled, in the given states,
# and the dag is not paused
query = (
- session
- .query(TI)
+ session.query(TI)
.outerjoin(TI.dag_run)
- .filter(or_(DR.run_id.is_(None),
- DR.run_type != DagRunType.BACKFILL_JOB))
+ .filter(or_(DR.run_id.is_(None), DR.run_type != DagRunType.BACKFILL_JOB))
.join(TI.dag_model)
.filter(not_(DM.is_paused))
.filter(TI.state == State.SCHEDULED)
@@ -952,12 +929,8 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session =
return executable_tis
# Put one task instance on each line
- task_instance_str = "\n\t".join(
- [repr(x) for x in task_instances_to_examine])
- self.log.info(
- "%s tasks up for execution:\n\t%s", len(task_instances_to_examine),
- task_instance_str
- )
+ task_instance_str = "\n\t".join([repr(x) for x in task_instances_to_examine])
+ self.log.info("%s tasks up for execution:\n\t%s", len(task_instances_to_examine), task_instance_str)
pool_to_task_instances: DefaultDict[str, List[models.Pool]] = defaultdict(list)
for task_instance in task_instances_to_examine:
@@ -967,7 +940,8 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session =
dag_concurrency_map: DefaultDict[str, int]
task_concurrency_map: DefaultDict[Tuple[str, str], int]
dag_concurrency_map, task_concurrency_map = self.__get_concurrency_maps(
- states=list(EXECUTION_STATES), session=session)
+ states=list(EXECUTION_STATES), session=session
+ )
num_tasks_in_executor = 0
# Number of tasks that cannot be scheduled because of no open slot in pool
@@ -979,10 +953,7 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session =
for pool, task_instances in pool_to_task_instances.items():
pool_name = pool
if pool not in pools:
- self.log.warning(
- "Tasks using non-existent pool '%s' will not be scheduled",
- pool
- )
+ self.log.warning("Tasks using non-existent pool '%s' will not be scheduled", pool)
continue
open_slots = pools[pool]["open"]
@@ -991,19 +962,19 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session =
self.log.info(
"Figuring out tasks to run in Pool(name=%s) with %s open slots "
"and %s task instances ready to be queued",
- pool, open_slots, num_ready
+ pool,
+ open_slots,
+ num_ready,
)
priority_sorted_task_instances = sorted(
- task_instances, key=lambda ti: (-ti.priority_weight, ti.execution_date))
+ task_instances, key=lambda ti: (-ti.priority_weight, ti.execution_date)
+ )
num_starving_tasks = 0
for current_index, task_instance in enumerate(priority_sorted_task_instances):
if open_slots <= 0:
- self.log.info(
- "Not scheduling since there are %s open slots in pool %s",
- open_slots, pool
- )
+ self.log.info("Not scheduling since there are %s open slots in pool %s", open_slots, pool)
# Can't schedule any more since there are no more open slots.
num_unhandled = len(priority_sorted_task_instances) - current_index
num_starving_tasks += num_unhandled
@@ -1018,13 +989,17 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session =
dag_concurrency_limit = task_instance.dag_model.concurrency
self.log.info(
"DAG %s has %s/%s running and queued tasks",
- dag_id, current_dag_concurrency, dag_concurrency_limit
+ dag_id,
+ current_dag_concurrency,
+ dag_concurrency_limit,
)
if current_dag_concurrency >= dag_concurrency_limit:
self.log.info(
"Not executing %s since the number of tasks running or queued "
"from DAG %s is >= to the DAG's task concurrency limit of %s",
- task_instance, dag_id, dag_concurrency_limit
+ task_instance,
+ dag_id,
+ dag_concurrency_limit,
)
continue
@@ -1035,7 +1010,8 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session =
serialized_dag = self.dagbag.get_dag(dag_id, session=session)
if serialized_dag.has_task(task_instance.task_id):
task_concurrency_limit = serialized_dag.get_task(
- task_instance.task_id).task_concurrency
+ task_instance.task_id
+ ).task_concurrency
if task_concurrency_limit is not None:
current_task_concurrency = task_concurrency_map[
@@ -1043,14 +1019,22 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session =
]
if current_task_concurrency >= task_concurrency_limit:
- self.log.info("Not executing %s since the task concurrency for"
- " this task has been reached.", task_instance)
+ self.log.info(
+ "Not executing %s since the task concurrency for"
+ " this task has been reached.",
+ task_instance,
+ )
continue
if task_instance.pool_slots > open_slots:
- self.log.info("Not executing %s since it requires %s slots "
- "but there are %s open slots in the pool %s.",
- task_instance, task_instance.pool_slots, open_slots, pool)
+ self.log.info(
+ "Not executing %s since it requires %s slots "
+ "but there are %s open slots in the pool %s.",
+ task_instance,
+ task_instance.pool_slots,
+ open_slots,
+ pool,
+ )
num_starving_tasks += 1
num_starving_tasks_total += 1
# Though we can execute tasks with lower priority if there's enough room
@@ -1067,10 +1051,8 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session =
Stats.gauge('scheduler.tasks.running', num_tasks_in_executor)
Stats.gauge('scheduler.tasks.executable', len(executable_tis))
- task_instance_str = "\n\t".join(
- [repr(x) for x in executable_tis])
- self.log.info(
- "Setting the following tasks to queued state:\n\t%s", task_instance_str)
+ task_instance_str = "\n\t".join([repr(x) for x in executable_tis])
+ self.log.info("Setting the following tasks to queued state:\n\t%s", task_instance_str)
# set TIs to queued state
filter_for_tis = TI.filter_for_tis(executable_tis)
@@ -1078,17 +1060,14 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session =
# TODO[ha]: should we use func.now()? How does that work with DB timezone on mysql when it's not
# UTC?
{TI.state: State.QUEUED, TI.queued_dttm: timezone.utcnow(), TI.queued_by_job_id: self.id},
- synchronize_session=False
+ synchronize_session=False,
)
for ti in executable_tis:
make_transient(ti)
return executable_tis
- def _enqueue_task_instances_with_queued_state(
- self,
- task_instances: List[TI]
- ) -> None:
+ def _enqueue_task_instances_with_queued_state(self, task_instances: List[TI]) -> None:
"""
Takes task_instances, which should have been set to queued, and enqueues them
with the executor.
@@ -1115,10 +1094,7 @@ def _enqueue_task_instances_with_queued_state(
priority = ti.priority_weight
queue = ti.queue
- self.log.info(
- "Sending %s to executor with priority %s and queue %s",
- ti.key, priority, queue
- )
+ self.log.info("Sending %s to executor with priority %s and queue %s", ti.key, priority, queue)
self.executor.queue_command(
ti,
@@ -1164,17 +1140,18 @@ def _change_state_for_tasks_failed_to_execute(self, session: Session = None):
if not self.executor.queued_tasks:
return
- filter_for_ti_state_change = (
- [and_(
+ filter_for_ti_state_change = [
+ and_(
TI.dag_id == dag_id,
TI.task_id == task_id,
TI.execution_date == execution_date,
# The TI.try_number will return raw try_number+1 since the
# ti is not running. And we need to -1 to match the DB record.
TI._try_number == try_number - 1, # pylint: disable=protected-access
- TI.state == State.QUEUED)
- for dag_id, task_id, execution_date, try_number
- in self.executor.queued_tasks.keys()])
+ TI.state == State.QUEUED,
+ )
+ for dag_id, task_id, execution_date, try_number in self.executor.queued_tasks.keys()
+ ]
ti_query = session.query(TI).filter(or_(*filter_for_ti_state_change))
tis_to_set_to_scheduled: List[TI] = with_row_locks(ti_query).all()
if not tis_to_set_to_scheduled:
@@ -1211,7 +1188,11 @@ def _process_executor_events(self, session: Session = None) -> int:
self.log.info(
"Executor reports execution of %s.%s execution_date=%s "
"exited with status %s for try_number %s",
- ti_key.dag_id, ti_key.task_id, ti_key.execution_date, state, ti_key.try_number
+ ti_key.dag_id,
+ ti_key.task_id,
+ ti_key.execution_date,
+ state,
+ ti_key.try_number,
)
if state in (State.FAILED, State.SUCCESS, State.QUEUED):
tis_with_right_state.append(ti_key)
@@ -1236,8 +1217,10 @@ def _process_executor_events(self, session: Session = None) -> int:
if ti.try_number == buffer_key.try_number and ti.state == State.QUEUED:
Stats.incr('scheduler.tasks.killed_externally')
- msg = "Executor reports task instance %s finished (%s) although the " \
- "task says its %s. (Info: %s) Was the task killed externally?"
+ msg = (
+ "Executor reports task instance %s finished (%s) although the "
+ "task says its %s. (Info: %s) Was the task killed externally?"
+ )
self.log.error(msg, ti, state, ti.state, info)
request = TaskCallbackRequest(
full_filepath=ti.dag_model.fileloc,
@@ -1297,8 +1280,7 @@ def _execute(self) -> None:
# deleted.
if self.processor_agent.all_files_processed:
self.log.info(
- "Deactivating DAGs that haven't been touched since %s",
- execute_start_time.isoformat()
+ "Deactivating DAGs that haven't been touched since %s", execute_start_time.isoformat()
)
models.DAG.deactivate_stale_dags(execute_start_time)
@@ -1316,14 +1298,11 @@ def _create_dag_file_processor(
file_path: str,
callback_requests: List[CallbackRequest],
dag_ids: Optional[List[str]],
- pickle_dags: bool
+ pickle_dags: bool,
) -> DagFileProcessorProcess:
"""Creates DagFileProcessorProcess instance."""
return DagFileProcessorProcess(
- file_path=file_path,
- pickle_dags=pickle_dags,
- dag_ids=dag_ids,
- callback_requests=callback_requests
+ file_path=file_path, pickle_dags=pickle_dags, dag_ids=dag_ids, callback_requests=callback_requests
)
def _run_scheduler_loop(self) -> None:
@@ -1384,14 +1363,16 @@ def _run_scheduler_loop(self) -> None:
if loop_count >= self.num_runs > 0:
self.log.info(
"Exiting scheduler loop as requested number of runs (%d - got to %d) has been reached",
- self.num_runs, loop_count,
+ self.num_runs,
+ loop_count,
)
break
if self.processor_agent.done:
self.log.info(
"Exiting scheduler loop as requested DAG parse count (%d) has been reached after %d "
" scheduler loops",
- self.num_times_parse_dags, loop_count,
+ self.num_times_parse_dags,
+ loop_count,
)
break
@@ -1455,13 +1436,17 @@ def _do_scheduling(self, session) -> int:
# dag_id -- only tasks from those runs will be scheduled.
active_runs_by_dag_id = defaultdict(set)
- query = session.query(
- TI.dag_id,
- TI.execution_date,
- ).filter(
- TI.dag_id.in_(list({dag_run.dag_id for dag_run in dag_runs})),
- TI.state.notin_(list(State.finished))
- ).group_by(TI.dag_id, TI.execution_date)
+ query = (
+ session.query(
+ TI.dag_id,
+ TI.execution_date,
+ )
+ .filter(
+ TI.dag_id.in_(list({dag_run.dag_id for dag_run in dag_runs})),
+ TI.state.notin_(list(State.finished)),
+ )
+ .group_by(TI.dag_id, TI.execution_date)
+ )
for dag_id, execution_date in query:
active_runs_by_dag_id[dag_id].add(execution_date)
@@ -1478,18 +1463,13 @@ def _do_scheduling(self, session) -> int:
# TODO[HA]: Do we need to do it every time?
try:
self._change_state_for_tis_without_dagrun(
- old_states=[State.UP_FOR_RETRY],
- new_state=State.FAILED,
- session=session
+ old_states=[State.UP_FOR_RETRY], new_state=State.FAILED, session=session
)
self._change_state_for_tis_without_dagrun(
- old_states=[State.QUEUED,
- State.SCHEDULED,
- State.UP_FOR_RESCHEDULE,
- State.SENSING],
+ old_states=[State.QUEUED, State.SCHEDULED, State.UP_FOR_RESCHEDULE, State.SENSING],
new_state=State.NONE,
- session=session
+ session=session,
)
guard.commit()
@@ -1560,11 +1540,16 @@ def _update_dag_next_dagruns(self, dag_models: Iterable[DagModel], session: Sess
"""
# Check max_active_runs, to see if we are _now_ at the limit for any of
# these dag? (we've just created a DagRun for them after all)
- active_runs_of_dags = dict(session.query(DagRun.dag_id, func.count('*')).filter(
- DagRun.dag_id.in_([o.dag_id for o in dag_models]),
- DagRun.state == State.RUNNING, # pylint: disable=comparison-with-callable
- DagRun.external_trigger.is_(False),
- ).group_by(DagRun.dag_id).all())
+ active_runs_of_dags = dict(
+ session.query(DagRun.dag_id, func.count('*'))
+ .filter(
+ DagRun.dag_id.in_([o.dag_id for o in dag_models]),
+ DagRun.state == State.RUNNING, # pylint: disable=comparison-with-callable
+ DagRun.external_trigger.is_(False),
+ )
+ .group_by(DagRun.dag_id)
+ .all()
+ )
for dag_model in dag_models:
dag = self.dagbag.get_dag(dag_model.dag_id, session=session)
@@ -1572,12 +1557,15 @@ def _update_dag_next_dagruns(self, dag_models: Iterable[DagModel], session: Sess
if dag.max_active_runs and active_runs_of_dag >= dag.max_active_runs:
self.log.info(
"DAG %s is at (or above) max_active_runs (%d of %d), not creating any more runs",
- dag.dag_id, active_runs_of_dag, dag.max_active_runs
+ dag.dag_id,
+ active_runs_of_dag,
+ dag.max_active_runs,
)
dag_model.next_dagrun_create_after = None
else:
- dag_model.next_dagrun, dag_model.next_dagrun_create_after = \
- dag.next_dagrun_info(dag_model.next_dagrun)
+ dag_model.next_dagrun, dag_model.next_dagrun_create_after = dag.next_dagrun_info(
+ dag_model.next_dagrun
+ )
def _schedule_dag_run(
self,
@@ -1598,14 +1586,13 @@ def _schedule_dag_run(
dag = dag_run.dag = self.dagbag.get_dag(dag_run.dag_id, session=session)
if not dag:
- self.log.error(
- "Couldn't find dag %s in DagBag/DB!", dag_run.dag_id
- )
+ self.log.error("Couldn't find dag %s in DagBag/DB!", dag_run.dag_id)
return 0
if (
- dag_run.start_date and dag.dagrun_timeout and
- dag_run.start_date < timezone.utcnow() - dag.dagrun_timeout
+ dag_run.start_date
+ and dag.dagrun_timeout
+ and dag_run.start_date < timezone.utcnow() - dag.dagrun_timeout
):
dag_run.state = State.FAILED
dag_run.end_date = timezone.utcnow()
@@ -1620,7 +1607,7 @@ def _schedule_dag_run(
dag_id=dag.dag_id,
execution_date=dag_run.execution_date,
is_failure_callback=True,
- msg='timed_out'
+ msg='timed_out',
)
# Send SLA & DAG Success/Failure Callbacks to be executed
@@ -1629,15 +1616,14 @@ def _schedule_dag_run(
return 0
if dag_run.execution_date > timezone.utcnow() and not dag.allow_future_exec_dates:
- self.log.error(
- "Execution date is in future: %s",
- dag_run.execution_date
- )
+ self.log.error("Execution date is in future: %s", dag_run.execution_date)
return 0
if dag.max_active_runs:
- if len(currently_active_runs) >= dag.max_active_runs and \
- dag_run.execution_date not in currently_active_runs:
+ if (
+ len(currently_active_runs) >= dag.max_active_runs
+ and dag_run.execution_date not in currently_active_runs
+ ):
self.log.info(
"DAG %s already has %d active runs, not queuing any tasks for run %s",
dag.dag_id,
@@ -1675,9 +1661,7 @@ def _verify_integrity_if_dag_changed(self, dag_run: DagRun, session=None):
dag_run.verify_integrity(session=session)
def _send_dag_callbacks_to_processor(
- self,
- dag_run: DagRun,
- callback: Optional[DagCallbackRequest] = None
+ self, dag_run: DagRun, callback: Optional[DagCallbackRequest] = None
):
if not self.processor_agent:
raise ValueError("Processor agent is not started.")
@@ -1700,8 +1684,7 @@ def _send_sla_callbacks_to_processor(self, dag: DAG):
raise ValueError("Processor agent is not started.")
self.processor_agent.send_sla_callback_request_to_execute(
- full_filepath=dag.fileloc,
- dag_id=dag.dag_id
+ full_filepath=dag.fileloc, dag_id=dag.dag_id
)
@provide_session
@@ -1727,10 +1710,14 @@ def adopt_or_reset_orphaned_tasks(self, session: Session = None):
"""
timeout = conf.getint('scheduler', 'scheduler_health_check_threshold')
- num_failed = session.query(SchedulerJob).filter(
- SchedulerJob.state == State.RUNNING,
- SchedulerJob.latest_heartbeat < (timezone.utcnow() - timedelta(seconds=timeout))
- ).update({"state": State.FAILED})
+ num_failed = (
+ session.query(SchedulerJob)
+ .filter(
+ SchedulerJob.state == State.RUNNING,
+ SchedulerJob.latest_heartbeat < (timezone.utcnow() - timedelta(seconds=timeout)),
+ )
+ .update({"state": State.FAILED})
+ )
if num_failed:
self.log.info("Marked %d SchedulerJob instances as failed", num_failed)
@@ -1738,7 +1725,8 @@ def adopt_or_reset_orphaned_tasks(self, session: Session = None):
resettable_states = [State.SCHEDULED, State.QUEUED, State.RUNNING]
query = (
- session.query(TI).filter(TI.state.in_(resettable_states))
+ session.query(TI)
+ .filter(TI.state.in_(resettable_states))
# outerjoin is because we didn't use to have queued_by_job
# set, so we need to pick up anything pre upgrade. This (and the
# "or queued_by_job_id IS NONE") can go as soon as scheduler HA is
@@ -1746,9 +1734,11 @@ def adopt_or_reset_orphaned_tasks(self, session: Session = None):
.outerjoin(TI.queued_by_job)
.filter(or_(TI.queued_by_job_id.is_(None), SchedulerJob.state != State.RUNNING))
.join(TI.dag_run)
- .filter(DagRun.run_type != DagRunType.BACKFILL_JOB,
- # pylint: disable=comparison-with-callable
- DagRun.state == State.RUNNING)
+ .filter(
+ DagRun.run_type != DagRunType.BACKFILL_JOB,
+ # pylint: disable=comparison-with-callable
+ DagRun.state == State.RUNNING,
+ )
.options(load_only(TI.dag_id, TI.task_id, TI.execution_date))
)
@@ -1770,8 +1760,9 @@ def adopt_or_reset_orphaned_tasks(self, session: Session = None):
if to_reset:
task_instance_str = '\n\t'.join(reset_tis_message)
- self.log.info("Reset the following %s orphaned TaskInstances:\n\t%s",
- len(to_reset), task_instance_str)
+ self.log.info(
+ "Reset the following %s orphaned TaskInstances:\n\t%s", len(to_reset), task_instance_str
+ )
# Issue SQL/finish "Unit of Work", but let @provide_session commit (or if passed a session, let caller
# decide when to commit
diff --git a/airflow/kubernetes/kube_client.py b/airflow/kubernetes/kube_client.py
index e51274feec46f..7e8c5e83d7f3c 100644
--- a/airflow/kubernetes/kube_client.py
+++ b/airflow/kubernetes/kube_client.py
@@ -26,13 +26,15 @@
from kubernetes.client.rest import ApiException # pylint: disable=unused-import
from airflow.kubernetes.refresh_config import ( # pylint: disable=ungrouped-imports
- RefreshConfiguration, load_kube_config,
+ RefreshConfiguration,
+ load_kube_config,
)
+
has_kubernetes = True
- def _get_kube_config(in_cluster: bool,
- cluster_context: Optional[str],
- config_file: Optional[str]) -> Optional[Configuration]:
+ def _get_kube_config(
+ in_cluster: bool, cluster_context: Optional[str], config_file: Optional[str]
+ ) -> Optional[Configuration]:
if in_cluster:
# load_incluster_config set default configuration with config populated by k8s
config.load_incluster_config()
@@ -41,8 +43,7 @@ def _get_kube_config(in_cluster: bool,
# this block can be replaced with just config.load_kube_config once
# refresh_config module is replaced with upstream fix
cfg = RefreshConfiguration()
- load_kube_config(
- client_configuration=cfg, config_file=config_file, context=cluster_context)
+ load_kube_config(client_configuration=cfg, config_file=config_file, context=cluster_context)
return cfg
def _get_client_with_patched_configuration(cfg: Optional[Configuration]) -> client.CoreV1Api:
@@ -57,6 +58,7 @@ def _get_client_with_patched_configuration(cfg: Optional[Configuration]) -> clie
else:
return client.CoreV1Api()
+
except ImportError as e:
# We need an exception class to be able to use it in ``except`` elsewhere
# in the code base
@@ -92,9 +94,11 @@ def _enable_tcp_keepalive() -> None:
HTTPConnection.default_socket_options = HTTPConnection.default_socket_options + socket_options
-def get_kube_client(in_cluster: bool = conf.getboolean('kubernetes', 'in_cluster'),
- cluster_context: Optional[str] = None,
- config_file: Optional[str] = None) -> client.CoreV1Api:
+def get_kube_client(
+ in_cluster: bool = conf.getboolean('kubernetes', 'in_cluster'),
+ cluster_context: Optional[str] = None,
+ config_file: Optional[str] = None,
+) -> client.CoreV1Api:
"""
Retrieves Kubernetes client
diff --git a/airflow/kubernetes/pod_generator.py b/airflow/kubernetes/pod_generator.py
index bf54a6ac81f9d..c981780bddca2 100644
--- a/airflow/kubernetes/pod_generator.py
+++ b/airflow/kubernetes/pod_generator.py
@@ -50,14 +50,8 @@ class PodDefaults:
XCOM_MOUNT_PATH = '/airflow/xcom'
SIDECAR_CONTAINER_NAME = 'airflow-xcom-sidecar'
XCOM_CMD = 'trap "exit 0" INT; while true; do sleep 30; done;'
- VOLUME_MOUNT = k8s.V1VolumeMount(
- name='xcom',
- mount_path=XCOM_MOUNT_PATH
- )
- VOLUME = k8s.V1Volume(
- name='xcom',
- empty_dir=k8s.V1EmptyDirVolumeSource()
- )
+ VOLUME_MOUNT = k8s.V1VolumeMount(name='xcom', mount_path=XCOM_MOUNT_PATH)
+ VOLUME = k8s.V1Volume(name='xcom', empty_dir=k8s.V1EmptyDirVolumeSource())
SIDECAR_CONTAINER = k8s.V1Container(
name=SIDECAR_CONTAINER_NAME,
command=['sh', '-c', XCOM_CMD],
@@ -85,7 +79,7 @@ def make_safe_label_value(string):
if len(safe_label) > MAX_LABEL_LEN or string != safe_label:
safe_hash = hashlib.md5(string.encode()).hexdigest()[:9]
- safe_label = safe_label[:MAX_LABEL_LEN - len(safe_hash) - 1] + "-" + safe_hash
+ safe_label = safe_label[: MAX_LABEL_LEN - len(safe_hash) - 1] + "-" + safe_hash
return safe_label
@@ -134,14 +128,14 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals
self,
pod: Optional[k8s.V1Pod] = None,
pod_template_file: Optional[str] = None,
- extract_xcom: bool = True
+ extract_xcom: bool = True,
):
if not pod_template_file and not pod:
- raise AirflowConfigException("Podgenerator requires either a "
- "`pod` or a `pod_template_file` argument")
+ raise AirflowConfigException(
+ "Podgenerator requires either a " "`pod` or a `pod_template_file` argument"
+ )
if pod_template_file and pod:
- raise AirflowConfigException("Cannot pass both `pod` "
- "and `pod_template_file` arguments")
+ raise AirflowConfigException("Cannot pass both `pod` " "and `pod_template_file` arguments")
if pod_template_file:
self.ud_pod = self.deserialize_model_file(pod_template_file)
@@ -184,25 +178,29 @@ def from_obj(obj) -> Optional[Union[dict, k8s.V1Pod]]:
k8s_object = obj.get("pod_override", None)
if k8s_legacy_object and k8s_object:
- raise AirflowConfigException("Can not have both a legacy and new"
- "executor_config object. Please delete the KubernetesExecutor"
- "dict and only use the pod_override kubernetes.client.models.V1Pod"
- "object.")
+ raise AirflowConfigException(
+ "Can not have both a legacy and new"
+ "executor_config object. Please delete the KubernetesExecutor"
+ "dict and only use the pod_override kubernetes.client.models.V1Pod"
+ "object."
+ )
if not k8s_object and not k8s_legacy_object:
return None
if isinstance(k8s_object, k8s.V1Pod):
return k8s_object
elif isinstance(k8s_legacy_object, dict):
- warnings.warn('Using a dictionary for the executor_config is deprecated and will soon be removed.'
- 'please use a `kubernetes.client.models.V1Pod` class with a "pod_override" key'
- ' instead. ',
- category=DeprecationWarning)
+ warnings.warn(
+ 'Using a dictionary for the executor_config is deprecated and will soon be removed.'
+ 'please use a `kubernetes.client.models.V1Pod` class with a "pod_override" key'
+ ' instead. ',
+ category=DeprecationWarning,
+ )
return PodGenerator.from_legacy_obj(obj)
else:
raise TypeError(
- 'Cannot convert a non-kubernetes.client.models.V1Pod'
- 'object into a KubernetesExecutorConfig')
+ 'Cannot convert a non-kubernetes.client.models.V1Pod' 'object into a KubernetesExecutorConfig'
+ )
@staticmethod
def from_legacy_obj(obj) -> Optional[k8s.V1Pod]:
@@ -223,12 +221,12 @@ def from_legacy_obj(obj) -> Optional[k8s.V1Pod]:
requests = {
'cpu': namespaced.pop('request_cpu', None),
'memory': namespaced.pop('request_memory', None),
- 'ephemeral-storage': namespaced.get('ephemeral-storage') # We pop this one in limits
+ 'ephemeral-storage': namespaced.get('ephemeral-storage'), # We pop this one in limits
}
limits = {
'cpu': namespaced.pop('limit_cpu', None),
'memory': namespaced.pop('limit_memory', None),
- 'ephemeral-storage': namespaced.pop('ephemeral-storage', None)
+ 'ephemeral-storage': namespaced.pop('ephemeral-storage', None),
}
all_resources = list(requests.values()) + list(limits.values())
if all(r is None for r in all_resources):
@@ -237,10 +235,7 @@ def from_legacy_obj(obj) -> Optional[k8s.V1Pod]:
# remove None's so they don't become 0's
requests = {k: v for k, v in requests.items() if v is not None}
limits = {k: v for k, v in limits.items() if v is not None}
- resources = k8s.V1ResourceRequirements(
- requests=requests,
- limits=limits
- )
+ resources = k8s.V1ResourceRequirements(requests=requests, limits=limits)
namespaced['resources'] = resources
return PodGeneratorDeprecated(**namespaced).gen_pod()
@@ -292,8 +287,9 @@ def reconcile_metadata(base_meta, client_meta):
return None
@staticmethod
- def reconcile_specs(base_spec: Optional[k8s.V1PodSpec],
- client_spec: Optional[k8s.V1PodSpec]) -> Optional[k8s.V1PodSpec]:
+ def reconcile_specs(
+ base_spec: Optional[k8s.V1PodSpec], client_spec: Optional[k8s.V1PodSpec]
+ ) -> Optional[k8s.V1PodSpec]:
"""
:param base_spec: has the base attributes which are overwritten if they exist
in the client_spec and remain if they do not exist in the client_spec
@@ -316,8 +312,9 @@ def reconcile_specs(base_spec: Optional[k8s.V1PodSpec],
return None
@staticmethod
- def reconcile_containers(base_containers: List[k8s.V1Container],
- client_containers: List[k8s.V1Container]) -> List[k8s.V1Container]:
+ def reconcile_containers(
+ base_containers: List[k8s.V1Container], client_containers: List[k8s.V1Container]
+ ) -> List[k8s.V1Container]:
"""
:param base_containers: has the base attributes which are overwritten if they exist
in the client_containers and remain if they do not exist in the client_containers
@@ -358,7 +355,7 @@ def construct_pod( # pylint: disable=too-many-arguments
pod_override_object: Optional[k8s.V1Pod],
base_worker_pod: k8s.V1Pod,
namespace: str,
- scheduler_job_id: str
+ scheduler_job_id: str,
) -> k8s.V1Pod:
"""
Construct a pod by gathering and consolidating the configuration from 3 places:
@@ -391,7 +388,7 @@ def construct_pod( # pylint: disable=too-many-arguments
'try_number': str(try_number),
'airflow_version': airflow_version.replace('+', '-'),
'kubernetes_executor': 'True',
- }
+ },
),
spec=k8s.V1PodSpec(
containers=[
@@ -401,7 +398,7 @@ def construct_pod( # pylint: disable=too-many-arguments
image=image,
)
]
- )
+ ),
)
# Reconcile the pods starting with the first chronologically,
@@ -449,8 +446,7 @@ def deserialize_model_dict(pod_dict: dict) -> k8s.V1Pod:
@return:
"""
api_client = ApiClient()
- return api_client._ApiClient__deserialize_model( # pylint: disable=W0212
- pod_dict, k8s.V1Pod)
+ return api_client._ApiClient__deserialize_model(pod_dict, k8s.V1Pod) # pylint: disable=W0212
@staticmethod
def make_unique_pod_id(pod_id):
@@ -466,7 +462,7 @@ def make_unique_pod_id(pod_id):
return None
safe_uuid = uuid.uuid4().hex
- safe_pod_id = pod_id[:MAX_POD_ID_LEN - len(safe_uuid) - 1] + "-" + safe_uuid
+ safe_pod_id = pod_id[: MAX_POD_ID_LEN - len(safe_uuid) - 1] + "-" + safe_uuid
return safe_pod_id
@@ -513,8 +509,9 @@ def extend_object_field(base_obj, client_obj, field_name):
base_obj_field = getattr(base_obj, field_name, None)
client_obj_field = getattr(client_obj, field_name, None)
- if (not isinstance(base_obj_field, list) and base_obj_field is not None) or \
- (not isinstance(client_obj_field, list) and client_obj_field is not None):
+ if (not isinstance(base_obj_field, list) and base_obj_field is not None) or (
+ not isinstance(client_obj_field, list) and client_obj_field is not None
+ ):
raise ValueError("The chosen field must be a list.")
if not base_obj_field:
diff --git a/airflow/kubernetes/pod_generator_deprecated.py b/airflow/kubernetes/pod_generator_deprecated.py
index cdf9a9182b99c..79bdcb4406e98 100644
--- a/airflow/kubernetes/pod_generator_deprecated.py
+++ b/airflow/kubernetes/pod_generator_deprecated.py
@@ -39,14 +39,8 @@ class PodDefaults:
XCOM_MOUNT_PATH = '/airflow/xcom'
SIDECAR_CONTAINER_NAME = 'airflow-xcom-sidecar'
XCOM_CMD = 'trap "exit 0" INT; while true; do sleep 30; done;'
- VOLUME_MOUNT = k8s.V1VolumeMount(
- name='xcom',
- mount_path=XCOM_MOUNT_PATH
- )
- VOLUME = k8s.V1Volume(
- name='xcom',
- empty_dir=k8s.V1EmptyDirVolumeSource()
- )
+ VOLUME_MOUNT = k8s.V1VolumeMount(name='xcom', mount_path=XCOM_MOUNT_PATH)
+ VOLUME = k8s.V1Volume(name='xcom', empty_dir=k8s.V1EmptyDirVolumeSource())
SIDECAR_CONTAINER = k8s.V1Container(
name=SIDECAR_CONTAINER_NAME,
command=['sh', '-c', XCOM_CMD],
@@ -74,7 +68,7 @@ def make_safe_label_value(string):
if len(safe_label) > MAX_LABEL_LEN or string != safe_label:
safe_hash = hashlib.md5(string.encode()).hexdigest()[:9]
- safe_label = safe_label[:MAX_LABEL_LEN - len(safe_hash) - 1] + "-" + safe_hash
+ safe_label = safe_label[: MAX_LABEL_LEN - len(safe_hash) - 1] + "-" + safe_hash
return safe_label
@@ -199,21 +193,16 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals
if envs:
if isinstance(envs, dict):
for key, val in envs.items():
- self.container.env.append(k8s.V1EnvVar(
- name=key,
- value=val
- ))
+ self.container.env.append(k8s.V1EnvVar(name=key, value=val))
elif isinstance(envs, list):
self.container.env.extend(envs)
configmaps = configmaps or []
self.container.env_from = []
for configmap in configmaps:
- self.container.env_from.append(k8s.V1EnvFromSource(
- config_map_ref=k8s.V1ConfigMapEnvSource(
- name=configmap
- )
- ))
+ self.container.env_from.append(
+ k8s.V1EnvFromSource(config_map_ref=k8s.V1ConfigMapEnvSource(name=configmap))
+ )
self.container.command = cmds or []
self.container.args = args or []
@@ -241,9 +230,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals
if image_pull_secrets:
for image_pull_secret in image_pull_secrets.split(','):
- self.spec.image_pull_secrets.append(k8s.V1LocalObjectReference(
- name=image_pull_secret
- ))
+ self.spec.image_pull_secrets.append(k8s.V1LocalObjectReference(name=image_pull_secret))
# Attach sidecar
self.extract_xcom = extract_xcom
@@ -289,7 +276,8 @@ def from_obj(obj) -> Optional[k8s.V1Pod]:
if not isinstance(obj, dict):
raise TypeError(
'Cannot convert a non-dictionary or non-PodGenerator '
- 'object into a KubernetesExecutorConfig')
+ 'object into a KubernetesExecutorConfig'
+ )
# We do not want to extract constant here from ExecutorLoader because it is just
# A name in dictionary rather than executor selection mechanism and it causes cyclic import
@@ -304,21 +292,18 @@ def from_obj(obj) -> Optional[k8s.V1Pod]:
requests = {
'cpu': namespaced.get('request_cpu'),
'memory': namespaced.get('request_memory'),
- 'ephemeral-storage': namespaced.get('ephemeral-storage')
+ 'ephemeral-storage': namespaced.get('ephemeral-storage'),
}
limits = {
'cpu': namespaced.get('limit_cpu'),
'memory': namespaced.get('limit_memory'),
- 'ephemeral-storage': namespaced.get('ephemeral-storage')
+ 'ephemeral-storage': namespaced.get('ephemeral-storage'),
}
all_resources = list(requests.values()) + list(limits.values())
if all(r is None for r in all_resources):
resources = None
else:
- resources = k8s.V1ResourceRequirements(
- requests=requests,
- limits=limits
- )
+ resources = k8s.V1ResourceRequirements(requests=requests, limits=limits)
namespaced['resources'] = resources
return PodGenerator(**namespaced).gen_pod()
@@ -336,6 +321,6 @@ def make_unique_pod_id(dag_id):
return None
safe_uuid = uuid.uuid4().hex
- safe_pod_id = dag_id[:MAX_POD_ID_LEN - len(safe_uuid) - 1] + "-" + safe_uuid
+ safe_pod_id = dag_id[: MAX_POD_ID_LEN - len(safe_uuid) - 1] + "-" + safe_uuid
return safe_pod_id
diff --git a/airflow/kubernetes/pod_launcher.py b/airflow/kubernetes/pod_launcher.py
index e528e2c4ae3a0..08d0bf4f37de9 100644
--- a/airflow/kubernetes/pod_launcher.py
+++ b/airflow/kubernetes/pod_launcher.py
@@ -49,11 +49,13 @@ class PodStatus:
class PodLauncher(LoggingMixin):
"""Launches PODS"""
- def __init__(self,
- kube_client: client.CoreV1Api = None,
- in_cluster: bool = True,
- cluster_context: Optional[str] = None,
- extract_xcom: bool = False):
+ def __init__(
+ self,
+ kube_client: client.CoreV1Api = None,
+ in_cluster: bool = True,
+ cluster_context: Optional[str] = None,
+ extract_xcom: bool = False,
+ ):
"""
Creates the launcher.
@@ -63,8 +65,7 @@ def __init__(self,
:param extract_xcom: whether we should extract xcom
"""
super().__init__()
- self._client = kube_client or get_kube_client(in_cluster=in_cluster,
- cluster_context=cluster_context)
+ self._client = kube_client or get_kube_client(in_cluster=in_cluster, cluster_context=cluster_context)
self._watch = watch.Watch()
self.extract_xcom = extract_xcom
@@ -77,12 +78,12 @@ def run_pod_async(self, pod: V1Pod, **kwargs):
self.log.debug('Pod Creation Request: \n%s', json_pod)
try:
- resp = self._client.create_namespaced_pod(body=sanitized_pod,
- namespace=pod.metadata.namespace, **kwargs)
+ resp = self._client.create_namespaced_pod(
+ body=sanitized_pod, namespace=pod.metadata.namespace, **kwargs
+ )
self.log.debug('Pod Creation Response: %s', resp)
except Exception as e:
- self.log.exception('Exception when attempting '
- 'to create Namespaced Pod: %s', json_pod)
+ self.log.exception('Exception when attempting ' 'to create Namespaced Pod: %s', json_pod)
raise e
return resp
@@ -90,16 +91,14 @@ def delete_pod(self, pod: V1Pod):
"""Deletes POD"""
try:
self._client.delete_namespaced_pod(
- pod.metadata.name, pod.metadata.namespace, body=client.V1DeleteOptions())
+ pod.metadata.name, pod.metadata.namespace, body=client.V1DeleteOptions()
+ )
except ApiException as e:
# If the pod is already deleted
if e.status != 404:
raise
- def start_pod(
- self,
- pod: V1Pod,
- startup_timeout: int = 120):
+ def start_pod(self, pod: V1Pod, startup_timeout: int = 120):
"""
Launches the pod synchronously and waits for completion.
@@ -170,13 +169,11 @@ def parse_log_line(self, line: str) -> Tuple[str, str]:
if split_at == -1:
raise Exception(f'Log not in "{{timestamp}} {{log}}" format. Got: {line}')
timestamp = line[:split_at]
- message = line[split_at + 1:].rstrip()
+ message = line[split_at + 1 :].rstrip()
return timestamp, message
def _task_status(self, event):
- self.log.info(
- 'Event: %s had an event of type %s',
- event.metadata.name, event.status.phase)
+ self.log.info('Event: %s had an event of type %s', event.metadata.name, event.status.phase)
status = self.process_status(event.metadata.name, event.status.phase)
return status
@@ -193,22 +190,19 @@ def pod_is_running(self, pod: V1Pod):
def base_container_is_running(self, pod: V1Pod):
"""Tests if base container is running"""
event = self.read_pod(pod)
- status = next(iter(filter(lambda s: s.name == 'base',
- event.status.container_statuses)), None)
+ status = next(iter(filter(lambda s: s.name == 'base', event.status.container_statuses)), None)
if not status:
return False
return status.state.running is not None
- @tenacity.retry(
- stop=tenacity.stop_after_attempt(3),
- wait=tenacity.wait_exponential(),
- reraise=True
- )
- def read_pod_logs(self,
- pod: V1Pod,
- tail_lines: Optional[int] = None,
- timestamps: bool = False,
- since_seconds: Optional[int] = None):
+ @tenacity.retry(stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_exponential(), reraise=True)
+ def read_pod_logs(
+ self,
+ pod: V1Pod,
+ tail_lines: Optional[int] = None,
+ timestamps: bool = False,
+ since_seconds: Optional[int] = None,
+ ):
"""Reads log from the POD"""
additional_kwargs = {}
if since_seconds:
@@ -225,54 +219,44 @@ def read_pod_logs(self,
follow=True,
timestamps=timestamps,
_preload_content=False,
- **additional_kwargs
+ **additional_kwargs,
)
except BaseHTTPError as e:
- raise AirflowException(
- f'There was an error reading the kubernetes API: {e}'
- )
+ raise AirflowException(f'There was an error reading the kubernetes API: {e}')
- @tenacity.retry(
- stop=tenacity.stop_after_attempt(3),
- wait=tenacity.wait_exponential(),
- reraise=True
- )
+ @tenacity.retry(stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_exponential(), reraise=True)
def read_pod_events(self, pod):
"""Reads events from the POD"""
try:
return self._client.list_namespaced_event(
- namespace=pod.metadata.namespace,
- field_selector=f"involvedObject.name={pod.metadata.name}"
+ namespace=pod.metadata.namespace, field_selector=f"involvedObject.name={pod.metadata.name}"
)
except BaseHTTPError as e:
- raise AirflowException(
- f'There was an error reading the kubernetes API: {e}'
- )
+ raise AirflowException(f'There was an error reading the kubernetes API: {e}')
- @tenacity.retry(
- stop=tenacity.stop_after_attempt(3),
- wait=tenacity.wait_exponential(),
- reraise=True
- )
+ @tenacity.retry(stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_exponential(), reraise=True)
def read_pod(self, pod: V1Pod):
"""Read POD information"""
try:
return self._client.read_namespaced_pod(pod.metadata.name, pod.metadata.namespace)
except BaseHTTPError as e:
- raise AirflowException(
- f'There was an error reading the kubernetes API: {e}'
- )
+ raise AirflowException(f'There was an error reading the kubernetes API: {e}')
def _extract_xcom(self, pod: V1Pod):
- resp = kubernetes_stream(self._client.connect_get_namespaced_pod_exec,
- pod.metadata.name, pod.metadata.namespace,
- container=PodDefaults.SIDECAR_CONTAINER_NAME,
- command=['/bin/sh'], stdin=True, stdout=True,
- stderr=True, tty=False,
- _preload_content=False)
+ resp = kubernetes_stream(
+ self._client.connect_get_namespaced_pod_exec,
+ pod.metadata.name,
+ pod.metadata.namespace,
+ container=PodDefaults.SIDECAR_CONTAINER_NAME,
+ command=['/bin/sh'],
+ stdin=True,
+ stdout=True,
+ stderr=True,
+ tty=False,
+ _preload_content=False,
+ )
try:
- result = self._exec_pod_command(
- resp, f'cat {PodDefaults.XCOM_MOUNT_PATH}/return.json')
+ result = self._exec_pod_command(resp, f'cat {PodDefaults.XCOM_MOUNT_PATH}/return.json')
self._exec_pod_command(resp, 'kill -s SIGINT 1')
finally:
resp.close()
diff --git a/airflow/kubernetes/pod_runtime_info_env.py b/airflow/kubernetes/pod_runtime_info_env.py
index 57ac8de4bd33d..fa841cdacd15b 100644
--- a/airflow/kubernetes/pod_runtime_info_env.py
+++ b/airflow/kubernetes/pod_runtime_info_env.py
@@ -42,11 +42,7 @@ def to_k8s_client_obj(self) -> k8s.V1EnvVar:
""":return: kubernetes.client.models.V1EnvVar"""
return k8s.V1EnvVar(
name=self.name,
- value_from=k8s.V1EnvVarSource(
- field_ref=k8s.V1ObjectFieldSelector(
- field_path=self.field_path
- )
- )
+ value_from=k8s.V1EnvVarSource(field_ref=k8s.V1ObjectFieldSelector(field_path=self.field_path)),
)
def attach_to_pod(self, pod: k8s.V1Pod) -> k8s.V1Pod:
diff --git a/airflow/kubernetes/refresh_config.py b/airflow/kubernetes/refresh_config.py
index 53b7decfd1207..9067cb15da71b 100644
--- a/airflow/kubernetes/refresh_config.py
+++ b/airflow/kubernetes/refresh_config.py
@@ -104,7 +104,8 @@ def _get_kube_config_loader_for_yaml_file(filename, **kwargs) -> Optional[Refres
return RefreshKubeConfigLoader(
config_dict=yaml.safe_load(f),
config_base_path=os.path.abspath(os.path.dirname(filename)),
- **kwargs)
+ **kwargs,
+ )
def load_kube_config(client_configuration, config_file=None, context=None):
@@ -117,6 +118,5 @@ def load_kube_config(client_configuration, config_file=None, context=None):
if config_file is None:
config_file = os.path.expanduser(KUBE_CONFIG_DEFAULT_LOCATION)
- loader = _get_kube_config_loader_for_yaml_file(
- config_file, active_context=context, config_persister=None)
+ loader = _get_kube_config_loader_for_yaml_file(config_file, active_context=context, config_persister=None)
loader.load_and_set(client_configuration)
diff --git a/airflow/kubernetes/secret.py b/airflow/kubernetes/secret.py
index 197464f562a78..20ed27b1ffb31 100644
--- a/airflow/kubernetes/secret.py
+++ b/airflow/kubernetes/secret.py
@@ -62,9 +62,7 @@ def __init__(self, deploy_type, deploy_target, secret, key=None, items=None):
self.deploy_target = deploy_target.upper()
if key is not None and deploy_target is None:
- raise AirflowConfigException(
- 'If `key` is set, `deploy_target` should not be None'
- )
+ raise AirflowConfigException('If `key` is set, `deploy_target` should not be None')
self.secret = secret
self.key = key
@@ -74,18 +72,13 @@ def to_env_secret(self) -> k8s.V1EnvVar:
return k8s.V1EnvVar(
name=self.deploy_target,
value_from=k8s.V1EnvVarSource(
- secret_key_ref=k8s.V1SecretKeySelector(
- name=self.secret,
- key=self.key
- )
- )
+ secret_key_ref=k8s.V1SecretKeySelector(name=self.secret, key=self.key)
+ ),
)
def to_env_from_secret(self) -> k8s.V1EnvFromSource:
"""Reads from environment to secret"""
- return k8s.V1EnvFromSource(
- secret_ref=k8s.V1SecretEnvSource(name=self.secret)
- )
+ return k8s.V1EnvFromSource(secret_ref=k8s.V1SecretEnvSource(name=self.secret))
def to_volume_secret(self) -> Tuple[k8s.V1Volume, k8s.V1VolumeMount]:
"""Converts to volume secret"""
@@ -93,14 +86,7 @@ def to_volume_secret(self) -> Tuple[k8s.V1Volume, k8s.V1VolumeMount]:
volume = k8s.V1Volume(name=vol_id, secret=k8s.V1SecretVolumeSource(secret_name=self.secret))
if self.items:
volume.secret.items = self.items
- return (
- volume,
- k8s.V1VolumeMount(
- mount_path=self.deploy_target,
- name=vol_id,
- read_only=True
- )
- )
+ return (volume, k8s.V1VolumeMount(mount_path=self.deploy_target, name=vol_id, read_only=True))
def attach_to_pod(self, pod: k8s.V1Pod) -> k8s.V1Pod:
"""Attaches to pod"""
@@ -123,16 +109,11 @@ def attach_to_pod(self, pod: k8s.V1Pod) -> k8s.V1Pod:
def __eq__(self, other):
return (
- self.deploy_type == other.deploy_type and
- self.deploy_target == other.deploy_target and
- self.secret == other.secret and
- self.key == other.key
+ self.deploy_type == other.deploy_type
+ and self.deploy_target == other.deploy_target
+ and self.secret == other.secret
+ and self.key == other.key
)
def __repr__(self):
- return 'Secret({}, {}, {}, {})'.format(
- self.deploy_type,
- self.deploy_target,
- self.secret,
- self.key
- )
+ return f'Secret({self.deploy_type}, {self.deploy_target}, {self.secret}, {self.key})'
diff --git a/airflow/lineage/__init__.py b/airflow/lineage/__init__.py
index 7641f39489ede..0f2a961e6dd2d 100644
--- a/airflow/lineage/__init__.py
+++ b/airflow/lineage/__init__.py
@@ -54,9 +54,14 @@ def _get_instance(meta: Metadata):
def _render_object(obj: Any, context) -> Any:
"""Renders a attr annotated object. Will set non serializable attributes to none"""
- return structure(json.loads(ENV.from_string(
- json.dumps(unstructure(obj), default=lambda o: None)
- ).render(**context).encode('utf-8')), type(obj))
+ return structure(
+ json.loads(
+ ENV.from_string(json.dumps(unstructure(obj), default=lambda o: None))
+ .render(**context)
+ .encode('utf-8')
+ ),
+ type(obj),
+ )
def _to_dataset(obj: Any, source: str) -> Optional[Metadata]:
@@ -81,26 +86,21 @@ def apply_lineage(func: T) -> T:
@wraps(func)
def wrapper(self, context, *args, **kwargs):
- self.log.debug("Lineage called with inlets: %s, outlets: %s",
- self.inlets, self.outlets)
+ self.log.debug("Lineage called with inlets: %s, outlets: %s", self.inlets, self.outlets)
ret_val = func(self, context, *args, **kwargs)
- outlets = [unstructure(_to_dataset(x, f"{self.dag_id}.{self.task_id}"))
- for x in self.outlets]
- inlets = [unstructure(_to_dataset(x, None))
- for x in self.inlets]
+ outlets = [unstructure(_to_dataset(x, f"{self.dag_id}.{self.task_id}")) for x in self.outlets]
+ inlets = [unstructure(_to_dataset(x, None)) for x in self.inlets]
if self.outlets:
- self.xcom_push(context,
- key=PIPELINE_OUTLETS,
- value=outlets,
- execution_date=context['ti'].execution_date)
+ self.xcom_push(
+ context, key=PIPELINE_OUTLETS, value=outlets, execution_date=context['ti'].execution_date
+ )
if self.inlets:
- self.xcom_push(context,
- key=PIPELINE_INLETS,
- value=inlets,
- execution_date=context['ti'].execution_date)
+ self.xcom_push(
+ context, key=PIPELINE_INLETS, value=inlets, execution_date=context['ti'].execution_date
+ )
return ret_val
@@ -123,27 +123,28 @@ def wrapper(self, context, *args, **kwargs):
self.log.debug("Preparing lineage inlets and outlets")
if isinstance(self._inlets, (str, Operator)) or attr.has(self._inlets):
- self._inlets = [self._inlets, ]
+ self._inlets = [
+ self._inlets,
+ ]
if self._inlets and isinstance(self._inlets, list):
# get task_ids that are specified as parameter and make sure they are upstream
- task_ids = set(
- filter(lambda x: isinstance(x, str) and x.lower() != AUTO, self._inlets)
- ).union(
- map(lambda op: op.task_id,
- filter(lambda op: isinstance(op, Operator), self._inlets))
- ).intersection(self.get_flat_relative_ids(upstream=True))
+ task_ids = (
+ set(filter(lambda x: isinstance(x, str) and x.lower() != AUTO, self._inlets))
+ .union(map(lambda op: op.task_id, filter(lambda op: isinstance(op, Operator), self._inlets)))
+ .intersection(self.get_flat_relative_ids(upstream=True))
+ )
# pick up unique direct upstream task_ids if AUTO is specified
if AUTO.upper() in self._inlets or AUTO.lower() in self._inlets:
task_ids = task_ids.union(task_ids.symmetric_difference(self.upstream_task_ids))
- _inlets = self.xcom_pull(context, task_ids=task_ids,
- dag_id=self.dag_id, key=PIPELINE_OUTLETS)
+ _inlets = self.xcom_pull(context, task_ids=task_ids, dag_id=self.dag_id, key=PIPELINE_OUTLETS)
# re-instantiate the obtained inlets
- _inlets = [_get_instance(structure(item, Metadata))
- for sublist in _inlets if sublist for item in sublist]
+ _inlets = [
+ _get_instance(structure(item, Metadata)) for sublist in _inlets if sublist for item in sublist
+ ]
self.inlets.extend(_inlets)
self.inlets.extend(self._inlets)
@@ -152,16 +153,16 @@ def wrapper(self, context, *args, **kwargs):
raise AttributeError("inlets is not a list, operator, string or attr annotated object")
if not isinstance(self._outlets, list):
- self._outlets = [self._outlets, ]
+ self._outlets = [
+ self._outlets,
+ ]
self.outlets.extend(self._outlets)
# render inlets and outlets
- self.inlets = [_render_object(i, context)
- for i in self.inlets if attr.has(i)]
+ self.inlets = [_render_object(i, context) for i in self.inlets if attr.has(i)]
- self.outlets = [_render_object(i, context)
- for i in self.outlets if attr.has(i)]
+ self.outlets = [_render_object(i, context) for i in self.outlets if attr.has(i)]
self.log.debug("inlets: %s, outlets: %s", self.inlets, self.outlets)
return func(self, context, *args, **kwargs)
diff --git a/airflow/lineage/entities.py b/airflow/lineage/entities.py
index 9ded4ff67c935..f2bad75796197 100644
--- a/airflow/lineage/entities.py
+++ b/airflow/lineage/entities.py
@@ -63,7 +63,7 @@ class Column:
# https://github.com/python/mypy/issues/6136 is resolved, use
# `attr.converters.default_if_none(default=False)`
# pylint: disable=missing-docstring
-def default_if_none(arg: Optional[bool]) -> bool: # noqa: D103
+def default_if_none(arg: Optional[bool]) -> bool: # noqa: D103
return arg or False
diff --git a/airflow/logging_config.py b/airflow/logging_config.py
index e2b6adb2537da..e827273ce5a1b 100644
--- a/airflow/logging_config.py
+++ b/airflow/logging_config.py
@@ -43,19 +43,12 @@ def configure_logging():
if not isinstance(logging_config, dict):
raise ValueError("Logging Config should be of dict type")
- log.info(
- 'Successfully imported user-defined logging config from %s',
- logging_class_path
- )
+ log.info('Successfully imported user-defined logging config from %s', logging_class_path)
except Exception as err:
# Import default logging configurations.
- raise ImportError(
- 'Unable to load custom logging from {} due to {}'
- .format(logging_class_path, err)
- )
+ raise ImportError(f'Unable to load custom logging from {logging_class_path} due to {err}')
else:
- logging_class_path = 'airflow.config_templates.' \
- 'airflow_local_settings.DEFAULT_LOGGING_CONFIG'
+ logging_class_path = 'airflow.config_templates.airflow_local_settings.DEFAULT_LOGGING_CONFIG'
logging_config = import_string(logging_class_path)
log.debug('Unable to load custom logging, using default config instead')
@@ -73,7 +66,7 @@ def configure_logging():
return logging_class_path
-def validate_logging_config(logging_config): # pylint: disable=unused-argument
+def validate_logging_config(logging_config): # pylint: disable=unused-argument
"""Validate the provided Logging Config"""
# Now lets validate the other logging-related settings
task_log_reader = conf.get('logging', 'task_log_reader')
diff --git a/airflow/macros/hive.py b/airflow/macros/hive.py
index c6e1e200fc919..39c66cdec1229 100644
--- a/airflow/macros/hive.py
+++ b/airflow/macros/hive.py
@@ -20,8 +20,8 @@
def max_partition(
- table, schema="default", field=None, filter_map=None,
- metastore_conn_id='metastore_default'):
+ table, schema="default", field=None, filter_map=None, metastore_conn_id='metastore_default'
+):
"""
Gets the max partition for a table.
@@ -47,11 +47,11 @@ def max_partition(
'2015-01-01'
"""
from airflow.providers.apache.hive.hooks.hive import HiveMetastoreHook
+
if '.' in table:
schema, table = table.split('.')
hive_hook = HiveMetastoreHook(metastore_conn_id=metastore_conn_id)
- return hive_hook.max_partition(
- schema=schema, table_name=table, field=field, filter_map=filter_map)
+ return hive_hook.max_partition(schema=schema, table_name=table, field=field, filter_map=filter_map)
def _closest_date(target_dt, date_list, before_target=None):
@@ -79,9 +79,7 @@ def _closest_date(target_dt, date_list, before_target=None):
return min(date_list, key=time_after).date()
-def closest_ds_partition(
- table, ds, before=True, schema="default",
- metastore_conn_id='metastore_default'):
+def closest_ds_partition(table, ds, before=True, schema="default", metastore_conn_id='metastore_default'):
"""
This function finds the date in a list closest to the target date.
An optional parameter can be given to get the closest before or after.
@@ -104,6 +102,7 @@ def closest_ds_partition(
'2015-01-01'
"""
from airflow.providers.apache.hive.hooks.hive import HiveMetastoreHook
+
if '.' in table:
schema, table = table.split('.')
hive_hook = HiveMetastoreHook(metastore_conn_id=metastore_conn_id)
@@ -114,8 +113,7 @@ def closest_ds_partition(
if ds in part_vals:
return ds
else:
- parts = [datetime.datetime.strptime(pv, '%Y-%m-%d')
- for pv in part_vals]
+ parts = [datetime.datetime.strptime(pv, '%Y-%m-%d') for pv in part_vals]
target_dt = datetime.datetime.strptime(ds, '%Y-%m-%d')
closest_ds = _closest_date(target_dt, parts, before_target=before)
return closest_ds.isoformat()
diff --git a/airflow/migrations/env.py b/airflow/migrations/env.py
index 2230724eb1a45..459cb7c176662 100644
--- a/airflow/migrations/env.py
+++ b/airflow/migrations/env.py
@@ -72,7 +72,8 @@ def run_migrations_offline():
target_metadata=target_metadata,
literal_binds=True,
compare_type=COMPARE_TYPE,
- render_as_batch=True)
+ render_as_batch=True,
+ )
with context.begin_transaction():
context.run_migrations()
@@ -94,7 +95,7 @@ def run_migrations_online():
target_metadata=target_metadata,
compare_type=COMPARE_TYPE,
include_object=include_object,
- render_as_batch=True
+ render_as_batch=True,
)
with context.begin_transaction():
diff --git a/airflow/migrations/versions/03bc53e68815_add_sm_dag_index.py b/airflow/migrations/versions/03bc53e68815_add_sm_dag_index.py
index f3e7330ac11e3..d66bf5c96e3d5 100644
--- a/airflow/migrations/versions/03bc53e68815_add_sm_dag_index.py
+++ b/airflow/migrations/versions/03bc53e68815_add_sm_dag_index.py
@@ -32,9 +32,9 @@
depends_on = None
-def upgrade(): # noqa: D103
+def upgrade(): # noqa: D103
op.create_index('sm_dag', 'sla_miss', ['dag_id'], unique=False)
-def downgrade(): # noqa: D103
+def downgrade(): # noqa: D103
op.drop_index('sm_dag', table_name='sla_miss')
diff --git a/airflow/migrations/versions/05f30312d566_merge_heads.py b/airflow/migrations/versions/05f30312d566_merge_heads.py
index 68e7dbd3a0943..ffe2330196270 100644
--- a/airflow/migrations/versions/05f30312d566_merge_heads.py
+++ b/airflow/migrations/versions/05f30312d566_merge_heads.py
@@ -30,9 +30,9 @@
depends_on = None
-def upgrade(): # noqa: D103
+def upgrade(): # noqa: D103
pass
-def downgrade(): # noqa: D103
+def downgrade(): # noqa: D103
pass
diff --git a/airflow/migrations/versions/0a2a5b66e19d_add_task_reschedule_table.py b/airflow/migrations/versions/0a2a5b66e19d_add_task_reschedule_table.py
index 378c62dfde65e..4c572f44ec5a2 100644
--- a/airflow/migrations/versions/0a2a5b66e19d_add_task_reschedule_table.py
+++ b/airflow/migrations/versions/0a2a5b66e19d_add_task_reschedule_table.py
@@ -41,19 +41,19 @@
# For Microsoft SQL Server, TIMESTAMP is a row-id type,
# having nothing to do with date-time. DateTime() will
# be sufficient.
-def mssql_timestamp(): # noqa: D103
+def mssql_timestamp(): # noqa: D103
return sa.DateTime()
-def mysql_timestamp(): # noqa: D103
+def mysql_timestamp(): # noqa: D103
return mysql.TIMESTAMP(fsp=6)
-def sa_timestamp(): # noqa: D103
+def sa_timestamp(): # noqa: D103
return sa.TIMESTAMP(timezone=True)
-def upgrade(): # noqa: D103
+def upgrade(): # noqa: D103
# See 0e2a74e0fc9f_add_time_zone_awareness
conn = op.get_bind()
if conn.dialect.name == 'mysql':
@@ -79,16 +79,12 @@ def upgrade(): # noqa: D103
sa.ForeignKeyConstraint(
['task_id', 'dag_id', 'execution_date'],
['task_instance.task_id', 'task_instance.dag_id', 'task_instance.execution_date'],
- name='task_reschedule_dag_task_date_fkey')
- )
- op.create_index(
- INDEX_NAME,
- TABLE_NAME,
- ['dag_id', 'task_id', 'execution_date'],
- unique=False
+ name='task_reschedule_dag_task_date_fkey',
+ ),
)
+ op.create_index(INDEX_NAME, TABLE_NAME, ['dag_id', 'task_id', 'execution_date'], unique=False)
-def downgrade(): # noqa: D103
+def downgrade(): # noqa: D103
op.drop_index(INDEX_NAME, table_name=TABLE_NAME)
op.drop_table(TABLE_NAME)
diff --git a/airflow/migrations/versions/0e2a74e0fc9f_add_time_zone_awareness.py b/airflow/migrations/versions/0e2a74e0fc9f_add_time_zone_awareness.py
index daa6a36f1ea6d..f18809c994b10 100644
--- a/airflow/migrations/versions/0e2a74e0fc9f_add_time_zone_awareness.py
+++ b/airflow/migrations/versions/0e2a74e0fc9f_add_time_zone_awareness.py
@@ -34,16 +34,14 @@
depends_on = None
-def upgrade(): # noqa: D103
+def upgrade(): # noqa: D103
conn = op.get_bind()
if conn.dialect.name == "mysql":
conn.execute("SET time_zone = '+00:00'")
cur = conn.execute("SELECT @@explicit_defaults_for_timestamp")
res = cur.fetchall()
if res[0][0] == 0:
- raise Exception(
- "Global variable explicit_defaults_for_timestamp needs to be on (1) for mysql"
- )
+ raise Exception("Global variable explicit_defaults_for_timestamp needs to be on (1) for mysql")
op.alter_column(
table_name="chart",
@@ -56,12 +54,8 @@ def upgrade(): # noqa: D103
column_name="last_scheduler_run",
type_=mysql.TIMESTAMP(fsp=6),
)
- op.alter_column(
- table_name="dag", column_name="last_pickled", type_=mysql.TIMESTAMP(fsp=6)
- )
- op.alter_column(
- table_name="dag", column_name="last_expired", type_=mysql.TIMESTAMP(fsp=6)
- )
+ op.alter_column(table_name="dag", column_name="last_pickled", type_=mysql.TIMESTAMP(fsp=6))
+ op.alter_column(table_name="dag", column_name="last_expired", type_=mysql.TIMESTAMP(fsp=6))
op.alter_column(
table_name="dag_pickle",
@@ -74,12 +68,8 @@ def upgrade(): # noqa: D103
column_name="execution_date",
type_=mysql.TIMESTAMP(fsp=6),
)
- op.alter_column(
- table_name="dag_run", column_name="start_date", type_=mysql.TIMESTAMP(fsp=6)
- )
- op.alter_column(
- table_name="dag_run", column_name="end_date", type_=mysql.TIMESTAMP(fsp=6)
- )
+ op.alter_column(table_name="dag_run", column_name="start_date", type_=mysql.TIMESTAMP(fsp=6))
+ op.alter_column(table_name="dag_run", column_name="end_date", type_=mysql.TIMESTAMP(fsp=6))
op.alter_column(
table_name="import_error",
@@ -87,24 +77,16 @@ def upgrade(): # noqa: D103
type_=mysql.TIMESTAMP(fsp=6),
)
- op.alter_column(
- table_name="job", column_name="start_date", type_=mysql.TIMESTAMP(fsp=6)
- )
- op.alter_column(
- table_name="job", column_name="end_date", type_=mysql.TIMESTAMP(fsp=6)
- )
+ op.alter_column(table_name="job", column_name="start_date", type_=mysql.TIMESTAMP(fsp=6))
+ op.alter_column(table_name="job", column_name="end_date", type_=mysql.TIMESTAMP(fsp=6))
op.alter_column(
table_name="job",
column_name="latest_heartbeat",
type_=mysql.TIMESTAMP(fsp=6),
)
- op.alter_column(
- table_name="log", column_name="dttm", type_=mysql.TIMESTAMP(fsp=6)
- )
- op.alter_column(
- table_name="log", column_name="execution_date", type_=mysql.TIMESTAMP(fsp=6)
- )
+ op.alter_column(table_name="log", column_name="dttm", type_=mysql.TIMESTAMP(fsp=6))
+ op.alter_column(table_name="log", column_name="execution_date", type_=mysql.TIMESTAMP(fsp=6))
op.alter_column(
table_name="sla_miss",
@@ -112,9 +94,7 @@ def upgrade(): # noqa: D103
type_=mysql.TIMESTAMP(fsp=6),
nullable=False,
)
- op.alter_column(
- table_name="sla_miss", column_name="timestamp", type_=mysql.TIMESTAMP(fsp=6)
- )
+ op.alter_column(table_name="sla_miss", column_name="timestamp", type_=mysql.TIMESTAMP(fsp=6))
op.alter_column(
table_name="task_fail",
@@ -126,9 +106,7 @@ def upgrade(): # noqa: D103
column_name="start_date",
type_=mysql.TIMESTAMP(fsp=6),
)
- op.alter_column(
- table_name="task_fail", column_name="end_date", type_=mysql.TIMESTAMP(fsp=6)
- )
+ op.alter_column(table_name="task_fail", column_name="end_date", type_=mysql.TIMESTAMP(fsp=6))
op.alter_column(
table_name="task_instance",
@@ -152,9 +130,7 @@ def upgrade(): # noqa: D103
type_=mysql.TIMESTAMP(fsp=6),
)
- op.alter_column(
- table_name="xcom", column_name="timestamp", type_=mysql.TIMESTAMP(fsp=6)
- )
+ op.alter_column(table_name="xcom", column_name="timestamp", type_=mysql.TIMESTAMP(fsp=6))
op.alter_column(
table_name="xcom",
column_name="execution_date",
@@ -225,18 +201,14 @@ def upgrade(): # noqa: D103
column_name="start_date",
type_=sa.TIMESTAMP(timezone=True),
)
- op.alter_column(
- table_name="job", column_name="end_date", type_=sa.TIMESTAMP(timezone=True)
- )
+ op.alter_column(table_name="job", column_name="end_date", type_=sa.TIMESTAMP(timezone=True))
op.alter_column(
table_name="job",
column_name="latest_heartbeat",
type_=sa.TIMESTAMP(timezone=True),
)
- op.alter_column(
- table_name="log", column_name="dttm", type_=sa.TIMESTAMP(timezone=True)
- )
+ op.alter_column(table_name="log", column_name="dttm", type_=sa.TIMESTAMP(timezone=True))
op.alter_column(
table_name="log",
column_name="execution_date",
@@ -305,25 +277,19 @@ def upgrade(): # noqa: D103
)
-def downgrade(): # noqa: D103
+def downgrade(): # noqa: D103
conn = op.get_bind()
if conn.dialect.name == "mysql":
conn.execute("SET time_zone = '+00:00'")
- op.alter_column(
- table_name="chart", column_name="last_modified", type_=mysql.DATETIME(fsp=6)
- )
+ op.alter_column(table_name="chart", column_name="last_modified", type_=mysql.DATETIME(fsp=6))
op.alter_column(
table_name="dag",
column_name="last_scheduler_run",
type_=mysql.DATETIME(fsp=6),
)
- op.alter_column(
- table_name="dag", column_name="last_pickled", type_=mysql.DATETIME(fsp=6)
- )
- op.alter_column(
- table_name="dag", column_name="last_expired", type_=mysql.DATETIME(fsp=6)
- )
+ op.alter_column(table_name="dag", column_name="last_pickled", type_=mysql.DATETIME(fsp=6))
+ op.alter_column(table_name="dag", column_name="last_expired", type_=mysql.DATETIME(fsp=6))
op.alter_column(
table_name="dag_pickle",
@@ -336,12 +302,8 @@ def downgrade(): # noqa: D103
column_name="execution_date",
type_=mysql.DATETIME(fsp=6),
)
- op.alter_column(
- table_name="dag_run", column_name="start_date", type_=mysql.DATETIME(fsp=6)
- )
- op.alter_column(
- table_name="dag_run", column_name="end_date", type_=mysql.DATETIME(fsp=6)
- )
+ op.alter_column(table_name="dag_run", column_name="start_date", type_=mysql.DATETIME(fsp=6))
+ op.alter_column(table_name="dag_run", column_name="end_date", type_=mysql.DATETIME(fsp=6))
op.alter_column(
table_name="import_error",
@@ -349,24 +311,16 @@ def downgrade(): # noqa: D103
type_=mysql.DATETIME(fsp=6),
)
- op.alter_column(
- table_name="job", column_name="start_date", type_=mysql.DATETIME(fsp=6)
- )
- op.alter_column(
- table_name="job", column_name="end_date", type_=mysql.DATETIME(fsp=6)
- )
+ op.alter_column(table_name="job", column_name="start_date", type_=mysql.DATETIME(fsp=6))
+ op.alter_column(table_name="job", column_name="end_date", type_=mysql.DATETIME(fsp=6))
op.alter_column(
table_name="job",
column_name="latest_heartbeat",
type_=mysql.DATETIME(fsp=6),
)
- op.alter_column(
- table_name="log", column_name="dttm", type_=mysql.DATETIME(fsp=6)
- )
- op.alter_column(
- table_name="log", column_name="execution_date", type_=mysql.DATETIME(fsp=6)
- )
+ op.alter_column(table_name="log", column_name="dttm", type_=mysql.DATETIME(fsp=6))
+ op.alter_column(table_name="log", column_name="execution_date", type_=mysql.DATETIME(fsp=6))
op.alter_column(
table_name="sla_miss",
@@ -374,9 +328,7 @@ def downgrade(): # noqa: D103
type_=mysql.DATETIME(fsp=6),
nullable=False,
)
- op.alter_column(
- table_name="sla_miss", column_name="timestamp", type_=mysql.DATETIME(fsp=6)
- )
+ op.alter_column(table_name="sla_miss", column_name="timestamp", type_=mysql.DATETIME(fsp=6))
op.alter_column(
table_name="task_fail",
@@ -388,9 +340,7 @@ def downgrade(): # noqa: D103
column_name="start_date",
type_=mysql.DATETIME(fsp=6),
)
- op.alter_column(
- table_name="task_fail", column_name="end_date", type_=mysql.DATETIME(fsp=6)
- )
+ op.alter_column(table_name="task_fail", column_name="end_date", type_=mysql.DATETIME(fsp=6))
op.alter_column(
table_name="task_instance",
@@ -414,12 +364,8 @@ def downgrade(): # noqa: D103
type_=mysql.DATETIME(fsp=6),
)
- op.alter_column(
- table_name="xcom", column_name="timestamp", type_=mysql.DATETIME(fsp=6)
- )
- op.alter_column(
- table_name="xcom", column_name="execution_date", type_=mysql.DATETIME(fsp=6)
- )
+ op.alter_column(table_name="xcom", column_name="timestamp", type_=mysql.DATETIME(fsp=6))
+ op.alter_column(table_name="xcom", column_name="execution_date", type_=mysql.DATETIME(fsp=6))
else:
if conn.dialect.name in ("sqlite", "mssql"):
return
@@ -429,48 +375,26 @@ def downgrade(): # noqa: D103
if conn.dialect.name == "postgresql":
conn.execute("set timezone=UTC")
- op.alter_column(
- table_name="chart", column_name="last_modified", type_=sa.DateTime()
- )
+ op.alter_column(table_name="chart", column_name="last_modified", type_=sa.DateTime())
- op.alter_column(
- table_name="dag", column_name="last_scheduler_run", type_=sa.DateTime()
- )
- op.alter_column(
- table_name="dag", column_name="last_pickled", type_=sa.DateTime()
- )
- op.alter_column(
- table_name="dag", column_name="last_expired", type_=sa.DateTime()
- )
+ op.alter_column(table_name="dag", column_name="last_scheduler_run", type_=sa.DateTime())
+ op.alter_column(table_name="dag", column_name="last_pickled", type_=sa.DateTime())
+ op.alter_column(table_name="dag", column_name="last_expired", type_=sa.DateTime())
- op.alter_column(
- table_name="dag_pickle", column_name="created_dttm", type_=sa.DateTime()
- )
+ op.alter_column(table_name="dag_pickle", column_name="created_dttm", type_=sa.DateTime())
- op.alter_column(
- table_name="dag_run", column_name="execution_date", type_=sa.DateTime()
- )
- op.alter_column(
- table_name="dag_run", column_name="start_date", type_=sa.DateTime()
- )
- op.alter_column(
- table_name="dag_run", column_name="end_date", type_=sa.DateTime()
- )
+ op.alter_column(table_name="dag_run", column_name="execution_date", type_=sa.DateTime())
+ op.alter_column(table_name="dag_run", column_name="start_date", type_=sa.DateTime())
+ op.alter_column(table_name="dag_run", column_name="end_date", type_=sa.DateTime())
- op.alter_column(
- table_name="import_error", column_name="timestamp", type_=sa.DateTime()
- )
+ op.alter_column(table_name="import_error", column_name="timestamp", type_=sa.DateTime())
op.alter_column(table_name="job", column_name="start_date", type_=sa.DateTime())
op.alter_column(table_name="job", column_name="end_date", type_=sa.DateTime())
- op.alter_column(
- table_name="job", column_name="latest_heartbeat", type_=sa.DateTime()
- )
+ op.alter_column(table_name="job", column_name="latest_heartbeat", type_=sa.DateTime())
op.alter_column(table_name="log", column_name="dttm", type_=sa.DateTime())
- op.alter_column(
- table_name="log", column_name="execution_date", type_=sa.DateTime()
- )
+ op.alter_column(table_name="log", column_name="execution_date", type_=sa.DateTime())
op.alter_column(
table_name="sla_miss",
@@ -478,19 +402,11 @@ def downgrade(): # noqa: D103
type_=sa.DateTime(),
nullable=False,
)
- op.alter_column(
- table_name="sla_miss", column_name="timestamp", type_=sa.DateTime()
- )
+ op.alter_column(table_name="sla_miss", column_name="timestamp", type_=sa.DateTime())
- op.alter_column(
- table_name="task_fail", column_name="execution_date", type_=sa.DateTime()
- )
- op.alter_column(
- table_name="task_fail", column_name="start_date", type_=sa.DateTime()
- )
- op.alter_column(
- table_name="task_fail", column_name="end_date", type_=sa.DateTime()
- )
+ op.alter_column(table_name="task_fail", column_name="execution_date", type_=sa.DateTime())
+ op.alter_column(table_name="task_fail", column_name="start_date", type_=sa.DateTime())
+ op.alter_column(table_name="task_fail", column_name="end_date", type_=sa.DateTime())
op.alter_column(
table_name="task_instance",
@@ -498,17 +414,9 @@ def downgrade(): # noqa: D103
type_=sa.DateTime(),
nullable=False,
)
- op.alter_column(
- table_name="task_instance", column_name="start_date", type_=sa.DateTime()
- )
- op.alter_column(
- table_name="task_instance", column_name="end_date", type_=sa.DateTime()
- )
- op.alter_column(
- table_name="task_instance", column_name="queued_dttm", type_=sa.DateTime()
- )
+ op.alter_column(table_name="task_instance", column_name="start_date", type_=sa.DateTime())
+ op.alter_column(table_name="task_instance", column_name="end_date", type_=sa.DateTime())
+ op.alter_column(table_name="task_instance", column_name="queued_dttm", type_=sa.DateTime())
op.alter_column(table_name="xcom", column_name="timestamp", type_=sa.DateTime())
- op.alter_column(
- table_name="xcom", column_name="execution_date", type_=sa.DateTime()
- )
+ op.alter_column(table_name="xcom", column_name="execution_date", type_=sa.DateTime())
diff --git a/airflow/migrations/versions/127d2bf2dfa7_add_dag_id_state_index_on_dag_run_table.py b/airflow/migrations/versions/127d2bf2dfa7_add_dag_id_state_index_on_dag_run_table.py
index df79a9512c965..855e55c5ee15d 100644
--- a/airflow/migrations/versions/127d2bf2dfa7_add_dag_id_state_index_on_dag_run_table.py
+++ b/airflow/migrations/versions/127d2bf2dfa7_add_dag_id_state_index_on_dag_run_table.py
@@ -32,9 +32,9 @@
depends_on = None
-def upgrade(): # noqa: D103
+def upgrade(): # noqa: D103
op.create_index('dag_id_state', 'dag_run', ['dag_id', 'state'], unique=False)
-def downgrade(): # noqa: D103
+def downgrade(): # noqa: D103
op.drop_index('dag_id_state', table_name='dag_run')
diff --git a/airflow/migrations/versions/13eb55f81627_for_compatibility.py b/airflow/migrations/versions/13eb55f81627_for_compatibility.py
index ec43294be6111..538db1a49ffba 100644
--- a/airflow/migrations/versions/13eb55f81627_for_compatibility.py
+++ b/airflow/migrations/versions/13eb55f81627_for_compatibility.py
@@ -31,9 +31,9 @@
depends_on = None
-def upgrade(): # noqa: D103
+def upgrade(): # noqa: D103
pass
-def downgrade(): # noqa: D103
+def downgrade(): # noqa: D103
pass
diff --git a/airflow/migrations/versions/1507a7289a2f_create_is_encrypted.py b/airflow/migrations/versions/1507a7289a2f_create_is_encrypted.py
index 876f79f63a15b..7afdeb2c1831b 100644
--- a/airflow/migrations/versions/1507a7289a2f_create_is_encrypted.py
+++ b/airflow/migrations/versions/1507a7289a2f_create_is_encrypted.py
@@ -34,14 +34,11 @@
depends_on = None
connectionhelper = sa.Table(
- 'connection',
- sa.MetaData(),
- sa.Column('id', sa.Integer, primary_key=True),
- sa.Column('is_encrypted')
+ 'connection', sa.MetaData(), sa.Column('id', sa.Integer, primary_key=True), sa.Column('is_encrypted')
)
-def upgrade(): # noqa: D103
+def upgrade(): # noqa: D103
# first check if the user already has this done. This should only be
# true for users who are upgrading from a previous version of Airflow
# that predates Alembic integration
@@ -55,15 +52,11 @@ def upgrade(): # noqa: D103
if 'is_encrypted' in col_names:
return
- op.add_column(
- 'connection',
- sa.Column('is_encrypted', sa.Boolean, unique=False, default=False))
+ op.add_column('connection', sa.Column('is_encrypted', sa.Boolean, unique=False, default=False))
conn = op.get_bind()
- conn.execute(
- connectionhelper.update().values(is_encrypted=False)
- )
+ conn.execute(connectionhelper.update().values(is_encrypted=False))
-def downgrade(): # noqa: D103
+def downgrade(): # noqa: D103
op.drop_column('connection', 'is_encrypted')
diff --git a/airflow/migrations/versions/1968acfc09e3_add_is_encrypted_column_to_variable_.py b/airflow/migrations/versions/1968acfc09e3_add_is_encrypted_column_to_variable_.py
index d9b4fe015932e..e880d77fee220 100644
--- a/airflow/migrations/versions/1968acfc09e3_add_is_encrypted_column_to_variable_.py
+++ b/airflow/migrations/versions/1968acfc09e3_add_is_encrypted_column_to_variable_.py
@@ -33,9 +33,9 @@
depends_on = None
-def upgrade(): # noqa: D103
+def upgrade(): # noqa: D103
op.add_column('variable', sa.Column('is_encrypted', sa.Boolean, default=False))
-def downgrade(): # noqa: D103
+def downgrade(): # noqa: D103
op.drop_column('variable', 'is_encrypted')
diff --git a/airflow/migrations/versions/1b38cef5b76e_add_dagrun.py b/airflow/migrations/versions/1b38cef5b76e_add_dagrun.py
index 8edcbb4d59618..7edebfc4cf9d2 100644
--- a/airflow/migrations/versions/1b38cef5b76e_add_dagrun.py
+++ b/airflow/migrations/versions/1b38cef5b76e_add_dagrun.py
@@ -34,18 +34,20 @@
depends_on = None
-def upgrade(): # noqa: D103
- op.create_table('dag_run',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('dag_id', sa.String(length=250), nullable=True),
- sa.Column('execution_date', sa.DateTime(), nullable=True),
- sa.Column('state', sa.String(length=50), nullable=True),
- sa.Column('run_id', sa.String(length=250), nullable=True),
- sa.Column('external_trigger', sa.Boolean(), nullable=True),
- sa.PrimaryKeyConstraint('id'),
- sa.UniqueConstraint('dag_id', 'execution_date'),
- sa.UniqueConstraint('dag_id', 'run_id'))
-
-
-def downgrade(): # noqa: D103
+def upgrade(): # noqa: D103
+ op.create_table(
+ 'dag_run',
+ sa.Column('id', sa.Integer(), nullable=False),
+ sa.Column('dag_id', sa.String(length=250), nullable=True),
+ sa.Column('execution_date', sa.DateTime(), nullable=True),
+ sa.Column('state', sa.String(length=50), nullable=True),
+ sa.Column('run_id', sa.String(length=250), nullable=True),
+ sa.Column('external_trigger', sa.Boolean(), nullable=True),
+ sa.PrimaryKeyConstraint('id'),
+ sa.UniqueConstraint('dag_id', 'execution_date'),
+ sa.UniqueConstraint('dag_id', 'run_id'),
+ )
+
+
+def downgrade(): # noqa: D103
op.drop_table('dag_run')
diff --git a/airflow/migrations/versions/211e584da130_add_ti_state_index.py b/airflow/migrations/versions/211e584da130_add_ti_state_index.py
index a6f946332118b..7df1550733d7e 100644
--- a/airflow/migrations/versions/211e584da130_add_ti_state_index.py
+++ b/airflow/migrations/versions/211e584da130_add_ti_state_index.py
@@ -32,9 +32,9 @@
depends_on = None
-def upgrade(): # noqa: D103
+def upgrade(): # noqa: D103
op.create_index('ti_state', 'task_instance', ['state'], unique=False)
-def downgrade(): # noqa: D103
+def downgrade(): # noqa: D103
op.drop_index('ti_state', table_name='task_instance')
diff --git a/airflow/migrations/versions/27c6a30d7c24_add_executor_config_to_task_instance.py b/airflow/migrations/versions/27c6a30d7c24_add_executor_config_to_task_instance.py
index 768a67fe529bd..d0853efb79a2f 100644
--- a/airflow/migrations/versions/27c6a30d7c24_add_executor_config_to_task_instance.py
+++ b/airflow/migrations/versions/27c6a30d7c24_add_executor_config_to_task_instance.py
@@ -38,9 +38,9 @@
NEW_COLUMN = "executor_config"
-def upgrade(): # noqa: D103
+def upgrade(): # noqa: D103
op.add_column(TASK_INSTANCE_TABLE, sa.Column(NEW_COLUMN, sa.PickleType(pickler=dill)))
-def downgrade(): # noqa: D103
+def downgrade(): # noqa: D103
op.drop_column(TASK_INSTANCE_TABLE, NEW_COLUMN)
diff --git a/airflow/migrations/versions/2e541a1dcfed_task_duration.py b/airflow/migrations/versions/2e541a1dcfed_task_duration.py
index b071f1c4b8d1e..12d8e2e5a608d 100644
--- a/airflow/migrations/versions/2e541a1dcfed_task_duration.py
+++ b/airflow/migrations/versions/2e541a1dcfed_task_duration.py
@@ -35,14 +35,16 @@
depends_on = None
-def upgrade(): # noqa: D103
+def upgrade(): # noqa: D103
# use batch_alter_table to support SQLite workaround
with op.batch_alter_table("task_instance") as batch_op:
- batch_op.alter_column('duration',
- existing_type=mysql.INTEGER(display_width=11),
- type_=sa.Float(),
- existing_nullable=True)
+ batch_op.alter_column(
+ 'duration',
+ existing_type=mysql.INTEGER(display_width=11),
+ type_=sa.Float(),
+ existing_nullable=True,
+ )
-def downgrade(): # noqa: D103
+def downgrade(): # noqa: D103
pass
diff --git a/airflow/migrations/versions/2e82aab8ef20_rename_user_table.py b/airflow/migrations/versions/2e82aab8ef20_rename_user_table.py
index 0acc346d9abb4..3dcbe47460efa 100644
--- a/airflow/migrations/versions/2e82aab8ef20_rename_user_table.py
+++ b/airflow/migrations/versions/2e82aab8ef20_rename_user_table.py
@@ -32,9 +32,9 @@
depends_on = None
-def upgrade(): # noqa: D103
+def upgrade(): # noqa: D103
op.rename_table('user', 'users')
-def downgrade(): # noqa: D103
+def downgrade(): # noqa: D103
op.rename_table('users', 'user')
diff --git a/airflow/migrations/versions/338e90f54d61_more_logging_into_task_isntance.py b/airflow/migrations/versions/338e90f54d61_more_logging_into_task_isntance.py
index 50245a3a8cb83..60ed6628a77dd 100644
--- a/airflow/migrations/versions/338e90f54d61_more_logging_into_task_isntance.py
+++ b/airflow/migrations/versions/338e90f54d61_more_logging_into_task_isntance.py
@@ -33,11 +33,11 @@
depends_on = None
-def upgrade(): # noqa: D103
+def upgrade(): # noqa: D103
op.add_column('task_instance', sa.Column('operator', sa.String(length=1000), nullable=True))
op.add_column('task_instance', sa.Column('queued_dttm', sa.DateTime(), nullable=True))
-def downgrade(): # noqa: D103
+def downgrade(): # noqa: D103
op.drop_column('task_instance', 'queued_dttm')
op.drop_column('task_instance', 'operator')
diff --git a/airflow/migrations/versions/33ae817a1ff4_add_kubernetes_resource_checkpointing.py b/airflow/migrations/versions/33ae817a1ff4_add_kubernetes_resource_checkpointing.py
index d22d87864a859..2f06756dcf054 100644
--- a/airflow/migrations/versions/33ae817a1ff4_add_kubernetes_resource_checkpointing.py
+++ b/airflow/migrations/versions/33ae817a1ff4_add_kubernetes_resource_checkpointing.py
@@ -35,10 +35,10 @@
RESOURCE_TABLE = "kube_resource_version"
-def upgrade(): # noqa: D103
+def upgrade(): # noqa: D103
columns_and_constraints = [
sa.Column("one_row_id", sa.Boolean, server_default=sa.true(), primary_key=True),
- sa.Column("resource_version", sa.String(255))
+ sa.Column("resource_version", sa.String(255)),
]
conn = op.get_bind()
@@ -53,15 +53,10 @@ def upgrade(): # noqa: D103
sa.CheckConstraint("one_row_id", name="kube_resource_version_one_row_id")
)
- table = op.create_table(
- RESOURCE_TABLE,
- *columns_and_constraints
- )
+ table = op.create_table(RESOURCE_TABLE, *columns_and_constraints)
- op.bulk_insert(table, [
- {"resource_version": ""}
- ])
+ op.bulk_insert(table, [{"resource_version": ""}])
-def downgrade(): # noqa: D103
+def downgrade(): # noqa: D103
op.drop_table(RESOURCE_TABLE)
diff --git a/airflow/migrations/versions/40e67319e3a9_dagrun_config.py b/airflow/migrations/versions/40e67319e3a9_dagrun_config.py
index d123f8e383d55..96c211eebfba4 100644
--- a/airflow/migrations/versions/40e67319e3a9_dagrun_config.py
+++ b/airflow/migrations/versions/40e67319e3a9_dagrun_config.py
@@ -33,9 +33,9 @@
depends_on = None
-def upgrade(): # noqa: D103
+def upgrade(): # noqa: D103
op.add_column('dag_run', sa.Column('conf', sa.PickleType(), nullable=True))
-def downgrade(): # noqa: D103
+def downgrade(): # noqa: D103
op.drop_column('dag_run', 'conf')
diff --git a/airflow/migrations/versions/41f5f12752f8_add_superuser_field.py b/airflow/migrations/versions/41f5f12752f8_add_superuser_field.py
index a7d5767f99bd1..572845b4d04d5 100644
--- a/airflow/migrations/versions/41f5f12752f8_add_superuser_field.py
+++ b/airflow/migrations/versions/41f5f12752f8_add_superuser_field.py
@@ -33,9 +33,9 @@
depends_on = None
-def upgrade(): # noqa: D103
+def upgrade(): # noqa: D103
op.add_column('users', sa.Column('superuser', sa.Boolean(), default=False))
-def downgrade(): # noqa: D103
+def downgrade(): # noqa: D103
op.drop_column('users', 'superuser')
diff --git a/airflow/migrations/versions/4446e08588_dagrun_start_end.py b/airflow/migrations/versions/4446e08588_dagrun_start_end.py
index ec20c807b7a03..2ee527361d415 100644
--- a/airflow/migrations/versions/4446e08588_dagrun_start_end.py
+++ b/airflow/migrations/versions/4446e08588_dagrun_start_end.py
@@ -34,11 +34,11 @@
depends_on = None
-def upgrade(): # noqa: D103
+def upgrade(): # noqa: D103
op.add_column('dag_run', sa.Column('end_date', sa.DateTime(), nullable=True))
op.add_column('dag_run', sa.Column('start_date', sa.DateTime(), nullable=True))
-def downgrade(): # noqa: D103
+def downgrade(): # noqa: D103
op.drop_column('dag_run', 'start_date')
op.drop_column('dag_run', 'end_date')
diff --git a/airflow/migrations/versions/4addfa1236f1_add_fractional_seconds_to_mysql_tables.py b/airflow/migrations/versions/4addfa1236f1_add_fractional_seconds_to_mysql_tables.py
index 20d77ce037286..919119fb8b88c 100644
--- a/airflow/migrations/versions/4addfa1236f1_add_fractional_seconds_to_mysql_tables.py
+++ b/airflow/migrations/versions/4addfa1236f1_add_fractional_seconds_to_mysql_tables.py
@@ -34,126 +34,86 @@
depends_on = None
-def upgrade(): # noqa: D103
+def upgrade(): # noqa: D103
if context.config.get_main_option('sqlalchemy.url').startswith('mysql'):
- op.alter_column(table_name='dag', column_name='last_scheduler_run',
- type_=mysql.DATETIME(fsp=6))
- op.alter_column(table_name='dag', column_name='last_pickled',
- type_=mysql.DATETIME(fsp=6))
- op.alter_column(table_name='dag', column_name='last_expired',
- type_=mysql.DATETIME(fsp=6))
-
- op.alter_column(table_name='dag_pickle', column_name='created_dttm',
- type_=mysql.DATETIME(fsp=6))
-
- op.alter_column(table_name='dag_run', column_name='execution_date',
- type_=mysql.DATETIME(fsp=6))
- op.alter_column(table_name='dag_run', column_name='start_date',
- type_=mysql.DATETIME(fsp=6))
- op.alter_column(table_name='dag_run', column_name='end_date',
- type_=mysql.DATETIME(fsp=6))
-
- op.alter_column(table_name='import_error', column_name='timestamp',
- type_=mysql.DATETIME(fsp=6))
-
- op.alter_column(table_name='job', column_name='start_date',
- type_=mysql.DATETIME(fsp=6))
- op.alter_column(table_name='job', column_name='end_date',
- type_=mysql.DATETIME(fsp=6))
- op.alter_column(table_name='job', column_name='latest_heartbeat',
- type_=mysql.DATETIME(fsp=6))
-
- op.alter_column(table_name='log', column_name='dttm',
- type_=mysql.DATETIME(fsp=6))
- op.alter_column(table_name='log', column_name='execution_date',
- type_=mysql.DATETIME(fsp=6))
-
- op.alter_column(table_name='sla_miss', column_name='execution_date',
- type_=mysql.DATETIME(fsp=6),
- nullable=False)
- op.alter_column(table_name='sla_miss', column_name='timestamp',
- type_=mysql.DATETIME(fsp=6))
-
- op.alter_column(table_name='task_fail', column_name='execution_date',
- type_=mysql.DATETIME(fsp=6))
- op.alter_column(table_name='task_fail', column_name='start_date',
- type_=mysql.DATETIME(fsp=6))
- op.alter_column(table_name='task_fail', column_name='end_date',
- type_=mysql.DATETIME(fsp=6))
-
- op.alter_column(table_name='task_instance', column_name='execution_date',
- type_=mysql.DATETIME(fsp=6),
- nullable=False)
- op.alter_column(table_name='task_instance', column_name='start_date',
- type_=mysql.DATETIME(fsp=6))
- op.alter_column(table_name='task_instance', column_name='end_date',
- type_=mysql.DATETIME(fsp=6))
- op.alter_column(table_name='task_instance', column_name='queued_dttm',
- type_=mysql.DATETIME(fsp=6))
-
- op.alter_column(table_name='xcom', column_name='timestamp',
- type_=mysql.DATETIME(fsp=6))
- op.alter_column(table_name='xcom', column_name='execution_date',
- type_=mysql.DATETIME(fsp=6))
-
-
-def downgrade(): # noqa: D103
+ op.alter_column(table_name='dag', column_name='last_scheduler_run', type_=mysql.DATETIME(fsp=6))
+ op.alter_column(table_name='dag', column_name='last_pickled', type_=mysql.DATETIME(fsp=6))
+ op.alter_column(table_name='dag', column_name='last_expired', type_=mysql.DATETIME(fsp=6))
+
+ op.alter_column(table_name='dag_pickle', column_name='created_dttm', type_=mysql.DATETIME(fsp=6))
+
+ op.alter_column(table_name='dag_run', column_name='execution_date', type_=mysql.DATETIME(fsp=6))
+ op.alter_column(table_name='dag_run', column_name='start_date', type_=mysql.DATETIME(fsp=6))
+ op.alter_column(table_name='dag_run', column_name='end_date', type_=mysql.DATETIME(fsp=6))
+
+ op.alter_column(table_name='import_error', column_name='timestamp', type_=mysql.DATETIME(fsp=6))
+
+ op.alter_column(table_name='job', column_name='start_date', type_=mysql.DATETIME(fsp=6))
+ op.alter_column(table_name='job', column_name='end_date', type_=mysql.DATETIME(fsp=6))
+ op.alter_column(table_name='job', column_name='latest_heartbeat', type_=mysql.DATETIME(fsp=6))
+
+ op.alter_column(table_name='log', column_name='dttm', type_=mysql.DATETIME(fsp=6))
+ op.alter_column(table_name='log', column_name='execution_date', type_=mysql.DATETIME(fsp=6))
+
+ op.alter_column(
+ table_name='sla_miss', column_name='execution_date', type_=mysql.DATETIME(fsp=6), nullable=False
+ )
+ op.alter_column(table_name='sla_miss', column_name='timestamp', type_=mysql.DATETIME(fsp=6))
+
+ op.alter_column(table_name='task_fail', column_name='execution_date', type_=mysql.DATETIME(fsp=6))
+ op.alter_column(table_name='task_fail', column_name='start_date', type_=mysql.DATETIME(fsp=6))
+ op.alter_column(table_name='task_fail', column_name='end_date', type_=mysql.DATETIME(fsp=6))
+
+ op.alter_column(
+ table_name='task_instance',
+ column_name='execution_date',
+ type_=mysql.DATETIME(fsp=6),
+ nullable=False,
+ )
+ op.alter_column(table_name='task_instance', column_name='start_date', type_=mysql.DATETIME(fsp=6))
+ op.alter_column(table_name='task_instance', column_name='end_date', type_=mysql.DATETIME(fsp=6))
+ op.alter_column(table_name='task_instance', column_name='queued_dttm', type_=mysql.DATETIME(fsp=6))
+
+ op.alter_column(table_name='xcom', column_name='timestamp', type_=mysql.DATETIME(fsp=6))
+ op.alter_column(table_name='xcom', column_name='execution_date', type_=mysql.DATETIME(fsp=6))
+
+
+def downgrade(): # noqa: D103
if context.config.get_main_option('sqlalchemy.url').startswith('mysql'):
- op.alter_column(table_name='dag', column_name='last_scheduler_run',
- type_=mysql.DATETIME())
- op.alter_column(table_name='dag', column_name='last_pickled',
- type_=mysql.DATETIME())
- op.alter_column(table_name='dag', column_name='last_expired',
- type_=mysql.DATETIME())
-
- op.alter_column(table_name='dag_pickle', column_name='created_dttm',
- type_=mysql.DATETIME())
-
- op.alter_column(table_name='dag_run', column_name='execution_date',
- type_=mysql.DATETIME())
- op.alter_column(table_name='dag_run', column_name='start_date',
- type_=mysql.DATETIME())
- op.alter_column(table_name='dag_run', column_name='end_date',
- type_=mysql.DATETIME())
-
- op.alter_column(table_name='import_error', column_name='timestamp',
- type_=mysql.DATETIME())
-
- op.alter_column(table_name='job', column_name='start_date',
- type_=mysql.DATETIME())
- op.alter_column(table_name='job', column_name='end_date',
- type_=mysql.DATETIME())
- op.alter_column(table_name='job', column_name='latest_heartbeat',
- type_=mysql.DATETIME())
-
- op.alter_column(table_name='log', column_name='dttm',
- type_=mysql.DATETIME())
- op.alter_column(table_name='log', column_name='execution_date',
- type_=mysql.DATETIME())
-
- op.alter_column(table_name='sla_miss', column_name='execution_date',
- type_=mysql.DATETIME(), nullable=False)
- op.alter_column(table_name='sla_miss', column_name='timestamp',
- type_=mysql.DATETIME())
-
- op.alter_column(table_name='task_fail', column_name='execution_date',
- type_=mysql.DATETIME())
- op.alter_column(table_name='task_fail', column_name='start_date',
- type_=mysql.DATETIME())
- op.alter_column(table_name='task_fail', column_name='end_date',
- type_=mysql.DATETIME())
-
- op.alter_column(table_name='task_instance', column_name='execution_date',
- type_=mysql.DATETIME(),
- nullable=False)
- op.alter_column(table_name='task_instance', column_name='start_date',
- type_=mysql.DATETIME())
- op.alter_column(table_name='task_instance', column_name='end_date',
- type_=mysql.DATETIME())
- op.alter_column(table_name='task_instance', column_name='queued_dttm',
- type_=mysql.DATETIME())
-
- op.alter_column(table_name='xcom', column_name='timestamp',
- type_=mysql.DATETIME())
- op.alter_column(table_name='xcom', column_name='execution_date',
- type_=mysql.DATETIME())
+ op.alter_column(table_name='dag', column_name='last_scheduler_run', type_=mysql.DATETIME())
+ op.alter_column(table_name='dag', column_name='last_pickled', type_=mysql.DATETIME())
+ op.alter_column(table_name='dag', column_name='last_expired', type_=mysql.DATETIME())
+
+ op.alter_column(table_name='dag_pickle', column_name='created_dttm', type_=mysql.DATETIME())
+
+ op.alter_column(table_name='dag_run', column_name='execution_date', type_=mysql.DATETIME())
+ op.alter_column(table_name='dag_run', column_name='start_date', type_=mysql.DATETIME())
+ op.alter_column(table_name='dag_run', column_name='end_date', type_=mysql.DATETIME())
+
+ op.alter_column(table_name='import_error', column_name='timestamp', type_=mysql.DATETIME())
+
+ op.alter_column(table_name='job', column_name='start_date', type_=mysql.DATETIME())
+ op.alter_column(table_name='job', column_name='end_date', type_=mysql.DATETIME())
+ op.alter_column(table_name='job', column_name='latest_heartbeat', type_=mysql.DATETIME())
+
+ op.alter_column(table_name='log', column_name='dttm', type_=mysql.DATETIME())
+ op.alter_column(table_name='log', column_name='execution_date', type_=mysql.DATETIME())
+
+ op.alter_column(
+ table_name='sla_miss', column_name='execution_date', type_=mysql.DATETIME(), nullable=False
+ )
+ op.alter_column(table_name='sla_miss', column_name='timestamp', type_=mysql.DATETIME())
+
+ op.alter_column(table_name='task_fail', column_name='execution_date', type_=mysql.DATETIME())
+ op.alter_column(table_name='task_fail', column_name='start_date', type_=mysql.DATETIME())
+ op.alter_column(table_name='task_fail', column_name='end_date', type_=mysql.DATETIME())
+
+ op.alter_column(
+ table_name='task_instance', column_name='execution_date', type_=mysql.DATETIME(), nullable=False
+ )
+ op.alter_column(table_name='task_instance', column_name='start_date', type_=mysql.DATETIME())
+ op.alter_column(table_name='task_instance', column_name='end_date', type_=mysql.DATETIME())
+ op.alter_column(table_name='task_instance', column_name='queued_dttm', type_=mysql.DATETIME())
+
+ op.alter_column(table_name='xcom', column_name='timestamp', type_=mysql.DATETIME())
+ op.alter_column(table_name='xcom', column_name='execution_date', type_=mysql.DATETIME())
diff --git a/airflow/migrations/versions/502898887f84_adding_extra_to_log.py b/airflow/migrations/versions/502898887f84_adding_extra_to_log.py
index 606b6b8b5a1be..0f00e110c1d5a 100644
--- a/airflow/migrations/versions/502898887f84_adding_extra_to_log.py
+++ b/airflow/migrations/versions/502898887f84_adding_extra_to_log.py
@@ -33,9 +33,9 @@
depends_on = None
-def upgrade(): # noqa: D103
+def upgrade(): # noqa: D103
op.add_column('log', sa.Column('extra', sa.Text(), nullable=True))
-def downgrade(): # noqa: D103
+def downgrade(): # noqa: D103
op.drop_column('log', 'extra')
diff --git a/airflow/migrations/versions/52d53670a240_fix_mssql_exec_date_rendered_task_instance.py b/airflow/migrations/versions/52d53670a240_fix_mssql_exec_date_rendered_task_instance.py
index c2898b86423a1..84daac5013518 100644
--- a/airflow/migrations/versions/52d53670a240_fix_mssql_exec_date_rendered_task_instance.py
+++ b/airflow/migrations/versions/52d53670a240_fix_mssql_exec_date_rendered_task_instance.py
@@ -52,7 +52,7 @@ def upgrade():
sa.Column('task_id', sa.String(length=250), nullable=False),
sa.Column('execution_date', mssql.DATETIME2, nullable=False),
sa.Column('rendered_fields', json_type(), nullable=False),
- sa.PrimaryKeyConstraint('dag_id', 'task_id', 'execution_date')
+ sa.PrimaryKeyConstraint('dag_id', 'task_id', 'execution_date'),
)
@@ -72,5 +72,5 @@ def downgrade():
sa.Column('task_id', sa.String(length=250), nullable=False),
sa.Column('execution_date', sa.TIMESTAMP, nullable=False),
sa.Column('rendered_fields', json_type(), nullable=False),
- sa.PrimaryKeyConstraint('dag_id', 'task_id', 'execution_date')
+ sa.PrimaryKeyConstraint('dag_id', 'task_id', 'execution_date'),
)
diff --git a/airflow/migrations/versions/52d714495f0_job_id_indices.py b/airflow/migrations/versions/52d714495f0_job_id_indices.py
index 2006271a48cee..fc3ecad87f200 100644
--- a/airflow/migrations/versions/52d714495f0_job_id_indices.py
+++ b/airflow/migrations/versions/52d714495f0_job_id_indices.py
@@ -32,10 +32,9 @@
depends_on = None
-def upgrade(): # noqa: D103
- op.create_index('idx_job_state_heartbeat', 'job',
- ['state', 'latest_heartbeat'], unique=False)
+def upgrade(): # noqa: D103
+ op.create_index('idx_job_state_heartbeat', 'job', ['state', 'latest_heartbeat'], unique=False)
-def downgrade(): # noqa: D103
+def downgrade(): # noqa: D103
op.drop_index('idx_job_state_heartbeat', table_name='job')
diff --git a/airflow/migrations/versions/561833c1c74b_add_password_column_to_user.py b/airflow/migrations/versions/561833c1c74b_add_password_column_to_user.py
index bbf835c5cca62..144259ef06f9c 100644
--- a/airflow/migrations/versions/561833c1c74b_add_password_column_to_user.py
+++ b/airflow/migrations/versions/561833c1c74b_add_password_column_to_user.py
@@ -33,9 +33,9 @@
depends_on = None
-def upgrade(): # noqa: D103
+def upgrade(): # noqa: D103
op.add_column('user', sa.Column('password', sa.String(255)))
-def downgrade(): # noqa: D103
+def downgrade(): # noqa: D103
op.drop_column('user', 'password')
diff --git a/airflow/migrations/versions/64de9cddf6c9_add_task_fails_journal_table.py b/airflow/migrations/versions/64de9cddf6c9_add_task_fails_journal_table.py
index d798a7b67b36e..40dd9ddbfae19 100644
--- a/airflow/migrations/versions/64de9cddf6c9_add_task_fails_journal_table.py
+++ b/airflow/migrations/versions/64de9cddf6c9_add_task_fails_journal_table.py
@@ -35,7 +35,7 @@
depends_on = None
-def upgrade(): # noqa: D103
+def upgrade(): # noqa: D103
op.create_table(
'task_fail',
sa.Column('id', sa.Integer(), nullable=False),
@@ -49,5 +49,5 @@ def upgrade(): # noqa: D103
)
-def downgrade(): # noqa: D103
+def downgrade(): # noqa: D103
op.drop_table('task_fail')
diff --git a/airflow/migrations/versions/6e96a59344a4_make_taskinstance_pool_not_nullable.py b/airflow/migrations/versions/6e96a59344a4_make_taskinstance_pool_not_nullable.py
index 8526a9a7446c6..c59c52225757f 100644
--- a/airflow/migrations/versions/6e96a59344a4_make_taskinstance_pool_not_nullable.py
+++ b/airflow/migrations/versions/6e96a59344a4_make_taskinstance_pool_not_nullable.py
@@ -84,10 +84,9 @@ class TaskInstance(Base): # type: ignore
def upgrade():
"""Make TaskInstance.pool field not nullable."""
with create_session() as session:
- session.query(TaskInstance) \
- .filter(TaskInstance.pool.is_(None)) \
- .update({TaskInstance.pool: 'default_pool'},
- synchronize_session=False) # Avoid select updated rows
+ session.query(TaskInstance).filter(TaskInstance.pool.is_(None)).update(
+ {TaskInstance.pool: 'default_pool'}, synchronize_session=False
+ ) # Avoid select updated rows
session.commit()
conn = op.get_bind()
@@ -124,8 +123,7 @@ def downgrade():
op.create_index('ti_pool', 'task_instance', ['pool', 'state', 'priority_weight'])
with create_session() as session:
- session.query(TaskInstance) \
- .filter(TaskInstance.pool == 'default_pool') \
- .update({TaskInstance.pool: None},
- synchronize_session=False) # Avoid select updated rows
+ session.query(TaskInstance).filter(TaskInstance.pool == 'default_pool').update(
+ {TaskInstance.pool: None}, synchronize_session=False
+ ) # Avoid select updated rows
session.commit()
diff --git a/airflow/migrations/versions/74effc47d867_change_datetime_to_datetime2_6_on_mssql_.py b/airflow/migrations/versions/74effc47d867_change_datetime_to_datetime2_6_on_mssql_.py
index 3a6746d1bdba7..5505acd0e17e6 100644
--- a/airflow/migrations/versions/74effc47d867_change_datetime_to_datetime2_6_on_mssql_.py
+++ b/airflow/migrations/versions/74effc47d867_change_datetime_to_datetime2_6_on_mssql_.py
@@ -43,7 +43,8 @@ def upgrade():
result = conn.execute(
"""SELECT CASE WHEN CONVERT(VARCHAR(128), SERVERPROPERTY ('productversion'))
like '8%' THEN '2000' WHEN CONVERT(VARCHAR(128), SERVERPROPERTY ('productversion'))
- like '9%' THEN '2005' ELSE '2005Plus' END AS MajorVersion""").fetchone()
+ like '9%' THEN '2005' ELSE '2005Plus' END AS MajorVersion"""
+ ).fetchone()
mssql_version = result[0]
if mssql_version in ("2000", "2005"):
return
@@ -51,37 +52,49 @@ def upgrade():
with op.batch_alter_table('task_reschedule') as task_reschedule_batch_op:
task_reschedule_batch_op.drop_index('idx_task_reschedule_dag_task_date')
task_reschedule_batch_op.drop_constraint('task_reschedule_dag_task_date_fkey', type_='foreignkey')
- task_reschedule_batch_op.alter_column(column_name="execution_date",
- type_=mssql.DATETIME2(precision=6), nullable=False, )
- task_reschedule_batch_op.alter_column(column_name='start_date',
- type_=mssql.DATETIME2(precision=6))
+ task_reschedule_batch_op.alter_column(
+ column_name="execution_date",
+ type_=mssql.DATETIME2(precision=6),
+ nullable=False,
+ )
+ task_reschedule_batch_op.alter_column(
+ column_name='start_date', type_=mssql.DATETIME2(precision=6)
+ )
task_reschedule_batch_op.alter_column(column_name='end_date', type_=mssql.DATETIME2(precision=6))
- task_reschedule_batch_op.alter_column(column_name='reschedule_date',
- type_=mssql.DATETIME2(precision=6))
+ task_reschedule_batch_op.alter_column(
+ column_name='reschedule_date', type_=mssql.DATETIME2(precision=6)
+ )
with op.batch_alter_table('task_instance') as task_instance_batch_op:
task_instance_batch_op.drop_index('ti_state_lkp')
task_instance_batch_op.drop_index('ti_dag_date')
- modify_execution_date_with_constraint(conn, task_instance_batch_op, 'task_instance',
- mssql.DATETIME2(precision=6), False)
+ modify_execution_date_with_constraint(
+ conn, task_instance_batch_op, 'task_instance', mssql.DATETIME2(precision=6), False
+ )
task_instance_batch_op.alter_column(column_name='start_date', type_=mssql.DATETIME2(precision=6))
task_instance_batch_op.alter_column(column_name='end_date', type_=mssql.DATETIME2(precision=6))
task_instance_batch_op.alter_column(column_name='queued_dttm', type_=mssql.DATETIME2(precision=6))
- task_instance_batch_op.create_index('ti_state_lkp', ['dag_id', 'task_id', 'execution_date'],
- unique=False)
+ task_instance_batch_op.create_index(
+ 'ti_state_lkp', ['dag_id', 'task_id', 'execution_date'], unique=False
+ )
task_instance_batch_op.create_index('ti_dag_date', ['dag_id', 'execution_date'], unique=False)
with op.batch_alter_table('task_reschedule') as task_reschedule_batch_op:
- task_reschedule_batch_op.create_foreign_key('task_reschedule_dag_task_date_fkey', 'task_instance',
- ['task_id', 'dag_id', 'execution_date'],
- ['task_id', 'dag_id', 'execution_date'],
- ondelete='CASCADE')
- task_reschedule_batch_op.create_index('idx_task_reschedule_dag_task_date',
- ['dag_id', 'task_id', 'execution_date'], unique=False)
+ task_reschedule_batch_op.create_foreign_key(
+ 'task_reschedule_dag_task_date_fkey',
+ 'task_instance',
+ ['task_id', 'dag_id', 'execution_date'],
+ ['task_id', 'dag_id', 'execution_date'],
+ ondelete='CASCADE',
+ )
+ task_reschedule_batch_op.create_index(
+ 'idx_task_reschedule_dag_task_date', ['dag_id', 'task_id', 'execution_date'], unique=False
+ )
with op.batch_alter_table('dag_run') as dag_run_batch_op:
- modify_execution_date_with_constraint(conn, dag_run_batch_op, 'dag_run',
- mssql.DATETIME2(precision=6), None)
+ modify_execution_date_with_constraint(
+ conn, dag_run_batch_op, 'dag_run', mssql.DATETIME2(precision=6), None
+ )
dag_run_batch_op.alter_column(column_name='start_date', type_=mssql.DATETIME2(precision=6))
dag_run_batch_op.alter_column(column_name='end_date', type_=mssql.DATETIME2(precision=6))
@@ -89,34 +102,41 @@ def upgrade():
op.alter_column(table_name='log', column_name='dttm', type_=mssql.DATETIME2(precision=6))
with op.batch_alter_table('sla_miss') as sla_miss_batch_op:
- modify_execution_date_with_constraint(conn, sla_miss_batch_op, 'sla_miss',
- mssql.DATETIME2(precision=6), False)
+ modify_execution_date_with_constraint(
+ conn, sla_miss_batch_op, 'sla_miss', mssql.DATETIME2(precision=6), False
+ )
sla_miss_batch_op.alter_column(column_name='timestamp', type_=mssql.DATETIME2(precision=6))
op.drop_index('idx_task_fail_dag_task_date', table_name='task_fail')
- op.alter_column(table_name="task_fail", column_name="execution_date",
- type_=mssql.DATETIME2(precision=6))
+ op.alter_column(
+ table_name="task_fail", column_name="execution_date", type_=mssql.DATETIME2(precision=6)
+ )
op.alter_column(table_name='task_fail', column_name='start_date', type_=mssql.DATETIME2(precision=6))
op.alter_column(table_name='task_fail', column_name='end_date', type_=mssql.DATETIME2(precision=6))
- op.create_index('idx_task_fail_dag_task_date', 'task_fail', ['dag_id', 'task_id', 'execution_date'],
- unique=False)
+ op.create_index(
+ 'idx_task_fail_dag_task_date', 'task_fail', ['dag_id', 'task_id', 'execution_date'], unique=False
+ )
op.drop_index('idx_xcom_dag_task_date', table_name='xcom')
op.alter_column(table_name="xcom", column_name="execution_date", type_=mssql.DATETIME2(precision=6))
op.alter_column(table_name='xcom', column_name='timestamp', type_=mssql.DATETIME2(precision=6))
- op.create_index('idx_xcom_dag_task_date', 'xcom', ['dag_id', 'task_id', 'execution_date'],
- unique=False)
+ op.create_index(
+ 'idx_xcom_dag_task_date', 'xcom', ['dag_id', 'task_id', 'execution_date'], unique=False
+ )
- op.alter_column(table_name='dag', column_name='last_scheduler_run',
- type_=mssql.DATETIME2(precision=6))
+ op.alter_column(
+ table_name='dag', column_name='last_scheduler_run', type_=mssql.DATETIME2(precision=6)
+ )
op.alter_column(table_name='dag', column_name='last_pickled', type_=mssql.DATETIME2(precision=6))
op.alter_column(table_name='dag', column_name='last_expired', type_=mssql.DATETIME2(precision=6))
- op.alter_column(table_name='dag_pickle', column_name='created_dttm',
- type_=mssql.DATETIME2(precision=6))
+ op.alter_column(
+ table_name='dag_pickle', column_name='created_dttm', type_=mssql.DATETIME2(precision=6)
+ )
- op.alter_column(table_name='import_error', column_name='timestamp',
- type_=mssql.DATETIME2(precision=6))
+ op.alter_column(
+ table_name='import_error', column_name='timestamp', type_=mssql.DATETIME2(precision=6)
+ )
op.drop_index('job_type_heart', table_name='job')
op.drop_index('idx_job_state_heartbeat', table_name='job')
@@ -134,7 +154,8 @@ def downgrade():
result = conn.execute(
"""SELECT CASE WHEN CONVERT(VARCHAR(128), SERVERPROPERTY ('productversion'))
like '8%' THEN '2000' WHEN CONVERT(VARCHAR(128), SERVERPROPERTY ('productversion'))
- like '9%' THEN '2005' ELSE '2005Plus' END AS MajorVersion""").fetchone()
+ like '9%' THEN '2005' ELSE '2005Plus' END AS MajorVersion"""
+ ).fetchone()
mssql_version = result[0]
if mssql_version in ("2000", "2005"):
return
@@ -142,8 +163,9 @@ def downgrade():
with op.batch_alter_table('task_reschedule') as task_reschedule_batch_op:
task_reschedule_batch_op.drop_index('idx_task_reschedule_dag_task_date')
task_reschedule_batch_op.drop_constraint('task_reschedule_dag_task_date_fkey', type_='foreignkey')
- task_reschedule_batch_op.alter_column(column_name="execution_date", type_=mssql.DATETIME,
- nullable=False)
+ task_reschedule_batch_op.alter_column(
+ column_name="execution_date", type_=mssql.DATETIME, nullable=False
+ )
task_reschedule_batch_op.alter_column(column_name='start_date', type_=mssql.DATETIME)
task_reschedule_batch_op.alter_column(column_name='end_date', type_=mssql.DATETIME)
task_reschedule_batch_op.alter_column(column_name='reschedule_date', type_=mssql.DATETIME)
@@ -151,23 +173,28 @@ def downgrade():
with op.batch_alter_table('task_instance') as task_instance_batch_op:
task_instance_batch_op.drop_index('ti_state_lkp')
task_instance_batch_op.drop_index('ti_dag_date')
- modify_execution_date_with_constraint(conn, task_instance_batch_op, 'task_instance',
- mssql.DATETIME, False)
+ modify_execution_date_with_constraint(
+ conn, task_instance_batch_op, 'task_instance', mssql.DATETIME, False
+ )
task_instance_batch_op.alter_column(column_name='start_date', type_=mssql.DATETIME)
task_instance_batch_op.alter_column(column_name='end_date', type_=mssql.DATETIME)
task_instance_batch_op.alter_column(column_name='queued_dttm', type_=mssql.DATETIME)
- task_instance_batch_op.create_index('ti_state_lkp',
- ['dag_id', 'task_id', 'execution_date'], unique=False)
- task_instance_batch_op.create_index('ti_dag_date',
- ['dag_id', 'execution_date'], unique=False)
+ task_instance_batch_op.create_index(
+ 'ti_state_lkp', ['dag_id', 'task_id', 'execution_date'], unique=False
+ )
+ task_instance_batch_op.create_index('ti_dag_date', ['dag_id', 'execution_date'], unique=False)
with op.batch_alter_table('task_reschedule') as task_reschedule_batch_op:
- task_reschedule_batch_op.create_foreign_key('task_reschedule_dag_task_date_fkey', 'task_instance',
- ['task_id', 'dag_id', 'execution_date'],
- ['task_id', 'dag_id', 'execution_date'],
- ondelete='CASCADE')
- task_reschedule_batch_op.create_index('idx_task_reschedule_dag_task_date',
- ['dag_id', 'task_id', 'execution_date'], unique=False)
+ task_reschedule_batch_op.create_foreign_key(
+ 'task_reschedule_dag_task_date_fkey',
+ 'task_instance',
+ ['task_id', 'dag_id', 'execution_date'],
+ ['task_id', 'dag_id', 'execution_date'],
+ ondelete='CASCADE',
+ )
+ task_reschedule_batch_op.create_index(
+ 'idx_task_reschedule_dag_task_date', ['dag_id', 'task_id', 'execution_date'], unique=False
+ )
with op.batch_alter_table('dag_run') as dag_run_batch_op:
modify_execution_date_with_constraint(conn, dag_run_batch_op, 'dag_run', mssql.DATETIME, None)
@@ -185,14 +212,16 @@ def downgrade():
op.alter_column(table_name="task_fail", column_name="execution_date", type_=mssql.DATETIME)
op.alter_column(table_name='task_fail', column_name='start_date', type_=mssql.DATETIME)
op.alter_column(table_name='task_fail', column_name='end_date', type_=mssql.DATETIME)
- op.create_index('idx_task_fail_dag_task_date', 'task_fail', ['dag_id', 'task_id', 'execution_date'],
- unique=False)
+ op.create_index(
+ 'idx_task_fail_dag_task_date', 'task_fail', ['dag_id', 'task_id', 'execution_date'], unique=False
+ )
op.drop_index('idx_xcom_dag_task_date', table_name='xcom')
op.alter_column(table_name="xcom", column_name="execution_date", type_=mssql.DATETIME)
op.alter_column(table_name='xcom', column_name='timestamp', type_=mssql.DATETIME)
- op.create_index('idx_xcom_dag_task_date', 'xcom', ['dag_id', 'task_ild', 'execution_date'],
- unique=False)
+ op.create_index(
+ 'idx_xcom_dag_task_date', 'xcom', ['dag_id', 'task_ild', 'execution_date'], unique=False
+ )
op.alter_column(table_name='dag', column_name='last_scheduler_run', type_=mssql.DATETIME)
op.alter_column(table_name='dag', column_name='last_pickled', type_=mssql.DATETIME)
@@ -229,7 +258,9 @@ def get_table_constraints(conn, table_name):
JOIN INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE AS ccu ON ccu.CONSTRAINT_NAME = tc.CONSTRAINT_NAME
WHERE tc.TABLE_NAME = '{table_name}' AND
(tc.CONSTRAINT_TYPE = 'PRIMARY KEY' or UPPER(tc.CONSTRAINT_TYPE) = 'UNIQUE')
- """.format(table_name=table_name)
+ """.format(
+ table_name=table_name
+ )
result = conn.execute(query).fetchall()
constraint_dict = defaultdict(list)
for constraint, constraint_type, column in result:
@@ -267,15 +298,9 @@ def drop_constraint(operator, constraint_dict):
for constraint, columns in constraint_dict.items():
if 'execution_date' in columns:
if constraint[1].lower().startswith("primary"):
- operator.drop_constraint(
- constraint[0],
- type_='primary'
- )
+ operator.drop_constraint(constraint[0], type_='primary')
elif constraint[1].lower().startswith("unique"):
- operator.drop_constraint(
- constraint[0],
- type_='unique'
- )
+ operator.drop_constraint(constraint[0], type_='unique')
def create_constraint(operator, constraint_dict):
@@ -288,14 +313,10 @@ def create_constraint(operator, constraint_dict):
for constraint, columns in constraint_dict.items():
if 'execution_date' in columns:
if constraint[1].lower().startswith("primary"):
- operator.create_primary_key(
- constraint_name=constraint[0],
- columns=reorder_columns(columns)
- )
+ operator.create_primary_key(constraint_name=constraint[0], columns=reorder_columns(columns))
elif constraint[1].lower().startswith("unique"):
operator.create_unique_constraint(
- constraint_name=constraint[0],
- columns=reorder_columns(columns)
+ constraint_name=constraint[0], columns=reorder_columns(columns)
)
diff --git a/airflow/migrations/versions/7939bcff74ba_add_dagtags_table.py b/airflow/migrations/versions/7939bcff74ba_add_dagtags_table.py
index 3bd49033e6e30..8b8b93c49f750 100644
--- a/airflow/migrations/versions/7939bcff74ba_add_dagtags_table.py
+++ b/airflow/migrations/versions/7939bcff74ba_add_dagtags_table.py
@@ -40,8 +40,11 @@ def upgrade():
'dag_tag',
sa.Column('name', sa.String(length=100), nullable=False),
sa.Column('dag_id', sa.String(length=250), nullable=False),
- sa.ForeignKeyConstraint(['dag_id'], ['dag.dag_id'], ),
- sa.PrimaryKeyConstraint('name', 'dag_id')
+ sa.ForeignKeyConstraint(
+ ['dag_id'],
+ ['dag.dag_id'],
+ ),
+ sa.PrimaryKeyConstraint('name', 'dag_id'),
)
diff --git a/airflow/migrations/versions/849da589634d_prefix_dag_permissions.py b/airflow/migrations/versions/849da589634d_prefix_dag_permissions.py
index 23fd96a21d289..240b39e1305fd 100644
--- a/airflow/migrations/versions/849da589634d_prefix_dag_permissions.py
+++ b/airflow/migrations/versions/849da589634d_prefix_dag_permissions.py
@@ -34,18 +34,16 @@
depends_on = None
-def upgrade(): # noqa: D103
+def upgrade(): # noqa: D103
permissions = ['can_dag_read', 'can_dag_edit']
view_menus = cached_app().appbuilder.sm.get_all_view_menu()
convert_permissions(permissions, view_menus, upgrade_action, upgrade_dag_id)
-def downgrade(): # noqa: D103
+def downgrade(): # noqa: D103
permissions = ['can_read', 'can_edit']
vms = cached_app().appbuilder.sm.get_all_view_menu()
- view_menus = [
- vm for vm in vms if (vm.name == permissions.RESOURCE_DAGS or vm.name.startswith('DAG:'))
- ]
+ view_menus = [vm for vm in vms if (vm.name == permissions.RESOURCE_DAGS or vm.name.startswith('DAG:'))]
convert_permissions(permissions, view_menus, downgrade_action, downgrade_dag_id)
@@ -63,7 +61,7 @@ def downgrade_dag_id(dag_id):
if dag_id == permissions.RESOURCE_DAGS:
return 'all_dags'
if dag_id.startswith("DAG:"):
- return dag_id[len("DAG:"):]
+ return dag_id[len("DAG:") :]
return dag_id
diff --git a/airflow/migrations/versions/8504051e801b_xcom_dag_task_indices.py b/airflow/migrations/versions/8504051e801b_xcom_dag_task_indices.py
index 7eadb56000137..6f60ac6b12335 100644
--- a/airflow/migrations/versions/8504051e801b_xcom_dag_task_indices.py
+++ b/airflow/migrations/versions/8504051e801b_xcom_dag_task_indices.py
@@ -35,8 +35,7 @@
def upgrade():
"""Create Index."""
- op.create_index('idx_xcom_dag_task_date', 'xcom',
- ['dag_id', 'task_id', 'execution_date'], unique=False)
+ op.create_index('idx_xcom_dag_task_date', 'xcom', ['dag_id', 'task_id', 'execution_date'], unique=False)
def downgrade():
diff --git a/airflow/migrations/versions/852ae6c715af_add_rendered_task_instance_fields_table.py b/airflow/migrations/versions/852ae6c715af_add_rendered_task_instance_fields_table.py
index 01357e7c6f706..282286d24e745 100644
--- a/airflow/migrations/versions/852ae6c715af_add_rendered_task_instance_fields_table.py
+++ b/airflow/migrations/versions/852ae6c715af_add_rendered_task_instance_fields_table.py
@@ -55,10 +55,10 @@ def upgrade():
sa.Column('task_id', sa.String(length=250), nullable=False),
sa.Column('execution_date', sa.TIMESTAMP(timezone=True), nullable=False),
sa.Column('rendered_fields', json_type(), nullable=False),
- sa.PrimaryKeyConstraint('dag_id', 'task_id', 'execution_date')
+ sa.PrimaryKeyConstraint('dag_id', 'task_id', 'execution_date'),
)
def downgrade():
"""Drop RenderedTaskInstanceFields table"""
- op.drop_table(TABLE_NAME) # pylint: disable=no-member
+ op.drop_table(TABLE_NAME) # pylint: disable=no-member
diff --git a/airflow/migrations/versions/856955da8476_fix_sqlite_foreign_key.py b/airflow/migrations/versions/856955da8476_fix_sqlite_foreign_key.py
index 9b804ce9c9243..fd8936c4be71e 100644
--- a/airflow/migrations/versions/856955da8476_fix_sqlite_foreign_key.py
+++ b/airflow/migrations/versions/856955da8476_fix_sqlite_foreign_key.py
@@ -43,29 +43,30 @@ def upgrade():
# which would fail because referenced user table doesn't exist.
#
# Use batch_alter_table to support SQLite workaround.
- chart_table = sa.Table('chart',
- sa.MetaData(),
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('label', sa.String(length=200), nullable=True),
- sa.Column('conn_id', sa.String(length=250), nullable=False),
- sa.Column('user_id', sa.Integer(), nullable=True),
- sa.Column('chart_type', sa.String(length=100), nullable=True),
- sa.Column('sql_layout', sa.String(length=50), nullable=True),
- sa.Column('sql', sa.Text(), nullable=True),
- sa.Column('y_log_scale', sa.Boolean(), nullable=True),
- sa.Column('show_datatable', sa.Boolean(), nullable=True),
- sa.Column('show_sql', sa.Boolean(), nullable=True),
- sa.Column('height', sa.Integer(), nullable=True),
- sa.Column('default_params', sa.String(length=5000), nullable=True),
- sa.Column('x_is_date', sa.Boolean(), nullable=True),
- sa.Column('iteration_no', sa.Integer(), nullable=True),
- sa.Column('last_modified', sa.DateTime(), nullable=True),
- sa.PrimaryKeyConstraint('id'))
+ chart_table = sa.Table(
+ 'chart',
+ sa.MetaData(),
+ sa.Column('id', sa.Integer(), nullable=False),
+ sa.Column('label', sa.String(length=200), nullable=True),
+ sa.Column('conn_id', sa.String(length=250), nullable=False),
+ sa.Column('user_id', sa.Integer(), nullable=True),
+ sa.Column('chart_type', sa.String(length=100), nullable=True),
+ sa.Column('sql_layout', sa.String(length=50), nullable=True),
+ sa.Column('sql', sa.Text(), nullable=True),
+ sa.Column('y_log_scale', sa.Boolean(), nullable=True),
+ sa.Column('show_datatable', sa.Boolean(), nullable=True),
+ sa.Column('show_sql', sa.Boolean(), nullable=True),
+ sa.Column('height', sa.Integer(), nullable=True),
+ sa.Column('default_params', sa.String(length=5000), nullable=True),
+ sa.Column('x_is_date', sa.Boolean(), nullable=True),
+ sa.Column('iteration_no', sa.Integer(), nullable=True),
+ sa.Column('last_modified', sa.DateTime(), nullable=True),
+ sa.PrimaryKeyConstraint('id'),
+ )
with op.batch_alter_table('chart', copy_from=chart_table) as batch_op:
- batch_op.create_foreign_key('chart_user_id_fkey', 'users',
- ['user_id'], ['id'])
+ batch_op.create_foreign_key('chart_user_id_fkey', 'users', ['user_id'], ['id'])
-def downgrade(): # noqa: D103
+def downgrade(): # noqa: D103
# Downgrade would fail because the broken FK constraint can't be re-created.
pass
diff --git a/airflow/migrations/versions/86770d1215c0_add_kubernetes_scheduler_uniqueness.py b/airflow/migrations/versions/86770d1215c0_add_kubernetes_scheduler_uniqueness.py
index b5928fdcd82ea..db3ccdc043bf4 100644
--- a/airflow/migrations/versions/86770d1215c0_add_kubernetes_scheduler_uniqueness.py
+++ b/airflow/migrations/versions/86770d1215c0_add_kubernetes_scheduler_uniqueness.py
@@ -35,34 +35,25 @@
RESOURCE_TABLE = "kube_worker_uuid"
-def upgrade(): # noqa: D103
+def upgrade(): # noqa: D103
columns_and_constraints = [
sa.Column("one_row_id", sa.Boolean, server_default=sa.true(), primary_key=True),
- sa.Column("worker_uuid", sa.String(255))
+ sa.Column("worker_uuid", sa.String(255)),
]
conn = op.get_bind()
# alembic creates an invalid SQL for mssql and mysql dialects
if conn.dialect.name in {"mysql"}:
- columns_and_constraints.append(
- sa.CheckConstraint("one_row_id<>0", name="kube_worker_one_row_id")
- )
+ columns_and_constraints.append(sa.CheckConstraint("one_row_id<>0", name="kube_worker_one_row_id"))
elif conn.dialect.name not in {"mssql"}:
- columns_and_constraints.append(
- sa.CheckConstraint("one_row_id", name="kube_worker_one_row_id")
- )
+ columns_and_constraints.append(sa.CheckConstraint("one_row_id", name="kube_worker_one_row_id"))
- table = op.create_table(
- RESOURCE_TABLE,
- *columns_and_constraints
- )
+ table = op.create_table(RESOURCE_TABLE, *columns_and_constraints)
- op.bulk_insert(table, [
- {"worker_uuid": ""}
- ])
+ op.bulk_insert(table, [{"worker_uuid": ""}])
-def downgrade(): # noqa: D103
+def downgrade(): # noqa: D103
op.drop_table(RESOURCE_TABLE)
diff --git a/airflow/migrations/versions/8d48763f6d53_add_unique_constraint_to_conn_id.py b/airflow/migrations/versions/8d48763f6d53_add_unique_constraint_to_conn_id.py
index 34ce9e5aa5679..2c743e4813185 100644
--- a/airflow/migrations/versions/8d48763f6d53_add_unique_constraint_to_conn_id.py
+++ b/airflow/migrations/versions/8d48763f6d53_add_unique_constraint_to_conn_id.py
@@ -38,16 +38,9 @@ def upgrade():
"""Apply add unique constraint to conn_id and set it as non-nullable"""
try:
with op.batch_alter_table('connection') as batch_op:
- batch_op.create_unique_constraint(
- constraint_name="unique_conn_id",
- columns=["conn_id"]
- )
+ batch_op.create_unique_constraint(constraint_name="unique_conn_id", columns=["conn_id"])
- batch_op.alter_column(
- "conn_id",
- nullable=False,
- existing_type=sa.String(250)
- )
+ batch_op.alter_column("conn_id", nullable=False, existing_type=sa.String(250))
except sa.exc.IntegrityError:
raise Exception("Make sure there are no duplicate connections with the same conn_id or null values")
@@ -55,13 +48,6 @@ def upgrade():
def downgrade():
"""Unapply add unique constraint to conn_id and set it as non-nullable"""
with op.batch_alter_table('connection') as batch_op:
- batch_op.drop_constraint(
- constraint_name="unique_conn_id",
- type_="unique"
- )
+ batch_op.drop_constraint(constraint_name="unique_conn_id", type_="unique")
- batch_op.alter_column(
- "conn_id",
- nullable=True,
- existing_type=sa.String(250)
- )
+ batch_op.alter_column("conn_id", nullable=True, existing_type=sa.String(250))
diff --git a/airflow/migrations/versions/939bb1e647c8_task_reschedule_fk_on_cascade_delete.py b/airflow/migrations/versions/939bb1e647c8_task_reschedule_fk_on_cascade_delete.py
index 4465ba1ae5b8c..ffb61a39eaaa8 100644
--- a/airflow/migrations/versions/939bb1e647c8_task_reschedule_fk_on_cascade_delete.py
+++ b/airflow/migrations/versions/939bb1e647c8_task_reschedule_fk_on_cascade_delete.py
@@ -32,30 +32,24 @@
depends_on = None
-def upgrade(): # noqa: D103
+def upgrade(): # noqa: D103
with op.batch_alter_table('task_reschedule') as batch_op:
- batch_op.drop_constraint(
- 'task_reschedule_dag_task_date_fkey',
- type_='foreignkey'
- )
+ batch_op.drop_constraint('task_reschedule_dag_task_date_fkey', type_='foreignkey')
batch_op.create_foreign_key(
'task_reschedule_dag_task_date_fkey',
'task_instance',
['task_id', 'dag_id', 'execution_date'],
['task_id', 'dag_id', 'execution_date'],
- ondelete='CASCADE'
+ ondelete='CASCADE',
)
-def downgrade(): # noqa: D103
+def downgrade(): # noqa: D103
with op.batch_alter_table('task_reschedule') as batch_op:
- batch_op.drop_constraint(
- 'task_reschedule_dag_task_date_fkey',
- type_='foreignkey'
- )
+ batch_op.drop_constraint('task_reschedule_dag_task_date_fkey', type_='foreignkey')
batch_op.create_foreign_key(
'task_reschedule_dag_task_date_fkey',
'task_instance',
['task_id', 'dag_id', 'execution_date'],
- ['task_id', 'dag_id', 'execution_date']
+ ['task_id', 'dag_id', 'execution_date'],
)
diff --git a/airflow/migrations/versions/947454bf1dff_add_ti_job_id_index.py b/airflow/migrations/versions/947454bf1dff_add_ti_job_id_index.py
index b18224e6a6158..a1d8b8f6099ba 100644
--- a/airflow/migrations/versions/947454bf1dff_add_ti_job_id_index.py
+++ b/airflow/migrations/versions/947454bf1dff_add_ti_job_id_index.py
@@ -32,9 +32,9 @@
depends_on = None
-def upgrade(): # noqa: D103
+def upgrade(): # noqa: D103
op.create_index('ti_job_id', 'task_instance', ['job_id'], unique=False)
-def downgrade(): # noqa: D103
+def downgrade(): # noqa: D103
op.drop_index('ti_job_id', table_name='task_instance')
diff --git a/airflow/migrations/versions/952da73b5eff_add_dag_code_table.py b/airflow/migrations/versions/952da73b5eff_add_dag_code_table.py
index a937b53055918..63fb68999e6ec 100644
--- a/airflow/migrations/versions/952da73b5eff_add_dag_code_table.py
+++ b/airflow/migrations/versions/952da73b5eff_add_dag_code_table.py
@@ -51,20 +51,22 @@ class SerializedDagModel(Base):
fileloc_hash = sa.Column(sa.BigInteger, nullable=False)
"""Apply add source code table"""
- op.create_table('dag_code', # pylint: disable=no-member
- sa.Column('fileloc_hash', sa.BigInteger(),
- nullable=False, primary_key=True, autoincrement=False),
- sa.Column('fileloc', sa.String(length=2000), nullable=False),
- sa.Column('source_code', sa.UnicodeText(), nullable=False),
- sa.Column('last_updated', sa.TIMESTAMP(timezone=True), nullable=False))
+ op.create_table(
+ 'dag_code', # pylint: disable=no-member
+ sa.Column('fileloc_hash', sa.BigInteger(), nullable=False, primary_key=True, autoincrement=False),
+ sa.Column('fileloc', sa.String(length=2000), nullable=False),
+ sa.Column('source_code', sa.UnicodeText(), nullable=False),
+ sa.Column('last_updated', sa.TIMESTAMP(timezone=True), nullable=False),
+ )
conn = op.get_bind()
if conn.dialect.name != 'sqlite':
if conn.dialect.name == "mssql":
op.drop_index('idx_fileloc_hash', 'serialized_dag')
- op.alter_column(table_name='serialized_dag', column_name='fileloc_hash',
- type_=sa.BigInteger(), nullable=False)
+ op.alter_column(
+ table_name='serialized_dag', column_name='fileloc_hash', type_=sa.BigInteger(), nullable=False
+ )
if conn.dialect.name == "mssql":
op.create_index('idx_fileloc_hash', 'serialized_dag', ['fileloc_hash'])
diff --git a/airflow/migrations/versions/9635ae0956e7_index_faskfail.py b/airflow/migrations/versions/9635ae0956e7_index_faskfail.py
index 9508170dcf0a8..c924b3a06af2a 100644
--- a/airflow/migrations/versions/9635ae0956e7_index_faskfail.py
+++ b/airflow/migrations/versions/9635ae0956e7_index_faskfail.py
@@ -31,11 +31,11 @@
depends_on = None
-def upgrade(): # noqa: D103
- op.create_index('idx_task_fail_dag_task_date',
- 'task_fail',
- ['dag_id', 'task_id', 'execution_date'], unique=False)
+def upgrade(): # noqa: D103
+ op.create_index(
+ 'idx_task_fail_dag_task_date', 'task_fail', ['dag_id', 'task_id', 'execution_date'], unique=False
+ )
-def downgrade(): # noqa: D103
+def downgrade(): # noqa: D103
op.drop_index('idx_task_fail_dag_task_date', table_name='task_fail')
diff --git a/airflow/migrations/versions/98271e7606e2_add_scheduling_decision_to_dagrun_and_.py b/airflow/migrations/versions/98271e7606e2_add_scheduling_decision_to_dagrun_and_.py
index a1ed2344f6b42..ac14cfe3445cd 100644
--- a/airflow/migrations/versions/98271e7606e2_add_scheduling_decision_to_dagrun_and_.py
+++ b/airflow/migrations/versions/98271e7606e2_add_scheduling_decision_to_dagrun_and_.py
@@ -58,6 +58,7 @@ def upgrade():
try:
from airflow.configuration import conf
+
concurrency = conf.getint('core', 'dag_concurrency', fallback=16)
except: # noqa
concurrency = 16
diff --git a/airflow/migrations/versions/a4c2fd67d16b_add_pool_slots_field_to_task_instance.py b/airflow/migrations/versions/a4c2fd67d16b_add_pool_slots_field_to_task_instance.py
index 1b742e4245649..121c7fa5e4fef 100644
--- a/airflow/migrations/versions/a4c2fd67d16b_add_pool_slots_field_to_task_instance.py
+++ b/airflow/migrations/versions/a4c2fd67d16b_add_pool_slots_field_to_task_instance.py
@@ -34,9 +34,9 @@
depends_on = None
-def upgrade(): # noqa: D103
+def upgrade(): # noqa: D103
op.add_column('task_instance', sa.Column('pool_slots', sa.Integer, default=1))
-def downgrade(): # noqa: D103
+def downgrade(): # noqa: D103
op.drop_column('task_instance', 'pool_slots')
diff --git a/airflow/migrations/versions/a56c9515abdc_remove_dag_stat_table.py b/airflow/migrations/versions/a56c9515abdc_remove_dag_stat_table.py
index 05adcae8b93ff..a9ea301e2737f 100644
--- a/airflow/migrations/versions/a56c9515abdc_remove_dag_stat_table.py
+++ b/airflow/migrations/versions/a56c9515abdc_remove_dag_stat_table.py
@@ -41,9 +41,11 @@ def upgrade():
def downgrade():
"""Create dag_stats table"""
- op.create_table('dag_stats',
- sa.Column('dag_id', sa.String(length=250), nullable=False),
- sa.Column('state', sa.String(length=50), nullable=False),
- sa.Column('count', sa.Integer(), nullable=False, default=0),
- sa.Column('dirty', sa.Boolean(), nullable=False, default=False),
- sa.PrimaryKeyConstraint('dag_id', 'state'))
+ op.create_table(
+ 'dag_stats',
+ sa.Column('dag_id', sa.String(length=250), nullable=False),
+ sa.Column('state', sa.String(length=50), nullable=False),
+ sa.Column('count', sa.Integer(), nullable=False, default=0),
+ sa.Column('dirty', sa.Boolean(), nullable=False, default=False),
+ sa.PrimaryKeyConstraint('dag_id', 'state'),
+ )
diff --git a/airflow/migrations/versions/a66efa278eea_add_precision_to_execution_date_in_mysql.py b/airflow/migrations/versions/a66efa278eea_add_precision_to_execution_date_in_mysql.py
index 7c122ef3df1c6..29cce5feee310 100644
--- a/airflow/migrations/versions/a66efa278eea_add_precision_to_execution_date_in_mysql.py
+++ b/airflow/migrations/versions/a66efa278eea_add_precision_to_execution_date_in_mysql.py
@@ -42,10 +42,7 @@ def upgrade():
conn = op.get_bind()
if conn.dialect.name == "mysql":
op.alter_column(
- table_name=TABLE_NAME,
- column_name=COLUMN_NAME,
- type_=mysql.TIMESTAMP(fsp=6),
- nullable=False
+ table_name=TABLE_NAME, column_name=COLUMN_NAME, type_=mysql.TIMESTAMP(fsp=6), nullable=False
)
@@ -54,8 +51,5 @@ def downgrade():
conn = op.get_bind()
if conn.dialect.name == "mysql":
op.alter_column(
- table_name=TABLE_NAME,
- column_name=COLUMN_NAME,
- type_=mysql.TIMESTAMP(),
- nullable=False
+ table_name=TABLE_NAME, column_name=COLUMN_NAME, type_=mysql.TIMESTAMP(), nullable=False
)
diff --git a/airflow/migrations/versions/b0125267960b_merge_heads.py b/airflow/migrations/versions/b0125267960b_merge_heads.py
index 007b99a2a3921..5c05dd78d3711 100644
--- a/airflow/migrations/versions/b0125267960b_merge_heads.py
+++ b/airflow/migrations/versions/b0125267960b_merge_heads.py
@@ -31,9 +31,9 @@
depends_on = None
-def upgrade(): # noqa: D103
+def upgrade(): # noqa: D103
pass
-def downgrade(): # noqa: D103
+def downgrade(): # noqa: D103
pass
diff --git a/airflow/migrations/versions/bba5a7cfc896_add_a_column_to_track_the_encryption_.py b/airflow/migrations/versions/bba5a7cfc896_add_a_column_to_track_the_encryption_.py
index b03b42c61d7cf..4b2cacd90f775 100644
--- a/airflow/migrations/versions/bba5a7cfc896_add_a_column_to_track_the_encryption_.py
+++ b/airflow/migrations/versions/bba5a7cfc896_add_a_column_to_track_the_encryption_.py
@@ -34,10 +34,9 @@
depends_on = None
-def upgrade(): # noqa: D103
- op.add_column('connection',
- sa.Column('is_extra_encrypted', sa.Boolean, default=False))
+def upgrade(): # noqa: D103
+ op.add_column('connection', sa.Column('is_extra_encrypted', sa.Boolean, default=False))
-def downgrade(): # noqa: D103
+def downgrade(): # noqa: D103
op.drop_column('connection', 'is_extra_encrypted')
diff --git a/airflow/migrations/versions/bbc73705a13e_add_notification_sent_column_to_sla_miss.py b/airflow/migrations/versions/bbc73705a13e_add_notification_sent_column_to_sla_miss.py
index f5d94f208ace2..2e73d05890950 100644
--- a/airflow/migrations/versions/bbc73705a13e_add_notification_sent_column_to_sla_miss.py
+++ b/airflow/migrations/versions/bbc73705a13e_add_notification_sent_column_to_sla_miss.py
@@ -33,9 +33,9 @@
depends_on = None
-def upgrade(): # noqa: D103
+def upgrade(): # noqa: D103
op.add_column('sla_miss', sa.Column('notification_sent', sa.Boolean, default=False))
-def downgrade(): # noqa: D103
+def downgrade(): # noqa: D103
op.drop_column('sla_miss', 'notification_sent')
diff --git a/airflow/migrations/versions/bdaa763e6c56_make_xcom_value_column_a_large_binary.py b/airflow/migrations/versions/bdaa763e6c56_make_xcom_value_column_a_large_binary.py
index 1bb0cdcd21dc9..cd4fa0d51c04a 100644
--- a/airflow/migrations/versions/bdaa763e6c56_make_xcom_value_column_a_large_binary.py
+++ b/airflow/migrations/versions/bdaa763e6c56_make_xcom_value_column_a_large_binary.py
@@ -34,7 +34,7 @@
depends_on = None
-def upgrade(): # noqa: D103
+def upgrade(): # noqa: D103
# There can be data truncation here as LargeBinary can be smaller than the pickle
# type.
# use batch_alter_table to support SQLite workaround
@@ -42,7 +42,7 @@ def upgrade(): # noqa: D103
batch_op.alter_column('value', type_=sa.LargeBinary())
-def downgrade(): # noqa: D103
+def downgrade(): # noqa: D103
# use batch_alter_table to support SQLite workaround
with op.batch_alter_table("xcom") as batch_op:
batch_op.alter_column('value', type_=sa.PickleType(pickler=dill))
diff --git a/airflow/migrations/versions/bef4f3d11e8b_drop_kuberesourceversion_and_.py b/airflow/migrations/versions/bef4f3d11e8b_drop_kuberesourceversion_and_.py
index daa00d611ccd5..d58d959ba1e16 100644
--- a/airflow/migrations/versions/bef4f3d11e8b_drop_kuberesourceversion_and_.py
+++ b/airflow/migrations/versions/bef4f3d11e8b_drop_kuberesourceversion_and_.py
@@ -53,35 +53,26 @@ def downgrade():
def _add_worker_uuid_table():
columns_and_constraints = [
sa.Column("one_row_id", sa.Boolean, server_default=sa.true(), primary_key=True),
- sa.Column("worker_uuid", sa.String(255))
+ sa.Column("worker_uuid", sa.String(255)),
]
conn = op.get_bind()
# alembic creates an invalid SQL for mssql and mysql dialects
if conn.dialect.name in {"mysql"}:
- columns_and_constraints.append(
- sa.CheckConstraint("one_row_id<>0", name="kube_worker_one_row_id")
- )
+ columns_and_constraints.append(sa.CheckConstraint("one_row_id<>0", name="kube_worker_one_row_id"))
elif conn.dialect.name not in {"mssql"}:
- columns_and_constraints.append(
- sa.CheckConstraint("one_row_id", name="kube_worker_one_row_id")
- )
+ columns_and_constraints.append(sa.CheckConstraint("one_row_id", name="kube_worker_one_row_id"))
- table = op.create_table(
- WORKER_RESOURCEVERSION_TABLE,
- *columns_and_constraints
- )
+ table = op.create_table(WORKER_RESOURCEVERSION_TABLE, *columns_and_constraints)
- op.bulk_insert(table, [
- {"worker_uuid": ""}
- ])
+ op.bulk_insert(table, [{"worker_uuid": ""}])
def _add_resource_table():
columns_and_constraints = [
sa.Column("one_row_id", sa.Boolean, server_default=sa.true(), primary_key=True),
- sa.Column("resource_version", sa.String(255))
+ sa.Column("resource_version", sa.String(255)),
]
conn = op.get_bind()
@@ -96,11 +87,6 @@ def _add_resource_table():
sa.CheckConstraint("one_row_id", name="kube_resource_version_one_row_id")
)
- table = op.create_table(
- WORKER_RESOURCEVERSION_TABLE,
- *columns_and_constraints
- )
+ table = op.create_table(WORKER_RESOURCEVERSION_TABLE, *columns_and_constraints)
- op.bulk_insert(table, [
- {"resource_version": ""}
- ])
+ op.bulk_insert(table, [{"resource_version": ""}])
diff --git a/airflow/migrations/versions/bf00311e1990_add_index_to_taskinstance.py b/airflow/migrations/versions/bf00311e1990_add_index_to_taskinstance.py
index d03868241ed36..845ce35a026c1 100644
--- a/airflow/migrations/versions/bf00311e1990_add_index_to_taskinstance.py
+++ b/airflow/migrations/versions/bf00311e1990_add_index_to_taskinstance.py
@@ -33,14 +33,9 @@
depends_on = None
-def upgrade(): # noqa: D103
- op.create_index(
- 'ti_dag_date',
- 'task_instance',
- ['dag_id', 'execution_date'],
- unique=False
- )
+def upgrade(): # noqa: D103
+ op.create_index('ti_dag_date', 'task_instance', ['dag_id', 'execution_date'], unique=False)
-def downgrade(): # noqa: D103
+def downgrade(): # noqa: D103
op.drop_index('ti_dag_date', table_name='task_instance')
diff --git a/airflow/migrations/versions/c8ffec048a3b_add_fields_to_dag.py b/airflow/migrations/versions/c8ffec048a3b_add_fields_to_dag.py
index bd0453a07dd37..c620286f7f15e 100644
--- a/airflow/migrations/versions/c8ffec048a3b_add_fields_to_dag.py
+++ b/airflow/migrations/versions/c8ffec048a3b_add_fields_to_dag.py
@@ -34,11 +34,11 @@
depends_on = None
-def upgrade(): # noqa: D103
+def upgrade(): # noqa: D103
op.add_column('dag', sa.Column('description', sa.Text(), nullable=True))
op.add_column('dag', sa.Column('default_view', sa.String(25), nullable=True))
-def downgrade(): # noqa: D103
+def downgrade(): # noqa: D103
op.drop_column('dag', 'description')
op.drop_column('dag', 'default_view')
diff --git a/airflow/migrations/versions/cc1e65623dc7_add_max_tries_column_to_task_instance.py b/airflow/migrations/versions/cc1e65623dc7_add_max_tries_column_to_task_instance.py
index 46f9178781b6b..e6169b2728f28 100644
--- a/airflow/migrations/versions/cc1e65623dc7_add_max_tries_column_to_task_instance.py
+++ b/airflow/migrations/versions/cc1e65623dc7_add_max_tries_column_to_task_instance.py
@@ -31,6 +31,7 @@
from airflow import settings
from airflow.models import DagBag
+
# revision identifiers, used by Alembic.
from airflow.models.base import COLLATION_ARGS
@@ -54,7 +55,7 @@ class TaskInstance(Base): # noqa: D101 # type: ignore
try_number = Column(Integer, default=0)
-def upgrade(): # noqa: D103
+def upgrade(): # noqa: D103
op.add_column('task_instance', sa.Column('max_tries', sa.Integer, server_default="-1"))
# Check if table task_instance exist before data migration. This check is
# needed for database that does not create table until migration finishes.
@@ -69,15 +70,11 @@ def upgrade(): # noqa: D103
sessionmaker = sa.orm.sessionmaker()
session = sessionmaker(bind=connection)
dagbag = DagBag(settings.DAGS_FOLDER)
- query = session.query(sa.func.count(TaskInstance.max_tries)).filter(
- TaskInstance.max_tries == -1
- )
+ query = session.query(sa.func.count(TaskInstance.max_tries)).filter(TaskInstance.max_tries == -1)
# Separate db query in batch to prevent loading entire table
# into memory and cause out of memory error.
while query.scalar():
- tis = session.query(TaskInstance).filter(
- TaskInstance.max_tries == -1
- ).limit(BATCH_SIZE).all()
+ tis = session.query(TaskInstance).filter(TaskInstance.max_tries == -1).limit(BATCH_SIZE).all()
for ti in tis:
dag = dagbag.get_dag(ti.dag_id)
if not dag or not dag.has_task(ti.task_id):
@@ -100,20 +97,16 @@ def upgrade(): # noqa: D103
session.commit()
-def downgrade(): # noqa: D103
+def downgrade(): # noqa: D103
engine = settings.engine
if engine.dialect.has_table(engine, 'task_instance'):
connection = op.get_bind()
sessionmaker = sa.orm.sessionmaker()
session = sessionmaker(bind=connection)
dagbag = DagBag(settings.DAGS_FOLDER)
- query = session.query(sa.func.count(TaskInstance.max_tries)).filter(
- TaskInstance.max_tries != -1
- )
+ query = session.query(sa.func.count(TaskInstance.max_tries)).filter(TaskInstance.max_tries != -1)
while query.scalar():
- tis = session.query(TaskInstance).filter(
- TaskInstance.max_tries != -1
- ).limit(BATCH_SIZE).all()
+ tis = session.query(TaskInstance).filter(TaskInstance.max_tries != -1).limit(BATCH_SIZE).all()
for ti in tis:
dag = dagbag.get_dag(ti.dag_id)
if not dag or not dag.has_task(ti.task_id):
@@ -124,8 +117,7 @@ def downgrade(): # noqa: D103
# left to retry by itself. So the current try_number should be
# max number of self retry (task.retries) minus number of
# times left for task instance to try the task.
- ti.try_number = max(0, task.retries - (ti.max_tries -
- ti.try_number))
+ ti.try_number = max(0, task.retries - (ti.max_tries - ti.try_number))
ti.max_tries = -1
session.merge(ti)
session.commit()
diff --git a/airflow/migrations/versions/cf5dc11e79ad_drop_user_and_chart.py b/airflow/migrations/versions/cf5dc11e79ad_drop_user_and_chart.py
index 86e976983a154..35b07a4c9e9df 100644
--- a/airflow/migrations/versions/cf5dc11e79ad_drop_user_and_chart.py
+++ b/airflow/migrations/versions/cf5dc11e79ad_drop_user_and_chart.py
@@ -34,7 +34,7 @@
depends_on = None
-def upgrade(): # noqa: D103
+def upgrade(): # noqa: D103
# We previously had a KnownEvent's table, but we deleted the table without
# a down migration to remove it (so we didn't delete anyone's data if they
# were happing to use the feature.
@@ -49,13 +49,15 @@ def upgrade(): # noqa: D103
op.drop_constraint('known_event_user_id_fkey', 'known_event')
if "chart" in tables:
- op.drop_table("chart", )
+ op.drop_table(
+ "chart",
+ )
if "users" in tables:
op.drop_table("users")
-def downgrade(): # noqa: D103
+def downgrade(): # noqa: D103
conn = op.get_bind()
op.create_table(
@@ -66,7 +68,7 @@ def downgrade(): # noqa: D103
sa.Column('password', sa.String(255)),
sa.Column('superuser', sa.Boolean(), default=False),
sa.PrimaryKeyConstraint('id'),
- sa.UniqueConstraint('username')
+ sa.UniqueConstraint('username'),
)
op.create_table(
@@ -86,8 +88,11 @@ def downgrade(): # noqa: D103
sa.Column('x_is_date', sa.Boolean(), nullable=True),
sa.Column('iteration_no', sa.Integer(), nullable=True),
sa.Column('last_modified', sa.DateTime(), nullable=True),
- sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
- sa.PrimaryKeyConstraint('id')
+ sa.ForeignKeyConstraint(
+ ['user_id'],
+ ['users.id'],
+ ),
+ sa.PrimaryKeyConstraint('id'),
)
if conn.dialect.name == 'mysql':
diff --git a/airflow/migrations/versions/d2ae31099d61_increase_text_size_for_mysql.py b/airflow/migrations/versions/d2ae31099d61_increase_text_size_for_mysql.py
index bb6f2a54f902b..e4f81569c4d61 100644
--- a/airflow/migrations/versions/d2ae31099d61_increase_text_size_for_mysql.py
+++ b/airflow/migrations/versions/d2ae31099d61_increase_text_size_for_mysql.py
@@ -33,11 +33,11 @@
depends_on = None
-def upgrade(): # noqa: D103
+def upgrade(): # noqa: D103
if context.config.get_main_option('sqlalchemy.url').startswith('mysql'):
op.alter_column(table_name='variable', column_name='val', type_=mysql.MEDIUMTEXT)
-def downgrade(): # noqa: D103
+def downgrade(): # noqa: D103
if context.config.get_main_option('sqlalchemy.url').startswith('mysql'):
op.alter_column(table_name='variable', column_name='val', type_=mysql.TEXT)
diff --git a/airflow/migrations/versions/d38e04c12aa2_add_serialized_dag_table.py b/airflow/migrations/versions/d38e04c12aa2_add_serialized_dag_table.py
index 2dd35a901e702..2a446e6a169f8 100644
--- a/airflow/migrations/versions/d38e04c12aa2_add_serialized_dag_table.py
+++ b/airflow/migrations/versions/d38e04c12aa2_add_serialized_dag_table.py
@@ -47,24 +47,23 @@ def upgrade():
except (sa.exc.OperationalError, sa.exc.ProgrammingError):
json_type = sa.Text
- op.create_table('serialized_dag', # pylint: disable=no-member
- sa.Column('dag_id', sa.String(length=250), nullable=False),
- sa.Column('fileloc', sa.String(length=2000), nullable=False),
- sa.Column('fileloc_hash', sa.Integer(), nullable=False),
- sa.Column('data', json_type(), nullable=False),
- sa.Column('last_updated', sa.DateTime(), nullable=False),
- sa.PrimaryKeyConstraint('dag_id'))
- op.create_index( # pylint: disable=no-member
- 'idx_fileloc_hash', 'serialized_dag', ['fileloc_hash'])
+ op.create_table(
+ 'serialized_dag', # pylint: disable=no-member
+ sa.Column('dag_id', sa.String(length=250), nullable=False),
+ sa.Column('fileloc', sa.String(length=2000), nullable=False),
+ sa.Column('fileloc_hash', sa.Integer(), nullable=False),
+ sa.Column('data', json_type(), nullable=False),
+ sa.Column('last_updated', sa.DateTime(), nullable=False),
+ sa.PrimaryKeyConstraint('dag_id'),
+ )
+ op.create_index('idx_fileloc_hash', 'serialized_dag', ['fileloc_hash']) # pylint: disable=no-member
if conn.dialect.name == "mysql":
conn.execute("SET time_zone = '+00:00'")
cur = conn.execute("SELECT @@explicit_defaults_for_timestamp")
res = cur.fetchall()
if res[0][0] == 0:
- raise Exception(
- "Global variable explicit_defaults_for_timestamp needs to be on (1) for mysql"
- )
+ raise Exception("Global variable explicit_defaults_for_timestamp needs to be on (1) for mysql")
op.alter_column( # pylint: disable=no-member
table_name="serialized_dag",
@@ -91,4 +90,4 @@ def upgrade():
def downgrade():
"""Downgrade version."""
- op.drop_table('serialized_dag') # pylint: disable=no-member
+ op.drop_table('serialized_dag') # pylint: disable=no-member
diff --git a/airflow/migrations/versions/da3f683c3a5a_add_dag_hash_column_to_serialized_dag_.py b/airflow/migrations/versions/da3f683c3a5a_add_dag_hash_column_to_serialized_dag_.py
index 4dbc77a4c5aa4..e757828731d76 100644
--- a/airflow/migrations/versions/da3f683c3a5a_add_dag_hash_column_to_serialized_dag_.py
+++ b/airflow/migrations/versions/da3f683c3a5a_add_dag_hash_column_to_serialized_dag_.py
@@ -38,7 +38,8 @@ def upgrade():
"""Apply Add dag_hash Column to serialized_dag table"""
op.add_column(
'serialized_dag',
- sa.Column('dag_hash', sa.String(32), nullable=False, server_default='Hash not calculated yet'))
+ sa.Column('dag_hash', sa.String(32), nullable=False, server_default='Hash not calculated yet'),
+ )
def downgrade():
diff --git a/airflow/migrations/versions/dd25f486b8ea_add_idx_log_dag.py b/airflow/migrations/versions/dd25f486b8ea_add_idx_log_dag.py
index c3530945d8269..220535ac7c5e8 100644
--- a/airflow/migrations/versions/dd25f486b8ea_add_idx_log_dag.py
+++ b/airflow/migrations/versions/dd25f486b8ea_add_idx_log_dag.py
@@ -31,9 +31,9 @@
depends_on = None
-def upgrade(): # noqa: D103
+def upgrade(): # noqa: D103
op.create_index('idx_log_dag', 'log', ['dag_id'], unique=False)
-def downgrade(): # noqa: D103
+def downgrade(): # noqa: D103
op.drop_index('idx_log_dag', table_name='log')
diff --git a/airflow/migrations/versions/dd4ecb8fbee3_add_schedule_interval_to_dag.py b/airflow/migrations/versions/dd4ecb8fbee3_add_schedule_interval_to_dag.py
index ed335d2e3508b..b5fdc29b9a13d 100644
--- a/airflow/migrations/versions/dd4ecb8fbee3_add_schedule_interval_to_dag.py
+++ b/airflow/migrations/versions/dd4ecb8fbee3_add_schedule_interval_to_dag.py
@@ -34,9 +34,9 @@
depends_on = None
-def upgrade(): # noqa: D103
+def upgrade(): # noqa: D103
op.add_column('dag', sa.Column('schedule_interval', sa.Text(), nullable=True))
-def downgrade(): # noqa: D103
+def downgrade(): # noqa: D103
op.drop_column('dag', 'schedule_interval')
diff --git a/airflow/migrations/versions/e38be357a868_update_schema_for_smart_sensor.py b/airflow/migrations/versions/e38be357a868_update_schema_for_smart_sensor.py
index 27227aee9377d..ed73a7e6ab92b 100644
--- a/airflow/migrations/versions/e38be357a868_update_schema_for_smart_sensor.py
+++ b/airflow/migrations/versions/e38be357a868_update_schema_for_smart_sensor.py
@@ -35,15 +35,15 @@
depends_on = None
-def mssql_timestamp(): # noqa: D103
+def mssql_timestamp(): # noqa: D103
return sa.DateTime()
-def mysql_timestamp(): # noqa: D103
+def mysql_timestamp(): # noqa: D103
return mysql.TIMESTAMP(fsp=6)
-def sa_timestamp(): # noqa: D103
+def sa_timestamp(): # noqa: D103
return sa.TIMESTAMP(timezone=True)
@@ -74,14 +74,9 @@ def upgrade(): # noqa: D103
sa.Column('execution_context', sa.Text(), nullable=True),
sa.Column('created_at', timestamp(), default=func.now(), nullable=False),
sa.Column('updated_at', timestamp(), default=func.now(), nullable=False),
- sa.PrimaryKeyConstraint('id')
- )
- op.create_index(
- 'ti_primary_key',
- 'sensor_instance',
- ['dag_id', 'task_id', 'execution_date'],
- unique=True
+ sa.PrimaryKeyConstraint('id'),
)
+ op.create_index('ti_primary_key', 'sensor_instance', ['dag_id', 'task_id', 'execution_date'], unique=True)
op.create_index('si_hashcode', 'sensor_instance', ['hashcode'], unique=False)
op.create_index('si_shardcode', 'sensor_instance', ['shardcode'], unique=False)
op.create_index('si_state_shard', 'sensor_instance', ['state', 'shardcode'], unique=False)
diff --git a/airflow/migrations/versions/e3a246e0dc1_current_schema.py b/airflow/migrations/versions/e3a246e0dc1_current_schema.py
index dfa4f5818c69f..60e6cdf3b3c29 100644
--- a/airflow/migrations/versions/e3a246e0dc1_current_schema.py
+++ b/airflow/migrations/versions/e3a246e0dc1_current_schema.py
@@ -38,7 +38,7 @@
depends_on = None
-def upgrade(): # noqa: D103
+def upgrade(): # noqa: D103
conn = op.get_bind()
inspector = Inspector.from_engine(conn)
tables = inspector.get_table_names()
@@ -55,7 +55,7 @@ def upgrade(): # noqa: D103
sa.Column('password', sa.String(length=500), nullable=True),
sa.Column('port', sa.Integer(), nullable=True),
sa.Column('extra', sa.String(length=5000), nullable=True),
- sa.PrimaryKeyConstraint('id')
+ sa.PrimaryKeyConstraint('id'),
)
if 'dag' not in tables:
op.create_table(
@@ -71,7 +71,7 @@ def upgrade(): # noqa: D103
sa.Column('pickle_id', sa.Integer(), nullable=True),
sa.Column('fileloc', sa.String(length=2000), nullable=True),
sa.Column('owners', sa.String(length=2000), nullable=True),
- sa.PrimaryKeyConstraint('dag_id')
+ sa.PrimaryKeyConstraint('dag_id'),
)
if 'dag_pickle' not in tables:
op.create_table(
@@ -80,7 +80,7 @@ def upgrade(): # noqa: D103
sa.Column('pickle', sa.PickleType(), nullable=True),
sa.Column('created_dttm', sa.DateTime(), nullable=True),
sa.Column('pickle_hash', sa.BigInteger(), nullable=True),
- sa.PrimaryKeyConstraint('id')
+ sa.PrimaryKeyConstraint('id'),
)
if 'import_error' not in tables:
op.create_table(
@@ -89,7 +89,7 @@ def upgrade(): # noqa: D103
sa.Column('timestamp', sa.DateTime(), nullable=True),
sa.Column('filename', sa.String(length=1024), nullable=True),
sa.Column('stacktrace', sa.Text(), nullable=True),
- sa.PrimaryKeyConstraint('id')
+ sa.PrimaryKeyConstraint('id'),
)
if 'job' not in tables:
op.create_table(
@@ -104,14 +104,9 @@ def upgrade(): # noqa: D103
sa.Column('executor_class', sa.String(length=500), nullable=True),
sa.Column('hostname', sa.String(length=500), nullable=True),
sa.Column('unixname', sa.String(length=1000), nullable=True),
- sa.PrimaryKeyConstraint('id')
- )
- op.create_index(
- 'job_type_heart',
- 'job',
- ['job_type', 'latest_heartbeat'],
- unique=False
+ sa.PrimaryKeyConstraint('id'),
)
+ op.create_index('job_type_heart', 'job', ['job_type', 'latest_heartbeat'], unique=False)
if 'log' not in tables:
op.create_table(
'log',
@@ -122,7 +117,7 @@ def upgrade(): # noqa: D103
sa.Column('event', sa.String(length=30), nullable=True),
sa.Column('execution_date', sa.DateTime(), nullable=True),
sa.Column('owner', sa.String(length=500), nullable=True),
- sa.PrimaryKeyConstraint('id')
+ sa.PrimaryKeyConstraint('id'),
)
if 'sla_miss' not in tables:
op.create_table(
@@ -133,7 +128,7 @@ def upgrade(): # noqa: D103
sa.Column('email_sent', sa.Boolean(), nullable=True),
sa.Column('timestamp', sa.DateTime(), nullable=True),
sa.Column('description', sa.Text(), nullable=True),
- sa.PrimaryKeyConstraint('task_id', 'dag_id', 'execution_date')
+ sa.PrimaryKeyConstraint('task_id', 'dag_id', 'execution_date'),
)
if 'slot_pool' not in tables:
op.create_table(
@@ -143,7 +138,7 @@ def upgrade(): # noqa: D103
sa.Column('slots', sa.Integer(), nullable=True),
sa.Column('description', sa.Text(), nullable=True),
sa.PrimaryKeyConstraint('id'),
- sa.UniqueConstraint('pool')
+ sa.UniqueConstraint('pool'),
)
if 'task_instance' not in tables:
op.create_table(
@@ -162,25 +157,12 @@ def upgrade(): # noqa: D103
sa.Column('pool', sa.String(length=50), nullable=True),
sa.Column('queue', sa.String(length=50), nullable=True),
sa.Column('priority_weight', sa.Integer(), nullable=True),
- sa.PrimaryKeyConstraint('task_id', 'dag_id', 'execution_date')
+ sa.PrimaryKeyConstraint('task_id', 'dag_id', 'execution_date'),
)
+ op.create_index('ti_dag_state', 'task_instance', ['dag_id', 'state'], unique=False)
+ op.create_index('ti_pool', 'task_instance', ['pool', 'state', 'priority_weight'], unique=False)
op.create_index(
- 'ti_dag_state',
- 'task_instance',
- ['dag_id', 'state'],
- unique=False
- )
- op.create_index(
- 'ti_pool',
- 'task_instance',
- ['pool', 'state', 'priority_weight'],
- unique=False
- )
- op.create_index(
- 'ti_state_lkp',
- 'task_instance',
- ['dag_id', 'task_id', 'execution_date', 'state'],
- unique=False
+ 'ti_state_lkp', 'task_instance', ['dag_id', 'task_id', 'execution_date', 'state'], unique=False
)
if 'user' not in tables:
@@ -190,7 +172,7 @@ def upgrade(): # noqa: D103
sa.Column('username', sa.String(length=250), nullable=True),
sa.Column('email', sa.String(length=500), nullable=True),
sa.PrimaryKeyConstraint('id'),
- sa.UniqueConstraint('username')
+ sa.UniqueConstraint('username'),
)
if 'variable' not in tables:
op.create_table(
@@ -199,7 +181,7 @@ def upgrade(): # noqa: D103
sa.Column('key', sa.String(length=250), nullable=True),
sa.Column('val', sa.Text(), nullable=True),
sa.PrimaryKeyConstraint('id'),
- sa.UniqueConstraint('key')
+ sa.UniqueConstraint('key'),
)
if 'chart' not in tables:
op.create_table(
@@ -219,8 +201,11 @@ def upgrade(): # noqa: D103
sa.Column('x_is_date', sa.Boolean(), nullable=True),
sa.Column('iteration_no', sa.Integer(), nullable=True),
sa.Column('last_modified', sa.DateTime(), nullable=True),
- sa.ForeignKeyConstraint(['user_id'], ['user.id'], ),
- sa.PrimaryKeyConstraint('id')
+ sa.ForeignKeyConstraint(
+ ['user_id'],
+ ['user.id'],
+ ),
+ sa.PrimaryKeyConstraint('id'),
)
if 'xcom' not in tables:
op.create_table(
@@ -228,19 +213,15 @@ def upgrade(): # noqa: D103
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('key', sa.String(length=512, **COLLATION_ARGS), nullable=True),
sa.Column('value', sa.PickleType(), nullable=True),
- sa.Column(
- 'timestamp',
- sa.DateTime(),
- default=func.now(),
- nullable=False),
+ sa.Column('timestamp', sa.DateTime(), default=func.now(), nullable=False),
sa.Column('execution_date', sa.DateTime(), nullable=False),
sa.Column('task_id', sa.String(length=250, **COLLATION_ARGS), nullable=False),
sa.Column('dag_id', sa.String(length=250, **COLLATION_ARGS), nullable=False),
- sa.PrimaryKeyConstraint('id')
+ sa.PrimaryKeyConstraint('id'),
)
-def downgrade(): # noqa: D103
+def downgrade(): # noqa: D103
op.drop_table('chart')
op.drop_table('variable')
op.drop_table('user')
diff --git a/airflow/migrations/versions/f23433877c24_fix_mysql_not_null_constraint.py b/airflow/migrations/versions/f23433877c24_fix_mysql_not_null_constraint.py
index f9725d38e4118..7a0a3c8cc1706 100644
--- a/airflow/migrations/versions/f23433877c24_fix_mysql_not_null_constraint.py
+++ b/airflow/migrations/versions/f23433877c24_fix_mysql_not_null_constraint.py
@@ -32,7 +32,7 @@
depends_on = None
-def upgrade(): # noqa: D103
+def upgrade(): # noqa: D103
conn = op.get_bind()
if conn.dialect.name == 'mysql':
conn.execute("SET time_zone = '+00:00'")
@@ -41,7 +41,7 @@ def upgrade(): # noqa: D103
op.alter_column('xcom', 'timestamp', existing_type=mysql.TIMESTAMP(fsp=6), nullable=False)
-def downgrade(): # noqa: D103
+def downgrade(): # noqa: D103
conn = op.get_bind()
if conn.dialect.name == 'mysql':
conn.execute("SET time_zone = '+00:00'")
diff --git a/airflow/migrations/versions/f2ca10b85618_add_dag_stats_table.py b/airflow/migrations/versions/f2ca10b85618_add_dag_stats_table.py
index 7ab2fe943726f..1db0440cbd4a4 100644
--- a/airflow/migrations/versions/f2ca10b85618_add_dag_stats_table.py
+++ b/airflow/migrations/versions/f2ca10b85618_add_dag_stats_table.py
@@ -33,14 +33,16 @@
depends_on = None
-def upgrade(): # noqa: D103
- op.create_table('dag_stats',
- sa.Column('dag_id', sa.String(length=250), nullable=False),
- sa.Column('state', sa.String(length=50), nullable=False),
- sa.Column('count', sa.Integer(), nullable=False, default=0),
- sa.Column('dirty', sa.Boolean(), nullable=False, default=False),
- sa.PrimaryKeyConstraint('dag_id', 'state'))
-
-
-def downgrade(): # noqa: D103
+def upgrade(): # noqa: D103
+ op.create_table(
+ 'dag_stats',
+ sa.Column('dag_id', sa.String(length=250), nullable=False),
+ sa.Column('state', sa.String(length=50), nullable=False),
+ sa.Column('count', sa.Integer(), nullable=False, default=0),
+ sa.Column('dirty', sa.Boolean(), nullable=False, default=False),
+ sa.PrimaryKeyConstraint('dag_id', 'state'),
+ )
+
+
+def downgrade(): # noqa: D103
op.drop_table('dag_stats')
diff --git a/airflow/migrations/versions/fe461863935f_increase_length_for_connection_password.py b/airflow/migrations/versions/fe461863935f_increase_length_for_connection_password.py
index 4b14e8824dcdc..0e97630248a9b 100644
--- a/airflow/migrations/versions/fe461863935f_increase_length_for_connection_password.py
+++ b/airflow/migrations/versions/fe461863935f_increase_length_for_connection_password.py
@@ -37,16 +37,20 @@
def upgrade():
"""Apply increase_length_for_connection_password"""
with op.batch_alter_table('connection', schema=None) as batch_op:
- batch_op.alter_column('password',
- existing_type=sa.VARCHAR(length=500),
- type_=sa.String(length=5000),
- existing_nullable=True)
+ batch_op.alter_column(
+ 'password',
+ existing_type=sa.VARCHAR(length=500),
+ type_=sa.String(length=5000),
+ existing_nullable=True,
+ )
def downgrade():
"""Unapply increase_length_for_connection_password"""
with op.batch_alter_table('connection', schema=None) as batch_op:
- batch_op.alter_column('password',
- existing_type=sa.String(length=5000),
- type_=sa.VARCHAR(length=500),
- existing_nullable=True)
+ batch_op.alter_column(
+ 'password',
+ existing_type=sa.String(length=5000),
+ type_=sa.VARCHAR(length=500),
+ existing_nullable=True,
+ )
diff --git a/airflow/models/base.py b/airflow/models/base.py
index dc1953439ecf1..55584e4f2a27b 100644
--- a/airflow/models/base.py
+++ b/airflow/models/base.py
@@ -26,9 +26,7 @@
SQL_ALCHEMY_SCHEMA = conf.get("core", "SQL_ALCHEMY_SCHEMA")
metadata = (
- None
- if not SQL_ALCHEMY_SCHEMA or SQL_ALCHEMY_SCHEMA.isspace()
- else MetaData(schema=SQL_ALCHEMY_SCHEMA)
+ None if not SQL_ALCHEMY_SCHEMA or SQL_ALCHEMY_SCHEMA.isspace() else MetaData(schema=SQL_ALCHEMY_SCHEMA)
)
Base = declarative_base(metadata=metadata) # type: Any
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 5e61ad40f1b36..fa4a840fe81e8 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -25,8 +25,20 @@
from abc import ABCMeta, abstractmethod
from datetime import datetime, timedelta
from typing import (
- TYPE_CHECKING, Any, Callable, ClassVar, Dict, FrozenSet, Iterable, List, Optional, Sequence, Set, Tuple,
- Type, Union,
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ ClassVar,
+ Dict,
+ FrozenSet,
+ Iterable,
+ List,
+ Optional,
+ Sequence,
+ Set,
+ Tuple,
+ Type,
+ Union,
)
import attr
@@ -282,8 +294,12 @@ class derived from this one results in the creation of a task object,
pool = "" # type: str
# base list which includes all the attrs that don't need deep copy.
- _base_operator_shallow_copy_attrs: Tuple[str, ...] = \
- ('user_defined_macros', 'user_defined_filters', 'params', '_log',)
+ _base_operator_shallow_copy_attrs: Tuple[str, ...] = (
+ 'user_defined_macros',
+ 'user_defined_filters',
+ 'params',
+ '_log',
+ )
# each operator should override this class attr for shallow copy attrs.
shallow_copy_attrs: Tuple[str, ...] = ()
@@ -365,7 +381,7 @@ def __init__(
inlets: Optional[Any] = None,
outlets: Optional[Any] = None,
task_group: Optional["TaskGroup"] = None,
- **kwargs
+ **kwargs,
):
from airflow.models.dag import DagContext
from airflow.utils.task_group import TaskGroupContext
@@ -375,17 +391,15 @@ def __init__(
if not conf.getboolean('operators', 'ALLOW_ILLEGAL_ARGUMENTS'):
raise AirflowException(
"Invalid arguments were passed to {c} (task_id: {t}). Invalid "
- "arguments were:\n**kwargs: {k}".format(
- c=self.__class__.__name__, k=kwargs, t=task_id),
+ "arguments were:\n**kwargs: {k}".format(c=self.__class__.__name__, k=kwargs, t=task_id),
)
warnings.warn(
'Invalid arguments were passed to {c} (task_id: {t}). '
'Support for passing such arguments will be dropped in '
'future. Invalid arguments were:'
- '\n**kwargs: {k}'.format(
- c=self.__class__.__name__, k=kwargs, t=task_id),
+ '\n**kwargs: {k}'.format(c=self.__class__.__name__, k=kwargs, t=task_id),
category=PendingDeprecationWarning,
- stacklevel=3
+ stacklevel=3,
)
validate_key(task_id)
self.task_id = task_id
@@ -412,9 +426,13 @@ def __init__(
if not TriggerRule.is_valid(trigger_rule):
raise AirflowException(
"The trigger_rule must be one of {all_triggers},"
- "'{d}.{t}'; received '{tr}'."
- .format(all_triggers=TriggerRule.all_triggers(),
- d=dag.dag_id if dag else "", t=task_id, tr=trigger_rule))
+ "'{d}.{t}'; received '{tr}'.".format(
+ all_triggers=TriggerRule.all_triggers(),
+ d=dag.dag_id if dag else "",
+ t=task_id,
+ tr=trigger_rule,
+ )
+ )
self.trigger_rule = trigger_rule
self.depends_on_past = depends_on_past
@@ -427,8 +445,7 @@ def __init__(
self.pool = Pool.DEFAULT_POOL_NAME if pool is None else pool
self.pool_slots = pool_slots
if self.pool_slots < 1:
- raise AirflowException("pool slots for %s in dag %s cannot be less than 1"
- % (self.task_id, dag.dag_id))
+ raise AirflowException(f"pool slots for {self.task_id} in dag {dag.dag_id} cannot be less than 1")
self.sla = sla
self.execution_timeout = execution_timeout
self.on_execute_callback = on_execute_callback
@@ -448,9 +465,13 @@ def __init__(
if not WeightRule.is_valid(weight_rule):
raise AirflowException(
"The weight_rule must be one of {all_weight_rules},"
- "'{d}.{t}'; received '{tr}'."
- .format(all_weight_rules=WeightRule.all_weight_rules,
- d=dag.dag_id if dag else "", t=task_id, tr=weight_rule))
+ "'{d}.{t}'; received '{tr}'.".format(
+ all_weight_rules=WeightRule.all_weight_rules,
+ d=dag.dag_id if dag else "",
+ t=task_id,
+ tr=weight_rule,
+ )
+ )
self.weight_rule = weight_rule
self.resources: Optional[Resources] = Resources(**resources) if resources else None
self.run_as_user = run_as_user
@@ -468,6 +489,7 @@ def __init__(
# subdag parameter is only set for SubDagOperator.
# Setting it to None by default as other Operators do not have that field
from airflow.models.dag import DAG
+
self.subdag: Optional[DAG] = None
self._log = logging.getLogger("airflow.task.operators")
@@ -480,10 +502,22 @@ def __init__(
self._outlets: List = []
if inlets:
- self._inlets = inlets if isinstance(inlets, list) else [inlets, ]
+ self._inlets = (
+ inlets
+ if isinstance(inlets, list)
+ else [
+ inlets,
+ ]
+ )
if outlets:
- self._outlets = outlets if isinstance(outlets, list) else [outlets, ]
+ self._outlets = (
+ outlets
+ if isinstance(outlets, list)
+ else [
+ outlets,
+ ]
+ )
def __eq__(self, other):
if type(self) is type(other) and self.task_id == other.task_id:
@@ -587,8 +621,7 @@ def dag(self) -> Any:
if self.has_dag():
return self._dag
else:
- raise AirflowException(
- f'Operator {self} has not been assigned to a DAG yet')
+ raise AirflowException(f'Operator {self} has not been assigned to a DAG yet')
@dag.setter
def dag(self, dag: Any):
@@ -597,15 +630,14 @@ def dag(self, dag: Any):
that same DAG are ok.
"""
from airflow.models.dag import DAG
+
if dag is None:
self._dag = None
return
if not isinstance(dag, DAG):
- raise TypeError(
- f'Expected DAG; received {dag.__class__.__name__}')
+ raise TypeError(f'Expected DAG; received {dag.__class__.__name__}')
elif self.has_dag() and self.dag is not dag:
- raise AirflowException(
- f"The DAG assigned to {self} can not be changed.")
+ raise AirflowException(f"The DAG assigned to {self} can not be changed.")
elif self.task_id not in dag.task_dict:
dag.add_task(self)
elif self.task_id in dag.task_dict and dag.task_dict[self.task_id] is not self:
@@ -671,7 +703,7 @@ def set_xcomargs_dependencies(self) -> None:
"""
from airflow.models.xcom_arg import XComArg
- def apply_set_upstream(arg: Any): # noqa
+ def apply_set_upstream(arg: Any): # noqa
if isinstance(arg, XComArg):
self.set_upstream(arg.operator)
elif isinstance(arg, (tuple, set, list)):
@@ -712,10 +744,13 @@ def priority_weight_total(self) -> int:
if not self._dag:
return self.priority_weight
from airflow.models.dag import DAG
+
dag: DAG = self._dag
return self.priority_weight + sum(
- map(lambda task_id: dag.task_dict[task_id].priority_weight,
- self.get_flat_relative_ids(upstream=upstream))
+ map(
+ lambda task_id: dag.task_dict[task_id].priority_weight,
+ self.get_flat_relative_ids(upstream=upstream),
+ )
)
@cached_property
@@ -723,6 +758,7 @@ def operator_extra_link_dict(self) -> Dict[str, Any]:
"""Returns dictionary of all extra links for the operator"""
op_extra_links_from_plugin: Dict[str, Any] = {}
from airflow import plugins_manager
+
plugins_manager.initialize_extra_operators_links_plugins()
if plugins_manager.operator_extra_links is None:
raise AirflowException("Can't load operators")
@@ -730,9 +766,7 @@ def operator_extra_link_dict(self) -> Dict[str, Any]:
if ope.operators and self.__class__ in ope.operators:
op_extra_links_from_plugin.update({ope.name: ope})
- operator_extra_links_all = {
- link.name: link for link in self.operator_extra_links
- }
+ operator_extra_links_all = {link.name: link for link in self.operator_extra_links}
# Extra links defined in Plugins overrides operator links defined in operator
operator_extra_links_all.update(op_extra_links_from_plugin)
@@ -742,6 +776,7 @@ def operator_extra_link_dict(self) -> Dict[str, Any]:
def global_operator_extra_link_dict(self) -> Dict[str, Any]:
"""Returns dictionary of all global extra links"""
from airflow import plugins_manager
+
plugins_manager.initialize_extra_operators_links_plugins()
if plugins_manager.global_operator_extra_links is None:
raise AirflowException("Can't load operators")
@@ -786,8 +821,9 @@ def __deepcopy__(self, memo):
result = cls.__new__(cls)
memo[id(self)] = result
- shallow_copy = cls.shallow_copy_attrs + \
- cls._base_operator_shallow_copy_attrs # pylint: disable=protected-access
+ shallow_copy = (
+ cls.shallow_copy_attrs + cls._base_operator_shallow_copy_attrs
+ ) # pylint: disable=protected-access
for k, v in self.__dict__.items():
if k not in shallow_copy:
@@ -821,8 +857,12 @@ def render_template_fields(self, context: Dict, jinja_env: Optional[jinja2.Envir
self._do_render_template_fields(self, self.template_fields, context, jinja_env, set())
def _do_render_template_fields(
- self, parent: Any, template_fields: Iterable[str], context: Dict, jinja_env: jinja2.Environment,
- seen_oids: Set
+ self,
+ parent: Any,
+ template_fields: Iterable[str],
+ context: Dict,
+ jinja_env: jinja2.Environment,
+ seen_oids: Set,
) -> None:
for attr_name in template_fields:
content = getattr(parent, attr_name)
@@ -830,9 +870,12 @@ def _do_render_template_fields(
rendered_content = self.render_template(content, context, jinja_env, seen_oids)
setattr(parent, attr_name, rendered_content)
- def render_template( # pylint: disable=too-many-return-statements
- self, content: Any, context: Dict, jinja_env: Optional[jinja2.Environment] = None,
- seen_oids: Optional[Set] = None
+ def render_template( # pylint: disable=too-many-return-statements
+ self,
+ content: Any,
+ context: Dict,
+ jinja_env: Optional[jinja2.Environment] = None,
+ seen_oids: Optional[Set] = None,
) -> Any:
"""
Render a templated string. The content can be a collection holding multiple templated strings and will
@@ -921,8 +964,7 @@ def resolve_template_files(self) -> None:
content = getattr(self, field, None)
if content is None: # pylint: disable=no-else-continue
continue
- elif isinstance(content, str) and \
- any(content.endswith(ext) for ext in self.template_ext):
+ elif isinstance(content, str) and any(content.endswith(ext) for ext in self.template_ext):
env = self.get_template_env()
try:
setattr(self, field, env.loader.get_source(env, content)[0])
@@ -931,8 +973,9 @@ def resolve_template_files(self) -> None:
elif isinstance(content, list):
env = self.dag.get_template_env()
for i in range(len(content)): # pylint: disable=consider-using-enumerate
- if isinstance(content[i], str) and \
- any(content[i].endswith(ext) for ext in self.template_ext):
+ if isinstance(content[i], str) and any(
+ content[i].endswith(ext) for ext in self.template_ext
+ ):
try:
content[i] = env.loader.get_source(env, content[i])[0]
except Exception as e: # pylint: disable=broad-except
@@ -960,12 +1003,14 @@ def downstream_task_ids(self) -> Set[str]:
return self._downstream_task_ids
@provide_session
- def clear(self,
- start_date: Optional[datetime] = None,
- end_date: Optional[datetime] = None,
- upstream: bool = False,
- downstream: bool = False,
- session: Session = None):
+ def clear(
+ self,
+ start_date: Optional[datetime] = None,
+ end_date: Optional[datetime] = None,
+ upstream: bool = False,
+ downstream: bool = False,
+ session: Session = None,
+ ):
"""
Clears the state of task instances associated with the task, following
the parameters specified.
@@ -980,12 +1025,10 @@ def clear(self,
tasks = [self.task_id]
if upstream:
- tasks += [
- t.task_id for t in self.get_flat_relatives(upstream=True)]
+ tasks += [t.task_id for t in self.get_flat_relatives(upstream=True)]
if downstream:
- tasks += [
- t.task_id for t in self.get_flat_relatives(upstream=False)]
+ tasks += [t.task_id for t in self.get_flat_relatives(upstream=False)]
qry = qry.filter(TaskInstance.task_id.in_(tasks))
results = qry.all()
@@ -995,26 +1038,32 @@ def clear(self,
return count
@provide_session
- def get_task_instances(self, start_date: Optional[datetime] = None,
- end_date: Optional[datetime] = None,
- session: Session = None) -> List[TaskInstance]:
+ def get_task_instances(
+ self,
+ start_date: Optional[datetime] = None,
+ end_date: Optional[datetime] = None,
+ session: Session = None,
+ ) -> List[TaskInstance]:
"""
Get a set of task instance related to this task for a specific date
range.
"""
end_date = end_date or timezone.utcnow()
- return session.query(TaskInstance)\
- .filter(TaskInstance.dag_id == self.dag_id)\
- .filter(TaskInstance.task_id == self.task_id)\
- .filter(TaskInstance.execution_date >= start_date)\
- .filter(TaskInstance.execution_date <= end_date)\
- .order_by(TaskInstance.execution_date)\
+ return (
+ session.query(TaskInstance)
+ .filter(TaskInstance.dag_id == self.dag_id)
+ .filter(TaskInstance.task_id == self.task_id)
+ .filter(TaskInstance.execution_date >= start_date)
+ .filter(TaskInstance.execution_date <= end_date)
+ .order_by(TaskInstance.execution_date)
.all()
+ )
- def get_flat_relative_ids(self,
- upstream: bool = False,
- found_descendants: Optional[Set[str]] = None,
- ) -> Set[str]:
+ def get_flat_relative_ids(
+ self,
+ upstream: bool = False,
+ found_descendants: Optional[Set[str]] = None,
+ ) -> Set[str]:
"""Get a flat set of relatives' ids, either upstream or downstream."""
if not self._dag:
return set()
@@ -1036,17 +1085,18 @@ def get_flat_relatives(self, upstream: bool = False):
if not self._dag:
return set()
from airflow.models.dag import DAG
+
dag: DAG = self._dag
- return list(map(lambda task_id: dag.task_dict[task_id],
- self.get_flat_relative_ids(upstream)))
+ return list(map(lambda task_id: dag.task_dict[task_id], self.get_flat_relative_ids(upstream)))
def run(
- self,
- start_date: Optional[datetime] = None,
- end_date: Optional[datetime] = None,
- ignore_first_depends_on_past: bool = True,
- ignore_ti_state: bool = False,
- mark_success: bool = False) -> None:
+ self,
+ start_date: Optional[datetime] = None,
+ end_date: Optional[datetime] = None,
+ ignore_first_depends_on_past: bool = True,
+ ignore_ti_state: bool = False,
+ mark_success: bool = False,
+ ) -> None:
"""Run a set of task instances for a date range."""
start_date = start_date or self.start_date
end_date = end_date or self.end_date or timezone.utcnow()
@@ -1054,9 +1104,9 @@ def run(
for execution_date in self.dag.date_range(start_date, end_date=end_date):
TaskInstance(self, execution_date).run(
mark_success=mark_success,
- ignore_depends_on_past=(
- execution_date == start_date and ignore_first_depends_on_past),
- ignore_ti_state=ignore_ti_state)
+ ignore_depends_on_past=(execution_date == start_date and ignore_first_depends_on_past),
+ ignore_ti_state=ignore_ti_state,
+ )
def dry_run(self) -> None:
"""Performs dry run for the operator - just render template fields."""
@@ -1088,8 +1138,7 @@ def get_direct_relatives(self, upstream: bool = False) -> List["BaseOperator"]:
return self.downstream_list
def __repr__(self):
- return "".format(
- self=self)
+ return "".format(self=self)
@property
def task_type(self) -> str:
@@ -1099,8 +1148,7 @@ def task_type(self) -> str:
def add_only_new(self, item_set: Set[str], item: str) -> None:
"""Adds only new items to item set"""
if item in item_set:
- self.log.warning(
- 'Dependency %s, %s already registered', self, item)
+ self.log.warning('Dependency %s, %s already registered', self, item)
else:
item_set.add(item)
@@ -1133,25 +1181,29 @@ def _set_relatives(
if not isinstance(task, BaseOperator):
raise AirflowException(
"Relationships can only be set between "
- "Operators; received {}".format(task.__class__.__name__))
+ "Operators; received {}".format(task.__class__.__name__)
+ )
# relationships can only be set if the tasks share a single DAG. Tasks
# without a DAG are assigned to that DAG.
dags = {
task._dag.dag_id: task._dag # type: ignore # pylint: disable=protected-access,no-member
- for task in self.roots + task_list if task.has_dag()} # pylint: disable=no-member
+ for task in self.roots + task_list
+ if task.has_dag() # pylint: disable=no-member
+ }
if len(dags) > 1:
raise AirflowException(
- 'Tried to set relationships between tasks in '
- 'more than one DAG: {}'.format(dags.values()))
+ 'Tried to set relationships between tasks in ' 'more than one DAG: {}'.format(dags.values())
+ )
elif len(dags) == 1:
dag = dags.popitem()[1]
else:
raise AirflowException(
"Tried to create relationships between tasks that don't have "
"DAGs yet. Set the DAG for at least one "
- "task and try again: {}".format([self] + task_list))
+ "task and try again: {}".format([self] + task_list)
+ )
if dag and not self.has_dag():
self.dag = dag
@@ -1184,14 +1236,15 @@ def set_upstream(self, task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]])
def output(self):
"""Returns reference to XCom pushed by current operator"""
from airflow.models.xcom_arg import XComArg
+
return XComArg(operator=self)
@staticmethod
def xcom_push(
- context: Any,
- key: str,
- value: Any,
- execution_date: Optional[datetime] = None,
+ context: Any,
+ key: str,
+ value: Any,
+ execution_date: Optional[datetime] = None,
) -> None:
"""
Make an XCom available for tasks to pull.
@@ -1208,18 +1261,15 @@ def xcom_push(
task on a future date without it being immediately visible.
:type execution_date: datetime
"""
- context['ti'].xcom_push(
- key=key,
- value=value,
- execution_date=execution_date)
+ context['ti'].xcom_push(key=key, value=value, execution_date=execution_date)
@staticmethod
def xcom_pull(
- context: Any,
- task_ids: Optional[List[str]] = None,
- dag_id: Optional[str] = None,
- key: str = XCOM_RETURN_KEY,
- include_prior_dates: Optional[bool] = None,
+ context: Any,
+ task_ids: Optional[List[str]] = None,
+ dag_id: Optional[str] = None,
+ key: str = XCOM_RETURN_KEY,
+ include_prior_dates: Optional[bool] = None,
) -> Any:
"""
Pull XComs that optionally meet certain criteria.
@@ -1253,16 +1303,15 @@ def xcom_pull(
:type include_prior_dates: bool
"""
return context['ti'].xcom_pull(
- key=key,
- task_ids=task_ids,
- dag_id=dag_id,
- include_prior_dates=include_prior_dates)
+ key=key, task_ids=task_ids, dag_id=dag_id, include_prior_dates=include_prior_dates
+ )
@cached_property
def extra_links(self) -> List[str]:
"""@property: extra links for the task"""
- return list(set(self.operator_extra_link_dict.keys())
- .union(self.global_operator_extra_link_dict.keys()))
+ return list(
+ set(self.operator_extra_link_dict.keys()).union(self.global_operator_extra_link_dict.keys())
+ )
def get_extra_links(self, dttm: datetime, link_name: str) -> Optional[Dict[str, Any]]:
"""
@@ -1288,11 +1337,25 @@ def get_serialized_fields(cls):
"""Stringified DAGs and operators contain exactly these fields."""
if not cls.__serialized_fields:
cls.__serialized_fields = frozenset(
- vars(BaseOperator(task_id='test')).keys() - {
- 'inlets', 'outlets', '_upstream_task_ids', 'default_args', 'dag', '_dag',
+ vars(BaseOperator(task_id='test')).keys()
+ - {
+ 'inlets',
+ 'outlets',
+ '_upstream_task_ids',
+ 'default_args',
+ 'dag',
+ '_dag',
'_BaseOperator__instantiated',
- } | {'_task_type', 'subdag', 'ui_color', 'ui_fgcolor',
- 'template_fields', 'template_fields_renderers'})
+ }
+ | {
+ '_task_type',
+ 'subdag',
+ 'ui_color',
+ 'ui_fgcolor',
+ 'template_fields',
+ 'template_fields_renderers',
+ }
+ )
return cls.__serialized_fields
@@ -1341,19 +1404,23 @@ def chain(*tasks: Union[BaseOperator, Sequence[BaseOperator]]):
if not isinstance(up_task, Sequence) or not isinstance(down_task, Sequence):
raise TypeError(
'Chain not supported between instances of {up_type} and {down_type}'.format(
- up_type=type(up_task), down_type=type(down_task)))
+ up_type=type(up_task), down_type=type(down_task)
+ )
+ )
up_task_list = up_task
down_task_list = down_task
if len(up_task_list) != len(down_task_list):
raise AirflowException(
f'Chain not supported different length Iterable '
- f'but get {len(up_task_list)} and {len(down_task_list)}')
+ f'but get {len(up_task_list)} and {len(down_task_list)}'
+ )
for up_t, down_t in zip(up_task_list, down_task_list):
up_t.set_downstream(down_t)
-def cross_downstream(from_tasks: Sequence[BaseOperator],
- to_tasks: Union[BaseOperator, Sequence[BaseOperator]]):
+def cross_downstream(
+ from_tasks: Sequence[BaseOperator], to_tasks: Union[BaseOperator, Sequence[BaseOperator]]
+):
r"""
Set downstream dependencies for all tasks in from_tasks to all tasks in to_tasks.
diff --git a/airflow/models/connection.py b/airflow/models/connection.py
index d724a3f18122c..9f50e04d14ab1 100644
--- a/airflow/models/connection.py
+++ b/airflow/models/connection.py
@@ -39,7 +39,7 @@
CONN_TYPE_TO_HOOK = {
"azure_batch": (
"airflow.providers.microsoft.azure.hooks.azure_batch.AzureBatchHook",
- "azure_batch_conn_id"
+ "azure_batch_conn_id",
),
"azure_cosmos": (
"airflow.providers.microsoft.azure.hooks.azure_cosmos.AzureCosmosDBHook",
@@ -55,7 +55,7 @@
"docker": ("airflow.providers.docker.hooks.docker.DockerHook", "docker_conn_id"),
"elasticsearch": (
"airflow.providers.elasticsearch.hooks.elasticsearch.ElasticsearchHook",
- "elasticsearch_conn_id"
+ "elasticsearch_conn_id",
),
"exasol": ("airflow.providers.exasol.hooks.exasol.ExasolHook", "exasol_conn_id"),
"gcpcloudsql": (
@@ -93,10 +93,7 @@
def parse_netloc_to_hostname(*args, **kwargs):
"""This method is deprecated."""
- warnings.warn(
- "This method is deprecated.",
- DeprecationWarning
- )
+ warnings.warn("This method is deprecated.", DeprecationWarning)
return _parse_netloc_to_hostname(*args, **kwargs)
@@ -170,7 +167,7 @@ def __init__(
schema: Optional[str] = None,
port: Optional[int] = None,
extra: Optional[str] = None,
- uri: Optional[str] = None
+ uri: Optional[str] = None,
):
super().__init__()
self.conn_id = conn_id
@@ -196,8 +193,7 @@ def __init__(
def parse_from_uri(self, **uri):
"""This method is deprecated. Please use uri parameter in constructor."""
warnings.warn(
- "This method is deprecated. Please use uri parameter in constructor.",
- DeprecationWarning
+ "This method is deprecated. Please use uri parameter in constructor.", DeprecationWarning
)
self._parse_from_uri(**uri)
@@ -212,10 +208,8 @@ def _parse_from_uri(self, uri: str):
self.host = _parse_netloc_to_hostname(uri_parts)
quoted_schema = uri_parts.path[1:]
self.schema = unquote(quoted_schema) if quoted_schema else quoted_schema
- self.login = unquote(uri_parts.username) \
- if uri_parts.username else uri_parts.username
- self.password = unquote(uri_parts.password) \
- if uri_parts.password else uri_parts.password
+ self.login = unquote(uri_parts.username) if uri_parts.username else uri_parts.username
+ self.password = unquote(uri_parts.password) if uri_parts.password else uri_parts.password
self.port = uri_parts.port
if uri_parts.query:
self.extra = json.dumps(dict(parse_qsl(uri_parts.query, keep_blank_values=True)))
@@ -263,7 +257,10 @@ def get_password(self) -> Optional[str]:
if not fernet.is_encrypted:
raise AirflowException(
"Can't decrypt encrypted password for login={}, \
- FERNET_KEY configuration is missing".format(self.login))
+ FERNET_KEY configuration is missing".format(
+ self.login
+ )
+ )
return fernet.decrypt(bytes(self._password, 'utf-8')).decode()
else:
return self._password
@@ -276,10 +273,9 @@ def set_password(self, value: Optional[str]):
self.is_encrypted = fernet.is_encrypted
@declared_attr
- def password(cls): # pylint: disable=no-self-argument
+ def password(cls): # pylint: disable=no-self-argument
"""Password. The value is decrypted/encrypted when reading/setting the value."""
- return synonym('_password',
- descriptor=property(cls.get_password, cls.set_password))
+ return synonym('_password', descriptor=property(cls.get_password, cls.set_password))
def get_extra(self) -> Dict:
"""Return encrypted extra-data."""
@@ -288,7 +284,10 @@ def get_extra(self) -> Dict:
if not fernet.is_encrypted:
raise AirflowException(
"Can't decrypt `extra` params for login={},\
- FERNET_KEY configuration is missing".format(self.login))
+ FERNET_KEY configuration is missing".format(
+ self.login
+ )
+ )
return fernet.decrypt(bytes(self._extra, 'utf-8')).decode()
else:
return self._extra
@@ -304,10 +303,9 @@ def set_extra(self, value: str):
self.is_extra_encrypted = False
@declared_attr
- def extra(cls): # pylint: disable=no-self-argument
+ def extra(cls): # pylint: disable=no-self-argument
"""Extra data. The value is decrypted/encrypted when reading/setting the value."""
- return synonym('_extra',
- descriptor=property(cls.get_extra, cls.set_extra))
+ return synonym('_extra', descriptor=property(cls.get_extra, cls.set_extra))
def rotate_fernet_key(self):
"""Encrypts data with a new key. See: :ref:`security/fernet`"""
@@ -337,17 +335,17 @@ def log_info(self):
"This method is deprecated. You can read each field individually or "
"use the default representation (__repr__).",
DeprecationWarning,
- stacklevel=2
+ stacklevel=2,
+ )
+ return "id: {}. Host: {}, Port: {}, Schema: {}, " "Login: {}, Password: {}, extra: {}".format(
+ self.conn_id,
+ self.host,
+ self.port,
+ self.schema,
+ self.login,
+ "XXXXXXXX" if self.password else None,
+ "XXXXXXXX" if self.extra_dejson else None,
)
- return ("id: {}. Host: {}, Port: {}, Schema: {}, "
- "Login: {}, Password: {}, extra: {}".
- format(self.conn_id,
- self.host,
- self.port,
- self.schema,
- self.login,
- "XXXXXXXX" if self.password else None,
- "XXXXXXXX" if self.extra_dejson else None))
def debug_info(self):
"""
@@ -358,17 +356,17 @@ def debug_info(self):
"This method is deprecated. You can read each field individually or "
"use the default representation (__repr__).",
DeprecationWarning,
- stacklevel=2
+ stacklevel=2,
+ )
+ return "id: {}. Host: {}, Port: {}, Schema: {}, " "Login: {}, Password: {}, extra: {}".format(
+ self.conn_id,
+ self.host,
+ self.port,
+ self.schema,
+ self.login,
+ "XXXXXXXX" if self.password else None,
+ self.extra_dejson,
)
- return ("id: {}. Host: {}, Port: {}, Schema: {}, "
- "Login: {}, Password: {}, extra: {}".
- format(self.conn_id,
- self.host,
- self.port,
- self.schema,
- self.login,
- "XXXXXXXX" if self.password else None,
- self.extra_dejson))
@property
def extra_dejson(self) -> Dict:
diff --git a/airflow/models/crypto.py b/airflow/models/crypto.py
index 8b55448055a91..d6e0ee8341ef8 100644
--- a/airflow/models/crypto.py
+++ b/airflow/models/crypto.py
@@ -79,15 +79,12 @@ def get_fernet():
try:
fernet_key = conf.get('core', 'FERNET_KEY')
if not fernet_key:
- log.warning(
- "empty cryptography key - values will not be stored encrypted."
- )
+ log.warning("empty cryptography key - values will not be stored encrypted.")
_fernet = NullFernet()
else:
- _fernet = MultiFernet([
- Fernet(fernet_part.encode('utf-8'))
- for fernet_part in fernet_key.split(',')
- ])
+ _fernet = MultiFernet(
+ [Fernet(fernet_part.encode('utf-8')) for fernet_part in fernet_key.split(',')]
+ )
_fernet.is_encrypted = True
except (ValueError, TypeError) as value_error:
raise AirflowException(f"Could not create Fernet object: {value_error}")
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 166553d9db415..78ad2479f4e16 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -29,7 +29,18 @@
from datetime import datetime, timedelta
from inspect import signature
from typing import (
- TYPE_CHECKING, Callable, Collection, Dict, FrozenSet, Iterable, List, Optional, Set, Tuple, Type, Union,
+ TYPE_CHECKING,
+ Callable,
+ Collection,
+ Dict,
+ FrozenSet,
+ Iterable,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ Type,
+ Union,
cast,
)
@@ -253,7 +264,7 @@ def __init__(
access_control: Optional[Dict] = None,
is_paused_upon_creation: Optional[bool] = None,
jinja_environment_kwargs: Optional[Dict] = None,
- tags: Optional[List[str]] = None
+ tags: Optional[List[str]] = None,
):
from airflow.utils.task_group import TaskGroup
@@ -286,9 +297,7 @@ def __init__(
self.timezone = start_date.tzinfo
elif 'start_date' in self.default_args and self.default_args['start_date']:
if isinstance(self.default_args['start_date'], str):
- self.default_args['start_date'] = (
- timezone.parse(self.default_args['start_date'])
- )
+ self.default_args['start_date'] = timezone.parse(self.default_args['start_date'])
self.timezone = self.default_args['start_date'].tzinfo
if not hasattr(self, 'timezone') or not self.timezone:
@@ -297,8 +306,8 @@ def __init__(
# Apply the timezone we settled on to end_date if it wasn't supplied
if 'end_date' in self.default_args and self.default_args['end_date']:
if isinstance(self.default_args['end_date'], str):
- self.default_args['end_date'] = (
- timezone.parse(self.default_args['end_date'], timezone=self.timezone)
+ self.default_args['end_date'] = timezone.parse(
+ self.default_args['end_date'], timezone=self.timezone
)
self.start_date = timezone.convert_to_utc(start_date)
@@ -306,13 +315,9 @@ def __init__(
# also convert tasks
if 'start_date' in self.default_args:
- self.default_args['start_date'] = (
- timezone.convert_to_utc(self.default_args['start_date'])
- )
+ self.default_args['start_date'] = timezone.convert_to_utc(self.default_args['start_date'])
if 'end_date' in self.default_args:
- self.default_args['end_date'] = (
- timezone.convert_to_utc(self.default_args['end_date'])
- )
+ self.default_args['end_date'] = timezone.convert_to_utc(self.default_args['end_date'])
self.schedule_interval = schedule_interval
if isinstance(template_searchpath, str):
@@ -328,13 +333,17 @@ def __init__(
if default_view in DEFAULT_VIEW_PRESETS:
self._default_view: str = default_view
else:
- raise AirflowException(f'Invalid values of dag.default_view: only support '
- f'{DEFAULT_VIEW_PRESETS}, but get {default_view}')
+ raise AirflowException(
+ f'Invalid values of dag.default_view: only support '
+ f'{DEFAULT_VIEW_PRESETS}, but get {default_view}'
+ )
if orientation in ORIENTATION_PRESETS:
self.orientation = orientation
else:
- raise AirflowException(f'Invalid values of dag.orientation: only support '
- f'{ORIENTATION_PRESETS}, but get {orientation}')
+ raise AirflowException(
+ f'Invalid values of dag.orientation: only support '
+ f'{ORIENTATION_PRESETS}, but get {orientation}'
+ )
self.catchup = catchup
self.is_subdag = False # DagBag.bag_dag() will set this to True if appropriate
@@ -354,8 +363,7 @@ def __repr__(self):
return f""
def __eq__(self, other):
- if (type(self) == type(other) and
- self.dag_id == other.dag_id):
+ if type(self) == type(other) and self.dag_id == other.dag_id:
# Use getattr() instead of __dict__ as __dict__ doesn't return
# correct values for properties.
@@ -414,7 +422,8 @@ def _upgrade_outdated_dag_access_control(access_control=None):
warnings.warn(
"The 'can_dag_read' and 'can_dag_edit' permissions are deprecated. "
"Please use 'can_read' and 'can_edit', respectively.",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
return updated_access_control
@@ -428,8 +437,8 @@ def date_range(
if num is not None:
end_date = None
return utils_date_range(
- start_date=start_date, end_date=end_date,
- num=num, delta=self.normalized_schedule_interval)
+ start_date=start_date, end_date=end_date, num=num, delta=self.normalized_schedule_interval
+ )
def is_fixed_time_schedule(self):
"""
@@ -516,8 +525,9 @@ def next_dagrun_info(
"automated" DagRuns for this dag (scheduled or backfill, but not
manual)
"""
- if (self.schedule_interval == "@once" and date_last_automated_dagrun) or \
- self.schedule_interval is None:
+ if (
+ self.schedule_interval == "@once" and date_last_automated_dagrun
+ ) or self.schedule_interval is None:
# Manual trigger, or already created the run for @once, can short circuit
return (None, None)
next_execution_date = self.next_dagrun_after_date(date_last_automated_dagrun)
@@ -588,10 +598,7 @@ def next_dagrun_after_date(self, date_last_automated_dagrun: Optional[pendulum.D
if next_run_date == self.start_date:
next_run_date = self.normalize_schedule(self.start_date)
- self.log.debug(
- "Dag start date: %s. Next run date: %s",
- self.start_date, next_run_date
- )
+ self.log.debug("Dag start date: %s. Next run date: %s", self.start_date, next_run_date)
# Don't schedule a dag beyond its end_date (as specified by the dag param)
if next_run_date and self.end_date and next_run_date > self.end_date:
@@ -630,8 +637,7 @@ def get_run_dates(self, start_date, end_date=None):
# next run date for a subdag isn't relevant (schedule_interval for subdags
# is ignored) so we use the dag run's start date in the case of a subdag
- next_run_date = (self.normalize_schedule(using_start_date)
- if not self.is_subdag else using_start_date)
+ next_run_date = self.normalize_schedule(using_start_date) if not self.is_subdag else using_start_date
while next_run_date and next_run_date <= using_end_date:
run_dates.append(next_run_date)
@@ -653,8 +659,9 @@ def normalize_schedule(self, dttm):
@provide_session
def get_last_dagrun(self, session=None, include_externally_triggered=False):
- return get_last_dagrun(self.dag_id, session=session,
- include_externally_triggered=include_externally_triggered)
+ return get_last_dagrun(
+ self.dag_id, session=session, include_externally_triggered=include_externally_triggered
+ )
@property
def dag_id(self) -> str:
@@ -720,8 +727,7 @@ def tasks(self) -> List[BaseOperator]:
@tasks.setter
def tasks(self, val):
- raise AttributeError(
- 'DAG.tasks can not be modified. Use dag.add_task() instead.')
+ raise AttributeError('DAG.tasks can not be modified. Use dag.add_task() instead.')
@property
def task_ids(self) -> List[str]:
@@ -783,8 +789,7 @@ def concurrency_reached(self):
@provide_session
def get_is_paused(self, session=None):
"""Returns a boolean indicating whether this DAG is paused"""
- qry = session.query(DagModel).filter(
- DagModel.dag_id == self.dag_id)
+ qry = session.query(DagModel).filter(DagModel.dag_id == self.dag_id)
return qry.value(DagModel.is_paused)
@property
@@ -870,10 +875,11 @@ def get_num_active_runs(self, external_trigger=None, session=None):
:return: number greater than 0 for active dag runs
"""
# .count() is inefficient
- query = (session
- .query(func.count())
- .filter(DagRun.dag_id == self.dag_id)
- .filter(DagRun.state == State.RUNNING))
+ query = (
+ session.query(func.count())
+ .filter(DagRun.dag_id == self.dag_id)
+ .filter(DagRun.state == State.RUNNING)
+ )
if external_trigger is not None:
query = query.filter(DagRun.external_trigger == external_trigger)
@@ -892,10 +898,9 @@ def get_dagrun(self, execution_date, session=None):
"""
dagrun = (
session.query(DagRun)
- .filter(
- DagRun.dag_id == self.dag_id,
- DagRun.execution_date == execution_date)
- .first())
+ .filter(DagRun.dag_id == self.dag_id, DagRun.execution_date == execution_date)
+ .first()
+ )
return dagrun
@@ -914,17 +919,17 @@ def get_dagruns_between(self, start_date, end_date, session=None):
.filter(
DagRun.dag_id == self.dag_id,
DagRun.execution_date >= start_date,
- DagRun.execution_date <= end_date)
- .all())
+ DagRun.execution_date <= end_date,
+ )
+ .all()
+ )
return dagruns
@provide_session
def get_latest_execution_date(self, session=None):
"""Returns the latest date for which at least one dag run exists"""
- return session.query(func.max(DagRun.execution_date)).filter(
- DagRun.dag_id == self.dag_id
- ).scalar()
+ return session.query(func.max(DagRun.execution_date)).filter(DagRun.dag_id == self.dag_id).scalar()
@property
def latest_execution_date(self):
@@ -941,12 +946,16 @@ def subdags(self):
"""Returns a list of the subdag objects associated to this DAG"""
# Check SubDag for class but don't check class directly
from airflow.operators.subdag_operator import SubDagOperator
+
subdag_lst = []
for task in self.tasks:
- if (isinstance(task, SubDagOperator) or
- # TODO remove in Airflow 2.0
- type(task).__name__ == 'SubDagOperator' or
- task.task_type == 'SubDagOperator'):
+ if (
+ isinstance(task, SubDagOperator)
+ or
+ # TODO remove in Airflow 2.0
+ type(task).__name__ == 'SubDagOperator'
+ or task.task_type == 'SubDagOperator'
+ ):
subdag_lst.append(task.subdag)
subdag_lst += task.subdag.subdags
return subdag_lst
@@ -967,7 +976,7 @@ def get_template_env(self) -> jinja2.Environment:
'loader': jinja2.FileSystemLoader(searchpath),
'undefined': self.template_undefined,
'extensions': ["jinja2.ext.do"],
- 'cache_size': 0
+ 'cache_size': 0,
}
if self.jinja_environment_kwargs:
jinja_env_options.update(self.jinja_environment_kwargs)
@@ -988,16 +997,13 @@ def set_dependency(self, upstream_task_id, downstream_task_id):
Simple utility method to set dependency between two tasks that
already have been added to the DAG using add_task()
"""
- self.get_task(upstream_task_id).set_downstream(
- self.get_task(downstream_task_id))
+ self.get_task(upstream_task_id).set_downstream(self.get_task(downstream_task_id))
@provide_session
- def get_task_instances(
- self, start_date=None, end_date=None, state=None, session=None):
+ def get_task_instances(self, start_date=None, end_date=None, state=None, session=None):
if not start_date:
start_date = (timezone.utcnow() - timedelta(30)).date()
- start_date = timezone.make_aware(
- datetime.combine(start_date, datetime.min.time()))
+ start_date = timezone.make_aware(datetime.combine(start_date, datetime.min.time()))
tis = session.query(TaskInstance).filter(
TaskInstance.dag_id == self.dag_id,
@@ -1020,8 +1026,7 @@ def get_task_instances(
else:
not_none_state = [s for s in state if s]
tis = tis.filter(
- or_(TaskInstance.state.in_(not_none_state),
- TaskInstance.state.is_(None))
+ or_(TaskInstance.state.in_(not_none_state), TaskInstance.state.is_(None))
)
else:
tis = tis.filter(TaskInstance.state.in_(state))
@@ -1088,18 +1093,17 @@ def topological_sort(self, include_subdag_tasks: bool = False):
graph_sorted.extend(node.subdag.topological_sort(include_subdag_tasks=True))
if not acyclic:
- raise AirflowException("A cyclic dependency occurred in dag: {}"
- .format(self.dag_id))
+ raise AirflowException(f"A cyclic dependency occurred in dag: {self.dag_id}")
return tuple(graph_sorted)
@provide_session
def set_dag_runs_state(
- self,
- state: str = State.RUNNING,
- session: Session = None,
- start_date: Optional[datetime] = None,
- end_date: Optional[datetime] = None,
+ self,
+ state: str = State.RUNNING,
+ session: Session = None,
+ start_date: Optional[datetime] = None,
+ end_date: Optional[datetime] = None,
) -> None:
query = session.query(DagRun).filter_by(dag_id=self.dag_id)
if start_date:
@@ -1110,20 +1114,22 @@ def set_dag_runs_state(
@provide_session
def clear(
- self, start_date=None, end_date=None,
- only_failed=False,
- only_running=False,
- confirm_prompt=False,
- include_subdags=True,
- include_parentdag=True,
- dag_run_state: str = State.RUNNING,
- dry_run=False,
- session=None,
- get_tis=False,
- recursion_depth=0,
- max_recursion_depth=None,
- dag_bag=None,
- visited_external_tis=None,
+ self,
+ start_date=None,
+ end_date=None,
+ only_failed=False,
+ only_running=False,
+ confirm_prompt=False,
+ include_subdags=True,
+ include_parentdag=True,
+ dag_run_state: str = State.RUNNING,
+ dry_run=False,
+ session=None,
+ get_tis=False,
+ recursion_depth=0,
+ max_recursion_depth=None,
+ dag_bag=None,
+ visited_external_tis=None,
):
"""
Clears a set of task instances associated with the current dag for
@@ -1169,10 +1175,7 @@ def clear(
# Crafting the right filter for dag_id and task_ids combo
conditions = []
for dag in self.subdags + [self]:
- conditions.append(
- (TI.dag_id == dag.dag_id) &
- TI.task_id.in_(dag.task_ids)
- )
+ conditions.append((TI.dag_id == dag.dag_id) & TI.task_id.in_(dag.task_ids))
tis = tis.filter(or_(*conditions))
else:
tis = session.query(TI).filter(TI.dag_id == self.dag_id)
@@ -1182,32 +1185,34 @@ def clear(
p_dag = self.parent_dag.sub_dag(
task_ids_or_regex=r"^{}$".format(self.dag_id.split('.')[1]),
include_upstream=False,
- include_downstream=True)
+ include_downstream=True,
+ )
- tis = tis.union(p_dag.clear(
- start_date=start_date, end_date=end_date,
- only_failed=only_failed,
- only_running=only_running,
- confirm_prompt=confirm_prompt,
- include_subdags=include_subdags,
- include_parentdag=False,
- dag_run_state=dag_run_state,
- get_tis=True,
- session=session,
- recursion_depth=recursion_depth,
- max_recursion_depth=max_recursion_depth,
- dag_bag=dag_bag,
- visited_external_tis=visited_external_tis
- ))
+ tis = tis.union(
+ p_dag.clear(
+ start_date=start_date,
+ end_date=end_date,
+ only_failed=only_failed,
+ only_running=only_running,
+ confirm_prompt=confirm_prompt,
+ include_subdags=include_subdags,
+ include_parentdag=False,
+ dag_run_state=dag_run_state,
+ get_tis=True,
+ session=session,
+ recursion_depth=recursion_depth,
+ max_recursion_depth=max_recursion_depth,
+ dag_bag=dag_bag,
+ visited_external_tis=visited_external_tis,
+ )
+ )
if start_date:
tis = tis.filter(TI.execution_date >= start_date)
if end_date:
tis = tis.filter(TI.execution_date <= end_date)
if only_failed:
- tis = tis.filter(or_(
- TI.state == State.FAILED,
- TI.state == State.UPSTREAM_FAILED))
+ tis = tis.filter(or_(TI.state == State.FAILED, TI.state == State.UPSTREAM_FAILED))
if only_running:
tis = tis.filter(TI.state == State.RUNNING)
@@ -1224,8 +1229,9 @@ def clear(
if ti_key not in visited_external_tis:
# Only clear this ExternalTaskMarker if it's not already visited by the
# recursive calls to dag.clear().
- task: ExternalTaskMarker = cast(ExternalTaskMarker,
- copy.copy(self.get_task(ti.task_id)))
+ task: ExternalTaskMarker = cast(
+ ExternalTaskMarker, copy.copy(self.get_task(ti.task_id))
+ )
ti.task = task
if recursion_depth == 0:
@@ -1235,16 +1241,19 @@ def clear(
if recursion_depth + 1 > max_recursion_depth:
# Prevent cycles or accidents.
- raise AirflowException("Maximum recursion depth {} reached for {} {}. "
- "Attempted to clear too many tasks "
- "or there may be a cyclic dependency."
- .format(max_recursion_depth,
- ExternalTaskMarker.__name__, ti.task_id))
+ raise AirflowException(
+ "Maximum recursion depth {} reached for {} {}. "
+ "Attempted to clear too many tasks "
+ "or there may be a cyclic dependency.".format(
+ max_recursion_depth, ExternalTaskMarker.__name__, ti.task_id
+ )
+ )
ti.render_templates()
- external_tis = session.query(TI).filter(TI.dag_id == task.external_dag_id,
- TI.task_id == task.external_task_id,
- TI.execution_date ==
- pendulum.parse(task.execution_date))
+ external_tis = session.query(TI).filter(
+ TI.dag_id == task.external_dag_id,
+ TI.task_id == task.external_task_id,
+ TI.execution_date == pendulum.parse(task.execution_date),
+ )
for tii in external_tis:
if not dag_bag:
@@ -1255,22 +1264,26 @@ def clear(
downstream = external_dag.sub_dag(
task_ids_or_regex=fr"^{tii.task_id}$",
include_upstream=False,
- include_downstream=True
+ include_downstream=True,
+ )
+ tis = tis.union(
+ downstream.clear(
+ start_date=tii.execution_date,
+ end_date=tii.execution_date,
+ only_failed=only_failed,
+ only_running=only_running,
+ confirm_prompt=confirm_prompt,
+ include_subdags=include_subdags,
+ include_parentdag=False,
+ dag_run_state=dag_run_state,
+ get_tis=True,
+ session=session,
+ recursion_depth=recursion_depth + 1,
+ max_recursion_depth=max_recursion_depth,
+ dag_bag=dag_bag,
+ visited_external_tis=visited_external_tis,
+ )
)
- tis = tis.union(downstream.clear(start_date=tii.execution_date,
- end_date=tii.execution_date,
- only_failed=only_failed,
- only_running=only_running,
- confirm_prompt=confirm_prompt,
- include_subdags=include_subdags,
- include_parentdag=False,
- dag_run_state=dag_run_state,
- get_tis=True,
- session=session,
- recursion_depth=recursion_depth + 1,
- max_recursion_depth=max_recursion_depth,
- dag_bag=dag_bag,
- visited_external_tis=visited_external_tis))
visited_external_tis.add(ti_key)
if get_tis:
@@ -1291,9 +1304,8 @@ def clear(
if confirm_prompt:
ti_list = "\n".join([str(t) for t in tis])
question = (
- "You are about to delete these {count} tasks:\n"
- "{ti_list}\n\n"
- "Are you sure? (yes/no): ").format(count=count, ti_list=ti_list)
+ "You are about to delete these {count} tasks:\n" "{ti_list}\n\n" "Are you sure? (yes/no): "
+ ).format(count=count, ti_list=ti_list)
do_it = utils.helpers.ask_yesno(question)
if do_it:
@@ -1318,16 +1330,17 @@ def clear(
@classmethod
def clear_dags(
- cls, dags,
- start_date=None,
- end_date=None,
- only_failed=False,
- only_running=False,
- confirm_prompt=False,
- include_subdags=True,
- include_parentdag=False,
- dag_run_state=State.RUNNING,
- dry_run=False,
+ cls,
+ dags,
+ start_date=None,
+ end_date=None,
+ only_failed=False,
+ only_running=False,
+ confirm_prompt=False,
+ include_subdags=True,
+ include_parentdag=False,
+ dag_run_state=State.RUNNING,
+ dry_run=False,
):
all_tis = []
for dag in dags:
@@ -1340,7 +1353,8 @@ def clear_dags(
include_subdags=include_subdags,
include_parentdag=include_parentdag,
dag_run_state=dag_run_state,
- dry_run=True)
+ dry_run=True,
+ )
all_tis.extend(tis)
if dry_run:
@@ -1354,22 +1368,22 @@ def clear_dags(
if confirm_prompt:
ti_list = "\n".join([str(t) for t in all_tis])
question = (
- "You are about to delete these {} tasks:\n"
- "{}\n\n"
- "Are you sure? (yes/no): ").format(count, ti_list)
+ "You are about to delete these {} tasks:\n" "{}\n\n" "Are you sure? (yes/no): "
+ ).format(count, ti_list)
do_it = utils.helpers.ask_yesno(question)
if do_it:
for dag in dags:
- dag.clear(start_date=start_date,
- end_date=end_date,
- only_failed=only_failed,
- only_running=only_running,
- confirm_prompt=False,
- include_subdags=include_subdags,
- dag_run_state=dag_run_state,
- dry_run=False,
- )
+ dag.clear(
+ start_date=start_date,
+ end_date=end_date,
+ only_failed=only_failed,
+ only_running=only_running,
+ confirm_prompt=False,
+ include_subdags=include_subdags,
+ dag_run_state=dag_run_state,
+ dry_run=False,
+ )
else:
count = 0
print("Cancelled, nothing was cleared.")
@@ -1432,8 +1446,7 @@ def partial_subset(
self._task_group = task_group
if isinstance(task_ids_or_regex, (str, PatternType)):
- matched_tasks = [
- t for t in self.tasks if re.findall(task_ids_or_regex, t.task_id)]
+ matched_tasks = [t for t in self.tasks if re.findall(task_ids_or_regex, t.task_id)]
else:
matched_tasks = [t for t in self.tasks if t.task_id in task_ids_or_regex]
@@ -1448,8 +1461,10 @@ def partial_subset(
# Compiling the unique list of tasks that made the cut
# Make sure to not recursively deepcopy the dag while copying the task
- dag.task_dict = {t.task_id: copy.deepcopy(t, {id(t.dag): dag}) # type: ignore
- for t in matched_tasks + also_include}
+ dag.task_dict = {
+ t.task_id: copy.deepcopy(t, {id(t.dag): dag}) # type: ignore
+ for t in matched_tasks + also_include
+ }
def filter_task_group(group, parent_group):
"""Exclude tasks not included in the subdag from the given TaskGroup."""
@@ -1487,8 +1502,7 @@ def filter_task_group(group, parent_group):
# Removing upstream/downstream references to tasks that did not
# make the cut
t._upstream_task_ids = t.upstream_task_ids.intersection(dag.task_dict.keys())
- t._downstream_task_ids = t.downstream_task_ids.intersection(
- dag.task_dict.keys())
+ t._downstream_task_ids = t.downstream_task_ids.intersection(dag.task_dict.keys())
if len(dag.tasks) < len(self.tasks):
dag.partial = True
@@ -1523,12 +1537,10 @@ def pickle_info(self):
@provide_session
def pickle(self, session=None) -> DagPickle:
- dag = session.query(
- DagModel).filter(DagModel.dag_id == self.dag_id).first()
+ dag = session.query(DagModel).filter(DagModel.dag_id == self.dag_id).first()
dp = None
if dag and dag.pickle_id:
- dp = session.query(DagPickle).filter(
- DagPickle.id == dag.pickle_id).first()
+ dp = session.query(DagPickle).filter(DagPickle.id == dag.pickle_id).first()
if not dp or dp.pickle != self:
dp = DagPickle(dag=self)
session.add(dp)
@@ -1540,6 +1552,7 @@ def pickle(self, session=None) -> DagPickle:
def tree_view(self) -> None:
"""Print an ASCII tree representation of the DAG."""
+
def get_downstream(task, level=0):
print((" " * level * 4) + str(task))
level += 1
@@ -1552,6 +1565,7 @@ def get_downstream(task, level=0):
@property
def task(self):
from airflow.operators.python import task
+
return functools.partial(task, dag=self)
def add_task(self, task):
@@ -1579,10 +1593,10 @@ def add_task(self, task):
elif task.end_date and self.end_date:
task.end_date = min(task.end_date, self.end_date)
- if ((task.task_id in self.task_dict and self.task_dict[task.task_id] is not task)
- or task.task_id in self._task_group.used_group_ids):
- raise DuplicateTaskIdFound(
- f"Task id '{task.task_id}' has already been added to the DAG")
+ if (
+ task.task_id in self.task_dict and self.task_dict[task.task_id] is not task
+ ) or task.task_id in self._task_group.used_group_ids:
+ raise DuplicateTaskIdFound(f"Task id '{task.task_id}' has already been added to the DAG")
else:
self.task_dict[task.task_id] = task
task.dag = self
@@ -1602,21 +1616,21 @@ def add_tasks(self, tasks):
self.add_task(task)
def run(
- self,
- start_date=None,
- end_date=None,
- mark_success=False,
- local=False,
- executor=None,
- donot_pickle=conf.getboolean('core', 'donot_pickle'),
- ignore_task_deps=False,
- ignore_first_depends_on_past=True,
- pool=None,
- delay_on_limit_secs=1.0,
- verbose=False,
- conf=None,
- rerun_failed_tasks=False,
- run_backwards=False,
+ self,
+ start_date=None,
+ end_date=None,
+ mark_success=False,
+ local=False,
+ executor=None,
+ donot_pickle=conf.getboolean('core', 'donot_pickle'),
+ ignore_task_deps=False,
+ ignore_first_depends_on_past=True,
+ pool=None,
+ delay_on_limit_secs=1.0,
+ verbose=False,
+ conf=None,
+ rerun_failed_tasks=False,
+ run_backwards=False,
):
"""
Runs the DAG.
@@ -1654,11 +1668,14 @@ def run(
"""
from airflow.jobs.backfill_job import BackfillJob
+
if not executor and local:
from airflow.executors.local_executor import LocalExecutor
+
executor = LocalExecutor()
elif not executor:
from airflow.executors.executor_loader import ExecutorLoader
+
executor = ExecutorLoader.get_default_executor()
job = BackfillJob(
self,
@@ -1681,6 +1698,7 @@ def run(
def cli(self):
"""Exposes a CLI specific to this DAG"""
from airflow.cli import cli_parser
+
parser = cli_parser.get_parser(dag_parser=True)
args = parser.parse_args()
args.func(args, self)
@@ -1747,7 +1765,7 @@ def create_dagrun(
state=state,
run_type=run_type,
dag_hash=dag_hash,
- creating_job_id=creating_job_id
+ creating_job_id=creating_job_id,
)
session.add(run)
session.flush()
@@ -1791,10 +1809,10 @@ def bulk_write_to_db(cls, dags: Collection["DAG"], session=None):
dag_by_ids = {dag.dag_id: dag for dag in dags}
dag_ids = set(dag_by_ids.keys())
query = (
- session
- .query(DagModel)
+ session.query(DagModel)
.options(joinedload(DagModel.tags, innerjoin=False))
- .filter(DagModel.dag_id.in_(dag_ids)))
+ .filter(DagModel.dag_id.in_(dag_ids))
+ )
orm_dags = with_row_locks(query, of=DagModel).all()
existing_dag_ids = {orm_dag.dag_id for orm_dag in orm_dags}
@@ -1811,21 +1829,31 @@ def bulk_write_to_db(cls, dags: Collection["DAG"], session=None):
orm_dags.append(orm_dag)
# Get the latest dag run for each existing dag as a single query (avoid n+1 query)
- most_recent_dag_runs = dict(session.query(DagRun.dag_id, func.max_(DagRun.execution_date)).filter(
- DagRun.dag_id.in_(existing_dag_ids),
- or_(
- DagRun.run_type == DagRunType.BACKFILL_JOB,
- DagRun.run_type == DagRunType.SCHEDULED,
- DagRun.external_trigger.is_(True),
- ),
- ).group_by(DagRun.dag_id).all())
+ most_recent_dag_runs = dict(
+ session.query(DagRun.dag_id, func.max_(DagRun.execution_date))
+ .filter(
+ DagRun.dag_id.in_(existing_dag_ids),
+ or_(
+ DagRun.run_type == DagRunType.BACKFILL_JOB,
+ DagRun.run_type == DagRunType.SCHEDULED,
+ DagRun.external_trigger.is_(True),
+ ),
+ )
+ .group_by(DagRun.dag_id)
+ .all()
+ )
# Get number of active dagruns for all dags we are processing as a single query.
- num_active_runs = dict(session.query(DagRun.dag_id, func.count('*')).filter(
- DagRun.dag_id.in_(existing_dag_ids),
- DagRun.state == State.RUNNING, # pylint: disable=comparison-with-callable
- DagRun.external_trigger.is_(False)
- ).group_by(DagRun.dag_id).all())
+ num_active_runs = dict(
+ session.query(DagRun.dag_id, func.count('*'))
+ .filter(
+ DagRun.dag_id.in_(existing_dag_ids),
+ DagRun.state == State.RUNNING, # pylint: disable=comparison-with-callable
+ DagRun.external_trigger.is_(False),
+ )
+ .group_by(DagRun.dag_id)
+ .all()
+ )
for orm_dag in sorted(orm_dags, key=lambda d: d.dag_id):
dag = dag_by_ids[orm_dag.dag_id]
@@ -1843,9 +1871,7 @@ def bulk_write_to_db(cls, dags: Collection["DAG"], session=None):
orm_dag.description = dag.description
orm_dag.schedule_interval = dag.schedule_interval
orm_dag.concurrency = dag.concurrency
- orm_dag.has_task_concurrency_limits = any(
- t.task_concurrency is not None for t in dag.tasks
- )
+ orm_dag.has_task_concurrency_limits = any(t.task_concurrency is not None for t in dag.tasks)
orm_dag.calculate_dagrun_date_fields(
dag,
@@ -1906,8 +1932,7 @@ def deactivate_unknown_dags(active_dag_ids, session=None):
"""
if len(active_dag_ids) == 0:
return
- for dag in session.query(
- DagModel).filter(~DagModel.dag_id.in_(active_dag_ids)).all():
+ for dag in session.query(DagModel).filter(~DagModel.dag_id.in_(active_dag_ids)).all():
dag.is_active = False
session.merge(dag)
session.commit()
@@ -1924,12 +1949,15 @@ def deactivate_stale_dags(expiration_date, session=None):
:type expiration_date: datetime
:return: None
"""
- for dag in session.query(
- DagModel).filter(DagModel.last_scheduler_run < expiration_date,
- DagModel.is_active).all():
+ for dag in (
+ session.query(DagModel)
+ .filter(DagModel.last_scheduler_run < expiration_date, DagModel.is_active)
+ .all()
+ ):
log.info(
"Deactivating DAG ID %s since it was last touched by the scheduler at %s",
- dag.dag_id, dag.last_scheduler_run.isoformat()
+ dag.dag_id,
+ dag.last_scheduler_run.isoformat(),
)
dag.is_active = False
session.merge(dag)
@@ -1965,9 +1993,9 @@ def get_num_task_instances(dag_id, task_ids=None, states=None, session=None):
qry = qry.filter(TaskInstance.state.is_(None))
else:
not_none_states = [state for state in states if state]
- qry = qry.filter(or_(
- TaskInstance.state.in_(not_none_states),
- TaskInstance.state.is_(None)))
+ qry = qry.filter(
+ or_(TaskInstance.state.in_(not_none_states), TaskInstance.state.is_(None))
+ )
else:
qry = qry.filter(TaskInstance.state.in_(states))
return qry.scalar()
@@ -1977,12 +2005,25 @@ def get_serialized_fields(cls):
"""Stringified DAGs and operators contain exactly these fields."""
if not cls.__serialized_fields:
cls.__serialized_fields = frozenset(vars(DAG(dag_id='test')).keys()) - {
- 'parent_dag', '_old_context_manager_dags', 'safe_dag_id', 'last_loaded',
- '_full_filepath', 'user_defined_filters', 'user_defined_macros',
- 'partial', '_old_context_manager_dags',
- '_pickle_id', '_log', 'is_subdag', 'task_dict', 'template_searchpath',
- 'sla_miss_callback', 'on_success_callback', 'on_failure_callback',
- 'template_undefined', 'jinja_environment_kwargs'
+ 'parent_dag',
+ '_old_context_manager_dags',
+ 'safe_dag_id',
+ 'last_loaded',
+ '_full_filepath',
+ 'user_defined_filters',
+ 'user_defined_macros',
+ 'partial',
+ '_old_context_manager_dags',
+ '_pickle_id',
+ '_log',
+ 'is_subdag',
+ 'task_dict',
+ 'template_searchpath',
+ 'sla_miss_callback',
+ 'on_success_callback',
+ 'on_failure_callback',
+ 'template_undefined',
+ 'jinja_environment_kwargs',
}
return cls.__serialized_fields
@@ -2009,9 +2050,7 @@ class DagModel(Base):
root_dag_id = Column(String(ID_LEN))
# A DAG can be paused from the UI / DB
# Set this default value of is_paused based on a configuration value!
- is_paused_at_creation = conf\
- .getboolean('core',
- 'dags_are_paused_at_creation')
+ is_paused_at_creation = conf.getboolean('core', 'dags_are_paused_at_creation')
is_paused = Column(Boolean, default=is_paused_at_creation)
# Whether the DAG is a subdag
is_subdag = Column(Boolean, default=False)
@@ -2058,11 +2097,7 @@ class DagModel(Base):
Index('idx_next_dagrun_create_after', next_dagrun_create_after, unique=False),
)
- NUM_DAGS_PER_DAGRUN_QUERY = conf.getint(
- 'scheduler',
- 'max_dagruns_to_create_per_loop',
- fallback=10
- )
+ NUM_DAGS_PER_DAGRUN_QUERY = conf.getint('scheduler', 'max_dagruns_to_create_per_loop', fallback=10)
def __init__(self, **kwargs):
super().__init__(**kwargs)
@@ -2091,8 +2126,9 @@ def get_current(cls, dag_id, session=None):
@provide_session
def get_last_dagrun(self, session=None, include_externally_triggered=False):
- return get_last_dagrun(self.dag_id, session=session,
- include_externally_triggered=include_externally_triggered)
+ return get_last_dagrun(
+ self.dag_id, session=session, include_externally_triggered=include_externally_triggered
+ )
@staticmethod
@provide_session
@@ -2127,10 +2163,7 @@ def safe_dag_id(self):
return self.dag_id.replace('.', '__dot__')
@provide_session
- def set_is_paused(self,
- is_paused: bool,
- including_subdags: bool = True,
- session=None) -> None:
+ def set_is_paused(self, is_paused: bool, including_subdags: bool = True, session=None) -> None:
"""
Pause/Un-pause a DAG.
@@ -2142,12 +2175,8 @@ def set_is_paused(self,
DagModel.dag_id == self.dag_id,
]
if including_subdags:
- filter_query.append(
- DagModel.root_dag_id == self.dag_id
- )
- session.query(DagModel).filter(or_(
- *filter_query
- )).update(
+ filter_query.append(DagModel.root_dag_id == self.dag_id)
+ session.query(DagModel).filter(or_(*filter_query)).update(
{DagModel.is_paused: is_paused}, synchronize_session='fetch'
)
session.commit()
@@ -2162,8 +2191,7 @@ def deactivate_deleted_dags(cls, alive_dag_filelocs: List[str], session=None):
:param alive_dag_filelocs: file paths of alive DAGs
:param session: ORM Session
"""
- log.debug("Deactivating DAGs (for which DAG files are deleted) from %s table ",
- cls.__tablename__)
+ log.debug("Deactivating DAGs (for which DAG files are deleted) from %s table ", cls.__tablename__)
dag_models = session.query(cls).all()
try:
@@ -2195,21 +2223,21 @@ def dags_needing_dagruns(cls, session: Session):
# TODO[HA]: Bake this query, it is run _A lot_
# We limit so that _one_ scheduler doesn't try to do all the creation
# of dag runs
- query = session.query(cls).filter(
- cls.is_paused.is_(False),
- cls.is_active.is_(True),
- cls.next_dagrun_create_after <= func.now(),
- ).order_by(
- cls.next_dagrun_create_after
- ).limit(cls.NUM_DAGS_PER_DAGRUN_QUERY)
+ query = (
+ session.query(cls)
+ .filter(
+ cls.is_paused.is_(False),
+ cls.is_active.is_(True),
+ cls.next_dagrun_create_after <= func.now(),
+ )
+ .order_by(cls.next_dagrun_create_after)
+ .limit(cls.NUM_DAGS_PER_DAGRUN_QUERY)
+ )
return with_row_locks(query, of=cls, **skip_locked(session=session))
def calculate_dagrun_date_fields(
- self,
- dag: DAG,
- most_recent_dag_run: Optional[pendulum.DateTime],
- active_runs_of_dag: int
+ self, dag: DAG, most_recent_dag_run: Optional[pendulum.DateTime], active_runs_of_dag: int
) -> None:
"""
Calculate ``next_dagrun`` and `next_dagrun_create_after``
@@ -2224,7 +2252,9 @@ def calculate_dagrun_date_fields(
# Since this happens every time the dag is parsed it would be quite spammy at info
log.debug(
"DAG %s is at (or above) max_active_runs (%d of %d), not creating any more runs",
- dag.dag_id, active_runs_of_dag, dag.max_active_runs
+ dag.dag_id,
+ active_runs_of_dag,
+ dag.max_active_runs,
)
self.next_dagrun_create_after = None
@@ -2241,6 +2271,7 @@ def dag(*dag_args, **dag_kwargs):
:param dag_kwargs: Kwargs for DAG object.
:type dag_kwargs: dict
"""
+
def wrapper(f: Callable):
# Get dag initializer signature and bind it to validate that dag_args, and dag_kwargs are correct
dag_sig = signature(DAG.__init__)
@@ -2276,7 +2307,9 @@ def factory(*args, **kwargs):
# Return dag object such that it's accessible in Globals.
return dag_obj
+
return factory
+
return wrapper
@@ -2287,6 +2320,7 @@ def factory(*args, **kwargs):
from sqlalchemy.orm import relationship
from airflow.models.serialized_dag import SerializedDagModel
+
DagModel.serialized_dag = relationship(SerializedDagModel)
diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py
index 7f3bab2e9e587..004d6f7c22551 100644
--- a/airflow/models/dagbag.py
+++ b/airflow/models/dagbag.py
@@ -97,13 +97,16 @@ def __init__(
):
# Avoid circular import
from airflow.models.dag import DAG
+
super().__init__()
if store_serialized_dags:
warnings.warn(
"The store_serialized_dags parameter has been deprecated. "
"You should pass the read_dags_from_db parameter.",
- DeprecationWarning, stacklevel=2)
+ DeprecationWarning,
+ stacklevel=2,
+ )
read_dags_from_db = store_serialized_dags
dag_folder = dag_folder or settings.DAGS_FOLDER
@@ -125,7 +128,8 @@ def __init__(
dag_folder=dag_folder,
include_examples=include_examples,
include_smart_sensor=include_smart_sensor,
- safe_mode=safe_mode)
+ safe_mode=safe_mode,
+ )
def size(self) -> int:
""":return: the amount of dags contained in this dagbag"""
@@ -135,8 +139,9 @@ def size(self) -> int:
def store_serialized_dags(self) -> bool:
"""Whether or not to read dags from DB"""
warnings.warn(
- "The store_serialized_dags property has been deprecated. "
- "Use read_dags_from_db instead.", DeprecationWarning, stacklevel=2
+ "The store_serialized_dags property has been deprecated. " "Use read_dags_from_db instead.",
+ DeprecationWarning,
+ stacklevel=2,
)
return self.read_dags_from_db
@@ -158,6 +163,7 @@ def get_dag(self, dag_id, session: Session = None):
if self.read_dags_from_db:
# Import here so that serialized dag is only imported when serialization is enabled
from airflow.models.serialized_dag import SerializedDagModel
+
if dag_id not in self.dags:
# Load from DB if not (yet) in the bag
self._add_dag_from_db(dag_id=dag_id, session=session)
@@ -169,8 +175,8 @@ def get_dag(self, dag_id, session: Session = None):
# 3. if (2) is yes, fetch the Serialized DAG.
min_serialized_dag_fetch_secs = timedelta(seconds=settings.MIN_SERIALIZED_DAG_FETCH_INTERVAL)
if (
- dag_id in self.dags_last_fetched and
- timezone.utcnow() > self.dags_last_fetched[dag_id] + min_serialized_dag_fetch_secs
+ dag_id in self.dags_last_fetched
+ and timezone.utcnow() > self.dags_last_fetched[dag_id] + min_serialized_dag_fetch_secs
):
sd_last_updated_datetime = SerializedDagModel.get_last_updated_datetime(
dag_id=dag_id,
@@ -196,11 +202,12 @@ def get_dag(self, dag_id, session: Session = None):
# If the dag corresponding to root_dag_id is absent or expired
is_missing = root_dag_id not in self.dags
- is_expired = (orm_dag.last_expired and dag and dag.last_loaded < orm_dag.last_expired)
+ is_expired = orm_dag.last_expired and dag and dag.last_loaded < orm_dag.last_expired
if is_missing or is_expired:
# Reprocess source file
found_dags = self.process_file(
- filepath=correct_maybe_zipped(orm_dag.fileloc), only_if_updated=False)
+ filepath=correct_maybe_zipped(orm_dag.fileloc), only_if_updated=False
+ )
# If the source file no longer exports `dag_id`, delete it from self.dags
if found_dags and dag_id in [found_dag.dag_id for found_dag in found_dags]:
@@ -212,6 +219,7 @@ def get_dag(self, dag_id, session: Session = None):
def _add_dag_from_db(self, dag_id: str, session: Session):
"""Add DAG to DagBag from DB"""
from airflow.models.serialized_dag import SerializedDagModel
+
row = SerializedDagModel.get(dag_id, session)
if not row:
raise SerializedDagNotFound(f"DAG '{dag_id}' not found in serialized_dag table")
@@ -240,9 +248,11 @@ def process_file(self, filepath, only_if_updated=True, safe_mode=True):
# This failed before in what may have been a git sync
# race condition
file_last_changed_on_disk = datetime.fromtimestamp(os.path.getmtime(filepath))
- if only_if_updated \
- and filepath in self.file_last_changed \
- and file_last_changed_on_disk == self.file_last_changed[filepath]:
+ if (
+ only_if_updated
+ and filepath in self.file_last_changed
+ and file_last_changed_on_disk == self.file_last_changed[filepath]
+ ):
return []
except Exception as e: # pylint: disable=broad-except
self.log.exception(e)
@@ -315,8 +325,7 @@ def _load_modules_from_zip(self, filepath, safe_mode):
if not self.has_logged or True:
self.has_logged = True
self.log.info(
- "File %s:%s assumed to contain no DAGs. Skipping.",
- filepath, zip_info.filename
+ "File %s:%s assumed to contain no DAGs. Skipping.", filepath, zip_info.filename
)
continue
@@ -341,12 +350,7 @@ def _process_modules(self, filepath, mods, file_last_changed_on_disk):
from airflow.models.dag import DAG # Avoid circular import
is_zipfile = zipfile.is_zipfile(filepath)
- top_level_dags = [
- o
- for m in mods
- for o in list(m.__dict__.values())
- if isinstance(o, DAG)
- ]
+ top_level_dags = [o for m in mods for o in list(m.__dict__.values()) if isinstance(o, DAG)]
found_dags = []
@@ -362,15 +366,11 @@ def _process_modules(self, filepath, mods, file_last_changed_on_disk):
self.bag_dag(dag=dag, root_dag=dag)
found_dags.append(dag)
found_dags += dag.subdags
- except (CroniterBadCronError,
- CroniterBadDateError,
- CroniterNotAlphaError) as cron_e:
+ except (CroniterBadCronError, CroniterBadDateError, CroniterNotAlphaError) as cron_e:
self.log.exception("Failed to bag_dag: %s", dag.full_filepath)
self.import_errors[dag.full_filepath] = f"Invalid Cron expression: {cron_e}"
- self.file_last_changed[dag.full_filepath] = \
- file_last_changed_on_disk
- except (AirflowDagCycleException,
- AirflowClusterPolicyViolation) as exception:
+ self.file_last_changed[dag.full_filepath] = file_last_changed_on_disk
+ except (AirflowDagCycleException, AirflowClusterPolicyViolation) as exception:
self.log.exception("Failed to bag_dag: %s", dag.full_filepath)
self.import_errors[dag.full_filepath] = str(exception)
self.file_last_changed[dag.full_filepath] = file_last_changed_on_disk
@@ -412,12 +412,13 @@ def bag_dag(self, dag, root_dag):
raise cycle_exception
def collect_dags(
- self,
- dag_folder=None,
- only_if_updated=True,
- include_examples=conf.getboolean('core', 'LOAD_EXAMPLES'),
- include_smart_sensor=conf.getboolean('smart_sensor', 'USE_SMART_SENSOR'),
- safe_mode=conf.getboolean('core', 'DAG_DISCOVERY_SAFE_MODE')):
+ self,
+ dag_folder=None,
+ only_if_updated=True,
+ include_examples=conf.getboolean('core', 'LOAD_EXAMPLES'),
+ include_smart_sensor=conf.getboolean('smart_sensor', 'USE_SMART_SENSOR'),
+ safe_mode=conf.getboolean('core', 'DAG_DISCOVERY_SAFE_MODE'),
+ ):
"""
Given a file path or a folder, this method looks for python modules,
imports them and adds them to the dagbag collection.
@@ -440,26 +441,26 @@ def collect_dags(
stats = []
dag_folder = correct_maybe_zipped(dag_folder)
- for filepath in list_py_file_paths(dag_folder,
- safe_mode=safe_mode,
- include_examples=include_examples,
- include_smart_sensor=include_smart_sensor):
+ for filepath in list_py_file_paths(
+ dag_folder,
+ safe_mode=safe_mode,
+ include_examples=include_examples,
+ include_smart_sensor=include_smart_sensor,
+ ):
try:
file_parse_start_dttm = timezone.utcnow()
- found_dags = self.process_file(
- filepath,
- only_if_updated=only_if_updated,
- safe_mode=safe_mode
- )
+ found_dags = self.process_file(filepath, only_if_updated=only_if_updated, safe_mode=safe_mode)
file_parse_end_dttm = timezone.utcnow()
- stats.append(FileLoadStat(
- file=filepath.replace(settings.DAGS_FOLDER, ''),
- duration=file_parse_end_dttm - file_parse_start_dttm,
- dag_num=len(found_dags),
- task_num=sum([len(dag.tasks) for dag in found_dags]),
- dags=str([dag.dag_id for dag in found_dags]),
- ))
+ stats.append(
+ FileLoadStat(
+ file=filepath.replace(settings.DAGS_FOLDER, ''),
+ duration=file_parse_end_dttm - file_parse_start_dttm,
+ dag_num=len(found_dags),
+ task_num=sum([len(dag.tasks) for dag in found_dags]),
+ dags=str([dag.dag_id for dag in found_dags]),
+ )
+ )
except Exception as e: # pylint: disable=broad-except
self.log.exception(e)
@@ -468,19 +469,17 @@ def collect_dags(
Stats.gauge('collect_dags', durations, 1)
Stats.gauge('dagbag_size', len(self.dags), 1)
Stats.gauge('dagbag_import_errors', len(self.import_errors), 1)
- self.dagbag_stats = sorted(
- stats, key=lambda x: x.duration, reverse=True)
+ self.dagbag_stats = sorted(stats, key=lambda x: x.duration, reverse=True)
for file_stat in self.dagbag_stats:
# file_stat.file similar format: /subdir/dag_name.py
# TODO: Remove for Airflow 2.0
filename = file_stat.file.split('/')[-1].replace('.py', '')
- Stats.timing('dag.loading-duration.{}'.
- format(filename),
- file_stat.duration)
+ Stats.timing(f'dag.loading-duration.{filename}', file_stat.duration)
def collect_dags_from_db(self):
"""Collects DAGs from database."""
from airflow.models.serialized_dag import SerializedDagModel
+
start_dttm = timezone.utcnow()
self.log.info("Filling up the DagBag from database")
@@ -508,7 +507,8 @@ def dagbag_report(self):
task_num = sum([o.task_num for o in stats])
table = tabulate(stats, headers="keys")
- report = textwrap.dedent(f"""\n
+ report = textwrap.dedent(
+ f"""\n
-------------------------------------------------------------------
DagBag loading stats for {dag_folder}
-------------------------------------------------------------------
@@ -516,7 +516,8 @@ def dagbag_report(self):
Total task number: {task_num}
DagBag parsing time: {duration}
{table}
- """)
+ """
+ )
return report
@provide_session
@@ -534,13 +535,13 @@ def sync_to_db(self, session: Optional[Session] = None):
wait=tenacity.wait_random_exponential(multiplier=0.5, max=5),
stop=tenacity.stop_after_attempt(settings.MAX_DB_RETRIES),
before_sleep=tenacity.before_sleep_log(self.log, logging.DEBUG),
- reraise=True
+ reraise=True,
):
with attempt:
self.log.debug(
"Running dagbag.sync_to_db with retries. Try %d of %d",
attempt.retry_state.attempt_number,
- settings.MAX_DB_RETRIES
+ settings.MAX_DB_RETRIES,
)
self.log.debug("Calling the DAG.bulk_sync_to_db method")
try:
diff --git a/airflow/models/dagcode.py b/airflow/models/dagcode.py
index 5194c3dc7e5b1..67e12b7e2a196 100644
--- a/airflow/models/dagcode.py
+++ b/airflow/models/dagcode.py
@@ -46,8 +46,7 @@ class DagCode(Base):
__tablename__ = 'dag_code'
- fileloc_hash = Column(
- BigInteger, nullable=False, primary_key=True, autoincrement=False)
+ fileloc_hash = Column(BigInteger, nullable=False, primary_key=True, autoincrement=False)
fileloc = Column(String(2000), nullable=False)
# The max length of fileloc exceeds the limit of indexing.
last_updated = Column(UtcDateTime, nullable=False)
@@ -76,12 +75,9 @@ def bulk_sync_to_db(cls, filelocs: Iterable[str], session=None):
:param session: ORM Session
"""
filelocs = set(filelocs)
- filelocs_to_hashes = {
- fileloc: DagCode.dag_fileloc_hash(fileloc) for fileloc in filelocs
- }
+ filelocs_to_hashes = {fileloc: DagCode.dag_fileloc_hash(fileloc) for fileloc in filelocs}
existing_orm_dag_codes = (
- session
- .query(DagCode)
+ session.query(DagCode)
.filter(DagCode.fileloc_hash.in_(filelocs_to_hashes.values()))
.with_for_update(of=DagCode)
.all()
@@ -94,29 +90,20 @@ def bulk_sync_to_db(cls, filelocs: Iterable[str], session=None):
else:
existing_orm_dag_codes_map = {}
- existing_orm_dag_codes_by_fileloc_hashes = {
- orm.fileloc_hash: orm for orm in existing_orm_dag_codes
- }
- existing_orm_filelocs = {
- orm.fileloc for orm in existing_orm_dag_codes_by_fileloc_hashes.values()
- }
+ existing_orm_dag_codes_by_fileloc_hashes = {orm.fileloc_hash: orm for orm in existing_orm_dag_codes}
+ existing_orm_filelocs = {orm.fileloc for orm in existing_orm_dag_codes_by_fileloc_hashes.values()}
if not existing_orm_filelocs.issubset(filelocs):
conflicting_filelocs = existing_orm_filelocs.difference(filelocs)
- hashes_to_filelocs = {
- DagCode.dag_fileloc_hash(fileloc): fileloc for fileloc in filelocs
- }
+ hashes_to_filelocs = {DagCode.dag_fileloc_hash(fileloc): fileloc for fileloc in filelocs}
message = ""
for fileloc in conflicting_filelocs:
- message += ("Filename '{}' causes a hash collision in the " +
- "database with '{}'. Please rename the file.")\
- .format(
- hashes_to_filelocs[DagCode.dag_fileloc_hash(fileloc)],
- fileloc)
+ message += (
+ "Filename '{}' causes a hash collision in the "
+ + "database with '{}'. Please rename the file."
+ ).format(hashes_to_filelocs[DagCode.dag_fileloc_hash(fileloc)], fileloc)
raise AirflowException(message)
- existing_filelocs = {
- dag_code.fileloc for dag_code in existing_orm_dag_codes
- }
+ existing_filelocs = {dag_code.fileloc for dag_code in existing_orm_dag_codes}
missing_filelocs = filelocs.difference(existing_filelocs)
for fileloc in missing_filelocs:
@@ -143,14 +130,13 @@ def remove_deleted_code(cls, alive_dag_filelocs: List[str], session=None):
:param alive_dag_filelocs: file paths of alive DAGs
:param session: ORM Session
"""
- alive_fileloc_hashes = [
- cls.dag_fileloc_hash(fileloc) for fileloc in alive_dag_filelocs]
+ alive_fileloc_hashes = [cls.dag_fileloc_hash(fileloc) for fileloc in alive_dag_filelocs]
log.debug("Deleting code from %s table ", cls.__tablename__)
session.query(cls).filter(
- cls.fileloc_hash.notin_(alive_fileloc_hashes),
- cls.fileloc.notin_(alive_dag_filelocs)).delete(synchronize_session='fetch')
+ cls.fileloc_hash.notin_(alive_fileloc_hashes), cls.fileloc.notin_(alive_dag_filelocs)
+ ).delete(synchronize_session='fetch')
@classmethod
@provide_session
@@ -161,8 +147,7 @@ def has_dag(cls, fileloc: str, session=None) -> bool:
:param session: ORM Session
"""
fileloc_hash = cls.dag_fileloc_hash(fileloc)
- return session.query(exists().where(cls.fileloc_hash == fileloc_hash))\
- .scalar()
+ return session.query(exists().where(cls.fileloc_hash == fileloc_hash)).scalar()
@classmethod
def get_code_by_fileloc(cls, fileloc: str) -> str:
@@ -193,9 +178,7 @@ def _get_code_from_file(fileloc):
@classmethod
@provide_session
def _get_code_from_db(cls, fileloc, session=None):
- dag_code = session.query(cls) \
- .filter(cls.fileloc_hash == cls.dag_fileloc_hash(fileloc)) \
- .first()
+ dag_code = session.query(cls).filter(cls.fileloc_hash == cls.dag_fileloc_hash(fileloc)).first()
if not dag_code:
raise DagCodeNotFound()
else:
@@ -214,5 +197,4 @@ def dag_fileloc_hash(full_filepath: str) -> int:
import hashlib
# Only 7 bytes because MySQL BigInteger can hold only 8 bytes (signed).
- return struct.unpack('>Q', hashlib.sha1(
- full_filepath.encode('utf-8')).digest()[-8:])[0] >> 8
+ return struct.unpack('>Q', hashlib.sha1(full_filepath.encode('utf-8')).digest()[-8:])[0] >> 8
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index 6ae17d143a6b7..ad03bad1033c2 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -19,7 +19,17 @@
from typing import Any, Iterable, List, NamedTuple, Optional, Tuple, Union
from sqlalchemy import (
- Boolean, Column, DateTime, Index, Integer, PickleType, String, UniqueConstraint, and_, func, or_,
+ Boolean,
+ Column,
+ DateTime,
+ Index,
+ Integer,
+ PickleType,
+ String,
+ UniqueConstraint,
+ and_,
+ func,
+ or_,
)
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.declarative import declared_attr
@@ -125,13 +135,13 @@ def __init__(
def __repr__(self):
return (
- ''
+ ''
).format(
dag_id=self.dag_id,
execution_date=self.execution_date,
run_id=self.run_id,
- external_trigger=self.external_trigger)
+ external_trigger=self.external_trigger,
+ )
def get_state(self):
return self._state
@@ -157,11 +167,15 @@ def refresh_from_db(self, session: Session = None):
exec_date = func.cast(self.execution_date, DateTime)
- dr = session.query(DR).filter(
- DR.dag_id == self.dag_id,
- func.cast(DR.execution_date, DateTime) == exec_date,
- DR.run_id == self.run_id
- ).one()
+ dr = (
+ session.query(DR)
+ .filter(
+ DR.dag_id == self.dag_id,
+ func.cast(DR.execution_date, DateTime) == exec_date,
+ DR.run_id == self.run_id,
+ )
+ .one()
+ )
self.id = dr.id
self.state = dr.state
@@ -187,18 +201,21 @@ def next_dagruns_to_examine(
max_number = cls.DEFAULT_DAGRUNS_TO_EXAMINE
# TODO: Bake this query, it is run _A lot_
- query = session.query(cls).filter(
- cls.state == State.RUNNING,
- cls.run_type != DagRunType.BACKFILL_JOB
- ).join(
- DagModel,
- DagModel.dag_id == cls.dag_id,
- ).filter(
- DagModel.is_paused.is_(False),
- DagModel.is_active.is_(True),
- ).order_by(
- nulls_first(cls.last_scheduling_decision, session=session),
- cls.execution_date,
+ query = (
+ session.query(cls)
+ .filter(cls.state == State.RUNNING, cls.run_type != DagRunType.BACKFILL_JOB)
+ .join(
+ DagModel,
+ DagModel.dag_id == cls.dag_id,
+ )
+ .filter(
+ DagModel.is_paused.is_(False),
+ DagModel.is_active.is_(True),
+ )
+ .order_by(
+ nulls_first(cls.last_scheduling_decision, session=session),
+ cls.execution_date,
+ )
)
if not settings.ALLOW_FUTURE_EXEC_DATES:
@@ -218,7 +235,7 @@ def find(
run_type: Optional[DagRunType] = None,
session: Session = None,
execution_start_date: Optional[datetime] = None,
- execution_end_date: Optional[datetime] = None
+ execution_end_date: Optional[datetime] = None,
) -> List["DagRun"]:
"""
Returns a set of dag runs for the given search criteria.
@@ -300,10 +317,7 @@ def get_task_instances(self, state=None, session=None):
tis = tis.filter(TI.state.is_(None))
else:
not_none_state = [s for s in state if s]
- tis = tis.filter(
- or_(TI.state.in_(not_none_state),
- TI.state.is_(None))
- )
+ tis = tis.filter(or_(TI.state.in_(not_none_state), TI.state.is_(None)))
else:
tis = tis.filter(TI.state.in_(state))
@@ -321,11 +335,11 @@ def get_task_instance(self, task_id: str, session: Session = None):
:param session: Sqlalchemy ORM Session
:type session: Session
"""
- ti = session.query(TI).filter(
- TI.dag_id == self.dag_id,
- TI.execution_date == self.execution_date,
- TI.task_id == task_id
- ).first()
+ ti = (
+ session.query(TI)
+ .filter(TI.dag_id == self.dag_id, TI.execution_date == self.execution_date, TI.task_id == task_id)
+ .first()
+ )
return ti
@@ -349,27 +363,25 @@ def get_previous_dagrun(self, state: Optional[str] = None, session: Session = No
]
if state is not None:
filters.append(DagRun.state == state)
- return session.query(DagRun).filter(
- *filters
- ).order_by(
- DagRun.execution_date.desc()
- ).first()
+ return session.query(DagRun).filter(*filters).order_by(DagRun.execution_date.desc()).first()
@provide_session
def get_previous_scheduled_dagrun(self, session: Session = None):
"""The previous, SCHEDULED DagRun, if there is one"""
dag = self.get_dag()
- return session.query(DagRun).filter(
- DagRun.dag_id == self.dag_id,
- DagRun.execution_date == dag.previous_schedule(self.execution_date)
- ).first()
+ return (
+ session.query(DagRun)
+ .filter(
+ DagRun.dag_id == self.dag_id,
+ DagRun.execution_date == dag.previous_schedule(self.execution_date),
+ )
+ .first()
+ )
@provide_session
def update_state(
- self,
- session: Session = None,
- execute_callbacks: bool = True
+ self, session: Session = None, execute_callbacks: bool = True
) -> Tuple[List[TI], Optional[callback_requests.DagCallbackRequest]]:
"""
Determines the overall state of the DagRun based on the state
@@ -403,10 +415,13 @@ def update_state(
if unfinished_tasks and none_depends_on_past and none_task_concurrency:
# small speed up
- are_runnable_tasks = schedulable_tis or self._are_premature_tis(
- unfinished_tasks, finished_tasks, session) or changed_tis
+ are_runnable_tasks = (
+ schedulable_tis
+ or self._are_premature_tis(unfinished_tasks, finished_tasks, session)
+ or changed_tis
+ )
- duration = (timezone.utcnow() - start_dttm)
+ duration = timezone.utcnow() - start_dttm
Stats.timing(f"dagrun.dependency-check.{self.dag_id}", duration)
leaf_task_ids = {t.task_id for t in dag.leaves}
@@ -426,7 +441,7 @@ def update_state(
dag_id=self.dag_id,
execution_date=self.execution_date,
is_failure_callback=True,
- msg='task_failure'
+ msg='task_failure',
)
# if all leafs succeeded and no unfinished tasks, the run succeeded
@@ -443,12 +458,11 @@ def update_state(
dag_id=self.dag_id,
execution_date=self.execution_date,
is_failure_callback=False,
- msg='success'
+ msg='success',
)
# if *all tasks* are deadlocked, the run failed
- elif (unfinished_tasks and none_depends_on_past and
- none_task_concurrency and not are_runnable_tasks):
+ elif unfinished_tasks and none_depends_on_past and none_task_concurrency and not are_runnable_tasks:
self.log.error('Deadlock; marking run %s failed', self)
self.set_state(State.FAILED)
if execute_callbacks:
@@ -459,7 +473,7 @@ def update_state(
dag_id=self.dag_id,
execution_date=self.execution_date,
is_failure_callback=True,
- msg='all_tasks_deadlocked'
+ msg='all_tasks_deadlocked',
)
# finally, if the roots aren't done, the dag is still running
@@ -487,9 +501,7 @@ def task_instance_scheduling_decisions(self, session: Session = None) -> TISched
finished_tasks = [t for t in tis if t.state in State.finished]
if unfinished_tasks:
scheduleable_tasks = [ut for ut in unfinished_tasks if ut.state in SCHEDULEABLE_STATES]
- self.log.debug(
- "number of scheduleable tasks for %s: %s task(s)",
- self, len(scheduleable_tasks))
+ self.log.debug("number of scheduleable tasks for %s: %s task(s)", self, len(scheduleable_tasks))
schedulable_tis, changed_tis = self._get_ready_tis(scheduleable_tasks, finished_tasks, session)
return TISchedulingDecision(
@@ -517,10 +529,9 @@ def _get_ready_tis(
for st in scheduleable_tasks:
old_state = st.state
if st.are_dependencies_met(
- dep_context=DepContext(
- flag_upstream_failed=True,
- finished_tasks=finished_tasks),
- session=session):
+ dep_context=DepContext(flag_upstream_failed=True, finished_tasks=finished_tasks),
+ session=session,
+ ):
ready_tis.append(st)
else:
old_states[st.key] = old_state
@@ -547,8 +558,10 @@ def _are_premature_tis(
flag_upstream_failed=True,
ignore_in_retry_period=True,
ignore_in_reschedule_period=True,
- finished_tasks=finished_tasks),
- session=session):
+ finished_tasks=finished_tasks,
+ ),
+ session=session,
+ ):
return True
return False
@@ -556,7 +569,7 @@ def _emit_duration_stats_for_finished_state(self):
if self.state == State.RUNNING:
return
- duration = (self.end_date - self.start_date)
+ duration = self.end_date - self.start_date
if self.state is State.SUCCESS:
Stats.timing(f'dagrun.duration.success.{self.dag_id}', duration)
elif self.state == State.FAILED:
@@ -586,16 +599,15 @@ def verify_integrity(self, session: Session = None):
if ti.state == State.REMOVED:
pass # ti has already been removed, just ignore it
elif self.state is not State.RUNNING and not dag.partial:
- self.log.warning("Failed to get task '%s' for dag '%s'. "
- "Marking it as removed.", ti, dag)
- Stats.incr(
- f"task_removed_from_dag.{dag.dag_id}", 1, 1)
+ self.log.warning(
+ "Failed to get task '%s' for dag '%s'. " "Marking it as removed.", ti, dag
+ )
+ Stats.incr(f"task_removed_from_dag.{dag.dag_id}", 1, 1)
ti.state = State.REMOVED
should_restore_task = (task is not None) and ti.state == State.REMOVED
if should_restore_task:
- self.log.info("Restoring task '%s' which was previously "
- "removed from DAG '%s'", ti, dag)
+ self.log.info("Restoring task '%s' which was previously " "removed from DAG '%s'", ti, dag)
Stats.incr(f"task_restored_to_dag.{dag.dag_id}", 1, 1)
ti.state = State.NONE
session.merge(ti)
@@ -606,9 +618,7 @@ def verify_integrity(self, session: Session = None):
continue
if task.task_id not in task_ids:
- Stats.incr(
- f"task_instance_created-{task.task_type}",
- 1, 1)
+ Stats.incr(f"task_instance_created-{task.task_type}", 1, 1)
ti = TI(task, self.execution_date)
task_instance_mutation_hook(ti)
session.add(ti)
@@ -617,8 +627,9 @@ def verify_integrity(self, session: Session = None):
session.flush()
except IntegrityError as err:
self.log.info(str(err))
- self.log.info('Hit IntegrityError while creating the TIs for '
- f'{dag.dag_id} - {self.execution_date}.')
+ self.log.info(
+ 'Hit IntegrityError while creating the TIs for ' f'{dag.dag_id} - {self.execution_date}.'
+ )
self.log.info('Doing session rollback.')
# TODO[HA]: We probably need to savepoint this so we can keep the transaction alive.
session.rollback()
@@ -654,19 +665,16 @@ def is_backfill(self):
def get_latest_runs(cls, session=None):
"""Returns the latest DagRun for each DAG"""
subquery = (
- session
- .query(
- cls.dag_id,
- func.max(cls.execution_date).label('execution_date'))
+ session.query(cls.dag_id, func.max(cls.execution_date).label('execution_date'))
.group_by(cls.dag_id)
.subquery()
)
dagruns = (
- session
- .query(cls)
- .join(subquery,
- and_(cls.dag_id == subquery.c.dag_id,
- cls.execution_date == subquery.c.execution_date))
+ session.query(cls)
+ .join(
+ subquery,
+ and_(cls.dag_id == subquery.c.dag_id, cls.execution_date == subquery.c.execution_date),
+ )
.all()
)
return dagruns
@@ -686,9 +694,9 @@ def schedule_tis(self, schedulable_tis: Iterable[TI], session: Session = None) -
# Get list of TIs that do not need to executed, these are
# tasks using DummyOperator and without on_execute_callback / on_success_callback
dummy_tis = {
- ti for ti in schedulable_tis
- if
- (
+ ti
+ for ti in schedulable_tis
+ if (
ti.task.task_type == "DummyOperator"
and not ti.task.on_execute_callback
and not ti.task.on_success_callback
@@ -699,23 +707,34 @@ def schedule_tis(self, schedulable_tis: Iterable[TI], session: Session = None) -
count = 0
if schedulable_ti_ids:
- count += session.query(TI).filter(
- TI.dag_id == self.dag_id,
- TI.execution_date == self.execution_date,
- TI.task_id.in_(schedulable_ti_ids)
- ).update({TI.state: State.SCHEDULED}, synchronize_session=False)
+ count += (
+ session.query(TI)
+ .filter(
+ TI.dag_id == self.dag_id,
+ TI.execution_date == self.execution_date,
+ TI.task_id.in_(schedulable_ti_ids),
+ )
+ .update({TI.state: State.SCHEDULED}, synchronize_session=False)
+ )
# Tasks using DummyOperator should not be executed, mark them as success
if dummy_tis:
- count += session.query(TI).filter(
- TI.dag_id == self.dag_id,
- TI.execution_date == self.execution_date,
- TI.task_id.in_(ti.task_id for ti in dummy_tis)
- ).update({
- TI.state: State.SUCCESS,
- TI.start_date: timezone.utcnow(),
- TI.end_date: timezone.utcnow(),
- TI.duration: 0
- }, synchronize_session=False)
+ count += (
+ session.query(TI)
+ .filter(
+ TI.dag_id == self.dag_id,
+ TI.execution_date == self.execution_date,
+ TI.task_id.in_(ti.task_id for ti in dummy_tis),
+ )
+ .update(
+ {
+ TI.state: State.SUCCESS,
+ TI.start_date: timezone.utcnow(),
+ TI.end_date: timezone.utcnow(),
+ TI.duration: 0,
+ },
+ synchronize_session=False,
+ )
+ )
return count
diff --git a/airflow/models/log.py b/airflow/models/log.py
index 8f12c63f514ad..7842f63483360 100644
--- a/airflow/models/log.py
+++ b/airflow/models/log.py
@@ -37,9 +37,7 @@ class Log(Base):
owner = Column(String(500))
extra = Column(Text)
- __table_args__ = (
- Index('idx_log_dag', dag_id),
- )
+ __table_args__ = (Index('idx_log_dag', dag_id),)
def __init__(self, event, task_instance, owner=None, extra=None, **kwargs):
self.dttm = timezone.utcnow()
diff --git a/airflow/models/pool.py b/airflow/models/pool.py
index 541a464c076cb..67cdee52098d2 100644
--- a/airflow/models/pool.py
+++ b/airflow/models/pool.py
@@ -53,7 +53,7 @@ class Pool(Base):
DEFAULT_POOL_NAME = 'default_pool'
def __repr__(self):
- return str(self.pool) # pylint: disable=E0012
+ return str(self.pool) # pylint: disable=E0012
@staticmethod
@provide_session
@@ -126,9 +126,7 @@ def slots_stats(
elif state == "queued":
stats_dict["queued"] = count
else:
- raise AirflowException(
- f"Unexpected state. Expected values: {EXECUTION_STATES}."
- )
+ raise AirflowException(f"Unexpected state. Expected values: {EXECUTION_STATES}.")
# calculate open metric
for pool_name, stats_dict in pools.items():
@@ -162,9 +160,9 @@ def occupied_slots(self, session: Session):
:return: the used number of slots
"""
from airflow.models.taskinstance import TaskInstance # Avoid circular import
+
return (
- session
- .query(func.sum(TaskInstance.pool_slots))
+ session.query(func.sum(TaskInstance.pool_slots))
.filter(TaskInstance.pool == self.pool)
.filter(TaskInstance.state.in_(list(EXECUTION_STATES)))
.scalar()
@@ -181,8 +179,7 @@ def running_slots(self, session: Session):
from airflow.models.taskinstance import TaskInstance # Avoid circular import
return (
- session
- .query(func.sum(TaskInstance.pool_slots))
+ session.query(func.sum(TaskInstance.pool_slots))
.filter(TaskInstance.pool == self.pool)
.filter(TaskInstance.state == State.RUNNING)
.scalar()
@@ -199,8 +196,7 @@ def queued_slots(self, session: Session):
from airflow.models.taskinstance import TaskInstance # Avoid circular import
return (
- session
- .query(func.sum(TaskInstance.pool_slots))
+ session.query(func.sum(TaskInstance.pool_slots))
.filter(TaskInstance.pool == self.pool)
.filter(TaskInstance.state == State.QUEUED)
.scalar()
diff --git a/airflow/models/renderedtifields.py b/airflow/models/renderedtifields.py
index 661b0a8ac219e..8102b01762847 100644
--- a/airflow/models/renderedtifields.py
+++ b/airflow/models/renderedtifields.py
@@ -50,9 +50,7 @@ def __init__(self, ti: TaskInstance, render_templates=True):
if render_templates:
ti.render_templates()
self.rendered_fields = {
- field: serialize_template_field(
- getattr(self.task, field)
- ) for field in self.task.template_fields
+ field: serialize_template_field(getattr(self.task, field)) for field in self.task.template_fields
}
def __repr__(self):
@@ -69,11 +67,13 @@ def get_templated_fields(cls, ti: TaskInstance, session: Session = None) -> Opti
:param session: SqlAlchemy Session
:return: Rendered Templated TI field
"""
- result = session.query(cls.rendered_fields).filter(
- cls.dag_id == ti.dag_id,
- cls.task_id == ti.task_id,
- cls.execution_date == ti.execution_date
- ).one_or_none()
+ result = (
+ session.query(cls.rendered_fields)
+ .filter(
+ cls.dag_id == ti.dag_id, cls.task_id == ti.task_id, cls.execution_date == ti.execution_date
+ )
+ .one_or_none()
+ )
if result:
rendered_fields = result.rendered_fields
@@ -92,9 +92,11 @@ def write(self, session: Session = None):
@classmethod
@provide_session
def delete_old_records(
- cls, task_id: str, dag_id: str,
+ cls,
+ task_id: str,
+ dag_id: str,
num_to_keep=conf.getint("core", "max_num_rendered_ti_fields_per_task", fallback=0),
- session: Session = None
+ session: Session = None,
):
"""
Keep only Last X (num_to_keep) number of records for a task by deleting others
@@ -107,22 +109,22 @@ def delete_old_records(
if num_to_keep <= 0:
return
- tis_to_keep_query = session \
- .query(cls.dag_id, cls.task_id, cls.execution_date) \
- .filter(cls.dag_id == dag_id, cls.task_id == task_id) \
- .order_by(cls.execution_date.desc()) \
+ tis_to_keep_query = (
+ session.query(cls.dag_id, cls.task_id, cls.execution_date)
+ .filter(cls.dag_id == dag_id, cls.task_id == task_id)
+ .order_by(cls.execution_date.desc())
.limit(num_to_keep)
+ )
if session.bind.dialect.name in ["postgresql", "sqlite"]:
# Fetch Top X records given dag_id & task_id ordered by Execution Date
subq1 = tis_to_keep_query.subquery('subq1')
- session.query(cls) \
- .filter(
- cls.dag_id == dag_id,
- cls.task_id == task_id,
- tuple_(cls.dag_id, cls.task_id, cls.execution_date).notin_(subq1)) \
- .delete(synchronize_session=False)
+ session.query(cls).filter(
+ cls.dag_id == dag_id,
+ cls.task_id == task_id,
+ tuple_(cls.dag_id, cls.task_id, cls.execution_date).notin_(subq1),
+ ).delete(synchronize_session=False)
elif session.bind.dialect.name in ["mysql"]:
# Fetch Top X records given dag_id & task_id ordered by Execution Date
subq1 = tis_to_keep_query.subquery('subq1')
@@ -131,28 +133,26 @@ def delete_old_records(
# Workaround for MySQL Limitation (https://stackoverflow.com/a/19344141/5691525)
# Limitation: This version of MySQL does not yet support
# LIMIT & IN/ALL/ANY/SOME subquery
- subq2 = (
- session
- .query(subq1.c.dag_id, subq1.c.task_id, subq1.c.execution_date)
- .subquery('subq2')
- )
+ subq2 = session.query(subq1.c.dag_id, subq1.c.task_id, subq1.c.execution_date).subquery('subq2')
- session.query(cls) \
- .filter(
- cls.dag_id == dag_id,
- cls.task_id == task_id,
- tuple_(cls.dag_id, cls.task_id, cls.execution_date).notin_(subq2)) \
- .delete(synchronize_session=False)
+ session.query(cls).filter(
+ cls.dag_id == dag_id,
+ cls.task_id == task_id,
+ tuple_(cls.dag_id, cls.task_id, cls.execution_date).notin_(subq2),
+ ).delete(synchronize_session=False)
else:
# Fetch Top X records given dag_id & task_id ordered by Execution Date
tis_to_keep = tis_to_keep_query.all()
- filter_tis = [not_(and_(
- cls.dag_id == ti.dag_id,
- cls.task_id == ti.task_id,
- cls.execution_date == ti.execution_date
- )) for ti in tis_to_keep]
-
- session.query(cls) \
- .filter(and_(*filter_tis)) \
- .delete(synchronize_session=False)
+ filter_tis = [
+ not_(
+ and_(
+ cls.dag_id == ti.dag_id,
+ cls.task_id == ti.task_id,
+ cls.execution_date == ti.execution_date,
+ )
+ )
+ for ti in tis_to_keep
+ ]
+
+ session.query(cls).filter(and_(*filter_tis)).delete(synchronize_session=False)
diff --git a/airflow/models/sensorinstance.py b/airflow/models/sensorinstance.py
index 3e0c417babe3b..f8e6ef82919c0 100644
--- a/airflow/models/sensorinstance.py
+++ b/airflow/models/sensorinstance.py
@@ -60,14 +60,10 @@ class SensorInstance(Base):
poke_context = Column(Text, nullable=False)
execution_context = Column(Text)
created_at = Column(UtcDateTime, default=timezone.utcnow(), nullable=False)
- updated_at = Column(UtcDateTime,
- default=timezone.utcnow(),
- onupdate=timezone.utcnow(),
- nullable=False)
+ updated_at = Column(UtcDateTime, default=timezone.utcnow(), onupdate=timezone.utcnow(), nullable=False)
__table_args__ = (
Index('ti_primary_key', dag_id, task_id, execution_date, unique=True),
-
Index('si_hashcode', hashcode),
Index('si_shardcode', shardcode),
Index('si_state_shard', state, shardcode),
@@ -118,11 +114,16 @@ def register(cls, ti, poke_context, execution_context, session=None):
encoded_poke = json.dumps(poke_context)
encoded_execution_context = json.dumps(execution_context)
- sensor = session.query(SensorInstance).filter(
- SensorInstance.dag_id == ti.dag_id,
- SensorInstance.task_id == ti.task_id,
- SensorInstance.execution_date == ti.execution_date
- ).with_for_update().first()
+ sensor = (
+ session.query(SensorInstance)
+ .filter(
+ SensorInstance.dag_id == ti.dag_id,
+ SensorInstance.task_id == ti.task_id,
+ SensorInstance.execution_date == ti.execution_date,
+ )
+ .with_for_update()
+ .first()
+ )
if sensor is None:
sensor = SensorInstance(ti=ti)
@@ -161,6 +162,7 @@ def try_number(self, value):
self._try_number = value
def __repr__(self):
- return "<{self.__class__.__name__}: id: {self.id} poke_context: {self.poke_context} " \
- "execution_context: {self.execution_context} state: {self.state}>".format(
- self=self)
+ return (
+ "<{self.__class__.__name__}: id: {self.id} poke_context: {self.poke_context} "
+ "execution_context: {self.execution_context} state: {self.state}>".format(self=self)
+ )
diff --git a/airflow/models/serialized_dag.py b/airflow/models/serialized_dag.py
index 2eedd4759c6ab..dfd8b3a753c77 100644
--- a/airflow/models/serialized_dag.py
+++ b/airflow/models/serialized_dag.py
@@ -69,9 +69,7 @@ class SerializedDagModel(Base):
last_updated = Column(UtcDateTime, nullable=False)
dag_hash = Column(String(32), nullable=False)
- __table_args__ = (
- Index('idx_fileloc_hash', fileloc_hash, unique=False),
- )
+ __table_args__ = (Index('idx_fileloc_hash', fileloc_hash, unique=False),)
dag_runs = relationship(
DagRun,
@@ -115,16 +113,19 @@ def write_dag(cls, dag: DAG, min_update_interval: Optional[int] = None, session:
# If Yes, does nothing
# If No or the DAG does not exists, updates / writes Serialized DAG to DB
if min_update_interval is not None:
- if session.query(exists().where(
- and_(cls.dag_id == dag.dag_id,
- (timezone.utcnow() - timedelta(seconds=min_update_interval)) < cls.last_updated))
+ if session.query(
+ exists().where(
+ and_(
+ cls.dag_id == dag.dag_id,
+ (timezone.utcnow() - timedelta(seconds=min_update_interval)) < cls.last_updated,
+ )
+ )
).scalar():
return
log.debug("Checking if DAG (%s) changed", dag.dag_id)
new_serialized_dag = cls(dag)
- serialized_dag_hash_from_db = session.query(
- cls.dag_hash).filter(cls.dag_id == dag.dag_id).scalar()
+ serialized_dag_hash_from_db = session.query(cls.dag_hash).filter(cls.dag_id == dag.dag_id).scalar()
if serialized_dag_hash_from_db == new_serialized_dag.dag_hash:
log.debug("Serialized DAG (%s) is unchanged. Skipping writing to DB", dag.dag_id)
@@ -154,8 +155,10 @@ def read_all_dags(cls, session: Session = None) -> Dict[str, 'SerializedDAG']:
dags[row.dag_id] = dag
else:
log.warning(
- "dag_id Mismatch in DB: Row with dag_id '%s' has Serialised DAG "
- "with '%s' dag_id", row.dag_id, dag.dag_id)
+ "dag_id Mismatch in DB: Row with dag_id '%s' has Serialised DAG " "with '%s' dag_id",
+ row.dag_id,
+ dag.dag_id,
+ )
return dags
@property
@@ -186,16 +189,18 @@ def remove_deleted_dags(cls, alive_dag_filelocs: List[str], session=None):
:param alive_dag_filelocs: file paths of alive DAGs
:param session: ORM Session
"""
- alive_fileloc_hashes = [
- DagCode.dag_fileloc_hash(fileloc) for fileloc in alive_dag_filelocs]
+ alive_fileloc_hashes = [DagCode.dag_fileloc_hash(fileloc) for fileloc in alive_dag_filelocs]
- log.debug("Deleting Serialized DAGs (for which DAG files are deleted) "
- "from %s table ", cls.__tablename__)
+ log.debug(
+ "Deleting Serialized DAGs (for which DAG files are deleted) " "from %s table ", cls.__tablename__
+ )
# pylint: disable=no-member
- session.execute(cls.__table__.delete().where(
- and_(cls.fileloc_hash.notin_(alive_fileloc_hashes),
- cls.fileloc.notin_(alive_dag_filelocs))))
+ session.execute(
+ cls.__table__.delete().where(
+ and_(cls.fileloc_hash.notin_(alive_fileloc_hashes), cls.fileloc.notin_(alive_dag_filelocs))
+ )
+ )
@classmethod
@provide_session
@@ -224,8 +229,7 @@ def get(cls, dag_id: str, session: Session = None) -> Optional['SerializedDagMod
# If we didn't find a matching DAG id then ask the DAG table to find
# out the root dag
- root_dag_id = session.query(
- DagModel.root_dag_id).filter(DagModel.dag_id == dag_id).scalar()
+ root_dag_id = session.query(DagModel.root_dag_id).filter(DagModel.dag_id == dag_id).scalar()
return session.query(cls).filter(cls.dag_id == root_dag_id).one_or_none()
@@ -245,9 +249,7 @@ def bulk_sync_to_db(dags: List[DAG], session: Session = None):
for dag in dags:
if not dag.is_subdag:
SerializedDagModel.write_dag(
- dag,
- min_update_interval=MIN_SERIALIZED_DAG_UPDATE_INTERVAL,
- session=session
+ dag, min_update_interval=MIN_SERIALIZED_DAG_UPDATE_INTERVAL, session=session
)
@classmethod
diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py
index 43033310e1ae4..dc40329087f8c 100644
--- a/airflow/models/skipmixin.py
+++ b/airflow/models/skipmixin.py
@@ -70,7 +70,11 @@ def _set_state_to_skipped(self, dag_run, execution_date, tasks, session):
@provide_session
def skip(
- self, dag_run, execution_date, tasks, session=None,
+ self,
+ dag_run,
+ execution_date,
+ tasks,
+ session=None,
):
"""
Sets tasks instances to skipped from the same dag run.
@@ -105,12 +109,10 @@ def skip(
task_id=task_id,
dag_id=dag_run.dag_id,
execution_date=dag_run.execution_date,
- session=session
+ session=session,
)
- def skip_all_except(
- self, ti: TaskInstance, branch_task_ids: Union[str, Iterable[str]]
- ):
+ def skip_all_except(self, ti: TaskInstance, branch_task_ids: Union[str, Iterable[str]]):
"""
This method implements the logic for a branching operator; given a single
task ID or list of task IDs to follow, this skips all other tasks
@@ -145,25 +147,15 @@ def skip_all_except(
# task1
#
for branch_task_id in list(branch_task_ids):
- branch_task_ids.update(
- dag.get_task(branch_task_id).get_flat_relative_ids(upstream=False)
- )
-
- skip_tasks = [
- t
- for t in downstream_tasks
- if t.task_id not in branch_task_ids
- ]
+ branch_task_ids.update(dag.get_task(branch_task_id).get_flat_relative_ids(upstream=False))
+
+ skip_tasks = [t for t in downstream_tasks if t.task_id not in branch_task_ids]
follow_task_ids = [t.task_id for t in downstream_tasks if t.task_id in branch_task_ids]
self.log.info("Skipping tasks %s", [t.task_id for t in skip_tasks])
with create_session() as session:
- self._set_state_to_skipped(
- dag_run, ti.execution_date, skip_tasks, session=session
- )
+ self._set_state_to_skipped(dag_run, ti.execution_date, skip_tasks, session=session)
# For some reason, session.commit() needs to happen before xcom_push.
# Otherwise the session is not committed.
session.commit()
- ti.xcom_push(
- key=XCOM_SKIPMIXIN_KEY, value={XCOM_SKIPMIXIN_FOLLOWED: follow_task_ids}
- )
+ ti.xcom_push(key=XCOM_SKIPMIXIN_KEY, value={XCOM_SKIPMIXIN_FOLLOWED: follow_task_ids})
diff --git a/airflow/models/slamiss.py b/airflow/models/slamiss.py
index 7c7160bf07bd4..6c841e3ee6e05 100644
--- a/airflow/models/slamiss.py
+++ b/airflow/models/slamiss.py
@@ -39,10 +39,7 @@ class SlaMiss(Base):
description = Column(Text)
notification_sent = Column(Boolean, default=False)
- __table_args__ = (
- Index('sm_dag', dag_id, unique=False),
- )
+ __table_args__ = (Index('sm_dag', dag_id, unique=False),)
def __repr__(self):
- return str((
- self.dag_id, self.task_id, self.execution_date.isoformat()))
+ return str((self.dag_id, self.task_id, self.execution_date.isoformat()))
diff --git a/airflow/models/taskfail.py b/airflow/models/taskfail.py
index 5b3979203d617..23d3086e2d013 100644
--- a/airflow/models/taskfail.py
+++ b/airflow/models/taskfail.py
@@ -35,10 +35,7 @@ class TaskFail(Base):
end_date = Column(UtcDateTime)
duration = Column(Integer)
- __table_args__ = (
- Index('idx_task_fail_dag_task_date', dag_id, task_id, execution_date,
- unique=False),
- )
+ __table_args__ = (Index('idx_task_fail_dag_task_date', dag_id, task_id, execution_date, unique=False),)
def __init__(self, task, execution_date, start_date, end_date):
self.dag_id = task.dag_id
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 872682d780179..6a73b40d674a6 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -42,8 +42,12 @@
from airflow import settings
from airflow.configuration import conf
from airflow.exceptions import (
- AirflowException, AirflowFailException, AirflowRescheduleException, AirflowSkipException,
- AirflowSmartSensorException, AirflowTaskTimeout,
+ AirflowException,
+ AirflowFailException,
+ AirflowRescheduleException,
+ AirflowSkipException,
+ AirflowSmartSensorException,
+ AirflowTaskTimeout,
)
from airflow.models.base import COLLATION_ARGS, ID_LEN, Base
from airflow.models.log import Log
@@ -92,11 +96,12 @@ def set_current_context(context: Context):
)
-def clear_task_instances(tis,
- session,
- activate_dag_runs=True,
- dag=None,
- ):
+def clear_task_instances(
+ tis,
+ session,
+ activate_dag_runs=True,
+ dag=None,
+):
"""
Clears a set of task instances, but makes sure the running ones
get killed.
@@ -132,20 +137,26 @@ def clear_task_instances(tis,
TR.dag_id == ti.dag_id,
TR.task_id == ti.task_id,
TR.execution_date == ti.execution_date,
- TR.try_number == ti.try_number
+ TR.try_number == ti.try_number,
).delete()
if job_ids:
from airflow.jobs.base_job import BaseJob
+
for job in session.query(BaseJob).filter(BaseJob.id.in_(job_ids)).all(): # noqa
job.state = State.SHUTDOWN
if activate_dag_runs and tis:
from airflow.models.dagrun import DagRun # Avoid circular import
- drs = session.query(DagRun).filter(
- DagRun.dag_id.in_({ti.dag_id for ti in tis}),
- DagRun.execution_date.in_({ti.execution_date for ti in tis}),
- ).all()
+
+ drs = (
+ session.query(DagRun)
+ .filter(
+ DagRun.dag_id.in_({ti.dag_id for ti in tis}),
+ DagRun.execution_date.in_({ti.execution_date for ti in tis}),
+ )
+ .all()
+ )
for dr in drs:
dr.state = State.RUNNING
dr.start_date = timezone.utcnow()
@@ -167,18 +178,14 @@ def primary(self) -> Tuple[str, str, datetime]:
@property
def reduced(self) -> 'TaskInstanceKey':
"""Remake the key by subtracting 1 from try number to match in memory information"""
- return TaskInstanceKey(
- self.dag_id, self.task_id, self.execution_date, max(1, self.try_number - 1)
- )
+ return TaskInstanceKey(self.dag_id, self.task_id, self.execution_date, max(1, self.try_number - 1))
def with_try_number(self, try_number: int) -> 'TaskInstanceKey':
"""Returns TaskInstanceKey with provided ``try_number``"""
- return TaskInstanceKey(
- self.dag_id, self.task_id, self.execution_date, try_number
- )
+ return TaskInstanceKey(self.dag_id, self.task_id, self.execution_date, try_number)
-class TaskInstance(Base, LoggingMixin): # pylint: disable=R0902,R0904
+class TaskInstance(Base, LoggingMixin): # pylint: disable=R0902,R0904
"""
Task instances store the state of a task instance. This table is the
authority and single source of truth around what tasks have run and the
@@ -247,11 +254,12 @@ def __init__(self, task, execution_date: datetime, state: Optional[str] = None):
# make sure we have a localized execution_date stored in UTC
if execution_date and not timezone.is_localized(execution_date):
- self.log.warning("execution date %s has no timezone information. Using "
- "default from dag or system", execution_date)
+ self.log.warning(
+ "execution date %s has no timezone information. Using " "default from dag or system",
+ execution_date,
+ )
if self.task.has_dag():
- execution_date = timezone.make_aware(execution_date,
- self.task.dag.timezone)
+ execution_date = timezone.make_aware(execution_date, self.task.dag.timezone)
else:
execution_date = timezone.make_aware(execution_date)
@@ -314,19 +322,20 @@ def next_try_number(self):
"""Setting Next Try Number"""
return self._try_number + 1
- def command_as_list( # pylint: disable=too-many-arguments
- self,
- mark_success=False,
- ignore_all_deps=False,
- ignore_task_deps=False,
- ignore_depends_on_past=False,
- ignore_ti_state=False,
- local=False,
- pickle_id=None,
- raw=False,
- job_id=None,
- pool=None,
- cfg_path=None):
+ def command_as_list( # pylint: disable=too-many-arguments
+ self,
+ mark_success=False,
+ ignore_all_deps=False,
+ ignore_task_deps=False,
+ ignore_depends_on_past=False,
+ ignore_ti_state=False,
+ local=False,
+ pickle_id=None,
+ raw=False,
+ job_id=None,
+ pool=None,
+ cfg_path=None,
+ ):
"""
Returns a command that can be executed anywhere where airflow is
installed. This command is part of the message sent to executors by
@@ -357,25 +366,27 @@ def command_as_list( # pylint: disable=too-many-arguments
raw=raw,
job_id=job_id,
pool=pool,
- cfg_path=cfg_path)
+ cfg_path=cfg_path,
+ )
@staticmethod
- def generate_command(dag_id: str, # pylint: disable=too-many-arguments
- task_id: str,
- execution_date: datetime,
- mark_success: bool = False,
- ignore_all_deps: bool = False,
- ignore_depends_on_past: bool = False,
- ignore_task_deps: bool = False,
- ignore_ti_state: bool = False,
- local: bool = False,
- pickle_id: Optional[int] = None,
- file_path: Optional[str] = None,
- raw: bool = False,
- job_id: Optional[str] = None,
- pool: Optional[str] = None,
- cfg_path: Optional[str] = None
- ) -> List[str]:
+ def generate_command(
+ dag_id: str, # pylint: disable=too-many-arguments
+ task_id: str,
+ execution_date: datetime,
+ mark_success: bool = False,
+ ignore_all_deps: bool = False,
+ ignore_depends_on_past: bool = False,
+ ignore_task_deps: bool = False,
+ ignore_ti_state: bool = False,
+ local: bool = False,
+ pickle_id: Optional[int] = None,
+ file_path: Optional[str] = None,
+ raw: bool = False,
+ job_id: Optional[str] = None,
+ pool: Optional[str] = None,
+ cfg_path: Optional[str] = None,
+ ) -> List[str]:
"""
Generates the shell command required to execute this task instance.
@@ -457,10 +468,7 @@ def log_url(self):
iso = quote(self.execution_date.isoformat())
base_url = conf.get('webserver', 'BASE_URL')
return base_url + ( # noqa
- "/log?"
- "execution_date={iso}"
- "&task_id={task_id}"
- "&dag_id={dag_id}"
+ "/log?" "execution_date={iso}" "&task_id={task_id}" "&dag_id={dag_id}"
).format(iso=iso, task_id=self.task_id, dag_id=self.dag_id)
@property
@@ -487,11 +495,15 @@ def current_state(self, session=None) -> str:
:param session: SQLAlchemy ORM Session
:type session: Session
"""
- ti = session.query(TaskInstance).filter(
- TaskInstance.dag_id == self.dag_id,
- TaskInstance.task_id == self.task_id,
- TaskInstance.execution_date == self.execution_date,
- ).all()
+ ti = (
+ session.query(TaskInstance)
+ .filter(
+ TaskInstance.dag_id == self.dag_id,
+ TaskInstance.task_id == self.task_id,
+ TaskInstance.execution_date == self.execution_date,
+ )
+ .all()
+ )
if ti:
state = ti[0].state
else:
@@ -528,7 +540,8 @@ def refresh_from_db(self, session=None, lock_for_update=False) -> None:
qry = session.query(TaskInstance).filter(
TaskInstance.dag_id == self.dag_id,
TaskInstance.task_id == self.task_id,
- TaskInstance.execution_date == self.execution_date)
+ TaskInstance.execution_date == self.execution_date,
+ )
if lock_for_update:
ti = qry.with_for_update().first()
@@ -542,7 +555,7 @@ def refresh_from_db(self, session=None, lock_for_update=False) -> None:
self.state = ti.state
# Get the raw value of try_number column, don't read through the
# accessor here otherwise it will be incremented by one already.
- self.try_number = ti._try_number # noqa pylint: disable=protected-access
+ self.try_number = ti._try_number # noqa pylint: disable=protected-access
self.max_tries = ti.max_tries
self.hostname = ti.hostname
self.unixname = ti.unixname
@@ -589,7 +602,7 @@ def clear_xcom_data(self, session=None):
session.query(XCom).filter(
XCom.dag_id == self.dag_id,
XCom.task_id == self.task_id,
- XCom.execution_date == self.execution_date
+ XCom.execution_date == self.execution_date,
).delete()
session.commit()
self.log.debug("XCom data cleared")
@@ -656,9 +669,7 @@ def are_dependents_done(self, session=None):
@provide_session
def get_previous_ti(
- self,
- state: Optional[str] = None,
- session: Session = None
+ self, state: Optional[str] = None, session: Session = None
) -> Optional['TaskInstance']:
"""
The task instance for the task that ran before this task instance.
@@ -745,9 +756,7 @@ def get_previous_execution_date(
@provide_session
def get_previous_start_date(
- self,
- state: Optional[str] = None,
- session: Session = None
+ self, state: Optional[str] = None, session: Session = None
) -> Optional[pendulum.DateTime]:
"""
The start date from property previous_ti_success.
@@ -776,11 +785,7 @@ def previous_start_date_success(self) -> Optional[pendulum.DateTime]:
return self.get_previous_start_date(state=State.SUCCESS)
@provide_session
- def are_dependencies_met(
- self,
- dep_context=None,
- session=None,
- verbose=False):
+ def are_dependencies_met(self, dep_context=None, session=None, verbose=False):
"""
Returns whether or not all the conditions are met for this task instance to be run
given the context for the dependencies (e.g. a task instance being force run from
@@ -798,14 +803,14 @@ def are_dependencies_met(
dep_context = dep_context or DepContext()
failed = False
verbose_aware_logger = self.log.info if verbose else self.log.debug
- for dep_status in self.get_failed_dep_statuses(
- dep_context=dep_context,
- session=session):
+ for dep_status in self.get_failed_dep_statuses(dep_context=dep_context, session=session):
failed = True
verbose_aware_logger(
"Dependencies not met for %s, dependency '%s' FAILED: %s",
- self, dep_status.dep_name, dep_status.reason
+ self,
+ dep_status.dep_name,
+ dep_status.reason,
)
if failed:
@@ -815,21 +820,18 @@ def are_dependencies_met(
return True
@provide_session
- def get_failed_dep_statuses(
- self,
- dep_context=None,
- session=None):
+ def get_failed_dep_statuses(self, dep_context=None, session=None):
"""Get failed Dependencies"""
dep_context = dep_context or DepContext()
for dep in dep_context.deps | self.task.deps:
- for dep_status in dep.get_dep_statuses(
- self,
- session,
- dep_context):
+ for dep_status in dep.get_dep_statuses(self, session, dep_context):
self.log.debug(
"%s dependency '%s' PASSED: %s, %s",
- self, dep_status.dep_name, dep_status.passed, dep_status.reason
+ self,
+ dep_status.dep_name,
+ dep_status.passed,
+ dep_status.reason,
)
if not dep_status.passed:
@@ -837,8 +839,7 @@ def get_failed_dep_statuses(
def __repr__(self):
return ( # noqa
- ""
+ ""
).format(ti=self)
def next_retry_datetime(self):
@@ -853,11 +854,14 @@ def next_retry_datetime(self):
# will occur in the modded_hash calculation.
min_backoff = int(math.ceil(delay.total_seconds() * (2 ** (self.try_number - 2))))
# deterministic per task instance
- ti_hash = int(hashlib.sha1("{}#{}#{}#{}".format(self.dag_id, # noqa
- self.task_id,
- self.execution_date,
- self.try_number)
- .encode('utf-8')).hexdigest(), 16)
+ ti_hash = int(
+ hashlib.sha1(
+ "{}#{}#{}#{}".format(
+ self.dag_id, self.task_id, self.execution_date, self.try_number # noqa
+ ).encode('utf-8')
+ ).hexdigest(),
+ 16,
+ )
# between 1 and 1.0 * delay * (2^retry_number)
modded_hash = min_backoff + ti_hash % min_backoff
# timedelta has a maximum representable value. The exponentiation
@@ -865,10 +869,7 @@ def next_retry_datetime(self):
# of tries (around 50 if the initial delay is 1s, even fewer if
# the delay is larger). Cap the value here before creating a
# timedelta object so the operation doesn't fail.
- delay_backoff_in_seconds = min(
- modded_hash,
- timedelta.max.total_seconds() - 1
- )
+ delay_backoff_in_seconds = min(modded_hash, timedelta.max.total_seconds() - 1)
delay = timedelta(seconds=delay_backoff_in_seconds)
if self.task.max_retry_delay:
delay = min(self.task.max_retry_delay, delay)
@@ -879,8 +880,7 @@ def ready_for_retry(self):
Checks on whether the task instance is in the right state and timeframe
to be retried.
"""
- return (self.state == State.UP_FOR_RETRY and
- self.next_retry_datetime() < timezone.utcnow())
+ return self.state == State.UP_FOR_RETRY and self.next_retry_datetime() < timezone.utcnow()
@provide_session
def get_dagrun(self, session: Session = None):
@@ -891,26 +891,29 @@ def get_dagrun(self, session: Session = None):
:return: DagRun
"""
from airflow.models.dagrun import DagRun # Avoid circular import
- dr = session.query(DagRun).filter(
- DagRun.dag_id == self.dag_id,
- DagRun.execution_date == self.execution_date
- ).first()
+
+ dr = (
+ session.query(DagRun)
+ .filter(DagRun.dag_id == self.dag_id, DagRun.execution_date == self.execution_date)
+ .first()
+ )
return dr
@provide_session
- def check_and_change_state_before_execution( # pylint: disable=too-many-arguments
- self,
- verbose: bool = True,
- ignore_all_deps: bool = False,
- ignore_depends_on_past: bool = False,
- ignore_task_deps: bool = False,
- ignore_ti_state: bool = False,
- mark_success: bool = False,
- test_mode: bool = False,
- job_id: Optional[str] = None,
- pool: Optional[str] = None,
- session=None) -> bool:
+ def check_and_change_state_before_execution( # pylint: disable=too-many-arguments
+ self,
+ verbose: bool = True,
+ ignore_all_deps: bool = False,
+ ignore_depends_on_past: bool = False,
+ ignore_task_deps: bool = False,
+ ignore_ti_state: bool = False,
+ mark_success: bool = False,
+ test_mode: bool = False,
+ job_id: Optional[str] = None,
+ pool: Optional[str] = None,
+ session=None,
+ ) -> bool:
"""
Checks dependencies and then sets state to RUNNING if they are met. Returns
True if and only if state is set to RUNNING, which implies that task should be
@@ -960,11 +963,11 @@ def check_and_change_state_before_execution( # pylint: disable=too-many-argum
ignore_all_deps=ignore_all_deps,
ignore_ti_state=ignore_ti_state,
ignore_depends_on_past=ignore_depends_on_past,
- ignore_task_deps=ignore_task_deps)
+ ignore_task_deps=ignore_task_deps,
+ )
if not self.are_dependencies_met(
- dep_context=non_requeueable_dep_context,
- session=session,
- verbose=True):
+ dep_context=non_requeueable_dep_context, session=session, verbose=True
+ ):
session.commit()
return False
@@ -987,17 +990,17 @@ def check_and_change_state_before_execution( # pylint: disable=too-many-argum
ignore_all_deps=ignore_all_deps,
ignore_depends_on_past=ignore_depends_on_past,
ignore_task_deps=ignore_task_deps,
- ignore_ti_state=ignore_ti_state)
- if not self.are_dependencies_met(
- dep_context=dep_context,
- session=session,
- verbose=True):
+ ignore_ti_state=ignore_ti_state,
+ )
+ if not self.are_dependencies_met(dep_context=dep_context, session=session, verbose=True):
self.state = State.NONE
self.log.warning(hr_line_break)
self.log.warning(
"Rescheduling due to concurrency limits reached "
"at task runtime. Attempt %s of "
- "%s. State set to NONE.", self.try_number, self.max_tries + 1
+ "%s. State set to NONE.",
+ self.try_number,
+ self.max_tries + 1,
)
self.log.warning(hr_line_break)
self.queued_dttm = timezone.utcnow()
@@ -1040,12 +1043,13 @@ def _date_or_empty(self, attr):
@provide_session
@Sentry.enrich_errors
def _run_raw_task(
- self,
- mark_success: bool = False,
- test_mode: bool = False,
- job_id: Optional[str] = None,
- pool: Optional[str] = None,
- session=None) -> None:
+ self,
+ mark_success: bool = False,
+ test_mode: bool = False,
+ job_id: Optional[str] = None,
+ pool: Optional[str] = None,
+ session=None,
+ ) -> None:
"""
Immediately runs the task (without checking or changing db state
before execution) and then sets the appropriate final state after
@@ -1130,7 +1134,7 @@ def _run_raw_task(
self.task_id,
self._date_or_empty('execution_date'),
self._date_or_empty('start_date'),
- self._date_or_empty('end_date')
+ self._date_or_empty('end_date'),
)
self.set_duration()
if not test_mode:
@@ -1149,10 +1153,12 @@ def _run_mini_scheduler_on_child_tasks(self, session=None) -> None:
try:
# Re-select the row with a lock
- dag_run = with_row_locks(session.query(DagRun).filter_by(
- dag_id=self.dag_id,
- execution_date=self.execution_date,
- )).one()
+ dag_run = with_row_locks(
+ session.query(DagRun).filter_by(
+ dag_id=self.dag_id,
+ execution_date=self.execution_date,
+ )
+ ).one()
# Get a partial dag with just the specific tasks we want to
# examine. In order for dep checks to work correctly, we
@@ -1174,9 +1180,7 @@ def _run_mini_scheduler_on_child_tasks(self, session=None) -> None:
if task_id not in self.task.downstream_task_ids
}
- schedulable_tis = [
- ti for ti in info.schedulable_tis if ti.task_id not in skippable_task_ids
- ]
+ schedulable_tis = [ti for ti in info.schedulable_tis if ti.task_id not in skippable_task_ids]
for schedulable_ti in schedulable_tis:
if not hasattr(schedulable_ti, "task"):
schedulable_ti.task = self.task.dag.get_task(schedulable_ti.task_id)
@@ -1193,10 +1197,7 @@ def _run_mini_scheduler_on_child_tasks(self, session=None) -> None:
)
session.rollback()
- def _prepare_and_execute_task_with_callbacks(
- self,
- context,
- task):
+ def _prepare_and_execute_task_with_callbacks(self, context, task):
"""Prepare Task for Execution"""
from airflow.models.renderedtifields import RenderedTaskInstanceFields
@@ -1220,9 +1221,10 @@ def signal_handler(signum, frame): # pylint: disable=unused-argument
# Export context to make it available for operators to use.
airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True)
- self.log.info("Exporting the following env vars:\n%s",
- '\n'.join([f"{k}={v}"
- for k, v in airflow_context_vars.items()]))
+ self.log.info(
+ "Exporting the following env vars:\n%s",
+ '\n'.join([f"{k}={v}" for k, v in airflow_context_vars.items()]),
+ )
os.environ.update(airflow_context_vars)
@@ -1238,8 +1240,9 @@ def signal_handler(signum, frame): # pylint: disable=unused-argument
try:
registered = task_copy.register_in_sensor_service(self, context)
except Exception as e:
- self.log.warning("Failed to register in sensor service."
- "Continue to run task in non smart sensor mode.")
+ self.log.warning(
+ "Failed to register in sensor service." "Continue to run task in non smart sensor mode."
+ )
self.log.exception(e, exc_info=True)
if registered:
@@ -1255,9 +1258,10 @@ def signal_handler(signum, frame): # pylint: disable=unused-argument
end_time = time.time()
duration = timedelta(seconds=end_time - start_time)
- Stats.timing('dag.{dag_id}.{task_id}.duration'.format(dag_id=task_copy.dag_id,
- task_id=task_copy.task_id),
- duration)
+ Stats.timing(
+ f'dag.{task_copy.dag_id}.{task_copy.task_id}.duration',
+ duration,
+ )
Stats.incr(f'operator_successes_{self.task.task_type}', 1, 1)
Stats.incr('ti_successes')
@@ -1310,17 +1314,18 @@ def _run_execute_callback(self, context, task):
@provide_session
def run( # pylint: disable=too-many-arguments
- self,
- verbose: bool = True,
- ignore_all_deps: bool = False,
- ignore_depends_on_past: bool = False,
- ignore_task_deps: bool = False,
- ignore_ti_state: bool = False,
- mark_success: bool = False,
- test_mode: bool = False,
- job_id: Optional[str] = None,
- pool: Optional[str] = None,
- session=None) -> None:
+ self,
+ verbose: bool = True,
+ ignore_all_deps: bool = False,
+ ignore_depends_on_past: bool = False,
+ ignore_task_deps: bool = False,
+ ignore_ti_state: bool = False,
+ mark_success: bool = False,
+ test_mode: bool = False,
+ job_id: Optional[str] = None,
+ pool: Optional[str] = None,
+ session=None,
+ ) -> None:
"""Run TaskInstance"""
res = self.check_and_change_state_before_execution(
verbose=verbose,
@@ -1332,14 +1337,12 @@ def run( # pylint: disable=too-many-arguments
test_mode=test_mode,
job_id=job_id,
pool=pool,
- session=session)
+ session=session,
+ )
if res:
self._run_raw_task(
- mark_success=mark_success,
- test_mode=test_mode,
- job_id=job_id,
- pool=pool,
- session=session)
+ mark_success=mark_success, test_mode=test_mode, job_id=job_id, pool=pool, session=session
+ )
def dry_run(self):
"""Only Renders Templates for the TI"""
@@ -1351,11 +1354,7 @@ def dry_run(self):
task_copy.dry_run()
@provide_session
- def _handle_reschedule(self,
- actual_start_date,
- reschedule_exception,
- test_mode=False,
- session=None):
+ def _handle_reschedule(self, actual_start_date, reschedule_exception, test_mode=False, session=None):
# Don't record reschedule request in test mode
if test_mode:
return
@@ -1364,9 +1363,16 @@ def _handle_reschedule(self,
self.set_duration()
# Log reschedule request
- session.add(TaskReschedule(self.task, self.execution_date, self._try_number,
- actual_start_date, self.end_date,
- reschedule_exception.reschedule_date))
+ session.add(
+ TaskReschedule(
+ self.task,
+ self.execution_date,
+ self._try_number,
+ actual_start_date,
+ self.end_date,
+ reschedule_exception.reschedule_date,
+ )
+ )
# set state
self.state = State.UP_FOR_RESCHEDULE
@@ -1431,12 +1437,12 @@ def handle_failure(self, error, test_mode=None, context=None, force_fail=False,
self.task_id,
self._safe_date('execution_date', '%Y%m%dT%H%M%S'),
self._safe_date('start_date', '%Y%m%dT%H%M%S'),
- self._safe_date('end_date', '%Y%m%dT%H%M%S')
+ self._safe_date('end_date', '%Y%m%dT%H%M%S'),
)
if email_for_state and task.email:
try:
self.email_alert(error)
- except Exception as exec2: # pylint: disable=broad-except
+ except Exception as exec2: # pylint: disable=broad-except
self.log.error('Failed to send email to: %s', task.email)
self.log.exception(exec2)
@@ -1444,7 +1450,7 @@ def handle_failure(self, error, test_mode=None, context=None, force_fail=False,
if callback:
try:
callback(context)
- except Exception as exec3: # pylint: disable=broad-except
+ except Exception as exec3: # pylint: disable=broad-except
self.log.error("Failed at executing callback")
self.log.exception(exec3)
@@ -1475,11 +1481,10 @@ def get_template_context(self, session=None) -> Context: # pylint: disable=too-
if task.dag.params:
params.update(task.dag.params)
from airflow.models.dagrun import DagRun # Avoid circular import
+
dag_run = (
session.query(DagRun)
- .filter_by(
- dag_id=task.dag.dag_id,
- execution_date=self.execution_date)
+ .filter_by(dag_id=task.dag.dag_id, execution_date=self.execution_date)
.first()
)
run_id = dag_run.run_id if dag_run else None
@@ -1523,7 +1528,8 @@ def get_template_context(self, session=None) -> Context: # pylint: disable=too-
tomorrow_ds_nodash = tomorrow_ds.replace('-', '')
ti_key_str = "{dag_id}__{task_id}__{ds_nodash}".format(
- dag_id=task.dag_id, task_id=task.task_id, ds_nodash=ds_nodash)
+ dag_id=task.dag_id, task_id=task.task_id, ds_nodash=ds_nodash
+ )
if task.params:
params.update(task.params)
@@ -1584,7 +1590,7 @@ def __repr__(self):
def get(
item: str,
# pylint: disable=protected-access
- default_var: Any = Variable._Variable__NO_DEFAULT_SENTINEL, # noqa
+ default_var: Any = Variable._Variable__NO_DEFAULT_SENTINEL, # noqa
):
"""Get Airflow Variable after deserializing JSON value"""
return Variable.get(item, default_var=default_var, deserialize_json=True)
@@ -1607,9 +1613,11 @@ def get(
'prev_ds_nodash': prev_ds_nodash,
'prev_execution_date': prev_execution_date,
'prev_execution_date_success': lazy_object_proxy.Proxy(
- lambda: self.get_previous_execution_date(state=State.SUCCESS)),
+ lambda: self.get_previous_execution_date(state=State.SUCCESS)
+ ),
'prev_start_date_success': lazy_object_proxy.Proxy(
- lambda: self.get_previous_start_date(state=State.SUCCESS)),
+ lambda: self.get_previous_start_date(state=State.SUCCESS)
+ ),
'run_id': run_id,
'task': task,
'task_instance': self,
@@ -1632,6 +1640,7 @@ def get(
def get_rendered_template_fields(self):
"""Fetch rendered template fields from DB"""
from airflow.models.renderedtifields import RenderedTaskInstanceFields
+
rendered_task_instance_fields = RenderedTaskInstanceFields.get_templated_fields(self)
if rendered_task_instance_fields:
for field_name, rendered_value in rendered_task_instance_fields.items():
@@ -1693,15 +1702,18 @@ def get_email_subject_content(self, exception):
jinja_context = {'ti': self}
# This function is called after changing the state
# from State.RUNNING so need to subtract 1 from self.try_number.
- jinja_context.update(dict(
- exception=exception,
- exception_html=exception_html,
- try_number=self.try_number - 1,
- max_tries=self.max_tries))
+ jinja_context.update(
+ dict(
+ exception=exception,
+ exception_html=exception_html,
+ try_number=self.try_number - 1,
+ max_tries=self.max_tries,
+ )
+ )
jinja_env = jinja2.Environment(
- loader=jinja2.FileSystemLoader(os.path.dirname(__file__)),
- autoescape=True)
+ loader=jinja2.FileSystemLoader(os.path.dirname(__file__)), autoescape=True
+ )
subject = jinja_env.from_string(default_subject).render(**jinja_context)
html_content = jinja_env.from_string(default_html_content).render(**jinja_context)
html_content_err = jinja_env.from_string(default_html_content_err).render(**jinja_context)
@@ -1709,11 +1721,14 @@ def get_email_subject_content(self, exception):
else:
jinja_context = self.get_template_context()
- jinja_context.update(dict(
- exception=exception,
- exception_html=exception_html,
- try_number=self.try_number - 1,
- max_tries=self.max_tries))
+ jinja_context.update(
+ dict(
+ exception=exception,
+ exception_html=exception_html,
+ try_number=self.try_number - 1,
+ max_tries=self.max_tries,
+ )
+ )
jinja_env = self.task.get_template_env()
@@ -1772,8 +1787,8 @@ def xcom_push(
if execution_date and execution_date < self.execution_date:
raise ValueError(
'execution_date can not be in the past (current '
- 'execution_date is {}; received {})'.format(
- self.execution_date, execution_date))
+ 'execution_date is {}; received {})'.format(self.execution_date, execution_date)
+ )
XCom.set(
key=key,
@@ -1785,13 +1800,13 @@ def xcom_push(
)
@provide_session
- def xcom_pull( # pylint: disable=inconsistent-return-statements
+ def xcom_pull( # pylint: disable=inconsistent-return-statements
self,
task_ids: Optional[Union[str, Iterable[str]]] = None,
dag_id: Optional[str] = None,
key: str = XCOM_RETURN_KEY,
include_prior_dates: bool = False,
- session: Session = None
+ session: Session = None,
) -> Any:
"""
Pull XComs that optionally meet certain criteria.
@@ -1833,7 +1848,7 @@ def xcom_pull( # pylint: disable=inconsistent-return-statements
dag_ids=dag_id,
task_ids=task_ids,
include_prior_dates=include_prior_dates,
- session=session
+ session=session,
).with_entities(XCom.value)
# Since we're only fetching the values field, and not the
@@ -1851,11 +1866,15 @@ def xcom_pull( # pylint: disable=inconsistent-return-statements
def get_num_running_task_instances(self, session):
"""Return Number of running TIs from the DB"""
# .count() is inefficient
- return session.query(func.count()).filter(
- TaskInstance.dag_id == self.dag_id,
- TaskInstance.task_id == self.task_id,
- TaskInstance.state == State.RUNNING
- ).scalar()
+ return (
+ session.query(func.count())
+ .filter(
+ TaskInstance.dag_id == self.dag_id,
+ TaskInstance.task_id == self.task_id,
+ TaskInstance.state == State.RUNNING,
+ )
+ .scalar()
+ )
def init_run_context(self, raw=False):
"""Sets the log context."""
@@ -1863,9 +1882,7 @@ def init_run_context(self, raw=False):
self._set_context(self)
@staticmethod
- def filter_for_tis(
- tis: Iterable[Union["TaskInstance", TaskInstanceKey]]
- ) -> Optional[BooleanClauseList]:
+ def filter_for_tis(tis: Iterable[Union["TaskInstance", TaskInstanceKey]]) -> Optional[BooleanClauseList]:
"""Returns SQLAlchemy filter to query selected task instances"""
if not tis:
return None
@@ -1895,7 +1912,8 @@ def filter_for_tis(
TaskInstance.dag_id == ti.dag_id,
TaskInstance.task_id == ti.task_id,
TaskInstance.execution_date == ti.execution_date,
- ) for ti in tis
+ )
+ for ti in tis
)
@@ -1993,7 +2011,8 @@ def construct_task_instance(self, session=None, lock_for_update=False) -> TaskIn
qry = session.query(TaskInstance).filter(
TaskInstance.dag_id == self._dag_id,
TaskInstance.task_id == self._task_id,
- TaskInstance.execution_date == self._execution_date)
+ TaskInstance.execution_date == self._execution_date,
+ )
if lock_for_update:
ti = qry.with_for_update().first()
@@ -2010,5 +2029,6 @@ def construct_task_instance(self, session=None, lock_for_update=False) -> TaskIn
from airflow.job.base_job import BaseJob
from airflow.models.dagrun import DagRun
+
TaskInstance.dag_run = relationship(DagRun)
TaskInstance.queued_by_job = relationship(BaseJob)
diff --git a/airflow/models/taskreschedule.py b/airflow/models/taskreschedule.py
index 88e18cdea73c2..6ef99f87ce90f 100644
--- a/airflow/models/taskreschedule.py
+++ b/airflow/models/taskreschedule.py
@@ -39,17 +39,16 @@ class TaskReschedule(Base):
reschedule_date = Column(UtcDateTime, nullable=False)
__table_args__ = (
- Index('idx_task_reschedule_dag_task_date', dag_id, task_id, execution_date,
- unique=False),
- ForeignKeyConstraint([task_id, dag_id, execution_date],
- ['task_instance.task_id', 'task_instance.dag_id',
- 'task_instance.execution_date'],
- name='task_reschedule_dag_task_date_fkey',
- ondelete='CASCADE')
+ Index('idx_task_reschedule_dag_task_date', dag_id, task_id, execution_date, unique=False),
+ ForeignKeyConstraint(
+ [task_id, dag_id, execution_date],
+ ['task_instance.task_id', 'task_instance.dag_id', 'task_instance.execution_date'],
+ name='task_reschedule_dag_task_date_fkey',
+ ondelete='CASCADE',
+ ),
)
- def __init__(self, task, execution_date, try_number, start_date, end_date,
- reschedule_date):
+ def __init__(self, task, execution_date, try_number, start_date, end_date, reschedule_date):
self.dag_id = task.dag_id
self.task_id = task.task_id
self.execution_date = execution_date
@@ -73,13 +72,11 @@ def query_for_task_instance(task_instance, descending=False, session=None):
:type descending: bool
"""
TR = TaskReschedule
- qry = (
- session
- .query(TR)
- .filter(TR.dag_id == task_instance.dag_id,
- TR.task_id == task_instance.task_id,
- TR.execution_date == task_instance.execution_date,
- TR.try_number == task_instance.try_number)
+ qry = session.query(TR).filter(
+ TR.dag_id == task_instance.dag_id,
+ TR.task_id == task_instance.task_id,
+ TR.execution_date == task_instance.execution_date,
+ TR.try_number == task_instance.try_number,
)
if descending:
return qry.order_by(desc(TR.id))
diff --git a/airflow/models/variable.py b/airflow/models/variable.py
index a95e1c905492b..571f41658dc4a 100644
--- a/airflow/models/variable.py
+++ b/airflow/models/variable.py
@@ -77,7 +77,7 @@ def set_val(self, value):
self.is_encrypted = fernet.is_encrypted
@declared_attr
- def val(cls): # pylint: disable=no-self-argument
+ def val(cls): # pylint: disable=no-self-argument
"""Get Airflow Variable from Metadata DB and decode it using the Fernet Key"""
return synonym('_val', descriptor=property(cls.get_val, cls.set_val))
@@ -96,8 +96,7 @@ def setdefault(cls, key, default, deserialize_json=False):
and un-encode it when retrieving a value
:return: Mixed
"""
- obj = Variable.get(key, default_var=None,
- deserialize_json=deserialize_json)
+ obj = Variable.get(key, default_var=None, deserialize_json=deserialize_json)
if obj is None:
if default is not None:
Variable.set(key, default, serialize_json=deserialize_json)
@@ -135,13 +134,7 @@ def get(
@classmethod
@provide_session
- def set(
- cls,
- key: str,
- value: Any,
- serialize_json: bool = False,
- session: Session = None
- ):
+ def set(cls, key: str, value: Any, serialize_json: bool = False, session: Session = None):
"""
Sets a value for an Airflow Variable with a given Key
diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py
index 6300bb05a22b2..bd3ce168077ee 100644
--- a/airflow/models/xcom.py
+++ b/airflow/models/xcom.py
@@ -72,20 +72,12 @@ def init_on_load(self):
def __repr__(self):
return ''.format(
- key=self.key,
- task_id=self.task_id,
- execution_date=self.execution_date)
+ key=self.key, task_id=self.task_id, execution_date=self.execution_date
+ )
@classmethod
@provide_session
- def set(
- cls,
- key,
- value,
- execution_date,
- task_id,
- dag_id,
- session=None):
+ def set(cls, key, value, execution_date, task_id, dag_id, session=None):
"""
Store an XCom value.
@@ -97,32 +89,27 @@ def set(
# remove any duplicate XComs
session.query(cls).filter(
- cls.key == key,
- cls.execution_date == execution_date,
- cls.task_id == task_id,
- cls.dag_id == dag_id).delete()
+ cls.key == key, cls.execution_date == execution_date, cls.task_id == task_id, cls.dag_id == dag_id
+ ).delete()
session.commit()
# insert new XCom
- session.add(XCom(
- key=key,
- value=value,
- execution_date=execution_date,
- task_id=task_id,
- dag_id=dag_id))
+ session.add(XCom(key=key, value=value, execution_date=execution_date, task_id=task_id, dag_id=dag_id))
session.commit()
@classmethod
@provide_session
- def get_one(cls,
- execution_date: pendulum.DateTime,
- key: Optional[str] = None,
- task_id: Optional[Union[str, Iterable[str]]] = None,
- dag_id: Optional[Union[str, Iterable[str]]] = None,
- include_prior_dates: bool = False,
- session: Session = None) -> Optional[Any]:
+ def get_one(
+ cls,
+ execution_date: pendulum.DateTime,
+ key: Optional[str] = None,
+ task_id: Optional[Union[str, Iterable[str]]] = None,
+ dag_id: Optional[Union[str, Iterable[str]]] = None,
+ include_prior_dates: bool = False,
+ session: Session = None,
+ ) -> Optional[Any]:
"""
Retrieve an XCom value, optionally meeting certain criteria. Returns None
of there are no results.
@@ -145,26 +132,30 @@ def get_one(cls,
:param session: database session
:type session: sqlalchemy.orm.session.Session
"""
- result = cls.get_many(execution_date=execution_date,
- key=key,
- task_ids=task_id,
- dag_ids=dag_id,
- include_prior_dates=include_prior_dates,
- session=session).first()
+ result = cls.get_many(
+ execution_date=execution_date,
+ key=key,
+ task_ids=task_id,
+ dag_ids=dag_id,
+ include_prior_dates=include_prior_dates,
+ session=session,
+ ).first()
if result:
return result.value
return None
@classmethod
@provide_session
- def get_many(cls,
- execution_date: pendulum.DateTime,
- key: Optional[str] = None,
- task_ids: Optional[Union[str, Iterable[str]]] = None,
- dag_ids: Optional[Union[str, Iterable[str]]] = None,
- include_prior_dates: bool = False,
- limit: Optional[int] = None,
- session: Session = None) -> Query:
+ def get_many(
+ cls,
+ execution_date: pendulum.DateTime,
+ key: Optional[str] = None,
+ task_ids: Optional[Union[str, Iterable[str]]] = None,
+ dag_ids: Optional[Union[str, Iterable[str]]] = None,
+ include_prior_dates: bool = False,
+ limit: Optional[int] = None,
+ session: Session = None,
+ ) -> Query:
"""
Composes a query to get one or more values from the xcom table.
@@ -212,10 +203,11 @@ def get_many(cls,
else:
filters.append(cls.execution_date == execution_date)
- query = (session
- .query(cls)
- .filter(and_(*filters))
- .order_by(cls.execution_date.desc(), cls.timestamp.desc()))
+ query = (
+ session.query(cls)
+ .filter(and_(*filters))
+ .order_by(cls.execution_date.desc(), cls.timestamp.desc())
+ )
if limit:
return query.limit(limit)
@@ -230,9 +222,7 @@ def delete(cls, xcoms, session=None):
xcoms = [xcoms]
for xcom in xcoms:
if not isinstance(xcom, XCom):
- raise TypeError(
- f'Expected XCom; received {xcom.__class__.__name__}'
- )
+ raise TypeError(f'Expected XCom; received {xcom.__class__.__name__}')
session.delete(xcom)
session.commit()
@@ -244,10 +234,12 @@ def serialize_value(value: Any):
try:
return json.dumps(value).encode('UTF-8')
except (ValueError, TypeError):
- log.error("Could not serialize the XCOM value into JSON. "
- "If you are using pickles instead of JSON "
- "for XCOM, then you need to enable pickle "
- "support for XCOM in your airflow config.")
+ log.error(
+ "Could not serialize the XCOM value into JSON. "
+ "If you are using pickles instead of JSON "
+ "for XCOM, then you need to enable pickle "
+ "support for XCOM in your airflow config."
+ )
raise
@staticmethod
@@ -259,10 +251,12 @@ def deserialize_value(result) -> Any:
try:
return json.loads(result.value.decode('UTF-8'))
except JSONDecodeError:
- log.error("Could not deserialize the XCOM value from JSON. "
- "If you are using pickles instead of JSON "
- "for XCOM, then you need to enable pickle "
- "support for XCOM in your airflow config.")
+ log.error(
+ "Could not deserialize the XCOM value from JSON. "
+ "If you are using pickles instead of JSON "
+ "for XCOM, then you need to enable pickle "
+ "support for XCOM in your airflow config."
+ )
raise
diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py
index aa29864fbe351..d09de69fd90e2 100644
--- a/airflow/models/xcom_arg.py
+++ b/airflow/models/xcom_arg.py
@@ -63,8 +63,7 @@ def __init__(self, operator: BaseOperator, key: str = XCOM_RETURN_KEY):
self._key = key
def __eq__(self, other):
- return (self.operator == other.operator
- and self.key == other.key)
+ return self.operator == other.operator and self.key == other.key
def __getitem__(self, item):
"""Implements xcomresult['some_result_key']"""
@@ -80,9 +79,10 @@ def __str__(self):
:return:
"""
- xcom_pull_kwargs = [f"task_ids='{self.operator.task_id}'",
- f"dag_id='{self.operator.dag.dag_id}'",
- ]
+ xcom_pull_kwargs = [
+ f"task_ids='{self.operator.task_id}'",
+ f"dag_id='{self.operator.dag.dag_id}'",
+ ]
if self.key is not None:
xcom_pull_kwargs.append(f"key='{self.key}'")
@@ -132,7 +132,8 @@ def resolve(self, context: Dict) -> Any:
if not resolved_value:
raise AirflowException(
f'XComArg result from {self.operator.task_id} at {self.operator.dag.dag_id} '
- f'with key="{self.key}"" is not found!')
+ f'with key="{self.key}"" is not found!'
+ )
resolved_value = resolved_value[0]
return resolved_value
diff --git a/airflow/operators/bash.py b/airflow/operators/bash.py
index 7f99a609bb9cf..6522025a3f7fc 100644
--- a/airflow/operators/bash.py
+++ b/airflow/operators/bash.py
@@ -103,17 +103,21 @@ class BashOperator(BaseOperator):
template_fields = ('bash_command', 'env')
template_fields_renderers = {'bash_command': 'bash', 'env': 'json'}
- template_ext = ('.sh', '.bash',)
+ template_ext = (
+ '.sh',
+ '.bash',
+ )
ui_color = '#f0ede4'
@apply_defaults
def __init__(
- self,
- *,
- bash_command: str,
- env: Optional[Dict[str, str]] = None,
- output_encoding: str = 'utf-8',
- **kwargs) -> None:
+ self,
+ *,
+ bash_command: str,
+ env: Optional[Dict[str, str]] = None,
+ output_encoding: str = 'utf-8',
+ **kwargs,
+ ) -> None:
super().__init__(**kwargs)
self.bash_command = bash_command
@@ -136,9 +140,10 @@ def execute(self, context):
env = os.environ.copy()
airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True)
- self.log.debug('Exporting the following env vars:\n%s',
- '\n'.join([f"{k}={v}"
- for k, v in airflow_context_vars.items()]))
+ self.log.debug(
+ 'Exporting the following env vars:\n%s',
+ '\n'.join([f"{k}={v}" for k, v in airflow_context_vars.items()]),
+ )
env.update(airflow_context_vars)
with TemporaryDirectory(prefix='airflowtmp') as tmp_dir:
@@ -158,7 +163,8 @@ def pre_exec():
stderr=STDOUT,
cwd=tmp_dir,
env=env,
- preexec_fn=pre_exec)
+ preexec_fn=pre_exec,
+ )
self.log.info('Output:')
line = ''
diff --git a/airflow/operators/bash_operator.py b/airflow/operators/bash_operator.py
index 328771b37aa8b..49cdecf81c14c 100644
--- a/airflow/operators/bash_operator.py
+++ b/airflow/operators/bash_operator.py
@@ -23,6 +23,5 @@
from airflow.operators.bash import STDOUT, BashOperator, Popen, gettempdir # noqa
warnings.warn(
- "This module is deprecated. Please use `airflow.operators.bash`.",
- DeprecationWarning, stacklevel=2
+ "This module is deprecated. Please use `airflow.operators.bash`.", DeprecationWarning, stacklevel=2
)
diff --git a/airflow/operators/check_operator.py b/airflow/operators/check_operator.py
index 6b7ea555c5fdd..d18a25c65765b 100644
--- a/airflow/operators/check_operator.py
+++ b/airflow/operators/check_operator.py
@@ -21,12 +21,14 @@
import warnings
from airflow.operators.sql import (
- SQLCheckOperator, SQLIntervalCheckOperator, SQLThresholdCheckOperator, SQLValueCheckOperator,
+ SQLCheckOperator,
+ SQLIntervalCheckOperator,
+ SQLThresholdCheckOperator,
+ SQLValueCheckOperator,
)
warnings.warn(
- "This module is deprecated. Please use `airflow.operators.sql`.",
- DeprecationWarning, stacklevel=2
+ "This module is deprecated. Please use `airflow.operators.sql`.", DeprecationWarning, stacklevel=2
)
@@ -40,7 +42,8 @@ def __init__(self, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.operators.sql.SQLCheckOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(**kwargs)
@@ -55,7 +58,8 @@ def __init__(self, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.operators.sql.SQLIntervalCheckOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(**kwargs)
@@ -70,7 +74,8 @@ def __init__(self, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.operators.sql.SQLThresholdCheckOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(**kwargs)
@@ -85,6 +90,7 @@ def __init__(self, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.operators.sql.SQLValueCheckOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(**kwargs)
diff --git a/airflow/operators/dagrun_operator.py b/airflow/operators/dagrun_operator.py
index b926081def5df..0ca0cf5335cc9 100644
--- a/airflow/operators/dagrun_operator.py
+++ b/airflow/operators/dagrun_operator.py
@@ -73,7 +73,7 @@ def __init__(
conf: Optional[Dict] = None,
execution_date: Optional[Union[str, datetime.datetime]] = None,
reset_dag_run: bool = False,
- **kwargs
+ **kwargs,
) -> None:
super().__init__(**kwargs)
self.trigger_dag_id = trigger_dag_id
@@ -119,10 +119,7 @@ def execute(self, context: Dict):
if dag_model is None:
raise DagNotFound(f"Dag id {self.trigger_dag_id} not found in DagModel")
- dag_bag = DagBag(
- dag_folder=dag_model.fileloc,
- read_dags_from_db=True
- )
+ dag_bag = DagBag(dag_folder=dag_model.fileloc, read_dags_from_db=True)
dag = dag_bag.get_dag(self.trigger_dag_id)
diff --git a/airflow/operators/docker_operator.py b/airflow/operators/docker_operator.py
index 245ee589c0196..eacdf9b7280d2 100644
--- a/airflow/operators/docker_operator.py
+++ b/airflow/operators/docker_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.docker.operators.docker`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/operators/druid_check_operator.py b/airflow/operators/druid_check_operator.py
index aad52ee431650..d8f30c3727384 100644
--- a/airflow/operators/druid_check_operator.py
+++ b/airflow/operators/druid_check_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.apache.druid.operators.druid_check`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/operators/email.py b/airflow/operators/email.py
index 80d11310d1a13..c979eaff2066a 100644
--- a/airflow/operators/email.py
+++ b/airflow/operators/email.py
@@ -52,16 +52,18 @@ class EmailOperator(BaseOperator):
@apply_defaults
def __init__( # pylint: disable=invalid-name
- self, *,
- to: Union[List[str], str],
- subject: str,
- html_content: str,
- files: Optional[List] = None,
- cc: Optional[Union[List[str], str]] = None,
- bcc: Optional[Union[List[str], str]] = None,
- mime_subtype: str = 'mixed',
- mime_charset: str = 'utf-8',
- **kwargs) -> None:
+ self,
+ *,
+ to: Union[List[str], str],
+ subject: str,
+ html_content: str,
+ files: Optional[List] = None,
+ cc: Optional[Union[List[str], str]] = None,
+ bcc: Optional[Union[List[str], str]] = None,
+ mime_subtype: str = 'mixed',
+ mime_charset: str = 'utf-8',
+ **kwargs,
+ ) -> None:
super().__init__(**kwargs)
self.to = to # pylint: disable=invalid-name
self.subject = subject
@@ -73,6 +75,13 @@ def __init__( # pylint: disable=invalid-name
self.mime_charset = mime_charset
def execute(self, context):
- send_email(self.to, self.subject, self.html_content,
- files=self.files, cc=self.cc, bcc=self.bcc,
- mime_subtype=self.mime_subtype, mime_charset=self.mime_charset)
+ send_email(
+ self.to,
+ self.subject,
+ self.html_content,
+ files=self.files,
+ cc=self.cc,
+ bcc=self.bcc,
+ mime_subtype=self.mime_subtype,
+ mime_charset=self.mime_charset,
+ )
diff --git a/airflow/operators/email_operator.py b/airflow/operators/email_operator.py
index 135eb68f13f0b..b3588941ec8fd 100644
--- a/airflow/operators/email_operator.py
+++ b/airflow/operators/email_operator.py
@@ -23,6 +23,5 @@
from airflow.operators.email import EmailOperator # noqa
warnings.warn(
- "This module is deprecated. Please use `airflow.operators.email`.",
- DeprecationWarning, stacklevel=2
+ "This module is deprecated. Please use `airflow.operators.email`.", DeprecationWarning, stacklevel=2
)
diff --git a/airflow/operators/gcs_to_s3.py b/airflow/operators/gcs_to_s3.py
index 19affc7c48553..5c9fb37f9f604 100644
--- a/airflow/operators/gcs_to_s3.py
+++ b/airflow/operators/gcs_to_s3.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.transfers.gcs_to_s3`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/operators/generic_transfer.py b/airflow/operators/generic_transfer.py
index 59f8fdcd86c57..a166b3a7a47a8 100644
--- a/airflow/operators/generic_transfer.py
+++ b/airflow/operators/generic_transfer.py
@@ -45,19 +45,23 @@ class GenericTransfer(BaseOperator):
"""
template_fields = ('sql', 'destination_table', 'preoperator')
- template_ext = ('.sql', '.hql',)
+ template_ext = (
+ '.sql',
+ '.hql',
+ )
ui_color = '#b0f07c'
@apply_defaults
def __init__(
- self,
- *,
- sql: str,
- destination_table: str,
- source_conn_id: str,
- destination_conn_id: str,
- preoperator: Optional[Union[str, List[str]]] = None,
- **kwargs) -> None:
+ self,
+ *,
+ sql: str,
+ destination_table: str,
+ source_conn_id: str,
+ destination_conn_id: str,
+ preoperator: Optional[Union[str, List[str]]] = None,
+ **kwargs,
+ ) -> None:
super().__init__(**kwargs)
self.sql = sql
self.destination_table = destination_table
diff --git a/airflow/operators/google_api_to_s3_transfer.py b/airflow/operators/google_api_to_s3_transfer.py
index 6f6d26500ef94..31b47fa0f08d4 100644
--- a/airflow/operators/google_api_to_s3_transfer.py
+++ b/airflow/operators/google_api_to_s3_transfer.py
@@ -25,9 +25,9 @@
from airflow.providers.amazon.aws.transfers.google_api_to_s3 import GoogleApiToS3Operator
warnings.warn(
- "This module is deprecated. "
- "Please use `airflow.providers.amazon.aws.transfers.google_api_to_s3`.",
- DeprecationWarning, stacklevel=2
+ "This module is deprecated. " "Please use `airflow.providers.amazon.aws.transfers.google_api_to_s3`.",
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -42,8 +42,9 @@ def __init__(self, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use
- `airflow.providers.amazon.aws.transfers.""" +
- "google_api_to_s3_transfer.GoogleApiToS3Operator`.",
- DeprecationWarning, stacklevel=3
+ `airflow.providers.amazon.aws.transfers."""
+ + "google_api_to_s3_transfer.GoogleApiToS3Operator`.",
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(**kwargs)
diff --git a/airflow/operators/hive_operator.py b/airflow/operators/hive_operator.py
index 196a863de7831..7429733e77403 100644
--- a/airflow/operators/hive_operator.py
+++ b/airflow/operators/hive_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.apache.hive.operators.hive`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/operators/hive_stats_operator.py b/airflow/operators/hive_stats_operator.py
index 4529d06cfeefb..737ccc4cd7ead 100644
--- a/airflow/operators/hive_stats_operator.py
+++ b/airflow/operators/hive_stats_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.apache.hive.operators.hive_stats`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/operators/hive_to_druid.py b/airflow/operators/hive_to_druid.py
index 88a57cba0dc22..6d29f0dbed00a 100644
--- a/airflow/operators/hive_to_druid.py
+++ b/airflow/operators/hive_to_druid.py
@@ -27,7 +27,8 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.apache.druid.transfers.hive_to_druid`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -43,6 +44,7 @@ def __init__(self, **kwargs):
"""This class is deprecated.
Please use
`airflow.providers.apache.druid.transfers.hive_to_druid.HiveToDruidOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(**kwargs)
diff --git a/airflow/operators/hive_to_mysql.py b/airflow/operators/hive_to_mysql.py
index 6264e57e61caa..bc42d89ec036c 100644
--- a/airflow/operators/hive_to_mysql.py
+++ b/airflow/operators/hive_to_mysql.py
@@ -27,7 +27,8 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.apache.hive.transfers.hive_to_mysql`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -43,6 +44,7 @@ def __init__(self, **kwargs):
"""This class is deprecated.
Please use
`airflow.providers.apache.hive.transfers.hive_to_mysql.HiveToMySqlOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(**kwargs)
diff --git a/airflow/operators/hive_to_samba_operator.py b/airflow/operators/hive_to_samba_operator.py
index d1d1e942e9ef1..3716bf90e9b3e 100644
--- a/airflow/operators/hive_to_samba_operator.py
+++ b/airflow/operators/hive_to_samba_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.apache.hive.transfers.hive_to_samba`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/operators/http_operator.py b/airflow/operators/http_operator.py
index edecef3c345d3..34e938144f009 100644
--- a/airflow/operators/http_operator.py
+++ b/airflow/operators/http_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.http.operators.http`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/operators/jdbc_operator.py b/airflow/operators/jdbc_operator.py
index 3361c4783824c..c0bf7021eab1a 100644
--- a/airflow/operators/jdbc_operator.py
+++ b/airflow/operators/jdbc_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.jdbc.operators.jdbc`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/operators/latest_only.py b/airflow/operators/latest_only.py
index d19e1b047efaa..e458c72e28e3f 100644
--- a/airflow/operators/latest_only.py
+++ b/airflow/operators/latest_only.py
@@ -44,17 +44,17 @@ def choose_branch(self, context: Dict) -> Union[str, Iterable[str]]:
# If the DAG Run is externally triggered, then return without
# skipping downstream tasks
if context['dag_run'] and context['dag_run'].external_trigger:
- self.log.info(
- "Externally triggered DAG_Run: allowing execution to proceed.")
+ self.log.info("Externally triggered DAG_Run: allowing execution to proceed.")
return list(context['task'].get_direct_relative_ids(upstream=False))
now = pendulum.now('UTC')
- left_window = context['dag'].following_schedule(
- context['execution_date'])
+ left_window = context['dag'].following_schedule(context['execution_date'])
right_window = context['dag'].following_schedule(left_window)
self.log.info(
'Checking latest only with left_window: %s right_window: %s now: %s',
- left_window, right_window, now
+ left_window,
+ right_window,
+ now,
)
if not left_window < now <= right_window:
diff --git a/airflow/operators/latest_only_operator.py b/airflow/operators/latest_only_operator.py
index af6c90e19397e..cb1a531260305 100644
--- a/airflow/operators/latest_only_operator.py
+++ b/airflow/operators/latest_only_operator.py
@@ -22,6 +22,5 @@
from airflow.operators.latest_only import LatestOnlyOperator # noqa
warnings.warn(
- "This module is deprecated. Please use `airflow.operators.latest_only`.",
- DeprecationWarning, stacklevel=2
+ "This module is deprecated. Please use `airflow.operators.latest_only`.", DeprecationWarning, stacklevel=2
)
diff --git a/airflow/operators/mssql_operator.py b/airflow/operators/mssql_operator.py
index 07a20192a51a2..8b5e8cc9caee3 100644
--- a/airflow/operators/mssql_operator.py
+++ b/airflow/operators/mssql_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.microsoft.mssql.operators.mssql`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/operators/mssql_to_hive.py b/airflow/operators/mssql_to_hive.py
index 3fd6f7496280d..8749d709d9de6 100644
--- a/airflow/operators/mssql_to_hive.py
+++ b/airflow/operators/mssql_to_hive.py
@@ -26,7 +26,8 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.apache.hive.transfers.mssql_to_hive`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -42,6 +43,7 @@ def __init__(self, **kwargs):
"""This class is deprecated.
Please use
`airflow.providers.apache.hive.transfers.mssql_to_hive.MsSqlToHiveOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(**kwargs)
diff --git a/airflow/operators/mysql_operator.py b/airflow/operators/mysql_operator.py
index 6b2b1849ecbb7..05ae818cc485a 100644
--- a/airflow/operators/mysql_operator.py
+++ b/airflow/operators/mysql_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.mysql.operators.mysql`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/operators/mysql_to_hive.py b/airflow/operators/mysql_to_hive.py
index 47735bb8d679c..243980f55f187 100644
--- a/airflow/operators/mysql_to_hive.py
+++ b/airflow/operators/mysql_to_hive.py
@@ -23,7 +23,8 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.apache.hive.transfers.mysql_to_hive`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -37,6 +38,7 @@ def __init__(self, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.apache.hive.transfers.mysql_to_hive.MySqlToHiveOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(**kwargs)
diff --git a/airflow/operators/oracle_operator.py b/airflow/operators/oracle_operator.py
index 7b5e5a1e4dc9c..1734d22d259a2 100644
--- a/airflow/operators/oracle_operator.py
+++ b/airflow/operators/oracle_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.oracle.operators.oracle`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/operators/papermill_operator.py b/airflow/operators/papermill_operator.py
index f071866a12bad..ab2d4fcc316e4 100644
--- a/airflow/operators/papermill_operator.py
+++ b/airflow/operators/papermill_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.papermill.operators.papermill`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/operators/pig_operator.py b/airflow/operators/pig_operator.py
index 82e49452b6df5..4eec846a1f3e6 100644
--- a/airflow/operators/pig_operator.py
+++ b/airflow/operators/pig_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.apache.pig.operators.pig`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/operators/postgres_operator.py b/airflow/operators/postgres_operator.py
index 815311c162d93..c6f91b6bc1455 100644
--- a/airflow/operators/postgres_operator.py
+++ b/airflow/operators/postgres_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.postgres.operators.postgres`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/operators/presto_check_operator.py b/airflow/operators/presto_check_operator.py
index 07546c297a889..79cc12ba7e5ff 100644
--- a/airflow/operators/presto_check_operator.py
+++ b/airflow/operators/presto_check_operator.py
@@ -23,8 +23,7 @@
from airflow.operators.sql import SQLCheckOperator, SQLIntervalCheckOperator, SQLValueCheckOperator # noqa
warnings.warn(
- "This module is deprecated. Please use `airflow.operators.sql`.",
- DeprecationWarning, stacklevel=2
+ "This module is deprecated. Please use `airflow.operators.sql`.", DeprecationWarning, stacklevel=2
)
@@ -38,7 +37,8 @@ def __init__(self, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.operators.sql.SQLCheckOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(**kwargs)
@@ -55,7 +55,8 @@ def __init__(self, **kwargs):
This class is deprecated.l
Please use `airflow.operators.sql.SQLIntervalCheckOperator`.
""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(**kwargs)
@@ -72,6 +73,7 @@ def __init__(self, **kwargs):
This class is deprecated.l
Please use `airflow.operators.sql.SQLValueCheckOperator`.
""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(**kwargs)
diff --git a/airflow/operators/presto_to_mysql.py b/airflow/operators/presto_to_mysql.py
index 6a50d6bf6b694..f5ccd423067dd 100644
--- a/airflow/operators/presto_to_mysql.py
+++ b/airflow/operators/presto_to_mysql.py
@@ -27,7 +27,8 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.mysql.transfers.presto_to_mysql`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -43,6 +44,7 @@ def __init__(self, **kwargs):
"""This class is deprecated.
Please use
`airflow.providers.mysql.transfers.presto_to_mysql.PrestoToMySqlOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(**kwargs)
diff --git a/airflow/operators/python.py b/airflow/operators/python.py
index 344f0ecaceb4b..08546cabc6e6e 100644
--- a/airflow/operators/python.py
+++ b/airflow/operators/python.py
@@ -73,7 +73,10 @@ class PythonOperator(BaseOperator):
# since we won't mutate the arguments, we should just do the shallow copy
# there are some cases we can't deepcopy the objects(e.g protobuf).
- shallow_copy_attrs = ('python_callable', 'op_kwargs',)
+ shallow_copy_attrs = (
+ 'python_callable',
+ 'op_kwargs',
+ )
@apply_defaults
def __init__(
@@ -84,11 +87,14 @@ def __init__(
op_kwargs: Optional[Dict] = None,
templates_dict: Optional[Dict] = None,
templates_exts: Optional[List[str]] = None,
- **kwargs
+ **kwargs,
) -> None:
if kwargs.get("provide_context"):
- warnings.warn("provide_context is deprecated as of 2.0 and is no longer required",
- DeprecationWarning, stacklevel=2)
+ warnings.warn(
+ "provide_context is deprecated as of 2.0 and is no longer required",
+ DeprecationWarning,
+ stacklevel=2,
+ )
kwargs.pop('provide_context', None)
super().__init__(**kwargs)
if not callable(python_callable):
@@ -101,9 +107,7 @@ def __init__(
self.template_ext = templates_exts
@staticmethod
- def determine_op_kwargs(python_callable: Callable,
- context: Dict,
- num_op_args: int = 0) -> Dict:
+ def determine_op_kwargs(python_callable: Callable, context: Dict, num_op_args: int = 0) -> Dict:
"""
Function that will inspect the signature of a python_callable to determine which
values need to be passed to the function.
@@ -190,7 +194,7 @@ def __init__(
op_args: Tuple[Any],
op_kwargs: Dict[str, Any],
multiple_outputs: bool = False,
- **kwargs
+ **kwargs,
) -> None:
kwargs['task_id'] = self._get_unique_task_id(task_id, kwargs.get('dag'))
super().__init__(**kwargs)
@@ -221,9 +225,11 @@ def _get_unique_task_id(task_id: str, dag: Optional[DAG] = None) -> str:
return task_id
core = re.split(r'__\d+$', task_id)[0]
suffixes = sorted(
- [int(re.split(r'^.+__', task_id)[1])
- for task_id in dag.task_ids
- if re.match(rf'^{core}__\d+$', task_id)]
+ [
+ int(re.split(r'^.+__', task_id)[1])
+ for task_id in dag.task_ids
+ if re.match(rf'^{core}__\d+$', task_id)
+ ]
)
if not suffixes:
return f'{core}__1'
@@ -251,13 +257,16 @@ def execute(self, context: Dict):
if isinstance(return_value, dict):
for key in return_value.keys():
if not isinstance(key, str):
- raise AirflowException('Returned dictionary keys must be strings when using '
- f'multiple_outputs, found {key} ({type(key)}) instead')
+ raise AirflowException(
+ 'Returned dictionary keys must be strings when using '
+ f'multiple_outputs, found {key} ({type(key)}) instead'
+ )
for key, value in return_value.items():
self.xcom_push(context, key, value)
else:
- raise AirflowException(f'Returned output was type {type(return_value)} expected dictionary '
- 'for multiple_outputs')
+ raise AirflowException(
+ f'Returned output was type {type(return_value)} expected dictionary ' 'for multiple_outputs'
+ )
return return_value
@@ -265,9 +274,7 @@ def execute(self, context: Dict):
def task(
- python_callable: Optional[Callable] = None,
- multiple_outputs: bool = False,
- **kwargs
+ python_callable: Optional[Callable] = None, multiple_outputs: bool = False, **kwargs
) -> Callable[[T], T]:
"""
Python operator decorator. Wraps a function into an Airflow operator.
@@ -282,6 +289,7 @@ def task(
:type multiple_outputs: bool
"""
+
def wrapper(f: T):
"""
Python wrapper to generate PythonDecoratedOperator out of simple python functions.
@@ -292,12 +300,19 @@ def wrapper(f: T):
@functools.wraps(f)
def factory(*args, **f_kwargs):
- op = _PythonDecoratedOperator(python_callable=f, op_args=args, op_kwargs=f_kwargs,
- multiple_outputs=multiple_outputs, **kwargs)
+ op = _PythonDecoratedOperator(
+ python_callable=f,
+ op_args=args,
+ op_kwargs=f_kwargs,
+ multiple_outputs=multiple_outputs,
+ **kwargs,
+ )
if f.__doc__:
op.doc_md = f.__doc__
return XComArg(op)
+
return cast(T, factory)
+
if callable(python_callable):
return wrapper(python_callable)
elif python_callable is not None:
@@ -427,22 +442,16 @@ class PythonVirtualenvOperator(PythonOperator):
'ts_nodash',
'ts_nodash_with_tz',
'yesterday_ds',
- 'yesterday_ds_nodash'
+ 'yesterday_ds_nodash',
}
PENDULUM_SERIALIZABLE_CONTEXT_KEYS = {
'execution_date',
'next_execution_date',
'prev_execution_date',
'prev_execution_date_success',
- 'prev_start_date_success'
- }
- AIRFLOW_SERIALIZABLE_CONTEXT_KEYS = {
- 'macros',
- 'conf',
- 'dag',
- 'dag_run',
- 'task'
+ 'prev_start_date_success',
}
+ AIRFLOW_SERIALIZABLE_CONTEXT_KEYS = {'macros', 'conf', 'dag', 'dag_run', 'task'}
@apply_defaults
def __init__( # pylint: disable=too-many-arguments
@@ -458,26 +467,31 @@ def __init__( # pylint: disable=too-many-arguments
string_args: Optional[Iterable[str]] = None,
templates_dict: Optional[Dict] = None,
templates_exts: Optional[List[str]] = None,
- **kwargs
+ **kwargs,
):
if (
- not isinstance(python_callable, types.FunctionType) or
- isinstance(python_callable, types.LambdaType) and python_callable.__name__ == ""
+ not isinstance(python_callable, types.FunctionType)
+ or isinstance(python_callable, types.LambdaType)
+ and python_callable.__name__ == ""
):
raise AirflowException('PythonVirtualenvOperator only supports functions for python_callable arg')
if (
- python_version and str(python_version)[0] != str(sys.version_info.major) and
- (op_args or op_kwargs)
+ python_version
+ and str(python_version)[0] != str(sys.version_info.major)
+ and (op_args or op_kwargs)
):
- raise AirflowException("Passing op_args or op_kwargs is not supported across different Python "
- "major versions for PythonVirtualenvOperator. Please use string_args.")
+ raise AirflowException(
+ "Passing op_args or op_kwargs is not supported across different Python "
+ "major versions for PythonVirtualenvOperator. Please use string_args."
+ )
super().__init__(
python_callable=python_callable,
op_args=op_args,
op_kwargs=op_kwargs,
templates_dict=templates_dict,
templates_exts=templates_exts,
- **kwargs)
+ **kwargs,
+ )
self.requirements = list(requirements or [])
self.string_args = string_args or []
self.python_version = python_version
@@ -505,7 +519,7 @@ def execute_callable(self):
venv_directory=tmp_dir,
python_bin=f'python{self.python_version}' if self.python_version else None,
system_site_packages=self.system_site_packages,
- requirements=self.requirements
+ requirements=self.requirements,
)
self._write_args(input_filename)
@@ -516,18 +530,20 @@ def execute_callable(self):
op_kwargs=self.op_kwargs,
pickling_library=self.pickling_library.__name__,
python_callable=self.python_callable.__name__,
- python_callable_source=dedent(inspect.getsource(self.python_callable))
+ python_callable_source=dedent(inspect.getsource(self.python_callable)),
),
- filename=script_filename
+ filename=script_filename,
)
- execute_in_subprocess(cmd=[
- f'{tmp_dir}/bin/python',
- script_filename,
- input_filename,
- output_filename,
- string_args_filename
- ])
+ execute_in_subprocess(
+ cmd=[
+ f'{tmp_dir}/bin/python',
+ script_filename,
+ input_filename,
+ output_filename,
+ string_args_filename,
+ ]
+ )
return self._read_result(output_filename)
@@ -561,8 +577,10 @@ def _read_result(self, filename):
try:
return self.pickling_library.load(file)
except ValueError:
- self.log.error("Error deserializing result. Note that result deserialization "
- "is not supported across major Python versions.")
+ self.log.error(
+ "Error deserializing result. Note that result deserialization "
+ "is not supported across major Python versions."
+ )
raise
diff --git a/airflow/operators/python_operator.py b/airflow/operators/python_operator.py
index 5318c15cc08ed..7ae3d4fa78c80 100644
--- a/airflow/operators/python_operator.py
+++ b/airflow/operators/python_operator.py
@@ -21,10 +21,12 @@
# pylint: disable=unused-import
from airflow.operators.python import ( # noqa
- BranchPythonOperator, PythonOperator, PythonVirtualenvOperator, ShortCircuitOperator,
+ BranchPythonOperator,
+ PythonOperator,
+ PythonVirtualenvOperator,
+ ShortCircuitOperator,
)
warnings.warn(
- "This module is deprecated. Please use `airflow.operators.python`.",
- DeprecationWarning, stacklevel=2
+ "This module is deprecated. Please use `airflow.operators.python`.", DeprecationWarning, stacklevel=2
)
diff --git a/airflow/operators/redshift_to_s3_operator.py b/airflow/operators/redshift_to_s3_operator.py
index a1ebadb4a54cc..20b92306432f2 100644
--- a/airflow/operators/redshift_to_s3_operator.py
+++ b/airflow/operators/redshift_to_s3_operator.py
@@ -26,7 +26,8 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.transfers.redshift_to_s3`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -42,6 +43,7 @@ def __init__(self, **kwargs):
"""This class is deprecated.
Please use
`airflow.providers.amazon.aws.transfers.redshift_to_s3.RedshiftToS3Operator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(**kwargs)
diff --git a/airflow/operators/s3_file_transform_operator.py b/airflow/operators/s3_file_transform_operator.py
index 8b175a8c7f688..6616778495225 100644
--- a/airflow/operators/s3_file_transform_operator.py
+++ b/airflow/operators/s3_file_transform_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3_file_transform`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/operators/s3_to_hive_operator.py b/airflow/operators/s3_to_hive_operator.py
index d7108cc072304..584a31f37c0ed 100644
--- a/airflow/operators/s3_to_hive_operator.py
+++ b/airflow/operators/s3_to_hive_operator.py
@@ -23,7 +23,8 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.apache.hive.transfers.s3_to_hive`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -37,6 +38,7 @@ def __init__(self, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.apache.hive.transfers.s3_to_hive.S3ToHiveOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(**kwargs)
diff --git a/airflow/operators/s3_to_redshift_operator.py b/airflow/operators/s3_to_redshift_operator.py
index 2148fdb429b39..b885c58c3e6ce 100644
--- a/airflow/operators/s3_to_redshift_operator.py
+++ b/airflow/operators/s3_to_redshift_operator.py
@@ -26,7 +26,8 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.transfers.s3_to_redshift`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
@@ -42,6 +43,7 @@ def __init__(self, **kwargs):
"""This class is deprecated.
Please use
`airflow.providers.amazon.aws.transfers.s3_to_redshift.S3ToRedshiftOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(**kwargs)
diff --git a/airflow/operators/slack_operator.py b/airflow/operators/slack_operator.py
index e154d9f40b576..dedd7699bc94f 100644
--- a/airflow/operators/slack_operator.py
+++ b/airflow/operators/slack_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.slack.operators.slack`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/operators/sql.py b/airflow/operators/sql.py
index 45cb07c2c131e..3210ff1ae0148 100644
--- a/airflow/operators/sql.py
+++ b/airflow/operators/sql.py
@@ -71,13 +71,14 @@ class SQLCheckOperator(BaseOperator):
"""
template_fields: Iterable[str] = ("sql",)
- template_ext: Iterable[str] = (".hql", ".sql",)
+ template_ext: Iterable[str] = (
+ ".hql",
+ ".sql",
+ )
ui_color = "#fff7e6"
@apply_defaults
- def __init__(
- self, *, sql: str, conn_id: Optional[str] = None, **kwargs
- ) -> None:
+ def __init__(self, *, sql: str, conn_id: Optional[str] = None, **kwargs) -> None:
super().__init__(**kwargs)
self.conn_id = conn_id
self.sql = sql
@@ -90,11 +91,7 @@ def execute(self, context=None):
if not records:
raise AirflowException("The query returned None")
elif not all(bool(r) for r in records):
- raise AirflowException(
- "Test failed.\nQuery:\n{query}\nResults:\n{records!s}".format(
- query=self.sql, records=records
- )
- )
+ raise AirflowException(f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}")
self.log.info("Success.")
@@ -148,7 +145,8 @@ class SQLValueCheckOperator(BaseOperator):
@apply_defaults
def __init__(
- self, *,
+ self,
+ *,
sql: str,
pass_value: Any,
tolerance: Any = None,
@@ -191,9 +189,7 @@ def execute(self, context=None):
try:
numeric_records = self._to_float(records)
except (ValueError, TypeError):
- raise AirflowException(
- f"Converting a result to float failed.\n{error_msg}"
- )
+ raise AirflowException(f"Converting a result to float failed.\n{error_msg}")
tests = self._get_numeric_matches(numeric_records, pass_value_conv)
else:
tests = []
@@ -257,7 +253,10 @@ class SQLIntervalCheckOperator(BaseOperator):
__mapper_args__ = {"polymorphic_identity": "SQLIntervalCheckOperator"}
template_fields: Iterable[str] = ("sql1", "sql2")
- template_ext: Iterable[str] = (".hql", ".sql",)
+ template_ext: Iterable[str] = (
+ ".hql",
+ ".sql",
+ )
ui_color = "#fff7e6"
ratio_formulas = {
@@ -267,7 +266,8 @@ class SQLIntervalCheckOperator(BaseOperator):
@apply_defaults
def __init__(
- self, *,
+ self,
+ *,
table: str,
metrics_thresholds: Dict[str, int],
date_filter_column: Optional[str] = "ds",
@@ -279,15 +279,10 @@ def __init__(
):
super().__init__(**kwargs)
if ratio_formula not in self.ratio_formulas:
- msg_template = (
- "Invalid diff_method: {diff_method}. "
- "Supported diff methods are: {diff_methods}"
- )
+ msg_template = "Invalid diff_method: {diff_method}. Supported diff methods are: {diff_methods}"
raise AirflowException(
- msg_template.format(
- diff_method=ratio_formula, diff_methods=self.ratio_formulas
- )
+ msg_template.format(diff_method=ratio_formula, diff_methods=self.ratio_formulas)
)
self.ratio_formula = ratio_formula
self.ignore_zero = ignore_zero
@@ -332,9 +327,7 @@ def execute(self, context=None):
ratios[metric] = None
test_results[metric] = self.ignore_zero
else:
- ratios[metric] = self.ratio_formulas[self.ratio_formula](
- current[metric], reference[metric]
- )
+ ratios[metric] = self.ratio_formulas[self.ratio_formula](current[metric], reference[metric])
test_results[metric] = ratios[metric] < threshold
self.log.info(
@@ -368,9 +361,7 @@ def execute(self, context=None):
self.metrics_thresholds[k],
)
raise AirflowException(
- "The following tests have failed:\n {}".format(
- ", ".join(sorted(failed_tests))
- )
+ "The following tests have failed:\n {}".format(", ".join(sorted(failed_tests)))
)
self.log.info("All tests have passed")
@@ -411,7 +402,8 @@ class SQLThresholdCheckOperator(BaseOperator):
@apply_defaults
def __init__(
- self, *,
+ self,
+ *,
sql: str,
min_threshold: Any,
max_threshold: Any,
@@ -500,7 +492,8 @@ class BranchSQLOperator(BaseOperator, SkipMixin):
@apply_defaults
def __init__(
- self, *,
+ self,
+ *,
sql: str,
follow_task_ids_if_true: List[str],
follow_task_ids_if_false: List[str],
@@ -525,7 +518,9 @@ def _get_hook(self):
if conn.conn_type not in ALLOWED_CONN_TYPE:
raise AirflowException(
"The connection type is not supported by BranchSQLOperator.\
- Supported connection types: {}".format(list(ALLOWED_CONN_TYPE))
+ Supported connection types: {}".format(
+ list(ALLOWED_CONN_TYPE)
+ )
)
if not self._hook:
@@ -540,22 +535,16 @@ def execute(self, context: Dict):
self._hook = self._get_hook()
if self._hook is None:
- raise AirflowException(
- "Failed to establish connection to '%s'" % self.conn_id
- )
+ raise AirflowException("Failed to establish connection to '%s'" % self.conn_id)
if self.sql is None:
raise AirflowException("Expected 'sql' parameter is missing.")
if self.follow_task_ids_if_true is None:
- raise AirflowException(
- "Expected 'follow_task_ids_if_true' parameter is missing."
- )
+ raise AirflowException("Expected 'follow_task_ids_if_true' parameter is missing.")
if self.follow_task_ids_if_false is None:
- raise AirflowException(
- "Expected 'follow_task_ids_if_false' parameter is missing."
- )
+ raise AirflowException("Expected 'follow_task_ids_if_false' parameter is missing.")
self.log.info(
"Executing: %s (with parameters %s) with connection: %s",
@@ -595,16 +584,14 @@ def execute(self, context: Dict):
follow_branch = self.follow_task_ids_if_true
else:
raise AirflowException(
- "Unexpected query return result '%s' type '%s'"
- % (query_result, type(query_result))
+ "Unexpected query return result '{}' type '{}'".format(query_result, type(query_result))
)
if follow_branch is None:
follow_branch = self.follow_task_ids_if_false
except ValueError:
raise AirflowException(
- "Unexpected query return result '%s' type '%s'"
- % (query_result, type(query_result))
+ "Unexpected query return result '{}' type '{}'".format(query_result, type(query_result))
)
self.skip_all_except(context["ti"], follow_branch)
diff --git a/airflow/operators/sql_branch_operator.py b/airflow/operators/sql_branch_operator.py
index 3c773b6109040..5dfa7b36ab2f6 100644
--- a/airflow/operators/sql_branch_operator.py
+++ b/airflow/operators/sql_branch_operator.py
@@ -20,8 +20,7 @@
from airflow.operators.sql import BranchSQLOperator
warnings.warn(
- """This module is deprecated. Please use `airflow.operators.sql`.""",
- DeprecationWarning, stacklevel=2
+ """This module is deprecated. Please use `airflow.operators.sql`.""", DeprecationWarning, stacklevel=2
)
@@ -35,6 +34,7 @@ def __init__(self, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.operators.sql.BranchSQLOperator`.""",
- DeprecationWarning, stacklevel=3
+ DeprecationWarning,
+ stacklevel=3,
)
super().__init__(**kwargs)
diff --git a/airflow/operators/sqlite_operator.py b/airflow/operators/sqlite_operator.py
index 6729852a428e7..daa496da70f2f 100644
--- a/airflow/operators/sqlite_operator.py
+++ b/airflow/operators/sqlite_operator.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.sqlite.operators.sqlite`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/operators/subdag_operator.py b/airflow/operators/subdag_operator.py
index 7ef0392de68b8..f5c147e940439 100644
--- a/airflow/operators/subdag_operator.py
+++ b/airflow/operators/subdag_operator.py
@@ -63,13 +63,15 @@ class SubDagOperator(BaseSensorOperator):
@provide_session
@apply_defaults
- def __init__(self,
- *,
- subdag: DAG,
- session: Optional[Session] = None,
- conf: Optional[Dict] = None,
- propagate_skipped_state: Optional[SkippedStatePropagationOptions] = None,
- **kwargs) -> None:
+ def __init__(
+ self,
+ *,
+ subdag: DAG,
+ session: Optional[Session] = None,
+ conf: Optional[Dict] = None,
+ propagate_skipped_state: Optional[SkippedStatePropagationOptions] = None,
+ **kwargs,
+ ) -> None:
super().__init__(**kwargs)
self.subdag = subdag
self.conf = conf
@@ -88,7 +90,8 @@ def _validate_dag(self, kwargs):
raise AirflowException(
"The subdag's dag_id should have the form '{{parent_dag_id}}.{{this_task_id}}'. "
"Expected '{d}.{t}'; received '{rcvd}'.".format(
- d=dag.dag_id, t=kwargs['task_id'], rcvd=self.subdag.dag_id)
+ d=dag.dag_id, t=kwargs['task_id'], rcvd=self.subdag.dag_id
+ )
)
def _validate_pool(self, session):
@@ -96,11 +99,7 @@ def _validate_pool(self, session):
conflicts = [t for t in self.subdag.tasks if t.pool == self.pool]
if conflicts:
# only query for pool conflicts if one may exist
- pool = (session
- .query(Pool)
- .filter(Pool.slots == 1)
- .filter(Pool.pool == self.pool)
- .first())
+ pool = session.query(Pool).filter(Pool.slots == 1).filter(Pool.pool == self.pool).first()
if pool and any(t.pool == self.pool for t in self.subdag.tasks):
raise AirflowException(
'SubDagOperator {sd} and subdag task{plural} {t} both '
@@ -109,7 +108,7 @@ def _validate_pool(self, session):
sd=self.task_id,
plural=len(conflicts) > 1,
t=', '.join(t.task_id for t in conflicts),
- p=self.pool
+ p=self.pool,
)
)
@@ -132,11 +131,12 @@ def _reset_dag_run_and_task_instances(self, dag_run, execution_date):
with create_session() as session:
dag_run.state = State.RUNNING
session.merge(dag_run)
- failed_task_instances = (session
- .query(TaskInstance)
- .filter(TaskInstance.dag_id == self.subdag.dag_id)
- .filter(TaskInstance.execution_date == execution_date)
- .filter(TaskInstance.state.in_([State.FAILED, State.UPSTREAM_FAILED])))
+ failed_task_instances = (
+ session.query(TaskInstance)
+ .filter(TaskInstance.dag_id == self.subdag.dag_id)
+ .filter(TaskInstance.execution_date == execution_date)
+ .filter(TaskInstance.state.in_([State.FAILED, State.UPSTREAM_FAILED]))
+ )
for task_instance in failed_task_instances:
task_instance.state = State.NONE
@@ -172,9 +172,7 @@ def post_execute(self, context, result=None):
self.log.info("Execution finished. State is %s", dag_run.state)
if dag_run.state != State.SUCCESS:
- raise AirflowException(
- f"Expected state: SUCCESS. Actual state: {dag_run.state}"
- )
+ raise AirflowException(f"Expected state: SUCCESS. Actual state: {dag_run.state}")
if self.propagate_skipped_state and self._check_skipped_states(context):
self._skip_downstream_tasks(context)
@@ -187,16 +185,15 @@ def _check_skipped_states(self, context):
if self.propagate_skipped_state == SkippedStatePropagationOptions.ALL_LEAVES:
return all(ti.state == State.SKIPPED for ti in leaves_tis)
raise AirflowException(
- f'Unimplemented SkippedStatePropagationOptions {self.propagate_skipped_state} used.')
+ f'Unimplemented SkippedStatePropagationOptions {self.propagate_skipped_state} used.'
+ )
def _get_leaves_tis(self, execution_date):
leaves_tis = []
for leaf in self.subdag.leaves:
try:
ti = get_task_instance(
- dag_id=self.subdag.dag_id,
- task_id=leaf.task_id,
- execution_date=execution_date
+ dag_id=self.subdag.dag_id, task_id=leaf.task_id, execution_date=execution_date
)
leaves_tis.append(ti)
except TaskInstanceNotFound:
@@ -204,8 +201,11 @@ def _get_leaves_tis(self, execution_date):
return leaves_tis
def _skip_downstream_tasks(self, context):
- self.log.info('Skipping downstream tasks because propagate_skipped_state is set to %s '
- 'and skipped task(s) were found.', self.propagate_skipped_state)
+ self.log.info(
+ 'Skipping downstream tasks because propagate_skipped_state is set to %s '
+ 'and skipped task(s) were found.',
+ self.propagate_skipped_state,
+ )
downstream_tasks = context['task'].downstream_list
self.log.debug('Downstream task_ids %s', downstream_tasks)
diff --git a/airflow/plugins_manager.py b/airflow/plugins_manager.py
index f6fd08b6c0e2e..99c05b9fce6b3 100644
--- a/airflow/plugins_manager.py
+++ b/airflow/plugins_manager.py
@@ -160,9 +160,9 @@ def is_valid_plugin(plugin_obj):
global plugins # pylint: disable=global-statement
if (
- inspect.isclass(plugin_obj) and
- issubclass(plugin_obj, AirflowPlugin) and
- (plugin_obj is not AirflowPlugin)
+ inspect.isclass(plugin_obj)
+ and issubclass(plugin_obj, AirflowPlugin)
+ and (plugin_obj is not AirflowPlugin)
):
plugin_obj.validate()
return plugin_obj not in plugins
@@ -202,8 +202,7 @@ def load_plugins_from_plugin_directory():
global plugins # pylint: disable=global-statement
log.debug("Loading plugins from directory: %s", settings.PLUGINS_FOLDER)
- for file_path in find_path_from_directory(
- settings.PLUGINS_FOLDER, ".airflowignore"):
+ for file_path in find_path_from_directory(settings.PLUGINS_FOLDER, ".airflowignore"):
if not os.path.isfile(file_path):
continue
@@ -239,9 +238,11 @@ def make_module(name: str, objects: List[Any]):
name = name.lower()
module = types.ModuleType(name)
module._name = name.split('.')[-1] # type: ignore
- module._objects = objects # type: ignore
+ module._objects = objects # type: ignore
module.__dict__.update((o.__name__, o) for o in objects)
return module
+
+
# pylint: enable=protected-access
@@ -277,9 +278,11 @@ def initialize_web_ui_plugins():
global flask_appbuilder_menu_links
# pylint: enable=global-statement
- if flask_blueprints is not None and \
- flask_appbuilder_views is not None and \
- flask_appbuilder_menu_links is not None:
+ if (
+ flask_blueprints is not None
+ and flask_appbuilder_views is not None
+ and flask_appbuilder_menu_links is not None
+ ):
return
ensure_plugins_loaded()
@@ -296,17 +299,15 @@ def initialize_web_ui_plugins():
for plugin in plugins:
flask_appbuilder_views.extend(plugin.appbuilder_views)
flask_appbuilder_menu_links.extend(plugin.appbuilder_menu_items)
- flask_blueprints.extend([{
- 'name': plugin.name,
- 'blueprint': bp
- } for bp in plugin.flask_blueprints])
+ flask_blueprints.extend([{'name': plugin.name, 'blueprint': bp} for bp in plugin.flask_blueprints])
if (plugin.admin_views and not plugin.appbuilder_views) or (
- plugin.menu_links and not plugin.appbuilder_menu_items):
+ plugin.menu_links and not plugin.appbuilder_menu_items
+ ):
log.warning(
"Plugin \'%s\' may not be compatible with the current Airflow version. "
"Please contact the author of the plugin.",
- plugin.name
+ plugin.name,
)
@@ -318,9 +319,11 @@ def initialize_extra_operators_links_plugins():
global registered_operator_link_classes
# pylint: enable=global-statement
- if global_operator_extra_links is not None and \
- operator_extra_links is not None and \
- registered_operator_link_classes is not None:
+ if (
+ global_operator_extra_links is not None
+ and operator_extra_links is not None
+ and registered_operator_link_classes is not None
+ ):
return
ensure_plugins_loaded()
@@ -338,11 +341,12 @@ def initialize_extra_operators_links_plugins():
global_operator_extra_links.extend(plugin.global_operator_extra_links)
operator_extra_links.extend(list(plugin.operator_extra_links))
- registered_operator_link_classes.update({
- "{}.{}".format(link.__class__.__module__,
- link.__class__.__name__): link.__class__
- for link in plugin.operator_extra_links
- })
+ registered_operator_link_classes.update(
+ {
+ f"{link.__class__.__module__}.{link.__class__.__name__}": link.__class__
+ for link in plugin.operator_extra_links
+ }
+ )
def integrate_executor_plugins() -> None:
@@ -384,10 +388,12 @@ def integrate_dag_plugins() -> None:
global macros_modules
# pylint: enable=global-statement
- if operators_modules is not None and \
- sensors_modules is not None and \
- hooks_modules is not None and \
- macros_modules is not None:
+ if (
+ operators_modules is not None
+ and sensors_modules is not None
+ and hooks_modules is not None
+ and macros_modules is not None
+ ):
return
ensure_plugins_loaded()
diff --git a/airflow/providers/amazon/aws/example_dags/example_glacier_to_gcs.py b/airflow/providers/amazon/aws/example_dags/example_glacier_to_gcs.py
index 9fd2a2034e0ba..87d6aa6221c69 100644
--- a/airflow/providers/amazon/aws/example_dags/example_glacier_to_gcs.py
+++ b/airflow/providers/amazon/aws/example_dags/example_glacier_to_gcs.py
@@ -17,9 +17,7 @@
import os
from airflow import models
-from airflow.providers.amazon.aws.operators.glacier import (
- GlacierCreateJobOperator,
-)
+from airflow.providers.amazon.aws.operators.glacier import GlacierCreateJobOperator
from airflow.providers.amazon.aws.sensors.glacier import GlacierJobOperationSensor
from airflow.providers.amazon.aws.transfers.glacier_to_gcs import GlacierToGCSOperator
from airflow.utils.dates import days_ago
diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py b/airflow/providers/amazon/aws/hooks/base_aws.py
index b16ea435f7f01..12878377aed17 100644
--- a/airflow/providers/amazon/aws/hooks/base_aws.py
+++ b/airflow/providers/amazon/aws/hooks/base_aws.py
@@ -29,8 +29,8 @@
from typing import Any, Dict, Optional, Tuple, Union
import boto3
-from botocore.credentials import ReadOnlyCredentials
from botocore.config import Config
+from botocore.credentials import ReadOnlyCredentials
from cached_property import cached_property
from airflow.exceptions import AirflowException
diff --git a/airflow/providers/amazon/aws/hooks/cloud_formation.py b/airflow/providers/amazon/aws/hooks/cloud_formation.py
index 610cded6fec9b..ef20a3083610a 100644
--- a/airflow/providers/amazon/aws/hooks/cloud_formation.py
+++ b/airflow/providers/amazon/aws/hooks/cloud_formation.py
@@ -19,8 +19,8 @@
"""This module contains AWS CloudFormation Hook"""
from typing import Optional, Union
-from botocore.exceptions import ClientError
from boto3 import client, resource
+from botocore.exceptions import ClientError
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
diff --git a/airflow/providers/amazon/aws/hooks/elasticache_replication_group.py b/airflow/providers/amazon/aws/hooks/elasticache_replication_group.py
index 54305d5b86129..d6ddb9cb661c5 100644
--- a/airflow/providers/amazon/aws/hooks/elasticache_replication_group.py
+++ b/airflow/providers/amazon/aws/hooks/elasticache_replication_group.py
@@ -15,9 +15,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Optional
-
from time import sleep
+from typing import Optional
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
diff --git a/airflow/providers/amazon/aws/hooks/glacier.py b/airflow/providers/amazon/aws/hooks/glacier.py
index 8b2f239f162e5..c2e0509ef2f8a 100644
--- a/airflow/providers/amazon/aws/hooks/glacier.py
+++ b/airflow/providers/amazon/aws/hooks/glacier.py
@@ -18,6 +18,7 @@
from typing import Any, Dict
+
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
diff --git a/airflow/providers/amazon/aws/hooks/glue_catalog.py b/airflow/providers/amazon/aws/hooks/glue_catalog.py
index 65d44666bc0f0..32c4d7c6635e4 100644
--- a/airflow/providers/amazon/aws/hooks/glue_catalog.py
+++ b/airflow/providers/amazon/aws/hooks/glue_catalog.py
@@ -17,7 +17,7 @@
# under the License.
"""This module contains AWS Glue Catalog Hook"""
-from typing import Set, Optional
+from typing import Optional, Set
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
diff --git a/airflow/providers/amazon/aws/hooks/sagemaker.py b/airflow/providers/amazon/aws/hooks/sagemaker.py
index 9190d27e1eb69..bd51813e25716 100644
--- a/airflow/providers/amazon/aws/hooks/sagemaker.py
+++ b/airflow/providers/amazon/aws/hooks/sagemaker.py
@@ -22,7 +22,7 @@
import time
import warnings
from functools import partial
-from typing import Dict, List, Optional, Set, Any, Callable, Generator
+from typing import Any, Callable, Dict, Generator, List, Optional, Set
from botocore.exceptions import ClientError
diff --git a/airflow/providers/amazon/aws/hooks/secrets_manager.py b/airflow/providers/amazon/aws/hooks/secrets_manager.py
index 117c83d9c7dba..3d0289a3a7eb5 100644
--- a/airflow/providers/amazon/aws/hooks/secrets_manager.py
+++ b/airflow/providers/amazon/aws/hooks/secrets_manager.py
@@ -20,6 +20,7 @@
import base64
import json
from typing import Union
+
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
diff --git a/airflow/providers/amazon/aws/hooks/sns.py b/airflow/providers/amazon/aws/hooks/sns.py
index 795e78a961723..af1ce6ec27594 100644
--- a/airflow/providers/amazon/aws/hooks/sns.py
+++ b/airflow/providers/amazon/aws/hooks/sns.py
@@ -18,7 +18,7 @@
"""This module contains AWS SNS hook"""
import json
-from typing import Optional, Union, Dict
+from typing import Dict, Optional, Union
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
diff --git a/airflow/providers/amazon/aws/operators/batch.py b/airflow/providers/amazon/aws/operators/batch.py
index 82868e6aeaa9a..07a5b4a4f77d0 100644
--- a/airflow/providers/amazon/aws/operators/batch.py
+++ b/airflow/providers/amazon/aws/operators/batch.py
@@ -26,7 +26,7 @@
- http://boto3.readthedocs.io/en/latest/reference/services/batch.html
- https://docs.aws.amazon.com/batch/latest/APIReference/Welcome.html
"""
-from typing import Dict, Optional, Any
+from typing import Any, Dict, Optional
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
diff --git a/airflow/providers/amazon/aws/operators/datasync.py b/airflow/providers/amazon/aws/operators/datasync.py
index 661054afd85b7..aa7a6b9455cee 100644
--- a/airflow/providers/amazon/aws/operators/datasync.py
+++ b/airflow/providers/amazon/aws/operators/datasync.py
@@ -19,7 +19,7 @@
import logging
import random
-from typing import Optional, List
+from typing import List, Optional
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
diff --git a/airflow/providers/amazon/aws/operators/ecs.py b/airflow/providers/amazon/aws/operators/ecs.py
index 2ec170667c284..9584a31d77271 100644
--- a/airflow/providers/amazon/aws/operators/ecs.py
+++ b/airflow/providers/amazon/aws/operators/ecs.py
@@ -18,7 +18,7 @@
import re
import sys
from datetime import datetime
-from typing import Optional, Dict
+from typing import Dict, Optional
from botocore.waiter import Waiter
diff --git a/airflow/providers/amazon/aws/operators/sagemaker_transform.py b/airflow/providers/amazon/aws/operators/sagemaker_transform.py
index 1dadb3daae6ce..7caf9f1beee69 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker_transform.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker_transform.py
@@ -15,7 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Optional, List
+from typing import List, Optional
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
diff --git a/airflow/providers/amazon/aws/sensors/emr_base.py b/airflow/providers/amazon/aws/sensors/emr_base.py
index d862c6b4efc12..0fa0878457578 100644
--- a/airflow/providers/amazon/aws/sensors/emr_base.py
+++ b/airflow/providers/amazon/aws/sensors/emr_base.py
@@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict, Optional, Iterable
+from typing import Any, Dict, Iterable, Optional
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.emr import EmrHook
diff --git a/airflow/providers/amazon/aws/sensors/sagemaker_training.py b/airflow/providers/amazon/aws/sensors/sagemaker_training.py
index 9cd76688facdf..302069ef9c5be 100644
--- a/airflow/providers/amazon/aws/sensors/sagemaker_training.py
+++ b/airflow/providers/amazon/aws/sensors/sagemaker_training.py
@@ -15,9 +15,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Optional
-
import time
+from typing import Optional
from airflow.providers.amazon.aws.hooks.sagemaker import LogState, SageMakerHook
from airflow.providers.amazon.aws.sensors.sagemaker_base import SageMakerBaseSensor
diff --git a/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py b/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py
index 0d95ee7861d9c..47b6c1998d362 100644
--- a/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py
@@ -25,7 +25,7 @@
from copy import copy
from os.path import getsize
from tempfile import NamedTemporaryFile
-from typing import Any, Callable, Dict, Optional, IO
+from typing import IO, Any, Callable, Dict, Optional
from uuid import uuid4
from airflow.models import BaseOperator
diff --git a/airflow/providers/amazon/aws/transfers/gcs_to_s3.py b/airflow/providers/amazon/aws/transfers/gcs_to_s3.py
index 48559f8588779..4b552e4a575e4 100644
--- a/airflow/providers/amazon/aws/transfers/gcs_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/gcs_to_s3.py
@@ -17,7 +17,7 @@
# under the License.
"""This module contains Google Cloud Storage to S3 operator."""
import warnings
-from typing import Iterable, Optional, Sequence, Union, Dict, List, cast
+from typing import Dict, Iterable, List, Optional, Sequence, Union, cast
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
diff --git a/airflow/providers/amazon/aws/transfers/glacier_to_gcs.py b/airflow/providers/amazon/aws/transfers/glacier_to_gcs.py
index 3506003181ba6..5ee9802f1a87e 100644
--- a/airflow/providers/amazon/aws/transfers/glacier_to_gcs.py
+++ b/airflow/providers/amazon/aws/transfers/glacier_to_gcs.py
@@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
import tempfile
-from typing import Optional, Union, Sequence
+from typing import Optional, Sequence, Union
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.glacier import GlacierHook
diff --git a/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py b/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py
index 16568dba9ddf8..2939be57a2096 100644
--- a/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py
+++ b/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py
@@ -19,7 +19,7 @@
"""This module contains operator to move data from Hive to DynamoDB."""
import json
-from typing import Optional, Callable
+from typing import Callable, Optional
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.dynamodb import AwsDynamoDBHook
diff --git a/airflow/providers/amazon/aws/transfers/mongo_to_s3.py b/airflow/providers/amazon/aws/transfers/mongo_to_s3.py
index b88087e291a51..0855fe99e3908 100644
--- a/airflow/providers/amazon/aws/transfers/mongo_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/mongo_to_s3.py
@@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
import json
-from typing import Optional, Any, Iterable, Union, cast
+from typing import Any, Iterable, Optional, Union, cast
from bson import json_util
diff --git a/airflow/providers/amazon/backport_provider_setup.py b/airflow/providers/amazon/backport_provider_setup.py
index 7c7054264c664..f21a080fc8829 100644
--- a/airflow/providers/amazon/backport_provider_setup.py
+++ b/airflow/providers/amazon/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/apache/cassandra/backport_provider_setup.py b/airflow/providers/apache/cassandra/backport_provider_setup.py
index 73632d72fdbfd..db239d0c4b6d3 100644
--- a/airflow/providers/apache/cassandra/backport_provider_setup.py
+++ b/airflow/providers/apache/cassandra/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/apache/druid/backport_provider_setup.py b/airflow/providers/apache/druid/backport_provider_setup.py
index cd9f547e65181..6674a2214890b 100644
--- a/airflow/providers/apache/druid/backport_provider_setup.py
+++ b/airflow/providers/apache/druid/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/apache/hdfs/backport_provider_setup.py b/airflow/providers/apache/hdfs/backport_provider_setup.py
index 9252805189dc7..2815d167df70e 100644
--- a/airflow/providers/apache/hdfs/backport_provider_setup.py
+++ b/airflow/providers/apache/hdfs/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/apache/hive/backport_provider_setup.py b/airflow/providers/apache/hive/backport_provider_setup.py
index 9fdd0f292bc5b..e817ba7bcbce3 100644
--- a/airflow/providers/apache/hive/backport_provider_setup.py
+++ b/airflow/providers/apache/hive/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/apache/kylin/backport_provider_setup.py b/airflow/providers/apache/kylin/backport_provider_setup.py
index 9b3e4ac996752..61a4bc7917f67 100644
--- a/airflow/providers/apache/kylin/backport_provider_setup.py
+++ b/airflow/providers/apache/kylin/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/apache/livy/backport_provider_setup.py b/airflow/providers/apache/livy/backport_provider_setup.py
index a5ee750127c73..27024e91dd367 100644
--- a/airflow/providers/apache/livy/backport_provider_setup.py
+++ b/airflow/providers/apache/livy/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/apache/pig/backport_provider_setup.py b/airflow/providers/apache/pig/backport_provider_setup.py
index 9fa0d6fdde6cc..35150daedb124 100644
--- a/airflow/providers/apache/pig/backport_provider_setup.py
+++ b/airflow/providers/apache/pig/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/apache/pinot/backport_provider_setup.py b/airflow/providers/apache/pinot/backport_provider_setup.py
index 6b0836b949ebf..b626b84c86767 100644
--- a/airflow/providers/apache/pinot/backport_provider_setup.py
+++ b/airflow/providers/apache/pinot/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/apache/spark/backport_provider_setup.py b/airflow/providers/apache/spark/backport_provider_setup.py
index ccb99e2f61af8..40eb4bb4184e7 100644
--- a/airflow/providers/apache/spark/backport_provider_setup.py
+++ b/airflow/providers/apache/spark/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/apache/sqoop/backport_provider_setup.py b/airflow/providers/apache/sqoop/backport_provider_setup.py
index ac933f5547714..ff4ebe2c9f7e7 100644
--- a/airflow/providers/apache/sqoop/backport_provider_setup.py
+++ b/airflow/providers/apache/sqoop/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/celery/backport_provider_setup.py b/airflow/providers/celery/backport_provider_setup.py
index 68fcc3b06aaaf..4b050fc0ceaa8 100644
--- a/airflow/providers/celery/backport_provider_setup.py
+++ b/airflow/providers/celery/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/cloudant/backport_provider_setup.py b/airflow/providers/cloudant/backport_provider_setup.py
index 643c7cbbe8266..92bfb7e1522b8 100644
--- a/airflow/providers/cloudant/backport_provider_setup.py
+++ b/airflow/providers/cloudant/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/cncf/kubernetes/backport_provider_setup.py b/airflow/providers/cncf/kubernetes/backport_provider_setup.py
index b9d647bacc93b..fe620b92def92 100644
--- a/airflow/providers/cncf/kubernetes/backport_provider_setup.py
+++ b/airflow/providers/cncf/kubernetes/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
index 95f7962569863..7c203a8d87c60 100644
--- a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
+++ b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import tempfile
-from typing import Generator, Optional, Tuple, Union, Any
+from typing import Any, Generator, Optional, Tuple, Union
import yaml
from cached_property import cached_property
diff --git a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
index 4bf38a720556b..fed4b5bbf7210 100644
--- a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
+++ b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
@@ -16,21 +16,20 @@
# under the License.
"""Executes task in a Kubernetes POD"""
import re
-from typing import Dict, Iterable, List, Optional, Tuple, Any
+from typing import Any, Dict, Iterable, List, Optional, Tuple
import yaml
-from kubernetes.client import CoreV1Api
-from kubernetes.client import models as k8s
+from kubernetes.client import CoreV1Api, models as k8s
from airflow.exceptions import AirflowException
from airflow.kubernetes import kube_client, pod_generator, pod_launcher
+from airflow.kubernetes.pod_generator import PodGenerator
from airflow.kubernetes.secret import Secret
from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults
from airflow.utils.helpers import validate_key
from airflow.utils.state import State
from airflow.version import version as airflow_version
-from airflow.kubernetes.pod_generator import PodGenerator
class KubernetesPodOperator(BaseOperator): # pylint: disable=too-many-instance-attributes
diff --git a/airflow/providers/databricks/backport_provider_setup.py b/airflow/providers/databricks/backport_provider_setup.py
index 4b90c4c2a77ff..7dabf94d54b26 100644
--- a/airflow/providers/databricks/backport_provider_setup.py
+++ b/airflow/providers/databricks/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py
index 623502373f9c8..ef71c2c9bf8d9 100644
--- a/airflow/providers/databricks/hooks/databricks.py
+++ b/airflow/providers/databricks/hooks/databricks.py
@@ -26,7 +26,7 @@
from urllib.parse import urlparse
import requests
-from requests import exceptions as requests_exceptions, PreparedRequest
+from requests import PreparedRequest, exceptions as requests_exceptions
from requests.auth import AuthBase
from airflow import __version__
diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py
index 7602de526adda..fe753bc0921c9 100644
--- a/airflow/providers/databricks/operators/databricks.py
+++ b/airflow/providers/databricks/operators/databricks.py
@@ -19,7 +19,7 @@
"""This module contains Databricks operators."""
import time
-from typing import Union, Optional, Any, Dict, List
+from typing import Any, Dict, List, Optional, Union
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
diff --git a/airflow/providers/datadog/backport_provider_setup.py b/airflow/providers/datadog/backport_provider_setup.py
index 282bb26b2bd57..02e757c7689e4 100644
--- a/airflow/providers/datadog/backport_provider_setup.py
+++ b/airflow/providers/datadog/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/dingding/backport_provider_setup.py b/airflow/providers/dingding/backport_provider_setup.py
index efeec2daf495c..78482b2446d91 100644
--- a/airflow/providers/dingding/backport_provider_setup.py
+++ b/airflow/providers/dingding/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/dingding/hooks/dingding.py b/airflow/providers/dingding/hooks/dingding.py
index 45163d7be8769..4bd37e41328bb 100644
--- a/airflow/providers/dingding/hooks/dingding.py
+++ b/airflow/providers/dingding/hooks/dingding.py
@@ -17,7 +17,7 @@
# under the License.
import json
-from typing import Union, Optional, List
+from typing import List, Optional, Union
import requests
from requests import Session
diff --git a/airflow/providers/dingding/operators/dingding.py b/airflow/providers/dingding/operators/dingding.py
index 8b5f57a66c585..6be1bd3ad0040 100644
--- a/airflow/providers/dingding/operators/dingding.py
+++ b/airflow/providers/dingding/operators/dingding.py
@@ -15,10 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Union, Optional, List
+from typing import List, Optional, Union
from airflow.models import BaseOperator
-
from airflow.providers.dingding.hooks.dingding import DingdingHook
from airflow.utils.decorators import apply_defaults
diff --git a/airflow/providers/discord/backport_provider_setup.py b/airflow/providers/discord/backport_provider_setup.py
index 1563c75c58053..363d4d090ab9c 100644
--- a/airflow/providers/discord/backport_provider_setup.py
+++ b/airflow/providers/discord/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/docker/backport_provider_setup.py b/airflow/providers/docker/backport_provider_setup.py
index 4a88da8f3f484..42cedfb68e1c0 100644
--- a/airflow/providers/docker/backport_provider_setup.py
+++ b/airflow/providers/docker/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/elasticsearch/backport_provider_setup.py b/airflow/providers/elasticsearch/backport_provider_setup.py
index 013efe4985b14..27ba27cf84d74 100644
--- a/airflow/providers/elasticsearch/backport_provider_setup.py
+++ b/airflow/providers/elasticsearch/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/exasol/backport_provider_setup.py b/airflow/providers/exasol/backport_provider_setup.py
index 4dcef0e04ffd6..1be0b3ffaf81f 100644
--- a/airflow/providers/exasol/backport_provider_setup.py
+++ b/airflow/providers/exasol/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/exasol/hooks/exasol.py b/airflow/providers/exasol/hooks/exasol.py
index 1b01222c8f306..935eb8b1b5fd9 100644
--- a/airflow/providers/exasol/hooks/exasol.py
+++ b/airflow/providers/exasol/hooks/exasol.py
@@ -17,7 +17,7 @@
# under the License.
from contextlib import closing
-from typing import Union, Optional, List, Tuple, Any
+from typing import Any, List, Optional, Tuple, Union
import pyexasol
from pyexasol import ExaConnection
diff --git a/airflow/providers/facebook/backport_provider_setup.py b/airflow/providers/facebook/backport_provider_setup.py
index b9a1d398333db..0262eb87d71c8 100644
--- a/airflow/providers/facebook/backport_provider_setup.py
+++ b/airflow/providers/facebook/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/ftp/backport_provider_setup.py b/airflow/providers/ftp/backport_provider_setup.py
index e116e07750a72..237f15cc3a750 100644
--- a/airflow/providers/ftp/backport_provider_setup.py
+++ b/airflow/providers/ftp/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/ftp/hooks/ftp.py b/airflow/providers/ftp/hooks/ftp.py
index 87ae53a7ca0c7..18711896c2e44 100644
--- a/airflow/providers/ftp/hooks/ftp.py
+++ b/airflow/providers/ftp/hooks/ftp.py
@@ -20,7 +20,7 @@
import datetime
import ftplib
import os.path
-from typing import List, Optional, Any
+from typing import Any, List, Optional
from airflow.hooks.base_hook import BaseHook
diff --git a/airflow/providers/google/backport_provider_setup.py b/airflow/providers/google/backport_provider_setup.py
index d150da5045691..d22578696510f 100644
--- a/airflow/providers/google/backport_provider_setup.py
+++ b/airflow/providers/google/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/google/cloud/example_dags/example_azure_fileshare_to_gcs.py b/airflow/providers/google/cloud/example_dags/example_azure_fileshare_to_gcs.py
index da7de2b3eec6a..c260cb8616c0d 100644
--- a/airflow/providers/google/cloud/example_dags/example_azure_fileshare_to_gcs.py
+++ b/airflow/providers/google/cloud/example_dags/example_azure_fileshare_to_gcs.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import os
-from datetime import timedelta, datetime
+from datetime import datetime, timedelta
from airflow import DAG
from airflow.providers.google.cloud.transfers.azure_fileshare_to_gcs import AzureFileShareToGCSOperator
diff --git a/airflow/providers/google/cloud/example_dags/example_cloud_memorystore.py b/airflow/providers/google/cloud/example_dags/example_cloud_memorystore.py
index 70522a55ad433..441c165981cca 100644
--- a/airflow/providers/google/cloud/example_dags/example_cloud_memorystore.py
+++ b/airflow/providers/google/cloud/example_dags/example_cloud_memorystore.py
@@ -21,8 +21,8 @@
import os
from urllib.parse import urlparse
-from google.cloud.redis_v1.gapic.enums import FailoverInstanceRequest, Instance
from google.cloud.memcache_v1beta2.types import cloud_memcache
+from google.cloud.redis_v1.gapic.enums import FailoverInstanceRequest, Instance
from airflow import models
from airflow.operators.bash import BashOperator
@@ -36,8 +36,6 @@
CloudMemorystoreGetInstanceOperator,
CloudMemorystoreImportOperator,
CloudMemorystoreListInstancesOperator,
- CloudMemorystoreScaleInstanceOperator,
- CloudMemorystoreUpdateInstanceOperator,
CloudMemorystoreMemcachedApplyParametersOperator,
CloudMemorystoreMemcachedCreateInstanceOperator,
CloudMemorystoreMemcachedDeleteInstanceOperator,
@@ -45,6 +43,8 @@
CloudMemorystoreMemcachedListInstancesOperator,
CloudMemorystoreMemcachedUpdateInstanceOperator,
CloudMemorystoreMemcachedUpdateParametersOperator,
+ CloudMemorystoreScaleInstanceOperator,
+ CloudMemorystoreUpdateInstanceOperator,
)
from airflow.providers.google.cloud.operators.gcs import GCSBucketCreateAclEntryOperator
from airflow.utils import dates
diff --git a/airflow/providers/google/cloud/example_dags/example_mysql_to_gcs.py b/airflow/providers/google/cloud/example_dags/example_mysql_to_gcs.py
index dbec5dcc87fa8..cdf97dc5b7da1 100644
--- a/airflow/providers/google/cloud/example_dags/example_mysql_to_gcs.py
+++ b/airflow/providers/google/cloud/example_dags/example_mysql_to_gcs.py
@@ -16,6 +16,7 @@
# under the License.
import os
+
from airflow import models
from airflow.providers.google.cloud.transfers.mysql_to_gcs import MySQLToGCSOperator
from airflow.utils import dates
diff --git a/airflow/providers/google/cloud/hooks/bigquery.py b/airflow/providers/google/cloud/hooks/bigquery.py
index 96e40560aa439..3ee600809a0c6 100644
--- a/airflow/providers/google/cloud/hooks/bigquery.py
+++ b/airflow/providers/google/cloud/hooks/bigquery.py
@@ -26,7 +26,7 @@
import time
import warnings
from copy import deepcopy
-from datetime import timedelta, datetime
+from datetime import datetime, timedelta
from typing import Any, Dict, Iterable, List, Mapping, NoReturn, Optional, Sequence, Tuple, Type, Union
from google.api_core.retry import Retry
@@ -43,7 +43,7 @@
from google.cloud.bigquery.dataset import AccessEntry, Dataset, DatasetListItem, DatasetReference
from google.cloud.bigquery.table import EncryptionConfiguration, Row, Table, TableReference
from google.cloud.exceptions import NotFound
-from googleapiclient.discovery import build, Resource
+from googleapiclient.discovery import Resource, build
from pandas import DataFrame
from pandas_gbq import read_gbq
from pandas_gbq.gbq import (
diff --git a/airflow/providers/google/cloud/hooks/cloud_memorystore.py b/airflow/providers/google/cloud/hooks/cloud_memorystore.py
index 693a82f77dc02..bfc01f94285df 100644
--- a/airflow/providers/google/cloud/hooks/cloud_memorystore.py
+++ b/airflow/providers/google/cloud/hooks/cloud_memorystore.py
@@ -18,8 +18,8 @@
"""Hooks for Cloud Memorystore service"""
from typing import Dict, Optional, Sequence, Tuple, Union
-from google.api_core.exceptions import NotFound
from google.api_core import path_template
+from google.api_core.exceptions import NotFound
from google.api_core.retry import Retry
from google.cloud.memcache_v1beta2 import CloudMemcacheClient
from google.cloud.memcache_v1beta2.types import cloud_memcache
diff --git a/airflow/providers/google/cloud/hooks/cloud_sql.py b/airflow/providers/google/cloud/hooks/cloud_sql.py
index 0443be78f3905..cc6d5725562ed 100644
--- a/airflow/providers/google/cloud/hooks/cloud_sql.py
+++ b/airflow/providers/google/cloud/hooks/cloud_sql.py
@@ -37,7 +37,7 @@
from urllib.parse import quote_plus
import requests
-from googleapiclient.discovery import build, Resource
+from googleapiclient.discovery import Resource, build
from googleapiclient.errors import HttpError
from sqlalchemy.orm import Session
diff --git a/airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py b/airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py
index b7aae109e1f03..5600a29a8696a 100644
--- a/airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py
+++ b/airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py
@@ -25,7 +25,7 @@
from datetime import timedelta
from typing import List, Optional, Sequence, Set, Union
-from googleapiclient.discovery import build, Resource
+from googleapiclient.discovery import Resource, build
from googleapiclient.errors import HttpError
from airflow.exceptions import AirflowException
diff --git a/airflow/providers/google/cloud/hooks/datastore.py b/airflow/providers/google/cloud/hooks/datastore.py
index c8ca3de1cb7d9..d36c295a3919d 100644
--- a/airflow/providers/google/cloud/hooks/datastore.py
+++ b/airflow/providers/google/cloud/hooks/datastore.py
@@ -23,7 +23,7 @@
import warnings
from typing import Any, Dict, Optional, Sequence, Union
-from googleapiclient.discovery import build, Resource
+from googleapiclient.discovery import Resource, build
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
diff --git a/airflow/providers/google/cloud/hooks/gcs.py b/airflow/providers/google/cloud/hooks/gcs.py
index 6d2f3480fcec8..f59ff8ed9fa67 100644
--- a/airflow/providers/google/cloud/hooks/gcs.py
+++ b/airflow/providers/google/cloud/hooks/gcs.py
@@ -27,7 +27,7 @@
from io import BytesIO
from os import path
from tempfile import NamedTemporaryFile
-from typing import Callable, Optional, Sequence, Set, Tuple, TypeVar, Union, cast, List
+from typing import Callable, List, Optional, Sequence, Set, Tuple, TypeVar, Union, cast
from urllib.parse import urlparse
from google.api_core.exceptions import NotFound
diff --git a/airflow/providers/google/cloud/hooks/gdm.py b/airflow/providers/google/cloud/hooks/gdm.py
index e726d2f366711..f11a3903f3005 100644
--- a/airflow/providers/google/cloud/hooks/gdm.py
+++ b/airflow/providers/google/cloud/hooks/gdm.py
@@ -19,7 +19,7 @@
from typing import Any, Dict, List, Optional, Sequence, Union
-from googleapiclient.discovery import build, Resource
+from googleapiclient.discovery import Resource, build
from airflow.exceptions import AirflowException
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
diff --git a/airflow/providers/google/cloud/hooks/mlengine.py b/airflow/providers/google/cloud/hooks/mlengine.py
index b17cc46fef257..e0182ce91c70a 100644
--- a/airflow/providers/google/cloud/hooks/mlengine.py
+++ b/airflow/providers/google/cloud/hooks/mlengine.py
@@ -21,7 +21,7 @@
import time
from typing import Callable, Dict, List, Optional
-from googleapiclient.discovery import build, Resource
+from googleapiclient.discovery import Resource, build
from googleapiclient.errors import HttpError
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
diff --git a/airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py b/airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py
index 37b616574fe28..127e24e0155e3 100644
--- a/airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py
+++ b/airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py
@@ -19,7 +19,7 @@
"""This module contains Google Cloud Transfer operators."""
from copy import deepcopy
from datetime import date, time
-from typing import Dict, Optional, Sequence, Union, List
+from typing import Dict, List, Optional, Sequence, Union
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
diff --git a/airflow/providers/google/cloud/operators/dlp.py b/airflow/providers/google/cloud/operators/dlp.py
index 5d28d6639be8b..099feb2f1aed0 100644
--- a/airflow/providers/google/cloud/operators/dlp.py
+++ b/airflow/providers/google/cloud/operators/dlp.py
@@ -24,7 +24,7 @@
"""
from typing import Dict, Optional, Sequence, Tuple, Union
-from google.api_core.exceptions import AlreadyExists, NotFound, InvalidArgument
+from google.api_core.exceptions import AlreadyExists, InvalidArgument, NotFound
from google.api_core.retry import Retry
from google.cloud.dlp_v2.types import (
ByteContentItem,
diff --git a/airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py b/airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py
index dec6de1633159..02c2472fa4468 100644
--- a/airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py
+++ b/airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py
@@ -19,10 +19,10 @@
from typing import Optional, Sequence, Set, Union
from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import (
- CloudDataTransferServiceHook,
COUNTERS,
METADATA,
NAME,
+ CloudDataTransferServiceHook,
)
from airflow.sensors.base_sensor_operator import BaseSensorOperator
from airflow.utils.decorators import apply_defaults
diff --git a/airflow/providers/google/cloud/sensors/dataproc.py b/airflow/providers/google/cloud/sensors/dataproc.py
index fb8a990674445..9606fca6b4859 100644
--- a/airflow/providers/google/cloud/sensors/dataproc.py
+++ b/airflow/providers/google/cloud/sensors/dataproc.py
@@ -20,10 +20,10 @@
from google.cloud.dataproc_v1beta2.types import JobStatus
+from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.dataproc import DataprocHook
from airflow.sensors.base_sensor_operator import BaseSensorOperator
from airflow.utils.decorators import apply_defaults
-from airflow.exceptions import AirflowException
class DataprocJobSensor(BaseSensorOperator):
diff --git a/airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py b/airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py
index d424838897229..913b9485e8f90 100644
--- a/airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py
+++ b/airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py
@@ -17,11 +17,11 @@
# under the License.
from tempfile import NamedTemporaryFile
-from typing import Optional, Union, Sequence, Iterable
+from typing import Iterable, Optional, Sequence, Union
from airflow import AirflowException
from airflow.models import BaseOperator
-from airflow.providers.google.cloud.hooks.gcs import _parse_gcs_url, GCSHook, gcs_object_is_directory
+from airflow.providers.google.cloud.hooks.gcs import GCSHook, _parse_gcs_url, gcs_object_is_directory
from airflow.providers.microsoft.azure.hooks.azure_fileshare import AzureFileShareHook
from airflow.utils.decorators import apply_defaults
diff --git a/airflow/providers/google/cloud/transfers/presto_to_gcs.py b/airflow/providers/google/cloud/transfers/presto_to_gcs.py
index 7543f83d4fc94..70a4062b0def0 100644
--- a/airflow/providers/google/cloud/transfers/presto_to_gcs.py
+++ b/airflow/providers/google/cloud/transfers/presto_to_gcs.py
@@ -15,7 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, List, Tuple, Dict
+from typing import Any, Dict, List, Tuple
from prestodb.client import PrestoResult
from prestodb.dbapi import Cursor as PrestoCursor
diff --git a/airflow/providers/google/cloud/utils/credentials_provider.py b/airflow/providers/google/cloud/utils/credentials_provider.py
index c06c5de833958..e39b0a814a7d0 100644
--- a/airflow/providers/google/cloud/utils/credentials_provider.py
+++ b/airflow/providers/google/cloud/utils/credentials_provider.py
@@ -23,7 +23,7 @@
import logging
import tempfile
from contextlib import ExitStack, contextmanager
-from typing import Collection, Dict, Optional, Sequence, Tuple, Union, Generator
+from typing import Collection, Dict, Generator, Optional, Sequence, Tuple, Union
from urllib.parse import urlencode
import google.auth
diff --git a/airflow/providers/google/marketing_platform/operators/analytics.py b/airflow/providers/google/marketing_platform/operators/analytics.py
index 803984db74f25..2a1f6a14c6f23 100644
--- a/airflow/providers/google/marketing_platform/operators/analytics.py
+++ b/airflow/providers/google/marketing_platform/operators/analytics.py
@@ -18,7 +18,7 @@
"""This module contains Google Analytics 360 operators."""
import csv
from tempfile import NamedTemporaryFile
-from typing import Dict, Optional, Sequence, Union, Any, List
+from typing import Any, Dict, List, Optional, Sequence, Union
from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.gcs import GCSHook
diff --git a/airflow/providers/grpc/backport_provider_setup.py b/airflow/providers/grpc/backport_provider_setup.py
index 9094ac31cf2de..18974469c319a 100644
--- a/airflow/providers/grpc/backport_provider_setup.py
+++ b/airflow/providers/grpc/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/hashicorp/backport_provider_setup.py b/airflow/providers/hashicorp/backport_provider_setup.py
index ecb051985fb26..ea75f9fc70e68 100644
--- a/airflow/providers/hashicorp/backport_provider_setup.py
+++ b/airflow/providers/hashicorp/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/http/backport_provider_setup.py b/airflow/providers/http/backport_provider_setup.py
index ebae13e66b903..010bf50d37883 100644
--- a/airflow/providers/http/backport_provider_setup.py
+++ b/airflow/providers/http/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/imap/backport_provider_setup.py b/airflow/providers/imap/backport_provider_setup.py
index b515f9ee49e17..86c6a3ad5ed8d 100644
--- a/airflow/providers/imap/backport_provider_setup.py
+++ b/airflow/providers/imap/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/jdbc/backport_provider_setup.py b/airflow/providers/jdbc/backport_provider_setup.py
index 11d751b5b7b59..798c8a9145952 100644
--- a/airflow/providers/jdbc/backport_provider_setup.py
+++ b/airflow/providers/jdbc/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/jenkins/backport_provider_setup.py b/airflow/providers/jenkins/backport_provider_setup.py
index f22d0df9fe692..d35948529b3e6 100644
--- a/airflow/providers/jenkins/backport_provider_setup.py
+++ b/airflow/providers/jenkins/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/jira/backport_provider_setup.py b/airflow/providers/jira/backport_provider_setup.py
index a12a10efd172b..525a6272b585d 100644
--- a/airflow/providers/jira/backport_provider_setup.py
+++ b/airflow/providers/jira/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/microsoft/azure/backport_provider_setup.py b/airflow/providers/microsoft/azure/backport_provider_setup.py
index 93e31c4348fe9..6580d1106aee2 100644
--- a/airflow/providers/microsoft/azure/backport_provider_setup.py
+++ b/airflow/providers/microsoft/azure/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/microsoft/azure/example_dags/example_azure_blob_to_gcs.py b/airflow/providers/microsoft/azure/example_dags/example_azure_blob_to_gcs.py
index 7183972545c83..693286fe00858 100644
--- a/airflow/providers/microsoft/azure/example_dags/example_azure_blob_to_gcs.py
+++ b/airflow/providers/microsoft/azure/example_dags/example_azure_blob_to_gcs.py
@@ -19,9 +19,7 @@
from airflow import DAG
from airflow.providers.microsoft.azure.sensors.wasb import WasbBlobSensor
-from airflow.providers.microsoft.azure.transfers.azure_blob_to_gcs import (
- AzureBlobStorageToGCSOperator,
-)
+from airflow.providers.microsoft.azure.transfers.azure_blob_to_gcs import AzureBlobStorageToGCSOperator
from airflow.utils.dates import days_ago
BLOB_NAME = os.environ.get("AZURE_BLOB_NAME", "file.txt")
diff --git a/airflow/providers/microsoft/azure/example_dags/example_local_to_adls.py b/airflow/providers/microsoft/azure/example_dags/example_local_to_adls.py
index 78383195082b5..294b5f71b5538 100644
--- a/airflow/providers/microsoft/azure/example_dags/example_local_to_adls.py
+++ b/airflow/providers/microsoft/azure/example_dags/example_local_to_adls.py
@@ -16,6 +16,7 @@
# under the License.
import os
+
from airflow import models
from airflow.providers.microsoft.azure.transfers.local_to_adls import LocalToAzureDataLakeStorageOperator
from airflow.utils.dates import days_ago
diff --git a/airflow/providers/microsoft/azure/hooks/azure_batch.py b/airflow/providers/microsoft/azure/hooks/azure_batch.py
index 864233c8315cf..ab4e6e90ffaba 100644
--- a/airflow/providers/microsoft/azure/hooks/azure_batch.py
+++ b/airflow/providers/microsoft/azure/hooks/azure_batch.py
@@ -21,7 +21,7 @@
from typing import Optional, Set
from azure.batch import BatchServiceClient, batch_auth, models as batch_models
-from azure.batch.models import PoolAddParameter, JobAddParameter, TaskAddParameter
+from azure.batch.models import JobAddParameter, PoolAddParameter, TaskAddParameter
from airflow.exceptions import AirflowException
from airflow.hooks.base_hook import BaseHook
diff --git a/airflow/providers/microsoft/azure/hooks/azure_fileshare.py b/airflow/providers/microsoft/azure/hooks/azure_fileshare.py
index 81d683960ad66..95b462c7dc25b 100644
--- a/airflow/providers/microsoft/azure/hooks/azure_fileshare.py
+++ b/airflow/providers/microsoft/azure/hooks/azure_fileshare.py
@@ -16,9 +16,9 @@
# specific language governing permissions and limitations
# under the License.
#
-from typing import Optional, List
+from typing import List, Optional
-from azure.storage.file import FileService, File
+from azure.storage.file import File, FileService
from airflow.hooks.base_hook import BaseHook
diff --git a/airflow/providers/microsoft/azure/log/wasb_task_handler.py b/airflow/providers/microsoft/azure/log/wasb_task_handler.py
index a8eb6db69a1d3..3f12135aff62e 100644
--- a/airflow/providers/microsoft/azure/log/wasb_task_handler.py
+++ b/airflow/providers/microsoft/azure/log/wasb_task_handler.py
@@ -17,7 +17,7 @@
# under the License.
import os
import shutil
-from typing import Optional, Tuple, Dict
+from typing import Dict, Optional, Tuple
from azure.common import AzureHttpError
from cached_property import cached_property
diff --git a/airflow/providers/microsoft/azure/operators/azure_container_instances.py b/airflow/providers/microsoft/azure/operators/azure_container_instances.py
index 3bf30d9e3d748..75c0ba532e710 100644
--- a/airflow/providers/microsoft/azure/operators/azure_container_instances.py
+++ b/airflow/providers/microsoft/azure/operators/azure_container_instances.py
@@ -19,17 +19,17 @@
import re
from collections import namedtuple
from time import sleep
-from typing import Any, List, Optional, Sequence, Union, Dict
+from typing import Any, Dict, List, Optional, Sequence, Union
from azure.mgmt.containerinstance.models import (
Container,
ContainerGroup,
+ ContainerPort,
EnvironmentVariable,
+ IpAddress,
ResourceRequests,
ResourceRequirements,
VolumeMount,
- IpAddress,
- ContainerPort,
)
from msrestazure.azure_exceptions import CloudError
diff --git a/airflow/providers/microsoft/azure/transfers/azure_blob_to_gcs.py b/airflow/providers/microsoft/azure/transfers/azure_blob_to_gcs.py
index a33a922138699..ccc7577a38f89 100644
--- a/airflow/providers/microsoft/azure/transfers/azure_blob_to_gcs.py
+++ b/airflow/providers/microsoft/azure/transfers/azure_blob_to_gcs.py
@@ -17,8 +17,8 @@
# under the License.
#
import tempfile
+from typing import Optional, Sequence, Union
-from typing import Optional, Union, Sequence
from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.providers.microsoft.azure.hooks.wasb import WasbHook
diff --git a/airflow/providers/microsoft/azure/transfers/local_to_adls.py b/airflow/providers/microsoft/azure/transfers/local_to_adls.py
index 755a171c3606a..bf6947b653de7 100644
--- a/airflow/providers/microsoft/azure/transfers/local_to_adls.py
+++ b/airflow/providers/microsoft/azure/transfers/local_to_adls.py
@@ -15,7 +15,8 @@
# specific language governing permissions and limitations
# under the License.
-from typing import Dict, Any, Optional
+from typing import Any, Dict, Optional
+
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.microsoft.azure.hooks.azure_data_lake import AzureDataLakeHook
diff --git a/airflow/providers/microsoft/mssql/backport_provider_setup.py b/airflow/providers/microsoft/mssql/backport_provider_setup.py
index 6868ab659fe7d..2c314e03a2f08 100644
--- a/airflow/providers/microsoft/mssql/backport_provider_setup.py
+++ b/airflow/providers/microsoft/mssql/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/microsoft/winrm/backport_provider_setup.py b/airflow/providers/microsoft/winrm/backport_provider_setup.py
index 876d4fc4a99e1..b79a1b5f3d7c3 100644
--- a/airflow/providers/microsoft/winrm/backport_provider_setup.py
+++ b/airflow/providers/microsoft/winrm/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/mongo/backport_provider_setup.py b/airflow/providers/mongo/backport_provider_setup.py
index ffc5500562aef..da729e4f05bee 100644
--- a/airflow/providers/mongo/backport_provider_setup.py
+++ b/airflow/providers/mongo/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/mysql/backport_provider_setup.py b/airflow/providers/mysql/backport_provider_setup.py
index a43120e7b3c83..af8383b1fa41e 100644
--- a/airflow/providers/mysql/backport_provider_setup.py
+++ b/airflow/providers/mysql/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/odbc/backport_provider_setup.py b/airflow/providers/odbc/backport_provider_setup.py
index f27329190f770..0e5c5cf361232 100644
--- a/airflow/providers/odbc/backport_provider_setup.py
+++ b/airflow/providers/odbc/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/odbc/hooks/odbc.py b/airflow/providers/odbc/hooks/odbc.py
index 3eabdf2e63f66..fc12549422754 100644
--- a/airflow/providers/odbc/hooks/odbc.py
+++ b/airflow/providers/odbc/hooks/odbc.py
@@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
"""This module contains ODBC hook."""
-from typing import Optional, Any
+from typing import Any, Optional
from urllib.parse import quote_plus
import pyodbc
diff --git a/airflow/providers/openfaas/backport_provider_setup.py b/airflow/providers/openfaas/backport_provider_setup.py
index ca748ad95305d..63400ec278376 100644
--- a/airflow/providers/openfaas/backport_provider_setup.py
+++ b/airflow/providers/openfaas/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/opsgenie/backport_provider_setup.py b/airflow/providers/opsgenie/backport_provider_setup.py
index e5d21ae961099..9ffb68eaf2c03 100644
--- a/airflow/providers/opsgenie/backport_provider_setup.py
+++ b/airflow/providers/opsgenie/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/opsgenie/hooks/opsgenie_alert.py b/airflow/providers/opsgenie/hooks/opsgenie_alert.py
index 9e5aa6b6b8e08..a65b9a3ea615b 100644
--- a/airflow/providers/opsgenie/hooks/opsgenie_alert.py
+++ b/airflow/providers/opsgenie/hooks/opsgenie_alert.py
@@ -18,7 +18,7 @@
#
import json
-from typing import Optional, Any
+from typing import Any, Optional
import requests
diff --git a/airflow/providers/opsgenie/operators/opsgenie_alert.py b/airflow/providers/opsgenie/operators/opsgenie_alert.py
index faea2d720e3b4..f4511dd1920bf 100644
--- a/airflow/providers/opsgenie/operators/opsgenie_alert.py
+++ b/airflow/providers/opsgenie/operators/opsgenie_alert.py
@@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
#
-from typing import Optional, List, Dict, Any
+from typing import Any, Dict, List, Optional
from airflow.models import BaseOperator
from airflow.providers.opsgenie.hooks.opsgenie_alert import OpsgenieAlertHook
diff --git a/airflow/providers/oracle/backport_provider_setup.py b/airflow/providers/oracle/backport_provider_setup.py
index 3302022d81a53..a2fe377e4fbff 100644
--- a/airflow/providers/oracle/backport_provider_setup.py
+++ b/airflow/providers/oracle/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/oracle/hooks/oracle.py b/airflow/providers/oracle/hooks/oracle.py
index 020267b914f7f..efb3fcc547f6f 100644
--- a/airflow/providers/oracle/hooks/oracle.py
+++ b/airflow/providers/oracle/hooks/oracle.py
@@ -17,7 +17,7 @@
# under the License.
from datetime import datetime
-from typing import Optional, List
+from typing import List, Optional
import cx_Oracle
import numpy
diff --git a/airflow/providers/pagerduty/backport_provider_setup.py b/airflow/providers/pagerduty/backport_provider_setup.py
index e28bd21907333..6b8d1ceec5d73 100644
--- a/airflow/providers/pagerduty/backport_provider_setup.py
+++ b/airflow/providers/pagerduty/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/plexus/backport_provider_setup.py b/airflow/providers/plexus/backport_provider_setup.py
index a7ffb31e3bbb1..f92d8dd8a9023 100644
--- a/airflow/providers/plexus/backport_provider_setup.py
+++ b/airflow/providers/plexus/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/plexus/operators/job.py b/airflow/providers/plexus/operators/job.py
index a1aec95cf37b5..ece5df6e5f318 100644
--- a/airflow/providers/plexus/operators/job.py
+++ b/airflow/providers/plexus/operators/job.py
@@ -17,7 +17,7 @@
import logging
import time
-from typing import Dict, Any, Optional
+from typing import Any, Dict, Optional
import requests
diff --git a/airflow/providers/postgres/backport_provider_setup.py b/airflow/providers/postgres/backport_provider_setup.py
index 8a1bffd0bde7d..f3c5d6ed1ed83 100644
--- a/airflow/providers/postgres/backport_provider_setup.py
+++ b/airflow/providers/postgres/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/postgres/hooks/postgres.py b/airflow/providers/postgres/hooks/postgres.py
index 825a0d027036e..b207edbf651fe 100644
--- a/airflow/providers/postgres/hooks/postgres.py
+++ b/airflow/providers/postgres/hooks/postgres.py
@@ -24,7 +24,7 @@
import psycopg2.extensions
import psycopg2.extras
from psycopg2.extensions import connection
-from psycopg2.extras import DictCursor, RealDictCursor, NamedTupleCursor
+from psycopg2.extras import DictCursor, NamedTupleCursor, RealDictCursor
from airflow.hooks.dbapi_hook import DbApiHook
from airflow.models.connection import Connection
diff --git a/airflow/providers/presto/backport_provider_setup.py b/airflow/providers/presto/backport_provider_setup.py
index b24fff549507c..cf9f269a838f4 100644
--- a/airflow/providers/presto/backport_provider_setup.py
+++ b/airflow/providers/presto/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/presto/hooks/presto.py b/airflow/providers/presto/hooks/presto.py
index 26575d3252a43..a5a0576e1b670 100644
--- a/airflow/providers/presto/hooks/presto.py
+++ b/airflow/providers/presto/hooks/presto.py
@@ -15,8 +15,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Optional, Any, Iterable
import os
+from typing import Any, Iterable, Optional
import prestodb
from prestodb.exceptions import DatabaseError
diff --git a/airflow/providers/qubole/backport_provider_setup.py b/airflow/providers/qubole/backport_provider_setup.py
index 4d41bf8a3be95..8387e4a5d0556 100644
--- a/airflow/providers/qubole/backport_provider_setup.py
+++ b/airflow/providers/qubole/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/qubole/hooks/qubole.py b/airflow/providers/qubole/hooks/qubole.py
index eb3a374e91a77..56c0a5220a55e 100644
--- a/airflow/providers/qubole/hooks/qubole.py
+++ b/airflow/providers/qubole/hooks/qubole.py
@@ -22,7 +22,7 @@
import os
import pathlib
import time
-from typing import List, Dict, Tuple
+from typing import Dict, List, Tuple
from qds_sdk.commands import (
Command,
@@ -31,12 +31,12 @@
DbTapQueryCommand,
HadoopCommand,
HiveCommand,
+ JupyterNotebookCommand,
PigCommand,
PrestoCommand,
ShellCommand,
SparkCommand,
SqlCommand,
- JupyterNotebookCommand,
)
from qds_sdk.qubole import Qubole
diff --git a/airflow/providers/qubole/hooks/qubole_check.py b/airflow/providers/qubole/hooks/qubole_check.py
index 1c6bdf43e2eac..987ea59ae2021 100644
--- a/airflow/providers/qubole/hooks/qubole_check.py
+++ b/airflow/providers/qubole/hooks/qubole_check.py
@@ -18,7 +18,7 @@
#
import logging
from io import StringIO
-from typing import List, Union, Optional
+from typing import List, Optional, Union
from qds_sdk.commands import Command
diff --git a/airflow/providers/qubole/operators/qubole_check.py b/airflow/providers/qubole/operators/qubole_check.py
index fc1561c28eeb0..41a4bc7e138b4 100644
--- a/airflow/providers/qubole/operators/qubole_check.py
+++ b/airflow/providers/qubole/operators/qubole_check.py
@@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
#
-from typing import Iterable, Union, Optional
+from typing import Iterable, Optional, Union
from airflow.exceptions import AirflowException
from airflow.operators.check_operator import CheckOperator, ValueCheckOperator
diff --git a/airflow/providers/redis/backport_provider_setup.py b/airflow/providers/redis/backport_provider_setup.py
index ce6f307c2ea50..bba8751ae7b12 100644
--- a/airflow/providers/redis/backport_provider_setup.py
+++ b/airflow/providers/redis/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/salesforce/backport_provider_setup.py b/airflow/providers/salesforce/backport_provider_setup.py
index 076714068eb1d..2743d4f88efe0 100644
--- a/airflow/providers/salesforce/backport_provider_setup.py
+++ b/airflow/providers/salesforce/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/salesforce/hooks/salesforce.py b/airflow/providers/salesforce/hooks/salesforce.py
index 1fea62139a48b..affe286a69da3 100644
--- a/airflow/providers/salesforce/hooks/salesforce.py
+++ b/airflow/providers/salesforce/hooks/salesforce.py
@@ -25,7 +25,7 @@
"""
import logging
import time
-from typing import Optional, List, Iterable
+from typing import Iterable, List, Optional
import pandas as pd
from simple_salesforce import Salesforce, api
diff --git a/airflow/providers/salesforce/hooks/tableau.py b/airflow/providers/salesforce/hooks/tableau.py
index 80c67640963c3..bd47d10b3e67d 100644
--- a/airflow/providers/salesforce/hooks/tableau.py
+++ b/airflow/providers/salesforce/hooks/tableau.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from enum import Enum
-from typing import Optional, Any
+from typing import Any, Optional
from tableauserverclient import Pager, PersonalAccessTokenAuth, Server, TableauAuth
from tableauserverclient.server import Auth
diff --git a/airflow/providers/samba/backport_provider_setup.py b/airflow/providers/samba/backport_provider_setup.py
index 6829dacf9c238..a426466494544 100644
--- a/airflow/providers/samba/backport_provider_setup.py
+++ b/airflow/providers/samba/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/segment/backport_provider_setup.py b/airflow/providers/segment/backport_provider_setup.py
index 3813732bca634..54a77898a277a 100644
--- a/airflow/providers/segment/backport_provider_setup.py
+++ b/airflow/providers/segment/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/sftp/backport_provider_setup.py b/airflow/providers/sftp/backport_provider_setup.py
index 3359f0fb798ac..a0664098df6dd 100644
--- a/airflow/providers/sftp/backport_provider_setup.py
+++ b/airflow/providers/sftp/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/singularity/backport_provider_setup.py b/airflow/providers/singularity/backport_provider_setup.py
index 26162133dd530..e9f51448cc4b1 100644
--- a/airflow/providers/singularity/backport_provider_setup.py
+++ b/airflow/providers/singularity/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/slack/backport_provider_setup.py b/airflow/providers/slack/backport_provider_setup.py
index e51f2073d02ef..98475e28ce5e3 100644
--- a/airflow/providers/slack/backport_provider_setup.py
+++ b/airflow/providers/slack/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/slack/operators/slack.py b/airflow/providers/slack/operators/slack.py
index 622bab8110638..0263466c40cb6 100644
--- a/airflow/providers/slack/operators/slack.py
+++ b/airflow/providers/slack/operators/slack.py
@@ -17,7 +17,7 @@
# under the License.
import json
-from typing import Dict, List, Optional, Any
+from typing import Any, Dict, List, Optional
from airflow.models import BaseOperator
from airflow.providers.slack.hooks.slack import SlackHook
diff --git a/airflow/providers/slack/operators/slack_webhook.py b/airflow/providers/slack/operators/slack_webhook.py
index 7899aa931bfac..1a7725a81c6ae 100644
--- a/airflow/providers/slack/operators/slack_webhook.py
+++ b/airflow/providers/slack/operators/slack_webhook.py
@@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
#
-from typing import Optional, Dict, Any
+from typing import Any, Dict, Optional
from airflow.providers.http.operators.http import SimpleHttpOperator
from airflow.providers.slack.hooks.slack_webhook import SlackWebhookHook
diff --git a/airflow/providers/snowflake/backport_provider_setup.py b/airflow/providers/snowflake/backport_provider_setup.py
index 6e67634cdacde..47ccac1f39e27 100644
--- a/airflow/providers/snowflake/backport_provider_setup.py
+++ b/airflow/providers/snowflake/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/snowflake/hooks/snowflake.py b/airflow/providers/snowflake/hooks/snowflake.py
index 00a1e2bcb35d6..24fef1132b5cc 100644
--- a/airflow/providers/snowflake/hooks/snowflake.py
+++ b/airflow/providers/snowflake/hooks/snowflake.py
@@ -15,7 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Dict, Optional, Tuple, Any
+from typing import Any, Dict, Optional, Tuple
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
diff --git a/airflow/providers/sqlite/backport_provider_setup.py b/airflow/providers/sqlite/backport_provider_setup.py
index 6c0fcd219b48a..da49f3006dd7c 100644
--- a/airflow/providers/sqlite/backport_provider_setup.py
+++ b/airflow/providers/sqlite/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/ssh/backport_provider_setup.py b/airflow/providers/ssh/backport_provider_setup.py
index 94137e01b1422..c073e129ada98 100644
--- a/airflow/providers/ssh/backport_provider_setup.py
+++ b/airflow/providers/ssh/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/ssh/hooks/ssh.py b/airflow/providers/ssh/hooks/ssh.py
index d453c468c72a5..e28e779f54254 100644
--- a/airflow/providers/ssh/hooks/ssh.py
+++ b/airflow/providers/ssh/hooks/ssh.py
@@ -20,7 +20,7 @@
import os
import warnings
from io import StringIO
-from typing import Optional, Union, Tuple
+from typing import Optional, Tuple, Union
import paramiko
from paramiko.config import SSH_PORT
diff --git a/airflow/providers/vertica/backport_provider_setup.py b/airflow/providers/vertica/backport_provider_setup.py
index 313311ee44383..620ac43f7ffd8 100644
--- a/airflow/providers/vertica/backport_provider_setup.py
+++ b/airflow/providers/vertica/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/yandex/backport_provider_setup.py b/airflow/providers/yandex/backport_provider_setup.py
index 6125d675de79d..2ce94832e5dd1 100644
--- a/airflow/providers/yandex/backport_provider_setup.py
+++ b/airflow/providers/yandex/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/providers/yandex/hooks/yandex.py b/airflow/providers/yandex/hooks/yandex.py
index 5bb71a706879f..8121f29294f7b 100644
--- a/airflow/providers/yandex/hooks/yandex.py
+++ b/airflow/providers/yandex/hooks/yandex.py
@@ -16,7 +16,7 @@
# under the License.
import json
-from typing import Optional, Dict, Any, Union
+from typing import Any, Dict, Optional, Union
import yandexcloud
diff --git a/airflow/providers/zendesk/backport_provider_setup.py b/airflow/providers/zendesk/backport_provider_setup.py
index 2e6a4336aae9e..42843864a7b4b 100644
--- a/airflow/providers/zendesk/backport_provider_setup.py
+++ b/airflow/providers/zendesk/backport_provider_setup.py
@@ -29,8 +29,8 @@
import logging
import os
import sys
-
from os.path import dirname
+
from setuptools import find_packages, setup
logger = logging.getLogger(__name__)
diff --git a/airflow/secrets/base_secrets.py b/airflow/secrets/base_secrets.py
index 3d2f2576194ab..63bc94ae01942 100644
--- a/airflow/secrets/base_secrets.py
+++ b/airflow/secrets/base_secrets.py
@@ -59,6 +59,7 @@ def get_connections(self, conn_id: str) -> List['Connection']:
:type conn_id: str
"""
from airflow.models.connection import Connection
+
conn_uri = self.get_conn_uri(conn_id=conn_id)
if not conn_uri:
return []
@@ -74,7 +75,7 @@ def get_variable(self, key: str) -> Optional[str]:
"""
raise NotImplementedError()
- def get_config(self, key: str) -> Optional[str]: # pylint: disable=unused-argument
+ def get_config(self, key: str) -> Optional[str]: # pylint: disable=unused-argument
"""
Return value for Airflow Config Key
diff --git a/airflow/secrets/local_filesystem.py b/airflow/secrets/local_filesystem.py
index 2b249cc7968bd..a7d9e4f2d0a24 100644
--- a/airflow/secrets/local_filesystem.py
+++ b/airflow/secrets/local_filesystem.py
@@ -28,7 +28,10 @@
import yaml
from airflow.exceptions import (
- AirflowException, AirflowFileParseException, ConnectionNotUnique, FileSyntaxError,
+ AirflowException,
+ AirflowFileParseException,
+ ConnectionNotUnique,
+ FileSyntaxError,
)
from airflow.secrets.base_secrets import BaseSecretsBackend
from airflow.utils.file import COMMENT_PATTERN
@@ -82,7 +85,12 @@ def _parse_env_file(file_path: str) -> Tuple[Dict[str, List[str]], List[FileSynt
key, value = var_parts
if not key:
- errors.append(FileSyntaxError(line_no=line_no, message="Invalid line format. Key is empty.",))
+ errors.append(
+ FileSyntaxError(
+ line_no=line_no,
+ message="Invalid line format. Key is empty.",
+ )
+ )
secrets[key].append(value)
return secrets, errors
@@ -236,7 +244,8 @@ def load_connections(file_path) -> Dict[str, List[Any]]:
"""This function is deprecated. Please use `airflow.secrets.local_filesystem.load_connections_dict`.","""
warnings.warn(
"This function is deprecated. Please use `airflow.secrets.local_filesystem.load_connections_dict`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
return {k: [v] for k, v in load_connections_dict(file_path).values()}
diff --git a/airflow/secrets/metastore.py b/airflow/secrets/metastore.py
index 497bc4476d66b..e896f082ee2cf 100644
--- a/airflow/secrets/metastore.py
+++ b/airflow/secrets/metastore.py
@@ -33,6 +33,7 @@ class MetastoreBackend(BaseSecretsBackend):
@provide_session
def get_connections(self, conn_id, session=None) -> List['Connection']:
from airflow.models.connection import Connection
+
conn_list = session.query(Connection).filter(Connection.conn_id == conn_id).all()
session.expunge_all()
return conn_list
@@ -46,6 +47,7 @@ def get_variable(self, key: str, session=None):
:return: Variable Value
"""
from airflow.models.variable import Variable
+
var_value = session.query(Variable).filter(Variable.key == key).first()
session.expunge_all()
if var_value:
diff --git a/airflow/security/kerberos.py b/airflow/security/kerberos.py
index 37c1d9ff7f745..87c269dbe91e6 100644
--- a/airflow/security/kerberos.py
+++ b/airflow/security/kerberos.py
@@ -58,33 +58,36 @@ def renew_from_kt(principal: str, keytab: str, exit_on_fail: bool = True):
# minutes to give ourselves a large renewal buffer.
renewal_lifetime = "%sm" % conf.getint('kerberos', 'reinit_frequency')
- cmd_principal = principal or conf.get('kerberos', 'principal').replace(
- "_HOST", socket.getfqdn()
- )
+ cmd_principal = principal or conf.get('kerberos', 'principal').replace("_HOST", socket.getfqdn())
cmdv = [
conf.get('kerberos', 'kinit_path'),
- "-r", renewal_lifetime,
+ "-r",
+ renewal_lifetime,
"-k", # host ticket
- "-t", keytab, # specify keytab
- "-c", conf.get('kerberos', 'ccache'), # specify credentials cache
- cmd_principal
+ "-t",
+ keytab, # specify keytab
+ "-c",
+ conf.get('kerberos', 'ccache'), # specify credentials cache
+ cmd_principal,
]
log.info("Re-initialising kerberos from keytab: %s", " ".join(cmdv))
- subp = subprocess.Popen(cmdv,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
- close_fds=True,
- bufsize=-1,
- universal_newlines=True)
+ subp = subprocess.Popen(
+ cmdv,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ close_fds=True,
+ bufsize=-1,
+ universal_newlines=True,
+ )
subp.wait()
if subp.returncode != 0:
log.error(
"Couldn't reinit from keytab! `kinit' exited with %s.\n%s\n%s",
subp.returncode,
"\n".join(subp.stdout.readlines() if subp.stdout else []),
- "\n".join(subp.stderr.readlines() if subp.stderr else [])
+ "\n".join(subp.stderr.readlines() if subp.stderr else []),
)
if exit_on_fail:
sys.exit(subp.returncode)
@@ -113,13 +116,14 @@ def perform_krb181_workaround(principal: str):
:param principal: principal name
:return: None
"""
- cmdv = [conf.get('kerberos', 'kinit_path'),
- "-c", conf.get('kerberos', 'ccache'),
- "-R"] # Renew ticket_cache
+ cmdv = [
+ conf.get('kerberos', 'kinit_path'),
+ "-c",
+ conf.get('kerberos', 'ccache'),
+ "-R",
+ ] # Renew ticket_cache
- log.info(
- "Renewing kerberos ticket to work around kerberos 1.8.1: %s", " ".join(cmdv)
- )
+ log.info("Renewing kerberos ticket to work around kerberos 1.8.1: %s", " ".join(cmdv))
ret = subprocess.call(cmdv, close_fds=True)
@@ -132,7 +136,10 @@ def perform_krb181_workaround(principal: str):
"the ticket for '%s' is still renewable:\n $ kinit -f -c %s\nIf the 'renew until' date is the "
"same as the 'valid starting' date, the ticket cannot be renewed. Please check your KDC "
"configuration, and the ticket renewal policy (maxrenewlife) for the '%s' and `krbtgt' "
- "principals.", princ, ccache, princ
+ "principals.",
+ princ,
+ ccache,
+ princ,
)
return ret
diff --git a/airflow/sensors/base_sensor_operator.py b/airflow/sensors/base_sensor_operator.py
index d98fb97ea9b9a..cb639bc3176fa 100644
--- a/airflow/sensors/base_sensor_operator.py
+++ b/airflow/sensors/base_sensor_operator.py
@@ -25,7 +25,10 @@
from airflow.configuration import conf
from airflow.exceptions import (
- AirflowException, AirflowRescheduleException, AirflowSensorTimeout, AirflowSkipException,
+ AirflowException,
+ AirflowRescheduleException,
+ AirflowSensorTimeout,
+ AirflowSkipException,
)
from airflow.models import BaseOperator, SensorInstance
from airflow.models.skipmixin import SkipMixin
@@ -76,17 +79,27 @@ class BaseSensorOperator(BaseOperator, SkipMixin):
# setup. Smart sensor serialize these attributes into a different DB column so
# that smart sensor service is able to handle corresponding execution details
# without breaking the sensor poking logic with dedup.
- execution_fields = ('poke_interval', 'retries', 'execution_timeout', 'timeout',
- 'email', 'email_on_retry', 'email_on_failure',)
+ execution_fields = (
+ 'poke_interval',
+ 'retries',
+ 'execution_timeout',
+ 'timeout',
+ 'email',
+ 'email_on_retry',
+ 'email_on_failure',
+ )
@apply_defaults
- def __init__(self, *,
- poke_interval: float = 60,
- timeout: float = 60 * 60 * 24 * 7,
- soft_fail: bool = False,
- mode: str = 'poke',
- exponential_backoff: bool = False,
- **kwargs) -> None:
+ def __init__(
+ self,
+ *,
+ poke_interval: float = 60,
+ timeout: float = 60 * 60 * 24 * 7,
+ soft_fail: bool = False,
+ mode: str = 'poke',
+ exponential_backoff: bool = False,
+ **kwargs,
+ ) -> None:
super().__init__(**kwargs)
self.poke_interval = poke_interval
self.soft_fail = soft_fail
@@ -96,22 +109,24 @@ def __init__(self, *,
self._validate_input_values()
self.sensor_service_enabled = conf.getboolean('smart_sensor', 'use_smart_sensor')
self.sensors_support_sensor_service = set(
- map(lambda l: l.strip(), conf.get('smart_sensor', 'sensors_enabled').split(',')))
+ map(lambda l: l.strip(), conf.get('smart_sensor', 'sensors_enabled').split(','))
+ )
def _validate_input_values(self) -> None:
if not isinstance(self.poke_interval, (int, float)) or self.poke_interval < 0:
- raise AirflowException(
- "The poke_interval must be a non-negative number")
+ raise AirflowException("The poke_interval must be a non-negative number")
if not isinstance(self.timeout, (int, float)) or self.timeout < 0:
- raise AirflowException(
- "The timeout must be a non-negative number")
+ raise AirflowException("The timeout must be a non-negative number")
if self.mode not in self.valid_modes:
raise AirflowException(
"The mode must be one of {valid_modes},"
- "'{d}.{t}'; received '{m}'."
- .format(valid_modes=self.valid_modes,
- d=self.dag.dag_id if self.dag else "",
- t=self.task_id, m=self.mode))
+ "'{d}.{t}'; received '{m}'.".format(
+ valid_modes=self.valid_modes,
+ d=self.dag.dag_id if self.dag else "",
+ t=self.task_id,
+ m=self.mode,
+ )
+ )
def poke(self, context: Dict) -> bool:
"""
@@ -121,10 +136,12 @@ def poke(self, context: Dict) -> bool:
raise AirflowException('Override me.')
def is_smart_sensor_compatible(self):
- check_list = [not self.sensor_service_enabled,
- self.on_success_callback,
- self.on_retry_callback,
- self.on_failure_callback]
+ check_list = [
+ not self.sensor_service_enabled,
+ self.on_success_callback,
+ self.on_retry_callback,
+ self.on_failure_callback,
+ ]
for status in check_list:
if status:
return False
@@ -195,14 +212,13 @@ def execute(self, context: Dict) -> Any:
# give it a chance and fail with timeout.
# This gives the ability to set up non-blocking AND soft-fail sensors.
if self.soft_fail and not context['ti'].is_eligible_to_retry():
- raise AirflowSkipException(
- f"Snap. Time is OUT. DAG id: {log_dag_id}")
+ raise AirflowSkipException(f"Snap. Time is OUT. DAG id: {log_dag_id}")
else:
- raise AirflowSensorTimeout(
- f"Snap. Time is OUT. DAG id: {log_dag_id}")
+ raise AirflowSensorTimeout(f"Snap. Time is OUT. DAG id: {log_dag_id}")
if self.reschedule:
reschedule_date = timezone.utcnow() + timedelta(
- seconds=self._get_next_poke_interval(started_at, try_number))
+ seconds=self._get_next_poke_interval(started_at, try_number)
+ )
raise AirflowRescheduleException(reschedule_date)
else:
sleep(self._get_next_poke_interval(started_at, try_number))
@@ -215,17 +231,18 @@ def _get_next_poke_interval(self, started_at, try_number):
min_backoff = int(self.poke_interval * (2 ** (try_number - 2)))
current_time = timezone.utcnow()
- run_hash = int(hashlib.sha1("{}#{}#{}#{}".format(
- self.dag_id, self.task_id, started_at, try_number
- ).encode("utf-8")).hexdigest(), 16)
+ run_hash = int(
+ hashlib.sha1(
+ f"{self.dag_id}#{self.task_id}#{started_at}#{try_number}".encode("utf-8")
+ ).hexdigest(),
+ 16,
+ )
modded_hash = min_backoff + run_hash % min_backoff
- delay_backoff_in_seconds = min(
- modded_hash,
- timedelta.max.total_seconds() - 1
+ delay_backoff_in_seconds = min(modded_hash, timedelta.max.total_seconds() - 1)
+ new_interval = min(
+ self.timeout - int((current_time - started_at).total_seconds()), delay_backoff_in_seconds
)
- new_interval = min(self.timeout - int((current_time - started_at).total_seconds()),
- delay_backoff_in_seconds)
self.log.info("new %s interval is %s", self.mode, new_interval)
return new_interval
else:
@@ -269,6 +286,7 @@ def poke_mode_only(cls):
:param cls: BaseSensor class to enforce methods only use 'poke' mode.
:type cls: type
"""
+
def decorate(cls_type):
def mode_getter(_):
return 'poke'
@@ -278,9 +296,11 @@ def mode_setter(_, value):
raise ValueError("cannot set mode to 'poke'.")
if not issubclass(cls_type, BaseSensorOperator):
- raise ValueError(f"poke_mode_only decorator should only be "
- f"applied to subclasses of BaseSensorOperator,"
- f" got:{cls_type}.")
+ raise ValueError(
+ f"poke_mode_only decorator should only be "
+ f"applied to subclasses of BaseSensorOperator,"
+ f" got:{cls_type}."
+ )
cls_type.mode = property(mode_getter, mode_setter)
diff --git a/airflow/sensors/bash.py b/airflow/sensors/bash.py
index 711a7d40d24e8..371edba6d064e 100644
--- a/airflow/sensors/bash.py
+++ b/airflow/sensors/bash.py
@@ -45,11 +45,7 @@ class BashSensor(BaseSensorOperator):
template_fields = ('bash_command', 'env')
@apply_defaults
- def __init__(self, *,
- bash_command,
- env=None,
- output_encoding='utf-8',
- **kwargs):
+ def __init__(self, *, bash_command, env=None, output_encoding='utf-8', **kwargs):
super().__init__(**kwargs)
self.bash_command = bash_command
self.env = env
@@ -72,9 +68,13 @@ def poke(self, context):
self.log.info("Running command: %s", bash_command)
resp = Popen( # pylint: disable=subprocess-popen-preexec-fn
['bash', fname],
- stdout=PIPE, stderr=STDOUT,
- close_fds=True, cwd=tmp_dir,
- env=self.env, preexec_fn=os.setsid)
+ stdout=PIPE,
+ stderr=STDOUT,
+ close_fds=True,
+ cwd=tmp_dir,
+ env=self.env,
+ preexec_fn=os.setsid,
+ )
self.log.info("Output:")
for line in iter(resp.stdout.readline, b''):
diff --git a/airflow/sensors/date_time_sensor.py b/airflow/sensors/date_time_sensor.py
index 0d479ed6394f3..028e505bdc15f 100644
--- a/airflow/sensors/date_time_sensor.py
+++ b/airflow/sensors/date_time_sensor.py
@@ -57,9 +57,7 @@ class DateTimeSensor(BaseSensorOperator):
template_fields = ("target_time",)
@apply_defaults
- def __init__(
- self, *, target_time: Union[str, datetime.datetime], **kwargs
- ) -> None:
+ def __init__(self, *, target_time: Union[str, datetime.datetime], **kwargs) -> None:
super().__init__(**kwargs)
if isinstance(target_time, datetime.datetime):
self.target_time = target_time.isoformat()
@@ -67,9 +65,7 @@ def __init__(
self.target_time = target_time
else:
raise TypeError(
- "Expected str or datetime.datetime type for target_time. Got {}".format(
- type(target_time)
- )
+ "Expected str or datetime.datetime type for target_time. Got {}".format(type(target_time))
)
def poke(self, context: Dict) -> bool:
diff --git a/airflow/sensors/external_task_sensor.py b/airflow/sensors/external_task_sensor.py
index c32939f3d167b..70cd971ce7a76 100644
--- a/airflow/sensors/external_task_sensor.py
+++ b/airflow/sensors/external_task_sensor.py
@@ -85,15 +85,18 @@ def operator_extra_links(self):
return [ExternalTaskSensorLink()]
@apply_defaults
- def __init__(self, *,
- external_dag_id,
- external_task_id=None,
- allowed_states=None,
- failed_states=None,
- execution_delta=None,
- execution_date_fn=None,
- check_existence=False,
- **kwargs):
+ def __init__(
+ self,
+ *,
+ external_dag_id,
+ external_task_id=None,
+ allowed_states=None,
+ failed_states=None,
+ execution_delta=None,
+ execution_date_fn=None,
+ check_existence=False,
+ **kwargs,
+ ):
super().__init__(**kwargs)
self.allowed_states = allowed_states or [State.SUCCESS]
self.failed_states = failed_states or []
@@ -102,9 +105,10 @@ def __init__(self, *,
total_states = set(total_states)
if set(self.failed_states).intersection(set(self.allowed_states)):
- raise AirflowException("Duplicate values provided as allowed "
- "`{}` and failed states `{}`"
- .format(self.allowed_states, self.failed_states))
+ raise AirflowException(
+ "Duplicate values provided as allowed "
+ "`{}` and failed states `{}`".format(self.allowed_states, self.failed_states)
+ )
if external_task_id:
if not total_states <= set(State.task_states):
@@ -121,7 +125,8 @@ def __init__(self, *,
if execution_delta is not None and execution_date_fn is not None:
raise ValueError(
'Only one of `execution_delta` or `execution_date_fn` may '
- 'be provided to ExternalTaskSensor; not both.')
+ 'be provided to ExternalTaskSensor; not both.'
+ )
self.execution_delta = execution_delta
self.execution_date_fn = execution_date_fn
@@ -141,20 +146,16 @@ def poke(self, context, session=None):
dttm = context['execution_date']
dttm_filter = dttm if isinstance(dttm, list) else [dttm]
- serialized_dttm_filter = ','.join(
- [datetime.isoformat() for datetime in dttm_filter])
+ serialized_dttm_filter = ','.join([datetime.isoformat() for datetime in dttm_filter])
self.log.info(
- 'Poking for %s.%s on %s ... ',
- self.external_dag_id, self.external_task_id, serialized_dttm_filter
+ 'Poking for %s.%s on %s ... ', self.external_dag_id, self.external_task_id, serialized_dttm_filter
)
DM = DagModel
# we only do the check for 1st time, no need for subsequent poke
if self.check_existence and not self.has_checked_existence:
- dag_to_wait = session.query(DM).filter(
- DM.dag_id == self.external_dag_id
- ).first()
+ dag_to_wait = session.query(DM).filter(DM.dag_id == self.external_dag_id).first()
if not dag_to_wait:
raise AirflowException(f'The external DAG {self.external_dag_id} does not exist.')
@@ -204,19 +205,27 @@ def get_count(self, dttm_filter, session, states):
if self.external_task_id:
# .count() is inefficient
- count = session.query(func.count()).filter(
- TI.dag_id == self.external_dag_id,
- TI.task_id == self.external_task_id,
- TI.state.in_(states), # pylint: disable=no-member
- TI.execution_date.in_(dttm_filter),
- ).scalar()
+ count = (
+ session.query(func.count())
+ .filter(
+ TI.dag_id == self.external_dag_id,
+ TI.task_id == self.external_task_id,
+ TI.state.in_(states), # pylint: disable=no-member
+ TI.execution_date.in_(dttm_filter),
+ )
+ .scalar()
+ )
else:
# .count() is inefficient
- count = session.query(func.count()).filter(
- DR.dag_id == self.external_dag_id,
- DR.state.in_(states), # pylint: disable=no-member
- DR.execution_date.in_(dttm_filter),
- ).scalar()
+ count = (
+ session.query(func.count())
+ .filter(
+ DR.dag_id == self.external_dag_id,
+ DR.state.in_(states), # pylint: disable=no-member
+ DR.execution_date.in_(dttm_filter),
+ )
+ .scalar()
+ )
return count
def _handle_execution_date_fn(self, context):
@@ -235,9 +244,7 @@ def _handle_execution_date_fn(self, context):
if num_fxn_params == 2:
return self.execution_date_fn(context['execution_date'], context)
- raise AirflowException(
- f'execution_date_fn passed {num_fxn_params} args but only allowed up to 2'
- )
+ raise AirflowException(f'execution_date_fn passed {num_fxn_params} args but only allowed up to 2')
class ExternalTaskMarker(DummyOperator):
@@ -266,12 +273,15 @@ class ExternalTaskMarker(DummyOperator):
__serialized_fields: Optional[FrozenSet[str]] = None
@apply_defaults
- def __init__(self, *,
- external_dag_id,
- external_task_id,
- execution_date: Optional[Union[str, datetime.datetime]] = "{{ execution_date.isoformat() }}",
- recursion_depth: int = 10,
- **kwargs):
+ def __init__(
+ self,
+ *,
+ external_dag_id,
+ external_task_id,
+ execution_date: Optional[Union[str, datetime.datetime]] = "{{ execution_date.isoformat() }}",
+ recursion_depth: int = 10,
+ **kwargs,
+ ):
super().__init__(**kwargs)
self.external_dag_id = external_dag_id
self.external_task_id = external_task_id
@@ -280,8 +290,11 @@ def __init__(self, *,
elif isinstance(execution_date, str):
self.execution_date = execution_date
else:
- raise TypeError('Expected str or datetime.datetime type for execution_date. Got {}'
- .format(type(execution_date)))
+ raise TypeError(
+ 'Expected str or datetime.datetime type for execution_date. Got {}'.format(
+ type(execution_date)
+ )
+ )
if recursion_depth <= 0:
raise ValueError("recursion_depth should be a positive integer")
self.recursion_depth = recursion_depth
@@ -290,9 +303,5 @@ def __init__(self, *,
def get_serialized_fields(cls):
"""Serialized ExternalTaskMarker contain exactly these fields + templated_fields ."""
if not cls.__serialized_fields:
- cls.__serialized_fields = frozenset(
- super().get_serialized_fields() | {
- "recursion_depth"
- }
- )
+ cls.__serialized_fields = frozenset(super().get_serialized_fields() | {"recursion_depth"})
return cls.__serialized_fields
diff --git a/airflow/sensors/filesystem.py b/airflow/sensors/filesystem.py
index 6270ba25bc438..328fd362edd6f 100644
--- a/airflow/sensors/filesystem.py
+++ b/airflow/sensors/filesystem.py
@@ -44,10 +44,7 @@ class FileSensor(BaseSensorOperator):
ui_color = '#91818a'
@apply_defaults
- def __init__(self, *,
- filepath,
- fs_conn_id='fs_default',
- **kwargs):
+ def __init__(self, *, filepath, fs_conn_id='fs_default', **kwargs):
super().__init__(**kwargs)
self.filepath = filepath
self.fs_conn_id = fs_conn_id
diff --git a/airflow/sensors/hdfs_sensor.py b/airflow/sensors/hdfs_sensor.py
index 0204a2f969185..ed8830e4bbaef 100644
--- a/airflow/sensors/hdfs_sensor.py
+++ b/airflow/sensors/hdfs_sensor.py
@@ -25,5 +25,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.apache.hdfs.sensors.hdfs`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/sensors/hive_partition_sensor.py b/airflow/sensors/hive_partition_sensor.py
index 1768b3a226c07..028a9fd2dcace 100644
--- a/airflow/sensors/hive_partition_sensor.py
+++ b/airflow/sensors/hive_partition_sensor.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.apache.hive.sensors.hive_partition`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/sensors/http_sensor.py b/airflow/sensors/http_sensor.py
index 26755c2ba8a11..ad85b60447d00 100644
--- a/airflow/sensors/http_sensor.py
+++ b/airflow/sensors/http_sensor.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.http.sensors.http`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/sensors/metastore_partition_sensor.py b/airflow/sensors/metastore_partition_sensor.py
index e4db0ccd4bb04..5ab966f12057c 100644
--- a/airflow/sensors/metastore_partition_sensor.py
+++ b/airflow/sensors/metastore_partition_sensor.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.apache.hive.sensors.metastore_partition`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/sensors/named_hive_partition_sensor.py b/airflow/sensors/named_hive_partition_sensor.py
index 37d8eb31ae91b..eeb28e1656533 100644
--- a/airflow/sensors/named_hive_partition_sensor.py
+++ b/airflow/sensors/named_hive_partition_sensor.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.apache.hive.sensors.named_hive_partition`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/sensors/python.py b/airflow/sensors/python.py
index 183405188635b..68fcc66b94e7a 100644
--- a/airflow/sensors/python.py
+++ b/airflow/sensors/python.py
@@ -50,12 +50,14 @@ class PythonSensor(BaseSensorOperator):
@apply_defaults
def __init__(
- self, *,
- python_callable: Callable,
- op_args: Optional[List] = None,
- op_kwargs: Optional[Dict] = None,
- templates_dict: Optional[Dict] = None,
- **kwargs):
+ self,
+ *,
+ python_callable: Callable,
+ op_args: Optional[List] = None,
+ op_kwargs: Optional[Dict] = None,
+ templates_dict: Optional[Dict] = None,
+ **kwargs,
+ ):
super().__init__(**kwargs)
self.python_callable = python_callable
self.op_args = op_args or []
diff --git a/airflow/sensors/s3_key_sensor.py b/airflow/sensors/s3_key_sensor.py
index d8bada859f867..26b973e8c6ab9 100644
--- a/airflow/sensors/s3_key_sensor.py
+++ b/airflow/sensors/s3_key_sensor.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.s3_key`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/sensors/s3_prefix_sensor.py b/airflow/sensors/s3_prefix_sensor.py
index 27ef2bd903d47..f09854e7f87c6 100644
--- a/airflow/sensors/s3_prefix_sensor.py
+++ b/airflow/sensors/s3_prefix_sensor.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.s3_prefix`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/sensors/smart_sensor_operator.py b/airflow/sensors/smart_sensor_operator.py
index 369002ca86b21..4917ee64cc0c7 100644
--- a/airflow/sensors/smart_sensor_operator.py
+++ b/airflow/sensors/smart_sensor_operator.py
@@ -94,10 +94,12 @@ def __eq__(self, other):
if not isinstance(other, SensorWork):
return NotImplemented
- return self.dag_id == other.dag_id and \
- self.task_id == other.task_id and \
- self.execution_date == other.execution_date and \
- self.try_number == other.try_number
+ return (
+ self.dag_id == other.dag_id
+ and self.task_id == other.task_id
+ and self.execution_date == other.execution_date
+ and self.try_number == other.try_number
+ )
@staticmethod
def create_new_task_handler():
@@ -117,10 +119,9 @@ def _get_sensor_logger(self, si):
# The created log_id is used inside of smart sensor as the key to fetch
# the corresponding in memory log handler.
si.raw = False # Otherwise set_context will fail
- log_id = "-".join([si.dag_id,
- si.task_id,
- si.execution_date.strftime("%Y_%m_%dT%H_%M_%S_%f"),
- str(si.try_number)])
+ log_id = "-".join(
+ [si.dag_id, si.task_id, si.execution_date.strftime("%Y_%m_%dT%H_%M_%S_%f"), str(si.try_number)]
+ )
logger = logging.getLogger('airflow.task' + '.' + log_id)
if len(logger.handlers) == 0:
@@ -128,10 +129,11 @@ def _get_sensor_logger(self, si):
logger.addHandler(handler)
set_context(logger, si)
- line_break = ("-" * 120)
+ line_break = "-" * 120
logger.info(line_break)
- logger.info("Processing sensor task %s in smart sensor service on host: %s",
- self.ti_key, get_hostname())
+ logger.info(
+ "Processing sensor task %s in smart sensor service on host: %s", self.ti_key, get_hostname()
+ )
logger.info(line_break)
return logger
@@ -200,10 +202,12 @@ class SensorExceptionInfo:
infra failure, give the task more chance to retry before fail it.
"""
- def __init__(self,
- exception_info,
- is_infra_failure=False,
- infra_failure_retry_window=datetime.timedelta(minutes=130)):
+ def __init__(
+ self,
+ exception_info,
+ is_infra_failure=False,
+ infra_failure_retry_window=datetime.timedelta(minutes=130),
+ ):
self._exception_info = exception_info
self._is_infra_failure = is_infra_failure
self._infra_failure_retry_window = infra_failure_retry_window
@@ -302,15 +306,17 @@ class SmartSensorOperator(BaseOperator, SkipMixin):
ui_color = '#e6f1f2'
@apply_defaults
- def __init__(self,
- poke_interval=180,
- smart_sensor_timeout=60 * 60 * 24 * 7,
- soft_fail=False,
- shard_min=0,
- shard_max=100000,
- poke_timeout=6.0,
- *args,
- **kwargs):
+ def __init__(
+ self,
+ poke_interval=180,
+ smart_sensor_timeout=60 * 60 * 24 * 7,
+ soft_fail=False,
+ shard_min=0,
+ shard_max=100000,
+ poke_timeout=6.0,
+ *args,
+ **kwargs,
+ ):
super().__init__(*args, **kwargs)
# super(SmartSensorOperator, self).__init__(*args, **kwargs)
self.poke_interval = poke_interval
@@ -330,11 +336,9 @@ def __init__(self,
def _validate_input_values(self):
if not isinstance(self.poke_interval, (int, float)) or self.poke_interval < 0:
- raise AirflowException(
- "The poke_interval must be a non-negative number")
+ raise AirflowException("The poke_interval must be a non-negative number")
if not isinstance(self.timeout, (int, float)) or self.timeout < 0:
- raise AirflowException(
- "The timeout must be a non-negative number")
+ raise AirflowException("The timeout must be a non-negative number")
@provide_session
def _load_sensor_works(self, session=None):
@@ -345,10 +349,11 @@ def _load_sensor_works(self, session=None):
"""
SI = SensorInstance
start_query_time = time.time()
- query = session.query(SI) \
- .filter(SI.state == State.SENSING)\
- .filter(SI.shardcode < self.shard_max,
- SI.shardcode >= self.shard_min)
+ query = (
+ session.query(SI)
+ .filter(SI.state == State.SENSING)
+ .filter(SI.shardcode < self.shard_max, SI.shardcode >= self.shard_min)
+ )
tis = query.all()
self.log.info("Performance query %s tis, time: %s", len(tis), time.time() - start_query_time)
@@ -387,10 +392,11 @@ def _update_ti_hostname(self, sensor_works, session=None):
def update_ti_hostname_with_count(count, ti_keys):
# Using or_ instead of in_ here to prevent from full table scan.
- tis = session.query(TI) \
- .filter(or_(tuple_(TI.dag_id, TI.task_id, TI.execution_date) == ti_key
- for ti_key in ti_keys)) \
+ tis = (
+ session.query(TI)
+ .filter(or_(tuple_(TI.dag_id, TI.task_id, TI.execution_date) == ti_key for ti_key in ti_keys))
.all()
+ )
for ti in tis:
ti.hostname = self.hostname
@@ -414,6 +420,7 @@ def _mark_multi_state(self, operator, poke_hash, encoded_poke_context, state, se
:param state: Set multiple sensor tasks to this state.
:param session: The sqlalchemy session.
"""
+
def mark_state(ti, sensor_instance):
ti.state = state
sensor_instance.state = state
@@ -426,14 +433,22 @@ def mark_state(ti, sensor_instance):
count_marked = 0
try:
- query_result = session.query(TI, SI)\
- .join(TI, and_(TI.dag_id == SI.dag_id,
- TI.task_id == SI.task_id,
- TI.execution_date == SI.execution_date)) \
- .filter(SI.state == State.SENSING) \
- .filter(SI.hashcode == poke_hash) \
- .filter(SI.operator == operator) \
- .with_for_update().all()
+ query_result = (
+ session.query(TI, SI)
+ .join(
+ TI,
+ and_(
+ TI.dag_id == SI.dag_id,
+ TI.task_id == SI.task_id,
+ TI.execution_date == SI.execution_date,
+ ),
+ )
+ .filter(SI.state == State.SENSING)
+ .filter(SI.hashcode == poke_hash)
+ .filter(SI.operator == operator)
+ .with_for_update()
+ .all()
+ )
end_date = timezone.utcnow()
for ti, sensor_instance in query_result:
@@ -451,8 +466,7 @@ def mark_state(ti, sensor_instance):
session.commit()
except Exception as e: # pylint: disable=broad-except
- self.log.warning("Exception _mark_multi_state in smart sensor for hashcode %s",
- str(poke_hash))
+ self.log.warning("Exception _mark_multi_state in smart sensor for hashcode %s", str(poke_hash))
self.log.exception(e, exc_info=True)
self.log.info("Marked %s tasks out of %s to state %s", count_marked, len(query_result), state)
@@ -469,6 +483,7 @@ def _retry_or_fail_task(self, sensor_work, error, session=None):
:type error: str.
:param session: The sqlalchemy session.
"""
+
def email_alert(task_instance, error_info):
try:
subject, html_content, _ = task_instance.get_email_subject_content(error_info)
@@ -480,18 +495,19 @@ def email_alert(task_instance, error_info):
sensor_work.log.exception(e, exc_info=True)
def handle_failure(sensor_work, ti):
- if sensor_work.execution_context.get('retries') and \
- ti.try_number <= ti.max_tries:
+ if sensor_work.execution_context.get('retries') and ti.try_number <= ti.max_tries:
# retry
ti.state = State.UP_FOR_RETRY
- if sensor_work.execution_context.get('email_on_retry') and \
- sensor_work.execution_context.get('email'):
+ if sensor_work.execution_context.get('email_on_retry') and sensor_work.execution_context.get(
+ 'email'
+ ):
sensor_work.log.info("%s sending email alert for retry", sensor_work.ti_key)
email_alert(ti, error)
else:
ti.state = State.FAILED
- if sensor_work.execution_context.get('email_on_failure') and \
- sensor_work.execution_context.get('email'):
+ if sensor_work.execution_context.get(
+ 'email_on_failure'
+ ) and sensor_work.execution_context.get('email'):
sensor_work.log.info("%s sending email alert for failure", sensor_work.ti_key)
email_alert(ti, error)
@@ -499,23 +515,23 @@ def handle_failure(sensor_work, ti):
dag_id, task_id, execution_date = sensor_work.ti_key
TI = TaskInstance
SI = SensorInstance
- sensor_instance = session.query(SI).filter(
- SI.dag_id == dag_id,
- SI.task_id == task_id,
- SI.execution_date == execution_date) \
- .with_for_update() \
+ sensor_instance = (
+ session.query(SI)
+ .filter(SI.dag_id == dag_id, SI.task_id == task_id, SI.execution_date == execution_date)
+ .with_for_update()
.first()
+ )
if sensor_instance.hashcode != sensor_work.hashcode:
# Return without setting state
return
- ti = session.query(TI).filter(
- TI.dag_id == dag_id,
- TI.task_id == task_id,
- TI.execution_date == execution_date) \
- .with_for_update() \
+ ti = (
+ session.query(TI)
+ .filter(TI.dag_id == dag_id, TI.task_id == task_id, TI.execution_date == execution_date)
+ .with_for_update()
.first()
+ )
if ti:
if ti.state == State.SENSING:
@@ -531,8 +547,9 @@ def handle_failure(sensor_work, ti):
session.merge(ti)
session.commit()
- sensor_work.log.info("Task %s got an error: %s. Set the state to failed. Exit.",
- str(sensor_work.ti_key), error)
+ sensor_work.log.info(
+ "Task %s got an error: %s. Set the state to failed. Exit.", str(sensor_work.ti_key), error
+ )
sensor_work.close_sensor_logger()
except AirflowException as e:
@@ -568,8 +585,9 @@ def _handle_poke_exception(self, sensor_work):
if sensor_exception.fail_current_run:
if sensor_exception.is_infra_failure:
- sensor_work.log.exception("Task %s failed by infra failure in smart sensor.",
- sensor_work.ti_key)
+ sensor_work.log.exception(
+ "Task %s failed by infra failure in smart sensor.", sensor_work.ti_key
+ )
# There is a risk for sensor object cached in smart sensor keep throwing
# exception and cause an infra failure. To make sure the sensor tasks after
# retry will not fall into same object and have endless infra failure,
@@ -619,10 +637,12 @@ def _execute_sensor_work(self, sensor_work):
# Got a landed signal, mark all tasks waiting for this partition
cached_work.set_state(PokeState.LANDED)
- self._mark_multi_state(sensor_work.operator,
- sensor_work.hashcode,
- sensor_work.encoded_poke_context,
- State.SUCCESS)
+ self._mark_multi_state(
+ sensor_work.operator,
+ sensor_work.hashcode,
+ sensor_work.encoded_poke_context,
+ State.SUCCESS,
+ )
log.info("Task %s succeeded", str(ti_key))
sensor_work.close_sensor_logger()
@@ -642,12 +662,12 @@ def _execute_sensor_work(self, sensor_work):
if cache_key in self.cached_sensor_exceptions:
self.cached_sensor_exceptions[cache_key].set_latest_exception(
- exception_info,
- is_infra_failure=is_infra_failure)
+ exception_info, is_infra_failure=is_infra_failure
+ )
else:
self.cached_sensor_exceptions[cache_key] = SensorExceptionInfo(
- exception_info,
- is_infra_failure=is_infra_failure)
+ exception_info, is_infra_failure=is_infra_failure
+ )
self._handle_poke_exception(sensor_work)
@@ -671,8 +691,7 @@ def poke(self, sensor_work):
"""
cached_work = self.cached_dedup_works[sensor_work.cache_key]
if not cached_work.sensor_task:
- init_args = dict(list(sensor_work.poke_context.items())
- + [('task_id', sensor_work.task_id)])
+ init_args = dict(list(sensor_work.poke_context.items()) + [('task_id', sensor_work.task_id)])
operator_class = import_string(sensor_work.op_classpath)
cached_work.sensor_task = operator_class(**init_args)
diff --git a/airflow/sensors/sql_sensor.py b/airflow/sensors/sql_sensor.py
index b8dc1432538b5..aabff0d547c1c 100644
--- a/airflow/sensors/sql_sensor.py
+++ b/airflow/sensors/sql_sensor.py
@@ -52,12 +52,16 @@ class SqlSensor(BaseSensorOperator):
"""
template_fields: Iterable[str] = ('sql',)
- template_ext: Iterable[str] = ('.hql', '.sql',)
+ template_ext: Iterable[str] = (
+ '.hql',
+ '.sql',
+ )
ui_color = '#7c7287'
@apply_defaults
- def __init__(self, *, conn_id, sql, parameters=None, success=None, failure=None, fail_on_empty=False,
- **kwargs):
+ def __init__(
+ self, *, conn_id, sql, parameters=None, success=None, failure=None, fail_on_empty=False, **kwargs
+ ):
self.conn_id = conn_id
self.sql = sql
self.parameters = parameters
@@ -69,12 +73,24 @@ def __init__(self, *, conn_id, sql, parameters=None, success=None, failure=None,
def _get_hook(self):
conn = BaseHook.get_connection(self.conn_id)
- allowed_conn_type = {'google_cloud_platform', 'jdbc', 'mssql',
- 'mysql', 'odbc', 'oracle', 'postgres',
- 'presto', 'snowflake', 'sqlite', 'vertica'}
+ allowed_conn_type = {
+ 'google_cloud_platform',
+ 'jdbc',
+ 'mssql',
+ 'mysql',
+ 'odbc',
+ 'oracle',
+ 'postgres',
+ 'presto',
+ 'snowflake',
+ 'sqlite',
+ 'vertica',
+ }
if conn.conn_type not in allowed_conn_type:
- raise AirflowException("The connection type is not supported by SqlSensor. " +
- "Supported connection types: {}".format(list(allowed_conn_type)))
+ raise AirflowException(
+ "The connection type is not supported by SqlSensor. "
+ + "Supported connection types: {}".format(list(allowed_conn_type))
+ )
return conn.get_hook()
def poke(self, context):
@@ -91,8 +107,7 @@ def poke(self, context):
if self.failure is not None:
if callable(self.failure):
if self.failure(first_cell):
- raise AirflowException(
- f"Failure criteria met. self.failure({first_cell}) returned True")
+ raise AirflowException(f"Failure criteria met. self.failure({first_cell}) returned True")
else:
raise AirflowException(f"self.failure is present, but not callable -> {self.success}")
if self.success is not None:
diff --git a/airflow/sensors/web_hdfs_sensor.py b/airflow/sensors/web_hdfs_sensor.py
index 7ad6cae83f17a..015d0ab684c2a 100644
--- a/airflow/sensors/web_hdfs_sensor.py
+++ b/airflow/sensors/web_hdfs_sensor.py
@@ -24,5 +24,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.apache.hdfs.sensors.web_hdfs`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/sensors/weekday_sensor.py b/airflow/sensors/weekday_sensor.py
index 0340794664587..cadbeaa1a009d 100644
--- a/airflow/sensors/weekday_sensor.py
+++ b/airflow/sensors/weekday_sensor.py
@@ -73,9 +73,7 @@ class DayOfWeekSensor(BaseSensorOperator):
"""
@apply_defaults
- def __init__(self, *, week_day,
- use_task_execution_day=False,
- **kwargs):
+ def __init__(self, *, week_day, use_task_execution_day=False, **kwargs):
super().__init__(**kwargs)
self.week_day = week_day
self.use_task_execution_day = use_task_execution_day
@@ -91,12 +89,15 @@ def __init__(self, *, week_day,
else:
raise TypeError(
'Unsupported Type for week_day parameter: {}. It should be one of str'
- ', set or Weekday enum type'.format(type(week_day)))
+ ', set or Weekday enum type'.format(type(week_day))
+ )
def poke(self, context):
- self.log.info('Poking until weekday is in %s, Today is %s',
- self.week_day,
- WeekDay(timezone.utcnow().isoweekday()).name)
+ self.log.info(
+ 'Poking until weekday is in %s, Today is %s',
+ self.week_day,
+ WeekDay(timezone.utcnow().isoweekday()).name,
+ )
if self.use_task_execution_day:
return context['execution_date'].isoweekday() in self._week_day_num
else:
diff --git a/airflow/sentry.py b/airflow/sentry.py
index 253534383e941..f16b05722f85a 100644
--- a/airflow/sentry.py
+++ b/airflow/sentry.py
@@ -50,6 +50,7 @@ def flush(self):
Sentry: DummySentry = DummySentry()
if conf.getboolean("sentry", 'sentry_on', fallback=False):
import sentry_sdk
+
# Verify blinker installation
from blinker import signal # noqa: F401 pylint: disable=unused-import
from sentry_sdk.integrations.flask import FlaskIntegration
@@ -58,14 +59,19 @@ def flush(self):
class ConfiguredSentry(DummySentry):
"""Configure Sentry SDK."""
- SCOPE_TAGS = frozenset(
- ("task_id", "dag_id", "execution_date", "operator", "try_number")
- )
+ SCOPE_TAGS = frozenset(("task_id", "dag_id", "execution_date", "operator", "try_number"))
SCOPE_CRUMBS = frozenset(("task_id", "state", "operator", "duration"))
UNSUPPORTED_SENTRY_OPTIONS = frozenset(
- ("integrations", "in_app_include", "in_app_exclude", "ignore_errors",
- "before_breadcrumb", "before_send", "transport")
+ (
+ "integrations",
+ "in_app_include",
+ "in_app_exclude",
+ "ignore_errors",
+ "before_breadcrumb",
+ "before_send",
+ "transport",
+ )
)
def __init__(self):
@@ -94,12 +100,11 @@ def __init__(self):
# supported backward compability with old way dsn option
dsn = old_way_dsn or new_way_dsn
- unsupported_options = self.UNSUPPORTED_SENTRY_OPTIONS.intersection(
- sentry_config_opts.keys())
+ unsupported_options = self.UNSUPPORTED_SENTRY_OPTIONS.intersection(sentry_config_opts.keys())
if unsupported_options:
log.warning(
"There are unsupported options in [sentry] section: %s",
- ", ".join(unsupported_options)
+ ", ".join(unsupported_options),
)
if dsn:
diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py
index 5b0def99ce751..7ce5e1a1d7f3c 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -43,6 +43,7 @@
# isort: off
from kubernetes.client import models as k8s
from airflow.kubernetes.pod_generator import PodGenerator
+
# isort: on
HAS_KUBERNETES = True
except ImportError:
@@ -99,8 +100,9 @@ def from_json(cls, serialized_obj: str) -> Union['BaseSerialization', dict, list
return cls.from_dict(json.loads(serialized_obj))
@classmethod
- def from_dict(cls, serialized_obj: Dict[Encoding, Any]) -> \
- Union['BaseSerialization', dict, list, set, tuple]:
+ def from_dict(
+ cls, serialized_obj: Dict[Encoding, Any]
+ ) -> Union['BaseSerialization', dict, list, set, tuple]:
"""Deserializes a python dict stored with type decorators and
reconstructs all DAGs and operators it contains.
"""
@@ -138,14 +140,14 @@ def _is_excluded(cls, var: Any, attrname: str, instance: Any) -> bool:
return True
return cls._value_is_hardcoded_default(attrname, var, instance)
- return (
- isinstance(var, cls._excluded_types) or
- cls._value_is_hardcoded_default(attrname, var, instance)
+ return isinstance(var, cls._excluded_types) or cls._value_is_hardcoded_default(
+ attrname, var, instance
)
@classmethod
- def serialize_to_json(cls, object_to_serialize: Union[BaseOperator, DAG], decorated_fields: Set) \
- -> Dict[str, Any]:
+ def serialize_to_json(
+ cls, object_to_serialize: Union[BaseOperator, DAG], decorated_fields: Set
+ ) -> Dict[str, Any]:
"""Serializes an object to json"""
serialized_object: Dict[str, Any] = {}
keys_to_serialize = object_to_serialize.get_serialized_fields()
@@ -184,10 +186,7 @@ def _serialize(cls, var: Any) -> Any: # Unfortunately there is no support for r
return var.value
return var
elif isinstance(var, dict):
- return cls._encode(
- {str(k): cls._serialize(v) for k, v in var.items()},
- type_=DAT.DICT
- )
+ return cls._encode({str(k): cls._serialize(v) for k, v in var.items()}, type_=DAT.DICT)
elif isinstance(var, list):
return [cls._serialize(v) for v in var]
elif HAS_KUBERNETES and isinstance(var, k8s.V1Pod):
@@ -215,12 +214,10 @@ def _serialize(cls, var: Any) -> Any: # Unfortunately there is no support for r
return str(get_python_source(var))
elif isinstance(var, set):
# FIXME: casts set to list in customized serialization in future.
- return cls._encode(
- [cls._serialize(v) for v in var], type_=DAT.SET)
+ return cls._encode([cls._serialize(v) for v in var], type_=DAT.SET)
elif isinstance(var, tuple):
# FIXME: casts tuple to list in customized serialization in future.
- return cls._encode(
- [cls._serialize(v) for v in var], type_=DAT.TUPLE)
+ return cls._encode([cls._serialize(v) for v in var], type_=DAT.TUPLE)
elif isinstance(var, TaskGroup):
return SerializedTaskGroup.serialize_task_group(var)
else:
@@ -256,9 +253,7 @@ def _deserialize(cls, encoded_var: Any) -> Any: # pylint: disable=too-many-retu
return pendulum.from_timestamp(var)
elif type_ == DAT.POD:
if not HAS_KUBERNETES:
- raise RuntimeError(
- "Cannot deserialize POD objects without kubernetes libraries installed!"
- )
+ raise RuntimeError("Cannot deserialize POD objects without kubernetes libraries installed!")
pod = PodGenerator.deserialize_model_dict(var)
return pod
elif type_ == DAT.TIMEDELTA:
@@ -306,8 +301,9 @@ def _value_is_hardcoded_default(cls, attrname: str, value: Any, instance: Any) -
``field = field or {}`` set.
"""
# pylint: disable=unused-argument
- if attrname in cls._CONSTRUCTOR_PARAMS and \
- (cls._CONSTRUCTOR_PARAMS[attrname] is value or (value in [{}, []])):
+ if attrname in cls._CONSTRUCTOR_PARAMS and (
+ cls._CONSTRUCTOR_PARAMS[attrname] is value or (value in [{}, []])
+ ):
return True
return False
@@ -322,7 +318,8 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
_decorated_fields = {'executor_config'}
_CONSTRUCTOR_PARAMS = {
- k: v.default for k, v in signature(BaseOperator.__init__).parameters.items()
+ k: v.default
+ for k, v in signature(BaseOperator.__init__).parameters.items()
if v.default is not v.empty
}
@@ -354,8 +351,9 @@ def serialize_operator(cls, op: BaseOperator) -> dict:
serialize_op['_task_type'] = op.__class__.__name__
serialize_op['_task_module'] = op.__class__.__module__
if op.operator_extra_links:
- serialize_op['_operator_extra_links'] = \
- cls._serialize_operator_extra_links(op.operator_extra_links)
+ serialize_op['_operator_extra_links'] = cls._serialize_operator_extra_links(
+ op.operator_extra_links
+ )
# Store all template_fields as they are if there are JSON Serializable
# If not, store them as strings
@@ -371,6 +369,7 @@ def serialize_operator(cls, op: BaseOperator) -> dict:
def deserialize_operator(cls, encoded_op: Dict[str, Any]) -> BaseOperator:
"""Deserializes an operator from a JSON object."""
from airflow import plugins_manager
+
plugins_manager.initialize_extra_operators_links_plugins()
if plugins_manager.operator_extra_links is None:
@@ -386,8 +385,10 @@ def deserialize_operator(cls, encoded_op: Dict[str, Any]) -> BaseOperator:
for ope in plugins_manager.operator_extra_links:
for operator in ope.operators:
- if operator.__name__ == encoded_op["_task_type"] and \
- operator.__module__ == encoded_op["_task_module"]:
+ if (
+ operator.__name__ == encoded_op["_task_type"]
+ and operator.__module__ == encoded_op["_task_module"]
+ ):
op_extra_links_from_plugin.update({ope.name: ope})
# If OperatorLinks are defined in Plugins but not in the Operator that is being Serialized
@@ -444,10 +445,7 @@ def _is_excluded(cls, var: Any, attrname: str, op: BaseOperator):
return super()._is_excluded(var, attrname, op)
@classmethod
- def _deserialize_operator_extra_links(
- cls,
- encoded_op_links: list
- ) -> Dict[str, BaseOperatorLink]:
+ def _deserialize_operator_extra_links(cls, encoded_op_links: list) -> Dict[str, BaseOperatorLink]:
"""
Deserialize Operator Links if the Classes are registered in Airflow Plugins.
Error is raised if the OperatorLink is not found in Plugins too.
@@ -456,6 +454,7 @@ def _deserialize_operator_extra_links(
:return: De-Serialized Operator Link
"""
from airflow import plugins_manager
+
plugins_manager.initialize_extra_operators_links_plugins()
if plugins_manager.registered_operator_link_classes is None:
@@ -502,20 +501,14 @@ def _deserialize_operator_extra_links(
log.error("Operator Link class %r not registered", _operator_link_class_path)
return {}
- op_predefined_extra_link: BaseOperatorLink = cattr.structure(
- data, single_op_link_class)
+ op_predefined_extra_link: BaseOperatorLink = cattr.structure(data, single_op_link_class)
- op_predefined_extra_links.update(
- {op_predefined_extra_link.name: op_predefined_extra_link}
- )
+ op_predefined_extra_links.update({op_predefined_extra_link.name: op_predefined_extra_link})
return op_predefined_extra_links
@classmethod
- def _serialize_operator_extra_links(
- cls,
- operator_extra_links: Iterable[BaseOperatorLink]
- ):
+ def _serialize_operator_extra_links(cls, operator_extra_links: Iterable[BaseOperatorLink]):
"""
Serialize Operator Links. Store the import path of the OperatorLink and the arguments
passed to it. Example
@@ -531,8 +524,9 @@ def _serialize_operator_extra_links(
op_link_arguments = {}
serialize_operator_extra_links.append(
{
- "{}.{}".format(operator_extra_link.__class__.__module__,
- operator_extra_link.__class__.__name__): op_link_arguments
+ "{}.{}".format(
+ operator_extra_link.__class__.__module__, operator_extra_link.__class__.__name__
+ ): op_link_arguments
}
)
@@ -563,7 +557,8 @@ def __get_constructor_defaults(): # pylint: disable=no-method-argument
'access_control': '_access_control',
}
return {
- param_to_attr.get(k, k): v.default for k, v in signature(DAG.__init__).parameters.items()
+ param_to_attr.get(k, k): v.default
+ for k, v in signature(DAG.__init__).parameters.items()
if v.default is not v.empty
}
@@ -590,9 +585,7 @@ def deserialize_dag(cls, encoded_dag: Dict[str, Any]) -> 'SerializedDAG':
if k == "_downstream_task_ids":
v = set(v)
elif k == "tasks":
- v = {
- task["task_id"]: SerializedBaseOperator.deserialize_operator(task) for task in v
- }
+ v = {task["task_id"]: SerializedBaseOperator.deserialize_operator(task) for task in v}
k = "task_dict"
elif k == "timezone":
v = cls._deserialize_timezone(v)
@@ -610,9 +603,7 @@ def deserialize_dag(cls, encoded_dag: Dict[str, Any]) -> 'SerializedDAG':
# pylint: disable=protected-access
if "_task_group" in encoded_dag:
dag._task_group = SerializedTaskGroup.deserialize_task_group( # type: ignore
- encoded_dag["_task_group"],
- None,
- dag.task_dict
+ encoded_dag["_task_group"], None, dag.task_dict
)
else:
# This must be old data that had no task_group. Create a root TaskGroup and add
@@ -641,17 +632,15 @@ def deserialize_dag(cls, encoded_dag: Dict[str, Any]) -> 'SerializedDAG':
for task_id in serializable_task.downstream_task_ids:
# Bypass set_upstream etc here - it does more than we want
- dag.task_dict[task_id]._upstream_task_ids.add(serializable_task.task_id) # noqa: E501 # pylint: disable=protected-access
+ # noqa: E501 # pylint: disable=protected-access
+ dag.task_dict[task_id]._upstream_task_ids.add(serializable_task.task_id)
return dag
@classmethod
def to_dict(cls, var: Any) -> dict:
"""Stringifies DAGs and operators contained by var and returns a dict of var."""
- json_dict = {
- "__version": cls.SERIALIZER_VERSION,
- "dag": cls.serialize_dag(var)
- }
+ json_dict = {"__version": cls.SERIALIZER_VERSION, "dag": cls.serialize_dag(var)}
# Validate Serialized DAG with Json Schema. Raises Error if it mismatches
cls.validate_schema(json_dict)
@@ -683,15 +672,14 @@ def serialize_task_group(cls, task_group: TaskGroup) -> Optional[Union[Dict[str,
"ui_fgcolor": task_group.ui_fgcolor,
"children": {
label: (DAT.OP, child.task_id)
- if isinstance(child, BaseOperator) else
- (DAT.TASK_GROUP, SerializedTaskGroup.serialize_task_group(child))
+ if isinstance(child, BaseOperator)
+ else (DAT.TASK_GROUP, SerializedTaskGroup.serialize_task_group(child))
for label, child in task_group.children.items()
},
"upstream_group_ids": cls._serialize(list(task_group.upstream_group_ids)),
"downstream_group_ids": cls._serialize(list(task_group.downstream_group_ids)),
"upstream_task_ids": cls._serialize(list(task_group.upstream_task_ids)),
"downstream_task_ids": cls._serialize(list(task_group.downstream_task_ids)),
-
}
return serialize_group
@@ -701,7 +689,7 @@ def deserialize_task_group(
cls,
encoded_group: Dict[str, Any],
parent_group: Optional[TaskGroup],
- task_dict: Dict[str, BaseOperator]
+ task_dict: Dict[str, BaseOperator],
) -> Optional[TaskGroup]:
"""Deserializes a TaskGroup from a JSON object."""
if not encoded_group:
@@ -712,15 +700,12 @@ def deserialize_task_group(
key: cls._deserialize(encoded_group[key])
for key in ["prefix_group_id", "tooltip", "ui_color", "ui_fgcolor"]
}
- group = SerializedTaskGroup(
- group_id=group_id,
- parent_group=parent_group,
- **kwargs
- )
+ group = SerializedTaskGroup(group_id=group_id, parent_group=parent_group, **kwargs)
group.children = {
- label: task_dict[val] if _type == DAT.OP # type: ignore
- else SerializedTaskGroup.deserialize_task_group(val, group, task_dict) for label, (_type, val)
- in encoded_group["children"].items()
+ label: task_dict[val]
+ if _type == DAT.OP # type: ignore
+ else SerializedTaskGroup.deserialize_task_group(val, group, task_dict)
+ for label, (_type, val) in encoded_group["children"].items()
}
group.upstream_group_ids = set(cls._deserialize(encoded_group["upstream_group_ids"]))
group.downstream_group_ids = set(cls._deserialize(encoded_group["downstream_group_ids"]))
diff --git a/airflow/settings.py b/airflow/settings.py
index 5c321095b3bae..1f38cfc1dad3b 100644
--- a/airflow/settings.py
+++ b/airflow/settings.py
@@ -49,13 +49,15 @@
log.info("Configured default timezone %s", TIMEZONE)
-HEADER = '\n'.join([
- r' ____________ _____________',
- r' ____ |__( )_________ __/__ /________ __',
- r'____ /| |_ /__ ___/_ /_ __ /_ __ \_ | /| / /',
- r'___ ___ | / _ / _ __/ _ / / /_/ /_ |/ |/ /',
- r' _/_/ |_/_/ /_/ /_/ /_/ \____/____/|__/',
-])
+HEADER = '\n'.join(
+ [
+ r' ____________ _____________',
+ r' ____ |__( )_________ __/__ /________ __',
+ r'____ /| |_ /__ ___/_ /_ __ /_ __ \_ | /| / /',
+ r'___ ___ | / _ / _ __/ _ / / /_/ /_ |/ |/ /',
+ r' _/_/ |_/_/ /_/ /_/ /_/ \____/____/|__/',
+ ]
+)
LOGGING_LEVEL = logging.INFO
@@ -146,11 +148,7 @@ def configure_vars():
SQL_ALCHEMY_CONN = conf.get('core', 'SQL_ALCHEMY_CONN')
DAGS_FOLDER = os.path.expanduser(conf.get('core', 'DAGS_FOLDER'))
- PLUGINS_FOLDER = conf.get(
- 'core',
- 'plugins_folder',
- fallback=os.path.join(AIRFLOW_HOME, 'plugins')
- )
+ PLUGINS_FOLDER = conf.get('core', 'plugins_folder', fallback=os.path.join(AIRFLOW_HOME, 'plugins'))
def configure_orm(disable_connection_pool=False):
@@ -172,12 +170,14 @@ def configure_orm(disable_connection_pool=False):
engine = create_engine(SQL_ALCHEMY_CONN, connect_args=connect_args, **engine_args)
setup_event_handlers(engine)
- Session = scoped_session(sessionmaker(
- autocommit=False,
- autoflush=False,
- bind=engine,
- expire_on_commit=False,
- ))
+ Session = scoped_session(
+ sessionmaker(
+ autocommit=False,
+ autoflush=False,
+ bind=engine,
+ expire_on_commit=False,
+ )
+ )
def prepare_engine_args(disable_connection_pool=False):
@@ -218,8 +218,14 @@ def prepare_engine_args(disable_connection_pool=False):
# https://docs.sqlalchemy.org/en/13/core/pooling.html#disconnect-handling-pessimistic
pool_pre_ping = conf.getboolean('core', 'SQL_ALCHEMY_POOL_PRE_PING', fallback=True)
- log.debug("settings.prepare_engine_args(): Using pool settings. pool_size=%d, max_overflow=%d, "
- "pool_recycle=%d, pid=%d", pool_size, max_overflow, pool_recycle, os.getpid())
+ log.debug(
+ "settings.prepare_engine_args(): Using pool settings. pool_size=%d, max_overflow=%d, "
+ "pool_recycle=%d, pid=%d",
+ pool_size,
+ max_overflow,
+ pool_recycle,
+ os.getpid(),
+ )
engine_args['pool_size'] = pool_size
engine_args['pool_recycle'] = pool_recycle
engine_args['pool_pre_ping'] = pool_pre_ping
@@ -244,18 +250,22 @@ def dispose_orm():
def configure_adapters():
"""Register Adapters and DB Converters"""
from pendulum import DateTime as Pendulum
+
try:
from sqlite3 import register_adapter
+
register_adapter(Pendulum, lambda val: val.isoformat(' '))
except ImportError:
pass
try:
import MySQLdb.converters
+
MySQLdb.converters.conversions[Pendulum] = MySQLdb.converters.DateTime2literal
except ImportError:
pass
try:
import pymysql.converters
+
pymysql.converters.conversions[Pendulum] = pymysql.converters.escape_datetime
except ImportError:
pass
@@ -334,6 +344,8 @@ def initialize():
# Ensure we close DB connections at scheduler and gunicon worker terminations
atexit.register(dispose_orm)
+
+
# pylint: enable=global-statement
@@ -341,19 +353,16 @@ def initialize():
KILOBYTE = 1024
MEGABYTE = KILOBYTE * KILOBYTE
-WEB_COLORS = {'LIGHTBLUE': '#4d9de0',
- 'LIGHTORANGE': '#FF9933'}
+WEB_COLORS = {'LIGHTBLUE': '#4d9de0', 'LIGHTORANGE': '#FF9933'}
# Updating serialized DAG can not be faster than a minimum interval to reduce database
# write rate.
-MIN_SERIALIZED_DAG_UPDATE_INTERVAL = conf.getint(
- 'core', 'min_serialized_dag_update_interval', fallback=30)
+MIN_SERIALIZED_DAG_UPDATE_INTERVAL = conf.getint('core', 'min_serialized_dag_update_interval', fallback=30)
# Fetching serialized DAG can not be faster than a minimum interval to reduce database
# read rate. This config controls when your DAGs are updated in the Webserver
-MIN_SERIALIZED_DAG_FETCH_INTERVAL = conf.getint(
- 'core', 'min_serialized_dag_fetch_interval', fallback=10)
+MIN_SERIALIZED_DAG_FETCH_INTERVAL = conf.getint('core', 'min_serialized_dag_fetch_interval', fallback=10)
# Whether to persist DAG files code in DB. If set to True, Webserver reads file contents
# from DB instead of trying to access files in a DAG folder.
diff --git a/airflow/stats.py b/airflow/stats.py
index b7457da28cbe0..baa2cb5e15260 100644
--- a/airflow/stats.py
+++ b/airflow/stats.py
@@ -125,15 +125,26 @@ def stat_name_default_handler(stat_name, max_length=250) -> str:
if not isinstance(stat_name, str):
raise InvalidStatsNameException('The stat_name has to be a string')
if len(stat_name) > max_length:
- raise InvalidStatsNameException(textwrap.dedent("""\
+ raise InvalidStatsNameException(
+ textwrap.dedent(
+ """\
The stat_name ({stat_name}) has to be less than {max_length} characters.
- """.format(stat_name=stat_name, max_length=max_length)))
+ """.format(
+ stat_name=stat_name, max_length=max_length
+ )
+ )
+ )
if not all((c in ALLOWED_CHARACTERS) for c in stat_name):
- raise InvalidStatsNameException(textwrap.dedent("""\
+ raise InvalidStatsNameException(
+ textwrap.dedent(
+ """\
The stat name ({stat_name}) has to be composed with characters in
{allowed_characters}.
- """.format(stat_name=stat_name,
- allowed_characters=ALLOWED_CHARACTERS)))
+ """.format(
+ stat_name=stat_name, allowed_characters=ALLOWED_CHARACTERS
+ )
+ )
+ )
return stat_name
@@ -149,6 +160,7 @@ def validate_stat(fn: T) -> T:
"""Check if stat name contains invalid characters.
Log and not emit stats if name is invalid
"""
+
@wraps(fn)
def wrapper(_self, stat, *args, **kwargs):
try:
@@ -315,7 +327,8 @@ def get_statsd_logger(cls):
statsd = stats_class(
host=conf.get('scheduler', 'statsd_host'),
port=conf.getint('scheduler', 'statsd_port'),
- prefix=conf.get('scheduler', 'statsd_prefix'))
+ prefix=conf.get('scheduler', 'statsd_prefix'),
+ )
allow_list_validator = AllowListValidator(conf.get('scheduler', 'statsd_allow_list', fallback=None))
return SafeStatsdLogger(statsd, allow_list_validator)
@@ -323,11 +336,13 @@ def get_statsd_logger(cls):
def get_dogstatsd_logger(cls):
"""Get DataDog statsd logger"""
from datadog import DogStatsd
+
dogstatsd = DogStatsd(
host=conf.get('scheduler', 'statsd_host'),
port=conf.getint('scheduler', 'statsd_port'),
namespace=conf.get('scheduler', 'statsd_prefix'),
- constant_tags=cls.get_constant_tags())
+ constant_tags=cls.get_constant_tags(),
+ )
dogstatsd_allow_list = conf.get('scheduler', 'statsd_allow_list', fallback=None)
allow_list_validator = AllowListValidator(dogstatsd_allow_list)
return SafeDogStatsdLogger(dogstatsd, allow_list_validator)
@@ -348,5 +363,6 @@ def get_constant_tags(cls):
if TYPE_CHECKING:
Stats: StatsLogger
else:
+
class Stats(metaclass=_Stats): # noqa: D101
"""Empty class for Stats - we use metaclass to inject the right one"""
diff --git a/airflow/task/task_runner/base_task_runner.py b/airflow/task/task_runner/base_task_runner.py
index 851d138620eda..743685ec54739 100644
--- a/airflow/task/task_runner/base_task_runner.py
+++ b/airflow/task/task_runner/base_task_runner.py
@@ -65,10 +65,7 @@ def __init__(self, local_task_job):
cfg_path = tmp_configuration_copy(chmod=0o600)
# Give ownership of file to user; only they can read and write
- subprocess.call(
- ['sudo', 'chown', self.run_as_user, cfg_path],
- close_fds=True
- )
+ subprocess.call(['sudo', 'chown', self.run_as_user, cfg_path], close_fds=True)
# propagate PYTHONPATH environment variable
pythonpath_value = os.environ.get(PYTHONPATH_VAR, '')
@@ -102,9 +99,12 @@ def _read_task_logs(self, stream):
line = line.decode('utf-8')
if not line:
break
- self.log.info('Job %s: Subtask %s %s',
- self._task_instance.job_id, self._task_instance.task_id,
- line.rstrip('\n'))
+ self.log.info(
+ 'Job %s: Subtask %s %s',
+ self._task_instance.job_id,
+ self._task_instance.task_id,
+ line.rstrip('\n'),
+ )
def run_command(self, run_with=None):
"""
@@ -128,7 +128,7 @@ def run_command(self, run_with=None):
universal_newlines=True,
close_fds=True,
env=os.environ.copy(),
- preexec_fn=os.setsid
+ preexec_fn=os.setsid,
)
# Start daemon thread to read subprocess logging output
diff --git a/airflow/task/task_runner/cgroup_task_runner.py b/airflow/task/task_runner/cgroup_task_runner.py
index ab1e88c16611a..6cd0ce6aa048b 100644
--- a/airflow/task/task_runner/cgroup_task_runner.py
+++ b/airflow/task/task_runner/cgroup_task_runner.py
@@ -92,8 +92,7 @@ def _create_cgroup(self, path):
node = node.create_cgroup(path_element)
else:
self.log.debug(
- "Not creating cgroup %s in %s since it already exists",
- path_element, node.path.decode()
+ "Not creating cgroup %s in %s since it already exists", path_element, node.path.decode()
)
node = name_to_node[path_element]
return node
@@ -122,20 +121,21 @@ def _delete_cgroup(self, path):
def start(self):
# Use bash if it's already in a cgroup
cgroups = self._get_cgroup_names()
- if ((cgroups.get("cpu") and cgroups.get("cpu") != "/") or
- (cgroups.get("memory") and cgroups.get("memory") != "/")):
+ if (cgroups.get("cpu") and cgroups.get("cpu") != "/") or (
+ cgroups.get("memory") and cgroups.get("memory") != "/"
+ ):
self.log.debug(
- "Already running in a cgroup (cpu: %s memory: %s) so not "
- "creating another one",
- cgroups.get("cpu"), cgroups.get("memory")
+ "Already running in a cgroup (cpu: %s memory: %s) so not " "creating another one",
+ cgroups.get("cpu"),
+ cgroups.get("memory"),
)
self.process = self.run_command()
return
# Create a unique cgroup name
- cgroup_name = "airflow/{}/{}".format(datetime.datetime.utcnow().
- strftime("%Y-%m-%d"),
- str(uuid.uuid4()))
+ cgroup_name = "airflow/{}/{}".format(
+ datetime.datetime.utcnow().strftime("%Y-%m-%d"), str(uuid.uuid4())
+ )
self.mem_cgroup_name = f"memory/{cgroup_name}"
self.cpu_cgroup_name = f"cpu/{cgroup_name}"
@@ -151,30 +151,19 @@ def start(self):
mem_cgroup_node = self._create_cgroup(self.mem_cgroup_name)
self._created_mem_cgroup = True
if self._mem_mb_limit > 0:
- self.log.debug(
- "Setting %s with %s MB of memory",
- self.mem_cgroup_name, self._mem_mb_limit
- )
+ self.log.debug("Setting %s with %s MB of memory", self.mem_cgroup_name, self._mem_mb_limit)
mem_cgroup_node.controller.limit_in_bytes = self._mem_mb_limit * 1024 * 1024
# Create the CPU cgroup
cpu_cgroup_node = self._create_cgroup(self.cpu_cgroup_name)
self._created_cpu_cgroup = True
if self._cpu_shares > 0:
- self.log.debug(
- "Setting %s with %s CPU shares",
- self.cpu_cgroup_name, self._cpu_shares
- )
+ self.log.debug("Setting %s with %s CPU shares", self.cpu_cgroup_name, self._cpu_shares)
cpu_cgroup_node.controller.shares = self._cpu_shares
# Start the process w/ cgroups
- self.log.debug(
- "Starting task process with cgroups cpu,memory: %s",
- cgroup_name
- )
- self.process = self.run_command(
- ['cgexec', '-g', f'cpu,memory:{cgroup_name}']
- )
+ self.log.debug("Starting task process with cgroups cpu,memory: %s", cgroup_name)
+ self.process = self.run_command(['cgexec', '-g', f'cpu,memory:{cgroup_name}'])
def return_code(self):
return_code = self.process.poll()
@@ -186,10 +175,12 @@ def return_code(self):
# I wasn't able to track down the root cause of the package install failures, but
# we might want to revisit that approach at some other point.
if return_code == 137:
- self.log.error("Task failed with return code of 137. This may indicate "
- "that it was killed due to excessive memory usage. "
- "Please consider optimizing your task or using the "
- "resources argument to reserve more memory for your task")
+ self.log.error(
+ "Task failed with return code of 137. This may indicate "
+ "that it was killed due to excessive memory usage. "
+ "Please consider optimizing your task or using the "
+ "resources argument to reserve more memory for your task"
+ )
return return_code
def terminate(self):
diff --git a/airflow/ti_deps/dep_context.py b/airflow/ti_deps/dep_context.py
index 5534863f60550..a44d14bd7b48f 100644
--- a/airflow/ti_deps/dep_context.py
+++ b/airflow/ti_deps/dep_context.py
@@ -64,16 +64,17 @@ class DepContext:
"""
def __init__(
- self,
- deps=None,
- flag_upstream_failed: bool = False,
- ignore_all_deps: bool = False,
- ignore_depends_on_past: bool = False,
- ignore_in_retry_period: bool = False,
- ignore_in_reschedule_period: bool = False,
- ignore_task_deps: bool = False,
- ignore_ti_state: bool = False,
- finished_tasks=None):
+ self,
+ deps=None,
+ flag_upstream_failed: bool = False,
+ ignore_all_deps: bool = False,
+ ignore_depends_on_past: bool = False,
+ ignore_in_retry_period: bool = False,
+ ignore_in_reschedule_period: bool = False,
+ ignore_task_deps: bool = False,
+ ignore_ti_state: bool = False,
+ finished_tasks=None,
+ ):
self.deps = deps or set()
self.flag_upstream_failed = flag_upstream_failed
self.ignore_all_deps = ignore_all_deps
diff --git a/airflow/ti_deps/dependencies_deps.py b/airflow/ti_deps/dependencies_deps.py
index 3b37ffa31e1a9..7062995887be6 100644
--- a/airflow/ti_deps/dependencies_deps.py
+++ b/airflow/ti_deps/dependencies_deps.py
@@ -16,7 +16,10 @@
# under the License.
from airflow.ti_deps.dependencies_states import (
- BACKFILL_QUEUEABLE_STATES, QUEUEABLE_STATES, RUNNABLE_STATES, SCHEDULEABLE_STATES,
+ BACKFILL_QUEUEABLE_STATES,
+ QUEUEABLE_STATES,
+ RUNNABLE_STATES,
+ SCHEDULEABLE_STATES,
)
from airflow.ti_deps.deps.dag_ti_slots_available_dep import DagTISlotsAvailableDep
from airflow.ti_deps.deps.dag_unpaused_dep import DagUnpausedDep
diff --git a/airflow/ti_deps/deps/base_ti_dep.py b/airflow/ti_deps/deps/base_ti_dep.py
index 335c55bc3e5f2..b8810c4592770 100644
--- a/airflow/ti_deps/deps/base_ti_dep.py
+++ b/airflow/ti_deps/deps/base_ti_dep.py
@@ -91,13 +91,11 @@ def get_dep_statuses(self, ti, session, dep_context=None):
dep_context = DepContext()
if self.IGNOREABLE and dep_context.ignore_all_deps:
- yield self._passing_status(
- reason="Context specified all dependencies should be ignored.")
+ yield self._passing_status(reason="Context specified all dependencies should be ignored.")
return
if self.IS_TASK_DEP and dep_context.ignore_task_deps:
- yield self._passing_status(
- reason="Context specified all task dependencies should be ignored.")
+ yield self._passing_status(reason="Context specified all task dependencies should be ignored.")
return
yield from self._get_dep_statuses(ti, session, dep_context)
@@ -117,8 +115,7 @@ def is_met(self, ti, session, dep_context=None):
state that can be used by this dependency.
:type dep_context: BaseDepContext
"""
- return all(status.passed for status in
- self.get_dep_statuses(ti, session, dep_context))
+ return all(status.passed for status in self.get_dep_statuses(ti, session, dep_context))
@provide_session
def get_failure_reasons(self, ti, session, dep_context=None):
diff --git a/airflow/ti_deps/deps/dag_ti_slots_available_dep.py b/airflow/ti_deps/deps/dag_ti_slots_available_dep.py
index e48de027ffee9..dfc96777472e1 100644
--- a/airflow/ti_deps/deps/dag_ti_slots_available_dep.py
+++ b/airflow/ti_deps/deps/dag_ti_slots_available_dep.py
@@ -31,5 +31,5 @@ def _get_dep_statuses(self, ti, session, dep_context):
if ti.task.dag.get_concurrency_reached(session):
yield self._failing_status(
reason="The maximum number of running tasks ({}) for this task's DAG "
- "'{}' has been reached.".format(ti.task.dag.concurrency,
- ti.dag_id))
+ "'{}' has been reached.".format(ti.task.dag.concurrency, ti.dag_id)
+ )
diff --git a/airflow/ti_deps/deps/dag_unpaused_dep.py b/airflow/ti_deps/deps/dag_unpaused_dep.py
index b65acc3f6cc26..7d33398c9adf4 100644
--- a/airflow/ti_deps/deps/dag_unpaused_dep.py
+++ b/airflow/ti_deps/deps/dag_unpaused_dep.py
@@ -29,5 +29,4 @@ class DagUnpausedDep(BaseTIDep):
@provide_session
def _get_dep_statuses(self, ti, session, dep_context):
if ti.task.dag.get_is_paused(session):
- yield self._failing_status(
- reason=f"Task's DAG '{ti.dag_id}' is paused.")
+ yield self._failing_status(reason=f"Task's DAG '{ti.dag_id}' is paused.")
diff --git a/airflow/ti_deps/deps/dagrun_exists_dep.py b/airflow/ti_deps/deps/dagrun_exists_dep.py
index bb1528170e99b..f244f15bdd05b 100644
--- a/airflow/ti_deps/deps/dagrun_exists_dep.py
+++ b/airflow/ti_deps/deps/dagrun_exists_dep.py
@@ -34,24 +34,22 @@ def _get_dep_statuses(self, ti, session, dep_context):
if not dagrun:
# The import is needed here to avoid a circular dependency
from airflow.models.dagrun import DagRun
+
running_dagruns = DagRun.find(
- dag_id=dag.dag_id,
- state=State.RUNNING,
- external_trigger=False,
- session=session
+ dag_id=dag.dag_id, state=State.RUNNING, external_trigger=False, session=session
)
if len(running_dagruns) >= dag.max_active_runs:
- reason = ("The maximum number of active dag runs ({}) for this task "
- "instance's DAG '{}' has been reached.".format(
- dag.max_active_runs,
- ti.dag_id))
+ reason = (
+ "The maximum number of active dag runs ({}) for this task "
+ "instance's DAG '{}' has been reached.".format(dag.max_active_runs, ti.dag_id)
+ )
else:
reason = "Unknown reason"
- yield self._failing_status(
- reason=f"Task instance's dagrun did not exist: {reason}.")
+ yield self._failing_status(reason=f"Task instance's dagrun did not exist: {reason}.")
else:
if dagrun.state != State.RUNNING:
yield self._failing_status(
reason="Task instance's dagrun was not in the 'running' state but in "
- "the state '{}'.".format(dagrun.state))
+ "the state '{}'.".format(dagrun.state)
+ )
diff --git a/airflow/ti_deps/deps/dagrun_id_dep.py b/airflow/ti_deps/deps/dagrun_id_dep.py
index e01b975ba685c..2ed84f1f362ab 100644
--- a/airflow/ti_deps/deps/dagrun_id_dep.py
+++ b/airflow/ti_deps/deps/dagrun_id_dep.py
@@ -47,8 +47,9 @@ def _get_dep_statuses(self, ti, session, dep_context=None):
if not dagrun or not dagrun.run_id or dagrun.run_type != DagRunType.BACKFILL_JOB:
yield self._passing_status(
reason=f"Task's DagRun doesn't exist or run_id is either NULL "
- f"or run_type is not {DagRunType.BACKFILL_JOB}")
+ f"or run_type is not {DagRunType.BACKFILL_JOB}"
+ )
else:
yield self._failing_status(
- reason=f"Task's DagRun run_id is not NULL "
- f"and run type is {DagRunType.BACKFILL_JOB}")
+ reason=f"Task's DagRun run_id is not NULL " f"and run type is {DagRunType.BACKFILL_JOB}"
+ )
diff --git a/airflow/ti_deps/deps/exec_date_after_start_date_dep.py b/airflow/ti_deps/deps/exec_date_after_start_date_dep.py
index c1ca967eaac02..4e33683bbebe9 100644
--- a/airflow/ti_deps/deps/exec_date_after_start_date_dep.py
+++ b/airflow/ti_deps/deps/exec_date_after_start_date_dep.py
@@ -31,14 +31,13 @@ def _get_dep_statuses(self, ti, session, dep_context):
if ti.task.start_date and ti.execution_date < ti.task.start_date:
yield self._failing_status(
reason="The execution date is {} but this is before the task's start "
- "date {}.".format(
- ti.execution_date.isoformat(),
- ti.task.start_date.isoformat()))
+ "date {}.".format(ti.execution_date.isoformat(), ti.task.start_date.isoformat())
+ )
- if (ti.task.dag and ti.task.dag.start_date and
- ti.execution_date < ti.task.dag.start_date):
+ if ti.task.dag and ti.task.dag.start_date and ti.execution_date < ti.task.dag.start_date:
yield self._failing_status(
reason="The execution date is {} but this is before the task's "
"DAG's start date {}.".format(
- ti.execution_date.isoformat(),
- ti.task.dag.start_date.isoformat()))
+ ti.execution_date.isoformat(), ti.task.dag.start_date.isoformat()
+ )
+ )
diff --git a/airflow/ti_deps/deps/not_in_retry_period_dep.py b/airflow/ti_deps/deps/not_in_retry_period_dep.py
index d19e0c8117d7e..08aa8ecb529b3 100644
--- a/airflow/ti_deps/deps/not_in_retry_period_dep.py
+++ b/airflow/ti_deps/deps/not_in_retry_period_dep.py
@@ -33,12 +33,12 @@ class NotInRetryPeriodDep(BaseTIDep):
def _get_dep_statuses(self, ti, session, dep_context):
if dep_context.ignore_in_retry_period:
yield self._passing_status(
- reason="The context specified that being in a retry period was permitted.")
+ reason="The context specified that being in a retry period was permitted."
+ )
return
if ti.state != State.UP_FOR_RETRY:
- yield self._passing_status(
- reason="The task instance was not marked for retrying.")
+ yield self._passing_status(reason="The task instance was not marked for retrying.")
return
# Calculate the date first so that it is always smaller than the timestamp used by
@@ -48,6 +48,6 @@ def _get_dep_statuses(self, ti, session, dep_context):
if ti.is_premature:
yield self._failing_status(
reason="Task is not ready for retry yet but will be retried "
- "automatically. Current date is {} and task will be retried "
- "at {}.".format(cur_date.isoformat(),
- next_task_retry_date.isoformat()))
+ "automatically. Current date is {} and task will be retried "
+ "at {}.".format(cur_date.isoformat(), next_task_retry_date.isoformat())
+ )
diff --git a/airflow/ti_deps/deps/not_previously_skipped_dep.py b/airflow/ti_deps/deps/not_previously_skipped_dep.py
index 4ecef93ad847b..08413dfd571d6 100644
--- a/airflow/ti_deps/deps/not_previously_skipped_dep.py
+++ b/airflow/ti_deps/deps/not_previously_skipped_dep.py
@@ -28,19 +28,18 @@ class NotPreviouslySkippedDep(BaseTIDep):
IGNORABLE = True
IS_TASK_DEP = True
- def _get_dep_statuses(
- self, ti, session, dep_context
- ): # pylint: disable=signature-differs
+ def _get_dep_statuses(self, ti, session, dep_context): # pylint: disable=signature-differs
from airflow.models.skipmixin import (
- XCOM_SKIPMIXIN_FOLLOWED, XCOM_SKIPMIXIN_KEY, XCOM_SKIPMIXIN_SKIPPED, SkipMixin,
+ XCOM_SKIPMIXIN_FOLLOWED,
+ XCOM_SKIPMIXIN_KEY,
+ XCOM_SKIPMIXIN_SKIPPED,
+ SkipMixin,
)
from airflow.utils.state import State
upstream = ti.task.get_direct_relatives(upstream=True)
- finished_tasks = dep_context.ensure_finished_tasks(
- ti.task.dag, ti.execution_date, session
- )
+ finished_tasks = dep_context.ensure_finished_tasks(ti.task.dag, ti.execution_date, session)
finished_task_ids = {t.task_id for t in finished_tasks}
@@ -50,9 +49,7 @@ def _get_dep_statuses(
# This can happen if the parent task has not yet run.
continue
- prev_result = ti.xcom_pull(
- task_ids=parent.task_id, key=XCOM_SKIPMIXIN_KEY, session=session
- )
+ prev_result = ti.xcom_pull(task_ids=parent.task_id, key=XCOM_SKIPMIXIN_KEY, session=session)
if prev_result is None:
# This can happen if the parent task has not yet run.
diff --git a/airflow/ti_deps/deps/pool_slots_available_dep.py b/airflow/ti_deps/deps/pool_slots_available_dep.py
index 2c58013f33304..17a8e75dd3c51 100644
--- a/airflow/ti_deps/deps/pool_slots_available_dep.py
+++ b/airflow/ti_deps/deps/pool_slots_available_dep.py
@@ -49,8 +49,8 @@ def _get_dep_statuses(self, ti, session, dep_context=None):
pools = session.query(Pool).filter(Pool.pool == pool_name).all()
if not pools:
yield self._failing_status(
- reason=("Tasks using non-existent pool '%s' will not be scheduled",
- pool_name))
+ reason=("Tasks using non-existent pool '%s' will not be scheduled", pool_name)
+ )
return
else:
# Controlled by UNIQUE key in slot_pool table,
@@ -62,12 +62,14 @@ def _get_dep_statuses(self, ti, session, dep_context=None):
if open_slots <= (ti.pool_slots - 1):
yield self._failing_status(
- reason=("Not scheduling since there are %s open slots in pool %s "
- "and require %s pool slots",
- open_slots, pool_name, ti.pool_slots)
+ reason=(
+ "Not scheduling since there are %s open slots in pool %s and require %s pool slots",
+ open_slots,
+ pool_name,
+ ti.pool_slots,
+ )
)
else:
yield self._passing_status(
- reason=(
- "There are enough open slots in %s to execute the task", pool_name)
+ reason=("There are enough open slots in %s to execute the task", pool_name)
)
diff --git a/airflow/ti_deps/deps/prev_dagrun_dep.py b/airflow/ti_deps/deps/prev_dagrun_dep.py
index bdd2bf622be6c..d5e717afb95d4 100644
--- a/airflow/ti_deps/deps/prev_dagrun_dep.py
+++ b/airflow/ti_deps/deps/prev_dagrun_dep.py
@@ -35,26 +35,24 @@ class PrevDagrunDep(BaseTIDep):
def _get_dep_statuses(self, ti, session, dep_context):
if dep_context.ignore_depends_on_past:
yield self._passing_status(
- reason="The context specified that the state of past DAGs could be "
- "ignored.")
+ reason="The context specified that the state of past DAGs could be " "ignored."
+ )
return
if not ti.task.depends_on_past:
- yield self._passing_status(
- reason="The task did not have depends_on_past set.")
+ yield self._passing_status(reason="The task did not have depends_on_past set.")
return
# Don't depend on the previous task instance if we are the first task
dag = ti.task.dag
if dag.catchup:
if dag.previous_schedule(ti.execution_date) is None:
- yield self._passing_status(
- reason="This task does not have a schedule or is @once"
- )
+ yield self._passing_status(reason="This task does not have a schedule or is @once")
return
if dag.previous_schedule(ti.execution_date) < ti.task.start_date:
yield self._passing_status(
- reason="This task instance was the first task instance for its task.")
+ reason="This task instance was the first task instance for its task."
+ )
return
else:
dr = ti.get_dagrun(session=session)
@@ -62,25 +60,28 @@ def _get_dep_statuses(self, ti, session, dep_context):
if not last_dagrun:
yield self._passing_status(
- reason="This task instance was the first task instance for its task.")
+ reason="This task instance was the first task instance for its task."
+ )
return
previous_ti = ti.get_previous_ti(session=session)
if not previous_ti:
yield self._failing_status(
reason="depends_on_past is true for this task's DAG, but the previous "
- "task instance has not run yet.")
+ "task instance has not run yet."
+ )
return
if previous_ti.state not in {State.SKIPPED, State.SUCCESS}:
yield self._failing_status(
reason="depends_on_past is true for this task, but the previous task "
- "instance {} is in the state '{}' which is not a successful "
- "state.".format(previous_ti, previous_ti.state))
+ "instance {} is in the state '{}' which is not a successful "
+ "state.".format(previous_ti, previous_ti.state)
+ )
previous_ti.task = ti.task
- if (ti.task.wait_for_downstream and
- not previous_ti.are_dependents_done(session=session)):
+ if ti.task.wait_for_downstream and not previous_ti.are_dependents_done(session=session):
yield self._failing_status(
reason="The tasks downstream of the previous task instance {} haven't "
- "completed (and wait_for_downstream is True).".format(previous_ti))
+ "completed (and wait_for_downstream is True).".format(previous_ti)
+ )
diff --git a/airflow/ti_deps/deps/ready_to_reschedule.py b/airflow/ti_deps/deps/ready_to_reschedule.py
index bdce41fb46e37..a4396070eeada 100644
--- a/airflow/ti_deps/deps/ready_to_reschedule.py
+++ b/airflow/ti_deps/deps/ready_to_reschedule.py
@@ -42,13 +42,14 @@ def _get_dep_statuses(self, ti, session, dep_context):
"""
if dep_context.ignore_in_reschedule_period:
yield self._passing_status(
- reason="The context specified that being in a reschedule period was "
- "permitted.")
+ reason="The context specified that being in a reschedule period was " "permitted."
+ )
return
if ti.state not in self.RESCHEDULEABLE_STATES:
yield self._passing_status(
- reason="The task instance is not in State_UP_FOR_RESCHEDULE or NONE state.")
+ reason="The task instance is not in State_UP_FOR_RESCHEDULE or NONE state."
+ )
return
task_reschedule = (
@@ -57,18 +58,17 @@ def _get_dep_statuses(self, ti, session, dep_context):
.first()
)
if not task_reschedule:
- yield self._passing_status(
- reason="There is no reschedule request for this task instance.")
+ yield self._passing_status(reason="There is no reschedule request for this task instance.")
return
now = timezone.utcnow()
next_reschedule_date = task_reschedule.reschedule_date
if now >= next_reschedule_date:
- yield self._passing_status(
- reason="Task instance id ready for reschedule.")
+ yield self._passing_status(reason="Task instance id ready for reschedule.")
return
yield self._failing_status(
reason="Task is not ready for reschedule yet but will be rescheduled "
- "automatically. Current date is {} and task will be rescheduled "
- "at {}.".format(now.isoformat(), next_reschedule_date.isoformat()))
+ "automatically. Current date is {} and task will be rescheduled "
+ "at {}.".format(now.isoformat(), next_reschedule_date.isoformat())
+ )
diff --git a/airflow/ti_deps/deps/runnable_exec_date_dep.py b/airflow/ti_deps/deps/runnable_exec_date_dep.py
index 20cb8b6995e1f..30d07c5d52392 100644
--- a/airflow/ti_deps/deps/runnable_exec_date_dep.py
+++ b/airflow/ti_deps/deps/runnable_exec_date_dep.py
@@ -36,21 +36,17 @@ def _get_dep_statuses(self, ti, session, dep_context):
if ti.execution_date > cur_date and not ti.task.dag.allow_future_exec_dates:
yield self._failing_status(
reason="Execution date {} is in the future (the current "
- "date is {}).".format(ti.execution_date.isoformat(),
- cur_date.isoformat()))
+ "date is {}).".format(ti.execution_date.isoformat(), cur_date.isoformat())
+ )
if ti.task.end_date and ti.execution_date > ti.task.end_date:
yield self._failing_status(
reason="The execution date is {} but this is after the task's end date "
- "{}.".format(
- ti.execution_date.isoformat(),
- ti.task.end_date.isoformat()))
+ "{}.".format(ti.execution_date.isoformat(), ti.task.end_date.isoformat())
+ )
- if (ti.task.dag and
- ti.task.dag.end_date and
- ti.execution_date > ti.task.dag.end_date):
+ if ti.task.dag and ti.task.dag.end_date and ti.execution_date > ti.task.dag.end_date:
yield self._failing_status(
reason="The execution date is {} but this is after the task's DAG's "
- "end date {}.".format(
- ti.execution_date.isoformat(),
- ti.task.dag.end_date.isoformat()))
+ "end date {}.".format(ti.execution_date.isoformat(), ti.task.dag.end_date.isoformat())
+ )
diff --git a/airflow/ti_deps/deps/task_concurrency_dep.py b/airflow/ti_deps/deps/task_concurrency_dep.py
index d1456f9112ce8..2824301771d81 100644
--- a/airflow/ti_deps/deps/task_concurrency_dep.py
+++ b/airflow/ti_deps/deps/task_concurrency_dep.py
@@ -34,10 +34,8 @@ def _get_dep_statuses(self, ti, session, dep_context):
return
if ti.get_num_running_task_instances(session) >= ti.task.task_concurrency:
- yield self._failing_status(reason="The max task concurrency "
- "has been reached.")
+ yield self._failing_status(reason="The max task concurrency " "has been reached.")
return
else:
- yield self._passing_status(reason="The max task concurrency "
- "has not been reached.")
+ yield self._passing_status(reason="The max task concurrency " "has not been reached.")
return
diff --git a/airflow/ti_deps/deps/task_not_running_dep.py b/airflow/ti_deps/deps/task_not_running_dep.py
index 184e65d1474fe..b0c9f3979bbbf 100644
--- a/airflow/ti_deps/deps/task_not_running_dep.py
+++ b/airflow/ti_deps/deps/task_not_running_dep.py
@@ -40,5 +40,4 @@ def _get_dep_statuses(self, ti, session, dep_context=None):
yield self._passing_status(reason="Task is not in running state.")
return
- yield self._failing_status(
- reason='Task is in the running state')
+ yield self._failing_status(reason='Task is in the running state')
diff --git a/airflow/ti_deps/deps/trigger_rule_dep.py b/airflow/ti_deps/deps/trigger_rule_dep.py
index 3544781e21601..e350863590278 100644
--- a/airflow/ti_deps/deps/trigger_rule_dep.py
+++ b/airflow/ti_deps/deps/trigger_rule_dep.py
@@ -46,15 +46,19 @@ def _get_states_count_upstream_ti(ti, finished_tasks):
:type finished_tasks: list[airflow.models.TaskInstance]
"""
counter = Counter(task.state for task in finished_tasks if task.task_id in ti.task.upstream_task_ids)
- return counter.get(State.SUCCESS, 0), counter.get(State.SKIPPED, 0), counter.get(State.FAILED, 0), \
- counter.get(State.UPSTREAM_FAILED, 0), sum(counter.values())
+ return (
+ counter.get(State.SUCCESS, 0),
+ counter.get(State.SKIPPED, 0),
+ counter.get(State.FAILED, 0),
+ counter.get(State.UPSTREAM_FAILED, 0),
+ sum(counter.values()),
+ )
@provide_session
def _get_dep_statuses(self, ti, session, dep_context):
# Checking that all upstream dependencies have succeeded
if not ti.task.upstream_list:
- yield self._passing_status(
- reason="The task instance did not have any upstream tasks.")
+ yield self._passing_status(reason="The task instance did not have any upstream tasks.")
return
if ti.task.trigger_rule == TR.DUMMY:
@@ -62,8 +66,8 @@ def _get_dep_statuses(self, ti, session, dep_context):
return
# see if the task name is in the task upstream for our task
successes, skipped, failed, upstream_failed, done = self._get_states_count_upstream_ti(
- ti=ti,
- finished_tasks=dep_context.ensure_finished_tasks(ti.task.dag, ti.execution_date, session))
+ ti=ti, finished_tasks=dep_context.ensure_finished_tasks(ti.task.dag, ti.execution_date, session)
+ )
yield from self._evaluate_trigger_rule(
ti=ti,
@@ -73,19 +77,13 @@ def _get_dep_statuses(self, ti, session, dep_context):
upstream_failed=upstream_failed,
done=done,
flag_upstream_failed=dep_context.flag_upstream_failed,
- session=session)
+ session=session,
+ )
@provide_session
def _evaluate_trigger_rule( # pylint: disable=too-many-branches
- self,
- ti,
- successes,
- skipped,
- failed,
- upstream_failed,
- done,
- flag_upstream_failed,
- session):
+ self, ti, successes, skipped, failed, upstream_failed, done, flag_upstream_failed, session
+ ):
"""
Yields a dependency status that indicate whether the given task instance's trigger
rule was met.
@@ -115,8 +113,12 @@ def _evaluate_trigger_rule( # pylint: disable=too-many-branches
trigger_rule = task.trigger_rule
upstream_done = done >= upstream
upstream_tasks_state = {
- "total": upstream, "successes": successes, "skipped": skipped,
- "failed": failed, "upstream_failed": upstream_failed, "done": done
+ "total": upstream,
+ "successes": successes,
+ "skipped": skipped,
+ "failed": failed,
+ "upstream_failed": upstream_failed,
+ "done": done,
}
# TODO(aoen): Ideally each individual trigger rules would be its own class, but
# this isn't very feasible at the moment since the database queries need to be
@@ -154,68 +156,77 @@ def _evaluate_trigger_rule( # pylint: disable=too-many-branches
yield self._failing_status(
reason="Task's trigger rule '{}' requires one upstream "
"task success, but none were found. "
- "upstream_tasks_state={}, upstream_task_ids={}"
- .format(trigger_rule, upstream_tasks_state, task.upstream_task_ids))
+ "upstream_tasks_state={}, upstream_task_ids={}".format(
+ trigger_rule, upstream_tasks_state, task.upstream_task_ids
+ )
+ )
elif trigger_rule == TR.ONE_FAILED:
if not failed and not upstream_failed:
yield self._failing_status(
reason="Task's trigger rule '{}' requires one upstream "
"task failure, but none were found. "
- "upstream_tasks_state={}, upstream_task_ids={}"
- .format(trigger_rule, upstream_tasks_state, task.upstream_task_ids))
+ "upstream_tasks_state={}, upstream_task_ids={}".format(
+ trigger_rule, upstream_tasks_state, task.upstream_task_ids
+ )
+ )
elif trigger_rule == TR.ALL_SUCCESS:
num_failures = upstream - successes
if num_failures > 0:
yield self._failing_status(
reason="Task's trigger rule '{}' requires all upstream "
"tasks to have succeeded, but found {} non-success(es). "
- "upstream_tasks_state={}, upstream_task_ids={}"
- .format(trigger_rule, num_failures, upstream_tasks_state,
- task.upstream_task_ids))
+ "upstream_tasks_state={}, upstream_task_ids={}".format(
+ trigger_rule, num_failures, upstream_tasks_state, task.upstream_task_ids
+ )
+ )
elif trigger_rule == TR.ALL_FAILED:
num_successes = upstream - failed - upstream_failed
if num_successes > 0:
yield self._failing_status(
reason="Task's trigger rule '{}' requires all upstream "
"tasks to have failed, but found {} non-failure(s). "
- "upstream_tasks_state={}, upstream_task_ids={}"
- .format(trigger_rule, num_successes, upstream_tasks_state,
- task.upstream_task_ids))
+ "upstream_tasks_state={}, upstream_task_ids={}".format(
+ trigger_rule, num_successes, upstream_tasks_state, task.upstream_task_ids
+ )
+ )
elif trigger_rule == TR.ALL_DONE:
if not upstream_done:
yield self._failing_status(
reason="Task's trigger rule '{}' requires all upstream "
"tasks to have completed, but found {} task(s) that "
"were not done. upstream_tasks_state={}, "
- "upstream_task_ids={}"
- .format(trigger_rule, upstream_done, upstream_tasks_state,
- task.upstream_task_ids))
+ "upstream_task_ids={}".format(
+ trigger_rule, upstream_done, upstream_tasks_state, task.upstream_task_ids
+ )
+ )
elif trigger_rule == TR.NONE_FAILED:
num_failures = upstream - successes - skipped
if num_failures > 0:
yield self._failing_status(
reason="Task's trigger rule '{}' requires all upstream "
"tasks to have succeeded or been skipped, but found {} non-success(es). "
- "upstream_tasks_state={}, upstream_task_ids={}"
- .format(trigger_rule, num_failures, upstream_tasks_state,
- task.upstream_task_ids))
+ "upstream_tasks_state={}, upstream_task_ids={}".format(
+ trigger_rule, num_failures, upstream_tasks_state, task.upstream_task_ids
+ )
+ )
elif trigger_rule == TR.NONE_FAILED_OR_SKIPPED:
num_failures = upstream - successes - skipped
if num_failures > 0:
yield self._failing_status(
reason="Task's trigger rule '{}' requires all upstream "
"tasks to have succeeded or been skipped, but found {} non-success(es). "
- "upstream_tasks_state={}, upstream_task_ids={}"
- .format(trigger_rule, num_failures, upstream_tasks_state,
- task.upstream_task_ids))
+ "upstream_tasks_state={}, upstream_task_ids={}".format(
+ trigger_rule, num_failures, upstream_tasks_state, task.upstream_task_ids
+ )
+ )
elif trigger_rule == TR.NONE_SKIPPED:
if not upstream_done or (skipped > 0):
yield self._failing_status(
reason="Task's trigger rule '{}' requires all upstream "
"tasks to not have been skipped, but found {} task(s) skipped. "
- "upstream_tasks_state={}, upstream_task_ids={}"
- .format(trigger_rule, skipped, upstream_tasks_state,
- task.upstream_task_ids))
+ "upstream_tasks_state={}, upstream_task_ids={}".format(
+ trigger_rule, skipped, upstream_tasks_state, task.upstream_task_ids
+ )
+ )
else:
- yield self._failing_status(
- reason=f"No strategy to evaluate trigger rule '{trigger_rule}'.")
+ yield self._failing_status(reason=f"No strategy to evaluate trigger rule '{trigger_rule}'.")
diff --git a/airflow/ti_deps/deps/valid_state_dep.py b/airflow/ti_deps/deps/valid_state_dep.py
index 1609afb819e0c..ea39ae772f383 100644
--- a/airflow/ti_deps/deps/valid_state_dep.py
+++ b/airflow/ti_deps/deps/valid_state_dep.py
@@ -38,8 +38,7 @@ def __init__(self, valid_states):
super().__init__()
if not valid_states:
- raise AirflowException(
- 'ValidStatesDep received an empty set of valid states.')
+ raise AirflowException('ValidStatesDep received an empty set of valid states.')
self._valid_states = valid_states
def __eq__(self, other):
@@ -51,8 +50,7 @@ def __hash__(self):
@provide_session
def _get_dep_statuses(self, ti, session, dep_context):
if dep_context.ignore_ti_state:
- yield self._passing_status(
- reason="Context specified that state should be ignored.")
+ yield self._passing_status(reason="Context specified that state should be ignored.")
return
if ti.state in self._valid_states:
@@ -61,5 +59,5 @@ def _get_dep_statuses(self, ti, session, dep_context):
yield self._failing_status(
reason="Task is in the '{}' state which is not a valid state for "
- "execution. The task must be cleared in order to be run.".format(
- ti.state))
+ "execution. The task must be cleared in order to be run.".format(ti.state)
+ )
diff --git a/airflow/typing_compat.py b/airflow/typing_compat.py
index 935dd0ea4de2b..6fd6d8c8252f2 100644
--- a/airflow/typing_compat.py
+++ b/airflow/typing_compat.py
@@ -26,7 +26,9 @@
# python 3.8 we can safely remove this shim import after Airflow drops
# support for <3.8
from typing import ( # type: ignore # noqa # pylint: disable=unused-import
- Protocol, TypedDict, runtime_checkable,
+ Protocol,
+ TypedDict,
+ runtime_checkable,
)
except ImportError:
from typing_extensions import Protocol, TypedDict, runtime_checkable # type: ignore # noqa
diff --git a/airflow/utils/callback_requests.py b/airflow/utils/callback_requests.py
index fe8017c721fb8..5561955dcf20a 100644
--- a/airflow/utils/callback_requests.py
+++ b/airflow/utils/callback_requests.py
@@ -56,7 +56,7 @@ def __init__(
full_filepath: str,
simple_task_instance: SimpleTaskInstance,
is_failure_callback: Optional[bool] = True,
- msg: Optional[str] = None
+ msg: Optional[str] = None,
):
super().__init__(full_filepath=full_filepath, msg=msg)
self.simple_task_instance = simple_task_instance
@@ -80,7 +80,7 @@ def __init__(
dag_id: str,
execution_date: datetime,
is_failure_callback: Optional[bool] = True,
- msg: Optional[str] = None
+ msg: Optional[str] = None,
):
super().__init__(full_filepath=full_filepath, msg=msg)
self.dag_id = dag_id
diff --git a/airflow/utils/cli.py b/airflow/utils/cli.py
index d1fdf63d880d4..0578acb7597bf 100644
--- a/airflow/utils/cli.py
+++ b/airflow/utils/cli.py
@@ -63,6 +63,7 @@ def action_logging(f: T) -> T:
:param f: function instance
:return: wrapped function
"""
+
@functools.wraps(f)
def wrapper(*args, **kwargs):
"""
@@ -76,8 +77,9 @@ def wrapper(*args, **kwargs):
if not args:
raise ValueError("Args should be set")
if not isinstance(args[0], Namespace):
- raise ValueError("1st positional argument should be argparse.Namespace instance,"
- f"but is {type(args[0])}")
+ raise ValueError(
+ "1st positional argument should be argparse.Namespace instance," f"but is {type(args[0])}"
+ )
metrics = _build_metrics(f.__name__, args[0])
cli_action_loggers.on_pre_execution(**metrics)
try:
@@ -115,12 +117,17 @@ def _build_metrics(func_name, namespace):
if command.startswith(f'{sensitive_field}='):
full_command[idx] = f'{sensitive_field}={"*" * 8}'
- metrics = {'sub_command': func_name, 'start_datetime': datetime.utcnow(),
- 'full_command': f'{full_command}', 'user': getpass.getuser()}
+ metrics = {
+ 'sub_command': func_name,
+ 'start_datetime': datetime.utcnow(),
+ 'full_command': f'{full_command}',
+ 'user': getpass.getuser(),
+ }
if not isinstance(namespace, Namespace):
- raise ValueError("namespace argument should be argparse.Namespace instance,"
- f"but is {type(namespace)}")
+ raise ValueError(
+ "namespace argument should be argparse.Namespace instance," f"but is {type(namespace)}"
+ )
tmp_dic = vars(namespace)
metrics['dag_id'] = tmp_dic.get('dag_id')
metrics['task_id'] = tmp_dic.get('task_id')
@@ -135,7 +142,8 @@ def _build_metrics(func_name, namespace):
extra=extra,
task_id=metrics.get('task_id'),
dag_id=metrics.get('dag_id'),
- execution_date=metrics.get('execution_date'))
+ execution_date=metrics.get('execution_date'),
+ )
metrics['log'] = log
return metrics
@@ -157,7 +165,8 @@ def get_dag_by_file_location(dag_id: str):
if dag_model is None:
raise AirflowException(
'dag_id could not be found: {}. Either the dag did not exist or it failed to '
- 'parse.'.format(dag_id))
+ 'parse.'.format(dag_id)
+ )
dagbag = DagBag(dag_folder=dag_model.fileloc)
return dagbag.dags[dag_id]
@@ -168,7 +177,8 @@ def get_dag(subdir: Optional[str], dag_id: str) -> DAG:
if dag_id not in dagbag.dags:
raise AirflowException(
'dag_id could not be found: {}. Either the dag did not exist or it failed to '
- 'parse.'.format(dag_id))
+ 'parse.'.format(dag_id)
+ )
return dagbag.dags[dag_id]
@@ -181,7 +191,8 @@ def get_dags(subdir: Optional[str], dag_id: str, use_regex: bool = False):
if not matched_dags:
raise AirflowException(
'dag_id could not be found with regex: {}. Either the dag did not exist '
- 'or it failed to parse.'.format(dag_id))
+ 'or it failed to parse.'.format(dag_id)
+ )
return matched_dags
@@ -238,11 +249,9 @@ def sigquit_handler(sig, frame): # pylint: disable=unused-argument
id_to_name = {th.ident: th.name for th in threading.enumerate()}
code = []
for thread_id, stack in sys._current_frames().items(): # pylint: disable=protected-access
- code.append("\n# Thread: {}({})"
- .format(id_to_name.get(thread_id, ""), thread_id))
+ code.append("\n# Thread: {}({})".format(id_to_name.get(thread_id, ""), thread_id))
for filename, line_number, name, line in traceback.extract_stack(stack):
- code.append('File: "{}", line {}, in {}'
- .format(filename, line_number, name))
+ code.append(f'File: "{filename}", line {line_number}, in {name}')
if line:
code.append(f" {line.strip()}")
print("\n".join(code))
diff --git a/airflow/utils/compression.py b/airflow/utils/compression.py
index dff8c8d3ed7ab..aa39b031dd715 100644
--- a/airflow/utils/compression.py
+++ b/airflow/utils/compression.py
@@ -25,16 +25,16 @@
def uncompress_file(input_file_name, file_extension, dest_dir):
"""Uncompress gz and bz2 files"""
if file_extension.lower() not in ('.gz', '.bz2'):
- raise NotImplementedError("Received {} format. Only gz and bz2 "
- "files can currently be uncompressed."
- .format(file_extension))
+ raise NotImplementedError(
+ "Received {} format. Only gz and bz2 "
+ "files can currently be uncompressed.".format(file_extension)
+ )
if file_extension.lower() == '.gz':
fmodule = gzip.GzipFile
elif file_extension.lower() == '.bz2':
fmodule = bz2.BZ2File
- with fmodule(input_file_name, mode='rb') as f_compressed,\
- NamedTemporaryFile(dir=dest_dir,
- mode='wb',
- delete=False) as f_uncompressed:
+ with fmodule(input_file_name, mode='rb') as f_compressed, NamedTemporaryFile(
+ dir=dest_dir, mode='wb', delete=False
+ ) as f_uncompressed:
shutil.copyfileobj(f_compressed, f_uncompressed)
return f_uncompressed.name
diff --git a/airflow/utils/dag_processing.py b/airflow/utils/dag_processing.py
index e1b0358f379ee..793d211c57a4c 100644
--- a/airflow/utils/dag_processing.py
+++ b/airflow/utils/dag_processing.py
@@ -194,13 +194,12 @@ def __init__(
dag_directory: str,
max_runs: int,
processor_factory: Callable[
- [str, List[CallbackRequest], Optional[List[str]], bool],
- AbstractDagFileProcessorProcess
+ [str, List[CallbackRequest], Optional[List[str]], bool], AbstractDagFileProcessorProcess
],
processor_timeout: timedelta,
dag_ids: Optional[List[str]],
pickle_dags: bool,
- async_mode: bool
+ async_mode: bool,
):
super().__init__()
self._file_path_queue: List[str] = []
@@ -241,8 +240,8 @@ def start(self) -> None:
child_signal_conn,
self._dag_ids,
self._pickle_dags,
- self._async_mode
- )
+ self._async_mode,
+ ),
)
self._process = process
@@ -326,15 +325,12 @@ def wait_until_finished(self) -> None:
def _run_processor_manager(
dag_directory: str,
max_runs: int,
- processor_factory: Callable[
- [str, List[CallbackRequest]],
- AbstractDagFileProcessorProcess
- ],
+ processor_factory: Callable[[str, List[CallbackRequest]], AbstractDagFileProcessorProcess],
processor_timeout: timedelta,
signal_conn: MultiprocessingConnection,
dag_ids: Optional[List[str]],
pickle_dags: bool,
- async_mode: bool
+ async_mode: bool,
) -> None:
# Make this process start as a new process group - that makes it easy
@@ -355,14 +351,16 @@ def _run_processor_manager(
importlib.reload(airflow.settings)
airflow.settings.initialize()
del os.environ['CONFIG_PROCESSOR_MANAGER_LOGGER']
- processor_manager = DagFileProcessorManager(dag_directory,
- max_runs,
- processor_factory,
- processor_timeout,
- signal_conn,
- dag_ids,
- pickle_dags,
- async_mode)
+ processor_manager = DagFileProcessorManager(
+ dag_directory,
+ max_runs,
+ processor_factory,
+ processor_timeout,
+ signal_conn,
+ dag_ids,
+ pickle_dags,
+ async_mode,
+ )
processor_manager.start()
@@ -397,7 +395,8 @@ def _heartbeat_manager(self):
if not self.done:
self.log.warning(
"DagFileProcessorManager (PID=%d) exited with exit code %d - re-launching",
- self._process.pid, self._process.exitcode
+ self._process.pid,
+ self._process.exitcode,
)
self.start()
@@ -409,7 +408,9 @@ def _heartbeat_manager(self):
Stats.incr('dag_processing.manager_stalls')
self.log.error(
"DagFileProcessorManager (PID=%d) last sent a heartbeat %.2f seconds ago! Restarting it",
- self._process.pid, parsing_stat_age)
+ self._process.pid,
+ parsing_stat_age,
+ )
reap_process_group(self._process.pid, logger=self.log)
self.start()
@@ -486,18 +487,17 @@ class DagFileProcessorManager(LoggingMixin): # pylint: disable=too-many-instanc
:type async_mode: bool
"""
- def __init__(self,
- dag_directory: str,
- max_runs: int,
- processor_factory: Callable[
- [str, List[CallbackRequest]],
- AbstractDagFileProcessorProcess
- ],
- processor_timeout: timedelta,
- signal_conn: MultiprocessingConnection,
- dag_ids: Optional[List[str]],
- pickle_dags: bool,
- async_mode: bool = True):
+ def __init__(
+ self,
+ dag_directory: str,
+ max_runs: int,
+ processor_factory: Callable[[str, List[CallbackRequest]], AbstractDagFileProcessorProcess],
+ processor_timeout: timedelta,
+ signal_conn: MultiprocessingConnection,
+ dag_ids: Optional[List[str]],
+ pickle_dags: bool,
+ async_mode: bool = True,
+ ):
super().__init__()
self._file_paths: List[str] = []
self._file_path_queue: List[str] = []
@@ -514,20 +514,18 @@ def __init__(self,
if 'sqlite' in conf.get('core', 'sql_alchemy_conn') and self._parallelism > 1:
self.log.warning(
"Because we cannot use more than 1 thread (max_threads = "
- "%d ) when using sqlite. So we set parallelism to 1.", self._parallelism
+ "%d ) when using sqlite. So we set parallelism to 1.",
+ self._parallelism,
)
self._parallelism = 1
# Parse and schedule each file no faster than this interval.
- self._file_process_interval = conf.getint('scheduler',
- 'min_file_process_interval')
+ self._file_process_interval = conf.getint('scheduler', 'min_file_process_interval')
# How often to print out DAG file processing stats to the log. Default to
# 30 seconds.
- self.print_stats_interval = conf.getint('scheduler',
- 'print_stats_interval')
+ self.print_stats_interval = conf.getint('scheduler', 'print_stats_interval')
# How many seconds do we wait for tasks to heartbeat before mark them as zombies.
- self._zombie_threshold_secs = (
- conf.getint('scheduler', 'scheduler_zombie_task_threshold'))
+ self._zombie_threshold_secs = conf.getint('scheduler', 'scheduler_zombie_task_threshold')
# Should store dag file source in a database?
self.store_dag_code = STORE_DAG_CODE
@@ -641,8 +639,7 @@ def _run_parsing_loop(self):
# This shouldn't happen, as in sync mode poll should block for
# ever. Lets be defensive about that.
self.log.warning(
- "wait() unexpectedly returned nothing ready after infinite timeout (%r)!",
- poll_time
+ "wait() unexpectedly returned nothing ready after infinite timeout (%r)!", poll_time
)
continue
@@ -676,8 +673,7 @@ def _run_parsing_loop(self):
self._num_run += 1
if not self._async_mode:
- self.log.debug(
- "Waiting for processors to finish since we're using sqlite")
+ self.log.debug("Waiting for processors to finish since we're using sqlite")
# Wait until the running DAG processors are finished before
# sending a DagParsingStat message back. This means the Agent
# can tell we've got to the end of this iteration when it sees
@@ -692,15 +688,17 @@ def _run_parsing_loop(self):
all_files_processed = all(self.get_last_finish_time(x) is not None for x in self.file_paths)
max_runs_reached = self.max_runs_reached()
- dag_parsing_stat = DagParsingStat(self._file_paths,
- max_runs_reached,
- all_files_processed,
- )
+ dag_parsing_stat = DagParsingStat(
+ self._file_paths,
+ max_runs_reached,
+ all_files_processed,
+ )
self._signal_conn.send(dag_parsing_stat)
if max_runs_reached:
- self.log.info("Exiting dag parsing loop as all files "
- "have been processed %s times", self._max_runs)
+ self.log.info(
+ "Exiting dag parsing loop as all files " "have been processed %s times", self._max_runs
+ )
break
if self._async_mode:
@@ -740,12 +738,12 @@ def _refresh_dag_dir(self):
if self.store_dag_code:
from airflow.models.dagcode import DagCode
+
DagCode.remove_deleted_code(self._file_paths)
def _print_stat(self):
"""Occasionally print out stats about how fast the files are getting processed"""
- if 0 < self.print_stats_interval < (
- timezone.utcnow() - self.last_stat_print_time).total_seconds():
+ if 0 < self.print_stats_interval < (timezone.utcnow() - self.last_stat_print_time).total_seconds():
if self._file_paths:
self._log_file_processing_stats(self._file_paths)
self.last_stat_print_time = timezone.utcnow()
@@ -760,9 +758,7 @@ def clear_nonexistent_import_errors(self, session):
"""
query = session.query(errors.ImportError)
if self._file_paths:
- query = query.filter(
- ~errors.ImportError.filename.in_(self._file_paths)
- )
+ query = query.filter(~errors.ImportError.filename.in_(self._file_paths))
query.delete(synchronize_session='fetch')
session.commit()
@@ -783,13 +779,7 @@ def _log_file_processing_stats(self, known_file_paths):
# Last Runtime: If the process ran before, how long did it take to
# finish in seconds
# Last Run: When the file finished processing in the previous run.
- headers = ["File Path",
- "PID",
- "Runtime",
- "# DAGs",
- "# Errors",
- "Last Runtime",
- "Last Run"]
+ headers = ["File Path", "PID", "Runtime", "# DAGs", "# Errors", "Last Runtime", "Last Run"]
rows = []
now = timezone.utcnow()
@@ -802,7 +792,7 @@ def _log_file_processing_stats(self, known_file_paths):
processor_pid = self.get_pid(file_path)
processor_start_time = self.get_start_time(file_path)
- runtime = ((now - processor_start_time) if processor_start_time else None)
+ runtime = (now - processor_start_time) if processor_start_time else None
last_run = self.get_last_finish_time(file_path)
if last_run:
seconds_ago = (now - last_run).total_seconds()
@@ -812,34 +802,33 @@ def _log_file_processing_stats(self, known_file_paths):
# TODO: Remove before Airflow 2.0
Stats.timing(f'dag_processing.last_runtime.{file_name}', runtime)
- rows.append((file_path,
- processor_pid,
- runtime,
- num_dags,
- num_errors,
- last_runtime,
- last_run))
+ rows.append((file_path, processor_pid, runtime, num_dags, num_errors, last_runtime, last_run))
# Sort by longest last runtime. (Can't sort None values in python3)
rows = sorted(rows, key=lambda x: x[3] or 0.0)
formatted_rows = []
for file_path, pid, runtime, num_dags, num_errors, last_runtime, last_run in rows:
- formatted_rows.append((file_path,
- pid,
- f"{runtime.total_seconds():.2f}s" if runtime else None,
- num_dags,
- num_errors,
- f"{last_runtime:.2f}s" if last_runtime else None,
- last_run.strftime("%Y-%m-%dT%H:%M:%S") if last_run else None
- ))
- log_str = ("\n" +
- "=" * 80 +
- "\n" +
- "DAG File Processing Stats\n\n" +
- tabulate(formatted_rows, headers=headers) +
- "\n" +
- "=" * 80)
+ formatted_rows.append(
+ (
+ file_path,
+ pid,
+ f"{runtime.total_seconds():.2f}s" if runtime else None,
+ num_dags,
+ num_errors,
+ f"{last_runtime:.2f}s" if last_runtime else None,
+ last_run.strftime("%Y-%m-%dT%H:%M:%S") if last_run else None,
+ )
+ )
+ log_str = (
+ "\n"
+ + "=" * 80
+ + "\n"
+ + "DAG File Processing Stats\n\n"
+ + tabulate(formatted_rows, headers=headers)
+ + "\n"
+ + "=" * 80
+ )
self.log.info(log_str)
@@ -937,8 +926,7 @@ def set_file_paths(self, new_file_paths):
:return: None
"""
self._file_paths = new_file_paths
- self._file_path_queue = [x for x in self._file_path_queue
- if x in new_file_paths]
+ self._file_path_queue = [x for x in self._file_path_queue if x in new_file_paths]
# Stop processors that are working on deleted files
filtered_processors = {}
for file_path, processor in self._processors.items():
@@ -966,8 +954,7 @@ def _collect_results_from_processor(self, processor) -> None:
num_dags, count_import_errors = processor.result
else:
self.log.error(
- "Processor for %s exited with return code %s.",
- processor.file_path, processor.exit_code
+ "Processor for %s exited with return code %s.", processor.file_path, processor.exit_code
)
count_import_errors = -1
num_dags = 0
@@ -993,11 +980,9 @@ def collect_results(self) -> None:
self._processors.pop(processor.file_path)
self._collect_results_from_processor(processor)
- self.log.debug("%s/%s DAG parsing processes running",
- len(self._processors), self._parallelism)
+ self.log.debug("%s/%s DAG parsing processes running", len(self._processors), self._parallelism)
- self.log.debug("%s file paths queued for processing",
- len(self._file_path_queue))
+ self.log.debug("%s file paths queued for processing", len(self._file_path_queue))
def start_new_processes(self):
"""Start more processors if we have enough slots and files to process"""
@@ -1005,19 +990,14 @@ def start_new_processes(self):
file_path = self._file_path_queue.pop(0)
callback_to_execute_for_file = self._callback_to_execute[file_path]
processor = self._processor_factory(
- file_path,
- callback_to_execute_for_file,
- self._dag_ids,
- self._pickle_dags)
+ file_path, callback_to_execute_for_file, self._dag_ids, self._pickle_dags
+ )
del self._callback_to_execute[file_path]
Stats.incr('dag_processing.processes')
processor.start()
- self.log.debug(
- "Started a process (PID: %s) to generate tasks for %s",
- processor.pid, file_path
- )
+ self.log.debug("Started a process (PID: %s) to generate tasks for %s", processor.pid, file_path)
self._processors[file_path] = processor
self.waitables[processor.waitable_handle] = processor
@@ -1031,39 +1011,36 @@ def prepare_file_path_queue(self):
file_paths_recently_processed = []
for file_path in self._file_paths:
last_finish_time = self.get_last_finish_time(file_path)
- if (last_finish_time is not None and
- (now - last_finish_time).total_seconds() <
- self._file_process_interval):
+ if (
+ last_finish_time is not None
+ and (now - last_finish_time).total_seconds() < self._file_process_interval
+ ):
file_paths_recently_processed.append(file_path)
- files_paths_at_run_limit = [file_path
- for file_path, stat in self._file_stats.items()
- if stat.run_count == self._max_runs]
+ files_paths_at_run_limit = [
+ file_path for file_path, stat in self._file_stats.items() if stat.run_count == self._max_runs
+ ]
- files_paths_to_queue = list(set(self._file_paths) -
- set(file_paths_in_progress) -
- set(file_paths_recently_processed) -
- set(files_paths_at_run_limit))
+ files_paths_to_queue = list(
+ set(self._file_paths)
+ - set(file_paths_in_progress)
+ - set(file_paths_recently_processed)
+ - set(files_paths_at_run_limit)
+ )
for file_path, processor in self._processors.items():
self.log.debug(
"File path %s is still being processed (started: %s)",
- processor.file_path, processor.start_time.isoformat()
+ processor.file_path,
+ processor.start_time.isoformat(),
)
- self.log.debug(
- "Queuing the following files for processing:\n\t%s",
- "\n\t".join(files_paths_to_queue)
- )
+ self.log.debug("Queuing the following files for processing:\n\t%s", "\n\t".join(files_paths_to_queue))
for file_path in files_paths_to_queue:
if file_path not in self._file_stats:
self._file_stats[file_path] = DagFileStat(
- num_dags=0,
- import_errors=0,
- last_finish_time=None,
- last_duration=None,
- run_count=0
+ num_dags=0, import_errors=0, last_finish_time=None, last_duration=None, run_count=0
)
self._file_path_queue.extend(files_paths_to_queue)
@@ -1075,10 +1052,13 @@ def _find_zombies(self, session):
and update the current zombie list.
"""
now = timezone.utcnow()
- if not self._last_zombie_query_time or \
- (now - self._last_zombie_query_time).total_seconds() > self._zombie_query_interval:
+ if (
+ not self._last_zombie_query_time
+ or (now - self._last_zombie_query_time).total_seconds() > self._zombie_query_interval
+ ):
# to avoid circular imports
from airflow.jobs.local_task_job import LocalTaskJob as LJ
+
self.log.info("Finding 'running' jobs without a recent heartbeat")
TI = airflow.models.TaskInstance
DM = airflow.models.DagModel
@@ -1095,7 +1075,8 @@ def _find_zombies(self, session):
LJ.state != State.RUNNING,
LJ.latest_heartbeat < limit_dttm,
)
- ).all()
+ )
+ .all()
)
self._last_zombie_query_time = timezone.utcnow()
@@ -1116,9 +1097,11 @@ def _kill_timed_out_processors(self):
duration = now - processor.start_time
if duration > self._processor_timeout:
self.log.error(
- "Processor for %s with PID %s started at %s has timed out, "
- "killing it.",
- file_path, processor.pid, processor.start_time.isoformat())
+ "Processor for %s with PID %s started at %s has timed out, " "killing it.",
+ file_path,
+ processor.pid,
+ processor.start_time.isoformat(),
+ )
Stats.decr('dag_processing.processes')
Stats.incr('dag_processing.processor_timeouts')
# TODO: Remove after Airflow 2.0
@@ -1164,8 +1147,9 @@ def emit_metrics(self):
parse_time = (timezone.utcnow() - self._parsing_start_time).total_seconds()
Stats.gauge('dag_processing.total_parse_time', parse_time)
Stats.gauge('dagbag_size', sum(stat.num_dags for stat in self._file_stats.values()))
- Stats.gauge('dag_processing.import_errors',
- sum(stat.import_errors for stat in self._file_stats.values()))
+ Stats.gauge(
+ 'dag_processing.import_errors', sum(stat.import_errors for stat in self._file_stats.values())
+ )
# TODO: Remove before Airflow 2.0
Stats.gauge('collect_dags', parse_time)
diff --git a/airflow/utils/dates.py b/airflow/utils/dates.py
index 2377e12cc2b1a..de5e52b668e57 100644
--- a/airflow/utils/dates.py
+++ b/airflow/utils/dates.py
@@ -244,11 +244,7 @@ def days_ago(n, hour=0, minute=0, second=0, microsecond=0):
Get a datetime object representing `n` days ago. By default the time is
set to midnight.
"""
- today = timezone.utcnow().replace(
- hour=hour,
- minute=minute,
- second=second,
- microsecond=microsecond)
+ today = timezone.utcnow().replace(hour=hour, minute=minute, second=second, microsecond=microsecond)
return today - timedelta(days=n)
diff --git a/airflow/utils/db.py b/airflow/utils/db.py
index 479ced65cd5ca..92cb224a0ec2b 100644
--- a/airflow/utils/db.py
+++ b/airflow/utils/db.py
@@ -25,14 +25,34 @@
from airflow.configuration import conf
from airflow.jobs.base_job import BaseJob # noqa: F401 # pylint: disable=unused-import
from airflow.models import ( # noqa: F401 # pylint: disable=unused-import
- DAG, XCOM_RETURN_KEY, BaseOperator, BaseOperatorLink, Connection, DagBag, DagModel, DagPickle, DagRun,
- DagTag, Log, Pool, SkipMixin, SlaMiss, TaskFail, TaskInstance, TaskReschedule, Variable, XCom,
+ DAG,
+ XCOM_RETURN_KEY,
+ BaseOperator,
+ BaseOperatorLink,
+ Connection,
+ DagBag,
+ DagModel,
+ DagPickle,
+ DagRun,
+ DagTag,
+ Log,
+ Pool,
+ SkipMixin,
+ SlaMiss,
+ TaskFail,
+ TaskInstance,
+ TaskReschedule,
+ Variable,
+ XCom,
)
+
# We need to add this model manually to get reset working well
from airflow.models.serialized_dag import SerializedDagModel # noqa: F401 # pylint: disable=unused-import
+
# TODO: remove create_session once we decide to break backward compatibility
from airflow.utils.session import ( # noqa: F401 # pylint: disable=unused-import
- create_session, provide_session,
+ create_session,
+ provide_session,
)
log = logging.getLogger(__name__)
@@ -52,8 +72,7 @@ def add_default_pool_if_not_exists(session=None):
if not Pool.get_pool(Pool.DEFAULT_POOL_NAME, session=session):
default_pool = Pool(
pool=Pool.DEFAULT_POOL_NAME,
- slots=conf.getint(section='core', key='non_pooled_task_slot_count',
- fallback=128),
+ slots=conf.getint(section='core', key='non_pooled_task_slot_count', fallback=128),
description="Default pool",
)
session.add(default_pool)
@@ -72,14 +91,14 @@ def create_default_connections(session=None):
password="",
schema="airflow",
),
- session
+ session,
)
merge_conn(
Connection(
conn_id="aws_default",
conn_type="aws",
),
- session
+ session,
)
merge_conn(
Connection(
@@ -87,7 +106,7 @@ def create_default_connections(session=None):
conn_type="azure_batch",
login="",
password="",
- extra='''{"account_url": ""}'''
+ extra='''{"account_url": ""}''',
)
)
merge_conn(
@@ -96,7 +115,7 @@ def create_default_connections(session=None):
conn_type="azure_container_instances",
extra='{"tenantId": "", "subscriptionId": "" }',
),
- session
+ session,
)
merge_conn(
Connection(
@@ -104,15 +123,16 @@ def create_default_connections(session=None):
conn_type="azure_cosmos",
extra='{"database_name": "", "collection_name": "" }',
),
- session
+ session,
)
merge_conn(
Connection(
- conn_id='azure_data_explorer_default', conn_type='azure_data_explorer',
+ conn_id='azure_data_explorer_default',
+ conn_type='azure_data_explorer',
host='https://.kusto.windows.net',
extra='''{"auth_method": "",
"tenant": "", "certificate": "",
- "thumbprint": ""}'''
+ "thumbprint": ""}''',
),
session,
)
@@ -122,7 +142,7 @@ def create_default_connections(session=None):
conn_type="azure_data_lake",
extra='{"tenant": "", "account_name": "" }',
),
- session
+ session,
)
merge_conn(
Connection(
@@ -131,7 +151,7 @@ def create_default_connections(session=None):
host="cassandra",
port=9042,
),
- session
+ session,
)
merge_conn(
Connection(
@@ -139,7 +159,7 @@ def create_default_connections(session=None):
conn_type="databricks",
host="localhost",
),
- session
+ session,
)
merge_conn(
Connection(
@@ -148,7 +168,7 @@ def create_default_connections(session=None):
host="",
password="",
),
- session
+ session,
)
merge_conn(
Connection(
@@ -158,7 +178,7 @@ def create_default_connections(session=None):
port=8082,
extra='{"endpoint": "druid/v2/sql"}',
),
- session
+ session,
)
merge_conn(
Connection(
@@ -168,7 +188,7 @@ def create_default_connections(session=None):
port=8081,
extra='{"endpoint": "druid/indexer/v1/task"}',
),
- session
+ session,
)
merge_conn(
Connection(
@@ -176,9 +196,9 @@ def create_default_connections(session=None):
conn_type="elasticsearch",
host="localhost",
schema="http",
- port=9200
+ port=9200,
),
- session
+ session,
)
merge_conn(
Connection(
@@ -229,7 +249,7 @@ def create_default_connections(session=None):
}
""",
),
- session
+ session,
)
merge_conn(
Connection(
@@ -243,7 +263,7 @@ def create_default_connections(session=None):
}
""",
),
- session
+ session,
)
merge_conn(
Connection(
@@ -259,7 +279,7 @@ def create_default_connections(session=None):
conn_type="google_cloud_platform",
schema="default",
),
- session
+ session,
)
merge_conn(
Connection(
@@ -270,7 +290,7 @@ def create_default_connections(session=None):
extra='{"use_beeline": true, "auth": ""}',
schema="default",
),
- session
+ session,
)
merge_conn(
Connection(
@@ -280,7 +300,7 @@ def create_default_connections(session=None):
schema="default",
port=10000,
),
- session
+ session,
)
merge_conn(
Connection(
@@ -288,14 +308,14 @@ def create_default_connections(session=None):
conn_type="http",
host="https://www.httpbin.org/",
),
- session
+ session,
)
merge_conn(
Connection(
conn_id='kubernetes_default',
conn_type='kubernetes',
),
- session
+ session,
)
merge_conn(
Connection(
@@ -304,19 +324,11 @@ def create_default_connections(session=None):
host='localhost',
port=7070,
login="ADMIN",
- password="KYLIN"
- ),
- session
- )
- merge_conn(
- Connection(
- conn_id="livy_default",
- conn_type="livy",
- host="livy",
- port=8998
+ password="KYLIN",
),
- session
+ session,
)
+ merge_conn(Connection(conn_id="livy_default", conn_type="livy", host="livy", port=8998), session)
merge_conn(
Connection(
conn_id="local_mysql",
@@ -326,7 +338,7 @@ def create_default_connections(session=None):
password="airflow",
schema="airflow",
),
- session
+ session,
)
merge_conn(
Connection(
@@ -336,17 +348,9 @@ def create_default_connections(session=None):
extra='{"authMechanism": "PLAIN"}',
port=9083,
),
- session
- )
- merge_conn(
- Connection(
- conn_id="mongo_default",
- conn_type="mongo",
- host="mongo",
- port=27017
- ),
- session
+ session,
)
+ merge_conn(Connection(conn_id="mongo_default", conn_type="mongo", host="mongo", port=27017), session)
merge_conn(
Connection(
conn_id="mssql_default",
@@ -354,7 +358,7 @@ def create_default_connections(session=None):
host="localhost",
port=1433,
),
- session
+ session,
)
merge_conn(
Connection(
@@ -364,7 +368,7 @@ def create_default_connections(session=None):
schema="airflow",
host="mysql",
),
- session
+ session,
)
merge_conn(
Connection(
@@ -373,7 +377,7 @@ def create_default_connections(session=None):
host="",
password="",
),
- session
+ session,
)
merge_conn(
Connection(
@@ -381,7 +385,7 @@ def create_default_connections(session=None):
conn_type="pig_cli",
schema="default",
),
- session
+ session,
)
merge_conn(
Connection(
@@ -390,7 +394,7 @@ def create_default_connections(session=None):
host="localhost",
port=9000,
),
- session
+ session,
)
merge_conn(
Connection(
@@ -400,7 +404,7 @@ def create_default_connections(session=None):
port=9000,
extra='{"endpoint": "/query", "schema": "http"}',
),
- session
+ session,
)
merge_conn(
Connection(
@@ -411,7 +415,7 @@ def create_default_connections(session=None):
schema="airflow",
host="postgres",
),
- session
+ session,
)
merge_conn(
Connection(
@@ -421,7 +425,7 @@ def create_default_connections(session=None):
schema="hive",
port=3400,
),
- session
+ session,
)
merge_conn(
Connection(
@@ -429,7 +433,7 @@ def create_default_connections(session=None):
conn_type="qubole",
host="localhost",
),
- session
+ session,
)
merge_conn(
Connection(
@@ -439,7 +443,7 @@ def create_default_connections(session=None):
port=6379,
extra='{"db": 0}',
),
- session
+ session,
)
merge_conn(
Connection(
@@ -447,7 +451,7 @@ def create_default_connections(session=None):
conn_type="segment",
extra='{"write_key": "my-segment-write-key"}',
),
- session
+ session,
)
merge_conn(
Connection(
@@ -458,7 +462,7 @@ def create_default_connections(session=None):
login="airflow",
extra='{"key_file": "~/.ssh/id_rsa", "no_host_key_check": true}',
),
- session
+ session,
)
merge_conn(
Connection(
@@ -467,7 +471,7 @@ def create_default_connections(session=None):
host="yarn",
extra='{"queue": "root.default"}',
),
- session
+ session,
)
merge_conn(
Connection(
@@ -475,7 +479,7 @@ def create_default_connections(session=None):
conn_type="sqlite",
host="/tmp/sqlite_default.db",
),
- session
+ session,
)
merge_conn(
Connection(
@@ -483,7 +487,7 @@ def create_default_connections(session=None):
conn_type="sqoop",
host="rdbms",
),
- session
+ session,
)
merge_conn(
Connection(
@@ -491,7 +495,7 @@ def create_default_connections(session=None):
conn_type="ssh",
host="localhost",
),
- session
+ session,
)
merge_conn(
Connection(
@@ -502,7 +506,7 @@ def create_default_connections(session=None):
password="password",
extra='{"site_id": "my_site"}',
),
- session
+ session,
)
merge_conn(
Connection(
@@ -511,7 +515,7 @@ def create_default_connections(session=None):
host="localhost",
port=5433,
),
- session
+ session,
)
merge_conn(
Connection(
@@ -519,7 +523,7 @@ def create_default_connections(session=None):
conn_type="wasb",
extra='{"sas_token": null}',
),
- session
+ session,
)
merge_conn(
Connection(
@@ -528,7 +532,7 @@ def create_default_connections(session=None):
host="localhost",
port=50070,
),
- session
+ session,
)
merge_conn(
Connection(
@@ -536,7 +540,7 @@ def create_default_connections(session=None):
conn_type='yandexcloud',
schema='default',
),
- session
+ session,
)
@@ -555,6 +559,7 @@ def initdb():
DAG.deactivate_unknown_dags(dagbag.dags.keys())
from flask_appbuilder.models.sqla import Base
+
Base.metadata.create_all(settings.engine) # pylint: disable=no-member
@@ -590,8 +595,7 @@ def check_migrations(timeout):
if source_heads == db_heads:
break
if ticker >= timeout:
- raise TimeoutError("There are still unapplied migrations after {} "
- "seconds.".format(ticker))
+ raise TimeoutError("There are still unapplied migrations after {} " "seconds.".format(ticker))
ticker += 1
time.sleep(1)
log.info('Waiting for migrations... %s second(s)', ticker)
@@ -649,6 +653,7 @@ def drop_airflow_models(connection):
Base.metadata.remove(chart)
# alembic adds significant import time, so we import it lazily
from alembic.migration import MigrationContext # noqa
+
migration_ctx = MigrationContext.configure(connection)
version = migration_ctx._version # noqa pylint: disable=protected-access
if version.exists(connection):
@@ -662,6 +667,7 @@ def drop_flask_models(connection):
@return:
"""
from flask_appbuilder.models.sqla import Base
+
Base.metadata.drop_all(connection) # pylint: disable=no-member
diff --git a/airflow/utils/decorators.py b/airflow/utils/decorators.py
index e8238c971e95e..40fb9e0702f7b 100644
--- a/airflow/utils/decorators.py
+++ b/airflow/utils/decorators.py
@@ -46,17 +46,19 @@ def apply_defaults(func: T) -> T:
# have a different sig_cache.
sig_cache = signature(func)
non_optional_args = {
- name for (name, param) in sig_cache.parameters.items()
- if param.default == param.empty and
- param.name != 'self' and
- param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)}
+ name
+ for (name, param) in sig_cache.parameters.items()
+ if param.default == param.empty
+ and param.name != 'self'
+ and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)
+ }
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
from airflow.models.dag import DagContext
+
if len(args) > 1:
- raise AirflowException(
- "Use keyword arguments when initializing operators")
+ raise AirflowException("Use keyword arguments when initializing operators")
dag_args: Dict[str, Any] = {}
dag_params: Dict[str, Any] = {}
@@ -91,6 +93,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
result = func(*args, **kwargs)
return result
+
return cast(T, wrapper)
diff --git a/airflow/utils/dot_renderer.py b/airflow/utils/dot_renderer.py
index 65e9a4a099fc2..990c7a7d126fb 100644
--- a/airflow/utils/dot_renderer.py
+++ b/airflow/utils/dot_renderer.py
@@ -55,8 +55,14 @@ def render_dag(dag: DAG, tis: Optional[List[TaskInstance]] = None) -> graphviz.D
:return: Graphviz object
:rtype: graphviz.Digraph
"""
- dot = graphviz.Digraph(dag.dag_id, graph_attr={"rankdir": dag.orientation if dag.orientation else "LR",
- "labelloc": "t", "label": dag.dag_id})
+ dot = graphviz.Digraph(
+ dag.dag_id,
+ graph_attr={
+ "rankdir": dag.orientation if dag.orientation else "LR",
+ "labelloc": "t",
+ "label": dag.dag_id,
+ },
+ )
states_by_task_id = None
if tis is not None:
states_by_task_id = {ti.task_id: ti.state for ti in tis}
@@ -66,16 +72,20 @@ def render_dag(dag: DAG, tis: Optional[List[TaskInstance]] = None) -> graphviz.D
"style": "filled,rounded",
}
if states_by_task_id is None:
- node_attrs.update({
- "color": _refine_color(task.ui_fgcolor),
- "fillcolor": _refine_color(task.ui_color),
- })
+ node_attrs.update(
+ {
+ "color": _refine_color(task.ui_fgcolor),
+ "fillcolor": _refine_color(task.ui_color),
+ }
+ )
else:
state = states_by_task_id.get(task.task_id, State.NONE)
- node_attrs.update({
- "color": State.color_fg(state),
- "fillcolor": State.color(state),
- })
+ node_attrs.update(
+ {
+ "color": State.color_fg(state),
+ "fillcolor": State.color(state),
+ }
+ )
dot.node(
task.task_id,
_attributes=node_attrs,
diff --git a/airflow/utils/email.py b/airflow/utils/email.py
index f32757d2633c6..17d3b6e97bd7c 100644
--- a/airflow/utils/email.py
+++ b/airflow/utils/email.py
@@ -32,17 +32,35 @@
log = logging.getLogger(__name__)
-def send_email(to: Union[List[str], Iterable[str]], subject: str, html_content: str,
- files=None, dryrun=False, cc=None, bcc=None,
- mime_subtype='mixed', mime_charset='utf-8', **kwargs):
+def send_email(
+ to: Union[List[str], Iterable[str]],
+ subject: str,
+ html_content: str,
+ files=None,
+ dryrun=False,
+ cc=None,
+ bcc=None,
+ mime_subtype='mixed',
+ mime_charset='utf-8',
+ **kwargs,
+):
"""Send email using backend specified in EMAIL_BACKEND."""
backend = conf.getimport('email', 'EMAIL_BACKEND')
to_list = get_email_address_list(to)
to_comma_separated = ", ".join(to_list)
- return backend(to_comma_separated, subject, html_content, files=files,
- dryrun=dryrun, cc=cc, bcc=bcc,
- mime_subtype=mime_subtype, mime_charset=mime_charset, **kwargs)
+ return backend(
+ to_comma_separated,
+ subject,
+ html_content,
+ files=files,
+ dryrun=dryrun,
+ cc=cc,
+ bcc=bcc,
+ mime_subtype=mime_subtype,
+ mime_charset=mime_charset,
+ **kwargs,
+ )
def send_email_smtp(
@@ -132,10 +150,7 @@ def build_mime_message(
for fname in files or []:
basename = os.path.basename(fname)
with open(fname, "rb") as file:
- part = MIMEApplication(
- file.read(),
- Name=basename
- )
+ part = MIMEApplication(file.read(), Name=basename)
part['Content-Disposition'] = f'attachment; filename="{basename}"'
part['Content-ID'] = f'<{basename}>'
msg.attach(part)
diff --git a/airflow/utils/file.py b/airflow/utils/file.py
index 405a6b029548e..553c506696e5b 100644
--- a/airflow/utils/file.py
+++ b/airflow/utils/file.py
@@ -31,9 +31,11 @@ def TemporaryDirectory(*args, **kwargs): # pylint: disable=invalid-name
"""This function is deprecated. Please use `tempfile.TemporaryDirectory`"""
import warnings
from tempfile import TemporaryDirectory as TmpDir
+
warnings.warn(
"This function is deprecated. Please use `tempfile.TemporaryDirectory`",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
return TmpDir(*args, **kwargs)
@@ -49,9 +51,11 @@ def mkdirs(path, mode):
:type mode: int
"""
import warnings
+
warnings.warn(
f"This function is deprecated. Please use `pathlib.Path({path}).mkdir`",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
Path(path).mkdir(mode=mode, parents=True, exist_ok=True)
@@ -85,9 +89,7 @@ def open_maybe_zipped(fileloc, mode='r'):
return open(fileloc, mode=mode)
-def find_path_from_directory(
- base_dir_path: str,
- ignore_file_name: str) -> Generator[str, None, None]:
+def find_path_from_directory(base_dir_path: str, ignore_file_name: str) -> Generator[str, None, None]:
"""
Search the file and return the path of the file that should not be ignored.
:param base_dir_path: the base path to be searched for.
@@ -110,8 +112,9 @@ def find_path_from_directory(
dirs[:] = [
subdir
for subdir in dirs
- if not any(p.search(
- os.path.join(os.path.relpath(root, str(base_dir_path)), subdir)) for p in patterns)
+ if not any(
+ p.search(os.path.join(os.path.relpath(root, str(base_dir_path)), subdir)) for p in patterns
+ )
]
patterns_by_dir.update({os.path.join(root, sd): patterns.copy() for sd in dirs})
@@ -126,11 +129,12 @@ def find_path_from_directory(
yield str(abs_file_path)
-def list_py_file_paths(directory: str,
- safe_mode: bool = conf.getboolean('core', 'DAG_DISCOVERY_SAFE_MODE', fallback=True),
- include_examples: Optional[bool] = None,
- include_smart_sensor: Optional[bool] =
- conf.getboolean('smart_sensor', 'use_smart_sensor')):
+def list_py_file_paths(
+ directory: str,
+ safe_mode: bool = conf.getboolean('core', 'DAG_DISCOVERY_SAFE_MODE', fallback=True),
+ include_examples: Optional[bool] = None,
+ include_smart_sensor: Optional[bool] = conf.getboolean('smart_sensor', 'use_smart_sensor'),
+):
"""
Traverse a directory and look for Python files.
@@ -159,10 +163,12 @@ def list_py_file_paths(directory: str,
find_dag_file_paths(directory, file_paths, safe_mode)
if include_examples:
from airflow import example_dags
+
example_dag_folder = example_dags.__path__[0] # type: ignore
file_paths.extend(list_py_file_paths(example_dag_folder, safe_mode, False, False))
if include_smart_sensor:
from airflow import smart_sensor_dags
+
smart_sensor_dag_folder = smart_sensor_dags.__path__[0] # type: ignore
file_paths.extend(list_py_file_paths(smart_sensor_dag_folder, safe_mode, False, False))
return file_paths
@@ -170,8 +176,7 @@ def list_py_file_paths(directory: str,
def find_dag_file_paths(directory: str, file_paths: list, safe_mode: bool):
"""Finds file paths of all DAG files."""
- for file_path in find_path_from_directory(
- directory, ".airflowignore"):
+ for file_path in find_path_from_directory(directory, ".airflowignore"):
try:
if not os.path.isfile(file_path):
continue
diff --git a/airflow/utils/helpers.py b/airflow/utils/helpers.py
index 01f88fcbac790..5ccb618f58766 100644
--- a/airflow/utils/helpers.py
+++ b/airflow/utils/helpers.py
@@ -36,12 +36,12 @@ def validate_key(k, max_length=250):
if not isinstance(k, str):
raise TypeError("The key has to be a string")
elif len(k) > max_length:
- raise AirflowException(
- f"The key has to be less than {max_length} characters")
+ raise AirflowException(f"The key has to be less than {max_length} characters")
elif not KEY_REGEX.match(k):
raise AirflowException(
"The key ({k}) has to be made of alphanumeric characters, dashes, "
- "dots and underscores exclusively".format(k=k))
+ "dots and underscores exclusively".format(k=k)
+ )
else:
return True
@@ -101,15 +101,10 @@ def chunks(items: List[T], chunk_size: int) -> Generator[List[T], None, None]:
if chunk_size <= 0:
raise ValueError('Chunk size must be a positive integer')
for i in range(0, len(items), chunk_size):
- yield items[i:i + chunk_size]
+ yield items[i : i + chunk_size]
-def reduce_in_chunks(
- fn: Callable[[S, List[T]], S],
- iterable: List[T],
- initializer: S,
- chunk_size: int = 0
-):
+def reduce_in_chunks(fn: Callable[[S, List[T]], S], iterable: List[T], initializer: S, chunk_size: int = 0):
"""
Reduce the given list of items by splitting it into chunks
of the given size and passing each chunk through the reducer
@@ -155,10 +150,12 @@ def render_log_filename(ti, try_number, filename_template):
jinja_context['try_number'] = try_number
return filename_jinja_template.render(**jinja_context)
- return filename_template.format(dag_id=ti.dag_id,
- task_id=ti.task_id,
- execution_date=ti.execution_date.isoformat(),
- try_number=try_number)
+ return filename_template.format(
+ dag_id=ti.dag_id,
+ task_id=ti.task_id,
+ execution_date=ti.execution_date.isoformat(),
+ try_number=try_number,
+ )
def convert_camel_to_snake(camel_str):
@@ -191,7 +188,8 @@ def chain(*args, **kwargs):
"""This function is deprecated. Please use `airflow.models.baseoperator.chain`."""
warnings.warn(
"This function is deprecated. Please use `airflow.models.baseoperator.chain`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
return import_string('airflow.models.baseoperator.chain')(*args, **kwargs)
@@ -200,6 +198,7 @@ def cross_downstream(*args, **kwargs):
"""This function is deprecated. Please use `airflow.models.baseoperator.cross_downstream`."""
warnings.warn(
"This function is deprecated. Please use `airflow.models.baseoperator.cross_downstream`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
return import_string('airflow.models.baseoperator.cross_downstream')(*args, **kwargs)
diff --git a/airflow/utils/json.py b/airflow/utils/json.py
index 6d1cc5f036d50..d388f11e39899 100644
--- a/airflow/utils/json.py
+++ b/airflow/utils/json.py
@@ -43,17 +43,32 @@ def _default(obj):
return obj.strftime('%Y-%m-%dT%H:%M:%SZ')
elif isinstance(obj, date):
return obj.strftime('%Y-%m-%d')
- elif isinstance(obj, (np.int_, np.intc, np.intp, np.int8, np.int16,
- np.int32, np.int64, np.uint8, np.uint16,
- np.uint32, np.uint64)):
+ elif isinstance(
+ obj,
+ (
+ np.int_,
+ np.intc,
+ np.intp,
+ np.int8,
+ np.int16,
+ np.int32,
+ np.int64,
+ np.uint8,
+ np.uint16,
+ np.uint32,
+ np.uint64,
+ ),
+ ):
return int(obj)
elif isinstance(obj, np.bool_):
return bool(obj)
- elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64,
- np.complex_, np.complex64, np.complex128)):
+ elif isinstance(
+ obj, (np.float_, np.float16, np.float32, np.float64, np.complex_, np.complex64, np.complex128)
+ ):
return float(obj)
elif k8s is not None and isinstance(obj, k8s.V1Pod):
from airflow.kubernetes.pod_generator import PodGenerator
+
return PodGenerator.serialize_pod(obj)
raise TypeError(f"Object of type '{obj.__class__.__name__}' is not JSON serializable")
diff --git a/airflow/utils/log/cloudwatch_task_handler.py b/airflow/utils/log/cloudwatch_task_handler.py
index 1ba2586151698..cee5bfd87def8 100644
--- a/airflow/utils/log/cloudwatch_task_handler.py
+++ b/airflow/utils/log/cloudwatch_task_handler.py
@@ -23,5 +23,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.log.cloudwatch_task_handler`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/utils/log/colored_log.py b/airflow/utils/log/colored_log.py
index 3d7a8f13e3182..59ca3a2f0465a 100644
--- a/airflow/utils/log/colored_log.py
+++ b/airflow/utils/log/colored_log.py
@@ -66,12 +66,10 @@ def _color_record_args(self, record: LogRecord) -> LogRecord:
elif isinstance(record.args, dict):
if self._count_number_of_arguments_in_message(record) > 1:
# Case of logging.debug("a %(a)d b %(b)s", {'a':1, 'b':2})
- record.args = {
- key: self._color_arg(value) for key, value in record.args.items()
- }
+ record.args = {key: self._color_arg(value) for key, value in record.args.items()}
else:
# Case of single dict passed to formatted string
- record.args = self._color_arg(record.args) # type: ignore
+ record.args = self._color_arg(record.args) # type: ignore
elif isinstance(record.args, str):
record.args = self._color_arg(record.args)
return record
@@ -84,8 +82,9 @@ def _color_record_traceback(self, record: LogRecord) -> LogRecord:
record.exc_text = self.formatException(record.exc_info)
if record.exc_text:
- record.exc_text = self.color(self.log_colors, record.levelname) + \
- record.exc_text + escape_codes['reset']
+ record.exc_text = (
+ self.color(self.log_colors, record.levelname) + record.exc_text + escape_codes['reset']
+ )
return record
@@ -96,4 +95,5 @@ def format(self, record: LogRecord) -> str:
return super().format(record)
except ValueError: # I/O operation on closed file
from logging import Formatter
+
return Formatter().format(record)
diff --git a/airflow/utils/log/es_task_handler.py b/airflow/utils/log/es_task_handler.py
index 019fa39de944b..64de08ee9671d 100644
--- a/airflow/utils/log/es_task_handler.py
+++ b/airflow/utils/log/es_task_handler.py
@@ -23,5 +23,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.elasticsearch.log.es_task_handler`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/utils/log/file_processor_handler.py b/airflow/utils/log/file_processor_handler.py
index 220787b6a3028..9d2768eb5da7a 100644
--- a/airflow/utils/log/file_processor_handler.py
+++ b/airflow/utils/log/file_processor_handler.py
@@ -40,8 +40,7 @@ def __init__(self, base_log_folder, filename_template):
self.handler = None
self.base_log_folder = base_log_folder
self.dag_dir = os.path.expanduser(settings.DAGS_FOLDER)
- self.filename_template, self.filename_jinja_template = \
- parse_template_string(filename_template)
+ self.filename_template, self.filename_jinja_template = parse_template_string(filename_template)
self._cur_date = datetime.today()
Path(self._get_log_directory()).mkdir(parents=True, exist_ok=True)
@@ -85,6 +84,7 @@ def _render_filename(self, filename):
# is always inside the log dir as other DAGs. To be differentiate with regular DAGs,
# their logs will be in the `log_dir/native_dags`.
import airflow
+
airflow_directory = airflow.__path__[0]
if filename.startswith(airflow_directory):
filename = os.path.join("native_dags", os.path.relpath(filename, airflow_directory))
@@ -119,17 +119,14 @@ def _symlink_latest_log_directory(self):
if os.readlink(latest_log_directory_path) != log_directory:
os.unlink(latest_log_directory_path)
os.symlink(log_directory, latest_log_directory_path)
- elif (os.path.isdir(latest_log_directory_path) or
- os.path.isfile(latest_log_directory_path)):
+ elif os.path.isdir(latest_log_directory_path) or os.path.isfile(latest_log_directory_path):
logging.warning(
- "%s already exists as a dir/file. Skip creating symlink.",
- latest_log_directory_path
+ "%s already exists as a dir/file. Skip creating symlink.", latest_log_directory_path
)
else:
os.symlink(log_directory, latest_log_directory_path)
except OSError:
- logging.warning("OSError while attempting to symlink "
- "the latest log directory")
+ logging.warning("OSError while attempting to symlink " "the latest log directory")
def _init_file(self, filename):
"""
@@ -138,8 +135,7 @@ def _init_file(self, filename):
:param filename: task instance object
:return: relative log path of the given task instance
"""
- relative_log_file_path = os.path.join(
- self._get_log_directory(), self._render_filename(filename))
+ relative_log_file_path = os.path.join(self._get_log_directory(), self._render_filename(filename))
log_file_path = os.path.abspath(relative_log_file_path)
directory = os.path.dirname(log_file_path)
diff --git a/airflow/utils/log/file_task_handler.py b/airflow/utils/log/file_task_handler.py
index b1170174c35aa..3a06ff3e93013 100644
--- a/airflow/utils/log/file_task_handler.py
+++ b/airflow/utils/log/file_task_handler.py
@@ -43,8 +43,7 @@ def __init__(self, base_log_folder: str, filename_template: str):
super().__init__()
self.handler = None # type: Optional[logging.FileHandler]
self.local_base = base_log_folder
- self.filename_template, self.filename_jinja_template = \
- parse_template_string(filename_template)
+ self.filename_template, self.filename_jinja_template = parse_template_string(filename_template)
def set_context(self, ti: TaskInstance):
"""
@@ -83,10 +82,12 @@ def _render_filename(self, ti, try_number):
}
return self.filename_jinja_template.render(**jinja_context)
- return self.filename_template.format(dag_id=ti.dag_id,
- task_id=ti.task_id,
- execution_date=ti.execution_date.isoformat(),
- try_number=try_number)
+ return self.filename_template.format(
+ dag_id=ti.dag_id,
+ task_id=ti.task_id,
+ execution_date=ti.execution_date.isoformat(),
+ try_number=try_number,
+ )
def _read_grouped_logs(self):
return False
@@ -118,7 +119,7 @@ def _read(self, ti, try_number, metadata=None): # pylint: disable=unused-argume
except Exception as e: # pylint: disable=broad-except
log = f"*** Failed to load local log file: {location}\n"
log += "*** {}\n".format(str(e))
- elif conf.get('core', 'executor') == 'KubernetesExecutor': # pylint: disable=too-many-nested-blocks
+ elif conf.get('core', 'executor') == 'KubernetesExecutor': # pylint: disable=too-many-nested-blocks
try:
from airflow.kubernetes.kube_client import get_kube_client
@@ -129,14 +130,18 @@ def _read(self, ti, try_number, metadata=None): # pylint: disable=unused-argume
# is returned for the fqdn to comply with the 63 character limit imposed by DNS standards
# on any label of a FQDN.
pod_list = kube_client.list_namespaced_pod(conf.get('kubernetes', 'namespace'))
- matches = [pod.metadata.name for pod in pod_list.items
- if pod.metadata.name.startswith(ti.hostname)]
+ matches = [
+ pod.metadata.name
+ for pod in pod_list.items
+ if pod.metadata.name.startswith(ti.hostname)
+ ]
if len(matches) == 1:
if len(matches[0]) > len(ti.hostname):
ti.hostname = matches[0]
- log += '*** Trying to get logs (last 100 lines) from worker pod {} ***\n\n'\
- .format(ti.hostname)
+ log += '*** Trying to get logs (last 100 lines) from worker pod {} ***\n\n'.format(
+ ti.hostname
+ )
res = kube_client.read_namespaced_pod_log(
name=ti.hostname,
@@ -144,22 +149,17 @@ def _read(self, ti, try_number, metadata=None): # pylint: disable=unused-argume
container='base',
follow=False,
tail_lines=100,
- _preload_content=False
+ _preload_content=False,
)
for line in res:
log += line.decode()
except Exception as f: # pylint: disable=broad-except
- log += '*** Unable to fetch logs from worker pod {} ***\n{}\n\n'.format(
- ti.hostname, str(f)
- )
+ log += '*** Unable to fetch logs from worker pod {} ***\n{}\n\n'.format(ti.hostname, str(f))
else:
- url = os.path.join(
- "http://{ti.hostname}:{worker_log_server_port}/log", log_relative_path
- ).format(
- ti=ti,
- worker_log_server_port=conf.get('celery', 'WORKER_LOG_SERVER_PORT')
+ url = os.path.join("http://{ti.hostname}:{worker_log_server_port}/log", log_relative_path).format(
+ ti=ti, worker_log_server_port=conf.get('celery', 'WORKER_LOG_SERVER_PORT')
)
log += f"*** Log file does not exist: {location}\n"
log += f"*** Fetching from: {url}\n"
diff --git a/airflow/utils/log/gcs_task_handler.py b/airflow/utils/log/gcs_task_handler.py
index 63251a4e91f07..bf8576b8906d5 100644
--- a/airflow/utils/log/gcs_task_handler.py
+++ b/airflow/utils/log/gcs_task_handler.py
@@ -23,5 +23,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.log.gcs_task_handler`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/utils/log/json_formatter.py b/airflow/utils/log/json_formatter.py
index 4d517deb1542b..73d461942186c 100644
--- a/airflow/utils/log/json_formatter.py
+++ b/airflow/utils/log/json_formatter.py
@@ -42,7 +42,6 @@ def usesTime(self):
def format(self, record):
super().format(record)
- record_dict = {label: getattr(record, label, None)
- for label in self.json_fields}
+ record_dict = {label: getattr(record, label, None) for label in self.json_fields}
merged_record = merge_dicts(record_dict, self.extras)
return json.dumps(merged_record)
diff --git a/airflow/utils/log/log_reader.py b/airflow/utils/log/log_reader.py
index 42d03d901b3cc..bef39a0b25af2 100644
--- a/airflow/utils/log/log_reader.py
+++ b/airflow/utils/log/log_reader.py
@@ -29,8 +29,9 @@
class TaskLogReader:
"""Task log reader"""
- def read_log_chunks(self, ti: TaskInstance, try_number: Optional[int],
- metadata) -> Tuple[List[str], Dict[str, Any]]:
+ def read_log_chunks(
+ self, ti: TaskInstance, try_number: Optional[int], metadata
+ ) -> Tuple[List[str], Dict[str, Any]]:
"""
Reads chunks of Task Instance logs.
@@ -58,8 +59,7 @@ def read_log_chunks(self, ti: TaskInstance, try_number: Optional[int],
metadata = metadatas[0]
return logs, metadata
- def read_log_stream(self, ti: TaskInstance, try_number: Optional[int],
- metadata: dict) -> Iterator[str]:
+ def read_log_stream(self, ti: TaskInstance, try_number: Optional[int], metadata: dict) -> Iterator[str]:
"""
Used to continuously read log to the end
@@ -115,7 +115,6 @@ def render_log_filename(self, ti: TaskInstance, try_number: Optional[int] = None
"""
filename_template = conf.get('logging', 'LOG_FILENAME_TEMPLATE')
attachment_filename = render_log_filename(
- ti=ti,
- try_number="all" if try_number is None else try_number,
- filename_template=filename_template)
+ ti=ti, try_number="all" if try_number is None else try_number, filename_template=filename_template
+ )
return attachment_filename
diff --git a/airflow/utils/log/logging_mixin.py b/airflow/utils/log/logging_mixin.py
index 624f6b931e937..12e949b8b2318 100644
--- a/airflow/utils/log/logging_mixin.py
+++ b/airflow/utils/log/logging_mixin.py
@@ -46,9 +46,7 @@ def log(self) -> Logger:
# FIXME: LoggingMixin should have a default _log field.
return self._log # type: ignore
except AttributeError:
- self._log = logging.getLogger(
- self.__class__.__module__ + '.' + self.__class__.__name__
- )
+ self._log = logging.getLogger(self.__class__.__module__ + '.' + self.__class__.__name__)
return self._log
def _set_context(self, context):
@@ -91,7 +89,7 @@ def close(self):
"""
@property
- def closed(self): # noqa: D402
+ def closed(self): # noqa: D402
"""
Returns False to indicate that the stream is not closed (as it will be
open for the duration of Airflow's lifecycle).
@@ -141,8 +139,9 @@ class RedirectStdHandler(StreamHandler):
# pylint: disable=super-init-not-called
def __init__(self, stream):
if not isinstance(stream, str):
- raise Exception("Cannot use file like objects. Use 'stdout' or 'stderr'"
- " as a str and without 'ext://'.")
+ raise Exception(
+ "Cannot use file like objects. Use 'stdout' or 'stderr'" " as a str and without 'ext://'."
+ )
self._use_stderr = True
if 'stdout' in stream:
diff --git a/airflow/utils/log/s3_task_handler.py b/airflow/utils/log/s3_task_handler.py
index 6bccff79f4e4d..cf8a7b7d5efbe 100644
--- a/airflow/utils/log/s3_task_handler.py
+++ b/airflow/utils/log/s3_task_handler.py
@@ -23,5 +23,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.log.s3_task_handler`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/utils/log/stackdriver_task_handler.py b/airflow/utils/log/stackdriver_task_handler.py
index 0b96380ad5cbc..182c6636a3bd8 100644
--- a/airflow/utils/log/stackdriver_task_handler.py
+++ b/airflow/utils/log/stackdriver_task_handler.py
@@ -22,5 +22,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.google.cloud.log.stackdriver_task_handler`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/utils/log/wasb_task_handler.py b/airflow/utils/log/wasb_task_handler.py
index 7b41933a1ecd2..263ef5641a71b 100644
--- a/airflow/utils/log/wasb_task_handler.py
+++ b/airflow/utils/log/wasb_task_handler.py
@@ -23,5 +23,6 @@
warnings.warn(
"This module is deprecated. Please use `airflow.providers.microsoft.azure.log.wasb_task_handler`.",
- DeprecationWarning, stacklevel=2
+ DeprecationWarning,
+ stacklevel=2,
)
diff --git a/airflow/utils/module_loading.py b/airflow/utils/module_loading.py
index dbd5622d203f3..e863f8641894e 100644
--- a/airflow/utils/module_loading.py
+++ b/airflow/utils/module_loading.py
@@ -34,6 +34,4 @@ def import_string(dotted_path):
try:
return getattr(module, class_name)
except AttributeError:
- raise ImportError('Module "{}" does not define a "{}" attribute/class'.format(
- module_path, class_name)
- )
+ raise ImportError(f'Module "{module_path}" does not define a "{class_name}" attribute/class')
diff --git a/airflow/utils/operator_helpers.py b/airflow/utils/operator_helpers.py
index c2e72289255d2..72d837fb3431f 100644
--- a/airflow/utils/operator_helpers.py
+++ b/airflow/utils/operator_helpers.py
@@ -18,18 +18,24 @@
#
AIRFLOW_VAR_NAME_FORMAT_MAPPING = {
- 'AIRFLOW_CONTEXT_DAG_ID': {'default': 'airflow.ctx.dag_id',
- 'env_var_format': 'AIRFLOW_CTX_DAG_ID'},
- 'AIRFLOW_CONTEXT_TASK_ID': {'default': 'airflow.ctx.task_id',
- 'env_var_format': 'AIRFLOW_CTX_TASK_ID'},
- 'AIRFLOW_CONTEXT_EXECUTION_DATE': {'default': 'airflow.ctx.execution_date',
- 'env_var_format': 'AIRFLOW_CTX_EXECUTION_DATE'},
- 'AIRFLOW_CONTEXT_DAG_RUN_ID': {'default': 'airflow.ctx.dag_run_id',
- 'env_var_format': 'AIRFLOW_CTX_DAG_RUN_ID'},
- 'AIRFLOW_CONTEXT_DAG_OWNER': {'default': 'airflow.ctx.dag_owner',
- 'env_var_format': 'AIRFLOW_CTX_DAG_OWNER'},
- 'AIRFLOW_CONTEXT_DAG_EMAIL': {'default': 'airflow.ctx.dag_email',
- 'env_var_format': 'AIRFLOW_CTX_DAG_EMAIL'},
+ 'AIRFLOW_CONTEXT_DAG_ID': {'default': 'airflow.ctx.dag_id', 'env_var_format': 'AIRFLOW_CTX_DAG_ID'},
+ 'AIRFLOW_CONTEXT_TASK_ID': {'default': 'airflow.ctx.task_id', 'env_var_format': 'AIRFLOW_CTX_TASK_ID'},
+ 'AIRFLOW_CONTEXT_EXECUTION_DATE': {
+ 'default': 'airflow.ctx.execution_date',
+ 'env_var_format': 'AIRFLOW_CTX_EXECUTION_DATE',
+ },
+ 'AIRFLOW_CONTEXT_DAG_RUN_ID': {
+ 'default': 'airflow.ctx.dag_run_id',
+ 'env_var_format': 'AIRFLOW_CTX_DAG_RUN_ID',
+ },
+ 'AIRFLOW_CONTEXT_DAG_OWNER': {
+ 'default': 'airflow.ctx.dag_owner',
+ 'env_var_format': 'AIRFLOW_CTX_DAG_OWNER',
+ },
+ 'AIRFLOW_CONTEXT_DAG_EMAIL': {
+ 'default': 'airflow.ctx.dag_email',
+ 'env_var_format': 'AIRFLOW_CTX_DAG_EMAIL',
+ },
}
@@ -54,33 +60,32 @@ def context_to_airflow_vars(context, in_env_var_format=False):
task = context.get('task')
if task and task.email:
if isinstance(task.email, str):
- params[AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_EMAIL'][
- name_format]] = task.email
+ params[AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_EMAIL'][name_format]] = task.email
elif isinstance(task.email, list):
# os env variable value needs to be string
- params[AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_EMAIL'][
- name_format]] = ','.join(task.email)
+ params[AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_EMAIL'][name_format]] = ','.join(
+ task.email
+ )
if task and task.owner:
if isinstance(task.owner, str):
- params[AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_OWNER'][
- name_format]] = task.owner
+ params[AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_OWNER'][name_format]] = task.owner
elif isinstance(task.owner, list):
# os env variable value needs to be string
- params[AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_OWNER'][
- name_format]] = ','.join(task.owner)
+ params[AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_OWNER'][name_format]] = ','.join(
+ task.owner
+ )
task_instance = context.get('task_instance')
if task_instance and task_instance.dag_id:
- params[AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_ID'][
- name_format]] = task_instance.dag_id
+ params[AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_ID'][name_format]] = task_instance.dag_id
if task_instance and task_instance.task_id:
- params[AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_TASK_ID'][
- name_format]] = task_instance.task_id
+ params[
+ AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_TASK_ID'][name_format]
+ ] = task_instance.task_id
if task_instance and task_instance.execution_date:
params[
- AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_EXECUTION_DATE'][
- name_format]] = task_instance.execution_date.isoformat()
+ AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_EXECUTION_DATE'][name_format]
+ ] = task_instance.execution_date.isoformat()
dag_run = context.get('dag_run')
if dag_run and dag_run.run_id:
- params[AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_RUN_ID'][
- name_format]] = dag_run.run_id
+ params[AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_RUN_ID'][name_format]] = dag_run.run_id
return params
diff --git a/airflow/utils/operator_resources.py b/airflow/utils/operator_resources.py
index b7f42ba2043ef..878102183ea32 100644
--- a/airflow/utils/operator_resources.py
+++ b/airflow/utils/operator_resources.py
@@ -45,7 +45,8 @@ def __init__(self, name, units_str, qty):
if qty < 0:
raise AirflowException(
'Received resource quantity {} for resource {} but resource quantity '
- 'must be non-negative.'.format(qty, name))
+ 'must be non-negative.'.format(qty, name)
+ )
self._name = name
self._units_str = units_str
@@ -119,12 +120,13 @@ class Resources:
:type gpus: long
"""
- def __init__(self,
- cpus=conf.getint('operators', 'default_cpus'),
- ram=conf.getint('operators', 'default_ram'),
- disk=conf.getint('operators', 'default_disk'),
- gpus=conf.getint('operators', 'default_gpus')
- ):
+ def __init__(
+ self,
+ cpus=conf.getint('operators', 'default_cpus'),
+ ram=conf.getint('operators', 'default_ram'),
+ disk=conf.getint('operators', 'default_disk'),
+ gpus=conf.getint('operators', 'default_gpus'),
+ ):
self.cpus = CpuResource(cpus)
self.ram = RamResource(ram)
self.disk = DiskResource(disk)
diff --git a/airflow/utils/orm_event_handlers.py b/airflow/utils/orm_event_handlers.py
index ab4954739a715..4f7d2b3fdbef3 100644
--- a/airflow/utils/orm_event_handlers.py
+++ b/airflow/utils/orm_event_handlers.py
@@ -36,6 +36,7 @@ def connect(dbapi_connection, connection_record):
connection_record.info['pid'] = os.getpid()
if engine.dialect.name == "sqlite":
+
@event.listens_for(engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
@@ -44,6 +45,7 @@ def set_sqlite_pragma(dbapi_connection, connection_record):
# this ensures sanity in mysql when storing datetimes (not required for postgres)
if engine.dialect.name == "mysql":
+
@event.listens_for(engine, "connect")
def set_mysql_timezone(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
@@ -59,7 +61,9 @@ def checkout(dbapi_connection, connection_record, connection_proxy):
"Connection record belongs to pid {}, "
"attempting to check out in pid {}".format(connection_record.info['pid'], pid)
)
+
if conf.getboolean('debug', 'sqlalchemy_stats', fallback=False):
+
@event.listens_for(engine, "before_cursor_execute")
def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
conn.info.setdefault('query_start_time', []).append(time.time())
@@ -68,12 +72,19 @@ def before_cursor_execute(conn, cursor, statement, parameters, context, executem
def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
total = time.time() - conn.info['query_start_time'].pop()
file_name = [
- f"'{f.name}':{f.filename}:{f.lineno}" for f
- in traceback.extract_stack() if 'sqlalchemy' not in f.filename][-1]
+ f"'{f.name}':{f.filename}:{f.lineno}"
+ for f in traceback.extract_stack()
+ if 'sqlalchemy' not in f.filename
+ ][-1]
stack = [f for f in traceback.extract_stack() if 'sqlalchemy' not in f.filename]
stack_info = ">".join([f"{f.filename.rpartition('/')[-1]}:{f.name}" for f in stack][-3:])
conn.info.setdefault('query_start_time', []).append(time.monotonic())
- log.info("@SQLALCHEMY %s |$ %s |$ %s |$ %s ",
- total, file_name, stack_info, statement.replace("\n", " ")
- )
+ log.info(
+ "@SQLALCHEMY %s |$ %s |$ %s |$ %s ",
+ total,
+ file_name,
+ stack_info,
+ statement.replace("\n", " "),
+ )
+
# pylint: enable=unused-argument, unused-variable
diff --git a/airflow/utils/process_utils.py b/airflow/utils/process_utils.py
index 226ec86453d97..1ee1ac60937c9 100644
--- a/airflow/utils/process_utils.py
+++ b/airflow/utils/process_utils.py
@@ -41,9 +41,7 @@
# When killing processes, time to wait after issuing a SIGTERM before issuing a
# SIGKILL.
-DEFAULT_TIME_TO_WAIT_AFTER_SIGTERM = conf.getint(
- 'core', 'KILLED_TASK_CLEANUP_TIME'
-)
+DEFAULT_TIME_TO_WAIT_AFTER_SIGTERM = conf.getint('core', 'KILLED_TASK_CLEANUP_TIME')
def reap_process_group(pgid, logger, sig=signal.SIGTERM, timeout=DEFAULT_TIME_TO_WAIT_AFTER_SIGTERM):
@@ -130,13 +128,7 @@ def execute_in_subprocess(cmd: List[str]):
:type cmd: List[str]
"""
log.info("Executing cmd: %s", " ".join([shlex.quote(c) for c in cmd]))
- proc = subprocess.Popen(
- cmd,
- stdout=subprocess.PIPE,
- stderr=subprocess.STDOUT,
- bufsize=0,
- close_fds=True
- )
+ proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, bufsize=0, close_fds=True)
log.info("Output:")
if proc.stdout:
with proc.stdout:
@@ -164,12 +156,7 @@ def execute_interactive(cmd: List[str], **kwargs):
try: # pylint: disable=too-many-nested-blocks
# use os.setsid() make it run in a new process group, or bash job control will not be enabled
proc = subprocess.Popen(
- cmd,
- stdin=slave_fd,
- stdout=slave_fd,
- stderr=slave_fd,
- universal_newlines=True,
- **kwargs
+ cmd, stdin=slave_fd, stdout=slave_fd, stderr=slave_fd, universal_newlines=True, **kwargs
)
while proc.poll() is None:
@@ -238,10 +225,7 @@ def patch_environ(new_env_variables: Dict[str, str]):
:param new_env_variables: Environment variables to set
"""
- current_env_state = {
- key: os.environ.get(key)
- for key in new_env_variables.keys()
- }
+ current_env_state = {key: os.environ.get(key) for key in new_env_variables.keys()}
os.environ.update(new_env_variables)
try: # pylint: disable=too-many-nested-blocks
yield
diff --git a/airflow/utils/python_virtualenv.py b/airflow/utils/python_virtualenv.py
index 486371904647b..185a60154df5d 100644
--- a/airflow/utils/python_virtualenv.py
+++ b/airflow/utils/python_virtualenv.py
@@ -43,10 +43,7 @@ def _generate_pip_install_cmd(tmp_dir: str, requirements: List[str]) -> Optional
def prepare_virtualenv(
- venv_directory: str,
- python_bin: str,
- system_site_packages: bool,
- requirements: List[str]
+ venv_directory: str, python_bin: str, system_site_packages: bool, requirements: List[str]
) -> str:
"""
Creates a virtual environment and installs the additional python packages
@@ -83,9 +80,6 @@ def write_python_script(jinja_context: dict, filename: str):
:type filename: str
"""
template_loader = jinja2.FileSystemLoader(searchpath=os.path.dirname(__file__))
- template_env = jinja2.Environment(
- loader=template_loader,
- undefined=jinja2.StrictUndefined
- )
+ template_env = jinja2.Environment(loader=template_loader, undefined=jinja2.StrictUndefined)
template = template_env.get_template('python_virtualenv_script.jinja2')
template.stream(**jinja_context).dump(filename)
diff --git a/airflow/utils/serve_logs.py b/airflow/utils/serve_logs.py
index 57717393d23e9..0fefa420b7d84 100644
--- a/airflow/utils/serve_logs.py
+++ b/airflow/utils/serve_logs.py
@@ -32,10 +32,8 @@ def serve_logs():
def serve_logs_view(filename): # pylint: disable=unused-variable
log_directory = os.path.expanduser(conf.get('logging', 'BASE_LOG_FOLDER'))
return flask.send_from_directory(
- log_directory,
- filename,
- mimetype="application/json",
- as_attachment=False)
+ log_directory, filename, mimetype="application/json", as_attachment=False
+ )
worker_log_server_port = conf.getint('celery', 'WORKER_LOG_SERVER_PORT')
flask_app.run(host='0.0.0.0', port=worker_log_server_port)
diff --git a/airflow/utils/session.py b/airflow/utils/session.py
index 979c23a2fad8c..ea526edb12288 100644
--- a/airflow/utils/session.py
+++ b/airflow/utils/session.py
@@ -46,13 +46,13 @@ def provide_session(func: Callable[..., RT]) -> Callable[..., RT]:
database transaction, you pass it to the function, if not this wrapper
will create one and close it for you.
"""
+
@wraps(func)
def wrapper(*args, **kwargs) -> RT:
arg_session = 'session'
func_params = func.__code__.co_varnames
- session_in_args = arg_session in func_params and \
- func_params.index(arg_session) < len(args)
+ session_in_args = arg_session in func_params and func_params.index(arg_session) < len(args)
session_in_kwargs = arg_session in kwargs
if session_in_kwargs or session_in_args:
diff --git a/airflow/utils/sqlalchemy.py b/airflow/utils/sqlalchemy.py
index 8025b4304c3c8..a5bba18678361 100644
--- a/airflow/utils/sqlalchemy.py
+++ b/airflow/utils/sqlalchemy.py
@@ -59,8 +59,7 @@ class UtcDateTime(TypeDecorator):
def process_bind_param(self, value, dialect):
if value is not None:
if not isinstance(value, datetime.datetime):
- raise TypeError('expected datetime.datetime, not ' +
- repr(value))
+ raise TypeError('expected datetime.datetime, not ' + repr(value))
elif value.tzinfo is None:
raise ValueError('naive datetime is disallowed')
# For mysql we should store timestamps as naive values
@@ -70,6 +69,7 @@ def process_bind_param(self, value, dialect):
# See https://issues.apache.org/jira/browse/AIRFLOW-7001
if using_mysql:
from airflow.utils.timezone import make_naive
+
return make_naive(value, timezone=utc)
return value.astimezone(utc)
return None
@@ -99,17 +99,27 @@ class Interval(TypeDecorator):
attr_keys = {
datetime.timedelta: ('days', 'seconds', 'microseconds'),
relativedelta.relativedelta: (
- 'years', 'months', 'days', 'leapdays', 'hours', 'minutes', 'seconds', 'microseconds',
- 'year', 'month', 'day', 'hour', 'minute', 'second', 'microsecond',
+ 'years',
+ 'months',
+ 'days',
+ 'leapdays',
+ 'hours',
+ 'minutes',
+ 'seconds',
+ 'microseconds',
+ 'year',
+ 'month',
+ 'day',
+ 'hour',
+ 'minute',
+ 'second',
+ 'microsecond',
),
}
def process_bind_param(self, value, dialect):
if isinstance(value, tuple(self.attr_keys)):
- attrs = {
- key: getattr(value, key)
- for key in self.attr_keys[type(value)]
- }
+ attrs = {key: getattr(value, key) for key in self.attr_keys[type(value)]}
return json.dumps({'type': type(value).__name__, 'attrs': attrs})
return json.dumps(value)
diff --git a/airflow/utils/state.py b/airflow/utils/state.py
index 81164ff00ca0a..3ead44fec792f 100644
--- a/airflow/utils/state.py
+++ b/airflow/utils/state.py
@@ -95,20 +95,19 @@ def color_fg(cls, state):
return 'white'
return 'black'
- running = frozenset([
- RUNNING,
- SENSING
- ])
+ running = frozenset([RUNNING, SENSING])
"""
A list of states indicating that a task is being executed.
"""
- finished = frozenset([
- SUCCESS,
- FAILED,
- SKIPPED,
- UPSTREAM_FAILED,
- ])
+ finished = frozenset(
+ [
+ SUCCESS,
+ FAILED,
+ SKIPPED,
+ UPSTREAM_FAILED,
+ ]
+ )
"""
A list of states indicating a task has reached a terminal state (i.e. it has "finished") and needs no
further action.
@@ -118,16 +117,18 @@ def color_fg(cls, state):
case, it is no longer running.
"""
- unfinished = frozenset([
- NONE,
- SCHEDULED,
- QUEUED,
- RUNNING,
- SENSING,
- SHUTDOWN,
- UP_FOR_RETRY,
- UP_FOR_RESCHEDULE,
- ])
+ unfinished = frozenset(
+ [
+ NONE,
+ SCHEDULED,
+ QUEUED,
+ RUNNING,
+ SENSING,
+ SHUTDOWN,
+ UP_FOR_RETRY,
+ UP_FOR_RESCHEDULE,
+ ]
+ )
"""
A list of states indicating that a task either has not completed
a run or has not even started.
diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py
index ee77694269654..d3d29c2a99668 100644
--- a/airflow/utils/task_group.py
+++ b/airflow/utils/task_group.py
@@ -180,8 +180,10 @@ def update_relative(self, other: "TaskMixin", upstream=True) -> None:
# Handles setting relationship between a TaskGroup and a task
for task in other.roots:
if not isinstance(task, BaseOperator):
- raise AirflowException("Relationships can only be set between TaskGroup "
- f"or operators; received {task.__class__.__name__}")
+ raise AirflowException(
+ "Relationships can only be set between TaskGroup "
+ f"or operators; received {task.__class__.__name__}"
+ )
if upstream:
self.upstream_task_ids.add(task.task_id)
@@ -189,9 +191,7 @@ def update_relative(self, other: "TaskMixin", upstream=True) -> None:
self.downstream_task_ids.add(task.task_id)
def _set_relative(
- self,
- task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]],
- upstream: bool = False
+ self, task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]], upstream: bool = False
) -> None:
"""
Call set_upstream/set_downstream for all root/leaf tasks within this TaskGroup.
@@ -210,15 +210,11 @@ def _set_relative(
for task_like in task_or_task_list:
self.update_relative(task_like, upstream)
- def set_downstream(
- self, task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]]
- ) -> None:
+ def set_downstream(self, task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]]) -> None:
"""Set a TaskGroup/task/list of task downstream of this TaskGroup."""
self._set_relative(task_or_task_list, upstream=False)
- def set_upstream(
- self, task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]]
- ) -> None:
+ def set_upstream(self, task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]]) -> None:
"""Set a TaskGroup/task/list of task upstream of this TaskGroup."""
self._set_relative(task_or_task_list, upstream=True)
diff --git a/airflow/utils/timeout.py b/airflow/utils/timeout.py
index 8c5b2490786db..b94a7bb63893e 100644
--- a/airflow/utils/timeout.py
+++ b/airflow/utils/timeout.py
@@ -31,7 +31,7 @@ def __init__(self, seconds=1, error_message='Timeout'):
self.seconds = seconds
self.error_message = error_message + ', PID: ' + str(os.getpid())
- def handle_timeout(self, signum, frame): # pylint: disable=unused-argument
+ def handle_timeout(self, signum, frame): # pylint: disable=unused-argument
"""Logs information and raises AirflowTaskTimeout."""
self.log.error("Process timed out, PID: %s", str(os.getpid()))
raise AirflowTaskTimeout(self.error_message)
diff --git a/airflow/utils/timezone.py b/airflow/utils/timezone.py
index 95a555f3624e5..d302cbe1a7c6c 100644
--- a/airflow/utils/timezone.py
+++ b/airflow/utils/timezone.py
@@ -109,8 +109,7 @@ def make_aware(value, timezone=None):
# Check that we won't overwrite the timezone of an aware datetime.
if is_localized(value):
- raise ValueError(
- "make_aware expects a naive datetime, got %s" % value)
+ raise ValueError("make_aware expects a naive datetime, got %s" % value)
if hasattr(value, 'fold'):
# In case of python 3.6 we want to do the same that pendulum does for python3.5
# i.e in case we move clock back we want to schedule the run at the time of the second
@@ -146,13 +145,9 @@ def make_naive(value, timezone=None):
date = value.astimezone(timezone)
# cross library compatibility
- naive = dt.datetime(date.year,
- date.month,
- date.day,
- date.hour,
- date.minute,
- date.second,
- date.microsecond)
+ naive = dt.datetime(
+ date.year, date.month, date.day, date.hour, date.minute, date.second, date.microsecond
+ )
return naive
diff --git a/airflow/utils/weekday.py b/airflow/utils/weekday.py
index d83b134956099..698ec6067187f 100644
--- a/airflow/utils/weekday.py
+++ b/airflow/utils/weekday.py
@@ -42,8 +42,6 @@ def get_weekday_number(cls, week_day_str):
sanitized_week_day_str = week_day_str.upper()
if sanitized_week_day_str not in cls.__members__:
- raise AttributeError(
- f'Invalid Week Day passed: "{week_day_str}"'
- )
+ raise AttributeError(f'Invalid Week Day passed: "{week_day_str}"')
return cls[sanitized_week_day_str]
diff --git a/airflow/www/api/experimental/endpoints.py b/airflow/www/api/experimental/endpoints.py
index 60b5dc6102c7e..ada059b8575ef 100644
--- a/airflow/www/api/experimental/endpoints.py
+++ b/airflow/www/api/experimental/endpoints.py
@@ -42,6 +42,7 @@
def requires_authentication(function: T):
"""Decorator for functions that require authentication"""
+
@wraps(function)
def decorated(*args, **kwargs):
return current_app.api_auth.requires_authentication(function)(*args, **kwargs)
@@ -98,15 +99,15 @@ def trigger_dag(dag_id):
except ValueError:
error_message = (
'Given execution date, {}, could not be identified '
- 'as a date. Example date format: 2015-11-16T14:34:15+00:00'
- .format(execution_date))
+ 'as a date. Example date format: 2015-11-16T14:34:15+00:00'.format(execution_date)
+ )
log.error(error_message)
response = jsonify({'error': error_message})
response.status_code = 400
return response
- replace_microseconds = (execution_date is None)
+ replace_microseconds = execution_date is None
if 'replace_microseconds' in data:
replace_microseconds = to_boolean(data['replace_microseconds'])
@@ -122,9 +123,7 @@ def trigger_dag(dag_id):
log.info("User %s created %s", g.user, dr)
response = jsonify(
- message=f"Created {dr}",
- execution_date=dr.execution_date.isoformat(),
- run_id=dr.run_id
+ message=f"Created {dr}", execution_date=dr.execution_date.isoformat(), run_id=dr.run_id
)
return response
@@ -206,9 +205,7 @@ def task_info(dag_id, task_id):
return response
# JSONify and return.
- fields = {k: str(v)
- for k, v in vars(t_info).items()
- if not k.startswith('_')}
+ fields = {k: str(v) for k, v in vars(t_info).items() if not k.startswith('_')}
return jsonify(fields)
@@ -236,8 +233,8 @@ def dag_is_paused(dag_id):
@api_experimental.route(
- '/dags//dag_runs//tasks/',
- methods=['GET'])
+ '/dags//dag_runs//tasks/', methods=['GET']
+)
@requires_authentication
def task_instance_info(dag_id, execution_date, task_id):
"""
@@ -252,8 +249,8 @@ def task_instance_info(dag_id, execution_date, task_id):
except ValueError:
error_message = (
'Given execution date, {}, could not be identified '
- 'as a date. Example date format: 2015-11-16T14:34:15+00:00'
- .format(execution_date))
+ 'as a date. Example date format: 2015-11-16T14:34:15+00:00'.format(execution_date)
+ )
log.error(error_message)
response = jsonify({'error': error_message})
response.status_code = 400
@@ -269,15 +266,11 @@ def task_instance_info(dag_id, execution_date, task_id):
return response
# JSONify and return.
- fields = {k: str(v)
- for k, v in vars(ti_info).items()
- if not k.startswith('_')}
+ fields = {k: str(v) for k, v in vars(ti_info).items() if not k.startswith('_')}
return jsonify(fields)
-@api_experimental.route(
- '/dags//dag_runs/',
- methods=['GET'])
+@api_experimental.route('/dags//dag_runs/', methods=['GET'])
@requires_authentication
def dag_run_status(dag_id, execution_date):
"""
@@ -292,8 +285,8 @@ def dag_run_status(dag_id, execution_date):
except ValueError:
error_message = (
'Given execution date, {}, could not be identified '
- 'as a date. Example date format: 2015-11-16T14:34:15+00:00'.format(
- execution_date))
+ 'as a date. Example date format: 2015-11-16T14:34:15+00:00'.format(execution_date)
+ )
log.error(error_message)
response = jsonify({'error': error_message})
response.status_code = 400
@@ -316,18 +309,21 @@ def dag_run_status(dag_id, execution_date):
def latest_dag_runs():
"""Returns the latest DagRun for each DAG formatted for the UI"""
from airflow.models import DagRun
+
dagruns = DagRun.get_latest_runs()
payload = []
for dagrun in dagruns:
if dagrun.execution_date:
- payload.append({
- 'dag_id': dagrun.dag_id,
- 'execution_date': dagrun.execution_date.isoformat(),
- 'start_date': ((dagrun.start_date or '') and
- dagrun.start_date.isoformat()),
- 'dag_run_url': url_for('Airflow.graph', dag_id=dagrun.dag_id,
- execution_date=dagrun.execution_date)
- })
+ payload.append(
+ {
+ 'dag_id': dagrun.dag_id,
+ 'execution_date': dagrun.execution_date.isoformat(),
+ 'start_date': ((dagrun.start_date or '') and dagrun.start_date.isoformat()),
+ 'dag_run_url': url_for(
+ 'Airflow.graph', dag_id=dagrun.dag_id, execution_date=dagrun.execution_date
+ ),
+ }
+ )
return jsonify(items=payload) # old flask versions don't support jsonifying arrays
@@ -392,8 +388,7 @@ def delete_pool(name):
return jsonify(pool.to_json())
-@api_experimental.route('/lineage//',
- methods=['GET'])
+@api_experimental.route('/lineage//', methods=['GET'])
def get_lineage(dag_id: str, execution_date: str):
"""Get Lineage details for a DagRun"""
# Convert string datetime into actual datetime
@@ -402,8 +397,8 @@ def get_lineage(dag_id: str, execution_date: str):
except ValueError:
error_message = (
'Given execution date, {}, could not be identified '
- 'as a date. Example date format: 2015-11-16T14:34:15+00:00'.format(
- execution_date))
+ 'as a date. Example date format: 2015-11-16T14:34:15+00:00'.format(execution_date)
+ )
log.error(error_message)
response = jsonify({'error': error_message})
response.status_code = 400
diff --git a/airflow/www/app.py b/airflow/www/app.py
index 5338dddcc17b5..04ce366425155 100644
--- a/airflow/www/app.py
+++ b/airflow/www/app.py
@@ -37,7 +37,11 @@
from airflow.www.extensions.init_security import init_api_experimental_auth, init_xframe_protection
from airflow.www.extensions.init_session import init_logout_timeout, init_permanent_session
from airflow.www.extensions.init_views import (
- init_api_connexion, init_api_experimental, init_appbuilder_views, init_error_handlers, init_flash_views,
+ init_api_connexion,
+ init_api_experimental,
+ init_appbuilder_views,
+ init_error_handlers,
+ init_flash_views,
init_plugins,
)
from airflow.www.extensions.init_wsgi_middlewares import init_wsgi_middleware
diff --git a/airflow/www/auth.py b/airflow/www/auth.py
index 41bdf8666f5fc..370f7c626d41f 100644
--- a/airflow/www/auth.py
+++ b/airflow/www/auth.py
@@ -25,6 +25,7 @@
def has_access(permissions: Optional[Sequence[Tuple[str, str]]] = None) -> Callable[[T], T]:
"""Factory for decorator that checks current user's permissions against required permissions."""
+
def requires_access_decorator(func: T):
@wraps(func)
def decorated(*args, **kwargs):
@@ -34,7 +35,12 @@ def decorated(*args, **kwargs):
else:
access_denied = "Access is Denied"
flash(access_denied, "danger")
- return redirect(url_for(appbuilder.sm.auth_view.__class__.__name__ + ".login", next=request.url,))
+ return redirect(
+ url_for(
+ appbuilder.sm.auth_view.__class__.__name__ + ".login",
+ next=request.url,
+ )
+ )
return cast(T, decorated)
diff --git a/airflow/www/extensions/init_jinja_globals.py b/airflow/www/extensions/init_jinja_globals.py
index 3554fff2f74fe..ff5a0a5899f39 100644
--- a/airflow/www/extensions/init_jinja_globals.py
+++ b/airflow/www/extensions/init_jinja_globals.py
@@ -64,7 +64,7 @@ def prepare_jinja_globals():
'log_animation_speed': conf.getint('webserver', 'log_animation_speed', fallback=1000),
'state_color_mapping': STATE_COLORS,
'airflow_version': airflow_version,
- 'git_version': git_version
+ 'git_version': git_version,
}
if 'analytics_tool' in conf.getsection('webserver'):
diff --git a/airflow/www/extensions/init_security.py b/airflow/www/extensions/init_security.py
index e2cc7bf6f8c78..544deebeb3af4 100644
--- a/airflow/www/extensions/init_security.py
+++ b/airflow/www/extensions/init_security.py
@@ -53,8 +53,5 @@ def init_api_experimental_auth(app):
app.api_auth = import_module(auth_backend)
app.api_auth.init_app(app)
except ImportError as err:
- log.critical(
- "Cannot import %s for API authentication due to: %s",
- auth_backend, err
- )
+ log.critical("Cannot import %s for API authentication due to: %s", auth_backend, err)
raise AirflowException(err)
diff --git a/airflow/www/forms.py b/airflow/www/forms.py
index 4c6cb90a11033..30abea5565747 100644
--- a/airflow/www/forms.py
+++ b/airflow/www/forms.py
@@ -22,14 +22,23 @@
import pendulum
from flask_appbuilder.fieldwidgets import (
- BS3PasswordFieldWidget, BS3TextAreaFieldWidget, BS3TextFieldWidget, Select2Widget,
+ BS3PasswordFieldWidget,
+ BS3TextAreaFieldWidget,
+ BS3TextFieldWidget,
+ Select2Widget,
)
from flask_appbuilder.forms import DynamicForm
from flask_babel import lazy_gettext
from flask_wtf import FlaskForm
from wtforms import widgets
from wtforms.fields import (
- BooleanField, Field, IntegerField, PasswordField, SelectField, StringField, TextAreaField,
+ BooleanField,
+ Field,
+ IntegerField,
+ PasswordField,
+ SelectField,
+ StringField,
+ TextAreaField,
)
from wtforms.validators import DataRequired, NumberRange, Optional
@@ -86,8 +95,7 @@ def _get_default_timezone(self):
class DateTimeForm(FlaskForm):
"""Date filter form needed for task views"""
- execution_date = DateTimeWithTimezoneField(
- "Execution date", widget=AirflowDateTimePickerWidget())
+ execution_date = DateTimeWithTimezoneField("Execution date", widget=AirflowDateTimePickerWidget())
class DateTimeWithNumRunsForm(FlaskForm):
@@ -97,14 +105,19 @@ class DateTimeWithNumRunsForm(FlaskForm):
"""
base_date = DateTimeWithTimezoneField(
- "Anchor date", widget=AirflowDateTimePickerWidget(), default=timezone.utcnow())
- num_runs = SelectField("Number of runs", default=25, choices=(
- (5, "5"),
- (25, "25"),
- (50, "50"),
- (100, "100"),
- (365, "365"),
- ))
+ "Anchor date", widget=AirflowDateTimePickerWidget(), default=timezone.utcnow()
+ )
+ num_runs = SelectField(
+ "Number of runs",
+ default=25,
+ choices=(
+ (5, "5"),
+ (25, "25"),
+ (50, "50"),
+ (100, "100"),
+ (365, "365"),
+ ),
+ )
class DateTimeWithNumRunsWithDagRunsForm(DateTimeWithNumRunsForm):
@@ -116,33 +129,26 @@ class DateTimeWithNumRunsWithDagRunsForm(DateTimeWithNumRunsForm):
class DagRunForm(DynamicForm):
"""Form for editing and adding DAG Run"""
- dag_id = StringField(
- lazy_gettext('Dag Id'),
- validators=[DataRequired()],
- widget=BS3TextFieldWidget())
- start_date = DateTimeWithTimezoneField(
- lazy_gettext('Start Date'),
- widget=AirflowDateTimePickerWidget())
- end_date = DateTimeWithTimezoneField(
- lazy_gettext('End Date'),
- widget=AirflowDateTimePickerWidget())
- run_id = StringField(
- lazy_gettext('Run Id'),
- validators=[DataRequired()],
- widget=BS3TextFieldWidget())
+ dag_id = StringField(lazy_gettext('Dag Id'), validators=[DataRequired()], widget=BS3TextFieldWidget())
+ start_date = DateTimeWithTimezoneField(lazy_gettext('Start Date'), widget=AirflowDateTimePickerWidget())
+ end_date = DateTimeWithTimezoneField(lazy_gettext('End Date'), widget=AirflowDateTimePickerWidget())
+ run_id = StringField(lazy_gettext('Run Id'), validators=[DataRequired()], widget=BS3TextFieldWidget())
state = SelectField(
lazy_gettext('State'),
- choices=(('success', 'success'), ('running', 'running'), ('failed', 'failed'),),
- widget=Select2Widget())
+ choices=(
+ ('success', 'success'),
+ ('running', 'running'),
+ ('failed', 'failed'),
+ ),
+ widget=Select2Widget(),
+ )
execution_date = DateTimeWithTimezoneField(
- lazy_gettext('Execution Date'),
- widget=AirflowDateTimePickerWidget())
- external_trigger = BooleanField(
- lazy_gettext('External Trigger'))
+ lazy_gettext('Execution Date'), widget=AirflowDateTimePickerWidget()
+ )
+ external_trigger = BooleanField(lazy_gettext('External Trigger'))
conf = TextAreaField(
- lazy_gettext('Conf'),
- validators=[ValidJson(), Optional()],
- widget=BS3TextAreaFieldWidget())
+ lazy_gettext('Conf'), validators=[ValidJson(), Optional()], widget=BS3TextAreaFieldWidget()
+ )
def populate_obj(self, item):
"""Populates the attributes of the passed obj with data from the form’s fields."""
@@ -214,83 +220,62 @@ def populate_obj(self, item):
class ConnectionForm(DynamicForm):
"""Form for editing and adding Connection"""
- conn_id = StringField(
- lazy_gettext('Conn Id'),
- widget=BS3TextFieldWidget())
+ conn_id = StringField(lazy_gettext('Conn Id'), widget=BS3TextFieldWidget())
conn_type = SelectField(
lazy_gettext('Conn Type'),
choices=sorted(_connection_types, key=itemgetter(1)), # pylint: disable=protected-access
- widget=Select2Widget())
- host = StringField(
- lazy_gettext('Host'),
- widget=BS3TextFieldWidget())
- schema = StringField(
- lazy_gettext('Schema'),
- widget=BS3TextFieldWidget())
- login = StringField(
- lazy_gettext('Login'),
- widget=BS3TextFieldWidget())
- password = PasswordField(
- lazy_gettext('Password'),
- widget=BS3PasswordFieldWidget())
- port = IntegerField(
- lazy_gettext('Port'),
- validators=[Optional()],
- widget=BS3TextFieldWidget())
- extra = TextAreaField(
- lazy_gettext('Extra'),
- widget=BS3TextAreaFieldWidget())
+ widget=Select2Widget(),
+ )
+ host = StringField(lazy_gettext('Host'), widget=BS3TextFieldWidget())
+ schema = StringField(lazy_gettext('Schema'), widget=BS3TextFieldWidget())
+ login = StringField(lazy_gettext('Login'), widget=BS3TextFieldWidget())
+ password = PasswordField(lazy_gettext('Password'), widget=BS3PasswordFieldWidget())
+ port = IntegerField(lazy_gettext('Port'), validators=[Optional()], widget=BS3TextFieldWidget())
+ extra = TextAreaField(lazy_gettext('Extra'), widget=BS3TextAreaFieldWidget())
# Used to customized the form, the forms elements get rendered
# and results are stored in the extra field as json. All of these
# need to be prefixed with extra__ and then the conn_type ___ as in
# extra__{conn_type}__name. You can also hide form elements and rename
# others from the connection_form.js file
- extra__jdbc__drv_path = StringField(
- lazy_gettext('Driver Path'),
- widget=BS3TextFieldWidget())
- extra__jdbc__drv_clsname = StringField(
- lazy_gettext('Driver Class'),
- widget=BS3TextFieldWidget())
+ extra__jdbc__drv_path = StringField(lazy_gettext('Driver Path'), widget=BS3TextFieldWidget())
+ extra__jdbc__drv_clsname = StringField(lazy_gettext('Driver Class'), widget=BS3TextFieldWidget())
extra__google_cloud_platform__project = StringField(
- lazy_gettext('Project Id'),
- widget=BS3TextFieldWidget())
+ lazy_gettext('Project Id'), widget=BS3TextFieldWidget()
+ )
extra__google_cloud_platform__key_path = StringField(
- lazy_gettext('Keyfile Path'),
- widget=BS3TextFieldWidget())
+ lazy_gettext('Keyfile Path'), widget=BS3TextFieldWidget()
+ )
extra__google_cloud_platform__keyfile_dict = PasswordField(
- lazy_gettext('Keyfile JSON'),
- widget=BS3PasswordFieldWidget())
+ lazy_gettext('Keyfile JSON'), widget=BS3PasswordFieldWidget()
+ )
extra__google_cloud_platform__scope = StringField(
- lazy_gettext('Scopes (comma separated)'),
- widget=BS3TextFieldWidget())
+ lazy_gettext('Scopes (comma separated)'), widget=BS3TextFieldWidget()
+ )
extra__google_cloud_platform__num_retries = IntegerField(
lazy_gettext('Number of Retries'),
validators=[NumberRange(min=0)],
widget=BS3TextFieldWidget(),
- default=5)
- extra__grpc__auth_type = StringField(
- lazy_gettext('Grpc Auth Type'),
- widget=BS3TextFieldWidget())
+ default=5,
+ )
+ extra__grpc__auth_type = StringField(lazy_gettext('Grpc Auth Type'), widget=BS3TextFieldWidget())
extra__grpc__credential_pem_file = StringField(
- lazy_gettext('Credential Keyfile Path'),
- widget=BS3TextFieldWidget())
- extra__grpc__scopes = StringField(
- lazy_gettext('Scopes (comma separated)'),
- widget=BS3TextFieldWidget())
+ lazy_gettext('Credential Keyfile Path'), widget=BS3TextFieldWidget()
+ )
+ extra__grpc__scopes = StringField(lazy_gettext('Scopes (comma separated)'), widget=BS3TextFieldWidget())
extra__yandexcloud__service_account_json = PasswordField(
lazy_gettext('Service account auth JSON'),
widget=BS3PasswordFieldWidget(),
description='Service account auth JSON. Looks like '
- '{"id", "...", "service_account_id": "...", "private_key": "..."}. '
- 'Will be used instead of OAuth token and SA JSON file path field if specified.',
+ '{"id", "...", "service_account_id": "...", "private_key": "..."}. '
+ 'Will be used instead of OAuth token and SA JSON file path field if specified.',
)
extra__yandexcloud__service_account_json_path = StringField(
lazy_gettext('Service account auth JSON file path'),
widget=BS3TextFieldWidget(),
description='Service account auth JSON file path. File content looks like '
- '{"id", "...", "service_account_id": "...", "private_key": "..."}. '
- 'Will be used instead of OAuth token if specified.',
+ '{"id", "...", "service_account_id": "...", "private_key": "..."}. '
+ 'Will be used instead of OAuth token if specified.',
)
extra__yandexcloud__oauth = PasswordField(
lazy_gettext('OAuth Token'),
@@ -308,14 +293,11 @@ class ConnectionForm(DynamicForm):
description='Optional. This key will be placed to all created Compute nodes'
'to let you have a root shell there',
)
- extra__kubernetes__in_cluster = BooleanField(
- lazy_gettext('In cluster configuration'))
+ extra__kubernetes__in_cluster = BooleanField(lazy_gettext('In cluster configuration'))
extra__kubernetes__kube_config_path = StringField(
- lazy_gettext('Kube config path'),
- widget=BS3TextFieldWidget())
+ lazy_gettext('Kube config path'), widget=BS3TextFieldWidget()
+ )
extra__kubernetes__kube_config = StringField(
- lazy_gettext('Kube config (JSON format)'),
- widget=BS3TextFieldWidget())
- extra__kubernetes__namespace = StringField(
- lazy_gettext('Namespace'),
- widget=BS3TextFieldWidget())
+ lazy_gettext('Kube config (JSON format)'), widget=BS3TextFieldWidget()
+ )
+ extra__kubernetes__namespace = StringField(lazy_gettext('Namespace'), widget=BS3TextFieldWidget())
diff --git a/airflow/www/security.py b/airflow/www/security.py
index be35ddbd65ef6..0654732597f7b 100644
--- a/airflow/www/security.py
+++ b/airflow/www/security.py
@@ -130,8 +130,14 @@ class AirflowSecurityManager(SecurityManager, LoggingMixin):
ROLE_CONFIGS = [
{'role': 'Viewer', 'perms': VIEWER_PERMISSIONS},
- {'role': 'User', 'perms': VIEWER_PERMISSIONS + USER_PERMISSIONS,},
- {'role': 'Op', 'perms': VIEWER_PERMISSIONS + USER_PERMISSIONS + OP_PERMISSIONS,},
+ {
+ 'role': 'User',
+ 'perms': VIEWER_PERMISSIONS + USER_PERMISSIONS,
+ },
+ {
+ 'role': 'Op',
+ 'perms': VIEWER_PERMISSIONS + USER_PERMISSIONS + OP_PERMISSIONS,
+ },
{
'role': 'Admin',
'perms': VIEWER_PERMISSIONS + USER_PERMISSIONS + OP_PERMISSIONS + ADMIN_PERMISSIONS,
diff --git a/airflow/www/utils.py b/airflow/www/utils.py
index 697f8806b1903..6095fb306ce5a 100644
--- a/airflow/www/utils.py
+++ b/airflow/www/utils.py
@@ -72,11 +72,7 @@ def get_params(**kwargs):
return urlencode({d: v for d, v in kwargs.items() if v is not None})
-def generate_pages(current_page,
- num_of_pages,
- search=None,
- status=None,
- window=7):
+def generate_pages(current_page, num_of_pages, search=None, status=None, window=7):
"""
Generates the HTML for a paging component using a similar logic to the paging
auto-generated by Flask managed views. The paging component defines a number of
@@ -97,43 +93,51 @@ def generate_pages(current_page,
:return: the HTML string of the paging component
"""
void_link = 'javascript:void(0)'
- first_node = Markup("""
+ first_node = Markup(
+ """
«
-""")
+"""
+ )
- previous_node = Markup("""
+ previous_node = Markup(
+ """
‹
-""")
+"""
+ )
- next_node = Markup("""
+ next_node = Markup(
+ """
›
-""")
+"""
+ )
- last_node = Markup("""
+ last_node = Markup(
+ """
»
-""")
+"""
+ )
- page_node = Markup("""
+ page_node = Markup(
+ """
{page_num}
-""")
+"""
+ )
output = [Markup(''))
@@ -186,10 +191,8 @@ def epoch(dttm):
def json_response(obj):
"""Returns a json response from a json serializable python object"""
return Response(
- response=json.dumps(
- obj, indent=4, cls=AirflowJsonEncoder),
- status=200,
- mimetype="application/json")
+ response=json.dumps(obj, indent=4, cls=AirflowJsonEncoder), status=200, mimetype="application/json"
+ )
def make_cache_key(*args, **kwargs):
@@ -204,16 +207,10 @@ def task_instance_link(attr):
dag_id = attr.get('dag_id')
task_id = attr.get('task_id')
execution_date = attr.get('execution_date')
- url = url_for(
- 'Airflow.task',
- dag_id=dag_id,
- task_id=task_id,
- execution_date=execution_date.isoformat())
+ url = url_for('Airflow.task', dag_id=dag_id, task_id=task_id, execution_date=execution_date.isoformat())
url_root = url_for(
- 'Airflow.graph',
- dag_id=dag_id,
- root=task_id,
- execution_date=execution_date.isoformat())
+ 'Airflow.graph', dag_id=dag_id, root=task_id, execution_date=execution_date.isoformat()
+ )
return Markup( # noqa
"""
@@ -223,7 +220,8 @@ def task_instance_link(attr):
aria-hidden="true">filter_alt
- """).format(url=url, task_id=task_id, url_root=url_root)
+ """
+ ).format(url=url, task_id=task_id, url_root=url_root)
def state_token(state):
@@ -234,7 +232,8 @@ def state_token(state):
"""
{state}
- """).format(color=color, state=state, fg_color=fg_color)
+ """
+ ).format(color=color, state=state, fg_color=fg_color)
def state_f(attr):
@@ -245,14 +244,17 @@ def state_f(attr):
def nobr_f(attr_name):
"""Returns a formatted string with HTML with a Non-breaking Text element"""
+
def nobr(attr):
f = attr.get(attr_name)
return Markup("{}").format(f) # noqa
+
return nobr
def datetime_f(attr_name):
"""Returns a formatted string with HTML for given DataTime"""
+
def dt(attr): # pylint: disable=invalid-name
f = attr.get(attr_name)
as_iso = f.isoformat() if f else ''
@@ -263,16 +265,21 @@ def dt(attr): # pylint: disable=invalid-name
f = f[5:]
# The empty title will be replaced in JS code when non-UTC dates are displayed
return Markup('').format(as_iso, f) # noqa
+
return dt
+
+
# pylint: enable=invalid-name
def json_f(attr_name):
"""Returns a formatted string with HTML for given JSON serializable"""
+
def json_(attr):
f = attr.get(attr_name)
serialized = json.dumps(f)
return Markup('{}').format(serialized) # noqa
+
return json_
@@ -280,12 +287,8 @@ def dag_link(attr):
"""Generates a URL to the Graph View for a Dag."""
dag_id = attr.get('dag_id')
execution_date = attr.get('execution_date')
- url = url_for(
- 'Airflow.graph',
- dag_id=dag_id,
- execution_date=execution_date)
- return Markup( # noqa
- '{}').format(url, dag_id) # noqa
+ url = url_for('Airflow.graph', dag_id=dag_id, execution_date=execution_date)
+ return Markup('{}').format(url, dag_id) # noqa # noqa
def dag_run_link(attr):
@@ -293,13 +296,8 @@ def dag_run_link(attr):
dag_id = attr.get('dag_id')
run_id = attr.get('run_id')
execution_date = attr.get('execution_date')
- url = url_for(
- 'Airflow.graph',
- dag_id=dag_id,
- run_id=run_id,
- execution_date=execution_date)
- return Markup( # noqa
- '{run_id}').format(url=url, run_id=run_id) # noqa
+ url = url_for('Airflow.graph', dag_id=dag_id, run_id=run_id, execution_date=execution_date)
+ return Markup('{run_id}').format(url=url, run_id=run_id) # noqa # noqa
def pygment_html_render(s, lexer=lexers.TextLexer): # noqa pylint: disable=no-member
@@ -328,9 +326,7 @@ def wrapped_markdown(s, css_class=None):
if s is None:
return None
- return Markup(
- f'' + markdown.markdown(s) + "
"
- )
+ return Markup(f'' + markdown.markdown(s) + "
")
# pylint: disable=no-member
@@ -353,6 +349,8 @@ def get_attr_renderer():
'rst': lambda x: render(x, lexers.RstLexer),
'yaml': lambda x: render(x, lexers.YamlLexer),
}
+
+
# pylint: enable=no-member
@@ -397,12 +395,11 @@ class UtcAwareFilterConverter(fab_sqlafilters.SQLAFilterConverter): # noqa: D10
"""Retrieve conversion tables for UTC-Aware filters."""
conversion_table = (
- (('is_utcdatetime', [UtcAwareFilterEqual,
- UtcAwareFilterGreater,
- UtcAwareFilterSmaller,
- UtcAwareFilterNotEqual]),) +
- fab_sqlafilters.SQLAFilterConverter.conversion_table
- )
+ (
+ 'is_utcdatetime',
+ [UtcAwareFilterEqual, UtcAwareFilterGreater, UtcAwareFilterSmaller, UtcAwareFilterNotEqual],
+ ),
+ ) + fab_sqlafilters.SQLAFilterConverter.conversion_table
class CustomSQLAInterface(SQLAInterface):
@@ -418,11 +415,9 @@ def __init__(self, obj, session=None):
def clean_column_names():
if self.list_properties:
- self.list_properties = {
- k.lstrip('_'): v for k, v in self.list_properties.items()}
+ self.list_properties = {k.lstrip('_'): v for k, v in self.list_properties.items()}
if self.list_columns:
- self.list_columns = {
- k.lstrip('_'): v for k, v in self.list_columns.items()}
+ self.list_columns = {k.lstrip('_'): v for k, v in self.list_columns.items()}
clean_column_names()
@@ -432,9 +427,11 @@ def is_utcdatetime(self, col_name):
if col_name in self.list_columns:
obj = self.list_columns[col_name].type
- return isinstance(obj, UtcDateTime) or \
- isinstance(obj, sqla.types.TypeDecorator) and \
- isinstance(obj.impl, UtcDateTime)
+ return (
+ isinstance(obj, UtcDateTime)
+ or isinstance(obj, sqla.types.TypeDecorator)
+ and isinstance(obj.impl, UtcDateTime)
+ )
return False
filter_converter_class = UtcAwareFilterConverter
@@ -444,6 +441,5 @@ def is_utcdatetime(self, col_name):
# subclass) so we have no other option than to edit the conversion table in
# place
FieldConverter.conversion_table = (
- (('is_utcdatetime', DateTimeWithTimezoneField, AirflowDateTimePickerWidget),) +
- FieldConverter.conversion_table
-)
+ ('is_utcdatetime', DateTimeWithTimezoneField, AirflowDateTimePickerWidget),
+) + FieldConverter.conversion_table
diff --git a/airflow/www/validators.py b/airflow/www/validators.py
index 282789c15dbb0..9699a47e46b3e 100644
--- a/airflow/www/validators.py
+++ b/airflow/www/validators.py
@@ -37,17 +37,14 @@ def __call__(self, form, field):
try:
other = form[self.fieldname]
except KeyError:
- raise ValidationError(
- field.gettext("Invalid field name '%s'." % self.fieldname)
- )
+ raise ValidationError(field.gettext("Invalid field name '%s'." % self.fieldname))
if field.data is None or other.data is None:
return
if field.data < other.data:
message_args = {
- 'other_label':
- hasattr(other, 'label') and other.label.text or self.fieldname,
+ 'other_label': hasattr(other, 'label') and other.label.text or self.fieldname,
'other_name': self.fieldname,
}
message = self.message
@@ -77,6 +74,4 @@ def __call__(self, form, field):
json.loads(field.data)
except JSONDecodeError as ex:
message = self.message or f'JSON Validation Error: {ex}'
- raise ValidationError(
- message=field.gettext(message.format(field.data))
- )
+ raise ValidationError(message=field.gettext(message.format(field.data)))
diff --git a/airflow/www/views.py b/airflow/www/views.py
index 3d16e77924bd0..fe818f26fb8f3 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -34,8 +34,19 @@
import nvd3
import sqlalchemy as sqla
from flask import (
- Markup, Response, current_app, escape, flash, g, jsonify, make_response, redirect, render_template,
- request, session as flask_session, url_for,
+ Markup,
+ Response,
+ current_app,
+ escape,
+ flash,
+ g,
+ jsonify,
+ make_response,
+ redirect,
+ render_template,
+ request,
+ session as flask_session,
+ url_for,
)
from flask_appbuilder import BaseView, ModelView, expose
from flask_appbuilder.actions import action
@@ -51,7 +62,8 @@
import airflow
from airflow import models, plugins_manager, settings
from airflow.api.common.experimental.mark_tasks import (
- set_dag_run_state_to_failed, set_dag_run_state_to_success,
+ set_dag_run_state_to_failed,
+ set_dag_run_state_to_success,
)
from airflow.configuration import AIRFLOW_CONFIG, conf
from airflow.exceptions import AirflowException
@@ -76,7 +88,11 @@
from airflow.www import auth, utils as wwwutils
from airflow.www.decorators import action_logging, gzipped
from airflow.www.forms import (
- ConnectionForm, DagRunForm, DateTimeForm, DateTimeWithNumRunsForm, DateTimeWithNumRunsWithDagRunsForm,
+ ConnectionForm,
+ DagRunForm,
+ DateTimeForm,
+ DateTimeWithNumRunsForm,
+ DateTimeWithNumRunsWithDagRunsForm,
)
from airflow.www.widgets import AirflowModelListWidget
@@ -120,9 +136,7 @@ def get_date_time_num_runs_dag_runs_form_data(www_request, session, dag):
drs = (
session.query(DagRun)
- .filter(
- DagRun.dag_id == dag.dag_id,
- DagRun.execution_date <= base_date)
+ .filter(DagRun.dag_id == dag.dag_id, DagRun.execution_date <= base_date)
.order_by(desc(DagRun.execution_date))
.limit(num_runs)
.all()
@@ -164,34 +178,39 @@ def task_group_to_dict(task_group):
'style': f"fill:{task_group.ui_color};",
'rx': 5,
'ry': 5,
- }
+ },
}
- children = [task_group_to_dict(child) for child in
- sorted(task_group.children.values(), key=lambda t: t.label)]
+ children = [
+ task_group_to_dict(child) for child in sorted(task_group.children.values(), key=lambda t: t.label)
+ ]
if task_group.upstream_group_ids or task_group.upstream_task_ids:
- children.append({
- 'id': task_group.upstream_join_id,
- 'value': {
- 'label': '',
- 'labelStyle': f"fill:{task_group.ui_fgcolor};",
- 'style': f"fill:{task_group.ui_color};",
- 'shape': 'circle',
+ children.append(
+ {
+ 'id': task_group.upstream_join_id,
+ 'value': {
+ 'label': '',
+ 'labelStyle': f"fill:{task_group.ui_fgcolor};",
+ 'style': f"fill:{task_group.ui_color};",
+ 'shape': 'circle',
+ },
}
- })
+ )
if task_group.downstream_group_ids or task_group.downstream_task_ids:
# This is the join node used to reduce the number of edges between two TaskGroup.
- children.append({
- 'id': task_group.downstream_join_id,
- 'value': {
- 'label': '',
- 'labelStyle': f"fill:{task_group.ui_fgcolor};",
- 'style': f"fill:{task_group.ui_color};",
- 'shape': 'circle',
+ children.append(
+ {
+ 'id': task_group.downstream_join_id,
+ 'value': {
+ 'label': '',
+ 'labelStyle': f"fill:{task_group.ui_fgcolor};",
+ 'style': f"fill:{task_group.ui_color};",
+ 'shape': 'circle',
+ },
}
- })
+ )
return {
"id": task_group.group_id,
@@ -204,7 +223,7 @@ def task_group_to_dict(task_group):
'clusterLabelPos': 'top',
},
'tooltip': task_group.tooltip,
- 'children': children
+ 'children': children,
}
@@ -299,39 +318,47 @@ def get_downstream(task):
for root in dag.roots:
get_downstream(root)
- return [{'source_id': source_id, 'target_id': target_id}
- for source_id, target_id
- in sorted(edges.union(edges_to_add) - edges_to_skip)]
+ return [
+ {'source_id': source_id, 'target_id': target_id}
+ for source_id, target_id in sorted(edges.union(edges_to_add) - edges_to_skip)
+ ]
######################################################################################
# Error handlers
######################################################################################
+
def circles(error): # pylint: disable=unused-argument
"""Show Circles on screen for any error in the Webserver"""
- return render_template(
- 'airflow/circles.html', hostname=socket.getfqdn() if conf.getboolean( # noqa
- 'webserver',
- 'EXPOSE_HOSTNAME',
- fallback=True) else 'redact'), 404
+ return (
+ render_template(
+ 'airflow/circles.html',
+ hostname=socket.getfqdn()
+ if conf.getboolean('webserver', 'EXPOSE_HOSTNAME', fallback=True) # noqa
+ else 'redact',
+ ),
+ 404,
+ )
def show_traceback(error): # pylint: disable=unused-argument
"""Show Traceback for a given error"""
- return render_template(
- 'airflow/traceback.html', # noqa
- python_version=sys.version.split(" ")[0],
- airflow_version=version,
- hostname=socket.getfqdn() if conf.getboolean(
- 'webserver',
- 'EXPOSE_HOSTNAME',
- fallback=True) else 'redact',
- info=traceback.format_exc() if conf.getboolean(
- 'webserver',
- 'EXPOSE_STACKTRACE',
- fallback=True) else 'Error! Please contact server admin.'
- ), 500
+ return (
+ render_template(
+ 'airflow/traceback.html', # noqa
+ python_version=sys.version.split(" ")[0],
+ airflow_version=version,
+ hostname=socket.getfqdn()
+ if conf.getboolean('webserver', 'EXPOSE_HOSTNAME', fallback=True)
+ else 'redact',
+ info=traceback.format_exc()
+ if conf.getboolean('webserver', 'EXPOSE_STACKTRACE', fallback=True)
+ else 'Error! Please contact server admin.',
+ ),
+ 500,
+ )
+
######################################################################################
# BaseViews
@@ -342,6 +369,7 @@ class AirflowBaseView(BaseView): # noqa: D101
"""Base View to set Airflow related properties"""
from airflow import macros
+
route_base = ''
# Make our macros available to our UI templates too.
@@ -354,7 +382,7 @@ def render_template(self, *args, **kwargs):
*args,
# Cache this at most once per request, not for the lifetime of the view instance
scheduler_job=lazy_object_proxy.Proxy(SchedulerJob.most_recent_job),
- **kwargs
+ **kwargs,
)
@@ -367,9 +395,7 @@ def health(self):
An endpoint helping check the health status of the Airflow instance,
including metadatabase and scheduler.
"""
- payload = {
- 'metadatabase': {'status': 'unhealthy'}
- }
+ payload = {'metadatabase': {'status': 'unhealthy'}}
latest_scheduler_heartbeat = None
scheduler_status = 'unhealthy'
@@ -384,19 +410,22 @@ def health(self):
except Exception: # noqa pylint: disable=broad-except
payload['metadatabase']['status'] = 'unhealthy'
- payload['scheduler'] = {'status': scheduler_status,
- 'latest_scheduler_heartbeat': latest_scheduler_heartbeat}
+ payload['scheduler'] = {
+ 'status': scheduler_status,
+ 'latest_scheduler_heartbeat': latest_scheduler_heartbeat,
+ }
return wwwutils.json_response(payload)
@expose('/home')
- @auth.has_access([
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE),
- ]) # pylint: disable=too-many-locals,too-many-statements
+ @auth.has_access(
+ [
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE),
+ ]
+ ) # pylint: disable=too-many-locals,too-many-statements
def index(self):
"""Home view."""
- hide_paused_dags_by_default = conf.getboolean('webserver',
- 'hide_paused_dags_by_default')
+ hide_paused_dags_by_default = conf.getboolean('webserver', 'hide_paused_dags_by_default')
default_dag_run = conf.getint('webserver', 'default_dag_run_display_number')
num_runs = request.args.get('num_runs')
@@ -448,15 +477,13 @@ def get_int_arg(value, default=0):
with create_session() as session:
# read orm_dags from the db
- dags_query = session.query(DagModel).filter(
- ~DagModel.is_subdag, DagModel.is_active
- )
+ dags_query = session.query(DagModel).filter(~DagModel.is_subdag, DagModel.is_active)
# pylint: disable=no-member
if arg_search_query:
dags_query = dags_query.filter(
- DagModel.dag_id.ilike('%' + arg_search_query + '%') | # noqa
- DagModel.owners.ilike('%' + arg_search_query + '%') # noqa
+ DagModel.dag_id.ilike('%' + arg_search_query + '%')
+ | DagModel.owners.ilike('%' + arg_search_query + '%') # noqa # noqa
)
if arg_tags_filter:
@@ -472,7 +499,8 @@ def get_int_arg(value, default=0):
is_paused_count = dict(
all_dags.with_entities(DagModel.is_paused, func.count(DagModel.dag_id))
- .group_by(DagModel.is_paused).all()
+ .group_by(DagModel.is_paused)
+ .all()
)
status_count_active = is_paused_count.get(False, 0)
status_count_paused = is_paused_count.get(True, 0)
@@ -487,8 +515,13 @@ def get_int_arg(value, default=0):
current_dags = all_dags
num_of_all_dags = all_dags_count
- dags = current_dags.order_by(DagModel.dag_id).options(
- joinedload(DagModel.tags)).offset(start).limit(dags_per_page).all()
+ dags = (
+ current_dags.order_by(DagModel.dag_id)
+ .options(joinedload(DagModel.tags))
+ .offset(start)
+ .limit(dags_per_page)
+ .all()
+ )
dagtags = session.query(DagTag.name).distinct(DagTag.name).all()
tags = [
@@ -499,17 +532,15 @@ def get_int_arg(value, default=0):
import_errors = session.query(errors.ImportError).all()
for import_error in import_errors:
- flash(
- "Broken DAG: [{ie.filename}] {ie.stacktrace}".format(ie=import_error),
- "dag_import_error")
+ flash("Broken DAG: [{ie.filename}] {ie.stacktrace}".format(ie=import_error), "dag_import_error")
from airflow.plugins_manager import import_errors as plugin_import_errors
+
for filename, stacktrace in plugin_import_errors.items():
flash(
- "Broken plugin: [{filename}] {stacktrace}".format(
- stacktrace=stacktrace,
- filename=filename),
- "error")
+ f"Broken plugin: [{filename}] {stacktrace}",
+ "error",
+ )
num_of_pages = int(math.ceil(num_of_all_dags / float(dags_per_page)))
@@ -526,10 +557,12 @@ def get_int_arg(value, default=0):
num_dag_from=min(start + 1, num_of_all_dags),
num_dag_to=min(end, num_of_all_dags),
num_of_all_dags=num_of_all_dags,
- paging=wwwutils.generate_pages(current_page,
- num_of_pages,
- search=escape(arg_search_query) if arg_search_query else None,
- status=arg_status_filter if arg_status_filter else None),
+ paging=wwwutils.generate_pages(
+ current_page,
+ num_of_pages,
+ search=escape(arg_search_query) if arg_search_query else None,
+ status=arg_status_filter if arg_status_filter else None,
+ ),
num_runs=num_runs,
tags=tags,
state_color=state_color_mapping,
@@ -537,7 +570,8 @@ def get_int_arg(value, default=0):
status_count_all=all_dags_count,
status_count_active=status_count_active,
status_count_paused=status_count_paused,
- tags_filter=arg_tags_filter)
+ tags_filter=arg_tags_filter,
+ )
@expose('/dag_stats', methods=['POST'])
@auth.has_access(
@@ -555,13 +589,12 @@ def dag_stats(self, session=None):
if permissions.RESOURCE_DAG in allowed_dag_ids:
allowed_dag_ids = [dag_id for dag_id, in session.query(models.DagModel.dag_id)]
- dag_state_stats = session.query(dr.dag_id, dr.state, sqla.func.count(dr.state))\
- .group_by(dr.dag_id, dr.state)
+ dag_state_stats = session.query(dr.dag_id, dr.state, sqla.func.count(dr.state)).group_by(
+ dr.dag_id, dr.state
+ )
# Filter by post parameters
- selected_dag_ids = {
- unquote(dag_id) for dag_id in request.form.getlist('dag_ids') if dag_id
- }
+ selected_dag_ids = {unquote(dag_id) for dag_id in request.form.getlist('dag_ids') if dag_id}
if selected_dag_ids:
filter_dag_ids = selected_dag_ids.intersection(allowed_dag_ids)
@@ -584,10 +617,7 @@ def dag_stats(self, session=None):
payload[dag_id] = []
for state in State.dag_states:
count = data.get(dag_id, {}).get(state, 0)
- payload[dag_id].append({
- 'state': state,
- 'count': count
- })
+ payload[dag_id].append({'state': state, 'count': count})
return wwwutils.json_response(payload)
@@ -611,9 +641,7 @@ def task_stats(self, session=None):
allowed_dag_ids = {dag_id for dag_id, in session.query(models.DagModel.dag_id)}
# Filter by post parameters
- selected_dag_ids = {
- unquote(dag_id) for dag_id in request.form.getlist('dag_ids') if dag_id
- }
+ selected_dag_ids = {unquote(dag_id) for dag_id in request.form.getlist('dag_ids') if dag_id}
if selected_dag_ids:
filter_dag_ids = selected_dag_ids.intersection(allowed_dag_ids)
@@ -623,39 +651,41 @@ def task_stats(self, session=None):
# pylint: disable=comparison-with-callable
running_dag_run_query_result = (
session.query(DagRun.dag_id, DagRun.execution_date)
- .join(DagModel, DagModel.dag_id == DagRun.dag_id)
- .filter(DagRun.state == State.RUNNING, DagModel.is_active)
+ .join(DagModel, DagModel.dag_id == DagRun.dag_id)
+ .filter(DagRun.state == State.RUNNING, DagModel.is_active)
)
# pylint: enable=comparison-with-callable
# pylint: disable=no-member
if selected_dag_ids:
- running_dag_run_query_result = \
- running_dag_run_query_result.filter(DagRun.dag_id.in_(filter_dag_ids))
+ running_dag_run_query_result = running_dag_run_query_result.filter(
+ DagRun.dag_id.in_(filter_dag_ids)
+ )
# pylint: enable=no-member
running_dag_run_query_result = running_dag_run_query_result.subquery('running_dag_run')
# pylint: disable=no-member
# Select all task_instances from active dag_runs.
- running_task_instance_query_result = (
- session.query(TaskInstance.dag_id.label('dag_id'), TaskInstance.state.label('state'))
- .join(running_dag_run_query_result,
- and_(running_dag_run_query_result.c.dag_id == TaskInstance.dag_id,
- running_dag_run_query_result.c.execution_date == TaskInstance.execution_date))
+ running_task_instance_query_result = session.query(
+ TaskInstance.dag_id.label('dag_id'), TaskInstance.state.label('state')
+ ).join(
+ running_dag_run_query_result,
+ and_(
+ running_dag_run_query_result.c.dag_id == TaskInstance.dag_id,
+ running_dag_run_query_result.c.execution_date == TaskInstance.execution_date,
+ ),
)
if selected_dag_ids:
- running_task_instance_query_result = \
- running_task_instance_query_result.filter(TaskInstance.dag_id.in_(filter_dag_ids))
+ running_task_instance_query_result = running_task_instance_query_result.filter(
+ TaskInstance.dag_id.in_(filter_dag_ids)
+ )
# pylint: enable=no-member
if conf.getboolean('webserver', 'SHOW_RECENT_STATS_FOR_COMPLETED_RUNS', fallback=True):
# pylint: disable=comparison-with-callable
last_dag_run = (
- session.query(
- DagRun.dag_id,
- sqla.func.max(DagRun.execution_date).label('execution_date')
- )
+ session.query(DagRun.dag_id, sqla.func.max(DagRun.execution_date).label('execution_date'))
.join(DagModel, DagModel.dag_id == DagRun.dag_id)
.filter(DagRun.state != State.RUNNING, DagModel.is_active)
.group_by(DagRun.dag_id)
@@ -669,30 +699,33 @@ def task_stats(self, session=None):
# Select all task_instances from active dag_runs.
# If no dag_run is active, return task instances from most recent dag_run.
- last_task_instance_query_result = (
- session.query(TaskInstance.dag_id.label('dag_id'), TaskInstance.state.label('state'))
- .join(last_dag_run,
- and_(last_dag_run.c.dag_id == TaskInstance.dag_id,
- last_dag_run.c.execution_date == TaskInstance.execution_date))
+ last_task_instance_query_result = session.query(
+ TaskInstance.dag_id.label('dag_id'), TaskInstance.state.label('state')
+ ).join(
+ last_dag_run,
+ and_(
+ last_dag_run.c.dag_id == TaskInstance.dag_id,
+ last_dag_run.c.execution_date == TaskInstance.execution_date,
+ ),
)
# pylint: disable=no-member
if selected_dag_ids:
- last_task_instance_query_result = \
- last_task_instance_query_result.filter(TaskInstance.dag_id.in_(filter_dag_ids))
+ last_task_instance_query_result = last_task_instance_query_result.filter(
+ TaskInstance.dag_id.in_(filter_dag_ids)
+ )
# pylint: enable=no-member
final_task_instance_query_result = union_all(
- last_task_instance_query_result,
- running_task_instance_query_result).alias('final_ti')
+ last_task_instance_query_result, running_task_instance_query_result
+ ).alias('final_ti')
else:
final_task_instance_query_result = running_task_instance_query_result.subquery('final_ti')
- qry = (
- session.query(final_task_instance_query_result.c.dag_id,
- final_task_instance_query_result.c.state, sqla.func.count())
- .group_by(final_task_instance_query_result.c.dag_id,
- final_task_instance_query_result.c.state)
- )
+ qry = session.query(
+ final_task_instance_query_result.c.dag_id,
+ final_task_instance_query_result.c.state,
+ sqla.func.count(),
+ ).group_by(final_task_instance_query_result.c.dag_id, final_task_instance_query_result.c.state)
data = {}
for dag_id, state, count in qry:
@@ -705,10 +738,7 @@ def task_stats(self, session=None):
payload[dag_id] = []
for state in State.task_states:
count = data.get(dag_id, {}).get(state, 0)
- payload[dag_id].append({
- 'state': state,
- 'count': count
- })
+ payload[dag_id].append({'state': state, 'count': count})
return wwwutils.json_response(payload)
@expose('/last_dagruns', methods=['POST'])
@@ -727,9 +757,7 @@ def last_dagruns(self, session=None):
allowed_dag_ids = [dag_id for dag_id, in session.query(models.DagModel.dag_id)]
# Filter by post parameters
- selected_dag_ids = {
- unquote(dag_id) for dag_id in request.form.getlist('dag_ids') if dag_id
- }
+ selected_dag_ids = {unquote(dag_id) for dag_id in request.form.getlist('dag_ids') if dag_id}
if selected_dag_ids:
filter_dag_ids = selected_dag_ids.intersection(allowed_dag_ids)
@@ -753,7 +781,8 @@ def last_dagruns(self, session=None):
'dag_id': r.dag_id,
'execution_date': r.execution_date.isoformat(),
'start_date': r.start_date.isoformat(),
- } for r in query
+ }
+ for r in query
}
return wwwutils.json_response(resp)
@@ -775,22 +804,30 @@ def code(self, session=None):
dag_id = request.args.get('dag_id')
dag_orm = DagModel.get_dagmodel(dag_id, session=session)
code = DagCode.get_code_by_fileloc(dag_orm.fileloc)
- html_code = Markup(highlight(
- code, lexers.PythonLexer(), HtmlFormatter(linenos=True))) # pylint: disable=no-member
+ html_code = Markup(
+ highlight(
+ code, lexers.PythonLexer(), HtmlFormatter(linenos=True) # pylint: disable=no-member
+ )
+ )
except Exception as e: # pylint: disable=broad-except
all_errors += (
- "Exception encountered during " +
- f"dag_id retrieval/dag retrieval fallback/code highlighting:\n\n{e}\n"
+ "Exception encountered during "
+ + f"dag_id retrieval/dag retrieval fallback/code highlighting:\n\n{e}\n"
)
html_code = Markup('Failed to load file.
Details: {}
').format( # noqa
- escape(all_errors))
+ escape(all_errors)
+ )
return self.render_template(
- 'airflow/dag_code.html', html_code=html_code, dag=dag_orm, title=dag_id,
+ 'airflow/dag_code.html',
+ html_code=html_code,
+ dag=dag_orm,
+ title=dag_id,
root=request.args.get('root'),
demo_mode=conf.getboolean('webserver', 'demo_mode'),
- wrapped=conf.getboolean('webserver', 'default_wrap'))
+ wrapped=conf.getboolean('webserver', 'default_wrap'),
+ )
@expose('/dag_details')
@auth.has_access(
@@ -809,19 +846,14 @@ def dag_details(self, session=None):
states = (
session.query(TaskInstance.state, sqla.func.count(TaskInstance.dag_id))
- .filter(TaskInstance.dag_id == dag_id)
- .group_by(TaskInstance.state)
- .all()
+ .filter(TaskInstance.dag_id == dag_id)
+ .group_by(TaskInstance.state)
+ .all()
)
- active_runs = models.DagRun.find(
- dag_id=dag_id,
- state=State.RUNNING,
- external_trigger=False
- )
+ active_runs = models.DagRun.find(dag_id=dag_id, state=State.RUNNING, external_trigger=False)
- tags = session.query(models.DagTag).filter(
- models.DagTag.dag_id == dag_id).all()
+ tags = session.query(models.DagTag).filter(models.DagTag.dag_id == dag_id).all()
return self.render_template(
'airflow/dag_details.html',
@@ -831,7 +863,7 @@ def dag_details(self, session=None):
states=states,
State=State,
active_runs=active_runs,
- tags=tags
+ tags=tags,
)
@expose('/rendered')
@@ -877,8 +909,9 @@ def rendered(self):
content = json.dumps(content, sort_keys=True, indent=4)
html_dict[template_field] = renderers[renderer](content)
else:
- html_dict[template_field] = \
- Markup("{}
").format(pformat(content)) # noqa
+ html_dict[template_field] = Markup("{}
").format(
+ pformat(content)
+ ) # noqa
return self.render_template(
'airflow/ti_code.html',
@@ -888,14 +921,17 @@ def rendered(self):
execution_date=execution_date,
form=form,
root=root,
- title=title)
+ title=title,
+ )
@expose('/get_logs_with_metadata')
- @auth.has_access([
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_LOG),
- ])
+ @auth.has_access(
+ [
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_LOG),
+ ]
+ )
@action_logging
@provide_session
def get_logs_with_metadata(self, session=None):
@@ -921,8 +957,8 @@ def get_logs_with_metadata(self, session=None):
except ValueError:
error_message = (
'Given execution date, {}, could not be identified '
- 'as a date. Example date format: 2015-11-16T14:34:15+00:00'.format(
- execution_date))
+ 'as a date. Example date format: 2015-11-16T14:34:15+00:00'.format(execution_date)
+ )
response = jsonify({'error': error_message})
response.status_code = 400
@@ -933,23 +969,24 @@ def get_logs_with_metadata(self, session=None):
return jsonify(
message="Task log handler does not support read logs.",
error=True,
- metadata={
- "end_of_log": True
- }
+ metadata={"end_of_log": True},
)
- ti = session.query(models.TaskInstance).filter(
- models.TaskInstance.dag_id == dag_id,
- models.TaskInstance.task_id == task_id,
- models.TaskInstance.execution_date == execution_date).first()
+ ti = (
+ session.query(models.TaskInstance)
+ .filter(
+ models.TaskInstance.dag_id == dag_id,
+ models.TaskInstance.task_id == task_id,
+ models.TaskInstance.execution_date == execution_date,
+ )
+ .first()
+ )
if ti is None:
return jsonify(
message="*** Task instance did not exist in the DB\n",
error=True,
- metadata={
- "end_of_log": True
- }
+ metadata={"end_of_log": True},
)
try:
@@ -968,13 +1005,10 @@ def get_logs_with_metadata(self, session=None):
return Response(
response=log_stream,
mimetype="text/plain",
- headers={
- "Content-Disposition": f"attachment; filename={attachment_filename}"
- })
+ headers={"Content-Disposition": f"attachment; filename={attachment_filename}"},
+ )
except AttributeError as e:
- error_message = [
- f"Task log handler does not support read logs.\n{str(e)}\n"
- ]
+ error_message = [f"Task log handler does not support read logs.\n{str(e)}\n"]
metadata['end_of_log'] = True
return jsonify(message=error_message, error=True, metadata=metadata)
@@ -997,10 +1031,15 @@ def log(self, session=None):
form = DateTimeForm(data={'execution_date': dttm})
dag_model = DagModel.get_dagmodel(dag_id)
- ti = session.query(models.TaskInstance).filter(
- models.TaskInstance.dag_id == dag_id,
- models.TaskInstance.task_id == task_id,
- models.TaskInstance.execution_date == dttm).first()
+ ti = (
+ session.query(models.TaskInstance)
+ .filter(
+ models.TaskInstance.dag_id == dag_id,
+ models.TaskInstance.task_id == task_id,
+ models.TaskInstance.execution_date == dttm,
+ )
+ .first()
+ )
num_logs = 0
if ti is not None:
@@ -1012,10 +1051,16 @@ def log(self, session=None):
root = request.args.get('root', '')
return self.render_template(
'airflow/ti_log.html',
- logs=logs, dag=dag_model, title="Log by attempts",
- dag_id=dag_id, task_id=task_id,
- execution_date=execution_date, form=form,
- root=root, wrapped=conf.getboolean('webserver', 'default_wrap'))
+ logs=logs,
+ dag=dag_model,
+ title="Log by attempts",
+ dag_id=dag_id,
+ task_id=task_id,
+ execution_date=execution_date,
+ form=form,
+ root=root,
+ wrapped=conf.getboolean('webserver', 'default_wrap'),
+ )
@expose('/redirect_to_external_log')
@auth.has_access(
@@ -1035,10 +1080,15 @@ def redirect_to_external_log(self, session=None):
dttm = timezone.parse(execution_date)
try_number = request.args.get('try_number', 1)
- ti = session.query(models.TaskInstance).filter(
- models.TaskInstance.dag_id == dag_id,
- models.TaskInstance.task_id == task_id,
- models.TaskInstance.execution_date == dttm).first()
+ ti = (
+ session.query(models.TaskInstance)
+ .filter(
+ models.TaskInstance.dag_id == dag_id,
+ models.TaskInstance.task_id == task_id,
+ models.TaskInstance.execution_date == dttm,
+ )
+ .first()
+ )
if not ti:
flash(f"Task [{dag_id}.{task_id}] does not exist", "error")
@@ -1074,10 +1124,7 @@ def task(self):
dag = current_app.dag_bag.get_dag(dag_id)
if not dag or task_id not in dag.task_ids:
- flash(
- "Task [{}.{}] doesn't seem to exist"
- " at the moment".format(dag_id, task_id),
- "error")
+ flash("Task [{}.{}] doesn't seem to exist" " at the moment".format(dag_id, task_id), "error")
return redirect(url_for('Airflow.index'))
task = copy.copy(dag.get_task(task_id))
task.resolve_template_files()
@@ -1096,8 +1143,7 @@ def task(self):
if not attr_name.startswith('_'):
attr = getattr(task, attr_name)
# pylint: disable=unidiomatic-typecheck
- if type(attr) != type(self.task) and \
- attr_name not in wwwutils.get_attr_renderer(): # noqa
+ if type(attr) != type(self.task) and attr_name not in wwwutils.get_attr_renderer(): # noqa
task_attrs.append((attr_name, str(attr)))
# pylint: enable=unidiomatic-typecheck
@@ -1106,24 +1152,29 @@ def task(self):
for attr_name in wwwutils.get_attr_renderer():
if hasattr(task, attr_name):
source = getattr(task, attr_name)
- special_attrs_rendered[attr_name] = \
- wwwutils.get_attr_renderer()[attr_name](source)
-
- no_failed_deps_result = [(
- "Unknown",
- "All dependencies are met but the task instance is not running. In most "
- "cases this just means that the task will probably be scheduled soon "
- "unless:
\n- The scheduler is down or under heavy load
\n{}\n"
- "
\nIf this task instance does not start soon please contact your "
- "Airflow administrator for assistance.".format(
- "- This task instance already ran and had it's state changed manually "
- "(e.g. cleared in the UI)
" if ti.state == State.NONE else ""))]
+ special_attrs_rendered[attr_name] = wwwutils.get_attr_renderer()[attr_name](source)
+
+ no_failed_deps_result = [
+ (
+ "Unknown",
+ "All dependencies are met but the task instance is not running. In most "
+ "cases this just means that the task will probably be scheduled soon "
+ "unless:
\n- The scheduler is down or under heavy load
\n{}\n"
+ "
\nIf this task instance does not start soon please contact your "
+ "Airflow administrator for assistance.".format(
+ "- This task instance already ran and had it's state changed manually "
+ "(e.g. cleared in the UI)
"
+ if ti.state == State.NONE
+ else ""
+ ),
+ )
+ ]
# Use the scheduler's context to figure out which dependencies are not met
dep_context = DepContext(SCHEDULER_QUEUED_DEPS)
- failed_dep_reasons = [(dep.dep_name, dep.reason) for dep in
- ti.get_failed_dep_statuses(
- dep_context=dep_context)]
+ failed_dep_reasons = [
+ (dep.dep_name, dep.reason) for dep in ti.get_failed_dep_statuses(dep_context=dep_context)
+ ]
title = "Task Instance Details"
return self.render_template(
@@ -1136,7 +1187,9 @@ def task(self):
special_attrs_rendered=special_attrs_rendered,
form=form,
root=root,
- dag=dag, title=title)
+ dag=dag,
+ title=title,
+ )
@expose('/xcom')
@auth.has_access(
@@ -1164,15 +1217,14 @@ def xcom(self, session=None):
ti = session.query(ti_db).filter(ti_db.dag_id == dag_id and ti_db.task_id == task_id).first()
if not ti:
- flash(
- "Task [{}.{}] doesn't seem to exist"
- " at the moment".format(dag_id, task_id),
- "error")
+ flash("Task [{}.{}] doesn't seem to exist" " at the moment".format(dag_id, task_id), "error")
return redirect(url_for('Airflow.index'))
- xcomlist = session.query(XCom).filter(
- XCom.dag_id == dag_id, XCom.task_id == task_id,
- XCom.execution_date == dttm).all()
+ xcomlist = (
+ session.query(XCom)
+ .filter(XCom.dag_id == dag_id, XCom.task_id == task_id, XCom.execution_date == dttm)
+ .all()
+ )
attributes = []
for xcom in xcomlist:
@@ -1187,7 +1239,9 @@ def xcom(self, session=None):
execution_date=execution_date,
form=form,
root=root,
- dag=dag, title=title)
+ dag=dag,
+ title=title,
+ )
@expose('/run', methods=['POST'])
@auth.has_access(
@@ -1217,12 +1271,14 @@ def run(self):
try:
from airflow.executors.celery_executor import CeleryExecutor # noqa
+
valid_celery_config = isinstance(executor, CeleryExecutor)
except ImportError:
pass
try:
from airflow.executors.kubernetes_executor import KubernetesExecutor # noqa
+
valid_kubernetes_config = isinstance(executor, KubernetesExecutor)
except ImportError:
pass
@@ -1239,14 +1295,16 @@ def run(self):
deps=RUNNING_DEPS,
ignore_all_deps=ignore_all_deps,
ignore_task_deps=ignore_task_deps,
- ignore_ti_state=ignore_ti_state)
+ ignore_ti_state=ignore_ti_state,
+ )
failed_deps = list(ti.get_failed_dep_statuses(dep_context=dep_context))
if failed_deps:
- failed_deps_str = ", ".join(
- [f"{dep.dep_name}: {dep.reason}" for dep in failed_deps])
- flash("Could not queue task instance for execution, dependencies not met: "
- "{}".format(failed_deps_str),
- "error")
+ failed_deps_str = ", ".join([f"{dep.dep_name}: {dep.reason}" for dep in failed_deps])
+ flash(
+ "Could not queue task instance for execution, dependencies not met: "
+ "{}".format(failed_deps_str),
+ "error",
+ )
return redirect(origin)
executor.start()
@@ -1254,11 +1312,10 @@ def run(self):
ti,
ignore_all_deps=ignore_all_deps,
ignore_task_deps=ignore_task_deps,
- ignore_ti_state=ignore_ti_state)
+ ignore_ti_state=ignore_ti_state,
+ )
executor.heartbeat()
- flash(
- "Sent {} to the message queue, "
- "it should start any moment now.".format(ti))
+ flash("Sent {} to the message queue, " "it should start any moment now.".format(ti))
return redirect(origin)
@expose('/delete', methods=['POST'])
@@ -1282,13 +1339,10 @@ def delete(self):
flash(f"DAG with id {dag_id} not found. Cannot delete", 'error')
return redirect(request.referrer)
except DagFileExists:
- flash("Dag id {} is still in DagBag. "
- "Remove the DAG file first.".format(dag_id),
- 'error')
+ flash("Dag id {} is still in DagBag. " "Remove the DAG file first.".format(dag_id), 'error')
return redirect(request.referrer)
- flash("Deleting DAG with id {}. May take a couple minutes to fully"
- " disappear.".format(dag_id))
+ flash("Deleting DAG with id {}. May take a couple minutes to fully" " disappear.".format(dag_id))
# Upon success return to origin.
return redirect(origin)
@@ -1320,10 +1374,7 @@ def trigger(self, session=None):
except TypeError:
flash("Could not pre-populate conf field due to non-JSON-serializable data-types")
return self.render_template(
- 'airflow/trigger.html',
- dag_id=dag_id,
- origin=origin,
- conf=default_conf
+ 'airflow/trigger.html', dag_id=dag_id, origin=origin, conf=default_conf
)
dag_orm = session.query(models.DagModel).filter(models.DagModel.dag_id == dag_id).first()
@@ -1345,10 +1396,7 @@ def trigger(self, session=None):
except json.decoder.JSONDecodeError:
flash("Invalid JSON configuration", "error")
return self.render_template(
- 'airflow/trigger.html',
- dag_id=dag_id,
- origin=origin,
- conf=request_conf
+ 'airflow/trigger.html', dag_id=dag_id, origin=origin, conf=request_conf
)
dag = current_app.dag_bag.get_dag(dag_id)
@@ -1361,13 +1409,12 @@ def trigger(self, session=None):
dag_hash=current_app.dag_bag.dags_hash.get(dag_id),
)
- flash(
- "Triggered {}, "
- "it should start any moment now.".format(dag_id))
+ flash("Triggered {}, " "it should start any moment now.".format(dag_id))
return redirect(origin)
- def _clear_dag_tis(self, dag, start_date, end_date, origin,
- recursive=False, confirmed=False, only_failed=False):
+ def _clear_dag_tis(
+ self, dag, start_date, end_date, origin, recursive=False, confirmed=False, only_failed=False
+ ):
if confirmed:
count = dag.clear(
start_date=start_date,
@@ -1401,9 +1448,9 @@ def _clear_dag_tis(self, dag, start_date, end_date, origin,
response = self.render_template(
'airflow/confirm.html',
- message=("Here's the list of task instances you are about "
- "to clear:"),
- details=details)
+ message=("Here's the list of task instances you are about " "to clear:"),
+ details=details,
+ )
return response
@@ -1435,19 +1482,29 @@ def clear(self):
dag = dag.sub_dag(
task_ids_or_regex=fr"^{task_id}$",
include_downstream=downstream,
- include_upstream=upstream)
+ include_upstream=upstream,
+ )
end_date = execution_date if not future else None
start_date = execution_date if not past else None
- return self._clear_dag_tis(dag, start_date, end_date, origin,
- recursive=recursive, confirmed=confirmed, only_failed=only_failed)
+ return self._clear_dag_tis(
+ dag,
+ start_date,
+ end_date,
+ origin,
+ recursive=recursive,
+ confirmed=confirmed,
+ only_failed=only_failed,
+ )
@expose('/dagrun_clear', methods=['POST'])
- @auth.has_access([
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_TASK_INSTANCE),
- ])
+ @auth.has_access(
+ [
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
+ (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_TASK_INSTANCE),
+ ]
+ )
@action_logging
def dagrun_clear(self):
"""Clears the DagRun"""
@@ -1461,14 +1518,15 @@ def dagrun_clear(self):
start_date = execution_date
end_date = execution_date
- return self._clear_dag_tis(dag, start_date, end_date, origin,
- recursive=True, confirmed=confirmed)
+ return self._clear_dag_tis(dag, start_date, end_date, origin, recursive=True, confirmed=confirmed)
@expose('/blocked', methods=['POST'])
- @auth.has_access([
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN),
- ])
+ @auth.has_access(
+ [
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN),
+ ]
+ )
@provide_session
def blocked(self, session=None):
"""Mark Dag Blocked."""
@@ -1478,9 +1536,7 @@ def blocked(self, session=None):
allowed_dag_ids = [dag_id for dag_id, in session.query(models.DagModel.dag_id)]
# Filter by post parameters
- selected_dag_ids = {
- unquote(dag_id) for dag_id in request.form.getlist('dag_ids') if dag_id
- }
+ selected_dag_ids = {unquote(dag_id) for dag_id in request.form.getlist('dag_ids') if dag_id}
if selected_dag_ids:
filter_dag_ids = selected_dag_ids.intersection(allowed_dag_ids)
@@ -1493,9 +1549,9 @@ def blocked(self, session=None):
# pylint: disable=comparison-with-callable
dags = (
session.query(DagRun.dag_id, sqla.func.count(DagRun.id))
- .filter(DagRun.state == State.RUNNING)
- .filter(DagRun.dag_id.in_(filter_dag_ids))
- .group_by(DagRun.dag_id)
+ .filter(DagRun.state == State.RUNNING)
+ .filter(DagRun.dag_id.in_(filter_dag_ids))
+ .group_by(DagRun.dag_id)
)
# pylint: enable=comparison-with-callable
@@ -1506,11 +1562,13 @@ def blocked(self, session=None):
if dag:
# TODO: Make max_active_runs a column so we can query for it directly
max_active_runs = dag.max_active_runs
- payload.append({
- 'dag_id': dag_id,
- 'active_dag_run': active_dag_runs,
- 'max_active_runs': max_active_runs,
- })
+ payload.append(
+ {
+ 'dag_id': dag_id,
+ 'active_dag_run': active_dag_runs,
+ 'max_active_runs': max_active_runs,
+ }
+ )
return wwwutils.json_response(payload)
def _mark_dagrun_state_as_failed(self, dag_id, execution_date, confirmed, origin):
@@ -1537,7 +1595,8 @@ def _mark_dagrun_state_as_failed(self, dag_id, execution_date, confirmed, origin
response = self.render_template(
'airflow/confirm.html',
message="Here's the list of task instances you are about to mark as failed",
- details=details)
+ details=details,
+ )
return response
@@ -1553,8 +1612,7 @@ def _mark_dagrun_state_as_success(self, dag_id, execution_date, confirmed, origi
flash(f'Cannot find DAG: {dag_id}', 'error')
return redirect(origin)
- new_dag_state = set_dag_run_state_to_success(dag, execution_date,
- commit=confirmed)
+ new_dag_state = set_dag_run_state_to_success(dag, execution_date, commit=confirmed)
if confirmed:
flash('Marked success on {} task instances'.format(len(new_dag_state)))
@@ -1566,15 +1624,18 @@ def _mark_dagrun_state_as_success(self, dag_id, execution_date, confirmed, origi
response = self.render_template(
'airflow/confirm.html',
message="Here's the list of task instances you are about to mark as success",
- details=details)
+ details=details,
+ )
return response
@expose('/dagrun_failed', methods=['POST'])
- @auth.has_access([
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN),
- ])
+ @auth.has_access(
+ [
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
+ (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN),
+ ]
+ )
@action_logging
def dagrun_failed(self):
"""Mark DagRun failed."""
@@ -1582,14 +1643,15 @@ def dagrun_failed(self):
execution_date = request.form.get('execution_date')
confirmed = request.form.get('confirmed') == 'true'
origin = get_safe_url(request.form.get('origin'))
- return self._mark_dagrun_state_as_failed(dag_id, execution_date,
- confirmed, origin)
+ return self._mark_dagrun_state_as_failed(dag_id, execution_date, confirmed, origin)
@expose('/dagrun_success', methods=['POST'])
- @auth.has_access([
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN),
- ])
+ @auth.has_access(
+ [
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
+ (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN),
+ ]
+ )
@action_logging
def dagrun_success(self):
"""Mark DagRun success"""
@@ -1597,13 +1659,21 @@ def dagrun_success(self):
execution_date = request.form.get('execution_date')
confirmed = request.form.get('confirmed') == 'true'
origin = get_safe_url(request.form.get('origin'))
- return self._mark_dagrun_state_as_success(dag_id, execution_date,
- confirmed, origin)
-
- def _mark_task_instance_state(self, # pylint: disable=too-many-arguments
- dag_id, task_id, origin, execution_date,
- confirmed, upstream, downstream,
- future, past, state):
+ return self._mark_dagrun_state_as_success(dag_id, execution_date, confirmed, origin)
+
+ def _mark_task_instance_state( # pylint: disable=too-many-arguments
+ self,
+ dag_id,
+ task_id,
+ origin,
+ execution_date,
+ confirmed,
+ upstream,
+ downstream,
+ future,
+ past,
+ state,
+ ):
dag = current_app.dag_bag.get_dag(dag_id)
task = dag.get_task(task_id)
task.dag = dag
@@ -1618,33 +1688,48 @@ def _mark_task_instance_state(self, # pylint: disable=too-many-arguments
from airflow.api.common.experimental.mark_tasks import set_state
if confirmed:
- altered = set_state(tasks=[task], execution_date=execution_date,
- upstream=upstream, downstream=downstream,
- future=future, past=past, state=state,
- commit=True)
+ altered = set_state(
+ tasks=[task],
+ execution_date=execution_date,
+ upstream=upstream,
+ downstream=downstream,
+ future=future,
+ past=past,
+ state=state,
+ commit=True,
+ )
flash("Marked {} on {} task instances".format(state, len(altered)))
return redirect(origin)
- to_be_altered = set_state(tasks=[task], execution_date=execution_date,
- upstream=upstream, downstream=downstream,
- future=future, past=past, state=state,
- commit=False)
+ to_be_altered = set_state(
+ tasks=[task],
+ execution_date=execution_date,
+ upstream=upstream,
+ downstream=downstream,
+ future=future,
+ past=past,
+ state=state,
+ commit=False,
+ )
details = "\n".join([str(t) for t in to_be_altered])
response = self.render_template(
"airflow/confirm.html",
message=(f"Here's the list of task instances you are about to mark as {state}:"),
- details=details)
+ details=details,
+ )
return response
@expose('/failed', methods=['POST'])
- @auth.has_access([
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE),
- ])
+ @auth.has_access(
+ [
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
+ (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE),
+ ]
+ )
@action_logging
def failed(self):
"""Mark task as failed."""
@@ -1659,15 +1744,26 @@ def failed(self):
future = request.form.get('failed_future') == "true"
past = request.form.get('failed_past') == "true"
- return self._mark_task_instance_state(dag_id, task_id, origin, execution_date,
- confirmed, upstream, downstream,
- future, past, State.FAILED)
+ return self._mark_task_instance_state(
+ dag_id,
+ task_id,
+ origin,
+ execution_date,
+ confirmed,
+ upstream,
+ downstream,
+ future,
+ past,
+ State.FAILED,
+ )
@expose('/success', methods=['POST'])
- @auth.has_access([
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE),
- ])
+ @auth.has_access(
+ [
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
+ (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE),
+ ]
+ )
@action_logging
def success(self):
"""Mark task as success."""
@@ -1682,16 +1778,27 @@ def success(self):
future = request.form.get('success_future') == "true"
past = request.form.get('success_past') == "true"
- return self._mark_task_instance_state(dag_id, task_id, origin, execution_date,
- confirmed, upstream, downstream,
- future, past, State.SUCCESS)
+ return self._mark_task_instance_state(
+ dag_id,
+ task_id,
+ origin,
+ execution_date,
+ confirmed,
+ upstream,
+ downstream,
+ future,
+ past,
+ State.SUCCESS,
+ )
@expose('/tree')
- @auth.has_access([
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_LOG),
- ])
+ @auth.has_access(
+ [
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_LOG),
+ ]
+ )
@gzipped # pylint: disable=too-many-locals
@action_logging # pylint: disable=too-many-locals
def tree(self):
@@ -1705,10 +1812,7 @@ def tree(self):
root = request.args.get('root')
if root:
- dag = dag.sub_dag(
- task_ids_or_regex=root,
- include_downstream=False,
- include_upstream=True)
+ dag = dag.sub_dag(task_ids_or_regex=root, include_downstream=False, include_upstream=True)
base_date = request.args.get('base_date')
num_runs = request.args.get('num_runs')
@@ -1725,16 +1829,12 @@ def tree(self):
with create_session() as session:
dag_runs = (
session.query(DagRun)
- .filter(
- DagRun.dag_id == dag.dag_id,
- DagRun.execution_date <= base_date)
+ .filter(DagRun.dag_id == dag.dag_id, DagRun.execution_date <= base_date)
.order_by(DagRun.execution_date.desc())
.limit(num_runs)
.all()
)
- dag_runs = {
- dr.execution_date: alchemy_to_dict(dr) for dr in dag_runs
- }
+ dag_runs = {dr.execution_date: alchemy_to_dict(dr) for dr in dag_runs}
dates = sorted(list(dag_runs.keys()))
max_date = max(dates) if dates else None
@@ -1782,10 +1882,7 @@ def recurse_nodes(task, visited):
node = {
'name': task.task_id,
- 'instances': [
- encode_ti(task_instances.get((task_id, d)))
- for d in dates
- ],
+ 'instances': [encode_ti(task_instances.get((task_id, d))) for d in dates],
'num_dep': len(task.downstream_list),
'operator': task.task_type,
'retries': task.retries,
@@ -1795,8 +1892,10 @@ def recurse_nodes(task, visited):
if task.downstream_list:
children = [
- recurse_nodes(t, visited) for t in task.downstream_list
- if node_count < node_limit or t not in visited]
+ recurse_nodes(t, visited)
+ for t in task.downstream_list
+ if node_count < node_limit or t not in visited
+ ]
# D3 tree uses children vs _children to define what is
# expanded or not. The following block makes it such that
@@ -1823,14 +1922,10 @@ def recurse_nodes(task, visited):
data = {
'name': '[DAG]',
'children': [recurse_nodes(t, set()) for t in dag.roots],
- 'instances': [
- dag_runs.get(d) or {'execution_date': d.isoformat()}
- for d in dates
- ],
+ 'instances': [dag_runs.get(d) or {'execution_date': d.isoformat()} for d in dates],
}
- form = DateTimeWithNumRunsForm(data={'base_date': max_date,
- 'num_runs': num_runs})
+ form = DateTimeWithNumRunsForm(data={'base_date': max_date, 'num_runs': num_runs})
doc_md = wwwutils.wrapped_markdown(getattr(dag, 'doc_md', None), css_class='dag-doc')
@@ -1851,16 +1946,20 @@ def recurse_nodes(task, visited):
dag=dag,
doc_md=doc_md,
data=data,
- blur=blur, num_runs=num_runs,
+ blur=blur,
+ num_runs=num_runs,
show_external_log_redirect=task_log_reader.supports_external_link,
- external_log_name=external_log_name)
+ external_log_name=external_log_name,
+ )
@expose('/graph')
- @auth.has_access([
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_LOG),
- ])
+ @auth.has_access(
+ [
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_LOG),
+ ]
+ )
@gzipped
@action_logging
@provide_session
@@ -1875,10 +1974,7 @@ def graph(self, session=None):
root = request.args.get('root')
if root:
- dag = dag.sub_dag(
- task_ids_or_regex=root,
- include_upstream=True,
- include_downstream=False)
+ dag = dag.sub_dag(task_ids_or_regex=root, include_upstream=True, include_downstream=False)
arrange = request.args.get('arrange', dag.orientation)
@@ -1892,26 +1988,28 @@ def graph(self, session=None):
class GraphForm(DateTimeWithNumRunsWithDagRunsForm):
"""Graph Form class."""
- arrange = SelectField("Layout", choices=(
- ('LR', "Left > Right"),
- ('RL', "Right > Left"),
- ('TB', "Top > Bottom"),
- ('BT', "Bottom > Top"),
- ))
+ arrange = SelectField(
+ "Layout",
+ choices=(
+ ('LR', "Left > Right"),
+ ('RL', "Right > Left"),
+ ('TB', "Top > Bottom"),
+ ('BT', "Bottom > Top"),
+ ),
+ )
form = GraphForm(data=dt_nr_dr_data)
form.execution_date.choices = dt_nr_dr_data['dr_choices']
- task_instances = {
- ti.task_id: alchemy_to_dict(ti)
- for ti in dag.get_task_instances(dttm, dttm)}
+ task_instances = {ti.task_id: alchemy_to_dict(ti) for ti in dag.get_task_instances(dttm, dttm)}
tasks = {
t.task_id: {
'dag_id': t.dag_id,
'task_type': t.task_type,
'extra_links': t.extra_links,
}
- for t in dag.tasks}
+ for t in dag.tasks
+ }
if not tasks:
flash("No tasks found", "error")
session.commit()
@@ -1942,13 +2040,16 @@ class GraphForm(DateTimeWithNumRunsWithDagRunsForm):
edges=edges,
show_external_log_redirect=task_log_reader.supports_external_link,
external_log_name=external_log_name,
- dag_run_state=dt_nr_dr_data['dr_state'])
+ dag_run_state=dt_nr_dr_data['dr_state'],
+ )
@expose('/duration')
- @auth.has_access([
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
- ])
+ @auth.has_access(
+ [
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
+ ]
+ )
@action_logging # pylint: disable=too-many-locals
@provide_session # pylint: disable=too-many-locals
def duration(self, session=None):
@@ -1979,16 +2080,11 @@ def duration(self, session=None):
root = request.args.get('root')
if root:
- dag = dag.sub_dag(
- task_ids_or_regex=root,
- include_upstream=True,
- include_downstream=False)
+ dag = dag.sub_dag(task_ids_or_regex=root, include_upstream=True, include_downstream=False)
chart_height = wwwutils.get_chart_height(dag)
- chart = nvd3.lineChart(
- name="lineChart", x_is_date=True, height=chart_height, width="1200")
- cum_chart = nvd3.lineChart(
- name="cumLineChart", x_is_date=True, height=chart_height, width="1200")
+ chart = nvd3.lineChart(name="lineChart", x_is_date=True, height=chart_height, width="1200")
+ cum_chart = nvd3.lineChart(name="cumLineChart", x_is_date=True, height=chart_height, width="1200")
y_points = defaultdict(list)
x_points = defaultdict(list)
@@ -1997,17 +2093,22 @@ def duration(self, session=None):
task_instances = dag.get_task_instances(start_date=min_date, end_date=base_date)
ti_fails = (
session.query(TaskFail)
- .filter(TaskFail.dag_id == dag.dag_id,
- TaskFail.execution_date >= min_date,
- TaskFail.execution_date <= base_date,
- TaskFail.task_id.in_([t.task_id for t in dag.tasks]))
+ .filter(
+ TaskFail.dag_id == dag.dag_id,
+ TaskFail.execution_date >= min_date,
+ TaskFail.execution_date <= base_date,
+ TaskFail.task_id.in_([t.task_id for t in dag.tasks]),
+ )
.all()
)
fails_totals = defaultdict(int)
for failed_task_instance in ti_fails:
- dict_key = (failed_task_instance.dag_id, failed_task_instance.task_id,
- failed_task_instance.execution_date)
+ dict_key = (
+ failed_task_instance.dag_id,
+ failed_task_instance.task_id,
+ failed_task_instance.execution_date,
+ )
if failed_task_instance.duration:
fails_totals[dict_key] += failed_task_instance.duration
@@ -2025,34 +2126,38 @@ def duration(self, session=None):
y_unit = infer_time_unit([d for t in y_points.values() for d in t])
cum_y_unit = infer_time_unit([d for t in cumulative_y.values() for d in t])
# update the y Axis on both charts to have the correct time units
- chart.create_y_axis('yAxis', format='.02f', custom_format=False,
- label=f'Duration ({y_unit})')
+ chart.create_y_axis('yAxis', format='.02f', custom_format=False, label=f'Duration ({y_unit})')
chart.axislist['yAxis']['axisLabelDistance'] = '-15'
- cum_chart.create_y_axis('yAxis', format='.02f', custom_format=False,
- label=f'Duration ({cum_y_unit})')
+ cum_chart.create_y_axis('yAxis', format='.02f', custom_format=False, label=f'Duration ({cum_y_unit})')
cum_chart.axislist['yAxis']['axisLabelDistance'] = '-15'
for task in dag.tasks:
if x_points[task.task_id]:
- chart.add_serie(name=task.task_id, x=x_points[task.task_id],
- y=scale_time_units(y_points[task.task_id], y_unit))
- cum_chart.add_serie(name=task.task_id, x=x_points[task.task_id],
- y=scale_time_units(cumulative_y[task.task_id],
- cum_y_unit))
+ chart.add_serie(
+ name=task.task_id,
+ x=x_points[task.task_id],
+ y=scale_time_units(y_points[task.task_id], y_unit),
+ )
+ cum_chart.add_serie(
+ name=task.task_id,
+ x=x_points[task.task_id],
+ y=scale_time_units(cumulative_y[task.task_id], cum_y_unit),
+ )
dates = sorted(list({ti.execution_date for ti in task_instances}))
max_date = max([ti.execution_date for ti in task_instances]) if dates else None
session.commit()
- form = DateTimeWithNumRunsForm(data={'base_date': max_date,
- 'num_runs': num_runs})
+ form = DateTimeWithNumRunsForm(data={'base_date': max_date, 'num_runs': num_runs})
chart.buildcontent()
cum_chart.buildcontent()
s_index = cum_chart.htmlcontent.rfind('});')
- cum_chart.htmlcontent = (cum_chart.htmlcontent[:s_index] +
- "$( document ).trigger('chartload')" +
- cum_chart.htmlcontent[s_index:])
+ cum_chart.htmlcontent = (
+ cum_chart.htmlcontent[:s_index]
+ + "$( document ).trigger('chartload')"
+ + cum_chart.htmlcontent[s_index:]
+ )
return self.render_template(
'airflow/duration_chart.html',
@@ -2065,10 +2170,12 @@ def duration(self, session=None):
)
@expose('/tries')
- @auth.has_access([
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
- ])
+ @auth.has_access(
+ [
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
+ ]
+ )
@action_logging
@provide_session
def tries(self, session=None):
@@ -2090,15 +2197,12 @@ def tries(self, session=None):
root = request.args.get('root')
if root:
- dag = dag.sub_dag(
- task_ids_or_regex=root,
- include_upstream=True,
- include_downstream=False)
+ dag = dag.sub_dag(task_ids_or_regex=root, include_upstream=True, include_downstream=False)
chart_height = wwwutils.get_chart_height(dag)
chart = nvd3.lineChart(
- name="lineChart", x_is_date=True, y_axis_format='d', height=chart_height,
- width="1200")
+ name="lineChart", x_is_date=True, y_axis_format='d', height=chart_height, width="1200"
+ )
for task in dag.tasks:
y_points = []
@@ -2119,8 +2223,7 @@ def tries(self, session=None):
session.commit()
- form = DateTimeWithNumRunsForm(data={'base_date': max_date,
- 'num_runs': num_runs})
+ form = DateTimeWithNumRunsForm(data={'base_date': max_date, 'num_runs': num_runs})
chart.buildcontent()
@@ -2135,10 +2238,12 @@ def tries(self, session=None):
)
@expose('/landing_times')
- @auth.has_access([
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
- ])
+ @auth.has_access(
+ [
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
+ ]
+ )
@action_logging
@provide_session
def landing_times(self, session=None):
@@ -2160,14 +2265,10 @@ def landing_times(self, session=None):
root = request.args.get('root')
if root:
- dag = dag.sub_dag(
- task_ids_or_regex=root,
- include_upstream=True,
- include_downstream=False)
+ dag = dag.sub_dag(task_ids_or_regex=root, include_upstream=True, include_downstream=False)
chart_height = wwwutils.get_chart_height(dag)
- chart = nvd3.lineChart(
- name="lineChart", x_is_date=True, height=chart_height, width="1200")
+ chart = nvd3.lineChart(name="lineChart", x_is_date=True, height=chart_height, width="1200")
y_points = {}
x_points = {}
for task in dag.tasks:
@@ -2188,13 +2289,15 @@ def landing_times(self, session=None):
# for the DAG
y_unit = infer_time_unit([d for t in y_points.values() for d in t])
# update the y Axis to have the correct time units
- chart.create_y_axis('yAxis', format='.02f', custom_format=False,
- label=f'Landing Time ({y_unit})')
+ chart.create_y_axis('yAxis', format='.02f', custom_format=False, label=f'Landing Time ({y_unit})')
chart.axislist['yAxis']['axisLabelDistance'] = '-15'
for task in dag.tasks:
if x_points[task.task_id]:
- chart.add_serie(name=task.task_id, x=x_points[task.task_id],
- y=scale_time_units(y_points[task.task_id], y_unit))
+ chart.add_serie(
+ name=task.task_id,
+ x=x_points[task.task_id],
+ y=scale_time_units(y_points[task.task_id], y_unit),
+ )
tis = dag.get_task_instances(start_date=min_date, end_date=base_date)
dates = sorted(list({ti.execution_date for ti in tis}))
@@ -2202,8 +2305,7 @@ def landing_times(self, session=None):
session.commit()
- form = DateTimeWithNumRunsForm(data={'base_date': max_date,
- 'num_runs': num_runs})
+ form = DateTimeWithNumRunsForm(data={'base_date': max_date, 'num_runs': num_runs})
chart.buildcontent()
return self.render_template(
'airflow/chart.html',
@@ -2217,29 +2319,31 @@ def landing_times(self, session=None):
)
@expose('/paused', methods=['POST'])
- @auth.has_access([
- (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG),
- ])
+ @auth.has_access(
+ [
+ (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG),
+ ]
+ )
@action_logging
def paused(self):
"""Toggle paused."""
dag_id = request.args.get('dag_id')
is_paused = request.args.get('is_paused') == 'false'
- models.DagModel.get_dagmodel(dag_id).set_is_paused(
- is_paused=is_paused)
+ models.DagModel.get_dagmodel(dag_id).set_is_paused(is_paused=is_paused)
return "OK"
@expose('/refresh', methods=['POST'])
- @auth.has_access([
- (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG),
- ])
+ @auth.has_access(
+ [
+ (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG),
+ ]
+ )
@action_logging
@provide_session
def refresh(self, session=None):
"""Refresh DAG."""
dag_id = request.values.get('dag_id')
- orm_dag = session.query(
- DagModel).filter(DagModel.dag_id == dag_id).first()
+ orm_dag = session.query(DagModel).filter(DagModel.dag_id == dag_id).first()
if orm_dag:
orm_dag.last_expired = timezone.utcnow()
@@ -2254,9 +2358,11 @@ def refresh(self, session=None):
return redirect(request.referrer)
@expose('/refresh_all', methods=['POST'])
- @auth.has_access([
- (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG),
- ])
+ @auth.has_access(
+ [
+ (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG),
+ ]
+ )
@action_logging
def refresh_all(self):
"""Refresh everything"""
@@ -2269,10 +2375,12 @@ def refresh_all(self):
return redirect(url_for('Airflow.index'))
@expose('/gantt')
- @auth.has_access([
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
- ])
+ @auth.has_access(
+ [
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
+ ]
+ )
@action_logging
@provide_session
def gantt(self, session=None):
@@ -2283,10 +2391,7 @@ def gantt(self, session=None):
root = request.args.get('root')
if root:
- dag = dag.sub_dag(
- task_ids_or_regex=root,
- include_upstream=True,
- include_downstream=False)
+ dag = dag.sub_dag(task_ids_or_regex=root, include_upstream=True, include_downstream=False)
dt_nr_dr_data = get_date_time_num_runs_dag_runs_form_data(request, session, dag)
dttm = dt_nr_dr_data['dttm']
@@ -2294,18 +2399,24 @@ def gantt(self, session=None):
form = DateTimeWithNumRunsWithDagRunsForm(data=dt_nr_dr_data)
form.execution_date.choices = dt_nr_dr_data['dr_choices']
- tis = [
- ti for ti in dag.get_task_instances(dttm, dttm)
- if ti.start_date and ti.state]
+ tis = [ti for ti in dag.get_task_instances(dttm, dttm) if ti.start_date and ti.state]
tis = sorted(tis, key=lambda ti: ti.start_date)
- ti_fails = list(itertools.chain(*[(
- session
- .query(TaskFail)
- .filter(TaskFail.dag_id == ti.dag_id,
- TaskFail.task_id == ti.task_id,
- TaskFail.execution_date == ti.execution_date)
- .all()
- ) for ti in tis]))
+ ti_fails = list(
+ itertools.chain(
+ *[
+ (
+ session.query(TaskFail)
+ .filter(
+ TaskFail.dag_id == ti.dag_id,
+ TaskFail.task_id == ti.task_id,
+ TaskFail.execution_date == ti.execution_date,
+ )
+ .all()
+ )
+ for ti in tis
+ ]
+ )
+ )
# determine bars to show in the gantt chart
gantt_bar_items = []
@@ -2333,8 +2444,9 @@ def gantt(self, session=None):
else:
try_count = 1
prev_task_id = failed_task_instance.task_id
- gantt_bar_items.append((failed_task_instance.task_id, start_date, end_date, State.FAILED,
- try_count))
+ gantt_bar_items.append(
+ (failed_task_instance.task_id, start_date, end_date, State.FAILED, try_count)
+ )
tf_count += 1
task = dag.get_task(failed_task_instance.task_id)
task_dict = alchemy_to_dict(failed_task_instance)
@@ -2364,10 +2476,12 @@ def gantt(self, session=None):
)
@expose('/extra_links')
- @auth.has_access([
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
- ])
+ @auth.has_access(
+ [
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
+ ]
+ )
@action_logging
def extra_links(self):
"""
@@ -2397,11 +2511,10 @@ def extra_links(self):
if not dag or task_id not in dag.task_ids:
response = jsonify(
- {'url': None,
- 'error': "can't find dag {dag} or task_id {task_id}".format(
- dag=dag,
- task_id=task_id
- )}
+ {
+ 'url': None,
+ 'error': f"can't find dag {dag} or task_id {task_id}",
+ }
)
response.status_code = 404
return response
@@ -2418,16 +2531,17 @@ def extra_links(self):
response.status_code = 200
return response
else:
- response = jsonify(
- {'url': None, 'error': f'No URL found for {link_name}'})
+ response = jsonify({'url': None, 'error': f'No URL found for {link_name}'})
response.status_code = 404
return response
@expose('/object/task_instances')
- @auth.has_access([
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
- ])
+ @auth.has_access(
+ [
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
+ ]
+ )
@action_logging
def task_instances(self):
"""Shows task instances."""
@@ -2440,9 +2554,7 @@ def task_instances(self):
else:
return "Error: Invalid execution_date"
- task_instances = {
- ti.task_id: alchemy_to_dict(ti)
- for ti in dag.get_task_instances(dttm, dttm)}
+ task_instances = {ti.task_id: alchemy_to_dict(ti) for ti in dag.get_task_instances(dttm, dttm)}
return json.dumps(task_instances, cls=utils_json.AirflowJsonEncoder)
@@ -2453,12 +2565,17 @@ class ConfigurationView(AirflowBaseView):
default_view = 'conf'
class_permission_name = permissions.RESOURCE_CONFIG
- base_permissions = [permissions.ACTION_CAN_READ, permissions.ACTION_CAN_ACCESS_MENU,]
+ base_permissions = [
+ permissions.ACTION_CAN_READ,
+ permissions.ACTION_CAN_ACCESS_MENU,
+ ]
@expose('/configuration')
- @auth.has_access([
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_CONFIG),
- ])
+ @auth.has_access(
+ [
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_CONFIG),
+ ]
+ )
def conf(self):
"""Shows configuration."""
raw = request.args.get('raw') == "true"
@@ -2468,31 +2585,36 @@ def conf(self):
if conf.getboolean("webserver", "expose_config"):
with open(AIRFLOW_CONFIG) as file:
config = file.read()
- table = [(section, key, value, source)
- for section, parameters in conf.as_dict(True, True).items()
- for key, (value, source) in parameters.items()]
+ table = [
+ (section, key, value, source)
+ for section, parameters in conf.as_dict(True, True).items()
+ for key, (value, source) in parameters.items()
+ ]
else:
config = (
"# Your Airflow administrator chose not to expose the "
- "configuration, most likely for security reasons.")
+ "configuration, most likely for security reasons."
+ )
table = None
if raw:
- return Response(
- response=config,
- status=200,
- mimetype="application/text")
+ return Response(response=config, status=200, mimetype="application/text")
else:
- code_html = Markup(highlight(
- config,
- lexers.IniLexer(), # Lexer call pylint: disable=no-member
- HtmlFormatter(noclasses=True))
+ code_html = Markup(
+ highlight(
+ config,
+ lexers.IniLexer(), # Lexer call pylint: disable=no-member
+ HtmlFormatter(noclasses=True),
+ )
)
return self.render_template(
'airflow/config.html',
pre_subtitle=settings.HEADER + " v" + airflow.__version__,
- code_html=code_html, title=title, subtitle=subtitle,
- table=table)
+ code_html=code_html,
+ title=title,
+ subtitle=subtitle,
+ table=table,
+ )
class RedocView(AirflowBaseView):
@@ -2511,10 +2633,11 @@ def redoc(self):
# ModelViews
######################################################################################
+
class DagFilter(BaseFilter):
"""Filter using DagIDs"""
- def apply(self, query, func): # noqa pylint: disable=redefined-outer-name,unused-argument
+ def apply(self, query, func): # noqa pylint: disable=redefined-outer-name,unused-argument
if current_app.appbuilder.sm.has_all_dags_access():
return query
filter_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids(g.user)
@@ -2595,8 +2718,7 @@ class XComModelView(AirflowModelView):
'dag_id': wwwutils.dag_link,
}
- @action('muldelete', 'Delete', "Are you sure you want to delete selected records?",
- single=False)
+ @action('muldelete', 'Delete', "Are you sure you want to delete selected records?", single=False)
def action_muldelete(self, items):
"""Multiple delete action."""
self.datamodel.delete_all(items)
@@ -2638,38 +2760,49 @@ class ConnectionModelView(AirflowModelView):
permissions.ACTION_CAN_ACCESS_MENU,
]
- extra_fields = ['extra__jdbc__drv_path', 'extra__jdbc__drv_clsname',
- 'extra__google_cloud_platform__project',
- 'extra__google_cloud_platform__key_path',
- 'extra__google_cloud_platform__keyfile_dict',
- 'extra__google_cloud_platform__scope',
- 'extra__google_cloud_platform__num_retries',
- 'extra__grpc__auth_type',
- 'extra__grpc__credential_pem_file',
- 'extra__grpc__scopes',
- 'extra__yandexcloud__service_account_json',
- 'extra__yandexcloud__service_account_json_path',
- 'extra__yandexcloud__oauth',
- 'extra__yandexcloud__public_ssh_key',
- 'extra__yandexcloud__folder_id',
- 'extra__kubernetes__in_cluster',
- 'extra__kubernetes__kube_config',
- 'extra__kubernetes__namespace']
- list_columns = ['conn_id', 'conn_type', 'host', 'port', 'is_encrypted',
- 'is_extra_encrypted']
- add_columns = edit_columns = ['conn_id', 'conn_type', 'host', 'schema',
- 'login', 'password', 'port', 'extra'] + extra_fields
+ extra_fields = [
+ 'extra__jdbc__drv_path',
+ 'extra__jdbc__drv_clsname',
+ 'extra__google_cloud_platform__project',
+ 'extra__google_cloud_platform__key_path',
+ 'extra__google_cloud_platform__keyfile_dict',
+ 'extra__google_cloud_platform__scope',
+ 'extra__google_cloud_platform__num_retries',
+ 'extra__grpc__auth_type',
+ 'extra__grpc__credential_pem_file',
+ 'extra__grpc__scopes',
+ 'extra__yandexcloud__service_account_json',
+ 'extra__yandexcloud__service_account_json_path',
+ 'extra__yandexcloud__oauth',
+ 'extra__yandexcloud__public_ssh_key',
+ 'extra__yandexcloud__folder_id',
+ 'extra__kubernetes__in_cluster',
+ 'extra__kubernetes__kube_config',
+ 'extra__kubernetes__namespace',
+ ]
+ list_columns = ['conn_id', 'conn_type', 'host', 'port', 'is_encrypted', 'is_extra_encrypted']
+ add_columns = edit_columns = [
+ 'conn_id',
+ 'conn_type',
+ 'host',
+ 'schema',
+ 'login',
+ 'password',
+ 'port',
+ 'extra',
+ ] + extra_fields
add_form = edit_form = ConnectionForm
add_template = 'airflow/conn_create.html'
edit_template = 'airflow/conn_edit.html'
base_order = ('conn_id', 'asc')
- @action('muldelete', 'Delete', 'Are you sure you want to delete selected records?',
- single=False)
- @auth.has_access([
- (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG),
- ])
+ @action('muldelete', 'Delete', 'Are you sure you want to delete selected records?', single=False)
+ @auth.has_access(
+ [
+ (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG),
+ ]
+ )
def action_muldelete(self, items):
"""Multiple delete."""
self.datamodel.delete_all(items)
@@ -2680,9 +2813,7 @@ def process_form(self, form, is_created):
"""Process form data."""
formdata = form.data
if formdata['conn_type'] in ['jdbc', 'google_cloud_platform', 'grpc', 'yandexcloud', 'kubernetes']:
- extra = {
- key: formdata[key]
- for key in self.extra_fields if key in formdata}
+ extra = {key: formdata[key] for key in self.extra_fields if key in formdata}
form.extra.data = json.dumps(extra)
def prefill_form(self, form, pk):
@@ -2795,8 +2926,7 @@ class PoolModelView(AirflowModelView):
base_order = ('pool', 'asc')
- @action('muldelete', 'Delete', 'Are you sure you want to delete selected records?',
- single=False)
+ @action('muldelete', 'Delete', 'Are you sure you want to delete selected records?', single=False)
def action_muldelete(self, items):
"""Multiple delete."""
if any(item.pool == models.Pool.DEFAULT_POOL_NAME for item in items):
@@ -2823,8 +2953,8 @@ def frunning_slots(self):
if pool_id is not None and running_slots is not None:
url = url_for('TaskInstanceModelView.list', _flt_3_pool=pool_id, _flt_3_state='running')
return Markup("{running_slots}").format( # noqa
- url=url,
- running_slots=running_slots)
+ url=url, running_slots=running_slots
+ )
else:
return Markup('Invalid')
@@ -2835,20 +2965,14 @@ def fqueued_slots(self):
if pool_id is not None and queued_slots is not None:
url = url_for('TaskInstanceModelView.list', _flt_3_pool=pool_id, _flt_3_state='queued')
return Markup("{queued_slots}").format( # noqa
- url=url, queued_slots=queued_slots)
+ url=url, queued_slots=queued_slots
+ )
else:
return Markup('Invalid')
- formatters_columns = {
- 'pool': pool_link,
- 'running_slots': frunning_slots,
- 'queued_slots': fqueued_slots
- }
+ formatters_columns = {'pool': pool_link, 'running_slots': frunning_slots, 'queued_slots': fqueued_slots}
- validators_columns = {
- 'pool': [validators.DataRequired()],
- 'slots': [validators.NumberRange(min=-1)]
- }
+ validators_columns = {'pool': [validators.DataRequired()], 'slots': [validators.NumberRange(min=-1)]}
class VariableModelView(AirflowModelView):
@@ -2901,16 +3025,13 @@ def hidden_field_formatter(self):
'val': hidden_field_formatter,
}
- validators_columns = {
- 'key': [validators.DataRequired()]
- }
+ validators_columns = {'key': [validators.DataRequired()]}
def prefill_form(self, form, request_id): # pylint: disable=unused-argument
if wwwutils.should_hide_value_for_key(form.key.data):
form.val.data = '*' * 8
- @action('muldelete', 'Delete', 'Are you sure you want to delete selected records?',
- single=False)
+ @action('muldelete', 'Delete', 'Are you sure you want to delete selected records?', single=False)
def action_muldelete(self, items):
"""Multiple delete."""
self.datamodel.delete_all(items)
@@ -2976,14 +3097,35 @@ class JobModelView(AirflowModelView):
method_permission_name = {
'list': 'read',
}
- base_permissions = [permissions.ACTION_CAN_READ, permissions.ACTION_CAN_ACCESS_MENU,]
+ base_permissions = [
+ permissions.ACTION_CAN_READ,
+ permissions.ACTION_CAN_ACCESS_MENU,
+ ]
- list_columns = ['id', 'dag_id', 'state', 'job_type', 'start_date',
- 'end_date', 'latest_heartbeat',
- 'executor_class', 'hostname', 'unixname']
- search_columns = ['id', 'dag_id', 'state', 'job_type', 'start_date',
- 'end_date', 'latest_heartbeat', 'executor_class',
- 'hostname', 'unixname']
+ list_columns = [
+ 'id',
+ 'dag_id',
+ 'state',
+ 'job_type',
+ 'start_date',
+ 'end_date',
+ 'latest_heartbeat',
+ 'executor_class',
+ 'hostname',
+ 'unixname',
+ ]
+ search_columns = [
+ 'id',
+ 'dag_id',
+ 'state',
+ 'job_type',
+ 'start_date',
+ 'end_date',
+ 'latest_heartbeat',
+ 'executor_class',
+ 'hostname',
+ 'unixname',
+ ]
base_order = ('start_date', 'desc')
@@ -3041,8 +3183,7 @@ class DagRunModelView(AirflowModelView):
'conf': wwwutils.json_f('conf'),
}
- @action('muldelete', "Delete", "Are you sure you want to delete selected records?",
- single=False)
+ @action('muldelete', "Delete", "Are you sure you want to delete selected records?", single=False)
@provide_session
def action_muldelete(self, items, session=None): # noqa # pylint: disable=unused-argument
"""Multiple delete."""
@@ -3060,8 +3201,9 @@ def action_set_running(self, drs, session=None):
try:
count = 0
dirty_ids = []
- for dr in session.query(DagRun).filter(
- DagRun.id.in_([dagrun.id for dagrun in drs])).all(): # noqa pylint: disable=no-member
+ for dr in (
+ session.query(DagRun).filter(DagRun.id.in_([dagrun.id for dagrun in drs])).all()
+ ): # noqa pylint: disable=no-member
dirty_ids.append(dr.dag_id)
count += 1
dr.start_date = timezone.utcnow()
@@ -3073,9 +3215,12 @@ def action_set_running(self, drs, session=None):
flash('Failed to set state', 'error')
return redirect(self.get_default_url())
- @action('set_failed', "Set state to 'failed'",
- "All running task instances would also be marked as failed, are you sure?",
- single=False)
+ @action(
+ 'set_failed',
+ "Set state to 'failed'",
+ "All running task instances would also be marked as failed, are you sure?",
+ single=False,
+ )
@provide_session
def action_set_failed(self, drs, session=None):
"""Set state to failed."""
@@ -3083,26 +3228,29 @@ def action_set_failed(self, drs, session=None):
count = 0
dirty_ids = []
altered_tis = []
- for dr in session.query(DagRun).filter(
- DagRun.id.in_([dagrun.id for dagrun in drs])).all(): # noqa pylint: disable=no-member
+ for dr in (
+ session.query(DagRun).filter(DagRun.id.in_([dagrun.id for dagrun in drs])).all()
+ ): # noqa pylint: disable=no-member
dirty_ids.append(dr.dag_id)
count += 1
- altered_tis += \
- set_dag_run_state_to_failed(current_app.dag_bag.get_dag(dr.dag_id),
- dr.execution_date,
- commit=True,
- session=session)
+ altered_tis += set_dag_run_state_to_failed(
+ current_app.dag_bag.get_dag(dr.dag_id), dr.execution_date, commit=True, session=session
+ )
altered_ti_count = len(altered_tis)
flash(
"{count} dag runs and {altered_ti_count} task instances "
- "were set to failed".format(count=count, altered_ti_count=altered_ti_count))
+ "were set to failed".format(count=count, altered_ti_count=altered_ti_count)
+ )
except Exception: # noqa pylint: disable=broad-except
flash('Failed to set state', 'error')
return redirect(self.get_default_url())
- @action('set_success', "Set state to 'success'",
- "All task instances would also be marked as success, are you sure?",
- single=False)
+ @action(
+ 'set_success',
+ "Set state to 'success'",
+ "All task instances would also be marked as success, are you sure?",
+ single=False,
+ )
@provide_session
def action_set_success(self, drs, session=None):
"""Set state to success."""
@@ -3110,26 +3258,24 @@ def action_set_success(self, drs, session=None):
count = 0
dirty_ids = []
altered_tis = []
- for dr in session.query(DagRun).filter(
- DagRun.id.in_([dagrun.id for dagrun in drs])).all(): # noqa pylint: disable=no-member
+ for dr in (
+ session.query(DagRun).filter(DagRun.id.in_([dagrun.id for dagrun in drs])).all()
+ ): # noqa pylint: disable=no-member
dirty_ids.append(dr.dag_id)
count += 1
- altered_tis += \
- set_dag_run_state_to_success(current_app.dag_bag.get_dag(dr.dag_id),
- dr.execution_date,
- commit=True,
- session=session)
+ altered_tis += set_dag_run_state_to_success(
+ current_app.dag_bag.get_dag(dr.dag_id), dr.execution_date, commit=True, session=session
+ )
altered_ti_count = len(altered_tis)
flash(
"{count} dag runs and {altered_ti_count} task instances "
- "were set to success".format(count=count, altered_ti_count=altered_ti_count))
+ "were set to success".format(count=count, altered_ti_count=altered_ti_count)
+ )
except Exception: # noqa pylint: disable=broad-except
flash('Failed to set state', 'error')
return redirect(self.get_default_url())
- @action('clear', "Clear the state",
- "All task instances would be cleared, are you sure?",
- single=False)
+ @action('clear', "Clear the state", "All task instances would be cleared, are you sure?", single=False)
@provide_session
def action_clear(self, drs, session=None):
"""Clears the state."""
@@ -3147,8 +3293,10 @@ def action_clear(self, drs, session=None):
cleared_ti_count += len(tis)
models.clear_task_instances(tis, session, dag=dag)
- flash("{count} dag runs and {altered_ti_count} task instances "
- "were cleared".format(count=count, altered_ti_count=cleared_ti_count))
+ flash(
+ "{count} dag runs and {altered_ti_count} task instances "
+ "were cleared".format(count=count, altered_ti_count=cleared_ti_count)
+ )
except Exception: # noqa pylint: disable=broad-except
flash('Failed to clear state', 'error')
return redirect(self.get_default_url())
@@ -3165,10 +3313,12 @@ class LogModelView(AirflowModelView):
method_permission_name = {
'list': 'read',
}
- base_permissions = [permissions.ACTION_CAN_READ, permissions.ACTION_CAN_ACCESS_MENU,]
+ base_permissions = [
+ permissions.ACTION_CAN_READ,
+ permissions.ACTION_CAN_ACCESS_MENU,
+ ]
- list_columns = ['id', 'dttm', 'dag_id', 'task_id', 'event', 'execution_date',
- 'owner', 'extra']
+ list_columns = ['id', 'dttm', 'dag_id', 'task_id', 'event', 'execution_date', 'owner', 'extra']
search_columns = ['dag_id', 'task_id', 'event', 'execution_date', 'owner', 'extra']
base_order = ('dttm', 'desc')
@@ -3194,13 +3344,24 @@ class TaskRescheduleModelView(AirflowModelView):
'list': 'read',
}
- base_permissions = [permissions.ACTION_CAN_READ, permissions.ACTION_CAN_ACCESS_MENU,]
+ base_permissions = [
+ permissions.ACTION_CAN_READ,
+ permissions.ACTION_CAN_ACCESS_MENU,
+ ]
- list_columns = ['id', 'dag_id', 'task_id', 'execution_date', 'try_number',
- 'start_date', 'end_date', 'duration', 'reschedule_date']
+ list_columns = [
+ 'id',
+ 'dag_id',
+ 'task_id',
+ 'execution_date',
+ 'try_number',
+ 'start_date',
+ 'end_date',
+ 'duration',
+ 'reschedule_date',
+ ]
- search_columns = ['dag_id', 'task_id', 'execution_date', 'start_date', 'end_date',
- 'reschedule_date']
+ search_columns = ['dag_id', 'task_id', 'execution_date', 'start_date', 'end_date', 'reschedule_date']
base_order = ('id', 'desc')
@@ -3251,15 +3412,40 @@ class TaskInstanceModelView(AirflowModelView):
page_size = PAGE_SIZE
- list_columns = ['state', 'dag_id', 'task_id', 'execution_date', 'operator',
- 'start_date', 'end_date', 'duration', 'job_id', 'hostname',
- 'unixname', 'priority_weight', 'queue', 'queued_dttm', 'try_number',
- 'pool', 'log_url']
+ list_columns = [
+ 'state',
+ 'dag_id',
+ 'task_id',
+ 'execution_date',
+ 'operator',
+ 'start_date',
+ 'end_date',
+ 'duration',
+ 'job_id',
+ 'hostname',
+ 'unixname',
+ 'priority_weight',
+ 'queue',
+ 'queued_dttm',
+ 'try_number',
+ 'pool',
+ 'log_url',
+ ]
order_columns = [item for item in list_columns if item not in ['try_number', 'log_url']]
- search_columns = ['state', 'dag_id', 'task_id', 'execution_date', 'hostname',
- 'queue', 'pool', 'operator', 'start_date', 'end_date']
+ search_columns = [
+ 'state',
+ 'dag_id',
+ 'task_id',
+ 'execution_date',
+ 'hostname',
+ 'queue',
+ 'pool',
+ 'operator',
+ 'start_date',
+ 'end_date',
+ ]
base_order = ('job_id', 'asc')
@@ -3294,10 +3480,15 @@ def duration_f(self):
}
@provide_session
- @action('clear', lazy_gettext('Clear'),
- lazy_gettext('Are you sure you want to clear the state of the selected task'
- ' instance(s) and set their dagruns to the running state?'),
- single=False)
+ @action(
+ 'clear',
+ lazy_gettext('Clear'),
+ lazy_gettext(
+ 'Are you sure you want to clear the state of the selected task'
+ ' instance(s) and set their dagruns to the running state?'
+ ),
+ single=False,
+ )
def action_clear(self, task_instances, session=None):
"""Clears the action."""
try:
@@ -3326,15 +3517,20 @@ def set_task_instance_state(self, tis, target_state, session=None):
for ti in tis:
ti.set_state(target_state, session)
session.commit()
- flash("{count} task instances were set to '{target_state}'".format(
- count=count, target_state=target_state))
+ flash(
+ "{count} task instances were set to '{target_state}'".format(
+ count=count, target_state=target_state
+ )
+ )
except Exception: # noqa pylint: disable=broad-except
flash('Failed to set state', 'error')
@action('set_running', "Set state to 'running'", '', single=False)
- @auth.has_access([
- (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG),
- ])
+ @auth.has_access(
+ [
+ (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG),
+ ]
+ )
def action_set_running(self, tis):
"""Set state to 'running'"""
self.set_task_instance_state(tis, State.RUNNING)
@@ -3342,9 +3538,11 @@ def action_set_running(self, tis):
return redirect(self.get_redirect())
@action('set_failed', "Set state to 'failed'", '', single=False)
- @auth.has_access([
- (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG),
- ])
+ @auth.has_access(
+ [
+ (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG),
+ ]
+ )
def action_set_failed(self, tis):
"""Set state to 'failed'"""
self.set_task_instance_state(tis, State.FAILED)
@@ -3352,9 +3550,11 @@ def action_set_failed(self, tis):
return redirect(self.get_redirect())
@action('set_success', "Set state to 'success'", '', single=False)
- @auth.has_access([
- (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG),
- ])
+ @auth.has_access(
+ [
+ (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG),
+ ]
+ )
def action_set_success(self, tis):
"""Set state to 'success'"""
self.set_task_instance_state(tis, State.SUCCESS)
@@ -3362,9 +3562,11 @@ def action_set_success(self, tis):
return redirect(self.get_redirect())
@action('set_retry', "Set state to 'up_for_retry'", '', single=False)
- @auth.has_access([
- (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG),
- ])
+ @auth.has_access(
+ [
+ (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG),
+ ]
+ )
def action_set_retry(self, tis):
"""Set state to 'up_for_retry'"""
self.set_task_instance_state(tis, State.UP_FOR_RETRY)
@@ -3387,38 +3589,46 @@ class DagModelView(AirflowModelView):
base_permissions = [
permissions.ACTION_CAN_READ,
permissions.ACTION_CAN_EDIT,
- permissions.ACTION_CAN_DELETE
+ permissions.ACTION_CAN_DELETE,
]
- list_columns = ['dag_id', 'is_paused', 'last_scheduler_run',
- 'last_expired', 'scheduler_lock', 'fileloc', 'owners']
+ list_columns = [
+ 'dag_id',
+ 'is_paused',
+ 'last_scheduler_run',
+ 'last_expired',
+ 'scheduler_lock',
+ 'fileloc',
+ 'owners',
+ ]
- formatters_columns = {
- 'dag_id': wwwutils.dag_link
- }
+ formatters_columns = {'dag_id': wwwutils.dag_link}
base_filters = [['dag_id', DagFilter, lambda: []]]
def get_query(self):
"""Default filters for model"""
return (
- super().get_query() # noqa pylint: disable=no-member
- .filter(or_(models.DagModel.is_active,
- models.DagModel.is_paused))
+ super() # noqa pylint: disable=no-member
+ .get_query()
+ .filter(or_(models.DagModel.is_active, models.DagModel.is_paused))
.filter(~models.DagModel.is_subdag)
)
def get_count_query(self):
"""Default filters for model"""
return (
- super().get_count_query() # noqa pylint: disable=no-member
+ super() # noqa pylint: disable=no-member
+ .get_count_query()
.filter(models.DagModel.is_active)
.filter(~models.DagModel.is_subdag)
)
- @auth.has_access([
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
- ])
+ @auth.has_access(
+ [
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
+ ]
+ )
@provide_session
@expose('/autocomplete')
def autocomplete(self, session=None):
@@ -3430,12 +3640,12 @@ def autocomplete(self, session=None):
# Provide suggestions of dag_ids and owners
dag_ids_query = session.query(DagModel.dag_id.label('item')).filter( # pylint: disable=no-member
- ~DagModel.is_subdag, DagModel.is_active,
- DagModel.dag_id.ilike('%' + query + '%')) # noqa pylint: disable=no-member
+ ~DagModel.is_subdag, DagModel.is_active, DagModel.dag_id.ilike('%' + query + '%')
+ ) # noqa pylint: disable=no-member
owners_query = session.query(func.distinct(DagModel.owners).label('item')).filter(
- ~DagModel.is_subdag, DagModel.is_active,
- DagModel.owners.ilike('%' + query + '%')) # noqa pylint: disable=no-member
+ ~DagModel.is_subdag, DagModel.is_active, DagModel.owners.ilike('%' + query + '%')
+ ) # noqa pylint: disable=no-member
# Hide DAGs if not showing status: "all"
status = flask_session.get(FILTER_STATUS_COOKIE)
diff --git a/airflow/www/widgets.py b/airflow/www/widgets.py
index 199c274ac0e7c..172380c2a733c 100644
--- a/airflow/www/widgets.py
+++ b/airflow/www/widgets.py
@@ -36,7 +36,6 @@ class AirflowDateTimePickerWidget:
""
''
""
-
)
def __call__(self, field, **kwargs):
@@ -46,6 +45,4 @@ def __call__(self, field, **kwargs):
field.data = ""
template = self.data_template
- return Markup(
- template % {"text": html_params(type="text", value=field.data, **kwargs)}
- )
+ return Markup(template % {"text": html_params(type="text", value=field.data, **kwargs)})
diff --git a/chart/tests/test_celery_kubernetes_executor.py b/chart/tests/test_celery_kubernetes_executor.py
index 57c39809b9351..fb219297553f3 100644
--- a/chart/tests/test_celery_kubernetes_executor.py
+++ b/chart/tests/test_celery_kubernetes_executor.py
@@ -18,6 +18,7 @@
import unittest
import jmespath
+
from tests.helm_template_generator import render_chart
diff --git a/chart/tests/test_celery_kubernetes_pod_launcher_role.py b/chart/tests/test_celery_kubernetes_pod_launcher_role.py
index 535be11fafabc..952bc390b2786 100644
--- a/chart/tests/test_celery_kubernetes_pod_launcher_role.py
+++ b/chart/tests/test_celery_kubernetes_pod_launcher_role.py
@@ -18,6 +18,7 @@
import unittest
import jmespath
+
from tests.helm_template_generator import render_chart
diff --git a/chart/tests/test_chart_quality.py b/chart/tests/test_chart_quality.py
index 32237cb7aa43d..5f17165ea3dea 100644
--- a/chart/tests/test_chart_quality.py
+++ b/chart/tests/test_chart_quality.py
@@ -18,11 +18,10 @@
import json
import os
import unittest
-import yaml
+import yaml
from jsonschema import validate
-
CHART_FOLDER = os.path.dirname(os.path.dirname(__file__))
diff --git a/chart/tests/test_dags_persistent_volume_claim.py b/chart/tests/test_dags_persistent_volume_claim.py
index 069a0cd5998d2..946c40fe7c319 100644
--- a/chart/tests/test_dags_persistent_volume_claim.py
+++ b/chart/tests/test_dags_persistent_volume_claim.py
@@ -18,6 +18,7 @@
import unittest
import jmespath
+
from tests.helm_template_generator import render_chart
diff --git a/chart/tests/test_flower_authorization.py b/chart/tests/test_flower_authorization.py
index f0cc5b07b2af7..0520ddd28f273 100644
--- a/chart/tests/test_flower_authorization.py
+++ b/chart/tests/test_flower_authorization.py
@@ -18,6 +18,7 @@
import unittest
import jmespath
+
from tests.helm_template_generator import render_chart
diff --git a/chart/tests/test_git_sync_scheduler.py b/chart/tests/test_git_sync_scheduler.py
index a01c0f216e8f4..068f36c83cf71 100644
--- a/chart/tests/test_git_sync_scheduler.py
+++ b/chart/tests/test_git_sync_scheduler.py
@@ -18,6 +18,7 @@
import unittest
import jmespath
+
from tests.helm_template_generator import render_chart
diff --git a/chart/tests/test_git_sync_webserver.py b/chart/tests/test_git_sync_webserver.py
index 30c3f330a48a8..75ec51bf1fd83 100644
--- a/chart/tests/test_git_sync_webserver.py
+++ b/chart/tests/test_git_sync_webserver.py
@@ -18,6 +18,7 @@
import unittest
import jmespath
+
from tests.helm_template_generator import render_chart
diff --git a/chart/tests/test_git_sync_worker.py b/chart/tests/test_git_sync_worker.py
index a70d311a121b1..e5036d76a4384 100644
--- a/chart/tests/test_git_sync_worker.py
+++ b/chart/tests/test_git_sync_worker.py
@@ -18,6 +18,7 @@
import unittest
import jmespath
+
from tests.helm_template_generator import render_chart
diff --git a/chart/tests/test_migrate_database_job.py b/chart/tests/test_migrate_database_job.py
index 0524315d5e608..4b92acaf60100 100644
--- a/chart/tests/test_migrate_database_job.py
+++ b/chart/tests/test_migrate_database_job.py
@@ -18,6 +18,7 @@
import unittest
import jmespath
+
from tests.helm_template_generator import render_chart
diff --git a/chart/tests/test_pod_template_file.py b/chart/tests/test_pod_template_file.py
index 0673e08ba7b18..d9334deb07d31 100644
--- a/chart/tests/test_pod_template_file.py
+++ b/chart/tests/test_pod_template_file.py
@@ -17,10 +17,11 @@
import unittest
from os import remove
-from os.path import realpath, dirname
+from os.path import dirname, realpath
from shutil import copyfile
import jmespath
+
from tests.helm_template_generator import render_chart
ROOT_FOLDER = realpath(dirname(realpath(__file__)) + "/..")
diff --git a/chart/tests/test_scheduler.py b/chart/tests/test_scheduler.py
index 976984823d55d..eb5225e35c389 100644
--- a/chart/tests/test_scheduler.py
+++ b/chart/tests/test_scheduler.py
@@ -18,6 +18,7 @@
import unittest
import jmespath
+
from tests.helm_template_generator import render_chart
diff --git a/chart/tests/test_worker.py b/chart/tests/test_worker.py
index 2fc6d17c48162..9b3515ef81562 100644
--- a/chart/tests/test_worker.py
+++ b/chart/tests/test_worker.py
@@ -18,6 +18,7 @@
import unittest
import jmespath
+
from tests.helm_template_generator import render_chart
diff --git a/dags/test_dag.py b/dags/test_dag.py
index 8a1695f310e3f..75f279c5a1dd0 100644
--- a/dags/test_dag.py
+++ b/dags/test_dag.py
@@ -23,17 +23,11 @@
from airflow.utils.dates import days_ago
now = datetime.now()
-now_to_the_hour = (
- now - timedelta(0, 0, 0, 0, 0, 3)
-).replace(minute=0, second=0, microsecond=0)
+now_to_the_hour = (now - timedelta(0, 0, 0, 0, 0, 3)).replace(minute=0, second=0, microsecond=0)
START_DATE = now_to_the_hour
DAG_NAME = 'test_dag_v1'
-default_args = {
- 'owner': 'airflow',
- 'depends_on_past': True,
- 'start_date': days_ago(2)
-}
+default_args = {'owner': 'airflow', 'depends_on_past': True, 'start_date': days_ago(2)}
dag = DAG(DAG_NAME, schedule_interval='*/10 * * * *', default_args=default_args)
run_this_1 = DummyOperator(task_id='run_this_1', dag=dag)
diff --git a/dev/airflow-github b/dev/airflow-github
index 17490f7c55ff8..10cfdfa862133 100755
--- a/dev/airflow-github
+++ b/dev/airflow-github
@@ -29,8 +29,7 @@ from github import Github
from github.Issue import Issue
from github.PullRequest import PullRequest
-GIT_COMMIT_FIELDS = ['id', 'author_name',
- 'author_email', 'date', 'subject', 'body']
+GIT_COMMIT_FIELDS = ['id', 'author_name', 'author_email', 'date', 'subject', 'body']
GIT_LOG_FORMAT = '%x1f'.join(['%h', '%an', '%ae', '%ad', '%s', '%b']) + '%x1e'
pr_title_re = re.compile(r".*\((#[0-9]{1,6})\)$")
@@ -131,17 +130,20 @@ def cli():
@cli.command(short_help='Compare a GitHub target version against git merges')
@click.argument('target_version')
@click.argument('github-token', envvar='GITHUB_TOKEN')
-@click.option('--previous-version',
- 'previous_version',
- help="Specify the previous tag on the working branch to limit"
- " searching for few commits to find the cherry-picked commits")
+@click.option(
+ '--previous-version',
+ 'previous_version',
+ help="Specify the previous tag on the working branch to limit"
+ " searching for few commits to find the cherry-picked commits",
+)
@click.option('--unmerged', 'show_uncherrypicked_only', help="Show unmerged issues only", is_flag=True)
def compare(target_version, github_token, previous_version=None, show_uncherrypicked_only=False):
repo = git.Repo(".", search_parent_directories=True)
github_handler = Github(github_token)
- milestone_issues: List[Issue] = list(github_handler.search_issues(
- f"repo:apache/airflow milestone:\"Airflow {target_version}\""))
+ milestone_issues: List[Issue] = list(
+ github_handler.search_issues(f"repo:apache/airflow milestone:\"Airflow {target_version}\"")
+ )
num_cherrypicked = 0
num_uncherrypicked = Counter()
@@ -151,13 +153,16 @@ def compare(target_version, github_token, previous_version=None, show_uncherrypi
# !s forces as string
formatstr = "{id:<8}|{typ!s:<15}|{status!s}|{description:<83.83}|{merged:<6}|{commit:>9.7}"
- print(formatstr.format(
- id="ISSUE",
- typ="TYPE",
- status="STATUS".ljust(10),
- description="DESCRIPTION",
- merged="MERGED",
- commit="COMMIT"))
+ print(
+ formatstr.format(
+ id="ISSUE",
+ typ="TYPE",
+ status="STATUS".ljust(10),
+ description="DESCRIPTION",
+ merged="MERGED",
+ commit="COMMIT",
+ )
+ )
for issue in milestone_issues:
commit_in_master = get_commit_in_master_associated_with_pr(repo, issue)
@@ -180,13 +185,17 @@ def compare(target_version, github_token, previous_version=None, show_uncherrypi
description=issue.title,
)
- print(formatstr.format(
- **fields,
- merged=cherrypicked,
- commit=commit_in_master if commit_in_master else ""))
+ print(
+ formatstr.format(
+ **fields, merged=cherrypicked, commit=commit_in_master if commit_in_master else ""
+ )
+ )
- print("Commits on branch: {:d}, {:d} ({}) yet to be cherry-picked".format(
- num_cherrypicked, sum(num_uncherrypicked.values()), dict(num_uncherrypicked)))
+ print(
+ "Commits on branch: {:d}, {:d} ({}) yet to be cherry-picked".format(
+ num_cherrypicked, sum(num_uncherrypicked.values()), dict(num_uncherrypicked)
+ )
+ )
@cli.command(short_help='Build a CHANGELOG grouped by GitHub Issue type')
@@ -196,8 +205,7 @@ def compare(target_version, github_token, previous_version=None, show_uncherrypi
def changelog(previous_version, target_version, github_token):
repo = git.Repo(".", search_parent_directories=True)
# Get a list of issues/PRs that have been committed on the current branch.
- log_args = [
- f'--format={GIT_LOG_FORMAT}', previous_version + ".." + target_version]
+ log_args = [f'--format={GIT_LOG_FORMAT}', previous_version + ".." + target_version]
log = repo.git.log(*log_args)
log = log.strip('\n\x1e').split("\x1e")
@@ -221,6 +229,7 @@ def changelog(previous_version, target_version, github_token):
if __name__ == "__main__":
import doctest
+
(failure_count, test_count) = doctest.testmod()
if failure_count:
sys.exit(-1)
diff --git a/dev/airflow-license b/dev/airflow-license
index 3487309ea38d3..2aac7e59fd5f0 100755
--- a/dev/airflow-license
+++ b/dev/airflow-license
@@ -24,10 +24,22 @@ import string
import slugify
# order is important
-_licenses = {'MIT': ['Permission is hereby granted free of charge', 'The above copyright notice and this permission notice shall'],
- 'BSD-3': ['Redistributions of source code must retain the above copyright', 'Redistributions in binary form must reproduce the above copyright', 'specific prior written permission'],
- 'BSD-2': ['Redistributions of source code must retain the above copyright', 'Redistributions in binary form must reproduce the above copyright'],
- 'AL': ['http://www.apache.org/licenses/LICENSE-2.0']}
+_licenses = {
+ 'MIT': [
+ 'Permission is hereby granted free of charge',
+ 'The above copyright notice and this permission notice shall',
+ ],
+ 'BSD-3': [
+ 'Redistributions of source code must retain the above copyright',
+ 'Redistributions in binary form must reproduce the above copyright',
+ 'specific prior written permission',
+ ],
+ 'BSD-2': [
+ 'Redistributions of source code must retain the above copyright',
+ 'Redistributions in binary form must reproduce the above copyright',
+ ],
+ 'AL': ['http://www.apache.org/licenses/LICENSE-2.0'],
+}
def get_notices():
@@ -56,16 +68,14 @@ def parse_license_file(project_name):
if __name__ == "__main__":
- print("{:<30}|{:<50}||{:<20}||{:<10}"
- .format("PROJECT", "URL", "LICENSE TYPE DEFINED", "DETECTED"))
+ print("{:<30}|{:<50}||{:<20}||{:<10}".format("PROJECT", "URL", "LICENSE TYPE DEFINED", "DETECTED"))
notices = get_notices()
for notice in notices:
notice = notice[0]
license = parse_license_file(notice[1])
- print("{:<30}|{:<50}||{:<20}||{:<10}"
- .format(notice[1], notice[2][:50], notice[0], license))
+ print("{:<30}|{:<50}||{:<20}||{:<10}".format(notice[1], notice[2][:50], notice[0], license))
file_count = len([name for name in os.listdir("../licenses")])
print("Defined licenses: {} Files found: {}".format(len(notices), file_count))
diff --git a/dev/send_email.py b/dev/send_email.py
index 6d698f7ede669..0e9c8b7a45f02 100755
--- a/dev/send_email.py
+++ b/dev/send_email.py
@@ -37,10 +37,7 @@
SMTP_PORT = 587
SMTP_SERVER = "mail-relay.apache.org"
-MAILING_LIST = {
- "dev": "dev@airflow.apache.org",
- "users": "users@airflow.apache.org"
-}
+MAILING_LIST = {"dev": "dev@airflow.apache.org", "users": "users@airflow.apache.org"}
def string_comma_to_list(message: str) -> List[str]:
@@ -51,9 +48,13 @@ def string_comma_to_list(message: str) -> List[str]:
def send_email(
- smtp_server: str, smpt_port: int,
- username: str, password: str,
- sender_email: str, receiver_email: Union[str, List], message: str,
+ smtp_server: str,
+ smpt_port: int,
+ username: str,
+ password: str,
+ sender_email: str,
+ receiver_email: Union[str, List],
+ message: str,
):
"""
Send a simple text email (SMTP)
@@ -92,8 +93,7 @@ def show_message(entity: str, message: str):
def inter_send_email(
- username: str, password: str, sender_email: str, receiver_email: Union[str, List],
- message: str
+ username: str, password: str, sender_email: str, receiver_email: Union[str, List], message: str
):
"""
Send email using SMTP
@@ -104,7 +104,13 @@ def inter_send_email(
try:
send_email(
- SMTP_SERVER, SMTP_PORT, username, password, sender_email, receiver_email, message,
+ SMTP_SERVER,
+ SMTP_PORT,
+ username,
+ password,
+ sender_email,
+ receiver_email,
+ message,
)
click.secho("✅ Email sent successfully", fg="green")
except smtplib.SMTPAuthenticationError:
@@ -117,10 +123,8 @@ class BaseParameters:
"""
Base Class to send emails using Apache Creds and for Jinja templating
"""
- def __init__(
- self, name=None, email=None, username=None, password=None,
- version=None, version_rc=None
- ):
+
+ def __init__(self, name=None, email=None, username=None, password=None, version=None, version_rc=None):
self.name = name
self.email = email
self.username = username
@@ -136,42 +140,53 @@ def __repr__(self):
@click.group(context_settings=dict(help_option_names=["-h", "--help"]))
@click.pass_context
@click.option(
- "-e", "--apache_email",
+ "-e",
+ "--apache_email",
prompt="Apache Email",
- envvar="APACHE_EMAIL", show_envvar=True,
+ envvar="APACHE_EMAIL",
+ show_envvar=True,
help="Your Apache email will be used for SMTP From",
- required=True
+ required=True,
)
@click.option(
- "-u", "--apache_username",
+ "-u",
+ "--apache_username",
prompt="Apache Username",
- envvar="APACHE_USERNAME", show_envvar=True,
+ envvar="APACHE_USERNAME",
+ show_envvar=True,
help="Your LDAP Apache username",
required=True,
)
@click.password_option( # type: ignore
- "-p", "--apache_password",
+ "-p",
+ "--apache_password",
prompt="Apache Password",
- envvar="APACHE_PASSWORD", show_envvar=True,
+ envvar="APACHE_PASSWORD",
+ show_envvar=True,
help="Your LDAP Apache password",
required=True,
)
@click.option(
- "-v", "--version",
+ "-v",
+ "--version",
prompt="Version",
- envvar="AIRFLOW_VERSION", show_envvar=True,
+ envvar="AIRFLOW_VERSION",
+ show_envvar=True,
help="Release Version",
required=True,
)
@click.option(
- "-rc", "--version_rc",
+ "-rc",
+ "--version_rc",
prompt="Version (with RC)",
- envvar="AIRFLOW_VERSION_RC", show_envvar=True,
+ envvar="AIRFLOW_VERSION_RC",
+ show_envvar=True,
help="Release Candidate Version",
required=True,
)
@click.option( # type: ignore
- "-n", "--name",
+ "-n",
+ "--name",
prompt="Your Name",
default=lambda: os.environ.get('USER', ''),
show_default="Current User",
@@ -180,9 +195,13 @@ def __repr__(self):
required=True,
)
def cli(
- ctx, apache_email: str,
- apache_username: str, apache_password: str, version: str, version_rc: str,
- name: str
+ ctx,
+ apache_email: str,
+ apache_username: str,
+ apache_password: str,
+ version: str,
+ version_rc: str,
+ name: str,
):
"""
🚀 CLI to send emails for the following:
@@ -232,7 +251,8 @@ def vote(base_parameters, receiver_email: str):
@cli.command("result")
@click.option(
- "-re", "--receiver_email",
+ "-re",
+ "--receiver_email",
default=MAILING_LIST.get("dev"),
type=click.STRING,
prompt="The receiver email (To:)",
@@ -258,25 +278,23 @@ def vote(base_parameters, receiver_email: str):
@click.pass_obj
def result(
base_parameters,
- receiver_email: str, vote_bindings: str, vote_nonbindings: str, vote_negatives: str,
+ receiver_email: str,
+ vote_bindings: str,
+ vote_nonbindings: str,
+ vote_negatives: str,
):
"""
Send email with results of voting on RC
"""
template_file = "templates/result_email.j2"
base_parameters.template_arguments["receiver_email"] = receiver_email
- base_parameters.template_arguments["vote_bindings"] = string_comma_to_list(
- vote_bindings
- )
- base_parameters.template_arguments["vote_nonbindings"] = string_comma_to_list(
- vote_nonbindings
- )
- base_parameters.template_arguments["vote_negatives"] = string_comma_to_list(
- vote_negatives
- )
+ base_parameters.template_arguments["vote_bindings"] = string_comma_to_list(vote_bindings)
+ base_parameters.template_arguments["vote_nonbindings"] = string_comma_to_list(vote_nonbindings)
+ base_parameters.template_arguments["vote_negatives"] = string_comma_to_list(vote_negatives)
message = render_template(template_file, **base_parameters.template_arguments)
inter_send_email(
- base_parameters.username, base_parameters.password,
+ base_parameters.username,
+ base_parameters.password,
base_parameters.template_arguments["sender_email"],
base_parameters.template_arguments["receiver_email"],
message,
@@ -288,12 +306,10 @@ def result(
"--receiver_email",
default=",".join(list(MAILING_LIST.values())),
prompt="The receiver email (To:)",
- help="Receiver's email address. If more than 1, separate them by comma"
+ help="Receiver's email address. If more than 1, separate them by comma",
)
@click.pass_obj
-def announce(
- base_parameters, receiver_email: str
-):
+def announce(base_parameters, receiver_email: str):
"""
Send email to announce release of the new version
"""
@@ -304,7 +320,8 @@ def announce(
message = render_template(template_file, **base_parameters.template_arguments)
inter_send_email(
- base_parameters.username, base_parameters.password,
+ base_parameters.username,
+ base_parameters.password,
base_parameters.template_arguments["sender_email"],
base_parameters.template_arguments["receiver_email"],
message,
@@ -312,14 +329,12 @@ def announce(
if click.confirm("Show Slack message for announcement?", default=True):
base_parameters.template_arguments["slack_rc"] = False
- slack_msg = render_template(
- "templates/slack.j2", **base_parameters.template_arguments)
+ slack_msg = render_template("templates/slack.j2", **base_parameters.template_arguments)
show_message("Slack", slack_msg)
if click.confirm("Show Twitter message for announcement?", default=True):
- twitter_msg = render_template(
- "templates/twitter.j2", **base_parameters.template_arguments)
+ twitter_msg = render_template("templates/twitter.j2", **base_parameters.template_arguments)
show_message("Twitter", twitter_msg)
if __name__ == '__main__':
- cli() # pylint: disable=no-value-for-parameter
+ cli() # pylint: disable=no-value-for-parameter
diff --git a/docs/build_docs.py b/docs/build_docs.py
index 8f565d8e1993f..2799a46169d78 100755
--- a/docs/build_docs.py
+++ b/docs/build_docs.py
@@ -87,8 +87,13 @@ def __lt__(self, other):
line_no_b = other.line_no or 0
context_line_a = self.context_line or ''
context_line_b = other.context_line or ''
- return (file_path_a, line_no_a, context_line_a, self.spelling, self.message) < \
- (file_path_b, line_no_b, context_line_b, other.spelling, other.message)
+ return (file_path_a, line_no_a, context_line_a, self.spelling, self.message) < (
+ file_path_b,
+ line_no_b,
+ context_line_b,
+ other.spelling,
+ other.message,
+ )
build_errors: List[DocBuildError] = []
@@ -221,7 +226,7 @@ def generate_build_error(path, line_no, operator_name):
f".. seealso::\n"
f" For more information on how to use this operator, take a look at the guide:\n"
f" :ref:`howto/operator:{operator_name}`\n"
- )
+ ),
)
# Extract operators for which there are existing .rst guides
@@ -261,9 +266,7 @@ def generate_build_error(path, line_no, operator_name):
if f":ref:`howto/operator:{existing_operator}`" in ast.get_docstring(class_def):
continue
- build_errors.append(
- generate_build_error(py_module_path, class_def.lineno, existing_operator)
- )
+ build_errors.append(generate_build_error(py_module_path, class_def.lineno, existing_operator))
def assert_file_not_contains(file_path: str, pattern: str, message: str) -> None:
@@ -325,8 +328,9 @@ def check_class_links_in_operators_and_hooks_ref() -> None:
airflow_modules = find_modules() - find_modules(deprecated_only=True)
airflow_modules = {
- o for o in airflow_modules if any(f".{d}." in o for d in
- ["operators", "hooks", "sensors", "transfers"])
+ o
+ for o in airflow_modules
+ if any(f".{d}." in o for d in ["operators", "hooks", "sensors", "transfers"])
}
missing_modules = airflow_modules - current_modules_in_file
@@ -359,20 +363,12 @@ def check_guide_links_in_operators_and_hooks_ref() -> None:
if "_partials" not in guide
]
# Remove partials and index
- all_guides = [
- guide
- for guide in all_guides
- if "/_partials/" not in guide and not guide.endswith("index")
- ]
+ all_guides = [guide for guide in all_guides if "/_partials/" not in guide and not guide.endswith("index")]
with open(os.path.join(DOCS_DIR, "operators-and-hooks-ref.rst")) as ref_file:
content = ref_file.read()
- missing_guides = [
- guide
- for guide in all_guides
- if guide not in content
- ]
+ missing_guides = [guide for guide in all_guides if guide not in content]
if missing_guides:
guide_text_list = "\n".join(f":doc:`How to use <{guide}>`" for guide in missing_guides)
@@ -403,7 +399,7 @@ def check_exampleinclude_for_example_dags():
message=(
"literalinclude directive is prohibited for example DAGs. \n"
"You should use the exampleinclude directive to include example DAGs."
- )
+ ),
)
@@ -418,7 +414,7 @@ def check_enforce_code_block():
message=(
"We recommend using the code-block directive instead of the code directive. "
"The code-block directive is more feature-full."
- )
+ ),
)
@@ -444,10 +440,12 @@ def check_google_guides():
doc_files = glob(f"{DOCS_DIR}/howto/operator/google/**/*.rst", recursive=True)
doc_names = {f.split("/")[-1].rsplit(".")[0] for f in doc_files}
- operators_files = chain(*[
- glob(f"{ROOT_PACKAGE_DIR}/providers/google/*/{resource_type}/*.py")
- for resource_type in ["operators", "sensors", "transfers"]
- ])
+ operators_files = chain(
+ *[
+ glob(f"{ROOT_PACKAGE_DIR}/providers/google/*/{resource_type}/*.py")
+ for resource_type in ["operators", "sensors", "transfers"]
+ ]
+ )
operators_files = (f for f in operators_files if not f.endswith("__init__.py"))
operator_names = {f.split("/")[-1].rsplit(".")[0] for f in operators_files}
@@ -495,6 +493,7 @@ def guess_lexer_for_filename(filename):
lexer = get_lexer_for_filename(filename)
except ClassNotFound:
from pygments.lexers.special import TextLexer
+
lexer = TextLexer()
return lexer
@@ -504,6 +503,7 @@ def guess_lexer_for_filename(filename):
with suppress(ImportError):
import pygments
from pygments.formatters.terminal import TerminalFormatter
+
code = pygments.highlight(
code=code, formatter=TerminalFormatter(), lexer=guess_lexer_for_filename(file_path)
)
@@ -541,9 +541,7 @@ def parse_sphinx_warnings(warning_text: str) -> List[DocBuildError]:
except Exception: # noqa pylint: disable=broad-except
# If an exception occurred while parsing the warning message, display the raw warning message.
sphinx_build_errors.append(
- DocBuildError(
- file_path=None, line_no=None, message=sphinx_warning
- )
+ DocBuildError(file_path=None, line_no=None, message=sphinx_warning)
)
else:
sphinx_build_errors.append(DocBuildError(file_path=None, line_no=None, message=sphinx_warning))
@@ -635,7 +633,7 @@ def check_spelling() -> None:
"-D", # override the extensions because one of them throws an error on the spelling builder
f"extensions={','.join(extensions_to_use)}",
".", # path to documentation source files
- "_build/spelling"
+ "_build/spelling",
]
print("Executing cmd: ", " ".join([shlex.quote(c) for c in build_cmd]))
@@ -643,8 +641,12 @@ def check_spelling() -> None:
if completed_proc.returncode != 0:
spelling_errors.append(
SpellingError(
- file_path=None, line_no=None, spelling=None, suggestion=None, context_line=None,
- message=f"Sphinx spellcheck returned non-zero exit status: {completed_proc.returncode}."
+ file_path=None,
+ line_no=None,
+ spelling=None,
+ suggestion=None,
+ context_line=None,
+ message=f"Sphinx spellcheck returned non-zero exit status: {completed_proc.returncode}.",
)
)
@@ -710,10 +712,10 @@ def print_build_errors_and_exit(message) -> None:
parser = argparse.ArgumentParser(description='Builds documentation and runs spell checking')
-parser.add_argument('--docs-only', dest='docs_only', action='store_true',
- help='Only build documentation')
-parser.add_argument('--spellcheck-only', dest='spellcheck_only', action='store_true',
- help='Only perform spellchecking')
+parser.add_argument('--docs-only', dest='docs_only', action='store_true', help='Only build documentation')
+parser.add_argument(
+ '--spellcheck-only', dest='spellcheck_only', action='store_true', help='Only perform spellchecking'
+)
args = parser.parse_args()
diff --git a/docs/conf.py b/docs/conf.py
index a035b596f878e..aab471a28f1a3 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -142,14 +142,9 @@
"sphinxcontrib.spelling",
]
-autodoc_default_options = {
- 'show-inheritance': True,
- 'members': True
-}
+autodoc_default_options = {'show-inheritance': True, 'members': True}
-jinja_contexts = {
- 'config_ctx': {"configs": default_config_yaml()}
-}
+jinja_contexts = {'config_ctx': {"configs": default_config_yaml()}}
viewcode_follow_imported_members = True
@@ -213,7 +208,7 @@
# Templates or partials
'autoapi_templates',
'howto/operator/google/_partials',
- 'howto/operator/microsoft/_partials'
+ 'howto/operator/microsoft/_partials',
]
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
@@ -225,7 +220,9 @@ def _get_rst_filepath_from_path(filepath: str):
elif os.path.isfile(filepath) and filepath.endswith('/__init__.py'):
result = filepath.rpartition("/")[0]
else:
- result = filepath.rpartition(".",)[0]
+ result = filepath.rpartition(
+ ".",
+ )[0]
result += "/index.rst"
result = f"_api/{os.path.relpath(result, ROOT_DIR)}"
@@ -252,8 +249,7 @@ def _get_rst_filepath_from_path(filepath: str):
}
providers_package_indexes = {
- f"_api/{os.path.relpath(name, ROOT_DIR)}/index.rst"
- for name in providers_packages_roots
+ f"_api/{os.path.relpath(name, ROOT_DIR)}/index.rst" for name in providers_packages_roots
}
exclude_patterns.extend(providers_package_indexes)
@@ -269,10 +265,7 @@ def _get_rst_filepath_from_path(filepath: str):
for p in excluded_packages_in_providers
for path in glob(f"{p}/**/*", recursive=True)
}
-excluded_files_in_providers |= {
- _get_rst_filepath_from_path(name)
- for name in excluded_packages_in_providers
-}
+excluded_files_in_providers |= {_get_rst_filepath_from_path(name) for name in excluded_packages_in_providers}
exclude_patterns.extend(excluded_files_in_providers)
@@ -437,10 +430,8 @@ def _get_rst_filepath_from_path(filepath: str):
latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
# 'papersize': 'letterpaper',
-
# The font size ('10pt', '11pt' or '12pt').
# 'pointsize': '10pt',
-
# Additional stuff for the LaTeX preamble.
# 'preamble': '',
} # type: Dict[str,str]
@@ -449,8 +440,7 @@ def _get_rst_filepath_from_path(filepath: str):
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
- ('index', 'Airflow.tex', 'Airflow Documentation',
- 'Apache Airflow', 'manual'),
+ ('index', 'Airflow.tex', 'Airflow Documentation', 'Apache Airflow', 'manual'),
]
# The name of an image file (relative to this directory) to place at the top of
@@ -478,10 +468,7 @@ def _get_rst_filepath_from_path(filepath: str):
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
-man_pages = [
- ('index', 'airflow', 'Airflow Documentation',
- ['Apache Airflow'], 1)
-]
+man_pages = [('index', 'airflow', 'Airflow Documentation', ['Apache Airflow'], 1)]
# If true, show URL addresses after external links.
# man_show_urls = False
@@ -492,12 +479,17 @@ def _get_rst_filepath_from_path(filepath: str):
# Grouping the document tree into Texinfo files. List of tuples
# (source start file, target name, title, author,
# dir menu entry, description, category)
-texinfo_documents = [(
- 'index', 'Airflow', 'Airflow Documentation',
- 'Apache Airflow', 'Airflow',
- 'Airflow is a system to programmatically author, schedule and monitor data pipelines.',
- 'Miscellaneous'
-), ]
+texinfo_documents = [
+ (
+ 'index',
+ 'Airflow',
+ 'Airflow Documentation',
+ 'Apache Airflow',
+ 'Airflow',
+ 'Airflow is a system to programmatically author, schedule and monitor data pipelines.',
+ 'Miscellaneous',
+ ),
+]
# Documents to append as an appendix to all manuals.
# texinfo_appendices = []
@@ -546,10 +538,7 @@ def _get_rst_filepath_from_path(filepath: str):
redirects_file = 'redirects.txt'
# -- Options for redoc docs ----------------------------------
-OPENAPI_FILE = os.path.join(
- os.path.dirname(__file__),
- "..", "airflow", "api_connexion", "openapi", "v1.yaml"
-)
+OPENAPI_FILE = os.path.join(os.path.dirname(__file__), "..", "airflow", "api_connexion", "openapi", "v1.yaml")
redoc = [
{
'name': 'Airflow REST API',
@@ -558,7 +547,7 @@ def _get_rst_filepath_from_path(filepath: str):
'opts': {
'hide-hostname': True,
'no-auto-auth': True,
- }
+ },
},
]
diff --git a/docs/exts/docroles.py b/docs/exts/docroles.py
index 5dd71cd1bfc6d..a1ac0cb2ecbee 100644
--- a/docs/exts/docroles.py
+++ b/docs/exts/docroles.py
@@ -52,21 +52,21 @@ def get_template_field(env, fullname):
template_fields = getattr(clazz, "template_fields")
if not template_fields:
- raise RoleException(
- f"Could not find the template fields for {classname} class in {modname} module."
- )
+ raise RoleException(f"Could not find the template fields for {classname} class in {modname} module.")
return list(template_fields)
-def template_field_role(app,
- typ, # pylint: disable=unused-argument
- rawtext,
- text,
- lineno,
- inliner,
- options=None, # pylint: disable=unused-argument
- content=None): # pylint: disable=unused-argument
+def template_field_role(
+ app,
+ typ, # pylint: disable=unused-argument
+ rawtext,
+ text,
+ lineno,
+ inliner,
+ options=None, # pylint: disable=unused-argument
+ content=None,
+): # pylint: disable=unused-argument
"""
A role that allows you to include a list of template fields in the middle of the text. This is especially
useful when writing guides describing how to use the operator.
@@ -90,7 +90,10 @@ def template_field_role(app,
try:
template_fields = get_template_field(app.env, text)
except RoleException as e:
- msg = inliner.reporter.error(f"invalid class name {text} \n{e}", line=lineno)
+ msg = inliner.reporter.error(
+ f"invalid class name {text} \n{e}",
+ line=lineno,
+ )
prb = inliner.problematic(rawtext, rawtext, msg)
return [prb], [msg]
@@ -106,6 +109,7 @@ def template_field_role(app,
def setup(app):
"""Sets the extension up"""
from docutils.parsers.rst import roles # pylint: disable=wrong-import-order
+
roles.register_local_role("template-fields", partial(template_field_role, app))
return {"version": "builtin", "parallel_read_safe": True, "parallel_write_safe": True}
diff --git a/docs/exts/exampleinclude.py b/docs/exts/exampleinclude.py
index 774098b6cb32a..141ca8202bef6 100644
--- a/docs/exts/exampleinclude.py
+++ b/docs/exts/exampleinclude.py
@@ -34,6 +34,7 @@
try:
import sphinx_airflow_theme # pylint: disable=unused-import
+
airflow_theme_is_available = True
except ImportError:
airflow_theme_is_available = False
@@ -147,8 +148,9 @@ def register_source(app, env, modname):
try:
analyzer = ModuleAnalyzer.for_module(modname)
except Exception as ex: # pylint: disable=broad-except
- logger.info("Module \"%s\" could not be loaded. Full source will not be available. \"%s\"",
- modname, ex)
+ logger.info(
+ "Module \"%s\" could not be loaded. Full source will not be available. \"%s\"", modname, ex
+ )
env._viewcode_modules[modname] = False
return False
@@ -168,6 +170,8 @@ def register_source(app, env, modname):
env._viewcode_modules[modname] = entry
return True
+
+
# pylint: enable=protected-access
@@ -232,6 +236,8 @@ def doctree_read(app, doctree):
onlynode = create_node(env, relative_path, show_button)
objnode.replace_self(onlynode)
+
+
# pylint: enable=protected-access
diff --git a/kubernetes_tests/test_kubernetes_executor.py b/kubernetes_tests/test_kubernetes_executor.py
index de2726aca7f74..ba4900f01c180 100644
--- a/kubernetes_tests/test_kubernetes_executor.py
+++ b/kubernetes_tests/test_kubernetes_executor.py
@@ -27,7 +27,7 @@
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
-CLUSTER_FORWARDED_PORT = (os.environ.get('CLUSTER_FORWARDED_PORT') or "8080")
+CLUSTER_FORWARDED_PORT = os.environ.get('CLUSTER_FORWARDED_PORT') or "8080"
KUBERNETES_HOST_PORT = (os.environ.get('CLUSTER_HOST') or "localhost") + ":" + CLUSTER_FORWARDED_PORT
print()
@@ -36,7 +36,6 @@
class TestKubernetesExecutor(unittest.TestCase):
-
@staticmethod
def _describe_resources(namespace: str):
print("=" * 80)
@@ -93,8 +92,7 @@ def setUp(self):
def tearDown(self):
self.session.close()
- def monitor_task(self, host, execution_date, dag_id, task_id, expected_final_state,
- timeout):
+ def monitor_task(self, host, execution_date, dag_id, task_id, expected_final_state, timeout):
tries = 0
state = ''
max_tries = max(int(timeout / 5), 1)
@@ -103,9 +101,10 @@ def monitor_task(self, host, execution_date, dag_id, task_id, expected_final_sta
time.sleep(5)
# Trigger a new dagrun
try:
- get_string = \
- f'http://{host}/api/experimental/dags/{dag_id}/' \
+ get_string = (
+ f'http://{host}/api/experimental/dags/{dag_id}/'
f'dag_runs/{execution_date}/tasks/{task_id}'
+ )
print(f"Calling [monitor_task]#1 {get_string}")
result = self.session.get(get_string)
if result.status_code == 404:
@@ -129,18 +128,14 @@ def monitor_task(self, host, execution_date, dag_id, task_id, expected_final_sta
print(f"The expected state is wrong {state} != {expected_final_state} (expected)!")
self.assertEqual(state, expected_final_state)
- def ensure_dag_expected_state(self, host, execution_date, dag_id,
- expected_final_state,
- timeout):
+ def ensure_dag_expected_state(self, host, execution_date, dag_id, expected_final_state, timeout):
tries = 0
state = ''
max_tries = max(int(timeout / 5), 1)
# Wait some time for the operator to complete
while tries < max_tries:
time.sleep(5)
- get_string = \
- f'http://{host}/api/experimental/dags/{dag_id}/' \
- f'dag_runs/{execution_date}'
+ get_string = f'http://{host}/api/experimental/dags/{dag_id}/' f'dag_runs/{execution_date}'
print(f"Calling {get_string}")
# Trigger a new dagrun
result = self.session.get(get_string)
@@ -148,8 +143,7 @@ def ensure_dag_expected_state(self, host, execution_date, dag_id,
result_json = result.json()
print(f"Received: {result}")
state = result_json['state']
- check_call(
- ["echo", f"Attempt {tries}: Current state of dag is {state}"])
+ check_call(["echo", f"Attempt {tries}: Current state of dag is {state}"])
print(f"Attempt {tries}: Current state of dag is {state}")
if state == expected_final_state:
@@ -162,8 +156,7 @@ def ensure_dag_expected_state(self, host, execution_date, dag_id,
# Maybe check if we can retrieve the logs, but then we need to extend the API
def start_dag(self, dag_id, host):
- get_string = f'http://{host}/api/experimental/' \
- f'dags/{dag_id}/paused/false'
+ get_string = f'http://{host}/api/experimental/' f'dags/{dag_id}/paused/false'
print(f"Calling [start_dag]#1 {get_string}")
result = self.session.get(get_string)
try:
@@ -171,10 +164,8 @@ def start_dag(self, dag_id, host):
except ValueError:
result_json = str(result)
print(f"Received [start_dag]#1 {result_json}")
- self.assertEqual(result.status_code, 200, "Could not enable DAG: {result}"
- .format(result=result_json))
- post_string = f'http://{host}/api/experimental/' \
- f'dags/{dag_id}/dag_runs'
+ self.assertEqual(result.status_code, 200, f"Could not enable DAG: {result_json}")
+ post_string = f'http://{host}/api/experimental/' f'dags/{dag_id}/dag_runs'
print(f"Calling [start_dag]#2 {post_string}")
# Trigger a new dagrun
result = self.session.post(post_string, json={})
@@ -183,17 +174,18 @@ def start_dag(self, dag_id, host):
except ValueError:
result_json = str(result)
print(f"Received [start_dag]#2 {result_json}")
- self.assertEqual(result.status_code, 200, "Could not trigger a DAG-run: {result}"
- .format(result=result_json))
+ self.assertEqual(result.status_code, 200, f"Could not trigger a DAG-run: {result_json}")
time.sleep(1)
get_string = f'http://{host}/api/experimental/latest_runs'
print(f"Calling [start_dag]#3 {get_string}")
result = self.session.get(get_string)
- self.assertEqual(result.status_code, 200, "Could not get the latest DAG-run:"
- " {result}"
- .format(result=result.json()))
+ self.assertEqual(
+ result.status_code,
+ 200,
+ "Could not get the latest DAG-run:" " {result}".format(result=result.json()),
+ )
result_json = result.json()
print(f"Received: [start_dag]#3 {result_json}")
return result_json
@@ -217,16 +209,22 @@ def test_integration_run_dag(self):
print(f"Found the job with execution date {execution_date}")
# Wait some time for the operator to complete
- self.monitor_task(host=host,
- execution_date=execution_date,
- dag_id=dag_id,
- task_id='start_task',
- expected_final_state='success', timeout=300)
+ self.monitor_task(
+ host=host,
+ execution_date=execution_date,
+ dag_id=dag_id,
+ task_id='start_task',
+ expected_final_state='success',
+ timeout=300,
+ )
- self.ensure_dag_expected_state(host=host,
- execution_date=execution_date,
- dag_id=dag_id,
- expected_final_state='success', timeout=300)
+ self.ensure_dag_expected_state(
+ host=host,
+ execution_date=execution_date,
+ dag_id=dag_id,
+ expected_final_state='success',
+ timeout=300,
+ )
def test_integration_run_dag_with_scheduler_failure(self):
host = KUBERNETES_HOST_PORT
@@ -239,23 +237,32 @@ def test_integration_run_dag_with_scheduler_failure(self):
time.sleep(10) # give time for pod to restart
# Wait some time for the operator to complete
- self.monitor_task(host=host,
- execution_date=execution_date,
- dag_id=dag_id,
- task_id='start_task',
- expected_final_state='success', timeout=300)
-
- self.monitor_task(host=host,
- execution_date=execution_date,
- dag_id=dag_id,
- task_id='other_namespace_task',
- expected_final_state='success', timeout=300)
-
- self.ensure_dag_expected_state(host=host,
- execution_date=execution_date,
- dag_id=dag_id,
- expected_final_state='success', timeout=300)
-
- self.assertEqual(self._num_pods_in_namespace('test-namespace'),
- 0,
- "failed to delete pods in other namespace")
+ self.monitor_task(
+ host=host,
+ execution_date=execution_date,
+ dag_id=dag_id,
+ task_id='start_task',
+ expected_final_state='success',
+ timeout=300,
+ )
+
+ self.monitor_task(
+ host=host,
+ execution_date=execution_date,
+ dag_id=dag_id,
+ task_id='other_namespace_task',
+ expected_final_state='success',
+ timeout=300,
+ )
+
+ self.ensure_dag_expected_state(
+ host=host,
+ execution_date=execution_date,
+ dag_id=dag_id,
+ expected_final_state='success',
+ timeout=300,
+ )
+
+ self.assertEqual(
+ self._num_pods_in_namespace('test-namespace'), 0, "failed to delete pods in other namespace"
+ )
diff --git a/kubernetes_tests/test_kubernetes_pod_operator.py b/kubernetes_tests/test_kubernetes_pod_operator.py
index c78a821b3b30c..33eb553c06f26 100644
--- a/kubernetes_tests/test_kubernetes_pod_operator.py
+++ b/kubernetes_tests/test_kubernetes_pod_operator.py
@@ -46,8 +46,7 @@ def create_context(task):
dag = DAG(dag_id="dag")
tzinfo = pendulum.timezone("Europe/Amsterdam")
execution_date = timezone.datetime(2016, 1, 1, 1, 0, 0, tzinfo=tzinfo)
- task_instance = TaskInstance(task=task,
- execution_date=execution_date)
+ task_instance = TaskInstance(task=task, execution_date=execution_date)
return {
"dag": dag,
"ts": execution_date.isoformat(),
@@ -57,7 +56,6 @@ def create_context(task):
class TestKubernetesPodOperatorSystem(unittest.TestCase):
-
def get_current_task_name(self):
# reverse test name to make pod name unique (it has limited length)
return "_" + unittest.TestCase.id(self).replace(".", "_")[::-1]
@@ -73,26 +71,30 @@ def setUp(self):
'name': ANY,
'annotations': {},
'labels': {
- 'foo': 'bar', 'kubernetes_pod_operator': 'True',
+ 'foo': 'bar',
+ 'kubernetes_pod_operator': 'True',
'airflow_version': airflow_version.replace('+', '-'),
'execution_date': '2016-01-01T0100000100-a2f50a31f',
'dag_id': 'dag',
'task_id': ANY,
- 'try_number': '1'},
+ 'try_number': '1',
+ },
},
'spec': {
'affinity': {},
- 'containers': [{
- 'image': 'ubuntu:16.04',
- 'args': ["echo 10"],
- 'command': ["bash", "-cx"],
- 'env': [],
- 'envFrom': [],
- 'resources': {},
- 'name': 'base',
- 'ports': [],
- 'volumeMounts': [],
- }],
+ 'containers': [
+ {
+ 'image': 'ubuntu:16.04',
+ 'args': ["echo 10"],
+ 'command': ["bash", "-cx"],
+ 'env': [],
+ 'envFrom': [],
+ 'resources': {},
+ 'name': 'base',
+ 'ports': [],
+ 'volumeMounts': [],
+ }
+ ],
'hostNetwork': False,
'imagePullSecrets': [],
'initContainers': [],
@@ -102,13 +104,14 @@ def setUp(self):
'serviceAccountName': 'default',
'tolerations': [],
'volumes': [],
- }
+ },
}
def tearDown(self) -> None:
client = kube_client.get_kube_client(in_cluster=False)
client.delete_collection_namespaced_pod(namespace="default")
import time
+
time.sleep(1)
def test_do_xcom_push_defaults_false(self):
@@ -222,7 +225,7 @@ def test_pod_dnspolicy(self):
in_cluster=False,
do_xcom_push=False,
hostnetwork=True,
- dnspolicy=dns_policy
+ dnspolicy=dns_policy,
)
context = create_context(k)
k.execute(context)
@@ -244,7 +247,7 @@ def test_pod_schedulername(self):
task_id="task" + self.get_current_task_name(),
in_cluster=False,
do_xcom_push=False,
- schedulername=scheduler_name
+ schedulername=scheduler_name,
)
context = create_context(k)
k.execute(context)
@@ -253,9 +256,7 @@ def test_pod_schedulername(self):
self.assertEqual(self.expected_pod, actual_pod)
def test_pod_node_selectors(self):
- node_selectors = {
- 'beta.kubernetes.io/os': 'linux'
- }
+ node_selectors = {'beta.kubernetes.io/os': 'linux'}
k = KubernetesPodOperator(
namespace='default',
image="ubuntu:16.04",
@@ -276,17 +277,8 @@ def test_pod_node_selectors(self):
def test_pod_resources(self):
resources = k8s.V1ResourceRequirements(
- requests={
- 'memory': '64Mi',
- 'cpu': '250m',
- 'ephemeral-storage': '1Gi'
- },
- limits={
- 'memory': '64Mi',
- 'cpu': 0.25,
- 'nvidia.com/gpu': None,
- 'ephemeral-storage': '2Gi'
- }
+ requests={'memory': '64Mi', 'cpu': '250m', 'ephemeral-storage': '1Gi'},
+ limits={'memory': '64Mi', 'cpu': 0.25, 'nvidia.com/gpu': None, 'ephemeral-storage': '2Gi'},
)
k = KubernetesPodOperator(
namespace='default',
@@ -304,17 +296,8 @@ def test_pod_resources(self):
k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['containers'][0]['resources'] = {
- 'requests': {
- 'memory': '64Mi',
- 'cpu': '250m',
- 'ephemeral-storage': '1Gi'
- },
- 'limits': {
- 'memory': '64Mi',
- 'cpu': 0.25,
- 'nvidia.com/gpu': None,
- 'ephemeral-storage': '2Gi'
- }
+ 'requests': {'memory': '64Mi', 'cpu': '250m', 'ephemeral-storage': '1Gi'},
+ 'limits': {'memory': '64Mi', 'cpu': 0.25, 'nvidia.com/gpu': None, 'ephemeral-storage': '2Gi'},
}
self.assertEqual(self.expected_pod, actual_pod)
@@ -325,11 +308,7 @@ def test_pod_affinity(self):
'nodeSelectorTerms': [
{
'matchExpressions': [
- {
- 'key': 'beta.kubernetes.io/os',
- 'operator': 'In',
- 'values': ['linux']
- }
+ {'key': 'beta.kubernetes.io/os', 'operator': 'In', 'values': ['linux']}
]
}
]
@@ -375,30 +354,24 @@ def test_port(self):
context = create_context(k)
k.execute(context=context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
- self.expected_pod['spec']['containers'][0]['ports'] = [{
- 'name': 'http',
- 'containerPort': 80
- }]
+ self.expected_pod['spec']['containers'][0]['ports'] = [{'name': 'http', 'containerPort': 80}]
self.assertEqual(self.expected_pod, actual_pod)
def test_volume_mount(self):
with mock.patch.object(PodLauncher, 'log') as mock_logger:
volume_mount = k8s.V1VolumeMount(
- name='test-volume',
- mount_path='/tmp/test_volume',
- sub_path=None,
- read_only=False
+ name='test-volume', mount_path='/tmp/test_volume', sub_path=None, read_only=False
)
volume = k8s.V1Volume(
name='test-volume',
- persistent_volume_claim=k8s.V1PersistentVolumeClaimVolumeSource(
- claim_name='test-volume'
- )
+ persistent_volume_claim=k8s.V1PersistentVolumeClaimVolumeSource(claim_name='test-volume'),
)
- args = ["echo \"retrieved from mount\" > /tmp/test_volume/test.txt "
- "&& cat /tmp/test_volume/test.txt"]
+ args = [
+ "echo \"retrieved from mount\" > /tmp/test_volume/test.txt "
+ "&& cat /tmp/test_volume/test.txt"
+ ]
k = KubernetesPodOperator(
namespace='default',
image="ubuntu:16.04",
@@ -417,17 +390,12 @@ def test_volume_mount(self):
mock_logger.info.assert_any_call('retrieved from mount')
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['containers'][0]['args'] = args
- self.expected_pod['spec']['containers'][0]['volumeMounts'] = [{
- 'name': 'test-volume',
- 'mountPath': '/tmp/test_volume',
- 'readOnly': False
- }]
- self.expected_pod['spec']['volumes'] = [{
- 'name': 'test-volume',
- 'persistentVolumeClaim': {
- 'claimName': 'test-volume'
- }
- }]
+ self.expected_pod['spec']['containers'][0]['volumeMounts'] = [
+ {'name': 'test-volume', 'mountPath': '/tmp/test_volume', 'readOnly': False}
+ ]
+ self.expected_pod['spec']['volumes'] = [
+ {'name': 'test-volume', 'persistentVolumeClaim': {'claimName': 'test-volume'}}
+ ]
self.assertEqual(self.expected_pod, actual_pod)
def test_run_as_user_root(self):
@@ -604,9 +572,7 @@ def test_envs_from_configmaps(self, mock_client, mock_monitor, mock_start):
from airflow.utils.state import State
configmap_name = "test-config-map"
- env_from = [k8s.V1EnvFromSource(config_map_ref=k8s.V1ConfigMapEnvSource(
- name=configmap_name
- ))]
+ env_from = [k8s.V1EnvFromSource(config_map_ref=k8s.V1ConfigMapEnvSource(name=configmap_name))]
# WHEN
k = KubernetesPodOperator(
namespace='default',
@@ -618,15 +584,13 @@ def test_envs_from_configmaps(self, mock_client, mock_monitor, mock_start):
task_id="task" + self.get_current_task_name(),
in_cluster=False,
do_xcom_push=False,
- env_from=env_from
+ env_from=env_from,
)
# THEN
mock_monitor.return_value = (State.SUCCESS, None)
context = create_context(k)
k.execute(context)
- self.assertEqual(
- mock_start.call_args[0][0].spec.containers[0].env_from, env_from
- )
+ self.assertEqual(mock_start.call_args[0][0].spec.containers[0].env_from, env_from)
@mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
@mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
@@ -634,6 +598,7 @@ def test_envs_from_configmaps(self, mock_client, mock_monitor, mock_start):
def test_envs_from_secrets(self, mock_client, monitor_mock, start_mock):
# GIVEN
from airflow.utils.state import State
+
secret_ref = 'secret_name'
secrets = [Secret('env', None, secret_ref)]
# WHEN
@@ -655,29 +620,17 @@ def test_envs_from_secrets(self, mock_client, monitor_mock, start_mock):
k.execute(context)
self.assertEqual(
start_mock.call_args[0][0].spec.containers[0].env_from,
- [k8s.V1EnvFromSource(secret_ref=k8s.V1SecretEnvSource(
- name=secret_ref
- ))]
+ [k8s.V1EnvFromSource(secret_ref=k8s.V1SecretEnvSource(name=secret_ref))],
)
def test_env_vars(self):
# WHEN
env_vars = [
- k8s.V1EnvVar(
- name="ENV1",
- value="val1"
- ),
- k8s.V1EnvVar(
- name="ENV2",
- value="val2"
- ),
+ k8s.V1EnvVar(name="ENV1", value="val1"),
+ k8s.V1EnvVar(name="ENV2", value="val2"),
k8s.V1EnvVar(
name="ENV3",
- value_from=k8s.V1EnvVarSource(
- field_ref=k8s.V1ObjectFieldSelector(
- field_path="status.podIP"
- )
- )
+ value_from=k8s.V1EnvVarSource(field_ref=k8s.V1ObjectFieldSelector(field_path="status.podIP")),
),
]
@@ -702,14 +655,7 @@ def test_env_vars(self):
self.expected_pod['spec']['containers'][0]['env'] = [
{'name': 'ENV1', 'value': 'val1'},
{'name': 'ENV2', 'value': 'val2'},
- {
- 'name': 'ENV3',
- 'valueFrom': {
- 'fieldRef': {
- 'fieldPath': 'status.podIP'
- }
- }
- }
+ {'name': 'ENV3', 'valueFrom': {'fieldRef': {'fieldPath': 'status.podIP'}}},
]
self.assertEqual(self.expected_pod, actual_pod)
@@ -719,7 +665,7 @@ def test_pod_template_file_system(self):
task_id="task" + self.get_current_task_name(),
in_cluster=False,
pod_template_file=fixture,
- do_xcom_push=True
+ do_xcom_push=True,
)
context = create_context(k)
@@ -735,7 +681,7 @@ def test_pod_template_file_with_overrides_system(self):
env_vars=[k8s.V1EnvVar(name="env_name", value="value")],
in_cluster=False,
pod_template_file=fixture,
- do_xcom_push=True
+ do_xcom_push=True,
)
context = create_context(k)
@@ -747,20 +693,14 @@ def test_pod_template_file_with_overrides_system(self):
def test_init_container(self):
# GIVEN
- volume_mounts = [k8s.V1VolumeMount(
- mount_path='/etc/foo',
- name='test-volume',
- sub_path=None,
- read_only=True
- )]
-
- init_environments = [k8s.V1EnvVar(
- name='key1',
- value='value1'
- ), k8s.V1EnvVar(
- name='key2',
- value='value2'
- )]
+ volume_mounts = [
+ k8s.V1VolumeMount(mount_path='/etc/foo', name='test-volume', sub_path=None, read_only=True)
+ ]
+
+ init_environments = [
+ k8s.V1EnvVar(name='key1', value='value1'),
+ k8s.V1EnvVar(name='key2', value='value2'),
+ ]
init_container = k8s.V1Container(
name="init-container",
@@ -768,32 +708,20 @@ def test_init_container(self):
env=init_environments,
volume_mounts=volume_mounts,
command=["bash", "-cx"],
- args=["echo 10"]
+ args=["echo 10"],
)
volume = k8s.V1Volume(
name='test-volume',
- persistent_volume_claim=k8s.V1PersistentVolumeClaimVolumeSource(
- claim_name='test-volume'
- )
+ persistent_volume_claim=k8s.V1PersistentVolumeClaimVolumeSource(claim_name='test-volume'),
)
expected_init_container = {
'name': 'init-container',
'image': 'ubuntu:16.04',
'command': ['bash', '-cx'],
'args': ['echo 10'],
- 'env': [{
- 'name': 'key1',
- 'value': 'value1'
- }, {
- 'name': 'key2',
- 'value': 'value2'
- }],
- 'volumeMounts': [{
- 'mountPath': '/etc/foo',
- 'name': 'test-volume',
- 'readOnly': True
- }],
+ 'env': [{'name': 'key1', 'value': 'value1'}, {'name': 'key2', 'value': 'value2'}],
+ 'volumeMounts': [{'mountPath': '/etc/foo', 'name': 'test-volume', 'readOnly': True}],
}
k = KubernetesPodOperator(
@@ -813,36 +741,30 @@ def test_init_container(self):
k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['initContainers'] = [expected_init_container]
- self.expected_pod['spec']['volumes'] = [{
- 'name': 'test-volume',
- 'persistentVolumeClaim': {
- 'claimName': 'test-volume'
- }
- }]
+ self.expected_pod['spec']['volumes'] = [
+ {'name': 'test-volume', 'persistentVolumeClaim': {'claimName': 'test-volume'}}
+ ]
self.assertEqual(self.expected_pod, actual_pod)
@mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
@mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
@mock.patch("airflow.kubernetes.kube_client.get_kube_client")
def test_pod_template_file(
- self,
- mock_client,
- monitor_mock,
- start_mock # pylint: disable=unused-argument
+ self, mock_client, monitor_mock, start_mock # pylint: disable=unused-argument
):
from airflow.utils.state import State
+
path = sys.path[0] + '/tests/kubernetes/pod.yaml'
k = KubernetesPodOperator(
- task_id="task" + self.get_current_task_name(),
- pod_template_file=path,
- do_xcom_push=True
+ task_id="task" + self.get_current_task_name(), pod_template_file=path, do_xcom_push=True
)
monitor_mock.return_value = (State.SUCCESS, None)
context = create_context(k)
with self.assertLogs(k.log, level=logging.DEBUG) as cm:
k.execute(context)
- expected_line = textwrap.dedent("""\
+ expected_line = textwrap.dedent(
+ """\
DEBUG:airflow.task.operators:Starting pod:
api_version: v1
kind: Pod
@@ -851,65 +773,57 @@ def test_pod_template_file(
cluster_name: null
creation_timestamp: null
deletion_grace_period_seconds: null\
- """).strip()
+ """
+ ).strip()
self.assertTrue(any(line.startswith(expected_line) for line in cm.output))
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
- expected_dict = {'apiVersion': 'v1',
- 'kind': 'Pod',
- 'metadata': {'annotations': {},
- 'labels': {},
- 'name': 'memory-demo',
- 'namespace': 'mem-example'},
- 'spec': {'affinity': {},
- 'containers': [{'args': ['--vm',
- '1',
- '--vm-bytes',
- '150M',
- '--vm-hang',
- '1'],
- 'command': ['stress'],
- 'env': [],
- 'envFrom': [],
- 'image': 'apache/airflow:stress-2020.07.10-1.0.4',
- 'name': 'base',
- 'ports': [],
- 'resources': {'limits': {'memory': '200Mi'},
- 'requests': {'memory': '100Mi'}},
- 'volumeMounts': [{'mountPath': '/airflow/xcom',
- 'name': 'xcom'}]},
- {'command': ['sh',
- '-c',
- 'trap "exit 0" INT; while true; do sleep '
- '30; done;'],
- 'image': 'alpine',
- 'name': 'airflow-xcom-sidecar',
- 'resources': {'requests': {'cpu': '1m'}},
- 'volumeMounts': [{'mountPath': '/airflow/xcom',
- 'name': 'xcom'}]}],
- 'hostNetwork': False,
- 'imagePullSecrets': [],
- 'initContainers': [],
- 'nodeSelector': {},
- 'restartPolicy': 'Never',
- 'securityContext': {},
- 'serviceAccountName': 'default',
- 'tolerations': [],
- 'volumes': [{'emptyDir': {}, 'name': 'xcom'}]}}
+ expected_dict = {
+ 'apiVersion': 'v1',
+ 'kind': 'Pod',
+ 'metadata': {'annotations': {}, 'labels': {}, 'name': 'memory-demo', 'namespace': 'mem-example'},
+ 'spec': {
+ 'affinity': {},
+ 'containers': [
+ {
+ 'args': ['--vm', '1', '--vm-bytes', '150M', '--vm-hang', '1'],
+ 'command': ['stress'],
+ 'env': [],
+ 'envFrom': [],
+ 'image': 'apache/airflow:stress-2020.07.10-1.0.4',
+ 'name': 'base',
+ 'ports': [],
+ 'resources': {'limits': {'memory': '200Mi'}, 'requests': {'memory': '100Mi'}},
+ 'volumeMounts': [{'mountPath': '/airflow/xcom', 'name': 'xcom'}],
+ },
+ {
+ 'command': ['sh', '-c', 'trap "exit 0" INT; while true; do sleep 30; done;'],
+ 'image': 'alpine',
+ 'name': 'airflow-xcom-sidecar',
+ 'resources': {'requests': {'cpu': '1m'}},
+ 'volumeMounts': [{'mountPath': '/airflow/xcom', 'name': 'xcom'}],
+ },
+ ],
+ 'hostNetwork': False,
+ 'imagePullSecrets': [],
+ 'initContainers': [],
+ 'nodeSelector': {},
+ 'restartPolicy': 'Never',
+ 'securityContext': {},
+ 'serviceAccountName': 'default',
+ 'tolerations': [],
+ 'volumes': [{'emptyDir': {}, 'name': 'xcom'}],
+ },
+ }
self.assertEqual(expected_dict, actual_pod)
@mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
@mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
@mock.patch("airflow.kubernetes.kube_client.get_kube_client")
def test_pod_priority_class_name(
- self,
- mock_client,
- monitor_mock,
- start_mock # pylint: disable=unused-argument
+ self, mock_client, monitor_mock, start_mock # pylint: disable=unused-argument
):
- """Test ability to assign priorityClassName to pod
-
- """
+ """Test ability to assign priorityClassName to pod"""
from airflow.utils.state import State
priority_class_name = "medium-test"
@@ -949,9 +863,9 @@ def test_pod_name(self):
)
@mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
- def test_on_kill(self,
- monitor_mock): # pylint: disable=unused-argument
+ def test_on_kill(self, monitor_mock): # pylint: disable=unused-argument
from airflow.utils.state import State
+
client = kube_client.get_kube_client(in_cluster=False)
name = "test"
namespace = "default"
@@ -979,6 +893,7 @@ def test_on_kill(self,
def test_reattach_failing_pod_once(self):
from airflow.utils.state import State
+
client = kube_client.get_kube_client(in_cluster=False)
name = "test"
namespace = "default"
@@ -1009,11 +924,14 @@ def test_reattach_failing_pod_once(self):
k.execute(context)
pod = client.read_namespaced_pod(name=name, namespace=namespace)
self.assertEqual(pod.metadata.labels["already_checked"], "True")
- with mock.patch("airflow.providers.cncf.kubernetes"
- ".operators.kubernetes_pod.KubernetesPodOperator"
- ".create_new_pod_for_operator") as create_mock:
+ with mock.patch(
+ "airflow.providers.cncf.kubernetes"
+ ".operators.kubernetes_pod.KubernetesPodOperator"
+ ".create_new_pod_for_operator"
+ ) as create_mock:
create_mock.return_value = ("success", {}, {})
k.execute(context)
create_mock.assert_called_once()
+
# pylint: enable=unused-argument
diff --git a/metastore_browser/hive_metastore.py b/metastore_browser/hive_metastore.py
index 6335891dba9d0..58d4da3ca62a3 100644
--- a/metastore_browser/hive_metastore.py
+++ b/metastore_browser/hive_metastore.py
@@ -63,16 +63,14 @@ def index(self):
"""
hook = MySqlHook(METASTORE_MYSQL_CONN_ID)
df = hook.get_pandas_df(sql)
- df.db = (
- '' + df.db + '')
+ df.db = '' + df.db + ''
table = df.to_html(
classes="table table-striped table-bordered table-hover",
index=False,
escape=False,
- na_rep='',)
- return self.render_template(
- "metastore_browser/dbs.html", table=Markup(table))
+ na_rep='',
+ )
+ return self.render_template("metastore_browser/dbs.html", table=Markup(table))
@expose('/table/')
def table(self):
@@ -81,8 +79,8 @@ def table(self):
metastore = HiveMetastoreHook(METASTORE_CONN_ID)
table = metastore.get_table(table_name)
return self.render_template(
- "metastore_browser/table.html",
- table=table, table_name=table_name, datetime=datetime, int=int)
+ "metastore_browser/table.html", table=table, table_name=table_name, datetime=datetime, int=int
+ )
@expose('/db/')
def db(self):
@@ -90,8 +88,7 @@ def db(self):
db = request.args.get("db")
metastore = HiveMetastoreHook(METASTORE_CONN_ID)
tables = sorted(metastore.get_tables(db=db), key=lambda x: x.tableName)
- return self.render_template(
- "metastore_browser/db.html", tables=tables, db=db)
+ return self.render_template("metastore_browser/db.html", tables=tables, db=db)
@gzipped
@expose('/partitions/')
@@ -114,13 +111,16 @@ def partitions(self):
b.TBL_NAME like '{table}' AND
d.NAME like '{schema}'
ORDER BY PART_NAME DESC
- """.format(table=table, schema=schema)
+ """.format(
+ table=table, schema=schema
+ )
hook = MySqlHook(METASTORE_MYSQL_CONN_ID)
df = hook.get_pandas_df(sql)
return df.to_html(
classes="table table-striped table-bordered table-hover",
index=False,
- na_rep='',)
+ na_rep='',
+ )
@gzipped
@expose('/objects/')
@@ -144,11 +144,11 @@ def objects(self):
b.NAME NOT LIKE '%temp%'
{where_clause}
LIMIT {LIMIT};
- """.format(where_clause=where_clause, LIMIT=TABLE_SELECTOR_LIMIT)
+ """.format(
+ where_clause=where_clause, LIMIT=TABLE_SELECTOR_LIMIT
+ )
hook = MySqlHook(METASTORE_MYSQL_CONN_ID)
- data = [
- {'id': row[0], 'text': row[0]}
- for row in hook.get_records(sql)]
+ data = [{'id': row[0], 'text': row[0]} for row in hook.get_records(sql)]
return json.dumps(data)
@gzipped
@@ -162,7 +162,8 @@ def data(self):
return df.to_html(
classes="table table-striped table-bordered table-hover",
index=False,
- na_rep='',)
+ na_rep='',
+ )
@expose('/ddl/')
def ddl(self):
@@ -175,10 +176,12 @@ def ddl(self):
# Creating a flask blueprint to integrate the templates and static folder
bp = Blueprint(
- "metastore_browser", __name__,
+ "metastore_browser",
+ __name__,
template_folder='templates',
static_folder='static',
- static_url_path='/static/metastore_browser')
+ static_url_path='/static/metastore_browser',
+)
class MetastoreBrowserPlugin(AirflowPlugin):
@@ -186,6 +189,6 @@ class MetastoreBrowserPlugin(AirflowPlugin):
name = "metastore_browser"
flask_blueprints = [bp]
- appbuilder_views = [{"name": "Hive Metadata Browser",
- "category": "Plugins",
- "view": MetastoreBrowserView()}]
+ appbuilder_views = [
+ {"name": "Hive Metadata Browser", "category": "Plugins", "view": MetastoreBrowserView()}
+ ]
diff --git a/provider_packages/import_all_provider_classes.py b/provider_packages/import_all_provider_classes.py
index 87e7f1ea41fb1..ff0ffbc1199eb 100755
--- a/provider_packages/import_all_provider_classes.py
+++ b/provider_packages/import_all_provider_classes.py
@@ -25,9 +25,9 @@
from typing import List
-def import_all_provider_classes(source_paths: str,
- provider_ids: List[str] = None,
- print_imports: bool = False) -> List[str]:
+def import_all_provider_classes(
+ source_paths: str, provider_ids: List[str] = None, print_imports: bool = False
+) -> List[str]:
"""
Imports all classes in providers packages. This method loads and imports
all the classes found in providers, so that we can find all the subclasses
@@ -61,8 +61,9 @@ def mk_prefix(provider_id):
provider_ids = ['']
for provider_id in provider_ids:
- for modinfo in pkgutil.walk_packages(mk_path(provider_id), prefix=mk_prefix(provider_id),
- onerror=onerror):
+ for modinfo in pkgutil.walk_packages(
+ mk_path(provider_id), prefix=mk_prefix(provider_id), onerror=onerror
+ ):
if print_imports:
print(f"Importing module: {modinfo.name}")
try:
@@ -78,9 +79,12 @@ def mk_prefix(provider_id):
exception_str = traceback.format_exc()
tracebacks.append(exception_str)
if tracebacks:
- print("""
+ print(
+ """
ERROR: There were some import errors
-""", file=sys.stderr)
+""",
+ file=sys.stderr,
+ )
for trace in tracebacks:
print("----------------------------------------", file=sys.stderr)
print(trace, file=sys.stderr)
@@ -93,6 +97,7 @@ def mk_prefix(provider_id):
if __name__ == '__main__':
try:
import airflow.providers
+
install_source_path = list(iter(airflow.providers.__path__))
except ImportError as e:
print("----------------------------------------", file=sys.stderr)
diff --git a/provider_packages/prepare_provider_packages.py b/provider_packages/prepare_provider_packages.py
index ef12b2aedfa7e..5eed2407ea1ff 100644
--- a/provider_packages/prepare_provider_packages.py
+++ b/provider_packages/prepare_provider_packages.py
@@ -36,7 +36,6 @@
from packaging.version import Version
-
PROVIDER_TEMPLATE_PREFIX = "PROVIDER_"
BACKPORT_PROVIDER_TEMPLATE_PREFIX = "BACKPORT_PROVIDER_"
@@ -183,8 +182,9 @@ def get_pip_package_name(provider_package_id: str, backport_packages: bool) -> s
:param backport_packages: whether to prepare regular (False) or backport (True) packages
:return: the name of pip package
"""
- return ("apache-airflow-backport-providers-" if backport_packages else "apache-airflow-providers-") \
- + provider_package_id.replace(".", "-")
+ return (
+ "apache-airflow-backport-providers-" if backport_packages else "apache-airflow-providers-"
+ ) + provider_package_id.replace(".", "-")
def get_long_description(provider_package_id: str, backport_packages: bool) -> str:
@@ -196,8 +196,9 @@ def get_long_description(provider_package_id: str, backport_packages: bool) -> s
:return: content of the description (BACKPORT_PROVIDER_README/README file)
"""
package_folder = get_target_providers_package_folder(provider_package_id)
- readme_file = os.path.join(package_folder,
- "BACKPORT_PROVIDER_README.md" if backport_packages else "README.md")
+ readme_file = os.path.join(
+ package_folder, "BACKPORT_PROVIDER_README.md" if backport_packages else "README.md"
+ )
if not os.path.exists(readme_file):
return ""
with open(readme_file, encoding='utf-8', mode="r") as file:
@@ -216,9 +217,9 @@ def get_long_description(provider_package_id: str, backport_packages: bool) -> s
return long_description
-def get_package_release_version(provider_package_id: str,
- backport_packages: bool,
- version_suffix: str = "") -> str:
+def get_package_release_version(
+ provider_package_id: str, backport_packages: bool, version_suffix: str = ""
+) -> str:
"""
Returns release version including optional suffix.
@@ -227,14 +228,15 @@ def get_package_release_version(provider_package_id: str,
:param version_suffix: optional suffix (rc1, rc2 etc).
:return:
"""
- return get_latest_release(
- get_package_path(provider_package_id=provider_package_id),
- backport_packages=backport_packages).release_version + version_suffix
+ return (
+ get_latest_release(
+ get_package_path(provider_package_id=provider_package_id), backport_packages=backport_packages
+ ).release_version
+ + version_suffix
+ )
-def get_install_requirements(
- provider_package_id: str,
- backport_packages: bool) -> List[str]:
+def get_install_requirements(provider_package_id: str, backport_packages: bool) -> List[str]:
"""
Returns install requirements for the package.
@@ -246,8 +248,11 @@ def get_install_requirements(
dependencies = PROVIDERS_REQUIREMENTS[provider_package_id]
if backport_packages:
- airflow_dependency = 'apache-airflow~=1.10' if provider_package_id != 'cncf.kubernetes' \
+ airflow_dependency = (
+ 'apache-airflow~=1.10'
+ if provider_package_id != 'cncf.kubernetes'
else 'apache-airflow>=1.10.12, <2.0.0'
+ )
else:
airflow_dependency = 'apache-airflow>=2.0.0a0'
install_requires = [airflow_dependency]
@@ -260,10 +265,7 @@ def get_setup_requirements() -> List[str]:
Returns setup requirements (common for all package for now).
:return: setup requirements
"""
- return [
- 'setuptools',
- 'wheel'
- ]
+ return ['setuptools', 'wheel']
def get_package_extras(provider_package_id: str, backport_packages: bool) -> Dict[str, List[str]]:
@@ -278,9 +280,14 @@ def get_package_extras(provider_package_id: str, backport_packages: bool) -> Dic
return {}
with open(DEPENDENCIES_JSON_FILE) as dependencies_file:
cross_provider_dependencies: Dict[str, List[str]] = json.load(dependencies_file)
- extras_dict = {module: [get_pip_package_name(module, backport_packages=backport_packages)]
- for module in cross_provider_dependencies[provider_package_id]} \
- if cross_provider_dependencies.get(provider_package_id) else {}
+ extras_dict = (
+ {
+ module: [get_pip_package_name(module, backport_packages=backport_packages)]
+ for module in cross_provider_dependencies[provider_package_id]
+ }
+ if cross_provider_dependencies.get(provider_package_id)
+ else {}
+ )
return extras_dict
@@ -360,6 +367,7 @@ def inherits_from(the_class: Type, expected_ancestor: Type) -> bool:
if expected_ancestor is None:
return False
import inspect
+
mro = inspect.getmro(the_class)
return the_class is not expected_ancestor and expected_ancestor in mro
@@ -371,6 +379,7 @@ def is_class(the_class: Type) -> bool:
:return: true if it is a class
"""
import inspect
+
return inspect.isclass(the_class)
@@ -386,14 +395,15 @@ def package_name_matches(the_class: Type, expected_pattern: Optional[str]) -> bo
def find_all_entities(
- imported_classes: List[str],
- base_package: str,
- ancestor_match: Type,
- sub_package_pattern_match: str,
- expected_class_name_pattern: str,
- unexpected_class_name_patterns: Set[str],
- exclude_class_type: Type = None,
- false_positive_class_names: Optional[Set[str]] = None) -> VerifiedEntities:
+ imported_classes: List[str],
+ base_package: str,
+ ancestor_match: Type,
+ sub_package_pattern_match: str,
+ expected_class_name_pattern: str,
+ unexpected_class_name_patterns: Set[str],
+ exclude_class_type: Type = None,
+ false_positive_class_names: Optional[Set[str]] = None,
+) -> VerifiedEntities:
"""
Returns set of entities containing all subclasses in package specified.
@@ -412,35 +422,44 @@ def find_all_entities(
for imported_name in imported_classes:
module, class_name = imported_name.rsplit(".", maxsplit=1)
the_class = getattr(importlib.import_module(module), class_name)
- if is_class(the_class=the_class) \
- and not is_example_dag(imported_name=imported_name) \
- and is_from_the_expected_base_package(the_class=the_class, expected_package=base_package) \
- and is_imported_from_same_module(the_class=the_class, imported_name=imported_name) \
- and inherits_from(the_class=the_class, expected_ancestor=ancestor_match) \
- and not inherits_from(the_class=the_class, expected_ancestor=exclude_class_type) \
- and package_name_matches(the_class=the_class, expected_pattern=sub_package_pattern_match):
+ if (
+ is_class(the_class=the_class)
+ and not is_example_dag(imported_name=imported_name)
+ and is_from_the_expected_base_package(the_class=the_class, expected_package=base_package)
+ and is_imported_from_same_module(the_class=the_class, imported_name=imported_name)
+ and inherits_from(the_class=the_class, expected_ancestor=ancestor_match)
+ and not inherits_from(the_class=the_class, expected_ancestor=exclude_class_type)
+ and package_name_matches(the_class=the_class, expected_pattern=sub_package_pattern_match)
+ ):
if not false_positive_class_names or class_name not in false_positive_class_names:
if not re.match(expected_class_name_pattern, class_name):
wrong_entities.append(
- (the_class, f"The class name {class_name} is wrong. "
- f"It should match {expected_class_name_pattern}"))
+ (
+ the_class,
+ f"The class name {class_name} is wrong. "
+ f"It should match {expected_class_name_pattern}",
+ )
+ )
continue
if unexpected_class_name_patterns:
for unexpected_class_name_pattern in unexpected_class_name_patterns:
if re.match(unexpected_class_name_pattern, class_name):
wrong_entities.append(
- (the_class,
- f"The class name {class_name} is wrong. "
- f"It should not match {unexpected_class_name_pattern}"))
+ (
+ the_class,
+ f"The class name {class_name} is wrong. "
+ f"It should not match {unexpected_class_name_pattern}",
+ )
+ )
continue
found_entities.add(imported_name)
return VerifiedEntities(all_entities=found_entities, wrong_entities=wrong_entities)
-def convert_new_classes_to_table(entity_type: EntityType,
- new_entities: List[str],
- full_package_name: str) -> str:
+def convert_new_classes_to_table(
+ entity_type: EntityType, new_entities: List[str], full_package_name: str
+) -> str:
"""
Converts new entities tp a markdown table.
@@ -450,15 +469,15 @@ def convert_new_classes_to_table(entity_type: EntityType,
:return: table of new classes
"""
from tabulate import tabulate
+
headers = [f"New Airflow 2.0 {entity_type.value.lower()}: `{full_package_name}` package"]
- table = [(get_class_code_link(full_package_name, class_name, "master"),)
- for class_name in new_entities]
+ table = [(get_class_code_link(full_package_name, class_name, "master"),) for class_name in new_entities]
return tabulate(table, headers=headers, tablefmt="pipe")
-def convert_moved_classes_to_table(entity_type: EntityType,
- moved_entities: Dict[str, str],
- full_package_name: str) -> str:
+def convert_moved_classes_to_table(
+ entity_type: EntityType, moved_entities: Dict[str, str], full_package_name: str
+) -> str:
"""
Converts moved entities to a markdown table
:param entity_type: type of entities -> operators, sensors etc.
@@ -467,21 +486,27 @@ def convert_moved_classes_to_table(entity_type: EntityType,
:return: table of moved classes
"""
from tabulate import tabulate
- headers = [f"Airflow 2.0 {entity_type.value.lower()}: `{full_package_name}` package",
- "Airflow 1.10.* previous location (usually `airflow.contrib`)"]
+
+ headers = [
+ f"Airflow 2.0 {entity_type.value.lower()}: `{full_package_name}` package",
+ "Airflow 1.10.* previous location (usually `airflow.contrib`)",
+ ]
table = [
- (get_class_code_link(full_package_name, to_class, "master"),
- get_class_code_link("airflow", moved_entities[to_class], "v1-10-stable"))
+ (
+ get_class_code_link(full_package_name, to_class, "master"),
+ get_class_code_link("airflow", moved_entities[to_class], "v1-10-stable"),
+ )
for to_class in sorted(moved_entities.keys())
]
return tabulate(table, headers=headers, tablefmt="pipe")
def get_details_about_classes(
- entity_type: EntityType,
- entities: Set[str],
- wrong_entities: List[Tuple[type, str]],
- full_package_name: str) -> EntityTypeSummary:
+ entity_type: EntityType,
+ entities: Set[str],
+ wrong_entities: List[Tuple[type, str]],
+ full_package_name: str,
+) -> EntityTypeSummary:
"""
Splits the set of entities into new and moved, depending on their presence in the dict of objects
retrieved from the test_contrib_to_core. Updates all_entities with the split class.
@@ -518,7 +543,7 @@ def get_details_about_classes(
moved_entities=moved_entities,
full_package_name=full_package_name,
),
- wrong_entities=wrong_entities
+ wrong_entities=wrong_entities,
)
@@ -527,7 +552,7 @@ def strip_package_from_class(base_package: str, class_name: str) -> str:
Strips base package name from the class (if it starts with the package name).
"""
if class_name.startswith(base_package):
- return class_name[len(base_package) + 1:]
+ return class_name[len(base_package) + 1 :]
else:
return class_name
@@ -553,8 +578,10 @@ def get_class_code_link(base_package: str, class_name: str, git_tag: str) -> str
:return: URL to the class
"""
url_prefix = f'https://github.com/apache/airflow/blob/{git_tag}/'
- return f'[{strip_package_from_class(base_package, class_name)}]' \
- f'({convert_class_name_to_url(url_prefix, class_name)})'
+ return (
+ f'[{strip_package_from_class(base_package, class_name)}]'
+ f'({convert_class_name_to_url(url_prefix, class_name)})'
+ )
def print_wrong_naming(entity_type: EntityType, wrong_classes: List[Tuple[type, str]]):
@@ -569,8 +596,9 @@ def print_wrong_naming(entity_type: EntityType, wrong_classes: List[Tuple[type,
print(f"{entity_type}: {message}", file=sys.stderr)
-def get_package_class_summary(full_package_name: str, imported_classes: List[str]) \
- -> Dict[EntityType, EntityTypeSummary]:
+def get_package_class_summary(
+ full_package_name: str, imported_classes: List[str]
+) -> Dict[EntityType, EntityTypeSummary]:
"""
Gets summary of the package in the form of dictionary containing all types of entities
:param full_package_name: full package name
@@ -582,72 +610,80 @@ def get_package_class_summary(full_package_name: str, imported_classes: List[str
from airflow.secrets import BaseSecretsBackend
from airflow.sensors.base_sensor_operator import BaseSensorOperator
- all_verified_entities: Dict[EntityType, VerifiedEntities] = {EntityType.Operators: find_all_entities(
- imported_classes=imported_classes,
- base_package=full_package_name,
- sub_package_pattern_match=r".*\.operators\..*",
- ancestor_match=BaseOperator,
- expected_class_name_pattern=OPERATORS_PATTERN,
- unexpected_class_name_patterns=ALL_PATTERNS - {OPERATORS_PATTERN},
- exclude_class_type=BaseSensorOperator,
- false_positive_class_names={
- 'CloudVisionAddProductToProductSetOperator',
- 'CloudDataTransferServiceGCSToGCSOperator',
- 'CloudDataTransferServiceS3ToGCSOperator',
- 'BigQueryCreateDataTransferOperator',
- 'CloudTextToSpeechSynthesizeOperator',
- 'CloudSpeechToTextRecognizeSpeechOperator',
- }
- ), EntityType.Sensors: find_all_entities(
- imported_classes=imported_classes,
- base_package=full_package_name,
- sub_package_pattern_match=r".*\.sensors\..*",
- ancestor_match=BaseSensorOperator,
- expected_class_name_pattern=SENSORS_PATTERN,
- unexpected_class_name_patterns=ALL_PATTERNS - {OPERATORS_PATTERN, SENSORS_PATTERN}
- ), EntityType.Hooks: find_all_entities(
- imported_classes=imported_classes,
- base_package=full_package_name,
- sub_package_pattern_match=r".*\.hooks\..*",
- ancestor_match=BaseHook,
- expected_class_name_pattern=HOOKS_PATTERN,
- unexpected_class_name_patterns=ALL_PATTERNS - {HOOKS_PATTERN}
- ), EntityType.Secrets: find_all_entities(
- imported_classes=imported_classes,
- sub_package_pattern_match=r".*\.secrets\..*",
- base_package=full_package_name,
- ancestor_match=BaseSecretsBackend,
- expected_class_name_pattern=SECRETS_PATTERN,
- unexpected_class_name_patterns=ALL_PATTERNS - {SECRETS_PATTERN},
- ), EntityType.Transfers: find_all_entities(
- imported_classes=imported_classes,
- base_package=full_package_name,
- sub_package_pattern_match=r".*\.transfers\..*",
- ancestor_match=BaseOperator,
- expected_class_name_pattern=TRANSFERS_PATTERN,
- unexpected_class_name_patterns=ALL_PATTERNS - {OPERATORS_PATTERN, TRANSFERS_PATTERN},
- )}
+ all_verified_entities: Dict[EntityType, VerifiedEntities] = {
+ EntityType.Operators: find_all_entities(
+ imported_classes=imported_classes,
+ base_package=full_package_name,
+ sub_package_pattern_match=r".*\.operators\..*",
+ ancestor_match=BaseOperator,
+ expected_class_name_pattern=OPERATORS_PATTERN,
+ unexpected_class_name_patterns=ALL_PATTERNS - {OPERATORS_PATTERN},
+ exclude_class_type=BaseSensorOperator,
+ false_positive_class_names={
+ 'CloudVisionAddProductToProductSetOperator',
+ 'CloudDataTransferServiceGCSToGCSOperator',
+ 'CloudDataTransferServiceS3ToGCSOperator',
+ 'BigQueryCreateDataTransferOperator',
+ 'CloudTextToSpeechSynthesizeOperator',
+ 'CloudSpeechToTextRecognizeSpeechOperator',
+ },
+ ),
+ EntityType.Sensors: find_all_entities(
+ imported_classes=imported_classes,
+ base_package=full_package_name,
+ sub_package_pattern_match=r".*\.sensors\..*",
+ ancestor_match=BaseSensorOperator,
+ expected_class_name_pattern=SENSORS_PATTERN,
+ unexpected_class_name_patterns=ALL_PATTERNS - {OPERATORS_PATTERN, SENSORS_PATTERN},
+ ),
+ EntityType.Hooks: find_all_entities(
+ imported_classes=imported_classes,
+ base_package=full_package_name,
+ sub_package_pattern_match=r".*\.hooks\..*",
+ ancestor_match=BaseHook,
+ expected_class_name_pattern=HOOKS_PATTERN,
+ unexpected_class_name_patterns=ALL_PATTERNS - {HOOKS_PATTERN},
+ ),
+ EntityType.Secrets: find_all_entities(
+ imported_classes=imported_classes,
+ sub_package_pattern_match=r".*\.secrets\..*",
+ base_package=full_package_name,
+ ancestor_match=BaseSecretsBackend,
+ expected_class_name_pattern=SECRETS_PATTERN,
+ unexpected_class_name_patterns=ALL_PATTERNS - {SECRETS_PATTERN},
+ ),
+ EntityType.Transfers: find_all_entities(
+ imported_classes=imported_classes,
+ base_package=full_package_name,
+ sub_package_pattern_match=r".*\.transfers\..*",
+ ancestor_match=BaseOperator,
+ expected_class_name_pattern=TRANSFERS_PATTERN,
+ unexpected_class_name_patterns=ALL_PATTERNS - {OPERATORS_PATTERN, TRANSFERS_PATTERN},
+ ),
+ }
for entity in EntityType:
print_wrong_naming(entity, all_verified_entities[entity].wrong_entities)
- entities_summary: Dict[EntityType, EntityTypeSummary] = {} # noqa
+ entities_summary: Dict[EntityType, EntityTypeSummary] = {} # noqa
for entity_type in EntityType:
entities_summary[entity_type] = get_details_about_classes(
entity_type,
all_verified_entities[entity_type].all_entities,
all_verified_entities[entity_type].wrong_entities,
- full_package_name)
+ full_package_name,
+ )
return entities_summary
def render_template(
- template_name: str,
- context: Dict[str, Any],
- extension: str,
- autoescape: bool = True,
- keep_trailing_newline: bool = False) -> str:
+ template_name: str,
+ context: Dict[str, Any],
+ extension: str,
+ autoescape: bool = True,
+ keep_trailing_newline: bool = False,
+) -> str:
"""
Renders template based on it's name. Reads the template from _TEMPLATE.md.jinja2 in current dir.
:param template_name: name of the template to use
@@ -658,12 +694,13 @@ def render_template(
:return: rendered template
"""
import jinja2
+
template_loader = jinja2.FileSystemLoader(searchpath=MY_DIR_PATH)
template_env = jinja2.Environment(
loader=template_loader,
undefined=jinja2.StrictUndefined,
autoescape=autoescape,
- keep_trailing_newline=keep_trailing_newline
+ keep_trailing_newline=keep_trailing_newline,
)
template = template_env.get_template(f"{template_name}_TEMPLATE{extension}.jinja2")
content: str = template.render(context)
@@ -684,6 +721,7 @@ def convert_git_changes_to_table(changes: str, base_url: str) -> str:
:return: markdown-formatted table
"""
from tabulate import tabulate
+
lines = changes.split("\n")
headers = ["Commit", "Committed", "Subject"]
table_data = []
@@ -702,6 +740,7 @@ def convert_pip_requirements_to_table(requirements: Iterable[str]) -> str:
:return: markdown-formatted table
"""
from tabulate import tabulate
+
headers = ["PIP package", "Version required"]
table_data = []
for dependency in requirements:
@@ -715,8 +754,7 @@ def convert_pip_requirements_to_table(requirements: Iterable[str]) -> str:
return tabulate(table_data, headers=headers, tablefmt="pipe")
-def convert_cross_package_dependencies_to_table(
- cross_package_dependencies: List[str], base_url: str) -> str:
+def convert_cross_package_dependencies_to_table(cross_package_dependencies: List[str], base_url: str) -> str:
"""
Converts cross-package dependencies to a markdown table
:param cross_package_dependencies: list of cross-package dependencies
@@ -724,6 +762,7 @@ def convert_cross_package_dependencies_to_table(
:return: markdown-formatted table
"""
from tabulate import tabulate
+
headers = ["Dependent package", "Extra"]
table_data = []
for dependency in cross_package_dependencies:
@@ -757,8 +796,8 @@ def convert_cross_package_dependencies_to_table(
Keeps information about historical releases.
"""
ReleaseInfo = collections.namedtuple(
- "ReleaseInfo",
- "release_version release_version_no_leading_zeros last_commit_hash content file_name")
+ "ReleaseInfo", "release_version release_version_no_leading_zeros last_commit_hash content file_name"
+)
def strip_leading_zeros_in_calver(calver_version: str) -> str:
@@ -805,15 +844,19 @@ def get_all_releases(provider_package_path: str, backport_packages: bool) -> Lis
print("No commit found. This seems to be first time you run it", file=sys.stderr)
else:
last_commit_hash = found.group(1)
- release_version = file_name[len(changes_file_prefix):][:-3]
- release_version_no_leading_zeros = strip_leading_zeros_in_calver(release_version) \
- if backport_packages else release_version
+ release_version = file_name[len(changes_file_prefix) :][:-3]
+ release_version_no_leading_zeros = (
+ strip_leading_zeros_in_calver(release_version) if backport_packages else release_version
+ )
past_releases.append(
- ReleaseInfo(release_version=release_version,
- release_version_no_leading_zeros=release_version_no_leading_zeros,
- last_commit_hash=last_commit_hash,
- content=content,
- file_name=file_name))
+ ReleaseInfo(
+ release_version=release_version,
+ release_version_no_leading_zeros=release_version_no_leading_zeros,
+ last_commit_hash=last_commit_hash,
+ content=content,
+ file_name=file_name,
+ )
+ )
return past_releases
@@ -825,21 +868,24 @@ def get_latest_release(provider_package_path: str, backport_packages: bool) -> R
:param backport_packages: whether to prepare regular (False) or backport (True) packages
:return: latest release information
"""
- releases = get_all_releases(provider_package_path=provider_package_path,
- backport_packages=backport_packages)
+ releases = get_all_releases(
+ provider_package_path=provider_package_path, backport_packages=backport_packages
+ )
if len(releases) == 0:
- return ReleaseInfo(release_version="0.0.0",
- release_version_no_leading_zeros="0.0.0",
- last_commit_hash="no_hash",
- content="empty",
- file_name="no_file")
+ return ReleaseInfo(
+ release_version="0.0.0",
+ release_version_no_leading_zeros="0.0.0",
+ last_commit_hash="no_hash",
+ content="empty",
+ file_name="no_file",
+ )
else:
return releases[0]
-def get_previous_release_info(previous_release_version: str,
- past_releases: List[ReleaseInfo],
- current_release_version: str) -> Optional[str]:
+def get_previous_release_info(
+ previous_release_version: str, past_releases: List[ReleaseInfo], current_release_version: str
+) -> Optional[str]:
"""
Find previous release. In case we are re-running current release we assume that last release was
the previous one. This is needed so that we can generate list of changes since the previous release.
@@ -859,9 +905,8 @@ def get_previous_release_info(previous_release_version: str,
def check_if_release_version_ok(
- past_releases: List[ReleaseInfo],
- current_release_version: str,
- backport_packages: bool) -> Tuple[str, Optional[str]]:
+ past_releases: List[ReleaseInfo], current_release_version: str, backport_packages: bool
+) -> Tuple[str, Optional[str]]:
"""
Check if the release version passed is not later than the last release version
:param past_releases: all past releases (if there are any)
@@ -880,8 +925,11 @@ def check_if_release_version_ok(
current_release_version = "0.0.1" # TODO: replace with maintained version
if previous_release_version:
if Version(current_release_version) < Version(previous_release_version):
- print(f"The release {current_release_version} must be not less than "
- f"{previous_release_version} - last release for the package", file=sys.stderr)
+ print(
+ f"The release {current_release_version} must be not less than "
+ f"{previous_release_version} - last release for the package",
+ file=sys.stderr,
+ )
sys.exit(2)
return current_release_version, previous_release_version
@@ -907,17 +955,23 @@ def make_sure_remote_apache_exists_and_fetch():
:return:
"""
try:
- subprocess.check_call(["git", "remote", "add", "apache-https-for-providers",
- "https://github.com/apache/airflow.git"],
- stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
+ subprocess.check_call(
+ ["git", "remote", "add", "apache-https-for-providers", "https://github.com/apache/airflow.git"],
+ stdout=subprocess.DEVNULL,
+ stderr=subprocess.DEVNULL,
+ )
except subprocess.CalledProcessError as e:
if e.returncode == 128:
- print("The remote `apache-https-for-providers` already exists. If you have trouble running "
- "git log delete the remote", file=sys.stderr)
+ print(
+ "The remote `apache-https-for-providers` already exists. If you have trouble running "
+ "git log delete the remote",
+ file=sys.stderr,
+ )
else:
raise
- subprocess.check_call(["git", "fetch", "apache-https-for-providers"],
- stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
+ subprocess.check_call(
+ ["git", "fetch", "apache-https-for-providers"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
+ )
def get_git_command(base_commit: Optional[str]) -> List[str]:
@@ -926,18 +980,22 @@ def get_git_command(base_commit: Optional[str]) -> List[str]:
:param base_commit: if present - base commit from which to start the log from
:return: git command to run
"""
- git_cmd = ["git", "log", "apache-https-for-providers/master",
- "--pretty=format:%H %h %cd %s", "--date=short"]
+ git_cmd = [
+ "git",
+ "log",
+ "apache-https-for-providers/master",
+ "--pretty=format:%H %h %cd %s",
+ "--date=short",
+ ]
if base_commit:
git_cmd.append(f"{base_commit}...HEAD")
git_cmd.extend(['--', '.'])
return git_cmd
-def store_current_changes(provider_package_path: str,
- current_release_version: str,
- current_changes: str,
- backport_packages: bool) -> None:
+def store_current_changes(
+ provider_package_path: str, current_release_version: str, current_changes: str, backport_packages: bool
+) -> None:
"""
Stores current changes in the *_changes_YYYY.MM.DD.md file.
@@ -948,7 +1006,8 @@ def store_current_changes(provider_package_path: str,
"""
current_changes_file_path = os.path.join(
provider_package_path,
- get_provider_changes_prefix(backport_packages=backport_packages) + current_release_version + ".md")
+ get_provider_changes_prefix(backport_packages=backport_packages) + current_release_version + ".md",
+ )
with open(current_changes_file_path, "wt") as current_changes_file:
current_changes_file.write(current_changes)
current_changes_file.write("\n")
@@ -999,7 +1058,8 @@ def is_camel_case_with_acronyms(s: str):
def check_if_classes_are_properly_named(
- entity_summary: Dict[EntityType, EntityTypeSummary]) -> Tuple[int, int]:
+ entity_summary: Dict[EntityType, EntityTypeSummary]
+) -> Tuple[int, int]:
"""
Check if all entities in the dictionary are named properly. It prints names at the output
and returns the status of class names.
@@ -1014,12 +1074,16 @@ def check_if_classes_are_properly_named(
_, class_name = class_full_name.rsplit(".", maxsplit=1)
error_encountered = False
if not is_camel_case_with_acronyms(class_name):
- print(f"The class {class_full_name} is wrongly named. The "
- f"class name should be CamelCaseWithACRONYMS !")
+ print(
+ f"The class {class_full_name} is wrongly named. The "
+ f"class name should be CamelCaseWithACRONYMS !"
+ )
error_encountered = True
if not class_name.endswith(class_suffix):
- print(f"The class {class_full_name} is wrongly named. It is one of the {entity_type.value}"
- f" so it should end with {class_suffix}")
+ print(
+ f"The class {class_full_name} is wrongly named. It is one of the {entity_type.value}"
+ f" so it should end with {class_suffix}"
+ )
error_encountered = True
total_class_number += 1
if error_encountered:
@@ -1034,11 +1098,13 @@ def get_package_pip_name(provider_package_id: str, backport_packages: bool):
return f"apache-airflow-providers-{provider_package_id.replace('.', '-')}"
-def update_release_notes_for_package(provider_package_id: str,
- current_release_version: str,
- version_suffix: str,
- imported_classes: List[str],
- backport_packages: bool) -> Tuple[int, int]:
+def update_release_notes_for_package(
+ provider_package_id: str,
+ current_release_version: str,
+ version_suffix: str,
+ imported_classes: List[str],
+ backport_packages: bool,
+) -> Tuple[int, int]:
"""
Updates release notes (BACKPORT_PROVIDER_README.md/README.md) for the package.
Returns Tuple of total number of entities and badly named entities.
@@ -1055,27 +1121,35 @@ def update_release_notes_for_package(provider_package_id: str,
full_package_name = f"airflow.providers.{provider_package_id}"
provider_package_path = get_package_path(provider_package_id)
entity_summaries = get_package_class_summary(full_package_name, imported_classes)
- past_releases = get_all_releases(provider_package_path=provider_package_path,
- backport_packages=backport_packages)
+ past_releases = get_all_releases(
+ provider_package_path=provider_package_path, backport_packages=backport_packages
+ )
current_release_version, previous_release = check_if_release_version_ok(
- past_releases, current_release_version, backport_packages)
- cross_providers_dependencies = \
- get_cross_provider_dependent_packages(provider_package_id=provider_package_id)
- previous_release = get_previous_release_info(previous_release_version=previous_release,
- past_releases=past_releases,
- current_release_version=current_release_version)
+ past_releases, current_release_version, backport_packages
+ )
+ cross_providers_dependencies = get_cross_provider_dependent_packages(
+ provider_package_id=provider_package_id
+ )
+ previous_release = get_previous_release_info(
+ previous_release_version=previous_release,
+ past_releases=past_releases,
+ current_release_version=current_release_version,
+ )
git_cmd = get_git_command(previous_release)
changes = subprocess.check_output(git_cmd, cwd=provider_package_path, universal_newlines=True)
changes_table = convert_git_changes_to_table(
- changes,
- base_url="https://github.com/apache/airflow/commit/")
+ changes, base_url="https://github.com/apache/airflow/commit/"
+ )
pip_requirements_table = convert_pip_requirements_to_table(PROVIDERS_REQUIREMENTS[provider_package_id])
- cross_providers_dependencies_table = \
- convert_cross_package_dependencies_to_table(
- cross_providers_dependencies,
- base_url="https://github.com/apache/airflow/tree/master/airflow/providers/")
- release_version_no_leading_zeros = strip_leading_zeros_in_calver(current_release_version) \
- if backport_packages else current_release_version
+ cross_providers_dependencies_table = convert_cross_package_dependencies_to_table(
+ cross_providers_dependencies,
+ base_url="https://github.com/apache/airflow/tree/master/airflow/providers/",
+ )
+ release_version_no_leading_zeros = (
+ strip_leading_zeros_in_calver(current_release_version)
+ if backport_packages
+ else current_release_version
+ )
context: Dict[str, Any] = {
"ENTITY_TYPES": list(EntityType),
"README_FILE": "BACKPORT_PROVIDER_README.md" if backport_packages else "README.md",
@@ -1095,17 +1169,21 @@ def update_release_notes_for_package(provider_package_id: str,
"PROVIDER_TYPE": "Backport provider" if BACKPORT_PACKAGES else "Provider",
"PROVIDERS_FOLDER": "backport-providers" if BACKPORT_PACKAGES else "providers",
"INSTALL_REQUIREMENTS": get_install_requirements(
- provider_package_id=provider_package_id,
- backport_packages=backport_packages
+ provider_package_id=provider_package_id, backport_packages=backport_packages
),
"SETUP_REQUIREMENTS": get_setup_requirements(),
"EXTRAS_REQUIREMENTS": get_package_extras(
- provider_package_id=provider_package_id,
- backport_packages=backport_packages
- )
+ provider_package_id=provider_package_id, backport_packages=backport_packages
+ ),
}
- prepare_readme_and_changes_files(backport_packages, context, current_release_version, entity_summaries,
- provider_package_id, provider_package_path)
+ prepare_readme_and_changes_files(
+ backport_packages,
+ context,
+ current_release_version,
+ entity_summaries,
+ provider_package_id,
+ provider_package_path,
+ )
prepare_setup_py_file(context, provider_package_path)
prepare_setup_cfg_file(context, provider_package_path)
total, bad = check_if_classes_are_properly_named(entity_summaries)
@@ -1125,17 +1203,27 @@ def get_template_name(backport_packages: bool, template_suffix: str) -> str:
:param template_suffix: suffix to add
:return template name
"""
- return (BACKPORT_PROVIDER_TEMPLATE_PREFIX if backport_packages else PROVIDER_TEMPLATE_PREFIX) \
- + template_suffix
+ return (
+ BACKPORT_PROVIDER_TEMPLATE_PREFIX if backport_packages else PROVIDER_TEMPLATE_PREFIX
+ ) + template_suffix
-def prepare_readme_and_changes_files(backport_packages, context, current_release_version, entity_summaries,
- provider_package_id, provider_package_path):
+def prepare_readme_and_changes_files(
+ backport_packages,
+ context,
+ current_release_version,
+ entity_summaries,
+ provider_package_id,
+ provider_package_path,
+):
changes_template_name = get_template_name(backport_packages, "CHANGES")
current_changes = render_template(template_name=changes_template_name, context=context, extension='.md')
- store_current_changes(provider_package_path=provider_package_path,
- current_release_version=current_release_version,
- current_changes=current_changes, backport_packages=backport_packages)
+ store_current_changes(
+ provider_package_path=provider_package_path,
+ current_release_version=current_release_version,
+ current_changes=current_changes,
+ backport_packages=backport_packages,
+ )
context['ENTITIES'] = entity_summaries
context['ENTITY_NAMES'] = ENTITY_NAMES
all_releases = get_all_releases(provider_package_path, backport_packages=backport_packages)
@@ -1147,8 +1235,9 @@ def prepare_readme_and_changes_files(backport_packages, context, current_release
readme += render_template(template_name=classes_template_name, context=context, extension='.md')
for a_release in all_releases:
readme += a_release.content
- readme_file_path = os.path.join(provider_package_path,
- "BACKPORT_PROVIDER_README.md" if backport_packages else "README.md")
+ readme_file_path = os.path.join(
+ provider_package_path, "BACKPORT_PROVIDER_README.md" if backport_packages else "README.md"
+ )
old_text = ""
if os.path.isfile(readme_file_path):
with open(readme_file_path) as readme_file_read:
@@ -1174,10 +1263,7 @@ def prepare_setup_py_file(context, provider_package_path):
setup_target_prefix = BACKPORT_PROVIDER_PREFIX if BACKPORT_PACKAGES else ""
setup_file_path = os.path.abspath(os.path.join(provider_package_path, setup_target_prefix + "setup.py"))
setup_content = render_template(
- template_name=setup_template_name,
- context=context,
- extension='.py',
- autoescape=False
+ template_name=setup_template_name, context=context, extension='.py', autoescape=False
)
with open(setup_file_path, "wt") as setup_file:
setup_file.write(setup_content)
@@ -1194,16 +1280,15 @@ def prepare_setup_cfg_file(context, provider_package_path):
context=context,
extension='.cfg',
autoescape=False,
- keep_trailing_newline=True
+ keep_trailing_newline=True,
)
with open(setup_file_path, "wt") as setup_file:
setup_file.write(setup_content)
-def update_release_notes_for_packages(provider_ids: List[str],
- release_version: str,
- version_suffix: str,
- backport_packages: bool):
+def update_release_notes_for_packages(
+ provider_ids: List[str], release_version: str, version_suffix: str, backport_packages: bool
+):
"""
Updates release notes for the list of packages specified.
:param provider_ids: list of provider ids
@@ -1213,7 +1298,8 @@ def update_release_notes_for_packages(provider_ids: List[str],
:return:
"""
imported_classes = import_all_provider_classes(
- source_paths=[PROVIDERS_PATH], provider_ids=provider_ids, print_imports=False)
+ source_paths=[PROVIDERS_PATH], provider_ids=provider_ids, print_imports=False
+ )
make_sure_remote_apache_exists_and_fetch()
if len(provider_ids) == 0:
if backport_packages:
@@ -1231,11 +1317,8 @@ def update_release_notes_for_packages(provider_ids: List[str],
print()
for package in provider_ids:
inc_total, inc_bad = update_release_notes_for_package(
- package,
- release_version,
- version_suffix,
- imported_classes,
- backport_packages)
+ package, release_version, version_suffix, imported_classes, backport_packages
+ )
total += inc_total
bad += inc_bad
if bad == 0:
@@ -1287,12 +1370,12 @@ def verify_provider_package(package: str) -> None:
:return: None
"""
if package not in get_provider_packages():
- raise Exception(f"The package {package} is not a provider package. "
- f"Use one of {get_provider_packages()}")
+ raise Exception(
+ f"The package {package} is not a provider package. " f"Use one of {get_provider_packages()}"
+ )
-def copy_setup_py(provider_package_id: str,
- backport_packages: bool) -> None:
+def copy_setup_py(provider_package_id: str, backport_packages: bool) -> None:
"""
Copies setup.py to provider_package directory.
:param provider_package_id: package from which to copy the setup.py
@@ -1301,12 +1384,13 @@ def copy_setup_py(provider_package_id: str,
"""
setup_source_prefix = BACKPORT_PROVIDER_PREFIX if backport_packages else ""
provider_package_path = get_package_path(provider_package_id)
- copyfile(os.path.join(provider_package_path, setup_source_prefix + "setup.py"),
- os.path.join(MY_DIR_PATH, "setup.py"))
+ copyfile(
+ os.path.join(provider_package_path, setup_source_prefix + "setup.py"),
+ os.path.join(MY_DIR_PATH, "setup.py"),
+ )
-def copy_setup_cfg(provider_package_id: str,
- backport_packages: bool) -> None:
+def copy_setup_cfg(provider_package_id: str, backport_packages: bool) -> None:
"""
Copies setup.py to provider_package directory.
:param provider_package_id: package from which to copy the setup.cfg
@@ -1315,12 +1399,13 @@ def copy_setup_cfg(provider_package_id: str,
"""
setup_source_prefix = BACKPORT_PROVIDER_PREFIX if backport_packages else ""
provider_package_path = get_package_path(provider_package_id)
- copyfile(os.path.join(provider_package_path, setup_source_prefix + "setup.cfg"),
- os.path.join(MY_DIR_PATH, "setup.cfg"))
+ copyfile(
+ os.path.join(provider_package_path, setup_source_prefix + "setup.cfg"),
+ os.path.join(MY_DIR_PATH, "setup.cfg"),
+ )
-def copy_readme_and_changelog(provider_package_id: str,
- backport_packages: bool) -> None:
+def copy_readme_and_changelog(provider_package_id: str, backport_packages: bool) -> None:
"""
Copies the right README.md/CHANGELOG.txt to provider_package directory.
:param provider_package_id: package from which to copy the setup.py
@@ -1347,7 +1432,7 @@ def copy_readme_and_changelog(provider_package_id: str,
LIST_BACKPORTABLE_PACKAGES = "list-backportable-packages"
UPDATE_PACKAGE_RELEASE_NOTES = "update-package-release-notes"
- BACKPORT_PACKAGES = (os.getenv('BACKPORT_PACKAGES') == "true")
+ BACKPORT_PACKAGES = os.getenv('BACKPORT_PACKAGES') == "true"
suffix = ""
provider_names = get_provider_packages()
@@ -1356,16 +1441,22 @@ def copy_readme_and_changelog(provider_package_id: str,
possible_first_params.append(LIST_BACKPORTABLE_PACKAGES)
possible_first_params.append(UPDATE_PACKAGE_RELEASE_NOTES)
if len(sys.argv) == 1:
- print("""
+ print(
+ """
ERROR! Missing first param"
-""", file=sys.stderr)
+""",
+ file=sys.stderr,
+ )
usage()
sys.exit(1)
if sys.argv[1] == "--version-suffix":
if len(sys.argv) < 3:
- print("""
+ print(
+ """
ERROR! --version-suffix needs parameter!
-""", file=sys.stderr)
+""",
+ file=sys.stderr,
+ )
usage()
sys.exit(1)
suffix = sys.argv[2]
@@ -1375,9 +1466,12 @@ def copy_readme_and_changelog(provider_package_id: str,
sys.exit(0)
if sys.argv[1] not in possible_first_params:
- print(f"""
+ print(
+ f"""
ERROR! Wrong first param: {sys.argv[1]}
-""", file=sys.stderr)
+""",
+ file=sys.stderr,
+ )
usage()
print()
sys.exit(1)
@@ -1410,7 +1504,8 @@ def copy_readme_and_changelog(provider_package_id: str,
package_list,
release_version=release_ver,
version_suffix=suffix,
- backport_packages=BACKPORT_PACKAGES)
+ backport_packages=BACKPORT_PACKAGES,
+ )
sys.exit(0)
_provider_package = sys.argv[1]
diff --git a/provider_packages/refactor_provider_packages.py b/provider_packages/refactor_provider_packages.py
index 9ea80e9532967..ccacef21cfdc8 100755
--- a/provider_packages/refactor_provider_packages.py
+++ b/provider_packages/refactor_provider_packages.py
@@ -27,7 +27,9 @@
from fissix.pytree import Leaf
from provider_packages.prepare_provider_packages import (
- get_source_airflow_folder, get_source_providers_folder, get_target_providers_folder,
+ get_source_airflow_folder,
+ get_source_providers_folder,
+ get_target_providers_folder,
get_target_providers_package_folder,
)
@@ -36,6 +38,7 @@ def copy_provider_sources() -> None:
"""
Copies provider sources to directory where they will be refactored.
"""
+
def rm_build_dir() -> None:
"""
Removes build directory.
@@ -111,6 +114,7 @@ def remove_class(self, class_name) -> None:
:param class_name: name to remove
"""
+
def _remover(node: LN, capture: Capture, filename: Filename) -> None:
node.remove()
@@ -181,27 +185,23 @@ def add_provide_context_to_python_operators(self) -> None:
)
"""
+
def add_provide_context_to_python_operator(node: LN, capture: Capture, filename: Filename) -> None:
fn_args = capture['function_arguments'][0]
- if len(fn_args.children) > 0 and (not isinstance(fn_args.children[-1], Leaf)
- or fn_args.children[-1].type != token.COMMA):
+ if len(fn_args.children) > 0 and (
+ not isinstance(fn_args.children[-1], Leaf) or fn_args.children[-1].type != token.COMMA
+ ):
fn_args.append_child(Comma())
provide_context_arg = KeywordArg(Name('provide_context'), Name('True'))
provide_context_arg.prefix = fn_args.children[0].prefix
fn_args.append_child(provide_context_arg)
+ (self.qry.select_function("PythonOperator").is_call().modify(add_provide_context_to_python_operator))
(
- self.qry.
- select_function("PythonOperator").
- is_call().
- modify(add_provide_context_to_python_operator)
- )
- (
- self.qry.
- select_function("BranchPythonOperator").
- is_call().
- modify(add_provide_context_to_python_operator)
+ self.qry.select_function("BranchPythonOperator")
+ .is_call()
+ .modify(add_provide_context_to_python_operator)
)
def remove_super_init_call(self):
@@ -259,6 +259,7 @@ def __init__(self, context=None):
self.max_ingestion_time = max_ingestion_time
"""
+
def remove_super_init_call_modifier(node: LN, capture: Capture, filename: Filename) -> None:
for ch in node.post_order():
if isinstance(ch, Leaf) and ch.value == "super":
@@ -294,6 +295,7 @@ def remove_tags(self):
) as dag:
"""
+
def remove_tags_modifier(_: LN, capture: Capture, filename: Filename) -> None:
for node in capture['function_arguments'][0].post_order():
if isinstance(node, Leaf) and node.value == "tags" and node.type == TOKEN.NAME:
@@ -324,6 +326,7 @@ class GCSUploadSessionCompleteSensor(BaseSensorOperator):
Checks for changes in the number of objects at prefix in Google Cloud Storage
"""
+
def find_and_remove_poke_mode_only_import(node: LN):
for child in node.children:
if isinstance(child, Leaf) and child.type == 1 and child.value == 'poke_mode_only':
@@ -353,9 +356,14 @@ def find_root_remove_import(node: LN):
find_and_remove_poke_mode_only_import(current_node)
def is_poke_mode_only_decorator(node: LN) -> bool:
- return node.children and len(node.children) >= 2 and \
- isinstance(node.children[0], Leaf) and node.children[0].value == '@' and \
- isinstance(node.children[1], Leaf) and node.children[1].value == 'poke_mode_only'
+ return (
+ node.children
+ and len(node.children) >= 2
+ and isinstance(node.children[0], Leaf)
+ and node.children[0].value == '@'
+ and isinstance(node.children[1], Leaf)
+ and node.children[1].value == 'poke_mode_only'
+ )
def remove_poke_mode_only_modifier(node: LN, capture: Capture, filename: Filename) -> None:
for child in capture['node'].parent.children:
@@ -391,36 +399,37 @@ def refactor_amazon_package(self):
def amazon_package_filter(node: LN, capture: Capture, filename: Filename) -> bool:
return filename.startswith("./airflow/providers/amazon/")
- os.makedirs(os.path.join(get_target_providers_package_folder("amazon"), "common", "utils"),
- exist_ok=True)
+ os.makedirs(
+ os.path.join(get_target_providers_package_folder("amazon"), "common", "utils"), exist_ok=True
+ )
copyfile(
os.path.join(get_source_airflow_folder(), "airflow", "utils", "__init__.py"),
- os.path.join(get_target_providers_package_folder("amazon"), "common", "__init__.py")
+ os.path.join(get_target_providers_package_folder("amazon"), "common", "__init__.py"),
)
copyfile(
os.path.join(get_source_airflow_folder(), "airflow", "utils", "__init__.py"),
- os.path.join(get_target_providers_package_folder("amazon"), "common", "utils", "__init__.py")
+ os.path.join(get_target_providers_package_folder("amazon"), "common", "utils", "__init__.py"),
)
copyfile(
os.path.join(get_source_airflow_folder(), "airflow", "typing_compat.py"),
- os.path.join(get_target_providers_package_folder("amazon"), "common", "utils", "typing_compat.py")
+ os.path.join(
+ get_target_providers_package_folder("amazon"), "common", "utils", "typing_compat.py"
+ ),
)
(
- self.qry.
- select_module("airflow.typing_compat").
- filter(callback=amazon_package_filter).
- rename("airflow.providers.amazon.common.utils.typing_compat")
+ self.qry.select_module("airflow.typing_compat")
+ .filter(callback=amazon_package_filter)
+ .rename("airflow.providers.amazon.common.utils.typing_compat")
)
copyfile(
os.path.join(get_source_airflow_folder(), "airflow", "utils", "email.py"),
- os.path.join(get_target_providers_package_folder("amazon"), "common", "utils", "email.py")
+ os.path.join(get_target_providers_package_folder("amazon"), "common", "utils", "email.py"),
)
(
- self.qry.
- select_module("airflow.utils.email").
- filter(callback=amazon_package_filter).
- rename("airflow.providers.amazon.common.utils.email")
+ self.qry.select_module("airflow.utils.email")
+ .filter(callback=amazon_package_filter)
+ .rename("airflow.providers.amazon.common.utils.email")
)
def refactor_google_package(self):
@@ -517,6 +526,7 @@ def _generate_virtualenv_cmd(tmp_dir: str, python_bin: str, system_site_packages
KEY_REGEX = re.compile(r'^[\\w.-]+$')
"""
+
def google_package_filter(node: LN, capture: Capture, filename: Filename) -> bool:
return filename.startswith("./airflow/providers/google/")
@@ -524,64 +534,66 @@ def pure_airflow_models_filter(node: LN, capture: Capture, filename: Filename) -
"""Check if select is exactly [airflow, . , models]"""
return len(list(node.children[1].leaves())) == 3
- os.makedirs(os.path.join(get_target_providers_package_folder("google"), "common", "utils"),
- exist_ok=True)
+ os.makedirs(
+ os.path.join(get_target_providers_package_folder("google"), "common", "utils"), exist_ok=True
+ )
copyfile(
os.path.join(get_source_airflow_folder(), "airflow", "utils", "__init__.py"),
- os.path.join(get_target_providers_package_folder("google"), "common", "utils", "__init__.py")
+ os.path.join(get_target_providers_package_folder("google"), "common", "utils", "__init__.py"),
)
copyfile(
os.path.join(get_source_airflow_folder(), "airflow", "utils", "python_virtualenv.py"),
- os.path.join(get_target_providers_package_folder("google"), "common", "utils",
- "python_virtualenv.py")
+ os.path.join(
+ get_target_providers_package_folder("google"), "common", "utils", "python_virtualenv.py"
+ ),
)
- copy_helper_py_file(os.path.join(
- get_target_providers_package_folder("google"), "common", "utils", "helpers.py"))
+ copy_helper_py_file(
+ os.path.join(get_target_providers_package_folder("google"), "common", "utils", "helpers.py")
+ )
copyfile(
os.path.join(get_source_airflow_folder(), "airflow", "utils", "module_loading.py"),
- os.path.join(get_target_providers_package_folder("google"), "common", "utils",
- "module_loading.py")
+ os.path.join(
+ get_target_providers_package_folder("google"), "common", "utils", "module_loading.py"
+ ),
)
(
- self.qry.
- select_module("airflow.utils.python_virtualenv").
- filter(callback=google_package_filter).
- rename("airflow.providers.google.common.utils.python_virtualenv")
+ self.qry.select_module("airflow.utils.python_virtualenv")
+ .filter(callback=google_package_filter)
+ .rename("airflow.providers.google.common.utils.python_virtualenv")
)
copyfile(
os.path.join(get_source_airflow_folder(), "airflow", "utils", "process_utils.py"),
- os.path.join(get_target_providers_package_folder("google"), "common", "utils", "process_utils.py")
+ os.path.join(
+ get_target_providers_package_folder("google"), "common", "utils", "process_utils.py"
+ ),
)
(
- self.qry.
- select_module("airflow.utils.process_utils").
- filter(callback=google_package_filter).
- rename("airflow.providers.google.common.utils.process_utils")
+ self.qry.select_module("airflow.utils.process_utils")
+ .filter(callback=google_package_filter)
+ .rename("airflow.providers.google.common.utils.process_utils")
)
(
- self.qry.
- select_module("airflow.utils.helpers").
- filter(callback=google_package_filter).
- rename("airflow.providers.google.common.utils.helpers")
+ self.qry.select_module("airflow.utils.helpers")
+ .filter(callback=google_package_filter)
+ .rename("airflow.providers.google.common.utils.helpers")
)
(
- self.qry.
- select_module("airflow.utils.module_loading").
- filter(callback=google_package_filter).
- rename("airflow.providers.google.common.utils.module_loading")
+ self.qry.select_module("airflow.utils.module_loading")
+ .filter(callback=google_package_filter)
+ .rename("airflow.providers.google.common.utils.module_loading")
)
(
# Fix BaseOperatorLinks imports
- self.qry.select_module("airflow.models").
- is_filename(include=r"bigquery\.py|mlengine\.py").
- filter(callback=google_package_filter).
- filter(pure_airflow_models_filter).
- rename("airflow.models.baseoperator")
+ self.qry.select_module("airflow.models")
+ .is_filename(include=r"bigquery\.py|mlengine\.py")
+ .filter(callback=google_package_filter)
+ .filter(pure_airflow_models_filter)
+ .rename("airflow.models.baseoperator")
)
def refactor_odbc_package(self):
@@ -608,22 +620,21 @@ def refactor_odbc_package(self):
"""
+
def odbc_package_filter(node: LN, capture: Capture, filename: Filename) -> bool:
return filename.startswith("./airflow/providers/odbc/")
os.makedirs(os.path.join(get_target_providers_folder(), "odbc", "utils"), exist_ok=True)
copyfile(
os.path.join(get_source_airflow_folder(), "airflow", "utils", "__init__.py"),
- os.path.join(get_target_providers_package_folder("odbc"), "utils", "__init__.py")
+ os.path.join(get_target_providers_package_folder("odbc"), "utils", "__init__.py"),
)
- copy_helper_py_file(os.path.join(
- get_target_providers_package_folder("odbc"), "utils", "helpers.py"))
+ copy_helper_py_file(os.path.join(get_target_providers_package_folder("odbc"), "utils", "helpers.py"))
(
- self.qry.
- select_module("airflow.utils.helpers").
- filter(callback=odbc_package_filter).
- rename("airflow.providers.odbc.utils.helpers")
+ self.qry.select_module("airflow.utils.helpers")
+ .filter(callback=odbc_package_filter)
+ .rename("airflow.providers.odbc.utils.helpers")
)
def refactor_kubernetes_pod_operator(self):
@@ -631,11 +642,10 @@ def kubernetes_package_filter(node: LN, capture: Capture, filename: Filename) ->
return filename.startswith("./airflow/providers/cncf/kubernetes")
(
- self.qry.
- select_class("KubernetesPodOperator").
- select_method("add_xcom_sidecar").
- filter(callback=kubernetes_package_filter).
- rename("add_sidecar")
+ self.qry.select_class("KubernetesPodOperator")
+ .select_method("add_xcom_sidecar")
+ .filter(callback=kubernetes_package_filter)
+ .rename("add_sidecar")
)
def do_refactor(self, in_process: bool = False) -> None: # noqa
@@ -653,7 +663,7 @@ def do_refactor(self, in_process: bool = False) -> None: # noqa
if __name__ == '__main__':
- BACKPORT_PACKAGES = (os.getenv('BACKPORT_PACKAGES') == "true")
+ BACKPORT_PACKAGES = os.getenv('BACKPORT_PACKAGES') == "true"
in_process = False
if len(sys.argv) > 1:
if sys.argv[1] in ['--help', '-h']:
diff --git a/provider_packages/remove_old_releases.py b/provider_packages/remove_old_releases.py
index 7383c0d548cff..fb8643d74853d 100644
--- a/provider_packages/remove_old_releases.py
+++ b/provider_packages/remove_old_releases.py
@@ -40,13 +40,15 @@ class VersionedFile(NamedTuple):
def split_version_and_suffix(file_name: str, suffix: str) -> VersionedFile:
- no_suffix_file = file_name[:-len(suffix)]
+ no_suffix_file = file_name[: -len(suffix)]
no_version_file, version = no_suffix_file.rsplit("-", 1)
- return VersionedFile(base=no_version_file + "-",
- version=version,
- suffix=suffix,
- type=no_version_file + "-" + suffix,
- comparable_version=LooseVersion(version))
+ return VersionedFile(
+ base=no_version_file + "-",
+ version=version,
+ suffix=suffix,
+ type=no_version_file + "-" + suffix,
+ comparable_version=LooseVersion(version),
+ )
def process_all_files(directory: str, suffix: str, execute: bool):
@@ -63,8 +65,10 @@ def process_all_files(directory: str, suffix: str, execute: bool):
for package_types in package_types_dicts.values():
if len(package_types) == 1:
versioned_file = package_types[0]
- print("Leaving the only version: "
- f"${versioned_file.base + versioned_file.version + versioned_file.suffix}")
+ print(
+ "Leaving the only version: "
+ f"${versioned_file.base + versioned_file.version + versioned_file.suffix}"
+ )
# Leave only last version from each type
for versioned_file in package_types[:-1]:
command = ["svn", "rm", versioned_file.base + versioned_file.version + versioned_file.suffix]
@@ -76,10 +80,16 @@ def process_all_files(directory: str, suffix: str, execute: bool):
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description='Removes old releases.')
- parser.add_argument('--directory', dest='directory', action='store', required=True,
- help='Directory to remove old releases in')
- parser.add_argument('--execute', dest='execute', action='store_true',
- help='Execute the removal rather than dry run')
+ parser.add_argument(
+ '--directory',
+ dest='directory',
+ action='store',
+ required=True,
+ help='Directory to remove old releases in',
+ )
+ parser.add_argument(
+ '--execute', dest='execute', action='store_true', help='Execute the removal rather than dry run'
+ )
return parser.parse_args()
diff --git a/scripts/ci/pre_commit/pre_commit_check_order_setup.py b/scripts/ci/pre_commit/pre_commit_check_order_setup.py
index f255c2d0e34de..cfcb297e68e4b 100755
--- a/scripts/ci/pre_commit/pre_commit_check_order_setup.py
+++ b/scripts/ci/pre_commit/pre_commit_check_order_setup.py
@@ -38,8 +38,9 @@ def _check_list_sorted(the_list: List[str], message: str) -> None:
while sorted_list[i] == the_list[i]:
i += 1
print(f"{message} NOK")
- errors.append(f"ERROR in {message}. First wrongly sorted element"
- f" {the_list[i]}. Should be {sorted_list[i]}")
+ errors.append(
+ f"ERROR in {message}. First wrongly sorted element" f" {the_list[i]}. Should be {sorted_list[i]}"
+ )
def setup() -> str:
@@ -55,7 +56,8 @@ def check_main_dependent_group(setup_context: str) -> None:
'# Start dependencies group' and '# End dependencies group' in setup.py
"""
pattern_main_dependent_group = re.compile(
- '# Start dependencies group\n(.*)# End dependencies group', re.DOTALL)
+ '# Start dependencies group\n(.*)# End dependencies group', re.DOTALL
+ )
main_dependent_group = pattern_main_dependent_group.findall(setup_context)[0]
pattern_sub_dependent = re.compile(' = \\[.*?\\]\n', re.DOTALL)
@@ -76,8 +78,7 @@ def check_sub_dependent_group(setup_context: str) -> None:
pattern_dependent_version = re.compile('[~|><=;].*')
for group_name in dependent_group_names:
- pattern_sub_dependent = re.compile(
- f'{group_name} = \\[(.*?)\\]', re.DOTALL)
+ pattern_sub_dependent = re.compile(f'{group_name} = \\[(.*?)\\]', re.DOTALL)
sub_dependent = pattern_sub_dependent.findall(setup_context)[0]
pattern_dependent = re.compile('\'(.*?)\'')
dependent = pattern_dependent.findall(sub_dependent)
@@ -104,8 +105,7 @@ def check_install_and_setup_requires(setup_context: str) -> None:
Test for an order of dependencies in function do_setup section
install_requires and setup_requires in setup.py
"""
- pattern_install_and_setup_requires = re.compile(
- '(setup_requires) ?= ?\\[(.*?)\\]', re.DOTALL)
+ pattern_install_and_setup_requires = re.compile('(setup_requires) ?= ?\\[(.*?)\\]', re.DOTALL)
install_and_setup_requires = pattern_install_and_setup_requires.findall(setup_context)
for dependent_requires in install_and_setup_requires:
@@ -123,7 +123,8 @@ def check_extras_require(setup_context: str) -> None:
extras_require in setup.py
"""
pattern_extras_requires = re.compile(
- r'EXTRAS_REQUIREMENTS: Dict\[str, Iterable\[str\]] = {(.*?)}', re.DOTALL)
+ r'EXTRAS_REQUIREMENTS: Dict\[str, Iterable\[str\]] = {(.*?)}', re.DOTALL
+ )
extras_requires = pattern_extras_requires.findall(setup_context)[0]
pattern_dependent = re.compile('\'(.*?)\'')
@@ -137,7 +138,8 @@ def check_provider_requirements(setup_context: str) -> None:
providers_require in setup.py
"""
pattern_extras_requires = re.compile(
- r'PROVIDERS_REQUIREMENTS: Dict\[str, Iterable\[str\]\] = {(.*?)}', re.DOTALL)
+ r'PROVIDERS_REQUIREMENTS: Dict\[str, Iterable\[str\]\] = {(.*?)}', re.DOTALL
+ )
extras_requires = pattern_extras_requires.findall(setup_context)[0]
pattern_dependent = re.compile('"(.*?)"')
diff --git a/scripts/ci/pre_commit/pre_commit_check_setup_installation.py b/scripts/ci/pre_commit/pre_commit_check_setup_installation.py
index 7c30683c2c805..20bec5ed3959d 100755
--- a/scripts/ci/pre_commit/pre_commit_check_setup_installation.py
+++ b/scripts/ci/pre_commit/pre_commit_check_setup_installation.py
@@ -44,12 +44,12 @@ def get_extras_from_setup() -> Dict[str, List[str]]:
"""
setup_content = get_file_content(SETUP_PY_FILE)
- extras_section_regex = re.compile(
- r'^EXTRAS_REQUIREMENTS: Dict[^{]+{([^}]+)}', re.MULTILINE)
+ extras_section_regex = re.compile(r'^EXTRAS_REQUIREMENTS: Dict[^{]+{([^}]+)}', re.MULTILINE)
extras_section = extras_section_regex.findall(setup_content)[0]
extras_regex = re.compile(
- rf'^\s+[\"\']({PY_IDENTIFIER})[\"\']:\s*({PY_IDENTIFIER})[^#\n]*(#\s*TODO.*)?$', re.MULTILINE)
+ rf'^\s+[\"\']({PY_IDENTIFIER})[\"\']:\s*({PY_IDENTIFIER})[^#\n]*(#\s*TODO.*)?$', re.MULTILINE
+ )
extras_dict: Dict[str, List[str]] = {}
for extras in extras_regex.findall(extras_section):
@@ -67,8 +67,9 @@ def get_extras_from_docs() -> List[str]:
"""
docs_content = get_file_content('docs', DOCS_FILE)
- extras_section_regex = re.compile(rf'^\|[^|]+\|.*pip install .apache-airflow\[({PY_IDENTIFIER})\].',
- re.MULTILINE)
+ extras_section_regex = re.compile(
+ rf'^\|[^|]+\|.*pip install .apache-airflow\[({PY_IDENTIFIER})\].', re.MULTILINE
+ )
extras = extras_section_regex.findall(docs_content)
extras = list(filter(lambda entry: entry != 'all', extras))
@@ -90,10 +91,11 @@ def get_extras_from_docs() -> List[str]:
if f"'{extras}'" not in setup_packages_str:
output_table += "| {:20} | {:^10} | {:^10} |\n".format(extras, "", "V")
- if(output_table == ""):
+ if output_table == "":
exit(0)
- print(f"""
+ print(
+ f"""
ERROR
"EXTRAS_REQUIREMENTS" section in {SETUP_PY_FILE} should be synchronized
@@ -101,7 +103,8 @@ def get_extras_from_docs() -> List[str]:
here is a list of packages that are used but are not documented, or
documented although not used.
- """)
+ """
+ )
print(".{:_^22}.{:_^12}.{:_^12}.".format("NAME", "SETUP", "INSTALLATION"))
print(output_table)
diff --git a/scripts/ci/pre_commit/pre_commit_yaml_to_cfg.py b/scripts/ci/pre_commit/pre_commit_yaml_to_cfg.py
index 9b31cef30f1c6..6558901890a14 100755
--- a/scripts/ci/pre_commit/pre_commit_yaml_to_cfg.py
+++ b/scripts/ci/pre_commit/pre_commit_yaml_to_cfg.py
@@ -91,7 +91,8 @@ def _write_section(configfile, section):
section_description = None
if section["description"] is not None:
section_description = list(
- filter(lambda x: (x is not None) or x != "", section["description"].splitlines()))
+ filter(lambda x: (x is not None) or x != "", section["description"].splitlines())
+ )
if section_description:
configfile.write("\n")
for single_line_desc in section_description:
@@ -106,8 +107,7 @@ def _write_section(configfile, section):
def _write_option(configfile, idx, option):
option_description = None
if option["description"] is not None:
- option_description = list(
- filter(lambda x: x is not None, option["description"].splitlines()))
+ option_description = list(filter(lambda x: x is not None, option["description"].splitlines()))
if option_description:
if idx != 0:
@@ -141,26 +141,26 @@ def _write_option(configfile, idx, option):
if __name__ == '__main__':
airflow_config_dir = os.path.join(
- os.path.dirname(__file__),
- os.pardir, os.pardir, os.pardir,
- "airflow", "config_templates")
+ os.path.dirname(__file__), os.pardir, os.pardir, os.pardir, "airflow", "config_templates"
+ )
airflow_default_config_path = os.path.join(airflow_config_dir, "default_airflow.cfg")
airflow_config_yaml_file_path = os.path.join(airflow_config_dir, "config.yml")
write_config(
- yaml_config_file_path=airflow_config_yaml_file_path,
- default_cfg_file_path=airflow_default_config_path
+ yaml_config_file_path=airflow_config_yaml_file_path, default_cfg_file_path=airflow_default_config_path
)
providers_dir = os.path.join(
- os.path.dirname(__file__),
- os.pardir, os.pardir, os.pardir,
- "airflow", "providers")
+ os.path.dirname(__file__), os.pardir, os.pardir, os.pardir, "airflow", "providers"
+ )
for root, dir_names, file_names in os.walk(providers_dir):
for file_name in file_names:
- if root.endswith("config_templates") and file_name == 'config.yml' and \
- os.path.isfile(os.path.join(root, "default_config.cfg")):
+ if (
+ root.endswith("config_templates")
+ and file_name == 'config.yml'
+ and os.path.isfile(os.path.join(root, "default_config.cfg"))
+ ):
write_config(
yaml_config_file_path=os.path.join(root, "config.yml"),
- default_cfg_file_path=os.path.join(root, "default_config.cfg")
+ default_cfg_file_path=os.path.join(root, "default_config.cfg"),
)
diff --git a/scripts/in_container/update_quarantined_test_status.py b/scripts/in_container/update_quarantined_test_status.py
index ba5bf3f3209c4..73cf962904c58 100755
--- a/scripts/in_container/update_quarantined_test_status.py
+++ b/scripts/in_container/update_quarantined_test_status.py
@@ -66,8 +66,10 @@ class TestHistory(NamedTuple):
def get_url(result: TestResult) -> str:
- return f"[{result.name}](https://github.com/{user}/{repo}/blob/" \
- f"master/{result.file}?test_id={result.test_id}#L{result.line})"
+ return (
+ f"[{result.name}](https://github.com/{user}/{repo}/blob/"
+ f"master/{result.file}?test_id={result.test_id}#L{result.line})"
+ )
def parse_state_history(history_string: str) -> List[bool]:
@@ -136,11 +138,7 @@ def update_test_history(history: TestHistory, last_status: bool):
def create_test_history(result: TestResult) -> TestHistory:
print(f"Creating test history {result}")
return TestHistory(
- test_id=result.test_id,
- name=result.name,
- url=get_url(result),
- states=[result.result],
- comment=""
+ test_id=result.test_id, name=result.name, url=get_url(result), states=[result.result], comment=""
)
@@ -151,9 +149,9 @@ def get_history_status(history: TestHistory):
return "Flaky"
if all(history.states):
return "Stable"
- if all(history.states[0:num_runs - 1]):
+ if all(history.states[0 : num_runs - 1]):
return "Just one more"
- if all(history.states[0:int(num_runs / 2)]):
+ if all(history.states[0 : int(num_runs / 2)]):
return "Almost there"
return "Flaky"
@@ -163,13 +161,15 @@ def get_table(history_map: Dict[str, TestHistory]) -> str:
the_table: List[List[str]] = []
for ordered_key in sorted(history_map.keys()):
history = history_map[ordered_key]
- the_table.append([
- history.url,
- "Succeeded" if history.states[0] else "Failed",
- " ".join([reverse_status_map[state] for state in history.states]),
- get_history_status(history),
- history.comment
- ])
+ the_table.append(
+ [
+ history.url,
+ "Succeeded" if history.states[0] else "Failed",
+ " ".join([reverse_status_map[state] for state in history.states]),
+ get_history_status(history),
+ history.comment,
+ ]
+ )
return tabulate(the_table, headers, tablefmt="github")
@@ -187,14 +187,16 @@ def get_table(history_map: Dict[str, TestHistory]) -> str:
if len(test.contents) > 0 and test.contents[0].name == 'skipped':
print(f"skipping {test['name']}")
continue
- test_results.append(TestResult(
- test_id=test['classname'] + "::" + test['name'],
- file=test['file'],
- line=test['line'],
- name=test['name'],
- classname=test['classname'],
- result=len(test.contents) == 0
- ))
+ test_results.append(
+ TestResult(
+ test_id=test['classname'] + "::" + test['name'],
+ file=test['file'],
+ line=test['line'],
+ name=test['name'],
+ classname=test['classname'],
+ result=len(test.contents) == 0,
+ )
+ )
token = os.environ.get("GITHUB_TOKEN")
print(f"Token: {token}")
@@ -221,8 +223,7 @@ def get_table(history_map: Dict[str, TestHistory]) -> str:
for test_result in test_results:
previous_results = parsed_test_map.get(test_result.test_id)
if previous_results:
- updated_results = update_test_history(
- previous_results, test_result.result)
+ updated_results = update_test_history(previous_results, test_result.result)
new_test_map[previous_results.test_id] = updated_results
else:
new_history = create_test_history(test_result)
@@ -234,8 +235,9 @@ def get_table(history_map: Dict[str, TestHistory]) -> str:
print(table)
print()
with open(join(dirname(realpath(__file__)), "quarantine_issue_header.md")) as f:
- header = jinja2.Template(f.read(), autoescape=True, undefined=StrictUndefined).\
- render(DATE_UTC_NOW=datetime.utcnow())
- quarantined_issue.edit(title=None,
- body=header + "\n\n" + str(table),
- state='open' if len(test_results) > 0 else 'closed')
+ header = jinja2.Template(f.read(), autoescape=True, undefined=StrictUndefined).render(
+ DATE_UTC_NOW=datetime.utcnow()
+ )
+ quarantined_issue.edit(
+ title=None, body=header + "\n\n" + str(table), state='open' if len(test_results) > 0 else 'closed'
+ )
diff --git a/scripts/tools/list-integrations.py b/scripts/tools/list-integrations.py
index 73c514c7c1ee3..ebfbc67be58fb 100755
--- a/scripts/tools/list-integrations.py
+++ b/scripts/tools/list-integrations.py
@@ -96,9 +96,7 @@ def _find_clazzes(directory, base_class):
"""
parser = argparse.ArgumentParser( # noqa
- description=HELP,
- formatter_class=argparse.RawTextHelpFormatter,
- epilog=EPILOG
+ description=HELP, formatter_class=argparse.RawTextHelpFormatter, epilog=EPILOG
)
# argparse handle `-h/--help/` internally
parser.parse_args()
@@ -111,8 +109,9 @@ def _find_clazzes(directory, base_class):
}
for integration_base_directory, integration_class in RESOURCE_TYPES.items():
- for integration_directory in glob(f"{AIRFLOW_ROOT}/airflow/**/{integration_base_directory}",
- recursive=True):
+ for integration_directory in glob(
+ f"{AIRFLOW_ROOT}/airflow/**/{integration_base_directory}", recursive=True
+ ):
if "contrib" in integration_directory:
continue
diff --git a/setup.py b/setup.py
index f0872b0c8d8bc..992ee1f438e78 100644
--- a/setup.py
+++ b/setup.py
@@ -131,6 +131,7 @@ def git_version(version_: str) -> str:
"""
try:
import git
+
try:
repo = git.Repo(os.path.join(*[my_dir, '.git']))
except git.NoSuchPathError:
@@ -208,10 +209,7 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version
cloudant = [
'cloudant>=2.0',
]
-dask = [
- 'cloudpickle>=1.4.1, <1.5.0',
- 'distributed>=2.11.1, <2.20'
-]
+dask = ['cloudpickle>=1.4.1, <1.5.0', 'distributed>=2.11.1, <2.20']
databricks = [
'requests>=2.20.0, <3',
]
@@ -227,7 +225,7 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version
'sphinx-rtd-theme>=0.1.6',
'sphinxcontrib-httpdomain>=1.7.0',
"sphinxcontrib-redoc>=1.6.0",
- "sphinxcontrib-spelling==5.2.1"
+ "sphinxcontrib-spelling==5.2.1",
]
docker = [
'docker~=3.0',
@@ -316,9 +314,7 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version
'cryptography>=2.0.0',
'kubernetes>=3.0.0, <12.0.0',
]
-kylin = [
- 'kylinpy>=2.6'
-]
+kylin = ['kylinpy>=2.6']
ldap = [
'ldap3>=2.5.1',
]
@@ -359,9 +355,7 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version
postgres = [
'psycopg2-binary>=2.7.4',
]
-presto = [
- 'presto-python-client>=0.7.0,<0.8'
-]
+presto = ['presto-python-client>=0.7.0,<0.8']
qds = [
'qds-sdk>=1.10.4',
]
@@ -429,8 +423,21 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version
]
# End dependencies group
-all_dbs = (cassandra + cloudant + druid + exasol + hdfs + hive + mongo + mssql + mysql +
- pinot + postgres + presto + vertica)
+all_dbs = (
+ cassandra
+ + cloudant
+ + druid
+ + exasol
+ + hdfs
+ + hive
+ + mongo
+ + mssql
+ + mysql
+ + pinot
+ + postgres
+ + presto
+ + vertica
+)
############################################################################################################
# IMPORTANT NOTE!!!!!!!!!!!!!!!
@@ -456,7 +463,7 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version
'jira',
'mongomock',
'moto==1.3.14', # TODO - fix Datasync issues to get higher version of moto:
- # See: https://github.com/apache/airflow/issues/10985
+ # See: https://github.com/apache/airflow/issues/10985
'parameterized',
'paramiko',
'pipdeptree',
@@ -599,7 +606,7 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version
'jdbc': jdbc,
'jira': jira,
'kerberos': kerberos,
- 'kubernetes': kubernetes, # TODO: remove this in Airflow 2.1
+ 'kubernetes': kubernetes, # TODO: remove this in Airflow 2.1
'ldap': ldap,
"microsoft.azure": azure,
"microsoft.mssql": mssql,
@@ -639,16 +646,22 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version
}
# Make devel_all contain all providers + extras + unique
-devel_all = list(set(devel +
- [req for req_list in EXTRAS_REQUIREMENTS.values() for req in req_list] +
- [req for req_list in PROVIDERS_REQUIREMENTS.values() for req in req_list]))
+devel_all = list(
+ set(
+ devel
+ + [req for req_list in EXTRAS_REQUIREMENTS.values() for req in req_list]
+ + [req for req_list in PROVIDERS_REQUIREMENTS.values() for req in req_list]
+ )
+)
PACKAGES_EXCLUDED_FOR_ALL = []
if PY3:
- PACKAGES_EXCLUDED_FOR_ALL.extend([
- 'snakebite',
- ])
+ PACKAGES_EXCLUDED_FOR_ALL.extend(
+ [
+ 'snakebite',
+ ]
+ )
# Those packages are excluded because they break tests (downgrading mock) and they are
# not needed to run our test suite.
@@ -668,13 +681,17 @@ def is_package_excluded(package: str, exclusion_list: List[str]):
return any(package.startswith(excluded_package) for excluded_package in exclusion_list)
-devel_all = [package for package in devel_all if not is_package_excluded(
- package=package,
- exclusion_list=PACKAGES_EXCLUDED_FOR_ALL)
+devel_all = [
+ package
+ for package in devel_all
+ if not is_package_excluded(package=package, exclusion_list=PACKAGES_EXCLUDED_FOR_ALL)
]
-devel_ci = [package for package in devel_all if not is_package_excluded(
- package=package,
- exclusion_list=PACKAGES_EXCLUDED_FOR_CI + PACKAGES_EXCLUDED_FOR_ALL)
+devel_ci = [
+ package
+ for package in devel_all
+ if not is_package_excluded(
+ package=package, exclusion_list=PACKAGES_EXCLUDED_FOR_CI + PACKAGES_EXCLUDED_FOR_ALL
+ )
]
EXTRAS_REQUIREMENTS.update(
@@ -748,9 +765,11 @@ def is_package_excluded(package: str, exclusion_list: List[str]):
def do_setup():
"""Perform the Airflow package setup."""
install_providers_from_sources = os.getenv('INSTALL_PROVIDERS_FROM_SOURCES')
- exclude_patterns = \
- [] if install_providers_from_sources and install_providers_from_sources == 'true' \
+ exclude_patterns = (
+ []
+ if install_providers_from_sources and install_providers_from_sources == 'true'
else ['airflow.providers', 'airflow.providers.*']
+ )
write_version()
setup(
name='apache-airflow',
@@ -759,13 +778,15 @@ def do_setup():
long_description_content_type='text/markdown',
license='Apache License 2.0',
version=version,
- packages=find_packages(
- include=['airflow*'],
- exclude=exclude_patterns),
+ packages=find_packages(include=['airflow*'], exclude=exclude_patterns),
package_data={
'airflow': ['py.typed'],
- '': ['airflow/alembic.ini', "airflow/git_version", "*.ipynb",
- "airflow/providers/cncf/kubernetes/example_dags/*.yaml"],
+ '': [
+ 'airflow/alembic.ini',
+ "airflow/git_version",
+ "*.ipynb",
+ "airflow/providers/cncf/kubernetes/example_dags/*.yaml",
+ ],
'airflow.api_connexion.openapi': ['*.yaml'],
'airflow.serialization': ["*.json"],
},
@@ -800,8 +821,7 @@ def do_setup():
author='Apache Software Foundation',
author_email='dev@airflow.apache.org',
url='http://airflow.apache.org/',
- download_url=(
- 'https://archive.apache.org/dist/airflow/' + version),
+ download_url=('https://archive.apache.org/dist/airflow/' + version),
cmdclass={
'extra_clean': CleanCommand,
'compile_assets': CompileAssets,
diff --git a/tests/airflow_pylint/disable_checks_for_tests.py b/tests/airflow_pylint/disable_checks_for_tests.py
index e7b23b2f9abe3..cbdf9b2ce936f 100644
--- a/tests/airflow_pylint/disable_checks_for_tests.py
+++ b/tests/airflow_pylint/disable_checks_for_tests.py
@@ -19,8 +19,9 @@
from astroid import MANAGER, scoped_nodes
from pylint.lint import PyLinter
-DISABLED_CHECKS_FOR_TESTS = \
+DISABLED_CHECKS_FOR_TESTS = (
"missing-docstring, no-self-use, too-many-public-methods, protected-access, do-not-use-asserts"
+)
def register(_: PyLinter):
@@ -42,18 +43,21 @@ def transform(mod):
:param mod: astroid module
:return: None
"""
- if mod.name.startswith("test_") or \
- mod.name.startswith("tests.") or \
- mod.name.startswith("kubernetes_tests.") or \
- mod.name.startswith("chart."):
+ if (
+ mod.name.startswith("test_")
+ or mod.name.startswith("tests.")
+ or mod.name.startswith("kubernetes_tests.")
+ or mod.name.startswith("chart.")
+ ):
decoded_lines = mod.stream().read().decode("utf-8").split("\n")
if decoded_lines[0].startswith("# pylint: disable="):
decoded_lines[0] = decoded_lines[0] + " " + DISABLED_CHECKS_FOR_TESTS
elif decoded_lines[0].startswith("#") or decoded_lines[0].strip() == "":
decoded_lines[0] = "# pylint: disable=" + DISABLED_CHECKS_FOR_TESTS
else:
- raise Exception(f"The first line of module {mod.name} is not a comment or empty. "
- f"Please make sure it is!")
+ raise Exception(
+ f"The first line of module {mod.name} is not a comment or empty. " f"Please make sure it is!"
+ )
# pylint will read from `.file_bytes` attribute later when tokenization
mod.file_bytes = "\n".join(decoded_lines).encode("utf-8")
diff --git a/tests/airflow_pylint/do_not_use_asserts.py b/tests/airflow_pylint/do_not_use_asserts.py
index e042c694bad38..47a0e208b3862 100644
--- a/tests/airflow_pylint/do_not_use_asserts.py
+++ b/tests/airflow_pylint/do_not_use_asserts.py
@@ -29,13 +29,14 @@ class DoNotUseAssertsChecker(BaseChecker):
'E7401': (
'Do not use asserts.',
'do-not-use-asserts',
- 'Asserts should not be used in the main Airflow code.'
+ 'Asserts should not be used in the main Airflow code.',
),
}
def visit_assert(self, node):
self.add_message(
- self.name, node=node,
+ self.name,
+ node=node,
)
diff --git a/tests/always/test_example_dags.py b/tests/always/test_example_dags.py
index 0261a4370c389..339a42a4c2f47 100644
--- a/tests/always/test_example_dags.py
+++ b/tests/always/test_example_dags.py
@@ -26,9 +26,7 @@
os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir, os.pardir)
)
-NO_DB_QUERY_EXCEPTION = [
- "/airflow/example_dags/example_subdag_operator.py"
-]
+NO_DB_QUERY_EXCEPTION = ["/airflow/example_dags/example_subdag_operator.py"]
class TestExampleDags(unittest.TestCase):
diff --git a/tests/always/test_project_structure.py b/tests/always/test_project_structure.py
index 6da4a53649420..bcde9bb9afe1e 100644
--- a/tests/always/test_project_structure.py
+++ b/tests/always/test_project_structure.py
@@ -57,8 +57,12 @@ def test_providers_modules_should_have_tests(self):
Assert every module in /airflow/providers has a corresponding test_ file in tests/airflow/providers.
"""
# Deprecated modules that don't have corresponded test
- expected_missing_providers_modules = {('airflow/providers/amazon/aws/hooks/aws_dynamodb.py',
- 'tests/providers/amazon/aws/hooks/test_aws_dynamodb.py')}
+ expected_missing_providers_modules = {
+ (
+ 'airflow/providers/amazon/aws/hooks/aws_dynamodb.py',
+ 'tests/providers/amazon/aws/hooks/test_aws_dynamodb.py',
+ )
+ }
# TODO: Should we extend this test to cover other directories?
modules_files = glob.glob(f"{ROOT_FOLDER}/airflow/providers/**/*.py", recursive=True)
@@ -71,13 +75,13 @@ def test_providers_modules_should_have_tests(self):
modules_files = (f for f in modules_files if not f.endswith("__init__.py"))
# Change airflow/ to tests/
expected_test_files = (
- f'tests/{f.partition("/")[2]}'
- for f in modules_files if not f.endswith("__init__.py")
+ f'tests/{f.partition("/")[2]}' for f in modules_files if not f.endswith("__init__.py")
)
# Add test_ prefix to filename
expected_test_files = (
f'{f.rpartition("/")[0]}/test_{f.rpartition("/")[2]}'
- for f in expected_test_files if not f.endswith("__init__.py")
+ for f in expected_test_files
+ if not f.endswith("__init__.py")
)
current_test_files = glob.glob(f"{ROOT_FOLDER}/tests/providers/**/*.py", recursive=True)
@@ -126,9 +130,7 @@ def test_example_dags(self):
)
example_dags_files = self.find_resource_files(resource_type="example_dags")
# Generate tuple of department and service e.g. ('marketing_platform', 'display_video')
- operator_sets = [
- (f.split("/")[-3], f.split("/")[-1].rsplit(".")[0]) for f in operators_modules
- ]
+ operator_sets = [(f.split("/")[-3], f.split("/")[-1].rsplit(".")[0]) for f in operators_modules]
example_sets = [
(f.split("/")[-3], f.split("/")[-1].rsplit(".")[0].replace("example_", "", 1))
for f in example_dags_files
@@ -174,7 +176,10 @@ def has_example_dag(operator_set):
@parameterized.expand(
[
- (resource_type, suffix,)
+ (
+ resource_type,
+ suffix,
+ )
for suffix in ["_system.py", "_system_helper.py"]
for resource_type in ["operators", "sensors", "tranfers"]
]
@@ -186,12 +191,8 @@ def test_detect_invalid_system_tests(self, resource_type, filename_suffix):
files = {f for f in operators_tests if f.endswith(filename_suffix)}
expected_files = (f"tests/{f[8:]}" for f in operators_files)
- expected_files = (
- f.replace(".py", filename_suffix).replace("/test_", "/") for f in expected_files
- )
- expected_files = {
- f'{f.rpartition("/")[0]}/test_{f.rpartition("/")[2]}' for f in expected_files
- }
+ expected_files = (f.replace(".py", filename_suffix).replace("/test_", "/") for f in expected_files)
+ expected_files = {f'{f.rpartition("/")[0]}/test_{f.rpartition("/")[2]}' for f in expected_files}
self.assertEqual(set(), files - expected_files)
@@ -200,7 +201,7 @@ def find_resource_files(
top_level_directory: str = "airflow",
department: str = "*",
resource_type: str = "*",
- service: str = "*"
+ service: str = "*",
):
python_files = glob.glob(
f"{ROOT_FOLDER}/{top_level_directory}/providers/google/{department}/{resource_type}/{service}.py"
@@ -215,16 +216,14 @@ def find_resource_files(
class TestOperatorsHooks(unittest.TestCase):
def test_no_illegal_suffixes(self):
illegal_suffixes = ["_operator.py", "_hook.py", "_sensor.py"]
- files = itertools.chain(*[
- glob.glob(f"{ROOT_FOLDER}/{part}/providers/**/{resource_type}/*.py", recursive=True)
- for resource_type in ["operators", "hooks", "sensors", "example_dags"]
- for part in ["airflow", "tests"]
- ])
-
- invalid_files = [
- f
- for f in files
- if any(f.endswith(suffix) for suffix in illegal_suffixes)
- ]
+ files = itertools.chain(
+ *[
+ glob.glob(f"{ROOT_FOLDER}/{part}/providers/**/{resource_type}/*.py", recursive=True)
+ for resource_type in ["operators", "hooks", "sensors", "example_dags"]
+ for part in ["airflow", "tests"]
+ ]
+ )
+
+ invalid_files = [f for f in files if any(f.endswith(suffix) for suffix in illegal_suffixes)]
self.assertEqual([], invalid_files)
diff --git a/tests/api/auth/backend/test_basic_auth.py b/tests/api/auth/backend/test_basic_auth.py
index 06461347384a3..4cf427780e201 100644
--- a/tests/api/auth/backend/test_basic_auth.py
+++ b/tests/api/auth/backend/test_basic_auth.py
@@ -29,9 +29,7 @@
class TestBasicAuth(unittest.TestCase):
def setUp(self) -> None:
- with conf_vars(
- {("api", "auth_backend"): "airflow.api.auth.backend.basic_auth"}
- ):
+ with conf_vars({("api", "auth_backend"): "airflow.api.auth.backend.basic_auth"}):
self.app = create_app(testing=True)
self.appbuilder = self.app.appbuilder # pylint: disable=no-member
@@ -52,9 +50,7 @@ def test_success(self):
clear_db_pools()
with self.app.test_client() as test_client:
- response = test_client.get(
- "/api/v1/pools", headers={"Authorization": token}
- )
+ response = test_client.get("/api/v1/pools", headers={"Authorization": token})
assert current_user.email == "test@fab.org"
assert response.status_code == 200
@@ -72,37 +68,37 @@ def test_success(self):
"total_entries": 1,
}
- @parameterized.expand([
- ("basic",),
- ("basic ",),
- ("bearer",),
- ("test:test",),
- (b64encode(b"test:test").decode(),),
- ("bearer ",),
- ("basic: ",),
- ("basic 123",),
- ])
+ @parameterized.expand(
+ [
+ ("basic",),
+ ("basic ",),
+ ("bearer",),
+ ("test:test",),
+ (b64encode(b"test:test").decode(),),
+ ("bearer ",),
+ ("basic: ",),
+ ("basic 123",),
+ ]
+ )
def test_malformed_headers(self, token):
with self.app.test_client() as test_client:
- response = test_client.get(
- "/api/v1/pools", headers={"Authorization": token}
- )
+ response = test_client.get("/api/v1/pools", headers={"Authorization": token})
assert response.status_code == 401
assert response.headers["Content-Type"] == "application/problem+json"
assert response.headers["WWW-Authenticate"] == "Basic"
assert_401(response)
- @parameterized.expand([
- ("basic " + b64encode(b"test").decode(),),
- ("basic " + b64encode(b"test:").decode(),),
- ("basic " + b64encode(b"test:123").decode(),),
- ("basic " + b64encode(b"test test").decode(),),
- ])
+ @parameterized.expand(
+ [
+ ("basic " + b64encode(b"test").decode(),),
+ ("basic " + b64encode(b"test:").decode(),),
+ ("basic " + b64encode(b"test:123").decode(),),
+ ("basic " + b64encode(b"test test").decode(),),
+ ]
+ )
def test_invalid_auth_header(self, token):
with self.app.test_client() as test_client:
- response = test_client.get(
- "/api/v1/pools", headers={"Authorization": token}
- )
+ response = test_client.get("/api/v1/pools", headers={"Authorization": token})
assert response.status_code == 401
assert response.headers["Content-Type"] == "application/problem+json"
assert response.headers["WWW-Authenticate"] == "Basic"
@@ -110,9 +106,7 @@ def test_invalid_auth_header(self, token):
def test_experimental_api(self):
with self.app.test_client() as test_client:
- response = test_client.get(
- "/api/experimental/pools", headers={"Authorization": "Basic"}
- )
+ response = test_client.get("/api/experimental/pools", headers={"Authorization": "Basic"})
assert response.status_code == 401
assert response.headers["WWW-Authenticate"] == "Basic"
assert response.data == b'Unauthorized'
@@ -120,7 +114,7 @@ def test_experimental_api(self):
clear_db_pools()
response = test_client.get(
"/api/experimental/pools",
- headers={"Authorization": "Basic " + b64encode(b"test:test").decode()}
+ headers={"Authorization": "Basic " + b64encode(b"test:test").decode()},
)
assert response.status_code == 200
assert response.json[0]["pool"] == 'default_pool'
diff --git a/tests/api/auth/test_client.py b/tests/api/auth/test_client.py
index 3275f6e58c03f..8652b12772b5f 100644
--- a/tests/api/auth/test_client.py
+++ b/tests/api/auth/test_client.py
@@ -23,14 +23,15 @@
class TestGetCurrentApiClient(unittest.TestCase):
-
@mock.patch("airflow.api.client.json_client.Client")
@mock.patch("airflow.api.auth.backend.default.CLIENT_AUTH", "CLIENT_AUTH")
- @conf_vars({
- ("api", 'auth_backend'): 'airflow.api.auth.backend.default',
- ("cli", 'api_client'): 'airflow.api.client.json_client',
- ("cli", 'endpoint_url'): 'http://localhost:1234',
- })
+ @conf_vars(
+ {
+ ("api", 'auth_backend'): 'airflow.api.auth.backend.default',
+ ("cli", 'api_client'): 'airflow.api.client.json_client',
+ ("cli", 'endpoint_url'): 'http://localhost:1234',
+ }
+ )
def test_should_create_client(self, mock_client):
result = get_current_api_client()
@@ -41,17 +42,17 @@ def test_should_create_client(self, mock_client):
@mock.patch("airflow.api.client.json_client.Client")
@mock.patch("airflow.providers.google.common.auth_backend.google_openid.create_client_session")
- @conf_vars({
- ("api", 'auth_backend'): 'airflow.providers.google.common.auth_backend.google_openid',
- ("cli", 'api_client'): 'airflow.api.client.json_client',
- ("cli", 'endpoint_url'): 'http://localhost:1234',
- })
+ @conf_vars(
+ {
+ ("api", 'auth_backend'): 'airflow.providers.google.common.auth_backend.google_openid',
+ ("cli", 'api_client'): 'airflow.api.client.json_client',
+ ("cli", 'endpoint_url'): 'http://localhost:1234',
+ }
+ )
def test_should_create_google_open_id_client(self, mock_create_client_session, mock_client):
result = get_current_api_client()
mock_client.assert_called_once_with(
- api_base_url='http://localhost:1234',
- auth=None,
- session=mock_create_client_session.return_value
+ api_base_url='http://localhost:1234', auth=None, session=mock_create_client_session.return_value
)
self.assertEqual(mock_client.return_value, result)
diff --git a/tests/api/client/test_local_client.py b/tests/api/client/test_local_client.py
index d3260a2cba016..d574615a243cf 100644
--- a/tests/api/client/test_local_client.py
+++ b/tests/api/client/test_local_client.py
@@ -38,7 +38,6 @@
class TestLocalClient(unittest.TestCase):
-
@classmethod
def setUpClass(cls):
super().setUpClass()
@@ -67,44 +66,52 @@ def test_trigger_dag(self, mock):
with freeze_time(EXECDATE):
# no execution date, execution date should be set automatically
self.client.trigger_dag(dag_id=test_dag_id)
- mock.assert_called_once_with(run_id=run_id,
- execution_date=EXECDATE_NOFRACTIONS,
- state=State.RUNNING,
- conf=None,
- external_trigger=True,
- dag_hash=ANY)
+ mock.assert_called_once_with(
+ run_id=run_id,
+ execution_date=EXECDATE_NOFRACTIONS,
+ state=State.RUNNING,
+ conf=None,
+ external_trigger=True,
+ dag_hash=ANY,
+ )
mock.reset_mock()
# execution date with microseconds cutoff
self.client.trigger_dag(dag_id=test_dag_id, execution_date=EXECDATE)
- mock.assert_called_once_with(run_id=run_id,
- execution_date=EXECDATE_NOFRACTIONS,
- state=State.RUNNING,
- conf=None,
- external_trigger=True,
- dag_hash=ANY)
+ mock.assert_called_once_with(
+ run_id=run_id,
+ execution_date=EXECDATE_NOFRACTIONS,
+ state=State.RUNNING,
+ conf=None,
+ external_trigger=True,
+ dag_hash=ANY,
+ )
mock.reset_mock()
# run id
custom_run_id = "my_run_id"
self.client.trigger_dag(dag_id=test_dag_id, run_id=custom_run_id)
- mock.assert_called_once_with(run_id=custom_run_id,
- execution_date=EXECDATE_NOFRACTIONS,
- state=State.RUNNING,
- conf=None,
- external_trigger=True,
- dag_hash=ANY)
+ mock.assert_called_once_with(
+ run_id=custom_run_id,
+ execution_date=EXECDATE_NOFRACTIONS,
+ state=State.RUNNING,
+ conf=None,
+ external_trigger=True,
+ dag_hash=ANY,
+ )
mock.reset_mock()
# test conf
conf = '{"name": "John"}'
self.client.trigger_dag(dag_id=test_dag_id, conf=conf)
- mock.assert_called_once_with(run_id=run_id,
- execution_date=EXECDATE_NOFRACTIONS,
- state=State.RUNNING,
- conf=json.loads(conf),
- external_trigger=True,
- dag_hash=ANY)
+ mock.assert_called_once_with(
+ run_id=run_id,
+ execution_date=EXECDATE_NOFRACTIONS,
+ state=State.RUNNING,
+ conf=json.loads(conf),
+ external_trigger=True,
+ dag_hash=ANY,
+ )
mock.reset_mock()
def test_delete_dag(self):
@@ -129,8 +136,7 @@ def test_get_pools(self):
self.client.create_pool(name='foo1', slots=1, description='')
self.client.create_pool(name='foo2', slots=2, description='')
pools = sorted(self.client.get_pools(), key=lambda p: p[0])
- self.assertEqual(pools, [('default_pool', 128, 'Default pool'),
- ('foo1', 1, ''), ('foo2', 2, '')])
+ self.assertEqual(pools, [('default_pool', 128, 'Default pool'), ('foo1', 1, ''), ('foo2', 2, '')])
def test_create_pool(self):
pool = self.client.create_pool(name='foo', slots=1, description='')
diff --git a/tests/api/common/experimental/test_delete_dag.py b/tests/api/common/experimental/test_delete_dag.py
index 00058713c63a0..471138f9f5cd3 100644
--- a/tests/api/common/experimental/test_delete_dag.py
+++ b/tests/api/common/experimental/test_delete_dag.py
@@ -37,7 +37,6 @@
class TestDeleteDAGCatchError(unittest.TestCase):
-
def setUp(self):
self.dagbag = models.DagBag(include_examples=True)
self.dag_id = 'example_bash_operator'
@@ -59,30 +58,47 @@ def setup_dag_models(self, for_sub_dag=False):
if for_sub_dag:
self.key = "test_dag_id.test_subdag"
- task = DummyOperator(task_id='dummy',
- dag=models.DAG(dag_id=self.key,
- default_args={'start_date': days_ago(2)}),
- owner='airflow')
+ task = DummyOperator(
+ task_id='dummy',
+ dag=models.DAG(dag_id=self.key, default_args={'start_date': days_ago(2)}),
+ owner='airflow',
+ )
test_date = days_ago(1)
with create_session() as session:
session.add(DM(dag_id=self.key, fileloc=self.dag_file_path, is_subdag=for_sub_dag))
session.add(DR(dag_id=self.key, run_type=DagRunType.MANUAL))
- session.add(TI(task=task,
- execution_date=test_date,
- state=State.SUCCESS))
+ session.add(TI(task=task, execution_date=test_date, state=State.SUCCESS))
# flush to ensure task instance if written before
# task reschedule because of FK constraint
session.flush()
- session.add(LOG(dag_id=self.key, task_id=None, task_instance=None,
- execution_date=test_date, event="varimport"))
- session.add(TF(task=task, execution_date=test_date,
- start_date=test_date, end_date=test_date))
- session.add(TR(task=task, execution_date=test_date,
- start_date=test_date, end_date=test_date,
- try_number=1, reschedule_date=test_date))
- session.add(IE(timestamp=test_date, filename=self.dag_file_path,
- stacktrace="NameError: name 'airflow' is not defined"))
+ session.add(
+ LOG(
+ dag_id=self.key,
+ task_id=None,
+ task_instance=None,
+ execution_date=test_date,
+ event="varimport",
+ )
+ )
+ session.add(TF(task=task, execution_date=test_date, start_date=test_date, end_date=test_date))
+ session.add(
+ TR(
+ task=task,
+ execution_date=test_date,
+ start_date=test_date,
+ end_date=test_date,
+ try_number=1,
+ reschedule_date=test_date,
+ )
+ )
+ session.add(
+ IE(
+ timestamp=test_date,
+ filename=self.dag_file_path,
+ stacktrace="NameError: name 'airflow' is not defined",
+ )
+ )
def tearDown(self):
with create_session() as session:
@@ -102,8 +118,7 @@ def check_dag_models_exists(self):
self.assertEqual(session.query(TF).filter(TF.dag_id == self.key).count(), 1)
self.assertEqual(session.query(TR).filter(TR.dag_id == self.key).count(), 1)
self.assertEqual(session.query(LOG).filter(LOG.dag_id == self.key).count(), 1)
- self.assertEqual(
- session.query(IE).filter(IE.filename == self.dag_file_path).count(), 1)
+ self.assertEqual(session.query(IE).filter(IE.filename == self.dag_file_path).count(), 1)
def check_dag_models_removed(self, expect_logs=1):
with create_session() as session:
@@ -113,8 +128,7 @@ def check_dag_models_removed(self, expect_logs=1):
self.assertEqual(session.query(TF).filter(TF.dag_id == self.key).count(), 0)
self.assertEqual(session.query(TR).filter(TR.dag_id == self.key).count(), 0)
self.assertEqual(session.query(LOG).filter(LOG.dag_id == self.key).count(), expect_logs)
- self.assertEqual(
- session.query(IE).filter(IE.filename == self.dag_file_path).count(), 0)
+ self.assertEqual(session.query(IE).filter(IE.filename == self.dag_file_path).count(), 0)
def test_delete_dag_successful_delete(self):
self.setup_dag_models()
diff --git a/tests/api/common/experimental/test_mark_tasks.py b/tests/api/common/experimental/test_mark_tasks.py
index dde25f5b857b2..4ff093100c93f 100644
--- a/tests/api/common/experimental/test_mark_tasks.py
+++ b/tests/api/common/experimental/test_mark_tasks.py
@@ -24,7 +24,10 @@
from airflow import models
from airflow.api.common.experimental.mark_tasks import (
- _create_dagruns, set_dag_run_state_to_failed, set_dag_run_state_to_running, set_dag_run_state_to_success,
+ _create_dagruns,
+ set_dag_run_state_to_failed,
+ set_dag_run_state_to_running,
+ set_dag_run_state_to_success,
set_state,
)
from airflow.models import DagRun
@@ -39,7 +42,6 @@
class TestMarkTasks(unittest.TestCase):
-
@classmethod
def setUpClass(cls):
models.DagBag(include_examples=True, read_dags_from_db=False).sync_to_db()
@@ -53,31 +55,32 @@ def setUpClass(cls):
cls.dag3.sync_to_db()
cls.execution_dates = [days_ago(2), days_ago(1)]
start_date3 = cls.dag3.start_date
- cls.dag3_execution_dates = [start_date3, start_date3 + timedelta(days=1),
- start_date3 + timedelta(days=2)]
+ cls.dag3_execution_dates = [
+ start_date3,
+ start_date3 + timedelta(days=1),
+ start_date3 + timedelta(days=2),
+ ]
def setUp(self):
clear_db_runs()
- drs = _create_dagruns(self.dag1, self.execution_dates,
- state=State.RUNNING,
- run_type=DagRunType.SCHEDULED)
+ drs = _create_dagruns(
+ self.dag1, self.execution_dates, state=State.RUNNING, run_type=DagRunType.SCHEDULED
+ )
for dr in drs:
dr.dag = self.dag1
dr.verify_integrity()
- drs = _create_dagruns(self.dag2,
- [self.dag2.start_date],
- state=State.RUNNING,
- run_type=DagRunType.SCHEDULED)
+ drs = _create_dagruns(
+ self.dag2, [self.dag2.start_date], state=State.RUNNING, run_type=DagRunType.SCHEDULED
+ )
for dr in drs:
dr.dag = self.dag2
dr.verify_integrity()
- drs = _create_dagruns(self.dag3,
- self.dag3_execution_dates,
- state=State.SUCCESS,
- run_type=DagRunType.MANUAL)
+ drs = _create_dagruns(
+ self.dag3, self.dag3_execution_dates, state=State.SUCCESS, run_type=DagRunType.MANUAL
+ )
for dr in drs:
dr.dag = self.dag3
dr.verify_integrity()
@@ -89,19 +92,17 @@ def tearDown(self):
def snapshot_state(dag, execution_dates):
TI = models.TaskInstance
with create_session() as session:
- return session.query(TI).filter(
- TI.dag_id == dag.dag_id,
- TI.execution_date.in_(execution_dates)
- ).all()
+ return (
+ session.query(TI)
+ .filter(TI.dag_id == dag.dag_id, TI.execution_date.in_(execution_dates))
+ .all()
+ )
@provide_session
def verify_state(self, dag, task_ids, execution_dates, state, old_tis, session=None):
TI = models.TaskInstance
- tis = session.query(TI).filter(
- TI.dag_id == dag.dag_id,
- TI.execution_date.in_(execution_dates)
- ).all()
+ tis = session.query(TI).filter(TI.dag_id == dag.dag_id, TI.execution_date.in_(execution_dates)).all()
self.assertTrue(len(tis) > 0)
@@ -120,63 +121,97 @@ def test_mark_tasks_now(self):
# set one task to success but do not commit
snapshot = TestMarkTasks.snapshot_state(self.dag1, self.execution_dates)
task = self.dag1.get_task("runme_1")
- altered = set_state(tasks=[task], execution_date=self.execution_dates[0],
- upstream=False, downstream=False, future=False,
- past=False, state=State.SUCCESS, commit=False)
+ altered = set_state(
+ tasks=[task],
+ execution_date=self.execution_dates[0],
+ upstream=False,
+ downstream=False,
+ future=False,
+ past=False,
+ state=State.SUCCESS,
+ commit=False,
+ )
self.assertEqual(len(altered), 1)
- self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]],
- None, snapshot)
+ self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], None, snapshot)
# set one and only one task to success
- altered = set_state(tasks=[task], execution_date=self.execution_dates[0],
- upstream=False, downstream=False, future=False,
- past=False, state=State.SUCCESS, commit=True)
+ altered = set_state(
+ tasks=[task],
+ execution_date=self.execution_dates[0],
+ upstream=False,
+ downstream=False,
+ future=False,
+ past=False,
+ state=State.SUCCESS,
+ commit=True,
+ )
self.assertEqual(len(altered), 1)
- self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]],
- State.SUCCESS, snapshot)
+ self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], State.SUCCESS, snapshot)
# set no tasks
- altered = set_state(tasks=[task], execution_date=self.execution_dates[0],
- upstream=False, downstream=False, future=False,
- past=False, state=State.SUCCESS, commit=True)
+ altered = set_state(
+ tasks=[task],
+ execution_date=self.execution_dates[0],
+ upstream=False,
+ downstream=False,
+ future=False,
+ past=False,
+ state=State.SUCCESS,
+ commit=True,
+ )
self.assertEqual(len(altered), 0)
- self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]],
- State.SUCCESS, snapshot)
+ self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], State.SUCCESS, snapshot)
# set task to other than success
- altered = set_state(tasks=[task], execution_date=self.execution_dates[0],
- upstream=False, downstream=False, future=False,
- past=False, state=State.FAILED, commit=True)
+ altered = set_state(
+ tasks=[task],
+ execution_date=self.execution_dates[0],
+ upstream=False,
+ downstream=False,
+ future=False,
+ past=False,
+ state=State.FAILED,
+ commit=True,
+ )
self.assertEqual(len(altered), 1)
- self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]],
- State.FAILED, snapshot)
+ self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], State.FAILED, snapshot)
# don't alter other tasks
snapshot = TestMarkTasks.snapshot_state(self.dag1, self.execution_dates)
task = self.dag1.get_task("runme_0")
- altered = set_state(tasks=[task], execution_date=self.execution_dates[0],
- upstream=False, downstream=False, future=False,
- past=False, state=State.SUCCESS, commit=True)
+ altered = set_state(
+ tasks=[task],
+ execution_date=self.execution_dates[0],
+ upstream=False,
+ downstream=False,
+ future=False,
+ past=False,
+ state=State.SUCCESS,
+ commit=True,
+ )
self.assertEqual(len(altered), 1)
- self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]],
- State.SUCCESS, snapshot)
+ self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], State.SUCCESS, snapshot)
# set one task as FAILED. dag3 has schedule_interval None
snapshot = TestMarkTasks.snapshot_state(self.dag3, self.dag3_execution_dates)
task = self.dag3.get_task("run_this")
- altered = set_state(tasks=[task], execution_date=self.dag3_execution_dates[1],
- upstream=False, downstream=False, future=False,
- past=False, state=State.FAILED, commit=True)
+ altered = set_state(
+ tasks=[task],
+ execution_date=self.dag3_execution_dates[1],
+ upstream=False,
+ downstream=False,
+ future=False,
+ past=False,
+ state=State.FAILED,
+ commit=True,
+ )
# exactly one TaskInstance should have been altered
self.assertEqual(len(altered), 1)
# task should have been marked as failed
- self.verify_state(self.dag3, [task.task_id], [self.dag3_execution_dates[1]],
- State.FAILED, snapshot)
+ self.verify_state(self.dag3, [task.task_id], [self.dag3_execution_dates[1]], State.FAILED, snapshot)
# tasks on other days should be unchanged
- self.verify_state(self.dag3, [task.task_id], [self.dag3_execution_dates[0]],
- None, snapshot)
- self.verify_state(self.dag3, [task.task_id], [self.dag3_execution_dates[2]],
- None, snapshot)
+ self.verify_state(self.dag3, [task.task_id], [self.dag3_execution_dates[0]], None, snapshot)
+ self.verify_state(self.dag3, [task.task_id], [self.dag3_execution_dates[2]], None, snapshot)
def test_mark_downstream(self):
# test downstream
@@ -186,9 +221,16 @@ def test_mark_downstream(self):
task_ids = [t.task_id for t in relatives]
task_ids.append(task.task_id)
- altered = set_state(tasks=[task], execution_date=self.execution_dates[0],
- upstream=False, downstream=True, future=False,
- past=False, state=State.SUCCESS, commit=True)
+ altered = set_state(
+ tasks=[task],
+ execution_date=self.execution_dates[0],
+ upstream=False,
+ downstream=True,
+ future=False,
+ past=False,
+ state=State.SUCCESS,
+ commit=True,
+ )
self.assertEqual(len(altered), 3)
self.verify_state(self.dag1, task_ids, [self.execution_dates[0]], State.SUCCESS, snapshot)
@@ -200,28 +242,48 @@ def test_mark_upstream(self):
task_ids = [t.task_id for t in relatives]
task_ids.append(task.task_id)
- altered = set_state(tasks=[task], execution_date=self.execution_dates[0],
- upstream=True, downstream=False, future=False,
- past=False, state=State.SUCCESS, commit=True)
+ altered = set_state(
+ tasks=[task],
+ execution_date=self.execution_dates[0],
+ upstream=True,
+ downstream=False,
+ future=False,
+ past=False,
+ state=State.SUCCESS,
+ commit=True,
+ )
self.assertEqual(len(altered), 4)
- self.verify_state(self.dag1, task_ids, [self.execution_dates[0]],
- State.SUCCESS, snapshot)
+ self.verify_state(self.dag1, task_ids, [self.execution_dates[0]], State.SUCCESS, snapshot)
def test_mark_tasks_future(self):
# set one task to success towards end of scheduled dag runs
snapshot = TestMarkTasks.snapshot_state(self.dag1, self.execution_dates)
task = self.dag1.get_task("runme_1")
- altered = set_state(tasks=[task], execution_date=self.execution_dates[0],
- upstream=False, downstream=False, future=True,
- past=False, state=State.SUCCESS, commit=True)
+ altered = set_state(
+ tasks=[task],
+ execution_date=self.execution_dates[0],
+ upstream=False,
+ downstream=False,
+ future=True,
+ past=False,
+ state=State.SUCCESS,
+ commit=True,
+ )
self.assertEqual(len(altered), 2)
self.verify_state(self.dag1, [task.task_id], self.execution_dates, State.SUCCESS, snapshot)
snapshot = TestMarkTasks.snapshot_state(self.dag3, self.dag3_execution_dates)
task = self.dag3.get_task("run_this")
- altered = set_state(tasks=[task], execution_date=self.dag3_execution_dates[1],
- upstream=False, downstream=False, future=True,
- past=False, state=State.FAILED, commit=True)
+ altered = set_state(
+ tasks=[task],
+ execution_date=self.dag3_execution_dates[1],
+ upstream=False,
+ downstream=False,
+ future=True,
+ past=False,
+ state=State.FAILED,
+ commit=True,
+ )
self.assertEqual(len(altered), 2)
self.verify_state(self.dag3, [task.task_id], [self.dag3_execution_dates[0]], None, snapshot)
self.verify_state(self.dag3, [task.task_id], self.dag3_execution_dates[1:], State.FAILED, snapshot)
@@ -230,17 +292,31 @@ def test_mark_tasks_past(self):
# set one task to success towards end of scheduled dag runs
snapshot = TestMarkTasks.snapshot_state(self.dag1, self.execution_dates)
task = self.dag1.get_task("runme_1")
- altered = set_state(tasks=[task], execution_date=self.execution_dates[1],
- upstream=False, downstream=False, future=False,
- past=True, state=State.SUCCESS, commit=True)
+ altered = set_state(
+ tasks=[task],
+ execution_date=self.execution_dates[1],
+ upstream=False,
+ downstream=False,
+ future=False,
+ past=True,
+ state=State.SUCCESS,
+ commit=True,
+ )
self.assertEqual(len(altered), 2)
self.verify_state(self.dag1, [task.task_id], self.execution_dates, State.SUCCESS, snapshot)
snapshot = TestMarkTasks.snapshot_state(self.dag3, self.dag3_execution_dates)
task = self.dag3.get_task("run_this")
- altered = set_state(tasks=[task], execution_date=self.dag3_execution_dates[1],
- upstream=False, downstream=False, future=False,
- past=True, state=State.FAILED, commit=True)
+ altered = set_state(
+ tasks=[task],
+ execution_date=self.dag3_execution_dates[1],
+ upstream=False,
+ downstream=False,
+ future=False,
+ past=True,
+ state=State.FAILED,
+ commit=True,
+ )
self.assertEqual(len(altered), 2)
self.verify_state(self.dag3, [task.task_id], self.dag3_execution_dates[:2], State.FAILED, snapshot)
self.verify_state(self.dag3, [task.task_id], [self.dag3_execution_dates[2]], None, snapshot)
@@ -249,12 +325,20 @@ def test_mark_tasks_multiple(self):
# set multiple tasks to success
snapshot = TestMarkTasks.snapshot_state(self.dag1, self.execution_dates)
tasks = [self.dag1.get_task("runme_1"), self.dag1.get_task("runme_2")]
- altered = set_state(tasks=tasks, execution_date=self.execution_dates[0],
- upstream=False, downstream=False, future=False,
- past=False, state=State.SUCCESS, commit=True)
+ altered = set_state(
+ tasks=tasks,
+ execution_date=self.execution_dates[0],
+ upstream=False,
+ downstream=False,
+ future=False,
+ past=False,
+ state=State.SUCCESS,
+ commit=True,
+ )
self.assertEqual(len(altered), 2)
- self.verify_state(self.dag1, [task.task_id for task in tasks], [self.execution_dates[0]],
- State.SUCCESS, snapshot)
+ self.verify_state(
+ self.dag1, [task.task_id for task in tasks], [self.execution_dates[0]], State.SUCCESS, snapshot
+ )
# TODO: this backend should be removed once a fixing solution is found later
# We skip it here because this test case is working with Postgres & SQLite
@@ -267,20 +351,25 @@ def test_mark_tasks_subdag(self):
task_ids = [t.task_id for t in relatives]
task_ids.append(task.task_id)
- altered = set_state(tasks=[task], execution_date=self.execution_dates[0],
- upstream=False, downstream=True, future=False,
- past=False, state=State.SUCCESS, commit=True)
+ altered = set_state(
+ tasks=[task],
+ execution_date=self.execution_dates[0],
+ upstream=False,
+ downstream=True,
+ future=False,
+ past=False,
+ state=State.SUCCESS,
+ commit=True,
+ )
self.assertEqual(len(altered), 14)
# cannot use snapshot here as that will require drilling down the
# sub dag tree essentially recreating the same code as in the
# tested logic.
- self.verify_state(self.dag2, task_ids, [self.execution_dates[0]],
- State.SUCCESS, [])
+ self.verify_state(self.dag2, task_ids, [self.execution_dates[0]], State.SUCCESS, [])
class TestMarkDAGRun(unittest.TestCase):
-
@classmethod
def setUpClass(cls):
dagbag = models.DagBag(include_examples=True, read_dags_from_db=False)
@@ -318,17 +407,12 @@ def _verify_task_instance_states_remain_default(self, dr):
@provide_session
def _verify_task_instance_states(self, dag, date, state, session=None):
TI = models.TaskInstance
- tis = session.query(TI)\
- .filter(TI.dag_id == dag.dag_id, TI.execution_date == date)
+ tis = session.query(TI).filter(TI.dag_id == dag.dag_id, TI.execution_date == date)
for ti in tis:
self.assertEqual(ti.state, state)
def _create_test_dag_run(self, state, date):
- return self.dag1.create_dagrun(
- run_type=DagRunType.MANUAL,
- state=state,
- execution_date=date
- )
+ return self.dag1.create_dagrun(run_type=DagRunType.MANUAL, state=state, execution_date=date)
def _verify_dag_run_state(self, dag, date, state):
drs = models.DagRun.find(dag_id=dag.dag_id, execution_date=date)
@@ -341,10 +425,7 @@ def _verify_dag_run_dates(self, dag, date, state, middle_time, session=None):
# When target state is RUNNING, we should set start_date,
# otherwise we should set end_date.
DR = DagRun
- dr = session.query(DR).filter(
- DR.dag_id == dag.dag_id,
- DR.execution_date == date
- ).one()
+ dr = session.query(DR).filter(DR.dag_id == dag.dag_id, DR.execution_date == date).one()
if state == State.RUNNING:
# Since the DAG is running, the start_date must be updated after creation
self.assertGreater(dr.start_date, middle_time)
@@ -515,19 +596,19 @@ def test_set_state_with_multiple_dagruns(self, session=None):
run_type=DagRunType.MANUAL,
state=State.FAILED,
execution_date=self.execution_dates[0],
- session=session
+ session=session,
)
self.dag2.create_dagrun(
run_type=DagRunType.MANUAL,
state=State.FAILED,
execution_date=self.execution_dates[1],
- session=session
+ session=session,
)
self.dag2.create_dagrun(
run_type=DagRunType.MANUAL,
state=State.RUNNING,
execution_date=self.execution_dates[2],
- session=session
+ session=session,
)
altered = set_dag_run_state_to_success(self.dag2, self.execution_dates[1], commit=True)
@@ -543,11 +624,9 @@ def count_dag_tasks(dag):
self._verify_dag_run_state(self.dag2, self.execution_dates[1], State.SUCCESS)
# Make sure other dag status are not changed
- models.DagRun.find(dag_id=self.dag2.dag_id,
- execution_date=self.execution_dates[0])
+ models.DagRun.find(dag_id=self.dag2.dag_id, execution_date=self.execution_dates[0])
self._verify_dag_run_state(self.dag2, self.execution_dates[0], State.FAILED)
- models.DagRun.find(dag_id=self.dag2.dag_id,
- execution_date=self.execution_dates[2])
+ models.DagRun.find(dag_id=self.dag2.dag_id, execution_date=self.execution_dates[2])
self._verify_dag_run_state(self.dag2, self.execution_dates[2], State.RUNNING)
def test_set_dag_run_state_edge_cases(self):
@@ -569,13 +648,13 @@ def test_set_dag_run_state_edge_cases(self):
# This will throw ValueError since dag.latest_execution_date
# need to be 0 does not exist.
- self.assertRaises(ValueError, set_dag_run_state_to_success, self.dag2,
- timezone.make_naive(self.execution_dates[0]))
+ self.assertRaises(
+ ValueError, set_dag_run_state_to_success, self.dag2, timezone.make_naive(self.execution_dates[0])
+ )
# altered = set_dag_run_state_to_success(self.dag1, self.execution_dates[0])
# DagRun does not exist
# This will throw ValueError since dag.latest_execution_date does not exist
- self.assertRaises(ValueError, set_dag_run_state_to_success,
- self.dag2, self.execution_dates[0])
+ self.assertRaises(ValueError, set_dag_run_state_to_success, self.dag2, self.execution_dates[0])
def test_set_dag_run_state_to_failed_no_running_tasks(self):
"""
diff --git a/tests/api/common/experimental/test_pool.py b/tests/api/common/experimental/test_pool.py
index d564ad2da9317..433823978a39b 100644
--- a/tests/api/common/experimental/test_pool.py
+++ b/tests/api/common/experimental/test_pool.py
@@ -52,28 +52,21 @@ def test_get_pool(self):
self.assertEqual(pool.pool, self.pools[0].pool)
def test_get_pool_non_existing(self):
- self.assertRaisesRegex(PoolNotFound,
- "^Pool 'test' doesn't exist$",
- pool_api.get_pool,
- name='test')
+ self.assertRaisesRegex(PoolNotFound, "^Pool 'test' doesn't exist$", pool_api.get_pool, name='test')
def test_get_pool_bad_name(self):
for name in ('', ' '):
- self.assertRaisesRegex(AirflowBadRequest,
- "^Pool name shouldn't be empty$",
- pool_api.get_pool,
- name=name)
+ self.assertRaisesRegex(
+ AirflowBadRequest, "^Pool name shouldn't be empty$", pool_api.get_pool, name=name
+ )
def test_get_pools(self):
- pools = sorted(pool_api.get_pools(),
- key=lambda p: p.pool)
+ pools = sorted(pool_api.get_pools(), key=lambda p: p.pool)
self.assertEqual(pools[0].pool, self.pools[0].pool)
self.assertEqual(pools[1].pool, self.pools[1].pool)
def test_create_pool(self):
- pool = pool_api.create_pool(name='foo',
- slots=5,
- description='')
+ pool = pool_api.create_pool(name='foo', slots=5, description='')
self.assertEqual(pool.pool, 'foo')
self.assertEqual(pool.slots, 5)
self.assertEqual(pool.description, '')
@@ -81,9 +74,7 @@ def test_create_pool(self):
self.assertEqual(session.query(models.Pool).count(), self.TOTAL_POOL_COUNT + 1)
def test_create_pool_existing(self):
- pool = pool_api.create_pool(name=self.pools[0].pool,
- slots=5,
- description='')
+ pool = pool_api.create_pool(name=self.pools[0].pool, slots=5, description='')
self.assertEqual(pool.pool, self.pools[0].pool)
self.assertEqual(pool.slots, 5)
self.assertEqual(pool.description, '')
@@ -92,30 +83,36 @@ def test_create_pool_existing(self):
def test_create_pool_bad_name(self):
for name in ('', ' '):
- self.assertRaisesRegex(AirflowBadRequest,
- "^Pool name shouldn't be empty$",
- pool_api.create_pool,
- name=name,
- slots=5,
- description='')
+ self.assertRaisesRegex(
+ AirflowBadRequest,
+ "^Pool name shouldn't be empty$",
+ pool_api.create_pool,
+ name=name,
+ slots=5,
+ description='',
+ )
def test_create_pool_name_too_long(self):
long_name = ''.join(random.choices(string.ascii_lowercase, k=300))
column_length = models.Pool.pool.property.columns[0].type.length
- self.assertRaisesRegex(AirflowBadRequest,
- "^Pool name can't be more than %d characters$" % column_length,
- pool_api.create_pool,
- name=long_name,
- slots=5,
- description='')
+ self.assertRaisesRegex(
+ AirflowBadRequest,
+ "^Pool name can't be more than %d characters$" % column_length,
+ pool_api.create_pool,
+ name=long_name,
+ slots=5,
+ description='',
+ )
def test_create_pool_bad_slots(self):
- self.assertRaisesRegex(AirflowBadRequest,
- "^Bad value for `slots`: foo$",
- pool_api.create_pool,
- name='foo',
- slots='foo',
- description='')
+ self.assertRaisesRegex(
+ AirflowBadRequest,
+ "^Bad value for `slots`: foo$",
+ pool_api.create_pool,
+ name='foo',
+ slots='foo',
+ description='',
+ )
def test_delete_pool(self):
pool = pool_api.delete_pool(name=self.pools[-1].pool)
@@ -124,19 +121,16 @@ def test_delete_pool(self):
self.assertEqual(session.query(models.Pool).count(), self.TOTAL_POOL_COUNT - 1)
def test_delete_pool_non_existing(self):
- self.assertRaisesRegex(pool_api.PoolNotFound,
- "^Pool 'test' doesn't exist$",
- pool_api.delete_pool,
- name='test')
+ self.assertRaisesRegex(
+ pool_api.PoolNotFound, "^Pool 'test' doesn't exist$", pool_api.delete_pool, name='test'
+ )
def test_delete_pool_bad_name(self):
for name in ('', ' '):
- self.assertRaisesRegex(AirflowBadRequest,
- "^Pool name shouldn't be empty$",
- pool_api.delete_pool,
- name=name)
+ self.assertRaisesRegex(
+ AirflowBadRequest, "^Pool name shouldn't be empty$", pool_api.delete_pool, name=name
+ )
def test_delete_default_pool_not_allowed(self):
- with self.assertRaisesRegex(AirflowBadRequest,
- "^default_pool cannot be deleted$"):
+ with self.assertRaisesRegex(AirflowBadRequest, "^default_pool cannot be deleted$"):
pool_api.delete_pool(Pool.DEFAULT_POOL_NAME)
diff --git a/tests/api/common/experimental/test_trigger_dag.py b/tests/api/common/experimental/test_trigger_dag.py
index 13ea2a32d2246..9fb772d3576f9 100644
--- a/tests/api/common/experimental/test_trigger_dag.py
+++ b/tests/api/common/experimental/test_trigger_dag.py
@@ -29,7 +29,6 @@
class TestTriggerDag(unittest.TestCase):
-
def setUp(self) -> None:
db.clear_db_runs()
@@ -107,11 +106,13 @@ def test_trigger_dag_with_valid_start_date(self, dag_bag_mock):
assert len(triggers) == 1
- @parameterized.expand([
- (None, {}),
- ({"foo": "bar"}, {"foo": "bar"}),
- ('{"foo": "bar"}', {"foo": "bar"}),
- ])
+ @parameterized.expand(
+ [
+ (None, {}),
+ ({"foo": "bar"}, {"foo": "bar"}),
+ ('{"foo": "bar"}', {"foo": "bar"}),
+ ]
+ )
@mock.patch('airflow.models.DagBag')
def test_trigger_dag_with_conf(self, conf, expected_conf, dag_bag_mock):
dag_id = "trigger_dag_with_conf"
diff --git a/tests/api_connexion/endpoints/test_config_endpoint.py b/tests/api_connexion/endpoints/test_config_endpoint.py
index 9acfaea2cf94f..5a6b14c9bec99 100644
--- a/tests/api_connexion/endpoints/test_config_endpoint.py
+++ b/tests/api_connexion/endpoints/test_config_endpoint.py
@@ -16,7 +16,6 @@
# under the License.
import textwrap
-
from unittest.mock import patch
from airflow.security import permissions
@@ -24,7 +23,6 @@
from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user
from tests.test_utils.config import conf_vars
-
MOCK_CONF = {
'core': {
'parallelism': '1024',
diff --git a/tests/api_connexion/endpoints/test_task_instance_endpoint.py b/tests/api_connexion/endpoints/test_task_instance_endpoint.py
index 4bc78683b177b..34c5a388f264c 100644
--- a/tests/api_connexion/endpoints/test_task_instance_endpoint.py
+++ b/tests/api_connexion/endpoints/test_task_instance_endpoint.py
@@ -14,22 +14,22 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import unittest
import datetime as dt
import getpass
+import unittest
from unittest import mock
from parameterized import parameterized
-from airflow.models import DagBag, DagRun, TaskInstance, SlaMiss
+from airflow.models import DagBag, DagRun, SlaMiss, TaskInstance
from airflow.security import permissions
-from airflow.utils.types import DagRunType
from airflow.utils.session import provide_session
from airflow.utils.state import State
from airflow.utils.timezone import datetime
+from airflow.utils.types import DagRunType
from airflow.www import app
+from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user
from tests.test_utils.config import conf_vars
-from tests.test_utils.api_connexion_utils import create_user, delete_user, assert_401
from tests.test_utils.db import clear_db_runs, clear_db_sla_miss
DEFAULT_DATETIME_1 = datetime(2020, 1, 1)
diff --git a/tests/api_connexion/endpoints/test_variable_endpoint.py b/tests/api_connexion/endpoints/test_variable_endpoint.py
index 9f768709236de..1c6cb657746de 100644
--- a/tests/api_connexion/endpoints/test_variable_endpoint.py
+++ b/tests/api_connexion/endpoints/test_variable_endpoint.py
@@ -18,9 +18,9 @@
from parameterized import parameterized
-from airflow.security import permissions
from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP
from airflow.models import Variable
+from airflow.security import permissions
from airflow.www import app
from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user
from tests.test_utils.config import conf_vars
diff --git a/tests/api_connexion/endpoints/test_xcom_endpoint.py b/tests/api_connexion/endpoints/test_xcom_endpoint.py
index be1e0d2bc1ba5..92ef069719f59 100644
--- a/tests/api_connexion/endpoints/test_xcom_endpoint.py
+++ b/tests/api_connexion/endpoints/test_xcom_endpoint.py
@@ -14,8 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from datetime import timedelta
import unittest
+from datetime import timedelta
from parameterized import parameterized
diff --git a/tests/api_connexion/schemas/test_task_instance_schema.py b/tests/api_connexion/schemas/test_task_instance_schema.py
index c164aec8cdace..a2d1342e7228a 100644
--- a/tests/api_connexion/schemas/test_task_instance_schema.py
+++ b/tests/api_connexion/schemas/test_task_instance_schema.py
@@ -24,13 +24,13 @@
from airflow.api_connexion.schemas.task_instance_schema import (
clear_task_instance_form,
- task_instance_schema,
set_task_instance_state_form,
+ task_instance_schema,
)
from airflow.models import DAG, SlaMiss, TaskInstance as TI
from airflow.operators.dummy_operator import DummyOperator
-from airflow.utils.state import State
from airflow.utils.session import create_session, provide_session
+from airflow.utils.state import State
from airflow.utils.timezone import datetime
diff --git a/tests/build_provider_packages_dependencies.py b/tests/build_provider_packages_dependencies.py
index 9542759a73c92..3832bbc0e3000 100644
--- a/tests/build_provider_packages_dependencies.py
+++ b/tests/build_provider_packages_dependencies.py
@@ -67,8 +67,10 @@ def get_provider_from_file_name(file_name: str) -> Optional[str]:
:param file_name: name of the file
:return: provider name or None if no provider could be found
"""
- if AIRFLOW_PROVIDERS_FILE_PREFIX not in file_name and \
- AIRFLOW_TESTS_PROVIDERS_FILE_PREFIX not in file_name:
+ if (
+ AIRFLOW_PROVIDERS_FILE_PREFIX not in file_name
+ and AIRFLOW_TESTS_PROVIDERS_FILE_PREFIX not in file_name
+ ):
# We should only check file that are provider
errors.append(f"Wrong file not in the providers package = {file_name}")
return None
@@ -84,9 +86,9 @@ def get_provider_from_file_name(file_name: str) -> Optional[str]:
def get_file_suffix(file_name):
if AIRFLOW_PROVIDERS_FILE_PREFIX in file_name:
- return file_name[file_name.find(AIRFLOW_PROVIDERS_FILE_PREFIX):]
+ return file_name[file_name.find(AIRFLOW_PROVIDERS_FILE_PREFIX) :]
if AIRFLOW_TESTS_PROVIDERS_FILE_PREFIX in file_name:
- return file_name[file_name.find(AIRFLOW_TESTS_PROVIDERS_FILE_PREFIX):]
+ return file_name[file_name.find(AIRFLOW_TESTS_PROVIDERS_FILE_PREFIX) :]
return None
@@ -99,7 +101,7 @@ def get_provider_from_import(import_name: str) -> Optional[str]:
if AIRFLOW_PROVIDERS_IMPORT_PREFIX not in import_name:
# skip silently - we expect non-providers imports
return None
- suffix = import_name[import_name.find(AIRFLOW_PROVIDERS_IMPORT_PREFIX):]
+ suffix = import_name[import_name.find(AIRFLOW_PROVIDERS_IMPORT_PREFIX) :]
split_import = suffix.split(".")[2:]
provider = find_provider(split_import)
if not provider:
@@ -111,6 +113,7 @@ class ImportFinder(NodeVisitor):
"""
AST visitor that collects all imported names in its imports
"""
+
def __init__(self, filename):
self.imports: List[str] = []
self.filename = filename
@@ -174,12 +177,16 @@ def check_if_different_provider_used(file_name: str):
def parse_arguments():
import argparse
+
parser = argparse.ArgumentParser(
- description='Checks if dependencies between packages are handled correctly.')
- parser.add_argument("-f", "--provider-dependencies-file",
- help="Stores dependencies between providers in the file")
- parser.add_argument("-d", "--documentation-file",
- help="Updates package documentation in the file specified (.rst)")
+ description='Checks if dependencies between packages are handled correctly.'
+ )
+ parser.add_argument(
+ "-f", "--provider-dependencies-file", help="Stores dependencies between providers in the file"
+ )
+ parser.add_argument(
+ "-d", "--documentation-file", help="Updates package documentation in the file specified (.rst)"
+ )
parser.add_argument('files', nargs='*')
args = parser.parse_args()
diff --git a/tests/cli/commands/test_celery_command.py b/tests/cli/commands/test_celery_command.py
index f7187cef3a6d6..5bd2855b8b53a 100644
--- a/tests/cli/commands/test_celery_command.py
+++ b/tests/cli/commands/test_celery_command.py
@@ -64,7 +64,6 @@ def test_validate_session_dbapi_exception(self, mock_session):
@pytest.mark.integration("rabbitmq")
@pytest.mark.backend("mysql", "postgres")
class TestWorkerServeLogs(unittest.TestCase):
-
@classmethod
def setUpClass(cls):
cls.parser = cli_parser.get_parser()
@@ -123,10 +122,7 @@ def test_if_right_pid_is_read(self, mock_process, mock_setup_locations):
@mock.patch("airflow.cli.commands.celery_command.setup_locations")
@conf_vars({("core", "executor"): "CeleryExecutor"})
def test_same_pid_file_is_used_in_start_and_stop(
- self,
- mock_setup_locations,
- mock_celery_worker,
- mock_read_pid_from_pidfile
+ self, mock_setup_locations, mock_celery_worker, mock_read_pid_from_pidfile
):
pid_file = "test_pid_file"
mock_setup_locations.return_value = (pid_file, None, None, None)
@@ -164,18 +160,20 @@ def test_worker_started_with_required_arguments(self, mock_worker, mock_popen, m
celery_hostname = "celery_hostname"
queues = "queue"
autoscale = "2,5"
- args = self.parser.parse_args([
- 'celery',
- 'worker',
- '--autoscale',
- autoscale,
- '--concurrency',
- concurrency,
- '--celery-hostname',
- celery_hostname,
- '--queues',
- queues
- ])
+ args = self.parser.parse_args(
+ [
+ 'celery',
+ 'worker',
+ '--autoscale',
+ autoscale,
+ '--concurrency',
+ concurrency,
+ '--celery-hostname',
+ celery_hostname,
+ '--queues',
+ queues,
+ ]
+ )
with mock.patch('celery.platforms.check_privileges') as mock_privil:
mock_privil.return_value = 0
diff --git a/tests/cli/commands/test_cheat_sheet_command.py b/tests/cli/commands/test_cheat_sheet_command.py
index c07259b4f9406..50239d6f95fc9 100644
--- a/tests/cli/commands/test_cheat_sheet_command.py
+++ b/tests/cli/commands/test_cheat_sheet_command.py
@@ -71,7 +71,7 @@ def noop():
help='Help text D',
func=noop,
args=(),
- )
+ ),
]
EXPECTED_OUTPUT = """\
@@ -92,7 +92,6 @@ def noop():
class TestCheatSheetCommand(unittest.TestCase):
-
@classmethod
def setUpClass(cls):
cls.parser = cli_parser.get_parser()
diff --git a/tests/cli/commands/test_config_command.py b/tests/cli/commands/test_config_command.py
index c7c8925b9e1fe..56686545fe7da 100644
--- a/tests/cli/commands/test_config_command.py
+++ b/tests/cli/commands/test_config_command.py
@@ -35,9 +35,7 @@ def test_cli_show_config_should_write_data(self, mock_conf, mock_stringio):
config_command.show_config(self.parser.parse_args(['config', 'list', '--color', 'off']))
mock_conf.write.assert_called_once_with(mock_stringio.return_value.__enter__.return_value)
- @conf_vars({
- ('core', 'testkey'): 'test_value'
- })
+ @conf_vars({('core', 'testkey'): 'test_value'})
def test_cli_show_config_should_display_key(self):
with contextlib.redirect_stdout(io.StringIO()) as temp_stdout:
config_command.show_config(self.parser.parse_args(['config', 'list', '--color', 'off']))
@@ -50,9 +48,7 @@ class TestCliConfigGetValue(unittest.TestCase):
def setUpClass(cls):
cls.parser = cli_parser.get_parser()
- @conf_vars({
- ('core', 'test_key'): 'test_value'
- })
+ @conf_vars({('core', 'test_key'): 'test_value'})
def test_should_display_value(self):
with contextlib.redirect_stdout(io.StringIO()) as temp_stdout:
config_command.get_value(self.parser.parse_args(['config', 'get-value', 'core', 'test_key']))
@@ -65,9 +61,9 @@ def test_should_raise_exception_when_section_is_missing(self, mock_conf):
mock_conf.has_option.return_value = True
with contextlib.redirect_stderr(io.StringIO()) as temp_stderr, self.assertRaises(SystemExit) as cm:
- config_command.get_value(self.parser.parse_args(
- ['config', 'get-value', 'missing-section', 'dags_folder']
- ))
+ config_command.get_value(
+ self.parser.parse_args(['config', 'get-value', 'missing-section', 'dags_folder'])
+ )
self.assertEqual(1, cm.exception.code)
self.assertEqual(
"The section [missing-section] is not found in config.", temp_stderr.getvalue().strip()
@@ -79,9 +75,9 @@ def test_should_raise_exception_when_option_is_missing(self, mock_conf):
mock_conf.has_option.return_value = False
with contextlib.redirect_stderr(io.StringIO()) as temp_stderr, self.assertRaises(SystemExit) as cm:
- config_command.get_value(self.parser.parse_args(
- ['config', 'get-value', 'missing-section', 'dags_folder']
- ))
+ config_command.get_value(
+ self.parser.parse_args(['config', 'get-value', 'missing-section', 'dags_folder'])
+ )
self.assertEqual(1, cm.exception.code)
self.assertEqual(
"The option [missing-section/dags_folder] is not found in config.", temp_stderr.getvalue().strip()
diff --git a/tests/cli/commands/test_connection_command.py b/tests/cli/commands/test_connection_command.py
index 9065f510302ce..d7112268ab623 100644
--- a/tests/cli/commands/test_connection_command.py
+++ b/tests/cli/commands/test_connection_command.py
@@ -42,36 +42,67 @@ def tearDown(self):
def test_cli_connection_get(self):
with redirect_stdout(io.StringIO()) as stdout:
- connection_command.connections_get(self.parser.parse_args(
- ["connections", "get", "google_cloud_default"]
- ))
+ connection_command.connections_get(
+ self.parser.parse_args(["connections", "get", "google_cloud_default"])
+ )
stdout = stdout.getvalue()
- self.assertIn(
- "URI: google-cloud-platform:///default",
- stdout
- )
+ self.assertIn("URI: google-cloud-platform:///default", stdout)
def test_cli_connection_get_invalid(self):
with self.assertRaisesRegex(SystemExit, re.escape("Connection not found.")):
- connection_command.connections_get(self.parser.parse_args(
- ["connections", "get", "INVALID"]
- ))
+ connection_command.connections_get(self.parser.parse_args(["connections", "get", "INVALID"]))
class TestCliListConnections(unittest.TestCase):
EXPECTED_CONS = [
- ('airflow_db', 'mysql', ),
- ('google_cloud_default', 'google_cloud_platform', ),
- ('http_default', 'http', ),
- ('local_mysql', 'mysql', ),
- ('mongo_default', 'mongo', ),
- ('mssql_default', 'mssql', ),
- ('mysql_default', 'mysql', ),
- ('pinot_broker_default', 'pinot', ),
- ('postgres_default', 'postgres', ),
- ('presto_default', 'presto', ),
- ('sqlite_default', 'sqlite', ),
- ('vertica_default', 'vertica', ),
+ (
+ 'airflow_db',
+ 'mysql',
+ ),
+ (
+ 'google_cloud_default',
+ 'google_cloud_platform',
+ ),
+ (
+ 'http_default',
+ 'http',
+ ),
+ (
+ 'local_mysql',
+ 'mysql',
+ ),
+ (
+ 'mongo_default',
+ 'mongo',
+ ),
+ (
+ 'mssql_default',
+ 'mssql',
+ ),
+ (
+ 'mysql_default',
+ 'mysql',
+ ),
+ (
+ 'pinot_broker_default',
+ 'pinot',
+ ),
+ (
+ 'postgres_default',
+ 'postgres',
+ ),
+ (
+ 'presto_default',
+ 'presto',
+ ),
+ (
+ 'sqlite_default',
+ 'sqlite',
+ ),
+ (
+ 'vertica_default',
+ 'vertica',
+ ),
]
def setUp(self):
@@ -126,7 +157,7 @@ def setUp(self, session=None):
password="plainpassword",
schema="airflow",
),
- session
+ session,
)
merge_conn(
Connection(
@@ -136,7 +167,7 @@ def setUp(self, session=None):
port=8082,
extra='{"endpoint": "druid/v2/sql"}',
),
- session
+ session,
)
self.parser = cli_parser.get_parser()
@@ -146,34 +177,32 @@ def tearDown(self):
def test_cli_connections_export_should_return_error_for_invalid_command(self):
with self.assertRaises(SystemExit):
- self.parser.parse_args([
- "connections",
- "export",
- ])
+ self.parser.parse_args(
+ [
+ "connections",
+ "export",
+ ]
+ )
def test_cli_connections_export_should_return_error_for_invalid_format(self):
with self.assertRaises(SystemExit):
- self.parser.parse_args([
- "connections",
- "export",
- "--format",
- "invalid",
- "/path/to/file"
- ])
+ self.parser.parse_args(["connections", "export", "--format", "invalid", "/path/to/file"])
@mock.patch('os.path.splitext')
@mock.patch('builtins.open', new_callable=mock.mock_open())
- def test_cli_connections_export_should_return_error_for_invalid_export_format(self,
- mock_file_open,
- mock_splittext):
+ def test_cli_connections_export_should_return_error_for_invalid_export_format(
+ self, mock_file_open, mock_splittext
+ ):
output_filepath = '/tmp/connections.invalid'
mock_splittext.return_value = (None, '.invalid')
- args = self.parser.parse_args([
- "connections",
- "export",
- output_filepath,
- ])
+ args = self.parser.parse_args(
+ [
+ "connections",
+ "export",
+ output_filepath,
+ ]
+ )
with self.assertRaisesRegex(
SystemExit, r"Unsupported file format. The file must have the extension .yaml, .json, .env"
):
@@ -186,21 +215,24 @@ def test_cli_connections_export_should_return_error_for_invalid_export_format(se
@mock.patch('os.path.splitext')
@mock.patch('builtins.open', new_callable=mock.mock_open())
@mock.patch.object(connection_command, 'create_session')
- def test_cli_connections_export_should_return_error_if_create_session_fails(self, mock_session,
- mock_file_open,
- mock_splittext):
+ def test_cli_connections_export_should_return_error_if_create_session_fails(
+ self, mock_session, mock_file_open, mock_splittext
+ ):
output_filepath = '/tmp/connections.json'
def my_side_effect():
raise Exception("dummy exception")
+
mock_session.side_effect = my_side_effect
mock_splittext.return_value = (None, '.json')
- args = self.parser.parse_args([
- "connections",
- "export",
- output_filepath,
- ])
+ args = self.parser.parse_args(
+ [
+ "connections",
+ "export",
+ output_filepath,
+ ]
+ )
with self.assertRaisesRegex(Exception, r"dummy exception"):
connection_command.connections_export(args)
@@ -211,22 +243,26 @@ def my_side_effect():
@mock.patch('os.path.splitext')
@mock.patch('builtins.open', new_callable=mock.mock_open())
@mock.patch.object(connection_command, 'create_session')
- def test_cli_connections_export_should_return_error_if_fetching_connections_fails(self, mock_session,
- mock_file_open,
- mock_splittext):
+ def test_cli_connections_export_should_return_error_if_fetching_connections_fails(
+ self, mock_session, mock_file_open, mock_splittext
+ ):
output_filepath = '/tmp/connections.json'
def my_side_effect(_):
raise Exception("dummy exception")
- mock_session.return_value.__enter__.return_value.query.return_value.order_by.side_effect = \
+
+ mock_session.return_value.__enter__.return_value.query.return_value.order_by.side_effect = (
my_side_effect
+ )
mock_splittext.return_value = (None, '.json')
- args = self.parser.parse_args([
- "connections",
- "export",
- output_filepath,
- ])
+ args = self.parser.parse_args(
+ [
+ "connections",
+ "export",
+ output_filepath,
+ ]
+ )
with self.assertRaisesRegex(Exception, r"dummy exception"):
connection_command.connections_export(args)
@@ -237,19 +273,21 @@ def my_side_effect(_):
@mock.patch('os.path.splitext')
@mock.patch('builtins.open', new_callable=mock.mock_open())
@mock.patch.object(connection_command, 'create_session')
- def test_cli_connections_export_should_not_return_error_if_connections_is_empty(self, mock_session,
- mock_file_open,
- mock_splittext):
+ def test_cli_connections_export_should_not_return_error_if_connections_is_empty(
+ self, mock_session, mock_file_open, mock_splittext
+ ):
output_filepath = '/tmp/connections.json'
mock_session.return_value.__enter__.return_value.query.return_value.all.return_value = []
mock_splittext.return_value = (None, '.json')
- args = self.parser.parse_args([
- "connections",
- "export",
- output_filepath,
- ])
+ args = self.parser.parse_args(
+ [
+ "connections",
+ "export",
+ output_filepath,
+ ]
+ )
connection_command.connections_export(args)
mock_splittext.assert_called_once()
@@ -262,33 +300,38 @@ def test_cli_connections_export_should_export_as_json(self, mock_file_open, mock
output_filepath = '/tmp/connections.json'
mock_splittext.return_value = (None, '.json')
- args = self.parser.parse_args([
- "connections",
- "export",
- output_filepath,
- ])
+ args = self.parser.parse_args(
+ [
+ "connections",
+ "export",
+ output_filepath,
+ ]
+ )
connection_command.connections_export(args)
- expected_connections = json.dumps({
- "airflow_db": {
- "conn_type": "mysql",
- "host": "mysql",
- "login": "root",
- "password": "plainpassword",
- "schema": "airflow",
- "port": None,
- "extra": None,
+ expected_connections = json.dumps(
+ {
+ "airflow_db": {
+ "conn_type": "mysql",
+ "host": "mysql",
+ "login": "root",
+ "password": "plainpassword",
+ "schema": "airflow",
+ "port": None,
+ "extra": None,
+ },
+ "druid_broker_default": {
+ "conn_type": "druid",
+ "host": "druid-broker",
+ "login": None,
+ "password": None,
+ "schema": None,
+ "port": 8082,
+ "extra": "{\"endpoint\": \"druid/v2/sql\"}",
+ },
},
- "druid_broker_default": {
- "conn_type": "druid",
- "host": "druid-broker",
- "login": None,
- "password": None,
- "schema": None,
- "port": 8082,
- "extra": "{\"endpoint\": \"druid/v2/sql\"}",
- }
- }, indent=2)
+ indent=2,
+ )
mock_splittext.assert_called_once()
mock_file_open.assert_called_once_with(output_filepath, 'w', -1, 'UTF-8', None)
@@ -300,29 +343,33 @@ def test_cli_connections_export_should_export_as_yaml(self, mock_file_open, mock
output_filepath = '/tmp/connections.yaml'
mock_splittext.return_value = (None, '.yaml')
- args = self.parser.parse_args([
- "connections",
- "export",
- output_filepath,
- ])
+ args = self.parser.parse_args(
+ [
+ "connections",
+ "export",
+ output_filepath,
+ ]
+ )
connection_command.connections_export(args)
- expected_connections = ("airflow_db:\n"
- " conn_type: mysql\n"
- " extra: null\n"
- " host: mysql\n"
- " login: root\n"
- " password: plainpassword\n"
- " port: null\n"
- " schema: airflow\n"
- "druid_broker_default:\n"
- " conn_type: druid\n"
- " extra: \'{\"endpoint\": \"druid/v2/sql\"}\'\n"
- " host: druid-broker\n"
- " login: null\n"
- " password: null\n"
- " port: 8082\n"
- " schema: null\n")
+ expected_connections = (
+ "airflow_db:\n"
+ " conn_type: mysql\n"
+ " extra: null\n"
+ " host: mysql\n"
+ " login: root\n"
+ " password: plainpassword\n"
+ " port: null\n"
+ " schema: airflow\n"
+ "druid_broker_default:\n"
+ " conn_type: druid\n"
+ " extra: \'{\"endpoint\": \"druid/v2/sql\"}\'\n"
+ " host: druid-broker\n"
+ " login: null\n"
+ " password: null\n"
+ " port: 8082\n"
+ " schema: null\n"
+ )
mock_splittext.assert_called_once()
mock_file_open.assert_called_once_with(output_filepath, 'w', -1, 'UTF-8', None)
mock_file_open.return_value.write.assert_called_once_with(expected_connections)
@@ -333,19 +380,20 @@ def test_cli_connections_export_should_export_as_env(self, mock_file_open, mock_
output_filepath = '/tmp/connections.env'
mock_splittext.return_value = (None, '.env')
- args = self.parser.parse_args([
- "connections",
- "export",
- output_filepath,
- ])
+ args = self.parser.parse_args(
+ [
+ "connections",
+ "export",
+ output_filepath,
+ ]
+ )
connection_command.connections_export(args)
expected_connections = [
"airflow_db=mysql://root:plainpassword@mysql/airflow\n"
"druid_broker_default=druid://druid-broker:8082?endpoint=druid%2Fv2%2Fsql\n",
-
"druid_broker_default=druid://druid-broker:8082?endpoint=druid%2Fv2%2Fsql\n"
- "airflow_db=mysql://root:plainpassword@mysql/airflow\n"
+ "airflow_db=mysql://root:plainpassword@mysql/airflow\n",
]
mock_splittext.assert_called_once()
@@ -355,24 +403,26 @@ def test_cli_connections_export_should_export_as_env(self, mock_file_open, mock_
@mock.patch('os.path.splitext')
@mock.patch('builtins.open', new_callable=mock.mock_open())
- def test_cli_connections_export_should_export_as_env_for_uppercase_file_extension(self, mock_file_open,
- mock_splittext):
+ def test_cli_connections_export_should_export_as_env_for_uppercase_file_extension(
+ self, mock_file_open, mock_splittext
+ ):
output_filepath = '/tmp/connections.ENV'
mock_splittext.return_value = (None, '.ENV')
- args = self.parser.parse_args([
- "connections",
- "export",
- output_filepath,
- ])
+ args = self.parser.parse_args(
+ [
+ "connections",
+ "export",
+ output_filepath,
+ ]
+ )
connection_command.connections_export(args)
expected_connections = [
"airflow_db=mysql://root:plainpassword@mysql/airflow\n"
"druid_broker_default=druid://druid-broker:8082?endpoint=druid%2Fv2%2Fsql\n",
-
"druid_broker_default=druid://druid-broker:8082?endpoint=druid%2Fv2%2Fsql\n"
- "airflow_db=mysql://root:plainpassword@mysql/airflow\n"
+ "airflow_db=mysql://root:plainpassword@mysql/airflow\n",
]
mock_splittext.assert_called_once()
@@ -382,39 +432,45 @@ def test_cli_connections_export_should_export_as_env_for_uppercase_file_extensio
@mock.patch('os.path.splitext')
@mock.patch('builtins.open', new_callable=mock.mock_open())
- def test_cli_connections_export_should_force_export_as_specified_format(self, mock_file_open,
- mock_splittext):
+ def test_cli_connections_export_should_force_export_as_specified_format(
+ self, mock_file_open, mock_splittext
+ ):
output_filepath = '/tmp/connections.yaml'
- args = self.parser.parse_args([
- "connections",
- "export",
- output_filepath,
- "--format",
- "json",
- ])
+ args = self.parser.parse_args(
+ [
+ "connections",
+ "export",
+ output_filepath,
+ "--format",
+ "json",
+ ]
+ )
connection_command.connections_export(args)
- expected_connections = json.dumps({
- "airflow_db": {
- "conn_type": "mysql",
- "host": "mysql",
- "login": "root",
- "password": "plainpassword",
- "schema": "airflow",
- "port": None,
- "extra": None,
+ expected_connections = json.dumps(
+ {
+ "airflow_db": {
+ "conn_type": "mysql",
+ "host": "mysql",
+ "login": "root",
+ "password": "plainpassword",
+ "schema": "airflow",
+ "port": None,
+ "extra": None,
+ },
+ "druid_broker_default": {
+ "conn_type": "druid",
+ "host": "druid-broker",
+ "login": None,
+ "password": None,
+ "schema": None,
+ "port": 8082,
+ "extra": "{\"endpoint\": \"druid/v2/sql\"}",
+ },
},
- "druid_broker_default": {
- "conn_type": "druid",
- "host": "druid-broker",
- "login": None,
- "password": None,
- "schema": None,
- "port": 8082,
- "extra": "{\"endpoint\": \"druid/v2/sql\"}",
- }
- }, indent=2)
+ indent=2,
+ )
mock_splittext.assert_not_called()
mock_file_open.assert_called_once_with(output_filepath, 'w', -1, 'UTF-8', None)
mock_file_open.return_value.write.assert_called_once_with(expected_connections)
@@ -607,7 +663,7 @@ def test_cli_delete_connections(self, session=None):
Connection(
conn_id="new1", conn_type="mysql", host="mysql", login="root", password="", schema="airflow"
),
- session=session
+ session=session,
)
# Delete connections
with redirect_stdout(io.StringIO()) as stdout:
diff --git a/tests/cli/commands/test_dag_command.py b/tests/cli/commands/test_dag_command.py
index 87b58320389a7..331bbf811292a 100644
--- a/tests/cli/commands/test_dag_command.py
+++ b/tests/cli/commands/test_dag_command.py
@@ -38,23 +38,16 @@
dag_folder_path = '/'.join(os.path.realpath(__file__).split('/')[:-1])
DEFAULT_DATE = timezone.make_aware(datetime(2015, 1, 1))
-TEST_DAG_FOLDER = os.path.join(
- os.path.dirname(dag_folder_path), 'dags')
+TEST_DAG_FOLDER = os.path.join(os.path.dirname(dag_folder_path), 'dags')
TEST_DAG_ID = 'unit_tests'
EXAMPLE_DAGS_FOLDER = os.path.join(
- os.path.dirname(
- os.path.dirname(
- os.path.dirname(os.path.realpath(__file__))
- )
- ),
- "airflow/example_dags"
+ os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))), "airflow/example_dags"
)
class TestCliDags(unittest.TestCase):
-
@classmethod
def setUpClass(cls):
cls.dagbag = DagBag(include_examples=True)
@@ -68,9 +61,11 @@ def tearDownClass(cls) -> None:
@mock.patch("airflow.cli.commands.dag_command.DAG.run")
def test_backfill(self, mock_run):
- dag_command.dag_backfill(self.parser.parse_args([
- 'dags', 'backfill', 'example_bash_operator',
- '--start-date', DEFAULT_DATE.isoformat()]))
+ dag_command.dag_backfill(
+ self.parser.parse_args(
+ ['dags', 'backfill', 'example_bash_operator', '--start-date', DEFAULT_DATE.isoformat()]
+ )
+ )
mock_run.assert_called_once_with(
start_date=DEFAULT_DATE,
@@ -91,9 +86,21 @@ def test_backfill(self, mock_run):
dag = self.dagbag.get_dag('example_bash_operator')
with contextlib.redirect_stdout(io.StringIO()) as stdout:
- dag_command.dag_backfill(self.parser.parse_args([
- 'dags', 'backfill', 'example_bash_operator', '--task-regex', 'runme_0', '--dry-run',
- '--start-date', DEFAULT_DATE.isoformat()]), dag=dag)
+ dag_command.dag_backfill(
+ self.parser.parse_args(
+ [
+ 'dags',
+ 'backfill',
+ 'example_bash_operator',
+ '--task-regex',
+ 'runme_0',
+ '--dry-run',
+ '--start-date',
+ DEFAULT_DATE.isoformat(),
+ ]
+ ),
+ dag=dag,
+ )
output = stdout.getvalue()
self.assertIn(f"Dry run of DAG example_bash_operator on {DEFAULT_DATE.isoformat()}\n", output)
@@ -101,15 +108,35 @@ def test_backfill(self, mock_run):
mock_run.assert_not_called() # Dry run shouldn't run the backfill
- dag_command.dag_backfill(self.parser.parse_args([
- 'dags', 'backfill', 'example_bash_operator', '--dry-run',
- '--start-date', DEFAULT_DATE.isoformat()]), dag=dag)
+ dag_command.dag_backfill(
+ self.parser.parse_args(
+ [
+ 'dags',
+ 'backfill',
+ 'example_bash_operator',
+ '--dry-run',
+ '--start-date',
+ DEFAULT_DATE.isoformat(),
+ ]
+ ),
+ dag=dag,
+ )
mock_run.assert_not_called() # Dry run shouldn't run the backfill
- dag_command.dag_backfill(self.parser.parse_args([
- 'dags', 'backfill', 'example_bash_operator', '--local',
- '--start-date', DEFAULT_DATE.isoformat()]), dag=dag)
+ dag_command.dag_backfill(
+ self.parser.parse_args(
+ [
+ 'dags',
+ 'backfill',
+ 'example_bash_operator',
+ '--local',
+ '--start-date',
+ DEFAULT_DATE.isoformat(),
+ ]
+ ),
+ dag=dag,
+ )
mock_run.assert_called_once_with(
start_date=DEFAULT_DATE,
@@ -130,8 +157,7 @@ def test_backfill(self, mock_run):
def test_show_dag_print(self):
with contextlib.redirect_stdout(io.StringIO()) as temp_stdout:
- dag_command.dag_show(self.parser.parse_args([
- 'dags', 'show', 'example_bash_operator']))
+ dag_command.dag_show(self.parser.parse_args(['dags', 'show', 'example_bash_operator']))
out = temp_stdout.getvalue()
self.assertIn("label=example_bash_operator", out)
self.assertIn("graph [label=example_bash_operator labelloc=t rankdir=LR]", out)
@@ -140,9 +166,9 @@ def test_show_dag_print(self):
@mock.patch("airflow.cli.commands.dag_command.render_dag")
def test_show_dag_dave(self, mock_render_dag):
with contextlib.redirect_stdout(io.StringIO()) as temp_stdout:
- dag_command.dag_show(self.parser.parse_args([
- 'dags', 'show', 'example_bash_operator', '--save', 'awesome.png']
- ))
+ dag_command.dag_show(
+ self.parser.parse_args(['dags', 'show', 'example_bash_operator', '--save', 'awesome.png'])
+ )
out = temp_stdout.getvalue()
mock_render_dag.return_value.render.assert_called_once_with(
cleanup=True, filename='awesome', format='png'
@@ -155,9 +181,9 @@ def test_show_dag_imgcat(self, mock_render_dag, mock_popen):
mock_render_dag.return_value.pipe.return_value = b"DOT_DATA"
mock_popen.return_value.communicate.return_value = (b"OUT", b"ERR")
with contextlib.redirect_stdout(io.StringIO()) as temp_stdout:
- dag_command.dag_show(self.parser.parse_args([
- 'dags', 'show', 'example_bash_operator', '--imgcat']
- ))
+ dag_command.dag_show(
+ self.parser.parse_args(['dags', 'show', 'example_bash_operator', '--imgcat'])
+ )
out = temp_stdout.getvalue()
mock_render_dag.return_value.pipe.assert_called_once_with(format='png')
mock_popen.return_value.communicate.assert_called_once_with(b'DOT_DATA')
@@ -243,10 +269,12 @@ def test_cli_backfill_depends_on_past_backwards(self, mock_run):
)
def test_next_execution(self):
- dag_ids = ['example_bash_operator', # schedule_interval is '0 0 * * *'
- 'latest_only', # schedule_interval is timedelta(hours=4)
- 'example_python_operator', # schedule_interval=None
- 'example_xcom'] # schedule_interval="@once"
+ dag_ids = [
+ 'example_bash_operator', # schedule_interval is '0 0 * * *'
+ 'latest_only', # schedule_interval is timedelta(hours=4)
+ 'example_python_operator', # schedule_interval=None
+ 'example_xcom',
+ ] # schedule_interval="@once"
# Delete DagRuns
with create_session() as session:
@@ -254,9 +282,7 @@ def test_next_execution(self):
dr.delete(synchronize_session=False)
# Test None output
- args = self.parser.parse_args(['dags',
- 'next-execution',
- dag_ids[0]])
+ args = self.parser.parse_args(['dags', 'next-execution', dag_ids[0]])
with contextlib.redirect_stdout(io.StringIO()) as temp_stdout:
dag_command.dag_next_execution(args)
out = temp_stdout.getvalue()
@@ -266,40 +292,30 @@ def test_next_execution(self):
# The details below is determined by the schedule_interval of example DAGs
now = DEFAULT_DATE
- expected_output = [str(now + timedelta(days=1)),
- str(now + timedelta(hours=4)),
- "None",
- "None"]
- expected_output_2 = [str(now + timedelta(days=1)) + os.linesep + str(now + timedelta(days=2)),
- str(now + timedelta(hours=4)) + os.linesep + str(now + timedelta(hours=8)),
- "None",
- "None"]
+ expected_output = [str(now + timedelta(days=1)), str(now + timedelta(hours=4)), "None", "None"]
+ expected_output_2 = [
+ str(now + timedelta(days=1)) + os.linesep + str(now + timedelta(days=2)),
+ str(now + timedelta(hours=4)) + os.linesep + str(now + timedelta(hours=8)),
+ "None",
+ "None",
+ ]
for i, dag_id in enumerate(dag_ids):
dag = self.dagbag.dags[dag_id]
# Create a DagRun for each DAG, to prepare for next step
dag.create_dagrun(
- run_type=DagRunType.MANUAL,
- execution_date=now,
- start_date=now,
- state=State.FAILED
+ run_type=DagRunType.MANUAL, execution_date=now, start_date=now, state=State.FAILED
)
# Test num-executions = 1 (default)
- args = self.parser.parse_args(['dags',
- 'next-execution',
- dag_id])
+ args = self.parser.parse_args(['dags', 'next-execution', dag_id])
with contextlib.redirect_stdout(io.StringIO()) as temp_stdout:
dag_command.dag_next_execution(args)
out = temp_stdout.getvalue()
self.assertIn(expected_output[i], out)
# Test num-executions = 2
- args = self.parser.parse_args(['dags',
- 'next-execution',
- dag_id,
- '--num-executions',
- '2'])
+ args = self.parser.parse_args(['dags', 'next-execution', dag_id, '--num-executions', '2'])
with contextlib.redirect_stdout(io.StringIO()) as temp_stdout:
dag_command.dag_next_execution(args)
out = temp_stdout.getvalue()
@@ -310,9 +326,7 @@ def test_next_execution(self):
dr = session.query(DagRun).filter(DagRun.dag_id.in_(dag_ids))
dr.delete(synchronize_session=False)
- @conf_vars({
- ('core', 'load_examples'): 'true'
- })
+ @conf_vars({('core', 'load_examples'): 'true'})
def test_cli_report(self):
args = self.parser.parse_args(['dags', 'report'])
with contextlib.redirect_stdout(io.StringIO()) as temp_stdout:
@@ -322,9 +336,7 @@ def test_cli_report(self):
self.assertIn("airflow/example_dags/example_complex.py ", out)
self.assertIn("['example_complex']", out)
- @conf_vars({
- ('core', 'load_examples'): 'true'
- })
+ @conf_vars({('core', 'load_examples'): 'true'})
def test_cli_list_dags(self):
args = self.parser.parse_args(['dags', 'list', '--output=fancy_grid'])
with contextlib.redirect_stdout(io.StringIO()) as temp_stdout:
@@ -335,49 +347,74 @@ def test_cli_list_dags(self):
self.assertIn("airflow/example_dags/example_complex.py", out)
def test_cli_list_dag_runs(self):
- dag_command.dag_trigger(self.parser.parse_args([
- 'dags', 'trigger', 'example_bash_operator', ]))
- args = self.parser.parse_args(['dags',
- 'list-runs',
- '--dag-id',
- 'example_bash_operator',
- '--no-backfill',
- '--start-date',
- DEFAULT_DATE.isoformat(),
- '--end-date',
- timezone.make_aware(datetime.max).isoformat()])
+ dag_command.dag_trigger(
+ self.parser.parse_args(
+ [
+ 'dags',
+ 'trigger',
+ 'example_bash_operator',
+ ]
+ )
+ )
+ args = self.parser.parse_args(
+ [
+ 'dags',
+ 'list-runs',
+ '--dag-id',
+ 'example_bash_operator',
+ '--no-backfill',
+ '--start-date',
+ DEFAULT_DATE.isoformat(),
+ '--end-date',
+ timezone.make_aware(datetime.max).isoformat(),
+ ]
+ )
dag_command.dag_list_dag_runs(args)
def test_cli_list_jobs_with_args(self):
- args = self.parser.parse_args(['dags', 'list-jobs', '--dag-id',
- 'example_bash_operator',
- '--state', 'success',
- '--limit', '100',
- '--output', 'tsv'])
+ args = self.parser.parse_args(
+ [
+ 'dags',
+ 'list-jobs',
+ '--dag-id',
+ 'example_bash_operator',
+ '--state',
+ 'success',
+ '--limit',
+ '100',
+ '--output',
+ 'tsv',
+ ]
+ )
dag_command.dag_list_jobs(args)
def test_pause(self):
- args = self.parser.parse_args([
- 'dags', 'pause', 'example_bash_operator'])
+ args = self.parser.parse_args(['dags', 'pause', 'example_bash_operator'])
dag_command.dag_pause(args)
self.assertIn(self.dagbag.dags['example_bash_operator'].get_is_paused(), [True, 1])
- args = self.parser.parse_args([
- 'dags', 'unpause', 'example_bash_operator'])
+ args = self.parser.parse_args(['dags', 'unpause', 'example_bash_operator'])
dag_command.dag_unpause(args)
self.assertIn(self.dagbag.dags['example_bash_operator'].get_is_paused(), [False, 0])
def test_trigger_dag(self):
- dag_command.dag_trigger(self.parser.parse_args([
- 'dags', 'trigger', 'example_bash_operator',
- '--conf', '{"foo": "bar"}']))
+ dag_command.dag_trigger(
+ self.parser.parse_args(['dags', 'trigger', 'example_bash_operator', '--conf', '{"foo": "bar"}'])
+ )
self.assertRaises(
ValueError,
dag_command.dag_trigger,
- self.parser.parse_args([
- 'dags', 'trigger', 'example_bash_operator',
- '--run-id', 'trigger_dag_xxx',
- '--conf', 'NOT JSON'])
+ self.parser.parse_args(
+ [
+ 'dags',
+ 'trigger',
+ 'example_bash_operator',
+ '--run-id',
+ 'trigger_dag_xxx',
+ '--conf',
+ 'NOT JSON',
+ ]
+ ),
)
def test_delete_dag(self):
@@ -386,16 +423,12 @@ def test_delete_dag(self):
session = settings.Session()
session.add(DM(dag_id=key))
session.commit()
- dag_command.dag_delete(self.parser.parse_args([
- 'dags', 'delete', key, '--yes']))
+ dag_command.dag_delete(self.parser.parse_args(['dags', 'delete', key, '--yes']))
self.assertEqual(session.query(DM).filter_by(dag_id=key).count(), 0)
self.assertRaises(
AirflowException,
dag_command.dag_delete,
- self.parser.parse_args([
- 'dags', 'delete',
- 'does_not_exist_dag',
- '--yes'])
+ self.parser.parse_args(['dags', 'delete', 'does_not_exist_dag', '--yes']),
)
def test_delete_dag_existing_file(self):
@@ -407,8 +440,7 @@ def test_delete_dag_existing_file(self):
with tempfile.NamedTemporaryFile() as f:
session.add(DM(dag_id=key, fileloc=f.name))
session.commit()
- dag_command.dag_delete(self.parser.parse_args([
- 'dags', 'delete', key, '--yes']))
+ dag_command.dag_delete(self.parser.parse_args(['dags', 'delete', key, '--yes']))
self.assertEqual(session.query(DM).filter_by(dag_id=key).count(), 0)
def test_cli_list_jobs(self):
@@ -416,8 +448,12 @@ def test_cli_list_jobs(self):
dag_command.dag_list_jobs(args)
def test_dag_state(self):
- self.assertEqual(None, dag_command.dag_state(self.parser.parse_args([
- 'dags', 'state', 'example_bash_operator', DEFAULT_DATE.isoformat()])))
+ self.assertEqual(
+ None,
+ dag_command.dag_state(
+ self.parser.parse_args(['dags', 'state', 'example_bash_operator', DEFAULT_DATE.isoformat()])
+ ),
+ )
@mock.patch("airflow.cli.commands.dag_command.DebugExecutor")
@mock.patch("airflow.cli.commands.dag_command.get_dag")
@@ -425,20 +461,21 @@ def test_dag_test(self, mock_get_dag, mock_executor):
cli_args = self.parser.parse_args(['dags', 'test', 'example_bash_operator', DEFAULT_DATE.isoformat()])
dag_command.dag_test(cli_args)
- mock_get_dag.assert_has_calls([
- mock.call(
- subdir=cli_args.subdir, dag_id='example_bash_operator'
- ),
- mock.call().clear(
- start_date=cli_args.execution_date, end_date=cli_args.execution_date,
- dag_run_state=State.NONE,
- ),
- mock.call().run(
- executor=mock_executor.return_value,
- start_date=cli_args.execution_date,
- end_date=cli_args.execution_date
- )
- ])
+ mock_get_dag.assert_has_calls(
+ [
+ mock.call(subdir=cli_args.subdir, dag_id='example_bash_operator'),
+ mock.call().clear(
+ start_date=cli_args.execution_date,
+ end_date=cli_args.execution_date,
+ dag_run_state=State.NONE,
+ ),
+ mock.call().run(
+ executor=mock_executor.return_value,
+ start_date=cli_args.execution_date,
+ end_date=cli_args.execution_date,
+ ),
+ ]
+ )
@mock.patch(
"airflow.cli.commands.dag_command.render_dag", **{'return_value.source': "SOURCE"} # type: ignore
@@ -446,30 +483,28 @@ def test_dag_test(self, mock_get_dag, mock_executor):
@mock.patch("airflow.cli.commands.dag_command.DebugExecutor")
@mock.patch("airflow.cli.commands.dag_command.get_dag")
def test_dag_test_show_dag(self, mock_get_dag, mock_executor, mock_render_dag):
- cli_args = self.parser.parse_args([
- 'dags', 'test', 'example_bash_operator', DEFAULT_DATE.isoformat(), '--show-dagrun'
- ])
+ cli_args = self.parser.parse_args(
+ ['dags', 'test', 'example_bash_operator', DEFAULT_DATE.isoformat(), '--show-dagrun']
+ )
with contextlib.redirect_stdout(io.StringIO()) as stdout:
dag_command.dag_test(cli_args)
output = stdout.getvalue()
- mock_get_dag.assert_has_calls([
- mock.call(
- subdir=cli_args.subdir, dag_id='example_bash_operator'
- ),
- mock.call().clear(
- start_date=cli_args.execution_date,
- end_date=cli_args.execution_date,
- dag_run_state=State.NONE,
- ),
- mock.call().run(
- executor=mock_executor.return_value,
- start_date=cli_args.execution_date,
- end_date=cli_args.execution_date
- )
- ])
- mock_render_dag.assert_has_calls([
- mock.call(mock_get_dag.return_value, tis=[])
- ])
+ mock_get_dag.assert_has_calls(
+ [
+ mock.call(subdir=cli_args.subdir, dag_id='example_bash_operator'),
+ mock.call().clear(
+ start_date=cli_args.execution_date,
+ end_date=cli_args.execution_date,
+ dag_run_state=State.NONE,
+ ),
+ mock.call().run(
+ executor=mock_executor.return_value,
+ start_date=cli_args.execution_date,
+ end_date=cli_args.execution_date,
+ ),
+ ]
+ )
+ mock_render_dag.assert_has_calls([mock.call(mock_get_dag.return_value, tis=[])])
self.assertIn("SOURCE", output)
diff --git a/tests/cli/commands/test_db_command.py b/tests/cli/commands/test_db_command.py
index e8ed61048761f..8543db9f6cc86 100644
--- a/tests/cli/commands/test_db_command.py
+++ b/tests/cli/commands/test_db_command.py
@@ -56,56 +56,48 @@ def test_cli_upgradedb(self, mock_upgradedb):
@mock.patch("airflow.cli.commands.db_command.execute_interactive")
@mock.patch("airflow.cli.commands.db_command.NamedTemporaryFile")
- @mock.patch(
- "airflow.cli.commands.db_command.settings.engine.url",
- make_url("mysql://root@mysql/airflow")
- )
+ @mock.patch("airflow.cli.commands.db_command.settings.engine.url", make_url("mysql://root@mysql/airflow"))
def test_cli_shell_mysql(self, mock_tmp_file, mock_execute_interactive):
mock_tmp_file.return_value.__enter__.return_value.name = "/tmp/name"
db_command.shell(self.parser.parse_args(['db', 'shell']))
- mock_execute_interactive.assert_called_once_with(
- ['mysql', '--defaults-extra-file=/tmp/name']
- )
+ mock_execute_interactive.assert_called_once_with(['mysql', '--defaults-extra-file=/tmp/name'])
mock_tmp_file.return_value.__enter__.return_value.write.assert_called_once_with(
- b'[client]\nhost = mysql\nuser = root\npassword = \nport = '
- b'\ndatabase = airflow'
+ b'[client]\nhost = mysql\nuser = root\npassword = \nport = ' b'\ndatabase = airflow'
)
@mock.patch("airflow.cli.commands.db_command.execute_interactive")
@mock.patch(
- "airflow.cli.commands.db_command.settings.engine.url",
- make_url("sqlite:////root/airflow/airflow.db")
+ "airflow.cli.commands.db_command.settings.engine.url", make_url("sqlite:////root/airflow/airflow.db")
)
def test_cli_shell_sqlite(self, mock_execute_interactive):
db_command.shell(self.parser.parse_args(['db', 'shell']))
- mock_execute_interactive.assert_called_once_with(
- ['sqlite3', '/root/airflow/airflow.db']
- )
+ mock_execute_interactive.assert_called_once_with(['sqlite3', '/root/airflow/airflow.db'])
@mock.patch("airflow.cli.commands.db_command.execute_interactive")
@mock.patch(
"airflow.cli.commands.db_command.settings.engine.url",
- make_url("postgresql+psycopg2://postgres:airflow@postgres/airflow")
+ make_url("postgresql+psycopg2://postgres:airflow@postgres/airflow"),
)
def test_cli_shell_postgres(self, mock_execute_interactive):
db_command.shell(self.parser.parse_args(['db', 'shell']))
- mock_execute_interactive.assert_called_once_with(
- ['psql'], env=mock.ANY
- )
+ mock_execute_interactive.assert_called_once_with(['psql'], env=mock.ANY)
_, kwargs = mock_execute_interactive.call_args
env = kwargs['env']
postgres_env = {k: v for k, v in env.items() if k.startswith('PG')}
- self.assertEqual({
- 'PGDATABASE': 'airflow',
- 'PGHOST': 'postgres',
- 'PGPASSWORD': 'airflow',
- 'PGPORT': '',
- 'PGUSER': 'postgres'
- }, postgres_env)
+ self.assertEqual(
+ {
+ 'PGDATABASE': 'airflow',
+ 'PGHOST': 'postgres',
+ 'PGPASSWORD': 'airflow',
+ 'PGPORT': '',
+ 'PGUSER': 'postgres',
+ },
+ postgres_env,
+ )
@mock.patch(
"airflow.cli.commands.db_command.settings.engine.url",
- make_url("invalid+psycopg2://postgres:airflow@postgres/airflow")
+ make_url("invalid+psycopg2://postgres:airflow@postgres/airflow"),
)
def test_cli_shell_invalid(self):
with self.assertRaisesRegex(AirflowException, r"Unknown driver: invalid\+psycopg2"):
diff --git a/tests/cli/commands/test_info_command.py b/tests/cli/commands/test_info_command.py
index 56eb584f9241a..bd69704f68f92 100644
--- a/tests/cli/commands/test_info_command.py
+++ b/tests/cli/commands/test_info_command.py
@@ -54,7 +54,10 @@ def test_should_remove_pii_from_path(self):
"postgresql+psycopg2://:airflow@postgres/airflow",
"postgresql+psycopg2://:PASSWORD@postgres/airflow",
),
- ("postgresql+psycopg2://postgres/airflow", "postgresql+psycopg2://postgres/airflow",),
+ (
+ "postgresql+psycopg2://postgres/airflow",
+ "postgresql+psycopg2://postgres/airflow",
+ ),
]
)
def test_should_remove_pii_from_url(self, before, after):
@@ -100,10 +103,12 @@ def test_should_read_config(self):
class TestConfigInfoLogging(unittest.TestCase):
def test_should_read_logging_configuration(self):
- with conf_vars({
- ('logging', 'remote_logging'): 'True',
- ('logging', 'remote_base_log_folder'): 'stackdriver://logs-name',
- }):
+ with conf_vars(
+ {
+ ('logging', 'remote_logging'): 'True',
+ ('logging', 'remote_base_log_folder'): 'stackdriver://logs-name',
+ }
+ ):
importlib.reload(airflow_local_settings)
configure_logging()
instance = info_command.ConfigInfo(info_command.NullAnonymizer())
@@ -166,7 +171,7 @@ def test_show_info_anonymize(self):
"link": "https://file.io/TEST",
"expiry": "14 days",
},
- }
+ },
)
def test_show_info_anonymize_fileio(self, mock_requests):
with contextlib.redirect_stdout(io.StringIO()) as stdout:
diff --git a/tests/cli/commands/test_kubernetes_command.py b/tests/cli/commands/test_kubernetes_command.py
index 9e21aaa82a11d..3c4dda8e1bc83 100644
--- a/tests/cli/commands/test_kubernetes_command.py
+++ b/tests/cli/commands/test_kubernetes_command.py
@@ -35,9 +35,18 @@ def setUpClass(cls):
def test_generate_dag_yaml(self):
with tempfile.TemporaryDirectory("airflow_dry_run_test/") as directory:
file_name = "example_bash_operator_run_after_loop_2020-11-03T00_00_00_plus_00_00.yml"
- kubernetes_command.generate_pod_yaml(self.parser.parse_args([
- 'kubernetes', 'generate-dag-yaml',
- 'example_bash_operator', "2020-11-03", "--output-path", directory]))
+ kubernetes_command.generate_pod_yaml(
+ self.parser.parse_args(
+ [
+ 'kubernetes',
+ 'generate-dag-yaml',
+ 'example_bash_operator',
+ "2020-11-03",
+ "--output-path",
+ directory,
+ ]
+ )
+ )
self.assertEqual(len(os.listdir(directory)), 1)
out_dir = directory + "/airflow_yaml_output/"
self.assertEqual(len(os.listdir(out_dir)), 6)
diff --git a/tests/cli/commands/test_legacy_commands.py b/tests/cli/commands/test_legacy_commands.py
index 42a04ff5bb008..444cda07c1f08 100644
--- a/tests/cli/commands/test_legacy_commands.py
+++ b/tests/cli/commands/test_legacy_commands.py
@@ -24,11 +24,35 @@
from airflow.cli.commands import config_command
from airflow.cli.commands.legacy_commands import COMMAND_MAP, check_legacy_command
-LEGACY_COMMANDS = ["worker", "flower", "trigger_dag", "delete_dag", "show_dag", "list_dag",
- "dag_status", "backfill", "list_dag_runs", "pause", "unpause", "test",
- "clear", "list_tasks", "task_failed_deps", "task_state", "run",
- "render", "initdb", "resetdb", "upgradedb", "checkdb", "shell", "pool",
- "list_users", "create_user", "delete_user"]
+LEGACY_COMMANDS = [
+ "worker",
+ "flower",
+ "trigger_dag",
+ "delete_dag",
+ "show_dag",
+ "list_dag",
+ "dag_status",
+ "backfill",
+ "list_dag_runs",
+ "pause",
+ "unpause",
+ "test",
+ "clear",
+ "list_tasks",
+ "task_failed_deps",
+ "task_state",
+ "run",
+ "render",
+ "initdb",
+ "resetdb",
+ "upgradedb",
+ "checkdb",
+ "shell",
+ "pool",
+ "list_users",
+ "create_user",
+ "delete_user",
+]
class TestCliDeprecatedCommandsValue(unittest.TestCase):
@@ -37,15 +61,16 @@ def setUpClass(cls):
cls.parser = cli_parser.get_parser()
def test_should_display_value(self):
- with self.assertRaises(SystemExit) as cm_exception, \
- contextlib.redirect_stderr(io.StringIO()) as temp_stderr:
+ with self.assertRaises(SystemExit) as cm_exception, contextlib.redirect_stderr(
+ io.StringIO()
+ ) as temp_stderr:
config_command.get_value(self.parser.parse_args(['worker']))
self.assertEqual(2, cm_exception.exception.code)
self.assertIn(
"`airflow worker` command, has been removed, "
"please use `airflow celery worker`, see help above.",
- temp_stderr.getvalue().strip()
+ temp_stderr.getvalue().strip(),
)
def test_command_map(self):
@@ -58,4 +83,5 @@ def test_check_legacy_command(self):
check_legacy_command(action, 'list_users')
self.assertEqual(
str(e.exception),
- "argument : `airflow list_users` command, has been removed, please use `airflow users list`")
+ "argument : `airflow list_users` command, has been removed, please use `airflow users list`",
+ )
diff --git a/tests/cli/commands/test_pool_command.py b/tests/cli/commands/test_pool_command.py
index 135166d033608..be95d09ede633 100644
--- a/tests/cli/commands/test_pool_command.py
+++ b/tests/cli/commands/test_pool_command.py
@@ -81,18 +81,9 @@ def test_pool_delete(self):
def test_pool_import_export(self):
# Create two pools first
pool_config_input = {
- "foo": {
- "description": "foo_test",
- "slots": 1
- },
- 'default_pool': {
- 'description': 'Default pool',
- 'slots': 128
- },
- "baz": {
- "description": "baz_test",
- "slots": 2
- }
+ "foo": {"description": "foo_test", "slots": 1},
+ 'default_pool': {'description': 'Default pool', 'slots': 128},
+ "baz": {"description": "baz_test", "slots": 2},
}
with open('pools_import.json', mode='w') as file:
json.dump(pool_config_input, file)
@@ -106,8 +97,7 @@ def test_pool_import_export(self):
with open('pools_export.json', mode='r') as file:
pool_config_output = json.load(file)
self.assertEqual(
- pool_config_input,
- pool_config_output,
- "Input and output pool files are not same")
+ pool_config_input, pool_config_output, "Input and output pool files are not same"
+ )
os.remove('pools_import.json')
os.remove('pools_export.json')
diff --git a/tests/cli/commands/test_role_command.py b/tests/cli/commands/test_role_command.py
index 20af87986d6d0..e5a93fcebf8ee 100644
--- a/tests/cli/commands/test_role_command.py
+++ b/tests/cli/commands/test_role_command.py
@@ -36,6 +36,7 @@ def setUpClass(cls):
def setUp(self):
from airflow.www import app as application
+
self.app = application.create_app(testing=True)
self.appbuilder = self.app.appbuilder # pylint: disable=no-member
self.clear_roles_and_roles()
@@ -56,9 +57,7 @@ def test_cli_create_roles(self):
self.assertIsNone(self.appbuilder.sm.find_role('FakeTeamA'))
self.assertIsNone(self.appbuilder.sm.find_role('FakeTeamB'))
- args = self.parser.parse_args([
- 'roles', 'create', 'FakeTeamA', 'FakeTeamB'
- ])
+ args = self.parser.parse_args(['roles', 'create', 'FakeTeamA', 'FakeTeamB'])
role_command.roles_create(args)
self.assertIsNotNone(self.appbuilder.sm.find_role('FakeTeamA'))
@@ -68,9 +67,7 @@ def test_cli_create_roles_is_reentrant(self):
self.assertIsNone(self.appbuilder.sm.find_role('FakeTeamA'))
self.assertIsNone(self.appbuilder.sm.find_role('FakeTeamB'))
- args = self.parser.parse_args([
- 'roles', 'create', 'FakeTeamA', 'FakeTeamB'
- ])
+ args = self.parser.parse_args(['roles', 'create', 'FakeTeamA', 'FakeTeamB'])
role_command.roles_create(args)
diff --git a/tests/cli/commands/test_sync_perm_command.py b/tests/cli/commands/test_sync_perm_command.py
index 5019d48d00d83..e1944472e1cf0 100644
--- a/tests/cli/commands/test_sync_perm_command.py
+++ b/tests/cli/commands/test_sync_perm_command.py
@@ -35,19 +35,17 @@ def setUpClass(cls):
@mock.patch("airflow.cli.commands.sync_perm_command.cached_app")
@mock.patch("airflow.cli.commands.sync_perm_command.DagBag")
def test_cli_sync_perm(self, dagbag_mock, mock_cached_app):
- self.expect_dagbag_contains([
- DAG('has_access_control',
- access_control={
- 'Public': {permissions.ACTION_CAN_READ}
- }),
- DAG('no_access_control')
- ], dagbag_mock)
+ self.expect_dagbag_contains(
+ [
+ DAG('has_access_control', access_control={'Public': {permissions.ACTION_CAN_READ}}),
+ DAG('no_access_control'),
+ ],
+ dagbag_mock,
+ )
appbuilder = mock_cached_app.return_value.appbuilder
appbuilder.sm = mock.Mock()
- args = self.parser.parse_args([
- 'sync-perm'
- ])
+ args = self.parser.parse_args(['sync-perm'])
sync_perm_command.sync_perm(args)
assert appbuilder.sm.sync_roles.call_count == 1
@@ -55,8 +53,7 @@ def test_cli_sync_perm(self, dagbag_mock, mock_cached_app):
dagbag_mock.assert_called_once_with(read_dags_from_db=True)
self.assertEqual(2, len(appbuilder.sm.sync_perm_for_dag.mock_calls))
appbuilder.sm.sync_perm_for_dag.assert_any_call(
- 'has_access_control',
- {'Public': {permissions.ACTION_CAN_READ}}
+ 'has_access_control', {'Public': {permissions.ACTION_CAN_READ}}
)
appbuilder.sm.sync_perm_for_dag.assert_any_call(
'no_access_control',
diff --git a/tests/cli/commands/test_task_command.py b/tests/cli/commands/test_task_command.py
index 94b8ac4e96807..113c961c21e91 100644
--- a/tests/cli/commands/test_task_command.py
+++ b/tests/cli/commands/test_task_command.py
@@ -65,15 +65,14 @@ def test_cli_list_tasks(self):
args = self.parser.parse_args(['tasks', 'list', dag_id])
task_command.task_list(args)
- args = self.parser.parse_args([
- 'tasks', 'list', 'example_bash_operator', '--tree'])
+ args = self.parser.parse_args(['tasks', 'list', 'example_bash_operator', '--tree'])
task_command.task_list(args)
def test_test(self):
"""Test the `airflow test` command"""
- args = self.parser.parse_args([
- "tasks", "test", "example_python_operator", 'print_the_context', '2018-01-01'
- ])
+ args = self.parser.parse_args(
+ ["tasks", "test", "example_python_operator", 'print_the_context', '2018-01-01']
+ )
with redirect_stdout(io.StringIO()) as stdout:
task_command.task_test(args)
@@ -91,13 +90,15 @@ def test_run_naive_taskinstance(self, mock_local_job):
dag = self.dagbag.get_dag('test_run_ignores_all_dependencies')
task0_id = 'test_run_dependent_task'
- args0 = ['tasks',
- 'run',
- '--ignore-all-dependencies',
- '--local',
- dag_id,
- task0_id,
- naive_date.isoformat()]
+ args0 = [
+ 'tasks',
+ 'run',
+ '--ignore-all-dependencies',
+ '--local',
+ dag_id,
+ task0_id,
+ naive_date.isoformat(),
+ ]
task_command.task_run(self.parser.parse_args(args0), dag=dag)
mock_local_job.assert_called_once_with(
@@ -112,67 +113,119 @@ def test_run_naive_taskinstance(self, mock_local_job):
)
def test_cli_test(self):
- task_command.task_test(self.parser.parse_args([
- 'tasks', 'test', 'example_bash_operator', 'runme_0',
- DEFAULT_DATE.isoformat()]))
- task_command.task_test(self.parser.parse_args([
- 'tasks', 'test', 'example_bash_operator', 'runme_0', '--dry-run',
- DEFAULT_DATE.isoformat()]))
+ task_command.task_test(
+ self.parser.parse_args(
+ ['tasks', 'test', 'example_bash_operator', 'runme_0', DEFAULT_DATE.isoformat()]
+ )
+ )
+ task_command.task_test(
+ self.parser.parse_args(
+ ['tasks', 'test', 'example_bash_operator', 'runme_0', '--dry-run', DEFAULT_DATE.isoformat()]
+ )
+ )
def test_cli_test_with_params(self):
- task_command.task_test(self.parser.parse_args([
- 'tasks', 'test', 'example_passing_params_via_test_command', 'run_this',
- '--task-params', '{"foo":"bar"}', DEFAULT_DATE.isoformat()]))
- task_command.task_test(self.parser.parse_args([
- 'tasks', 'test', 'example_passing_params_via_test_command', 'also_run_this',
- '--task-params', '{"foo":"bar"}', DEFAULT_DATE.isoformat()]))
+ task_command.task_test(
+ self.parser.parse_args(
+ [
+ 'tasks',
+ 'test',
+ 'example_passing_params_via_test_command',
+ 'run_this',
+ '--task-params',
+ '{"foo":"bar"}',
+ DEFAULT_DATE.isoformat(),
+ ]
+ )
+ )
+ task_command.task_test(
+ self.parser.parse_args(
+ [
+ 'tasks',
+ 'test',
+ 'example_passing_params_via_test_command',
+ 'also_run_this',
+ '--task-params',
+ '{"foo":"bar"}',
+ DEFAULT_DATE.isoformat(),
+ ]
+ )
+ )
def test_cli_test_with_env_vars(self):
with redirect_stdout(io.StringIO()) as stdout:
- task_command.task_test(self.parser.parse_args([
- 'tasks', 'test', 'example_passing_params_via_test_command', 'env_var_test_task',
- '--env-vars', '{"foo":"bar"}', DEFAULT_DATE.isoformat()]))
+ task_command.task_test(
+ self.parser.parse_args(
+ [
+ 'tasks',
+ 'test',
+ 'example_passing_params_via_test_command',
+ 'env_var_test_task',
+ '--env-vars',
+ '{"foo":"bar"}',
+ DEFAULT_DATE.isoformat(),
+ ]
+ )
+ )
output = stdout.getvalue()
self.assertIn('foo=bar', output)
self.assertIn('AIRFLOW_TEST_MODE=True', output)
def test_cli_run(self):
- task_command.task_run(self.parser.parse_args([
- 'tasks', 'run', 'example_bash_operator', 'runme_0', '--local',
- DEFAULT_DATE.isoformat()]))
+ task_command.task_run(
+ self.parser.parse_args(
+ ['tasks', 'run', 'example_bash_operator', 'runme_0', '--local', DEFAULT_DATE.isoformat()]
+ )
+ )
@parameterized.expand(
[
- ("--ignore-all-dependencies", ),
- ("--ignore-depends-on-past", ),
+ ("--ignore-all-dependencies",),
+ ("--ignore-depends-on-past",),
("--ignore-dependencies",),
("--force",),
],
-
)
def test_cli_run_invalid_raw_option(self, option: str):
with self.assertRaisesRegex(
- AirflowException,
- "Option --raw does not work with some of the other options on this command."
+ AirflowException, "Option --raw does not work with some of the other options on this command."
):
- task_command.task_run(self.parser.parse_args([ # type: ignore
- 'tasks', 'run', 'example_bash_operator', 'runme_0', DEFAULT_DATE.isoformat(), '--raw', option
- ]))
+ task_command.task_run(
+ self.parser.parse_args(
+ [ # type: ignore
+ 'tasks',
+ 'run',
+ 'example_bash_operator',
+ 'runme_0',
+ DEFAULT_DATE.isoformat(),
+ '--raw',
+ option,
+ ]
+ )
+ )
def test_cli_run_mutually_exclusive(self):
- with self.assertRaisesRegex(
- AirflowException,
- "Option --raw and --local are mutually exclusive."
- ):
- task_command.task_run(self.parser.parse_args([
- 'tasks', 'run', 'example_bash_operator', 'runme_0', DEFAULT_DATE.isoformat(), '--raw',
- '--local'
- ]))
+ with self.assertRaisesRegex(AirflowException, "Option --raw and --local are mutually exclusive."):
+ task_command.task_run(
+ self.parser.parse_args(
+ [
+ 'tasks',
+ 'run',
+ 'example_bash_operator',
+ 'runme_0',
+ DEFAULT_DATE.isoformat(),
+ '--raw',
+ '--local',
+ ]
+ )
+ )
def test_task_state(self):
- task_command.task_state(self.parser.parse_args([
- 'tasks', 'state', 'example_bash_operator', 'runme_0',
- DEFAULT_DATE.isoformat()]))
+ task_command.task_state(
+ self.parser.parse_args(
+ ['tasks', 'state', 'example_bash_operator', 'runme_0', DEFAULT_DATE.isoformat()]
+ )
+ )
def test_task_states_for_dag_run(self):
@@ -188,58 +241,61 @@ def test_task_states_for_dag_run(self):
ti_end = ti2.end_date
with redirect_stdout(io.StringIO()) as stdout:
- task_command.task_states_for_dag_run(self.parser.parse_args([
- 'tasks', 'states-for-dag-run', 'example_python_operator', defaut_date2.isoformat()]))
+ task_command.task_states_for_dag_run(
+ self.parser.parse_args(
+ ['tasks', 'states-for-dag-run', 'example_python_operator', defaut_date2.isoformat()]
+ )
+ )
actual_out = stdout.getvalue()
- formatted_rows = [('example_python_operator',
- '2016-01-09 00:00:00+00:00',
- 'print_the_context',
- 'success',
- ti_start,
- ti_end)]
-
- expected = tabulate(formatted_rows,
- ['dag',
- 'exec_date',
- 'task',
- 'state',
- 'start_date',
- 'end_date'],
- tablefmt="plain")
+ formatted_rows = [
+ (
+ 'example_python_operator',
+ '2016-01-09 00:00:00+00:00',
+ 'print_the_context',
+ 'success',
+ ti_start,
+ ti_end,
+ )
+ ]
+
+ expected = tabulate(
+ formatted_rows, ['dag', 'exec_date', 'task', 'state', 'start_date', 'end_date'], tablefmt="plain"
+ )
# Check that prints, and log messages, are shown
self.assertIn(expected.replace("\n", ""), actual_out.replace("\n", ""))
def test_subdag_clear(self):
- args = self.parser.parse_args([
- 'tasks', 'clear', 'example_subdag_operator', '--yes'])
+ args = self.parser.parse_args(['tasks', 'clear', 'example_subdag_operator', '--yes'])
task_command.task_clear(args)
- args = self.parser.parse_args([
- 'tasks', 'clear', 'example_subdag_operator', '--yes', '--exclude-subdags'])
+ args = self.parser.parse_args(
+ ['tasks', 'clear', 'example_subdag_operator', '--yes', '--exclude-subdags']
+ )
task_command.task_clear(args)
def test_parentdag_downstream_clear(self):
- args = self.parser.parse_args([
- 'tasks', 'clear', 'example_subdag_operator.section-1', '--yes'])
+ args = self.parser.parse_args(['tasks', 'clear', 'example_subdag_operator.section-1', '--yes'])
task_command.task_clear(args)
- args = self.parser.parse_args([
- 'tasks', 'clear', 'example_subdag_operator.section-1', '--yes',
- '--exclude-parentdag'])
+ args = self.parser.parse_args(
+ ['tasks', 'clear', 'example_subdag_operator.section-1', '--yes', '--exclude-parentdag']
+ )
task_command.task_clear(args)
@pytest.mark.quarantined
def test_local_run(self):
- args = self.parser.parse_args([
- 'tasks',
- 'run',
- 'example_python_operator',
- 'print_the_context',
- '2018-04-27T08:39:51.298439+00:00',
- '--interactive',
- '--subdir',
- '/root/dags/example_python_operator.py'
- ])
+ args = self.parser.parse_args(
+ [
+ 'tasks',
+ 'run',
+ 'example_python_operator',
+ 'print_the_context',
+ '2018-04-27T08:39:51.298439+00:00',
+ '--interactive',
+ '--subdir',
+ '/root/dags/example_python_operator.py',
+ ]
+ )
dag = get_dag(args.subdir, args.dag_id)
reset(dag.dag_id)
@@ -253,7 +309,6 @@ def test_local_run(self):
class TestLogsfromTaskRunCommand(unittest.TestCase):
-
def setUp(self) -> None:
self.dag_id = "test_logging_dag"
self.task_id = "test_task"
@@ -299,40 +354,52 @@ def test_logging_with_run_task(self):
# as that is what gets displayed
with conf_vars({('core', 'dags_folder'): self.dag_path}):
- task_command.task_run(self.parser.parse_args([
- 'tasks', 'run', self.dag_id, self.task_id, '--local', self.execution_date_str]))
+ task_command.task_run(
+ self.parser.parse_args(
+ ['tasks', 'run', self.dag_id, self.task_id, '--local', self.execution_date_str]
+ )
+ )
with open(self.ti_log_file_path) as l_file:
logs = l_file.read()
- print(logs) # In case of a test failures this line would show detailed log
+ print(logs) # In case of a test failures this line would show detailed log
logs_list = logs.splitlines()
self.assertIn("INFO - Started process", logs)
self.assertIn(f"Subtask {self.task_id}", logs)
self.assertIn("standard_task_runner.py", logs)
- self.assertIn(f"INFO - Running: ['airflow', 'tasks', 'run', '{self.dag_id}', "
- f"'{self.task_id}', '{self.execution_date_str}',", logs)
+ self.assertIn(
+ f"INFO - Running: ['airflow', 'tasks', 'run', '{self.dag_id}', "
+ f"'{self.task_id}', '{self.execution_date_str}',",
+ logs,
+ )
self.assert_log_line("Log from DAG Logger", logs_list)
self.assert_log_line("Log from TI Logger", logs_list)
self.assert_log_line("Log from Print statement", logs_list, expect_from_logging_mixin=True)
- self.assertIn(f"INFO - Marking task as SUCCESS. dag_id={self.dag_id}, "
- f"task_id={self.task_id}, execution_date=20170101T000000", logs)
+ self.assertIn(
+ f"INFO - Marking task as SUCCESS. dag_id={self.dag_id}, "
+ f"task_id={self.task_id}, execution_date=20170101T000000",
+ logs,
+ )
@mock.patch("airflow.task.task_runner.standard_task_runner.CAN_FORK", False)
def test_logging_with_run_task_subprocess(self):
# We are not using self.assertLogs as we want to verify what actually is stored in the Log file
# as that is what gets displayed
with conf_vars({('core', 'dags_folder'): self.dag_path}):
- task_command.task_run(self.parser.parse_args([
- 'tasks', 'run', self.dag_id, self.task_id, '--local', self.execution_date_str]))
+ task_command.task_run(
+ self.parser.parse_args(
+ ['tasks', 'run', self.dag_id, self.task_id, '--local', self.execution_date_str]
+ )
+ )
with open(self.ti_log_file_path) as l_file:
logs = l_file.read()
- print(logs) # In case of a test failures this line would show detailed log
+ print(logs) # In case of a test failures this line would show detailed log
logs_list = logs.splitlines()
self.assertIn(f"Subtask {self.task_id}", logs)
@@ -341,10 +408,16 @@ def test_logging_with_run_task_subprocess(self):
self.assert_log_line("Log from TI Logger", logs_list)
self.assert_log_line("Log from Print statement", logs_list, expect_from_logging_mixin=True)
- self.assertIn(f"INFO - Running: ['airflow', 'tasks', 'run', '{self.dag_id}', "
- f"'{self.task_id}', '{self.execution_date_str}',", logs)
- self.assertIn(f"INFO - Marking task as SUCCESS. dag_id={self.dag_id}, "
- f"task_id={self.task_id}, execution_date=20170101T000000", logs)
+ self.assertIn(
+ f"INFO - Running: ['airflow', 'tasks', 'run', '{self.dag_id}', "
+ f"'{self.task_id}', '{self.execution_date_str}',",
+ logs,
+ )
+ self.assertIn(
+ f"INFO - Marking task as SUCCESS. dag_id={self.dag_id}, "
+ f"task_id={self.task_id}, execution_date=20170101T000000",
+ logs,
+ )
def test_log_file_template_with_run_task(self):
"""Verify that the taskinstance has the right context for log_filename_template"""
@@ -405,46 +478,43 @@ def test_run_ignores_all_dependencies(self):
dag.clear()
task0_id = 'test_run_dependent_task'
- args0 = ['tasks',
- 'run',
- '--ignore-all-dependencies',
- dag_id,
- task0_id,
- DEFAULT_DATE.isoformat()]
+ args0 = ['tasks', 'run', '--ignore-all-dependencies', dag_id, task0_id, DEFAULT_DATE.isoformat()]
task_command.task_run(self.parser.parse_args(args0))
- ti_dependent0 = TaskInstance(
- task=dag.get_task(task0_id),
- execution_date=DEFAULT_DATE)
+ ti_dependent0 = TaskInstance(task=dag.get_task(task0_id), execution_date=DEFAULT_DATE)
ti_dependent0.refresh_from_db()
self.assertEqual(ti_dependent0.state, State.FAILED)
task1_id = 'test_run_dependency_task'
- args1 = ['tasks',
- 'run',
- '--ignore-all-dependencies',
- dag_id,
- task1_id,
- (DEFAULT_DATE + timedelta(days=1)).isoformat()]
+ args1 = [
+ 'tasks',
+ 'run',
+ '--ignore-all-dependencies',
+ dag_id,
+ task1_id,
+ (DEFAULT_DATE + timedelta(days=1)).isoformat(),
+ ]
task_command.task_run(self.parser.parse_args(args1))
ti_dependency = TaskInstance(
- task=dag.get_task(task1_id),
- execution_date=DEFAULT_DATE + timedelta(days=1))
+ task=dag.get_task(task1_id), execution_date=DEFAULT_DATE + timedelta(days=1)
+ )
ti_dependency.refresh_from_db()
self.assertEqual(ti_dependency.state, State.FAILED)
task2_id = 'test_run_dependent_task'
- args2 = ['tasks',
- 'run',
- '--ignore-all-dependencies',
- dag_id,
- task2_id,
- (DEFAULT_DATE + timedelta(days=1)).isoformat()]
+ args2 = [
+ 'tasks',
+ 'run',
+ '--ignore-all-dependencies',
+ dag_id,
+ task2_id,
+ (DEFAULT_DATE + timedelta(days=1)).isoformat(),
+ ]
task_command.task_run(self.parser.parse_args(args2))
ti_dependent = TaskInstance(
- task=dag.get_task(task2_id),
- execution_date=DEFAULT_DATE + timedelta(days=1))
+ task=dag.get_task(task2_id), execution_date=DEFAULT_DATE + timedelta(days=1)
+ )
ti_dependent.refresh_from_db()
self.assertEqual(ti_dependent.state, State.SUCCESS)
diff --git a/tests/cli/commands/test_user_command.py b/tests/cli/commands/test_user_command.py
index 70595430ab5c4..188a977676b3c 100644
--- a/tests/cli/commands/test_user_command.py
+++ b/tests/cli/commands/test_user_command.py
@@ -47,6 +47,7 @@ def setUpClass(cls):
def setUp(self):
from airflow.www import app as application
+
self.app = application.create_app(testing=True)
self.appbuilder = self.app.appbuilder # pylint: disable=no-member
self.clear_roles_and_roles()
@@ -64,41 +65,94 @@ def clear_roles_and_roles(self):
self.appbuilder.sm.delete_role(role_name)
def test_cli_create_user_random_password(self):
- args = self.parser.parse_args([
- 'users', 'create', '--username', 'test1', '--lastname', 'doe',
- '--firstname', 'jon',
- '--email', 'jdoe@foo.com', '--role', 'Viewer', '--use-random-password'
- ])
+ args = self.parser.parse_args(
+ [
+ 'users',
+ 'create',
+ '--username',
+ 'test1',
+ '--lastname',
+ 'doe',
+ '--firstname',
+ 'jon',
+ '--email',
+ 'jdoe@foo.com',
+ '--role',
+ 'Viewer',
+ '--use-random-password',
+ ]
+ )
user_command.users_create(args)
def test_cli_create_user_supplied_password(self):
- args = self.parser.parse_args([
- 'users', 'create', '--username', 'test2', '--lastname', 'doe',
- '--firstname', 'jon',
- '--email', 'jdoe@apache.org', '--role', 'Viewer', '--password', 'test'
- ])
+ args = self.parser.parse_args(
+ [
+ 'users',
+ 'create',
+ '--username',
+ 'test2',
+ '--lastname',
+ 'doe',
+ '--firstname',
+ 'jon',
+ '--email',
+ 'jdoe@apache.org',
+ '--role',
+ 'Viewer',
+ '--password',
+ 'test',
+ ]
+ )
user_command.users_create(args)
def test_cli_delete_user(self):
- args = self.parser.parse_args([
- 'users', 'create', '--username', 'test3', '--lastname', 'doe',
- '--firstname', 'jon',
- '--email', 'jdoe@example.com', '--role', 'Viewer', '--use-random-password'
- ])
+ args = self.parser.parse_args(
+ [
+ 'users',
+ 'create',
+ '--username',
+ 'test3',
+ '--lastname',
+ 'doe',
+ '--firstname',
+ 'jon',
+ '--email',
+ 'jdoe@example.com',
+ '--role',
+ 'Viewer',
+ '--use-random-password',
+ ]
+ )
user_command.users_create(args)
- args = self.parser.parse_args([
- 'users', 'delete', '--username', 'test3',
- ])
+ args = self.parser.parse_args(
+ [
+ 'users',
+ 'delete',
+ '--username',
+ 'test3',
+ ]
+ )
user_command.users_delete(args)
def test_cli_list_users(self):
for i in range(0, 3):
- args = self.parser.parse_args([
- 'users', 'create', '--username', f'user{i}', '--lastname',
- 'doe', '--firstname', 'jon',
- '--email', f'jdoe+{i}@gmail.com', '--role', 'Viewer',
- '--use-random-password'
- ])
+ args = self.parser.parse_args(
+ [
+ 'users',
+ 'create',
+ '--username',
+ f'user{i}',
+ '--lastname',
+ 'doe',
+ '--firstname',
+ 'jon',
+ '--email',
+ f'jdoe+{i}@gmail.com',
+ '--role',
+ 'Viewer',
+ '--use-random-password',
+ ]
+ )
user_command.users_create(args)
with redirect_stdout(io.StringIO()) as stdout:
user_command.users_list(self.parser.parse_args(['users', 'list']))
@@ -122,15 +176,19 @@ def assert_user_not_in_roles(email, roles):
assert_user_not_in_roles(TEST_USER2_EMAIL, ['Public'])
users = [
{
- "username": "imported_user1", "lastname": "doe1",
- "firstname": "jon", "email": TEST_USER1_EMAIL,
- "roles": ["Admin", "Op"]
+ "username": "imported_user1",
+ "lastname": "doe1",
+ "firstname": "jon",
+ "email": TEST_USER1_EMAIL,
+ "roles": ["Admin", "Op"],
},
{
- "username": "imported_user2", "lastname": "doe2",
- "firstname": "jon", "email": TEST_USER2_EMAIL,
- "roles": ["Public"]
- }
+ "username": "imported_user2",
+ "lastname": "doe2",
+ "firstname": "jon",
+ "email": TEST_USER2_EMAIL,
+ "roles": ["Public"],
+ },
]
self._import_users_from_file(users)
@@ -139,15 +197,19 @@ def assert_user_not_in_roles(email, roles):
users = [
{
- "username": "imported_user1", "lastname": "doe1",
- "firstname": "jon", "email": TEST_USER1_EMAIL,
- "roles": ["Public"]
+ "username": "imported_user1",
+ "lastname": "doe1",
+ "firstname": "jon",
+ "email": TEST_USER1_EMAIL,
+ "roles": ["Public"],
},
{
- "username": "imported_user2", "lastname": "doe2",
- "firstname": "jon", "email": TEST_USER2_EMAIL,
- "roles": ["Admin"]
- }
+ "username": "imported_user2",
+ "lastname": "doe2",
+ "firstname": "jon",
+ "email": TEST_USER2_EMAIL,
+ "roles": ["Admin"],
+ },
]
self._import_users_from_file(users)
@@ -157,12 +219,20 @@ def assert_user_not_in_roles(email, roles):
assert_user_in_roles(TEST_USER2_EMAIL, ['Admin'])
def test_cli_export_users(self):
- user1 = {"username": "imported_user1", "lastname": "doe1",
- "firstname": "jon", "email": TEST_USER1_EMAIL,
- "roles": ["Public"]}
- user2 = {"username": "imported_user2", "lastname": "doe2",
- "firstname": "jon", "email": TEST_USER2_EMAIL,
- "roles": ["Admin"]}
+ user1 = {
+ "username": "imported_user1",
+ "lastname": "doe1",
+ "firstname": "jon",
+ "email": TEST_USER1_EMAIL,
+ "roles": ["Public"],
+ }
+ user2 = {
+ "username": "imported_user2",
+ "lastname": "doe2",
+ "firstname": "jon",
+ "email": TEST_USER2_EMAIL,
+ "roles": ["Admin"],
+ }
self._import_users_from_file([user1, user2])
users_filename = self._export_users_to_file()
@@ -190,63 +260,79 @@ def _import_users_from_file(self, user_list):
f.write(json_file_content.encode())
f.flush()
- args = self.parser.parse_args([
- 'users', 'import', f.name
- ])
+ args = self.parser.parse_args(['users', 'import', f.name])
user_command.users_import(args)
finally:
os.remove(f.name)
def _export_users_to_file(self):
f = tempfile.NamedTemporaryFile(delete=False)
- args = self.parser.parse_args([
- 'users', 'export', f.name
- ])
+ args = self.parser.parse_args(['users', 'export', f.name])
user_command.users_export(args)
return f.name
def test_cli_add_user_role(self):
- args = self.parser.parse_args([
- 'users', 'create', '--username', 'test4', '--lastname', 'doe',
- '--firstname', 'jon',
- '--email', TEST_USER1_EMAIL, '--role', 'Viewer', '--use-random-password'
- ])
+ args = self.parser.parse_args(
+ [
+ 'users',
+ 'create',
+ '--username',
+ 'test4',
+ '--lastname',
+ 'doe',
+ '--firstname',
+ 'jon',
+ '--email',
+ TEST_USER1_EMAIL,
+ '--role',
+ 'Viewer',
+ '--use-random-password',
+ ]
+ )
user_command.users_create(args)
self.assertFalse(
_does_user_belong_to_role(appbuilder=self.appbuilder, email=TEST_USER1_EMAIL, rolename='Op'),
- "User should not yet be a member of role 'Op'"
+ "User should not yet be a member of role 'Op'",
)
- args = self.parser.parse_args([
- 'users', 'add-role', '--username', 'test4', '--role', 'Op'
- ])
+ args = self.parser.parse_args(['users', 'add-role', '--username', 'test4', '--role', 'Op'])
user_command.users_manage_role(args, remove=False)
self.assertTrue(
_does_user_belong_to_role(appbuilder=self.appbuilder, email=TEST_USER1_EMAIL, rolename='Op'),
- "User should have been added to role 'Op'"
+ "User should have been added to role 'Op'",
)
def test_cli_remove_user_role(self):
- args = self.parser.parse_args([
- 'users', 'create', '--username', 'test4', '--lastname', 'doe',
- '--firstname', 'jon',
- '--email', TEST_USER1_EMAIL, '--role', 'Viewer', '--use-random-password'
- ])
+ args = self.parser.parse_args(
+ [
+ 'users',
+ 'create',
+ '--username',
+ 'test4',
+ '--lastname',
+ 'doe',
+ '--firstname',
+ 'jon',
+ '--email',
+ TEST_USER1_EMAIL,
+ '--role',
+ 'Viewer',
+ '--use-random-password',
+ ]
+ )
user_command.users_create(args)
self.assertTrue(
_does_user_belong_to_role(appbuilder=self.appbuilder, email=TEST_USER1_EMAIL, rolename='Viewer'),
- "User should have been created with role 'Viewer'"
+ "User should have been created with role 'Viewer'",
)
- args = self.parser.parse_args([
- 'users', 'remove-role', '--username', 'test4', '--role', 'Viewer'
- ])
+ args = self.parser.parse_args(['users', 'remove-role', '--username', 'test4', '--role', 'Viewer'])
user_command.users_manage_role(args, remove=True)
self.assertFalse(
_does_user_belong_to_role(appbuilder=self.appbuilder, email=TEST_USER1_EMAIL, rolename='Viewer'),
- "User should have been removed from role 'Viewer'"
+ "User should have been removed from role 'Viewer'",
)
diff --git a/tests/cli/commands/test_variable_command.py b/tests/cli/commands/test_variable_command.py
index 71a4d9842caf1..04c253921eba9 100644
--- a/tests/cli/commands/test_variable_command.py
+++ b/tests/cli/commands/test_variable_command.py
@@ -43,8 +43,7 @@ def tearDown(self):
def test_variables_set(self):
"""Test variable_set command"""
- variable_command.variables_set(self.parser.parse_args([
- 'variables', 'set', 'foo', 'bar']))
+ variable_command.variables_set(self.parser.parse_args(['variables', 'set', 'foo', 'bar']))
self.assertIsNotNone(Variable.get("foo"))
self.assertRaises(KeyError, Variable.get, "foo1")
@@ -52,53 +51,48 @@ def test_variables_get(self):
Variable.set('foo', {'foo': 'bar'}, serialize_json=True)
with redirect_stdout(io.StringIO()) as stdout:
- variable_command.variables_get(self.parser.parse_args([
- 'variables', 'get', 'foo']))
+ variable_command.variables_get(self.parser.parse_args(['variables', 'get', 'foo']))
self.assertEqual('{\n "foo": "bar"\n}\n', stdout.getvalue())
def test_get_variable_default_value(self):
with redirect_stdout(io.StringIO()) as stdout:
- variable_command.variables_get(self.parser.parse_args([
- 'variables', 'get', 'baz', '--default', 'bar']))
+ variable_command.variables_get(
+ self.parser.parse_args(['variables', 'get', 'baz', '--default', 'bar'])
+ )
self.assertEqual("bar\n", stdout.getvalue())
def test_get_variable_missing_variable(self):
with self.assertRaises(SystemExit):
- variable_command.variables_get(self.parser.parse_args([
- 'variables', 'get', 'no-existing-VAR']))
+ variable_command.variables_get(self.parser.parse_args(['variables', 'get', 'no-existing-VAR']))
def test_variables_set_different_types(self):
"""Test storage of various data types"""
# Set a dict
- variable_command.variables_set(self.parser.parse_args([
- 'variables', 'set', 'dict', '{"foo": "oops"}']))
+ variable_command.variables_set(
+ self.parser.parse_args(['variables', 'set', 'dict', '{"foo": "oops"}'])
+ )
# Set a list
- variable_command.variables_set(self.parser.parse_args([
- 'variables', 'set', 'list', '["oops"]']))
+ variable_command.variables_set(self.parser.parse_args(['variables', 'set', 'list', '["oops"]']))
# Set str
- variable_command.variables_set(self.parser.parse_args([
- 'variables', 'set', 'str', 'hello string']))
+ variable_command.variables_set(self.parser.parse_args(['variables', 'set', 'str', 'hello string']))
# Set int
- variable_command.variables_set(self.parser.parse_args([
- 'variables', 'set', 'int', '42']))
+ variable_command.variables_set(self.parser.parse_args(['variables', 'set', 'int', '42']))
# Set float
- variable_command.variables_set(self.parser.parse_args([
- 'variables', 'set', 'float', '42.0']))
+ variable_command.variables_set(self.parser.parse_args(['variables', 'set', 'float', '42.0']))
# Set true
- variable_command.variables_set(self.parser.parse_args([
- 'variables', 'set', 'true', 'true']))
+ variable_command.variables_set(self.parser.parse_args(['variables', 'set', 'true', 'true']))
# Set false
- variable_command.variables_set(self.parser.parse_args([
- 'variables', 'set', 'false', 'false']))
+ variable_command.variables_set(self.parser.parse_args(['variables', 'set', 'false', 'false']))
# Set none
- variable_command.variables_set(self.parser.parse_args([
- 'variables', 'set', 'null', 'null']))
+ variable_command.variables_set(self.parser.parse_args(['variables', 'set', 'null', 'null']))
# Export and then import
- variable_command.variables_export(self.parser.parse_args([
- 'variables', 'export', 'variables_types.json']))
- variable_command.variables_import(self.parser.parse_args([
- 'variables', 'import', 'variables_types.json']))
+ variable_command.variables_export(
+ self.parser.parse_args(['variables', 'export', 'variables_types.json'])
+ )
+ variable_command.variables_import(
+ self.parser.parse_args(['variables', 'import', 'variables_types.json'])
+ )
# Assert value
self.assertEqual({'foo': 'oops'}, Variable.get('dict', deserialize_json=True))
@@ -115,26 +109,21 @@ def test_variables_set_different_types(self):
def test_variables_list(self):
"""Test variable_list command"""
# Test command is received
- variable_command.variables_list(self.parser.parse_args([
- 'variables', 'list']))
+ variable_command.variables_list(self.parser.parse_args(['variables', 'list']))
def test_variables_delete(self):
"""Test variable_delete command"""
- variable_command.variables_set(self.parser.parse_args([
- 'variables', 'set', 'foo', 'bar']))
- variable_command.variables_delete(self.parser.parse_args([
- 'variables', 'delete', 'foo']))
+ variable_command.variables_set(self.parser.parse_args(['variables', 'set', 'foo', 'bar']))
+ variable_command.variables_delete(self.parser.parse_args(['variables', 'delete', 'foo']))
self.assertRaises(KeyError, Variable.get, "foo")
def test_variables_import(self):
"""Test variables_import command"""
- variable_command.variables_import(self.parser.parse_args([
- 'variables', 'import', os.devnull]))
+ variable_command.variables_import(self.parser.parse_args(['variables', 'import', os.devnull]))
def test_variables_export(self):
"""Test variables_export command"""
- variable_command.variables_export(self.parser.parse_args([
- 'variables', 'export', os.devnull]))
+ variable_command.variables_export(self.parser.parse_args(['variables', 'export', os.devnull]))
def test_variables_isolation(self):
"""Test isolation of variables"""
@@ -142,30 +131,22 @@ def test_variables_isolation(self):
tmp2 = tempfile.NamedTemporaryFile(delete=True)
# First export
- variable_command.variables_set(self.parser.parse_args([
- 'variables', 'set', 'foo', '{"foo":"bar"}']))
- variable_command.variables_set(self.parser.parse_args([
- 'variables', 'set', 'bar', 'original']))
- variable_command.variables_export(self.parser.parse_args([
- 'variables', 'export', tmp1.name]))
+ variable_command.variables_set(self.parser.parse_args(['variables', 'set', 'foo', '{"foo":"bar"}']))
+ variable_command.variables_set(self.parser.parse_args(['variables', 'set', 'bar', 'original']))
+ variable_command.variables_export(self.parser.parse_args(['variables', 'export', tmp1.name]))
first_exp = open(tmp1.name)
- variable_command.variables_set(self.parser.parse_args([
- 'variables', 'set', 'bar', 'updated']))
- variable_command.variables_set(self.parser.parse_args([
- 'variables', 'set', 'foo', '{"foo":"oops"}']))
- variable_command.variables_delete(self.parser.parse_args([
- 'variables', 'delete', 'foo']))
- variable_command.variables_import(self.parser.parse_args([
- 'variables', 'import', tmp1.name]))
+ variable_command.variables_set(self.parser.parse_args(['variables', 'set', 'bar', 'updated']))
+ variable_command.variables_set(self.parser.parse_args(['variables', 'set', 'foo', '{"foo":"oops"}']))
+ variable_command.variables_delete(self.parser.parse_args(['variables', 'delete', 'foo']))
+ variable_command.variables_import(self.parser.parse_args(['variables', 'import', tmp1.name]))
self.assertEqual('original', Variable.get('bar'))
self.assertEqual('{\n "foo": "bar"\n}', Variable.get('foo'))
# Second export
- variable_command.variables_export(self.parser.parse_args([
- 'variables', 'export', tmp2.name]))
+ variable_command.variables_export(self.parser.parse_args(['variables', 'export', tmp2.name]))
second_exp = open(tmp2.name)
self.assertEqual(first_exp.read(), second_exp.read())
diff --git a/tests/cli/commands/test_webserver_command.py b/tests/cli/commands/test_webserver_command.py
index 2468466fb6513..0829d8d611fa0 100644
--- a/tests/cli/commands/test_webserver_command.py
+++ b/tests/cli/commands/test_webserver_command.py
@@ -34,7 +34,6 @@
class TestGunicornMonitor(unittest.TestCase):
-
def setUp(self) -> None:
self.monitor = GunicornMonitor(
gunicorn_master_pid=1,
@@ -124,8 +123,9 @@ def _prepare_test_file(filepath: str, size: int):
file.flush()
def test_should_detect_changes_in_directory(self):
- with tempfile.TemporaryDirectory() as tempdir,\
- mock.patch("airflow.cli.commands.webserver_command.settings.PLUGINS_FOLDER", tempdir):
+ with tempfile.TemporaryDirectory() as tempdir, mock.patch(
+ "airflow.cli.commands.webserver_command.settings.PLUGINS_FOLDER", tempdir
+ ):
self._prepare_test_file(f"{tempdir}/file1.txt", 100)
self._prepare_test_file(f"{tempdir}/nested/nested/nested/nested/file2.txt", 200)
self._prepare_test_file(f"{tempdir}/file3.txt", 300)
@@ -172,7 +172,6 @@ def test_should_detect_changes_in_directory(self):
class TestCLIGetNumReadyWorkersRunning(unittest.TestCase):
-
@classmethod
def setUpClass(cls):
cls.parser = cli_parser.get_parser()
@@ -271,7 +270,7 @@ def test_cli_webserver_foreground(self):
"os.environ",
AIRFLOW__CORE__DAGS_FOLDER="/dev/null",
AIRFLOW__CORE__LOAD_EXAMPLES="False",
- AIRFLOW__WEBSERVER__WORKERS="1"
+ AIRFLOW__WEBSERVER__WORKERS="1",
):
# Run webserver in foreground and terminate it.
proc = subprocess.Popen(["airflow", "webserver"])
@@ -293,7 +292,7 @@ def test_cli_webserver_foreground_with_pid(self):
"os.environ",
AIRFLOW__CORE__DAGS_FOLDER="/dev/null",
AIRFLOW__CORE__LOAD_EXAMPLES="False",
- AIRFLOW__WEBSERVER__WORKERS="1"
+ AIRFLOW__WEBSERVER__WORKERS="1",
):
proc = subprocess.Popen(["airflow", "webserver", "--pid", pidfile])
self.assertEqual(None, proc.poll())
@@ -307,12 +306,12 @@ def test_cli_webserver_foreground_with_pid(self):
@pytest.mark.quarantined
def test_cli_webserver_background(self):
- with tempfile.TemporaryDirectory(prefix="gunicorn") as tmpdir, \
- mock.patch.dict(
- "os.environ",
- AIRFLOW__CORE__DAGS_FOLDER="/dev/null",
- AIRFLOW__CORE__LOAD_EXAMPLES="False",
- AIRFLOW__WEBSERVER__WORKERS="1"):
+ with tempfile.TemporaryDirectory(prefix="gunicorn") as tmpdir, mock.patch.dict(
+ "os.environ",
+ AIRFLOW__CORE__DAGS_FOLDER="/dev/null",
+ AIRFLOW__CORE__LOAD_EXAMPLES="False",
+ AIRFLOW__WEBSERVER__WORKERS="1",
+ ):
pidfile_webserver = f"{tmpdir}/pidflow-webserver.pid"
pidfile_monitor = f"{tmpdir}/pidflow-webserver-monitor.pid"
stdout = f"{tmpdir}/airflow-webserver.out"
@@ -320,15 +319,21 @@ def test_cli_webserver_background(self):
logfile = f"{tmpdir}/airflow-webserver.log"
try:
# Run webserver as daemon in background. Note that the wait method is not called.
- proc = subprocess.Popen([
- "airflow",
- "webserver",
- "--daemon",
- "--pid", pidfile_webserver,
- "--stdout", stdout,
- "--stderr", stderr,
- "--log-file", logfile,
- ])
+ proc = subprocess.Popen(
+ [
+ "airflow",
+ "webserver",
+ "--daemon",
+ "--pid",
+ pidfile_webserver,
+ "--stdout",
+ stdout,
+ "--stderr",
+ stderr,
+ "--log-file",
+ logfile,
+ ]
+ )
self.assertEqual(None, proc.poll())
pid_monitor = self._wait_pidfile(pidfile_monitor)
@@ -354,8 +359,9 @@ def test_cli_webserver_background(self):
raise
# Patch for causing webserver timeout
- @mock.patch("airflow.cli.commands.webserver_command.GunicornMonitor._get_num_workers_running",
- return_value=0)
+ @mock.patch(
+ "airflow.cli.commands.webserver_command.GunicornMonitor._get_num_workers_running", return_value=0
+ )
def test_cli_webserver_shutdown_when_gunicorn_master_is_killed(self, _):
# Shorten timeout so that this test doesn't take too long time
args = self.parser.parse_args(['webserver'])
@@ -370,8 +376,7 @@ def test_cli_webserver_debug(self):
sleep(3) # wait for webserver to start
return_code = proc.poll()
self.assertEqual(
- None,
- return_code,
- f"webserver terminated with return code {return_code} in debug mode")
+ None, return_code, f"webserver terminated with return code {return_code} in debug mode"
+ )
proc.terminate()
self.assertEqual(-15, proc.wait(60))
diff --git a/tests/cli/test_cli_parser.py b/tests/cli/test_cli_parser.py
index 9593d411149f5..4a2dd8524420f 100644
--- a/tests/cli/test_cli_parser.py
+++ b/tests/cli/test_cli_parser.py
@@ -35,34 +35,28 @@
class TestCli(TestCase):
-
def test_arg_option_long_only(self):
"""
Test if the name of cli.args long option valid
"""
optional_long = [
- arg
- for arg in cli_args.values()
- if len(arg.flags) == 1 and arg.flags[0].startswith("-")
+ arg for arg in cli_args.values() if len(arg.flags) == 1 and arg.flags[0].startswith("-")
]
for arg in optional_long:
- self.assertIsNone(ILLEGAL_LONG_OPTION_PATTERN.match(arg.flags[0]),
- f"{arg.flags[0]} is not match")
+ self.assertIsNone(ILLEGAL_LONG_OPTION_PATTERN.match(arg.flags[0]), f"{arg.flags[0]} is not match")
def test_arg_option_mix_short_long(self):
"""
Test if the name of cli.args mix option (-s, --long) valid
"""
optional_mix = [
- arg
- for arg in cli_args.values()
- if len(arg.flags) == 2 and arg.flags[0].startswith("-")
+ arg for arg in cli_args.values() if len(arg.flags) == 2 and arg.flags[0].startswith("-")
]
for arg in optional_mix:
- self.assertIsNotNone(LEGAL_SHORT_OPTION_PATTERN.match(arg.flags[0]),
- f"{arg.flags[0]} is not match")
- self.assertIsNone(ILLEGAL_LONG_OPTION_PATTERN.match(arg.flags[1]),
- f"{arg.flags[1]} is not match")
+ self.assertIsNotNone(
+ LEGAL_SHORT_OPTION_PATTERN.match(arg.flags[0]), f"{arg.flags[0]} is not match"
+ )
+ self.assertIsNone(ILLEGAL_LONG_OPTION_PATTERN.match(arg.flags[1]), f"{arg.flags[1]} is not match")
def test_subcommand_conflict(self):
"""
@@ -75,8 +69,9 @@ def test_subcommand_conflict(self):
}
for group_name, sub in subcommand.items():
name = [command.name.lower() for command in sub]
- self.assertEqual(len(name), len(set(name)),
- f"Command group {group_name} have conflict subcommand")
+ self.assertEqual(
+ len(name), len(set(name)), f"Command group {group_name} have conflict subcommand"
+ )
def test_subcommand_arg_name_conflict(self):
"""
@@ -90,9 +85,11 @@ def test_subcommand_arg_name_conflict(self):
for group, command in subcommand.items():
for com in command:
conflict_arg = [arg for arg, count in Counter(com.args).items() if count > 1]
- self.assertListEqual([], conflict_arg,
- f"Command group {group} function {com.name} have "
- f"conflict args name {conflict_arg}")
+ self.assertListEqual(
+ [],
+ conflict_arg,
+ f"Command group {group} function {com.name} have " f"conflict args name {conflict_arg}",
+ )
def test_subcommand_arg_flag_conflict(self):
"""
@@ -106,35 +103,35 @@ def test_subcommand_arg_flag_conflict(self):
for group, command in subcommand.items():
for com in command:
position = [
- a.flags[0]
- for a in com.args
- if (len(a.flags) == 1
- and not a.flags[0].startswith("-"))
+ a.flags[0] for a in com.args if (len(a.flags) == 1 and not a.flags[0].startswith("-"))
]
conflict_position = [arg for arg, count in Counter(position).items() if count > 1]
- self.assertListEqual([], conflict_position,
- f"Command group {group} function {com.name} have conflict "
- f"position flags {conflict_position}")
-
- long_option = [a.flags[0]
- for a in com.args
- if (len(a.flags) == 1
- and a.flags[0].startswith("-"))] + \
- [a.flags[1]
- for a in com.args if len(a.flags) == 2]
+ self.assertListEqual(
+ [],
+ conflict_position,
+ f"Command group {group} function {com.name} have conflict "
+ f"position flags {conflict_position}",
+ )
+
+ long_option = [
+ a.flags[0] for a in com.args if (len(a.flags) == 1 and a.flags[0].startswith("-"))
+ ] + [a.flags[1] for a in com.args if len(a.flags) == 2]
conflict_long_option = [arg for arg, count in Counter(long_option).items() if count > 1]
- self.assertListEqual([], conflict_long_option,
- f"Command group {group} function {com.name} have conflict "
- f"long option flags {conflict_long_option}")
-
- short_option = [
- a.flags[0]
- for a in com.args if len(a.flags) == 2
- ]
+ self.assertListEqual(
+ [],
+ conflict_long_option,
+ f"Command group {group} function {com.name} have conflict "
+ f"long option flags {conflict_long_option}",
+ )
+
+ short_option = [a.flags[0] for a in com.args if len(a.flags) == 2]
conflict_short_option = [arg for arg, count in Counter(short_option).items() if count > 1]
- self.assertEqual([], conflict_short_option,
- f"Command group {group} function {com.name} have conflict "
- f"short option flags {conflict_short_option}")
+ self.assertEqual(
+ [],
+ conflict_short_option,
+ f"Command group {group} function {com.name} have conflict "
+ f"short option flags {conflict_short_option}",
+ )
def test_falsy_default_value(self):
arg = cli_parser.Arg(("--test",), default=0, type=int)
@@ -166,9 +163,7 @@ def test_should_display_help(self):
for command_as_args in (
[[top_command.name]]
if isinstance(top_command, cli_parser.ActionCommand)
- else [
- [top_command.name, nested_command.name] for nested_command in top_command.subcommands
- ]
+ else [[top_command.name, nested_command.name] for nested_command in top_command.subcommands]
)
]
for cmd_args in all_command_as_args:
diff --git a/tests/cluster_policies/__init__.py b/tests/cluster_policies/__init__.py
index 4d9c4b356b020..77e5253f2254b 100644
--- a/tests/cluster_policies/__init__.py
+++ b/tests/cluster_policies/__init__.py
@@ -24,10 +24,12 @@
# [START example_cluster_policy_rule]
def task_must_have_owners(task: BaseOperator):
- if not task.owner or task.owner.lower() == conf.get('operators',
- 'default_owner'):
+ if not task.owner or task.owner.lower() == conf.get('operators', 'default_owner'):
raise AirflowClusterPolicyViolation(
- f'''Task must have non-None non-default owner. Current value: {task.owner}''')
+ f'''Task must have non-None non-default owner. Current value: {task.owner}'''
+ )
+
+
# [END example_cluster_policy_rule]
@@ -50,10 +52,13 @@ def _check_task_rules(current_task: BaseOperator):
raise AirflowClusterPolicyViolation(
f"DAG policy violation (DAG ID: {current_task.dag_id}, Path: {current_task.dag.filepath}):\n"
f"Notices:\n"
- f"{notices_list}")
+ f"{notices_list}"
+ )
def cluster_policy(task: BaseOperator):
"""Ensure Tasks have non-default owners."""
_check_task_rules(task)
+
+
# [END example_list_of_cluster_policy_rules]
diff --git a/tests/conftest.py b/tests/conftest.py
index dee74c1e327cb..98a7d1f5b0478 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -30,11 +30,12 @@
os.environ["AIRFLOW__CORE__DAGS_FOLDER"] = os.path.join(tests_directory, "dags")
os.environ["AIRFLOW__CORE__UNIT_TEST_MODE"] = "True"
-os.environ["AWS_DEFAULT_REGION"] = (os.environ.get("AWS_DEFAULT_REGION") or "us-east-1")
-os.environ["CREDENTIALS_DIR"] = (os.environ.get('CREDENTIALS_DIR') or "/files/airflow-breeze-config/keys")
+os.environ["AWS_DEFAULT_REGION"] = os.environ.get("AWS_DEFAULT_REGION") or "us-east-1"
+os.environ["CREDENTIALS_DIR"] = os.environ.get('CREDENTIALS_DIR') or "/files/airflow-breeze-config/keys"
from tests.test_utils.perf.perf_kit.sqlalchemy import ( # noqa isort:skip # pylint: disable=wrong-import-position
- count_queries, trace_queries
+ count_queries,
+ trace_queries,
)
@@ -60,6 +61,7 @@ def reset_db():
"""
from airflow.utils import db
+
db.resetdb()
yield
@@ -94,11 +96,7 @@ def pytest_print(text):
if columns == ['num']:
# It is very unlikely that the user wants to display only numbers, but probably
# the user just wants to count the queries.
- exit_stack.enter_context( # pylint: disable=no-member
- count_queries(
- print_fn=pytest_print
- )
- )
+ exit_stack.enter_context(count_queries(print_fn=pytest_print)) # pylint: disable=no-member
elif any(c for c in ['time', 'trace', 'sql', 'parameters']):
exit_stack.enter_context( # pylint: disable=no-member
trace_queries(
@@ -107,7 +105,7 @@ def pytest_print(text):
display_trace='trace' in columns,
display_sql='sql' in columns,
display_parameters='parameters' in columns,
- print_fn=pytest_print
+ print_fn=pytest_print,
)
)
@@ -130,7 +128,7 @@ def pytest_addoption(parser):
action="append",
metavar="INTEGRATIONS",
help="only run tests matching integration specified: "
- "[cassandra,kerberos,mongo,openldap,presto,rabbitmq,redis]. ",
+ "[cassandra,kerberos,mongo,openldap,presto,rabbitmq,redis]. ",
)
group.addoption(
"--backend",
@@ -177,6 +175,7 @@ def initial_db_init():
os.system("airflow resetdb -y")
else:
from airflow.utils import db
+
db.resetdb()
@@ -193,6 +192,7 @@ def breeze_test_helper(request):
return
from airflow import __version__
+
if __version__.startswith("1.10"):
os.environ['RUN_AIRFLOW_1_10'] = "true"
@@ -235,18 +235,10 @@ def breeze_test_helper(request):
def pytest_configure(config):
- config.addinivalue_line(
- "markers", "integration(name): mark test to run with named integration"
- )
- config.addinivalue_line(
- "markers", "backend(name): mark test to run with named backend"
- )
- config.addinivalue_line(
- "markers", "system(name): mark test to run with named system"
- )
- config.addinivalue_line(
- "markers", "long_running: mark test that run for a long time (many minutes)"
- )
+ config.addinivalue_line("markers", "integration(name): mark test to run with named integration")
+ config.addinivalue_line("markers", "backend(name): mark test to run with named backend")
+ config.addinivalue_line("markers", "system(name): mark test to run with named system")
+ config.addinivalue_line("markers", "long_running: mark test that run for a long time (many minutes)")
config.addinivalue_line(
"markers", "quarantined: mark test that are in quarantine (i.e. flaky, need to be isolated and fixed)"
)
@@ -256,9 +248,7 @@ def pytest_configure(config):
config.addinivalue_line(
"markers", "credential_file(name): mark tests that require credential file in CREDENTIALS_DIR"
)
- config.addinivalue_line(
- "markers", "airflow_2: mark tests that works only on Airflow 2.0 / master"
- )
+ config.addinivalue_line("markers", "airflow_2: mark tests that works only on Airflow 2.0 / master")
def skip_if_not_marked_with_integration(selected_integrations, item):
@@ -266,10 +256,11 @@ def skip_if_not_marked_with_integration(selected_integrations, item):
integration_name = marker.args[0]
if integration_name in selected_integrations or "all" in selected_integrations:
return
- pytest.skip("The test is skipped because it does not have the right integration marker. "
- "Only tests marked with pytest.mark.integration(INTEGRATION) are run with INTEGRATION"
- " being one of {integration}. {item}".
- format(integration=selected_integrations, item=item))
+ pytest.skip(
+ "The test is skipped because it does not have the right integration marker. "
+ "Only tests marked with pytest.mark.integration(INTEGRATION) are run with INTEGRATION"
+ " being one of {integration}. {item}".format(integration=selected_integrations, item=item)
+ )
def skip_if_not_marked_with_backend(selected_backend, item):
@@ -277,10 +268,11 @@ def skip_if_not_marked_with_backend(selected_backend, item):
backend_names = marker.args
if selected_backend in backend_names:
return
- pytest.skip("The test is skipped because it does not have the right backend marker "
- "Only tests marked with pytest.mark.backend('{backend}') are run"
- ": {item}".
- format(backend=selected_backend, item=item))
+ pytest.skip(
+ "The test is skipped because it does not have the right backend marker "
+ "Only tests marked with pytest.mark.backend('{backend}') are run"
+ ": {item}".format(backend=selected_backend, item=item)
+ )
def skip_if_not_marked_with_system(selected_systems, item):
@@ -288,39 +280,46 @@ def skip_if_not_marked_with_system(selected_systems, item):
systems_name = marker.args[0]
if systems_name in selected_systems or "all" in selected_systems:
return
- pytest.skip("The test is skipped because it does not have the right system marker. "
- "Only tests marked with pytest.mark.system(SYSTEM) are run with SYSTEM"
- " being one of {systems}. {item}".
- format(systems=selected_systems, item=item))
+ pytest.skip(
+ "The test is skipped because it does not have the right system marker. "
+ "Only tests marked with pytest.mark.system(SYSTEM) are run with SYSTEM"
+ " being one of {systems}. {item}".format(systems=selected_systems, item=item)
+ )
def skip_system_test(item):
for marker in item.iter_markers(name="system"):
- pytest.skip("The test is skipped because it has system marker. "
- "System tests are only run when --system flag "
- "with the right system ({system}) is passed to pytest. {item}".
- format(system=marker.args[0], item=item))
+ pytest.skip(
+ "The test is skipped because it has system marker. "
+ "System tests are only run when --system flag "
+ "with the right system ({system}) is passed to pytest. {item}".format(
+ system=marker.args[0], item=item
+ )
+ )
def skip_long_running_test(item):
for _ in item.iter_markers(name="long_running"):
- pytest.skip("The test is skipped because it has long_running marker. "
- "And --include-long-running flag is not passed to pytest. {item}".
- format(item=item))
+ pytest.skip(
+ "The test is skipped because it has long_running marker. "
+ "And --include-long-running flag is not passed to pytest. {item}".format(item=item)
+ )
def skip_quarantined_test(item):
for _ in item.iter_markers(name="quarantined"):
- pytest.skip("The test is skipped because it has quarantined marker. "
- "And --include-quarantined flag is passed to pytest. {item}".
- format(item=item))
+ pytest.skip(
+ "The test is skipped because it has quarantined marker. "
+ "And --include-quarantined flag is passed to pytest. {item}".format(item=item)
+ )
def skip_heisen_test(item):
for _ in item.iter_markers(name="heisentests"):
- pytest.skip("The test is skipped because it has heisentests marker. "
- "And --include-heisentests flag is passed to pytest. {item}".
- format(item=item))
+ pytest.skip(
+ "The test is skipped because it has heisentests marker. "
+ "And --include-heisentests flag is passed to pytest. {item}".format(item=item)
+ )
def skip_if_integration_disabled(marker, item):
@@ -328,12 +327,17 @@ def skip_if_integration_disabled(marker, item):
environment_variable_name = "INTEGRATION_" + integration_name.upper()
environment_variable_value = os.environ.get(environment_variable_name)
if not environment_variable_value or environment_variable_value != "true":
- pytest.skip("The test requires {integration_name} integration started and "
- "{name} environment variable to be set to true (it is '{value}')."
- " It can be set by specifying '--integration {integration_name}' at breeze startup"
- ": {item}".
- format(name=environment_variable_name, value=environment_variable_value,
- integration_name=integration_name, item=item))
+ pytest.skip(
+ "The test requires {integration_name} integration started and "
+ "{name} environment variable to be set to true (it is '{value}')."
+ " It can be set by specifying '--integration {integration_name}' at breeze startup"
+ ": {item}".format(
+ name=environment_variable_name,
+ value=environment_variable_value,
+ integration_name=integration_name,
+ item=item,
+ )
+ )
def skip_if_wrong_backend(marker, item):
@@ -341,12 +345,17 @@ def skip_if_wrong_backend(marker, item):
environment_variable_name = "BACKEND"
environment_variable_value = os.environ.get(environment_variable_name)
if not environment_variable_value or environment_variable_value not in valid_backend_names:
- pytest.skip("The test requires one of {valid_backend_names} backend started and "
- "{name} environment variable to be set to 'true' (it is '{value}')."
- " It can be set by specifying backend at breeze startup"
- ": {item}".
- format(name=environment_variable_name, value=environment_variable_value,
- valid_backend_names=valid_backend_names, item=item))
+ pytest.skip(
+ "The test requires one of {valid_backend_names} backend started and "
+ "{name} environment variable to be set to 'true' (it is '{value}')."
+ " It can be set by specifying backend at breeze startup"
+ ": {item}".format(
+ name=environment_variable_name,
+ value=environment_variable_value,
+ valid_backend_names=valid_backend_names,
+ item=item,
+ )
+ )
def skip_if_credential_file_missing(item):
@@ -354,8 +363,7 @@ def skip_if_credential_file_missing(item):
credential_file = marker.args[0]
credential_path = os.path.join(os.environ.get('CREDENTIALS_DIR'), credential_file)
if not os.path.exists(credential_path):
- pytest.skip("The test requires credential file {path}: {item}".
- format(path=credential_path, item=item))
+ pytest.skip(f"The test requires credential file {credential_path}: {item}")
def skip_if_airflow_2_test(item):
diff --git a/tests/core/test_config_templates.py b/tests/core/test_config_templates.py
index af6cc4cc91e12..f603739074bb5 100644
--- a/tests/core/test_config_templates.py
+++ b/tests/core/test_config_templates.py
@@ -69,24 +69,34 @@
'admin',
'elasticsearch',
'elasticsearch_configs',
- 'kubernetes'
+ 'kubernetes',
]
class TestAirflowCfg(unittest.TestCase):
- @parameterized.expand([
- ("default_airflow.cfg",),
- ("default_test.cfg",),
- ])
+ @parameterized.expand(
+ [
+ ("default_airflow.cfg",),
+ ("default_test.cfg",),
+ ]
+ )
def test_should_be_ascii_file(self, filename: str):
with open(os.path.join(CONFIG_TEMPLATES_FOLDER, filename), "rb") as f:
content = f.read().decode("ascii")
self.assertTrue(content)
- @parameterized.expand([
- ("default_airflow.cfg", DEFAULT_AIRFLOW_SECTIONS,),
- ("default_test.cfg", DEFAULT_TEST_SECTIONS,),
- ])
+ @parameterized.expand(
+ [
+ (
+ "default_airflow.cfg",
+ DEFAULT_AIRFLOW_SECTIONS,
+ ),
+ (
+ "default_test.cfg",
+ DEFAULT_TEST_SECTIONS,
+ ),
+ ]
+ )
def test_should_be_ini_file(self, filename: str, expected_sections):
filepath = os.path.join(CONFIG_TEMPLATES_FOLDER, filename)
config = configparser.ConfigParser()
diff --git a/tests/core/test_configuration.py b/tests/core/test_configuration.py
index 7a07c3be02105..e202ce5b51a77 100644
--- a/tests/core/test_configuration.py
+++ b/tests/core/test_configuration.py
@@ -26,48 +26,49 @@
from airflow import configuration
from airflow.configuration import (
- DEFAULT_CONFIG, AirflowConfigException, AirflowConfigParser, conf, expand_env_var, get_airflow_config,
- get_airflow_home, parameterized_config, run_command,
+ DEFAULT_CONFIG,
+ AirflowConfigException,
+ AirflowConfigParser,
+ conf,
+ expand_env_var,
+ get_airflow_config,
+ get_airflow_home,
+ parameterized_config,
+ run_command,
)
from tests.test_utils.config import conf_vars
from tests.test_utils.reset_warning_registry import reset_warning_registry
-@unittest.mock.patch.dict('os.environ', {
- 'AIRFLOW__TESTSECTION__TESTKEY': 'testvalue',
- 'AIRFLOW__TESTSECTION__TESTPERCENT': 'with%percent',
- 'AIRFLOW__TESTCMDENV__ITSACOMMAND_CMD': 'echo -n "OK"',
- 'AIRFLOW__TESTCMDENV__NOTACOMMAND_CMD': 'echo -n "NOT OK"'
-})
+@unittest.mock.patch.dict(
+ 'os.environ',
+ {
+ 'AIRFLOW__TESTSECTION__TESTKEY': 'testvalue',
+ 'AIRFLOW__TESTSECTION__TESTPERCENT': 'with%percent',
+ 'AIRFLOW__TESTCMDENV__ITSACOMMAND_CMD': 'echo -n "OK"',
+ 'AIRFLOW__TESTCMDENV__NOTACOMMAND_CMD': 'echo -n "NOT OK"',
+ },
+)
class TestConf(unittest.TestCase):
-
def test_airflow_home_default(self):
with unittest.mock.patch.dict('os.environ'):
if 'AIRFLOW_HOME' in os.environ:
del os.environ['AIRFLOW_HOME']
- self.assertEqual(
- get_airflow_home(),
- expand_env_var('~/airflow'))
+ self.assertEqual(get_airflow_home(), expand_env_var('~/airflow'))
def test_airflow_home_override(self):
with unittest.mock.patch.dict('os.environ', AIRFLOW_HOME='/path/to/airflow'):
- self.assertEqual(
- get_airflow_home(),
- '/path/to/airflow')
+ self.assertEqual(get_airflow_home(), '/path/to/airflow')
def test_airflow_config_default(self):
with unittest.mock.patch.dict('os.environ'):
if 'AIRFLOW_CONFIG' in os.environ:
del os.environ['AIRFLOW_CONFIG']
- self.assertEqual(
- get_airflow_config('/home/airflow'),
- expand_env_var('/home/airflow/airflow.cfg'))
+ self.assertEqual(get_airflow_config('/home/airflow'), expand_env_var('/home/airflow/airflow.cfg'))
def test_airflow_config_override(self):
with unittest.mock.patch.dict('os.environ', AIRFLOW_CONFIG='/path/to/airflow/airflow.cfg'):
- self.assertEqual(
- get_airflow_config('/home//airflow'),
- '/path/to/airflow/airflow.cfg')
+ self.assertEqual(get_airflow_config('/home//airflow'), '/path/to/airflow/airflow.cfg')
@conf_vars({("core", "percent"): "with%%inside"})
def test_case_sensitivity(self):
@@ -87,15 +88,13 @@ def test_env_var_config(self):
self.assertTrue(conf.has_option('testsection', 'testkey'))
with unittest.mock.patch.dict(
- 'os.environ',
- AIRFLOW__KUBERNETES_ENVIRONMENT_VARIABLES__AIRFLOW__TESTSECTION__TESTKEY='nested'
+ 'os.environ', AIRFLOW__KUBERNETES_ENVIRONMENT_VARIABLES__AIRFLOW__TESTSECTION__TESTKEY='nested'
):
opt = conf.get('kubernetes_environment_variables', 'AIRFLOW__TESTSECTION__TESTKEY')
self.assertEqual(opt, 'nested')
@mock.patch.dict(
- 'os.environ',
- AIRFLOW__KUBERNETES_ENVIRONMENT_VARIABLES__AIRFLOW__TESTSECTION__TESTKEY='nested'
+ 'os.environ', AIRFLOW__KUBERNETES_ENVIRONMENT_VARIABLES__AIRFLOW__TESTSECTION__TESTKEY='nested'
)
@conf_vars({("core", "percent"): "with%%inside"})
def test_conf_as_dict(self):
@@ -109,18 +108,15 @@ def test_conf_as_dict(self):
# test env vars
self.assertEqual(cfg_dict['testsection']['testkey'], '< hidden >')
self.assertEqual(
- cfg_dict['kubernetes_environment_variables']['AIRFLOW__TESTSECTION__TESTKEY'],
- '< hidden >')
+ cfg_dict['kubernetes_environment_variables']['AIRFLOW__TESTSECTION__TESTKEY'], '< hidden >'
+ )
def test_conf_as_dict_source(self):
# test display_source
cfg_dict = conf.as_dict(display_source=True)
- self.assertEqual(
- cfg_dict['core']['load_examples'][1], 'airflow.cfg')
- self.assertEqual(
- cfg_dict['core']['load_default_connections'][1], 'airflow.cfg')
- self.assertEqual(
- cfg_dict['testsection']['testkey'], ('< hidden >', 'env var'))
+ self.assertEqual(cfg_dict['core']['load_examples'][1], 'airflow.cfg')
+ self.assertEqual(cfg_dict['core']['load_default_connections'][1], 'airflow.cfg')
+ self.assertEqual(cfg_dict['testsection']['testkey'], ('< hidden >', 'env var'))
def test_conf_as_dict_sensitive(self):
# test display_sensitive
@@ -130,8 +126,7 @@ def test_conf_as_dict_sensitive(self):
# test display_source and display_sensitive
cfg_dict = conf.as_dict(display_sensitive=True, display_source=True)
- self.assertEqual(
- cfg_dict['testsection']['testkey'], ('testvalue', 'env var'))
+ self.assertEqual(cfg_dict['testsection']['testkey'], ('testvalue', 'env var'))
@conf_vars({("core", "percent"): "with%%inside"})
def test_conf_as_dict_raw(self):
@@ -166,8 +161,7 @@ def test_command_precedence(self):
key6 = value6
'''
- test_conf = AirflowConfigParser(
- default_config=parameterized_config(test_config_default))
+ test_conf = AirflowConfigParser(default_config=parameterized_config(test_config_default))
test_conf.read_string(test_config)
test_conf.sensitive_config_values = test_conf.sensitive_config_values | {
('test', 'key2'),
@@ -204,10 +198,12 @@ def test_command_precedence(self):
self.assertEqual('printf key4_result', cfg_dict['test']['key4_cmd'])
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
- @conf_vars({
- ("secrets", "backend"): "airflow.providers.hashicorp.secrets.vault.VaultBackend",
- ("secrets", "backend_kwargs"): '{"url": "http://127.0.0.1:8200", "token": "token"}',
- })
+ @conf_vars(
+ {
+ ("secrets", "backend"): "airflow.providers.hashicorp.secrets.vault.VaultBackend",
+ ("secrets", "backend_kwargs"): '{"url": "http://127.0.0.1:8200", "token": "token"}',
+ }
+ )
def test_config_from_secret_backend(self, mock_hvac):
"""Get Config Value from a Secret Backend"""
mock_client = mock.MagicMock()
@@ -217,14 +213,18 @@ def test_config_from_secret_backend(self, mock_hvac):
'lease_id': '',
'renewable': False,
'lease_duration': 0,
- 'data': {'data': {'value': 'sqlite:////Users/airflow/airflow/airflow.db'},
- 'metadata': {'created_time': '2020-03-28T02:10:54.301784Z',
- 'deletion_time': '',
- 'destroyed': False,
- 'version': 1}},
+ 'data': {
+ 'data': {'value': 'sqlite:////Users/airflow/airflow/airflow.db'},
+ 'metadata': {
+ 'created_time': '2020-03-28T02:10:54.301784Z',
+ 'deletion_time': '',
+ 'destroyed': False,
+ 'version': 1,
+ },
+ },
'wrap_info': None,
'warnings': None,
- 'auth': None
+ 'auth': None,
}
test_config = '''[test]
@@ -241,7 +241,8 @@ def test_config_from_secret_backend(self, mock_hvac):
}
self.assertEqual(
- 'sqlite:////Users/airflow/airflow/airflow.db', test_conf.get('test', 'sql_alchemy_conn'))
+ 'sqlite:////Users/airflow/airflow/airflow.db', test_conf.get('test', 'sql_alchemy_conn')
+ )
def test_getboolean(self):
"""Test AirflowConfigParser.getboolean"""
@@ -268,7 +269,7 @@ def test_getboolean(self):
re.escape(
'Failed to convert value to bool. Please check "key1" key in "type_validation" section. '
'Current value: "non_bool_value".'
- )
+ ),
):
test_conf.getboolean('type_validation', 'key1')
self.assertTrue(isinstance(test_conf.getboolean('true', 'key3'), bool))
@@ -295,7 +296,7 @@ def test_getint(self):
re.escape(
'Failed to convert value to int. Please check "key1" key in "invalid" section. '
'Current value: "str".'
- )
+ ),
):
test_conf.getint('invalid', 'key1')
self.assertTrue(isinstance(test_conf.getint('valid', 'key2'), int))
@@ -316,7 +317,7 @@ def test_getfloat(self):
re.escape(
'Failed to convert value to float. Please check "key1" key in "invalid" section. '
'Current value: "str".'
- )
+ ),
):
test_conf.getfloat('invalid', 'key1')
self.assertTrue(isinstance(test_conf.getfloat('valid', 'key2'), float))
@@ -342,8 +343,7 @@ def test_remove_option(self):
key2 = airflow
'''
- test_conf = AirflowConfigParser(
- default_config=parameterized_config(test_config_default))
+ test_conf = AirflowConfigParser(default_config=parameterized_config(test_config_default))
test_conf.read_string(test_config)
self.assertEqual('hello', test_conf.get('test', 'key1'))
@@ -366,20 +366,13 @@ def test_getsection(self):
[testsection]
key3 = value3
'''
- test_conf = AirflowConfigParser(
- default_config=parameterized_config(test_config_default))
+ test_conf = AirflowConfigParser(default_config=parameterized_config(test_config_default))
test_conf.read_string(test_config)
+ self.assertEqual(OrderedDict([('key1', 'hello'), ('key2', 'airflow')]), test_conf.getsection('test'))
self.assertEqual(
- OrderedDict([('key1', 'hello'), ('key2', 'airflow')]),
- test_conf.getsection('test')
- )
- self.assertEqual(
- OrderedDict([
- ('key3', 'value3'),
- ('testkey', 'testvalue'),
- ('testpercent', 'with%percent')]),
- test_conf.getsection('testsection')
+ OrderedDict([('key3', 'value3'), ('testkey', 'testvalue'), ('testpercent', 'with%percent')]),
+ test_conf.getsection('testsection'),
)
def test_get_section_should_respect_cmd_env_variable(self):
@@ -390,9 +383,7 @@ def test_get_section_should_respect_cmd_env_variable(self):
os.chmod(cmd_file.name, 0o0555)
cmd_file.close()
- with mock.patch.dict(
- "os.environ", {"AIRFLOW__KUBERNETES__GIT_PASSWORD_CMD": cmd_file.name}
- ):
+ with mock.patch.dict("os.environ", {"AIRFLOW__KUBERNETES__GIT_PASSWORD_CMD": cmd_file.name}):
content = conf.getsection("kubernetes")
os.unlink(cmd_file.name)
self.assertEqual(content["git_password"], "difficult_unpredictable_cat_password")
@@ -406,13 +397,12 @@ def test_kubernetes_environment_variables_section(self):
test_config_default = '''
[kubernetes_environment_variables]
'''
- test_conf = AirflowConfigParser(
- default_config=parameterized_config(test_config_default))
+ test_conf = AirflowConfigParser(default_config=parameterized_config(test_config_default))
test_conf.read_string(test_config)
self.assertEqual(
OrderedDict([('key1', 'hello'), ('AIRFLOW_HOME', '/root/airflow')]),
- test_conf.getsection('kubernetes_environment_variables')
+ test_conf.getsection('kubernetes_environment_variables'),
)
def test_broker_transport_options(self):
@@ -422,10 +412,12 @@ def test_broker_transport_options(self):
self.assertTrue(isinstance(section_dict['_test_only_float'], float))
self.assertTrue(isinstance(section_dict['_test_only_string'], str))
- @conf_vars({
- ("celery", "worker_concurrency"): None,
- ("celery", "celeryd_concurrency"): None,
- })
+ @conf_vars(
+ {
+ ("celery", "worker_concurrency"): None,
+ ("celery", "celeryd_concurrency"): None,
+ }
+ )
def test_deprecated_options(self):
# Guarantee we have a deprecated setting, so we test the deprecation
# lookup even if we remove this explicit fallback
@@ -443,10 +435,12 @@ def test_deprecated_options(self):
with self.assertWarns(DeprecationWarning), conf_vars({('celery', 'celeryd_concurrency'): '99'}):
self.assertEqual(conf.getint('celery', 'worker_concurrency'), 99)
- @conf_vars({
- ('logging', 'logging_level'): None,
- ('core', 'logging_level'): None,
- })
+ @conf_vars(
+ {
+ ('logging', 'logging_level'): None,
+ ('core', 'logging_level'): None,
+ }
+ )
def test_deprecated_options_with_new_section(self):
# Guarantee we have a deprecated setting, so we test the deprecation
# lookup even if we remove this explicit fallback
@@ -465,11 +459,13 @@ def test_deprecated_options_with_new_section(self):
with self.assertWarns(DeprecationWarning), conf_vars({('core', 'logging_level'): 'VALUE'}):
self.assertEqual(conf.get('logging', 'logging_level'), "VALUE")
- @conf_vars({
- ("celery", "result_backend"): None,
- ("celery", "celery_result_backend"): None,
- ("celery", "celery_result_backend_cmd"): None,
- })
+ @conf_vars(
+ {
+ ("celery", "result_backend"): None,
+ ("celery", "celery_result_backend"): None,
+ ("celery", "celery_result_backend_cmd"): None,
+ }
+ )
def test_deprecated_options_cmd(self):
# Guarantee we have a deprecated setting, so we test the deprecation
# lookup even if we remove this explicit fallback
@@ -497,14 +493,16 @@ def make_config():
'hostname_callable': (re.compile(r':'), r'.', '2.0'),
},
}
- test_conf.read_dict({
- 'core': {
- 'executor': 'SequentialExecutor',
- 'task_runner': 'BashTaskRunner',
- 'sql_alchemy_conn': 'sqlite://',
- 'hostname_callable': 'socket:getfqdn',
- },
- })
+ test_conf.read_dict(
+ {
+ 'core': {
+ 'executor': 'SequentialExecutor',
+ 'task_runner': 'BashTaskRunner',
+ 'sql_alchemy_conn': 'sqlite://',
+ 'hostname_callable': 'socket:getfqdn',
+ },
+ }
+ )
return test_conf
with self.assertWarns(FutureWarning):
@@ -526,9 +524,11 @@ def make_config():
with reset_warning_registry():
with warnings.catch_warnings(record=True) as warning:
- with unittest.mock.patch.dict('os.environ',
- AIRFLOW__CORE__TASK_RUNNER='NotBashTaskRunner',
- AIRFLOW__CORE__HOSTNAME_CALLABLE='CarrierPigeon'):
+ with unittest.mock.patch.dict(
+ 'os.environ',
+ AIRFLOW__CORE__TASK_RUNNER='NotBashTaskRunner',
+ AIRFLOW__CORE__HOSTNAME_CALLABLE='CarrierPigeon',
+ ):
test_conf = make_config()
self.assertEqual(test_conf.get('core', 'task_runner'), 'NotBashTaskRunner')
@@ -537,8 +537,17 @@ def make_config():
self.assertListEqual([], warning)
def test_deprecated_funcs(self):
- for func in ['load_test_config', 'get', 'getboolean', 'getfloat', 'getint', 'has_option',
- 'remove_option', 'as_dict', 'set']:
+ for func in [
+ 'load_test_config',
+ 'get',
+ 'getboolean',
+ 'getfloat',
+ 'getint',
+ 'has_option',
+ 'remove_option',
+ 'as_dict',
+ 'set',
+ ]:
with mock.patch(f'airflow.configuration.conf.{func}') as mock_method:
with self.assertWarns(DeprecationWarning):
getattr(configuration, func)()
@@ -583,10 +592,7 @@ def test_config_use_original_when_original_and_fallback_are_present(self):
fernet_key = conf.get('core', 'FERNET_KEY')
with conf_vars({('core', 'FERNET_KEY_CMD'): 'printf HELLO'}):
- fallback_fernet_key = conf.get(
- "core",
- "FERNET_KEY"
- )
+ fallback_fernet_key = conf.get("core", "FERNET_KEY")
self.assertEqual(fernet_key, fallback_fernet_key)
@@ -648,10 +654,7 @@ def test_store_dag_code_default_config(self):
self.assertTrue(store_serialized_dags)
self.assertTrue(store_dag_code)
- @conf_vars({
- ("core", "store_serialized_dags"): "True",
- ("core", "store_dag_code"): "False"
- })
+ @conf_vars({("core", "store_serialized_dags"): "True", ("core", "store_dag_code"): "False"})
def test_store_dag_code_config_when_set(self):
store_serialized_dags = conf.getboolean('core', 'store_serialized_dags', fallback=False)
store_dag_code = conf.getboolean("core", "store_dag_code", fallback=store_serialized_dags)
diff --git a/tests/core/test_core.py b/tests/core/test_core.py
index 71a22368d7461..44f4aa2f9633a 100644
--- a/tests/core/test_core.py
+++ b/tests/core/test_core.py
@@ -54,6 +54,7 @@ class OperatorSubclass(BaseOperator):
"""
An operator to test template substitution
"""
+
template_fields = ['some_templated_field']
def __init__(self, some_templated_field, *args, **kwargs):
@@ -78,15 +79,11 @@ def setUp(self):
def tearDown(self):
session = Session()
- session.query(DagRun).filter(
- DagRun.dag_id == TEST_DAG_ID).delete(
- synchronize_session=False)
- session.query(TaskInstance).filter(
- TaskInstance.dag_id == TEST_DAG_ID).delete(
- synchronize_session=False)
- session.query(TaskFail).filter(
- TaskFail.dag_id == TEST_DAG_ID).delete(
- synchronize_session=False)
+ session.query(DagRun).filter(DagRun.dag_id == TEST_DAG_ID).delete(synchronize_session=False)
+ session.query(TaskInstance).filter(TaskInstance.dag_id == TEST_DAG_ID).delete(
+ synchronize_session=False
+ )
+ session.query(TaskFail).filter(TaskFail.dag_id == TEST_DAG_ID).delete(synchronize_session=False)
session.commit()
session.close()
clear_db_dags()
@@ -102,10 +99,8 @@ def test_check_operators(self):
self.dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=DEFAULT_DATE)
op = CheckOperator(
- task_id='check',
- sql="select count(*) from operator_test_table",
- conn_id=conn_id,
- dag=self.dag)
+ task_id='check', sql="select count(*) from operator_test_table", conn_id=conn_id, dag=self.dag
+ )
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
@@ -115,16 +110,15 @@ def test_check_operators(self):
tolerance=0.1,
conn_id=conn_id,
sql="SELECT 100",
- dag=self.dag)
+ dag=self.dag,
+ )
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
captain_hook.run("drop table operator_test_table")
def test_clear_api(self):
task = self.dag_bash.tasks[0]
- task.clear(
- start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
- upstream=True, downstream=True)
+ task.clear(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, upstream=True, downstream=True)
ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
ti.are_dependents_done()
@@ -139,7 +133,8 @@ def test_illegal_args(self):
task_id='test_illegal_args',
bash_command='echo success',
dag=self.dag,
- illegal_argument_1234='hello?')
+ illegal_argument_1234='hello?',
+ )
assert any(msg in str(w) for w in warning.warnings)
def test_illegal_args_forbidden(self):
@@ -152,17 +147,15 @@ def test_illegal_args_forbidden(self):
task_id='test_illegal_args',
bash_command='echo success',
dag=self.dag,
- illegal_argument_1234='hello?')
+ illegal_argument_1234='hello?',
+ )
self.assertIn(
- ('Invalid arguments were passed to BashOperator '
- '(task_id: test_illegal_args).'),
- str(ctx.exception))
+ ('Invalid arguments were passed to BashOperator ' '(task_id: test_illegal_args).'),
+ str(ctx.exception),
+ )
def test_bash_operator(self):
- op = BashOperator(
- task_id='test_bash_operator',
- bash_command="echo success",
- dag=self.dag)
+ op = BashOperator(task_id='test_bash_operator', bash_command="echo success", dag=self.dag)
self.dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=DEFAULT_DATE)
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
@@ -172,22 +165,22 @@ def test_bash_operator_multi_byte_output(self):
task_id='test_multi_byte_bash_operator',
bash_command="echo \u2600",
dag=self.dag,
- output_encoding='utf-8')
+ output_encoding='utf-8',
+ )
self.dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=DEFAULT_DATE)
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_bash_operator_kill(self):
import psutil
+
sleep_time = "100%d" % os.getpid()
op = BashOperator(
task_id='test_bash_operator_kill',
execution_timeout=timedelta(seconds=1),
bash_command="/bin/bash -c 'sleep %s'" % sleep_time,
- dag=self.dag)
- self.assertRaises(
- AirflowTaskTimeout,
- op.run,
- start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+ dag=self.dag,
+ )
+ self.assertRaises(AirflowTaskTimeout, op.run, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
sleep(2)
pid = -1
for proc in psutil.process_iter():
@@ -210,26 +203,23 @@ def check_failure(context, test_case=self):
task_id='check_on_failure_callback',
bash_command="exit 1",
dag=self.dag,
- on_failure_callback=check_failure)
+ on_failure_callback=check_failure,
+ )
self.assertRaises(
- AirflowException,
- op.run,
- start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+ AirflowException, op.run, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True
+ )
self.assertTrue(data['called'])
def test_dryrun(self):
- op = BashOperator(
- task_id='test_dryrun',
- bash_command="echo success",
- dag=self.dag)
+ op = BashOperator(task_id='test_dryrun', bash_command="echo success", dag=self.dag)
op.dry_run()
def test_sqlite(self):
import airflow.providers.sqlite.operators.sqlite
+
op = airflow.providers.sqlite.operators.sqlite.SqliteOperator(
- task_id='time_sqlite',
- sql="CREATE TABLE IF NOT EXISTS unitest (dummy VARCHAR(20))",
- dag=self.dag)
+ task_id='time_sqlite', sql="CREATE TABLE IF NOT EXISTS unitest (dummy VARCHAR(20))", dag=self.dag
+ )
self.dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=DEFAULT_DATE)
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
@@ -238,11 +228,11 @@ def test_timeout(self):
task_id='test_timeout',
execution_timeout=timedelta(seconds=1),
python_callable=lambda: sleep(5),
- dag=self.dag)
+ dag=self.dag,
+ )
self.assertRaises(
- AirflowTaskTimeout,
- op.run,
- start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+ AirflowTaskTimeout, op.run, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True
+ )
def test_python_op(self):
def test_py_op(templates_dict, ds, **kwargs):
@@ -250,25 +240,20 @@ def test_py_op(templates_dict, ds, **kwargs):
raise Exception("failure")
op = PythonOperator(
- task_id='test_py_op',
- python_callable=test_py_op,
- templates_dict={'ds': "{{ ds }}"},
- dag=self.dag)
+ task_id='test_py_op', python_callable=test_py_op, templates_dict={'ds': "{{ ds }}"}, dag=self.dag
+ )
self.dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=DEFAULT_DATE)
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_complex_template(self):
def verify_templated_field(context):
- self.assertEqual(context['ti'].task.some_templated_field['bar'][1],
- context['ds'])
+ self.assertEqual(context['ti'].task.some_templated_field['bar'][1], context['ds'])
op = OperatorSubclass(
task_id='test_complex_template',
- some_templated_field={
- 'foo': '123',
- 'bar': ['baz', '{{ ds }}']
- },
- dag=self.dag)
+ some_templated_field={'foo': '123', 'bar': ['baz', '{{ ds }}']},
+ dag=self.dag,
+ )
op.execute = verify_templated_field
self.dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=DEFAULT_DATE)
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
@@ -286,20 +271,16 @@ def __bool__(self): # pylint: disable=invalid-bool-returned, bad-option-value
return NotImplemented
op = OperatorSubclass(
- task_id='test_bad_template_obj',
- some_templated_field=NonBoolObject(),
- dag=self.dag)
+ task_id='test_bad_template_obj', some_templated_field=NonBoolObject(), dag=self.dag
+ )
op.resolve_template_files()
def test_task_get_template(self):
TI = TaskInstance
- ti = TI(
- task=self.runme_0, execution_date=DEFAULT_DATE)
+ ti = TI(task=self.runme_0, execution_date=DEFAULT_DATE)
ti.dag = self.dag_bash
self.dag_bash.create_dagrun(
- run_type=DagRunType.MANUAL,
- state=State.RUNNING,
- execution_date=DEFAULT_DATE
+ run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=DEFAULT_DATE
)
ti.run(ignore_ti_state=True)
context = ti.get_template_context()
@@ -328,20 +309,16 @@ def test_task_get_template(self):
def test_local_task_job(self):
TI = TaskInstance
- ti = TI(
- task=self.runme_0, execution_date=DEFAULT_DATE)
+ ti = TI(task=self.runme_0, execution_date=DEFAULT_DATE)
job = LocalTaskJob(task_instance=ti, ignore_ti_state=True)
job.run()
def test_raw_job(self):
TI = TaskInstance
- ti = TI(
- task=self.runme_0, execution_date=DEFAULT_DATE)
+ ti = TI(task=self.runme_0, execution_date=DEFAULT_DATE)
ti.dag = self.dag_bash
self.dag_bash.create_dagrun(
- run_type=DagRunType.MANUAL,
- state=State.RUNNING,
- execution_date=DEFAULT_DATE
+ run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=DEFAULT_DATE
)
ti.run(ignore_ti_state=True)
@@ -353,20 +330,16 @@ def test_round_time(self):
rt2 = round_time(datetime(2015, 1, 2), relativedelta(months=1))
self.assertEqual(datetime(2015, 1, 1, 0, 0), rt2)
- rt3 = round_time(datetime(2015, 9, 16, 0, 0), timedelta(1), datetime(
- 2015, 9, 14, 0, 0))
+ rt3 = round_time(datetime(2015, 9, 16, 0, 0), timedelta(1), datetime(2015, 9, 14, 0, 0))
self.assertEqual(datetime(2015, 9, 16, 0, 0), rt3)
- rt4 = round_time(datetime(2015, 9, 15, 0, 0), timedelta(1), datetime(
- 2015, 9, 14, 0, 0))
+ rt4 = round_time(datetime(2015, 9, 15, 0, 0), timedelta(1), datetime(2015, 9, 14, 0, 0))
self.assertEqual(datetime(2015, 9, 15, 0, 0), rt4)
- rt5 = round_time(datetime(2015, 9, 14, 0, 0), timedelta(1), datetime(
- 2015, 9, 14, 0, 0))
+ rt5 = round_time(datetime(2015, 9, 14, 0, 0), timedelta(1), datetime(2015, 9, 14, 0, 0))
self.assertEqual(datetime(2015, 9, 14, 0, 0), rt5)
- rt6 = round_time(datetime(2015, 9, 13, 0, 0), timedelta(1), datetime(
- 2015, 9, 14, 0, 0))
+ rt6 = round_time(datetime(2015, 9, 13, 0, 0), timedelta(1), datetime(2015, 9, 14, 0, 0))
self.assertEqual(datetime(2015, 9, 14, 0, 0), rt6)
def test_infer_time_unit(self):
@@ -390,29 +363,25 @@ def test_scale_time_units(self):
assert_array_almost_equal(arr2, [110.0, 50.0, 10.0, 100.0], decimal=3)
arr3 = scale_time_units([100000, 50000, 10000, 20000], 'hours')
- assert_array_almost_equal(arr3, [27.778, 13.889, 2.778, 5.556],
- decimal=3)
+ assert_array_almost_equal(arr3, [27.778, 13.889, 2.778, 5.556], decimal=3)
arr4 = scale_time_units([200000, 100000], 'days')
assert_array_almost_equal(arr4, [2.315, 1.157], decimal=3)
def test_bad_trigger_rule(self):
with self.assertRaises(AirflowException):
- DummyOperator(
- task_id='test_bad_trigger',
- trigger_rule="non_existent",
- dag=self.dag)
+ DummyOperator(task_id='test_bad_trigger', trigger_rule="non_existent", dag=self.dag)
def test_terminate_task(self):
"""If a task instance's db state get deleted, it should fail"""
from airflow.executors.sequential_executor import SequentialExecutor
+
TI = TaskInstance
dag = self.dagbag.dags.get('test_utils')
task = dag.task_dict.get('sleeps_forever')
ti = TI(task=task, execution_date=DEFAULT_DATE)
- job = LocalTaskJob(
- task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
+ job = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
# Running task instance asynchronously
proc = multiprocessing.Process(target=job.run)
@@ -423,11 +392,11 @@ def test_terminate_task(self):
ti.refresh_from_db(session=session)
# making sure it's actually running
self.assertEqual(State.RUNNING, ti.state)
- ti = session.query(TI).filter_by(
- dag_id=task.dag_id,
- task_id=task.task_id,
- execution_date=DEFAULT_DATE
- ).one()
+ ti = (
+ session.query(TI)
+ .filter_by(dag_id=task.dag_id, task_id=task.task_id, execution_date=DEFAULT_DATE)
+ .one()
+ )
# deleting the instance should result in a failure
session.delete(ti)
@@ -443,16 +412,14 @@ def test_terminate_task(self):
def test_task_fail_duration(self):
"""If a task fails, the duration should be recorded in TaskFail"""
- op1 = BashOperator(
- task_id='pass_sleepy',
- bash_command='sleep 3',
- dag=self.dag)
+ op1 = BashOperator(task_id='pass_sleepy', bash_command='sleep 3', dag=self.dag)
op2 = BashOperator(
task_id='fail_sleepy',
bash_command='sleep 5',
execution_timeout=timedelta(seconds=3),
retry_delay=timedelta(seconds=0),
- dag=self.dag)
+ dag=self.dag,
+ )
session = settings.Session()
try:
op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
@@ -462,14 +429,16 @@ def test_task_fail_duration(self):
op2.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
except Exception: # pylint: disable=broad-except
pass
- op1_fails = session.query(TaskFail).filter_by(
- task_id='pass_sleepy',
- dag_id=self.dag.dag_id,
- execution_date=DEFAULT_DATE).all()
- op2_fails = session.query(TaskFail).filter_by(
- task_id='fail_sleepy',
- dag_id=self.dag.dag_id,
- execution_date=DEFAULT_DATE).all()
+ op1_fails = (
+ session.query(TaskFail)
+ .filter_by(task_id='pass_sleepy', dag_id=self.dag.dag_id, execution_date=DEFAULT_DATE)
+ .all()
+ )
+ op2_fails = (
+ session.query(TaskFail)
+ .filter_by(task_id='fail_sleepy', dag_id=self.dag.dag_id, execution_date=DEFAULT_DATE)
+ .all()
+ )
self.assertEqual(0, len(op1_fails))
self.assertEqual(1, len(op2_fails))
@@ -484,18 +453,16 @@ def test_externally_triggered_dagrun(self):
execution_ds_nodash = execution_ds.replace('-', '')
dag = DAG(
- TEST_DAG_ID,
- default_args=self.args,
- schedule_interval=timedelta(weeks=1),
- start_date=DEFAULT_DATE)
- task = DummyOperator(task_id='test_externally_triggered_dag_context',
- dag=dag)
- dag.create_dagrun(run_type=DagRunType.SCHEDULED,
- execution_date=execution_date,
- state=State.RUNNING,
- external_trigger=True)
- task.run(
- start_date=execution_date, end_date=execution_date)
+ TEST_DAG_ID, default_args=self.args, schedule_interval=timedelta(weeks=1), start_date=DEFAULT_DATE
+ )
+ task = DummyOperator(task_id='test_externally_triggered_dag_context', dag=dag)
+ dag.create_dagrun(
+ run_type=DagRunType.SCHEDULED,
+ execution_date=execution_date,
+ state=State.RUNNING,
+ external_trigger=True,
+ )
+ task.run(start_date=execution_date, end_date=execution_date)
ti = TI(task=task, execution_date=execution_date)
context = ti.get_template_context()
diff --git a/tests/core/test_core_to_contrib.py b/tests/core/test_core_to_contrib.py
index e44287ccc7b26..e91e69d987525 100644
--- a/tests/core/test_core_to_contrib.py
+++ b/tests/core/test_core_to_contrib.py
@@ -35,9 +35,7 @@ def assert_warning(msg: str, warning: Any):
assert any(msg in str(w) for w in warning.warnings), error
def assert_is_subclass(self, clazz, other):
- self.assertTrue(
- issubclass(clazz, other), f"{clazz} is not subclass of {other}"
- )
+ self.assertTrue(issubclass(clazz, other), f"{clazz} is not subclass of {other}")
def assert_proper_import(self, old_resource, new_resource):
new_path, _, _ = new_resource.rpartition(".")
@@ -66,9 +64,7 @@ def get_class_from_path(path_to_class, parent=False):
if isabstract(class_) and not parent:
class_name = f"Mock({class_.__name__})"
- attributes = {
- a: mock.MagicMock() for a in class_.__abstractmethods__
- }
+ attributes = {a: mock.MagicMock() for a in class_.__abstractmethods__}
new_class = type(class_name, (class_,), attributes)
return new_class
@@ -111,9 +107,7 @@ def test_no_redirect_to_deprecated_classes(self):
This will tell us to use new_A instead of old_B.
"""
- all_classes_by_old = {
- old: new for new, old in ALL
- }
+ all_classes_by_old = {old: new for new, old in ALL}
for new, old in ALL:
# Using if statement allows us to create a developer-friendly message only when we need it.
diff --git a/tests/core/test_example_dags_system.py b/tests/core/test_example_dags_system.py
index 0ce411d5e5222..4d16f0c24e6d8 100644
--- a/tests/core/test_example_dags_system.py
+++ b/tests/core/test_example_dags_system.py
@@ -23,12 +23,14 @@
@pytest.mark.system("core")
class TestExampleDagsSystem(SystemTest):
- @parameterized.expand([
- "example_bash_operator",
- "example_branch_operator",
- "tutorial_etl_dag",
- "tutorial_functional_etl_dag",
- "example_dag_decorator",
- ])
+ @parameterized.expand(
+ [
+ "example_bash_operator",
+ "example_branch_operator",
+ "tutorial_etl_dag",
+ "tutorial_functional_etl_dag",
+ "example_dag_decorator",
+ ]
+ )
def test_dag_example(self, dag_id):
self.run_dag(dag_id=dag_id)
diff --git a/tests/core/test_impersonation_tests.py b/tests/core/test_impersonation_tests.py
index 56c5cdf43b0c3..5fe10e03c5afd 100644
--- a/tests/core/test_impersonation_tests.py
+++ b/tests/core/test_impersonation_tests.py
@@ -35,12 +35,9 @@
from airflow.utils.timezone import datetime
DEV_NULL = '/dev/null'
-TEST_DAG_FOLDER = os.path.join(
- os.path.dirname(os.path.realpath(__file__)), 'dags')
-TEST_DAG_CORRUPTED_FOLDER = os.path.join(
- os.path.dirname(os.path.realpath(__file__)), 'dags_corrupted')
-TEST_UTILS_FOLDER = os.path.join(
- os.path.dirname(os.path.realpath(__file__)), 'test_utils')
+TEST_DAG_FOLDER = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'dags')
+TEST_DAG_CORRUPTED_FOLDER = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'dags_corrupted')
+TEST_UTILS_FOLDER = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_utils')
DEFAULT_DATE = datetime(2015, 1, 1)
TEST_USER = 'airflow_test_user'
@@ -54,6 +51,7 @@ def mock_custom_module_path(path: str):
the :envvar:`PYTHONPATH` environment variable set and sets the environment variable
:envvar:`PYTHONPATH` to change the module load directory for child scripts.
"""
+
def wrapper(func):
@functools.wraps(func)
def decorator(*args, **kwargs):
@@ -64,6 +62,7 @@ def decorator(*args, **kwargs):
return func(*args, **kwargs)
finally:
sys.path = copy_sys_path
+
return decorator
return wrapper
@@ -72,28 +71,31 @@ def decorator(*args, **kwargs):
def grant_permissions():
airflow_home = os.environ['AIRFLOW_HOME']
subprocess.check_call(
- 'find "%s" -exec sudo chmod og+w {} +; sudo chmod og+rx /root' % airflow_home, shell=True)
+ 'find "%s" -exec sudo chmod og+w {} +; sudo chmod og+rx /root' % airflow_home, shell=True
+ )
def revoke_permissions():
airflow_home = os.environ['AIRFLOW_HOME']
subprocess.check_call(
- 'find "%s" -exec sudo chmod og-w {} +; sudo chmod og-rx /root' % airflow_home, shell=True)
+ 'find "%s" -exec sudo chmod og-w {} +; sudo chmod og-rx /root' % airflow_home, shell=True
+ )
def check_original_docker_image():
if not os.path.isfile('/.dockerenv') or os.environ.get('PYTHON_BASE_IMAGE') is None:
- raise unittest.SkipTest("""Adding/removing a user as part of a test is very bad for host os
+ raise unittest.SkipTest(
+ """Adding/removing a user as part of a test is very bad for host os
(especially if the user already existed to begin with on the OS), therefore we check if we run inside a
the official docker container and only allow to run the test there. This is done by checking /.dockerenv
file (always present inside container) and checking for PYTHON_BASE_IMAGE variable.
-""")
+"""
+ )
def create_user():
try:
- subprocess.check_output(['sudo', 'useradd', '-m', TEST_USER, '-g',
- str(os.getegid())])
+ subprocess.check_output(['sudo', 'useradd', '-m', TEST_USER, '-g', str(os.getegid())])
except OSError as e:
if e.errno == errno.ENOENT:
raise unittest.SkipTest(
@@ -111,7 +113,6 @@ def create_user():
@pytest.mark.quarantined
class TestImpersonation(unittest.TestCase):
-
def setUp(self):
check_original_docker_image()
grant_permissions()
@@ -133,14 +134,9 @@ def run_backfill(self, dag_id, task_id):
dag = self.dagbag.get_dag(dag_id)
dag.clear()
- BackfillJob(
- dag=dag,
- start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE).run()
+ BackfillJob(dag=dag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE).run()
- ti = models.TaskInstance(
- task=dag.get_task(task_id),
- execution_date=DEFAULT_DATE)
+ ti = models.TaskInstance(task=dag.get_task(task_id), execution_date=DEFAULT_DATE)
ti.refresh_from_db()
self.assertEqual(ti.state, State.SUCCESS)
@@ -149,10 +145,7 @@ def test_impersonation(self):
"""
Tests that impersonating a unix user works
"""
- self.run_backfill(
- 'test_impersonation',
- 'test_impersonated_user'
- )
+ self.run_backfill('test_impersonation', 'test_impersonated_user')
def test_no_impersonation(self):
"""
@@ -170,25 +163,18 @@ def test_default_impersonation(self):
If default_impersonation=TEST_USER, tests that the job defaults
to running as TEST_USER for a test without run_as_user set
"""
- self.run_backfill(
- 'test_default_impersonation',
- 'test_deelevated_user'
- )
+ self.run_backfill('test_default_impersonation', 'test_deelevated_user')
def test_impersonation_subdag(self):
"""
Tests that impersonation using a subdag correctly passes the right configuration
:return:
"""
- self.run_backfill(
- 'impersonation_subdag',
- 'test_subdag_operation'
- )
+ self.run_backfill('impersonation_subdag', 'test_subdag_operation')
@pytest.mark.quarantined
class TestImpersonationWithCustomPythonPath(unittest.TestCase):
-
@mock_custom_module_path(TEST_UTILS_FOLDER)
def setUp(self):
check_original_docker_image()
@@ -211,14 +197,9 @@ def run_backfill(self, dag_id, task_id):
dag = self.dagbag.get_dag(dag_id)
dag.clear()
- BackfillJob(
- dag=dag,
- start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE).run()
+ BackfillJob(dag=dag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE).run()
- ti = models.TaskInstance(
- task=dag.get_task(task_id),
- execution_date=DEFAULT_DATE)
+ ti = models.TaskInstance(task=dag.get_task(task_id), execution_date=DEFAULT_DATE)
ti.refresh_from_db()
self.assertEqual(ti.state, State.SUCCESS)
@@ -232,7 +213,4 @@ def test_impersonation_custom(self):
# PYTHONPATH is already set in script triggering tests
assert 'PYTHONPATH' in os.environ
- self.run_backfill(
- 'impersonation_with_custom_pkg',
- 'exec_python_fn'
- )
+ self.run_backfill('impersonation_with_custom_pkg', 'exec_python_fn')
diff --git a/tests/core/test_local_settings.py b/tests/core/test_local_settings.py
index 9f32d3583a09f..8eacf0e542f5c 100644
--- a/tests/core/test_local_settings.py
+++ b/tests/core/test_local_settings.py
@@ -96,6 +96,7 @@ def test_initialize_order(self, prepare_syspath, import_local_settings):
mock.attach_mock(import_local_settings, "import_local_settings")
import airflow.settings
+
airflow.settings.initialize()
mock.assert_has_calls([call.prepare_syspath(), call.import_local_settings()])
@@ -107,6 +108,7 @@ def test_import_with_dunder_all_not_specified(self):
"""
with SettingsContext(SETTINGS_FILE_POLICY_WITH_DUNDER_ALL, "airflow_local_settings"):
from airflow import settings
+
settings.import_local_settings()
with self.assertRaises(AttributeError):
@@ -119,6 +121,7 @@ def test_import_with_dunder_all(self):
"""
with SettingsContext(SETTINGS_FILE_POLICY_WITH_DUNDER_ALL, "airflow_local_settings"):
from airflow import settings
+
settings.import_local_settings()
task_instance = MagicMock()
@@ -133,6 +136,7 @@ def test_import_local_settings_without_syspath(self, log_mock):
if there is no airflow_local_settings module on the syspath.
"""
from airflow import settings
+
settings.import_local_settings()
log_mock.assert_called_once_with("Failed to import airflow_local_settings.", exc_info=True)
@@ -143,6 +147,7 @@ def test_policy_function(self):
"""
with SettingsContext(SETTINGS_FILE_POLICY, "airflow_local_settings"):
from airflow import settings
+
settings.import_local_settings()
task_instance = MagicMock()
@@ -157,6 +162,7 @@ def test_pod_mutation_hook(self):
"""
with SettingsContext(SETTINGS_FILE_POD_MUTATION_HOOK, "airflow_local_settings"):
from airflow import settings
+
settings.import_local_settings()
pod = MagicMock()
@@ -167,6 +173,7 @@ def test_pod_mutation_hook(self):
def test_custom_policy(self):
with SettingsContext(SETTINGS_FILE_CUSTOM_POLICY, "airflow_local_settings"):
from airflow import settings
+
settings.import_local_settings()
task_instance = MagicMock()
diff --git a/tests/core/test_logging_config.py b/tests/core/test_logging_config.py
index 99b05fd927ab9..9c6983319c1a0 100644
--- a/tests/core/test_logging_config.py
+++ b/tests/core/test_logging_config.py
@@ -188,6 +188,7 @@ def tearDown(self):
# When we try to load an invalid config file, we expect an error
def test_loading_invalid_local_settings(self):
from airflow.logging_config import configure_logging, log
+
with settings_context(SETTINGS_FILE_INVALID):
with patch.object(log, 'error') as mock_info:
# Load config
@@ -204,56 +205,54 @@ def test_loading_valid_complex_local_settings(self):
dir_structure = module_structure.replace('.', '/')
with settings_context(SETTINGS_FILE_VALID, dir_structure):
from airflow.logging_config import configure_logging, log
+
with patch.object(log, 'info') as mock_info:
configure_logging()
mock_info.assert_called_once_with(
'Successfully imported user-defined logging config from %s',
- 'etc.airflow.config.{}.LOGGING_CONFIG'.format(
- SETTINGS_DEFAULT_NAME
- )
+ f'etc.airflow.config.{SETTINGS_DEFAULT_NAME}.LOGGING_CONFIG',
)
# When we try to load a valid config
def test_loading_valid_local_settings(self):
with settings_context(SETTINGS_FILE_VALID):
from airflow.logging_config import configure_logging, log
+
with patch.object(log, 'info') as mock_info:
configure_logging()
mock_info.assert_called_once_with(
'Successfully imported user-defined logging config from %s',
- '{}.LOGGING_CONFIG'.format(
- SETTINGS_DEFAULT_NAME
- )
+ f'{SETTINGS_DEFAULT_NAME}.LOGGING_CONFIG',
)
# When we load an empty file, it should go to default
def test_loading_no_local_settings(self):
with settings_context(SETTINGS_FILE_EMPTY):
from airflow.logging_config import configure_logging
+
with self.assertRaises(ImportError):
configure_logging()
# When the key is not available in the configuration
def test_when_the_config_key_does_not_exists(self):
from airflow import logging_config
+
with conf_vars({('logging', 'logging_config_class'): None}):
with patch.object(logging_config.log, 'debug') as mock_debug:
logging_config.configure_logging()
- mock_debug.assert_any_call(
- 'Could not find key logging_config_class in config'
- )
+ mock_debug.assert_any_call('Could not find key logging_config_class in config')
# Just default
def test_loading_local_settings_without_logging_config(self):
from airflow.logging_config import configure_logging, log
+
with patch.object(log, 'debug') as mock_info:
configure_logging()
- mock_info.assert_called_once_with(
- 'Unable to load custom logging, using default config instead'
- )
+ mock_info.assert_called_once_with('Unable to load custom logging, using default config instead')
def test_1_9_config(self):
from airflow.logging_config import configure_logging
+
with conf_vars({('logging', 'task_log_reader'): 'file.task'}):
with self.assertWarnsRegex(DeprecationWarning, r'file.task'):
configure_logging()
@@ -265,11 +264,13 @@ def test_loading_remote_logging_with_wasb_handler(self):
from airflow.logging_config import configure_logging
from airflow.utils.log.wasb_task_handler import WasbTaskHandler
- with conf_vars({
- ('logging', 'remote_logging'): 'True',
- ('logging', 'remote_log_conn_id'): 'some_wasb',
- ('logging', 'remote_base_log_folder'): 'wasb://some-folder',
- }):
+ with conf_vars(
+ {
+ ('logging', 'remote_logging'): 'True',
+ ('logging', 'remote_log_conn_id'): 'some_wasb',
+ ('logging', 'remote_base_log_folder'): 'wasb://some-folder',
+ }
+ ):
importlib.reload(airflow_local_settings)
configure_logging()
diff --git a/tests/core/test_sentry.py b/tests/core/test_sentry.py
index 3d5979a25ddfd..bfff5843fdfc4 100644
--- a/tests/core/test_sentry.py
+++ b/tests/core/test_sentry.py
@@ -64,6 +64,7 @@ class TestSentryHook(unittest.TestCase):
@conf_vars({('sentry', 'sentry_on'): 'True'})
def setUp(self):
from airflow import sentry
+
importlib.reload(sentry)
self.sentry = sentry.ConfiguredSentry()
diff --git a/tests/core/test_sqlalchemy_config.py b/tests/core/test_sqlalchemy_config.py
index 6f0ee94b9c141..a7cd08b727bd7 100644
--- a/tests/core/test_sqlalchemy_config.py
+++ b/tests/core/test_sqlalchemy_config.py
@@ -25,13 +25,7 @@
from airflow.exceptions import AirflowConfigException
from tests.test_utils.config import conf_vars
-SQL_ALCHEMY_CONNECT_ARGS = {
- 'test': 43503,
- 'dict': {
- 'is': 1,
- 'supported': 'too'
- }
-}
+SQL_ALCHEMY_CONNECT_ARGS = {'test': 43503, 'dict': {'is': 1, 'supported': 'too'}}
class TestSqlAlchemySettings(unittest.TestCase):
@@ -50,11 +44,9 @@ def tearDown(self):
@patch('airflow.settings.scoped_session')
@patch('airflow.settings.sessionmaker')
@patch('airflow.settings.create_engine')
- def test_configure_orm_with_default_values(self,
- mock_create_engine,
- mock_sessionmaker,
- mock_scoped_session,
- mock_setup_event_handlers):
+ def test_configure_orm_with_default_values(
+ self, mock_create_engine, mock_sessionmaker, mock_scoped_session, mock_setup_event_handlers
+ ):
settings.configure_orm()
mock_create_engine.assert_called_once_with(
settings.SQL_ALCHEMY_CONN,
@@ -63,22 +55,22 @@ def test_configure_orm_with_default_values(self,
max_overflow=10,
pool_pre_ping=True,
pool_recycle=1800,
- pool_size=5
+ pool_size=5,
)
@patch('airflow.settings.setup_event_handlers')
@patch('airflow.settings.scoped_session')
@patch('airflow.settings.sessionmaker')
@patch('airflow.settings.create_engine')
- def test_sql_alchemy_connect_args(self,
- mock_create_engine,
- mock_sessionmaker,
- mock_scoped_session,
- mock_setup_event_handlers):
+ def test_sql_alchemy_connect_args(
+ self, mock_create_engine, mock_sessionmaker, mock_scoped_session, mock_setup_event_handlers
+ ):
config = {
- ('core', 'sql_alchemy_connect_args'):
- 'tests.core.test_sqlalchemy_config.SQL_ALCHEMY_CONNECT_ARGS',
- ('core', 'sql_alchemy_pool_enabled'): 'False'
+ (
+ 'core',
+ 'sql_alchemy_connect_args',
+ ): 'tests.core.test_sqlalchemy_config.SQL_ALCHEMY_CONNECT_ARGS',
+ ('core', 'sql_alchemy_pool_enabled'): 'False',
}
with conf_vars(config):
settings.configure_orm()
@@ -86,21 +78,19 @@ def test_sql_alchemy_connect_args(self,
settings.SQL_ALCHEMY_CONN,
connect_args=SQL_ALCHEMY_CONNECT_ARGS,
poolclass=NullPool,
- encoding='utf-8'
+ encoding='utf-8',
)
@patch('airflow.settings.setup_event_handlers')
@patch('airflow.settings.scoped_session')
@patch('airflow.settings.sessionmaker')
@patch('airflow.settings.create_engine')
- def test_sql_alchemy_invalid_connect_args(self,
- mock_create_engine,
- mock_sessionmaker,
- mock_scoped_session,
- mock_setup_event_handlers):
+ def test_sql_alchemy_invalid_connect_args(
+ self, mock_create_engine, mock_sessionmaker, mock_scoped_session, mock_setup_event_handlers
+ ):
config = {
('core', 'sql_alchemy_connect_args'): 'does.not.exist',
- ('core', 'sql_alchemy_pool_enabled'): 'False'
+ ('core', 'sql_alchemy_pool_enabled'): 'False',
}
with self.assertRaises(AirflowConfigException):
with conf_vars(config):
diff --git a/tests/core/test_stats.py b/tests/core/test_stats.py
index fd2acd7135ad7..9f9adeaf14a06 100644
--- a/tests/core/test_stats.py
+++ b/tests/core/test_stats.py
@@ -48,6 +48,7 @@ class InvalidCustomStatsd:
This custom Statsd class is invalid because it does not subclass
statsd.StatsClient.
"""
+
incr_calls = 0
def __init__(self, host=None, port=None, prefix=None):
@@ -62,7 +63,6 @@ def _reset(cls):
class TestStats(unittest.TestCase):
-
def setUp(self):
self.statsd_client = Mock()
self.stats = SafeStatsdLogger(self.statsd_client)
@@ -85,52 +85,54 @@ def test_stat_name_must_only_include_allowed_characters(self):
self.stats.incr('test/$tats')
self.statsd_client.assert_not_called()
- @conf_vars({
- ('scheduler', 'statsd_on'): 'True'
- })
+ @conf_vars({('scheduler', 'statsd_on'): 'True'})
@mock.patch("statsd.StatsClient")
def test_does_send_stats_using_statsd(self, mock_statsd):
importlib.reload(airflow.stats)
airflow.stats.Stats.incr("dummy_key")
mock_statsd.return_value.incr.assert_called_once_with('dummy_key', 1, 1)
- @conf_vars({
- ('scheduler', 'statsd_on'): 'True'
- })
+ @conf_vars({('scheduler', 'statsd_on'): 'True'})
@mock.patch("datadog.DogStatsd")
def test_does_not_send_stats_using_dogstatsd(self, mock_dogstatsd):
importlib.reload(airflow.stats)
airflow.stats.Stats.incr("dummy_key")
mock_dogstatsd.return_value.assert_not_called()
- @conf_vars({
- ("scheduler", "statsd_on"): "True",
- ("scheduler", "statsd_custom_client_path"): "tests.core.test_stats.CustomStatsd",
- })
+ @conf_vars(
+ {
+ ("scheduler", "statsd_on"): "True",
+ ("scheduler", "statsd_custom_client_path"): "tests.core.test_stats.CustomStatsd",
+ }
+ )
def test_load_custom_statsd_client(self):
importlib.reload(airflow.stats)
self.assertEqual('CustomStatsd', type(airflow.stats.Stats.statsd).__name__)
- @conf_vars({
- ("scheduler", "statsd_on"): "True",
- ("scheduler", "statsd_custom_client_path"): "tests.core.test_stats.CustomStatsd",
- })
+ @conf_vars(
+ {
+ ("scheduler", "statsd_on"): "True",
+ ("scheduler", "statsd_custom_client_path"): "tests.core.test_stats.CustomStatsd",
+ }
+ )
def test_does_use_custom_statsd_client(self):
importlib.reload(airflow.stats)
airflow.stats.Stats.incr("dummy_key")
assert airflow.stats.Stats.statsd.incr_calls == 1
- @conf_vars({
- ("scheduler", "statsd_on"): "True",
- ("scheduler", "statsd_custom_client_path"): "tests.core.test_stats.InvalidCustomStatsd",
- })
+ @conf_vars(
+ {
+ ("scheduler", "statsd_on"): "True",
+ ("scheduler", "statsd_custom_client_path"): "tests.core.test_stats.InvalidCustomStatsd",
+ }
+ )
def test_load_invalid_custom_stats_client(self):
with self.assertRaisesRegex(
AirflowConfigException,
re.escape(
'Your custom Statsd client must extend the statsd.'
'StatsClient in order to ensure backwards compatibility.'
- )
+ ),
):
importlib.reload(airflow.stats)
@@ -140,7 +142,6 @@ def tearDown(self) -> None:
class TestDogStats(unittest.TestCase):
-
def setUp(self):
self.dogstatsd_client = Mock()
self.dogstatsd = SafeDogStatsdLogger(self.dogstatsd_client)
@@ -163,9 +164,7 @@ def test_stat_name_must_only_include_allowed_characters_with_dogstatsd(self):
self.dogstatsd.incr('test/$tats')
self.dogstatsd_client.assert_not_called()
- @conf_vars({
- ('scheduler', 'statsd_datadog_enabled'): 'True'
- })
+ @conf_vars({('scheduler', 'statsd_datadog_enabled'): 'True'})
@mock.patch("datadog.DogStatsd")
def test_does_send_stats_using_dogstatsd_when_dogstatsd_on(self, mock_dogstatsd):
importlib.reload(airflow.stats)
@@ -174,9 +173,7 @@ def test_does_send_stats_using_dogstatsd_when_dogstatsd_on(self, mock_dogstatsd)
metric='dummy_key', sample_rate=1, tags=[], value=1
)
- @conf_vars({
- ('scheduler', 'statsd_datadog_enabled'): 'True'
- })
+ @conf_vars({('scheduler', 'statsd_datadog_enabled'): 'True'})
@mock.patch("datadog.DogStatsd")
def test_does_send_stats_using_dogstatsd_with_tags(self, mock_dogstatsd):
importlib.reload(airflow.stats)
@@ -185,10 +182,7 @@ def test_does_send_stats_using_dogstatsd_with_tags(self, mock_dogstatsd):
metric='dummy_key', sample_rate=1, tags=['key1:value1', 'key2:value2'], value=1
)
- @conf_vars({
- ('scheduler', 'statsd_on'): 'True',
- ('scheduler', 'statsd_datadog_enabled'): 'True'
- })
+ @conf_vars({('scheduler', 'statsd_on'): 'True', ('scheduler', 'statsd_datadog_enabled'): 'True'})
@mock.patch("datadog.DogStatsd")
def test_does_send_stats_using_dogstatsd_when_statsd_and_dogstatsd_both_on(self, mock_dogstatsd):
importlib.reload(airflow.stats)
@@ -197,10 +191,7 @@ def test_does_send_stats_using_dogstatsd_when_statsd_and_dogstatsd_both_on(self,
metric='dummy_key', sample_rate=1, tags=[], value=1
)
- @conf_vars({
- ('scheduler', 'statsd_on'): 'True',
- ('scheduler', 'statsd_datadog_enabled'): 'True'
- })
+ @conf_vars({('scheduler', 'statsd_on'): 'True', ('scheduler', 'statsd_datadog_enabled'): 'True'})
@mock.patch("statsd.StatsClient")
def test_does_not_send_stats_using_statsd_when_statsd_and_dogstatsd_both_on(self, mock_statsd):
importlib.reload(airflow.stats)
@@ -213,7 +204,6 @@ def tearDown(self) -> None:
class TestStatsWithAllowList(unittest.TestCase):
-
def setUp(self):
self.statsd_client = Mock()
self.stats = SafeStatsdLogger(self.statsd_client, AllowListValidator("stats_one, stats_two"))
@@ -232,7 +222,6 @@ def test_not_increment_counter_if_not_allowed(self):
class TestDogStatsWithAllowList(unittest.TestCase):
-
def setUp(self):
self.dogstatsd_client = Mock()
self.dogstats = SafeDogStatsdLogger(self.dogstatsd_client, AllowListValidator("stats_one, stats_two"))
@@ -263,40 +252,48 @@ def always_valid(stat_name):
class TestCustomStatsName(unittest.TestCase):
- @conf_vars({
- ('scheduler', 'statsd_on'): 'True',
- ('scheduler', 'stat_name_handler'): 'tests.core.test_stats.always_invalid'
- })
+ @conf_vars(
+ {
+ ('scheduler', 'statsd_on'): 'True',
+ ('scheduler', 'stat_name_handler'): 'tests.core.test_stats.always_invalid',
+ }
+ )
@mock.patch("statsd.StatsClient")
def test_does_not_send_stats_using_statsd_when_the_name_is_not_valid(self, mock_statsd):
importlib.reload(airflow.stats)
airflow.stats.Stats.incr("dummy_key")
mock_statsd.return_value.assert_not_called()
- @conf_vars({
- ('scheduler', 'statsd_datadog_enabled'): 'True',
- ('scheduler', 'stat_name_handler'): 'tests.core.test_stats.always_invalid'
- })
+ @conf_vars(
+ {
+ ('scheduler', 'statsd_datadog_enabled'): 'True',
+ ('scheduler', 'stat_name_handler'): 'tests.core.test_stats.always_invalid',
+ }
+ )
@mock.patch("datadog.DogStatsd")
def test_does_not_send_stats_using_dogstatsd_when_the_name_is_not_valid(self, mock_dogstatsd):
importlib.reload(airflow.stats)
airflow.stats.Stats.incr("dummy_key")
mock_dogstatsd.return_value.assert_not_called()
- @conf_vars({
- ('scheduler', 'statsd_on'): 'True',
- ('scheduler', 'stat_name_handler'): 'tests.core.test_stats.always_valid'
- })
+ @conf_vars(
+ {
+ ('scheduler', 'statsd_on'): 'True',
+ ('scheduler', 'stat_name_handler'): 'tests.core.test_stats.always_valid',
+ }
+ )
@mock.patch("statsd.StatsClient")
def test_does_send_stats_using_statsd_when_the_name_is_valid(self, mock_statsd):
importlib.reload(airflow.stats)
airflow.stats.Stats.incr("dummy_key")
mock_statsd.return_value.incr.assert_called_once_with('dummy_key', 1, 1)
- @conf_vars({
- ('scheduler', 'statsd_datadog_enabled'): 'True',
- ('scheduler', 'stat_name_handler'): 'tests.core.test_stats.always_valid'
- })
+ @conf_vars(
+ {
+ ('scheduler', 'statsd_datadog_enabled'): 'True',
+ ('scheduler', 'stat_name_handler'): 'tests.core.test_stats.always_valid',
+ }
+ )
@mock.patch("datadog.DogStatsd")
def test_does_send_stats_using_dogstatsd_when_the_name_is_valid(self, mock_dogstatsd):
importlib.reload(airflow.stats)
diff --git a/tests/dags/subdir2/test_dont_ignore_this.py b/tests/dags/subdir2/test_dont_ignore_this.py
index d1894d620649e..6af1f19a239ec 100644
--- a/tests/dags/subdir2/test_dont_ignore_this.py
+++ b/tests/dags/subdir2/test_dont_ignore_this.py
@@ -23,7 +23,4 @@
DEFAULT_DATE = datetime(2019, 12, 1)
dag = DAG(dag_id='test_dag_under_subdir2', start_date=DEFAULT_DATE, schedule_interval=None)
-task = BashOperator(
- task_id='task1',
- bash_command='echo "test dag under sub directory subdir2"',
- dag=dag)
+task = BashOperator(task_id='task1', bash_command='echo "test dag under sub directory subdir2"', dag=dag)
diff --git a/tests/dags/test_backfill_pooled_tasks.py b/tests/dags/test_backfill_pooled_tasks.py
index ced09c67ecff6..02a68b31b5952 100644
--- a/tests/dags/test_backfill_pooled_tasks.py
+++ b/tests/dags/test_backfill_pooled_tasks.py
@@ -33,4 +33,5 @@
dag=dag,
pool='test_backfill_pooled_task_pool',
owner='airflow',
- start_date=datetime(2016, 2, 1))
+ start_date=datetime(2016, 2, 1),
+)
diff --git a/tests/dags/test_clear_subdag.py b/tests/dags/test_clear_subdag.py
index 1f11cdd7ed92a..4a59126f91a46 100644
--- a/tests/dags/test_clear_subdag.py
+++ b/tests/dags/test_clear_subdag.py
@@ -32,11 +32,7 @@ def create_subdag_opt(main_dag):
schedule_interval=None,
concurrency=2,
)
- BashOperator(
- bash_command="echo 1",
- task_id="daily_job_subdag_task",
- dag=subdag
- )
+ BashOperator(bash_command="echo 1", task_id="daily_job_subdag_task", dag=subdag)
return SubDagOperator(
task_id=subdag_name,
subdag=subdag,
@@ -48,12 +44,7 @@ def create_subdag_opt(main_dag):
start_date = datetime.datetime(2016, 1, 1)
-dag = DAG(
- dag_id=dag_name,
- concurrency=3,
- start_date=start_date,
- schedule_interval="0 0 * * *"
-)
+dag = DAG(dag_id=dag_name, concurrency=3, start_date=start_date, schedule_interval="0 0 * * *")
daily_job_irrelevant = BashOperator(
bash_command="echo 1",
diff --git a/tests/dags/test_cli_triggered_dags.py b/tests/dags/test_cli_triggered_dags.py
index 4facacb48a711..631a7317ace73 100644
--- a/tests/dags/test_cli_triggered_dags.py
+++ b/tests/dags/test_cli_triggered_dags.py
@@ -24,9 +24,7 @@
from airflow.utils.timezone import datetime
DEFAULT_DATE = datetime(2016, 1, 1)
-default_args = dict(
- start_date=DEFAULT_DATE,
- owner='airflow')
+default_args = dict(start_date=DEFAULT_DATE, owner='airflow')
def fail():
@@ -40,14 +38,9 @@ def success(ti=None, *args, **kwargs):
# DAG tests that tasks ignore all dependencies
-dag1 = DAG(dag_id='test_run_ignores_all_dependencies',
- default_args=dict(depends_on_past=True, **default_args))
-dag1_task1 = PythonOperator(
- task_id='test_run_dependency_task',
- python_callable=fail,
- dag=dag1)
-dag1_task2 = PythonOperator(
- task_id='test_run_dependent_task',
- python_callable=success,
- dag=dag1)
+dag1 = DAG(
+ dag_id='test_run_ignores_all_dependencies', default_args=dict(depends_on_past=True, **default_args)
+)
+dag1_task1 = PythonOperator(task_id='test_run_dependency_task', python_callable=fail, dag=dag1)
+dag1_task2 = PythonOperator(task_id='test_run_dependent_task', python_callable=success, dag=dag1)
dag1_task1.set_downstream(dag1_task2)
diff --git a/tests/dags/test_default_impersonation.py b/tests/dags/test_default_impersonation.py
index 5268287bcd762..4e379d3edea90 100644
--- a/tests/dags/test_default_impersonation.py
+++ b/tests/dags/test_default_impersonation.py
@@ -39,7 +39,10 @@
echo current user $(whoami) is not {user}!
exit 1
fi
- """.format(user=deelevated_user))
+ """.format(
+ user=deelevated_user
+ )
+)
task = BashOperator(
task_id='test_deelevated_user',
diff --git a/tests/dags/test_default_views.py b/tests/dags/test_default_views.py
index ab1b19227649a..f30177f356f5c 100644
--- a/tests/dags/test_default_views.py
+++ b/tests/dags/test_default_views.py
@@ -18,20 +18,18 @@
from airflow.models import DAG
from airflow.utils.dates import days_ago
-args = {
- 'owner': 'airflow',
- 'retries': 3,
- 'start_date': days_ago(2)
-}
+args = {'owner': 'airflow', 'retries': 3, 'start_date': days_ago(2)}
tree_dag = DAG(
- dag_id='test_tree_view', default_args=args,
+ dag_id='test_tree_view',
+ default_args=args,
schedule_interval='0 0 * * *',
default_view='tree',
)
graph_dag = DAG(
- dag_id='test_graph_view', default_args=args,
+ dag_id='test_graph_view',
+ default_args=args,
schedule_interval='0 0 * * *',
default_view='graph',
)
diff --git a/tests/dags/test_double_trigger.py b/tests/dags/test_double_trigger.py
index e1c03441afef5..9c8aae11c5150 100644
--- a/tests/dags/test_double_trigger.py
+++ b/tests/dags/test_double_trigger.py
@@ -28,6 +28,4 @@
}
dag = DAG(dag_id='test_localtaskjob_double_trigger', default_args=args)
-task = DummyOperator(
- task_id='test_localtaskjob_double_trigger_task',
- dag=dag)
+task = DummyOperator(task_id='test_localtaskjob_double_trigger_task', dag=dag)
diff --git a/tests/dags/test_example_bash_operator.py b/tests/dags/test_example_bash_operator.py
index 4d2dfdb24fb6f..1e28f664dcc4f 100644
--- a/tests/dags/test_example_bash_operator.py
+++ b/tests/dags/test_example_bash_operator.py
@@ -22,35 +22,30 @@
from airflow.operators.dummy_operator import DummyOperator
from airflow.utils.dates import days_ago
-args = {
- 'owner': 'airflow',
- 'retries': 3,
- 'start_date': days_ago(2)
-}
+args = {'owner': 'airflow', 'retries': 3, 'start_date': days_ago(2)}
dag = DAG(
- dag_id='test_example_bash_operator', default_args=args,
+ dag_id='test_example_bash_operator',
+ default_args=args,
schedule_interval='0 0 * * *',
- dagrun_timeout=timedelta(minutes=60))
+ dagrun_timeout=timedelta(minutes=60),
+)
cmd = 'ls -l'
run_this_last = DummyOperator(task_id='run_this_last', dag=dag)
-run_this = BashOperator(
- task_id='run_after_loop', bash_command='echo 1', dag=dag)
+run_this = BashOperator(task_id='run_after_loop', bash_command='echo 1', dag=dag)
run_this.set_downstream(run_this_last)
for i in range(3):
task = BashOperator(
- task_id='runme_' + str(i),
- bash_command='echo "{{ task_instance_key_str }}" && sleep 1',
- dag=dag)
+ task_id='runme_' + str(i), bash_command='echo "{{ task_instance_key_str }}" && sleep 1', dag=dag
+ )
task.set_downstream(run_this)
task = BashOperator(
- task_id='also_run_this',
- bash_command='echo "run_id={{ run_id }} | dag_run={{ dag_run }}"',
- dag=dag)
+ task_id='also_run_this', bash_command='echo "run_id={{ run_id }} | dag_run={{ dag_run }}"', dag=dag
+)
task.set_downstream(run_this_last)
if __name__ == "__main__":
diff --git a/tests/dags/test_heartbeat_failed_fast.py b/tests/dags/test_heartbeat_failed_fast.py
index 5c74d49a578d6..01d5d6be88fd2 100644
--- a/tests/dags/test_heartbeat_failed_fast.py
+++ b/tests/dags/test_heartbeat_failed_fast.py
@@ -28,7 +28,4 @@
}
dag = DAG(dag_id='test_heartbeat_failed_fast', default_args=args)
-task = BashOperator(
- task_id='test_heartbeat_failed_fast_op',
- bash_command='sleep 7',
- dag=dag)
+task = BashOperator(task_id='test_heartbeat_failed_fast_op', bash_command='sleep 7', dag=dag)
diff --git a/tests/dags/test_impersonation.py b/tests/dags/test_impersonation.py
index d58f9cae5ee97..e72a9e11eb05a 100644
--- a/tests/dags/test_impersonation.py
+++ b/tests/dags/test_impersonation.py
@@ -39,7 +39,10 @@
echo current user is not {user}!
exit 1
fi
- """.format(user=run_as_user))
+ """.format(
+ user=run_as_user
+ )
+)
task = BashOperator(
task_id='test_impersonated_user',
diff --git a/tests/dags/test_impersonation_subdag.py b/tests/dags/test_impersonation_subdag.py
index a55d5227e749a..3ea52002b2ba8 100644
--- a/tests/dags/test_impersonation_subdag.py
+++ b/tests/dags/test_impersonation_subdag.py
@@ -25,11 +25,7 @@
DEFAULT_DATE = datetime(2016, 1, 1)
-default_args = {
- 'owner': 'airflow',
- 'start_date': DEFAULT_DATE,
- 'run_as_user': 'airflow_test_user'
-}
+default_args = {'owner': 'airflow', 'start_date': DEFAULT_DATE, 'run_as_user': 'airflow_test_user'}
dag = DAG(dag_id='impersonation_subdag', default_args=default_args)
@@ -38,25 +34,15 @@ def print_today():
print(f'Today is {datetime.utcnow()}')
-subdag = DAG('impersonation_subdag.test_subdag_operation',
- default_args=default_args)
+subdag = DAG('impersonation_subdag.test_subdag_operation', default_args=default_args)
-PythonOperator(
- python_callable=print_today,
- task_id='exec_python_fn',
- dag=subdag)
+PythonOperator(python_callable=print_today, task_id='exec_python_fn', dag=subdag)
-BashOperator(
- task_id='exec_bash_operator',
- bash_command='echo "Running within SubDag"',
- dag=subdag
-)
+BashOperator(task_id='exec_bash_operator', bash_command='echo "Running within SubDag"', dag=subdag)
-subdag_operator = SubDagOperator(task_id='test_subdag_operation',
- subdag=subdag,
- mode='reschedule',
- poke_interval=1,
- dag=dag)
+subdag_operator = SubDagOperator(
+ task_id='test_subdag_operation', subdag=subdag, mode='reschedule', poke_interval=1, dag=dag
+)
diff --git a/tests/dags/test_invalid_cron.py b/tests/dags/test_invalid_cron.py
index 60e9f822c18a3..270abd377a452 100644
--- a/tests/dags/test_invalid_cron.py
+++ b/tests/dags/test_invalid_cron.py
@@ -24,11 +24,5 @@
# Cron expression. This invalid DAG will be used to
# test whether dagbag.process_file() can identify
# invalid Cron expression.
-dag1 = DAG(
- dag_id='test_invalid_cron',
- start_date=datetime(2015, 1, 1),
- schedule_interval="0 100 * * *")
-dag1_task1 = DummyOperator(
- task_id='task1',
- dag=dag1,
- owner='airflow')
+dag1 = DAG(dag_id='test_invalid_cron', start_date=datetime(2015, 1, 1), schedule_interval="0 100 * * *")
+dag1_task1 = DummyOperator(task_id='task1', dag=dag1, owner='airflow')
diff --git a/tests/dags/test_issue_1225.py b/tests/dags/test_issue_1225.py
index b962f23c112ba..d0daf9dc4ee89 100644
--- a/tests/dags/test_issue_1225.py
+++ b/tests/dags/test_issue_1225.py
@@ -30,9 +30,7 @@
from airflow.utils.trigger_rule import TriggerRule
DEFAULT_DATE = datetime(2016, 1, 1)
-default_args = dict(
- start_date=DEFAULT_DATE,
- owner='airflow')
+default_args = dict(start_date=DEFAULT_DATE, owner='airflow')
def fail():
@@ -45,19 +43,18 @@ def fail():
dag1_task1 = DummyOperator(
task_id='test_backfill_pooled_task',
dag=dag1,
- pool='test_backfill_pooled_task_pool',)
+ pool='test_backfill_pooled_task_pool',
+)
# dag2 has been moved to test_prev_dagrun_dep.py
# DAG tests that a Dag run that doesn't complete is marked failed
dag3 = DAG(dag_id='test_dagrun_states_fail', default_args=default_args)
-dag3_task1 = PythonOperator(
- task_id='test_dagrun_fail',
- dag=dag3,
- python_callable=fail)
+dag3_task1 = PythonOperator(task_id='test_dagrun_fail', dag=dag3, python_callable=fail)
dag3_task2 = DummyOperator(
task_id='test_dagrun_succeed',
- dag=dag3,)
+ dag=dag3,
+)
dag3_task2.set_upstream(dag3_task1)
# DAG tests that a Dag run that completes but has a failure is marked success
@@ -67,11 +64,7 @@ def fail():
dag=dag4,
python_callable=fail,
)
-dag4_task2 = DummyOperator(
- task_id='test_dagrun_succeed',
- dag=dag4,
- trigger_rule=TriggerRule.ALL_FAILED
-)
+dag4_task2 = DummyOperator(task_id='test_dagrun_succeed', dag=dag4, trigger_rule=TriggerRule.ALL_FAILED)
dag4_task2.set_upstream(dag4_task1)
# DAG tests that a Dag run that completes but has a root failure is marked fail
@@ -91,11 +84,13 @@ def fail():
dag6_task1 = DummyOperator(
task_id='test_depends_on_past',
depends_on_past=True,
- dag=dag6,)
+ dag=dag6,
+)
dag6_task2 = DummyOperator(
task_id='test_depends_on_past_2',
depends_on_past=True,
- dag=dag6,)
+ dag=dag6,
+)
dag6_task2.set_upstream(dag6_task1)
@@ -103,7 +98,7 @@ def fail():
dag8 = DAG(dag_id='test_dagrun_states_root_fail_unfinished', default_args=default_args)
dag8_task1 = DummyOperator(
task_id='test_dagrun_unfinished', # The test will unset the task instance state after
- # running this test
+ # running this test
dag=dag8,
)
dag8_task2 = PythonOperator(
diff --git a/tests/dags/test_latest_runs.py b/tests/dags/test_latest_runs.py
index a4ea8eb8afaea..9bd15688a49e1 100644
--- a/tests/dags/test_latest_runs.py
+++ b/tests/dags/test_latest_runs.py
@@ -24,8 +24,4 @@
for i in range(1, 2):
dag = DAG(dag_id=f'test_latest_runs_{i}')
- task = DummyOperator(
- task_id='dummy_task',
- dag=dag,
- owner='airflow',
- start_date=datetime(2016, 2, 1))
+ task = DummyOperator(task_id='dummy_task', dag=dag, owner='airflow', start_date=datetime(2016, 2, 1))
diff --git a/tests/dags/test_logging_in_dag.py b/tests/dags/test_logging_in_dag.py
index 6d828fa631d4a..8b0c629de750a 100644
--- a/tests/dags/test_logging_in_dag.py
+++ b/tests/dags/test_logging_in_dag.py
@@ -35,11 +35,7 @@ def test_logging_fn(**kwargs):
print("Log from Print statement")
-dag = DAG(
- dag_id='test_logging_dag',
- schedule_interval=None,
- start_date=datetime(2016, 1, 1)
-)
+dag = DAG(dag_id='test_logging_dag', schedule_interval=None, start_date=datetime(2016, 1, 1))
PythonOperator(
task_id='test_task',
diff --git a/tests/dags/test_mark_success.py b/tests/dags/test_mark_success.py
index d731026b26f54..759b39f33457b 100644
--- a/tests/dags/test_mark_success.py
+++ b/tests/dags/test_mark_success.py
@@ -31,7 +31,5 @@
dag = DAG(dag_id='test_mark_success', default_args=args)
task = PythonOperator(
- task_id='task1',
- python_callable=lambda x: sleep(x), # pylint: disable=W0108
- op_args=[600],
- dag=dag)
+ task_id='task1', python_callable=lambda x: sleep(x), op_args=[600], dag=dag # pylint: disable=W0108
+)
diff --git a/tests/dags/test_missing_owner.py b/tests/dags/test_missing_owner.py
index 16f715c061b96..ef81f98859f36 100644
--- a/tests/dags/test_missing_owner.py
+++ b/tests/dags/test_missing_owner.py
@@ -29,4 +29,6 @@
dagrun_timeout=timedelta(minutes=60),
tags=["example"],
) as dag:
- run_this_last = DummyOperator(task_id="test_task",)
+ run_this_last = DummyOperator(
+ task_id="test_task",
+ )
diff --git a/tests/dags/test_multiple_dags.py b/tests/dags/test_multiple_dags.py
index 0f7132615883b..44aa5cdfc2c5b 100644
--- a/tests/dags/test_multiple_dags.py
+++ b/tests/dags/test_multiple_dags.py
@@ -21,11 +21,7 @@
from airflow.operators.bash import BashOperator
from airflow.utils.dates import days_ago
-args = {
- 'owner': 'airflow',
- 'retries': 3,
- 'start_date': days_ago(2)
-}
+args = {'owner': 'airflow', 'retries': 3, 'start_date': days_ago(2)}
def create_dag(suffix):
@@ -33,7 +29,7 @@ def create_dag(suffix):
dag_id=f'test_multiple_dags__{suffix}',
default_args=args,
schedule_interval='0 0 * * *',
- dagrun_timeout=timedelta(minutes=60)
+ dagrun_timeout=timedelta(minutes=60),
)
with dag:
diff --git a/tests/dags/test_no_impersonation.py b/tests/dags/test_no_impersonation.py
index e068690ae8ac2..c5c9cea0cdafa 100644
--- a/tests/dags/test_no_impersonation.py
+++ b/tests/dags/test_no_impersonation.py
@@ -38,7 +38,8 @@
echo 'current uid does not have root privileges!'
exit 1
fi
- """)
+ """
+)
task = BashOperator(
task_id='test_superuser',
diff --git a/tests/dags/test_on_failure_callback.py b/tests/dags/test_on_failure_callback.py
index 28ea084fd1872..1daddba6da2a4 100644
--- a/tests/dags/test_on_failure_callback.py
+++ b/tests/dags/test_on_failure_callback.py
@@ -37,7 +37,5 @@ def write_data_to_callback(*arg, **kwargs): # pylint: disable=unused-argument
task = DummyOperator(
- task_id='test_om_failure_callback_task',
- dag=dag,
- on_failure_callback=write_data_to_callback
+ task_id='test_om_failure_callback_task', dag=dag, on_failure_callback=write_data_to_callback
)
diff --git a/tests/dags/test_on_kill.py b/tests/dags/test_on_kill.py
index ec172689dea6b..cb75edf349124 100644
--- a/tests/dags/test_on_kill.py
+++ b/tests/dags/test_on_kill.py
@@ -40,10 +40,5 @@ def on_kill(self):
# DAG tests backfill with pooled tasks
# Previously backfill would queue the task but never run it
-dag1 = DAG(
- dag_id='test_on_kill',
- start_date=datetime(2015, 1, 1))
-dag1_task1 = DummyWithOnKill(
- task_id='task1',
- dag=dag1,
- owner='airflow')
+dag1 = DAG(dag_id='test_on_kill', start_date=datetime(2015, 1, 1))
+dag1_task1 = DummyWithOnKill(task_id='task1', dag=dag1, owner='airflow')
diff --git a/tests/dags/test_prev_dagrun_dep.py b/tests/dags/test_prev_dagrun_dep.py
index b89e99626763a..6103c0e5c22b5 100644
--- a/tests/dags/test_prev_dagrun_dep.py
+++ b/tests/dags/test_prev_dagrun_dep.py
@@ -27,11 +27,19 @@
# DAG tests depends_on_past dependencies
dag_dop = DAG(dag_id="test_depends_on_past", default_args=default_args)
with dag_dop:
- dag_dop_task = DummyOperator(task_id="test_dop_task", depends_on_past=True,)
+ dag_dop_task = DummyOperator(
+ task_id="test_dop_task",
+ depends_on_past=True,
+ )
# DAG tests wait_for_downstream dependencies
dag_wfd = DAG(dag_id="test_wait_for_downstream", default_args=default_args)
with dag_wfd:
- dag_wfd_upstream = DummyOperator(task_id="upstream_task", wait_for_downstream=True,)
- dag_wfd_downstream = DummyOperator(task_id="downstream_task",)
+ dag_wfd_upstream = DummyOperator(
+ task_id="upstream_task",
+ wait_for_downstream=True,
+ )
+ dag_wfd_downstream = DummyOperator(
+ task_id="downstream_task",
+ )
dag_wfd_upstream >> dag_wfd_downstream
diff --git a/tests/dags/test_retry_handling_job.py b/tests/dags/test_retry_handling_job.py
index 894cfc9430711..42bf9a47331e1 100644
--- a/tests/dags/test_retry_handling_job.py
+++ b/tests/dags/test_retry_handling_job.py
@@ -34,7 +34,4 @@
dag = DAG('test_retry_handling_job', default_args=default_args, schedule_interval='@once')
-task1 = BashOperator(
- task_id='test_retry_handling_op',
- bash_command='exit 1',
- dag=dag)
+task1 = BashOperator(task_id='test_retry_handling_op', bash_command='exit 1', dag=dag)
diff --git a/tests/dags/test_scheduler_dags.py b/tests/dags/test_scheduler_dags.py
index bc47fbcf1abd1..35a9bd06a7c7b 100644
--- a/tests/dags/test_scheduler_dags.py
+++ b/tests/dags/test_scheduler_dags.py
@@ -25,26 +25,11 @@
# DAG tests backfill with pooled tasks
# Previously backfill would queue the task but never run it
-dag1 = DAG(
- dag_id='test_start_date_scheduling',
- start_date=datetime.utcnow() + timedelta(days=1))
-dag1_task1 = DummyOperator(
- task_id='dummy',
- dag=dag1,
- owner='airflow')
+dag1 = DAG(dag_id='test_start_date_scheduling', start_date=datetime.utcnow() + timedelta(days=1))
+dag1_task1 = DummyOperator(task_id='dummy', dag=dag1, owner='airflow')
-dag2 = DAG(
- dag_id='test_task_start_date_scheduling',
- start_date=DEFAULT_DATE
-)
+dag2 = DAG(dag_id='test_task_start_date_scheduling', start_date=DEFAULT_DATE)
dag2_task1 = DummyOperator(
- task_id='dummy1',
- dag=dag2,
- owner='airflow',
- start_date=DEFAULT_DATE + timedelta(days=3)
-)
-dag2_task2 = DummyOperator(
- task_id='dummy2',
- dag=dag2,
- owner='airflow'
+ task_id='dummy1', dag=dag2, owner='airflow', start_date=DEFAULT_DATE + timedelta(days=3)
)
+dag2_task2 = DummyOperator(task_id='dummy2', dag=dag2, owner='airflow')
diff --git a/tests/dags/test_task_view_type_check.py b/tests/dags/test_task_view_type_check.py
index 6f4585ba0258a..5003cadefe137 100644
--- a/tests/dags/test_task_view_type_check.py
+++ b/tests/dags/test_task_view_type_check.py
@@ -28,15 +28,14 @@
from airflow.operators.python import PythonOperator
DEFAULT_DATE = datetime(2016, 1, 1)
-default_args = dict(
- start_date=DEFAULT_DATE,
- owner='airflow')
+default_args = dict(start_date=DEFAULT_DATE, owner='airflow')
class CallableClass:
"""
Class that is callable.
"""
+
def __call__(self):
"""A __call__ method """
diff --git a/tests/dags/test_with_non_default_owner.py b/tests/dags/test_with_non_default_owner.py
index eebbb64c621bb..3e004d38fe5d2 100644
--- a/tests/dags/test_with_non_default_owner.py
+++ b/tests/dags/test_with_non_default_owner.py
@@ -29,4 +29,7 @@
dagrun_timeout=timedelta(minutes=60),
tags=["example"],
) as dag:
- run_this_last = DummyOperator(task_id="test_task", owner="John",)
+ run_this_last = DummyOperator(
+ task_id="test_task",
+ owner="John",
+ )
diff --git a/tests/dags_corrupted/test_impersonation_custom.py b/tests/dags_corrupted/test_impersonation_custom.py
index f9deb45ae0f97..77ea1ed59951c 100644
--- a/tests/dags_corrupted/test_impersonation_custom.py
+++ b/tests/dags_corrupted/test_impersonation_custom.py
@@ -32,11 +32,7 @@
DEFAULT_DATE = datetime(2016, 1, 1)
-args = {
- 'owner': 'airflow',
- 'start_date': DEFAULT_DATE,
- 'run_as_user': 'airflow_test_user'
-}
+args = {'owner': 'airflow', 'start_date': DEFAULT_DATE, 'run_as_user': 'airflow_test_user'}
dag = DAG(dag_id='impersonation_with_custom_pkg', default_args=args)
@@ -48,15 +44,10 @@ def print_today():
def check_hive_conf():
from airflow.configuration import conf
+
assert conf.get('hive', 'default_hive_mapred_queue') == 'airflow'
-PythonOperator(
- python_callable=print_today,
- task_id='exec_python_fn',
- dag=dag)
+PythonOperator(python_callable=print_today, task_id='exec_python_fn', dag=dag)
-PythonOperator(
- python_callable=check_hive_conf,
- task_id='exec_check_hive_conf_fn',
- dag=dag)
+PythonOperator(python_callable=check_hive_conf, task_id='exec_check_hive_conf_fn', dag=dag)
diff --git a/tests/dags_with_system_exit/a_system_exit.py b/tests/dags_with_system_exit/a_system_exit.py
index 530fd8fc500e0..fd0a86f2dc530 100644
--- a/tests/dags_with_system_exit/a_system_exit.py
+++ b/tests/dags_with_system_exit/a_system_exit.py
@@ -26,8 +26,6 @@
DEFAULT_DATE = datetime(2100, 1, 1)
-dag1 = DAG(
- dag_id='test_system_exit',
- start_date=DEFAULT_DATE)
+dag1 = DAG(dag_id='test_system_exit', start_date=DEFAULT_DATE)
sys.exit(-1)
diff --git a/tests/dags_with_system_exit/b_test_scheduler_dags.py b/tests/dags_with_system_exit/b_test_scheduler_dags.py
index ff70b374c40fc..9fabdb526f0ca 100644
--- a/tests/dags_with_system_exit/b_test_scheduler_dags.py
+++ b/tests/dags_with_system_exit/b_test_scheduler_dags.py
@@ -23,11 +23,6 @@
DEFAULT_DATE = datetime(2000, 1, 1)
-dag1 = DAG(
- dag_id='exit_test_dag',
- start_date=DEFAULT_DATE)
+dag1 = DAG(dag_id='exit_test_dag', start_date=DEFAULT_DATE)
-dag1_task1 = DummyOperator(
- task_id='dummy',
- dag=dag1,
- owner='airflow')
+dag1_task1 = DummyOperator(task_id='dummy', dag=dag1, owner='airflow')
diff --git a/tests/dags_with_system_exit/c_system_exit.py b/tests/dags_with_system_exit/c_system_exit.py
index 1f919dd34c836..9222bc699ad5c 100644
--- a/tests/dags_with_system_exit/c_system_exit.py
+++ b/tests/dags_with_system_exit/c_system_exit.py
@@ -26,8 +26,6 @@
DEFAULT_DATE = datetime(2100, 1, 1)
-dag1 = DAG(
- dag_id='test_system_exit',
- start_date=DEFAULT_DATE)
+dag1 = DAG(dag_id='test_system_exit', start_date=DEFAULT_DATE)
sys.exit(-1)
diff --git a/tests/deprecated_classes.py b/tests/deprecated_classes.py
index 170b0c75503b0..d94774b2715ca 100644
--- a/tests/deprecated_classes.py
+++ b/tests/deprecated_classes.py
@@ -250,7 +250,6 @@
(
'airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook',
'airflow.contrib.hooks.sagemaker_hook.SageMakerHook',
-
),
(
'airflow.providers.mongo.hooks.mongo.MongoHook',
@@ -466,8 +465,7 @@
(
"airflow.providers.google.cloud.operators.compute"
".ComputeEngineInstanceGroupUpdateManagerTemplateOperator",
- "airflow.contrib.operators.gcp_compute_operator."
- "GceInstanceGroupManagerUpdateTemplateOperator",
+ "airflow.contrib.operators.gcp_compute_operator.GceInstanceGroupManagerUpdateTemplateOperator",
),
(
"airflow.providers.google.cloud.operators.compute.ComputeEngineStartInstanceOperator",
@@ -628,8 +626,7 @@
(
"airflow.providers.google.cloud.operators.natural_language."
"CloudNaturalLanguageAnalyzeEntitiesOperator",
- "airflow.contrib.operators.gcp_natural_language_operator."
- "CloudLanguageAnalyzeEntitiesOperator",
+ "airflow.contrib.operators.gcp_natural_language_operator.CloudLanguageAnalyzeEntitiesOperator",
),
(
"airflow.providers.google.cloud.operators.natural_language."
@@ -640,8 +637,7 @@
(
"airflow.providers.google.cloud.operators.natural_language."
"CloudNaturalLanguageAnalyzeSentimentOperator",
- "airflow.contrib.operators.gcp_natural_language_operator."
- "CloudLanguageAnalyzeSentimentOperator",
+ "airflow.contrib.operators.gcp_natural_language_operator.CloudLanguageAnalyzeSentimentOperator",
),
(
"airflow.providers.google.cloud.operators.natural_language."
@@ -690,32 +686,27 @@
(
"airflow.providers.google.cloud.operators.cloud_storage_transfer_service."
"CloudDataTransferServiceCancelOperationOperator",
- "airflow.contrib.operators.gcp_transfer_operator."
- "GcpTransferServiceOperationCancelOperator",
+ "airflow.contrib.operators.gcp_transfer_operator.GcpTransferServiceOperationCancelOperator",
),
(
"airflow.providers.google.cloud.operators.cloud_storage_transfer_service."
"CloudDataTransferServiceGetOperationOperator",
- "airflow.contrib.operators.gcp_transfer_operator."
- "GcpTransferServiceOperationGetOperator",
+ "airflow.contrib.operators.gcp_transfer_operator.GcpTransferServiceOperationGetOperator",
),
(
"airflow.providers.google.cloud.operators.cloud_storage_transfer_service."
"CloudDataTransferServicePauseOperationOperator",
- "airflow.contrib.operators.gcp_transfer_operator."
- "GcpTransferServiceOperationPauseOperator",
+ "airflow.contrib.operators.gcp_transfer_operator.GcpTransferServiceOperationPauseOperator",
),
(
"airflow.providers.google.cloud.operators.cloud_storage_transfer_service."
"CloudDataTransferServiceResumeOperationOperator",
- "airflow.contrib.operators.gcp_transfer_operator."
- "GcpTransferServiceOperationResumeOperator",
+ "airflow.contrib.operators.gcp_transfer_operator.GcpTransferServiceOperationResumeOperator",
),
(
"airflow.providers.google.cloud.operators.cloud_storage_transfer_service."
"CloudDataTransferServiceListOperationsOperator",
- "airflow.contrib.operators.gcp_transfer_operator."
- "GcpTransferServiceOperationsListOperator",
+ "airflow.contrib.operators.gcp_transfer_operator.GcpTransferServiceOperationsListOperator",
),
(
"airflow.providers.google.cloud.operators.translate.CloudTranslateTextOperator",
@@ -801,8 +792,7 @@
),
(
"airflow.providers.google.cloud.operators.vision.CloudVisionRemoveProductFromProductSetOperator",
- "airflow.contrib.operators.gcp_vision_operator."
- "CloudVisionRemoveProductFromProductSetOperator",
+ "airflow.contrib.operators.gcp_vision_operator.CloudVisionRemoveProductFromProductSetOperator",
),
(
"airflow.providers.google.cloud.operators.mlengine.MLEngineStartBatchPredictionJobOperator",
@@ -841,66 +831,53 @@
"airflow.contrib.operators.pubsub_operator.PubSubTopicDeleteOperator",
),
(
- "airflow.providers.google.cloud."
- "operators.dataproc.DataprocCreateClusterOperator",
+ "airflow.providers.google.cloud.operators.dataproc.DataprocCreateClusterOperator",
"airflow.contrib.operators.dataproc_operator.DataprocClusterCreateOperator",
),
(
- "airflow.providers.google.cloud."
- "operators.dataproc.DataprocDeleteClusterOperator",
+ "airflow.providers.google.cloud.operators.dataproc.DataprocDeleteClusterOperator",
"airflow.contrib.operators.dataproc_operator.DataprocClusterDeleteOperator",
),
(
- "airflow.providers.google.cloud."
- "operators.dataproc.DataprocScaleClusterOperator",
+ "airflow.providers.google.cloud.operators.dataproc.DataprocScaleClusterOperator",
"airflow.contrib.operators.dataproc_operator.DataprocClusterScaleOperator",
),
(
- "airflow.providers.google.cloud."
- "operators.dataproc.DataprocSubmitHadoopJobOperator",
+ "airflow.providers.google.cloud.operators.dataproc.DataprocSubmitHadoopJobOperator",
"airflow.contrib.operators.dataproc_operator.DataProcHadoopOperator",
),
(
- "airflow.providers.google.cloud."
- "operators.dataproc.DataprocSubmitHiveJobOperator",
+ "airflow.providers.google.cloud.operators.dataproc.DataprocSubmitHiveJobOperator",
"airflow.contrib.operators.dataproc_operator.DataProcHiveOperator",
),
(
- "airflow.providers.google.cloud."
- "operators.dataproc.DataprocJobBaseOperator",
+ "airflow.providers.google.cloud.operators.dataproc.DataprocJobBaseOperator",
"airflow.contrib.operators.dataproc_operator.DataProcJobBaseOperator",
),
(
- "airflow.providers.google.cloud."
- "operators.dataproc.DataprocSubmitPigJobOperator",
+ "airflow.providers.google.cloud.operators.dataproc.DataprocSubmitPigJobOperator",
"airflow.contrib.operators.dataproc_operator.DataProcPigOperator",
),
(
- "airflow.providers.google.cloud."
- "operators.dataproc.DataprocSubmitPySparkJobOperator",
+ "airflow.providers.google.cloud.operators.dataproc.DataprocSubmitPySparkJobOperator",
"airflow.contrib.operators.dataproc_operator.DataProcPySparkOperator",
),
(
- "airflow.providers.google.cloud."
- "operators.dataproc.DataprocSubmitSparkJobOperator",
+ "airflow.providers.google.cloud.operators.dataproc.DataprocSubmitSparkJobOperator",
"airflow.contrib.operators.dataproc_operator.DataProcSparkOperator",
),
(
- "airflow.providers.google.cloud."
- "operators.dataproc.DataprocSubmitSparkSqlJobOperator",
+ "airflow.providers.google.cloud.operators.dataproc.DataprocSubmitSparkSqlJobOperator",
"airflow.contrib.operators.dataproc_operator.DataProcSparkSqlOperator",
),
(
"airflow.providers.google.cloud."
"operators.dataproc.DataprocInstantiateInlineWorkflowTemplateOperator",
- "airflow.contrib.operators.dataproc_operator."
- "DataprocWorkflowTemplateInstantiateInlineOperator",
+ "airflow.contrib.operators.dataproc_operator.DataprocWorkflowTemplateInstantiateInlineOperator",
),
(
- "airflow.providers.google.cloud."
- "operators.dataproc.DataprocInstantiateWorkflowTemplateOperator",
- "airflow.contrib.operators.dataproc_operator."
- "DataprocWorkflowTemplateInstantiateOperator",
+ "airflow.providers.google.cloud.operators.dataproc.DataprocInstantiateWorkflowTemplateOperator",
+ "airflow.contrib.operators.dataproc_operator.DataprocWorkflowTemplateInstantiateOperator",
),
(
"airflow.providers.google.cloud.operators.bigquery.BigQueryCreateEmptyDatasetOperator",
@@ -1285,43 +1262,43 @@
),
(
'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLBaseOperator',
- 'airflow.contrib.operators.gcp_sql_operator.CloudSqlBaseOperator'
+ 'airflow.contrib.operators.gcp_sql_operator.CloudSqlBaseOperator',
),
(
'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLCreateInstanceDatabaseOperator',
- 'airflow.contrib.operators.gcp_sql_operator.CloudSqlInstanceDatabaseCreateOperator'
+ 'airflow.contrib.operators.gcp_sql_operator.CloudSqlInstanceDatabaseCreateOperator',
),
(
'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLCreateInstanceOperator',
- 'airflow.contrib.operators.gcp_sql_operator.CloudSqlInstanceCreateOperator'
+ 'airflow.contrib.operators.gcp_sql_operator.CloudSqlInstanceCreateOperator',
),
(
'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLDeleteInstanceDatabaseOperator',
- 'airflow.contrib.operators.gcp_sql_operator.CloudSqlInstanceDatabaseDeleteOperator'
+ 'airflow.contrib.operators.gcp_sql_operator.CloudSqlInstanceDatabaseDeleteOperator',
),
(
'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLDeleteInstanceOperator',
- 'airflow.contrib.operators.gcp_sql_operator.CloudSqlInstanceDeleteOperator'
+ 'airflow.contrib.operators.gcp_sql_operator.CloudSqlInstanceDeleteOperator',
),
(
'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLExecuteQueryOperator',
- 'airflow.contrib.operators.gcp_sql_operator.CloudSqlQueryOperator'
+ 'airflow.contrib.operators.gcp_sql_operator.CloudSqlQueryOperator',
),
(
'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLExportInstanceOperator',
- 'airflow.contrib.operators.gcp_sql_operator.CloudSqlInstanceExportOperator'
+ 'airflow.contrib.operators.gcp_sql_operator.CloudSqlInstanceExportOperator',
),
(
'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLImportInstanceOperator',
- 'airflow.contrib.operators.gcp_sql_operator.CloudSqlInstanceImportOperator'
+ 'airflow.contrib.operators.gcp_sql_operator.CloudSqlInstanceImportOperator',
),
(
'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLInstancePatchOperator',
- 'airflow.contrib.operators.gcp_sql_operator.CloudSqlInstancePatchOperator'
+ 'airflow.contrib.operators.gcp_sql_operator.CloudSqlInstancePatchOperator',
),
(
'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLPatchInstanceDatabaseOperator',
- 'airflow.contrib.operators.gcp_sql_operator.CloudSqlInstanceDatabasePatchOperator'
+ 'airflow.contrib.operators.gcp_sql_operator.CloudSqlInstanceDatabasePatchOperator',
),
(
'airflow.providers.jira.operators.jira.JiraOperator',
@@ -1379,14 +1356,12 @@
),
(
"airflow.providers.google.cloud.sensors.bigtable.BigtableTableReplicationCompletedSensor",
- "airflow.contrib.operators.gcp_bigtable_operator."
- "BigtableTableWaitForReplicationSensor",
+ "airflow.contrib.operators.gcp_bigtable_operator.BigtableTableWaitForReplicationSensor",
),
(
"airflow.providers.google.cloud.sensors.cloud_storage_transfer_service."
"CloudDataTransferServiceJobStatusSensor",
- "airflow.contrib.sensors.gcp_transfer_sensor."
- "GCPTransferServiceWaitForJobStatusSensor",
+ "airflow.contrib.sensors.gcp_transfer_sensor.GCPTransferServiceWaitForJobStatusSensor",
),
(
"airflow.providers.google.cloud.sensors.pubsub.PubSubPullSensor",
@@ -1575,7 +1550,7 @@
(
'airflow.providers.sftp.sensors.sftp.SFTPSensor',
'airflow.contrib.sensors.sftp_sensor.SFTPSensor',
- )
+ ),
]
TRANSFERS = [
@@ -1619,8 +1594,7 @@
),
(
"airflow.providers.google.cloud.transfers.postgres_to_gcs.PostgresToGCSOperator",
- "airflow.contrib.operators.postgres_to_gcs_operator."
- "PostgresToGoogleCloudStorageOperator",
+ "airflow.contrib.operators.postgres_to_gcs_operator.PostgresToGoogleCloudStorageOperator",
),
(
"airflow.providers.google.cloud.transfers.bigquery_to_bigquery.BigQueryToBigQueryOperator",
@@ -1738,7 +1712,7 @@
(
'airflow.providers.google.cloud.operators.cloud_storage_transfer_service'
'.CloudDataTransferServiceS3ToGCSOperator',
- 'airflow.contrib.operators.s3_to_gcs_transfer_operator.CloudDataTransferServiceS3ToGCSOperator'
+ 'airflow.contrib.operators.s3_to_gcs_transfer_operator.CloudDataTransferServiceS3ToGCSOperator',
),
(
'airflow.providers.google.cloud.transfers.cassandra_to_gcs.CassandraToGCSOperator',
@@ -1754,34 +1728,34 @@
(
'airflow.utils.weekday.WeekDay',
'airflow.contrib.utils.weekday.WeekDay',
- )
+ ),
]
LOGS = [
(
"airflow.providers.amazon.aws.log.s3_task_handler.S3TaskHandler",
- "airflow.utils.log.s3_task_handler.S3TaskHandler"
+ "airflow.utils.log.s3_task_handler.S3TaskHandler",
),
(
'airflow.providers.amazon.aws.log.cloudwatch_task_handler.CloudwatchTaskHandler',
- 'airflow.utils.log.cloudwatch_task_handler.CloudwatchTaskHandler'
+ 'airflow.utils.log.cloudwatch_task_handler.CloudwatchTaskHandler',
),
(
'airflow.providers.elasticsearch.log.es_task_handler.ElasticsearchTaskHandler',
- 'airflow.utils.log.es_task_handler.ElasticsearchTaskHandler'
+ 'airflow.utils.log.es_task_handler.ElasticsearchTaskHandler',
),
(
"airflow.providers.google.cloud.log.stackdriver_task_handler.StackdriverTaskHandler",
- "airflow.utils.log.stackdriver_task_handler.StackdriverTaskHandler"
+ "airflow.utils.log.stackdriver_task_handler.StackdriverTaskHandler",
),
(
"airflow.providers.google.cloud.log.gcs_task_handler.GCSTaskHandler",
- "airflow.utils.log.gcs_task_handler.GCSTaskHandler"
+ "airflow.utils.log.gcs_task_handler.GCSTaskHandler",
),
(
"airflow.providers.microsoft.azure.log.wasb_task_handler.WasbTaskHandler",
- "airflow.utils.log.wasb_task_handler.WasbTaskHandler"
- )
+ "airflow.utils.log.wasb_task_handler.WasbTaskHandler",
+ ),
]
ALL = HOOKS + OPERATORS + SECRETS + SENSORS + TRANSFERS + UTILS + LOGS
diff --git a/tests/executors/test_base_executor.py b/tests/executors/test_base_executor.py
index f324fa753b21f..06fe03067c84f 100644
--- a/tests/executors/test_base_executor.py
+++ b/tests/executors/test_base_executor.py
@@ -51,9 +51,11 @@ def test_get_event_buffer(self):
def test_gauge_executor_metrics(self, mock_stats_gauge, mock_trigger_tasks, mock_sync):
executor = BaseExecutor()
executor.heartbeat()
- calls = [mock.call('executor.open_slots', mock.ANY),
- mock.call('executor.queued_tasks', mock.ANY),
- mock.call('executor.running_tasks', mock.ANY)]
+ calls = [
+ mock.call('executor.open_slots', mock.ANY),
+ mock.call('executor.queued_tasks', mock.ANY),
+ mock.call('executor.running_tasks', mock.ANY),
+ ]
mock_stats_gauge.assert_has_calls(calls)
def test_try_adopt_task_instances(self):
diff --git a/tests/executors/test_celery_executor.py b/tests/executors/test_celery_executor.py
index feb245ef13a85..1701fed4ce736 100644
--- a/tests/executors/test_celery_executor.py
+++ b/tests/executors/test_celery_executor.py
@@ -49,10 +49,7 @@
def _prepare_test_bodies():
if 'CELERY_BROKER_URLS' in os.environ:
- return [
- (url, )
- for url in os.environ['CELERY_BROKER_URLS'].split(',')
- ]
+ return [(url,) for url in os.environ['CELERY_BROKER_URLS'].split(',')]
return [(conf.get('celery', 'BROKER_URL'))]
@@ -97,7 +94,6 @@ def _prepare_app(broker_url=None, execute=None):
class TestCeleryExecutor(unittest.TestCase):
-
def setUp(self) -> None:
db.clear_db_runs()
db.clear_db_jobs()
@@ -127,12 +123,20 @@ def fake_execute_command(command):
execute_date = datetime.now()
task_tuples_to_send = [
- (('success', 'fake_simple_ti', execute_date, 0),
- None, success_command, celery_executor.celery_configuration['task_default_queue'],
- celery_executor.execute_command),
- (('fail', 'fake_simple_ti', execute_date, 0),
- None, fail_command, celery_executor.celery_configuration['task_default_queue'],
- celery_executor.execute_command)
+ (
+ ('success', 'fake_simple_ti', execute_date, 0),
+ None,
+ success_command,
+ celery_executor.celery_configuration['task_default_queue'],
+ celery_executor.execute_command,
+ ),
+ (
+ ('fail', 'fake_simple_ti', execute_date, 0),
+ None,
+ fail_command,
+ celery_executor.celery_configuration['task_default_queue'],
+ celery_executor.execute_command,
+ ),
]
# "Enqueue" them. We don't have a real SimpleTaskInstance, so directly edit the dict
@@ -145,24 +149,22 @@ def fake_execute_command(command):
list(executor.tasks.keys()),
[
('success', 'fake_simple_ti', execute_date, 0),
- ('fail', 'fake_simple_ti', execute_date, 0)
- ]
+ ('fail', 'fake_simple_ti', execute_date, 0),
+ ],
)
self.assertEqual(
- executor.event_buffer[('success', 'fake_simple_ti', execute_date, 0)][0],
- State.QUEUED
+ executor.event_buffer[('success', 'fake_simple_ti', execute_date, 0)][0], State.QUEUED
)
self.assertEqual(
- executor.event_buffer[('fail', 'fake_simple_ti', execute_date, 0)][0],
- State.QUEUED
+ executor.event_buffer[('fail', 'fake_simple_ti', execute_date, 0)][0], State.QUEUED
)
executor.end(synchronous=True)
- self.assertEqual(executor.event_buffer[('success', 'fake_simple_ti', execute_date, 0)][0],
- State.SUCCESS)
- self.assertEqual(executor.event_buffer[('fail', 'fake_simple_ti', execute_date, 0)][0],
- State.FAILED)
+ self.assertEqual(
+ executor.event_buffer[('success', 'fake_simple_ti', execute_date, 0)][0], State.SUCCESS
+ )
+ self.assertEqual(executor.event_buffer[('fail', 'fake_simple_ti', execute_date, 0)][0], State.FAILED)
self.assertNotIn('success', executor.tasks)
self.assertNotIn('fail', executor.tasks)
@@ -182,14 +184,15 @@ def fake_execute_command():
# which will cause TypeError when calling task.apply_async()
executor = celery_executor.CeleryExecutor()
task = BashOperator(
- task_id="test",
- bash_command="true",
- dag=DAG(dag_id='id'),
- start_date=datetime.now()
+ task_id="test", bash_command="true", dag=DAG(dag_id='id'), start_date=datetime.now()
)
when = datetime.now()
- value_tuple = 'command', 1, None, \
- SimpleTaskInstance(ti=TaskInstance(task=task, execution_date=datetime.now()))
+ value_tuple = (
+ 'command',
+ 1,
+ None,
+ SimpleTaskInstance(ti=TaskInstance(task=task, execution_date=datetime.now())),
+ )
key = ('fail', 'fake_simple_ti', when, 0)
executor.queued_tasks[key] = value_tuple
executor.heartbeat()
@@ -202,9 +205,7 @@ def test_exception_propagation(self):
with _prepare_app(), self.assertLogs(celery_executor.log) as cm:
executor = celery_executor.CeleryExecutor()
- executor.tasks = {
- 'key': FakeCeleryResult()
- }
+ executor.tasks = {'key': FakeCeleryResult()}
executor.bulk_state_fetcher._get_many_using_multiprocessing(executor.tasks.values())
self.assertTrue(any(celery_executor.CELERY_FETCH_ERR_MSG_HEADER in line for line in cm.output))
@@ -216,20 +217,21 @@ def test_exception_propagation(self):
def test_gauge_executor_metrics(self, mock_stats_gauge, mock_trigger_tasks, mock_sync):
executor = celery_executor.CeleryExecutor()
executor.heartbeat()
- calls = [mock.call('executor.open_slots', mock.ANY),
- mock.call('executor.queued_tasks', mock.ANY),
- mock.call('executor.running_tasks', mock.ANY)]
+ calls = [
+ mock.call('executor.open_slots', mock.ANY),
+ mock.call('executor.queued_tasks', mock.ANY),
+ mock.call('executor.running_tasks', mock.ANY),
+ ]
mock_stats_gauge.assert_has_calls(calls)
- @parameterized.expand((
- [['true'], ValueError],
- [['airflow', 'version'], ValueError],
- [['airflow', 'tasks', 'run'], None]
- ))
+ @parameterized.expand(
+ ([['true'], ValueError], [['airflow', 'version'], ValueError], [['airflow', 'tasks', 'run'], None])
+ )
def test_command_validation(self, command, expected_exception):
# Check that we validate _on the receiving_ side, not just sending side
- with mock.patch('airflow.executors.celery_executor._execute_in_subprocess') as mock_subproc, \
- mock.patch('airflow.executors.celery_executor._execute_in_fork') as mock_fork:
+ with mock.patch(
+ 'airflow.executors.celery_executor._execute_in_subprocess'
+ ) as mock_subproc, mock.patch('airflow.executors.celery_executor._execute_in_fork') as mock_fork:
if expected_exception:
with pytest.raises(expected_exception):
celery_executor.execute_command(command)
@@ -238,8 +240,7 @@ def test_command_validation(self, command, expected_exception):
else:
celery_executor.execute_command(command)
# One of these should be called.
- assert mock_subproc.call_args == ((command,),) or \
- mock_fork.call_args == ((command,),)
+ assert mock_subproc.call_args == ((command,),) or mock_fork.call_args == ((command,),)
@pytest.mark.backend("mysql", "postgres")
def test_try_adopt_task_instances_none(self):
@@ -289,8 +290,8 @@ def test_try_adopt_task_instances(self):
dict(executor.adopted_task_timeouts),
{
key_1: queued_dttm + executor.task_adoption_timeout,
- key_2: queued_dttm + executor.task_adoption_timeout
- }
+ key_2: queued_dttm + executor.task_adoption_timeout,
+ },
)
self.assertEqual(executor.tasks, {key_1: AsyncResult("231"), key_2: AsyncResult("232")})
self.assertEqual(not_adopted_tis, [])
@@ -313,14 +314,11 @@ def test_check_for_stalled_adopted_tasks(self):
executor = celery_executor.CeleryExecutor()
executor.adopted_task_timeouts = {
key_1: queued_dttm + executor.task_adoption_timeout,
- key_2: queued_dttm + executor.task_adoption_timeout
+ key_2: queued_dttm + executor.task_adoption_timeout,
}
executor.tasks = {key_1: AsyncResult("231"), key_2: AsyncResult("232")}
executor.sync()
- self.assertEqual(
- executor.event_buffer,
- {key_1: (State.FAILED, None), key_2: (State.FAILED, None)}
- )
+ self.assertEqual(executor.event_buffer, {key_1: (State.FAILED, None), key_2: (State.FAILED, None)})
self.assertEqual(executor.tasks, {})
self.assertEqual(executor.adopted_task_timeouts, {})
@@ -350,10 +348,10 @@ def __ne__(self, other):
class TestBulkStateFetcher(unittest.TestCase):
-
- @mock.patch("celery.backends.base.BaseKeyValueStoreBackend.mget", return_value=[
- json.dumps({"status": "SUCCESS", "task_id": "123"})
- ])
+ @mock.patch(
+ "celery.backends.base.BaseKeyValueStoreBackend.mget",
+ return_value=[json.dumps({"status": "SUCCESS", "task_id": "123"})],
+ )
@pytest.mark.integration("redis")
@pytest.mark.integration("rabbitmq")
@pytest.mark.backend("mysql", "postgres")
@@ -362,10 +360,12 @@ def test_should_support_kv_backend(self, mock_mget):
mock_backend = BaseKeyValueStoreBackend(app=celery_executor.app)
with mock.patch.object(celery_executor.app, 'backend', mock_backend):
fetcher = BulkStateFetcher()
- result = fetcher.get_many([
- mock.MagicMock(task_id="123"),
- mock.MagicMock(task_id="456"),
- ])
+ result = fetcher.get_many(
+ [
+ mock.MagicMock(task_id="123"),
+ mock.MagicMock(task_id="456"),
+ ]
+ )
# Assert called - ignore order
mget_args, _ = mock_mget.call_args
@@ -389,10 +389,12 @@ def test_should_support_db_backend(self, mock_session):
]
fetcher = BulkStateFetcher()
- result = fetcher.get_many([
- mock.MagicMock(task_id="123"),
- mock.MagicMock(task_id="456"),
- ])
+ result = fetcher.get_many(
+ [
+ mock.MagicMock(task_id="123"),
+ mock.MagicMock(task_id="456"),
+ ]
+ )
self.assertEqual(result, {'123': ('SUCCESS', None), '456': ("PENDING", None)})
@@ -405,9 +407,11 @@ def test_should_support_base_backend(self):
with mock.patch.object(celery_executor.app, 'backend', mock_backend):
fetcher = BulkStateFetcher(1)
- result = fetcher.get_many([
- ClassWithCustomAttributes(task_id="123", state='SUCCESS'),
- ClassWithCustomAttributes(task_id="456", state="PENDING"),
- ])
+ result = fetcher.get_many(
+ [
+ ClassWithCustomAttributes(task_id="123", state='SUCCESS'),
+ ClassWithCustomAttributes(task_id="456", state="PENDING"),
+ ]
+ )
self.assertEqual(result, {'123': ('SUCCESS', None), '456': ("PENDING", None)})
diff --git a/tests/executors/test_celery_kubernetes_executor.py b/tests/executors/test_celery_kubernetes_executor.py
index a4f4852da11b0..cc8a958b2ed5a 100644
--- a/tests/executors/test_celery_kubernetes_executor.py
+++ b/tests/executors/test_celery_kubernetes_executor.py
@@ -74,7 +74,8 @@ def when_using_k8s_executor():
cke.queue_command(simple_task_instance, command, priority, queue)
k8s_executor_mock.queue_command.assert_called_once_with(
- simple_task_instance, command, priority, queue)
+ simple_task_instance, command, priority, queue
+ )
celery_executor_mock.queue_command.assert_not_called()
def when_using_celery_executor():
@@ -88,7 +89,8 @@ def when_using_celery_executor():
cke.queue_command(simple_task_instance, command, priority, queue)
celery_executor_mock.queue_command.assert_called_once_with(
- simple_task_instance, command, priority, queue)
+ simple_task_instance, command, priority, queue
+ )
k8s_executor_mock.queue_command.assert_not_called()
when_using_k8s_executor()
@@ -121,7 +123,7 @@ def when_using_k8s_executor():
ignore_task_deps,
ignore_ti_state,
pool,
- cfg_path
+ cfg_path,
)
k8s_executor_mock.queue_task_instance.assert_called_once_with(
@@ -133,7 +135,7 @@ def when_using_k8s_executor():
ignore_task_deps,
ignore_ti_state,
pool,
- cfg_path
+ cfg_path,
)
celery_executor_mock.queue_task_instance.assert_not_called()
@@ -154,7 +156,7 @@ def when_using_celery_executor():
ignore_task_deps,
ignore_ti_state,
pool,
- cfg_path
+ cfg_path,
)
k8s_executor_mock.queue_task_instance.assert_not_called()
@@ -167,7 +169,7 @@ def when_using_celery_executor():
ignore_task_deps,
ignore_ti_state,
pool,
- cfg_path
+ cfg_path,
)
when_using_k8s_executor()
diff --git a/tests/executors/test_dask_executor.py b/tests/executors/test_dask_executor.py
index 2f43271acb4a8..09a22f7bedbc3 100644
--- a/tests/executors/test_dask_executor.py
+++ b/tests/executors/test_dask_executor.py
@@ -26,6 +26,7 @@
try:
from distributed import LocalCluster
+
# utility functions imported from the dask testing suite to instantiate a test
# cluster for tls tests
from distributed.utils_test import cluster as dask_testing_cluster, get_cert, tls_security
@@ -38,7 +39,6 @@
class TestBaseDask(unittest.TestCase):
-
def assert_tasks_on_executor(self, executor):
success_command = ['airflow', 'tasks', 'run', '--help']
@@ -49,10 +49,8 @@ def assert_tasks_on_executor(self, executor):
executor.execute_async(key='success', command=success_command)
executor.execute_async(key='fail', command=fail_command)
- success_future = next(
- k for k, v in executor.futures.items() if v == 'success')
- fail_future = next(
- k for k, v in executor.futures.items() if v == 'fail')
+ success_future = next(k for k, v in executor.futures.items() if v == 'success')
+ fail_future = next(k for k, v in executor.futures.items() if v == 'fail')
# wait for the futures to execute, with a timeout
timeout = timezone.utcnow() + timedelta(seconds=30)
@@ -60,7 +58,8 @@ def assert_tasks_on_executor(self, executor):
if timezone.utcnow() > timeout:
raise ValueError(
'The futures should have finished; there is probably '
- 'an error communicating with the Dask cluster.')
+ 'an error communicating with the Dask cluster.'
+ )
# both tasks should have finished
self.assertTrue(success_future.done())
@@ -72,7 +71,6 @@ def assert_tasks_on_executor(self, executor):
class TestDaskExecutor(TestBaseDask):
-
def setUp(self):
self.dagbag = DagBag(include_examples=True)
self.cluster = LocalCluster()
@@ -92,8 +90,8 @@ def test_backfill_integration(self):
start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE,
ignore_first_depends_on_past=True,
- executor=DaskExecutor(
- cluster_address=self.cluster.scheduler_address))
+ executor=DaskExecutor(cluster_address=self.cluster.scheduler_address),
+ )
job.run()
def tearDown(self):
@@ -101,21 +99,22 @@ def tearDown(self):
class TestDaskExecutorTLS(TestBaseDask):
-
def setUp(self):
self.dagbag = DagBag(include_examples=True)
- @conf_vars({
- ('dask', 'tls_ca'): get_cert('tls-ca-cert.pem'),
- ('dask', 'tls_cert'): get_cert('tls-key-cert.pem'),
- ('dask', 'tls_key'): get_cert('tls-key.pem'),
- })
+ @conf_vars(
+ {
+ ('dask', 'tls_ca'): get_cert('tls-ca-cert.pem'),
+ ('dask', 'tls_cert'): get_cert('tls-key-cert.pem'),
+ ('dask', 'tls_key'): get_cert('tls-key.pem'),
+ }
+ )
def test_tls(self):
# These use test certs that ship with dask/distributed and should not be
# used in production
with dask_testing_cluster(
worker_kwargs={'security': tls_security(), "protocol": "tls"},
- scheduler_kwargs={'security': tls_security(), "protocol": "tls"}
+ scheduler_kwargs={'security': tls_security(), "protocol": "tls"},
) as (cluster, _):
executor = DaskExecutor(cluster_address=cluster['address'])
@@ -133,7 +132,9 @@ def test_tls(self):
def test_gauge_executor_metrics(self, mock_stats_gauge, mock_trigger_tasks, mock_sync):
executor = DaskExecutor()
executor.heartbeat()
- calls = [mock.call('executor.open_slots', mock.ANY),
- mock.call('executor.queued_tasks', mock.ANY),
- mock.call('executor.running_tasks', mock.ANY)]
+ calls = [
+ mock.call('executor.open_slots', mock.ANY),
+ mock.call('executor.queued_tasks', mock.ANY),
+ mock.call('executor.running_tasks', mock.ANY),
+ ]
mock_stats_gauge.assert_has_calls(calls)
diff --git a/tests/executors/test_executor_loader.py b/tests/executors/test_executor_loader.py
index 60f4bcd7e0d5d..63ef8dd71961f 100644
--- a/tests/executors/test_executor_loader.py
+++ b/tests/executors/test_executor_loader.py
@@ -38,44 +38,37 @@ class FakePlugin(plugins_manager.AirflowPlugin):
class TestExecutorLoader(unittest.TestCase):
-
def setUp(self) -> None:
ExecutorLoader._default_executor = None
def tearDown(self) -> None:
ExecutorLoader._default_executor = None
- @parameterized.expand([
- ("CeleryExecutor", ),
- ("CeleryKubernetesExecutor", ),
- ("DebugExecutor", ),
- ("KubernetesExecutor", ),
- ("LocalExecutor", ),
- ])
+ @parameterized.expand(
+ [
+ ("CeleryExecutor",),
+ ("CeleryKubernetesExecutor",),
+ ("DebugExecutor",),
+ ("KubernetesExecutor",),
+ ("LocalExecutor",),
+ ]
+ )
def test_should_support_executor_from_core(self, executor_name):
- with conf_vars({
- ("core", "executor"): executor_name
- }):
+ with conf_vars({("core", "executor"): executor_name}):
executor = ExecutorLoader.get_default_executor()
self.assertIsNotNone(executor)
self.assertEqual(executor_name, executor.__class__.__name__)
- @mock.patch("airflow.plugins_manager.plugins", [
- FakePlugin()
- ])
+ @mock.patch("airflow.plugins_manager.plugins", [FakePlugin()])
@mock.patch("airflow.plugins_manager.executors_modules", None)
def test_should_support_plugins(self):
- with conf_vars({
- ("core", "executor"): f"{TEST_PLUGIN_NAME}.FakeExecutor"
- }):
+ with conf_vars({("core", "executor"): f"{TEST_PLUGIN_NAME}.FakeExecutor"}):
executor = ExecutorLoader.get_default_executor()
self.assertIsNotNone(executor)
self.assertEqual("FakeExecutor", executor.__class__.__name__)
def test_should_support_custom_path(self):
- with conf_vars({
- ("core", "executor"): "tests.executors.test_executor_loader.FakeExecutor"
- }):
+ with conf_vars({("core", "executor"): "tests.executors.test_executor_loader.FakeExecutor"}):
executor = ExecutorLoader.get_default_executor()
self.assertIsNotNone(executor)
self.assertEqual("FakeExecutor", executor.__class__.__name__)
diff --git a/tests/executors/test_kubernetes_executor.py b/tests/executors/test_kubernetes_executor.py
index 4c38352ceef25..9765669d32aa9 100644
--- a/tests/executors/test_kubernetes_executor.py
+++ b/tests/executors/test_kubernetes_executor.py
@@ -32,7 +32,9 @@
from kubernetes.client.rest import ApiException
from airflow.executors.kubernetes_executor import (
- AirflowKubernetesScheduler, KubernetesExecutor, create_pod_id,
+ AirflowKubernetesScheduler,
+ KubernetesExecutor,
+ create_pod_id,
)
from airflow.kubernetes import pod_generator
from airflow.kubernetes.pod_generator import PodGenerator
@@ -57,38 +59,29 @@ def _cases(self):
("my.dag.id", "my.task.id"),
("MYDAGID", "MYTASKID"),
("my_dag_id", "my_task_id"),
- ("mydagid" * 200, "my_task_id" * 200)
+ ("mydagid" * 200, "my_task_id" * 200),
]
- cases.extend([
- (self._gen_random_string(seed, 200), self._gen_random_string(seed, 200))
- for seed in range(100)
- ])
+ cases.extend(
+ [(self._gen_random_string(seed, 200), self._gen_random_string(seed, 200)) for seed in range(100)]
+ )
return cases
@staticmethod
def _is_valid_pod_id(name):
regex = r"^[a-z0-9]([-a-z0-9]*[a-z0-9])?(\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*$"
- return (
- len(name) <= 253 and
- all(ch.lower() == ch for ch in name) and
- re.match(regex, name))
+ return len(name) <= 253 and all(ch.lower() == ch for ch in name) and re.match(regex, name)
@staticmethod
def _is_safe_label_value(value):
regex = r'^[^a-z0-9A-Z]*|[^a-zA-Z0-9_\-\.]|[^a-z0-9A-Z]*$'
- return (
- len(value) <= 63 and
- re.match(regex, value))
+ return len(value) <= 63 and re.match(regex, value)
- @unittest.skipIf(AirflowKubernetesScheduler is None,
- 'kubernetes python package is not installed')
+ @unittest.skipIf(AirflowKubernetesScheduler is None, 'kubernetes python package is not installed')
def test_create_pod_id(self):
for dag_id, task_id in self._cases():
- pod_name = PodGenerator.make_unique_pod_id(
- create_pod_id(dag_id, task_id)
- )
+ pod_name = PodGenerator.make_unique_pod_id(create_pod_id(dag_id, task_id))
self.assertTrue(self._is_valid_pod_id(pod_name))
def test_make_safe_label_value(self):
@@ -98,23 +91,16 @@ def test_make_safe_label_value(self):
safe_task_id = pod_generator.make_safe_label_value(task_id)
self.assertTrue(self._is_safe_label_value(safe_task_id))
dag_id = "my_dag_id"
- self.assertEqual(
- dag_id,
- pod_generator.make_safe_label_value(dag_id)
- )
+ self.assertEqual(dag_id, pod_generator.make_safe_label_value(dag_id))
dag_id = "my_dag_id_" + "a" * 64
self.assertEqual(
- "my_dag_id_" + "a" * 43 + "-0ce114c45",
- pod_generator.make_safe_label_value(dag_id)
+ "my_dag_id_" + "a" * 43 + "-0ce114c45", pod_generator.make_safe_label_value(dag_id)
)
def test_execution_date_serialize_deserialize(self):
datetime_obj = datetime.now()
- serialized_datetime = \
- pod_generator.datetime_to_label_safe_datestring(
- datetime_obj)
- new_datetime_obj = pod_generator.label_safe_datestring_to_datetime(
- serialized_datetime)
+ serialized_datetime = pod_generator.datetime_to_label_safe_datestring(datetime_obj)
+ new_datetime_obj = pod_generator.label_safe_datestring_to_datetime(serialized_datetime)
self.assertEqual(datetime_obj, new_datetime_obj)
@@ -129,27 +115,27 @@ def setUp(self) -> None:
self.kubernetes_executor = KubernetesExecutor()
self.kubernetes_executor.job_id = "5"
- @unittest.skipIf(AirflowKubernetesScheduler is None,
- 'kubernetes python package is not installed')
+ @unittest.skipIf(AirflowKubernetesScheduler is None, 'kubernetes python package is not installed')
@mock.patch('airflow.executors.kubernetes_executor.KubernetesJobWatcher')
@mock.patch('airflow.executors.kubernetes_executor.get_kube_client')
def test_run_next_exception(self, mock_get_kube_client, mock_kubernetes_job_watcher):
import sys
+
path = sys.path[0] + '/tests/kubernetes/pod_generator_base_with_secrets.yaml'
# When a quota is exceeded this is the ApiException we get
response = HTTPResponse(
body='{"kind": "Status", "apiVersion": "v1", "metadata": {}, "status": "Failure", '
- '"message": "pods \\"podname\\" is forbidden: exceeded quota: compute-resources, '
- 'requested: limits.memory=4Gi, used: limits.memory=6508Mi, limited: limits.memory=10Gi", '
- '"reason": "Forbidden", "details": {"name": "podname", "kind": "pods"}, "code": 403}')
+ '"message": "pods \\"podname\\" is forbidden: exceeded quota: compute-resources, '
+ 'requested: limits.memory=4Gi, used: limits.memory=6508Mi, limited: limits.memory=10Gi", '
+ '"reason": "Forbidden", "details": {"name": "podname", "kind": "pods"}, "code": 403}'
+ )
response.status = 403
response.reason = "Forbidden"
# A mock kube_client that throws errors when making a pod
mock_kube_client = mock.patch('kubernetes.client.CoreV1Api', autospec=True)
- mock_kube_client.create_namespaced_pod = mock.MagicMock(
- side_effect=ApiException(http_resp=response))
+ mock_kube_client.create_namespaced_pod = mock.MagicMock(side_effect=ApiException(http_resp=response))
mock_get_kube_client.return_value = mock_kube_client
mock_api_client = mock.MagicMock()
mock_api_client.sanitize_for_serialization.return_value = {}
@@ -163,10 +149,11 @@ def test_run_next_exception(self, mock_get_kube_client, mock_kubernetes_job_watc
kubernetes_executor.start()
# Execute a task while the Api Throws errors
try_number = 1
- kubernetes_executor.execute_async(key=('dag', 'task', datetime.utcnow(), try_number),
- queue=None,
- command=['airflow', 'tasks', 'run', 'true', 'some_parameter'],
- )
+ kubernetes_executor.execute_async(
+ key=('dag', 'task', datetime.utcnow(), try_number),
+ queue=None,
+ command=['airflow', 'tasks', 'run', 'true', 'some_parameter'],
+ )
kubernetes_executor.sync()
kubernetes_executor.sync()
@@ -188,9 +175,11 @@ def test_run_next_exception(self, mock_get_kube_client, mock_kubernetes_job_watc
def test_gauge_executor_metrics(self, mock_stats_gauge, mock_trigger_tasks, mock_sync, mock_kube_config):
executor = self.kubernetes_executor
executor.heartbeat()
- calls = [mock.call('executor.open_slots', mock.ANY),
- mock.call('executor.queued_tasks', mock.ANY),
- mock.call('executor.running_tasks', mock.ANY)]
+ calls = [
+ mock.call('executor.open_slots', mock.ANY),
+ mock.call('executor.queued_tasks', mock.ANY),
+ mock.call('executor.running_tasks', mock.ANY),
+ ]
mock_stats_gauge.assert_has_calls(calls)
@mock.patch('airflow.executors.kubernetes_executor.KubernetesJobWatcher')
@@ -218,10 +207,7 @@ def test_change_state_success(self, mock_delete_pod, mock_get_kube_client, mock_
@mock.patch('airflow.executors.kubernetes_executor.get_kube_client')
@mock.patch('airflow.executors.kubernetes_executor.AirflowKubernetesScheduler.delete_pod')
def test_change_state_failed_no_deletion(
- self,
- mock_delete_pod,
- mock_get_kube_client,
- mock_kubernetes_job_watcher
+ self, mock_delete_pod, mock_get_kube_client, mock_kubernetes_job_watcher
):
executor = self.kubernetes_executor
executor.kube_config.delete_worker_pods = False
@@ -232,13 +218,15 @@ def test_change_state_failed_no_deletion(
executor._change_state(key, State.FAILED, 'pod_id', 'default')
self.assertTrue(executor.event_buffer[key][0] == State.FAILED)
mock_delete_pod.assert_not_called()
-# pylint: enable=unused-argument
+
+ # pylint: enable=unused-argument
@mock.patch('airflow.executors.kubernetes_executor.KubernetesJobWatcher')
@mock.patch('airflow.executors.kubernetes_executor.get_kube_client')
@mock.patch('airflow.executors.kubernetes_executor.AirflowKubernetesScheduler.delete_pod')
- def test_change_state_skip_pod_deletion(self, mock_delete_pod, mock_get_kube_client,
- mock_kubernetes_job_watcher):
+ def test_change_state_skip_pod_deletion(
+ self, mock_delete_pod, mock_get_kube_client, mock_kubernetes_job_watcher
+ ):
test_time = timezone.utcnow()
executor = self.kubernetes_executor
executor.kube_config.delete_worker_pods = False
@@ -253,8 +241,9 @@ def test_change_state_skip_pod_deletion(self, mock_delete_pod, mock_get_kube_cli
@mock.patch('airflow.executors.kubernetes_executor.KubernetesJobWatcher')
@mock.patch('airflow.executors.kubernetes_executor.get_kube_client')
@mock.patch('airflow.executors.kubernetes_executor.AirflowKubernetesScheduler.delete_pod')
- def test_change_state_failed_pod_deletion(self, mock_delete_pod, mock_get_kube_client,
- mock_kubernetes_job_watcher):
+ def test_change_state_failed_pod_deletion(
+ self, mock_delete_pod, mock_get_kube_client, mock_kubernetes_job_watcher
+ ):
executor = self.kubernetes_executor
executor.kube_config.delete_worker_pods_on_failure = True
@@ -271,24 +260,22 @@ def test_adopt_launched_task(self, mock_kube_client):
pod_ids = {"dagtask": {}}
pod = k8s.V1Pod(
metadata=k8s.V1ObjectMeta(
- name="foo",
- labels={
- "airflow-worker": "bar",
- "dag_id": "dag",
- "task_id": "task"
- }
+ name="foo", labels={"airflow-worker": "bar", "dag_id": "dag", "task_id": "task"}
)
)
executor.adopt_launched_task(mock_kube_client, pod=pod, pod_ids=pod_ids)
self.assertEqual(
mock_kube_client.patch_namespaced_pod.call_args[1],
- {'body': {'metadata': {'labels': {'airflow-worker': 'modified',
- 'dag_id': 'dag',
- 'task_id': 'task'},
- 'name': 'foo'}
- },
- 'name': 'foo',
- 'namespace': None}
+ {
+ 'body': {
+ 'metadata': {
+ 'labels': {'airflow-worker': 'modified', 'dag_id': 'dag', 'task_id': 'task'},
+ 'name': 'foo',
+ }
+ },
+ 'name': 'foo',
+ 'namespace': None,
+ },
)
self.assertDictEqual(pod_ids, {})
@@ -306,12 +293,7 @@ def test_not_adopt_unassigned_task(self, mock_kube_client):
pod_ids = {"foobar": {}}
pod = k8s.V1Pod(
metadata=k8s.V1ObjectMeta(
- name="foo",
- labels={
- "airflow-worker": "bar",
- "dag_id": "dag",
- "task_id": "task"
- }
+ name="foo", labels={"airflow-worker": "bar", "dag_id": "dag", "task_id": "task"}
)
)
executor.adopt_launched_task(mock_kube_client, pod=pod, pod_ids=pod_ids)
diff --git a/tests/executors/test_local_executor.py b/tests/executors/test_local_executor.py
index 20722bf478fd4..b12f783c1b5f7 100644
--- a/tests/executors/test_local_executor.py
+++ b/tests/executors/test_local_executor.py
@@ -91,14 +91,16 @@ def _test_execute(self, parallelism, success_command, fail_command):
self.assertEqual(executor.workers_used, expected)
def test_execution_subprocess_unlimited_parallelism(self):
- with mock.patch.object(settings, 'EXECUTE_TASKS_NEW_PYTHON_INTERPRETER',
- new_callable=mock.PropertyMock) as option:
+ with mock.patch.object(
+ settings, 'EXECUTE_TASKS_NEW_PYTHON_INTERPRETER', new_callable=mock.PropertyMock
+ ) as option:
option.return_value = True
self.execution_parallelism_subprocess(parallelism=0) # pylint: disable=no-value-for-parameter
def test_execution_subprocess_limited_parallelism(self):
- with mock.patch.object(settings, 'EXECUTE_TASKS_NEW_PYTHON_INTERPRETER',
- new_callable=mock.PropertyMock) as option:
+ with mock.patch.object(
+ settings, 'EXECUTE_TASKS_NEW_PYTHON_INTERPRETER', new_callable=mock.PropertyMock
+ ) as option:
option.return_value = True
self.execution_parallelism_subprocess(parallelism=2) # pylint: disable=no-value-for-parameter
@@ -116,7 +118,9 @@ def test_execution_limited_parallelism_fork(self):
def test_gauge_executor_metrics(self, mock_stats_gauge, mock_trigger_tasks, mock_sync):
executor = LocalExecutor()
executor.heartbeat()
- calls = [mock.call('executor.open_slots', mock.ANY),
- mock.call('executor.queued_tasks', mock.ANY),
- mock.call('executor.running_tasks', mock.ANY)]
+ calls = [
+ mock.call('executor.open_slots', mock.ANY),
+ mock.call('executor.queued_tasks', mock.ANY),
+ mock.call('executor.running_tasks', mock.ANY),
+ ]
mock_stats_gauge.assert_has_calls(calls)
diff --git a/tests/executors/test_sequential_executor.py b/tests/executors/test_sequential_executor.py
index 0189a6ebe401f..bb97e98b2a018 100644
--- a/tests/executors/test_sequential_executor.py
+++ b/tests/executors/test_sequential_executor.py
@@ -23,14 +23,15 @@
class TestSequentialExecutor(unittest.TestCase):
-
@mock.patch('airflow.executors.sequential_executor.SequentialExecutor.sync')
@mock.patch('airflow.executors.base_executor.BaseExecutor.trigger_tasks')
@mock.patch('airflow.executors.base_executor.Stats.gauge')
def test_gauge_executor_metrics(self, mock_stats_gauge, mock_trigger_tasks, mock_sync):
executor = SequentialExecutor()
executor.heartbeat()
- calls = [mock.call('executor.open_slots', mock.ANY),
- mock.call('executor.queued_tasks', mock.ANY),
- mock.call('executor.running_tasks', mock.ANY)]
+ calls = [
+ mock.call('executor.open_slots', mock.ANY),
+ mock.call('executor.queued_tasks', mock.ANY),
+ mock.call('executor.running_tasks', mock.ANY),
+ ]
mock_stats_gauge.assert_has_calls(calls)
diff --git a/tests/hooks/test_dbapi_hook.py b/tests/hooks/test_dbapi_hook.py
index 5e16aaf4cde80..a122e01bdfc02 100644
--- a/tests/hooks/test_dbapi_hook.py
+++ b/tests/hooks/test_dbapi_hook.py
@@ -25,7 +25,6 @@
class TestDbApiHook(unittest.TestCase):
-
def setUp(self):
super().setUp()
@@ -45,8 +44,7 @@ def get_conn(self):
def test_get_records(self):
statement = "SQL"
- rows = [("hello",),
- ("world",)]
+ rows = [("hello",), ("world",)]
self.cur.fetchall.return_value = rows
@@ -59,8 +57,7 @@ def test_get_records(self):
def test_get_records_parameters(self):
statement = "SQL"
parameters = ["X", "Y", "Z"]
- rows = [("hello",),
- ("world",)]
+ rows = [("hello",), ("world",)]
self.cur.fetchall.return_value = rows
@@ -83,8 +80,7 @@ def test_get_records_exception(self):
def test_insert_rows(self):
table = "table"
- rows = [("hello",),
- ("world",)]
+ rows = [("hello",), ("world",)]
self.db_hook.insert_rows(table, rows)
@@ -100,8 +96,7 @@ def test_insert_rows(self):
def test_insert_rows_replace(self):
table = "table"
- rows = [("hello",),
- ("world",)]
+ rows = [("hello",), ("world",)]
self.db_hook.insert_rows(table, rows, replace=True)
@@ -117,8 +112,7 @@ def test_insert_rows_replace(self):
def test_insert_rows_target_fields(self):
table = "table"
- rows = [("hello",),
- ("world",)]
+ rows = [("hello",), ("world",)]
target_fields = ["field"]
self.db_hook.insert_rows(table, rows, target_fields)
@@ -135,8 +129,7 @@ def test_insert_rows_target_fields(self):
def test_insert_rows_commit_every(self):
table = "table"
- rows = [("hello",),
- ("world",)]
+ rows = [("hello",), ("world",)]
commit_every = 1
self.db_hook.insert_rows(table, rows, commit_every=commit_every)
@@ -152,25 +145,24 @@ def test_insert_rows_commit_every(self):
self.cur.execute.assert_any_call(sql, row)
def test_get_uri_schema_not_none(self):
- self.db_hook.get_connection = mock.MagicMock(return_value=Connection(
- conn_type="conn_type",
- host="host",
- login="login",
- password="password",
- schema="schema",
- port=1
- ))
+ self.db_hook.get_connection = mock.MagicMock(
+ return_value=Connection(
+ conn_type="conn_type",
+ host="host",
+ login="login",
+ password="password",
+ schema="schema",
+ port=1,
+ )
+ )
self.assertEqual("conn_type://login:password@host:1/schema", self.db_hook.get_uri())
def test_get_uri_schema_none(self):
- self.db_hook.get_connection = mock.MagicMock(return_value=Connection(
- conn_type="conn_type",
- host="host",
- login="login",
- password="password",
- schema=None,
- port=1
- ))
+ self.db_hook.get_connection = mock.MagicMock(
+ return_value=Connection(
+ conn_type="conn_type", host="host", login="login", password="password", schema=None, port=1
+ )
+ )
self.assertEqual("conn_type://login:password@host:1/", self.db_hook.get_uri())
def test_run_log(self):
diff --git a/tests/jobs/test_backfill_job.py b/tests/jobs/test_backfill_job.py
index b9a3b6a754a96..9898859236416 100644
--- a/tests/jobs/test_backfill_job.py
+++ b/tests/jobs/test_backfill_job.py
@@ -31,7 +31,10 @@
from airflow import settings
from airflow.cli import cli_parser
from airflow.exceptions import (
- AirflowException, AirflowTaskTimeout, DagConcurrencyLimitReached, NoAvailablePoolSlot,
+ AirflowException,
+ AirflowTaskTimeout,
+ DagConcurrencyLimitReached,
+ NoAvailablePoolSlot,
TaskConcurrencyLimitReached,
)
from airflow.jobs.backfill_job import BackfillJob
@@ -53,19 +56,11 @@
@pytest.mark.heisentests
class TestBackfillJob(unittest.TestCase):
-
def _get_dummy_dag(self, dag_id, pool=Pool.DEFAULT_POOL_NAME, task_concurrency=None):
- dag = DAG(
- dag_id=dag_id,
- start_date=DEFAULT_DATE,
- schedule_interval='@daily')
+ dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, schedule_interval='@daily')
with dag:
- DummyOperator(
- task_id='op',
- pool=pool,
- task_concurrency=task_concurrency,
- dag=dag)
+ DummyOperator(task_id='op', pool=pool, task_concurrency=task_concurrency, dag=dag)
dag.clear()
return dag
@@ -105,7 +100,7 @@ def test_unfinished_dag_runs_set_to_failed(self):
dag=dag,
start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=8),
- ignore_first_depends_on_past=True
+ ignore_first_depends_on_past=True,
)
job._set_unfinished_dag_runs_to_failed([dag_run])
@@ -129,7 +124,7 @@ def test_dag_run_with_finished_tasks_set_to_success(self):
dag=dag,
start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=8),
- ignore_first_depends_on_past=True
+ ignore_first_depends_on_past=True,
)
job._set_unfinished_dag_runs_to_failed([dag_run])
@@ -154,10 +149,7 @@ def test_trigger_controller_dag(self):
self.assertFalse(task_instances_list)
job = BackfillJob(
- dag=dag,
- start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE,
- ignore_first_depends_on_past=True
+ dag=dag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_first_depends_on_past=True
)
job.run()
@@ -181,7 +173,7 @@ def test_backfill_multi_dates(self):
start_date=DEFAULT_DATE,
end_date=end_date,
executor=executor,
- ignore_first_depends_on_past=True
+ ignore_first_depends_on_past=True,
)
job.run()
@@ -201,20 +193,19 @@ def test_backfill_multi_dates(self):
("run_this_last", end_date),
]
self.assertListEqual(
- [((dag.dag_id, task_id, when, 1), (State.SUCCESS, None))
- for (task_id, when) in expected_execution_order],
- executor.sorted_tasks
+ [
+ ((dag.dag_id, task_id, when, 1), (State.SUCCESS, None))
+ for (task_id, when) in expected_execution_order
+ ],
+ executor.sorted_tasks,
)
session = settings.Session()
- drs = session.query(DagRun).filter(
- DagRun.dag_id == dag.dag_id
- ).order_by(DagRun.execution_date).all()
+ drs = session.query(DagRun).filter(DagRun.dag_id == dag.dag_id).order_by(DagRun.execution_date).all()
self.assertTrue(drs[0].execution_date == DEFAULT_DATE)
self.assertTrue(drs[0].state == State.SUCCESS)
- self.assertTrue(drs[1].execution_date ==
- DEFAULT_DATE + datetime.timedelta(days=1))
+ self.assertTrue(drs[1].execution_date == DEFAULT_DATE + datetime.timedelta(days=1))
self.assertTrue(drs[1].state == State.SUCCESS)
dag.clear()
@@ -241,8 +232,7 @@ def test_backfill_multi_dates(self):
],
[
"example_bash_operator",
- ("runme_0", "runme_1", "runme_2", "also_run_this", "run_after_loop",
- "run_this_last"),
+ ("runme_0", "runme_1", "runme_2", "also_run_this", "run_after_loop", "run_this_last"),
],
[
"example_skip_dag",
@@ -277,13 +267,16 @@ def test_backfill_examples(self, dag_id, expected_execution_order):
start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE,
executor=executor,
- ignore_first_depends_on_past=True)
+ ignore_first_depends_on_past=True,
+ )
job.run()
self.assertListEqual(
- [((dag_id, task_id, DEFAULT_DATE, 1), (State.SUCCESS, None))
- for task_id in expected_execution_order],
- executor.sorted_tasks
+ [
+ ((dag_id, task_id, DEFAULT_DATE, 1), (State.SUCCESS, None))
+ for task_id in expected_execution_order
+ ],
+ executor.sorted_tasks,
)
def test_backfill_conf(self):
@@ -292,11 +285,13 @@ def test_backfill_conf(self):
executor = MockExecutor()
conf_ = json.loads("""{"key": "value"}""")
- job = BackfillJob(dag=dag,
- executor=executor,
- start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE + datetime.timedelta(days=2),
- conf=conf_)
+ job = BackfillJob(
+ dag=dag,
+ executor=executor,
+ start_date=DEFAULT_DATE,
+ end_date=DEFAULT_DATE + datetime.timedelta(days=2),
+ conf=conf_,
+ )
job.run()
dr = DagRun.find(dag_id='test_backfill_conf')
@@ -548,10 +543,7 @@ def test_backfill_respect_pool_limit(self, mock_log):
self.assertGreater(times_pool_limit_reached_in_debug, 0)
def test_backfill_run_rescheduled(self):
- dag = DAG(
- dag_id='test_backfill_run_rescheduled',
- start_date=DEFAULT_DATE,
- schedule_interval='@daily')
+ dag = DAG(dag_id='test_backfill_run_rescheduled', start_date=DEFAULT_DATE, schedule_interval='@daily')
with dag:
DummyOperator(
@@ -563,142 +555,130 @@ def test_backfill_run_rescheduled(self):
executor = MockExecutor()
- job = BackfillJob(dag=dag,
- executor=executor,
- start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE + datetime.timedelta(days=2),
- )
+ job = BackfillJob(
+ dag=dag,
+ executor=executor,
+ start_date=DEFAULT_DATE,
+ end_date=DEFAULT_DATE + datetime.timedelta(days=2),
+ )
job.run()
- ti = TI(task=dag.get_task('test_backfill_run_rescheduled_task-1'),
- execution_date=DEFAULT_DATE)
+ ti = TI(task=dag.get_task('test_backfill_run_rescheduled_task-1'), execution_date=DEFAULT_DATE)
ti.refresh_from_db()
ti.set_state(State.UP_FOR_RESCHEDULE)
- job = BackfillJob(dag=dag,
- executor=executor,
- start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE + datetime.timedelta(days=2),
- rerun_failed_tasks=True
- )
+ job = BackfillJob(
+ dag=dag,
+ executor=executor,
+ start_date=DEFAULT_DATE,
+ end_date=DEFAULT_DATE + datetime.timedelta(days=2),
+ rerun_failed_tasks=True,
+ )
job.run()
- ti = TI(task=dag.get_task('test_backfill_run_rescheduled_task-1'),
- execution_date=DEFAULT_DATE)
+ ti = TI(task=dag.get_task('test_backfill_run_rescheduled_task-1'), execution_date=DEFAULT_DATE)
ti.refresh_from_db()
self.assertEqual(ti.state, State.SUCCESS)
def test_backfill_rerun_failed_tasks(self):
- dag = DAG(
- dag_id='test_backfill_rerun_failed',
- start_date=DEFAULT_DATE,
- schedule_interval='@daily')
+ dag = DAG(dag_id='test_backfill_rerun_failed', start_date=DEFAULT_DATE, schedule_interval='@daily')
with dag:
- DummyOperator(
- task_id='test_backfill_rerun_failed_task-1',
- dag=dag)
+ DummyOperator(task_id='test_backfill_rerun_failed_task-1', dag=dag)
dag.clear()
executor = MockExecutor()
- job = BackfillJob(dag=dag,
- executor=executor,
- start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE + datetime.timedelta(days=2),
- )
+ job = BackfillJob(
+ dag=dag,
+ executor=executor,
+ start_date=DEFAULT_DATE,
+ end_date=DEFAULT_DATE + datetime.timedelta(days=2),
+ )
job.run()
- ti = TI(task=dag.get_task('test_backfill_rerun_failed_task-1'),
- execution_date=DEFAULT_DATE)
+ ti = TI(task=dag.get_task('test_backfill_rerun_failed_task-1'), execution_date=DEFAULT_DATE)
ti.refresh_from_db()
ti.set_state(State.FAILED)
- job = BackfillJob(dag=dag,
- executor=executor,
- start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE + datetime.timedelta(days=2),
- rerun_failed_tasks=True
- )
+ job = BackfillJob(
+ dag=dag,
+ executor=executor,
+ start_date=DEFAULT_DATE,
+ end_date=DEFAULT_DATE + datetime.timedelta(days=2),
+ rerun_failed_tasks=True,
+ )
job.run()
- ti = TI(task=dag.get_task('test_backfill_rerun_failed_task-1'),
- execution_date=DEFAULT_DATE)
+ ti = TI(task=dag.get_task('test_backfill_rerun_failed_task-1'), execution_date=DEFAULT_DATE)
ti.refresh_from_db()
self.assertEqual(ti.state, State.SUCCESS)
def test_backfill_rerun_upstream_failed_tasks(self):
dag = DAG(
- dag_id='test_backfill_rerun_upstream_failed',
- start_date=DEFAULT_DATE,
- schedule_interval='@daily')
+ dag_id='test_backfill_rerun_upstream_failed', start_date=DEFAULT_DATE, schedule_interval='@daily'
+ )
with dag:
- op1 = DummyOperator(task_id='test_backfill_rerun_upstream_failed_task-1',
- dag=dag)
- op2 = DummyOperator(task_id='test_backfill_rerun_upstream_failed_task-2',
- dag=dag)
+ op1 = DummyOperator(task_id='test_backfill_rerun_upstream_failed_task-1', dag=dag)
+ op2 = DummyOperator(task_id='test_backfill_rerun_upstream_failed_task-2', dag=dag)
op1.set_upstream(op2)
dag.clear()
executor = MockExecutor()
- job = BackfillJob(dag=dag,
- executor=executor,
- start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE + datetime.timedelta(days=2),
- )
+ job = BackfillJob(
+ dag=dag,
+ executor=executor,
+ start_date=DEFAULT_DATE,
+ end_date=DEFAULT_DATE + datetime.timedelta(days=2),
+ )
job.run()
- ti = TI(task=dag.get_task('test_backfill_rerun_upstream_failed_task-1'),
- execution_date=DEFAULT_DATE)
+ ti = TI(task=dag.get_task('test_backfill_rerun_upstream_failed_task-1'), execution_date=DEFAULT_DATE)
ti.refresh_from_db()
ti.set_state(State.UPSTREAM_FAILED)
- job = BackfillJob(dag=dag,
- executor=executor,
- start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE + datetime.timedelta(days=2),
- rerun_failed_tasks=True
- )
+ job = BackfillJob(
+ dag=dag,
+ executor=executor,
+ start_date=DEFAULT_DATE,
+ end_date=DEFAULT_DATE + datetime.timedelta(days=2),
+ rerun_failed_tasks=True,
+ )
job.run()
- ti = TI(task=dag.get_task('test_backfill_rerun_upstream_failed_task-1'),
- execution_date=DEFAULT_DATE)
+ ti = TI(task=dag.get_task('test_backfill_rerun_upstream_failed_task-1'), execution_date=DEFAULT_DATE)
ti.refresh_from_db()
self.assertEqual(ti.state, State.SUCCESS)
def test_backfill_rerun_failed_tasks_without_flag(self):
- dag = DAG(
- dag_id='test_backfill_rerun_failed',
- start_date=DEFAULT_DATE,
- schedule_interval='@daily')
+ dag = DAG(dag_id='test_backfill_rerun_failed', start_date=DEFAULT_DATE, schedule_interval='@daily')
with dag:
- DummyOperator(
- task_id='test_backfill_rerun_failed_task-1',
- dag=dag)
+ DummyOperator(task_id='test_backfill_rerun_failed_task-1', dag=dag)
dag.clear()
executor = MockExecutor()
- job = BackfillJob(dag=dag,
- executor=executor,
- start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE + datetime.timedelta(days=2),
- )
+ job = BackfillJob(
+ dag=dag,
+ executor=executor,
+ start_date=DEFAULT_DATE,
+ end_date=DEFAULT_DATE + datetime.timedelta(days=2),
+ )
job.run()
- ti = TI(task=dag.get_task('test_backfill_rerun_failed_task-1'),
- execution_date=DEFAULT_DATE)
+ ti = TI(task=dag.get_task('test_backfill_rerun_failed_task-1'), execution_date=DEFAULT_DATE)
ti.refresh_from_db()
ti.set_state(State.FAILED)
- job = BackfillJob(dag=dag,
- executor=executor,
- start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE + datetime.timedelta(days=2),
- rerun_failed_tasks=False
- )
+ job = BackfillJob(
+ dag=dag,
+ executor=executor,
+ start_date=DEFAULT_DATE,
+ end_date=DEFAULT_DATE + datetime.timedelta(days=2),
+ rerun_failed_tasks=False,
+ )
with self.assertRaises(AirflowException):
job.run()
@@ -707,7 +687,8 @@ def test_backfill_ordered_concurrent_execute(self):
dag = DAG(
dag_id='test_backfill_ordered_concurrent_execute',
start_date=DEFAULT_DATE,
- schedule_interval="@daily")
+ schedule_interval="@daily",
+ )
with dag:
op1 = DummyOperator(task_id='leave1')
@@ -724,11 +705,12 @@ def test_backfill_ordered_concurrent_execute(self):
dag.clear()
executor = MockExecutor(parallelism=16)
- job = BackfillJob(dag=dag,
- executor=executor,
- start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE + datetime.timedelta(days=2),
- )
+ job = BackfillJob(
+ dag=dag,
+ executor=executor,
+ start_date=DEFAULT_DATE,
+ end_date=DEFAULT_DATE + datetime.timedelta(days=2),
+ )
job.run()
date0 = DEFAULT_DATE
@@ -748,15 +730,12 @@ def test_backfill_ordered_concurrent_execute(self):
('leave1', date2),
('leave2', date0),
('leave2', date1),
- ('leave2', date2)
+ ('leave2', date2),
],
- [('upstream_level_1', date0), ('upstream_level_1', date1),
- ('upstream_level_1', date2)],
- [('upstream_level_2', date0), ('upstream_level_2', date1),
- ('upstream_level_2', date2)],
- [('upstream_level_3', date0), ('upstream_level_3', date1),
- ('upstream_level_3', date2)],
- ]
+ [('upstream_level_1', date0), ('upstream_level_1', date1), ('upstream_level_1', date2)],
+ [('upstream_level_2', date0), ('upstream_level_2', date1), ('upstream_level_2', date2)],
+ [('upstream_level_3', date0), ('upstream_level_3', date1), ('upstream_level_3', date2)],
+ ],
)
def test_backfill_pooled_tasks(self):
@@ -773,11 +752,7 @@ def test_backfill_pooled_tasks(self):
dag.clear()
executor = MockExecutor(do_update=True)
- job = BackfillJob(
- dag=dag,
- start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE,
- executor=executor)
+ job = BackfillJob(dag=dag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, executor=executor)
# run with timeout because this creates an infinite loop if not
# caught
@@ -786,9 +761,7 @@ def test_backfill_pooled_tasks(self):
job.run()
except AirflowTaskTimeout:
pass
- ti = TI(
- task=dag.get_task('test_backfill_pooled_task'),
- execution_date=DEFAULT_DATE)
+ ti = TI(task=dag.get_task('test_backfill_pooled_task'), execution_date=DEFAULT_DATE)
ti.refresh_from_db()
self.assertEqual(ti.state, State.SUCCESS)
@@ -804,14 +777,16 @@ def test_backfill_depends_on_past(self):
self.assertRaisesRegex(
AirflowException,
'BackfillJob is deadlocked',
- BackfillJob(dag=dag, start_date=run_date, end_date=run_date).run)
+ BackfillJob(dag=dag, start_date=run_date, end_date=run_date).run,
+ )
BackfillJob(
dag=dag,
start_date=run_date,
end_date=run_date,
executor=MockExecutor(),
- ignore_first_depends_on_past=True).run()
+ ignore_first_depends_on_past=True,
+ ).run()
# ti should have succeeded
ti = TI(dag.tasks[0], run_date)
@@ -833,10 +808,7 @@ def test_backfill_depends_on_past_backwards(self):
dag.clear()
executor = MockExecutor()
- job = BackfillJob(dag=dag,
- executor=executor,
- ignore_first_depends_on_past=True,
- **kwargs)
+ job = BackfillJob(dag=dag, executor=executor, ignore_first_depends_on_past=True, **kwargs)
job.run()
ti = TI(dag.get_task('test_dop_task'), end_date)
@@ -846,13 +818,11 @@ def test_backfill_depends_on_past_backwards(self):
# raises backwards
expected_msg = 'You cannot backfill backwards because one or more tasks depend_on_past: {}'.format(
- 'test_dop_task')
+ 'test_dop_task'
+ )
with self.assertRaisesRegex(AirflowException, expected_msg):
executor = MockExecutor()
- job = BackfillJob(dag=dag,
- executor=executor,
- run_backwards=True,
- **kwargs)
+ job = BackfillJob(dag=dag, executor=executor, run_backwards=True, **kwargs)
job.run()
def test_cli_receives_delay_arg(self):
@@ -878,7 +848,7 @@ def _get_dag_test_max_active_limits(self, dag_id, max_active_runs=1):
dag_id=dag_id,
start_date=DEFAULT_DATE,
schedule_interval="@hourly",
- max_active_runs=max_active_runs
+ max_active_runs=max_active_runs,
)
with dag:
@@ -895,18 +865,16 @@ def _get_dag_test_max_active_limits(self, dag_id, max_active_runs=1):
def test_backfill_max_limit_check_within_limit(self):
dag = self._get_dag_test_max_active_limits(
- 'test_backfill_max_limit_check_within_limit',
- max_active_runs=16)
+ 'test_backfill_max_limit_check_within_limit', max_active_runs=16
+ )
start_date = DEFAULT_DATE - datetime.timedelta(hours=1)
end_date = DEFAULT_DATE
executor = MockExecutor()
- job = BackfillJob(dag=dag,
- start_date=start_date,
- end_date=end_date,
- executor=executor,
- donot_pickle=True)
+ job = BackfillJob(
+ dag=dag, start_date=start_date, end_date=end_date, executor=executor, donot_pickle=True
+ )
job.run()
dagruns = DagRun.find(dag_id=dag.dag_id)
@@ -943,16 +911,14 @@ def run_backfill(cond):
thread_session.close()
executor = MockExecutor()
- job = BackfillJob(dag=dag,
- start_date=start_date,
- end_date=end_date,
- executor=executor,
- donot_pickle=True)
+ job = BackfillJob(
+ dag=dag, start_date=start_date, end_date=end_date, executor=executor, donot_pickle=True
+ )
job.run()
- backfill_job_thread = threading.Thread(target=run_backfill,
- name="run_backfill",
- args=(dag_run_created_cond,))
+ backfill_job_thread = threading.Thread(
+ target=run_backfill, name="run_backfill", args=(dag_run_created_cond,)
+ )
dag_run_created_cond.acquire()
with create_session() as session:
@@ -982,23 +948,22 @@ def run_backfill(cond):
dag_run_created_cond.release()
def test_backfill_max_limit_check_no_count_existing(self):
- dag = self._get_dag_test_max_active_limits(
- 'test_backfill_max_limit_check_no_count_existing')
+ dag = self._get_dag_test_max_active_limits('test_backfill_max_limit_check_no_count_existing')
start_date = DEFAULT_DATE
end_date = DEFAULT_DATE
# Existing dagrun that is within the backfill range
- dag.create_dagrun(run_id="test_existing_backfill",
- state=State.RUNNING,
- execution_date=DEFAULT_DATE,
- start_date=DEFAULT_DATE)
+ dag.create_dagrun(
+ run_id="test_existing_backfill",
+ state=State.RUNNING,
+ execution_date=DEFAULT_DATE,
+ start_date=DEFAULT_DATE,
+ )
executor = MockExecutor()
- job = BackfillJob(dag=dag,
- start_date=start_date,
- end_date=end_date,
- executor=executor,
- donot_pickle=True)
+ job = BackfillJob(
+ dag=dag, start_date=start_date, end_date=end_date, executor=executor, donot_pickle=True
+ )
job.run()
# BackfillJob will run since the existing DagRun does not count for the max
@@ -1010,8 +975,7 @@ def test_backfill_max_limit_check_no_count_existing(self):
self.assertEqual(State.SUCCESS, dagruns[0].state)
def test_backfill_max_limit_check_complete_loop(self):
- dag = self._get_dag_test_max_active_limits(
- 'test_backfill_max_limit_check_complete_loop')
+ dag = self._get_dag_test_max_active_limits('test_backfill_max_limit_check_complete_loop')
start_date = DEFAULT_DATE - datetime.timedelta(hours=1)
end_date = DEFAULT_DATE
@@ -1019,11 +983,9 @@ def test_backfill_max_limit_check_complete_loop(self):
# backfill job 3 times
success_expected = 2
executor = MockExecutor()
- job = BackfillJob(dag=dag,
- start_date=start_date,
- end_date=end_date,
- executor=executor,
- donot_pickle=True)
+ job = BackfillJob(
+ dag=dag, start_date=start_date, end_date=end_date, executor=executor, donot_pickle=True
+ )
job.run()
success_dagruns = len(DagRun.find(dag_id=dag.dag_id, state=State.SUCCESS))
@@ -1032,10 +994,7 @@ def test_backfill_max_limit_check_complete_loop(self):
self.assertEqual(0, running_dagruns) # no dag_runs in running state are left
def test_sub_set_subdag(self):
- dag = DAG(
- 'test_sub_set_subdag',
- start_date=DEFAULT_DATE,
- default_args={'owner': 'owner1'})
+ dag = DAG('test_sub_set_subdag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
with dag:
op1 = DummyOperator(task_id='leave1')
@@ -1050,19 +1009,13 @@ def test_sub_set_subdag(self):
op3.set_downstream(op4)
dag.clear()
- dr = dag.create_dagrun(run_id="test",
- state=State.RUNNING,
- execution_date=DEFAULT_DATE,
- start_date=DEFAULT_DATE)
+ dr = dag.create_dagrun(
+ run_id="test", state=State.RUNNING, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE
+ )
executor = MockExecutor()
- sub_dag = dag.sub_dag(task_ids_or_regex="leave*",
- include_downstream=False,
- include_upstream=False)
- job = BackfillJob(dag=sub_dag,
- start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE,
- executor=executor)
+ sub_dag = dag.sub_dag(task_ids_or_regex="leave*", include_downstream=False, include_upstream=False)
+ job = BackfillJob(dag=sub_dag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, executor=executor)
job.run()
self.assertRaises(sqlalchemy.orm.exc.NoResultFound, dr.refresh_from_db)
@@ -1093,10 +1046,9 @@ def test_backfill_fill_blanks(self):
op6 = DummyOperator(task_id='op6')
dag.clear()
- dr = dag.create_dagrun(run_id='test',
- state=State.RUNNING,
- execution_date=DEFAULT_DATE,
- start_date=DEFAULT_DATE)
+ dr = dag.create_dagrun(
+ run_id='test', state=State.RUNNING, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE
+ )
executor = MockExecutor()
session = settings.Session()
@@ -1119,14 +1071,8 @@ def test_backfill_fill_blanks(self):
session.commit()
session.close()
- job = BackfillJob(dag=dag,
- start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE,
- executor=executor)
- self.assertRaisesRegex(
- AirflowException,
- 'Some task instances failed',
- job.run)
+ job = BackfillJob(dag=dag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, executor=executor)
+ self.assertRaisesRegex(AirflowException, 'Some task instances failed', job.run)
self.assertRaises(sqlalchemy.orm.exc.NoResultFound, dr.refresh_from_db)
# the run_id should have changed, so a refresh won't work
@@ -1155,11 +1101,9 @@ def test_backfill_execute_subdag(self):
start_date = timezone.utcnow()
executor = MockExecutor()
- job = BackfillJob(dag=subdag,
- start_date=start_date,
- end_date=start_date,
- executor=executor,
- donot_pickle=True)
+ job = BackfillJob(
+ dag=subdag, start_date=start_date, end_date=start_date, executor=executor, donot_pickle=True
+ )
job.run()
subdag_op_task.pre_execute(context={'execution_date': start_date})
@@ -1177,8 +1121,7 @@ def test_backfill_execute_subdag(self):
with create_session() as session:
successful_subdag_runs = (
- session
- .query(DagRun)
+ session.query(DagRun)
.filter(DagRun.dag_id == subdag.dag_id)
.filter(DagRun.execution_date == start_date)
# pylint: disable=comparison-with-callable
@@ -1198,42 +1141,30 @@ def test_subdag_clear_parentdag_downstream_clear(self):
subdag = subdag_op_task.subdag
executor = MockExecutor()
- job = BackfillJob(dag=dag,
- start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE,
- executor=executor,
- donot_pickle=True)
+ job = BackfillJob(
+ dag=dag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, executor=executor, donot_pickle=True
+ )
with timeout(seconds=30):
job.run()
- ti_subdag = TI(
- task=dag.get_task('daily_job'),
- execution_date=DEFAULT_DATE)
+ ti_subdag = TI(task=dag.get_task('daily_job'), execution_date=DEFAULT_DATE)
ti_subdag.refresh_from_db()
self.assertEqual(ti_subdag.state, State.SUCCESS)
- ti_irrelevant = TI(
- task=dag.get_task('daily_job_irrelevant'),
- execution_date=DEFAULT_DATE)
+ ti_irrelevant = TI(task=dag.get_task('daily_job_irrelevant'), execution_date=DEFAULT_DATE)
ti_irrelevant.refresh_from_db()
self.assertEqual(ti_irrelevant.state, State.SUCCESS)
- ti_downstream = TI(
- task=dag.get_task('daily_job_downstream'),
- execution_date=DEFAULT_DATE)
+ ti_downstream = TI(task=dag.get_task('daily_job_downstream'), execution_date=DEFAULT_DATE)
ti_downstream.refresh_from_db()
self.assertEqual(ti_downstream.state, State.SUCCESS)
sdag = subdag.sub_dag(
- task_ids_or_regex='daily_job_subdag_task',
- include_downstream=True,
- include_upstream=False)
+ task_ids_or_regex='daily_job_subdag_task', include_downstream=True, include_upstream=False
+ )
- sdag.clear(
- start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE,
- include_parentdag=True)
+ sdag.clear(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, include_parentdag=True)
ti_subdag.refresh_from_db()
self.assertEqual(State.NONE, ti_subdag.state)
@@ -1257,16 +1188,13 @@ def test_backfill_execute_subdag_with_removed_task(self):
subdag = dag.get_task('section-1').subdag
executor = MockExecutor()
- job = BackfillJob(dag=subdag,
- start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE,
- executor=executor,
- donot_pickle=True)
+ job = BackfillJob(
+ dag=subdag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, executor=executor, donot_pickle=True
+ )
removed_task_ti = TI(
- task=DummyOperator(task_id='removed_task'),
- execution_date=DEFAULT_DATE,
- state=State.REMOVED)
+ task=DummyOperator(task_id='removed_task'), execution_date=DEFAULT_DATE, state=State.REMOVED
+ )
removed_task_ti.dag_id = subdag.dag_id
session = settings.Session()
@@ -1277,10 +1205,13 @@ def test_backfill_execute_subdag_with_removed_task(self):
job.run()
for task in subdag.tasks:
- instance = session.query(TI).filter(
- TI.dag_id == subdag.dag_id,
- TI.task_id == task.task_id,
- TI.execution_date == DEFAULT_DATE).first()
+ instance = (
+ session.query(TI)
+ .filter(
+ TI.dag_id == subdag.dag_id, TI.task_id == task.task_id, TI.execution_date == DEFAULT_DATE
+ )
+ .first()
+ )
self.assertIsNotNone(instance)
self.assertEqual(instance.state, State.SUCCESS)
@@ -1292,23 +1223,20 @@ def test_backfill_execute_subdag_with_removed_task(self):
dag.clear()
def test_update_counters(self):
- dag = DAG(
- dag_id='test_manage_executor_state',
- start_date=DEFAULT_DATE)
+ dag = DAG(dag_id='test_manage_executor_state', start_date=DEFAULT_DATE)
- task1 = DummyOperator(
- task_id='dummy',
- dag=dag,
- owner='airflow')
+ task1 = DummyOperator(task_id='dummy', dag=dag, owner='airflow')
job = BackfillJob(dag=dag)
session = settings.Session()
- dr = dag.create_dagrun(run_type=DagRunType.SCHEDULED,
- state=State.RUNNING,
- execution_date=DEFAULT_DATE,
- start_date=DEFAULT_DATE,
- session=session)
+ dr = dag.create_dagrun(
+ run_type=DagRunType.SCHEDULED,
+ state=State.RUNNING,
+ execution_date=DEFAULT_DATE,
+ start_date=DEFAULT_DATE,
+ session=session,
+ )
ti = TI(task1, dr.execution_date)
ti.refresh_from_db()
@@ -1389,12 +1317,8 @@ def test_update_counters(self):
session.close()
def test_dag_get_run_dates(self):
-
def get_test_dag_for_backfill(schedule_interval=None):
- dag = DAG(
- dag_id='test_get_dates',
- start_date=DEFAULT_DATE,
- schedule_interval=schedule_interval)
+ dag = DAG(dag_id='test_get_dates', start_date=DEFAULT_DATE, schedule_interval=schedule_interval)
DummyOperator(
task_id='dummy',
dag=dag,
@@ -1403,18 +1327,23 @@ def get_test_dag_for_backfill(schedule_interval=None):
return dag
test_dag = get_test_dag_for_backfill()
- self.assertEqual([DEFAULT_DATE], test_dag.get_run_dates(
- start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE))
+ self.assertEqual(
+ [DEFAULT_DATE], test_dag.get_run_dates(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+ )
test_dag = get_test_dag_for_backfill(schedule_interval="@hourly")
- self.assertEqual([DEFAULT_DATE - datetime.timedelta(hours=3),
- DEFAULT_DATE - datetime.timedelta(hours=2),
- DEFAULT_DATE - datetime.timedelta(hours=1),
- DEFAULT_DATE],
- test_dag.get_run_dates(
- start_date=DEFAULT_DATE - datetime.timedelta(hours=3),
- end_date=DEFAULT_DATE, ))
+ self.assertEqual(
+ [
+ DEFAULT_DATE - datetime.timedelta(hours=3),
+ DEFAULT_DATE - datetime.timedelta(hours=2),
+ DEFAULT_DATE - datetime.timedelta(hours=1),
+ DEFAULT_DATE,
+ ],
+ test_dag.get_run_dates(
+ start_date=DEFAULT_DATE - datetime.timedelta(hours=3),
+ end_date=DEFAULT_DATE,
+ ),
+ )
def test_backfill_run_backwards(self):
dag = self.dagbag.get_dag("test_start_date_scheduling")
@@ -1427,14 +1356,17 @@ def test_backfill_run_backwards(self):
dag=dag,
start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=1),
- run_backwards=True
+ run_backwards=True,
)
job.run()
session = settings.Session()
- tis = session.query(TI).filter(
- TI.dag_id == 'test_start_date_scheduling' and TI.task_id == 'dummy'
- ).order_by(TI.execution_date).all()
+ tis = (
+ session.query(TI)
+ .filter(TI.dag_id == 'test_start_date_scheduling' and TI.task_id == 'dummy')
+ .order_by(TI.execution_date)
+ .all()
+ )
queued_times = [ti.queued_dttm for ti in tis]
self.assertTrue(queued_times == sorted(queued_times, reverse=True))
@@ -1449,9 +1381,7 @@ def test_reset_orphaned_tasks_with_orphans(self):
states = [State.QUEUED, State.SCHEDULED, State.NONE, State.RUNNING, State.SUCCESS]
states_to_reset = [State.QUEUED, State.SCHEDULED, State.NONE]
- dag = DAG(dag_id=prefix,
- start_date=DEFAULT_DATE,
- schedule_interval="@daily")
+ dag = DAG(dag_id=prefix, start_date=DEFAULT_DATE, schedule_interval="@daily")
tasks = []
for i in range(len(states)):
task_id = f"{prefix}_task_{i}"
@@ -1543,9 +1473,7 @@ def test_job_id_is_assigned_to_dag_run(self):
DummyOperator(task_id="dummy_task", dag=dag)
job = BackfillJob(
- dag=dag,
- executor=MockExecutor(),
- start_date=datetime.datetime.now() - datetime.timedelta(days=1)
+ dag=dag, executor=MockExecutor(), start_date=datetime.datetime.now() - datetime.timedelta(days=1)
)
job.run()
dr: DagRun = dag.get_last_dagrun()
diff --git a/tests/jobs/test_base_job.py b/tests/jobs/test_base_job.py
index 2d1496cb6e7a6..3ffe4adbee7b0 100644
--- a/tests/jobs/test_base_job.py
+++ b/tests/jobs/test_base_job.py
@@ -32,9 +32,7 @@
class MockJob(BaseJob):
- __mapper_args__ = {
- 'polymorphic_identity': 'MockJob'
- }
+ __mapper_args__ = {'polymorphic_identity': 'MockJob'}
def __init__(self, func, **kwargs):
self.func = func
@@ -54,6 +52,7 @@ def test_state_success(self):
def test_state_sysexit(self):
import sys
+
job = MockJob(lambda: sys.exit(0))
job.run()
diff --git a/tests/jobs/test_local_task_job.py b/tests/jobs/test_local_task_job.py
index f81027f724ba9..bcdf1fb5d3ce1 100644
--- a/tests/jobs/test_local_task_job.py
+++ b/tests/jobs/test_local_task_job.py
@@ -67,23 +67,19 @@ def test_localtaskjob_essential_attr(self):
proper values without intervention
"""
dag = DAG(
- 'test_localtaskjob_essential_attr',
- start_date=DEFAULT_DATE,
- default_args={'owner': 'owner1'})
+ 'test_localtaskjob_essential_attr', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}
+ )
with dag:
op1 = DummyOperator(task_id='op1')
dag.clear()
- dr = dag.create_dagrun(run_id="test",
- state=State.SUCCESS,
- execution_date=DEFAULT_DATE,
- start_date=DEFAULT_DATE)
+ dr = dag.create_dagrun(
+ run_id="test", state=State.SUCCESS, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE
+ )
ti = dr.get_task_instance(task_id=op1.task_id)
- job1 = LocalTaskJob(task_instance=ti,
- ignore_ti_state=True,
- executor=SequentialExecutor())
+ job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
essential_attr = ["dag_id", "job_type", "start_date", "hostname"]
@@ -96,28 +92,25 @@ def test_localtaskjob_essential_attr(self):
@patch('os.getpid')
def test_localtaskjob_heartbeat(self, mock_pid):
session = settings.Session()
- dag = DAG(
- 'test_localtaskjob_heartbeat',
- start_date=DEFAULT_DATE,
- default_args={'owner': 'owner1'})
+ dag = DAG('test_localtaskjob_heartbeat', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
with dag:
op1 = DummyOperator(task_id='op1')
dag.clear()
- dr = dag.create_dagrun(run_id="test",
- state=State.SUCCESS,
- execution_date=DEFAULT_DATE,
- start_date=DEFAULT_DATE,
- session=session)
+ dr = dag.create_dagrun(
+ run_id="test",
+ state=State.SUCCESS,
+ execution_date=DEFAULT_DATE,
+ start_date=DEFAULT_DATE,
+ session=session,
+ )
ti = dr.get_task_instance(task_id=op1.task_id, session=session)
ti.state = State.RUNNING
ti.hostname = "blablabla"
session.commit()
- job1 = LocalTaskJob(task_instance=ti,
- ignore_ti_state=True,
- executor=SequentialExecutor())
+ job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
self.assertRaises(AirflowException, job1.heartbeat_callback)
mock_pid.return_value = 1
@@ -148,11 +141,13 @@ def test_heartbeat_failed_fast(self):
dag = dagbag.get_dag(dag_id)
task = dag.get_task(task_id)
- dag.create_dagrun(run_id="test_heartbeat_failed_fast_run",
- state=State.RUNNING,
- execution_date=DEFAULT_DATE,
- start_date=DEFAULT_DATE,
- session=session)
+ dag.create_dagrun(
+ run_id="test_heartbeat_failed_fast_run",
+ state=State.RUNNING,
+ execution_date=DEFAULT_DATE,
+ start_date=DEFAULT_DATE,
+ session=session,
+ )
ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
ti.refresh_from_db()
ti.state = State.RUNNING
@@ -189,11 +184,13 @@ def test_mark_success_no_kill(self):
session = settings.Session()
dag.clear()
- dag.create_dagrun(run_id="test",
- state=State.RUNNING,
- execution_date=DEFAULT_DATE,
- start_date=DEFAULT_DATE,
- session=session)
+ dag.create_dagrun(
+ run_id="test",
+ state=State.RUNNING,
+ execution_date=DEFAULT_DATE,
+ start_date=DEFAULT_DATE,
+ session=session,
+ )
ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
ti.refresh_from_db()
job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True)
@@ -226,11 +223,13 @@ def test_localtaskjob_double_trigger(self):
session = settings.Session()
dag.clear()
- dr = dag.create_dagrun(run_id="test",
- state=State.SUCCESS,
- execution_date=DEFAULT_DATE,
- start_date=DEFAULT_DATE,
- session=session)
+ dr = dag.create_dagrun(
+ run_id="test",
+ state=State.SUCCESS,
+ execution_date=DEFAULT_DATE,
+ start_date=DEFAULT_DATE,
+ session=session,
+ )
ti = dr.get_task_instance(task_id=task.task_id, session=session)
ti.state = State.RUNNING
ti.hostname = get_hostname()
@@ -240,9 +239,9 @@ def test_localtaskjob_double_trigger(self):
ti_run = TaskInstance(task=task, execution_date=DEFAULT_DATE)
ti_run.refresh_from_db()
- job1 = LocalTaskJob(task_instance=ti_run,
- executor=SequentialExecutor())
+ job1 = LocalTaskJob(task_instance=ti_run, executor=SequentialExecutor())
from airflow.task.task_runner.standard_task_runner import StandardTaskRunner
+
with patch.object(StandardTaskRunner, 'start', return_value=None) as mock_method:
job1.run()
mock_method.assert_not_called()
@@ -265,16 +264,17 @@ def test_localtaskjob_maintain_heart_rate(self):
session = settings.Session()
dag.clear()
- dag.create_dagrun(run_id="test",
- state=State.SUCCESS,
- execution_date=DEFAULT_DATE,
- start_date=DEFAULT_DATE,
- session=session)
+ dag.create_dagrun(
+ run_id="test",
+ state=State.SUCCESS,
+ execution_date=DEFAULT_DATE,
+ start_date=DEFAULT_DATE,
+ session=session,
+ )
ti_run = TaskInstance(task=task, execution_date=DEFAULT_DATE)
ti_run.refresh_from_db()
- job1 = LocalTaskJob(task_instance=ti_run,
- executor=SequentialExecutor())
+ job1 = LocalTaskJob(task_instance=ti_run, executor=SequentialExecutor())
# this should make sure we only heartbeat once and exit at the second
# loop in _execute()
@@ -285,6 +285,7 @@ def multi_return_code():
time_start = time.time()
from airflow.task.task_runner.standard_task_runner import StandardTaskRunner
+
with patch.object(StandardTaskRunner, 'start', return_value=None) as mock_start:
with patch.object(StandardTaskRunner, 'return_code') as mock_ret_code:
mock_ret_code.side_effect = multi_return_code
@@ -331,22 +332,23 @@ def task_function(ti):
task = PythonOperator(
task_id='test_state_succeeded1',
python_callable=task_function,
- on_failure_callback=check_failure)
+ on_failure_callback=check_failure,
+ )
session = settings.Session()
dag.clear()
- dag.create_dagrun(run_id="test",
- state=State.RUNNING,
- execution_date=DEFAULT_DATE,
- start_date=DEFAULT_DATE,
- session=session)
+ dag.create_dagrun(
+ run_id="test",
+ state=State.RUNNING,
+ execution_date=DEFAULT_DATE,
+ start_date=DEFAULT_DATE,
+ session=session,
+ )
ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
ti.refresh_from_db()
- job1 = LocalTaskJob(task_instance=ti,
- ignore_ti_state=True,
- executor=SequentialExecutor())
+ job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
with timeout(30):
# This should be _much_ shorter to run.
# If you change this limit, make the timeout in the callbable above bigger
@@ -355,8 +357,9 @@ def task_function(ti):
ti.refresh_from_db()
self.assertEqual(ti.state, State.FAILED)
self.assertTrue(data['called'])
- self.assertNotIn('reached_end_of_sleep', data,
- 'Task should not have been allowed to run to completion')
+ self.assertNotIn(
+ 'reached_end_of_sleep', data, 'Task should not have been allowed to run to completion'
+ )
@pytest.mark.quarantined
def test_mark_success_on_success_callback(self):
@@ -367,33 +370,28 @@ def test_mark_success_on_success_callback(self):
data = {'called': False}
def success_callback(context):
- self.assertEqual(context['dag_run'].dag_id,
- 'test_mark_success')
+ self.assertEqual(context['dag_run'].dag_id, 'test_mark_success')
data['called'] = True
- dag = DAG(dag_id='test_mark_success',
- start_date=DEFAULT_DATE,
- default_args={'owner': 'owner1'})
+ dag = DAG(dag_id='test_mark_success', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
- task = DummyOperator(
- task_id='test_state_succeeded1',
- dag=dag,
- on_success_callback=success_callback)
+ task = DummyOperator(task_id='test_state_succeeded1', dag=dag, on_success_callback=success_callback)
session = settings.Session()
dag.clear()
- dag.create_dagrun(run_id="test",
- state=State.RUNNING,
- execution_date=DEFAULT_DATE,
- start_date=DEFAULT_DATE,
- session=session)
+ dag.create_dagrun(
+ run_id="test",
+ state=State.RUNNING,
+ execution_date=DEFAULT_DATE,
+ start_date=DEFAULT_DATE,
+ session=session,
+ )
ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
ti.refresh_from_db()
- job1 = LocalTaskJob(task_instance=ti,
- ignore_ti_state=True,
- executor=SequentialExecutor())
+ job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
from airflow.task.task_runner.standard_task_runner import StandardTaskRunner
+
job1.task_runner = StandardTaskRunner(job1)
process = multiprocessing.Process(target=job1.run)
process.start()
diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py
index a7589ef57dca6..5dd1bec6785b7 100644
--- a/tests/jobs/test_scheduler_job.py
+++ b/tests/jobs/test_scheduler_job.py
@@ -58,8 +58,14 @@
from tests.test_utils.asserts import assert_queries_count
from tests.test_utils.config import conf_vars, env_vars
from tests.test_utils.db import (
- clear_db_dags, clear_db_errors, clear_db_jobs, clear_db_pools, clear_db_runs, clear_db_serialized_dags,
- clear_db_sla_miss, set_default_pool_slots,
+ clear_db_dags,
+ clear_db_errors,
+ clear_db_jobs,
+ clear_db_pools,
+ clear_db_runs,
+ clear_db_serialized_dags,
+ clear_db_sla_miss,
+ set_default_pool_slots,
)
from tests.test_utils.mock_executor import MockExecutor
@@ -77,11 +83,7 @@
# files contain a DAG (otherwise Airflow will skip them)
PARSEABLE_DAG_FILE_CONTENTS = '"airflow DAG"'
UNPARSEABLE_DAG_FILE_CONTENTS = 'airflow DAG'
-INVALID_DAG_WITH_DEPTH_FILE_CONTENTS = (
- "def something():\n"
- " return airflow_DAG\n"
- "something()"
-)
+INVALID_DAG_WITH_DEPTH_FILE_CONTENTS = "def something():\n return airflow_DAG\nsomething()"
# Filename to be used for dags that are created in an ad-hoc manner and can be removed/
# created at runtime
@@ -97,7 +99,6 @@ def disable_load_example():
@pytest.mark.usefixtures("disable_load_example")
class TestDagFileProcessor(unittest.TestCase):
-
@staticmethod
def clean_db():
clear_db_runs()
@@ -123,7 +124,8 @@ def create_test_dag(self, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + timed
dag_id='test_scheduler_reschedule',
start_date=start_date,
# Make sure it only creates a single DAG Run
- end_date=end_date)
+ end_date=end_date,
+ )
dag.clear()
dag.is_subdag = False
with create_session() as session:
@@ -150,14 +152,13 @@ def test_dag_file_processor_sla_miss_callback(self):
# Create dag with a start of 1 day ago, but an sla of 0
# so we'll already have an sla_miss on the books.
test_start_date = days_ago(1)
- dag = DAG(dag_id='test_sla_miss',
- sla_miss_callback=sla_callback,
- default_args={'start_date': test_start_date,
- 'sla': datetime.timedelta()})
+ dag = DAG(
+ dag_id='test_sla_miss',
+ sla_miss_callback=sla_callback,
+ default_args={'start_date': test_start_date, 'sla': datetime.timedelta()},
+ )
- task = DummyOperator(task_id='dummy',
- dag=dag,
- owner='airflow')
+ task = DummyOperator(task_id='dummy', dag=dag, owner='airflow')
session.merge(TaskInstance(task=task, execution_date=test_start_date, state='success'))
@@ -181,10 +182,11 @@ def test_dag_file_processor_sla_miss_callback_invalid_sla(self):
# so we'll already have an sla_miss on the books.
# Pass anything besides a timedelta object to the sla argument.
test_start_date = days_ago(1)
- dag = DAG(dag_id='test_sla_miss',
- sla_miss_callback=sla_callback,
- default_args={'start_date': test_start_date,
- 'sla': None})
+ dag = DAG(
+ dag_id='test_sla_miss',
+ sla_miss_callback=sla_callback,
+ default_args={'start_date': test_start_date, 'sla': None},
+ )
task = DummyOperator(task_id='dummy', dag=dag, owner='airflow')
@@ -209,10 +211,11 @@ def test_dag_file_processor_sla_miss_callback_sent_notification(self):
# Create dag with a start of 2 days ago, but an sla of 1 day
# ago so we'll already have an sla_miss on the books
test_start_date = days_ago(2)
- dag = DAG(dag_id='test_sla_miss',
- sla_miss_callback=sla_callback,
- default_args={'start_date': test_start_date,
- 'sla': datetime.timedelta(days=1)})
+ dag = DAG(
+ dag_id='test_sla_miss',
+ sla_miss_callback=sla_callback,
+ default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)},
+ )
task = DummyOperator(task_id='dummy', dag=dag, owner='airflow')
@@ -220,11 +223,15 @@ def test_dag_file_processor_sla_miss_callback_sent_notification(self):
session.merge(TaskInstance(task=task, execution_date=test_start_date, state='success'))
# Create an SlaMiss where notification was sent, but email was not
- session.merge(SlaMiss(task_id='dummy',
- dag_id='test_sla_miss',
- execution_date=test_start_date,
- email_sent=False,
- notification_sent=True))
+ session.merge(
+ SlaMiss(
+ task_id='dummy',
+ dag_id='test_sla_miss',
+ execution_date=test_start_date,
+ email_sent=False,
+ notification_sent=True,
+ )
+ )
# Now call manage_slas and see if the sla_miss callback gets called
dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
@@ -242,21 +249,18 @@ def test_dag_file_processor_sla_miss_callback_exception(self):
sla_callback = MagicMock(side_effect=RuntimeError('Could not call function'))
test_start_date = days_ago(2)
- dag = DAG(dag_id='test_sla_miss',
- sla_miss_callback=sla_callback,
- default_args={'start_date': test_start_date})
+ dag = DAG(
+ dag_id='test_sla_miss',
+ sla_miss_callback=sla_callback,
+ default_args={'start_date': test_start_date},
+ )
- task = DummyOperator(task_id='dummy',
- dag=dag,
- owner='airflow',
- sla=datetime.timedelta(hours=1))
+ task = DummyOperator(task_id='dummy', dag=dag, owner='airflow', sla=datetime.timedelta(hours=1))
session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success'))
# Create an SlaMiss where notification was sent, but email was not
- session.merge(SlaMiss(task_id='dummy',
- dag_id='test_sla_miss',
- execution_date=test_start_date))
+ session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', execution_date=test_start_date))
# Now call manage_slas and see if the sla_miss callback gets called
mock_log = mock.MagicMock()
@@ -264,32 +268,28 @@ def test_dag_file_processor_sla_miss_callback_exception(self):
dag_file_processor.manage_slas(dag=dag, session=session)
assert sla_callback.called
mock_log.exception.assert_called_once_with(
- 'Could not call sla_miss_callback for DAG %s',
- 'test_sla_miss')
+ 'Could not call sla_miss_callback for DAG %s', 'test_sla_miss'
+ )
@mock.patch('airflow.jobs.scheduler_job.send_email')
def test_dag_file_processor_only_collect_emails_from_sla_missed_tasks(self, mock_send_email):
session = settings.Session()
test_start_date = days_ago(2)
- dag = DAG(dag_id='test_sla_miss',
- default_args={'start_date': test_start_date,
- 'sla': datetime.timedelta(days=1)})
+ dag = DAG(
+ dag_id='test_sla_miss',
+ default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)},
+ )
email1 = 'test1@test.com'
- task = DummyOperator(task_id='sla_missed',
- dag=dag,
- owner='airflow',
- email=email1,
- sla=datetime.timedelta(hours=1))
+ task = DummyOperator(
+ task_id='sla_missed', dag=dag, owner='airflow', email=email1, sla=datetime.timedelta(hours=1)
+ )
session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success'))
email2 = 'test2@test.com'
- DummyOperator(task_id='sla_not_missed',
- dag=dag,
- owner='airflow',
- email=email2)
+ DummyOperator(task_id='sla_not_missed', dag=dag, owner='airflow', email=email2)
session.merge(SlaMiss(task_id='sla_missed', dag_id='test_sla_miss', execution_date=test_start_date))
@@ -316,15 +316,14 @@ def test_dag_file_processor_sla_miss_email_exception(self, mock_send_email, mock
mock_send_email.side_effect = RuntimeError('Could not send an email')
test_start_date = days_ago(2)
- dag = DAG(dag_id='test_sla_miss',
- default_args={'start_date': test_start_date,
- 'sla': datetime.timedelta(days=1)})
+ dag = DAG(
+ dag_id='test_sla_miss',
+ default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)},
+ )
- task = DummyOperator(task_id='dummy',
- dag=dag,
- owner='airflow',
- email='test@test.com',
- sla=datetime.timedelta(hours=1))
+ task = DummyOperator(
+ task_id='dummy', dag=dag, owner='airflow', email='test@test.com', sla=datetime.timedelta(hours=1)
+ )
session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success'))
@@ -336,8 +335,8 @@ def test_dag_file_processor_sla_miss_email_exception(self, mock_send_email, mock
dag_file_processor.manage_slas(dag=dag, session=session)
mock_log.exception.assert_called_once_with(
- 'Could not send SLA Miss email notification for DAG %s',
- 'test_sla_miss')
+ 'Could not send SLA Miss email notification for DAG %s', 'test_sla_miss'
+ )
mock_stats_incr.assert_called_once_with('sla_email_notification_failure')
def test_dag_file_processor_sla_miss_deleted_task(self):
@@ -348,47 +347,48 @@ def test_dag_file_processor_sla_miss_deleted_task(self):
session = settings.Session()
test_start_date = days_ago(2)
- dag = DAG(dag_id='test_sla_miss',
- default_args={'start_date': test_start_date,
- 'sla': datetime.timedelta(days=1)})
+ dag = DAG(
+ dag_id='test_sla_miss',
+ default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)},
+ )
- task = DummyOperator(task_id='dummy',
- dag=dag,
- owner='airflow',
- email='test@test.com',
- sla=datetime.timedelta(hours=1))
+ task = DummyOperator(
+ task_id='dummy', dag=dag, owner='airflow', email='test@test.com', sla=datetime.timedelta(hours=1)
+ )
session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success'))
# Create an SlaMiss where notification was sent, but email was not
- session.merge(SlaMiss(task_id='dummy_deleted', dag_id='test_sla_miss',
- execution_date=test_start_date))
+ session.merge(
+ SlaMiss(task_id='dummy_deleted', dag_id='test_sla_miss', execution_date=test_start_date)
+ )
mock_log = mock.MagicMock()
dag_file_processor = DagFileProcessor(dag_ids=[], log=mock_log)
dag_file_processor.manage_slas(dag=dag, session=session)
- @parameterized.expand([
- [State.NONE, None, None],
- [State.UP_FOR_RETRY, timezone.utcnow() - datetime.timedelta(minutes=30),
- timezone.utcnow() - datetime.timedelta(minutes=15)],
- [State.UP_FOR_RESCHEDULE, timezone.utcnow() - datetime.timedelta(minutes=30),
- timezone.utcnow() - datetime.timedelta(minutes=15)],
- ])
+ @parameterized.expand(
+ [
+ [State.NONE, None, None],
+ [
+ State.UP_FOR_RETRY,
+ timezone.utcnow() - datetime.timedelta(minutes=30),
+ timezone.utcnow() - datetime.timedelta(minutes=15),
+ ],
+ [
+ State.UP_FOR_RESCHEDULE,
+ timezone.utcnow() - datetime.timedelta(minutes=30),
+ timezone.utcnow() - datetime.timedelta(minutes=15),
+ ],
+ ]
+ )
def test_dag_file_processor_process_task_instances(self, state, start_date, end_date):
"""
Test if _process_task_instances puts the right task instances into the
mock_list.
"""
- dag = DAG(
- dag_id='test_scheduler_process_execute_task',
- start_date=DEFAULT_DATE)
- BashOperator(
- task_id='dummy',
- dag=dag,
- owner='airflow',
- bash_command='echo hi'
- )
+ dag = DAG(dag_id='test_scheduler_process_execute_task', start_date=DEFAULT_DATE)
+ BashOperator(task_id='dummy', dag=dag, owner='airflow', bash_command='echo hi')
with create_session() as session:
orm_dag = DagModel(dag_id=dag.dag_id)
@@ -419,30 +419,33 @@ def test_dag_file_processor_process_task_instances(self, state, start_date, end_
session.refresh(ti)
assert ti.state == State.SCHEDULED
- @parameterized.expand([
- [State.NONE, None, None],
- [State.UP_FOR_RETRY, timezone.utcnow() - datetime.timedelta(minutes=30),
- timezone.utcnow() - datetime.timedelta(minutes=15)],
- [State.UP_FOR_RESCHEDULE, timezone.utcnow() - datetime.timedelta(minutes=30),
- timezone.utcnow() - datetime.timedelta(minutes=15)],
- ])
+ @parameterized.expand(
+ [
+ [State.NONE, None, None],
+ [
+ State.UP_FOR_RETRY,
+ timezone.utcnow() - datetime.timedelta(minutes=30),
+ timezone.utcnow() - datetime.timedelta(minutes=15),
+ ],
+ [
+ State.UP_FOR_RESCHEDULE,
+ timezone.utcnow() - datetime.timedelta(minutes=30),
+ timezone.utcnow() - datetime.timedelta(minutes=15),
+ ],
+ ]
+ )
def test_dag_file_processor_process_task_instances_with_task_concurrency(
- self, state, start_date, end_date,
+ self,
+ state,
+ start_date,
+ end_date,
):
"""
Test if _process_task_instances puts the right task instances into the
mock_list.
"""
- dag = DAG(
- dag_id='test_scheduler_process_execute_task_with_task_concurrency',
- start_date=DEFAULT_DATE)
- BashOperator(
- task_id='dummy',
- task_concurrency=2,
- dag=dag,
- owner='airflow',
- bash_command='echo Hi'
- )
+ dag = DAG(dag_id='test_scheduler_process_execute_task_with_task_concurrency', start_date=DEFAULT_DATE)
+ BashOperator(task_id='dummy', task_concurrency=2, dag=dag, owner='airflow', bash_command='echo Hi')
with create_session() as session:
orm_dag = DagModel(dag_id=dag.dag_id)
@@ -473,13 +476,21 @@ def test_dag_file_processor_process_task_instances_with_task_concurrency(
session.refresh(ti)
assert ti.state == State.SCHEDULED
- @parameterized.expand([
- [State.NONE, None, None],
- [State.UP_FOR_RETRY, timezone.utcnow() - datetime.timedelta(minutes=30),
- timezone.utcnow() - datetime.timedelta(minutes=15)],
- [State.UP_FOR_RESCHEDULE, timezone.utcnow() - datetime.timedelta(minutes=30),
- timezone.utcnow() - datetime.timedelta(minutes=15)],
- ])
+ @parameterized.expand(
+ [
+ [State.NONE, None, None],
+ [
+ State.UP_FOR_RETRY,
+ timezone.utcnow() - datetime.timedelta(minutes=30),
+ timezone.utcnow() - datetime.timedelta(minutes=15),
+ ],
+ [
+ State.UP_FOR_RESCHEDULE,
+ timezone.utcnow() - datetime.timedelta(minutes=30),
+ timezone.utcnow() - datetime.timedelta(minutes=15),
+ ],
+ ]
+ )
def test_dag_file_processor_process_task_instances_depends_on_past(self, state, start_date, end_date):
"""
Test if _process_task_instances puts the right task instances into the
@@ -492,18 +503,8 @@ def test_dag_file_processor_process_task_instances_depends_on_past(self, state,
'depends_on_past': True,
},
)
- BashOperator(
- task_id='dummy1',
- dag=dag,
- owner='airflow',
- bash_command='echo hi'
- )
- BashOperator(
- task_id='dummy2',
- dag=dag,
- owner='airflow',
- bash_command='echo hi'
- )
+ BashOperator(task_id='dummy1', dag=dag, owner='airflow', bash_command='echo hi')
+ BashOperator(task_id='dummy2', dag=dag, owner='airflow', bash_command='echo hi')
with create_session() as session:
orm_dag = DagModel(dag_id=dag.dag_id)
@@ -586,17 +587,10 @@ def test_runs_respected_after_clear(self):
Test if _process_task_instances only schedules ti's up to max_active_runs
(related to issue AIRFLOW-137)
"""
- dag = DAG(
- dag_id='test_scheduler_max_active_runs_respected_after_clear',
- start_date=DEFAULT_DATE)
+ dag = DAG(dag_id='test_scheduler_max_active_runs_respected_after_clear', start_date=DEFAULT_DATE)
dag.max_active_runs = 3
- BashOperator(
- task_id='dummy',
- dag=dag,
- owner='airflow',
- bash_command='echo Hi'
- )
+ BashOperator(task_id='dummy', dag=dag, owner='airflow', bash_command='echo Hi')
session = settings.Session()
orm_dag = DagModel(dag_id=dag.dag_id)
@@ -664,16 +658,12 @@ def test_execute_on_failure_callbacks(self, mock_ti_handle_failure):
requests = [
TaskCallbackRequest(
- full_filepath="A",
- simple_task_instance=SimpleTaskInstance(ti),
- msg="Message"
+ full_filepath="A", simple_task_instance=SimpleTaskInstance(ti), msg="Message"
)
]
dag_file_processor.execute_callbacks(dagbag, requests)
mock_ti_handle_failure.assert_called_once_with(
- "Message",
- conf.getboolean('core', 'unit_test_mode'),
- mock.ANY
+ "Message", conf.getboolean('core', 'unit_test_mode'), mock.ANY
)
def test_process_file_should_failure_callback(self):
@@ -696,7 +686,7 @@ def test_process_file_should_failure_callback(self):
TaskCallbackRequest(
full_filepath=dag.full_filepath,
simple_task_instance=SimpleTaskInstance(ti),
- msg="Message"
+ msg="Message",
)
]
callback_file.close()
@@ -738,15 +728,19 @@ def test_should_mark_dummy_task_as_success(self):
dags = scheduler_job.dagbag.dags.values()
self.assertEqual(['test_only_dummy_tasks'], [dag.dag_id for dag in dags])
self.assertEqual(5, len(tis))
- self.assertEqual({
- ('test_task_a', 'success'),
- ('test_task_b', None),
- ('test_task_c', 'success'),
- ('test_task_on_execute', 'scheduled'),
- ('test_task_on_success', 'scheduled'),
- }, {(ti.task_id, ti.state) for ti in tis})
- for state, start_date, end_date, duration in [(ti.state, ti.start_date, ti.end_date, ti.duration) for
- ti in tis]:
+ self.assertEqual(
+ {
+ ('test_task_a', 'success'),
+ ('test_task_b', None),
+ ('test_task_c', 'success'),
+ ('test_task_on_execute', 'scheduled'),
+ ('test_task_on_success', 'scheduled'),
+ },
+ {(ti.task_id, ti.state) for ti in tis},
+ )
+ for state, start_date, end_date, duration in [
+ (ti.state, ti.start_date, ti.end_date, ti.duration) for ti in tis
+ ]:
if state == 'success':
self.assertIsNotNone(start_date)
self.assertIsNotNone(end_date)
@@ -761,15 +755,19 @@ def test_should_mark_dummy_task_as_success(self):
tis = session.query(TaskInstance).all()
self.assertEqual(5, len(tis))
- self.assertEqual({
- ('test_task_a', 'success'),
- ('test_task_b', 'success'),
- ('test_task_c', 'success'),
- ('test_task_on_execute', 'scheduled'),
- ('test_task_on_success', 'scheduled'),
- }, {(ti.task_id, ti.state) for ti in tis})
- for state, start_date, end_date, duration in [(ti.state, ti.start_date, ti.end_date, ti.duration) for
- ti in tis]:
+ self.assertEqual(
+ {
+ ('test_task_a', 'success'),
+ ('test_task_b', 'success'),
+ ('test_task_c', 'success'),
+ ('test_task_on_execute', 'scheduled'),
+ ('test_task_on_success', 'scheduled'),
+ },
+ {(ti.task_id, ti.state) for ti in tis},
+ )
+ for state, start_date, end_date, duration in [
+ (ti.state, ti.start_date, ti.end_date, ti.duration) for ti in tis
+ ]:
if state == 'success':
self.assertIsNotNone(start_date)
self.assertIsNotNone(end_date)
@@ -782,7 +780,6 @@ def test_should_mark_dummy_task_as_success(self):
@pytest.mark.usefixtures("disable_load_example")
class TestSchedulerJob(unittest.TestCase):
-
def setUp(self):
clear_db_runs()
clear_db_pools()
@@ -841,9 +838,8 @@ def run_single_scheduler_loop_with_no_dags(self, dags_folder):
:type dags_folder: str
"""
scheduler = SchedulerJob(
- executor=self.null_exec,
- num_times_parse_dags=1,
- subdir=os.path.join(dags_folder))
+ executor=self.null_exec, num_times_parse_dags=1, subdir=os.path.join(dags_folder)
+ )
scheduler.heartrate = 0
scheduler.run()
@@ -851,15 +847,12 @@ def test_no_orphan_process_will_be_left(self):
empty_dir = mkdtemp()
current_process = psutil.Process()
old_children = current_process.children(recursive=True)
- scheduler = SchedulerJob(subdir=empty_dir,
- num_runs=1,
- executor=MockExecutor(do_update=False))
+ scheduler = SchedulerJob(subdir=empty_dir, num_runs=1, executor=MockExecutor(do_update=False))
scheduler.run()
shutil.rmtree(empty_dir)
# Remove potential noise created by previous tests.
- current_children = set(current_process.children(recursive=True)) - set(
- old_children)
+ current_children = set(current_process.children(recursive=True)) - set(old_children)
self.assertFalse(current_children)
@mock.patch('airflow.jobs.scheduler_job.TaskCallbackRequest')
@@ -900,9 +893,9 @@ def test_process_executor_events(self, mock_stats_incr, mock_task_callback):
full_filepath='/test_path1/',
simple_task_instance=mock.ANY,
msg='Executor reports task instance '
- ' '
- 'finished (failed) although the task says its queued. (Info: None) '
- 'Was the task killed externally?'
+ ' '
+ 'finished (failed) although the task says its queued. (Info: None) '
+ 'Was the task killed externally?',
)
scheduler.processor_agent.send_callback_to_execute.assert_called_once_with(task_callback)
scheduler.processor_agent.reset_mock()
@@ -929,9 +922,7 @@ def test_process_executor_events_uses_inmemory_try_number(self):
executor = MagicMock()
scheduler = SchedulerJob(executor=executor)
scheduler.processor_agent = MagicMock()
- event_buffer = {
- TaskInstanceKey(dag_id, task_id, execution_date, try_number): (State.SUCCESS, None)
- }
+ event_buffer = {TaskInstanceKey(dag_id, task_id, execution_date, try_number): (State.SUCCESS, None)}
executor.get_event_buffer.return_value = event_buffer
dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE)
@@ -1141,12 +1132,12 @@ def test_find_executable_task_instances_pool(self):
state=State.RUNNING,
)
- tis = ([
+ tis = [
TaskInstance(task1, dr1.execution_date),
TaskInstance(task2, dr1.execution_date),
TaskInstance(task1, dr2.execution_date),
- TaskInstance(task2, dr2.execution_date)
- ])
+ TaskInstance(task2, dr2.execution_date),
+ ]
for ti in tis:
ti.state = State.SCHEDULED
session.merge(ti)
@@ -1278,9 +1269,7 @@ def test_find_executable_task_instances_none(self):
)
session.flush()
- self.assertEqual(0, len(scheduler._executable_task_instances_to_queued(
- max_tis=32,
- session=session)))
+ self.assertEqual(0, len(scheduler._executable_task_instances_to_queued(max_tis=32, session=session)))
session.rollback()
def test_find_executable_task_instances_concurrency(self):
@@ -1600,10 +1589,7 @@ def test_critical_section_execute_task_instances(self):
self.assertEqual(State.RUNNING, dr1.state)
self.assertEqual(
- 2,
- DAG.get_num_task_instances(
- dag_id, dag.task_ids, states=[State.RUNNING], session=session
- )
+ 2, DAG.get_num_task_instances(dag_id, dag.task_ids, states=[State.RUNNING], session=session)
)
# create second dag run
@@ -1636,7 +1622,7 @@ def test_critical_section_execute_task_instances(self):
3,
DAG.get_num_task_instances(
dag_id, dag.task_ids, states=[State.RUNNING, State.QUEUED], session=session
- )
+ ),
)
self.assertEqual(State.RUNNING, ti1.state)
self.assertEqual(State.RUNNING, ti2.state)
@@ -1691,9 +1677,9 @@ def test_execute_task_instances_limit(self):
self.assertEqual(2, res)
scheduler.max_tis_per_query = 8
- with mock.patch.object(type(scheduler.executor),
- 'slots_available',
- new_callable=mock.PropertyMock) as mock_slots:
+ with mock.patch.object(
+ type(scheduler.executor), 'slots_available', new_callable=mock.PropertyMock
+ ) as mock_slots:
mock_slots.return_value = 2
# Check that we don't "overfill" the executor
self.assertEqual(2, res)
@@ -1722,17 +1708,21 @@ def test_change_state_for_tis_without_dagrun(self):
DummyOperator(task_id='dummy', dag=dag3, owner='airflow')
session = settings.Session()
- dr1 = dag1.create_dagrun(run_type=DagRunType.SCHEDULED,
- state=State.RUNNING,
- execution_date=DEFAULT_DATE,
- start_date=DEFAULT_DATE,
- session=session)
-
- dr2 = dag2.create_dagrun(run_type=DagRunType.SCHEDULED,
- state=State.RUNNING,
- execution_date=DEFAULT_DATE,
- start_date=DEFAULT_DATE,
- session=session)
+ dr1 = dag1.create_dagrun(
+ run_type=DagRunType.SCHEDULED,
+ state=State.RUNNING,
+ execution_date=DEFAULT_DATE,
+ start_date=DEFAULT_DATE,
+ session=session,
+ )
+
+ dr2 = dag2.create_dagrun(
+ run_type=DagRunType.SCHEDULED,
+ state=State.RUNNING,
+ execution_date=DEFAULT_DATE,
+ start_date=DEFAULT_DATE,
+ session=session,
+ )
ti1a = dr1.get_task_instance(task_id='dummy', session=session)
ti1a.state = State.SCHEDULED
@@ -1760,9 +1750,8 @@ def test_change_state_for_tis_without_dagrun(self):
scheduler.dagbag.collect_dags_from_db()
scheduler._change_state_for_tis_without_dagrun(
- old_states=[State.SCHEDULED, State.QUEUED],
- new_state=State.NONE,
- session=session)
+ old_states=[State.SCHEDULED, State.QUEUED], new_state=State.NONE, session=session
+ )
ti1a = dr1.get_task_instance(task_id='dummy', session=session)
ti1a.refresh_from_db(session=session)
@@ -1790,9 +1779,8 @@ def test_change_state_for_tis_without_dagrun(self):
session.commit()
scheduler._change_state_for_tis_without_dagrun(
- old_states=[State.SCHEDULED, State.QUEUED],
- new_state=State.NONE,
- session=session)
+ old_states=[State.SCHEDULED, State.QUEUED], new_state=State.NONE, session=session
+ )
ti1a.refresh_from_db(session=session)
self.assertEqual(ti1a.state, State.SCHEDULED)
@@ -1805,14 +1793,9 @@ def test_change_state_for_tis_without_dagrun(self):
self.assertEqual(ti2.state, State.SCHEDULED)
def test_change_state_for_tasks_failed_to_execute(self):
- dag = DAG(
- dag_id='dag_id',
- start_date=DEFAULT_DATE)
+ dag = DAG(dag_id='dag_id', start_date=DEFAULT_DATE)
- task = DummyOperator(
- task_id='task_id',
- dag=dag,
- owner='airflow')
+ task = DummyOperator(task_id='task_id', dag=dag, owner='airflow')
dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
# If there's no left over task in executor.queued_tasks, nothing happens
@@ -1858,23 +1841,28 @@ def test_adopt_or_reset_orphaned_tasks(self):
dag = DAG(
'test_execute_helper_reset_orphaned_tasks',
start_date=DEFAULT_DATE,
- default_args={'owner': 'owner1'})
+ default_args={'owner': 'owner1'},
+ )
with dag:
op1 = DummyOperator(task_id='op1')
dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
dag.clear()
- dr = dag.create_dagrun(run_type=DagRunType.SCHEDULED,
- state=State.RUNNING,
- execution_date=DEFAULT_DATE,
- start_date=DEFAULT_DATE,
- session=session)
- dr2 = dag.create_dagrun(run_type=DagRunType.BACKFILL_JOB,
- state=State.RUNNING,
- execution_date=DEFAULT_DATE + datetime.timedelta(1),
- start_date=DEFAULT_DATE,
- session=session)
+ dr = dag.create_dagrun(
+ run_type=DagRunType.SCHEDULED,
+ state=State.RUNNING,
+ execution_date=DEFAULT_DATE,
+ start_date=DEFAULT_DATE,
+ session=session,
+ )
+ dr2 = dag.create_dagrun(
+ run_type=DagRunType.BACKFILL_JOB,
+ state=State.RUNNING,
+ execution_date=DEFAULT_DATE + datetime.timedelta(1),
+ start_date=DEFAULT_DATE,
+ session=session,
+ )
ti = dr.get_task_instance(task_id=op1.task_id, session=session)
ti.state = State.SCHEDULED
ti2 = dr2.get_task_instance(task_id=op1.task_id, session=session)
@@ -1894,15 +1882,17 @@ def test_adopt_or_reset_orphaned_tasks(self):
ti2 = dr2.get_task_instance(task_id=op1.task_id, session=session)
self.assertEqual(ti2.state, State.SCHEDULED, "Tasks run by Backfill Jobs should not be reset")
- @parameterized.expand([
- [State.UP_FOR_RETRY, State.FAILED],
- [State.QUEUED, State.NONE],
- [State.SCHEDULED, State.NONE],
- [State.UP_FOR_RESCHEDULE, State.NONE],
- ])
- def test_scheduler_loop_should_change_state_for_tis_without_dagrun(self,
- initial_task_state,
- expected_task_state):
+ @parameterized.expand(
+ [
+ [State.UP_FOR_RETRY, State.FAILED],
+ [State.QUEUED, State.NONE],
+ [State.SCHEDULED, State.NONE],
+ [State.UP_FOR_RESCHEDULE, State.NONE],
+ ]
+ )
+ def test_scheduler_loop_should_change_state_for_tis_without_dagrun(
+ self, initial_task_state, expected_task_state
+ ):
session = settings.Session()
dag_id = 'test_execute_helper_should_change_state_for_tis_without_dagrun'
dag = DAG(dag_id, start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
@@ -1918,11 +1908,13 @@ def test_scheduler_loop_should_change_state_for_tis_without_dagrun(self,
dag = DagBag(read_dags_from_db=True, include_examples=False).get_dag(dag_id)
# Create DAG run with FAILED state
dag.clear()
- dr = dag.create_dagrun(run_type=DagRunType.SCHEDULED,
- state=State.FAILED,
- execution_date=DEFAULT_DATE + timedelta(days=1),
- start_date=DEFAULT_DATE + timedelta(days=1),
- session=session)
+ dr = dag.create_dagrun(
+ run_type=DagRunType.SCHEDULED,
+ state=State.FAILED,
+ execution_date=DEFAULT_DATE + timedelta(days=1),
+ start_date=DEFAULT_DATE + timedelta(days=1),
+ session=session,
+ )
ti = dr.get_task_instance(task_id=op1.task_id, session=session)
ti.state = initial_task_state
session.commit()
@@ -1956,16 +1948,11 @@ def test_dagrun_timeout_verify_max_active_runs(self):
Test if a a dagrun would be scheduled if max_dag_runs has
been reached but dagrun_timeout is also reached
"""
- dag = DAG(
- dag_id='test_scheduler_verify_max_active_runs_and_dagrun_timeout',
- start_date=DEFAULT_DATE)
+ dag = DAG(dag_id='test_scheduler_verify_max_active_runs_and_dagrun_timeout', start_date=DEFAULT_DATE)
dag.max_active_runs = 1
dag.dagrun_timeout = datetime.timedelta(seconds=60)
- DummyOperator(
- task_id='dummy',
- dag=dag,
- owner='airflow')
+ DummyOperator(task_id='dummy', dag=dag, owner='airflow')
scheduler = SchedulerJob()
scheduler.dagbag.bag_dag(dag, root_dag=dag)
@@ -2012,7 +1999,7 @@ def test_dagrun_timeout_verify_max_active_runs(self):
dag_id=dr.dag_id,
is_failure_callback=True,
execution_date=dr.execution_date,
- msg="timed_out"
+ msg="timed_out",
)
# Verify dag failure callback request is sent to file processor
@@ -2025,15 +2012,10 @@ def test_dagrun_timeout_fails_run(self):
"""
Test if a a dagrun will be set failed if timeout, even without max_active_runs
"""
- dag = DAG(
- dag_id='test_scheduler_fail_dagrun_timeout',
- start_date=DEFAULT_DATE)
+ dag = DAG(dag_id='test_scheduler_fail_dagrun_timeout', start_date=DEFAULT_DATE)
dag.dagrun_timeout = datetime.timedelta(seconds=60)
- DummyOperator(
- task_id='dummy',
- dag=dag,
- owner='airflow')
+ DummyOperator(task_id='dummy', dag=dag, owner='airflow')
scheduler = SchedulerJob()
scheduler.dagbag.bag_dag(dag, root_dag=dag)
@@ -2071,7 +2053,7 @@ def test_dagrun_timeout_fails_run(self):
dag_id=dr.dag_id,
is_failure_callback=True,
execution_date=dr.execution_date,
- msg="timed_out"
+ msg="timed_out",
)
# Verify dag failure callback request is sent to file processor
@@ -2080,10 +2062,7 @@ def test_dagrun_timeout_fails_run(self):
session.rollback()
session.close()
- @parameterized.expand([
- (State.SUCCESS, "success"),
- (State.FAILED, "task_failure")
- ])
+ @parameterized.expand([(State.SUCCESS, "success"), (State.FAILED, "task_failure")])
def test_dagrun_callbacks_are_called(self, state, expected_callback_msg):
"""
Test if DagRun is successful, and if Success callbacks is defined, it is sent to DagFileProcessor.
@@ -2093,7 +2072,7 @@ def test_dagrun_callbacks_are_called(self, state, expected_callback_msg):
dag_id='test_dagrun_callbacks_are_called',
start_date=DEFAULT_DATE,
on_success_callback=lambda x: print("success"),
- on_failure_callback=lambda x: print("failed")
+ on_failure_callback=lambda x: print("failed"),
)
DummyOperator(task_id='dummy', dag=dag, owner='airflow')
@@ -2129,7 +2108,7 @@ def test_dagrun_callbacks_are_called(self, state, expected_callback_msg):
dag_id=dr.dag_id,
is_failure_callback=bool(state == State.FAILED),
execution_date=dr.execution_date,
- msg=expected_callback_msg
+ msg=expected_callback_msg,
)
# Verify dag failure callback request is sent to file processor
@@ -2142,13 +2121,8 @@ def test_dagrun_callbacks_are_called(self, state, expected_callback_msg):
session.close()
def test_do_not_schedule_removed_task(self):
- dag = DAG(
- dag_id='test_scheduler_do_not_schedule_removed_task',
- start_date=DEFAULT_DATE)
- DummyOperator(
- task_id='dummy',
- dag=dag,
- owner='airflow')
+ dag = DAG(dag_id='test_scheduler_do_not_schedule_removed_task', start_date=DEFAULT_DATE)
+ DummyOperator(task_id='dummy', dag=dag, owner='airflow')
session = settings.Session()
dag.sync_to_db(session=session)
@@ -2165,9 +2139,7 @@ def test_do_not_schedule_removed_task(self):
self.assertIsNotNone(dr)
# Re-create the DAG, but remove the task
- dag = DAG(
- dag_id='test_scheduler_do_not_schedule_removed_task',
- start_date=DEFAULT_DATE)
+ dag = DAG(dag_id='test_scheduler_do_not_schedule_removed_task', start_date=DEFAULT_DATE)
dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
scheduler = SchedulerJob()
@@ -2179,13 +2151,14 @@ def test_do_not_schedule_removed_task(self):
@provide_session
def evaluate_dagrun(
- self,
- dag_id,
- expected_task_states, # dict of task_id: state
- dagrun_state,
- run_kwargs=None,
- advance_execution_date=False,
- session=None): # pylint: disable=unused-argument
+ self,
+ dag_id,
+ expected_task_states, # dict of task_id: state
+ dagrun_state,
+ run_kwargs=None,
+ advance_execution_date=False,
+ session=None,
+ ): # pylint: disable=unused-argument
"""
Helper for testing DagRun states with simple two-task DAGS.
@@ -2248,7 +2221,8 @@ def test_dagrun_fail(self):
'test_dagrun_fail': State.FAILED,
'test_dagrun_succeed': State.UPSTREAM_FAILED,
},
- dagrun_state=State.FAILED)
+ dagrun_state=State.FAILED,
+ )
def test_dagrun_success(self):
"""
@@ -2260,7 +2234,8 @@ def test_dagrun_success(self):
'test_dagrun_fail': State.FAILED,
'test_dagrun_succeed': State.SUCCESS,
},
- dagrun_state=State.SUCCESS)
+ dagrun_state=State.SUCCESS,
+ )
def test_dagrun_root_fail(self):
"""
@@ -2272,7 +2247,8 @@ def test_dagrun_root_fail(self):
'test_dagrun_succeed': State.SUCCESS,
'test_dagrun_fail': State.FAILED,
},
- dagrun_state=State.FAILED)
+ dagrun_state=State.FAILED,
+ )
def test_dagrun_root_fail_unfinished(self):
"""
@@ -2311,10 +2287,7 @@ def test_dagrun_root_after_dagrun_unfinished(self):
dag_id = 'test_dagrun_states_root_future'
dag = self.dagbag.get_dag(dag_id)
dag.sync_to_db()
- scheduler = SchedulerJob(
- num_runs=1,
- executor=self.null_exec,
- subdir=dag.fileloc)
+ scheduler = SchedulerJob(num_runs=1, executor=self.null_exec, subdir=dag.fileloc)
scheduler.run()
first_run = DagRun.find(dag_id=dag_id, execution_date=DEFAULT_DATE)[0]
@@ -2339,7 +2312,8 @@ def test_dagrun_deadlock_ignore_depends_on_past_advance_ex_date(self):
},
dagrun_state=State.SUCCESS,
advance_execution_date=True,
- run_kwargs=dict(ignore_first_depends_on_past=True))
+ run_kwargs=dict(ignore_first_depends_on_past=True),
+ )
def test_dagrun_deadlock_ignore_depends_on_past(self):
"""
@@ -2355,7 +2329,8 @@ def test_dagrun_deadlock_ignore_depends_on_past(self):
'test_depends_on_past_2': State.SUCCESS,
},
dagrun_state=State.SUCCESS,
- run_kwargs=dict(ignore_first_depends_on_past=True))
+ run_kwargs=dict(ignore_first_depends_on_past=True),
+ )
def test_scheduler_start_date(self):
"""
@@ -2372,14 +2347,11 @@ def test_scheduler_start_date(self):
other_dag.is_paused_upon_creation = True
other_dag.sync_to_db()
- scheduler = SchedulerJob(executor=self.null_exec,
- subdir=dag.fileloc,
- num_runs=1)
+ scheduler = SchedulerJob(executor=self.null_exec, subdir=dag.fileloc, num_runs=1)
scheduler.run()
# zero tasks ran
- self.assertEqual(
- len(session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).all()), 0)
+ self.assertEqual(len(session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).all()), 0)
session.commit()
self.assertListEqual([], self.null_exec.sorted_tasks)
@@ -2388,32 +2360,24 @@ def test_scheduler_start_date(self):
# That behavior still exists, but now it will only do so if after the
# start date
bf_exec = MockExecutor()
- backfill = BackfillJob(
- executor=bf_exec,
- dag=dag,
- start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE)
+ backfill = BackfillJob(executor=bf_exec, dag=dag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
backfill.run()
# one task ran
- self.assertEqual(
- len(session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).all()), 1)
+ self.assertEqual(len(session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).all()), 1)
self.assertListEqual(
[
(TaskInstanceKey(dag.dag_id, 'dummy', DEFAULT_DATE, 1), (State.SUCCESS, None)),
],
- bf_exec.sorted_tasks
+ bf_exec.sorted_tasks,
)
session.commit()
- scheduler = SchedulerJob(dag.fileloc,
- executor=self.null_exec,
- num_runs=1)
+ scheduler = SchedulerJob(dag.fileloc, executor=self.null_exec, num_runs=1)
scheduler.run()
# still one task
- self.assertEqual(
- len(session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).all()), 1)
+ self.assertEqual(len(session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).all()), 1)
session.commit()
self.assertListEqual([], self.null_exec.sorted_tasks)
@@ -2435,9 +2399,7 @@ def test_scheduler_task_start_date(self):
dagbag.sync_to_db()
- scheduler = SchedulerJob(executor=self.null_exec,
- subdir=dag.fileloc,
- num_runs=2)
+ scheduler = SchedulerJob(executor=self.null_exec, subdir=dag.fileloc, num_runs=2)
scheduler.run()
session = settings.Session()
@@ -2458,16 +2420,17 @@ def test_scheduler_multiprocessing(self):
dag = self.dagbag.get_dag(dag_id)
dag.clear()
- scheduler = SchedulerJob(executor=self.null_exec,
- subdir=os.path.join(TEST_DAG_FOLDER, 'test_scheduler_dags.py'),
- num_runs=1)
+ scheduler = SchedulerJob(
+ executor=self.null_exec,
+ subdir=os.path.join(TEST_DAG_FOLDER, 'test_scheduler_dags.py'),
+ num_runs=1,
+ )
scheduler.run()
# zero tasks ran
dag_id = 'test_start_date_scheduling'
session = settings.Session()
- self.assertEqual(
- len(session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).all()), 0)
+ self.assertEqual(len(session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).all()), 0)
@conf_vars({("core", "mp_start_method"): "spawn"})
def test_scheduler_multiprocessing_with_spawn_method(self):
@@ -2480,26 +2443,24 @@ def test_scheduler_multiprocessing_with_spawn_method(self):
dag = self.dagbag.get_dag(dag_id)
dag.clear()
- scheduler = SchedulerJob(executor=self.null_exec,
- subdir=os.path.join(
- TEST_DAG_FOLDER, 'test_scheduler_dags.py'),
- num_runs=1)
+ scheduler = SchedulerJob(
+ executor=self.null_exec,
+ subdir=os.path.join(TEST_DAG_FOLDER, 'test_scheduler_dags.py'),
+ num_runs=1,
+ )
scheduler.run()
# zero tasks ran
dag_id = 'test_start_date_scheduling'
with create_session() as session:
- self.assertEqual(
- session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).count(), 0)
+ self.assertEqual(session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).count(), 0)
def test_scheduler_verify_pool_full(self):
"""
Test task instances not queued when pool is full
"""
- dag = DAG(
- dag_id='test_scheduler_verify_pool_full',
- start_date=DEFAULT_DATE)
+ dag = DAG(dag_id='test_scheduler_verify_pool_full', start_date=DEFAULT_DATE)
BashOperator(
task_id='dummy',
@@ -2509,9 +2470,11 @@ def test_scheduler_verify_pool_full(self):
bash_command='echo hi',
)
- dagbag = DagBag(dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"),
- include_examples=False,
- read_dags_from_db=True)
+ dagbag = DagBag(
+ dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"),
+ include_examples=False,
+ read_dags_from_db=True,
+ )
dagbag.bag_dag(dag=dag, root_dag=dag)
dagbag.sync_to_db()
@@ -2549,9 +2512,7 @@ def test_scheduler_verify_pool_full_2_slots_per_task(self):
Variation with non-default pool_slots
"""
- dag = DAG(
- dag_id='test_scheduler_verify_pool_full_2_slots_per_task',
- start_date=DEFAULT_DATE)
+ dag = DAG(dag_id='test_scheduler_verify_pool_full_2_slots_per_task', start_date=DEFAULT_DATE)
BashOperator(
task_id='dummy',
@@ -2562,9 +2523,11 @@ def test_scheduler_verify_pool_full_2_slots_per_task(self):
bash_command='echo hi',
)
- dagbag = DagBag(dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"),
- include_examples=False,
- read_dags_from_db=True)
+ dagbag = DagBag(
+ dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"),
+ include_examples=False,
+ read_dags_from_db=True,
+ )
dagbag.bag_dag(dag=dag, root_dag=dag)
dagbag.sync_to_db()
@@ -2601,9 +2564,7 @@ def test_scheduler_verify_priority_and_slots(self):
Though tasks with lower priority might be executed.
"""
- dag = DAG(
- dag_id='test_scheduler_verify_priority_and_slots',
- start_date=DEFAULT_DATE)
+ dag = DAG(dag_id='test_scheduler_verify_priority_and_slots', start_date=DEFAULT_DATE)
# Medium priority, not enough slots
BashOperator(
@@ -2636,9 +2597,11 @@ def test_scheduler_verify_priority_and_slots(self):
bash_command='echo hi',
)
- dagbag = DagBag(dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"),
- include_examples=False,
- read_dags_from_db=True)
+ dagbag = DagBag(
+ dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"),
+ include_examples=False,
+ read_dags_from_db=True,
+ )
dagbag.bag_dag(dag=dag, root_dag=dag)
dagbag.sync_to_db()
@@ -2664,16 +2627,25 @@ def test_scheduler_verify_priority_and_slots(self):
# Only second and third
self.assertEqual(len(task_instances_list), 2)
- ti0 = session.query(TaskInstance)\
- .filter(TaskInstance.task_id == 'test_scheduler_verify_priority_and_slots_t0').first()
+ ti0 = (
+ session.query(TaskInstance)
+ .filter(TaskInstance.task_id == 'test_scheduler_verify_priority_and_slots_t0')
+ .first()
+ )
self.assertEqual(ti0.state, State.SCHEDULED)
- ti1 = session.query(TaskInstance)\
- .filter(TaskInstance.task_id == 'test_scheduler_verify_priority_and_slots_t1').first()
+ ti1 = (
+ session.query(TaskInstance)
+ .filter(TaskInstance.task_id == 'test_scheduler_verify_priority_and_slots_t1')
+ .first()
+ )
self.assertEqual(ti1.state, State.QUEUED)
- ti2 = session.query(TaskInstance)\
- .filter(TaskInstance.task_id == 'test_scheduler_verify_priority_and_slots_t2').first()
+ ti2 = (
+ session.query(TaskInstance)
+ .filter(TaskInstance.task_id == 'test_scheduler_verify_priority_and_slots_t2')
+ .first()
+ )
self.assertEqual(ti2.state, State.QUEUED)
def test_verify_integrity_if_dag_not_changed(self):
@@ -2711,12 +2683,16 @@ def test_verify_integrity_if_dag_not_changed(self):
assert scheduled_tis == 1
- tis_count = session.query(func.count(TaskInstance.task_id)).filter(
- TaskInstance.dag_id == dr.dag_id,
- TaskInstance.execution_date == dr.execution_date,
- TaskInstance.task_id == dr.dag.tasks[0].task_id,
- TaskInstance.state == State.SCHEDULED
- ).scalar()
+ tis_count = (
+ session.query(func.count(TaskInstance.task_id))
+ .filter(
+ TaskInstance.dag_id == dr.dag_id,
+ TaskInstance.execution_date == dr.execution_date,
+ TaskInstance.task_id == dr.dag.tasks[0].task_id,
+ TaskInstance.state == State.SCHEDULED,
+ )
+ .scalar()
+ )
assert tis_count == 1
latest_dag_version = SerializedDagModel.get_latest_version_hash(dr.dag_id, session=session)
@@ -2776,11 +2752,15 @@ def test_verify_integrity_if_dag_changed(self):
assert scheduler.dagbag.dags == {'test_verify_integrity_if_dag_changed': dag}
assert len(scheduler.dagbag.dags.get("test_verify_integrity_if_dag_changed").tasks) == 2
- tis_count = session.query(func.count(TaskInstance.task_id)).filter(
- TaskInstance.dag_id == dr.dag_id,
- TaskInstance.execution_date == dr.execution_date,
- TaskInstance.state == State.SCHEDULED
- ).scalar()
+ tis_count = (
+ session.query(func.count(TaskInstance.task_id))
+ .filter(
+ TaskInstance.dag_id == dr.dag_id,
+ TaskInstance.execution_date == dr.execution_date,
+ TaskInstance.state == State.SCHEDULED,
+ )
+ .scalar()
+ )
assert tis_count == 2
latest_dag_version = SerializedDagModel.get_latest_version_hash(dr.dag_id, session=session)
@@ -2798,16 +2778,10 @@ def test_retry_still_in_executor(self):
dagbag = DagBag(dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"), include_examples=False)
dagbag.dags.clear()
- dag = DAG(
- dag_id='test_retry_still_in_executor',
- start_date=DEFAULT_DATE,
- schedule_interval="@once")
+ dag = DAG(dag_id='test_retry_still_in_executor', start_date=DEFAULT_DATE, schedule_interval="@once")
dag_task1 = BashOperator(
- task_id='test_retry_handling_op',
- bash_command='exit 1',
- retries=1,
- dag=dag,
- owner='airflow')
+ task_id='test_retry_handling_op', bash_command='exit 1', retries=1, dag=dag, owner='airflow'
+ )
dag.clear()
dag.is_subdag = False
@@ -2825,17 +2799,22 @@ def do_schedule(mock_dagbag):
# Use a empty file since the above mock will return the
# expected DAGs. Also specify only a single file so that it doesn't
# try to schedule the above DAG repeatedly.
- scheduler = SchedulerJob(num_runs=1,
- executor=executor,
- subdir=os.path.join(settings.DAGS_FOLDER,
- "no_dags.py"))
+ scheduler = SchedulerJob(
+ num_runs=1, executor=executor, subdir=os.path.join(settings.DAGS_FOLDER, "no_dags.py")
+ )
scheduler.heartrate = 0
scheduler.run()
do_schedule() # pylint: disable=no-value-for-parameter
with create_session() as session:
- ti = session.query(TaskInstance).filter(TaskInstance.dag_id == 'test_retry_still_in_executor',
- TaskInstance.task_id == 'test_retry_handling_op').first()
+ ti = (
+ session.query(TaskInstance)
+ .filter(
+ TaskInstance.dag_id == 'test_retry_still_in_executor',
+ TaskInstance.task_id == 'test_retry_handling_op',
+ )
+ .first()
+ )
ti.task = dag_task1
def run_with_error(ti, ignore_ti_state=False):
@@ -2875,14 +2854,16 @@ def test_retry_handling_job(self):
dag_task1 = dag.get_task("test_retry_handling_op")
dag.clear()
- scheduler = SchedulerJob(dag_id=dag.dag_id,
- num_runs=1)
+ scheduler = SchedulerJob(dag_id=dag.dag_id, num_runs=1)
scheduler.heartrate = 0
scheduler.run()
session = settings.Session()
- ti = session.query(TaskInstance).filter(TaskInstance.dag_id == dag.dag_id,
- TaskInstance.task_id == dag_task1.task_id).first()
+ ti = (
+ session.query(TaskInstance)
+ .filter(TaskInstance.dag_id == dag.dag_id, TaskInstance.task_id == dag_task1.task_id)
+ .first()
+ )
# make sure the counter has increased
self.assertEqual(ti.try_number, 2)
@@ -2894,23 +2875,15 @@ def test_dag_get_active_runs(self):
"""
now = timezone.utcnow()
- six_hours_ago_to_the_hour = \
- (now - datetime.timedelta(hours=6)).replace(minute=0, second=0, microsecond=0)
+ six_hours_ago_to_the_hour = (now - datetime.timedelta(hours=6)).replace(
+ minute=0, second=0, microsecond=0
+ )
start_date = six_hours_ago_to_the_hour
dag_name1 = 'get_active_runs_test'
- default_args = {
- 'owner': 'airflow',
- 'depends_on_past': False,
- 'start_date': start_date
-
- }
- dag1 = DAG(dag_name1,
- schedule_interval='* * * * *',
- max_active_runs=1,
- default_args=default_args
- )
+ default_args = {'owner': 'airflow', 'depends_on_past': False, 'start_date': start_date}
+ dag1 = DAG(dag_name1, schedule_interval='* * * * *', max_active_runs=1, default_args=default_args)
run_this_1 = DummyOperator(task_id='run_this_1', dag=dag1)
run_this_2 = DummyOperator(task_id='run_this_2', dag=dag1)
@@ -2963,10 +2936,8 @@ def test_add_unparseable_file_before_sched_start_creates_import_error(self):
self.assertEqual(len(import_errors), 1)
import_error = import_errors[0]
- self.assertEqual(import_error.filename,
- unparseable_filename)
- self.assertEqual(import_error.stacktrace,
- f"invalid syntax ({TEMP_DAG_FILENAME}, line 1)")
+ self.assertEqual(import_error.filename, unparseable_filename)
+ self.assertEqual(import_error.stacktrace, f"invalid syntax ({TEMP_DAG_FILENAME}, line 1)")
@conf_vars({("core", "dagbag_import_error_tracebacks"): "False"})
def test_add_unparseable_file_after_sched_start_creates_import_error(self):
@@ -2993,10 +2964,8 @@ def test_add_unparseable_file_after_sched_start_creates_import_error(self):
self.assertEqual(len(import_errors), 1)
import_error = import_errors[0]
- self.assertEqual(import_error.filename,
- unparseable_filename)
- self.assertEqual(import_error.stacktrace,
- f"invalid syntax ({TEMP_DAG_FILENAME}, line 1)")
+ self.assertEqual(import_error.filename, unparseable_filename)
+ self.assertEqual(import_error.stacktrace, f"invalid syntax ({TEMP_DAG_FILENAME}, line 1)")
def test_no_import_errors_with_parseable_dag(self):
try:
@@ -3028,9 +2997,8 @@ def test_new_import_error_replaces_old(self):
# Generate replacement import error (the error will be on the second line now)
with open(unparseable_filename, 'w') as unparseable_file:
unparseable_file.writelines(
- PARSEABLE_DAG_FILE_CONTENTS +
- os.linesep +
- UNPARSEABLE_DAG_FILE_CONTENTS)
+ PARSEABLE_DAG_FILE_CONTENTS + os.linesep + UNPARSEABLE_DAG_FILE_CONTENTS
+ )
self.run_single_scheduler_loop_with_no_dags(dags_folder)
finally:
shutil.rmtree(dags_folder)
@@ -3040,10 +3008,8 @@ def test_new_import_error_replaces_old(self):
self.assertEqual(len(import_errors), 1)
import_error = import_errors[0]
- self.assertEqual(import_error.filename,
- unparseable_filename)
- self.assertEqual(import_error.stacktrace,
- f"invalid syntax ({TEMP_DAG_FILENAME}, line 2)")
+ self.assertEqual(import_error.filename, unparseable_filename)
+ self.assertEqual(import_error.stacktrace, f"invalid syntax ({TEMP_DAG_FILENAME}, line 2)")
def test_remove_error_clears_import_error(self):
try:
@@ -3057,8 +3023,7 @@ def test_remove_error_clears_import_error(self):
# Remove the import error from the file
with open(filename_to_parse, 'w') as file_to_parse:
- file_to_parse.writelines(
- PARSEABLE_DAG_FILE_CONTENTS)
+ file_to_parse.writelines(PARSEABLE_DAG_FILE_CONTENTS)
self.run_single_scheduler_loop_with_no_dags(dags_folder)
finally:
shutil.rmtree(dags_folder)
@@ -3113,8 +3078,7 @@ def test_import_error_tracebacks(self):
"NameError: name 'airflow_DAG' is not defined\n"
)
self.assertEqual(
- import_error.stacktrace,
- expected_stacktrace.format(unparseable_filename, unparseable_filename)
+ import_error.stacktrace, expected_stacktrace.format(unparseable_filename, unparseable_filename)
)
@conf_vars({("core", "dagbag_import_error_traceback_depth"): "1"})
@@ -3140,9 +3104,7 @@ def test_import_error_traceback_depth(self):
" return airflow_DAG\n"
"NameError: name 'airflow_DAG' is not defined\n"
)
- self.assertEqual(
- import_error.stacktrace, expected_stacktrace.format(unparseable_filename)
- )
+ self.assertEqual(import_error.stacktrace, expected_stacktrace.format(unparseable_filename))
def test_import_error_tracebacks_zip(self):
dags_folder = mkdtemp()
@@ -3170,8 +3132,7 @@ def test_import_error_tracebacks_zip(self):
"NameError: name 'airflow_DAG' is not defined\n"
)
self.assertEqual(
- import_error.stacktrace,
- expected_stacktrace.format(invalid_dag_filename, invalid_dag_filename)
+ import_error.stacktrace, expected_stacktrace.format(invalid_dag_filename, invalid_dag_filename)
)
@conf_vars({("core", "dagbag_import_error_traceback_depth"): "1"})
@@ -3198,9 +3159,7 @@ def test_import_error_tracebacks_zip_depth(self):
" return airflow_DAG\n"
"NameError: name 'airflow_DAG' is not defined\n"
)
- self.assertEqual(
- import_error.stacktrace, expected_stacktrace.format(invalid_dag_filename)
- )
+ self.assertEqual(import_error.stacktrace, expected_stacktrace.format(invalid_dag_filename))
def test_list_py_file_paths(self):
"""
@@ -3220,8 +3179,7 @@ def test_list_py_file_paths(self):
for file_name in files:
if file_name.endswith('.py') or file_name.endswith('.zip'):
if file_name not in ignored_files:
- expected_files.add(
- f'{root}/{file_name}')
+ expected_files.add(f'{root}/{file_name}')
for file_path in list_py_file_paths(TEST_DAG_FOLDER, include_examples=False):
detected_files.add(file_path)
self.assertEqual(detected_files, expected_files)
@@ -3243,13 +3201,14 @@ def test_list_py_file_paths(self):
smart_sensor_dag_folder = airflow.smart_sensor_dags.__path__[0]
for root, _, files in os.walk(smart_sensor_dag_folder):
for file_name in files:
- if (file_name.endswith('.py') or file_name.endswith('.zip')) and \
- file_name not in ['__init__.py']:
+ if (file_name.endswith('.py') or file_name.endswith('.zip')) and file_name not in [
+ '__init__.py'
+ ]:
expected_files.add(os.path.join(root, file_name))
detected_files.clear()
- for file_path in list_py_file_paths(TEST_DAG_FOLDER,
- include_examples=True,
- include_smart_sensor=True):
+ for file_path in list_py_file_paths(
+ TEST_DAG_FOLDER, include_examples=True, include_smart_sensor=True
+ ):
detected_files.add(file_path)
self.assertEqual(detected_files, expected_files)
@@ -3268,12 +3227,14 @@ def test_adopt_or_reset_orphaned_tasks_external_triggered_dag(self):
scheduler = SchedulerJob()
session = settings.Session()
- dr1 = dag.create_dagrun(run_type=DagRunType.SCHEDULED,
- state=State.RUNNING,
- execution_date=DEFAULT_DATE,
- start_date=DEFAULT_DATE,
- external_trigger=True,
- session=session)
+ dr1 = dag.create_dagrun(
+ run_type=DagRunType.SCHEDULED,
+ state=State.RUNNING,
+ execution_date=DEFAULT_DATE,
+ start_date=DEFAULT_DATE,
+ external_trigger=True,
+ session=session,
+ )
ti = dr1.get_task_instances(session=session)[0]
ti.state = State.SCHEDULED
session.merge(ti)
@@ -3294,11 +3255,13 @@ def test_adopt_or_reset_orphaned_tasks_backfill_dag(self):
session.add(scheduler)
session.flush()
- dr1 = dag.create_dagrun(run_type=DagRunType.BACKFILL_JOB,
- state=State.RUNNING,
- execution_date=DEFAULT_DATE,
- start_date=DEFAULT_DATE,
- session=session)
+ dr1 = dag.create_dagrun(
+ run_type=DagRunType.BACKFILL_JOB,
+ state=State.RUNNING,
+ execution_date=DEFAULT_DATE,
+ start_date=DEFAULT_DATE,
+ session=session,
+ )
ti = dr1.get_task_instances(session=session)[0]
ti.state = State.SCHEDULED
session.merge(ti)
@@ -3342,11 +3305,13 @@ def test_reset_orphaned_tasks_no_orphans(self):
session.add(scheduler)
session.flush()
- dr1 = dag.create_dagrun(run_type=DagRunType.SCHEDULED,
- state=State.RUNNING,
- execution_date=DEFAULT_DATE,
- start_date=DEFAULT_DATE,
- session=session)
+ dr1 = dag.create_dagrun(
+ run_type=DagRunType.SCHEDULED,
+ state=State.RUNNING,
+ execution_date=DEFAULT_DATE,
+ start_date=DEFAULT_DATE,
+ session=session,
+ )
tis = dr1.get_task_instances(session=session)
tis[0].state = State.RUNNING
tis[0].queued_by_job_id = scheduler.id
@@ -3370,11 +3335,13 @@ def test_reset_orphaned_tasks_non_running_dagruns(self):
session.add(scheduler)
session.flush()
- dr1 = dag.create_dagrun(run_type=DagRunType.SCHEDULED,
- state=State.SUCCESS,
- execution_date=DEFAULT_DATE,
- start_date=DEFAULT_DATE,
- session=session)
+ dr1 = dag.create_dagrun(
+ run_type=DagRunType.SCHEDULED,
+ state=State.SUCCESS,
+ execution_date=DEFAULT_DATE,
+ start_date=DEFAULT_DATE,
+ session=session,
+ )
tis = dr1.get_task_instances(session=session)
self.assertEqual(1, len(tis))
tis[0].state = State.SCHEDULED
@@ -3409,7 +3376,7 @@ def test_adopt_or_reset_orphaned_tasks_stale_scheduler_jobs(self):
execution_date=DEFAULT_DATE,
start_date=timezone.utcnow(),
state=State.RUNNING,
- session=session
+ session=session,
)
ti1, ti2 = dr1.get_task_instances(session=session)
@@ -3482,9 +3449,7 @@ def test_send_sla_callbacks_to_processor_sla_with_task_slas(self):
)
def test_scheduler_sets_job_id_on_dag_run(self):
- dag = DAG(
- dag_id='test_scheduler_sets_job_id_on_dag_run',
- start_date=DEFAULT_DATE)
+ dag = DAG(dag_id='test_scheduler_sets_job_id_on_dag_run', start_date=DEFAULT_DATE)
DummyOperator(
task_id='dummy',
@@ -3494,7 +3459,7 @@ def test_scheduler_sets_job_id_on_dag_run(self):
dagbag = DagBag(
dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"),
include_examples=False,
- read_dags_from_db=True
+ read_dags_from_db=True,
)
dagbag.bag_dag(dag=dag, root_dag=dag)
dagbag.sync_to_db()
@@ -3587,7 +3552,7 @@ def test_do_schedule_max_active_runs_and_manual_trigger(self):
dagbag = DagBag(
dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"),
include_examples=False,
- read_dags_from_db=True
+ read_dags_from_db=True,
)
dagbag.bag_dag(dag=dag, root_dag=dag)
@@ -3665,9 +3630,7 @@ def test_task_with_upstream_skip_process_task_instances():
"""
clear_db_runs()
with DAG(
- dag_id='test_task_with_upstream_skip_dag',
- start_date=DEFAULT_DATE,
- schedule_interval=None
+ dag_id='test_task_with_upstream_skip_dag', start_date=DEFAULT_DATE, schedule_interval=None
) as dag:
dummy1 = DummyOperator(task_id='dummy1')
dummy2 = DummyOperator(task_id="dummy2")
@@ -3676,9 +3639,7 @@ def test_task_with_upstream_skip_process_task_instances():
# dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
dag.clear()
- dr = dag.create_dagrun(run_type=DagRunType.MANUAL,
- state=State.RUNNING,
- execution_date=DEFAULT_DATE)
+ dr = dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=DEFAULT_DATE)
assert dr is not None
with create_session() as session:
@@ -3729,17 +3690,24 @@ def setUp(self) -> None:
)
@pytest.mark.quarantined
def test_execute_queries_count_with_harvested_dags(self, expected_query_count, dag_count, task_count):
- with mock.patch.dict("os.environ", {
- "PERF_DAGS_COUNT": str(dag_count),
- "PERF_TASKS_COUNT": str(task_count),
- "PERF_START_AGO": "1d",
- "PERF_SCHEDULE_INTERVAL": "30m",
- "PERF_SHAPE": "no_structure",
- }), conf_vars({
- ('scheduler', 'use_job_schedule'): 'True',
- ('core', 'load_examples'): 'False',
- ('core', 'store_serialized_dags'): 'True',
- }), mock.patch.object(settings, 'STORE_SERIALIZED_DAGS', True):
+ with mock.patch.dict(
+ "os.environ",
+ {
+ "PERF_DAGS_COUNT": str(dag_count),
+ "PERF_TASKS_COUNT": str(task_count),
+ "PERF_START_AGO": "1d",
+ "PERF_SCHEDULE_INTERVAL": "30m",
+ "PERF_SHAPE": "no_structure",
+ },
+ ), conf_vars(
+ {
+ ('scheduler', 'use_job_schedule'): 'True',
+ ('core', 'load_examples'): 'False',
+ ('core', 'store_serialized_dags'): 'True',
+ }
+ ), mock.patch.object(
+ settings, 'STORE_SERIALIZED_DAGS', True
+ ):
dagruns = []
dagbag = DagBag(dag_folder=ELASTIC_DAG_FILE, include_examples=False, read_dags_from_db=False)
dagbag.sync_to_db()
@@ -3775,35 +3743,35 @@ def test_execute_queries_count_with_harvested_dags(self, expected_query_count, d
# pylint: disable=bad-whitespace
# expected, dag_count, task_count, start_ago, schedule_interval, shape
# One DAG with one task per DAG file
- ([10, 10, 10, 10], 1, 1, "1d", "None", "no_structure"), # noqa
- ([10, 10, 10, 10], 1, 1, "1d", "None", "linear"), # noqa
- ([22, 14, 14, 14], 1, 1, "1d", "@once", "no_structure"), # noqa
- ([22, 14, 14, 14], 1, 1, "1d", "@once", "linear"), # noqa
- ([22, 24, 27, 30], 1, 1, "1d", "30m", "no_structure"), # noqa
- ([22, 24, 27, 30], 1, 1, "1d", "30m", "linear"), # noqa
- ([22, 24, 27, 30], 1, 1, "1d", "30m", "binary_tree"), # noqa
- ([22, 24, 27, 30], 1, 1, "1d", "30m", "star"), # noqa
- ([22, 24, 27, 30], 1, 1, "1d", "30m", "grid"), # noqa
+ ([10, 10, 10, 10], 1, 1, "1d", "None", "no_structure"), # noqa
+ ([10, 10, 10, 10], 1, 1, "1d", "None", "linear"), # noqa
+ ([22, 14, 14, 14], 1, 1, "1d", "@once", "no_structure"), # noqa
+ ([22, 14, 14, 14], 1, 1, "1d", "@once", "linear"), # noqa
+ ([22, 24, 27, 30], 1, 1, "1d", "30m", "no_structure"), # noqa
+ ([22, 24, 27, 30], 1, 1, "1d", "30m", "linear"), # noqa
+ ([22, 24, 27, 30], 1, 1, "1d", "30m", "binary_tree"), # noqa
+ ([22, 24, 27, 30], 1, 1, "1d", "30m", "star"), # noqa
+ ([22, 24, 27, 30], 1, 1, "1d", "30m", "grid"), # noqa
# One DAG with five tasks per DAG file
- ([10, 10, 10, 10], 1, 5, "1d", "None", "no_structure"), # noqa
- ([10, 10, 10, 10], 1, 5, "1d", "None", "linear"), # noqa
- ([22, 14, 14, 14], 1, 5, "1d", "@once", "no_structure"), # noqa
- ([23, 15, 15, 15], 1, 5, "1d", "@once", "linear"), # noqa
- ([22, 24, 27, 30], 1, 5, "1d", "30m", "no_structure"), # noqa
- ([23, 26, 30, 34], 1, 5, "1d", "30m", "linear"), # noqa
- ([23, 26, 30, 34], 1, 5, "1d", "30m", "binary_tree"), # noqa
- ([23, 26, 30, 34], 1, 5, "1d", "30m", "star"), # noqa
- ([23, 26, 30, 34], 1, 5, "1d", "30m", "grid"), # noqa
+ ([10, 10, 10, 10], 1, 5, "1d", "None", "no_structure"), # noqa
+ ([10, 10, 10, 10], 1, 5, "1d", "None", "linear"), # noqa
+ ([22, 14, 14, 14], 1, 5, "1d", "@once", "no_structure"), # noqa
+ ([23, 15, 15, 15], 1, 5, "1d", "@once", "linear"), # noqa
+ ([22, 24, 27, 30], 1, 5, "1d", "30m", "no_structure"), # noqa
+ ([23, 26, 30, 34], 1, 5, "1d", "30m", "linear"), # noqa
+ ([23, 26, 30, 34], 1, 5, "1d", "30m", "binary_tree"), # noqa
+ ([23, 26, 30, 34], 1, 5, "1d", "30m", "star"), # noqa
+ ([23, 26, 30, 34], 1, 5, "1d", "30m", "grid"), # noqa
# 10 DAGs with 10 tasks per DAG file
- ([10, 10, 10, 10], 10, 10, "1d", "None", "no_structure"), # noqa
- ([10, 10, 10, 10], 10, 10, "1d", "None", "linear"), # noqa
- ([85, 38, 38, 38], 10, 10, "1d", "@once", "no_structure"), # noqa
- ([95, 51, 51, 51], 10, 10, "1d", "@once", "linear"), # noqa
- ([85, 99, 99, 99], 10, 10, "1d", "30m", "no_structure"), # noqa
- ([95, 125, 125, 125], 10, 10, "1d", "30m", "linear"), # noqa
- ([95, 119, 119, 119], 10, 10, "1d", "30m", "binary_tree"), # noqa
- ([95, 119, 119, 119], 10, 10, "1d", "30m", "star"), # noqa
- ([95, 119, 119, 119], 10, 10, "1d", "30m", "grid"), # noqa
+ ([10, 10, 10, 10], 10, 10, "1d", "None", "no_structure"), # noqa
+ ([10, 10, 10, 10], 10, 10, "1d", "None", "linear"), # noqa
+ ([85, 38, 38, 38], 10, 10, "1d", "@once", "no_structure"), # noqa
+ ([95, 51, 51, 51], 10, 10, "1d", "@once", "linear"), # noqa
+ ([85, 99, 99, 99], 10, 10, "1d", "30m", "no_structure"), # noqa
+ ([95, 125, 125, 125], 10, 10, "1d", "30m", "linear"), # noqa
+ ([95, 119, 119, 119], 10, 10, "1d", "30m", "binary_tree"), # noqa
+ ([95, 119, 119, 119], 10, 10, "1d", "30m", "star"), # noqa
+ ([95, 119, 119, 119], 10, 10, "1d", "30m", "grid"), # noqa
# pylint: enable=bad-whitespace
]
)
@@ -3811,16 +3779,23 @@ def test_execute_queries_count_with_harvested_dags(self, expected_query_count, d
def test_process_dags_queries_count(
self, expected_query_counts, dag_count, task_count, start_ago, schedule_interval, shape
):
- with mock.patch.dict("os.environ", {
- "PERF_DAGS_COUNT": str(dag_count),
- "PERF_TASKS_COUNT": str(task_count),
- "PERF_START_AGO": start_ago,
- "PERF_SCHEDULE_INTERVAL": schedule_interval,
- "PERF_SHAPE": shape,
- }), conf_vars({
- ('scheduler', 'use_job_schedule'): 'True',
- ('core', 'store_serialized_dags'): 'True',
- }), mock.patch.object(settings, 'STORE_SERIALIZED_DAGS', True):
+ with mock.patch.dict(
+ "os.environ",
+ {
+ "PERF_DAGS_COUNT": str(dag_count),
+ "PERF_TASKS_COUNT": str(task_count),
+ "PERF_START_AGO": start_ago,
+ "PERF_SCHEDULE_INTERVAL": schedule_interval,
+ "PERF_SHAPE": shape,
+ },
+ ), conf_vars(
+ {
+ ('scheduler', 'use_job_schedule'): 'True',
+ ('core', 'store_serialized_dags'): 'True',
+ }
+ ), mock.patch.object(
+ settings, 'STORE_SERIALIZED_DAGS', True
+ ):
dagbag = DagBag(dag_folder=ELASTIC_DAG_FILE, include_examples=False)
dagbag.sync_to_db()
diff --git a/tests/kubernetes/models/test_secret.py b/tests/kubernetes/models/test_secret.py
index 7036cf7f8d640..183a9ef9bd381 100644
--- a/tests/kubernetes/models/test_secret.py
+++ b/tests/kubernetes/models/test_secret.py
@@ -27,65 +27,50 @@
class TestSecret(unittest.TestCase):
-
def test_to_env_secret(self):
secret = Secret('env', 'name', 'secret', 'key')
- self.assertEqual(secret.to_env_secret(), k8s.V1EnvVar(
- name='NAME',
- value_from=k8s.V1EnvVarSource(
- secret_key_ref=k8s.V1SecretKeySelector(
- name='secret',
- key='key'
- )
- )
- ))
+ self.assertEqual(
+ secret.to_env_secret(),
+ k8s.V1EnvVar(
+ name='NAME',
+ value_from=k8s.V1EnvVarSource(
+ secret_key_ref=k8s.V1SecretKeySelector(name='secret', key='key')
+ ),
+ ),
+ )
def test_to_env_from_secret(self):
secret = Secret('env', None, 'secret')
- self.assertEqual(secret.to_env_from_secret(), k8s.V1EnvFromSource(
- secret_ref=k8s.V1SecretEnvSource(name='secret')
- ))
+ self.assertEqual(
+ secret.to_env_from_secret(), k8s.V1EnvFromSource(secret_ref=k8s.V1SecretEnvSource(name='secret'))
+ )
@mock.patch('uuid.uuid4')
def test_to_volume_secret(self, mock_uuid):
mock_uuid.return_value = '0'
secret = Secret('volume', '/etc/foo', 'secret_b')
- self.assertEqual(secret.to_volume_secret(), (
- k8s.V1Volume(
- name='secretvol0',
- secret=k8s.V1SecretVolumeSource(
- secret_name='secret_b'
- )
+ self.assertEqual(
+ secret.to_volume_secret(),
+ (
+ k8s.V1Volume(name='secretvol0', secret=k8s.V1SecretVolumeSource(secret_name='secret_b')),
+ k8s.V1VolumeMount(mount_path='/etc/foo', name='secretvol0', read_only=True),
),
- k8s.V1VolumeMount(
- mount_path='/etc/foo',
- name='secretvol0',
- read_only=True
- )
- ))
+ )
@mock.patch('uuid.uuid4')
def test_only_mount_sub_secret(self, mock_uuid):
mock_uuid.return_value = '0'
- items = [k8s.V1KeyToPath(
- key="my-username",
- path="/extra/path"
- )
- ]
+ items = [k8s.V1KeyToPath(key="my-username", path="/extra/path")]
secret = Secret('volume', '/etc/foo', 'secret_b', items=items)
- self.assertEqual(secret.to_volume_secret(), (
- k8s.V1Volume(
- name='secretvol0',
- secret=k8s.V1SecretVolumeSource(
- secret_name='secret_b',
- items=items)
+ self.assertEqual(
+ secret.to_volume_secret(),
+ (
+ k8s.V1Volume(
+ name='secretvol0', secret=k8s.V1SecretVolumeSource(secret_name='secret_b', items=items)
+ ),
+ k8s.V1VolumeMount(mount_path='/etc/foo', name='secretvol0', read_only=True),
),
- k8s.V1VolumeMount(
- mount_path='/etc/foo',
- name='secretvol0',
- read_only=True
- )
- ))
+ )
@mock.patch('uuid.uuid4')
def test_attach_to_pod(self, mock_uuid):
@@ -104,44 +89,60 @@ def test_attach_to_pod(self, mock_uuid):
k8s_client = ApiClient()
pod = append_to_pod(pod, secrets)
result = k8s_client.sanitize_for_serialization(pod)
- self.assertEqual(result,
- {'apiVersion': 'v1',
- 'kind': 'Pod',
- 'metadata': {'labels': {'app': 'myapp'},
- 'name': 'myapp-pod-cf4a56d281014217b0272af6216feb48',
- 'namespace': 'default'},
- 'spec': {'containers': [{'command': ['sh', '-c', 'echo Hello Kubernetes!'],
- 'env': [{'name': 'ENVIRONMENT', 'value': 'prod'},
- {'name': 'LOG_LEVEL', 'value': 'warning'},
- {'name': 'TARGET',
- 'valueFrom':
- {'secretKeyRef': {'key': 'source_b',
- 'name': 'secret_b'}}}],
- 'envFrom': [{'configMapRef': {'name': 'configmap_a'}},
- {'secretRef': {'name': 'secret_a'}}],
- 'image': 'busybox',
- 'name': 'base',
- 'ports': [{'containerPort': 1234, 'name': 'foo'}],
- 'resources': {'limits': {'memory': '200Mi'},
- 'requests': {'memory': '100Mi'}},
- 'volumeMounts': [{'mountPath': '/airflow/xcom',
- 'name': 'xcom'},
- {'mountPath': '/etc/foo',
- 'name': 'secretvol' + str(static_uuid),
- 'readOnly': True}]},
- {'command': ['sh',
- '-c',
- 'trap "exit 0" INT; while true; do sleep '
- '30; done;'],
- 'image': 'alpine',
- 'name': 'airflow-xcom-sidecar',
- 'resources': {'requests': {'cpu': '1m'}},
- 'volumeMounts': [{'mountPath': '/airflow/xcom',
- 'name': 'xcom'}]}],
- 'hostNetwork': True,
- 'imagePullSecrets': [{'name': 'pull_secret_a'},
- {'name': 'pull_secret_b'}],
- 'securityContext': {'fsGroup': 2000, 'runAsUser': 1000},
- 'volumes': [{'emptyDir': {}, 'name': 'xcom'},
- {'name': 'secretvol' + str(static_uuid),
- 'secret': {'secretName': 'secret_b'}}]}})
+ self.assertEqual(
+ result,
+ {
+ 'apiVersion': 'v1',
+ 'kind': 'Pod',
+ 'metadata': {
+ 'labels': {'app': 'myapp'},
+ 'name': 'myapp-pod-cf4a56d281014217b0272af6216feb48',
+ 'namespace': 'default',
+ },
+ 'spec': {
+ 'containers': [
+ {
+ 'command': ['sh', '-c', 'echo Hello Kubernetes!'],
+ 'env': [
+ {'name': 'ENVIRONMENT', 'value': 'prod'},
+ {'name': 'LOG_LEVEL', 'value': 'warning'},
+ {
+ 'name': 'TARGET',
+ 'valueFrom': {'secretKeyRef': {'key': 'source_b', 'name': 'secret_b'}},
+ },
+ ],
+ 'envFrom': [
+ {'configMapRef': {'name': 'configmap_a'}},
+ {'secretRef': {'name': 'secret_a'}},
+ ],
+ 'image': 'busybox',
+ 'name': 'base',
+ 'ports': [{'containerPort': 1234, 'name': 'foo'}],
+ 'resources': {'limits': {'memory': '200Mi'}, 'requests': {'memory': '100Mi'}},
+ 'volumeMounts': [
+ {'mountPath': '/airflow/xcom', 'name': 'xcom'},
+ {
+ 'mountPath': '/etc/foo',
+ 'name': 'secretvol' + str(static_uuid),
+ 'readOnly': True,
+ },
+ ],
+ },
+ {
+ 'command': ['sh', '-c', 'trap "exit 0" INT; while true; do sleep 30; done;'],
+ 'image': 'alpine',
+ 'name': 'airflow-xcom-sidecar',
+ 'resources': {'requests': {'cpu': '1m'}},
+ 'volumeMounts': [{'mountPath': '/airflow/xcom', 'name': 'xcom'}],
+ },
+ ],
+ 'hostNetwork': True,
+ 'imagePullSecrets': [{'name': 'pull_secret_a'}, {'name': 'pull_secret_b'}],
+ 'securityContext': {'fsGroup': 2000, 'runAsUser': 1000},
+ 'volumes': [
+ {'emptyDir': {}, 'name': 'xcom'},
+ {'name': 'secretvol' + str(static_uuid), 'secret': {'secretName': 'secret_b'}},
+ ],
+ },
+ },
+ )
diff --git a/tests/kubernetes/test_client.py b/tests/kubernetes/test_client.py
index fae79a5470bae..73faedd370a9e 100644
--- a/tests/kubernetes/test_client.py
+++ b/tests/kubernetes/test_client.py
@@ -25,7 +25,6 @@
class TestClient(unittest.TestCase):
-
@mock.patch('airflow.kubernetes.kube_client.config')
def test_load_cluster_config(self, _):
client = get_kube_client(in_cluster=True)
diff --git a/tests/kubernetes/test_pod_generator.py b/tests/kubernetes/test_pod_generator.py
index 865c156b45678..8af12f9a03fd6 100644
--- a/tests/kubernetes/test_pod_generator.py
+++ b/tests/kubernetes/test_pod_generator.py
@@ -25,13 +25,16 @@
from airflow import __version__
from airflow.exceptions import AirflowConfigException
from airflow.kubernetes.pod_generator import (
- PodDefaults, PodGenerator, datetime_to_label_safe_datestring, extend_object_field, merge_objects,
+ PodDefaults,
+ PodGenerator,
+ datetime_to_label_safe_datestring,
+ extend_object_field,
+ merge_objects,
)
from airflow.kubernetes.secret import Secret
class TestPodGenerator(unittest.TestCase):
-
def setUp(self):
self.static_uuid = uuid.UUID('cf4a56d2-8101-4217-b027-2af6216feb48')
self.deserialize_result = {
@@ -39,23 +42,19 @@ def setUp(self):
'kind': 'Pod',
'metadata': {'name': 'memory-demo', 'namespace': 'mem-example'},
'spec': {
- 'containers': [{
- 'args': ['--vm', '1', '--vm-bytes', '150M', '--vm-hang', '1'],
- 'command': ['stress'],
- 'image': 'apache/airflow:stress-2020.07.10-1.0.4',
- 'name': 'memory-demo-ctr',
- 'resources': {
- 'limits': {'memory': '200Mi'},
- 'requests': {'memory': '100Mi'}
+ 'containers': [
+ {
+ 'args': ['--vm', '1', '--vm-bytes', '150M', '--vm-hang', '1'],
+ 'command': ['stress'],
+ 'image': 'apache/airflow:stress-2020.07.10-1.0.4',
+ 'name': 'memory-demo-ctr',
+ 'resources': {'limits': {'memory': '200Mi'}, 'requests': {'memory': '100Mi'}},
}
- }]
- }
+ ]
+ },
}
- self.envs = {
- 'ENVIRONMENT': 'prod',
- 'LOG_LEVEL': 'warning'
- }
+ self.envs = {'ENVIRONMENT': 'prod', 'LOG_LEVEL': 'warning'}
self.secrets = [
# This should be a secretRef
Secret('env', None, 'secret_a'),
@@ -77,7 +76,7 @@ def setUp(self):
'task_id': self.task_id,
'try_number': str(self.try_number),
'airflow_version': __version__.replace('+', '-'),
- 'kubernetes_executor': 'True'
+ 'kubernetes_executor': 'True',
}
self.annotations = {
'dag_id': self.dag_id,
@@ -98,12 +97,7 @@ def setUp(self):
"memory": "1Gi",
"ephemeral-storage": "2Gi",
},
- limits={
- "cpu": 2,
- "memory": "2Gi",
- "ephemeral-storage": "4Gi",
- 'nvidia.com/gpu': 1
- }
+ limits={"cpu": 2, "memory": "2Gi", "ephemeral-storage": "4Gi", 'nvidia.com/gpu': 1},
)
self.k8s_client = ApiClient()
@@ -120,14 +114,9 @@ def setUp(self):
k8s.V1Container(
name='base',
image='busybox',
- command=[
- 'sh', '-c', 'echo Hello Kubernetes!'
- ],
+ command=['sh', '-c', 'echo Hello Kubernetes!'],
env=[
- k8s.V1EnvVar(
- name='ENVIRONMENT',
- value='prod'
- ),
+ k8s.V1EnvVar(name='ENVIRONMENT', value='prod'),
k8s.V1EnvVar(
name="LOG_LEVEL",
value='warning',
@@ -135,44 +124,22 @@ def setUp(self):
k8s.V1EnvVar(
name='TARGET',
value_from=k8s.V1EnvVarSource(
- secret_key_ref=k8s.V1SecretKeySelector(
- name='secret_b',
- key='source_b'
- )
+ secret_key_ref=k8s.V1SecretKeySelector(name='secret_b', key='source_b')
),
- )
- ],
- env_from=[
- k8s.V1EnvFromSource(
- config_map_ref=k8s.V1ConfigMapEnvSource(
- name='configmap_a'
- )
- ),
- k8s.V1EnvFromSource(
- config_map_ref=k8s.V1ConfigMapEnvSource(
- name='configmap_b'
- )
- ),
- k8s.V1EnvFromSource(
- secret_ref=k8s.V1SecretEnvSource(
- name='secret_a'
- )
),
],
- ports=[
- k8s.V1ContainerPort(
- name="foo",
- container_port=1234
- )
+ env_from=[
+ k8s.V1EnvFromSource(config_map_ref=k8s.V1ConfigMapEnvSource(name='configmap_a')),
+ k8s.V1EnvFromSource(config_map_ref=k8s.V1ConfigMapEnvSource(name='configmap_b')),
+ k8s.V1EnvFromSource(secret_ref=k8s.V1SecretEnvSource(name='secret_a')),
],
+ ports=[k8s.V1ContainerPort(name="foo", container_port=1234)],
resources=k8s.V1ResourceRequirements(
- requests={
- 'memory': '100Mi'
- },
+ requests={'memory': '100Mi'},
limits={
'memory': '200Mi',
- }
- )
+ },
+ ),
)
],
security_context=k8s.V1PodSecurityContext(
@@ -181,13 +148,9 @@ def setUp(self):
),
host_network=True,
image_pull_secrets=[
- k8s.V1LocalObjectReference(
- name="pull_secret_a"
- ),
- k8s.V1LocalObjectReference(
- name="pull_secret_b"
- )
- ]
+ k8s.V1LocalObjectReference(name="pull_secret_a"),
+ k8s.V1LocalObjectReference(name="pull_secret_b"),
+ ],
),
)
@@ -196,31 +159,20 @@ def test_gen_pod_extract_xcom(self, mock_uuid):
mock_uuid.return_value = self.static_uuid
path = sys.path[0] + '/tests/kubernetes/pod_generator_base_with_secrets.yaml'
- pod_generator = PodGenerator(
- pod_template_file=path,
- extract_xcom=True
- )
+ pod_generator = PodGenerator(pod_template_file=path, extract_xcom=True)
result = pod_generator.gen_pod()
result_dict = self.k8s_client.sanitize_for_serialization(result)
container_two = {
'name': 'airflow-xcom-sidecar',
'image': "alpine",
'command': ['sh', '-c', PodDefaults.XCOM_CMD],
- 'volumeMounts': [
- {
- 'name': 'xcom',
- 'mountPath': '/airflow/xcom'
- }
- ],
+ 'volumeMounts': [{'name': 'xcom', 'mountPath': '/airflow/xcom'}],
'resources': {'requests': {'cpu': '1m'}},
}
self.expected.spec.containers.append(container_two)
base_container: k8s.V1Container = self.expected.spec.containers[0]
base_container.volume_mounts = base_container.volume_mounts or []
- base_container.volume_mounts.append(k8s.V1VolumeMount(
- name="xcom",
- mount_path="/airflow/xcom"
- ))
+ base_container.volume_mounts.append(k8s.V1VolumeMount(name="xcom", mount_path="/airflow/xcom"))
self.expected.spec.containers[0] = base_container
self.expected.spec.volumes = self.expected.spec.volumes or []
self.expected.spec.volumes.append(
@@ -240,153 +192,148 @@ def test_from_obj(self):
"pod_override": k8s.V1Pod(
api_version="v1",
kind="Pod",
- metadata=k8s.V1ObjectMeta(
- name="foo",
- annotations={"test": "annotation"}
- ),
+ metadata=k8s.V1ObjectMeta(name="foo", annotations={"test": "annotation"}),
spec=k8s.V1PodSpec(
containers=[
k8s.V1Container(
name="base",
volume_mounts=[
k8s.V1VolumeMount(
- mount_path="/foo/",
- name="example-kubernetes-test-volume"
+ mount_path="/foo/", name="example-kubernetes-test-volume"
)
- ]
+ ],
)
],
volumes=[
k8s.V1Volume(
name="example-kubernetes-test-volume",
- host_path=k8s.V1HostPathVolumeSource(
- path="/tmp/"
- )
+ host_path=k8s.V1HostPathVolumeSource(path="/tmp/"),
)
- ]
- )
+ ],
+ ),
)
}
)
result = self.k8s_client.sanitize_for_serialization(result)
- self.assertEqual({
- 'apiVersion': 'v1',
- 'kind': 'Pod',
- 'metadata': {
- 'name': 'foo',
- 'annotations': {'test': 'annotation'},
+ self.assertEqual(
+ {
+ 'apiVersion': 'v1',
+ 'kind': 'Pod',
+ 'metadata': {
+ 'name': 'foo',
+ 'annotations': {'test': 'annotation'},
+ },
+ 'spec': {
+ 'containers': [
+ {
+ 'name': 'base',
+ 'volumeMounts': [
+ {'mountPath': '/foo/', 'name': 'example-kubernetes-test-volume'}
+ ],
+ }
+ ],
+ 'volumes': [{'hostPath': {'path': '/tmp/'}, 'name': 'example-kubernetes-test-volume'}],
+ },
},
- 'spec': {
- 'containers': [{
- 'name': 'base',
- 'volumeMounts': [{
- 'mountPath': '/foo/',
- 'name': 'example-kubernetes-test-volume'
- }],
- }],
- 'volumes': [{
- 'hostPath': {'path': '/tmp/'},
- 'name': 'example-kubernetes-test-volume'
- }],
- }
- }, result)
- result = PodGenerator.from_obj({
- "KubernetesExecutor": {
- "annotations": {"test": "annotation"},
- "volumes": [
- {
- "name": "example-kubernetes-test-volume",
- "hostPath": {"path": "/tmp/"},
- },
- ],
- "volume_mounts": [
- {
- "mountPath": "/foo/",
- "name": "example-kubernetes-test-volume",
- },
- ],
+ result,
+ )
+ result = PodGenerator.from_obj(
+ {
+ "KubernetesExecutor": {
+ "annotations": {"test": "annotation"},
+ "volumes": [
+ {
+ "name": "example-kubernetes-test-volume",
+ "hostPath": {"path": "/tmp/"},
+ },
+ ],
+ "volume_mounts": [
+ {
+ "mountPath": "/foo/",
+ "name": "example-kubernetes-test-volume",
+ },
+ ],
+ }
}
- })
+ )
result_from_pod = PodGenerator.from_obj(
- {"pod_override":
- k8s.V1Pod(
- metadata=k8s.V1ObjectMeta(
- annotations={"test": "annotation"}
- ),
+ {
+ "pod_override": k8s.V1Pod(
+ metadata=k8s.V1ObjectMeta(annotations={"test": "annotation"}),
spec=k8s.V1PodSpec(
containers=[
k8s.V1Container(
name="base",
volume_mounts=[
k8s.V1VolumeMount(
- name="example-kubernetes-test-volume",
- mount_path="/foo/"
+ name="example-kubernetes-test-volume", mount_path="/foo/"
)
- ]
+ ],
)
],
- volumes=[
- k8s.V1Volume(
- name="example-kubernetes-test-volume",
- host_path="/tmp/"
- )
- ]
- )
+ volumes=[k8s.V1Volume(name="example-kubernetes-test-volume", host_path="/tmp/")],
+ ),
)
- }
+ }
)
result = self.k8s_client.sanitize_for_serialization(result)
result_from_pod = self.k8s_client.sanitize_for_serialization(result_from_pod)
- expected_from_pod = {'metadata': {'annotations': {'test': 'annotation'}},
- 'spec': {'containers': [
- {'name': 'base',
- 'volumeMounts': [{'mountPath': '/foo/',
- 'name': 'example-kubernetes-test-volume'}]}],
- 'volumes': [{'hostPath': '/tmp/',
- 'name': 'example-kubernetes-test-volume'}]}}
- self.assertEqual(result_from_pod, expected_from_pod, "There was a discrepency"
- " between KubernetesExecutor and pod_override")
-
- self.assertEqual({
- 'apiVersion': 'v1',
- 'kind': 'Pod',
- 'metadata': {
- 'annotations': {'test': 'annotation'},
- },
+ expected_from_pod = {
+ 'metadata': {'annotations': {'test': 'annotation'}},
'spec': {
- 'containers': [{
- 'args': [],
- 'command': [],
- 'env': [],
- 'envFrom': [],
- 'name': 'base',
- 'ports': [],
- 'volumeMounts': [{
- 'mountPath': '/foo/',
- 'name': 'example-kubernetes-test-volume'
- }],
- }],
- 'hostNetwork': False,
- 'imagePullSecrets': [],
- 'volumes': [{
- 'hostPath': {'path': '/tmp/'},
- 'name': 'example-kubernetes-test-volume'
- }],
- }
- }, result)
+ 'containers': [
+ {
+ 'name': 'base',
+ 'volumeMounts': [{'mountPath': '/foo/', 'name': 'example-kubernetes-test-volume'}],
+ }
+ ],
+ 'volumes': [{'hostPath': '/tmp/', 'name': 'example-kubernetes-test-volume'}],
+ },
+ }
+ self.assertEqual(
+ result_from_pod,
+ expected_from_pod,
+ "There was a discrepency" " between KubernetesExecutor and pod_override",
+ )
+
+ self.assertEqual(
+ {
+ 'apiVersion': 'v1',
+ 'kind': 'Pod',
+ 'metadata': {
+ 'annotations': {'test': 'annotation'},
+ },
+ 'spec': {
+ 'containers': [
+ {
+ 'args': [],
+ 'command': [],
+ 'env': [],
+ 'envFrom': [],
+ 'name': 'base',
+ 'ports': [],
+ 'volumeMounts': [
+ {'mountPath': '/foo/', 'name': 'example-kubernetes-test-volume'}
+ ],
+ }
+ ],
+ 'hostNetwork': False,
+ 'imagePullSecrets': [],
+ 'volumes': [{'hostPath': {'path': '/tmp/'}, 'name': 'example-kubernetes-test-volume'}],
+ },
+ },
+ result,
+ )
@mock.patch('uuid.uuid4')
def test_reconcile_pods_empty_mutator_pod(self, mock_uuid):
mock_uuid.return_value = self.static_uuid
path = sys.path[0] + '/tests/kubernetes/pod_generator_base_with_secrets.yaml'
- pod_generator = PodGenerator(
- pod_template_file=path,
- extract_xcom=True
- )
+ pod_generator = PodGenerator(pod_template_file=path, extract_xcom=True)
base_pod = pod_generator.gen_pod()
mutator_pod = None
name = 'name1-' + self.static_uuid.hex
@@ -405,10 +352,7 @@ def test_reconcile_pods(self, mock_uuid):
mock_uuid.return_value = self.static_uuid
path = sys.path[0] + '/tests/kubernetes/pod_generator_base_with_secrets.yaml'
- base_pod = PodGenerator(
- pod_template_file=path,
- extract_xcom=False
- ).gen_pod()
+ base_pod = PodGenerator(pod_template_file=path, extract_xcom=False).gen_pod()
mutator_pod = k8s.V1Pod(
metadata=k8s.V1ObjectMeta(
@@ -416,19 +360,23 @@ def test_reconcile_pods(self, mock_uuid):
labels={"bar": "baz"},
),
spec=k8s.V1PodSpec(
- containers=[k8s.V1Container(
- image='',
- name='name',
- command=['/bin/command2.sh', 'arg2'],
- volume_mounts=[k8s.V1VolumeMount(mount_path="/foo/",
- name="example-kubernetes-test-volume2")]
- )],
+ containers=[
+ k8s.V1Container(
+ image='',
+ name='name',
+ command=['/bin/command2.sh', 'arg2'],
+ volume_mounts=[
+ k8s.V1VolumeMount(mount_path="/foo/", name="example-kubernetes-test-volume2")
+ ],
+ )
+ ],
volumes=[
- k8s.V1Volume(host_path=k8s.V1HostPathVolumeSource(path="/tmp/"),
- name="example-kubernetes-test-volume2")
- ]
- )
-
+ k8s.V1Volume(
+ host_path=k8s.V1HostPathVolumeSource(path="/tmp/"),
+ name="example-kubernetes-test-volume2",
+ )
+ ],
+ ),
)
result = PodGenerator.reconcile_pods(base_pod, mutator_pod)
@@ -437,15 +385,15 @@ def test_reconcile_pods(self, mock_uuid):
expected.metadata.labels['bar'] = 'baz'
expected.spec.volumes = expected.spec.volumes or []
expected.spec.volumes.append(
- k8s.V1Volume(host_path=k8s.V1HostPathVolumeSource(path="/tmp/"),
- name="example-kubernetes-test-volume2")
+ k8s.V1Volume(
+ host_path=k8s.V1HostPathVolumeSource(path="/tmp/"), name="example-kubernetes-test-volume2"
+ )
)
base_container: k8s.V1Container = expected.spec.containers[0]
base_container.command = ['/bin/command2.sh', 'arg2']
base_container.volume_mounts = [
- k8s.V1VolumeMount(mount_path="/foo/",
- name="example-kubernetes-test-volume2")
+ k8s.V1VolumeMount(mount_path="/foo/", name="example-kubernetes-test-volume2")
]
base_container.name = "name"
expected.spec.containers[0] = base_container
@@ -464,13 +412,7 @@ def test_construct_pod(self, mock_uuid):
spec=k8s.V1PodSpec(
containers=[
k8s.V1Container(
- name='',
- resources=k8s.V1ResourceRequirements(
- limits={
- 'cpu': '1m',
- 'memory': '1G'
- }
- )
+ name='', resources=k8s.V1ResourceRequirements(limits={'cpu': '1m', 'memory': '1G'})
)
]
)
@@ -497,9 +439,7 @@ def test_construct_pod(self, mock_uuid):
expected.metadata.namespace = 'test_namespace'
expected.spec.containers[0].command = ['command']
expected.spec.containers[0].image = 'airflow_image'
- expected.spec.containers[0].resources = {'limits': {'cpu': '1m',
- 'memory': '1G'}
- }
+ expected.spec.containers[0].resources = {'limits': {'cpu': '1m', 'memory': '1G'}}
result_dict = self.k8s_client.sanitize_for_serialization(result)
expected_dict = self.k8s_client.sanitize_for_serialization(self.expected)
@@ -560,10 +500,7 @@ def test_merge_objects(self):
base_annotations = {'foo1': 'bar1'}
base_labels = {'foo1': 'bar1'}
client_annotations = {'foo2': 'bar2'}
- base_obj = k8s.V1ObjectMeta(
- annotations=base_annotations,
- labels=base_labels
- )
+ base_obj = k8s.V1ObjectMeta(annotations=base_annotations, labels=base_labels)
client_obj = k8s.V1ObjectMeta(annotations=client_annotations)
res = merge_objects(base_obj, client_obj)
client_obj.labels = base_labels
diff --git a/tests/kubernetes/test_pod_launcher.py b/tests/kubernetes/test_pod_launcher.py
index b85ac2bc5bb59..2b55a7770bdcf 100644
--- a/tests/kubernetes/test_pod_launcher.py
+++ b/tests/kubernetes/test_pod_launcher.py
@@ -24,7 +24,6 @@
class TestPodLauncher(unittest.TestCase):
-
def setUp(self):
self.mock_kube_client = mock.Mock()
self.pod_launcher = PodLauncher(kube_client=self.mock_kube_client)
@@ -39,79 +38,77 @@ def test_read_pod_logs_retries_successfully(self):
mock.sentinel.metadata = mock.MagicMock()
self.mock_kube_client.read_namespaced_pod_log.side_effect = [
BaseHTTPError('Boom'),
- mock.sentinel.logs
+ mock.sentinel.logs,
]
logs = self.pod_launcher.read_pod_logs(mock.sentinel)
self.assertEqual(mock.sentinel.logs, logs)
- self.mock_kube_client.read_namespaced_pod_log.assert_has_calls([
- mock.call(
- _preload_content=False,
- container='base',
- follow=True,
- timestamps=False,
- name=mock.sentinel.metadata.name,
- namespace=mock.sentinel.metadata.namespace
- ),
- mock.call(
- _preload_content=False,
- container='base',
- follow=True,
- timestamps=False,
- name=mock.sentinel.metadata.name,
- namespace=mock.sentinel.metadata.namespace
- )
- ])
+ self.mock_kube_client.read_namespaced_pod_log.assert_has_calls(
+ [
+ mock.call(
+ _preload_content=False,
+ container='base',
+ follow=True,
+ timestamps=False,
+ name=mock.sentinel.metadata.name,
+ namespace=mock.sentinel.metadata.namespace,
+ ),
+ mock.call(
+ _preload_content=False,
+ container='base',
+ follow=True,
+ timestamps=False,
+ name=mock.sentinel.metadata.name,
+ namespace=mock.sentinel.metadata.namespace,
+ ),
+ ]
+ )
def test_read_pod_logs_retries_fails(self):
mock.sentinel.metadata = mock.MagicMock()
self.mock_kube_client.read_namespaced_pod_log.side_effect = [
BaseHTTPError('Boom'),
BaseHTTPError('Boom'),
- BaseHTTPError('Boom')
+ BaseHTTPError('Boom'),
]
- self.assertRaises(
- AirflowException,
- self.pod_launcher.read_pod_logs,
- mock.sentinel
- )
+ self.assertRaises(AirflowException, self.pod_launcher.read_pod_logs, mock.sentinel)
def test_read_pod_logs_successfully_with_tail_lines(self):
mock.sentinel.metadata = mock.MagicMock()
- self.mock_kube_client.read_namespaced_pod_log.side_effect = [
- mock.sentinel.logs
- ]
+ self.mock_kube_client.read_namespaced_pod_log.side_effect = [mock.sentinel.logs]
logs = self.pod_launcher.read_pod_logs(mock.sentinel, tail_lines=100)
self.assertEqual(mock.sentinel.logs, logs)
- self.mock_kube_client.read_namespaced_pod_log.assert_has_calls([
- mock.call(
- _preload_content=False,
- container='base',
- follow=True,
- timestamps=False,
- name=mock.sentinel.metadata.name,
- namespace=mock.sentinel.metadata.namespace,
- tail_lines=100
- ),
- ])
+ self.mock_kube_client.read_namespaced_pod_log.assert_has_calls(
+ [
+ mock.call(
+ _preload_content=False,
+ container='base',
+ follow=True,
+ timestamps=False,
+ name=mock.sentinel.metadata.name,
+ namespace=mock.sentinel.metadata.namespace,
+ tail_lines=100,
+ ),
+ ]
+ )
def test_read_pod_logs_successfully_with_since_seconds(self):
mock.sentinel.metadata = mock.MagicMock()
- self.mock_kube_client.read_namespaced_pod_log.side_effect = [
- mock.sentinel.logs
- ]
+ self.mock_kube_client.read_namespaced_pod_log.side_effect = [mock.sentinel.logs]
logs = self.pod_launcher.read_pod_logs(mock.sentinel, since_seconds=2)
self.assertEqual(mock.sentinel.logs, logs)
- self.mock_kube_client.read_namespaced_pod_log.assert_has_calls([
- mock.call(
- _preload_content=False,
- container='base',
- follow=True,
- timestamps=False,
- name=mock.sentinel.metadata.name,
- namespace=mock.sentinel.metadata.namespace,
- since_seconds=2
- ),
- ])
+ self.mock_kube_client.read_namespaced_pod_log.assert_has_calls(
+ [
+ mock.call(
+ _preload_content=False,
+ container='base',
+ follow=True,
+ timestamps=False,
+ name=mock.sentinel.metadata.name,
+ namespace=mock.sentinel.metadata.namespace,
+ since_seconds=2,
+ ),
+ ]
+ )
def test_read_pod_events_successfully_returns_events(self):
mock.sentinel.metadata = mock.MagicMock()
@@ -123,33 +120,31 @@ def test_read_pod_events_retries_successfully(self):
mock.sentinel.metadata = mock.MagicMock()
self.mock_kube_client.list_namespaced_event.side_effect = [
BaseHTTPError('Boom'),
- mock.sentinel.events
+ mock.sentinel.events,
]
events = self.pod_launcher.read_pod_events(mock.sentinel)
self.assertEqual(mock.sentinel.events, events)
- self.mock_kube_client.list_namespaced_event.assert_has_calls([
- mock.call(
- namespace=mock.sentinel.metadata.namespace,
- field_selector=f"involvedObject.name={mock.sentinel.metadata.name}"
- ),
- mock.call(
- namespace=mock.sentinel.metadata.namespace,
- field_selector=f"involvedObject.name={mock.sentinel.metadata.name}"
- )
- ])
+ self.mock_kube_client.list_namespaced_event.assert_has_calls(
+ [
+ mock.call(
+ namespace=mock.sentinel.metadata.namespace,
+ field_selector=f"involvedObject.name={mock.sentinel.metadata.name}",
+ ),
+ mock.call(
+ namespace=mock.sentinel.metadata.namespace,
+ field_selector=f"involvedObject.name={mock.sentinel.metadata.name}",
+ ),
+ ]
+ )
def test_read_pod_events_retries_fails(self):
mock.sentinel.metadata = mock.MagicMock()
self.mock_kube_client.list_namespaced_event.side_effect = [
BaseHTTPError('Boom'),
BaseHTTPError('Boom'),
- BaseHTTPError('Boom')
+ BaseHTTPError('Boom'),
]
- self.assertRaises(
- AirflowException,
- self.pod_launcher.read_pod_events,
- mock.sentinel
- )
+ self.assertRaises(AirflowException, self.pod_launcher.read_pod_events, mock.sentinel)
def test_read_pod_returns_logs(self):
mock.sentinel.metadata = mock.MagicMock()
@@ -161,31 +156,30 @@ def test_read_pod_retries_successfully(self):
mock.sentinel.metadata = mock.MagicMock()
self.mock_kube_client.read_namespaced_pod.side_effect = [
BaseHTTPError('Boom'),
- mock.sentinel.pod_info
+ mock.sentinel.pod_info,
]
pod_info = self.pod_launcher.read_pod(mock.sentinel)
self.assertEqual(mock.sentinel.pod_info, pod_info)
- self.mock_kube_client.read_namespaced_pod.assert_has_calls([
- mock.call(mock.sentinel.metadata.name, mock.sentinel.metadata.namespace),
- mock.call(mock.sentinel.metadata.name, mock.sentinel.metadata.namespace)
- ])
+ self.mock_kube_client.read_namespaced_pod.assert_has_calls(
+ [
+ mock.call(mock.sentinel.metadata.name, mock.sentinel.metadata.namespace),
+ mock.call(mock.sentinel.metadata.name, mock.sentinel.metadata.namespace),
+ ]
+ )
def test_read_pod_retries_fails(self):
mock.sentinel.metadata = mock.MagicMock()
self.mock_kube_client.read_namespaced_pod.side_effect = [
BaseHTTPError('Boom'),
BaseHTTPError('Boom'),
- BaseHTTPError('Boom')
+ BaseHTTPError('Boom'),
]
- self.assertRaises(
- AirflowException,
- self.pod_launcher.read_pod,
- mock.sentinel
- )
+ self.assertRaises(AirflowException, self.pod_launcher.read_pod, mock.sentinel)
def test_parse_log_line(self):
- timestamp, message = \
- self.pod_launcher.parse_log_line('2020-10-08T14:16:17.793417674Z Valid message\n')
+ timestamp, message = self.pod_launcher.parse_log_line(
+ '2020-10-08T14:16:17.793417674Z Valid message\n'
+ )
self.assertEqual(timestamp, '2020-10-08T14:16:17.793417674Z')
self.assertEqual(message, 'Valid message')
diff --git a/tests/kubernetes/test_refresh_config.py b/tests/kubernetes/test_refresh_config.py
index 98a671de08bee..ca3740efd2257 100644
--- a/tests/kubernetes/test_refresh_config.py
+++ b/tests/kubernetes/test_refresh_config.py
@@ -23,7 +23,6 @@
class TestRefreshKubeConfigLoader(TestCase):
-
def test_parse_timestamp_should_convert_z_timezone_to_unix_timestamp(self):
ts = _parse_timestamp("2020-01-13T13:42:20Z")
self.assertEqual(1578922940, ts)
diff --git a/tests/lineage/test_lineage.py b/tests/lineage/test_lineage.py
index dddaaaa70722d..fddf723576cf9 100644
--- a/tests/lineage/test_lineage.py
+++ b/tests/lineage/test_lineage.py
@@ -27,12 +27,8 @@
class TestLineage(unittest.TestCase):
-
def test_lineage(self):
- dag = DAG(
- dag_id='test_prepare_lineage',
- start_date=DEFAULT_DATE
- )
+ dag = DAG(dag_id='test_prepare_lineage', start_date=DEFAULT_DATE)
f1s = "/tmp/does_not_exist_1-{}"
f2s = "/tmp/does_not_exist_2-{}"
@@ -42,16 +38,17 @@ def test_lineage(self):
file3 = File(f3s)
with dag:
- op1 = DummyOperator(task_id='leave1',
- inlets=file1,
- outlets=[file2, ])
+ op1 = DummyOperator(
+ task_id='leave1',
+ inlets=file1,
+ outlets=[
+ file2,
+ ],
+ )
op2 = DummyOperator(task_id='leave2')
- op3 = DummyOperator(task_id='upstream_level_1',
- inlets=AUTO,
- outlets=file3)
+ op3 = DummyOperator(task_id='upstream_level_1', inlets=AUTO, outlets=file3)
op4 = DummyOperator(task_id='upstream_level_2')
- op5 = DummyOperator(task_id='upstream_level_3',
- inlets=["leave1", "upstream_level_1"])
+ op5 = DummyOperator(task_id='upstream_level_3', inlets=["leave1", "upstream_level_1"])
op1.set_downstream(op3)
op2.set_downstream(op3)
@@ -61,14 +58,10 @@ def test_lineage(self):
dag.clear()
# execution_date is set in the context in order to avoid creating task instances
- ctx1 = {"ti": TI(task=op1, execution_date=DEFAULT_DATE),
- "execution_date": DEFAULT_DATE}
- ctx2 = {"ti": TI(task=op2, execution_date=DEFAULT_DATE),
- "execution_date": DEFAULT_DATE}
- ctx3 = {"ti": TI(task=op3, execution_date=DEFAULT_DATE),
- "execution_date": DEFAULT_DATE}
- ctx5 = {"ti": TI(task=op5, execution_date=DEFAULT_DATE),
- "execution_date": DEFAULT_DATE}
+ ctx1 = {"ti": TI(task=op1, execution_date=DEFAULT_DATE), "execution_date": DEFAULT_DATE}
+ ctx2 = {"ti": TI(task=op2, execution_date=DEFAULT_DATE), "execution_date": DEFAULT_DATE}
+ ctx3 = {"ti": TI(task=op3, execution_date=DEFAULT_DATE), "execution_date": DEFAULT_DATE}
+ ctx5 = {"ti": TI(task=op5, execution_date=DEFAULT_DATE), "execution_date": DEFAULT_DATE}
# prepare with manual inlets and outlets
op1.pre_execute(ctx1)
@@ -101,10 +94,7 @@ def test_lineage(self):
def test_lineage_render(self):
# tests inlets / outlets are rendered if they are added
# after initalization
- dag = DAG(
- dag_id='test_lineage_render',
- start_date=DEFAULT_DATE
- )
+ dag = DAG(dag_id='test_lineage_render', start_date=DEFAULT_DATE)
with dag:
op1 = DummyOperator(task_id='task1')
@@ -116,8 +106,7 @@ def test_lineage_render(self):
op1.outlets.append(file1)
# execution_date is set in the context in order to avoid creating task instances
- ctx1 = {"ti": TI(task=op1, execution_date=DEFAULT_DATE),
- "execution_date": DEFAULT_DATE}
+ ctx1 = {"ti": TI(task=op1, execution_date=DEFAULT_DATE), "execution_date": DEFAULT_DATE}
op1.pre_execute(ctx1)
self.assertEqual(op1.inlets[0].url, f1s.format(DEFAULT_DATE))
diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py
index 8f835afd39e0a..9232d48d4c7a0 100644
--- a/tests/models/test_baseoperator.py
+++ b/tests/models/test_baseoperator.py
@@ -55,15 +55,8 @@ def __ne__(self, other):
# Objects with circular references (for testing purpose)
-object1 = ClassWithCustomAttributes(
- attr="{{ foo }}_1",
- template_fields=["ref"]
-)
-object2 = ClassWithCustomAttributes(
- attr="{{ foo }}_2",
- ref=object1,
- template_fields=["ref"]
-)
+object1 = ClassWithCustomAttributes(attr="{{ foo }}_1", template_fields=["ref"])
+object2 = ClassWithCustomAttributes(attr="{{ foo }}_2", ref=object1, template_fields=["ref"])
setattr(object1, 'ref', object2)
@@ -98,29 +91,31 @@ class TestBaseOperator(unittest.TestCase):
),
(
# check deep nested fields can be templated
- ClassWithCustomAttributes(nested1=ClassWithCustomAttributes(att1="{{ foo }}_1",
- att2="{{ foo }}_2",
- template_fields=["att1"]),
- nested2=ClassWithCustomAttributes(att3="{{ foo }}_3",
- att4="{{ foo }}_4",
- template_fields=["att3"]),
- template_fields=["nested1"]),
+ ClassWithCustomAttributes(
+ nested1=ClassWithCustomAttributes(
+ att1="{{ foo }}_1", att2="{{ foo }}_2", template_fields=["att1"]
+ ),
+ nested2=ClassWithCustomAttributes(
+ att3="{{ foo }}_3", att4="{{ foo }}_4", template_fields=["att3"]
+ ),
+ template_fields=["nested1"],
+ ),
{"foo": "bar"},
- ClassWithCustomAttributes(nested1=ClassWithCustomAttributes(att1="bar_1",
- att2="{{ foo }}_2",
- template_fields=["att1"]),
- nested2=ClassWithCustomAttributes(att3="{{ foo }}_3",
- att4="{{ foo }}_4",
- template_fields=["att3"]),
- template_fields=["nested1"]),
+ ClassWithCustomAttributes(
+ nested1=ClassWithCustomAttributes(
+ att1="bar_1", att2="{{ foo }}_2", template_fields=["att1"]
+ ),
+ nested2=ClassWithCustomAttributes(
+ att3="{{ foo }}_3", att4="{{ foo }}_4", template_fields=["att3"]
+ ),
+ template_fields=["nested1"],
+ ),
),
(
# check null value on nested template field
- ClassWithCustomAttributes(att1=None,
- template_fields=["att1"]),
+ ClassWithCustomAttributes(att1=None, template_fields=["att1"]),
{},
- ClassWithCustomAttributes(att1=None,
- template_fields=["att1"]),
+ ClassWithCustomAttributes(att1=None, template_fields=["att1"]),
),
(
# check there is no RecursionError on circular references
@@ -214,8 +209,9 @@ def test_nested_template_fields_declared_must_exist(self):
with self.assertRaises(AttributeError) as e:
task.render_template(ClassWithCustomAttributes(template_fields=["missing_field"]), {})
- self.assertEqual("'ClassWithCustomAttributes' object has no attribute 'missing_field'",
- str(e.exception))
+ self.assertEqual(
+ "'ClassWithCustomAttributes' object has no attribute 'missing_field'", str(e.exception)
+ )
def test_jinja_invalid_expression_is_just_propagated(self):
"""Test render_template propagates Jinja invalid expression errors."""
@@ -236,9 +232,9 @@ def test_jinja_env_creation(self, mock_jinja_env):
def test_set_jinja_env_additional_option(self):
"""Test render_template given various input types."""
- with DAG("test-dag",
- start_date=DEFAULT_DATE,
- jinja_environment_kwargs={'keep_trailing_newline': True}):
+ with DAG(
+ "test-dag", start_date=DEFAULT_DATE, jinja_environment_kwargs={'keep_trailing_newline': True}
+ ):
task = DummyOperator(task_id="op1")
result = task.render_template("{{ foo }}\n\n", {"foo": "bar"})
@@ -246,9 +242,7 @@ def test_set_jinja_env_additional_option(self):
def test_override_jinja_env_option(self):
"""Test render_template given various input types."""
- with DAG("test-dag",
- start_date=DEFAULT_DATE,
- jinja_environment_kwargs={'cache_size': 50}):
+ with DAG("test-dag", start_date=DEFAULT_DATE, jinja_environment_kwargs={'cache_size': 50}):
task = DummyOperator(task_id="op1")
result = task.render_template("{{ foo }}", {"foo": "bar"})
@@ -270,9 +264,7 @@ def test_default_email_on_actions(self):
def test_email_on_actions(self):
test_task = DummyOperator(
- task_id='test_default_email_on_actions',
- email_on_retry=False,
- email_on_failure=True
+ task_id='test_default_email_on_actions', email_on_retry=False, email_on_failure=True
)
assert test_task.email_on_retry is False
assert test_task.email_on_failure is True
@@ -291,10 +283,7 @@ def test_cross_downstream(self):
def test_chain(self):
dag = DAG(dag_id='test_chain', start_date=datetime.now())
- [op1, op2, op3, op4, op5, op6] = [
- DummyOperator(task_id=f't{i}', dag=dag)
- for i in range(1, 7)
- ]
+ [op1, op2, op3, op4, op5, op6] = [DummyOperator(task_id=f't{i}', dag=dag) for i in range(1, 7)]
chain(op1, [op2, op3], [op4, op5], op6)
self.assertCountEqual([op2, op3], op1.get_direct_relatives(upstream=False))
@@ -390,9 +379,7 @@ def test_setattr_performs_no_custom_action_at_execute_time(self):
op = CustomOp(task_id="test_task")
op_copy = op.prepare_for_execution()
- with mock.patch(
- "airflow.models.baseoperator.BaseOperator.set_xcomargs_dependencies"
- ) as method_mock:
+ with mock.patch("airflow.models.baseoperator.BaseOperator.set_xcomargs_dependencies") as method_mock:
op_copy.execute({})
assert method_mock.call_count == 0
diff --git a/tests/models/test_cleartasks.py b/tests/models/test_cleartasks.py
index 9efe51ae150db..1909e6c9fa08b 100644
--- a/tests/models/test_cleartasks.py
+++ b/tests/models/test_cleartasks.py
@@ -30,7 +30,6 @@
class TestClearTasks(unittest.TestCase):
-
def setUp(self) -> None:
db.clear_db_runs()
@@ -38,8 +37,11 @@ def tearDown(self):
db.clear_db_runs()
def test_clear_task_instances(self):
- dag = DAG('test_clear_task_instances', start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE + datetime.timedelta(days=10))
+ dag = DAG(
+ 'test_clear_task_instances',
+ start_date=DEFAULT_DATE,
+ end_date=DEFAULT_DATE + datetime.timedelta(days=10),
+ )
task0 = DummyOperator(task_id='0', owner='test', dag=dag)
task1 = DummyOperator(task_id='1', owner='test', dag=dag, retries=2)
ti0 = TI(task=task0, execution_date=DEFAULT_DATE)
@@ -54,8 +56,7 @@ def test_clear_task_instances(self):
ti0.run()
ti1.run()
with create_session() as session:
- qry = session.query(TI).filter(
- TI.dag_id == dag.dag_id).all()
+ qry = session.query(TI).filter(TI.dag_id == dag.dag_id).all()
clear_task_instances(qry, session, dag=dag)
ti0.refresh_from_db()
@@ -67,8 +68,11 @@ def test_clear_task_instances(self):
self.assertEqual(ti1.max_tries, 3)
def test_clear_task_instances_without_task(self):
- dag = DAG('test_clear_task_instances_without_task', start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE + datetime.timedelta(days=10))
+ dag = DAG(
+ 'test_clear_task_instances_without_task',
+ start_date=DEFAULT_DATE,
+ end_date=DEFAULT_DATE + datetime.timedelta(days=10),
+ )
task0 = DummyOperator(task_id='task0', owner='test', dag=dag)
task1 = DummyOperator(task_id='task1', owner='test', dag=dag, retries=2)
ti0 = TI(task=task0, execution_date=DEFAULT_DATE)
@@ -89,8 +93,7 @@ def test_clear_task_instances_without_task(self):
self.assertFalse(dag.has_task(task1.task_id))
with create_session() as session:
- qry = session.query(TI).filter(
- TI.dag_id == dag.dag_id).all()
+ qry = session.query(TI).filter(TI.dag_id == dag.dag_id).all()
clear_task_instances(qry, session)
# When dag is None, max_tries will be maximum of original max_tries or try_number.
@@ -103,8 +106,11 @@ def test_clear_task_instances_without_task(self):
self.assertEqual(ti1.max_tries, 2)
def test_clear_task_instances_without_dag(self):
- dag = DAG('test_clear_task_instances_without_dag', start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE + datetime.timedelta(days=10))
+ dag = DAG(
+ 'test_clear_task_instances_without_dag',
+ start_date=DEFAULT_DATE,
+ end_date=DEFAULT_DATE + datetime.timedelta(days=10),
+ )
task0 = DummyOperator(task_id='task_0', owner='test', dag=dag)
task1 = DummyOperator(task_id='task_1', owner='test', dag=dag, retries=2)
ti0 = TI(task=task0, execution_date=DEFAULT_DATE)
@@ -120,8 +126,7 @@ def test_clear_task_instances_without_dag(self):
ti1.run()
with create_session() as session:
- qry = session.query(TI).filter(
- TI.dag_id == dag.dag_id).all()
+ qry = session.query(TI).filter(TI.dag_id == dag.dag_id).all()
clear_task_instances(qry, session)
# When dag is None, max_tries will be maximum of original max_tries or try_number.
@@ -134,8 +139,9 @@ def test_clear_task_instances_without_dag(self):
self.assertEqual(ti1.max_tries, 2)
def test_dag_clear(self):
- dag = DAG('test_dag_clear', start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE + datetime.timedelta(days=10))
+ dag = DAG(
+ 'test_dag_clear', start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=10)
+ )
task0 = DummyOperator(task_id='test_dag_clear_task_0', owner='test', dag=dag)
ti0 = TI(task=task0, execution_date=DEFAULT_DATE)
@@ -155,8 +161,7 @@ def test_dag_clear(self):
self.assertEqual(ti0.state, State.NONE)
self.assertEqual(ti0.max_tries, 1)
- task1 = DummyOperator(task_id='test_dag_clear_task_1', owner='test',
- dag=dag, retries=2)
+ task1 = DummyOperator(task_id='test_dag_clear_task_1', owner='test', dag=dag, retries=2)
ti1 = TI(task=task1, execution_date=DEFAULT_DATE)
self.assertEqual(ti1.max_tries, 2)
ti1.try_number = 1
@@ -181,11 +186,15 @@ def test_dags_clear(self):
dags, tis = [], []
num_of_dags = 5
for i in range(num_of_dags):
- dag = DAG('test_dag_clear_' + str(i), start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE + datetime.timedelta(days=10))
- ti = TI(task=DummyOperator(task_id='test_task_clear_' + str(i), owner='test',
- dag=dag),
- execution_date=DEFAULT_DATE)
+ dag = DAG(
+ 'test_dag_clear_' + str(i),
+ start_date=DEFAULT_DATE,
+ end_date=DEFAULT_DATE + datetime.timedelta(days=10),
+ )
+ ti = TI(
+ task=DummyOperator(task_id='test_task_clear_' + str(i), owner='test', dag=dag),
+ execution_date=DEFAULT_DATE,
+ )
dag.create_dagrun(
execution_date=ti.execution_date,
@@ -227,6 +236,7 @@ def test_dags_clear(self):
# test only_failed
from random import randint
+
failed_dag_idx = randint(0, len(tis) - 1)
tis[failed_dag_idx].state = State.FAILED
session.merge(tis[failed_dag_idx])
@@ -246,8 +256,11 @@ def test_dags_clear(self):
self.assertEqual(tis[i].max_tries, 2)
def test_operator_clear(self):
- dag = DAG('test_operator_clear', start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE + datetime.timedelta(days=10))
+ dag = DAG(
+ 'test_operator_clear',
+ start_date=DEFAULT_DATE,
+ end_date=DEFAULT_DATE + datetime.timedelta(days=10),
+ )
op1 = DummyOperator(task_id='bash_op', owner='test', dag=dag)
op2 = DummyOperator(task_id='dummy_op', owner='test', dag=dag, retries=1)
diff --git a/tests/models/test_connection.py b/tests/models/test_connection.py
index df3ace8fdb80b..fbd275d01164c 100644
--- a/tests/models/test_connection.py
+++ b/tests/models/test_connection.py
@@ -126,7 +126,7 @@ def test_connection_extra_with_encryption_rotate_fernet_key(self):
),
UriTestCaseConfig(
test_conn_uri='scheme://user:password@host%2Flocation:1234/schema?'
- 'extra1=a%20value&extra2=%2Fpath%2F',
+ 'extra1=a%20value&extra2=%2Fpath%2F',
test_conn_attributes=dict(
conn_type='scheme',
host='host/location',
@@ -134,9 +134,9 @@ def test_connection_extra_with_encryption_rotate_fernet_key(self):
login='user',
password='password',
port=1234,
- extra_dejson={'extra1': 'a value', 'extra2': '/path/'}
+ extra_dejson={'extra1': 'a value', 'extra2': '/path/'},
),
- description='with extras'
+ description='with extras',
),
UriTestCaseConfig(
test_conn_uri='scheme://user:password@host%2Flocation:1234/schema?extra1=a%20value&extra2=',
@@ -147,13 +147,13 @@ def test_connection_extra_with_encryption_rotate_fernet_key(self):
login='user',
password='password',
port=1234,
- extra_dejson={'extra1': 'a value', 'extra2': ''}
+ extra_dejson={'extra1': 'a value', 'extra2': ''},
),
- description='with empty extras'
+ description='with empty extras',
),
UriTestCaseConfig(
test_conn_uri='scheme://user:password@host%2Flocation%3Ax%3Ay:1234/schema?'
- 'extra1=a%20value&extra2=%2Fpath%2F',
+ 'extra1=a%20value&extra2=%2Fpath%2F',
test_conn_attributes=dict(
conn_type='scheme',
host='host/location:x:y',
@@ -163,7 +163,7 @@ def test_connection_extra_with_encryption_rotate_fernet_key(self):
port=1234,
extra_dejson={'extra1': 'a value', 'extra2': '/path/'},
),
- description='with colon in hostname'
+ description='with colon in hostname',
),
UriTestCaseConfig(
test_conn_uri='scheme://user:password%20with%20space@host%2Flocation%3Ax%3Ay:1234/schema',
@@ -175,7 +175,7 @@ def test_connection_extra_with_encryption_rotate_fernet_key(self):
password='password with space',
port=1234,
),
- description='with encoded password'
+ description='with encoded password',
),
UriTestCaseConfig(
test_conn_uri='scheme://domain%2Fuser:password@host%2Flocation%3Ax%3Ay:1234/schema',
@@ -199,7 +199,7 @@ def test_connection_extra_with_encryption_rotate_fernet_key(self):
password='password with space',
port=1234,
),
- description='with encoded schema'
+ description='with encoded schema',
),
UriTestCaseConfig(
test_conn_uri='scheme://user:password%20with%20space@host:1234',
@@ -211,7 +211,7 @@ def test_connection_extra_with_encryption_rotate_fernet_key(self):
password='password with space',
port=1234,
),
- description='no schema'
+ description='no schema',
),
UriTestCaseConfig(
test_conn_uri='google-cloud-platform://?extra__google_cloud_platform__key_'
@@ -229,7 +229,7 @@ def test_connection_extra_with_encryption_rotate_fernet_key(self):
extra__google_cloud_platform__key_path='/keys/key.json',
extra__google_cloud_platform__scope='https://www.googleapis.com/auth/cloud-platform',
extra__google_cloud_platform__project='airflow',
- )
+ ),
),
description='with underscore',
),
@@ -243,7 +243,7 @@ def test_connection_extra_with_encryption_rotate_fernet_key(self):
password=None,
port=1234,
),
- description='without auth info'
+ description='without auth info',
),
UriTestCaseConfig(
test_conn_uri='scheme://%2FTmP%2F:1234',
@@ -435,9 +435,12 @@ def test_connection_from_with_auth_info(self, uri, uri_parts):
self.assertEqual(connection.port, uri_parts.port)
self.assertEqual(connection.schema, uri_parts.schema)
- @mock.patch.dict('os.environ', {
- 'AIRFLOW_CONN_TEST_URI': 'postgres://username:password@ec2.compute.com:5432/the_database',
- })
+ @mock.patch.dict(
+ 'os.environ',
+ {
+ 'AIRFLOW_CONN_TEST_URI': 'postgres://username:password@ec2.compute.com:5432/the_database',
+ },
+ )
def test_using_env_var(self):
conn = SqliteHook.get_connection(conn_id='test_uri')
self.assertEqual('ec2.compute.com', conn.host)
@@ -446,9 +449,12 @@ def test_using_env_var(self):
self.assertEqual('password', conn.password)
self.assertEqual(5432, conn.port)
- @mock.patch.dict('os.environ', {
- 'AIRFLOW_CONN_TEST_URI_NO_CREDS': 'postgres://ec2.compute.com/the_database',
- })
+ @mock.patch.dict(
+ 'os.environ',
+ {
+ 'AIRFLOW_CONN_TEST_URI_NO_CREDS': 'postgres://ec2.compute.com/the_database',
+ },
+ )
def test_using_unix_socket_env_var(self):
conn = SqliteHook.get_connection(conn_id='test_uri_no_creds')
self.assertEqual('ec2.compute.com', conn.host)
@@ -458,9 +464,14 @@ def test_using_unix_socket_env_var(self):
self.assertIsNone(conn.port)
def test_param_setup(self):
- conn = Connection(conn_id='local_mysql', conn_type='mysql',
- host='localhost', login='airflow',
- password='airflow', schema='airflow')
+ conn = Connection(
+ conn_id='local_mysql',
+ conn_type='mysql',
+ host='localhost',
+ login='airflow',
+ password='airflow',
+ schema='airflow',
+ )
self.assertEqual('localhost', conn.host)
self.assertEqual('airflow', conn.schema)
self.assertEqual('airflow', conn.login)
@@ -471,9 +482,12 @@ def test_env_var_priority(self):
conn = SqliteHook.get_connection(conn_id='airflow_db')
self.assertNotEqual('ec2.compute.com', conn.host)
- with mock.patch.dict('os.environ', {
- 'AIRFLOW_CONN_AIRFLOW_DB': 'postgres://username:password@ec2.compute.com:5432/the_database',
- }):
+ with mock.patch.dict(
+ 'os.environ',
+ {
+ 'AIRFLOW_CONN_AIRFLOW_DB': 'postgres://username:password@ec2.compute.com:5432/the_database',
+ },
+ ):
conn = SqliteHook.get_connection(conn_id='airflow_db')
self.assertEqual('ec2.compute.com', conn.host)
self.assertEqual('the_database', conn.schema)
@@ -481,10 +495,13 @@ def test_env_var_priority(self):
self.assertEqual('password', conn.password)
self.assertEqual(5432, conn.port)
- @mock.patch.dict('os.environ', {
- 'AIRFLOW_CONN_TEST_URI': 'postgres://username:password@ec2.compute.com:5432/the_database',
- 'AIRFLOW_CONN_TEST_URI_NO_CREDS': 'postgres://ec2.compute.com/the_database',
- })
+ @mock.patch.dict(
+ 'os.environ',
+ {
+ 'AIRFLOW_CONN_TEST_URI': 'postgres://username:password@ec2.compute.com:5432/the_database',
+ 'AIRFLOW_CONN_TEST_URI_NO_CREDS': 'postgres://ec2.compute.com/the_database',
+ },
+ )
def test_dbapi_get_uri(self):
conn = BaseHook.get_connection(conn_id='test_uri')
hook = conn.get_hook()
@@ -493,10 +510,13 @@ def test_dbapi_get_uri(self):
hook2 = conn2.get_hook()
self.assertEqual('postgres://ec2.compute.com/the_database', hook2.get_uri())
- @mock.patch.dict('os.environ', {
- 'AIRFLOW_CONN_TEST_URI': 'postgres://username:password@ec2.compute.com:5432/the_database',
- 'AIRFLOW_CONN_TEST_URI_NO_CREDS': 'postgres://ec2.compute.com/the_database',
- })
+ @mock.patch.dict(
+ 'os.environ',
+ {
+ 'AIRFLOW_CONN_TEST_URI': 'postgres://username:password@ec2.compute.com:5432/the_database',
+ 'AIRFLOW_CONN_TEST_URI_NO_CREDS': 'postgres://ec2.compute.com/the_database',
+ },
+ )
def test_dbapi_get_sqlalchemy_engine(self):
conn = BaseHook.get_connection(conn_id='test_uri')
hook = conn.get_hook()
@@ -504,10 +524,13 @@ def test_dbapi_get_sqlalchemy_engine(self):
self.assertIsInstance(engine, sqlalchemy.engine.Engine)
self.assertEqual('postgres://username:password@ec2.compute.com:5432/the_database', str(engine.url))
- @mock.patch.dict('os.environ', {
- 'AIRFLOW_CONN_TEST_URI': 'postgres://username:password@ec2.compute.com:5432/the_database',
- 'AIRFLOW_CONN_TEST_URI_NO_CREDS': 'postgres://ec2.compute.com/the_database',
- })
+ @mock.patch.dict(
+ 'os.environ',
+ {
+ 'AIRFLOW_CONN_TEST_URI': 'postgres://username:password@ec2.compute.com:5432/the_database',
+ 'AIRFLOW_CONN_TEST_URI_NO_CREDS': 'postgres://ec2.compute.com/the_database',
+ },
+ )
def test_get_connections_env_var(self):
conns = SqliteHook.get_connections(conn_id='test_uri')
assert len(conns) == 1
@@ -523,7 +546,7 @@ def test_connection_mixed(self):
re.escape(
"You must create an object using the URI or individual values (conn_type, host, login, "
"password, schema, port or extra).You can't mix these two ways to create this object."
- )
+ ),
):
Connection(conn_id="TEST_ID", uri="mysql://", schema="AAA")
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 7fd0f2dbab695..7b09ba89ddba4 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -63,7 +63,6 @@
class TestDag(unittest.TestCase):
-
def setUp(self) -> None:
clear_db_runs()
clear_db_dags()
@@ -78,15 +77,9 @@ def tearDown(self) -> None:
@staticmethod
def _clean_up(dag_id: str):
with create_session() as session:
- session.query(DagRun).filter(
- DagRun.dag_id == dag_id).delete(
- synchronize_session=False)
- session.query(TI).filter(
- TI.dag_id == dag_id).delete(
- synchronize_session=False)
- session.query(TaskFail).filter(
- TaskFail.dag_id == dag_id).delete(
- synchronize_session=False)
+ session.query(DagRun).filter(DagRun.dag_id == dag_id).delete(synchronize_session=False)
+ session.query(TI).filter(TI.dag_id == dag_id).delete(synchronize_session=False)
+ session.query(TaskFail).filter(TaskFail.dag_id == dag_id).delete(synchronize_session=False)
@staticmethod
def _occur_before(a, b, list_):
@@ -122,9 +115,7 @@ def test_params_passed_and_params_in_default_args_no_override(self):
params1 = {'parameter1': 1}
params2 = {'parameter2': 2}
- dag = models.DAG('test-dag',
- default_args={'params': params1},
- params=params2)
+ dag = models.DAG('test-dag', default_args={'params': params1}, params=params2)
params_combined = params1.copy()
params_combined.update(params2)
@@ -134,43 +125,29 @@ def test_dag_invalid_default_view(self):
"""
Test invalid `default_view` of DAG initialization
"""
- with self.assertRaisesRegex(AirflowException,
- 'Invalid values of dag.default_view: only support'):
- models.DAG(
- dag_id='test-invalid-default_view',
- default_view='airflow'
- )
+ with self.assertRaisesRegex(AirflowException, 'Invalid values of dag.default_view: only support'):
+ models.DAG(dag_id='test-invalid-default_view', default_view='airflow')
def test_dag_default_view_default_value(self):
"""
Test `default_view` default value of DAG initialization
"""
- dag = models.DAG(
- dag_id='test-default_default_view'
- )
- self.assertEqual(conf.get('webserver', 'dag_default_view').lower(),
- dag.default_view)
+ dag = models.DAG(dag_id='test-default_default_view')
+ self.assertEqual(conf.get('webserver', 'dag_default_view').lower(), dag.default_view)
def test_dag_invalid_orientation(self):
"""
Test invalid `orientation` of DAG initialization
"""
- with self.assertRaisesRegex(AirflowException,
- 'Invalid values of dag.orientation: only support'):
- models.DAG(
- dag_id='test-invalid-orientation',
- orientation='airflow'
- )
+ with self.assertRaisesRegex(AirflowException, 'Invalid values of dag.orientation: only support'):
+ models.DAG(dag_id='test-invalid-orientation', orientation='airflow')
def test_dag_orientation_default_value(self):
"""
Test `orientation` default value of DAG initialization
"""
- dag = models.DAG(
- dag_id='test-default_orientation'
- )
- self.assertEqual(conf.get('webserver', 'dag_orientation'),
- dag.orientation)
+ dag = models.DAG(dag_id='test-default_orientation')
+ self.assertEqual(conf.get('webserver', 'dag_orientation'), dag.orientation)
def test_dag_as_context_manager(self):
"""
@@ -178,14 +155,8 @@ def test_dag_as_context_manager(self):
When used as a context manager, Operators are automatically added to
the DAG (unless they specify a different DAG)
"""
- dag = DAG(
- 'dag',
- start_date=DEFAULT_DATE,
- default_args={'owner': 'owner1'})
- dag2 = DAG(
- 'dag2',
- start_date=DEFAULT_DATE,
- default_args={'owner': 'owner2'})
+ dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
+ dag2 = DAG('dag2', start_date=DEFAULT_DATE, default_args={'owner': 'owner2'})
with dag:
op1 = DummyOperator(task_id='op1')
@@ -263,10 +234,7 @@ def test_dag_topological_sort_include_subdag_tasks(self):
self.assertTrue(self._occur_before('b_child', 'b_parent', topological_list))
def test_dag_topological_sort1(self):
- dag = DAG(
- 'dag',
- start_date=DEFAULT_DATE,
- default_args={'owner': 'owner1'})
+ dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
# A -> B
# A -> C -> D
@@ -292,10 +260,7 @@ def test_dag_topological_sort1(self):
self.assertTrue(topological_list[3] == op1)
def test_dag_topological_sort2(self):
- dag = DAG(
- 'dag',
- start_date=DEFAULT_DATE,
- default_args={'owner': 'owner1'})
+ dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
# C -> (A u B) -> D
# C -> E
@@ -332,10 +297,7 @@ def test_dag_topological_sort2(self):
self.assertTrue(topological_list[4] == op3)
def test_dag_topological_sort_dag_without_tasks(self):
- dag = DAG(
- 'dag',
- start_date=DEFAULT_DATE,
- default_args={'owner': 'owner1'})
+ dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
self.assertEqual((), dag.topological_sort())
@@ -356,8 +318,9 @@ def test_dag_start_date_propagates_to_end_date(self):
An explicit check the `tzinfo` attributes for both are the same is an extra check.
"""
- dag = DAG('DAG', default_args={'start_date': '2019-06-05T00:00:00+05:00',
- 'end_date': '2019-06-05T00:00:00'})
+ dag = DAG(
+ 'DAG', default_args={'start_date': '2019-06-05T00:00:00+05:00', 'end_date': '2019-06-05T00:00:00'}
+ )
self.assertEqual(dag.default_args['start_date'], dag.default_args['end_date'])
self.assertEqual(dag.default_args['start_date'].tzinfo, dag.default_args['end_date'].tzinfo)
@@ -383,12 +346,10 @@ def test_dag_task_priority_weight_total(self):
# Fully connected parallel tasks. i.e. every task at each parallel
# stage is dependent on every task in the previous stage.
# Default weight should be calculated using downstream descendants
- with DAG('dag', start_date=DEFAULT_DATE,
- default_args={'owner': 'owner1'}) as dag:
+ with DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) as dag:
pipeline = [
- [DummyOperator(
- task_id=f'stage{i}.{j}', priority_weight=weight)
- for j in range(0, width)] for i in range(0, depth)
+ [DummyOperator(task_id=f'stage{i}.{j}', priority_weight=weight) for j in range(0, width)]
+ for i in range(0, depth)
]
for i, stage in enumerate(pipeline):
if i == 0:
@@ -412,13 +373,17 @@ def test_dag_task_priority_weight_total_using_upstream(self):
width = 5
depth = 5
pattern = re.compile('stage(\\d*).(\\d*)')
- with DAG('dag', start_date=DEFAULT_DATE,
- default_args={'owner': 'owner1'}) as dag:
+ with DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) as dag:
pipeline = [
- [DummyOperator(
- task_id=f'stage{i}.{j}', priority_weight=weight,
- weight_rule=WeightRule.UPSTREAM)
- for j in range(0, width)] for i in range(0, depth)
+ [
+ DummyOperator(
+ task_id=f'stage{i}.{j}',
+ priority_weight=weight,
+ weight_rule=WeightRule.UPSTREAM,
+ )
+ for j in range(0, width)
+ ]
+ for i in range(0, depth)
]
for i, stage in enumerate(pipeline):
if i == 0:
@@ -441,13 +406,17 @@ def test_dag_task_priority_weight_total_using_absolute(self):
weight = 10
width = 5
depth = 5
- with DAG('dag', start_date=DEFAULT_DATE,
- default_args={'owner': 'owner1'}) as dag:
+ with DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) as dag:
pipeline = [
- [DummyOperator(
- task_id=f'stage{i}.{j}', priority_weight=weight,
- weight_rule=WeightRule.ABSOLUTE)
- for j in range(0, width)] for i in range(0, depth)
+ [
+ DummyOperator(
+ task_id=f'stage{i}.{j}',
+ priority_weight=weight,
+ weight_rule=WeightRule.ABSOLUTE,
+ )
+ for j in range(0, width)
+ ]
+ for i in range(0, depth)
]
for i, stage in enumerate(pipeline):
if i == 0:
@@ -490,40 +459,29 @@ def test_get_num_task_instances(self):
session.merge(ti4)
session.commit()
+ self.assertEqual(0, DAG.get_num_task_instances(test_dag_id, ['fakename'], session=session))
+ self.assertEqual(4, DAG.get_num_task_instances(test_dag_id, [test_task_id], session=session))
self.assertEqual(
- 0,
- DAG.get_num_task_instances(test_dag_id, ['fakename'], session=session)
- )
- self.assertEqual(
- 4,
- DAG.get_num_task_instances(test_dag_id, [test_task_id], session=session)
- )
- self.assertEqual(
- 4,
- DAG.get_num_task_instances(
- test_dag_id, ['fakename', test_task_id], session=session)
+ 4, DAG.get_num_task_instances(test_dag_id, ['fakename', test_task_id], session=session)
)
self.assertEqual(
- 1,
- DAG.get_num_task_instances(
- test_dag_id, [test_task_id], states=[None], session=session)
+ 1, DAG.get_num_task_instances(test_dag_id, [test_task_id], states=[None], session=session)
)
self.assertEqual(
2,
- DAG.get_num_task_instances(
- test_dag_id, [test_task_id], states=[State.RUNNING], session=session)
+ DAG.get_num_task_instances(test_dag_id, [test_task_id], states=[State.RUNNING], session=session),
)
self.assertEqual(
3,
DAG.get_num_task_instances(
- test_dag_id, [test_task_id],
- states=[None, State.RUNNING], session=session)
+ test_dag_id, [test_task_id], states=[None, State.RUNNING], session=session
+ ),
)
self.assertEqual(
4,
DAG.get_num_task_instances(
- test_dag_id, [test_task_id],
- states=[None, State.QUEUED, State.RUNNING], session=session)
+ test_dag_id, [test_task_id], states=[None, State.QUEUED, State.RUNNING], session=session
+ ),
)
session.close()
@@ -578,10 +536,10 @@ def test_following_previous_schedule(self):
Make sure DST transitions are properly observed
"""
local_tz = pendulum.timezone('Europe/Zurich')
- start = local_tz.convert(datetime.datetime(2018, 10, 28, 2, 55),
- dst_rule=pendulum.PRE_TRANSITION)
- self.assertEqual(start.isoformat(), "2018-10-28T02:55:00+02:00",
- "Pre-condition: start date is in DST")
+ start = local_tz.convert(datetime.datetime(2018, 10, 28, 2, 55), dst_rule=pendulum.PRE_TRANSITION)
+ self.assertEqual(
+ start.isoformat(), "2018-10-28T02:55:00+02:00", "Pre-condition: start date is in DST"
+ )
utc = timezone.convert_to_utc(start)
@@ -608,8 +566,7 @@ def test_following_previous_schedule_daily_dag_cest_to_cet(self):
Make sure DST transitions are properly observed
"""
local_tz = pendulum.timezone('Europe/Zurich')
- start = local_tz.convert(datetime.datetime(2018, 10, 27, 3),
- dst_rule=pendulum.PRE_TRANSITION)
+ start = local_tz.convert(datetime.datetime(2018, 10, 27, 3), dst_rule=pendulum.PRE_TRANSITION)
utc = timezone.convert_to_utc(start)
@@ -638,8 +595,7 @@ def test_following_previous_schedule_daily_dag_cet_to_cest(self):
Make sure DST transitions are properly observed
"""
local_tz = pendulum.timezone('Europe/Zurich')
- start = local_tz.convert(datetime.datetime(2018, 3, 25, 2),
- dst_rule=pendulum.PRE_TRANSITION)
+ start = local_tz.convert(datetime.datetime(2018, 3, 25, 2), dst_rule=pendulum.PRE_TRANSITION)
utc = timezone.convert_to_utc(start)
@@ -669,12 +625,8 @@ def test_following_schedule_relativedelta(self):
"""
dag_id = "test_schedule_dag_relativedelta"
delta = relativedelta(hours=+1)
- dag = DAG(dag_id=dag_id,
- schedule_interval=delta)
- dag.add_task(BaseOperator(
- task_id="faketastic",
- owner='Also fake',
- start_date=TEST_DATE))
+ dag = DAG(dag_id=dag_id, schedule_interval=delta)
+ dag.add_task(BaseOperator(task_id="faketastic", owner='Also fake', start_date=TEST_DATE))
_next = dag.following_schedule(TEST_DATE)
self.assertEqual(_next.isoformat(), "2015-01-02T01:00:00+00:00")
@@ -687,22 +639,21 @@ def test_dagtag_repr(self):
dag = DAG('dag-test-dagtag', start_date=DEFAULT_DATE, tags=['tag-1', 'tag-2'])
dag.sync_to_db()
with create_session() as session:
- self.assertEqual({'tag-1', 'tag-2'},
- {repr(t) for t in session.query(DagTag).filter(
- DagTag.dag_id == 'dag-test-dagtag').all()})
+ self.assertEqual(
+ {'tag-1', 'tag-2'},
+ {repr(t) for t in session.query(DagTag).filter(DagTag.dag_id == 'dag-test-dagtag').all()},
+ )
def test_bulk_write_to_db(self):
clear_db_dags()
- dags = [
- DAG(f'dag-bulk-sync-{i}', start_date=DEFAULT_DATE, tags=["test-dag"]) for i in range(0, 4)
- ]
+ dags = [DAG(f'dag-bulk-sync-{i}', start_date=DEFAULT_DATE, tags=["test-dag"]) for i in range(0, 4)]
with assert_queries_count(5):
DAG.bulk_write_to_db(dags)
with create_session() as session:
self.assertEqual(
{'dag-bulk-sync-0', 'dag-bulk-sync-1', 'dag-bulk-sync-2', 'dag-bulk-sync-3'},
- {row[0] for row in session.query(DagModel.dag_id).all()}
+ {row[0] for row in session.query(DagModel.dag_id).all()},
)
self.assertEqual(
{
@@ -711,7 +662,7 @@ def test_bulk_write_to_db(self):
('dag-bulk-sync-2', 'test-dag'),
('dag-bulk-sync-3', 'test-dag'),
},
- set(session.query(DagTag.dag_id, DagTag.name).all())
+ set(session.query(DagTag.dag_id, DagTag.name).all()),
)
# Re-sync should do fewer queries
with assert_queries_count(3):
@@ -726,7 +677,7 @@ def test_bulk_write_to_db(self):
with create_session() as session:
self.assertEqual(
{'dag-bulk-sync-0', 'dag-bulk-sync-1', 'dag-bulk-sync-2', 'dag-bulk-sync-3'},
- {row[0] for row in session.query(DagModel.dag_id).all()}
+ {row[0] for row in session.query(DagModel.dag_id).all()},
)
self.assertEqual(
{
@@ -739,7 +690,7 @@ def test_bulk_write_to_db(self):
('dag-bulk-sync-3', 'test-dag'),
('dag-bulk-sync-3', 'test-dag2'),
},
- set(session.query(DagTag.dag_id, DagTag.name).all())
+ set(session.query(DagTag.dag_id, DagTag.name).all()),
)
# Removing tags
for dag in dags:
@@ -749,7 +700,7 @@ def test_bulk_write_to_db(self):
with create_session() as session:
self.assertEqual(
{'dag-bulk-sync-0', 'dag-bulk-sync-1', 'dag-bulk-sync-2', 'dag-bulk-sync-3'},
- {row[0] for row in session.query(DagModel.dag_id).all()}
+ {row[0] for row in session.query(DagModel.dag_id).all()},
)
self.assertEqual(
{
@@ -758,7 +709,7 @@ def test_bulk_write_to_db(self):
('dag-bulk-sync-2', 'test-dag2'),
('dag-bulk-sync-3', 'test-dag2'),
},
- set(session.query(DagTag.dag_id, DagTag.name).all())
+ set(session.query(DagTag.dag_id, DagTag.name).all()),
)
def test_bulk_write_to_db_max_active_runs(self):
@@ -766,15 +717,10 @@ def test_bulk_write_to_db_max_active_runs(self):
Test that DagModel.next_dagrun_create_after is set to NULL when the dag cannot be created due to max
active runs being hit.
"""
- dag = DAG(
- dag_id='test_scheduler_verify_max_active_runs',
- start_date=DEFAULT_DATE)
+ dag = DAG(dag_id='test_scheduler_verify_max_active_runs', start_date=DEFAULT_DATE)
dag.max_active_runs = 1
- DummyOperator(
- task_id='dummy',
- dag=dag,
- owner='airflow')
+ DummyOperator(task_id='dummy', dag=dag, owner='airflow')
session = settings.Session()
dag.clear()
@@ -807,15 +753,14 @@ def test_sync_to_db(self):
)
with dag:
DummyOperator(task_id='task', owner='owner1')
- subdag = DAG('dag.subtask', start_date=DEFAULT_DATE, )
+ subdag = DAG(
+ 'dag.subtask',
+ start_date=DEFAULT_DATE,
+ )
# parent_dag and is_subdag was set by DagBag. We don't use DagBag, so this value is not set.
subdag.parent_dag = dag
subdag.is_subdag = True
- SubDagOperator(
- task_id='subtask',
- owner='owner2',
- subdag=subdag
- )
+ SubDagOperator(task_id='subtask', owner='owner2', subdag=subdag)
session = settings.Session()
dag.sync_to_db(session=session)
@@ -823,8 +768,7 @@ def test_sync_to_db(self):
self.assertEqual(set(orm_dag.owners.split(', ')), {'owner1', 'owner2'})
self.assertTrue(orm_dag.is_active)
self.assertIsNotNone(orm_dag.default_view)
- self.assertEqual(orm_dag.default_view,
- conf.get('webserver', 'dag_default_view').lower())
+ self.assertEqual(orm_dag.default_view, conf.get('webserver', 'dag_default_view').lower())
self.assertEqual(orm_dag.safe_dag_id, 'dag')
orm_subdag = session.query(DagModel).filter(DagModel.dag_id == 'dag.subtask').one()
@@ -848,7 +792,7 @@ def test_sync_to_db_default_view(self):
subdag=DAG(
'dag.subtask',
start_date=DEFAULT_DATE,
- )
+ ),
)
session = settings.Session()
dag.sync_to_db(session=session)
@@ -877,10 +821,7 @@ def test_is_paused_subdag(self, session):
)
with dag:
- SubDagOperator(
- task_id='subdag',
- subdag=subdag
- )
+ SubDagOperator(task_id='subdag', subdag=subdag)
# parent_dag and is_subdag was set by DagBag. We don't use DagBag, so this value is not set.
subdag.parent_dag = dag
@@ -892,63 +833,70 @@ def test_is_paused_subdag(self, session):
dag.sync_to_db(session=session)
- unpaused_dags = session.query(
- DagModel.dag_id, DagModel.is_paused
- ).filter(
- DagModel.dag_id.in_([subdag_id, dag_id]),
- ).all()
+ unpaused_dags = (
+ session.query(DagModel.dag_id, DagModel.is_paused)
+ .filter(
+ DagModel.dag_id.in_([subdag_id, dag_id]),
+ )
+ .all()
+ )
- self.assertEqual({
- (dag_id, False),
- (subdag_id, False),
- }, set(unpaused_dags))
+ self.assertEqual(
+ {
+ (dag_id, False),
+ (subdag_id, False),
+ },
+ set(unpaused_dags),
+ )
DagModel.get_dagmodel(dag.dag_id).set_is_paused(is_paused=True, including_subdags=False)
- paused_dags = session.query(
- DagModel.dag_id, DagModel.is_paused
- ).filter(
- DagModel.dag_id.in_([subdag_id, dag_id]),
- ).all()
+ paused_dags = (
+ session.query(DagModel.dag_id, DagModel.is_paused)
+ .filter(
+ DagModel.dag_id.in_([subdag_id, dag_id]),
+ )
+ .all()
+ )
- self.assertEqual({
- (dag_id, True),
- (subdag_id, False),
- }, set(paused_dags))
+ self.assertEqual(
+ {
+ (dag_id, True),
+ (subdag_id, False),
+ },
+ set(paused_dags),
+ )
DagModel.get_dagmodel(dag.dag_id).set_is_paused(is_paused=True)
- paused_dags = session.query(
- DagModel.dag_id, DagModel.is_paused
- ).filter(
- DagModel.dag_id.in_([subdag_id, dag_id]),
- ).all()
+ paused_dags = (
+ session.query(DagModel.dag_id, DagModel.is_paused)
+ .filter(
+ DagModel.dag_id.in_([subdag_id, dag_id]),
+ )
+ .all()
+ )
- self.assertEqual({
- (dag_id, True),
- (subdag_id, True),
- }, set(paused_dags))
+ self.assertEqual(
+ {
+ (dag_id, True),
+ (subdag_id, True),
+ },
+ set(paused_dags),
+ )
def test_existing_dag_is_paused_upon_creation(self):
- dag = DAG(
- 'dag_paused'
- )
+ dag = DAG('dag_paused')
dag.sync_to_db()
self.assertFalse(dag.get_is_paused())
- dag = DAG(
- 'dag_paused',
- is_paused_upon_creation=True
- )
+ dag = DAG('dag_paused', is_paused_upon_creation=True)
dag.sync_to_db()
# Since the dag existed before, it should not follow the pause flag upon creation
self.assertFalse(dag.get_is_paused())
def test_new_dag_is_paused_upon_creation(self):
- dag = DAG(
- 'new_nonexisting_dag',
- is_paused_upon_creation=True
- )
+ dag = DAG('new_nonexisting_dag', is_paused_upon_creation=True)
session = settings.Session()
dag.sync_to_db(session=session)
@@ -1045,9 +993,7 @@ def test_tree_view(self):
def test_duplicate_task_ids_not_allowed_with_dag_context_manager(self):
"""Verify tasks with Duplicate task_id raises error"""
- with self.assertRaisesRegex(
- DuplicateTaskIdFound, "Task id 't1' has already been added to the DAG"
- ):
+ with self.assertRaisesRegex(DuplicateTaskIdFound, "Task id 't1' has already been added to the DAG"):
with DAG("test_dag", start_date=DEFAULT_DATE) as dag:
op1 = DummyOperator(task_id="t1")
op2 = BashOperator(task_id="t1", bash_command="sleep 1")
@@ -1057,9 +1003,7 @@ def test_duplicate_task_ids_not_allowed_with_dag_context_manager(self):
def test_duplicate_task_ids_not_allowed_without_dag_context_manager(self):
"""Verify tasks with Duplicate task_id raises error"""
- with self.assertRaisesRegex(
- DuplicateTaskIdFound, "Task id 't1' has already been added to the DAG"
- ):
+ with self.assertRaisesRegex(DuplicateTaskIdFound, "Task id 't1' has already been added to the DAG"):
dag = DAG("test_dag", start_date=DEFAULT_DATE)
op1 = DummyOperator(task_id="t1", dag=dag)
op2 = DummyOperator(task_id="t1", dag=dag)
@@ -1096,10 +1040,7 @@ def test_schedule_dag_no_previous_runs(self):
"""
dag_id = "test_schedule_dag_no_previous_runs"
dag = DAG(dag_id=dag_id)
- dag.add_task(BaseOperator(
- task_id="faketastic",
- owner='Also fake',
- start_date=TEST_DATE))
+ dag.add_task(BaseOperator(task_id="faketastic", owner='Also fake', start_date=TEST_DATE))
dag_run = dag.create_dagrun(
run_type=DagRunType.SCHEDULED,
@@ -1113,8 +1054,7 @@ def test_schedule_dag_no_previous_runs(self):
self.assertEqual(
TEST_DATE,
dag_run.execution_date,
- msg='dag_run.execution_date did not match expectation: {}'
- .format(dag_run.execution_date)
+ msg=f'dag_run.execution_date did not match expectation: {dag_run.execution_date}',
)
self.assertEqual(State.RUNNING, dag_run.state)
self.assertFalse(dag_run.external_trigger)
@@ -1133,12 +1073,10 @@ def test_dag_handle_callback_crash(self, mock_stats):
dag_id=dag_id,
# callback with invalid signature should not cause crashes
on_success_callback=lambda: 1,
- on_failure_callback=mock_callback_with_exception)
+ on_failure_callback=mock_callback_with_exception,
+ )
when = TEST_DATE
- dag.add_task(BaseOperator(
- task_id="faketastic",
- owner='Also fake',
- start_date=when))
+ dag.add_task(BaseOperator(task_id="faketastic", owner='Also fake', start_date=when))
dag_run = dag.create_dagrun(State.RUNNING, when, run_type=DagRunType.MANUAL)
# should not rause any exception
@@ -1157,18 +1095,15 @@ def test_next_dagrun_after_fake_scheduled_previous(self):
"""
delta = datetime.timedelta(hours=1)
dag_id = "test_schedule_dag_fake_scheduled_previous"
- dag = DAG(dag_id=dag_id,
- schedule_interval=delta,
- start_date=DEFAULT_DATE)
- dag.add_task(BaseOperator(
- task_id="faketastic",
- owner='Also fake',
- start_date=DEFAULT_DATE))
-
- dag.create_dagrun(run_type=DagRunType.SCHEDULED,
- execution_date=DEFAULT_DATE,
- state=State.SUCCESS,
- external_trigger=True)
+ dag = DAG(dag_id=dag_id, schedule_interval=delta, start_date=DEFAULT_DATE)
+ dag.add_task(BaseOperator(task_id="faketastic", owner='Also fake', start_date=DEFAULT_DATE))
+
+ dag.create_dagrun(
+ run_type=DagRunType.SCHEDULED,
+ execution_date=DEFAULT_DATE,
+ state=State.SUCCESS,
+ external_trigger=True,
+ )
dag.sync_to_db()
with create_session() as session:
model = session.query(DagModel).get((dag.dag_id,))
@@ -1189,17 +1124,12 @@ def test_schedule_dag_once(self):
dag = DAG(dag_id=dag_id)
dag.schedule_interval = '@once'
self.assertEqual(dag.normalized_schedule_interval, None)
- dag.add_task(BaseOperator(
- task_id="faketastic",
- owner='Also fake',
- start_date=TEST_DATE))
+ dag.add_task(BaseOperator(task_id="faketastic", owner='Also fake', start_date=TEST_DATE))
# Sync once to create the DagModel
dag.sync_to_db()
- dag.create_dagrun(run_type=DagRunType.SCHEDULED,
- execution_date=TEST_DATE,
- state=State.SUCCESS)
+ dag.create_dagrun(run_type=DagRunType.SCHEDULED, execution_date=TEST_DATE, state=State.SUCCESS)
# Then sync again after creating the dag run -- this should update next_dagrun
dag.sync_to_db()
@@ -1217,10 +1147,7 @@ def test_fractional_seconds(self):
dag_id = "test_fractional_seconds"
dag = DAG(dag_id=dag_id)
dag.schedule_interval = '@once'
- dag.add_task(BaseOperator(
- task_id="faketastic",
- owner='Also fake',
- start_date=TEST_DATE))
+ dag.add_task(BaseOperator(task_id="faketastic", owner='Also fake', start_date=TEST_DATE))
start_date = timezone.utcnow()
@@ -1229,15 +1156,13 @@ def test_fractional_seconds(self):
execution_date=start_date,
start_date=start_date,
state=State.RUNNING,
- external_trigger=False
+ external_trigger=False,
)
run.refresh_from_db()
- self.assertEqual(start_date, run.execution_date,
- "dag run execution_date loses precision")
- self.assertEqual(start_date, run.start_date,
- "dag run start_date loses precision ")
+ self.assertEqual(start_date, run.execution_date, "dag run execution_date loses precision")
+ self.assertEqual(start_date, run.start_date, "dag run start_date loses precision ")
self._clean_up(dag_id)
def test_pickling(self):
@@ -1262,8 +1187,7 @@ class DAGsubclass(DAG):
dag_diff_name = DAG(test_dag_id + '_neq', default_args=args)
dag_subclass = DAGsubclass(test_dag_id, default_args=args)
- dag_subclass_diff_name = DAGsubclass(
- test_dag_id + '2', default_args=args)
+ dag_subclass_diff_name = DAGsubclass(test_dag_id + '2', default_args=args)
for dag_ in [dag_eq, dag_diff_name, dag_subclass, dag_subclass_diff_name]:
dag_.last_loaded = dag.last_loaded
@@ -1307,23 +1231,21 @@ def test_get_paused_dag_ids(self):
self.assertEqual(paused_dag_ids, {dag_id})
with create_session() as session:
- session.query(DagModel).filter(
- DagModel.dag_id == dag_id).delete(
- synchronize_session=False)
-
- @parameterized.expand([
- (None, None),
- ("@daily", "0 0 * * *"),
- ("@weekly", "0 0 * * 0"),
- ("@monthly", "0 0 1 * *"),
- ("@quarterly", "0 0 1 */3 *"),
- ("@yearly", "0 0 1 1 *"),
- ("@once", None),
- (datetime.timedelta(days=1), datetime.timedelta(days=1)),
- ])
- def test_normalized_schedule_interval(
- self, schedule_interval, expected_n_schedule_interval
- ):
+ session.query(DagModel).filter(DagModel.dag_id == dag_id).delete(synchronize_session=False)
+
+ @parameterized.expand(
+ [
+ (None, None),
+ ("@daily", "0 0 * * *"),
+ ("@weekly", "0 0 * * 0"),
+ ("@monthly", "0 0 1 * *"),
+ ("@quarterly", "0 0 1 */3 *"),
+ ("@yearly", "0 0 1 1 *"),
+ ("@once", None),
+ (datetime.timedelta(days=1), datetime.timedelta(days=1)),
+ ]
+ )
+ def test_normalized_schedule_interval(self, schedule_interval, expected_n_schedule_interval):
dag = DAG("test_schedule_interval", schedule_interval=schedule_interval)
self.assertEqual(dag.normalized_schedule_interval, expected_n_schedule_interval)
@@ -1399,20 +1321,24 @@ def test_clear_set_dagrun_state(self, dag_run_state):
session=session,
)
- dagruns = session.query(
- DagRun,
- ).filter(
- DagRun.dag_id == dag_id,
- ).all()
+ dagruns = (
+ session.query(
+ DagRun,
+ )
+ .filter(
+ DagRun.dag_id == dag_id,
+ )
+ .all()
+ )
self.assertEqual(len(dagruns), 1)
dagrun = dagruns[0] # type: DagRun
self.assertEqual(dagrun.state, dag_run_state)
- @parameterized.expand([
- (state, State.NONE)
- for state in State.task_states if state != State.RUNNING
- ] + [(State.RUNNING, State.SHUTDOWN)]) # type: ignore
+ @parameterized.expand(
+ [(state, State.NONE) for state in State.task_states if state != State.RUNNING]
+ + [(State.RUNNING, State.SHUTDOWN)]
+ ) # type: ignore
def test_clear_dag(self, ti_state_begin, ti_state_end: Optional[str]):
dag_id = 'test_clear_dag'
self._clean_up(dag_id)
@@ -1440,11 +1366,15 @@ def test_clear_dag(self, ti_state_begin, ti_state_end: Optional[str]):
session=session,
)
- task_instances = session.query(
- TI,
- ).filter(
- TI.dag_id == dag_id,
- ).all()
+ task_instances = (
+ session.query(
+ TI,
+ )
+ .filter(
+ TI.dag_id == dag_id,
+ )
+ .all()
+ )
self.assertEqual(len(task_instances), 1)
task_instance = task_instances[0] # type: TI
@@ -1453,9 +1383,8 @@ def test_clear_dag(self, ti_state_begin, ti_state_end: Optional[str]):
def test_next_dagrun_after_date_once(self):
dag = DAG(
- 'test_scheduler_dagrun_once',
- start_date=timezone.datetime(2015, 1, 1),
- schedule_interval="@once")
+ 'test_scheduler_dagrun_once', start_date=timezone.datetime(2015, 1, 1), schedule_interval="@once"
+ )
next_date = dag.next_dagrun_after_date(None)
@@ -1474,10 +1403,7 @@ def test_next_dagrun_after_date_start_end_dates(self):
start_date = DEFAULT_DATE
end_date = start_date + (runs - 1) * delta
dag_id = "test_schedule_dag_start_end_dates"
- dag = DAG(dag_id=dag_id,
- start_date=start_date,
- end_date=end_date,
- schedule_interval=delta)
+ dag = DAG(dag_id=dag_id, start_date=start_date, end_date=end_date, schedule_interval=delta)
dag.add_task(BaseOperator(task_id='faketastic', owner='Also fake'))
# Create and schedule the dag runs
@@ -1504,11 +1430,13 @@ def make_dag(dag_id, schedule_interval, start_date, catchup):
'owner': 'airflow',
'depends_on_past': False,
}
- dag = DAG(dag_id,
- schedule_interval=schedule_interval,
- start_date=start_date,
- catchup=catchup,
- default_args=default_args)
+ dag = DAG(
+ dag_id,
+ schedule_interval=schedule_interval,
+ start_date=start_date,
+ catchup=catchup,
+ default_args=default_args,
+ )
op1 = DummyOperator(task_id='t1', dag=dag)
op2 = DummyOperator(task_id='t2', dag=dag)
@@ -1519,23 +1447,28 @@ def make_dag(dag_id, schedule_interval, start_date, catchup):
now = timezone.utcnow()
six_hours_ago_to_the_hour = (now - datetime.timedelta(hours=6)).replace(
- minute=0, second=0, microsecond=0)
+ minute=0, second=0, microsecond=0
+ )
half_an_hour_ago = now - datetime.timedelta(minutes=30)
two_hours_ago = now - datetime.timedelta(hours=2)
- dag1 = make_dag(dag_id='dag_without_catchup_ten_minute',
- schedule_interval='*/10 * * * *',
- start_date=six_hours_ago_to_the_hour,
- catchup=False)
+ dag1 = make_dag(
+ dag_id='dag_without_catchup_ten_minute',
+ schedule_interval='*/10 * * * *',
+ start_date=six_hours_ago_to_the_hour,
+ catchup=False,
+ )
next_date = dag1.next_dagrun_after_date(None)
# The DR should be scheduled in the last half an hour, not 6 hours ago
assert next_date > half_an_hour_ago
assert next_date < timezone.utcnow()
- dag2 = make_dag(dag_id='dag_without_catchup_hourly',
- schedule_interval='@hourly',
- start_date=six_hours_ago_to_the_hour,
- catchup=False)
+ dag2 = make_dag(
+ dag_id='dag_without_catchup_hourly',
+ schedule_interval='@hourly',
+ start_date=six_hours_ago_to_the_hour,
+ catchup=False,
+ )
next_date = dag2.next_dagrun_after_date(None)
# The DR should be scheduled in the last 2 hours, not 6 hours ago
@@ -1543,10 +1476,12 @@ def make_dag(dag_id, schedule_interval, start_date, catchup):
# The DR should be scheduled BEFORE now
assert next_date < timezone.utcnow()
- dag3 = make_dag(dag_id='dag_without_catchup_once',
- schedule_interval='@once',
- start_date=six_hours_ago_to_the_hour,
- catchup=False)
+ dag3 = make_dag(
+ dag_id='dag_without_catchup_once',
+ schedule_interval='@once',
+ start_date=six_hours_ago_to_the_hour,
+ catchup=False,
+ )
next_date = dag3.next_dagrun_after_date(None)
# The DR should be scheduled in the last 2 hours, not 6 hours ago
@@ -1562,7 +1497,8 @@ def test_next_dagrun_after_date_timedelta_schedule_and_catchup_false(self):
'test_scheduler_dagrun_once_with_timedelta_and_catchup_false',
start_date=timezone.datetime(2015, 1, 1),
schedule_interval=timedelta(days=1),
- catchup=False)
+ catchup=False,
+ )
next_date = dag.next_dagrun_after_date(None)
assert next_date == timezone.datetime(2020, 1, 4)
@@ -1581,7 +1517,8 @@ def test_next_dagrun_after_date_timedelta_schedule_and_catchup_true(self):
'test_scheduler_dagrun_once_with_timedelta_and_catchup_true',
start_date=timezone.datetime(2020, 5, 1),
schedule_interval=timedelta(days=1),
- catchup=True)
+ catchup=True,
+ )
next_date = dag.next_dagrun_after_date(None)
assert next_date == timezone.datetime(2020, 5, 1)
@@ -1606,12 +1543,9 @@ def test_next_dagrun_after_auto_align(self):
dag = DAG(
dag_id='test_scheduler_auto_align_1',
start_date=timezone.datetime(2016, 1, 1, 10, 10, 0),
- schedule_interval="4 5 * * *"
+ schedule_interval="4 5 * * *",
)
- DummyOperator(
- task_id='dummy',
- dag=dag,
- owner='airflow')
+ DummyOperator(task_id='dummy', dag=dag, owner='airflow')
next_date = dag.next_dagrun_after_date(None)
assert next_date == timezone.datetime(2016, 1, 2, 5, 4)
@@ -1619,12 +1553,9 @@ def test_next_dagrun_after_auto_align(self):
dag = DAG(
dag_id='test_scheduler_auto_align_2',
start_date=timezone.datetime(2016, 1, 1, 10, 10, 0),
- schedule_interval="10 10 * * *"
+ schedule_interval="10 10 * * *",
)
- DummyOperator(
- task_id='dummy',
- dag=dag,
- owner='airflow')
+ DummyOperator(task_id='dummy', dag=dag, owner='airflow')
next_date = dag.next_dagrun_after_date(None)
assert next_date == timezone.datetime(2016, 1, 1, 10, 10)
@@ -1639,9 +1570,11 @@ def subdag(parent_dag_name, child_dag_name, args):
"""
Create a subdag.
"""
- dag_subdag = DAG(dag_id=f'{parent_dag_name}.{child_dag_name}',
- schedule_interval="@daily",
- default_args=args)
+ dag_subdag = DAG(
+ dag_id=f'{parent_dag_name}.{child_dag_name}',
+ schedule_interval="@daily",
+ default_args=args,
+ )
for i in range(2):
DummyOperator(task_id='{}-task-{}'.format(child_dag_name, i + 1), dag=dag_subdag)
@@ -1673,11 +1606,11 @@ def subdag(parent_dag_name, child_dag_name, args):
def test_replace_outdated_access_control_actions(self):
outdated_permissions = {
'role1': {permissions.ACTION_CAN_READ, permissions.ACTION_CAN_EDIT},
- 'role2': {permissions.DEPRECATED_ACTION_CAN_DAG_READ, permissions.DEPRECATED_ACTION_CAN_DAG_EDIT}
+ 'role2': {permissions.DEPRECATED_ACTION_CAN_DAG_READ, permissions.DEPRECATED_ACTION_CAN_DAG_EDIT},
}
updated_permissions = {
'role1': {permissions.ACTION_CAN_READ, permissions.ACTION_CAN_EDIT},
- 'role2': {permissions.ACTION_CAN_READ, permissions.ACTION_CAN_EDIT}
+ 'role2': {permissions.ACTION_CAN_READ, permissions.ACTION_CAN_EDIT},
}
with pytest.warns(DeprecationWarning):
@@ -1690,15 +1623,9 @@ def test_replace_outdated_access_control_actions(self):
class TestDagModel:
-
def test_dags_needing_dagruns_not_too_early(self):
- dag = DAG(
- dag_id='far_future_dag',
- start_date=timezone.datetime(2038, 1, 1))
- DummyOperator(
- task_id='dummy',
- dag=dag,
- owner='airflow')
+ dag = DAG(dag_id='far_future_dag', start_date=timezone.datetime(2038, 1, 1))
+ DummyOperator(task_id='dummy', dag=dag, owner='airflow')
session = settings.Session()
orm_dag = DagModel(
@@ -1722,13 +1649,8 @@ def test_dags_needing_dagruns_only_unpaused(self):
"""
We should never create dagruns for unpaused DAGs
"""
- dag = DAG(
- dag_id='test_dags',
- start_date=DEFAULT_DATE)
- DummyOperator(
- task_id='dummy',
- dag=dag,
- owner='airflow')
+ dag = DAG(dag_id='test_dags', start_date=DEFAULT_DATE)
+ DummyOperator(task_id='dummy', dag=dag, owner='airflow')
session = settings.Session()
orm_dag = DagModel(
@@ -1755,17 +1677,18 @@ def test_dags_needing_dagruns_only_unpaused(self):
class TestQueries(unittest.TestCase):
-
def setUp(self) -> None:
clear_db_runs()
def tearDown(self) -> None:
clear_db_runs()
- @parameterized.expand([
- (3, ),
- (12, ),
- ])
+ @parameterized.expand(
+ [
+ (3,),
+ (12,),
+ ]
+ )
def test_count_number_queries(self, tasks_count):
dag = DAG('test_dagrun_query_count', start_date=DEFAULT_DATE)
for i in range(tasks_count):
@@ -1799,6 +1722,7 @@ def tearDown(self):
def test_set_dag_id(self):
"""Test that checks you can set dag_id from decorator."""
+
@dag_decorator('test', default_args=self.DEFAULT_ARGS)
def noop_pipeline():
@task_decorator
@@ -1806,12 +1730,14 @@ def return_num(num):
return num
return_num(4)
+
dag = noop_pipeline()
assert isinstance(dag, DAG)
assert dag.dag_id, 'test'
def test_default_dag_id(self):
"""Test that @dag uses function name as default dag id."""
+
@dag_decorator(default_args=self.DEFAULT_ARGS)
def noop_pipeline():
@task_decorator
@@ -1819,22 +1745,26 @@ def return_num(num):
return num
return_num(4)
+
dag = noop_pipeline()
assert isinstance(dag, DAG)
assert dag.dag_id, 'noop_pipeline'
def test_documentation_added(self):
"""Test that @dag uses function docs as doc_md for DAG object"""
+
@dag_decorator(default_args=self.DEFAULT_ARGS)
def noop_pipeline():
"""
Regular DAG documentation
"""
+
@task_decorator
def return_num(num):
return num
return_num(4)
+
dag = noop_pipeline()
assert isinstance(dag, DAG)
assert dag.dag_id, 'test'
@@ -1842,6 +1772,7 @@ def return_num(num):
def test_fails_if_arg_not_set(self):
"""Test that @dag decorated function fails if positional argument is not set"""
+
@dag_decorator(default_args=self.DEFAULT_ARGS)
def noop_pipeline(value):
@task_decorator
@@ -1856,6 +1787,7 @@ def return_num(num):
def test_dag_param_resolves(self):
"""Test that dag param is correctly resolved by operator"""
+
@dag_decorator(default_args=self.DEFAULT_ARGS)
def xcom_pass_to_op(value=self.VALUE):
@task_decorator
@@ -1871,7 +1803,7 @@ def return_num(num):
run_id=DagRunType.MANUAL.value,
start_date=timezone.utcnow(),
execution_date=self.DEFAULT_DATE,
- state=State.RUNNING
+ state=State.RUNNING,
)
self.operator.run(start_date=self.DEFAULT_DATE, end_date=self.DEFAULT_DATE)
@@ -1880,11 +1812,13 @@ def return_num(num):
def test_dag_param_dagrun_parameterized(self):
"""Test that dag param is correctly overwritten when set in dag run"""
+
@dag_decorator(default_args=self.DEFAULT_ARGS)
def xcom_pass_to_op(value=self.VALUE):
@task_decorator
def return_num(num):
return num
+
assert isinstance(value, DagParam)
xcom_arg = return_num(value)
@@ -1897,7 +1831,7 @@ def return_num(num):
start_date=timezone.utcnow(),
execution_date=self.DEFAULT_DATE,
state=State.RUNNING,
- conf={'value': new_value}
+ conf={'value': new_value},
)
self.operator.run(start_date=self.DEFAULT_DATE, end_date=self.DEFAULT_DATE)
@@ -1906,6 +1840,7 @@ def return_num(num):
def test_set_params_for_dag(self):
"""Test that dag param is correctly set when using dag decorator"""
+
@dag_decorator(default_args=self.DEFAULT_ARGS)
def xcom_pass_to_op(value=self.VALUE):
@task_decorator
diff --git a/tests/models/test_dagbag.py b/tests/models/test_dagbag.py
index d43950996bc8b..a0d7fbed5f062 100644
--- a/tests/models/test_dagbag.py
+++ b/tests/models/test_dagbag.py
@@ -64,8 +64,7 @@ def test_get_existing_dag(self):
"""
dagbag = models.DagBag(dag_folder=self.empty_dir, include_examples=True)
- some_expected_dag_ids = ["example_bash_operator",
- "example_branch_operator"]
+ some_expected_dag_ids = ["example_bash_operator", "example_branch_operator"]
for dag_id in some_expected_dag_ids:
dag = dagbag.get_dag(dag_id)
@@ -105,9 +104,7 @@ def test_safe_mode_heuristic_match(self):
dagbag = models.DagBag(include_examples=False, safe_mode=True)
self.assertEqual(len(dagbag.dagbag_stats), 1)
- self.assertEqual(
- dagbag.dagbag_stats[0].file,
- "/{}".format(os.path.basename(f.name)))
+ self.assertEqual(dagbag.dagbag_stats[0].file, "/{}".format(os.path.basename(f.name)))
def test_safe_mode_heuristic_mismatch(self):
"""With safe mode enabled, a file not matching the discovery heuristics
@@ -119,15 +116,12 @@ def test_safe_mode_heuristic_mismatch(self):
self.assertEqual(len(dagbag.dagbag_stats), 0)
def test_safe_mode_disabled(self):
- """With safe mode disabled, an empty python file should be discovered.
- """
+ """With safe mode disabled, an empty python file should be discovered."""
with NamedTemporaryFile(dir=self.empty_dir, suffix=".py") as f:
with conf_vars({('core', 'dags_folder'): self.empty_dir}):
dagbag = models.DagBag(include_examples=False, safe_mode=False)
self.assertEqual(len(dagbag.dagbag_stats), 1)
- self.assertEqual(
- dagbag.dagbag_stats[0].file,
- "/{}".format(os.path.basename(f.name)))
+ self.assertEqual(dagbag.dagbag_stats[0].file, "/{}".format(os.path.basename(f.name)))
def test_process_file_that_contains_multi_bytes_char(self):
"""
@@ -153,7 +147,7 @@ def test_zip_skip_log(self):
self.assertIn(
f'INFO:airflow.models.dagbag.DagBag:File {test_zip_path}:file_no_airflow_dag.py '
'assumed to contain no DAGs. Skipping.',
- cm.output
+ cm.output,
)
def test_zip(self):
@@ -218,7 +212,7 @@ def test_get_dag_fileloc(self):
'example_bash_operator': 'airflow/example_dags/example_bash_operator.py',
'example_subdag_operator': 'airflow/example_dags/example_subdag_operator.py',
'example_subdag_operator.section-1': 'airflow/example_dags/subdags/subdag.py',
- 'test_zip_dag': 'dags/test_zip.zip/test_zip.py'
+ 'test_zip_dag': 'dags/test_zip.zip/test_zip.py',
}
for dag_id, path in expected.items():
@@ -233,14 +227,10 @@ def test_refresh_py_dag(self, mock_dagmodel):
example_dags_folder = airflow.example_dags.__path__[0]
dag_id = "example_bash_operator"
- fileloc = os.path.realpath(
- os.path.join(example_dags_folder, "example_bash_operator.py")
- )
+ fileloc = os.path.realpath(os.path.join(example_dags_folder, "example_bash_operator.py"))
mock_dagmodel.return_value = DagModel()
- mock_dagmodel.return_value.last_expired = datetime.max.replace(
- tzinfo=timezone.utc
- )
+ mock_dagmodel.return_value.last_expired = datetime.max.replace(tzinfo=timezone.utc)
mock_dagmodel.return_value.fileloc = fileloc
class _TestDagBag(DagBag):
@@ -265,14 +255,10 @@ def test_refresh_packaged_dag(self, mock_dagmodel):
Test that we can refresh a packaged DAG
"""
dag_id = "test_zip_dag"
- fileloc = os.path.realpath(
- os.path.join(TEST_DAGS_FOLDER, "test_zip.zip/test_zip.py")
- )
+ fileloc = os.path.realpath(os.path.join(TEST_DAGS_FOLDER, "test_zip.zip/test_zip.py"))
mock_dagmodel.return_value = DagModel()
- mock_dagmodel.return_value.last_expired = datetime.max.replace(
- tzinfo=timezone.utc
- )
+ mock_dagmodel.return_value.last_expired = datetime.max.replace(tzinfo=timezone.utc)
mock_dagmodel.return_value.fileloc = fileloc
class _TestDagBag(DagBag):
@@ -296,8 +282,7 @@ def process_dag(self, create_dag):
Helper method to process a file generated from the input create_dag function.
"""
# write source to file
- source = textwrap.dedent(''.join(
- inspect.getsource(create_dag).splitlines(True)[1:-1]))
+ source = textwrap.dedent(''.join(inspect.getsource(create_dag).splitlines(True)[1:-1]))
f = NamedTemporaryFile()
f.write(source.encode('utf8'))
f.flush()
@@ -306,8 +291,7 @@ def process_dag(self, create_dag):
found_dags = dagbag.process_file(f.name)
return dagbag, found_dags, f.name
- def validate_dags(self, expected_parent_dag, actual_found_dags, actual_dagbag,
- should_be_found=True):
+ def validate_dags(self, expected_parent_dag, actual_found_dags, actual_dagbag, should_be_found=True):
expected_dag_ids = list(map(lambda dag: dag.dag_id, expected_parent_dag.subdags))
expected_dag_ids.append(expected_parent_dag.dag_id)
@@ -316,14 +300,16 @@ def validate_dags(self, expected_parent_dag, actual_found_dags, actual_dagbag,
for dag_id in expected_dag_ids:
actual_dagbag.log.info('validating %s' % dag_id)
self.assertEqual(
- dag_id in actual_found_dag_ids, should_be_found,
- 'dag "%s" should %shave been found after processing dag "%s"' %
- (dag_id, '' if should_be_found else 'not ', expected_parent_dag.dag_id)
+ dag_id in actual_found_dag_ids,
+ should_be_found,
+ 'dag "%s" should %shave been found after processing dag "%s"'
+ % (dag_id, '' if should_be_found else 'not ', expected_parent_dag.dag_id),
)
self.assertEqual(
- dag_id in actual_dagbag.dags, should_be_found,
- 'dag "%s" should %sbe in dagbag.dags after processing dag "%s"' %
- (dag_id, '' if should_be_found else 'not ', expected_parent_dag.dag_id)
+ dag_id in actual_dagbag.dags,
+ should_be_found,
+ 'dag "%s" should %sbe in dagbag.dags after processing dag "%s"'
+ % (dag_id, '' if should_be_found else 'not ', expected_parent_dag.dag_id),
)
def test_load_subdags(self):
@@ -334,14 +320,10 @@ def standard_subdag():
from airflow.models import DAG
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.subdag_operator import SubDagOperator
+
dag_name = 'master'
- default_args = {
- 'owner': 'owner1',
- 'start_date': datetime.datetime(2016, 1, 1)
- }
- dag = DAG(
- dag_name,
- default_args=default_args)
+ default_args = {'owner': 'owner1', 'start_date': datetime.datetime(2016, 1, 1)}
+ dag = DAG(dag_name, default_args=default_args)
# master:
# A -> opSubDag_0
@@ -352,6 +334,7 @@ def standard_subdag():
# -> subdag_1.task
with dag:
+
def subdag_0():
subdag_0 = DAG('master.op_subdag_0', default_args=default_args)
DummyOperator(task_id='subdag_0.task', dag=subdag_0)
@@ -362,10 +345,8 @@ def subdag_1():
DummyOperator(task_id='subdag_1.task', dag=subdag_1)
return subdag_1
- op_subdag_0 = SubDagOperator(
- task_id='op_subdag_0', dag=dag, subdag=subdag_0())
- op_subdag_1 = SubDagOperator(
- task_id='op_subdag_1', dag=dag, subdag=subdag_1())
+ op_subdag_0 = SubDagOperator(task_id='op_subdag_0', dag=dag, subdag=subdag_0())
+ op_subdag_1 = SubDagOperator(task_id='op_subdag_1', dag=dag, subdag=subdag_1())
op_a = DummyOperator(task_id='A')
op_a.set_downstream(op_subdag_0)
@@ -390,14 +371,10 @@ def nested_subdags():
from airflow.models import DAG
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.subdag_operator import SubDagOperator
+
dag_name = 'master'
- default_args = {
- 'owner': 'owner1',
- 'start_date': datetime.datetime(2016, 1, 1)
- }
- dag = DAG(
- dag_name,
- default_args=default_args)
+ default_args = {'owner': 'owner1', 'start_date': datetime.datetime(2016, 1, 1)}
+ dag = DAG(dag_name, default_args=default_args)
# master:
# A -> op_subdag_0
@@ -418,27 +395,24 @@ def nested_subdags():
# -> subdag_d.task
with dag:
+
def subdag_a():
- subdag_a = DAG(
- 'master.op_subdag_0.opSubdag_A', default_args=default_args)
+ subdag_a = DAG('master.op_subdag_0.opSubdag_A', default_args=default_args)
DummyOperator(task_id='subdag_a.task', dag=subdag_a)
return subdag_a
def subdag_b():
- subdag_b = DAG(
- 'master.op_subdag_0.opSubdag_B', default_args=default_args)
+ subdag_b = DAG('master.op_subdag_0.opSubdag_B', default_args=default_args)
DummyOperator(task_id='subdag_b.task', dag=subdag_b)
return subdag_b
def subdag_c():
- subdag_c = DAG(
- 'master.op_subdag_1.opSubdag_C', default_args=default_args)
+ subdag_c = DAG('master.op_subdag_1.opSubdag_C', default_args=default_args)
DummyOperator(task_id='subdag_c.task', dag=subdag_c)
return subdag_c
def subdag_d():
- subdag_d = DAG(
- 'master.op_subdag_1.opSubdag_D', default_args=default_args)
+ subdag_d = DAG('master.op_subdag_1.opSubdag_D', default_args=default_args)
DummyOperator(task_id='subdag_d.task', dag=subdag_d)
return subdag_d
@@ -454,10 +428,8 @@ def subdag_1():
SubDagOperator(task_id='opSubdag_D', dag=subdag_1, subdag=subdag_d())
return subdag_1
- op_subdag_0 = SubDagOperator(
- task_id='op_subdag_0', dag=dag, subdag=subdag_0())
- op_subdag_1 = SubDagOperator(
- task_id='op_subdag_1', dag=dag, subdag=subdag_1())
+ op_subdag_0 = SubDagOperator(task_id='op_subdag_0', dag=dag, subdag=subdag_0())
+ op_subdag_1 = SubDagOperator(task_id='op_subdag_1', dag=dag, subdag=subdag_1())
op_a = DummyOperator(task_id='A')
op_a.set_downstream(op_subdag_0)
@@ -488,14 +460,10 @@ def basic_cycle():
from airflow.models import DAG
from airflow.operators.dummy_operator import DummyOperator
+
dag_name = 'cycle_dag'
- default_args = {
- 'owner': 'owner1',
- 'start_date': datetime.datetime(2016, 1, 1)
- }
- dag = DAG(
- dag_name,
- default_args=default_args)
+ default_args = {'owner': 'owner1', 'start_date': datetime.datetime(2016, 1, 1)}
+ dag = DAG(dag_name, default_args=default_args)
# A -> A
with dag:
@@ -523,14 +491,10 @@ def nested_subdag_cycle():
from airflow.models import DAG
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.subdag_operator import SubDagOperator
+
dag_name = 'nested_cycle'
- default_args = {
- 'owner': 'owner1',
- 'start_date': datetime.datetime(2016, 1, 1)
- }
- dag = DAG(
- dag_name,
- default_args=default_args)
+ default_args = {'owner': 'owner1', 'start_date': datetime.datetime(2016, 1, 1)}
+ dag = DAG(dag_name, default_args=default_args)
# cycle:
# A -> op_subdag_0
@@ -551,30 +515,26 @@ def nested_subdag_cycle():
# -> subdag_d.task
with dag:
+
def subdag_a():
- subdag_a = DAG(
- 'nested_cycle.op_subdag_0.opSubdag_A', default_args=default_args)
+ subdag_a = DAG('nested_cycle.op_subdag_0.opSubdag_A', default_args=default_args)
DummyOperator(task_id='subdag_a.task', dag=subdag_a)
return subdag_a
def subdag_b():
- subdag_b = DAG(
- 'nested_cycle.op_subdag_0.opSubdag_B', default_args=default_args)
+ subdag_b = DAG('nested_cycle.op_subdag_0.opSubdag_B', default_args=default_args)
DummyOperator(task_id='subdag_b.task', dag=subdag_b)
return subdag_b
def subdag_c():
- subdag_c = DAG(
- 'nested_cycle.op_subdag_1.opSubdag_C', default_args=default_args)
- op_subdag_c_task = DummyOperator(
- task_id='subdag_c.task', dag=subdag_c)
+ subdag_c = DAG('nested_cycle.op_subdag_1.opSubdag_C', default_args=default_args)
+ op_subdag_c_task = DummyOperator(task_id='subdag_c.task', dag=subdag_c)
# introduce a loop in opSubdag_C
op_subdag_c_task.set_downstream(op_subdag_c_task)
return subdag_c
def subdag_d():
- subdag_d = DAG(
- 'nested_cycle.op_subdag_1.opSubdag_D', default_args=default_args)
+ subdag_d = DAG('nested_cycle.op_subdag_1.opSubdag_D', default_args=default_args)
DummyOperator(task_id='subdag_d.task', dag=subdag_d)
return subdag_d
@@ -590,10 +550,8 @@ def subdag_1():
SubDagOperator(task_id='opSubdag_D', dag=subdag_1, subdag=subdag_d())
return subdag_1
- op_subdag_0 = SubDagOperator(
- task_id='op_subdag_0', dag=dag, subdag=subdag_0())
- op_subdag_1 = SubDagOperator(
- task_id='op_subdag_1', dag=dag, subdag=subdag_1())
+ op_subdag_0 = SubDagOperator(task_id='op_subdag_0', dag=dag, subdag=subdag_0())
+ op_subdag_1 = SubDagOperator(task_id='op_subdag_1', dag=dag, subdag=subdag_1())
op_a = DummyOperator(task_id='A')
op_a.set_downstream(op_subdag_0)
@@ -655,7 +613,8 @@ def test_serialized_dags_are_written_to_db_on_sync(self):
dagbag = DagBag(
dag_folder=os.path.join(TEST_DAGS_FOLDER, "test_example_bash_operator.py"),
- include_examples=False)
+ include_examples=False,
+ )
dagbag.sync_to_db()
self.assertFalse(dagbag.read_dags_from_db)
@@ -682,11 +641,13 @@ def test_sync_to_db_is_retried(self, mock_bulk_write_to_db, mock_sdag_sync_to_db
dagbag.sync_to_db(session=mock_session)
# Test that 3 attempts were made to run 'DAG.bulk_write_to_db' successfully
- mock_bulk_write_to_db.assert_has_calls([
- mock.call(mock.ANY, session=mock.ANY),
- mock.call(mock.ANY, session=mock.ANY),
- mock.call(mock.ANY, session=mock.ANY),
- ])
+ mock_bulk_write_to_db.assert_has_calls(
+ [
+ mock.call(mock.ANY, session=mock.ANY),
+ mock.call(mock.ANY, session=mock.ANY),
+ mock.call(mock.ANY, session=mock.ANY),
+ ]
+ )
# Assert that rollback is called twice (i.e. whenever OperationalError occurs)
mock_session.rollback.assert_has_calls([mock.call(), mock.call()])
# Check that 'SerializedDagModel.bulk_sync_to_db' is also called
@@ -760,9 +721,7 @@ def test_cluster_policy_violation(self):
"""
dag_file = os.path.join(TEST_DAGS_FOLDER, "test_missing_owner.py")
- dagbag = DagBag(dag_folder=dag_file,
- include_smart_sensor=False,
- include_examples=False)
+ dagbag = DagBag(dag_folder=dag_file, include_smart_sensor=False, include_examples=False)
self.assertEqual(set(), set(dagbag.dag_ids))
expected_import_errors = {
dag_file: (
@@ -778,12 +737,9 @@ def test_cluster_policy_obeyed(self):
"""test that dag successfully imported without import errors when tasks
obey cluster policy.
"""
- dag_file = os.path.join(TEST_DAGS_FOLDER,
- "test_with_non_default_owner.py")
+ dag_file = os.path.join(TEST_DAGS_FOLDER, "test_with_non_default_owner.py")
- dagbag = DagBag(dag_folder=dag_file,
- include_examples=False,
- include_smart_sensor=False)
+ dagbag = DagBag(dag_folder=dag_file, include_examples=False, include_smart_sensor=False)
self.assertEqual({"test_with_non_default_owner"}, set(dagbag.dag_ids))
self.assertEqual({}, dagbag.import_errors)
diff --git a/tests/models/test_dagcode.py b/tests/models/test_dagcode.py
index c16ebf1d2f7f5..ee2eb2aaa61d4 100644
--- a/tests/models/test_dagcode.py
+++ b/tests/models/test_dagcode.py
@@ -22,6 +22,7 @@
from airflow import AirflowException, example_dags as example_dags_module
from airflow.models import DagBag
from airflow.models.dagcode import DagCode
+
# To move it to a shared module.
from airflow.utils.file import open_maybe_zipped
from airflow.utils.session import create_session
@@ -82,7 +83,7 @@ def test_bulk_sync_to_db_half_files(self):
"""Dg code can be bulk written into database."""
example_dags = make_example_dags(example_dags_module)
files = [dag.fileloc for dag in example_dags.values()]
- half_files = files[:int(len(files) / 2)]
+ half_files = files[: int(len(files) / 2)]
with create_session() as session:
DagCode.bulk_sync_to_db(half_files, session=session)
session.commit()
@@ -107,11 +108,12 @@ def _compare_example_dags(self, example_dags):
dag.fileloc = dag.parent_dag.fileloc
self.assertTrue(DagCode.has_dag(dag.fileloc))
dag_fileloc_hash = DagCode.dag_fileloc_hash(dag.fileloc)
- result = session.query(
- DagCode.fileloc, DagCode.fileloc_hash, DagCode.source_code) \
- .filter(DagCode.fileloc == dag.fileloc) \
- .filter(DagCode.fileloc_hash == dag_fileloc_hash) \
+ result = (
+ session.query(DagCode.fileloc, DagCode.fileloc_hash, DagCode.source_code)
+ .filter(DagCode.fileloc == dag.fileloc)
+ .filter(DagCode.fileloc_hash == dag_fileloc_hash)
.one()
+ )
self.assertEqual(result.fileloc, dag.fileloc)
with open_maybe_zipped(dag.fileloc, 'r') as source:
@@ -145,9 +147,7 @@ def test_db_code_updated_on_dag_file_change(self):
example_dag.sync_to_db()
with create_session() as session:
- result = session.query(DagCode) \
- .filter(DagCode.fileloc == example_dag.fileloc) \
- .one()
+ result = session.query(DagCode).filter(DagCode.fileloc == example_dag.fileloc).one()
self.assertEqual(result.fileloc, example_dag.fileloc)
self.assertIsNotNone(result.source_code)
@@ -160,9 +160,7 @@ def test_db_code_updated_on_dag_file_change(self):
example_dag.sync_to_db()
with create_session() as session:
- new_result = session.query(DagCode) \
- .filter(DagCode.fileloc == example_dag.fileloc) \
- .one()
+ new_result = session.query(DagCode).filter(DagCode.fileloc == example_dag.fileloc).one()
self.assertEqual(new_result.fileloc, example_dag.fileloc)
self.assertEqual(new_result.source_code, "# dummy code")
diff --git a/tests/models/test_dagparam.py b/tests/models/test_dagparam.py
index 2f723baea3fa0..1eb28cbc54c1c 100644
--- a/tests/models/test_dagparam.py
+++ b/tests/models/test_dagparam.py
@@ -57,7 +57,7 @@ def return_num(num):
run_id=DagRunType.MANUAL.value,
start_date=timezone.utcnow(),
execution_date=self.DEFAULT_DATE,
- state=State.RUNNING
+ state=State.RUNNING,
)
# pylint: disable=maybe-no-member
@@ -84,7 +84,7 @@ def return_num(num):
start_date=timezone.utcnow(),
execution_date=self.DEFAULT_DATE,
state=State.RUNNING,
- conf={'value': new_value}
+ conf={'value': new_value},
)
# pylint: disable=maybe-no-member
diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py
index ec7f6e69ed573..33bf32d6373e6 100644
--- a/tests/models/test_dagrun.py
+++ b/tests/models/test_dagrun.py
@@ -37,7 +37,6 @@
class TestDagRun(unittest.TestCase):
-
@classmethod
def setUpClass(cls):
cls.dagbag = DagBag(include_examples=True)
@@ -46,12 +45,14 @@ def setUp(self):
clear_db_runs()
clear_db_pools()
- def create_dag_run(self, dag,
- state=State.RUNNING,
- task_states=None,
- execution_date=None,
- is_backfill=False,
- ):
+ def create_dag_run(
+ self,
+ dag,
+ state=State.RUNNING,
+ task_states=None,
+ execution_date=None,
+ is_backfill=False,
+ ):
now = timezone.utcnow()
if execution_date is None:
execution_date = now
@@ -88,15 +89,11 @@ def test_clear_task_instances_for_backfill_dagrun(self):
ti0 = TI(task=task0, execution_date=now)
ti0.run()
- qry = session.query(TI).filter(
- TI.dag_id == dag.dag_id).all()
+ qry = session.query(TI).filter(TI.dag_id == dag.dag_id).all()
clear_task_instances(qry, session)
session.commit()
ti0.refresh_from_db()
- dr0 = session.query(DagRun).filter(
- DagRun.dag_id == dag_id,
- DagRun.execution_date == now
- ).first()
+ dr0 = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.execution_date == now).first()
self.assertEqual(dr0.state, State.RUNNING)
def test_dagrun_find(self):
@@ -127,33 +124,21 @@ def test_dagrun_find(self):
session.commit()
- self.assertEqual(1,
- len(models.DagRun.find(dag_id=dag_id1, external_trigger=True)))
- self.assertEqual(0,
- len(models.DagRun.find(dag_id=dag_id1, external_trigger=False)))
- self.assertEqual(0,
- len(models.DagRun.find(dag_id=dag_id2, external_trigger=True)))
- self.assertEqual(1,
- len(models.DagRun.find(dag_id=dag_id2, external_trigger=False)))
+ self.assertEqual(1, len(models.DagRun.find(dag_id=dag_id1, external_trigger=True)))
+ self.assertEqual(0, len(models.DagRun.find(dag_id=dag_id1, external_trigger=False)))
+ self.assertEqual(0, len(models.DagRun.find(dag_id=dag_id2, external_trigger=True)))
+ self.assertEqual(1, len(models.DagRun.find(dag_id=dag_id2, external_trigger=False)))
def test_dagrun_success_when_all_skipped(self):
"""
Tests that a DAG run succeeds when all tasks are skipped
"""
- dag = DAG(
- dag_id='test_dagrun_success_when_all_skipped',
- start_date=timezone.datetime(2017, 1, 1)
- )
+ dag = DAG(dag_id='test_dagrun_success_when_all_skipped', start_date=timezone.datetime(2017, 1, 1))
dag_task1 = ShortCircuitOperator(
- task_id='test_short_circuit_false',
- dag=dag,
- python_callable=lambda: False)
- dag_task2 = DummyOperator(
- task_id='test_state_skipped1',
- dag=dag)
- dag_task3 = DummyOperator(
- task_id='test_state_skipped2',
- dag=dag)
+ task_id='test_short_circuit_false', dag=dag, python_callable=lambda: False
+ )
+ dag_task2 = DummyOperator(task_id='test_state_skipped1', dag=dag)
+ dag_task3 = DummyOperator(task_id='test_state_skipped2', dag=dag)
dag_task1.set_downstream(dag_task2)
dag_task2.set_downstream(dag_task3)
@@ -163,19 +148,14 @@ def test_dagrun_success_when_all_skipped(self):
'test_state_skipped2': State.SKIPPED,
}
- dag_run = self.create_dag_run(dag=dag,
- state=State.RUNNING,
- task_states=initial_task_states)
+ dag_run = self.create_dag_run(dag=dag, state=State.RUNNING, task_states=initial_task_states)
dag_run.update_state()
self.assertEqual(State.SUCCESS, dag_run.state)
def test_dagrun_success_conditions(self):
session = settings.Session()
- dag = DAG(
- 'test_dagrun_success_conditions',
- start_date=DEFAULT_DATE,
- default_args={'owner': 'owner1'})
+ dag = DAG('test_dagrun_success_conditions', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
# A -> B
# A -> C -> D
@@ -191,10 +171,9 @@ def test_dagrun_success_conditions(self):
dag.clear()
now = timezone.utcnow()
- dr = dag.create_dagrun(run_id='test_dagrun_success_conditions',
- state=State.RUNNING,
- execution_date=now,
- start_date=now)
+ dr = dag.create_dagrun(
+ run_id='test_dagrun_success_conditions', state=State.RUNNING, execution_date=now, start_date=now
+ )
# op1 = root
ti_op1 = dr.get_task_instance(task_id=op1.task_id)
@@ -217,10 +196,7 @@ def test_dagrun_success_conditions(self):
def test_dagrun_deadlock(self):
session = settings.Session()
- dag = DAG(
- 'text_dagrun_deadlock',
- start_date=DEFAULT_DATE,
- default_args={'owner': 'owner1'})
+ dag = DAG('text_dagrun_deadlock', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
with dag:
op1 = DummyOperator(task_id='A')
@@ -230,10 +206,9 @@ def test_dagrun_deadlock(self):
dag.clear()
now = timezone.utcnow()
- dr = dag.create_dagrun(run_id='test_dagrun_deadlock',
- state=State.RUNNING,
- execution_date=now,
- start_date=now)
+ dr = dag.create_dagrun(
+ run_id='test_dagrun_deadlock', state=State.RUNNING, execution_date=now, start_date=now
+ )
ti_op1 = dr.get_task_instance(task_id=op1.task_id)
ti_op1.set_state(state=State.SUCCESS, session=session)
@@ -250,17 +225,18 @@ def test_dagrun_deadlock(self):
def test_dagrun_no_deadlock_with_shutdown(self):
session = settings.Session()
- dag = DAG('test_dagrun_no_deadlock_with_shutdown',
- start_date=DEFAULT_DATE)
+ dag = DAG('test_dagrun_no_deadlock_with_shutdown', start_date=DEFAULT_DATE)
with dag:
op1 = DummyOperator(task_id='upstream_task')
op2 = DummyOperator(task_id='downstream_task')
op2.set_upstream(op1)
- dr = dag.create_dagrun(run_id='test_dagrun_no_deadlock_with_shutdown',
- state=State.RUNNING,
- execution_date=DEFAULT_DATE,
- start_date=DEFAULT_DATE)
+ dr = dag.create_dagrun(
+ run_id='test_dagrun_no_deadlock_with_shutdown',
+ state=State.RUNNING,
+ execution_date=DEFAULT_DATE,
+ start_date=DEFAULT_DATE,
+ )
upstream_ti = dr.get_task_instance(task_id='upstream_task')
upstream_ti.set_state(State.SHUTDOWN, session=session)
@@ -269,21 +245,24 @@ def test_dagrun_no_deadlock_with_shutdown(self):
def test_dagrun_no_deadlock_with_depends_on_past(self):
session = settings.Session()
- dag = DAG('test_dagrun_no_deadlock',
- start_date=DEFAULT_DATE)
+ dag = DAG('test_dagrun_no_deadlock', start_date=DEFAULT_DATE)
with dag:
DummyOperator(task_id='dop', depends_on_past=True)
DummyOperator(task_id='tc', task_concurrency=1)
dag.clear()
- dr = dag.create_dagrun(run_id='test_dagrun_no_deadlock_1',
- state=State.RUNNING,
- execution_date=DEFAULT_DATE,
- start_date=DEFAULT_DATE)
- dr2 = dag.create_dagrun(run_id='test_dagrun_no_deadlock_2',
- state=State.RUNNING,
- execution_date=DEFAULT_DATE + datetime.timedelta(days=1),
- start_date=DEFAULT_DATE + datetime.timedelta(days=1))
+ dr = dag.create_dagrun(
+ run_id='test_dagrun_no_deadlock_1',
+ state=State.RUNNING,
+ execution_date=DEFAULT_DATE,
+ start_date=DEFAULT_DATE,
+ )
+ dr2 = dag.create_dagrun(
+ run_id='test_dagrun_no_deadlock_2',
+ state=State.RUNNING,
+ execution_date=DEFAULT_DATE + datetime.timedelta(days=1),
+ start_date=DEFAULT_DATE + datetime.timedelta(days=1),
+ )
ti1_op1 = dr.get_task_instance(task_id='dop')
dr2.get_task_instance(task_id='dop')
ti2_op1 = dr.get_task_instance(task_id='tc')
@@ -302,22 +281,15 @@ def test_dagrun_no_deadlock_with_depends_on_past(self):
def test_dagrun_success_callback(self):
def on_success_callable(context):
- self.assertEqual(
- context['dag_run'].dag_id,
- 'test_dagrun_success_callback'
- )
+ self.assertEqual(context['dag_run'].dag_id, 'test_dagrun_success_callback')
dag = DAG(
dag_id='test_dagrun_success_callback',
start_date=datetime.datetime(2017, 1, 1),
on_success_callback=on_success_callable,
)
- dag_task1 = DummyOperator(
- task_id='test_state_succeeded1',
- dag=dag)
- dag_task2 = DummyOperator(
- task_id='test_state_succeeded2',
- dag=dag)
+ dag_task1 = DummyOperator(task_id='test_state_succeeded1', dag=dag)
+ dag_task2 = DummyOperator(task_id='test_state_succeeded2', dag=dag)
dag_task1.set_downstream(dag_task2)
initial_task_states = {
@@ -325,9 +297,7 @@ def on_success_callable(context):
'test_state_succeeded2': State.SUCCESS,
}
- dag_run = self.create_dag_run(dag=dag,
- state=State.RUNNING,
- task_states=initial_task_states)
+ dag_run = self.create_dag_run(dag=dag, state=State.RUNNING, task_states=initial_task_states)
_, callback = dag_run.update_state()
self.assertEqual(State.SUCCESS, dag_run.state)
# Callbacks are not added until handle_callback = False is passed to dag_run.update_state()
@@ -335,22 +305,15 @@ def on_success_callable(context):
def test_dagrun_failure_callback(self):
def on_failure_callable(context):
- self.assertEqual(
- context['dag_run'].dag_id,
- 'test_dagrun_failure_callback'
- )
+ self.assertEqual(context['dag_run'].dag_id, 'test_dagrun_failure_callback')
dag = DAG(
dag_id='test_dagrun_failure_callback',
start_date=datetime.datetime(2017, 1, 1),
on_failure_callback=on_failure_callable,
)
- dag_task1 = DummyOperator(
- task_id='test_state_succeeded1',
- dag=dag)
- dag_task2 = DummyOperator(
- task_id='test_state_failed2',
- dag=dag)
+ dag_task1 = DummyOperator(task_id='test_state_succeeded1', dag=dag)
+ dag_task2 = DummyOperator(task_id='test_state_failed2', dag=dag)
initial_task_states = {
'test_state_succeeded1': State.SUCCESS,
@@ -358,9 +321,7 @@ def on_failure_callable(context):
}
dag_task1.set_downstream(dag_task2)
- dag_run = self.create_dag_run(dag=dag,
- state=State.RUNNING,
- task_states=initial_task_states)
+ dag_run = self.create_dag_run(dag=dag, state=State.RUNNING, task_states=initial_task_states)
_, callback = dag_run.update_state()
self.assertEqual(State.FAILED, dag_run.state)
# Callbacks are not added until handle_callback = False is passed to dag_run.update_state()
@@ -369,8 +330,7 @@ def on_failure_callable(context):
def test_dagrun_update_state_with_handle_callback_success(self):
def on_success_callable(context):
self.assertEqual(
- context['dag_run'].dag_id,
- 'test_dagrun_update_state_with_handle_callback_success'
+ context['dag_run'].dag_id, 'test_dagrun_update_state_with_handle_callback_success'
)
dag = DAG(
@@ -378,12 +338,8 @@ def on_success_callable(context):
start_date=datetime.datetime(2017, 1, 1),
on_success_callback=on_success_callable,
)
- dag_task1 = DummyOperator(
- task_id='test_state_succeeded1',
- dag=dag)
- dag_task2 = DummyOperator(
- task_id='test_state_succeeded2',
- dag=dag)
+ dag_task1 = DummyOperator(task_id='test_state_succeeded1', dag=dag)
+ dag_task2 = DummyOperator(task_id='test_state_succeeded2', dag=dag)
dag_task1.set_downstream(dag_task2)
initial_task_states = {
@@ -402,14 +358,13 @@ def on_success_callable(context):
dag_id="test_dagrun_update_state_with_handle_callback_success",
execution_date=dag_run.execution_date,
is_failure_callback=False,
- msg="success"
+ msg="success",
)
def test_dagrun_update_state_with_handle_callback_failure(self):
def on_failure_callable(context):
self.assertEqual(
- context['dag_run'].dag_id,
- 'test_dagrun_update_state_with_handle_callback_failure'
+ context['dag_run'].dag_id, 'test_dagrun_update_state_with_handle_callback_failure'
)
dag = DAG(
@@ -417,12 +372,8 @@ def on_failure_callable(context):
start_date=datetime.datetime(2017, 1, 1),
on_failure_callback=on_failure_callable,
)
- dag_task1 = DummyOperator(
- task_id='test_state_succeeded1',
- dag=dag)
- dag_task2 = DummyOperator(
- task_id='test_state_failed2',
- dag=dag)
+ dag_task1 = DummyOperator(task_id='test_state_succeeded1', dag=dag)
+ dag_task2 = DummyOperator(task_id='test_state_failed2', dag=dag)
dag_task1.set_downstream(dag_task2)
initial_task_states = {
@@ -441,24 +392,20 @@ def on_failure_callable(context):
dag_id="test_dagrun_update_state_with_handle_callback_failure",
execution_date=dag_run.execution_date,
is_failure_callback=True,
- msg="task_failure"
+ msg="task_failure",
)
def test_dagrun_set_state_end_date(self):
session = settings.Session()
- dag = DAG(
- 'test_dagrun_set_state_end_date',
- start_date=DEFAULT_DATE,
- default_args={'owner': 'owner1'})
+ dag = DAG('test_dagrun_set_state_end_date', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
dag.clear()
now = timezone.utcnow()
- dr = dag.create_dagrun(run_id='test_dagrun_set_state_end_date',
- state=State.RUNNING,
- execution_date=now,
- start_date=now)
+ dr = dag.create_dagrun(
+ run_id='test_dagrun_set_state_end_date', state=State.RUNNING, execution_date=now, start_date=now
+ )
# Initial end_date should be NULL
# State.SUCCESS and State.FAILED are all ending state and should set end_date
@@ -471,9 +418,7 @@ def test_dagrun_set_state_end_date(self):
session.merge(dr)
session.commit()
- dr_database = session.query(DagRun).filter(
- DagRun.run_id == 'test_dagrun_set_state_end_date'
- ).one()
+ dr_database = session.query(DagRun).filter(DagRun.run_id == 'test_dagrun_set_state_end_date').one()
self.assertIsNotNone(dr_database.end_date)
self.assertEqual(dr.end_date, dr_database.end_date)
@@ -481,18 +426,14 @@ def test_dagrun_set_state_end_date(self):
session.merge(dr)
session.commit()
- dr_database = session.query(DagRun).filter(
- DagRun.run_id == 'test_dagrun_set_state_end_date'
- ).one()
+ dr_database = session.query(DagRun).filter(DagRun.run_id == 'test_dagrun_set_state_end_date').one()
self.assertIsNone(dr_database.end_date)
dr.set_state(State.FAILED)
session.merge(dr)
session.commit()
- dr_database = session.query(DagRun).filter(
- DagRun.run_id == 'test_dagrun_set_state_end_date'
- ).one()
+ dr_database = session.query(DagRun).filter(DagRun.run_id == 'test_dagrun_set_state_end_date').one()
self.assertIsNotNone(dr_database.end_date)
self.assertEqual(dr.end_date, dr_database.end_date)
@@ -501,9 +442,8 @@ def test_dagrun_update_state_end_date(self):
session = settings.Session()
dag = DAG(
- 'test_dagrun_update_state_end_date',
- start_date=DEFAULT_DATE,
- default_args={'owner': 'owner1'})
+ 'test_dagrun_update_state_end_date', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}
+ )
# A -> B
with dag:
@@ -514,10 +454,12 @@ def test_dagrun_update_state_end_date(self):
dag.clear()
now = timezone.utcnow()
- dr = dag.create_dagrun(run_id='test_dagrun_update_state_end_date',
- state=State.RUNNING,
- execution_date=now,
- start_date=now)
+ dr = dag.create_dagrun(
+ run_id='test_dagrun_update_state_end_date',
+ state=State.RUNNING,
+ execution_date=now,
+ start_date=now,
+ )
# Initial end_date should be NULL
# State.SUCCESS and State.FAILED are all ending state and should set end_date
@@ -533,9 +475,7 @@ def test_dagrun_update_state_end_date(self):
dr.update_state()
- dr_database = session.query(DagRun).filter(
- DagRun.run_id == 'test_dagrun_update_state_end_date'
- ).one()
+ dr_database = session.query(DagRun).filter(DagRun.run_id == 'test_dagrun_update_state_end_date').one()
self.assertIsNotNone(dr_database.end_date)
self.assertEqual(dr.end_date, dr_database.end_date)
@@ -543,9 +483,7 @@ def test_dagrun_update_state_end_date(self):
ti_op2.set_state(state=State.RUNNING, session=session)
dr.update_state()
- dr_database = session.query(DagRun).filter(
- DagRun.run_id == 'test_dagrun_update_state_end_date'
- ).one()
+ dr_database = session.query(DagRun).filter(DagRun.run_id == 'test_dagrun_update_state_end_date').one()
self.assertEqual(dr._state, State.RUNNING)
self.assertIsNone(dr.end_date)
@@ -555,9 +493,7 @@ def test_dagrun_update_state_end_date(self):
ti_op2.set_state(state=State.FAILED, session=session)
dr.update_state()
- dr_database = session.query(DagRun).filter(
- DagRun.run_id == 'test_dagrun_update_state_end_date'
- ).one()
+ dr_database = session.query(DagRun).filter(DagRun.run_id == 'test_dagrun_update_state_end_date').one()
self.assertIsNotNone(dr_database.end_date)
self.assertEqual(dr.end_date, dr_database.end_date)
@@ -566,14 +502,8 @@ def test_get_task_instance_on_empty_dagrun(self):
"""
Make sure that a proper value is returned when a dagrun has no task instances
"""
- dag = DAG(
- dag_id='test_get_task_instance_on_empty_dagrun',
- start_date=timezone.datetime(2017, 1, 1)
- )
- ShortCircuitOperator(
- task_id='test_short_circuit_false',
- dag=dag,
- python_callable=lambda: False)
+ dag = DAG(dag_id='test_get_task_instance_on_empty_dagrun', start_date=timezone.datetime(2017, 1, 1))
+ ShortCircuitOperator(task_id='test_short_circuit_false', dag=dag, python_callable=lambda: False)
session = settings.Session()
@@ -597,9 +527,7 @@ def test_get_task_instance_on_empty_dagrun(self):
def test_get_latest_runs(self):
session = settings.Session()
- dag = DAG(
- dag_id='test_latest_runs_1',
- start_date=DEFAULT_DATE)
+ dag = DAG(dag_id='test_latest_runs_1', start_date=DEFAULT_DATE)
self.create_dag_run(dag, execution_date=timezone.datetime(2015, 1, 1))
self.create_dag_run(dag, execution_date=timezone.datetime(2015, 1, 2))
dagruns = models.DagRun.get_latest_runs(session)
@@ -614,11 +542,9 @@ def test_is_backfill(self):
dagrun = self.create_dag_run(dag, execution_date=DEFAULT_DATE)
dagrun.run_type = DagRunType.BACKFILL_JOB
- dagrun2 = self.create_dag_run(
- dag, execution_date=DEFAULT_DATE + datetime.timedelta(days=1))
+ dagrun2 = self.create_dag_run(dag, execution_date=DEFAULT_DATE + datetime.timedelta(days=1))
- dagrun3 = self.create_dag_run(
- dag, execution_date=DEFAULT_DATE + datetime.timedelta(days=2))
+ dagrun3 = self.create_dag_run(dag, execution_date=DEFAULT_DATE + datetime.timedelta(days=2))
dagrun3.run_id = None
self.assertTrue(dagrun.is_backfill)
@@ -694,13 +620,15 @@ def mutate_task_instance(task_instance):
task = dagrun.get_task_instances()[0]
assert task.queue == 'queue1'
- @parameterized.expand([
- (State.SUCCESS, True),
- (State.SKIPPED, True),
- (State.RUNNING, False),
- (State.FAILED, False),
- (State.NONE, False),
- ])
+ @parameterized.expand(
+ [
+ (State.SUCCESS, True),
+ (State.SKIPPED, True),
+ (State.RUNNING, False),
+ (State.FAILED, False),
+ (State.NONE, False),
+ ]
+ )
def test_depends_on_past(self, prev_ti_state, is_ti_success):
dag_id = 'test_depends_on_past'
@@ -718,13 +646,15 @@ def test_depends_on_past(self, prev_ti_state, is_ti_success):
ti.run()
self.assertEqual(ti.state == State.SUCCESS, is_ti_success)
- @parameterized.expand([
- (State.SUCCESS, True),
- (State.SKIPPED, True),
- (State.RUNNING, False),
- (State.FAILED, False),
- (State.NONE, False),
- ])
+ @parameterized.expand(
+ [
+ (State.SUCCESS, True),
+ (State.SKIPPED, True),
+ (State.RUNNING, False),
+ (State.FAILED, False),
+ (State.NONE, False),
+ ]
+ )
def test_wait_for_downstream(self, prev_ti_state, is_ti_success):
dag_id = 'test_wait_for_downstream'
dag = self.dagbag.get_dag(dag_id)
@@ -751,13 +681,8 @@ def test_next_dagruns_to_examine_only_unpaused(self):
Check that "next_dagruns_to_examine" ignores runs from paused/inactive DAGs
"""
- dag = DAG(
- dag_id='test_dags',
- start_date=DEFAULT_DATE)
- DummyOperator(
- task_id='dummy',
- dag=dag,
- owner='airflow')
+ dag = DAG(dag_id='test_dags', start_date=DEFAULT_DATE)
+ DummyOperator(task_id='dummy', dag=dag, owner='airflow')
session = settings.Session()
orm_dag = DagModel(
@@ -769,11 +694,13 @@ def test_next_dagruns_to_examine_only_unpaused(self):
)
session.add(orm_dag)
session.flush()
- dr = dag.create_dagrun(run_type=DagRunType.SCHEDULED,
- state=State.RUNNING,
- execution_date=DEFAULT_DATE,
- start_date=DEFAULT_DATE,
- session=session)
+ dr = dag.create_dagrun(
+ run_type=DagRunType.SCHEDULED,
+ state=State.RUNNING,
+ execution_date=DEFAULT_DATE,
+ start_date=DEFAULT_DATE,
+ session=session,
+ )
runs = DagRun.next_dagruns_to_examine(session).all()
diff --git a/tests/models/test_pool.py b/tests/models/test_pool.py
index 1860ed446935a..bea4b2ebcfa94 100644
--- a/tests/models/test_pool.py
+++ b/tests/models/test_pool.py
@@ -31,7 +31,6 @@
class TestPool(unittest.TestCase):
-
def setUp(self):
clear_db_runs()
clear_db_pools()
@@ -44,7 +43,8 @@ def test_open_slots(self):
pool = Pool(pool='test_pool', slots=5)
dag = DAG(
dag_id='test_open_slots',
- start_date=DEFAULT_DATE, )
+ start_date=DEFAULT_DATE,
+ )
op1 = DummyOperator(task_id='dummy1', dag=dag, pool='test_pool')
op2 = DummyOperator(task_id='dummy2', dag=dag, pool='test_pool')
ti1 = TI(task=op1, execution_date=DEFAULT_DATE)
@@ -63,26 +63,30 @@ def test_open_slots(self):
self.assertEqual(1, pool.running_slots()) # pylint: disable=no-value-for-parameter
self.assertEqual(1, pool.queued_slots()) # pylint: disable=no-value-for-parameter
self.assertEqual(2, pool.occupied_slots()) # pylint: disable=no-value-for-parameter
- self.assertEqual({
- "default_pool": {
- "open": 128,
- "queued": 0,
- "total": 128,
- "running": 0,
- },
- "test_pool": {
- "open": 3,
- "queued": 1,
- "running": 1,
- "total": 5,
+ self.assertEqual(
+ {
+ "default_pool": {
+ "open": 128,
+ "queued": 0,
+ "total": 128,
+ "running": 0,
+ },
+ "test_pool": {
+ "open": 3,
+ "queued": 1,
+ "running": 1,
+ "total": 5,
+ },
},
- }, pool.slots_stats())
+ pool.slots_stats(),
+ )
def test_infinite_slots(self):
pool = Pool(pool='test_pool', slots=-1)
dag = DAG(
dag_id='test_infinite_slots',
- start_date=DEFAULT_DATE, )
+ start_date=DEFAULT_DATE,
+ )
op1 = DummyOperator(task_id='dummy1', dag=dag, pool='test_pool')
op2 = DummyOperator(task_id='dummy2', dag=dag, pool='test_pool')
ti1 = TI(task=op1, execution_date=DEFAULT_DATE)
@@ -108,7 +112,8 @@ def test_default_pool_open_slots(self):
dag = DAG(
dag_id='test_default_pool_open_slots',
- start_date=DEFAULT_DATE, )
+ start_date=DEFAULT_DATE,
+ )
op1 = DummyOperator(task_id='dummy1', dag=dag)
op2 = DummyOperator(task_id='dummy2', dag=dag, pool_slots=2)
ti1 = TI(task=op1, execution_date=DEFAULT_DATE)
diff --git a/tests/models/test_renderedtifields.py b/tests/models/test_renderedtifields.py
index ec988a7e7227e..d2eab8d70e434 100644
--- a/tests/models/test_renderedtifields.py
+++ b/tests/models/test_renderedtifields.py
@@ -68,36 +68,41 @@ def setUp(self):
def tearDown(self):
clear_rendered_ti_fields()
- @parameterized.expand([
- (None, None),
- ([], []),
- ({}, {}),
- ("test-string", "test-string"),
- ({"foo": "bar"}, {"foo": "bar"}),
- ("{{ task.task_id }}", "test"),
- (date(2018, 12, 6), "2018-12-06"),
- (datetime(2018, 12, 6, 10, 55), "2018-12-06 10:55:00+00:00"),
- (
- ClassWithCustomAttributes(
- att1="{{ task.task_id }}", att2="{{ task.task_id }}", template_fields=["att1"]),
- "ClassWithCustomAttributes({'att1': 'test', 'att2': '{{ task.task_id }}', "
- "'template_fields': ['att1']})",
- ),
- (
- ClassWithCustomAttributes(nested1=ClassWithCustomAttributes(att1="{{ task.task_id }}",
- att2="{{ task.task_id }}",
- template_fields=["att1"]),
- nested2=ClassWithCustomAttributes(att3="{{ task.task_id }}",
- att4="{{ task.task_id }}",
- template_fields=["att3"]),
- template_fields=["nested1"]),
- "ClassWithCustomAttributes({'nested1': ClassWithCustomAttributes("
- "{'att1': 'test', 'att2': '{{ task.task_id }}', 'template_fields': ['att1']}), "
- "'nested2': ClassWithCustomAttributes("
- "{'att3': '{{ task.task_id }}', 'att4': '{{ task.task_id }}', 'template_fields': ['att3']}), "
- "'template_fields': ['nested1']})",
- ),
- ])
+ @parameterized.expand(
+ [
+ (None, None),
+ ([], []),
+ ({}, {}),
+ ("test-string", "test-string"),
+ ({"foo": "bar"}, {"foo": "bar"}),
+ ("{{ task.task_id }}", "test"),
+ (date(2018, 12, 6), "2018-12-06"),
+ (datetime(2018, 12, 6, 10, 55), "2018-12-06 10:55:00+00:00"),
+ (
+ ClassWithCustomAttributes(
+ att1="{{ task.task_id }}", att2="{{ task.task_id }}", template_fields=["att1"]
+ ),
+ "ClassWithCustomAttributes({'att1': 'test', 'att2': '{{ task.task_id }}', "
+ "'template_fields': ['att1']})",
+ ),
+ (
+ ClassWithCustomAttributes(
+ nested1=ClassWithCustomAttributes(
+ att1="{{ task.task_id }}", att2="{{ task.task_id }}", template_fields=["att1"]
+ ),
+ nested2=ClassWithCustomAttributes(
+ att3="{{ task.task_id }}", att4="{{ task.task_id }}", template_fields=["att3"]
+ ),
+ template_fields=["nested1"],
+ ),
+ "ClassWithCustomAttributes({'nested1': ClassWithCustomAttributes("
+ "{'att1': 'test', 'att2': '{{ task.task_id }}', 'template_fields': ['att1']}), "
+ "'nested2': ClassWithCustomAttributes("
+ "{'att3': '{{ task.task_id }}', 'att4': '{{ task.task_id }}', 'template_fields': ['att3']}), "
+ "'template_fields': ['nested1']})",
+ ),
+ ]
+ )
def test_get_templated_fields(self, templated_field, expected_rendered_field):
"""
Test that template_fields are rendered correctly, stored in the Database,
@@ -118,8 +123,8 @@ def test_get_templated_fields(self, templated_field, expected_rendered_field):
session.add(rtif)
self.assertEqual(
- {"bash_command": expected_rendered_field, "env": None},
- RTIF.get_templated_fields(ti=ti))
+ {"bash_command": expected_rendered_field, "env": None}, RTIF.get_templated_fields(ti=ti)
+ )
# Test the else part of get_templated_fields
# i.e. for the TIs that are not stored in RTIF table
@@ -130,14 +135,16 @@ def test_get_templated_fields(self, templated_field, expected_rendered_field):
ti2 = TI(task_2, EXECUTION_DATE)
self.assertIsNone(RTIF.get_templated_fields(ti=ti2))
- @parameterized.expand([
- (0, 1, 0, 1),
- (1, 1, 1, 1),
- (1, 0, 1, 0),
- (3, 1, 1, 1),
- (4, 2, 2, 1),
- (5, 2, 2, 1),
- ])
+ @parameterized.expand(
+ [
+ (0, 1, 0, 1),
+ (1, 1, 1, 1),
+ (1, 0, 1, 0),
+ (3, 1, 1, 1),
+ (4, 2, 2, 1),
+ (5, 2, 2, 1),
+ ]
+ )
def test_delete_old_records(self, rtif_num, num_to_keep, remaining_rtifs, expected_query_count):
"""
Test that old records are deleted from rendered_task_instance_fields table
@@ -156,8 +163,7 @@ def test_delete_old_records(self, rtif_num, num_to_keep, remaining_rtifs, expect
session.add_all(rtif_list)
session.commit()
- result = session.query(RTIF)\
- .filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all()
+ result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all()
for rtif in rtif_list:
self.assertIn(rtif, result)
@@ -167,8 +173,7 @@ def test_delete_old_records(self, rtif_num, num_to_keep, remaining_rtifs, expect
# Verify old records are deleted and only 'num_to_keep' records are kept
with assert_queries_count(expected_query_count):
RTIF.delete_old_records(task_id=task.task_id, dag_id=task.dag_id, num_to_keep=num_to_keep)
- result = session.query(RTIF) \
- .filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all()
+ result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all()
self.assertEqual(remaining_rtifs, len(result))
def test_write(self):
@@ -186,17 +191,16 @@ def test_write(self):
rtif = RTIF(TI(task=task, execution_date=EXECUTION_DATE))
rtif.write()
- result = session.query(RTIF.dag_id, RTIF.task_id, RTIF.rendered_fields).filter(
- RTIF.dag_id == rtif.dag_id,
- RTIF.task_id == rtif.task_id,
- RTIF.execution_date == rtif.execution_date
- ).first()
- self.assertEqual(
- (
- 'test_write', 'test', {
- 'bash_command': 'echo test_val', 'env': None
- }
- ), result)
+ result = (
+ session.query(RTIF.dag_id, RTIF.task_id, RTIF.rendered_fields)
+ .filter(
+ RTIF.dag_id == rtif.dag_id,
+ RTIF.task_id == rtif.task_id,
+ RTIF.execution_date == rtif.execution_date,
+ )
+ .first()
+ )
+ self.assertEqual(('test_write', 'test', {'bash_command': 'echo test_val', 'env': None}), result)
# Test that overwrite saves new values to the DB
Variable.delete("test_key")
@@ -208,14 +212,15 @@ def test_write(self):
rtif_updated = RTIF(TI(task=updated_task, execution_date=EXECUTION_DATE))
rtif_updated.write()
- result_updated = session.query(RTIF.dag_id, RTIF.task_id, RTIF.rendered_fields).filter(
- RTIF.dag_id == rtif_updated.dag_id,
- RTIF.task_id == rtif_updated.task_id,
- RTIF.execution_date == rtif_updated.execution_date
- ).first()
+ result_updated = (
+ session.query(RTIF.dag_id, RTIF.task_id, RTIF.rendered_fields)
+ .filter(
+ RTIF.dag_id == rtif_updated.dag_id,
+ RTIF.task_id == rtif_updated.task_id,
+ RTIF.execution_date == rtif_updated.execution_date,
+ )
+ .first()
+ )
self.assertEqual(
- (
- 'test_write', 'test', {
- 'bash_command': 'echo test_val_updated', 'env': None
- }
- ), result_updated)
+ ('test_write', 'test', {'bash_command': 'echo test_val_updated', 'env': None}), result_updated
+ )
diff --git a/tests/models/test_sensorinstance.py b/tests/models/test_sensorinstance.py
index 168d97f82d916..6246df8c14d98 100644
--- a/tests/models/test_sensorinstance.py
+++ b/tests/models/test_sensorinstance.py
@@ -24,22 +24,20 @@
class SensorInstanceTest(unittest.TestCase):
-
def test_get_classpath(self):
# Test the classpath in/out airflow
- obj1 = NamedHivePartitionSensor(
- partition_names=['test_partition'],
- task_id='meta_partition_test_1')
+ obj1 = NamedHivePartitionSensor(partition_names=['test_partition'], task_id='meta_partition_test_1')
obj1_classpath = SensorInstance.get_classpath(obj1)
- obj1_importpath = "airflow.providers.apache.hive." \
- "sensors.named_hive_partition.NamedHivePartitionSensor"
+ obj1_importpath = (
+ "airflow.providers.apache.hive.sensors.named_hive_partition.NamedHivePartitionSensor"
+ )
self.assertEqual(obj1_classpath, obj1_importpath)
def test_callable():
return
- obj3 = PythonSensor(python_callable=test_callable,
- task_id='python_sensor_test')
+
+ obj3 = PythonSensor(python_callable=test_callable, task_id='python_sensor_test')
obj3_classpath = SensorInstance.get_classpath(obj3)
obj3_importpath = "airflow.sensors.python.PythonSensor"
diff --git a/tests/models/test_serialized_dag.py b/tests/models/test_serialized_dag.py
index a6309327f2af4..590ceed56220d 100644
--- a/tests/models/test_serialized_dag.py
+++ b/tests/models/test_serialized_dag.py
@@ -52,8 +52,7 @@ def tearDown(self):
def test_dag_fileloc_hash(self):
"""Verifies the correctness of hashing file path."""
- self.assertEqual(DagCode.dag_fileloc_hash('/airflow/dags/test_dag.py'),
- 33826252060516589)
+ self.assertEqual(DagCode.dag_fileloc_hash('/airflow/dags/test_dag.py'), 33826252060516589)
def _write_example_dags(self):
example_dags = make_example_dags(example_dags_module)
@@ -68,8 +67,7 @@ def test_write_dag(self):
with create_session() as session:
for dag in example_dags.values():
self.assertTrue(SDM.has_dag(dag.dag_id))
- result = session.query(
- SDM.fileloc, SDM.data).filter(SDM.dag_id == dag.dag_id).one()
+ result = session.query(SDM.fileloc, SDM.data).filter(SDM.dag_id == dag.dag_id).one()
self.assertTrue(result.fileloc == dag.full_filepath)
# Verifies JSON schema.
@@ -142,7 +140,9 @@ def test_remove_dags_by_filepath(self):
def test_bulk_sync_to_db(self):
dags = [
- DAG("dag_1"), DAG("dag_2"), DAG("dag_3"),
+ DAG("dag_1"),
+ DAG("dag_2"),
+ DAG("dag_3"),
]
with assert_queries_count(10):
SDM.bulk_sync_to_db(dags)
diff --git a/tests/models/test_skipmixin.py b/tests/models/test_skipmixin.py
index 0a96ee23b6195..a9145701c6458 100644
--- a/tests/models/test_skipmixin.py
+++ b/tests/models/test_skipmixin.py
@@ -35,7 +35,6 @@
class TestSkipMixin(unittest.TestCase):
-
@patch('airflow.utils.timezone.utcnow')
def test_skip(self, mock_now):
session = settings.Session()
@@ -52,11 +51,7 @@ def test_skip(self, mock_now):
execution_date=now,
state=State.FAILED,
)
- SkipMixin().skip(
- dag_run=dag_run,
- execution_date=now,
- tasks=tasks,
- session=session)
+ SkipMixin().skip(dag_run=dag_run, execution_date=now, tasks=tasks, session=session)
session.query(TI).filter(
TI.dag_id == 'dag',
@@ -77,11 +72,7 @@ def test_skip_none_dagrun(self, mock_now):
)
with dag:
tasks = [DummyOperator(task_id='task')]
- SkipMixin().skip(
- dag_run=None,
- execution_date=now,
- tasks=tasks,
- session=session)
+ SkipMixin().skip(dag_run=None, execution_date=now, tasks=tasks, session=session)
session.query(TI).filter(
TI.dag_id == 'dag',
@@ -113,10 +104,7 @@ def test_skip_all_except(self):
ti2 = TI(task2, execution_date=DEFAULT_DATE)
ti3 = TI(task3, execution_date=DEFAULT_DATE)
- SkipMixin().skip_all_except(
- ti=ti1,
- branch_task_ids=['task2']
- )
+ SkipMixin().skip_all_except(ti=ti1, branch_task_ids=['task2'])
def get_state(ti):
ti.refresh_from_db()
diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py
index bcf4ca0346530..f91d9070d9189 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -34,7 +34,14 @@
from airflow.exceptions import AirflowException, AirflowFailException, AirflowSkipException
from airflow.jobs.scheduler_job import SchedulerJob
from airflow.models import (
- DAG, DagModel, DagRun, Pool, RenderedTaskInstanceFields, TaskInstance as TI, TaskReschedule, Variable,
+ DAG,
+ DagModel,
+ DagRun,
+ Pool,
+ RenderedTaskInstanceFields,
+ TaskInstance as TI,
+ TaskReschedule,
+ Variable,
)
from airflow.operators.bash import BashOperator
from airflow.operators.dummy_operator import DummyOperator
@@ -74,15 +81,17 @@ def wrap_task_instance(self, ti):
def success_handler(self, context): # pylint: disable=unused-argument
self.callback_ran = True
session = settings.Session()
- temp_instance = session.query(TI).filter(
- TI.task_id == self.task_id).filter(
- TI.dag_id == self.dag_id).filter(
- TI.execution_date == self.execution_date).one()
+ temp_instance = (
+ session.query(TI)
+ .filter(TI.task_id == self.task_id)
+ .filter(TI.dag_id == self.dag_id)
+ .filter(TI.execution_date == self.execution_date)
+ .one()
+ )
self.task_state_in_callback = temp_instance.state
class TestTaskInstance(unittest.TestCase):
-
@staticmethod
def clean_db():
db.clear_db_dags()
@@ -106,8 +115,7 @@ def test_set_task_dates(self):
"""
Test that tasks properly take start/end dates from DAGs
"""
- dag = DAG('dag', start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE + datetime.timedelta(days=10))
+ dag = DAG('dag', start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=10))
op1 = DummyOperator(task_id='op_1', owner='test')
@@ -115,31 +123,29 @@ def test_set_task_dates(self):
# dag should assign its dates to op1 because op1 has no dates
dag.add_task(op1)
- self.assertTrue(
- op1.start_date == dag.start_date and op1.end_date == dag.end_date)
+ self.assertTrue(op1.start_date == dag.start_date and op1.end_date == dag.end_date)
op2 = DummyOperator(
task_id='op_2',
owner='test',
start_date=DEFAULT_DATE - datetime.timedelta(days=1),
- end_date=DEFAULT_DATE + datetime.timedelta(days=11))
+ end_date=DEFAULT_DATE + datetime.timedelta(days=11),
+ )
# dag should assign its dates to op2 because they are more restrictive
dag.add_task(op2)
- self.assertTrue(
- op2.start_date == dag.start_date and op2.end_date == dag.end_date)
+ self.assertTrue(op2.start_date == dag.start_date and op2.end_date == dag.end_date)
op3 = DummyOperator(
task_id='op_3',
owner='test',
start_date=DEFAULT_DATE + datetime.timedelta(days=1),
- end_date=DEFAULT_DATE + datetime.timedelta(days=9))
+ end_date=DEFAULT_DATE + datetime.timedelta(days=9),
+ )
# op3 should keep its dates because they are more restrictive
dag.add_task(op3)
- self.assertTrue(
- op3.start_date == DEFAULT_DATE + datetime.timedelta(days=1))
- self.assertTrue(
- op3.end_date == DEFAULT_DATE + datetime.timedelta(days=9))
+ self.assertTrue(op3.start_date == DEFAULT_DATE + datetime.timedelta(days=1))
+ self.assertTrue(op3.end_date == DEFAULT_DATE + datetime.timedelta(days=9))
def test_timezone_awareness(self):
naive_datetime = DEFAULT_DATE.replace(tzinfo=None)
@@ -168,9 +174,9 @@ def test_timezone_awareness(self):
def test_task_naive_datetime(self):
naive_datetime = DEFAULT_DATE.replace(tzinfo=None)
- op_no_dag = DummyOperator(task_id='test_task_naive_datetime',
- start_date=naive_datetime,
- end_date=naive_datetime)
+ op_no_dag = DummyOperator(
+ task_id='test_task_naive_datetime', start_date=naive_datetime, end_date=naive_datetime
+ )
self.assertTrue(op_no_dag.start_date.tzinfo)
self.assertTrue(op_no_dag.end_date.tzinfo)
@@ -213,9 +219,7 @@ def test_infer_dag(self):
op4 = DummyOperator(task_id='test_op_4', owner='test', dag=dag2)
# double check dags
- self.assertEqual(
- [i.has_dag() for i in [op1, op2, op3, op4]],
- [False, False, True, True])
+ self.assertEqual([i.has_dag() for i in [op1, op2, op3, op4]], [False, False, True, True])
# can't combine operators with no dags
self.assertRaises(AirflowException, op1.set_downstream, op2)
@@ -246,8 +250,12 @@ def test_bitshift_compose_operators(self):
def test_requeue_over_dag_concurrency(self, mock_concurrency_reached):
mock_concurrency_reached.return_value = True
- dag = DAG(dag_id='test_requeue_over_dag_concurrency', start_date=DEFAULT_DATE,
- max_active_runs=1, concurrency=2)
+ dag = DAG(
+ dag_id='test_requeue_over_dag_concurrency',
+ start_date=DEFAULT_DATE,
+ max_active_runs=1,
+ concurrency=2,
+ )
task = DummyOperator(task_id='test_requeue_over_dag_concurrency_op', dag=dag)
ti = TI(task=task, execution_date=timezone.utcnow(), state=State.QUEUED)
@@ -259,10 +267,13 @@ def test_requeue_over_dag_concurrency(self, mock_concurrency_reached):
self.assertEqual(ti.state, State.NONE)
def test_requeue_over_task_concurrency(self):
- dag = DAG(dag_id='test_requeue_over_task_concurrency', start_date=DEFAULT_DATE,
- max_active_runs=1, concurrency=2)
- task = DummyOperator(task_id='test_requeue_over_task_concurrency_op', dag=dag,
- task_concurrency=0)
+ dag = DAG(
+ dag_id='test_requeue_over_task_concurrency',
+ start_date=DEFAULT_DATE,
+ max_active_runs=1,
+ concurrency=2,
+ )
+ task = DummyOperator(task_id='test_requeue_over_task_concurrency_op', dag=dag, task_concurrency=0)
ti = TI(task=task, execution_date=timezone.utcnow(), state=State.QUEUED)
# TI.run() will sync from DB before validating deps.
@@ -273,10 +284,13 @@ def test_requeue_over_task_concurrency(self):
self.assertEqual(ti.state, State.NONE)
def test_requeue_over_pool_concurrency(self):
- dag = DAG(dag_id='test_requeue_over_pool_concurrency', start_date=DEFAULT_DATE,
- max_active_runs=1, concurrency=2)
- task = DummyOperator(task_id='test_requeue_over_pool_concurrency_op', dag=dag,
- task_concurrency=0)
+ dag = DAG(
+ dag_id='test_requeue_over_pool_concurrency',
+ start_date=DEFAULT_DATE,
+ max_active_runs=1,
+ concurrency=2,
+ )
+ task = DummyOperator(task_id='test_requeue_over_pool_concurrency_op', dag=dag, task_concurrency=0)
ti = TI(task=task, execution_date=timezone.utcnow(), state=State.QUEUED)
# TI.run() will sync from DB before validating deps.
@@ -297,9 +311,9 @@ def test_not_requeue_non_requeueable_task_instance(self):
dag=dag,
pool='test_pool',
owner='airflow',
- start_date=timezone.datetime(2016, 2, 1, 0, 0, 0))
- ti = TI(
- task=task, execution_date=timezone.utcnow(), state=State.QUEUED)
+ start_date=timezone.datetime(2016, 2, 1, 0, 0, 0),
+ )
+ ti = TI(task=task, execution_date=timezone.utcnow(), state=State.QUEUED)
with create_session() as session:
session.add(ti)
session.commit()
@@ -309,16 +323,13 @@ def test_not_requeue_non_requeueable_task_instance(self):
patch_dict = {}
for dep in all_non_requeueable_deps:
class_name = dep.__class__.__name__
- dep_patch = patch('{}.{}.{}'.format(dep.__module__, class_name,
- dep._get_dep_statuses.__name__))
+ dep_patch = patch(f'{dep.__module__}.{class_name}.{dep._get_dep_statuses.__name__}')
method_patch = dep_patch.start()
- method_patch.return_value = iter([TIDepStatus('mock_' + class_name, True,
- 'mock')])
+ method_patch.return_value = iter([TIDepStatus('mock_' + class_name, True, 'mock')])
patch_dict[class_name] = (dep_patch, method_patch)
for class_name, (dep_patch, method_patch) in patch_dict.items():
- method_patch.return_value = iter(
- [TIDepStatus('mock_' + class_name, False, 'mock')])
+ method_patch.return_value = iter([TIDepStatus('mock_' + class_name, False, 'mock')])
ti.run()
self.assertEqual(ti.state, State.QUEUED)
dep_patch.return_value = TIDepStatus('mock_' + class_name, True, 'mock')
@@ -331,17 +342,16 @@ def test_mark_non_runnable_task_as_success(self):
test that running task with mark_success param update task state
as SUCCESS without running task despite it fails dependency checks.
"""
- non_runnable_state = (
- set(State.task_states) - RUNNABLE_STATES - set(State.SUCCESS)).pop()
+ non_runnable_state = (set(State.task_states) - RUNNABLE_STATES - set(State.SUCCESS)).pop()
dag = models.DAG(dag_id='test_mark_non_runnable_task_as_success')
task = DummyOperator(
task_id='test_mark_non_runnable_task_as_success_op',
dag=dag,
pool='test_pool',
owner='airflow',
- start_date=timezone.datetime(2016, 2, 1, 0, 0, 0))
- ti = TI(
- task=task, execution_date=timezone.utcnow(), state=non_runnable_state)
+ start_date=timezone.datetime(2016, 2, 1, 0, 0, 0),
+ )
+ ti = TI(task=task, execution_date=timezone.utcnow(), state=non_runnable_state)
# TI.run() will sync from DB before validating deps.
with create_session() as session:
session.add(ti)
@@ -361,9 +371,13 @@ def test_run_pooling_task(self):
test that running a task in an existing pool update task state as SUCCESS.
"""
dag = models.DAG(dag_id='test_run_pooling_task')
- task = DummyOperator(task_id='test_run_pooling_task_op', dag=dag,
- pool='test_pool', owner='airflow',
- start_date=timezone.datetime(2016, 2, 1, 0, 0, 0))
+ task = DummyOperator(
+ task_id='test_run_pooling_task_op',
+ dag=dag,
+ pool='test_pool',
+ owner='airflow',
+ start_date=timezone.datetime(2016, 2, 1, 0, 0, 0),
+ )
ti = TI(task=task, execution_date=timezone.utcnow())
dag.create_dagrun(
@@ -380,11 +394,17 @@ def test_pool_slots_property(self):
"""
test that try to create a task with pool_slots less than 1
"""
+
def create_task_instance():
dag = models.DAG(dag_id='test_run_pooling_task')
- task = DummyOperator(task_id='test_run_pooling_task_op', dag=dag,
- pool='test_pool', pool_slots=0, owner='airflow',
- start_date=timezone.datetime(2016, 2, 1, 0, 0, 0))
+ task = DummyOperator(
+ task_id='test_run_pooling_task_op',
+ dag=dag,
+ pool='test_pool',
+ pool_slots=0,
+ owner='airflow',
+ start_date=timezone.datetime(2016, 2, 1, 0, 0, 0),
+ )
return TI(task=task, execution_date=timezone.utcnow())
self.assertRaises(AirflowException, create_task_instance)
@@ -395,9 +415,12 @@ def test_ti_updates_with_task(self, session=None):
test that updating the executor_config propogates to the TaskInstance DB
"""
with models.DAG(dag_id='test_run_pooling_task') as dag:
- task = DummyOperator(task_id='test_run_pooling_task_op', owner='airflow',
- executor_config={'foo': 'bar'},
- start_date=timezone.datetime(2016, 2, 1, 0, 0, 0))
+ task = DummyOperator(
+ task_id='test_run_pooling_task_op',
+ owner='airflow',
+ executor_config={'foo': 'bar'},
+ start_date=timezone.datetime(2016, 2, 1, 0, 0, 0),
+ )
ti = TI(task=task, execution_date=timezone.utcnow())
dag.create_dagrun(
@@ -411,9 +434,12 @@ def test_ti_updates_with_task(self, session=None):
tis = dag.get_task_instances()
self.assertEqual({'foo': 'bar'}, tis[0].executor_config)
with models.DAG(dag_id='test_run_pooling_task') as dag:
- task2 = DummyOperator(task_id='test_run_pooling_task_op', owner='airflow',
- executor_config={'bar': 'baz'},
- start_date=timezone.datetime(2016, 2, 1, 0, 0, 0))
+ task2 = DummyOperator(
+ task_id='test_run_pooling_task_op',
+ owner='airflow',
+ executor_config={'bar': 'baz'},
+ start_date=timezone.datetime(2016, 2, 1, 0, 0, 0),
+ )
ti = TI(task=task2, execution_date=timezone.utcnow())
@@ -440,7 +466,8 @@ def test_run_pooling_task_with_mark_success(self):
dag=dag,
pool='test_pool',
owner='airflow',
- start_date=timezone.datetime(2016, 2, 1, 0, 0, 0))
+ start_date=timezone.datetime(2016, 2, 1, 0, 0, 0),
+ )
ti = TI(task=task, execution_date=timezone.utcnow())
dag.create_dagrun(
@@ -466,7 +493,8 @@ def raise_skip_exception():
dag=dag,
python_callable=raise_skip_exception,
owner='airflow',
- start_date=timezone.datetime(2016, 2, 1, 0, 0, 0))
+ start_date=timezone.datetime(2016, 2, 1, 0, 0, 0),
+ )
ti = TI(task=task, execution_date=timezone.utcnow())
dag.create_dagrun(
execution_date=ti.execution_date,
@@ -488,7 +516,8 @@ def test_retry_delay(self):
retry_delay=datetime.timedelta(seconds=3),
dag=dag,
owner='airflow',
- start_date=timezone.datetime(2016, 2, 1, 0, 0, 0))
+ start_date=timezone.datetime(2016, 2, 1, 0, 0, 0),
+ )
def run_with_error(ti):
try:
@@ -532,7 +561,8 @@ def test_retry_handling(self):
retry_delay=datetime.timedelta(seconds=0),
dag=dag,
owner='test_pool',
- start_date=timezone.datetime(2016, 2, 1, 0, 0, 0))
+ start_date=timezone.datetime(2016, 2, 1, 0, 0, 0),
+ )
def run_with_error(ti):
try:
@@ -540,8 +570,7 @@ def run_with_error(ti):
except AirflowException:
pass
- ti = TI(
- task=task, execution_date=timezone.utcnow())
+ ti = TI(task=task, execution_date=timezone.utcnow())
self.assertEqual(ti.try_number, 1)
# first run -- up for retry
@@ -588,9 +617,9 @@ def test_next_retry_datetime(self):
max_retry_delay=max_delay,
dag=dag,
owner='airflow',
- start_date=timezone.datetime(2016, 2, 1, 0, 0, 0))
- ti = TI(
- task=task, execution_date=DEFAULT_DATE)
+ start_date=timezone.datetime(2016, 2, 1, 0, 0, 0),
+ )
+ ti = TI(task=task, execution_date=DEFAULT_DATE)
ti.end_date = pendulum.instance(timezone.utcnow())
date = ti.next_retry_datetime()
@@ -632,9 +661,9 @@ def test_next_retry_datetime_short_intervals(self):
max_retry_delay=max_delay,
dag=dag,
owner='airflow',
- start_date=timezone.datetime(2016, 2, 1, 0, 0, 0))
- ti = TI(
- task=task, execution_date=DEFAULT_DATE)
+ start_date=timezone.datetime(2016, 2, 1, 0, 0, 0),
+ )
+ ti = TI(task=task, execution_date=DEFAULT_DATE)
ti.end_date = pendulum.instance(timezone.utcnow())
date = ti.next_retry_datetime()
@@ -666,7 +695,8 @@ def func():
dag=dag,
owner='airflow',
pool='test_pool',
- start_date=timezone.datetime(2016, 2, 1, 0, 0, 0))
+ start_date=timezone.datetime(2016, 2, 1, 0, 0, 0),
+ )
ti = TI(task=task, execution_date=timezone.utcnow())
self.assertEqual(ti._try_number, 0)
@@ -678,10 +708,15 @@ def func():
run_type=DagRunType.SCHEDULED,
)
- def run_ti_and_assert(run_date, expected_start_date, expected_end_date,
- expected_duration,
- expected_state, expected_try_number,
- expected_task_reschedule_count):
+ def run_ti_and_assert(
+ run_date,
+ expected_start_date,
+ expected_end_date,
+ expected_duration,
+ expected_state,
+ expected_try_number,
+ expected_task_reschedule_count,
+ ):
with freeze_time(run_date):
try:
ti.run()
@@ -767,16 +802,22 @@ def func():
dag=dag,
owner='airflow',
pool='test_pool',
- start_date=timezone.datetime(2016, 2, 1, 0, 0, 0))
+ start_date=timezone.datetime(2016, 2, 1, 0, 0, 0),
+ )
ti = TI(task=task, execution_date=timezone.utcnow())
self.assertEqual(ti._try_number, 0)
self.assertEqual(ti.try_number, 1)
- def run_ti_and_assert(run_date, expected_start_date, expected_end_date,
- expected_duration,
- expected_state, expected_try_number,
- expected_task_reschedule_count):
+ def run_ti_and_assert(
+ run_date,
+ expected_start_date,
+ expected_end_date,
+ expected_duration,
+ expected_state,
+ expected_try_number,
+ expected_task_reschedule_count,
+ ):
with freeze_time(run_date):
try:
ti.run()
@@ -808,10 +849,7 @@ def run_ti_and_assert(run_date, expected_start_date, expected_end_date,
self.assertFalse(trs)
def test_depends_on_past(self):
- dag = DAG(
- dag_id='test_depends_on_past',
- start_date=DEFAULT_DATE
- )
+ dag = DAG(dag_id='test_depends_on_past', start_date=DEFAULT_DATE)
task = DummyOperator(
task_id='test_dop_task',
@@ -836,10 +874,7 @@ def test_depends_on_past(self):
self.assertIs(ti.state, None)
# ignore first depends_on_past to allow the run
- task.run(
- start_date=run_date,
- end_date=run_date,
- ignore_first_depends_on_past=True)
+ task.run(start_date=run_date, end_date=run_date, ignore_first_depends_on_past=True)
ti.refresh_from_db()
self.assertEqual(ti.state, State.SUCCESS)
@@ -847,58 +882,64 @@ def test_depends_on_past(self):
# of the trigger_rule under various circumstances
# Numeric fields are in order:
# successes, skipped, failed, upstream_failed, done
- @parameterized.expand([
-
- #
- # Tests for all_success
- #
- ['all_success', 5, 0, 0, 0, 0, True, None, True],
- ['all_success', 2, 0, 0, 0, 0, True, None, False],
- ['all_success', 2, 0, 1, 0, 0, True, State.UPSTREAM_FAILED, False],
- ['all_success', 2, 1, 0, 0, 0, True, State.SKIPPED, False],
- #
- # Tests for one_success
- #
- ['one_success', 5, 0, 0, 0, 5, True, None, True],
- ['one_success', 2, 0, 0, 0, 2, True, None, True],
- ['one_success', 2, 0, 1, 0, 3, True, None, True],
- ['one_success', 2, 1, 0, 0, 3, True, None, True],
- #
- # Tests for all_failed
- #
- ['all_failed', 5, 0, 0, 0, 5, True, State.SKIPPED, False],
- ['all_failed', 0, 0, 5, 0, 5, True, None, True],
- ['all_failed', 2, 0, 0, 0, 2, True, State.SKIPPED, False],
- ['all_failed', 2, 0, 1, 0, 3, True, State.SKIPPED, False],
- ['all_failed', 2, 1, 0, 0, 3, True, State.SKIPPED, False],
- #
- # Tests for one_failed
- #
- ['one_failed', 5, 0, 0, 0, 0, True, None, False],
- ['one_failed', 2, 0, 0, 0, 0, True, None, False],
- ['one_failed', 2, 0, 1, 0, 0, True, None, True],
- ['one_failed', 2, 1, 0, 0, 3, True, None, False],
- ['one_failed', 2, 3, 0, 0, 5, True, State.SKIPPED, False],
- #
- # Tests for done
- #
- ['all_done', 5, 0, 0, 0, 5, True, None, True],
- ['all_done', 2, 0, 0, 0, 2, True, None, False],
- ['all_done', 2, 0, 1, 0, 3, True, None, False],
- ['all_done', 2, 1, 0, 0, 3, True, None, False]
- ])
- def test_check_task_dependencies(self, trigger_rule, successes, skipped,
- failed, upstream_failed, done,
- flag_upstream_failed,
- expect_state, expect_completed):
+ @parameterized.expand(
+ [
+ #
+ # Tests for all_success
+ #
+ ['all_success', 5, 0, 0, 0, 0, True, None, True],
+ ['all_success', 2, 0, 0, 0, 0, True, None, False],
+ ['all_success', 2, 0, 1, 0, 0, True, State.UPSTREAM_FAILED, False],
+ ['all_success', 2, 1, 0, 0, 0, True, State.SKIPPED, False],
+ #
+ # Tests for one_success
+ #
+ ['one_success', 5, 0, 0, 0, 5, True, None, True],
+ ['one_success', 2, 0, 0, 0, 2, True, None, True],
+ ['one_success', 2, 0, 1, 0, 3, True, None, True],
+ ['one_success', 2, 1, 0, 0, 3, True, None, True],
+ #
+ # Tests for all_failed
+ #
+ ['all_failed', 5, 0, 0, 0, 5, True, State.SKIPPED, False],
+ ['all_failed', 0, 0, 5, 0, 5, True, None, True],
+ ['all_failed', 2, 0, 0, 0, 2, True, State.SKIPPED, False],
+ ['all_failed', 2, 0, 1, 0, 3, True, State.SKIPPED, False],
+ ['all_failed', 2, 1, 0, 0, 3, True, State.SKIPPED, False],
+ #
+ # Tests for one_failed
+ #
+ ['one_failed', 5, 0, 0, 0, 0, True, None, False],
+ ['one_failed', 2, 0, 0, 0, 0, True, None, False],
+ ['one_failed', 2, 0, 1, 0, 0, True, None, True],
+ ['one_failed', 2, 1, 0, 0, 3, True, None, False],
+ ['one_failed', 2, 3, 0, 0, 5, True, State.SKIPPED, False],
+ #
+ # Tests for done
+ #
+ ['all_done', 5, 0, 0, 0, 5, True, None, True],
+ ['all_done', 2, 0, 0, 0, 2, True, None, False],
+ ['all_done', 2, 0, 1, 0, 3, True, None, False],
+ ['all_done', 2, 1, 0, 0, 3, True, None, False],
+ ]
+ )
+ def test_check_task_dependencies(
+ self,
+ trigger_rule,
+ successes,
+ skipped,
+ failed,
+ upstream_failed,
+ done,
+ flag_upstream_failed,
+ expect_state,
+ expect_completed,
+ ):
start_date = timezone.datetime(2016, 2, 1, 0, 0, 0)
dag = models.DAG('test-dag', start_date=start_date)
- downstream = DummyOperator(task_id='downstream',
- dag=dag, owner='airflow',
- trigger_rule=trigger_rule)
+ downstream = DummyOperator(task_id='downstream', dag=dag, owner='airflow', trigger_rule=trigger_rule)
for i in range(5):
- task = DummyOperator(task_id=f'runme_{i}',
- dag=dag, owner='airflow')
+ task = DummyOperator(task_id=f'runme_{i}', dag=dag, owner='airflow')
task.set_downstream(downstream)
run_date = task.start_date + datetime.timedelta(days=5)
@@ -923,20 +964,24 @@ def test_respects_prev_dagrun_dep(self):
ti = TI(task, DEFAULT_DATE)
failing_status = [TIDepStatus('test fail status name', False, 'test fail reason')]
passing_status = [TIDepStatus('test pass status name', True, 'test passing reason')]
- with patch('airflow.ti_deps.deps.prev_dagrun_dep.PrevDagrunDep.get_dep_statuses',
- return_value=failing_status):
+ with patch(
+ 'airflow.ti_deps.deps.prev_dagrun_dep.PrevDagrunDep.get_dep_statuses', return_value=failing_status
+ ):
self.assertFalse(ti.are_dependencies_met())
- with patch('airflow.ti_deps.deps.prev_dagrun_dep.PrevDagrunDep.get_dep_statuses',
- return_value=passing_status):
+ with patch(
+ 'airflow.ti_deps.deps.prev_dagrun_dep.PrevDagrunDep.get_dep_statuses', return_value=passing_status
+ ):
self.assertTrue(ti.are_dependencies_met())
- @parameterized.expand([
- (State.SUCCESS, True),
- (State.SKIPPED, True),
- (State.RUNNING, False),
- (State.FAILED, False),
- (State.NONE, False),
- ])
+ @parameterized.expand(
+ [
+ (State.SUCCESS, True),
+ (State.SKIPPED, True),
+ (State.RUNNING, False),
+ (State.FAILED, False),
+ (State.NONE, False),
+ ]
+ )
def test_are_dependents_done(self, downstream_ti_state, expected_are_dependents_done):
with DAG(dag_id='test_dag'):
task = DummyOperator(task_id='task', start_date=DEFAULT_DATE)
@@ -954,8 +999,10 @@ def test_xcom_pull(self):
Test xcom_pull, using different filtering methods.
"""
dag = models.DAG(
- dag_id='test_xcom', schedule_interval='@monthly',
- start_date=timezone.datetime(2016, 6, 1, 0, 0, 0))
+ dag_id='test_xcom',
+ schedule_interval='@monthly',
+ start_date=timezone.datetime(2016, 6, 1, 0, 0, 0),
+ )
exec_date = timezone.utcnow()
@@ -982,8 +1029,7 @@ def test_xcom_pull(self):
result = ti1.xcom_pull(task_ids='test_xcom_2', key='foo')
self.assertEqual(result, 'baz')
# Pull the values pushed by both tasks
- result = ti1.xcom_pull(
- task_ids=['test_xcom_1', 'test_xcom_2'], key='foo')
+ result = ti1.xcom_pull(task_ids=['test_xcom_1', 'test_xcom_2'], key='foo')
self.assertEqual(result, ['baz', 'bar'])
def test_xcom_pull_after_success(self):
@@ -999,7 +1045,8 @@ def test_xcom_pull_after_success(self):
dag=dag,
pool='test_xcom',
owner='airflow',
- start_date=timezone.datetime(2016, 6, 2, 0, 0, 0))
+ start_date=timezone.datetime(2016, 6, 2, 0, 0, 0),
+ )
exec_date = timezone.utcnow()
ti = TI(task=task, execution_date=exec_date)
@@ -1039,7 +1086,8 @@ def test_xcom_pull_different_execution_date(self):
dag=dag,
pool='test_xcom',
owner='airflow',
- start_date=timezone.datetime(2016, 6, 2, 0, 0, 0))
+ start_date=timezone.datetime(2016, 6, 2, 0, 0, 0),
+ )
exec_date = timezone.utcnow()
ti = TI(task=task, execution_date=exec_date)
@@ -1054,18 +1102,14 @@ def test_xcom_pull_different_execution_date(self):
self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), value)
ti.run()
exec_date += datetime.timedelta(days=1)
- ti = TI(
- task=task, execution_date=exec_date)
+ ti = TI(task=task, execution_date=exec_date)
ti.run()
# We have set a new execution date (and did not pass in
# 'include_prior_dates'which means this task should now have a cleared
# xcom value
self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), None)
# We *should* get a value using 'include_prior_dates'
- self.assertEqual(ti.xcom_pull(task_ids='test_xcom',
- key=key,
- include_prior_dates=True),
- value)
+ self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key, include_prior_dates=True), value)
def test_xcom_push_flag(self):
"""
@@ -1082,7 +1126,7 @@ def test_xcom_push_flag(self):
python_callable=lambda: value,
do_xcom_push=False,
owner='airflow',
- start_date=datetime.datetime(2017, 1, 1)
+ start_date=datetime.datetime(2017, 1, 1),
)
ti = TI(task=task, execution_date=datetime.datetime(2017, 1, 1))
dag.create_dagrun(
@@ -1091,12 +1135,7 @@ def test_xcom_push_flag(self):
run_type=DagRunType.SCHEDULED,
)
ti.run()
- self.assertEqual(
- ti.xcom_pull(
- task_ids=task_id, key=models.XCOM_RETURN_KEY
- ),
- None
- )
+ self.assertEqual(ti.xcom_pull(task_ids=task_id, key=models.XCOM_RETURN_KEY), None)
def test_post_execute_hook(self):
"""
@@ -1118,7 +1157,8 @@ def post_execute(self, context, result=None):
dag=dag,
python_callable=lambda: 'error',
owner='airflow',
- start_date=timezone.datetime(2017, 2, 1))
+ start_date=timezone.datetime(2017, 2, 1),
+ )
ti = TI(task=task, execution_date=timezone.utcnow())
with self.assertRaises(TestError):
@@ -1127,8 +1167,7 @@ def post_execute(self, context, result=None):
def test_check_and_change_state_before_execution(self):
dag = models.DAG(dag_id='test_check_and_change_state_before_execution')
task = DummyOperator(task_id='task', dag=dag, start_date=DEFAULT_DATE)
- ti = TI(
- task=task, execution_date=timezone.utcnow())
+ ti = TI(task=task, execution_date=timezone.utcnow())
self.assertEqual(ti._try_number, 0)
self.assertTrue(ti.check_and_change_state_before_execution())
# State should be running, and try_number column should be incremented
@@ -1140,8 +1179,7 @@ def test_check_and_change_state_before_execution_dep_not_met(self):
task = DummyOperator(task_id='task', dag=dag, start_date=DEFAULT_DATE)
task2 = DummyOperator(task_id='task2', dag=dag, start_date=DEFAULT_DATE)
task >> task2
- ti = TI(
- task=task2, execution_date=timezone.utcnow())
+ ti = TI(task=task2, execution_date=timezone.utcnow())
self.assertFalse(ti.check_and_change_state_before_execution())
def test_try_number(self):
@@ -1212,8 +1250,8 @@ def test_mark_success_url(self):
task = DummyOperator(task_id='op', dag=dag)
ti = TI(task=task, execution_date=now)
query = urllib.parse.parse_qs(
- urllib.parse.urlparse(ti.mark_success_url).query,
- keep_blank_values=True, strict_parsing=True)
+ urllib.parse.urlparse(ti.mark_success_url).query, keep_blank_values=True, strict_parsing=True
+ )
self.assertEqual(query['dag_id'][0], 'dag')
self.assertEqual(query['task_id'][0], 'op')
self.assertEqual(pendulum.parse(query['execution_date'][0]), now)
@@ -1252,11 +1290,8 @@ def test_overwrite_params_with_dag_run_conf_none(self):
def test_email_alert(self, mock_send_email):
dag = models.DAG(dag_id='test_failure_email')
task = BashOperator(
- task_id='test_email_alert',
- dag=dag,
- bash_command='exit 1',
- start_date=DEFAULT_DATE,
- email='to')
+ task_id='test_email_alert', dag=dag, bash_command='exit 1', start_date=DEFAULT_DATE, email='to'
+ )
ti = TI(task=task, execution_date=timezone.utcnow())
@@ -1271,10 +1306,12 @@ def test_email_alert(self, mock_send_email):
self.assertIn('test_email_alert', body)
self.assertIn('Try 1', body)
- @conf_vars({
- ('email', 'subject_template'): '/subject/path',
- ('email', 'html_content_template'): '/html_content/path',
- })
+ @conf_vars(
+ {
+ ('email', 'subject_template'): '/subject/path',
+ ('email', 'html_content_template'): '/html_content/path',
+ }
+ )
@patch('airflow.models.taskinstance.send_email')
def test_email_alert_with_config(self, mock_send_email):
dag = models.DAG(dag_id='test_failure_email')
@@ -1283,7 +1320,8 @@ def test_email_alert_with_config(self, mock_send_email):
dag=dag,
bash_command='exit 1',
start_date=DEFAULT_DATE,
- email='to')
+ email='to',
+ )
ti = TI(task=task, execution_date=timezone.utcnow())
@@ -1318,10 +1356,17 @@ def test_set_duration_empty_dates(self):
def test_success_callback_no_race_condition(self):
callback_wrapper = CallbackWrapper()
- dag = DAG('test_success_callback_no_race_condition', start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE + datetime.timedelta(days=10))
- task = DummyOperator(task_id='op', email='test@test.test',
- on_success_callback=callback_wrapper.success_handler, dag=dag)
+ dag = DAG(
+ 'test_success_callback_no_race_condition',
+ start_date=DEFAULT_DATE,
+ end_date=DEFAULT_DATE + datetime.timedelta(days=10),
+ )
+ task = DummyOperator(
+ task_id='op',
+ email='test@test.test',
+ on_success_callback=callback_wrapper.success_handler,
+ dag=dag,
+ )
ti = TI(task=task, execution_date=datetime.datetime.now())
ti.state = State.RUNNING
session = settings.Session()
@@ -1343,8 +1388,9 @@ def test_success_callback_no_race_condition(self):
self.assertEqual(ti.state, State.SUCCESS)
@staticmethod
- def _test_previous_dates_setup(schedule_interval: Union[str, datetime.timedelta, None],
- catchup: bool, scenario: List[str]) -> list:
+ def _test_previous_dates_setup(
+ schedule_interval: Union[str, datetime.timedelta, None], catchup: bool, scenario: List[str]
+ ) -> list:
dag_id = 'test_previous_dates'
dag = models.DAG(dag_id=dag_id, schedule_interval=schedule_interval, catchup=catchup)
task = DummyOperator(task_id='task', dag=dag, start_date=DEFAULT_DATE)
@@ -1355,7 +1401,7 @@ def get_test_ti(session, execution_date: pendulum.DateTime, state: str) -> TI:
state=state,
execution_date=execution_date,
start_date=pendulum.now('UTC'),
- session=session
+ session=session,
)
ti = TI(task=task, execution_date=execution_date)
ti.set_state(state=State.SUCCESS, session=session)
@@ -1392,15 +1438,9 @@ def test_previous_ti(self, _, schedule_interval, catchup) -> None:
self.assertIsNone(ti_list[0].get_previous_ti())
- self.assertEqual(
- ti_list[2].get_previous_ti().execution_date,
- ti_list[1].execution_date
- )
+ self.assertEqual(ti_list[2].get_previous_ti().execution_date, ti_list[1].execution_date)
- self.assertNotEqual(
- ti_list[2].get_previous_ti().execution_date,
- ti_list[0].execution_date
- )
+ self.assertNotEqual(ti_list[2].get_previous_ti().execution_date, ti_list[0].execution_date)
@parameterized.expand(_prev_dates_param_list)
def test_previous_ti_success(self, _, schedule_interval, catchup) -> None:
@@ -1413,13 +1453,11 @@ def test_previous_ti_success(self, _, schedule_interval, catchup) -> None:
self.assertIsNone(ti_list[1].get_previous_ti(state=State.SUCCESS))
self.assertEqual(
- ti_list[3].get_previous_ti(state=State.SUCCESS).execution_date,
- ti_list[1].execution_date
+ ti_list[3].get_previous_ti(state=State.SUCCESS).execution_date, ti_list[1].execution_date
)
self.assertNotEqual(
- ti_list[3].get_previous_ti(state=State.SUCCESS).execution_date,
- ti_list[2].execution_date
+ ti_list[3].get_previous_ti(state=State.SUCCESS).execution_date, ti_list[2].execution_date
)
@parameterized.expand(_prev_dates_param_list)
@@ -1432,12 +1470,10 @@ def test_previous_execution_date_success(self, _, schedule_interval, catchup) ->
self.assertIsNone(ti_list[0].get_previous_execution_date(state=State.SUCCESS))
self.assertIsNone(ti_list[1].get_previous_execution_date(state=State.SUCCESS))
self.assertEqual(
- ti_list[3].get_previous_execution_date(state=State.SUCCESS),
- ti_list[1].execution_date
+ ti_list[3].get_previous_execution_date(state=State.SUCCESS), ti_list[1].execution_date
)
self.assertNotEqual(
- ti_list[3].get_previous_execution_date(state=State.SUCCESS),
- ti_list[2].execution_date
+ ti_list[3].get_previous_execution_date(state=State.SUCCESS), ti_list[2].execution_date
)
@parameterized.expand(_prev_dates_param_list)
@@ -1460,8 +1496,10 @@ def test_previous_start_date_success(self, _, schedule_interval, catchup) -> Non
def test_pendulum_template_dates(self):
dag = models.DAG(
- dag_id='test_pendulum_template_dates', schedule_interval='0 12 * * *',
- start_date=timezone.datetime(2016, 6, 1, 0, 0, 0))
+ dag_id='test_pendulum_template_dates',
+ schedule_interval='0 12 * * *',
+ start_date=timezone.datetime(2016, 6, 1, 0, 0, 0),
+ )
task = DummyOperator(task_id='test_pendulum_template_dates_task', dag=dag)
ti = TI(task=task, execution_date=timezone.utcnow())
@@ -1544,16 +1582,16 @@ def test_execute_callback(self):
def on_execute_callable(context):
nonlocal called
called = True
- self.assertEqual(
- context['dag_run'].dag_id,
- 'test_dagrun_execute_callback'
- )
+ self.assertEqual(context['dag_run'].dag_id, 'test_dagrun_execute_callback')
- dag = DAG('test_execute_callback', start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE + datetime.timedelta(days=10))
- task = DummyOperator(task_id='op', email='test@test.test',
- on_execute_callback=on_execute_callable,
- dag=dag)
+ dag = DAG(
+ 'test_execute_callback',
+ start_date=DEFAULT_DATE,
+ end_date=DEFAULT_DATE + datetime.timedelta(days=10),
+ )
+ task = DummyOperator(
+ task_id='op', email='test@test.test', on_execute_callback=on_execute_callable, dag=dag
+ )
ti = TI(task=task, execution_date=datetime.datetime.now())
ti.state = State.RUNNING
session = settings.Session()
@@ -1580,10 +1618,12 @@ def test_handle_failure(self):
mock_on_failure_1 = mock.MagicMock()
mock_on_retry_1 = mock.MagicMock()
- task1 = DummyOperator(task_id="test_handle_failure_on_failure",
- on_failure_callback=mock_on_failure_1,
- on_retry_callback=mock_on_retry_1,
- dag=dag)
+ task1 = DummyOperator(
+ task_id="test_handle_failure_on_failure",
+ on_failure_callback=mock_on_failure_1,
+ on_retry_callback=mock_on_retry_1,
+ dag=dag,
+ )
ti1 = TI(task=task1, execution_date=start_date)
ti1.state = State.FAILED
ti1.handle_failure("test failure handling")
@@ -1594,11 +1634,13 @@ def test_handle_failure(self):
mock_on_failure_2 = mock.MagicMock()
mock_on_retry_2 = mock.MagicMock()
- task2 = DummyOperator(task_id="test_handle_failure_on_retry",
- on_failure_callback=mock_on_failure_2,
- on_retry_callback=mock_on_retry_2,
- retries=1,
- dag=dag)
+ task2 = DummyOperator(
+ task_id="test_handle_failure_on_retry",
+ on_failure_callback=mock_on_failure_2,
+ on_retry_callback=mock_on_retry_2,
+ retries=1,
+ dag=dag,
+ )
ti2 = TI(task=task2, execution_date=start_date)
ti2.state = State.FAILED
ti2.handle_failure("test retry handling")
@@ -1611,11 +1653,13 @@ def test_handle_failure(self):
# test the scenario where normally we would retry but have been asked to fail
mock_on_failure_3 = mock.MagicMock()
mock_on_retry_3 = mock.MagicMock()
- task3 = DummyOperator(task_id="test_handle_failure_on_force_fail",
- on_failure_callback=mock_on_failure_3,
- on_retry_callback=mock_on_retry_3,
- retries=1,
- dag=dag)
+ task3 = DummyOperator(
+ task_id="test_handle_failure_on_force_fail",
+ on_failure_callback=mock_on_failure_3,
+ on_retry_callback=mock_on_retry_3,
+ retries=1,
+ dag=dag,
+ )
ti3 = TI(task=task3, execution_date=start_date)
ti3.state = State.FAILED
ti3.handle_failure("test force_fail handling", force_fail=True)
@@ -1635,7 +1679,7 @@ def fail():
python_callable=fail,
owner='airflow',
start_date=timezone.datetime(2016, 2, 1, 0, 0, 0),
- retries=1
+ retries=1,
)
ti = TI(task=task, execution_date=timezone.utcnow())
try:
@@ -1655,7 +1699,7 @@ def fail():
python_callable=fail,
owner='airflow',
start_date=timezone.datetime(2016, 2, 1, 0, 0, 0),
- retries=1
+ retries=1,
)
ti = TI(task=task, execution_date=timezone.utcnow())
try:
@@ -1667,23 +1711,27 @@ def fail():
def _env_var_check_callback(self):
self.assertEqual('test_echo_env_variables', os.environ['AIRFLOW_CTX_DAG_ID'])
self.assertEqual('hive_in_python_op', os.environ['AIRFLOW_CTX_TASK_ID'])
- self.assertEqual(DEFAULT_DATE.isoformat(),
- os.environ['AIRFLOW_CTX_EXECUTION_DATE'])
- self.assertEqual(DagRun.generate_run_id(DagRunType.MANUAL, DEFAULT_DATE),
- os.environ['AIRFLOW_CTX_DAG_RUN_ID'])
+ self.assertEqual(DEFAULT_DATE.isoformat(), os.environ['AIRFLOW_CTX_EXECUTION_DATE'])
+ self.assertEqual(
+ DagRun.generate_run_id(DagRunType.MANUAL, DEFAULT_DATE), os.environ['AIRFLOW_CTX_DAG_RUN_ID']
+ )
def test_echo_env_variables(self):
- dag = DAG('test_echo_env_variables', start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE + datetime.timedelta(days=10))
- op = PythonOperator(task_id='hive_in_python_op',
- dag=dag,
- python_callable=self._env_var_check_callback)
+ dag = DAG(
+ 'test_echo_env_variables',
+ start_date=DEFAULT_DATE,
+ end_date=DEFAULT_DATE + datetime.timedelta(days=10),
+ )
+ op = PythonOperator(
+ task_id='hive_in_python_op', dag=dag, python_callable=self._env_var_check_callback
+ )
dag.create_dagrun(
run_type=DagRunType.MANUAL,
execution_date=DEFAULT_DATE,
start_date=DEFAULT_DATE,
state=State.RUNNING,
- external_trigger=False)
+ external_trigger=False,
+ )
ti = TI(task=op, execution_date=DEFAULT_DATE)
ti.state = State.RUNNING
session = settings.Session()
@@ -1695,15 +1743,19 @@ def test_echo_env_variables(self):
@patch.object(Stats, 'incr')
def test_task_stats(self, stats_mock):
- dag = DAG('test_task_start_end_stats', start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE + datetime.timedelta(days=10))
+ dag = DAG(
+ 'test_task_start_end_stats',
+ start_date=DEFAULT_DATE,
+ end_date=DEFAULT_DATE + datetime.timedelta(days=10),
+ )
op = DummyOperator(task_id='dummy_op', dag=dag)
dag.create_dagrun(
run_id='manual__' + DEFAULT_DATE.isoformat(),
execution_date=DEFAULT_DATE,
start_date=DEFAULT_DATE,
state=State.RUNNING,
- external_trigger=False)
+ external_trigger=False,
+ )
ti = TI(task=op, execution_date=DEFAULT_DATE)
ti.state = State.RUNNING
session = settings.Session()
@@ -1725,10 +1777,18 @@ def test_generate_command_default_param(self):
def test_generate_command_specific_param(self):
dag_id = 'test_generate_command_specific_param'
task_id = 'task'
- assert_command = ['airflow', 'tasks', 'run', dag_id,
- task_id, DEFAULT_DATE.isoformat(), '--mark-success']
- generate_command = TI.generate_command(dag_id=dag_id, task_id=task_id,
- execution_date=DEFAULT_DATE, mark_success=True)
+ assert_command = [
+ 'airflow',
+ 'tasks',
+ 'run',
+ dag_id,
+ task_id,
+ DEFAULT_DATE.isoformat(),
+ '--mark-success',
+ ]
+ generate_command = TI.generate_command(
+ dag_id=dag_id, task_id=task_id, execution_date=DEFAULT_DATE, mark_success=True
+ )
assert assert_command == generate_command
def test_get_rendered_template_fields(self):
@@ -1759,50 +1819,49 @@ def validate_ti_states(self, dag_run, ti_state_mapping, error_message):
task_instance = dag_run.get_task_instance(task_id=task_id)
self.assertEqual(task_instance.state, expected_state, error_message)
- @parameterized.expand([
- (
- {('scheduler', 'schedule_after_task_execution'): 'True'},
- {'A': 'B', 'B': 'C'},
- {'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE},
- {'A': State.SUCCESS, 'B': State.SCHEDULED, 'C': State.NONE},
- {'A': State.SUCCESS, 'B': State.SUCCESS, 'C': State.SCHEDULED},
- "A -> B -> C, with fast-follow ON when A runs, B should be QUEUED. Same for B and C."
- ),
- (
- {('scheduler', 'schedule_after_task_execution'): 'False'},
- {'A': 'B', 'B': 'C'},
- {'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE},
- {'A': State.SUCCESS, 'B': State.NONE, 'C': State.NONE},
- None,
- "A -> B -> C, with fast-follow OFF, when A runs, B shouldn't be QUEUED."
- ),
- (
- {('scheduler', 'schedule_after_task_execution'): 'True'},
- {'A': 'B', 'C': 'B', 'D': 'C'},
- {'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE, 'D': State.NONE},
- {'A': State.SUCCESS, 'B': State.NONE, 'C': State.NONE, 'D': State.NONE},
- None,
- "D -> C -> B & A -> B, when A runs but C isn't QUEUED yet, B shouldn't be QUEUED."
- ),
- (
- {('scheduler', 'schedule_after_task_execution'): 'True'},
- {'A': 'C', 'B': 'C'},
- {'A': State.QUEUED, 'B': State.FAILED, 'C': State.NONE},
- {'A': State.SUCCESS, 'B': State.FAILED, 'C': State.UPSTREAM_FAILED},
- None,
- "A -> C & B -> C, when A is QUEUED but B has FAILED, C is marked UPSTREAM_FAILED."
- ),
- ])
+ @parameterized.expand(
+ [
+ (
+ {('scheduler', 'schedule_after_task_execution'): 'True'},
+ {'A': 'B', 'B': 'C'},
+ {'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE},
+ {'A': State.SUCCESS, 'B': State.SCHEDULED, 'C': State.NONE},
+ {'A': State.SUCCESS, 'B': State.SUCCESS, 'C': State.SCHEDULED},
+ "A -> B -> C, with fast-follow ON when A runs, B should be QUEUED. Same for B and C.",
+ ),
+ (
+ {('scheduler', 'schedule_after_task_execution'): 'False'},
+ {'A': 'B', 'B': 'C'},
+ {'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE},
+ {'A': State.SUCCESS, 'B': State.NONE, 'C': State.NONE},
+ None,
+ "A -> B -> C, with fast-follow OFF, when A runs, B shouldn't be QUEUED.",
+ ),
+ (
+ {('scheduler', 'schedule_after_task_execution'): 'True'},
+ {'A': 'B', 'C': 'B', 'D': 'C'},
+ {'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE, 'D': State.NONE},
+ {'A': State.SUCCESS, 'B': State.NONE, 'C': State.NONE, 'D': State.NONE},
+ None,
+ "D -> C -> B & A -> B, when A runs but C isn't QUEUED yet, B shouldn't be QUEUED.",
+ ),
+ (
+ {('scheduler', 'schedule_after_task_execution'): 'True'},
+ {'A': 'C', 'B': 'C'},
+ {'A': State.QUEUED, 'B': State.FAILED, 'C': State.NONE},
+ {'A': State.SUCCESS, 'B': State.FAILED, 'C': State.UPSTREAM_FAILED},
+ None,
+ "A -> C & B -> C, when A is QUEUED but B has FAILED, C is marked UPSTREAM_FAILED.",
+ ),
+ ]
+ )
def test_fast_follow(
self, conf, dependencies, init_state, first_run_state, second_run_state, error_message
):
with conf_vars(conf):
session = settings.Session()
- dag = DAG(
- 'test_dagrun_fast_follow',
- start_date=DEFAULT_DATE
- )
+ dag = DAG('test_dagrun_fast_follow', start_date=DEFAULT_DATE)
dag_model = DagModel(
dag_id=dag.dag_id,
@@ -1857,9 +1916,16 @@ def test_fast_follow(
@pytest.mark.parametrize("pool_override", [None, "test_pool2"])
def test_refresh_from_task(pool_override):
- task = DummyOperator(task_id="dummy", queue="test_queue", pool="test_pool1", pool_slots=3,
- priority_weight=10, run_as_user="test", retries=30,
- executor_config={"KubernetesExecutor": {"image": "myCustomDockerImage"}})
+ task = DummyOperator(
+ task_id="dummy",
+ queue="test_queue",
+ pool="test_pool1",
+ pool_slots=3,
+ priority_weight=10,
+ run_as_user="test",
+ retries=30,
+ executor_config={"KubernetesExecutor": {"image": "myCustomDockerImage"}},
+ )
ti = TI(task, execution_date=pendulum.datetime(2020, 1, 1))
ti.refresh_from_task(task, pool_override=pool_override)
@@ -1898,11 +1964,13 @@ def setUp(self) -> None:
def tearDown(self) -> None:
self._clean()
- @parameterized.expand([
- # Expected queries, mark_success
- (10, False),
- (5, True),
- ])
+ @parameterized.expand(
+ [
+ # Expected queries, mark_success
+ (10, False),
+ (5, True),
+ ]
+ )
def test_execute_queries_count(self, expected_query_count, mark_success):
with create_session() as session:
dag = DAG('test_queries', start_date=DEFAULT_DATE)
diff --git a/tests/models/test_timestamp.py b/tests/models/test_timestamp.py
index 979fb1d5b6415..8424a687ec94e 100644
--- a/tests/models/test_timestamp.py
+++ b/tests/models/test_timestamp.py
@@ -38,8 +38,7 @@ def clear_db(session=None):
def add_log(execdate, session, timezone_override=None):
- dag = DAG(dag_id='logging',
- default_args={'start_date': execdate})
+ dag = DAG(dag_id='logging', default_args={'start_date': execdate})
task = DummyOperator(task_id='dummy', dag=dag, owner='airflow')
task_instance = TaskInstance(task=task, execution_date=execdate, state='success')
session.merge(task_instance)
diff --git a/tests/models/test_variable.py b/tests/models/test_variable.py
index 4b507e616cd1b..54074dd3a6ea3 100644
--- a/tests/models/test_variable.py
+++ b/tests/models/test_variable.py
@@ -58,10 +58,7 @@ def test_variable_with_encryption(self):
self.assertTrue(test_var.is_encrypted)
self.assertEqual(test_var.val, 'value')
- @parameterized.expand([
- 'value',
- ''
- ])
+ @parameterized.expand(['value', ''])
def test_var_with_encryption_rotate_fernet_key(self, test_value):
"""
Tests rotating encrypted variables.
@@ -106,8 +103,7 @@ def test_variable_set_existing_value_to_blank(self):
def test_get_non_existing_var_should_return_default(self):
default_value = "some default val"
- self.assertEqual(default_value, Variable.get("thisIdDoesNotExist",
- default_var=default_value))
+ self.assertEqual(default_value, Variable.get("thisIdDoesNotExist", default_var=default_value))
def test_get_non_existing_var_should_raise_key_error(self):
with self.assertRaises(KeyError):
@@ -118,9 +114,10 @@ def test_get_non_existing_var_with_none_default_should_return_none(self):
def test_get_non_existing_var_should_not_deserialize_json_default(self):
default_value = "}{ this is a non JSON default }{"
- self.assertEqual(default_value, Variable.get("thisIdDoesNotExist",
- default_var=default_value,
- deserialize_json=True))
+ self.assertEqual(
+ default_value,
+ Variable.get("thisIdDoesNotExist", default_var=default_value, deserialize_json=True),
+ )
def test_variable_setdefault_round_trip(self):
key = "tested_var_setdefault_1_id"
diff --git a/tests/models/test_xcom.py b/tests/models/test_xcom.py
index 39586d1638468..9793a26faf6f8 100644
--- a/tests/models/test_xcom.py
+++ b/tests/models/test_xcom.py
@@ -32,7 +32,6 @@ def serialize_value(_):
class TestXCom(unittest.TestCase):
-
def setUp(self) -> None:
db.clear_db_xcom()
@@ -45,9 +44,7 @@ def test_resolve_xcom_class(self):
assert issubclass(cls, CustomXCom)
assert cls().serialize_value(None) == "custom_value"
- @conf_vars(
- {("core", "xcom_backend"): "", ("core", "enable_xcom_pickling"): "False"}
- )
+ @conf_vars({("core", "xcom_backend"): "", ("core", "enable_xcom_pickling"): "False"})
def test_resolve_xcom_class_fallback_to_basexcom(self):
cls = resolve_xcom_backend()
assert issubclass(cls, BaseXCom)
@@ -68,24 +65,28 @@ def test_xcom_disable_pickle_type(self):
key = "xcom_test1"
dag_id = "test_dag1"
task_id = "test_task1"
- XCom.set(key=key,
- value=json_obj,
- dag_id=dag_id,
- task_id=task_id,
- execution_date=execution_date)
+ XCom.set(key=key, value=json_obj, dag_id=dag_id, task_id=task_id, execution_date=execution_date)
- ret_value = XCom.get_many(key=key,
- dag_ids=dag_id,
- task_ids=task_id,
- execution_date=execution_date).first().value
+ ret_value = (
+ XCom.get_many(key=key, dag_ids=dag_id, task_ids=task_id, execution_date=execution_date)
+ .first()
+ .value
+ )
self.assertEqual(ret_value, json_obj)
session = settings.Session()
- ret_value = session.query(XCom).filter(XCom.key == key, XCom.dag_id == dag_id,
- XCom.task_id == task_id,
- XCom.execution_date == execution_date
- ).first().value
+ ret_value = (
+ session.query(XCom)
+ .filter(
+ XCom.key == key,
+ XCom.dag_id == dag_id,
+ XCom.task_id == task_id,
+ XCom.execution_date == execution_date,
+ )
+ .first()
+ .value
+ )
self.assertEqual(ret_value, json_obj)
@@ -96,24 +97,24 @@ def test_xcom_get_one_disable_pickle_type(self):
key = "xcom_test1"
dag_id = "test_dag1"
task_id = "test_task1"
- XCom.set(key=key,
- value=json_obj,
- dag_id=dag_id,
- task_id=task_id,
- execution_date=execution_date)
+ XCom.set(key=key, value=json_obj, dag_id=dag_id, task_id=task_id, execution_date=execution_date)
- ret_value = XCom.get_one(key=key,
- dag_id=dag_id,
- task_id=task_id,
- execution_date=execution_date)
+ ret_value = XCom.get_one(key=key, dag_id=dag_id, task_id=task_id, execution_date=execution_date)
self.assertEqual(ret_value, json_obj)
session = settings.Session()
- ret_value = session.query(XCom).filter(XCom.key == key, XCom.dag_id == dag_id,
- XCom.task_id == task_id,
- XCom.execution_date == execution_date
- ).first().value
+ ret_value = (
+ session.query(XCom)
+ .filter(
+ XCom.key == key,
+ XCom.dag_id == dag_id,
+ XCom.task_id == task_id,
+ XCom.execution_date == execution_date,
+ )
+ .first()
+ .value
+ )
self.assertEqual(ret_value, json_obj)
@@ -124,24 +125,28 @@ def test_xcom_enable_pickle_type(self):
key = "xcom_test2"
dag_id = "test_dag2"
task_id = "test_task2"
- XCom.set(key=key,
- value=json_obj,
- dag_id=dag_id,
- task_id=task_id,
- execution_date=execution_date)
+ XCom.set(key=key, value=json_obj, dag_id=dag_id, task_id=task_id, execution_date=execution_date)
- ret_value = XCom.get_many(key=key,
- dag_ids=dag_id,
- task_ids=task_id,
- execution_date=execution_date).first().value
+ ret_value = (
+ XCom.get_many(key=key, dag_ids=dag_id, task_ids=task_id, execution_date=execution_date)
+ .first()
+ .value
+ )
self.assertEqual(ret_value, json_obj)
session = settings.Session()
- ret_value = session.query(XCom).filter(XCom.key == key, XCom.dag_id == dag_id,
- XCom.task_id == task_id,
- XCom.execution_date == execution_date
- ).first().value
+ ret_value = (
+ session.query(XCom)
+ .filter(
+ XCom.key == key,
+ XCom.dag_id == dag_id,
+ XCom.task_id == task_id,
+ XCom.execution_date == execution_date,
+ )
+ .first()
+ .value
+ )
self.assertEqual(ret_value, json_obj)
@@ -152,24 +157,24 @@ def test_xcom_get_one_enable_pickle_type(self):
key = "xcom_test3"
dag_id = "test_dag"
task_id = "test_task3"
- XCom.set(key=key,
- value=json_obj,
- dag_id=dag_id,
- task_id=task_id,
- execution_date=execution_date)
+ XCom.set(key=key, value=json_obj, dag_id=dag_id, task_id=task_id, execution_date=execution_date)
- ret_value = XCom.get_one(key=key,
- dag_id=dag_id,
- task_id=task_id,
- execution_date=execution_date)
+ ret_value = XCom.get_one(key=key, dag_id=dag_id, task_id=task_id, execution_date=execution_date)
self.assertEqual(ret_value, json_obj)
session = settings.Session()
- ret_value = session.query(XCom).filter(XCom.key == key, XCom.dag_id == dag_id,
- XCom.task_id == task_id,
- XCom.execution_date == execution_date
- ).first().value
+ ret_value = (
+ session.query(XCom)
+ .filter(
+ XCom.key == key,
+ XCom.dag_id == dag_id,
+ XCom.task_id == task_id,
+ XCom.execution_date == execution_date,
+ )
+ .first()
+ .value
+ )
self.assertEqual(ret_value, json_obj)
@@ -179,12 +184,15 @@ class PickleRce:
def __reduce__(self):
return os.system, ("ls -alt",)
- self.assertRaises(TypeError, XCom.set,
- key="xcom_test3",
- value=PickleRce(),
- dag_id="test_dag3",
- task_id="test_task3",
- execution_date=timezone.utcnow())
+ self.assertRaises(
+ TypeError,
+ XCom.set,
+ key="xcom_test3",
+ value=PickleRce(),
+ dag_id="test_dag3",
+ task_id="test_task3",
+ execution_date=timezone.utcnow(),
+ )
@conf_vars({("core", "xcom_enable_pickling"): "True"})
def test_xcom_get_many(self):
@@ -196,17 +204,9 @@ def test_xcom_get_many(self):
dag_id2 = "test_dag5"
task_id2 = "test_task5"
- XCom.set(key=key,
- value=json_obj,
- dag_id=dag_id1,
- task_id=task_id1,
- execution_date=execution_date)
-
- XCom.set(key=key,
- value=json_obj,
- dag_id=dag_id2,
- task_id=task_id2,
- execution_date=execution_date)
+ XCom.set(key=key, value=json_obj, dag_id=dag_id1, task_id=task_id1, execution_date=execution_date)
+
+ XCom.set(key=key, value=json_obj, dag_id=dag_id2, task_id=task_id2, execution_date=execution_date)
results = XCom.get_many(key=key, execution_date=execution_date)
diff --git a/tests/models/test_xcom_arg.py b/tests/models/test_xcom_arg.py
index 2f38ed72bdafc..6c08de069e0da 100644
--- a/tests/models/test_xcom_arg.py
+++ b/tests/models/test_xcom_arg.py
@@ -62,17 +62,21 @@ def test_xcom_ctor(self):
assert actual.key == "test_key"
# Asserting the overridden __eq__ method
assert actual == XComArg(python_op, "test_key")
- assert str(actual) == "task_instance.xcom_pull(" \
- "task_ids=\'test_xcom_op\', " \
- "dag_id=\'test_xcom_dag\', " \
- "key=\'test_key\')"
+ assert (
+ str(actual) == "task_instance.xcom_pull("
+ "task_ids=\'test_xcom_op\', "
+ "dag_id=\'test_xcom_dag\', "
+ "key=\'test_key\')"
+ )
def test_xcom_key_is_empty_str(self):
python_op = build_python_op()
actual = XComArg(python_op, key="")
assert actual.key == ""
- assert str(actual) == "task_instance.xcom_pull(task_ids='test_xcom_op', " \
- "dag_id='test_xcom_dag', key='')"
+ assert (
+ str(actual) == "task_instance.xcom_pull(task_ids='test_xcom_op', "
+ "dag_id='test_xcom_dag', key='')"
+ )
def test_set_downstream(self):
with DAG("test_set_downstream", default_args=DEFAULT_ARGS):
diff --git a/tests/operators/test_bash.py b/tests/operators/test_bash.py
index 22dd60051affa..9c0cca18d817e 100644
--- a/tests/operators/test_bash.py
+++ b/tests/operators/test_bash.py
@@ -36,7 +36,6 @@
class TestBashOperator(unittest.TestCase):
-
def test_echo_env_variables(self):
"""
Test that env variables are exported correctly to the
@@ -46,13 +45,11 @@ def test_echo_env_variables(self):
now = now.replace(tzinfo=timezone.utc)
dag = DAG(
- dag_id='bash_op_test', default_args={
- 'owner': 'airflow',
- 'retries': 100,
- 'start_date': DEFAULT_DATE
- },
+ dag_id='bash_op_test',
+ default_args={'owner': 'airflow', 'retries': 100, 'start_date': DEFAULT_DATE},
schedule_interval='@daily',
- dagrun_timeout=timedelta(minutes=60))
+ dagrun_timeout=timedelta(minutes=60),
+ )
dag.create_dagrun(
run_type=DagRunType.MANUAL,
@@ -67,19 +64,17 @@ def test_echo_env_variables(self):
task_id='echo_env_vars',
dag=dag,
bash_command='echo $AIRFLOW_HOME>> {0};'
- 'echo $PYTHONPATH>> {0};'
- 'echo $AIRFLOW_CTX_DAG_ID >> {0};'
- 'echo $AIRFLOW_CTX_TASK_ID>> {0};'
- 'echo $AIRFLOW_CTX_EXECUTION_DATE>> {0};'
- 'echo $AIRFLOW_CTX_DAG_RUN_ID>> {0};'.format(tmp_file.name)
+ 'echo $PYTHONPATH>> {0};'
+ 'echo $AIRFLOW_CTX_DAG_ID >> {0};'
+ 'echo $AIRFLOW_CTX_TASK_ID>> {0};'
+ 'echo $AIRFLOW_CTX_EXECUTION_DATE>> {0};'
+ 'echo $AIRFLOW_CTX_DAG_RUN_ID>> {0};'.format(tmp_file.name),
)
- with mock.patch.dict('os.environ', {
- 'AIRFLOW_HOME': 'MY_PATH_TO_AIRFLOW_HOME',
- 'PYTHONPATH': 'AWESOME_PYTHONPATH'
- }):
- task.run(DEFAULT_DATE, DEFAULT_DATE,
- ignore_first_depends_on_past=True, ignore_ti_state=True)
+ with mock.patch.dict(
+ 'os.environ', {'AIRFLOW_HOME': 'MY_PATH_TO_AIRFLOW_HOME', 'PYTHONPATH': 'AWESOME_PYTHONPATH'}
+ ):
+ task.run(DEFAULT_DATE, DEFAULT_DATE, ignore_first_depends_on_past=True, ignore_ti_state=True)
with open(tmp_file.name) as file:
output = ''.join(file.readlines())
@@ -92,60 +87,44 @@ def test_echo_env_variables(self):
self.assertIn(DagRun.generate_run_id(DagRunType.MANUAL, DEFAULT_DATE), output)
def test_return_value(self):
- bash_operator = BashOperator(
- bash_command='echo "stdout"',
- task_id='test_return_value',
- dag=None
- )
+ bash_operator = BashOperator(bash_command='echo "stdout"', task_id='test_return_value', dag=None)
return_value = bash_operator.execute(context={})
self.assertEqual(return_value, 'stdout')
def test_raise_exception_on_non_zero_exit_code(self):
- bash_operator = BashOperator(
- bash_command='exit 42',
- task_id='test_return_value',
- dag=None
- )
+ bash_operator = BashOperator(bash_command='exit 42', task_id='test_return_value', dag=None)
with self.assertRaisesRegex(
- AirflowException,
- "Bash command failed\\. The command returned a non-zero exit code\\."
+ AirflowException, "Bash command failed\\. The command returned a non-zero exit code\\."
):
bash_operator.execute(context={})
def test_task_retries(self):
bash_operator = BashOperator(
- bash_command='echo "stdout"',
- task_id='test_task_retries',
- retries=2,
- dag=None
+ bash_command='echo "stdout"', task_id='test_task_retries', retries=2, dag=None
)
self.assertEqual(bash_operator.retries, 2)
def test_default_retries(self):
- bash_operator = BashOperator(
- bash_command='echo "stdout"',
- task_id='test_default_retries',
- dag=None
- )
+ bash_operator = BashOperator(bash_command='echo "stdout"', task_id='test_default_retries', dag=None)
self.assertEqual(bash_operator.retries, 0)
@mock.patch.dict('os.environ', clear=True)
- @mock.patch("airflow.operators.bash.TemporaryDirectory", **{ # type: ignore
- 'return_value.__enter__.return_value': '/tmp/airflowtmpcatcat'
- })
- @mock.patch("airflow.operators.bash.Popen", **{ # type: ignore
- 'return_value.stdout.readline.side_effect': [b'BAR', b'BAZ'],
- 'return_value.returncode': 0
- })
+ @mock.patch(
+ "airflow.operators.bash.TemporaryDirectory",
+ **{'return_value.__enter__.return_value': '/tmp/airflowtmpcatcat'}, # type: ignore
+ )
+ @mock.patch(
+ "airflow.operators.bash.Popen",
+ **{ # type: ignore
+ 'return_value.stdout.readline.side_effect': [b'BAR', b'BAZ'],
+ 'return_value.returncode': 0,
+ },
+ )
def test_should_exec_subprocess(self, mock_popen, mock_temporary_directory):
- bash_operator = BashOperator(
- bash_command='echo "stdout"',
- task_id='test_return_value',
- dag=None
- )
+ bash_operator = BashOperator(bash_command='echo "stdout"', task_id='test_return_value', dag=None)
bash_operator.execute({})
mock_popen.assert_called_once_with(
@@ -154,5 +133,5 @@ def test_should_exec_subprocess(self, mock_popen, mock_temporary_directory):
env={},
preexec_fn=mock.ANY,
stderr=STDOUT,
- stdout=PIPE
+ stdout=PIPE,
)
diff --git a/tests/operators/test_branch_operator.py b/tests/operators/test_branch_operator.py
index 9553a725581dc..f97f0dab8b0cf 100644
--- a/tests/operators/test_branch_operator.py
+++ b/tests/operators/test_branch_operator.py
@@ -51,11 +51,11 @@ def setUpClass(cls):
session.query(TI).delete()
def setUp(self):
- self.dag = DAG('branch_operator_test',
- default_args={
- 'owner': 'airflow',
- 'start_date': DEFAULT_DATE},
- schedule_interval=INTERVAL)
+ self.dag = DAG(
+ 'branch_operator_test',
+ default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE},
+ schedule_interval=INTERVAL,
+ )
self.branch_1 = DummyOperator(task_id='branch_1', dag=self.dag)
self.branch_2 = DummyOperator(task_id='branch_2', dag=self.dag)
@@ -79,10 +79,7 @@ def test_without_dag_run(self):
self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
with create_session() as session:
- tis = session.query(TI).filter(
- TI.dag_id == self.dag.dag_id,
- TI.execution_date == DEFAULT_DATE
- )
+ tis = session.query(TI).filter(TI.dag_id == self.dag.dag_id, TI.execution_date == DEFAULT_DATE)
for ti in tis:
if ti.task_id == 'make_choice':
@@ -107,10 +104,7 @@ def test_branch_list_without_dag_run(self):
self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
with create_session() as session:
- tis = session.query(TI).filter(
- TI.dag_id == self.dag.dag_id,
- TI.execution_date == DEFAULT_DATE
- )
+ tis = session.query(TI).filter(TI.dag_id == self.dag.dag_id, TI.execution_date == DEFAULT_DATE)
expected = {
"make_choice": State.SUCCESS,
@@ -135,7 +129,7 @@ def test_with_dag_run(self):
run_type=DagRunType.MANUAL,
start_date=timezone.utcnow(),
execution_date=DEFAULT_DATE,
- state=State.RUNNING
+ state=State.RUNNING,
)
self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
@@ -161,7 +155,7 @@ def test_with_skip_in_branch_downstream_dependencies(self):
run_type=DagRunType.MANUAL,
start_date=timezone.utcnow(),
execution_date=DEFAULT_DATE,
- state=State.RUNNING
+ state=State.RUNNING,
)
self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
diff --git a/tests/operators/test_dagrun_operator.py b/tests/operators/test_dagrun_operator.py
index 804305b6204d0..0ef41587932df 100644
--- a/tests/operators/test_dagrun_operator.py
+++ b/tests/operators/test_dagrun_operator.py
@@ -149,11 +149,13 @@ def test_trigger_dagrun_operator_templated_conf(self):
def test_trigger_dagrun_with_reset_dag_run_false(self):
"""Test TriggerDagRunOperator with reset_dag_run."""
execution_date = DEFAULT_DATE
- task = TriggerDagRunOperator(task_id="test_task",
- trigger_dag_id=TRIGGERED_DAG_ID,
- execution_date=execution_date,
- reset_dag_run=False,
- dag=self.dag)
+ task = TriggerDagRunOperator(
+ task_id="test_task",
+ trigger_dag_id=TRIGGERED_DAG_ID,
+ execution_date=execution_date,
+ reset_dag_run=False,
+ dag=self.dag,
+ )
task.run(start_date=execution_date, end_date=execution_date, ignore_ti_state=True)
with self.assertRaises(DagRunAlreadyExists):
@@ -162,11 +164,13 @@ def test_trigger_dagrun_with_reset_dag_run_false(self):
def test_trigger_dagrun_with_reset_dag_run_true(self):
"""Test TriggerDagRunOperator with reset_dag_run."""
execution_date = DEFAULT_DATE
- task = TriggerDagRunOperator(task_id="test_task",
- trigger_dag_id=TRIGGERED_DAG_ID,
- execution_date=execution_date,
- reset_dag_run=True,
- dag=self.dag)
+ task = TriggerDagRunOperator(
+ task_id="test_task",
+ trigger_dag_id=TRIGGERED_DAG_ID,
+ execution_date=execution_date,
+ reset_dag_run=True,
+ dag=self.dag,
+ )
task.run(start_date=execution_date, end_date=execution_date, ignore_ti_state=True)
task.run(start_date=execution_date, end_date=execution_date, ignore_ti_state=True)
diff --git a/tests/operators/test_email.py b/tests/operators/test_email.py
index 867c1116a9370..c02b04ff494ea 100644
--- a/tests/operators/test_email.py
+++ b/tests/operators/test_email.py
@@ -34,15 +34,13 @@
class TestEmailOperator(unittest.TestCase):
-
def setUp(self):
super().setUp()
self.dag = DAG(
'test_dag',
- default_args={
- 'owner': 'airflow',
- 'start_date': DEFAULT_DATE},
- schedule_interval=INTERVAL)
+ default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE},
+ schedule_interval=INTERVAL,
+ )
self.addCleanup(self.dag.clear)
def _run_as_operator(self, **kwargs):
@@ -52,12 +50,11 @@ def _run_as_operator(self, **kwargs):
html_content='The quick brown fox jumps over the lazy dog',
task_id='task',
dag=self.dag,
- **kwargs)
+ **kwargs,
+ )
task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
def test_execute(self):
- with conf_vars(
- {('email', 'email_backend'): 'tests.operators.test_email.send_email_test'}
- ):
+ with conf_vars({('email', 'email_backend'): 'tests.operators.test_email.send_email_test'}):
self._run_as_operator()
assert send_email_test.call_count == 1
diff --git a/tests/operators/test_generic_transfer.py b/tests/operators/test_generic_transfer.py
index 6ad670615b55b..68b5f8f251a2a 100644
--- a/tests/operators/test_generic_transfer.py
+++ b/tests/operators/test_generic_transfer.py
@@ -37,10 +37,7 @@
@pytest.mark.backend("mysql")
class TestMySql(unittest.TestCase):
def setUp(self):
- args = {
- 'owner': 'airflow',
- 'start_date': DEFAULT_DATE
- }
+ args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
dag = DAG(TEST_DAG_ID, default_args=args)
self.dag = dag
@@ -50,7 +47,12 @@ def tearDown(self):
for table in drop_tables:
conn.execute(f"DROP TABLE IF EXISTS {table}")
- @parameterized.expand([("mysqlclient",), ("mysql-connector-python",), ])
+ @parameterized.expand(
+ [
+ ("mysqlclient",),
+ ("mysql-connector-python",),
+ ]
+ )
def test_mysql_to_mysql(self, client):
with MySqlContext(client):
sql = "SELECT * FROM INFORMATION_SCHEMA.TABLES LIMIT 100;"
@@ -58,14 +60,14 @@ def test_mysql_to_mysql(self, client):
task_id='test_m2m',
preoperator=[
"DROP TABLE IF EXISTS test_mysql_to_mysql",
- "CREATE TABLE IF NOT EXISTS "
- "test_mysql_to_mysql LIKE INFORMATION_SCHEMA.TABLES"
+ "CREATE TABLE IF NOT EXISTS test_mysql_to_mysql LIKE INFORMATION_SCHEMA.TABLES",
],
source_conn_id='airflow_db',
destination_conn_id='airflow_db',
destination_table="test_mysql_to_mysql",
sql=sql,
- dag=self.dag)
+ dag=self.dag,
+ )
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
@@ -89,12 +91,12 @@ def test_postgres_to_postgres(self):
task_id='test_p2p',
preoperator=[
"DROP TABLE IF EXISTS test_postgres_to_postgres",
- "CREATE TABLE IF NOT EXISTS "
- "test_postgres_to_postgres (LIKE INFORMATION_SCHEMA.TABLES)"
+ "CREATE TABLE IF NOT EXISTS test_postgres_to_postgres (LIKE INFORMATION_SCHEMA.TABLES)",
],
source_conn_id='postgres_default',
destination_conn_id='postgres_default',
destination_table="test_postgres_to_postgres",
sql=sql,
- dag=self.dag)
+ dag=self.dag,
+ )
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
diff --git a/tests/operators/test_latest_only_operator.py b/tests/operators/test_latest_only_operator.py
index b3987564b3d8d..a514d642c6808 100644
--- a/tests/operators/test_latest_only_operator.py
+++ b/tests/operators/test_latest_only_operator.py
@@ -40,23 +40,22 @@
def get_task_instances(task_id):
session = settings.Session()
- return session \
- .query(TaskInstance) \
- .filter(TaskInstance.task_id == task_id) \
- .order_by(TaskInstance.execution_date) \
+ return (
+ session.query(TaskInstance)
+ .filter(TaskInstance.task_id == task_id)
+ .order_by(TaskInstance.execution_date)
.all()
+ )
class TestLatestOnlyOperator(unittest.TestCase):
-
def setUp(self):
super().setUp()
self.dag = DAG(
'test_dag',
- default_args={
- 'owner': 'airflow',
- 'start_date': DEFAULT_DATE},
- schedule_interval=INTERVAL)
+ default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE},
+ schedule_interval=INTERVAL,
+ )
with create_session() as session:
session.query(DagRun).delete()
session.query(TaskInstance).delete()
@@ -65,25 +64,16 @@ def setUp(self):
self.addCleanup(freezer.stop)
def test_run(self):
- task = LatestOnlyOperator(
- task_id='latest',
- dag=self.dag)
+ task = LatestOnlyOperator(task_id='latest', dag=self.dag)
task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
def test_skipping_non_latest(self):
- latest_task = LatestOnlyOperator(
- task_id='latest',
- dag=self.dag)
- downstream_task = DummyOperator(
- task_id='downstream',
- dag=self.dag)
- downstream_task2 = DummyOperator(
- task_id='downstream_2',
- dag=self.dag)
+ latest_task = LatestOnlyOperator(task_id='latest', dag=self.dag)
+ downstream_task = DummyOperator(task_id='downstream', dag=self.dag)
+ downstream_task2 = DummyOperator(task_id='downstream_2', dag=self.dag)
downstream_task3 = DummyOperator(
- task_id='downstream_3',
- trigger_rule=TriggerRule.NONE_FAILED,
- dag=self.dag)
+ task_id='downstream_3', trigger_rule=TriggerRule.NONE_FAILED, dag=self.dag
+ )
downstream_task.set_upstream(latest_task)
downstream_task2.set_upstream(downstream_task)
@@ -116,51 +106,53 @@ def test_skipping_non_latest(self):
downstream_task3.run(start_date=DEFAULT_DATE, end_date=END_DATE)
latest_instances = get_task_instances('latest')
- exec_date_to_latest_state = {
- ti.execution_date: ti.state for ti in latest_instances}
- self.assertEqual({
- timezone.datetime(2016, 1, 1): 'success',
- timezone.datetime(2016, 1, 1, 12): 'success',
- timezone.datetime(2016, 1, 2): 'success'},
- exec_date_to_latest_state)
+ exec_date_to_latest_state = {ti.execution_date: ti.state for ti in latest_instances}
+ self.assertEqual(
+ {
+ timezone.datetime(2016, 1, 1): 'success',
+ timezone.datetime(2016, 1, 1, 12): 'success',
+ timezone.datetime(2016, 1, 2): 'success',
+ },
+ exec_date_to_latest_state,
+ )
downstream_instances = get_task_instances('downstream')
- exec_date_to_downstream_state = {
- ti.execution_date: ti.state for ti in downstream_instances}
- self.assertEqual({
- timezone.datetime(2016, 1, 1): 'skipped',
- timezone.datetime(2016, 1, 1, 12): 'skipped',
- timezone.datetime(2016, 1, 2): 'success'},
- exec_date_to_downstream_state)
+ exec_date_to_downstream_state = {ti.execution_date: ti.state for ti in downstream_instances}
+ self.assertEqual(
+ {
+ timezone.datetime(2016, 1, 1): 'skipped',
+ timezone.datetime(2016, 1, 1, 12): 'skipped',
+ timezone.datetime(2016, 1, 2): 'success',
+ },
+ exec_date_to_downstream_state,
+ )
downstream_instances = get_task_instances('downstream_2')
- exec_date_to_downstream_state = {
- ti.execution_date: ti.state for ti in downstream_instances}
- self.assertEqual({
- timezone.datetime(2016, 1, 1): None,
- timezone.datetime(2016, 1, 1, 12): None,
- timezone.datetime(2016, 1, 2): 'success'},
- exec_date_to_downstream_state)
+ exec_date_to_downstream_state = {ti.execution_date: ti.state for ti in downstream_instances}
+ self.assertEqual(
+ {
+ timezone.datetime(2016, 1, 1): None,
+ timezone.datetime(2016, 1, 1, 12): None,
+ timezone.datetime(2016, 1, 2): 'success',
+ },
+ exec_date_to_downstream_state,
+ )
downstream_instances = get_task_instances('downstream_3')
- exec_date_to_downstream_state = {
- ti.execution_date: ti.state for ti in downstream_instances}
- self.assertEqual({
- timezone.datetime(2016, 1, 1): 'success',
- timezone.datetime(2016, 1, 1, 12): 'success',
- timezone.datetime(2016, 1, 2): 'success'},
- exec_date_to_downstream_state)
+ exec_date_to_downstream_state = {ti.execution_date: ti.state for ti in downstream_instances}
+ self.assertEqual(
+ {
+ timezone.datetime(2016, 1, 1): 'success',
+ timezone.datetime(2016, 1, 1, 12): 'success',
+ timezone.datetime(2016, 1, 2): 'success',
+ },
+ exec_date_to_downstream_state,
+ )
def test_not_skipping_external(self):
- latest_task = LatestOnlyOperator(
- task_id='latest',
- dag=self.dag)
- downstream_task = DummyOperator(
- task_id='downstream',
- dag=self.dag)
- downstream_task2 = DummyOperator(
- task_id='downstream_2',
- dag=self.dag)
+ latest_task = LatestOnlyOperator(task_id='latest', dag=self.dag)
+ downstream_task = DummyOperator(task_id='downstream', dag=self.dag)
+ downstream_task2 = DummyOperator(task_id='downstream_2', dag=self.dag)
downstream_task.set_upstream(latest_task)
downstream_task2.set_upstream(downstream_task)
@@ -194,28 +186,34 @@ def test_not_skipping_external(self):
downstream_task2.run(start_date=DEFAULT_DATE, end_date=END_DATE)
latest_instances = get_task_instances('latest')
- exec_date_to_latest_state = {
- ti.execution_date: ti.state for ti in latest_instances}
- self.assertEqual({
- timezone.datetime(2016, 1, 1): 'success',
- timezone.datetime(2016, 1, 1, 12): 'success',
- timezone.datetime(2016, 1, 2): 'success'},
- exec_date_to_latest_state)
+ exec_date_to_latest_state = {ti.execution_date: ti.state for ti in latest_instances}
+ self.assertEqual(
+ {
+ timezone.datetime(2016, 1, 1): 'success',
+ timezone.datetime(2016, 1, 1, 12): 'success',
+ timezone.datetime(2016, 1, 2): 'success',
+ },
+ exec_date_to_latest_state,
+ )
downstream_instances = get_task_instances('downstream')
- exec_date_to_downstream_state = {
- ti.execution_date: ti.state for ti in downstream_instances}
- self.assertEqual({
- timezone.datetime(2016, 1, 1): 'success',
- timezone.datetime(2016, 1, 1, 12): 'success',
- timezone.datetime(2016, 1, 2): 'success'},
- exec_date_to_downstream_state)
+ exec_date_to_downstream_state = {ti.execution_date: ti.state for ti in downstream_instances}
+ self.assertEqual(
+ {
+ timezone.datetime(2016, 1, 1): 'success',
+ timezone.datetime(2016, 1, 1, 12): 'success',
+ timezone.datetime(2016, 1, 2): 'success',
+ },
+ exec_date_to_downstream_state,
+ )
downstream_instances = get_task_instances('downstream_2')
- exec_date_to_downstream_state = {
- ti.execution_date: ti.state for ti in downstream_instances}
- self.assertEqual({
- timezone.datetime(2016, 1, 1): 'success',
- timezone.datetime(2016, 1, 1, 12): 'success',
- timezone.datetime(2016, 1, 2): 'success'},
- exec_date_to_downstream_state)
+ exec_date_to_downstream_state = {ti.execution_date: ti.state for ti in downstream_instances}
+ self.assertEqual(
+ {
+ timezone.datetime(2016, 1, 1): 'success',
+ timezone.datetime(2016, 1, 1, 12): 'success',
+ timezone.datetime(2016, 1, 2): 'success',
+ },
+ exec_date_to_downstream_state,
+ )
diff --git a/tests/operators/test_python.py b/tests/operators/test_python.py
index f804f06e855ad..f1bb085d80334 100644
--- a/tests/operators/test_python.py
+++ b/tests/operators/test_python.py
@@ -34,7 +34,11 @@
from airflow.models.xcom_arg import XComArg
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.python import (
- BranchPythonOperator, PythonOperator, PythonVirtualenvOperator, ShortCircuitOperator, get_current_context,
+ BranchPythonOperator,
+ PythonOperator,
+ PythonVirtualenvOperator,
+ ShortCircuitOperator,
+ get_current_context,
task as task_decorator,
)
from airflow.utils import timezone
@@ -49,10 +53,12 @@
INTERVAL = timedelta(hours=12)
FROZEN_NOW = timezone.datetime(2016, 1, 2, 12, 1, 1)
-TI_CONTEXT_ENV_VARS = ['AIRFLOW_CTX_DAG_ID',
- 'AIRFLOW_CTX_TASK_ID',
- 'AIRFLOW_CTX_EXECUTION_DATE',
- 'AIRFLOW_CTX_DAG_RUN_ID']
+TI_CONTEXT_ENV_VARS = [
+ 'AIRFLOW_CTX_DAG_ID',
+ 'AIRFLOW_CTX_TASK_ID',
+ 'AIRFLOW_CTX_EXECUTION_DATE',
+ 'AIRFLOW_CTX_DAG_RUN_ID',
+]
class Call:
@@ -88,11 +94,7 @@ def setUpClass(cls):
def setUp(self):
super().setUp()
- self.dag = DAG(
- 'test_dag',
- default_args={
- 'owner': 'airflow',
- 'start_date': DEFAULT_DATE})
+ self.dag = DAG('test_dag', default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE})
self.addCleanup(self.dag.clear)
self.clear_run()
self.addCleanup(self.clear_run)
@@ -113,21 +115,12 @@ def _assert_calls_equal(self, first, second):
self.assertTupleEqual(first.args, second.args)
# eliminate context (conf, dag_run, task_instance, etc.)
test_args = ["an_int", "a_date", "a_templated_string"]
- first.kwargs = {
- key: value
- for (key, value) in first.kwargs.items()
- if key in test_args
- }
- second.kwargs = {
- key: value
- for (key, value) in second.kwargs.items()
- if key in test_args
- }
+ first.kwargs = {key: value for (key, value) in first.kwargs.items() if key in test_args}
+ second.kwargs = {key: value for (key, value) in second.kwargs.items() if key in test_args}
self.assertDictEqual(first.kwargs, second.kwargs)
class TestPythonOperator(TestPythonBase):
-
def do_run(self):
self.run = True
@@ -136,10 +129,7 @@ def is_run(self):
def test_python_operator_run(self):
"""Tests that the python callable is invoked on task run."""
- task = PythonOperator(
- python_callable=self.do_run,
- task_id='python_operator',
- dag=self.dag)
+ task = PythonOperator(python_callable=self.do_run, task_id='python_operator', dag=self.dag)
self.assertFalse(self.is_run())
task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
self.assertTrue(self.is_run())
@@ -149,16 +139,10 @@ def test_python_operator_python_callable_is_callable(self):
the python_callable argument is callable."""
not_callable = {}
with self.assertRaises(AirflowException):
- PythonOperator(
- python_callable=not_callable,
- task_id='python_operator',
- dag=self.dag)
+ PythonOperator(python_callable=not_callable, task_id='python_operator', dag=self.dag)
not_callable = None
with self.assertRaises(AirflowException):
- PythonOperator(
- python_callable=not_callable,
- task_id='python_operator',
- dag=self.dag)
+ PythonOperator(python_callable=not_callable, task_id='python_operator', dag=self.dag)
def test_python_callable_arguments_are_templatized(self):
"""Test PythonOperator op_args are templatized"""
@@ -174,19 +158,15 @@ def test_python_callable_arguments_are_templatized(self):
# a Mock instance cannot be used as a callable function or test fails with a
# TypeError: Object of type Mock is not JSON serializable
python_callable=build_recording_function(recorded_calls),
- op_args=[
- 4,
- date(2019, 1, 1),
- "dag {{dag.dag_id}} ran on {{ds}}.",
- named_tuple
- ],
- dag=self.dag)
+ op_args=[4, date(2019, 1, 1), "dag {{dag.dag_id}} ran on {{ds}}.", named_tuple],
+ dag=self.dag,
+ )
self.dag.create_dagrun(
run_type=DagRunType.MANUAL,
execution_date=DEFAULT_DATE,
start_date=DEFAULT_DATE,
- state=State.RUNNING
+ state=State.RUNNING,
)
task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
@@ -194,10 +174,12 @@ def test_python_callable_arguments_are_templatized(self):
self.assertEqual(1, len(recorded_calls))
self._assert_calls_equal(
recorded_calls[0],
- Call(4,
- date(2019, 1, 1),
- f"dag {self.dag.dag_id} ran on {ds_templated}.",
- Named(ds_templated, 'unchanged'))
+ Call(
+ 4,
+ date(2019, 1, 1),
+ f"dag {self.dag.dag_id} ran on {ds_templated}.",
+ Named(ds_templated, 'unchanged'),
+ ),
)
def test_python_callable_keyword_arguments_are_templatized(self):
@@ -212,25 +194,29 @@ def test_python_callable_keyword_arguments_are_templatized(self):
op_kwargs={
'an_int': 4,
'a_date': date(2019, 1, 1),
- 'a_templated_string': "dag {{dag.dag_id}} ran on {{ds}}."
+ 'a_templated_string': "dag {{dag.dag_id}} ran on {{ds}}.",
},
- dag=self.dag)
+ dag=self.dag,
+ )
self.dag.create_dagrun(
run_type=DagRunType.MANUAL,
execution_date=DEFAULT_DATE,
start_date=DEFAULT_DATE,
- state=State.RUNNING
+ state=State.RUNNING,
)
task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
self.assertEqual(1, len(recorded_calls))
self._assert_calls_equal(
recorded_calls[0],
- Call(an_int=4,
- a_date=date(2019, 1, 1),
- a_templated_string="dag {} ran on {}.".format(
- self.dag.dag_id, DEFAULT_DATE.date().isoformat()))
+ Call(
+ an_int=4,
+ a_date=date(2019, 1, 1),
+ a_templated_string="dag {} ran on {}.".format(
+ self.dag.dag_id, DEFAULT_DATE.date().isoformat()
+ ),
+ ),
)
def test_python_operator_shallow_copy_attr(self):
@@ -239,15 +225,15 @@ def test_python_operator_shallow_copy_attr(self):
python_callable=not_callable,
task_id='python_operator',
op_kwargs={'certain_attrs': ''},
- dag=self.dag
+ dag=self.dag,
)
new_task = copy.deepcopy(original_task)
# shallow copy op_kwargs
- self.assertEqual(id(original_task.op_kwargs['certain_attrs']),
- id(new_task.op_kwargs['certain_attrs']))
+ self.assertEqual(
+ id(original_task.op_kwargs['certain_attrs']), id(new_task.op_kwargs['certain_attrs'])
+ )
# shallow copy python_callable
- self.assertEqual(id(original_task.python_callable),
- id(new_task.python_callable))
+ self.assertEqual(id(original_task.python_callable), id(new_task.python_callable))
def test_conflicting_kwargs(self):
self.dag.create_dagrun(
@@ -265,10 +251,7 @@ def func(dag):
raise RuntimeError(f"Should not be triggered, dag: {dag}")
python_operator = PythonOperator(
- task_id='python_operator',
- op_args=[1],
- python_callable=func,
- dag=self.dag
+ task_id='python_operator', op_args=[1], python_callable=func, dag=self.dag
)
with self.assertRaises(ValueError) as context:
@@ -296,7 +279,7 @@ def func(custom, dag):
op_kwargs={'custom': 1},
python_callable=func,
provide_context=True,
- dag=self.dag
+ dag=self.dag,
)
python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
@@ -314,10 +297,7 @@ def func(custom, dag):
self.assertIsNotNone(dag, "dag should be set")
python_operator = PythonOperator(
- task_id='python_operator',
- op_kwargs={'custom': 1},
- python_callable=func,
- dag=self.dag
+ task_id='python_operator', op_kwargs={'custom': 1}, python_callable=func, dag=self.dag
)
python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
@@ -335,16 +315,12 @@ def func(**context):
self.assertGreater(len(context), 0, "Context has not been injected")
python_operator = PythonOperator(
- task_id='python_operator',
- op_kwargs={'custom': 1},
- python_callable=func,
- dag=self.dag
+ task_id='python_operator', op_kwargs={'custom': 1}, python_callable=func, dag=self.dag
)
python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
class TestAirflowTaskDecorator(TestPythonBase):
-
def test_python_operator_python_callable_is_callable(self):
"""Tests that @task will only instantiate if
the python_callable argument is callable."""
@@ -369,6 +345,7 @@ def test_fail_method(self):
"""Tests that @task will fail if signature is not binding."""
with pytest.raises(AirflowException):
+
class Test:
num = 2
@@ -389,7 +366,7 @@ def add_number(num: int):
run_id=DagRunType.MANUAL,
execution_date=DEFAULT_DATE,
start_date=DEFAULT_DATE,
- state=State.RUNNING
+ state=State.RUNNING,
)
with pytest.raises(AirflowException):
@@ -407,7 +384,7 @@ def add_number(num: int):
run_id=DagRunType.MANUAL,
execution_date=DEFAULT_DATE,
start_date=DEFAULT_DATE,
- state=State.RUNNING
+ state=State.RUNNING,
)
with pytest.raises(AirflowException):
@@ -427,14 +404,15 @@ def test_python_callable_arguments_are_templatized(self):
# a Mock instance cannot be used as a callable function or test fails with a
# TypeError: Object of type Mock is not JSON serializable
build_recording_function(recorded_calls),
- dag=self.dag)
+ dag=self.dag,
+ )
ret = task(4, date(2019, 1, 1), "dag {{dag.dag_id}} ran on {{ds}}.", named_tuple)
self.dag.create_dagrun(
run_id=DagRunType.MANUAL,
execution_date=DEFAULT_DATE,
start_date=DEFAULT_DATE,
- state=State.RUNNING
+ state=State.RUNNING,
)
ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) # pylint: disable=maybe-no-member
@@ -442,10 +420,12 @@ def test_python_callable_arguments_are_templatized(self):
assert len(recorded_calls) == 1
self._assert_calls_equal(
recorded_calls[0],
- Call(4,
- date(2019, 1, 1),
- f"dag {self.dag.dag_id} ran on {ds_templated}.",
- Named(ds_templated, 'unchanged'))
+ Call(
+ 4,
+ date(2019, 1, 1),
+ f"dag {self.dag.dag_id} ran on {ds_templated}.",
+ Named(ds_templated, 'unchanged'),
+ ),
)
def test_python_callable_keyword_arguments_are_templatized(self):
@@ -456,24 +436,27 @@ def test_python_callable_keyword_arguments_are_templatized(self):
# a Mock instance cannot be used as a callable function or test fails with a
# TypeError: Object of type Mock is not JSON serializable
build_recording_function(recorded_calls),
- dag=self.dag
+ dag=self.dag,
)
ret = task(an_int=4, a_date=date(2019, 1, 1), a_templated_string="dag {{dag.dag_id}} ran on {{ds}}.")
self.dag.create_dagrun(
run_id=DagRunType.MANUAL,
execution_date=DEFAULT_DATE,
start_date=DEFAULT_DATE,
- state=State.RUNNING
+ state=State.RUNNING,
)
ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) # pylint: disable=maybe-no-member
assert len(recorded_calls) == 1
self._assert_calls_equal(
recorded_calls[0],
- Call(an_int=4,
- a_date=date(2019, 1, 1),
- a_templated_string="dag {} ran on {}.".format(
- self.dag.dag_id, DEFAULT_DATE.date().isoformat()))
+ Call(
+ an_int=4,
+ a_date=date(2019, 1, 1),
+ a_templated_string="dag {} ran on {}.".format(
+ self.dag.dag_id, DEFAULT_DATE.date().isoformat()
+ ),
+ ),
)
def test_manual_task_id(self):
@@ -523,10 +506,7 @@ def test_multiple_outputs(self):
@task_decorator(multiple_outputs=True)
def return_dict(number: int):
- return {
- 'number': number + 1,
- '43': 43
- }
+ return {'number': number + 1, '43': 43}
test_number = 10
with self.dag:
@@ -536,7 +516,7 @@ def return_dict(number: int):
run_id=DagRunType.MANUAL,
start_date=timezone.utcnow(),
execution_date=DEFAULT_DATE,
- state=State.RUNNING
+ state=State.RUNNING,
)
ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) # pylint: disable=maybe-no-member
@@ -578,7 +558,7 @@ def add_num(number: int, num2: int = 2):
run_id=DagRunType.MANUAL,
start_date=timezone.utcnow(),
execution_date=DEFAULT_DATE,
- state=State.RUNNING
+ state=State.RUNNING,
)
bigger_number.operator.run( # pylint: disable=maybe-no-member
@@ -655,11 +635,11 @@ def setUpClass(cls):
session.query(TI).delete()
def setUp(self):
- self.dag = DAG('branch_operator_test',
- default_args={
- 'owner': 'airflow',
- 'start_date': DEFAULT_DATE},
- schedule_interval=INTERVAL)
+ self.dag = DAG(
+ 'branch_operator_test',
+ default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE},
+ schedule_interval=INTERVAL,
+ )
self.branch_1 = DummyOperator(task_id='branch_1', dag=self.dag)
self.branch_2 = DummyOperator(task_id='branch_2', dag=self.dag)
@@ -674,9 +654,9 @@ def tearDown(self):
def test_without_dag_run(self):
"""This checks the defensive against non existent tasks in a dag run"""
- branch_op = BranchPythonOperator(task_id='make_choice',
- dag=self.dag,
- python_callable=lambda: 'branch_1')
+ branch_op = BranchPythonOperator(
+ task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_1'
+ )
self.branch_1.set_upstream(branch_op)
self.branch_2.set_upstream(branch_op)
self.dag.clear()
@@ -684,10 +664,7 @@ def test_without_dag_run(self):
branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
with create_session() as session:
- tis = session.query(TI).filter(
- TI.dag_id == self.dag.dag_id,
- TI.execution_date == DEFAULT_DATE
- )
+ tis = session.query(TI).filter(TI.dag_id == self.dag.dag_id, TI.execution_date == DEFAULT_DATE)
for ti in tis:
if ti.task_id == 'make_choice':
@@ -702,9 +679,9 @@ def test_without_dag_run(self):
def test_branch_list_without_dag_run(self):
"""This checks if the BranchPythonOperator supports branching off to a list of tasks."""
- branch_op = BranchPythonOperator(task_id='make_choice',
- dag=self.dag,
- python_callable=lambda: ['branch_1', 'branch_2'])
+ branch_op = BranchPythonOperator(
+ task_id='make_choice', dag=self.dag, python_callable=lambda: ['branch_1', 'branch_2']
+ )
self.branch_1.set_upstream(branch_op)
self.branch_2.set_upstream(branch_op)
self.branch_3 = DummyOperator(task_id='branch_3', dag=self.dag)
@@ -714,10 +691,7 @@ def test_branch_list_without_dag_run(self):
branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
with create_session() as session:
- tis = session.query(TI).filter(
- TI.dag_id == self.dag.dag_id,
- TI.execution_date == DEFAULT_DATE
- )
+ tis = session.query(TI).filter(TI.dag_id == self.dag.dag_id, TI.execution_date == DEFAULT_DATE)
expected = {
"make_choice": State.SUCCESS,
@@ -733,9 +707,9 @@ def test_branch_list_without_dag_run(self):
raise ValueError(f'Invalid task id {ti.task_id} found!')
def test_with_dag_run(self):
- branch_op = BranchPythonOperator(task_id='make_choice',
- dag=self.dag,
- python_callable=lambda: 'branch_1')
+ branch_op = BranchPythonOperator(
+ task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_1'
+ )
self.branch_1.set_upstream(branch_op)
self.branch_2.set_upstream(branch_op)
@@ -745,7 +719,7 @@ def test_with_dag_run(self):
run_type=DagRunType.MANUAL,
start_date=timezone.utcnow(),
execution_date=DEFAULT_DATE,
- state=State.RUNNING
+ state=State.RUNNING,
)
branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
@@ -762,9 +736,9 @@ def test_with_dag_run(self):
raise ValueError(f'Invalid task id {ti.task_id} found!')
def test_with_skip_in_branch_downstream_dependencies(self):
- branch_op = BranchPythonOperator(task_id='make_choice',
- dag=self.dag,
- python_callable=lambda: 'branch_1')
+ branch_op = BranchPythonOperator(
+ task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_1'
+ )
branch_op >> self.branch_1 >> self.branch_2
branch_op >> self.branch_2
@@ -774,7 +748,7 @@ def test_with_skip_in_branch_downstream_dependencies(self):
run_type=DagRunType.MANUAL,
start_date=timezone.utcnow(),
execution_date=DEFAULT_DATE,
- state=State.RUNNING
+ state=State.RUNNING,
)
branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
@@ -791,9 +765,9 @@ def test_with_skip_in_branch_downstream_dependencies(self):
raise ValueError(f'Invalid task id {ti.task_id} found!')
def test_with_skip_in_branch_downstream_dependencies2(self):
- branch_op = BranchPythonOperator(task_id='make_choice',
- dag=self.dag,
- python_callable=lambda: 'branch_2')
+ branch_op = BranchPythonOperator(
+ task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_2'
+ )
branch_op >> self.branch_1 >> self.branch_2
branch_op >> self.branch_2
@@ -803,7 +777,7 @@ def test_with_skip_in_branch_downstream_dependencies2(self):
run_type=DagRunType.MANUAL,
start_date=timezone.utcnow(),
execution_date=DEFAULT_DATE,
- state=State.RUNNING
+ state=State.RUNNING,
)
branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
@@ -820,9 +794,9 @@ def test_with_skip_in_branch_downstream_dependencies2(self):
raise ValueError(f'Invalid task id {ti.task_id} found!')
def test_xcom_push(self):
- branch_op = BranchPythonOperator(task_id='make_choice',
- dag=self.dag,
- python_callable=lambda: 'branch_1')
+ branch_op = BranchPythonOperator(
+ task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_1'
+ )
self.branch_1.set_upstream(branch_op)
self.branch_2.set_upstream(branch_op)
@@ -832,7 +806,7 @@ def test_xcom_push(self):
run_type=DagRunType.MANUAL,
start_date=timezone.utcnow(),
execution_date=DEFAULT_DATE,
- state=State.RUNNING
+ state=State.RUNNING,
)
branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
@@ -840,17 +814,16 @@ def test_xcom_push(self):
tis = dr.get_task_instances()
for ti in tis:
if ti.task_id == 'make_choice':
- self.assertEqual(
- ti.xcom_pull(task_ids='make_choice'), 'branch_1')
+ self.assertEqual(ti.xcom_pull(task_ids='make_choice'), 'branch_1')
def test_clear_skipped_downstream_task(self):
"""
After a downstream task is skipped by BranchPythonOperator, clearing the skipped task
should not cause it to be executed.
"""
- branch_op = BranchPythonOperator(task_id='make_choice',
- dag=self.dag,
- python_callable=lambda: 'branch_1')
+ branch_op = BranchPythonOperator(
+ task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_1'
+ )
branches = [self.branch_1, self.branch_2]
branch_op >> branches
self.dag.clear()
@@ -859,7 +832,7 @@ def test_clear_skipped_downstream_task(self):
run_type=DagRunType.MANUAL,
start_date=timezone.utcnow(),
execution_date=DEFAULT_DATE,
- state=State.RUNNING
+ state=State.RUNNING,
)
branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
@@ -919,15 +892,12 @@ def tearDown(self):
def test_without_dag_run(self):
"""This checks the defensive against non existent tasks in a dag run"""
value = False
- dag = DAG('shortcircuit_operator_test_without_dag_run',
- default_args={
- 'owner': 'airflow',
- 'start_date': DEFAULT_DATE
- },
- schedule_interval=INTERVAL)
- short_op = ShortCircuitOperator(task_id='make_choice',
- dag=dag,
- python_callable=lambda: value)
+ dag = DAG(
+ 'shortcircuit_operator_test_without_dag_run',
+ default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE},
+ schedule_interval=INTERVAL,
+ )
+ short_op = ShortCircuitOperator(task_id='make_choice', dag=dag, python_callable=lambda: value)
branch_1 = DummyOperator(task_id='branch_1', dag=dag)
branch_1.set_upstream(short_op)
branch_2 = DummyOperator(task_id='branch_2', dag=dag)
@@ -939,10 +909,7 @@ def test_without_dag_run(self):
short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
with create_session() as session:
- tis = session.query(TI).filter(
- TI.dag_id == dag.dag_id,
- TI.execution_date == DEFAULT_DATE
- )
+ tis = session.query(TI).filter(TI.dag_id == dag.dag_id, TI.execution_date == DEFAULT_DATE)
for ti in tis:
if ti.task_id == 'make_choice':
@@ -972,15 +939,12 @@ def test_without_dag_run(self):
def test_with_dag_run(self):
value = False
- dag = DAG('shortcircuit_operator_test_with_dag_run',
- default_args={
- 'owner': 'airflow',
- 'start_date': DEFAULT_DATE
- },
- schedule_interval=INTERVAL)
- short_op = ShortCircuitOperator(task_id='make_choice',
- dag=dag,
- python_callable=lambda: value)
+ dag = DAG(
+ 'shortcircuit_operator_test_with_dag_run',
+ default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE},
+ schedule_interval=INTERVAL,
+ )
+ short_op = ShortCircuitOperator(task_id='make_choice', dag=dag, python_callable=lambda: value)
branch_1 = DummyOperator(task_id='branch_1', dag=dag)
branch_1.set_upstream(short_op)
branch_2 = DummyOperator(task_id='branch_2', dag=dag)
@@ -994,7 +958,7 @@ def test_with_dag_run(self):
run_type=DagRunType.MANUAL,
start_date=timezone.utcnow(),
execution_date=DEFAULT_DATE,
- state=State.RUNNING
+ state=State.RUNNING,
)
upstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
@@ -1035,15 +999,12 @@ def test_clear_skipped_downstream_task(self):
After a downstream task is skipped by ShortCircuitOperator, clearing the skipped task
should not cause it to be executed.
"""
- dag = DAG('shortcircuit_clear_skipped_downstream_task',
- default_args={
- 'owner': 'airflow',
- 'start_date': DEFAULT_DATE
- },
- schedule_interval=INTERVAL)
- short_op = ShortCircuitOperator(task_id='make_choice',
- dag=dag,
- python_callable=lambda: False)
+ dag = DAG(
+ 'shortcircuit_clear_skipped_downstream_task',
+ default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE},
+ schedule_interval=INTERVAL,
+ )
+ short_op = ShortCircuitOperator(task_id='make_choice', dag=dag, python_callable=lambda: False)
downstream = DummyOperator(task_id='downstream', dag=dag)
short_op >> downstream
@@ -1054,7 +1015,7 @@ def test_clear_skipped_downstream_task(self):
run_type=DagRunType.MANUAL,
start_date=timezone.utcnow(),
execution_date=DEFAULT_DATE,
- state=State.RUNNING
+ state=State.RUNNING,
)
short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
@@ -1072,9 +1033,7 @@ def test_clear_skipped_downstream_task(self):
# Clear downstream
with create_session() as session:
- clear_task_instances([t for t in tis if t.task_id == "downstream"],
- session=session,
- dag=dag)
+ clear_task_instances([t for t in tis if t.task_id == "downstream"], session=session, dag=dag)
# Run downstream again
downstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
@@ -1093,24 +1052,19 @@ def test_clear_skipped_downstream_task(self):
class TestPythonVirtualenvOperator(unittest.TestCase):
-
def setUp(self):
super().setUp()
self.dag = DAG(
'test_dag',
- default_args={
- 'owner': 'airflow',
- 'start_date': DEFAULT_DATE},
- schedule_interval=INTERVAL)
+ default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE},
+ schedule_interval=INTERVAL,
+ )
self.addCleanup(self.dag.clear)
def _run_as_operator(self, fn, python_version=sys.version_info[0], **kwargs):
task = PythonVirtualenvOperator(
- python_callable=fn,
- python_version=python_version,
- task_id='task',
- dag=self.dag,
- **kwargs)
+ python_callable=fn, python_version=python_version, task_id='task', dag=self.dag, **kwargs
+ )
task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
return task
@@ -1146,11 +1100,11 @@ def f():
self._run_as_operator(f, requirements=['funcsigs'], system_site_packages=True)
def test_with_requirements_pinned(self):
- self.assertNotEqual(
- '0.4', funcsigs.__version__, 'Please update this string if this fails')
+ self.assertNotEqual('0.4', funcsigs.__version__, 'Please update this string if this fails')
def f():
import funcsigs # noqa: F401 # pylint: disable=redefined-outer-name,reimported
+
if funcsigs.__version__ != '0.4':
raise Exception
@@ -1160,15 +1114,13 @@ def test_unpinned_requirements(self):
def f():
import funcsigs # noqa: F401 # pylint: disable=redefined-outer-name,reimported,unused-import
- self._run_as_operator(
- f, requirements=['funcsigs', 'dill'], system_site_packages=False)
+ self._run_as_operator(f, requirements=['funcsigs', 'dill'], system_site_packages=False)
def test_range_requirements(self):
def f():
import funcsigs # noqa: F401 # pylint: disable=redefined-outer-name,reimported,unused-import
- self._run_as_operator(
- f, requirements=['funcsigs>1.0', 'dill'], system_site_packages=False)
+ self._run_as_operator(f, requirements=['funcsigs>1.0', 'dill'], system_site_packages=False)
def test_fail(self):
def f():
@@ -1193,6 +1145,7 @@ def f():
def test_python_3(self):
def f():
import sys # pylint: disable=reimported,unused-import,redefined-outer-name
+
print(sys.version)
try:
{}.iteritems() # pylint: disable=no-member
@@ -1234,8 +1187,7 @@ def f():
if virtualenv_string_args[0] != virtualenv_string_args[2]:
raise Exception
- self._run_as_operator(
- f, python_version=self._invert_python_major_version(), string_args=[1, 2, 1])
+ self._run_as_operator(f, python_version=self._invert_python_major_version(), string_args=[1, 2, 1])
def test_with_args(self):
def f(a, b, c=False, d=False):
@@ -1254,10 +1206,7 @@ def f():
def test_lambda(self):
with self.assertRaises(AirflowException):
- PythonVirtualenvOperator(
- python_callable=lambda x: 4,
- task_id='task',
- dag=self.dag)
+ PythonVirtualenvOperator(python_callable=lambda x: 4, task_id='task', dag=self.dag)
def test_nonimported_as_arg(self):
def f(_):
@@ -1305,16 +1254,11 @@ def f(
dag_run,
task,
# other
- **context
+ **context,
): # pylint: disable=unused-argument,too-many-arguments,too-many-locals
pass
- self._run_as_operator(
- f,
- use_dill=True,
- system_site_packages=True,
- requirements=None
- )
+ self._run_as_operator(f, use_dill=True, system_site_packages=True, requirements=None)
def test_pendulum_context(self):
def f(
@@ -1344,15 +1288,12 @@ def f(
prev_execution_date_success,
prev_start_date_success,
# other
- **context
+ **context,
): # pylint: disable=unused-argument,too-many-arguments,too-many-locals
pass
self._run_as_operator(
- f,
- use_dill=True,
- system_site_packages=False,
- requirements=['pendulum', 'lazy_object_proxy']
+ f, use_dill=True, system_site_packages=False, requirements=['pendulum', 'lazy_object_proxy']
)
def test_base_context(self):
@@ -1377,16 +1318,11 @@ def f(
yesterday_ds,
yesterday_ds_nodash,
# other
- **context
+ **context,
): # pylint: disable=unused-argument,too-many-arguments,too-many-locals
pass
- self._run_as_operator(
- f,
- use_dill=True,
- system_site_packages=False,
- requirements=None
- )
+ self._run_as_operator(f, use_dill=True, system_site_packages=False, requirements=None)
DEFAULT_ARGS = {
@@ -1416,7 +1352,9 @@ def test_context_removed_after_exit(self):
with set_current_context(example_context):
pass
- with pytest.raises(AirflowException, ):
+ with pytest.raises(
+ AirflowException,
+ ):
get_current_context()
def test_nested_context(self):
@@ -1477,7 +1415,7 @@ def test_get_context_in_old_style_context_task(self):
[
("task1", [State.SUCCESS, State.SUCCESS, State.SUCCESS]),
("join", [State.SUCCESS, State.SKIPPED, State.SUCCESS]),
- ]
+ ],
)
def test_empty_branch(choice, expected_states):
"""
diff --git a/tests/operators/test_sql.py b/tests/operators/test_sql.py
index b6c9ec12595fe..a2bd128576907 100644
--- a/tests/operators/test_sql.py
+++ b/tests/operators/test_sql.py
@@ -25,7 +25,10 @@
from airflow.exceptions import AirflowException
from airflow.models import DAG, DagRun, TaskInstance as TI
from airflow.operators.check_operator import (
- CheckOperator, IntervalCheckOperator, ThresholdCheckOperator, ValueCheckOperator,
+ CheckOperator,
+ IntervalCheckOperator,
+ ThresholdCheckOperator,
+ ValueCheckOperator,
)
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.sql import BranchSQLOperator
@@ -98,8 +101,7 @@ def _construct_operator(self, sql, pass_value, tolerance=None):
def test_pass_value_template_string(self):
pass_value_str = "2018-03-22"
- operator = self._construct_operator(
- "select date from tab1;", "{{ ds }}")
+ operator = self._construct_operator("select date from tab1;", "{{ ds }}")
operator.render_template_fields({"ds": pass_value_str})
@@ -108,8 +110,7 @@ def test_pass_value_template_string(self):
def test_pass_value_template_string_float(self):
pass_value_float = 4.0
- operator = self._construct_operator(
- "select date from tab1;", pass_value_float)
+ operator = self._construct_operator("select date from tab1;", pass_value_float)
operator.render_template_fields({})
@@ -134,8 +135,7 @@ def test_execute_fail(self, mock_get_db_hook):
mock_hook.get_first.return_value = [11]
mock_get_db_hook.return_value = mock_hook
- operator = self._construct_operator(
- "select value from tab1 limit 1;", 5, 1)
+ operator = self._construct_operator("select value from tab1 limit 1;", 5, 1)
with self.assertRaisesRegex(AirflowException, "Tolerance:100.0%"):
operator.execute()
@@ -155,7 +155,9 @@ def test_invalid_ratio_formula(self):
with self.assertRaisesRegex(AirflowException, "Invalid diff_method"):
self._construct_operator(
table="test_table",
- metric_thresholds={"f1": 1, },
+ metric_thresholds={
+ "f1": 1,
+ },
ratio_formula="abs",
ignore_zero=False,
)
@@ -168,7 +170,9 @@ def test_execute_not_ignore_zero(self, mock_get_db_hook):
operator = self._construct_operator(
table="test_table",
- metric_thresholds={"f1": 1, },
+ metric_thresholds={
+ "f1": 1,
+ },
ratio_formula="max_over_min",
ignore_zero=False,
)
@@ -184,7 +188,9 @@ def test_execute_ignore_zero(self, mock_get_db_hook):
operator = self._construct_operator(
table="test_table",
- metric_thresholds={"f1": 1, },
+ metric_thresholds={
+ "f1": 1,
+ },
ratio_formula="max_over_min",
ignore_zero=True,
)
@@ -208,7 +214,12 @@ def returned_row():
operator = self._construct_operator(
table="test_table",
- metric_thresholds={"f0": 1.0, "f1": 1.5, "f2": 2.0, "f3": 2.5, },
+ metric_thresholds={
+ "f0": 1.0,
+ "f1": 1.5,
+ "f2": 2.0,
+ "f3": 2.5,
+ },
ratio_formula="max_over_min",
ignore_zero=True,
)
@@ -233,7 +244,12 @@ def returned_row():
operator = self._construct_operator(
table="test_table",
- metric_thresholds={"f0": 0.5, "f1": 0.6, "f2": 0.7, "f3": 0.8, },
+ metric_thresholds={
+ "f0": 0.5,
+ "f1": 0.6,
+ "f2": 0.7,
+ "f3": 0.8,
+ },
ratio_formula="relative_diff",
ignore_zero=True,
)
@@ -260,9 +276,7 @@ def test_pass_min_value_max_value(self, mock_get_db_hook):
mock_hook.get_first.return_value = (10,)
mock_get_db_hook.return_value = mock_hook
- operator = self._construct_operator(
- "Select avg(val) from table1 limit 1", 1, 100
- )
+ operator = self._construct_operator("Select avg(val) from table1 limit 1", 1, 100)
operator.execute()
@@ -272,9 +286,7 @@ def test_fail_min_value_max_value(self, mock_get_db_hook):
mock_hook.get_first.return_value = (10,)
mock_get_db_hook.return_value = mock_hook
- operator = self._construct_operator(
- "Select avg(val) from table1 limit 1", 20, 100
- )
+ operator = self._construct_operator("Select avg(val) from table1 limit 1", 20, 100)
with self.assertRaisesRegex(AirflowException, "10.*20.0.*100.0"):
operator.execute()
@@ -285,8 +297,7 @@ def test_pass_min_sql_max_sql(self, mock_get_db_hook):
mock_hook.get_first.side_effect = lambda x: (int(x.split()[1]),)
mock_get_db_hook.return_value = mock_hook
- operator = self._construct_operator(
- "Select 10", "Select 1", "Select 100")
+ operator = self._construct_operator("Select 10", "Select 1", "Select 100")
operator.execute()
@@ -296,8 +307,7 @@ def test_fail_min_sql_max_sql(self, mock_get_db_hook):
mock_hook.get_first.side_effect = lambda x: (int(x.split()[1]),)
mock_get_db_hook.return_value = mock_hook
- operator = self._construct_operator(
- "Select 10", "Select 20", "Select 100")
+ operator = self._construct_operator("Select 10", "Select 20", "Select 100")
with self.assertRaisesRegex(AirflowException, "10.*20.*100"):
operator.execute()
@@ -367,8 +377,7 @@ def test_unsupported_conn_type(self):
)
with self.assertRaises(AirflowException):
- op.run(start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE, ignore_ti_state=True)
+ op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_invalid_conn(self):
"""Check if BranchSQLOperator throws an exception for invalid connection """
@@ -382,8 +391,7 @@ def test_invalid_conn(self):
)
with self.assertRaises(AirflowException):
- op.run(start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE, ignore_ti_state=True)
+ op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_invalid_follow_task_true(self):
"""Check if BranchSQLOperator throws an exception for invalid connection """
@@ -397,8 +405,7 @@ def test_invalid_follow_task_true(self):
)
with self.assertRaises(AirflowException):
- op.run(start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE, ignore_ti_state=True)
+ op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_invalid_follow_task_false(self):
"""Check if BranchSQLOperator throws an exception for invalid connection """
@@ -412,8 +419,7 @@ def test_invalid_follow_task_false(self):
)
with self.assertRaises(AirflowException):
- op.run(start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE, ignore_ti_state=True)
+ op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
@pytest.mark.backend("mysql")
def test_sql_branch_operator_mysql(self):
@@ -426,9 +432,7 @@ def test_sql_branch_operator_mysql(self):
follow_task_ids_if_false="branch_2",
dag=self.dag,
)
- branch_op.run(
- start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True
- )
+ branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
@pytest.mark.backend("postgres")
def test_sql_branch_operator_postgres(self):
@@ -441,9 +445,7 @@ def test_sql_branch_operator_postgres(self):
follow_task_ids_if_false="branch_2",
dag=self.dag,
)
- branch_op.run(
- start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True
- )
+ branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
@mock.patch("airflow.operators.sql.BaseHook")
def test_branch_single_value_with_dag_run(self, mock_hook):
@@ -469,9 +471,7 @@ def test_branch_single_value_with_dag_run(self, mock_hook):
)
mock_hook.get_connection("mysql_default").conn_type = "mysql"
- mock_get_records = (
- mock_hook.get_connection.return_value.get_hook.return_value.get_first
- )
+ mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_first
mock_get_records.return_value = 1
@@ -512,9 +512,7 @@ def test_branch_true_with_dag_run(self, mock_hook):
)
mock_hook.get_connection("mysql_default").conn_type = "mysql"
- mock_get_records = (
- mock_hook.get_connection.return_value.get_hook.return_value.get_first
- )
+ mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_first
for true_value in SUPPORTED_TRUE_VALUES:
mock_get_records.return_value = true_value
@@ -556,9 +554,7 @@ def test_branch_false_with_dag_run(self, mock_hook):
)
mock_hook.get_connection("mysql_default").conn_type = "mysql"
- mock_get_records = (
- mock_hook.get_connection.return_value.get_hook.return_value.get_first
- )
+ mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_first
for false_value in SUPPORTED_FALSE_VALUES:
mock_get_records.return_value = false_value
@@ -602,9 +598,7 @@ def test_branch_list_with_dag_run(self, mock_hook):
)
mock_hook.get_connection("mysql_default").conn_type = "mysql"
- mock_get_records = (
- mock_hook.get_connection.return_value.get_hook.return_value.get_first
- )
+ mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_first
mock_get_records.return_value = [["1"]]
branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
@@ -646,9 +640,7 @@ def test_invalid_query_result_with_dag_run(self, mock_hook):
)
mock_hook.get_connection("mysql_default").conn_type = "mysql"
- mock_get_records = (
- mock_hook.get_connection.return_value.get_hook.return_value.get_first
- )
+ mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_first
mock_get_records.return_value = ["Invalid Value"]
@@ -679,9 +671,7 @@ def test_with_skip_in_branch_downstream_dependencies(self, mock_hook):
)
mock_hook.get_connection("mysql_default").conn_type = "mysql"
- mock_get_records = (
- mock_hook.get_connection.return_value.get_hook.return_value.get_first
- )
+ mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_first
for true_value in SUPPORTED_TRUE_VALUES:
mock_get_records.return_value = [true_value]
@@ -723,9 +713,7 @@ def test_with_skip_in_branch_downstream_dependencies2(self, mock_hook):
)
mock_hook.get_connection("mysql_default").conn_type = "mysql"
- mock_get_records = (
- mock_hook.get_connection.return_value.get_hook.return_value.get_first
- )
+ mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_first
for false_value in SUPPORTED_FALSE_VALUES:
mock_get_records.return_value = [false_value]
diff --git a/tests/operators/test_subdag_operator.py b/tests/operators/test_subdag_operator.py
index f1946d36af96e..bd2b578d631fd 100644
--- a/tests/operators/test_subdag_operator.py
+++ b/tests/operators/test_subdag_operator.py
@@ -42,7 +42,6 @@
class TestSubDagOperator(unittest.TestCase):
-
def setUp(self):
clear_db_runs()
self.dag_run_running = DagRun()
@@ -63,15 +62,9 @@ def test_subdag_name(self):
subdag_bad3 = DAG('bad.bad', default_args=default_args)
SubDagOperator(task_id='test', dag=dag, subdag=subdag_good)
- self.assertRaises(
- AirflowException,
- SubDagOperator, task_id='test', dag=dag, subdag=subdag_bad1)
- self.assertRaises(
- AirflowException,
- SubDagOperator, task_id='test', dag=dag, subdag=subdag_bad2)
- self.assertRaises(
- AirflowException,
- SubDagOperator, task_id='test', dag=dag, subdag=subdag_bad3)
+ self.assertRaises(AirflowException, SubDagOperator, task_id='test', dag=dag, subdag=subdag_bad1)
+ self.assertRaises(AirflowException, SubDagOperator, task_id='test', dag=dag, subdag=subdag_bad2)
+ self.assertRaises(AirflowException, SubDagOperator, task_id='test', dag=dag, subdag=subdag_bad3)
def test_subdag_in_context_manager(self):
"""
@@ -101,14 +94,12 @@ def test_subdag_pools(self):
DummyOperator(task_id='dummy', dag=subdag, pool='test_pool_1')
self.assertRaises(
- AirflowException,
- SubDagOperator,
- task_id='child', dag=dag, subdag=subdag, pool='test_pool_1')
+ AirflowException, SubDagOperator, task_id='child', dag=dag, subdag=subdag, pool='test_pool_1'
+ )
# recreate dag because failed subdagoperator was already added
dag = DAG('parent', default_args=default_args)
- SubDagOperator(
- task_id='child', dag=dag, subdag=subdag, pool='test_pool_10')
+ SubDagOperator(task_id='child', dag=dag, subdag=subdag, pool='test_pool_10')
session.delete(pool_1)
session.delete(pool_10)
@@ -132,9 +123,7 @@ def test_subdag_pools_no_possible_conflict(self):
DummyOperator(task_id='dummy', dag=subdag, pool='test_pool_10')
mock_session = Mock()
- SubDagOperator(
- task_id='child', dag=dag, subdag=subdag, pool='test_pool_1',
- session=mock_session)
+ SubDagOperator(task_id='child', dag=dag, subdag=subdag, pool='test_pool_1', session=mock_session)
self.assertFalse(mock_session.query.called)
session.delete(pool_1)
@@ -272,22 +261,19 @@ def test_rerun_failed_subdag(self):
sub_dagrun.refresh_from_db()
self.assertEqual(sub_dagrun.state, State.RUNNING)
- @parameterized.expand([
- (SkippedStatePropagationOptions.ALL_LEAVES, [State.SKIPPED, State.SKIPPED], True),
- (SkippedStatePropagationOptions.ALL_LEAVES, [State.SKIPPED, State.SUCCESS], False),
- (SkippedStatePropagationOptions.ANY_LEAF, [State.SKIPPED, State.SUCCESS], True),
- (SkippedStatePropagationOptions.ANY_LEAF, [State.FAILED, State.SKIPPED], True),
- (None, [State.SKIPPED, State.SKIPPED], False),
- ])
+ @parameterized.expand(
+ [
+ (SkippedStatePropagationOptions.ALL_LEAVES, [State.SKIPPED, State.SKIPPED], True),
+ (SkippedStatePropagationOptions.ALL_LEAVES, [State.SKIPPED, State.SUCCESS], False),
+ (SkippedStatePropagationOptions.ANY_LEAF, [State.SKIPPED, State.SUCCESS], True),
+ (SkippedStatePropagationOptions.ANY_LEAF, [State.FAILED, State.SKIPPED], True),
+ (None, [State.SKIPPED, State.SKIPPED], False),
+ ]
+ )
@mock.patch('airflow.operators.subdag_operator.SubDagOperator.skip')
@mock.patch('airflow.operators.subdag_operator.get_task_instance')
def test_subdag_with_propagate_skipped_state(
- self,
- propagate_option,
- states,
- skip_parent,
- mock_get_task_instance,
- mock_skip
+ self, propagate_option, states, skip_parent, mock_get_task_instance, mock_skip
):
"""
Tests that skipped state of leaf tasks propagates to the parent dag.
@@ -296,15 +282,10 @@ def test_subdag_with_propagate_skipped_state(
dag = DAG('parent', default_args=default_args)
subdag = DAG('parent.test', default_args=default_args)
subdag_task = SubDagOperator(
- task_id='test',
- subdag=subdag,
- dag=dag,
- poke_interval=1,
- propagate_skipped_state=propagate_option
+ task_id='test', subdag=subdag, dag=dag, poke_interval=1, propagate_skipped_state=propagate_option
)
dummy_subdag_tasks = [
- DummyOperator(task_id=f'dummy_subdag_{i}', dag=subdag)
- for i in range(len(states))
+ DummyOperator(task_id=f'dummy_subdag_{i}', dag=subdag) for i in range(len(states))
]
dummy_dag_task = DummyOperator(task_id='dummy_dag', dag=dag)
subdag_task >> dummy_dag_task
@@ -312,25 +293,14 @@ def test_subdag_with_propagate_skipped_state(
subdag_task._get_dagrun = Mock()
subdag_task._get_dagrun.return_value = self.dag_run_success
mock_get_task_instance.side_effect = [
- TaskInstance(
- task=task,
- execution_date=DEFAULT_DATE,
- state=state
- ) for task, state in zip(dummy_subdag_tasks, states)
+ TaskInstance(task=task, execution_date=DEFAULT_DATE, state=state)
+ for task, state in zip(dummy_subdag_tasks, states)
]
- context = {
- 'execution_date': DEFAULT_DATE,
- 'dag_run': DagRun(),
- 'task': subdag_task
- }
+ context = {'execution_date': DEFAULT_DATE, 'dag_run': DagRun(), 'task': subdag_task}
subdag_task.post_execute(context)
if skip_parent:
- mock_skip.assert_called_once_with(
- context['dag_run'],
- context['execution_date'],
- [dummy_dag_task]
- )
+ mock_skip.assert_called_once_with(context['dag_run'], context['execution_date'], [dummy_dag_task])
else:
mock_skip.assert_not_called()
diff --git a/tests/plugins/test_plugin.py b/tests/plugins/test_plugin.py
index 1f719f16d5fb7..d7a4df1f40fbb 100644
--- a/tests/plugins/test_plugin.py
+++ b/tests/plugins/test_plugin.py
@@ -20,14 +20,21 @@
from flask_appbuilder import BaseView as AppBuilderBaseView, expose
from airflow.executors.base_executor import BaseExecutor
+
# Importing base classes that we need to derive
from airflow.hooks.base_hook import BaseHook
from airflow.models.baseoperator import BaseOperator
+
# This is the class you derive to create a plugin
from airflow.plugins_manager import AirflowPlugin
from airflow.sensors.base_sensor_operator import BaseSensorOperator
from tests.test_utils.mock_operators import (
- AirflowLink, AirflowLink2, CustomBaseIndexOpLink, CustomOpLink, GithubLink, GoogleLink,
+ AirflowLink,
+ AirflowLink2,
+ CustomBaseIndexOpLink,
+ CustomOpLink,
+ GithubLink,
+ GoogleLink,
)
@@ -66,22 +73,24 @@ def test(self):
v_appbuilder_view = PluginTestAppBuilderBaseView()
-v_appbuilder_package = {"name": "Test View",
- "category": "Test Plugin",
- "view": v_appbuilder_view}
+v_appbuilder_package = {"name": "Test View", "category": "Test Plugin", "view": v_appbuilder_view}
# Creating a flask appbuilder Menu Item
-appbuilder_mitem = {"name": "Google",
- "category": "Search",
- "category_icon": "fa-th",
- "href": "https://www.google.com"}
+appbuilder_mitem = {
+ "name": "Google",
+ "category": "Search",
+ "category_icon": "fa-th",
+ "href": "https://www.google.com",
+}
# Creating a flask blueprint to intergrate the templates and static folder
bp = Blueprint(
- "test_plugin", __name__,
+ "test_plugin",
+ __name__,
template_folder='templates', # registers airflow/plugins/templates as a Jinja template folder
static_folder='static',
- static_url_path='/static/test_plugin')
+ static_url_path='/static/test_plugin',
+)
# Defining the plugin class
@@ -99,9 +108,7 @@ class AirflowTestPlugin(AirflowPlugin):
AirflowLink(),
GithubLink(),
]
- operator_extra_links = [
- GoogleLink(), AirflowLink2(), CustomOpLink(), CustomBaseIndexOpLink(1)
- ]
+ operator_extra_links = [GoogleLink(), AirflowLink2(), CustomOpLink(), CustomBaseIndexOpLink(1)]
class MockPluginA(AirflowPlugin):
diff --git a/tests/plugins/test_plugins_manager.py b/tests/plugins/test_plugins_manager.py
index 7b5357ba58bd3..07e500dc01b96 100644
--- a/tests/plugins/test_plugins_manager.py
+++ b/tests/plugins/test_plugins_manager.py
@@ -31,15 +31,20 @@ def setUp(self):
def test_flaskappbuilder_views(self):
from tests.plugins.test_plugin import v_appbuilder_package
+
appbuilder_class_name = str(v_appbuilder_package['view'].__class__.__name__)
- plugin_views = [view for view in self.appbuilder.baseviews
- if view.blueprint.name == appbuilder_class_name]
+ plugin_views = [
+ view for view in self.appbuilder.baseviews if view.blueprint.name == appbuilder_class_name
+ ]
self.assertTrue(len(plugin_views) == 1)
# view should have a menu item matching category of v_appbuilder_package
- links = [menu_item for menu_item in self.appbuilder.menu.menu
- if menu_item.name == v_appbuilder_package['category']]
+ links = [
+ menu_item
+ for menu_item in self.appbuilder.menu.menu
+ if menu_item.name == v_appbuilder_package['category']
+ ]
self.assertTrue(len(links) == 1)
@@ -52,8 +57,11 @@ def test_flaskappbuilder_menu_links(self):
from tests.plugins.test_plugin import appbuilder_mitem
# menu item should exist matching appbuilder_mitem
- links = [menu_item for menu_item in self.appbuilder.menu.menu
- if menu_item.name == appbuilder_mitem['category']]
+ links = [
+ menu_item
+ for menu_item in self.appbuilder.menu.menu
+ if menu_item.name == appbuilder_mitem['category']
+ ]
self.assertTrue(len(links) == 1)
@@ -115,6 +123,7 @@ class TestNonPropertyHook(BaseHook):
with mock_plugin_manager(plugins=[AirflowTestPropertyPlugin()]):
from airflow import plugins_manager
+
plugins_manager.integrate_dag_plugins()
self.assertIn('AirflowTestPropertyPlugin', str(plugins_manager.plugins))
@@ -132,24 +141,24 @@ class AirflowAdminMenuLinksPlugin(AirflowPlugin):
menu_links = [mock.MagicMock()]
- with mock_plugin_manager(plugins=[
- AirflowAdminViewsPlugin(),
- AirflowAdminMenuLinksPlugin()
- ]):
+ with mock_plugin_manager(plugins=[AirflowAdminViewsPlugin(), AirflowAdminMenuLinksPlugin()]):
from airflow import plugins_manager
# assert not logs
with self.assertLogs(plugins_manager.log) as cm:
plugins_manager.initialize_web_ui_plugins()
- self.assertEqual(cm.output, [
- 'WARNING:airflow.plugins_manager:Plugin \'test_admin_views_plugin\' may not be '
- 'compatible with the current Airflow version. Please contact the author of '
- 'the plugin.',
- 'WARNING:airflow.plugins_manager:Plugin \'test_menu_links_plugin\' may not be '
- 'compatible with the current Airflow version. Please contact the author of '
- 'the plugin.'
- ])
+ self.assertEqual(
+ cm.output,
+ [
+ 'WARNING:airflow.plugins_manager:Plugin \'test_admin_views_plugin\' may not be '
+ 'compatible with the current Airflow version. Please contact the author of '
+ 'the plugin.',
+ 'WARNING:airflow.plugins_manager:Plugin \'test_menu_links_plugin\' may not be '
+ 'compatible with the current Airflow version. Please contact the author of '
+ 'the plugin.',
+ ],
+ )
def test_should_not_warning_about_fab_plugins(self):
class AirflowAdminViewsPlugin(AirflowPlugin):
@@ -162,10 +171,7 @@ class AirflowAdminMenuLinksPlugin(AirflowPlugin):
appbuilder_menu_items = [mock.MagicMock()]
- with mock_plugin_manager(plugins=[
- AirflowAdminViewsPlugin(),
- AirflowAdminMenuLinksPlugin()
- ]):
+ with mock_plugin_manager(plugins=[AirflowAdminViewsPlugin(), AirflowAdminMenuLinksPlugin()]):
from airflow import plugins_manager
# assert not logs
@@ -185,10 +191,7 @@ class AirflowAdminMenuLinksPlugin(AirflowPlugin):
menu_links = [mock.MagicMock()]
appbuilder_menu_items = [mock.MagicMock()]
- with mock_plugin_manager(plugins=[
- AirflowAdminViewsPlugin(),
- AirflowAdminMenuLinksPlugin()
- ]):
+ with mock_plugin_manager(plugins=[AirflowAdminViewsPlugin(), AirflowAdminMenuLinksPlugin()]):
from airflow import plugins_manager
# assert not logs
diff --git a/tests/providers/amazon/aws/hooks/test_glacier.py b/tests/providers/amazon/aws/hooks/test_glacier.py
index 63b93d5004fd1..0f60f371884f1 100644
--- a/tests/providers/amazon/aws/hooks/test_glacier.py
+++ b/tests/providers/amazon/aws/hooks/test_glacier.py
@@ -17,8 +17,8 @@
# under the License.
import unittest
-
from unittest import mock
+
from testfixtures import LogCapture
from airflow.providers.amazon.aws.hooks.glacier import GlacierHook
diff --git a/tests/providers/amazon/aws/hooks/test_glue.py b/tests/providers/amazon/aws/hooks/test_glue.py
index d450bc780dde3..512e7c84691cd 100644
--- a/tests/providers/amazon/aws/hooks/test_glue.py
+++ b/tests/providers/amazon/aws/hooks/test_glue.py
@@ -17,7 +17,6 @@
# under the License.
import json
import unittest
-
from unittest import mock
from airflow.providers.amazon.aws.hooks.glue import AwsGlueJobHook
diff --git a/tests/providers/amazon/aws/hooks/test_sagemaker.py b/tests/providers/amazon/aws/hooks/test_sagemaker.py
index a94281a6452fa..b6ecfce873f19 100644
--- a/tests/providers/amazon/aws/hooks/test_sagemaker.py
+++ b/tests/providers/amazon/aws/hooks/test_sagemaker.py
@@ -20,8 +20,8 @@
import time
import unittest
from datetime import datetime
-
from unittest import mock
+
from tzlocal import get_localzone
from airflow.exceptions import AirflowException
diff --git a/tests/providers/amazon/aws/hooks/test_secrets_manager.py b/tests/providers/amazon/aws/hooks/test_secrets_manager.py
index c83260c3ba397..916a5ea04def0 100644
--- a/tests/providers/amazon/aws/hooks/test_secrets_manager.py
+++ b/tests/providers/amazon/aws/hooks/test_secrets_manager.py
@@ -17,9 +17,9 @@
# under the License.
#
-import unittest
import base64
import json
+import unittest
from airflow.providers.amazon.aws.hooks.secrets_manager import SecretsManagerHook
diff --git a/tests/providers/amazon/aws/operators/test_athena.py b/tests/providers/amazon/aws/operators/test_athena.py
index 1eef2c4b5f484..61e43d26aec1c 100644
--- a/tests/providers/amazon/aws/operators/test_athena.py
+++ b/tests/providers/amazon/aws/operators/test_athena.py
@@ -16,7 +16,6 @@
# under the License.
import unittest
-
from unittest import mock
from airflow.models import DAG, TaskInstance
diff --git a/tests/providers/amazon/aws/operators/test_batch.py b/tests/providers/amazon/aws/operators/test_batch.py
index 4bf36c455132c..f72e12a2ece43 100644
--- a/tests/providers/amazon/aws/operators/test_batch.py
+++ b/tests/providers/amazon/aws/operators/test_batch.py
@@ -20,7 +20,6 @@
# pylint: disable=missing-docstring
import unittest
-
from unittest import mock
from airflow.exceptions import AirflowException
diff --git a/tests/providers/amazon/aws/operators/test_ecs.py b/tests/providers/amazon/aws/operators/test_ecs.py
index 16f39bc116342..a8bb79216164f 100644
--- a/tests/providers/amazon/aws/operators/test_ecs.py
+++ b/tests/providers/amazon/aws/operators/test_ecs.py
@@ -20,8 +20,8 @@
import sys
import unittest
from copy import deepcopy
-
from unittest import mock
+
from parameterized import parameterized
from airflow.exceptions import AirflowException
diff --git a/tests/providers/amazon/aws/operators/test_glacier.py b/tests/providers/amazon/aws/operators/test_glacier.py
index 561ad20683b7b..e86e82cf929b8 100644
--- a/tests/providers/amazon/aws/operators/test_glacier.py
+++ b/tests/providers/amazon/aws/operators/test_glacier.py
@@ -16,13 +16,9 @@
# specific language governing permissions and limitations
# under the License.
-from unittest import TestCase
+from unittest import TestCase, mock
-from unittest import mock
-
-from airflow.providers.amazon.aws.operators.glacier import (
- GlacierCreateJobOperator,
-)
+from airflow.providers.amazon.aws.operators.glacier import GlacierCreateJobOperator
AWS_CONN_ID = "aws_default"
BUCKET_NAME = "airflow_bucket"
diff --git a/tests/providers/amazon/aws/operators/test_glacier_system.py b/tests/providers/amazon/aws/operators/test_glacier_system.py
index b31c69958e76b..99a5315504b25 100644
--- a/tests/providers/amazon/aws/operators/test_glacier_system.py
+++ b/tests/providers/amazon/aws/operators/test_glacier_system.py
@@ -18,7 +18,6 @@
from tests.test_utils.amazon_system_helpers import AWS_DAG_FOLDER, AmazonSystemTest
from tests.test_utils.gcp_system_helpers import GoogleSystemTest
-
BUCKET = "data_from_glacier"
diff --git a/tests/providers/amazon/aws/operators/test_glue.py b/tests/providers/amazon/aws/operators/test_glue.py
index f0a08f5551f22..144bebed29f83 100644
--- a/tests/providers/amazon/aws/operators/test_glue.py
+++ b/tests/providers/amazon/aws/operators/test_glue.py
@@ -16,7 +16,6 @@
# under the License.
import unittest
-
from unittest import mock
from airflow import configuration
diff --git a/tests/providers/amazon/aws/operators/test_s3_bucket.py b/tests/providers/amazon/aws/operators/test_s3_bucket.py
index 5d090b81a0b40..e501125ae90bf 100644
--- a/tests/providers/amazon/aws/operators/test_s3_bucket.py
+++ b/tests/providers/amazon/aws/operators/test_s3_bucket.py
@@ -17,8 +17,8 @@
# under the License.
import os
import unittest
-
from unittest import mock
+
from moto import mock_s3
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
diff --git a/tests/providers/amazon/aws/operators/test_s3_list.py b/tests/providers/amazon/aws/operators/test_s3_list.py
index 5c271a38d23b4..b51c8fac89b65 100644
--- a/tests/providers/amazon/aws/operators/test_s3_list.py
+++ b/tests/providers/amazon/aws/operators/test_s3_list.py
@@ -17,7 +17,6 @@
# under the License.
import unittest
-
from unittest import mock
from airflow.providers.amazon.aws.operators.s3_list import S3ListOperator
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py
index 44563204571f1..d08c92453842c 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py
@@ -17,8 +17,8 @@
# under the License.
import unittest
-
from unittest import mock
+
from botocore.exceptions import ClientError
from airflow.exceptions import AirflowException
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint_config.py b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint_config.py
index 6d36284f01f30..e725c0228a706 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint_config.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint_config.py
@@ -17,7 +17,6 @@
# under the License.
import unittest
-
from unittest import mock
from airflow.exceptions import AirflowException
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_model.py b/tests/providers/amazon/aws/operators/test_sagemaker_model.py
index cf8aaa925a879..6676f1009c395 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_model.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_model.py
@@ -17,7 +17,6 @@
# under the License.
import unittest
-
from unittest import mock
from airflow.exceptions import AirflowException
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py
index 1e51a9f523211..5e55b39006167 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py
@@ -16,8 +16,8 @@
# under the License.
import unittest
-
from unittest import mock
+
from parameterized import parameterized
from airflow.exceptions import AirflowException
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_training.py b/tests/providers/amazon/aws/operators/test_sagemaker_training.py
index e7134860a2207..7424d6f7f8b25 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_training.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_training.py
@@ -16,7 +16,6 @@
# under the License.
import unittest
-
from unittest import mock
from airflow.exceptions import AirflowException
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
index ccbaa09b51918..739baed701dbe 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
@@ -17,7 +17,6 @@
# under the License.
import unittest
-
from unittest import mock
from airflow.exceptions import AirflowException
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py b/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py
index a197ddf3d925b..9978944a18eec 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py
@@ -17,7 +17,6 @@
# under the License.
import unittest
-
from unittest import mock
from airflow.exceptions import AirflowException
diff --git a/tests/providers/amazon/aws/sensors/test_athena.py b/tests/providers/amazon/aws/sensors/test_athena.py
index ad83cfe29544d..8f5f2c7173cf7 100644
--- a/tests/providers/amazon/aws/sensors/test_athena.py
+++ b/tests/providers/amazon/aws/sensors/test_athena.py
@@ -17,7 +17,6 @@
# under the License.
import unittest
-
from unittest import mock
from airflow.exceptions import AirflowException
diff --git a/tests/providers/amazon/aws/sensors/test_glacier.py b/tests/providers/amazon/aws/sensors/test_glacier.py
index 410f984cdacc1..d954c6e7ed599 100644
--- a/tests/providers/amazon/aws/sensors/test_glacier.py
+++ b/tests/providers/amazon/aws/sensors/test_glacier.py
@@ -17,7 +17,6 @@
# under the License.
import unittest
-
from unittest import mock
from airflow import AirflowException
diff --git a/tests/providers/amazon/aws/sensors/test_glue.py b/tests/providers/amazon/aws/sensors/test_glue.py
index 21f0083869a54..ff503b1467970 100644
--- a/tests/providers/amazon/aws/sensors/test_glue.py
+++ b/tests/providers/amazon/aws/sensors/test_glue.py
@@ -16,7 +16,6 @@
# under the License.
import unittest
-
from unittest import mock
from airflow import configuration
diff --git a/tests/providers/amazon/aws/sensors/test_glue_catalog_partition.py b/tests/providers/amazon/aws/sensors/test_glue_catalog_partition.py
index 26f683f6cb0c9..a4c264bbe6359 100644
--- a/tests/providers/amazon/aws/sensors/test_glue_catalog_partition.py
+++ b/tests/providers/amazon/aws/sensors/test_glue_catalog_partition.py
@@ -17,7 +17,6 @@
# under the License.
import unittest
-
from unittest import mock
from airflow.providers.amazon.aws.hooks.glue_catalog import AwsGlueCatalogHook
diff --git a/tests/providers/amazon/aws/sensors/test_sagemaker_endpoint.py b/tests/providers/amazon/aws/sensors/test_sagemaker_endpoint.py
index 042859c53b00a..ca4090b12f2dd 100644
--- a/tests/providers/amazon/aws/sensors/test_sagemaker_endpoint.py
+++ b/tests/providers/amazon/aws/sensors/test_sagemaker_endpoint.py
@@ -17,7 +17,6 @@
# under the License.
import unittest
-
from unittest import mock
from airflow.exceptions import AirflowException
diff --git a/tests/providers/amazon/aws/sensors/test_sagemaker_training.py b/tests/providers/amazon/aws/sensors/test_sagemaker_training.py
index 011d23f4696dc..c6d5b782a68ed 100644
--- a/tests/providers/amazon/aws/sensors/test_sagemaker_training.py
+++ b/tests/providers/amazon/aws/sensors/test_sagemaker_training.py
@@ -18,7 +18,6 @@
import unittest
from datetime import datetime
-
from unittest import mock
from airflow.exceptions import AirflowException
diff --git a/tests/providers/amazon/aws/sensors/test_sagemaker_transform.py b/tests/providers/amazon/aws/sensors/test_sagemaker_transform.py
index 3f576f11211f8..b0463373c10ac 100644
--- a/tests/providers/amazon/aws/sensors/test_sagemaker_transform.py
+++ b/tests/providers/amazon/aws/sensors/test_sagemaker_transform.py
@@ -17,7 +17,6 @@
# under the License.
import unittest
-
from unittest import mock
from airflow.exceptions import AirflowException
diff --git a/tests/providers/amazon/aws/sensors/test_sagemaker_tuning.py b/tests/providers/amazon/aws/sensors/test_sagemaker_tuning.py
index 90e8a5e7e1845..32b2553f1f5c6 100644
--- a/tests/providers/amazon/aws/sensors/test_sagemaker_tuning.py
+++ b/tests/providers/amazon/aws/sensors/test_sagemaker_tuning.py
@@ -17,7 +17,6 @@
# under the License.
import unittest
-
from unittest import mock
from airflow.exceptions import AirflowException
diff --git a/tests/providers/amazon/aws/transfers/test_gcs_to_s3.py b/tests/providers/amazon/aws/transfers/test_gcs_to_s3.py
index 404431e81de6b..fc62cc4635c19 100644
--- a/tests/providers/amazon/aws/transfers/test_gcs_to_s3.py
+++ b/tests/providers/amazon/aws/transfers/test_gcs_to_s3.py
@@ -17,7 +17,6 @@
# under the License.
import unittest
-
from unittest import mock
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
diff --git a/tests/providers/amazon/aws/transfers/test_glacier_to_gcs.py b/tests/providers/amazon/aws/transfers/test_glacier_to_gcs.py
index a2a49eb057947..f871f5acfa72c 100644
--- a/tests/providers/amazon/aws/transfers/test_glacier_to_gcs.py
+++ b/tests/providers/amazon/aws/transfers/test_glacier_to_gcs.py
@@ -15,9 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from unittest import TestCase
-
-from unittest import mock
+from unittest import TestCase, mock
from airflow.providers.amazon.aws.transfers.glacier_to_gcs import GlacierToGCSOperator
diff --git a/tests/providers/amazon/aws/transfers/test_mongo_to_s3.py b/tests/providers/amazon/aws/transfers/test_mongo_to_s3.py
index 44deffaa489a6..8ac0735b6b98f 100644
--- a/tests/providers/amazon/aws/transfers/test_mongo_to_s3.py
+++ b/tests/providers/amazon/aws/transfers/test_mongo_to_s3.py
@@ -16,7 +16,6 @@
# specific language governing permissions and limitations
# under the License.
import unittest
-
from unittest import mock
from airflow.models import TaskInstance
diff --git a/tests/providers/apache/cassandra/hooks/test_cassandra.py b/tests/providers/apache/cassandra/hooks/test_cassandra.py
index 391cfd0705e5b..5e37acef674ec 100644
--- a/tests/providers/apache/cassandra/hooks/test_cassandra.py
+++ b/tests/providers/apache/cassandra/hooks/test_cassandra.py
@@ -17,8 +17,8 @@
# under the License.
import unittest
-
from unittest import mock
+
import pytest
from cassandra.cluster import Cluster
from cassandra.policies import (
diff --git a/tests/providers/apache/druid/operators/test_druid_check.py b/tests/providers/apache/druid/operators/test_druid_check.py
index b1bc844190db3..bb84e9d47daec 100644
--- a/tests/providers/apache/druid/operators/test_druid_check.py
+++ b/tests/providers/apache/druid/operators/test_druid_check.py
@@ -19,7 +19,6 @@
import unittest
from datetime import datetime
-
from unittest import mock
from airflow.exceptions import AirflowException
diff --git a/tests/providers/apache/hive/transfers/test_s3_to_hive.py b/tests/providers/apache/hive/transfers/test_s3_to_hive.py
index ad3fe94f24245..473b889335b65 100644
--- a/tests/providers/apache/hive/transfers/test_s3_to_hive.py
+++ b/tests/providers/apache/hive/transfers/test_s3_to_hive.py
@@ -26,7 +26,6 @@
from gzip import GzipFile
from itertools import product
from tempfile import NamedTemporaryFile, mkdtemp
-
from unittest import mock
from airflow.exceptions import AirflowException
diff --git a/tests/providers/apache/pig/hooks/test_pig.py b/tests/providers/apache/pig/hooks/test_pig.py
index da8de176a8fa1..a714368b88354 100644
--- a/tests/providers/apache/pig/hooks/test_pig.py
+++ b/tests/providers/apache/pig/hooks/test_pig.py
@@ -17,7 +17,6 @@
# under the License.
import unittest
-
from unittest import mock
from airflow.providers.apache.pig.hooks.pig import PigCliHook
diff --git a/tests/providers/apache/pig/operators/test_pig.py b/tests/providers/apache/pig/operators/test_pig.py
index 174f24c0607ea..92546d122e175 100644
--- a/tests/providers/apache/pig/operators/test_pig.py
+++ b/tests/providers/apache/pig/operators/test_pig.py
@@ -16,7 +16,6 @@
# under the License.
import unittest
-
from unittest import mock
from airflow.providers.apache.pig.hooks.pig import PigCliHook
diff --git a/tests/providers/apache/spark/hooks/test_spark_jdbc_script.py b/tests/providers/apache/spark/hooks/test_spark_jdbc_script.py
index 1431ec1a86c86..18bb75f192a3a 100644
--- a/tests/providers/apache/spark/hooks/test_spark_jdbc_script.py
+++ b/tests/providers/apache/spark/hooks/test_spark_jdbc_script.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from unittest import mock
+
import pytest
from pyspark.sql.readwriter import DataFrameReader, DataFrameWriter
diff --git a/tests/providers/cloudant/hooks/test_cloudant.py b/tests/providers/cloudant/hooks/test_cloudant.py
index cf291c027cf82..d0c78dbebb97c 100644
--- a/tests/providers/cloudant/hooks/test_cloudant.py
+++ b/tests/providers/cloudant/hooks/test_cloudant.py
@@ -16,7 +16,6 @@
# specific language governing permissions and limitations
# under the License.
import unittest
-
from unittest.mock import patch
from airflow.exceptions import AirflowException
diff --git a/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py b/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py
index e03540abecb61..9a01d8fa5007b 100644
--- a/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py
+++ b/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py
@@ -23,11 +23,10 @@
from unittest import mock
from unittest.mock import patch
+import kubernetes
from parameterized import parameterized
-import kubernetes
from airflow import AirflowException
-
from airflow.models import Connection
from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook
from airflow.utils import db
diff --git a/tests/providers/databricks/hooks/test_databricks.py b/tests/providers/databricks/hooks/test_databricks.py
index ac70edcf91e59..94a094be7ba83 100644
--- a/tests/providers/databricks/hooks/test_databricks.py
+++ b/tests/providers/databricks/hooks/test_databricks.py
@@ -20,8 +20,8 @@
import itertools
import json
import unittest
-
from unittest import mock
+
from requests import exceptions as requests_exceptions
from airflow import __version__
diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py
index c26e464f96795..8f43ad5391b65 100644
--- a/tests/providers/databricks/operators/test_databricks.py
+++ b/tests/providers/databricks/operators/test_databricks.py
@@ -18,7 +18,6 @@
#
import unittest
from datetime import datetime
-
from unittest import mock
from airflow.exceptions import AirflowException
diff --git a/tests/providers/datadog/sensors/test_datadog.py b/tests/providers/datadog/sensors/test_datadog.py
index cf9633b752b61..10dae5f31a508 100644
--- a/tests/providers/datadog/sensors/test_datadog.py
+++ b/tests/providers/datadog/sensors/test_datadog.py
@@ -19,7 +19,6 @@
import json
import unittest
from typing import List
-
from unittest.mock import patch
from airflow.models import Connection
diff --git a/tests/providers/docker/hooks/test_docker.py b/tests/providers/docker/hooks/test_docker.py
index d66823a96030b..49e810b3d2859 100644
--- a/tests/providers/docker/hooks/test_docker.py
+++ b/tests/providers/docker/hooks/test_docker.py
@@ -17,7 +17,6 @@
# under the License.
import unittest
-
from unittest import mock
from airflow.exceptions import AirflowException
diff --git a/tests/providers/docker/operators/test_docker.py b/tests/providers/docker/operators/test_docker.py
index 49cd6a69bf3d7..631d42be2df54 100644
--- a/tests/providers/docker/operators/test_docker.py
+++ b/tests/providers/docker/operators/test_docker.py
@@ -17,7 +17,6 @@
# under the License.
import logging
import unittest
-
from unittest import mock
from airflow.exceptions import AirflowException
diff --git a/tests/providers/docker/operators/test_docker_swarm.py b/tests/providers/docker/operators/test_docker_swarm.py
index 784e8e7fc3d03..a3208ef00138b 100644
--- a/tests/providers/docker/operators/test_docker_swarm.py
+++ b/tests/providers/docker/operators/test_docker_swarm.py
@@ -17,8 +17,8 @@
# under the License.
import unittest
-
from unittest import mock
+
import requests
from docker import APIClient
diff --git a/tests/providers/exasol/hooks/test_exasol.py b/tests/providers/exasol/hooks/test_exasol.py
index 110259e5c6ad2..4b964e102ad33 100644
--- a/tests/providers/exasol/hooks/test_exasol.py
+++ b/tests/providers/exasol/hooks/test_exasol.py
@@ -19,7 +19,6 @@
import json
import unittest
-
from unittest import mock
from airflow import models
diff --git a/tests/providers/exasol/operators/test_exasol.py b/tests/providers/exasol/operators/test_exasol.py
index a882075a72ac4..2b486a14c8057 100644
--- a/tests/providers/exasol/operators/test_exasol.py
+++ b/tests/providers/exasol/operators/test_exasol.py
@@ -17,7 +17,6 @@
# under the License.
import unittest
-
from unittest import mock
from airflow.providers.exasol.operators.exasol import ExasolOperator
diff --git a/tests/providers/facebook/ads/hooks/test_ads.py b/tests/providers/facebook/ads/hooks/test_ads.py
index 366277a7a03b2..8d4d88b021c98 100644
--- a/tests/providers/facebook/ads/hooks/test_ads.py
+++ b/tests/providers/facebook/ads/hooks/test_ads.py
@@ -16,6 +16,7 @@
# under the License.
from unittest import mock
+
import pytest
from airflow.providers.facebook.ads.hooks.ads import FacebookAdsReportingHook
diff --git a/tests/providers/google/ads/hooks/test_ads.py b/tests/providers/google/ads/hooks/test_ads.py
index 1928f79760ba9..83d17b7d716c4 100644
--- a/tests/providers/google/ads/hooks/test_ads.py
+++ b/tests/providers/google/ads/hooks/test_ads.py
@@ -16,6 +16,7 @@
# under the License.
from unittest import mock
+
import pytest
from airflow.providers.google.ads.hooks.ads import GoogleAdsHook
diff --git a/tests/providers/google/cloud/hooks/test_automl.py b/tests/providers/google/cloud/hooks/test_automl.py
index 353921b644a74..898001c3b17a2 100644
--- a/tests/providers/google/cloud/hooks/test_automl.py
+++ b/tests/providers/google/cloud/hooks/test_automl.py
@@ -17,8 +17,8 @@
# under the License.
#
import unittest
-
from unittest import mock
+
from google.cloud.automl_v1beta1 import AutoMlClient, PredictionServiceClient
from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
diff --git a/tests/providers/google/cloud/hooks/test_bigquery_dts.py b/tests/providers/google/cloud/hooks/test_bigquery_dts.py
index 82944f8afb42e..cf8b1cd899fae 100644
--- a/tests/providers/google/cloud/hooks/test_bigquery_dts.py
+++ b/tests/providers/google/cloud/hooks/test_bigquery_dts.py
@@ -18,8 +18,8 @@
import unittest
from copy import deepcopy
-
from unittest import mock
+
from google.cloud.bigquery_datatransfer_v1 import DataTransferServiceClient
from google.cloud.bigquery_datatransfer_v1.types import TransferConfig
from google.protobuf.json_format import ParseDict
diff --git a/tests/providers/google/cloud/hooks/test_cloud_build.py b/tests/providers/google/cloud/hooks/test_cloud_build.py
index ecd7a16377495..9a3b67ec258e6 100644
--- a/tests/providers/google/cloud/hooks/test_cloud_build.py
+++ b/tests/providers/google/cloud/hooks/test_cloud_build.py
@@ -21,7 +21,6 @@
import unittest
from typing import Optional
from unittest import mock
-
from unittest.mock import PropertyMock
from airflow.exceptions import AirflowException
diff --git a/tests/providers/google/cloud/hooks/test_cloud_storage_transfer_service.py b/tests/providers/google/cloud/hooks/test_cloud_storage_transfer_service.py
index 4fda2ad776907..b7d869b535cb4 100644
--- a/tests/providers/google/cloud/hooks/test_cloud_storage_transfer_service.py
+++ b/tests/providers/google/cloud/hooks/test_cloud_storage_transfer_service.py
@@ -20,9 +20,9 @@
import re
import unittest
from copy import deepcopy
-
from unittest import mock
from unittest.mock import MagicMock, PropertyMock
+
from googleapiclient.errors import HttpError
from parameterized import parameterized
diff --git a/tests/providers/google/cloud/hooks/test_compute.py b/tests/providers/google/cloud/hooks/test_compute.py
index 01ee15113f827..64f6d928fcf60 100644
--- a/tests/providers/google/cloud/hooks/test_compute.py
+++ b/tests/providers/google/cloud/hooks/test_compute.py
@@ -19,7 +19,6 @@
# pylint: disable=too-many-lines
import unittest
-
from unittest import mock
from unittest.mock import PropertyMock
diff --git a/tests/providers/google/cloud/hooks/test_dataflow.py b/tests/providers/google/cloud/hooks/test_dataflow.py
index 9de46a5b81e09..be486eed511df 100644
--- a/tests/providers/google/cloud/hooks/test_dataflow.py
+++ b/tests/providers/google/cloud/hooks/test_dataflow.py
@@ -21,9 +21,9 @@
import shlex
import unittest
from typing import Any, Dict
-
from unittest import mock
from unittest.mock import MagicMock
+
from parameterized import parameterized
from airflow.exceptions import AirflowException
diff --git a/tests/providers/google/cloud/hooks/test_dataproc.py b/tests/providers/google/cloud/hooks/test_dataproc.py
index ebca9f5f340d7..43a28ee7c0594 100644
--- a/tests/providers/google/cloud/hooks/test_dataproc.py
+++ b/tests/providers/google/cloud/hooks/test_dataproc.py
@@ -17,8 +17,8 @@
# under the License.
#
import unittest
-
from unittest import mock
+
from google.cloud.dataproc_v1beta2.types import JobStatus # pylint: disable=no-name-in-module
from airflow.exceptions import AirflowException
diff --git a/tests/providers/google/cloud/hooks/test_datastore.py b/tests/providers/google/cloud/hooks/test_datastore.py
index dfb15d173808d..3d9216ab9daee 100644
--- a/tests/providers/google/cloud/hooks/test_datastore.py
+++ b/tests/providers/google/cloud/hooks/test_datastore.py
@@ -18,9 +18,8 @@
#
import unittest
-from unittest.mock import call, patch
-
from unittest import mock
+from unittest.mock import call, patch
from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.datastore import DatastoreHook
diff --git a/tests/providers/google/cloud/hooks/test_dlp.py b/tests/providers/google/cloud/hooks/test_dlp.py
index 7e5c5c8dc601f..c9f14d99e81cf 100644
--- a/tests/providers/google/cloud/hooks/test_dlp.py
+++ b/tests/providers/google/cloud/hooks/test_dlp.py
@@ -24,9 +24,9 @@
import unittest
from typing import Any, Dict
-
from unittest import mock
from unittest.mock import PropertyMock
+
from google.cloud.dlp_v2.types import DlpJob
from airflow.exceptions import AirflowException
diff --git a/tests/providers/google/cloud/hooks/test_functions.py b/tests/providers/google/cloud/hooks/test_functions.py
index 02fd13847492e..76304b00b3a07 100644
--- a/tests/providers/google/cloud/hooks/test_functions.py
+++ b/tests/providers/google/cloud/hooks/test_functions.py
@@ -17,7 +17,6 @@
# under the License.
import unittest
-
from unittest import mock
from unittest.mock import PropertyMock
diff --git a/tests/providers/google/cloud/hooks/test_kms.py b/tests/providers/google/cloud/hooks/test_kms.py
index 94383040cb9f1..894532aaed41b 100644
--- a/tests/providers/google/cloud/hooks/test_kms.py
+++ b/tests/providers/google/cloud/hooks/test_kms.py
@@ -19,7 +19,6 @@
import unittest
from base64 import b64decode, b64encode
from collections import namedtuple
-
from unittest import mock
from airflow.providers.google.cloud.hooks.kms import CloudKMSHook
diff --git a/tests/providers/google/cloud/hooks/test_kubernetes_engine.py b/tests/providers/google/cloud/hooks/test_kubernetes_engine.py
index d93abe52b4009..892c9f7a6fb53 100644
--- a/tests/providers/google/cloud/hooks/test_kubernetes_engine.py
+++ b/tests/providers/google/cloud/hooks/test_kubernetes_engine.py
@@ -17,9 +17,9 @@
# under the License.
#
import unittest
-
from unittest import mock
from unittest.mock import PropertyMock
+
from google.cloud.container_v1.types import Cluster
from airflow.exceptions import AirflowException
diff --git a/tests/providers/google/cloud/hooks/test_life_sciences.py b/tests/providers/google/cloud/hooks/test_life_sciences.py
index 9f6c553d997a1..e203e87f3eb79 100644
--- a/tests/providers/google/cloud/hooks/test_life_sciences.py
+++ b/tests/providers/google/cloud/hooks/test_life_sciences.py
@@ -20,7 +20,6 @@
"""
import unittest
from unittest import mock
-
from unittest.mock import PropertyMock
from airflow.exceptions import AirflowException
diff --git a/tests/providers/google/cloud/hooks/test_natural_language.py b/tests/providers/google/cloud/hooks/test_natural_language.py
index 1a4b588017490..745b304506364 100644
--- a/tests/providers/google/cloud/hooks/test_natural_language.py
+++ b/tests/providers/google/cloud/hooks/test_natural_language.py
@@ -18,8 +18,8 @@
#
import unittest
from typing import Any, Dict
-
from unittest import mock
+
from google.cloud.language_v1.proto.language_service_pb2 import Document
from airflow.providers.google.cloud.hooks.natural_language import CloudNaturalLanguageHook
diff --git a/tests/providers/google/cloud/hooks/test_pubsub.py b/tests/providers/google/cloud/hooks/test_pubsub.py
index 7c1fcad6953eb..f519c70504a4b 100644
--- a/tests/providers/google/cloud/hooks/test_pubsub.py
+++ b/tests/providers/google/cloud/hooks/test_pubsub.py
@@ -18,8 +18,8 @@
import unittest
from typing import List
-
from unittest import mock
+
from google.api_core.exceptions import AlreadyExists, GoogleAPICallError
from google.cloud.exceptions import NotFound
from google.cloud.pubsub_v1.types import ReceivedMessage
diff --git a/tests/providers/google/cloud/hooks/test_spanner.py b/tests/providers/google/cloud/hooks/test_spanner.py
index 2fb4527a5c5de..27b7a06e67456 100644
--- a/tests/providers/google/cloud/hooks/test_spanner.py
+++ b/tests/providers/google/cloud/hooks/test_spanner.py
@@ -17,7 +17,6 @@
# under the License.
import unittest
-
from unittest import mock
from unittest.mock import PropertyMock
diff --git a/tests/providers/google/cloud/hooks/test_speech_to_text.py b/tests/providers/google/cloud/hooks/test_speech_to_text.py
index cd2c8fe5debfa..be5d933137823 100644
--- a/tests/providers/google/cloud/hooks/test_speech_to_text.py
+++ b/tests/providers/google/cloud/hooks/test_speech_to_text.py
@@ -18,7 +18,6 @@
#
import unittest
-
from unittest.mock import PropertyMock, patch
from airflow.providers.google.cloud.hooks.speech_to_text import CloudSpeechToTextHook
diff --git a/tests/providers/google/cloud/hooks/test_stackdriver.py b/tests/providers/google/cloud/hooks/test_stackdriver.py
index ae16064829364..6892d0552f559 100644
--- a/tests/providers/google/cloud/hooks/test_stackdriver.py
+++ b/tests/providers/google/cloud/hooks/test_stackdriver.py
@@ -18,8 +18,8 @@
import json
import unittest
-
from unittest import mock
+
from google.api_core.gapic_v1.method import DEFAULT
from google.cloud import monitoring_v3
from google.protobuf.json_format import ParseDict
diff --git a/tests/providers/google/cloud/hooks/test_tasks.py b/tests/providers/google/cloud/hooks/test_tasks.py
index 8c7791843c546..8be6686cf6c1d 100644
--- a/tests/providers/google/cloud/hooks/test_tasks.py
+++ b/tests/providers/google/cloud/hooks/test_tasks.py
@@ -18,8 +18,8 @@
#
import unittest
from typing import Any, Dict
-
from unittest import mock
+
from google.cloud.tasks_v2.types import Queue, Task
from airflow.providers.google.cloud.hooks.tasks import CloudTasksHook
diff --git a/tests/providers/google/cloud/hooks/test_text_to_speech.py b/tests/providers/google/cloud/hooks/test_text_to_speech.py
index ed6338c19a538..cc627a380b0dd 100644
--- a/tests/providers/google/cloud/hooks/test_text_to_speech.py
+++ b/tests/providers/google/cloud/hooks/test_text_to_speech.py
@@ -18,7 +18,6 @@
#
import unittest
-
from unittest.mock import PropertyMock, patch
from airflow.providers.google.cloud.hooks.text_to_speech import CloudTextToSpeechHook
diff --git a/tests/providers/google/cloud/hooks/test_translate.py b/tests/providers/google/cloud/hooks/test_translate.py
index c00eb026742c7..43c559c3b1a35 100644
--- a/tests/providers/google/cloud/hooks/test_translate.py
+++ b/tests/providers/google/cloud/hooks/test_translate.py
@@ -17,7 +17,6 @@
# under the License.
import unittest
-
from unittest import mock
from airflow.providers.google.cloud.hooks.translate import CloudTranslateHook
diff --git a/tests/providers/google/cloud/hooks/test_video_intelligence.py b/tests/providers/google/cloud/hooks/test_video_intelligence.py
index f6973f4c1e840..624d0c597b95c 100644
--- a/tests/providers/google/cloud/hooks/test_video_intelligence.py
+++ b/tests/providers/google/cloud/hooks/test_video_intelligence.py
@@ -17,8 +17,8 @@
# under the License.
#
import unittest
-
from unittest import mock
+
from google.cloud.videointelligence_v1 import enums
from airflow.providers.google.cloud.hooks.video_intelligence import CloudVideoIntelligenceHook
diff --git a/tests/providers/google/cloud/hooks/test_vision.py b/tests/providers/google/cloud/hooks/test_vision.py
index f6cfbd882118f..31f004c98419f 100644
--- a/tests/providers/google/cloud/hooks/test_vision.py
+++ b/tests/providers/google/cloud/hooks/test_vision.py
@@ -16,8 +16,8 @@
# specific language governing permissions and limitations
# under the License.
import unittest
-
from unittest import mock
+
from google.cloud.vision import enums
from google.cloud.vision_v1 import ProductSearchClient
from google.cloud.vision_v1.proto.image_annotator_pb2 import (
diff --git a/tests/providers/google/cloud/operators/test_automl.py b/tests/providers/google/cloud/operators/test_automl.py
index 22650a79b3638..903600b0fa815 100644
--- a/tests/providers/google/cloud/operators/test_automl.py
+++ b/tests/providers/google/cloud/operators/test_automl.py
@@ -18,8 +18,8 @@
#
import copy
import unittest
-
from unittest import mock
+
from google.cloud.automl_v1beta1 import AutoMlClient, PredictionServiceClient
from airflow.providers.google.cloud.operators.automl import (
diff --git a/tests/providers/google/cloud/operators/test_bigquery.py b/tests/providers/google/cloud/operators/test_bigquery.py
index b3558bcaddd4f..3fa01f88273e1 100644
--- a/tests/providers/google/cloud/operators/test_bigquery.py
+++ b/tests/providers/google/cloud/operators/test_bigquery.py
@@ -18,9 +18,9 @@
import unittest
from datetime import datetime
+from unittest import mock
from unittest.mock import MagicMock
-from unittest import mock
import pytest
from google.cloud.exceptions import Conflict
from parameterized import parameterized
diff --git a/tests/providers/google/cloud/operators/test_bigquery_dts.py b/tests/providers/google/cloud/operators/test_bigquery_dts.py
index d25e1abd18140..4d423527acc8e 100644
--- a/tests/providers/google/cloud/operators/test_bigquery_dts.py
+++ b/tests/providers/google/cloud/operators/test_bigquery_dts.py
@@ -16,7 +16,6 @@
# specific language governing permissions and limitations
# under the License.
import unittest
-
from unittest import mock
from airflow.providers.google.cloud.operators.bigquery_dts import (
diff --git a/tests/providers/google/cloud/operators/test_cloud_build.py b/tests/providers/google/cloud/operators/test_cloud_build.py
index 924bb45df95bd..970ee0606ef0f 100644
--- a/tests/providers/google/cloud/operators/test_cloud_build.py
+++ b/tests/providers/google/cloud/operators/test_cloud_build.py
@@ -20,9 +20,8 @@
import tempfile
from copy import deepcopy
from datetime import datetime
-from unittest import TestCase
+from unittest import TestCase, mock
-from unittest import mock
from parameterized import parameterized
from airflow.exceptions import AirflowException
diff --git a/tests/providers/google/cloud/operators/test_cloud_memorystore.py b/tests/providers/google/cloud/operators/test_cloud_memorystore.py
index d64acf8bcd388..8ef60bd9b62b1 100644
--- a/tests/providers/google/cloud/operators/test_cloud_memorystore.py
+++ b/tests/providers/google/cloud/operators/test_cloud_memorystore.py
@@ -19,9 +19,9 @@
from unittest import TestCase, mock
from google.api_core.retry import Retry
+from google.cloud.memcache_v1beta2.types import cloud_memcache
from google.cloud.redis_v1.gapic.enums import FailoverInstanceRequest
from google.cloud.redis_v1.types import Instance
-from google.cloud.memcache_v1beta2.types import cloud_memcache
from airflow.providers.google.cloud.operators.cloud_memorystore import (
CloudMemorystoreCreateInstanceAndImportOperator,
@@ -32,13 +32,13 @@
CloudMemorystoreGetInstanceOperator,
CloudMemorystoreImportOperator,
CloudMemorystoreListInstancesOperator,
- CloudMemorystoreScaleInstanceOperator,
- CloudMemorystoreUpdateInstanceOperator,
CloudMemorystoreMemcachedCreateInstanceOperator,
CloudMemorystoreMemcachedDeleteInstanceOperator,
CloudMemorystoreMemcachedGetInstanceOperator,
CloudMemorystoreMemcachedListInstancesOperator,
CloudMemorystoreMemcachedUpdateInstanceOperator,
+ CloudMemorystoreScaleInstanceOperator,
+ CloudMemorystoreUpdateInstanceOperator,
)
TEST_GCP_CONN_ID = "test-gcp-conn-id"
diff --git a/tests/providers/google/cloud/operators/test_cloud_sql.py b/tests/providers/google/cloud/operators/test_cloud_sql.py
index 13b36d01a8eb2..2e14687ae53ac 100644
--- a/tests/providers/google/cloud/operators/test_cloud_sql.py
+++ b/tests/providers/google/cloud/operators/test_cloud_sql.py
@@ -20,8 +20,8 @@
import os
import unittest
-
from unittest import mock
+
from parameterized import parameterized
from airflow.exceptions import AirflowException
diff --git a/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py b/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py
index 31b8c53b14ce4..9dc3060b4c432 100644
--- a/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py
+++ b/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py
@@ -20,8 +20,8 @@
from copy import deepcopy
from datetime import date, time
from typing import Dict
-
from unittest import mock
+
from botocore.credentials import Credentials
from freezegun import freeze_time
from parameterized import parameterized
diff --git a/tests/providers/google/cloud/operators/test_dataflow.py b/tests/providers/google/cloud/operators/test_dataflow.py
index 48ccdfbeda6c5..02d95e52179ae 100644
--- a/tests/providers/google/cloud/operators/test_dataflow.py
+++ b/tests/providers/google/cloud/operators/test_dataflow.py
@@ -18,15 +18,14 @@
#
import unittest
-
from unittest import mock
from airflow.providers.google.cloud.operators.dataflow import (
CheckJobRunning,
DataflowCreateJavaJobOperator,
DataflowCreatePythonJobOperator,
- DataflowTemplatedJobStartOperator,
DataflowStartFlexTemplateOperator,
+ DataflowTemplatedJobStartOperator,
)
from airflow.version import version
diff --git a/tests/providers/google/cloud/operators/test_dataproc_system.py b/tests/providers/google/cloud/operators/test_dataproc_system.py
index 48d4fdf8e6bbb..568af28f53fa0 100644
--- a/tests/providers/google/cloud/operators/test_dataproc_system.py
+++ b/tests/providers/google/cloud/operators/test_dataproc_system.py
@@ -17,7 +17,7 @@
# under the License.
import pytest
-from airflow.providers.google.cloud.example_dags.example_dataproc import PYSPARK_MAIN, BUCKET, SPARKR_MAIN
+from airflow.providers.google.cloud.example_dags.example_dataproc import BUCKET, PYSPARK_MAIN, SPARKR_MAIN
from tests.providers.google.cloud.utils.gcp_authenticator import GCP_DATAPROC_KEY
from tests.test_utils.gcp_system_helpers import CLOUD_DAG_FOLDER, GoogleSystemTest, provide_gcp_context
diff --git a/tests/providers/google/cloud/operators/test_dlp.py b/tests/providers/google/cloud/operators/test_dlp.py
index 6b6ad09489208..7c68102784f7e 100644
--- a/tests/providers/google/cloud/operators/test_dlp.py
+++ b/tests/providers/google/cloud/operators/test_dlp.py
@@ -22,7 +22,6 @@
"""
import unittest
-
from unittest import mock
from airflow.providers.google.cloud.operators.dlp import (
diff --git a/tests/providers/google/cloud/operators/test_dlp_system.py b/tests/providers/google/cloud/operators/test_dlp_system.py
index d27951fd7c2c5..12296ae248e78 100644
--- a/tests/providers/google/cloud/operators/test_dlp_system.py
+++ b/tests/providers/google/cloud/operators/test_dlp_system.py
@@ -23,9 +23,9 @@
"""
import pytest
+from airflow.providers.google.cloud.example_dags.example_dlp import OUTPUT_BUCKET, OUTPUT_FILENAME
from tests.providers.google.cloud.utils.gcp_authenticator import GCP_DLP_KEY
from tests.test_utils.gcp_system_helpers import CLOUD_DAG_FOLDER, GoogleSystemTest, provide_gcp_context
-from airflow.providers.google.cloud.example_dags.example_dlp import OUTPUT_BUCKET, OUTPUT_FILENAME
@pytest.fixture(scope="class")
diff --git a/tests/providers/google/cloud/operators/test_functions.py b/tests/providers/google/cloud/operators/test_functions.py
index a96d45ae1ee2c..4a8546054eb16 100644
--- a/tests/providers/google/cloud/operators/test_functions.py
+++ b/tests/providers/google/cloud/operators/test_functions.py
@@ -18,8 +18,8 @@
import unittest
from copy import deepcopy
-
from unittest import mock
+
from googleapiclient.errors import HttpError
from parameterized import parameterized
diff --git a/tests/providers/google/cloud/operators/test_gcs.py b/tests/providers/google/cloud/operators/test_gcs.py
index 08d649dd9ca5e..b41cf97710a6f 100644
--- a/tests/providers/google/cloud/operators/test_gcs.py
+++ b/tests/providers/google/cloud/operators/test_gcs.py
@@ -17,7 +17,6 @@
# under the License.
import unittest
-
from unittest import mock
from airflow.providers.google.cloud.operators.gcs import (
diff --git a/tests/providers/google/cloud/operators/test_kubernetes_engine.py b/tests/providers/google/cloud/operators/test_kubernetes_engine.py
index 2bf5933ba0da2..44b9b68daea14 100644
--- a/tests/providers/google/cloud/operators/test_kubernetes_engine.py
+++ b/tests/providers/google/cloud/operators/test_kubernetes_engine.py
@@ -18,9 +18,9 @@
import json
import os
import unittest
-
from unittest import mock
from unittest.mock import PropertyMock
+
from parameterized import parameterized
from airflow.exceptions import AirflowException
diff --git a/tests/providers/google/cloud/operators/test_life_sciences.py b/tests/providers/google/cloud/operators/test_life_sciences.py
index d7432825657b4..cc08e13ff8ff8 100644
--- a/tests/providers/google/cloud/operators/test_life_sciences.py
+++ b/tests/providers/google/cloud/operators/test_life_sciences.py
@@ -18,7 +18,6 @@
"""Tests for Google Life Sciences Run Pipeline operator """
import unittest
-
from unittest import mock
from airflow.providers.google.cloud.operators.life_sciences import LifeSciencesRunPipelineOperator
diff --git a/tests/providers/google/cloud/operators/test_mlengine_utils.py b/tests/providers/google/cloud/operators/test_mlengine_utils.py
index 8cf5009a99754..6c133154b61c0 100644
--- a/tests/providers/google/cloud/operators/test_mlengine_utils.py
+++ b/tests/providers/google/cloud/operators/test_mlengine_utils.py
@@ -17,9 +17,8 @@
import datetime
import unittest
-from unittest.mock import ANY, patch
-
from unittest import mock
+from unittest.mock import ANY, patch
from airflow.exceptions import AirflowException
from airflow.models.dag import DAG
diff --git a/tests/providers/google/cloud/operators/test_pubsub.py b/tests/providers/google/cloud/operators/test_pubsub.py
index 662baa8e7db4d..0bb1dfd8d2a86 100644
--- a/tests/providers/google/cloud/operators/test_pubsub.py
+++ b/tests/providers/google/cloud/operators/test_pubsub.py
@@ -18,8 +18,8 @@
import unittest
from typing import Any, Dict, List
-
from unittest import mock
+
from google.cloud.pubsub_v1.types import ReceivedMessage
from google.protobuf.json_format import MessageToDict, ParseDict
diff --git a/tests/providers/google/cloud/operators/test_spanner.py b/tests/providers/google/cloud/operators/test_spanner.py
index 1ce86f2044d0e..4daccdb0e3936 100644
--- a/tests/providers/google/cloud/operators/test_spanner.py
+++ b/tests/providers/google/cloud/operators/test_spanner.py
@@ -16,8 +16,8 @@
# specific language governing permissions and limitations
# under the License.
import unittest
-
from unittest import mock
+
from parameterized import parameterized
from airflow.exceptions import AirflowException
diff --git a/tests/providers/google/cloud/operators/test_speech_to_text.py b/tests/providers/google/cloud/operators/test_speech_to_text.py
index bd527bdc4cd4f..c9325ebe5924e 100644
--- a/tests/providers/google/cloud/operators/test_speech_to_text.py
+++ b/tests/providers/google/cloud/operators/test_speech_to_text.py
@@ -17,7 +17,6 @@
# under the License.
import unittest
-
from unittest.mock import MagicMock, Mock, patch
from airflow.exceptions import AirflowException
diff --git a/tests/providers/google/cloud/operators/test_stackdriver.py b/tests/providers/google/cloud/operators/test_stackdriver.py
index 3cd126329c5c7..28901b4b32416 100644
--- a/tests/providers/google/cloud/operators/test_stackdriver.py
+++ b/tests/providers/google/cloud/operators/test_stackdriver.py
@@ -18,8 +18,8 @@
import json
import unittest
-
from unittest import mock
+
from google.api_core.gapic_v1.method import DEFAULT
from airflow.providers.google.cloud.operators.stackdriver import (
diff --git a/tests/providers/google/cloud/operators/test_tasks.py b/tests/providers/google/cloud/operators/test_tasks.py
index 73eae98a8949e..cac1441f67698 100644
--- a/tests/providers/google/cloud/operators/test_tasks.py
+++ b/tests/providers/google/cloud/operators/test_tasks.py
@@ -17,8 +17,8 @@
# under the License.
import unittest
-
from unittest import mock
+
from google.cloud.tasks_v2.types import Queue, Task
from airflow.providers.google.cloud.operators.tasks import (
diff --git a/tests/providers/google/cloud/operators/test_text_to_speech.py b/tests/providers/google/cloud/operators/test_text_to_speech.py
index dd6a628cfcc23..006c6b5df77a6 100644
--- a/tests/providers/google/cloud/operators/test_text_to_speech.py
+++ b/tests/providers/google/cloud/operators/test_text_to_speech.py
@@ -17,8 +17,8 @@
# under the License.
import unittest
-
from unittest.mock import ANY, Mock, PropertyMock, patch
+
from parameterized import parameterized
from airflow.exceptions import AirflowException
diff --git a/tests/providers/google/cloud/operators/test_translate.py b/tests/providers/google/cloud/operators/test_translate.py
index 5a882e5275053..c32bd2f709701 100644
--- a/tests/providers/google/cloud/operators/test_translate.py
+++ b/tests/providers/google/cloud/operators/test_translate.py
@@ -17,7 +17,6 @@
# under the License.
import unittest
-
from unittest import mock
from airflow.providers.google.cloud.operators.translate import CloudTranslateTextOperator
diff --git a/tests/providers/google/cloud/operators/test_translate_speech.py b/tests/providers/google/cloud/operators/test_translate_speech.py
index 015d4ee00edd6..fc1c6376cadcc 100644
--- a/tests/providers/google/cloud/operators/test_translate_speech.py
+++ b/tests/providers/google/cloud/operators/test_translate_speech.py
@@ -17,8 +17,8 @@
# under the License.
import unittest
-
from unittest import mock
+
from google.cloud.speech_v1.proto.cloud_speech_pb2 import (
RecognizeResponse,
SpeechRecognitionAlternative,
diff --git a/tests/providers/google/cloud/operators/test_video_intelligence.py b/tests/providers/google/cloud/operators/test_video_intelligence.py
index f3c8360516588..c6b5f322ae59f 100644
--- a/tests/providers/google/cloud/operators/test_video_intelligence.py
+++ b/tests/providers/google/cloud/operators/test_video_intelligence.py
@@ -17,8 +17,8 @@
# under the License.
import unittest
-
from unittest import mock
+
from google.cloud.videointelligence_v1 import enums
from google.cloud.videointelligence_v1.proto.video_intelligence_pb2 import AnnotateVideoResponse
diff --git a/tests/providers/google/cloud/operators/test_vision.py b/tests/providers/google/cloud/operators/test_vision.py
index 5dcaf0ef007d0..2ca8d9ae70b89 100644
--- a/tests/providers/google/cloud/operators/test_vision.py
+++ b/tests/providers/google/cloud/operators/test_vision.py
@@ -17,8 +17,8 @@
# under the License.
import unittest
-
from unittest import mock
+
from google.api_core.exceptions import AlreadyExists
from google.cloud.vision_v1.types import Product, ProductSet, ReferenceImage
diff --git a/tests/providers/google/cloud/sensors/test_bigquery_dts.py b/tests/providers/google/cloud/sensors/test_bigquery_dts.py
index 4bb466dbb3d1a..92a116ef8e3e4 100644
--- a/tests/providers/google/cloud/sensors/test_bigquery_dts.py
+++ b/tests/providers/google/cloud/sensors/test_bigquery_dts.py
@@ -17,7 +17,6 @@
# under the License.
import unittest
-
from unittest import mock
from airflow.providers.google.cloud.sensors.bigquery_dts import BigQueryDataTransferServiceTransferRunSensor
diff --git a/tests/providers/google/cloud/sensors/test_cloud_storage_transfer_service.py b/tests/providers/google/cloud/sensors/test_cloud_storage_transfer_service.py
index e86746986b10d..aa169fdf1dd23 100644
--- a/tests/providers/google/cloud/sensors/test_cloud_storage_transfer_service.py
+++ b/tests/providers/google/cloud/sensors/test_cloud_storage_transfer_service.py
@@ -16,8 +16,8 @@
# specific language governing permissions and limitations
# under the License.
import unittest
-
from unittest import mock
+
from parameterized import parameterized
from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import GcpTransferOperationStatus
diff --git a/tests/providers/google/cloud/sensors/test_dataproc.py b/tests/providers/google/cloud/sensors/test_dataproc.py
index f9b40035e8cb5..f2a1d45704cc8 100644
--- a/tests/providers/google/cloud/sensors/test_dataproc.py
+++ b/tests/providers/google/cloud/sensors/test_dataproc.py
@@ -19,9 +19,9 @@
from unittest import mock
from google.cloud.dataproc_v1beta2.types import JobStatus
+
from airflow import AirflowException
from airflow.providers.google.cloud.sensors.dataproc import DataprocJobSensor
-
from airflow.version import version as airflow_version
AIRFLOW_VERSION = "v" + airflow_version.replace(".", "-").replace("+", "-")
diff --git a/tests/providers/google/cloud/sensors/test_pubsub.py b/tests/providers/google/cloud/sensors/test_pubsub.py
index 46c85e3158e50..5c1ead938fd1a 100644
--- a/tests/providers/google/cloud/sensors/test_pubsub.py
+++ b/tests/providers/google/cloud/sensors/test_pubsub.py
@@ -18,8 +18,8 @@
import unittest
from typing import Any, Dict, List
-
from unittest import mock
+
from google.cloud.pubsub_v1.types import ReceivedMessage
from google.protobuf.json_format import MessageToDict, ParseDict
diff --git a/tests/providers/google/cloud/transfers/test_adls_to_gcs.py b/tests/providers/google/cloud/transfers/test_adls_to_gcs.py
index 94003ffbd1a3d..68649c242c6d1 100644
--- a/tests/providers/google/cloud/transfers/test_adls_to_gcs.py
+++ b/tests/providers/google/cloud/transfers/test_adls_to_gcs.py
@@ -17,7 +17,6 @@
# under the License.
import unittest
-
from unittest import mock
from airflow.providers.google.cloud.transfers.adls_to_gcs import ADLSToGCSOperator
diff --git a/tests/providers/google/cloud/transfers/test_azure_fileshare_to_gcs.py b/tests/providers/google/cloud/transfers/test_azure_fileshare_to_gcs.py
index 11da033e5d944..1a46622869b27 100644
--- a/tests/providers/google/cloud/transfers/test_azure_fileshare_to_gcs.py
+++ b/tests/providers/google/cloud/transfers/test_azure_fileshare_to_gcs.py
@@ -16,7 +16,6 @@
# under the License.
import unittest
-
from unittest import mock
from airflow.providers.google.cloud.transfers.azure_fileshare_to_gcs import AzureFileShareToGCSOperator
diff --git a/tests/providers/google/cloud/transfers/test_azure_fileshare_to_gcs_system.py b/tests/providers/google/cloud/transfers/test_azure_fileshare_to_gcs_system.py
index ebd36652ea55b..76e69cdc67500 100644
--- a/tests/providers/google/cloud/transfers/test_azure_fileshare_to_gcs_system.py
+++ b/tests/providers/google/cloud/transfers/test_azure_fileshare_to_gcs_system.py
@@ -19,22 +19,21 @@
import pytest
+from airflow.models import Connection
from airflow.providers.google.cloud.example_dags.example_azure_fileshare_to_gcs import (
+ AZURE_DIRECTORY_NAME,
AZURE_SHARE_NAME,
DEST_GCS_BUCKET,
- AZURE_DIRECTORY_NAME,
)
from airflow.utils.session import create_session
-
-from airflow.models import Connection
from tests.providers.google.cloud.utils.gcp_authenticator import GCP_GCS_KEY
from tests.test_utils.azure_system_helpers import AzureSystemTest, provide_azure_fileshare
from tests.test_utils.db import clear_db_connections
from tests.test_utils.gcp_system_helpers import (
- GoogleSystemTest,
- provide_gcs_bucket,
CLOUD_DAG_FOLDER,
+ GoogleSystemTest,
provide_gcp_context,
+ provide_gcs_bucket,
)
AZURE_LOGIN = os.environ.get('AZURE_LOGIN', 'default_login')
diff --git a/tests/providers/google/cloud/transfers/test_bigquery_to_bigquery.py b/tests/providers/google/cloud/transfers/test_bigquery_to_bigquery.py
index 11f71fbb4d451..162182958655d 100644
--- a/tests/providers/google/cloud/transfers/test_bigquery_to_bigquery.py
+++ b/tests/providers/google/cloud/transfers/test_bigquery_to_bigquery.py
@@ -17,7 +17,6 @@
# under the License.
import unittest
-
from unittest import mock
from airflow.providers.google.cloud.transfers.bigquery_to_bigquery import BigQueryToBigQueryOperator
diff --git a/tests/providers/google/cloud/transfers/test_bigquery_to_gcs.py b/tests/providers/google/cloud/transfers/test_bigquery_to_gcs.py
index d25dbf1955bd1..2ddac81e76929 100644
--- a/tests/providers/google/cloud/transfers/test_bigquery_to_gcs.py
+++ b/tests/providers/google/cloud/transfers/test_bigquery_to_gcs.py
@@ -17,7 +17,6 @@
# under the License.
import unittest
-
from unittest import mock
from airflow.providers.google.cloud.transfers.bigquery_to_gcs import BigQueryToGCSOperator
diff --git a/tests/providers/google/cloud/transfers/test_bigquery_to_mysql.py b/tests/providers/google/cloud/transfers/test_bigquery_to_mysql.py
index 29811ab33ac77..acdf63dd0b103 100644
--- a/tests/providers/google/cloud/transfers/test_bigquery_to_mysql.py
+++ b/tests/providers/google/cloud/transfers/test_bigquery_to_mysql.py
@@ -16,7 +16,6 @@
# specific language governing permissions and limitations
# under the License.
import unittest
-
from unittest import mock
from airflow.providers.google.cloud.transfers.bigquery_to_mysql import BigQueryToMySqlOperator
diff --git a/tests/providers/google/cloud/transfers/test_cassandra_to_gcs.py b/tests/providers/google/cloud/transfers/test_cassandra_to_gcs.py
index 7016d8709c8d2..4d6481d1b1e72 100644
--- a/tests/providers/google/cloud/transfers/test_cassandra_to_gcs.py
+++ b/tests/providers/google/cloud/transfers/test_cassandra_to_gcs.py
@@ -18,7 +18,6 @@
import unittest
from unittest import mock
-
from unittest.mock import call
from airflow.providers.google.cloud.transfers.cassandra_to_gcs import CassandraToGCSOperator
diff --git a/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py b/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py
index a602bee4257bc..9b9aabfe4ebcb 100644
--- a/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py
+++ b/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py
@@ -17,7 +17,6 @@
# under the License.
import unittest
-
from unittest import mock
from airflow.providers.google.cloud.transfers.gcs_to_bigquery import GCSToBigQueryOperator
diff --git a/tests/providers/google/cloud/transfers/test_gcs_to_gcs.py b/tests/providers/google/cloud/transfers/test_gcs_to_gcs.py
index e280a6d9d0188..0b2e1eb0ca2fa 100644
--- a/tests/providers/google/cloud/transfers/test_gcs_to_gcs.py
+++ b/tests/providers/google/cloud/transfers/test_gcs_to_gcs.py
@@ -18,7 +18,6 @@
import unittest
from datetime import datetime
-
from unittest import mock
from airflow.exceptions import AirflowException
diff --git a/tests/providers/google/cloud/transfers/test_gcs_to_local.py b/tests/providers/google/cloud/transfers/test_gcs_to_local.py
index b703d63741114..fcd03d6c3dca1 100644
--- a/tests/providers/google/cloud/transfers/test_gcs_to_local.py
+++ b/tests/providers/google/cloud/transfers/test_gcs_to_local.py
@@ -16,7 +16,6 @@
# under the License.
import unittest
-
from unittest import mock
from airflow.providers.google.cloud.transfers.gcs_to_local import GCSToLocalFilesystemOperator
diff --git a/tests/providers/google/cloud/transfers/test_gcs_to_local_system.py b/tests/providers/google/cloud/transfers/test_gcs_to_local_system.py
index 1b5447b2ec642..55762bab3836f 100644
--- a/tests/providers/google/cloud/transfers/test_gcs_to_local_system.py
+++ b/tests/providers/google/cloud/transfers/test_gcs_to_local_system.py
@@ -21,8 +21,8 @@
from airflow.providers.google.cloud.example_dags.example_gcs_to_local import (
BUCKET,
- PATH_TO_REMOTE_FILE,
PATH_TO_LOCAL_FILE,
+ PATH_TO_REMOTE_FILE,
)
from tests.providers.google.cloud.utils.gcp_authenticator import GCP_GCS_KEY
from tests.test_utils.gcp_system_helpers import CLOUD_DAG_FOLDER, GoogleSystemTest, provide_gcp_context
diff --git a/tests/providers/google/cloud/transfers/test_gcs_to_sftp.py b/tests/providers/google/cloud/transfers/test_gcs_to_sftp.py
index 81653447cd93a..3cf3d537392f9 100644
--- a/tests/providers/google/cloud/transfers/test_gcs_to_sftp.py
+++ b/tests/providers/google/cloud/transfers/test_gcs_to_sftp.py
@@ -19,7 +19,6 @@
import os
import unittest
-
from unittest import mock
from airflow.exceptions import AirflowException
diff --git a/tests/providers/google/cloud/transfers/test_local_to_gcs.py b/tests/providers/google/cloud/transfers/test_local_to_gcs.py
index 5cf2509f06500..800d9d8a23d35 100644
--- a/tests/providers/google/cloud/transfers/test_local_to_gcs.py
+++ b/tests/providers/google/cloud/transfers/test_local_to_gcs.py
@@ -21,8 +21,8 @@
import os
import unittest
from glob import glob
-
from unittest import mock
+
import pytest
from airflow.models.dag import DAG
diff --git a/tests/providers/google/cloud/transfers/test_mssql_to_gcs.py b/tests/providers/google/cloud/transfers/test_mssql_to_gcs.py
index f537aef205a6b..8f22ef4e26d88 100644
--- a/tests/providers/google/cloud/transfers/test_mssql_to_gcs.py
+++ b/tests/providers/google/cloud/transfers/test_mssql_to_gcs.py
@@ -17,7 +17,6 @@
# under the License.
import unittest
-
from unittest import mock
from airflow import PY38
diff --git a/tests/providers/google/cloud/transfers/test_mysql_to_gcs.py b/tests/providers/google/cloud/transfers/test_mysql_to_gcs.py
index f96f830622d0c..d5690c63afa7d 100644
--- a/tests/providers/google/cloud/transfers/test_mysql_to_gcs.py
+++ b/tests/providers/google/cloud/transfers/test_mysql_to_gcs.py
@@ -19,8 +19,8 @@
import datetime
import decimal
import unittest
-
from unittest import mock
+
from _mysql_exceptions import ProgrammingError
from parameterized import parameterized
diff --git a/tests/providers/google/cloud/transfers/test_mysql_to_gcs_system.py b/tests/providers/google/cloud/transfers/test_mysql_to_gcs_system.py
index 65e2bc7af5330..08aeac929ca8c 100644
--- a/tests/providers/google/cloud/transfers/test_mysql_to_gcs_system.py
+++ b/tests/providers/google/cloud/transfers/test_mysql_to_gcs_system.py
@@ -16,10 +16,10 @@
# specific language governing permissions and limitations
# under the License.
import pytest
-from psycopg2 import ProgrammingError, OperationalError
+from psycopg2 import OperationalError, ProgrammingError
-from airflow.providers.mysql.hooks.mysql import MySqlHook
from airflow.providers.google.cloud.example_dags.example_mysql_to_gcs import GCS_BUCKET
+from airflow.providers.mysql.hooks.mysql import MySqlHook
from tests.providers.google.cloud.utils.gcp_authenticator import GCP_GCS_KEY
from tests.test_utils.gcp_system_helpers import CLOUD_DAG_FOLDER, GoogleSystemTest, provide_gcp_context
diff --git a/tests/providers/google/cloud/transfers/test_s3_to_gcs.py b/tests/providers/google/cloud/transfers/test_s3_to_gcs.py
index e59fa34a6cd71..990702fdbdb9a 100644
--- a/tests/providers/google/cloud/transfers/test_s3_to_gcs.py
+++ b/tests/providers/google/cloud/transfers/test_s3_to_gcs.py
@@ -17,7 +17,6 @@
# under the License.
import unittest
-
from unittest import mock
from airflow.providers.google.cloud.transfers.s3_to_gcs import S3ToGCSOperator
diff --git a/tests/providers/google/cloud/transfers/test_s3_to_gcs_system.py b/tests/providers/google/cloud/transfers/test_s3_to_gcs_system.py
index 66b4217452f69..6f25f641443fa 100644
--- a/tests/providers/google/cloud/transfers/test_s3_to_gcs_system.py
+++ b/tests/providers/google/cloud/transfers/test_s3_to_gcs_system.py
@@ -16,9 +16,10 @@
# under the License.
import pytest
+
from airflow.providers.google.cloud.example_dags.example_s3_to_gcs import UPLOAD_FILE
from tests.providers.google.cloud.utils.gcp_authenticator import GCP_GCS_KEY
-from tests.test_utils.gcp_system_helpers import GoogleSystemTest, provide_gcp_context, CLOUD_DAG_FOLDER
+from tests.test_utils.gcp_system_helpers import CLOUD_DAG_FOLDER, GoogleSystemTest, provide_gcp_context
FILENAME = UPLOAD_FILE.split('/')[-1]
diff --git a/tests/providers/google/cloud/transfers/test_salesforce_to_gcs.py b/tests/providers/google/cloud/transfers/test_salesforce_to_gcs.py
index e52e8ddf39ded..5aed4b7535662 100644
--- a/tests/providers/google/cloud/transfers/test_salesforce_to_gcs.py
+++ b/tests/providers/google/cloud/transfers/test_salesforce_to_gcs.py
@@ -17,7 +17,6 @@
import unittest
from collections import OrderedDict
-
from unittest import mock
from airflow.providers.google.cloud.hooks.gcs import GCSHook
diff --git a/tests/providers/google/cloud/transfers/test_salesforce_to_gcs_system.py b/tests/providers/google/cloud/transfers/test_salesforce_to_gcs_system.py
index ffdf49584956f..d5e45b6682bab 100644
--- a/tests/providers/google/cloud/transfers/test_salesforce_to_gcs_system.py
+++ b/tests/providers/google/cloud/transfers/test_salesforce_to_gcs_system.py
@@ -16,6 +16,7 @@
# under the License.
import os
+
import pytest
from tests.providers.google.cloud.utils.gcp_authenticator import GCP_BIGQUERY_KEY
diff --git a/tests/providers/google/cloud/transfers/test_sftp_to_gcs.py b/tests/providers/google/cloud/transfers/test_sftp_to_gcs.py
index 681b6fe657ec7..03b21d48c085e 100644
--- a/tests/providers/google/cloud/transfers/test_sftp_to_gcs.py
+++ b/tests/providers/google/cloud/transfers/test_sftp_to_gcs.py
@@ -19,7 +19,6 @@
import os
import unittest
-
from unittest import mock
from airflow.exceptions import AirflowException
diff --git a/tests/providers/google/cloud/transfers/test_sql_to_gcs.py b/tests/providers/google/cloud/transfers/test_sql_to_gcs.py
index 889352d6357d4..b4bc5a3263fb9 100644
--- a/tests/providers/google/cloud/transfers/test_sql_to_gcs.py
+++ b/tests/providers/google/cloud/transfers/test_sql_to_gcs.py
@@ -17,9 +17,9 @@
import json
import unittest
+from unittest import mock
from unittest.mock import Mock
-from unittest import mock
import unicodecsv as csv
from airflow.providers.google.cloud.hooks.gcs import GCSHook
diff --git a/tests/providers/google/cloud/utils/test_credentials_provider.py b/tests/providers/google/cloud/utils/test_credentials_provider.py
index 0aa9770dd2a74..c49872c0c29a9 100644
--- a/tests/providers/google/cloud/utils/test_credentials_provider.py
+++ b/tests/providers/google/cloud/utils/test_credentials_provider.py
@@ -20,9 +20,9 @@
import re
import unittest
from io import StringIO
+from unittest import mock
from uuid import uuid4
-from unittest import mock
from google.auth.environment_vars import CREDENTIALS
from parameterized import parameterized
diff --git a/tests/providers/google/firebase/hooks/test_firestore.py b/tests/providers/google/firebase/hooks/test_firestore.py
index dea7c78691565..ee1117a5875bc 100644
--- a/tests/providers/google/firebase/hooks/test_firestore.py
+++ b/tests/providers/google/firebase/hooks/test_firestore.py
@@ -21,7 +21,6 @@
import unittest
from typing import Optional
from unittest import mock
-
from unittest.mock import PropertyMock
from airflow.exceptions import AirflowException
diff --git a/tests/providers/google/suite/hooks/test_drive.py b/tests/providers/google/suite/hooks/test_drive.py
index ba0e31bdb9cd0..37f224734a8ba 100644
--- a/tests/providers/google/suite/hooks/test_drive.py
+++ b/tests/providers/google/suite/hooks/test_drive.py
@@ -17,7 +17,6 @@
# under the License.
#
import unittest
-
from unittest import mock
from airflow.providers.google.suite.hooks.drive import GoogleDriveHook
diff --git a/tests/providers/google/suite/hooks/test_sheets.py b/tests/providers/google/suite/hooks/test_sheets.py
index adcbb3de28762..56f5f5e93a495 100644
--- a/tests/providers/google/suite/hooks/test_sheets.py
+++ b/tests/providers/google/suite/hooks/test_sheets.py
@@ -21,7 +21,6 @@
"""
import unittest
-
from unittest import mock
from airflow.exceptions import AirflowException
diff --git a/tests/providers/grpc/hooks/test_grpc.py b/tests/providers/grpc/hooks/test_grpc.py
index 07d432980821b..4ec9267c4670a 100644
--- a/tests/providers/grpc/hooks/test_grpc.py
+++ b/tests/providers/grpc/hooks/test_grpc.py
@@ -17,7 +17,6 @@
import unittest
from io import StringIO
-
from unittest import mock
from airflow.exceptions import AirflowConfigException
diff --git a/tests/providers/grpc/operators/test_grpc.py b/tests/providers/grpc/operators/test_grpc.py
index 5e96249b5bc24..1bbed8b789c0c 100644
--- a/tests/providers/grpc/operators/test_grpc.py
+++ b/tests/providers/grpc/operators/test_grpc.py
@@ -16,7 +16,6 @@
# under the License.
import unittest
-
from unittest import mock
from airflow.providers.grpc.operators.grpc import GrpcOperator
diff --git a/tests/providers/http/hooks/test_http.py b/tests/providers/http/hooks/test_http.py
index ae666ef4ccc6f..6824d64ae6284 100644
--- a/tests/providers/http/hooks/test_http.py
+++ b/tests/providers/http/hooks/test_http.py
@@ -17,8 +17,8 @@
# under the License.
import json
import unittest
-
from unittest import mock
+
import requests
import requests_mock
import tenacity
diff --git a/tests/providers/http/operators/test_http.py b/tests/providers/http/operators/test_http.py
index d9c7c2ba11f96..75ca9dc91b9bc 100644
--- a/tests/providers/http/operators/test_http.py
+++ b/tests/providers/http/operators/test_http.py
@@ -17,8 +17,8 @@
# under the License.
import unittest
-
from unittest import mock
+
import requests_mock
from airflow.exceptions import AirflowException
diff --git a/tests/providers/http/sensors/test_http.py b/tests/providers/http/sensors/test_http.py
index 046390fac846b..04b6c4946c758 100644
--- a/tests/providers/http/sensors/test_http.py
+++ b/tests/providers/http/sensors/test_http.py
@@ -16,9 +16,9 @@
# specific language governing permissions and limitations
# under the License.
import unittest
+from unittest import mock
from unittest.mock import patch
-from unittest import mock
import requests
from airflow.exceptions import AirflowException, AirflowSensorTimeout
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_batch.py b/tests/providers/microsoft/azure/hooks/test_azure_batch.py
index 6d9188fae4869..9549e5795a2e5 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_batch.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_batch.py
@@ -18,8 +18,8 @@
#
import json
import unittest
-
from unittest import mock
+
from azure.batch import BatchServiceClient, models as batch_models
from airflow.models import Connection
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py b/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
index 1703f4092bfc5..701abc5a34143 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
@@ -22,8 +22,8 @@
import logging
import unittest
import uuid
-
from unittest import mock
+
from azure.cosmos.cosmos_client import CosmosClient
from airflow.exceptions import AirflowException
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_data_lake.py b/tests/providers/microsoft/azure/hooks/test_azure_data_lake.py
index 90a608e1b7464..8fa0feea54146 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_data_lake.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_data_lake.py
@@ -18,7 +18,6 @@
#
import json
import unittest
-
from unittest import mock
from airflow.models import Connection
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_fileshare.py b/tests/providers/microsoft/azure/hooks/test_azure_fileshare.py
index d298ccd8d19e9..a9a6298ddf641 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_fileshare.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_fileshare.py
@@ -27,9 +27,9 @@
import json
import unittest
-
from unittest import mock
-from azure.storage.file import File, Directory
+
+from azure.storage.file import Directory, File
from airflow.models import Connection
from airflow.providers.microsoft.azure.hooks.azure_fileshare import AzureFileShareHook
diff --git a/tests/providers/microsoft/azure/hooks/test_wasb.py b/tests/providers/microsoft/azure/hooks/test_wasb.py
index d7cd0dc4bdaca..12a784790c73e 100644
--- a/tests/providers/microsoft/azure/hooks/test_wasb.py
+++ b/tests/providers/microsoft/azure/hooks/test_wasb.py
@@ -21,7 +21,6 @@
import json
import unittest
from collections import namedtuple
-
from unittest import mock
from airflow.exceptions import AirflowException
diff --git a/tests/providers/microsoft/azure/operators/test_adls_list.py b/tests/providers/microsoft/azure/operators/test_adls_list.py
index c5be2b5ded627..9b0a5c2f7ed09 100644
--- a/tests/providers/microsoft/azure/operators/test_adls_list.py
+++ b/tests/providers/microsoft/azure/operators/test_adls_list.py
@@ -17,7 +17,6 @@
# under the License.
import unittest
-
from unittest import mock
from airflow.providers.microsoft.azure.operators.adls_list import AzureDataLakeStorageListOperator
diff --git a/tests/providers/microsoft/azure/operators/test_azure_batch.py b/tests/providers/microsoft/azure/operators/test_azure_batch.py
index 18702f568db0f..48e5a7a812383 100644
--- a/tests/providers/microsoft/azure/operators/test_azure_batch.py
+++ b/tests/providers/microsoft/azure/operators/test_azure_batch.py
@@ -18,7 +18,6 @@
#
import json
import unittest
-
from unittest import mock
from airflow.exceptions import AirflowException
diff --git a/tests/providers/microsoft/azure/operators/test_azure_container_instances.py b/tests/providers/microsoft/azure/operators/test_azure_container_instances.py
index 22c495249d86c..90f9ece7c3435 100644
--- a/tests/providers/microsoft/azure/operators/test_azure_container_instances.py
+++ b/tests/providers/microsoft/azure/operators/test_azure_container_instances.py
@@ -19,9 +19,9 @@
import unittest
from collections import namedtuple
+from unittest import mock
from unittest.mock import MagicMock
-from unittest import mock
from azure.mgmt.containerinstance.models import ContainerState, Event
from airflow.exceptions import AirflowException
diff --git a/tests/providers/microsoft/azure/operators/test_azure_cosmos.py b/tests/providers/microsoft/azure/operators/test_azure_cosmos.py
index ac1299012c1c9..26144406cb892 100644
--- a/tests/providers/microsoft/azure/operators/test_azure_cosmos.py
+++ b/tests/providers/microsoft/azure/operators/test_azure_cosmos.py
@@ -21,7 +21,6 @@
import json
import unittest
import uuid
-
from unittest import mock
from airflow.models import Connection
diff --git a/tests/providers/microsoft/azure/operators/test_wasb_delete_blob.py b/tests/providers/microsoft/azure/operators/test_wasb_delete_blob.py
index 8310f72bedf67..04df34cd89ca6 100644
--- a/tests/providers/microsoft/azure/operators/test_wasb_delete_blob.py
+++ b/tests/providers/microsoft/azure/operators/test_wasb_delete_blob.py
@@ -19,7 +19,6 @@
import datetime
import unittest
-
from unittest import mock
from airflow.models.dag import DAG
diff --git a/tests/providers/microsoft/azure/sensors/test_wasb.py b/tests/providers/microsoft/azure/sensors/test_wasb.py
index 8b1c6a1d44509..5aaec19e5e203 100644
--- a/tests/providers/microsoft/azure/sensors/test_wasb.py
+++ b/tests/providers/microsoft/azure/sensors/test_wasb.py
@@ -19,7 +19,6 @@
import datetime
import unittest
-
from unittest import mock
from airflow.models.dag import DAG
diff --git a/tests/providers/microsoft/azure/transfers/test_azure_blob_to_gcs.py b/tests/providers/microsoft/azure/transfers/test_azure_blob_to_gcs.py
index cc1d7b1d51606..b9517a894c727 100644
--- a/tests/providers/microsoft/azure/transfers/test_azure_blob_to_gcs.py
+++ b/tests/providers/microsoft/azure/transfers/test_azure_blob_to_gcs.py
@@ -15,12 +15,9 @@
# specific language governing permissions and limitations
# under the License.
import unittest
-
from unittest import mock
-from airflow.providers.microsoft.azure.transfers.azure_blob_to_gcs import (
- AzureBlobStorageToGCSOperator,
-)
+from airflow.providers.microsoft.azure.transfers.azure_blob_to_gcs import AzureBlobStorageToGCSOperator
WASB_CONN_ID = "wasb_default"
GCP_CONN_ID = "google_cloud_default"
diff --git a/tests/providers/microsoft/azure/transfers/test_file_to_wasb.py b/tests/providers/microsoft/azure/transfers/test_file_to_wasb.py
index d7b708ba8b6a7..5a4f14c729259 100644
--- a/tests/providers/microsoft/azure/transfers/test_file_to_wasb.py
+++ b/tests/providers/microsoft/azure/transfers/test_file_to_wasb.py
@@ -19,7 +19,6 @@
import datetime
import unittest
-
from unittest import mock
from airflow.models.dag import DAG
diff --git a/tests/providers/microsoft/azure/transfers/test_local_to_adls.py b/tests/providers/microsoft/azure/transfers/test_local_to_adls.py
index d79188dd6962f..0bcc371393a35 100644
--- a/tests/providers/microsoft/azure/transfers/test_local_to_adls.py
+++ b/tests/providers/microsoft/azure/transfers/test_local_to_adls.py
@@ -17,7 +17,6 @@
# under the License.
import unittest
-
from unittest import mock
from airflow.exceptions import AirflowException
diff --git a/tests/providers/microsoft/azure/transfers/test_oracle_to_azure_data_lake.py b/tests/providers/microsoft/azure/transfers/test_oracle_to_azure_data_lake.py
index 49f591f4278e7..8d16878b2d45d 100644
--- a/tests/providers/microsoft/azure/transfers/test_oracle_to_azure_data_lake.py
+++ b/tests/providers/microsoft/azure/transfers/test_oracle_to_azure_data_lake.py
@@ -19,9 +19,9 @@
import os
import unittest
from tempfile import TemporaryDirectory
-
from unittest import mock
from unittest.mock import MagicMock
+
import unicodecsv as csv
from airflow.providers.microsoft.azure.transfers.oracle_to_azure_data_lake import (
diff --git a/tests/providers/microsoft/mssql/hooks/test_mssql.py b/tests/providers/microsoft/mssql/hooks/test_mssql.py
index 0065f9438ef41..91aed5d76fd44 100644
--- a/tests/providers/microsoft/mssql/hooks/test_mssql.py
+++ b/tests/providers/microsoft/mssql/hooks/test_mssql.py
@@ -17,7 +17,6 @@
# under the License.
import unittest
-
from unittest import mock
from airflow import PY38
diff --git a/tests/providers/microsoft/mssql/operators/test_mssql.py b/tests/providers/microsoft/mssql/operators/test_mssql.py
index 4e4c6c272d1f8..1304b123985db 100644
--- a/tests/providers/microsoft/mssql/operators/test_mssql.py
+++ b/tests/providers/microsoft/mssql/operators/test_mssql.py
@@ -17,7 +17,6 @@
# under the License.
import unittest
-
from unittest import mock
from airflow import PY38
diff --git a/tests/providers/openfaas/hooks/test_openfaas.py b/tests/providers/openfaas/hooks/test_openfaas.py
index 10332e36f9051..5c68def0c5c1c 100644
--- a/tests/providers/openfaas/hooks/test_openfaas.py
+++ b/tests/providers/openfaas/hooks/test_openfaas.py
@@ -18,8 +18,8 @@
#
import unittest
-
from unittest import mock
+
import requests_mock
from airflow.exceptions import AirflowException
diff --git a/tests/providers/oracle/hooks/test_oracle.py b/tests/providers/oracle/hooks/test_oracle.py
index 2176306895d8d..d27ca44d68704 100644
--- a/tests/providers/oracle/hooks/test_oracle.py
+++ b/tests/providers/oracle/hooks/test_oracle.py
@@ -19,8 +19,8 @@
import json
import unittest
from datetime import datetime
-
from unittest import mock
+
import numpy
from airflow.models import Connection
diff --git a/tests/providers/oracle/operators/test_oracle.py b/tests/providers/oracle/operators/test_oracle.py
index bbc3fc9221c98..2c4d984647e8c 100644
--- a/tests/providers/oracle/operators/test_oracle.py
+++ b/tests/providers/oracle/operators/test_oracle.py
@@ -16,7 +16,6 @@
# under the License.
import unittest
-
from unittest import mock
from airflow.providers.oracle.hooks.oracle import OracleHook
diff --git a/tests/providers/oracle/transfers/test_oracle_to_oracle.py b/tests/providers/oracle/transfers/test_oracle_to_oracle.py
index f7536afe6d1a5..a036a3e990df0 100644
--- a/tests/providers/oracle/transfers/test_oracle_to_oracle.py
+++ b/tests/providers/oracle/transfers/test_oracle_to_oracle.py
@@ -18,7 +18,6 @@
import unittest
from unittest import mock
-
from unittest.mock import MagicMock
from airflow.providers.oracle.transfers.oracle_to_oracle import OracleToOracleOperator
diff --git a/tests/providers/pagerduty/hooks/test_pagerduty.py b/tests/providers/pagerduty/hooks/test_pagerduty.py
index 50a115437ff8c..95e9719cd10bb 100644
--- a/tests/providers/pagerduty/hooks/test_pagerduty.py
+++ b/tests/providers/pagerduty/hooks/test_pagerduty.py
@@ -16,6 +16,7 @@
# specific language governing permissions and limitations
# under the License.
import unittest
+
import requests_mock
from airflow.models import Connection
diff --git a/tests/providers/plexus/hooks/test_plexus.py b/tests/providers/plexus/hooks/test_plexus.py
index f90d35b90da27..3419777af4119 100644
--- a/tests/providers/plexus/hooks/test_plexus.py
+++ b/tests/providers/plexus/hooks/test_plexus.py
@@ -21,6 +21,7 @@
import arrow
import pytest
from requests.exceptions import Timeout
+
from airflow.exceptions import AirflowException
from airflow.providers.plexus.hooks.plexus import PlexusHook
diff --git a/tests/providers/plexus/operators/test_job.py b/tests/providers/plexus/operators/test_job.py
index 44873ff4f2521..7c902332fc5b9 100644
--- a/tests/providers/plexus/operators/test_job.py
+++ b/tests/providers/plexus/operators/test_job.py
@@ -17,8 +17,10 @@
from unittest import mock
from unittest.mock import Mock
+
import pytest
from requests.exceptions import Timeout
+
from airflow.exceptions import AirflowException
from airflow.providers.plexus.operators.job import PlexusJobOperator
diff --git a/tests/providers/qubole/operators/test_qubole_check.py b/tests/providers/qubole/operators/test_qubole_check.py
index e3f09a56cc89f..227f0490a5d35 100644
--- a/tests/providers/qubole/operators/test_qubole_check.py
+++ b/tests/providers/qubole/operators/test_qubole_check.py
@@ -18,8 +18,8 @@
#
import unittest
from datetime import datetime
-
from unittest import mock
+
from qds_sdk.commands import HiveCommand
from airflow.exceptions import AirflowException
diff --git a/tests/providers/samba/hooks/test_samba.py b/tests/providers/samba/hooks/test_samba.py
index e5f294d93a0c7..86d9c764207d8 100644
--- a/tests/providers/samba/hooks/test_samba.py
+++ b/tests/providers/samba/hooks/test_samba.py
@@ -17,9 +17,9 @@
# under the License.
import unittest
+from unittest import mock
from unittest.mock import call
-from unittest import mock
import smbclient
from airflow.exceptions import AirflowException
diff --git a/tests/providers/sendgrid/utils/test_emailer.py b/tests/providers/sendgrid/utils/test_emailer.py
index df872013447ba..bb1a5f2cf5f20 100644
--- a/tests/providers/sendgrid/utils/test_emailer.py
+++ b/tests/providers/sendgrid/utils/test_emailer.py
@@ -21,7 +21,6 @@
import os
import tempfile
import unittest
-
from unittest import mock
from airflow.providers.sendgrid.utils.emailer import send_email
diff --git a/tests/providers/singularity/operators/test_singularity.py b/tests/providers/singularity/operators/test_singularity.py
index 324ca35fd9df1..71ab5ace64db9 100644
--- a/tests/providers/singularity/operators/test_singularity.py
+++ b/tests/providers/singularity/operators/test_singularity.py
@@ -17,8 +17,8 @@
# under the License.
import unittest
-
from unittest import mock
+
from parameterized import parameterized
from spython.instance import Instance
diff --git a/tests/providers/slack/hooks/test_slack.py b/tests/providers/slack/hooks/test_slack.py
index 4ca0060f8fbaf..6ebec3a38386d 100644
--- a/tests/providers/slack/hooks/test_slack.py
+++ b/tests/providers/slack/hooks/test_slack.py
@@ -17,8 +17,8 @@
# under the License.
import unittest
-
from unittest import mock
+
from slack.errors import SlackApiError
from airflow.exceptions import AirflowException
diff --git a/tests/providers/slack/operators/test_slack.py b/tests/providers/slack/operators/test_slack.py
index 39e31a3c9bc99..e505282c033b3 100644
--- a/tests/providers/slack/operators/test_slack.py
+++ b/tests/providers/slack/operators/test_slack.py
@@ -18,7 +18,6 @@
import json
import unittest
-
from unittest import mock
from airflow.providers.slack.operators.slack import SlackAPIFileOperator, SlackAPIPostOperator
diff --git a/tests/providers/snowflake/operators/test_snowflake.py b/tests/providers/snowflake/operators/test_snowflake.py
index 9d0d9ff6fa9c4..3bec67174ce71 100644
--- a/tests/providers/snowflake/operators/test_snowflake.py
+++ b/tests/providers/snowflake/operators/test_snowflake.py
@@ -17,7 +17,6 @@
# under the License.
import unittest
-
from unittest import mock
from airflow.models.dag import DAG
diff --git a/tests/providers/ssh/hooks/test_ssh.py b/tests/providers/ssh/hooks/test_ssh.py
index be52138daedab..d09a6678edc70 100644
--- a/tests/providers/ssh/hooks/test_ssh.py
+++ b/tests/providers/ssh/hooks/test_ssh.py
@@ -16,13 +16,13 @@
# specific language governing permissions and limitations
# under the License.
import json
+import random
+import string
import unittest
from io import StringIO
-
from typing import Optional
-import random
-import string
from unittest import mock
+
import paramiko
from airflow.models import Connection
diff --git a/tests/secrets/test_local_filesystem.py b/tests/secrets/test_local_filesystem.py
index bc6d0dea0a079..5ee4813f8285b 100644
--- a/tests/secrets/test_local_filesystem.py
+++ b/tests/secrets/test_local_filesystem.py
@@ -107,10 +107,13 @@ def test_missing_file(self, mock_exists):
@parameterized.expand(
(
("KEY: AAA", {"KEY": "AAA"}),
- ("""
+ (
+ """
KEY_A: AAA
KEY_B: BBB
- """, {"KEY_A": "AAA", "KEY_B": "BBB"}),
+ """,
+ {"KEY_A": "AAA", "KEY_B": "BBB"},
+ ),
)
)
def test_yaml_file_should_load_variables(self, file_content, expected_variables):
@@ -141,8 +144,7 @@ def test_env_file_should_load_connection(self, file_content, expected_connection
with mock_local_file(file_content):
connection_by_conn_id = local_filesystem.load_connections_dict("a.env")
connection_uris_by_conn_id = {
- conn_id: connection.get_uri()
- for conn_id, connection in connection_by_conn_id.items()
+ conn_id: connection.get_uri() for conn_id, connection in connection_by_conn_id.items()
}
self.assertEqual(expected_connection_uris, connection_uris_by_conn_id)
@@ -170,8 +172,7 @@ def test_json_file_should_load_connection(self, file_content, expected_connectio
with mock_local_file(json.dumps(file_content)):
connections_by_conn_id = local_filesystem.load_connections_dict("a.json")
connection_uris_by_conn_id = {
- conn_id: connection.get_uri()
- for conn_id, connection in connections_by_conn_id.items()
+ conn_id: connection.get_uri() for conn_id, connection in connections_by_conn_id.items()
}
self.assertEqual(expected_connection_uris, connection_uris_by_conn_id)
@@ -203,7 +204,8 @@ def test_missing_file(self, mock_exists):
@parameterized.expand(
(
("""CONN_A: 'mysql://host_a'""", {"CONN_A": "mysql://host_a"}),
- ("""
+ (
+ """
conn_a: mysql://hosta
conn_b:
conn_type: scheme
@@ -216,18 +218,22 @@ def test_missing_file(self, mock_exists):
extra__google_cloud_platform__keyfile_dict:
a: b
extra__google_cloud_platform__keyfile_path: asaa""",
- {"conn_a": "mysql://hosta",
- "conn_b": ''.join("""scheme://Login:None@host:1234/lschema?
+ {
+ "conn_a": "mysql://hosta",
+ "conn_b": ''.join(
+ """scheme://Login:None@host:1234/lschema?
extra__google_cloud_platform__keyfile_dict=%7B%27a%27%3A+%27b%27%7D
- &extra__google_cloud_platform__keyfile_path=asaa""".split())}),
+ &extra__google_cloud_platform__keyfile_path=asaa""".split()
+ ),
+ },
+ ),
)
)
def test_yaml_file_should_load_connection(self, file_content, expected_connection_uris):
with mock_local_file(file_content):
connections_by_conn_id = local_filesystem.load_connections_dict("a.yaml")
connection_uris_by_conn_id = {
- conn_id: connection.get_uri()
- for conn_id, connection in connections_by_conn_id.items()
+ conn_id: connection.get_uri() for conn_id, connection in connections_by_conn_id.items()
}
self.assertEqual(expected_connection_uris, connection_uris_by_conn_id)
@@ -295,7 +301,8 @@ def test_yaml_file_should_load_connection_extras(self, file_content, expected_ex
@parameterized.expand(
(
- ("""conn_c:
+ (
+ """conn_c:
conn_type: scheme
host: host
schema: lschema
@@ -307,7 +314,9 @@ def test_yaml_file_should_load_connection_extras(self, file_content, expected_ex
extra_dejson:
aws_conn_id: bbb
region_name: ccc
- """, "The extra and extra_dejson parameters are mutually exclusive."),
+ """,
+ "The extra and extra_dejson parameters are mutually exclusive.",
+ ),
)
)
def test_yaml_invalid_extra(self, file_content, expected_message):
@@ -316,9 +325,7 @@ def test_yaml_invalid_extra(self, file_content, expected_message):
local_filesystem.load_connections_dict("a.yaml")
@parameterized.expand(
- (
- "CONN_ID=mysql://host_1/\nCONN_ID=mysql://host_2/",
- ),
+ ("CONN_ID=mysql://host_1/\nCONN_ID=mysql://host_2/",),
)
def test_ensure_unique_connection_env(self, file_content):
with mock_local_file(file_content):
@@ -327,12 +334,8 @@ def test_ensure_unique_connection_env(self, file_content):
@parameterized.expand(
(
- (
- {"CONN_ID": ["mysql://host_1", "mysql://host_2"]},
- ),
- (
- {"CONN_ID": [{"uri": "mysql://host_1"}, {"uri": "mysql://host_2"}]},
- ),
+ ({"CONN_ID": ["mysql://host_1", "mysql://host_2"]},),
+ ({"CONN_ID": [{"uri": "mysql://host_1"}, {"uri": "mysql://host_2"}]},),
)
)
def test_ensure_unique_connection_json(self, file_content):
@@ -342,10 +345,12 @@ def test_ensure_unique_connection_json(self, file_content):
@parameterized.expand(
(
- ("""
+ (
+ """
conn_a:
- mysql://hosta
- - mysql://hostb"""),
+ - mysql://hostb"""
+ ),
),
)
def test_ensure_unique_connection_yaml(self, file_content):
diff --git a/tests/secrets/test_secrets.py b/tests/secrets/test_secrets.py
index 53a11a14b6c90..ed13adec61dac 100644
--- a/tests/secrets/test_secrets.py
+++ b/tests/secrets/test_secrets.py
@@ -42,11 +42,15 @@ def test_get_connections_first_try(self, mock_env_get, mock_meta_get):
mock_env_get.assert_called_once_with(conn_id="fake_conn_id")
mock_meta_get.not_called()
- @conf_vars({
- ("secrets", "backend"):
- "airflow.providers.amazon.aws.secrets.systems_manager.SystemsManagerParameterStoreBackend",
- ("secrets", "backend_kwargs"): '{"connections_prefix": "/airflow", "profile_name": null}',
- })
+ @conf_vars(
+ {
+ (
+ "secrets",
+ "backend",
+ ): "airflow.providers.amazon.aws.secrets.systems_manager.SystemsManagerParameterStoreBackend",
+ ("secrets", "backend_kwargs"): '{"connections_prefix": "/airflow", "profile_name": null}',
+ }
+ )
def test_initialize_secrets_backends(self):
backends = initialize_secrets_backends()
backend_classes = [backend.__class__.__name__ for backend in backends]
@@ -54,30 +58,44 @@ def test_initialize_secrets_backends(self):
self.assertEqual(3, len(backends))
self.assertIn('SystemsManagerParameterStoreBackend', backend_classes)
- @conf_vars({
- ("secrets", "backend"):
- "airflow.providers.amazon.aws.secrets.systems_manager.SystemsManagerParameterStoreBackend",
- ("secrets", "backend_kwargs"): '{"use_ssl": false}',
- })
+ @conf_vars(
+ {
+ (
+ "secrets",
+ "backend",
+ ): "airflow.providers.amazon.aws.secrets.systems_manager.SystemsManagerParameterStoreBackend",
+ ("secrets", "backend_kwargs"): '{"use_ssl": false}',
+ }
+ )
def test_backends_kwargs(self):
backends = initialize_secrets_backends()
systems_manager = [
- backend for backend in backends
+ backend
+ for backend in backends
if backend.__class__.__name__ == 'SystemsManagerParameterStoreBackend'
][0]
self.assertEqual(systems_manager.kwargs, {'use_ssl': False})
- @conf_vars({
- ("secrets", "backend"):
- "airflow.providers.amazon.aws.secrets.systems_manager.SystemsManagerParameterStoreBackend",
- ("secrets", "backend_kwargs"): '{"connections_prefix": "/airflow", "profile_name": null}',
- })
- @mock.patch.dict('os.environ', {
- 'AIRFLOW_CONN_TEST_MYSQL': 'mysql://airflow:airflow@host:5432/airflow',
- })
- @mock.patch("airflow.providers.amazon.aws.secrets.systems_manager."
- "SystemsManagerParameterStoreBackend.get_conn_uri")
+ @conf_vars(
+ {
+ (
+ "secrets",
+ "backend",
+ ): "airflow.providers.amazon.aws.secrets.systems_manager.SystemsManagerParameterStoreBackend",
+ ("secrets", "backend_kwargs"): '{"connections_prefix": "/airflow", "profile_name": null}',
+ }
+ )
+ @mock.patch.dict(
+ 'os.environ',
+ {
+ 'AIRFLOW_CONN_TEST_MYSQL': 'mysql://airflow:airflow@host:5432/airflow',
+ },
+ )
+ @mock.patch(
+ "airflow.providers.amazon.aws.secrets.systems_manager."
+ "SystemsManagerParameterStoreBackend.get_conn_uri"
+ )
def test_backend_fallback_to_env_var(self, mock_get_uri):
mock_get_uri.return_value = None
@@ -94,7 +112,6 @@ def test_backend_fallback_to_env_var(self, mock_get_uri):
class TestVariableFromSecrets(unittest.TestCase):
-
def setUp(self) -> None:
clear_db_variables()
diff --git a/tests/secrets/test_secrets_backends.py b/tests/secrets/test_secrets_backends.py
index ced829a96ce23..5318b8b1212d9 100644
--- a/tests/secrets/test_secrets_backends.py
+++ b/tests/secrets/test_secrets_backends.py
@@ -36,14 +36,11 @@ def __init__(self, conn_id, variation: str):
self.conn_id = conn_id
self.var_name = "AIRFLOW_CONN_" + self.conn_id.upper()
self.host = f"host_{variation}.com"
- self.conn_uri = (
- "mysql://user:pw@" + self.host + "/schema?extra1=val%2B1&extra2=val%2B2"
- )
+ self.conn_uri = "mysql://user:pw@" + self.host + "/schema?extra1=val%2B1&extra2=val%2B2"
self.conn = Connection(conn_id=self.conn_id, uri=self.conn_uri)
class TestBaseSecretsBackend(unittest.TestCase):
-
def setUp(self) -> None:
clear_db_variables()
@@ -51,10 +48,12 @@ def tearDown(self) -> None:
clear_db_connections()
clear_db_variables()
- @parameterized.expand([
- ('default', {"path_prefix": "PREFIX", "secret_id": "ID"}, "PREFIX/ID"),
- ('with_sep', {"path_prefix": "PREFIX", "secret_id": "ID", "sep": "-"}, "PREFIX-ID")
- ])
+ @parameterized.expand(
+ [
+ ('default', {"path_prefix": "PREFIX", "secret_id": "ID"}, "PREFIX/ID"),
+ ('with_sep', {"path_prefix": "PREFIX", "secret_id": "ID", "sep": "-"}, "PREFIX-ID"),
+ ]
+ )
def test_build_path(self, _, kwargs, output):
build_path = BaseSecretsBackend.build_path
self.assertEqual(build_path(**kwargs), output)
@@ -78,14 +77,15 @@ def test_connection_metastore_secrets_backend(self):
metastore_backend = MetastoreBackend()
conn_list = metastore_backend.get_connections("sample_2")
host_list = {x.host for x in conn_list}
- self.assertEqual(
- {sample_conn_2.host.lower()}, set(host_list)
- )
+ self.assertEqual({sample_conn_2.host.lower()}, set(host_list))
- @mock.patch.dict('os.environ', {
- 'AIRFLOW_VAR_HELLO': 'World',
- 'AIRFLOW_VAR_EMPTY_STR': '',
- })
+ @mock.patch.dict(
+ 'os.environ',
+ {
+ 'AIRFLOW_VAR_HELLO': 'World',
+ 'AIRFLOW_VAR_EMPTY_STR': '',
+ },
+ )
def test_variable_env_secrets_backend(self):
env_secrets_backend = EnvironmentVariablesBackend()
variable_value = env_secrets_backend.get_variable(key="hello")
diff --git a/tests/security/test_kerberos.py b/tests/security/test_kerberos.py
index 93e880990d90a..3f08540f38491 100644
--- a/tests/security/test_kerberos.py
+++ b/tests/security/test_kerberos.py
@@ -30,16 +30,18 @@
@unittest.skipIf(KRB5_KTNAME is None, 'Skipping Kerberos API tests due to missing KRB5_KTNAME')
class TestKerberos(unittest.TestCase):
def setUp(self):
- self.args = Namespace(keytab=KRB5_KTNAME, principal=None, pid=None,
- daemon=None, stdout=None, stderr=None, log_file=None)
+ self.args = Namespace(
+ keytab=KRB5_KTNAME, principal=None, pid=None, daemon=None, stdout=None, stderr=None, log_file=None
+ )
@conf_vars({('kerberos', 'keytab'): KRB5_KTNAME})
def test_renew_from_kt(self):
"""
We expect no result, but a successful run. No more TypeError
"""
- self.assertIsNone(renew_from_kt(principal=self.args.principal, # pylint: disable=no-member
- keytab=self.args.keytab))
+ self.assertIsNone(
+ renew_from_kt(principal=self.args.principal, keytab=self.args.keytab) # pylint: disable=no-member
+ )
@conf_vars({('kerberos', 'keytab'): ''})
def test_args_from_cli(self):
@@ -49,13 +51,14 @@ def test_args_from_cli(self):
self.args.keytab = "test_keytab"
with self.assertRaises(SystemExit) as err:
- renew_from_kt(principal=self.args.principal, # pylint: disable=no-member
- keytab=self.args.keytab)
+ renew_from_kt(principal=self.args.principal, keytab=self.args.keytab) # pylint: disable=no-member
with self.assertLogs(kerberos.log) as log:
self.assertIn(
'kinit: krb5_init_creds_set_keytab: Failed to find '
'airflow@LUPUS.GRIDDYNAMICS.NET in keytab FILE:{} '
- '(unknown enctype)'.format(self.args.keytab), log.output)
+ '(unknown enctype)'.format(self.args.keytab),
+ log.output,
+ )
self.assertEqual(err.exception.code, 1)
diff --git a/tests/sensors/test_base_sensor.py b/tests/sensors/test_base_sensor.py
index cc0bbef359524..be3ee694f9876 100644
--- a/tests/sensors/test_base_sensor.py
+++ b/tests/sensors/test_base_sensor.py
@@ -59,10 +59,7 @@ def clean_db():
db.clear_db_xcom()
def setUp(self):
- args = {
- 'owner': 'airflow',
- 'start_date': DEFAULT_DATE
- }
+ args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
self.dag = DAG(TEST_DAG_ID, default_args=args)
self.clean_db()
@@ -74,7 +71,7 @@ def _make_dag_run(self):
run_type=DagRunType.MANUAL,
start_date=timezone.utcnow(),
execution_date=DEFAULT_DATE,
- state=State.RUNNING
+ state=State.RUNNING,
)
def _make_sensor(self, return_value, task_id=SENSOR_OP, **kwargs):
@@ -86,17 +83,9 @@ def _make_sensor(self, return_value, task_id=SENSOR_OP, **kwargs):
if timeout not in kwargs:
kwargs[timeout] = 0
- sensor = DummySensor(
- task_id=task_id,
- return_value=return_value,
- dag=self.dag,
- **kwargs
- )
+ sensor = DummySensor(task_id=task_id, return_value=return_value, dag=self.dag, **kwargs)
- dummy_op = DummyOperator(
- task_id=DUMMY_OP,
- dag=self.dag
- )
+ dummy_op = DummyOperator(task_id=DUMMY_OP, dag=self.dag)
dummy_op.set_upstream(sensor)
return sensor
@@ -146,10 +135,8 @@ def test_soft_fail(self):
def test_soft_fail_with_retries(self):
sensor = self._make_sensor(
- return_value=False,
- soft_fail=True,
- retries=1,
- retry_delay=timedelta(milliseconds=1))
+ return_value=False, soft_fail=True, retries=1, retry_delay=timedelta(milliseconds=1)
+ )
dr = self._make_dag_run()
# first run fails and task instance is marked up to retry
@@ -175,11 +162,7 @@ def test_soft_fail_with_retries(self):
self.assertEqual(ti.state, State.NONE)
def test_ok_with_reschedule(self):
- sensor = self._make_sensor(
- return_value=None,
- poke_interval=10,
- timeout=25,
- mode='reschedule')
+ sensor = self._make_sensor(return_value=None, poke_interval=10, timeout=25, mode='reschedule')
sensor.poke = Mock(side_effect=[False, False, True])
dr = self._make_dag_run()
@@ -199,8 +182,9 @@ def test_ok_with_reschedule(self):
task_reschedules = TaskReschedule.find_for_task_instance(ti)
self.assertEqual(len(task_reschedules), 1)
self.assertEqual(task_reschedules[0].start_date, date1)
- self.assertEqual(task_reschedules[0].reschedule_date,
- date1 + timedelta(seconds=sensor.poke_interval))
+ self.assertEqual(
+ task_reschedules[0].reschedule_date, date1 + timedelta(seconds=sensor.poke_interval)
+ )
if ti.task_id == DUMMY_OP:
self.assertEqual(ti.state, State.NONE)
@@ -220,8 +204,9 @@ def test_ok_with_reschedule(self):
task_reschedules = TaskReschedule.find_for_task_instance(ti)
self.assertEqual(len(task_reschedules), 2)
self.assertEqual(task_reschedules[1].start_date, date2)
- self.assertEqual(task_reschedules[1].reschedule_date,
- date2 + timedelta(seconds=sensor.poke_interval))
+ self.assertEqual(
+ task_reschedules[1].reschedule_date, date2 + timedelta(seconds=sensor.poke_interval)
+ )
if ti.task_id == DUMMY_OP:
self.assertEqual(ti.state, State.NONE)
@@ -240,11 +225,7 @@ def test_ok_with_reschedule(self):
self.assertEqual(ti.state, State.NONE)
def test_fail_with_reschedule(self):
- sensor = self._make_sensor(
- return_value=False,
- poke_interval=10,
- timeout=5,
- mode='reschedule')
+ sensor = self._make_sensor(return_value=False, poke_interval=10, timeout=5, mode='reschedule')
dr = self._make_dag_run()
# first poke returns False and task is re-scheduled
@@ -274,11 +255,8 @@ def test_fail_with_reschedule(self):
def test_soft_fail_with_reschedule(self):
sensor = self._make_sensor(
- return_value=False,
- poke_interval=10,
- timeout=5,
- soft_fail=True,
- mode='reschedule')
+ return_value=False, poke_interval=10, timeout=5, soft_fail=True, mode='reschedule'
+ )
dr = self._make_dag_run()
# first poke returns False and task is re-scheduled
@@ -312,7 +290,8 @@ def test_ok_with_reschedule_and_retry(self):
timeout=5,
retries=1,
retry_delay=timedelta(seconds=10),
- mode='reschedule')
+ mode='reschedule',
+ )
sensor.poke = Mock(side_effect=[False, False, False, True])
dr = self._make_dag_run()
@@ -329,8 +308,9 @@ def test_ok_with_reschedule_and_retry(self):
task_reschedules = TaskReschedule.find_for_task_instance(ti)
self.assertEqual(len(task_reschedules), 1)
self.assertEqual(task_reschedules[0].start_date, date1)
- self.assertEqual(task_reschedules[0].reschedule_date,
- date1 + timedelta(seconds=sensor.poke_interval))
+ self.assertEqual(
+ task_reschedules[0].reschedule_date, date1 + timedelta(seconds=sensor.poke_interval)
+ )
self.assertEqual(task_reschedules[0].try_number, 1)
if ti.task_id == DUMMY_OP:
self.assertEqual(ti.state, State.NONE)
@@ -361,8 +341,9 @@ def test_ok_with_reschedule_and_retry(self):
task_reschedules = TaskReschedule.find_for_task_instance(ti)
self.assertEqual(len(task_reschedules), 1)
self.assertEqual(task_reschedules[0].start_date, date3)
- self.assertEqual(task_reschedules[0].reschedule_date,
- date3 + timedelta(seconds=sensor.poke_interval))
+ self.assertEqual(
+ task_reschedules[0].reschedule_date, date3 + timedelta(seconds=sensor.poke_interval)
+ )
self.assertEqual(task_reschedules[0].try_number, 2)
if ti.task_id == DUMMY_OP:
self.assertEqual(ti.state, State.NONE)
@@ -391,22 +372,20 @@ def test_should_not_include_ready_to_reschedule_dep_in_poke_mode(self):
def test_invalid_mode(self):
with self.assertRaises(AirflowException):
- self._make_sensor(
- return_value=True,
- mode='foo')
+ self._make_sensor(return_value=True, mode='foo')
def test_ok_with_custom_reschedule_exception(self):
- sensor = self._make_sensor(
- return_value=None,
- mode='reschedule')
+ sensor = self._make_sensor(return_value=None, mode='reschedule')
date1 = timezone.utcnow()
date2 = date1 + timedelta(seconds=60)
date3 = date1 + timedelta(seconds=120)
- sensor.poke = Mock(side_effect=[
- AirflowRescheduleException(date2),
- AirflowRescheduleException(date3),
- True,
- ])
+ sensor.poke = Mock(
+ side_effect=[
+ AirflowRescheduleException(date2),
+ AirflowRescheduleException(date3),
+ True,
+ ]
+ )
dr = self._make_dag_run()
# first poke returns False and task is re-scheduled
@@ -455,11 +434,7 @@ def test_ok_with_custom_reschedule_exception(self):
self.assertEqual(ti.state, State.NONE)
def test_reschedule_with_test_mode(self):
- sensor = self._make_sensor(
- return_value=None,
- poke_interval=10,
- timeout=25,
- mode='reschedule')
+ sensor = self._make_sensor(return_value=None, poke_interval=10, timeout=25, mode='reschedule')
sensor.poke = Mock(side_effect=[False])
dr = self._make_dag_run()
@@ -467,9 +442,7 @@ def test_reschedule_with_test_mode(self):
date1 = timezone.utcnow()
with freeze_time(date1):
for date in self.dag.date_range(DEFAULT_DATE, end_date=DEFAULT_DATE):
- TaskInstance(sensor, date).run(
- ignore_ti_state=True,
- test_mode=True)
+ TaskInstance(sensor, date).run(ignore_ti_state=True, test_mode=True)
tis = dr.get_task_instances()
self.assertEqual(len(tis), 2)
for ti in tis:
@@ -491,20 +464,20 @@ def test_sensor_with_invalid_poke_interval(self):
task_id='test_sensor_task_1',
return_value=None,
poke_interval=negative_poke_interval,
- timeout=25)
+ timeout=25,
+ )
with self.assertRaises(AirflowException):
self._make_sensor(
task_id='test_sensor_task_2',
return_value=None,
poke_interval=non_number_poke_interval,
- timeout=25)
+ timeout=25,
+ )
self._make_sensor(
- task_id='test_sensor_task_3',
- return_value=None,
- poke_interval=positive_poke_interval,
- timeout=25)
+ task_id='test_sensor_task_3', return_value=None, poke_interval=positive_poke_interval, timeout=25
+ )
def test_sensor_with_invalid_timeout(self):
negative_timeout = -25
@@ -512,30 +485,20 @@ def test_sensor_with_invalid_timeout(self):
positive_timeout = 25
with self.assertRaises(AirflowException):
self._make_sensor(
- task_id='test_sensor_task_1',
- return_value=None,
- poke_interval=10,
- timeout=negative_timeout)
+ task_id='test_sensor_task_1', return_value=None, poke_interval=10, timeout=negative_timeout
+ )
with self.assertRaises(AirflowException):
self._make_sensor(
- task_id='test_sensor_task_2',
- return_value=None,
- poke_interval=10,
- timeout=non_number_timeout)
+ task_id='test_sensor_task_2', return_value=None, poke_interval=10, timeout=non_number_timeout
+ )
self._make_sensor(
- task_id='test_sensor_task_3',
- return_value=None,
- poke_interval=10,
- timeout=positive_timeout)
+ task_id='test_sensor_task_3', return_value=None, poke_interval=10, timeout=positive_timeout
+ )
def test_sensor_with_exponential_backoff_off(self):
- sensor = self._make_sensor(
- return_value=None,
- poke_interval=5,
- timeout=60,
- exponential_backoff=False)
+ sensor = self._make_sensor(return_value=None, poke_interval=5, timeout=60, exponential_backoff=False)
started_at = timezone.utcnow() - timedelta(seconds=10)
self.assertEqual(sensor._get_next_poke_interval(started_at, 1), sensor.poke_interval)
@@ -543,11 +506,7 @@ def test_sensor_with_exponential_backoff_off(self):
def test_sensor_with_exponential_backoff_on(self):
- sensor = self._make_sensor(
- return_value=None,
- poke_interval=5,
- timeout=60,
- exponential_backoff=True)
+ sensor = self._make_sensor(return_value=None, poke_interval=5, timeout=60, exponential_backoff=True)
with patch('airflow.utils.timezone.utcnow') as mock_utctime:
mock_utctime.return_value = DEFAULT_DATE
@@ -582,22 +541,14 @@ def change_mode(self, mode):
class TestPokeModeOnly(unittest.TestCase):
-
def setUp(self):
- self.dagbag = DagBag(
- dag_folder=DEV_NULL,
- include_examples=True
- )
- self.args = {
- 'owner': 'airflow',
- 'start_date': DEFAULT_DATE
- }
+ self.dagbag = DagBag(dag_folder=DEV_NULL, include_examples=True)
+ self.args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
self.dag = DAG(TEST_DAG_ID, default_args=self.args)
def test_poke_mode_only_allows_poke_mode(self):
try:
- sensor = DummyPokeOnlySensor(task_id='foo', mode='poke', poke_changes_mode=False,
- dag=self.dag)
+ sensor = DummyPokeOnlySensor(task_id='foo', mode='poke', poke_changes_mode=False, dag=self.dag)
except ValueError:
self.fail("__init__ failed with mode='poke'.")
try:
@@ -610,18 +561,15 @@ def test_poke_mode_only_allows_poke_mode(self):
self.fail("class method failed without changing mode from 'poke'.")
def test_poke_mode_only_bad_class_method(self):
- sensor = DummyPokeOnlySensor(task_id='foo', mode='poke', poke_changes_mode=False,
- dag=self.dag)
+ sensor = DummyPokeOnlySensor(task_id='foo', mode='poke', poke_changes_mode=False, dag=self.dag)
with self.assertRaises(ValueError):
sensor.change_mode('reschedule')
def test_poke_mode_only_bad_init(self):
with self.assertRaises(ValueError):
- DummyPokeOnlySensor(task_id='foo', mode='reschedule',
- poke_changes_mode=False, dag=self.dag)
+ DummyPokeOnlySensor(task_id='foo', mode='reschedule', poke_changes_mode=False, dag=self.dag)
def test_poke_mode_only_bad_poke(self):
- sensor = DummyPokeOnlySensor(task_id='foo', mode='poke', poke_changes_mode=True,
- dag=self.dag)
+ sensor = DummyPokeOnlySensor(task_id='foo', mode='poke', poke_changes_mode=True, dag=self.dag)
with self.assertRaises(ValueError):
sensor.poke({})
diff --git a/tests/sensors/test_bash.py b/tests/sensors/test_bash.py
index 3322ba2b74b1b..934d67cd74ec4 100644
--- a/tests/sensors/test_bash.py
+++ b/tests/sensors/test_bash.py
@@ -27,10 +27,7 @@
class TestBashSensor(unittest.TestCase):
def setUp(self):
- args = {
- 'owner': 'airflow',
- 'start_date': datetime.datetime(2017, 1, 1)
- }
+ args = {'owner': 'airflow', 'start_date': datetime.datetime(2017, 1, 1)}
dag = DAG('test_dag_id', default_args=args)
self.dag = dag
@@ -41,7 +38,7 @@ def test_true_condition(self):
output_encoding='utf-8',
poke_interval=1,
timeout=2,
- dag=self.dag
+ dag=self.dag,
)
op.execute(None)
@@ -52,7 +49,7 @@ def test_false_condition(self):
output_encoding='utf-8',
poke_interval=1,
timeout=2,
- dag=self.dag
+ dag=self.dag,
)
with self.assertRaises(AirflowSensorTimeout):
op.execute(None)
diff --git a/tests/sensors/test_date_time_sensor.py b/tests/sensors/test_date_time_sensor.py
index 88370153c7397..87b7dbdbdd221 100644
--- a/tests/sensors/test_date_time_sensor.py
+++ b/tests/sensors/test_date_time_sensor.py
@@ -44,13 +44,19 @@ def setup_class(cls):
]
)
def test_valid_input(self, task_id, target_time, expected):
- op = DateTimeSensor(task_id=task_id, target_time=target_time, dag=self.dag,)
+ op = DateTimeSensor(
+ task_id=task_id,
+ target_time=target_time,
+ dag=self.dag,
+ )
assert op.target_time == expected
def test_invalid_input(self):
with pytest.raises(TypeError):
DateTimeSensor(
- task_id="test", target_time=timezone.utcnow().time(), dag=self.dag,
+ task_id="test",
+ target_time=timezone.utcnow().time(),
+ dag=self.dag,
)
@parameterized.expand(
diff --git a/tests/sensors/test_external_task_sensor.py b/tests/sensors/test_external_task_sensor.py
index 7a1e917e8b31b..a99edf9260e0e 100644
--- a/tests/sensors/test_external_task_sensor.py
+++ b/tests/sensors/test_external_task_sensor.py
@@ -39,24 +39,13 @@
class TestExternalTaskSensor(unittest.TestCase):
-
def setUp(self):
- self.dagbag = DagBag(
- dag_folder=DEV_NULL,
- include_examples=True
- )
- self.args = {
- 'owner': 'airflow',
- 'start_date': DEFAULT_DATE
- }
+ self.dagbag = DagBag(dag_folder=DEV_NULL, include_examples=True)
+ self.args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
self.dag = DAG(TEST_DAG_ID, default_args=self.args)
def test_time_sensor(self):
- op = TimeSensor(
- task_id=TEST_TASK_ID,
- target_time=time(0),
- dag=self.dag
- )
+ op = TimeSensor(task_id=TEST_TASK_ID, target_time=time(0), dag=self.dag)
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_external_task_sensor(self):
@@ -65,13 +54,9 @@ def test_external_task_sensor(self):
task_id='test_external_task_sensor_check',
external_dag_id=TEST_DAG_ID,
external_task_id=TEST_TASK_ID,
- dag=self.dag
- )
- op.run(
- start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE,
- ignore_ti_state=True
+ dag=self.dag,
)
+ op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_catch_overlap_allowed_failed_state(self):
with self.assertRaises(AirflowException):
@@ -81,7 +66,7 @@ def test_catch_overlap_allowed_failed_state(self):
external_task_id=TEST_TASK_ID,
allowed_states=[State.SUCCESS],
failed_states=[State.SUCCESS],
- dag=self.dag
+ dag=self.dag,
)
def test_external_task_sensor_wrong_failed_states(self):
@@ -91,7 +76,7 @@ def test_external_task_sensor_wrong_failed_states(self):
external_dag_id=TEST_DAG_ID,
external_task_id=TEST_TASK_ID,
failed_states=["invalid_state"],
- dag=self.dag
+ dag=self.dag,
)
def test_external_task_sensor_failed_states(self):
@@ -101,13 +86,9 @@ def test_external_task_sensor_failed_states(self):
external_dag_id=TEST_DAG_ID,
external_task_id=TEST_TASK_ID,
failed_states=["failed"],
- dag=self.dag
- )
- op.run(
- start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE,
- ignore_ti_state=True
+ dag=self.dag,
)
+ op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_external_task_sensor_failed_states_as_success(self):
self.test_time_sensor()
@@ -117,57 +98,38 @@ def test_external_task_sensor_failed_states_as_success(self):
external_task_id=TEST_TASK_ID,
allowed_states=["failed"],
failed_states=["success"],
- dag=self.dag
+ dag=self.dag,
)
with self.assertRaises(AirflowException) as cm:
- op.run(
- start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE,
- ignore_ti_state=True
- )
- self.assertEqual(str(cm.exception),
- "The external task "
- "time_sensor_check in DAG "
- "unit_test_dag failed.")
+ op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+ self.assertEqual(
+ str(cm.exception), "The external task " "time_sensor_check in DAG " "unit_test_dag failed."
+ )
def test_external_dag_sensor(self):
- other_dag = DAG(
- 'other_dag',
- default_args=self.args,
- end_date=DEFAULT_DATE,
- schedule_interval='@once')
+ other_dag = DAG('other_dag', default_args=self.args, end_date=DEFAULT_DATE, schedule_interval='@once')
other_dag.create_dagrun(
- run_id='test',
- start_date=DEFAULT_DATE,
- execution_date=DEFAULT_DATE,
- state=State.SUCCESS)
+ run_id='test', start_date=DEFAULT_DATE, execution_date=DEFAULT_DATE, state=State.SUCCESS
+ )
op = ExternalTaskSensor(
task_id='test_external_dag_sensor_check',
external_dag_id='other_dag',
external_task_id=None,
- dag=self.dag
- )
- op.run(
- start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE,
- ignore_ti_state=True
+ dag=self.dag,
)
+ op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_templated_sensor(self):
with self.dag:
sensor = ExternalTaskSensor(
- task_id='templated_task',
- external_dag_id='dag_{{ ds }}',
- external_task_id='task_{{ ds }}'
+ task_id='templated_task', external_dag_id='dag_{{ ds }}', external_task_id='task_{{ ds }}'
)
instance = TaskInstance(sensor, DEFAULT_DATE)
instance.render_templates()
- self.assertEqual(sensor.external_dag_id,
- f"dag_{DEFAULT_DATE.date()}")
- self.assertEqual(sensor.external_task_id,
- f"task_{DEFAULT_DATE.date()}")
+ self.assertEqual(sensor.external_dag_id, f"dag_{DEFAULT_DATE.date()}")
+ self.assertEqual(sensor.external_task_id, f"task_{DEFAULT_DATE.date()}")
def test_external_task_sensor_fn_multiple_execution_dates(self):
bash_command_code = """
@@ -180,84 +142,71 @@ def test_external_task_sensor_fn_multiple_execution_dates(self):
exit 0
"""
dag_external_id = TEST_DAG_ID + '_external'
- dag_external = DAG(
- dag_external_id,
- default_args=self.args,
- schedule_interval=timedelta(seconds=1))
+ dag_external = DAG(dag_external_id, default_args=self.args, schedule_interval=timedelta(seconds=1))
task_external_with_failure = BashOperator(
- task_id="task_external_with_failure",
- bash_command=bash_command_code,
- retries=0,
- dag=dag_external)
+ task_id="task_external_with_failure", bash_command=bash_command_code, retries=0, dag=dag_external
+ )
task_external_without_failure = DummyOperator(
- task_id="task_external_without_failure",
- retries=0,
- dag=dag_external)
+ task_id="task_external_without_failure", retries=0, dag=dag_external
+ )
task_external_without_failure.run(
- start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE + timedelta(seconds=1),
- ignore_ti_state=True)
+ start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + timedelta(seconds=1), ignore_ti_state=True
+ )
session = settings.Session()
TI = TaskInstance
try:
task_external_with_failure.run(
- start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE + timedelta(seconds=1),
- ignore_ti_state=True)
+ start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + timedelta(seconds=1), ignore_ti_state=True
+ )
# The test_with_failure task is excepted to fail
# once per minute (the run on the first second of
# each minute).
except Exception as e: # pylint: disable=broad-except
- failed_tis = session.query(TI).filter(
- TI.dag_id == dag_external_id,
- TI.state == State.FAILED,
- TI.execution_date == DEFAULT_DATE + timedelta(seconds=1)).all()
- if len(failed_tis) == 1 and \
- failed_tis[0].task_id == 'task_external_with_failure':
+ failed_tis = (
+ session.query(TI)
+ .filter(
+ TI.dag_id == dag_external_id,
+ TI.state == State.FAILED,
+ TI.execution_date == DEFAULT_DATE + timedelta(seconds=1),
+ )
+ .all()
+ )
+ if len(failed_tis) == 1 and failed_tis[0].task_id == 'task_external_with_failure':
pass
else:
raise e
dag_id = TEST_DAG_ID
- dag = DAG(
- dag_id,
- default_args=self.args,
- schedule_interval=timedelta(minutes=1))
+ dag = DAG(dag_id, default_args=self.args, schedule_interval=timedelta(minutes=1))
task_without_failure = ExternalTaskSensor(
task_id='task_without_failure',
external_dag_id=dag_external_id,
external_task_id='task_external_without_failure',
- execution_date_fn=lambda dt: [dt + timedelta(seconds=i)
- for i in range(2)],
+ execution_date_fn=lambda dt: [dt + timedelta(seconds=i) for i in range(2)],
allowed_states=['success'],
retries=0,
timeout=1,
poke_interval=1,
- dag=dag)
+ dag=dag,
+ )
task_with_failure = ExternalTaskSensor(
task_id='task_with_failure',
external_dag_id=dag_external_id,
external_task_id='task_external_with_failure',
- execution_date_fn=lambda dt: [dt + timedelta(seconds=i)
- for i in range(2)],
+ execution_date_fn=lambda dt: [dt + timedelta(seconds=i) for i in range(2)],
allowed_states=['success'],
retries=0,
timeout=1,
poke_interval=1,
- dag=dag)
+ dag=dag,
+ )
- task_without_failure.run(
- start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE,
- ignore_ti_state=True)
+ task_without_failure.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
with self.assertRaises(AirflowSensorTimeout):
- task_with_failure.run(
- start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE,
- ignore_ti_state=True)
+ task_with_failure.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_external_task_sensor_delta(self):
self.test_time_sensor()
@@ -267,13 +216,9 @@ def test_external_task_sensor_delta(self):
external_task_id=TEST_TASK_ID,
execution_delta=timedelta(0),
allowed_states=['success'],
- dag=self.dag
- )
- op.run(
- start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE,
- ignore_ti_state=True
+ dag=self.dag,
)
+ op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_external_task_sensor_fn(self):
self.test_time_sensor()
@@ -284,13 +229,9 @@ def test_external_task_sensor_fn(self):
external_task_id=TEST_TASK_ID,
execution_date_fn=lambda dt: dt + timedelta(0),
allowed_states=['success'],
- dag=self.dag
- )
- op1.run(
- start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE,
- ignore_ti_state=True
+ dag=self.dag,
)
+ op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
# double check that the execution is being called by failing the test
op2 = ExternalTaskSensor(
task_id='test_external_task_sensor_check_delta_2',
@@ -300,14 +241,10 @@ def test_external_task_sensor_fn(self):
allowed_states=['success'],
timeout=1,
poke_interval=1,
- dag=self.dag
+ dag=self.dag,
)
with self.assertRaises(exceptions.AirflowSensorTimeout):
- op2.run(
- start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE,
- ignore_ti_state=True
- )
+ op2.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_external_task_sensor_fn_multiple_args(self):
"""Check this task sensor passes multiple args with full context. If no failure, means clean run."""
@@ -323,13 +260,9 @@ def my_func(dt, context):
external_task_id=TEST_TASK_ID,
execution_date_fn=my_func,
allowed_states=['success'],
- dag=self.dag
- )
- op1.run(
- start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE,
- ignore_ti_state=True
+ dag=self.dag,
)
+ op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_external_task_sensor_error_delta_and_fn(self):
self.test_time_sensor()
@@ -342,7 +275,7 @@ def test_external_task_sensor_error_delta_and_fn(self):
execution_delta=timedelta(0),
execution_date_fn=lambda dt: dt,
allowed_states=['success'],
- dag=self.dag
+ dag=self.dag,
)
def test_catch_invalid_allowed_states(self):
@@ -352,7 +285,7 @@ def test_catch_invalid_allowed_states(self):
external_dag_id=TEST_DAG_ID,
external_task_id=TEST_TASK_ID,
allowed_states=['invalid_state'],
- dag=self.dag
+ dag=self.dag,
)
with self.assertRaises(ValueError):
@@ -361,7 +294,7 @@ def test_catch_invalid_allowed_states(self):
external_dag_id=TEST_DAG_ID,
external_task_id=None,
allowed_states=['invalid_state'],
- dag=self.dag
+ dag=self.dag,
)
def test_external_task_sensor_waits_for_task_check_existence(self):
@@ -370,15 +303,11 @@ def test_external_task_sensor_waits_for_task_check_existence(self):
external_dag_id="example_bash_operator",
external_task_id="non-existing-task",
check_existence=True,
- dag=self.dag
+ dag=self.dag,
)
with self.assertRaises(AirflowException):
- op.run(
- start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE,
- ignore_ti_state=True
- )
+ op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_external_task_sensor_waits_for_dag_check_existence(self):
op = ExternalTaskSensor(
@@ -386,15 +315,11 @@ def test_external_task_sensor_waits_for_dag_check_existence(self):
external_dag_id="non-existing-dag",
external_task_id=None,
check_existence=True,
- dag=self.dag
+ dag=self.dag,
)
with self.assertRaises(AirflowException):
- op.run(
- start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE,
- ignore_ti_state=True
- )
+ op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
class TestExternalTaskMarker(unittest.TestCase):
@@ -407,7 +332,7 @@ def test_serialized_external_task_marker(self):
task_id="parent_task",
external_dag_id="external_task_marker_child",
external_task_id="child_task1",
- dag=dag
+ dag=dag,
)
serialized_op = SerializedBaseOperator.serialize_operator(task)
@@ -438,42 +363,33 @@ def dag_bag_ext():
dag_0 = DAG("dag_0", start_date=DEFAULT_DATE, schedule_interval=None)
task_a_0 = DummyOperator(task_id="task_a_0", dag=dag_0)
- task_b_0 = ExternalTaskMarker(task_id="task_b_0",
- external_dag_id="dag_1",
- external_task_id="task_a_1",
- recursion_depth=3,
- dag=dag_0)
+ task_b_0 = ExternalTaskMarker(
+ task_id="task_b_0", external_dag_id="dag_1", external_task_id="task_a_1", recursion_depth=3, dag=dag_0
+ )
task_a_0 >> task_b_0
dag_1 = DAG("dag_1", start_date=DEFAULT_DATE, schedule_interval=None)
- task_a_1 = ExternalTaskSensor(task_id="task_a_1",
- external_dag_id=dag_0.dag_id,
- external_task_id=task_b_0.task_id,
- dag=dag_1)
- task_b_1 = ExternalTaskMarker(task_id="task_b_1",
- external_dag_id="dag_2",
- external_task_id="task_a_2",
- recursion_depth=2,
- dag=dag_1)
+ task_a_1 = ExternalTaskSensor(
+ task_id="task_a_1", external_dag_id=dag_0.dag_id, external_task_id=task_b_0.task_id, dag=dag_1
+ )
+ task_b_1 = ExternalTaskMarker(
+ task_id="task_b_1", external_dag_id="dag_2", external_task_id="task_a_2", recursion_depth=2, dag=dag_1
+ )
task_a_1 >> task_b_1
dag_2 = DAG("dag_2", start_date=DEFAULT_DATE, schedule_interval=None)
- task_a_2 = ExternalTaskSensor(task_id="task_a_2",
- external_dag_id=dag_1.dag_id,
- external_task_id=task_b_1.task_id,
- dag=dag_2)
- task_b_2 = ExternalTaskMarker(task_id="task_b_2",
- external_dag_id="dag_3",
- external_task_id="task_a_3",
- recursion_depth=1,
- dag=dag_2)
+ task_a_2 = ExternalTaskSensor(
+ task_id="task_a_2", external_dag_id=dag_1.dag_id, external_task_id=task_b_1.task_id, dag=dag_2
+ )
+ task_b_2 = ExternalTaskMarker(
+ task_id="task_b_2", external_dag_id="dag_3", external_task_id="task_a_3", recursion_depth=1, dag=dag_2
+ )
task_a_2 >> task_b_2
dag_3 = DAG("dag_3", start_date=DEFAULT_DATE, schedule_interval=None)
- task_a_3 = ExternalTaskSensor(task_id="task_a_3",
- external_dag_id=dag_2.dag_id,
- external_task_id=task_b_2.task_id,
- dag=dag_3)
+ task_a_3 = ExternalTaskSensor(
+ task_id="task_a_3", external_dag_id=dag_2.dag_id, external_task_id=task_b_2.task_id, dag=dag_3
+ )
task_b_3 = DummyOperator(task_id="task_b_3", dag=dag_3)
task_a_3 >> task_b_3
@@ -587,23 +503,18 @@ def dag_bag_cyclic():
dag_0 = DAG("dag_0", start_date=DEFAULT_DATE, schedule_interval=None)
task_a_0 = DummyOperator(task_id="task_a_0", dag=dag_0)
- task_b_0 = ExternalTaskMarker(task_id="task_b_0",
- external_dag_id="dag_1",
- external_task_id="task_a_1",
- recursion_depth=3,
- dag=dag_0)
+ task_b_0 = ExternalTaskMarker(
+ task_id="task_b_0", external_dag_id="dag_1", external_task_id="task_a_1", recursion_depth=3, dag=dag_0
+ )
task_a_0 >> task_b_0
dag_1 = DAG("dag_1", start_date=DEFAULT_DATE, schedule_interval=None)
- task_a_1 = ExternalTaskSensor(task_id="task_a_1",
- external_dag_id=dag_0.dag_id,
- external_task_id=task_b_0.task_id,
- dag=dag_1)
- task_b_1 = ExternalTaskMarker(task_id="task_b_1",
- external_dag_id="dag_0",
- external_task_id="task_a_0",
- recursion_depth=2,
- dag=dag_1)
+ task_a_1 = ExternalTaskSensor(
+ task_id="task_a_1", external_dag_id=dag_0.dag_id, external_task_id=task_b_0.task_id, dag=dag_1
+ )
+ task_b_1 = ExternalTaskMarker(
+ task_id="task_b_1", external_dag_id="dag_0", external_task_id="task_a_0", recursion_depth=2, dag=dag_1
+ )
task_a_1 >> task_b_1
for dag in [dag_0, dag_1]:
@@ -639,11 +550,13 @@ def dag_bag_multiple():
start = DummyOperator(task_id="start", dag=agg_dag)
for i in range(25):
- task = ExternalTaskMarker(task_id=f"{daily_task.task_id}_{i}",
- external_dag_id=daily_dag.dag_id,
- external_task_id=daily_task.task_id,
- execution_date="{{ macros.ds_add(ds, -1 * %s) }}" % i,
- dag=agg_dag)
+ task = ExternalTaskMarker(
+ task_id=f"{daily_task.task_id}_{i}",
+ external_dag_id=daily_dag.dag_id,
+ external_task_id=daily_task.task_id,
+ execution_date="{{ macros.ds_add(ds, -1 * %s) }}" % i,
+ dag=agg_dag,
+ )
start >> task
yield dag_bag
@@ -689,16 +602,20 @@ def dag_bag_head_tail():
"""
dag_bag = DagBag(dag_folder=DEV_NULL, include_examples=False)
with DAG("head_tail", start_date=DEFAULT_DATE, schedule_interval="@daily") as dag:
- head = ExternalTaskSensor(task_id='head',
- external_dag_id=dag.dag_id,
- external_task_id="tail",
- execution_delta=timedelta(days=1),
- mode="reschedule")
+ head = ExternalTaskSensor(
+ task_id='head',
+ external_dag_id=dag.dag_id,
+ external_task_id="tail",
+ execution_delta=timedelta(days=1),
+ mode="reschedule",
+ )
body = DummyOperator(task_id="body")
- tail = ExternalTaskMarker(task_id="tail",
- external_dag_id=dag.dag_id,
- external_task_id=head.task_id,
- execution_date="{{ tomorrow_ds_nodash }}")
+ tail = ExternalTaskMarker(
+ task_id="tail",
+ external_dag_id=dag.dag_id,
+ external_task_id=head.task_id,
+ execution_date="{{ tomorrow_ds_nodash }}",
+ )
head >> body >> tail
dag_bag.bag_dag(dag=dag, root_dag=dag)
diff --git a/tests/sensors/test_filesystem.py b/tests/sensors/test_filesystem.py
index 03ba2758b9e3e..8332f9c553cf7 100644
--- a/tests/sensors/test_filesystem.py
+++ b/tests/sensors/test_filesystem.py
@@ -34,11 +34,9 @@
class TestFileSensor(unittest.TestCase):
def setUp(self):
from airflow.hooks.filesystem import FSHook
+
hook = FSHook()
- args = {
- 'owner': 'airflow',
- 'start_date': DEFAULT_DATE
- }
+ args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args)
self.hook = hook
self.dag = dag
@@ -53,8 +51,7 @@ def test_simple(self):
timeout=0,
)
task._hook = self.hook
- task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
- ignore_ti_state=True)
+ task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_file_in_nonexistent_dir(self):
temp_dir = tempfile.mkdtemp()
@@ -64,13 +61,12 @@ def test_file_in_nonexistent_dir(self):
fs_conn_id='fs_default',
dag=self.dag,
timeout=0,
- poke_interval=1
+ poke_interval=1,
)
task._hook = self.hook
try:
with self.assertRaises(AirflowSensorTimeout):
- task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
- ignore_ti_state=True)
+ task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
finally:
shutil.rmtree(temp_dir)
@@ -82,13 +78,12 @@ def test_empty_dir(self):
fs_conn_id='fs_default',
dag=self.dag,
timeout=0,
- poke_interval=1
+ poke_interval=1,
)
task._hook = self.hook
try:
with self.assertRaises(AirflowSensorTimeout):
- task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
- ignore_ti_state=True)
+ task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
finally:
shutil.rmtree(temp_dir)
@@ -105,8 +100,7 @@ def test_file_in_dir(self):
try:
# `touch` the dir
open(temp_dir + "/file", "a").close()
- task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
- ignore_ti_state=True)
+ task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
finally:
shutil.rmtree(temp_dir)
@@ -119,8 +113,7 @@ def test_default_fs_conn_id(self):
timeout=0,
)
task._hook = self.hook
- task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
- ignore_ti_state=True)
+ task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_wildcard_file(self):
suffix = '.txt'
@@ -134,8 +127,7 @@ def test_wildcard_file(self):
timeout=0,
)
task._hook = self.hook
- task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
- ignore_ti_state=True)
+ task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_subdirectory_not_empty(self):
suffix = '.txt'
@@ -151,8 +143,7 @@ def test_subdirectory_not_empty(self):
timeout=0,
)
task._hook = self.hook
- task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
- ignore_ti_state=True)
+ task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
shutil.rmtree(temp_dir)
def test_subdirectory_empty(self):
@@ -164,11 +155,10 @@ def test_subdirectory_empty(self):
fs_conn_id='fs_default',
dag=self.dag,
timeout=0,
- poke_interval=1
+ poke_interval=1,
)
task._hook = self.hook
with self.assertRaises(AirflowSensorTimeout):
- task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
- ignore_ti_state=True)
+ task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
shutil.rmtree(temp_dir)
diff --git a/tests/sensors/test_python.py b/tests/sensors/test_python.py
index 445496c8b6714..c8eee8428420e 100644
--- a/tests/sensors/test_python.py
+++ b/tests/sensors/test_python.py
@@ -31,12 +31,8 @@
class TestPythonSensor(TestPythonBase):
-
def test_python_sensor_true(self):
- op = PythonSensor(
- task_id='python_sensor_check_true',
- python_callable=lambda: True,
- dag=self.dag)
+ op = PythonSensor(task_id='python_sensor_check_true', python_callable=lambda: True, dag=self.dag)
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_python_sensor_false(self):
@@ -45,15 +41,13 @@ def test_python_sensor_false(self):
timeout=0.01,
poke_interval=0.01,
python_callable=lambda: False,
- dag=self.dag)
+ dag=self.dag,
+ )
with self.assertRaises(AirflowSensorTimeout):
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_python_sensor_raise(self):
- op = PythonSensor(
- task_id='python_sensor_check_raise',
- python_callable=lambda: 1 / 0,
- dag=self.dag)
+ op = PythonSensor(task_id='python_sensor_check_raise', python_callable=lambda: 1 / 0, dag=self.dag)
with self.assertRaises(ZeroDivisionError):
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
@@ -73,19 +67,15 @@ def test_python_callable_arguments_are_templatized(self):
# a Mock instance cannot be used as a callable function or test fails with a
# TypeError: Object of type Mock is not JSON serializable
python_callable=build_recording_function(recorded_calls),
- op_args=[
- 4,
- date(2019, 1, 1),
- "dag {{dag.dag_id}} ran on {{ds}}.",
- named_tuple
- ],
- dag=self.dag)
+ op_args=[4, date(2019, 1, 1), "dag {{dag.dag_id}} ran on {{ds}}.", named_tuple],
+ dag=self.dag,
+ )
self.dag.create_dagrun(
run_type=DagRunType.MANUAL,
execution_date=DEFAULT_DATE,
start_date=DEFAULT_DATE,
- state=State.RUNNING
+ state=State.RUNNING,
)
with self.assertRaises(AirflowSensorTimeout):
task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
@@ -95,10 +85,12 @@ def test_python_callable_arguments_are_templatized(self):
self.assertEqual(2, len(recorded_calls))
self._assert_calls_equal(
recorded_calls[0],
- Call(4,
- date(2019, 1, 1),
- f"dag {self.dag.dag_id} ran on {ds_templated}.",
- Named(ds_templated, 'unchanged'))
+ Call(
+ 4,
+ date(2019, 1, 1),
+ f"dag {self.dag.dag_id} ran on {ds_templated}.",
+ Named(ds_templated, 'unchanged'),
+ ),
)
def test_python_callable_keyword_arguments_are_templatized(self):
@@ -115,15 +107,16 @@ def test_python_callable_keyword_arguments_are_templatized(self):
op_kwargs={
'an_int': 4,
'a_date': date(2019, 1, 1),
- 'a_templated_string': "dag {{dag.dag_id}} ran on {{ds}}."
+ 'a_templated_string': "dag {{dag.dag_id}} ran on {{ds}}.",
},
- dag=self.dag)
+ dag=self.dag,
+ )
self.dag.create_dagrun(
run_type=DagRunType.MANUAL,
execution_date=DEFAULT_DATE,
start_date=DEFAULT_DATE,
- state=State.RUNNING
+ state=State.RUNNING,
)
with self.assertRaises(AirflowSensorTimeout):
task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
@@ -132,8 +125,11 @@ def test_python_callable_keyword_arguments_are_templatized(self):
self.assertEqual(2, len(recorded_calls))
self._assert_calls_equal(
recorded_calls[0],
- Call(an_int=4,
- a_date=date(2019, 1, 1),
- a_templated_string="dag {} ran on {}.".format(
- self.dag.dag_id, DEFAULT_DATE.date().isoformat()))
+ Call(
+ an_int=4,
+ a_date=date(2019, 1, 1),
+ a_templated_string="dag {} ran on {}.".format(
+ self.dag.dag_id, DEFAULT_DATE.date().isoformat()
+ ),
+ ),
)
diff --git a/tests/sensors/test_smart_sensor_operator.py b/tests/sensors/test_smart_sensor_operator.py
index 90f5c5c67dd7b..fa7fc079bb208 100644
--- a/tests/sensors/test_smart_sensor_operator.py
+++ b/tests/sensors/test_smart_sensor_operator.py
@@ -43,13 +43,10 @@
class DummySmartSensor(SmartSensorOperator):
- def __init__(self,
- shard_max=conf.getint('smart_sensor', 'shard_code_upper_limit'),
- shard_min=0,
- **kwargs):
- super().__init__(shard_min=shard_min,
- shard_max=shard_max,
- **kwargs)
+ def __init__(
+ self, shard_max=conf.getint('smart_sensor', 'shard_code_upper_limit'), shard_min=0, **kwargs
+ ):
+ super().__init__(shard_min=shard_min, shard_max=shard_max, **kwargs)
class DummySensor(BaseSensorOperator):
@@ -73,10 +70,7 @@ def setUp(self):
os.environ['AIRFLOW__SMART_SENSER__USE_SMART_SENSOR'] = 'true'
os.environ['AIRFLOW__SMART_SENSER__SENSORS_ENABLED'] = 'DummySmartSensor'
- args = {
- 'owner': 'airflow',
- 'start_date': DEFAULT_DATE
- }
+ args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
self.dag = DAG(TEST_DAG_ID, default_args=args)
self.sensor_dag = DAG(TEST_SENSOR_DAG_ID, default_args=args)
self.log = logging.getLogger('BaseSmartTest')
@@ -102,7 +96,7 @@ def _make_dag_run(self):
run_id='manual__' + TEST_DAG_ID,
start_date=timezone.utcnow(),
execution_date=DEFAULT_DATE,
- state=State.RUNNING
+ state=State.RUNNING,
)
def _make_sensor_dag_run(self):
@@ -110,7 +104,7 @@ def _make_sensor_dag_run(self):
run_id='manual__' + TEST_SENSOR_DAG_ID,
start_date=timezone.utcnow(),
execution_date=DEFAULT_DATE,
- state=State.RUNNING
+ state=State.RUNNING,
)
def _make_sensor(self, return_value, **kwargs):
@@ -121,12 +115,7 @@ def _make_sensor(self, return_value, **kwargs):
if timeout not in kwargs:
kwargs[timeout] = 0
- sensor = DummySensor(
- task_id=SENSOR_OP,
- return_value=return_value,
- dag=self.sensor_dag,
- **kwargs
- )
+ sensor = DummySensor(task_id=SENSOR_OP, return_value=return_value, dag=self.sensor_dag, **kwargs)
return sensor
@@ -139,12 +128,7 @@ def _make_sensor_instance(self, index, return_value, **kwargs):
kwargs[timeout] = 0
task_id = SENSOR_OP + str(index)
- sensor = DummySensor(
- task_id=task_id,
- return_value=return_value,
- dag=self.sensor_dag,
- **kwargs
- )
+ sensor = DummySensor(task_id=task_id, return_value=return_value, dag=self.sensor_dag, **kwargs)
ti = TaskInstance(task=sensor, execution_date=DEFAULT_DATE)
@@ -158,16 +142,9 @@ def _make_smart_operator(self, index, **kwargs):
if smart_sensor_timeout not in kwargs:
kwargs[smart_sensor_timeout] = 0
- smart_task = DummySmartSensor(
- task_id=SMART_OP + "_" + str(index),
- dag=self.dag,
- **kwargs
- )
+ smart_task = DummySmartSensor(task_id=SMART_OP + "_" + str(index), dag=self.dag, **kwargs)
- dummy_op = DummyOperator(
- task_id=DUMMY_OP,
- dag=self.dag
- )
+ dummy_op = DummyOperator(task_id=DUMMY_OP, dag=self.dag)
dummy_op.set_upstream(smart_task)
return smart_task
@@ -315,11 +292,13 @@ def test_register_in_sensor_service(self):
session = settings.Session()
SI = SensorInstance
- sensor_instance = session.query(SI).filter(
- SI.dag_id == si1.dag_id,
- SI.task_id == si1.task_id,
- SI.execution_date == si1.execution_date) \
+ sensor_instance = (
+ session.query(SI)
+ .filter(
+ SI.dag_id == si1.dag_id, SI.task_id == si1.task_id, SI.execution_date == si1.execution_date
+ )
.first()
+ )
self.assertIsNotNone(sensor_instance)
self.assertEqual(sensor_instance.state, State.SENSING)
diff --git a/tests/sensors/test_sql_sensor.py b/tests/sensors/test_sql_sensor.py
index ac58f26926007..18948856abffa 100644
--- a/tests/sensors/test_sql_sensor.py
+++ b/tests/sensors/test_sql_sensor.py
@@ -32,13 +32,9 @@
class TestSqlSensor(TestHiveEnvironment):
-
def setUp(self):
super().setUp()
- args = {
- 'owner': 'airflow',
- 'start_date': DEFAULT_DATE
- }
+ args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
self.dag = DAG(TEST_DAG_ID, default_args=args)
def test_unsupported_conn_type(self):
@@ -46,7 +42,7 @@ def test_unsupported_conn_type(self):
task_id='sql_sensor_check',
conn_id='redis_default',
sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
- dag=self.dag
+ dag=self.dag,
)
with self.assertRaises(AirflowException):
@@ -58,7 +54,7 @@ def test_sql_sensor_mysql(self):
task_id='sql_sensor_check_1',
conn_id='mysql_default',
sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
- dag=self.dag
+ dag=self.dag,
)
op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
@@ -67,7 +63,7 @@ def test_sql_sensor_mysql(self):
conn_id='mysql_default',
sql="SELECT count(%s) FROM INFORMATION_SCHEMA.TABLES",
parameters=["table_name"],
- dag=self.dag
+ dag=self.dag,
)
op2.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
@@ -77,7 +73,7 @@ def test_sql_sensor_postgres(self):
task_id='sql_sensor_check_1',
conn_id='postgres_default',
sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
- dag=self.dag
+ dag=self.dag,
)
op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
@@ -86,7 +82,7 @@ def test_sql_sensor_postgres(self):
conn_id='postgres_default',
sql="SELECT count(%s) FROM INFORMATION_SCHEMA.TABLES",
parameters=["table_name"],
- dag=self.dag
+ dag=self.dag,
)
op2.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
@@ -125,10 +121,7 @@ def test_sql_sensor_postgres_poke(self, mock_hook):
@mock.patch('airflow.sensors.sql_sensor.BaseHook')
def test_sql_sensor_postgres_poke_fail_on_empty(self, mock_hook):
op = SqlSensor(
- task_id='sql_sensor_check',
- conn_id='postgres_default',
- sql="SELECT 1",
- fail_on_empty=True
+ task_id='sql_sensor_check', conn_id='postgres_default', sql="SELECT 1", fail_on_empty=True
)
mock_hook.get_connection('postgres_default').conn_type = "postgres"
@@ -140,10 +133,7 @@ def test_sql_sensor_postgres_poke_fail_on_empty(self, mock_hook):
@mock.patch('airflow.sensors.sql_sensor.BaseHook')
def test_sql_sensor_postgres_poke_success(self, mock_hook):
op = SqlSensor(
- task_id='sql_sensor_check',
- conn_id='postgres_default',
- sql="SELECT 1",
- success=lambda x: x in [1]
+ task_id='sql_sensor_check', conn_id='postgres_default', sql="SELECT 1", success=lambda x: x in [1]
)
mock_hook.get_connection('postgres_default').conn_type = "postgres"
@@ -161,10 +151,7 @@ def test_sql_sensor_postgres_poke_success(self, mock_hook):
@mock.patch('airflow.sensors.sql_sensor.BaseHook')
def test_sql_sensor_postgres_poke_failure(self, mock_hook):
op = SqlSensor(
- task_id='sql_sensor_check',
- conn_id='postgres_default',
- sql="SELECT 1",
- failure=lambda x: x in [1]
+ task_id='sql_sensor_check', conn_id='postgres_default', sql="SELECT 1", failure=lambda x: x in [1]
)
mock_hook.get_connection('postgres_default').conn_type = "postgres"
@@ -183,7 +170,7 @@ def test_sql_sensor_postgres_poke_failure_success(self, mock_hook):
conn_id='postgres_default',
sql="SELECT 1",
failure=lambda x: x in [1],
- success=lambda x: x in [2]
+ success=lambda x: x in [2],
)
mock_hook.get_connection('postgres_default').conn_type = "postgres"
@@ -205,7 +192,7 @@ def test_sql_sensor_postgres_poke_failure_success_same(self, mock_hook):
conn_id='postgres_default',
sql="SELECT 1",
failure=lambda x: x in [1],
- success=lambda x: x in [1]
+ success=lambda x: x in [1],
)
mock_hook.get_connection('postgres_default').conn_type = "postgres"
@@ -248,13 +235,13 @@ def test_sql_sensor_postgres_poke_invalid_success(self, mock_hook):
self.assertRaises(AirflowException, op.poke, None)
@unittest.skipIf(
- 'AIRFLOW_RUNALL_TESTS' not in os.environ,
- "Skipped because AIRFLOW_RUNALL_TESTS is not set")
+ 'AIRFLOW_RUNALL_TESTS' not in os.environ, "Skipped because AIRFLOW_RUNALL_TESTS is not set"
+ )
def test_sql_sensor_presto(self):
op = SqlSensor(
task_id='hdfs_sensor_check',
conn_id='presto_default',
sql="SELECT 'x' FROM airflow.static_babynames LIMIT 1;",
- dag=self.dag)
- op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
- ignore_ti_state=True)
+ dag=self.dag,
+ )
+ op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
diff --git a/tests/sensors/test_time_sensor.py b/tests/sensors/test_time_sensor.py
index 28d8936e338dd..567297d0917dc 100644
--- a/tests/sensors/test_time_sensor.py
+++ b/tests/sensors/test_time_sensor.py
@@ -28,9 +28,7 @@
DEFAULT_TIMEZONE = "Asia/Singapore" # UTC+08:00
DEFAULT_DATE_WO_TZ = datetime(2015, 1, 1)
-DEFAULT_DATE_WITH_TZ = datetime(
- 2015, 1, 1, tzinfo=pendulum.tz.timezone(DEFAULT_TIMEZONE)
-)
+DEFAULT_DATE_WITH_TZ = datetime(2015, 1, 1, tzinfo=pendulum.tz.timezone(DEFAULT_TIMEZONE))
@patch(
diff --git a/tests/sensors/test_timedelta_sensor.py b/tests/sensors/test_timedelta_sensor.py
index 615c922e68560..d613337161b91 100644
--- a/tests/sensors/test_timedelta_sensor.py
+++ b/tests/sensors/test_timedelta_sensor.py
@@ -30,14 +30,10 @@
class TestTimedeltaSensor(unittest.TestCase):
def setUp(self):
- self.dagbag = DagBag(
- dag_folder=DEV_NULL, include_examples=True)
+ self.dagbag = DagBag(dag_folder=DEV_NULL, include_examples=True)
self.args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
self.dag = DAG(TEST_DAG_ID, default_args=self.args)
def test_timedelta_sensor(self):
- op = TimeDeltaSensor(
- task_id='timedelta_sensor_check',
- delta=timedelta(seconds=2),
- dag=self.dag)
+ op = TimeDeltaSensor(task_id='timedelta_sensor_check', delta=timedelta(seconds=2), dag=self.dag)
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
diff --git a/tests/sensors/test_timeout_sensor.py b/tests/sensors/test_timeout_sensor.py
index 09b35b3ad2247..3df6e101ece5b 100644
--- a/tests/sensors/test_timeout_sensor.py
+++ b/tests/sensors/test_timeout_sensor.py
@@ -39,9 +39,7 @@ class TimeoutTestSensor(BaseSensorOperator):
"""
@apply_defaults
- def __init__(self,
- return_value=False,
- **kwargs):
+ def __init__(self, return_value=False, **kwargs):
self.return_value = return_value
super().__init__(**kwargs)
@@ -65,10 +63,7 @@ def execute(self, context):
class TestSensorTimeout(unittest.TestCase):
def setUp(self):
- args = {
- 'owner': 'airflow',
- 'start_date': DEFAULT_DATE
- }
+ args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
self.dag = DAG(TEST_DAG_ID, default_args=args)
def test_timeout(self):
@@ -78,10 +73,8 @@ def test_timeout(self):
return_value=False,
poke_interval=5,
params={'time_jump': timedelta(days=2, seconds=1)},
- dag=self.dag
+ dag=self.dag,
)
self.assertRaises(
- AirflowSensorTimeout,
- op.run,
- start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True
+ AirflowSensorTimeout, op.run, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True
)
diff --git a/tests/sensors/test_weekday_sensor.py b/tests/sensors/test_weekday_sensor.py
index 75b5cb9961b17..26b65a06e9147 100644
--- a/tests/sensors/test_weekday_sensor.py
+++ b/tests/sensors/test_weekday_sensor.py
@@ -37,7 +37,6 @@
class TestDayOfWeekSensor(unittest.TestCase):
-
@staticmethod
def clean_db():
db.clear_db_runs()
@@ -45,34 +44,28 @@ def clean_db():
def setUp(self):
self.clean_db()
- self.dagbag = DagBag(
- dag_folder=DEV_NULL,
- include_examples=True
- )
- self.args = {
- 'owner': 'airflow',
- 'start_date': DEFAULT_DATE
- }
+ self.dagbag = DagBag(dag_folder=DEV_NULL, include_examples=True)
+ self.args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
dag = DAG(TEST_DAG_ID, default_args=self.args)
self.dag = dag
def tearDown(self):
self.clean_db()
- @parameterized.expand([
- ("with-string", 'Thursday'),
- ("with-enum", WeekDay.THURSDAY),
- ("with-enum-set", {WeekDay.THURSDAY}),
- ("with-enum-set-2-items", {WeekDay.THURSDAY, WeekDay.FRIDAY}),
- ("with-string-set", {'Thursday'}),
- ("with-string-set-2-items", {'Thursday', 'Friday'}),
- ])
+ @parameterized.expand(
+ [
+ ("with-string", 'Thursday'),
+ ("with-enum", WeekDay.THURSDAY),
+ ("with-enum-set", {WeekDay.THURSDAY}),
+ ("with-enum-set-2-items", {WeekDay.THURSDAY, WeekDay.FRIDAY}),
+ ("with-string-set", {'Thursday'}),
+ ("with-string-set-2-items", {'Thursday', 'Friday'}),
+ ]
+ )
def test_weekday_sensor_true(self, _, week_day):
op = DayOfWeekSensor(
- task_id='weekday_sensor_check_true',
- week_day=week_day,
- use_task_execution_day=True,
- dag=self.dag)
+ task_id='weekday_sensor_check_true', week_day=week_day, use_task_execution_day=True, dag=self.dag
+ )
op.run(start_date=WEEKDAY_DATE, end_date=WEEKDAY_DATE, ignore_ti_state=True)
self.assertEqual(op.week_day, week_day)
@@ -83,32 +76,35 @@ def test_weekday_sensor_false(self):
timeout=2,
week_day='Tuesday',
use_task_execution_day=True,
- dag=self.dag)
+ dag=self.dag,
+ )
with self.assertRaises(AirflowSensorTimeout):
op.run(start_date=WEEKDAY_DATE, end_date=WEEKDAY_DATE, ignore_ti_state=True)
def test_invalid_weekday_number(self):
invalid_week_day = 'Thsday'
- with self.assertRaisesRegex(AttributeError,
- f'Invalid Week Day passed: "{invalid_week_day}"'):
+ with self.assertRaisesRegex(AttributeError, f'Invalid Week Day passed: "{invalid_week_day}"'):
DayOfWeekSensor(
task_id='weekday_sensor_invalid_weekday_num',
week_day=invalid_week_day,
use_task_execution_day=True,
- dag=self.dag)
+ dag=self.dag,
+ )
def test_weekday_sensor_with_invalid_type(self):
invalid_week_day = ['Thsday']
- with self.assertRaisesRegex(TypeError,
- 'Unsupported Type for week_day parameter:'
- ' {}. It should be one of str, set or '
- 'Weekday enum type'.format(type(invalid_week_day))
- ):
+ with self.assertRaisesRegex(
+ TypeError,
+ 'Unsupported Type for week_day parameter:'
+ ' {}. It should be one of str, set or '
+ 'Weekday enum type'.format(type(invalid_week_day)),
+ ):
DayOfWeekSensor(
task_id='weekday_sensor_check_true',
week_day=invalid_week_day,
use_task_execution_day=True,
- dag=self.dag)
+ dag=self.dag,
+ )
def test_weekday_sensor_timeout_with_set(self):
op = DayOfWeekSensor(
@@ -117,6 +113,7 @@ def test_weekday_sensor_timeout_with_set(self):
timeout=2,
week_day={WeekDay.MONDAY, WeekDay.TUESDAY},
use_task_execution_day=True,
- dag=self.dag)
+ dag=self.dag,
+ )
with self.assertRaises(AirflowSensorTimeout):
op.run(start_date=WEEKDAY_DATE, end_date=WEEKDAY_DATE, ignore_ti_state=True)
diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py
index 08a8aaad938a9..51be0b0abf903 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -44,17 +44,12 @@
executor_config_pod = k8s.V1Pod(
metadata=k8s.V1ObjectMeta(name="my-name"),
- spec=k8s.V1PodSpec(containers=[
- k8s.V1Container(
- name="base",
- volume_mounts=[
- k8s.V1VolumeMount(
- name="my-vol",
- mount_path="/vol/"
- )
- ]
- )
- ]))
+ spec=k8s.V1PodSpec(
+ containers=[
+ k8s.V1Container(name="base", volume_mounts=[k8s.V1VolumeMount(name="my-vol", mount_path="/vol/")])
+ ]
+ ),
+)
serialized_simple_dag_ground_truth = {
"__version": 1,
@@ -64,24 +59,22 @@
"__var": {
"depends_on_past": False,
"retries": 1,
- "retry_delay": {
- "__type": "timedelta",
- "__var": 300.0
- }
- }
+ "retry_delay": {"__type": "timedelta", "__var": 300.0},
+ },
},
"start_date": 1564617600.0,
- '_task_group': {'_group_id': None,
- 'prefix_group_id': True,
- 'children': {'bash_task': ('operator', 'bash_task'),
- 'custom_task': ('operator', 'custom_task')},
- 'tooltip': '',
- 'ui_color': 'CornflowerBlue',
- 'ui_fgcolor': '#000',
- 'upstream_group_ids': [],
- 'downstream_group_ids': [],
- 'upstream_task_ids': [],
- 'downstream_task_ids': []},
+ '_task_group': {
+ '_group_id': None,
+ 'prefix_group_id': True,
+ 'children': {'bash_task': ('operator', 'bash_task'), 'custom_task': ('operator', 'custom_task')},
+ 'tooltip': '',
+ 'ui_color': 'CornflowerBlue',
+ 'ui_fgcolor': '#000',
+ 'upstream_group_ids': [],
+ 'downstream_group_ids': [],
+ 'upstream_task_ids': [],
+ 'downstream_task_ids': [],
+ },
"is_paused_upon_creation": False,
"_dag_id": "simple_dag",
"fileloc": None,
@@ -103,12 +96,15 @@
"_task_type": "BashOperator",
"_task_module": "airflow.operators.bash",
"pool": "default_pool",
- "executor_config": {'__type': 'dict',
- '__var': {"pod_override": {
- '__type': 'k8s.V1Pod',
- '__var': PodGenerator.serialize_pod(executor_config_pod)}
- }
- }
+ "executor_config": {
+ '__type': 'dict',
+ '__var': {
+ "pod_override": {
+ '__type': 'k8s.V1Pod',
+ '__var': PodGenerator.serialize_pod(executor_config_pod),
+ }
+ },
+ },
},
{
"task_id": "custom_task",
@@ -134,13 +130,10 @@
"__var": {
"test_role": {
"__type": "set",
- "__var": [
- permissions.ACTION_CAN_READ,
- permissions.ACTION_CAN_EDIT
- ]
+ "__var": [permissions.ACTION_CAN_READ, permissions.ACTION_CAN_EDIT],
}
- }
- }
+ },
+ },
},
}
@@ -166,13 +159,15 @@ def make_simple_dag():
},
start_date=datetime(2019, 8, 1),
is_paused_upon_creation=False,
- access_control={
- "test_role": {permissions.ACTION_CAN_READ, permissions.ACTION_CAN_EDIT}
- }
+ access_control={"test_role": {permissions.ACTION_CAN_READ, permissions.ACTION_CAN_EDIT}},
) as dag:
CustomOperator(task_id='custom_task')
- BashOperator(task_id='bash_task', bash_command='echo {{ task.task_id }}', owner='airflow',
- executor_config={"pod_override": executor_config_pod})
+ BashOperator(
+ task_id='bash_task',
+ bash_command='echo {{ task.task_id }}',
+ owner='airflow',
+ executor_config={"pod_override": executor_config_pod},
+ )
return {'simple_dag': dag}
@@ -189,19 +184,15 @@ def make_user_defined_macro_filter_dag():
def compute_next_execution_date(dag, execution_date):
return dag.following_schedule(execution_date)
- default_args = {
- 'start_date': datetime(2019, 7, 10)
- }
+ default_args = {'start_date': datetime(2019, 7, 10)}
dag = DAG(
'user_defined_macro_filter_dag',
default_args=default_args,
user_defined_macros={
'next_execution_date': compute_next_execution_date,
},
- user_defined_filters={
- 'hello': lambda name: 'Hello %s' % name
- },
- catchup=False
+ user_defined_filters={'hello': lambda name: 'Hello %s' % name},
+ catchup=False,
)
BashOperator(
task_id='echo',
@@ -252,14 +243,18 @@ def setUp(self):
super().setUp()
BaseHook.get_connection = mock.Mock(
return_value=Connection(
- extra=('{'
- '"project_id": "mock", '
- '"location": "mock", '
- '"instance": "mock", '
- '"database_type": "postgres", '
- '"use_proxy": "False", '
- '"use_ssl": "False"'
- '}')))
+ extra=(
+ '{'
+ '"project_id": "mock", '
+ '"location": "mock", '
+ '"instance": "mock", '
+ '"database_type": "postgres", '
+ '"use_proxy": "False", '
+ '"use_ssl": "False"'
+ '}'
+ )
+ )
+ )
self.maxDiff = None # pylint: disable=invalid-name
def test_serialization(self):
@@ -272,14 +267,11 @@ def test_serialization(self):
serialized_dags[v.dag_id] = dag
# Compares with the ground truth of JSON string.
- self.validate_serialized_dag(
- serialized_dags['simple_dag'],
- serialized_simple_dag_ground_truth)
+ self.validate_serialized_dag(serialized_dags['simple_dag'], serialized_simple_dag_ground_truth)
def validate_serialized_dag(self, json_dag, ground_truth_dag):
"""Verify serialized DAGs match the ground truth."""
- self.assertTrue(
- json_dag['dag']['fileloc'].split('/')[-1] == 'test_dag_serialization.py')
+ self.assertTrue(json_dag['dag']['fileloc'].split('/')[-1] == 'test_dag_serialization.py')
json_dag['dag']['fileloc'] = None
def sorted_serialized_dag(dag_dict: dict):
@@ -289,8 +281,7 @@ def sorted_serialized_dag(dag_dict: dict):
items should not matter but assertEqual would fail if the order of
items changes in the dag dictionary
"""
- dag_dict["dag"]["tasks"] = sorted(dag_dict["dag"]["tasks"],
- key=lambda x: sorted(x.keys()))
+ dag_dict["dag"]["tasks"] = sorted(dag_dict["dag"]["tasks"], key=lambda x: sorted(x.keys()))
dag_dict["dag"]["_access_control"]["__var"]["test_role"]["__var"] = sorted(
dag_dict["dag"]["_access_control"]["__var"]["test_role"]["__var"]
)
@@ -306,8 +297,7 @@ def test_deserialization_across_process(self):
# and once here to get a DAG to compare to) we don't want to load all
# dags.
queue = multiprocessing.Queue()
- proc = multiprocessing.Process(
- target=serialize_subprocess, args=(queue, "airflow/example_dags"))
+ proc = multiprocessing.Process(target=serialize_subprocess, args=(queue, "airflow/example_dags"))
proc.daemon = True
proc.start()
@@ -328,10 +318,12 @@ def test_deserialization_across_process(self):
self.validate_deserialized_dag(stringified_dags[dag_id], dags[dag_id])
def test_roundtrip_provider_example_dags(self):
- dags = collect_dags([
- "airflow/providers/*/example_dags",
- "airflow/providers/*/*/example_dags",
- ])
+ dags = collect_dags(
+ [
+ "airflow/providers/*/example_dags",
+ "airflow/providers/*/*/example_dags",
+ ]
+ )
# Verify deserialized DAGs.
for dag in dags.values():
@@ -346,14 +338,14 @@ def validate_deserialized_dag(self, serialized_dag, dag):
fields_to_check = dag.get_serialized_fields() - {
# Doesn't implement __eq__ properly. Check manually
'timezone',
-
# Need to check fields in it, to exclude functions
'default_args',
- "_task_group"
+ "_task_group",
}
for field in fields_to_check:
- assert getattr(serialized_dag, field) == getattr(dag, field), \
- f'{dag.dag_id}.{field} does not match'
+ assert getattr(serialized_dag, field) == getattr(
+ dag, field
+ ), f'{dag.dag_id}.{field} does not match'
if dag.default_args:
for k, v in dag.default_args.items():
@@ -361,8 +353,9 @@ def validate_deserialized_dag(self, serialized_dag, dag):
# Check we stored _something_.
assert k in serialized_dag.default_args
else:
- assert v == serialized_dag.default_args[k], \
- f'{dag.dag_id}.default_args[{k}] does not match'
+ assert (
+ v == serialized_dag.default_args[k]
+ ), f'{dag.dag_id}.default_args[{k}] does not match'
assert serialized_dag.timezone.name == dag.timezone.name
@@ -373,7 +366,11 @@ def validate_deserialized_dag(self, serialized_dag, dag):
# and is equal to fileloc
assert serialized_dag.full_filepath == dag.fileloc
- def validate_deserialized_task(self, serialized_task, task,):
+ def validate_deserialized_task(
+ self,
+ serialized_task,
+ task,
+ ):
"""Verify non-airflow operators are casted to BaseOperator."""
assert isinstance(serialized_task, SerializedBaseOperator)
assert not isinstance(task, SerializedBaseOperator)
@@ -381,17 +378,16 @@ def validate_deserialized_task(self, serialized_task, task,):
fields_to_check = task.get_serialized_fields() - {
# Checked separately
- '_task_type', 'subdag',
-
+ '_task_type',
+ 'subdag',
# Type is excluded, so don't check it
'_log',
-
# List vs tuple. Check separately
'template_fields',
-
# We store the string, real dag has the actual code
- 'on_failure_callback', 'on_success_callback', 'on_retry_callback',
-
+ 'on_failure_callback',
+ 'on_success_callback',
+ 'on_retry_callback',
# Checked separately
'resources',
}
@@ -403,8 +399,9 @@ def validate_deserialized_task(self, serialized_task, task,):
assert serialized_task.downstream_task_ids == task.downstream_task_ids
for field in fields_to_check:
- assert getattr(serialized_task, field) == getattr(task, field), \
- f'{task.dag.dag_id}.{task.task_id}.{field} does not match'
+ assert getattr(serialized_task, field) == getattr(
+ task, field
+ ), f'{task.dag.dag_id}.{task.task_id}.{field} does not match'
if serialized_task.resources is None:
assert task.resources is None or task.resources == []
@@ -419,17 +416,22 @@ def validate_deserialized_task(self, serialized_task, task,):
else:
assert serialized_task.subdag is None
- @parameterized.expand([
- (datetime(2019, 8, 1, tzinfo=timezone.utc), None, datetime(2019, 8, 1, tzinfo=timezone.utc)),
- (datetime(2019, 8, 1, tzinfo=timezone.utc), datetime(2019, 8, 2, tzinfo=timezone.utc),
- datetime(2019, 8, 2, tzinfo=timezone.utc)),
- (datetime(2019, 8, 1, tzinfo=timezone.utc), datetime(2019, 7, 30, tzinfo=timezone.utc),
- datetime(2019, 8, 1, tzinfo=timezone.utc)),
- ])
- def test_deserialization_start_date(self,
- dag_start_date,
- task_start_date,
- expected_task_start_date):
+ @parameterized.expand(
+ [
+ (datetime(2019, 8, 1, tzinfo=timezone.utc), None, datetime(2019, 8, 1, tzinfo=timezone.utc)),
+ (
+ datetime(2019, 8, 1, tzinfo=timezone.utc),
+ datetime(2019, 8, 2, tzinfo=timezone.utc),
+ datetime(2019, 8, 2, tzinfo=timezone.utc),
+ ),
+ (
+ datetime(2019, 8, 1, tzinfo=timezone.utc),
+ datetime(2019, 7, 30, tzinfo=timezone.utc),
+ datetime(2019, 8, 1, tzinfo=timezone.utc),
+ ),
+ ]
+ )
+ def test_deserialization_start_date(self, dag_start_date, task_start_date, expected_task_start_date):
dag = DAG(dag_id='simple_dag', start_date=dag_start_date)
BaseOperator(task_id='simple_task', dag=dag, start_date=task_start_date)
@@ -446,19 +448,23 @@ def test_deserialization_start_date(self,
simple_task = dag.task_dict["simple_task"]
self.assertEqual(simple_task.start_date, expected_task_start_date)
- @parameterized.expand([
- (datetime(2019, 8, 1, tzinfo=timezone.utc), None, datetime(2019, 8, 1, tzinfo=timezone.utc)),
- (datetime(2019, 8, 1, tzinfo=timezone.utc), datetime(2019, 8, 2, tzinfo=timezone.utc),
- datetime(2019, 8, 1, tzinfo=timezone.utc)),
- (datetime(2019, 8, 1, tzinfo=timezone.utc), datetime(2019, 7, 30, tzinfo=timezone.utc),
- datetime(2019, 7, 30, tzinfo=timezone.utc)),
- ])
- def test_deserialization_end_date(self,
- dag_end_date,
- task_end_date,
- expected_task_end_date):
- dag = DAG(dag_id='simple_dag', start_date=datetime(2019, 8, 1),
- end_date=dag_end_date)
+ @parameterized.expand(
+ [
+ (datetime(2019, 8, 1, tzinfo=timezone.utc), None, datetime(2019, 8, 1, tzinfo=timezone.utc)),
+ (
+ datetime(2019, 8, 1, tzinfo=timezone.utc),
+ datetime(2019, 8, 2, tzinfo=timezone.utc),
+ datetime(2019, 8, 1, tzinfo=timezone.utc),
+ ),
+ (
+ datetime(2019, 8, 1, tzinfo=timezone.utc),
+ datetime(2019, 7, 30, tzinfo=timezone.utc),
+ datetime(2019, 7, 30, tzinfo=timezone.utc),
+ ),
+ ]
+ )
+ def test_deserialization_end_date(self, dag_end_date, task_end_date, expected_task_end_date):
+ dag = DAG(dag_id='simple_dag', start_date=datetime(2019, 8, 1), end_date=dag_end_date)
BaseOperator(task_id='simple_task', dag=dag, end_date=task_end_date)
serialized_dag = SerializedDAG.to_dict(dag)
@@ -473,12 +479,14 @@ def test_deserialization_end_date(self,
simple_task = dag.task_dict["simple_task"]
self.assertEqual(simple_task.end_date, expected_task_end_date)
- @parameterized.expand([
- (None, None, None),
- ("@weekly", "@weekly", "0 0 * * 0"),
- ("@once", "@once", None),
- ({"__type": "timedelta", "__var": 86400.0}, timedelta(days=1), timedelta(days=1)),
- ])
+ @parameterized.expand(
+ [
+ (None, None, None),
+ ("@weekly", "@weekly", "0 0 * * 0"),
+ ("@once", "@once", None),
+ ({"__type": "timedelta", "__var": 86400.0}, timedelta(days=1), timedelta(days=1)),
+ ]
+ )
def test_deserialization_schedule_interval(
self, serialized_schedule_interval, expected_schedule_interval, expected_n_schedule_interval
):
@@ -501,14 +509,16 @@ def test_deserialization_schedule_interval(
self.assertEqual(dag.schedule_interval, expected_schedule_interval)
self.assertEqual(dag.normalized_schedule_interval, expected_n_schedule_interval)
- @parameterized.expand([
- (relativedelta(days=-1), {"__type": "relativedelta", "__var": {"days": -1}}),
- (relativedelta(month=1, days=-1), {"__type": "relativedelta", "__var": {"month": 1, "days": -1}}),
- # Every friday
- (relativedelta(weekday=FR), {"__type": "relativedelta", "__var": {"weekday": [4]}}),
- # Every second friday
- (relativedelta(weekday=FR(2)), {"__type": "relativedelta", "__var": {"weekday": [4, 2]}})
- ])
+ @parameterized.expand(
+ [
+ (relativedelta(days=-1), {"__type": "relativedelta", "__var": {"days": -1}}),
+ (relativedelta(month=1, days=-1), {"__type": "relativedelta", "__var": {"month": 1, "days": -1}}),
+ # Every friday
+ (relativedelta(weekday=FR), {"__type": "relativedelta", "__var": {"weekday": [4]}}),
+ # Every second friday
+ (relativedelta(weekday=FR(2)), {"__type": "relativedelta", "__var": {"weekday": [4, 2]}}),
+ ]
+ )
def test_roundtrip_relativedelta(self, val, expected):
serialized = SerializedDAG._serialize(val)
self.assertDictEqual(serialized, expected)
@@ -516,10 +526,12 @@ def test_roundtrip_relativedelta(self, val, expected):
round_tripped = SerializedDAG._deserialize(serialized)
self.assertEqual(val, round_tripped)
- @parameterized.expand([
- (None, {}),
- ({"param_1": "value_1"}, {"param_1": "value_1"}),
- ])
+ @parameterized.expand(
+ [
+ (None, {}),
+ ({"param_1": "value_1"}, {"param_1": "value_1"}),
+ ]
+ )
def test_dag_params_roundtrip(self, val, expected_val):
"""
Test that params work both on Serialized DAGs & Tasks
@@ -538,17 +550,18 @@ def test_dag_params_roundtrip(self, val, expected_val):
self.assertEqual(expected_val, deserialized_dag.params)
self.assertEqual(expected_val, deserialized_simple_task.params)
- @parameterized.expand([
- (None, {}),
- ({"param_1": "value_1"}, {"param_1": "value_1"}),
- ])
+ @parameterized.expand(
+ [
+ (None, {}),
+ ({"param_1": "value_1"}, {"param_1": "value_1"}),
+ ]
+ )
def test_task_params_roundtrip(self, val, expected_val):
"""
Test that params work both on Serialized DAGs & Tasks
"""
dag = DAG(dag_id='simple_dag')
- BaseOperator(task_id='simple_task', dag=dag, params=val,
- start_date=datetime(2019, 8, 1))
+ BaseOperator(task_id='simple_task', dag=dag, params=val, start_date=datetime(2019, 8, 1))
serialized_dag = SerializedDAG.to_dict(dag)
if val:
@@ -589,7 +602,7 @@ def test_extra_serialized_field_and_operator_links(self):
# Check Serialized version of operator link only contains the inbuilt Op Link
self.assertEqual(
serialized_dag["dag"]["tasks"][0]["_operator_extra_links"],
- [{'tests.test_utils.mock_operators.CustomOpLink': {}}]
+ [{'tests.test_utils.mock_operators.CustomOpLink': {}}],
)
# Test all the extra_links are set
@@ -614,6 +627,7 @@ def test_extra_operator_links_logs_error_for_non_registered_extra_links(self):
class TaskStateLink(BaseOperatorLink):
"""OperatorLink not registered via Plugins nor a built-in OperatorLink"""
+
name = 'My Link'
def get_link(self, operator, dttm):
@@ -621,6 +635,7 @@ def get_link(self, operator, dttm):
class MyOperator(BaseOperator):
"""Just a DummyOperator using above defined Extra Operator Link"""
+
operator_extra_links = [TaskStateLink()]
def execute(self, context):
@@ -672,12 +687,14 @@ def test_extra_serialized_field_and_multiple_operator_links(self):
[
{'tests.test_utils.mock_operators.CustomBaseIndexOpLink': {'index': 0}},
{'tests.test_utils.mock_operators.CustomBaseIndexOpLink': {'index': 1}},
- ]
+ ],
)
# Test all the extra_links are set
- self.assertCountEqual(simple_task.extra_links, [
- 'BigQuery Console #1', 'BigQuery Console #2', 'airflow', 'github', 'google'])
+ self.assertCountEqual(
+ simple_task.extra_links,
+ ['BigQuery Console #1', 'BigQuery Console #2', 'airflow', 'github', 'google'],
+ )
ti = TaskInstance(task=simple_task, execution_date=test_date)
ti.xcom_push('search_query', ["dummy_value_1", "dummy_value_2"])
@@ -715,42 +732,48 @@ def __eq__(self, other):
def __ne__(self, other):
return not self.__eq__(other)
- @parameterized.expand([
- (None, None),
- ([], []),
- ({}, {}),
- ("{{ task.task_id }}", "{{ task.task_id }}"),
- (["{{ task.task_id }}", "{{ task.task_id }}"]),
- ({"foo": "{{ task.task_id }}"}, {"foo": "{{ task.task_id }}"}),
- ({"foo": {"bar": "{{ task.task_id }}"}}, {"foo": {"bar": "{{ task.task_id }}"}}),
- (
- [{"foo1": {"bar": "{{ task.task_id }}"}}, {"foo2": {"bar": "{{ task.task_id }}"}}],
- [{"foo1": {"bar": "{{ task.task_id }}"}}, {"foo2": {"bar": "{{ task.task_id }}"}}],
- ),
- (
- {"foo": {"bar": {"{{ task.task_id }}": ["sar"]}}},
- {"foo": {"bar": {"{{ task.task_id }}": ["sar"]}}}),
- (
- ClassWithCustomAttributes(
- att1="{{ task.task_id }}", att2="{{ task.task_id }}", template_fields=["att1"]),
- "ClassWithCustomAttributes("
- "{'att1': '{{ task.task_id }}', 'att2': '{{ task.task_id }}', 'template_fields': ['att1']})",
- ),
- (
- ClassWithCustomAttributes(nested1=ClassWithCustomAttributes(att1="{{ task.task_id }}",
- att2="{{ task.task_id }}",
- template_fields=["att1"]),
- nested2=ClassWithCustomAttributes(att3="{{ task.task_id }}",
- att4="{{ task.task_id }}",
- template_fields=["att3"]),
- template_fields=["nested1"]),
- "ClassWithCustomAttributes("
- "{'nested1': ClassWithCustomAttributes({'att1': '{{ task.task_id }}', "
- "'att2': '{{ task.task_id }}', 'template_fields': ['att1']}), "
- "'nested2': ClassWithCustomAttributes({'att3': '{{ task.task_id }}', "
- "'att4': '{{ task.task_id }}', 'template_fields': ['att3']}), 'template_fields': ['nested1']})",
- ),
- ])
+ @parameterized.expand(
+ [
+ (None, None),
+ ([], []),
+ ({}, {}),
+ ("{{ task.task_id }}", "{{ task.task_id }}"),
+ (["{{ task.task_id }}", "{{ task.task_id }}"]),
+ ({"foo": "{{ task.task_id }}"}, {"foo": "{{ task.task_id }}"}),
+ ({"foo": {"bar": "{{ task.task_id }}"}}, {"foo": {"bar": "{{ task.task_id }}"}}),
+ (
+ [{"foo1": {"bar": "{{ task.task_id }}"}}, {"foo2": {"bar": "{{ task.task_id }}"}}],
+ [{"foo1": {"bar": "{{ task.task_id }}"}}, {"foo2": {"bar": "{{ task.task_id }}"}}],
+ ),
+ (
+ {"foo": {"bar": {"{{ task.task_id }}": ["sar"]}}},
+ {"foo": {"bar": {"{{ task.task_id }}": ["sar"]}}},
+ ),
+ (
+ ClassWithCustomAttributes(
+ att1="{{ task.task_id }}", att2="{{ task.task_id }}", template_fields=["att1"]
+ ),
+ "ClassWithCustomAttributes("
+ "{'att1': '{{ task.task_id }}', 'att2': '{{ task.task_id }}', 'template_fields': ['att1']})",
+ ),
+ (
+ ClassWithCustomAttributes(
+ nested1=ClassWithCustomAttributes(
+ att1="{{ task.task_id }}", att2="{{ task.task_id }}", template_fields=["att1"]
+ ),
+ nested2=ClassWithCustomAttributes(
+ att3="{{ task.task_id }}", att4="{{ task.task_id }}", template_fields=["att3"]
+ ),
+ template_fields=["nested1"],
+ ),
+ "ClassWithCustomAttributes("
+ "{'nested1': ClassWithCustomAttributes({'att1': '{{ task.task_id }}', "
+ "'att2': '{{ task.task_id }}', 'template_fields': ['att1']}), "
+ "'nested2': ClassWithCustomAttributes({'att3': '{{ task.task_id }}', 'att4': "
+ "'{{ task.task_id }}', 'template_fields': ['att3']}), 'template_fields': ['nested1']})",
+ ),
+ ]
+ )
def test_templated_fields_exist_in_serialized_dag(self, templated_field, expected_field):
"""
Test that templated_fields exists for all Operators in Serialized DAG
@@ -781,8 +804,9 @@ def test_dag_serialized_fields_with_schema(self):
self.assertEqual(set(DAG.get_serialized_fields()), dag_params)
def test_operator_subclass_changing_base_defaults(self):
- assert BaseOperator(task_id='dummy').do_xcom_push is True, \
- "Precondition check! If this fails the test won't make sense"
+ assert (
+ BaseOperator(task_id='dummy').do_xcom_push is True
+ ), "Precondition check! If this fails the test won't make sense"
class MyOperator(BaseOperator):
def __init__(self, do_xcom_push=False, **kwargs):
@@ -804,49 +828,53 @@ def test_no_new_fields_added_to_base_operator(self):
"""
base_operator = BaseOperator(task_id="10")
fields = base_operator.__dict__
- self.assertEqual({'_BaseOperator__instantiated': True,
- '_dag': None,
- '_downstream_task_ids': set(),
- '_inlets': [],
- '_log': base_operator.log,
- '_outlets': [],
- '_upstream_task_ids': set(),
- 'depends_on_past': False,
- 'do_xcom_push': True,
- 'email': None,
- 'email_on_failure': True,
- 'email_on_retry': True,
- 'end_date': None,
- 'execution_timeout': None,
- 'executor_config': {},
- 'inlets': [],
- 'label': '10',
- 'max_retry_delay': None,
- 'on_execute_callback': None,
- 'on_failure_callback': None,
- 'on_retry_callback': None,
- 'on_success_callback': None,
- 'outlets': [],
- 'owner': 'airflow',
- 'params': {},
- 'pool': 'default_pool',
- 'pool_slots': 1,
- 'priority_weight': 1,
- 'queue': 'default',
- 'resources': None,
- 'retries': 0,
- 'retry_delay': timedelta(0, 300),
- 'retry_exponential_backoff': False,
- 'run_as_user': None,
- 'sla': None,
- 'start_date': None,
- 'subdag': None,
- 'task_concurrency': None,
- 'task_id': '10',
- 'trigger_rule': 'all_success',
- 'wait_for_downstream': False,
- 'weight_rule': 'downstream'}, fields,
- """
+ self.assertEqual(
+ {
+ '_BaseOperator__instantiated': True,
+ '_dag': None,
+ '_downstream_task_ids': set(),
+ '_inlets': [],
+ '_log': base_operator.log,
+ '_outlets': [],
+ '_upstream_task_ids': set(),
+ 'depends_on_past': False,
+ 'do_xcom_push': True,
+ 'email': None,
+ 'email_on_failure': True,
+ 'email_on_retry': True,
+ 'end_date': None,
+ 'execution_timeout': None,
+ 'executor_config': {},
+ 'inlets': [],
+ 'label': '10',
+ 'max_retry_delay': None,
+ 'on_execute_callback': None,
+ 'on_failure_callback': None,
+ 'on_retry_callback': None,
+ 'on_success_callback': None,
+ 'outlets': [],
+ 'owner': 'airflow',
+ 'params': {},
+ 'pool': 'default_pool',
+ 'pool_slots': 1,
+ 'priority_weight': 1,
+ 'queue': 'default',
+ 'resources': None,
+ 'retries': 0,
+ 'retry_delay': timedelta(0, 300),
+ 'retry_exponential_backoff': False,
+ 'run_as_user': None,
+ 'sla': None,
+ 'start_date': None,
+ 'subdag': None,
+ 'task_concurrency': None,
+ 'task_id': '10',
+ 'trigger_rule': 'all_success',
+ 'wait_for_downstream': False,
+ 'weight_rule': 'downstream',
+ },
+ fields,
+ """
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
ACTION NEEDED! PLEASE READ THIS CAREFULLY AND CORRECT TESTS CAREFULLY
@@ -858,8 +886,8 @@ def test_no_new_fields_added_to_base_operator(self):
Note that we do not support versioning yet so you should only add optional fields to BaseOperator.
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
- """
- )
+ """,
+ )
def test_task_group_serialization(self):
"""
diff --git a/tests/task/task_runner/test_cgroup_task_runner.py b/tests/task/task_runner/test_cgroup_task_runner.py
index c83cf6242fc2a..c5fb97098abe0 100644
--- a/tests/task/task_runner/test_cgroup_task_runner.py
+++ b/tests/task/task_runner/test_cgroup_task_runner.py
@@ -22,7 +22,6 @@
class TestCgroupTaskRunner(unittest.TestCase):
-
@mock.patch("airflow.task.task_runner.base_task_runner.BaseTaskRunner.__init__")
@mock.patch("airflow.task.task_runner.base_task_runner.BaseTaskRunner.on_finish")
def test_cgroup_task_runner_super_calls(self, mock_super_on_finish, mock_super_init):
diff --git a/tests/task/task_runner/test_standard_task_runner.py b/tests/task/task_runner/test_standard_task_runner.py
index b37040c2adfe8..b9e1db7a4cfcf 100644
--- a/tests/task/task_runner/test_standard_task_runner.py
+++ b/tests/task/task_runner/test_standard_task_runner.py
@@ -40,24 +40,16 @@
'version': 1,
'disable_existing_loggers': False,
'formatters': {
- 'airflow.task': {
- 'format': '[%(asctime)s] {{%(filename)s:%(lineno)d}} %(levelname)s - %(message)s'
- },
+ 'airflow.task': {'format': '[%(asctime)s] {{%(filename)s:%(lineno)d}} %(levelname)s - %(message)s'},
},
'handlers': {
'console': {
'class': 'logging.StreamHandler',
'formatter': 'airflow.task',
- 'stream': 'ext://sys.stdout'
+ 'stream': 'ext://sys.stdout',
}
},
- 'loggers': {
- 'airflow': {
- 'handlers': ['console'],
- 'level': 'INFO',
- 'propagate': False
- }
- }
+ 'loggers': {'airflow': {'handlers': ['console'], 'level': 'INFO', 'propagate': False}},
}
@@ -79,7 +71,12 @@ def test_start_and_terminate(self):
local_task_job.task_instance = mock.MagicMock()
local_task_job.task_instance.run_as_user = None
local_task_job.task_instance.command_as_list.return_value = [
- 'airflow', 'tasks', 'test', 'test_on_kill', 'task1', '2016-01-01'
+ 'airflow',
+ 'tasks',
+ 'test',
+ 'test_on_kill',
+ 'task1',
+ '2016-01-01',
]
runner = StandardTaskRunner(local_task_job)
@@ -104,7 +101,12 @@ def test_start_and_terminate_run_as_user(self):
local_task_job.task_instance = mock.MagicMock()
local_task_job.task_instance.run_as_user = getpass.getuser()
local_task_job.task_instance.command_as_list.return_value = [
- 'airflow', 'tasks', 'test', 'test_on_kill', 'task1', '2016-01-01'
+ 'airflow',
+ 'tasks',
+ 'test',
+ 'test_on_kill',
+ 'task1',
+ '2016-01-01',
]
runner = StandardTaskRunner(local_task_job)
@@ -146,11 +148,13 @@ def test_on_kill(self):
session = settings.Session()
dag.clear()
- dag.create_dagrun(run_id="test",
- state=State.RUNNING,
- execution_date=DEFAULT_DATE,
- start_date=DEFAULT_DATE,
- session=session)
+ dag.create_dagrun(
+ run_id="test",
+ state=State.RUNNING,
+ execution_date=DEFAULT_DATE,
+ start_date=DEFAULT_DATE,
+ session=session,
+ )
ti = TI(task=task, execution_date=DEFAULT_DATE)
job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True)
session.commit()
diff --git a/tests/task/task_runner/test_task_runner.py b/tests/task/task_runner/test_task_runner.py
index be4187d36166a..4601333ee18ea 100644
--- a/tests/task/task_runner/test_task_runner.py
+++ b/tests/task/task_runner/test_task_runner.py
@@ -27,16 +27,11 @@
class GetTaskRunner(unittest.TestCase):
-
- @parameterized.expand([
- (import_path, ) for import_path in CORE_TASK_RUNNERS.values()
- ])
+ @parameterized.expand([(import_path,) for import_path in CORE_TASK_RUNNERS.values()])
def test_should_have_valid_imports(self, import_path):
self.assertIsNotNone(import_string(import_path))
- @mock.patch(
- 'airflow.task.task_runner.base_task_runner.subprocess'
- )
+ @mock.patch('airflow.task.task_runner.base_task_runner.subprocess')
@mock.patch('airflow.task.task_runner._TASK_RUNNER_NAME', "StandardTaskRunner")
def test_should_support_core_task_runner(self, mock_subprocess):
local_task_job = mock.MagicMock(
@@ -48,7 +43,7 @@ def test_should_support_core_task_runner(self, mock_subprocess):
@mock.patch(
'airflow.task.task_runner._TASK_RUNNER_NAME',
- "tests.task.task_runner.test_task_runner.custom_task_runner"
+ "tests.task.task_runner.test_task_runner.custom_task_runner",
)
def test_should_support_custom_task_runner(self):
local_task_job = mock.MagicMock(
diff --git a/tests/test_utils/amazon_system_helpers.py b/tests/test_utils/amazon_system_helpers.py
index f252a66121711..ac0d153a65e99 100644
--- a/tests/test_utils/amazon_system_helpers.py
+++ b/tests/test_utils/amazon_system_helpers.py
@@ -28,9 +28,7 @@
from tests.test_utils.logging_command_executor import get_executor
from tests.test_utils.system_tests_class import SystemTest
-AWS_DAG_FOLDER = os.path.join(
- AIRFLOW_MAIN_FOLDER, "airflow", "providers", "amazon", "aws", "example_dags"
-)
+AWS_DAG_FOLDER = os.path.join(AIRFLOW_MAIN_FOLDER, "airflow", "providers", "amazon", "aws", "example_dags")
@contextmanager
@@ -53,7 +51,6 @@ def provide_aws_s3_bucket(name):
@pytest.mark.system("amazon")
class AmazonSystemTest(SystemTest):
-
@staticmethod
def _region_name():
return os.environ.get("REGION_NAME")
@@ -85,8 +82,7 @@ def execute_with_ctx(cls, cmd: List[str]):
executor.execute_cmd(cmd=cmd)
@staticmethod
- def create_connection(aws_conn_id: str,
- region: str) -> None:
+ def create_connection(aws_conn_id: str, region: str) -> None:
"""
Create aws connection with region
@@ -137,8 +133,7 @@ def create_emr_default_roles(cls) -> None:
cls.execute_with_ctx(cmd)
@staticmethod
- def create_ecs_cluster(aws_conn_id: str,
- cluster_name: str) -> None:
+ def create_ecs_cluster(aws_conn_id: str, cluster_name: str) -> None:
"""
Create ecs cluster with given name
@@ -174,8 +169,7 @@ def create_ecs_cluster(aws_conn_id: str,
)
@staticmethod
- def delete_ecs_cluster(aws_conn_id: str,
- cluster_name: str) -> None:
+ def delete_ecs_cluster(aws_conn_id: str, cluster_name: str) -> None:
"""
Delete ecs cluster with given short name or full Amazon Resource Name (ARN)
@@ -193,14 +187,16 @@ def delete_ecs_cluster(aws_conn_id: str,
)
@staticmethod
- def create_ecs_task_definition(aws_conn_id: str,
- task_definition: str,
- container: str,
- image: str,
- execution_role_arn: str,
- awslogs_group: str,
- awslogs_region: str,
- awslogs_stream_prefix: str) -> None:
+ def create_ecs_task_definition(
+ aws_conn_id: str,
+ task_definition: str,
+ container: str,
+ image: str,
+ execution_role_arn: str,
+ awslogs_group: str,
+ awslogs_region: str,
+ awslogs_stream_prefix: str,
+ ) -> None:
"""
Create ecs task definition with given name
@@ -256,8 +252,7 @@ def create_ecs_task_definition(aws_conn_id: str,
)
@staticmethod
- def delete_ecs_task_definition(aws_conn_id: str,
- task_definition: str) -> None:
+ def delete_ecs_task_definition(aws_conn_id: str, task_definition: str) -> None:
"""
Delete all revisions of given ecs task definition
@@ -283,8 +278,7 @@ def delete_ecs_task_definition(aws_conn_id: str,
)
@staticmethod
- def is_ecs_task_definition_exists(aws_conn_id: str,
- task_definition: str) -> bool:
+ def is_ecs_task_definition_exists(aws_conn_id: str, task_definition: str) -> bool:
"""
Check whether given task definition exits in ecs
diff --git a/tests/test_utils/api_connexion_utils.py b/tests/test_utils/api_connexion_utils.py
index 919c99f57560a..522c70c1b5ae3 100644
--- a/tests/test_utils/api_connexion_utils.py
+++ b/tests/test_utils/api_connexion_utils.py
@@ -69,5 +69,5 @@ def assert_401(response):
'detail': None,
'status': 401,
'title': 'Unauthorized',
- 'type': EXCEPTIONS_LINK_MAP[401]
+ 'type': EXCEPTIONS_LINK_MAP[401],
}
diff --git a/tests/test_utils/asserts.py b/tests/test_utils/asserts.py
index 220331de1540f..dccaad10e2894 100644
--- a/tests/test_utils/asserts.py
+++ b/tests/test_utils/asserts.py
@@ -32,6 +32,7 @@
def assert_equal_ignore_multiple_spaces(case, first, second, msg=None):
def _trim(s):
return re.sub(r"\s+", " ", s.strip())
+
return case.assertEqual(_trim(first), _trim(second), msg)
@@ -42,6 +43,7 @@ class CountQueries:
Does not support multiple processes. When a new process is started in context, its queries will
not be included.
"""
+
def __init__(self):
self.result = Counter()
@@ -55,10 +57,11 @@ def __exit__(self, type_, value, tb):
def after_cursor_execute(self, *args, **kwargs):
stack = [
- f for f in traceback.extract_stack()
- if 'sqlalchemy' not in f.filename and
- __file__ != f.filename and
- ('session.py' not in f.filename and f.name != 'wrapper')
+ f
+ for f in traceback.extract_stack()
+ if 'sqlalchemy' not in f.filename
+ and __file__ != f.filename
+ and ('session.py' not in f.filename and f.name != 'wrapper')
]
stack_info = ">".join([f"{f.filename.rpartition('/')[-1]}:{f.name}" for f in stack][-3:])
lineno = stack[-1].lineno
@@ -75,9 +78,12 @@ def assert_queries_count(expected_count, message_fmt=None):
count = sum(result.values())
if expected_count != count:
- message_fmt = message_fmt or "The expected number of db queries is {expected_count}. " \
- "The current number is {current_count}.\n\n" \
- "Recorded query locations:"
+ message_fmt = (
+ message_fmt
+ or "The expected number of db queries is {expected_count}. "
+ "The current number is {current_count}.\n\n"
+ "Recorded query locations:"
+ )
message = message_fmt.format(current_count=count, expected_count=expected_count)
for location, count in result.items():
diff --git a/tests/test_utils/azure_system_helpers.py b/tests/test_utils/azure_system_helpers.py
index 6f6f1e8b3722c..add526d752109 100644
--- a/tests/test_utils/azure_system_helpers.py
+++ b/tests/test_utils/azure_system_helpers.py
@@ -39,7 +39,6 @@ def provide_azure_fileshare(share_name: str, wasb_conn_id: str, file_name: str,
@pytest.mark.system("azure")
class AzureSystemTest(SystemTest):
-
@classmethod
def create_share(cls, share_name: str, wasb_conn_id: str):
hook = AzureFileShareHook(wasb_conn_id=wasb_conn_id)
diff --git a/tests/test_utils/db.py b/tests/test_utils/db.py
index cd2ef469faa35..cfc43682f1ea0 100644
--- a/tests/test_utils/db.py
+++ b/tests/test_utils/db.py
@@ -17,8 +17,20 @@
# under the License.
from airflow.jobs.base_job import BaseJob
from airflow.models import (
- Connection, DagModel, DagRun, DagTag, Log, Pool, RenderedTaskInstanceFields, SlaMiss, TaskFail,
- TaskInstance, TaskReschedule, Variable, XCom, errors,
+ Connection,
+ DagModel,
+ DagRun,
+ DagTag,
+ Log,
+ Pool,
+ RenderedTaskInstanceFields,
+ SlaMiss,
+ TaskFail,
+ TaskInstance,
+ TaskReschedule,
+ Variable,
+ XCom,
+ errors,
)
from airflow.models.dagcode import DagCode
from airflow.models.serialized_dag import SerializedDagModel
diff --git a/tests/test_utils/gcp_system_helpers.py b/tests/test_utils/gcp_system_helpers.py
index b80b76be29825..250593af240d0 100644
--- a/tests/test_utils/gcp_system_helpers.py
+++ b/tests/test_utils/gcp_system_helpers.py
@@ -86,19 +86,24 @@ def provide_gcp_context(
:type project_id: str
"""
key_file_path = resolve_full_gcp_key_path(key_file_path) # type: ignore
- with provide_gcp_conn_and_credentials(key_file_path, scopes, project_id), \
- tempfile.TemporaryDirectory() as gcloud_config_tmp, \
- mock.patch.dict('os.environ', {CLOUD_SDK_CONFIG_DIR: gcloud_config_tmp}):
+ with provide_gcp_conn_and_credentials(
+ key_file_path, scopes, project_id
+ ), tempfile.TemporaryDirectory() as gcloud_config_tmp, mock.patch.dict(
+ 'os.environ', {CLOUD_SDK_CONFIG_DIR: gcloud_config_tmp}
+ ):
executor = get_executor()
if project_id:
- executor.execute_cmd([
- "gcloud", "config", "set", "core/project", project_id
- ])
+ executor.execute_cmd(["gcloud", "config", "set", "core/project", project_id])
if key_file_path:
- executor.execute_cmd([
- "gcloud", "auth", "activate-service-account", f"--key-file={key_file_path}",
- ])
+ executor.execute_cmd(
+ [
+ "gcloud",
+ "auth",
+ "activate-service-account",
+ f"--key-file={key_file_path}",
+ ]
+ )
yield
@@ -121,8 +126,9 @@ def _service_key():
return os.environ.get(CREDENTIALS)
@classmethod
- def execute_with_ctx(cls, cmd: List[str], key: str = GCP_GCS_KEY, project_id=None, scopes=None,
- silent: bool = False):
+ def execute_with_ctx(
+ cls, cmd: List[str], key: str = GCP_GCS_KEY, project_id=None, scopes=None, silent: bool = False
+ ):
"""
Executes command with context created by provide_gcp_context and activated
service key.
@@ -149,9 +155,7 @@ def delete_gcs_bucket(cls, name: str):
@classmethod
def upload_to_gcs(cls, source_uri: str, target_uri: str):
- cls.execute_with_ctx(
- ["gsutil", "cp", source_uri, target_uri], key=GCP_GCS_KEY
- )
+ cls.execute_with_ctx(["gsutil", "cp", source_uri, target_uri], key=GCP_GCS_KEY)
@classmethod
def upload_content_to_gcs(cls, lines: str, bucket: str, filename: str):
@@ -192,10 +196,18 @@ def create_secret(cls, name: str, value: str):
with tempfile.NamedTemporaryFile() as tmp:
tmp.write(value.encode("UTF-8"))
tmp.flush()
- cmd = ["gcloud", "secrets", "create", name,
- "--replication-policy", "automatic",
- "--project", GoogleSystemTest._project_id(),
- "--data-file", tmp.name]
+ cmd = [
+ "gcloud",
+ "secrets",
+ "create",
+ name,
+ "--replication-policy",
+ "automatic",
+ "--project",
+ GoogleSystemTest._project_id(),
+ "--data-file",
+ tmp.name,
+ ]
cls.execute_with_ctx(cmd, key=GCP_SECRET_MANAGER_KEY)
@classmethod
@@ -203,7 +215,15 @@ def update_secret(cls, name: str, value: str):
with tempfile.NamedTemporaryFile() as tmp:
tmp.write(value.encode("UTF-8"))
tmp.flush()
- cmd = ["gcloud", "secrets", "versions", "add", name,
- "--project", GoogleSystemTest._project_id(),
- "--data-file", tmp.name]
+ cmd = [
+ "gcloud",
+ "secrets",
+ "versions",
+ "add",
+ name,
+ "--project",
+ GoogleSystemTest._project_id(),
+ "--data-file",
+ tmp.name,
+ ]
cls.execute_with_ctx(cmd, key=GCP_SECRET_MANAGER_KEY)
diff --git a/tests/test_utils/hdfs_utils.py b/tests/test_utils/hdfs_utils.py
index 91a4a464a870f..348396f84e2e4 100644
--- a/tests/test_utils/hdfs_utils.py
+++ b/tests/test_utils/hdfs_utils.py
@@ -33,7 +33,6 @@ class FakeSnakeBiteClientException(Exception):
class FakeSnakeBiteClient:
-
def __init__(self):
self.started = True
@@ -48,126 +47,142 @@ def ls(self, path, include_toplevel=False): # pylint: disable=invalid-name
if path[0] == '/datadirectory/empty_directory' and not include_toplevel:
return []
elif path[0] == '/datadirectory/datafile':
- return [{
- 'group': 'supergroup',
- 'permission': 420,
- 'file_type': 'f',
- 'access_time': 1481122343796,
- 'block_replication': 3,
- 'modification_time': 1481122343862,
- 'length': 0,
- 'blocksize': 134217728,
- 'owner': 'hdfs',
- 'path': '/datadirectory/datafile'
- }]
+ return [
+ {
+ 'group': 'supergroup',
+ 'permission': 420,
+ 'file_type': 'f',
+ 'access_time': 1481122343796,
+ 'block_replication': 3,
+ 'modification_time': 1481122343862,
+ 'length': 0,
+ 'blocksize': 134217728,
+ 'owner': 'hdfs',
+ 'path': '/datadirectory/datafile',
+ }
+ ]
elif path[0] == '/datadirectory/empty_directory' and include_toplevel:
- return [{
- 'group': 'supergroup',
- 'permission': 493,
- 'file_type': 'd',
- 'access_time': 0,
- 'block_replication': 0,
- 'modification_time': 1481132141540,
- 'length': 0,
- 'blocksize': 0,
- 'owner': 'hdfs',
- 'path': '/datadirectory/empty_directory'
- }]
+ return [
+ {
+ 'group': 'supergroup',
+ 'permission': 493,
+ 'file_type': 'd',
+ 'access_time': 0,
+ 'block_replication': 0,
+ 'modification_time': 1481132141540,
+ 'length': 0,
+ 'blocksize': 0,
+ 'owner': 'hdfs',
+ 'path': '/datadirectory/empty_directory',
+ }
+ ]
elif path[0] == '/datadirectory/not_empty_directory' and include_toplevel:
- return [{
- 'group': 'supergroup',
- 'permission': 493,
- 'file_type': 'd',
- 'access_time': 0,
- 'block_replication': 0,
- 'modification_time': 1481132141540,
- 'length': 0,
- 'blocksize': 0,
- 'owner': 'hdfs',
- 'path': '/datadirectory/empty_directory'
- }, {
- 'group': 'supergroup',
- 'permission': 420,
- 'file_type': 'f',
- 'access_time': 1481122343796,
- 'block_replication': 3,
- 'modification_time': 1481122343862,
- 'length': 0,
- 'blocksize': 134217728,
- 'owner': 'hdfs',
- 'path': '/datadirectory/not_empty_directory/test_file'
- }]
+ return [
+ {
+ 'group': 'supergroup',
+ 'permission': 493,
+ 'file_type': 'd',
+ 'access_time': 0,
+ 'block_replication': 0,
+ 'modification_time': 1481132141540,
+ 'length': 0,
+ 'blocksize': 0,
+ 'owner': 'hdfs',
+ 'path': '/datadirectory/empty_directory',
+ },
+ {
+ 'group': 'supergroup',
+ 'permission': 420,
+ 'file_type': 'f',
+ 'access_time': 1481122343796,
+ 'block_replication': 3,
+ 'modification_time': 1481122343862,
+ 'length': 0,
+ 'blocksize': 134217728,
+ 'owner': 'hdfs',
+ 'path': '/datadirectory/not_empty_directory/test_file',
+ },
+ ]
elif path[0] == '/datadirectory/not_empty_directory':
- return [{
- 'group': 'supergroup',
- 'permission': 420,
- 'file_type': 'f',
- 'access_time': 1481122343796,
- 'block_replication': 3,
- 'modification_time': 1481122343862,
- 'length': 0,
- 'blocksize': 134217728,
- 'owner': 'hdfs',
- 'path': '/datadirectory/not_empty_directory/test_file'
- }]
+ return [
+ {
+ 'group': 'supergroup',
+ 'permission': 420,
+ 'file_type': 'f',
+ 'access_time': 1481122343796,
+ 'block_replication': 3,
+ 'modification_time': 1481122343862,
+ 'length': 0,
+ 'blocksize': 134217728,
+ 'owner': 'hdfs',
+ 'path': '/datadirectory/not_empty_directory/test_file',
+ }
+ ]
elif path[0] == '/datadirectory/not_existing_file_or_directory':
raise FakeSnakeBiteClientException
elif path[0] == '/datadirectory/regex_dir':
- return [{
- 'group': 'supergroup',
- 'permission': 420,
- 'file_type': 'f',
- 'access_time': 1481122343796,
- 'block_replication': 3,
- 'modification_time': 1481122343862, 'length': 12582912,
- 'blocksize': 134217728,
- 'owner': 'hdfs',
- 'path': '/datadirectory/regex_dir/test1file'
- }, {
- 'group': 'supergroup',
- 'permission': 420,
- 'file_type': 'f',
- 'access_time': 1481122343796,
- 'block_replication': 3,
- 'modification_time': 1481122343862,
- 'length': 12582912,
- 'blocksize': 134217728,
- 'owner': 'hdfs',
- 'path': '/datadirectory/regex_dir/test2file'
- }, {
- 'group': 'supergroup',
- 'permission': 420,
- 'file_type': 'f',
- 'access_time': 1481122343796,
- 'block_replication': 3,
- 'modification_time': 1481122343862,
- 'length': 12582912,
- 'blocksize': 134217728,
- 'owner': 'hdfs',
- 'path': '/datadirectory/regex_dir/test3file'
- }, {
- 'group': 'supergroup',
- 'permission': 420,
- 'file_type': 'f',
- 'access_time': 1481122343796,
- 'block_replication': 3,
- 'modification_time': 1481122343862,
- 'length': 12582912,
- 'blocksize': 134217728,
- 'owner': 'hdfs',
- 'path': '/datadirectory/regex_dir/copying_file_1.txt._COPYING_'
- }, {
- 'group': 'supergroup',
- 'permission': 420,
- 'file_type': 'f',
- 'access_time': 1481122343796,
- 'block_replication': 3,
- 'modification_time': 1481122343862,
- 'length': 12582912,
- 'blocksize': 134217728,
- 'owner': 'hdfs',
- 'path': '/datadirectory/regex_dir/copying_file_3.txt.sftp'
- }]
+ return [
+ {
+ 'group': 'supergroup',
+ 'permission': 420,
+ 'file_type': 'f',
+ 'access_time': 1481122343796,
+ 'block_replication': 3,
+ 'modification_time': 1481122343862,
+ 'length': 12582912,
+ 'blocksize': 134217728,
+ 'owner': 'hdfs',
+ 'path': '/datadirectory/regex_dir/test1file',
+ },
+ {
+ 'group': 'supergroup',
+ 'permission': 420,
+ 'file_type': 'f',
+ 'access_time': 1481122343796,
+ 'block_replication': 3,
+ 'modification_time': 1481122343862,
+ 'length': 12582912,
+ 'blocksize': 134217728,
+ 'owner': 'hdfs',
+ 'path': '/datadirectory/regex_dir/test2file',
+ },
+ {
+ 'group': 'supergroup',
+ 'permission': 420,
+ 'file_type': 'f',
+ 'access_time': 1481122343796,
+ 'block_replication': 3,
+ 'modification_time': 1481122343862,
+ 'length': 12582912,
+ 'blocksize': 134217728,
+ 'owner': 'hdfs',
+ 'path': '/datadirectory/regex_dir/test3file',
+ },
+ {
+ 'group': 'supergroup',
+ 'permission': 420,
+ 'file_type': 'f',
+ 'access_time': 1481122343796,
+ 'block_replication': 3,
+ 'modification_time': 1481122343862,
+ 'length': 12582912,
+ 'blocksize': 134217728,
+ 'owner': 'hdfs',
+ 'path': '/datadirectory/regex_dir/copying_file_1.txt._COPYING_',
+ },
+ {
+ 'group': 'supergroup',
+ 'permission': 420,
+ 'file_type': 'f',
+ 'access_time': 1481122343796,
+ 'block_replication': 3,
+ 'modification_time': 1481122343862,
+ 'length': 12582912,
+ 'blocksize': 134217728,
+ 'owner': 'hdfs',
+ 'path': '/datadirectory/regex_dir/copying_file_3.txt.sftp',
+ },
+ ]
else:
raise FakeSnakeBiteClientException
diff --git a/tests/test_utils/logging_command_executor.py b/tests/test_utils/logging_command_executor.py
index 081b21247bcf6..1ebf729acd860 100644
--- a/tests/test_utils/logging_command_executor.py
+++ b/tests/test_utils/logging_command_executor.py
@@ -24,7 +24,6 @@
class LoggingCommandExecutor(LoggingMixin):
-
def execute_cmd(self, cmd, silent=False, cwd=None, env=None):
if silent:
self.log.info("Executing in silent mode: '%s'", " ".join([shlex.quote(c) for c in cmd]))
@@ -33,8 +32,12 @@ def execute_cmd(self, cmd, silent=False, cwd=None, env=None):
else:
self.log.info("Executing: '%s'", " ".join([shlex.quote(c) for c in cmd]))
process = subprocess.Popen(
- args=cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
- universal_newlines=True, cwd=cwd, env=env
+ args=cmd,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ universal_newlines=True,
+ cwd=cwd,
+ env=env,
)
output, err = process.communicate()
retcode = process.poll()
@@ -46,16 +49,16 @@ def execute_cmd(self, cmd, silent=False, cwd=None, env=None):
def check_output(self, cmd):
self.log.info("Executing for output: '%s'", " ".join([shlex.quote(c) for c in cmd]))
- process = subprocess.Popen(args=cmd, stdout=subprocess.PIPE,
- stderr=subprocess.PIPE)
+ process = subprocess.Popen(args=cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
output, err = process.communicate()
retcode = process.poll()
if retcode:
self.log.error("Error when executing '%s'", " ".join([shlex.quote(c) for c in cmd]))
self.log.info("Stdout: %s", output)
self.log.info("Stderr: %s", err)
- raise AirflowException("Retcode {} on {} with stdout: {}, stderr: {}".
- format(retcode, " ".join(cmd), output, err))
+ raise AirflowException(
+ "Retcode {} on {} with stdout: {}, stderr: {}".format(retcode, " ".join(cmd), output, err)
+ )
return output
diff --git a/tests/test_utils/mock_hooks.py b/tests/test_utils/mock_hooks.py
index 900cf206720d7..2c6e2893f4822 100644
--- a/tests/test_utils/mock_hooks.py
+++ b/tests/test_utils/mock_hooks.py
@@ -27,8 +27,7 @@
class MockHiveMetastoreHook(HiveMetastoreHook):
def __init__(self, *args, **kwargs):
self._find_valid_server = mock.MagicMock(return_value={})
- self.get_metastore_client = mock.MagicMock(
- return_value=mock.MagicMock())
+ self.get_metastore_client = mock.MagicMock(return_value=mock.MagicMock())
super().__init__()
@@ -64,8 +63,7 @@ def __init__(self, *args, **kwargs):
self.conn.execute = mock.MagicMock()
self.get_conn = mock.MagicMock(return_value=self.conn)
- self.get_first = mock.MagicMock(
- return_value=[['val_0', 'val_1'], 'val_2'])
+ self.get_first = mock.MagicMock(return_value=[['val_0', 'val_1'], 'val_2'])
super().__init__(*args, **kwargs)
diff --git a/tests/test_utils/mock_operators.py b/tests/test_utils/mock_operators.py
index ce410876413bf..534770ea4100b 100644
--- a/tests/test_utils/mock_operators.py
+++ b/tests/test_utils/mock_operators.py
@@ -51,6 +51,7 @@ class AirflowLink(BaseOperatorLink):
"""
Operator Link for Apache Airflow Website
"""
+
name = 'airflow'
def get_link(self, operator, dttm):
@@ -62,9 +63,8 @@ class Dummy2TestOperator(BaseOperator):
Example of an Operator that has an extra operator link
and will be overridden by the one defined in tests/plugins/test_plugin.py
"""
- operator_extra_links = (
- AirflowLink(),
- )
+
+ operator_extra_links = (AirflowLink(),)
class Dummy3TestOperator(BaseOperator):
@@ -72,6 +72,7 @@ class Dummy3TestOperator(BaseOperator):
Example of an operator that has no extra Operator link.
An operator link would be added to this operator via Airflow plugin
"""
+
operator_extra_links = ()
@@ -113,12 +114,8 @@ def operator_extra_links(self):
Return operator extra links
"""
if isinstance(self.bash_command, str) or self.bash_command is None:
- return (
- CustomOpLink(),
- )
- return (
- CustomBaseIndexOpLink(i) for i, _ in enumerate(self.bash_command)
- )
+ return (CustomOpLink(),)
+ return (CustomBaseIndexOpLink(i) for i, _ in enumerate(self.bash_command))
@apply_defaults
def __init__(self, bash_command=None, **kwargs):
@@ -134,6 +131,7 @@ class GoogleLink(BaseOperatorLink):
"""
Operator Link for Apache Airflow Website for Google
"""
+
name = 'google'
operators = [Dummy3TestOperator, CustomOperator]
@@ -145,6 +143,7 @@ class AirflowLink2(BaseOperatorLink):
"""
Operator Link for Apache Airflow Website for 1.10.5
"""
+
name = 'airflow'
operators = [Dummy2TestOperator, Dummy3TestOperator]
@@ -156,6 +155,7 @@ class GithubLink(BaseOperatorLink):
"""
Operator Link for Apache Airflow GitHub
"""
+
name = 'github'
def get_link(self, operator, dttm):
diff --git a/tests/test_utils/mock_process.py b/tests/test_utils/mock_process.py
index abd1e42b969da..4031d55d8d67e 100644
--- a/tests/test_utils/mock_process.py
+++ b/tests/test_utils/mock_process.py
@@ -25,8 +25,7 @@ def __init__(self, extra_dejson=None, *args, **kwargs):
self.get_records = mock.MagicMock(return_value=[['test_record']])
output = kwargs.get('output', ['' for _ in range(10)])
- self.readline = mock.MagicMock(
- side_effect=[line.encode() for line in output])
+ self.readline = mock.MagicMock(side_effect=[line.encode() for line in output])
def status(self, *args, **kwargs):
return True
@@ -35,8 +34,7 @@ def status(self, *args, **kwargs):
class MockStdOut:
def __init__(self, *args, **kwargs):
output = kwargs.get('output', ['' for _ in range(10)])
- self.readline = mock.MagicMock(
- side_effect=[line.encode() for line in output])
+ self.readline = mock.MagicMock(side_effect=[line.encode() for line in output])
class MockSubProcess:
@@ -54,8 +52,10 @@ def wait(self):
class MockConnectionCursor:
def __init__(self, *args, **kwargs):
self.arraysize = None
- self.description = [('hive_server_hook.a', 'INT_TYPE', None, None, None, None, True),
- ('hive_server_hook.b', 'INT_TYPE', None, None, None, None, True)]
+ self.description = [
+ ('hive_server_hook.a', 'INT_TYPE', None, None, None, None, True),
+ ('hive_server_hook.b', 'INT_TYPE', None, None, None, None, True),
+ ]
self.iterable = [(1, 1), (2, 2)]
self.conn_exists = kwargs.get('exists', True)
diff --git a/tests/test_utils/perf/dags/elastic_dag.py b/tests/test_utils/perf/dags/elastic_dag.py
index 4de9cd30f5c8b..9aa0a4d621c8c 100644
--- a/tests/test_utils/perf/dags/elastic_dag.py
+++ b/tests/test_utils/perf/dags/elastic_dag.py
@@ -141,6 +141,7 @@ class DagShape(Enum):
"""
Define shape of the Dag that will be used for testing.
"""
+
NO_STRUCTURE = "no_structure"
LINEAR = "linear"
BINARY_TREE = "binary_tree"
@@ -168,25 +169,25 @@ class DagShape(Enum):
for dag_no in range(1, DAG_COUNT + 1):
dag = DAG(
- dag_id=safe_dag_id("__".join(
- [
- DAG_PREFIX,
- f"SHAPE={SHAPE.name.lower()}",
- f"DAGS_COUNT={dag_no}_of_{DAG_COUNT}",
- f"TASKS_COUNT=${TASKS_COUNT}",
- f"START_DATE=${START_DATE_ENV}",
- f"SCHEDULE_INTERVAL=${SCHEDULE_INTERVAL_ENV}",
- ]
- )),
+ dag_id=safe_dag_id(
+ "__".join(
+ [
+ DAG_PREFIX,
+ f"SHAPE={SHAPE.name.lower()}",
+ f"DAGS_COUNT={dag_no}_of_{DAG_COUNT}",
+ f"TASKS_COUNT=${TASKS_COUNT}",
+ f"START_DATE=${START_DATE_ENV}",
+ f"SCHEDULE_INTERVAL=${SCHEDULE_INTERVAL_ENV}",
+ ]
+ )
+ ),
is_paused_upon_creation=False,
default_args=args,
schedule_interval=SCHEDULE_INTERVAL,
)
elastic_dag_tasks = [
- BashOperator(
- task_id="__".join(["tasks", f"{i}_of_{TASKS_COUNT}"]), bash_command='echo test', dag=dag
- )
+ BashOperator(task_id="__".join(["tasks", f"{i}_of_{TASKS_COUNT}"]), bash_command='echo test', dag=dag)
for i in range(1, TASKS_COUNT + 1)
]
diff --git a/tests/test_utils/perf/dags/perf_dag_1.py b/tests/test_utils/perf/dags/perf_dag_1.py
index 021a910b54433..3757c7d40e092 100644
--- a/tests/test_utils/perf/dags/perf_dag_1.py
+++ b/tests/test_utils/perf/dags/perf_dag_1.py
@@ -30,14 +30,12 @@
}
dag = DAG(
- dag_id='perf_dag_1', default_args=args,
- schedule_interval='@daily',
- dagrun_timeout=timedelta(minutes=60))
+ dag_id='perf_dag_1', default_args=args, schedule_interval='@daily', dagrun_timeout=timedelta(minutes=60)
+)
task_1 = BashOperator(
- task_id='perf_task_1',
- bash_command='sleep 5; echo "run_id={{ run_id }} | dag_run={{ dag_run }}"',
- dag=dag)
+ task_id='perf_task_1', bash_command='sleep 5; echo "run_id={{ run_id }} | dag_run={{ dag_run }}"', dag=dag
+)
for i in range(2, 5):
task = BashOperator(
@@ -45,7 +43,8 @@
bash_command='''
sleep 5; echo "run_id={{ run_id }} | dag_run={{ dag_run }}"
''',
- dag=dag)
+ dag=dag,
+ )
task.set_upstream(task_1)
if __name__ == "__main__":
diff --git a/tests/test_utils/perf/dags/perf_dag_2.py b/tests/test_utils/perf/dags/perf_dag_2.py
index d9ef47efa1195..a0e8bba9429fe 100644
--- a/tests/test_utils/perf/dags/perf_dag_2.py
+++ b/tests/test_utils/perf/dags/perf_dag_2.py
@@ -30,14 +30,12 @@
}
dag = DAG(
- dag_id='perf_dag_2', default_args=args,
- schedule_interval='@daily',
- dagrun_timeout=timedelta(minutes=60))
+ dag_id='perf_dag_2', default_args=args, schedule_interval='@daily', dagrun_timeout=timedelta(minutes=60)
+)
task_1 = BashOperator(
- task_id='perf_task_1',
- bash_command='sleep 5; echo "run_id={{ run_id }} | dag_run={{ dag_run }}"',
- dag=dag)
+ task_id='perf_task_1', bash_command='sleep 5; echo "run_id={{ run_id }} | dag_run={{ dag_run }}"', dag=dag
+)
for i in range(2, 5):
task = BashOperator(
@@ -45,7 +43,8 @@
bash_command='''
sleep 5; echo "run_id={{ run_id }} | dag_run={{ dag_run }}"
''',
- dag=dag)
+ dag=dag,
+ )
task.set_upstream(task_1)
if __name__ == "__main__":
diff --git a/tests/test_utils/perf/perf_kit/memory.py b/tests/test_utils/perf/perf_kit/memory.py
index f84c505efafd3..5236e236d05db 100644
--- a/tests/test_utils/perf/perf_kit/memory.py
+++ b/tests/test_utils/perf/perf_kit/memory.py
@@ -36,6 +36,7 @@ def _human_readable_size(size, decimal_places=3):
class TraceMemoryResult:
"""Trace results of memory,"""
+
def __init__(self):
self.before = 0
self.after = 0
diff --git a/tests/test_utils/perf/perf_kit/repeat_and_time.py b/tests/test_utils/perf/perf_kit/repeat_and_time.py
index 8efd7f10b6bda..b0434b1c579d0 100644
--- a/tests/test_utils/perf/perf_kit/repeat_and_time.py
+++ b/tests/test_utils/perf/perf_kit/repeat_and_time.py
@@ -24,6 +24,7 @@
class TimingResult:
"""Timing result."""
+
def __init__(self):
self.start_time = 0
self.end_time = 0
@@ -86,6 +87,7 @@ def timeout(seconds=1):
:param seconds: Number of seconds
"""
+
def handle_timeout(signum, frame):
raise TimeoutException("Process timed out.")
diff --git a/tests/test_utils/perf/perf_kit/sqlalchemy.py b/tests/test_utils/perf/perf_kit/sqlalchemy.py
index e5c7c359cec90..b0c5a3ed70f5e 100644
--- a/tests/test_utils/perf/perf_kit/sqlalchemy.py
+++ b/tests/test_utils/perf/perf_kit/sqlalchemy.py
@@ -26,9 +26,8 @@ def _pretty_format_sql(text: str):
import pygments
from pygments.formatters.terminal import TerminalFormatter
from pygments.lexers.sql import SqlLexer
- text = pygments.highlight(
- code=text, formatter=TerminalFormatter(), lexer=SqlLexer()
- ).rstrip()
+
+ text = pygments.highlight(code=text, formatter=TerminalFormatter(), lexer=SqlLexer()).rstrip()
return text
@@ -43,6 +42,7 @@ class TraceQueries:
:param display_parameters: If True, display SQL statement parameters
:param print_fn: The function used to display the text. By default,``builtins.print``
"""
+
def __init__(
self,
*,
@@ -51,7 +51,7 @@ def __init__(
display_trace: bool = True,
display_sql: bool = False,
display_parameters: bool = True,
- print_fn: Callable[[str], None] = print
+ print_fn: Callable[[str], None] = print,
):
self.display_num = display_num
self.display_time = display_time
@@ -61,13 +61,15 @@ def __init__(
self.print_fn = print_fn
self.query_count = 0
- def before_cursor_execute(self,
- conn,
- cursor, # pylint: disable=unused-argument
- statement, # pylint: disable=unused-argument
- parameters, # pylint: disable=unused-argument
- context, # pylint: disable=unused-argument
- executemany): # pylint: disable=unused-argument
+ def before_cursor_execute(
+ self,
+ conn,
+ cursor, # pylint: disable=unused-argument
+ statement, # pylint: disable=unused-argument
+ parameters, # pylint: disable=unused-argument
+ context, # pylint: disable=unused-argument
+ executemany,
+ ): # pylint: disable=unused-argument
"""
Executed before cursor.
@@ -83,13 +85,15 @@ def before_cursor_execute(self,
conn.info.setdefault("query_start_time", []).append(time.monotonic())
self.query_count += 1
- def after_cursor_execute(self,
- conn,
- cursor, # pylint: disable=unused-argument
- statement,
- parameters,
- context, # pylint: disable=unused-argument
- executemany): # pylint: disable=unused-argument
+ def after_cursor_execute(
+ self,
+ conn,
+ cursor, # pylint: disable=unused-argument
+ statement,
+ parameters,
+ context, # pylint: disable=unused-argument
+ executemany,
+ ): # pylint: disable=unused-argument
"""
Executed after cursor.
@@ -139,6 +143,7 @@ def __enter__(self):
def __exit__(self, type_, value, traceback): # noqa pylint: disable=redefined-outer-name
import airflow.settings
+
event.remove(airflow.settings.engine, "before_cursor_execute", self.before_cursor_execute)
event.remove(airflow.settings.engine, "after_cursor_execute", self.after_cursor_execute)
@@ -150,6 +155,7 @@ class CountQueriesResult:
"""
Counter for number of queries.
"""
+
def __init__(self):
self.count = 0
@@ -180,13 +186,15 @@ def __exit__(self, type_, value, traceback): # noqa pylint: disable=redefined-o
event.remove(airflow.settings.engine, "after_cursor_execute", self.after_cursor_execute)
self.print_fn(f"Count SQL queries: {self.result.count}")
- def after_cursor_execute(self,
- conn, # pylint: disable=unused-argument
- cursor, # pylint: disable=unused-argument
- statement, # pylint: disable=unused-argument
- parameters, # pylint: disable=unused-argument
- context, # pylint: disable=unused-argument
- executemany): # pylint: disable=unused-argument
+ def after_cursor_execute(
+ self,
+ conn, # pylint: disable=unused-argument
+ cursor, # pylint: disable=unused-argument
+ statement, # pylint: disable=unused-argument
+ parameters, # pylint: disable=unused-argument
+ context, # pylint: disable=unused-argument
+ executemany,
+ ): # pylint: disable=unused-argument
"""
Executed after cursor.
@@ -212,13 +220,16 @@ def case():
from airflow.jobs.scheduler_job import DagFileProcessor
- with mock.patch.dict("os.environ", {
- "PERF_DAGS_COUNT": "200",
- "PERF_TASKS_COUNT": "10",
- "PERF_START_AGO": "2d",
- "PERF_SCHEDULE_INTERVAL": "None",
- "PERF_SHAPE": "no_structure",
- }):
+ with mock.patch.dict(
+ "os.environ",
+ {
+ "PERF_DAGS_COUNT": "200",
+ "PERF_TASKS_COUNT": "10",
+ "PERF_START_AGO": "2d",
+ "PERF_SCHEDULE_INTERVAL": "None",
+ "PERF_SHAPE": "no_structure",
+ },
+ ):
log = logging.getLogger(__name__)
processor = DagFileProcessor(dag_ids=[], log=log)
dag_file = os.path.join(os.path.dirname(__file__), os.path.pardir, "dags", "elastic_dag.py")
diff --git a/tests/test_utils/perf/scheduler_dag_execution_timing.py b/tests/test_utils/perf/scheduler_dag_execution_timing.py
index 8b9b979357e44..87ea15f5a05c7 100755
--- a/tests/test_utils/perf/scheduler_dag_execution_timing.py
+++ b/tests/test_utils/perf/scheduler_dag_execution_timing.py
@@ -33,6 +33,7 @@ class ShortCircuitExecutorMixin:
"""
Mixin class to manage the scheduler state during the performance test run.
"""
+
def __init__(self, dag_ids_to_watch, num_runs):
super().__init__()
self.num_runs_per_dag = num_runs
@@ -48,8 +49,9 @@ def reset(self, dag_ids_to_watch):
# A "cache" of DagRun row, so we don't have to look it up each
# time. This is to try and reduce the impact of our
# benchmarking code on runtime,
- runs={}
- ) for dag_id in dag_ids_to_watch
+ runs={},
+ )
+ for dag_id in dag_ids_to_watch
}
def change_state(self, key, state, info=None):
@@ -58,6 +60,7 @@ def change_state(self, key, state, info=None):
and then shut down the scheduler after the task is complete
"""
from airflow.utils.state import State
+
super().change_state(key, state, info=info)
dag_id, _, execution_date, __ = key
@@ -87,8 +90,9 @@ def change_state(self, key, state, info=None):
self.log.warning("STOPPING SCHEDULER -- all runs complete")
self.scheduler_job.processor_agent._done = True # pylint: disable=protected-access
return
- self.log.warning("WAITING ON %d RUNS",
- sum(map(attrgetter('waiting_for'), self.dags_to_watch.values())))
+ self.log.warning(
+ "WAITING ON %d RUNS", sum(map(attrgetter('waiting_for'), self.dags_to_watch.values()))
+ )
def get_executor_under_test(dotted_path):
@@ -114,6 +118,7 @@ class ShortCircuitExecutor(ShortCircuitExecutorMixin, executor_cls):
"""
Placeholder class that implements the inheritance hierarchy
"""
+
scheduler_job = None
return ShortCircuitExecutor
@@ -142,6 +147,7 @@ def pause_all_dags(session):
Pause all Dags
"""
from airflow.models.dag import DagModel
+
session.query(DagModel).update({'is_paused': True})
@@ -154,9 +160,11 @@ def create_dag_runs(dag, num_runs, session):
try:
from airflow.utils.types import DagRunType
+
id_prefix = f'{DagRunType.SCHEDULED.value}__'
except ImportError:
from airflow.models.dagrun import DagRun
+
id_prefix = DagRun.ID_PREFIX # pylint: disable=no-member
next_run_date = dag.normalize_schedule(dag.start_date or min(t.start_date for t in dag.tasks))
@@ -176,18 +184,27 @@ def create_dag_runs(dag, num_runs, session):
@click.command()
@click.option('--num-runs', default=1, help='number of DagRun, to run for each DAG')
@click.option('--repeat', default=3, help='number of times to run test, to reduce variance')
-@click.option('--pre-create-dag-runs', is_flag=True, default=False,
- help='''Pre-create the dag runs and stop the scheduler creating more.
+@click.option(
+ '--pre-create-dag-runs',
+ is_flag=True,
+ default=False,
+ help='''Pre-create the dag runs and stop the scheduler creating more.
Warning: this makes the scheduler do (slightly) less work so may skew your numbers. Use sparingly!
- ''')
-@click.option('--executor-class', default='MockExecutor',
- help=textwrap.dedent('''
+ ''',
+)
+@click.option(
+ '--executor-class',
+ default='MockExecutor',
+ help=textwrap.dedent(
+ '''
Dotted path Executor class to test, for example
'airflow.executors.local_executor.LocalExecutor'. Defaults to MockExecutor which doesn't run tasks.
- '''))
+ '''
+ ), # pylint: disable=too-many-locals
+)
@click.argument('dag_ids', required=True, nargs=-1)
-def main(num_runs, repeat, pre_create_dag_runs, executor_class, dag_ids): # pylint: disable=too-many-locals
+def main(num_runs, repeat, pre_create_dag_runs, executor_class, dag_ids):
"""
This script can be used to measure the total "scheduler overhead" of Airflow.
@@ -250,7 +267,8 @@ def main(num_runs, repeat, pre_create_dag_runs, executor_class, dag_ids): # pyl
message = (
f"DAG {dag_id} has incorrect end_date ({end_date}) for number of runs! "
f"It should be "
- f" {next_run_date}")
+ f" {next_run_date}"
+ )
sys.exit(message)
if pre_create_dag_runs:
@@ -298,13 +316,10 @@ def main(num_runs, repeat, pre_create_dag_runs, executor_class, dag_ids): # pyl
msg = "Time for %d dag runs of %d dags with %d total tasks: %.4fs"
if len(times) > 1:
- print((msg + " (±%.3fs)") % (
- num_runs,
- len(dags),
- total_tasks,
- statistics.mean(times),
- statistics.stdev(times)
- ))
+ print(
+ (msg + " (±%.3fs)")
+ % (num_runs, len(dags), total_tasks, statistics.mean(times), statistics.stdev(times))
+ )
else:
print(msg % (num_runs, len(dags), total_tasks, times[0]))
diff --git a/tests/test_utils/perf/scheduler_ops_metrics.py b/tests/test_utils/perf/scheduler_ops_metrics.py
index 0d3b4a474c640..a0ffd420c1ae9 100644
--- a/tests/test_utils/perf/scheduler_ops_metrics.py
+++ b/tests/test_utils/perf/scheduler_ops_metrics.py
@@ -59,9 +59,8 @@ class SchedulerMetricsJob(SchedulerJob):
You can specify timeout in seconds as an optional parameter.
Its default value is 6 seconds.
"""
- __mapper_args__ = {
- 'polymorphic_identity': 'SchedulerMetricsJob'
- }
+
+ __mapper_args__ = {'polymorphic_identity': 'SchedulerMetricsJob'}
def __init__(self, dag_ids, subdir, max_runtime_secs):
self.max_runtime_secs = max_runtime_secs
@@ -73,23 +72,32 @@ def print_stats(self):
"""
session = settings.Session()
TI = TaskInstance
- tis = (
- session
- .query(TI)
- .filter(TI.dag_id.in_(DAG_IDS))
- .all()
- )
+ tis = session.query(TI).filter(TI.dag_id.in_(DAG_IDS)).all()
successful_tis = [x for x in tis if x.state == State.SUCCESS]
- ti_perf = [(ti.dag_id, ti.task_id, ti.execution_date,
- (ti.queued_dttm - self.start_date).total_seconds(),
- (ti.start_date - self.start_date).total_seconds(),
- (ti.end_date - self.start_date).total_seconds(),
- ti.duration) for ti in successful_tis]
- ti_perf_df = pd.DataFrame(ti_perf, columns=['dag_id', 'task_id',
- 'execution_date',
- 'queue_delay',
- 'start_delay', 'land_time',
- 'duration'])
+ ti_perf = [
+ (
+ ti.dag_id,
+ ti.task_id,
+ ti.execution_date,
+ (ti.queued_dttm - self.start_date).total_seconds(),
+ (ti.start_date - self.start_date).total_seconds(),
+ (ti.end_date - self.start_date).total_seconds(),
+ ti.duration,
+ )
+ for ti in successful_tis
+ ]
+ ti_perf_df = pd.DataFrame(
+ ti_perf,
+ columns=[
+ 'dag_id',
+ 'task_id',
+ 'execution_date',
+ 'queue_delay',
+ 'start_delay',
+ 'land_time',
+ 'duration',
+ ],
+ )
print('Performance Results')
print('###################')
@@ -99,9 +107,15 @@ def print_stats(self):
print('###################')
if len(tis) > len(successful_tis):
print("WARNING!! The following task instances haven't completed")
- print(pd.DataFrame([(ti.dag_id, ti.task_id, ti.execution_date, ti.state)
- for ti in filter(lambda x: x.state != State.SUCCESS, tis)],
- columns=['dag_id', 'task_id', 'execution_date', 'state']))
+ print(
+ pd.DataFrame(
+ [
+ (ti.dag_id, ti.task_id, ti.execution_date, ti.state)
+ for ti in filter(lambda x: x.state != State.SUCCESS, tis)
+ ],
+ columns=['dag_id', 'task_id', 'execution_date', 'state'],
+ )
+ )
session.commit()
@@ -114,23 +128,21 @@ def heartbeat(self):
# Get all the relevant task instances
TI = TaskInstance
successful_tis = (
- session
- .query(TI)
- .filter(TI.dag_id.in_(DAG_IDS))
- .filter(TI.state.in_([State.SUCCESS]))
- .all()
+ session.query(TI).filter(TI.dag_id.in_(DAG_IDS)).filter(TI.state.in_([State.SUCCESS])).all()
)
session.commit()
dagbag = DagBag(SUBDIR)
dags = [dagbag.dags[dag_id] for dag_id in DAG_IDS]
# the tasks in perf_dag_1 and per_dag_2 have a daily schedule interval.
- num_task_instances = sum([(timezone.utcnow() - task.start_date).days
- for dag in dags for task in dag.tasks])
+ num_task_instances = sum(
+ [(timezone.utcnow() - task.start_date).days for dag in dags for task in dag.tasks]
+ )
- if (len(successful_tis) == num_task_instances or
- (timezone.utcnow() - self.start_date).total_seconds() >
- self.max_runtime_secs):
+ if (
+ len(successful_tis) == num_task_instances
+ or (timezone.utcnow() - self.start_date).total_seconds() > self.max_runtime_secs
+ ):
if len(successful_tis) == num_task_instances:
self.log.info("All tasks processed! Printing stats.")
else:
@@ -145,9 +157,13 @@ def clear_dag_runs():
Remove any existing DAG runs for the perf test DAGs.
"""
session = settings.Session()
- drs = session.query(DagRun).filter(
- DagRun.dag_id.in_(DAG_IDS),
- ).all()
+ drs = (
+ session.query(DagRun)
+ .filter(
+ DagRun.dag_id.in_(DAG_IDS),
+ )
+ .all()
+ )
for dr in drs:
logging.info('Deleting DagRun :: %s', dr)
session.delete(dr)
@@ -159,12 +175,7 @@ def clear_dag_task_instances():
"""
session = settings.Session()
TI = TaskInstance
- tis = (
- session
- .query(TI)
- .filter(TI.dag_id.in_(DAG_IDS))
- .all()
- )
+ tis = session.query(TI).filter(TI.dag_id.in_(DAG_IDS)).all()
for ti in tis:
logging.info('Deleting TaskInstance :: %s', ti)
session.delete(ti)
@@ -176,8 +187,7 @@ def set_dags_paused_state(is_paused):
Toggle the pause state of the DAGs in the test.
"""
session = settings.Session()
- dag_models = session.query(DagModel).filter(
- DagModel.dag_id.in_(DAG_IDS))
+ dag_models = session.query(DagModel).filter(DagModel.dag_id.in_(DAG_IDS))
for dag_model in dag_models:
logging.info('Setting DAG :: %s is_paused=%s', dag_model, is_paused)
dag_model.is_paused = is_paused
@@ -205,10 +215,7 @@ def main():
clear_dag_runs()
clear_dag_task_instances()
- job = SchedulerMetricsJob(
- dag_ids=DAG_IDS,
- subdir=SUBDIR,
- max_runtime_secs=max_runtime_secs)
+ job = SchedulerMetricsJob(dag_ids=DAG_IDS, subdir=SUBDIR, max_runtime_secs=max_runtime_secs)
job.run()
diff --git a/tests/test_utils/perf/sql_queries.py b/tests/test_utils/perf/sql_queries.py
index fc3320cb81823..a8e57aa7fe184 100644
--- a/tests/test_utils/perf/sql_queries.py
+++ b/tests/test_utils/perf/sql_queries.py
@@ -34,9 +34,7 @@
LOG_LEVEL = "INFO"
LOG_FILE = "/files/sql_stats.log" # Default to run in Breeze
-os.environ[
- "AIRFLOW__LOGGING__LOGGING_CONFIG_CLASS"
-] = "scripts.perf.sql_queries.DEBUG_LOGGING_CONFIG"
+os.environ["AIRFLOW__LOGGING__LOGGING_CONFIG_CLASS"] = "scripts.perf.sql_queries.DEBUG_LOGGING_CONFIG"
DEBUG_LOGGING_CONFIG = {
"version": 1,
@@ -76,6 +74,7 @@ class Query(NamedTuple):
"""
Define attributes of the queries that will be picked up by the performance tests.
"""
+
function: str
file: str
location: int
@@ -110,6 +109,7 @@ def reset_db():
Wrapper function that calls the airflows resetdb function.
"""
from airflow.utils.db import resetdb
+
resetdb()
diff --git a/tests/test_utils/remote_user_api_auth_backend.py b/tests/test_utils/remote_user_api_auth_backend.py
index 3793831d1c953..e5f189a0af94b 100644
--- a/tests/test_utils/remote_user_api_auth_backend.py
+++ b/tests/test_utils/remote_user_api_auth_backend.py
@@ -38,9 +38,8 @@ def init_app(_):
def _lookup_user(user_email_or_username: str):
security_manager = current_app.appbuilder.sm
- user = (
- security_manager.find_user(email=user_email_or_username)
- or security_manager.find_user(username=user_email_or_username)
+ user = security_manager.find_user(email=user_email_or_username) or security_manager.find_user(
+ username=user_email_or_username
)
if not user:
return None
@@ -53,6 +52,7 @@ def _lookup_user(user_email_or_username: str):
def requires_authentication(function: T):
"""Decorator for functions that require authentication"""
+
@wraps(function)
def decorated(*args, **kwargs):
user_id = request.remote_user
diff --git a/tests/test_utils/test_remote_user_api_auth_backend.py b/tests/test_utils/test_remote_user_api_auth_backend.py
index c63d0f46fdf5a..966412fbf4e97 100644
--- a/tests/test_utils/test_remote_user_api_auth_backend.py
+++ b/tests/test_utils/test_remote_user_api_auth_backend.py
@@ -49,9 +49,7 @@ def test_success_using_username(self):
clear_db_pools()
with self.app.test_client() as test_client:
- response = test_client.get(
- "/api/experimental/pools", environ_overrides={'REMOTE_USER': "test"}
- )
+ response = test_client.get("/api/experimental/pools", environ_overrides={'REMOTE_USER': "test"})
self.assertEqual("test@fab.org", current_user.email)
self.assertEqual(200, response.status_code)
diff --git a/tests/ti_deps/deps/fake_models.py b/tests/ti_deps/deps/fake_models.py
index a7f16b2fa44b7..63346101c8aeb 100644
--- a/tests/ti_deps/deps/fake_models.py
+++ b/tests/ti_deps/deps/fake_models.py
@@ -20,7 +20,6 @@
class FakeTI:
-
def __init__(self, **kwds):
self.__dict__.update(kwds)
@@ -32,13 +31,11 @@ def are_dependents_done(self, session): # pylint: disable=unused-argument
class FakeTask:
-
def __init__(self, **kwds):
self.__dict__.update(kwds)
class FakeDag:
-
def __init__(self, **kwds):
self.__dict__.update(kwds)
@@ -47,6 +44,5 @@ def get_running_dagruns(self, _):
class FakeContext:
-
def __init__(self, **kwds):
self.__dict__.update(kwds)
diff --git a/tests/ti_deps/deps/test_dag_ti_slots_available_dep.py b/tests/ti_deps/deps/test_dag_ti_slots_available_dep.py
index 3b27fe63be376..f6fba610f6710 100644
--- a/tests/ti_deps/deps/test_dag_ti_slots_available_dep.py
+++ b/tests/ti_deps/deps/test_dag_ti_slots_available_dep.py
@@ -24,7 +24,6 @@
class TestDagTISlotsAvailableDep(unittest.TestCase):
-
def test_concurrency_reached(self):
"""
Test concurrency reached should fail dep
diff --git a/tests/ti_deps/deps/test_dag_unpaused_dep.py b/tests/ti_deps/deps/test_dag_unpaused_dep.py
index a397af0498074..1303f616c9370 100644
--- a/tests/ti_deps/deps/test_dag_unpaused_dep.py
+++ b/tests/ti_deps/deps/test_dag_unpaused_dep.py
@@ -24,7 +24,6 @@
class TestDagUnpausedDep(unittest.TestCase):
-
def test_concurrency_reached(self):
"""
Test paused DAG should fail dependency
diff --git a/tests/ti_deps/deps/test_dagrun_exists_dep.py b/tests/ti_deps/deps/test_dagrun_exists_dep.py
index 567a3aa6508e6..a1696c0794810 100644
--- a/tests/ti_deps/deps/test_dagrun_exists_dep.py
+++ b/tests/ti_deps/deps/test_dagrun_exists_dep.py
@@ -25,7 +25,6 @@
class TestDagrunRunningDep(unittest.TestCase):
-
@patch('airflow.models.DagRun.find', return_value=())
def test_dagrun_doesnt_exist(self, mock_dagrun_find):
"""
diff --git a/tests/ti_deps/deps/test_not_in_retry_period_dep.py b/tests/ti_deps/deps/test_not_in_retry_period_dep.py
index bab5372ec1d86..2fd28d32e9675 100644
--- a/tests/ti_deps/deps/test_not_in_retry_period_dep.py
+++ b/tests/ti_deps/deps/test_not_in_retry_period_dep.py
@@ -29,9 +29,7 @@
class TestNotInRetryPeriodDep(unittest.TestCase):
-
- def _get_task_instance(self, state, end_date=None,
- retry_delay=timedelta(minutes=15)):
+ def _get_task_instance(self, state, end_date=None, retry_delay=timedelta(minutes=15)):
task = Mock(retry_delay=retry_delay, retry_exponential_backoff=False)
ti = TaskInstance(task=task, state=state, execution_date=None)
ti.end_date = end_date
@@ -42,8 +40,7 @@ def test_still_in_retry_period(self):
"""
Task instances that are in their retry period should fail this dep
"""
- ti = self._get_task_instance(State.UP_FOR_RETRY,
- end_date=datetime(2016, 1, 1, 15, 30))
+ ti = self._get_task_instance(State.UP_FOR_RETRY, end_date=datetime(2016, 1, 1, 15, 30))
self.assertTrue(ti.is_premature)
self.assertFalse(NotInRetryPeriodDep().is_met(ti=ti))
@@ -52,8 +49,7 @@ def test_retry_period_finished(self):
"""
Task instance's that have had their retry period elapse should pass this dep
"""
- ti = self._get_task_instance(State.UP_FOR_RETRY,
- end_date=datetime(2016, 1, 1))
+ ti = self._get_task_instance(State.UP_FOR_RETRY, end_date=datetime(2016, 1, 1))
self.assertFalse(ti.is_premature)
self.assertTrue(NotInRetryPeriodDep().is_met(ti=ti))
diff --git a/tests/ti_deps/deps/test_not_previously_skipped_dep.py b/tests/ti_deps/deps/test_not_previously_skipped_dep.py
index 658fee3644076..9ddd0287c78d1 100644
--- a/tests/ti_deps/deps/test_not_previously_skipped_dep.py
+++ b/tests/ti_deps/deps/test_not_previously_skipped_dep.py
@@ -50,9 +50,7 @@ def test_no_skipmixin_parent():
A simple DAG with no branching. Both op1 and op2 are DummyOperator. NotPreviouslySkippedDep is met.
"""
start_date = pendulum.datetime(2020, 1, 1)
- dag = DAG(
- "test_no_skipmixin_parent_dag", schedule_interval=None, start_date=start_date
- )
+ dag = DAG("test_no_skipmixin_parent_dag", schedule_interval=None, start_date=start_date)
op1 = DummyOperator(task_id="op1", dag=dag)
op2 = DummyOperator(task_id="op2", dag=dag)
op1 >> op2
@@ -71,9 +69,7 @@ def test_parent_follow_branch():
A simple DAG with a BranchPythonOperator that follows op2. NotPreviouslySkippedDep is met.
"""
start_date = pendulum.datetime(2020, 1, 1)
- dag = DAG(
- "test_parent_follow_branch_dag", schedule_interval=None, start_date=start_date
- )
+ dag = DAG("test_parent_follow_branch_dag", schedule_interval=None, start_date=start_date)
dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=start_date)
op1 = BranchPythonOperator(task_id="op1", python_callable=lambda: "op2", dag=dag)
op2 = DummyOperator(task_id="op2", dag=dag)
@@ -96,9 +92,7 @@ def test_parent_skip_branch():
session.query(DagRun).delete()
session.query(TaskInstance).delete()
start_date = pendulum.datetime(2020, 1, 1)
- dag = DAG(
- "test_parent_skip_branch_dag", schedule_interval=None, start_date=start_date
- )
+ dag = DAG("test_parent_skip_branch_dag", schedule_interval=None, start_date=start_date)
dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=start_date)
op1 = BranchPythonOperator(task_id="op1", python_callable=lambda: "op3", dag=dag)
op2 = DummyOperator(task_id="op2", dag=dag)
@@ -120,9 +114,7 @@ def test_parent_not_executed():
executed (no xcom data). NotPreviouslySkippedDep is met (no decision).
"""
start_date = pendulum.datetime(2020, 1, 1)
- dag = DAG(
- "test_parent_not_executed_dag", schedule_interval=None, start_date=start_date
- )
+ dag = DAG("test_parent_not_executed_dag", schedule_interval=None, start_date=start_date)
op1 = BranchPythonOperator(task_id="op1", python_callable=lambda: "op3", dag=dag)
op2 = DummyOperator(task_id="op2", dag=dag)
op3 = DummyOperator(task_id="op3", dag=dag)
diff --git a/tests/ti_deps/deps/test_prev_dagrun_dep.py b/tests/ti_deps/deps/test_prev_dagrun_dep.py
index 948ebd5d39e71..4a05d77ea72c1 100644
--- a/tests/ti_deps/deps/test_prev_dagrun_dep.py
+++ b/tests/ti_deps/deps/test_prev_dagrun_dep.py
@@ -28,7 +28,6 @@
class TestPrevDagrunDep(unittest.TestCase):
-
def _get_task(self, **kwargs):
return BaseOperator(task_id='test_task', dag=DAG('test_dag'), **kwargs)
@@ -37,14 +36,16 @@ def test_not_depends_on_past(self):
If depends on past isn't set in the task then the previous dagrun should be
ignored, even though there is no previous_ti which would normally fail the dep
"""
- task = self._get_task(depends_on_past=False,
- start_date=datetime(2016, 1, 1),
- wait_for_downstream=False)
- prev_ti = Mock(task=task, state=State.SUCCESS,
- are_dependents_done=Mock(return_value=True),
- execution_date=datetime(2016, 1, 2))
- ti = Mock(task=task, previous_ti=prev_ti,
- execution_date=datetime(2016, 1, 3))
+ task = self._get_task(
+ depends_on_past=False, start_date=datetime(2016, 1, 1), wait_for_downstream=False
+ )
+ prev_ti = Mock(
+ task=task,
+ state=State.SUCCESS,
+ are_dependents_done=Mock(return_value=True),
+ execution_date=datetime(2016, 1, 2),
+ )
+ ti = Mock(task=task, previous_ti=prev_ti, execution_date=datetime(2016, 1, 3))
dep_context = DepContext(ignore_depends_on_past=False)
self.assertTrue(PrevDagrunDep().is_met(ti=ti, dep_context=dep_context))
@@ -54,14 +55,16 @@ def test_context_ignore_depends_on_past(self):
If the context overrides depends_on_past then the dep should be met,
even though there is no previous_ti which would normally fail the dep
"""
- task = self._get_task(depends_on_past=True,
- start_date=datetime(2016, 1, 1),
- wait_for_downstream=False)
- prev_ti = Mock(task=task, state=State.SUCCESS,
- are_dependents_done=Mock(return_value=True),
- execution_date=datetime(2016, 1, 2))
- ti = Mock(task=task, previous_ti=prev_ti,
- execution_date=datetime(2016, 1, 3))
+ task = self._get_task(
+ depends_on_past=True, start_date=datetime(2016, 1, 1), wait_for_downstream=False
+ )
+ prev_ti = Mock(
+ task=task,
+ state=State.SUCCESS,
+ are_dependents_done=Mock(return_value=True),
+ execution_date=datetime(2016, 1, 2),
+ )
+ ti = Mock(task=task, previous_ti=prev_ti, execution_date=datetime(2016, 1, 3))
dep_context = DepContext(ignore_depends_on_past=True)
self.assertTrue(PrevDagrunDep().is_met(ti=ti, dep_context=dep_context))
@@ -70,12 +73,11 @@ def test_first_task_run(self):
"""
The first task run for a TI should pass since it has no previous dagrun.
"""
- task = self._get_task(depends_on_past=True,
- start_date=datetime(2016, 1, 1),
- wait_for_downstream=False)
+ task = self._get_task(
+ depends_on_past=True, start_date=datetime(2016, 1, 1), wait_for_downstream=False
+ )
prev_ti = None
- ti = Mock(task=task, previous_ti=prev_ti,
- execution_date=datetime(2016, 1, 1))
+ ti = Mock(task=task, previous_ti=prev_ti, execution_date=datetime(2016, 1, 1))
dep_context = DepContext(ignore_depends_on_past=False)
self.assertTrue(PrevDagrunDep().is_met(ti=ti, dep_context=dep_context))
@@ -84,13 +86,11 @@ def test_prev_ti_bad_state(self):
"""
If the previous TI did not complete execution this dep should fail.
"""
- task = self._get_task(depends_on_past=True,
- start_date=datetime(2016, 1, 1),
- wait_for_downstream=False)
- prev_ti = Mock(state=State.NONE,
- are_dependents_done=Mock(return_value=True))
- ti = Mock(task=task, previous_ti=prev_ti,
- execution_date=datetime(2016, 1, 2))
+ task = self._get_task(
+ depends_on_past=True, start_date=datetime(2016, 1, 1), wait_for_downstream=False
+ )
+ prev_ti = Mock(state=State.NONE, are_dependents_done=Mock(return_value=True))
+ ti = Mock(task=task, previous_ti=prev_ti, execution_date=datetime(2016, 1, 2))
dep_context = DepContext(ignore_depends_on_past=False)
self.assertFalse(PrevDagrunDep().is_met(ti=ti, dep_context=dep_context))
@@ -101,13 +101,9 @@ def test_failed_wait_for_downstream(self):
previous dagrun then it should fail this dep if the downstream TIs of
the previous TI are not done.
"""
- task = self._get_task(depends_on_past=True,
- start_date=datetime(2016, 1, 1),
- wait_for_downstream=True)
- prev_ti = Mock(state=State.SUCCESS,
- are_dependents_done=Mock(return_value=False))
- ti = Mock(task=task, previous_ti=prev_ti,
- execution_date=datetime(2016, 1, 2))
+ task = self._get_task(depends_on_past=True, start_date=datetime(2016, 1, 1), wait_for_downstream=True)
+ prev_ti = Mock(state=State.SUCCESS, are_dependents_done=Mock(return_value=False))
+ ti = Mock(task=task, previous_ti=prev_ti, execution_date=datetime(2016, 1, 2))
dep_context = DepContext(ignore_depends_on_past=False)
self.assertFalse(PrevDagrunDep().is_met(ti=ti, dep_context=dep_context))
@@ -116,16 +112,9 @@ def test_all_met(self):
"""
Test to make sure all of the conditions for the dep are met
"""
- task = self._get_task(depends_on_past=True,
- start_date=datetime(2016, 1, 1),
- wait_for_downstream=True)
- prev_ti = Mock(state=State.SUCCESS,
- are_dependents_done=Mock(return_value=True))
- ti = Mock(
- task=task,
- execution_date=datetime(2016, 1, 2),
- **{'get_previous_ti.return_value': prev_ti}
- )
+ task = self._get_task(depends_on_past=True, start_date=datetime(2016, 1, 1), wait_for_downstream=True)
+ prev_ti = Mock(state=State.SUCCESS, are_dependents_done=Mock(return_value=True))
+ ti = Mock(task=task, execution_date=datetime(2016, 1, 2), **{'get_previous_ti.return_value': prev_ti})
dep_context = DepContext(ignore_depends_on_past=False)
self.assertTrue(PrevDagrunDep().is_met(ti=ti, dep_context=dep_context))
diff --git a/tests/ti_deps/deps/test_ready_to_reschedule_dep.py b/tests/ti_deps/deps/test_ready_to_reschedule_dep.py
index 66562b30a8462..f9bbb9f282b3f 100644
--- a/tests/ti_deps/deps/test_ready_to_reschedule_dep.py
+++ b/tests/ti_deps/deps/test_ready_to_reschedule_dep.py
@@ -28,7 +28,6 @@
class TestNotInReschedulePeriodDep(unittest.TestCase):
-
def _get_task_instance(self, state):
dag = DAG('test_dag')
task = Mock(dag=dag)
@@ -37,9 +36,14 @@ def _get_task_instance(self, state):
def _get_task_reschedule(self, reschedule_date):
task = Mock(dag_id='test_dag', task_id='test_task')
- reschedule = TaskReschedule(task=task, execution_date=None, try_number=None,
- start_date=reschedule_date, end_date=reschedule_date,
- reschedule_date=reschedule_date)
+ reschedule = TaskReschedule(
+ task=task,
+ execution_date=None,
+ try_number=None,
+ start_date=reschedule_date,
+ end_date=reschedule_date,
+ reschedule_date=reschedule_date,
+ )
return reschedule
def test_should_pass_if_ignore_in_reschedule_period_is_set(self):
@@ -59,8 +63,9 @@ def test_should_pass_if_no_reschedule_record_exists(self, mock_query_for_task_in
@patch('airflow.models.taskreschedule.TaskReschedule.query_for_task_instance')
def test_should_pass_after_reschedule_date_one(self, mock_query_for_task_instance):
- mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value = \
+ mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value = (
self._get_task_reschedule(utcnow() - timedelta(minutes=1))
+ )
ti = self._get_task_instance(State.UP_FOR_RESCHEDULE)
self.assertTrue(ReadyToRescheduleDep().is_met(ti=ti))
@@ -76,8 +81,9 @@ def test_should_pass_after_reschedule_date_multiple(self, mock_query_for_task_in
@patch('airflow.models.taskreschedule.TaskReschedule.query_for_task_instance')
def test_should_fail_before_reschedule_date_one(self, mock_query_for_task_instance):
- mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value = \
+ mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value = (
self._get_task_reschedule(utcnow() + timedelta(minutes=1))
+ )
ti = self._get_task_instance(State.UP_FOR_RESCHEDULE)
self.assertFalse(ReadyToRescheduleDep().is_met(ti=ti))
diff --git a/tests/ti_deps/deps/test_runnable_exec_date_dep.py b/tests/ti_deps/deps/test_runnable_exec_date_dep.py
index 4c08f577d163e..f12f3a284a890 100644
--- a/tests/ti_deps/deps/test_runnable_exec_date_dep.py
+++ b/tests/ti_deps/deps/test_runnable_exec_date_dep.py
@@ -30,13 +30,16 @@
@freeze_time('2016-11-01')
-@pytest.mark.parametrize("allow_trigger_in_future,schedule_interval,execution_date,is_met", [
- (True, None, datetime(2016, 11, 3), True),
- (True, "@daily", datetime(2016, 11, 3), False),
- (False, None, datetime(2016, 11, 3), False),
- (False, "@daily", datetime(2016, 11, 3), False),
- (False, "@daily", datetime(2016, 11, 1), True),
- (False, None, datetime(2016, 11, 1), True)]
+@pytest.mark.parametrize(
+ "allow_trigger_in_future,schedule_interval,execution_date,is_met",
+ [
+ (True, None, datetime(2016, 11, 3), True),
+ (True, "@daily", datetime(2016, 11, 3), False),
+ (False, None, datetime(2016, 11, 3), False),
+ (False, "@daily", datetime(2016, 11, 3), False),
+ (False, "@daily", datetime(2016, 11, 1), True),
+ (False, None, datetime(2016, 11, 1), True),
+ ],
)
def test_exec_date_dep(allow_trigger_in_future, schedule_interval, execution_date, is_met):
"""
@@ -49,7 +52,8 @@ def test_exec_date_dep(allow_trigger_in_future, schedule_interval, execution_dat
'test_localtaskjob_heartbeat',
start_date=datetime(2015, 1, 1),
end_date=datetime(2016, 11, 5),
- schedule_interval=schedule_interval)
+ schedule_interval=schedule_interval,
+ )
with dag:
op1 = DummyOperator(task_id='op1')
@@ -59,7 +63,6 @@ def test_exec_date_dep(allow_trigger_in_future, schedule_interval, execution_dat
class TestRunnableExecDateDep(unittest.TestCase):
-
def _get_task_instance(self, execution_date, dag_end_date=None, task_end_date=None):
dag = Mock(end_date=dag_end_date)
task = Mock(dag=dag, end_date=task_end_date)
@@ -74,7 +77,8 @@ def test_exec_date_after_end_date(self):
'test_localtaskjob_heartbeat',
start_date=datetime(2015, 1, 1),
end_date=datetime(2016, 11, 5),
- schedule_interval=None)
+ schedule_interval=None,
+ )
with dag:
op1 = DummyOperator(task_id='op1')
diff --git a/tests/ti_deps/deps/test_task_concurrency.py b/tests/ti_deps/deps/test_task_concurrency.py
index 4d8f5fba5a8f7..9bdd1fab1c971 100644
--- a/tests/ti_deps/deps/test_task_concurrency.py
+++ b/tests/ti_deps/deps/test_task_concurrency.py
@@ -27,7 +27,6 @@
class TestTaskConcurrencyDep(unittest.TestCase):
-
def _get_task(self, **kwargs):
return BaseOperator(task_id='test_task', dag=DAG('test_dag'), **kwargs)
diff --git a/tests/ti_deps/deps/test_task_not_running_dep.py b/tests/ti_deps/deps/test_task_not_running_dep.py
index 2a4ab563f556d..353db3fafb8db 100644
--- a/tests/ti_deps/deps/test_task_not_running_dep.py
+++ b/tests/ti_deps/deps/test_task_not_running_dep.py
@@ -25,7 +25,6 @@
class TestTaskNotRunningDep(unittest.TestCase):
-
def test_not_running_state(self):
ti = Mock(state=State.QUEUED, end_date=datetime(2016, 1, 1))
self.assertTrue(TaskNotRunningDep().is_met(ti=ti))
diff --git a/tests/ti_deps/deps/test_trigger_rule_dep.py b/tests/ti_deps/deps/test_trigger_rule_dep.py
index 76b61b8af4f1f..c993e6c05fb66 100644
--- a/tests/ti_deps/deps/test_trigger_rule_dep.py
+++ b/tests/ti_deps/deps/test_trigger_rule_dep.py
@@ -34,11 +34,8 @@
class TestTriggerRuleDep(unittest.TestCase):
-
- def _get_task_instance(self, trigger_rule=TriggerRule.ALL_SUCCESS,
- state=None, upstream_task_ids=None):
- task = BaseOperator(task_id='test_task', trigger_rule=trigger_rule,
- start_date=datetime(2015, 1, 1))
+ def _get_task_instance(self, trigger_rule=TriggerRule.ALL_SUCCESS, state=None, upstream_task_ids=None):
+ task = BaseOperator(task_id='test_task', trigger_rule=trigger_rule, start_date=datetime(2015, 1, 1))
if upstream_task_ids:
task._upstream_task_ids.update(upstream_task_ids)
return TaskInstance(task=task, state=state, execution_date=task.start_date)
@@ -62,15 +59,18 @@ def test_one_success_tr_success(self):
One-success trigger rule success
"""
ti = self._get_task_instance(TriggerRule.ONE_SUCCESS, State.UP_FOR_RETRY)
- dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
- ti=ti,
- successes=1,
- skipped=2,
- failed=2,
- upstream_failed=2,
- done=2,
- flag_upstream_failed=False,
- session="Fake Session"))
+ dep_statuses = tuple(
+ TriggerRuleDep()._evaluate_trigger_rule(
+ ti=ti,
+ successes=1,
+ skipped=2,
+ failed=2,
+ upstream_failed=2,
+ done=2,
+ flag_upstream_failed=False,
+ session="Fake Session",
+ )
+ )
self.assertEqual(len(dep_statuses), 0)
def test_one_success_tr_failure(self):
@@ -78,15 +78,18 @@ def test_one_success_tr_failure(self):
One-success trigger rule failure
"""
ti = self._get_task_instance(TriggerRule.ONE_SUCCESS)
- dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
- ti=ti,
- successes=0,
- skipped=2,
- failed=2,
- upstream_failed=2,
- done=2,
- flag_upstream_failed=False,
- session="Fake Session"))
+ dep_statuses = tuple(
+ TriggerRuleDep()._evaluate_trigger_rule(
+ ti=ti,
+ successes=0,
+ skipped=2,
+ failed=2,
+ upstream_failed=2,
+ done=2,
+ flag_upstream_failed=False,
+ session="Fake Session",
+ )
+ )
self.assertEqual(len(dep_statuses), 1)
self.assertFalse(dep_statuses[0].passed)
@@ -95,15 +98,18 @@ def test_one_failure_tr_failure(self):
One-failure trigger rule failure
"""
ti = self._get_task_instance(TriggerRule.ONE_FAILED)
- dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
- ti=ti,
- successes=2,
- skipped=0,
- failed=0,
- upstream_failed=0,
- done=2,
- flag_upstream_failed=False,
- session="Fake Session"))
+ dep_statuses = tuple(
+ TriggerRuleDep()._evaluate_trigger_rule(
+ ti=ti,
+ successes=2,
+ skipped=0,
+ failed=0,
+ upstream_failed=0,
+ done=2,
+ flag_upstream_failed=False,
+ session="Fake Session",
+ )
+ )
self.assertEqual(len(dep_statuses), 1)
self.assertFalse(dep_statuses[0].passed)
@@ -112,61 +118,72 @@ def test_one_failure_tr_success(self):
One-failure trigger rule success
"""
ti = self._get_task_instance(TriggerRule.ONE_FAILED)
- dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
- ti=ti,
- successes=0,
- skipped=2,
- failed=2,
- upstream_failed=0,
- done=2,
- flag_upstream_failed=False,
- session="Fake Session"))
+ dep_statuses = tuple(
+ TriggerRuleDep()._evaluate_trigger_rule(
+ ti=ti,
+ successes=0,
+ skipped=2,
+ failed=2,
+ upstream_failed=0,
+ done=2,
+ flag_upstream_failed=False,
+ session="Fake Session",
+ )
+ )
self.assertEqual(len(dep_statuses), 0)
- dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
- ti=ti,
- successes=0,
- skipped=2,
- failed=0,
- upstream_failed=2,
- done=2,
- flag_upstream_failed=False,
- session="Fake Session"))
+ dep_statuses = tuple(
+ TriggerRuleDep()._evaluate_trigger_rule(
+ ti=ti,
+ successes=0,
+ skipped=2,
+ failed=0,
+ upstream_failed=2,
+ done=2,
+ flag_upstream_failed=False,
+ session="Fake Session",
+ )
+ )
self.assertEqual(len(dep_statuses), 0)
def test_all_success_tr_success(self):
"""
All-success trigger rule success
"""
- ti = self._get_task_instance(TriggerRule.ALL_SUCCESS,
- upstream_task_ids=["FakeTaskID"])
- dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
- ti=ti,
- successes=1,
- skipped=0,
- failed=0,
- upstream_failed=0,
- done=1,
- flag_upstream_failed=False,
- session="Fake Session"))
+ ti = self._get_task_instance(TriggerRule.ALL_SUCCESS, upstream_task_ids=["FakeTaskID"])
+ dep_statuses = tuple(
+ TriggerRuleDep()._evaluate_trigger_rule(
+ ti=ti,
+ successes=1,
+ skipped=0,
+ failed=0,
+ upstream_failed=0,
+ done=1,
+ flag_upstream_failed=False,
+ session="Fake Session",
+ )
+ )
self.assertEqual(len(dep_statuses), 0)
def test_all_success_tr_failure(self):
"""
All-success trigger rule failure
"""
- ti = self._get_task_instance(TriggerRule.ALL_SUCCESS,
- upstream_task_ids=["FakeTaskID",
- "OtherFakeTaskID"])
- dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
- ti=ti,
- successes=1,
- skipped=0,
- failed=1,
- upstream_failed=0,
- done=2,
- flag_upstream_failed=False,
- session="Fake Session"))
+ ti = self._get_task_instance(
+ TriggerRule.ALL_SUCCESS, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"]
+ )
+ dep_statuses = tuple(
+ TriggerRuleDep()._evaluate_trigger_rule(
+ ti=ti,
+ successes=1,
+ skipped=0,
+ failed=1,
+ upstream_failed=0,
+ done=2,
+ flag_upstream_failed=False,
+ session="Fake Session",
+ )
+ )
self.assertEqual(len(dep_statuses), 1)
self.assertFalse(dep_statuses[0].passed)
@@ -174,18 +191,21 @@ def test_all_success_tr_skip(self):
"""
All-success trigger rule fails when some upstream tasks are skipped.
"""
- ti = self._get_task_instance(TriggerRule.ALL_SUCCESS,
- upstream_task_ids=["FakeTaskID",
- "OtherFakeTaskID"])
- dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
- ti=ti,
- successes=1,
- skipped=1,
- failed=0,
- upstream_failed=0,
- done=2,
- flag_upstream_failed=False,
- session="Fake Session"))
+ ti = self._get_task_instance(
+ TriggerRule.ALL_SUCCESS, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"]
+ )
+ dep_statuses = tuple(
+ TriggerRuleDep()._evaluate_trigger_rule(
+ ti=ti,
+ successes=1,
+ skipped=1,
+ failed=0,
+ upstream_failed=0,
+ done=2,
+ flag_upstream_failed=False,
+ session="Fake Session",
+ )
+ )
self.assertEqual(len(dep_statuses), 1)
self.assertFalse(dep_statuses[0].passed)
@@ -194,18 +214,21 @@ def test_all_success_tr_skip_flag_upstream(self):
All-success trigger rule fails when some upstream tasks are skipped. The state of the ti
should be set to SKIPPED when flag_upstream_failed is True.
"""
- ti = self._get_task_instance(TriggerRule.ALL_SUCCESS,
- upstream_task_ids=["FakeTaskID",
- "OtherFakeTaskID"])
- dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
- ti=ti,
- successes=1,
- skipped=1,
- failed=0,
- upstream_failed=0,
- done=2,
- flag_upstream_failed=True,
- session=Mock()))
+ ti = self._get_task_instance(
+ TriggerRule.ALL_SUCCESS, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"]
+ )
+ dep_statuses = tuple(
+ TriggerRuleDep()._evaluate_trigger_rule(
+ ti=ti,
+ successes=1,
+ skipped=1,
+ failed=0,
+ upstream_failed=0,
+ done=2,
+ flag_upstream_failed=True,
+ session=Mock(),
+ )
+ )
self.assertEqual(len(dep_statuses), 1)
self.assertFalse(dep_statuses[0].passed)
self.assertEqual(ti.state, State.SKIPPED)
@@ -214,36 +237,42 @@ def test_none_failed_tr_success(self):
"""
All success including skip trigger rule success
"""
- ti = self._get_task_instance(TriggerRule.NONE_FAILED,
- upstream_task_ids=["FakeTaskID",
- "OtherFakeTaskID"])
- dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
- ti=ti,
- successes=1,
- skipped=1,
- failed=0,
- upstream_failed=0,
- done=2,
- flag_upstream_failed=False,
- session="Fake Session"))
+ ti = self._get_task_instance(
+ TriggerRule.NONE_FAILED, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"]
+ )
+ dep_statuses = tuple(
+ TriggerRuleDep()._evaluate_trigger_rule(
+ ti=ti,
+ successes=1,
+ skipped=1,
+ failed=0,
+ upstream_failed=0,
+ done=2,
+ flag_upstream_failed=False,
+ session="Fake Session",
+ )
+ )
self.assertEqual(len(dep_statuses), 0)
def test_none_failed_tr_skipped(self):
"""
All success including all upstream skips trigger rule success
"""
- ti = self._get_task_instance(TriggerRule.NONE_FAILED,
- upstream_task_ids=["FakeTaskID",
- "OtherFakeTaskID"])
- dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
- ti=ti,
- successes=0,
- skipped=2,
- failed=0,
- upstream_failed=0,
- done=2,
- flag_upstream_failed=True,
- session=Mock()))
+ ti = self._get_task_instance(
+ TriggerRule.NONE_FAILED, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"]
+ )
+ dep_statuses = tuple(
+ TriggerRuleDep()._evaluate_trigger_rule(
+ ti=ti,
+ successes=0,
+ skipped=2,
+ failed=0,
+ upstream_failed=0,
+ done=2,
+ flag_upstream_failed=True,
+ session=Mock(),
+ )
+ )
self.assertEqual(len(dep_statuses), 0)
self.assertEqual(ti.state, State.NONE)
@@ -251,19 +280,21 @@ def test_none_failed_tr_failure(self):
"""
All success including skip trigger rule failure
"""
- ti = self._get_task_instance(TriggerRule.NONE_FAILED,
- upstream_task_ids=["FakeTaskID",
- "OtherFakeTaskID",
- "FailedFakeTaskID"])
- dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
- ti=ti,
- successes=1,
- skipped=1,
- failed=1,
- upstream_failed=0,
- done=3,
- flag_upstream_failed=False,
- session="Fake Session"))
+ ti = self._get_task_instance(
+ TriggerRule.NONE_FAILED, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID", "FailedFakeTaskID"]
+ )
+ dep_statuses = tuple(
+ TriggerRuleDep()._evaluate_trigger_rule(
+ ti=ti,
+ successes=1,
+ skipped=1,
+ failed=1,
+ upstream_failed=0,
+ done=3,
+ flag_upstream_failed=False,
+ session="Fake Session",
+ )
+ )
self.assertEqual(len(dep_statuses), 1)
self.assertFalse(dep_statuses[0].passed)
@@ -271,36 +302,42 @@ def test_none_failed_or_skipped_tr_success(self):
"""
All success including skip trigger rule success
"""
- ti = self._get_task_instance(TriggerRule.NONE_FAILED_OR_SKIPPED,
- upstream_task_ids=["FakeTaskID",
- "OtherFakeTaskID"])
- dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
- ti=ti,
- successes=1,
- skipped=1,
- failed=0,
- upstream_failed=0,
- done=2,
- flag_upstream_failed=False,
- session="Fake Session"))
+ ti = self._get_task_instance(
+ TriggerRule.NONE_FAILED_OR_SKIPPED, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"]
+ )
+ dep_statuses = tuple(
+ TriggerRuleDep()._evaluate_trigger_rule(
+ ti=ti,
+ successes=1,
+ skipped=1,
+ failed=0,
+ upstream_failed=0,
+ done=2,
+ flag_upstream_failed=False,
+ session="Fake Session",
+ )
+ )
self.assertEqual(len(dep_statuses), 0)
def test_none_failed_or_skipped_tr_skipped(self):
"""
All success including all upstream skips trigger rule success
"""
- ti = self._get_task_instance(TriggerRule.NONE_FAILED_OR_SKIPPED,
- upstream_task_ids=["FakeTaskID",
- "OtherFakeTaskID"])
- dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
- ti=ti,
- successes=0,
- skipped=2,
- failed=0,
- upstream_failed=0,
- done=2,
- flag_upstream_failed=True,
- session=Mock()))
+ ti = self._get_task_instance(
+ TriggerRule.NONE_FAILED_OR_SKIPPED, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"]
+ )
+ dep_statuses = tuple(
+ TriggerRuleDep()._evaluate_trigger_rule(
+ ti=ti,
+ successes=0,
+ skipped=2,
+ failed=0,
+ upstream_failed=0,
+ done=2,
+ flag_upstream_failed=True,
+ session=Mock(),
+ )
+ )
self.assertEqual(len(dep_statuses), 0)
self.assertEqual(ti.state, State.SKIPPED)
@@ -308,19 +345,22 @@ def test_none_failed_or_skipped_tr_failure(self):
"""
All success including skip trigger rule failure
"""
- ti = self._get_task_instance(TriggerRule.NONE_FAILED_OR_SKIPPED,
- upstream_task_ids=["FakeTaskID",
- "OtherFakeTaskID",
- "FailedFakeTaskID"])
- dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
- ti=ti,
- successes=1,
- skipped=1,
- failed=1,
- upstream_failed=0,
- done=3,
- flag_upstream_failed=False,
- session="Fake Session"))
+ ti = self._get_task_instance(
+ TriggerRule.NONE_FAILED_OR_SKIPPED,
+ upstream_task_ids=["FakeTaskID", "OtherFakeTaskID", "FailedFakeTaskID"],
+ )
+ dep_statuses = tuple(
+ TriggerRuleDep()._evaluate_trigger_rule(
+ ti=ti,
+ successes=1,
+ skipped=1,
+ failed=1,
+ upstream_failed=0,
+ done=3,
+ flag_upstream_failed=False,
+ session="Fake Session",
+ )
+ )
self.assertEqual(len(dep_statuses), 1)
self.assertFalse(dep_statuses[0].passed)
@@ -328,36 +368,42 @@ def test_all_failed_tr_success(self):
"""
All-failed trigger rule success
"""
- ti = self._get_task_instance(TriggerRule.ALL_FAILED,
- upstream_task_ids=["FakeTaskID",
- "OtherFakeTaskID"])
- dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
- ti=ti,
- successes=0,
- skipped=0,
- failed=2,
- upstream_failed=0,
- done=2,
- flag_upstream_failed=False,
- session="Fake Session"))
+ ti = self._get_task_instance(
+ TriggerRule.ALL_FAILED, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"]
+ )
+ dep_statuses = tuple(
+ TriggerRuleDep()._evaluate_trigger_rule(
+ ti=ti,
+ successes=0,
+ skipped=0,
+ failed=2,
+ upstream_failed=0,
+ done=2,
+ flag_upstream_failed=False,
+ session="Fake Session",
+ )
+ )
self.assertEqual(len(dep_statuses), 0)
def test_all_failed_tr_failure(self):
"""
All-failed trigger rule failure
"""
- ti = self._get_task_instance(TriggerRule.ALL_FAILED,
- upstream_task_ids=["FakeTaskID",
- "OtherFakeTaskID"])
- dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
- ti=ti,
- successes=2,
- skipped=0,
- failed=0,
- upstream_failed=0,
- done=2,
- flag_upstream_failed=False,
- session="Fake Session"))
+ ti = self._get_task_instance(
+ TriggerRule.ALL_FAILED, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"]
+ )
+ dep_statuses = tuple(
+ TriggerRuleDep()._evaluate_trigger_rule(
+ ti=ti,
+ successes=2,
+ skipped=0,
+ failed=0,
+ upstream_failed=0,
+ done=2,
+ flag_upstream_failed=False,
+ session="Fake Session",
+ )
+ )
self.assertEqual(len(dep_statuses), 1)
self.assertFalse(dep_statuses[0].passed)
@@ -365,36 +411,42 @@ def test_all_done_tr_success(self):
"""
All-done trigger rule success
"""
- ti = self._get_task_instance(TriggerRule.ALL_DONE,
- upstream_task_ids=["FakeTaskID",
- "OtherFakeTaskID"])
- dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
- ti=ti,
- successes=2,
- skipped=0,
- failed=0,
- upstream_failed=0,
- done=2,
- flag_upstream_failed=False,
- session="Fake Session"))
+ ti = self._get_task_instance(
+ TriggerRule.ALL_DONE, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"]
+ )
+ dep_statuses = tuple(
+ TriggerRuleDep()._evaluate_trigger_rule(
+ ti=ti,
+ successes=2,
+ skipped=0,
+ failed=0,
+ upstream_failed=0,
+ done=2,
+ flag_upstream_failed=False,
+ session="Fake Session",
+ )
+ )
self.assertEqual(len(dep_statuses), 0)
def test_all_done_tr_failure(self):
"""
All-done trigger rule failure
"""
- ti = self._get_task_instance(TriggerRule.ALL_DONE,
- upstream_task_ids=["FakeTaskID",
- "OtherFakeTaskID"])
- dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
- ti=ti,
- successes=1,
- skipped=0,
- failed=0,
- upstream_failed=0,
- done=1,
- flag_upstream_failed=False,
- session="Fake Session"))
+ ti = self._get_task_instance(
+ TriggerRule.ALL_DONE, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"]
+ )
+ dep_statuses = tuple(
+ TriggerRuleDep()._evaluate_trigger_rule(
+ ti=ti,
+ successes=1,
+ skipped=0,
+ failed=0,
+ upstream_failed=0,
+ done=1,
+ flag_upstream_failed=False,
+ session="Fake Session",
+ )
+ )
self.assertEqual(len(dep_statuses), 1)
self.assertFalse(dep_statuses[0].passed)
@@ -403,78 +455,92 @@ def test_none_skipped_tr_success(self):
None-skipped trigger rule success
"""
- ti = self._get_task_instance(TriggerRule.NONE_SKIPPED,
- upstream_task_ids=["FakeTaskID",
- "OtherFakeTaskID",
- "FailedFakeTaskID"])
+ ti = self._get_task_instance(
+ TriggerRule.NONE_SKIPPED, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID", "FailedFakeTaskID"]
+ )
with create_session() as session:
- dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
- ti=ti,
- successes=2,
- skipped=0,
- failed=1,
- upstream_failed=0,
- done=3,
- flag_upstream_failed=False,
- session=session))
+ dep_statuses = tuple(
+ TriggerRuleDep()._evaluate_trigger_rule(
+ ti=ti,
+ successes=2,
+ skipped=0,
+ failed=1,
+ upstream_failed=0,
+ done=3,
+ flag_upstream_failed=False,
+ session=session,
+ )
+ )
self.assertEqual(len(dep_statuses), 0)
# with `flag_upstream_failed` set to True
- dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
- ti=ti,
- successes=0,
- skipped=0,
- failed=3,
- upstream_failed=0,
- done=3,
- flag_upstream_failed=True,
- session=session))
+ dep_statuses = tuple(
+ TriggerRuleDep()._evaluate_trigger_rule(
+ ti=ti,
+ successes=0,
+ skipped=0,
+ failed=3,
+ upstream_failed=0,
+ done=3,
+ flag_upstream_failed=True,
+ session=session,
+ )
+ )
self.assertEqual(len(dep_statuses), 0)
def test_none_skipped_tr_failure(self):
"""
None-skipped trigger rule failure
"""
- ti = self._get_task_instance(TriggerRule.NONE_SKIPPED,
- upstream_task_ids=["FakeTaskID",
- "SkippedTaskID"])
+ ti = self._get_task_instance(
+ TriggerRule.NONE_SKIPPED, upstream_task_ids=["FakeTaskID", "SkippedTaskID"]
+ )
with create_session() as session:
- dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
- ti=ti,
- successes=1,
- skipped=1,
- failed=0,
- upstream_failed=0,
- done=2,
- flag_upstream_failed=False,
- session=session))
+ dep_statuses = tuple(
+ TriggerRuleDep()._evaluate_trigger_rule(
+ ti=ti,
+ successes=1,
+ skipped=1,
+ failed=0,
+ upstream_failed=0,
+ done=2,
+ flag_upstream_failed=False,
+ session=session,
+ )
+ )
self.assertEqual(len(dep_statuses), 1)
self.assertFalse(dep_statuses[0].passed)
# with `flag_upstream_failed` set to True
- dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
- ti=ti,
- successes=1,
- skipped=1,
- failed=0,
- upstream_failed=0,
- done=2,
- flag_upstream_failed=True,
- session=session))
+ dep_statuses = tuple(
+ TriggerRuleDep()._evaluate_trigger_rule(
+ ti=ti,
+ successes=1,
+ skipped=1,
+ failed=0,
+ upstream_failed=0,
+ done=2,
+ flag_upstream_failed=True,
+ session=session,
+ )
+ )
self.assertEqual(len(dep_statuses), 1)
self.assertFalse(dep_statuses[0].passed)
# Fail until all upstream tasks have completed execution
- dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
- ti=ti,
- successes=0,
- skipped=0,
- failed=0,
- upstream_failed=0,
- done=0,
- flag_upstream_failed=False,
- session=session))
+ dep_statuses = tuple(
+ TriggerRuleDep()._evaluate_trigger_rule(
+ ti=ti,
+ successes=0,
+ skipped=0,
+ failed=0,
+ upstream_failed=0,
+ done=0,
+ flag_upstream_failed=False,
+ session=session,
+ )
+ )
self.assertEqual(len(dep_statuses), 1)
self.assertFalse(dep_statuses[0].passed)
@@ -484,15 +550,18 @@ def test_unknown_tr(self):
"""
ti = self._get_task_instance()
ti.task.trigger_rule = "Unknown Trigger Rule"
- dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
- ti=ti,
- successes=1,
- skipped=0,
- failed=0,
- upstream_failed=0,
- done=1,
- flag_upstream_failed=False,
- session="Fake Session"))
+ dep_statuses = tuple(
+ TriggerRuleDep()._evaluate_trigger_rule(
+ ti=ti,
+ successes=1,
+ skipped=0,
+ failed=0,
+ upstream_failed=0,
+ done=1,
+ flag_upstream_failed=False,
+ session="Fake Session",
+ )
+ )
self.assertEqual(len(dep_statuses), 1)
self.assertFalse(dep_statuses[0].passed)
@@ -506,10 +575,7 @@ def test_get_states_count_upstream_ti(self):
get_states_count_upstream_ti = TriggerRuleDep._get_states_count_upstream_ti
session = settings.Session()
now = timezone.utcnow()
- dag = DAG(
- 'test_dagrun_with_pre_tis',
- start_date=DEFAULT_DATE,
- default_args={'owner': 'owner1'})
+ dag = DAG('test_dagrun_with_pre_tis', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
with dag:
op1 = DummyOperator(task_id='A')
@@ -524,10 +590,9 @@ def test_get_states_count_upstream_ti(self):
clear_db_runs()
dag.clear()
- dr = dag.create_dagrun(run_id='test_dagrun_with_pre_tis',
- state=State.RUNNING,
- execution_date=now,
- start_date=now)
+ dr = dag.create_dagrun(
+ run_id='test_dagrun_with_pre_tis', state=State.RUNNING, execution_date=now, start_date=now
+ )
ti_op1 = TaskInstance(task=dag.get_task(op1.task_id), execution_date=dr.execution_date)
ti_op2 = TaskInstance(task=dag.get_task(op2.task_id), execution_date=dr.execution_date)
@@ -545,13 +610,16 @@ def test_get_states_count_upstream_ti(self):
# check handling with cases that tasks are triggered from backfill with no finished tasks
finished_tasks = DepContext().ensure_finished_tasks(ti_op2.task.dag, ti_op2.execution_date, session)
- self.assertEqual(get_states_count_upstream_ti(finished_tasks=finished_tasks, ti=ti_op2),
- (1, 0, 0, 0, 1))
+ self.assertEqual(
+ get_states_count_upstream_ti(finished_tasks=finished_tasks, ti=ti_op2), (1, 0, 0, 0, 1)
+ )
finished_tasks = dr.get_task_instances(state=State.finished, session=session)
- self.assertEqual(get_states_count_upstream_ti(finished_tasks=finished_tasks, ti=ti_op4),
- (1, 0, 1, 0, 2))
- self.assertEqual(get_states_count_upstream_ti(finished_tasks=finished_tasks, ti=ti_op5),
- (2, 0, 1, 0, 3))
+ self.assertEqual(
+ get_states_count_upstream_ti(finished_tasks=finished_tasks, ti=ti_op4), (1, 0, 1, 0, 2)
+ )
+ self.assertEqual(
+ get_states_count_upstream_ti(finished_tasks=finished_tasks, ti=ti_op5), (2, 0, 1, 0, 3)
+ )
dr.update_state()
self.assertEqual(State.SUCCESS, dr.state)
diff --git a/tests/ti_deps/deps/test_valid_state_dep.py b/tests/ti_deps/deps/test_valid_state_dep.py
index 74c7e2aaf9d48..7e6ee7fcf2fac 100644
--- a/tests/ti_deps/deps/test_valid_state_dep.py
+++ b/tests/ti_deps/deps/test_valid_state_dep.py
@@ -26,7 +26,6 @@
class TestValidStateDep(unittest.TestCase):
-
def test_valid_state(self):
"""
Valid state should pass this dep
diff --git a/tests/utils/log/test_file_processor_handler.py b/tests/utils/log/test_file_processor_handler.py
index a3556e5cc7a83..20115059bc7a8 100644
--- a/tests/utils/log/test_file_processor_handler.py
+++ b/tests/utils/log/test_file_processor_handler.py
@@ -37,8 +37,7 @@ def setUp(self):
def test_non_template(self):
date = timezone.utcnow().strftime("%Y-%m-%d")
- handler = FileProcessorHandler(base_log_folder=self.base_log_folder,
- filename_template=self.filename)
+ handler = FileProcessorHandler(base_log_folder=self.base_log_folder, filename_template=self.filename)
handler.dag_dir = self.dag_dir
path = os.path.join(self.base_log_folder, "latest")
@@ -50,8 +49,9 @@ def test_non_template(self):
def test_template(self):
date = timezone.utcnow().strftime("%Y-%m-%d")
- handler = FileProcessorHandler(base_log_folder=self.base_log_folder,
- filename_template=self.filename_template)
+ handler = FileProcessorHandler(
+ base_log_folder=self.base_log_folder, filename_template=self.filename_template
+ )
handler.dag_dir = self.dag_dir
path = os.path.join(self.base_log_folder, "latest")
@@ -62,8 +62,7 @@ def test_template(self):
self.assertTrue(os.path.exists(os.path.join(path, "logfile.log")))
def test_symlink_latest_log_directory(self):
- handler = FileProcessorHandler(base_log_folder=self.base_log_folder,
- filename_template=self.filename)
+ handler = FileProcessorHandler(base_log_folder=self.base_log_folder, filename_template=self.filename)
handler.dag_dir = self.dag_dir
date1 = (timezone.utcnow() + timedelta(days=1)).strftime("%Y-%m-%d")
@@ -92,8 +91,7 @@ def test_symlink_latest_log_directory(self):
self.assertTrue(os.path.exists(os.path.join(link, "log2")))
def test_symlink_latest_log_directory_exists(self):
- handler = FileProcessorHandler(base_log_folder=self.base_log_folder,
- filename_template=self.filename)
+ handler = FileProcessorHandler(base_log_folder=self.base_log_folder, filename_template=self.filename)
handler.dag_dir = self.dag_dir
date1 = (timezone.utcnow() + timedelta(days=1)).strftime("%Y-%m-%d")
diff --git a/tests/utils/log/test_json_formatter.py b/tests/utils/log/test_json_formatter.py
index a14bcc3c80540..ba827fb25e87e 100644
--- a/tests/utils/log/test_json_formatter.py
+++ b/tests/utils/log/test_json_formatter.py
@@ -30,6 +30,7 @@ class TestJSONFormatter(unittest.TestCase):
"""
TestJSONFormatter class combine all tests for JSONFormatter
"""
+
def test_json_formatter_is_not_none(self):
"""
JSONFormatter instance should return not none
@@ -61,5 +62,6 @@ def test_format_with_extras(self):
log_record = makeLogRecord({"label": "value"})
json_fmt = JSONFormatter(json_fields=["label"], extras={'pod_extra': 'useful_message'})
# compare as a dicts to not fail on sorting errors
- self.assertDictEqual(json.loads(json_fmt.format(log_record)),
- {"label": "value", "pod_extra": "useful_message"})
+ self.assertDictEqual(
+ json.loads(json_fmt.format(log_record)), {"label": "value", "pod_extra": "useful_message"}
+ )
diff --git a/tests/utils/log/test_log_reader.py b/tests/utils/log/test_log_reader.py
index 2ae98c4534cff..b4d1331e4effa 100644
--- a/tests/utils/log/test_log_reader.py
+++ b/tests/utils/log/test_log_reader.py
@@ -105,10 +105,12 @@ def test_test_read_log_chunks_should_read_one_try(self):
self.assertEqual(
[
- ('',
- f"*** Reading local file: "
- f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/1.log\n"
- f"try_number=1.\n")
+ (
+ '',
+ f"*** Reading local file: "
+ f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/1.log\n"
+ f"try_number=1.\n",
+ )
],
logs[0],
)
@@ -120,18 +122,30 @@ def test_test_read_log_chunks_should_read_all_files(self):
self.assertEqual(
[
- [('',
- "*** Reading local file: "
- f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/1.log\n"
- "try_number=1.\n")],
- [('',
- f"*** Reading local file: "
- f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/2.log\n"
- f"try_number=2.\n")],
- [('',
- f"*** Reading local file: "
- f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/3.log\n"
- f"try_number=3.\n")],
+ [
+ (
+ '',
+ "*** Reading local file: "
+ f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/1.log\n"
+ "try_number=1.\n",
+ )
+ ],
+ [
+ (
+ '',
+ f"*** Reading local file: "
+ f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/2.log\n"
+ f"try_number=2.\n",
+ )
+ ],
+ [
+ (
+ '',
+ f"*** Reading local file: "
+ f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/3.log\n"
+ f"try_number=3.\n",
+ )
+ ],
],
logs,
)
diff --git a/tests/utils/test_cli_util.py b/tests/utils/test_cli_util.py
index dfa71e10c7413..64cd57fd16775 100644
--- a/tests/utils/test_cli_util.py
+++ b/tests/utils/test_cli_util.py
@@ -33,19 +33,19 @@
class TestCliUtil(unittest.TestCase):
-
def test_metrics_build(self):
func_name = 'test'
exec_date = datetime.utcnow()
- namespace = Namespace(dag_id='foo', task_id='bar',
- subcommand='test', execution_date=exec_date)
+ namespace = Namespace(dag_id='foo', task_id='bar', subcommand='test', execution_date=exec_date)
metrics = cli._build_metrics(func_name, namespace)
- expected = {'user': os.environ.get('USER'),
- 'sub_command': 'test',
- 'dag_id': 'foo',
- 'task_id': 'bar',
- 'execution_date': exec_date}
+ expected = {
+ 'user': os.environ.get('USER'),
+ 'sub_command': 'test',
+ 'dag_id': 'foo',
+ 'task_id': 'bar',
+ 'execution_date': exec_date,
+ }
for k, v in expected.items():
self.assertEqual(v, metrics.get(k))
@@ -93,24 +93,24 @@ def test_get_dags(self):
[
(
"airflow users create -u test2 -l doe -f jon -e jdoe@apache.org -r admin --password test",
- "airflow users create -u test2 -l doe -f jon -e jdoe@apache.org -r admin --password ********"
+ "airflow users create -u test2 -l doe -f jon -e jdoe@apache.org -r admin --password ********",
),
(
"airflow users create -u test2 -l doe -f jon -e jdoe@apache.org -r admin -p test",
- "airflow users create -u test2 -l doe -f jon -e jdoe@apache.org -r admin -p ********"
+ "airflow users create -u test2 -l doe -f jon -e jdoe@apache.org -r admin -p ********",
),
(
"airflow users create -u test2 -l doe -f jon -e jdoe@apache.org -r admin --password=test",
- "airflow users create -u test2 -l doe -f jon -e jdoe@apache.org -r admin --password=********"
+ "airflow users create -u test2 -l doe -f jon -e jdoe@apache.org -r admin --password=********",
),
(
"airflow users create -u test2 -l doe -f jon -e jdoe@apache.org -r admin -p=test",
- "airflow users create -u test2 -l doe -f jon -e jdoe@apache.org -r admin -p=********"
+ "airflow users create -u test2 -l doe -f jon -e jdoe@apache.org -r admin -p=********",
),
(
"airflow connections add dsfs --conn-login asd --conn-password test --conn-type google",
"airflow connections add dsfs --conn-login asd --conn-password ******** --conn-type google",
- )
+ ),
]
)
def test_cli_create_user_supplied_password_is_masked(self, given_command, expected_masked_command):
diff --git a/tests/utils/test_compression.py b/tests/utils/test_compression.py
index 1e62e8cf6ee18..b288ca60fde64 100644
--- a/tests/utils/test_compression.py
+++ b/tests/utils/test_compression.py
@@ -29,7 +29,6 @@
class TestCompression(unittest.TestCase):
-
def setUp(self):
self.file_names = {}
try:
@@ -38,15 +37,12 @@ def setUp(self):
line2 = b"2\tCompressionUtil\n"
self.tmp_dir = tempfile.mkdtemp(prefix='test_utils_compression_')
# create sample txt, gz and bz2 files
- with tempfile.NamedTemporaryFile(mode='wb+',
- dir=self.tmp_dir,
- delete=False) as f_txt:
+ with tempfile.NamedTemporaryFile(mode='wb+', dir=self.tmp_dir, delete=False) as f_txt:
self._set_fn(f_txt.name, '.txt')
f_txt.writelines([header, line1, line2])
fn_gz = self._get_fn('.txt') + ".gz"
- with gzip.GzipFile(filename=fn_gz,
- mode="wb") as f_gz:
+ with gzip.GzipFile(filename=fn_gz, mode="wb") as f_gz:
self._set_fn(fn_gz, '.gz')
f_gz.writelines([header, line1, line2])
@@ -80,21 +76,22 @@ def _get_fn(self, ext):
def test_uncompress_file(self):
# Testing txt file type
- self.assertRaisesRegex(NotImplementedError,
- "^Received .txt format. Only gz and bz2.*",
- compression.uncompress_file,
- **{'input_file_name': None,
- 'file_extension': '.txt',
- 'dest_dir': None
- })
+ self.assertRaisesRegex(
+ NotImplementedError,
+ "^Received .txt format. Only gz and bz2.*",
+ compression.uncompress_file,
+ **{'input_file_name': None, 'file_extension': '.txt', 'dest_dir': None},
+ )
# Testing gz file type
fn_txt = self._get_fn('.txt')
fn_gz = self._get_fn('.gz')
txt_gz = compression.uncompress_file(fn_gz, '.gz', self.tmp_dir)
- self.assertTrue(filecmp.cmp(txt_gz, fn_txt, shallow=False),
- msg="Uncompressed file doest match original")
+ self.assertTrue(
+ filecmp.cmp(txt_gz, fn_txt, shallow=False), msg="Uncompressed file doest match original"
+ )
# Testing bz2 file type
fn_bz2 = self._get_fn('.bz2')
txt_bz2 = compression.uncompress_file(fn_bz2, '.bz2', self.tmp_dir)
- self.assertTrue(filecmp.cmp(txt_bz2, fn_txt, shallow=False),
- msg="Uncompressed file doest match original")
+ self.assertTrue(
+ filecmp.cmp(txt_bz2, fn_txt, shallow=False), msg="Uncompressed file doest match original"
+ )
diff --git a/tests/utils/test_dag_cycle.py b/tests/utils/test_dag_cycle.py
index 320faafb45d4d..fec7440e82e12 100644
--- a/tests/utils/test_dag_cycle.py
+++ b/tests/utils/test_dag_cycle.py
@@ -27,19 +27,13 @@
class TestCycleTester(unittest.TestCase):
def test_cycle_empty(self):
# test empty
- dag = DAG(
- 'dag',
- start_date=DEFAULT_DATE,
- default_args={'owner': 'owner1'})
+ dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
self.assertFalse(_test_cycle(dag))
def test_cycle_single_task(self):
# test single task
- dag = DAG(
- 'dag',
- start_date=DEFAULT_DATE,
- default_args={'owner': 'owner1'})
+ dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
with dag:
DummyOperator(task_id='A')
@@ -47,10 +41,7 @@ def test_cycle_single_task(self):
self.assertFalse(_test_cycle(dag))
def test_semi_complex(self):
- dag = DAG(
- 'dag',
- start_date=DEFAULT_DATE,
- default_args={'owner': 'owner1'})
+ dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
# A -> B -> C
# B -> D
@@ -67,10 +58,7 @@ def test_semi_complex(self):
def test_cycle_no_cycle(self):
# test no cycle
- dag = DAG(
- 'dag',
- start_date=DEFAULT_DATE,
- default_args={'owner': 'owner1'})
+ dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
# A -> B -> C
# B -> D
@@ -91,10 +79,7 @@ def test_cycle_no_cycle(self):
def test_cycle_loop(self):
# test self loop
- dag = DAG(
- 'dag',
- start_date=DEFAULT_DATE,
- default_args={'owner': 'owner1'})
+ dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
# A -> A
with dag:
@@ -106,10 +91,7 @@ def test_cycle_loop(self):
def test_cycle_downstream_loop(self):
# test downstream self loop
- dag = DAG(
- 'dag',
- start_date=DEFAULT_DATE,
- default_args={'owner': 'owner1'})
+ dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
# A -> B -> C -> D -> E -> E
with dag:
@@ -129,10 +111,7 @@ def test_cycle_downstream_loop(self):
def test_cycle_large_loop(self):
# large loop
- dag = DAG(
- 'dag',
- start_date=DEFAULT_DATE,
- default_args={'owner': 'owner1'})
+ dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
# A -> B -> C -> D -> E -> A
with dag:
@@ -150,10 +129,7 @@ def test_cycle_large_loop(self):
def test_cycle_arbitrary_loop(self):
# test arbitrary loop
- dag = DAG(
- 'dag',
- start_date=DEFAULT_DATE,
- default_args={'owner': 'owner1'})
+ dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
# E-> A -> B -> F -> A
# -> C -> F
diff --git a/tests/utils/test_dag_processing.py b/tests/utils/test_dag_processing.py
index 84feff7f654f9..8a9dc342555a9 100644
--- a/tests/utils/test_dag_processing.py
+++ b/tests/utils/test_dag_processing.py
@@ -36,7 +36,11 @@
from airflow.utils import timezone
from airflow.utils.callback_requests import TaskCallbackRequest
from airflow.utils.dag_processing import (
- DagFileProcessorAgent, DagFileProcessorManager, DagFileStat, DagParsingSignal, DagParsingStat,
+ DagFileProcessorAgent,
+ DagFileProcessorManager,
+ DagFileStat,
+ DagParsingSignal,
+ DagParsingStat,
)
from airflow.utils.file import correct_maybe_zipped, open_maybe_zipped
from airflow.utils.session import create_session
@@ -45,8 +49,7 @@
from tests.test_utils.config import conf_vars
from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags
-TEST_DAG_FOLDER = os.path.join(
- os.path.dirname(os.path.realpath(__file__)), os.pardir, 'dags')
+TEST_DAG_FOLDER = os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir, 'dags')
DEFAULT_DATE = timezone.datetime(2016, 1, 1)
@@ -132,7 +135,8 @@ def test_max_runs_when_no_files(self):
signal_conn=child_pipe,
dag_ids=[],
pickle_dags=False,
- async_mode=async_mode)
+ async_mode=async_mode,
+ )
self.run_processor_manager_one_loop(manager, parent_pipe)
child_pipe.close()
@@ -147,11 +151,11 @@ def test_set_file_paths_when_processor_file_path_not_in_new_file_paths(self):
signal_conn=MagicMock(),
dag_ids=[],
pickle_dags=False,
- async_mode=True)
+ async_mode=True,
+ )
mock_processor = MagicMock()
- mock_processor.stop.side_effect = AttributeError(
- 'DagFileProcessor object has no attribute stop')
+ mock_processor.stop.side_effect = AttributeError('DagFileProcessor object has no attribute stop')
mock_processor.terminate.side_effect = None
manager._processors['missing_file.txt'] = mock_processor
@@ -169,11 +173,11 @@ def test_set_file_paths_when_processor_file_path_is_in_new_file_paths(self):
signal_conn=MagicMock(),
dag_ids=[],
pickle_dags=False,
- async_mode=True)
+ async_mode=True,
+ )
mock_processor = MagicMock()
- mock_processor.stop.side_effect = AttributeError(
- 'DagFileProcessor object has no attribute stop')
+ mock_processor.stop.side_effect = AttributeError('DagFileProcessor object has no attribute stop')
mock_processor.terminate.side_effect = None
manager._processors['abc.txt'] = mock_processor
@@ -190,7 +194,8 @@ def test_find_zombies(self):
signal_conn=MagicMock(),
dag_ids=[],
pickle_dags=False,
- async_mode=True)
+ async_mode=True,
+ )
dagbag = DagBag(TEST_DAG_FOLDER, read_dags_from_db=False)
with create_session() as session:
@@ -211,7 +216,8 @@ def test_find_zombies(self):
session.commit()
manager._last_zombie_query_time = timezone.utcnow() - timedelta(
- seconds=manager._zombie_threshold_secs + 1)
+ seconds=manager._zombie_threshold_secs + 1
+ )
manager._find_zombies() # pylint: disable=no-value-for-parameter
requests = manager._callback_to_execute[dag.full_filepath]
self.assertEqual(1, len(requests))
@@ -232,8 +238,7 @@ def test_handle_failure_callback_with_zombies_are_correctly_passed_to_dag_file_p
file processors until the next zombie detection logic is invoked.
"""
test_dag_path = os.path.join(TEST_DAG_FOLDER, 'test_example_bash_operator.py')
- with conf_vars({('scheduler', 'max_threads'): '1',
- ('core', 'load_examples'): 'False'}):
+ with conf_vars({('scheduler', 'max_threads'): '1', ('core', 'load_examples'): 'False'}):
dagbag = DagBag(test_dag_path, read_dags_from_db=False)
with create_session() as session:
session.query(LJ).delete()
@@ -257,7 +262,7 @@ def test_handle_failure_callback_with_zombies_are_correctly_passed_to_dag_file_p
TaskCallbackRequest(
full_filepath=dag.full_filepath,
simple_task_instance=SimpleTaskInstance(ti),
- msg="Message"
+ msg="Message",
)
]
@@ -282,7 +287,8 @@ def fake_processor_factory(*args, **kwargs):
signal_conn=child_pipe,
dag_ids=[],
pickle_dags=False,
- async_mode=async_mode)
+ async_mode=async_mode,
+ )
self.run_processor_manager_one_loop(manager, parent_pipe)
@@ -296,10 +302,9 @@ def fake_processor_factory(*args, **kwargs):
assert fake_processors[-1]._file_path == test_dag_path
callback_requests = fake_processors[-1]._callback_requests
- assert (
- {zombie.simple_task_instance.key for zombie in expected_failure_callback_requests} ==
- {result.simple_task_instance.key for result in callback_requests}
- )
+ assert {zombie.simple_task_instance.key for zombie in expected_failure_callback_requests} == {
+ result.simple_task_instance.key for result in callback_requests
+ }
child_pipe.close()
parent_pipe.close()
@@ -316,7 +321,8 @@ def test_kill_timed_out_processors_kill(self, mock_kill, mock_pid):
signal_conn=MagicMock(),
dag_ids=[],
pickle_dags=False,
- async_mode=True)
+ async_mode=True,
+ )
processor = DagFileProcessorProcess('abc.txt', False, [], [])
processor._start_time = timezone.make_aware(datetime.min)
@@ -336,7 +342,8 @@ def test_kill_timed_out_processors_no_kill(self, mock_dag_file_processor, mock_p
signal_conn=MagicMock(),
dag_ids=[],
pickle_dags=False,
- async_mode=True)
+ async_mode=True,
+ )
processor = DagFileProcessorProcess('abc.txt', False, [], [])
processor._start_time = timezone.make_aware(datetime.max)
@@ -373,7 +380,8 @@ def test_dag_with_system_exit(self):
processor_timeout=timedelta(seconds=5),
signal_conn=child_pipe,
pickle_dags=False,
- async_mode=True)
+ async_mode=True,
+ )
manager._run_parsing_loop()
@@ -407,10 +415,7 @@ def tearDown(self):
@staticmethod
def _processor_factory(file_path, zombies, dag_ids, pickle_dags):
- return DagFileProcessorProcess(file_path,
- pickle_dags,
- dag_ids,
- zombies)
+ return DagFileProcessorProcess(file_path, pickle_dags, dag_ids, zombies)
def test_reload_module(self):
"""
@@ -431,13 +436,9 @@ class path, thus when reloading logging module the airflow.processor_manager
pass
# Starting dag processing with 0 max_runs to avoid redundant operations.
- processor_agent = DagFileProcessorAgent(test_dag_path,
- 0,
- type(self)._processor_factory,
- timedelta.max,
- [],
- False,
- async_mode)
+ processor_agent = DagFileProcessorAgent(
+ test_dag_path, 0, type(self)._processor_factory, timedelta.max, [], False, async_mode
+ )
processor_agent.start()
if not async_mode:
processor_agent.run_single_parsing_loop()
@@ -455,13 +456,9 @@ def test_parse_once(self):
test_dag_path = os.path.join(TEST_DAG_FOLDER, 'test_scheduler_dags.py')
async_mode = 'sqlite' not in conf.get('core', 'sql_alchemy_conn')
- processor_agent = DagFileProcessorAgent(test_dag_path,
- 1,
- type(self)._processor_factory,
- timedelta.max,
- [],
- False,
- async_mode)
+ processor_agent = DagFileProcessorAgent(
+ test_dag_path, 1, type(self)._processor_factory, timedelta.max, [], False, async_mode
+ )
processor_agent.start()
if not async_mode:
processor_agent.run_single_parsing_loop()
@@ -491,13 +488,9 @@ def test_launch_process(self):
pass
# Starting dag processing with 0 max_runs to avoid redundant operations.
- processor_agent = DagFileProcessorAgent(test_dag_path,
- 0,
- type(self)._processor_factory,
- timedelta.max,
- [],
- False,
- async_mode)
+ processor_agent = DagFileProcessorAgent(
+ test_dag_path, 0, type(self)._processor_factory, timedelta.max, [], False, async_mode
+ )
processor_agent.start()
if not async_mode:
processor_agent.run_single_parsing_loop()
diff --git a/tests/utils/test_dates.py b/tests/utils/test_dates.py
index c2f0f7b6861f4..0cab6ae58d574 100644
--- a/tests/utils/test_dates.py
+++ b/tests/utils/test_dates.py
@@ -25,7 +25,6 @@
class TestDates(unittest.TestCase):
-
def test_days_ago(self):
today = pendulum.today()
today_midnight = pendulum.instance(datetime.fromordinal(today.date().toordinal()))
@@ -44,39 +43,31 @@ def test_parse_execution_date(self):
bad_execution_date_str = '2017-11-06TXX:00:00Z'
self.assertEqual(
- timezone.datetime(2017, 11, 2, 0, 0, 0),
- dates.parse_execution_date(execution_date_str_wo_ms))
+ timezone.datetime(2017, 11, 2, 0, 0, 0), dates.parse_execution_date(execution_date_str_wo_ms)
+ )
self.assertEqual(
timezone.datetime(2017, 11, 5, 16, 18, 30, 989729),
- dates.parse_execution_date(execution_date_str_w_ms))
+ dates.parse_execution_date(execution_date_str_w_ms),
+ )
self.assertRaises(ValueError, dates.parse_execution_date, bad_execution_date_str)
class TestUtilsDatesDateRange(unittest.TestCase):
-
def test_no_delta(self):
- self.assertEqual(dates.date_range(datetime(2016, 1, 1), datetime(2016, 1, 3)),
- [])
+ self.assertEqual(dates.date_range(datetime(2016, 1, 1), datetime(2016, 1, 3)), [])
def test_end_date_before_start_date(self):
with self.assertRaisesRegex(Exception, "Wait. start_date needs to be before end_date"):
- dates.date_range(datetime(2016, 2, 1),
- datetime(2016, 1, 1),
- delta=timedelta(seconds=1))
+ dates.date_range(datetime(2016, 2, 1), datetime(2016, 1, 1), delta=timedelta(seconds=1))
def test_both_end_date_and_num_given(self):
with self.assertRaisesRegex(Exception, "Wait. Either specify end_date OR num"):
- dates.date_range(datetime(2016, 1, 1),
- datetime(2016, 1, 3),
- num=2,
- delta=timedelta(seconds=1))
+ dates.date_range(datetime(2016, 1, 1), datetime(2016, 1, 3), num=2, delta=timedelta(seconds=1))
def test_invalid_delta(self):
exception_msg = "Wait. delta must be either datetime.timedelta or cron expression as str"
with self.assertRaisesRegex(Exception, exception_msg):
- dates.date_range(datetime(2016, 1, 1),
- datetime(2016, 1, 3),
- delta=1)
+ dates.date_range(datetime(2016, 1, 1), datetime(2016, 1, 3), delta=1)
def test_positive_num_given(self):
for num in range(1, 10):
diff --git a/tests/utils/test_db.py b/tests/utils/test_db.py
index e724543f6b951..70a1a9c12794e 100644
--- a/tests/utils/test_db.py
+++ b/tests/utils/test_db.py
@@ -32,7 +32,6 @@
class TestDb(unittest.TestCase):
-
def test_database_schema_and_sqlalchemy_model_are_in_sync(self):
all_meta_data = MetaData()
for (table_name, table) in airflow_base.metadata.tables.items():
@@ -45,62 +44,35 @@ def test_database_schema_and_sqlalchemy_model_are_in_sync(self):
# known diffs to ignore
ignores = [
# ignore tables created by celery
- lambda t: (t[0] == 'remove_table' and
- t[1].name == 'celery_taskmeta'),
- lambda t: (t[0] == 'remove_table' and
- t[1].name == 'celery_tasksetmeta'),
-
+ lambda t: (t[0] == 'remove_table' and t[1].name == 'celery_taskmeta'),
+ lambda t: (t[0] == 'remove_table' and t[1].name == 'celery_tasksetmeta'),
# ignore indices created by celery
- lambda t: (t[0] == 'remove_index' and
- t[1].name == 'task_id'),
- lambda t: (t[0] == 'remove_index' and
- t[1].name == 'taskset_id'),
-
+ lambda t: (t[0] == 'remove_index' and t[1].name == 'task_id'),
+ lambda t: (t[0] == 'remove_index' and t[1].name == 'taskset_id'),
# Ignore all the fab tables
- lambda t: (t[0] == 'remove_table' and
- t[1].name == 'ab_permission'),
- lambda t: (t[0] == 'remove_table' and
- t[1].name == 'ab_register_user'),
- lambda t: (t[0] == 'remove_table' and
- t[1].name == 'ab_role'),
- lambda t: (t[0] == 'remove_table' and
- t[1].name == 'ab_permission_view'),
- lambda t: (t[0] == 'remove_table' and
- t[1].name == 'ab_permission_view_role'),
- lambda t: (t[0] == 'remove_table' and
- t[1].name == 'ab_user_role'),
- lambda t: (t[0] == 'remove_table' and
- t[1].name == 'ab_user'),
- lambda t: (t[0] == 'remove_table' and
- t[1].name == 'ab_view_menu'),
-
+ lambda t: (t[0] == 'remove_table' and t[1].name == 'ab_permission'),
+ lambda t: (t[0] == 'remove_table' and t[1].name == 'ab_register_user'),
+ lambda t: (t[0] == 'remove_table' and t[1].name == 'ab_role'),
+ lambda t: (t[0] == 'remove_table' and t[1].name == 'ab_permission_view'),
+ lambda t: (t[0] == 'remove_table' and t[1].name == 'ab_permission_view_role'),
+ lambda t: (t[0] == 'remove_table' and t[1].name == 'ab_user_role'),
+ lambda t: (t[0] == 'remove_table' and t[1].name == 'ab_user'),
+ lambda t: (t[0] == 'remove_table' and t[1].name == 'ab_view_menu'),
# Ignore all the fab indices
- lambda t: (t[0] == 'remove_index' and
- t[1].name == 'permission_id'),
- lambda t: (t[0] == 'remove_index' and
- t[1].name == 'name'),
- lambda t: (t[0] == 'remove_index' and
- t[1].name == 'user_id'),
- lambda t: (t[0] == 'remove_index' and
- t[1].name == 'username'),
- lambda t: (t[0] == 'remove_index' and
- t[1].name == 'field_string'),
- lambda t: (t[0] == 'remove_index' and
- t[1].name == 'email'),
- lambda t: (t[0] == 'remove_index' and
- t[1].name == 'permission_view_id'),
-
+ lambda t: (t[0] == 'remove_index' and t[1].name == 'permission_id'),
+ lambda t: (t[0] == 'remove_index' and t[1].name == 'name'),
+ lambda t: (t[0] == 'remove_index' and t[1].name == 'user_id'),
+ lambda t: (t[0] == 'remove_index' and t[1].name == 'username'),
+ lambda t: (t[0] == 'remove_index' and t[1].name == 'field_string'),
+ lambda t: (t[0] == 'remove_index' and t[1].name == 'email'),
+ lambda t: (t[0] == 'remove_index' and t[1].name == 'permission_view_id'),
# from test_security unit test
- lambda t: (t[0] == 'remove_table' and
- t[1].name == 'some_model'),
+ lambda t: (t[0] == 'remove_table' and t[1].name == 'some_model'),
]
for ignore in ignores:
diff = [d for d in diff if not ignore(d)]
- self.assertFalse(
- diff,
- 'Database schema and SQLAlchemy model are not in sync: ' + str(diff)
- )
+ self.assertFalse(diff, 'Database schema and SQLAlchemy model are not in sync: ' + str(diff))
def test_only_single_head_revision_in_migrations(self):
config = Config()
diff --git a/tests/utils/test_decorators.py b/tests/utils/test_decorators.py
index 94920b9dee4cc..274eb77da0711 100644
--- a/tests/utils/test_decorators.py
+++ b/tests/utils/test_decorators.py
@@ -37,7 +37,6 @@ def __init__(self, test_sub_param, **kwargs):
class TestApplyDefault(unittest.TestCase):
-
def test_apply(self):
dummy = DummyClass(test_param=True)
self.assertTrue(dummy.test_param)
@@ -60,8 +59,7 @@ def test_default_args(self):
self.assertTrue(dummy_class.test_param)
self.assertTrue(dummy_subclass.test_sub_param)
- with self.assertRaisesRegex(AirflowException,
- 'Argument.*test_sub_param.*required'):
+ with self.assertRaisesRegex(AirflowException, 'Argument.*test_sub_param.*required'):
DummySubClass(default_args=default_args) # pylint: disable=no-value-for-parameter
def test_incorrect_default_args(self):
diff --git a/tests/utils/test_docs.py b/tests/utils/test_docs.py
index 8a86bcada557c..81d6a1d6e8e95 100644
--- a/tests/utils/test_docs.py
+++ b/tests/utils/test_docs.py
@@ -24,12 +24,14 @@
class TestGetDocsUrl(unittest.TestCase):
- @parameterized.expand([
- ('2.0.0.dev0', None, 'https://airflow.readthedocs.io/en/latest/'),
- ('2.0.0.dev0', 'migration.html', 'https://airflow.readthedocs.io/en/latest/migration.html'),
- ('1.10.0', None, 'https://airflow.apache.org/docs/1.10.0/'),
- ('1.10.0', 'migration.html', 'https://airflow.apache.org/docs/1.10.0/migration.html'),
- ])
+ @parameterized.expand(
+ [
+ ('2.0.0.dev0', None, 'https://airflow.readthedocs.io/en/latest/'),
+ ('2.0.0.dev0', 'migration.html', 'https://airflow.readthedocs.io/en/latest/migration.html'),
+ ('1.10.0', None, 'https://airflow.apache.org/docs/1.10.0/'),
+ ('1.10.0', 'migration.html', 'https://airflow.apache.org/docs/1.10.0/migration.html'),
+ ]
+ )
def test_should_return_link(self, version, page, expected_urk):
with mock.patch('airflow.version.version', version):
self.assertEqual(expected_urk, get_docs_url(page))
diff --git a/tests/utils/test_email.py b/tests/utils/test_email.py
index 2a0810f21d599..8081fc85a74cf 100644
--- a/tests/utils/test_email.py
+++ b/tests/utils/test_email.py
@@ -34,48 +34,40 @@
class TestEmail(unittest.TestCase):
-
def test_get_email_address_single_email(self):
emails_string = 'test1@example.com'
- self.assertEqual(
- get_email_address_list(emails_string), [emails_string])
+ self.assertEqual(get_email_address_list(emails_string), [emails_string])
def test_get_email_address_comma_sep_string(self):
emails_string = 'test1@example.com, test2@example.com'
- self.assertEqual(
- get_email_address_list(emails_string), EMAILS)
+ self.assertEqual(get_email_address_list(emails_string), EMAILS)
def test_get_email_address_colon_sep_string(self):
emails_string = 'test1@example.com; test2@example.com'
- self.assertEqual(
- get_email_address_list(emails_string), EMAILS)
+ self.assertEqual(get_email_address_list(emails_string), EMAILS)
def test_get_email_address_list(self):
emails_list = ['test1@example.com', 'test2@example.com']
- self.assertEqual(
- get_email_address_list(emails_list), EMAILS)
+ self.assertEqual(get_email_address_list(emails_list), EMAILS)
def test_get_email_address_tuple(self):
emails_tuple = ('test1@example.com', 'test2@example.com')
- self.assertEqual(
- get_email_address_list(emails_tuple), EMAILS)
+ self.assertEqual(get_email_address_list(emails_tuple), EMAILS)
def test_get_email_address_invalid_type(self):
emails_string = 1
- self.assertRaises(
- TypeError, get_email_address_list, emails_string)
+ self.assertRaises(TypeError, get_email_address_list, emails_string)
def test_get_email_address_invalid_type_in_iterable(self):
emails_list = ['test1@example.com', 2]
- self.assertRaises(
- TypeError, get_email_address_list, emails_list)
+ self.assertRaises(TypeError, get_email_address_list, emails_list)
def setUp(self):
conf.remove_option('email', 'EMAIL_BACKEND')
@@ -91,8 +83,16 @@ def test_custom_backend(self, mock_send_email):
with conf_vars({('email', 'email_backend'): 'tests.utils.test_email.send_email_test'}):
utils.email.send_email('to', 'subject', 'content')
send_email_test.assert_called_once_with(
- 'to', 'subject', 'content', files=None, dryrun=False,
- cc=None, bcc=None, mime_charset='utf-8', mime_subtype='mixed')
+ 'to',
+ 'subject',
+ 'content',
+ files=None,
+ dryrun=False,
+ cc=None,
+ bcc=None,
+ mime_charset='utf-8',
+ mime_subtype='mixed',
+ )
self.assertFalse(mock_send_email.called)
def test_build_mime_message(self):
@@ -162,8 +162,10 @@ def test_send_bcc_smtp(self, mock_send_mime):
self.assertEqual('subject', msg['Subject'])
self.assertEqual(conf.get('smtp', 'SMTP_MAIL_FROM'), msg['From'])
self.assertEqual(2, len(msg.get_payload()))
- self.assertEqual('attachment; filename="' + os.path.basename(attachment.name) + '"',
- msg.get_payload()[-1].get('Content-Disposition'))
+ self.assertEqual(
+ 'attachment; filename="' + os.path.basename(attachment.name) + '"',
+ msg.get_payload()[-1].get('Content-Disposition'),
+ )
mimeapp = MIMEApplication('attachment')
self.assertEqual(mimeapp.get_payload(), msg.get_payload()[-1].get_payload())
@@ -204,10 +206,12 @@ def test_send_mime_ssl(self, mock_smtp, mock_smtp_ssl):
def test_send_mime_noauth(self, mock_smtp, mock_smtp_ssl):
mock_smtp.return_value = mock.Mock()
mock_smtp_ssl.return_value = mock.Mock()
- with conf_vars({
- ('smtp', 'smtp_user'): None,
- ('smtp', 'smtp_password'): None,
- }):
+ with conf_vars(
+ {
+ ('smtp', 'smtp_user'): None,
+ ('smtp', 'smtp_password'): None,
+ }
+ ):
utils.email.send_mime_email('from', 'to', MIMEMultipart(), dryrun=False)
self.assertFalse(mock_smtp_ssl.called)
mock_smtp.assert_called_once_with(
diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py
index b513e3e401a0e..85c53c5a0cc48 100644
--- a/tests/utils/test_helpers.py
+++ b/tests/utils/test_helpers.py
@@ -27,7 +27,6 @@
class TestHelpers(unittest.TestCase):
-
def test_render_log_filename(self):
try_number = 1
dag_id = 'test_render_log_filename_dag'
@@ -41,10 +40,9 @@ def test_render_log_filename(self):
filename_template = "{{ ti.dag_id }}/{{ ti.task_id }}/{{ ts }}/{{ try_number }}.log"
ts = ti.get_template_context()['ts']
- expected_filename = "{dag_id}/{task_id}/{ts}/{try_number}.log".format(dag_id=dag_id,
- task_id=task_id,
- ts=ts,
- try_number=try_number)
+ expected_filename = "{dag_id}/{task_id}/{ts}/{try_number}.log".format(
+ dag_id=dag_id, task_id=task_id, ts=ts, try_number=try_number
+ )
rendered_filename = helpers.render_log_filename(ti, try_number, filename_template)
@@ -62,22 +60,15 @@ def test_chunks(self):
self.assertEqual(list(helpers.chunks([1, 2, 3], 2)), [[1, 2], [3]])
def test_reduce_in_chunks(self):
- self.assertEqual(helpers.reduce_in_chunks(lambda x, y: x + [y],
- [1, 2, 3, 4, 5],
- []),
- [[1, 2, 3, 4, 5]])
-
- self.assertEqual(helpers.reduce_in_chunks(lambda x, y: x + [y],
- [1, 2, 3, 4, 5],
- [],
- 2),
- [[1, 2], [3, 4], [5]])
-
- self.assertEqual(helpers.reduce_in_chunks(lambda x, y: x + y[0] * y[1],
- [1, 2, 3, 4],
- 0,
- 2),
- 14)
+ self.assertEqual(
+ helpers.reduce_in_chunks(lambda x, y: x + [y], [1, 2, 3, 4, 5], []), [[1, 2, 3, 4, 5]]
+ )
+
+ self.assertEqual(
+ helpers.reduce_in_chunks(lambda x, y: x + [y], [1, 2, 3, 4, 5], [], 2), [[1, 2], [3, 4], [5]]
+ )
+
+ self.assertEqual(helpers.reduce_in_chunks(lambda x, y: x + y[0] * y[1], [1, 2, 3, 4], 0, 2), 14)
def test_is_container(self):
self.assertFalse(helpers.is_container("a string is not a container"))
@@ -89,14 +80,10 @@ def test_is_container(self):
self.assertFalse(helpers.is_container(10))
def test_as_tuple(self):
- self.assertEqual(
- helpers.as_tuple("a string is not a container"),
- ("a string is not a container",)
- )
+ self.assertEqual(helpers.as_tuple("a string is not a container"), ("a string is not a container",))
self.assertEqual(
- helpers.as_tuple(["a", "list", "is", "a", "container"]),
- ("a", "list", "is", "a", "container")
+ helpers.as_tuple(["a", "list", "is", "a", "container"]), ("a", "list", "is", "a", "container")
)
def test_as_tuple_iter(self):
@@ -111,8 +98,7 @@ def test_as_tuple_no_iter(self):
def test_convert_camel_to_snake(self):
self.assertEqual(helpers.convert_camel_to_snake('LocalTaskJob'), 'local_task_job')
- self.assertEqual(helpers.convert_camel_to_snake('somethingVeryRandom'),
- 'something_very_random')
+ self.assertEqual(helpers.convert_camel_to_snake('somethingVeryRandom'), 'something_very_random')
def test_merge_dicts(self):
"""
diff --git a/tests/utils/test_json.py b/tests/utils/test_json.py
index 69363b84feb8d..a127ff98e6cab 100644
--- a/tests/utils/test_json.py
+++ b/tests/utils/test_json.py
@@ -26,37 +26,21 @@
class TestAirflowJsonEncoder(unittest.TestCase):
-
def test_encode_datetime(self):
obj = datetime.strptime('2017-05-21 00:00:00', '%Y-%m-%d %H:%M:%S')
- self.assertEqual(
- json.dumps(obj, cls=utils_json.AirflowJsonEncoder),
- '"2017-05-21T00:00:00Z"'
- )
+ self.assertEqual(json.dumps(obj, cls=utils_json.AirflowJsonEncoder), '"2017-05-21T00:00:00Z"')
def test_encode_date(self):
- self.assertEqual(
- json.dumps(date(2017, 5, 21), cls=utils_json.AirflowJsonEncoder),
- '"2017-05-21"'
- )
+ self.assertEqual(json.dumps(date(2017, 5, 21), cls=utils_json.AirflowJsonEncoder), '"2017-05-21"')
def test_encode_numpy_int(self):
- self.assertEqual(
- json.dumps(np.int32(5), cls=utils_json.AirflowJsonEncoder),
- '5'
- )
+ self.assertEqual(json.dumps(np.int32(5), cls=utils_json.AirflowJsonEncoder), '5')
def test_encode_numpy_bool(self):
- self.assertEqual(
- json.dumps(np.bool_(True), cls=utils_json.AirflowJsonEncoder),
- 'true'
- )
+ self.assertEqual(json.dumps(np.bool_(True), cls=utils_json.AirflowJsonEncoder), 'true')
def test_encode_numpy_float(self):
- self.assertEqual(
- json.dumps(np.float16(3.76953125), cls=utils_json.AirflowJsonEncoder),
- '3.76953125'
- )
+ self.assertEqual(json.dumps(np.float16(3.76953125), cls=utils_json.AirflowJsonEncoder), '3.76953125')
def test_encode_k8s_v1pod(self):
from kubernetes.client import models as k8s
@@ -73,17 +57,21 @@ def test_encode_k8s_v1pod(self):
image="bar",
)
]
- )
+ ),
)
self.assertEqual(
json.loads(json.dumps(pod, cls=utils_json.AirflowJsonEncoder)),
- {"metadata": {"name": "foo", "namespace": "bar"},
- "spec": {"containers": [{"image": "bar", "name": "foo"}]}}
+ {
+ "metadata": {"name": "foo", "namespace": "bar"},
+ "spec": {"containers": [{"image": "bar", "name": "foo"}]},
+ },
)
def test_encode_raises(self):
- self.assertRaisesRegex(TypeError,
- "^.*is not JSON serializable$",
- json.dumps,
- Exception,
- cls=utils_json.AirflowJsonEncoder)
+ self.assertRaisesRegex(
+ TypeError,
+ "^.*is not JSON serializable$",
+ json.dumps,
+ Exception,
+ cls=utils_json.AirflowJsonEncoder,
+ )
diff --git a/tests/utils/test_log_handlers.py b/tests/utils/test_log_handlers.py
index df23ed1af4b36..32a8d251226cb 100644
--- a/tests/utils/test_log_handlers.py
+++ b/tests/utils/test_log_handlers.py
@@ -38,7 +38,6 @@
class TestFileTaskLogHandler(unittest.TestCase):
-
def clean_up(self):
with create_session() as session:
session.query(DagRun).delete()
@@ -66,6 +65,7 @@ def test_default_task_logging_setup(self):
def test_file_task_handler(self):
def task_callable(ti, **kwargs):
ti.log.info("test")
+
dag = DAG('dag_for_testing_file_task_handler', start_date=DEFAULT_DATE)
dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=DEFAULT_DATE)
task = PythonOperator(
@@ -78,8 +78,9 @@ def task_callable(ti, **kwargs):
logger = ti.log
ti.log.disabled = False
- file_handler = next((handler for handler in logger.handlers
- if handler.name == FILE_TASK_HANDLER), None)
+ file_handler = next(
+ (handler for handler in logger.handlers if handler.name == FILE_TASK_HANDLER), None
+ )
self.assertIsNotNone(file_handler)
set_context(logger, ti)
@@ -106,11 +107,7 @@ def task_callable(ti, **kwargs):
# We should expect our log line from the callable above to appear in
# the logs we read back
- self.assertRegex(
- logs[0][0][-1],
- target_re,
- "Logs were " + str(logs)
- )
+ self.assertRegex(logs[0][0][-1], target_re, "Logs were " + str(logs))
# Remove the generated tmp log file.
os.remove(log_filename)
@@ -118,6 +115,7 @@ def task_callable(ti, **kwargs):
def test_file_task_handler_running(self):
def task_callable(ti, **kwargs):
ti.log.info("test")
+
dag = DAG('dag_for_testing_file_task_handler', start_date=DEFAULT_DATE)
task = PythonOperator(
task_id='task_for_testing_file_log_handler',
@@ -131,8 +129,9 @@ def task_callable(ti, **kwargs):
logger = ti.log
ti.log.disabled = False
- file_handler = next((handler for handler in logger.handlers
- if handler.name == FILE_TASK_HANDLER), None)
+ file_handler = next(
+ (handler for handler in logger.handlers if handler.name == FILE_TASK_HANDLER), None
+ )
self.assertIsNotNone(file_handler)
set_context(logger, ti)
@@ -159,25 +158,26 @@ def task_callable(ti, **kwargs):
class TestFilenameRendering(unittest.TestCase):
-
def setUp(self):
dag = DAG('dag_for_testing_filename_rendering', start_date=DEFAULT_DATE)
task = DummyOperator(task_id='task_for_testing_filename_rendering', dag=dag)
self.ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
def test_python_formatting(self):
- expected_filename = \
- 'dag_for_testing_filename_rendering/task_for_testing_filename_rendering/%s/42.log' \
+ expected_filename = (
+ 'dag_for_testing_filename_rendering/task_for_testing_filename_rendering/%s/42.log'
% DEFAULT_DATE.isoformat()
+ )
fth = FileTaskHandler('', '{dag_id}/{task_id}/{execution_date}/{try_number}.log')
rendered_filename = fth._render_filename(self.ti, 42)
self.assertEqual(expected_filename, rendered_filename)
def test_jinja_rendering(self):
- expected_filename = \
- 'dag_for_testing_filename_rendering/task_for_testing_filename_rendering/%s/42.log' \
+ expected_filename = (
+ 'dag_for_testing_filename_rendering/task_for_testing_filename_rendering/%s/42.log'
% DEFAULT_DATE.isoformat()
+ )
fth = FileTaskHandler('', '{{ ti.dag_id }}/{{ ti.task_id }}/{{ ts }}/{{ try_number }}.log')
rendered_filename = fth._render_filename(self.ti, 42)
diff --git a/tests/utils/test_logging_mixin.py b/tests/utils/test_logging_mixin.py
index 7a6bda50f83d8..583eefd4bccbf 100644
--- a/tests/utils/test_logging_mixin.py
+++ b/tests/utils/test_logging_mixin.py
@@ -25,18 +25,20 @@
class TestLoggingMixin(unittest.TestCase):
def setUp(self):
- warnings.filterwarnings(
- action='always'
- )
+ warnings.filterwarnings(action='always')
def test_set_context(self):
handler1 = mock.MagicMock()
handler2 = mock.MagicMock()
parent = mock.MagicMock()
parent.propagate = False
- parent.handlers = [handler1, ]
+ parent.handlers = [
+ handler1,
+ ]
log = mock.MagicMock()
- log.handlers = [handler2, ]
+ log.handlers = [
+ handler2,
+ ]
log.parent = parent
log.propagate = True
diff --git a/tests/utils/test_net.py b/tests/utils/test_net.py
index 2217785b4f4d4..2f06371018a7f 100644
--- a/tests/utils/test_net.py
+++ b/tests/utils/test_net.py
@@ -29,7 +29,6 @@ def get_hostname():
class TestGetHostname(unittest.TestCase):
-
@mock.patch('socket.getfqdn', return_value='first')
@conf_vars({('core', 'hostname_callable'): None})
def test_get_hostname_unset(self, mock_getfqdn):
@@ -51,6 +50,6 @@ def test_get_hostname_set_missing(self):
re.escape(
'The object could not be loaded. Please check "hostname_callable" key in "core" section. '
'Current value: "tests.utils.test_net.missing_func"'
- )
+ ),
):
net.get_hostname()
diff --git a/tests/utils/test_operator_helpers.py b/tests/utils/test_operator_helpers.py
index d6a861287567d..c448901f3c8b8 100644
--- a/tests/utils/test_operator_helpers.py
+++ b/tests/utils/test_operator_helpers.py
@@ -24,7 +24,6 @@
class TestOperatorHelpers(unittest.TestCase):
-
def setUp(self):
super().setUp()
self.dag_id = 'dag_id'
@@ -37,21 +36,15 @@ def setUp(self):
'dag_run': mock.MagicMock(
name='dag_run',
run_id=self.dag_run_id,
- execution_date=datetime.strptime(self.execution_date,
- '%Y-%m-%dT%H:%M:%S'),
+ execution_date=datetime.strptime(self.execution_date, '%Y-%m-%dT%H:%M:%S'),
),
'task_instance': mock.MagicMock(
name='task_instance',
task_id=self.task_id,
dag_id=self.dag_id,
- execution_date=datetime.strptime(self.execution_date,
- '%Y-%m-%dT%H:%M:%S'),
+ execution_date=datetime.strptime(self.execution_date, '%Y-%m-%dT%H:%M:%S'),
),
- 'task': mock.MagicMock(
- name='task',
- owner=self.owner,
- email=self.email
- )
+ 'task': mock.MagicMock(name='task', owner=self.owner, email=self.email),
}
def test_context_to_airflow_vars_empty_context(self):
@@ -66,19 +59,18 @@ def test_context_to_airflow_vars_all_context(self):
'airflow.ctx.task_id': self.task_id,
'airflow.ctx.dag_run_id': self.dag_run_id,
'airflow.ctx.dag_owner': 'owner1,owner2',
- 'airflow.ctx.dag_email': 'email1@test.com'
- }
+ 'airflow.ctx.dag_email': 'email1@test.com',
+ },
)
self.assertDictEqual(
- operator_helpers.context_to_airflow_vars(self.context,
- in_env_var_format=True),
+ operator_helpers.context_to_airflow_vars(self.context, in_env_var_format=True),
{
'AIRFLOW_CTX_DAG_ID': self.dag_id,
'AIRFLOW_CTX_EXECUTION_DATE': self.execution_date,
'AIRFLOW_CTX_TASK_ID': self.task_id,
'AIRFLOW_CTX_DAG_RUN_ID': self.dag_run_id,
'AIRFLOW_CTX_DAG_OWNER': 'owner1,owner2',
- 'AIRFLOW_CTX_DAG_EMAIL': 'email1@test.com'
- }
+ 'AIRFLOW_CTX_DAG_EMAIL': 'email1@test.com',
+ },
)
diff --git a/tests/utils/test_process_utils.py b/tests/utils/test_process_utils.py
index 1620bfe747bf9..ee1097e93f22c 100644
--- a/tests/utils/test_process_utils.py
+++ b/tests/utils/test_process_utils.py
@@ -39,7 +39,6 @@
class TestReapProcessGroup(unittest.TestCase):
-
@staticmethod
def _ignores_sigterm(child_pid, child_setup_done):
def signal_handler(unused_signum, unused_frame):
@@ -55,11 +54,13 @@ def signal_handler(unused_signum, unused_frame):
def _parent_of_ignores_sigterm(parent_pid, child_pid, setup_done):
def signal_handler(unused_signum, unused_frame):
pass
+
os.setsid()
signal.signal(signal.SIGTERM, signal_handler)
child_setup_done = multiprocessing.Semaphore(0)
- child = multiprocessing.Process(target=TestReapProcessGroup._ignores_sigterm,
- args=[child_pid, child_setup_done])
+ child = multiprocessing.Process(
+ target=TestReapProcessGroup._ignores_sigterm, args=[child_pid, child_setup_done]
+ )
child.start()
child_setup_done.acquire(timeout=5.0)
parent_pid.value = os.getpid()
@@ -96,19 +97,13 @@ def test_reap_process_group(self):
class TestExecuteInSubProcess(unittest.TestCase):
-
def test_should_print_all_messages1(self):
with self.assertLogs(log) as logs:
execute_in_subprocess(["bash", "-c", "echo CAT; echo KITTY;"])
msgs = [record.getMessage() for record in logs.records]
- self.assertEqual([
- "Executing cmd: bash -c 'echo CAT; echo KITTY;'",
- 'Output:',
- 'CAT',
- 'KITTY'
- ], msgs)
+ self.assertEqual(["Executing cmd: bash -c 'echo CAT; echo KITTY;'", 'Output:', 'CAT', 'KITTY'], msgs)
def test_should_raise_exception(self):
with self.assertRaises(CalledProcessError):
diff --git a/tests/utils/test_python_virtualenv.py b/tests/utils/test_python_virtualenv.py
index 218bce343ddaf..59ad504b84e1f 100644
--- a/tests/utils/test_python_virtualenv.py
+++ b/tests/utils/test_python_virtualenv.py
@@ -23,14 +23,10 @@
class TestPrepareVirtualenv(unittest.TestCase):
-
@mock.patch('airflow.utils.python_virtualenv.execute_in_subprocess')
def test_should_create_virtualenv(self, mock_execute_in_subprocess):
python_bin = prepare_virtualenv(
- venv_directory="/VENV",
- python_bin="pythonVER",
- system_site_packages=False,
- requirements=[]
+ venv_directory="/VENV", python_bin="pythonVER", system_site_packages=False, requirements=[]
)
self.assertEqual("/VENV/bin/python", python_bin)
mock_execute_in_subprocess.assert_called_once_with(['virtualenv', '/VENV', '--python=pythonVER'])
@@ -38,10 +34,7 @@ def test_should_create_virtualenv(self, mock_execute_in_subprocess):
@mock.patch('airflow.utils.python_virtualenv.execute_in_subprocess')
def test_should_create_virtualenv_with_system_packages(self, mock_execute_in_subprocess):
python_bin = prepare_virtualenv(
- venv_directory="/VENV",
- python_bin="pythonVER",
- system_site_packages=True,
- requirements=[]
+ venv_directory="/VENV", python_bin="pythonVER", system_site_packages=True, requirements=[]
)
self.assertEqual("/VENV/bin/python", python_bin)
mock_execute_in_subprocess.assert_called_once_with(
@@ -54,14 +47,10 @@ def test_should_create_virtualenv_with_extra_packages(self, mock_execute_in_subp
venv_directory="/VENV",
python_bin="pythonVER",
system_site_packages=False,
- requirements=['apache-beam[gcp]']
+ requirements=['apache-beam[gcp]'],
)
self.assertEqual("/VENV/bin/python", python_bin)
- mock_execute_in_subprocess.assert_any_call(
- ['virtualenv', '/VENV', '--python=pythonVER']
- )
+ mock_execute_in_subprocess.assert_any_call(['virtualenv', '/VENV', '--python=pythonVER'])
- mock_execute_in_subprocess.assert_called_with(
- ['/VENV/bin/pip', 'install', 'apache-beam[gcp]']
- )
+ mock_execute_in_subprocess.assert_called_with(['/VENV/bin/pip', 'install', 'apache-beam[gcp]'])
diff --git a/tests/utils/test_sqlalchemy.py b/tests/utils/test_sqlalchemy.py
index ac8a52a776ef9..dd6729bb6882d 100644
--- a/tests/utils/test_sqlalchemy.py
+++ b/tests/utils/test_sqlalchemy.py
@@ -94,28 +94,66 @@ def test_process_bind_param_naive(self):
state=State.NONE,
execution_date=start_date,
start_date=start_date,
- session=self.session
+ session=self.session,
)
dag.clear()
- @parameterized.expand([
- ("postgresql", True, {'skip_locked': True}, ),
- ("mysql", False, {}, ),
- ("mysql", True, {'skip_locked': True}, ),
- ("sqlite", False, {'skip_locked': True}, ),
- ])
+ @parameterized.expand(
+ [
+ (
+ "postgresql",
+ True,
+ {'skip_locked': True},
+ ),
+ (
+ "mysql",
+ False,
+ {},
+ ),
+ (
+ "mysql",
+ True,
+ {'skip_locked': True},
+ ),
+ (
+ "sqlite",
+ False,
+ {'skip_locked': True},
+ ),
+ ]
+ )
def test_skip_locked(self, dialect, supports_for_update_of, expected_return_value):
session = mock.Mock()
session.bind.dialect.name = dialect
session.bind.dialect.supports_for_update_of = supports_for_update_of
self.assertEqual(skip_locked(session=session), expected_return_value)
- @parameterized.expand([
- ("postgresql", True, {'nowait': True}, ),
- ("mysql", False, {}, ),
- ("mysql", True, {'nowait': True}, ),
- ("sqlite", False, {'nowait': True, }, ),
- ])
+ @parameterized.expand(
+ [
+ (
+ "postgresql",
+ True,
+ {'nowait': True},
+ ),
+ (
+ "mysql",
+ False,
+ {},
+ ),
+ (
+ "mysql",
+ True,
+ {'nowait': True},
+ ),
+ (
+ "sqlite",
+ False,
+ {
+ 'nowait': True,
+ },
+ ),
+ ]
+ )
def test_nowait(self, dialect, supports_for_update_of, expected_return_value):
session = mock.Mock()
session.bind.dialect.name = dialect
diff --git a/tests/utils/test_task_handler_with_custom_formatter.py b/tests/utils/test_task_handler_with_custom_formatter.py
index 37f874442252c..3c9ea13b092c3 100644
--- a/tests/utils/test_task_handler_with_custom_formatter.py
+++ b/tests/utils/test_task_handler_with_custom_formatter.py
@@ -29,8 +29,7 @@
DEFAULT_DATE = datetime(2019, 1, 1)
TASK_LOGGER = 'airflow.task'
TASK_HANDLER = 'task'
-TASK_HANDLER_CLASS = 'airflow.utils.log.task_handler_with_custom_formatter.' \
- 'TaskHandlerWithCustomFormatter'
+TASK_HANDLER_CLASS = 'airflow.utils.log.task_handler_with_custom_formatter.TaskHandlerWithCustomFormatter'
PREV_TASK_HANDLER = DEFAULT_LOGGING_CONFIG['handlers']['task']
@@ -40,7 +39,7 @@ def setUp(self):
DEFAULT_LOGGING_CONFIG['handlers']['task'] = {
'class': TASK_HANDLER_CLASS,
'formatter': 'airflow',
- 'stream': 'sys.stdout'
+ 'stream': 'sys.stdout',
}
logging.config.dictConfig(DEFAULT_LOGGING_CONFIG)
diff --git a/tests/utils/test_timezone.py b/tests/utils/test_timezone.py
index 400813e88a4cc..d7249bde1cdd4 100644
--- a/tests/utils/test_timezone.py
+++ b/tests/utils/test_timezone.py
@@ -55,10 +55,12 @@ def test_convert_to_utc(self):
def test_make_naive(self):
self.assertEqual(
timezone.make_naive(datetime.datetime(2011, 9, 1, 13, 20, 30, tzinfo=EAT), EAT),
- datetime.datetime(2011, 9, 1, 13, 20, 30))
+ datetime.datetime(2011, 9, 1, 13, 20, 30),
+ )
self.assertEqual(
timezone.make_naive(datetime.datetime(2011, 9, 1, 17, 20, 30, tzinfo=ICT), EAT),
- datetime.datetime(2011, 9, 1, 13, 20, 30))
+ datetime.datetime(2011, 9, 1, 13, 20, 30),
+ )
with self.assertRaises(ValueError):
timezone.make_naive(datetime.datetime(2011, 9, 1, 13, 20, 30), EAT)
@@ -66,6 +68,7 @@ def test_make_naive(self):
def test_make_aware(self):
self.assertEqual(
timezone.make_aware(datetime.datetime(2011, 9, 1, 13, 20, 30), EAT),
- datetime.datetime(2011, 9, 1, 13, 20, 30, tzinfo=EAT))
+ datetime.datetime(2011, 9, 1, 13, 20, 30, tzinfo=EAT),
+ )
with self.assertRaises(ValueError):
timezone.make_aware(datetime.datetime(2011, 9, 1, 13, 20, 30, tzinfo=EAT), EAT)
diff --git a/tests/utils/test_trigger_rule.py b/tests/utils/test_trigger_rule.py
index 1ea399ebde3b7..afbe0fcf48a37 100644
--- a/tests/utils/test_trigger_rule.py
+++ b/tests/utils/test_trigger_rule.py
@@ -22,7 +22,6 @@
class TestTriggerRule(unittest.TestCase):
-
def test_valid_trigger_rules(self):
self.assertTrue(TriggerRule.is_valid(TriggerRule.ALL_SUCCESS))
self.assertTrue(TriggerRule.is_valid(TriggerRule.ALL_FAILED))
diff --git a/tests/utils/test_weight_rule.py b/tests/utils/test_weight_rule.py
index de63d5f9c276f..862e1b426f7fa 100644
--- a/tests/utils/test_weight_rule.py
+++ b/tests/utils/test_weight_rule.py
@@ -22,7 +22,6 @@
class TestWeightRule(unittest.TestCase):
-
def test_valid_weight_rules(self):
self.assertTrue(WeightRule.is_valid(WeightRule.DOWNSTREAM))
self.assertTrue(WeightRule.is_valid(WeightRule.UPSTREAM))
diff --git a/tests/www/api/experimental/test_dag_runs_endpoint.py b/tests/www/api/experimental/test_dag_runs_endpoint.py
index 8e7502f321f06..4e020c76e6188 100644
--- a/tests/www/api/experimental/test_dag_runs_endpoint.py
+++ b/tests/www/api/experimental/test_dag_runs_endpoint.py
@@ -26,7 +26,6 @@
class TestDagRunsEndpoint(unittest.TestCase):
-
@classmethod
def setUpClass(cls):
super().setUpClass()
@@ -55,8 +54,7 @@ def test_get_dag_runs_success(self):
url_template = '/api/experimental/dags/{}/dag_runs'
dag_id = 'example_bash_operator'
# Create DagRun
- dag_run = trigger_dag(
- dag_id=dag_id, run_id='test_get_dag_runs_success')
+ dag_run = trigger_dag(dag_id=dag_id, run_id='test_get_dag_runs_success')
response = self.app.get(url_template.format(dag_id))
self.assertEqual(200, response.status_code)
@@ -71,8 +69,7 @@ def test_get_dag_runs_success_with_state_parameter(self):
url_template = '/api/experimental/dags/{}/dag_runs?state=running'
dag_id = 'example_bash_operator'
# Create DagRun
- dag_run = trigger_dag(
- dag_id=dag_id, run_id='test_get_dag_runs_success')
+ dag_run = trigger_dag(dag_id=dag_id, run_id='test_get_dag_runs_success')
response = self.app.get(url_template.format(dag_id))
self.assertEqual(200, response.status_code)
@@ -87,8 +84,7 @@ def test_get_dag_runs_success_with_capital_state_parameter(self):
url_template = '/api/experimental/dags/{}/dag_runs?state=RUNNING'
dag_id = 'example_bash_operator'
# Create DagRun
- dag_run = trigger_dag(
- dag_id=dag_id, run_id='test_get_dag_runs_success')
+ dag_run = trigger_dag(dag_id=dag_id, run_id='test_get_dag_runs_success')
response = self.app.get(url_template.format(dag_id))
self.assertEqual(200, response.status_code)
diff --git a/tests/www/api/experimental/test_endpoints.py b/tests/www/api/experimental/test_endpoints.py
index 760e0ac639e44..cbdf1c44c6b0c 100644
--- a/tests/www/api/experimental/test_endpoints.py
+++ b/tests/www/api/experimental/test_endpoints.py
@@ -53,13 +53,11 @@ def assert_deprecated(self, resp):
self.assertEqual('true', resp.headers['Deprecation'])
self.assertRegex(
resp.headers['Link'],
- r'\<.+/stable-rest-api/migration.html\>; '
- 'rel="deprecation"; type="text/html"',
+ r'\<.+/stable-rest-api/migration.html\>; ' 'rel="deprecation"; type="text/html"',
)
class TestApiExperimental(TestBase):
-
@classmethod
def setUpClass(cls):
super().setUpClass()
@@ -94,40 +92,30 @@ def test_info(self):
def test_task_info(self):
url_template = '/api/experimental/dags/{}/tasks/{}'
- response = self.client.get(
- url_template.format('example_bash_operator', 'runme_0')
- )
+ response = self.client.get(url_template.format('example_bash_operator', 'runme_0'))
self.assert_deprecated(response)
self.assertIn('"email"', response.data.decode('utf-8'))
self.assertNotIn('error', response.data.decode('utf-8'))
self.assertEqual(200, response.status_code)
- response = self.client.get(
- url_template.format('example_bash_operator', 'DNE')
- )
+ response = self.client.get(url_template.format('example_bash_operator', 'DNE'))
self.assertIn('error', response.data.decode('utf-8'))
self.assertEqual(404, response.status_code)
- response = self.client.get(
- url_template.format('DNE', 'DNE')
- )
+ response = self.client.get(url_template.format('DNE', 'DNE'))
self.assertIn('error', response.data.decode('utf-8'))
self.assertEqual(404, response.status_code)
def test_get_dag_code(self):
url_template = '/api/experimental/dags/{}/code'
- response = self.client.get(
- url_template.format('example_bash_operator')
- )
+ response = self.client.get(url_template.format('example_bash_operator'))
self.assert_deprecated(response)
self.assertIn('BashOperator(', response.data.decode('utf-8'))
self.assertEqual(200, response.status_code)
- response = self.client.get(
- url_template.format('xyz')
- )
+ response = self.client.get(url_template.format('xyz'))
self.assertEqual(404, response.status_code)
def test_dag_paused(self):
@@ -135,9 +123,7 @@ def test_dag_paused(self):
paused_url_template = '/api/experimental/dags/{}/paused'
paused_url = paused_url_template.format('example_bash_operator')
- response = self.client.get(
- pause_url_template.format('example_bash_operator', 'true')
- )
+ response = self.client.get(pause_url_template.format('example_bash_operator', 'true'))
self.assert_deprecated(response)
self.assertIn('ok', response.data.decode('utf-8'))
self.assertEqual(200, response.status_code)
@@ -147,9 +133,7 @@ def test_dag_paused(self):
self.assertEqual(200, paused_response.status_code)
self.assertEqual({"is_paused": True}, paused_response.json)
- response = self.client.get(
- pause_url_template.format('example_bash_operator', 'false')
- )
+ response = self.client.get(pause_url_template.format('example_bash_operator', 'false'))
self.assertIn('ok', response.data.decode('utf-8'))
self.assertEqual(200, response.status_code)
@@ -164,13 +148,12 @@ def test_trigger_dag(self):
response = self.client.post(
url_template.format('example_bash_operator'),
data=json.dumps({'run_id': run_id}),
- content_type="application/json"
+ content_type="application/json",
)
self.assert_deprecated(response)
self.assertEqual(200, response.status_code)
- response_execution_date = parse_datetime(
- json.loads(response.data.decode('utf-8'))['execution_date'])
+ response_execution_date = parse_datetime(json.loads(response.data.decode('utf-8'))['execution_date'])
self.assertEqual(0, response_execution_date.microsecond)
# Check execution_date is correct
@@ -184,9 +167,7 @@ def test_trigger_dag(self):
# Test error for nonexistent dag
response = self.client.post(
- url_template.format('does_not_exist_dag'),
- data=json.dumps({}),
- content_type="application/json"
+ url_template.format('does_not_exist_dag'), data=json.dumps({}), content_type="application/json"
)
self.assertEqual(404, response.status_code)
@@ -200,7 +181,7 @@ def test_trigger_dag_for_date(self):
response = self.client.post(
url_template.format(dag_id),
data=json.dumps({'execution_date': datetime_string}),
- content_type="application/json"
+ content_type="application/json",
)
self.assert_deprecated(response)
self.assertEqual(200, response.status_code)
@@ -209,33 +190,28 @@ def test_trigger_dag_for_date(self):
dagbag = DagBag()
dag = dagbag.get_dag(dag_id)
dag_run = dag.get_dagrun(execution_date)
- self.assertTrue(dag_run,
- 'Dag Run not found for execution date {}'
- .format(execution_date))
+ self.assertTrue(dag_run, f'Dag Run not found for execution date {execution_date}')
# Test correct execution with execution date and microseconds replaced
response = self.client.post(
url_template.format(dag_id),
data=json.dumps({'execution_date': datetime_string, 'replace_microseconds': 'true'}),
- content_type="application/json"
+ content_type="application/json",
)
self.assertEqual(200, response.status_code)
- response_execution_date = parse_datetime(
- json.loads(response.data.decode('utf-8'))['execution_date'])
+ response_execution_date = parse_datetime(json.loads(response.data.decode('utf-8'))['execution_date'])
self.assertEqual(0, response_execution_date.microsecond)
dagbag = DagBag()
dag = dagbag.get_dag(dag_id)
dag_run = dag.get_dagrun(response_execution_date)
- self.assertTrue(dag_run,
- 'Dag Run not found for execution date {}'
- .format(execution_date))
+ self.assertTrue(dag_run, f'Dag Run not found for execution date {execution_date}')
# Test error for nonexistent dag
response = self.client.post(
url_template.format('does_not_exist_dag'),
data=json.dumps({'execution_date': datetime_string}),
- content_type="application/json"
+ content_type="application/json",
)
self.assertEqual(404, response.status_code)
@@ -243,7 +219,7 @@ def test_trigger_dag_for_date(self):
response = self.client.post(
url_template.format(dag_id),
data=json.dumps({'execution_date': 'not_a_datetime'}),
- content_type="application/json"
+ content_type="application/json",
)
self.assertEqual(400, response.status_code)
@@ -253,19 +229,13 @@ def test_task_instance_info(self):
task_id = 'also_run_this'
execution_date = utcnow().replace(microsecond=0)
datetime_string = quote_plus(execution_date.isoformat())
- wrong_datetime_string = quote_plus(
- datetime(1990, 1, 1, 1, 1, 1).isoformat()
- )
+ wrong_datetime_string = quote_plus(datetime(1990, 1, 1, 1, 1, 1).isoformat())
# Create DagRun
- trigger_dag(dag_id=dag_id,
- run_id='test_task_instance_info_run',
- execution_date=execution_date)
+ trigger_dag(dag_id=dag_id, run_id='test_task_instance_info_run', execution_date=execution_date)
# Test Correct execution
- response = self.client.get(
- url_template.format(dag_id, datetime_string, task_id)
- )
+ response = self.client.get(url_template.format(dag_id, datetime_string, task_id))
self.assert_deprecated(response)
self.assertEqual(200, response.status_code)
self.assertIn('state', response.data.decode('utf-8'))
@@ -273,30 +243,23 @@ def test_task_instance_info(self):
# Test error for nonexistent dag
response = self.client.get(
- url_template.format('does_not_exist_dag', datetime_string,
- task_id),
+ url_template.format('does_not_exist_dag', datetime_string, task_id),
)
self.assertEqual(404, response.status_code)
self.assertIn('error', response.data.decode('utf-8'))
# Test error for nonexistent task
- response = self.client.get(
- url_template.format(dag_id, datetime_string, 'does_not_exist_task')
- )
+ response = self.client.get(url_template.format(dag_id, datetime_string, 'does_not_exist_task'))
self.assertEqual(404, response.status_code)
self.assertIn('error', response.data.decode('utf-8'))
# Test error for nonexistent dag run (wrong execution_date)
- response = self.client.get(
- url_template.format(dag_id, wrong_datetime_string, task_id)
- )
+ response = self.client.get(url_template.format(dag_id, wrong_datetime_string, task_id))
self.assertEqual(404, response.status_code)
self.assertIn('error', response.data.decode('utf-8'))
# Test error for bad datetime format
- response = self.client.get(
- url_template.format(dag_id, 'not_a_datetime', task_id)
- )
+ response = self.client.get(url_template.format(dag_id, 'not_a_datetime', task_id))
self.assertEqual(400, response.status_code)
self.assertIn('error', response.data.decode('utf-8'))
@@ -305,19 +268,13 @@ def test_dagrun_status(self):
dag_id = 'example_bash_operator'
execution_date = utcnow().replace(microsecond=0)
datetime_string = quote_plus(execution_date.isoformat())
- wrong_datetime_string = quote_plus(
- datetime(1990, 1, 1, 1, 1, 1).isoformat()
- )
+ wrong_datetime_string = quote_plus(datetime(1990, 1, 1, 1, 1, 1).isoformat())
# Create DagRun
- trigger_dag(dag_id=dag_id,
- run_id='test_task_instance_info_run',
- execution_date=execution_date)
+ trigger_dag(dag_id=dag_id, run_id='test_task_instance_info_run', execution_date=execution_date)
# Test Correct execution
- response = self.client.get(
- url_template.format(dag_id, datetime_string)
- )
+ response = self.client.get(url_template.format(dag_id, datetime_string))
self.assert_deprecated(response)
self.assertEqual(200, response.status_code)
self.assertIn('state', response.data.decode('utf-8'))
@@ -331,16 +288,12 @@ def test_dagrun_status(self):
self.assertIn('error', response.data.decode('utf-8'))
# Test error for nonexistent dag run (wrong execution_date)
- response = self.client.get(
- url_template.format(dag_id, wrong_datetime_string)
- )
+ response = self.client.get(url_template.format(dag_id, wrong_datetime_string))
self.assertEqual(404, response.status_code)
self.assertIn('error', response.data.decode('utf-8'))
# Test error for bad datetime format
- response = self.client.get(
- url_template.format(dag_id, 'not_a_datetime')
- )
+ response = self.client.get(url_template.format(dag_id, 'not_a_datetime'))
self.assertEqual(400, response.status_code)
self.assertIn('error', response.data.decode('utf-8'))
@@ -368,19 +321,13 @@ def test_lineage_info(self):
dag_id = 'example_papermill_operator'
execution_date = utcnow().replace(microsecond=0)
datetime_string = quote_plus(execution_date.isoformat())
- wrong_datetime_string = quote_plus(
- datetime(1990, 1, 1, 1, 1, 1).isoformat()
- )
+ wrong_datetime_string = quote_plus(datetime(1990, 1, 1, 1, 1, 1).isoformat())
# create DagRun
- trigger_dag(dag_id=dag_id,
- run_id='test_lineage_info_run',
- execution_date=execution_date)
+ trigger_dag(dag_id=dag_id, run_id='test_lineage_info_run', execution_date=execution_date)
# test correct execution
- response = self.client.get(
- url_template.format(dag_id, datetime_string)
- )
+ response = self.client.get(url_template.format(dag_id, datetime_string))
self.assert_deprecated(response)
self.assertEqual(200, response.status_code)
self.assertIn('task_ids', response.data.decode('utf-8'))
@@ -394,16 +341,12 @@ def test_lineage_info(self):
self.assertIn('error', response.data.decode('utf-8'))
# Test error for nonexistent dag run (wrong execution_date)
- response = self.client.get(
- url_template.format(dag_id, wrong_datetime_string)
- )
+ response = self.client.get(url_template.format(dag_id, wrong_datetime_string))
self.assertEqual(404, response.status_code)
self.assertIn('error', response.data.decode('utf-8'))
# Test error for bad datetime format
- response = self.client.get(
- url_template.format(dag_id, 'not_a_datetime')
- )
+ response = self.client.get(url_template.format(dag_id, 'not_a_datetime'))
self.assertEqual(400, response.status_code)
self.assertIn('error', response.data.decode('utf-8'))
@@ -444,14 +387,12 @@ def test_get_pool(self):
)
self.assert_deprecated(response)
self.assertEqual(response.status_code, 200)
- self.assertEqual(json.loads(response.data.decode('utf-8')),
- self.pool.to_json())
+ self.assertEqual(json.loads(response.data.decode('utf-8')), self.pool.to_json())
def test_get_pool_non_existing(self):
response = self.client.get('/api/experimental/pools/foo')
self.assertEqual(response.status_code, 404)
- self.assertEqual(json.loads(response.data.decode('utf-8'))['error'],
- "Pool 'foo' doesn't exist")
+ self.assertEqual(json.loads(response.data.decode('utf-8'))['error'], "Pool 'foo' doesn't exist")
def test_get_pools(self):
response = self.client.get('/api/experimental/pools')
@@ -465,11 +406,13 @@ def test_get_pools(self):
def test_create_pool(self):
response = self.client.post(
'/api/experimental/pools',
- data=json.dumps({
- 'name': 'foo',
- 'slots': 1,
- 'description': '',
- }),
+ data=json.dumps(
+ {
+ 'name': 'foo',
+ 'slots': 1,
+ 'description': '',
+ }
+ ),
content_type='application/json',
)
self.assert_deprecated(response)
@@ -484,11 +427,13 @@ def test_create_pool_with_bad_name(self):
for name in ('', ' '):
response = self.client.post(
'/api/experimental/pools',
- data=json.dumps({
- 'name': name,
- 'slots': 1,
- 'description': '',
- }),
+ data=json.dumps(
+ {
+ 'name': name,
+ 'slots': 1,
+ 'description': '',
+ }
+ ),
content_type='application/json',
)
self.assertEqual(response.status_code, 400)
@@ -504,8 +449,7 @@ def test_delete_pool(self):
)
self.assert_deprecated(response)
self.assertEqual(response.status_code, 200)
- self.assertEqual(json.loads(response.data.decode('utf-8')),
- self.pool.to_json())
+ self.assertEqual(json.loads(response.data.decode('utf-8')), self.pool.to_json())
self.assertEqual(self._get_pool_count(), self.TOTAL_POOL_COUNT - 1)
def test_delete_pool_non_existing(self):
@@ -513,8 +457,7 @@ def test_delete_pool_non_existing(self):
'/api/experimental/pools/foo',
)
self.assertEqual(response.status_code, 404)
- self.assertEqual(json.loads(response.data.decode('utf-8'))['error'],
- "Pool 'foo' doesn't exist")
+ self.assertEqual(json.loads(response.data.decode('utf-8'))['error'], "Pool 'foo' doesn't exist")
def test_delete_default_pool(self):
clear_db_pools()
@@ -522,5 +465,4 @@ def test_delete_default_pool(self):
'/api/experimental/pools/default_pool',
)
self.assertEqual(response.status_code, 400)
- self.assertEqual(json.loads(response.data.decode('utf-8'))['error'],
- "default_pool cannot be deleted")
+ self.assertEqual(json.loads(response.data.decode('utf-8'))['error'], "default_pool cannot be deleted")
diff --git a/tests/www/api/experimental/test_kerberos_endpoints.py b/tests/www/api/experimental/test_kerberos_endpoints.py
index 402e5cf059b73..8425d39eb393f 100644
--- a/tests/www/api/experimental/test_kerberos_endpoints.py
+++ b/tests/www/api/experimental/test_kerberos_endpoints.py
@@ -42,10 +42,12 @@ def setUpClass(cls):
for dag in dagbag.dags.values():
dag.sync_to_db()
- @conf_vars({
- ("api", "auth_backend"): "airflow.api.auth.backend.kerberos_auth",
- ("kerberos", "keytab"): KRB5_KTNAME,
- })
+ @conf_vars(
+ {
+ ("api", "auth_backend"): "airflow.api.auth.backend.kerberos_auth",
+ ("kerberos", "keytab"): KRB5_KTNAME,
+ }
+ )
def setUp(self):
self.app = application.create_app(testing=True)
@@ -59,7 +61,7 @@ def test_trigger_dag(self):
response = client.post(
url_template.format('example_bash_operator'),
data=json.dumps(dict(run_id='my_run' + datetime.now().isoformat())),
- content_type="application/json"
+ content_type="application/json",
)
self.assertEqual(401, response.status_code)
@@ -84,7 +86,7 @@ class Request:
url_template.format('example_bash_operator'),
data=json.dumps(dict(run_id='my_run' + datetime.now().isoformat())),
content_type="application/json",
- headers=response.request.headers
+ headers=response.request.headers,
)
self.assertEqual(200, response2.status_code)
@@ -94,7 +96,7 @@ def test_unauthorized(self):
response = client.post(
url_template.format('example_bash_operator'),
data=json.dumps(dict(run_id='my_run' + datetime.now().isoformat())),
- content_type="application/json"
+ content_type="application/json",
)
self.assertEqual(401, response.status_code)
diff --git a/tests/www/test_app.py b/tests/www/test_app.py
index 198a73ed1f825..5740e9a55ace9 100644
--- a/tests/www/test_app.py
+++ b/tests/www/test_app.py
@@ -32,16 +32,19 @@ class TestApp(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
from airflow import settings
+
settings.configure_orm()
- @conf_vars({
- ('webserver', 'enable_proxy_fix'): 'True',
- ('webserver', 'proxy_fix_x_for'): '1',
- ('webserver', 'proxy_fix_x_proto'): '1',
- ('webserver', 'proxy_fix_x_host'): '1',
- ('webserver', 'proxy_fix_x_port'): '1',
- ('webserver', 'proxy_fix_x_prefix'): '1'
- })
+ @conf_vars(
+ {
+ ('webserver', 'enable_proxy_fix'): 'True',
+ ('webserver', 'proxy_fix_x_for'): '1',
+ ('webserver', 'proxy_fix_x_proto'): '1',
+ ('webserver', 'proxy_fix_x_host'): '1',
+ ('webserver', 'proxy_fix_x_port'): '1',
+ ('webserver', 'proxy_fix_x_prefix'): '1',
+ }
+ )
@mock.patch("airflow.www.app.app", None)
def test_should_respect_proxy_fix(self):
app = application.cached_app(testing=True)
@@ -77,9 +80,11 @@ def debug_view():
self.assertEqual(b"success", response.get_data())
self.assertEqual(response.status_code, 200)
- @conf_vars({
- ('webserver', 'base_url'): 'http://localhost:8080/internal-client',
- })
+ @conf_vars(
+ {
+ ('webserver', 'base_url'): 'http://localhost:8080/internal-client',
+ }
+ )
@mock.patch("airflow.www.app.app", None)
def test_should_respect_base_url_ignore_proxy_headers(self):
app = application.cached_app(testing=True)
@@ -115,15 +120,17 @@ def debug_view():
self.assertEqual(b"success", response.get_data())
self.assertEqual(response.status_code, 200)
- @conf_vars({
- ('webserver', 'base_url'): 'http://localhost:8080/internal-client',
- ('webserver', 'enable_proxy_fix'): 'True',
- ('webserver', 'proxy_fix_x_for'): '1',
- ('webserver', 'proxy_fix_x_proto'): '1',
- ('webserver', 'proxy_fix_x_host'): '1',
- ('webserver', 'proxy_fix_x_port'): '1',
- ('webserver', 'proxy_fix_x_prefix'): '1'
- })
+ @conf_vars(
+ {
+ ('webserver', 'base_url'): 'http://localhost:8080/internal-client',
+ ('webserver', 'enable_proxy_fix'): 'True',
+ ('webserver', 'proxy_fix_x_for'): '1',
+ ('webserver', 'proxy_fix_x_proto'): '1',
+ ('webserver', 'proxy_fix_x_host'): '1',
+ ('webserver', 'proxy_fix_x_port'): '1',
+ ('webserver', 'proxy_fix_x_prefix'): '1',
+ }
+ )
@mock.patch("airflow.www.app.app", None)
def test_should_respect_base_url_when_proxy_fix_and_base_url_is_set_up_but_headers_missing(self):
app = application.cached_app(testing=True)
@@ -153,15 +160,17 @@ def debug_view():
self.assertEqual(b"success", response.get_data())
self.assertEqual(response.status_code, 200)
- @conf_vars({
- ('webserver', 'base_url'): 'http://localhost:8080/internal-client',
- ('webserver', 'enable_proxy_fix'): 'True',
- ('webserver', 'proxy_fix_x_for'): '1',
- ('webserver', 'proxy_fix_x_proto'): '1',
- ('webserver', 'proxy_fix_x_host'): '1',
- ('webserver', 'proxy_fix_x_port'): '1',
- ('webserver', 'proxy_fix_x_prefix'): '1'
- })
+ @conf_vars(
+ {
+ ('webserver', 'base_url'): 'http://localhost:8080/internal-client',
+ ('webserver', 'enable_proxy_fix'): 'True',
+ ('webserver', 'proxy_fix_x_for'): '1',
+ ('webserver', 'proxy_fix_x_proto'): '1',
+ ('webserver', 'proxy_fix_x_host'): '1',
+ ('webserver', 'proxy_fix_x_port'): '1',
+ ('webserver', 'proxy_fix_x_prefix'): '1',
+ }
+ )
@mock.patch("airflow.www.app.app", None)
def test_should_respect_base_url_and_proxy_when_proxy_fix_and_base_url_is_set_up(self):
app = application.cached_app(testing=True)
@@ -197,21 +206,18 @@ def debug_view():
self.assertEqual(b"success", response.get_data())
self.assertEqual(response.status_code, 200)
- @conf_vars({
- ('core', 'sql_alchemy_pool_enabled'): 'True',
- ('core', 'sql_alchemy_pool_size'): '3',
- ('core', 'sql_alchemy_max_overflow'): '5',
- ('core', 'sql_alchemy_pool_recycle'): '120',
- ('core', 'sql_alchemy_pool_pre_ping'): 'True',
- })
+ @conf_vars(
+ {
+ ('core', 'sql_alchemy_pool_enabled'): 'True',
+ ('core', 'sql_alchemy_pool_size'): '3',
+ ('core', 'sql_alchemy_max_overflow'): '5',
+ ('core', 'sql_alchemy_pool_recycle'): '120',
+ ('core', 'sql_alchemy_pool_pre_ping'): 'True',
+ }
+ )
@mock.patch("airflow.www.app.app", None)
@pytest.mark.backend("mysql", "postgres")
def test_should_set_sqlalchemy_engine_options(self):
app = application.cached_app(testing=True)
- engine_params = {
- 'pool_size': 3,
- 'pool_recycle': 120,
- 'pool_pre_ping': True,
- 'max_overflow': 5
- }
+ engine_params = {'pool_size': 3, 'pool_recycle': 120, 'pool_pre_ping': True, 'max_overflow': 5}
self.assertEqual(app.config['SQLALCHEMY_ENGINE_OPTIONS'], engine_params)
diff --git a/tests/www/test_security.py b/tests/www/test_security.py
index b44f67482da86..1bf460b8713c1 100644
--- a/tests/www/test_security.py
+++ b/tests/www/test_security.py
@@ -187,7 +187,14 @@ def test_get_all_permissions_views(self, mock_get_user_roles):
username = 'get_all_permissions_views'
with self.app.app_context():
- user = fab_utils.create_user(self.app, username, role_name, permissions=[(role_perm, role_vm),],)
+ user = fab_utils.create_user(
+ self.app,
+ username,
+ role_name,
+ permissions=[
+ (role_perm, role_vm),
+ ],
+ )
role = user.roles[0]
mock_get_user_roles.return_value = [role]
@@ -281,9 +288,7 @@ def test_all_dag_access_doesnt_give_non_dag_access(self):
)
self.assertFalse(
self.security_manager.has_access(
- permissions.ACTION_CAN_READ,
- permissions.RESOURCE_TASK_INSTANCE,
- user
+ permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE, user
)
)
@@ -293,7 +298,11 @@ def test_access_control_with_invalid_permission(self):
'can_eat_pudding', # clearly not a real permission
]
username = "LaUser"
- user = fab_utils.create_user(self.app, username=username, role_name='team-a',)
+ user = fab_utils.create_user(
+ self.app,
+ username=username,
+ role_name='team-a',
+ )
for permission in invalid_permissions:
self.expect_user_is_in_role(user, rolename='team-a')
with self.assertRaises(AirflowException) as context:
@@ -306,7 +315,12 @@ def test_access_control_is_set_on_init(self):
username = 'access_control_is_set_on_init'
role_name = 'team-a'
with self.app.app_context():
- user = fab_utils.create_user(self.app, username, role_name, permissions=[],)
+ user = fab_utils.create_user(
+ self.app,
+ username,
+ role_name,
+ permissions=[],
+ )
self.expect_user_is_in_role(user, rolename='team-a')
self.security_manager.sync_perm_for_dag(
'access_control_test',
@@ -329,7 +343,12 @@ def test_access_control_stale_perms_are_revoked(self):
username = 'access_control_stale_perms_are_revoked'
role_name = 'team-a'
with self.app.app_context():
- user = fab_utils.create_user(self.app, username, role_name, permissions=[],)
+ user = fab_utils.create_user(
+ self.app,
+ username,
+ role_name,
+ permissions=[],
+ )
self.expect_user_is_in_role(user, rolename='team-a')
self.security_manager.sync_perm_for_dag(
'access_control_test', access_control={'team-a': READ_WRITE}
diff --git a/tests/www/test_utils.py b/tests/www/test_utils.py
index 51d73fba286b0..efdb249e8cbd2 100644
--- a/tests/www/test_utils.py
+++ b/tests/www/test_utils.py
@@ -61,7 +61,7 @@ def test_sensitive_variable_fields_should_be_hidden(
[
(None, 'TRELLO_API', False),
('token', 'TRELLO_KEY', False),
- ('token, mysecretword', 'TRELLO_KEY', False)
+ ('token, mysecretword', 'TRELLO_KEY', False),
],
)
def test_normal_variable_fields_should_not_be_hidden(
@@ -70,22 +70,15 @@ def test_normal_variable_fields_should_not_be_hidden(
with conf_vars({('admin', 'sensitive_variable_fields'): str(sensitive_variable_fields)}):
self.assertEqual(expected_result, utils.should_hide_value_for_key(key))
- def check_generate_pages_html(self, current_page, total_pages,
- window=7, check_middle=False):
+ def check_generate_pages_html(self, current_page, total_pages, window=7, check_middle=False):
extra_links = 4 # first, prev, next, last
search = "'>\"/>
"
- html_str = utils.generate_pages(current_page, total_pages,
- search=search)
+ html_str = utils.generate_pages(current_page, total_pages, search=search)
- self.assertNotIn(search, html_str,
- "The raw search string shouldn't appear in the output")
- self.assertIn('search=%27%3E%22%2F%3E%3Cimg+src%3Dx+onerror%3Dalert%281%29%3E',
- html_str)
+ self.assertNotIn(search, html_str, "The raw search string shouldn't appear in the output")
+ self.assertIn('search=%27%3E%22%2F%3E%3Cimg+src%3Dx+onerror%3Dalert%281%29%3E', html_str)
- self.assertTrue(
- callable(html_str.__html__),
- "Should return something that is HTML-escaping aware"
- )
+ self.assertTrue(callable(html_str.__html__), "Should return something that is HTML-escaping aware")
dom = BeautifulSoup(html_str, 'html.parser')
self.assertIsNotNone(dom)
@@ -112,25 +105,20 @@ def check_generate_pages_html(self, current_page, total_pages,
self.assertListEqual(query['search'], [search])
def test_generate_pager_current_start(self):
- self.check_generate_pages_html(current_page=0,
- total_pages=6)
+ self.check_generate_pages_html(current_page=0, total_pages=6)
def test_generate_pager_current_middle(self):
- self.check_generate_pages_html(current_page=10,
- total_pages=20,
- check_middle=True)
+ self.check_generate_pages_html(current_page=10, total_pages=20, check_middle=True)
def test_generate_pager_current_end(self):
- self.check_generate_pages_html(current_page=38,
- total_pages=39)
+ self.check_generate_pages_html(current_page=38, total_pages=39)
def test_params_no_values(self):
"""Should return an empty string if no params are passed"""
self.assertEqual('', utils.get_params())
def test_params_search(self):
- self.assertEqual('search=bash_',
- utils.get_params(search='bash_'))
+ self.assertEqual('search=bash_', utils.get_params(search='bash_'))
def test_params_none_and_zero(self):
query_str = utils.get_params(a=0, b=None, c='true')
@@ -140,16 +128,13 @@ def test_params_none_and_zero(self):
def test_params_all(self):
query = utils.get_params(status='active', page=3, search='bash_')
- self.assertEqual(
- {'page': ['3'],
- 'search': ['bash_'],
- 'status': ['active']},
- parse_qs(query)
- )
+ self.assertEqual({'page': ['3'], 'search': ['bash_'], 'status': ['active']}, parse_qs(query))
def test_params_escape(self):
- self.assertEqual('search=%27%3E%22%2F%3E%3Cimg+src%3Dx+onerror%3Dalert%281%29%3E',
- utils.get_params(search="'>\"/>
"))
+ self.assertEqual(
+ 'search=%27%3E%22%2F%3E%3Cimg+src%3Dx+onerror%3Dalert%281%29%3E',
+ utils.get_params(search="'>\"/>
"),
+ )
def test_state_token(self):
# It's shouldn't possible to set these odd values anymore, but lets
@@ -168,12 +153,13 @@ def test_state_token(self):
def test_task_instance_link(self):
from airflow.www.app import cached_app
+
with cached_app(testing=True).test_request_context():
- html = str(utils.task_instance_link({
- 'dag_id': '',
- 'task_id': '',
- 'execution_date': datetime.now()
- }))
+ html = str(
+ utils.task_instance_link(
+ {'dag_id': '', 'task_id': '', 'execution_date': datetime.now()}
+ )
+ )
self.assertIn('%3Ca%261%3E', html)
self.assertIn('%3Cb2%3E', html)
@@ -182,23 +168,20 @@ def test_task_instance_link(self):
def test_dag_link(self):
from airflow.www.app import cached_app
+
with cached_app(testing=True).test_request_context():
- html = str(utils.dag_link({
- 'dag_id': '',
- 'execution_date': datetime.now()
- }))
+ html = str(utils.dag_link({'dag_id': '', 'execution_date': datetime.now()}))
self.assertIn('%3Ca%261%3E', html)
self.assertNotIn('', html)
def test_dag_run_link(self):
from airflow.www.app import cached_app
+
with cached_app(testing=True).test_request_context():
- html = str(utils.dag_run_link({
- 'dag_id': '',
- 'run_id': '',
- 'execution_date': datetime.now()
- }))
+ html = str(
+ utils.dag_run_link({'dag_id': '', 'run_id': '', 'execution_date': datetime.now()})
+ )
self.assertIn('%3Ca%261%3E', html)
self.assertIn('%3Cb2%3E', html)
@@ -207,13 +190,13 @@ def test_dag_run_link(self):
class TestAttrRenderer(unittest.TestCase):
-
def setUp(self):
self.attr_renderer = utils.get_attr_renderer()
def test_python_callable(self):
def example_callable(unused_self):
print("example")
+
rendered = self.attr_renderer["python_callable"](example_callable)
self.assertIn('"example"', rendered)
@@ -233,7 +216,6 @@ def test_markdown_none(self):
class TestWrappedMarkdown(unittest.TestCase):
-
def test_wrapped_markdown_with_docstring_curly_braces(self):
rendered = wrapped_markdown("{braces}", css_class="a_class")
self.assertEqual('', rendered)
@@ -242,4 +224,6 @@ def test_wrapped_markdown_with_some_markdown(self):
rendered = wrapped_markdown("*italic*\n**bold**\n", css_class="a_class")
self.assertEqual(
'''''', rendered)
+bold
''',
+ rendered,
+ )
diff --git a/tests/www/test_validators.py b/tests/www/test_validators.py
index e63f6b042e529..df36b61f1d294 100644
--- a/tests/www/test_validators.py
+++ b/tests/www/test_validators.py
@@ -23,7 +23,6 @@
class TestGreaterEqualThan(unittest.TestCase):
-
def setUp(self):
super().setUp()
self.form_field_mock = mock.MagicMock(data='2017-05-06')
@@ -39,8 +38,7 @@ def _validate(self, fieldname=None, message=None):
if fieldname is None:
fieldname = 'other_field'
- validator = validators.GreaterEqualThan(fieldname=fieldname,
- message=message)
+ validator = validators.GreaterEqualThan(fieldname=fieldname, message=message)
return validator(self.form_mock, self.form_field_mock)
@@ -92,7 +90,6 @@ def test_validation_raises_custom_message(self):
class TestValidJson(unittest.TestCase):
-
def setUp(self):
super().setUp()
self.form_field_mock = mock.MagicMock(data='{"valid":"True"}')
diff --git a/tests/www/test_views.py b/tests/www/test_views.py
index f56d79b7bd454..e17180579447c 100644
--- a/tests/www/test_views.py
+++ b/tests/www/test_views.py
@@ -175,6 +175,7 @@ def capture_templates(self) -> Generator[List[TemplateWithContext], None, None]:
def record(sender, template, context, **extra): # pylint: disable=unused-argument
recorded.append(TemplateWithContext(template, context))
+
template_rendered.connect(record, self.app) # type: ignore
try:
yield recorded
@@ -218,8 +219,7 @@ def create_user_and_login(self, username, role_name, perms):
role_name=role_name,
permissions=perms,
)
- self.login(username=username,
- password=username)
+ self.login(username=username, password=username)
class TestConnectionModelView(TestBase):
@@ -231,7 +231,7 @@ def setUp(self):
'host': 'localhost',
'port': 8080,
'username': 'root',
- 'password': 'admin'
+ 'password': 'admin',
}
def tearDown(self):
@@ -239,20 +239,14 @@ def tearDown(self):
super().tearDown()
def test_create_connection(self):
- resp = self.client.post('/connection/add',
- data=self.connection,
- follow_redirects=True)
+ resp = self.client.post('/connection/add', data=self.connection, follow_redirects=True)
self.check_content_in_response('Added Row', resp)
class TestVariableModelView(TestBase):
def setUp(self):
super().setUp()
- self.variable = {
- 'key': 'test_key',
- 'val': 'text_val',
- 'is_encrypted': True
- }
+ self.variable = {'key': 'test_key', 'val': 'text_val', 'is_encrypted': True}
def tearDown(self):
self.clear_table(models.Variable)
@@ -265,18 +259,17 @@ def test_can_handle_error_on_decrypt(self):
# update the variable with a wrong value, given that is encrypted
Var = models.Variable # pylint: disable=invalid-name
- (self.session.query(Var)
+ (
+ self.session.query(Var)
.filter(Var.key == self.variable['key'])
- .update({
- 'val': 'failed_value_not_encrypted'
- }, synchronize_session=False))
+ .update({'val': 'failed_value_not_encrypted'}, synchronize_session=False)
+ )
self.session.commit()
# retrieve Variables page, should not fail and contain the Invalid
# label for the variable
resp = self.client.get('/variable/list', follow_redirects=True)
- self.check_content_in_response(
- 'Invalid', resp)
+ self.check_content_in_response('Invalid', resp)
def test_xss_prevention(self):
xss = "/variable/list/
"
@@ -286,12 +279,10 @@ def test_xss_prevention(self):
follow_redirects=True,
)
self.assertEqual(resp.status_code, 404)
- self.assertNotIn("
",
- resp.data.decode("utf-8"))
+ self.assertNotIn("
", resp.data.decode("utf-8"))
def test_import_variables_no_file(self):
- resp = self.client.post('/variable/varimport',
- follow_redirects=True)
+ resp = self.client.post('/variable/varimport', follow_redirects=True)
self.check_content_in_response('Missing file or syntax error.', resp)
def test_import_variables_failed(self):
@@ -308,16 +299,17 @@ def test_import_variables_failed(self):
# python 2.7
bytes_content = io.BytesIO(bytes(content))
- resp = self.client.post('/variable/varimport',
- data={'file': (bytes_content, 'test.json')},
- follow_redirects=True)
+ resp = self.client.post(
+ '/variable/varimport', data={'file': (bytes_content, 'test.json')}, follow_redirects=True
+ )
self.check_content_in_response('1 variable(s) failed to be updated.', resp)
def test_import_variables_success(self):
self.assertEqual(self.session.query(models.Variable).count(), 0)
- content = ('{"str_key": "str_value", "int_key": 60,'
- '"list_key": [1, 2], "dict_key": {"k_a": 2, "k_b": 3}}')
+ content = (
+ '{"str_key": "str_value", "int_key": 60, "list_key": [1, 2], "dict_key": {"k_a": 2, "k_b": 3}}'
+ )
try:
# python 3+
bytes_content = io.BytesIO(bytes(content, encoding='utf-8'))
@@ -325,9 +317,9 @@ def test_import_variables_success(self):
# python 2.7
bytes_content = io.BytesIO(bytes(content))
- resp = self.client.post('/variable/varimport',
- data={'file': (bytes_content, 'test.json')},
- follow_redirects=True)
+ resp = self.client.post(
+ '/variable/varimport', data={'file': (bytes_content, 'test.json')}, follow_redirects=True
+ )
self.check_content_in_response('4 variable(s) successfully updated.', resp)
@@ -494,19 +486,22 @@ def prepare_dagruns(self):
run_type=DagRunType.SCHEDULED,
execution_date=self.EXAMPLE_DAG_DEFAULT_DATE,
start_date=timezone.utcnow(),
- state=State.RUNNING)
+ state=State.RUNNING,
+ )
self.sub_dagrun = self.sub_dag.create_dagrun(
run_type=DagRunType.SCHEDULED,
execution_date=self.EXAMPLE_DAG_DEFAULT_DATE,
start_date=timezone.utcnow(),
- state=State.RUNNING)
+ state=State.RUNNING,
+ )
self.xcom_dagrun = self.xcom_dag.create_dagrun(
run_type=DagRunType.SCHEDULED,
execution_date=self.EXAMPLE_DAG_DEFAULT_DATE,
start_date=timezone.utcnow(),
- state=State.RUNNING)
+ state=State.RUNNING,
+ )
def test_index(self):
with assert_queries_count(42):
@@ -527,55 +522,67 @@ def test_health(self):
# case-1: healthy scheduler status
last_scheduler_heartbeat_for_testing_1 = timezone.utcnow()
- self.session.add(BaseJob(job_type='SchedulerJob',
- state='running',
- latest_heartbeat=last_scheduler_heartbeat_for_testing_1))
+ self.session.add(
+ BaseJob(
+ job_type='SchedulerJob',
+ state='running',
+ latest_heartbeat=last_scheduler_heartbeat_for_testing_1,
+ )
+ )
self.session.commit()
resp_json = json.loads(self.client.get('health', follow_redirects=True).data.decode('utf-8'))
self.assertEqual('healthy', resp_json['metadatabase']['status'])
self.assertEqual('healthy', resp_json['scheduler']['status'])
- self.assertEqual(last_scheduler_heartbeat_for_testing_1.isoformat(),
- resp_json['scheduler']['latest_scheduler_heartbeat'])
-
- self.session.query(BaseJob).\
- filter(BaseJob.job_type == 'SchedulerJob',
- BaseJob.state == 'running',
- BaseJob.latest_heartbeat == last_scheduler_heartbeat_for_testing_1).\
- delete()
+ self.assertEqual(
+ last_scheduler_heartbeat_for_testing_1.isoformat(),
+ resp_json['scheduler']['latest_scheduler_heartbeat'],
+ )
+
+ self.session.query(BaseJob).filter(
+ BaseJob.job_type == 'SchedulerJob',
+ BaseJob.state == 'running',
+ BaseJob.latest_heartbeat == last_scheduler_heartbeat_for_testing_1,
+ ).delete()
self.session.commit()
# case-2: unhealthy scheduler status - scenario 1 (SchedulerJob is running too slowly)
last_scheduler_heartbeat_for_testing_2 = timezone.utcnow() - timedelta(minutes=1)
- (self.session
- .query(BaseJob)
- .filter(BaseJob.job_type == 'SchedulerJob')
- .update({'latest_heartbeat': last_scheduler_heartbeat_for_testing_2 - timedelta(seconds=1)}))
- self.session.add(BaseJob(job_type='SchedulerJob',
- state='running',
- latest_heartbeat=last_scheduler_heartbeat_for_testing_2))
+ (
+ self.session.query(BaseJob)
+ .filter(BaseJob.job_type == 'SchedulerJob')
+ .update({'latest_heartbeat': last_scheduler_heartbeat_for_testing_2 - timedelta(seconds=1)})
+ )
+ self.session.add(
+ BaseJob(
+ job_type='SchedulerJob',
+ state='running',
+ latest_heartbeat=last_scheduler_heartbeat_for_testing_2,
+ )
+ )
self.session.commit()
resp_json = json.loads(self.client.get('health', follow_redirects=True).data.decode('utf-8'))
self.assertEqual('healthy', resp_json['metadatabase']['status'])
self.assertEqual('unhealthy', resp_json['scheduler']['status'])
- self.assertEqual(last_scheduler_heartbeat_for_testing_2.isoformat(),
- resp_json['scheduler']['latest_scheduler_heartbeat'])
-
- self.session.query(BaseJob).\
- filter(BaseJob.job_type == 'SchedulerJob',
- BaseJob.state == 'running',
- BaseJob.latest_heartbeat == last_scheduler_heartbeat_for_testing_2).\
- delete()
+ self.assertEqual(
+ last_scheduler_heartbeat_for_testing_2.isoformat(),
+ resp_json['scheduler']['latest_scheduler_heartbeat'],
+ )
+
+ self.session.query(BaseJob).filter(
+ BaseJob.job_type == 'SchedulerJob',
+ BaseJob.state == 'running',
+ BaseJob.latest_heartbeat == last_scheduler_heartbeat_for_testing_2,
+ ).delete()
self.session.commit()
# case-3: unhealthy scheduler status - scenario 2 (no running SchedulerJob)
- self.session.query(BaseJob).\
- filter(BaseJob.job_type == 'SchedulerJob',
- BaseJob.state == 'running').\
- delete()
+ self.session.query(BaseJob).filter(
+ BaseJob.job_type == 'SchedulerJob', BaseJob.state == 'running'
+ ).delete()
self.session.commit()
resp_json = json.loads(self.client.get('health', follow_redirects=True).data.decode('utf-8'))
@@ -588,13 +595,15 @@ def test_home(self):
with self.capture_templates() as templates:
resp = self.client.get('home', follow_redirects=True)
self.check_content_in_response('DAGs', resp)
- val_state_color_mapping = 'const STATE_COLOR = {"failed": "red", ' \
- '"null": "lightblue", "queued": "gray", ' \
- '"removed": "lightgrey", "running": "lime", ' \
- '"scheduled": "tan", "sensing": "lightseagreen", ' \
- '"shutdown": "blue", "skipped": "pink", ' \
- '"success": "green", "up_for_reschedule": "turquoise", ' \
- '"up_for_retry": "gold", "upstream_failed": "orange"};'
+ val_state_color_mapping = (
+ 'const STATE_COLOR = {"failed": "red", '
+ '"null": "lightblue", "queued": "gray", '
+ '"removed": "lightgrey", "running": "lime", '
+ '"scheduled": "tan", "sensing": "lightseagreen", '
+ '"shutdown": "blue", "skipped": "pink", '
+ '"success": "green", "up_for_reschedule": "turquoise", '
+ '"up_for_retry": "gold", "upstream_failed": "orange"};'
+ )
self.check_content_in_response(val_state_color_mapping, resp)
self.assertEqual(len(templates), 1)
@@ -629,6 +638,7 @@ def test_permissionsviews_list(self):
def test_home_filter_tags(self):
from airflow.www.views import FILTER_TAGS_COOKIE
+
with self.client:
self.client.get('home?tags=example&tags=data', follow_redirects=True)
self.assertEqual('example,data', flask_session[FILTER_TAGS_COOKIE])
@@ -638,6 +648,7 @@ def test_home_filter_tags(self):
def test_home_status_filter_cookie(self):
from airflow.www.views import FILTER_STATUS_COOKIE
+
with self.client:
self.client.get('home', follow_redirects=True)
self.assertEqual('all', flask_session[FILTER_STATUS_COOKIE])
@@ -652,20 +663,23 @@ def test_home_status_filter_cookie(self):
self.assertEqual('all', flask_session[FILTER_STATUS_COOKIE])
def test_task(self):
- url = ('task?task_id=runme_0&dag_id=example_bash_operator&execution_date={}'
- .format(self.percent_encode(self.EXAMPLE_DAG_DEFAULT_DATE)))
+ url = 'task?task_id=runme_0&dag_id=example_bash_operator&execution_date={}'.format(
+ self.percent_encode(self.EXAMPLE_DAG_DEFAULT_DATE)
+ )
resp = self.client.get(url, follow_redirects=True)
self.check_content_in_response('Task Instance Details', resp)
def test_xcom(self):
- url = ('xcom?task_id=runme_0&dag_id=example_bash_operator&execution_date={}'
- .format(self.percent_encode(self.EXAMPLE_DAG_DEFAULT_DATE)))
+ url = 'xcom?task_id=runme_0&dag_id=example_bash_operator&execution_date={}'.format(
+ self.percent_encode(self.EXAMPLE_DAG_DEFAULT_DATE)
+ )
resp = self.client.get(url, follow_redirects=True)
self.check_content_in_response('XCom', resp)
def test_rendered(self):
- url = ('rendered?task_id=runme_0&dag_id=example_bash_operator&execution_date={}'
- .format(self.percent_encode(self.EXAMPLE_DAG_DEFAULT_DATE)))
+ url = 'rendered?task_id=runme_0&dag_id=example_bash_operator&execution_date={}'.format(
+ self.percent_encode(self.EXAMPLE_DAG_DEFAULT_DATE)
+ )
resp = self.client.get(url, follow_redirects=True)
self.check_content_in_response('Rendered Template', resp)
@@ -681,8 +695,7 @@ def test_dag_stats(self):
def test_task_stats(self):
resp = self.client.post('task_stats', follow_redirects=True)
self.assertEqual(resp.status_code, 200)
- self.assertEqual(set(list(resp.json.items())[0][1][0].keys()),
- {'state', 'count'})
+ self.assertEqual(set(list(resp.json.items())[0][1][0].keys()), {'state', 'count'})
@conf_vars({("webserver", "show_recent_stats_for_completed_runs"): "False"})
def test_task_stats_only_noncompleted(self):
@@ -704,12 +717,14 @@ def test_view_uses_existing_dagbag(self, endpoint):
resp = self.client.get(url, follow_redirects=True)
self.check_content_in_response('example_bash_operator', resp)
- @parameterized.expand([
- ("hello\nworld", r'\"conf\":{\"abc\":\"hello\\nworld\"}'),
- ("hello'world", r'\"conf\":{\"abc\":\"hello\\u0027world\"}'),
- ("