From 4a5ce39ada36cead975429f3892aa532a7430df7 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Wed, 7 Oct 2020 20:53:24 +0100 Subject: [PATCH 1/2] Enable Black across entire Repo --- .pre-commit-config.yaml | 5 ++--- setup.cfg | 6 +----- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 43564d067e24c..be8573802024b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -159,8 +159,7 @@ repos: rev: 20.8b1 hooks: - id: black - files: api_connexion/.*\.py|.*providers.*\.py|^chart/tests/.*\.py - exclude: .*kubernetes_pod\.py|.*google/common/hooks/base_google\.py$ + exclude: .*kubernetes_pod\.py|.*google/common/hooks/base_google\.py$|^airflow/configuration.py$ args: [--config=./pyproject.toml] - repo: https://github.com/pre-commit/pre-commit-hooks rev: v3.3.0 @@ -203,7 +202,7 @@ repos: name: Run isort to sort imports types: [python] # To keep consistent with the global isort skip config defined in setup.cfg - exclude: ^build/.*$|^.tox/.*$|^venv/.*$|.*api_connexion/.*\.py|.*providers.*\.py + exclude: ^build/.*$|^.tox/.*$|^venv/.*$ - repo: https://github.com/pycqa/pydocstyle rev: 5.1.1 hooks: diff --git a/setup.cfg b/setup.cfg index 9cb0fa9ddb617..3d07da425bb55 100644 --- a/setup.cfg +++ b/setup.cfg @@ -70,11 +70,7 @@ ignore_errors = True line_length=110 combine_as_imports = true default_section = THIRDPARTY -include_trailing_comma = true known_first_party=airflow,tests -multi_line_output=5 # Need to be consistent with the exclude config defined in pre-commit-config.yaml skip=build,.tox,venv -# ToDo: Enable the below before Airflow 2.0 -# profile = "black" -skip_glob=*/api_connexion/**/*.py,*/providers/**/*.py,provider_packages/**,chart/tests/*.py +profile = black From 9bdaf18f7f4c222603107ee73c198a1c87169bab Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Wed, 7 Oct 2020 20:59:21 +0100 Subject: [PATCH 2/2] Black Formatting --- airflow/__init__.py | 3 + airflow/api/__init__.py | 5 +- airflow/api/auth/backend/basic_auth.py | 5 +- airflow/api/auth/backend/default.py | 1 + airflow/api/auth/backend/kerberos_auth.py | 5 +- airflow/api/client/__init__.py | 2 +- airflow/api/client/json_client.py | 30 +- airflow/api/client/local_client.py | 7 +- airflow/api/common/experimental/__init__.py | 8 +- airflow/api/common/experimental/delete_dag.py | 11 +- .../api/common/experimental/get_dag_runs.py | 22 +- .../api/common/experimental/get_lineage.py | 10 +- .../common/experimental/get_task_instance.py | 3 +- airflow/api/common/experimental/mark_tasks.py | 73 +- .../api/common/experimental/trigger_dag.py | 19 +- .../endpoints/dag_run_endpoint.py | 2 +- .../endpoints/dag_source_endpoint.py | 1 - .../endpoints/task_instance_endpoint.py | 20 +- .../api_connexion/schemas/sla_miss_schema.py | 1 + .../schemas/task_instance_schema.py | 6 +- airflow/cli/cli_parser.py | 943 +++++------ airflow/cli/commands/cheat_sheet_command.py | 1 + airflow/cli/commands/config_command.py | 4 +- airflow/cli/commands/connection_command.py | 75 +- airflow/cli/commands/dag_command.py | 102 +- airflow/cli/commands/db_command.py | 10 +- airflow/cli/commands/info_command.py | 10 +- airflow/cli/commands/legacy_commands.py | 2 +- airflow/cli/commands/pool_command.py | 13 +- airflow/cli/commands/role_command.py | 4 +- .../cli/commands/rotate_fernet_key_command.py | 3 +- airflow/cli/commands/scheduler_command.py | 11 +- airflow/cli/commands/sync_perm_command.py | 4 +- airflow/cli/commands/task_command.py | 82 +- airflow/cli/commands/user_command.py | 63 +- airflow/cli/commands/variable_command.py | 11 +- airflow/cli/commands/webserver_command.py | 106 +- .../airflow_local_settings.py | 43 +- airflow/config_templates/default_celery.py | 49 +- airflow/configuration.py | 1 + airflow/contrib/__init__.py | 4 +- airflow/contrib/hooks/aws_dynamodb_hook.py | 3 +- .../contrib/hooks/aws_glue_catalog_hook.py | 3 +- airflow/contrib/hooks/aws_hook.py | 6 +- airflow/contrib/hooks/aws_logs_hook.py | 3 +- .../hooks/azure_container_instance_hook.py | 3 +- .../hooks/azure_container_registry_hook.py | 3 +- .../hooks/azure_container_volume_hook.py | 3 +- airflow/contrib/hooks/azure_cosmos_hook.py | 3 +- airflow/contrib/hooks/azure_data_lake_hook.py | 3 +- airflow/contrib/hooks/azure_fileshare_hook.py | 3 +- airflow/contrib/hooks/bigquery_hook.py | 9 +- airflow/contrib/hooks/cassandra_hook.py | 3 +- airflow/contrib/hooks/cloudant_hook.py | 3 +- airflow/contrib/hooks/databricks_hook.py | 17 +- airflow/contrib/hooks/datadog_hook.py | 3 +- airflow/contrib/hooks/datastore_hook.py | 3 +- airflow/contrib/hooks/dingding_hook.py | 3 +- airflow/contrib/hooks/discord_webhook_hook.py | 3 +- airflow/contrib/hooks/emr_hook.py | 3 +- airflow/contrib/hooks/fs_hook.py | 3 +- airflow/contrib/hooks/ftp_hook.py | 3 +- airflow/contrib/hooks/gcp_api_base_hook.py | 6 +- airflow/contrib/hooks/gcp_bigtable_hook.py | 3 +- airflow/contrib/hooks/gcp_cloud_build_hook.py | 3 +- airflow/contrib/hooks/gcp_compute_hook.py | 6 +- airflow/contrib/hooks/gcp_container_hook.py | 6 +- airflow/contrib/hooks/gcp_dataflow_hook.py | 6 +- airflow/contrib/hooks/gcp_dataproc_hook.py | 6 +- airflow/contrib/hooks/gcp_dlp_hook.py | 3 +- airflow/contrib/hooks/gcp_function_hook.py | 6 +- airflow/contrib/hooks/gcp_kms_hook.py | 6 +- airflow/contrib/hooks/gcp_mlengine_hook.py | 3 +- .../hooks/gcp_natural_language_hook.py | 3 +- airflow/contrib/hooks/gcp_pubsub_hook.py | 3 +- .../contrib/hooks/gcp_speech_to_text_hook.py | 6 +- airflow/contrib/hooks/gcp_tasks_hook.py | 3 +- .../contrib/hooks/gcp_text_to_speech_hook.py | 6 +- airflow/contrib/hooks/gcp_transfer_hook.py | 6 +- airflow/contrib/hooks/gcp_translate_hook.py | 3 +- .../hooks/gcp_video_intelligence_hook.py | 3 +- airflow/contrib/hooks/gcp_vision_hook.py | 3 +- airflow/contrib/hooks/gcs_hook.py | 6 +- airflow/contrib/hooks/gdrive_hook.py | 3 +- airflow/contrib/hooks/grpc_hook.py | 3 +- airflow/contrib/hooks/imap_hook.py | 3 +- airflow/contrib/hooks/jenkins_hook.py | 3 +- airflow/contrib/hooks/mongo_hook.py | 3 +- airflow/contrib/hooks/openfaas_hook.py | 3 +- airflow/contrib/hooks/opsgenie_alert_hook.py | 3 +- airflow/contrib/hooks/pagerduty_hook.py | 3 +- airflow/contrib/hooks/pinot_hook.py | 3 +- airflow/contrib/hooks/qubole_check_hook.py | 3 +- airflow/contrib/hooks/qubole_hook.py | 3 +- airflow/contrib/hooks/redis_hook.py | 3 +- airflow/contrib/hooks/sagemaker_hook.py | 9 +- airflow/contrib/hooks/salesforce_hook.py | 3 +- airflow/contrib/hooks/segment_hook.py | 3 +- airflow/contrib/hooks/slack_webhook_hook.py | 3 +- airflow/contrib/hooks/spark_jdbc_hook.py | 3 +- airflow/contrib/hooks/spark_sql_hook.py | 3 +- airflow/contrib/hooks/spark_submit_hook.py | 3 +- airflow/contrib/hooks/sqoop_hook.py | 3 +- airflow/contrib/hooks/ssh_hook.py | 3 +- airflow/contrib/hooks/vertica_hook.py | 3 +- airflow/contrib/hooks/wasb_hook.py | 3 +- airflow/contrib/hooks/winrm_hook.py | 3 +- .../contrib/operators/adls_list_operator.py | 3 +- airflow/contrib/operators/adls_to_gcs.py | 6 +- .../azure_container_instances_operator.py | 3 +- .../operators/azure_cosmos_operator.py | 3 +- .../operators/bigquery_check_operator.py | 7 +- .../contrib/operators/bigquery_get_data.py | 3 +- .../contrib/operators/bigquery_operator.py | 12 +- .../contrib/operators/bigquery_to_bigquery.py | 3 +- airflow/contrib/operators/bigquery_to_gcs.py | 6 +- .../operators/bigquery_to_mysql_operator.py | 3 +- airflow/contrib/operators/cassandra_to_gcs.py | 6 +- .../contrib/operators/databricks_operator.py | 6 +- .../contrib/operators/dataflow_operator.py | 16 +- .../contrib/operators/dataproc_operator.py | 53 +- .../operators/datastore_export_operator.py | 6 +- .../operators/datastore_import_operator.py | 6 +- .../contrib/operators/dingding_operator.py | 3 +- .../operators/discord_webhook_operator.py | 3 +- .../operators/docker_swarm_operator.py | 3 +- airflow/contrib/operators/druid_operator.py | 3 +- airflow/contrib/operators/dynamodb_to_s3.py | 3 +- airflow/contrib/operators/ecs_operator.py | 3 +- .../operators/emr_add_steps_operator.py | 3 +- .../operators/emr_create_job_flow_operator.py | 3 +- .../emr_terminate_job_flow_operator.py | 3 +- airflow/contrib/operators/file_to_gcs.py | 6 +- airflow/contrib/operators/file_to_wasb.py | 3 +- .../operators/gcp_bigtable_operator.py | 28 +- .../operators/gcp_cloud_build_operator.py | 3 +- .../contrib/operators/gcp_compute_operator.py | 30 +- .../operators/gcp_container_operator.py | 16 +- airflow/contrib/operators/gcp_dlp_operator.py | 58 +- .../operators/gcp_function_operator.py | 12 +- .../gcp_natural_language_operator.py | 21 +- .../contrib/operators/gcp_spanner_operator.py | 33 +- .../operators/gcp_speech_to_text_operator.py | 6 +- airflow/contrib/operators/gcp_sql_operator.py | 15 +- .../contrib/operators/gcp_tasks_operator.py | 19 +- .../operators/gcp_text_to_speech_operator.py | 6 +- .../operators/gcp_transfer_operator.py | 48 +- .../operators/gcp_translate_operator.py | 3 +- .../gcp_translate_speech_operator.py | 6 +- .../gcp_video_intelligence_operator.py | 6 +- .../contrib/operators/gcp_vision_operator.py | 59 +- airflow/contrib/operators/gcs_acl_operator.py | 12 +- .../contrib/operators/gcs_delete_operator.py | 6 +- .../operators/gcs_download_operator.py | 6 +- .../contrib/operators/gcs_list_operator.py | 6 +- airflow/contrib/operators/gcs_operator.py | 6 +- airflow/contrib/operators/gcs_to_bq.py | 6 +- airflow/contrib/operators/gcs_to_gcs.py | 6 +- .../operators/gcs_to_gcs_transfer_operator.py | 3 +- .../operators/gcs_to_gdrive_operator.py | 6 +- airflow/contrib/operators/gcs_to_s3.py | 6 +- airflow/contrib/operators/grpc_operator.py | 3 +- airflow/contrib/operators/hive_to_dynamodb.py | 3 +- .../imap_attachment_to_s3_operator.py | 3 +- .../operators/jenkins_job_trigger_operator.py | 3 +- airflow/contrib/operators/jira_operator.py | 3 +- .../operators/kubernetes_pod_operator.py | 3 +- .../contrib/operators/mlengine_operator.py | 19 +- airflow/contrib/operators/mongo_to_s3.py | 3 +- airflow/contrib/operators/mssql_to_gcs.py | 6 +- airflow/contrib/operators/mysql_to_gcs.py | 6 +- .../operators/opsgenie_alert_operator.py | 3 +- .../oracle_to_azure_data_lake_transfer.py | 3 +- .../operators/oracle_to_oracle_transfer.py | 6 +- .../operators/postgres_to_gcs_operator.py | 6 +- airflow/contrib/operators/pubsub_operator.py | 7 +- .../operators/qubole_check_operator.py | 6 +- airflow/contrib/operators/qubole_operator.py | 3 +- .../operators/redis_publish_operator.py | 3 +- .../operators/s3_copy_object_operator.py | 3 +- .../operators/s3_delete_objects_operator.py | 3 +- airflow/contrib/operators/s3_list_operator.py | 3 +- .../contrib/operators/s3_to_gcs_operator.py | 3 +- .../operators/s3_to_gcs_transfer_operator.py | 7 +- .../contrib/operators/s3_to_sftp_operator.py | 3 +- .../operators/sagemaker_base_operator.py | 3 +- .../sagemaker_endpoint_config_operator.py | 3 +- .../operators/sagemaker_endpoint_operator.py | 3 +- .../operators/sagemaker_model_operator.py | 3 +- .../operators/sagemaker_training_operator.py | 3 +- .../operators/sagemaker_transform_operator.py | 3 +- .../operators/sagemaker_tuning_operator.py | 3 +- .../operators/segment_track_event_operator.py | 3 +- .../contrib/operators/sftp_to_s3_operator.py | 3 +- .../operators/slack_webhook_operator.py | 3 +- .../contrib/operators/snowflake_operator.py | 3 +- .../contrib/operators/spark_jdbc_operator.py | 3 +- .../contrib/operators/spark_sql_operator.py | 3 +- .../operators/spark_submit_operator.py | 3 +- airflow/contrib/operators/sql_to_gcs.py | 6 +- airflow/contrib/operators/sqoop_operator.py | 3 +- airflow/contrib/operators/ssh_operator.py | 3 +- airflow/contrib/operators/vertica_operator.py | 3 +- airflow/contrib/operators/vertica_to_hive.py | 6 +- airflow/contrib/operators/vertica_to_mysql.py | 6 +- .../operators/wasb_delete_blob_operator.py | 3 +- airflow/contrib/operators/winrm_operator.py | 3 +- .../contrib/secrets/gcp_secrets_manager.py | 3 +- .../aws_glue_catalog_partition_sensor.py | 3 +- .../contrib/sensors/azure_cosmos_sensor.py | 3 +- airflow/contrib/sensors/bash_sensor.py | 3 +- .../contrib/sensors/celery_queue_sensor.py | 3 +- airflow/contrib/sensors/datadog_sensor.py | 3 +- airflow/contrib/sensors/emr_base_sensor.py | 3 +- .../contrib/sensors/emr_job_flow_sensor.py | 3 +- airflow/contrib/sensors/emr_step_sensor.py | 3 +- airflow/contrib/sensors/file_sensor.py | 3 +- airflow/contrib/sensors/ftp_sensor.py | 3 +- .../contrib/sensors/gcp_transfer_sensor.py | 6 +- airflow/contrib/sensors/gcs_sensor.py | 19 +- airflow/contrib/sensors/hdfs_sensor.py | 9 +- .../contrib/sensors/imap_attachment_sensor.py | 3 +- airflow/contrib/sensors/mongo_sensor.py | 3 +- airflow/contrib/sensors/pubsub_sensor.py | 3 +- airflow/contrib/sensors/python_sensor.py | 3 +- airflow/contrib/sensors/qubole_sensor.py | 7 +- airflow/contrib/sensors/redis_key_sensor.py | 3 +- .../contrib/sensors/redis_pub_sub_sensor.py | 3 +- .../contrib/sensors/sagemaker_base_sensor.py | 3 +- .../sensors/sagemaker_endpoint_sensor.py | 3 +- .../sensors/sagemaker_training_sensor.py | 6 +- .../sensors/sagemaker_transform_sensor.py | 3 +- .../sensors/sagemaker_tuning_sensor.py | 3 +- airflow/contrib/sensors/wasb_sensor.py | 3 +- airflow/contrib/sensors/weekday_sensor.py | 3 +- .../contrib/task_runner/cgroup_task_runner.py | 3 +- airflow/contrib/utils/__init__.py | 5 +- airflow/contrib/utils/gcp_field_sanitizer.py | 6 +- airflow/contrib/utils/gcp_field_validator.py | 7 +- airflow/contrib/utils/log/__init__.py | 5 +- .../log/task_handler_with_custom_formatter.py | 3 +- .../contrib/utils/mlengine_operator_utils.py | 3 +- .../utils/mlengine_prediction_summary.py | 3 +- airflow/contrib/utils/sendgrid.py | 3 +- airflow/contrib/utils/weekday.py | 3 +- airflow/example_dags/example_bash_operator.py | 2 +- .../example_dags/example_branch_operator.py | 2 +- .../example_branch_python_dop_operator_3.py | 9 +- airflow/example_dags/example_dag_decorator.py | 11 +- .../example_external_task_marker_dag.py | 24 +- .../example_kubernetes_executor.py | 38 +- .../example_kubernetes_executor_config.py | 63 +- airflow/example_dags/example_latest_only.py | 2 +- .../example_latest_only_with_trigger.py | 2 +- .../example_dags/example_nested_branch_dag.py | 5 +- ...example_passing_params_via_test_command.py | 16 +- .../example_dags/example_python_operator.py | 7 +- .../example_dags/example_subdag_operator.py | 6 +- .../example_trigger_controller_dag.py | 2 +- .../example_trigger_target_dag.py | 2 +- airflow/example_dags/example_xcom.py | 2 +- airflow/example_dags/subdags/subdag.py | 2 + airflow/example_dags/tutorial.py | 1 + .../tutorial_decorated_etl_dag.py | 3 + airflow/example_dags/tutorial_etl_dag.py | 4 + airflow/executors/base_executor.py | 62 +- airflow/executors/celery_executor.py | 65 +- .../executors/celery_kubernetes_executor.py | 40 +- airflow/executors/dask_executor.py | 12 +- airflow/executors/debug_executor.py | 10 +- airflow/executors/executor_loader.py | 9 +- airflow/executors/kubernetes_executor.py | 311 ++-- airflow/executors/local_executor.py | 52 +- airflow/executors/sequential_executor.py | 12 +- airflow/hooks/base_hook.py | 2 +- airflow/hooks/dbapi_hook.py | 31 +- airflow/hooks/docker_hook.py | 3 +- airflow/hooks/druid_hook.py | 3 +- airflow/hooks/hdfs_hook.py | 3 +- airflow/hooks/hive_hooks.py | 8 +- airflow/hooks/http_hook.py | 3 +- airflow/hooks/jdbc_hook.py | 3 +- airflow/hooks/mssql_hook.py | 3 +- airflow/hooks/mysql_hook.py | 3 +- airflow/hooks/oracle_hook.py | 3 +- airflow/hooks/pig_hook.py | 3 +- airflow/hooks/postgres_hook.py | 3 +- airflow/hooks/presto_hook.py | 3 +- airflow/hooks/samba_hook.py | 3 +- airflow/hooks/slack_hook.py | 3 +- airflow/hooks/sqlite_hook.py | 3 +- airflow/hooks/webhdfs_hook.py | 3 +- airflow/hooks/zendesk_hook.py | 3 +- airflow/jobs/backfill_job.py | 297 ++-- airflow/jobs/base_job.py | 30 +- airflow/jobs/local_task_job.py | 77 +- airflow/jobs/scheduler_job.py | 431 +++-- airflow/kubernetes/kube_client.py | 22 +- airflow/kubernetes/pod_generator.py | 81 +- .../kubernetes/pod_generator_deprecated.py | 43 +- airflow/kubernetes/pod_launcher.py | 108 +- airflow/kubernetes/pod_runtime_info_env.py | 6 +- airflow/kubernetes/refresh_config.py | 6 +- airflow/kubernetes/secret.py | 39 +- airflow/lineage/__init__.py | 67 +- airflow/lineage/entities.py | 2 +- airflow/logging_config.py | 15 +- airflow/macros/hive.py | 16 +- airflow/migrations/env.py | 5 +- .../versions/03bc53e68815_add_sm_dag_index.py | 4 +- .../versions/05f30312d566_merge_heads.py | 4 +- .../0a2a5b66e19d_add_task_reschedule_table.py | 20 +- .../0e2a74e0fc9f_add_time_zone_awareness.py | 188 +-- ...add_dag_id_state_index_on_dag_run_table.py | 4 +- .../13eb55f81627_for_compatibility.py | 4 +- .../1507a7289a2f_create_is_encrypted.py | 17 +- ...e3_add_is_encrypted_column_to_variable_.py | 4 +- .../versions/1b38cef5b76e_add_dagrun.py | 30 +- .../211e584da130_add_ti_state_index.py | 4 +- ...24_add_executor_config_to_task_instance.py | 4 +- .../versions/2e541a1dcfed_task_duration.py | 14 +- .../2e82aab8ef20_rename_user_table.py | 4 +- ...0f54d61_more_logging_into_task_isntance.py | 4 +- ...4_add_kubernetes_resource_checkpointing.py | 15 +- .../versions/40e67319e3a9_dagrun_config.py | 4 +- .../41f5f12752f8_add_superuser_field.py | 4 +- .../versions/4446e08588_dagrun_start_end.py | 4 +- ..._add_fractional_seconds_to_mysql_tables.py | 202 +-- .../502898887f84_adding_extra_to_log.py | 4 +- ..._mssql_exec_date_rendered_task_instance.py | 4 +- .../versions/52d714495f0_job_id_indices.py | 7 +- ...61833c1c74b_add_password_column_to_user.py | 4 +- ...de9cddf6c9_add_task_fails_journal_table.py | 4 +- ...4a4_make_taskinstance_pool_not_nullable.py | 14 +- ...hange_datetime_to_datetime2_6_on_mssql_.py | 155 +- .../7939bcff74ba_add_dagtags_table.py | 7 +- .../849da589634d_prefix_dag_permissions.py | 10 +- .../8504051e801b_xcom_dag_task_indices.py | 3 +- ...add_rendered_task_instance_fields_table.py | 4 +- .../856955da8476_fix_sqlite_foreign_key.py | 43 +- ...5c0_add_kubernetes_scheduler_uniqueness.py | 23 +- ...3f6d53_add_unique_constraint_to_conn_id.py | 22 +- ...c8_task_reschedule_fk_on_cascade_delete.py | 18 +- .../947454bf1dff_add_ti_job_id_index.py | 4 +- .../952da73b5eff_add_dag_code_table.py | 18 +- .../versions/9635ae0956e7_index_faskfail.py | 10 +- ..._add_scheduling_decision_to_dagrun_and_.py | 1 + ...b_add_pool_slots_field_to_task_instance.py | 4 +- .../a56c9515abdc_remove_dag_stat_table.py | 14 +- ...dd_precision_to_execution_date_in_mysql.py | 10 +- .../versions/b0125267960b_merge_heads.py | 4 +- ...6_add_a_column_to_track_the_encryption_.py | 7 +- ...dd_notification_sent_column_to_sla_miss.py | 4 +- ...6_make_xcom_value_column_a_large_binary.py | 4 +- ...4f3d11e8b_drop_kuberesourceversion_and_.py | 30 +- .../bf00311e1990_add_index_to_taskinstance.py | 11 +- .../c8ffec048a3b_add_fields_to_dag.py | 4 +- ...7_add_max_tries_column_to_task_instance.py | 24 +- .../cf5dc11e79ad_drop_user_and_chart.py | 17 +- ...ae31099d61_increase_text_size_for_mysql.py | 4 +- .../d38e04c12aa2_add_serialized_dag_table.py | 25 +- ..._add_dag_hash_column_to_serialized_dag_.py | 3 +- .../versions/dd25f486b8ea_add_idx_log_dag.py | 4 +- ...4ecb8fbee3_add_schedule_interval_to_dag.py | 4 +- ...e357a868_update_schema_for_smart_sensor.py | 15 +- .../versions/e3a246e0dc1_current_schema.py | 67 +- ...433877c24_fix_mysql_not_null_constraint.py | 4 +- .../f2ca10b85618_add_dag_stats_table.py | 22 +- ...increase_length_for_connection_password.py | 20 +- airflow/models/base.py | 4 +- airflow/models/baseoperator.py | 293 ++-- airflow/models/connection.py | 80 +- airflow/models/crypto.py | 11 +- airflow/models/dag.py | 546 +++--- airflow/models/dagbag.py | 115 +- airflow/models/dagcode.py | 52 +- airflow/models/dagrun.py | 217 +-- airflow/models/log.py | 4 +- airflow/models/pool.py | 16 +- airflow/models/renderedtifields.py | 80 +- airflow/models/sensorinstance.py | 28 +- airflow/models/serialized_dag.py | 46 +- airflow/models/skipmixin.py | 32 +- airflow/models/slamiss.py | 7 +- airflow/models/taskfail.py | 5 +- airflow/models/taskinstance.py | 452 ++--- airflow/models/taskreschedule.py | 29 +- airflow/models/variable.py | 13 +- airflow/models/xcom.py | 108 +- airflow/models/xcom_arg.py | 13 +- airflow/operators/bash.py | 28 +- airflow/operators/bash_operator.py | 3 +- airflow/operators/check_operator.py | 20 +- airflow/operators/dagrun_operator.py | 7 +- airflow/operators/docker_operator.py | 3 +- airflow/operators/druid_check_operator.py | 3 +- airflow/operators/email.py | 35 +- airflow/operators/email_operator.py | 3 +- airflow/operators/gcs_to_s3.py | 3 +- airflow/operators/generic_transfer.py | 22 +- .../operators/google_api_to_s3_transfer.py | 13 +- airflow/operators/hive_operator.py | 3 +- airflow/operators/hive_stats_operator.py | 3 +- airflow/operators/hive_to_druid.py | 6 +- airflow/operators/hive_to_mysql.py | 6 +- airflow/operators/hive_to_samba_operator.py | 3 +- airflow/operators/http_operator.py | 3 +- airflow/operators/jdbc_operator.py | 3 +- airflow/operators/latest_only.py | 10 +- airflow/operators/latest_only_operator.py | 3 +- airflow/operators/mssql_operator.py | 3 +- airflow/operators/mssql_to_hive.py | 6 +- airflow/operators/mysql_operator.py | 3 +- airflow/operators/mysql_to_hive.py | 6 +- airflow/operators/oracle_operator.py | 3 +- airflow/operators/papermill_operator.py | 3 +- airflow/operators/pig_operator.py | 3 +- airflow/operators/postgres_operator.py | 3 +- airflow/operators/presto_check_operator.py | 12 +- airflow/operators/presto_to_mysql.py | 6 +- airflow/operators/python.py | 116 +- airflow/operators/python_operator.py | 8 +- airflow/operators/redshift_to_s3_operator.py | 6 +- .../operators/s3_file_transform_operator.py | 3 +- airflow/operators/s3_to_hive_operator.py | 6 +- airflow/operators/s3_to_redshift_operator.py | 6 +- airflow/operators/slack_operator.py | 3 +- airflow/operators/sql.py | 75 +- airflow/operators/sql_branch_operator.py | 6 +- airflow/operators/sqlite_operator.py | 3 +- airflow/operators/subdag_operator.py | 56 +- airflow/plugins_manager.py | 60 +- .../example_dags/example_glacier_to_gcs.py | 4 +- .../providers/amazon/aws/hooks/base_aws.py | 2 +- .../amazon/aws/hooks/cloud_formation.py | 2 +- .../hooks/elasticache_replication_group.py | 3 +- airflow/providers/amazon/aws/hooks/glacier.py | 1 + .../amazon/aws/hooks/glue_catalog.py | 2 +- .../providers/amazon/aws/hooks/sagemaker.py | 2 +- .../amazon/aws/hooks/secrets_manager.py | 1 + airflow/providers/amazon/aws/hooks/sns.py | 2 +- .../providers/amazon/aws/operators/batch.py | 2 +- .../amazon/aws/operators/datasync.py | 2 +- airflow/providers/amazon/aws/operators/ecs.py | 2 +- .../aws/operators/sagemaker_transform.py | 2 +- .../providers/amazon/aws/sensors/emr_base.py | 2 +- .../amazon/aws/sensors/sagemaker_training.py | 3 +- .../amazon/aws/transfers/dynamodb_to_s3.py | 2 +- .../amazon/aws/transfers/gcs_to_s3.py | 2 +- .../amazon/aws/transfers/glacier_to_gcs.py | 2 +- .../amazon/aws/transfers/hive_to_dynamodb.py | 2 +- .../amazon/aws/transfers/mongo_to_s3.py | 2 +- .../amazon/backport_provider_setup.py | 2 +- .../cassandra/backport_provider_setup.py | 2 +- .../apache/druid/backport_provider_setup.py | 2 +- .../apache/hdfs/backport_provider_setup.py | 2 +- .../apache/hive/backport_provider_setup.py | 2 +- .../apache/kylin/backport_provider_setup.py | 2 +- .../apache/livy/backport_provider_setup.py | 2 +- .../apache/pig/backport_provider_setup.py | 2 +- .../apache/pinot/backport_provider_setup.py | 2 +- .../apache/spark/backport_provider_setup.py | 2 +- .../apache/sqoop/backport_provider_setup.py | 2 +- .../celery/backport_provider_setup.py | 2 +- .../cloudant/backport_provider_setup.py | 2 +- .../kubernetes/backport_provider_setup.py | 2 +- .../cncf/kubernetes/hooks/kubernetes.py | 2 +- .../kubernetes/operators/kubernetes_pod.py | 7 +- .../databricks/backport_provider_setup.py | 2 +- .../providers/databricks/hooks/databricks.py | 2 +- .../databricks/operators/databricks.py | 2 +- .../datadog/backport_provider_setup.py | 2 +- .../dingding/backport_provider_setup.py | 2 +- airflow/providers/dingding/hooks/dingding.py | 2 +- .../providers/dingding/operators/dingding.py | 3 +- .../discord/backport_provider_setup.py | 2 +- .../docker/backport_provider_setup.py | 2 +- .../elasticsearch/backport_provider_setup.py | 2 +- .../exasol/backport_provider_setup.py | 2 +- airflow/providers/exasol/hooks/exasol.py | 2 +- .../facebook/backport_provider_setup.py | 2 +- .../providers/ftp/backport_provider_setup.py | 2 +- airflow/providers/ftp/hooks/ftp.py | 2 +- .../google/backport_provider_setup.py | 2 +- .../example_azure_fileshare_to_gcs.py | 2 +- .../example_dags/example_cloud_memorystore.py | 6 +- .../example_dags/example_mysql_to_gcs.py | 1 + .../providers/google/cloud/hooks/bigquery.py | 4 +- .../google/cloud/hooks/cloud_memorystore.py | 2 +- .../providers/google/cloud/hooks/cloud_sql.py | 2 +- .../hooks/cloud_storage_transfer_service.py | 2 +- .../providers/google/cloud/hooks/datastore.py | 2 +- airflow/providers/google/cloud/hooks/gcs.py | 2 +- airflow/providers/google/cloud/hooks/gdm.py | 2 +- .../providers/google/cloud/hooks/mlengine.py | 2 +- .../cloud_storage_transfer_service.py | 2 +- .../providers/google/cloud/operators/dlp.py | 2 +- .../sensors/cloud_storage_transfer_service.py | 2 +- .../google/cloud/sensors/dataproc.py | 2 +- .../cloud/transfers/azure_fileshare_to_gcs.py | 4 +- .../google/cloud/transfers/presto_to_gcs.py | 2 +- .../cloud/utils/credentials_provider.py | 2 +- .../marketing_platform/operators/analytics.py | 2 +- .../providers/grpc/backport_provider_setup.py | 2 +- .../hashicorp/backport_provider_setup.py | 2 +- .../providers/http/backport_provider_setup.py | 2 +- .../providers/imap/backport_provider_setup.py | 2 +- .../providers/jdbc/backport_provider_setup.py | 2 +- .../jenkins/backport_provider_setup.py | 2 +- .../providers/jira/backport_provider_setup.py | 2 +- .../azure/backport_provider_setup.py | 2 +- .../example_dags/example_azure_blob_to_gcs.py | 4 +- .../example_dags/example_local_to_adls.py | 1 + .../microsoft/azure/hooks/azure_batch.py | 2 +- .../microsoft/azure/hooks/azure_fileshare.py | 4 +- .../microsoft/azure/log/wasb_task_handler.py | 2 +- .../operators/azure_container_instances.py | 6 +- .../azure/transfers/azure_blob_to_gcs.py | 2 +- .../azure/transfers/local_to_adls.py | 3 +- .../mssql/backport_provider_setup.py | 2 +- .../winrm/backport_provider_setup.py | 2 +- .../mongo/backport_provider_setup.py | 2 +- .../mysql/backport_provider_setup.py | 2 +- .../providers/odbc/backport_provider_setup.py | 2 +- airflow/providers/odbc/hooks/odbc.py | 2 +- .../openfaas/backport_provider_setup.py | 2 +- .../opsgenie/backport_provider_setup.py | 2 +- .../opsgenie/hooks/opsgenie_alert.py | 2 +- .../opsgenie/operators/opsgenie_alert.py | 2 +- .../oracle/backport_provider_setup.py | 2 +- airflow/providers/oracle/hooks/oracle.py | 2 +- .../pagerduty/backport_provider_setup.py | 2 +- .../plexus/backport_provider_setup.py | 2 +- airflow/providers/plexus/operators/job.py | 2 +- .../postgres/backport_provider_setup.py | 2 +- airflow/providers/postgres/hooks/postgres.py | 2 +- .../presto/backport_provider_setup.py | 2 +- airflow/providers/presto/hooks/presto.py | 2 +- .../qubole/backport_provider_setup.py | 2 +- airflow/providers/qubole/hooks/qubole.py | 4 +- .../providers/qubole/hooks/qubole_check.py | 2 +- .../qubole/operators/qubole_check.py | 2 +- .../redis/backport_provider_setup.py | 2 +- .../salesforce/backport_provider_setup.py | 2 +- .../providers/salesforce/hooks/salesforce.py | 2 +- airflow/providers/salesforce/hooks/tableau.py | 2 +- .../samba/backport_provider_setup.py | 2 +- .../segment/backport_provider_setup.py | 2 +- .../providers/sftp/backport_provider_setup.py | 2 +- .../singularity/backport_provider_setup.py | 2 +- .../slack/backport_provider_setup.py | 2 +- airflow/providers/slack/operators/slack.py | 2 +- .../slack/operators/slack_webhook.py | 2 +- .../snowflake/backport_provider_setup.py | 2 +- .../providers/snowflake/hooks/snowflake.py | 2 +- .../sqlite/backport_provider_setup.py | 2 +- .../providers/ssh/backport_provider_setup.py | 2 +- airflow/providers/ssh/hooks/ssh.py | 2 +- .../vertica/backport_provider_setup.py | 2 +- .../yandex/backport_provider_setup.py | 2 +- airflow/providers/yandex/hooks/yandex.py | 2 +- .../zendesk/backport_provider_setup.py | 2 +- airflow/secrets/base_secrets.py | 3 +- airflow/secrets/local_filesystem.py | 15 +- airflow/secrets/metastore.py | 2 + airflow/security/kerberos.py | 49 +- airflow/sensors/base_sensor_operator.py | 98 +- airflow/sensors/bash.py | 16 +- airflow/sensors/date_time_sensor.py | 8 +- airflow/sensors/external_task_sensor.py | 103 +- airflow/sensors/filesystem.py | 5 +- airflow/sensors/hdfs_sensor.py | 3 +- airflow/sensors/hive_partition_sensor.py | 3 +- airflow/sensors/http_sensor.py | 3 +- airflow/sensors/metastore_partition_sensor.py | 3 +- .../sensors/named_hive_partition_sensor.py | 3 +- airflow/sensors/python.py | 14 +- airflow/sensors/s3_key_sensor.py | 3 +- airflow/sensors/s3_prefix_sensor.py | 3 +- airflow/sensors/smart_sensor_operator.py | 169 +- airflow/sensors/sql_sensor.py | 35 +- airflow/sensors/web_hdfs_sensor.py | 3 +- airflow/sensors/weekday_sensor.py | 15 +- airflow/sentry.py | 21 +- airflow/serialization/serialized_objects.py | 113 +- airflow/settings.py | 61 +- airflow/stats.py | 30 +- airflow/task/task_runner/base_task_runner.py | 16 +- .../task/task_runner/cgroup_task_runner.py | 49 +- airflow/ti_deps/dep_context.py | 21 +- airflow/ti_deps/dependencies_deps.py | 5 +- airflow/ti_deps/deps/base_ti_dep.py | 9 +- .../deps/dag_ti_slots_available_dep.py | 4 +- airflow/ti_deps/deps/dag_unpaused_dep.py | 3 +- airflow/ti_deps/deps/dagrun_exists_dep.py | 20 +- airflow/ti_deps/deps/dagrun_id_dep.py | 7 +- .../deps/exec_date_after_start_date_dep.py | 13 +- .../ti_deps/deps/not_in_retry_period_dep.py | 12 +- .../deps/not_previously_skipped_dep.py | 17 +- .../ti_deps/deps/pool_slots_available_dep.py | 16 +- airflow/ti_deps/deps/prev_dagrun_dep.py | 31 +- airflow/ti_deps/deps/ready_to_reschedule.py | 18 +- .../ti_deps/deps/runnable_exec_date_dep.py | 18 +- airflow/ti_deps/deps/task_concurrency_dep.py | 6 +- airflow/ti_deps/deps/task_not_running_dep.py | 3 +- airflow/ti_deps/deps/trigger_rule_dep.py | 95 +- airflow/ti_deps/deps/valid_state_dep.py | 10 +- airflow/typing_compat.py | 4 +- airflow/utils/callback_requests.py | 4 +- airflow/utils/cli.py | 37 +- airflow/utils/compression.py | 14 +- airflow/utils/dag_processing.py | 240 ++- airflow/utils/dates.py | 6 +- airflow/utils/db.py | 150 +- airflow/utils/decorators.py | 15 +- airflow/utils/dot_renderer.py | 30 +- airflow/utils/email.py | 35 +- airflow/utils/file.py | 33 +- airflow/utils/helpers.py | 31 +- airflow/utils/json.py | 25 +- airflow/utils/log/cloudwatch_task_handler.py | 3 +- airflow/utils/log/colored_log.py | 12 +- airflow/utils/log/es_task_handler.py | 3 +- airflow/utils/log/file_processor_handler.py | 16 +- airflow/utils/log/file_task_handler.py | 40 +- airflow/utils/log/gcs_task_handler.py | 3 +- airflow/utils/log/json_formatter.py | 3 +- airflow/utils/log/log_reader.py | 13 +- airflow/utils/log/logging_mixin.py | 11 +- airflow/utils/log/s3_task_handler.py | 3 +- airflow/utils/log/stackdriver_task_handler.py | 3 +- airflow/utils/log/wasb_task_handler.py | 3 +- airflow/utils/module_loading.py | 4 +- airflow/utils/operator_helpers.py | 61 +- airflow/utils/operator_resources.py | 16 +- airflow/utils/orm_event_handlers.py | 21 +- airflow/utils/process_utils.py | 24 +- airflow/utils/python_virtualenv.py | 10 +- airflow/utils/serve_logs.py | 6 +- airflow/utils/session.py | 4 +- airflow/utils/sqlalchemy.py | 26 +- airflow/utils/state.py | 41 +- airflow/utils/task_group.py | 18 +- airflow/utils/timeout.py | 2 +- airflow/utils/timezone.py | 13 +- airflow/utils/weekday.py | 4 +- airflow/www/api/experimental/endpoints.py | 61 +- airflow/www/app.py | 6 +- airflow/www/auth.py | 8 +- airflow/www/extensions/init_jinja_globals.py | 2 +- airflow/www/extensions/init_security.py | 5 +- airflow/www/forms.py | 170 +- airflow/www/security.py | 10 +- airflow/www/utils.py | 168 +- airflow/www/validators.py | 11 +- airflow/www/views.py | 1472 ++++++++++------- airflow/www/widgets.py | 5 +- .../tests/test_celery_kubernetes_executor.py | 1 + ...est_celery_kubernetes_pod_launcher_role.py | 1 + chart/tests/test_chart_quality.py | 3 +- .../test_dags_persistent_volume_claim.py | 1 + chart/tests/test_flower_authorization.py | 1 + chart/tests/test_git_sync_scheduler.py | 1 + chart/tests/test_git_sync_webserver.py | 1 + chart/tests/test_git_sync_worker.py | 1 + chart/tests/test_migrate_database_job.py | 1 + chart/tests/test_pod_template_file.py | 3 +- chart/tests/test_scheduler.py | 1 + chart/tests/test_worker.py | 1 + dags/test_dag.py | 10 +- dev/airflow-github | 55 +- dev/airflow-license | 26 +- dev/send_email.py | 117 +- docs/build_docs.py | 70 +- docs/conf.py | 57 +- docs/exts/docroles.py | 28 +- docs/exts/exampleinclude.py | 10 +- kubernetes_tests/test_kubernetes_executor.py | 115 +- .../test_kubernetes_pod_operator.py | 326 ++-- metastore_browser/hive_metastore.py | 47 +- .../import_all_provider_classes.py | 19 +- .../prepare_provider_packages.py | 565 ++++--- .../refactor_provider_packages.py | 158 +- provider_packages/remove_old_releases.py | 34 +- .../pre_commit_check_order_setup.py | 20 +- .../pre_commit_check_setup_installation.py | 19 +- .../ci/pre_commit/pre_commit_yaml_to_cfg.py | 28 +- .../update_quarantined_test_status.py | 64 +- scripts/tools/list-integrations.py | 9 +- setup.py | 92 +- .../disable_checks_for_tests.py | 18 +- tests/airflow_pylint/do_not_use_asserts.py | 5 +- tests/always/test_example_dags.py | 4 +- tests/always/test_project_structure.py | 53 +- tests/api/auth/backend/test_basic_auth.py | 58 +- tests/api/auth/test_client.py | 29 +- tests/api/client/test_local_client.py | 60 +- .../common/experimental/test_delete_dag.py | 56 +- .../common/experimental/test_mark_tasks.py | 299 ++-- tests/api/common/experimental/test_pool.py | 82 +- .../common/experimental/test_trigger_dag.py | 13 +- .../endpoints/test_config_endpoint.py | 2 - .../endpoints/test_task_instance_endpoint.py | 8 +- .../endpoints/test_variable_endpoint.py | 2 +- .../endpoints/test_xcom_endpoint.py | 2 +- .../schemas/test_task_instance_schema.py | 4 +- tests/build_provider_packages_dependencies.py | 27 +- tests/cli/commands/test_celery_command.py | 32 +- .../cli/commands/test_cheat_sheet_command.py | 3 +- tests/cli/commands/test_config_command.py | 20 +- tests/cli/commands/test_connection_command.py | 372 +++-- tests/cli/commands/test_dag_command.py | 303 ++-- tests/cli/commands/test_db_command.py | 44 +- tests/cli/commands/test_info_command.py | 17 +- tests/cli/commands/test_kubernetes_command.py | 15 +- tests/cli/commands/test_legacy_commands.py | 44 +- tests/cli/commands/test_pool_command.py | 20 +- tests/cli/commands/test_role_command.py | 9 +- tests/cli/commands/test_sync_perm_command.py | 21 +- tests/cli/commands/test_task_command.py | 320 ++-- tests/cli/commands/test_user_command.py | 226 ++- tests/cli/commands/test_variable_command.py | 89 +- tests/cli/commands/test_webserver_command.py | 57 +- tests/cli/test_cli_parser.py | 87 +- tests/cluster_policies/__init__.py | 13 +- tests/conftest.py | 136 +- tests/core/test_config_templates.py | 28 +- tests/core/test_configuration.py | 211 +-- tests/core/test_core.py | 205 +-- tests/core/test_core_to_contrib.py | 12 +- tests/core/test_example_dags_system.py | 16 +- tests/core/test_impersonation_tests.py | 66 +- tests/core/test_local_settings.py | 7 + tests/core/test_logging_config.py | 35 +- tests/core/test_sentry.py | 1 + tests/core/test_sqlalchemy_config.py | 46 +- tests/core/test_stats.py | 103 +- tests/dags/subdir2/test_dont_ignore_this.py | 5 +- tests/dags/test_backfill_pooled_tasks.py | 3 +- tests/dags/test_clear_subdag.py | 13 +- tests/dags/test_cli_triggered_dags.py | 19 +- tests/dags/test_default_impersonation.py | 5 +- tests/dags/test_default_views.py | 12 +- tests/dags/test_double_trigger.py | 4 +- tests/dags/test_example_bash_operator.py | 25 +- tests/dags/test_heartbeat_failed_fast.py | 5 +- tests/dags/test_impersonation.py | 5 +- tests/dags/test_impersonation_subdag.py | 28 +- tests/dags/test_invalid_cron.py | 10 +- tests/dags/test_issue_1225.py | 29 +- tests/dags/test_latest_runs.py | 6 +- tests/dags/test_logging_in_dag.py | 6 +- tests/dags/test_mark_success.py | 6 +- tests/dags/test_missing_owner.py | 4 +- tests/dags/test_multiple_dags.py | 8 +- tests/dags/test_no_impersonation.py | 3 +- tests/dags/test_on_failure_callback.py | 4 +- tests/dags/test_on_kill.py | 9 +- tests/dags/test_prev_dagrun_dep.py | 14 +- tests/dags/test_retry_handling_job.py | 5 +- tests/dags/test_scheduler_dags.py | 25 +- tests/dags/test_task_view_type_check.py | 5 +- tests/dags/test_with_non_default_owner.py | 5 +- .../test_impersonation_custom.py | 17 +- tests/dags_with_system_exit/a_system_exit.py | 4 +- .../b_test_scheduler_dags.py | 9 +- tests/dags_with_system_exit/c_system_exit.py | 4 +- tests/deprecated_classes.py | 116 +- tests/executors/test_base_executor.py | 8 +- tests/executors/test_celery_executor.py | 134 +- .../test_celery_kubernetes_executor.py | 14 +- tests/executors/test_dask_executor.py | 39 +- tests/executors/test_executor_loader.py | 33 +- tests/executors/test_kubernetes_executor.py | 124 +- tests/executors/test_local_executor.py | 18 +- tests/executors/test_sequential_executor.py | 9 +- tests/hooks/test_dbapi_hook.py | 50 +- tests/jobs/test_backfill_job.py | 510 +++--- tests/jobs/test_base_job.py | 5 +- tests/jobs/test_local_task_job.py | 146 +- tests/jobs/test_scheduler_job.py | 975 ++++++----- tests/kubernetes/models/test_secret.py | 167 +- tests/kubernetes/test_client.py | 1 - tests/kubernetes/test_pod_generator.py | 377 ++--- tests/kubernetes/test_pod_launcher.py | 162 +- tests/kubernetes/test_refresh_config.py | 1 - tests/lineage/test_lineage.py | 43 +- tests/models/test_baseoperator.py | 77 +- tests/models/test_cleartasks.py | 61 +- tests/models/test_connection.py | 97 +- tests/models/test_dag.py | 575 +++---- tests/models/test_dagbag.py | 168 +- tests/models/test_dagcode.py | 20 +- tests/models/test_dagparam.py | 4 +- tests/models/test_dagrun.py | 301 ++-- tests/models/test_pool.py | 39 +- tests/models/test_renderedtifields.py | 135 +- tests/models/test_sensorinstance.py | 14 +- tests/models/test_serialized_dag.py | 10 +- tests/models/test_skipmixin.py | 18 +- tests/models/test_taskinstance.py | 666 ++++---- tests/models/test_timestamp.py | 3 +- tests/models/test_variable.py | 15 +- tests/models/test_xcom.py | 146 +- tests/models/test_xcom_arg.py | 16 +- tests/operators/test_bash.py | 83 +- tests/operators/test_branch_operator.py | 24 +- tests/operators/test_dagrun_operator.py | 24 +- tests/operators/test_email.py | 15 +- tests/operators/test_generic_transfer.py | 24 +- tests/operators/test_latest_only_operator.py | 162 +- tests/operators/test_python.py | 360 ++-- tests/operators/test_sql.py | 102 +- tests/operators/test_subdag_operator.py | 76 +- tests/plugins/test_plugin.py | 33 +- tests/plugins/test_plugins_manager.py | 55 +- .../amazon/aws/hooks/test_glacier.py | 2 +- tests/providers/amazon/aws/hooks/test_glue.py | 1 - .../amazon/aws/hooks/test_sagemaker.py | 2 +- .../amazon/aws/hooks/test_secrets_manager.py | 2 +- .../amazon/aws/operators/test_athena.py | 1 - .../amazon/aws/operators/test_batch.py | 1 - .../amazon/aws/operators/test_ecs.py | 2 +- .../amazon/aws/operators/test_glacier.py | 8 +- .../aws/operators/test_glacier_system.py | 1 - .../amazon/aws/operators/test_glue.py | 1 - .../amazon/aws/operators/test_s3_bucket.py | 2 +- .../amazon/aws/operators/test_s3_list.py | 1 - .../aws/operators/test_sagemaker_endpoint.py | 2 +- .../test_sagemaker_endpoint_config.py | 1 - .../aws/operators/test_sagemaker_model.py | 1 - .../operators/test_sagemaker_processing.py | 2 +- .../aws/operators/test_sagemaker_training.py | 1 - .../aws/operators/test_sagemaker_transform.py | 1 - .../aws/operators/test_sagemaker_tuning.py | 1 - .../amazon/aws/sensors/test_athena.py | 1 - .../amazon/aws/sensors/test_glacier.py | 1 - .../providers/amazon/aws/sensors/test_glue.py | 1 - .../sensors/test_glue_catalog_partition.py | 1 - .../aws/sensors/test_sagemaker_endpoint.py | 1 - .../aws/sensors/test_sagemaker_training.py | 1 - .../aws/sensors/test_sagemaker_transform.py | 1 - .../aws/sensors/test_sagemaker_tuning.py | 1 - .../amazon/aws/transfers/test_gcs_to_s3.py | 1 - .../aws/transfers/test_glacier_to_gcs.py | 4 +- .../amazon/aws/transfers/test_mongo_to_s3.py | 1 - .../apache/cassandra/hooks/test_cassandra.py | 2 +- .../druid/operators/test_druid_check.py | 1 - .../apache/hive/transfers/test_s3_to_hive.py | 1 - tests/providers/apache/pig/hooks/test_pig.py | 1 - .../apache/pig/operators/test_pig.py | 1 - .../spark/hooks/test_spark_jdbc_script.py | 1 + .../providers/cloudant/hooks/test_cloudant.py | 1 - .../cncf/kubernetes/hooks/test_kubernetes.py | 3 +- .../databricks/hooks/test_databricks.py | 2 +- .../databricks/operators/test_databricks.py | 1 - .../providers/datadog/sensors/test_datadog.py | 1 - tests/providers/docker/hooks/test_docker.py | 1 - .../providers/docker/operators/test_docker.py | 1 - .../docker/operators/test_docker_swarm.py | 2 +- tests/providers/exasol/hooks/test_exasol.py | 1 - .../providers/exasol/operators/test_exasol.py | 1 - .../providers/facebook/ads/hooks/test_ads.py | 1 + tests/providers/google/ads/hooks/test_ads.py | 1 + .../google/cloud/hooks/test_automl.py | 2 +- .../google/cloud/hooks/test_bigquery_dts.py | 2 +- .../google/cloud/hooks/test_cloud_build.py | 1 - .../test_cloud_storage_transfer_service.py | 2 +- .../google/cloud/hooks/test_compute.py | 1 - .../google/cloud/hooks/test_dataflow.py | 2 +- .../google/cloud/hooks/test_dataproc.py | 2 +- .../google/cloud/hooks/test_datastore.py | 3 +- .../providers/google/cloud/hooks/test_dlp.py | 2 +- .../google/cloud/hooks/test_functions.py | 1 - .../providers/google/cloud/hooks/test_kms.py | 1 - .../cloud/hooks/test_kubernetes_engine.py | 2 +- .../google/cloud/hooks/test_life_sciences.py | 1 - .../cloud/hooks/test_natural_language.py | 2 +- .../google/cloud/hooks/test_pubsub.py | 2 +- .../google/cloud/hooks/test_spanner.py | 1 - .../google/cloud/hooks/test_speech_to_text.py | 1 - .../google/cloud/hooks/test_stackdriver.py | 2 +- .../google/cloud/hooks/test_tasks.py | 2 +- .../google/cloud/hooks/test_text_to_speech.py | 1 - .../google/cloud/hooks/test_translate.py | 1 - .../cloud/hooks/test_video_intelligence.py | 2 +- .../google/cloud/hooks/test_vision.py | 2 +- .../google/cloud/operators/test_automl.py | 2 +- .../google/cloud/operators/test_bigquery.py | 2 +- .../cloud/operators/test_bigquery_dts.py | 1 - .../cloud/operators/test_cloud_build.py | 3 +- .../cloud/operators/test_cloud_memorystore.py | 6 +- .../google/cloud/operators/test_cloud_sql.py | 2 +- .../test_cloud_storage_transfer_service.py | 2 +- .../google/cloud/operators/test_dataflow.py | 3 +- .../cloud/operators/test_dataproc_system.py | 2 +- .../google/cloud/operators/test_dlp.py | 1 - .../google/cloud/operators/test_dlp_system.py | 2 +- .../google/cloud/operators/test_functions.py | 2 +- .../google/cloud/operators/test_gcs.py | 1 - .../cloud/operators/test_kubernetes_engine.py | 2 +- .../cloud/operators/test_life_sciences.py | 1 - .../cloud/operators/test_mlengine_utils.py | 3 +- .../google/cloud/operators/test_pubsub.py | 2 +- .../google/cloud/operators/test_spanner.py | 2 +- .../cloud/operators/test_speech_to_text.py | 1 - .../cloud/operators/test_stackdriver.py | 2 +- .../google/cloud/operators/test_tasks.py | 2 +- .../cloud/operators/test_text_to_speech.py | 2 +- .../google/cloud/operators/test_translate.py | 1 - .../cloud/operators/test_translate_speech.py | 2 +- .../operators/test_video_intelligence.py | 2 +- .../google/cloud/operators/test_vision.py | 2 +- .../google/cloud/sensors/test_bigquery_dts.py | 1 - .../test_cloud_storage_transfer_service.py | 2 +- .../google/cloud/sensors/test_dataproc.py | 2 +- .../google/cloud/sensors/test_pubsub.py | 2 +- .../cloud/transfers/test_adls_to_gcs.py | 1 - .../transfers/test_azure_fileshare_to_gcs.py | 1 - .../test_azure_fileshare_to_gcs_system.py | 9 +- .../transfers/test_bigquery_to_bigquery.py | 1 - .../cloud/transfers/test_bigquery_to_gcs.py | 1 - .../cloud/transfers/test_bigquery_to_mysql.py | 1 - .../cloud/transfers/test_cassandra_to_gcs.py | 1 - .../cloud/transfers/test_gcs_to_bigquery.py | 1 - .../google/cloud/transfers/test_gcs_to_gcs.py | 1 - .../cloud/transfers/test_gcs_to_local.py | 1 - .../transfers/test_gcs_to_local_system.py | 2 +- .../cloud/transfers/test_gcs_to_sftp.py | 1 - .../cloud/transfers/test_local_to_gcs.py | 2 +- .../cloud/transfers/test_mssql_to_gcs.py | 1 - .../cloud/transfers/test_mysql_to_gcs.py | 2 +- .../transfers/test_mysql_to_gcs_system.py | 4 +- .../google/cloud/transfers/test_s3_to_gcs.py | 1 - .../cloud/transfers/test_s3_to_gcs_system.py | 3 +- .../cloud/transfers/test_salesforce_to_gcs.py | 1 - .../test_salesforce_to_gcs_system.py | 1 + .../cloud/transfers/test_sftp_to_gcs.py | 1 - .../google/cloud/transfers/test_sql_to_gcs.py | 2 +- .../cloud/utils/test_credentials_provider.py | 2 +- .../google/firebase/hooks/test_firestore.py | 1 - .../google/suite/hooks/test_drive.py | 1 - .../google/suite/hooks/test_sheets.py | 1 - tests/providers/grpc/hooks/test_grpc.py | 1 - tests/providers/grpc/operators/test_grpc.py | 1 - tests/providers/http/hooks/test_http.py | 2 +- tests/providers/http/operators/test_http.py | 2 +- tests/providers/http/sensors/test_http.py | 2 +- .../microsoft/azure/hooks/test_azure_batch.py | 2 +- .../azure/hooks/test_azure_cosmos.py | 2 +- .../azure/hooks/test_azure_data_lake.py | 1 - .../azure/hooks/test_azure_fileshare.py | 4 +- .../microsoft/azure/hooks/test_wasb.py | 1 - .../azure/operators/test_adls_list.py | 1 - .../azure/operators/test_azure_batch.py | 1 - .../test_azure_container_instances.py | 2 +- .../azure/operators/test_azure_cosmos.py | 1 - .../azure/operators/test_wasb_delete_blob.py | 1 - .../microsoft/azure/sensors/test_wasb.py | 1 - .../azure/transfers/test_azure_blob_to_gcs.py | 5 +- .../azure/transfers/test_file_to_wasb.py | 1 - .../azure/transfers/test_local_to_adls.py | 1 - .../test_oracle_to_azure_data_lake.py | 2 +- .../microsoft/mssql/hooks/test_mssql.py | 1 - .../microsoft/mssql/operators/test_mssql.py | 1 - .../providers/openfaas/hooks/test_openfaas.py | 2 +- tests/providers/oracle/hooks/test_oracle.py | 2 +- .../providers/oracle/operators/test_oracle.py | 1 - .../oracle/transfers/test_oracle_to_oracle.py | 1 - .../pagerduty/hooks/test_pagerduty.py | 1 + tests/providers/plexus/hooks/test_plexus.py | 1 + tests/providers/plexus/operators/test_job.py | 2 + .../qubole/operators/test_qubole_check.py | 2 +- tests/providers/samba/hooks/test_samba.py | 2 +- .../providers/sendgrid/utils/test_emailer.py | 1 - .../singularity/operators/test_singularity.py | 2 +- tests/providers/slack/hooks/test_slack.py | 2 +- tests/providers/slack/operators/test_slack.py | 1 - .../snowflake/operators/test_snowflake.py | 1 - tests/providers/ssh/hooks/test_ssh.py | 6 +- tests/secrets/test_local_filesystem.py | 55 +- tests/secrets/test_secrets.py | 61 +- tests/secrets/test_secrets_backends.py | 30 +- tests/security/test_kerberos.py | 17 +- tests/sensors/test_base_sensor.py | 162 +- tests/sensors/test_bash.py | 9 +- tests/sensors/test_date_time_sensor.py | 10 +- tests/sensors/test_external_task_sensor.py | 309 ++-- tests/sensors/test_filesystem.py | 36 +- tests/sensors/test_python.py | 54 +- tests/sensors/test_smart_sensor_operator.py | 55 +- tests/sensors/test_sql_sensor.py | 45 +- tests/sensors/test_time_sensor.py | 4 +- tests/sensors/test_timedelta_sensor.py | 8 +- tests/sensors/test_timeout_sensor.py | 15 +- tests/sensors/test_weekday_sensor.py | 61 +- tests/serialization/test_dag_serialization.py | 472 +++--- .../task_runner/test_cgroup_task_runner.py | 1 - .../task_runner/test_standard_task_runner.py | 40 +- tests/task/task_runner/test_task_runner.py | 11 +- tests/test_utils/amazon_system_helpers.py | 38 +- tests/test_utils/api_connexion_utils.py | 2 +- tests/test_utils/asserts.py | 20 +- tests/test_utils/azure_system_helpers.py | 1 - tests/test_utils/db.py | 16 +- tests/test_utils/gcp_system_helpers.py | 62 +- tests/test_utils/hdfs_utils.py | 245 +-- tests/test_utils/logging_command_executor.py | 17 +- tests/test_utils/mock_hooks.py | 6 +- tests/test_utils/mock_operators.py | 18 +- tests/test_utils/mock_process.py | 12 +- tests/test_utils/perf/dags/elastic_dag.py | 27 +- tests/test_utils/perf/dags/perf_dag_1.py | 13 +- tests/test_utils/perf/dags/perf_dag_2.py | 13 +- tests/test_utils/perf/perf_kit/memory.py | 1 + .../perf/perf_kit/repeat_and_time.py | 2 + tests/test_utils/perf/perf_kit/sqlalchemy.py | 75 +- .../perf/scheduler_dag_execution_timing.py | 53 +- .../test_utils/perf/scheduler_ops_metrics.py | 101 +- tests/test_utils/perf/sql_queries.py | 6 +- .../remote_user_api_auth_backend.py | 6 +- .../test_remote_user_api_auth_backend.py | 4 +- tests/ti_deps/deps/fake_models.py | 4 - .../deps/test_dag_ti_slots_available_dep.py | 1 - tests/ti_deps/deps/test_dag_unpaused_dep.py | 1 - tests/ti_deps/deps/test_dagrun_exists_dep.py | 1 - .../deps/test_not_in_retry_period_dep.py | 10 +- .../deps/test_not_previously_skipped_dep.py | 16 +- tests/ti_deps/deps/test_prev_dagrun_dep.py | 81 +- .../deps/test_ready_to_reschedule_dep.py | 18 +- .../deps/test_runnable_exec_date_dep.py | 24 +- tests/ti_deps/deps/test_task_concurrency.py | 1 - .../ti_deps/deps/test_task_not_running_dep.py | 1 - tests/ti_deps/deps/test_trigger_rule_dep.py | 656 ++++---- tests/ti_deps/deps/test_valid_state_dep.py | 1 - .../utils/log/test_file_processor_handler.py | 14 +- tests/utils/log/test_json_formatter.py | 6 +- tests/utils/log/test_log_reader.py | 46 +- tests/utils/test_cli_util.py | 26 +- tests/utils/test_compression.py | 31 +- tests/utils/test_dag_cycle.py | 40 +- tests/utils/test_dag_processing.py | 89 +- tests/utils/test_dates.py | 25 +- tests/utils/test_db.py | 70 +- tests/utils/test_decorators.py | 4 +- tests/utils/test_docs.py | 14 +- tests/utils/test_email.py | 50 +- tests/utils/test_helpers.py | 44 +- tests/utils/test_json.py | 46 +- tests/utils/test_log_handlers.py | 30 +- tests/utils/test_logging_mixin.py | 12 +- tests/utils/test_net.py | 3 +- tests/utils/test_operator_helpers.py | 24 +- tests/utils/test_process_utils.py | 15 +- tests/utils/test_python_virtualenv.py | 21 +- tests/utils/test_sqlalchemy.py | 64 +- ...test_task_handler_with_custom_formatter.py | 5 +- tests/utils/test_timezone.py | 9 +- tests/utils/test_trigger_rule.py | 1 - tests/utils/test_weight_rule.py | 1 - .../experimental/test_dag_runs_endpoint.py | 10 +- tests/www/api/experimental/test_endpoints.py | 166 +- .../experimental/test_kerberos_endpoints.py | 16 +- tests/www/test_app.py | 90 +- tests/www/test_security.py | 33 +- tests/www/test_utils.py | 78 +- tests/www/test_validators.py | 5 +- tests/www/test_views.py | 1095 ++++++------ 1068 files changed, 15410 insertions(+), 15132 deletions(-) diff --git a/airflow/__init__.py b/airflow/__init__.py index 486bd5fbbe148..35e755ece0e04 100644 --- a/airflow/__init__.py +++ b/airflow/__init__.py @@ -55,15 +55,18 @@ def __getattr__(name): # PEP-562: Lazy loaded attributes on python modules if name == "DAG": from airflow.models.dag import DAG # pylint: disable=redefined-outer-name + return DAG if name == "AirflowException": from airflow.exceptions import AirflowException # pylint: disable=redefined-outer-name + return AirflowException raise AttributeError(f"module {__name__} has no attribute {name}") if not settings.LAZY_LOAD_PLUGINS: from airflow import plugins_manager + plugins_manager.ensure_plugins_loaded() diff --git a/airflow/api/__init__.py b/airflow/api/__init__.py index 63dbfcff93283..c869252858a5b 100644 --- a/airflow/api/__init__.py +++ b/airflow/api/__init__.py @@ -38,8 +38,5 @@ def load_auth(): log.info("Loaded API auth backend: %s", auth_backend) return auth_backend except ImportError as err: - log.critical( - "Cannot import %s for API authentication due to: %s", - auth_backend, err - ) + log.critical("Cannot import %s for API authentication due to: %s", auth_backend, err) raise AirflowException(err) diff --git a/airflow/api/auth/backend/basic_auth.py b/airflow/api/auth/backend/basic_auth.py index bd42708661f6e..9b0e32087b6ba 100644 --- a/airflow/api/auth/backend/basic_auth.py +++ b/airflow/api/auth/backend/basic_auth.py @@ -53,13 +53,12 @@ def auth_current_user() -> Optional[User]: def requires_authentication(function: T): """Decorator for functions that require authentication""" + @wraps(function) def decorated(*args, **kwargs): if auth_current_user() is not None: return function(*args, **kwargs) else: - return Response( - "Unauthorized", 401, {"WWW-Authenticate": "Basic"} - ) + return Response("Unauthorized", 401, {"WWW-Authenticate": "Basic"}) return cast(T, decorated) diff --git a/airflow/api/auth/backend/default.py b/airflow/api/auth/backend/default.py index 70ae82b2d6c60..bfa96dfb0e9f7 100644 --- a/airflow/api/auth/backend/default.py +++ b/airflow/api/auth/backend/default.py @@ -33,6 +33,7 @@ def init_app(_): def requires_authentication(function: T): """Decorator for functions that require authentication""" + @wraps(function) def decorated(*args, **kwargs): return function(*args, **kwargs) diff --git a/airflow/api/auth/backend/kerberos_auth.py b/airflow/api/auth/backend/kerberos_auth.py index 41042dec32c01..08a0fecace7cd 100644 --- a/airflow/api/auth/backend/kerberos_auth.py +++ b/airflow/api/auth/backend/kerberos_auth.py @@ -132,6 +132,7 @@ def _gssapi_authenticate(token): def requires_authentication(function: T): """Decorator for functions that require authentication with Kerberos""" + @wraps(function) def decorated(*args, **kwargs): header = request.headers.get("Authorization") @@ -144,11 +145,11 @@ def decorated(*args, **kwargs): response = function(*args, **kwargs) response = make_response(response) if ctx.kerberos_token is not None: - response.headers['WWW-Authenticate'] = ' '.join(['negotiate', - ctx.kerberos_token]) + response.headers['WWW-Authenticate'] = ' '.join(['negotiate', ctx.kerberos_token]) return response if return_code != kerberos.AUTH_GSS_CONTINUE: return _forbidden() return _unauthorized() + return cast(T, decorated) diff --git a/airflow/api/client/__init__.py b/airflow/api/client/__init__.py index 7431dfa4cfc34..53c85062e63e9 100644 --- a/airflow/api/client/__init__.py +++ b/airflow/api/client/__init__.py @@ -35,6 +35,6 @@ def get_current_api_client() -> Client: api_client = api_module.Client( api_base_url=conf.get('cli', 'endpoint_url'), auth=getattr(auth_backend, 'CLIENT_AUTH', None), - session=session + session=session, ) return api_client diff --git a/airflow/api/client/json_client.py b/airflow/api/client/json_client.py index c17307f3f1508..1ffe7fd88cd5b 100644 --- a/airflow/api/client/json_client.py +++ b/airflow/api/client/json_client.py @@ -45,12 +45,15 @@ def _request(self, url, method='GET', json=None): def trigger_dag(self, dag_id, run_id=None, conf=None, execution_date=None): endpoint = f'/api/experimental/dags/{dag_id}/dag_runs' url = urljoin(self._api_base_url, endpoint) - data = self._request(url, method='POST', - json={ - "run_id": run_id, - "conf": conf, - "execution_date": execution_date, - }) + data = self._request( + url, + method='POST', + json={ + "run_id": run_id, + "conf": conf, + "execution_date": execution_date, + }, + ) return data['message'] def delete_dag(self, dag_id): @@ -74,12 +77,15 @@ def get_pools(self): def create_pool(self, name, slots, description): endpoint = '/api/experimental/pools' url = urljoin(self._api_base_url, endpoint) - pool = self._request(url, method='POST', - json={ - 'name': name, - 'slots': slots, - 'description': description, - }) + pool = self._request( + url, + method='POST', + json={ + 'name': name, + 'slots': slots, + 'description': description, + }, + ) return pool['pool'], pool['slots'], pool['description'] def delete_pool(self, name): diff --git a/airflow/api/client/local_client.py b/airflow/api/client/local_client.py index 5c08f1ab3eb13..7ce0d1655da6e 100644 --- a/airflow/api/client/local_client.py +++ b/airflow/api/client/local_client.py @@ -26,10 +26,9 @@ class Client(api_client.Client): """Local API client implementation.""" def trigger_dag(self, dag_id, run_id=None, conf=None, execution_date=None): - dag_run = trigger_dag.trigger_dag(dag_id=dag_id, - run_id=run_id, - conf=conf, - execution_date=execution_date) + dag_run = trigger_dag.trigger_dag( + dag_id=dag_id, run_id=run_id, conf=conf, execution_date=execution_date + ) return f"Created {dag_run}" def delete_dag(self, dag_id): diff --git a/airflow/api/common/experimental/__init__.py b/airflow/api/common/experimental/__init__.py index ebaea5e658322..b161e04346358 100644 --- a/airflow/api/common/experimental/__init__.py +++ b/airflow/api/common/experimental/__init__.py @@ -29,10 +29,7 @@ def check_and_get_dag(dag_id: str, task_id: Optional[str] = None) -> DagModel: if dag_model is None: raise DagNotFound(f"Dag id {dag_id} not found in DagModel") - dagbag = DagBag( - dag_folder=dag_model.fileloc, - read_dags_from_db=True - ) + dagbag = DagBag(dag_folder=dag_model.fileloc, read_dags_from_db=True) dag = dagbag.get_dag(dag_id) if not dag: error_message = f"Dag id {dag_id} not found" @@ -47,7 +44,6 @@ def check_and_get_dagrun(dag: DagModel, execution_date: datetime) -> DagRun: """Get DagRun object and check that it exists""" dagrun = dag.get_dagrun(execution_date=execution_date) if not dagrun: - error_message = ('Dag Run for date {} not found in dag {}' - .format(execution_date, dag.dag_id)) + error_message = f'Dag Run for date {execution_date} not found in dag {dag.dag_id}' raise DagRunNotFound(error_message) return dagrun diff --git a/airflow/api/common/experimental/delete_dag.py b/airflow/api/common/experimental/delete_dag.py index ee50c7d60090e..d27c21f18b2fa 100644 --- a/airflow/api/common/experimental/delete_dag.py +++ b/airflow/api/common/experimental/delete_dag.py @@ -60,13 +60,14 @@ def delete_dag(dag_id: str, keep_records_in_log: bool = True, session=None) -> i if dag.is_subdag: parent_dag_id, task_id = dag_id.rsplit(".", 1) for model in TaskFail, models.TaskInstance: - count += session.query(model).filter(model.dag_id == parent_dag_id, - model.task_id == task_id).delete() + count += ( + session.query(model).filter(model.dag_id == parent_dag_id, model.task_id == task_id).delete() + ) # Delete entries in Import Errors table for a deleted DAG # This handles the case when the dag_id is changed in the file - session.query(models.ImportError).filter( - models.ImportError.filename == dag.fileloc - ).delete(synchronize_session='fetch') + session.query(models.ImportError).filter(models.ImportError.filename == dag.fileloc).delete( + synchronize_session='fetch' + ) return count diff --git a/airflow/api/common/experimental/get_dag_runs.py b/airflow/api/common/experimental/get_dag_runs.py index 8b85c9919ac8d..cd939676bb3ac 100644 --- a/airflow/api/common/experimental/get_dag_runs.py +++ b/airflow/api/common/experimental/get_dag_runs.py @@ -38,16 +38,16 @@ def get_dag_runs(dag_id: str, state: Optional[str] = None) -> List[Dict[str, Any dag_runs = [] state = state.lower() if state else None for run in DagRun.find(dag_id=dag_id, state=state): - dag_runs.append({ - 'id': run.id, - 'run_id': run.run_id, - 'state': run.state, - 'dag_id': run.dag_id, - 'execution_date': run.execution_date.isoformat(), - 'start_date': ((run.start_date or '') and - run.start_date.isoformat()), - 'dag_run_url': url_for('Airflow.graph', dag_id=run.dag_id, - execution_date=run.execution_date) - }) + dag_runs.append( + { + 'id': run.id, + 'run_id': run.run_id, + 'state': run.state, + 'dag_id': run.dag_id, + 'execution_date': run.execution_date.isoformat(), + 'start_date': ((run.start_date or '') and run.start_date.isoformat()), + 'dag_run_url': url_for('Airflow.graph', dag_id=run.dag_id, execution_date=run.execution_date), + } + ) return dag_runs diff --git a/airflow/api/common/experimental/get_lineage.py b/airflow/api/common/experimental/get_lineage.py index 2d0e97dcdaf8d..c6857a65fb279 100644 --- a/airflow/api/common/experimental/get_lineage.py +++ b/airflow/api/common/experimental/get_lineage.py @@ -31,10 +31,12 @@ def get_lineage(dag_id: str, execution_date: datetime.datetime, session=None) -> dag = check_and_get_dag(dag_id) check_and_get_dagrun(dag, execution_date) - inlets: List[XCom] = XCom.get_many(dag_ids=dag_id, execution_date=execution_date, - key=PIPELINE_INLETS, session=session).all() - outlets: List[XCom] = XCom.get_many(dag_ids=dag_id, execution_date=execution_date, - key=PIPELINE_OUTLETS, session=session).all() + inlets: List[XCom] = XCom.get_many( + dag_ids=dag_id, execution_date=execution_date, key=PIPELINE_INLETS, session=session + ).all() + outlets: List[XCom] = XCom.get_many( + dag_ids=dag_id, execution_date=execution_date, key=PIPELINE_OUTLETS, session=session + ).all() lineage: Dict[str, Dict[str, Any]] = {} for meta in inlets: diff --git a/airflow/api/common/experimental/get_task_instance.py b/airflow/api/common/experimental/get_task_instance.py index 0b1dc8e18d8a0..07f41cbc400c1 100644 --- a/airflow/api/common/experimental/get_task_instance.py +++ b/airflow/api/common/experimental/get_task_instance.py @@ -31,8 +31,7 @@ def get_task_instance(dag_id: str, task_id: str, execution_date: datetime) -> Ta # Get task instance object and check that it exists task_instance = dagrun.get_task_instance(task_id) if not task_instance: - error_message = ('Task {} instance for date {} not found' - .format(task_id, execution_date)) + error_message = f'Task {task_id} instance for date {execution_date} not found' raise TaskInstanceNotFound(error_message) return task_instance diff --git a/airflow/api/common/experimental/mark_tasks.py b/airflow/api/common/experimental/mark_tasks.py index 3f1cb8363e486..852038dfe88a9 100644 --- a/airflow/api/common/experimental/mark_tasks.py +++ b/airflow/api/common/experimental/mark_tasks.py @@ -69,7 +69,7 @@ def set_state( past: bool = False, state: str = State.SUCCESS, commit: bool = False, - session=None + session=None, ): # pylint: disable=too-many-arguments,too-many-locals """ Set the state of a task instance and if needed its relatives. Can set state @@ -137,33 +137,24 @@ def set_state( # Flake and pylint disagree about correct indents here def all_subdag_tasks_query(sub_dag_run_ids, session, state, confirmed_dates): # noqa: E123 """Get *all* tasks of the sub dags""" - qry_sub_dag = session.query(TaskInstance). \ - filter( - TaskInstance.dag_id.in_(sub_dag_run_ids), - TaskInstance.execution_date.in_(confirmed_dates) - ). \ - filter( - or_( - TaskInstance.state.is_(None), - TaskInstance.state != state - ) + qry_sub_dag = ( + session.query(TaskInstance) + .filter(TaskInstance.dag_id.in_(sub_dag_run_ids), TaskInstance.execution_date.in_(confirmed_dates)) + .filter(or_(TaskInstance.state.is_(None), TaskInstance.state != state)) ) # noqa: E123 return qry_sub_dag def get_all_dag_task_query(dag, session, state, task_ids, confirmed_dates): """Get all tasks of the main dag that will be affected by a state change""" - qry_dag = session.query(TaskInstance). \ - filter( - TaskInstance.dag_id == dag.dag_id, - TaskInstance.execution_date.in_(confirmed_dates), - TaskInstance.task_id.in_(task_ids) # noqa: E123 - ). \ - filter( - or_( - TaskInstance.state.is_(None), - TaskInstance.state != state + qry_dag = ( + session.query(TaskInstance) + .filter( + TaskInstance.dag_id == dag.dag_id, + TaskInstance.execution_date.in_(confirmed_dates), + TaskInstance.task_id.in_(task_ids), # noqa: E123 ) + .filter(or_(TaskInstance.state.is_(None), TaskInstance.state != state)) ) return qry_dag @@ -186,10 +177,12 @@ def get_subdag_runs(dag, session, state, task_ids, commit, confirmed_dates): # this works as a kind of integrity check # it creates missing dag runs for subdag operators, # maybe this should be moved to dagrun.verify_integrity - dag_runs = _create_dagruns(current_task.subdag, - execution_dates=confirmed_dates, - state=State.RUNNING, - run_type=DagRunType.BACKFILL_JOB) + dag_runs = _create_dagruns( + current_task.subdag, + execution_dates=confirmed_dates, + state=State.RUNNING, + run_type=DagRunType.BACKFILL_JOB, + ) verify_dagruns(dag_runs, commit, state, session, current_task) @@ -279,10 +272,9 @@ def _set_dag_run_state(dag_id, execution_date, state, session=None): :param state: target state :param session: database session """ - dag_run = session.query(DagRun).filter( - DagRun.dag_id == dag_id, - DagRun.execution_date == execution_date - ).one() + dag_run = ( + session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.execution_date == execution_date).one() + ) dag_run.state = state if state == State.RUNNING: dag_run.start_date = timezone.utcnow() @@ -316,8 +308,9 @@ def set_dag_run_state_to_success(dag, execution_date, commit=False, session=None # Mark all task instances of the dag run to success. for task in dag.tasks: task.dag = dag - return set_state(tasks=dag.tasks, execution_date=execution_date, - state=State.SUCCESS, commit=commit, session=session) + return set_state( + tasks=dag.tasks, execution_date=execution_date, state=State.SUCCESS, commit=commit, session=session + ) @provide_session @@ -343,10 +336,15 @@ def set_dag_run_state_to_failed(dag, execution_date, commit=False, session=None) # Mark only RUNNING task instances. task_ids = [task.task_id for task in dag.tasks] - tis = session.query(TaskInstance).filter( - TaskInstance.dag_id == dag.dag_id, - TaskInstance.execution_date == execution_date, - TaskInstance.task_id.in_(task_ids)).filter(TaskInstance.state == State.RUNNING) + tis = ( + session.query(TaskInstance) + .filter( + TaskInstance.dag_id == dag.dag_id, + TaskInstance.execution_date == execution_date, + TaskInstance.task_id.in_(task_ids), + ) + .filter(TaskInstance.state == State.RUNNING) + ) task_ids_of_running_tis = [task_instance.task_id for task_instance in tis] tasks = [] @@ -356,8 +354,9 @@ def set_dag_run_state_to_failed(dag, execution_date, commit=False, session=None) task.dag = dag tasks.append(task) - return set_state(tasks=tasks, execution_date=execution_date, - state=State.FAILED, commit=commit, session=session) + return set_state( + tasks=tasks, execution_date=execution_date, state=State.FAILED, commit=commit, session=session + ) @provide_session diff --git a/airflow/api/common/experimental/trigger_dag.py b/airflow/api/common/experimental/trigger_dag.py index e1d3ceebff242..519079e0d30b4 100644 --- a/airflow/api/common/experimental/trigger_dag.py +++ b/airflow/api/common/experimental/trigger_dag.py @@ -63,16 +63,15 @@ def _trigger_dag( if min_dag_start_date and execution_date < min_dag_start_date: raise ValueError( "The execution_date [{}] should be >= start_date [{}] from DAG's default_args".format( - execution_date.isoformat(), - min_dag_start_date.isoformat())) + execution_date.isoformat(), min_dag_start_date.isoformat() + ) + ) run_id = run_id or DagRun.generate_run_id(DagRunType.MANUAL, execution_date) dag_run = DagRun.find(dag_id=dag_id, run_id=run_id) if dag_run: - raise DagRunAlreadyExists( - f"Run id {run_id} already exists for dag id {dag_id}" - ) + raise DagRunAlreadyExists(f"Run id {run_id} already exists for dag id {dag_id}") run_conf = None if conf: @@ -95,11 +94,11 @@ def _trigger_dag( def trigger_dag( - dag_id: str, - run_id: Optional[str] = None, - conf: Optional[Union[dict, str]] = None, - execution_date: Optional[datetime] = None, - replace_microseconds: bool = True, + dag_id: str, + run_id: Optional[str] = None, + conf: Optional[Union[dict, str]] = None, + execution_date: Optional[datetime] = None, + replace_microseconds: bool = True, ) -> Optional[DagRun]: """Triggers execution of DAG specified by dag_id diff --git a/airflow/api_connexion/endpoints/dag_run_endpoint.py b/airflow/api_connexion/endpoints/dag_run_endpoint.py index bac2ad215bfe4..d6a155546efd4 100644 --- a/airflow/api_connexion/endpoints/dag_run_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_run_endpoint.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from connexion import NoContent -from flask import g, request, current_app +from flask import current_app, g, request from marshmallow import ValidationError from airflow.api_connexion import security diff --git a/airflow/api_connexion/endpoints/dag_source_endpoint.py b/airflow/api_connexion/endpoints/dag_source_endpoint.py index b023ae60a32a6..0281b2250fcae 100644 --- a/airflow/api_connexion/endpoints/dag_source_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_source_endpoint.py @@ -25,7 +25,6 @@ from airflow.models.dagcode import DagCode from airflow.security import permissions - log = logging.getLogger(__name__) diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py b/airflow/api_connexion/endpoints/task_instance_endpoint.py index 37e9b02cdc4f6..b84c59ad9f896 100644 --- a/airflow/api_connexion/endpoints/task_instance_endpoint.py +++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py @@ -16,31 +16,31 @@ # under the License. from typing import Any, List, Optional, Tuple -from flask import request, current_app +from flask import current_app, request from marshmallow import ValidationError from sqlalchemy import and_, func from airflow.api.common.experimental.mark_tasks import set_state from airflow.api_connexion import security -from airflow.api_connexion.exceptions import NotFound, BadRequest +from airflow.api_connexion.exceptions import BadRequest, NotFound from airflow.api_connexion.parameters import format_datetime, format_parameters from airflow.api_connexion.schemas.task_instance_schema import ( - clear_task_instance_form, TaskInstanceCollection, - task_instance_collection_schema, - task_instance_schema, - task_instance_batch_form, - task_instance_reference_collection_schema, TaskInstanceReferenceCollection, + clear_task_instance_form, set_task_instance_state_form, + task_instance_batch_form, + task_instance_collection_schema, + task_instance_reference_collection_schema, + task_instance_schema, ) from airflow.exceptions import SerializedDagNotFound -from airflow.models.dagrun import DagRun as DR -from airflow.models.taskinstance import clear_task_instances, TaskInstance as TI from airflow.models import SlaMiss +from airflow.models.dagrun import DagRun as DR +from airflow.models.taskinstance import TaskInstance as TI, clear_task_instances from airflow.security import permissions -from airflow.utils.state import State from airflow.utils.session import provide_session +from airflow.utils.state import State @security.requires_access( diff --git a/airflow/api_connexion/schemas/sla_miss_schema.py b/airflow/api_connexion/schemas/sla_miss_schema.py index 341b8aa2a7c88..9413e37cbde21 100644 --- a/airflow/api_connexion/schemas/sla_miss_schema.py +++ b/airflow/api_connexion/schemas/sla_miss_schema.py @@ -16,6 +16,7 @@ # under the License. from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field + from airflow.models import SlaMiss diff --git a/airflow/api_connexion/schemas/task_instance_schema.py b/airflow/api_connexion/schemas/task_instance_schema.py index e825665f9528a..27bf4a6cd891e 100644 --- a/airflow/api_connexion/schemas/task_instance_schema.py +++ b/airflow/api_connexion/schemas/task_instance_schema.py @@ -15,14 +15,14 @@ # specific language governing permissions and limitations # under the License. -from typing import List, NamedTuple, Tuple, Optional +from typing import List, NamedTuple, Optional, Tuple -from marshmallow import Schema, fields, ValidationError, validates_schema, validate +from marshmallow import Schema, ValidationError, fields, validate, validates_schema from marshmallow.utils import get_value from airflow.api_connexion.schemas.enum_schemas import TaskInstanceStateField from airflow.api_connexion.schemas.sla_miss_schema import SlaMissSchema -from airflow.models import TaskInstance, SlaMiss +from airflow.models import SlaMiss, TaskInstance from airflow.utils.state import State diff --git a/airflow/cli/cli_parser.py b/airflow/cli/cli_parser.py index 84f1349f1121d..c9dbfec837ca8 100644 --- a/airflow/cli/cli_parser.py +++ b/airflow/cli/cli_parser.py @@ -92,8 +92,18 @@ class Arg: """Class to keep information about command line argument""" # pylint: disable=redefined-builtin,unused-argument - def __init__(self, flags=_UNSET, help=_UNSET, action=_UNSET, default=_UNSET, nargs=_UNSET, type=_UNSET, - choices=_UNSET, required=_UNSET, metavar=_UNSET): + def __init__( + self, + flags=_UNSET, + help=_UNSET, + action=_UNSET, + default=_UNSET, + nargs=_UNSET, + type=_UNSET, + choices=_UNSET, + required=_UNSET, + metavar=_UNSET, + ): self.flags = flags self.kwargs = {} for k, v in locals().items(): @@ -103,6 +113,7 @@ def __init__(self, flags=_UNSET, help=_UNSET, action=_UNSET, default=_UNSET, nar continue self.kwargs[k] = v + # pylint: enable=redefined-builtin,unused-argument def add_to_parser(self, parser: argparse.ArgumentParser): @@ -122,65 +133,47 @@ def positive_int(value): # Shared -ARG_DAG_ID = Arg( - ("dag_id",), - help="The id of the dag") -ARG_TASK_ID = Arg( - ("task_id",), - help="The id of the task") -ARG_EXECUTION_DATE = Arg( - ("execution_date",), - help="The execution date of the DAG", - type=parsedate) +ARG_DAG_ID = Arg(("dag_id",), help="The id of the dag") +ARG_TASK_ID = Arg(("task_id",), help="The id of the task") +ARG_EXECUTION_DATE = Arg(("execution_date",), help="The execution date of the DAG", type=parsedate) ARG_TASK_REGEX = Arg( - ("-t", "--task-regex"), - help="The regex to filter specific task_ids to backfill (optional)") + ("-t", "--task-regex"), help="The regex to filter specific task_ids to backfill (optional)" +) ARG_SUBDIR = Arg( ("-S", "--subdir"), help=( "File location or directory from which to look for the dag. " "Defaults to '[AIRFLOW_HOME]/dags' where [AIRFLOW_HOME] is the " - "value you set for 'AIRFLOW_HOME' config you set in 'airflow.cfg' "), - default='[AIRFLOW_HOME]/dags' if BUILD_DOCS else settings.DAGS_FOLDER) -ARG_START_DATE = Arg( - ("-s", "--start-date"), - help="Override start_date YYYY-MM-DD", - type=parsedate) -ARG_END_DATE = Arg( - ("-e", "--end-date"), - help="Override end_date YYYY-MM-DD", - type=parsedate) + "value you set for 'AIRFLOW_HOME' config you set in 'airflow.cfg' " + ), + default='[AIRFLOW_HOME]/dags' if BUILD_DOCS else settings.DAGS_FOLDER, +) +ARG_START_DATE = Arg(("-s", "--start-date"), help="Override start_date YYYY-MM-DD", type=parsedate) +ARG_END_DATE = Arg(("-e", "--end-date"), help="Override end_date YYYY-MM-DD", type=parsedate) ARG_OUTPUT_PATH = Arg( - ("-o", "--output-path",), + ( + "-o", + "--output-path", + ), help="The output for generated yaml files", type=str, - default="[CWD]" if BUILD_DOCS else os.getcwd()) + default="[CWD]" if BUILD_DOCS else os.getcwd(), +) ARG_DRY_RUN = Arg( ("-n", "--dry-run"), help="Perform a dry run for each task. Only renders Template Fields for each task, nothing else", - action="store_true") -ARG_PID = Arg( - ("--pid",), - help="PID file location", - nargs='?') + action="store_true", +) +ARG_PID = Arg(("--pid",), help="PID file location", nargs='?') ARG_DAEMON = Arg( - ("-D", "--daemon"), - help="Daemonize instead of running in the foreground", - action="store_true") -ARG_STDERR = Arg( - ("--stderr",), - help="Redirect stderr to this file") -ARG_STDOUT = Arg( - ("--stdout",), - help="Redirect stdout to this file") -ARG_LOG_FILE = Arg( - ("-l", "--log-file"), - help="Location of the log file") + ("-D", "--daemon"), help="Daemonize instead of running in the foreground", action="store_true" +) +ARG_STDERR = Arg(("--stderr",), help="Redirect stderr to this file") +ARG_STDOUT = Arg(("--stdout",), help="Redirect stdout to this file") +ARG_LOG_FILE = Arg(("-l", "--log-file"), help="Location of the log file") ARG_YES = Arg( - ("-y", "--yes"), - help="Do not prompt to confirm reset. Use with care!", - action="store_true", - default=False) + ("-y", "--yes"), help="Do not prompt to confirm reset. Use with care!", action="store_true", default=False +) ARG_OUTPUT = Arg( ("--output",), help=( @@ -189,238 +182,181 @@ def positive_int(value): ), metavar="FORMAT", choices=tabulate_formats, - default="plain") + default="plain", +) ARG_COLOR = Arg( ('--color',), help="Do emit colored output (default: auto)", choices={ColorMode.ON, ColorMode.OFF, ColorMode.AUTO}, - default=ColorMode.AUTO) + default=ColorMode.AUTO, +) # list_dag_runs -ARG_DAG_ID_OPT = Arg( - ("-d", "--dag-id"), - help="The id of the dag" -) +ARG_DAG_ID_OPT = Arg(("-d", "--dag-id"), help="The id of the dag") ARG_NO_BACKFILL = Arg( - ("--no-backfill",), - help="filter all the backfill dagruns given the dag id", - action="store_true") -ARG_STATE = Arg( - ("--state",), - help="Only list the dag runs corresponding to the state") + ("--no-backfill",), help="filter all the backfill dagruns given the dag id", action="store_true" +) +ARG_STATE = Arg(("--state",), help="Only list the dag runs corresponding to the state") # list_jobs -ARG_LIMIT = Arg( - ("--limit",), - help="Return a limited number of records") +ARG_LIMIT = Arg(("--limit",), help="Return a limited number of records") # next_execution ARG_NUM_EXECUTIONS = Arg( ("-n", "--num-executions"), default=1, type=positive_int, - help="The number of next execution datetimes to show") + help="The number of next execution datetimes to show", +) # backfill ARG_MARK_SUCCESS = Arg( - ("-m", "--mark-success"), - help="Mark jobs as succeeded without running them", - action="store_true") -ARG_VERBOSE = Arg( - ("-v", "--verbose"), - help="Make logging output more verbose", - action="store_true") -ARG_LOCAL = Arg( - ("-l", "--local"), - help="Run the task using the LocalExecutor", - action="store_true") + ("-m", "--mark-success"), help="Mark jobs as succeeded without running them", action="store_true" +) +ARG_VERBOSE = Arg(("-v", "--verbose"), help="Make logging output more verbose", action="store_true") +ARG_LOCAL = Arg(("-l", "--local"), help="Run the task using the LocalExecutor", action="store_true") ARG_DONOT_PICKLE = Arg( ("-x", "--donot-pickle"), help=( "Do not attempt to pickle the DAG object to send over " "to the workers, just tell the workers to run their version " - "of the code"), - action="store_true") + "of the code" + ), + action="store_true", +) ARG_BF_IGNORE_DEPENDENCIES = Arg( ("-i", "--ignore-dependencies"), help=( "Skip upstream tasks, run only the tasks " "matching the regexp. Only works in conjunction " - "with task_regex"), - action="store_true") + "with task_regex" + ), + action="store_true", +) ARG_BF_IGNORE_FIRST_DEPENDS_ON_PAST = Arg( ("-I", "--ignore-first-depends-on-past"), help=( "Ignores depends_on_past dependencies for the first " "set of tasks only (subsequent executions in the backfill " - "DO respect depends_on_past)"), - action="store_true") + "DO respect depends_on_past)" + ), + action="store_true", +) ARG_POOL = Arg(("--pool",), "Resource pool to use") ARG_DELAY_ON_LIMIT = Arg( ("--delay-on-limit",), - help=("Amount of time in seconds to wait when the limit " - "on maximum active dag runs (max_active_runs) has " - "been reached before trying to execute a dag run " - "again"), + help=( + "Amount of time in seconds to wait when the limit " + "on maximum active dag runs (max_active_runs) has " + "been reached before trying to execute a dag run " + "again" + ), type=float, - default=1.0) + default=1.0, +) ARG_RESET_DAG_RUN = Arg( ("--reset-dagruns",), help=( "if set, the backfill will delete existing " "backfill-related DAG runs and start " - "anew with fresh, running DAG runs"), - action="store_true") + "anew with fresh, running DAG runs" + ), + action="store_true", +) ARG_RERUN_FAILED_TASKS = Arg( ("--rerun-failed-tasks",), help=( "if set, the backfill will auto-rerun " "all the failed tasks for the backfill date range " - "instead of throwing exceptions"), - action="store_true") + "instead of throwing exceptions" + ), + action="store_true", +) ARG_RUN_BACKWARDS = Arg( - ("-B", "--run-backwards",), + ( + "-B", + "--run-backwards", + ), help=( "if set, the backfill will run tasks from the most " "recent day first. if there are tasks that depend_on_past " - "this option will throw an exception"), - action="store_true") + "this option will throw an exception" + ), + action="store_true", +) # test_dag ARG_SHOW_DAGRUN = Arg( - ("--show-dagrun", ), + ("--show-dagrun",), help=( "After completing the backfill, shows the diagram for current DAG Run.\n" "\n" - "The diagram is in DOT language\n"), - action='store_true') + "The diagram is in DOT language\n" + ), + action='store_true', +) ARG_IMGCAT_DAGRUN = Arg( - ("--imgcat-dagrun", ), + ("--imgcat-dagrun",), help=( "After completing the dag run, prints a diagram on the screen for the " "current DAG Run using the imgcat tool.\n" ), - action='store_true') + action='store_true', +) ARG_SAVE_DAGRUN = Arg( - ("--save-dagrun", ), + ("--save-dagrun",), help=( - "After completing the backfill, saves the diagram for current DAG Run to the indicated file.\n" - "\n" - )) + "After completing the backfill, saves the diagram for current DAG Run to the indicated file.\n" "\n" + ), +) # list_tasks -ARG_TREE = Arg( - ("-t", "--tree"), - help="Tree view", - action="store_true") +ARG_TREE = Arg(("-t", "--tree"), help="Tree view", action="store_true") # clear -ARG_UPSTREAM = Arg( - ("-u", "--upstream"), - help="Include upstream tasks", - action="store_true") -ARG_ONLY_FAILED = Arg( - ("-f", "--only-failed"), - help="Only failed jobs", - action="store_true") -ARG_ONLY_RUNNING = Arg( - ("-r", "--only-running"), - help="Only running jobs", - action="store_true") -ARG_DOWNSTREAM = Arg( - ("-d", "--downstream"), - help="Include downstream tasks", - action="store_true") -ARG_EXCLUDE_SUBDAGS = Arg( - ("-x", "--exclude-subdags"), - help="Exclude subdags", - action="store_true") +ARG_UPSTREAM = Arg(("-u", "--upstream"), help="Include upstream tasks", action="store_true") +ARG_ONLY_FAILED = Arg(("-f", "--only-failed"), help="Only failed jobs", action="store_true") +ARG_ONLY_RUNNING = Arg(("-r", "--only-running"), help="Only running jobs", action="store_true") +ARG_DOWNSTREAM = Arg(("-d", "--downstream"), help="Include downstream tasks", action="store_true") +ARG_EXCLUDE_SUBDAGS = Arg(("-x", "--exclude-subdags"), help="Exclude subdags", action="store_true") ARG_EXCLUDE_PARENTDAG = Arg( ("-X", "--exclude-parentdag"), help="Exclude ParentDAGS if the task cleared is a part of a SubDAG", - action="store_true") + action="store_true", +) ARG_DAG_REGEX = Arg( - ("-R", "--dag-regex"), - help="Search dag_id as regex instead of exact string", - action="store_true") + ("-R", "--dag-regex"), help="Search dag_id as regex instead of exact string", action="store_true" +) # show_dag -ARG_SAVE = Arg( - ("-s", "--save"), - help="Saves the result to the indicated file.") +ARG_SAVE = Arg(("-s", "--save"), help="Saves the result to the indicated file.") -ARG_IMGCAT = Arg( - ("--imgcat", ), - help=( - "Displays graph using the imgcat tool."), - action='store_true') +ARG_IMGCAT = Arg(("--imgcat",), help=("Displays graph using the imgcat tool."), action='store_true') # trigger_dag -ARG_RUN_ID = Arg( - ("-r", "--run-id"), - help="Helps to identify this run") -ARG_CONF = Arg( - ('-c', '--conf'), - help="JSON string that gets pickled into the DagRun's conf attribute") -ARG_EXEC_DATE = Arg( - ("-e", "--exec-date"), - help="The execution date of the DAG", - type=parsedate) +ARG_RUN_ID = Arg(("-r", "--run-id"), help="Helps to identify this run") +ARG_CONF = Arg(('-c', '--conf'), help="JSON string that gets pickled into the DagRun's conf attribute") +ARG_EXEC_DATE = Arg(("-e", "--exec-date"), help="The execution date of the DAG", type=parsedate) # pool -ARG_POOL_NAME = Arg( - ("pool",), - metavar='NAME', - help="Pool name") -ARG_POOL_SLOTS = Arg( - ("slots",), - type=int, - help="Pool slots") -ARG_POOL_DESCRIPTION = Arg( - ("description",), - help="Pool description") -ARG_POOL_IMPORT = Arg( - ("file",), - metavar="FILEPATH", - help="Import pools from JSON file") -ARG_POOL_EXPORT = Arg( - ("file",), - metavar="FILEPATH", - help="Export all pools to JSON file") +ARG_POOL_NAME = Arg(("pool",), metavar='NAME', help="Pool name") +ARG_POOL_SLOTS = Arg(("slots",), type=int, help="Pool slots") +ARG_POOL_DESCRIPTION = Arg(("description",), help="Pool description") +ARG_POOL_IMPORT = Arg(("file",), metavar="FILEPATH", help="Import pools from JSON file") +ARG_POOL_EXPORT = Arg(("file",), metavar="FILEPATH", help="Export all pools to JSON file") # variables -ARG_VAR = Arg( - ("key",), - help="Variable key") -ARG_VAR_VALUE = Arg( - ("value",), - metavar='VALUE', - help="Variable value") +ARG_VAR = Arg(("key",), help="Variable key") +ARG_VAR_VALUE = Arg(("value",), metavar='VALUE', help="Variable value") ARG_DEFAULT = Arg( - ("-d", "--default"), - metavar="VAL", - default=None, - help="Default value returned if variable does not exist") -ARG_JSON = Arg( - ("-j", "--json"), - help="Deserialize JSON variable", - action="store_true") -ARG_VAR_IMPORT = Arg( - ("file",), - help="Import variables from JSON file") -ARG_VAR_EXPORT = Arg( - ("file",), - help="Export all variables to JSON file") + ("-d", "--default"), metavar="VAL", default=None, help="Default value returned if variable does not exist" +) +ARG_JSON = Arg(("-j", "--json"), help="Deserialize JSON variable", action="store_true") +ARG_VAR_IMPORT = Arg(("file",), help="Import variables from JSON file") +ARG_VAR_EXPORT = Arg(("file",), help="Export all variables to JSON file") # kerberos -ARG_PRINCIPAL = Arg( - ("principal",), - help="kerberos principal", - nargs='?') -ARG_KEYTAB = Arg( - ("-k", "--keytab"), - help="keytab", - nargs='?', - default=conf.get('kerberos', 'keytab')) +ARG_PRINCIPAL = Arg(("principal",), help="kerberos principal", nargs='?') +ARG_KEYTAB = Arg(("-k", "--keytab"), help="keytab", nargs='?', default=conf.get('kerberos', 'keytab')) # run # TODO(aoen): "force" is a poor choice of name here since it implies it overrides # all dependencies (not just past success), e.g. the ignore_depends_on_past @@ -429,18 +365,20 @@ def positive_int(value): # instead. ARG_INTERACTIVE = Arg( ('-N', '--interactive'), - help='Do not capture standard output and error streams ' - '(useful for interactive debugging)', - action='store_true') + help='Do not capture standard output and error streams ' '(useful for interactive debugging)', + action='store_true', +) ARG_FORCE = Arg( ("-f", "--force"), help="Ignore previous task instance state, rerun regardless if task already succeeded/failed", - action="store_true") + action="store_true", +) ARG_RAW = Arg(("-r", "--raw"), argparse.SUPPRESS, "store_true") ARG_IGNORE_ALL_DEPENDENCIES = Arg( ("-A", "--ignore-all-dependencies"), help="Ignores all non-critical dependencies, including ignore_ti_state and ignore_task_deps", - action="store_true") + action="store_true", +) # TODO(aoen): ignore_dependencies is a poor choice of name here because it is too # vague (e.g. a task being in the appropriate state to be run is also a dependency # but is not ignored by this flag), the name 'ignore_task_dependencies' is @@ -449,24 +387,19 @@ def positive_int(value): ARG_IGNORE_DEPENDENCIES = Arg( ("-i", "--ignore-dependencies"), help="Ignore task-specific dependencies, e.g. upstream, depends_on_past, and retry delay dependencies", - action="store_true") + action="store_true", +) ARG_IGNORE_DEPENDS_ON_PAST = Arg( ("-I", "--ignore-depends-on-past"), help="Ignore depends_on_past dependencies (but respect upstream dependencies)", - action="store_true") + action="store_true", +) ARG_SHIP_DAG = Arg( - ("--ship-dag",), - help="Pickles (serializes) the DAG and ships it to the worker", - action="store_true") -ARG_PICKLE = Arg( - ("-p", "--pickle"), - help="Serialized pickle object of the entire dag (used internally)") -ARG_JOB_ID = Arg( - ("-j", "--job-id"), - help=argparse.SUPPRESS) -ARG_CFG_PATH = Arg( - ("--cfg-path",), - help="Path to config file to use instead of airflow.cfg") + ("--ship-dag",), help="Pickles (serializes) the DAG and ships it to the worker", action="store_true" +) +ARG_PICKLE = Arg(("-p", "--pickle"), help="Serialized pickle object of the entire dag (used internally)") +ARG_JOB_ID = Arg(("-j", "--job-id"), help=argparse.SUPPRESS) +ARG_CFG_PATH = Arg(("--cfg-path",), help="Path to config file to use instead of airflow.cfg") ARG_MIGRATION_TIMEOUT = Arg( ("-t", "--migration-wait-timeout"), help="timeout to wait for db to migrate ", @@ -479,220 +412,194 @@ def positive_int(value): ("-p", "--port"), default=conf.get('webserver', 'WEB_SERVER_PORT'), type=int, - help="The port on which to run the server") + help="The port on which to run the server", +) ARG_SSL_CERT = Arg( ("--ssl-cert",), default=conf.get('webserver', 'WEB_SERVER_SSL_CERT'), - help="Path to the SSL certificate for the webserver") + help="Path to the SSL certificate for the webserver", +) ARG_SSL_KEY = Arg( ("--ssl-key",), default=conf.get('webserver', 'WEB_SERVER_SSL_KEY'), - help="Path to the key to use with the SSL certificate") + help="Path to the key to use with the SSL certificate", +) ARG_WORKERS = Arg( ("-w", "--workers"), default=conf.get('webserver', 'WORKERS'), type=int, - help="Number of workers to run the webserver on") + help="Number of workers to run the webserver on", +) ARG_WORKERCLASS = Arg( ("-k", "--workerclass"), default=conf.get('webserver', 'WORKER_CLASS'), choices=['sync', 'eventlet', 'gevent', 'tornado'], - help="The worker class to use for Gunicorn") + help="The worker class to use for Gunicorn", +) ARG_WORKER_TIMEOUT = Arg( ("-t", "--worker-timeout"), default=conf.get('webserver', 'WEB_SERVER_WORKER_TIMEOUT'), type=int, - help="The timeout for waiting on webserver workers") + help="The timeout for waiting on webserver workers", +) ARG_HOSTNAME = Arg( ("-H", "--hostname"), default=conf.get('webserver', 'WEB_SERVER_HOST'), - help="Set the hostname on which to run the web server") + help="Set the hostname on which to run the web server", +) ARG_DEBUG = Arg( - ("-d", "--debug"), - help="Use the server that ships with Flask in debug mode", - action="store_true") + ("-d", "--debug"), help="Use the server that ships with Flask in debug mode", action="store_true" +) ARG_ACCESS_LOGFILE = Arg( ("-A", "--access-logfile"), default=conf.get('webserver', 'ACCESS_LOGFILE'), - help="The logfile to store the webserver access log. Use '-' to print to " - "stderr") + help="The logfile to store the webserver access log. Use '-' to print to " "stderr", +) ARG_ERROR_LOGFILE = Arg( ("-E", "--error-logfile"), default=conf.get('webserver', 'ERROR_LOGFILE'), - help="The logfile to store the webserver error log. Use '-' to print to " - "stderr") + help="The logfile to store the webserver error log. Use '-' to print to " "stderr", +) # scheduler ARG_NUM_RUNS = Arg( ("-n", "--num-runs"), default=conf.getint('scheduler', 'num_runs'), type=int, - help="Set the number of runs to execute before exiting") + help="Set the number of runs to execute before exiting", +) ARG_DO_PICKLE = Arg( ("-p", "--do-pickle"), default=False, help=( "Attempt to pickle the DAG object to send over " "to the workers, instead of letting workers run their version " - "of the code"), - action="store_true") + "of the code" + ), + action="store_true", +) # worker ARG_QUEUES = Arg( ("-q", "--queues"), help="Comma delimited list of queues to serve", - default=conf.get('celery', 'DEFAULT_QUEUE')) + default=conf.get('celery', 'DEFAULT_QUEUE'), +) ARG_CONCURRENCY = Arg( ("-c", "--concurrency"), type=int, help="The number of worker processes", - default=conf.get('celery', 'worker_concurrency')) + default=conf.get('celery', 'worker_concurrency'), +) ARG_CELERY_HOSTNAME = Arg( ("-H", "--celery-hostname"), - help=("Set the hostname of celery worker " - "if you have multiple workers on a single machine")) + help=("Set the hostname of celery worker " "if you have multiple workers on a single machine"), +) ARG_UMASK = Arg( ("-u", "--umask"), help="Set the umask of celery worker in daemon mode", - default=conf.get('celery', 'worker_umask')) + default=conf.get('celery', 'worker_umask'), +) # flower ARG_BROKER_API = Arg(("-a", "--broker-api"), help="Broker API") ARG_FLOWER_HOSTNAME = Arg( ("-H", "--hostname"), default=conf.get('celery', 'FLOWER_HOST'), - help="Set the hostname on which to run the server") + help="Set the hostname on which to run the server", +) ARG_FLOWER_PORT = Arg( ("-p", "--port"), default=conf.get('celery', 'FLOWER_PORT'), type=int, - help="The port on which to run the server") -ARG_FLOWER_CONF = Arg( - ("-c", "--flower-conf"), - help="Configuration file for flower") + help="The port on which to run the server", +) +ARG_FLOWER_CONF = Arg(("-c", "--flower-conf"), help="Configuration file for flower") ARG_FLOWER_URL_PREFIX = Arg( - ("-u", "--url-prefix"), - default=conf.get('celery', 'FLOWER_URL_PREFIX'), - help="URL prefix for Flower") + ("-u", "--url-prefix"), default=conf.get('celery', 'FLOWER_URL_PREFIX'), help="URL prefix for Flower" +) ARG_FLOWER_BASIC_AUTH = Arg( ("-A", "--basic-auth"), default=conf.get('celery', 'FLOWER_BASIC_AUTH'), - help=("Securing Flower with Basic Authentication. " - "Accepts user:password pairs separated by a comma. " - "Example: flower_basic_auth = user1:password1,user2:password2")) -ARG_TASK_PARAMS = Arg( - ("-t", "--task-params"), - help="Sends a JSON params dict to the task") + help=( + "Securing Flower with Basic Authentication. " + "Accepts user:password pairs separated by a comma. " + "Example: flower_basic_auth = user1:password1,user2:password2" + ), +) +ARG_TASK_PARAMS = Arg(("-t", "--task-params"), help="Sends a JSON params dict to the task") ARG_POST_MORTEM = Arg( - ("-m", "--post-mortem"), - action="store_true", - help="Open debugger on uncaught exception") + ("-m", "--post-mortem"), action="store_true", help="Open debugger on uncaught exception" +) ARG_ENV_VARS = Arg( - ("--env-vars", ), + ("--env-vars",), help="Set env var in both parsing time and runtime for each of entry supplied in a JSON dict", - type=json.loads) + type=json.loads, +) # connections -ARG_CONN_ID = Arg( - ('conn_id',), - help='Connection id, required to get/add/delete a connection', - type=str) +ARG_CONN_ID = Arg(('conn_id',), help='Connection id, required to get/add/delete a connection', type=str) ARG_CONN_ID_FILTER = Arg( - ('--conn-id',), - help='If passed, only items with the specified connection ID will be displayed', - type=str) + ('--conn-id',), help='If passed, only items with the specified connection ID will be displayed', type=str +) ARG_CONN_URI = Arg( - ('--conn-uri',), - help='Connection URI, required to add a connection without conn_type', - type=str) + ('--conn-uri',), help='Connection URI, required to add a connection without conn_type', type=str +) ARG_CONN_TYPE = Arg( - ('--conn-type',), - help='Connection type, required to add a connection without conn_uri', - type=str) -ARG_CONN_HOST = Arg( - ('--conn-host',), - help='Connection host, optional when adding a connection', - type=str) -ARG_CONN_LOGIN = Arg( - ('--conn-login',), - help='Connection login, optional when adding a connection', - type=str) + ('--conn-type',), help='Connection type, required to add a connection without conn_uri', type=str +) +ARG_CONN_HOST = Arg(('--conn-host',), help='Connection host, optional when adding a connection', type=str) +ARG_CONN_LOGIN = Arg(('--conn-login',), help='Connection login, optional when adding a connection', type=str) ARG_CONN_PASSWORD = Arg( - ('--conn-password',), - help='Connection password, optional when adding a connection', - type=str) + ('--conn-password',), help='Connection password, optional when adding a connection', type=str +) ARG_CONN_SCHEMA = Arg( - ('--conn-schema',), - help='Connection schema, optional when adding a connection', - type=str) -ARG_CONN_PORT = Arg( - ('--conn-port',), - help='Connection port, optional when adding a connection', - type=str) + ('--conn-schema',), help='Connection schema, optional when adding a connection', type=str +) +ARG_CONN_PORT = Arg(('--conn-port',), help='Connection port, optional when adding a connection', type=str) ARG_CONN_EXTRA = Arg( - ('--conn-extra',), - help='Connection `Extra` field, optional when adding a connection', - type=str) + ('--conn-extra',), help='Connection `Extra` field, optional when adding a connection', type=str +) ARG_CONN_EXPORT = Arg( ('file',), help='Output file path for exporting the connections', - type=argparse.FileType('w', encoding='UTF-8')) + type=argparse.FileType('w', encoding='UTF-8'), +) ARG_CONN_EXPORT_FORMAT = Arg( - ('--format',), - help='Format of the connections data in file', - type=str, - choices=['json', 'yaml', 'env']) + ('--format',), help='Format of the connections data in file', type=str, choices=['json', 'yaml', 'env'] +) # users -ARG_USERNAME = Arg( - ('-u', '--username'), - help='Username of the user', - required=True, - type=str) -ARG_USERNAME_OPTIONAL = Arg( - ('-u', '--username'), - help='Username of the user', - type=str) -ARG_FIRSTNAME = Arg( - ('-f', '--firstname'), - help='First name of the user', - required=True, - type=str) -ARG_LASTNAME = Arg( - ('-l', '--lastname'), - help='Last name of the user', - required=True, - type=str) +ARG_USERNAME = Arg(('-u', '--username'), help='Username of the user', required=True, type=str) +ARG_USERNAME_OPTIONAL = Arg(('-u', '--username'), help='Username of the user', type=str) +ARG_FIRSTNAME = Arg(('-f', '--firstname'), help='First name of the user', required=True, type=str) +ARG_LASTNAME = Arg(('-l', '--lastname'), help='Last name of the user', required=True, type=str) ARG_ROLE = Arg( ('-r', '--role'), - help='Role of the user. Existing roles include Admin, ' - 'User, Op, Viewer, and Public', - required=True, - type=str,) -ARG_EMAIL = Arg( - ('-e', '--email'), - help='Email of the user', + help='Role of the user. Existing roles include Admin, ' 'User, Op, Viewer, and Public', required=True, - type=str) -ARG_EMAIL_OPTIONAL = Arg( - ('-e', '--email'), - help='Email of the user', - type=str) + type=str, +) +ARG_EMAIL = Arg(('-e', '--email'), help='Email of the user', required=True, type=str) +ARG_EMAIL_OPTIONAL = Arg(('-e', '--email'), help='Email of the user', type=str) ARG_PASSWORD = Arg( ('-p', '--password'), - help='Password of the user, required to create a user ' - 'without --use-random-password', - type=str) + help='Password of the user, required to create a user ' 'without --use-random-password', + type=str, +) ARG_USE_RANDOM_PASSWORD = Arg( ('--use-random-password',), help='Do not prompt for password. Use random string instead.' - ' Required to create a user without --password ', + ' Required to create a user without --password ', default=False, - action='store_true') + action='store_true', +) ARG_USER_IMPORT = Arg( ("import",), metavar="FILEPATH", - help="Import users from JSON file. Example format::\n" + - textwrap.indent(textwrap.dedent(''' + help="Import users from JSON file. Example format::\n" + + textwrap.indent( + textwrap.dedent( + ''' [ { "email": "foo@bar.org", @@ -701,49 +608,33 @@ def positive_int(value): "roles": ["Public"], "username": "jondoe" } - ]'''), " " * 4)) -ARG_USER_EXPORT = Arg( - ("export",), - metavar="FILEPATH", - help="Export all users to JSON file") + ]''' + ), + " " * 4, + ), +) +ARG_USER_EXPORT = Arg(("export",), metavar="FILEPATH", help="Export all users to JSON file") # roles -ARG_CREATE_ROLE = Arg( - ('-c', '--create'), - help='Create a new role', - action='store_true') -ARG_LIST_ROLES = Arg( - ('-l', '--list'), - help='List roles', - action='store_true') -ARG_ROLES = Arg( - ('role',), - help='The name of a role', - nargs='*') -ARG_AUTOSCALE = Arg( - ('-a', '--autoscale'), - help="Minimum and Maximum number of worker to autoscale") +ARG_CREATE_ROLE = Arg(('-c', '--create'), help='Create a new role', action='store_true') +ARG_LIST_ROLES = Arg(('-l', '--list'), help='List roles', action='store_true') +ARG_ROLES = Arg(('role',), help='The name of a role', nargs='*') +ARG_AUTOSCALE = Arg(('-a', '--autoscale'), help="Minimum and Maximum number of worker to autoscale") ARG_SKIP_SERVE_LOGS = Arg( ("-s", "--skip-serve-logs"), default=False, help="Don't start the serve logs process along with the workers", - action="store_true") + action="store_true", +) # info ARG_ANONYMIZE = Arg( ('--anonymize',), - help=( - 'Minimize any personal identifiable information. ' - 'Use it when sharing output with others.' - ), - action='store_true' + help=('Minimize any personal identifiable information. ' 'Use it when sharing output with others.'), + action='store_true', ) ARG_FILE_IO = Arg( - ('--file-io',), - help=( - 'Send output to file.io service and returns link.' - ), - action='store_true' + ('--file-io',), help=('Send output to file.io service and returns link.'), action='store_true' ) # config @@ -764,7 +655,12 @@ def positive_int(value): ) ALTERNATIVE_CONN_SPECS_ARGS = [ - ARG_CONN_TYPE, ARG_CONN_HOST, ARG_CONN_LOGIN, ARG_CONN_PASSWORD, ARG_CONN_SCHEMA, ARG_CONN_PORT + ARG_CONN_TYPE, + ARG_CONN_HOST, + ARG_CONN_LOGIN, + ARG_CONN_PASSWORD, + ARG_CONN_SCHEMA, + ARG_CONN_PORT, ] @@ -867,23 +763,30 @@ class GroupCommand(NamedTuple): ActionCommand( name='show', help="Displays DAG's tasks with their dependencies", - description=("The --imgcat option only works in iTerm.\n" - "\n" - "For more information, see: https://www.iterm2.com/documentation-images.html\n" - "\n" - "The --save option saves the result to the indicated file.\n" - "\n" - "The file format is determined by the file extension. " - "For more information about supported " - "format, see: https://www.graphviz.org/doc/info/output.html\n" - "\n" - "If you want to create a PNG file then you should execute the following command:\n" - "airflow dags show --save output.png\n" - "\n" - "If you want to create a DOT file then you should execute the following command:\n" - "airflow dags show --save output.dot\n"), + description=( + "The --imgcat option only works in iTerm.\n" + "\n" + "For more information, see: https://www.iterm2.com/documentation-images.html\n" + "\n" + "The --save option saves the result to the indicated file.\n" + "\n" + "The file format is determined by the file extension. " + "For more information about supported " + "format, see: https://www.graphviz.org/doc/info/output.html\n" + "\n" + "If you want to create a PNG file then you should execute the following command:\n" + "airflow dags show --save output.png\n" + "\n" + "If you want to create a DOT file then you should execute the following command:\n" + "airflow dags show --save output.dot\n" + ), func=lazy_load_command('airflow.cli.commands.dag_command.dag_show'), - args=(ARG_DAG_ID, ARG_SUBDIR, ARG_SAVE, ARG_IMGCAT,), + args=( + ARG_DAG_ID, + ARG_SUBDIR, + ARG_SAVE, + ARG_IMGCAT, + ), ), ActionCommand( name='backfill', @@ -896,36 +799,58 @@ class GroupCommand(NamedTuple): ), func=lazy_load_command('airflow.cli.commands.dag_command.dag_backfill'), args=( - ARG_DAG_ID, ARG_TASK_REGEX, ARG_START_DATE, ARG_END_DATE, ARG_MARK_SUCCESS, ARG_LOCAL, - ARG_DONOT_PICKLE, ARG_YES, ARG_BF_IGNORE_DEPENDENCIES, ARG_BF_IGNORE_FIRST_DEPENDS_ON_PAST, - ARG_SUBDIR, ARG_POOL, ARG_DELAY_ON_LIMIT, ARG_DRY_RUN, ARG_VERBOSE, ARG_CONF, - ARG_RESET_DAG_RUN, ARG_RERUN_FAILED_TASKS, ARG_RUN_BACKWARDS + ARG_DAG_ID, + ARG_TASK_REGEX, + ARG_START_DATE, + ARG_END_DATE, + ARG_MARK_SUCCESS, + ARG_LOCAL, + ARG_DONOT_PICKLE, + ARG_YES, + ARG_BF_IGNORE_DEPENDENCIES, + ARG_BF_IGNORE_FIRST_DEPENDS_ON_PAST, + ARG_SUBDIR, + ARG_POOL, + ARG_DELAY_ON_LIMIT, + ARG_DRY_RUN, + ARG_VERBOSE, + ARG_CONF, + ARG_RESET_DAG_RUN, + ARG_RERUN_FAILED_TASKS, + ARG_RUN_BACKWARDS, ), ), ActionCommand( name='test', help="Execute one single DagRun", - description=("Execute one single DagRun for a given DAG and execution date, " - "using the DebugExecutor.\n" - "\n" - "The --imgcat-dagrun option only works in iTerm.\n" - "\n" - "For more information, see: https://www.iterm2.com/documentation-images.html\n" - "\n" - "If --save-dagrun is used, then, after completing the backfill, saves the diagram " - "for current DAG Run to the indicated file.\n" - "The file format is determined by the file extension. " - "For more information about supported format, " - "see: https://www.graphviz.org/doc/info/output.html\n" - "\n" - "If you want to create a PNG file then you should execute the following command:\n" - "airflow dags test --save-dagrun output.png\n" - "\n" - "If you want to create a DOT file then you should execute the following command:\n" - "airflow dags test --save-dagrun output.dot\n"), + description=( + "Execute one single DagRun for a given DAG and execution date, " + "using the DebugExecutor.\n" + "\n" + "The --imgcat-dagrun option only works in iTerm.\n" + "\n" + "For more information, see: https://www.iterm2.com/documentation-images.html\n" + "\n" + "If --save-dagrun is used, then, after completing the backfill, saves the diagram " + "for current DAG Run to the indicated file.\n" + "The file format is determined by the file extension. " + "For more information about supported format, " + "see: https://www.graphviz.org/doc/info/output.html\n" + "\n" + "If you want to create a PNG file then you should execute the following command:\n" + "airflow dags test --save-dagrun output.png\n" + "\n" + "If you want to create a DOT file then you should execute the following command:\n" + "airflow dags test --save-dagrun output.dot\n" + ), func=lazy_load_command('airflow.cli.commands.dag_command.dag_test'), args=( - ARG_DAG_ID, ARG_EXECUTION_DATE, ARG_SUBDIR, ARG_SHOW_DAGRUN, ARG_IMGCAT_DAGRUN, ARG_SAVE_DAGRUN + ARG_DAG_ID, + ARG_EXECUTION_DATE, + ARG_SUBDIR, + ARG_SHOW_DAGRUN, + ARG_IMGCAT_DAGRUN, + ARG_SAVE_DAGRUN, ), ), ) @@ -941,9 +866,19 @@ class GroupCommand(NamedTuple): help="Clear a set of task instance, as if they never ran", func=lazy_load_command('airflow.cli.commands.task_command.task_clear'), args=( - ARG_DAG_ID, ARG_TASK_REGEX, ARG_START_DATE, ARG_END_DATE, ARG_SUBDIR, ARG_UPSTREAM, - ARG_DOWNSTREAM, ARG_YES, ARG_ONLY_FAILED, ARG_ONLY_RUNNING, ARG_EXCLUDE_SUBDAGS, - ARG_EXCLUDE_PARENTDAG, ARG_DAG_REGEX + ARG_DAG_ID, + ARG_TASK_REGEX, + ARG_START_DATE, + ARG_END_DATE, + ARG_SUBDIR, + ARG_UPSTREAM, + ARG_DOWNSTREAM, + ARG_YES, + ARG_ONLY_FAILED, + ARG_ONLY_RUNNING, + ARG_EXCLUDE_SUBDAGS, + ARG_EXCLUDE_PARENTDAG, + ARG_DAG_REGEX, ), ), ActionCommand( @@ -974,9 +909,22 @@ class GroupCommand(NamedTuple): help="Run a single task instance", func=lazy_load_command('airflow.cli.commands.task_command.task_run'), args=( - ARG_DAG_ID, ARG_TASK_ID, ARG_EXECUTION_DATE, ARG_SUBDIR, ARG_MARK_SUCCESS, ARG_FORCE, - ARG_POOL, ARG_CFG_PATH, ARG_LOCAL, ARG_RAW, ARG_IGNORE_ALL_DEPENDENCIES, - ARG_IGNORE_DEPENDENCIES, ARG_IGNORE_DEPENDS_ON_PAST, ARG_SHIP_DAG, ARG_PICKLE, ARG_JOB_ID, + ARG_DAG_ID, + ARG_TASK_ID, + ARG_EXECUTION_DATE, + ARG_SUBDIR, + ARG_MARK_SUCCESS, + ARG_FORCE, + ARG_POOL, + ARG_CFG_PATH, + ARG_LOCAL, + ARG_RAW, + ARG_IGNORE_ALL_DEPENDENCIES, + ARG_IGNORE_DEPENDENCIES, + ARG_IGNORE_DEPENDS_ON_PAST, + ARG_SHIP_DAG, + ARG_PICKLE, + ARG_JOB_ID, ARG_INTERACTIVE, ), ), @@ -989,8 +937,14 @@ class GroupCommand(NamedTuple): ), func=lazy_load_command('airflow.cli.commands.task_command.task_test'), args=( - ARG_DAG_ID, ARG_TASK_ID, ARG_EXECUTION_DATE, ARG_SUBDIR, ARG_DRY_RUN, - ARG_TASK_PARAMS, ARG_POST_MORTEM, ARG_ENV_VARS + ARG_DAG_ID, + ARG_TASK_ID, + ARG_EXECUTION_DATE, + ARG_SUBDIR, + ARG_DRY_RUN, + ARG_TASK_PARAMS, + ARG_POST_MORTEM, + ARG_ENV_VARS, ), ), ActionCommand( @@ -1011,31 +965,48 @@ class GroupCommand(NamedTuple): name='get', help='Get pool size', func=lazy_load_command('airflow.cli.commands.pool_command.pool_get'), - args=(ARG_POOL_NAME, ARG_OUTPUT,), + args=( + ARG_POOL_NAME, + ARG_OUTPUT, + ), ), ActionCommand( name='set', help='Configure pool', func=lazy_load_command('airflow.cli.commands.pool_command.pool_set'), - args=(ARG_POOL_NAME, ARG_POOL_SLOTS, ARG_POOL_DESCRIPTION, ARG_OUTPUT,), + args=( + ARG_POOL_NAME, + ARG_POOL_SLOTS, + ARG_POOL_DESCRIPTION, + ARG_OUTPUT, + ), ), ActionCommand( name='delete', help='Delete pool', func=lazy_load_command('airflow.cli.commands.pool_command.pool_delete'), - args=(ARG_POOL_NAME, ARG_OUTPUT,), + args=( + ARG_POOL_NAME, + ARG_OUTPUT, + ), ), ActionCommand( name='import', help='Import pools', func=lazy_load_command('airflow.cli.commands.pool_command.pool_import'), - args=(ARG_POOL_IMPORT, ARG_OUTPUT,), + args=( + ARG_POOL_IMPORT, + ARG_OUTPUT, + ), ), ActionCommand( name='export', help='Export all pools', func=lazy_load_command('airflow.cli.commands.pool_command.pool_export'), - args=(ARG_POOL_EXPORT, ARG_OUTPUT,), + args=( + ARG_POOL_EXPORT, + ARG_OUTPUT, + ), ), ) VARIABLES_COMMANDS = ( @@ -1086,9 +1057,7 @@ class GroupCommand(NamedTuple): ActionCommand( name="check-migrations", help="Check if migration have finished", - description=( - "Check if migration have finished (or continually check until timeout)" - ), + description=("Check if migration have finished (or continually check until timeout)"), func=lazy_load_command('airflow.cli.commands.db_command.check_migrations'), args=(ARG_MIGRATION_TIMEOUT,), ), @@ -1145,18 +1114,23 @@ class GroupCommand(NamedTuple): ActionCommand( name='export', help='Export all connections', - description=("All connections can be exported in STDOUT using the following command:\n" - "airflow connections export -\n" - "The file format can be determined by the provided file extension. eg, The following " - "command will export the connections in JSON format:\n" - "airflow connections export /tmp/connections.json\n" - "The --format parameter can be used to mention the connections format. eg, " - "the default format is JSON in STDOUT mode, which can be overridden using: \n" - "airflow connections export - --format yaml\n" - "The --format parameter can also be used for the files, for example:\n" - "airflow connections export /tmp/connections --format json\n"), + description=( + "All connections can be exported in STDOUT using the following command:\n" + "airflow connections export -\n" + "The file format can be determined by the provided file extension. eg, The following " + "command will export the connections in JSON format:\n" + "airflow connections export /tmp/connections.json\n" + "The --format parameter can be used to mention the connections format. eg, " + "the default format is JSON in STDOUT mode, which can be overridden using: \n" + "airflow connections export - --format yaml\n" + "The --format parameter can also be used for the files, for example:\n" + "airflow connections export /tmp/connections --format json\n" + ), func=lazy_load_command('airflow.cli.commands.connection_command.connections_export'), - args=(ARG_CONN_EXPORT, ARG_CONN_EXPORT_FORMAT,), + args=( + ARG_CONN_EXPORT, + ARG_CONN_EXPORT_FORMAT, + ), ), ) USERS_COMMANDS = ( @@ -1171,8 +1145,13 @@ class GroupCommand(NamedTuple): help='Create a user', func=lazy_load_command('airflow.cli.commands.user_command.users_create'), args=( - ARG_ROLE, ARG_USERNAME, ARG_EMAIL, ARG_FIRSTNAME, ARG_LASTNAME, ARG_PASSWORD, - ARG_USE_RANDOM_PASSWORD + ARG_ROLE, + ARG_USERNAME, + ARG_EMAIL, + ARG_FIRSTNAME, + ARG_LASTNAME, + ARG_PASSWORD, + ARG_USE_RANDOM_PASSWORD, ), epilog=( 'examples:\n' @@ -1184,7 +1163,7 @@ class GroupCommand(NamedTuple): ' --lastname LAST_NAME \\\n' ' --role Admin \\\n' ' --email admin@example.org' - ) + ), ), ActionCommand( name='delete', @@ -1238,8 +1217,17 @@ class GroupCommand(NamedTuple): help="Start a Celery worker node", func=lazy_load_command('airflow.cli.commands.celery_command.worker'), args=( - ARG_QUEUES, ARG_CONCURRENCY, ARG_CELERY_HOSTNAME, ARG_PID, ARG_DAEMON, - ARG_UMASK, ARG_STDOUT, ARG_STDERR, ARG_LOG_FILE, ARG_AUTOSCALE, ARG_SKIP_SERVE_LOGS + ARG_QUEUES, + ARG_CONCURRENCY, + ARG_CELERY_HOSTNAME, + ARG_PID, + ARG_DAEMON, + ARG_UMASK, + ARG_STDOUT, + ARG_STDERR, + ARG_LOG_FILE, + ARG_AUTOSCALE, + ARG_SKIP_SERVE_LOGS, ), ), ActionCommand( @@ -1247,9 +1235,17 @@ class GroupCommand(NamedTuple): help="Start a Celery Flower", func=lazy_load_command('airflow.cli.commands.celery_command.flower'), args=( - ARG_FLOWER_HOSTNAME, ARG_FLOWER_PORT, ARG_FLOWER_CONF, ARG_FLOWER_URL_PREFIX, - ARG_FLOWER_BASIC_AUTH, ARG_BROKER_API, ARG_PID, ARG_DAEMON, ARG_STDOUT, ARG_STDERR, - ARG_LOG_FILE + ARG_FLOWER_HOSTNAME, + ARG_FLOWER_PORT, + ARG_FLOWER_CONF, + ARG_FLOWER_URL_PREFIX, + ARG_FLOWER_BASIC_AUTH, + ARG_BROKER_API, + ARG_PID, + ARG_DAEMON, + ARG_STDOUT, + ARG_STDERR, + ARG_LOG_FILE, ), ), ActionCommand( @@ -1257,7 +1253,7 @@ class GroupCommand(NamedTuple): help="Stop the Celery worker gracefully", func=lazy_load_command('airflow.cli.commands.celery_command.stop_worker'), args=(), - ) + ), ) CONFIG_COMMANDS = ( @@ -1265,13 +1261,16 @@ class GroupCommand(NamedTuple): name='get-value', help='Print the value of the configuration', func=lazy_load_command('airflow.cli.commands.config_command.get_value'), - args=(ARG_SECTION, ARG_OPTION, ), + args=( + ARG_SECTION, + ARG_OPTION, + ), ), ActionCommand( name='list', help='List options for the configuration', func=lazy_load_command('airflow.cli.commands.config_command.show_config'), - args=(ARG_COLOR, ), + args=(ARG_COLOR,), ), ) @@ -1280,12 +1279,12 @@ class GroupCommand(NamedTuple): name='cleanup-pods', help="Clean up Kubernetes pods in evicted/failed/succeeded states", func=lazy_load_command('airflow.cli.commands.kubernetes_command.cleanup_pods'), - args=(ARG_NAMESPACE, ), + args=(ARG_NAMESPACE,), ), ActionCommand( name='generate-dag-yaml', help="Generate YAML files for all tasks in DAG. Useful for debugging tasks without " - "launching into a cluster", + "launching into a cluster", func=lazy_load_command('airflow.cli.commands.kubernetes_command.generate_pod_yaml'), args=(ARG_DAG_ID, ARG_EXECUTION_DATE, ARG_SUBDIR, ARG_OUTPUT_PATH), ), @@ -1298,9 +1297,7 @@ class GroupCommand(NamedTuple): subcommands=DAGS_COMMANDS, ), GroupCommand( - name="kubernetes", - help='Tools to help run the KubernetesExecutor', - subcommands=KUBERNETES_COMMANDS + name="kubernetes", help='Tools to help run the KubernetesExecutor', subcommands=KUBERNETES_COMMANDS ), GroupCommand( name='tasks', @@ -1333,9 +1330,21 @@ class GroupCommand(NamedTuple): help="Start a Airflow webserver instance", func=lazy_load_command('airflow.cli.commands.webserver_command.webserver'), args=( - ARG_PORT, ARG_WORKERS, ARG_WORKERCLASS, ARG_WORKER_TIMEOUT, ARG_HOSTNAME, ARG_PID, - ARG_DAEMON, ARG_STDOUT, ARG_STDERR, ARG_ACCESS_LOGFILE, ARG_ERROR_LOGFILE, ARG_LOG_FILE, - ARG_SSL_CERT, ARG_SSL_KEY, ARG_DEBUG + ARG_PORT, + ARG_WORKERS, + ARG_WORKERCLASS, + ARG_WORKER_TIMEOUT, + ARG_HOSTNAME, + ARG_PID, + ARG_DAEMON, + ARG_STDOUT, + ARG_STDERR, + ARG_ACCESS_LOGFILE, + ARG_ERROR_LOGFILE, + ARG_LOG_FILE, + ARG_SSL_CERT, + ARG_SSL_KEY, + ARG_DEBUG, ), ), ActionCommand( @@ -1343,8 +1352,14 @@ class GroupCommand(NamedTuple): help="Start a scheduler instance", func=lazy_load_command('airflow.cli.commands.scheduler_command.scheduler'), args=( - ARG_SUBDIR, ARG_NUM_RUNS, ARG_DO_PICKLE, ARG_PID, ARG_DAEMON, ARG_STDOUT, - ARG_STDERR, ARG_LOG_FILE + ARG_SUBDIR, + ARG_NUM_RUNS, + ARG_DO_PICKLE, + ARG_PID, + ARG_DAEMON, + ARG_STDOUT, + ARG_STDERR, + ARG_LOG_FILE, ), ), ActionCommand( @@ -1391,16 +1406,15 @@ class GroupCommand(NamedTuple): ), args=(), ), - GroupCommand( - name="config", - help='View configuration', - subcommands=CONFIG_COMMANDS - ), + GroupCommand(name="config", help='View configuration', subcommands=CONFIG_COMMANDS), ActionCommand( name='info', help='Show information about current Airflow and environment', func=lazy_load_command('airflow.cli.commands.info_command.show_info'), - args=(ARG_ANONYMIZE, ARG_FILE_IO, ), + args=( + ARG_ANONYMIZE, + ARG_FILE_IO, + ), ), ActionCommand( name='plugins', @@ -1415,13 +1429,11 @@ class GroupCommand(NamedTuple): 'Start celery components. Works only when using CeleryExecutor. For more information, see ' 'https://airflow.readthedocs.io/en/stable/executor/celery.html' ), - subcommands=CELERY_COMMANDS - ) + subcommands=CELERY_COMMANDS, + ), ] ALL_COMMANDS_DICT: Dict[str, CLICommand] = {sp.name: sp for sp in airflow_commands} -DAG_CLI_COMMANDS: Set[str] = { - 'list_tasks', 'backfill', 'test', 'run', 'pause', 'unpause', 'list_dag_runs' -} +DAG_CLI_COMMANDS: Set[str] = {'list_tasks', 'backfill', 'test', 'run', 'pause', 'unpause', 'list_dag_runs'} class AirflowHelpFormatter(argparse.HelpFormatter): @@ -1483,17 +1495,18 @@ def get_parser(dag_parser: bool = False) -> argparse.ArgumentParser: def _sort_args(args: Iterable[Arg]) -> Iterable[Arg]: """Sort subcommand optional args, keep positional args""" + def get_long_option(arg: Arg): """Get long option from Arg.flags""" return arg.flags[0] if len(arg.flags) == 1 else arg.flags[1] + positional, optional = partition(lambda x: x.flags[0].startswith("-"), args) yield from positional yield from sorted(optional, key=lambda x: get_long_option(x).lower()) def _add_command( - subparsers: argparse._SubParsersAction, # pylint: disable=protected-access - sub: CLICommand + subparsers: argparse._SubParsersAction, sub: CLICommand # pylint: disable=protected-access ) -> None: sub_proc = subparsers.add_parser( sub.name, help=sub.help, description=sub.description or sub.help, epilog=sub.epilog diff --git a/airflow/cli/commands/cheat_sheet_command.py b/airflow/cli/commands/cheat_sheet_command.py index 24a86d319d99c..0e9abaa3009af 100644 --- a/airflow/cli/commands/cheat_sheet_command.py +++ b/airflow/cli/commands/cheat_sheet_command.py @@ -41,6 +41,7 @@ def cheat_sheet(args): def display_commands_index(): """Display list of all commands.""" + def display_recursive(prefix: List[str], commands: Iterable[Union[GroupCommand, ActionCommand]]): actions: List[ActionCommand] groups: List[GroupCommand] diff --git a/airflow/cli/commands/config_command.py b/airflow/cli/commands/config_command.py index 0ec624cffb4f3..2eedfa5cbc373 100644 --- a/airflow/cli/commands/config_command.py +++ b/airflow/cli/commands/config_command.py @@ -32,9 +32,7 @@ def show_config(args): conf.write(output) code = output.getvalue() if should_use_colors(args): - code = pygments.highlight( - code=code, formatter=get_terminal_formatter(), lexer=IniLexer() - ) + code = pygments.highlight(code=code, formatter=get_terminal_formatter(), lexer=IniLexer()) print(code) diff --git a/airflow/cli/commands/connection_command.py b/airflow/cli/commands/connection_command.py index 3a8d31cd5d155..6708806165692 100644 --- a/airflow/cli/commands/connection_command.py +++ b/airflow/cli/commands/connection_command.py @@ -47,7 +47,8 @@ def _tabulate_connection(conns: List[Connection], tablefmt: str): 'Is Encrypted': conn.is_encrypted, 'Is Extra Encrypted': conn.is_encrypted, 'Extra': conn.extra, - } for conn in conns + } + for conn in conns ] msg = tabulate(tabulate_data, tablefmt=tablefmt, headers='keys') @@ -67,7 +68,7 @@ def _yamulate_connection(conn: Connection): 'Is Encrypted': conn.is_encrypted, 'Is Extra Encrypted': conn.is_encrypted, 'Extra': conn.extra_dejson, - 'URI': conn.get_uri() + 'URI': conn.get_uri(), } return yaml.safe_dump(yaml_data, sort_keys=False) @@ -150,8 +151,10 @@ def connections_export(args): _, filetype = os.path.splitext(args.file.name) filetype = filetype.lower() if filetype not in allowed_formats: - msg = f"Unsupported file format. " \ - f"The file must have the extension {', '.join(allowed_formats)}" + msg = ( + f"Unsupported file format. " + f"The file must have the extension {', '.join(allowed_formats)}" + ) raise SystemExit(msg) connections = session.query(Connection).order_by(Connection.conn_id).all() @@ -164,8 +167,7 @@ def connections_export(args): print(f"Connections successfully exported to {args.file.name}") -alternative_conn_specs = ['conn_type', 'conn_host', - 'conn_login', 'conn_password', 'conn_schema', 'conn_port'] +alternative_conn_specs = ['conn_type', 'conn_host', 'conn_login', 'conn_password', 'conn_schema', 'conn_port'] @cli_utils.action_logging @@ -181,42 +183,51 @@ def connections_add(args): elif not args.conn_type: missing_args.append('conn-uri or conn-type') if missing_args: - msg = ('The following args are required to add a connection:' + - f' {missing_args!r}') + msg = 'The following args are required to add a connection:' + f' {missing_args!r}' raise SystemExit(msg) if invalid_args: - msg = ('The following args are not compatible with the ' + - 'add flag and --conn-uri flag: {invalid!r}') + msg = 'The following args are not compatible with the ' + 'add flag and --conn-uri flag: {invalid!r}' msg = msg.format(invalid=invalid_args) raise SystemExit(msg) if args.conn_uri: new_conn = Connection(conn_id=args.conn_id, uri=args.conn_uri) else: - new_conn = Connection(conn_id=args.conn_id, - conn_type=args.conn_type, - host=args.conn_host, - login=args.conn_login, - password=args.conn_password, - schema=args.conn_schema, - port=args.conn_port) + new_conn = Connection( + conn_id=args.conn_id, + conn_type=args.conn_type, + host=args.conn_host, + login=args.conn_login, + password=args.conn_password, + schema=args.conn_schema, + port=args.conn_port, + ) if args.conn_extra is not None: new_conn.set_extra(args.conn_extra) with create_session() as session: - if not (session.query(Connection) - .filter(Connection.conn_id == new_conn.conn_id).first()): + if not session.query(Connection).filter(Connection.conn_id == new_conn.conn_id).first(): session.add(new_conn) msg = '\n\tSuccessfully added `conn_id`={conn_id} : {uri}\n' - msg = msg.format(conn_id=new_conn.conn_id, - uri=args.conn_uri or - urlunparse((args.conn_type, - '{login}:{password}@{host}:{port}' - .format(login=args.conn_login or '', - password='******' if args.conn_password else '', - host=args.conn_host or '', - port=args.conn_port or ''), - args.conn_schema or '', '', '', ''))) + msg = msg.format( + conn_id=new_conn.conn_id, + uri=args.conn_uri + or urlunparse( + ( + args.conn_type, + '{login}:{password}@{host}:{port}'.format( + login=args.conn_login or '', + password='******' if args.conn_password else '', + host=args.conn_host or '', + port=args.conn_port or '', + ), + args.conn_schema or '', + '', + '', + '', + ) + ), + ) print(msg) else: msg = '\n\tA connection with `conn_id`={conn_id} already exists\n' @@ -229,18 +240,14 @@ def connections_delete(args): """Deletes connection from DB""" with create_session() as session: try: - to_delete = (session - .query(Connection) - .filter(Connection.conn_id == args.conn_id) - .one()) + to_delete = session.query(Connection).filter(Connection.conn_id == args.conn_id).one() except exc.NoResultFound: msg = '\n\tDid not find a connection with `conn_id`={conn_id}\n' msg = msg.format(conn_id=args.conn_id) print(msg) return except exc.MultipleResultsFound: - msg = ('\n\tFound more than one connection with ' + - '`conn_id`={conn_id}\n') + msg = '\n\tFound more than one connection with ' + '`conn_id`={conn_id}\n' msg = msg.format(conn_id=args.conn_id) print(msg) return diff --git a/airflow/cli/commands/dag_command.py b/airflow/cli/commands/dag_command.py index 456b5b9838ddf..a3d1fe1bfbff5 100644 --- a/airflow/cli/commands/dag_command.py +++ b/airflow/cli/commands/dag_command.py @@ -52,12 +52,10 @@ def _tabulate_dag_runs(dag_runs: List[DagRun], tablefmt: str = "fancy_grid") -> 'Execution date': dag_run.execution_date.isoformat(), 'Start date': dag_run.start_date.isoformat() if dag_run.start_date else '', 'End date': dag_run.end_date.isoformat() if dag_run.end_date else '', - } for dag_run in dag_runs - ) - return tabulate( - tabular_data=tabulate_data, - tablefmt=tablefmt + } + for dag_run in dag_runs ) + return tabulate(tabular_data=tabulate_data, tablefmt=tablefmt) def _tabulate_dags(dags: List[DAG], tablefmt: str = "fancy_grid") -> str: @@ -66,27 +64,25 @@ def _tabulate_dags(dags: List[DAG], tablefmt: str = "fancy_grid") -> str: 'DAG ID': dag.dag_id, 'Filepath': dag.filepath, 'Owner': dag.owner, - } for dag in sorted(dags, key=lambda d: d.dag_id) - ) - return tabulate( - tabular_data=tabulate_data, - tablefmt=tablefmt, - headers='keys' + } + for dag in sorted(dags, key=lambda d: d.dag_id) ) + return tabulate(tabular_data=tabulate_data, tablefmt=tablefmt, headers='keys') @cli_utils.action_logging def dag_backfill(args, dag=None): """Creates backfill job or dry run for a DAG""" - logging.basicConfig( - level=settings.LOGGING_LEVEL, - format=settings.SIMPLE_LOG_FORMAT) + logging.basicConfig(level=settings.LOGGING_LEVEL, format=settings.SIMPLE_LOG_FORMAT) signal.signal(signal.SIGTERM, sigint_handler) import warnings - warnings.warn('--ignore-first-depends-on-past is deprecated as the value is always set to True', - category=PendingDeprecationWarning) + + warnings.warn( + '--ignore-first-depends-on-past is deprecated as the value is always set to True', + category=PendingDeprecationWarning, + ) if args.ignore_first_depends_on_past is False: args.ignore_first_depends_on_past = True @@ -102,16 +98,15 @@ def dag_backfill(args, dag=None): if args.task_regex: dag = dag.partial_subset( - task_ids_or_regex=args.task_regex, - include_upstream=not args.ignore_dependencies) + task_ids_or_regex=args.task_regex, include_upstream=not args.ignore_dependencies + ) run_conf = None if args.conf: run_conf = json.loads(args.conf) if args.dry_run: - print("Dry run of DAG {} on {}".format(args.dag_id, - args.start_date)) + print(f"Dry run of DAG {args.dag_id} on {args.start_date}") for task in dag.tasks: print(f"Task {task.task_id}") ti = TaskInstance(task, args.start_date) @@ -132,8 +127,7 @@ def dag_backfill(args, dag=None): end_date=args.end_date, mark_success=args.mark_success, local=args.local, - donot_pickle=(args.donot_pickle or - conf.getboolean('core', 'donot_pickle')), + donot_pickle=(args.donot_pickle or conf.getboolean('core', 'donot_pickle')), ignore_first_depends_on_past=args.ignore_first_depends_on_past, ignore_task_deps=args.ignore_dependencies, pool=args.pool, @@ -141,7 +135,7 @@ def dag_backfill(args, dag=None): verbose=args.verbose, conf=run_conf, rerun_failed_tasks=args.rerun_failed_tasks, - run_backwards=args.run_backwards + run_backwards=args.run_backwards, ) @@ -150,10 +144,9 @@ def dag_trigger(args): """Creates a dag run for the specified dag""" api_client = get_current_api_client() try: - message = api_client.trigger_dag(dag_id=args.dag_id, - run_id=args.run_id, - conf=args.conf, - execution_date=args.exec_date) + message = api_client.trigger_dag( + dag_id=args.dag_id, run_id=args.run_id, conf=args.conf, execution_date=args.exec_date + ) print(message) except OSError as err: raise AirflowException(err) @@ -163,9 +156,13 @@ def dag_trigger(args): def dag_delete(args): """Deletes all DB records related to the specified dag""" api_client = get_current_api_client() - if args.yes or input( - "This will drop all existing records related to the specified DAG. " - "Proceed? (y/n)").upper() == "Y": + if ( + args.yes + or input( + "This will drop all existing records related to the specified DAG. " "Proceed? (y/n)" + ).upper() + == "Y" + ): try: message = api_client.delete_dag(dag_id=args.dag_id) print(message) @@ -207,7 +204,7 @@ def dag_show(args): print( "Option --save and --imgcat are mutually exclusive. " "Please remove one option to execute the command.", - file=sys.stderr + file=sys.stderr, ) sys.exit(1) elif filename: @@ -280,8 +277,11 @@ def dag_next_execution(args): next_execution_dttm = dag.following_schedule(latest_execution_date) if next_execution_dttm is None: - print("[WARN] No following schedule can be found. " + - "This DAG may have schedule interval '@once' or `None`.", file=sys.stderr) + print( + "[WARN] No following schedule can be found. " + + "This DAG may have schedule interval '@once' or `None`.", + file=sys.stderr, + ) print(None) else: print(next_execution_dttm) @@ -327,17 +327,18 @@ def dag_list_jobs(args, dag=None): queries.append(BaseJob.state == args.state) with create_session() as session: - all_jobs = (session - .query(BaseJob) - .filter(*queries) - .order_by(BaseJob.start_date.desc()) - .limit(args.limit) - .all()) + all_jobs = ( + session.query(BaseJob) + .filter(*queries) + .order_by(BaseJob.start_date.desc()) + .limit(args.limit) + .all() + ) fields = ['dag_id', 'state', 'job_type', 'start_date', 'end_date'] all_jobs = [[job.__getattribute__(field) for field in fields] for job in all_jobs] - msg = tabulate(all_jobs, - [field.capitalize().replace('_', ' ') for field in fields], - tablefmt=args.output) + msg = tabulate( + all_jobs, [field.capitalize().replace('_', ' ') for field in fields], tablefmt=args.output + ) print(msg) @@ -367,10 +368,7 @@ def dag_list_dag_runs(args, dag=None): return dag_runs.sort(key=lambda x: x.execution_date, reverse=True) - table = _tabulate_dag_runs( - dag_runs, - tablefmt=args.output - ) + table = _tabulate_dag_runs(dag_runs, tablefmt=args.output) print(table) @@ -389,10 +387,14 @@ def dag_test(args, session=None): imgcat = args.imgcat_dagrun filename = args.save_dagrun if show_dagrun or imgcat or filename: - tis = session.query(TaskInstance).filter( - TaskInstance.dag_id == args.dag_id, - TaskInstance.execution_date == args.execution_date, - ).all() + tis = ( + session.query(TaskInstance) + .filter( + TaskInstance.dag_id == args.dag_id, + TaskInstance.execution_date == args.execution_date, + ) + .all() + ) dot_graph = render_dag(dag, tis=tis) print() diff --git a/airflow/cli/commands/db_command.py b/airflow/cli/commands/db_command.py index 6ba5b375f3eb7..1b05bd0cce6af 100644 --- a/airflow/cli/commands/db_command.py +++ b/airflow/cli/commands/db_command.py @@ -35,9 +35,7 @@ def initdb(args): def resetdb(args): """Resets the metadata database""" print("DB: " + repr(settings.engine.url)) - if args.yes or input("This will drop existing tables " - "if they exist. Proceed? " - "(y/n)").upper() == "Y": + if args.yes or input("This will drop existing tables " "if they exist. Proceed? " "(y/n)").upper() == "Y": db.resetdb() else: print("Cancelled") @@ -63,14 +61,16 @@ def shell(args): if url.get_backend_name() == 'mysql': with NamedTemporaryFile(suffix="my.cnf") as f: - content = textwrap.dedent(f""" + content = textwrap.dedent( + f""" [client] host = {url.host} user = {url.username} password = {url.password or ""} port = {url.port or ""} database = {url.database} - """).strip() + """ + ).strip() f.write(content.encode()) f.flush() execute_interactive(["mysql", f"--defaults-extra-file={f.name}"]) diff --git a/airflow/cli/commands/info_command.py b/airflow/cli/commands/info_command.py index 3fcad451464a0..a60a8c2cfd48a 100644 --- a/airflow/cli/commands/info_command.py +++ b/airflow/cli/commands/info_command.py @@ -303,19 +303,17 @@ def __init__(self, anonymizer: Anonymizer): @property def task_logging_handler(self): """Returns task logging handler.""" + def get_fullname(o): module = o.__class__.__module__ if module is None or module == str.__class__.__module__: return o.__class__.__name__ # Avoid reporting __builtin__ else: return module + '.' + o.__class__.__name__ + try: - handler_names = [ - get_fullname(handler) for handler in logging.getLogger('airflow.task').handlers - ] - return ", ".join( - handler_names - ) + handler_names = [get_fullname(handler) for handler in logging.getLogger('airflow.task').handlers] + return ", ".join(handler_names) except Exception: # noqa pylint: disable=broad-except return "NOT AVAILABLE" diff --git a/airflow/cli/commands/legacy_commands.py b/airflow/cli/commands/legacy_commands.py index b66cc6f575371..94f9b690327b2 100644 --- a/airflow/cli/commands/legacy_commands.py +++ b/airflow/cli/commands/legacy_commands.py @@ -44,7 +44,7 @@ "pool": "pools", "list_users": "users list", "create_user": "users create", - "delete_user": "users delete" + "delete_user": "users delete", } diff --git a/airflow/cli/commands/pool_command.py b/airflow/cli/commands/pool_command.py index d824bd40ee16b..4984618d8e026 100644 --- a/airflow/cli/commands/pool_command.py +++ b/airflow/cli/commands/pool_command.py @@ -27,8 +27,7 @@ def _tabulate_pools(pools, tablefmt="fancy_grid"): - return "\n%s" % tabulate(pools, ['Pool', 'Slots', 'Description'], - tablefmt=tablefmt) + return "\n%s" % tabulate(pools, ['Pool', 'Slots', 'Description'], tablefmt=tablefmt) def pool_list(args): @@ -49,9 +48,7 @@ def pool_get(args): def pool_set(args): """Creates new pool with a given name and slots""" api_client = get_current_api_client() - pools = [api_client.create_pool(name=args.pool, - slots=args.slots, - description=args.description)] + pools = [api_client.create_pool(name=args.pool, slots=args.slots, description=args.description)] print(_tabulate_pools(pools=pools, tablefmt=args.output)) @@ -97,9 +94,9 @@ def pool_import_helper(filepath): counter = 0 for k, v in pools_json.items(): if isinstance(v, dict) and len(v) == 2: - pools.append(api_client.create_pool(name=k, - slots=v["slots"], - description=v["description"])) + pools.append( + api_client.create_pool(name=k, slots=v["slots"], description=v["description"]) + ) counter += 1 else: pass diff --git a/airflow/cli/commands/role_command.py b/airflow/cli/commands/role_command.py index b4e6f591a1048..fb34d8c9a6003 100644 --- a/airflow/cli/commands/role_command.py +++ b/airflow/cli/commands/role_command.py @@ -29,9 +29,7 @@ def roles_list(args): roles = appbuilder.sm.get_all_roles() print("Existing roles:\n") role_names = sorted([[r.name] for r in roles]) - msg = tabulate(role_names, - headers=['Role'], - tablefmt=args.output) + msg = tabulate(role_names, headers=['Role'], tablefmt=args.output) print(msg) diff --git a/airflow/cli/commands/rotate_fernet_key_command.py b/airflow/cli/commands/rotate_fernet_key_command.py index 784eae86572c7..10c32d2a1fbaa 100644 --- a/airflow/cli/commands/rotate_fernet_key_command.py +++ b/airflow/cli/commands/rotate_fernet_key_command.py @@ -24,8 +24,7 @@ def rotate_fernet_key(args): """Rotates all encrypted connection credentials and variables""" with create_session() as session: - for conn in session.query(Connection).filter( - Connection.is_encrypted | Connection.is_extra_encrypted): + for conn in session.query(Connection).filter(Connection.is_encrypted | Connection.is_extra_encrypted): conn.rotate_fernet_key() for var in session.query(Variable).filter(Variable.is_encrypted): var.rotate_fernet_key() diff --git a/airflow/cli/commands/scheduler_command.py b/airflow/cli/commands/scheduler_command.py index f0f019a58ac12..100a0f1d0f775 100644 --- a/airflow/cli/commands/scheduler_command.py +++ b/airflow/cli/commands/scheduler_command.py @@ -34,14 +34,13 @@ def scheduler(args): job = SchedulerJob( subdir=process_subdir(args.subdir), num_runs=args.num_runs, - do_pickle=args.do_pickle) + do_pickle=args.do_pickle, + ) if args.daemon: - pid, stdout, stderr, log_file = setup_locations("scheduler", - args.pid, - args.stdout, - args.stderr, - args.log_file) + pid, stdout, stderr, log_file = setup_locations( + "scheduler", args.pid, args.stdout, args.stderr, args.log_file + ) handle = setup_logging(log_file) stdout = open(stdout, 'w+') stderr = open(stderr, 'w+') diff --git a/airflow/cli/commands/sync_perm_command.py b/airflow/cli/commands/sync_perm_command.py index 3a31ea0df3ee3..b072e8de1016a 100644 --- a/airflow/cli/commands/sync_perm_command.py +++ b/airflow/cli/commands/sync_perm_command.py @@ -30,6 +30,4 @@ def sync_perm(args): print('Updating permission on all DAG views') dags = DagBag(read_dags_from_db=True).dags.values() for dag in dags: - appbuilder.sm.sync_perm_for_dag( - dag.dag_id, - dag.access_control) + appbuilder.sm.sync_perm_for_dag(dag.dag_id, dag.access_control) diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py index 46f670b5742c0..3fa1d1fb5a0fa 100644 --- a/airflow/cli/commands/task_command.py +++ b/airflow/cli/commands/task_command.py @@ -77,8 +77,7 @@ def _run_task_by_executor(args, dag, ti): session.add(pickle) pickle_id = pickle.id # TODO: This should be written to a log - print('Pickled dag {dag} as pickle_id: {pickle_id}'.format( - dag=dag, pickle_id=pickle_id)) + print(f'Pickled dag {dag} as pickle_id: {pickle_id}') except Exception as e: print('Could not pickle the DAG') print(e) @@ -94,7 +93,8 @@ def _run_task_by_executor(args, dag, ti): ignore_depends_on_past=args.ignore_depends_on_past, ignore_task_deps=args.ignore_dependencies, ignore_ti_state=args.force, - pool=args.pool) + pool=args.pool, + ) executor.heartbeat() executor.end() @@ -109,12 +109,16 @@ def _run_task_by_local_task_job(args, ti): ignore_depends_on_past=args.ignore_depends_on_past, ignore_task_deps=args.ignore_dependencies, ignore_ti_state=args.force, - pool=args.pool) + pool=args.pool, + ) run_job.run() RAW_TASK_UNSUPPORTED_OPTION = [ - "ignore_all_dependencies", "ignore_depends_on_past", "ignore_dependencies", "force" + "ignore_all_dependencies", + "ignore_depends_on_past", + "ignore_dependencies", + "force", ] @@ -182,8 +186,9 @@ def task_run(args, dag=None): _run_task_by_selected_method(args, dag, ti) else: if settings.DONOT_MODIFY_HANDLERS: - with redirect_stdout(StreamLogWriter(ti.log, logging.INFO)), \ - redirect_stderr(StreamLogWriter(ti.log, logging.WARN)): + with redirect_stdout(StreamLogWriter(ti.log, logging.INFO)), redirect_stderr( + StreamLogWriter(ti.log, logging.WARN) + ): _run_task_by_selected_method(args, dag, ti) else: # Get all the Handlers from 'airflow.task' logger @@ -201,8 +206,9 @@ def task_run(args, dag=None): root_logger.addHandler(handler) root_logger.setLevel(logging.getLogger('airflow.task').level) - with redirect_stdout(StreamLogWriter(ti.log, logging.INFO)), \ - redirect_stderr(StreamLogWriter(ti.log, logging.WARN)): + with redirect_stdout(StreamLogWriter(ti.log, logging.INFO)), redirect_stderr( + StreamLogWriter(ti.log, logging.WARN) + ): _run_task_by_selected_method(args, dag, ti) # We need to restore the handlers to the loggers as celery worker process @@ -300,15 +306,18 @@ def task_states_for_dag_run(args): """Get the status of all task instances in a DagRun""" session = settings.Session() - tis = session.query( - TaskInstance.dag_id, - TaskInstance.execution_date, - TaskInstance.task_id, - TaskInstance.state, - TaskInstance.start_date, - TaskInstance.end_date).filter( - TaskInstance.dag_id == args.dag_id, - TaskInstance.execution_date == args.execution_date).all() + tis = ( + session.query( + TaskInstance.dag_id, + TaskInstance.execution_date, + TaskInstance.task_id, + TaskInstance.state, + TaskInstance.start_date, + TaskInstance.end_date, + ) + .filter(TaskInstance.dag_id == args.dag_id, TaskInstance.execution_date == args.execution_date) + .all() + ) if len(tis) == 0: raise AirflowException("DagRun does not exist.") @@ -316,18 +325,18 @@ def task_states_for_dag_run(args): formatted_rows = [] for ti in tis: - formatted_rows.append((ti.dag_id, - ti.execution_date, - ti.task_id, - ti.state, - ti.start_date, - ti.end_date)) + formatted_rows.append( + (ti.dag_id, ti.execution_date, ti.task_id, ti.state, ti.start_date, ti.end_date) + ) print( - "\n%s" % - tabulate( - formatted_rows, [ - 'dag', 'exec_date', 'task', 'state', 'start_date', 'end_date'], tablefmt=args.output)) + "\n%s" + % tabulate( + formatted_rows, + ['dag', 'exec_date', 'task', 'state', 'start_date', 'end_date'], + tablefmt=args.output, + ) + ) session.close() @@ -387,20 +396,24 @@ def task_render(args): ti = TaskInstance(task, args.execution_date) ti.render_templates() for attr in task.__class__.template_fields: - print(textwrap.dedent("""\ + print( + textwrap.dedent( + """\ # ---------------------------------------------------------- # property: {} # ---------------------------------------------------------- {} - """.format(attr, getattr(task, attr)))) + """.format( + attr, getattr(task, attr) + ) + ) + ) @cli_utils.action_logging def task_clear(args): """Clears all task instances or only those matched by regex for a DAG(s)""" - logging.basicConfig( - level=settings.LOGGING_LEVEL, - format=settings.SIMPLE_LOG_FORMAT) + logging.basicConfig(level=settings.LOGGING_LEVEL, format=settings.SIMPLE_LOG_FORMAT) if args.dag_id and not args.subdir and not args.dag_regex and not args.task_regex: dags = get_dag_by_file_location(args.dag_id) @@ -413,7 +426,8 @@ def task_clear(args): dags[idx] = dag.partial_subset( task_ids_or_regex=args.task_regex, include_downstream=args.downstream, - include_upstream=args.upstream) + include_upstream=args.upstream, + ) DAG.clear_dags( dags, diff --git a/airflow/cli/commands/user_command.py b/airflow/cli/commands/user_command.py index 90a8f2a1f35eb..7536ce6928260 100644 --- a/airflow/cli/commands/user_command.py +++ b/airflow/cli/commands/user_command.py @@ -36,8 +36,7 @@ def users_list(args): users = appbuilder.sm.get_all_users() fields = ['id', 'username', 'email', 'first_name', 'last_name', 'roles'] users = [[user.__getattribute__(field) for field in fields] for user in users] - msg = tabulate(users, [field.capitalize().replace('_', ' ') for field in fields], - tablefmt=args.output) + msg = tabulate(users, [field.capitalize().replace('_', ' ') for field in fields], tablefmt=args.output) print(msg) @@ -63,8 +62,7 @@ def users_create(args): if appbuilder.sm.find_user(args.username): print(f'{args.username} already exist in the db') return - user = appbuilder.sm.add_user(args.username, args.firstname, args.lastname, - args.email, role, password) + user = appbuilder.sm.add_user(args.username, args.firstname, args.lastname, args.email, role, password) if user: print(f'{args.role} user {args.username} created.') else: @@ -77,8 +75,7 @@ def users_delete(args): appbuilder = cached_app().appbuilder # pylint: disable=no-member try: - user = next(u for u in appbuilder.sm.get_all_users() - if u.username == args.username) + user = next(u for u in appbuilder.sm.get_all_users() if u.username == args.username) except StopIteration: raise SystemExit(f'{args.username} is not a valid user.') @@ -95,15 +92,12 @@ def users_manage_role(args, remove=False): raise SystemExit('Missing args: must supply one of --username or --email') if args.username and args.email: - raise SystemExit('Conflicting args: must supply either --username' - ' or --email, but not both') + raise SystemExit('Conflicting args: must supply either --username' ' or --email, but not both') appbuilder = cached_app().appbuilder # pylint: disable=no-member - user = (appbuilder.sm.find_user(username=args.username) or - appbuilder.sm.find_user(email=args.email)) + user = appbuilder.sm.find_user(username=args.username) or appbuilder.sm.find_user(email=args.email) if not user: - raise SystemExit('User "{}" does not exist'.format( - args.username or args.email)) + raise SystemExit('User "{}" does not exist'.format(args.username or args.email)) role = appbuilder.sm.find_role(args.role) if not role: @@ -114,24 +108,16 @@ def users_manage_role(args, remove=False): if role in user.roles: user.roles = [r for r in user.roles if r != role] appbuilder.sm.update_user(user) - print('User "{}" removed from role "{}".'.format( - user, - args.role)) + print(f'User "{user}" removed from role "{args.role}".') else: - raise SystemExit('User "{}" is not a member of role "{}".'.format( - user, - args.role)) + raise SystemExit(f'User "{user}" is not a member of role "{args.role}".') else: if role in user.roles: - raise SystemExit('User "{}" is already a member of role "{}".'.format( - user, - args.role)) + raise SystemExit(f'User "{user}" is already a member of role "{args.role}".') else: user.roles.append(role) appbuilder.sm.update_user(user) - print('User "{}" added to role "{}".'.format( - user, - args.role)) + print(f'User "{user}" added to role "{args.role}".') def users_export(args): @@ -146,9 +132,12 @@ def remove_underscores(s): return re.sub("_", "", s) users = [ - {remove_underscores(field): user.__getattribute__(field) - if field != 'roles' else [r.name for r in user.roles] - for field in fields} + { + remove_underscores(field): user.__getattribute__(field) + if field != 'roles' + else [r.name for r in user.roles] + for field in fields + } for user in users ] @@ -175,12 +164,10 @@ def users_import(args): users_created, users_updated = _import_users(users_list) if users_created: - print("Created the following users:\n\t{}".format( - "\n\t".join(users_created))) + print("Created the following users:\n\t{}".format("\n\t".join(users_created))) if users_updated: - print("Updated the following users:\n\t{}".format( - "\n\t".join(users_updated))) + print("Updated the following users:\n\t{}".format("\n\t".join(users_updated))) def _import_users(users_list): # pylint: disable=redefined-outer-name @@ -199,12 +186,10 @@ def _import_users(users_list): # pylint: disable=redefined-outer-name else: roles.append(role) - required_fields = ['username', 'firstname', 'lastname', - 'email', 'roles'] + required_fields = ['username', 'firstname', 'lastname', 'email', 'roles'] for field in required_fields: if not user.get(field): - print("Error: '{}' is a required field, but was not " - "specified".format(field)) + print("Error: '{}' is a required field, but was not " "specified".format(field)) sys.exit(1) existing_user = appbuilder.sm.find_user(email=user['email']) @@ -215,9 +200,11 @@ def _import_users(users_list): # pylint: disable=redefined-outer-name existing_user.last_name = user['lastname'] if existing_user.username != user['username']: - print("Error: Changing the username is not allowed - " - "please delete and recreate the user with " - "email '{}'".format(user['email'])) + print( + "Error: Changing the username is not allowed - " + "please delete and recreate the user with " + "email '{}'".format(user['email']) + ) sys.exit(1) appbuilder.sm.update_user(existing_user) diff --git a/airflow/cli/commands/variable_command.py b/airflow/cli/commands/variable_command.py index 1781b3667625a..e76dc800eac35 100644 --- a/airflow/cli/commands/variable_command.py +++ b/airflow/cli/commands/variable_command.py @@ -37,17 +37,10 @@ def variables_get(args): """Displays variable by a given name""" try: if args.default is None: - var = Variable.get( - args.key, - deserialize_json=args.json - ) + var = Variable.get(args.key, deserialize_json=args.json) print(var) else: - var = Variable.get( - args.key, - deserialize_json=args.json, - default_var=args.default - ) + var = Variable.get(args.key, deserialize_json=args.json, default_var=args.default) print(var) except (ValueError, KeyError) as e: print(str(e), file=sys.stderr) diff --git a/airflow/cli/commands/webserver_command.py b/airflow/cli/commands/webserver_command.py index 5a4f7e24f986b..1b3f91308220c 100644 --- a/airflow/cli/commands/webserver_command.py +++ b/airflow/cli/commands/webserver_command.py @@ -87,7 +87,7 @@ def __init__( master_timeout: int, worker_refresh_interval: int, worker_refresh_batch_size: int, - reload_on_plugin_change: bool + reload_on_plugin_change: bool, ): super().__init__() self.gunicorn_master_proc = psutil.Process(gunicorn_master_pid) @@ -152,9 +152,7 @@ def _wait_until_true(self, fn, timeout: int = 0) -> None: start_time = time.time() while not fn(): if 0 < timeout <= time.time() - start_time: - raise AirflowWebServerTimeout( - f"No response from gunicorn master within {timeout} seconds" - ) + raise AirflowWebServerTimeout(f"No response from gunicorn master within {timeout} seconds") sleep(0.1) def _spawn_new_workers(self, count: int) -> None: @@ -170,7 +168,7 @@ def _spawn_new_workers(self, count: int) -> None: excess += 1 self._wait_until_true( lambda: self.num_workers_expected + excess == self._get_num_workers_running(), - timeout=self.master_timeout + timeout=self.master_timeout, ) def _kill_old_workers(self, count: int) -> None: @@ -185,7 +183,8 @@ def _kill_old_workers(self, count: int) -> None: self.gunicorn_master_proc.send_signal(signal.SIGTTOU) self._wait_until_true( lambda: self.num_workers_expected + count == self._get_num_workers_running(), - timeout=self.master_timeout) + timeout=self.master_timeout, + ) def _reload_gunicorn(self) -> None: """ @@ -197,8 +196,7 @@ def _reload_gunicorn(self) -> None: self.gunicorn_master_proc.send_signal(signal.SIGHUP) sleep(1) self._wait_until_true( - lambda: self.num_workers_expected == self._get_num_workers_running(), - timeout=self.master_timeout + lambda: self.num_workers_expected == self._get_num_workers_running(), timeout=self.master_timeout ) def start(self) -> NoReturn: @@ -206,7 +204,7 @@ def start(self) -> NoReturn: try: # pylint: disable=too-many-nested-blocks self._wait_until_true( lambda: self.num_workers_expected == self._get_num_workers_running(), - timeout=self.master_timeout + timeout=self.master_timeout, ) while True: if not self.gunicorn_master_proc.is_running(): @@ -232,7 +230,8 @@ def _check_workers(self) -> None: if num_ready_workers_running < num_workers_running: self.log.debug( '[%d / %d] Some workers are starting up, waiting...', - num_ready_workers_running, num_workers_running + num_ready_workers_running, + num_workers_running, ) sleep(1) return @@ -251,9 +250,9 @@ def _check_workers(self) -> None: # to increase number of workers if num_workers_running < self.num_workers_expected: self.log.error( - "[%d / %d] Some workers seem to have died and gunicorn did not restart " - "them as expected", - num_ready_workers_running, num_workers_running + "[%d / %d] Some workers seem to have died and gunicorn did not restart " "them as expected", + num_ready_workers_running, + num_workers_running, ) sleep(10) num_workers_running = self._get_num_workers_running() @@ -263,7 +262,9 @@ def _check_workers(self) -> None: ) self.log.debug( '[%d / %d] Spawning %d workers', - num_ready_workers_running, num_workers_running, new_worker_count + num_ready_workers_running, + num_workers_running, + new_worker_count, ) self._spawn_new_workers(num_workers_running) return @@ -273,12 +274,14 @@ def _check_workers(self) -> None: # If workers should be restarted periodically. if self.worker_refresh_interval > 0 and self._last_refresh_time: # and we refreshed the workers a long time ago, refresh the workers - last_refresh_diff = (time.time() - self._last_refresh_time) + last_refresh_diff = time.time() - self._last_refresh_time if self.worker_refresh_interval < last_refresh_diff: num_new_workers = self.worker_refresh_batch_size self.log.debug( '[%d / %d] Starting doing a refresh. Starting %d workers.', - num_ready_workers_running, num_workers_running, num_new_workers + num_ready_workers_running, + num_workers_running, + num_new_workers, ) self._spawn_new_workers(num_new_workers) self._last_refresh_time = time.time() @@ -293,14 +296,16 @@ def _check_workers(self) -> None: self.log.debug( '[%d / %d] Plugins folder changed. The gunicorn will be restarted the next time the ' 'plugin directory is checked, if there is no change in it.', - num_ready_workers_running, num_workers_running + num_ready_workers_running, + num_workers_running, ) self._restart_on_next_plugin_check = True self._last_plugin_state = new_state elif self._restart_on_next_plugin_check: self.log.debug( '[%d / %d] Starts reloading the gunicorn configuration.', - num_ready_workers_running, num_workers_running + num_ready_workers_running, + num_workers_running, ) self._restart_on_next_plugin_check = False self._last_refresh_time = time.time() @@ -315,25 +320,24 @@ def webserver(args): access_logfile = args.access_logfile or conf.get('webserver', 'access_logfile') error_logfile = args.error_logfile or conf.get('webserver', 'error_logfile') num_workers = args.workers or conf.get('webserver', 'workers') - worker_timeout = (args.worker_timeout or - conf.get('webserver', 'web_server_worker_timeout')) + worker_timeout = args.worker_timeout or conf.get('webserver', 'web_server_worker_timeout') ssl_cert = args.ssl_cert or conf.get('webserver', 'web_server_ssl_cert') ssl_key = args.ssl_key or conf.get('webserver', 'web_server_ssl_key') if not ssl_cert and ssl_key: - raise AirflowException( - 'An SSL certificate must also be provided for use with ' + ssl_key) + raise AirflowException('An SSL certificate must also be provided for use with ' + ssl_key) if ssl_cert and not ssl_key: - raise AirflowException( - 'An SSL key must also be provided for use with ' + ssl_cert) + raise AirflowException('An SSL key must also be provided for use with ' + ssl_cert) if args.debug: - print( - "Starting the web server on port {} and host {}.".format( - args.port, args.hostname)) + print(f"Starting the web server on port {args.port} and host {args.hostname}.") app = create_app(testing=conf.getboolean('core', 'unit_test_mode')) - app.run(debug=True, use_reloader=not app.config['TESTING'], - port=args.port, host=args.hostname, - ssl_context=(ssl_cert, ssl_key) if ssl_cert and ssl_key else None) + app.run( + debug=True, + use_reloader=not app.config['TESTING'], + port=args.port, + host=args.hostname, + ssl_context=(ssl_cert, ssl_key) if ssl_cert and ssl_key else None, + ) else: # This pre-warms the cache, and makes possible errors # get reported earlier (i.e. before demonization) @@ -342,33 +346,49 @@ def webserver(args): os.environ.pop('SKIP_DAGS_PARSING') pid_file, stdout, stderr, log_file = setup_locations( - "webserver", args.pid, args.stdout, args.stderr, args.log_file) + "webserver", args.pid, args.stdout, args.stderr, args.log_file + ) # Check if webserver is already running if not, remove old pidfile check_if_pidfile_process_is_running(pid_file=pid_file, process_name="webserver") print( - textwrap.dedent('''\ + textwrap.dedent( + '''\ Running the Gunicorn Server with: Workers: {num_workers} {workerclass} Host: {hostname}:{port} Timeout: {worker_timeout} Logfiles: {access_logfile} {error_logfile} =================================================================\ - '''.format(num_workers=num_workers, workerclass=args.workerclass, - hostname=args.hostname, port=args.port, - worker_timeout=worker_timeout, access_logfile=access_logfile, - error_logfile=error_logfile))) + '''.format( + num_workers=num_workers, + workerclass=args.workerclass, + hostname=args.hostname, + port=args.port, + worker_timeout=worker_timeout, + access_logfile=access_logfile, + error_logfile=error_logfile, + ) + ) + ) run_args = [ 'gunicorn', - '--workers', str(num_workers), - '--worker-class', str(args.workerclass), - '--timeout', str(worker_timeout), - '--bind', args.hostname + ':' + str(args.port), - '--name', 'airflow-webserver', - '--pid', pid_file, - '--config', 'python:airflow.www.gunicorn_config', + '--workers', + str(num_workers), + '--worker-class', + str(args.workerclass), + '--timeout', + str(worker_timeout), + '--bind', + args.hostname + ':' + str(args.port), + '--name', + 'airflow-webserver', + '--pid', + pid_file, + '--config', + 'python:airflow.www.gunicorn_config', ] if args.access_logfile: diff --git a/airflow/config_templates/airflow_local_settings.py b/airflow/config_templates/airflow_local_settings.py index 67c54447a9ef6..16b372cc3302c 100644 --- a/airflow/config_templates/airflow_local_settings.py +++ b/airflow/config_templates/airflow_local_settings.py @@ -58,19 +58,17 @@ 'version': 1, 'disable_existing_loggers': False, 'formatters': { - 'airflow': { - 'format': LOG_FORMAT - }, + 'airflow': {'format': LOG_FORMAT}, 'airflow_coloured': { 'format': COLORED_LOG_FORMAT if COLORED_LOG else LOG_FORMAT, - 'class': COLORED_FORMATTER_CLASS if COLORED_LOG else 'logging.Formatter' + 'class': COLORED_FORMATTER_CLASS if COLORED_LOG else 'logging.Formatter', }, }, 'handlers': { 'console': { 'class': 'airflow.utils.log.logging_mixin.RedirectStdHandler', 'formatter': 'airflow_coloured', - 'stream': 'sys.stdout' + 'stream': 'sys.stdout', }, 'task': { 'class': 'airflow.utils.log.file_task_handler.FileTaskHandler', @@ -83,7 +81,7 @@ 'formatter': 'airflow', 'base_log_folder': os.path.expanduser(PROCESSOR_LOG_FOLDER), 'filename_template': PROCESSOR_FILENAME_TEMPLATE, - } + }, }, 'loggers': { 'airflow.processor': { @@ -100,19 +98,22 @@ 'handler': ['console'], 'level': FAB_LOG_LEVEL, 'propagate': True, - } + }, }, 'root': { 'handlers': ['console'], 'level': LOG_LEVEL, - } + }, } EXTRA_LOGGER_NAMES: str = conf.get('logging', 'EXTRA_LOGGER_NAMES', fallback=None) if EXTRA_LOGGER_NAMES: new_loggers = { logger_name.strip(): { - 'handler': ['console'], 'level': LOG_LEVEL, 'propagate': True, } + 'handler': ['console'], + 'level': LOG_LEVEL, + 'propagate': True, + } for logger_name in EXTRA_LOGGER_NAMES.split(",") } DEFAULT_LOGGING_CONFIG['loggers'].update(new_loggers) @@ -125,7 +126,7 @@ 'filename': DAG_PROCESSOR_MANAGER_LOG_LOCATION, 'mode': 'a', 'maxBytes': 104857600, # 100MB - 'backupCount': 5 + 'backupCount': 5, } }, 'loggers': { @@ -134,22 +135,21 @@ 'level': LOG_LEVEL, 'propagate': False, } - } + }, } # Only update the handlers and loggers when CONFIG_PROCESSOR_MANAGER_LOGGER is set. # This is to avoid exceptions when initializing RotatingFileHandler multiple times # in multiple processes. if os.environ.get('CONFIG_PROCESSOR_MANAGER_LOGGER') == 'True': - DEFAULT_LOGGING_CONFIG['handlers'] \ - .update(DEFAULT_DAG_PARSING_LOGGING_CONFIG['handlers']) - DEFAULT_LOGGING_CONFIG['loggers'] \ - .update(DEFAULT_DAG_PARSING_LOGGING_CONFIG['loggers']) + DEFAULT_LOGGING_CONFIG['handlers'].update(DEFAULT_DAG_PARSING_LOGGING_CONFIG['handlers']) + DEFAULT_LOGGING_CONFIG['loggers'].update(DEFAULT_DAG_PARSING_LOGGING_CONFIG['loggers']) # Manually create log directory for processor_manager handler as RotatingFileHandler # will only create file but not the directory. - processor_manager_handler_config: Dict[str, Any] = \ - DEFAULT_DAG_PARSING_LOGGING_CONFIG['handlers']['processor_manager'] + processor_manager_handler_config: Dict[str, Any] = DEFAULT_DAG_PARSING_LOGGING_CONFIG['handlers'][ + 'processor_manager' + ] directory: str = os.path.dirname(processor_manager_handler_config['filename']) Path(directory).mkdir(parents=True, exist_ok=True, mode=0o755) @@ -204,7 +204,7 @@ 'base_log_folder': str(os.path.expanduser(BASE_LOG_FOLDER)), 'gcs_log_folder': REMOTE_BASE_LOG_FOLDER, 'filename_template': FILENAME_TEMPLATE, - 'gcp_key_path': key_path + 'gcp_key_path': key_path, }, } @@ -232,7 +232,7 @@ 'class': 'airflow.providers.google.cloud.log.stackdriver_task_handler.StackdriverTaskHandler', 'formatter': 'airflow', 'name': log_name, - 'gcp_key_path': key_path + 'gcp_key_path': key_path, } } @@ -257,7 +257,7 @@ 'frontend': ELASTICSEARCH_FRONTEND, 'write_stdout': ELASTICSEARCH_WRITE_STDOUT, 'json_format': ELASTICSEARCH_JSON_FORMAT, - 'json_fields': ELASTICSEARCH_JSON_FIELDS + 'json_fields': ELASTICSEARCH_JSON_FIELDS, }, } @@ -266,4 +266,5 @@ raise AirflowException( "Incorrect remote log configuration. Please check the configuration of option 'host' in " "section 'elasticsearch' if you are using Elasticsearch. In the other case, " - "'remote_base_log_folder' option in 'core' section.") + "'remote_base_log_folder' option in 'core' section." + ) diff --git a/airflow/config_templates/default_celery.py b/airflow/config_templates/default_celery.py index b12bf14e13e10..14d4c0f2fc597 100644 --- a/airflow/config_templates/default_celery.py +++ b/airflow/config_templates/default_celery.py @@ -59,30 +59,43 @@ def _broker_supports_visibility_timeout(url): try: if celery_ssl_active: if 'amqp://' in broker_url: - broker_use_ssl = {'keyfile': conf.get('celery', 'SSL_KEY'), - 'certfile': conf.get('celery', 'SSL_CERT'), - 'ca_certs': conf.get('celery', 'SSL_CACERT'), - 'cert_reqs': ssl.CERT_REQUIRED} + broker_use_ssl = { + 'keyfile': conf.get('celery', 'SSL_KEY'), + 'certfile': conf.get('celery', 'SSL_CERT'), + 'ca_certs': conf.get('celery', 'SSL_CACERT'), + 'cert_reqs': ssl.CERT_REQUIRED, + } elif 'redis://' in broker_url: - broker_use_ssl = {'ssl_keyfile': conf.get('celery', 'SSL_KEY'), - 'ssl_certfile': conf.get('celery', 'SSL_CERT'), - 'ssl_ca_certs': conf.get('celery', 'SSL_CACERT'), - 'ssl_cert_reqs': ssl.CERT_REQUIRED} + broker_use_ssl = { + 'ssl_keyfile': conf.get('celery', 'SSL_KEY'), + 'ssl_certfile': conf.get('celery', 'SSL_CERT'), + 'ssl_ca_certs': conf.get('celery', 'SSL_CACERT'), + 'ssl_cert_reqs': ssl.CERT_REQUIRED, + } else: - raise AirflowException('The broker you configured does not support SSL_ACTIVE to be True. ' - 'Please use RabbitMQ or Redis if you would like to use SSL for broker.') + raise AirflowException( + 'The broker you configured does not support SSL_ACTIVE to be True. ' + 'Please use RabbitMQ or Redis if you would like to use SSL for broker.' + ) DEFAULT_CELERY_CONFIG['broker_use_ssl'] = broker_use_ssl except AirflowConfigException: - raise AirflowException('AirflowConfigException: SSL_ACTIVE is True, ' - 'please ensure SSL_KEY, ' - 'SSL_CERT and SSL_CACERT are set') + raise AirflowException( + 'AirflowConfigException: SSL_ACTIVE is True, ' + 'please ensure SSL_KEY, ' + 'SSL_CERT and SSL_CACERT are set' + ) except Exception as e: - raise AirflowException('Exception: There was an unknown Celery SSL Error. ' - 'Please ensure you want to use ' - 'SSL and/or have all necessary certs and key ({}).'.format(e)) + raise AirflowException( + 'Exception: There was an unknown Celery SSL Error. ' + 'Please ensure you want to use ' + 'SSL and/or have all necessary certs and key ({}).'.format(e) + ) result_backend = DEFAULT_CELERY_CONFIG['result_backend'] if 'amqp://' in result_backend or 'redis://' in result_backend or 'rpc://' in result_backend: - log.warning("You have configured a result_backend of %s, it is highly recommended " - "to use an alternative result_backend (i.e. a database).", result_backend) + log.warning( + "You have configured a result_backend of %s, it is highly recommended " + "to use an alternative result_backend (i.e. a database).", + result_backend, + ) diff --git a/airflow/configuration.py b/airflow/configuration.py index bc4b041a3e50e..b5d1aa2ec2265 100644 --- a/airflow/configuration.py +++ b/airflow/configuration.py @@ -28,6 +28,7 @@ import warnings from base64 import b64encode from collections import OrderedDict + # Ignored Mypy on configparser because it thinks the configparser module has no _UNSET attribute from configparser import _UNSET, ConfigParser, NoOptionError, NoSectionError # type: ignore from json.decoder import JSONDecodeError diff --git a/airflow/contrib/__init__.py b/airflow/contrib/__init__.py index e852a5918ebb4..3a89862127ebc 100644 --- a/airflow/contrib/__init__.py +++ b/airflow/contrib/__init__.py @@ -19,6 +19,4 @@ import warnings -warnings.warn( - "This module is deprecated.", DeprecationWarning, stacklevel=2 -) +warnings.warn("This module is deprecated.", DeprecationWarning, stacklevel=2) diff --git a/airflow/contrib/hooks/aws_dynamodb_hook.py b/airflow/contrib/hooks/aws_dynamodb_hook.py index 9ea6e1d29121a..fc6bd1b79c4cc 100644 --- a/airflow/contrib/hooks/aws_dynamodb_hook.py +++ b/airflow/contrib/hooks/aws_dynamodb_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.dynamodb`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/aws_glue_catalog_hook.py b/airflow/contrib/hooks/aws_glue_catalog_hook.py index f36820e2196e4..59733ea44b320 100644 --- a/airflow/contrib/hooks/aws_glue_catalog_hook.py +++ b/airflow/contrib/hooks/aws_glue_catalog_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.glue_catalog`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/aws_hook.py b/airflow/contrib/hooks/aws_hook.py index 55bd95139e70e..c18577656fcbd 100644 --- a/airflow/contrib/hooks/aws_hook.py +++ b/airflow/contrib/hooks/aws_hook.py @@ -24,7 +24,8 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.base_aws`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -34,6 +35,7 @@ class AwsHook(AwsBaseHook): def __init__(self, *args, **kwargs): warnings.warn( "This class is deprecated. Please use `airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/hooks/aws_logs_hook.py b/airflow/contrib/hooks/aws_logs_hook.py index d2b3446aad3a8..ab27562f55624 100644 --- a/airflow/contrib/hooks/aws_logs_hook.py +++ b/airflow/contrib/hooks/aws_logs_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.logs`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/azure_container_instance_hook.py b/airflow/contrib/hooks/azure_container_instance_hook.py index eaa0ff6cfd054..eaa9f091752b4 100644 --- a/airflow/contrib/hooks/azure_container_instance_hook.py +++ b/airflow/contrib/hooks/azure_container_instance_hook.py @@ -30,5 +30,6 @@ warnings.warn( "This module is deprecated. " "Please use `airflow.providers.microsoft.azure.hooks.azure_container_instance`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/azure_container_registry_hook.py b/airflow/contrib/hooks/azure_container_registry_hook.py index 22260c49bf31a..fe9cfdc488e70 100644 --- a/airflow/contrib/hooks/azure_container_registry_hook.py +++ b/airflow/contrib/hooks/azure_container_registry_hook.py @@ -30,5 +30,6 @@ warnings.warn( "This module is deprecated. " "Please use `airflow.providers.microsoft.azure.hooks.azure_container_registry`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/azure_container_volume_hook.py b/airflow/contrib/hooks/azure_container_volume_hook.py index 2547a5fca3fdb..4c747cadd3472 100644 --- a/airflow/contrib/hooks/azure_container_volume_hook.py +++ b/airflow/contrib/hooks/azure_container_volume_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.azure_container_volume`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/azure_cosmos_hook.py b/airflow/contrib/hooks/azure_cosmos_hook.py index 35ef4aea403c1..d39766935f7e1 100644 --- a/airflow/contrib/hooks/azure_cosmos_hook.py +++ b/airflow/contrib/hooks/azure_cosmos_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.azure_cosmos`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/azure_data_lake_hook.py b/airflow/contrib/hooks/azure_data_lake_hook.py index d11836f490178..0e21185e8eda6 100644 --- a/airflow/contrib/hooks/azure_data_lake_hook.py +++ b/airflow/contrib/hooks/azure_data_lake_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.azure_data_lake`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/azure_fileshare_hook.py b/airflow/contrib/hooks/azure_fileshare_hook.py index e166aa995bfe9..b4a9f99e39d34 100644 --- a/airflow/contrib/hooks/azure_fileshare_hook.py +++ b/airflow/contrib/hooks/azure_fileshare_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.azure_fileshare`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/bigquery_hook.py b/airflow/contrib/hooks/bigquery_hook.py index 6c9a95baa6524..6b95898c76d63 100644 --- a/airflow/contrib/hooks/bigquery_hook.py +++ b/airflow/contrib/hooks/bigquery_hook.py @@ -21,11 +21,16 @@ # pylint: disable=unused-import from airflow.providers.google.cloud.hooks.bigquery import ( # noqa - BigQueryBaseCursor, BigQueryConnection, BigQueryCursor, BigQueryHook, BigQueryPandasConnector, + BigQueryBaseCursor, + BigQueryConnection, + BigQueryCursor, + BigQueryHook, + BigQueryPandasConnector, GbqConnector, ) warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.bigquery`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/cassandra_hook.py b/airflow/contrib/hooks/cassandra_hook.py index c0002862ae264..223f15b237223 100644 --- a/airflow/contrib/hooks/cassandra_hook.py +++ b/airflow/contrib/hooks/cassandra_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.apache.cassandra.hooks.cassandra`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/cloudant_hook.py b/airflow/contrib/hooks/cloudant_hook.py index 453dd36a99abd..4d1da6c987ed5 100644 --- a/airflow/contrib/hooks/cloudant_hook.py +++ b/airflow/contrib/hooks/cloudant_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.cloudant.hooks.cloudant`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/databricks_hook.py b/airflow/contrib/hooks/databricks_hook.py index 3c610e316c21a..b2f177983294b 100644 --- a/airflow/contrib/hooks/databricks_hook.py +++ b/airflow/contrib/hooks/databricks_hook.py @@ -21,12 +21,21 @@ # pylint: disable=unused-import from airflow.providers.databricks.hooks.databricks import ( # noqa - CANCEL_RUN_ENDPOINT, GET_RUN_ENDPOINT, RESTART_CLUSTER_ENDPOINT, RUN_LIFE_CYCLE_STATES, RUN_NOW_ENDPOINT, - START_CLUSTER_ENDPOINT, SUBMIT_RUN_ENDPOINT, TERMINATE_CLUSTER_ENDPOINT, USER_AGENT_HEADER, - DatabricksHook, RunState, + CANCEL_RUN_ENDPOINT, + GET_RUN_ENDPOINT, + RESTART_CLUSTER_ENDPOINT, + RUN_LIFE_CYCLE_STATES, + RUN_NOW_ENDPOINT, + START_CLUSTER_ENDPOINT, + SUBMIT_RUN_ENDPOINT, + TERMINATE_CLUSTER_ENDPOINT, + USER_AGENT_HEADER, + DatabricksHook, + RunState, ) warnings.warn( "This module is deprecated. Please use `airflow.providers.databricks.hooks.databricks`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/datadog_hook.py b/airflow/contrib/hooks/datadog_hook.py index 4746724e0d2ce..3751127db8393 100644 --- a/airflow/contrib/hooks/datadog_hook.py +++ b/airflow/contrib/hooks/datadog_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.datadog.hooks.datadog`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/datastore_hook.py b/airflow/contrib/hooks/datastore_hook.py index 2d395e10951fc..3a072eca543c0 100644 --- a/airflow/contrib/hooks/datastore_hook.py +++ b/airflow/contrib/hooks/datastore_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.datastore`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/dingding_hook.py b/airflow/contrib/hooks/dingding_hook.py index 5a4aaf78ed140..66f9c2954aa60 100644 --- a/airflow/contrib/hooks/dingding_hook.py +++ b/airflow/contrib/hooks/dingding_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.dingding.hooks.dingding`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/discord_webhook_hook.py b/airflow/contrib/hooks/discord_webhook_hook.py index 8de11e0016a18..e6c3cd00913f0 100644 --- a/airflow/contrib/hooks/discord_webhook_hook.py +++ b/airflow/contrib/hooks/discord_webhook_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.discord.hooks.discord_webhook`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/emr_hook.py b/airflow/contrib/hooks/emr_hook.py index a8cadcb6f727a..0ff71d4af9fba 100644 --- a/airflow/contrib/hooks/emr_hook.py +++ b/airflow/contrib/hooks/emr_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.emr`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/fs_hook.py b/airflow/contrib/hooks/fs_hook.py index 93adafda16209..73d409136a2e8 100644 --- a/airflow/contrib/hooks/fs_hook.py +++ b/airflow/contrib/hooks/fs_hook.py @@ -23,6 +23,5 @@ from airflow.hooks.filesystem import FSHook # noqa warnings.warn( - "This module is deprecated. Please use `airflow.hooks.filesystem`.", - DeprecationWarning, stacklevel=2 + "This module is deprecated. Please use `airflow.hooks.filesystem`.", DeprecationWarning, stacklevel=2 ) diff --git a/airflow/contrib/hooks/ftp_hook.py b/airflow/contrib/hooks/ftp_hook.py index afe236bed01a8..bdb47b68c3c1d 100644 --- a/airflow/contrib/hooks/ftp_hook.py +++ b/airflow/contrib/hooks/ftp_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.ftp.hooks.ftp`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/gcp_api_base_hook.py b/airflow/contrib/hooks/gcp_api_base_hook.py index ca173f44955bb..468d7a3699b70 100644 --- a/airflow/contrib/hooks/gcp_api_base_hook.py +++ b/airflow/contrib/hooks/gcp_api_base_hook.py @@ -22,7 +22,8 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.common.hooks.base_google`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -36,6 +37,7 @@ def __init__(self, *args, **kwargs): warnings.warn( "This class is deprecated. Please use " "`airflow.providers.google.common.hooks.base_google.GoogleBaseHook`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/hooks/gcp_bigtable_hook.py b/airflow/contrib/hooks/gcp_bigtable_hook.py index 2926f5f0d4f4c..372293070a9eb 100644 --- a/airflow/contrib/hooks/gcp_bigtable_hook.py +++ b/airflow/contrib/hooks/gcp_bigtable_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.bigtable`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/gcp_cloud_build_hook.py b/airflow/contrib/hooks/gcp_cloud_build_hook.py index 79009cd3ebc42..34535622a139e 100644 --- a/airflow/contrib/hooks/gcp_cloud_build_hook.py +++ b/airflow/contrib/hooks/gcp_cloud_build_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.cloud_build`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/gcp_compute_hook.py b/airflow/contrib/hooks/gcp_compute_hook.py index 84e2e4f7f59e7..f24f0bd4be6c4 100644 --- a/airflow/contrib/hooks/gcp_compute_hook.py +++ b/airflow/contrib/hooks/gcp_compute_hook.py @@ -23,7 +23,8 @@ warnings.warn( "This module is deprecated. Please use airflow.providers.google.cloud.hooks.compute`", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -36,7 +37,8 @@ class GceHook(ComputeEngineHook): def __init__(self, *args, **kwargs): warnings.warn( "This class is deprecated. Please use `airflow.providers.google.cloud.hooks.compute`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/hooks/gcp_container_hook.py b/airflow/contrib/hooks/gcp_container_hook.py index 5d8d900a90d4d..d014a440eaa28 100644 --- a/airflow/contrib/hooks/gcp_container_hook.py +++ b/airflow/contrib/hooks/gcp_container_hook.py @@ -23,7 +23,8 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.kubernetes_engine`", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -33,6 +34,7 @@ class GKEClusterHook(GKEHook): def __init__(self, *args, **kwargs): warnings.warn( "This class is deprecated. Please use `airflow.providers.google.cloud.hooks.container.GKEHook`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/hooks/gcp_dataflow_hook.py b/airflow/contrib/hooks/gcp_dataflow_hook.py index 54080c513fafa..8e1502325a705 100644 --- a/airflow/contrib/hooks/gcp_dataflow_hook.py +++ b/airflow/contrib/hooks/gcp_dataflow_hook.py @@ -23,7 +23,8 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.dataflow`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -34,6 +35,7 @@ def __init__(self, *args, **kwargs): warnings.warn( "This class is deprecated. " "Please use `airflow.providers.google.cloud.hooks.dataflow.DataflowHook`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/hooks/gcp_dataproc_hook.py b/airflow/contrib/hooks/gcp_dataproc_hook.py index 75c6df88a9d33..54f618e41a80a 100644 --- a/airflow/contrib/hooks/gcp_dataproc_hook.py +++ b/airflow/contrib/hooks/gcp_dataproc_hook.py @@ -23,7 +23,8 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.dataproc`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -37,7 +38,8 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.hooks.dataproc.DataprocHook`.""", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/hooks/gcp_dlp_hook.py b/airflow/contrib/hooks/gcp_dlp_hook.py index cb1fbfc74637b..bf496e4f93a98 100644 --- a/airflow/contrib/hooks/gcp_dlp_hook.py +++ b/airflow/contrib/hooks/gcp_dlp_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.dlp`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/gcp_function_hook.py b/airflow/contrib/hooks/gcp_function_hook.py index d684d8d451658..d728fc98dfc81 100644 --- a/airflow/contrib/hooks/gcp_function_hook.py +++ b/airflow/contrib/hooks/gcp_function_hook.py @@ -23,7 +23,8 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.functions`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -37,7 +38,8 @@ def __init__(self, *args, **kwargs): warnings.warn( "This class is deprecated. " "Please use `airflow.providers.google.cloud.hooks.functions.CloudFunctionsHook`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/hooks/gcp_kms_hook.py b/airflow/contrib/hooks/gcp_kms_hook.py index 6bd57399b4b38..7fce040de5031 100644 --- a/airflow/contrib/hooks/gcp_kms_hook.py +++ b/airflow/contrib/hooks/gcp_kms_hook.py @@ -23,7 +23,8 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.kms`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -33,6 +34,7 @@ class GoogleCloudKMSHook(CloudKMSHook): def __init__(self, *args, **kwargs): warnings.warn( "This class is deprecated. Please use `airflow.providers.google.cloud.hooks.kms.CloudKMSHook`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/hooks/gcp_mlengine_hook.py b/airflow/contrib/hooks/gcp_mlengine_hook.py index 07687c0b189ca..cec3f31de694b 100644 --- a/airflow/contrib/hooks/gcp_mlengine_hook.py +++ b/airflow/contrib/hooks/gcp_mlengine_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.mlengine`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/gcp_natural_language_hook.py b/airflow/contrib/hooks/gcp_natural_language_hook.py index 70f9405c01e2a..0243ff6c7a7ee 100644 --- a/airflow/contrib/hooks/gcp_natural_language_hook.py +++ b/airflow/contrib/hooks/gcp_natural_language_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.natural_language`", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/gcp_pubsub_hook.py b/airflow/contrib/hooks/gcp_pubsub_hook.py index e04f3cd6712e1..23f7ab0868329 100644 --- a/airflow/contrib/hooks/gcp_pubsub_hook.py +++ b/airflow/contrib/hooks/gcp_pubsub_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.pubsub`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/gcp_speech_to_text_hook.py b/airflow/contrib/hooks/gcp_speech_to_text_hook.py index dedece29b11c3..5b27c779c8b6b 100644 --- a/airflow/contrib/hooks/gcp_speech_to_text_hook.py +++ b/airflow/contrib/hooks/gcp_speech_to_text_hook.py @@ -23,7 +23,8 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.speech_to_text`", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -37,7 +38,8 @@ def __init__(self, *args, **kwargs): warnings.warn( "This class is deprecated. " "Please use `airflow.providers.google.cloud.hooks.speech_to_text.CloudSpeechToTextHook`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/hooks/gcp_tasks_hook.py b/airflow/contrib/hooks/gcp_tasks_hook.py index e577de66e6130..837c210c8b1b9 100644 --- a/airflow/contrib/hooks/gcp_tasks_hook.py +++ b/airflow/contrib/hooks/gcp_tasks_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.tasks`", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/gcp_text_to_speech_hook.py b/airflow/contrib/hooks/gcp_text_to_speech_hook.py index 593502f9ab7a1..8a7c8fb23bf77 100644 --- a/airflow/contrib/hooks/gcp_text_to_speech_hook.py +++ b/airflow/contrib/hooks/gcp_text_to_speech_hook.py @@ -23,7 +23,8 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.text_to_speech`", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -37,7 +38,8 @@ def __init__(self, *args, **kwargs): warnings.warn( "This class is deprecated. " "Please use `airflow.providers.google.cloud.hooks.text_to_speech.CloudTextToSpeechHook`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/hooks/gcp_transfer_hook.py b/airflow/contrib/hooks/gcp_transfer_hook.py index 535c24fe2cf06..57c098d52b7f9 100644 --- a/airflow/contrib/hooks/gcp_transfer_hook.py +++ b/airflow/contrib/hooks/gcp_transfer_hook.py @@ -27,7 +27,8 @@ warnings.warn( "This module is deprecated. " "Please use `airflow.providers.google.cloud.hooks.cloud_storage_transfer_service`", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -43,6 +44,7 @@ def __init__(self, *args, **kwargs): Please use `airflow.providers.google.cloud.hooks.cloud_storage_transfer_service .CloudDataTransferServiceHook`.""", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/hooks/gcp_translate_hook.py b/airflow/contrib/hooks/gcp_translate_hook.py index 016f8efc6175c..a0b1788ec8dc0 100644 --- a/airflow/contrib/hooks/gcp_translate_hook.py +++ b/airflow/contrib/hooks/gcp_translate_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.translate`", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/gcp_video_intelligence_hook.py b/airflow/contrib/hooks/gcp_video_intelligence_hook.py index 76ad51ef68ceb..9a99748ea9452 100644 --- a/airflow/contrib/hooks/gcp_video_intelligence_hook.py +++ b/airflow/contrib/hooks/gcp_video_intelligence_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.video_intelligence`", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/gcp_vision_hook.py b/airflow/contrib/hooks/gcp_vision_hook.py index 05ac86a5e86f2..98ac270ba1144 100644 --- a/airflow/contrib/hooks/gcp_vision_hook.py +++ b/airflow/contrib/hooks/gcp_vision_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.vision`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/gcs_hook.py b/airflow/contrib/hooks/gcs_hook.py index 625e1aa3313c7..7134f327e5f6a 100644 --- a/airflow/contrib/hooks/gcs_hook.py +++ b/airflow/contrib/hooks/gcs_hook.py @@ -22,7 +22,8 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.gcs`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -32,6 +33,7 @@ class GoogleCloudStorageHook(GCSHook): def __init__(self, *args, **kwargs): warnings.warn( "This class is deprecated. Please use `airflow.providers.google.cloud.hooks.gcs.GCSHook`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/hooks/gdrive_hook.py b/airflow/contrib/hooks/gdrive_hook.py index b2deb96529a83..c6d8172e7bb49 100644 --- a/airflow/contrib/hooks/gdrive_hook.py +++ b/airflow/contrib/hooks/gdrive_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.suite.hooks.drive`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/grpc_hook.py b/airflow/contrib/hooks/grpc_hook.py index 7f89d401169ed..631b47d418b21 100644 --- a/airflow/contrib/hooks/grpc_hook.py +++ b/airflow/contrib/hooks/grpc_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.grpc.hooks.grpc`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/imap_hook.py b/airflow/contrib/hooks/imap_hook.py index 12f0ab1306c98..c868dfa1ac968 100644 --- a/airflow/contrib/hooks/imap_hook.py +++ b/airflow/contrib/hooks/imap_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.imap.hooks.imap`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/jenkins_hook.py b/airflow/contrib/hooks/jenkins_hook.py index 72dabb6fbb38e..1e260e4595cbe 100644 --- a/airflow/contrib/hooks/jenkins_hook.py +++ b/airflow/contrib/hooks/jenkins_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.jenkins.hooks.jenkins`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/mongo_hook.py b/airflow/contrib/hooks/mongo_hook.py index db25b1131f580..ac5a3a0b753b0 100644 --- a/airflow/contrib/hooks/mongo_hook.py +++ b/airflow/contrib/hooks/mongo_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.mongo.hooks.mongo`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/openfaas_hook.py b/airflow/contrib/hooks/openfaas_hook.py index 4ea87488189a7..32ab77ced95d6 100644 --- a/airflow/contrib/hooks/openfaas_hook.py +++ b/airflow/contrib/hooks/openfaas_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.openfaas.hooks.openfaas`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/opsgenie_alert_hook.py b/airflow/contrib/hooks/opsgenie_alert_hook.py index 5a6b338cc996f..9bf46057a6ec9 100644 --- a/airflow/contrib/hooks/opsgenie_alert_hook.py +++ b/airflow/contrib/hooks/opsgenie_alert_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.opsgenie.hooks.opsgenie_alert`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/pagerduty_hook.py b/airflow/contrib/hooks/pagerduty_hook.py index cfce5e86160b6..0f1bbdd0943e2 100644 --- a/airflow/contrib/hooks/pagerduty_hook.py +++ b/airflow/contrib/hooks/pagerduty_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.pagerduty.hooks.pagerduty`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/pinot_hook.py b/airflow/contrib/hooks/pinot_hook.py index 5ecfbfa35dfc3..d3b9816c3ffbb 100644 --- a/airflow/contrib/hooks/pinot_hook.py +++ b/airflow/contrib/hooks/pinot_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.apache.pinot.hooks.pinot`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/qubole_check_hook.py b/airflow/contrib/hooks/qubole_check_hook.py index 03fed8c1e6da6..8df277d17c641 100644 --- a/airflow/contrib/hooks/qubole_check_hook.py +++ b/airflow/contrib/hooks/qubole_check_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.qubole.hooks.qubole_check`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/qubole_hook.py b/airflow/contrib/hooks/qubole_hook.py index ab2e6cd8b04e6..aee57ba675ada 100644 --- a/airflow/contrib/hooks/qubole_hook.py +++ b/airflow/contrib/hooks/qubole_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.qubole.hooks.qubole`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/redis_hook.py b/airflow/contrib/hooks/redis_hook.py index 7938f57a4ac2b..d2bd8aba1b3af 100644 --- a/airflow/contrib/hooks/redis_hook.py +++ b/airflow/contrib/hooks/redis_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.redis.hooks.redis`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/sagemaker_hook.py b/airflow/contrib/hooks/sagemaker_hook.py index c2f60125a815f..8d33c27359c23 100644 --- a/airflow/contrib/hooks/sagemaker_hook.py +++ b/airflow/contrib/hooks/sagemaker_hook.py @@ -21,11 +21,16 @@ # pylint: disable=unused-import from airflow.providers.amazon.aws.hooks.sagemaker import ( # noqa - LogState, Position, SageMakerHook, argmin, secondary_training_status_changed, + LogState, + Position, + SageMakerHook, + argmin, + secondary_training_status_changed, secondary_training_status_message, ) warnings.warn( "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.sagemaker`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/salesforce_hook.py b/airflow/contrib/hooks/salesforce_hook.py index 5ea627a0b0674..d1462e1c490b8 100644 --- a/airflow/contrib/hooks/salesforce_hook.py +++ b/airflow/contrib/hooks/salesforce_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.salesforce.hooks.salesforce`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/segment_hook.py b/airflow/contrib/hooks/segment_hook.py index c286fbe405d6b..7e3247c1018b2 100644 --- a/airflow/contrib/hooks/segment_hook.py +++ b/airflow/contrib/hooks/segment_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.segment.hooks.segment`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/slack_webhook_hook.py b/airflow/contrib/hooks/slack_webhook_hook.py index 420a34bf7a90a..7e4ec92d73157 100644 --- a/airflow/contrib/hooks/slack_webhook_hook.py +++ b/airflow/contrib/hooks/slack_webhook_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.slack.hooks.slack_webhook`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/spark_jdbc_hook.py b/airflow/contrib/hooks/spark_jdbc_hook.py index af3ed2c0149bf..189e5bf65c6cb 100644 --- a/airflow/contrib/hooks/spark_jdbc_hook.py +++ b/airflow/contrib/hooks/spark_jdbc_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.apache.spark.hooks.spark_jdbc`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/spark_sql_hook.py b/airflow/contrib/hooks/spark_sql_hook.py index 760c0e7808e82..639fae1392c9f 100644 --- a/airflow/contrib/hooks/spark_sql_hook.py +++ b/airflow/contrib/hooks/spark_sql_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.apache.spark.hooks.spark_sql`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/spark_submit_hook.py b/airflow/contrib/hooks/spark_submit_hook.py index d067ba7f186ad..5c69e48c552fe 100644 --- a/airflow/contrib/hooks/spark_submit_hook.py +++ b/airflow/contrib/hooks/spark_submit_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.apache.spark.hooks.spark_submit`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/sqoop_hook.py b/airflow/contrib/hooks/sqoop_hook.py index ccf1df3948929..ce5c737371a54 100644 --- a/airflow/contrib/hooks/sqoop_hook.py +++ b/airflow/contrib/hooks/sqoop_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.apache.sqoop.hooks.sqoop`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/ssh_hook.py b/airflow/contrib/hooks/ssh_hook.py index 2f1b02e7c5295..d374339390f3d 100644 --- a/airflow/contrib/hooks/ssh_hook.py +++ b/airflow/contrib/hooks/ssh_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.ssh.hooks.ssh`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/vertica_hook.py b/airflow/contrib/hooks/vertica_hook.py index 9a38a496d2aba..592fccc2b3dd2 100644 --- a/airflow/contrib/hooks/vertica_hook.py +++ b/airflow/contrib/hooks/vertica_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.vertica.hooks.vertica`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/wasb_hook.py b/airflow/contrib/hooks/wasb_hook.py index 6cb1f66ad6dba..a1b62ad539bb4 100644 --- a/airflow/contrib/hooks/wasb_hook.py +++ b/airflow/contrib/hooks/wasb_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.wasb`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/hooks/winrm_hook.py b/airflow/contrib/hooks/winrm_hook.py index c45c825404d04..24e5abec28536 100644 --- a/airflow/contrib/hooks/winrm_hook.py +++ b/airflow/contrib/hooks/winrm_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.microsoft.winrm.hooks.winrm`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/adls_list_operator.py b/airflow/contrib/operators/adls_list_operator.py index f6643ecdf65c3..1572a22d72d22 100644 --- a/airflow/contrib/operators/adls_list_operator.py +++ b/airflow/contrib/operators/adls_list_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.microsoft.azure.operators.adls_list`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/adls_to_gcs.py b/airflow/contrib/operators/adls_to_gcs.py index 0b273ca93a275..7839b4750728b 100644 --- a/airflow/contrib/operators/adls_to_gcs.py +++ b/airflow/contrib/operators/adls_to_gcs.py @@ -23,7 +23,8 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.transfers.adls_to_gcs`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -37,6 +38,7 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.transfers.adls_to_gcs.ADLSToGCSOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/azure_container_instances_operator.py b/airflow/contrib/operators/azure_container_instances_operator.py index 56ea84fb7bf27..7efd0084cccf3 100644 --- a/airflow/contrib/operators/azure_container_instances_operator.py +++ b/airflow/contrib/operators/azure_container_instances_operator.py @@ -29,5 +29,6 @@ warnings.warn( "This module is deprecated. " "Please use `airflow.providers.microsoft.azure.operators.azure_container_instances`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/azure_cosmos_operator.py b/airflow/contrib/operators/azure_cosmos_operator.py index 48eb2c2932dba..f08b8ddf1f5e9 100644 --- a/airflow/contrib/operators/azure_cosmos_operator.py +++ b/airflow/contrib/operators/azure_cosmos_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.microsoft.azure.operators.azure_cosmos`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/bigquery_check_operator.py b/airflow/contrib/operators/bigquery_check_operator.py index 29069fd55aebe..c6813946c1af1 100644 --- a/airflow/contrib/operators/bigquery_check_operator.py +++ b/airflow/contrib/operators/bigquery_check_operator.py @@ -21,10 +21,13 @@ # pylint: disable=unused-import from airflow.providers.google.cloud.operators.bigquery import ( # noqa - BigQueryCheckOperator, BigQueryIntervalCheckOperator, BigQueryValueCheckOperator, + BigQueryCheckOperator, + BigQueryIntervalCheckOperator, + BigQueryValueCheckOperator, ) warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.operators.bigquery`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/bigquery_get_data.py b/airflow/contrib/operators/bigquery_get_data.py index 9fe6be369907a..0a5a176f65031 100644 --- a/airflow/contrib/operators/bigquery_get_data.py +++ b/airflow/contrib/operators/bigquery_get_data.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.operators.bigquery`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/bigquery_operator.py b/airflow/contrib/operators/bigquery_operator.py index c01c10366f538..32746264fa8e5 100644 --- a/airflow/contrib/operators/bigquery_operator.py +++ b/airflow/contrib/operators/bigquery_operator.py @@ -21,9 +21,15 @@ # pylint: disable=unused-import from airflow.providers.google.cloud.operators.bigquery import ( # noqa; noqa; noqa; noqa; noqa - BigQueryCreateEmptyDatasetOperator, BigQueryCreateEmptyTableOperator, BigQueryCreateExternalTableOperator, - BigQueryDeleteDatasetOperator, BigQueryExecuteQueryOperator, BigQueryGetDatasetOperator, - BigQueryGetDatasetTablesOperator, BigQueryPatchDatasetOperator, BigQueryUpdateDatasetOperator, + BigQueryCreateEmptyDatasetOperator, + BigQueryCreateEmptyTableOperator, + BigQueryCreateExternalTableOperator, + BigQueryDeleteDatasetOperator, + BigQueryExecuteQueryOperator, + BigQueryGetDatasetOperator, + BigQueryGetDatasetTablesOperator, + BigQueryPatchDatasetOperator, + BigQueryUpdateDatasetOperator, BigQueryUpsertTableOperator, ) diff --git a/airflow/contrib/operators/bigquery_to_bigquery.py b/airflow/contrib/operators/bigquery_to_bigquery.py index e585ccdf04a00..74df81e70ec91 100644 --- a/airflow/contrib/operators/bigquery_to_bigquery.py +++ b/airflow/contrib/operators/bigquery_to_bigquery.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.transfers.bigquery_to_bigquery`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/bigquery_to_gcs.py b/airflow/contrib/operators/bigquery_to_gcs.py index d51397da64ef2..243bc4343ee68 100644 --- a/airflow/contrib/operators/bigquery_to_gcs.py +++ b/airflow/contrib/operators/bigquery_to_gcs.py @@ -23,7 +23,8 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.transfers.bigquery_to_gcs`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -37,6 +38,7 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.transfers.bigquery_to_gcs.BigQueryToGCSOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/bigquery_to_mysql_operator.py b/airflow/contrib/operators/bigquery_to_mysql_operator.py index 37fec73ac92b3..bebb8edc3de92 100644 --- a/airflow/contrib/operators/bigquery_to_mysql_operator.py +++ b/airflow/contrib/operators/bigquery_to_mysql_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.transfers.bigquery_to_mysql`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/cassandra_to_gcs.py b/airflow/contrib/operators/cassandra_to_gcs.py index 9ac70f8f1da46..3d889e686f351 100644 --- a/airflow/contrib/operators/cassandra_to_gcs.py +++ b/airflow/contrib/operators/cassandra_to_gcs.py @@ -23,7 +23,8 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.transfers.cassandra_to_gcs`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -37,6 +38,7 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.transfers.cassandra_to_gcs.CassandraToGCSOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/databricks_operator.py b/airflow/contrib/operators/databricks_operator.py index a0b8e495d9d16..6f7c950b366dc 100644 --- a/airflow/contrib/operators/databricks_operator.py +++ b/airflow/contrib/operators/databricks_operator.py @@ -21,10 +21,12 @@ # pylint: disable=unused-import from airflow.providers.databricks.operators.databricks import ( # noqa - DatabricksRunNowOperator, DatabricksSubmitRunOperator, + DatabricksRunNowOperator, + DatabricksSubmitRunOperator, ) warnings.warn( "This module is deprecated. Please use `airflow.providers.databricks.operators.databricks`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/dataflow_operator.py b/airflow/contrib/operators/dataflow_operator.py index 1f1f79470eeab..3d03a8bf5f779 100644 --- a/airflow/contrib/operators/dataflow_operator.py +++ b/airflow/contrib/operators/dataflow_operator.py @@ -20,12 +20,15 @@ import warnings from airflow.providers.google.cloud.operators.dataflow import ( - DataflowCreateJavaJobOperator, DataflowCreatePythonJobOperator, DataflowTemplatedJobStartOperator, + DataflowCreateJavaJobOperator, + DataflowCreatePythonJobOperator, + DataflowTemplatedJobStartOperator, ) warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.operators.dataflow`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -39,7 +42,8 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.operators.dataflow.DataflowCreateJavaJobOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -55,7 +59,8 @@ def __init__(self, *args, **kwargs): """This class is deprecated. Please use `airflow.providers.google.cloud.operators.dataflow.DataflowCreatePythonJobOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -71,6 +76,7 @@ def __init__(self, *args, **kwargs): """This class is deprecated. Please use `airflow.providers.google.cloud.operators.dataflow.DataflowTemplatedJobStartOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/dataproc_operator.py b/airflow/contrib/operators/dataproc_operator.py index 0337b58eeff46..feac6274a9db0 100644 --- a/airflow/contrib/operators/dataproc_operator.py +++ b/airflow/contrib/operators/dataproc_operator.py @@ -20,11 +20,18 @@ import warnings from airflow.providers.google.cloud.operators.dataproc import ( - DataprocCreateClusterOperator, DataprocDeleteClusterOperator, - DataprocInstantiateInlineWorkflowTemplateOperator, DataprocInstantiateWorkflowTemplateOperator, - DataprocJobBaseOperator, DataprocScaleClusterOperator, DataprocSubmitHadoopJobOperator, - DataprocSubmitHiveJobOperator, DataprocSubmitPigJobOperator, DataprocSubmitPySparkJobOperator, - DataprocSubmitSparkJobOperator, DataprocSubmitSparkSqlJobOperator, + DataprocCreateClusterOperator, + DataprocDeleteClusterOperator, + DataprocInstantiateInlineWorkflowTemplateOperator, + DataprocInstantiateWorkflowTemplateOperator, + DataprocJobBaseOperator, + DataprocScaleClusterOperator, + DataprocSubmitHadoopJobOperator, + DataprocSubmitHiveJobOperator, + DataprocSubmitPigJobOperator, + DataprocSubmitPySparkJobOperator, + DataprocSubmitSparkJobOperator, + DataprocSubmitSparkSqlJobOperator, ) warnings.warn( @@ -44,7 +51,8 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.operators.dataproc.DataprocCreateClusterOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -59,7 +67,8 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.operators.dataproc.DataprocDeleteClusterOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -74,7 +83,8 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.operators.dataproc.DataprocScaleClusterOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -90,7 +100,8 @@ def __init__(self, *args, **kwargs): """This class is deprecated. Please use `airflow.providers.google.cloud.operators.dataproc.DataprocSubmitHadoopJobOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -106,7 +117,8 @@ def __init__(self, *args, **kwargs): """This class is deprecated. Please use `airflow.providers.google.cloud.operators.dataproc.DataprocSubmitHiveJobOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -121,7 +133,8 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.operators.dataproc.DataprocJobBaseOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -136,7 +149,8 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.operators.dataproc.DataprocSubmitPigJobOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -152,7 +166,8 @@ def __init__(self, *args, **kwargs): """This class is deprecated. Please use `airflow.providers.google.cloud.operators.dataproc.DataprocSubmitPySparkJobOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -168,7 +183,8 @@ def __init__(self, *args, **kwargs): """This class is deprecated. Please use `airflow.providers.google.cloud.operators.dataproc.DataprocSubmitSparkJobOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -184,7 +200,8 @@ def __init__(self, *args, **kwargs): """This class is deprecated. Please use `airflow.providers.google.cloud.operators.dataproc.DataprocSubmitSparkSqlJobOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -202,7 +219,8 @@ def __init__(self, *args, **kwargs): Please use `airflow.providers.google.cloud.operators.dataproc .DataprocInstantiateInlineWorkflowTemplateOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -220,6 +238,7 @@ def __init__(self, *args, **kwargs): Please use `airflow.providers.google.cloud.operators.dataproc .DataprocInstantiateWorkflowTemplateOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/datastore_export_operator.py b/airflow/contrib/operators/datastore_export_operator.py index e2aca08c8df1f..bc88e6e47e988 100644 --- a/airflow/contrib/operators/datastore_export_operator.py +++ b/airflow/contrib/operators/datastore_export_operator.py @@ -23,7 +23,8 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.operators.datastore`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -38,6 +39,7 @@ def __init__(self, *args, **kwargs): """This class is deprecated.l Please use `airflow.providers.google.cloud.operators.datastore.CloudDatastoreExportEntitiesOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/datastore_import_operator.py b/airflow/contrib/operators/datastore_import_operator.py index 2bd8b2493b0d8..4375336588e9e 100644 --- a/airflow/contrib/operators/datastore_import_operator.py +++ b/airflow/contrib/operators/datastore_import_operator.py @@ -23,7 +23,8 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.operators.datastore`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -38,6 +39,7 @@ def __init__(self, *args, **kwargs): """This class is deprecated. Please use `airflow.providers.google.cloud.operators.datastore.CloudDatastoreImportEntitiesOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/dingding_operator.py b/airflow/contrib/operators/dingding_operator.py index 7f911d46ed133..0fb52aacff631 100644 --- a/airflow/contrib/operators/dingding_operator.py +++ b/airflow/contrib/operators/dingding_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.dingding.operators.dingding`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/discord_webhook_operator.py b/airflow/contrib/operators/discord_webhook_operator.py index 51c12b840ad8c..d6392bc9c5a69 100644 --- a/airflow/contrib/operators/discord_webhook_operator.py +++ b/airflow/contrib/operators/discord_webhook_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.discord.operators.discord_webhook`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/docker_swarm_operator.py b/airflow/contrib/operators/docker_swarm_operator.py index 3ac517fd80b6c..796a60d0eb8f9 100644 --- a/airflow/contrib/operators/docker_swarm_operator.py +++ b/airflow/contrib/operators/docker_swarm_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.docker.operators.docker_swarm`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/druid_operator.py b/airflow/contrib/operators/druid_operator.py index 912aea975c9ee..ea3d1fb705177 100644 --- a/airflow/contrib/operators/druid_operator.py +++ b/airflow/contrib/operators/druid_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.apache.druid.operators.druid`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/dynamodb_to_s3.py b/airflow/contrib/operators/dynamodb_to_s3.py index 748a69dc92263..28ade7edc50b3 100644 --- a/airflow/contrib/operators/dynamodb_to_s3.py +++ b/airflow/contrib/operators/dynamodb_to_s3.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.amazon.aws.transfers.dynamodb_to_s3`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/ecs_operator.py b/airflow/contrib/operators/ecs_operator.py index 37fc32d8c59d9..8c8573847b7ca 100644 --- a/airflow/contrib/operators/ecs_operator.py +++ b/airflow/contrib/operators/ecs_operator.py @@ -25,7 +25,8 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.ecs`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/emr_add_steps_operator.py b/airflow/contrib/operators/emr_add_steps_operator.py index 19f2db0b119b3..421d4dcd3395b 100644 --- a/airflow/contrib/operators/emr_add_steps_operator.py +++ b/airflow/contrib/operators/emr_add_steps_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.emr_add_steps`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/emr_create_job_flow_operator.py b/airflow/contrib/operators/emr_create_job_flow_operator.py index eafa3111c0e02..bf59df5879360 100644 --- a/airflow/contrib/operators/emr_create_job_flow_operator.py +++ b/airflow/contrib/operators/emr_create_job_flow_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.emr_create_job_flow`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/emr_terminate_job_flow_operator.py b/airflow/contrib/operators/emr_terminate_job_flow_operator.py index 281f64cc80ae9..0158399413427 100644 --- a/airflow/contrib/operators/emr_terminate_job_flow_operator.py +++ b/airflow/contrib/operators/emr_terminate_job_flow_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.emr_terminate_job_flow`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/file_to_gcs.py b/airflow/contrib/operators/file_to_gcs.py index 7b7b259d33f15..4e696a56e91c4 100644 --- a/airflow/contrib/operators/file_to_gcs.py +++ b/airflow/contrib/operators/file_to_gcs.py @@ -23,7 +23,8 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.transfers.local_to_gcs`,", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -38,6 +39,7 @@ def __init__(self, *args, **kwargs): """This class is deprecated. Please use `airflow.providers.google.cloud.transfers.local_to_gcs.LocalFilesystemToGCSOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/file_to_wasb.py b/airflow/contrib/operators/file_to_wasb.py index 8ed0da71336ca..3c9866cbe664c 100644 --- a/airflow/contrib/operators/file_to_wasb.py +++ b/airflow/contrib/operators/file_to_wasb.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.microsoft.azure.transfers.file_to_wasb`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/gcp_bigtable_operator.py b/airflow/contrib/operators/gcp_bigtable_operator.py index 507a55f79dfce..7256d64052d14 100644 --- a/airflow/contrib/operators/gcp_bigtable_operator.py +++ b/airflow/contrib/operators/gcp_bigtable_operator.py @@ -23,15 +23,19 @@ import warnings from airflow.providers.google.cloud.operators.bigtable import ( - BigtableCreateInstanceOperator, BigtableCreateTableOperator, BigtableDeleteInstanceOperator, - BigtableDeleteTableOperator, BigtableUpdateClusterOperator, + BigtableCreateInstanceOperator, + BigtableCreateTableOperator, + BigtableDeleteInstanceOperator, + BigtableDeleteTableOperator, + BigtableUpdateClusterOperator, ) from airflow.providers.google.cloud.sensors.bigtable import BigtableTableReplicationCompletedSensor warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.operators.bigtable`" " or `airflow.providers.google.cloud.sensors.bigtable`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -45,7 +49,8 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.operators.bigtable.BigtableUpdateClusterOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -60,7 +65,8 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.operators.bigtable.BigtableCreateInstanceOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -75,7 +81,8 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.operators.bigtable.BigtableDeleteInstanceOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -90,7 +97,8 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.operators.bigtable.BigtableCreateTableOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -105,7 +113,8 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.operators.bigtable.BigtableDeleteTableOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -122,6 +131,7 @@ def __init__(self, *args, **kwargs): """This class is deprecated. Please use `airflow.providers.google.cloud.sensors.bigtable.BigtableTableReplicationCompletedSensor`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/gcp_cloud_build_operator.py b/airflow/contrib/operators/gcp_cloud_build_operator.py index 318ca765c0aaa..ea585089b6440 100644 --- a/airflow/contrib/operators/gcp_cloud_build_operator.py +++ b/airflow/contrib/operators/gcp_cloud_build_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.operators.cloud_build`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/gcp_compute_operator.py b/airflow/contrib/operators/gcp_compute_operator.py index 968b4b2df1299..efd513e193ab7 100644 --- a/airflow/contrib/operators/gcp_compute_operator.py +++ b/airflow/contrib/operators/gcp_compute_operator.py @@ -20,14 +20,18 @@ import warnings from airflow.providers.google.cloud.operators.compute import ( - ComputeEngineBaseOperator, ComputeEngineCopyInstanceTemplateOperator, - ComputeEngineInstanceGroupUpdateManagerTemplateOperator, ComputeEngineSetMachineTypeOperator, - ComputeEngineStartInstanceOperator, ComputeEngineStopInstanceOperator, + ComputeEngineBaseOperator, + ComputeEngineCopyInstanceTemplateOperator, + ComputeEngineInstanceGroupUpdateManagerTemplateOperator, + ComputeEngineSetMachineTypeOperator, + ComputeEngineStartInstanceOperator, + ComputeEngineStopInstanceOperator, ) warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.operators.compute`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -41,7 +45,8 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.operators.compute.ComputeEngineBaseOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -58,7 +63,8 @@ def __init__(self, *args, **kwargs): """This class is deprecated. Please use `airflow.providers.google.cloud.operators.compute .ComputeEngineInstanceGroupUpdateManagerTemplateOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -75,7 +81,8 @@ def __init__(self, *args, **kwargs): """This class is deprecated. Please use `airflow.providers.google.cloud.operators.compute .ComputeEngineStartInstanceOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -91,7 +98,8 @@ def __init__(self, *args, **kwargs): """This class is deprecated. Please use `airflow.providers.google.cloud.operators.compute .ComputeEngineStopInstanceOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -107,7 +115,8 @@ def __init__(self, *args, **kwargs): """"This class is deprecated. Please use `airflow.providers.google.cloud.operators.compute .ComputeEngineCopyInstanceTemplateOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -123,6 +132,7 @@ def __init__(self, *args, **kwargs): """This class is deprecated. Please use `airflow.providers.google.cloud.operators.compute .ComputeEngineSetMachineTypeOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/gcp_container_operator.py b/airflow/contrib/operators/gcp_container_operator.py index 886571d3d0dd9..449bfbef62493 100644 --- a/airflow/contrib/operators/gcp_container_operator.py +++ b/airflow/contrib/operators/gcp_container_operator.py @@ -20,12 +20,15 @@ import warnings from airflow.providers.google.cloud.operators.kubernetes_engine import ( - GKECreateClusterOperator, GKEDeleteClusterOperator, GKEStartPodOperator, + GKECreateClusterOperator, + GKEDeleteClusterOperator, + GKEStartPodOperator, ) warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.operators.kubernetes_engine`", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -39,7 +42,8 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.operators.container.GKECreateClusterOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -54,7 +58,8 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.operators.container.GKEDeleteClusterOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -69,6 +74,7 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.operators.container.GKEStartPodOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/gcp_dlp_operator.py b/airflow/contrib/operators/gcp_dlp_operator.py index cbf5194089ef3..99d373fd49488 100644 --- a/airflow/contrib/operators/gcp_dlp_operator.py +++ b/airflow/contrib/operators/gcp_dlp_operator.py @@ -21,23 +21,42 @@ # pylint: disable=unused-import from airflow.providers.google.cloud.operators.dlp import ( # noqa - CloudDLPCancelDLPJobOperator, CloudDLPCreateDeidentifyTemplateOperator, CloudDLPCreateDLPJobOperator, - CloudDLPCreateInspectTemplateOperator, CloudDLPCreateJobTriggerOperator, - CloudDLPCreateStoredInfoTypeOperator, CloudDLPDeidentifyContentOperator, - CloudDLPDeleteDeidentifyTemplateOperator, CloudDLPDeleteDLPJobOperator, - CloudDLPDeleteInspectTemplateOperator, CloudDLPDeleteJobTriggerOperator, - CloudDLPDeleteStoredInfoTypeOperator, CloudDLPGetDeidentifyTemplateOperator, CloudDLPGetDLPJobOperator, - CloudDLPGetDLPJobTriggerOperator, CloudDLPGetInspectTemplateOperator, CloudDLPGetStoredInfoTypeOperator, - CloudDLPInspectContentOperator, CloudDLPListDeidentifyTemplatesOperator, CloudDLPListDLPJobsOperator, - CloudDLPListInfoTypesOperator, CloudDLPListInspectTemplatesOperator, CloudDLPListJobTriggersOperator, - CloudDLPListStoredInfoTypesOperator, CloudDLPRedactImageOperator, CloudDLPReidentifyContentOperator, - CloudDLPUpdateDeidentifyTemplateOperator, CloudDLPUpdateInspectTemplateOperator, - CloudDLPUpdateJobTriggerOperator, CloudDLPUpdateStoredInfoTypeOperator, + CloudDLPCancelDLPJobOperator, + CloudDLPCreateDeidentifyTemplateOperator, + CloudDLPCreateDLPJobOperator, + CloudDLPCreateInspectTemplateOperator, + CloudDLPCreateJobTriggerOperator, + CloudDLPCreateStoredInfoTypeOperator, + CloudDLPDeidentifyContentOperator, + CloudDLPDeleteDeidentifyTemplateOperator, + CloudDLPDeleteDLPJobOperator, + CloudDLPDeleteInspectTemplateOperator, + CloudDLPDeleteJobTriggerOperator, + CloudDLPDeleteStoredInfoTypeOperator, + CloudDLPGetDeidentifyTemplateOperator, + CloudDLPGetDLPJobOperator, + CloudDLPGetDLPJobTriggerOperator, + CloudDLPGetInspectTemplateOperator, + CloudDLPGetStoredInfoTypeOperator, + CloudDLPInspectContentOperator, + CloudDLPListDeidentifyTemplatesOperator, + CloudDLPListDLPJobsOperator, + CloudDLPListInfoTypesOperator, + CloudDLPListInspectTemplatesOperator, + CloudDLPListJobTriggersOperator, + CloudDLPListStoredInfoTypesOperator, + CloudDLPRedactImageOperator, + CloudDLPReidentifyContentOperator, + CloudDLPUpdateDeidentifyTemplateOperator, + CloudDLPUpdateInspectTemplateOperator, + CloudDLPUpdateJobTriggerOperator, + CloudDLPUpdateStoredInfoTypeOperator, ) warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.operators.dlp`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -51,8 +70,8 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.operators.dlp.CloudDLPDeleteDLPJobOperator`.""", - - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -67,7 +86,8 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.operators.dlp.CloudDLPGetDLPJobOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -82,7 +102,8 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.operators.dlp.CloudDLPGetDLPJobTriggerOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -97,6 +118,7 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.operators.dlp.CloudDLPListDLPJobsOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/gcp_function_operator.py b/airflow/contrib/operators/gcp_function_operator.py index 664f5e5ed8274..cebf4e70619a5 100644 --- a/airflow/contrib/operators/gcp_function_operator.py +++ b/airflow/contrib/operators/gcp_function_operator.py @@ -20,12 +20,14 @@ import warnings from airflow.providers.google.cloud.operators.functions import ( - CloudFunctionDeleteFunctionOperator, CloudFunctionDeployFunctionOperator, + CloudFunctionDeleteFunctionOperator, + CloudFunctionDeployFunctionOperator, ) warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.operators.functions`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -40,7 +42,8 @@ def __init__(self, *args, **kwargs): """This class is deprecated. Please use `airflow.providers.google.cloud.operators.function.CloudFunctionDeleteFunctionOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -56,6 +59,7 @@ def __init__(self, *args, **kwargs): """This class is deprecated. Please use `airflow.providers.google.cloud.operators.function.CloudFunctionDeployFunctionOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/gcp_natural_language_operator.py b/airflow/contrib/operators/gcp_natural_language_operator.py index daac37eca33eb..e9ec096b5771e 100644 --- a/airflow/contrib/operators/gcp_natural_language_operator.py +++ b/airflow/contrib/operators/gcp_natural_language_operator.py @@ -20,15 +20,18 @@ import warnings from airflow.providers.google.cloud.operators.natural_language import ( - CloudNaturalLanguageAnalyzeEntitiesOperator, CloudNaturalLanguageAnalyzeEntitySentimentOperator, - CloudNaturalLanguageAnalyzeSentimentOperator, CloudNaturalLanguageClassifyTextOperator, + CloudNaturalLanguageAnalyzeEntitiesOperator, + CloudNaturalLanguageAnalyzeEntitySentimentOperator, + CloudNaturalLanguageAnalyzeSentimentOperator, + CloudNaturalLanguageClassifyTextOperator, ) warnings.warn( """This module is deprecated. Please use `airflow.providers.google.cloud.operators.natural_language` """, - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -46,7 +49,8 @@ def __init__(self, *args, **kwargs): `airflow.providers.google.cloud.operators.natural_language .CloudNaturalLanguageAnalyzeEntitiesOperator`. """, - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -66,7 +70,8 @@ def __init__(self, *args, **kwargs): `airflow.providers.google.cloud.operators.natural_language .CloudNaturalLanguageAnalyzeEntitySentimentOperator`. """, - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -84,7 +89,8 @@ def __init__(self, *args, **kwargs): Please use `airflow.providers.google.cloud.operators.natural_language .CloudNaturalLanguageAnalyzeSentimentOperator`. """, - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -102,6 +108,7 @@ def __init__(self, *args, **kwargs): Please use `airflow.providers.google.cloud.operators.natural_language .CloudNaturalLanguageClassifyTextOperator`. """, - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/gcp_spanner_operator.py b/airflow/contrib/operators/gcp_spanner_operator.py index bd5f4fadd3433..d15c7c0dc430c 100644 --- a/airflow/contrib/operators/gcp_spanner_operator.py +++ b/airflow/contrib/operators/gcp_spanner_operator.py @@ -20,9 +20,12 @@ import warnings from airflow.providers.google.cloud.operators.spanner import ( - SpannerDeleteDatabaseInstanceOperator, SpannerDeleteInstanceOperator, - SpannerDeployDatabaseInstanceOperator, SpannerDeployInstanceOperator, - SpannerQueryDatabaseInstanceOperator, SpannerUpdateDatabaseInstanceOperator, + SpannerDeleteDatabaseInstanceOperator, + SpannerDeleteInstanceOperator, + SpannerDeployDatabaseInstanceOperator, + SpannerDeployInstanceOperator, + SpannerQueryDatabaseInstanceOperator, + SpannerUpdateDatabaseInstanceOperator, ) warnings.warn( @@ -40,7 +43,9 @@ class CloudSpannerInstanceDatabaseDeleteOperator(SpannerDeleteDatabaseInstanceOp def __init__(self, *args, **kwargs): warnings.warn( - self.__doc__, DeprecationWarning, stacklevel=3, + self.__doc__, + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -53,7 +58,9 @@ class CloudSpannerInstanceDatabaseDeployOperator(SpannerDeployDatabaseInstanceOp def __init__(self, *args, **kwargs): warnings.warn( - self.__doc__, DeprecationWarning, stacklevel=3, + self.__doc__, + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -66,7 +73,9 @@ class CloudSpannerInstanceDatabaseQueryOperator(SpannerQueryDatabaseInstanceOper def __init__(self, *args, **kwargs): warnings.warn( - self.__doc__, DeprecationWarning, stacklevel=3, + self.__doc__, + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -79,7 +88,9 @@ class CloudSpannerInstanceDatabaseUpdateOperator(SpannerUpdateDatabaseInstanceOp def __init__(self, *args, **kwargs): warnings.warn( - self.__doc__, DeprecationWarning, stacklevel=3, + self.__doc__, + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -92,7 +103,9 @@ class CloudSpannerInstanceDeleteOperator(SpannerDeleteInstanceOperator): def __init__(self, *args, **kwargs): warnings.warn( - self.__doc__, DeprecationWarning, stacklevel=3, + self.__doc__, + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -105,6 +118,8 @@ class CloudSpannerInstanceDeployOperator(SpannerDeployInstanceOperator): def __init__(self, *args, **kwargs): warnings.warn( - self.__doc__, DeprecationWarning, stacklevel=3, + self.__doc__, + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/gcp_speech_to_text_operator.py b/airflow/contrib/operators/gcp_speech_to_text_operator.py index 4d2c644096a00..c24a027cdcaac 100644 --- a/airflow/contrib/operators/gcp_speech_to_text_operator.py +++ b/airflow/contrib/operators/gcp_speech_to_text_operator.py @@ -23,7 +23,8 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.operators.speech_to_text`", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -40,6 +41,7 @@ def __init__(self, *args, **kwargs): Please use `airflow.providers.google.cloud.operators.speech_to_text .CloudSpeechToTextRecognizeSpeechOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/gcp_sql_operator.py b/airflow/contrib/operators/gcp_sql_operator.py index 00a5b6edb399d..9ec00f8fd1048 100644 --- a/airflow/contrib/operators/gcp_sql_operator.py +++ b/airflow/contrib/operators/gcp_sql_operator.py @@ -20,15 +20,22 @@ import warnings from airflow.providers.google.cloud.operators.cloud_sql import ( - CloudSQLBaseOperator, CloudSQLCreateInstanceDatabaseOperator, CloudSQLCreateInstanceOperator, - CloudSQLDeleteInstanceDatabaseOperator, CloudSQLDeleteInstanceOperator, CloudSQLExecuteQueryOperator, - CloudSQLExportInstanceOperator, CloudSQLImportInstanceOperator, CloudSQLInstancePatchOperator, + CloudSQLBaseOperator, + CloudSQLCreateInstanceDatabaseOperator, + CloudSQLCreateInstanceOperator, + CloudSQLDeleteInstanceDatabaseOperator, + CloudSQLDeleteInstanceOperator, + CloudSQLExecuteQueryOperator, + CloudSQLExportInstanceOperator, + CloudSQLImportInstanceOperator, + CloudSQLInstancePatchOperator, CloudSQLPatchInstanceDatabaseOperator, ) warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.operators.cloud_sql`", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/gcp_tasks_operator.py b/airflow/contrib/operators/gcp_tasks_operator.py index 1e39bdd0f7a70..e1a53242cb43d 100644 --- a/airflow/contrib/operators/gcp_tasks_operator.py +++ b/airflow/contrib/operators/gcp_tasks_operator.py @@ -21,14 +21,23 @@ # pylint: disable=unused-import from airflow.providers.google.cloud.operators.tasks import ( # noqa - CloudTasksQueueCreateOperator, CloudTasksQueueDeleteOperator, CloudTasksQueueGetOperator, - CloudTasksQueuePauseOperator, CloudTasksQueuePurgeOperator, CloudTasksQueueResumeOperator, - CloudTasksQueuesListOperator, CloudTasksQueueUpdateOperator, CloudTasksTaskCreateOperator, - CloudTasksTaskDeleteOperator, CloudTasksTaskGetOperator, CloudTasksTaskRunOperator, + CloudTasksQueueCreateOperator, + CloudTasksQueueDeleteOperator, + CloudTasksQueueGetOperator, + CloudTasksQueuePauseOperator, + CloudTasksQueuePurgeOperator, + CloudTasksQueueResumeOperator, + CloudTasksQueuesListOperator, + CloudTasksQueueUpdateOperator, + CloudTasksTaskCreateOperator, + CloudTasksTaskDeleteOperator, + CloudTasksTaskGetOperator, + CloudTasksTaskRunOperator, CloudTasksTasksListOperator, ) warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.operators.tasks`", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/gcp_text_to_speech_operator.py b/airflow/contrib/operators/gcp_text_to_speech_operator.py index 9a43b7cee0a0f..03c2bcde9cae2 100644 --- a/airflow/contrib/operators/gcp_text_to_speech_operator.py +++ b/airflow/contrib/operators/gcp_text_to_speech_operator.py @@ -23,7 +23,8 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.operators.text_to_speech`", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -38,6 +39,7 @@ def __init__(self, *args, **kwargs): """This class is deprecated. Please use `airflow.providers.google.cloud.operators.text_to_speech.CloudTextToSpeechSynthesizeOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/gcp_transfer_operator.py b/airflow/contrib/operators/gcp_transfer_operator.py index 2f5f0a1cfd7da..a939fdc9e5c82 100644 --- a/airflow/contrib/operators/gcp_transfer_operator.py +++ b/airflow/contrib/operators/gcp_transfer_operator.py @@ -23,17 +23,23 @@ import warnings from airflow.providers.google.cloud.operators.cloud_storage_transfer_service import ( - CloudDataTransferServiceCancelOperationOperator, CloudDataTransferServiceCreateJobOperator, - CloudDataTransferServiceDeleteJobOperator, CloudDataTransferServiceGCSToGCSOperator, - CloudDataTransferServiceGetOperationOperator, CloudDataTransferServiceListOperationsOperator, - CloudDataTransferServicePauseOperationOperator, CloudDataTransferServiceResumeOperationOperator, - CloudDataTransferServiceS3ToGCSOperator, CloudDataTransferServiceUpdateJobOperator, + CloudDataTransferServiceCancelOperationOperator, + CloudDataTransferServiceCreateJobOperator, + CloudDataTransferServiceDeleteJobOperator, + CloudDataTransferServiceGCSToGCSOperator, + CloudDataTransferServiceGetOperationOperator, + CloudDataTransferServiceListOperationsOperator, + CloudDataTransferServicePauseOperationOperator, + CloudDataTransferServiceResumeOperationOperator, + CloudDataTransferServiceS3ToGCSOperator, + CloudDataTransferServiceUpdateJobOperator, ) warnings.warn( "This module is deprecated. " "Please use `airflow.providers.google.cloud.operators.cloud_storage_transfer_service`", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -49,7 +55,8 @@ def __init__(self, *args, **kwargs): """This class is deprecated. Please use `airflow.providers.google.cloud.operators.data_transfer .CloudDataTransferServiceCreateJobOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -66,7 +73,8 @@ def __init__(self, *args, **kwargs): """This class is deprecated. Please use `airflow.providers.google.cloud.operators.data_transfer .CloudDataTransferServiceDeleteJobOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -83,7 +91,8 @@ def __init__(self, *args, **kwargs): """This class is deprecated. Please use `airflow.providers.google.cloud.operators.data_transfer .CloudDataTransferServiceUpdateJobOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -101,7 +110,8 @@ def __init__(self, *args, **kwargs): Please use `airflow.providers.google.cloud.operators.data_transfer .CloudDataTransferServiceCancelOperationOperator`. """, - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -118,7 +128,8 @@ def __init__(self, *args, **kwargs): """This class is deprecated. Please use `airflow.providers.google.cloud.operators.data_transfer .CloudDataTransferServiceGetOperationOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -136,7 +147,8 @@ def __init__(self, *args, **kwargs): Please use `airflow.providers.google.cloud.operators.data_transfer .CloudDataTransferServicePauseOperationOperator`. """, - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -154,7 +166,8 @@ def __init__(self, *args, **kwargs): Please use `airflow.providers.google.cloud.operators.data_transfer .CloudDataTransferServiceResumeOperationOperator`. """, - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -172,7 +185,8 @@ def __init__(self, *args, **kwargs): Please use `airflow.providers.google.cloud.operators.data_transfer .CloudDataTransferServiceListOperationsOperator`. """, - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -190,7 +204,8 @@ def __init__(self, *args, **kwargs): Please use `airflow.providers.google.cloud.operators.data_transfer .CloudDataTransferServiceGCSToGCSOperator`. """, - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -208,6 +223,7 @@ def __init__(self, *args, **kwargs): Please use `airflow.providers.google.cloud.operators.data_transfer .CloudDataTransferServiceS3ToGCSOperator`. """, - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/gcp_translate_operator.py b/airflow/contrib/operators/gcp_translate_operator.py index 8b1c939d20c34..22f302d150515 100644 --- a/airflow/contrib/operators/gcp_translate_operator.py +++ b/airflow/contrib/operators/gcp_translate_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.operators.translate`", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/gcp_translate_speech_operator.py b/airflow/contrib/operators/gcp_translate_speech_operator.py index 121158d6e552e..724697852639d 100644 --- a/airflow/contrib/operators/gcp_translate_speech_operator.py +++ b/airflow/contrib/operators/gcp_translate_speech_operator.py @@ -24,7 +24,8 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.operators.translate_speech`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -40,6 +41,7 @@ def __init__(self, *args, **kwargs): Please use `airflow.providers.google.cloud.operators.translate_speech.CloudTranslateSpeechOperator`. """, - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/gcp_video_intelligence_operator.py b/airflow/contrib/operators/gcp_video_intelligence_operator.py index 401c70d577c56..3e08eb6d0e2db 100644 --- a/airflow/contrib/operators/gcp_video_intelligence_operator.py +++ b/airflow/contrib/operators/gcp_video_intelligence_operator.py @@ -21,11 +21,13 @@ # pylint: disable=unused-import from airflow.providers.google.cloud.operators.video_intelligence import ( # noqa - CloudVideoIntelligenceDetectVideoExplicitContentOperator, CloudVideoIntelligenceDetectVideoLabelsOperator, + CloudVideoIntelligenceDetectVideoExplicitContentOperator, + CloudVideoIntelligenceDetectVideoLabelsOperator, CloudVideoIntelligenceDetectVideoShotsOperator, ) warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.operators.video_intelligence`", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/gcp_vision_operator.py b/airflow/contrib/operators/gcp_vision_operator.py index 99f89737882fc..3975f775ea577 100644 --- a/airflow/contrib/operators/gcp_vision_operator.py +++ b/airflow/contrib/operators/gcp_vision_operator.py @@ -20,18 +20,28 @@ import warnings from airflow.providers.google.cloud.operators.vision import ( # noqa # pylint: disable=unused-import - CloudVisionAddProductToProductSetOperator, CloudVisionCreateProductOperator, - CloudVisionCreateProductSetOperator, CloudVisionCreateReferenceImageOperator, - CloudVisionDeleteProductOperator, CloudVisionDeleteProductSetOperator, - CloudVisionDetectImageLabelsOperator, CloudVisionDetectImageSafeSearchOperator, - CloudVisionDetectTextOperator, CloudVisionGetProductOperator, CloudVisionGetProductSetOperator, - CloudVisionImageAnnotateOperator, CloudVisionRemoveProductFromProductSetOperator, - CloudVisionTextDetectOperator, CloudVisionUpdateProductOperator, CloudVisionUpdateProductSetOperator, + CloudVisionAddProductToProductSetOperator, + CloudVisionCreateProductOperator, + CloudVisionCreateProductSetOperator, + CloudVisionCreateReferenceImageOperator, + CloudVisionDeleteProductOperator, + CloudVisionDeleteProductSetOperator, + CloudVisionDetectImageLabelsOperator, + CloudVisionDetectImageSafeSearchOperator, + CloudVisionDetectTextOperator, + CloudVisionGetProductOperator, + CloudVisionGetProductSetOperator, + CloudVisionImageAnnotateOperator, + CloudVisionRemoveProductFromProductSetOperator, + CloudVisionTextDetectOperator, + CloudVisionUpdateProductOperator, + CloudVisionUpdateProductSetOperator, ) warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.operators.vision`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -45,7 +55,8 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.operators.vision.CloudVisionImageAnnotateOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -60,7 +71,8 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.operators.vision.CloudVisionTextDetectOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -75,7 +87,8 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.operators.vision.CloudVisionCreateProductOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -90,7 +103,8 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.operators.vision.CloudVisionDeleteProductOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -105,7 +119,8 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.operators.vision.CloudVisionGetProductOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -121,7 +136,8 @@ def __init__(self, *args, **kwargs): """This class is deprecated. Please use `airflow.providers.google.cloud.operators.vision.CloudVisionCreateProductSetOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -137,7 +153,8 @@ def __init__(self, *args, **kwargs): """This class is deprecated. Please use `airflow.providers.google.cloud.operators.vision.CloudVisionDeleteProductSetOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -153,7 +170,8 @@ def __init__(self, *args, **kwargs): """This class is deprecated. Please use `airflow.providers.google.cloud.operators.vision.CloudVisionGetProductSetOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -169,7 +187,8 @@ def __init__(self, *args, **kwargs): """This class is deprecated. Please use `airflow.providers.google.cloud.operators.vision.CloudVisionUpdateProductSetOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -185,7 +204,8 @@ def __init__(self, *args, **kwargs): """This class is deprecated. Please use `airflow.providers.google.cloud.operators.vision.CloudVisionUpdateProductOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -201,6 +221,7 @@ def __init__(self, *args, **kwargs): """This class is deprecated. Please use `airflow.providers.google.cloud.operators.vision.CloudVisionCreateReferenceImageOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/gcs_acl_operator.py b/airflow/contrib/operators/gcs_acl_operator.py index 4803c994d7dbe..efd81d9f225f9 100644 --- a/airflow/contrib/operators/gcs_acl_operator.py +++ b/airflow/contrib/operators/gcs_acl_operator.py @@ -20,12 +20,14 @@ import warnings from airflow.providers.google.cloud.operators.gcs import ( - GCSBucketCreateAclEntryOperator, GCSObjectCreateAclEntryOperator, + GCSBucketCreateAclEntryOperator, + GCSObjectCreateAclEntryOperator, ) warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.operators.gcs`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -39,7 +41,8 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.operators.gcs.GCSBucketCreateAclEntryOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -54,6 +57,7 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.operators.gcs.GCSObjectCreateAclEntryOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/gcs_delete_operator.py b/airflow/contrib/operators/gcs_delete_operator.py index e02926f7a9890..4ea63fe3e9a75 100644 --- a/airflow/contrib/operators/gcs_delete_operator.py +++ b/airflow/contrib/operators/gcs_delete_operator.py @@ -23,7 +23,8 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.operators.gcs`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -37,6 +38,7 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.operators.gcs.GCSDeleteObjectsOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/gcs_download_operator.py b/airflow/contrib/operators/gcs_download_operator.py index 43f959d529f7c..bfc421ac83840 100644 --- a/airflow/contrib/operators/gcs_download_operator.py +++ b/airflow/contrib/operators/gcs_download_operator.py @@ -23,7 +23,8 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.operators.gcs`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -37,6 +38,7 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.operators.gcs.GCSToLocalFilesystemOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/gcs_list_operator.py b/airflow/contrib/operators/gcs_list_operator.py index de17ce6d03daa..5ac4a9525935c 100644 --- a/airflow/contrib/operators/gcs_list_operator.py +++ b/airflow/contrib/operators/gcs_list_operator.py @@ -23,7 +23,8 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.operators.gcs`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -37,6 +38,7 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.operators.gcs.GCSListObjectsOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/gcs_operator.py b/airflow/contrib/operators/gcs_operator.py index bef41c5fc4dab..72975bc0f2ad8 100644 --- a/airflow/contrib/operators/gcs_operator.py +++ b/airflow/contrib/operators/gcs_operator.py @@ -23,7 +23,8 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.operators.gcs`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -37,6 +38,7 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.operators.gcs.GCSCreateBucketOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/gcs_to_bq.py b/airflow/contrib/operators/gcs_to_bq.py index 27af7b7953a74..5a9c4546fd52e 100644 --- a/airflow/contrib/operators/gcs_to_bq.py +++ b/airflow/contrib/operators/gcs_to_bq.py @@ -23,7 +23,8 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.transfers.gcs_to_bigquery`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -37,6 +38,7 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.transfers.gcs_to_bq.GCSToBigQueryOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/gcs_to_gcs.py b/airflow/contrib/operators/gcs_to_gcs.py index 4737fa0bf06f3..ca02151657e8d 100644 --- a/airflow/contrib/operators/gcs_to_gcs.py +++ b/airflow/contrib/operators/gcs_to_gcs.py @@ -23,7 +23,8 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.transfers.gcs_to_gcs`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -37,6 +38,7 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSToGCSOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/gcs_to_gcs_transfer_operator.py b/airflow/contrib/operators/gcs_to_gcs_transfer_operator.py index fa1ec22012fac..99d7ca2d805cd 100644 --- a/airflow/contrib/operators/gcs_to_gcs_transfer_operator.py +++ b/airflow/contrib/operators/gcs_to_gcs_transfer_operator.py @@ -27,5 +27,6 @@ warnings.warn( "This module is deprecated. " "Please use `airflow.providers.google.cloud.operators.cloud_storage_transfer_service`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/gcs_to_gdrive_operator.py b/airflow/contrib/operators/gcs_to_gdrive_operator.py index 6ec7462d950cf..67f312d6ba39b 100644 --- a/airflow/contrib/operators/gcs_to_gdrive_operator.py +++ b/airflow/contrib/operators/gcs_to_gdrive_operator.py @@ -23,7 +23,7 @@ from airflow.providers.google.suite.transfers.gcs_to_gdrive import GCSToGoogleDriveOperator # noqa warnings.warn( - "This module is deprecated. " - "Please use `airflow.providers.google.suite.transfers.gcs_to_gdrive.", - DeprecationWarning, stacklevel=2 + "This module is deprecated. " "Please use `airflow.providers.google.suite.transfers.gcs_to_gdrive.", + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/gcs_to_s3.py b/airflow/contrib/operators/gcs_to_s3.py index dd358ee44cc87..ed7fef6fb767f 100644 --- a/airflow/contrib/operators/gcs_to_s3.py +++ b/airflow/contrib/operators/gcs_to_s3.py @@ -23,7 +23,8 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.amazon.aws.transfers.gcs_to_s3`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -37,6 +38,7 @@ def __init__(self, *args, **kwargs): warnings.warn( "This class is deprecated. " "Please use `airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSToS3Operator`.", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/grpc_operator.py b/airflow/contrib/operators/grpc_operator.py index 03a9b6e249dae..abda1207a14ad 100644 --- a/airflow/contrib/operators/grpc_operator.py +++ b/airflow/contrib/operators/grpc_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.grpc.operators.grpc`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/hive_to_dynamodb.py b/airflow/contrib/operators/hive_to_dynamodb.py index 6784680272273..48959601c0eea 100644 --- a/airflow/contrib/operators/hive_to_dynamodb.py +++ b/airflow/contrib/operators/hive_to_dynamodb.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.amazon.aws.transfers.hive_to_dynamodb`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/imap_attachment_to_s3_operator.py b/airflow/contrib/operators/imap_attachment_to_s3_operator.py index 597d6beba0cd3..316da8a2d8e6a 100644 --- a/airflow/contrib/operators/imap_attachment_to_s3_operator.py +++ b/airflow/contrib/operators/imap_attachment_to_s3_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.amazon.aws.transfers.imap_attachment_to_s3`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/jenkins_job_trigger_operator.py b/airflow/contrib/operators/jenkins_job_trigger_operator.py index 13f68ae3b686b..538a0a2a8025c 100644 --- a/airflow/contrib/operators/jenkins_job_trigger_operator.py +++ b/airflow/contrib/operators/jenkins_job_trigger_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.jenkins.operators.jenkins_job_trigger`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/jira_operator.py b/airflow/contrib/operators/jira_operator.py index d7e977f6138a1..d84662972adf5 100644 --- a/airflow/contrib/operators/jira_operator.py +++ b/airflow/contrib/operators/jira_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.jira.operators.jira`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/kubernetes_pod_operator.py b/airflow/contrib/operators/kubernetes_pod_operator.py index 2663ceddf2f63..a6587350e7101 100644 --- a/airflow/contrib/operators/kubernetes_pod_operator.py +++ b/airflow/contrib/operators/kubernetes_pod_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.cncf.kubernetes.operators.kubernetes_pod`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/mlengine_operator.py b/airflow/contrib/operators/mlengine_operator.py index fd6b530bb8bfd..149ff82180177 100644 --- a/airflow/contrib/operators/mlengine_operator.py +++ b/airflow/contrib/operators/mlengine_operator.py @@ -20,13 +20,16 @@ import warnings from airflow.providers.google.cloud.operators.mlengine import ( - MLEngineManageModelOperator, MLEngineManageVersionOperator, MLEngineStartBatchPredictionJobOperator, + MLEngineManageModelOperator, + MLEngineManageVersionOperator, + MLEngineStartBatchPredictionJobOperator, MLEngineStartTrainingJobOperator, ) warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.operators.mlengine`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -41,7 +44,8 @@ def __init__(self, *args, **kwargs): """This class is deprecated. Please use `airflow.providers.google.cloud.operators.mlengine.MLEngineStartBatchPredictionJobOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -56,7 +60,8 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.operators.mlengine.MLEngineManageModelOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -72,7 +77,8 @@ def __init__(self, *args, **kwargs): """This class is deprecated. Please use `airflow.providers.google.cloud.operators.mlengine.MLEngineStartTrainingJobOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -88,6 +94,7 @@ def __init__(self, *args, **kwargs): """This class is deprecated. Please use `airflow.providers.google.cloud.operators.mlengine.MLEngineManageVersionOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/mongo_to_s3.py b/airflow/contrib/operators/mongo_to_s3.py index 82449ee5d7a8a..fe8bd1d3fef2a 100644 --- a/airflow/contrib/operators/mongo_to_s3.py +++ b/airflow/contrib/operators/mongo_to_s3.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.amazon.aws.transfers.mongo_to_s3`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/mssql_to_gcs.py b/airflow/contrib/operators/mssql_to_gcs.py index b094c27f59f11..3527327c5c34b 100644 --- a/airflow/contrib/operators/mssql_to_gcs.py +++ b/airflow/contrib/operators/mssql_to_gcs.py @@ -23,7 +23,8 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.transfers.mssql_to_gcs`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -37,6 +38,7 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.transfers.mssql_to_gcs.MSSQLToGCSOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/mysql_to_gcs.py b/airflow/contrib/operators/mysql_to_gcs.py index dbb5cdfccb804..73a137527e9a0 100644 --- a/airflow/contrib/operators/mysql_to_gcs.py +++ b/airflow/contrib/operators/mysql_to_gcs.py @@ -23,7 +23,8 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.transfers.mysql_to_gcs`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -37,6 +38,7 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.transfers.mysql_to_gcs.MySQLToGCSOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/opsgenie_alert_operator.py b/airflow/contrib/operators/opsgenie_alert_operator.py index 877b87affe786..acf036aba7a34 100644 --- a/airflow/contrib/operators/opsgenie_alert_operator.py +++ b/airflow/contrib/operators/opsgenie_alert_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.opsgenie.operators.opsgenie_alert`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/oracle_to_azure_data_lake_transfer.py b/airflow/contrib/operators/oracle_to_azure_data_lake_transfer.py index 2cdca434975dd..c27d94826fb39 100644 --- a/airflow/contrib/operators/oracle_to_azure_data_lake_transfer.py +++ b/airflow/contrib/operators/oracle_to_azure_data_lake_transfer.py @@ -30,5 +30,6 @@ warnings.warn( "This module is deprecated. " "Please use `airflow.providers.microsoft.azure.transfers.oracle_to_azure_data_lake`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/oracle_to_oracle_transfer.py b/airflow/contrib/operators/oracle_to_oracle_transfer.py index e522951963015..28a39c61e73f8 100644 --- a/airflow/contrib/operators/oracle_to_oracle_transfer.py +++ b/airflow/contrib/operators/oracle_to_oracle_transfer.py @@ -27,7 +27,8 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.oracle.transfers.oracle_to_oracle`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -43,6 +44,7 @@ def __init__(self, *args, **kwargs): """This class is deprecated. Please use `airflow.providers.oracle.transfers.oracle_to_oracle.OracleToOracleOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/postgres_to_gcs_operator.py b/airflow/contrib/operators/postgres_to_gcs_operator.py index 0cbdca34fa140..62f2d06e202a7 100644 --- a/airflow/contrib/operators/postgres_to_gcs_operator.py +++ b/airflow/contrib/operators/postgres_to_gcs_operator.py @@ -23,7 +23,8 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.transfers.postgres_to_gcs`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -37,6 +38,7 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.transfers.postgres_to_gcs.PostgresToGCSOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/pubsub_operator.py b/airflow/contrib/operators/pubsub_operator.py index d2bab4fa1e91f..6527c8f2f5a1b 100644 --- a/airflow/contrib/operators/pubsub_operator.py +++ b/airflow/contrib/operators/pubsub_operator.py @@ -23,8 +23,11 @@ import warnings from airflow.providers.google.cloud.operators.pubsub import ( - PubSubCreateSubscriptionOperator, PubSubCreateTopicOperator, PubSubDeleteSubscriptionOperator, - PubSubDeleteTopicOperator, PubSubPublishMessageOperator, + PubSubCreateSubscriptionOperator, + PubSubCreateTopicOperator, + PubSubDeleteSubscriptionOperator, + PubSubDeleteTopicOperator, + PubSubPublishMessageOperator, ) warnings.warn( diff --git a/airflow/contrib/operators/qubole_check_operator.py b/airflow/contrib/operators/qubole_check_operator.py index 7436d67061a27..6eefe0bc8c199 100644 --- a/airflow/contrib/operators/qubole_check_operator.py +++ b/airflow/contrib/operators/qubole_check_operator.py @@ -21,10 +21,12 @@ # pylint: disable=unused-import from airflow.providers.qubole.operators.qubole_check import ( # noqa - QuboleCheckOperator, QuboleValueCheckOperator, + QuboleCheckOperator, + QuboleValueCheckOperator, ) warnings.warn( "This module is deprecated. Please use `airflow.providers.qubole.operators.qubole_check`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/qubole_operator.py b/airflow/contrib/operators/qubole_operator.py index 86cc465d284d2..521c1dd0f199b 100644 --- a/airflow/contrib/operators/qubole_operator.py +++ b/airflow/contrib/operators/qubole_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.qubole.operators.qubole`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/redis_publish_operator.py b/airflow/contrib/operators/redis_publish_operator.py index f84aaed477813..adee8f12c749d 100644 --- a/airflow/contrib/operators/redis_publish_operator.py +++ b/airflow/contrib/operators/redis_publish_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.redis.operators.redis_publish`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/s3_copy_object_operator.py b/airflow/contrib/operators/s3_copy_object_operator.py index 44d34c3dbe71d..81a92154560ac 100644 --- a/airflow/contrib/operators/s3_copy_object_operator.py +++ b/airflow/contrib/operators/s3_copy_object_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3_copy_object`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/s3_delete_objects_operator.py b/airflow/contrib/operators/s3_delete_objects_operator.py index 211524e338499..39ab997e71f4f 100644 --- a/airflow/contrib/operators/s3_delete_objects_operator.py +++ b/airflow/contrib/operators/s3_delete_objects_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3_delete_objects`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/s3_list_operator.py b/airflow/contrib/operators/s3_list_operator.py index 56a1d9f840a97..e6ae59aea1a2c 100644 --- a/airflow/contrib/operators/s3_list_operator.py +++ b/airflow/contrib/operators/s3_list_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3_list`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/s3_to_gcs_operator.py b/airflow/contrib/operators/s3_to_gcs_operator.py index 7a43f761c5402..d6c1f04b5aaf6 100644 --- a/airflow/contrib/operators/s3_to_gcs_operator.py +++ b/airflow/contrib/operators/s3_to_gcs_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.transfers.s3_to_gcs`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/s3_to_gcs_transfer_operator.py b/airflow/contrib/operators/s3_to_gcs_transfer_operator.py index 84915a8d93748..d82657b774bc5 100644 --- a/airflow/contrib/operators/s3_to_gcs_transfer_operator.py +++ b/airflow/contrib/operators/s3_to_gcs_transfer_operator.py @@ -22,10 +22,13 @@ import warnings # pylint: disable=unused-import,line-too-long -from airflow.providers.google.cloud.operators.cloud_storage_transfer_service import CloudDataTransferServiceS3ToGCSOperator # noqa isort:skip +from airflow.providers.google.cloud.operators.cloud_storage_transfer_service import ( # noqa isort:skip + CloudDataTransferServiceS3ToGCSOperator, +) warnings.warn( "This module is deprecated. " "Please use `airflow.providers.google.cloud.operators.cloud_storage_transfer_service`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/s3_to_sftp_operator.py b/airflow/contrib/operators/s3_to_sftp_operator.py index b247ce53816e2..1e5a934264e2c 100644 --- a/airflow/contrib/operators/s3_to_sftp_operator.py +++ b/airflow/contrib/operators/s3_to_sftp_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.amazon.aws.transfers.s3_to_sftp`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/sagemaker_base_operator.py b/airflow/contrib/operators/sagemaker_base_operator.py index c44a1cd189dc8..006424a4f492f 100644 --- a/airflow/contrib/operators/sagemaker_base_operator.py +++ b/airflow/contrib/operators/sagemaker_base_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_base`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/sagemaker_endpoint_config_operator.py b/airflow/contrib/operators/sagemaker_endpoint_config_operator.py index cd8b8c8c3995a..cf828d4ac7317 100644 --- a/airflow/contrib/operators/sagemaker_endpoint_config_operator.py +++ b/airflow/contrib/operators/sagemaker_endpoint_config_operator.py @@ -30,5 +30,6 @@ warnings.warn( "This module is deprecated. " "Please use `airflow.providers.amazon.aws.operators.sagemaker_endpoint_config`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/sagemaker_endpoint_operator.py b/airflow/contrib/operators/sagemaker_endpoint_operator.py index 6cc6abb370aed..cb1de3f5d65d4 100644 --- a/airflow/contrib/operators/sagemaker_endpoint_operator.py +++ b/airflow/contrib/operators/sagemaker_endpoint_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_endpoint`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/sagemaker_model_operator.py b/airflow/contrib/operators/sagemaker_model_operator.py index 174e63851477b..c03799dfee397 100644 --- a/airflow/contrib/operators/sagemaker_model_operator.py +++ b/airflow/contrib/operators/sagemaker_model_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_model`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/sagemaker_training_operator.py b/airflow/contrib/operators/sagemaker_training_operator.py index 17f2b137768fe..f6b1a0ecf71b8 100644 --- a/airflow/contrib/operators/sagemaker_training_operator.py +++ b/airflow/contrib/operators/sagemaker_training_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_training`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/sagemaker_transform_operator.py b/airflow/contrib/operators/sagemaker_transform_operator.py index 490a0c2749def..e4f8460d6ee52 100644 --- a/airflow/contrib/operators/sagemaker_transform_operator.py +++ b/airflow/contrib/operators/sagemaker_transform_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_transform`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/sagemaker_tuning_operator.py b/airflow/contrib/operators/sagemaker_tuning_operator.py index 24cde0fe5bca5..2113eab2a1df6 100644 --- a/airflow/contrib/operators/sagemaker_tuning_operator.py +++ b/airflow/contrib/operators/sagemaker_tuning_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_tuning`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/segment_track_event_operator.py b/airflow/contrib/operators/segment_track_event_operator.py index e1a17f37450f7..49fc35e7efea0 100644 --- a/airflow/contrib/operators/segment_track_event_operator.py +++ b/airflow/contrib/operators/segment_track_event_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.segment.operators.segment_track_event`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/sftp_to_s3_operator.py b/airflow/contrib/operators/sftp_to_s3_operator.py index 817880c14c706..d52fab2f6c057 100644 --- a/airflow/contrib/operators/sftp_to_s3_operator.py +++ b/airflow/contrib/operators/sftp_to_s3_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.amazon.aws.transfers.sftp_to_s3`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/slack_webhook_operator.py b/airflow/contrib/operators/slack_webhook_operator.py index bdf517424a5c4..627ca143a9cd5 100644 --- a/airflow/contrib/operators/slack_webhook_operator.py +++ b/airflow/contrib/operators/slack_webhook_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.slack.operators.slack_webhook`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/snowflake_operator.py b/airflow/contrib/operators/snowflake_operator.py index de14627669133..41fe2d0a44a95 100644 --- a/airflow/contrib/operators/snowflake_operator.py +++ b/airflow/contrib/operators/snowflake_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.snowflake.operators.snowflake`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/spark_jdbc_operator.py b/airflow/contrib/operators/spark_jdbc_operator.py index 178663e04c99e..55c0fc1c4cce7 100644 --- a/airflow/contrib/operators/spark_jdbc_operator.py +++ b/airflow/contrib/operators/spark_jdbc_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.apache.spark.operators.spark_jdbc`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/spark_sql_operator.py b/airflow/contrib/operators/spark_sql_operator.py index 37dfda92269a3..b6d6a5f152a77 100644 --- a/airflow/contrib/operators/spark_sql_operator.py +++ b/airflow/contrib/operators/spark_sql_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.apache.spark.operators.spark_sql`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/spark_submit_operator.py b/airflow/contrib/operators/spark_submit_operator.py index 4b0c469f331c6..5c3253be5b79b 100644 --- a/airflow/contrib/operators/spark_submit_operator.py +++ b/airflow/contrib/operators/spark_submit_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.apache.spark.operators.spark_submit`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/sql_to_gcs.py b/airflow/contrib/operators/sql_to_gcs.py index cfeaa9b5395c0..01ba7dffe961c 100644 --- a/airflow/contrib/operators/sql_to_gcs.py +++ b/airflow/contrib/operators/sql_to_gcs.py @@ -23,7 +23,8 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.transfers.sql_to_gcs`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -37,6 +38,7 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.transfers.sql_to_gcs.BaseSQLToGCSOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/sqoop_operator.py b/airflow/contrib/operators/sqoop_operator.py index d0daa2b8da938..7464a8c1cb16a 100644 --- a/airflow/contrib/operators/sqoop_operator.py +++ b/airflow/contrib/operators/sqoop_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.apache.sqoop.operators.sqoop`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/ssh_operator.py b/airflow/contrib/operators/ssh_operator.py index cc29431b8c30d..82e88e2155356 100644 --- a/airflow/contrib/operators/ssh_operator.py +++ b/airflow/contrib/operators/ssh_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.ssh.operators.ssh`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/vertica_operator.py b/airflow/contrib/operators/vertica_operator.py index 2504516a98396..71ef3d15d30ed 100644 --- a/airflow/contrib/operators/vertica_operator.py +++ b/airflow/contrib/operators/vertica_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.vertica.operators.vertica`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/vertica_to_hive.py b/airflow/contrib/operators/vertica_to_hive.py index 84484ee69d94f..a1ed0f65ea51d 100644 --- a/airflow/contrib/operators/vertica_to_hive.py +++ b/airflow/contrib/operators/vertica_to_hive.py @@ -26,7 +26,8 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.apache.hive.transfers.vertica_to_hive`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -42,6 +43,7 @@ def __init__(self, *args, **kwargs): """This class is deprecated. Please use `airflow.providers.apache.hive.transfers.vertica_to_hive.VerticaToHiveOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/vertica_to_mysql.py b/airflow/contrib/operators/vertica_to_mysql.py index 7a1d174ebe9bc..acef39e9cf3bb 100644 --- a/airflow/contrib/operators/vertica_to_mysql.py +++ b/airflow/contrib/operators/vertica_to_mysql.py @@ -27,7 +27,8 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.mysql.transfers.vertica_to_mysql`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -43,6 +44,7 @@ def __init__(self, *args, **kwargs): """This class is deprecated. Please use `airflow.providers.mysql.transfers.vertica_to_mysql.VerticaToMySqlOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/wasb_delete_blob_operator.py b/airflow/contrib/operators/wasb_delete_blob_operator.py index 44fff265b724b..f204e8b174a8a 100644 --- a/airflow/contrib/operators/wasb_delete_blob_operator.py +++ b/airflow/contrib/operators/wasb_delete_blob_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.microsoft.azure.operators.wasb_delete_blob`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/operators/winrm_operator.py b/airflow/contrib/operators/winrm_operator.py index 9d24e774b7e34..593ae63a58534 100644 --- a/airflow/contrib/operators/winrm_operator.py +++ b/airflow/contrib/operators/winrm_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.microsoft.winrm.operators.winrm`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/secrets/gcp_secrets_manager.py b/airflow/contrib/secrets/gcp_secrets_manager.py index 6142a7ffc5928..0adcbbdc7bc41 100644 --- a/airflow/contrib/secrets/gcp_secrets_manager.py +++ b/airflow/contrib/secrets/gcp_secrets_manager.py @@ -40,6 +40,7 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.secrets.secret_manager.CloudSecretManagerBackend`.""", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/sensors/aws_glue_catalog_partition_sensor.py b/airflow/contrib/sensors/aws_glue_catalog_partition_sensor.py index 99698a7cccbc3..97bff996e9098 100644 --- a/airflow/contrib/sensors/aws_glue_catalog_partition_sensor.py +++ b/airflow/contrib/sensors/aws_glue_catalog_partition_sensor.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.glue_catalog_partition`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/sensors/azure_cosmos_sensor.py b/airflow/contrib/sensors/azure_cosmos_sensor.py index beae4a9978dd8..3fdce37d55673 100644 --- a/airflow/contrib/sensors/azure_cosmos_sensor.py +++ b/airflow/contrib/sensors/azure_cosmos_sensor.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.microsoft.azure.sensors.azure_cosmos`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/sensors/bash_sensor.py b/airflow/contrib/sensors/bash_sensor.py index 136f78a13a132..c294fb1e8d1d4 100644 --- a/airflow/contrib/sensors/bash_sensor.py +++ b/airflow/contrib/sensors/bash_sensor.py @@ -23,6 +23,5 @@ from airflow.sensors.bash import STDOUT, BashSensor, Popen, TemporaryDirectory, gettempdir # noqa warnings.warn( - "This module is deprecated. Please use `airflow.sensors.bash`.", - DeprecationWarning, stacklevel=2 + "This module is deprecated. Please use `airflow.sensors.bash`.", DeprecationWarning, stacklevel=2 ) diff --git a/airflow/contrib/sensors/celery_queue_sensor.py b/airflow/contrib/sensors/celery_queue_sensor.py index 61b2ea6959afe..5516eb468a979 100644 --- a/airflow/contrib/sensors/celery_queue_sensor.py +++ b/airflow/contrib/sensors/celery_queue_sensor.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.celery.sensors.celery_queue`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/sensors/datadog_sensor.py b/airflow/contrib/sensors/datadog_sensor.py index d32445068c641..d4bdaf40d9669 100644 --- a/airflow/contrib/sensors/datadog_sensor.py +++ b/airflow/contrib/sensors/datadog_sensor.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.datadog.sensors.datadog`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/sensors/emr_base_sensor.py b/airflow/contrib/sensors/emr_base_sensor.py index 2519440529def..595f84782a703 100644 --- a/airflow/contrib/sensors/emr_base_sensor.py +++ b/airflow/contrib/sensors/emr_base_sensor.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.emr_base`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/sensors/emr_job_flow_sensor.py b/airflow/contrib/sensors/emr_job_flow_sensor.py index dc2a098933250..1e8f62a669bac 100644 --- a/airflow/contrib/sensors/emr_job_flow_sensor.py +++ b/airflow/contrib/sensors/emr_job_flow_sensor.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.emr_job_flow`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/sensors/emr_step_sensor.py b/airflow/contrib/sensors/emr_step_sensor.py index 598919824583b..d64b20e6b0c99 100644 --- a/airflow/contrib/sensors/emr_step_sensor.py +++ b/airflow/contrib/sensors/emr_step_sensor.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.emr_step`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/sensors/file_sensor.py b/airflow/contrib/sensors/file_sensor.py index eced61c89302d..97d3c458a6440 100644 --- a/airflow/contrib/sensors/file_sensor.py +++ b/airflow/contrib/sensors/file_sensor.py @@ -23,6 +23,5 @@ from airflow.sensors.filesystem import FileSensor # noqa warnings.warn( - "This module is deprecated. Please use `airflow.sensors.filesystem`.", - DeprecationWarning, stacklevel=2 + "This module is deprecated. Please use `airflow.sensors.filesystem`.", DeprecationWarning, stacklevel=2 ) diff --git a/airflow/contrib/sensors/ftp_sensor.py b/airflow/contrib/sensors/ftp_sensor.py index 548efe90f4b7b..b8f3f6d766f06 100644 --- a/airflow/contrib/sensors/ftp_sensor.py +++ b/airflow/contrib/sensors/ftp_sensor.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.ftp.sensors.ftp`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/sensors/gcp_transfer_sensor.py b/airflow/contrib/sensors/gcp_transfer_sensor.py index 7aa9ec68164cd..6adb234e8432e 100644 --- a/airflow/contrib/sensors/gcp_transfer_sensor.py +++ b/airflow/contrib/sensors/gcp_transfer_sensor.py @@ -29,7 +29,8 @@ warnings.warn( "This module is deprecated. " "Please use `airflow.providers.google.cloud.sensors.cloud_storage_transfer_service`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -44,6 +45,7 @@ def __init__(self, *args, **kwargs): """This class is deprecated. Please use `airflow.providers.google.cloud.sensors.transfer.CloudDataTransferServiceJobStatusSensor`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/sensors/gcs_sensor.py b/airflow/contrib/sensors/gcs_sensor.py index 42182c7e6bfcc..2198886ba9ad9 100644 --- a/airflow/contrib/sensors/gcs_sensor.py +++ b/airflow/contrib/sensors/gcs_sensor.py @@ -20,13 +20,16 @@ import warnings from airflow.providers.google.cloud.sensors.gcs import ( - GCSObjectExistenceSensor, GCSObjectsWtihPrefixExistenceSensor, GCSObjectUpdateSensor, + GCSObjectExistenceSensor, + GCSObjectsWtihPrefixExistenceSensor, + GCSObjectUpdateSensor, GCSUploadSessionCompleteSensor, ) warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.sensors.gcs`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -40,7 +43,8 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.sensors.gcs.GCSObjectExistenceSensor`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -55,7 +59,8 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.sensors.gcs.GCSObjectUpdateSensor`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -70,7 +75,8 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.sensors.gcs.GCSObjectsWtihPrefixExistenceSensor`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -85,6 +91,7 @@ def __init__(self, *args, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.google.cloud.sensors.gcs.GCSUploadSessionCompleteSensor`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/sensors/hdfs_sensor.py b/airflow/contrib/sensors/hdfs_sensor.py index 744c1af83c880..192314684272d 100644 --- a/airflow/contrib/sensors/hdfs_sensor.py +++ b/airflow/contrib/sensors/hdfs_sensor.py @@ -27,7 +27,8 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.apache.hdfs.sensors.hdfs`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -44,7 +45,8 @@ def __init__(self, *args, **kwargs): """This class is deprecated. Please use `airflow.providers.apache.hdfs.sensors.hdfs.HdfsFolderSensor`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) @@ -62,6 +64,7 @@ def __init__(self, *args, **kwargs): """This class is deprecated. Please use `airflow.providers.apache.hdfs.sensors.hdfs.HdfsRegexSensor`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(*args, **kwargs) diff --git a/airflow/contrib/sensors/imap_attachment_sensor.py b/airflow/contrib/sensors/imap_attachment_sensor.py index 3c1fc676c598f..03d0dfe7caea4 100644 --- a/airflow/contrib/sensors/imap_attachment_sensor.py +++ b/airflow/contrib/sensors/imap_attachment_sensor.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.imap.sensors.imap_attachment`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/sensors/mongo_sensor.py b/airflow/contrib/sensors/mongo_sensor.py index d817dee9df49b..28e158676d3fe 100644 --- a/airflow/contrib/sensors/mongo_sensor.py +++ b/airflow/contrib/sensors/mongo_sensor.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.mongo.sensors.mongo`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/sensors/pubsub_sensor.py b/airflow/contrib/sensors/pubsub_sensor.py index 9fe3354f4ce5f..2a9cb520e4687 100644 --- a/airflow/contrib/sensors/pubsub_sensor.py +++ b/airflow/contrib/sensors/pubsub_sensor.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.sensors.pubsub`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/sensors/python_sensor.py b/airflow/contrib/sensors/python_sensor.py index 6d9f745ef83f5..7e39a2a573fea 100644 --- a/airflow/contrib/sensors/python_sensor.py +++ b/airflow/contrib/sensors/python_sensor.py @@ -23,6 +23,5 @@ from airflow.sensors.python import PythonSensor # noqa warnings.warn( - "This module is deprecated. Please use `airflow.sensors.python`.", - DeprecationWarning, stacklevel=2 + "This module is deprecated. Please use `airflow.sensors.python`.", DeprecationWarning, stacklevel=2 ) diff --git a/airflow/contrib/sensors/qubole_sensor.py b/airflow/contrib/sensors/qubole_sensor.py index a503c3ea5c898..ec39a40fd9f1b 100644 --- a/airflow/contrib/sensors/qubole_sensor.py +++ b/airflow/contrib/sensors/qubole_sensor.py @@ -21,10 +21,13 @@ # pylint: disable=unused-import from airflow.providers.qubole.sensors.qubole import ( # noqa - QuboleFileSensor, QubolePartitionSensor, QuboleSensor, + QuboleFileSensor, + QubolePartitionSensor, + QuboleSensor, ) warnings.warn( "This module is deprecated. Please use `airflow.providers.qubole.sensors.qubole`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/sensors/redis_key_sensor.py b/airflow/contrib/sensors/redis_key_sensor.py index 77dd1a431a3bd..f9762de724cf6 100644 --- a/airflow/contrib/sensors/redis_key_sensor.py +++ b/airflow/contrib/sensors/redis_key_sensor.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.redis.sensors.redis_key`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/sensors/redis_pub_sub_sensor.py b/airflow/contrib/sensors/redis_pub_sub_sensor.py index 121a46559354a..3b541de5dbb9c 100644 --- a/airflow/contrib/sensors/redis_pub_sub_sensor.py +++ b/airflow/contrib/sensors/redis_pub_sub_sensor.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.redis.sensors.redis_pub_sub`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/sensors/sagemaker_base_sensor.py b/airflow/contrib/sensors/sagemaker_base_sensor.py index 1e5e60ad7f45f..54b8323d1f798 100644 --- a/airflow/contrib/sensors/sagemaker_base_sensor.py +++ b/airflow/contrib/sensors/sagemaker_base_sensor.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker_base`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/sensors/sagemaker_endpoint_sensor.py b/airflow/contrib/sensors/sagemaker_endpoint_sensor.py index afda91af082e1..f6ebe815d8e28 100644 --- a/airflow/contrib/sensors/sagemaker_endpoint_sensor.py +++ b/airflow/contrib/sensors/sagemaker_endpoint_sensor.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker_endpoint`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/sensors/sagemaker_training_sensor.py b/airflow/contrib/sensors/sagemaker_training_sensor.py index 6c1ce75ee5483..61f8d432a4820 100644 --- a/airflow/contrib/sensors/sagemaker_training_sensor.py +++ b/airflow/contrib/sensors/sagemaker_training_sensor.py @@ -21,10 +21,12 @@ # pylint: disable=unused-import from airflow.providers.amazon.aws.sensors.sagemaker_training import ( # noqa - SageMakerHook, SageMakerTrainingSensor, + SageMakerHook, + SageMakerTrainingSensor, ) warnings.warn( "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker_training`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/sensors/sagemaker_transform_sensor.py b/airflow/contrib/sensors/sagemaker_transform_sensor.py index f2bdc73ed10ee..daf06add899eb 100644 --- a/airflow/contrib/sensors/sagemaker_transform_sensor.py +++ b/airflow/contrib/sensors/sagemaker_transform_sensor.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker_transform`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/sensors/sagemaker_tuning_sensor.py b/airflow/contrib/sensors/sagemaker_tuning_sensor.py index 3cbbf1d902e43..a1886516ad6b5 100644 --- a/airflow/contrib/sensors/sagemaker_tuning_sensor.py +++ b/airflow/contrib/sensors/sagemaker_tuning_sensor.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker_tuning`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/sensors/wasb_sensor.py b/airflow/contrib/sensors/wasb_sensor.py index bfe0ec115c4ff..22584473399e7 100644 --- a/airflow/contrib/sensors/wasb_sensor.py +++ b/airflow/contrib/sensors/wasb_sensor.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.microsoft.azure.sensors.wasb`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/sensors/weekday_sensor.py b/airflow/contrib/sensors/weekday_sensor.py index 890d07378ae7d..8f462bbb6df42 100644 --- a/airflow/contrib/sensors/weekday_sensor.py +++ b/airflow/contrib/sensors/weekday_sensor.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.sensors.weekday_sensor`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/task_runner/cgroup_task_runner.py b/airflow/contrib/task_runner/cgroup_task_runner.py index 0f03fed792b0a..9970da6e4674b 100644 --- a/airflow/contrib/task_runner/cgroup_task_runner.py +++ b/airflow/contrib/task_runner/cgroup_task_runner.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.task.task_runner.cgroup_task_runner`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/utils/__init__.py b/airflow/contrib/utils/__init__.py index 9404bc6de998f..4e6cbb7d51e23 100644 --- a/airflow/contrib/utils/__init__.py +++ b/airflow/contrib/utils/__init__.py @@ -19,7 +19,4 @@ import warnings -warnings.warn( - "This module is deprecated. Please use `airflow.utils`.", - DeprecationWarning, stacklevel=2 -) +warnings.warn("This module is deprecated. Please use `airflow.utils`.", DeprecationWarning, stacklevel=2) diff --git a/airflow/contrib/utils/gcp_field_sanitizer.py b/airflow/contrib/utils/gcp_field_sanitizer.py index 1e39bb287bf42..a7dca19c2f37d 100644 --- a/airflow/contrib/utils/gcp_field_sanitizer.py +++ b/airflow/contrib/utils/gcp_field_sanitizer.py @@ -21,10 +21,12 @@ # pylint: disable=unused-import from airflow.providers.google.cloud.utils.field_sanitizer import ( # noqa - GcpBodyFieldSanitizer, GcpFieldSanitizerException, + GcpBodyFieldSanitizer, + GcpFieldSanitizerException, ) warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.utils.field_sanitizer`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/utils/gcp_field_validator.py b/airflow/contrib/utils/gcp_field_validator.py index 08d677c2c17e0..e623d37ccdb6c 100644 --- a/airflow/contrib/utils/gcp_field_validator.py +++ b/airflow/contrib/utils/gcp_field_validator.py @@ -21,10 +21,13 @@ # pylint: disable=unused-import from airflow.providers.google.cloud.utils.field_validator import ( # noqa - GcpBodyFieldValidator, GcpFieldValidationException, GcpValidationSpecificationException, + GcpBodyFieldValidator, + GcpFieldValidationException, + GcpValidationSpecificationException, ) warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.utils.field_validator`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/utils/log/__init__.py b/airflow/contrib/utils/log/__init__.py index aecb6b83be87f..ecc47d3d85122 100644 --- a/airflow/contrib/utils/log/__init__.py +++ b/airflow/contrib/utils/log/__init__.py @@ -18,7 +18,4 @@ import warnings -warnings.warn( - "This module is deprecated. Please use `airflow.utils.log`.", - DeprecationWarning, stacklevel=2 -) +warnings.warn("This module is deprecated. Please use `airflow.utils.log`.", DeprecationWarning, stacklevel=2) diff --git a/airflow/contrib/utils/log/task_handler_with_custom_formatter.py b/airflow/contrib/utils/log/task_handler_with_custom_formatter.py index a6ebd66483c85..5045c90ad3f62 100644 --- a/airflow/contrib/utils/log/task_handler_with_custom_formatter.py +++ b/airflow/contrib/utils/log/task_handler_with_custom_formatter.py @@ -23,5 +23,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.utils.log.task_handler_with_custom_formatter`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/utils/mlengine_operator_utils.py b/airflow/contrib/utils/mlengine_operator_utils.py index 80db88d62016b..0c9439874e4ad 100644 --- a/airflow/contrib/utils/mlengine_operator_utils.py +++ b/airflow/contrib/utils/mlengine_operator_utils.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.utils.mlengine_operator_utils`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/utils/mlengine_prediction_summary.py b/airflow/contrib/utils/mlengine_prediction_summary.py index c245d4a3de0ca..2ac126f9cca48 100644 --- a/airflow/contrib/utils/mlengine_prediction_summary.py +++ b/airflow/contrib/utils/mlengine_prediction_summary.py @@ -28,5 +28,6 @@ warnings.warn( "This module is deprecated. " "Please use `airflow.providers.google.cloud.utils.mlengine_prediction_summary`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/contrib/utils/sendgrid.py b/airflow/contrib/utils/sendgrid.py index 8801e76cd402f..16408f92d3a5a 100644 --- a/airflow/contrib/utils/sendgrid.py +++ b/airflow/contrib/utils/sendgrid.py @@ -30,6 +30,7 @@ def send_email(*args, **kwargs): """This function is deprecated. Please use `airflow.providers.sendgrid.utils.emailer.send_email`.""" warnings.warn( "This function is deprecated. Please use `airflow.providers.sendgrid.utils.emailer.send_email`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) return import_string('airflow.providers.sendgrid.utils.emailer.send_email')(*args, **kwargs) diff --git a/airflow/contrib/utils/weekday.py b/airflow/contrib/utils/weekday.py index 85e85cceed5ea..d2548aa391ad3 100644 --- a/airflow/contrib/utils/weekday.py +++ b/airflow/contrib/utils/weekday.py @@ -21,6 +21,5 @@ from airflow.utils.weekday import WeekDay # noqa warnings.warn( - "This module is deprecated. Please use `airflow.utils.weekday`.", - DeprecationWarning, stacklevel=2 + "This module is deprecated. Please use `airflow.utils.weekday`.", DeprecationWarning, stacklevel=2 ) diff --git a/airflow/example_dags/example_bash_operator.py b/airflow/example_dags/example_bash_operator.py index 0bda87a9a314b..aa3004f65a380 100644 --- a/airflow/example_dags/example_bash_operator.py +++ b/airflow/example_dags/example_bash_operator.py @@ -36,7 +36,7 @@ start_date=days_ago(2), dagrun_timeout=timedelta(minutes=60), tags=['example', 'example2'], - params={"example_key": "example_value"} + params={"example_key": "example_value"}, ) run_this_last = DummyOperator( diff --git a/airflow/example_dags/example_branch_operator.py b/airflow/example_dags/example_branch_operator.py index 6efacc8808138..97256c65dfb3c 100644 --- a/airflow/example_dags/example_branch_operator.py +++ b/airflow/example_dags/example_branch_operator.py @@ -34,7 +34,7 @@ default_args=args, start_date=days_ago(2), schedule_interval="@daily", - tags=['example', 'example2'] + tags=['example', 'example2'], ) run_this_first = DummyOperator( diff --git a/airflow/example_dags/example_branch_python_dop_operator_3.py b/airflow/example_dags/example_branch_python_dop_operator_3.py index 0a45e6093ebe9..9d60fa3eb2b58 100644 --- a/airflow/example_dags/example_branch_python_dop_operator_3.py +++ b/airflow/example_dags/example_branch_python_dop_operator_3.py @@ -36,7 +36,7 @@ schedule_interval='*/1 * * * *', start_date=days_ago(2), default_args=args, - tags=['example'] + tags=['example'], ) @@ -48,8 +48,11 @@ def should_run(**kwargs): :return: Id of the task to run :rtype: str """ - print('------------- exec dttm = {} and minute = {}'. - format(kwargs['execution_date'], kwargs['execution_date'].minute)) + print( + '------------- exec dttm = {} and minute = {}'.format( + kwargs['execution_date'], kwargs['execution_date'].minute + ) + ) if kwargs['execution_date'].minute % 2 == 0: return "dummy_task_1" else: diff --git a/airflow/example_dags/example_dag_decorator.py b/airflow/example_dags/example_dag_decorator.py index 6f8f38229aace..04c87ac399cd5 100644 --- a/airflow/example_dags/example_dag_decorator.py +++ b/airflow/example_dags/example_dag_decorator.py @@ -40,25 +40,20 @@ def example_dag_decorator(email: str = 'example@example.com'): :type email: str """ # Using default connection as it's set to httpbin.org by default - get_ip = SimpleHttpOperator( - task_id='get_ip', endpoint='get', method='GET' - ) + get_ip = SimpleHttpOperator(task_id='get_ip', endpoint='get', method='GET') @task(multiple_outputs=True) def prepare_email(raw_json: str) -> Dict[str, str]: external_ip = json.loads(raw_json)['origin'] return { 'subject': f'Server connected from {external_ip}', - 'body': f'Seems like today your server executing Airflow is connected from IP {external_ip}
' + 'body': f'Seems like today your server executing Airflow is connected from IP {external_ip}
', } email_info = prepare_email(get_ip.output) EmailOperator( - task_id='send_email', - to=email, - subject=email_info['subject'], - html_content=email_info['body'] + task_id='send_email', to=email, subject=email_info['subject'], html_content=email_info['body'] ) diff --git a/airflow/example_dags/example_external_task_marker_dag.py b/airflow/example_dags/example_external_task_marker_dag.py index 13dab0dcd015f..b31240f9226ef 100644 --- a/airflow/example_dags/example_external_task_marker_dag.py +++ b/airflow/example_dags/example_external_task_marker_dag.py @@ -52,9 +52,11 @@ tags=['example2'], ) as parent_dag: # [START howto_operator_external_task_marker] - parent_task = ExternalTaskMarker(task_id="parent_task", - external_dag_id="example_external_task_marker_child", - external_task_id="child_task1") + parent_task = ExternalTaskMarker( + task_id="parent_task", + external_dag_id="example_external_task_marker_child", + external_task_id="child_task1", + ) # [END howto_operator_external_task_marker] with DAG( @@ -64,13 +66,15 @@ tags=['example2'], ) as child_dag: # [START howto_operator_external_task_sensor] - child_task1 = ExternalTaskSensor(task_id="child_task1", - external_dag_id=parent_dag.dag_id, - external_task_id=parent_task.task_id, - timeout=600, - allowed_states=['success'], - failed_states=['failed', 'skipped'], - mode="reschedule") + child_task1 = ExternalTaskSensor( + task_id="child_task1", + external_dag_id=parent_dag.dag_id, + external_task_id=parent_task.task_id, + timeout=600, + allowed_states=['success'], + failed_states=['failed', 'skipped'], + mode="reschedule", + ) # [END howto_operator_external_task_sensor] child_task2 = DummyOperator(task_id="child_task2") child_task1 >> child_task2 diff --git a/airflow/example_dags/example_kubernetes_executor.py b/airflow/example_dags/example_kubernetes_executor.py index aa4fc32f1b2c1..a782efcf77e4a 100644 --- a/airflow/example_dags/example_kubernetes_executor.py +++ b/airflow/example_dags/example_kubernetes_executor.py @@ -43,24 +43,14 @@ { 'topologyKey': 'kubernetes.io/hostname', 'labelSelector': { - 'matchExpressions': [ - { - 'key': 'app', - 'operator': 'In', - 'values': ['airflow'] - } - ] - } + 'matchExpressions': [{'key': 'app', 'operator': 'In', 'values': ['airflow']}] + }, } ] } } - tolerations = [{ - 'key': 'dedicated', - 'operator': 'Equal', - 'value': 'airflow' - }] + tolerations = [{'key': 'dedicated', 'operator': 'Equal', 'value': 'airflow'}] def use_zip_binary(): """ @@ -74,23 +64,20 @@ def use_zip_binary(): raise SystemError("The zip binary is missing") # You don't have to use any special KubernetesExecutor configuration if you don't want to - start_task = PythonOperator( - task_id="start_task", - python_callable=print_stuff - ) + start_task = PythonOperator(task_id="start_task", python_callable=print_stuff) # But you can if you want to one_task = PythonOperator( task_id="one_task", python_callable=print_stuff, - executor_config={"KubernetesExecutor": {"image": "airflow/ci:latest"}} + executor_config={"KubernetesExecutor": {"image": "airflow/ci:latest"}}, ) # Use the zip binary, which is only found in this special docker image two_task = PythonOperator( task_id="two_task", python_callable=use_zip_binary, - executor_config={"KubernetesExecutor": {"image": "airflow/ci_zip:latest"}} + executor_config={"KubernetesExecutor": {"image": "airflow/ci_zip:latest"}}, ) # Limit resources on this operator/task with node affinity & tolerations @@ -98,17 +85,20 @@ def use_zip_binary(): task_id="three_task", python_callable=print_stuff, executor_config={ - "KubernetesExecutor": {"request_memory": "128Mi", - "limit_memory": "128Mi", - "tolerations": tolerations, - "affinity": affinity}} + "KubernetesExecutor": { + "request_memory": "128Mi", + "limit_memory": "128Mi", + "tolerations": tolerations, + "affinity": affinity, + } + }, ) # Add arbitrary labels to worker pods four_task = PythonOperator( task_id="four_task", python_callable=print_stuff, - executor_config={"KubernetesExecutor": {"labels": {"foo": "bar"}}} + executor_config={"KubernetesExecutor": {"labels": {"foo": "bar"}}}, ) start_task >> [one_task, two_task, three_task, four_task] diff --git a/airflow/example_dags/example_kubernetes_executor_config.py b/airflow/example_dags/example_kubernetes_executor_config.py index d361227a4e7d1..57b2c4a8435e0 100644 --- a/airflow/example_dags/example_kubernetes_executor_config.py +++ b/airflow/example_dags/example_kubernetes_executor_config.py @@ -67,12 +67,9 @@ def test_volume_mount(): start_task = PythonOperator( task_id="start_task", python_callable=print_stuff, - executor_config={"pod_override": k8s.V1Pod( - metadata=k8s.V1ObjectMeta( - annotations={"test": "annotation"} - ) - ) - } + executor_config={ + "pod_override": k8s.V1Pod(metadata=k8s.V1ObjectMeta(annotations={"test": "annotation"})) + }, ) # [START task_with_volume] @@ -86,24 +83,19 @@ def test_volume_mount(): k8s.V1Container( name="base", volume_mounts=[ - k8s.V1VolumeMount( - mount_path="/foo/", - name="example-kubernetes-test-volume" - ) - ] + k8s.V1VolumeMount(mount_path="/foo/", name="example-kubernetes-test-volume") + ], ) ], volumes=[ k8s.V1Volume( name="example-kubernetes-test-volume", - host_path=k8s.V1HostPathVolumeSource( - path="/tmp/" - ) + host_path=k8s.V1HostPathVolumeSource(path="/tmp/"), ) - ] + ], ) ), - } + }, ) # [END task_with_volume] @@ -117,31 +109,22 @@ def test_volume_mount(): containers=[ k8s.V1Container( name="base", - volume_mounts=[k8s.V1VolumeMount( - mount_path="/shared/", - name="shared-empty-dir" - )] + volume_mounts=[k8s.V1VolumeMount(mount_path="/shared/", name="shared-empty-dir")], ), k8s.V1Container( name="sidecar", image="ubuntu", args=["echo \"retrieved from mount\" > /shared/test.txt"], command=["bash", "-cx"], - volume_mounts=[k8s.V1VolumeMount( - mount_path="/shared/", - name="shared-empty-dir" - )] - ) + volume_mounts=[k8s.V1VolumeMount(mount_path="/shared/", name="shared-empty-dir")], + ), ], volumes=[ - k8s.V1Volume( - name="shared-empty-dir", - empty_dir=k8s.V1EmptyDirVolumeSource() - ), - ] + k8s.V1Volume(name="shared-empty-dir", empty_dir=k8s.V1EmptyDirVolumeSource()), + ], ) ), - } + }, ) # [END task_with_sidecar] @@ -149,27 +132,15 @@ def test_volume_mount(): third_task = PythonOperator( task_id="non_root_task", python_callable=print_stuff, - executor_config={"pod_override": k8s.V1Pod( - metadata=k8s.V1ObjectMeta( - labels={ - "release": "stable" - } - ) - ) - } + executor_config={"pod_override": k8s.V1Pod(metadata=k8s.V1ObjectMeta(labels={"release": "stable"}))}, ) other_ns_task = PythonOperator( task_id="other_namespace_task", python_callable=print_stuff, executor_config={ - "KubernetesExecutor": { - "namespace": "test-namespace", - "labels": { - "release": "stable" - } - } - } + "KubernetesExecutor": {"namespace": "test-namespace", "labels": {"release": "stable"}} + }, ) start_task >> volume_task >> third_task diff --git a/airflow/example_dags/example_latest_only.py b/airflow/example_dags/example_latest_only.py index fda514b8fd8fb..9be8a82958170 100644 --- a/airflow/example_dags/example_latest_only.py +++ b/airflow/example_dags/example_latest_only.py @@ -29,7 +29,7 @@ dag_id='latest_only', schedule_interval=dt.timedelta(hours=4), start_date=days_ago(2), - tags=['example2', 'example3'] + tags=['example2', 'example3'], ) latest_only = LatestOnlyOperator(task_id='latest_only', dag=dag) diff --git a/airflow/example_dags/example_latest_only_with_trigger.py b/airflow/example_dags/example_latest_only_with_trigger.py index 2568139ec1bb4..e9d136063b89c 100644 --- a/airflow/example_dags/example_latest_only_with_trigger.py +++ b/airflow/example_dags/example_latest_only_with_trigger.py @@ -32,7 +32,7 @@ dag_id='latest_only_with_trigger', schedule_interval=dt.timedelta(hours=4), start_date=days_ago(2), - tags=['example3'] + tags=['example3'], ) latest_only = LatestOnlyOperator(task_id='latest_only', dag=dag) diff --git a/airflow/example_dags/example_nested_branch_dag.py b/airflow/example_dags/example_nested_branch_dag.py index 0050d9a8df862..ffc1ad5cf70a0 100644 --- a/airflow/example_dags/example_nested_branch_dag.py +++ b/airflow/example_dags/example_nested_branch_dag.py @@ -28,10 +28,7 @@ from airflow.utils.dates import days_ago with DAG( - dag_id="example_nested_branch_dag", - start_date=days_ago(2), - schedule_interval="@daily", - tags=["example"] + dag_id="example_nested_branch_dag", start_date=days_ago(2), schedule_interval="@daily", tags=["example"] ) as dag: branch_1 = BranchPythonOperator(task_id="branch_1", python_callable=lambda: "true_1") join_1 = DummyOperator(task_id="join_1", trigger_rule="none_failed_or_skipped") diff --git a/airflow/example_dags/example_passing_params_via_test_command.py b/airflow/example_dags/example_passing_params_via_test_command.py index bcd5318c65fe1..8eaadd7eb245f 100644 --- a/airflow/example_dags/example_passing_params_via_test_command.py +++ b/airflow/example_dags/example_passing_params_via_test_command.py @@ -34,7 +34,7 @@ schedule_interval='*/1 * * * *', start_date=days_ago(1), dagrun_timeout=timedelta(minutes=4), - tags=['example'] + tags=['example'], ) @@ -45,8 +45,12 @@ def my_py_command(test_mode, params): -t '{"foo":"bar"}'` """ if test_mode: - print(" 'foo' was passed in via test={} command : kwargs[params][foo] \ - = {}".format(test_mode, params["foo"])) + print( + " 'foo' was passed in via test={} command : kwargs[params][foo] \ + = {}".format( + test_mode, params["foo"] + ) + ) # Print out the value of "miff", passed in below via the Python Operator print(" 'miff' was passed in via task params = {}".format(params["miff"])) return 1 @@ -83,10 +87,6 @@ def print_env_vars(test_mode): print("AIRFLOW_TEST_MODE={}".format(os.environ.get('AIRFLOW_TEST_MODE'))) -env_var_test_task = PythonOperator( - task_id='env_var_test_task', - python_callable=print_env_vars, - dag=dag -) +env_var_test_task = PythonOperator(task_id='env_var_test_task', python_callable=print_env_vars, dag=dag) run_this >> also_run_this diff --git a/airflow/example_dags/example_python_operator.py b/airflow/example_dags/example_python_operator.py index 5b6d7b54d0b74..d5e16a55e95b8 100644 --- a/airflow/example_dags/example_python_operator.py +++ b/airflow/example_dags/example_python_operator.py @@ -33,7 +33,7 @@ default_args=args, schedule_interval=None, start_date=days_ago(2), - tags=['example'] + tags=['example'], ) @@ -83,6 +83,7 @@ def callable_virtualenv(): from time import sleep from colorama import Back, Fore, Style + print(Fore.RED + 'some red text') print(Back.GREEN + 'and with a green background') print(Style.DIM + 'and in dim text') @@ -96,9 +97,7 @@ def callable_virtualenv(): virtualenv_task = PythonVirtualenvOperator( task_id="virtualenv_python", python_callable=callable_virtualenv, - requirements=[ - "colorama==0.4.0" - ], + requirements=["colorama==0.4.0"], system_site_packages=False, dag=dag, ) diff --git a/airflow/example_dags/example_subdag_operator.py b/airflow/example_dags/example_subdag_operator.py index f21e3d4db1577..de2853d3eb842 100644 --- a/airflow/example_dags/example_subdag_operator.py +++ b/airflow/example_dags/example_subdag_operator.py @@ -32,11 +32,7 @@ } dag = DAG( - dag_id=DAG_NAME, - default_args=args, - start_date=days_ago(2), - schedule_interval="@once", - tags=['example'] + dag_id=DAG_NAME, default_args=args, start_date=days_ago(2), schedule_interval="@once", tags=['example'] ) start = DummyOperator( diff --git a/airflow/example_dags/example_trigger_controller_dag.py b/airflow/example_dags/example_trigger_controller_dag.py index f8fd5d6610b98..39bc766f90b3f 100644 --- a/airflow/example_dags/example_trigger_controller_dag.py +++ b/airflow/example_dags/example_trigger_controller_dag.py @@ -30,7 +30,7 @@ default_args={"owner": "airflow"}, start_date=days_ago(2), schedule_interval="@once", - tags=['example'] + tags=['example'], ) trigger = TriggerDagRunOperator( diff --git a/airflow/example_dags/example_trigger_target_dag.py b/airflow/example_dags/example_trigger_target_dag.py index 3f4cfd0f3c856..035527546d289 100644 --- a/airflow/example_dags/example_trigger_target_dag.py +++ b/airflow/example_dags/example_trigger_target_dag.py @@ -32,7 +32,7 @@ default_args={"owner": "airflow"}, start_date=days_ago(2), schedule_interval=None, - tags=['example'] + tags=['example'], ) diff --git a/airflow/example_dags/example_xcom.py b/airflow/example_dags/example_xcom.py index b3956822a3f64..779e392c70083 100644 --- a/airflow/example_dags/example_xcom.py +++ b/airflow/example_dags/example_xcom.py @@ -26,7 +26,7 @@ schedule_interval="@once", start_date=days_ago(2), default_args={'owner': 'airflow'}, - tags=['example'] + tags=['example'], ) value_1 = [1, 2, 3] diff --git a/airflow/example_dags/subdags/subdag.py b/airflow/example_dags/subdags/subdag.py index e65a5c98eaa49..9815af33f93d5 100644 --- a/airflow/example_dags/subdags/subdag.py +++ b/airflow/example_dags/subdags/subdag.py @@ -49,4 +49,6 @@ def subdag(parent_dag_name, child_dag_name, args): ) return dag_subdag + + # [END subdag] diff --git a/airflow/example_dags/tutorial.py b/airflow/example_dags/tutorial.py index 39f779c905b67..a00051c43abe2 100644 --- a/airflow/example_dags/tutorial.py +++ b/airflow/example_dags/tutorial.py @@ -27,6 +27,7 @@ # The DAG object; we'll need this to instantiate a DAG from airflow import DAG + # Operators; we need this to operate! from airflow.operators.bash import BashOperator from airflow.utils.dates import days_ago diff --git a/airflow/example_dags/tutorial_decorated_etl_dag.py b/airflow/example_dags/tutorial_decorated_etl_dag.py index 0f78940d824e3..b351e7318ad6f 100644 --- a/airflow/example_dags/tutorial_decorated_etl_dag.py +++ b/airflow/example_dags/tutorial_decorated_etl_dag.py @@ -71,6 +71,7 @@ def extract(): order_data_dict = json.loads(data_string) return order_data_dict + # [END extract] # [START transform] @@ -87,6 +88,7 @@ def transform(order_data_dict: dict): total_order_value += value return {"total_order_value": total_order_value} + # [END transform] # [START load] @@ -99,6 +101,7 @@ def load(total_order_value: float): """ print("Total order value is: %.2f" % total_order_value) + # [END load] # [START main_flow] diff --git a/airflow/example_dags/tutorial_etl_dag.py b/airflow/example_dags/tutorial_etl_dag.py index 4a4405e0d2c18..48b519b5e59eb 100644 --- a/airflow/example_dags/tutorial_etl_dag.py +++ b/airflow/example_dags/tutorial_etl_dag.py @@ -30,6 +30,7 @@ # The DAG object; we'll need this to instantiate a DAG from airflow import DAG + # Operators; we need this to operate! from airflow.operators.python import PythonOperator from airflow.utils.dates import days_ago @@ -63,6 +64,7 @@ def extract(**kwargs): ti = kwargs['ti'] data_string = '{"1001": 301.27, "1002": 433.21, "1003": 502.22}' ti.xcom_push('order_data', data_string) + # [END extract_function] # [START transform_function] @@ -78,6 +80,7 @@ def transform(**kwargs): total_value = {"total_order_value": total_order_value} total_value_json_string = json.dumps(total_value) ti.xcom_push('total_order_value', total_value_json_string) + # [END transform_function] # [START load_function] @@ -87,6 +90,7 @@ def load(**kwargs): total_order_value = json.loads(total_value_string) print(total_order_value) + # [END load_function] # [START main_flow] diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py index 8477019de614d..8ae4c327bd63a 100644 --- a/airflow/executors/base_executor.py +++ b/airflow/executors/base_executor.py @@ -60,19 +60,20 @@ class BaseExecutor(LoggingMixin): def __init__(self, parallelism: int = PARALLELISM): super().__init__() self.parallelism: int = parallelism - self.queued_tasks: OrderedDict[TaskInstanceKey, QueuedTaskInstanceType] \ - = OrderedDict() + self.queued_tasks: OrderedDict[TaskInstanceKey, QueuedTaskInstanceType] = OrderedDict() self.running: Set[TaskInstanceKey] = set() self.event_buffer: Dict[TaskInstanceKey, EventBufferValueType] = {} def start(self): # pragma: no cover """Executors may need to get things started.""" - def queue_command(self, - task_instance: TaskInstance, - command: CommandType, - priority: int = 1, - queue: Optional[str] = None): + def queue_command( + self, + task_instance: TaskInstance, + command: CommandType, + priority: int = 1, + queue: Optional[str] = None, + ): """Queues command to task""" if task_instance.key not in self.queued_tasks and task_instance.key not in self.running: self.log.info("Adding to queue: %s", command) @@ -81,16 +82,17 @@ def queue_command(self, self.log.error("could not queue task %s", task_instance.key) def queue_task_instance( - self, - task_instance: TaskInstance, - mark_success: bool = False, - pickle_id: Optional[str] = None, - ignore_all_deps: bool = False, - ignore_depends_on_past: bool = False, - ignore_task_deps: bool = False, - ignore_ti_state: bool = False, - pool: Optional[str] = None, - cfg_path: Optional[str] = None) -> None: + self, + task_instance: TaskInstance, + mark_success: bool = False, + pickle_id: Optional[str] = None, + ignore_all_deps: bool = False, + ignore_depends_on_past: bool = False, + ignore_task_deps: bool = False, + ignore_ti_state: bool = False, + pool: Optional[str] = None, + cfg_path: Optional[str] = None, + ) -> None: """Queues task instance.""" pool = pool or task_instance.pool @@ -107,13 +109,15 @@ def queue_task_instance( ignore_ti_state=ignore_ti_state, pool=pool, pickle_id=pickle_id, - cfg_path=cfg_path) + cfg_path=cfg_path, + ) self.log.debug("created command %s", command_list_to_run) self.queue_command( task_instance, command_list_to_run, priority=task_instance.task.priority_weight_total, - queue=task_instance.task.queue) + queue=task_instance.task.queue, + ) def has_task(self, task_instance: TaskInstance) -> bool: """ @@ -163,7 +167,8 @@ def order_queued_tasks_by_priority(self) -> List[Tuple[TaskInstanceKey, QueuedTa return sorted( [(k, v) for k, v in self.queued_tasks.items()], # pylint: disable=unnecessary-comprehension key=lambda x: x[1][1], - reverse=True) + reverse=True, + ) def trigger_tasks(self, open_slots: int) -> None: """ @@ -177,10 +182,7 @@ def trigger_tasks(self, open_slots: int) -> None: key, (command, _, _, ti) = sorted_queue.pop(0) self.queued_tasks.pop(key) self.running.add(key) - self.execute_async(key=key, - command=command, - queue=None, - executor_config=ti.executor_config) + self.execute_async(key=key, command=command, queue=None, executor_config=ti.executor_config) def change_state(self, key: TaskInstanceKey, state: str, info=None) -> None: """ @@ -235,11 +237,13 @@ def get_event_buffer(self, dag_ids=None) -> Dict[TaskInstanceKey, EventBufferVal return cleared_events - def execute_async(self, - key: TaskInstanceKey, - command: CommandType, - queue: Optional[str] = None, - executor_config: Optional[Any] = None) -> None: # pragma: no cover + def execute_async( + self, + key: TaskInstanceKey, + command: CommandType, + queue: Optional[str] = None, + executor_config: Optional[Any] = None, + ) -> None: # pragma: no cover """ This method will execute the command asynchronously. diff --git a/airflow/executors/celery_executor.py b/airflow/executors/celery_executor.py index 716d1ed973873..86d76d86cf0c7 100644 --- a/airflow/executors/celery_executor.py +++ b/airflow/executors/celery_executor.py @@ -71,9 +71,7 @@ else: celery_configuration = DEFAULT_CELERY_CONFIG -app = Celery( - conf.get('celery', 'CELERY_APP_NAME'), - config_source=celery_configuration) +app = Celery(conf.get('celery', 'CELERY_APP_NAME'), config_source=celery_configuration) @app.task @@ -103,6 +101,7 @@ def _execute_in_fork(command_to_exec: CommandType) -> None: ret = 1 try: from airflow.cli.cli_parser import get_parser + parser = get_parser() # [1:] - remove "airflow" from the start of the command args = parser.parse_args(command_to_exec[1:]) @@ -123,8 +122,7 @@ def _execute_in_subprocess(command_to_exec: CommandType) -> None: env = os.environ.copy() try: # pylint: disable=unexpected-keyword-arg - subprocess.check_output(command_to_exec, stderr=subprocess.STDOUT, - close_fds=True, env=env) + subprocess.check_output(command_to_exec, stderr=subprocess.STDOUT, close_fds=True, env=env) # pylint: disable=unexpected-keyword-arg except subprocess.CalledProcessError as e: log.exception('execute_command encountered a CalledProcessError') @@ -153,8 +151,9 @@ def __init__(self, exception: Exception, exception_traceback: str): TaskInstanceInCelery = Tuple[TaskInstanceKey, SimpleTaskInstance, CommandType, Optional[str], Task] -def send_task_to_executor(task_tuple: TaskInstanceInCelery) \ - -> Tuple[TaskInstanceKey, CommandType, Union[AsyncResult, ExceptionWithTraceback]]: +def send_task_to_executor( + task_tuple: TaskInstanceInCelery, +) -> Tuple[TaskInstanceKey, CommandType, Union[AsyncResult, ExceptionWithTraceback]]: """Sends task to executor.""" key, _, command, queue, task_to_run = task_tuple try: @@ -185,10 +184,13 @@ def on_celery_import_modules(*args, **kwargs): import airflow.operators.bash import airflow.operators.python import airflow.operators.subdag_operator # noqa: F401 + try: import kubernetes.client # noqa: F401 except ImportError: pass + + # pylint: enable=unused-import @@ -220,10 +222,7 @@ def __init__(self): ) def start(self) -> None: - self.log.debug( - 'Starting Celery Executor using %s processes for syncing', - self._sync_parallelism - ) + self.log.debug('Starting Celery Executor using %s processes for syncing', self._sync_parallelism) def _num_tasks_per_send_process(self, to_send_count: int) -> int: """ @@ -232,8 +231,7 @@ def _num_tasks_per_send_process(self, to_send_count: int) -> int: :return: Number of tasks that should be sent per process :rtype: int """ - return max(1, - int(math.ceil(1.0 * to_send_count / self._sync_parallelism))) + return max(1, int(math.ceil(1.0 * to_send_count / self._sync_parallelism))) def trigger_tasks(self, open_slots: int) -> None: """ @@ -297,14 +295,14 @@ def reset_signals(): # Since we are run from inside the SchedulerJob, we don't to # inherit the signal handlers that we registered there. import signal + signal.signal(signal.SIGINT, signal.SIG_DFL) signal.signal(signal.SIGTERM, signal.SIG_DFL) with Pool(processes=num_processes, initializer=reset_signals) as send_pool: key_and_async_results = send_pool.map( - send_task_to_executor, - task_tuples_to_send, - chunksize=chunksize) + send_task_to_executor, task_tuples_to_send, chunksize=chunksize + ) return key_and_async_results def sync(self) -> None: @@ -348,7 +346,7 @@ def _check_for_stalled_adopted_tasks(self): "Adopted tasks were still pending after %s, assuming they never made it to celery and " "clearing:\n\t%s", self.task_adoption_timeout, - "\n\t".join([repr(x) for x in timedout_keys]) + "\n\t".join([repr(x) for x in timedout_keys]), ) for key in timedout_keys: self.event_buffer[key] = (State.FAILED, None) @@ -394,11 +392,13 @@ def end(self, synchronous: bool = False) -> None: time.sleep(5) self.sync() - def execute_async(self, - key: TaskInstanceKey, - command: CommandType, - queue: Optional[str] = None, - executor_config: Optional[Any] = None): + def execute_async( + self, + key: TaskInstanceKey, + command: CommandType, + queue: Optional[str] = None, + executor_config: Optional[Any] = None, + ): """Do not allow async execution for Celery executor.""" raise AirflowException("No Async execution for Celery executor.") @@ -456,14 +456,14 @@ def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance if adopted: task_instance_str = '\n\t'.join(adopted) - self.log.info("Adopted the following %d tasks from a dead executor\n\t%s", - len(adopted), task_instance_str) + self.log.info( + "Adopted the following %d tasks from a dead executor\n\t%s", len(adopted), task_instance_str + ) return not_adopted_tis -def fetch_celery_task_state(async_result: AsyncResult) -> \ - Tuple[str, Union[str, ExceptionWithTraceback], Any]: +def fetch_celery_task_state(async_result: AsyncResult) -> Tuple[str, Union[str, ExceptionWithTraceback], Any]: """ Fetch and return the state of the given celery task. The scope of this function is global so that it can be called by subprocesses in the pool. @@ -535,8 +535,9 @@ def _get_many_from_db_backend(self, async_tasks) -> Mapping[str, EventBufferValu return self._prepare_state_and_info_by_task_dict(task_ids, task_results_by_task_id) @staticmethod - def _prepare_state_and_info_by_task_dict(task_ids, - task_results_by_task_id) -> Mapping[str, EventBufferValueType]: + def _prepare_state_and_info_by_task_dict( + task_ids, task_results_by_task_id + ) -> Mapping[str, EventBufferValueType]: state_info: MutableMapping[str, EventBufferValueType] = {} for task_id in task_ids: task_result = task_results_by_task_id.get(task_id) @@ -556,16 +557,16 @@ def _get_many_using_multiprocessing(self, async_results) -> Mapping[str, EventBu chunksize = max(1, math.floor(math.ceil(1.0 * len(async_results) / self._sync_parallelism))) task_id_to_states_and_info = sync_pool.map( - fetch_celery_task_state, - async_results, - chunksize=chunksize) + fetch_celery_task_state, async_results, chunksize=chunksize + ) states_and_info_by_task_id: MutableMapping[str, EventBufferValueType] = {} for task_id, state_or_exception, info in task_id_to_states_and_info: if isinstance(state_or_exception, ExceptionWithTraceback): self.log.error( # pylint: disable=logging-not-lazy CELERY_FETCH_ERR_MSG_HEADER + ":%s\n%s\n", - state_or_exception.exception, state_or_exception.traceback + state_or_exception.exception, + state_or_exception.traceback, ) else: states_and_info_by_task_id[task_id] = state_or_exception, info diff --git a/airflow/executors/celery_kubernetes_executor.py b/airflow/executors/celery_kubernetes_executor.py index 0c400f0910079..b58d579ee2df6 100644 --- a/airflow/executors/celery_kubernetes_executor.py +++ b/airflow/executors/celery_kubernetes_executor.py @@ -63,31 +63,30 @@ def queue_command( task_instance: TaskInstance, command: CommandType, priority: int = 1, - queue: Optional[str] = None + queue: Optional[str] = None, ): """Queues command via celery or kubernetes executor""" executor = self._router(task_instance) - self.log.debug( - "Using executor: %s for %s", executor.__class__.__name__, task_instance.key - ) + self.log.debug("Using executor: %s for %s", executor.__class__.__name__, task_instance.key) executor.queue_command(task_instance, command, priority, queue) def queue_task_instance( - self, - task_instance: TaskInstance, - mark_success: bool = False, - pickle_id: Optional[str] = None, - ignore_all_deps: bool = False, - ignore_depends_on_past: bool = False, - ignore_task_deps: bool = False, - ignore_ti_state: bool = False, - pool: Optional[str] = None, - cfg_path: Optional[str] = None) -> None: + self, + task_instance: TaskInstance, + mark_success: bool = False, + pickle_id: Optional[str] = None, + ignore_all_deps: bool = False, + ignore_depends_on_past: bool = False, + ignore_task_deps: bool = False, + ignore_ti_state: bool = False, + pool: Optional[str] = None, + cfg_path: Optional[str] = None, + ) -> None: """Queues task instance via celery or kubernetes executor""" executor = self._router(SimpleTaskInstance(task_instance)) - self.log.debug("Using executor: %s to queue_task_instance for %s", - executor.__class__.__name__, task_instance.key - ) + self.log.debug( + "Using executor: %s to queue_task_instance for %s", executor.__class__.__name__, task_instance.key + ) executor.queue_task_instance( task_instance, mark_success, @@ -97,7 +96,7 @@ def queue_task_instance( ignore_task_deps, ignore_ti_state, pool, - cfg_path + cfg_path, ) def has_task(self, task_instance: TaskInstance) -> bool: @@ -107,8 +106,9 @@ def has_task(self, task_instance: TaskInstance) -> bool: :param task_instance: TaskInstance :return: True if the task is known to this executor """ - return self.celery_executor.has_task(task_instance) \ - or self.kubernetes_executor.has_task(task_instance) + return self.celery_executor.has_task(task_instance) or self.kubernetes_executor.has_task( + task_instance + ) def heartbeat(self) -> None: """Heartbeat sent to trigger new jobs in celery and kubernetes executor""" diff --git a/airflow/executors/dask_executor.py b/airflow/executors/dask_executor.py index e2cb64fda9961..f26ad0ee7d4c2 100644 --- a/airflow/executors/dask_executor.py +++ b/airflow/executors/dask_executor.py @@ -65,11 +65,13 @@ def start(self) -> None: self.client = Client(self.cluster_address, security=security) self.futures = {} - def execute_async(self, - key: TaskInstanceKey, - command: CommandType, - queue: Optional[str] = None, - executor_config: Optional[Any] = None) -> None: + def execute_async( + self, + key: TaskInstanceKey, + command: CommandType, + queue: Optional[str] = None, + executor_config: Optional[Any] = None, + ) -> None: self.validate_command(command) diff --git a/airflow/executors/debug_executor.py b/airflow/executors/debug_executor.py index 391b7c701e53a..580dc653957de 100644 --- a/airflow/executors/debug_executor.py +++ b/airflow/executors/debug_executor.py @@ -49,7 +49,7 @@ def __init__(self): self.tasks_params: Dict[TaskInstanceKey, Dict[str, Any]] = {} self.fail_fast = conf.getboolean("debug", "fail_fast") - def execute_async(self, *args, **kwargs) -> None: # pylint: disable=signature-differs + def execute_async(self, *args, **kwargs) -> None: # pylint: disable=signature-differs """The method is replaced by custom trigger_task implementation.""" def sync(self) -> None: @@ -63,9 +63,7 @@ def sync(self) -> None: continue if self._terminated.is_set(): - self.log.info( - "Executor is terminated! Stopping %s to %s", ti.key, State.FAILED - ) + self.log.info("Executor is terminated! Stopping %s to %s", ti.key, State.FAILED) ti.set_state(State.FAILED) self.change_state(ti.key, State.FAILED) continue @@ -77,9 +75,7 @@ def _run_task(self, ti: TaskInstance) -> bool: key = ti.key try: params = self.tasks_params.pop(ti.key, {}) - ti._run_raw_task( # pylint: disable=protected-access - job_id=ti.job_id, **params - ) + ti._run_raw_task(job_id=ti.job_id, **params) # pylint: disable=protected-access self.change_state(key, State.SUCCESS) return True except Exception as e: # pylint: disable=broad-except diff --git a/airflow/executors/executor_loader.py b/airflow/executors/executor_loader.py index 6039a7adcf598..5ae3e66f80804 100644 --- a/airflow/executors/executor_loader.py +++ b/airflow/executors/executor_loader.py @@ -45,7 +45,7 @@ class ExecutorLoader: CELERY_KUBERNETES_EXECUTOR: 'airflow.executors.celery_kubernetes_executor.CeleryKubernetesExecutor', DASK_EXECUTOR: 'airflow.executors.dask_executor.DaskExecutor', KUBERNETES_EXECUTOR: 'airflow.executors.kubernetes_executor.KubernetesExecutor', - DEBUG_EXECUTOR: 'airflow.executors.debug_executor.DebugExecutor' + DEBUG_EXECUTOR: 'airflow.executors.debug_executor.DebugExecutor', } @classmethod @@ -55,6 +55,7 @@ def get_default_executor(cls) -> BaseExecutor: return cls._default_executor from airflow.configuration import conf + executor_name = conf.get('core', 'EXECUTOR') cls._default_executor = cls.load_executor(executor_name) @@ -83,12 +84,14 @@ def load_executor(cls, executor_name: str) -> BaseExecutor: if executor_name.count(".") == 1: log.debug( "The executor name looks like the plugin path (executor_name=%s). Trying to load a " - "executor from a plugin", executor_name + "executor from a plugin", + executor_name, ) with suppress(ImportError), suppress(AttributeError): # Load plugins here for executors as at that time the plugins might not have been # initialized yet from airflow import plugins_manager + plugins_manager.integrate_executor_plugins() return import_string(f"airflow.executors.{executor_name}")() @@ -118,5 +121,5 @@ def __load_celery_kubernetes_executor(cls) -> BaseExecutor: UNPICKLEABLE_EXECUTORS = ( ExecutorLoader.LOCAL_EXECUTOR, ExecutorLoader.SEQUENTIAL_EXECUTOR, - ExecutorLoader.DASK_EXECUTOR + ExecutorLoader.DASK_EXECUTOR, ) diff --git a/airflow/executors/kubernetes_executor.py b/airflow/executors/kubernetes_executor.py index 79b75b38b18a2..44d42a7ad846b 100644 --- a/airflow/executors/kubernetes_executor.py +++ b/airflow/executors/kubernetes_executor.py @@ -85,20 +85,18 @@ def __init__(self): # pylint: disable=too-many-statements self.airflow_home = settings.AIRFLOW_HOME self.dags_folder = conf.get(self.core_section, 'dags_folder') self.parallelism = conf.getint(self.core_section, 'parallelism') - self.pod_template_file = conf.get(self.kubernetes_section, 'pod_template_file', - fallback=None) + self.pod_template_file = conf.get(self.kubernetes_section, 'pod_template_file', fallback=None) - self.delete_worker_pods = conf.getboolean( - self.kubernetes_section, 'delete_worker_pods') + self.delete_worker_pods = conf.getboolean(self.kubernetes_section, 'delete_worker_pods') self.delete_worker_pods_on_failure = conf.getboolean( - self.kubernetes_section, 'delete_worker_pods_on_failure') + self.kubernetes_section, 'delete_worker_pods_on_failure' + ) self.worker_pods_creation_batch_size = conf.getint( - self.kubernetes_section, 'worker_pods_creation_batch_size') + self.kubernetes_section, 'worker_pods_creation_batch_size' + ) - self.worker_container_repository = conf.get( - self.kubernetes_section, 'worker_container_repository') - self.worker_container_tag = conf.get( - self.kubernetes_section, 'worker_container_tag') + self.worker_container_repository = conf.get(self.kubernetes_section, 'worker_container_repository') + self.worker_container_tag = conf.get(self.kubernetes_section, 'worker_container_tag') self.kube_image = f'{self.worker_container_repository}:{self.worker_container_tag}' # The Kubernetes Namespace in which the Scheduler and Webserver reside. Note @@ -116,10 +114,12 @@ def __init__(self): # pylint: disable=too-many-statements kube_client_request_args = conf.get(self.kubernetes_section, 'kube_client_request_args') if kube_client_request_args: self.kube_client_request_args = json.loads(kube_client_request_args) - if self.kube_client_request_args['_request_timeout'] and \ - isinstance(self.kube_client_request_args['_request_timeout'], list): - self.kube_client_request_args['_request_timeout'] = \ - tuple(self.kube_client_request_args['_request_timeout']) + if self.kube_client_request_args['_request_timeout'] and isinstance( + self.kube_client_request_args['_request_timeout'], list + ): + self.kube_client_request_args['_request_timeout'] = tuple( + self.kube_client_request_args['_request_timeout'] + ) else: self.kube_client_request_args = {} delete_option_kwargs = conf.get(self.kubernetes_section, 'delete_option_kwargs') @@ -141,13 +141,15 @@ def _get_security_context_val(self, scontext: str) -> Union[str, int]: class KubernetesJobWatcher(multiprocessing.Process, LoggingMixin): """Watches for Kubernetes jobs""" - def __init__(self, - namespace: Optional[str], - multi_namespace_mode: bool, - watcher_queue: 'Queue[KubernetesWatchType]', - resource_version: Optional[str], - scheduler_job_id: Optional[str], - kube_config: Configuration): + def __init__( + self, + namespace: Optional[str], + multi_namespace_mode: bool, + watcher_queue: 'Queue[KubernetesWatchType]', + resource_version: Optional[str], + scheduler_job_id: Optional[str], + kube_config: Configuration, + ): super().__init__() self.namespace = namespace self.multi_namespace_mode = multi_namespace_mode @@ -163,28 +165,31 @@ def run(self) -> None: raise AirflowException(NOT_STARTED_MESSAGE) while True: try: - self.resource_version = self._run(kube_client, self.resource_version, - self.scheduler_job_id, self.kube_config) + self.resource_version = self._run( + kube_client, self.resource_version, self.scheduler_job_id, self.kube_config + ) except ReadTimeoutError: - self.log.warning("There was a timeout error accessing the Kube API. " - "Retrying request.", exc_info=True) + self.log.warning( + "There was a timeout error accessing the Kube API. " "Retrying request.", exc_info=True + ) time.sleep(1) except Exception: self.log.exception('Unknown error in KubernetesJobWatcher. Failing') raise else: - self.log.warning('Watch died gracefully, starting back up with: ' - 'last resource_version: %s', self.resource_version) - - def _run(self, - kube_client: client.CoreV1Api, - resource_version: Optional[str], - scheduler_job_id: str, - kube_config: Any) -> Optional[str]: - self.log.info( - 'Event: and now my watch begins starting at resource_version: %s', - resource_version - ) + self.log.warning( + 'Watch died gracefully, starting back up with: ' 'last resource_version: %s', + self.resource_version, + ) + + def _run( + self, + kube_client: client.CoreV1Api, + resource_version: Optional[str], + scheduler_job_id: str, + kube_config: Any, + ) -> Optional[str]: + self.log.info('Event: and now my watch begins starting at resource_version: %s', resource_version) watcher = watch.Watch() kwargs = {'label_selector': f'airflow-worker={scheduler_job_id}'} @@ -196,20 +201,16 @@ def _run(self, last_resource_version: Optional[str] = None if self.multi_namespace_mode: - list_worker_pods = functools.partial(watcher.stream, - kube_client.list_pod_for_all_namespaces, - **kwargs) + list_worker_pods = functools.partial( + watcher.stream, kube_client.list_pod_for_all_namespaces, **kwargs + ) else: - list_worker_pods = functools.partial(watcher.stream, - kube_client.list_namespaced_pod, - self.namespace, - **kwargs) + list_worker_pods = functools.partial( + watcher.stream, kube_client.list_namespaced_pod, self.namespace, **kwargs + ) for event in list_worker_pods(): task = event['object'] - self.log.info( - 'Event: %s had an event of type %s', - task.metadata.name, event['type'] - ) + self.log.info('Event: %s had an event of type %s', task.metadata.name, event['type']) if event['type'] == 'ERROR': return self.process_error(event) annotations = task.metadata.annotations @@ -234,29 +235,28 @@ def _run(self, def process_error(self, event: Any) -> str: """Process error response""" - self.log.error( - 'Encountered Error response from k8s list namespaced pod stream => %s', - event - ) + self.log.error('Encountered Error response from k8s list namespaced pod stream => %s', event) raw_object = event['raw_object'] if raw_object['code'] == 410: self.log.info( - 'Kubernetes resource version is too old, must reset to 0 => %s', - (raw_object['message'],) + 'Kubernetes resource version is too old, must reset to 0 => %s', (raw_object['message'],) ) # Return resource version 0 return '0' raise AirflowException( - 'Kubernetes failure for %s with code %s and message: %s' % - (raw_object['reason'], raw_object['code'], raw_object['message']) + 'Kubernetes failure for %s with code %s and message: %s' + % (raw_object['reason'], raw_object['code'], raw_object['message']) ) - def process_status(self, pod_id: str, - namespace: str, - status: str, - annotations: Dict[str, str], - resource_version: str, - event: Any) -> None: + def process_status( + self, + pod_id: str, + namespace: str, + status: str, + annotations: Dict[str, str], + resource_version: str, + event: Any, + ) -> None: """Process status response""" if status == 'Pending': if event['type'] == 'DELETED': @@ -277,19 +277,26 @@ def process_status(self, pod_id: str, else: self.log.warning( 'Event: Invalid state: %s on pod: %s in namespace %s with annotations: %s with ' - 'resource_version: %s', status, pod_id, namespace, annotations, resource_version + 'resource_version: %s', + status, + pod_id, + namespace, + annotations, + resource_version, ) class AirflowKubernetesScheduler(LoggingMixin): """Airflow Scheduler for Kubernetes""" - def __init__(self, - kube_config: Any, - task_queue: 'Queue[KubernetesJobType]', - result_queue: 'Queue[KubernetesResultsType]', - kube_client: client.CoreV1Api, - scheduler_job_id: str): + def __init__( + self, + kube_config: Any, + task_queue: 'Queue[KubernetesJobType]', + result_queue: 'Queue[KubernetesResultsType]', + kube_client: client.CoreV1Api, + scheduler_job_id: str, + ): super().__init__() self.log.debug("Creating Kubernetes executor") self.kube_config = kube_config @@ -306,12 +313,14 @@ def __init__(self, def _make_kube_watcher(self) -> KubernetesJobWatcher: resource_version = ResourceVersion().resource_version - watcher = KubernetesJobWatcher(watcher_queue=self.watcher_queue, - namespace=self.kube_config.kube_namespace, - multi_namespace_mode=self.kube_config.multi_namespace_mode, - resource_version=resource_version, - scheduler_job_id=self.scheduler_job_id, - kube_config=self.kube_config) + watcher = KubernetesJobWatcher( + watcher_queue=self.watcher_queue, + namespace=self.kube_config.kube_namespace, + multi_namespace_mode=self.kube_config.multi_namespace_mode, + resource_version=resource_version, + scheduler_job_id=self.scheduler_job_id, + kube_config=self.kube_config, + ) watcher.start() return watcher @@ -320,8 +329,8 @@ def _health_check_kube_watcher(self): self.log.debug("KubeJobWatcher alive, continuing") else: self.log.error( - 'Error while health checking kube watcher process. ' - 'Process died for unknown reasons') + 'Error while health checking kube watcher process. ' 'Process died for unknown reasons' + ) self.kube_watcher = self._make_kube_watcher() def run_next(self, next_job: KubernetesJobType) -> None: @@ -340,8 +349,9 @@ def run_next(self, next_job: KubernetesJobType) -> None: base_worker_pod = PodGenerator.deserialize_model_file(self.kube_config.pod_template_file) if not base_worker_pod: - raise AirflowException("could not find a valid worker template yaml at {}" - .format(self.kube_config.pod_template_file)) + raise AirflowException( + f"could not find a valid worker template yaml at {self.kube_config.pod_template_file}" + ) pod = PodGenerator.construct_pod( namespace=self.namespace, @@ -354,7 +364,7 @@ def run_next(self, next_job: KubernetesJobType) -> None: date=execution_date, command=command, pod_override_object=kube_executor_config, - base_worker_pod=base_worker_pod + base_worker_pod=base_worker_pod, ) # Reconcile the pod generated by the Operator and the Pod # generated by the .cfg file @@ -370,8 +380,11 @@ def delete_pod(self, pod_id: str, namespace: str) -> None: try: self.log.debug("Deleting pod %s in namespace %s", pod_id, namespace) self.kube_client.delete_namespaced_pod( - pod_id, namespace, body=client.V1DeleteOptions(**self.kube_config.delete_option_kwargs), - **self.kube_config.kube_client_request_args) + pod_id, + namespace, + body=client.V1DeleteOptions(**self.kube_config.delete_option_kwargs), + **self.kube_config.kube_client_request_args, + ) except ApiException as e: # If the pod is already deleted if e.status != 404: @@ -403,8 +416,7 @@ def process_watcher_task(self, task: KubernetesWatchType) -> None: """Process the task by watcher.""" pod_id, namespace, state, annotations, resource_version = task self.log.info( - 'Attempting to finish pod; pod_id: %s; state: %s; annotations: %s', - pod_id, state, annotations + 'Attempting to finish pod; pod_id: %s; state: %s; annotations: %s', pod_id, state, annotations ) key = self._annotations_to_key(annotations=annotations) if key: @@ -434,7 +446,7 @@ def _make_safe_pod_id(safe_dag_id: str, safe_task_id: str, safe_uuid: str) -> st """ safe_key = safe_dag_id + safe_task_id - safe_pod_id = safe_key[:MAX_POD_ID_LEN - len(safe_uuid) - 1] + "-" + safe_uuid + safe_pod_id = safe_key[: MAX_POD_ID_LEN - len(safe_uuid) - 1] + "-" + safe_uuid return safe_pod_id @@ -526,67 +538,60 @@ def clear_not_launched_queued_tasks(self, session=None) -> None: self.log.debug("Clearing tasks that have not been launched") if not self.kube_client: raise AirflowException(NOT_STARTED_MESSAGE) - queued_tasks = session \ - .query(TaskInstance) \ - .filter(TaskInstance.state == State.QUEUED).all() - self.log.info( - 'When executor started up, found %s queued task instances', - len(queued_tasks) - ) + queued_tasks = session.query(TaskInstance).filter(TaskInstance.state == State.QUEUED).all() + self.log.info('When executor started up, found %s queued task instances', len(queued_tasks)) for task in queued_tasks: # pylint: disable=protected-access self.log.debug("Checking task %s", task) - dict_string = ( - "dag_id={},task_id={},execution_date={},airflow-worker={}".format( - pod_generator.make_safe_label_value(task.dag_id), - pod_generator.make_safe_label_value(task.task_id), - pod_generator.datetime_to_label_safe_datestring( - task.execution_date - ), - self.scheduler_job_id - ) + dict_string = "dag_id={},task_id={},execution_date={},airflow-worker={}".format( + pod_generator.make_safe_label_value(task.dag_id), + pod_generator.make_safe_label_value(task.task_id), + pod_generator.datetime_to_label_safe_datestring(task.execution_date), + self.scheduler_job_id, ) # pylint: enable=protected-access kwargs = dict(label_selector=dict_string) if self.kube_config.kube_client_request_args: for key, value in self.kube_config.kube_client_request_args.items(): kwargs[key] = value - pod_list = self.kube_client.list_namespaced_pod( - self.kube_config.kube_namespace, **kwargs) + pod_list = self.kube_client.list_namespaced_pod(self.kube_config.kube_namespace, **kwargs) if not pod_list.items: self.log.info( - 'TaskInstance: %s found in queued state but was not launched, ' - 'rescheduling', task + 'TaskInstance: %s found in queued state but was not launched, ' 'rescheduling', task ) session.query(TaskInstance).filter( TaskInstance.dag_id == task.dag_id, TaskInstance.task_id == task.task_id, - TaskInstance.execution_date == task.execution_date + TaskInstance.execution_date == task.execution_date, ).update({TaskInstance.state: State.NONE}) def _inject_secrets(self) -> None: def _create_or_update_secret(secret_name, secret_path): try: return self.kube_client.create_namespaced_secret( - self.kube_config.executor_namespace, kubernetes.client.V1Secret( - data={ - 'key.json': base64.b64encode(open(secret_path).read())}, - metadata=kubernetes.client.V1ObjectMeta(name=secret_name)), - **self.kube_config.kube_client_request_args) + self.kube_config.executor_namespace, + kubernetes.client.V1Secret( + data={'key.json': base64.b64encode(open(secret_path).read())}, + metadata=kubernetes.client.V1ObjectMeta(name=secret_name), + ), + **self.kube_config.kube_client_request_args, + ) except ApiException as e: if e.status == 409: return self.kube_client.replace_namespaced_secret( - secret_name, self.kube_config.executor_namespace, + secret_name, + self.kube_config.executor_namespace, kubernetes.client.V1Secret( - data={'key.json': base64.b64encode( - open(secret_path).read())}, - metadata=kubernetes.client.V1ObjectMeta(name=secret_name)), - **self.kube_config.kube_client_request_args) + data={'key.json': base64.b64encode(open(secret_path).read())}, + metadata=kubernetes.client.V1ObjectMeta(name=secret_name), + ), + **self.kube_config.kube_client_request_args, + ) self.log.exception( - 'Exception while trying to inject secret. ' - 'Secret name: %s, error details: %s', - secret_name, e + 'Exception while trying to inject secret. ' 'Secret name: %s, error details: %s', + secret_name, + e, ) raise @@ -599,22 +604,20 @@ def start(self) -> None: self.log.debug('Start with scheduler_job_id: %s', self.scheduler_job_id) self.kube_client = get_kube_client() self.kube_scheduler = AirflowKubernetesScheduler( - self.kube_config, self.task_queue, self.result_queue, - self.kube_client, self.scheduler_job_id + self.kube_config, self.task_queue, self.result_queue, self.kube_client, self.scheduler_job_id ) self._inject_secrets() self.clear_not_launched_queued_tasks() - def execute_async(self, - key: TaskInstanceKey, - command: CommandType, - queue: Optional[str] = None, - executor_config: Optional[Any] = None) -> None: + def execute_async( + self, + key: TaskInstanceKey, + command: CommandType, + queue: Optional[str] = None, + executor_config: Optional[Any] = None, + ) -> None: """Executes task asynchronously""" - self.log.info( - 'Add task %s with command %s with executor_config %s', - key, command, executor_config - ) + self.log.info('Add task %s with command %s with executor_config %s', key, command, executor_config) kube_executor_config = PodGenerator.from_obj(executor_config) if not self.task_queue: raise AirflowException(NOT_STARTED_MESSAGE) @@ -652,7 +655,9 @@ def sync(self) -> None: except Exception as e: # pylint: disable=broad-except self.log.exception( "Exception: %s when attempting to change state of %s to %s, re-queueing.", - e, results, state + e, + results, + state, ) self.result_queue.put(results) finally: @@ -675,8 +680,10 @@ def sync(self) -> None: key, _, _ = task self.change_state(key, State.FAILED, e) else: - self.log.warning('ApiException when attempting to run task, re-queueing. ' - 'Message: %s', json.loads(e.body)['message']) + self.log.warning( + 'ApiException when attempting to run task, re-queueing. ' 'Message: %s', + json.loads(e.body)['message'], + ) self.task_queue.put(task) finally: self.task_queue.task_done() @@ -684,11 +691,7 @@ def sync(self) -> None: break # pylint: enable=too-many-nested-blocks - def _change_state(self, - key: TaskInstanceKey, - state: Optional[str], - pod_id: str, - namespace: str) -> None: + def _change_state(self, key: TaskInstanceKey, state: Optional[str], pod_id: str, namespace: str) -> None: if state != State.RUNNING: if self.kube_config.delete_worker_pods: if not self.kube_scheduler: @@ -706,18 +709,12 @@ def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance tis_to_flush = [ti for ti in tis if not ti.external_executor_id] scheduler_job_ids = [ti.external_executor_id for ti in tis] pod_ids = { - create_pod_id(dag_id=ti.dag_id, task_id=ti.task_id): ti - for ti in tis if ti.external_executor_id + create_pod_id(dag_id=ti.dag_id, task_id=ti.task_id): ti for ti in tis if ti.external_executor_id } kube_client: client.CoreV1Api = self.kube_client for scheduler_job_id in scheduler_job_ids: - kwargs = { - 'label_selector': f'airflow-worker={scheduler_job_id}' - } - pod_list = kube_client.list_namespaced_pod( - namespace=self.kube_config.kube_namespace, - **kwargs - ) + kwargs = {'label_selector': f'airflow-worker={scheduler_job_id}'} + pod_list = kube_client.list_namespaced_pod(namespace=self.kube_config.kube_namespace, **kwargs) for pod in pod_list.items: self.adopt_launched_task(kube_client, pod, pod_ids) self._adopt_completed_pods(kube_client) @@ -738,8 +735,11 @@ def adopt_launched_task(self, kube_client, pod, pod_ids: dict): task_id = pod.metadata.labels['task_id'] pod_id = create_pod_id(dag_id=dag_id, task_id=task_id) if pod_id not in pod_ids: - self.log.error("attempting to adopt task %s in dag %s" - " which was not specified by database", task_id, dag_id) + self.log.error( + "attempting to adopt task %s in dag %s" " which was not specified by database", + task_id, + dag_id, + ) else: try: kube_client.patch_namespaced_pod( @@ -798,13 +798,18 @@ def _flush_result_queue(self) -> None: self.log.warning('Executor shutting down, flushing results=%s', results) try: key, state, pod_id, namespace, resource_version = results - self.log.info('Changing state of %s to %s : resource_version=%d', results, state, - resource_version) + self.log.info( + 'Changing state of %s to %s : resource_version=%d', results, state, resource_version + ) try: self._change_state(key, state, pod_id, namespace) except Exception as e: # pylint: disable=broad-except - self.log.exception('Ignoring exception: %s when attempting to change state of %s ' - 'to %s.', e, results, state) + self.log.exception( + 'Ignoring exception: %s when attempting to change state of %s ' 'to %s.', + e, + results, + state, + ) finally: self.result_queue.task_done() except Empty: diff --git a/airflow/executors/local_executor.py b/airflow/executors/local_executor.py index 4ad4033666276..729a45e8b1e43 100644 --- a/airflow/executors/local_executor.py +++ b/airflow/executors/local_executor.py @@ -36,7 +36,8 @@ from airflow.exceptions import AirflowException from airflow.executors.base_executor import NOT_STARTED_MESSAGE, PARALLELISM, BaseExecutor, CommandType from airflow.models.taskinstance import ( # pylint: disable=unused-import # noqa: F401 - TaskInstanceKey, TaskInstanceStateType, + TaskInstanceKey, + TaskInstanceStateType, ) from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.state import State @@ -100,6 +101,7 @@ def _execute_work_in_fork(self, command: CommandType) -> str: return State.SUCCESS if ret == 0 else State.FAILED from airflow.sentry import Sentry + ret = 1 try: import signal @@ -140,10 +142,9 @@ class LocalWorker(LocalWorkerBase): :param command: Command to execute """ - def __init__(self, - result_queue: 'Queue[TaskInstanceStateType]', - key: TaskInstanceKey, - command: CommandType): + def __init__( + self, result_queue: 'Queue[TaskInstanceStateType]', key: TaskInstanceKey, command: CommandType + ): super().__init__(result_queue) self.key: TaskInstanceKey = key self.command: CommandType = command @@ -162,9 +163,7 @@ class QueuedLocalWorker(LocalWorkerBase): :param result_queue: queue where worker puts results after finishing tasks """ - def __init__(self, - task_queue: 'Queue[ExecutorWorkType]', - result_queue: 'Queue[TaskInstanceStateType]'): + def __init__(self, task_queue: 'Queue[ExecutorWorkType]', result_queue: 'Queue[TaskInstanceStateType]'): super().__init__(result_queue=result_queue) self.task_queue = task_queue @@ -196,8 +195,9 @@ def __init__(self, parallelism: int = PARALLELISM): self.workers: List[QueuedLocalWorker] = [] self.workers_used: int = 0 self.workers_active: int = 0 - self.impl: Optional[Union['LocalExecutor.UnlimitedParallelism', - 'LocalExecutor.LimitedParallelism']] = None + self.impl: Optional[ + Union['LocalExecutor.UnlimitedParallelism', 'LocalExecutor.LimitedParallelism'] + ] = None class UnlimitedParallelism: """ @@ -216,11 +216,13 @@ def start(self) -> None: self.executor.workers_active = 0 # pylint: disable=unused-argument # pragma: no cover - def execute_async(self, - key: TaskInstanceKey, - command: CommandType, - queue: Optional[str] = None, - executor_config: Optional[Any] = None) -> None: + def execute_async( + self, + key: TaskInstanceKey, + command: CommandType, + queue: Optional[str] = None, + executor_config: Optional[Any] = None, + ) -> None: """ Executes task asynchronously. @@ -289,7 +291,7 @@ def execute_async( key: TaskInstanceKey, command: CommandType, queue: Optional[str] = None, # pylint: disable=unused-argument - executor_config: Optional[Any] = None # pylint: disable=unused-argument + executor_config: Optional[Any] = None, # pylint: disable=unused-argument ) -> None: """ Executes task asynchronously. @@ -331,15 +333,21 @@ def start(self) -> None: self.workers = [] self.workers_used = 0 self.workers_active = 0 - self.impl = (LocalExecutor.UnlimitedParallelism(self) if self.parallelism == 0 - else LocalExecutor.LimitedParallelism(self)) + self.impl = ( + LocalExecutor.UnlimitedParallelism(self) + if self.parallelism == 0 + else LocalExecutor.LimitedParallelism(self) + ) self.impl.start() - def execute_async(self, key: TaskInstanceKey, - command: CommandType, - queue: Optional[str] = None, - executor_config: Optional[Any] = None) -> None: + def execute_async( + self, + key: TaskInstanceKey, + command: CommandType, + queue: Optional[str] = None, + executor_config: Optional[Any] = None, + ) -> None: """Execute asynchronously.""" if not self.impl: raise AirflowException(NOT_STARTED_MESSAGE) diff --git a/airflow/executors/sequential_executor.py b/airflow/executors/sequential_executor.py index 18a4747790c45..456e3e9893e8b 100644 --- a/airflow/executors/sequential_executor.py +++ b/airflow/executors/sequential_executor.py @@ -44,11 +44,13 @@ def __init__(self): super().__init__() self.commands_to_run = [] - def execute_async(self, - key: TaskInstanceKey, - command: CommandType, - queue: Optional[str] = None, - executor_config: Optional[Any] = None) -> None: + def execute_async( + self, + key: TaskInstanceKey, + command: CommandType, + queue: Optional[str] = None, + executor_config: Optional[Any] = None, + ) -> None: self.validate_command(command) self.commands_to_run.append((key, command)) diff --git a/airflow/hooks/base_hook.py b/airflow/hooks/base_hook.py index 9a54077519ea1..cb9f6dcab4345 100644 --- a/airflow/hooks/base_hook.py +++ b/airflow/hooks/base_hook.py @@ -64,7 +64,7 @@ def get_connection(cls, conn_id: str) -> Connection: conn.schema, conn.login, "XXXXXXXX" if conn.password else None, - "XXXXXXXX" if conn.extra_dejson else None + "XXXXXXXX" if conn.extra_dejson else None, ) return conn diff --git a/airflow/hooks/dbapi_hook.py b/airflow/hooks/dbapi_hook.py index fde2463409df0..17a4c81e42f8b 100644 --- a/airflow/hooks/dbapi_hook.py +++ b/airflow/hooks/dbapi_hook.py @@ -67,11 +67,7 @@ def __init__(self, *args, **kwargs): def get_conn(self): """Returns a connection object""" db = self.get_connection(getattr(self, self.conn_name_attr)) - return self.connector.connect( - host=db.host, - port=db.port, - username=db.login, - schema=db.schema) + return self.connector.connect(host=db.host, port=db.port, username=db.login, schema=db.schema) def get_uri(self) -> str: """ @@ -86,8 +82,7 @@ def get_uri(self) -> str: host = conn.host if conn.port is not None: host += f':{conn.port}' - uri = '{conn.conn_type}://{login}{host}/'.format( - conn=conn, login=login, host=host) + uri = f'{conn.conn_type}://{login}{host}/' if conn.schema: uri += conn.schema return uri @@ -199,7 +194,7 @@ def set_autocommit(self, conn, autocommit): if not self.supports_autocommit and autocommit: self.log.warning( "%s connection doesn't support autocommit but autocommit activated.", - getattr(self, self.conn_name_attr) + getattr(self, self.conn_name_attr), ) conn.autocommit = autocommit @@ -238,7 +233,9 @@ def _generate_insert_sql(table, values, target_fields, replace, **kwargs): :return: The generated INSERT or REPLACE SQL statement :rtype: str """ - placeholders = ["%s", ] * len(values) + placeholders = [ + "%s", + ] * len(values) if target_fields: target_fields = ", ".join(target_fields) @@ -250,14 +247,10 @@ def _generate_insert_sql(table, values, target_fields, replace, **kwargs): sql = "INSERT INTO " else: sql = "REPLACE INTO " - sql += "{} {} VALUES ({})".format( - table, - target_fields, - ",".join(placeholders)) + sql += "{} {} VALUES ({})".format(table, target_fields, ",".join(placeholders)) return sql - def insert_rows(self, table, rows, target_fields=None, commit_every=1000, - replace=False, **kwargs): + def insert_rows(self, table, rows, target_fields=None, commit_every=1000, replace=False, **kwargs): """ A generic way to insert a set of tuples into a table, a new transaction is created every commit_every rows @@ -287,15 +280,11 @@ def insert_rows(self, table, rows, target_fields=None, commit_every=1000, for cell in row: lst.append(self._serialize_cell(cell, conn)) values = tuple(lst) - sql = self._generate_insert_sql( - table, values, target_fields, replace, **kwargs - ) + sql = self._generate_insert_sql(table, values, target_fields, replace, **kwargs) cur.execute(sql, values) if commit_every and i % commit_every == 0: conn.commit() - self.log.info( - "Loaded %s rows into %s so far", i, table - ) + self.log.info("Loaded %s rows into %s so far", i, table) conn.commit() self.log.info("Done loading. Loaded a total of %s rows", i) diff --git a/airflow/hooks/docker_hook.py b/airflow/hooks/docker_hook.py index b7da340770da6..ff02d3c1978c9 100644 --- a/airflow/hooks/docker_hook.py +++ b/airflow/hooks/docker_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.docker.hooks.docker`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/hooks/druid_hook.py b/airflow/hooks/druid_hook.py index 11f7d6c4429b3..9ae87b01ef352 100644 --- a/airflow/hooks/druid_hook.py +++ b/airflow/hooks/druid_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.apache.druid.hooks.druid`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/hooks/hdfs_hook.py b/airflow/hooks/hdfs_hook.py index 10e81c5c6d39e..9b0cdb66d311b 100644 --- a/airflow/hooks/hdfs_hook.py +++ b/airflow/hooks/hdfs_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.apache.hdfs.hooks.hdfs`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/hooks/hive_hooks.py b/airflow/hooks/hive_hooks.py index 990ee470430a6..5009647593e49 100644 --- a/airflow/hooks/hive_hooks.py +++ b/airflow/hooks/hive_hooks.py @@ -21,10 +21,14 @@ # pylint: disable=unused-import from airflow.providers.apache.hive.hooks.hive import ( # noqa - HIVE_QUEUE_PRIORITIES, HiveCliHook, HiveMetastoreHook, HiveServer2Hook, + HIVE_QUEUE_PRIORITIES, + HiveCliHook, + HiveMetastoreHook, + HiveServer2Hook, ) warnings.warn( "This module is deprecated. Please use `airflow.providers.apache.hive.hooks.hive`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/hooks/http_hook.py b/airflow/hooks/http_hook.py index 0bc5203996b6e..80fed8f0d94e3 100644 --- a/airflow/hooks/http_hook.py +++ b/airflow/hooks/http_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.http.hooks.http`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/hooks/jdbc_hook.py b/airflow/hooks/jdbc_hook.py index 4ae5efca20162..1f7e127b98844 100644 --- a/airflow/hooks/jdbc_hook.py +++ b/airflow/hooks/jdbc_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.jdbc.hooks.jdbc`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/hooks/mssql_hook.py b/airflow/hooks/mssql_hook.py index 2584ea76ee73f..cc9bc848bcf72 100644 --- a/airflow/hooks/mssql_hook.py +++ b/airflow/hooks/mssql_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.microsoft.mssql.hooks.mssql`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/hooks/mysql_hook.py b/airflow/hooks/mysql_hook.py index a077a535a9aa1..97e3669613cc3 100644 --- a/airflow/hooks/mysql_hook.py +++ b/airflow/hooks/mysql_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.mysql.hooks.mysql`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/hooks/oracle_hook.py b/airflow/hooks/oracle_hook.py index e820ee3bc1aa6..202e9fe9f4f0f 100644 --- a/airflow/hooks/oracle_hook.py +++ b/airflow/hooks/oracle_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.oracle.hooks.oracle`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/hooks/pig_hook.py b/airflow/hooks/pig_hook.py index c750c87615f1f..9035a262263cc 100644 --- a/airflow/hooks/pig_hook.py +++ b/airflow/hooks/pig_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.apache.pig.hooks.pig`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/hooks/postgres_hook.py b/airflow/hooks/postgres_hook.py index e72d517da2c2b..1202a3f144d40 100644 --- a/airflow/hooks/postgres_hook.py +++ b/airflow/hooks/postgres_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.postgres.hooks.postgres`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/hooks/presto_hook.py b/airflow/hooks/presto_hook.py index 75358eecb4b02..2daca69966954 100644 --- a/airflow/hooks/presto_hook.py +++ b/airflow/hooks/presto_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.presto.hooks.presto`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/hooks/samba_hook.py b/airflow/hooks/samba_hook.py index 72c1af508804c..715655dca0ac6 100644 --- a/airflow/hooks/samba_hook.py +++ b/airflow/hooks/samba_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.samba.hooks.samba`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/hooks/slack_hook.py b/airflow/hooks/slack_hook.py index de2b6720f71da..c0c9dadbc0f8f 100644 --- a/airflow/hooks/slack_hook.py +++ b/airflow/hooks/slack_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.slack.hooks.slack`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/hooks/sqlite_hook.py b/airflow/hooks/sqlite_hook.py index 9b770c77aabcd..e8bdd64d47013 100644 --- a/airflow/hooks/sqlite_hook.py +++ b/airflow/hooks/sqlite_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.sqlite.hooks.sqlite`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/hooks/webhdfs_hook.py b/airflow/hooks/webhdfs_hook.py index 6c134215c6987..3d4291a2ab60f 100644 --- a/airflow/hooks/webhdfs_hook.py +++ b/airflow/hooks/webhdfs_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.apache.hdfs.hooks.webhdfs`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/hooks/zendesk_hook.py b/airflow/hooks/zendesk_hook.py index db20fe0b03711..d5a5e659f5770 100644 --- a/airflow/hooks/zendesk_hook.py +++ b/airflow/hooks/zendesk_hook.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.zendesk.hooks.zendesk`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/jobs/backfill_job.py b/airflow/jobs/backfill_job.py index ad7ba436e5f6b..63a711aa22e84 100644 --- a/airflow/jobs/backfill_job.py +++ b/airflow/jobs/backfill_job.py @@ -28,7 +28,11 @@ from airflow import models from airflow.exceptions import ( - AirflowException, BackfillUnfinished, DagConcurrencyLimitReached, NoAvailablePoolSlot, PoolNotFound, + AirflowException, + BackfillUnfinished, + DagConcurrencyLimitReached, + NoAvailablePoolSlot, + PoolNotFound, TaskConcurrencyLimitReached, ) from airflow.executors.executor_loader import ExecutorLoader @@ -54,9 +58,7 @@ class BackfillJob(BaseJob): STATES_COUNT_AS_RUNNING = (State.RUNNING, State.QUEUED) - __mapper_args__ = { - 'polymorphic_identity': 'BackfillJob' - } + __mapper_args__ = {'polymorphic_identity': 'BackfillJob'} class _DagRunTaskStatus: """ @@ -93,19 +95,20 @@ class _DagRunTaskStatus: """ # TODO(edgarRd): AIRFLOW-1444: Add consistency check on counts - def __init__(self, # pylint: disable=too-many-arguments - to_run=None, - running=None, - skipped=None, - succeeded=None, - failed=None, - not_ready=None, - deadlocked=None, - active_runs=None, - executed_dag_run_dates=None, - finished_runs=0, - total_runs=0, - ): + def __init__( # pylint: disable=too-many-arguments + self, + to_run=None, + running=None, + skipped=None, + succeeded=None, + failed=None, + not_ready=None, + deadlocked=None, + active_runs=None, + executed_dag_run_dates=None, + finished_runs=0, + total_runs=0, + ): self.to_run = to_run or OrderedDict() self.running = running or {} self.skipped = skipped or set() @@ -119,21 +122,23 @@ def __init__(self, # pylint: disable=too-many-arguments self.total_runs = total_runs def __init__( # pylint: disable=too-many-arguments - self, - dag, - start_date=None, - end_date=None, - mark_success=False, - donot_pickle=False, - ignore_first_depends_on_past=False, - ignore_task_deps=False, - pool=None, - delay_on_limit_secs=1.0, - verbose=False, - conf=None, - rerun_failed_tasks=False, - run_backwards=False, - *args, **kwargs): + self, + dag, + start_date=None, + end_date=None, + mark_success=False, + donot_pickle=False, + ignore_first_depends_on_past=False, + ignore_task_deps=False, + pool=None, + delay_on_limit_secs=1.0, + verbose=False, + conf=None, + rerun_failed_tasks=False, + run_backwards=False, + *args, + **kwargs, + ): """ :param dag: DAG object. :type dag: airflow.models.DAG @@ -234,7 +239,7 @@ def _update_counters(self, ti_status, session=None): self.log.warning( "FIXME: task instance %s state was set to none externally or " "reaching concurrency limits. Re-adding task to queue.", - ti + ti, ) tis_to_be_scheduled.append(ti) ti_status.running.pop(key) @@ -260,10 +265,7 @@ def _manage_executor_state(self, running): for key, value in list(executor.get_event_buffer().items()): state, info = value if key not in running: - self.log.warning( - "%s state %s not in running=%s", - key, state, running.values() - ) + self.log.warning("%s state %s not in running=%s", key, state, running.values()) continue ti = running[key] @@ -272,9 +274,11 @@ def _manage_executor_state(self, running): self.log.debug("Executor state: %s task %s", state, ti) if state in (State.FAILED, State.SUCCESS) and ti.state in self.STATES_COUNT_AS_RUNNING: - msg = ("Executor reports task instance {} finished ({}) " - "although the task says its {}. Was the task " - "killed externally? Info: {}".format(ti, state, ti.state, info)) + msg = ( + "Executor reports task instance {} finished ({}) " + "although the task says its {}. Was the task " + "killed externally? Info: {}".format(ti, state, ti.state, info) + ) self.log.error(msg) ti.handle_failure(msg) @@ -297,11 +301,7 @@ def _get_dag_run(self, run_date: datetime, dag: DAG, session: Session = None): # check if we are scheduling on top of a already existing dag_run # we could find a "scheduled" run instead of a "backfill" - runs = DagRun.find( - dag_id=dag.dag_id, - execution_date=run_date, - session=session - ) + runs = DagRun.find(dag_id=dag.dag_id, execution_date=run_date, session=session) run: Optional[DagRun] if runs: run = runs[0] @@ -312,8 +312,7 @@ def _get_dag_run(self, run_date: datetime, dag: DAG, session: Session = None): # enforce max_active_runs limit for dag, special cases already # handled by respect_dag_max_active_limit - if (respect_dag_max_active_limit and - current_active_dag_count >= dag.max_active_runs): + if respect_dag_max_active_limit and current_active_dag_count >= dag.max_active_runs: return None run = run or dag.create_dagrun( @@ -378,22 +377,28 @@ def _log_progress(self, ti_status): self.log.info( '[backfill progress] | finished run %s of %s | tasks waiting: %s | succeeded: %s | ' 'running: %s | failed: %s | skipped: %s | deadlocked: %s | not ready: %s', - ti_status.finished_runs, ti_status.total_runs, len(ti_status.to_run), len(ti_status.succeeded), - len(ti_status.running), len(ti_status.failed), len(ti_status.skipped), len(ti_status.deadlocked), - len(ti_status.not_ready) + ti_status.finished_runs, + ti_status.total_runs, + len(ti_status.to_run), + len(ti_status.succeeded), + len(ti_status.running), + len(ti_status.failed), + len(ti_status.skipped), + len(ti_status.deadlocked), + len(ti_status.not_ready), ) - self.log.debug( - "Finished dag run loop iteration. Remaining tasks %s", - ti_status.to_run.values() - ) + self.log.debug("Finished dag run loop iteration. Remaining tasks %s", ti_status.to_run.values()) @provide_session - def _process_backfill_task_instances(self, # pylint: disable=too-many-statements - ti_status, - executor, - pickle_id, - start_date=None, session=None): + def _process_backfill_task_instances( # pylint: disable=too-many-statements + self, + ti_status, + executor, + pickle_id, + start_date=None, + session=None, + ): """ Process a set of task instances from a set of dag runs. Special handling is done to account for different task instance states that could be present when running @@ -414,8 +419,7 @@ def _process_backfill_task_instances(self, # pylint: disable=too-many-statement """ executed_run_dates = [] - while ((len(ti_status.to_run) > 0 or len(ti_status.running) > 0) and - len(ti_status.deadlocked) == 0): + while (len(ti_status.to_run) > 0 or len(ti_status.running) > 0) and len(ti_status.deadlocked) == 0: self.log.debug("*** Clearing out not_ready list ***") ti_status.not_ready.clear() @@ -430,11 +434,10 @@ def _per_task_process(key, ti, session=None): # pylint: disable=too-many-return task = self.dag.get_task(ti.task_id, include_subdags=True) ti.task = task - ignore_depends_on_past = ( - self.ignore_first_depends_on_past and - ti.execution_date == (start_date or ti.start_date)) - self.log.debug( - "Task instance to run %s state %s", ti, ti.state) + ignore_depends_on_past = self.ignore_first_depends_on_past and ti.execution_date == ( + start_date or ti.start_date + ) + self.log.debug("Task instance to run %s state %s", ti, ti.state) # The task was already marked successful or skipped by a # different Job. Don't rerun it. @@ -457,8 +460,7 @@ def _per_task_process(key, ti, session=None): # pylint: disable=too-many-return # in case max concurrency has been reached at task runtime elif ti.state == State.NONE: self.log.warning( - "FIXME: task instance {} state was set to None " - "externally. This should not happen" + "FIXME: task instance {} state was set to None " "externally. This should not happen" ) ti.set_state(State.SCHEDULED, session=session) if self.rerun_failed_tasks: @@ -483,19 +485,17 @@ def _per_task_process(key, ti, session=None): # pylint: disable=too-many-return deps=BACKFILL_QUEUED_DEPS, ignore_depends_on_past=ignore_depends_on_past, ignore_task_deps=self.ignore_task_deps, - flag_upstream_failed=True) + flag_upstream_failed=True, + ) # Is the task runnable? -- then run it # the dependency checker can change states of tis if ti.are_dependencies_met( - dep_context=backfill_context, - session=session, - verbose=self.verbose): + dep_context=backfill_context, session=session, verbose=self.verbose + ): if executor.has_task(ti): self.log.debug( - "Task Instance %s already in executor " - "waiting for queue to clear", - ti + "Task Instance %s already in executor " "waiting for queue to clear", ti ) else: self.log.debug('Sending %s to executor', ti) @@ -507,7 +507,8 @@ def _per_task_process(key, ti, session=None): # pylint: disable=too-many-return cfg_path = None if self.executor_class in ( - ExecutorLoader.LOCAL_EXECUTOR, ExecutorLoader.SEQUENTIAL_EXECUTOR + ExecutorLoader.LOCAL_EXECUTOR, + ExecutorLoader.SEQUENTIAL_EXECUTOR, ): cfg_path = tmp_configuration_copy() @@ -518,7 +519,8 @@ def _per_task_process(key, ti, session=None): # pylint: disable=too-many-return ignore_task_deps=self.ignore_task_deps, ignore_depends_on_past=ignore_depends_on_past, pool=self.pool, - cfg_path=cfg_path) + cfg_path=cfg_path, + ) ti_status.running[key] = ti ti_status.to_run.pop(key) session.commit() @@ -534,9 +536,7 @@ def _per_task_process(key, ti, session=None): # pylint: disable=too-many-return # special case if ti.state == State.UP_FOR_RETRY: - self.log.debug( - "Task instance %s retry period not " - "expired yet", ti) + self.log.debug("Task instance %s retry period not " "expired yet", ti) if key in ti_status.running: ti_status.running.pop(key) ti_status.to_run[key] = ti @@ -544,9 +544,7 @@ def _per_task_process(key, ti, session=None): # pylint: disable=too-many-return # special case if ti.state == State.UP_FOR_RESCHEDULE: - self.log.debug( - "Task instance %s reschedule period not " - "expired yet", ti) + self.log.debug("Task instance %s reschedule period not " "expired yet", ti) if key in ti_status.running: ti_status.running.pop(key) ti_status.to_run[key] = ti @@ -562,9 +560,7 @@ def _per_task_process(key, ti, session=None): # pylint: disable=too-many-return if task.task_id != ti.task_id: continue - pool = session.query(models.Pool) \ - .filter(models.Pool.pool == task.pool) \ - .first() + pool = session.query(models.Pool).filter(models.Pool.pool == task.pool).first() if not pool: raise PoolNotFound(f'Unknown pool: {task.pool}') @@ -572,8 +568,8 @@ def _per_task_process(key, ti, session=None): # pylint: disable=too-many-return if open_slots <= 0: raise NoAvailablePoolSlot( "Not scheduling since there are " - "{} open slots in pool {}".format( - open_slots, task.pool)) + "{} open slots in pool {}".format(open_slots, task.pool) + ) num_running_task_instances_in_dag = DAG.get_num_task_instances( self.dag_id, @@ -582,8 +578,7 @@ def _per_task_process(key, ti, session=None): # pylint: disable=too-many-return if num_running_task_instances_in_dag >= self.dag.concurrency: raise DagConcurrencyLimitReached( - "Not scheduling since DAG concurrency limit " - "is reached." + "Not scheduling since DAG concurrency limit " "is reached." ) if task.task_concurrency: @@ -595,8 +590,7 @@ def _per_task_process(key, ti, session=None): # pylint: disable=too-many-return if num_running_task_instances_in_task >= task.task_concurrency: raise TaskConcurrencyLimitReached( - "Not scheduling since Task concurrency limit " - "is reached." + "Not scheduling since Task concurrency limit " "is reached." ) _per_task_process(key, ti) @@ -610,13 +604,12 @@ def _per_task_process(key, ti, session=None): # pylint: disable=too-many-return # If the set of tasks that aren't ready ever equals the set of # tasks to run and there are no running tasks then the backfill # is deadlocked - if (ti_status.not_ready and - ti_status.not_ready == set(ti_status.to_run) and - len(ti_status.running) == 0): - self.log.warning( - "Deadlock discovered for ti_status.to_run=%s", - ti_status.to_run.values() - ) + if ( + ti_status.not_ready + and ti_status.not_ready == set(ti_status.to_run) + and len(ti_status.running) == 0 + ): + self.log.warning("Deadlock discovered for ti_status.to_run=%s", ti_status.to_run.values()) ti_status.deadlocked.update(ti_status.to_run.values()) ti_status.to_run.clear() @@ -645,19 +638,17 @@ def _collect_errors(self, ti_status, session=None): def tabulate_ti_keys_set(set_ti_keys: Set[TaskInstanceKey]) -> str: # Sorting by execution date first sorted_ti_keys = sorted( - set_ti_keys, key=lambda ti_key: - (ti_key.execution_date, ti_key.dag_id, ti_key.task_id, ti_key.try_number) + set_ti_keys, + key=lambda ti_key: (ti_key.execution_date, ti_key.dag_id, ti_key.task_id, ti_key.try_number), ) return tabulate(sorted_ti_keys, headers=["DAG ID", "Task ID", "Execution date", "Try number"]) def tabulate_tis_set(set_tis: Set[TaskInstance]) -> str: # Sorting by execution date first sorted_tis = sorted( - set_tis, key=lambda ti: (ti.execution_date, ti.dag_id, ti.task_id, ti.try_number)) - tis_values = ( - (ti.dag_id, ti.task_id, ti.execution_date, ti.try_number) - for ti in sorted_tis + set_tis, key=lambda ti: (ti.execution_date, ti.dag_id, ti.task_id, ti.try_number) ) + tis_values = ((ti.dag_id, ti.task_id, ti.execution_date, ti.try_number) for ti in sorted_tis) return tabulate(tis_values, headers=["DAG ID", "Task ID", "Execution date", "Try number"]) err = '' @@ -670,19 +661,21 @@ def tabulate_tis_set(set_tis: Set[TaskInstance]) -> str: t.are_dependencies_met( dep_context=DepContext(ignore_depends_on_past=False), session=session, - verbose=self.verbose) != - t.are_dependencies_met( - dep_context=DepContext(ignore_depends_on_past=True), - session=session, - verbose=self.verbose) - for t in ti_status.deadlocked) + verbose=self.verbose, + ) + != t.are_dependencies_met( + dep_context=DepContext(ignore_depends_on_past=True), session=session, verbose=self.verbose + ) + for t in ti_status.deadlocked + ) if deadlocked_depends_on_past: err += ( 'Some of the deadlocked tasks were unable to run because ' 'of "depends_on_past" relationships. Try running the ' 'backfill with the option ' '"ignore_first_depends_on_past=True" or passing "-I" at ' - 'the command line.') + 'the command line.' + ) err += '\nThese tasks have succeeded:\n' err += tabulate_ti_keys_set(ti_status.succeeded) err += '\n\nThese tasks are running:\n' @@ -697,8 +690,7 @@ def tabulate_tis_set(set_tis: Set[TaskInstance]) -> str: return err @provide_session - def _execute_for_run_dates(self, run_dates, ti_status, executor, pickle_id, - start_date, session=None): + def _execute_for_run_dates(self, run_dates, ti_status, executor, pickle_id, start_date, session=None): """ Computes the dag runs and their respective task instances for the given run dates and executes the task instances. @@ -720,8 +712,7 @@ def _execute_for_run_dates(self, run_dates, ti_status, executor, pickle_id, for next_run_date in run_dates: for dag in [self.dag] + self.dag.subdags: dag_run = self._get_dag_run(next_run_date, dag, session=session) - tis_map = self._task_instances_for_dag_run(dag_run, - session=session) + tis_map = self._task_instances_for_dag_run(dag_run, session=session) if dag_run is None: continue @@ -733,7 +724,8 @@ def _execute_for_run_dates(self, run_dates, ti_status, executor, pickle_id, executor=executor, pickle_id=pickle_id, start_date=start_date, - session=session) + session=session, + ) ti_status.executed_dag_run_dates.update(processed_dag_run_dates) @@ -764,14 +756,15 @@ def _execute(self, session=None): start_date = self.bf_start_date # Get intervals between the start/end dates, which will turn into dag runs - run_dates = self.dag.get_run_dates(start_date=start_date, - end_date=self.bf_end_date) + run_dates = self.dag.get_run_dates(start_date=start_date, end_date=self.bf_end_date) if self.run_backwards: tasks_that_depend_on_past = [t.task_id for t in self.dag.task_dict.values() if t.depends_on_past] if tasks_that_depend_on_past: raise AirflowException( 'You cannot backfill backwards because one or more tasks depend_on_past: {}'.format( - ",".join(tasks_that_depend_on_past))) + ",".join(tasks_that_depend_on_past) + ) + ) run_dates = run_dates[::-1] if len(run_dates) == 0: @@ -782,7 +775,9 @@ def _execute(self, session=None): pickle_id = None if not self.donot_pickle and self.executor_class not in ( - ExecutorLoader.LOCAL_EXECUTOR, ExecutorLoader.SEQUENTIAL_EXECUTOR, ExecutorLoader.DASK_EXECUTOR, + ExecutorLoader.LOCAL_EXECUTOR, + ExecutorLoader.SEQUENTIAL_EXECUTOR, + ExecutorLoader.DASK_EXECUTOR, ): pickle = DagPickle(self.dag) session.add(pickle) @@ -797,19 +792,20 @@ def _execute(self, session=None): try: # pylint: disable=too-many-nested-blocks remaining_dates = ti_status.total_runs while remaining_dates > 0: - dates_to_process = [run_date for run_date in run_dates if run_date not in - ti_status.executed_dag_run_dates] - - self._execute_for_run_dates(run_dates=dates_to_process, - ti_status=ti_status, - executor=executor, - pickle_id=pickle_id, - start_date=start_date, - session=session) - - remaining_dates = ( - ti_status.total_runs - len(ti_status.executed_dag_run_dates) + dates_to_process = [ + run_date for run_date in run_dates if run_date not in ti_status.executed_dag_run_dates + ] + + self._execute_for_run_dates( + run_dates=dates_to_process, + ti_status=ti_status, + executor=executor, + pickle_id=pickle_id, + start_date=start_date, + session=session, ) + + remaining_dates = ti_status.total_runs - len(ti_status.executed_dag_run_dates) err = self._collect_errors(ti_status=ti_status, session=session) if err: raise BackfillUnfinished(err, ti_status) @@ -818,7 +814,7 @@ def _execute(self, session=None): self.log.info( "max_active_runs limit for dag %s has been reached " " - waiting for other dag runs to finish", - self.dag_id + self.dag_id, ) time.sleep(self.delay_on_limit_secs) except (KeyboardInterrupt, SystemExit): @@ -854,21 +850,23 @@ def reset_state_for_orphaned_tasks(self, filter_by_dag_run=None, session=None): resettable_states = [State.SCHEDULED, State.QUEUED] if filter_by_dag_run is None: resettable_tis = ( - session - .query(TaskInstance) + session.query(TaskInstance) .join( DagRun, and_( TaskInstance.dag_id == DagRun.dag_id, - TaskInstance.execution_date == DagRun.execution_date)) + TaskInstance.execution_date == DagRun.execution_date, + ), + ) .filter( # pylint: disable=comparison-with-callable DagRun.state == State.RUNNING, DagRun.run_type != DagRunType.BACKFILL_JOB, - TaskInstance.state.in_(resettable_states))).all() + TaskInstance.state.in_(resettable_states), + ) + ).all() else: - resettable_tis = filter_by_dag_run.get_task_instances(state=resettable_states, - session=session) + resettable_tis = filter_by_dag_run.get_task_instances(state=resettable_states, session=session) tis_to_reset = [] # Can't use an update here since it doesn't support joins for ti in resettable_tis: @@ -883,9 +881,12 @@ def query(result, items): return result filter_for_tis = TaskInstance.filter_for_tis(items) - reset_tis = session.query(TaskInstance).filter( - filter_for_tis, TaskInstance.state.in_(resettable_states) - ).with_for_update().all() + reset_tis = ( + session.query(TaskInstance) + .filter(filter_for_tis, TaskInstance.state.in_(resettable_states)) + .with_for_update() + .all() + ) for ti in reset_tis: ti.state = State.NONE @@ -893,16 +894,10 @@ def query(result, items): return result + reset_tis - reset_tis = helpers.reduce_in_chunks(query, - tis_to_reset, - [], - self.max_tis_per_query) + reset_tis = helpers.reduce_in_chunks(query, tis_to_reset, [], self.max_tis_per_query) task_instance_str = '\n\t'.join([repr(x) for x in reset_tis]) session.commit() - self.log.info( - "Reset the following %s TaskInstances:\n\t%s", - len(reset_tis), task_instance_str - ) + self.log.info("Reset the following %s TaskInstances:\n\t%s", len(reset_tis), task_instance_str) return len(reset_tis) diff --git a/airflow/jobs/base_job.py b/airflow/jobs/base_job.py index caff8c52279f3..1ca0aaa19404b 100644 --- a/airflow/jobs/base_job.py +++ b/airflow/jobs/base_job.py @@ -53,7 +53,9 @@ class BaseJob(Base, LoggingMixin): __tablename__ = "job" id = Column(Integer, primary_key=True) - dag_id = Column(String(ID_LEN),) + dag_id = Column( + String(ID_LEN), + ) state = Column(String(20)) job_type = Column(String(30)) start_date = Column(UtcDateTime()) @@ -63,10 +65,7 @@ class BaseJob(Base, LoggingMixin): hostname = Column(String(500)) unixname = Column(String(1000)) - __mapper_args__ = { - 'polymorphic_on': job_type, - 'polymorphic_identity': 'BaseJob' - } + __mapper_args__ = {'polymorphic_on': job_type, 'polymorphic_identity': 'BaseJob'} __table_args__ = ( Index('job_type_heart', job_type, latest_heartbeat), @@ -95,11 +94,7 @@ class BaseJob(Base, LoggingMixin): heartrate = conf.getfloat('scheduler', 'JOB_HEARTBEAT_SEC') - def __init__( - self, - executor=None, - heartrate=None, - *args, **kwargs): + def __init__(self, executor=None, heartrate=None, *args, **kwargs): self.hostname = get_hostname() self.executor = executor or ExecutorLoader.get_default_executor() self.executor_class = self.executor.__class__.__name__ @@ -139,8 +134,9 @@ def is_alive(self, grace_multiplier=2.1): :rtype: boolean """ return ( - self.state == State.RUNNING and - (timezone.utcnow() - self.latest_heartbeat).total_seconds() < self.heartrate * grace_multiplier + self.state == State.RUNNING + and (timezone.utcnow() - self.latest_heartbeat).total_seconds() + < self.heartrate * grace_multiplier ) @provide_session @@ -206,9 +202,9 @@ def heartbeat(self, only_if_necessary: bool = False): # Figure out how long to sleep for sleep_for = 0 if self.latest_heartbeat: - seconds_remaining = self.heartrate - \ - (timezone.utcnow() - self.latest_heartbeat)\ - .total_seconds() + seconds_remaining = ( + self.heartrate - (timezone.utcnow() - self.latest_heartbeat).total_seconds() + ) sleep_for = max(0, seconds_remaining) sleep(sleep_for) @@ -224,9 +220,7 @@ def heartbeat(self, only_if_necessary: bool = False): self.heartbeat_callback(session=session) self.log.debug('[heartbeat]') except OperationalError: - Stats.incr( - convert_camel_to_snake(self.__class__.__name__) + '_heartbeat_failure', 1, - 1) + Stats.incr(convert_camel_to_snake(self.__class__.__name__) + '_heartbeat_failure', 1, 1) self.log.exception("%s heartbeat got an exception", self.__class__.__name__) # We didn't manage to heartbeat, so make sure that the timestamp isn't updated self.latest_heartbeat = previous_heartbeat diff --git a/airflow/jobs/local_task_job.py b/airflow/jobs/local_task_job.py index 0c3e86215c44e..f4d4ef0efcf86 100644 --- a/airflow/jobs/local_task_job.py +++ b/airflow/jobs/local_task_job.py @@ -36,21 +36,21 @@ class LocalTaskJob(BaseJob): """LocalTaskJob runs a single task instance.""" - __mapper_args__ = { - 'polymorphic_identity': 'LocalTaskJob' - } + __mapper_args__ = {'polymorphic_identity': 'LocalTaskJob'} def __init__( - self, - task_instance: TaskInstance, - ignore_all_deps: bool = False, - ignore_depends_on_past: bool = False, - ignore_task_deps: bool = False, - ignore_ti_state: bool = False, - mark_success: bool = False, - pickle_id: Optional[str] = None, - pool: Optional[str] = None, - *args, **kwargs): + self, + task_instance: TaskInstance, + ignore_all_deps: bool = False, + ignore_depends_on_past: bool = False, + ignore_task_deps: bool = False, + ignore_ti_state: bool = False, + mark_success: bool = False, + pickle_id: Optional[str] = None, + pool: Optional[str] = None, + *args, + **kwargs, + ): self.task_instance = task_instance self.dag_id = task_instance.dag_id self.ignore_all_deps = ignore_all_deps @@ -82,13 +82,14 @@ def signal_handler(signum, frame): signal.signal(signal.SIGTERM, signal_handler) if not self.task_instance.check_and_change_state_before_execution( - mark_success=self.mark_success, - ignore_all_deps=self.ignore_all_deps, - ignore_depends_on_past=self.ignore_depends_on_past, - ignore_task_deps=self.ignore_task_deps, - ignore_ti_state=self.ignore_ti_state, - job_id=self.id, - pool=self.pool): + mark_success=self.mark_success, + ignore_all_deps=self.ignore_all_deps, + ignore_depends_on_past=self.ignore_depends_on_past, + ignore_task_deps=self.ignore_task_deps, + ignore_ti_state=self.ignore_ti_state, + job_id=self.id, + pool=self.pool, + ): self.log.info("Task is not able to be run") return @@ -104,10 +105,12 @@ def signal_handler(signum, frame): max_wait_time = max( 0, # Make sure this value is never negative, min( - (heartbeat_time_limit - - (timezone.utcnow() - self.latest_heartbeat).total_seconds() * 0.75), + ( + heartbeat_time_limit + - (timezone.utcnow() - self.latest_heartbeat).total_seconds() * 0.75 + ), self.heartrate, - ) + ), ) return_code = self.task_runner.return_code(timeout=max_wait_time) @@ -124,10 +127,10 @@ def signal_handler(signum, frame): if time_since_last_heartbeat > heartbeat_time_limit: Stats.incr('local_task_job_prolonged_heartbeat_failure', 1, 1) self.log.error("Heartbeat time limit exceeded!") - raise AirflowException("Time since last heartbeat({:.2f}s) " - "exceeded limit ({}s)." - .format(time_since_last_heartbeat, - heartbeat_time_limit)) + raise AirflowException( + "Time since last heartbeat({:.2f}s) " + "exceeded limit ({}s).".format(time_since_last_heartbeat, heartbeat_time_limit) + ) finally: self.on_kill() @@ -150,25 +153,21 @@ def heartbeat_callback(self, session=None): fqdn = get_hostname() same_hostname = fqdn == ti.hostname if not same_hostname: - self.log.warning("The recorded hostname %s " - "does not match this instance's hostname " - "%s", ti.hostname, fqdn) + self.log.warning( + "The recorded hostname %s " "does not match this instance's hostname " "%s", + ti.hostname, + fqdn, + ) raise AirflowException("Hostname of job runner does not match") current_pid = os.getpid() same_process = ti.pid == current_pid if not same_process: - self.log.warning("Recorded pid %s does not match " - "the current pid %s", ti.pid, current_pid) + self.log.warning("Recorded pid %s does not match " "the current pid %s", ti.pid, current_pid) raise AirflowException("PID of job runner does not match") - elif ( - self.task_runner.return_code() is None and - hasattr(self.task_runner, 'process') - ): + elif self.task_runner.return_code() is None and hasattr(self.task_runner, 'process'): self.log.warning( - "State of this instance has been externally set to %s. " - "Terminating instance.", - ti.state + "State of this instance has been externally set to %s. " "Terminating instance.", ti.state ) if ti.state == State.FAILED and ti.task.on_failure_callback: context = ti.get_template_context() diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index 38a79ffa51fb2..ceee033789aae 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -52,7 +52,10 @@ from airflow.ti_deps.dependencies_states import EXECUTION_STATES from airflow.utils import timezone from airflow.utils.callback_requests import ( - CallbackRequest, DagCallbackRequest, SlaCallbackRequest, TaskCallbackRequest, + CallbackRequest, + DagCallbackRequest, + SlaCallbackRequest, + TaskCallbackRequest, ) from airflow.utils.dag_processing import AbstractDagFileProcessorProcess, DagFileProcessorAgent from airflow.utils.email import get_email_address_list, send_email @@ -184,9 +187,7 @@ def _run_file_processor( ) result_channel.send(result) end_time = time.time() - log.info( - "Processing %s took %.3f seconds", file_path, end_time - start_time - ) + log.info("Processing %s took %.3f seconds", file_path, end_time - start_time) except Exception: # pylint: disable=broad-except # Log exceptions through the logging framework. log.exception("Got an exception! Propagating...") @@ -213,9 +214,9 @@ def start(self) -> None: self._pickle_dags, self._dag_ids, f"DagFileProcessor{self._instance_id}", - self._callback_requests + self._callback_requests, ), - name=f"DagFileProcessor{self._instance_id}-Process" + name=f"DagFileProcessor{self._instance_id}-Process", ) self._process = process self._start_time = timezone.utcnow() @@ -400,28 +401,24 @@ def manage_slas(self, dag: DAG, session: Session = None) -> None: return qry = ( - session - .query( - TI.task_id, - func.max(TI.execution_date).label('max_ti') - ) + session.query(TI.task_id, func.max(TI.execution_date).label('max_ti')) .with_hint(TI, 'USE INDEX (PRIMARY)', dialect_name='mysql') .filter(TI.dag_id == dag.dag_id) - .filter( - or_( - TI.state == State.SUCCESS, - TI.state == State.SKIPPED - ) - ) + .filter(or_(TI.state == State.SUCCESS, TI.state == State.SKIPPED)) .filter(TI.task_id.in_(dag.task_ids)) - .group_by(TI.task_id).subquery('sq') + .group_by(TI.task_id) + .subquery('sq') ) - max_tis: List[TI] = session.query(TI).filter( - TI.dag_id == dag.dag_id, - TI.task_id == qry.c.task_id, - TI.execution_date == qry.c.max_ti, - ).all() + max_tis: List[TI] = ( + session.query(TI) + .filter( + TI.dag_id == dag.dag_id, + TI.task_id == qry.c.task_id, + TI.execution_date == qry.c.max_ti, + ) + .all() + ) ts = timezone.utcnow() for ti in max_tis: @@ -433,31 +430,26 @@ def manage_slas(self, dag: DAG, session: Session = None) -> None: while dttm < timezone.utcnow(): following_schedule = dag.following_schedule(dttm) if following_schedule + task.sla < timezone.utcnow(): - session.merge(SlaMiss( - task_id=ti.task_id, - dag_id=ti.dag_id, - execution_date=dttm, - timestamp=ts)) + session.merge( + SlaMiss(task_id=ti.task_id, dag_id=ti.dag_id, execution_date=dttm, timestamp=ts) + ) dttm = dag.following_schedule(dttm) session.commit() + # pylint: disable=singleton-comparison slas: List[SlaMiss] = ( - session - .query(SlaMiss) - .filter(SlaMiss.notification_sent == False, SlaMiss.dag_id == dag.dag_id) # noqa pylint: disable=singleton-comparison + session.query(SlaMiss) + .filter(SlaMiss.notification_sent == False, SlaMiss.dag_id == dag.dag_id) # noqa .all() ) + # pylint: enable=singleton-comparison if slas: # pylint: disable=too-many-nested-blocks sla_dates: List[datetime.datetime] = [sla.execution_date for sla in slas] fetched_tis: List[TI] = ( - session - .query(TI) - .filter( - TI.state != State.SUCCESS, - TI.execution_date.in_(sla_dates), - TI.dag_id == dag.dag_id - ).all() + session.query(TI) + .filter(TI.state != State.SUCCESS, TI.execution_date.in_(sla_dates), TI.dag_id == dag.dag_id) + .all() ) blocking_tis: List[TI] = [] for ti in fetched_tis: @@ -468,12 +460,10 @@ def manage_slas(self, dag: DAG, session: Session = None) -> None: session.delete(ti) session.commit() - task_list = "\n".join([ - sla.task_id + ' on ' + sla.execution_date.isoformat() - for sla in slas]) - blocking_task_list = "\n".join([ - ti.task_id + ' on ' + ti.execution_date.isoformat() - for ti in blocking_tis]) + task_list = "\n".join([sla.task_id + ' on ' + sla.execution_date.isoformat() for sla in slas]) + blocking_task_list = "\n".join( + [ti.task_id + ' on ' + ti.execution_date.isoformat() for ti in blocking_tis] + ) # Track whether email or any alert notification sent # We consider email or the alert callback as notifications email_sent = False @@ -482,8 +472,7 @@ def manage_slas(self, dag: DAG, session: Session = None) -> None: # Execute the alert callback self.log.info('Calling SLA miss callback') try: - dag.sla_miss_callback(dag, task_list, blocking_task_list, slas, - blocking_tis) + dag.sla_miss_callback(dag, task_list, blocking_task_list, slas, blocking_tis) notification_sent = True except Exception: # pylint: disable=broad-except self.log.exception("Could not call sla_miss_callback for DAG %s", dag.dag_id) @@ -501,8 +490,8 @@ def manage_slas(self, dag: DAG, session: Session = None) -> None: except TaskNotFound: # task already deleted from DAG, skip it self.log.warning( - "Task %s doesn't exist in DAG anymore, skipping SLA miss notification.", - sla.task_id) + "Task %s doesn't exist in DAG anymore, skipping SLA miss notification.", sla.task_id + ) continue tasks_missed_sla.append(task) @@ -515,17 +504,12 @@ def manage_slas(self, dag: DAG, session: Session = None) -> None: emails |= set(task.email) if emails: try: - send_email( - emails, - f"[airflow] SLA miss on DAG={dag.dag_id}", - email_content - ) + send_email(emails, f"[airflow] SLA miss on DAG={dag.dag_id}", email_content) email_sent = True notification_sent = True except Exception: # pylint: disable=broad-except Stats.incr('sla_email_notification_failure') - self.log.exception("Could not send SLA Miss email notification for" - " DAG %s", dag.dag_id) + self.log.exception("Could not send SLA Miss email notification for" " DAG %s", dag.dag_id) # If we sent any notification, update the sla_miss table if notification_sent: for sla in slas: @@ -548,24 +532,18 @@ def update_import_errors(session: Session, dagbag: DagBag) -> None: """ # Clear the errors of the processed files for dagbag_file in dagbag.file_last_changed: - session.query(errors.ImportError).filter( - errors.ImportError.filename == dagbag_file - ).delete() + session.query(errors.ImportError).filter(errors.ImportError.filename == dagbag_file).delete() # Add the errors of the processed files for filename, stacktrace in dagbag.import_errors.items(): - session.add(errors.ImportError( - filename=filename, - timestamp=timezone.utcnow(), - stacktrace=stacktrace)) + session.add( + errors.ImportError(filename=filename, timestamp=timezone.utcnow(), stacktrace=stacktrace) + ) session.commit() @provide_session def execute_callbacks( - self, - dagbag: DagBag, - callback_requests: List[CallbackRequest], - session: Session = None + self, dagbag: DagBag, callback_requests: List[CallbackRequest], session: Session = None ) -> None: """ Execute on failure callbacks. These objects can come from SchedulerJob or from @@ -588,7 +566,7 @@ def execute_callbacks( self.log.exception( "Error executing %s callback for file: %s", request.__class__.__name__, - request.full_filepath + request.full_filepath, ) session.commit() @@ -598,10 +576,7 @@ def _execute_dag_callbacks(self, dagbag: DagBag, request: DagCallbackRequest, se dag = dagbag.dags[request.dag_id] dag_run = dag.get_dagrun(execution_date=request.execution_date, session=session) dag.handle_callback( - dagrun=dag_run, - success=not request.is_failure_callback, - reason=request.msg, - session=session + dagrun=dag_run, success=not request.is_failure_callback, reason=request.msg, session=session ) def _execute_task_callbacks(self, dagbag: DagBag, request: TaskCallbackRequest): @@ -627,7 +602,7 @@ def process_file( file_path: str, callback_requests: List[CallbackRequest], pickle_dags: bool = False, - session: Session = None + session: Session = None, ) -> Tuple[int, int]: """ Process a Python file containing Airflow DAGs. @@ -728,20 +703,20 @@ class SchedulerJob(BaseJob): # pylint: disable=too-many-instance-attributes :type do_pickle: bool """ - __mapper_args__ = { - 'polymorphic_identity': 'SchedulerJob' - } + __mapper_args__ = {'polymorphic_identity': 'SchedulerJob'} heartrate: int = conf.getint('scheduler', 'SCHEDULER_HEARTBEAT_SEC') def __init__( - self, - subdir: str = settings.DAGS_FOLDER, - num_runs: int = conf.getint('scheduler', 'num_runs'), - num_times_parse_dags: int = -1, - processor_poll_interval: float = conf.getfloat('scheduler', 'processor_poll_interval'), - do_pickle: bool = False, - log: Any = None, - *args, **kwargs): + self, + subdir: str = settings.DAGS_FOLDER, + num_runs: int = conf.getint('scheduler', 'num_runs'), + num_times_parse_dags: int = -1, + processor_poll_interval: float = conf.getfloat('scheduler', 'processor_poll_interval'), + do_pickle: bool = False, + log: Any = None, + *args, + **kwargs, + ): self.subdir = subdir self.num_runs = num_runs @@ -796,16 +771,13 @@ def is_alive(self, grace_multiplier: Optional[float] = None) -> bool: return super().is_alive(grace_multiplier=grace_multiplier) scheduler_health_check_threshold: int = conf.getint('scheduler', 'scheduler_health_check_threshold') return ( - self.state == State.RUNNING and - (timezone.utcnow() - self.latest_heartbeat).total_seconds() < scheduler_health_check_threshold + self.state == State.RUNNING + and (timezone.utcnow() - self.latest_heartbeat).total_seconds() < scheduler_health_check_threshold ) @provide_session def _change_state_for_tis_without_dagrun( - self, - old_states: List[str], - new_state: str, - session: Session = None + self, old_states: List[str], new_state: str, session: Session = None ) -> None: """ For all DAG IDs in the DagBag, look for task instances in the @@ -820,20 +792,24 @@ def _change_state_for_tis_without_dagrun( :type new_state: airflow.utils.state.State """ tis_changed = 0 - query = session \ - .query(models.TaskInstance) \ - .outerjoin(models.TaskInstance.dag_run) \ - .filter(models.TaskInstance.dag_id.in_(list(self.dagbag.dag_ids))) \ - .filter(models.TaskInstance.state.in_(old_states)) \ - .filter(or_( - # pylint: disable=comparison-with-callable - models.DagRun.state != State.RUNNING, - models.DagRun.state.is_(None))) # pylint: disable=no-member + query = ( + session.query(models.TaskInstance) + .outerjoin(models.TaskInstance.dag_run) + .filter(models.TaskInstance.dag_id.in_(list(self.dagbag.dag_ids))) + .filter(models.TaskInstance.state.in_(old_states)) + .filter( + or_( + # pylint: disable=comparison-with-callable + models.DagRun.state != State.RUNNING, + # pylint: disable=no-member + models.DagRun.state.is_(None), + ) + ) + ) # We need to do this for mysql as well because it can cause deadlocks # as discussed in https://issues.apache.org/jira/browse/AIRFLOW-2516 if self.using_sqlite or self.using_mysql: - tis_to_change: List[TI] = with_row_locks(query, of=TI, - **skip_locked(session=session)).all() + tis_to_change: List[TI] = with_row_locks(query, of=TI, **skip_locked(session=session)).all() for ti in tis_to_change: ti.set_state(new_state, session=session) tis_changed += 1 @@ -847,25 +823,29 @@ def _change_state_for_tis_without_dagrun( # Only add end_date and duration if the new_state is 'success', 'failed' or 'skipped' if new_state in State.finished: - ti_prop_update.update({ - models.TaskInstance.end_date: current_time, - models.TaskInstance.duration: 0, - }) + ti_prop_update.update( + { + models.TaskInstance.end_date: current_time, + models.TaskInstance.duration: 0, + } + ) - tis_changed = session \ - .query(models.TaskInstance) \ + tis_changed = ( + session.query(models.TaskInstance) .filter( models.TaskInstance.dag_id == subq.c.dag_id, models.TaskInstance.task_id == subq.c.task_id, - models.TaskInstance.execution_date == - subq.c.execution_date) \ + models.TaskInstance.execution_date == subq.c.execution_date, + ) .update(ti_prop_update, synchronize_session=False) + ) session.flush() if tis_changed > 0: self.log.warning( "Set %s task instances to state=%s as their associated DagRun was not in RUNNING state", - tis_changed, new_state + tis_changed, + new_state, ) Stats.gauge('scheduler.tasks.without_dagrun', tis_changed) @@ -883,8 +863,7 @@ def __get_concurrency_maps( :rtype: tuple[dict[str, int], dict[tuple[str, str], int]] """ ti_concurrency_query: List[Tuple[str, str, int]] = ( - session - .query(TI.task_id, TI.dag_id, func.count('*')) + session.query(TI.task_id, TI.dag_id, func.count('*')) .filter(TI.state.in_(states)) .group_by(TI.task_id, TI.dag_id) ).all() @@ -927,11 +906,9 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session = # DagRuns which are not backfilled, in the given states, # and the dag is not paused query = ( - session - .query(TI) + session.query(TI) .outerjoin(TI.dag_run) - .filter(or_(DR.run_id.is_(None), - DR.run_type != DagRunType.BACKFILL_JOB)) + .filter(or_(DR.run_id.is_(None), DR.run_type != DagRunType.BACKFILL_JOB)) .join(TI.dag_model) .filter(not_(DM.is_paused)) .filter(TI.state == State.SCHEDULED) @@ -952,12 +929,8 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session = return executable_tis # Put one task instance on each line - task_instance_str = "\n\t".join( - [repr(x) for x in task_instances_to_examine]) - self.log.info( - "%s tasks up for execution:\n\t%s", len(task_instances_to_examine), - task_instance_str - ) + task_instance_str = "\n\t".join([repr(x) for x in task_instances_to_examine]) + self.log.info("%s tasks up for execution:\n\t%s", len(task_instances_to_examine), task_instance_str) pool_to_task_instances: DefaultDict[str, List[models.Pool]] = defaultdict(list) for task_instance in task_instances_to_examine: @@ -967,7 +940,8 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session = dag_concurrency_map: DefaultDict[str, int] task_concurrency_map: DefaultDict[Tuple[str, str], int] dag_concurrency_map, task_concurrency_map = self.__get_concurrency_maps( - states=list(EXECUTION_STATES), session=session) + states=list(EXECUTION_STATES), session=session + ) num_tasks_in_executor = 0 # Number of tasks that cannot be scheduled because of no open slot in pool @@ -979,10 +953,7 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session = for pool, task_instances in pool_to_task_instances.items(): pool_name = pool if pool not in pools: - self.log.warning( - "Tasks using non-existent pool '%s' will not be scheduled", - pool - ) + self.log.warning("Tasks using non-existent pool '%s' will not be scheduled", pool) continue open_slots = pools[pool]["open"] @@ -991,19 +962,19 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session = self.log.info( "Figuring out tasks to run in Pool(name=%s) with %s open slots " "and %s task instances ready to be queued", - pool, open_slots, num_ready + pool, + open_slots, + num_ready, ) priority_sorted_task_instances = sorted( - task_instances, key=lambda ti: (-ti.priority_weight, ti.execution_date)) + task_instances, key=lambda ti: (-ti.priority_weight, ti.execution_date) + ) num_starving_tasks = 0 for current_index, task_instance in enumerate(priority_sorted_task_instances): if open_slots <= 0: - self.log.info( - "Not scheduling since there are %s open slots in pool %s", - open_slots, pool - ) + self.log.info("Not scheduling since there are %s open slots in pool %s", open_slots, pool) # Can't schedule any more since there are no more open slots. num_unhandled = len(priority_sorted_task_instances) - current_index num_starving_tasks += num_unhandled @@ -1018,13 +989,17 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session = dag_concurrency_limit = task_instance.dag_model.concurrency self.log.info( "DAG %s has %s/%s running and queued tasks", - dag_id, current_dag_concurrency, dag_concurrency_limit + dag_id, + current_dag_concurrency, + dag_concurrency_limit, ) if current_dag_concurrency >= dag_concurrency_limit: self.log.info( "Not executing %s since the number of tasks running or queued " "from DAG %s is >= to the DAG's task concurrency limit of %s", - task_instance, dag_id, dag_concurrency_limit + task_instance, + dag_id, + dag_concurrency_limit, ) continue @@ -1035,7 +1010,8 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session = serialized_dag = self.dagbag.get_dag(dag_id, session=session) if serialized_dag.has_task(task_instance.task_id): task_concurrency_limit = serialized_dag.get_task( - task_instance.task_id).task_concurrency + task_instance.task_id + ).task_concurrency if task_concurrency_limit is not None: current_task_concurrency = task_concurrency_map[ @@ -1043,14 +1019,22 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session = ] if current_task_concurrency >= task_concurrency_limit: - self.log.info("Not executing %s since the task concurrency for" - " this task has been reached.", task_instance) + self.log.info( + "Not executing %s since the task concurrency for" + " this task has been reached.", + task_instance, + ) continue if task_instance.pool_slots > open_slots: - self.log.info("Not executing %s since it requires %s slots " - "but there are %s open slots in the pool %s.", - task_instance, task_instance.pool_slots, open_slots, pool) + self.log.info( + "Not executing %s since it requires %s slots " + "but there are %s open slots in the pool %s.", + task_instance, + task_instance.pool_slots, + open_slots, + pool, + ) num_starving_tasks += 1 num_starving_tasks_total += 1 # Though we can execute tasks with lower priority if there's enough room @@ -1067,10 +1051,8 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session = Stats.gauge('scheduler.tasks.running', num_tasks_in_executor) Stats.gauge('scheduler.tasks.executable', len(executable_tis)) - task_instance_str = "\n\t".join( - [repr(x) for x in executable_tis]) - self.log.info( - "Setting the following tasks to queued state:\n\t%s", task_instance_str) + task_instance_str = "\n\t".join([repr(x) for x in executable_tis]) + self.log.info("Setting the following tasks to queued state:\n\t%s", task_instance_str) # set TIs to queued state filter_for_tis = TI.filter_for_tis(executable_tis) @@ -1078,17 +1060,14 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session = # TODO[ha]: should we use func.now()? How does that work with DB timezone on mysql when it's not # UTC? {TI.state: State.QUEUED, TI.queued_dttm: timezone.utcnow(), TI.queued_by_job_id: self.id}, - synchronize_session=False + synchronize_session=False, ) for ti in executable_tis: make_transient(ti) return executable_tis - def _enqueue_task_instances_with_queued_state( - self, - task_instances: List[TI] - ) -> None: + def _enqueue_task_instances_with_queued_state(self, task_instances: List[TI]) -> None: """ Takes task_instances, which should have been set to queued, and enqueues them with the executor. @@ -1115,10 +1094,7 @@ def _enqueue_task_instances_with_queued_state( priority = ti.priority_weight queue = ti.queue - self.log.info( - "Sending %s to executor with priority %s and queue %s", - ti.key, priority, queue - ) + self.log.info("Sending %s to executor with priority %s and queue %s", ti.key, priority, queue) self.executor.queue_command( ti, @@ -1164,17 +1140,18 @@ def _change_state_for_tasks_failed_to_execute(self, session: Session = None): if not self.executor.queued_tasks: return - filter_for_ti_state_change = ( - [and_( + filter_for_ti_state_change = [ + and_( TI.dag_id == dag_id, TI.task_id == task_id, TI.execution_date == execution_date, # The TI.try_number will return raw try_number+1 since the # ti is not running. And we need to -1 to match the DB record. TI._try_number == try_number - 1, # pylint: disable=protected-access - TI.state == State.QUEUED) - for dag_id, task_id, execution_date, try_number - in self.executor.queued_tasks.keys()]) + TI.state == State.QUEUED, + ) + for dag_id, task_id, execution_date, try_number in self.executor.queued_tasks.keys() + ] ti_query = session.query(TI).filter(or_(*filter_for_ti_state_change)) tis_to_set_to_scheduled: List[TI] = with_row_locks(ti_query).all() if not tis_to_set_to_scheduled: @@ -1211,7 +1188,11 @@ def _process_executor_events(self, session: Session = None) -> int: self.log.info( "Executor reports execution of %s.%s execution_date=%s " "exited with status %s for try_number %s", - ti_key.dag_id, ti_key.task_id, ti_key.execution_date, state, ti_key.try_number + ti_key.dag_id, + ti_key.task_id, + ti_key.execution_date, + state, + ti_key.try_number, ) if state in (State.FAILED, State.SUCCESS, State.QUEUED): tis_with_right_state.append(ti_key) @@ -1236,8 +1217,10 @@ def _process_executor_events(self, session: Session = None) -> int: if ti.try_number == buffer_key.try_number and ti.state == State.QUEUED: Stats.incr('scheduler.tasks.killed_externally') - msg = "Executor reports task instance %s finished (%s) although the " \ - "task says its %s. (Info: %s) Was the task killed externally?" + msg = ( + "Executor reports task instance %s finished (%s) although the " + "task says its %s. (Info: %s) Was the task killed externally?" + ) self.log.error(msg, ti, state, ti.state, info) request = TaskCallbackRequest( full_filepath=ti.dag_model.fileloc, @@ -1297,8 +1280,7 @@ def _execute(self) -> None: # deleted. if self.processor_agent.all_files_processed: self.log.info( - "Deactivating DAGs that haven't been touched since %s", - execute_start_time.isoformat() + "Deactivating DAGs that haven't been touched since %s", execute_start_time.isoformat() ) models.DAG.deactivate_stale_dags(execute_start_time) @@ -1316,14 +1298,11 @@ def _create_dag_file_processor( file_path: str, callback_requests: List[CallbackRequest], dag_ids: Optional[List[str]], - pickle_dags: bool + pickle_dags: bool, ) -> DagFileProcessorProcess: """Creates DagFileProcessorProcess instance.""" return DagFileProcessorProcess( - file_path=file_path, - pickle_dags=pickle_dags, - dag_ids=dag_ids, - callback_requests=callback_requests + file_path=file_path, pickle_dags=pickle_dags, dag_ids=dag_ids, callback_requests=callback_requests ) def _run_scheduler_loop(self) -> None: @@ -1384,14 +1363,16 @@ def _run_scheduler_loop(self) -> None: if loop_count >= self.num_runs > 0: self.log.info( "Exiting scheduler loop as requested number of runs (%d - got to %d) has been reached", - self.num_runs, loop_count, + self.num_runs, + loop_count, ) break if self.processor_agent.done: self.log.info( "Exiting scheduler loop as requested DAG parse count (%d) has been reached after %d " " scheduler loops", - self.num_times_parse_dags, loop_count, + self.num_times_parse_dags, + loop_count, ) break @@ -1455,13 +1436,17 @@ def _do_scheduling(self, session) -> int: # dag_id -- only tasks from those runs will be scheduled. active_runs_by_dag_id = defaultdict(set) - query = session.query( - TI.dag_id, - TI.execution_date, - ).filter( - TI.dag_id.in_(list({dag_run.dag_id for dag_run in dag_runs})), - TI.state.notin_(list(State.finished)) - ).group_by(TI.dag_id, TI.execution_date) + query = ( + session.query( + TI.dag_id, + TI.execution_date, + ) + .filter( + TI.dag_id.in_(list({dag_run.dag_id for dag_run in dag_runs})), + TI.state.notin_(list(State.finished)), + ) + .group_by(TI.dag_id, TI.execution_date) + ) for dag_id, execution_date in query: active_runs_by_dag_id[dag_id].add(execution_date) @@ -1478,18 +1463,13 @@ def _do_scheduling(self, session) -> int: # TODO[HA]: Do we need to do it every time? try: self._change_state_for_tis_without_dagrun( - old_states=[State.UP_FOR_RETRY], - new_state=State.FAILED, - session=session + old_states=[State.UP_FOR_RETRY], new_state=State.FAILED, session=session ) self._change_state_for_tis_without_dagrun( - old_states=[State.QUEUED, - State.SCHEDULED, - State.UP_FOR_RESCHEDULE, - State.SENSING], + old_states=[State.QUEUED, State.SCHEDULED, State.UP_FOR_RESCHEDULE, State.SENSING], new_state=State.NONE, - session=session + session=session, ) guard.commit() @@ -1560,11 +1540,16 @@ def _update_dag_next_dagruns(self, dag_models: Iterable[DagModel], session: Sess """ # Check max_active_runs, to see if we are _now_ at the limit for any of # these dag? (we've just created a DagRun for them after all) - active_runs_of_dags = dict(session.query(DagRun.dag_id, func.count('*')).filter( - DagRun.dag_id.in_([o.dag_id for o in dag_models]), - DagRun.state == State.RUNNING, # pylint: disable=comparison-with-callable - DagRun.external_trigger.is_(False), - ).group_by(DagRun.dag_id).all()) + active_runs_of_dags = dict( + session.query(DagRun.dag_id, func.count('*')) + .filter( + DagRun.dag_id.in_([o.dag_id for o in dag_models]), + DagRun.state == State.RUNNING, # pylint: disable=comparison-with-callable + DagRun.external_trigger.is_(False), + ) + .group_by(DagRun.dag_id) + .all() + ) for dag_model in dag_models: dag = self.dagbag.get_dag(dag_model.dag_id, session=session) @@ -1572,12 +1557,15 @@ def _update_dag_next_dagruns(self, dag_models: Iterable[DagModel], session: Sess if dag.max_active_runs and active_runs_of_dag >= dag.max_active_runs: self.log.info( "DAG %s is at (or above) max_active_runs (%d of %d), not creating any more runs", - dag.dag_id, active_runs_of_dag, dag.max_active_runs + dag.dag_id, + active_runs_of_dag, + dag.max_active_runs, ) dag_model.next_dagrun_create_after = None else: - dag_model.next_dagrun, dag_model.next_dagrun_create_after = \ - dag.next_dagrun_info(dag_model.next_dagrun) + dag_model.next_dagrun, dag_model.next_dagrun_create_after = dag.next_dagrun_info( + dag_model.next_dagrun + ) def _schedule_dag_run( self, @@ -1598,14 +1586,13 @@ def _schedule_dag_run( dag = dag_run.dag = self.dagbag.get_dag(dag_run.dag_id, session=session) if not dag: - self.log.error( - "Couldn't find dag %s in DagBag/DB!", dag_run.dag_id - ) + self.log.error("Couldn't find dag %s in DagBag/DB!", dag_run.dag_id) return 0 if ( - dag_run.start_date and dag.dagrun_timeout and - dag_run.start_date < timezone.utcnow() - dag.dagrun_timeout + dag_run.start_date + and dag.dagrun_timeout + and dag_run.start_date < timezone.utcnow() - dag.dagrun_timeout ): dag_run.state = State.FAILED dag_run.end_date = timezone.utcnow() @@ -1620,7 +1607,7 @@ def _schedule_dag_run( dag_id=dag.dag_id, execution_date=dag_run.execution_date, is_failure_callback=True, - msg='timed_out' + msg='timed_out', ) # Send SLA & DAG Success/Failure Callbacks to be executed @@ -1629,15 +1616,14 @@ def _schedule_dag_run( return 0 if dag_run.execution_date > timezone.utcnow() and not dag.allow_future_exec_dates: - self.log.error( - "Execution date is in future: %s", - dag_run.execution_date - ) + self.log.error("Execution date is in future: %s", dag_run.execution_date) return 0 if dag.max_active_runs: - if len(currently_active_runs) >= dag.max_active_runs and \ - dag_run.execution_date not in currently_active_runs: + if ( + len(currently_active_runs) >= dag.max_active_runs + and dag_run.execution_date not in currently_active_runs + ): self.log.info( "DAG %s already has %d active runs, not queuing any tasks for run %s", dag.dag_id, @@ -1675,9 +1661,7 @@ def _verify_integrity_if_dag_changed(self, dag_run: DagRun, session=None): dag_run.verify_integrity(session=session) def _send_dag_callbacks_to_processor( - self, - dag_run: DagRun, - callback: Optional[DagCallbackRequest] = None + self, dag_run: DagRun, callback: Optional[DagCallbackRequest] = None ): if not self.processor_agent: raise ValueError("Processor agent is not started.") @@ -1700,8 +1684,7 @@ def _send_sla_callbacks_to_processor(self, dag: DAG): raise ValueError("Processor agent is not started.") self.processor_agent.send_sla_callback_request_to_execute( - full_filepath=dag.fileloc, - dag_id=dag.dag_id + full_filepath=dag.fileloc, dag_id=dag.dag_id ) @provide_session @@ -1727,10 +1710,14 @@ def adopt_or_reset_orphaned_tasks(self, session: Session = None): """ timeout = conf.getint('scheduler', 'scheduler_health_check_threshold') - num_failed = session.query(SchedulerJob).filter( - SchedulerJob.state == State.RUNNING, - SchedulerJob.latest_heartbeat < (timezone.utcnow() - timedelta(seconds=timeout)) - ).update({"state": State.FAILED}) + num_failed = ( + session.query(SchedulerJob) + .filter( + SchedulerJob.state == State.RUNNING, + SchedulerJob.latest_heartbeat < (timezone.utcnow() - timedelta(seconds=timeout)), + ) + .update({"state": State.FAILED}) + ) if num_failed: self.log.info("Marked %d SchedulerJob instances as failed", num_failed) @@ -1738,7 +1725,8 @@ def adopt_or_reset_orphaned_tasks(self, session: Session = None): resettable_states = [State.SCHEDULED, State.QUEUED, State.RUNNING] query = ( - session.query(TI).filter(TI.state.in_(resettable_states)) + session.query(TI) + .filter(TI.state.in_(resettable_states)) # outerjoin is because we didn't use to have queued_by_job # set, so we need to pick up anything pre upgrade. This (and the # "or queued_by_job_id IS NONE") can go as soon as scheduler HA is @@ -1746,9 +1734,11 @@ def adopt_or_reset_orphaned_tasks(self, session: Session = None): .outerjoin(TI.queued_by_job) .filter(or_(TI.queued_by_job_id.is_(None), SchedulerJob.state != State.RUNNING)) .join(TI.dag_run) - .filter(DagRun.run_type != DagRunType.BACKFILL_JOB, - # pylint: disable=comparison-with-callable - DagRun.state == State.RUNNING) + .filter( + DagRun.run_type != DagRunType.BACKFILL_JOB, + # pylint: disable=comparison-with-callable + DagRun.state == State.RUNNING, + ) .options(load_only(TI.dag_id, TI.task_id, TI.execution_date)) ) @@ -1770,8 +1760,9 @@ def adopt_or_reset_orphaned_tasks(self, session: Session = None): if to_reset: task_instance_str = '\n\t'.join(reset_tis_message) - self.log.info("Reset the following %s orphaned TaskInstances:\n\t%s", - len(to_reset), task_instance_str) + self.log.info( + "Reset the following %s orphaned TaskInstances:\n\t%s", len(to_reset), task_instance_str + ) # Issue SQL/finish "Unit of Work", but let @provide_session commit (or if passed a session, let caller # decide when to commit diff --git a/airflow/kubernetes/kube_client.py b/airflow/kubernetes/kube_client.py index e51274feec46f..7e8c5e83d7f3c 100644 --- a/airflow/kubernetes/kube_client.py +++ b/airflow/kubernetes/kube_client.py @@ -26,13 +26,15 @@ from kubernetes.client.rest import ApiException # pylint: disable=unused-import from airflow.kubernetes.refresh_config import ( # pylint: disable=ungrouped-imports - RefreshConfiguration, load_kube_config, + RefreshConfiguration, + load_kube_config, ) + has_kubernetes = True - def _get_kube_config(in_cluster: bool, - cluster_context: Optional[str], - config_file: Optional[str]) -> Optional[Configuration]: + def _get_kube_config( + in_cluster: bool, cluster_context: Optional[str], config_file: Optional[str] + ) -> Optional[Configuration]: if in_cluster: # load_incluster_config set default configuration with config populated by k8s config.load_incluster_config() @@ -41,8 +43,7 @@ def _get_kube_config(in_cluster: bool, # this block can be replaced with just config.load_kube_config once # refresh_config module is replaced with upstream fix cfg = RefreshConfiguration() - load_kube_config( - client_configuration=cfg, config_file=config_file, context=cluster_context) + load_kube_config(client_configuration=cfg, config_file=config_file, context=cluster_context) return cfg def _get_client_with_patched_configuration(cfg: Optional[Configuration]) -> client.CoreV1Api: @@ -57,6 +58,7 @@ def _get_client_with_patched_configuration(cfg: Optional[Configuration]) -> clie else: return client.CoreV1Api() + except ImportError as e: # We need an exception class to be able to use it in ``except`` elsewhere # in the code base @@ -92,9 +94,11 @@ def _enable_tcp_keepalive() -> None: HTTPConnection.default_socket_options = HTTPConnection.default_socket_options + socket_options -def get_kube_client(in_cluster: bool = conf.getboolean('kubernetes', 'in_cluster'), - cluster_context: Optional[str] = None, - config_file: Optional[str] = None) -> client.CoreV1Api: +def get_kube_client( + in_cluster: bool = conf.getboolean('kubernetes', 'in_cluster'), + cluster_context: Optional[str] = None, + config_file: Optional[str] = None, +) -> client.CoreV1Api: """ Retrieves Kubernetes client diff --git a/airflow/kubernetes/pod_generator.py b/airflow/kubernetes/pod_generator.py index bf54a6ac81f9d..c981780bddca2 100644 --- a/airflow/kubernetes/pod_generator.py +++ b/airflow/kubernetes/pod_generator.py @@ -50,14 +50,8 @@ class PodDefaults: XCOM_MOUNT_PATH = '/airflow/xcom' SIDECAR_CONTAINER_NAME = 'airflow-xcom-sidecar' XCOM_CMD = 'trap "exit 0" INT; while true; do sleep 30; done;' - VOLUME_MOUNT = k8s.V1VolumeMount( - name='xcom', - mount_path=XCOM_MOUNT_PATH - ) - VOLUME = k8s.V1Volume( - name='xcom', - empty_dir=k8s.V1EmptyDirVolumeSource() - ) + VOLUME_MOUNT = k8s.V1VolumeMount(name='xcom', mount_path=XCOM_MOUNT_PATH) + VOLUME = k8s.V1Volume(name='xcom', empty_dir=k8s.V1EmptyDirVolumeSource()) SIDECAR_CONTAINER = k8s.V1Container( name=SIDECAR_CONTAINER_NAME, command=['sh', '-c', XCOM_CMD], @@ -85,7 +79,7 @@ def make_safe_label_value(string): if len(safe_label) > MAX_LABEL_LEN or string != safe_label: safe_hash = hashlib.md5(string.encode()).hexdigest()[:9] - safe_label = safe_label[:MAX_LABEL_LEN - len(safe_hash) - 1] + "-" + safe_hash + safe_label = safe_label[: MAX_LABEL_LEN - len(safe_hash) - 1] + "-" + safe_hash return safe_label @@ -134,14 +128,14 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals self, pod: Optional[k8s.V1Pod] = None, pod_template_file: Optional[str] = None, - extract_xcom: bool = True + extract_xcom: bool = True, ): if not pod_template_file and not pod: - raise AirflowConfigException("Podgenerator requires either a " - "`pod` or a `pod_template_file` argument") + raise AirflowConfigException( + "Podgenerator requires either a " "`pod` or a `pod_template_file` argument" + ) if pod_template_file and pod: - raise AirflowConfigException("Cannot pass both `pod` " - "and `pod_template_file` arguments") + raise AirflowConfigException("Cannot pass both `pod` " "and `pod_template_file` arguments") if pod_template_file: self.ud_pod = self.deserialize_model_file(pod_template_file) @@ -184,25 +178,29 @@ def from_obj(obj) -> Optional[Union[dict, k8s.V1Pod]]: k8s_object = obj.get("pod_override", None) if k8s_legacy_object and k8s_object: - raise AirflowConfigException("Can not have both a legacy and new" - "executor_config object. Please delete the KubernetesExecutor" - "dict and only use the pod_override kubernetes.client.models.V1Pod" - "object.") + raise AirflowConfigException( + "Can not have both a legacy and new" + "executor_config object. Please delete the KubernetesExecutor" + "dict and only use the pod_override kubernetes.client.models.V1Pod" + "object." + ) if not k8s_object and not k8s_legacy_object: return None if isinstance(k8s_object, k8s.V1Pod): return k8s_object elif isinstance(k8s_legacy_object, dict): - warnings.warn('Using a dictionary for the executor_config is deprecated and will soon be removed.' - 'please use a `kubernetes.client.models.V1Pod` class with a "pod_override" key' - ' instead. ', - category=DeprecationWarning) + warnings.warn( + 'Using a dictionary for the executor_config is deprecated and will soon be removed.' + 'please use a `kubernetes.client.models.V1Pod` class with a "pod_override" key' + ' instead. ', + category=DeprecationWarning, + ) return PodGenerator.from_legacy_obj(obj) else: raise TypeError( - 'Cannot convert a non-kubernetes.client.models.V1Pod' - 'object into a KubernetesExecutorConfig') + 'Cannot convert a non-kubernetes.client.models.V1Pod' 'object into a KubernetesExecutorConfig' + ) @staticmethod def from_legacy_obj(obj) -> Optional[k8s.V1Pod]: @@ -223,12 +221,12 @@ def from_legacy_obj(obj) -> Optional[k8s.V1Pod]: requests = { 'cpu': namespaced.pop('request_cpu', None), 'memory': namespaced.pop('request_memory', None), - 'ephemeral-storage': namespaced.get('ephemeral-storage') # We pop this one in limits + 'ephemeral-storage': namespaced.get('ephemeral-storage'), # We pop this one in limits } limits = { 'cpu': namespaced.pop('limit_cpu', None), 'memory': namespaced.pop('limit_memory', None), - 'ephemeral-storage': namespaced.pop('ephemeral-storage', None) + 'ephemeral-storage': namespaced.pop('ephemeral-storage', None), } all_resources = list(requests.values()) + list(limits.values()) if all(r is None for r in all_resources): @@ -237,10 +235,7 @@ def from_legacy_obj(obj) -> Optional[k8s.V1Pod]: # remove None's so they don't become 0's requests = {k: v for k, v in requests.items() if v is not None} limits = {k: v for k, v in limits.items() if v is not None} - resources = k8s.V1ResourceRequirements( - requests=requests, - limits=limits - ) + resources = k8s.V1ResourceRequirements(requests=requests, limits=limits) namespaced['resources'] = resources return PodGeneratorDeprecated(**namespaced).gen_pod() @@ -292,8 +287,9 @@ def reconcile_metadata(base_meta, client_meta): return None @staticmethod - def reconcile_specs(base_spec: Optional[k8s.V1PodSpec], - client_spec: Optional[k8s.V1PodSpec]) -> Optional[k8s.V1PodSpec]: + def reconcile_specs( + base_spec: Optional[k8s.V1PodSpec], client_spec: Optional[k8s.V1PodSpec] + ) -> Optional[k8s.V1PodSpec]: """ :param base_spec: has the base attributes which are overwritten if they exist in the client_spec and remain if they do not exist in the client_spec @@ -316,8 +312,9 @@ def reconcile_specs(base_spec: Optional[k8s.V1PodSpec], return None @staticmethod - def reconcile_containers(base_containers: List[k8s.V1Container], - client_containers: List[k8s.V1Container]) -> List[k8s.V1Container]: + def reconcile_containers( + base_containers: List[k8s.V1Container], client_containers: List[k8s.V1Container] + ) -> List[k8s.V1Container]: """ :param base_containers: has the base attributes which are overwritten if they exist in the client_containers and remain if they do not exist in the client_containers @@ -358,7 +355,7 @@ def construct_pod( # pylint: disable=too-many-arguments pod_override_object: Optional[k8s.V1Pod], base_worker_pod: k8s.V1Pod, namespace: str, - scheduler_job_id: str + scheduler_job_id: str, ) -> k8s.V1Pod: """ Construct a pod by gathering and consolidating the configuration from 3 places: @@ -391,7 +388,7 @@ def construct_pod( # pylint: disable=too-many-arguments 'try_number': str(try_number), 'airflow_version': airflow_version.replace('+', '-'), 'kubernetes_executor': 'True', - } + }, ), spec=k8s.V1PodSpec( containers=[ @@ -401,7 +398,7 @@ def construct_pod( # pylint: disable=too-many-arguments image=image, ) ] - ) + ), ) # Reconcile the pods starting with the first chronologically, @@ -449,8 +446,7 @@ def deserialize_model_dict(pod_dict: dict) -> k8s.V1Pod: @return: """ api_client = ApiClient() - return api_client._ApiClient__deserialize_model( # pylint: disable=W0212 - pod_dict, k8s.V1Pod) + return api_client._ApiClient__deserialize_model(pod_dict, k8s.V1Pod) # pylint: disable=W0212 @staticmethod def make_unique_pod_id(pod_id): @@ -466,7 +462,7 @@ def make_unique_pod_id(pod_id): return None safe_uuid = uuid.uuid4().hex - safe_pod_id = pod_id[:MAX_POD_ID_LEN - len(safe_uuid) - 1] + "-" + safe_uuid + safe_pod_id = pod_id[: MAX_POD_ID_LEN - len(safe_uuid) - 1] + "-" + safe_uuid return safe_pod_id @@ -513,8 +509,9 @@ def extend_object_field(base_obj, client_obj, field_name): base_obj_field = getattr(base_obj, field_name, None) client_obj_field = getattr(client_obj, field_name, None) - if (not isinstance(base_obj_field, list) and base_obj_field is not None) or \ - (not isinstance(client_obj_field, list) and client_obj_field is not None): + if (not isinstance(base_obj_field, list) and base_obj_field is not None) or ( + not isinstance(client_obj_field, list) and client_obj_field is not None + ): raise ValueError("The chosen field must be a list.") if not base_obj_field: diff --git a/airflow/kubernetes/pod_generator_deprecated.py b/airflow/kubernetes/pod_generator_deprecated.py index cdf9a9182b99c..79bdcb4406e98 100644 --- a/airflow/kubernetes/pod_generator_deprecated.py +++ b/airflow/kubernetes/pod_generator_deprecated.py @@ -39,14 +39,8 @@ class PodDefaults: XCOM_MOUNT_PATH = '/airflow/xcom' SIDECAR_CONTAINER_NAME = 'airflow-xcom-sidecar' XCOM_CMD = 'trap "exit 0" INT; while true; do sleep 30; done;' - VOLUME_MOUNT = k8s.V1VolumeMount( - name='xcom', - mount_path=XCOM_MOUNT_PATH - ) - VOLUME = k8s.V1Volume( - name='xcom', - empty_dir=k8s.V1EmptyDirVolumeSource() - ) + VOLUME_MOUNT = k8s.V1VolumeMount(name='xcom', mount_path=XCOM_MOUNT_PATH) + VOLUME = k8s.V1Volume(name='xcom', empty_dir=k8s.V1EmptyDirVolumeSource()) SIDECAR_CONTAINER = k8s.V1Container( name=SIDECAR_CONTAINER_NAME, command=['sh', '-c', XCOM_CMD], @@ -74,7 +68,7 @@ def make_safe_label_value(string): if len(safe_label) > MAX_LABEL_LEN or string != safe_label: safe_hash = hashlib.md5(string.encode()).hexdigest()[:9] - safe_label = safe_label[:MAX_LABEL_LEN - len(safe_hash) - 1] + "-" + safe_hash + safe_label = safe_label[: MAX_LABEL_LEN - len(safe_hash) - 1] + "-" + safe_hash return safe_label @@ -199,21 +193,16 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals if envs: if isinstance(envs, dict): for key, val in envs.items(): - self.container.env.append(k8s.V1EnvVar( - name=key, - value=val - )) + self.container.env.append(k8s.V1EnvVar(name=key, value=val)) elif isinstance(envs, list): self.container.env.extend(envs) configmaps = configmaps or [] self.container.env_from = [] for configmap in configmaps: - self.container.env_from.append(k8s.V1EnvFromSource( - config_map_ref=k8s.V1ConfigMapEnvSource( - name=configmap - ) - )) + self.container.env_from.append( + k8s.V1EnvFromSource(config_map_ref=k8s.V1ConfigMapEnvSource(name=configmap)) + ) self.container.command = cmds or [] self.container.args = args or [] @@ -241,9 +230,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals if image_pull_secrets: for image_pull_secret in image_pull_secrets.split(','): - self.spec.image_pull_secrets.append(k8s.V1LocalObjectReference( - name=image_pull_secret - )) + self.spec.image_pull_secrets.append(k8s.V1LocalObjectReference(name=image_pull_secret)) # Attach sidecar self.extract_xcom = extract_xcom @@ -289,7 +276,8 @@ def from_obj(obj) -> Optional[k8s.V1Pod]: if not isinstance(obj, dict): raise TypeError( 'Cannot convert a non-dictionary or non-PodGenerator ' - 'object into a KubernetesExecutorConfig') + 'object into a KubernetesExecutorConfig' + ) # We do not want to extract constant here from ExecutorLoader because it is just # A name in dictionary rather than executor selection mechanism and it causes cyclic import @@ -304,21 +292,18 @@ def from_obj(obj) -> Optional[k8s.V1Pod]: requests = { 'cpu': namespaced.get('request_cpu'), 'memory': namespaced.get('request_memory'), - 'ephemeral-storage': namespaced.get('ephemeral-storage') + 'ephemeral-storage': namespaced.get('ephemeral-storage'), } limits = { 'cpu': namespaced.get('limit_cpu'), 'memory': namespaced.get('limit_memory'), - 'ephemeral-storage': namespaced.get('ephemeral-storage') + 'ephemeral-storage': namespaced.get('ephemeral-storage'), } all_resources = list(requests.values()) + list(limits.values()) if all(r is None for r in all_resources): resources = None else: - resources = k8s.V1ResourceRequirements( - requests=requests, - limits=limits - ) + resources = k8s.V1ResourceRequirements(requests=requests, limits=limits) namespaced['resources'] = resources return PodGenerator(**namespaced).gen_pod() @@ -336,6 +321,6 @@ def make_unique_pod_id(dag_id): return None safe_uuid = uuid.uuid4().hex - safe_pod_id = dag_id[:MAX_POD_ID_LEN - len(safe_uuid) - 1] + "-" + safe_uuid + safe_pod_id = dag_id[: MAX_POD_ID_LEN - len(safe_uuid) - 1] + "-" + safe_uuid return safe_pod_id diff --git a/airflow/kubernetes/pod_launcher.py b/airflow/kubernetes/pod_launcher.py index e528e2c4ae3a0..08d0bf4f37de9 100644 --- a/airflow/kubernetes/pod_launcher.py +++ b/airflow/kubernetes/pod_launcher.py @@ -49,11 +49,13 @@ class PodStatus: class PodLauncher(LoggingMixin): """Launches PODS""" - def __init__(self, - kube_client: client.CoreV1Api = None, - in_cluster: bool = True, - cluster_context: Optional[str] = None, - extract_xcom: bool = False): + def __init__( + self, + kube_client: client.CoreV1Api = None, + in_cluster: bool = True, + cluster_context: Optional[str] = None, + extract_xcom: bool = False, + ): """ Creates the launcher. @@ -63,8 +65,7 @@ def __init__(self, :param extract_xcom: whether we should extract xcom """ super().__init__() - self._client = kube_client or get_kube_client(in_cluster=in_cluster, - cluster_context=cluster_context) + self._client = kube_client or get_kube_client(in_cluster=in_cluster, cluster_context=cluster_context) self._watch = watch.Watch() self.extract_xcom = extract_xcom @@ -77,12 +78,12 @@ def run_pod_async(self, pod: V1Pod, **kwargs): self.log.debug('Pod Creation Request: \n%s', json_pod) try: - resp = self._client.create_namespaced_pod(body=sanitized_pod, - namespace=pod.metadata.namespace, **kwargs) + resp = self._client.create_namespaced_pod( + body=sanitized_pod, namespace=pod.metadata.namespace, **kwargs + ) self.log.debug('Pod Creation Response: %s', resp) except Exception as e: - self.log.exception('Exception when attempting ' - 'to create Namespaced Pod: %s', json_pod) + self.log.exception('Exception when attempting ' 'to create Namespaced Pod: %s', json_pod) raise e return resp @@ -90,16 +91,14 @@ def delete_pod(self, pod: V1Pod): """Deletes POD""" try: self._client.delete_namespaced_pod( - pod.metadata.name, pod.metadata.namespace, body=client.V1DeleteOptions()) + pod.metadata.name, pod.metadata.namespace, body=client.V1DeleteOptions() + ) except ApiException as e: # If the pod is already deleted if e.status != 404: raise - def start_pod( - self, - pod: V1Pod, - startup_timeout: int = 120): + def start_pod(self, pod: V1Pod, startup_timeout: int = 120): """ Launches the pod synchronously and waits for completion. @@ -170,13 +169,11 @@ def parse_log_line(self, line: str) -> Tuple[str, str]: if split_at == -1: raise Exception(f'Log not in "{{timestamp}} {{log}}" format. Got: {line}') timestamp = line[:split_at] - message = line[split_at + 1:].rstrip() + message = line[split_at + 1 :].rstrip() return timestamp, message def _task_status(self, event): - self.log.info( - 'Event: %s had an event of type %s', - event.metadata.name, event.status.phase) + self.log.info('Event: %s had an event of type %s', event.metadata.name, event.status.phase) status = self.process_status(event.metadata.name, event.status.phase) return status @@ -193,22 +190,19 @@ def pod_is_running(self, pod: V1Pod): def base_container_is_running(self, pod: V1Pod): """Tests if base container is running""" event = self.read_pod(pod) - status = next(iter(filter(lambda s: s.name == 'base', - event.status.container_statuses)), None) + status = next(iter(filter(lambda s: s.name == 'base', event.status.container_statuses)), None) if not status: return False return status.state.running is not None - @tenacity.retry( - stop=tenacity.stop_after_attempt(3), - wait=tenacity.wait_exponential(), - reraise=True - ) - def read_pod_logs(self, - pod: V1Pod, - tail_lines: Optional[int] = None, - timestamps: bool = False, - since_seconds: Optional[int] = None): + @tenacity.retry(stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_exponential(), reraise=True) + def read_pod_logs( + self, + pod: V1Pod, + tail_lines: Optional[int] = None, + timestamps: bool = False, + since_seconds: Optional[int] = None, + ): """Reads log from the POD""" additional_kwargs = {} if since_seconds: @@ -225,54 +219,44 @@ def read_pod_logs(self, follow=True, timestamps=timestamps, _preload_content=False, - **additional_kwargs + **additional_kwargs, ) except BaseHTTPError as e: - raise AirflowException( - f'There was an error reading the kubernetes API: {e}' - ) + raise AirflowException(f'There was an error reading the kubernetes API: {e}') - @tenacity.retry( - stop=tenacity.stop_after_attempt(3), - wait=tenacity.wait_exponential(), - reraise=True - ) + @tenacity.retry(stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_exponential(), reraise=True) def read_pod_events(self, pod): """Reads events from the POD""" try: return self._client.list_namespaced_event( - namespace=pod.metadata.namespace, - field_selector=f"involvedObject.name={pod.metadata.name}" + namespace=pod.metadata.namespace, field_selector=f"involvedObject.name={pod.metadata.name}" ) except BaseHTTPError as e: - raise AirflowException( - f'There was an error reading the kubernetes API: {e}' - ) + raise AirflowException(f'There was an error reading the kubernetes API: {e}') - @tenacity.retry( - stop=tenacity.stop_after_attempt(3), - wait=tenacity.wait_exponential(), - reraise=True - ) + @tenacity.retry(stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_exponential(), reraise=True) def read_pod(self, pod: V1Pod): """Read POD information""" try: return self._client.read_namespaced_pod(pod.metadata.name, pod.metadata.namespace) except BaseHTTPError as e: - raise AirflowException( - f'There was an error reading the kubernetes API: {e}' - ) + raise AirflowException(f'There was an error reading the kubernetes API: {e}') def _extract_xcom(self, pod: V1Pod): - resp = kubernetes_stream(self._client.connect_get_namespaced_pod_exec, - pod.metadata.name, pod.metadata.namespace, - container=PodDefaults.SIDECAR_CONTAINER_NAME, - command=['/bin/sh'], stdin=True, stdout=True, - stderr=True, tty=False, - _preload_content=False) + resp = kubernetes_stream( + self._client.connect_get_namespaced_pod_exec, + pod.metadata.name, + pod.metadata.namespace, + container=PodDefaults.SIDECAR_CONTAINER_NAME, + command=['/bin/sh'], + stdin=True, + stdout=True, + stderr=True, + tty=False, + _preload_content=False, + ) try: - result = self._exec_pod_command( - resp, f'cat {PodDefaults.XCOM_MOUNT_PATH}/return.json') + result = self._exec_pod_command(resp, f'cat {PodDefaults.XCOM_MOUNT_PATH}/return.json') self._exec_pod_command(resp, 'kill -s SIGINT 1') finally: resp.close() diff --git a/airflow/kubernetes/pod_runtime_info_env.py b/airflow/kubernetes/pod_runtime_info_env.py index 57ac8de4bd33d..fa841cdacd15b 100644 --- a/airflow/kubernetes/pod_runtime_info_env.py +++ b/airflow/kubernetes/pod_runtime_info_env.py @@ -42,11 +42,7 @@ def to_k8s_client_obj(self) -> k8s.V1EnvVar: """:return: kubernetes.client.models.V1EnvVar""" return k8s.V1EnvVar( name=self.name, - value_from=k8s.V1EnvVarSource( - field_ref=k8s.V1ObjectFieldSelector( - field_path=self.field_path - ) - ) + value_from=k8s.V1EnvVarSource(field_ref=k8s.V1ObjectFieldSelector(field_path=self.field_path)), ) def attach_to_pod(self, pod: k8s.V1Pod) -> k8s.V1Pod: diff --git a/airflow/kubernetes/refresh_config.py b/airflow/kubernetes/refresh_config.py index 53b7decfd1207..9067cb15da71b 100644 --- a/airflow/kubernetes/refresh_config.py +++ b/airflow/kubernetes/refresh_config.py @@ -104,7 +104,8 @@ def _get_kube_config_loader_for_yaml_file(filename, **kwargs) -> Optional[Refres return RefreshKubeConfigLoader( config_dict=yaml.safe_load(f), config_base_path=os.path.abspath(os.path.dirname(filename)), - **kwargs) + **kwargs, + ) def load_kube_config(client_configuration, config_file=None, context=None): @@ -117,6 +118,5 @@ def load_kube_config(client_configuration, config_file=None, context=None): if config_file is None: config_file = os.path.expanduser(KUBE_CONFIG_DEFAULT_LOCATION) - loader = _get_kube_config_loader_for_yaml_file( - config_file, active_context=context, config_persister=None) + loader = _get_kube_config_loader_for_yaml_file(config_file, active_context=context, config_persister=None) loader.load_and_set(client_configuration) diff --git a/airflow/kubernetes/secret.py b/airflow/kubernetes/secret.py index 197464f562a78..20ed27b1ffb31 100644 --- a/airflow/kubernetes/secret.py +++ b/airflow/kubernetes/secret.py @@ -62,9 +62,7 @@ def __init__(self, deploy_type, deploy_target, secret, key=None, items=None): self.deploy_target = deploy_target.upper() if key is not None and deploy_target is None: - raise AirflowConfigException( - 'If `key` is set, `deploy_target` should not be None' - ) + raise AirflowConfigException('If `key` is set, `deploy_target` should not be None') self.secret = secret self.key = key @@ -74,18 +72,13 @@ def to_env_secret(self) -> k8s.V1EnvVar: return k8s.V1EnvVar( name=self.deploy_target, value_from=k8s.V1EnvVarSource( - secret_key_ref=k8s.V1SecretKeySelector( - name=self.secret, - key=self.key - ) - ) + secret_key_ref=k8s.V1SecretKeySelector(name=self.secret, key=self.key) + ), ) def to_env_from_secret(self) -> k8s.V1EnvFromSource: """Reads from environment to secret""" - return k8s.V1EnvFromSource( - secret_ref=k8s.V1SecretEnvSource(name=self.secret) - ) + return k8s.V1EnvFromSource(secret_ref=k8s.V1SecretEnvSource(name=self.secret)) def to_volume_secret(self) -> Tuple[k8s.V1Volume, k8s.V1VolumeMount]: """Converts to volume secret""" @@ -93,14 +86,7 @@ def to_volume_secret(self) -> Tuple[k8s.V1Volume, k8s.V1VolumeMount]: volume = k8s.V1Volume(name=vol_id, secret=k8s.V1SecretVolumeSource(secret_name=self.secret)) if self.items: volume.secret.items = self.items - return ( - volume, - k8s.V1VolumeMount( - mount_path=self.deploy_target, - name=vol_id, - read_only=True - ) - ) + return (volume, k8s.V1VolumeMount(mount_path=self.deploy_target, name=vol_id, read_only=True)) def attach_to_pod(self, pod: k8s.V1Pod) -> k8s.V1Pod: """Attaches to pod""" @@ -123,16 +109,11 @@ def attach_to_pod(self, pod: k8s.V1Pod) -> k8s.V1Pod: def __eq__(self, other): return ( - self.deploy_type == other.deploy_type and - self.deploy_target == other.deploy_target and - self.secret == other.secret and - self.key == other.key + self.deploy_type == other.deploy_type + and self.deploy_target == other.deploy_target + and self.secret == other.secret + and self.key == other.key ) def __repr__(self): - return 'Secret({}, {}, {}, {})'.format( - self.deploy_type, - self.deploy_target, - self.secret, - self.key - ) + return f'Secret({self.deploy_type}, {self.deploy_target}, {self.secret}, {self.key})' diff --git a/airflow/lineage/__init__.py b/airflow/lineage/__init__.py index 7641f39489ede..0f2a961e6dd2d 100644 --- a/airflow/lineage/__init__.py +++ b/airflow/lineage/__init__.py @@ -54,9 +54,14 @@ def _get_instance(meta: Metadata): def _render_object(obj: Any, context) -> Any: """Renders a attr annotated object. Will set non serializable attributes to none""" - return structure(json.loads(ENV.from_string( - json.dumps(unstructure(obj), default=lambda o: None) - ).render(**context).encode('utf-8')), type(obj)) + return structure( + json.loads( + ENV.from_string(json.dumps(unstructure(obj), default=lambda o: None)) + .render(**context) + .encode('utf-8') + ), + type(obj), + ) def _to_dataset(obj: Any, source: str) -> Optional[Metadata]: @@ -81,26 +86,21 @@ def apply_lineage(func: T) -> T: @wraps(func) def wrapper(self, context, *args, **kwargs): - self.log.debug("Lineage called with inlets: %s, outlets: %s", - self.inlets, self.outlets) + self.log.debug("Lineage called with inlets: %s, outlets: %s", self.inlets, self.outlets) ret_val = func(self, context, *args, **kwargs) - outlets = [unstructure(_to_dataset(x, f"{self.dag_id}.{self.task_id}")) - for x in self.outlets] - inlets = [unstructure(_to_dataset(x, None)) - for x in self.inlets] + outlets = [unstructure(_to_dataset(x, f"{self.dag_id}.{self.task_id}")) for x in self.outlets] + inlets = [unstructure(_to_dataset(x, None)) for x in self.inlets] if self.outlets: - self.xcom_push(context, - key=PIPELINE_OUTLETS, - value=outlets, - execution_date=context['ti'].execution_date) + self.xcom_push( + context, key=PIPELINE_OUTLETS, value=outlets, execution_date=context['ti'].execution_date + ) if self.inlets: - self.xcom_push(context, - key=PIPELINE_INLETS, - value=inlets, - execution_date=context['ti'].execution_date) + self.xcom_push( + context, key=PIPELINE_INLETS, value=inlets, execution_date=context['ti'].execution_date + ) return ret_val @@ -123,27 +123,28 @@ def wrapper(self, context, *args, **kwargs): self.log.debug("Preparing lineage inlets and outlets") if isinstance(self._inlets, (str, Operator)) or attr.has(self._inlets): - self._inlets = [self._inlets, ] + self._inlets = [ + self._inlets, + ] if self._inlets and isinstance(self._inlets, list): # get task_ids that are specified as parameter and make sure they are upstream - task_ids = set( - filter(lambda x: isinstance(x, str) and x.lower() != AUTO, self._inlets) - ).union( - map(lambda op: op.task_id, - filter(lambda op: isinstance(op, Operator), self._inlets)) - ).intersection(self.get_flat_relative_ids(upstream=True)) + task_ids = ( + set(filter(lambda x: isinstance(x, str) and x.lower() != AUTO, self._inlets)) + .union(map(lambda op: op.task_id, filter(lambda op: isinstance(op, Operator), self._inlets))) + .intersection(self.get_flat_relative_ids(upstream=True)) + ) # pick up unique direct upstream task_ids if AUTO is specified if AUTO.upper() in self._inlets or AUTO.lower() in self._inlets: task_ids = task_ids.union(task_ids.symmetric_difference(self.upstream_task_ids)) - _inlets = self.xcom_pull(context, task_ids=task_ids, - dag_id=self.dag_id, key=PIPELINE_OUTLETS) + _inlets = self.xcom_pull(context, task_ids=task_ids, dag_id=self.dag_id, key=PIPELINE_OUTLETS) # re-instantiate the obtained inlets - _inlets = [_get_instance(structure(item, Metadata)) - for sublist in _inlets if sublist for item in sublist] + _inlets = [ + _get_instance(structure(item, Metadata)) for sublist in _inlets if sublist for item in sublist + ] self.inlets.extend(_inlets) self.inlets.extend(self._inlets) @@ -152,16 +153,16 @@ def wrapper(self, context, *args, **kwargs): raise AttributeError("inlets is not a list, operator, string or attr annotated object") if not isinstance(self._outlets, list): - self._outlets = [self._outlets, ] + self._outlets = [ + self._outlets, + ] self.outlets.extend(self._outlets) # render inlets and outlets - self.inlets = [_render_object(i, context) - for i in self.inlets if attr.has(i)] + self.inlets = [_render_object(i, context) for i in self.inlets if attr.has(i)] - self.outlets = [_render_object(i, context) - for i in self.outlets if attr.has(i)] + self.outlets = [_render_object(i, context) for i in self.outlets if attr.has(i)] self.log.debug("inlets: %s, outlets: %s", self.inlets, self.outlets) return func(self, context, *args, **kwargs) diff --git a/airflow/lineage/entities.py b/airflow/lineage/entities.py index 9ded4ff67c935..f2bad75796197 100644 --- a/airflow/lineage/entities.py +++ b/airflow/lineage/entities.py @@ -63,7 +63,7 @@ class Column: # https://github.com/python/mypy/issues/6136 is resolved, use # `attr.converters.default_if_none(default=False)` # pylint: disable=missing-docstring -def default_if_none(arg: Optional[bool]) -> bool: # noqa: D103 +def default_if_none(arg: Optional[bool]) -> bool: # noqa: D103 return arg or False diff --git a/airflow/logging_config.py b/airflow/logging_config.py index e2b6adb2537da..e827273ce5a1b 100644 --- a/airflow/logging_config.py +++ b/airflow/logging_config.py @@ -43,19 +43,12 @@ def configure_logging(): if not isinstance(logging_config, dict): raise ValueError("Logging Config should be of dict type") - log.info( - 'Successfully imported user-defined logging config from %s', - logging_class_path - ) + log.info('Successfully imported user-defined logging config from %s', logging_class_path) except Exception as err: # Import default logging configurations. - raise ImportError( - 'Unable to load custom logging from {} due to {}' - .format(logging_class_path, err) - ) + raise ImportError(f'Unable to load custom logging from {logging_class_path} due to {err}') else: - logging_class_path = 'airflow.config_templates.' \ - 'airflow_local_settings.DEFAULT_LOGGING_CONFIG' + logging_class_path = 'airflow.config_templates.airflow_local_settings.DEFAULT_LOGGING_CONFIG' logging_config = import_string(logging_class_path) log.debug('Unable to load custom logging, using default config instead') @@ -73,7 +66,7 @@ def configure_logging(): return logging_class_path -def validate_logging_config(logging_config): # pylint: disable=unused-argument +def validate_logging_config(logging_config): # pylint: disable=unused-argument """Validate the provided Logging Config""" # Now lets validate the other logging-related settings task_log_reader = conf.get('logging', 'task_log_reader') diff --git a/airflow/macros/hive.py b/airflow/macros/hive.py index c6e1e200fc919..39c66cdec1229 100644 --- a/airflow/macros/hive.py +++ b/airflow/macros/hive.py @@ -20,8 +20,8 @@ def max_partition( - table, schema="default", field=None, filter_map=None, - metastore_conn_id='metastore_default'): + table, schema="default", field=None, filter_map=None, metastore_conn_id='metastore_default' +): """ Gets the max partition for a table. @@ -47,11 +47,11 @@ def max_partition( '2015-01-01' """ from airflow.providers.apache.hive.hooks.hive import HiveMetastoreHook + if '.' in table: schema, table = table.split('.') hive_hook = HiveMetastoreHook(metastore_conn_id=metastore_conn_id) - return hive_hook.max_partition( - schema=schema, table_name=table, field=field, filter_map=filter_map) + return hive_hook.max_partition(schema=schema, table_name=table, field=field, filter_map=filter_map) def _closest_date(target_dt, date_list, before_target=None): @@ -79,9 +79,7 @@ def _closest_date(target_dt, date_list, before_target=None): return min(date_list, key=time_after).date() -def closest_ds_partition( - table, ds, before=True, schema="default", - metastore_conn_id='metastore_default'): +def closest_ds_partition(table, ds, before=True, schema="default", metastore_conn_id='metastore_default'): """ This function finds the date in a list closest to the target date. An optional parameter can be given to get the closest before or after. @@ -104,6 +102,7 @@ def closest_ds_partition( '2015-01-01' """ from airflow.providers.apache.hive.hooks.hive import HiveMetastoreHook + if '.' in table: schema, table = table.split('.') hive_hook = HiveMetastoreHook(metastore_conn_id=metastore_conn_id) @@ -114,8 +113,7 @@ def closest_ds_partition( if ds in part_vals: return ds else: - parts = [datetime.datetime.strptime(pv, '%Y-%m-%d') - for pv in part_vals] + parts = [datetime.datetime.strptime(pv, '%Y-%m-%d') for pv in part_vals] target_dt = datetime.datetime.strptime(ds, '%Y-%m-%d') closest_ds = _closest_date(target_dt, parts, before_target=before) return closest_ds.isoformat() diff --git a/airflow/migrations/env.py b/airflow/migrations/env.py index 2230724eb1a45..459cb7c176662 100644 --- a/airflow/migrations/env.py +++ b/airflow/migrations/env.py @@ -72,7 +72,8 @@ def run_migrations_offline(): target_metadata=target_metadata, literal_binds=True, compare_type=COMPARE_TYPE, - render_as_batch=True) + render_as_batch=True, + ) with context.begin_transaction(): context.run_migrations() @@ -94,7 +95,7 @@ def run_migrations_online(): target_metadata=target_metadata, compare_type=COMPARE_TYPE, include_object=include_object, - render_as_batch=True + render_as_batch=True, ) with context.begin_transaction(): diff --git a/airflow/migrations/versions/03bc53e68815_add_sm_dag_index.py b/airflow/migrations/versions/03bc53e68815_add_sm_dag_index.py index f3e7330ac11e3..d66bf5c96e3d5 100644 --- a/airflow/migrations/versions/03bc53e68815_add_sm_dag_index.py +++ b/airflow/migrations/versions/03bc53e68815_add_sm_dag_index.py @@ -32,9 +32,9 @@ depends_on = None -def upgrade(): # noqa: D103 +def upgrade(): # noqa: D103 op.create_index('sm_dag', 'sla_miss', ['dag_id'], unique=False) -def downgrade(): # noqa: D103 +def downgrade(): # noqa: D103 op.drop_index('sm_dag', table_name='sla_miss') diff --git a/airflow/migrations/versions/05f30312d566_merge_heads.py b/airflow/migrations/versions/05f30312d566_merge_heads.py index 68e7dbd3a0943..ffe2330196270 100644 --- a/airflow/migrations/versions/05f30312d566_merge_heads.py +++ b/airflow/migrations/versions/05f30312d566_merge_heads.py @@ -30,9 +30,9 @@ depends_on = None -def upgrade(): # noqa: D103 +def upgrade(): # noqa: D103 pass -def downgrade(): # noqa: D103 +def downgrade(): # noqa: D103 pass diff --git a/airflow/migrations/versions/0a2a5b66e19d_add_task_reschedule_table.py b/airflow/migrations/versions/0a2a5b66e19d_add_task_reschedule_table.py index 378c62dfde65e..4c572f44ec5a2 100644 --- a/airflow/migrations/versions/0a2a5b66e19d_add_task_reschedule_table.py +++ b/airflow/migrations/versions/0a2a5b66e19d_add_task_reschedule_table.py @@ -41,19 +41,19 @@ # For Microsoft SQL Server, TIMESTAMP is a row-id type, # having nothing to do with date-time. DateTime() will # be sufficient. -def mssql_timestamp(): # noqa: D103 +def mssql_timestamp(): # noqa: D103 return sa.DateTime() -def mysql_timestamp(): # noqa: D103 +def mysql_timestamp(): # noqa: D103 return mysql.TIMESTAMP(fsp=6) -def sa_timestamp(): # noqa: D103 +def sa_timestamp(): # noqa: D103 return sa.TIMESTAMP(timezone=True) -def upgrade(): # noqa: D103 +def upgrade(): # noqa: D103 # See 0e2a74e0fc9f_add_time_zone_awareness conn = op.get_bind() if conn.dialect.name == 'mysql': @@ -79,16 +79,12 @@ def upgrade(): # noqa: D103 sa.ForeignKeyConstraint( ['task_id', 'dag_id', 'execution_date'], ['task_instance.task_id', 'task_instance.dag_id', 'task_instance.execution_date'], - name='task_reschedule_dag_task_date_fkey') - ) - op.create_index( - INDEX_NAME, - TABLE_NAME, - ['dag_id', 'task_id', 'execution_date'], - unique=False + name='task_reschedule_dag_task_date_fkey', + ), ) + op.create_index(INDEX_NAME, TABLE_NAME, ['dag_id', 'task_id', 'execution_date'], unique=False) -def downgrade(): # noqa: D103 +def downgrade(): # noqa: D103 op.drop_index(INDEX_NAME, table_name=TABLE_NAME) op.drop_table(TABLE_NAME) diff --git a/airflow/migrations/versions/0e2a74e0fc9f_add_time_zone_awareness.py b/airflow/migrations/versions/0e2a74e0fc9f_add_time_zone_awareness.py index daa6a36f1ea6d..f18809c994b10 100644 --- a/airflow/migrations/versions/0e2a74e0fc9f_add_time_zone_awareness.py +++ b/airflow/migrations/versions/0e2a74e0fc9f_add_time_zone_awareness.py @@ -34,16 +34,14 @@ depends_on = None -def upgrade(): # noqa: D103 +def upgrade(): # noqa: D103 conn = op.get_bind() if conn.dialect.name == "mysql": conn.execute("SET time_zone = '+00:00'") cur = conn.execute("SELECT @@explicit_defaults_for_timestamp") res = cur.fetchall() if res[0][0] == 0: - raise Exception( - "Global variable explicit_defaults_for_timestamp needs to be on (1) for mysql" - ) + raise Exception("Global variable explicit_defaults_for_timestamp needs to be on (1) for mysql") op.alter_column( table_name="chart", @@ -56,12 +54,8 @@ def upgrade(): # noqa: D103 column_name="last_scheduler_run", type_=mysql.TIMESTAMP(fsp=6), ) - op.alter_column( - table_name="dag", column_name="last_pickled", type_=mysql.TIMESTAMP(fsp=6) - ) - op.alter_column( - table_name="dag", column_name="last_expired", type_=mysql.TIMESTAMP(fsp=6) - ) + op.alter_column(table_name="dag", column_name="last_pickled", type_=mysql.TIMESTAMP(fsp=6)) + op.alter_column(table_name="dag", column_name="last_expired", type_=mysql.TIMESTAMP(fsp=6)) op.alter_column( table_name="dag_pickle", @@ -74,12 +68,8 @@ def upgrade(): # noqa: D103 column_name="execution_date", type_=mysql.TIMESTAMP(fsp=6), ) - op.alter_column( - table_name="dag_run", column_name="start_date", type_=mysql.TIMESTAMP(fsp=6) - ) - op.alter_column( - table_name="dag_run", column_name="end_date", type_=mysql.TIMESTAMP(fsp=6) - ) + op.alter_column(table_name="dag_run", column_name="start_date", type_=mysql.TIMESTAMP(fsp=6)) + op.alter_column(table_name="dag_run", column_name="end_date", type_=mysql.TIMESTAMP(fsp=6)) op.alter_column( table_name="import_error", @@ -87,24 +77,16 @@ def upgrade(): # noqa: D103 type_=mysql.TIMESTAMP(fsp=6), ) - op.alter_column( - table_name="job", column_name="start_date", type_=mysql.TIMESTAMP(fsp=6) - ) - op.alter_column( - table_name="job", column_name="end_date", type_=mysql.TIMESTAMP(fsp=6) - ) + op.alter_column(table_name="job", column_name="start_date", type_=mysql.TIMESTAMP(fsp=6)) + op.alter_column(table_name="job", column_name="end_date", type_=mysql.TIMESTAMP(fsp=6)) op.alter_column( table_name="job", column_name="latest_heartbeat", type_=mysql.TIMESTAMP(fsp=6), ) - op.alter_column( - table_name="log", column_name="dttm", type_=mysql.TIMESTAMP(fsp=6) - ) - op.alter_column( - table_name="log", column_name="execution_date", type_=mysql.TIMESTAMP(fsp=6) - ) + op.alter_column(table_name="log", column_name="dttm", type_=mysql.TIMESTAMP(fsp=6)) + op.alter_column(table_name="log", column_name="execution_date", type_=mysql.TIMESTAMP(fsp=6)) op.alter_column( table_name="sla_miss", @@ -112,9 +94,7 @@ def upgrade(): # noqa: D103 type_=mysql.TIMESTAMP(fsp=6), nullable=False, ) - op.alter_column( - table_name="sla_miss", column_name="timestamp", type_=mysql.TIMESTAMP(fsp=6) - ) + op.alter_column(table_name="sla_miss", column_name="timestamp", type_=mysql.TIMESTAMP(fsp=6)) op.alter_column( table_name="task_fail", @@ -126,9 +106,7 @@ def upgrade(): # noqa: D103 column_name="start_date", type_=mysql.TIMESTAMP(fsp=6), ) - op.alter_column( - table_name="task_fail", column_name="end_date", type_=mysql.TIMESTAMP(fsp=6) - ) + op.alter_column(table_name="task_fail", column_name="end_date", type_=mysql.TIMESTAMP(fsp=6)) op.alter_column( table_name="task_instance", @@ -152,9 +130,7 @@ def upgrade(): # noqa: D103 type_=mysql.TIMESTAMP(fsp=6), ) - op.alter_column( - table_name="xcom", column_name="timestamp", type_=mysql.TIMESTAMP(fsp=6) - ) + op.alter_column(table_name="xcom", column_name="timestamp", type_=mysql.TIMESTAMP(fsp=6)) op.alter_column( table_name="xcom", column_name="execution_date", @@ -225,18 +201,14 @@ def upgrade(): # noqa: D103 column_name="start_date", type_=sa.TIMESTAMP(timezone=True), ) - op.alter_column( - table_name="job", column_name="end_date", type_=sa.TIMESTAMP(timezone=True) - ) + op.alter_column(table_name="job", column_name="end_date", type_=sa.TIMESTAMP(timezone=True)) op.alter_column( table_name="job", column_name="latest_heartbeat", type_=sa.TIMESTAMP(timezone=True), ) - op.alter_column( - table_name="log", column_name="dttm", type_=sa.TIMESTAMP(timezone=True) - ) + op.alter_column(table_name="log", column_name="dttm", type_=sa.TIMESTAMP(timezone=True)) op.alter_column( table_name="log", column_name="execution_date", @@ -305,25 +277,19 @@ def upgrade(): # noqa: D103 ) -def downgrade(): # noqa: D103 +def downgrade(): # noqa: D103 conn = op.get_bind() if conn.dialect.name == "mysql": conn.execute("SET time_zone = '+00:00'") - op.alter_column( - table_name="chart", column_name="last_modified", type_=mysql.DATETIME(fsp=6) - ) + op.alter_column(table_name="chart", column_name="last_modified", type_=mysql.DATETIME(fsp=6)) op.alter_column( table_name="dag", column_name="last_scheduler_run", type_=mysql.DATETIME(fsp=6), ) - op.alter_column( - table_name="dag", column_name="last_pickled", type_=mysql.DATETIME(fsp=6) - ) - op.alter_column( - table_name="dag", column_name="last_expired", type_=mysql.DATETIME(fsp=6) - ) + op.alter_column(table_name="dag", column_name="last_pickled", type_=mysql.DATETIME(fsp=6)) + op.alter_column(table_name="dag", column_name="last_expired", type_=mysql.DATETIME(fsp=6)) op.alter_column( table_name="dag_pickle", @@ -336,12 +302,8 @@ def downgrade(): # noqa: D103 column_name="execution_date", type_=mysql.DATETIME(fsp=6), ) - op.alter_column( - table_name="dag_run", column_name="start_date", type_=mysql.DATETIME(fsp=6) - ) - op.alter_column( - table_name="dag_run", column_name="end_date", type_=mysql.DATETIME(fsp=6) - ) + op.alter_column(table_name="dag_run", column_name="start_date", type_=mysql.DATETIME(fsp=6)) + op.alter_column(table_name="dag_run", column_name="end_date", type_=mysql.DATETIME(fsp=6)) op.alter_column( table_name="import_error", @@ -349,24 +311,16 @@ def downgrade(): # noqa: D103 type_=mysql.DATETIME(fsp=6), ) - op.alter_column( - table_name="job", column_name="start_date", type_=mysql.DATETIME(fsp=6) - ) - op.alter_column( - table_name="job", column_name="end_date", type_=mysql.DATETIME(fsp=6) - ) + op.alter_column(table_name="job", column_name="start_date", type_=mysql.DATETIME(fsp=6)) + op.alter_column(table_name="job", column_name="end_date", type_=mysql.DATETIME(fsp=6)) op.alter_column( table_name="job", column_name="latest_heartbeat", type_=mysql.DATETIME(fsp=6), ) - op.alter_column( - table_name="log", column_name="dttm", type_=mysql.DATETIME(fsp=6) - ) - op.alter_column( - table_name="log", column_name="execution_date", type_=mysql.DATETIME(fsp=6) - ) + op.alter_column(table_name="log", column_name="dttm", type_=mysql.DATETIME(fsp=6)) + op.alter_column(table_name="log", column_name="execution_date", type_=mysql.DATETIME(fsp=6)) op.alter_column( table_name="sla_miss", @@ -374,9 +328,7 @@ def downgrade(): # noqa: D103 type_=mysql.DATETIME(fsp=6), nullable=False, ) - op.alter_column( - table_name="sla_miss", column_name="timestamp", type_=mysql.DATETIME(fsp=6) - ) + op.alter_column(table_name="sla_miss", column_name="timestamp", type_=mysql.DATETIME(fsp=6)) op.alter_column( table_name="task_fail", @@ -388,9 +340,7 @@ def downgrade(): # noqa: D103 column_name="start_date", type_=mysql.DATETIME(fsp=6), ) - op.alter_column( - table_name="task_fail", column_name="end_date", type_=mysql.DATETIME(fsp=6) - ) + op.alter_column(table_name="task_fail", column_name="end_date", type_=mysql.DATETIME(fsp=6)) op.alter_column( table_name="task_instance", @@ -414,12 +364,8 @@ def downgrade(): # noqa: D103 type_=mysql.DATETIME(fsp=6), ) - op.alter_column( - table_name="xcom", column_name="timestamp", type_=mysql.DATETIME(fsp=6) - ) - op.alter_column( - table_name="xcom", column_name="execution_date", type_=mysql.DATETIME(fsp=6) - ) + op.alter_column(table_name="xcom", column_name="timestamp", type_=mysql.DATETIME(fsp=6)) + op.alter_column(table_name="xcom", column_name="execution_date", type_=mysql.DATETIME(fsp=6)) else: if conn.dialect.name in ("sqlite", "mssql"): return @@ -429,48 +375,26 @@ def downgrade(): # noqa: D103 if conn.dialect.name == "postgresql": conn.execute("set timezone=UTC") - op.alter_column( - table_name="chart", column_name="last_modified", type_=sa.DateTime() - ) + op.alter_column(table_name="chart", column_name="last_modified", type_=sa.DateTime()) - op.alter_column( - table_name="dag", column_name="last_scheduler_run", type_=sa.DateTime() - ) - op.alter_column( - table_name="dag", column_name="last_pickled", type_=sa.DateTime() - ) - op.alter_column( - table_name="dag", column_name="last_expired", type_=sa.DateTime() - ) + op.alter_column(table_name="dag", column_name="last_scheduler_run", type_=sa.DateTime()) + op.alter_column(table_name="dag", column_name="last_pickled", type_=sa.DateTime()) + op.alter_column(table_name="dag", column_name="last_expired", type_=sa.DateTime()) - op.alter_column( - table_name="dag_pickle", column_name="created_dttm", type_=sa.DateTime() - ) + op.alter_column(table_name="dag_pickle", column_name="created_dttm", type_=sa.DateTime()) - op.alter_column( - table_name="dag_run", column_name="execution_date", type_=sa.DateTime() - ) - op.alter_column( - table_name="dag_run", column_name="start_date", type_=sa.DateTime() - ) - op.alter_column( - table_name="dag_run", column_name="end_date", type_=sa.DateTime() - ) + op.alter_column(table_name="dag_run", column_name="execution_date", type_=sa.DateTime()) + op.alter_column(table_name="dag_run", column_name="start_date", type_=sa.DateTime()) + op.alter_column(table_name="dag_run", column_name="end_date", type_=sa.DateTime()) - op.alter_column( - table_name="import_error", column_name="timestamp", type_=sa.DateTime() - ) + op.alter_column(table_name="import_error", column_name="timestamp", type_=sa.DateTime()) op.alter_column(table_name="job", column_name="start_date", type_=sa.DateTime()) op.alter_column(table_name="job", column_name="end_date", type_=sa.DateTime()) - op.alter_column( - table_name="job", column_name="latest_heartbeat", type_=sa.DateTime() - ) + op.alter_column(table_name="job", column_name="latest_heartbeat", type_=sa.DateTime()) op.alter_column(table_name="log", column_name="dttm", type_=sa.DateTime()) - op.alter_column( - table_name="log", column_name="execution_date", type_=sa.DateTime() - ) + op.alter_column(table_name="log", column_name="execution_date", type_=sa.DateTime()) op.alter_column( table_name="sla_miss", @@ -478,19 +402,11 @@ def downgrade(): # noqa: D103 type_=sa.DateTime(), nullable=False, ) - op.alter_column( - table_name="sla_miss", column_name="timestamp", type_=sa.DateTime() - ) + op.alter_column(table_name="sla_miss", column_name="timestamp", type_=sa.DateTime()) - op.alter_column( - table_name="task_fail", column_name="execution_date", type_=sa.DateTime() - ) - op.alter_column( - table_name="task_fail", column_name="start_date", type_=sa.DateTime() - ) - op.alter_column( - table_name="task_fail", column_name="end_date", type_=sa.DateTime() - ) + op.alter_column(table_name="task_fail", column_name="execution_date", type_=sa.DateTime()) + op.alter_column(table_name="task_fail", column_name="start_date", type_=sa.DateTime()) + op.alter_column(table_name="task_fail", column_name="end_date", type_=sa.DateTime()) op.alter_column( table_name="task_instance", @@ -498,17 +414,9 @@ def downgrade(): # noqa: D103 type_=sa.DateTime(), nullable=False, ) - op.alter_column( - table_name="task_instance", column_name="start_date", type_=sa.DateTime() - ) - op.alter_column( - table_name="task_instance", column_name="end_date", type_=sa.DateTime() - ) - op.alter_column( - table_name="task_instance", column_name="queued_dttm", type_=sa.DateTime() - ) + op.alter_column(table_name="task_instance", column_name="start_date", type_=sa.DateTime()) + op.alter_column(table_name="task_instance", column_name="end_date", type_=sa.DateTime()) + op.alter_column(table_name="task_instance", column_name="queued_dttm", type_=sa.DateTime()) op.alter_column(table_name="xcom", column_name="timestamp", type_=sa.DateTime()) - op.alter_column( - table_name="xcom", column_name="execution_date", type_=sa.DateTime() - ) + op.alter_column(table_name="xcom", column_name="execution_date", type_=sa.DateTime()) diff --git a/airflow/migrations/versions/127d2bf2dfa7_add_dag_id_state_index_on_dag_run_table.py b/airflow/migrations/versions/127d2bf2dfa7_add_dag_id_state_index_on_dag_run_table.py index df79a9512c965..855e55c5ee15d 100644 --- a/airflow/migrations/versions/127d2bf2dfa7_add_dag_id_state_index_on_dag_run_table.py +++ b/airflow/migrations/versions/127d2bf2dfa7_add_dag_id_state_index_on_dag_run_table.py @@ -32,9 +32,9 @@ depends_on = None -def upgrade(): # noqa: D103 +def upgrade(): # noqa: D103 op.create_index('dag_id_state', 'dag_run', ['dag_id', 'state'], unique=False) -def downgrade(): # noqa: D103 +def downgrade(): # noqa: D103 op.drop_index('dag_id_state', table_name='dag_run') diff --git a/airflow/migrations/versions/13eb55f81627_for_compatibility.py b/airflow/migrations/versions/13eb55f81627_for_compatibility.py index ec43294be6111..538db1a49ffba 100644 --- a/airflow/migrations/versions/13eb55f81627_for_compatibility.py +++ b/airflow/migrations/versions/13eb55f81627_for_compatibility.py @@ -31,9 +31,9 @@ depends_on = None -def upgrade(): # noqa: D103 +def upgrade(): # noqa: D103 pass -def downgrade(): # noqa: D103 +def downgrade(): # noqa: D103 pass diff --git a/airflow/migrations/versions/1507a7289a2f_create_is_encrypted.py b/airflow/migrations/versions/1507a7289a2f_create_is_encrypted.py index 876f79f63a15b..7afdeb2c1831b 100644 --- a/airflow/migrations/versions/1507a7289a2f_create_is_encrypted.py +++ b/airflow/migrations/versions/1507a7289a2f_create_is_encrypted.py @@ -34,14 +34,11 @@ depends_on = None connectionhelper = sa.Table( - 'connection', - sa.MetaData(), - sa.Column('id', sa.Integer, primary_key=True), - sa.Column('is_encrypted') + 'connection', sa.MetaData(), sa.Column('id', sa.Integer, primary_key=True), sa.Column('is_encrypted') ) -def upgrade(): # noqa: D103 +def upgrade(): # noqa: D103 # first check if the user already has this done. This should only be # true for users who are upgrading from a previous version of Airflow # that predates Alembic integration @@ -55,15 +52,11 @@ def upgrade(): # noqa: D103 if 'is_encrypted' in col_names: return - op.add_column( - 'connection', - sa.Column('is_encrypted', sa.Boolean, unique=False, default=False)) + op.add_column('connection', sa.Column('is_encrypted', sa.Boolean, unique=False, default=False)) conn = op.get_bind() - conn.execute( - connectionhelper.update().values(is_encrypted=False) - ) + conn.execute(connectionhelper.update().values(is_encrypted=False)) -def downgrade(): # noqa: D103 +def downgrade(): # noqa: D103 op.drop_column('connection', 'is_encrypted') diff --git a/airflow/migrations/versions/1968acfc09e3_add_is_encrypted_column_to_variable_.py b/airflow/migrations/versions/1968acfc09e3_add_is_encrypted_column_to_variable_.py index d9b4fe015932e..e880d77fee220 100644 --- a/airflow/migrations/versions/1968acfc09e3_add_is_encrypted_column_to_variable_.py +++ b/airflow/migrations/versions/1968acfc09e3_add_is_encrypted_column_to_variable_.py @@ -33,9 +33,9 @@ depends_on = None -def upgrade(): # noqa: D103 +def upgrade(): # noqa: D103 op.add_column('variable', sa.Column('is_encrypted', sa.Boolean, default=False)) -def downgrade(): # noqa: D103 +def downgrade(): # noqa: D103 op.drop_column('variable', 'is_encrypted') diff --git a/airflow/migrations/versions/1b38cef5b76e_add_dagrun.py b/airflow/migrations/versions/1b38cef5b76e_add_dagrun.py index 8edcbb4d59618..7edebfc4cf9d2 100644 --- a/airflow/migrations/versions/1b38cef5b76e_add_dagrun.py +++ b/airflow/migrations/versions/1b38cef5b76e_add_dagrun.py @@ -34,18 +34,20 @@ depends_on = None -def upgrade(): # noqa: D103 - op.create_table('dag_run', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('dag_id', sa.String(length=250), nullable=True), - sa.Column('execution_date', sa.DateTime(), nullable=True), - sa.Column('state', sa.String(length=50), nullable=True), - sa.Column('run_id', sa.String(length=250), nullable=True), - sa.Column('external_trigger', sa.Boolean(), nullable=True), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('dag_id', 'execution_date'), - sa.UniqueConstraint('dag_id', 'run_id')) - - -def downgrade(): # noqa: D103 +def upgrade(): # noqa: D103 + op.create_table( + 'dag_run', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('dag_id', sa.String(length=250), nullable=True), + sa.Column('execution_date', sa.DateTime(), nullable=True), + sa.Column('state', sa.String(length=50), nullable=True), + sa.Column('run_id', sa.String(length=250), nullable=True), + sa.Column('external_trigger', sa.Boolean(), nullable=True), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('dag_id', 'execution_date'), + sa.UniqueConstraint('dag_id', 'run_id'), + ) + + +def downgrade(): # noqa: D103 op.drop_table('dag_run') diff --git a/airflow/migrations/versions/211e584da130_add_ti_state_index.py b/airflow/migrations/versions/211e584da130_add_ti_state_index.py index a6f946332118b..7df1550733d7e 100644 --- a/airflow/migrations/versions/211e584da130_add_ti_state_index.py +++ b/airflow/migrations/versions/211e584da130_add_ti_state_index.py @@ -32,9 +32,9 @@ depends_on = None -def upgrade(): # noqa: D103 +def upgrade(): # noqa: D103 op.create_index('ti_state', 'task_instance', ['state'], unique=False) -def downgrade(): # noqa: D103 +def downgrade(): # noqa: D103 op.drop_index('ti_state', table_name='task_instance') diff --git a/airflow/migrations/versions/27c6a30d7c24_add_executor_config_to_task_instance.py b/airflow/migrations/versions/27c6a30d7c24_add_executor_config_to_task_instance.py index 768a67fe529bd..d0853efb79a2f 100644 --- a/airflow/migrations/versions/27c6a30d7c24_add_executor_config_to_task_instance.py +++ b/airflow/migrations/versions/27c6a30d7c24_add_executor_config_to_task_instance.py @@ -38,9 +38,9 @@ NEW_COLUMN = "executor_config" -def upgrade(): # noqa: D103 +def upgrade(): # noqa: D103 op.add_column(TASK_INSTANCE_TABLE, sa.Column(NEW_COLUMN, sa.PickleType(pickler=dill))) -def downgrade(): # noqa: D103 +def downgrade(): # noqa: D103 op.drop_column(TASK_INSTANCE_TABLE, NEW_COLUMN) diff --git a/airflow/migrations/versions/2e541a1dcfed_task_duration.py b/airflow/migrations/versions/2e541a1dcfed_task_duration.py index b071f1c4b8d1e..12d8e2e5a608d 100644 --- a/airflow/migrations/versions/2e541a1dcfed_task_duration.py +++ b/airflow/migrations/versions/2e541a1dcfed_task_duration.py @@ -35,14 +35,16 @@ depends_on = None -def upgrade(): # noqa: D103 +def upgrade(): # noqa: D103 # use batch_alter_table to support SQLite workaround with op.batch_alter_table("task_instance") as batch_op: - batch_op.alter_column('duration', - existing_type=mysql.INTEGER(display_width=11), - type_=sa.Float(), - existing_nullable=True) + batch_op.alter_column( + 'duration', + existing_type=mysql.INTEGER(display_width=11), + type_=sa.Float(), + existing_nullable=True, + ) -def downgrade(): # noqa: D103 +def downgrade(): # noqa: D103 pass diff --git a/airflow/migrations/versions/2e82aab8ef20_rename_user_table.py b/airflow/migrations/versions/2e82aab8ef20_rename_user_table.py index 0acc346d9abb4..3dcbe47460efa 100644 --- a/airflow/migrations/versions/2e82aab8ef20_rename_user_table.py +++ b/airflow/migrations/versions/2e82aab8ef20_rename_user_table.py @@ -32,9 +32,9 @@ depends_on = None -def upgrade(): # noqa: D103 +def upgrade(): # noqa: D103 op.rename_table('user', 'users') -def downgrade(): # noqa: D103 +def downgrade(): # noqa: D103 op.rename_table('users', 'user') diff --git a/airflow/migrations/versions/338e90f54d61_more_logging_into_task_isntance.py b/airflow/migrations/versions/338e90f54d61_more_logging_into_task_isntance.py index 50245a3a8cb83..60ed6628a77dd 100644 --- a/airflow/migrations/versions/338e90f54d61_more_logging_into_task_isntance.py +++ b/airflow/migrations/versions/338e90f54d61_more_logging_into_task_isntance.py @@ -33,11 +33,11 @@ depends_on = None -def upgrade(): # noqa: D103 +def upgrade(): # noqa: D103 op.add_column('task_instance', sa.Column('operator', sa.String(length=1000), nullable=True)) op.add_column('task_instance', sa.Column('queued_dttm', sa.DateTime(), nullable=True)) -def downgrade(): # noqa: D103 +def downgrade(): # noqa: D103 op.drop_column('task_instance', 'queued_dttm') op.drop_column('task_instance', 'operator') diff --git a/airflow/migrations/versions/33ae817a1ff4_add_kubernetes_resource_checkpointing.py b/airflow/migrations/versions/33ae817a1ff4_add_kubernetes_resource_checkpointing.py index d22d87864a859..2f06756dcf054 100644 --- a/airflow/migrations/versions/33ae817a1ff4_add_kubernetes_resource_checkpointing.py +++ b/airflow/migrations/versions/33ae817a1ff4_add_kubernetes_resource_checkpointing.py @@ -35,10 +35,10 @@ RESOURCE_TABLE = "kube_resource_version" -def upgrade(): # noqa: D103 +def upgrade(): # noqa: D103 columns_and_constraints = [ sa.Column("one_row_id", sa.Boolean, server_default=sa.true(), primary_key=True), - sa.Column("resource_version", sa.String(255)) + sa.Column("resource_version", sa.String(255)), ] conn = op.get_bind() @@ -53,15 +53,10 @@ def upgrade(): # noqa: D103 sa.CheckConstraint("one_row_id", name="kube_resource_version_one_row_id") ) - table = op.create_table( - RESOURCE_TABLE, - *columns_and_constraints - ) + table = op.create_table(RESOURCE_TABLE, *columns_and_constraints) - op.bulk_insert(table, [ - {"resource_version": ""} - ]) + op.bulk_insert(table, [{"resource_version": ""}]) -def downgrade(): # noqa: D103 +def downgrade(): # noqa: D103 op.drop_table(RESOURCE_TABLE) diff --git a/airflow/migrations/versions/40e67319e3a9_dagrun_config.py b/airflow/migrations/versions/40e67319e3a9_dagrun_config.py index d123f8e383d55..96c211eebfba4 100644 --- a/airflow/migrations/versions/40e67319e3a9_dagrun_config.py +++ b/airflow/migrations/versions/40e67319e3a9_dagrun_config.py @@ -33,9 +33,9 @@ depends_on = None -def upgrade(): # noqa: D103 +def upgrade(): # noqa: D103 op.add_column('dag_run', sa.Column('conf', sa.PickleType(), nullable=True)) -def downgrade(): # noqa: D103 +def downgrade(): # noqa: D103 op.drop_column('dag_run', 'conf') diff --git a/airflow/migrations/versions/41f5f12752f8_add_superuser_field.py b/airflow/migrations/versions/41f5f12752f8_add_superuser_field.py index a7d5767f99bd1..572845b4d04d5 100644 --- a/airflow/migrations/versions/41f5f12752f8_add_superuser_field.py +++ b/airflow/migrations/versions/41f5f12752f8_add_superuser_field.py @@ -33,9 +33,9 @@ depends_on = None -def upgrade(): # noqa: D103 +def upgrade(): # noqa: D103 op.add_column('users', sa.Column('superuser', sa.Boolean(), default=False)) -def downgrade(): # noqa: D103 +def downgrade(): # noqa: D103 op.drop_column('users', 'superuser') diff --git a/airflow/migrations/versions/4446e08588_dagrun_start_end.py b/airflow/migrations/versions/4446e08588_dagrun_start_end.py index ec20c807b7a03..2ee527361d415 100644 --- a/airflow/migrations/versions/4446e08588_dagrun_start_end.py +++ b/airflow/migrations/versions/4446e08588_dagrun_start_end.py @@ -34,11 +34,11 @@ depends_on = None -def upgrade(): # noqa: D103 +def upgrade(): # noqa: D103 op.add_column('dag_run', sa.Column('end_date', sa.DateTime(), nullable=True)) op.add_column('dag_run', sa.Column('start_date', sa.DateTime(), nullable=True)) -def downgrade(): # noqa: D103 +def downgrade(): # noqa: D103 op.drop_column('dag_run', 'start_date') op.drop_column('dag_run', 'end_date') diff --git a/airflow/migrations/versions/4addfa1236f1_add_fractional_seconds_to_mysql_tables.py b/airflow/migrations/versions/4addfa1236f1_add_fractional_seconds_to_mysql_tables.py index 20d77ce037286..919119fb8b88c 100644 --- a/airflow/migrations/versions/4addfa1236f1_add_fractional_seconds_to_mysql_tables.py +++ b/airflow/migrations/versions/4addfa1236f1_add_fractional_seconds_to_mysql_tables.py @@ -34,126 +34,86 @@ depends_on = None -def upgrade(): # noqa: D103 +def upgrade(): # noqa: D103 if context.config.get_main_option('sqlalchemy.url').startswith('mysql'): - op.alter_column(table_name='dag', column_name='last_scheduler_run', - type_=mysql.DATETIME(fsp=6)) - op.alter_column(table_name='dag', column_name='last_pickled', - type_=mysql.DATETIME(fsp=6)) - op.alter_column(table_name='dag', column_name='last_expired', - type_=mysql.DATETIME(fsp=6)) - - op.alter_column(table_name='dag_pickle', column_name='created_dttm', - type_=mysql.DATETIME(fsp=6)) - - op.alter_column(table_name='dag_run', column_name='execution_date', - type_=mysql.DATETIME(fsp=6)) - op.alter_column(table_name='dag_run', column_name='start_date', - type_=mysql.DATETIME(fsp=6)) - op.alter_column(table_name='dag_run', column_name='end_date', - type_=mysql.DATETIME(fsp=6)) - - op.alter_column(table_name='import_error', column_name='timestamp', - type_=mysql.DATETIME(fsp=6)) - - op.alter_column(table_name='job', column_name='start_date', - type_=mysql.DATETIME(fsp=6)) - op.alter_column(table_name='job', column_name='end_date', - type_=mysql.DATETIME(fsp=6)) - op.alter_column(table_name='job', column_name='latest_heartbeat', - type_=mysql.DATETIME(fsp=6)) - - op.alter_column(table_name='log', column_name='dttm', - type_=mysql.DATETIME(fsp=6)) - op.alter_column(table_name='log', column_name='execution_date', - type_=mysql.DATETIME(fsp=6)) - - op.alter_column(table_name='sla_miss', column_name='execution_date', - type_=mysql.DATETIME(fsp=6), - nullable=False) - op.alter_column(table_name='sla_miss', column_name='timestamp', - type_=mysql.DATETIME(fsp=6)) - - op.alter_column(table_name='task_fail', column_name='execution_date', - type_=mysql.DATETIME(fsp=6)) - op.alter_column(table_name='task_fail', column_name='start_date', - type_=mysql.DATETIME(fsp=6)) - op.alter_column(table_name='task_fail', column_name='end_date', - type_=mysql.DATETIME(fsp=6)) - - op.alter_column(table_name='task_instance', column_name='execution_date', - type_=mysql.DATETIME(fsp=6), - nullable=False) - op.alter_column(table_name='task_instance', column_name='start_date', - type_=mysql.DATETIME(fsp=6)) - op.alter_column(table_name='task_instance', column_name='end_date', - type_=mysql.DATETIME(fsp=6)) - op.alter_column(table_name='task_instance', column_name='queued_dttm', - type_=mysql.DATETIME(fsp=6)) - - op.alter_column(table_name='xcom', column_name='timestamp', - type_=mysql.DATETIME(fsp=6)) - op.alter_column(table_name='xcom', column_name='execution_date', - type_=mysql.DATETIME(fsp=6)) - - -def downgrade(): # noqa: D103 + op.alter_column(table_name='dag', column_name='last_scheduler_run', type_=mysql.DATETIME(fsp=6)) + op.alter_column(table_name='dag', column_name='last_pickled', type_=mysql.DATETIME(fsp=6)) + op.alter_column(table_name='dag', column_name='last_expired', type_=mysql.DATETIME(fsp=6)) + + op.alter_column(table_name='dag_pickle', column_name='created_dttm', type_=mysql.DATETIME(fsp=6)) + + op.alter_column(table_name='dag_run', column_name='execution_date', type_=mysql.DATETIME(fsp=6)) + op.alter_column(table_name='dag_run', column_name='start_date', type_=mysql.DATETIME(fsp=6)) + op.alter_column(table_name='dag_run', column_name='end_date', type_=mysql.DATETIME(fsp=6)) + + op.alter_column(table_name='import_error', column_name='timestamp', type_=mysql.DATETIME(fsp=6)) + + op.alter_column(table_name='job', column_name='start_date', type_=mysql.DATETIME(fsp=6)) + op.alter_column(table_name='job', column_name='end_date', type_=mysql.DATETIME(fsp=6)) + op.alter_column(table_name='job', column_name='latest_heartbeat', type_=mysql.DATETIME(fsp=6)) + + op.alter_column(table_name='log', column_name='dttm', type_=mysql.DATETIME(fsp=6)) + op.alter_column(table_name='log', column_name='execution_date', type_=mysql.DATETIME(fsp=6)) + + op.alter_column( + table_name='sla_miss', column_name='execution_date', type_=mysql.DATETIME(fsp=6), nullable=False + ) + op.alter_column(table_name='sla_miss', column_name='timestamp', type_=mysql.DATETIME(fsp=6)) + + op.alter_column(table_name='task_fail', column_name='execution_date', type_=mysql.DATETIME(fsp=6)) + op.alter_column(table_name='task_fail', column_name='start_date', type_=mysql.DATETIME(fsp=6)) + op.alter_column(table_name='task_fail', column_name='end_date', type_=mysql.DATETIME(fsp=6)) + + op.alter_column( + table_name='task_instance', + column_name='execution_date', + type_=mysql.DATETIME(fsp=6), + nullable=False, + ) + op.alter_column(table_name='task_instance', column_name='start_date', type_=mysql.DATETIME(fsp=6)) + op.alter_column(table_name='task_instance', column_name='end_date', type_=mysql.DATETIME(fsp=6)) + op.alter_column(table_name='task_instance', column_name='queued_dttm', type_=mysql.DATETIME(fsp=6)) + + op.alter_column(table_name='xcom', column_name='timestamp', type_=mysql.DATETIME(fsp=6)) + op.alter_column(table_name='xcom', column_name='execution_date', type_=mysql.DATETIME(fsp=6)) + + +def downgrade(): # noqa: D103 if context.config.get_main_option('sqlalchemy.url').startswith('mysql'): - op.alter_column(table_name='dag', column_name='last_scheduler_run', - type_=mysql.DATETIME()) - op.alter_column(table_name='dag', column_name='last_pickled', - type_=mysql.DATETIME()) - op.alter_column(table_name='dag', column_name='last_expired', - type_=mysql.DATETIME()) - - op.alter_column(table_name='dag_pickle', column_name='created_dttm', - type_=mysql.DATETIME()) - - op.alter_column(table_name='dag_run', column_name='execution_date', - type_=mysql.DATETIME()) - op.alter_column(table_name='dag_run', column_name='start_date', - type_=mysql.DATETIME()) - op.alter_column(table_name='dag_run', column_name='end_date', - type_=mysql.DATETIME()) - - op.alter_column(table_name='import_error', column_name='timestamp', - type_=mysql.DATETIME()) - - op.alter_column(table_name='job', column_name='start_date', - type_=mysql.DATETIME()) - op.alter_column(table_name='job', column_name='end_date', - type_=mysql.DATETIME()) - op.alter_column(table_name='job', column_name='latest_heartbeat', - type_=mysql.DATETIME()) - - op.alter_column(table_name='log', column_name='dttm', - type_=mysql.DATETIME()) - op.alter_column(table_name='log', column_name='execution_date', - type_=mysql.DATETIME()) - - op.alter_column(table_name='sla_miss', column_name='execution_date', - type_=mysql.DATETIME(), nullable=False) - op.alter_column(table_name='sla_miss', column_name='timestamp', - type_=mysql.DATETIME()) - - op.alter_column(table_name='task_fail', column_name='execution_date', - type_=mysql.DATETIME()) - op.alter_column(table_name='task_fail', column_name='start_date', - type_=mysql.DATETIME()) - op.alter_column(table_name='task_fail', column_name='end_date', - type_=mysql.DATETIME()) - - op.alter_column(table_name='task_instance', column_name='execution_date', - type_=mysql.DATETIME(), - nullable=False) - op.alter_column(table_name='task_instance', column_name='start_date', - type_=mysql.DATETIME()) - op.alter_column(table_name='task_instance', column_name='end_date', - type_=mysql.DATETIME()) - op.alter_column(table_name='task_instance', column_name='queued_dttm', - type_=mysql.DATETIME()) - - op.alter_column(table_name='xcom', column_name='timestamp', - type_=mysql.DATETIME()) - op.alter_column(table_name='xcom', column_name='execution_date', - type_=mysql.DATETIME()) + op.alter_column(table_name='dag', column_name='last_scheduler_run', type_=mysql.DATETIME()) + op.alter_column(table_name='dag', column_name='last_pickled', type_=mysql.DATETIME()) + op.alter_column(table_name='dag', column_name='last_expired', type_=mysql.DATETIME()) + + op.alter_column(table_name='dag_pickle', column_name='created_dttm', type_=mysql.DATETIME()) + + op.alter_column(table_name='dag_run', column_name='execution_date', type_=mysql.DATETIME()) + op.alter_column(table_name='dag_run', column_name='start_date', type_=mysql.DATETIME()) + op.alter_column(table_name='dag_run', column_name='end_date', type_=mysql.DATETIME()) + + op.alter_column(table_name='import_error', column_name='timestamp', type_=mysql.DATETIME()) + + op.alter_column(table_name='job', column_name='start_date', type_=mysql.DATETIME()) + op.alter_column(table_name='job', column_name='end_date', type_=mysql.DATETIME()) + op.alter_column(table_name='job', column_name='latest_heartbeat', type_=mysql.DATETIME()) + + op.alter_column(table_name='log', column_name='dttm', type_=mysql.DATETIME()) + op.alter_column(table_name='log', column_name='execution_date', type_=mysql.DATETIME()) + + op.alter_column( + table_name='sla_miss', column_name='execution_date', type_=mysql.DATETIME(), nullable=False + ) + op.alter_column(table_name='sla_miss', column_name='timestamp', type_=mysql.DATETIME()) + + op.alter_column(table_name='task_fail', column_name='execution_date', type_=mysql.DATETIME()) + op.alter_column(table_name='task_fail', column_name='start_date', type_=mysql.DATETIME()) + op.alter_column(table_name='task_fail', column_name='end_date', type_=mysql.DATETIME()) + + op.alter_column( + table_name='task_instance', column_name='execution_date', type_=mysql.DATETIME(), nullable=False + ) + op.alter_column(table_name='task_instance', column_name='start_date', type_=mysql.DATETIME()) + op.alter_column(table_name='task_instance', column_name='end_date', type_=mysql.DATETIME()) + op.alter_column(table_name='task_instance', column_name='queued_dttm', type_=mysql.DATETIME()) + + op.alter_column(table_name='xcom', column_name='timestamp', type_=mysql.DATETIME()) + op.alter_column(table_name='xcom', column_name='execution_date', type_=mysql.DATETIME()) diff --git a/airflow/migrations/versions/502898887f84_adding_extra_to_log.py b/airflow/migrations/versions/502898887f84_adding_extra_to_log.py index 606b6b8b5a1be..0f00e110c1d5a 100644 --- a/airflow/migrations/versions/502898887f84_adding_extra_to_log.py +++ b/airflow/migrations/versions/502898887f84_adding_extra_to_log.py @@ -33,9 +33,9 @@ depends_on = None -def upgrade(): # noqa: D103 +def upgrade(): # noqa: D103 op.add_column('log', sa.Column('extra', sa.Text(), nullable=True)) -def downgrade(): # noqa: D103 +def downgrade(): # noqa: D103 op.drop_column('log', 'extra') diff --git a/airflow/migrations/versions/52d53670a240_fix_mssql_exec_date_rendered_task_instance.py b/airflow/migrations/versions/52d53670a240_fix_mssql_exec_date_rendered_task_instance.py index c2898b86423a1..84daac5013518 100644 --- a/airflow/migrations/versions/52d53670a240_fix_mssql_exec_date_rendered_task_instance.py +++ b/airflow/migrations/versions/52d53670a240_fix_mssql_exec_date_rendered_task_instance.py @@ -52,7 +52,7 @@ def upgrade(): sa.Column('task_id', sa.String(length=250), nullable=False), sa.Column('execution_date', mssql.DATETIME2, nullable=False), sa.Column('rendered_fields', json_type(), nullable=False), - sa.PrimaryKeyConstraint('dag_id', 'task_id', 'execution_date') + sa.PrimaryKeyConstraint('dag_id', 'task_id', 'execution_date'), ) @@ -72,5 +72,5 @@ def downgrade(): sa.Column('task_id', sa.String(length=250), nullable=False), sa.Column('execution_date', sa.TIMESTAMP, nullable=False), sa.Column('rendered_fields', json_type(), nullable=False), - sa.PrimaryKeyConstraint('dag_id', 'task_id', 'execution_date') + sa.PrimaryKeyConstraint('dag_id', 'task_id', 'execution_date'), ) diff --git a/airflow/migrations/versions/52d714495f0_job_id_indices.py b/airflow/migrations/versions/52d714495f0_job_id_indices.py index 2006271a48cee..fc3ecad87f200 100644 --- a/airflow/migrations/versions/52d714495f0_job_id_indices.py +++ b/airflow/migrations/versions/52d714495f0_job_id_indices.py @@ -32,10 +32,9 @@ depends_on = None -def upgrade(): # noqa: D103 - op.create_index('idx_job_state_heartbeat', 'job', - ['state', 'latest_heartbeat'], unique=False) +def upgrade(): # noqa: D103 + op.create_index('idx_job_state_heartbeat', 'job', ['state', 'latest_heartbeat'], unique=False) -def downgrade(): # noqa: D103 +def downgrade(): # noqa: D103 op.drop_index('idx_job_state_heartbeat', table_name='job') diff --git a/airflow/migrations/versions/561833c1c74b_add_password_column_to_user.py b/airflow/migrations/versions/561833c1c74b_add_password_column_to_user.py index bbf835c5cca62..144259ef06f9c 100644 --- a/airflow/migrations/versions/561833c1c74b_add_password_column_to_user.py +++ b/airflow/migrations/versions/561833c1c74b_add_password_column_to_user.py @@ -33,9 +33,9 @@ depends_on = None -def upgrade(): # noqa: D103 +def upgrade(): # noqa: D103 op.add_column('user', sa.Column('password', sa.String(255))) -def downgrade(): # noqa: D103 +def downgrade(): # noqa: D103 op.drop_column('user', 'password') diff --git a/airflow/migrations/versions/64de9cddf6c9_add_task_fails_journal_table.py b/airflow/migrations/versions/64de9cddf6c9_add_task_fails_journal_table.py index d798a7b67b36e..40dd9ddbfae19 100644 --- a/airflow/migrations/versions/64de9cddf6c9_add_task_fails_journal_table.py +++ b/airflow/migrations/versions/64de9cddf6c9_add_task_fails_journal_table.py @@ -35,7 +35,7 @@ depends_on = None -def upgrade(): # noqa: D103 +def upgrade(): # noqa: D103 op.create_table( 'task_fail', sa.Column('id', sa.Integer(), nullable=False), @@ -49,5 +49,5 @@ def upgrade(): # noqa: D103 ) -def downgrade(): # noqa: D103 +def downgrade(): # noqa: D103 op.drop_table('task_fail') diff --git a/airflow/migrations/versions/6e96a59344a4_make_taskinstance_pool_not_nullable.py b/airflow/migrations/versions/6e96a59344a4_make_taskinstance_pool_not_nullable.py index 8526a9a7446c6..c59c52225757f 100644 --- a/airflow/migrations/versions/6e96a59344a4_make_taskinstance_pool_not_nullable.py +++ b/airflow/migrations/versions/6e96a59344a4_make_taskinstance_pool_not_nullable.py @@ -84,10 +84,9 @@ class TaskInstance(Base): # type: ignore def upgrade(): """Make TaskInstance.pool field not nullable.""" with create_session() as session: - session.query(TaskInstance) \ - .filter(TaskInstance.pool.is_(None)) \ - .update({TaskInstance.pool: 'default_pool'}, - synchronize_session=False) # Avoid select updated rows + session.query(TaskInstance).filter(TaskInstance.pool.is_(None)).update( + {TaskInstance.pool: 'default_pool'}, synchronize_session=False + ) # Avoid select updated rows session.commit() conn = op.get_bind() @@ -124,8 +123,7 @@ def downgrade(): op.create_index('ti_pool', 'task_instance', ['pool', 'state', 'priority_weight']) with create_session() as session: - session.query(TaskInstance) \ - .filter(TaskInstance.pool == 'default_pool') \ - .update({TaskInstance.pool: None}, - synchronize_session=False) # Avoid select updated rows + session.query(TaskInstance).filter(TaskInstance.pool == 'default_pool').update( + {TaskInstance.pool: None}, synchronize_session=False + ) # Avoid select updated rows session.commit() diff --git a/airflow/migrations/versions/74effc47d867_change_datetime_to_datetime2_6_on_mssql_.py b/airflow/migrations/versions/74effc47d867_change_datetime_to_datetime2_6_on_mssql_.py index 3a6746d1bdba7..5505acd0e17e6 100644 --- a/airflow/migrations/versions/74effc47d867_change_datetime_to_datetime2_6_on_mssql_.py +++ b/airflow/migrations/versions/74effc47d867_change_datetime_to_datetime2_6_on_mssql_.py @@ -43,7 +43,8 @@ def upgrade(): result = conn.execute( """SELECT CASE WHEN CONVERT(VARCHAR(128), SERVERPROPERTY ('productversion')) like '8%' THEN '2000' WHEN CONVERT(VARCHAR(128), SERVERPROPERTY ('productversion')) - like '9%' THEN '2005' ELSE '2005Plus' END AS MajorVersion""").fetchone() + like '9%' THEN '2005' ELSE '2005Plus' END AS MajorVersion""" + ).fetchone() mssql_version = result[0] if mssql_version in ("2000", "2005"): return @@ -51,37 +52,49 @@ def upgrade(): with op.batch_alter_table('task_reschedule') as task_reschedule_batch_op: task_reschedule_batch_op.drop_index('idx_task_reschedule_dag_task_date') task_reschedule_batch_op.drop_constraint('task_reschedule_dag_task_date_fkey', type_='foreignkey') - task_reschedule_batch_op.alter_column(column_name="execution_date", - type_=mssql.DATETIME2(precision=6), nullable=False, ) - task_reschedule_batch_op.alter_column(column_name='start_date', - type_=mssql.DATETIME2(precision=6)) + task_reschedule_batch_op.alter_column( + column_name="execution_date", + type_=mssql.DATETIME2(precision=6), + nullable=False, + ) + task_reschedule_batch_op.alter_column( + column_name='start_date', type_=mssql.DATETIME2(precision=6) + ) task_reschedule_batch_op.alter_column(column_name='end_date', type_=mssql.DATETIME2(precision=6)) - task_reschedule_batch_op.alter_column(column_name='reschedule_date', - type_=mssql.DATETIME2(precision=6)) + task_reschedule_batch_op.alter_column( + column_name='reschedule_date', type_=mssql.DATETIME2(precision=6) + ) with op.batch_alter_table('task_instance') as task_instance_batch_op: task_instance_batch_op.drop_index('ti_state_lkp') task_instance_batch_op.drop_index('ti_dag_date') - modify_execution_date_with_constraint(conn, task_instance_batch_op, 'task_instance', - mssql.DATETIME2(precision=6), False) + modify_execution_date_with_constraint( + conn, task_instance_batch_op, 'task_instance', mssql.DATETIME2(precision=6), False + ) task_instance_batch_op.alter_column(column_name='start_date', type_=mssql.DATETIME2(precision=6)) task_instance_batch_op.alter_column(column_name='end_date', type_=mssql.DATETIME2(precision=6)) task_instance_batch_op.alter_column(column_name='queued_dttm', type_=mssql.DATETIME2(precision=6)) - task_instance_batch_op.create_index('ti_state_lkp', ['dag_id', 'task_id', 'execution_date'], - unique=False) + task_instance_batch_op.create_index( + 'ti_state_lkp', ['dag_id', 'task_id', 'execution_date'], unique=False + ) task_instance_batch_op.create_index('ti_dag_date', ['dag_id', 'execution_date'], unique=False) with op.batch_alter_table('task_reschedule') as task_reschedule_batch_op: - task_reschedule_batch_op.create_foreign_key('task_reschedule_dag_task_date_fkey', 'task_instance', - ['task_id', 'dag_id', 'execution_date'], - ['task_id', 'dag_id', 'execution_date'], - ondelete='CASCADE') - task_reschedule_batch_op.create_index('idx_task_reschedule_dag_task_date', - ['dag_id', 'task_id', 'execution_date'], unique=False) + task_reschedule_batch_op.create_foreign_key( + 'task_reschedule_dag_task_date_fkey', + 'task_instance', + ['task_id', 'dag_id', 'execution_date'], + ['task_id', 'dag_id', 'execution_date'], + ondelete='CASCADE', + ) + task_reschedule_batch_op.create_index( + 'idx_task_reschedule_dag_task_date', ['dag_id', 'task_id', 'execution_date'], unique=False + ) with op.batch_alter_table('dag_run') as dag_run_batch_op: - modify_execution_date_with_constraint(conn, dag_run_batch_op, 'dag_run', - mssql.DATETIME2(precision=6), None) + modify_execution_date_with_constraint( + conn, dag_run_batch_op, 'dag_run', mssql.DATETIME2(precision=6), None + ) dag_run_batch_op.alter_column(column_name='start_date', type_=mssql.DATETIME2(precision=6)) dag_run_batch_op.alter_column(column_name='end_date', type_=mssql.DATETIME2(precision=6)) @@ -89,34 +102,41 @@ def upgrade(): op.alter_column(table_name='log', column_name='dttm', type_=mssql.DATETIME2(precision=6)) with op.batch_alter_table('sla_miss') as sla_miss_batch_op: - modify_execution_date_with_constraint(conn, sla_miss_batch_op, 'sla_miss', - mssql.DATETIME2(precision=6), False) + modify_execution_date_with_constraint( + conn, sla_miss_batch_op, 'sla_miss', mssql.DATETIME2(precision=6), False + ) sla_miss_batch_op.alter_column(column_name='timestamp', type_=mssql.DATETIME2(precision=6)) op.drop_index('idx_task_fail_dag_task_date', table_name='task_fail') - op.alter_column(table_name="task_fail", column_name="execution_date", - type_=mssql.DATETIME2(precision=6)) + op.alter_column( + table_name="task_fail", column_name="execution_date", type_=mssql.DATETIME2(precision=6) + ) op.alter_column(table_name='task_fail', column_name='start_date', type_=mssql.DATETIME2(precision=6)) op.alter_column(table_name='task_fail', column_name='end_date', type_=mssql.DATETIME2(precision=6)) - op.create_index('idx_task_fail_dag_task_date', 'task_fail', ['dag_id', 'task_id', 'execution_date'], - unique=False) + op.create_index( + 'idx_task_fail_dag_task_date', 'task_fail', ['dag_id', 'task_id', 'execution_date'], unique=False + ) op.drop_index('idx_xcom_dag_task_date', table_name='xcom') op.alter_column(table_name="xcom", column_name="execution_date", type_=mssql.DATETIME2(precision=6)) op.alter_column(table_name='xcom', column_name='timestamp', type_=mssql.DATETIME2(precision=6)) - op.create_index('idx_xcom_dag_task_date', 'xcom', ['dag_id', 'task_id', 'execution_date'], - unique=False) + op.create_index( + 'idx_xcom_dag_task_date', 'xcom', ['dag_id', 'task_id', 'execution_date'], unique=False + ) - op.alter_column(table_name='dag', column_name='last_scheduler_run', - type_=mssql.DATETIME2(precision=6)) + op.alter_column( + table_name='dag', column_name='last_scheduler_run', type_=mssql.DATETIME2(precision=6) + ) op.alter_column(table_name='dag', column_name='last_pickled', type_=mssql.DATETIME2(precision=6)) op.alter_column(table_name='dag', column_name='last_expired', type_=mssql.DATETIME2(precision=6)) - op.alter_column(table_name='dag_pickle', column_name='created_dttm', - type_=mssql.DATETIME2(precision=6)) + op.alter_column( + table_name='dag_pickle', column_name='created_dttm', type_=mssql.DATETIME2(precision=6) + ) - op.alter_column(table_name='import_error', column_name='timestamp', - type_=mssql.DATETIME2(precision=6)) + op.alter_column( + table_name='import_error', column_name='timestamp', type_=mssql.DATETIME2(precision=6) + ) op.drop_index('job_type_heart', table_name='job') op.drop_index('idx_job_state_heartbeat', table_name='job') @@ -134,7 +154,8 @@ def downgrade(): result = conn.execute( """SELECT CASE WHEN CONVERT(VARCHAR(128), SERVERPROPERTY ('productversion')) like '8%' THEN '2000' WHEN CONVERT(VARCHAR(128), SERVERPROPERTY ('productversion')) - like '9%' THEN '2005' ELSE '2005Plus' END AS MajorVersion""").fetchone() + like '9%' THEN '2005' ELSE '2005Plus' END AS MajorVersion""" + ).fetchone() mssql_version = result[0] if mssql_version in ("2000", "2005"): return @@ -142,8 +163,9 @@ def downgrade(): with op.batch_alter_table('task_reschedule') as task_reschedule_batch_op: task_reschedule_batch_op.drop_index('idx_task_reschedule_dag_task_date') task_reschedule_batch_op.drop_constraint('task_reschedule_dag_task_date_fkey', type_='foreignkey') - task_reschedule_batch_op.alter_column(column_name="execution_date", type_=mssql.DATETIME, - nullable=False) + task_reschedule_batch_op.alter_column( + column_name="execution_date", type_=mssql.DATETIME, nullable=False + ) task_reschedule_batch_op.alter_column(column_name='start_date', type_=mssql.DATETIME) task_reschedule_batch_op.alter_column(column_name='end_date', type_=mssql.DATETIME) task_reschedule_batch_op.alter_column(column_name='reschedule_date', type_=mssql.DATETIME) @@ -151,23 +173,28 @@ def downgrade(): with op.batch_alter_table('task_instance') as task_instance_batch_op: task_instance_batch_op.drop_index('ti_state_lkp') task_instance_batch_op.drop_index('ti_dag_date') - modify_execution_date_with_constraint(conn, task_instance_batch_op, 'task_instance', - mssql.DATETIME, False) + modify_execution_date_with_constraint( + conn, task_instance_batch_op, 'task_instance', mssql.DATETIME, False + ) task_instance_batch_op.alter_column(column_name='start_date', type_=mssql.DATETIME) task_instance_batch_op.alter_column(column_name='end_date', type_=mssql.DATETIME) task_instance_batch_op.alter_column(column_name='queued_dttm', type_=mssql.DATETIME) - task_instance_batch_op.create_index('ti_state_lkp', - ['dag_id', 'task_id', 'execution_date'], unique=False) - task_instance_batch_op.create_index('ti_dag_date', - ['dag_id', 'execution_date'], unique=False) + task_instance_batch_op.create_index( + 'ti_state_lkp', ['dag_id', 'task_id', 'execution_date'], unique=False + ) + task_instance_batch_op.create_index('ti_dag_date', ['dag_id', 'execution_date'], unique=False) with op.batch_alter_table('task_reschedule') as task_reschedule_batch_op: - task_reschedule_batch_op.create_foreign_key('task_reschedule_dag_task_date_fkey', 'task_instance', - ['task_id', 'dag_id', 'execution_date'], - ['task_id', 'dag_id', 'execution_date'], - ondelete='CASCADE') - task_reschedule_batch_op.create_index('idx_task_reschedule_dag_task_date', - ['dag_id', 'task_id', 'execution_date'], unique=False) + task_reschedule_batch_op.create_foreign_key( + 'task_reschedule_dag_task_date_fkey', + 'task_instance', + ['task_id', 'dag_id', 'execution_date'], + ['task_id', 'dag_id', 'execution_date'], + ondelete='CASCADE', + ) + task_reschedule_batch_op.create_index( + 'idx_task_reschedule_dag_task_date', ['dag_id', 'task_id', 'execution_date'], unique=False + ) with op.batch_alter_table('dag_run') as dag_run_batch_op: modify_execution_date_with_constraint(conn, dag_run_batch_op, 'dag_run', mssql.DATETIME, None) @@ -185,14 +212,16 @@ def downgrade(): op.alter_column(table_name="task_fail", column_name="execution_date", type_=mssql.DATETIME) op.alter_column(table_name='task_fail', column_name='start_date', type_=mssql.DATETIME) op.alter_column(table_name='task_fail', column_name='end_date', type_=mssql.DATETIME) - op.create_index('idx_task_fail_dag_task_date', 'task_fail', ['dag_id', 'task_id', 'execution_date'], - unique=False) + op.create_index( + 'idx_task_fail_dag_task_date', 'task_fail', ['dag_id', 'task_id', 'execution_date'], unique=False + ) op.drop_index('idx_xcom_dag_task_date', table_name='xcom') op.alter_column(table_name="xcom", column_name="execution_date", type_=mssql.DATETIME) op.alter_column(table_name='xcom', column_name='timestamp', type_=mssql.DATETIME) - op.create_index('idx_xcom_dag_task_date', 'xcom', ['dag_id', 'task_ild', 'execution_date'], - unique=False) + op.create_index( + 'idx_xcom_dag_task_date', 'xcom', ['dag_id', 'task_ild', 'execution_date'], unique=False + ) op.alter_column(table_name='dag', column_name='last_scheduler_run', type_=mssql.DATETIME) op.alter_column(table_name='dag', column_name='last_pickled', type_=mssql.DATETIME) @@ -229,7 +258,9 @@ def get_table_constraints(conn, table_name): JOIN INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE AS ccu ON ccu.CONSTRAINT_NAME = tc.CONSTRAINT_NAME WHERE tc.TABLE_NAME = '{table_name}' AND (tc.CONSTRAINT_TYPE = 'PRIMARY KEY' or UPPER(tc.CONSTRAINT_TYPE) = 'UNIQUE') - """.format(table_name=table_name) + """.format( + table_name=table_name + ) result = conn.execute(query).fetchall() constraint_dict = defaultdict(list) for constraint, constraint_type, column in result: @@ -267,15 +298,9 @@ def drop_constraint(operator, constraint_dict): for constraint, columns in constraint_dict.items(): if 'execution_date' in columns: if constraint[1].lower().startswith("primary"): - operator.drop_constraint( - constraint[0], - type_='primary' - ) + operator.drop_constraint(constraint[0], type_='primary') elif constraint[1].lower().startswith("unique"): - operator.drop_constraint( - constraint[0], - type_='unique' - ) + operator.drop_constraint(constraint[0], type_='unique') def create_constraint(operator, constraint_dict): @@ -288,14 +313,10 @@ def create_constraint(operator, constraint_dict): for constraint, columns in constraint_dict.items(): if 'execution_date' in columns: if constraint[1].lower().startswith("primary"): - operator.create_primary_key( - constraint_name=constraint[0], - columns=reorder_columns(columns) - ) + operator.create_primary_key(constraint_name=constraint[0], columns=reorder_columns(columns)) elif constraint[1].lower().startswith("unique"): operator.create_unique_constraint( - constraint_name=constraint[0], - columns=reorder_columns(columns) + constraint_name=constraint[0], columns=reorder_columns(columns) ) diff --git a/airflow/migrations/versions/7939bcff74ba_add_dagtags_table.py b/airflow/migrations/versions/7939bcff74ba_add_dagtags_table.py index 3bd49033e6e30..8b8b93c49f750 100644 --- a/airflow/migrations/versions/7939bcff74ba_add_dagtags_table.py +++ b/airflow/migrations/versions/7939bcff74ba_add_dagtags_table.py @@ -40,8 +40,11 @@ def upgrade(): 'dag_tag', sa.Column('name', sa.String(length=100), nullable=False), sa.Column('dag_id', sa.String(length=250), nullable=False), - sa.ForeignKeyConstraint(['dag_id'], ['dag.dag_id'], ), - sa.PrimaryKeyConstraint('name', 'dag_id') + sa.ForeignKeyConstraint( + ['dag_id'], + ['dag.dag_id'], + ), + sa.PrimaryKeyConstraint('name', 'dag_id'), ) diff --git a/airflow/migrations/versions/849da589634d_prefix_dag_permissions.py b/airflow/migrations/versions/849da589634d_prefix_dag_permissions.py index 23fd96a21d289..240b39e1305fd 100644 --- a/airflow/migrations/versions/849da589634d_prefix_dag_permissions.py +++ b/airflow/migrations/versions/849da589634d_prefix_dag_permissions.py @@ -34,18 +34,16 @@ depends_on = None -def upgrade(): # noqa: D103 +def upgrade(): # noqa: D103 permissions = ['can_dag_read', 'can_dag_edit'] view_menus = cached_app().appbuilder.sm.get_all_view_menu() convert_permissions(permissions, view_menus, upgrade_action, upgrade_dag_id) -def downgrade(): # noqa: D103 +def downgrade(): # noqa: D103 permissions = ['can_read', 'can_edit'] vms = cached_app().appbuilder.sm.get_all_view_menu() - view_menus = [ - vm for vm in vms if (vm.name == permissions.RESOURCE_DAGS or vm.name.startswith('DAG:')) - ] + view_menus = [vm for vm in vms if (vm.name == permissions.RESOURCE_DAGS or vm.name.startswith('DAG:'))] convert_permissions(permissions, view_menus, downgrade_action, downgrade_dag_id) @@ -63,7 +61,7 @@ def downgrade_dag_id(dag_id): if dag_id == permissions.RESOURCE_DAGS: return 'all_dags' if dag_id.startswith("DAG:"): - return dag_id[len("DAG:"):] + return dag_id[len("DAG:") :] return dag_id diff --git a/airflow/migrations/versions/8504051e801b_xcom_dag_task_indices.py b/airflow/migrations/versions/8504051e801b_xcom_dag_task_indices.py index 7eadb56000137..6f60ac6b12335 100644 --- a/airflow/migrations/versions/8504051e801b_xcom_dag_task_indices.py +++ b/airflow/migrations/versions/8504051e801b_xcom_dag_task_indices.py @@ -35,8 +35,7 @@ def upgrade(): """Create Index.""" - op.create_index('idx_xcom_dag_task_date', 'xcom', - ['dag_id', 'task_id', 'execution_date'], unique=False) + op.create_index('idx_xcom_dag_task_date', 'xcom', ['dag_id', 'task_id', 'execution_date'], unique=False) def downgrade(): diff --git a/airflow/migrations/versions/852ae6c715af_add_rendered_task_instance_fields_table.py b/airflow/migrations/versions/852ae6c715af_add_rendered_task_instance_fields_table.py index 01357e7c6f706..282286d24e745 100644 --- a/airflow/migrations/versions/852ae6c715af_add_rendered_task_instance_fields_table.py +++ b/airflow/migrations/versions/852ae6c715af_add_rendered_task_instance_fields_table.py @@ -55,10 +55,10 @@ def upgrade(): sa.Column('task_id', sa.String(length=250), nullable=False), sa.Column('execution_date', sa.TIMESTAMP(timezone=True), nullable=False), sa.Column('rendered_fields', json_type(), nullable=False), - sa.PrimaryKeyConstraint('dag_id', 'task_id', 'execution_date') + sa.PrimaryKeyConstraint('dag_id', 'task_id', 'execution_date'), ) def downgrade(): """Drop RenderedTaskInstanceFields table""" - op.drop_table(TABLE_NAME) # pylint: disable=no-member + op.drop_table(TABLE_NAME) # pylint: disable=no-member diff --git a/airflow/migrations/versions/856955da8476_fix_sqlite_foreign_key.py b/airflow/migrations/versions/856955da8476_fix_sqlite_foreign_key.py index 9b804ce9c9243..fd8936c4be71e 100644 --- a/airflow/migrations/versions/856955da8476_fix_sqlite_foreign_key.py +++ b/airflow/migrations/versions/856955da8476_fix_sqlite_foreign_key.py @@ -43,29 +43,30 @@ def upgrade(): # which would fail because referenced user table doesn't exist. # # Use batch_alter_table to support SQLite workaround. - chart_table = sa.Table('chart', - sa.MetaData(), - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('label', sa.String(length=200), nullable=True), - sa.Column('conn_id', sa.String(length=250), nullable=False), - sa.Column('user_id', sa.Integer(), nullable=True), - sa.Column('chart_type', sa.String(length=100), nullable=True), - sa.Column('sql_layout', sa.String(length=50), nullable=True), - sa.Column('sql', sa.Text(), nullable=True), - sa.Column('y_log_scale', sa.Boolean(), nullable=True), - sa.Column('show_datatable', sa.Boolean(), nullable=True), - sa.Column('show_sql', sa.Boolean(), nullable=True), - sa.Column('height', sa.Integer(), nullable=True), - sa.Column('default_params', sa.String(length=5000), nullable=True), - sa.Column('x_is_date', sa.Boolean(), nullable=True), - sa.Column('iteration_no', sa.Integer(), nullable=True), - sa.Column('last_modified', sa.DateTime(), nullable=True), - sa.PrimaryKeyConstraint('id')) + chart_table = sa.Table( + 'chart', + sa.MetaData(), + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('label', sa.String(length=200), nullable=True), + sa.Column('conn_id', sa.String(length=250), nullable=False), + sa.Column('user_id', sa.Integer(), nullable=True), + sa.Column('chart_type', sa.String(length=100), nullable=True), + sa.Column('sql_layout', sa.String(length=50), nullable=True), + sa.Column('sql', sa.Text(), nullable=True), + sa.Column('y_log_scale', sa.Boolean(), nullable=True), + sa.Column('show_datatable', sa.Boolean(), nullable=True), + sa.Column('show_sql', sa.Boolean(), nullable=True), + sa.Column('height', sa.Integer(), nullable=True), + sa.Column('default_params', sa.String(length=5000), nullable=True), + sa.Column('x_is_date', sa.Boolean(), nullable=True), + sa.Column('iteration_no', sa.Integer(), nullable=True), + sa.Column('last_modified', sa.DateTime(), nullable=True), + sa.PrimaryKeyConstraint('id'), + ) with op.batch_alter_table('chart', copy_from=chart_table) as batch_op: - batch_op.create_foreign_key('chart_user_id_fkey', 'users', - ['user_id'], ['id']) + batch_op.create_foreign_key('chart_user_id_fkey', 'users', ['user_id'], ['id']) -def downgrade(): # noqa: D103 +def downgrade(): # noqa: D103 # Downgrade would fail because the broken FK constraint can't be re-created. pass diff --git a/airflow/migrations/versions/86770d1215c0_add_kubernetes_scheduler_uniqueness.py b/airflow/migrations/versions/86770d1215c0_add_kubernetes_scheduler_uniqueness.py index b5928fdcd82ea..db3ccdc043bf4 100644 --- a/airflow/migrations/versions/86770d1215c0_add_kubernetes_scheduler_uniqueness.py +++ b/airflow/migrations/versions/86770d1215c0_add_kubernetes_scheduler_uniqueness.py @@ -35,34 +35,25 @@ RESOURCE_TABLE = "kube_worker_uuid" -def upgrade(): # noqa: D103 +def upgrade(): # noqa: D103 columns_and_constraints = [ sa.Column("one_row_id", sa.Boolean, server_default=sa.true(), primary_key=True), - sa.Column("worker_uuid", sa.String(255)) + sa.Column("worker_uuid", sa.String(255)), ] conn = op.get_bind() # alembic creates an invalid SQL for mssql and mysql dialects if conn.dialect.name in {"mysql"}: - columns_and_constraints.append( - sa.CheckConstraint("one_row_id<>0", name="kube_worker_one_row_id") - ) + columns_and_constraints.append(sa.CheckConstraint("one_row_id<>0", name="kube_worker_one_row_id")) elif conn.dialect.name not in {"mssql"}: - columns_and_constraints.append( - sa.CheckConstraint("one_row_id", name="kube_worker_one_row_id") - ) + columns_and_constraints.append(sa.CheckConstraint("one_row_id", name="kube_worker_one_row_id")) - table = op.create_table( - RESOURCE_TABLE, - *columns_and_constraints - ) + table = op.create_table(RESOURCE_TABLE, *columns_and_constraints) - op.bulk_insert(table, [ - {"worker_uuid": ""} - ]) + op.bulk_insert(table, [{"worker_uuid": ""}]) -def downgrade(): # noqa: D103 +def downgrade(): # noqa: D103 op.drop_table(RESOURCE_TABLE) diff --git a/airflow/migrations/versions/8d48763f6d53_add_unique_constraint_to_conn_id.py b/airflow/migrations/versions/8d48763f6d53_add_unique_constraint_to_conn_id.py index 34ce9e5aa5679..2c743e4813185 100644 --- a/airflow/migrations/versions/8d48763f6d53_add_unique_constraint_to_conn_id.py +++ b/airflow/migrations/versions/8d48763f6d53_add_unique_constraint_to_conn_id.py @@ -38,16 +38,9 @@ def upgrade(): """Apply add unique constraint to conn_id and set it as non-nullable""" try: with op.batch_alter_table('connection') as batch_op: - batch_op.create_unique_constraint( - constraint_name="unique_conn_id", - columns=["conn_id"] - ) + batch_op.create_unique_constraint(constraint_name="unique_conn_id", columns=["conn_id"]) - batch_op.alter_column( - "conn_id", - nullable=False, - existing_type=sa.String(250) - ) + batch_op.alter_column("conn_id", nullable=False, existing_type=sa.String(250)) except sa.exc.IntegrityError: raise Exception("Make sure there are no duplicate connections with the same conn_id or null values") @@ -55,13 +48,6 @@ def upgrade(): def downgrade(): """Unapply add unique constraint to conn_id and set it as non-nullable""" with op.batch_alter_table('connection') as batch_op: - batch_op.drop_constraint( - constraint_name="unique_conn_id", - type_="unique" - ) + batch_op.drop_constraint(constraint_name="unique_conn_id", type_="unique") - batch_op.alter_column( - "conn_id", - nullable=True, - existing_type=sa.String(250) - ) + batch_op.alter_column("conn_id", nullable=True, existing_type=sa.String(250)) diff --git a/airflow/migrations/versions/939bb1e647c8_task_reschedule_fk_on_cascade_delete.py b/airflow/migrations/versions/939bb1e647c8_task_reschedule_fk_on_cascade_delete.py index 4465ba1ae5b8c..ffb61a39eaaa8 100644 --- a/airflow/migrations/versions/939bb1e647c8_task_reschedule_fk_on_cascade_delete.py +++ b/airflow/migrations/versions/939bb1e647c8_task_reschedule_fk_on_cascade_delete.py @@ -32,30 +32,24 @@ depends_on = None -def upgrade(): # noqa: D103 +def upgrade(): # noqa: D103 with op.batch_alter_table('task_reschedule') as batch_op: - batch_op.drop_constraint( - 'task_reschedule_dag_task_date_fkey', - type_='foreignkey' - ) + batch_op.drop_constraint('task_reschedule_dag_task_date_fkey', type_='foreignkey') batch_op.create_foreign_key( 'task_reschedule_dag_task_date_fkey', 'task_instance', ['task_id', 'dag_id', 'execution_date'], ['task_id', 'dag_id', 'execution_date'], - ondelete='CASCADE' + ondelete='CASCADE', ) -def downgrade(): # noqa: D103 +def downgrade(): # noqa: D103 with op.batch_alter_table('task_reschedule') as batch_op: - batch_op.drop_constraint( - 'task_reschedule_dag_task_date_fkey', - type_='foreignkey' - ) + batch_op.drop_constraint('task_reschedule_dag_task_date_fkey', type_='foreignkey') batch_op.create_foreign_key( 'task_reschedule_dag_task_date_fkey', 'task_instance', ['task_id', 'dag_id', 'execution_date'], - ['task_id', 'dag_id', 'execution_date'] + ['task_id', 'dag_id', 'execution_date'], ) diff --git a/airflow/migrations/versions/947454bf1dff_add_ti_job_id_index.py b/airflow/migrations/versions/947454bf1dff_add_ti_job_id_index.py index b18224e6a6158..a1d8b8f6099ba 100644 --- a/airflow/migrations/versions/947454bf1dff_add_ti_job_id_index.py +++ b/airflow/migrations/versions/947454bf1dff_add_ti_job_id_index.py @@ -32,9 +32,9 @@ depends_on = None -def upgrade(): # noqa: D103 +def upgrade(): # noqa: D103 op.create_index('ti_job_id', 'task_instance', ['job_id'], unique=False) -def downgrade(): # noqa: D103 +def downgrade(): # noqa: D103 op.drop_index('ti_job_id', table_name='task_instance') diff --git a/airflow/migrations/versions/952da73b5eff_add_dag_code_table.py b/airflow/migrations/versions/952da73b5eff_add_dag_code_table.py index a937b53055918..63fb68999e6ec 100644 --- a/airflow/migrations/versions/952da73b5eff_add_dag_code_table.py +++ b/airflow/migrations/versions/952da73b5eff_add_dag_code_table.py @@ -51,20 +51,22 @@ class SerializedDagModel(Base): fileloc_hash = sa.Column(sa.BigInteger, nullable=False) """Apply add source code table""" - op.create_table('dag_code', # pylint: disable=no-member - sa.Column('fileloc_hash', sa.BigInteger(), - nullable=False, primary_key=True, autoincrement=False), - sa.Column('fileloc', sa.String(length=2000), nullable=False), - sa.Column('source_code', sa.UnicodeText(), nullable=False), - sa.Column('last_updated', sa.TIMESTAMP(timezone=True), nullable=False)) + op.create_table( + 'dag_code', # pylint: disable=no-member + sa.Column('fileloc_hash', sa.BigInteger(), nullable=False, primary_key=True, autoincrement=False), + sa.Column('fileloc', sa.String(length=2000), nullable=False), + sa.Column('source_code', sa.UnicodeText(), nullable=False), + sa.Column('last_updated', sa.TIMESTAMP(timezone=True), nullable=False), + ) conn = op.get_bind() if conn.dialect.name != 'sqlite': if conn.dialect.name == "mssql": op.drop_index('idx_fileloc_hash', 'serialized_dag') - op.alter_column(table_name='serialized_dag', column_name='fileloc_hash', - type_=sa.BigInteger(), nullable=False) + op.alter_column( + table_name='serialized_dag', column_name='fileloc_hash', type_=sa.BigInteger(), nullable=False + ) if conn.dialect.name == "mssql": op.create_index('idx_fileloc_hash', 'serialized_dag', ['fileloc_hash']) diff --git a/airflow/migrations/versions/9635ae0956e7_index_faskfail.py b/airflow/migrations/versions/9635ae0956e7_index_faskfail.py index 9508170dcf0a8..c924b3a06af2a 100644 --- a/airflow/migrations/versions/9635ae0956e7_index_faskfail.py +++ b/airflow/migrations/versions/9635ae0956e7_index_faskfail.py @@ -31,11 +31,11 @@ depends_on = None -def upgrade(): # noqa: D103 - op.create_index('idx_task_fail_dag_task_date', - 'task_fail', - ['dag_id', 'task_id', 'execution_date'], unique=False) +def upgrade(): # noqa: D103 + op.create_index( + 'idx_task_fail_dag_task_date', 'task_fail', ['dag_id', 'task_id', 'execution_date'], unique=False + ) -def downgrade(): # noqa: D103 +def downgrade(): # noqa: D103 op.drop_index('idx_task_fail_dag_task_date', table_name='task_fail') diff --git a/airflow/migrations/versions/98271e7606e2_add_scheduling_decision_to_dagrun_and_.py b/airflow/migrations/versions/98271e7606e2_add_scheduling_decision_to_dagrun_and_.py index a1ed2344f6b42..ac14cfe3445cd 100644 --- a/airflow/migrations/versions/98271e7606e2_add_scheduling_decision_to_dagrun_and_.py +++ b/airflow/migrations/versions/98271e7606e2_add_scheduling_decision_to_dagrun_and_.py @@ -58,6 +58,7 @@ def upgrade(): try: from airflow.configuration import conf + concurrency = conf.getint('core', 'dag_concurrency', fallback=16) except: # noqa concurrency = 16 diff --git a/airflow/migrations/versions/a4c2fd67d16b_add_pool_slots_field_to_task_instance.py b/airflow/migrations/versions/a4c2fd67d16b_add_pool_slots_field_to_task_instance.py index 1b742e4245649..121c7fa5e4fef 100644 --- a/airflow/migrations/versions/a4c2fd67d16b_add_pool_slots_field_to_task_instance.py +++ b/airflow/migrations/versions/a4c2fd67d16b_add_pool_slots_field_to_task_instance.py @@ -34,9 +34,9 @@ depends_on = None -def upgrade(): # noqa: D103 +def upgrade(): # noqa: D103 op.add_column('task_instance', sa.Column('pool_slots', sa.Integer, default=1)) -def downgrade(): # noqa: D103 +def downgrade(): # noqa: D103 op.drop_column('task_instance', 'pool_slots') diff --git a/airflow/migrations/versions/a56c9515abdc_remove_dag_stat_table.py b/airflow/migrations/versions/a56c9515abdc_remove_dag_stat_table.py index 05adcae8b93ff..a9ea301e2737f 100644 --- a/airflow/migrations/versions/a56c9515abdc_remove_dag_stat_table.py +++ b/airflow/migrations/versions/a56c9515abdc_remove_dag_stat_table.py @@ -41,9 +41,11 @@ def upgrade(): def downgrade(): """Create dag_stats table""" - op.create_table('dag_stats', - sa.Column('dag_id', sa.String(length=250), nullable=False), - sa.Column('state', sa.String(length=50), nullable=False), - sa.Column('count', sa.Integer(), nullable=False, default=0), - sa.Column('dirty', sa.Boolean(), nullable=False, default=False), - sa.PrimaryKeyConstraint('dag_id', 'state')) + op.create_table( + 'dag_stats', + sa.Column('dag_id', sa.String(length=250), nullable=False), + sa.Column('state', sa.String(length=50), nullable=False), + sa.Column('count', sa.Integer(), nullable=False, default=0), + sa.Column('dirty', sa.Boolean(), nullable=False, default=False), + sa.PrimaryKeyConstraint('dag_id', 'state'), + ) diff --git a/airflow/migrations/versions/a66efa278eea_add_precision_to_execution_date_in_mysql.py b/airflow/migrations/versions/a66efa278eea_add_precision_to_execution_date_in_mysql.py index 7c122ef3df1c6..29cce5feee310 100644 --- a/airflow/migrations/versions/a66efa278eea_add_precision_to_execution_date_in_mysql.py +++ b/airflow/migrations/versions/a66efa278eea_add_precision_to_execution_date_in_mysql.py @@ -42,10 +42,7 @@ def upgrade(): conn = op.get_bind() if conn.dialect.name == "mysql": op.alter_column( - table_name=TABLE_NAME, - column_name=COLUMN_NAME, - type_=mysql.TIMESTAMP(fsp=6), - nullable=False + table_name=TABLE_NAME, column_name=COLUMN_NAME, type_=mysql.TIMESTAMP(fsp=6), nullable=False ) @@ -54,8 +51,5 @@ def downgrade(): conn = op.get_bind() if conn.dialect.name == "mysql": op.alter_column( - table_name=TABLE_NAME, - column_name=COLUMN_NAME, - type_=mysql.TIMESTAMP(), - nullable=False + table_name=TABLE_NAME, column_name=COLUMN_NAME, type_=mysql.TIMESTAMP(), nullable=False ) diff --git a/airflow/migrations/versions/b0125267960b_merge_heads.py b/airflow/migrations/versions/b0125267960b_merge_heads.py index 007b99a2a3921..5c05dd78d3711 100644 --- a/airflow/migrations/versions/b0125267960b_merge_heads.py +++ b/airflow/migrations/versions/b0125267960b_merge_heads.py @@ -31,9 +31,9 @@ depends_on = None -def upgrade(): # noqa: D103 +def upgrade(): # noqa: D103 pass -def downgrade(): # noqa: D103 +def downgrade(): # noqa: D103 pass diff --git a/airflow/migrations/versions/bba5a7cfc896_add_a_column_to_track_the_encryption_.py b/airflow/migrations/versions/bba5a7cfc896_add_a_column_to_track_the_encryption_.py index b03b42c61d7cf..4b2cacd90f775 100644 --- a/airflow/migrations/versions/bba5a7cfc896_add_a_column_to_track_the_encryption_.py +++ b/airflow/migrations/versions/bba5a7cfc896_add_a_column_to_track_the_encryption_.py @@ -34,10 +34,9 @@ depends_on = None -def upgrade(): # noqa: D103 - op.add_column('connection', - sa.Column('is_extra_encrypted', sa.Boolean, default=False)) +def upgrade(): # noqa: D103 + op.add_column('connection', sa.Column('is_extra_encrypted', sa.Boolean, default=False)) -def downgrade(): # noqa: D103 +def downgrade(): # noqa: D103 op.drop_column('connection', 'is_extra_encrypted') diff --git a/airflow/migrations/versions/bbc73705a13e_add_notification_sent_column_to_sla_miss.py b/airflow/migrations/versions/bbc73705a13e_add_notification_sent_column_to_sla_miss.py index f5d94f208ace2..2e73d05890950 100644 --- a/airflow/migrations/versions/bbc73705a13e_add_notification_sent_column_to_sla_miss.py +++ b/airflow/migrations/versions/bbc73705a13e_add_notification_sent_column_to_sla_miss.py @@ -33,9 +33,9 @@ depends_on = None -def upgrade(): # noqa: D103 +def upgrade(): # noqa: D103 op.add_column('sla_miss', sa.Column('notification_sent', sa.Boolean, default=False)) -def downgrade(): # noqa: D103 +def downgrade(): # noqa: D103 op.drop_column('sla_miss', 'notification_sent') diff --git a/airflow/migrations/versions/bdaa763e6c56_make_xcom_value_column_a_large_binary.py b/airflow/migrations/versions/bdaa763e6c56_make_xcom_value_column_a_large_binary.py index 1bb0cdcd21dc9..cd4fa0d51c04a 100644 --- a/airflow/migrations/versions/bdaa763e6c56_make_xcom_value_column_a_large_binary.py +++ b/airflow/migrations/versions/bdaa763e6c56_make_xcom_value_column_a_large_binary.py @@ -34,7 +34,7 @@ depends_on = None -def upgrade(): # noqa: D103 +def upgrade(): # noqa: D103 # There can be data truncation here as LargeBinary can be smaller than the pickle # type. # use batch_alter_table to support SQLite workaround @@ -42,7 +42,7 @@ def upgrade(): # noqa: D103 batch_op.alter_column('value', type_=sa.LargeBinary()) -def downgrade(): # noqa: D103 +def downgrade(): # noqa: D103 # use batch_alter_table to support SQLite workaround with op.batch_alter_table("xcom") as batch_op: batch_op.alter_column('value', type_=sa.PickleType(pickler=dill)) diff --git a/airflow/migrations/versions/bef4f3d11e8b_drop_kuberesourceversion_and_.py b/airflow/migrations/versions/bef4f3d11e8b_drop_kuberesourceversion_and_.py index daa00d611ccd5..d58d959ba1e16 100644 --- a/airflow/migrations/versions/bef4f3d11e8b_drop_kuberesourceversion_and_.py +++ b/airflow/migrations/versions/bef4f3d11e8b_drop_kuberesourceversion_and_.py @@ -53,35 +53,26 @@ def downgrade(): def _add_worker_uuid_table(): columns_and_constraints = [ sa.Column("one_row_id", sa.Boolean, server_default=sa.true(), primary_key=True), - sa.Column("worker_uuid", sa.String(255)) + sa.Column("worker_uuid", sa.String(255)), ] conn = op.get_bind() # alembic creates an invalid SQL for mssql and mysql dialects if conn.dialect.name in {"mysql"}: - columns_and_constraints.append( - sa.CheckConstraint("one_row_id<>0", name="kube_worker_one_row_id") - ) + columns_and_constraints.append(sa.CheckConstraint("one_row_id<>0", name="kube_worker_one_row_id")) elif conn.dialect.name not in {"mssql"}: - columns_and_constraints.append( - sa.CheckConstraint("one_row_id", name="kube_worker_one_row_id") - ) + columns_and_constraints.append(sa.CheckConstraint("one_row_id", name="kube_worker_one_row_id")) - table = op.create_table( - WORKER_RESOURCEVERSION_TABLE, - *columns_and_constraints - ) + table = op.create_table(WORKER_RESOURCEVERSION_TABLE, *columns_and_constraints) - op.bulk_insert(table, [ - {"worker_uuid": ""} - ]) + op.bulk_insert(table, [{"worker_uuid": ""}]) def _add_resource_table(): columns_and_constraints = [ sa.Column("one_row_id", sa.Boolean, server_default=sa.true(), primary_key=True), - sa.Column("resource_version", sa.String(255)) + sa.Column("resource_version", sa.String(255)), ] conn = op.get_bind() @@ -96,11 +87,6 @@ def _add_resource_table(): sa.CheckConstraint("one_row_id", name="kube_resource_version_one_row_id") ) - table = op.create_table( - WORKER_RESOURCEVERSION_TABLE, - *columns_and_constraints - ) + table = op.create_table(WORKER_RESOURCEVERSION_TABLE, *columns_and_constraints) - op.bulk_insert(table, [ - {"resource_version": ""} - ]) + op.bulk_insert(table, [{"resource_version": ""}]) diff --git a/airflow/migrations/versions/bf00311e1990_add_index_to_taskinstance.py b/airflow/migrations/versions/bf00311e1990_add_index_to_taskinstance.py index d03868241ed36..845ce35a026c1 100644 --- a/airflow/migrations/versions/bf00311e1990_add_index_to_taskinstance.py +++ b/airflow/migrations/versions/bf00311e1990_add_index_to_taskinstance.py @@ -33,14 +33,9 @@ depends_on = None -def upgrade(): # noqa: D103 - op.create_index( - 'ti_dag_date', - 'task_instance', - ['dag_id', 'execution_date'], - unique=False - ) +def upgrade(): # noqa: D103 + op.create_index('ti_dag_date', 'task_instance', ['dag_id', 'execution_date'], unique=False) -def downgrade(): # noqa: D103 +def downgrade(): # noqa: D103 op.drop_index('ti_dag_date', table_name='task_instance') diff --git a/airflow/migrations/versions/c8ffec048a3b_add_fields_to_dag.py b/airflow/migrations/versions/c8ffec048a3b_add_fields_to_dag.py index bd0453a07dd37..c620286f7f15e 100644 --- a/airflow/migrations/versions/c8ffec048a3b_add_fields_to_dag.py +++ b/airflow/migrations/versions/c8ffec048a3b_add_fields_to_dag.py @@ -34,11 +34,11 @@ depends_on = None -def upgrade(): # noqa: D103 +def upgrade(): # noqa: D103 op.add_column('dag', sa.Column('description', sa.Text(), nullable=True)) op.add_column('dag', sa.Column('default_view', sa.String(25), nullable=True)) -def downgrade(): # noqa: D103 +def downgrade(): # noqa: D103 op.drop_column('dag', 'description') op.drop_column('dag', 'default_view') diff --git a/airflow/migrations/versions/cc1e65623dc7_add_max_tries_column_to_task_instance.py b/airflow/migrations/versions/cc1e65623dc7_add_max_tries_column_to_task_instance.py index 46f9178781b6b..e6169b2728f28 100644 --- a/airflow/migrations/versions/cc1e65623dc7_add_max_tries_column_to_task_instance.py +++ b/airflow/migrations/versions/cc1e65623dc7_add_max_tries_column_to_task_instance.py @@ -31,6 +31,7 @@ from airflow import settings from airflow.models import DagBag + # revision identifiers, used by Alembic. from airflow.models.base import COLLATION_ARGS @@ -54,7 +55,7 @@ class TaskInstance(Base): # noqa: D101 # type: ignore try_number = Column(Integer, default=0) -def upgrade(): # noqa: D103 +def upgrade(): # noqa: D103 op.add_column('task_instance', sa.Column('max_tries', sa.Integer, server_default="-1")) # Check if table task_instance exist before data migration. This check is # needed for database that does not create table until migration finishes. @@ -69,15 +70,11 @@ def upgrade(): # noqa: D103 sessionmaker = sa.orm.sessionmaker() session = sessionmaker(bind=connection) dagbag = DagBag(settings.DAGS_FOLDER) - query = session.query(sa.func.count(TaskInstance.max_tries)).filter( - TaskInstance.max_tries == -1 - ) + query = session.query(sa.func.count(TaskInstance.max_tries)).filter(TaskInstance.max_tries == -1) # Separate db query in batch to prevent loading entire table # into memory and cause out of memory error. while query.scalar(): - tis = session.query(TaskInstance).filter( - TaskInstance.max_tries == -1 - ).limit(BATCH_SIZE).all() + tis = session.query(TaskInstance).filter(TaskInstance.max_tries == -1).limit(BATCH_SIZE).all() for ti in tis: dag = dagbag.get_dag(ti.dag_id) if not dag or not dag.has_task(ti.task_id): @@ -100,20 +97,16 @@ def upgrade(): # noqa: D103 session.commit() -def downgrade(): # noqa: D103 +def downgrade(): # noqa: D103 engine = settings.engine if engine.dialect.has_table(engine, 'task_instance'): connection = op.get_bind() sessionmaker = sa.orm.sessionmaker() session = sessionmaker(bind=connection) dagbag = DagBag(settings.DAGS_FOLDER) - query = session.query(sa.func.count(TaskInstance.max_tries)).filter( - TaskInstance.max_tries != -1 - ) + query = session.query(sa.func.count(TaskInstance.max_tries)).filter(TaskInstance.max_tries != -1) while query.scalar(): - tis = session.query(TaskInstance).filter( - TaskInstance.max_tries != -1 - ).limit(BATCH_SIZE).all() + tis = session.query(TaskInstance).filter(TaskInstance.max_tries != -1).limit(BATCH_SIZE).all() for ti in tis: dag = dagbag.get_dag(ti.dag_id) if not dag or not dag.has_task(ti.task_id): @@ -124,8 +117,7 @@ def downgrade(): # noqa: D103 # left to retry by itself. So the current try_number should be # max number of self retry (task.retries) minus number of # times left for task instance to try the task. - ti.try_number = max(0, task.retries - (ti.max_tries - - ti.try_number)) + ti.try_number = max(0, task.retries - (ti.max_tries - ti.try_number)) ti.max_tries = -1 session.merge(ti) session.commit() diff --git a/airflow/migrations/versions/cf5dc11e79ad_drop_user_and_chart.py b/airflow/migrations/versions/cf5dc11e79ad_drop_user_and_chart.py index 86e976983a154..35b07a4c9e9df 100644 --- a/airflow/migrations/versions/cf5dc11e79ad_drop_user_and_chart.py +++ b/airflow/migrations/versions/cf5dc11e79ad_drop_user_and_chart.py @@ -34,7 +34,7 @@ depends_on = None -def upgrade(): # noqa: D103 +def upgrade(): # noqa: D103 # We previously had a KnownEvent's table, but we deleted the table without # a down migration to remove it (so we didn't delete anyone's data if they # were happing to use the feature. @@ -49,13 +49,15 @@ def upgrade(): # noqa: D103 op.drop_constraint('known_event_user_id_fkey', 'known_event') if "chart" in tables: - op.drop_table("chart", ) + op.drop_table( + "chart", + ) if "users" in tables: op.drop_table("users") -def downgrade(): # noqa: D103 +def downgrade(): # noqa: D103 conn = op.get_bind() op.create_table( @@ -66,7 +68,7 @@ def downgrade(): # noqa: D103 sa.Column('password', sa.String(255)), sa.Column('superuser', sa.Boolean(), default=False), sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('username') + sa.UniqueConstraint('username'), ) op.create_table( @@ -86,8 +88,11 @@ def downgrade(): # noqa: D103 sa.Column('x_is_date', sa.Boolean(), nullable=True), sa.Column('iteration_no', sa.Integer(), nullable=True), sa.Column('last_modified', sa.DateTime(), nullable=True), - sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), - sa.PrimaryKeyConstraint('id') + sa.ForeignKeyConstraint( + ['user_id'], + ['users.id'], + ), + sa.PrimaryKeyConstraint('id'), ) if conn.dialect.name == 'mysql': diff --git a/airflow/migrations/versions/d2ae31099d61_increase_text_size_for_mysql.py b/airflow/migrations/versions/d2ae31099d61_increase_text_size_for_mysql.py index bb6f2a54f902b..e4f81569c4d61 100644 --- a/airflow/migrations/versions/d2ae31099d61_increase_text_size_for_mysql.py +++ b/airflow/migrations/versions/d2ae31099d61_increase_text_size_for_mysql.py @@ -33,11 +33,11 @@ depends_on = None -def upgrade(): # noqa: D103 +def upgrade(): # noqa: D103 if context.config.get_main_option('sqlalchemy.url').startswith('mysql'): op.alter_column(table_name='variable', column_name='val', type_=mysql.MEDIUMTEXT) -def downgrade(): # noqa: D103 +def downgrade(): # noqa: D103 if context.config.get_main_option('sqlalchemy.url').startswith('mysql'): op.alter_column(table_name='variable', column_name='val', type_=mysql.TEXT) diff --git a/airflow/migrations/versions/d38e04c12aa2_add_serialized_dag_table.py b/airflow/migrations/versions/d38e04c12aa2_add_serialized_dag_table.py index 2dd35a901e702..2a446e6a169f8 100644 --- a/airflow/migrations/versions/d38e04c12aa2_add_serialized_dag_table.py +++ b/airflow/migrations/versions/d38e04c12aa2_add_serialized_dag_table.py @@ -47,24 +47,23 @@ def upgrade(): except (sa.exc.OperationalError, sa.exc.ProgrammingError): json_type = sa.Text - op.create_table('serialized_dag', # pylint: disable=no-member - sa.Column('dag_id', sa.String(length=250), nullable=False), - sa.Column('fileloc', sa.String(length=2000), nullable=False), - sa.Column('fileloc_hash', sa.Integer(), nullable=False), - sa.Column('data', json_type(), nullable=False), - sa.Column('last_updated', sa.DateTime(), nullable=False), - sa.PrimaryKeyConstraint('dag_id')) - op.create_index( # pylint: disable=no-member - 'idx_fileloc_hash', 'serialized_dag', ['fileloc_hash']) + op.create_table( + 'serialized_dag', # pylint: disable=no-member + sa.Column('dag_id', sa.String(length=250), nullable=False), + sa.Column('fileloc', sa.String(length=2000), nullable=False), + sa.Column('fileloc_hash', sa.Integer(), nullable=False), + sa.Column('data', json_type(), nullable=False), + sa.Column('last_updated', sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint('dag_id'), + ) + op.create_index('idx_fileloc_hash', 'serialized_dag', ['fileloc_hash']) # pylint: disable=no-member if conn.dialect.name == "mysql": conn.execute("SET time_zone = '+00:00'") cur = conn.execute("SELECT @@explicit_defaults_for_timestamp") res = cur.fetchall() if res[0][0] == 0: - raise Exception( - "Global variable explicit_defaults_for_timestamp needs to be on (1) for mysql" - ) + raise Exception("Global variable explicit_defaults_for_timestamp needs to be on (1) for mysql") op.alter_column( # pylint: disable=no-member table_name="serialized_dag", @@ -91,4 +90,4 @@ def upgrade(): def downgrade(): """Downgrade version.""" - op.drop_table('serialized_dag') # pylint: disable=no-member + op.drop_table('serialized_dag') # pylint: disable=no-member diff --git a/airflow/migrations/versions/da3f683c3a5a_add_dag_hash_column_to_serialized_dag_.py b/airflow/migrations/versions/da3f683c3a5a_add_dag_hash_column_to_serialized_dag_.py index 4dbc77a4c5aa4..e757828731d76 100644 --- a/airflow/migrations/versions/da3f683c3a5a_add_dag_hash_column_to_serialized_dag_.py +++ b/airflow/migrations/versions/da3f683c3a5a_add_dag_hash_column_to_serialized_dag_.py @@ -38,7 +38,8 @@ def upgrade(): """Apply Add dag_hash Column to serialized_dag table""" op.add_column( 'serialized_dag', - sa.Column('dag_hash', sa.String(32), nullable=False, server_default='Hash not calculated yet')) + sa.Column('dag_hash', sa.String(32), nullable=False, server_default='Hash not calculated yet'), + ) def downgrade(): diff --git a/airflow/migrations/versions/dd25f486b8ea_add_idx_log_dag.py b/airflow/migrations/versions/dd25f486b8ea_add_idx_log_dag.py index c3530945d8269..220535ac7c5e8 100644 --- a/airflow/migrations/versions/dd25f486b8ea_add_idx_log_dag.py +++ b/airflow/migrations/versions/dd25f486b8ea_add_idx_log_dag.py @@ -31,9 +31,9 @@ depends_on = None -def upgrade(): # noqa: D103 +def upgrade(): # noqa: D103 op.create_index('idx_log_dag', 'log', ['dag_id'], unique=False) -def downgrade(): # noqa: D103 +def downgrade(): # noqa: D103 op.drop_index('idx_log_dag', table_name='log') diff --git a/airflow/migrations/versions/dd4ecb8fbee3_add_schedule_interval_to_dag.py b/airflow/migrations/versions/dd4ecb8fbee3_add_schedule_interval_to_dag.py index ed335d2e3508b..b5fdc29b9a13d 100644 --- a/airflow/migrations/versions/dd4ecb8fbee3_add_schedule_interval_to_dag.py +++ b/airflow/migrations/versions/dd4ecb8fbee3_add_schedule_interval_to_dag.py @@ -34,9 +34,9 @@ depends_on = None -def upgrade(): # noqa: D103 +def upgrade(): # noqa: D103 op.add_column('dag', sa.Column('schedule_interval', sa.Text(), nullable=True)) -def downgrade(): # noqa: D103 +def downgrade(): # noqa: D103 op.drop_column('dag', 'schedule_interval') diff --git a/airflow/migrations/versions/e38be357a868_update_schema_for_smart_sensor.py b/airflow/migrations/versions/e38be357a868_update_schema_for_smart_sensor.py index 27227aee9377d..ed73a7e6ab92b 100644 --- a/airflow/migrations/versions/e38be357a868_update_schema_for_smart_sensor.py +++ b/airflow/migrations/versions/e38be357a868_update_schema_for_smart_sensor.py @@ -35,15 +35,15 @@ depends_on = None -def mssql_timestamp(): # noqa: D103 +def mssql_timestamp(): # noqa: D103 return sa.DateTime() -def mysql_timestamp(): # noqa: D103 +def mysql_timestamp(): # noqa: D103 return mysql.TIMESTAMP(fsp=6) -def sa_timestamp(): # noqa: D103 +def sa_timestamp(): # noqa: D103 return sa.TIMESTAMP(timezone=True) @@ -74,14 +74,9 @@ def upgrade(): # noqa: D103 sa.Column('execution_context', sa.Text(), nullable=True), sa.Column('created_at', timestamp(), default=func.now(), nullable=False), sa.Column('updated_at', timestamp(), default=func.now(), nullable=False), - sa.PrimaryKeyConstraint('id') - ) - op.create_index( - 'ti_primary_key', - 'sensor_instance', - ['dag_id', 'task_id', 'execution_date'], - unique=True + sa.PrimaryKeyConstraint('id'), ) + op.create_index('ti_primary_key', 'sensor_instance', ['dag_id', 'task_id', 'execution_date'], unique=True) op.create_index('si_hashcode', 'sensor_instance', ['hashcode'], unique=False) op.create_index('si_shardcode', 'sensor_instance', ['shardcode'], unique=False) op.create_index('si_state_shard', 'sensor_instance', ['state', 'shardcode'], unique=False) diff --git a/airflow/migrations/versions/e3a246e0dc1_current_schema.py b/airflow/migrations/versions/e3a246e0dc1_current_schema.py index dfa4f5818c69f..60e6cdf3b3c29 100644 --- a/airflow/migrations/versions/e3a246e0dc1_current_schema.py +++ b/airflow/migrations/versions/e3a246e0dc1_current_schema.py @@ -38,7 +38,7 @@ depends_on = None -def upgrade(): # noqa: D103 +def upgrade(): # noqa: D103 conn = op.get_bind() inspector = Inspector.from_engine(conn) tables = inspector.get_table_names() @@ -55,7 +55,7 @@ def upgrade(): # noqa: D103 sa.Column('password', sa.String(length=500), nullable=True), sa.Column('port', sa.Integer(), nullable=True), sa.Column('extra', sa.String(length=5000), nullable=True), - sa.PrimaryKeyConstraint('id') + sa.PrimaryKeyConstraint('id'), ) if 'dag' not in tables: op.create_table( @@ -71,7 +71,7 @@ def upgrade(): # noqa: D103 sa.Column('pickle_id', sa.Integer(), nullable=True), sa.Column('fileloc', sa.String(length=2000), nullable=True), sa.Column('owners', sa.String(length=2000), nullable=True), - sa.PrimaryKeyConstraint('dag_id') + sa.PrimaryKeyConstraint('dag_id'), ) if 'dag_pickle' not in tables: op.create_table( @@ -80,7 +80,7 @@ def upgrade(): # noqa: D103 sa.Column('pickle', sa.PickleType(), nullable=True), sa.Column('created_dttm', sa.DateTime(), nullable=True), sa.Column('pickle_hash', sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint('id') + sa.PrimaryKeyConstraint('id'), ) if 'import_error' not in tables: op.create_table( @@ -89,7 +89,7 @@ def upgrade(): # noqa: D103 sa.Column('timestamp', sa.DateTime(), nullable=True), sa.Column('filename', sa.String(length=1024), nullable=True), sa.Column('stacktrace', sa.Text(), nullable=True), - sa.PrimaryKeyConstraint('id') + sa.PrimaryKeyConstraint('id'), ) if 'job' not in tables: op.create_table( @@ -104,14 +104,9 @@ def upgrade(): # noqa: D103 sa.Column('executor_class', sa.String(length=500), nullable=True), sa.Column('hostname', sa.String(length=500), nullable=True), sa.Column('unixname', sa.String(length=1000), nullable=True), - sa.PrimaryKeyConstraint('id') - ) - op.create_index( - 'job_type_heart', - 'job', - ['job_type', 'latest_heartbeat'], - unique=False + sa.PrimaryKeyConstraint('id'), ) + op.create_index('job_type_heart', 'job', ['job_type', 'latest_heartbeat'], unique=False) if 'log' not in tables: op.create_table( 'log', @@ -122,7 +117,7 @@ def upgrade(): # noqa: D103 sa.Column('event', sa.String(length=30), nullable=True), sa.Column('execution_date', sa.DateTime(), nullable=True), sa.Column('owner', sa.String(length=500), nullable=True), - sa.PrimaryKeyConstraint('id') + sa.PrimaryKeyConstraint('id'), ) if 'sla_miss' not in tables: op.create_table( @@ -133,7 +128,7 @@ def upgrade(): # noqa: D103 sa.Column('email_sent', sa.Boolean(), nullable=True), sa.Column('timestamp', sa.DateTime(), nullable=True), sa.Column('description', sa.Text(), nullable=True), - sa.PrimaryKeyConstraint('task_id', 'dag_id', 'execution_date') + sa.PrimaryKeyConstraint('task_id', 'dag_id', 'execution_date'), ) if 'slot_pool' not in tables: op.create_table( @@ -143,7 +138,7 @@ def upgrade(): # noqa: D103 sa.Column('slots', sa.Integer(), nullable=True), sa.Column('description', sa.Text(), nullable=True), sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('pool') + sa.UniqueConstraint('pool'), ) if 'task_instance' not in tables: op.create_table( @@ -162,25 +157,12 @@ def upgrade(): # noqa: D103 sa.Column('pool', sa.String(length=50), nullable=True), sa.Column('queue', sa.String(length=50), nullable=True), sa.Column('priority_weight', sa.Integer(), nullable=True), - sa.PrimaryKeyConstraint('task_id', 'dag_id', 'execution_date') + sa.PrimaryKeyConstraint('task_id', 'dag_id', 'execution_date'), ) + op.create_index('ti_dag_state', 'task_instance', ['dag_id', 'state'], unique=False) + op.create_index('ti_pool', 'task_instance', ['pool', 'state', 'priority_weight'], unique=False) op.create_index( - 'ti_dag_state', - 'task_instance', - ['dag_id', 'state'], - unique=False - ) - op.create_index( - 'ti_pool', - 'task_instance', - ['pool', 'state', 'priority_weight'], - unique=False - ) - op.create_index( - 'ti_state_lkp', - 'task_instance', - ['dag_id', 'task_id', 'execution_date', 'state'], - unique=False + 'ti_state_lkp', 'task_instance', ['dag_id', 'task_id', 'execution_date', 'state'], unique=False ) if 'user' not in tables: @@ -190,7 +172,7 @@ def upgrade(): # noqa: D103 sa.Column('username', sa.String(length=250), nullable=True), sa.Column('email', sa.String(length=500), nullable=True), sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('username') + sa.UniqueConstraint('username'), ) if 'variable' not in tables: op.create_table( @@ -199,7 +181,7 @@ def upgrade(): # noqa: D103 sa.Column('key', sa.String(length=250), nullable=True), sa.Column('val', sa.Text(), nullable=True), sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('key') + sa.UniqueConstraint('key'), ) if 'chart' not in tables: op.create_table( @@ -219,8 +201,11 @@ def upgrade(): # noqa: D103 sa.Column('x_is_date', sa.Boolean(), nullable=True), sa.Column('iteration_no', sa.Integer(), nullable=True), sa.Column('last_modified', sa.DateTime(), nullable=True), - sa.ForeignKeyConstraint(['user_id'], ['user.id'], ), - sa.PrimaryKeyConstraint('id') + sa.ForeignKeyConstraint( + ['user_id'], + ['user.id'], + ), + sa.PrimaryKeyConstraint('id'), ) if 'xcom' not in tables: op.create_table( @@ -228,19 +213,15 @@ def upgrade(): # noqa: D103 sa.Column('id', sa.Integer(), nullable=False), sa.Column('key', sa.String(length=512, **COLLATION_ARGS), nullable=True), sa.Column('value', sa.PickleType(), nullable=True), - sa.Column( - 'timestamp', - sa.DateTime(), - default=func.now(), - nullable=False), + sa.Column('timestamp', sa.DateTime(), default=func.now(), nullable=False), sa.Column('execution_date', sa.DateTime(), nullable=False), sa.Column('task_id', sa.String(length=250, **COLLATION_ARGS), nullable=False), sa.Column('dag_id', sa.String(length=250, **COLLATION_ARGS), nullable=False), - sa.PrimaryKeyConstraint('id') + sa.PrimaryKeyConstraint('id'), ) -def downgrade(): # noqa: D103 +def downgrade(): # noqa: D103 op.drop_table('chart') op.drop_table('variable') op.drop_table('user') diff --git a/airflow/migrations/versions/f23433877c24_fix_mysql_not_null_constraint.py b/airflow/migrations/versions/f23433877c24_fix_mysql_not_null_constraint.py index f9725d38e4118..7a0a3c8cc1706 100644 --- a/airflow/migrations/versions/f23433877c24_fix_mysql_not_null_constraint.py +++ b/airflow/migrations/versions/f23433877c24_fix_mysql_not_null_constraint.py @@ -32,7 +32,7 @@ depends_on = None -def upgrade(): # noqa: D103 +def upgrade(): # noqa: D103 conn = op.get_bind() if conn.dialect.name == 'mysql': conn.execute("SET time_zone = '+00:00'") @@ -41,7 +41,7 @@ def upgrade(): # noqa: D103 op.alter_column('xcom', 'timestamp', existing_type=mysql.TIMESTAMP(fsp=6), nullable=False) -def downgrade(): # noqa: D103 +def downgrade(): # noqa: D103 conn = op.get_bind() if conn.dialect.name == 'mysql': conn.execute("SET time_zone = '+00:00'") diff --git a/airflow/migrations/versions/f2ca10b85618_add_dag_stats_table.py b/airflow/migrations/versions/f2ca10b85618_add_dag_stats_table.py index 7ab2fe943726f..1db0440cbd4a4 100644 --- a/airflow/migrations/versions/f2ca10b85618_add_dag_stats_table.py +++ b/airflow/migrations/versions/f2ca10b85618_add_dag_stats_table.py @@ -33,14 +33,16 @@ depends_on = None -def upgrade(): # noqa: D103 - op.create_table('dag_stats', - sa.Column('dag_id', sa.String(length=250), nullable=False), - sa.Column('state', sa.String(length=50), nullable=False), - sa.Column('count', sa.Integer(), nullable=False, default=0), - sa.Column('dirty', sa.Boolean(), nullable=False, default=False), - sa.PrimaryKeyConstraint('dag_id', 'state')) - - -def downgrade(): # noqa: D103 +def upgrade(): # noqa: D103 + op.create_table( + 'dag_stats', + sa.Column('dag_id', sa.String(length=250), nullable=False), + sa.Column('state', sa.String(length=50), nullable=False), + sa.Column('count', sa.Integer(), nullable=False, default=0), + sa.Column('dirty', sa.Boolean(), nullable=False, default=False), + sa.PrimaryKeyConstraint('dag_id', 'state'), + ) + + +def downgrade(): # noqa: D103 op.drop_table('dag_stats') diff --git a/airflow/migrations/versions/fe461863935f_increase_length_for_connection_password.py b/airflow/migrations/versions/fe461863935f_increase_length_for_connection_password.py index 4b14e8824dcdc..0e97630248a9b 100644 --- a/airflow/migrations/versions/fe461863935f_increase_length_for_connection_password.py +++ b/airflow/migrations/versions/fe461863935f_increase_length_for_connection_password.py @@ -37,16 +37,20 @@ def upgrade(): """Apply increase_length_for_connection_password""" with op.batch_alter_table('connection', schema=None) as batch_op: - batch_op.alter_column('password', - existing_type=sa.VARCHAR(length=500), - type_=sa.String(length=5000), - existing_nullable=True) + batch_op.alter_column( + 'password', + existing_type=sa.VARCHAR(length=500), + type_=sa.String(length=5000), + existing_nullable=True, + ) def downgrade(): """Unapply increase_length_for_connection_password""" with op.batch_alter_table('connection', schema=None) as batch_op: - batch_op.alter_column('password', - existing_type=sa.String(length=5000), - type_=sa.VARCHAR(length=500), - existing_nullable=True) + batch_op.alter_column( + 'password', + existing_type=sa.String(length=5000), + type_=sa.VARCHAR(length=500), + existing_nullable=True, + ) diff --git a/airflow/models/base.py b/airflow/models/base.py index dc1953439ecf1..55584e4f2a27b 100644 --- a/airflow/models/base.py +++ b/airflow/models/base.py @@ -26,9 +26,7 @@ SQL_ALCHEMY_SCHEMA = conf.get("core", "SQL_ALCHEMY_SCHEMA") metadata = ( - None - if not SQL_ALCHEMY_SCHEMA or SQL_ALCHEMY_SCHEMA.isspace() - else MetaData(schema=SQL_ALCHEMY_SCHEMA) + None if not SQL_ALCHEMY_SCHEMA or SQL_ALCHEMY_SCHEMA.isspace() else MetaData(schema=SQL_ALCHEMY_SCHEMA) ) Base = declarative_base(metadata=metadata) # type: Any diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 5e61ad40f1b36..fa4a840fe81e8 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -25,8 +25,20 @@ from abc import ABCMeta, abstractmethod from datetime import datetime, timedelta from typing import ( - TYPE_CHECKING, Any, Callable, ClassVar, Dict, FrozenSet, Iterable, List, Optional, Sequence, Set, Tuple, - Type, Union, + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Dict, + FrozenSet, + Iterable, + List, + Optional, + Sequence, + Set, + Tuple, + Type, + Union, ) import attr @@ -282,8 +294,12 @@ class derived from this one results in the creation of a task object, pool = "" # type: str # base list which includes all the attrs that don't need deep copy. - _base_operator_shallow_copy_attrs: Tuple[str, ...] = \ - ('user_defined_macros', 'user_defined_filters', 'params', '_log',) + _base_operator_shallow_copy_attrs: Tuple[str, ...] = ( + 'user_defined_macros', + 'user_defined_filters', + 'params', + '_log', + ) # each operator should override this class attr for shallow copy attrs. shallow_copy_attrs: Tuple[str, ...] = () @@ -365,7 +381,7 @@ def __init__( inlets: Optional[Any] = None, outlets: Optional[Any] = None, task_group: Optional["TaskGroup"] = None, - **kwargs + **kwargs, ): from airflow.models.dag import DagContext from airflow.utils.task_group import TaskGroupContext @@ -375,17 +391,15 @@ def __init__( if not conf.getboolean('operators', 'ALLOW_ILLEGAL_ARGUMENTS'): raise AirflowException( "Invalid arguments were passed to {c} (task_id: {t}). Invalid " - "arguments were:\n**kwargs: {k}".format( - c=self.__class__.__name__, k=kwargs, t=task_id), + "arguments were:\n**kwargs: {k}".format(c=self.__class__.__name__, k=kwargs, t=task_id), ) warnings.warn( 'Invalid arguments were passed to {c} (task_id: {t}). ' 'Support for passing such arguments will be dropped in ' 'future. Invalid arguments were:' - '\n**kwargs: {k}'.format( - c=self.__class__.__name__, k=kwargs, t=task_id), + '\n**kwargs: {k}'.format(c=self.__class__.__name__, k=kwargs, t=task_id), category=PendingDeprecationWarning, - stacklevel=3 + stacklevel=3, ) validate_key(task_id) self.task_id = task_id @@ -412,9 +426,13 @@ def __init__( if not TriggerRule.is_valid(trigger_rule): raise AirflowException( "The trigger_rule must be one of {all_triggers}," - "'{d}.{t}'; received '{tr}'." - .format(all_triggers=TriggerRule.all_triggers(), - d=dag.dag_id if dag else "", t=task_id, tr=trigger_rule)) + "'{d}.{t}'; received '{tr}'.".format( + all_triggers=TriggerRule.all_triggers(), + d=dag.dag_id if dag else "", + t=task_id, + tr=trigger_rule, + ) + ) self.trigger_rule = trigger_rule self.depends_on_past = depends_on_past @@ -427,8 +445,7 @@ def __init__( self.pool = Pool.DEFAULT_POOL_NAME if pool is None else pool self.pool_slots = pool_slots if self.pool_slots < 1: - raise AirflowException("pool slots for %s in dag %s cannot be less than 1" - % (self.task_id, dag.dag_id)) + raise AirflowException(f"pool slots for {self.task_id} in dag {dag.dag_id} cannot be less than 1") self.sla = sla self.execution_timeout = execution_timeout self.on_execute_callback = on_execute_callback @@ -448,9 +465,13 @@ def __init__( if not WeightRule.is_valid(weight_rule): raise AirflowException( "The weight_rule must be one of {all_weight_rules}," - "'{d}.{t}'; received '{tr}'." - .format(all_weight_rules=WeightRule.all_weight_rules, - d=dag.dag_id if dag else "", t=task_id, tr=weight_rule)) + "'{d}.{t}'; received '{tr}'.".format( + all_weight_rules=WeightRule.all_weight_rules, + d=dag.dag_id if dag else "", + t=task_id, + tr=weight_rule, + ) + ) self.weight_rule = weight_rule self.resources: Optional[Resources] = Resources(**resources) if resources else None self.run_as_user = run_as_user @@ -468,6 +489,7 @@ def __init__( # subdag parameter is only set for SubDagOperator. # Setting it to None by default as other Operators do not have that field from airflow.models.dag import DAG + self.subdag: Optional[DAG] = None self._log = logging.getLogger("airflow.task.operators") @@ -480,10 +502,22 @@ def __init__( self._outlets: List = [] if inlets: - self._inlets = inlets if isinstance(inlets, list) else [inlets, ] + self._inlets = ( + inlets + if isinstance(inlets, list) + else [ + inlets, + ] + ) if outlets: - self._outlets = outlets if isinstance(outlets, list) else [outlets, ] + self._outlets = ( + outlets + if isinstance(outlets, list) + else [ + outlets, + ] + ) def __eq__(self, other): if type(self) is type(other) and self.task_id == other.task_id: @@ -587,8 +621,7 @@ def dag(self) -> Any: if self.has_dag(): return self._dag else: - raise AirflowException( - f'Operator {self} has not been assigned to a DAG yet') + raise AirflowException(f'Operator {self} has not been assigned to a DAG yet') @dag.setter def dag(self, dag: Any): @@ -597,15 +630,14 @@ def dag(self, dag: Any): that same DAG are ok. """ from airflow.models.dag import DAG + if dag is None: self._dag = None return if not isinstance(dag, DAG): - raise TypeError( - f'Expected DAG; received {dag.__class__.__name__}') + raise TypeError(f'Expected DAG; received {dag.__class__.__name__}') elif self.has_dag() and self.dag is not dag: - raise AirflowException( - f"The DAG assigned to {self} can not be changed.") + raise AirflowException(f"The DAG assigned to {self} can not be changed.") elif self.task_id not in dag.task_dict: dag.add_task(self) elif self.task_id in dag.task_dict and dag.task_dict[self.task_id] is not self: @@ -671,7 +703,7 @@ def set_xcomargs_dependencies(self) -> None: """ from airflow.models.xcom_arg import XComArg - def apply_set_upstream(arg: Any): # noqa + def apply_set_upstream(arg: Any): # noqa if isinstance(arg, XComArg): self.set_upstream(arg.operator) elif isinstance(arg, (tuple, set, list)): @@ -712,10 +744,13 @@ def priority_weight_total(self) -> int: if not self._dag: return self.priority_weight from airflow.models.dag import DAG + dag: DAG = self._dag return self.priority_weight + sum( - map(lambda task_id: dag.task_dict[task_id].priority_weight, - self.get_flat_relative_ids(upstream=upstream)) + map( + lambda task_id: dag.task_dict[task_id].priority_weight, + self.get_flat_relative_ids(upstream=upstream), + ) ) @cached_property @@ -723,6 +758,7 @@ def operator_extra_link_dict(self) -> Dict[str, Any]: """Returns dictionary of all extra links for the operator""" op_extra_links_from_plugin: Dict[str, Any] = {} from airflow import plugins_manager + plugins_manager.initialize_extra_operators_links_plugins() if plugins_manager.operator_extra_links is None: raise AirflowException("Can't load operators") @@ -730,9 +766,7 @@ def operator_extra_link_dict(self) -> Dict[str, Any]: if ope.operators and self.__class__ in ope.operators: op_extra_links_from_plugin.update({ope.name: ope}) - operator_extra_links_all = { - link.name: link for link in self.operator_extra_links - } + operator_extra_links_all = {link.name: link for link in self.operator_extra_links} # Extra links defined in Plugins overrides operator links defined in operator operator_extra_links_all.update(op_extra_links_from_plugin) @@ -742,6 +776,7 @@ def operator_extra_link_dict(self) -> Dict[str, Any]: def global_operator_extra_link_dict(self) -> Dict[str, Any]: """Returns dictionary of all global extra links""" from airflow import plugins_manager + plugins_manager.initialize_extra_operators_links_plugins() if plugins_manager.global_operator_extra_links is None: raise AirflowException("Can't load operators") @@ -786,8 +821,9 @@ def __deepcopy__(self, memo): result = cls.__new__(cls) memo[id(self)] = result - shallow_copy = cls.shallow_copy_attrs + \ - cls._base_operator_shallow_copy_attrs # pylint: disable=protected-access + shallow_copy = ( + cls.shallow_copy_attrs + cls._base_operator_shallow_copy_attrs + ) # pylint: disable=protected-access for k, v in self.__dict__.items(): if k not in shallow_copy: @@ -821,8 +857,12 @@ def render_template_fields(self, context: Dict, jinja_env: Optional[jinja2.Envir self._do_render_template_fields(self, self.template_fields, context, jinja_env, set()) def _do_render_template_fields( - self, parent: Any, template_fields: Iterable[str], context: Dict, jinja_env: jinja2.Environment, - seen_oids: Set + self, + parent: Any, + template_fields: Iterable[str], + context: Dict, + jinja_env: jinja2.Environment, + seen_oids: Set, ) -> None: for attr_name in template_fields: content = getattr(parent, attr_name) @@ -830,9 +870,12 @@ def _do_render_template_fields( rendered_content = self.render_template(content, context, jinja_env, seen_oids) setattr(parent, attr_name, rendered_content) - def render_template( # pylint: disable=too-many-return-statements - self, content: Any, context: Dict, jinja_env: Optional[jinja2.Environment] = None, - seen_oids: Optional[Set] = None + def render_template( # pylint: disable=too-many-return-statements + self, + content: Any, + context: Dict, + jinja_env: Optional[jinja2.Environment] = None, + seen_oids: Optional[Set] = None, ) -> Any: """ Render a templated string. The content can be a collection holding multiple templated strings and will @@ -921,8 +964,7 @@ def resolve_template_files(self) -> None: content = getattr(self, field, None) if content is None: # pylint: disable=no-else-continue continue - elif isinstance(content, str) and \ - any(content.endswith(ext) for ext in self.template_ext): + elif isinstance(content, str) and any(content.endswith(ext) for ext in self.template_ext): env = self.get_template_env() try: setattr(self, field, env.loader.get_source(env, content)[0]) @@ -931,8 +973,9 @@ def resolve_template_files(self) -> None: elif isinstance(content, list): env = self.dag.get_template_env() for i in range(len(content)): # pylint: disable=consider-using-enumerate - if isinstance(content[i], str) and \ - any(content[i].endswith(ext) for ext in self.template_ext): + if isinstance(content[i], str) and any( + content[i].endswith(ext) for ext in self.template_ext + ): try: content[i] = env.loader.get_source(env, content[i])[0] except Exception as e: # pylint: disable=broad-except @@ -960,12 +1003,14 @@ def downstream_task_ids(self) -> Set[str]: return self._downstream_task_ids @provide_session - def clear(self, - start_date: Optional[datetime] = None, - end_date: Optional[datetime] = None, - upstream: bool = False, - downstream: bool = False, - session: Session = None): + def clear( + self, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + upstream: bool = False, + downstream: bool = False, + session: Session = None, + ): """ Clears the state of task instances associated with the task, following the parameters specified. @@ -980,12 +1025,10 @@ def clear(self, tasks = [self.task_id] if upstream: - tasks += [ - t.task_id for t in self.get_flat_relatives(upstream=True)] + tasks += [t.task_id for t in self.get_flat_relatives(upstream=True)] if downstream: - tasks += [ - t.task_id for t in self.get_flat_relatives(upstream=False)] + tasks += [t.task_id for t in self.get_flat_relatives(upstream=False)] qry = qry.filter(TaskInstance.task_id.in_(tasks)) results = qry.all() @@ -995,26 +1038,32 @@ def clear(self, return count @provide_session - def get_task_instances(self, start_date: Optional[datetime] = None, - end_date: Optional[datetime] = None, - session: Session = None) -> List[TaskInstance]: + def get_task_instances( + self, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + session: Session = None, + ) -> List[TaskInstance]: """ Get a set of task instance related to this task for a specific date range. """ end_date = end_date or timezone.utcnow() - return session.query(TaskInstance)\ - .filter(TaskInstance.dag_id == self.dag_id)\ - .filter(TaskInstance.task_id == self.task_id)\ - .filter(TaskInstance.execution_date >= start_date)\ - .filter(TaskInstance.execution_date <= end_date)\ - .order_by(TaskInstance.execution_date)\ + return ( + session.query(TaskInstance) + .filter(TaskInstance.dag_id == self.dag_id) + .filter(TaskInstance.task_id == self.task_id) + .filter(TaskInstance.execution_date >= start_date) + .filter(TaskInstance.execution_date <= end_date) + .order_by(TaskInstance.execution_date) .all() + ) - def get_flat_relative_ids(self, - upstream: bool = False, - found_descendants: Optional[Set[str]] = None, - ) -> Set[str]: + def get_flat_relative_ids( + self, + upstream: bool = False, + found_descendants: Optional[Set[str]] = None, + ) -> Set[str]: """Get a flat set of relatives' ids, either upstream or downstream.""" if not self._dag: return set() @@ -1036,17 +1085,18 @@ def get_flat_relatives(self, upstream: bool = False): if not self._dag: return set() from airflow.models.dag import DAG + dag: DAG = self._dag - return list(map(lambda task_id: dag.task_dict[task_id], - self.get_flat_relative_ids(upstream))) + return list(map(lambda task_id: dag.task_dict[task_id], self.get_flat_relative_ids(upstream))) def run( - self, - start_date: Optional[datetime] = None, - end_date: Optional[datetime] = None, - ignore_first_depends_on_past: bool = True, - ignore_ti_state: bool = False, - mark_success: bool = False) -> None: + self, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + ignore_first_depends_on_past: bool = True, + ignore_ti_state: bool = False, + mark_success: bool = False, + ) -> None: """Run a set of task instances for a date range.""" start_date = start_date or self.start_date end_date = end_date or self.end_date or timezone.utcnow() @@ -1054,9 +1104,9 @@ def run( for execution_date in self.dag.date_range(start_date, end_date=end_date): TaskInstance(self, execution_date).run( mark_success=mark_success, - ignore_depends_on_past=( - execution_date == start_date and ignore_first_depends_on_past), - ignore_ti_state=ignore_ti_state) + ignore_depends_on_past=(execution_date == start_date and ignore_first_depends_on_past), + ignore_ti_state=ignore_ti_state, + ) def dry_run(self) -> None: """Performs dry run for the operator - just render template fields.""" @@ -1088,8 +1138,7 @@ def get_direct_relatives(self, upstream: bool = False) -> List["BaseOperator"]: return self.downstream_list def __repr__(self): - return "".format( - self=self) + return "".format(self=self) @property def task_type(self) -> str: @@ -1099,8 +1148,7 @@ def task_type(self) -> str: def add_only_new(self, item_set: Set[str], item: str) -> None: """Adds only new items to item set""" if item in item_set: - self.log.warning( - 'Dependency %s, %s already registered', self, item) + self.log.warning('Dependency %s, %s already registered', self, item) else: item_set.add(item) @@ -1133,25 +1181,29 @@ def _set_relatives( if not isinstance(task, BaseOperator): raise AirflowException( "Relationships can only be set between " - "Operators; received {}".format(task.__class__.__name__)) + "Operators; received {}".format(task.__class__.__name__) + ) # relationships can only be set if the tasks share a single DAG. Tasks # without a DAG are assigned to that DAG. dags = { task._dag.dag_id: task._dag # type: ignore # pylint: disable=protected-access,no-member - for task in self.roots + task_list if task.has_dag()} # pylint: disable=no-member + for task in self.roots + task_list + if task.has_dag() # pylint: disable=no-member + } if len(dags) > 1: raise AirflowException( - 'Tried to set relationships between tasks in ' - 'more than one DAG: {}'.format(dags.values())) + 'Tried to set relationships between tasks in ' 'more than one DAG: {}'.format(dags.values()) + ) elif len(dags) == 1: dag = dags.popitem()[1] else: raise AirflowException( "Tried to create relationships between tasks that don't have " "DAGs yet. Set the DAG for at least one " - "task and try again: {}".format([self] + task_list)) + "task and try again: {}".format([self] + task_list) + ) if dag and not self.has_dag(): self.dag = dag @@ -1184,14 +1236,15 @@ def set_upstream(self, task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]]) def output(self): """Returns reference to XCom pushed by current operator""" from airflow.models.xcom_arg import XComArg + return XComArg(operator=self) @staticmethod def xcom_push( - context: Any, - key: str, - value: Any, - execution_date: Optional[datetime] = None, + context: Any, + key: str, + value: Any, + execution_date: Optional[datetime] = None, ) -> None: """ Make an XCom available for tasks to pull. @@ -1208,18 +1261,15 @@ def xcom_push( task on a future date without it being immediately visible. :type execution_date: datetime """ - context['ti'].xcom_push( - key=key, - value=value, - execution_date=execution_date) + context['ti'].xcom_push(key=key, value=value, execution_date=execution_date) @staticmethod def xcom_pull( - context: Any, - task_ids: Optional[List[str]] = None, - dag_id: Optional[str] = None, - key: str = XCOM_RETURN_KEY, - include_prior_dates: Optional[bool] = None, + context: Any, + task_ids: Optional[List[str]] = None, + dag_id: Optional[str] = None, + key: str = XCOM_RETURN_KEY, + include_prior_dates: Optional[bool] = None, ) -> Any: """ Pull XComs that optionally meet certain criteria. @@ -1253,16 +1303,15 @@ def xcom_pull( :type include_prior_dates: bool """ return context['ti'].xcom_pull( - key=key, - task_ids=task_ids, - dag_id=dag_id, - include_prior_dates=include_prior_dates) + key=key, task_ids=task_ids, dag_id=dag_id, include_prior_dates=include_prior_dates + ) @cached_property def extra_links(self) -> List[str]: """@property: extra links for the task""" - return list(set(self.operator_extra_link_dict.keys()) - .union(self.global_operator_extra_link_dict.keys())) + return list( + set(self.operator_extra_link_dict.keys()).union(self.global_operator_extra_link_dict.keys()) + ) def get_extra_links(self, dttm: datetime, link_name: str) -> Optional[Dict[str, Any]]: """ @@ -1288,11 +1337,25 @@ def get_serialized_fields(cls): """Stringified DAGs and operators contain exactly these fields.""" if not cls.__serialized_fields: cls.__serialized_fields = frozenset( - vars(BaseOperator(task_id='test')).keys() - { - 'inlets', 'outlets', '_upstream_task_ids', 'default_args', 'dag', '_dag', + vars(BaseOperator(task_id='test')).keys() + - { + 'inlets', + 'outlets', + '_upstream_task_ids', + 'default_args', + 'dag', + '_dag', '_BaseOperator__instantiated', - } | {'_task_type', 'subdag', 'ui_color', 'ui_fgcolor', - 'template_fields', 'template_fields_renderers'}) + } + | { + '_task_type', + 'subdag', + 'ui_color', + 'ui_fgcolor', + 'template_fields', + 'template_fields_renderers', + } + ) return cls.__serialized_fields @@ -1341,19 +1404,23 @@ def chain(*tasks: Union[BaseOperator, Sequence[BaseOperator]]): if not isinstance(up_task, Sequence) or not isinstance(down_task, Sequence): raise TypeError( 'Chain not supported between instances of {up_type} and {down_type}'.format( - up_type=type(up_task), down_type=type(down_task))) + up_type=type(up_task), down_type=type(down_task) + ) + ) up_task_list = up_task down_task_list = down_task if len(up_task_list) != len(down_task_list): raise AirflowException( f'Chain not supported different length Iterable ' - f'but get {len(up_task_list)} and {len(down_task_list)}') + f'but get {len(up_task_list)} and {len(down_task_list)}' + ) for up_t, down_t in zip(up_task_list, down_task_list): up_t.set_downstream(down_t) -def cross_downstream(from_tasks: Sequence[BaseOperator], - to_tasks: Union[BaseOperator, Sequence[BaseOperator]]): +def cross_downstream( + from_tasks: Sequence[BaseOperator], to_tasks: Union[BaseOperator, Sequence[BaseOperator]] +): r""" Set downstream dependencies for all tasks in from_tasks to all tasks in to_tasks. diff --git a/airflow/models/connection.py b/airflow/models/connection.py index d724a3f18122c..9f50e04d14ab1 100644 --- a/airflow/models/connection.py +++ b/airflow/models/connection.py @@ -39,7 +39,7 @@ CONN_TYPE_TO_HOOK = { "azure_batch": ( "airflow.providers.microsoft.azure.hooks.azure_batch.AzureBatchHook", - "azure_batch_conn_id" + "azure_batch_conn_id", ), "azure_cosmos": ( "airflow.providers.microsoft.azure.hooks.azure_cosmos.AzureCosmosDBHook", @@ -55,7 +55,7 @@ "docker": ("airflow.providers.docker.hooks.docker.DockerHook", "docker_conn_id"), "elasticsearch": ( "airflow.providers.elasticsearch.hooks.elasticsearch.ElasticsearchHook", - "elasticsearch_conn_id" + "elasticsearch_conn_id", ), "exasol": ("airflow.providers.exasol.hooks.exasol.ExasolHook", "exasol_conn_id"), "gcpcloudsql": ( @@ -93,10 +93,7 @@ def parse_netloc_to_hostname(*args, **kwargs): """This method is deprecated.""" - warnings.warn( - "This method is deprecated.", - DeprecationWarning - ) + warnings.warn("This method is deprecated.", DeprecationWarning) return _parse_netloc_to_hostname(*args, **kwargs) @@ -170,7 +167,7 @@ def __init__( schema: Optional[str] = None, port: Optional[int] = None, extra: Optional[str] = None, - uri: Optional[str] = None + uri: Optional[str] = None, ): super().__init__() self.conn_id = conn_id @@ -196,8 +193,7 @@ def __init__( def parse_from_uri(self, **uri): """This method is deprecated. Please use uri parameter in constructor.""" warnings.warn( - "This method is deprecated. Please use uri parameter in constructor.", - DeprecationWarning + "This method is deprecated. Please use uri parameter in constructor.", DeprecationWarning ) self._parse_from_uri(**uri) @@ -212,10 +208,8 @@ def _parse_from_uri(self, uri: str): self.host = _parse_netloc_to_hostname(uri_parts) quoted_schema = uri_parts.path[1:] self.schema = unquote(quoted_schema) if quoted_schema else quoted_schema - self.login = unquote(uri_parts.username) \ - if uri_parts.username else uri_parts.username - self.password = unquote(uri_parts.password) \ - if uri_parts.password else uri_parts.password + self.login = unquote(uri_parts.username) if uri_parts.username else uri_parts.username + self.password = unquote(uri_parts.password) if uri_parts.password else uri_parts.password self.port = uri_parts.port if uri_parts.query: self.extra = json.dumps(dict(parse_qsl(uri_parts.query, keep_blank_values=True))) @@ -263,7 +257,10 @@ def get_password(self) -> Optional[str]: if not fernet.is_encrypted: raise AirflowException( "Can't decrypt encrypted password for login={}, \ - FERNET_KEY configuration is missing".format(self.login)) + FERNET_KEY configuration is missing".format( + self.login + ) + ) return fernet.decrypt(bytes(self._password, 'utf-8')).decode() else: return self._password @@ -276,10 +273,9 @@ def set_password(self, value: Optional[str]): self.is_encrypted = fernet.is_encrypted @declared_attr - def password(cls): # pylint: disable=no-self-argument + def password(cls): # pylint: disable=no-self-argument """Password. The value is decrypted/encrypted when reading/setting the value.""" - return synonym('_password', - descriptor=property(cls.get_password, cls.set_password)) + return synonym('_password', descriptor=property(cls.get_password, cls.set_password)) def get_extra(self) -> Dict: """Return encrypted extra-data.""" @@ -288,7 +284,10 @@ def get_extra(self) -> Dict: if not fernet.is_encrypted: raise AirflowException( "Can't decrypt `extra` params for login={},\ - FERNET_KEY configuration is missing".format(self.login)) + FERNET_KEY configuration is missing".format( + self.login + ) + ) return fernet.decrypt(bytes(self._extra, 'utf-8')).decode() else: return self._extra @@ -304,10 +303,9 @@ def set_extra(self, value: str): self.is_extra_encrypted = False @declared_attr - def extra(cls): # pylint: disable=no-self-argument + def extra(cls): # pylint: disable=no-self-argument """Extra data. The value is decrypted/encrypted when reading/setting the value.""" - return synonym('_extra', - descriptor=property(cls.get_extra, cls.set_extra)) + return synonym('_extra', descriptor=property(cls.get_extra, cls.set_extra)) def rotate_fernet_key(self): """Encrypts data with a new key. See: :ref:`security/fernet`""" @@ -337,17 +335,17 @@ def log_info(self): "This method is deprecated. You can read each field individually or " "use the default representation (__repr__).", DeprecationWarning, - stacklevel=2 + stacklevel=2, + ) + return "id: {}. Host: {}, Port: {}, Schema: {}, " "Login: {}, Password: {}, extra: {}".format( + self.conn_id, + self.host, + self.port, + self.schema, + self.login, + "XXXXXXXX" if self.password else None, + "XXXXXXXX" if self.extra_dejson else None, ) - return ("id: {}. Host: {}, Port: {}, Schema: {}, " - "Login: {}, Password: {}, extra: {}". - format(self.conn_id, - self.host, - self.port, - self.schema, - self.login, - "XXXXXXXX" if self.password else None, - "XXXXXXXX" if self.extra_dejson else None)) def debug_info(self): """ @@ -358,17 +356,17 @@ def debug_info(self): "This method is deprecated. You can read each field individually or " "use the default representation (__repr__).", DeprecationWarning, - stacklevel=2 + stacklevel=2, + ) + return "id: {}. Host: {}, Port: {}, Schema: {}, " "Login: {}, Password: {}, extra: {}".format( + self.conn_id, + self.host, + self.port, + self.schema, + self.login, + "XXXXXXXX" if self.password else None, + self.extra_dejson, ) - return ("id: {}. Host: {}, Port: {}, Schema: {}, " - "Login: {}, Password: {}, extra: {}". - format(self.conn_id, - self.host, - self.port, - self.schema, - self.login, - "XXXXXXXX" if self.password else None, - self.extra_dejson)) @property def extra_dejson(self) -> Dict: diff --git a/airflow/models/crypto.py b/airflow/models/crypto.py index 8b55448055a91..d6e0ee8341ef8 100644 --- a/airflow/models/crypto.py +++ b/airflow/models/crypto.py @@ -79,15 +79,12 @@ def get_fernet(): try: fernet_key = conf.get('core', 'FERNET_KEY') if not fernet_key: - log.warning( - "empty cryptography key - values will not be stored encrypted." - ) + log.warning("empty cryptography key - values will not be stored encrypted.") _fernet = NullFernet() else: - _fernet = MultiFernet([ - Fernet(fernet_part.encode('utf-8')) - for fernet_part in fernet_key.split(',') - ]) + _fernet = MultiFernet( + [Fernet(fernet_part.encode('utf-8')) for fernet_part in fernet_key.split(',')] + ) _fernet.is_encrypted = True except (ValueError, TypeError) as value_error: raise AirflowException(f"Could not create Fernet object: {value_error}") diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 166553d9db415..78ad2479f4e16 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -29,7 +29,18 @@ from datetime import datetime, timedelta from inspect import signature from typing import ( - TYPE_CHECKING, Callable, Collection, Dict, FrozenSet, Iterable, List, Optional, Set, Tuple, Type, Union, + TYPE_CHECKING, + Callable, + Collection, + Dict, + FrozenSet, + Iterable, + List, + Optional, + Set, + Tuple, + Type, + Union, cast, ) @@ -253,7 +264,7 @@ def __init__( access_control: Optional[Dict] = None, is_paused_upon_creation: Optional[bool] = None, jinja_environment_kwargs: Optional[Dict] = None, - tags: Optional[List[str]] = None + tags: Optional[List[str]] = None, ): from airflow.utils.task_group import TaskGroup @@ -286,9 +297,7 @@ def __init__( self.timezone = start_date.tzinfo elif 'start_date' in self.default_args and self.default_args['start_date']: if isinstance(self.default_args['start_date'], str): - self.default_args['start_date'] = ( - timezone.parse(self.default_args['start_date']) - ) + self.default_args['start_date'] = timezone.parse(self.default_args['start_date']) self.timezone = self.default_args['start_date'].tzinfo if not hasattr(self, 'timezone') or not self.timezone: @@ -297,8 +306,8 @@ def __init__( # Apply the timezone we settled on to end_date if it wasn't supplied if 'end_date' in self.default_args and self.default_args['end_date']: if isinstance(self.default_args['end_date'], str): - self.default_args['end_date'] = ( - timezone.parse(self.default_args['end_date'], timezone=self.timezone) + self.default_args['end_date'] = timezone.parse( + self.default_args['end_date'], timezone=self.timezone ) self.start_date = timezone.convert_to_utc(start_date) @@ -306,13 +315,9 @@ def __init__( # also convert tasks if 'start_date' in self.default_args: - self.default_args['start_date'] = ( - timezone.convert_to_utc(self.default_args['start_date']) - ) + self.default_args['start_date'] = timezone.convert_to_utc(self.default_args['start_date']) if 'end_date' in self.default_args: - self.default_args['end_date'] = ( - timezone.convert_to_utc(self.default_args['end_date']) - ) + self.default_args['end_date'] = timezone.convert_to_utc(self.default_args['end_date']) self.schedule_interval = schedule_interval if isinstance(template_searchpath, str): @@ -328,13 +333,17 @@ def __init__( if default_view in DEFAULT_VIEW_PRESETS: self._default_view: str = default_view else: - raise AirflowException(f'Invalid values of dag.default_view: only support ' - f'{DEFAULT_VIEW_PRESETS}, but get {default_view}') + raise AirflowException( + f'Invalid values of dag.default_view: only support ' + f'{DEFAULT_VIEW_PRESETS}, but get {default_view}' + ) if orientation in ORIENTATION_PRESETS: self.orientation = orientation else: - raise AirflowException(f'Invalid values of dag.orientation: only support ' - f'{ORIENTATION_PRESETS}, but get {orientation}') + raise AirflowException( + f'Invalid values of dag.orientation: only support ' + f'{ORIENTATION_PRESETS}, but get {orientation}' + ) self.catchup = catchup self.is_subdag = False # DagBag.bag_dag() will set this to True if appropriate @@ -354,8 +363,7 @@ def __repr__(self): return f"" def __eq__(self, other): - if (type(self) == type(other) and - self.dag_id == other.dag_id): + if type(self) == type(other) and self.dag_id == other.dag_id: # Use getattr() instead of __dict__ as __dict__ doesn't return # correct values for properties. @@ -414,7 +422,8 @@ def _upgrade_outdated_dag_access_control(access_control=None): warnings.warn( "The 'can_dag_read' and 'can_dag_edit' permissions are deprecated. " "Please use 'can_read' and 'can_edit', respectively.", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) return updated_access_control @@ -428,8 +437,8 @@ def date_range( if num is not None: end_date = None return utils_date_range( - start_date=start_date, end_date=end_date, - num=num, delta=self.normalized_schedule_interval) + start_date=start_date, end_date=end_date, num=num, delta=self.normalized_schedule_interval + ) def is_fixed_time_schedule(self): """ @@ -516,8 +525,9 @@ def next_dagrun_info( "automated" DagRuns for this dag (scheduled or backfill, but not manual) """ - if (self.schedule_interval == "@once" and date_last_automated_dagrun) or \ - self.schedule_interval is None: + if ( + self.schedule_interval == "@once" and date_last_automated_dagrun + ) or self.schedule_interval is None: # Manual trigger, or already created the run for @once, can short circuit return (None, None) next_execution_date = self.next_dagrun_after_date(date_last_automated_dagrun) @@ -588,10 +598,7 @@ def next_dagrun_after_date(self, date_last_automated_dagrun: Optional[pendulum.D if next_run_date == self.start_date: next_run_date = self.normalize_schedule(self.start_date) - self.log.debug( - "Dag start date: %s. Next run date: %s", - self.start_date, next_run_date - ) + self.log.debug("Dag start date: %s. Next run date: %s", self.start_date, next_run_date) # Don't schedule a dag beyond its end_date (as specified by the dag param) if next_run_date and self.end_date and next_run_date > self.end_date: @@ -630,8 +637,7 @@ def get_run_dates(self, start_date, end_date=None): # next run date for a subdag isn't relevant (schedule_interval for subdags # is ignored) so we use the dag run's start date in the case of a subdag - next_run_date = (self.normalize_schedule(using_start_date) - if not self.is_subdag else using_start_date) + next_run_date = self.normalize_schedule(using_start_date) if not self.is_subdag else using_start_date while next_run_date and next_run_date <= using_end_date: run_dates.append(next_run_date) @@ -653,8 +659,9 @@ def normalize_schedule(self, dttm): @provide_session def get_last_dagrun(self, session=None, include_externally_triggered=False): - return get_last_dagrun(self.dag_id, session=session, - include_externally_triggered=include_externally_triggered) + return get_last_dagrun( + self.dag_id, session=session, include_externally_triggered=include_externally_triggered + ) @property def dag_id(self) -> str: @@ -720,8 +727,7 @@ def tasks(self) -> List[BaseOperator]: @tasks.setter def tasks(self, val): - raise AttributeError( - 'DAG.tasks can not be modified. Use dag.add_task() instead.') + raise AttributeError('DAG.tasks can not be modified. Use dag.add_task() instead.') @property def task_ids(self) -> List[str]: @@ -783,8 +789,7 @@ def concurrency_reached(self): @provide_session def get_is_paused(self, session=None): """Returns a boolean indicating whether this DAG is paused""" - qry = session.query(DagModel).filter( - DagModel.dag_id == self.dag_id) + qry = session.query(DagModel).filter(DagModel.dag_id == self.dag_id) return qry.value(DagModel.is_paused) @property @@ -870,10 +875,11 @@ def get_num_active_runs(self, external_trigger=None, session=None): :return: number greater than 0 for active dag runs """ # .count() is inefficient - query = (session - .query(func.count()) - .filter(DagRun.dag_id == self.dag_id) - .filter(DagRun.state == State.RUNNING)) + query = ( + session.query(func.count()) + .filter(DagRun.dag_id == self.dag_id) + .filter(DagRun.state == State.RUNNING) + ) if external_trigger is not None: query = query.filter(DagRun.external_trigger == external_trigger) @@ -892,10 +898,9 @@ def get_dagrun(self, execution_date, session=None): """ dagrun = ( session.query(DagRun) - .filter( - DagRun.dag_id == self.dag_id, - DagRun.execution_date == execution_date) - .first()) + .filter(DagRun.dag_id == self.dag_id, DagRun.execution_date == execution_date) + .first() + ) return dagrun @@ -914,17 +919,17 @@ def get_dagruns_between(self, start_date, end_date, session=None): .filter( DagRun.dag_id == self.dag_id, DagRun.execution_date >= start_date, - DagRun.execution_date <= end_date) - .all()) + DagRun.execution_date <= end_date, + ) + .all() + ) return dagruns @provide_session def get_latest_execution_date(self, session=None): """Returns the latest date for which at least one dag run exists""" - return session.query(func.max(DagRun.execution_date)).filter( - DagRun.dag_id == self.dag_id - ).scalar() + return session.query(func.max(DagRun.execution_date)).filter(DagRun.dag_id == self.dag_id).scalar() @property def latest_execution_date(self): @@ -941,12 +946,16 @@ def subdags(self): """Returns a list of the subdag objects associated to this DAG""" # Check SubDag for class but don't check class directly from airflow.operators.subdag_operator import SubDagOperator + subdag_lst = [] for task in self.tasks: - if (isinstance(task, SubDagOperator) or - # TODO remove in Airflow 2.0 - type(task).__name__ == 'SubDagOperator' or - task.task_type == 'SubDagOperator'): + if ( + isinstance(task, SubDagOperator) + or + # TODO remove in Airflow 2.0 + type(task).__name__ == 'SubDagOperator' + or task.task_type == 'SubDagOperator' + ): subdag_lst.append(task.subdag) subdag_lst += task.subdag.subdags return subdag_lst @@ -967,7 +976,7 @@ def get_template_env(self) -> jinja2.Environment: 'loader': jinja2.FileSystemLoader(searchpath), 'undefined': self.template_undefined, 'extensions': ["jinja2.ext.do"], - 'cache_size': 0 + 'cache_size': 0, } if self.jinja_environment_kwargs: jinja_env_options.update(self.jinja_environment_kwargs) @@ -988,16 +997,13 @@ def set_dependency(self, upstream_task_id, downstream_task_id): Simple utility method to set dependency between two tasks that already have been added to the DAG using add_task() """ - self.get_task(upstream_task_id).set_downstream( - self.get_task(downstream_task_id)) + self.get_task(upstream_task_id).set_downstream(self.get_task(downstream_task_id)) @provide_session - def get_task_instances( - self, start_date=None, end_date=None, state=None, session=None): + def get_task_instances(self, start_date=None, end_date=None, state=None, session=None): if not start_date: start_date = (timezone.utcnow() - timedelta(30)).date() - start_date = timezone.make_aware( - datetime.combine(start_date, datetime.min.time())) + start_date = timezone.make_aware(datetime.combine(start_date, datetime.min.time())) tis = session.query(TaskInstance).filter( TaskInstance.dag_id == self.dag_id, @@ -1020,8 +1026,7 @@ def get_task_instances( else: not_none_state = [s for s in state if s] tis = tis.filter( - or_(TaskInstance.state.in_(not_none_state), - TaskInstance.state.is_(None)) + or_(TaskInstance.state.in_(not_none_state), TaskInstance.state.is_(None)) ) else: tis = tis.filter(TaskInstance.state.in_(state)) @@ -1088,18 +1093,17 @@ def topological_sort(self, include_subdag_tasks: bool = False): graph_sorted.extend(node.subdag.topological_sort(include_subdag_tasks=True)) if not acyclic: - raise AirflowException("A cyclic dependency occurred in dag: {}" - .format(self.dag_id)) + raise AirflowException(f"A cyclic dependency occurred in dag: {self.dag_id}") return tuple(graph_sorted) @provide_session def set_dag_runs_state( - self, - state: str = State.RUNNING, - session: Session = None, - start_date: Optional[datetime] = None, - end_date: Optional[datetime] = None, + self, + state: str = State.RUNNING, + session: Session = None, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, ) -> None: query = session.query(DagRun).filter_by(dag_id=self.dag_id) if start_date: @@ -1110,20 +1114,22 @@ def set_dag_runs_state( @provide_session def clear( - self, start_date=None, end_date=None, - only_failed=False, - only_running=False, - confirm_prompt=False, - include_subdags=True, - include_parentdag=True, - dag_run_state: str = State.RUNNING, - dry_run=False, - session=None, - get_tis=False, - recursion_depth=0, - max_recursion_depth=None, - dag_bag=None, - visited_external_tis=None, + self, + start_date=None, + end_date=None, + only_failed=False, + only_running=False, + confirm_prompt=False, + include_subdags=True, + include_parentdag=True, + dag_run_state: str = State.RUNNING, + dry_run=False, + session=None, + get_tis=False, + recursion_depth=0, + max_recursion_depth=None, + dag_bag=None, + visited_external_tis=None, ): """ Clears a set of task instances associated with the current dag for @@ -1169,10 +1175,7 @@ def clear( # Crafting the right filter for dag_id and task_ids combo conditions = [] for dag in self.subdags + [self]: - conditions.append( - (TI.dag_id == dag.dag_id) & - TI.task_id.in_(dag.task_ids) - ) + conditions.append((TI.dag_id == dag.dag_id) & TI.task_id.in_(dag.task_ids)) tis = tis.filter(or_(*conditions)) else: tis = session.query(TI).filter(TI.dag_id == self.dag_id) @@ -1182,32 +1185,34 @@ def clear( p_dag = self.parent_dag.sub_dag( task_ids_or_regex=r"^{}$".format(self.dag_id.split('.')[1]), include_upstream=False, - include_downstream=True) + include_downstream=True, + ) - tis = tis.union(p_dag.clear( - start_date=start_date, end_date=end_date, - only_failed=only_failed, - only_running=only_running, - confirm_prompt=confirm_prompt, - include_subdags=include_subdags, - include_parentdag=False, - dag_run_state=dag_run_state, - get_tis=True, - session=session, - recursion_depth=recursion_depth, - max_recursion_depth=max_recursion_depth, - dag_bag=dag_bag, - visited_external_tis=visited_external_tis - )) + tis = tis.union( + p_dag.clear( + start_date=start_date, + end_date=end_date, + only_failed=only_failed, + only_running=only_running, + confirm_prompt=confirm_prompt, + include_subdags=include_subdags, + include_parentdag=False, + dag_run_state=dag_run_state, + get_tis=True, + session=session, + recursion_depth=recursion_depth, + max_recursion_depth=max_recursion_depth, + dag_bag=dag_bag, + visited_external_tis=visited_external_tis, + ) + ) if start_date: tis = tis.filter(TI.execution_date >= start_date) if end_date: tis = tis.filter(TI.execution_date <= end_date) if only_failed: - tis = tis.filter(or_( - TI.state == State.FAILED, - TI.state == State.UPSTREAM_FAILED)) + tis = tis.filter(or_(TI.state == State.FAILED, TI.state == State.UPSTREAM_FAILED)) if only_running: tis = tis.filter(TI.state == State.RUNNING) @@ -1224,8 +1229,9 @@ def clear( if ti_key not in visited_external_tis: # Only clear this ExternalTaskMarker if it's not already visited by the # recursive calls to dag.clear(). - task: ExternalTaskMarker = cast(ExternalTaskMarker, - copy.copy(self.get_task(ti.task_id))) + task: ExternalTaskMarker = cast( + ExternalTaskMarker, copy.copy(self.get_task(ti.task_id)) + ) ti.task = task if recursion_depth == 0: @@ -1235,16 +1241,19 @@ def clear( if recursion_depth + 1 > max_recursion_depth: # Prevent cycles or accidents. - raise AirflowException("Maximum recursion depth {} reached for {} {}. " - "Attempted to clear too many tasks " - "or there may be a cyclic dependency." - .format(max_recursion_depth, - ExternalTaskMarker.__name__, ti.task_id)) + raise AirflowException( + "Maximum recursion depth {} reached for {} {}. " + "Attempted to clear too many tasks " + "or there may be a cyclic dependency.".format( + max_recursion_depth, ExternalTaskMarker.__name__, ti.task_id + ) + ) ti.render_templates() - external_tis = session.query(TI).filter(TI.dag_id == task.external_dag_id, - TI.task_id == task.external_task_id, - TI.execution_date == - pendulum.parse(task.execution_date)) + external_tis = session.query(TI).filter( + TI.dag_id == task.external_dag_id, + TI.task_id == task.external_task_id, + TI.execution_date == pendulum.parse(task.execution_date), + ) for tii in external_tis: if not dag_bag: @@ -1255,22 +1264,26 @@ def clear( downstream = external_dag.sub_dag( task_ids_or_regex=fr"^{tii.task_id}$", include_upstream=False, - include_downstream=True + include_downstream=True, + ) + tis = tis.union( + downstream.clear( + start_date=tii.execution_date, + end_date=tii.execution_date, + only_failed=only_failed, + only_running=only_running, + confirm_prompt=confirm_prompt, + include_subdags=include_subdags, + include_parentdag=False, + dag_run_state=dag_run_state, + get_tis=True, + session=session, + recursion_depth=recursion_depth + 1, + max_recursion_depth=max_recursion_depth, + dag_bag=dag_bag, + visited_external_tis=visited_external_tis, + ) ) - tis = tis.union(downstream.clear(start_date=tii.execution_date, - end_date=tii.execution_date, - only_failed=only_failed, - only_running=only_running, - confirm_prompt=confirm_prompt, - include_subdags=include_subdags, - include_parentdag=False, - dag_run_state=dag_run_state, - get_tis=True, - session=session, - recursion_depth=recursion_depth + 1, - max_recursion_depth=max_recursion_depth, - dag_bag=dag_bag, - visited_external_tis=visited_external_tis)) visited_external_tis.add(ti_key) if get_tis: @@ -1291,9 +1304,8 @@ def clear( if confirm_prompt: ti_list = "\n".join([str(t) for t in tis]) question = ( - "You are about to delete these {count} tasks:\n" - "{ti_list}\n\n" - "Are you sure? (yes/no): ").format(count=count, ti_list=ti_list) + "You are about to delete these {count} tasks:\n" "{ti_list}\n\n" "Are you sure? (yes/no): " + ).format(count=count, ti_list=ti_list) do_it = utils.helpers.ask_yesno(question) if do_it: @@ -1318,16 +1330,17 @@ def clear( @classmethod def clear_dags( - cls, dags, - start_date=None, - end_date=None, - only_failed=False, - only_running=False, - confirm_prompt=False, - include_subdags=True, - include_parentdag=False, - dag_run_state=State.RUNNING, - dry_run=False, + cls, + dags, + start_date=None, + end_date=None, + only_failed=False, + only_running=False, + confirm_prompt=False, + include_subdags=True, + include_parentdag=False, + dag_run_state=State.RUNNING, + dry_run=False, ): all_tis = [] for dag in dags: @@ -1340,7 +1353,8 @@ def clear_dags( include_subdags=include_subdags, include_parentdag=include_parentdag, dag_run_state=dag_run_state, - dry_run=True) + dry_run=True, + ) all_tis.extend(tis) if dry_run: @@ -1354,22 +1368,22 @@ def clear_dags( if confirm_prompt: ti_list = "\n".join([str(t) for t in all_tis]) question = ( - "You are about to delete these {} tasks:\n" - "{}\n\n" - "Are you sure? (yes/no): ").format(count, ti_list) + "You are about to delete these {} tasks:\n" "{}\n\n" "Are you sure? (yes/no): " + ).format(count, ti_list) do_it = utils.helpers.ask_yesno(question) if do_it: for dag in dags: - dag.clear(start_date=start_date, - end_date=end_date, - only_failed=only_failed, - only_running=only_running, - confirm_prompt=False, - include_subdags=include_subdags, - dag_run_state=dag_run_state, - dry_run=False, - ) + dag.clear( + start_date=start_date, + end_date=end_date, + only_failed=only_failed, + only_running=only_running, + confirm_prompt=False, + include_subdags=include_subdags, + dag_run_state=dag_run_state, + dry_run=False, + ) else: count = 0 print("Cancelled, nothing was cleared.") @@ -1432,8 +1446,7 @@ def partial_subset( self._task_group = task_group if isinstance(task_ids_or_regex, (str, PatternType)): - matched_tasks = [ - t for t in self.tasks if re.findall(task_ids_or_regex, t.task_id)] + matched_tasks = [t for t in self.tasks if re.findall(task_ids_or_regex, t.task_id)] else: matched_tasks = [t for t in self.tasks if t.task_id in task_ids_or_regex] @@ -1448,8 +1461,10 @@ def partial_subset( # Compiling the unique list of tasks that made the cut # Make sure to not recursively deepcopy the dag while copying the task - dag.task_dict = {t.task_id: copy.deepcopy(t, {id(t.dag): dag}) # type: ignore - for t in matched_tasks + also_include} + dag.task_dict = { + t.task_id: copy.deepcopy(t, {id(t.dag): dag}) # type: ignore + for t in matched_tasks + also_include + } def filter_task_group(group, parent_group): """Exclude tasks not included in the subdag from the given TaskGroup.""" @@ -1487,8 +1502,7 @@ def filter_task_group(group, parent_group): # Removing upstream/downstream references to tasks that did not # make the cut t._upstream_task_ids = t.upstream_task_ids.intersection(dag.task_dict.keys()) - t._downstream_task_ids = t.downstream_task_ids.intersection( - dag.task_dict.keys()) + t._downstream_task_ids = t.downstream_task_ids.intersection(dag.task_dict.keys()) if len(dag.tasks) < len(self.tasks): dag.partial = True @@ -1523,12 +1537,10 @@ def pickle_info(self): @provide_session def pickle(self, session=None) -> DagPickle: - dag = session.query( - DagModel).filter(DagModel.dag_id == self.dag_id).first() + dag = session.query(DagModel).filter(DagModel.dag_id == self.dag_id).first() dp = None if dag and dag.pickle_id: - dp = session.query(DagPickle).filter( - DagPickle.id == dag.pickle_id).first() + dp = session.query(DagPickle).filter(DagPickle.id == dag.pickle_id).first() if not dp or dp.pickle != self: dp = DagPickle(dag=self) session.add(dp) @@ -1540,6 +1552,7 @@ def pickle(self, session=None) -> DagPickle: def tree_view(self) -> None: """Print an ASCII tree representation of the DAG.""" + def get_downstream(task, level=0): print((" " * level * 4) + str(task)) level += 1 @@ -1552,6 +1565,7 @@ def get_downstream(task, level=0): @property def task(self): from airflow.operators.python import task + return functools.partial(task, dag=self) def add_task(self, task): @@ -1579,10 +1593,10 @@ def add_task(self, task): elif task.end_date and self.end_date: task.end_date = min(task.end_date, self.end_date) - if ((task.task_id in self.task_dict and self.task_dict[task.task_id] is not task) - or task.task_id in self._task_group.used_group_ids): - raise DuplicateTaskIdFound( - f"Task id '{task.task_id}' has already been added to the DAG") + if ( + task.task_id in self.task_dict and self.task_dict[task.task_id] is not task + ) or task.task_id in self._task_group.used_group_ids: + raise DuplicateTaskIdFound(f"Task id '{task.task_id}' has already been added to the DAG") else: self.task_dict[task.task_id] = task task.dag = self @@ -1602,21 +1616,21 @@ def add_tasks(self, tasks): self.add_task(task) def run( - self, - start_date=None, - end_date=None, - mark_success=False, - local=False, - executor=None, - donot_pickle=conf.getboolean('core', 'donot_pickle'), - ignore_task_deps=False, - ignore_first_depends_on_past=True, - pool=None, - delay_on_limit_secs=1.0, - verbose=False, - conf=None, - rerun_failed_tasks=False, - run_backwards=False, + self, + start_date=None, + end_date=None, + mark_success=False, + local=False, + executor=None, + donot_pickle=conf.getboolean('core', 'donot_pickle'), + ignore_task_deps=False, + ignore_first_depends_on_past=True, + pool=None, + delay_on_limit_secs=1.0, + verbose=False, + conf=None, + rerun_failed_tasks=False, + run_backwards=False, ): """ Runs the DAG. @@ -1654,11 +1668,14 @@ def run( """ from airflow.jobs.backfill_job import BackfillJob + if not executor and local: from airflow.executors.local_executor import LocalExecutor + executor = LocalExecutor() elif not executor: from airflow.executors.executor_loader import ExecutorLoader + executor = ExecutorLoader.get_default_executor() job = BackfillJob( self, @@ -1681,6 +1698,7 @@ def run( def cli(self): """Exposes a CLI specific to this DAG""" from airflow.cli import cli_parser + parser = cli_parser.get_parser(dag_parser=True) args = parser.parse_args() args.func(args, self) @@ -1747,7 +1765,7 @@ def create_dagrun( state=state, run_type=run_type, dag_hash=dag_hash, - creating_job_id=creating_job_id + creating_job_id=creating_job_id, ) session.add(run) session.flush() @@ -1791,10 +1809,10 @@ def bulk_write_to_db(cls, dags: Collection["DAG"], session=None): dag_by_ids = {dag.dag_id: dag for dag in dags} dag_ids = set(dag_by_ids.keys()) query = ( - session - .query(DagModel) + session.query(DagModel) .options(joinedload(DagModel.tags, innerjoin=False)) - .filter(DagModel.dag_id.in_(dag_ids))) + .filter(DagModel.dag_id.in_(dag_ids)) + ) orm_dags = with_row_locks(query, of=DagModel).all() existing_dag_ids = {orm_dag.dag_id for orm_dag in orm_dags} @@ -1811,21 +1829,31 @@ def bulk_write_to_db(cls, dags: Collection["DAG"], session=None): orm_dags.append(orm_dag) # Get the latest dag run for each existing dag as a single query (avoid n+1 query) - most_recent_dag_runs = dict(session.query(DagRun.dag_id, func.max_(DagRun.execution_date)).filter( - DagRun.dag_id.in_(existing_dag_ids), - or_( - DagRun.run_type == DagRunType.BACKFILL_JOB, - DagRun.run_type == DagRunType.SCHEDULED, - DagRun.external_trigger.is_(True), - ), - ).group_by(DagRun.dag_id).all()) + most_recent_dag_runs = dict( + session.query(DagRun.dag_id, func.max_(DagRun.execution_date)) + .filter( + DagRun.dag_id.in_(existing_dag_ids), + or_( + DagRun.run_type == DagRunType.BACKFILL_JOB, + DagRun.run_type == DagRunType.SCHEDULED, + DagRun.external_trigger.is_(True), + ), + ) + .group_by(DagRun.dag_id) + .all() + ) # Get number of active dagruns for all dags we are processing as a single query. - num_active_runs = dict(session.query(DagRun.dag_id, func.count('*')).filter( - DagRun.dag_id.in_(existing_dag_ids), - DagRun.state == State.RUNNING, # pylint: disable=comparison-with-callable - DagRun.external_trigger.is_(False) - ).group_by(DagRun.dag_id).all()) + num_active_runs = dict( + session.query(DagRun.dag_id, func.count('*')) + .filter( + DagRun.dag_id.in_(existing_dag_ids), + DagRun.state == State.RUNNING, # pylint: disable=comparison-with-callable + DagRun.external_trigger.is_(False), + ) + .group_by(DagRun.dag_id) + .all() + ) for orm_dag in sorted(orm_dags, key=lambda d: d.dag_id): dag = dag_by_ids[orm_dag.dag_id] @@ -1843,9 +1871,7 @@ def bulk_write_to_db(cls, dags: Collection["DAG"], session=None): orm_dag.description = dag.description orm_dag.schedule_interval = dag.schedule_interval orm_dag.concurrency = dag.concurrency - orm_dag.has_task_concurrency_limits = any( - t.task_concurrency is not None for t in dag.tasks - ) + orm_dag.has_task_concurrency_limits = any(t.task_concurrency is not None for t in dag.tasks) orm_dag.calculate_dagrun_date_fields( dag, @@ -1906,8 +1932,7 @@ def deactivate_unknown_dags(active_dag_ids, session=None): """ if len(active_dag_ids) == 0: return - for dag in session.query( - DagModel).filter(~DagModel.dag_id.in_(active_dag_ids)).all(): + for dag in session.query(DagModel).filter(~DagModel.dag_id.in_(active_dag_ids)).all(): dag.is_active = False session.merge(dag) session.commit() @@ -1924,12 +1949,15 @@ def deactivate_stale_dags(expiration_date, session=None): :type expiration_date: datetime :return: None """ - for dag in session.query( - DagModel).filter(DagModel.last_scheduler_run < expiration_date, - DagModel.is_active).all(): + for dag in ( + session.query(DagModel) + .filter(DagModel.last_scheduler_run < expiration_date, DagModel.is_active) + .all() + ): log.info( "Deactivating DAG ID %s since it was last touched by the scheduler at %s", - dag.dag_id, dag.last_scheduler_run.isoformat() + dag.dag_id, + dag.last_scheduler_run.isoformat(), ) dag.is_active = False session.merge(dag) @@ -1965,9 +1993,9 @@ def get_num_task_instances(dag_id, task_ids=None, states=None, session=None): qry = qry.filter(TaskInstance.state.is_(None)) else: not_none_states = [state for state in states if state] - qry = qry.filter(or_( - TaskInstance.state.in_(not_none_states), - TaskInstance.state.is_(None))) + qry = qry.filter( + or_(TaskInstance.state.in_(not_none_states), TaskInstance.state.is_(None)) + ) else: qry = qry.filter(TaskInstance.state.in_(states)) return qry.scalar() @@ -1977,12 +2005,25 @@ def get_serialized_fields(cls): """Stringified DAGs and operators contain exactly these fields.""" if not cls.__serialized_fields: cls.__serialized_fields = frozenset(vars(DAG(dag_id='test')).keys()) - { - 'parent_dag', '_old_context_manager_dags', 'safe_dag_id', 'last_loaded', - '_full_filepath', 'user_defined_filters', 'user_defined_macros', - 'partial', '_old_context_manager_dags', - '_pickle_id', '_log', 'is_subdag', 'task_dict', 'template_searchpath', - 'sla_miss_callback', 'on_success_callback', 'on_failure_callback', - 'template_undefined', 'jinja_environment_kwargs' + 'parent_dag', + '_old_context_manager_dags', + 'safe_dag_id', + 'last_loaded', + '_full_filepath', + 'user_defined_filters', + 'user_defined_macros', + 'partial', + '_old_context_manager_dags', + '_pickle_id', + '_log', + 'is_subdag', + 'task_dict', + 'template_searchpath', + 'sla_miss_callback', + 'on_success_callback', + 'on_failure_callback', + 'template_undefined', + 'jinja_environment_kwargs', } return cls.__serialized_fields @@ -2009,9 +2050,7 @@ class DagModel(Base): root_dag_id = Column(String(ID_LEN)) # A DAG can be paused from the UI / DB # Set this default value of is_paused based on a configuration value! - is_paused_at_creation = conf\ - .getboolean('core', - 'dags_are_paused_at_creation') + is_paused_at_creation = conf.getboolean('core', 'dags_are_paused_at_creation') is_paused = Column(Boolean, default=is_paused_at_creation) # Whether the DAG is a subdag is_subdag = Column(Boolean, default=False) @@ -2058,11 +2097,7 @@ class DagModel(Base): Index('idx_next_dagrun_create_after', next_dagrun_create_after, unique=False), ) - NUM_DAGS_PER_DAGRUN_QUERY = conf.getint( - 'scheduler', - 'max_dagruns_to_create_per_loop', - fallback=10 - ) + NUM_DAGS_PER_DAGRUN_QUERY = conf.getint('scheduler', 'max_dagruns_to_create_per_loop', fallback=10) def __init__(self, **kwargs): super().__init__(**kwargs) @@ -2091,8 +2126,9 @@ def get_current(cls, dag_id, session=None): @provide_session def get_last_dagrun(self, session=None, include_externally_triggered=False): - return get_last_dagrun(self.dag_id, session=session, - include_externally_triggered=include_externally_triggered) + return get_last_dagrun( + self.dag_id, session=session, include_externally_triggered=include_externally_triggered + ) @staticmethod @provide_session @@ -2127,10 +2163,7 @@ def safe_dag_id(self): return self.dag_id.replace('.', '__dot__') @provide_session - def set_is_paused(self, - is_paused: bool, - including_subdags: bool = True, - session=None) -> None: + def set_is_paused(self, is_paused: bool, including_subdags: bool = True, session=None) -> None: """ Pause/Un-pause a DAG. @@ -2142,12 +2175,8 @@ def set_is_paused(self, DagModel.dag_id == self.dag_id, ] if including_subdags: - filter_query.append( - DagModel.root_dag_id == self.dag_id - ) - session.query(DagModel).filter(or_( - *filter_query - )).update( + filter_query.append(DagModel.root_dag_id == self.dag_id) + session.query(DagModel).filter(or_(*filter_query)).update( {DagModel.is_paused: is_paused}, synchronize_session='fetch' ) session.commit() @@ -2162,8 +2191,7 @@ def deactivate_deleted_dags(cls, alive_dag_filelocs: List[str], session=None): :param alive_dag_filelocs: file paths of alive DAGs :param session: ORM Session """ - log.debug("Deactivating DAGs (for which DAG files are deleted) from %s table ", - cls.__tablename__) + log.debug("Deactivating DAGs (for which DAG files are deleted) from %s table ", cls.__tablename__) dag_models = session.query(cls).all() try: @@ -2195,21 +2223,21 @@ def dags_needing_dagruns(cls, session: Session): # TODO[HA]: Bake this query, it is run _A lot_ # We limit so that _one_ scheduler doesn't try to do all the creation # of dag runs - query = session.query(cls).filter( - cls.is_paused.is_(False), - cls.is_active.is_(True), - cls.next_dagrun_create_after <= func.now(), - ).order_by( - cls.next_dagrun_create_after - ).limit(cls.NUM_DAGS_PER_DAGRUN_QUERY) + query = ( + session.query(cls) + .filter( + cls.is_paused.is_(False), + cls.is_active.is_(True), + cls.next_dagrun_create_after <= func.now(), + ) + .order_by(cls.next_dagrun_create_after) + .limit(cls.NUM_DAGS_PER_DAGRUN_QUERY) + ) return with_row_locks(query, of=cls, **skip_locked(session=session)) def calculate_dagrun_date_fields( - self, - dag: DAG, - most_recent_dag_run: Optional[pendulum.DateTime], - active_runs_of_dag: int + self, dag: DAG, most_recent_dag_run: Optional[pendulum.DateTime], active_runs_of_dag: int ) -> None: """ Calculate ``next_dagrun`` and `next_dagrun_create_after`` @@ -2224,7 +2252,9 @@ def calculate_dagrun_date_fields( # Since this happens every time the dag is parsed it would be quite spammy at info log.debug( "DAG %s is at (or above) max_active_runs (%d of %d), not creating any more runs", - dag.dag_id, active_runs_of_dag, dag.max_active_runs + dag.dag_id, + active_runs_of_dag, + dag.max_active_runs, ) self.next_dagrun_create_after = None @@ -2241,6 +2271,7 @@ def dag(*dag_args, **dag_kwargs): :param dag_kwargs: Kwargs for DAG object. :type dag_kwargs: dict """ + def wrapper(f: Callable): # Get dag initializer signature and bind it to validate that dag_args, and dag_kwargs are correct dag_sig = signature(DAG.__init__) @@ -2276,7 +2307,9 @@ def factory(*args, **kwargs): # Return dag object such that it's accessible in Globals. return dag_obj + return factory + return wrapper @@ -2287,6 +2320,7 @@ def factory(*args, **kwargs): from sqlalchemy.orm import relationship from airflow.models.serialized_dag import SerializedDagModel + DagModel.serialized_dag = relationship(SerializedDagModel) diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py index 7f3bab2e9e587..004d6f7c22551 100644 --- a/airflow/models/dagbag.py +++ b/airflow/models/dagbag.py @@ -97,13 +97,16 @@ def __init__( ): # Avoid circular import from airflow.models.dag import DAG + super().__init__() if store_serialized_dags: warnings.warn( "The store_serialized_dags parameter has been deprecated. " "You should pass the read_dags_from_db parameter.", - DeprecationWarning, stacklevel=2) + DeprecationWarning, + stacklevel=2, + ) read_dags_from_db = store_serialized_dags dag_folder = dag_folder or settings.DAGS_FOLDER @@ -125,7 +128,8 @@ def __init__( dag_folder=dag_folder, include_examples=include_examples, include_smart_sensor=include_smart_sensor, - safe_mode=safe_mode) + safe_mode=safe_mode, + ) def size(self) -> int: """:return: the amount of dags contained in this dagbag""" @@ -135,8 +139,9 @@ def size(self) -> int: def store_serialized_dags(self) -> bool: """Whether or not to read dags from DB""" warnings.warn( - "The store_serialized_dags property has been deprecated. " - "Use read_dags_from_db instead.", DeprecationWarning, stacklevel=2 + "The store_serialized_dags property has been deprecated. " "Use read_dags_from_db instead.", + DeprecationWarning, + stacklevel=2, ) return self.read_dags_from_db @@ -158,6 +163,7 @@ def get_dag(self, dag_id, session: Session = None): if self.read_dags_from_db: # Import here so that serialized dag is only imported when serialization is enabled from airflow.models.serialized_dag import SerializedDagModel + if dag_id not in self.dags: # Load from DB if not (yet) in the bag self._add_dag_from_db(dag_id=dag_id, session=session) @@ -169,8 +175,8 @@ def get_dag(self, dag_id, session: Session = None): # 3. if (2) is yes, fetch the Serialized DAG. min_serialized_dag_fetch_secs = timedelta(seconds=settings.MIN_SERIALIZED_DAG_FETCH_INTERVAL) if ( - dag_id in self.dags_last_fetched and - timezone.utcnow() > self.dags_last_fetched[dag_id] + min_serialized_dag_fetch_secs + dag_id in self.dags_last_fetched + and timezone.utcnow() > self.dags_last_fetched[dag_id] + min_serialized_dag_fetch_secs ): sd_last_updated_datetime = SerializedDagModel.get_last_updated_datetime( dag_id=dag_id, @@ -196,11 +202,12 @@ def get_dag(self, dag_id, session: Session = None): # If the dag corresponding to root_dag_id is absent or expired is_missing = root_dag_id not in self.dags - is_expired = (orm_dag.last_expired and dag and dag.last_loaded < orm_dag.last_expired) + is_expired = orm_dag.last_expired and dag and dag.last_loaded < orm_dag.last_expired if is_missing or is_expired: # Reprocess source file found_dags = self.process_file( - filepath=correct_maybe_zipped(orm_dag.fileloc), only_if_updated=False) + filepath=correct_maybe_zipped(orm_dag.fileloc), only_if_updated=False + ) # If the source file no longer exports `dag_id`, delete it from self.dags if found_dags and dag_id in [found_dag.dag_id for found_dag in found_dags]: @@ -212,6 +219,7 @@ def get_dag(self, dag_id, session: Session = None): def _add_dag_from_db(self, dag_id: str, session: Session): """Add DAG to DagBag from DB""" from airflow.models.serialized_dag import SerializedDagModel + row = SerializedDagModel.get(dag_id, session) if not row: raise SerializedDagNotFound(f"DAG '{dag_id}' not found in serialized_dag table") @@ -240,9 +248,11 @@ def process_file(self, filepath, only_if_updated=True, safe_mode=True): # This failed before in what may have been a git sync # race condition file_last_changed_on_disk = datetime.fromtimestamp(os.path.getmtime(filepath)) - if only_if_updated \ - and filepath in self.file_last_changed \ - and file_last_changed_on_disk == self.file_last_changed[filepath]: + if ( + only_if_updated + and filepath in self.file_last_changed + and file_last_changed_on_disk == self.file_last_changed[filepath] + ): return [] except Exception as e: # pylint: disable=broad-except self.log.exception(e) @@ -315,8 +325,7 @@ def _load_modules_from_zip(self, filepath, safe_mode): if not self.has_logged or True: self.has_logged = True self.log.info( - "File %s:%s assumed to contain no DAGs. Skipping.", - filepath, zip_info.filename + "File %s:%s assumed to contain no DAGs. Skipping.", filepath, zip_info.filename ) continue @@ -341,12 +350,7 @@ def _process_modules(self, filepath, mods, file_last_changed_on_disk): from airflow.models.dag import DAG # Avoid circular import is_zipfile = zipfile.is_zipfile(filepath) - top_level_dags = [ - o - for m in mods - for o in list(m.__dict__.values()) - if isinstance(o, DAG) - ] + top_level_dags = [o for m in mods for o in list(m.__dict__.values()) if isinstance(o, DAG)] found_dags = [] @@ -362,15 +366,11 @@ def _process_modules(self, filepath, mods, file_last_changed_on_disk): self.bag_dag(dag=dag, root_dag=dag) found_dags.append(dag) found_dags += dag.subdags - except (CroniterBadCronError, - CroniterBadDateError, - CroniterNotAlphaError) as cron_e: + except (CroniterBadCronError, CroniterBadDateError, CroniterNotAlphaError) as cron_e: self.log.exception("Failed to bag_dag: %s", dag.full_filepath) self.import_errors[dag.full_filepath] = f"Invalid Cron expression: {cron_e}" - self.file_last_changed[dag.full_filepath] = \ - file_last_changed_on_disk - except (AirflowDagCycleException, - AirflowClusterPolicyViolation) as exception: + self.file_last_changed[dag.full_filepath] = file_last_changed_on_disk + except (AirflowDagCycleException, AirflowClusterPolicyViolation) as exception: self.log.exception("Failed to bag_dag: %s", dag.full_filepath) self.import_errors[dag.full_filepath] = str(exception) self.file_last_changed[dag.full_filepath] = file_last_changed_on_disk @@ -412,12 +412,13 @@ def bag_dag(self, dag, root_dag): raise cycle_exception def collect_dags( - self, - dag_folder=None, - only_if_updated=True, - include_examples=conf.getboolean('core', 'LOAD_EXAMPLES'), - include_smart_sensor=conf.getboolean('smart_sensor', 'USE_SMART_SENSOR'), - safe_mode=conf.getboolean('core', 'DAG_DISCOVERY_SAFE_MODE')): + self, + dag_folder=None, + only_if_updated=True, + include_examples=conf.getboolean('core', 'LOAD_EXAMPLES'), + include_smart_sensor=conf.getboolean('smart_sensor', 'USE_SMART_SENSOR'), + safe_mode=conf.getboolean('core', 'DAG_DISCOVERY_SAFE_MODE'), + ): """ Given a file path or a folder, this method looks for python modules, imports them and adds them to the dagbag collection. @@ -440,26 +441,26 @@ def collect_dags( stats = [] dag_folder = correct_maybe_zipped(dag_folder) - for filepath in list_py_file_paths(dag_folder, - safe_mode=safe_mode, - include_examples=include_examples, - include_smart_sensor=include_smart_sensor): + for filepath in list_py_file_paths( + dag_folder, + safe_mode=safe_mode, + include_examples=include_examples, + include_smart_sensor=include_smart_sensor, + ): try: file_parse_start_dttm = timezone.utcnow() - found_dags = self.process_file( - filepath, - only_if_updated=only_if_updated, - safe_mode=safe_mode - ) + found_dags = self.process_file(filepath, only_if_updated=only_if_updated, safe_mode=safe_mode) file_parse_end_dttm = timezone.utcnow() - stats.append(FileLoadStat( - file=filepath.replace(settings.DAGS_FOLDER, ''), - duration=file_parse_end_dttm - file_parse_start_dttm, - dag_num=len(found_dags), - task_num=sum([len(dag.tasks) for dag in found_dags]), - dags=str([dag.dag_id for dag in found_dags]), - )) + stats.append( + FileLoadStat( + file=filepath.replace(settings.DAGS_FOLDER, ''), + duration=file_parse_end_dttm - file_parse_start_dttm, + dag_num=len(found_dags), + task_num=sum([len(dag.tasks) for dag in found_dags]), + dags=str([dag.dag_id for dag in found_dags]), + ) + ) except Exception as e: # pylint: disable=broad-except self.log.exception(e) @@ -468,19 +469,17 @@ def collect_dags( Stats.gauge('collect_dags', durations, 1) Stats.gauge('dagbag_size', len(self.dags), 1) Stats.gauge('dagbag_import_errors', len(self.import_errors), 1) - self.dagbag_stats = sorted( - stats, key=lambda x: x.duration, reverse=True) + self.dagbag_stats = sorted(stats, key=lambda x: x.duration, reverse=True) for file_stat in self.dagbag_stats: # file_stat.file similar format: /subdir/dag_name.py # TODO: Remove for Airflow 2.0 filename = file_stat.file.split('/')[-1].replace('.py', '') - Stats.timing('dag.loading-duration.{}'. - format(filename), - file_stat.duration) + Stats.timing(f'dag.loading-duration.{filename}', file_stat.duration) def collect_dags_from_db(self): """Collects DAGs from database.""" from airflow.models.serialized_dag import SerializedDagModel + start_dttm = timezone.utcnow() self.log.info("Filling up the DagBag from database") @@ -508,7 +507,8 @@ def dagbag_report(self): task_num = sum([o.task_num for o in stats]) table = tabulate(stats, headers="keys") - report = textwrap.dedent(f"""\n + report = textwrap.dedent( + f"""\n ------------------------------------------------------------------- DagBag loading stats for {dag_folder} ------------------------------------------------------------------- @@ -516,7 +516,8 @@ def dagbag_report(self): Total task number: {task_num} DagBag parsing time: {duration} {table} - """) + """ + ) return report @provide_session @@ -534,13 +535,13 @@ def sync_to_db(self, session: Optional[Session] = None): wait=tenacity.wait_random_exponential(multiplier=0.5, max=5), stop=tenacity.stop_after_attempt(settings.MAX_DB_RETRIES), before_sleep=tenacity.before_sleep_log(self.log, logging.DEBUG), - reraise=True + reraise=True, ): with attempt: self.log.debug( "Running dagbag.sync_to_db with retries. Try %d of %d", attempt.retry_state.attempt_number, - settings.MAX_DB_RETRIES + settings.MAX_DB_RETRIES, ) self.log.debug("Calling the DAG.bulk_sync_to_db method") try: diff --git a/airflow/models/dagcode.py b/airflow/models/dagcode.py index 5194c3dc7e5b1..67e12b7e2a196 100644 --- a/airflow/models/dagcode.py +++ b/airflow/models/dagcode.py @@ -46,8 +46,7 @@ class DagCode(Base): __tablename__ = 'dag_code' - fileloc_hash = Column( - BigInteger, nullable=False, primary_key=True, autoincrement=False) + fileloc_hash = Column(BigInteger, nullable=False, primary_key=True, autoincrement=False) fileloc = Column(String(2000), nullable=False) # The max length of fileloc exceeds the limit of indexing. last_updated = Column(UtcDateTime, nullable=False) @@ -76,12 +75,9 @@ def bulk_sync_to_db(cls, filelocs: Iterable[str], session=None): :param session: ORM Session """ filelocs = set(filelocs) - filelocs_to_hashes = { - fileloc: DagCode.dag_fileloc_hash(fileloc) for fileloc in filelocs - } + filelocs_to_hashes = {fileloc: DagCode.dag_fileloc_hash(fileloc) for fileloc in filelocs} existing_orm_dag_codes = ( - session - .query(DagCode) + session.query(DagCode) .filter(DagCode.fileloc_hash.in_(filelocs_to_hashes.values())) .with_for_update(of=DagCode) .all() @@ -94,29 +90,20 @@ def bulk_sync_to_db(cls, filelocs: Iterable[str], session=None): else: existing_orm_dag_codes_map = {} - existing_orm_dag_codes_by_fileloc_hashes = { - orm.fileloc_hash: orm for orm in existing_orm_dag_codes - } - existing_orm_filelocs = { - orm.fileloc for orm in existing_orm_dag_codes_by_fileloc_hashes.values() - } + existing_orm_dag_codes_by_fileloc_hashes = {orm.fileloc_hash: orm for orm in existing_orm_dag_codes} + existing_orm_filelocs = {orm.fileloc for orm in existing_orm_dag_codes_by_fileloc_hashes.values()} if not existing_orm_filelocs.issubset(filelocs): conflicting_filelocs = existing_orm_filelocs.difference(filelocs) - hashes_to_filelocs = { - DagCode.dag_fileloc_hash(fileloc): fileloc for fileloc in filelocs - } + hashes_to_filelocs = {DagCode.dag_fileloc_hash(fileloc): fileloc for fileloc in filelocs} message = "" for fileloc in conflicting_filelocs: - message += ("Filename '{}' causes a hash collision in the " + - "database with '{}'. Please rename the file.")\ - .format( - hashes_to_filelocs[DagCode.dag_fileloc_hash(fileloc)], - fileloc) + message += ( + "Filename '{}' causes a hash collision in the " + + "database with '{}'. Please rename the file." + ).format(hashes_to_filelocs[DagCode.dag_fileloc_hash(fileloc)], fileloc) raise AirflowException(message) - existing_filelocs = { - dag_code.fileloc for dag_code in existing_orm_dag_codes - } + existing_filelocs = {dag_code.fileloc for dag_code in existing_orm_dag_codes} missing_filelocs = filelocs.difference(existing_filelocs) for fileloc in missing_filelocs: @@ -143,14 +130,13 @@ def remove_deleted_code(cls, alive_dag_filelocs: List[str], session=None): :param alive_dag_filelocs: file paths of alive DAGs :param session: ORM Session """ - alive_fileloc_hashes = [ - cls.dag_fileloc_hash(fileloc) for fileloc in alive_dag_filelocs] + alive_fileloc_hashes = [cls.dag_fileloc_hash(fileloc) for fileloc in alive_dag_filelocs] log.debug("Deleting code from %s table ", cls.__tablename__) session.query(cls).filter( - cls.fileloc_hash.notin_(alive_fileloc_hashes), - cls.fileloc.notin_(alive_dag_filelocs)).delete(synchronize_session='fetch') + cls.fileloc_hash.notin_(alive_fileloc_hashes), cls.fileloc.notin_(alive_dag_filelocs) + ).delete(synchronize_session='fetch') @classmethod @provide_session @@ -161,8 +147,7 @@ def has_dag(cls, fileloc: str, session=None) -> bool: :param session: ORM Session """ fileloc_hash = cls.dag_fileloc_hash(fileloc) - return session.query(exists().where(cls.fileloc_hash == fileloc_hash))\ - .scalar() + return session.query(exists().where(cls.fileloc_hash == fileloc_hash)).scalar() @classmethod def get_code_by_fileloc(cls, fileloc: str) -> str: @@ -193,9 +178,7 @@ def _get_code_from_file(fileloc): @classmethod @provide_session def _get_code_from_db(cls, fileloc, session=None): - dag_code = session.query(cls) \ - .filter(cls.fileloc_hash == cls.dag_fileloc_hash(fileloc)) \ - .first() + dag_code = session.query(cls).filter(cls.fileloc_hash == cls.dag_fileloc_hash(fileloc)).first() if not dag_code: raise DagCodeNotFound() else: @@ -214,5 +197,4 @@ def dag_fileloc_hash(full_filepath: str) -> int: import hashlib # Only 7 bytes because MySQL BigInteger can hold only 8 bytes (signed). - return struct.unpack('>Q', hashlib.sha1( - full_filepath.encode('utf-8')).digest()[-8:])[0] >> 8 + return struct.unpack('>Q', hashlib.sha1(full_filepath.encode('utf-8')).digest()[-8:])[0] >> 8 diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 6ae17d143a6b7..ad03bad1033c2 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -19,7 +19,17 @@ from typing import Any, Iterable, List, NamedTuple, Optional, Tuple, Union from sqlalchemy import ( - Boolean, Column, DateTime, Index, Integer, PickleType, String, UniqueConstraint, and_, func, or_, + Boolean, + Column, + DateTime, + Index, + Integer, + PickleType, + String, + UniqueConstraint, + and_, + func, + or_, ) from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.declarative import declared_attr @@ -125,13 +135,13 @@ def __init__( def __repr__(self): return ( - '' + '' ).format( dag_id=self.dag_id, execution_date=self.execution_date, run_id=self.run_id, - external_trigger=self.external_trigger) + external_trigger=self.external_trigger, + ) def get_state(self): return self._state @@ -157,11 +167,15 @@ def refresh_from_db(self, session: Session = None): exec_date = func.cast(self.execution_date, DateTime) - dr = session.query(DR).filter( - DR.dag_id == self.dag_id, - func.cast(DR.execution_date, DateTime) == exec_date, - DR.run_id == self.run_id - ).one() + dr = ( + session.query(DR) + .filter( + DR.dag_id == self.dag_id, + func.cast(DR.execution_date, DateTime) == exec_date, + DR.run_id == self.run_id, + ) + .one() + ) self.id = dr.id self.state = dr.state @@ -187,18 +201,21 @@ def next_dagruns_to_examine( max_number = cls.DEFAULT_DAGRUNS_TO_EXAMINE # TODO: Bake this query, it is run _A lot_ - query = session.query(cls).filter( - cls.state == State.RUNNING, - cls.run_type != DagRunType.BACKFILL_JOB - ).join( - DagModel, - DagModel.dag_id == cls.dag_id, - ).filter( - DagModel.is_paused.is_(False), - DagModel.is_active.is_(True), - ).order_by( - nulls_first(cls.last_scheduling_decision, session=session), - cls.execution_date, + query = ( + session.query(cls) + .filter(cls.state == State.RUNNING, cls.run_type != DagRunType.BACKFILL_JOB) + .join( + DagModel, + DagModel.dag_id == cls.dag_id, + ) + .filter( + DagModel.is_paused.is_(False), + DagModel.is_active.is_(True), + ) + .order_by( + nulls_first(cls.last_scheduling_decision, session=session), + cls.execution_date, + ) ) if not settings.ALLOW_FUTURE_EXEC_DATES: @@ -218,7 +235,7 @@ def find( run_type: Optional[DagRunType] = None, session: Session = None, execution_start_date: Optional[datetime] = None, - execution_end_date: Optional[datetime] = None + execution_end_date: Optional[datetime] = None, ) -> List["DagRun"]: """ Returns a set of dag runs for the given search criteria. @@ -300,10 +317,7 @@ def get_task_instances(self, state=None, session=None): tis = tis.filter(TI.state.is_(None)) else: not_none_state = [s for s in state if s] - tis = tis.filter( - or_(TI.state.in_(not_none_state), - TI.state.is_(None)) - ) + tis = tis.filter(or_(TI.state.in_(not_none_state), TI.state.is_(None))) else: tis = tis.filter(TI.state.in_(state)) @@ -321,11 +335,11 @@ def get_task_instance(self, task_id: str, session: Session = None): :param session: Sqlalchemy ORM Session :type session: Session """ - ti = session.query(TI).filter( - TI.dag_id == self.dag_id, - TI.execution_date == self.execution_date, - TI.task_id == task_id - ).first() + ti = ( + session.query(TI) + .filter(TI.dag_id == self.dag_id, TI.execution_date == self.execution_date, TI.task_id == task_id) + .first() + ) return ti @@ -349,27 +363,25 @@ def get_previous_dagrun(self, state: Optional[str] = None, session: Session = No ] if state is not None: filters.append(DagRun.state == state) - return session.query(DagRun).filter( - *filters - ).order_by( - DagRun.execution_date.desc() - ).first() + return session.query(DagRun).filter(*filters).order_by(DagRun.execution_date.desc()).first() @provide_session def get_previous_scheduled_dagrun(self, session: Session = None): """The previous, SCHEDULED DagRun, if there is one""" dag = self.get_dag() - return session.query(DagRun).filter( - DagRun.dag_id == self.dag_id, - DagRun.execution_date == dag.previous_schedule(self.execution_date) - ).first() + return ( + session.query(DagRun) + .filter( + DagRun.dag_id == self.dag_id, + DagRun.execution_date == dag.previous_schedule(self.execution_date), + ) + .first() + ) @provide_session def update_state( - self, - session: Session = None, - execute_callbacks: bool = True + self, session: Session = None, execute_callbacks: bool = True ) -> Tuple[List[TI], Optional[callback_requests.DagCallbackRequest]]: """ Determines the overall state of the DagRun based on the state @@ -403,10 +415,13 @@ def update_state( if unfinished_tasks and none_depends_on_past and none_task_concurrency: # small speed up - are_runnable_tasks = schedulable_tis or self._are_premature_tis( - unfinished_tasks, finished_tasks, session) or changed_tis + are_runnable_tasks = ( + schedulable_tis + or self._are_premature_tis(unfinished_tasks, finished_tasks, session) + or changed_tis + ) - duration = (timezone.utcnow() - start_dttm) + duration = timezone.utcnow() - start_dttm Stats.timing(f"dagrun.dependency-check.{self.dag_id}", duration) leaf_task_ids = {t.task_id for t in dag.leaves} @@ -426,7 +441,7 @@ def update_state( dag_id=self.dag_id, execution_date=self.execution_date, is_failure_callback=True, - msg='task_failure' + msg='task_failure', ) # if all leafs succeeded and no unfinished tasks, the run succeeded @@ -443,12 +458,11 @@ def update_state( dag_id=self.dag_id, execution_date=self.execution_date, is_failure_callback=False, - msg='success' + msg='success', ) # if *all tasks* are deadlocked, the run failed - elif (unfinished_tasks and none_depends_on_past and - none_task_concurrency and not are_runnable_tasks): + elif unfinished_tasks and none_depends_on_past and none_task_concurrency and not are_runnable_tasks: self.log.error('Deadlock; marking run %s failed', self) self.set_state(State.FAILED) if execute_callbacks: @@ -459,7 +473,7 @@ def update_state( dag_id=self.dag_id, execution_date=self.execution_date, is_failure_callback=True, - msg='all_tasks_deadlocked' + msg='all_tasks_deadlocked', ) # finally, if the roots aren't done, the dag is still running @@ -487,9 +501,7 @@ def task_instance_scheduling_decisions(self, session: Session = None) -> TISched finished_tasks = [t for t in tis if t.state in State.finished] if unfinished_tasks: scheduleable_tasks = [ut for ut in unfinished_tasks if ut.state in SCHEDULEABLE_STATES] - self.log.debug( - "number of scheduleable tasks for %s: %s task(s)", - self, len(scheduleable_tasks)) + self.log.debug("number of scheduleable tasks for %s: %s task(s)", self, len(scheduleable_tasks)) schedulable_tis, changed_tis = self._get_ready_tis(scheduleable_tasks, finished_tasks, session) return TISchedulingDecision( @@ -517,10 +529,9 @@ def _get_ready_tis( for st in scheduleable_tasks: old_state = st.state if st.are_dependencies_met( - dep_context=DepContext( - flag_upstream_failed=True, - finished_tasks=finished_tasks), - session=session): + dep_context=DepContext(flag_upstream_failed=True, finished_tasks=finished_tasks), + session=session, + ): ready_tis.append(st) else: old_states[st.key] = old_state @@ -547,8 +558,10 @@ def _are_premature_tis( flag_upstream_failed=True, ignore_in_retry_period=True, ignore_in_reschedule_period=True, - finished_tasks=finished_tasks), - session=session): + finished_tasks=finished_tasks, + ), + session=session, + ): return True return False @@ -556,7 +569,7 @@ def _emit_duration_stats_for_finished_state(self): if self.state == State.RUNNING: return - duration = (self.end_date - self.start_date) + duration = self.end_date - self.start_date if self.state is State.SUCCESS: Stats.timing(f'dagrun.duration.success.{self.dag_id}', duration) elif self.state == State.FAILED: @@ -586,16 +599,15 @@ def verify_integrity(self, session: Session = None): if ti.state == State.REMOVED: pass # ti has already been removed, just ignore it elif self.state is not State.RUNNING and not dag.partial: - self.log.warning("Failed to get task '%s' for dag '%s'. " - "Marking it as removed.", ti, dag) - Stats.incr( - f"task_removed_from_dag.{dag.dag_id}", 1, 1) + self.log.warning( + "Failed to get task '%s' for dag '%s'. " "Marking it as removed.", ti, dag + ) + Stats.incr(f"task_removed_from_dag.{dag.dag_id}", 1, 1) ti.state = State.REMOVED should_restore_task = (task is not None) and ti.state == State.REMOVED if should_restore_task: - self.log.info("Restoring task '%s' which was previously " - "removed from DAG '%s'", ti, dag) + self.log.info("Restoring task '%s' which was previously " "removed from DAG '%s'", ti, dag) Stats.incr(f"task_restored_to_dag.{dag.dag_id}", 1, 1) ti.state = State.NONE session.merge(ti) @@ -606,9 +618,7 @@ def verify_integrity(self, session: Session = None): continue if task.task_id not in task_ids: - Stats.incr( - f"task_instance_created-{task.task_type}", - 1, 1) + Stats.incr(f"task_instance_created-{task.task_type}", 1, 1) ti = TI(task, self.execution_date) task_instance_mutation_hook(ti) session.add(ti) @@ -617,8 +627,9 @@ def verify_integrity(self, session: Session = None): session.flush() except IntegrityError as err: self.log.info(str(err)) - self.log.info('Hit IntegrityError while creating the TIs for ' - f'{dag.dag_id} - {self.execution_date}.') + self.log.info( + 'Hit IntegrityError while creating the TIs for ' f'{dag.dag_id} - {self.execution_date}.' + ) self.log.info('Doing session rollback.') # TODO[HA]: We probably need to savepoint this so we can keep the transaction alive. session.rollback() @@ -654,19 +665,16 @@ def is_backfill(self): def get_latest_runs(cls, session=None): """Returns the latest DagRun for each DAG""" subquery = ( - session - .query( - cls.dag_id, - func.max(cls.execution_date).label('execution_date')) + session.query(cls.dag_id, func.max(cls.execution_date).label('execution_date')) .group_by(cls.dag_id) .subquery() ) dagruns = ( - session - .query(cls) - .join(subquery, - and_(cls.dag_id == subquery.c.dag_id, - cls.execution_date == subquery.c.execution_date)) + session.query(cls) + .join( + subquery, + and_(cls.dag_id == subquery.c.dag_id, cls.execution_date == subquery.c.execution_date), + ) .all() ) return dagruns @@ -686,9 +694,9 @@ def schedule_tis(self, schedulable_tis: Iterable[TI], session: Session = None) - # Get list of TIs that do not need to executed, these are # tasks using DummyOperator and without on_execute_callback / on_success_callback dummy_tis = { - ti for ti in schedulable_tis - if - ( + ti + for ti in schedulable_tis + if ( ti.task.task_type == "DummyOperator" and not ti.task.on_execute_callback and not ti.task.on_success_callback @@ -699,23 +707,34 @@ def schedule_tis(self, schedulable_tis: Iterable[TI], session: Session = None) - count = 0 if schedulable_ti_ids: - count += session.query(TI).filter( - TI.dag_id == self.dag_id, - TI.execution_date == self.execution_date, - TI.task_id.in_(schedulable_ti_ids) - ).update({TI.state: State.SCHEDULED}, synchronize_session=False) + count += ( + session.query(TI) + .filter( + TI.dag_id == self.dag_id, + TI.execution_date == self.execution_date, + TI.task_id.in_(schedulable_ti_ids), + ) + .update({TI.state: State.SCHEDULED}, synchronize_session=False) + ) # Tasks using DummyOperator should not be executed, mark them as success if dummy_tis: - count += session.query(TI).filter( - TI.dag_id == self.dag_id, - TI.execution_date == self.execution_date, - TI.task_id.in_(ti.task_id for ti in dummy_tis) - ).update({ - TI.state: State.SUCCESS, - TI.start_date: timezone.utcnow(), - TI.end_date: timezone.utcnow(), - TI.duration: 0 - }, synchronize_session=False) + count += ( + session.query(TI) + .filter( + TI.dag_id == self.dag_id, + TI.execution_date == self.execution_date, + TI.task_id.in_(ti.task_id for ti in dummy_tis), + ) + .update( + { + TI.state: State.SUCCESS, + TI.start_date: timezone.utcnow(), + TI.end_date: timezone.utcnow(), + TI.duration: 0, + }, + synchronize_session=False, + ) + ) return count diff --git a/airflow/models/log.py b/airflow/models/log.py index 8f12c63f514ad..7842f63483360 100644 --- a/airflow/models/log.py +++ b/airflow/models/log.py @@ -37,9 +37,7 @@ class Log(Base): owner = Column(String(500)) extra = Column(Text) - __table_args__ = ( - Index('idx_log_dag', dag_id), - ) + __table_args__ = (Index('idx_log_dag', dag_id),) def __init__(self, event, task_instance, owner=None, extra=None, **kwargs): self.dttm = timezone.utcnow() diff --git a/airflow/models/pool.py b/airflow/models/pool.py index 541a464c076cb..67cdee52098d2 100644 --- a/airflow/models/pool.py +++ b/airflow/models/pool.py @@ -53,7 +53,7 @@ class Pool(Base): DEFAULT_POOL_NAME = 'default_pool' def __repr__(self): - return str(self.pool) # pylint: disable=E0012 + return str(self.pool) # pylint: disable=E0012 @staticmethod @provide_session @@ -126,9 +126,7 @@ def slots_stats( elif state == "queued": stats_dict["queued"] = count else: - raise AirflowException( - f"Unexpected state. Expected values: {EXECUTION_STATES}." - ) + raise AirflowException(f"Unexpected state. Expected values: {EXECUTION_STATES}.") # calculate open metric for pool_name, stats_dict in pools.items(): @@ -162,9 +160,9 @@ def occupied_slots(self, session: Session): :return: the used number of slots """ from airflow.models.taskinstance import TaskInstance # Avoid circular import + return ( - session - .query(func.sum(TaskInstance.pool_slots)) + session.query(func.sum(TaskInstance.pool_slots)) .filter(TaskInstance.pool == self.pool) .filter(TaskInstance.state.in_(list(EXECUTION_STATES))) .scalar() @@ -181,8 +179,7 @@ def running_slots(self, session: Session): from airflow.models.taskinstance import TaskInstance # Avoid circular import return ( - session - .query(func.sum(TaskInstance.pool_slots)) + session.query(func.sum(TaskInstance.pool_slots)) .filter(TaskInstance.pool == self.pool) .filter(TaskInstance.state == State.RUNNING) .scalar() @@ -199,8 +196,7 @@ def queued_slots(self, session: Session): from airflow.models.taskinstance import TaskInstance # Avoid circular import return ( - session - .query(func.sum(TaskInstance.pool_slots)) + session.query(func.sum(TaskInstance.pool_slots)) .filter(TaskInstance.pool == self.pool) .filter(TaskInstance.state == State.QUEUED) .scalar() diff --git a/airflow/models/renderedtifields.py b/airflow/models/renderedtifields.py index 661b0a8ac219e..8102b01762847 100644 --- a/airflow/models/renderedtifields.py +++ b/airflow/models/renderedtifields.py @@ -50,9 +50,7 @@ def __init__(self, ti: TaskInstance, render_templates=True): if render_templates: ti.render_templates() self.rendered_fields = { - field: serialize_template_field( - getattr(self.task, field) - ) for field in self.task.template_fields + field: serialize_template_field(getattr(self.task, field)) for field in self.task.template_fields } def __repr__(self): @@ -69,11 +67,13 @@ def get_templated_fields(cls, ti: TaskInstance, session: Session = None) -> Opti :param session: SqlAlchemy Session :return: Rendered Templated TI field """ - result = session.query(cls.rendered_fields).filter( - cls.dag_id == ti.dag_id, - cls.task_id == ti.task_id, - cls.execution_date == ti.execution_date - ).one_or_none() + result = ( + session.query(cls.rendered_fields) + .filter( + cls.dag_id == ti.dag_id, cls.task_id == ti.task_id, cls.execution_date == ti.execution_date + ) + .one_or_none() + ) if result: rendered_fields = result.rendered_fields @@ -92,9 +92,11 @@ def write(self, session: Session = None): @classmethod @provide_session def delete_old_records( - cls, task_id: str, dag_id: str, + cls, + task_id: str, + dag_id: str, num_to_keep=conf.getint("core", "max_num_rendered_ti_fields_per_task", fallback=0), - session: Session = None + session: Session = None, ): """ Keep only Last X (num_to_keep) number of records for a task by deleting others @@ -107,22 +109,22 @@ def delete_old_records( if num_to_keep <= 0: return - tis_to_keep_query = session \ - .query(cls.dag_id, cls.task_id, cls.execution_date) \ - .filter(cls.dag_id == dag_id, cls.task_id == task_id) \ - .order_by(cls.execution_date.desc()) \ + tis_to_keep_query = ( + session.query(cls.dag_id, cls.task_id, cls.execution_date) + .filter(cls.dag_id == dag_id, cls.task_id == task_id) + .order_by(cls.execution_date.desc()) .limit(num_to_keep) + ) if session.bind.dialect.name in ["postgresql", "sqlite"]: # Fetch Top X records given dag_id & task_id ordered by Execution Date subq1 = tis_to_keep_query.subquery('subq1') - session.query(cls) \ - .filter( - cls.dag_id == dag_id, - cls.task_id == task_id, - tuple_(cls.dag_id, cls.task_id, cls.execution_date).notin_(subq1)) \ - .delete(synchronize_session=False) + session.query(cls).filter( + cls.dag_id == dag_id, + cls.task_id == task_id, + tuple_(cls.dag_id, cls.task_id, cls.execution_date).notin_(subq1), + ).delete(synchronize_session=False) elif session.bind.dialect.name in ["mysql"]: # Fetch Top X records given dag_id & task_id ordered by Execution Date subq1 = tis_to_keep_query.subquery('subq1') @@ -131,28 +133,26 @@ def delete_old_records( # Workaround for MySQL Limitation (https://stackoverflow.com/a/19344141/5691525) # Limitation: This version of MySQL does not yet support # LIMIT & IN/ALL/ANY/SOME subquery - subq2 = ( - session - .query(subq1.c.dag_id, subq1.c.task_id, subq1.c.execution_date) - .subquery('subq2') - ) + subq2 = session.query(subq1.c.dag_id, subq1.c.task_id, subq1.c.execution_date).subquery('subq2') - session.query(cls) \ - .filter( - cls.dag_id == dag_id, - cls.task_id == task_id, - tuple_(cls.dag_id, cls.task_id, cls.execution_date).notin_(subq2)) \ - .delete(synchronize_session=False) + session.query(cls).filter( + cls.dag_id == dag_id, + cls.task_id == task_id, + tuple_(cls.dag_id, cls.task_id, cls.execution_date).notin_(subq2), + ).delete(synchronize_session=False) else: # Fetch Top X records given dag_id & task_id ordered by Execution Date tis_to_keep = tis_to_keep_query.all() - filter_tis = [not_(and_( - cls.dag_id == ti.dag_id, - cls.task_id == ti.task_id, - cls.execution_date == ti.execution_date - )) for ti in tis_to_keep] - - session.query(cls) \ - .filter(and_(*filter_tis)) \ - .delete(synchronize_session=False) + filter_tis = [ + not_( + and_( + cls.dag_id == ti.dag_id, + cls.task_id == ti.task_id, + cls.execution_date == ti.execution_date, + ) + ) + for ti in tis_to_keep + ] + + session.query(cls).filter(and_(*filter_tis)).delete(synchronize_session=False) diff --git a/airflow/models/sensorinstance.py b/airflow/models/sensorinstance.py index 3e0c417babe3b..f8e6ef82919c0 100644 --- a/airflow/models/sensorinstance.py +++ b/airflow/models/sensorinstance.py @@ -60,14 +60,10 @@ class SensorInstance(Base): poke_context = Column(Text, nullable=False) execution_context = Column(Text) created_at = Column(UtcDateTime, default=timezone.utcnow(), nullable=False) - updated_at = Column(UtcDateTime, - default=timezone.utcnow(), - onupdate=timezone.utcnow(), - nullable=False) + updated_at = Column(UtcDateTime, default=timezone.utcnow(), onupdate=timezone.utcnow(), nullable=False) __table_args__ = ( Index('ti_primary_key', dag_id, task_id, execution_date, unique=True), - Index('si_hashcode', hashcode), Index('si_shardcode', shardcode), Index('si_state_shard', state, shardcode), @@ -118,11 +114,16 @@ def register(cls, ti, poke_context, execution_context, session=None): encoded_poke = json.dumps(poke_context) encoded_execution_context = json.dumps(execution_context) - sensor = session.query(SensorInstance).filter( - SensorInstance.dag_id == ti.dag_id, - SensorInstance.task_id == ti.task_id, - SensorInstance.execution_date == ti.execution_date - ).with_for_update().first() + sensor = ( + session.query(SensorInstance) + .filter( + SensorInstance.dag_id == ti.dag_id, + SensorInstance.task_id == ti.task_id, + SensorInstance.execution_date == ti.execution_date, + ) + .with_for_update() + .first() + ) if sensor is None: sensor = SensorInstance(ti=ti) @@ -161,6 +162,7 @@ def try_number(self, value): self._try_number = value def __repr__(self): - return "<{self.__class__.__name__}: id: {self.id} poke_context: {self.poke_context} " \ - "execution_context: {self.execution_context} state: {self.state}>".format( - self=self) + return ( + "<{self.__class__.__name__}: id: {self.id} poke_context: {self.poke_context} " + "execution_context: {self.execution_context} state: {self.state}>".format(self=self) + ) diff --git a/airflow/models/serialized_dag.py b/airflow/models/serialized_dag.py index 2eedd4759c6ab..dfd8b3a753c77 100644 --- a/airflow/models/serialized_dag.py +++ b/airflow/models/serialized_dag.py @@ -69,9 +69,7 @@ class SerializedDagModel(Base): last_updated = Column(UtcDateTime, nullable=False) dag_hash = Column(String(32), nullable=False) - __table_args__ = ( - Index('idx_fileloc_hash', fileloc_hash, unique=False), - ) + __table_args__ = (Index('idx_fileloc_hash', fileloc_hash, unique=False),) dag_runs = relationship( DagRun, @@ -115,16 +113,19 @@ def write_dag(cls, dag: DAG, min_update_interval: Optional[int] = None, session: # If Yes, does nothing # If No or the DAG does not exists, updates / writes Serialized DAG to DB if min_update_interval is not None: - if session.query(exists().where( - and_(cls.dag_id == dag.dag_id, - (timezone.utcnow() - timedelta(seconds=min_update_interval)) < cls.last_updated)) + if session.query( + exists().where( + and_( + cls.dag_id == dag.dag_id, + (timezone.utcnow() - timedelta(seconds=min_update_interval)) < cls.last_updated, + ) + ) ).scalar(): return log.debug("Checking if DAG (%s) changed", dag.dag_id) new_serialized_dag = cls(dag) - serialized_dag_hash_from_db = session.query( - cls.dag_hash).filter(cls.dag_id == dag.dag_id).scalar() + serialized_dag_hash_from_db = session.query(cls.dag_hash).filter(cls.dag_id == dag.dag_id).scalar() if serialized_dag_hash_from_db == new_serialized_dag.dag_hash: log.debug("Serialized DAG (%s) is unchanged. Skipping writing to DB", dag.dag_id) @@ -154,8 +155,10 @@ def read_all_dags(cls, session: Session = None) -> Dict[str, 'SerializedDAG']: dags[row.dag_id] = dag else: log.warning( - "dag_id Mismatch in DB: Row with dag_id '%s' has Serialised DAG " - "with '%s' dag_id", row.dag_id, dag.dag_id) + "dag_id Mismatch in DB: Row with dag_id '%s' has Serialised DAG " "with '%s' dag_id", + row.dag_id, + dag.dag_id, + ) return dags @property @@ -186,16 +189,18 @@ def remove_deleted_dags(cls, alive_dag_filelocs: List[str], session=None): :param alive_dag_filelocs: file paths of alive DAGs :param session: ORM Session """ - alive_fileloc_hashes = [ - DagCode.dag_fileloc_hash(fileloc) for fileloc in alive_dag_filelocs] + alive_fileloc_hashes = [DagCode.dag_fileloc_hash(fileloc) for fileloc in alive_dag_filelocs] - log.debug("Deleting Serialized DAGs (for which DAG files are deleted) " - "from %s table ", cls.__tablename__) + log.debug( + "Deleting Serialized DAGs (for which DAG files are deleted) " "from %s table ", cls.__tablename__ + ) # pylint: disable=no-member - session.execute(cls.__table__.delete().where( - and_(cls.fileloc_hash.notin_(alive_fileloc_hashes), - cls.fileloc.notin_(alive_dag_filelocs)))) + session.execute( + cls.__table__.delete().where( + and_(cls.fileloc_hash.notin_(alive_fileloc_hashes), cls.fileloc.notin_(alive_dag_filelocs)) + ) + ) @classmethod @provide_session @@ -224,8 +229,7 @@ def get(cls, dag_id: str, session: Session = None) -> Optional['SerializedDagMod # If we didn't find a matching DAG id then ask the DAG table to find # out the root dag - root_dag_id = session.query( - DagModel.root_dag_id).filter(DagModel.dag_id == dag_id).scalar() + root_dag_id = session.query(DagModel.root_dag_id).filter(DagModel.dag_id == dag_id).scalar() return session.query(cls).filter(cls.dag_id == root_dag_id).one_or_none() @@ -245,9 +249,7 @@ def bulk_sync_to_db(dags: List[DAG], session: Session = None): for dag in dags: if not dag.is_subdag: SerializedDagModel.write_dag( - dag, - min_update_interval=MIN_SERIALIZED_DAG_UPDATE_INTERVAL, - session=session + dag, min_update_interval=MIN_SERIALIZED_DAG_UPDATE_INTERVAL, session=session ) @classmethod diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py index 43033310e1ae4..dc40329087f8c 100644 --- a/airflow/models/skipmixin.py +++ b/airflow/models/skipmixin.py @@ -70,7 +70,11 @@ def _set_state_to_skipped(self, dag_run, execution_date, tasks, session): @provide_session def skip( - self, dag_run, execution_date, tasks, session=None, + self, + dag_run, + execution_date, + tasks, + session=None, ): """ Sets tasks instances to skipped from the same dag run. @@ -105,12 +109,10 @@ def skip( task_id=task_id, dag_id=dag_run.dag_id, execution_date=dag_run.execution_date, - session=session + session=session, ) - def skip_all_except( - self, ti: TaskInstance, branch_task_ids: Union[str, Iterable[str]] - ): + def skip_all_except(self, ti: TaskInstance, branch_task_ids: Union[str, Iterable[str]]): """ This method implements the logic for a branching operator; given a single task ID or list of task IDs to follow, this skips all other tasks @@ -145,25 +147,15 @@ def skip_all_except( # task1 # for branch_task_id in list(branch_task_ids): - branch_task_ids.update( - dag.get_task(branch_task_id).get_flat_relative_ids(upstream=False) - ) - - skip_tasks = [ - t - for t in downstream_tasks - if t.task_id not in branch_task_ids - ] + branch_task_ids.update(dag.get_task(branch_task_id).get_flat_relative_ids(upstream=False)) + + skip_tasks = [t for t in downstream_tasks if t.task_id not in branch_task_ids] follow_task_ids = [t.task_id for t in downstream_tasks if t.task_id in branch_task_ids] self.log.info("Skipping tasks %s", [t.task_id for t in skip_tasks]) with create_session() as session: - self._set_state_to_skipped( - dag_run, ti.execution_date, skip_tasks, session=session - ) + self._set_state_to_skipped(dag_run, ti.execution_date, skip_tasks, session=session) # For some reason, session.commit() needs to happen before xcom_push. # Otherwise the session is not committed. session.commit() - ti.xcom_push( - key=XCOM_SKIPMIXIN_KEY, value={XCOM_SKIPMIXIN_FOLLOWED: follow_task_ids} - ) + ti.xcom_push(key=XCOM_SKIPMIXIN_KEY, value={XCOM_SKIPMIXIN_FOLLOWED: follow_task_ids}) diff --git a/airflow/models/slamiss.py b/airflow/models/slamiss.py index 7c7160bf07bd4..6c841e3ee6e05 100644 --- a/airflow/models/slamiss.py +++ b/airflow/models/slamiss.py @@ -39,10 +39,7 @@ class SlaMiss(Base): description = Column(Text) notification_sent = Column(Boolean, default=False) - __table_args__ = ( - Index('sm_dag', dag_id, unique=False), - ) + __table_args__ = (Index('sm_dag', dag_id, unique=False),) def __repr__(self): - return str(( - self.dag_id, self.task_id, self.execution_date.isoformat())) + return str((self.dag_id, self.task_id, self.execution_date.isoformat())) diff --git a/airflow/models/taskfail.py b/airflow/models/taskfail.py index 5b3979203d617..23d3086e2d013 100644 --- a/airflow/models/taskfail.py +++ b/airflow/models/taskfail.py @@ -35,10 +35,7 @@ class TaskFail(Base): end_date = Column(UtcDateTime) duration = Column(Integer) - __table_args__ = ( - Index('idx_task_fail_dag_task_date', dag_id, task_id, execution_date, - unique=False), - ) + __table_args__ = (Index('idx_task_fail_dag_task_date', dag_id, task_id, execution_date, unique=False),) def __init__(self, task, execution_date, start_date, end_date): self.dag_id = task.dag_id diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 872682d780179..6a73b40d674a6 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -42,8 +42,12 @@ from airflow import settings from airflow.configuration import conf from airflow.exceptions import ( - AirflowException, AirflowFailException, AirflowRescheduleException, AirflowSkipException, - AirflowSmartSensorException, AirflowTaskTimeout, + AirflowException, + AirflowFailException, + AirflowRescheduleException, + AirflowSkipException, + AirflowSmartSensorException, + AirflowTaskTimeout, ) from airflow.models.base import COLLATION_ARGS, ID_LEN, Base from airflow.models.log import Log @@ -92,11 +96,12 @@ def set_current_context(context: Context): ) -def clear_task_instances(tis, - session, - activate_dag_runs=True, - dag=None, - ): +def clear_task_instances( + tis, + session, + activate_dag_runs=True, + dag=None, +): """ Clears a set of task instances, but makes sure the running ones get killed. @@ -132,20 +137,26 @@ def clear_task_instances(tis, TR.dag_id == ti.dag_id, TR.task_id == ti.task_id, TR.execution_date == ti.execution_date, - TR.try_number == ti.try_number + TR.try_number == ti.try_number, ).delete() if job_ids: from airflow.jobs.base_job import BaseJob + for job in session.query(BaseJob).filter(BaseJob.id.in_(job_ids)).all(): # noqa job.state = State.SHUTDOWN if activate_dag_runs and tis: from airflow.models.dagrun import DagRun # Avoid circular import - drs = session.query(DagRun).filter( - DagRun.dag_id.in_({ti.dag_id for ti in tis}), - DagRun.execution_date.in_({ti.execution_date for ti in tis}), - ).all() + + drs = ( + session.query(DagRun) + .filter( + DagRun.dag_id.in_({ti.dag_id for ti in tis}), + DagRun.execution_date.in_({ti.execution_date for ti in tis}), + ) + .all() + ) for dr in drs: dr.state = State.RUNNING dr.start_date = timezone.utcnow() @@ -167,18 +178,14 @@ def primary(self) -> Tuple[str, str, datetime]: @property def reduced(self) -> 'TaskInstanceKey': """Remake the key by subtracting 1 from try number to match in memory information""" - return TaskInstanceKey( - self.dag_id, self.task_id, self.execution_date, max(1, self.try_number - 1) - ) + return TaskInstanceKey(self.dag_id, self.task_id, self.execution_date, max(1, self.try_number - 1)) def with_try_number(self, try_number: int) -> 'TaskInstanceKey': """Returns TaskInstanceKey with provided ``try_number``""" - return TaskInstanceKey( - self.dag_id, self.task_id, self.execution_date, try_number - ) + return TaskInstanceKey(self.dag_id, self.task_id, self.execution_date, try_number) -class TaskInstance(Base, LoggingMixin): # pylint: disable=R0902,R0904 +class TaskInstance(Base, LoggingMixin): # pylint: disable=R0902,R0904 """ Task instances store the state of a task instance. This table is the authority and single source of truth around what tasks have run and the @@ -247,11 +254,12 @@ def __init__(self, task, execution_date: datetime, state: Optional[str] = None): # make sure we have a localized execution_date stored in UTC if execution_date and not timezone.is_localized(execution_date): - self.log.warning("execution date %s has no timezone information. Using " - "default from dag or system", execution_date) + self.log.warning( + "execution date %s has no timezone information. Using " "default from dag or system", + execution_date, + ) if self.task.has_dag(): - execution_date = timezone.make_aware(execution_date, - self.task.dag.timezone) + execution_date = timezone.make_aware(execution_date, self.task.dag.timezone) else: execution_date = timezone.make_aware(execution_date) @@ -314,19 +322,20 @@ def next_try_number(self): """Setting Next Try Number""" return self._try_number + 1 - def command_as_list( # pylint: disable=too-many-arguments - self, - mark_success=False, - ignore_all_deps=False, - ignore_task_deps=False, - ignore_depends_on_past=False, - ignore_ti_state=False, - local=False, - pickle_id=None, - raw=False, - job_id=None, - pool=None, - cfg_path=None): + def command_as_list( # pylint: disable=too-many-arguments + self, + mark_success=False, + ignore_all_deps=False, + ignore_task_deps=False, + ignore_depends_on_past=False, + ignore_ti_state=False, + local=False, + pickle_id=None, + raw=False, + job_id=None, + pool=None, + cfg_path=None, + ): """ Returns a command that can be executed anywhere where airflow is installed. This command is part of the message sent to executors by @@ -357,25 +366,27 @@ def command_as_list( # pylint: disable=too-many-arguments raw=raw, job_id=job_id, pool=pool, - cfg_path=cfg_path) + cfg_path=cfg_path, + ) @staticmethod - def generate_command(dag_id: str, # pylint: disable=too-many-arguments - task_id: str, - execution_date: datetime, - mark_success: bool = False, - ignore_all_deps: bool = False, - ignore_depends_on_past: bool = False, - ignore_task_deps: bool = False, - ignore_ti_state: bool = False, - local: bool = False, - pickle_id: Optional[int] = None, - file_path: Optional[str] = None, - raw: bool = False, - job_id: Optional[str] = None, - pool: Optional[str] = None, - cfg_path: Optional[str] = None - ) -> List[str]: + def generate_command( + dag_id: str, # pylint: disable=too-many-arguments + task_id: str, + execution_date: datetime, + mark_success: bool = False, + ignore_all_deps: bool = False, + ignore_depends_on_past: bool = False, + ignore_task_deps: bool = False, + ignore_ti_state: bool = False, + local: bool = False, + pickle_id: Optional[int] = None, + file_path: Optional[str] = None, + raw: bool = False, + job_id: Optional[str] = None, + pool: Optional[str] = None, + cfg_path: Optional[str] = None, + ) -> List[str]: """ Generates the shell command required to execute this task instance. @@ -457,10 +468,7 @@ def log_url(self): iso = quote(self.execution_date.isoformat()) base_url = conf.get('webserver', 'BASE_URL') return base_url + ( # noqa - "/log?" - "execution_date={iso}" - "&task_id={task_id}" - "&dag_id={dag_id}" + "/log?" "execution_date={iso}" "&task_id={task_id}" "&dag_id={dag_id}" ).format(iso=iso, task_id=self.task_id, dag_id=self.dag_id) @property @@ -487,11 +495,15 @@ def current_state(self, session=None) -> str: :param session: SQLAlchemy ORM Session :type session: Session """ - ti = session.query(TaskInstance).filter( - TaskInstance.dag_id == self.dag_id, - TaskInstance.task_id == self.task_id, - TaskInstance.execution_date == self.execution_date, - ).all() + ti = ( + session.query(TaskInstance) + .filter( + TaskInstance.dag_id == self.dag_id, + TaskInstance.task_id == self.task_id, + TaskInstance.execution_date == self.execution_date, + ) + .all() + ) if ti: state = ti[0].state else: @@ -528,7 +540,8 @@ def refresh_from_db(self, session=None, lock_for_update=False) -> None: qry = session.query(TaskInstance).filter( TaskInstance.dag_id == self.dag_id, TaskInstance.task_id == self.task_id, - TaskInstance.execution_date == self.execution_date) + TaskInstance.execution_date == self.execution_date, + ) if lock_for_update: ti = qry.with_for_update().first() @@ -542,7 +555,7 @@ def refresh_from_db(self, session=None, lock_for_update=False) -> None: self.state = ti.state # Get the raw value of try_number column, don't read through the # accessor here otherwise it will be incremented by one already. - self.try_number = ti._try_number # noqa pylint: disable=protected-access + self.try_number = ti._try_number # noqa pylint: disable=protected-access self.max_tries = ti.max_tries self.hostname = ti.hostname self.unixname = ti.unixname @@ -589,7 +602,7 @@ def clear_xcom_data(self, session=None): session.query(XCom).filter( XCom.dag_id == self.dag_id, XCom.task_id == self.task_id, - XCom.execution_date == self.execution_date + XCom.execution_date == self.execution_date, ).delete() session.commit() self.log.debug("XCom data cleared") @@ -656,9 +669,7 @@ def are_dependents_done(self, session=None): @provide_session def get_previous_ti( - self, - state: Optional[str] = None, - session: Session = None + self, state: Optional[str] = None, session: Session = None ) -> Optional['TaskInstance']: """ The task instance for the task that ran before this task instance. @@ -745,9 +756,7 @@ def get_previous_execution_date( @provide_session def get_previous_start_date( - self, - state: Optional[str] = None, - session: Session = None + self, state: Optional[str] = None, session: Session = None ) -> Optional[pendulum.DateTime]: """ The start date from property previous_ti_success. @@ -776,11 +785,7 @@ def previous_start_date_success(self) -> Optional[pendulum.DateTime]: return self.get_previous_start_date(state=State.SUCCESS) @provide_session - def are_dependencies_met( - self, - dep_context=None, - session=None, - verbose=False): + def are_dependencies_met(self, dep_context=None, session=None, verbose=False): """ Returns whether or not all the conditions are met for this task instance to be run given the context for the dependencies (e.g. a task instance being force run from @@ -798,14 +803,14 @@ def are_dependencies_met( dep_context = dep_context or DepContext() failed = False verbose_aware_logger = self.log.info if verbose else self.log.debug - for dep_status in self.get_failed_dep_statuses( - dep_context=dep_context, - session=session): + for dep_status in self.get_failed_dep_statuses(dep_context=dep_context, session=session): failed = True verbose_aware_logger( "Dependencies not met for %s, dependency '%s' FAILED: %s", - self, dep_status.dep_name, dep_status.reason + self, + dep_status.dep_name, + dep_status.reason, ) if failed: @@ -815,21 +820,18 @@ def are_dependencies_met( return True @provide_session - def get_failed_dep_statuses( - self, - dep_context=None, - session=None): + def get_failed_dep_statuses(self, dep_context=None, session=None): """Get failed Dependencies""" dep_context = dep_context or DepContext() for dep in dep_context.deps | self.task.deps: - for dep_status in dep.get_dep_statuses( - self, - session, - dep_context): + for dep_status in dep.get_dep_statuses(self, session, dep_context): self.log.debug( "%s dependency '%s' PASSED: %s, %s", - self, dep_status.dep_name, dep_status.passed, dep_status.reason + self, + dep_status.dep_name, + dep_status.passed, + dep_status.reason, ) if not dep_status.passed: @@ -837,8 +839,7 @@ def get_failed_dep_statuses( def __repr__(self): return ( # noqa - "" + "" ).format(ti=self) def next_retry_datetime(self): @@ -853,11 +854,14 @@ def next_retry_datetime(self): # will occur in the modded_hash calculation. min_backoff = int(math.ceil(delay.total_seconds() * (2 ** (self.try_number - 2)))) # deterministic per task instance - ti_hash = int(hashlib.sha1("{}#{}#{}#{}".format(self.dag_id, # noqa - self.task_id, - self.execution_date, - self.try_number) - .encode('utf-8')).hexdigest(), 16) + ti_hash = int( + hashlib.sha1( + "{}#{}#{}#{}".format( + self.dag_id, self.task_id, self.execution_date, self.try_number # noqa + ).encode('utf-8') + ).hexdigest(), + 16, + ) # between 1 and 1.0 * delay * (2^retry_number) modded_hash = min_backoff + ti_hash % min_backoff # timedelta has a maximum representable value. The exponentiation @@ -865,10 +869,7 @@ def next_retry_datetime(self): # of tries (around 50 if the initial delay is 1s, even fewer if # the delay is larger). Cap the value here before creating a # timedelta object so the operation doesn't fail. - delay_backoff_in_seconds = min( - modded_hash, - timedelta.max.total_seconds() - 1 - ) + delay_backoff_in_seconds = min(modded_hash, timedelta.max.total_seconds() - 1) delay = timedelta(seconds=delay_backoff_in_seconds) if self.task.max_retry_delay: delay = min(self.task.max_retry_delay, delay) @@ -879,8 +880,7 @@ def ready_for_retry(self): Checks on whether the task instance is in the right state and timeframe to be retried. """ - return (self.state == State.UP_FOR_RETRY and - self.next_retry_datetime() < timezone.utcnow()) + return self.state == State.UP_FOR_RETRY and self.next_retry_datetime() < timezone.utcnow() @provide_session def get_dagrun(self, session: Session = None): @@ -891,26 +891,29 @@ def get_dagrun(self, session: Session = None): :return: DagRun """ from airflow.models.dagrun import DagRun # Avoid circular import - dr = session.query(DagRun).filter( - DagRun.dag_id == self.dag_id, - DagRun.execution_date == self.execution_date - ).first() + + dr = ( + session.query(DagRun) + .filter(DagRun.dag_id == self.dag_id, DagRun.execution_date == self.execution_date) + .first() + ) return dr @provide_session - def check_and_change_state_before_execution( # pylint: disable=too-many-arguments - self, - verbose: bool = True, - ignore_all_deps: bool = False, - ignore_depends_on_past: bool = False, - ignore_task_deps: bool = False, - ignore_ti_state: bool = False, - mark_success: bool = False, - test_mode: bool = False, - job_id: Optional[str] = None, - pool: Optional[str] = None, - session=None) -> bool: + def check_and_change_state_before_execution( # pylint: disable=too-many-arguments + self, + verbose: bool = True, + ignore_all_deps: bool = False, + ignore_depends_on_past: bool = False, + ignore_task_deps: bool = False, + ignore_ti_state: bool = False, + mark_success: bool = False, + test_mode: bool = False, + job_id: Optional[str] = None, + pool: Optional[str] = None, + session=None, + ) -> bool: """ Checks dependencies and then sets state to RUNNING if they are met. Returns True if and only if state is set to RUNNING, which implies that task should be @@ -960,11 +963,11 @@ def check_and_change_state_before_execution( # pylint: disable=too-many-argum ignore_all_deps=ignore_all_deps, ignore_ti_state=ignore_ti_state, ignore_depends_on_past=ignore_depends_on_past, - ignore_task_deps=ignore_task_deps) + ignore_task_deps=ignore_task_deps, + ) if not self.are_dependencies_met( - dep_context=non_requeueable_dep_context, - session=session, - verbose=True): + dep_context=non_requeueable_dep_context, session=session, verbose=True + ): session.commit() return False @@ -987,17 +990,17 @@ def check_and_change_state_before_execution( # pylint: disable=too-many-argum ignore_all_deps=ignore_all_deps, ignore_depends_on_past=ignore_depends_on_past, ignore_task_deps=ignore_task_deps, - ignore_ti_state=ignore_ti_state) - if not self.are_dependencies_met( - dep_context=dep_context, - session=session, - verbose=True): + ignore_ti_state=ignore_ti_state, + ) + if not self.are_dependencies_met(dep_context=dep_context, session=session, verbose=True): self.state = State.NONE self.log.warning(hr_line_break) self.log.warning( "Rescheduling due to concurrency limits reached " "at task runtime. Attempt %s of " - "%s. State set to NONE.", self.try_number, self.max_tries + 1 + "%s. State set to NONE.", + self.try_number, + self.max_tries + 1, ) self.log.warning(hr_line_break) self.queued_dttm = timezone.utcnow() @@ -1040,12 +1043,13 @@ def _date_or_empty(self, attr): @provide_session @Sentry.enrich_errors def _run_raw_task( - self, - mark_success: bool = False, - test_mode: bool = False, - job_id: Optional[str] = None, - pool: Optional[str] = None, - session=None) -> None: + self, + mark_success: bool = False, + test_mode: bool = False, + job_id: Optional[str] = None, + pool: Optional[str] = None, + session=None, + ) -> None: """ Immediately runs the task (without checking or changing db state before execution) and then sets the appropriate final state after @@ -1130,7 +1134,7 @@ def _run_raw_task( self.task_id, self._date_or_empty('execution_date'), self._date_or_empty('start_date'), - self._date_or_empty('end_date') + self._date_or_empty('end_date'), ) self.set_duration() if not test_mode: @@ -1149,10 +1153,12 @@ def _run_mini_scheduler_on_child_tasks(self, session=None) -> None: try: # Re-select the row with a lock - dag_run = with_row_locks(session.query(DagRun).filter_by( - dag_id=self.dag_id, - execution_date=self.execution_date, - )).one() + dag_run = with_row_locks( + session.query(DagRun).filter_by( + dag_id=self.dag_id, + execution_date=self.execution_date, + ) + ).one() # Get a partial dag with just the specific tasks we want to # examine. In order for dep checks to work correctly, we @@ -1174,9 +1180,7 @@ def _run_mini_scheduler_on_child_tasks(self, session=None) -> None: if task_id not in self.task.downstream_task_ids } - schedulable_tis = [ - ti for ti in info.schedulable_tis if ti.task_id not in skippable_task_ids - ] + schedulable_tis = [ti for ti in info.schedulable_tis if ti.task_id not in skippable_task_ids] for schedulable_ti in schedulable_tis: if not hasattr(schedulable_ti, "task"): schedulable_ti.task = self.task.dag.get_task(schedulable_ti.task_id) @@ -1193,10 +1197,7 @@ def _run_mini_scheduler_on_child_tasks(self, session=None) -> None: ) session.rollback() - def _prepare_and_execute_task_with_callbacks( - self, - context, - task): + def _prepare_and_execute_task_with_callbacks(self, context, task): """Prepare Task for Execution""" from airflow.models.renderedtifields import RenderedTaskInstanceFields @@ -1220,9 +1221,10 @@ def signal_handler(signum, frame): # pylint: disable=unused-argument # Export context to make it available for operators to use. airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True) - self.log.info("Exporting the following env vars:\n%s", - '\n'.join([f"{k}={v}" - for k, v in airflow_context_vars.items()])) + self.log.info( + "Exporting the following env vars:\n%s", + '\n'.join([f"{k}={v}" for k, v in airflow_context_vars.items()]), + ) os.environ.update(airflow_context_vars) @@ -1238,8 +1240,9 @@ def signal_handler(signum, frame): # pylint: disable=unused-argument try: registered = task_copy.register_in_sensor_service(self, context) except Exception as e: - self.log.warning("Failed to register in sensor service." - "Continue to run task in non smart sensor mode.") + self.log.warning( + "Failed to register in sensor service." "Continue to run task in non smart sensor mode." + ) self.log.exception(e, exc_info=True) if registered: @@ -1255,9 +1258,10 @@ def signal_handler(signum, frame): # pylint: disable=unused-argument end_time = time.time() duration = timedelta(seconds=end_time - start_time) - Stats.timing('dag.{dag_id}.{task_id}.duration'.format(dag_id=task_copy.dag_id, - task_id=task_copy.task_id), - duration) + Stats.timing( + f'dag.{task_copy.dag_id}.{task_copy.task_id}.duration', + duration, + ) Stats.incr(f'operator_successes_{self.task.task_type}', 1, 1) Stats.incr('ti_successes') @@ -1310,17 +1314,18 @@ def _run_execute_callback(self, context, task): @provide_session def run( # pylint: disable=too-many-arguments - self, - verbose: bool = True, - ignore_all_deps: bool = False, - ignore_depends_on_past: bool = False, - ignore_task_deps: bool = False, - ignore_ti_state: bool = False, - mark_success: bool = False, - test_mode: bool = False, - job_id: Optional[str] = None, - pool: Optional[str] = None, - session=None) -> None: + self, + verbose: bool = True, + ignore_all_deps: bool = False, + ignore_depends_on_past: bool = False, + ignore_task_deps: bool = False, + ignore_ti_state: bool = False, + mark_success: bool = False, + test_mode: bool = False, + job_id: Optional[str] = None, + pool: Optional[str] = None, + session=None, + ) -> None: """Run TaskInstance""" res = self.check_and_change_state_before_execution( verbose=verbose, @@ -1332,14 +1337,12 @@ def run( # pylint: disable=too-many-arguments test_mode=test_mode, job_id=job_id, pool=pool, - session=session) + session=session, + ) if res: self._run_raw_task( - mark_success=mark_success, - test_mode=test_mode, - job_id=job_id, - pool=pool, - session=session) + mark_success=mark_success, test_mode=test_mode, job_id=job_id, pool=pool, session=session + ) def dry_run(self): """Only Renders Templates for the TI""" @@ -1351,11 +1354,7 @@ def dry_run(self): task_copy.dry_run() @provide_session - def _handle_reschedule(self, - actual_start_date, - reschedule_exception, - test_mode=False, - session=None): + def _handle_reschedule(self, actual_start_date, reschedule_exception, test_mode=False, session=None): # Don't record reschedule request in test mode if test_mode: return @@ -1364,9 +1363,16 @@ def _handle_reschedule(self, self.set_duration() # Log reschedule request - session.add(TaskReschedule(self.task, self.execution_date, self._try_number, - actual_start_date, self.end_date, - reschedule_exception.reschedule_date)) + session.add( + TaskReschedule( + self.task, + self.execution_date, + self._try_number, + actual_start_date, + self.end_date, + reschedule_exception.reschedule_date, + ) + ) # set state self.state = State.UP_FOR_RESCHEDULE @@ -1431,12 +1437,12 @@ def handle_failure(self, error, test_mode=None, context=None, force_fail=False, self.task_id, self._safe_date('execution_date', '%Y%m%dT%H%M%S'), self._safe_date('start_date', '%Y%m%dT%H%M%S'), - self._safe_date('end_date', '%Y%m%dT%H%M%S') + self._safe_date('end_date', '%Y%m%dT%H%M%S'), ) if email_for_state and task.email: try: self.email_alert(error) - except Exception as exec2: # pylint: disable=broad-except + except Exception as exec2: # pylint: disable=broad-except self.log.error('Failed to send email to: %s', task.email) self.log.exception(exec2) @@ -1444,7 +1450,7 @@ def handle_failure(self, error, test_mode=None, context=None, force_fail=False, if callback: try: callback(context) - except Exception as exec3: # pylint: disable=broad-except + except Exception as exec3: # pylint: disable=broad-except self.log.error("Failed at executing callback") self.log.exception(exec3) @@ -1475,11 +1481,10 @@ def get_template_context(self, session=None) -> Context: # pylint: disable=too- if task.dag.params: params.update(task.dag.params) from airflow.models.dagrun import DagRun # Avoid circular import + dag_run = ( session.query(DagRun) - .filter_by( - dag_id=task.dag.dag_id, - execution_date=self.execution_date) + .filter_by(dag_id=task.dag.dag_id, execution_date=self.execution_date) .first() ) run_id = dag_run.run_id if dag_run else None @@ -1523,7 +1528,8 @@ def get_template_context(self, session=None) -> Context: # pylint: disable=too- tomorrow_ds_nodash = tomorrow_ds.replace('-', '') ti_key_str = "{dag_id}__{task_id}__{ds_nodash}".format( - dag_id=task.dag_id, task_id=task.task_id, ds_nodash=ds_nodash) + dag_id=task.dag_id, task_id=task.task_id, ds_nodash=ds_nodash + ) if task.params: params.update(task.params) @@ -1584,7 +1590,7 @@ def __repr__(self): def get( item: str, # pylint: disable=protected-access - default_var: Any = Variable._Variable__NO_DEFAULT_SENTINEL, # noqa + default_var: Any = Variable._Variable__NO_DEFAULT_SENTINEL, # noqa ): """Get Airflow Variable after deserializing JSON value""" return Variable.get(item, default_var=default_var, deserialize_json=True) @@ -1607,9 +1613,11 @@ def get( 'prev_ds_nodash': prev_ds_nodash, 'prev_execution_date': prev_execution_date, 'prev_execution_date_success': lazy_object_proxy.Proxy( - lambda: self.get_previous_execution_date(state=State.SUCCESS)), + lambda: self.get_previous_execution_date(state=State.SUCCESS) + ), 'prev_start_date_success': lazy_object_proxy.Proxy( - lambda: self.get_previous_start_date(state=State.SUCCESS)), + lambda: self.get_previous_start_date(state=State.SUCCESS) + ), 'run_id': run_id, 'task': task, 'task_instance': self, @@ -1632,6 +1640,7 @@ def get( def get_rendered_template_fields(self): """Fetch rendered template fields from DB""" from airflow.models.renderedtifields import RenderedTaskInstanceFields + rendered_task_instance_fields = RenderedTaskInstanceFields.get_templated_fields(self) if rendered_task_instance_fields: for field_name, rendered_value in rendered_task_instance_fields.items(): @@ -1693,15 +1702,18 @@ def get_email_subject_content(self, exception): jinja_context = {'ti': self} # This function is called after changing the state # from State.RUNNING so need to subtract 1 from self.try_number. - jinja_context.update(dict( - exception=exception, - exception_html=exception_html, - try_number=self.try_number - 1, - max_tries=self.max_tries)) + jinja_context.update( + dict( + exception=exception, + exception_html=exception_html, + try_number=self.try_number - 1, + max_tries=self.max_tries, + ) + ) jinja_env = jinja2.Environment( - loader=jinja2.FileSystemLoader(os.path.dirname(__file__)), - autoescape=True) + loader=jinja2.FileSystemLoader(os.path.dirname(__file__)), autoescape=True + ) subject = jinja_env.from_string(default_subject).render(**jinja_context) html_content = jinja_env.from_string(default_html_content).render(**jinja_context) html_content_err = jinja_env.from_string(default_html_content_err).render(**jinja_context) @@ -1709,11 +1721,14 @@ def get_email_subject_content(self, exception): else: jinja_context = self.get_template_context() - jinja_context.update(dict( - exception=exception, - exception_html=exception_html, - try_number=self.try_number - 1, - max_tries=self.max_tries)) + jinja_context.update( + dict( + exception=exception, + exception_html=exception_html, + try_number=self.try_number - 1, + max_tries=self.max_tries, + ) + ) jinja_env = self.task.get_template_env() @@ -1772,8 +1787,8 @@ def xcom_push( if execution_date and execution_date < self.execution_date: raise ValueError( 'execution_date can not be in the past (current ' - 'execution_date is {}; received {})'.format( - self.execution_date, execution_date)) + 'execution_date is {}; received {})'.format(self.execution_date, execution_date) + ) XCom.set( key=key, @@ -1785,13 +1800,13 @@ def xcom_push( ) @provide_session - def xcom_pull( # pylint: disable=inconsistent-return-statements + def xcom_pull( # pylint: disable=inconsistent-return-statements self, task_ids: Optional[Union[str, Iterable[str]]] = None, dag_id: Optional[str] = None, key: str = XCOM_RETURN_KEY, include_prior_dates: bool = False, - session: Session = None + session: Session = None, ) -> Any: """ Pull XComs that optionally meet certain criteria. @@ -1833,7 +1848,7 @@ def xcom_pull( # pylint: disable=inconsistent-return-statements dag_ids=dag_id, task_ids=task_ids, include_prior_dates=include_prior_dates, - session=session + session=session, ).with_entities(XCom.value) # Since we're only fetching the values field, and not the @@ -1851,11 +1866,15 @@ def xcom_pull( # pylint: disable=inconsistent-return-statements def get_num_running_task_instances(self, session): """Return Number of running TIs from the DB""" # .count() is inefficient - return session.query(func.count()).filter( - TaskInstance.dag_id == self.dag_id, - TaskInstance.task_id == self.task_id, - TaskInstance.state == State.RUNNING - ).scalar() + return ( + session.query(func.count()) + .filter( + TaskInstance.dag_id == self.dag_id, + TaskInstance.task_id == self.task_id, + TaskInstance.state == State.RUNNING, + ) + .scalar() + ) def init_run_context(self, raw=False): """Sets the log context.""" @@ -1863,9 +1882,7 @@ def init_run_context(self, raw=False): self._set_context(self) @staticmethod - def filter_for_tis( - tis: Iterable[Union["TaskInstance", TaskInstanceKey]] - ) -> Optional[BooleanClauseList]: + def filter_for_tis(tis: Iterable[Union["TaskInstance", TaskInstanceKey]]) -> Optional[BooleanClauseList]: """Returns SQLAlchemy filter to query selected task instances""" if not tis: return None @@ -1895,7 +1912,8 @@ def filter_for_tis( TaskInstance.dag_id == ti.dag_id, TaskInstance.task_id == ti.task_id, TaskInstance.execution_date == ti.execution_date, - ) for ti in tis + ) + for ti in tis ) @@ -1993,7 +2011,8 @@ def construct_task_instance(self, session=None, lock_for_update=False) -> TaskIn qry = session.query(TaskInstance).filter( TaskInstance.dag_id == self._dag_id, TaskInstance.task_id == self._task_id, - TaskInstance.execution_date == self._execution_date) + TaskInstance.execution_date == self._execution_date, + ) if lock_for_update: ti = qry.with_for_update().first() @@ -2010,5 +2029,6 @@ def construct_task_instance(self, session=None, lock_for_update=False) -> TaskIn from airflow.job.base_job import BaseJob from airflow.models.dagrun import DagRun + TaskInstance.dag_run = relationship(DagRun) TaskInstance.queued_by_job = relationship(BaseJob) diff --git a/airflow/models/taskreschedule.py b/airflow/models/taskreschedule.py index 88e18cdea73c2..6ef99f87ce90f 100644 --- a/airflow/models/taskreschedule.py +++ b/airflow/models/taskreschedule.py @@ -39,17 +39,16 @@ class TaskReschedule(Base): reschedule_date = Column(UtcDateTime, nullable=False) __table_args__ = ( - Index('idx_task_reschedule_dag_task_date', dag_id, task_id, execution_date, - unique=False), - ForeignKeyConstraint([task_id, dag_id, execution_date], - ['task_instance.task_id', 'task_instance.dag_id', - 'task_instance.execution_date'], - name='task_reschedule_dag_task_date_fkey', - ondelete='CASCADE') + Index('idx_task_reschedule_dag_task_date', dag_id, task_id, execution_date, unique=False), + ForeignKeyConstraint( + [task_id, dag_id, execution_date], + ['task_instance.task_id', 'task_instance.dag_id', 'task_instance.execution_date'], + name='task_reschedule_dag_task_date_fkey', + ondelete='CASCADE', + ), ) - def __init__(self, task, execution_date, try_number, start_date, end_date, - reschedule_date): + def __init__(self, task, execution_date, try_number, start_date, end_date, reschedule_date): self.dag_id = task.dag_id self.task_id = task.task_id self.execution_date = execution_date @@ -73,13 +72,11 @@ def query_for_task_instance(task_instance, descending=False, session=None): :type descending: bool """ TR = TaskReschedule - qry = ( - session - .query(TR) - .filter(TR.dag_id == task_instance.dag_id, - TR.task_id == task_instance.task_id, - TR.execution_date == task_instance.execution_date, - TR.try_number == task_instance.try_number) + qry = session.query(TR).filter( + TR.dag_id == task_instance.dag_id, + TR.task_id == task_instance.task_id, + TR.execution_date == task_instance.execution_date, + TR.try_number == task_instance.try_number, ) if descending: return qry.order_by(desc(TR.id)) diff --git a/airflow/models/variable.py b/airflow/models/variable.py index a95e1c905492b..571f41658dc4a 100644 --- a/airflow/models/variable.py +++ b/airflow/models/variable.py @@ -77,7 +77,7 @@ def set_val(self, value): self.is_encrypted = fernet.is_encrypted @declared_attr - def val(cls): # pylint: disable=no-self-argument + def val(cls): # pylint: disable=no-self-argument """Get Airflow Variable from Metadata DB and decode it using the Fernet Key""" return synonym('_val', descriptor=property(cls.get_val, cls.set_val)) @@ -96,8 +96,7 @@ def setdefault(cls, key, default, deserialize_json=False): and un-encode it when retrieving a value :return: Mixed """ - obj = Variable.get(key, default_var=None, - deserialize_json=deserialize_json) + obj = Variable.get(key, default_var=None, deserialize_json=deserialize_json) if obj is None: if default is not None: Variable.set(key, default, serialize_json=deserialize_json) @@ -135,13 +134,7 @@ def get( @classmethod @provide_session - def set( - cls, - key: str, - value: Any, - serialize_json: bool = False, - session: Session = None - ): + def set(cls, key: str, value: Any, serialize_json: bool = False, session: Session = None): """ Sets a value for an Airflow Variable with a given Key diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py index 6300bb05a22b2..bd3ce168077ee 100644 --- a/airflow/models/xcom.py +++ b/airflow/models/xcom.py @@ -72,20 +72,12 @@ def init_on_load(self): def __repr__(self): return ''.format( - key=self.key, - task_id=self.task_id, - execution_date=self.execution_date) + key=self.key, task_id=self.task_id, execution_date=self.execution_date + ) @classmethod @provide_session - def set( - cls, - key, - value, - execution_date, - task_id, - dag_id, - session=None): + def set(cls, key, value, execution_date, task_id, dag_id, session=None): """ Store an XCom value. @@ -97,32 +89,27 @@ def set( # remove any duplicate XComs session.query(cls).filter( - cls.key == key, - cls.execution_date == execution_date, - cls.task_id == task_id, - cls.dag_id == dag_id).delete() + cls.key == key, cls.execution_date == execution_date, cls.task_id == task_id, cls.dag_id == dag_id + ).delete() session.commit() # insert new XCom - session.add(XCom( - key=key, - value=value, - execution_date=execution_date, - task_id=task_id, - dag_id=dag_id)) + session.add(XCom(key=key, value=value, execution_date=execution_date, task_id=task_id, dag_id=dag_id)) session.commit() @classmethod @provide_session - def get_one(cls, - execution_date: pendulum.DateTime, - key: Optional[str] = None, - task_id: Optional[Union[str, Iterable[str]]] = None, - dag_id: Optional[Union[str, Iterable[str]]] = None, - include_prior_dates: bool = False, - session: Session = None) -> Optional[Any]: + def get_one( + cls, + execution_date: pendulum.DateTime, + key: Optional[str] = None, + task_id: Optional[Union[str, Iterable[str]]] = None, + dag_id: Optional[Union[str, Iterable[str]]] = None, + include_prior_dates: bool = False, + session: Session = None, + ) -> Optional[Any]: """ Retrieve an XCom value, optionally meeting certain criteria. Returns None of there are no results. @@ -145,26 +132,30 @@ def get_one(cls, :param session: database session :type session: sqlalchemy.orm.session.Session """ - result = cls.get_many(execution_date=execution_date, - key=key, - task_ids=task_id, - dag_ids=dag_id, - include_prior_dates=include_prior_dates, - session=session).first() + result = cls.get_many( + execution_date=execution_date, + key=key, + task_ids=task_id, + dag_ids=dag_id, + include_prior_dates=include_prior_dates, + session=session, + ).first() if result: return result.value return None @classmethod @provide_session - def get_many(cls, - execution_date: pendulum.DateTime, - key: Optional[str] = None, - task_ids: Optional[Union[str, Iterable[str]]] = None, - dag_ids: Optional[Union[str, Iterable[str]]] = None, - include_prior_dates: bool = False, - limit: Optional[int] = None, - session: Session = None) -> Query: + def get_many( + cls, + execution_date: pendulum.DateTime, + key: Optional[str] = None, + task_ids: Optional[Union[str, Iterable[str]]] = None, + dag_ids: Optional[Union[str, Iterable[str]]] = None, + include_prior_dates: bool = False, + limit: Optional[int] = None, + session: Session = None, + ) -> Query: """ Composes a query to get one or more values from the xcom table. @@ -212,10 +203,11 @@ def get_many(cls, else: filters.append(cls.execution_date == execution_date) - query = (session - .query(cls) - .filter(and_(*filters)) - .order_by(cls.execution_date.desc(), cls.timestamp.desc())) + query = ( + session.query(cls) + .filter(and_(*filters)) + .order_by(cls.execution_date.desc(), cls.timestamp.desc()) + ) if limit: return query.limit(limit) @@ -230,9 +222,7 @@ def delete(cls, xcoms, session=None): xcoms = [xcoms] for xcom in xcoms: if not isinstance(xcom, XCom): - raise TypeError( - f'Expected XCom; received {xcom.__class__.__name__}' - ) + raise TypeError(f'Expected XCom; received {xcom.__class__.__name__}') session.delete(xcom) session.commit() @@ -244,10 +234,12 @@ def serialize_value(value: Any): try: return json.dumps(value).encode('UTF-8') except (ValueError, TypeError): - log.error("Could not serialize the XCOM value into JSON. " - "If you are using pickles instead of JSON " - "for XCOM, then you need to enable pickle " - "support for XCOM in your airflow config.") + log.error( + "Could not serialize the XCOM value into JSON. " + "If you are using pickles instead of JSON " + "for XCOM, then you need to enable pickle " + "support for XCOM in your airflow config." + ) raise @staticmethod @@ -259,10 +251,12 @@ def deserialize_value(result) -> Any: try: return json.loads(result.value.decode('UTF-8')) except JSONDecodeError: - log.error("Could not deserialize the XCOM value from JSON. " - "If you are using pickles instead of JSON " - "for XCOM, then you need to enable pickle " - "support for XCOM in your airflow config.") + log.error( + "Could not deserialize the XCOM value from JSON. " + "If you are using pickles instead of JSON " + "for XCOM, then you need to enable pickle " + "support for XCOM in your airflow config." + ) raise diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py index aa29864fbe351..d09de69fd90e2 100644 --- a/airflow/models/xcom_arg.py +++ b/airflow/models/xcom_arg.py @@ -63,8 +63,7 @@ def __init__(self, operator: BaseOperator, key: str = XCOM_RETURN_KEY): self._key = key def __eq__(self, other): - return (self.operator == other.operator - and self.key == other.key) + return self.operator == other.operator and self.key == other.key def __getitem__(self, item): """Implements xcomresult['some_result_key']""" @@ -80,9 +79,10 @@ def __str__(self): :return: """ - xcom_pull_kwargs = [f"task_ids='{self.operator.task_id}'", - f"dag_id='{self.operator.dag.dag_id}'", - ] + xcom_pull_kwargs = [ + f"task_ids='{self.operator.task_id}'", + f"dag_id='{self.operator.dag.dag_id}'", + ] if self.key is not None: xcom_pull_kwargs.append(f"key='{self.key}'") @@ -132,7 +132,8 @@ def resolve(self, context: Dict) -> Any: if not resolved_value: raise AirflowException( f'XComArg result from {self.operator.task_id} at {self.operator.dag.dag_id} ' - f'with key="{self.key}"" is not found!') + f'with key="{self.key}"" is not found!' + ) resolved_value = resolved_value[0] return resolved_value diff --git a/airflow/operators/bash.py b/airflow/operators/bash.py index 7f99a609bb9cf..6522025a3f7fc 100644 --- a/airflow/operators/bash.py +++ b/airflow/operators/bash.py @@ -103,17 +103,21 @@ class BashOperator(BaseOperator): template_fields = ('bash_command', 'env') template_fields_renderers = {'bash_command': 'bash', 'env': 'json'} - template_ext = ('.sh', '.bash',) + template_ext = ( + '.sh', + '.bash', + ) ui_color = '#f0ede4' @apply_defaults def __init__( - self, - *, - bash_command: str, - env: Optional[Dict[str, str]] = None, - output_encoding: str = 'utf-8', - **kwargs) -> None: + self, + *, + bash_command: str, + env: Optional[Dict[str, str]] = None, + output_encoding: str = 'utf-8', + **kwargs, + ) -> None: super().__init__(**kwargs) self.bash_command = bash_command @@ -136,9 +140,10 @@ def execute(self, context): env = os.environ.copy() airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True) - self.log.debug('Exporting the following env vars:\n%s', - '\n'.join([f"{k}={v}" - for k, v in airflow_context_vars.items()])) + self.log.debug( + 'Exporting the following env vars:\n%s', + '\n'.join([f"{k}={v}" for k, v in airflow_context_vars.items()]), + ) env.update(airflow_context_vars) with TemporaryDirectory(prefix='airflowtmp') as tmp_dir: @@ -158,7 +163,8 @@ def pre_exec(): stderr=STDOUT, cwd=tmp_dir, env=env, - preexec_fn=pre_exec) + preexec_fn=pre_exec, + ) self.log.info('Output:') line = '' diff --git a/airflow/operators/bash_operator.py b/airflow/operators/bash_operator.py index 328771b37aa8b..49cdecf81c14c 100644 --- a/airflow/operators/bash_operator.py +++ b/airflow/operators/bash_operator.py @@ -23,6 +23,5 @@ from airflow.operators.bash import STDOUT, BashOperator, Popen, gettempdir # noqa warnings.warn( - "This module is deprecated. Please use `airflow.operators.bash`.", - DeprecationWarning, stacklevel=2 + "This module is deprecated. Please use `airflow.operators.bash`.", DeprecationWarning, stacklevel=2 ) diff --git a/airflow/operators/check_operator.py b/airflow/operators/check_operator.py index 6b7ea555c5fdd..d18a25c65765b 100644 --- a/airflow/operators/check_operator.py +++ b/airflow/operators/check_operator.py @@ -21,12 +21,14 @@ import warnings from airflow.operators.sql import ( - SQLCheckOperator, SQLIntervalCheckOperator, SQLThresholdCheckOperator, SQLValueCheckOperator, + SQLCheckOperator, + SQLIntervalCheckOperator, + SQLThresholdCheckOperator, + SQLValueCheckOperator, ) warnings.warn( - "This module is deprecated. Please use `airflow.operators.sql`.", - DeprecationWarning, stacklevel=2 + "This module is deprecated. Please use `airflow.operators.sql`.", DeprecationWarning, stacklevel=2 ) @@ -40,7 +42,8 @@ def __init__(self, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.operators.sql.SQLCheckOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(**kwargs) @@ -55,7 +58,8 @@ def __init__(self, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.operators.sql.SQLIntervalCheckOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(**kwargs) @@ -70,7 +74,8 @@ def __init__(self, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.operators.sql.SQLThresholdCheckOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(**kwargs) @@ -85,6 +90,7 @@ def __init__(self, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.operators.sql.SQLValueCheckOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(**kwargs) diff --git a/airflow/operators/dagrun_operator.py b/airflow/operators/dagrun_operator.py index b926081def5df..0ca0cf5335cc9 100644 --- a/airflow/operators/dagrun_operator.py +++ b/airflow/operators/dagrun_operator.py @@ -73,7 +73,7 @@ def __init__( conf: Optional[Dict] = None, execution_date: Optional[Union[str, datetime.datetime]] = None, reset_dag_run: bool = False, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.trigger_dag_id = trigger_dag_id @@ -119,10 +119,7 @@ def execute(self, context: Dict): if dag_model is None: raise DagNotFound(f"Dag id {self.trigger_dag_id} not found in DagModel") - dag_bag = DagBag( - dag_folder=dag_model.fileloc, - read_dags_from_db=True - ) + dag_bag = DagBag(dag_folder=dag_model.fileloc, read_dags_from_db=True) dag = dag_bag.get_dag(self.trigger_dag_id) diff --git a/airflow/operators/docker_operator.py b/airflow/operators/docker_operator.py index 245ee589c0196..eacdf9b7280d2 100644 --- a/airflow/operators/docker_operator.py +++ b/airflow/operators/docker_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.docker.operators.docker`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/operators/druid_check_operator.py b/airflow/operators/druid_check_operator.py index aad52ee431650..d8f30c3727384 100644 --- a/airflow/operators/druid_check_operator.py +++ b/airflow/operators/druid_check_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.apache.druid.operators.druid_check`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/operators/email.py b/airflow/operators/email.py index 80d11310d1a13..c979eaff2066a 100644 --- a/airflow/operators/email.py +++ b/airflow/operators/email.py @@ -52,16 +52,18 @@ class EmailOperator(BaseOperator): @apply_defaults def __init__( # pylint: disable=invalid-name - self, *, - to: Union[List[str], str], - subject: str, - html_content: str, - files: Optional[List] = None, - cc: Optional[Union[List[str], str]] = None, - bcc: Optional[Union[List[str], str]] = None, - mime_subtype: str = 'mixed', - mime_charset: str = 'utf-8', - **kwargs) -> None: + self, + *, + to: Union[List[str], str], + subject: str, + html_content: str, + files: Optional[List] = None, + cc: Optional[Union[List[str], str]] = None, + bcc: Optional[Union[List[str], str]] = None, + mime_subtype: str = 'mixed', + mime_charset: str = 'utf-8', + **kwargs, + ) -> None: super().__init__(**kwargs) self.to = to # pylint: disable=invalid-name self.subject = subject @@ -73,6 +75,13 @@ def __init__( # pylint: disable=invalid-name self.mime_charset = mime_charset def execute(self, context): - send_email(self.to, self.subject, self.html_content, - files=self.files, cc=self.cc, bcc=self.bcc, - mime_subtype=self.mime_subtype, mime_charset=self.mime_charset) + send_email( + self.to, + self.subject, + self.html_content, + files=self.files, + cc=self.cc, + bcc=self.bcc, + mime_subtype=self.mime_subtype, + mime_charset=self.mime_charset, + ) diff --git a/airflow/operators/email_operator.py b/airflow/operators/email_operator.py index 135eb68f13f0b..b3588941ec8fd 100644 --- a/airflow/operators/email_operator.py +++ b/airflow/operators/email_operator.py @@ -23,6 +23,5 @@ from airflow.operators.email import EmailOperator # noqa warnings.warn( - "This module is deprecated. Please use `airflow.operators.email`.", - DeprecationWarning, stacklevel=2 + "This module is deprecated. Please use `airflow.operators.email`.", DeprecationWarning, stacklevel=2 ) diff --git a/airflow/operators/gcs_to_s3.py b/airflow/operators/gcs_to_s3.py index 19affc7c48553..5c9fb37f9f604 100644 --- a/airflow/operators/gcs_to_s3.py +++ b/airflow/operators/gcs_to_s3.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.amazon.aws.transfers.gcs_to_s3`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/operators/generic_transfer.py b/airflow/operators/generic_transfer.py index 59f8fdcd86c57..a166b3a7a47a8 100644 --- a/airflow/operators/generic_transfer.py +++ b/airflow/operators/generic_transfer.py @@ -45,19 +45,23 @@ class GenericTransfer(BaseOperator): """ template_fields = ('sql', 'destination_table', 'preoperator') - template_ext = ('.sql', '.hql',) + template_ext = ( + '.sql', + '.hql', + ) ui_color = '#b0f07c' @apply_defaults def __init__( - self, - *, - sql: str, - destination_table: str, - source_conn_id: str, - destination_conn_id: str, - preoperator: Optional[Union[str, List[str]]] = None, - **kwargs) -> None: + self, + *, + sql: str, + destination_table: str, + source_conn_id: str, + destination_conn_id: str, + preoperator: Optional[Union[str, List[str]]] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self.sql = sql self.destination_table = destination_table diff --git a/airflow/operators/google_api_to_s3_transfer.py b/airflow/operators/google_api_to_s3_transfer.py index 6f6d26500ef94..31b47fa0f08d4 100644 --- a/airflow/operators/google_api_to_s3_transfer.py +++ b/airflow/operators/google_api_to_s3_transfer.py @@ -25,9 +25,9 @@ from airflow.providers.amazon.aws.transfers.google_api_to_s3 import GoogleApiToS3Operator warnings.warn( - "This module is deprecated. " - "Please use `airflow.providers.amazon.aws.transfers.google_api_to_s3`.", - DeprecationWarning, stacklevel=2 + "This module is deprecated. " "Please use `airflow.providers.amazon.aws.transfers.google_api_to_s3`.", + DeprecationWarning, + stacklevel=2, ) @@ -42,8 +42,9 @@ def __init__(self, **kwargs): warnings.warn( """This class is deprecated. Please use - `airflow.providers.amazon.aws.transfers.""" + - "google_api_to_s3_transfer.GoogleApiToS3Operator`.", - DeprecationWarning, stacklevel=3 + `airflow.providers.amazon.aws.transfers.""" + + "google_api_to_s3_transfer.GoogleApiToS3Operator`.", + DeprecationWarning, + stacklevel=3, ) super().__init__(**kwargs) diff --git a/airflow/operators/hive_operator.py b/airflow/operators/hive_operator.py index 196a863de7831..7429733e77403 100644 --- a/airflow/operators/hive_operator.py +++ b/airflow/operators/hive_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.apache.hive.operators.hive`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/operators/hive_stats_operator.py b/airflow/operators/hive_stats_operator.py index 4529d06cfeefb..737ccc4cd7ead 100644 --- a/airflow/operators/hive_stats_operator.py +++ b/airflow/operators/hive_stats_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.apache.hive.operators.hive_stats`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/operators/hive_to_druid.py b/airflow/operators/hive_to_druid.py index 88a57cba0dc22..6d29f0dbed00a 100644 --- a/airflow/operators/hive_to_druid.py +++ b/airflow/operators/hive_to_druid.py @@ -27,7 +27,8 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.apache.druid.transfers.hive_to_druid`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -43,6 +44,7 @@ def __init__(self, **kwargs): """This class is deprecated. Please use `airflow.providers.apache.druid.transfers.hive_to_druid.HiveToDruidOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(**kwargs) diff --git a/airflow/operators/hive_to_mysql.py b/airflow/operators/hive_to_mysql.py index 6264e57e61caa..bc42d89ec036c 100644 --- a/airflow/operators/hive_to_mysql.py +++ b/airflow/operators/hive_to_mysql.py @@ -27,7 +27,8 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.apache.hive.transfers.hive_to_mysql`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -43,6 +44,7 @@ def __init__(self, **kwargs): """This class is deprecated. Please use `airflow.providers.apache.hive.transfers.hive_to_mysql.HiveToMySqlOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(**kwargs) diff --git a/airflow/operators/hive_to_samba_operator.py b/airflow/operators/hive_to_samba_operator.py index d1d1e942e9ef1..3716bf90e9b3e 100644 --- a/airflow/operators/hive_to_samba_operator.py +++ b/airflow/operators/hive_to_samba_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.apache.hive.transfers.hive_to_samba`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/operators/http_operator.py b/airflow/operators/http_operator.py index edecef3c345d3..34e938144f009 100644 --- a/airflow/operators/http_operator.py +++ b/airflow/operators/http_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.http.operators.http`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/operators/jdbc_operator.py b/airflow/operators/jdbc_operator.py index 3361c4783824c..c0bf7021eab1a 100644 --- a/airflow/operators/jdbc_operator.py +++ b/airflow/operators/jdbc_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.jdbc.operators.jdbc`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/operators/latest_only.py b/airflow/operators/latest_only.py index d19e1b047efaa..e458c72e28e3f 100644 --- a/airflow/operators/latest_only.py +++ b/airflow/operators/latest_only.py @@ -44,17 +44,17 @@ def choose_branch(self, context: Dict) -> Union[str, Iterable[str]]: # If the DAG Run is externally triggered, then return without # skipping downstream tasks if context['dag_run'] and context['dag_run'].external_trigger: - self.log.info( - "Externally triggered DAG_Run: allowing execution to proceed.") + self.log.info("Externally triggered DAG_Run: allowing execution to proceed.") return list(context['task'].get_direct_relative_ids(upstream=False)) now = pendulum.now('UTC') - left_window = context['dag'].following_schedule( - context['execution_date']) + left_window = context['dag'].following_schedule(context['execution_date']) right_window = context['dag'].following_schedule(left_window) self.log.info( 'Checking latest only with left_window: %s right_window: %s now: %s', - left_window, right_window, now + left_window, + right_window, + now, ) if not left_window < now <= right_window: diff --git a/airflow/operators/latest_only_operator.py b/airflow/operators/latest_only_operator.py index af6c90e19397e..cb1a531260305 100644 --- a/airflow/operators/latest_only_operator.py +++ b/airflow/operators/latest_only_operator.py @@ -22,6 +22,5 @@ from airflow.operators.latest_only import LatestOnlyOperator # noqa warnings.warn( - "This module is deprecated. Please use `airflow.operators.latest_only`.", - DeprecationWarning, stacklevel=2 + "This module is deprecated. Please use `airflow.operators.latest_only`.", DeprecationWarning, stacklevel=2 ) diff --git a/airflow/operators/mssql_operator.py b/airflow/operators/mssql_operator.py index 07a20192a51a2..8b5e8cc9caee3 100644 --- a/airflow/operators/mssql_operator.py +++ b/airflow/operators/mssql_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.microsoft.mssql.operators.mssql`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/operators/mssql_to_hive.py b/airflow/operators/mssql_to_hive.py index 3fd6f7496280d..8749d709d9de6 100644 --- a/airflow/operators/mssql_to_hive.py +++ b/airflow/operators/mssql_to_hive.py @@ -26,7 +26,8 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.apache.hive.transfers.mssql_to_hive`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -42,6 +43,7 @@ def __init__(self, **kwargs): """This class is deprecated. Please use `airflow.providers.apache.hive.transfers.mssql_to_hive.MsSqlToHiveOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(**kwargs) diff --git a/airflow/operators/mysql_operator.py b/airflow/operators/mysql_operator.py index 6b2b1849ecbb7..05ae818cc485a 100644 --- a/airflow/operators/mysql_operator.py +++ b/airflow/operators/mysql_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.mysql.operators.mysql`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/operators/mysql_to_hive.py b/airflow/operators/mysql_to_hive.py index 47735bb8d679c..243980f55f187 100644 --- a/airflow/operators/mysql_to_hive.py +++ b/airflow/operators/mysql_to_hive.py @@ -23,7 +23,8 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.apache.hive.transfers.mysql_to_hive`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -37,6 +38,7 @@ def __init__(self, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.apache.hive.transfers.mysql_to_hive.MySqlToHiveOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(**kwargs) diff --git a/airflow/operators/oracle_operator.py b/airflow/operators/oracle_operator.py index 7b5e5a1e4dc9c..1734d22d259a2 100644 --- a/airflow/operators/oracle_operator.py +++ b/airflow/operators/oracle_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.oracle.operators.oracle`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/operators/papermill_operator.py b/airflow/operators/papermill_operator.py index f071866a12bad..ab2d4fcc316e4 100644 --- a/airflow/operators/papermill_operator.py +++ b/airflow/operators/papermill_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.papermill.operators.papermill`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/operators/pig_operator.py b/airflow/operators/pig_operator.py index 82e49452b6df5..4eec846a1f3e6 100644 --- a/airflow/operators/pig_operator.py +++ b/airflow/operators/pig_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.apache.pig.operators.pig`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/operators/postgres_operator.py b/airflow/operators/postgres_operator.py index 815311c162d93..c6f91b6bc1455 100644 --- a/airflow/operators/postgres_operator.py +++ b/airflow/operators/postgres_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.postgres.operators.postgres`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/operators/presto_check_operator.py b/airflow/operators/presto_check_operator.py index 07546c297a889..79cc12ba7e5ff 100644 --- a/airflow/operators/presto_check_operator.py +++ b/airflow/operators/presto_check_operator.py @@ -23,8 +23,7 @@ from airflow.operators.sql import SQLCheckOperator, SQLIntervalCheckOperator, SQLValueCheckOperator # noqa warnings.warn( - "This module is deprecated. Please use `airflow.operators.sql`.", - DeprecationWarning, stacklevel=2 + "This module is deprecated. Please use `airflow.operators.sql`.", DeprecationWarning, stacklevel=2 ) @@ -38,7 +37,8 @@ def __init__(self, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.operators.sql.SQLCheckOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(**kwargs) @@ -55,7 +55,8 @@ def __init__(self, **kwargs): This class is deprecated.l Please use `airflow.operators.sql.SQLIntervalCheckOperator`. """, - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(**kwargs) @@ -72,6 +73,7 @@ def __init__(self, **kwargs): This class is deprecated.l Please use `airflow.operators.sql.SQLValueCheckOperator`. """, - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(**kwargs) diff --git a/airflow/operators/presto_to_mysql.py b/airflow/operators/presto_to_mysql.py index 6a50d6bf6b694..f5ccd423067dd 100644 --- a/airflow/operators/presto_to_mysql.py +++ b/airflow/operators/presto_to_mysql.py @@ -27,7 +27,8 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.mysql.transfers.presto_to_mysql`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -43,6 +44,7 @@ def __init__(self, **kwargs): """This class is deprecated. Please use `airflow.providers.mysql.transfers.presto_to_mysql.PrestoToMySqlOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(**kwargs) diff --git a/airflow/operators/python.py b/airflow/operators/python.py index 344f0ecaceb4b..08546cabc6e6e 100644 --- a/airflow/operators/python.py +++ b/airflow/operators/python.py @@ -73,7 +73,10 @@ class PythonOperator(BaseOperator): # since we won't mutate the arguments, we should just do the shallow copy # there are some cases we can't deepcopy the objects(e.g protobuf). - shallow_copy_attrs = ('python_callable', 'op_kwargs',) + shallow_copy_attrs = ( + 'python_callable', + 'op_kwargs', + ) @apply_defaults def __init__( @@ -84,11 +87,14 @@ def __init__( op_kwargs: Optional[Dict] = None, templates_dict: Optional[Dict] = None, templates_exts: Optional[List[str]] = None, - **kwargs + **kwargs, ) -> None: if kwargs.get("provide_context"): - warnings.warn("provide_context is deprecated as of 2.0 and is no longer required", - DeprecationWarning, stacklevel=2) + warnings.warn( + "provide_context is deprecated as of 2.0 and is no longer required", + DeprecationWarning, + stacklevel=2, + ) kwargs.pop('provide_context', None) super().__init__(**kwargs) if not callable(python_callable): @@ -101,9 +107,7 @@ def __init__( self.template_ext = templates_exts @staticmethod - def determine_op_kwargs(python_callable: Callable, - context: Dict, - num_op_args: int = 0) -> Dict: + def determine_op_kwargs(python_callable: Callable, context: Dict, num_op_args: int = 0) -> Dict: """ Function that will inspect the signature of a python_callable to determine which values need to be passed to the function. @@ -190,7 +194,7 @@ def __init__( op_args: Tuple[Any], op_kwargs: Dict[str, Any], multiple_outputs: bool = False, - **kwargs + **kwargs, ) -> None: kwargs['task_id'] = self._get_unique_task_id(task_id, kwargs.get('dag')) super().__init__(**kwargs) @@ -221,9 +225,11 @@ def _get_unique_task_id(task_id: str, dag: Optional[DAG] = None) -> str: return task_id core = re.split(r'__\d+$', task_id)[0] suffixes = sorted( - [int(re.split(r'^.+__', task_id)[1]) - for task_id in dag.task_ids - if re.match(rf'^{core}__\d+$', task_id)] + [ + int(re.split(r'^.+__', task_id)[1]) + for task_id in dag.task_ids + if re.match(rf'^{core}__\d+$', task_id) + ] ) if not suffixes: return f'{core}__1' @@ -251,13 +257,16 @@ def execute(self, context: Dict): if isinstance(return_value, dict): for key in return_value.keys(): if not isinstance(key, str): - raise AirflowException('Returned dictionary keys must be strings when using ' - f'multiple_outputs, found {key} ({type(key)}) instead') + raise AirflowException( + 'Returned dictionary keys must be strings when using ' + f'multiple_outputs, found {key} ({type(key)}) instead' + ) for key, value in return_value.items(): self.xcom_push(context, key, value) else: - raise AirflowException(f'Returned output was type {type(return_value)} expected dictionary ' - 'for multiple_outputs') + raise AirflowException( + f'Returned output was type {type(return_value)} expected dictionary ' 'for multiple_outputs' + ) return return_value @@ -265,9 +274,7 @@ def execute(self, context: Dict): def task( - python_callable: Optional[Callable] = None, - multiple_outputs: bool = False, - **kwargs + python_callable: Optional[Callable] = None, multiple_outputs: bool = False, **kwargs ) -> Callable[[T], T]: """ Python operator decorator. Wraps a function into an Airflow operator. @@ -282,6 +289,7 @@ def task( :type multiple_outputs: bool """ + def wrapper(f: T): """ Python wrapper to generate PythonDecoratedOperator out of simple python functions. @@ -292,12 +300,19 @@ def wrapper(f: T): @functools.wraps(f) def factory(*args, **f_kwargs): - op = _PythonDecoratedOperator(python_callable=f, op_args=args, op_kwargs=f_kwargs, - multiple_outputs=multiple_outputs, **kwargs) + op = _PythonDecoratedOperator( + python_callable=f, + op_args=args, + op_kwargs=f_kwargs, + multiple_outputs=multiple_outputs, + **kwargs, + ) if f.__doc__: op.doc_md = f.__doc__ return XComArg(op) + return cast(T, factory) + if callable(python_callable): return wrapper(python_callable) elif python_callable is not None: @@ -427,22 +442,16 @@ class PythonVirtualenvOperator(PythonOperator): 'ts_nodash', 'ts_nodash_with_tz', 'yesterday_ds', - 'yesterday_ds_nodash' + 'yesterday_ds_nodash', } PENDULUM_SERIALIZABLE_CONTEXT_KEYS = { 'execution_date', 'next_execution_date', 'prev_execution_date', 'prev_execution_date_success', - 'prev_start_date_success' - } - AIRFLOW_SERIALIZABLE_CONTEXT_KEYS = { - 'macros', - 'conf', - 'dag', - 'dag_run', - 'task' + 'prev_start_date_success', } + AIRFLOW_SERIALIZABLE_CONTEXT_KEYS = {'macros', 'conf', 'dag', 'dag_run', 'task'} @apply_defaults def __init__( # pylint: disable=too-many-arguments @@ -458,26 +467,31 @@ def __init__( # pylint: disable=too-many-arguments string_args: Optional[Iterable[str]] = None, templates_dict: Optional[Dict] = None, templates_exts: Optional[List[str]] = None, - **kwargs + **kwargs, ): if ( - not isinstance(python_callable, types.FunctionType) or - isinstance(python_callable, types.LambdaType) and python_callable.__name__ == "" + not isinstance(python_callable, types.FunctionType) + or isinstance(python_callable, types.LambdaType) + and python_callable.__name__ == "" ): raise AirflowException('PythonVirtualenvOperator only supports functions for python_callable arg') if ( - python_version and str(python_version)[0] != str(sys.version_info.major) and - (op_args or op_kwargs) + python_version + and str(python_version)[0] != str(sys.version_info.major) + and (op_args or op_kwargs) ): - raise AirflowException("Passing op_args or op_kwargs is not supported across different Python " - "major versions for PythonVirtualenvOperator. Please use string_args.") + raise AirflowException( + "Passing op_args or op_kwargs is not supported across different Python " + "major versions for PythonVirtualenvOperator. Please use string_args." + ) super().__init__( python_callable=python_callable, op_args=op_args, op_kwargs=op_kwargs, templates_dict=templates_dict, templates_exts=templates_exts, - **kwargs) + **kwargs, + ) self.requirements = list(requirements or []) self.string_args = string_args or [] self.python_version = python_version @@ -505,7 +519,7 @@ def execute_callable(self): venv_directory=tmp_dir, python_bin=f'python{self.python_version}' if self.python_version else None, system_site_packages=self.system_site_packages, - requirements=self.requirements + requirements=self.requirements, ) self._write_args(input_filename) @@ -516,18 +530,20 @@ def execute_callable(self): op_kwargs=self.op_kwargs, pickling_library=self.pickling_library.__name__, python_callable=self.python_callable.__name__, - python_callable_source=dedent(inspect.getsource(self.python_callable)) + python_callable_source=dedent(inspect.getsource(self.python_callable)), ), - filename=script_filename + filename=script_filename, ) - execute_in_subprocess(cmd=[ - f'{tmp_dir}/bin/python', - script_filename, - input_filename, - output_filename, - string_args_filename - ]) + execute_in_subprocess( + cmd=[ + f'{tmp_dir}/bin/python', + script_filename, + input_filename, + output_filename, + string_args_filename, + ] + ) return self._read_result(output_filename) @@ -561,8 +577,10 @@ def _read_result(self, filename): try: return self.pickling_library.load(file) except ValueError: - self.log.error("Error deserializing result. Note that result deserialization " - "is not supported across major Python versions.") + self.log.error( + "Error deserializing result. Note that result deserialization " + "is not supported across major Python versions." + ) raise diff --git a/airflow/operators/python_operator.py b/airflow/operators/python_operator.py index 5318c15cc08ed..7ae3d4fa78c80 100644 --- a/airflow/operators/python_operator.py +++ b/airflow/operators/python_operator.py @@ -21,10 +21,12 @@ # pylint: disable=unused-import from airflow.operators.python import ( # noqa - BranchPythonOperator, PythonOperator, PythonVirtualenvOperator, ShortCircuitOperator, + BranchPythonOperator, + PythonOperator, + PythonVirtualenvOperator, + ShortCircuitOperator, ) warnings.warn( - "This module is deprecated. Please use `airflow.operators.python`.", - DeprecationWarning, stacklevel=2 + "This module is deprecated. Please use `airflow.operators.python`.", DeprecationWarning, stacklevel=2 ) diff --git a/airflow/operators/redshift_to_s3_operator.py b/airflow/operators/redshift_to_s3_operator.py index a1ebadb4a54cc..20b92306432f2 100644 --- a/airflow/operators/redshift_to_s3_operator.py +++ b/airflow/operators/redshift_to_s3_operator.py @@ -26,7 +26,8 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.amazon.aws.transfers.redshift_to_s3`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -42,6 +43,7 @@ def __init__(self, **kwargs): """This class is deprecated. Please use `airflow.providers.amazon.aws.transfers.redshift_to_s3.RedshiftToS3Operator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(**kwargs) diff --git a/airflow/operators/s3_file_transform_operator.py b/airflow/operators/s3_file_transform_operator.py index 8b175a8c7f688..6616778495225 100644 --- a/airflow/operators/s3_file_transform_operator.py +++ b/airflow/operators/s3_file_transform_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3_file_transform`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/operators/s3_to_hive_operator.py b/airflow/operators/s3_to_hive_operator.py index d7108cc072304..584a31f37c0ed 100644 --- a/airflow/operators/s3_to_hive_operator.py +++ b/airflow/operators/s3_to_hive_operator.py @@ -23,7 +23,8 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.apache.hive.transfers.s3_to_hive`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -37,6 +38,7 @@ def __init__(self, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.providers.apache.hive.transfers.s3_to_hive.S3ToHiveOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(**kwargs) diff --git a/airflow/operators/s3_to_redshift_operator.py b/airflow/operators/s3_to_redshift_operator.py index 2148fdb429b39..b885c58c3e6ce 100644 --- a/airflow/operators/s3_to_redshift_operator.py +++ b/airflow/operators/s3_to_redshift_operator.py @@ -26,7 +26,8 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.amazon.aws.transfers.s3_to_redshift`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) @@ -42,6 +43,7 @@ def __init__(self, **kwargs): """This class is deprecated. Please use `airflow.providers.amazon.aws.transfers.s3_to_redshift.S3ToRedshiftOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(**kwargs) diff --git a/airflow/operators/slack_operator.py b/airflow/operators/slack_operator.py index e154d9f40b576..dedd7699bc94f 100644 --- a/airflow/operators/slack_operator.py +++ b/airflow/operators/slack_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.slack.operators.slack`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/operators/sql.py b/airflow/operators/sql.py index 45cb07c2c131e..3210ff1ae0148 100644 --- a/airflow/operators/sql.py +++ b/airflow/operators/sql.py @@ -71,13 +71,14 @@ class SQLCheckOperator(BaseOperator): """ template_fields: Iterable[str] = ("sql",) - template_ext: Iterable[str] = (".hql", ".sql",) + template_ext: Iterable[str] = ( + ".hql", + ".sql", + ) ui_color = "#fff7e6" @apply_defaults - def __init__( - self, *, sql: str, conn_id: Optional[str] = None, **kwargs - ) -> None: + def __init__(self, *, sql: str, conn_id: Optional[str] = None, **kwargs) -> None: super().__init__(**kwargs) self.conn_id = conn_id self.sql = sql @@ -90,11 +91,7 @@ def execute(self, context=None): if not records: raise AirflowException("The query returned None") elif not all(bool(r) for r in records): - raise AirflowException( - "Test failed.\nQuery:\n{query}\nResults:\n{records!s}".format( - query=self.sql, records=records - ) - ) + raise AirflowException(f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}") self.log.info("Success.") @@ -148,7 +145,8 @@ class SQLValueCheckOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, sql: str, pass_value: Any, tolerance: Any = None, @@ -191,9 +189,7 @@ def execute(self, context=None): try: numeric_records = self._to_float(records) except (ValueError, TypeError): - raise AirflowException( - f"Converting a result to float failed.\n{error_msg}" - ) + raise AirflowException(f"Converting a result to float failed.\n{error_msg}") tests = self._get_numeric_matches(numeric_records, pass_value_conv) else: tests = [] @@ -257,7 +253,10 @@ class SQLIntervalCheckOperator(BaseOperator): __mapper_args__ = {"polymorphic_identity": "SQLIntervalCheckOperator"} template_fields: Iterable[str] = ("sql1", "sql2") - template_ext: Iterable[str] = (".hql", ".sql",) + template_ext: Iterable[str] = ( + ".hql", + ".sql", + ) ui_color = "#fff7e6" ratio_formulas = { @@ -267,7 +266,8 @@ class SQLIntervalCheckOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, table: str, metrics_thresholds: Dict[str, int], date_filter_column: Optional[str] = "ds", @@ -279,15 +279,10 @@ def __init__( ): super().__init__(**kwargs) if ratio_formula not in self.ratio_formulas: - msg_template = ( - "Invalid diff_method: {diff_method}. " - "Supported diff methods are: {diff_methods}" - ) + msg_template = "Invalid diff_method: {diff_method}. Supported diff methods are: {diff_methods}" raise AirflowException( - msg_template.format( - diff_method=ratio_formula, diff_methods=self.ratio_formulas - ) + msg_template.format(diff_method=ratio_formula, diff_methods=self.ratio_formulas) ) self.ratio_formula = ratio_formula self.ignore_zero = ignore_zero @@ -332,9 +327,7 @@ def execute(self, context=None): ratios[metric] = None test_results[metric] = self.ignore_zero else: - ratios[metric] = self.ratio_formulas[self.ratio_formula]( - current[metric], reference[metric] - ) + ratios[metric] = self.ratio_formulas[self.ratio_formula](current[metric], reference[metric]) test_results[metric] = ratios[metric] < threshold self.log.info( @@ -368,9 +361,7 @@ def execute(self, context=None): self.metrics_thresholds[k], ) raise AirflowException( - "The following tests have failed:\n {}".format( - ", ".join(sorted(failed_tests)) - ) + "The following tests have failed:\n {}".format(", ".join(sorted(failed_tests))) ) self.log.info("All tests have passed") @@ -411,7 +402,8 @@ class SQLThresholdCheckOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, sql: str, min_threshold: Any, max_threshold: Any, @@ -500,7 +492,8 @@ class BranchSQLOperator(BaseOperator, SkipMixin): @apply_defaults def __init__( - self, *, + self, + *, sql: str, follow_task_ids_if_true: List[str], follow_task_ids_if_false: List[str], @@ -525,7 +518,9 @@ def _get_hook(self): if conn.conn_type not in ALLOWED_CONN_TYPE: raise AirflowException( "The connection type is not supported by BranchSQLOperator.\ - Supported connection types: {}".format(list(ALLOWED_CONN_TYPE)) + Supported connection types: {}".format( + list(ALLOWED_CONN_TYPE) + ) ) if not self._hook: @@ -540,22 +535,16 @@ def execute(self, context: Dict): self._hook = self._get_hook() if self._hook is None: - raise AirflowException( - "Failed to establish connection to '%s'" % self.conn_id - ) + raise AirflowException("Failed to establish connection to '%s'" % self.conn_id) if self.sql is None: raise AirflowException("Expected 'sql' parameter is missing.") if self.follow_task_ids_if_true is None: - raise AirflowException( - "Expected 'follow_task_ids_if_true' parameter is missing." - ) + raise AirflowException("Expected 'follow_task_ids_if_true' parameter is missing.") if self.follow_task_ids_if_false is None: - raise AirflowException( - "Expected 'follow_task_ids_if_false' parameter is missing." - ) + raise AirflowException("Expected 'follow_task_ids_if_false' parameter is missing.") self.log.info( "Executing: %s (with parameters %s) with connection: %s", @@ -595,16 +584,14 @@ def execute(self, context: Dict): follow_branch = self.follow_task_ids_if_true else: raise AirflowException( - "Unexpected query return result '%s' type '%s'" - % (query_result, type(query_result)) + "Unexpected query return result '{}' type '{}'".format(query_result, type(query_result)) ) if follow_branch is None: follow_branch = self.follow_task_ids_if_false except ValueError: raise AirflowException( - "Unexpected query return result '%s' type '%s'" - % (query_result, type(query_result)) + "Unexpected query return result '{}' type '{}'".format(query_result, type(query_result)) ) self.skip_all_except(context["ti"], follow_branch) diff --git a/airflow/operators/sql_branch_operator.py b/airflow/operators/sql_branch_operator.py index 3c773b6109040..5dfa7b36ab2f6 100644 --- a/airflow/operators/sql_branch_operator.py +++ b/airflow/operators/sql_branch_operator.py @@ -20,8 +20,7 @@ from airflow.operators.sql import BranchSQLOperator warnings.warn( - """This module is deprecated. Please use `airflow.operators.sql`.""", - DeprecationWarning, stacklevel=2 + """This module is deprecated. Please use `airflow.operators.sql`.""", DeprecationWarning, stacklevel=2 ) @@ -35,6 +34,7 @@ def __init__(self, **kwargs): warnings.warn( """This class is deprecated. Please use `airflow.operators.sql.BranchSQLOperator`.""", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) super().__init__(**kwargs) diff --git a/airflow/operators/sqlite_operator.py b/airflow/operators/sqlite_operator.py index 6729852a428e7..daa496da70f2f 100644 --- a/airflow/operators/sqlite_operator.py +++ b/airflow/operators/sqlite_operator.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.sqlite.operators.sqlite`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/operators/subdag_operator.py b/airflow/operators/subdag_operator.py index 7ef0392de68b8..f5c147e940439 100644 --- a/airflow/operators/subdag_operator.py +++ b/airflow/operators/subdag_operator.py @@ -63,13 +63,15 @@ class SubDagOperator(BaseSensorOperator): @provide_session @apply_defaults - def __init__(self, - *, - subdag: DAG, - session: Optional[Session] = None, - conf: Optional[Dict] = None, - propagate_skipped_state: Optional[SkippedStatePropagationOptions] = None, - **kwargs) -> None: + def __init__( + self, + *, + subdag: DAG, + session: Optional[Session] = None, + conf: Optional[Dict] = None, + propagate_skipped_state: Optional[SkippedStatePropagationOptions] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self.subdag = subdag self.conf = conf @@ -88,7 +90,8 @@ def _validate_dag(self, kwargs): raise AirflowException( "The subdag's dag_id should have the form '{{parent_dag_id}}.{{this_task_id}}'. " "Expected '{d}.{t}'; received '{rcvd}'.".format( - d=dag.dag_id, t=kwargs['task_id'], rcvd=self.subdag.dag_id) + d=dag.dag_id, t=kwargs['task_id'], rcvd=self.subdag.dag_id + ) ) def _validate_pool(self, session): @@ -96,11 +99,7 @@ def _validate_pool(self, session): conflicts = [t for t in self.subdag.tasks if t.pool == self.pool] if conflicts: # only query for pool conflicts if one may exist - pool = (session - .query(Pool) - .filter(Pool.slots == 1) - .filter(Pool.pool == self.pool) - .first()) + pool = session.query(Pool).filter(Pool.slots == 1).filter(Pool.pool == self.pool).first() if pool and any(t.pool == self.pool for t in self.subdag.tasks): raise AirflowException( 'SubDagOperator {sd} and subdag task{plural} {t} both ' @@ -109,7 +108,7 @@ def _validate_pool(self, session): sd=self.task_id, plural=len(conflicts) > 1, t=', '.join(t.task_id for t in conflicts), - p=self.pool + p=self.pool, ) ) @@ -132,11 +131,12 @@ def _reset_dag_run_and_task_instances(self, dag_run, execution_date): with create_session() as session: dag_run.state = State.RUNNING session.merge(dag_run) - failed_task_instances = (session - .query(TaskInstance) - .filter(TaskInstance.dag_id == self.subdag.dag_id) - .filter(TaskInstance.execution_date == execution_date) - .filter(TaskInstance.state.in_([State.FAILED, State.UPSTREAM_FAILED]))) + failed_task_instances = ( + session.query(TaskInstance) + .filter(TaskInstance.dag_id == self.subdag.dag_id) + .filter(TaskInstance.execution_date == execution_date) + .filter(TaskInstance.state.in_([State.FAILED, State.UPSTREAM_FAILED])) + ) for task_instance in failed_task_instances: task_instance.state = State.NONE @@ -172,9 +172,7 @@ def post_execute(self, context, result=None): self.log.info("Execution finished. State is %s", dag_run.state) if dag_run.state != State.SUCCESS: - raise AirflowException( - f"Expected state: SUCCESS. Actual state: {dag_run.state}" - ) + raise AirflowException(f"Expected state: SUCCESS. Actual state: {dag_run.state}") if self.propagate_skipped_state and self._check_skipped_states(context): self._skip_downstream_tasks(context) @@ -187,16 +185,15 @@ def _check_skipped_states(self, context): if self.propagate_skipped_state == SkippedStatePropagationOptions.ALL_LEAVES: return all(ti.state == State.SKIPPED for ti in leaves_tis) raise AirflowException( - f'Unimplemented SkippedStatePropagationOptions {self.propagate_skipped_state} used.') + f'Unimplemented SkippedStatePropagationOptions {self.propagate_skipped_state} used.' + ) def _get_leaves_tis(self, execution_date): leaves_tis = [] for leaf in self.subdag.leaves: try: ti = get_task_instance( - dag_id=self.subdag.dag_id, - task_id=leaf.task_id, - execution_date=execution_date + dag_id=self.subdag.dag_id, task_id=leaf.task_id, execution_date=execution_date ) leaves_tis.append(ti) except TaskInstanceNotFound: @@ -204,8 +201,11 @@ def _get_leaves_tis(self, execution_date): return leaves_tis def _skip_downstream_tasks(self, context): - self.log.info('Skipping downstream tasks because propagate_skipped_state is set to %s ' - 'and skipped task(s) were found.', self.propagate_skipped_state) + self.log.info( + 'Skipping downstream tasks because propagate_skipped_state is set to %s ' + 'and skipped task(s) were found.', + self.propagate_skipped_state, + ) downstream_tasks = context['task'].downstream_list self.log.debug('Downstream task_ids %s', downstream_tasks) diff --git a/airflow/plugins_manager.py b/airflow/plugins_manager.py index f6fd08b6c0e2e..99c05b9fce6b3 100644 --- a/airflow/plugins_manager.py +++ b/airflow/plugins_manager.py @@ -160,9 +160,9 @@ def is_valid_plugin(plugin_obj): global plugins # pylint: disable=global-statement if ( - inspect.isclass(plugin_obj) and - issubclass(plugin_obj, AirflowPlugin) and - (plugin_obj is not AirflowPlugin) + inspect.isclass(plugin_obj) + and issubclass(plugin_obj, AirflowPlugin) + and (plugin_obj is not AirflowPlugin) ): plugin_obj.validate() return plugin_obj not in plugins @@ -202,8 +202,7 @@ def load_plugins_from_plugin_directory(): global plugins # pylint: disable=global-statement log.debug("Loading plugins from directory: %s", settings.PLUGINS_FOLDER) - for file_path in find_path_from_directory( - settings.PLUGINS_FOLDER, ".airflowignore"): + for file_path in find_path_from_directory(settings.PLUGINS_FOLDER, ".airflowignore"): if not os.path.isfile(file_path): continue @@ -239,9 +238,11 @@ def make_module(name: str, objects: List[Any]): name = name.lower() module = types.ModuleType(name) module._name = name.split('.')[-1] # type: ignore - module._objects = objects # type: ignore + module._objects = objects # type: ignore module.__dict__.update((o.__name__, o) for o in objects) return module + + # pylint: enable=protected-access @@ -277,9 +278,11 @@ def initialize_web_ui_plugins(): global flask_appbuilder_menu_links # pylint: enable=global-statement - if flask_blueprints is not None and \ - flask_appbuilder_views is not None and \ - flask_appbuilder_menu_links is not None: + if ( + flask_blueprints is not None + and flask_appbuilder_views is not None + and flask_appbuilder_menu_links is not None + ): return ensure_plugins_loaded() @@ -296,17 +299,15 @@ def initialize_web_ui_plugins(): for plugin in plugins: flask_appbuilder_views.extend(plugin.appbuilder_views) flask_appbuilder_menu_links.extend(plugin.appbuilder_menu_items) - flask_blueprints.extend([{ - 'name': plugin.name, - 'blueprint': bp - } for bp in plugin.flask_blueprints]) + flask_blueprints.extend([{'name': plugin.name, 'blueprint': bp} for bp in plugin.flask_blueprints]) if (plugin.admin_views and not plugin.appbuilder_views) or ( - plugin.menu_links and not plugin.appbuilder_menu_items): + plugin.menu_links and not plugin.appbuilder_menu_items + ): log.warning( "Plugin \'%s\' may not be compatible with the current Airflow version. " "Please contact the author of the plugin.", - plugin.name + plugin.name, ) @@ -318,9 +319,11 @@ def initialize_extra_operators_links_plugins(): global registered_operator_link_classes # pylint: enable=global-statement - if global_operator_extra_links is not None and \ - operator_extra_links is not None and \ - registered_operator_link_classes is not None: + if ( + global_operator_extra_links is not None + and operator_extra_links is not None + and registered_operator_link_classes is not None + ): return ensure_plugins_loaded() @@ -338,11 +341,12 @@ def initialize_extra_operators_links_plugins(): global_operator_extra_links.extend(plugin.global_operator_extra_links) operator_extra_links.extend(list(plugin.operator_extra_links)) - registered_operator_link_classes.update({ - "{}.{}".format(link.__class__.__module__, - link.__class__.__name__): link.__class__ - for link in plugin.operator_extra_links - }) + registered_operator_link_classes.update( + { + f"{link.__class__.__module__}.{link.__class__.__name__}": link.__class__ + for link in plugin.operator_extra_links + } + ) def integrate_executor_plugins() -> None: @@ -384,10 +388,12 @@ def integrate_dag_plugins() -> None: global macros_modules # pylint: enable=global-statement - if operators_modules is not None and \ - sensors_modules is not None and \ - hooks_modules is not None and \ - macros_modules is not None: + if ( + operators_modules is not None + and sensors_modules is not None + and hooks_modules is not None + and macros_modules is not None + ): return ensure_plugins_loaded() diff --git a/airflow/providers/amazon/aws/example_dags/example_glacier_to_gcs.py b/airflow/providers/amazon/aws/example_dags/example_glacier_to_gcs.py index 9fd2a2034e0ba..87d6aa6221c69 100644 --- a/airflow/providers/amazon/aws/example_dags/example_glacier_to_gcs.py +++ b/airflow/providers/amazon/aws/example_dags/example_glacier_to_gcs.py @@ -17,9 +17,7 @@ import os from airflow import models -from airflow.providers.amazon.aws.operators.glacier import ( - GlacierCreateJobOperator, -) +from airflow.providers.amazon.aws.operators.glacier import GlacierCreateJobOperator from airflow.providers.amazon.aws.sensors.glacier import GlacierJobOperationSensor from airflow.providers.amazon.aws.transfers.glacier_to_gcs import GlacierToGCSOperator from airflow.utils.dates import days_ago diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py b/airflow/providers/amazon/aws/hooks/base_aws.py index b16ea435f7f01..12878377aed17 100644 --- a/airflow/providers/amazon/aws/hooks/base_aws.py +++ b/airflow/providers/amazon/aws/hooks/base_aws.py @@ -29,8 +29,8 @@ from typing import Any, Dict, Optional, Tuple, Union import boto3 -from botocore.credentials import ReadOnlyCredentials from botocore.config import Config +from botocore.credentials import ReadOnlyCredentials from cached_property import cached_property from airflow.exceptions import AirflowException diff --git a/airflow/providers/amazon/aws/hooks/cloud_formation.py b/airflow/providers/amazon/aws/hooks/cloud_formation.py index 610cded6fec9b..ef20a3083610a 100644 --- a/airflow/providers/amazon/aws/hooks/cloud_formation.py +++ b/airflow/providers/amazon/aws/hooks/cloud_formation.py @@ -19,8 +19,8 @@ """This module contains AWS CloudFormation Hook""" from typing import Optional, Union -from botocore.exceptions import ClientError from boto3 import client, resource +from botocore.exceptions import ClientError from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook diff --git a/airflow/providers/amazon/aws/hooks/elasticache_replication_group.py b/airflow/providers/amazon/aws/hooks/elasticache_replication_group.py index 54305d5b86129..d6ddb9cb661c5 100644 --- a/airflow/providers/amazon/aws/hooks/elasticache_replication_group.py +++ b/airflow/providers/amazon/aws/hooks/elasticache_replication_group.py @@ -15,9 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Optional - from time import sleep +from typing import Optional from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook diff --git a/airflow/providers/amazon/aws/hooks/glacier.py b/airflow/providers/amazon/aws/hooks/glacier.py index 8b2f239f162e5..c2e0509ef2f8a 100644 --- a/airflow/providers/amazon/aws/hooks/glacier.py +++ b/airflow/providers/amazon/aws/hooks/glacier.py @@ -18,6 +18,7 @@ from typing import Any, Dict + from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook diff --git a/airflow/providers/amazon/aws/hooks/glue_catalog.py b/airflow/providers/amazon/aws/hooks/glue_catalog.py index 65d44666bc0f0..32c4d7c6635e4 100644 --- a/airflow/providers/amazon/aws/hooks/glue_catalog.py +++ b/airflow/providers/amazon/aws/hooks/glue_catalog.py @@ -17,7 +17,7 @@ # under the License. """This module contains AWS Glue Catalog Hook""" -from typing import Set, Optional +from typing import Optional, Set from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook diff --git a/airflow/providers/amazon/aws/hooks/sagemaker.py b/airflow/providers/amazon/aws/hooks/sagemaker.py index 9190d27e1eb69..bd51813e25716 100644 --- a/airflow/providers/amazon/aws/hooks/sagemaker.py +++ b/airflow/providers/amazon/aws/hooks/sagemaker.py @@ -22,7 +22,7 @@ import time import warnings from functools import partial -from typing import Dict, List, Optional, Set, Any, Callable, Generator +from typing import Any, Callable, Dict, Generator, List, Optional, Set from botocore.exceptions import ClientError diff --git a/airflow/providers/amazon/aws/hooks/secrets_manager.py b/airflow/providers/amazon/aws/hooks/secrets_manager.py index 117c83d9c7dba..3d0289a3a7eb5 100644 --- a/airflow/providers/amazon/aws/hooks/secrets_manager.py +++ b/airflow/providers/amazon/aws/hooks/secrets_manager.py @@ -20,6 +20,7 @@ import base64 import json from typing import Union + from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook diff --git a/airflow/providers/amazon/aws/hooks/sns.py b/airflow/providers/amazon/aws/hooks/sns.py index 795e78a961723..af1ce6ec27594 100644 --- a/airflow/providers/amazon/aws/hooks/sns.py +++ b/airflow/providers/amazon/aws/hooks/sns.py @@ -18,7 +18,7 @@ """This module contains AWS SNS hook""" import json -from typing import Optional, Union, Dict +from typing import Dict, Optional, Union from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook diff --git a/airflow/providers/amazon/aws/operators/batch.py b/airflow/providers/amazon/aws/operators/batch.py index 82868e6aeaa9a..07a5b4a4f77d0 100644 --- a/airflow/providers/amazon/aws/operators/batch.py +++ b/airflow/providers/amazon/aws/operators/batch.py @@ -26,7 +26,7 @@ - http://boto3.readthedocs.io/en/latest/reference/services/batch.html - https://docs.aws.amazon.com/batch/latest/APIReference/Welcome.html """ -from typing import Dict, Optional, Any +from typing import Any, Dict, Optional from airflow.exceptions import AirflowException from airflow.models import BaseOperator diff --git a/airflow/providers/amazon/aws/operators/datasync.py b/airflow/providers/amazon/aws/operators/datasync.py index 661054afd85b7..aa7a6b9455cee 100644 --- a/airflow/providers/amazon/aws/operators/datasync.py +++ b/airflow/providers/amazon/aws/operators/datasync.py @@ -19,7 +19,7 @@ import logging import random -from typing import Optional, List +from typing import List, Optional from airflow.exceptions import AirflowException from airflow.models import BaseOperator diff --git a/airflow/providers/amazon/aws/operators/ecs.py b/airflow/providers/amazon/aws/operators/ecs.py index 2ec170667c284..9584a31d77271 100644 --- a/airflow/providers/amazon/aws/operators/ecs.py +++ b/airflow/providers/amazon/aws/operators/ecs.py @@ -18,7 +18,7 @@ import re import sys from datetime import datetime -from typing import Optional, Dict +from typing import Dict, Optional from botocore.waiter import Waiter diff --git a/airflow/providers/amazon/aws/operators/sagemaker_transform.py b/airflow/providers/amazon/aws/operators/sagemaker_transform.py index 1dadb3daae6ce..7caf9f1beee69 100644 --- a/airflow/providers/amazon/aws/operators/sagemaker_transform.py +++ b/airflow/providers/amazon/aws/operators/sagemaker_transform.py @@ -15,7 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Optional, List +from typing import List, Optional from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook diff --git a/airflow/providers/amazon/aws/sensors/emr_base.py b/airflow/providers/amazon/aws/sensors/emr_base.py index d862c6b4efc12..0fa0878457578 100644 --- a/airflow/providers/amazon/aws/sensors/emr_base.py +++ b/airflow/providers/amazon/aws/sensors/emr_base.py @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, Optional, Iterable +from typing import Any, Dict, Iterable, Optional from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.emr import EmrHook diff --git a/airflow/providers/amazon/aws/sensors/sagemaker_training.py b/airflow/providers/amazon/aws/sensors/sagemaker_training.py index 9cd76688facdf..302069ef9c5be 100644 --- a/airflow/providers/amazon/aws/sensors/sagemaker_training.py +++ b/airflow/providers/amazon/aws/sensors/sagemaker_training.py @@ -15,9 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Optional - import time +from typing import Optional from airflow.providers.amazon.aws.hooks.sagemaker import LogState, SageMakerHook from airflow.providers.amazon.aws.sensors.sagemaker_base import SageMakerBaseSensor diff --git a/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py b/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py index 0d95ee7861d9c..47b6c1998d362 100644 --- a/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py @@ -25,7 +25,7 @@ from copy import copy from os.path import getsize from tempfile import NamedTemporaryFile -from typing import Any, Callable, Dict, Optional, IO +from typing import IO, Any, Callable, Dict, Optional from uuid import uuid4 from airflow.models import BaseOperator diff --git a/airflow/providers/amazon/aws/transfers/gcs_to_s3.py b/airflow/providers/amazon/aws/transfers/gcs_to_s3.py index 48559f8588779..4b552e4a575e4 100644 --- a/airflow/providers/amazon/aws/transfers/gcs_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/gcs_to_s3.py @@ -17,7 +17,7 @@ # under the License. """This module contains Google Cloud Storage to S3 operator.""" import warnings -from typing import Iterable, Optional, Sequence, Union, Dict, List, cast +from typing import Dict, Iterable, List, Optional, Sequence, Union, cast from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.s3 import S3Hook diff --git a/airflow/providers/amazon/aws/transfers/glacier_to_gcs.py b/airflow/providers/amazon/aws/transfers/glacier_to_gcs.py index 3506003181ba6..5ee9802f1a87e 100644 --- a/airflow/providers/amazon/aws/transfers/glacier_to_gcs.py +++ b/airflow/providers/amazon/aws/transfers/glacier_to_gcs.py @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. import tempfile -from typing import Optional, Union, Sequence +from typing import Optional, Sequence, Union from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.glacier import GlacierHook diff --git a/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py b/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py index 16568dba9ddf8..2939be57a2096 100644 --- a/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py +++ b/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py @@ -19,7 +19,7 @@ """This module contains operator to move data from Hive to DynamoDB.""" import json -from typing import Optional, Callable +from typing import Callable, Optional from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.dynamodb import AwsDynamoDBHook diff --git a/airflow/providers/amazon/aws/transfers/mongo_to_s3.py b/airflow/providers/amazon/aws/transfers/mongo_to_s3.py index b88087e291a51..0855fe99e3908 100644 --- a/airflow/providers/amazon/aws/transfers/mongo_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/mongo_to_s3.py @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. import json -from typing import Optional, Any, Iterable, Union, cast +from typing import Any, Iterable, Optional, Union, cast from bson import json_util diff --git a/airflow/providers/amazon/backport_provider_setup.py b/airflow/providers/amazon/backport_provider_setup.py index 7c7054264c664..f21a080fc8829 100644 --- a/airflow/providers/amazon/backport_provider_setup.py +++ b/airflow/providers/amazon/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/apache/cassandra/backport_provider_setup.py b/airflow/providers/apache/cassandra/backport_provider_setup.py index 73632d72fdbfd..db239d0c4b6d3 100644 --- a/airflow/providers/apache/cassandra/backport_provider_setup.py +++ b/airflow/providers/apache/cassandra/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/apache/druid/backport_provider_setup.py b/airflow/providers/apache/druid/backport_provider_setup.py index cd9f547e65181..6674a2214890b 100644 --- a/airflow/providers/apache/druid/backport_provider_setup.py +++ b/airflow/providers/apache/druid/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/apache/hdfs/backport_provider_setup.py b/airflow/providers/apache/hdfs/backport_provider_setup.py index 9252805189dc7..2815d167df70e 100644 --- a/airflow/providers/apache/hdfs/backport_provider_setup.py +++ b/airflow/providers/apache/hdfs/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/apache/hive/backport_provider_setup.py b/airflow/providers/apache/hive/backport_provider_setup.py index 9fdd0f292bc5b..e817ba7bcbce3 100644 --- a/airflow/providers/apache/hive/backport_provider_setup.py +++ b/airflow/providers/apache/hive/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/apache/kylin/backport_provider_setup.py b/airflow/providers/apache/kylin/backport_provider_setup.py index 9b3e4ac996752..61a4bc7917f67 100644 --- a/airflow/providers/apache/kylin/backport_provider_setup.py +++ b/airflow/providers/apache/kylin/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/apache/livy/backport_provider_setup.py b/airflow/providers/apache/livy/backport_provider_setup.py index a5ee750127c73..27024e91dd367 100644 --- a/airflow/providers/apache/livy/backport_provider_setup.py +++ b/airflow/providers/apache/livy/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/apache/pig/backport_provider_setup.py b/airflow/providers/apache/pig/backport_provider_setup.py index 9fa0d6fdde6cc..35150daedb124 100644 --- a/airflow/providers/apache/pig/backport_provider_setup.py +++ b/airflow/providers/apache/pig/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/apache/pinot/backport_provider_setup.py b/airflow/providers/apache/pinot/backport_provider_setup.py index 6b0836b949ebf..b626b84c86767 100644 --- a/airflow/providers/apache/pinot/backport_provider_setup.py +++ b/airflow/providers/apache/pinot/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/apache/spark/backport_provider_setup.py b/airflow/providers/apache/spark/backport_provider_setup.py index ccb99e2f61af8..40eb4bb4184e7 100644 --- a/airflow/providers/apache/spark/backport_provider_setup.py +++ b/airflow/providers/apache/spark/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/apache/sqoop/backport_provider_setup.py b/airflow/providers/apache/sqoop/backport_provider_setup.py index ac933f5547714..ff4ebe2c9f7e7 100644 --- a/airflow/providers/apache/sqoop/backport_provider_setup.py +++ b/airflow/providers/apache/sqoop/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/celery/backport_provider_setup.py b/airflow/providers/celery/backport_provider_setup.py index 68fcc3b06aaaf..4b050fc0ceaa8 100644 --- a/airflow/providers/celery/backport_provider_setup.py +++ b/airflow/providers/celery/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/cloudant/backport_provider_setup.py b/airflow/providers/cloudant/backport_provider_setup.py index 643c7cbbe8266..92bfb7e1522b8 100644 --- a/airflow/providers/cloudant/backport_provider_setup.py +++ b/airflow/providers/cloudant/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/cncf/kubernetes/backport_provider_setup.py b/airflow/providers/cncf/kubernetes/backport_provider_setup.py index b9d647bacc93b..fe620b92def92 100644 --- a/airflow/providers/cncf/kubernetes/backport_provider_setup.py +++ b/airflow/providers/cncf/kubernetes/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py index 95f7962569863..7c203a8d87c60 100644 --- a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py +++ b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import tempfile -from typing import Generator, Optional, Tuple, Union, Any +from typing import Any, Generator, Optional, Tuple, Union import yaml from cached_property import cached_property diff --git a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py index 4bf38a720556b..fed4b5bbf7210 100644 --- a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py +++ b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py @@ -16,21 +16,20 @@ # under the License. """Executes task in a Kubernetes POD""" import re -from typing import Dict, Iterable, List, Optional, Tuple, Any +from typing import Any, Dict, Iterable, List, Optional, Tuple import yaml -from kubernetes.client import CoreV1Api -from kubernetes.client import models as k8s +from kubernetes.client import CoreV1Api, models as k8s from airflow.exceptions import AirflowException from airflow.kubernetes import kube_client, pod_generator, pod_launcher +from airflow.kubernetes.pod_generator import PodGenerator from airflow.kubernetes.secret import Secret from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults from airflow.utils.helpers import validate_key from airflow.utils.state import State from airflow.version import version as airflow_version -from airflow.kubernetes.pod_generator import PodGenerator class KubernetesPodOperator(BaseOperator): # pylint: disable=too-many-instance-attributes diff --git a/airflow/providers/databricks/backport_provider_setup.py b/airflow/providers/databricks/backport_provider_setup.py index 4b90c4c2a77ff..7dabf94d54b26 100644 --- a/airflow/providers/databricks/backport_provider_setup.py +++ b/airflow/providers/databricks/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py index 623502373f9c8..ef71c2c9bf8d9 100644 --- a/airflow/providers/databricks/hooks/databricks.py +++ b/airflow/providers/databricks/hooks/databricks.py @@ -26,7 +26,7 @@ from urllib.parse import urlparse import requests -from requests import exceptions as requests_exceptions, PreparedRequest +from requests import PreparedRequest, exceptions as requests_exceptions from requests.auth import AuthBase from airflow import __version__ diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index 7602de526adda..fe753bc0921c9 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -19,7 +19,7 @@ """This module contains Databricks operators.""" import time -from typing import Union, Optional, Any, Dict, List +from typing import Any, Dict, List, Optional, Union from airflow.exceptions import AirflowException from airflow.models import BaseOperator diff --git a/airflow/providers/datadog/backport_provider_setup.py b/airflow/providers/datadog/backport_provider_setup.py index 282bb26b2bd57..02e757c7689e4 100644 --- a/airflow/providers/datadog/backport_provider_setup.py +++ b/airflow/providers/datadog/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/dingding/backport_provider_setup.py b/airflow/providers/dingding/backport_provider_setup.py index efeec2daf495c..78482b2446d91 100644 --- a/airflow/providers/dingding/backport_provider_setup.py +++ b/airflow/providers/dingding/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/dingding/hooks/dingding.py b/airflow/providers/dingding/hooks/dingding.py index 45163d7be8769..4bd37e41328bb 100644 --- a/airflow/providers/dingding/hooks/dingding.py +++ b/airflow/providers/dingding/hooks/dingding.py @@ -17,7 +17,7 @@ # under the License. import json -from typing import Union, Optional, List +from typing import List, Optional, Union import requests from requests import Session diff --git a/airflow/providers/dingding/operators/dingding.py b/airflow/providers/dingding/operators/dingding.py index 8b5f57a66c585..6be1bd3ad0040 100644 --- a/airflow/providers/dingding/operators/dingding.py +++ b/airflow/providers/dingding/operators/dingding.py @@ -15,10 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Union, Optional, List +from typing import List, Optional, Union from airflow.models import BaseOperator - from airflow.providers.dingding.hooks.dingding import DingdingHook from airflow.utils.decorators import apply_defaults diff --git a/airflow/providers/discord/backport_provider_setup.py b/airflow/providers/discord/backport_provider_setup.py index 1563c75c58053..363d4d090ab9c 100644 --- a/airflow/providers/discord/backport_provider_setup.py +++ b/airflow/providers/discord/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/docker/backport_provider_setup.py b/airflow/providers/docker/backport_provider_setup.py index 4a88da8f3f484..42cedfb68e1c0 100644 --- a/airflow/providers/docker/backport_provider_setup.py +++ b/airflow/providers/docker/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/elasticsearch/backport_provider_setup.py b/airflow/providers/elasticsearch/backport_provider_setup.py index 013efe4985b14..27ba27cf84d74 100644 --- a/airflow/providers/elasticsearch/backport_provider_setup.py +++ b/airflow/providers/elasticsearch/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/exasol/backport_provider_setup.py b/airflow/providers/exasol/backport_provider_setup.py index 4dcef0e04ffd6..1be0b3ffaf81f 100644 --- a/airflow/providers/exasol/backport_provider_setup.py +++ b/airflow/providers/exasol/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/exasol/hooks/exasol.py b/airflow/providers/exasol/hooks/exasol.py index 1b01222c8f306..935eb8b1b5fd9 100644 --- a/airflow/providers/exasol/hooks/exasol.py +++ b/airflow/providers/exasol/hooks/exasol.py @@ -17,7 +17,7 @@ # under the License. from contextlib import closing -from typing import Union, Optional, List, Tuple, Any +from typing import Any, List, Optional, Tuple, Union import pyexasol from pyexasol import ExaConnection diff --git a/airflow/providers/facebook/backport_provider_setup.py b/airflow/providers/facebook/backport_provider_setup.py index b9a1d398333db..0262eb87d71c8 100644 --- a/airflow/providers/facebook/backport_provider_setup.py +++ b/airflow/providers/facebook/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/ftp/backport_provider_setup.py b/airflow/providers/ftp/backport_provider_setup.py index e116e07750a72..237f15cc3a750 100644 --- a/airflow/providers/ftp/backport_provider_setup.py +++ b/airflow/providers/ftp/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/ftp/hooks/ftp.py b/airflow/providers/ftp/hooks/ftp.py index 87ae53a7ca0c7..18711896c2e44 100644 --- a/airflow/providers/ftp/hooks/ftp.py +++ b/airflow/providers/ftp/hooks/ftp.py @@ -20,7 +20,7 @@ import datetime import ftplib import os.path -from typing import List, Optional, Any +from typing import Any, List, Optional from airflow.hooks.base_hook import BaseHook diff --git a/airflow/providers/google/backport_provider_setup.py b/airflow/providers/google/backport_provider_setup.py index d150da5045691..d22578696510f 100644 --- a/airflow/providers/google/backport_provider_setup.py +++ b/airflow/providers/google/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/google/cloud/example_dags/example_azure_fileshare_to_gcs.py b/airflow/providers/google/cloud/example_dags/example_azure_fileshare_to_gcs.py index da7de2b3eec6a..c260cb8616c0d 100644 --- a/airflow/providers/google/cloud/example_dags/example_azure_fileshare_to_gcs.py +++ b/airflow/providers/google/cloud/example_dags/example_azure_fileshare_to_gcs.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import os -from datetime import timedelta, datetime +from datetime import datetime, timedelta from airflow import DAG from airflow.providers.google.cloud.transfers.azure_fileshare_to_gcs import AzureFileShareToGCSOperator diff --git a/airflow/providers/google/cloud/example_dags/example_cloud_memorystore.py b/airflow/providers/google/cloud/example_dags/example_cloud_memorystore.py index 70522a55ad433..441c165981cca 100644 --- a/airflow/providers/google/cloud/example_dags/example_cloud_memorystore.py +++ b/airflow/providers/google/cloud/example_dags/example_cloud_memorystore.py @@ -21,8 +21,8 @@ import os from urllib.parse import urlparse -from google.cloud.redis_v1.gapic.enums import FailoverInstanceRequest, Instance from google.cloud.memcache_v1beta2.types import cloud_memcache +from google.cloud.redis_v1.gapic.enums import FailoverInstanceRequest, Instance from airflow import models from airflow.operators.bash import BashOperator @@ -36,8 +36,6 @@ CloudMemorystoreGetInstanceOperator, CloudMemorystoreImportOperator, CloudMemorystoreListInstancesOperator, - CloudMemorystoreScaleInstanceOperator, - CloudMemorystoreUpdateInstanceOperator, CloudMemorystoreMemcachedApplyParametersOperator, CloudMemorystoreMemcachedCreateInstanceOperator, CloudMemorystoreMemcachedDeleteInstanceOperator, @@ -45,6 +43,8 @@ CloudMemorystoreMemcachedListInstancesOperator, CloudMemorystoreMemcachedUpdateInstanceOperator, CloudMemorystoreMemcachedUpdateParametersOperator, + CloudMemorystoreScaleInstanceOperator, + CloudMemorystoreUpdateInstanceOperator, ) from airflow.providers.google.cloud.operators.gcs import GCSBucketCreateAclEntryOperator from airflow.utils import dates diff --git a/airflow/providers/google/cloud/example_dags/example_mysql_to_gcs.py b/airflow/providers/google/cloud/example_dags/example_mysql_to_gcs.py index dbec5dcc87fa8..cdf97dc5b7da1 100644 --- a/airflow/providers/google/cloud/example_dags/example_mysql_to_gcs.py +++ b/airflow/providers/google/cloud/example_dags/example_mysql_to_gcs.py @@ -16,6 +16,7 @@ # under the License. import os + from airflow import models from airflow.providers.google.cloud.transfers.mysql_to_gcs import MySQLToGCSOperator from airflow.utils import dates diff --git a/airflow/providers/google/cloud/hooks/bigquery.py b/airflow/providers/google/cloud/hooks/bigquery.py index 96e40560aa439..3ee600809a0c6 100644 --- a/airflow/providers/google/cloud/hooks/bigquery.py +++ b/airflow/providers/google/cloud/hooks/bigquery.py @@ -26,7 +26,7 @@ import time import warnings from copy import deepcopy -from datetime import timedelta, datetime +from datetime import datetime, timedelta from typing import Any, Dict, Iterable, List, Mapping, NoReturn, Optional, Sequence, Tuple, Type, Union from google.api_core.retry import Retry @@ -43,7 +43,7 @@ from google.cloud.bigquery.dataset import AccessEntry, Dataset, DatasetListItem, DatasetReference from google.cloud.bigquery.table import EncryptionConfiguration, Row, Table, TableReference from google.cloud.exceptions import NotFound -from googleapiclient.discovery import build, Resource +from googleapiclient.discovery import Resource, build from pandas import DataFrame from pandas_gbq import read_gbq from pandas_gbq.gbq import ( diff --git a/airflow/providers/google/cloud/hooks/cloud_memorystore.py b/airflow/providers/google/cloud/hooks/cloud_memorystore.py index 693a82f77dc02..bfc01f94285df 100644 --- a/airflow/providers/google/cloud/hooks/cloud_memorystore.py +++ b/airflow/providers/google/cloud/hooks/cloud_memorystore.py @@ -18,8 +18,8 @@ """Hooks for Cloud Memorystore service""" from typing import Dict, Optional, Sequence, Tuple, Union -from google.api_core.exceptions import NotFound from google.api_core import path_template +from google.api_core.exceptions import NotFound from google.api_core.retry import Retry from google.cloud.memcache_v1beta2 import CloudMemcacheClient from google.cloud.memcache_v1beta2.types import cloud_memcache diff --git a/airflow/providers/google/cloud/hooks/cloud_sql.py b/airflow/providers/google/cloud/hooks/cloud_sql.py index 0443be78f3905..cc6d5725562ed 100644 --- a/airflow/providers/google/cloud/hooks/cloud_sql.py +++ b/airflow/providers/google/cloud/hooks/cloud_sql.py @@ -37,7 +37,7 @@ from urllib.parse import quote_plus import requests -from googleapiclient.discovery import build, Resource +from googleapiclient.discovery import Resource, build from googleapiclient.errors import HttpError from sqlalchemy.orm import Session diff --git a/airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py b/airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py index b7aae109e1f03..5600a29a8696a 100644 --- a/airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +++ b/airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py @@ -25,7 +25,7 @@ from datetime import timedelta from typing import List, Optional, Sequence, Set, Union -from googleapiclient.discovery import build, Resource +from googleapiclient.discovery import Resource, build from googleapiclient.errors import HttpError from airflow.exceptions import AirflowException diff --git a/airflow/providers/google/cloud/hooks/datastore.py b/airflow/providers/google/cloud/hooks/datastore.py index c8ca3de1cb7d9..d36c295a3919d 100644 --- a/airflow/providers/google/cloud/hooks/datastore.py +++ b/airflow/providers/google/cloud/hooks/datastore.py @@ -23,7 +23,7 @@ import warnings from typing import Any, Dict, Optional, Sequence, Union -from googleapiclient.discovery import build, Resource +from googleapiclient.discovery import Resource, build from airflow.providers.google.common.hooks.base_google import GoogleBaseHook diff --git a/airflow/providers/google/cloud/hooks/gcs.py b/airflow/providers/google/cloud/hooks/gcs.py index 6d2f3480fcec8..f59ff8ed9fa67 100644 --- a/airflow/providers/google/cloud/hooks/gcs.py +++ b/airflow/providers/google/cloud/hooks/gcs.py @@ -27,7 +27,7 @@ from io import BytesIO from os import path from tempfile import NamedTemporaryFile -from typing import Callable, Optional, Sequence, Set, Tuple, TypeVar, Union, cast, List +from typing import Callable, List, Optional, Sequence, Set, Tuple, TypeVar, Union, cast from urllib.parse import urlparse from google.api_core.exceptions import NotFound diff --git a/airflow/providers/google/cloud/hooks/gdm.py b/airflow/providers/google/cloud/hooks/gdm.py index e726d2f366711..f11a3903f3005 100644 --- a/airflow/providers/google/cloud/hooks/gdm.py +++ b/airflow/providers/google/cloud/hooks/gdm.py @@ -19,7 +19,7 @@ from typing import Any, Dict, List, Optional, Sequence, Union -from googleapiclient.discovery import build, Resource +from googleapiclient.discovery import Resource, build from airflow.exceptions import AirflowException from airflow.providers.google.common.hooks.base_google import GoogleBaseHook diff --git a/airflow/providers/google/cloud/hooks/mlengine.py b/airflow/providers/google/cloud/hooks/mlengine.py index b17cc46fef257..e0182ce91c70a 100644 --- a/airflow/providers/google/cloud/hooks/mlengine.py +++ b/airflow/providers/google/cloud/hooks/mlengine.py @@ -21,7 +21,7 @@ import time from typing import Callable, Dict, List, Optional -from googleapiclient.discovery import build, Resource +from googleapiclient.discovery import Resource, build from googleapiclient.errors import HttpError from airflow.providers.google.common.hooks.base_google import GoogleBaseHook diff --git a/airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py b/airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py index 37b616574fe28..127e24e0155e3 100644 --- a/airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +++ b/airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py @@ -19,7 +19,7 @@ """This module contains Google Cloud Transfer operators.""" from copy import deepcopy from datetime import date, time -from typing import Dict, Optional, Sequence, Union, List +from typing import Dict, List, Optional, Sequence, Union from airflow.exceptions import AirflowException from airflow.models import BaseOperator diff --git a/airflow/providers/google/cloud/operators/dlp.py b/airflow/providers/google/cloud/operators/dlp.py index 5d28d6639be8b..099feb2f1aed0 100644 --- a/airflow/providers/google/cloud/operators/dlp.py +++ b/airflow/providers/google/cloud/operators/dlp.py @@ -24,7 +24,7 @@ """ from typing import Dict, Optional, Sequence, Tuple, Union -from google.api_core.exceptions import AlreadyExists, NotFound, InvalidArgument +from google.api_core.exceptions import AlreadyExists, InvalidArgument, NotFound from google.api_core.retry import Retry from google.cloud.dlp_v2.types import ( ByteContentItem, diff --git a/airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py b/airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py index dec6de1633159..02c2472fa4468 100644 --- a/airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +++ b/airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py @@ -19,10 +19,10 @@ from typing import Optional, Sequence, Set, Union from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import ( - CloudDataTransferServiceHook, COUNTERS, METADATA, NAME, + CloudDataTransferServiceHook, ) from airflow.sensors.base_sensor_operator import BaseSensorOperator from airflow.utils.decorators import apply_defaults diff --git a/airflow/providers/google/cloud/sensors/dataproc.py b/airflow/providers/google/cloud/sensors/dataproc.py index fb8a990674445..9606fca6b4859 100644 --- a/airflow/providers/google/cloud/sensors/dataproc.py +++ b/airflow/providers/google/cloud/sensors/dataproc.py @@ -20,10 +20,10 @@ from google.cloud.dataproc_v1beta2.types import JobStatus +from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.dataproc import DataprocHook from airflow.sensors.base_sensor_operator import BaseSensorOperator from airflow.utils.decorators import apply_defaults -from airflow.exceptions import AirflowException class DataprocJobSensor(BaseSensorOperator): diff --git a/airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py b/airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py index d424838897229..913b9485e8f90 100644 --- a/airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py @@ -17,11 +17,11 @@ # under the License. from tempfile import NamedTemporaryFile -from typing import Optional, Union, Sequence, Iterable +from typing import Iterable, Optional, Sequence, Union from airflow import AirflowException from airflow.models import BaseOperator -from airflow.providers.google.cloud.hooks.gcs import _parse_gcs_url, GCSHook, gcs_object_is_directory +from airflow.providers.google.cloud.hooks.gcs import GCSHook, _parse_gcs_url, gcs_object_is_directory from airflow.providers.microsoft.azure.hooks.azure_fileshare import AzureFileShareHook from airflow.utils.decorators import apply_defaults diff --git a/airflow/providers/google/cloud/transfers/presto_to_gcs.py b/airflow/providers/google/cloud/transfers/presto_to_gcs.py index 7543f83d4fc94..70a4062b0def0 100644 --- a/airflow/providers/google/cloud/transfers/presto_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/presto_to_gcs.py @@ -15,7 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, List, Tuple, Dict +from typing import Any, Dict, List, Tuple from prestodb.client import PrestoResult from prestodb.dbapi import Cursor as PrestoCursor diff --git a/airflow/providers/google/cloud/utils/credentials_provider.py b/airflow/providers/google/cloud/utils/credentials_provider.py index c06c5de833958..e39b0a814a7d0 100644 --- a/airflow/providers/google/cloud/utils/credentials_provider.py +++ b/airflow/providers/google/cloud/utils/credentials_provider.py @@ -23,7 +23,7 @@ import logging import tempfile from contextlib import ExitStack, contextmanager -from typing import Collection, Dict, Optional, Sequence, Tuple, Union, Generator +from typing import Collection, Dict, Generator, Optional, Sequence, Tuple, Union from urllib.parse import urlencode import google.auth diff --git a/airflow/providers/google/marketing_platform/operators/analytics.py b/airflow/providers/google/marketing_platform/operators/analytics.py index 803984db74f25..2a1f6a14c6f23 100644 --- a/airflow/providers/google/marketing_platform/operators/analytics.py +++ b/airflow/providers/google/marketing_platform/operators/analytics.py @@ -18,7 +18,7 @@ """This module contains Google Analytics 360 operators.""" import csv from tempfile import NamedTemporaryFile -from typing import Dict, Optional, Sequence, Union, Any, List +from typing import Any, Dict, List, Optional, Sequence, Union from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.gcs import GCSHook diff --git a/airflow/providers/grpc/backport_provider_setup.py b/airflow/providers/grpc/backport_provider_setup.py index 9094ac31cf2de..18974469c319a 100644 --- a/airflow/providers/grpc/backport_provider_setup.py +++ b/airflow/providers/grpc/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/hashicorp/backport_provider_setup.py b/airflow/providers/hashicorp/backport_provider_setup.py index ecb051985fb26..ea75f9fc70e68 100644 --- a/airflow/providers/hashicorp/backport_provider_setup.py +++ b/airflow/providers/hashicorp/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/http/backport_provider_setup.py b/airflow/providers/http/backport_provider_setup.py index ebae13e66b903..010bf50d37883 100644 --- a/airflow/providers/http/backport_provider_setup.py +++ b/airflow/providers/http/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/imap/backport_provider_setup.py b/airflow/providers/imap/backport_provider_setup.py index b515f9ee49e17..86c6a3ad5ed8d 100644 --- a/airflow/providers/imap/backport_provider_setup.py +++ b/airflow/providers/imap/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/jdbc/backport_provider_setup.py b/airflow/providers/jdbc/backport_provider_setup.py index 11d751b5b7b59..798c8a9145952 100644 --- a/airflow/providers/jdbc/backport_provider_setup.py +++ b/airflow/providers/jdbc/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/jenkins/backport_provider_setup.py b/airflow/providers/jenkins/backport_provider_setup.py index f22d0df9fe692..d35948529b3e6 100644 --- a/airflow/providers/jenkins/backport_provider_setup.py +++ b/airflow/providers/jenkins/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/jira/backport_provider_setup.py b/airflow/providers/jira/backport_provider_setup.py index a12a10efd172b..525a6272b585d 100644 --- a/airflow/providers/jira/backport_provider_setup.py +++ b/airflow/providers/jira/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/microsoft/azure/backport_provider_setup.py b/airflow/providers/microsoft/azure/backport_provider_setup.py index 93e31c4348fe9..6580d1106aee2 100644 --- a/airflow/providers/microsoft/azure/backport_provider_setup.py +++ b/airflow/providers/microsoft/azure/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/microsoft/azure/example_dags/example_azure_blob_to_gcs.py b/airflow/providers/microsoft/azure/example_dags/example_azure_blob_to_gcs.py index 7183972545c83..693286fe00858 100644 --- a/airflow/providers/microsoft/azure/example_dags/example_azure_blob_to_gcs.py +++ b/airflow/providers/microsoft/azure/example_dags/example_azure_blob_to_gcs.py @@ -19,9 +19,7 @@ from airflow import DAG from airflow.providers.microsoft.azure.sensors.wasb import WasbBlobSensor -from airflow.providers.microsoft.azure.transfers.azure_blob_to_gcs import ( - AzureBlobStorageToGCSOperator, -) +from airflow.providers.microsoft.azure.transfers.azure_blob_to_gcs import AzureBlobStorageToGCSOperator from airflow.utils.dates import days_ago BLOB_NAME = os.environ.get("AZURE_BLOB_NAME", "file.txt") diff --git a/airflow/providers/microsoft/azure/example_dags/example_local_to_adls.py b/airflow/providers/microsoft/azure/example_dags/example_local_to_adls.py index 78383195082b5..294b5f71b5538 100644 --- a/airflow/providers/microsoft/azure/example_dags/example_local_to_adls.py +++ b/airflow/providers/microsoft/azure/example_dags/example_local_to_adls.py @@ -16,6 +16,7 @@ # under the License. import os + from airflow import models from airflow.providers.microsoft.azure.transfers.local_to_adls import LocalToAzureDataLakeStorageOperator from airflow.utils.dates import days_ago diff --git a/airflow/providers/microsoft/azure/hooks/azure_batch.py b/airflow/providers/microsoft/azure/hooks/azure_batch.py index 864233c8315cf..ab4e6e90ffaba 100644 --- a/airflow/providers/microsoft/azure/hooks/azure_batch.py +++ b/airflow/providers/microsoft/azure/hooks/azure_batch.py @@ -21,7 +21,7 @@ from typing import Optional, Set from azure.batch import BatchServiceClient, batch_auth, models as batch_models -from azure.batch.models import PoolAddParameter, JobAddParameter, TaskAddParameter +from azure.batch.models import JobAddParameter, PoolAddParameter, TaskAddParameter from airflow.exceptions import AirflowException from airflow.hooks.base_hook import BaseHook diff --git a/airflow/providers/microsoft/azure/hooks/azure_fileshare.py b/airflow/providers/microsoft/azure/hooks/azure_fileshare.py index 81d683960ad66..95b462c7dc25b 100644 --- a/airflow/providers/microsoft/azure/hooks/azure_fileshare.py +++ b/airflow/providers/microsoft/azure/hooks/azure_fileshare.py @@ -16,9 +16,9 @@ # specific language governing permissions and limitations # under the License. # -from typing import Optional, List +from typing import List, Optional -from azure.storage.file import FileService, File +from azure.storage.file import File, FileService from airflow.hooks.base_hook import BaseHook diff --git a/airflow/providers/microsoft/azure/log/wasb_task_handler.py b/airflow/providers/microsoft/azure/log/wasb_task_handler.py index a8eb6db69a1d3..3f12135aff62e 100644 --- a/airflow/providers/microsoft/azure/log/wasb_task_handler.py +++ b/airflow/providers/microsoft/azure/log/wasb_task_handler.py @@ -17,7 +17,7 @@ # under the License. import os import shutil -from typing import Optional, Tuple, Dict +from typing import Dict, Optional, Tuple from azure.common import AzureHttpError from cached_property import cached_property diff --git a/airflow/providers/microsoft/azure/operators/azure_container_instances.py b/airflow/providers/microsoft/azure/operators/azure_container_instances.py index 3bf30d9e3d748..75c0ba532e710 100644 --- a/airflow/providers/microsoft/azure/operators/azure_container_instances.py +++ b/airflow/providers/microsoft/azure/operators/azure_container_instances.py @@ -19,17 +19,17 @@ import re from collections import namedtuple from time import sleep -from typing import Any, List, Optional, Sequence, Union, Dict +from typing import Any, Dict, List, Optional, Sequence, Union from azure.mgmt.containerinstance.models import ( Container, ContainerGroup, + ContainerPort, EnvironmentVariable, + IpAddress, ResourceRequests, ResourceRequirements, VolumeMount, - IpAddress, - ContainerPort, ) from msrestazure.azure_exceptions import CloudError diff --git a/airflow/providers/microsoft/azure/transfers/azure_blob_to_gcs.py b/airflow/providers/microsoft/azure/transfers/azure_blob_to_gcs.py index a33a922138699..ccc7577a38f89 100644 --- a/airflow/providers/microsoft/azure/transfers/azure_blob_to_gcs.py +++ b/airflow/providers/microsoft/azure/transfers/azure_blob_to_gcs.py @@ -17,8 +17,8 @@ # under the License. # import tempfile +from typing import Optional, Sequence, Union -from typing import Optional, Union, Sequence from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.gcs import GCSHook from airflow.providers.microsoft.azure.hooks.wasb import WasbHook diff --git a/airflow/providers/microsoft/azure/transfers/local_to_adls.py b/airflow/providers/microsoft/azure/transfers/local_to_adls.py index 755a171c3606a..bf6947b653de7 100644 --- a/airflow/providers/microsoft/azure/transfers/local_to_adls.py +++ b/airflow/providers/microsoft/azure/transfers/local_to_adls.py @@ -15,7 +15,8 @@ # specific language governing permissions and limitations # under the License. -from typing import Dict, Any, Optional +from typing import Any, Dict, Optional + from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.providers.microsoft.azure.hooks.azure_data_lake import AzureDataLakeHook diff --git a/airflow/providers/microsoft/mssql/backport_provider_setup.py b/airflow/providers/microsoft/mssql/backport_provider_setup.py index 6868ab659fe7d..2c314e03a2f08 100644 --- a/airflow/providers/microsoft/mssql/backport_provider_setup.py +++ b/airflow/providers/microsoft/mssql/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/microsoft/winrm/backport_provider_setup.py b/airflow/providers/microsoft/winrm/backport_provider_setup.py index 876d4fc4a99e1..b79a1b5f3d7c3 100644 --- a/airflow/providers/microsoft/winrm/backport_provider_setup.py +++ b/airflow/providers/microsoft/winrm/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/mongo/backport_provider_setup.py b/airflow/providers/mongo/backport_provider_setup.py index ffc5500562aef..da729e4f05bee 100644 --- a/airflow/providers/mongo/backport_provider_setup.py +++ b/airflow/providers/mongo/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/mysql/backport_provider_setup.py b/airflow/providers/mysql/backport_provider_setup.py index a43120e7b3c83..af8383b1fa41e 100644 --- a/airflow/providers/mysql/backport_provider_setup.py +++ b/airflow/providers/mysql/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/odbc/backport_provider_setup.py b/airflow/providers/odbc/backport_provider_setup.py index f27329190f770..0e5c5cf361232 100644 --- a/airflow/providers/odbc/backport_provider_setup.py +++ b/airflow/providers/odbc/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/odbc/hooks/odbc.py b/airflow/providers/odbc/hooks/odbc.py index 3eabdf2e63f66..fc12549422754 100644 --- a/airflow/providers/odbc/hooks/odbc.py +++ b/airflow/providers/odbc/hooks/odbc.py @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. """This module contains ODBC hook.""" -from typing import Optional, Any +from typing import Any, Optional from urllib.parse import quote_plus import pyodbc diff --git a/airflow/providers/openfaas/backport_provider_setup.py b/airflow/providers/openfaas/backport_provider_setup.py index ca748ad95305d..63400ec278376 100644 --- a/airflow/providers/openfaas/backport_provider_setup.py +++ b/airflow/providers/openfaas/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/opsgenie/backport_provider_setup.py b/airflow/providers/opsgenie/backport_provider_setup.py index e5d21ae961099..9ffb68eaf2c03 100644 --- a/airflow/providers/opsgenie/backport_provider_setup.py +++ b/airflow/providers/opsgenie/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/opsgenie/hooks/opsgenie_alert.py b/airflow/providers/opsgenie/hooks/opsgenie_alert.py index 9e5aa6b6b8e08..a65b9a3ea615b 100644 --- a/airflow/providers/opsgenie/hooks/opsgenie_alert.py +++ b/airflow/providers/opsgenie/hooks/opsgenie_alert.py @@ -18,7 +18,7 @@ # import json -from typing import Optional, Any +from typing import Any, Optional import requests diff --git a/airflow/providers/opsgenie/operators/opsgenie_alert.py b/airflow/providers/opsgenie/operators/opsgenie_alert.py index faea2d720e3b4..f4511dd1920bf 100644 --- a/airflow/providers/opsgenie/operators/opsgenie_alert.py +++ b/airflow/providers/opsgenie/operators/opsgenie_alert.py @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. # -from typing import Optional, List, Dict, Any +from typing import Any, Dict, List, Optional from airflow.models import BaseOperator from airflow.providers.opsgenie.hooks.opsgenie_alert import OpsgenieAlertHook diff --git a/airflow/providers/oracle/backport_provider_setup.py b/airflow/providers/oracle/backport_provider_setup.py index 3302022d81a53..a2fe377e4fbff 100644 --- a/airflow/providers/oracle/backport_provider_setup.py +++ b/airflow/providers/oracle/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/oracle/hooks/oracle.py b/airflow/providers/oracle/hooks/oracle.py index 020267b914f7f..efb3fcc547f6f 100644 --- a/airflow/providers/oracle/hooks/oracle.py +++ b/airflow/providers/oracle/hooks/oracle.py @@ -17,7 +17,7 @@ # under the License. from datetime import datetime -from typing import Optional, List +from typing import List, Optional import cx_Oracle import numpy diff --git a/airflow/providers/pagerduty/backport_provider_setup.py b/airflow/providers/pagerduty/backport_provider_setup.py index e28bd21907333..6b8d1ceec5d73 100644 --- a/airflow/providers/pagerduty/backport_provider_setup.py +++ b/airflow/providers/pagerduty/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/plexus/backport_provider_setup.py b/airflow/providers/plexus/backport_provider_setup.py index a7ffb31e3bbb1..f92d8dd8a9023 100644 --- a/airflow/providers/plexus/backport_provider_setup.py +++ b/airflow/providers/plexus/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/plexus/operators/job.py b/airflow/providers/plexus/operators/job.py index a1aec95cf37b5..ece5df6e5f318 100644 --- a/airflow/providers/plexus/operators/job.py +++ b/airflow/providers/plexus/operators/job.py @@ -17,7 +17,7 @@ import logging import time -from typing import Dict, Any, Optional +from typing import Any, Dict, Optional import requests diff --git a/airflow/providers/postgres/backport_provider_setup.py b/airflow/providers/postgres/backport_provider_setup.py index 8a1bffd0bde7d..f3c5d6ed1ed83 100644 --- a/airflow/providers/postgres/backport_provider_setup.py +++ b/airflow/providers/postgres/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/postgres/hooks/postgres.py b/airflow/providers/postgres/hooks/postgres.py index 825a0d027036e..b207edbf651fe 100644 --- a/airflow/providers/postgres/hooks/postgres.py +++ b/airflow/providers/postgres/hooks/postgres.py @@ -24,7 +24,7 @@ import psycopg2.extensions import psycopg2.extras from psycopg2.extensions import connection -from psycopg2.extras import DictCursor, RealDictCursor, NamedTupleCursor +from psycopg2.extras import DictCursor, NamedTupleCursor, RealDictCursor from airflow.hooks.dbapi_hook import DbApiHook from airflow.models.connection import Connection diff --git a/airflow/providers/presto/backport_provider_setup.py b/airflow/providers/presto/backport_provider_setup.py index b24fff549507c..cf9f269a838f4 100644 --- a/airflow/providers/presto/backport_provider_setup.py +++ b/airflow/providers/presto/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/presto/hooks/presto.py b/airflow/providers/presto/hooks/presto.py index 26575d3252a43..a5a0576e1b670 100644 --- a/airflow/providers/presto/hooks/presto.py +++ b/airflow/providers/presto/hooks/presto.py @@ -15,8 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Optional, Any, Iterable import os +from typing import Any, Iterable, Optional import prestodb from prestodb.exceptions import DatabaseError diff --git a/airflow/providers/qubole/backport_provider_setup.py b/airflow/providers/qubole/backport_provider_setup.py index 4d41bf8a3be95..8387e4a5d0556 100644 --- a/airflow/providers/qubole/backport_provider_setup.py +++ b/airflow/providers/qubole/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/qubole/hooks/qubole.py b/airflow/providers/qubole/hooks/qubole.py index eb3a374e91a77..56c0a5220a55e 100644 --- a/airflow/providers/qubole/hooks/qubole.py +++ b/airflow/providers/qubole/hooks/qubole.py @@ -22,7 +22,7 @@ import os import pathlib import time -from typing import List, Dict, Tuple +from typing import Dict, List, Tuple from qds_sdk.commands import ( Command, @@ -31,12 +31,12 @@ DbTapQueryCommand, HadoopCommand, HiveCommand, + JupyterNotebookCommand, PigCommand, PrestoCommand, ShellCommand, SparkCommand, SqlCommand, - JupyterNotebookCommand, ) from qds_sdk.qubole import Qubole diff --git a/airflow/providers/qubole/hooks/qubole_check.py b/airflow/providers/qubole/hooks/qubole_check.py index 1c6bdf43e2eac..987ea59ae2021 100644 --- a/airflow/providers/qubole/hooks/qubole_check.py +++ b/airflow/providers/qubole/hooks/qubole_check.py @@ -18,7 +18,7 @@ # import logging from io import StringIO -from typing import List, Union, Optional +from typing import List, Optional, Union from qds_sdk.commands import Command diff --git a/airflow/providers/qubole/operators/qubole_check.py b/airflow/providers/qubole/operators/qubole_check.py index fc1561c28eeb0..41a4bc7e138b4 100644 --- a/airflow/providers/qubole/operators/qubole_check.py +++ b/airflow/providers/qubole/operators/qubole_check.py @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. # -from typing import Iterable, Union, Optional +from typing import Iterable, Optional, Union from airflow.exceptions import AirflowException from airflow.operators.check_operator import CheckOperator, ValueCheckOperator diff --git a/airflow/providers/redis/backport_provider_setup.py b/airflow/providers/redis/backport_provider_setup.py index ce6f307c2ea50..bba8751ae7b12 100644 --- a/airflow/providers/redis/backport_provider_setup.py +++ b/airflow/providers/redis/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/salesforce/backport_provider_setup.py b/airflow/providers/salesforce/backport_provider_setup.py index 076714068eb1d..2743d4f88efe0 100644 --- a/airflow/providers/salesforce/backport_provider_setup.py +++ b/airflow/providers/salesforce/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/salesforce/hooks/salesforce.py b/airflow/providers/salesforce/hooks/salesforce.py index 1fea62139a48b..affe286a69da3 100644 --- a/airflow/providers/salesforce/hooks/salesforce.py +++ b/airflow/providers/salesforce/hooks/salesforce.py @@ -25,7 +25,7 @@ """ import logging import time -from typing import Optional, List, Iterable +from typing import Iterable, List, Optional import pandas as pd from simple_salesforce import Salesforce, api diff --git a/airflow/providers/salesforce/hooks/tableau.py b/airflow/providers/salesforce/hooks/tableau.py index 80c67640963c3..bd47d10b3e67d 100644 --- a/airflow/providers/salesforce/hooks/tableau.py +++ b/airflow/providers/salesforce/hooks/tableau.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from enum import Enum -from typing import Optional, Any +from typing import Any, Optional from tableauserverclient import Pager, PersonalAccessTokenAuth, Server, TableauAuth from tableauserverclient.server import Auth diff --git a/airflow/providers/samba/backport_provider_setup.py b/airflow/providers/samba/backport_provider_setup.py index 6829dacf9c238..a426466494544 100644 --- a/airflow/providers/samba/backport_provider_setup.py +++ b/airflow/providers/samba/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/segment/backport_provider_setup.py b/airflow/providers/segment/backport_provider_setup.py index 3813732bca634..54a77898a277a 100644 --- a/airflow/providers/segment/backport_provider_setup.py +++ b/airflow/providers/segment/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/sftp/backport_provider_setup.py b/airflow/providers/sftp/backport_provider_setup.py index 3359f0fb798ac..a0664098df6dd 100644 --- a/airflow/providers/sftp/backport_provider_setup.py +++ b/airflow/providers/sftp/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/singularity/backport_provider_setup.py b/airflow/providers/singularity/backport_provider_setup.py index 26162133dd530..e9f51448cc4b1 100644 --- a/airflow/providers/singularity/backport_provider_setup.py +++ b/airflow/providers/singularity/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/slack/backport_provider_setup.py b/airflow/providers/slack/backport_provider_setup.py index e51f2073d02ef..98475e28ce5e3 100644 --- a/airflow/providers/slack/backport_provider_setup.py +++ b/airflow/providers/slack/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/slack/operators/slack.py b/airflow/providers/slack/operators/slack.py index 622bab8110638..0263466c40cb6 100644 --- a/airflow/providers/slack/operators/slack.py +++ b/airflow/providers/slack/operators/slack.py @@ -17,7 +17,7 @@ # under the License. import json -from typing import Dict, List, Optional, Any +from typing import Any, Dict, List, Optional from airflow.models import BaseOperator from airflow.providers.slack.hooks.slack import SlackHook diff --git a/airflow/providers/slack/operators/slack_webhook.py b/airflow/providers/slack/operators/slack_webhook.py index 7899aa931bfac..1a7725a81c6ae 100644 --- a/airflow/providers/slack/operators/slack_webhook.py +++ b/airflow/providers/slack/operators/slack_webhook.py @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. # -from typing import Optional, Dict, Any +from typing import Any, Dict, Optional from airflow.providers.http.operators.http import SimpleHttpOperator from airflow.providers.slack.hooks.slack_webhook import SlackWebhookHook diff --git a/airflow/providers/snowflake/backport_provider_setup.py b/airflow/providers/snowflake/backport_provider_setup.py index 6e67634cdacde..47ccac1f39e27 100644 --- a/airflow/providers/snowflake/backport_provider_setup.py +++ b/airflow/providers/snowflake/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/snowflake/hooks/snowflake.py b/airflow/providers/snowflake/hooks/snowflake.py index 00a1e2bcb35d6..24fef1132b5cc 100644 --- a/airflow/providers/snowflake/hooks/snowflake.py +++ b/airflow/providers/snowflake/hooks/snowflake.py @@ -15,7 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Dict, Optional, Tuple, Any +from typing import Any, Dict, Optional, Tuple from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization diff --git a/airflow/providers/sqlite/backport_provider_setup.py b/airflow/providers/sqlite/backport_provider_setup.py index 6c0fcd219b48a..da49f3006dd7c 100644 --- a/airflow/providers/sqlite/backport_provider_setup.py +++ b/airflow/providers/sqlite/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/ssh/backport_provider_setup.py b/airflow/providers/ssh/backport_provider_setup.py index 94137e01b1422..c073e129ada98 100644 --- a/airflow/providers/ssh/backport_provider_setup.py +++ b/airflow/providers/ssh/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/ssh/hooks/ssh.py b/airflow/providers/ssh/hooks/ssh.py index d453c468c72a5..e28e779f54254 100644 --- a/airflow/providers/ssh/hooks/ssh.py +++ b/airflow/providers/ssh/hooks/ssh.py @@ -20,7 +20,7 @@ import os import warnings from io import StringIO -from typing import Optional, Union, Tuple +from typing import Optional, Tuple, Union import paramiko from paramiko.config import SSH_PORT diff --git a/airflow/providers/vertica/backport_provider_setup.py b/airflow/providers/vertica/backport_provider_setup.py index 313311ee44383..620ac43f7ffd8 100644 --- a/airflow/providers/vertica/backport_provider_setup.py +++ b/airflow/providers/vertica/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/yandex/backport_provider_setup.py b/airflow/providers/yandex/backport_provider_setup.py index 6125d675de79d..2ce94832e5dd1 100644 --- a/airflow/providers/yandex/backport_provider_setup.py +++ b/airflow/providers/yandex/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/providers/yandex/hooks/yandex.py b/airflow/providers/yandex/hooks/yandex.py index 5bb71a706879f..8121f29294f7b 100644 --- a/airflow/providers/yandex/hooks/yandex.py +++ b/airflow/providers/yandex/hooks/yandex.py @@ -16,7 +16,7 @@ # under the License. import json -from typing import Optional, Dict, Any, Union +from typing import Any, Dict, Optional, Union import yandexcloud diff --git a/airflow/providers/zendesk/backport_provider_setup.py b/airflow/providers/zendesk/backport_provider_setup.py index 2e6a4336aae9e..42843864a7b4b 100644 --- a/airflow/providers/zendesk/backport_provider_setup.py +++ b/airflow/providers/zendesk/backport_provider_setup.py @@ -29,8 +29,8 @@ import logging import os import sys - from os.path import dirname + from setuptools import find_packages, setup logger = logging.getLogger(__name__) diff --git a/airflow/secrets/base_secrets.py b/airflow/secrets/base_secrets.py index 3d2f2576194ab..63bc94ae01942 100644 --- a/airflow/secrets/base_secrets.py +++ b/airflow/secrets/base_secrets.py @@ -59,6 +59,7 @@ def get_connections(self, conn_id: str) -> List['Connection']: :type conn_id: str """ from airflow.models.connection import Connection + conn_uri = self.get_conn_uri(conn_id=conn_id) if not conn_uri: return [] @@ -74,7 +75,7 @@ def get_variable(self, key: str) -> Optional[str]: """ raise NotImplementedError() - def get_config(self, key: str) -> Optional[str]: # pylint: disable=unused-argument + def get_config(self, key: str) -> Optional[str]: # pylint: disable=unused-argument """ Return value for Airflow Config Key diff --git a/airflow/secrets/local_filesystem.py b/airflow/secrets/local_filesystem.py index 2b249cc7968bd..a7d9e4f2d0a24 100644 --- a/airflow/secrets/local_filesystem.py +++ b/airflow/secrets/local_filesystem.py @@ -28,7 +28,10 @@ import yaml from airflow.exceptions import ( - AirflowException, AirflowFileParseException, ConnectionNotUnique, FileSyntaxError, + AirflowException, + AirflowFileParseException, + ConnectionNotUnique, + FileSyntaxError, ) from airflow.secrets.base_secrets import BaseSecretsBackend from airflow.utils.file import COMMENT_PATTERN @@ -82,7 +85,12 @@ def _parse_env_file(file_path: str) -> Tuple[Dict[str, List[str]], List[FileSynt key, value = var_parts if not key: - errors.append(FileSyntaxError(line_no=line_no, message="Invalid line format. Key is empty.",)) + errors.append( + FileSyntaxError( + line_no=line_no, + message="Invalid line format. Key is empty.", + ) + ) secrets[key].append(value) return secrets, errors @@ -236,7 +244,8 @@ def load_connections(file_path) -> Dict[str, List[Any]]: """This function is deprecated. Please use `airflow.secrets.local_filesystem.load_connections_dict`.",""" warnings.warn( "This function is deprecated. Please use `airflow.secrets.local_filesystem.load_connections_dict`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) return {k: [v] for k, v in load_connections_dict(file_path).values()} diff --git a/airflow/secrets/metastore.py b/airflow/secrets/metastore.py index 497bc4476d66b..e896f082ee2cf 100644 --- a/airflow/secrets/metastore.py +++ b/airflow/secrets/metastore.py @@ -33,6 +33,7 @@ class MetastoreBackend(BaseSecretsBackend): @provide_session def get_connections(self, conn_id, session=None) -> List['Connection']: from airflow.models.connection import Connection + conn_list = session.query(Connection).filter(Connection.conn_id == conn_id).all() session.expunge_all() return conn_list @@ -46,6 +47,7 @@ def get_variable(self, key: str, session=None): :return: Variable Value """ from airflow.models.variable import Variable + var_value = session.query(Variable).filter(Variable.key == key).first() session.expunge_all() if var_value: diff --git a/airflow/security/kerberos.py b/airflow/security/kerberos.py index 37c1d9ff7f745..87c269dbe91e6 100644 --- a/airflow/security/kerberos.py +++ b/airflow/security/kerberos.py @@ -58,33 +58,36 @@ def renew_from_kt(principal: str, keytab: str, exit_on_fail: bool = True): # minutes to give ourselves a large renewal buffer. renewal_lifetime = "%sm" % conf.getint('kerberos', 'reinit_frequency') - cmd_principal = principal or conf.get('kerberos', 'principal').replace( - "_HOST", socket.getfqdn() - ) + cmd_principal = principal or conf.get('kerberos', 'principal').replace("_HOST", socket.getfqdn()) cmdv = [ conf.get('kerberos', 'kinit_path'), - "-r", renewal_lifetime, + "-r", + renewal_lifetime, "-k", # host ticket - "-t", keytab, # specify keytab - "-c", conf.get('kerberos', 'ccache'), # specify credentials cache - cmd_principal + "-t", + keytab, # specify keytab + "-c", + conf.get('kerberos', 'ccache'), # specify credentials cache + cmd_principal, ] log.info("Re-initialising kerberos from keytab: %s", " ".join(cmdv)) - subp = subprocess.Popen(cmdv, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - close_fds=True, - bufsize=-1, - universal_newlines=True) + subp = subprocess.Popen( + cmdv, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + close_fds=True, + bufsize=-1, + universal_newlines=True, + ) subp.wait() if subp.returncode != 0: log.error( "Couldn't reinit from keytab! `kinit' exited with %s.\n%s\n%s", subp.returncode, "\n".join(subp.stdout.readlines() if subp.stdout else []), - "\n".join(subp.stderr.readlines() if subp.stderr else []) + "\n".join(subp.stderr.readlines() if subp.stderr else []), ) if exit_on_fail: sys.exit(subp.returncode) @@ -113,13 +116,14 @@ def perform_krb181_workaround(principal: str): :param principal: principal name :return: None """ - cmdv = [conf.get('kerberos', 'kinit_path'), - "-c", conf.get('kerberos', 'ccache'), - "-R"] # Renew ticket_cache + cmdv = [ + conf.get('kerberos', 'kinit_path'), + "-c", + conf.get('kerberos', 'ccache'), + "-R", + ] # Renew ticket_cache - log.info( - "Renewing kerberos ticket to work around kerberos 1.8.1: %s", " ".join(cmdv) - ) + log.info("Renewing kerberos ticket to work around kerberos 1.8.1: %s", " ".join(cmdv)) ret = subprocess.call(cmdv, close_fds=True) @@ -132,7 +136,10 @@ def perform_krb181_workaround(principal: str): "the ticket for '%s' is still renewable:\n $ kinit -f -c %s\nIf the 'renew until' date is the " "same as the 'valid starting' date, the ticket cannot be renewed. Please check your KDC " "configuration, and the ticket renewal policy (maxrenewlife) for the '%s' and `krbtgt' " - "principals.", princ, ccache, princ + "principals.", + princ, + ccache, + princ, ) return ret diff --git a/airflow/sensors/base_sensor_operator.py b/airflow/sensors/base_sensor_operator.py index d98fb97ea9b9a..cb639bc3176fa 100644 --- a/airflow/sensors/base_sensor_operator.py +++ b/airflow/sensors/base_sensor_operator.py @@ -25,7 +25,10 @@ from airflow.configuration import conf from airflow.exceptions import ( - AirflowException, AirflowRescheduleException, AirflowSensorTimeout, AirflowSkipException, + AirflowException, + AirflowRescheduleException, + AirflowSensorTimeout, + AirflowSkipException, ) from airflow.models import BaseOperator, SensorInstance from airflow.models.skipmixin import SkipMixin @@ -76,17 +79,27 @@ class BaseSensorOperator(BaseOperator, SkipMixin): # setup. Smart sensor serialize these attributes into a different DB column so # that smart sensor service is able to handle corresponding execution details # without breaking the sensor poking logic with dedup. - execution_fields = ('poke_interval', 'retries', 'execution_timeout', 'timeout', - 'email', 'email_on_retry', 'email_on_failure',) + execution_fields = ( + 'poke_interval', + 'retries', + 'execution_timeout', + 'timeout', + 'email', + 'email_on_retry', + 'email_on_failure', + ) @apply_defaults - def __init__(self, *, - poke_interval: float = 60, - timeout: float = 60 * 60 * 24 * 7, - soft_fail: bool = False, - mode: str = 'poke', - exponential_backoff: bool = False, - **kwargs) -> None: + def __init__( + self, + *, + poke_interval: float = 60, + timeout: float = 60 * 60 * 24 * 7, + soft_fail: bool = False, + mode: str = 'poke', + exponential_backoff: bool = False, + **kwargs, + ) -> None: super().__init__(**kwargs) self.poke_interval = poke_interval self.soft_fail = soft_fail @@ -96,22 +109,24 @@ def __init__(self, *, self._validate_input_values() self.sensor_service_enabled = conf.getboolean('smart_sensor', 'use_smart_sensor') self.sensors_support_sensor_service = set( - map(lambda l: l.strip(), conf.get('smart_sensor', 'sensors_enabled').split(','))) + map(lambda l: l.strip(), conf.get('smart_sensor', 'sensors_enabled').split(',')) + ) def _validate_input_values(self) -> None: if not isinstance(self.poke_interval, (int, float)) or self.poke_interval < 0: - raise AirflowException( - "The poke_interval must be a non-negative number") + raise AirflowException("The poke_interval must be a non-negative number") if not isinstance(self.timeout, (int, float)) or self.timeout < 0: - raise AirflowException( - "The timeout must be a non-negative number") + raise AirflowException("The timeout must be a non-negative number") if self.mode not in self.valid_modes: raise AirflowException( "The mode must be one of {valid_modes}," - "'{d}.{t}'; received '{m}'." - .format(valid_modes=self.valid_modes, - d=self.dag.dag_id if self.dag else "", - t=self.task_id, m=self.mode)) + "'{d}.{t}'; received '{m}'.".format( + valid_modes=self.valid_modes, + d=self.dag.dag_id if self.dag else "", + t=self.task_id, + m=self.mode, + ) + ) def poke(self, context: Dict) -> bool: """ @@ -121,10 +136,12 @@ def poke(self, context: Dict) -> bool: raise AirflowException('Override me.') def is_smart_sensor_compatible(self): - check_list = [not self.sensor_service_enabled, - self.on_success_callback, - self.on_retry_callback, - self.on_failure_callback] + check_list = [ + not self.sensor_service_enabled, + self.on_success_callback, + self.on_retry_callback, + self.on_failure_callback, + ] for status in check_list: if status: return False @@ -195,14 +212,13 @@ def execute(self, context: Dict) -> Any: # give it a chance and fail with timeout. # This gives the ability to set up non-blocking AND soft-fail sensors. if self.soft_fail and not context['ti'].is_eligible_to_retry(): - raise AirflowSkipException( - f"Snap. Time is OUT. DAG id: {log_dag_id}") + raise AirflowSkipException(f"Snap. Time is OUT. DAG id: {log_dag_id}") else: - raise AirflowSensorTimeout( - f"Snap. Time is OUT. DAG id: {log_dag_id}") + raise AirflowSensorTimeout(f"Snap. Time is OUT. DAG id: {log_dag_id}") if self.reschedule: reschedule_date = timezone.utcnow() + timedelta( - seconds=self._get_next_poke_interval(started_at, try_number)) + seconds=self._get_next_poke_interval(started_at, try_number) + ) raise AirflowRescheduleException(reschedule_date) else: sleep(self._get_next_poke_interval(started_at, try_number)) @@ -215,17 +231,18 @@ def _get_next_poke_interval(self, started_at, try_number): min_backoff = int(self.poke_interval * (2 ** (try_number - 2))) current_time = timezone.utcnow() - run_hash = int(hashlib.sha1("{}#{}#{}#{}".format( - self.dag_id, self.task_id, started_at, try_number - ).encode("utf-8")).hexdigest(), 16) + run_hash = int( + hashlib.sha1( + f"{self.dag_id}#{self.task_id}#{started_at}#{try_number}".encode("utf-8") + ).hexdigest(), + 16, + ) modded_hash = min_backoff + run_hash % min_backoff - delay_backoff_in_seconds = min( - modded_hash, - timedelta.max.total_seconds() - 1 + delay_backoff_in_seconds = min(modded_hash, timedelta.max.total_seconds() - 1) + new_interval = min( + self.timeout - int((current_time - started_at).total_seconds()), delay_backoff_in_seconds ) - new_interval = min(self.timeout - int((current_time - started_at).total_seconds()), - delay_backoff_in_seconds) self.log.info("new %s interval is %s", self.mode, new_interval) return new_interval else: @@ -269,6 +286,7 @@ def poke_mode_only(cls): :param cls: BaseSensor class to enforce methods only use 'poke' mode. :type cls: type """ + def decorate(cls_type): def mode_getter(_): return 'poke' @@ -278,9 +296,11 @@ def mode_setter(_, value): raise ValueError("cannot set mode to 'poke'.") if not issubclass(cls_type, BaseSensorOperator): - raise ValueError(f"poke_mode_only decorator should only be " - f"applied to subclasses of BaseSensorOperator," - f" got:{cls_type}.") + raise ValueError( + f"poke_mode_only decorator should only be " + f"applied to subclasses of BaseSensorOperator," + f" got:{cls_type}." + ) cls_type.mode = property(mode_getter, mode_setter) diff --git a/airflow/sensors/bash.py b/airflow/sensors/bash.py index 711a7d40d24e8..371edba6d064e 100644 --- a/airflow/sensors/bash.py +++ b/airflow/sensors/bash.py @@ -45,11 +45,7 @@ class BashSensor(BaseSensorOperator): template_fields = ('bash_command', 'env') @apply_defaults - def __init__(self, *, - bash_command, - env=None, - output_encoding='utf-8', - **kwargs): + def __init__(self, *, bash_command, env=None, output_encoding='utf-8', **kwargs): super().__init__(**kwargs) self.bash_command = bash_command self.env = env @@ -72,9 +68,13 @@ def poke(self, context): self.log.info("Running command: %s", bash_command) resp = Popen( # pylint: disable=subprocess-popen-preexec-fn ['bash', fname], - stdout=PIPE, stderr=STDOUT, - close_fds=True, cwd=tmp_dir, - env=self.env, preexec_fn=os.setsid) + stdout=PIPE, + stderr=STDOUT, + close_fds=True, + cwd=tmp_dir, + env=self.env, + preexec_fn=os.setsid, + ) self.log.info("Output:") for line in iter(resp.stdout.readline, b''): diff --git a/airflow/sensors/date_time_sensor.py b/airflow/sensors/date_time_sensor.py index 0d479ed6394f3..028e505bdc15f 100644 --- a/airflow/sensors/date_time_sensor.py +++ b/airflow/sensors/date_time_sensor.py @@ -57,9 +57,7 @@ class DateTimeSensor(BaseSensorOperator): template_fields = ("target_time",) @apply_defaults - def __init__( - self, *, target_time: Union[str, datetime.datetime], **kwargs - ) -> None: + def __init__(self, *, target_time: Union[str, datetime.datetime], **kwargs) -> None: super().__init__(**kwargs) if isinstance(target_time, datetime.datetime): self.target_time = target_time.isoformat() @@ -67,9 +65,7 @@ def __init__( self.target_time = target_time else: raise TypeError( - "Expected str or datetime.datetime type for target_time. Got {}".format( - type(target_time) - ) + "Expected str or datetime.datetime type for target_time. Got {}".format(type(target_time)) ) def poke(self, context: Dict) -> bool: diff --git a/airflow/sensors/external_task_sensor.py b/airflow/sensors/external_task_sensor.py index c32939f3d167b..70cd971ce7a76 100644 --- a/airflow/sensors/external_task_sensor.py +++ b/airflow/sensors/external_task_sensor.py @@ -85,15 +85,18 @@ def operator_extra_links(self): return [ExternalTaskSensorLink()] @apply_defaults - def __init__(self, *, - external_dag_id, - external_task_id=None, - allowed_states=None, - failed_states=None, - execution_delta=None, - execution_date_fn=None, - check_existence=False, - **kwargs): + def __init__( + self, + *, + external_dag_id, + external_task_id=None, + allowed_states=None, + failed_states=None, + execution_delta=None, + execution_date_fn=None, + check_existence=False, + **kwargs, + ): super().__init__(**kwargs) self.allowed_states = allowed_states or [State.SUCCESS] self.failed_states = failed_states or [] @@ -102,9 +105,10 @@ def __init__(self, *, total_states = set(total_states) if set(self.failed_states).intersection(set(self.allowed_states)): - raise AirflowException("Duplicate values provided as allowed " - "`{}` and failed states `{}`" - .format(self.allowed_states, self.failed_states)) + raise AirflowException( + "Duplicate values provided as allowed " + "`{}` and failed states `{}`".format(self.allowed_states, self.failed_states) + ) if external_task_id: if not total_states <= set(State.task_states): @@ -121,7 +125,8 @@ def __init__(self, *, if execution_delta is not None and execution_date_fn is not None: raise ValueError( 'Only one of `execution_delta` or `execution_date_fn` may ' - 'be provided to ExternalTaskSensor; not both.') + 'be provided to ExternalTaskSensor; not both.' + ) self.execution_delta = execution_delta self.execution_date_fn = execution_date_fn @@ -141,20 +146,16 @@ def poke(self, context, session=None): dttm = context['execution_date'] dttm_filter = dttm if isinstance(dttm, list) else [dttm] - serialized_dttm_filter = ','.join( - [datetime.isoformat() for datetime in dttm_filter]) + serialized_dttm_filter = ','.join([datetime.isoformat() for datetime in dttm_filter]) self.log.info( - 'Poking for %s.%s on %s ... ', - self.external_dag_id, self.external_task_id, serialized_dttm_filter + 'Poking for %s.%s on %s ... ', self.external_dag_id, self.external_task_id, serialized_dttm_filter ) DM = DagModel # we only do the check for 1st time, no need for subsequent poke if self.check_existence and not self.has_checked_existence: - dag_to_wait = session.query(DM).filter( - DM.dag_id == self.external_dag_id - ).first() + dag_to_wait = session.query(DM).filter(DM.dag_id == self.external_dag_id).first() if not dag_to_wait: raise AirflowException(f'The external DAG {self.external_dag_id} does not exist.') @@ -204,19 +205,27 @@ def get_count(self, dttm_filter, session, states): if self.external_task_id: # .count() is inefficient - count = session.query(func.count()).filter( - TI.dag_id == self.external_dag_id, - TI.task_id == self.external_task_id, - TI.state.in_(states), # pylint: disable=no-member - TI.execution_date.in_(dttm_filter), - ).scalar() + count = ( + session.query(func.count()) + .filter( + TI.dag_id == self.external_dag_id, + TI.task_id == self.external_task_id, + TI.state.in_(states), # pylint: disable=no-member + TI.execution_date.in_(dttm_filter), + ) + .scalar() + ) else: # .count() is inefficient - count = session.query(func.count()).filter( - DR.dag_id == self.external_dag_id, - DR.state.in_(states), # pylint: disable=no-member - DR.execution_date.in_(dttm_filter), - ).scalar() + count = ( + session.query(func.count()) + .filter( + DR.dag_id == self.external_dag_id, + DR.state.in_(states), # pylint: disable=no-member + DR.execution_date.in_(dttm_filter), + ) + .scalar() + ) return count def _handle_execution_date_fn(self, context): @@ -235,9 +244,7 @@ def _handle_execution_date_fn(self, context): if num_fxn_params == 2: return self.execution_date_fn(context['execution_date'], context) - raise AirflowException( - f'execution_date_fn passed {num_fxn_params} args but only allowed up to 2' - ) + raise AirflowException(f'execution_date_fn passed {num_fxn_params} args but only allowed up to 2') class ExternalTaskMarker(DummyOperator): @@ -266,12 +273,15 @@ class ExternalTaskMarker(DummyOperator): __serialized_fields: Optional[FrozenSet[str]] = None @apply_defaults - def __init__(self, *, - external_dag_id, - external_task_id, - execution_date: Optional[Union[str, datetime.datetime]] = "{{ execution_date.isoformat() }}", - recursion_depth: int = 10, - **kwargs): + def __init__( + self, + *, + external_dag_id, + external_task_id, + execution_date: Optional[Union[str, datetime.datetime]] = "{{ execution_date.isoformat() }}", + recursion_depth: int = 10, + **kwargs, + ): super().__init__(**kwargs) self.external_dag_id = external_dag_id self.external_task_id = external_task_id @@ -280,8 +290,11 @@ def __init__(self, *, elif isinstance(execution_date, str): self.execution_date = execution_date else: - raise TypeError('Expected str or datetime.datetime type for execution_date. Got {}' - .format(type(execution_date))) + raise TypeError( + 'Expected str or datetime.datetime type for execution_date. Got {}'.format( + type(execution_date) + ) + ) if recursion_depth <= 0: raise ValueError("recursion_depth should be a positive integer") self.recursion_depth = recursion_depth @@ -290,9 +303,5 @@ def __init__(self, *, def get_serialized_fields(cls): """Serialized ExternalTaskMarker contain exactly these fields + templated_fields .""" if not cls.__serialized_fields: - cls.__serialized_fields = frozenset( - super().get_serialized_fields() | { - "recursion_depth" - } - ) + cls.__serialized_fields = frozenset(super().get_serialized_fields() | {"recursion_depth"}) return cls.__serialized_fields diff --git a/airflow/sensors/filesystem.py b/airflow/sensors/filesystem.py index 6270ba25bc438..328fd362edd6f 100644 --- a/airflow/sensors/filesystem.py +++ b/airflow/sensors/filesystem.py @@ -44,10 +44,7 @@ class FileSensor(BaseSensorOperator): ui_color = '#91818a' @apply_defaults - def __init__(self, *, - filepath, - fs_conn_id='fs_default', - **kwargs): + def __init__(self, *, filepath, fs_conn_id='fs_default', **kwargs): super().__init__(**kwargs) self.filepath = filepath self.fs_conn_id = fs_conn_id diff --git a/airflow/sensors/hdfs_sensor.py b/airflow/sensors/hdfs_sensor.py index 0204a2f969185..ed8830e4bbaef 100644 --- a/airflow/sensors/hdfs_sensor.py +++ b/airflow/sensors/hdfs_sensor.py @@ -25,5 +25,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.apache.hdfs.sensors.hdfs`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/sensors/hive_partition_sensor.py b/airflow/sensors/hive_partition_sensor.py index 1768b3a226c07..028a9fd2dcace 100644 --- a/airflow/sensors/hive_partition_sensor.py +++ b/airflow/sensors/hive_partition_sensor.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.apache.hive.sensors.hive_partition`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/sensors/http_sensor.py b/airflow/sensors/http_sensor.py index 26755c2ba8a11..ad85b60447d00 100644 --- a/airflow/sensors/http_sensor.py +++ b/airflow/sensors/http_sensor.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.http.sensors.http`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/sensors/metastore_partition_sensor.py b/airflow/sensors/metastore_partition_sensor.py index e4db0ccd4bb04..5ab966f12057c 100644 --- a/airflow/sensors/metastore_partition_sensor.py +++ b/airflow/sensors/metastore_partition_sensor.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.apache.hive.sensors.metastore_partition`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/sensors/named_hive_partition_sensor.py b/airflow/sensors/named_hive_partition_sensor.py index 37d8eb31ae91b..eeb28e1656533 100644 --- a/airflow/sensors/named_hive_partition_sensor.py +++ b/airflow/sensors/named_hive_partition_sensor.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.apache.hive.sensors.named_hive_partition`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/sensors/python.py b/airflow/sensors/python.py index 183405188635b..68fcc66b94e7a 100644 --- a/airflow/sensors/python.py +++ b/airflow/sensors/python.py @@ -50,12 +50,14 @@ class PythonSensor(BaseSensorOperator): @apply_defaults def __init__( - self, *, - python_callable: Callable, - op_args: Optional[List] = None, - op_kwargs: Optional[Dict] = None, - templates_dict: Optional[Dict] = None, - **kwargs): + self, + *, + python_callable: Callable, + op_args: Optional[List] = None, + op_kwargs: Optional[Dict] = None, + templates_dict: Optional[Dict] = None, + **kwargs, + ): super().__init__(**kwargs) self.python_callable = python_callable self.op_args = op_args or [] diff --git a/airflow/sensors/s3_key_sensor.py b/airflow/sensors/s3_key_sensor.py index d8bada859f867..26b973e8c6ab9 100644 --- a/airflow/sensors/s3_key_sensor.py +++ b/airflow/sensors/s3_key_sensor.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.s3_key`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/sensors/s3_prefix_sensor.py b/airflow/sensors/s3_prefix_sensor.py index 27ef2bd903d47..f09854e7f87c6 100644 --- a/airflow/sensors/s3_prefix_sensor.py +++ b/airflow/sensors/s3_prefix_sensor.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.s3_prefix`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/sensors/smart_sensor_operator.py b/airflow/sensors/smart_sensor_operator.py index 369002ca86b21..4917ee64cc0c7 100644 --- a/airflow/sensors/smart_sensor_operator.py +++ b/airflow/sensors/smart_sensor_operator.py @@ -94,10 +94,12 @@ def __eq__(self, other): if not isinstance(other, SensorWork): return NotImplemented - return self.dag_id == other.dag_id and \ - self.task_id == other.task_id and \ - self.execution_date == other.execution_date and \ - self.try_number == other.try_number + return ( + self.dag_id == other.dag_id + and self.task_id == other.task_id + and self.execution_date == other.execution_date + and self.try_number == other.try_number + ) @staticmethod def create_new_task_handler(): @@ -117,10 +119,9 @@ def _get_sensor_logger(self, si): # The created log_id is used inside of smart sensor as the key to fetch # the corresponding in memory log handler. si.raw = False # Otherwise set_context will fail - log_id = "-".join([si.dag_id, - si.task_id, - si.execution_date.strftime("%Y_%m_%dT%H_%M_%S_%f"), - str(si.try_number)]) + log_id = "-".join( + [si.dag_id, si.task_id, si.execution_date.strftime("%Y_%m_%dT%H_%M_%S_%f"), str(si.try_number)] + ) logger = logging.getLogger('airflow.task' + '.' + log_id) if len(logger.handlers) == 0: @@ -128,10 +129,11 @@ def _get_sensor_logger(self, si): logger.addHandler(handler) set_context(logger, si) - line_break = ("-" * 120) + line_break = "-" * 120 logger.info(line_break) - logger.info("Processing sensor task %s in smart sensor service on host: %s", - self.ti_key, get_hostname()) + logger.info( + "Processing sensor task %s in smart sensor service on host: %s", self.ti_key, get_hostname() + ) logger.info(line_break) return logger @@ -200,10 +202,12 @@ class SensorExceptionInfo: infra failure, give the task more chance to retry before fail it. """ - def __init__(self, - exception_info, - is_infra_failure=False, - infra_failure_retry_window=datetime.timedelta(minutes=130)): + def __init__( + self, + exception_info, + is_infra_failure=False, + infra_failure_retry_window=datetime.timedelta(minutes=130), + ): self._exception_info = exception_info self._is_infra_failure = is_infra_failure self._infra_failure_retry_window = infra_failure_retry_window @@ -302,15 +306,17 @@ class SmartSensorOperator(BaseOperator, SkipMixin): ui_color = '#e6f1f2' @apply_defaults - def __init__(self, - poke_interval=180, - smart_sensor_timeout=60 * 60 * 24 * 7, - soft_fail=False, - shard_min=0, - shard_max=100000, - poke_timeout=6.0, - *args, - **kwargs): + def __init__( + self, + poke_interval=180, + smart_sensor_timeout=60 * 60 * 24 * 7, + soft_fail=False, + shard_min=0, + shard_max=100000, + poke_timeout=6.0, + *args, + **kwargs, + ): super().__init__(*args, **kwargs) # super(SmartSensorOperator, self).__init__(*args, **kwargs) self.poke_interval = poke_interval @@ -330,11 +336,9 @@ def __init__(self, def _validate_input_values(self): if not isinstance(self.poke_interval, (int, float)) or self.poke_interval < 0: - raise AirflowException( - "The poke_interval must be a non-negative number") + raise AirflowException("The poke_interval must be a non-negative number") if not isinstance(self.timeout, (int, float)) or self.timeout < 0: - raise AirflowException( - "The timeout must be a non-negative number") + raise AirflowException("The timeout must be a non-negative number") @provide_session def _load_sensor_works(self, session=None): @@ -345,10 +349,11 @@ def _load_sensor_works(self, session=None): """ SI = SensorInstance start_query_time = time.time() - query = session.query(SI) \ - .filter(SI.state == State.SENSING)\ - .filter(SI.shardcode < self.shard_max, - SI.shardcode >= self.shard_min) + query = ( + session.query(SI) + .filter(SI.state == State.SENSING) + .filter(SI.shardcode < self.shard_max, SI.shardcode >= self.shard_min) + ) tis = query.all() self.log.info("Performance query %s tis, time: %s", len(tis), time.time() - start_query_time) @@ -387,10 +392,11 @@ def _update_ti_hostname(self, sensor_works, session=None): def update_ti_hostname_with_count(count, ti_keys): # Using or_ instead of in_ here to prevent from full table scan. - tis = session.query(TI) \ - .filter(or_(tuple_(TI.dag_id, TI.task_id, TI.execution_date) == ti_key - for ti_key in ti_keys)) \ + tis = ( + session.query(TI) + .filter(or_(tuple_(TI.dag_id, TI.task_id, TI.execution_date) == ti_key for ti_key in ti_keys)) .all() + ) for ti in tis: ti.hostname = self.hostname @@ -414,6 +420,7 @@ def _mark_multi_state(self, operator, poke_hash, encoded_poke_context, state, se :param state: Set multiple sensor tasks to this state. :param session: The sqlalchemy session. """ + def mark_state(ti, sensor_instance): ti.state = state sensor_instance.state = state @@ -426,14 +433,22 @@ def mark_state(ti, sensor_instance): count_marked = 0 try: - query_result = session.query(TI, SI)\ - .join(TI, and_(TI.dag_id == SI.dag_id, - TI.task_id == SI.task_id, - TI.execution_date == SI.execution_date)) \ - .filter(SI.state == State.SENSING) \ - .filter(SI.hashcode == poke_hash) \ - .filter(SI.operator == operator) \ - .with_for_update().all() + query_result = ( + session.query(TI, SI) + .join( + TI, + and_( + TI.dag_id == SI.dag_id, + TI.task_id == SI.task_id, + TI.execution_date == SI.execution_date, + ), + ) + .filter(SI.state == State.SENSING) + .filter(SI.hashcode == poke_hash) + .filter(SI.operator == operator) + .with_for_update() + .all() + ) end_date = timezone.utcnow() for ti, sensor_instance in query_result: @@ -451,8 +466,7 @@ def mark_state(ti, sensor_instance): session.commit() except Exception as e: # pylint: disable=broad-except - self.log.warning("Exception _mark_multi_state in smart sensor for hashcode %s", - str(poke_hash)) + self.log.warning("Exception _mark_multi_state in smart sensor for hashcode %s", str(poke_hash)) self.log.exception(e, exc_info=True) self.log.info("Marked %s tasks out of %s to state %s", count_marked, len(query_result), state) @@ -469,6 +483,7 @@ def _retry_or_fail_task(self, sensor_work, error, session=None): :type error: str. :param session: The sqlalchemy session. """ + def email_alert(task_instance, error_info): try: subject, html_content, _ = task_instance.get_email_subject_content(error_info) @@ -480,18 +495,19 @@ def email_alert(task_instance, error_info): sensor_work.log.exception(e, exc_info=True) def handle_failure(sensor_work, ti): - if sensor_work.execution_context.get('retries') and \ - ti.try_number <= ti.max_tries: + if sensor_work.execution_context.get('retries') and ti.try_number <= ti.max_tries: # retry ti.state = State.UP_FOR_RETRY - if sensor_work.execution_context.get('email_on_retry') and \ - sensor_work.execution_context.get('email'): + if sensor_work.execution_context.get('email_on_retry') and sensor_work.execution_context.get( + 'email' + ): sensor_work.log.info("%s sending email alert for retry", sensor_work.ti_key) email_alert(ti, error) else: ti.state = State.FAILED - if sensor_work.execution_context.get('email_on_failure') and \ - sensor_work.execution_context.get('email'): + if sensor_work.execution_context.get( + 'email_on_failure' + ) and sensor_work.execution_context.get('email'): sensor_work.log.info("%s sending email alert for failure", sensor_work.ti_key) email_alert(ti, error) @@ -499,23 +515,23 @@ def handle_failure(sensor_work, ti): dag_id, task_id, execution_date = sensor_work.ti_key TI = TaskInstance SI = SensorInstance - sensor_instance = session.query(SI).filter( - SI.dag_id == dag_id, - SI.task_id == task_id, - SI.execution_date == execution_date) \ - .with_for_update() \ + sensor_instance = ( + session.query(SI) + .filter(SI.dag_id == dag_id, SI.task_id == task_id, SI.execution_date == execution_date) + .with_for_update() .first() + ) if sensor_instance.hashcode != sensor_work.hashcode: # Return without setting state return - ti = session.query(TI).filter( - TI.dag_id == dag_id, - TI.task_id == task_id, - TI.execution_date == execution_date) \ - .with_for_update() \ + ti = ( + session.query(TI) + .filter(TI.dag_id == dag_id, TI.task_id == task_id, TI.execution_date == execution_date) + .with_for_update() .first() + ) if ti: if ti.state == State.SENSING: @@ -531,8 +547,9 @@ def handle_failure(sensor_work, ti): session.merge(ti) session.commit() - sensor_work.log.info("Task %s got an error: %s. Set the state to failed. Exit.", - str(sensor_work.ti_key), error) + sensor_work.log.info( + "Task %s got an error: %s. Set the state to failed. Exit.", str(sensor_work.ti_key), error + ) sensor_work.close_sensor_logger() except AirflowException as e: @@ -568,8 +585,9 @@ def _handle_poke_exception(self, sensor_work): if sensor_exception.fail_current_run: if sensor_exception.is_infra_failure: - sensor_work.log.exception("Task %s failed by infra failure in smart sensor.", - sensor_work.ti_key) + sensor_work.log.exception( + "Task %s failed by infra failure in smart sensor.", sensor_work.ti_key + ) # There is a risk for sensor object cached in smart sensor keep throwing # exception and cause an infra failure. To make sure the sensor tasks after # retry will not fall into same object and have endless infra failure, @@ -619,10 +637,12 @@ def _execute_sensor_work(self, sensor_work): # Got a landed signal, mark all tasks waiting for this partition cached_work.set_state(PokeState.LANDED) - self._mark_multi_state(sensor_work.operator, - sensor_work.hashcode, - sensor_work.encoded_poke_context, - State.SUCCESS) + self._mark_multi_state( + sensor_work.operator, + sensor_work.hashcode, + sensor_work.encoded_poke_context, + State.SUCCESS, + ) log.info("Task %s succeeded", str(ti_key)) sensor_work.close_sensor_logger() @@ -642,12 +662,12 @@ def _execute_sensor_work(self, sensor_work): if cache_key in self.cached_sensor_exceptions: self.cached_sensor_exceptions[cache_key].set_latest_exception( - exception_info, - is_infra_failure=is_infra_failure) + exception_info, is_infra_failure=is_infra_failure + ) else: self.cached_sensor_exceptions[cache_key] = SensorExceptionInfo( - exception_info, - is_infra_failure=is_infra_failure) + exception_info, is_infra_failure=is_infra_failure + ) self._handle_poke_exception(sensor_work) @@ -671,8 +691,7 @@ def poke(self, sensor_work): """ cached_work = self.cached_dedup_works[sensor_work.cache_key] if not cached_work.sensor_task: - init_args = dict(list(sensor_work.poke_context.items()) - + [('task_id', sensor_work.task_id)]) + init_args = dict(list(sensor_work.poke_context.items()) + [('task_id', sensor_work.task_id)]) operator_class = import_string(sensor_work.op_classpath) cached_work.sensor_task = operator_class(**init_args) diff --git a/airflow/sensors/sql_sensor.py b/airflow/sensors/sql_sensor.py index b8dc1432538b5..aabff0d547c1c 100644 --- a/airflow/sensors/sql_sensor.py +++ b/airflow/sensors/sql_sensor.py @@ -52,12 +52,16 @@ class SqlSensor(BaseSensorOperator): """ template_fields: Iterable[str] = ('sql',) - template_ext: Iterable[str] = ('.hql', '.sql',) + template_ext: Iterable[str] = ( + '.hql', + '.sql', + ) ui_color = '#7c7287' @apply_defaults - def __init__(self, *, conn_id, sql, parameters=None, success=None, failure=None, fail_on_empty=False, - **kwargs): + def __init__( + self, *, conn_id, sql, parameters=None, success=None, failure=None, fail_on_empty=False, **kwargs + ): self.conn_id = conn_id self.sql = sql self.parameters = parameters @@ -69,12 +73,24 @@ def __init__(self, *, conn_id, sql, parameters=None, success=None, failure=None, def _get_hook(self): conn = BaseHook.get_connection(self.conn_id) - allowed_conn_type = {'google_cloud_platform', 'jdbc', 'mssql', - 'mysql', 'odbc', 'oracle', 'postgres', - 'presto', 'snowflake', 'sqlite', 'vertica'} + allowed_conn_type = { + 'google_cloud_platform', + 'jdbc', + 'mssql', + 'mysql', + 'odbc', + 'oracle', + 'postgres', + 'presto', + 'snowflake', + 'sqlite', + 'vertica', + } if conn.conn_type not in allowed_conn_type: - raise AirflowException("The connection type is not supported by SqlSensor. " + - "Supported connection types: {}".format(list(allowed_conn_type))) + raise AirflowException( + "The connection type is not supported by SqlSensor. " + + "Supported connection types: {}".format(list(allowed_conn_type)) + ) return conn.get_hook() def poke(self, context): @@ -91,8 +107,7 @@ def poke(self, context): if self.failure is not None: if callable(self.failure): if self.failure(first_cell): - raise AirflowException( - f"Failure criteria met. self.failure({first_cell}) returned True") + raise AirflowException(f"Failure criteria met. self.failure({first_cell}) returned True") else: raise AirflowException(f"self.failure is present, but not callable -> {self.success}") if self.success is not None: diff --git a/airflow/sensors/web_hdfs_sensor.py b/airflow/sensors/web_hdfs_sensor.py index 7ad6cae83f17a..015d0ab684c2a 100644 --- a/airflow/sensors/web_hdfs_sensor.py +++ b/airflow/sensors/web_hdfs_sensor.py @@ -24,5 +24,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.apache.hdfs.sensors.web_hdfs`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/sensors/weekday_sensor.py b/airflow/sensors/weekday_sensor.py index 0340794664587..cadbeaa1a009d 100644 --- a/airflow/sensors/weekday_sensor.py +++ b/airflow/sensors/weekday_sensor.py @@ -73,9 +73,7 @@ class DayOfWeekSensor(BaseSensorOperator): """ @apply_defaults - def __init__(self, *, week_day, - use_task_execution_day=False, - **kwargs): + def __init__(self, *, week_day, use_task_execution_day=False, **kwargs): super().__init__(**kwargs) self.week_day = week_day self.use_task_execution_day = use_task_execution_day @@ -91,12 +89,15 @@ def __init__(self, *, week_day, else: raise TypeError( 'Unsupported Type for week_day parameter: {}. It should be one of str' - ', set or Weekday enum type'.format(type(week_day))) + ', set or Weekday enum type'.format(type(week_day)) + ) def poke(self, context): - self.log.info('Poking until weekday is in %s, Today is %s', - self.week_day, - WeekDay(timezone.utcnow().isoweekday()).name) + self.log.info( + 'Poking until weekday is in %s, Today is %s', + self.week_day, + WeekDay(timezone.utcnow().isoweekday()).name, + ) if self.use_task_execution_day: return context['execution_date'].isoweekday() in self._week_day_num else: diff --git a/airflow/sentry.py b/airflow/sentry.py index 253534383e941..f16b05722f85a 100644 --- a/airflow/sentry.py +++ b/airflow/sentry.py @@ -50,6 +50,7 @@ def flush(self): Sentry: DummySentry = DummySentry() if conf.getboolean("sentry", 'sentry_on', fallback=False): import sentry_sdk + # Verify blinker installation from blinker import signal # noqa: F401 pylint: disable=unused-import from sentry_sdk.integrations.flask import FlaskIntegration @@ -58,14 +59,19 @@ def flush(self): class ConfiguredSentry(DummySentry): """Configure Sentry SDK.""" - SCOPE_TAGS = frozenset( - ("task_id", "dag_id", "execution_date", "operator", "try_number") - ) + SCOPE_TAGS = frozenset(("task_id", "dag_id", "execution_date", "operator", "try_number")) SCOPE_CRUMBS = frozenset(("task_id", "state", "operator", "duration")) UNSUPPORTED_SENTRY_OPTIONS = frozenset( - ("integrations", "in_app_include", "in_app_exclude", "ignore_errors", - "before_breadcrumb", "before_send", "transport") + ( + "integrations", + "in_app_include", + "in_app_exclude", + "ignore_errors", + "before_breadcrumb", + "before_send", + "transport", + ) ) def __init__(self): @@ -94,12 +100,11 @@ def __init__(self): # supported backward compability with old way dsn option dsn = old_way_dsn or new_way_dsn - unsupported_options = self.UNSUPPORTED_SENTRY_OPTIONS.intersection( - sentry_config_opts.keys()) + unsupported_options = self.UNSUPPORTED_SENTRY_OPTIONS.intersection(sentry_config_opts.keys()) if unsupported_options: log.warning( "There are unsupported options in [sentry] section: %s", - ", ".join(unsupported_options) + ", ".join(unsupported_options), ) if dsn: diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 5b0def99ce751..7ce5e1a1d7f3c 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -43,6 +43,7 @@ # isort: off from kubernetes.client import models as k8s from airflow.kubernetes.pod_generator import PodGenerator + # isort: on HAS_KUBERNETES = True except ImportError: @@ -99,8 +100,9 @@ def from_json(cls, serialized_obj: str) -> Union['BaseSerialization', dict, list return cls.from_dict(json.loads(serialized_obj)) @classmethod - def from_dict(cls, serialized_obj: Dict[Encoding, Any]) -> \ - Union['BaseSerialization', dict, list, set, tuple]: + def from_dict( + cls, serialized_obj: Dict[Encoding, Any] + ) -> Union['BaseSerialization', dict, list, set, tuple]: """Deserializes a python dict stored with type decorators and reconstructs all DAGs and operators it contains. """ @@ -138,14 +140,14 @@ def _is_excluded(cls, var: Any, attrname: str, instance: Any) -> bool: return True return cls._value_is_hardcoded_default(attrname, var, instance) - return ( - isinstance(var, cls._excluded_types) or - cls._value_is_hardcoded_default(attrname, var, instance) + return isinstance(var, cls._excluded_types) or cls._value_is_hardcoded_default( + attrname, var, instance ) @classmethod - def serialize_to_json(cls, object_to_serialize: Union[BaseOperator, DAG], decorated_fields: Set) \ - -> Dict[str, Any]: + def serialize_to_json( + cls, object_to_serialize: Union[BaseOperator, DAG], decorated_fields: Set + ) -> Dict[str, Any]: """Serializes an object to json""" serialized_object: Dict[str, Any] = {} keys_to_serialize = object_to_serialize.get_serialized_fields() @@ -184,10 +186,7 @@ def _serialize(cls, var: Any) -> Any: # Unfortunately there is no support for r return var.value return var elif isinstance(var, dict): - return cls._encode( - {str(k): cls._serialize(v) for k, v in var.items()}, - type_=DAT.DICT - ) + return cls._encode({str(k): cls._serialize(v) for k, v in var.items()}, type_=DAT.DICT) elif isinstance(var, list): return [cls._serialize(v) for v in var] elif HAS_KUBERNETES and isinstance(var, k8s.V1Pod): @@ -215,12 +214,10 @@ def _serialize(cls, var: Any) -> Any: # Unfortunately there is no support for r return str(get_python_source(var)) elif isinstance(var, set): # FIXME: casts set to list in customized serialization in future. - return cls._encode( - [cls._serialize(v) for v in var], type_=DAT.SET) + return cls._encode([cls._serialize(v) for v in var], type_=DAT.SET) elif isinstance(var, tuple): # FIXME: casts tuple to list in customized serialization in future. - return cls._encode( - [cls._serialize(v) for v in var], type_=DAT.TUPLE) + return cls._encode([cls._serialize(v) for v in var], type_=DAT.TUPLE) elif isinstance(var, TaskGroup): return SerializedTaskGroup.serialize_task_group(var) else: @@ -256,9 +253,7 @@ def _deserialize(cls, encoded_var: Any) -> Any: # pylint: disable=too-many-retu return pendulum.from_timestamp(var) elif type_ == DAT.POD: if not HAS_KUBERNETES: - raise RuntimeError( - "Cannot deserialize POD objects without kubernetes libraries installed!" - ) + raise RuntimeError("Cannot deserialize POD objects without kubernetes libraries installed!") pod = PodGenerator.deserialize_model_dict(var) return pod elif type_ == DAT.TIMEDELTA: @@ -306,8 +301,9 @@ def _value_is_hardcoded_default(cls, attrname: str, value: Any, instance: Any) - ``field = field or {}`` set. """ # pylint: disable=unused-argument - if attrname in cls._CONSTRUCTOR_PARAMS and \ - (cls._CONSTRUCTOR_PARAMS[attrname] is value or (value in [{}, []])): + if attrname in cls._CONSTRUCTOR_PARAMS and ( + cls._CONSTRUCTOR_PARAMS[attrname] is value or (value in [{}, []]) + ): return True return False @@ -322,7 +318,8 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization): _decorated_fields = {'executor_config'} _CONSTRUCTOR_PARAMS = { - k: v.default for k, v in signature(BaseOperator.__init__).parameters.items() + k: v.default + for k, v in signature(BaseOperator.__init__).parameters.items() if v.default is not v.empty } @@ -354,8 +351,9 @@ def serialize_operator(cls, op: BaseOperator) -> dict: serialize_op['_task_type'] = op.__class__.__name__ serialize_op['_task_module'] = op.__class__.__module__ if op.operator_extra_links: - serialize_op['_operator_extra_links'] = \ - cls._serialize_operator_extra_links(op.operator_extra_links) + serialize_op['_operator_extra_links'] = cls._serialize_operator_extra_links( + op.operator_extra_links + ) # Store all template_fields as they are if there are JSON Serializable # If not, store them as strings @@ -371,6 +369,7 @@ def serialize_operator(cls, op: BaseOperator) -> dict: def deserialize_operator(cls, encoded_op: Dict[str, Any]) -> BaseOperator: """Deserializes an operator from a JSON object.""" from airflow import plugins_manager + plugins_manager.initialize_extra_operators_links_plugins() if plugins_manager.operator_extra_links is None: @@ -386,8 +385,10 @@ def deserialize_operator(cls, encoded_op: Dict[str, Any]) -> BaseOperator: for ope in plugins_manager.operator_extra_links: for operator in ope.operators: - if operator.__name__ == encoded_op["_task_type"] and \ - operator.__module__ == encoded_op["_task_module"]: + if ( + operator.__name__ == encoded_op["_task_type"] + and operator.__module__ == encoded_op["_task_module"] + ): op_extra_links_from_plugin.update({ope.name: ope}) # If OperatorLinks are defined in Plugins but not in the Operator that is being Serialized @@ -444,10 +445,7 @@ def _is_excluded(cls, var: Any, attrname: str, op: BaseOperator): return super()._is_excluded(var, attrname, op) @classmethod - def _deserialize_operator_extra_links( - cls, - encoded_op_links: list - ) -> Dict[str, BaseOperatorLink]: + def _deserialize_operator_extra_links(cls, encoded_op_links: list) -> Dict[str, BaseOperatorLink]: """ Deserialize Operator Links if the Classes are registered in Airflow Plugins. Error is raised if the OperatorLink is not found in Plugins too. @@ -456,6 +454,7 @@ def _deserialize_operator_extra_links( :return: De-Serialized Operator Link """ from airflow import plugins_manager + plugins_manager.initialize_extra_operators_links_plugins() if plugins_manager.registered_operator_link_classes is None: @@ -502,20 +501,14 @@ def _deserialize_operator_extra_links( log.error("Operator Link class %r not registered", _operator_link_class_path) return {} - op_predefined_extra_link: BaseOperatorLink = cattr.structure( - data, single_op_link_class) + op_predefined_extra_link: BaseOperatorLink = cattr.structure(data, single_op_link_class) - op_predefined_extra_links.update( - {op_predefined_extra_link.name: op_predefined_extra_link} - ) + op_predefined_extra_links.update({op_predefined_extra_link.name: op_predefined_extra_link}) return op_predefined_extra_links @classmethod - def _serialize_operator_extra_links( - cls, - operator_extra_links: Iterable[BaseOperatorLink] - ): + def _serialize_operator_extra_links(cls, operator_extra_links: Iterable[BaseOperatorLink]): """ Serialize Operator Links. Store the import path of the OperatorLink and the arguments passed to it. Example @@ -531,8 +524,9 @@ def _serialize_operator_extra_links( op_link_arguments = {} serialize_operator_extra_links.append( { - "{}.{}".format(operator_extra_link.__class__.__module__, - operator_extra_link.__class__.__name__): op_link_arguments + "{}.{}".format( + operator_extra_link.__class__.__module__, operator_extra_link.__class__.__name__ + ): op_link_arguments } ) @@ -563,7 +557,8 @@ def __get_constructor_defaults(): # pylint: disable=no-method-argument 'access_control': '_access_control', } return { - param_to_attr.get(k, k): v.default for k, v in signature(DAG.__init__).parameters.items() + param_to_attr.get(k, k): v.default + for k, v in signature(DAG.__init__).parameters.items() if v.default is not v.empty } @@ -590,9 +585,7 @@ def deserialize_dag(cls, encoded_dag: Dict[str, Any]) -> 'SerializedDAG': if k == "_downstream_task_ids": v = set(v) elif k == "tasks": - v = { - task["task_id"]: SerializedBaseOperator.deserialize_operator(task) for task in v - } + v = {task["task_id"]: SerializedBaseOperator.deserialize_operator(task) for task in v} k = "task_dict" elif k == "timezone": v = cls._deserialize_timezone(v) @@ -610,9 +603,7 @@ def deserialize_dag(cls, encoded_dag: Dict[str, Any]) -> 'SerializedDAG': # pylint: disable=protected-access if "_task_group" in encoded_dag: dag._task_group = SerializedTaskGroup.deserialize_task_group( # type: ignore - encoded_dag["_task_group"], - None, - dag.task_dict + encoded_dag["_task_group"], None, dag.task_dict ) else: # This must be old data that had no task_group. Create a root TaskGroup and add @@ -641,17 +632,15 @@ def deserialize_dag(cls, encoded_dag: Dict[str, Any]) -> 'SerializedDAG': for task_id in serializable_task.downstream_task_ids: # Bypass set_upstream etc here - it does more than we want - dag.task_dict[task_id]._upstream_task_ids.add(serializable_task.task_id) # noqa: E501 # pylint: disable=protected-access + # noqa: E501 # pylint: disable=protected-access + dag.task_dict[task_id]._upstream_task_ids.add(serializable_task.task_id) return dag @classmethod def to_dict(cls, var: Any) -> dict: """Stringifies DAGs and operators contained by var and returns a dict of var.""" - json_dict = { - "__version": cls.SERIALIZER_VERSION, - "dag": cls.serialize_dag(var) - } + json_dict = {"__version": cls.SERIALIZER_VERSION, "dag": cls.serialize_dag(var)} # Validate Serialized DAG with Json Schema. Raises Error if it mismatches cls.validate_schema(json_dict) @@ -683,15 +672,14 @@ def serialize_task_group(cls, task_group: TaskGroup) -> Optional[Union[Dict[str, "ui_fgcolor": task_group.ui_fgcolor, "children": { label: (DAT.OP, child.task_id) - if isinstance(child, BaseOperator) else - (DAT.TASK_GROUP, SerializedTaskGroup.serialize_task_group(child)) + if isinstance(child, BaseOperator) + else (DAT.TASK_GROUP, SerializedTaskGroup.serialize_task_group(child)) for label, child in task_group.children.items() }, "upstream_group_ids": cls._serialize(list(task_group.upstream_group_ids)), "downstream_group_ids": cls._serialize(list(task_group.downstream_group_ids)), "upstream_task_ids": cls._serialize(list(task_group.upstream_task_ids)), "downstream_task_ids": cls._serialize(list(task_group.downstream_task_ids)), - } return serialize_group @@ -701,7 +689,7 @@ def deserialize_task_group( cls, encoded_group: Dict[str, Any], parent_group: Optional[TaskGroup], - task_dict: Dict[str, BaseOperator] + task_dict: Dict[str, BaseOperator], ) -> Optional[TaskGroup]: """Deserializes a TaskGroup from a JSON object.""" if not encoded_group: @@ -712,15 +700,12 @@ def deserialize_task_group( key: cls._deserialize(encoded_group[key]) for key in ["prefix_group_id", "tooltip", "ui_color", "ui_fgcolor"] } - group = SerializedTaskGroup( - group_id=group_id, - parent_group=parent_group, - **kwargs - ) + group = SerializedTaskGroup(group_id=group_id, parent_group=parent_group, **kwargs) group.children = { - label: task_dict[val] if _type == DAT.OP # type: ignore - else SerializedTaskGroup.deserialize_task_group(val, group, task_dict) for label, (_type, val) - in encoded_group["children"].items() + label: task_dict[val] + if _type == DAT.OP # type: ignore + else SerializedTaskGroup.deserialize_task_group(val, group, task_dict) + for label, (_type, val) in encoded_group["children"].items() } group.upstream_group_ids = set(cls._deserialize(encoded_group["upstream_group_ids"])) group.downstream_group_ids = set(cls._deserialize(encoded_group["downstream_group_ids"])) diff --git a/airflow/settings.py b/airflow/settings.py index 5c321095b3bae..1f38cfc1dad3b 100644 --- a/airflow/settings.py +++ b/airflow/settings.py @@ -49,13 +49,15 @@ log.info("Configured default timezone %s", TIMEZONE) -HEADER = '\n'.join([ - r' ____________ _____________', - r' ____ |__( )_________ __/__ /________ __', - r'____ /| |_ /__ ___/_ /_ __ /_ __ \_ | /| / /', - r'___ ___ | / _ / _ __/ _ / / /_/ /_ |/ |/ /', - r' _/_/ |_/_/ /_/ /_/ /_/ \____/____/|__/', -]) +HEADER = '\n'.join( + [ + r' ____________ _____________', + r' ____ |__( )_________ __/__ /________ __', + r'____ /| |_ /__ ___/_ /_ __ /_ __ \_ | /| / /', + r'___ ___ | / _ / _ __/ _ / / /_/ /_ |/ |/ /', + r' _/_/ |_/_/ /_/ /_/ /_/ \____/____/|__/', + ] +) LOGGING_LEVEL = logging.INFO @@ -146,11 +148,7 @@ def configure_vars(): SQL_ALCHEMY_CONN = conf.get('core', 'SQL_ALCHEMY_CONN') DAGS_FOLDER = os.path.expanduser(conf.get('core', 'DAGS_FOLDER')) - PLUGINS_FOLDER = conf.get( - 'core', - 'plugins_folder', - fallback=os.path.join(AIRFLOW_HOME, 'plugins') - ) + PLUGINS_FOLDER = conf.get('core', 'plugins_folder', fallback=os.path.join(AIRFLOW_HOME, 'plugins')) def configure_orm(disable_connection_pool=False): @@ -172,12 +170,14 @@ def configure_orm(disable_connection_pool=False): engine = create_engine(SQL_ALCHEMY_CONN, connect_args=connect_args, **engine_args) setup_event_handlers(engine) - Session = scoped_session(sessionmaker( - autocommit=False, - autoflush=False, - bind=engine, - expire_on_commit=False, - )) + Session = scoped_session( + sessionmaker( + autocommit=False, + autoflush=False, + bind=engine, + expire_on_commit=False, + ) + ) def prepare_engine_args(disable_connection_pool=False): @@ -218,8 +218,14 @@ def prepare_engine_args(disable_connection_pool=False): # https://docs.sqlalchemy.org/en/13/core/pooling.html#disconnect-handling-pessimistic pool_pre_ping = conf.getboolean('core', 'SQL_ALCHEMY_POOL_PRE_PING', fallback=True) - log.debug("settings.prepare_engine_args(): Using pool settings. pool_size=%d, max_overflow=%d, " - "pool_recycle=%d, pid=%d", pool_size, max_overflow, pool_recycle, os.getpid()) + log.debug( + "settings.prepare_engine_args(): Using pool settings. pool_size=%d, max_overflow=%d, " + "pool_recycle=%d, pid=%d", + pool_size, + max_overflow, + pool_recycle, + os.getpid(), + ) engine_args['pool_size'] = pool_size engine_args['pool_recycle'] = pool_recycle engine_args['pool_pre_ping'] = pool_pre_ping @@ -244,18 +250,22 @@ def dispose_orm(): def configure_adapters(): """Register Adapters and DB Converters""" from pendulum import DateTime as Pendulum + try: from sqlite3 import register_adapter + register_adapter(Pendulum, lambda val: val.isoformat(' ')) except ImportError: pass try: import MySQLdb.converters + MySQLdb.converters.conversions[Pendulum] = MySQLdb.converters.DateTime2literal except ImportError: pass try: import pymysql.converters + pymysql.converters.conversions[Pendulum] = pymysql.converters.escape_datetime except ImportError: pass @@ -334,6 +344,8 @@ def initialize(): # Ensure we close DB connections at scheduler and gunicon worker terminations atexit.register(dispose_orm) + + # pylint: enable=global-statement @@ -341,19 +353,16 @@ def initialize(): KILOBYTE = 1024 MEGABYTE = KILOBYTE * KILOBYTE -WEB_COLORS = {'LIGHTBLUE': '#4d9de0', - 'LIGHTORANGE': '#FF9933'} +WEB_COLORS = {'LIGHTBLUE': '#4d9de0', 'LIGHTORANGE': '#FF9933'} # Updating serialized DAG can not be faster than a minimum interval to reduce database # write rate. -MIN_SERIALIZED_DAG_UPDATE_INTERVAL = conf.getint( - 'core', 'min_serialized_dag_update_interval', fallback=30) +MIN_SERIALIZED_DAG_UPDATE_INTERVAL = conf.getint('core', 'min_serialized_dag_update_interval', fallback=30) # Fetching serialized DAG can not be faster than a minimum interval to reduce database # read rate. This config controls when your DAGs are updated in the Webserver -MIN_SERIALIZED_DAG_FETCH_INTERVAL = conf.getint( - 'core', 'min_serialized_dag_fetch_interval', fallback=10) +MIN_SERIALIZED_DAG_FETCH_INTERVAL = conf.getint('core', 'min_serialized_dag_fetch_interval', fallback=10) # Whether to persist DAG files code in DB. If set to True, Webserver reads file contents # from DB instead of trying to access files in a DAG folder. diff --git a/airflow/stats.py b/airflow/stats.py index b7457da28cbe0..baa2cb5e15260 100644 --- a/airflow/stats.py +++ b/airflow/stats.py @@ -125,15 +125,26 @@ def stat_name_default_handler(stat_name, max_length=250) -> str: if not isinstance(stat_name, str): raise InvalidStatsNameException('The stat_name has to be a string') if len(stat_name) > max_length: - raise InvalidStatsNameException(textwrap.dedent("""\ + raise InvalidStatsNameException( + textwrap.dedent( + """\ The stat_name ({stat_name}) has to be less than {max_length} characters. - """.format(stat_name=stat_name, max_length=max_length))) + """.format( + stat_name=stat_name, max_length=max_length + ) + ) + ) if not all((c in ALLOWED_CHARACTERS) for c in stat_name): - raise InvalidStatsNameException(textwrap.dedent("""\ + raise InvalidStatsNameException( + textwrap.dedent( + """\ The stat name ({stat_name}) has to be composed with characters in {allowed_characters}. - """.format(stat_name=stat_name, - allowed_characters=ALLOWED_CHARACTERS))) + """.format( + stat_name=stat_name, allowed_characters=ALLOWED_CHARACTERS + ) + ) + ) return stat_name @@ -149,6 +160,7 @@ def validate_stat(fn: T) -> T: """Check if stat name contains invalid characters. Log and not emit stats if name is invalid """ + @wraps(fn) def wrapper(_self, stat, *args, **kwargs): try: @@ -315,7 +327,8 @@ def get_statsd_logger(cls): statsd = stats_class( host=conf.get('scheduler', 'statsd_host'), port=conf.getint('scheduler', 'statsd_port'), - prefix=conf.get('scheduler', 'statsd_prefix')) + prefix=conf.get('scheduler', 'statsd_prefix'), + ) allow_list_validator = AllowListValidator(conf.get('scheduler', 'statsd_allow_list', fallback=None)) return SafeStatsdLogger(statsd, allow_list_validator) @@ -323,11 +336,13 @@ def get_statsd_logger(cls): def get_dogstatsd_logger(cls): """Get DataDog statsd logger""" from datadog import DogStatsd + dogstatsd = DogStatsd( host=conf.get('scheduler', 'statsd_host'), port=conf.getint('scheduler', 'statsd_port'), namespace=conf.get('scheduler', 'statsd_prefix'), - constant_tags=cls.get_constant_tags()) + constant_tags=cls.get_constant_tags(), + ) dogstatsd_allow_list = conf.get('scheduler', 'statsd_allow_list', fallback=None) allow_list_validator = AllowListValidator(dogstatsd_allow_list) return SafeDogStatsdLogger(dogstatsd, allow_list_validator) @@ -348,5 +363,6 @@ def get_constant_tags(cls): if TYPE_CHECKING: Stats: StatsLogger else: + class Stats(metaclass=_Stats): # noqa: D101 """Empty class for Stats - we use metaclass to inject the right one""" diff --git a/airflow/task/task_runner/base_task_runner.py b/airflow/task/task_runner/base_task_runner.py index 851d138620eda..743685ec54739 100644 --- a/airflow/task/task_runner/base_task_runner.py +++ b/airflow/task/task_runner/base_task_runner.py @@ -65,10 +65,7 @@ def __init__(self, local_task_job): cfg_path = tmp_configuration_copy(chmod=0o600) # Give ownership of file to user; only they can read and write - subprocess.call( - ['sudo', 'chown', self.run_as_user, cfg_path], - close_fds=True - ) + subprocess.call(['sudo', 'chown', self.run_as_user, cfg_path], close_fds=True) # propagate PYTHONPATH environment variable pythonpath_value = os.environ.get(PYTHONPATH_VAR, '') @@ -102,9 +99,12 @@ def _read_task_logs(self, stream): line = line.decode('utf-8') if not line: break - self.log.info('Job %s: Subtask %s %s', - self._task_instance.job_id, self._task_instance.task_id, - line.rstrip('\n')) + self.log.info( + 'Job %s: Subtask %s %s', + self._task_instance.job_id, + self._task_instance.task_id, + line.rstrip('\n'), + ) def run_command(self, run_with=None): """ @@ -128,7 +128,7 @@ def run_command(self, run_with=None): universal_newlines=True, close_fds=True, env=os.environ.copy(), - preexec_fn=os.setsid + preexec_fn=os.setsid, ) # Start daemon thread to read subprocess logging output diff --git a/airflow/task/task_runner/cgroup_task_runner.py b/airflow/task/task_runner/cgroup_task_runner.py index ab1e88c16611a..6cd0ce6aa048b 100644 --- a/airflow/task/task_runner/cgroup_task_runner.py +++ b/airflow/task/task_runner/cgroup_task_runner.py @@ -92,8 +92,7 @@ def _create_cgroup(self, path): node = node.create_cgroup(path_element) else: self.log.debug( - "Not creating cgroup %s in %s since it already exists", - path_element, node.path.decode() + "Not creating cgroup %s in %s since it already exists", path_element, node.path.decode() ) node = name_to_node[path_element] return node @@ -122,20 +121,21 @@ def _delete_cgroup(self, path): def start(self): # Use bash if it's already in a cgroup cgroups = self._get_cgroup_names() - if ((cgroups.get("cpu") and cgroups.get("cpu") != "/") or - (cgroups.get("memory") and cgroups.get("memory") != "/")): + if (cgroups.get("cpu") and cgroups.get("cpu") != "/") or ( + cgroups.get("memory") and cgroups.get("memory") != "/" + ): self.log.debug( - "Already running in a cgroup (cpu: %s memory: %s) so not " - "creating another one", - cgroups.get("cpu"), cgroups.get("memory") + "Already running in a cgroup (cpu: %s memory: %s) so not " "creating another one", + cgroups.get("cpu"), + cgroups.get("memory"), ) self.process = self.run_command() return # Create a unique cgroup name - cgroup_name = "airflow/{}/{}".format(datetime.datetime.utcnow(). - strftime("%Y-%m-%d"), - str(uuid.uuid4())) + cgroup_name = "airflow/{}/{}".format( + datetime.datetime.utcnow().strftime("%Y-%m-%d"), str(uuid.uuid4()) + ) self.mem_cgroup_name = f"memory/{cgroup_name}" self.cpu_cgroup_name = f"cpu/{cgroup_name}" @@ -151,30 +151,19 @@ def start(self): mem_cgroup_node = self._create_cgroup(self.mem_cgroup_name) self._created_mem_cgroup = True if self._mem_mb_limit > 0: - self.log.debug( - "Setting %s with %s MB of memory", - self.mem_cgroup_name, self._mem_mb_limit - ) + self.log.debug("Setting %s with %s MB of memory", self.mem_cgroup_name, self._mem_mb_limit) mem_cgroup_node.controller.limit_in_bytes = self._mem_mb_limit * 1024 * 1024 # Create the CPU cgroup cpu_cgroup_node = self._create_cgroup(self.cpu_cgroup_name) self._created_cpu_cgroup = True if self._cpu_shares > 0: - self.log.debug( - "Setting %s with %s CPU shares", - self.cpu_cgroup_name, self._cpu_shares - ) + self.log.debug("Setting %s with %s CPU shares", self.cpu_cgroup_name, self._cpu_shares) cpu_cgroup_node.controller.shares = self._cpu_shares # Start the process w/ cgroups - self.log.debug( - "Starting task process with cgroups cpu,memory: %s", - cgroup_name - ) - self.process = self.run_command( - ['cgexec', '-g', f'cpu,memory:{cgroup_name}'] - ) + self.log.debug("Starting task process with cgroups cpu,memory: %s", cgroup_name) + self.process = self.run_command(['cgexec', '-g', f'cpu,memory:{cgroup_name}']) def return_code(self): return_code = self.process.poll() @@ -186,10 +175,12 @@ def return_code(self): # I wasn't able to track down the root cause of the package install failures, but # we might want to revisit that approach at some other point. if return_code == 137: - self.log.error("Task failed with return code of 137. This may indicate " - "that it was killed due to excessive memory usage. " - "Please consider optimizing your task or using the " - "resources argument to reserve more memory for your task") + self.log.error( + "Task failed with return code of 137. This may indicate " + "that it was killed due to excessive memory usage. " + "Please consider optimizing your task or using the " + "resources argument to reserve more memory for your task" + ) return return_code def terminate(self): diff --git a/airflow/ti_deps/dep_context.py b/airflow/ti_deps/dep_context.py index 5534863f60550..a44d14bd7b48f 100644 --- a/airflow/ti_deps/dep_context.py +++ b/airflow/ti_deps/dep_context.py @@ -64,16 +64,17 @@ class DepContext: """ def __init__( - self, - deps=None, - flag_upstream_failed: bool = False, - ignore_all_deps: bool = False, - ignore_depends_on_past: bool = False, - ignore_in_retry_period: bool = False, - ignore_in_reschedule_period: bool = False, - ignore_task_deps: bool = False, - ignore_ti_state: bool = False, - finished_tasks=None): + self, + deps=None, + flag_upstream_failed: bool = False, + ignore_all_deps: bool = False, + ignore_depends_on_past: bool = False, + ignore_in_retry_period: bool = False, + ignore_in_reschedule_period: bool = False, + ignore_task_deps: bool = False, + ignore_ti_state: bool = False, + finished_tasks=None, + ): self.deps = deps or set() self.flag_upstream_failed = flag_upstream_failed self.ignore_all_deps = ignore_all_deps diff --git a/airflow/ti_deps/dependencies_deps.py b/airflow/ti_deps/dependencies_deps.py index 3b37ffa31e1a9..7062995887be6 100644 --- a/airflow/ti_deps/dependencies_deps.py +++ b/airflow/ti_deps/dependencies_deps.py @@ -16,7 +16,10 @@ # under the License. from airflow.ti_deps.dependencies_states import ( - BACKFILL_QUEUEABLE_STATES, QUEUEABLE_STATES, RUNNABLE_STATES, SCHEDULEABLE_STATES, + BACKFILL_QUEUEABLE_STATES, + QUEUEABLE_STATES, + RUNNABLE_STATES, + SCHEDULEABLE_STATES, ) from airflow.ti_deps.deps.dag_ti_slots_available_dep import DagTISlotsAvailableDep from airflow.ti_deps.deps.dag_unpaused_dep import DagUnpausedDep diff --git a/airflow/ti_deps/deps/base_ti_dep.py b/airflow/ti_deps/deps/base_ti_dep.py index 335c55bc3e5f2..b8810c4592770 100644 --- a/airflow/ti_deps/deps/base_ti_dep.py +++ b/airflow/ti_deps/deps/base_ti_dep.py @@ -91,13 +91,11 @@ def get_dep_statuses(self, ti, session, dep_context=None): dep_context = DepContext() if self.IGNOREABLE and dep_context.ignore_all_deps: - yield self._passing_status( - reason="Context specified all dependencies should be ignored.") + yield self._passing_status(reason="Context specified all dependencies should be ignored.") return if self.IS_TASK_DEP and dep_context.ignore_task_deps: - yield self._passing_status( - reason="Context specified all task dependencies should be ignored.") + yield self._passing_status(reason="Context specified all task dependencies should be ignored.") return yield from self._get_dep_statuses(ti, session, dep_context) @@ -117,8 +115,7 @@ def is_met(self, ti, session, dep_context=None): state that can be used by this dependency. :type dep_context: BaseDepContext """ - return all(status.passed for status in - self.get_dep_statuses(ti, session, dep_context)) + return all(status.passed for status in self.get_dep_statuses(ti, session, dep_context)) @provide_session def get_failure_reasons(self, ti, session, dep_context=None): diff --git a/airflow/ti_deps/deps/dag_ti_slots_available_dep.py b/airflow/ti_deps/deps/dag_ti_slots_available_dep.py index e48de027ffee9..dfc96777472e1 100644 --- a/airflow/ti_deps/deps/dag_ti_slots_available_dep.py +++ b/airflow/ti_deps/deps/dag_ti_slots_available_dep.py @@ -31,5 +31,5 @@ def _get_dep_statuses(self, ti, session, dep_context): if ti.task.dag.get_concurrency_reached(session): yield self._failing_status( reason="The maximum number of running tasks ({}) for this task's DAG " - "'{}' has been reached.".format(ti.task.dag.concurrency, - ti.dag_id)) + "'{}' has been reached.".format(ti.task.dag.concurrency, ti.dag_id) + ) diff --git a/airflow/ti_deps/deps/dag_unpaused_dep.py b/airflow/ti_deps/deps/dag_unpaused_dep.py index b65acc3f6cc26..7d33398c9adf4 100644 --- a/airflow/ti_deps/deps/dag_unpaused_dep.py +++ b/airflow/ti_deps/deps/dag_unpaused_dep.py @@ -29,5 +29,4 @@ class DagUnpausedDep(BaseTIDep): @provide_session def _get_dep_statuses(self, ti, session, dep_context): if ti.task.dag.get_is_paused(session): - yield self._failing_status( - reason=f"Task's DAG '{ti.dag_id}' is paused.") + yield self._failing_status(reason=f"Task's DAG '{ti.dag_id}' is paused.") diff --git a/airflow/ti_deps/deps/dagrun_exists_dep.py b/airflow/ti_deps/deps/dagrun_exists_dep.py index bb1528170e99b..f244f15bdd05b 100644 --- a/airflow/ti_deps/deps/dagrun_exists_dep.py +++ b/airflow/ti_deps/deps/dagrun_exists_dep.py @@ -34,24 +34,22 @@ def _get_dep_statuses(self, ti, session, dep_context): if not dagrun: # The import is needed here to avoid a circular dependency from airflow.models.dagrun import DagRun + running_dagruns = DagRun.find( - dag_id=dag.dag_id, - state=State.RUNNING, - external_trigger=False, - session=session + dag_id=dag.dag_id, state=State.RUNNING, external_trigger=False, session=session ) if len(running_dagruns) >= dag.max_active_runs: - reason = ("The maximum number of active dag runs ({}) for this task " - "instance's DAG '{}' has been reached.".format( - dag.max_active_runs, - ti.dag_id)) + reason = ( + "The maximum number of active dag runs ({}) for this task " + "instance's DAG '{}' has been reached.".format(dag.max_active_runs, ti.dag_id) + ) else: reason = "Unknown reason" - yield self._failing_status( - reason=f"Task instance's dagrun did not exist: {reason}.") + yield self._failing_status(reason=f"Task instance's dagrun did not exist: {reason}.") else: if dagrun.state != State.RUNNING: yield self._failing_status( reason="Task instance's dagrun was not in the 'running' state but in " - "the state '{}'.".format(dagrun.state)) + "the state '{}'.".format(dagrun.state) + ) diff --git a/airflow/ti_deps/deps/dagrun_id_dep.py b/airflow/ti_deps/deps/dagrun_id_dep.py index e01b975ba685c..2ed84f1f362ab 100644 --- a/airflow/ti_deps/deps/dagrun_id_dep.py +++ b/airflow/ti_deps/deps/dagrun_id_dep.py @@ -47,8 +47,9 @@ def _get_dep_statuses(self, ti, session, dep_context=None): if not dagrun or not dagrun.run_id or dagrun.run_type != DagRunType.BACKFILL_JOB: yield self._passing_status( reason=f"Task's DagRun doesn't exist or run_id is either NULL " - f"or run_type is not {DagRunType.BACKFILL_JOB}") + f"or run_type is not {DagRunType.BACKFILL_JOB}" + ) else: yield self._failing_status( - reason=f"Task's DagRun run_id is not NULL " - f"and run type is {DagRunType.BACKFILL_JOB}") + reason=f"Task's DagRun run_id is not NULL " f"and run type is {DagRunType.BACKFILL_JOB}" + ) diff --git a/airflow/ti_deps/deps/exec_date_after_start_date_dep.py b/airflow/ti_deps/deps/exec_date_after_start_date_dep.py index c1ca967eaac02..4e33683bbebe9 100644 --- a/airflow/ti_deps/deps/exec_date_after_start_date_dep.py +++ b/airflow/ti_deps/deps/exec_date_after_start_date_dep.py @@ -31,14 +31,13 @@ def _get_dep_statuses(self, ti, session, dep_context): if ti.task.start_date and ti.execution_date < ti.task.start_date: yield self._failing_status( reason="The execution date is {} but this is before the task's start " - "date {}.".format( - ti.execution_date.isoformat(), - ti.task.start_date.isoformat())) + "date {}.".format(ti.execution_date.isoformat(), ti.task.start_date.isoformat()) + ) - if (ti.task.dag and ti.task.dag.start_date and - ti.execution_date < ti.task.dag.start_date): + if ti.task.dag and ti.task.dag.start_date and ti.execution_date < ti.task.dag.start_date: yield self._failing_status( reason="The execution date is {} but this is before the task's " "DAG's start date {}.".format( - ti.execution_date.isoformat(), - ti.task.dag.start_date.isoformat())) + ti.execution_date.isoformat(), ti.task.dag.start_date.isoformat() + ) + ) diff --git a/airflow/ti_deps/deps/not_in_retry_period_dep.py b/airflow/ti_deps/deps/not_in_retry_period_dep.py index d19e0c8117d7e..08aa8ecb529b3 100644 --- a/airflow/ti_deps/deps/not_in_retry_period_dep.py +++ b/airflow/ti_deps/deps/not_in_retry_period_dep.py @@ -33,12 +33,12 @@ class NotInRetryPeriodDep(BaseTIDep): def _get_dep_statuses(self, ti, session, dep_context): if dep_context.ignore_in_retry_period: yield self._passing_status( - reason="The context specified that being in a retry period was permitted.") + reason="The context specified that being in a retry period was permitted." + ) return if ti.state != State.UP_FOR_RETRY: - yield self._passing_status( - reason="The task instance was not marked for retrying.") + yield self._passing_status(reason="The task instance was not marked for retrying.") return # Calculate the date first so that it is always smaller than the timestamp used by @@ -48,6 +48,6 @@ def _get_dep_statuses(self, ti, session, dep_context): if ti.is_premature: yield self._failing_status( reason="Task is not ready for retry yet but will be retried " - "automatically. Current date is {} and task will be retried " - "at {}.".format(cur_date.isoformat(), - next_task_retry_date.isoformat())) + "automatically. Current date is {} and task will be retried " + "at {}.".format(cur_date.isoformat(), next_task_retry_date.isoformat()) + ) diff --git a/airflow/ti_deps/deps/not_previously_skipped_dep.py b/airflow/ti_deps/deps/not_previously_skipped_dep.py index 4ecef93ad847b..08413dfd571d6 100644 --- a/airflow/ti_deps/deps/not_previously_skipped_dep.py +++ b/airflow/ti_deps/deps/not_previously_skipped_dep.py @@ -28,19 +28,18 @@ class NotPreviouslySkippedDep(BaseTIDep): IGNORABLE = True IS_TASK_DEP = True - def _get_dep_statuses( - self, ti, session, dep_context - ): # pylint: disable=signature-differs + def _get_dep_statuses(self, ti, session, dep_context): # pylint: disable=signature-differs from airflow.models.skipmixin import ( - XCOM_SKIPMIXIN_FOLLOWED, XCOM_SKIPMIXIN_KEY, XCOM_SKIPMIXIN_SKIPPED, SkipMixin, + XCOM_SKIPMIXIN_FOLLOWED, + XCOM_SKIPMIXIN_KEY, + XCOM_SKIPMIXIN_SKIPPED, + SkipMixin, ) from airflow.utils.state import State upstream = ti.task.get_direct_relatives(upstream=True) - finished_tasks = dep_context.ensure_finished_tasks( - ti.task.dag, ti.execution_date, session - ) + finished_tasks = dep_context.ensure_finished_tasks(ti.task.dag, ti.execution_date, session) finished_task_ids = {t.task_id for t in finished_tasks} @@ -50,9 +49,7 @@ def _get_dep_statuses( # This can happen if the parent task has not yet run. continue - prev_result = ti.xcom_pull( - task_ids=parent.task_id, key=XCOM_SKIPMIXIN_KEY, session=session - ) + prev_result = ti.xcom_pull(task_ids=parent.task_id, key=XCOM_SKIPMIXIN_KEY, session=session) if prev_result is None: # This can happen if the parent task has not yet run. diff --git a/airflow/ti_deps/deps/pool_slots_available_dep.py b/airflow/ti_deps/deps/pool_slots_available_dep.py index 2c58013f33304..17a8e75dd3c51 100644 --- a/airflow/ti_deps/deps/pool_slots_available_dep.py +++ b/airflow/ti_deps/deps/pool_slots_available_dep.py @@ -49,8 +49,8 @@ def _get_dep_statuses(self, ti, session, dep_context=None): pools = session.query(Pool).filter(Pool.pool == pool_name).all() if not pools: yield self._failing_status( - reason=("Tasks using non-existent pool '%s' will not be scheduled", - pool_name)) + reason=("Tasks using non-existent pool '%s' will not be scheduled", pool_name) + ) return else: # Controlled by UNIQUE key in slot_pool table, @@ -62,12 +62,14 @@ def _get_dep_statuses(self, ti, session, dep_context=None): if open_slots <= (ti.pool_slots - 1): yield self._failing_status( - reason=("Not scheduling since there are %s open slots in pool %s " - "and require %s pool slots", - open_slots, pool_name, ti.pool_slots) + reason=( + "Not scheduling since there are %s open slots in pool %s and require %s pool slots", + open_slots, + pool_name, + ti.pool_slots, + ) ) else: yield self._passing_status( - reason=( - "There are enough open slots in %s to execute the task", pool_name) + reason=("There are enough open slots in %s to execute the task", pool_name) ) diff --git a/airflow/ti_deps/deps/prev_dagrun_dep.py b/airflow/ti_deps/deps/prev_dagrun_dep.py index bdd2bf622be6c..d5e717afb95d4 100644 --- a/airflow/ti_deps/deps/prev_dagrun_dep.py +++ b/airflow/ti_deps/deps/prev_dagrun_dep.py @@ -35,26 +35,24 @@ class PrevDagrunDep(BaseTIDep): def _get_dep_statuses(self, ti, session, dep_context): if dep_context.ignore_depends_on_past: yield self._passing_status( - reason="The context specified that the state of past DAGs could be " - "ignored.") + reason="The context specified that the state of past DAGs could be " "ignored." + ) return if not ti.task.depends_on_past: - yield self._passing_status( - reason="The task did not have depends_on_past set.") + yield self._passing_status(reason="The task did not have depends_on_past set.") return # Don't depend on the previous task instance if we are the first task dag = ti.task.dag if dag.catchup: if dag.previous_schedule(ti.execution_date) is None: - yield self._passing_status( - reason="This task does not have a schedule or is @once" - ) + yield self._passing_status(reason="This task does not have a schedule or is @once") return if dag.previous_schedule(ti.execution_date) < ti.task.start_date: yield self._passing_status( - reason="This task instance was the first task instance for its task.") + reason="This task instance was the first task instance for its task." + ) return else: dr = ti.get_dagrun(session=session) @@ -62,25 +60,28 @@ def _get_dep_statuses(self, ti, session, dep_context): if not last_dagrun: yield self._passing_status( - reason="This task instance was the first task instance for its task.") + reason="This task instance was the first task instance for its task." + ) return previous_ti = ti.get_previous_ti(session=session) if not previous_ti: yield self._failing_status( reason="depends_on_past is true for this task's DAG, but the previous " - "task instance has not run yet.") + "task instance has not run yet." + ) return if previous_ti.state not in {State.SKIPPED, State.SUCCESS}: yield self._failing_status( reason="depends_on_past is true for this task, but the previous task " - "instance {} is in the state '{}' which is not a successful " - "state.".format(previous_ti, previous_ti.state)) + "instance {} is in the state '{}' which is not a successful " + "state.".format(previous_ti, previous_ti.state) + ) previous_ti.task = ti.task - if (ti.task.wait_for_downstream and - not previous_ti.are_dependents_done(session=session)): + if ti.task.wait_for_downstream and not previous_ti.are_dependents_done(session=session): yield self._failing_status( reason="The tasks downstream of the previous task instance {} haven't " - "completed (and wait_for_downstream is True).".format(previous_ti)) + "completed (and wait_for_downstream is True).".format(previous_ti) + ) diff --git a/airflow/ti_deps/deps/ready_to_reschedule.py b/airflow/ti_deps/deps/ready_to_reschedule.py index bdce41fb46e37..a4396070eeada 100644 --- a/airflow/ti_deps/deps/ready_to_reschedule.py +++ b/airflow/ti_deps/deps/ready_to_reschedule.py @@ -42,13 +42,14 @@ def _get_dep_statuses(self, ti, session, dep_context): """ if dep_context.ignore_in_reschedule_period: yield self._passing_status( - reason="The context specified that being in a reschedule period was " - "permitted.") + reason="The context specified that being in a reschedule period was " "permitted." + ) return if ti.state not in self.RESCHEDULEABLE_STATES: yield self._passing_status( - reason="The task instance is not in State_UP_FOR_RESCHEDULE or NONE state.") + reason="The task instance is not in State_UP_FOR_RESCHEDULE or NONE state." + ) return task_reschedule = ( @@ -57,18 +58,17 @@ def _get_dep_statuses(self, ti, session, dep_context): .first() ) if not task_reschedule: - yield self._passing_status( - reason="There is no reschedule request for this task instance.") + yield self._passing_status(reason="There is no reschedule request for this task instance.") return now = timezone.utcnow() next_reschedule_date = task_reschedule.reschedule_date if now >= next_reschedule_date: - yield self._passing_status( - reason="Task instance id ready for reschedule.") + yield self._passing_status(reason="Task instance id ready for reschedule.") return yield self._failing_status( reason="Task is not ready for reschedule yet but will be rescheduled " - "automatically. Current date is {} and task will be rescheduled " - "at {}.".format(now.isoformat(), next_reschedule_date.isoformat())) + "automatically. Current date is {} and task will be rescheduled " + "at {}.".format(now.isoformat(), next_reschedule_date.isoformat()) + ) diff --git a/airflow/ti_deps/deps/runnable_exec_date_dep.py b/airflow/ti_deps/deps/runnable_exec_date_dep.py index 20cb8b6995e1f..30d07c5d52392 100644 --- a/airflow/ti_deps/deps/runnable_exec_date_dep.py +++ b/airflow/ti_deps/deps/runnable_exec_date_dep.py @@ -36,21 +36,17 @@ def _get_dep_statuses(self, ti, session, dep_context): if ti.execution_date > cur_date and not ti.task.dag.allow_future_exec_dates: yield self._failing_status( reason="Execution date {} is in the future (the current " - "date is {}).".format(ti.execution_date.isoformat(), - cur_date.isoformat())) + "date is {}).".format(ti.execution_date.isoformat(), cur_date.isoformat()) + ) if ti.task.end_date and ti.execution_date > ti.task.end_date: yield self._failing_status( reason="The execution date is {} but this is after the task's end date " - "{}.".format( - ti.execution_date.isoformat(), - ti.task.end_date.isoformat())) + "{}.".format(ti.execution_date.isoformat(), ti.task.end_date.isoformat()) + ) - if (ti.task.dag and - ti.task.dag.end_date and - ti.execution_date > ti.task.dag.end_date): + if ti.task.dag and ti.task.dag.end_date and ti.execution_date > ti.task.dag.end_date: yield self._failing_status( reason="The execution date is {} but this is after the task's DAG's " - "end date {}.".format( - ti.execution_date.isoformat(), - ti.task.dag.end_date.isoformat())) + "end date {}.".format(ti.execution_date.isoformat(), ti.task.dag.end_date.isoformat()) + ) diff --git a/airflow/ti_deps/deps/task_concurrency_dep.py b/airflow/ti_deps/deps/task_concurrency_dep.py index d1456f9112ce8..2824301771d81 100644 --- a/airflow/ti_deps/deps/task_concurrency_dep.py +++ b/airflow/ti_deps/deps/task_concurrency_dep.py @@ -34,10 +34,8 @@ def _get_dep_statuses(self, ti, session, dep_context): return if ti.get_num_running_task_instances(session) >= ti.task.task_concurrency: - yield self._failing_status(reason="The max task concurrency " - "has been reached.") + yield self._failing_status(reason="The max task concurrency " "has been reached.") return else: - yield self._passing_status(reason="The max task concurrency " - "has not been reached.") + yield self._passing_status(reason="The max task concurrency " "has not been reached.") return diff --git a/airflow/ti_deps/deps/task_not_running_dep.py b/airflow/ti_deps/deps/task_not_running_dep.py index 184e65d1474fe..b0c9f3979bbbf 100644 --- a/airflow/ti_deps/deps/task_not_running_dep.py +++ b/airflow/ti_deps/deps/task_not_running_dep.py @@ -40,5 +40,4 @@ def _get_dep_statuses(self, ti, session, dep_context=None): yield self._passing_status(reason="Task is not in running state.") return - yield self._failing_status( - reason='Task is in the running state') + yield self._failing_status(reason='Task is in the running state') diff --git a/airflow/ti_deps/deps/trigger_rule_dep.py b/airflow/ti_deps/deps/trigger_rule_dep.py index 3544781e21601..e350863590278 100644 --- a/airflow/ti_deps/deps/trigger_rule_dep.py +++ b/airflow/ti_deps/deps/trigger_rule_dep.py @@ -46,15 +46,19 @@ def _get_states_count_upstream_ti(ti, finished_tasks): :type finished_tasks: list[airflow.models.TaskInstance] """ counter = Counter(task.state for task in finished_tasks if task.task_id in ti.task.upstream_task_ids) - return counter.get(State.SUCCESS, 0), counter.get(State.SKIPPED, 0), counter.get(State.FAILED, 0), \ - counter.get(State.UPSTREAM_FAILED, 0), sum(counter.values()) + return ( + counter.get(State.SUCCESS, 0), + counter.get(State.SKIPPED, 0), + counter.get(State.FAILED, 0), + counter.get(State.UPSTREAM_FAILED, 0), + sum(counter.values()), + ) @provide_session def _get_dep_statuses(self, ti, session, dep_context): # Checking that all upstream dependencies have succeeded if not ti.task.upstream_list: - yield self._passing_status( - reason="The task instance did not have any upstream tasks.") + yield self._passing_status(reason="The task instance did not have any upstream tasks.") return if ti.task.trigger_rule == TR.DUMMY: @@ -62,8 +66,8 @@ def _get_dep_statuses(self, ti, session, dep_context): return # see if the task name is in the task upstream for our task successes, skipped, failed, upstream_failed, done = self._get_states_count_upstream_ti( - ti=ti, - finished_tasks=dep_context.ensure_finished_tasks(ti.task.dag, ti.execution_date, session)) + ti=ti, finished_tasks=dep_context.ensure_finished_tasks(ti.task.dag, ti.execution_date, session) + ) yield from self._evaluate_trigger_rule( ti=ti, @@ -73,19 +77,13 @@ def _get_dep_statuses(self, ti, session, dep_context): upstream_failed=upstream_failed, done=done, flag_upstream_failed=dep_context.flag_upstream_failed, - session=session) + session=session, + ) @provide_session def _evaluate_trigger_rule( # pylint: disable=too-many-branches - self, - ti, - successes, - skipped, - failed, - upstream_failed, - done, - flag_upstream_failed, - session): + self, ti, successes, skipped, failed, upstream_failed, done, flag_upstream_failed, session + ): """ Yields a dependency status that indicate whether the given task instance's trigger rule was met. @@ -115,8 +113,12 @@ def _evaluate_trigger_rule( # pylint: disable=too-many-branches trigger_rule = task.trigger_rule upstream_done = done >= upstream upstream_tasks_state = { - "total": upstream, "successes": successes, "skipped": skipped, - "failed": failed, "upstream_failed": upstream_failed, "done": done + "total": upstream, + "successes": successes, + "skipped": skipped, + "failed": failed, + "upstream_failed": upstream_failed, + "done": done, } # TODO(aoen): Ideally each individual trigger rules would be its own class, but # this isn't very feasible at the moment since the database queries need to be @@ -154,68 +156,77 @@ def _evaluate_trigger_rule( # pylint: disable=too-many-branches yield self._failing_status( reason="Task's trigger rule '{}' requires one upstream " "task success, but none were found. " - "upstream_tasks_state={}, upstream_task_ids={}" - .format(trigger_rule, upstream_tasks_state, task.upstream_task_ids)) + "upstream_tasks_state={}, upstream_task_ids={}".format( + trigger_rule, upstream_tasks_state, task.upstream_task_ids + ) + ) elif trigger_rule == TR.ONE_FAILED: if not failed and not upstream_failed: yield self._failing_status( reason="Task's trigger rule '{}' requires one upstream " "task failure, but none were found. " - "upstream_tasks_state={}, upstream_task_ids={}" - .format(trigger_rule, upstream_tasks_state, task.upstream_task_ids)) + "upstream_tasks_state={}, upstream_task_ids={}".format( + trigger_rule, upstream_tasks_state, task.upstream_task_ids + ) + ) elif trigger_rule == TR.ALL_SUCCESS: num_failures = upstream - successes if num_failures > 0: yield self._failing_status( reason="Task's trigger rule '{}' requires all upstream " "tasks to have succeeded, but found {} non-success(es). " - "upstream_tasks_state={}, upstream_task_ids={}" - .format(trigger_rule, num_failures, upstream_tasks_state, - task.upstream_task_ids)) + "upstream_tasks_state={}, upstream_task_ids={}".format( + trigger_rule, num_failures, upstream_tasks_state, task.upstream_task_ids + ) + ) elif trigger_rule == TR.ALL_FAILED: num_successes = upstream - failed - upstream_failed if num_successes > 0: yield self._failing_status( reason="Task's trigger rule '{}' requires all upstream " "tasks to have failed, but found {} non-failure(s). " - "upstream_tasks_state={}, upstream_task_ids={}" - .format(trigger_rule, num_successes, upstream_tasks_state, - task.upstream_task_ids)) + "upstream_tasks_state={}, upstream_task_ids={}".format( + trigger_rule, num_successes, upstream_tasks_state, task.upstream_task_ids + ) + ) elif trigger_rule == TR.ALL_DONE: if not upstream_done: yield self._failing_status( reason="Task's trigger rule '{}' requires all upstream " "tasks to have completed, but found {} task(s) that " "were not done. upstream_tasks_state={}, " - "upstream_task_ids={}" - .format(trigger_rule, upstream_done, upstream_tasks_state, - task.upstream_task_ids)) + "upstream_task_ids={}".format( + trigger_rule, upstream_done, upstream_tasks_state, task.upstream_task_ids + ) + ) elif trigger_rule == TR.NONE_FAILED: num_failures = upstream - successes - skipped if num_failures > 0: yield self._failing_status( reason="Task's trigger rule '{}' requires all upstream " "tasks to have succeeded or been skipped, but found {} non-success(es). " - "upstream_tasks_state={}, upstream_task_ids={}" - .format(trigger_rule, num_failures, upstream_tasks_state, - task.upstream_task_ids)) + "upstream_tasks_state={}, upstream_task_ids={}".format( + trigger_rule, num_failures, upstream_tasks_state, task.upstream_task_ids + ) + ) elif trigger_rule == TR.NONE_FAILED_OR_SKIPPED: num_failures = upstream - successes - skipped if num_failures > 0: yield self._failing_status( reason="Task's trigger rule '{}' requires all upstream " "tasks to have succeeded or been skipped, but found {} non-success(es). " - "upstream_tasks_state={}, upstream_task_ids={}" - .format(trigger_rule, num_failures, upstream_tasks_state, - task.upstream_task_ids)) + "upstream_tasks_state={}, upstream_task_ids={}".format( + trigger_rule, num_failures, upstream_tasks_state, task.upstream_task_ids + ) + ) elif trigger_rule == TR.NONE_SKIPPED: if not upstream_done or (skipped > 0): yield self._failing_status( reason="Task's trigger rule '{}' requires all upstream " "tasks to not have been skipped, but found {} task(s) skipped. " - "upstream_tasks_state={}, upstream_task_ids={}" - .format(trigger_rule, skipped, upstream_tasks_state, - task.upstream_task_ids)) + "upstream_tasks_state={}, upstream_task_ids={}".format( + trigger_rule, skipped, upstream_tasks_state, task.upstream_task_ids + ) + ) else: - yield self._failing_status( - reason=f"No strategy to evaluate trigger rule '{trigger_rule}'.") + yield self._failing_status(reason=f"No strategy to evaluate trigger rule '{trigger_rule}'.") diff --git a/airflow/ti_deps/deps/valid_state_dep.py b/airflow/ti_deps/deps/valid_state_dep.py index 1609afb819e0c..ea39ae772f383 100644 --- a/airflow/ti_deps/deps/valid_state_dep.py +++ b/airflow/ti_deps/deps/valid_state_dep.py @@ -38,8 +38,7 @@ def __init__(self, valid_states): super().__init__() if not valid_states: - raise AirflowException( - 'ValidStatesDep received an empty set of valid states.') + raise AirflowException('ValidStatesDep received an empty set of valid states.') self._valid_states = valid_states def __eq__(self, other): @@ -51,8 +50,7 @@ def __hash__(self): @provide_session def _get_dep_statuses(self, ti, session, dep_context): if dep_context.ignore_ti_state: - yield self._passing_status( - reason="Context specified that state should be ignored.") + yield self._passing_status(reason="Context specified that state should be ignored.") return if ti.state in self._valid_states: @@ -61,5 +59,5 @@ def _get_dep_statuses(self, ti, session, dep_context): yield self._failing_status( reason="Task is in the '{}' state which is not a valid state for " - "execution. The task must be cleared in order to be run.".format( - ti.state)) + "execution. The task must be cleared in order to be run.".format(ti.state) + ) diff --git a/airflow/typing_compat.py b/airflow/typing_compat.py index 935dd0ea4de2b..6fd6d8c8252f2 100644 --- a/airflow/typing_compat.py +++ b/airflow/typing_compat.py @@ -26,7 +26,9 @@ # python 3.8 we can safely remove this shim import after Airflow drops # support for <3.8 from typing import ( # type: ignore # noqa # pylint: disable=unused-import - Protocol, TypedDict, runtime_checkable, + Protocol, + TypedDict, + runtime_checkable, ) except ImportError: from typing_extensions import Protocol, TypedDict, runtime_checkable # type: ignore # noqa diff --git a/airflow/utils/callback_requests.py b/airflow/utils/callback_requests.py index fe8017c721fb8..5561955dcf20a 100644 --- a/airflow/utils/callback_requests.py +++ b/airflow/utils/callback_requests.py @@ -56,7 +56,7 @@ def __init__( full_filepath: str, simple_task_instance: SimpleTaskInstance, is_failure_callback: Optional[bool] = True, - msg: Optional[str] = None + msg: Optional[str] = None, ): super().__init__(full_filepath=full_filepath, msg=msg) self.simple_task_instance = simple_task_instance @@ -80,7 +80,7 @@ def __init__( dag_id: str, execution_date: datetime, is_failure_callback: Optional[bool] = True, - msg: Optional[str] = None + msg: Optional[str] = None, ): super().__init__(full_filepath=full_filepath, msg=msg) self.dag_id = dag_id diff --git a/airflow/utils/cli.py b/airflow/utils/cli.py index d1fdf63d880d4..0578acb7597bf 100644 --- a/airflow/utils/cli.py +++ b/airflow/utils/cli.py @@ -63,6 +63,7 @@ def action_logging(f: T) -> T: :param f: function instance :return: wrapped function """ + @functools.wraps(f) def wrapper(*args, **kwargs): """ @@ -76,8 +77,9 @@ def wrapper(*args, **kwargs): if not args: raise ValueError("Args should be set") if not isinstance(args[0], Namespace): - raise ValueError("1st positional argument should be argparse.Namespace instance," - f"but is {type(args[0])}") + raise ValueError( + "1st positional argument should be argparse.Namespace instance," f"but is {type(args[0])}" + ) metrics = _build_metrics(f.__name__, args[0]) cli_action_loggers.on_pre_execution(**metrics) try: @@ -115,12 +117,17 @@ def _build_metrics(func_name, namespace): if command.startswith(f'{sensitive_field}='): full_command[idx] = f'{sensitive_field}={"*" * 8}' - metrics = {'sub_command': func_name, 'start_datetime': datetime.utcnow(), - 'full_command': f'{full_command}', 'user': getpass.getuser()} + metrics = { + 'sub_command': func_name, + 'start_datetime': datetime.utcnow(), + 'full_command': f'{full_command}', + 'user': getpass.getuser(), + } if not isinstance(namespace, Namespace): - raise ValueError("namespace argument should be argparse.Namespace instance," - f"but is {type(namespace)}") + raise ValueError( + "namespace argument should be argparse.Namespace instance," f"but is {type(namespace)}" + ) tmp_dic = vars(namespace) metrics['dag_id'] = tmp_dic.get('dag_id') metrics['task_id'] = tmp_dic.get('task_id') @@ -135,7 +142,8 @@ def _build_metrics(func_name, namespace): extra=extra, task_id=metrics.get('task_id'), dag_id=metrics.get('dag_id'), - execution_date=metrics.get('execution_date')) + execution_date=metrics.get('execution_date'), + ) metrics['log'] = log return metrics @@ -157,7 +165,8 @@ def get_dag_by_file_location(dag_id: str): if dag_model is None: raise AirflowException( 'dag_id could not be found: {}. Either the dag did not exist or it failed to ' - 'parse.'.format(dag_id)) + 'parse.'.format(dag_id) + ) dagbag = DagBag(dag_folder=dag_model.fileloc) return dagbag.dags[dag_id] @@ -168,7 +177,8 @@ def get_dag(subdir: Optional[str], dag_id: str) -> DAG: if dag_id not in dagbag.dags: raise AirflowException( 'dag_id could not be found: {}. Either the dag did not exist or it failed to ' - 'parse.'.format(dag_id)) + 'parse.'.format(dag_id) + ) return dagbag.dags[dag_id] @@ -181,7 +191,8 @@ def get_dags(subdir: Optional[str], dag_id: str, use_regex: bool = False): if not matched_dags: raise AirflowException( 'dag_id could not be found with regex: {}. Either the dag did not exist ' - 'or it failed to parse.'.format(dag_id)) + 'or it failed to parse.'.format(dag_id) + ) return matched_dags @@ -238,11 +249,9 @@ def sigquit_handler(sig, frame): # pylint: disable=unused-argument id_to_name = {th.ident: th.name for th in threading.enumerate()} code = [] for thread_id, stack in sys._current_frames().items(): # pylint: disable=protected-access - code.append("\n# Thread: {}({})" - .format(id_to_name.get(thread_id, ""), thread_id)) + code.append("\n# Thread: {}({})".format(id_to_name.get(thread_id, ""), thread_id)) for filename, line_number, name, line in traceback.extract_stack(stack): - code.append('File: "{}", line {}, in {}' - .format(filename, line_number, name)) + code.append(f'File: "{filename}", line {line_number}, in {name}') if line: code.append(f" {line.strip()}") print("\n".join(code)) diff --git a/airflow/utils/compression.py b/airflow/utils/compression.py index dff8c8d3ed7ab..aa39b031dd715 100644 --- a/airflow/utils/compression.py +++ b/airflow/utils/compression.py @@ -25,16 +25,16 @@ def uncompress_file(input_file_name, file_extension, dest_dir): """Uncompress gz and bz2 files""" if file_extension.lower() not in ('.gz', '.bz2'): - raise NotImplementedError("Received {} format. Only gz and bz2 " - "files can currently be uncompressed." - .format(file_extension)) + raise NotImplementedError( + "Received {} format. Only gz and bz2 " + "files can currently be uncompressed.".format(file_extension) + ) if file_extension.lower() == '.gz': fmodule = gzip.GzipFile elif file_extension.lower() == '.bz2': fmodule = bz2.BZ2File - with fmodule(input_file_name, mode='rb') as f_compressed,\ - NamedTemporaryFile(dir=dest_dir, - mode='wb', - delete=False) as f_uncompressed: + with fmodule(input_file_name, mode='rb') as f_compressed, NamedTemporaryFile( + dir=dest_dir, mode='wb', delete=False + ) as f_uncompressed: shutil.copyfileobj(f_compressed, f_uncompressed) return f_uncompressed.name diff --git a/airflow/utils/dag_processing.py b/airflow/utils/dag_processing.py index e1b0358f379ee..793d211c57a4c 100644 --- a/airflow/utils/dag_processing.py +++ b/airflow/utils/dag_processing.py @@ -194,13 +194,12 @@ def __init__( dag_directory: str, max_runs: int, processor_factory: Callable[ - [str, List[CallbackRequest], Optional[List[str]], bool], - AbstractDagFileProcessorProcess + [str, List[CallbackRequest], Optional[List[str]], bool], AbstractDagFileProcessorProcess ], processor_timeout: timedelta, dag_ids: Optional[List[str]], pickle_dags: bool, - async_mode: bool + async_mode: bool, ): super().__init__() self._file_path_queue: List[str] = [] @@ -241,8 +240,8 @@ def start(self) -> None: child_signal_conn, self._dag_ids, self._pickle_dags, - self._async_mode - ) + self._async_mode, + ), ) self._process = process @@ -326,15 +325,12 @@ def wait_until_finished(self) -> None: def _run_processor_manager( dag_directory: str, max_runs: int, - processor_factory: Callable[ - [str, List[CallbackRequest]], - AbstractDagFileProcessorProcess - ], + processor_factory: Callable[[str, List[CallbackRequest]], AbstractDagFileProcessorProcess], processor_timeout: timedelta, signal_conn: MultiprocessingConnection, dag_ids: Optional[List[str]], pickle_dags: bool, - async_mode: bool + async_mode: bool, ) -> None: # Make this process start as a new process group - that makes it easy @@ -355,14 +351,16 @@ def _run_processor_manager( importlib.reload(airflow.settings) airflow.settings.initialize() del os.environ['CONFIG_PROCESSOR_MANAGER_LOGGER'] - processor_manager = DagFileProcessorManager(dag_directory, - max_runs, - processor_factory, - processor_timeout, - signal_conn, - dag_ids, - pickle_dags, - async_mode) + processor_manager = DagFileProcessorManager( + dag_directory, + max_runs, + processor_factory, + processor_timeout, + signal_conn, + dag_ids, + pickle_dags, + async_mode, + ) processor_manager.start() @@ -397,7 +395,8 @@ def _heartbeat_manager(self): if not self.done: self.log.warning( "DagFileProcessorManager (PID=%d) exited with exit code %d - re-launching", - self._process.pid, self._process.exitcode + self._process.pid, + self._process.exitcode, ) self.start() @@ -409,7 +408,9 @@ def _heartbeat_manager(self): Stats.incr('dag_processing.manager_stalls') self.log.error( "DagFileProcessorManager (PID=%d) last sent a heartbeat %.2f seconds ago! Restarting it", - self._process.pid, parsing_stat_age) + self._process.pid, + parsing_stat_age, + ) reap_process_group(self._process.pid, logger=self.log) self.start() @@ -486,18 +487,17 @@ class DagFileProcessorManager(LoggingMixin): # pylint: disable=too-many-instanc :type async_mode: bool """ - def __init__(self, - dag_directory: str, - max_runs: int, - processor_factory: Callable[ - [str, List[CallbackRequest]], - AbstractDagFileProcessorProcess - ], - processor_timeout: timedelta, - signal_conn: MultiprocessingConnection, - dag_ids: Optional[List[str]], - pickle_dags: bool, - async_mode: bool = True): + def __init__( + self, + dag_directory: str, + max_runs: int, + processor_factory: Callable[[str, List[CallbackRequest]], AbstractDagFileProcessorProcess], + processor_timeout: timedelta, + signal_conn: MultiprocessingConnection, + dag_ids: Optional[List[str]], + pickle_dags: bool, + async_mode: bool = True, + ): super().__init__() self._file_paths: List[str] = [] self._file_path_queue: List[str] = [] @@ -514,20 +514,18 @@ def __init__(self, if 'sqlite' in conf.get('core', 'sql_alchemy_conn') and self._parallelism > 1: self.log.warning( "Because we cannot use more than 1 thread (max_threads = " - "%d ) when using sqlite. So we set parallelism to 1.", self._parallelism + "%d ) when using sqlite. So we set parallelism to 1.", + self._parallelism, ) self._parallelism = 1 # Parse and schedule each file no faster than this interval. - self._file_process_interval = conf.getint('scheduler', - 'min_file_process_interval') + self._file_process_interval = conf.getint('scheduler', 'min_file_process_interval') # How often to print out DAG file processing stats to the log. Default to # 30 seconds. - self.print_stats_interval = conf.getint('scheduler', - 'print_stats_interval') + self.print_stats_interval = conf.getint('scheduler', 'print_stats_interval') # How many seconds do we wait for tasks to heartbeat before mark them as zombies. - self._zombie_threshold_secs = ( - conf.getint('scheduler', 'scheduler_zombie_task_threshold')) + self._zombie_threshold_secs = conf.getint('scheduler', 'scheduler_zombie_task_threshold') # Should store dag file source in a database? self.store_dag_code = STORE_DAG_CODE @@ -641,8 +639,7 @@ def _run_parsing_loop(self): # This shouldn't happen, as in sync mode poll should block for # ever. Lets be defensive about that. self.log.warning( - "wait() unexpectedly returned nothing ready after infinite timeout (%r)!", - poll_time + "wait() unexpectedly returned nothing ready after infinite timeout (%r)!", poll_time ) continue @@ -676,8 +673,7 @@ def _run_parsing_loop(self): self._num_run += 1 if not self._async_mode: - self.log.debug( - "Waiting for processors to finish since we're using sqlite") + self.log.debug("Waiting for processors to finish since we're using sqlite") # Wait until the running DAG processors are finished before # sending a DagParsingStat message back. This means the Agent # can tell we've got to the end of this iteration when it sees @@ -692,15 +688,17 @@ def _run_parsing_loop(self): all_files_processed = all(self.get_last_finish_time(x) is not None for x in self.file_paths) max_runs_reached = self.max_runs_reached() - dag_parsing_stat = DagParsingStat(self._file_paths, - max_runs_reached, - all_files_processed, - ) + dag_parsing_stat = DagParsingStat( + self._file_paths, + max_runs_reached, + all_files_processed, + ) self._signal_conn.send(dag_parsing_stat) if max_runs_reached: - self.log.info("Exiting dag parsing loop as all files " - "have been processed %s times", self._max_runs) + self.log.info( + "Exiting dag parsing loop as all files " "have been processed %s times", self._max_runs + ) break if self._async_mode: @@ -740,12 +738,12 @@ def _refresh_dag_dir(self): if self.store_dag_code: from airflow.models.dagcode import DagCode + DagCode.remove_deleted_code(self._file_paths) def _print_stat(self): """Occasionally print out stats about how fast the files are getting processed""" - if 0 < self.print_stats_interval < ( - timezone.utcnow() - self.last_stat_print_time).total_seconds(): + if 0 < self.print_stats_interval < (timezone.utcnow() - self.last_stat_print_time).total_seconds(): if self._file_paths: self._log_file_processing_stats(self._file_paths) self.last_stat_print_time = timezone.utcnow() @@ -760,9 +758,7 @@ def clear_nonexistent_import_errors(self, session): """ query = session.query(errors.ImportError) if self._file_paths: - query = query.filter( - ~errors.ImportError.filename.in_(self._file_paths) - ) + query = query.filter(~errors.ImportError.filename.in_(self._file_paths)) query.delete(synchronize_session='fetch') session.commit() @@ -783,13 +779,7 @@ def _log_file_processing_stats(self, known_file_paths): # Last Runtime: If the process ran before, how long did it take to # finish in seconds # Last Run: When the file finished processing in the previous run. - headers = ["File Path", - "PID", - "Runtime", - "# DAGs", - "# Errors", - "Last Runtime", - "Last Run"] + headers = ["File Path", "PID", "Runtime", "# DAGs", "# Errors", "Last Runtime", "Last Run"] rows = [] now = timezone.utcnow() @@ -802,7 +792,7 @@ def _log_file_processing_stats(self, known_file_paths): processor_pid = self.get_pid(file_path) processor_start_time = self.get_start_time(file_path) - runtime = ((now - processor_start_time) if processor_start_time else None) + runtime = (now - processor_start_time) if processor_start_time else None last_run = self.get_last_finish_time(file_path) if last_run: seconds_ago = (now - last_run).total_seconds() @@ -812,34 +802,33 @@ def _log_file_processing_stats(self, known_file_paths): # TODO: Remove before Airflow 2.0 Stats.timing(f'dag_processing.last_runtime.{file_name}', runtime) - rows.append((file_path, - processor_pid, - runtime, - num_dags, - num_errors, - last_runtime, - last_run)) + rows.append((file_path, processor_pid, runtime, num_dags, num_errors, last_runtime, last_run)) # Sort by longest last runtime. (Can't sort None values in python3) rows = sorted(rows, key=lambda x: x[3] or 0.0) formatted_rows = [] for file_path, pid, runtime, num_dags, num_errors, last_runtime, last_run in rows: - formatted_rows.append((file_path, - pid, - f"{runtime.total_seconds():.2f}s" if runtime else None, - num_dags, - num_errors, - f"{last_runtime:.2f}s" if last_runtime else None, - last_run.strftime("%Y-%m-%dT%H:%M:%S") if last_run else None - )) - log_str = ("\n" + - "=" * 80 + - "\n" + - "DAG File Processing Stats\n\n" + - tabulate(formatted_rows, headers=headers) + - "\n" + - "=" * 80) + formatted_rows.append( + ( + file_path, + pid, + f"{runtime.total_seconds():.2f}s" if runtime else None, + num_dags, + num_errors, + f"{last_runtime:.2f}s" if last_runtime else None, + last_run.strftime("%Y-%m-%dT%H:%M:%S") if last_run else None, + ) + ) + log_str = ( + "\n" + + "=" * 80 + + "\n" + + "DAG File Processing Stats\n\n" + + tabulate(formatted_rows, headers=headers) + + "\n" + + "=" * 80 + ) self.log.info(log_str) @@ -937,8 +926,7 @@ def set_file_paths(self, new_file_paths): :return: None """ self._file_paths = new_file_paths - self._file_path_queue = [x for x in self._file_path_queue - if x in new_file_paths] + self._file_path_queue = [x for x in self._file_path_queue if x in new_file_paths] # Stop processors that are working on deleted files filtered_processors = {} for file_path, processor in self._processors.items(): @@ -966,8 +954,7 @@ def _collect_results_from_processor(self, processor) -> None: num_dags, count_import_errors = processor.result else: self.log.error( - "Processor for %s exited with return code %s.", - processor.file_path, processor.exit_code + "Processor for %s exited with return code %s.", processor.file_path, processor.exit_code ) count_import_errors = -1 num_dags = 0 @@ -993,11 +980,9 @@ def collect_results(self) -> None: self._processors.pop(processor.file_path) self._collect_results_from_processor(processor) - self.log.debug("%s/%s DAG parsing processes running", - len(self._processors), self._parallelism) + self.log.debug("%s/%s DAG parsing processes running", len(self._processors), self._parallelism) - self.log.debug("%s file paths queued for processing", - len(self._file_path_queue)) + self.log.debug("%s file paths queued for processing", len(self._file_path_queue)) def start_new_processes(self): """Start more processors if we have enough slots and files to process""" @@ -1005,19 +990,14 @@ def start_new_processes(self): file_path = self._file_path_queue.pop(0) callback_to_execute_for_file = self._callback_to_execute[file_path] processor = self._processor_factory( - file_path, - callback_to_execute_for_file, - self._dag_ids, - self._pickle_dags) + file_path, callback_to_execute_for_file, self._dag_ids, self._pickle_dags + ) del self._callback_to_execute[file_path] Stats.incr('dag_processing.processes') processor.start() - self.log.debug( - "Started a process (PID: %s) to generate tasks for %s", - processor.pid, file_path - ) + self.log.debug("Started a process (PID: %s) to generate tasks for %s", processor.pid, file_path) self._processors[file_path] = processor self.waitables[processor.waitable_handle] = processor @@ -1031,39 +1011,36 @@ def prepare_file_path_queue(self): file_paths_recently_processed = [] for file_path in self._file_paths: last_finish_time = self.get_last_finish_time(file_path) - if (last_finish_time is not None and - (now - last_finish_time).total_seconds() < - self._file_process_interval): + if ( + last_finish_time is not None + and (now - last_finish_time).total_seconds() < self._file_process_interval + ): file_paths_recently_processed.append(file_path) - files_paths_at_run_limit = [file_path - for file_path, stat in self._file_stats.items() - if stat.run_count == self._max_runs] + files_paths_at_run_limit = [ + file_path for file_path, stat in self._file_stats.items() if stat.run_count == self._max_runs + ] - files_paths_to_queue = list(set(self._file_paths) - - set(file_paths_in_progress) - - set(file_paths_recently_processed) - - set(files_paths_at_run_limit)) + files_paths_to_queue = list( + set(self._file_paths) + - set(file_paths_in_progress) + - set(file_paths_recently_processed) + - set(files_paths_at_run_limit) + ) for file_path, processor in self._processors.items(): self.log.debug( "File path %s is still being processed (started: %s)", - processor.file_path, processor.start_time.isoformat() + processor.file_path, + processor.start_time.isoformat(), ) - self.log.debug( - "Queuing the following files for processing:\n\t%s", - "\n\t".join(files_paths_to_queue) - ) + self.log.debug("Queuing the following files for processing:\n\t%s", "\n\t".join(files_paths_to_queue)) for file_path in files_paths_to_queue: if file_path not in self._file_stats: self._file_stats[file_path] = DagFileStat( - num_dags=0, - import_errors=0, - last_finish_time=None, - last_duration=None, - run_count=0 + num_dags=0, import_errors=0, last_finish_time=None, last_duration=None, run_count=0 ) self._file_path_queue.extend(files_paths_to_queue) @@ -1075,10 +1052,13 @@ def _find_zombies(self, session): and update the current zombie list. """ now = timezone.utcnow() - if not self._last_zombie_query_time or \ - (now - self._last_zombie_query_time).total_seconds() > self._zombie_query_interval: + if ( + not self._last_zombie_query_time + or (now - self._last_zombie_query_time).total_seconds() > self._zombie_query_interval + ): # to avoid circular imports from airflow.jobs.local_task_job import LocalTaskJob as LJ + self.log.info("Finding 'running' jobs without a recent heartbeat") TI = airflow.models.TaskInstance DM = airflow.models.DagModel @@ -1095,7 +1075,8 @@ def _find_zombies(self, session): LJ.state != State.RUNNING, LJ.latest_heartbeat < limit_dttm, ) - ).all() + ) + .all() ) self._last_zombie_query_time = timezone.utcnow() @@ -1116,9 +1097,11 @@ def _kill_timed_out_processors(self): duration = now - processor.start_time if duration > self._processor_timeout: self.log.error( - "Processor for %s with PID %s started at %s has timed out, " - "killing it.", - file_path, processor.pid, processor.start_time.isoformat()) + "Processor for %s with PID %s started at %s has timed out, " "killing it.", + file_path, + processor.pid, + processor.start_time.isoformat(), + ) Stats.decr('dag_processing.processes') Stats.incr('dag_processing.processor_timeouts') # TODO: Remove after Airflow 2.0 @@ -1164,8 +1147,9 @@ def emit_metrics(self): parse_time = (timezone.utcnow() - self._parsing_start_time).total_seconds() Stats.gauge('dag_processing.total_parse_time', parse_time) Stats.gauge('dagbag_size', sum(stat.num_dags for stat in self._file_stats.values())) - Stats.gauge('dag_processing.import_errors', - sum(stat.import_errors for stat in self._file_stats.values())) + Stats.gauge( + 'dag_processing.import_errors', sum(stat.import_errors for stat in self._file_stats.values()) + ) # TODO: Remove before Airflow 2.0 Stats.gauge('collect_dags', parse_time) diff --git a/airflow/utils/dates.py b/airflow/utils/dates.py index 2377e12cc2b1a..de5e52b668e57 100644 --- a/airflow/utils/dates.py +++ b/airflow/utils/dates.py @@ -244,11 +244,7 @@ def days_ago(n, hour=0, minute=0, second=0, microsecond=0): Get a datetime object representing `n` days ago. By default the time is set to midnight. """ - today = timezone.utcnow().replace( - hour=hour, - minute=minute, - second=second, - microsecond=microsecond) + today = timezone.utcnow().replace(hour=hour, minute=minute, second=second, microsecond=microsecond) return today - timedelta(days=n) diff --git a/airflow/utils/db.py b/airflow/utils/db.py index 479ced65cd5ca..92cb224a0ec2b 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -25,14 +25,34 @@ from airflow.configuration import conf from airflow.jobs.base_job import BaseJob # noqa: F401 # pylint: disable=unused-import from airflow.models import ( # noqa: F401 # pylint: disable=unused-import - DAG, XCOM_RETURN_KEY, BaseOperator, BaseOperatorLink, Connection, DagBag, DagModel, DagPickle, DagRun, - DagTag, Log, Pool, SkipMixin, SlaMiss, TaskFail, TaskInstance, TaskReschedule, Variable, XCom, + DAG, + XCOM_RETURN_KEY, + BaseOperator, + BaseOperatorLink, + Connection, + DagBag, + DagModel, + DagPickle, + DagRun, + DagTag, + Log, + Pool, + SkipMixin, + SlaMiss, + TaskFail, + TaskInstance, + TaskReschedule, + Variable, + XCom, ) + # We need to add this model manually to get reset working well from airflow.models.serialized_dag import SerializedDagModel # noqa: F401 # pylint: disable=unused-import + # TODO: remove create_session once we decide to break backward compatibility from airflow.utils.session import ( # noqa: F401 # pylint: disable=unused-import - create_session, provide_session, + create_session, + provide_session, ) log = logging.getLogger(__name__) @@ -52,8 +72,7 @@ def add_default_pool_if_not_exists(session=None): if not Pool.get_pool(Pool.DEFAULT_POOL_NAME, session=session): default_pool = Pool( pool=Pool.DEFAULT_POOL_NAME, - slots=conf.getint(section='core', key='non_pooled_task_slot_count', - fallback=128), + slots=conf.getint(section='core', key='non_pooled_task_slot_count', fallback=128), description="Default pool", ) session.add(default_pool) @@ -72,14 +91,14 @@ def create_default_connections(session=None): password="", schema="airflow", ), - session + session, ) merge_conn( Connection( conn_id="aws_default", conn_type="aws", ), - session + session, ) merge_conn( Connection( @@ -87,7 +106,7 @@ def create_default_connections(session=None): conn_type="azure_batch", login="", password="", - extra='''{"account_url": ""}''' + extra='''{"account_url": ""}''', ) ) merge_conn( @@ -96,7 +115,7 @@ def create_default_connections(session=None): conn_type="azure_container_instances", extra='{"tenantId": "", "subscriptionId": "" }', ), - session + session, ) merge_conn( Connection( @@ -104,15 +123,16 @@ def create_default_connections(session=None): conn_type="azure_cosmos", extra='{"database_name": "", "collection_name": "" }', ), - session + session, ) merge_conn( Connection( - conn_id='azure_data_explorer_default', conn_type='azure_data_explorer', + conn_id='azure_data_explorer_default', + conn_type='azure_data_explorer', host='https://.kusto.windows.net', extra='''{"auth_method": "", "tenant": "", "certificate": "", - "thumbprint": ""}''' + "thumbprint": ""}''', ), session, ) @@ -122,7 +142,7 @@ def create_default_connections(session=None): conn_type="azure_data_lake", extra='{"tenant": "", "account_name": "" }', ), - session + session, ) merge_conn( Connection( @@ -131,7 +151,7 @@ def create_default_connections(session=None): host="cassandra", port=9042, ), - session + session, ) merge_conn( Connection( @@ -139,7 +159,7 @@ def create_default_connections(session=None): conn_type="databricks", host="localhost", ), - session + session, ) merge_conn( Connection( @@ -148,7 +168,7 @@ def create_default_connections(session=None): host="", password="", ), - session + session, ) merge_conn( Connection( @@ -158,7 +178,7 @@ def create_default_connections(session=None): port=8082, extra='{"endpoint": "druid/v2/sql"}', ), - session + session, ) merge_conn( Connection( @@ -168,7 +188,7 @@ def create_default_connections(session=None): port=8081, extra='{"endpoint": "druid/indexer/v1/task"}', ), - session + session, ) merge_conn( Connection( @@ -176,9 +196,9 @@ def create_default_connections(session=None): conn_type="elasticsearch", host="localhost", schema="http", - port=9200 + port=9200, ), - session + session, ) merge_conn( Connection( @@ -229,7 +249,7 @@ def create_default_connections(session=None): } """, ), - session + session, ) merge_conn( Connection( @@ -243,7 +263,7 @@ def create_default_connections(session=None): } """, ), - session + session, ) merge_conn( Connection( @@ -259,7 +279,7 @@ def create_default_connections(session=None): conn_type="google_cloud_platform", schema="default", ), - session + session, ) merge_conn( Connection( @@ -270,7 +290,7 @@ def create_default_connections(session=None): extra='{"use_beeline": true, "auth": ""}', schema="default", ), - session + session, ) merge_conn( Connection( @@ -280,7 +300,7 @@ def create_default_connections(session=None): schema="default", port=10000, ), - session + session, ) merge_conn( Connection( @@ -288,14 +308,14 @@ def create_default_connections(session=None): conn_type="http", host="https://www.httpbin.org/", ), - session + session, ) merge_conn( Connection( conn_id='kubernetes_default', conn_type='kubernetes', ), - session + session, ) merge_conn( Connection( @@ -304,19 +324,11 @@ def create_default_connections(session=None): host='localhost', port=7070, login="ADMIN", - password="KYLIN" - ), - session - ) - merge_conn( - Connection( - conn_id="livy_default", - conn_type="livy", - host="livy", - port=8998 + password="KYLIN", ), - session + session, ) + merge_conn(Connection(conn_id="livy_default", conn_type="livy", host="livy", port=8998), session) merge_conn( Connection( conn_id="local_mysql", @@ -326,7 +338,7 @@ def create_default_connections(session=None): password="airflow", schema="airflow", ), - session + session, ) merge_conn( Connection( @@ -336,17 +348,9 @@ def create_default_connections(session=None): extra='{"authMechanism": "PLAIN"}', port=9083, ), - session - ) - merge_conn( - Connection( - conn_id="mongo_default", - conn_type="mongo", - host="mongo", - port=27017 - ), - session + session, ) + merge_conn(Connection(conn_id="mongo_default", conn_type="mongo", host="mongo", port=27017), session) merge_conn( Connection( conn_id="mssql_default", @@ -354,7 +358,7 @@ def create_default_connections(session=None): host="localhost", port=1433, ), - session + session, ) merge_conn( Connection( @@ -364,7 +368,7 @@ def create_default_connections(session=None): schema="airflow", host="mysql", ), - session + session, ) merge_conn( Connection( @@ -373,7 +377,7 @@ def create_default_connections(session=None): host="", password="", ), - session + session, ) merge_conn( Connection( @@ -381,7 +385,7 @@ def create_default_connections(session=None): conn_type="pig_cli", schema="default", ), - session + session, ) merge_conn( Connection( @@ -390,7 +394,7 @@ def create_default_connections(session=None): host="localhost", port=9000, ), - session + session, ) merge_conn( Connection( @@ -400,7 +404,7 @@ def create_default_connections(session=None): port=9000, extra='{"endpoint": "/query", "schema": "http"}', ), - session + session, ) merge_conn( Connection( @@ -411,7 +415,7 @@ def create_default_connections(session=None): schema="airflow", host="postgres", ), - session + session, ) merge_conn( Connection( @@ -421,7 +425,7 @@ def create_default_connections(session=None): schema="hive", port=3400, ), - session + session, ) merge_conn( Connection( @@ -429,7 +433,7 @@ def create_default_connections(session=None): conn_type="qubole", host="localhost", ), - session + session, ) merge_conn( Connection( @@ -439,7 +443,7 @@ def create_default_connections(session=None): port=6379, extra='{"db": 0}', ), - session + session, ) merge_conn( Connection( @@ -447,7 +451,7 @@ def create_default_connections(session=None): conn_type="segment", extra='{"write_key": "my-segment-write-key"}', ), - session + session, ) merge_conn( Connection( @@ -458,7 +462,7 @@ def create_default_connections(session=None): login="airflow", extra='{"key_file": "~/.ssh/id_rsa", "no_host_key_check": true}', ), - session + session, ) merge_conn( Connection( @@ -467,7 +471,7 @@ def create_default_connections(session=None): host="yarn", extra='{"queue": "root.default"}', ), - session + session, ) merge_conn( Connection( @@ -475,7 +479,7 @@ def create_default_connections(session=None): conn_type="sqlite", host="/tmp/sqlite_default.db", ), - session + session, ) merge_conn( Connection( @@ -483,7 +487,7 @@ def create_default_connections(session=None): conn_type="sqoop", host="rdbms", ), - session + session, ) merge_conn( Connection( @@ -491,7 +495,7 @@ def create_default_connections(session=None): conn_type="ssh", host="localhost", ), - session + session, ) merge_conn( Connection( @@ -502,7 +506,7 @@ def create_default_connections(session=None): password="password", extra='{"site_id": "my_site"}', ), - session + session, ) merge_conn( Connection( @@ -511,7 +515,7 @@ def create_default_connections(session=None): host="localhost", port=5433, ), - session + session, ) merge_conn( Connection( @@ -519,7 +523,7 @@ def create_default_connections(session=None): conn_type="wasb", extra='{"sas_token": null}', ), - session + session, ) merge_conn( Connection( @@ -528,7 +532,7 @@ def create_default_connections(session=None): host="localhost", port=50070, ), - session + session, ) merge_conn( Connection( @@ -536,7 +540,7 @@ def create_default_connections(session=None): conn_type='yandexcloud', schema='default', ), - session + session, ) @@ -555,6 +559,7 @@ def initdb(): DAG.deactivate_unknown_dags(dagbag.dags.keys()) from flask_appbuilder.models.sqla import Base + Base.metadata.create_all(settings.engine) # pylint: disable=no-member @@ -590,8 +595,7 @@ def check_migrations(timeout): if source_heads == db_heads: break if ticker >= timeout: - raise TimeoutError("There are still unapplied migrations after {} " - "seconds.".format(ticker)) + raise TimeoutError("There are still unapplied migrations after {} " "seconds.".format(ticker)) ticker += 1 time.sleep(1) log.info('Waiting for migrations... %s second(s)', ticker) @@ -649,6 +653,7 @@ def drop_airflow_models(connection): Base.metadata.remove(chart) # alembic adds significant import time, so we import it lazily from alembic.migration import MigrationContext # noqa + migration_ctx = MigrationContext.configure(connection) version = migration_ctx._version # noqa pylint: disable=protected-access if version.exists(connection): @@ -662,6 +667,7 @@ def drop_flask_models(connection): @return: """ from flask_appbuilder.models.sqla import Base + Base.metadata.drop_all(connection) # pylint: disable=no-member diff --git a/airflow/utils/decorators.py b/airflow/utils/decorators.py index e8238c971e95e..40fb9e0702f7b 100644 --- a/airflow/utils/decorators.py +++ b/airflow/utils/decorators.py @@ -46,17 +46,19 @@ def apply_defaults(func: T) -> T: # have a different sig_cache. sig_cache = signature(func) non_optional_args = { - name for (name, param) in sig_cache.parameters.items() - if param.default == param.empty and - param.name != 'self' and - param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)} + name + for (name, param) in sig_cache.parameters.items() + if param.default == param.empty + and param.name != 'self' + and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD) + } @wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: from airflow.models.dag import DagContext + if len(args) > 1: - raise AirflowException( - "Use keyword arguments when initializing operators") + raise AirflowException("Use keyword arguments when initializing operators") dag_args: Dict[str, Any] = {} dag_params: Dict[str, Any] = {} @@ -91,6 +93,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: result = func(*args, **kwargs) return result + return cast(T, wrapper) diff --git a/airflow/utils/dot_renderer.py b/airflow/utils/dot_renderer.py index 65e9a4a099fc2..990c7a7d126fb 100644 --- a/airflow/utils/dot_renderer.py +++ b/airflow/utils/dot_renderer.py @@ -55,8 +55,14 @@ def render_dag(dag: DAG, tis: Optional[List[TaskInstance]] = None) -> graphviz.D :return: Graphviz object :rtype: graphviz.Digraph """ - dot = graphviz.Digraph(dag.dag_id, graph_attr={"rankdir": dag.orientation if dag.orientation else "LR", - "labelloc": "t", "label": dag.dag_id}) + dot = graphviz.Digraph( + dag.dag_id, + graph_attr={ + "rankdir": dag.orientation if dag.orientation else "LR", + "labelloc": "t", + "label": dag.dag_id, + }, + ) states_by_task_id = None if tis is not None: states_by_task_id = {ti.task_id: ti.state for ti in tis} @@ -66,16 +72,20 @@ def render_dag(dag: DAG, tis: Optional[List[TaskInstance]] = None) -> graphviz.D "style": "filled,rounded", } if states_by_task_id is None: - node_attrs.update({ - "color": _refine_color(task.ui_fgcolor), - "fillcolor": _refine_color(task.ui_color), - }) + node_attrs.update( + { + "color": _refine_color(task.ui_fgcolor), + "fillcolor": _refine_color(task.ui_color), + } + ) else: state = states_by_task_id.get(task.task_id, State.NONE) - node_attrs.update({ - "color": State.color_fg(state), - "fillcolor": State.color(state), - }) + node_attrs.update( + { + "color": State.color_fg(state), + "fillcolor": State.color(state), + } + ) dot.node( task.task_id, _attributes=node_attrs, diff --git a/airflow/utils/email.py b/airflow/utils/email.py index f32757d2633c6..17d3b6e97bd7c 100644 --- a/airflow/utils/email.py +++ b/airflow/utils/email.py @@ -32,17 +32,35 @@ log = logging.getLogger(__name__) -def send_email(to: Union[List[str], Iterable[str]], subject: str, html_content: str, - files=None, dryrun=False, cc=None, bcc=None, - mime_subtype='mixed', mime_charset='utf-8', **kwargs): +def send_email( + to: Union[List[str], Iterable[str]], + subject: str, + html_content: str, + files=None, + dryrun=False, + cc=None, + bcc=None, + mime_subtype='mixed', + mime_charset='utf-8', + **kwargs, +): """Send email using backend specified in EMAIL_BACKEND.""" backend = conf.getimport('email', 'EMAIL_BACKEND') to_list = get_email_address_list(to) to_comma_separated = ", ".join(to_list) - return backend(to_comma_separated, subject, html_content, files=files, - dryrun=dryrun, cc=cc, bcc=bcc, - mime_subtype=mime_subtype, mime_charset=mime_charset, **kwargs) + return backend( + to_comma_separated, + subject, + html_content, + files=files, + dryrun=dryrun, + cc=cc, + bcc=bcc, + mime_subtype=mime_subtype, + mime_charset=mime_charset, + **kwargs, + ) def send_email_smtp( @@ -132,10 +150,7 @@ def build_mime_message( for fname in files or []: basename = os.path.basename(fname) with open(fname, "rb") as file: - part = MIMEApplication( - file.read(), - Name=basename - ) + part = MIMEApplication(file.read(), Name=basename) part['Content-Disposition'] = f'attachment; filename="{basename}"' part['Content-ID'] = f'<{basename}>' msg.attach(part) diff --git a/airflow/utils/file.py b/airflow/utils/file.py index 405a6b029548e..553c506696e5b 100644 --- a/airflow/utils/file.py +++ b/airflow/utils/file.py @@ -31,9 +31,11 @@ def TemporaryDirectory(*args, **kwargs): # pylint: disable=invalid-name """This function is deprecated. Please use `tempfile.TemporaryDirectory`""" import warnings from tempfile import TemporaryDirectory as TmpDir + warnings.warn( "This function is deprecated. Please use `tempfile.TemporaryDirectory`", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) return TmpDir(*args, **kwargs) @@ -49,9 +51,11 @@ def mkdirs(path, mode): :type mode: int """ import warnings + warnings.warn( f"This function is deprecated. Please use `pathlib.Path({path}).mkdir`", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) Path(path).mkdir(mode=mode, parents=True, exist_ok=True) @@ -85,9 +89,7 @@ def open_maybe_zipped(fileloc, mode='r'): return open(fileloc, mode=mode) -def find_path_from_directory( - base_dir_path: str, - ignore_file_name: str) -> Generator[str, None, None]: +def find_path_from_directory(base_dir_path: str, ignore_file_name: str) -> Generator[str, None, None]: """ Search the file and return the path of the file that should not be ignored. :param base_dir_path: the base path to be searched for. @@ -110,8 +112,9 @@ def find_path_from_directory( dirs[:] = [ subdir for subdir in dirs - if not any(p.search( - os.path.join(os.path.relpath(root, str(base_dir_path)), subdir)) for p in patterns) + if not any( + p.search(os.path.join(os.path.relpath(root, str(base_dir_path)), subdir)) for p in patterns + ) ] patterns_by_dir.update({os.path.join(root, sd): patterns.copy() for sd in dirs}) @@ -126,11 +129,12 @@ def find_path_from_directory( yield str(abs_file_path) -def list_py_file_paths(directory: str, - safe_mode: bool = conf.getboolean('core', 'DAG_DISCOVERY_SAFE_MODE', fallback=True), - include_examples: Optional[bool] = None, - include_smart_sensor: Optional[bool] = - conf.getboolean('smart_sensor', 'use_smart_sensor')): +def list_py_file_paths( + directory: str, + safe_mode: bool = conf.getboolean('core', 'DAG_DISCOVERY_SAFE_MODE', fallback=True), + include_examples: Optional[bool] = None, + include_smart_sensor: Optional[bool] = conf.getboolean('smart_sensor', 'use_smart_sensor'), +): """ Traverse a directory and look for Python files. @@ -159,10 +163,12 @@ def list_py_file_paths(directory: str, find_dag_file_paths(directory, file_paths, safe_mode) if include_examples: from airflow import example_dags + example_dag_folder = example_dags.__path__[0] # type: ignore file_paths.extend(list_py_file_paths(example_dag_folder, safe_mode, False, False)) if include_smart_sensor: from airflow import smart_sensor_dags + smart_sensor_dag_folder = smart_sensor_dags.__path__[0] # type: ignore file_paths.extend(list_py_file_paths(smart_sensor_dag_folder, safe_mode, False, False)) return file_paths @@ -170,8 +176,7 @@ def list_py_file_paths(directory: str, def find_dag_file_paths(directory: str, file_paths: list, safe_mode: bool): """Finds file paths of all DAG files.""" - for file_path in find_path_from_directory( - directory, ".airflowignore"): + for file_path in find_path_from_directory(directory, ".airflowignore"): try: if not os.path.isfile(file_path): continue diff --git a/airflow/utils/helpers.py b/airflow/utils/helpers.py index 01f88fcbac790..5ccb618f58766 100644 --- a/airflow/utils/helpers.py +++ b/airflow/utils/helpers.py @@ -36,12 +36,12 @@ def validate_key(k, max_length=250): if not isinstance(k, str): raise TypeError("The key has to be a string") elif len(k) > max_length: - raise AirflowException( - f"The key has to be less than {max_length} characters") + raise AirflowException(f"The key has to be less than {max_length} characters") elif not KEY_REGEX.match(k): raise AirflowException( "The key ({k}) has to be made of alphanumeric characters, dashes, " - "dots and underscores exclusively".format(k=k)) + "dots and underscores exclusively".format(k=k) + ) else: return True @@ -101,15 +101,10 @@ def chunks(items: List[T], chunk_size: int) -> Generator[List[T], None, None]: if chunk_size <= 0: raise ValueError('Chunk size must be a positive integer') for i in range(0, len(items), chunk_size): - yield items[i:i + chunk_size] + yield items[i : i + chunk_size] -def reduce_in_chunks( - fn: Callable[[S, List[T]], S], - iterable: List[T], - initializer: S, - chunk_size: int = 0 -): +def reduce_in_chunks(fn: Callable[[S, List[T]], S], iterable: List[T], initializer: S, chunk_size: int = 0): """ Reduce the given list of items by splitting it into chunks of the given size and passing each chunk through the reducer @@ -155,10 +150,12 @@ def render_log_filename(ti, try_number, filename_template): jinja_context['try_number'] = try_number return filename_jinja_template.render(**jinja_context) - return filename_template.format(dag_id=ti.dag_id, - task_id=ti.task_id, - execution_date=ti.execution_date.isoformat(), - try_number=try_number) + return filename_template.format( + dag_id=ti.dag_id, + task_id=ti.task_id, + execution_date=ti.execution_date.isoformat(), + try_number=try_number, + ) def convert_camel_to_snake(camel_str): @@ -191,7 +188,8 @@ def chain(*args, **kwargs): """This function is deprecated. Please use `airflow.models.baseoperator.chain`.""" warnings.warn( "This function is deprecated. Please use `airflow.models.baseoperator.chain`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) return import_string('airflow.models.baseoperator.chain')(*args, **kwargs) @@ -200,6 +198,7 @@ def cross_downstream(*args, **kwargs): """This function is deprecated. Please use `airflow.models.baseoperator.cross_downstream`.""" warnings.warn( "This function is deprecated. Please use `airflow.models.baseoperator.cross_downstream`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) return import_string('airflow.models.baseoperator.cross_downstream')(*args, **kwargs) diff --git a/airflow/utils/json.py b/airflow/utils/json.py index 6d1cc5f036d50..d388f11e39899 100644 --- a/airflow/utils/json.py +++ b/airflow/utils/json.py @@ -43,17 +43,32 @@ def _default(obj): return obj.strftime('%Y-%m-%dT%H:%M:%SZ') elif isinstance(obj, date): return obj.strftime('%Y-%m-%d') - elif isinstance(obj, (np.int_, np.intc, np.intp, np.int8, np.int16, - np.int32, np.int64, np.uint8, np.uint16, - np.uint32, np.uint64)): + elif isinstance( + obj, + ( + np.int_, + np.intc, + np.intp, + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64, + ), + ): return int(obj) elif isinstance(obj, np.bool_): return bool(obj) - elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64, - np.complex_, np.complex64, np.complex128)): + elif isinstance( + obj, (np.float_, np.float16, np.float32, np.float64, np.complex_, np.complex64, np.complex128) + ): return float(obj) elif k8s is not None and isinstance(obj, k8s.V1Pod): from airflow.kubernetes.pod_generator import PodGenerator + return PodGenerator.serialize_pod(obj) raise TypeError(f"Object of type '{obj.__class__.__name__}' is not JSON serializable") diff --git a/airflow/utils/log/cloudwatch_task_handler.py b/airflow/utils/log/cloudwatch_task_handler.py index 1ba2586151698..cee5bfd87def8 100644 --- a/airflow/utils/log/cloudwatch_task_handler.py +++ b/airflow/utils/log/cloudwatch_task_handler.py @@ -23,5 +23,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.amazon.aws.log.cloudwatch_task_handler`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/utils/log/colored_log.py b/airflow/utils/log/colored_log.py index 3d7a8f13e3182..59ca3a2f0465a 100644 --- a/airflow/utils/log/colored_log.py +++ b/airflow/utils/log/colored_log.py @@ -66,12 +66,10 @@ def _color_record_args(self, record: LogRecord) -> LogRecord: elif isinstance(record.args, dict): if self._count_number_of_arguments_in_message(record) > 1: # Case of logging.debug("a %(a)d b %(b)s", {'a':1, 'b':2}) - record.args = { - key: self._color_arg(value) for key, value in record.args.items() - } + record.args = {key: self._color_arg(value) for key, value in record.args.items()} else: # Case of single dict passed to formatted string - record.args = self._color_arg(record.args) # type: ignore + record.args = self._color_arg(record.args) # type: ignore elif isinstance(record.args, str): record.args = self._color_arg(record.args) return record @@ -84,8 +82,9 @@ def _color_record_traceback(self, record: LogRecord) -> LogRecord: record.exc_text = self.formatException(record.exc_info) if record.exc_text: - record.exc_text = self.color(self.log_colors, record.levelname) + \ - record.exc_text + escape_codes['reset'] + record.exc_text = ( + self.color(self.log_colors, record.levelname) + record.exc_text + escape_codes['reset'] + ) return record @@ -96,4 +95,5 @@ def format(self, record: LogRecord) -> str: return super().format(record) except ValueError: # I/O operation on closed file from logging import Formatter + return Formatter().format(record) diff --git a/airflow/utils/log/es_task_handler.py b/airflow/utils/log/es_task_handler.py index 019fa39de944b..64de08ee9671d 100644 --- a/airflow/utils/log/es_task_handler.py +++ b/airflow/utils/log/es_task_handler.py @@ -23,5 +23,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.elasticsearch.log.es_task_handler`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/utils/log/file_processor_handler.py b/airflow/utils/log/file_processor_handler.py index 220787b6a3028..9d2768eb5da7a 100644 --- a/airflow/utils/log/file_processor_handler.py +++ b/airflow/utils/log/file_processor_handler.py @@ -40,8 +40,7 @@ def __init__(self, base_log_folder, filename_template): self.handler = None self.base_log_folder = base_log_folder self.dag_dir = os.path.expanduser(settings.DAGS_FOLDER) - self.filename_template, self.filename_jinja_template = \ - parse_template_string(filename_template) + self.filename_template, self.filename_jinja_template = parse_template_string(filename_template) self._cur_date = datetime.today() Path(self._get_log_directory()).mkdir(parents=True, exist_ok=True) @@ -85,6 +84,7 @@ def _render_filename(self, filename): # is always inside the log dir as other DAGs. To be differentiate with regular DAGs, # their logs will be in the `log_dir/native_dags`. import airflow + airflow_directory = airflow.__path__[0] if filename.startswith(airflow_directory): filename = os.path.join("native_dags", os.path.relpath(filename, airflow_directory)) @@ -119,17 +119,14 @@ def _symlink_latest_log_directory(self): if os.readlink(latest_log_directory_path) != log_directory: os.unlink(latest_log_directory_path) os.symlink(log_directory, latest_log_directory_path) - elif (os.path.isdir(latest_log_directory_path) or - os.path.isfile(latest_log_directory_path)): + elif os.path.isdir(latest_log_directory_path) or os.path.isfile(latest_log_directory_path): logging.warning( - "%s already exists as a dir/file. Skip creating symlink.", - latest_log_directory_path + "%s already exists as a dir/file. Skip creating symlink.", latest_log_directory_path ) else: os.symlink(log_directory, latest_log_directory_path) except OSError: - logging.warning("OSError while attempting to symlink " - "the latest log directory") + logging.warning("OSError while attempting to symlink " "the latest log directory") def _init_file(self, filename): """ @@ -138,8 +135,7 @@ def _init_file(self, filename): :param filename: task instance object :return: relative log path of the given task instance """ - relative_log_file_path = os.path.join( - self._get_log_directory(), self._render_filename(filename)) + relative_log_file_path = os.path.join(self._get_log_directory(), self._render_filename(filename)) log_file_path = os.path.abspath(relative_log_file_path) directory = os.path.dirname(log_file_path) diff --git a/airflow/utils/log/file_task_handler.py b/airflow/utils/log/file_task_handler.py index b1170174c35aa..3a06ff3e93013 100644 --- a/airflow/utils/log/file_task_handler.py +++ b/airflow/utils/log/file_task_handler.py @@ -43,8 +43,7 @@ def __init__(self, base_log_folder: str, filename_template: str): super().__init__() self.handler = None # type: Optional[logging.FileHandler] self.local_base = base_log_folder - self.filename_template, self.filename_jinja_template = \ - parse_template_string(filename_template) + self.filename_template, self.filename_jinja_template = parse_template_string(filename_template) def set_context(self, ti: TaskInstance): """ @@ -83,10 +82,12 @@ def _render_filename(self, ti, try_number): } return self.filename_jinja_template.render(**jinja_context) - return self.filename_template.format(dag_id=ti.dag_id, - task_id=ti.task_id, - execution_date=ti.execution_date.isoformat(), - try_number=try_number) + return self.filename_template.format( + dag_id=ti.dag_id, + task_id=ti.task_id, + execution_date=ti.execution_date.isoformat(), + try_number=try_number, + ) def _read_grouped_logs(self): return False @@ -118,7 +119,7 @@ def _read(self, ti, try_number, metadata=None): # pylint: disable=unused-argume except Exception as e: # pylint: disable=broad-except log = f"*** Failed to load local log file: {location}\n" log += "*** {}\n".format(str(e)) - elif conf.get('core', 'executor') == 'KubernetesExecutor': # pylint: disable=too-many-nested-blocks + elif conf.get('core', 'executor') == 'KubernetesExecutor': # pylint: disable=too-many-nested-blocks try: from airflow.kubernetes.kube_client import get_kube_client @@ -129,14 +130,18 @@ def _read(self, ti, try_number, metadata=None): # pylint: disable=unused-argume # is returned for the fqdn to comply with the 63 character limit imposed by DNS standards # on any label of a FQDN. pod_list = kube_client.list_namespaced_pod(conf.get('kubernetes', 'namespace')) - matches = [pod.metadata.name for pod in pod_list.items - if pod.metadata.name.startswith(ti.hostname)] + matches = [ + pod.metadata.name + for pod in pod_list.items + if pod.metadata.name.startswith(ti.hostname) + ] if len(matches) == 1: if len(matches[0]) > len(ti.hostname): ti.hostname = matches[0] - log += '*** Trying to get logs (last 100 lines) from worker pod {} ***\n\n'\ - .format(ti.hostname) + log += '*** Trying to get logs (last 100 lines) from worker pod {} ***\n\n'.format( + ti.hostname + ) res = kube_client.read_namespaced_pod_log( name=ti.hostname, @@ -144,22 +149,17 @@ def _read(self, ti, try_number, metadata=None): # pylint: disable=unused-argume container='base', follow=False, tail_lines=100, - _preload_content=False + _preload_content=False, ) for line in res: log += line.decode() except Exception as f: # pylint: disable=broad-except - log += '*** Unable to fetch logs from worker pod {} ***\n{}\n\n'.format( - ti.hostname, str(f) - ) + log += '*** Unable to fetch logs from worker pod {} ***\n{}\n\n'.format(ti.hostname, str(f)) else: - url = os.path.join( - "http://{ti.hostname}:{worker_log_server_port}/log", log_relative_path - ).format( - ti=ti, - worker_log_server_port=conf.get('celery', 'WORKER_LOG_SERVER_PORT') + url = os.path.join("http://{ti.hostname}:{worker_log_server_port}/log", log_relative_path).format( + ti=ti, worker_log_server_port=conf.get('celery', 'WORKER_LOG_SERVER_PORT') ) log += f"*** Log file does not exist: {location}\n" log += f"*** Fetching from: {url}\n" diff --git a/airflow/utils/log/gcs_task_handler.py b/airflow/utils/log/gcs_task_handler.py index 63251a4e91f07..bf8576b8906d5 100644 --- a/airflow/utils/log/gcs_task_handler.py +++ b/airflow/utils/log/gcs_task_handler.py @@ -23,5 +23,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.log.gcs_task_handler`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/utils/log/json_formatter.py b/airflow/utils/log/json_formatter.py index 4d517deb1542b..73d461942186c 100644 --- a/airflow/utils/log/json_formatter.py +++ b/airflow/utils/log/json_formatter.py @@ -42,7 +42,6 @@ def usesTime(self): def format(self, record): super().format(record) - record_dict = {label: getattr(record, label, None) - for label in self.json_fields} + record_dict = {label: getattr(record, label, None) for label in self.json_fields} merged_record = merge_dicts(record_dict, self.extras) return json.dumps(merged_record) diff --git a/airflow/utils/log/log_reader.py b/airflow/utils/log/log_reader.py index 42d03d901b3cc..bef39a0b25af2 100644 --- a/airflow/utils/log/log_reader.py +++ b/airflow/utils/log/log_reader.py @@ -29,8 +29,9 @@ class TaskLogReader: """Task log reader""" - def read_log_chunks(self, ti: TaskInstance, try_number: Optional[int], - metadata) -> Tuple[List[str], Dict[str, Any]]: + def read_log_chunks( + self, ti: TaskInstance, try_number: Optional[int], metadata + ) -> Tuple[List[str], Dict[str, Any]]: """ Reads chunks of Task Instance logs. @@ -58,8 +59,7 @@ def read_log_chunks(self, ti: TaskInstance, try_number: Optional[int], metadata = metadatas[0] return logs, metadata - def read_log_stream(self, ti: TaskInstance, try_number: Optional[int], - metadata: dict) -> Iterator[str]: + def read_log_stream(self, ti: TaskInstance, try_number: Optional[int], metadata: dict) -> Iterator[str]: """ Used to continuously read log to the end @@ -115,7 +115,6 @@ def render_log_filename(self, ti: TaskInstance, try_number: Optional[int] = None """ filename_template = conf.get('logging', 'LOG_FILENAME_TEMPLATE') attachment_filename = render_log_filename( - ti=ti, - try_number="all" if try_number is None else try_number, - filename_template=filename_template) + ti=ti, try_number="all" if try_number is None else try_number, filename_template=filename_template + ) return attachment_filename diff --git a/airflow/utils/log/logging_mixin.py b/airflow/utils/log/logging_mixin.py index 624f6b931e937..12e949b8b2318 100644 --- a/airflow/utils/log/logging_mixin.py +++ b/airflow/utils/log/logging_mixin.py @@ -46,9 +46,7 @@ def log(self) -> Logger: # FIXME: LoggingMixin should have a default _log field. return self._log # type: ignore except AttributeError: - self._log = logging.getLogger( - self.__class__.__module__ + '.' + self.__class__.__name__ - ) + self._log = logging.getLogger(self.__class__.__module__ + '.' + self.__class__.__name__) return self._log def _set_context(self, context): @@ -91,7 +89,7 @@ def close(self): """ @property - def closed(self): # noqa: D402 + def closed(self): # noqa: D402 """ Returns False to indicate that the stream is not closed (as it will be open for the duration of Airflow's lifecycle). @@ -141,8 +139,9 @@ class RedirectStdHandler(StreamHandler): # pylint: disable=super-init-not-called def __init__(self, stream): if not isinstance(stream, str): - raise Exception("Cannot use file like objects. Use 'stdout' or 'stderr'" - " as a str and without 'ext://'.") + raise Exception( + "Cannot use file like objects. Use 'stdout' or 'stderr'" " as a str and without 'ext://'." + ) self._use_stderr = True if 'stdout' in stream: diff --git a/airflow/utils/log/s3_task_handler.py b/airflow/utils/log/s3_task_handler.py index 6bccff79f4e4d..cf8a7b7d5efbe 100644 --- a/airflow/utils/log/s3_task_handler.py +++ b/airflow/utils/log/s3_task_handler.py @@ -23,5 +23,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.amazon.aws.log.s3_task_handler`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/utils/log/stackdriver_task_handler.py b/airflow/utils/log/stackdriver_task_handler.py index 0b96380ad5cbc..182c6636a3bd8 100644 --- a/airflow/utils/log/stackdriver_task_handler.py +++ b/airflow/utils/log/stackdriver_task_handler.py @@ -22,5 +22,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.google.cloud.log.stackdriver_task_handler`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/utils/log/wasb_task_handler.py b/airflow/utils/log/wasb_task_handler.py index 7b41933a1ecd2..263ef5641a71b 100644 --- a/airflow/utils/log/wasb_task_handler.py +++ b/airflow/utils/log/wasb_task_handler.py @@ -23,5 +23,6 @@ warnings.warn( "This module is deprecated. Please use `airflow.providers.microsoft.azure.log.wasb_task_handler`.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/utils/module_loading.py b/airflow/utils/module_loading.py index dbd5622d203f3..e863f8641894e 100644 --- a/airflow/utils/module_loading.py +++ b/airflow/utils/module_loading.py @@ -34,6 +34,4 @@ def import_string(dotted_path): try: return getattr(module, class_name) except AttributeError: - raise ImportError('Module "{}" does not define a "{}" attribute/class'.format( - module_path, class_name) - ) + raise ImportError(f'Module "{module_path}" does not define a "{class_name}" attribute/class') diff --git a/airflow/utils/operator_helpers.py b/airflow/utils/operator_helpers.py index c2e72289255d2..72d837fb3431f 100644 --- a/airflow/utils/operator_helpers.py +++ b/airflow/utils/operator_helpers.py @@ -18,18 +18,24 @@ # AIRFLOW_VAR_NAME_FORMAT_MAPPING = { - 'AIRFLOW_CONTEXT_DAG_ID': {'default': 'airflow.ctx.dag_id', - 'env_var_format': 'AIRFLOW_CTX_DAG_ID'}, - 'AIRFLOW_CONTEXT_TASK_ID': {'default': 'airflow.ctx.task_id', - 'env_var_format': 'AIRFLOW_CTX_TASK_ID'}, - 'AIRFLOW_CONTEXT_EXECUTION_DATE': {'default': 'airflow.ctx.execution_date', - 'env_var_format': 'AIRFLOW_CTX_EXECUTION_DATE'}, - 'AIRFLOW_CONTEXT_DAG_RUN_ID': {'default': 'airflow.ctx.dag_run_id', - 'env_var_format': 'AIRFLOW_CTX_DAG_RUN_ID'}, - 'AIRFLOW_CONTEXT_DAG_OWNER': {'default': 'airflow.ctx.dag_owner', - 'env_var_format': 'AIRFLOW_CTX_DAG_OWNER'}, - 'AIRFLOW_CONTEXT_DAG_EMAIL': {'default': 'airflow.ctx.dag_email', - 'env_var_format': 'AIRFLOW_CTX_DAG_EMAIL'}, + 'AIRFLOW_CONTEXT_DAG_ID': {'default': 'airflow.ctx.dag_id', 'env_var_format': 'AIRFLOW_CTX_DAG_ID'}, + 'AIRFLOW_CONTEXT_TASK_ID': {'default': 'airflow.ctx.task_id', 'env_var_format': 'AIRFLOW_CTX_TASK_ID'}, + 'AIRFLOW_CONTEXT_EXECUTION_DATE': { + 'default': 'airflow.ctx.execution_date', + 'env_var_format': 'AIRFLOW_CTX_EXECUTION_DATE', + }, + 'AIRFLOW_CONTEXT_DAG_RUN_ID': { + 'default': 'airflow.ctx.dag_run_id', + 'env_var_format': 'AIRFLOW_CTX_DAG_RUN_ID', + }, + 'AIRFLOW_CONTEXT_DAG_OWNER': { + 'default': 'airflow.ctx.dag_owner', + 'env_var_format': 'AIRFLOW_CTX_DAG_OWNER', + }, + 'AIRFLOW_CONTEXT_DAG_EMAIL': { + 'default': 'airflow.ctx.dag_email', + 'env_var_format': 'AIRFLOW_CTX_DAG_EMAIL', + }, } @@ -54,33 +60,32 @@ def context_to_airflow_vars(context, in_env_var_format=False): task = context.get('task') if task and task.email: if isinstance(task.email, str): - params[AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_EMAIL'][ - name_format]] = task.email + params[AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_EMAIL'][name_format]] = task.email elif isinstance(task.email, list): # os env variable value needs to be string - params[AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_EMAIL'][ - name_format]] = ','.join(task.email) + params[AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_EMAIL'][name_format]] = ','.join( + task.email + ) if task and task.owner: if isinstance(task.owner, str): - params[AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_OWNER'][ - name_format]] = task.owner + params[AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_OWNER'][name_format]] = task.owner elif isinstance(task.owner, list): # os env variable value needs to be string - params[AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_OWNER'][ - name_format]] = ','.join(task.owner) + params[AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_OWNER'][name_format]] = ','.join( + task.owner + ) task_instance = context.get('task_instance') if task_instance and task_instance.dag_id: - params[AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_ID'][ - name_format]] = task_instance.dag_id + params[AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_ID'][name_format]] = task_instance.dag_id if task_instance and task_instance.task_id: - params[AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_TASK_ID'][ - name_format]] = task_instance.task_id + params[ + AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_TASK_ID'][name_format] + ] = task_instance.task_id if task_instance and task_instance.execution_date: params[ - AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_EXECUTION_DATE'][ - name_format]] = task_instance.execution_date.isoformat() + AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_EXECUTION_DATE'][name_format] + ] = task_instance.execution_date.isoformat() dag_run = context.get('dag_run') if dag_run and dag_run.run_id: - params[AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_RUN_ID'][ - name_format]] = dag_run.run_id + params[AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_RUN_ID'][name_format]] = dag_run.run_id return params diff --git a/airflow/utils/operator_resources.py b/airflow/utils/operator_resources.py index b7f42ba2043ef..878102183ea32 100644 --- a/airflow/utils/operator_resources.py +++ b/airflow/utils/operator_resources.py @@ -45,7 +45,8 @@ def __init__(self, name, units_str, qty): if qty < 0: raise AirflowException( 'Received resource quantity {} for resource {} but resource quantity ' - 'must be non-negative.'.format(qty, name)) + 'must be non-negative.'.format(qty, name) + ) self._name = name self._units_str = units_str @@ -119,12 +120,13 @@ class Resources: :type gpus: long """ - def __init__(self, - cpus=conf.getint('operators', 'default_cpus'), - ram=conf.getint('operators', 'default_ram'), - disk=conf.getint('operators', 'default_disk'), - gpus=conf.getint('operators', 'default_gpus') - ): + def __init__( + self, + cpus=conf.getint('operators', 'default_cpus'), + ram=conf.getint('operators', 'default_ram'), + disk=conf.getint('operators', 'default_disk'), + gpus=conf.getint('operators', 'default_gpus'), + ): self.cpus = CpuResource(cpus) self.ram = RamResource(ram) self.disk = DiskResource(disk) diff --git a/airflow/utils/orm_event_handlers.py b/airflow/utils/orm_event_handlers.py index ab4954739a715..4f7d2b3fdbef3 100644 --- a/airflow/utils/orm_event_handlers.py +++ b/airflow/utils/orm_event_handlers.py @@ -36,6 +36,7 @@ def connect(dbapi_connection, connection_record): connection_record.info['pid'] = os.getpid() if engine.dialect.name == "sqlite": + @event.listens_for(engine, "connect") def set_sqlite_pragma(dbapi_connection, connection_record): cursor = dbapi_connection.cursor() @@ -44,6 +45,7 @@ def set_sqlite_pragma(dbapi_connection, connection_record): # this ensures sanity in mysql when storing datetimes (not required for postgres) if engine.dialect.name == "mysql": + @event.listens_for(engine, "connect") def set_mysql_timezone(dbapi_connection, connection_record): cursor = dbapi_connection.cursor() @@ -59,7 +61,9 @@ def checkout(dbapi_connection, connection_record, connection_proxy): "Connection record belongs to pid {}, " "attempting to check out in pid {}".format(connection_record.info['pid'], pid) ) + if conf.getboolean('debug', 'sqlalchemy_stats', fallback=False): + @event.listens_for(engine, "before_cursor_execute") def before_cursor_execute(conn, cursor, statement, parameters, context, executemany): conn.info.setdefault('query_start_time', []).append(time.time()) @@ -68,12 +72,19 @@ def before_cursor_execute(conn, cursor, statement, parameters, context, executem def after_cursor_execute(conn, cursor, statement, parameters, context, executemany): total = time.time() - conn.info['query_start_time'].pop() file_name = [ - f"'{f.name}':{f.filename}:{f.lineno}" for f - in traceback.extract_stack() if 'sqlalchemy' not in f.filename][-1] + f"'{f.name}':{f.filename}:{f.lineno}" + for f in traceback.extract_stack() + if 'sqlalchemy' not in f.filename + ][-1] stack = [f for f in traceback.extract_stack() if 'sqlalchemy' not in f.filename] stack_info = ">".join([f"{f.filename.rpartition('/')[-1]}:{f.name}" for f in stack][-3:]) conn.info.setdefault('query_start_time', []).append(time.monotonic()) - log.info("@SQLALCHEMY %s |$ %s |$ %s |$ %s ", - total, file_name, stack_info, statement.replace("\n", " ") - ) + log.info( + "@SQLALCHEMY %s |$ %s |$ %s |$ %s ", + total, + file_name, + stack_info, + statement.replace("\n", " "), + ) + # pylint: enable=unused-argument, unused-variable diff --git a/airflow/utils/process_utils.py b/airflow/utils/process_utils.py index 226ec86453d97..1ee1ac60937c9 100644 --- a/airflow/utils/process_utils.py +++ b/airflow/utils/process_utils.py @@ -41,9 +41,7 @@ # When killing processes, time to wait after issuing a SIGTERM before issuing a # SIGKILL. -DEFAULT_TIME_TO_WAIT_AFTER_SIGTERM = conf.getint( - 'core', 'KILLED_TASK_CLEANUP_TIME' -) +DEFAULT_TIME_TO_WAIT_AFTER_SIGTERM = conf.getint('core', 'KILLED_TASK_CLEANUP_TIME') def reap_process_group(pgid, logger, sig=signal.SIGTERM, timeout=DEFAULT_TIME_TO_WAIT_AFTER_SIGTERM): @@ -130,13 +128,7 @@ def execute_in_subprocess(cmd: List[str]): :type cmd: List[str] """ log.info("Executing cmd: %s", " ".join([shlex.quote(c) for c in cmd])) - proc = subprocess.Popen( - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - bufsize=0, - close_fds=True - ) + proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, bufsize=0, close_fds=True) log.info("Output:") if proc.stdout: with proc.stdout: @@ -164,12 +156,7 @@ def execute_interactive(cmd: List[str], **kwargs): try: # pylint: disable=too-many-nested-blocks # use os.setsid() make it run in a new process group, or bash job control will not be enabled proc = subprocess.Popen( - cmd, - stdin=slave_fd, - stdout=slave_fd, - stderr=slave_fd, - universal_newlines=True, - **kwargs + cmd, stdin=slave_fd, stdout=slave_fd, stderr=slave_fd, universal_newlines=True, **kwargs ) while proc.poll() is None: @@ -238,10 +225,7 @@ def patch_environ(new_env_variables: Dict[str, str]): :param new_env_variables: Environment variables to set """ - current_env_state = { - key: os.environ.get(key) - for key in new_env_variables.keys() - } + current_env_state = {key: os.environ.get(key) for key in new_env_variables.keys()} os.environ.update(new_env_variables) try: # pylint: disable=too-many-nested-blocks yield diff --git a/airflow/utils/python_virtualenv.py b/airflow/utils/python_virtualenv.py index 486371904647b..185a60154df5d 100644 --- a/airflow/utils/python_virtualenv.py +++ b/airflow/utils/python_virtualenv.py @@ -43,10 +43,7 @@ def _generate_pip_install_cmd(tmp_dir: str, requirements: List[str]) -> Optional def prepare_virtualenv( - venv_directory: str, - python_bin: str, - system_site_packages: bool, - requirements: List[str] + venv_directory: str, python_bin: str, system_site_packages: bool, requirements: List[str] ) -> str: """ Creates a virtual environment and installs the additional python packages @@ -83,9 +80,6 @@ def write_python_script(jinja_context: dict, filename: str): :type filename: str """ template_loader = jinja2.FileSystemLoader(searchpath=os.path.dirname(__file__)) - template_env = jinja2.Environment( - loader=template_loader, - undefined=jinja2.StrictUndefined - ) + template_env = jinja2.Environment(loader=template_loader, undefined=jinja2.StrictUndefined) template = template_env.get_template('python_virtualenv_script.jinja2') template.stream(**jinja_context).dump(filename) diff --git a/airflow/utils/serve_logs.py b/airflow/utils/serve_logs.py index 57717393d23e9..0fefa420b7d84 100644 --- a/airflow/utils/serve_logs.py +++ b/airflow/utils/serve_logs.py @@ -32,10 +32,8 @@ def serve_logs(): def serve_logs_view(filename): # pylint: disable=unused-variable log_directory = os.path.expanduser(conf.get('logging', 'BASE_LOG_FOLDER')) return flask.send_from_directory( - log_directory, - filename, - mimetype="application/json", - as_attachment=False) + log_directory, filename, mimetype="application/json", as_attachment=False + ) worker_log_server_port = conf.getint('celery', 'WORKER_LOG_SERVER_PORT') flask_app.run(host='0.0.0.0', port=worker_log_server_port) diff --git a/airflow/utils/session.py b/airflow/utils/session.py index 979c23a2fad8c..ea526edb12288 100644 --- a/airflow/utils/session.py +++ b/airflow/utils/session.py @@ -46,13 +46,13 @@ def provide_session(func: Callable[..., RT]) -> Callable[..., RT]: database transaction, you pass it to the function, if not this wrapper will create one and close it for you. """ + @wraps(func) def wrapper(*args, **kwargs) -> RT: arg_session = 'session' func_params = func.__code__.co_varnames - session_in_args = arg_session in func_params and \ - func_params.index(arg_session) < len(args) + session_in_args = arg_session in func_params and func_params.index(arg_session) < len(args) session_in_kwargs = arg_session in kwargs if session_in_kwargs or session_in_args: diff --git a/airflow/utils/sqlalchemy.py b/airflow/utils/sqlalchemy.py index 8025b4304c3c8..a5bba18678361 100644 --- a/airflow/utils/sqlalchemy.py +++ b/airflow/utils/sqlalchemy.py @@ -59,8 +59,7 @@ class UtcDateTime(TypeDecorator): def process_bind_param(self, value, dialect): if value is not None: if not isinstance(value, datetime.datetime): - raise TypeError('expected datetime.datetime, not ' + - repr(value)) + raise TypeError('expected datetime.datetime, not ' + repr(value)) elif value.tzinfo is None: raise ValueError('naive datetime is disallowed') # For mysql we should store timestamps as naive values @@ -70,6 +69,7 @@ def process_bind_param(self, value, dialect): # See https://issues.apache.org/jira/browse/AIRFLOW-7001 if using_mysql: from airflow.utils.timezone import make_naive + return make_naive(value, timezone=utc) return value.astimezone(utc) return None @@ -99,17 +99,27 @@ class Interval(TypeDecorator): attr_keys = { datetime.timedelta: ('days', 'seconds', 'microseconds'), relativedelta.relativedelta: ( - 'years', 'months', 'days', 'leapdays', 'hours', 'minutes', 'seconds', 'microseconds', - 'year', 'month', 'day', 'hour', 'minute', 'second', 'microsecond', + 'years', + 'months', + 'days', + 'leapdays', + 'hours', + 'minutes', + 'seconds', + 'microseconds', + 'year', + 'month', + 'day', + 'hour', + 'minute', + 'second', + 'microsecond', ), } def process_bind_param(self, value, dialect): if isinstance(value, tuple(self.attr_keys)): - attrs = { - key: getattr(value, key) - for key in self.attr_keys[type(value)] - } + attrs = {key: getattr(value, key) for key in self.attr_keys[type(value)]} return json.dumps({'type': type(value).__name__, 'attrs': attrs}) return json.dumps(value) diff --git a/airflow/utils/state.py b/airflow/utils/state.py index 81164ff00ca0a..3ead44fec792f 100644 --- a/airflow/utils/state.py +++ b/airflow/utils/state.py @@ -95,20 +95,19 @@ def color_fg(cls, state): return 'white' return 'black' - running = frozenset([ - RUNNING, - SENSING - ]) + running = frozenset([RUNNING, SENSING]) """ A list of states indicating that a task is being executed. """ - finished = frozenset([ - SUCCESS, - FAILED, - SKIPPED, - UPSTREAM_FAILED, - ]) + finished = frozenset( + [ + SUCCESS, + FAILED, + SKIPPED, + UPSTREAM_FAILED, + ] + ) """ A list of states indicating a task has reached a terminal state (i.e. it has "finished") and needs no further action. @@ -118,16 +117,18 @@ def color_fg(cls, state): case, it is no longer running. """ - unfinished = frozenset([ - NONE, - SCHEDULED, - QUEUED, - RUNNING, - SENSING, - SHUTDOWN, - UP_FOR_RETRY, - UP_FOR_RESCHEDULE, - ]) + unfinished = frozenset( + [ + NONE, + SCHEDULED, + QUEUED, + RUNNING, + SENSING, + SHUTDOWN, + UP_FOR_RETRY, + UP_FOR_RESCHEDULE, + ] + ) """ A list of states indicating that a task either has not completed a run or has not even started. diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py index ee77694269654..d3d29c2a99668 100644 --- a/airflow/utils/task_group.py +++ b/airflow/utils/task_group.py @@ -180,8 +180,10 @@ def update_relative(self, other: "TaskMixin", upstream=True) -> None: # Handles setting relationship between a TaskGroup and a task for task in other.roots: if not isinstance(task, BaseOperator): - raise AirflowException("Relationships can only be set between TaskGroup " - f"or operators; received {task.__class__.__name__}") + raise AirflowException( + "Relationships can only be set between TaskGroup " + f"or operators; received {task.__class__.__name__}" + ) if upstream: self.upstream_task_ids.add(task.task_id) @@ -189,9 +191,7 @@ def update_relative(self, other: "TaskMixin", upstream=True) -> None: self.downstream_task_ids.add(task.task_id) def _set_relative( - self, - task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]], - upstream: bool = False + self, task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]], upstream: bool = False ) -> None: """ Call set_upstream/set_downstream for all root/leaf tasks within this TaskGroup. @@ -210,15 +210,11 @@ def _set_relative( for task_like in task_or_task_list: self.update_relative(task_like, upstream) - def set_downstream( - self, task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]] - ) -> None: + def set_downstream(self, task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]]) -> None: """Set a TaskGroup/task/list of task downstream of this TaskGroup.""" self._set_relative(task_or_task_list, upstream=False) - def set_upstream( - self, task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]] - ) -> None: + def set_upstream(self, task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]]) -> None: """Set a TaskGroup/task/list of task upstream of this TaskGroup.""" self._set_relative(task_or_task_list, upstream=True) diff --git a/airflow/utils/timeout.py b/airflow/utils/timeout.py index 8c5b2490786db..b94a7bb63893e 100644 --- a/airflow/utils/timeout.py +++ b/airflow/utils/timeout.py @@ -31,7 +31,7 @@ def __init__(self, seconds=1, error_message='Timeout'): self.seconds = seconds self.error_message = error_message + ', PID: ' + str(os.getpid()) - def handle_timeout(self, signum, frame): # pylint: disable=unused-argument + def handle_timeout(self, signum, frame): # pylint: disable=unused-argument """Logs information and raises AirflowTaskTimeout.""" self.log.error("Process timed out, PID: %s", str(os.getpid())) raise AirflowTaskTimeout(self.error_message) diff --git a/airflow/utils/timezone.py b/airflow/utils/timezone.py index 95a555f3624e5..d302cbe1a7c6c 100644 --- a/airflow/utils/timezone.py +++ b/airflow/utils/timezone.py @@ -109,8 +109,7 @@ def make_aware(value, timezone=None): # Check that we won't overwrite the timezone of an aware datetime. if is_localized(value): - raise ValueError( - "make_aware expects a naive datetime, got %s" % value) + raise ValueError("make_aware expects a naive datetime, got %s" % value) if hasattr(value, 'fold'): # In case of python 3.6 we want to do the same that pendulum does for python3.5 # i.e in case we move clock back we want to schedule the run at the time of the second @@ -146,13 +145,9 @@ def make_naive(value, timezone=None): date = value.astimezone(timezone) # cross library compatibility - naive = dt.datetime(date.year, - date.month, - date.day, - date.hour, - date.minute, - date.second, - date.microsecond) + naive = dt.datetime( + date.year, date.month, date.day, date.hour, date.minute, date.second, date.microsecond + ) return naive diff --git a/airflow/utils/weekday.py b/airflow/utils/weekday.py index d83b134956099..698ec6067187f 100644 --- a/airflow/utils/weekday.py +++ b/airflow/utils/weekday.py @@ -42,8 +42,6 @@ def get_weekday_number(cls, week_day_str): sanitized_week_day_str = week_day_str.upper() if sanitized_week_day_str not in cls.__members__: - raise AttributeError( - f'Invalid Week Day passed: "{week_day_str}"' - ) + raise AttributeError(f'Invalid Week Day passed: "{week_day_str}"') return cls[sanitized_week_day_str] diff --git a/airflow/www/api/experimental/endpoints.py b/airflow/www/api/experimental/endpoints.py index 60b5dc6102c7e..ada059b8575ef 100644 --- a/airflow/www/api/experimental/endpoints.py +++ b/airflow/www/api/experimental/endpoints.py @@ -42,6 +42,7 @@ def requires_authentication(function: T): """Decorator for functions that require authentication""" + @wraps(function) def decorated(*args, **kwargs): return current_app.api_auth.requires_authentication(function)(*args, **kwargs) @@ -98,15 +99,15 @@ def trigger_dag(dag_id): except ValueError: error_message = ( 'Given execution date, {}, could not be identified ' - 'as a date. Example date format: 2015-11-16T14:34:15+00:00' - .format(execution_date)) + 'as a date. Example date format: 2015-11-16T14:34:15+00:00'.format(execution_date) + ) log.error(error_message) response = jsonify({'error': error_message}) response.status_code = 400 return response - replace_microseconds = (execution_date is None) + replace_microseconds = execution_date is None if 'replace_microseconds' in data: replace_microseconds = to_boolean(data['replace_microseconds']) @@ -122,9 +123,7 @@ def trigger_dag(dag_id): log.info("User %s created %s", g.user, dr) response = jsonify( - message=f"Created {dr}", - execution_date=dr.execution_date.isoformat(), - run_id=dr.run_id + message=f"Created {dr}", execution_date=dr.execution_date.isoformat(), run_id=dr.run_id ) return response @@ -206,9 +205,7 @@ def task_info(dag_id, task_id): return response # JSONify and return. - fields = {k: str(v) - for k, v in vars(t_info).items() - if not k.startswith('_')} + fields = {k: str(v) for k, v in vars(t_info).items() if not k.startswith('_')} return jsonify(fields) @@ -236,8 +233,8 @@ def dag_is_paused(dag_id): @api_experimental.route( - '/dags//dag_runs//tasks/', - methods=['GET']) + '/dags//dag_runs//tasks/', methods=['GET'] +) @requires_authentication def task_instance_info(dag_id, execution_date, task_id): """ @@ -252,8 +249,8 @@ def task_instance_info(dag_id, execution_date, task_id): except ValueError: error_message = ( 'Given execution date, {}, could not be identified ' - 'as a date. Example date format: 2015-11-16T14:34:15+00:00' - .format(execution_date)) + 'as a date. Example date format: 2015-11-16T14:34:15+00:00'.format(execution_date) + ) log.error(error_message) response = jsonify({'error': error_message}) response.status_code = 400 @@ -269,15 +266,11 @@ def task_instance_info(dag_id, execution_date, task_id): return response # JSONify and return. - fields = {k: str(v) - for k, v in vars(ti_info).items() - if not k.startswith('_')} + fields = {k: str(v) for k, v in vars(ti_info).items() if not k.startswith('_')} return jsonify(fields) -@api_experimental.route( - '/dags//dag_runs/', - methods=['GET']) +@api_experimental.route('/dags//dag_runs/', methods=['GET']) @requires_authentication def dag_run_status(dag_id, execution_date): """ @@ -292,8 +285,8 @@ def dag_run_status(dag_id, execution_date): except ValueError: error_message = ( 'Given execution date, {}, could not be identified ' - 'as a date. Example date format: 2015-11-16T14:34:15+00:00'.format( - execution_date)) + 'as a date. Example date format: 2015-11-16T14:34:15+00:00'.format(execution_date) + ) log.error(error_message) response = jsonify({'error': error_message}) response.status_code = 400 @@ -316,18 +309,21 @@ def dag_run_status(dag_id, execution_date): def latest_dag_runs(): """Returns the latest DagRun for each DAG formatted for the UI""" from airflow.models import DagRun + dagruns = DagRun.get_latest_runs() payload = [] for dagrun in dagruns: if dagrun.execution_date: - payload.append({ - 'dag_id': dagrun.dag_id, - 'execution_date': dagrun.execution_date.isoformat(), - 'start_date': ((dagrun.start_date or '') and - dagrun.start_date.isoformat()), - 'dag_run_url': url_for('Airflow.graph', dag_id=dagrun.dag_id, - execution_date=dagrun.execution_date) - }) + payload.append( + { + 'dag_id': dagrun.dag_id, + 'execution_date': dagrun.execution_date.isoformat(), + 'start_date': ((dagrun.start_date or '') and dagrun.start_date.isoformat()), + 'dag_run_url': url_for( + 'Airflow.graph', dag_id=dagrun.dag_id, execution_date=dagrun.execution_date + ), + } + ) return jsonify(items=payload) # old flask versions don't support jsonifying arrays @@ -392,8 +388,7 @@ def delete_pool(name): return jsonify(pool.to_json()) -@api_experimental.route('/lineage//', - methods=['GET']) +@api_experimental.route('/lineage//', methods=['GET']) def get_lineage(dag_id: str, execution_date: str): """Get Lineage details for a DagRun""" # Convert string datetime into actual datetime @@ -402,8 +397,8 @@ def get_lineage(dag_id: str, execution_date: str): except ValueError: error_message = ( 'Given execution date, {}, could not be identified ' - 'as a date. Example date format: 2015-11-16T14:34:15+00:00'.format( - execution_date)) + 'as a date. Example date format: 2015-11-16T14:34:15+00:00'.format(execution_date) + ) log.error(error_message) response = jsonify({'error': error_message}) response.status_code = 400 diff --git a/airflow/www/app.py b/airflow/www/app.py index 5338dddcc17b5..04ce366425155 100644 --- a/airflow/www/app.py +++ b/airflow/www/app.py @@ -37,7 +37,11 @@ from airflow.www.extensions.init_security import init_api_experimental_auth, init_xframe_protection from airflow.www.extensions.init_session import init_logout_timeout, init_permanent_session from airflow.www.extensions.init_views import ( - init_api_connexion, init_api_experimental, init_appbuilder_views, init_error_handlers, init_flash_views, + init_api_connexion, + init_api_experimental, + init_appbuilder_views, + init_error_handlers, + init_flash_views, init_plugins, ) from airflow.www.extensions.init_wsgi_middlewares import init_wsgi_middleware diff --git a/airflow/www/auth.py b/airflow/www/auth.py index 41bdf8666f5fc..370f7c626d41f 100644 --- a/airflow/www/auth.py +++ b/airflow/www/auth.py @@ -25,6 +25,7 @@ def has_access(permissions: Optional[Sequence[Tuple[str, str]]] = None) -> Callable[[T], T]: """Factory for decorator that checks current user's permissions against required permissions.""" + def requires_access_decorator(func: T): @wraps(func) def decorated(*args, **kwargs): @@ -34,7 +35,12 @@ def decorated(*args, **kwargs): else: access_denied = "Access is Denied" flash(access_denied, "danger") - return redirect(url_for(appbuilder.sm.auth_view.__class__.__name__ + ".login", next=request.url,)) + return redirect( + url_for( + appbuilder.sm.auth_view.__class__.__name__ + ".login", + next=request.url, + ) + ) return cast(T, decorated) diff --git a/airflow/www/extensions/init_jinja_globals.py b/airflow/www/extensions/init_jinja_globals.py index 3554fff2f74fe..ff5a0a5899f39 100644 --- a/airflow/www/extensions/init_jinja_globals.py +++ b/airflow/www/extensions/init_jinja_globals.py @@ -64,7 +64,7 @@ def prepare_jinja_globals(): 'log_animation_speed': conf.getint('webserver', 'log_animation_speed', fallback=1000), 'state_color_mapping': STATE_COLORS, 'airflow_version': airflow_version, - 'git_version': git_version + 'git_version': git_version, } if 'analytics_tool' in conf.getsection('webserver'): diff --git a/airflow/www/extensions/init_security.py b/airflow/www/extensions/init_security.py index e2cc7bf6f8c78..544deebeb3af4 100644 --- a/airflow/www/extensions/init_security.py +++ b/airflow/www/extensions/init_security.py @@ -53,8 +53,5 @@ def init_api_experimental_auth(app): app.api_auth = import_module(auth_backend) app.api_auth.init_app(app) except ImportError as err: - log.critical( - "Cannot import %s for API authentication due to: %s", - auth_backend, err - ) + log.critical("Cannot import %s for API authentication due to: %s", auth_backend, err) raise AirflowException(err) diff --git a/airflow/www/forms.py b/airflow/www/forms.py index 4c6cb90a11033..30abea5565747 100644 --- a/airflow/www/forms.py +++ b/airflow/www/forms.py @@ -22,14 +22,23 @@ import pendulum from flask_appbuilder.fieldwidgets import ( - BS3PasswordFieldWidget, BS3TextAreaFieldWidget, BS3TextFieldWidget, Select2Widget, + BS3PasswordFieldWidget, + BS3TextAreaFieldWidget, + BS3TextFieldWidget, + Select2Widget, ) from flask_appbuilder.forms import DynamicForm from flask_babel import lazy_gettext from flask_wtf import FlaskForm from wtforms import widgets from wtforms.fields import ( - BooleanField, Field, IntegerField, PasswordField, SelectField, StringField, TextAreaField, + BooleanField, + Field, + IntegerField, + PasswordField, + SelectField, + StringField, + TextAreaField, ) from wtforms.validators import DataRequired, NumberRange, Optional @@ -86,8 +95,7 @@ def _get_default_timezone(self): class DateTimeForm(FlaskForm): """Date filter form needed for task views""" - execution_date = DateTimeWithTimezoneField( - "Execution date", widget=AirflowDateTimePickerWidget()) + execution_date = DateTimeWithTimezoneField("Execution date", widget=AirflowDateTimePickerWidget()) class DateTimeWithNumRunsForm(FlaskForm): @@ -97,14 +105,19 @@ class DateTimeWithNumRunsForm(FlaskForm): """ base_date = DateTimeWithTimezoneField( - "Anchor date", widget=AirflowDateTimePickerWidget(), default=timezone.utcnow()) - num_runs = SelectField("Number of runs", default=25, choices=( - (5, "5"), - (25, "25"), - (50, "50"), - (100, "100"), - (365, "365"), - )) + "Anchor date", widget=AirflowDateTimePickerWidget(), default=timezone.utcnow() + ) + num_runs = SelectField( + "Number of runs", + default=25, + choices=( + (5, "5"), + (25, "25"), + (50, "50"), + (100, "100"), + (365, "365"), + ), + ) class DateTimeWithNumRunsWithDagRunsForm(DateTimeWithNumRunsForm): @@ -116,33 +129,26 @@ class DateTimeWithNumRunsWithDagRunsForm(DateTimeWithNumRunsForm): class DagRunForm(DynamicForm): """Form for editing and adding DAG Run""" - dag_id = StringField( - lazy_gettext('Dag Id'), - validators=[DataRequired()], - widget=BS3TextFieldWidget()) - start_date = DateTimeWithTimezoneField( - lazy_gettext('Start Date'), - widget=AirflowDateTimePickerWidget()) - end_date = DateTimeWithTimezoneField( - lazy_gettext('End Date'), - widget=AirflowDateTimePickerWidget()) - run_id = StringField( - lazy_gettext('Run Id'), - validators=[DataRequired()], - widget=BS3TextFieldWidget()) + dag_id = StringField(lazy_gettext('Dag Id'), validators=[DataRequired()], widget=BS3TextFieldWidget()) + start_date = DateTimeWithTimezoneField(lazy_gettext('Start Date'), widget=AirflowDateTimePickerWidget()) + end_date = DateTimeWithTimezoneField(lazy_gettext('End Date'), widget=AirflowDateTimePickerWidget()) + run_id = StringField(lazy_gettext('Run Id'), validators=[DataRequired()], widget=BS3TextFieldWidget()) state = SelectField( lazy_gettext('State'), - choices=(('success', 'success'), ('running', 'running'), ('failed', 'failed'),), - widget=Select2Widget()) + choices=( + ('success', 'success'), + ('running', 'running'), + ('failed', 'failed'), + ), + widget=Select2Widget(), + ) execution_date = DateTimeWithTimezoneField( - lazy_gettext('Execution Date'), - widget=AirflowDateTimePickerWidget()) - external_trigger = BooleanField( - lazy_gettext('External Trigger')) + lazy_gettext('Execution Date'), widget=AirflowDateTimePickerWidget() + ) + external_trigger = BooleanField(lazy_gettext('External Trigger')) conf = TextAreaField( - lazy_gettext('Conf'), - validators=[ValidJson(), Optional()], - widget=BS3TextAreaFieldWidget()) + lazy_gettext('Conf'), validators=[ValidJson(), Optional()], widget=BS3TextAreaFieldWidget() + ) def populate_obj(self, item): """Populates the attributes of the passed obj with data from the form’s fields.""" @@ -214,83 +220,62 @@ def populate_obj(self, item): class ConnectionForm(DynamicForm): """Form for editing and adding Connection""" - conn_id = StringField( - lazy_gettext('Conn Id'), - widget=BS3TextFieldWidget()) + conn_id = StringField(lazy_gettext('Conn Id'), widget=BS3TextFieldWidget()) conn_type = SelectField( lazy_gettext('Conn Type'), choices=sorted(_connection_types, key=itemgetter(1)), # pylint: disable=protected-access - widget=Select2Widget()) - host = StringField( - lazy_gettext('Host'), - widget=BS3TextFieldWidget()) - schema = StringField( - lazy_gettext('Schema'), - widget=BS3TextFieldWidget()) - login = StringField( - lazy_gettext('Login'), - widget=BS3TextFieldWidget()) - password = PasswordField( - lazy_gettext('Password'), - widget=BS3PasswordFieldWidget()) - port = IntegerField( - lazy_gettext('Port'), - validators=[Optional()], - widget=BS3TextFieldWidget()) - extra = TextAreaField( - lazy_gettext('Extra'), - widget=BS3TextAreaFieldWidget()) + widget=Select2Widget(), + ) + host = StringField(lazy_gettext('Host'), widget=BS3TextFieldWidget()) + schema = StringField(lazy_gettext('Schema'), widget=BS3TextFieldWidget()) + login = StringField(lazy_gettext('Login'), widget=BS3TextFieldWidget()) + password = PasswordField(lazy_gettext('Password'), widget=BS3PasswordFieldWidget()) + port = IntegerField(lazy_gettext('Port'), validators=[Optional()], widget=BS3TextFieldWidget()) + extra = TextAreaField(lazy_gettext('Extra'), widget=BS3TextAreaFieldWidget()) # Used to customized the form, the forms elements get rendered # and results are stored in the extra field as json. All of these # need to be prefixed with extra__ and then the conn_type ___ as in # extra__{conn_type}__name. You can also hide form elements and rename # others from the connection_form.js file - extra__jdbc__drv_path = StringField( - lazy_gettext('Driver Path'), - widget=BS3TextFieldWidget()) - extra__jdbc__drv_clsname = StringField( - lazy_gettext('Driver Class'), - widget=BS3TextFieldWidget()) + extra__jdbc__drv_path = StringField(lazy_gettext('Driver Path'), widget=BS3TextFieldWidget()) + extra__jdbc__drv_clsname = StringField(lazy_gettext('Driver Class'), widget=BS3TextFieldWidget()) extra__google_cloud_platform__project = StringField( - lazy_gettext('Project Id'), - widget=BS3TextFieldWidget()) + lazy_gettext('Project Id'), widget=BS3TextFieldWidget() + ) extra__google_cloud_platform__key_path = StringField( - lazy_gettext('Keyfile Path'), - widget=BS3TextFieldWidget()) + lazy_gettext('Keyfile Path'), widget=BS3TextFieldWidget() + ) extra__google_cloud_platform__keyfile_dict = PasswordField( - lazy_gettext('Keyfile JSON'), - widget=BS3PasswordFieldWidget()) + lazy_gettext('Keyfile JSON'), widget=BS3PasswordFieldWidget() + ) extra__google_cloud_platform__scope = StringField( - lazy_gettext('Scopes (comma separated)'), - widget=BS3TextFieldWidget()) + lazy_gettext('Scopes (comma separated)'), widget=BS3TextFieldWidget() + ) extra__google_cloud_platform__num_retries = IntegerField( lazy_gettext('Number of Retries'), validators=[NumberRange(min=0)], widget=BS3TextFieldWidget(), - default=5) - extra__grpc__auth_type = StringField( - lazy_gettext('Grpc Auth Type'), - widget=BS3TextFieldWidget()) + default=5, + ) + extra__grpc__auth_type = StringField(lazy_gettext('Grpc Auth Type'), widget=BS3TextFieldWidget()) extra__grpc__credential_pem_file = StringField( - lazy_gettext('Credential Keyfile Path'), - widget=BS3TextFieldWidget()) - extra__grpc__scopes = StringField( - lazy_gettext('Scopes (comma separated)'), - widget=BS3TextFieldWidget()) + lazy_gettext('Credential Keyfile Path'), widget=BS3TextFieldWidget() + ) + extra__grpc__scopes = StringField(lazy_gettext('Scopes (comma separated)'), widget=BS3TextFieldWidget()) extra__yandexcloud__service_account_json = PasswordField( lazy_gettext('Service account auth JSON'), widget=BS3PasswordFieldWidget(), description='Service account auth JSON. Looks like ' - '{"id", "...", "service_account_id": "...", "private_key": "..."}. ' - 'Will be used instead of OAuth token and SA JSON file path field if specified.', + '{"id", "...", "service_account_id": "...", "private_key": "..."}. ' + 'Will be used instead of OAuth token and SA JSON file path field if specified.', ) extra__yandexcloud__service_account_json_path = StringField( lazy_gettext('Service account auth JSON file path'), widget=BS3TextFieldWidget(), description='Service account auth JSON file path. File content looks like ' - '{"id", "...", "service_account_id": "...", "private_key": "..."}. ' - 'Will be used instead of OAuth token if specified.', + '{"id", "...", "service_account_id": "...", "private_key": "..."}. ' + 'Will be used instead of OAuth token if specified.', ) extra__yandexcloud__oauth = PasswordField( lazy_gettext('OAuth Token'), @@ -308,14 +293,11 @@ class ConnectionForm(DynamicForm): description='Optional. This key will be placed to all created Compute nodes' 'to let you have a root shell there', ) - extra__kubernetes__in_cluster = BooleanField( - lazy_gettext('In cluster configuration')) + extra__kubernetes__in_cluster = BooleanField(lazy_gettext('In cluster configuration')) extra__kubernetes__kube_config_path = StringField( - lazy_gettext('Kube config path'), - widget=BS3TextFieldWidget()) + lazy_gettext('Kube config path'), widget=BS3TextFieldWidget() + ) extra__kubernetes__kube_config = StringField( - lazy_gettext('Kube config (JSON format)'), - widget=BS3TextFieldWidget()) - extra__kubernetes__namespace = StringField( - lazy_gettext('Namespace'), - widget=BS3TextFieldWidget()) + lazy_gettext('Kube config (JSON format)'), widget=BS3TextFieldWidget() + ) + extra__kubernetes__namespace = StringField(lazy_gettext('Namespace'), widget=BS3TextFieldWidget()) diff --git a/airflow/www/security.py b/airflow/www/security.py index be35ddbd65ef6..0654732597f7b 100644 --- a/airflow/www/security.py +++ b/airflow/www/security.py @@ -130,8 +130,14 @@ class AirflowSecurityManager(SecurityManager, LoggingMixin): ROLE_CONFIGS = [ {'role': 'Viewer', 'perms': VIEWER_PERMISSIONS}, - {'role': 'User', 'perms': VIEWER_PERMISSIONS + USER_PERMISSIONS,}, - {'role': 'Op', 'perms': VIEWER_PERMISSIONS + USER_PERMISSIONS + OP_PERMISSIONS,}, + { + 'role': 'User', + 'perms': VIEWER_PERMISSIONS + USER_PERMISSIONS, + }, + { + 'role': 'Op', + 'perms': VIEWER_PERMISSIONS + USER_PERMISSIONS + OP_PERMISSIONS, + }, { 'role': 'Admin', 'perms': VIEWER_PERMISSIONS + USER_PERMISSIONS + OP_PERMISSIONS + ADMIN_PERMISSIONS, diff --git a/airflow/www/utils.py b/airflow/www/utils.py index 697f8806b1903..6095fb306ce5a 100644 --- a/airflow/www/utils.py +++ b/airflow/www/utils.py @@ -72,11 +72,7 @@ def get_params(**kwargs): return urlencode({d: v for d, v in kwargs.items() if v is not None}) -def generate_pages(current_page, - num_of_pages, - search=None, - status=None, - window=7): +def generate_pages(current_page, num_of_pages, search=None, status=None, window=7): """ Generates the HTML for a paging component using a similar logic to the paging auto-generated by Flask managed views. The paging component defines a number of @@ -97,43 +93,51 @@ def generate_pages(current_page, :return: the HTML string of the paging component """ void_link = 'javascript:void(0)' - first_node = Markup("""
  • + first_node = Markup( + """
  • « -
  • """) +""" + ) - previous_node = Markup("""""") +""" + ) - next_node = Markup("""""") +""" + ) - last_node = Markup("""
  • + last_node = Markup( + """
  • » -
  • """) +""" + ) - page_node = Markup("""
  • + page_node = Markup( + """
  • {page_num} -
  • """) +""" + ) output = [Markup('
      ')] is_disabled = 'disabled' if current_page <= 0 else '' - output.append(first_node.format(href_link="?{}" # noqa - .format(get_params(page=0, - search=search, - status=status)), - disabled=is_disabled)) + output.append( + first_node.format( + href_link="?{}".format(get_params(page=0, search=search, status=status)), # noqa + disabled=is_disabled, + ) + ) page_link = void_link if current_page > 0: - page_link = '?{}'.format(get_params(page=(current_page - 1), - search=search, - status=status)) + page_link = '?{}'.format(get_params(page=(current_page - 1), search=search, status=status)) - output.append(previous_node.format(href_link=page_link, # noqa - disabled=is_disabled)) + output.append(previous_node.format(href_link=page_link, disabled=is_disabled)) # noqa mid = int(window / 2) last_page = num_of_pages - 1 @@ -151,27 +155,28 @@ def is_current(current, page): # noqa for page in pages: vals = { 'is_active': 'active' if is_current(current_page, page) else '', - 'href_link': void_link if is_current(current_page, page) - else '?{}'.format(get_params(page=page, - search=search, - status=status)), - 'page_num': page + 1 + 'href_link': void_link + if is_current(current_page, page) + else '?{}'.format(get_params(page=page, search=search, status=status)), + 'page_num': page + 1, } output.append(page_node.format(**vals)) # noqa is_disabled = 'disabled' if current_page >= num_of_pages - 1 else '' - page_link = (void_link if current_page >= num_of_pages - 1 - else '?{}'.format(get_params(page=current_page + 1, - search=search, - status=status))) + page_link = ( + void_link + if current_page >= num_of_pages - 1 + else '?{}'.format(get_params(page=current_page + 1, search=search, status=status)) + ) output.append(next_node.format(href_link=page_link, disabled=is_disabled)) # noqa - output.append(last_node.format(href_link="?{}" # noqa - .format(get_params(page=last_page, - search=search, - status=status)), - disabled=is_disabled)) + output.append( + last_node.format( + href_link="?{}".format(get_params(page=last_page, search=search, status=status)), # noqa + disabled=is_disabled, + ) + ) output.append(Markup('
    ')) @@ -186,10 +191,8 @@ def epoch(dttm): def json_response(obj): """Returns a json response from a json serializable python object""" return Response( - response=json.dumps( - obj, indent=4, cls=AirflowJsonEncoder), - status=200, - mimetype="application/json") + response=json.dumps(obj, indent=4, cls=AirflowJsonEncoder), status=200, mimetype="application/json" + ) def make_cache_key(*args, **kwargs): @@ -204,16 +207,10 @@ def task_instance_link(attr): dag_id = attr.get('dag_id') task_id = attr.get('task_id') execution_date = attr.get('execution_date') - url = url_for( - 'Airflow.task', - dag_id=dag_id, - task_id=task_id, - execution_date=execution_date.isoformat()) + url = url_for('Airflow.task', dag_id=dag_id, task_id=task_id, execution_date=execution_date.isoformat()) url_root = url_for( - 'Airflow.graph', - dag_id=dag_id, - root=task_id, - execution_date=execution_date.isoformat()) + 'Airflow.graph', dag_id=dag_id, root=task_id, execution_date=execution_date.isoformat() + ) return Markup( # noqa """ @@ -223,7 +220,8 @@ def task_instance_link(attr): aria-hidden="true">filter_alt - """).format(url=url, task_id=task_id, url_root=url_root) + """ + ).format(url=url, task_id=task_id, url_root=url_root) def state_token(state): @@ -234,7 +232,8 @@ def state_token(state): """ {state} - """).format(color=color, state=state, fg_color=fg_color) + """ + ).format(color=color, state=state, fg_color=fg_color) def state_f(attr): @@ -245,14 +244,17 @@ def state_f(attr): def nobr_f(attr_name): """Returns a formatted string with HTML with a Non-breaking Text element""" + def nobr(attr): f = attr.get(attr_name) return Markup("{}").format(f) # noqa + return nobr def datetime_f(attr_name): """Returns a formatted string with HTML for given DataTime""" + def dt(attr): # pylint: disable=invalid-name f = attr.get(attr_name) as_iso = f.isoformat() if f else '' @@ -263,16 +265,21 @@ def dt(attr): # pylint: disable=invalid-name f = f[5:] # The empty title will be replaced in JS code when non-UTC dates are displayed return Markup('').format(as_iso, f) # noqa + return dt + + # pylint: enable=invalid-name def json_f(attr_name): """Returns a formatted string with HTML for given JSON serializable""" + def json_(attr): f = attr.get(attr_name) serialized = json.dumps(f) return Markup('{}').format(serialized) # noqa + return json_ @@ -280,12 +287,8 @@ def dag_link(attr): """Generates a URL to the Graph View for a Dag.""" dag_id = attr.get('dag_id') execution_date = attr.get('execution_date') - url = url_for( - 'Airflow.graph', - dag_id=dag_id, - execution_date=execution_date) - return Markup( # noqa - '{}').format(url, dag_id) # noqa + url = url_for('Airflow.graph', dag_id=dag_id, execution_date=execution_date) + return Markup('{}').format(url, dag_id) # noqa # noqa def dag_run_link(attr): @@ -293,13 +296,8 @@ def dag_run_link(attr): dag_id = attr.get('dag_id') run_id = attr.get('run_id') execution_date = attr.get('execution_date') - url = url_for( - 'Airflow.graph', - dag_id=dag_id, - run_id=run_id, - execution_date=execution_date) - return Markup( # noqa - '{run_id}').format(url=url, run_id=run_id) # noqa + url = url_for('Airflow.graph', dag_id=dag_id, run_id=run_id, execution_date=execution_date) + return Markup('{run_id}').format(url=url, run_id=run_id) # noqa # noqa def pygment_html_render(s, lexer=lexers.TextLexer): # noqa pylint: disable=no-member @@ -328,9 +326,7 @@ def wrapped_markdown(s, css_class=None): if s is None: return None - return Markup( - f'
    ' + markdown.markdown(s) + "
    " - ) + return Markup(f'
    ' + markdown.markdown(s) + "
    ") # pylint: disable=no-member @@ -353,6 +349,8 @@ def get_attr_renderer(): 'rst': lambda x: render(x, lexers.RstLexer), 'yaml': lambda x: render(x, lexers.YamlLexer), } + + # pylint: enable=no-member @@ -397,12 +395,11 @@ class UtcAwareFilterConverter(fab_sqlafilters.SQLAFilterConverter): # noqa: D10 """Retrieve conversion tables for UTC-Aware filters.""" conversion_table = ( - (('is_utcdatetime', [UtcAwareFilterEqual, - UtcAwareFilterGreater, - UtcAwareFilterSmaller, - UtcAwareFilterNotEqual]),) + - fab_sqlafilters.SQLAFilterConverter.conversion_table - ) + ( + 'is_utcdatetime', + [UtcAwareFilterEqual, UtcAwareFilterGreater, UtcAwareFilterSmaller, UtcAwareFilterNotEqual], + ), + ) + fab_sqlafilters.SQLAFilterConverter.conversion_table class CustomSQLAInterface(SQLAInterface): @@ -418,11 +415,9 @@ def __init__(self, obj, session=None): def clean_column_names(): if self.list_properties: - self.list_properties = { - k.lstrip('_'): v for k, v in self.list_properties.items()} + self.list_properties = {k.lstrip('_'): v for k, v in self.list_properties.items()} if self.list_columns: - self.list_columns = { - k.lstrip('_'): v for k, v in self.list_columns.items()} + self.list_columns = {k.lstrip('_'): v for k, v in self.list_columns.items()} clean_column_names() @@ -432,9 +427,11 @@ def is_utcdatetime(self, col_name): if col_name in self.list_columns: obj = self.list_columns[col_name].type - return isinstance(obj, UtcDateTime) or \ - isinstance(obj, sqla.types.TypeDecorator) and \ - isinstance(obj.impl, UtcDateTime) + return ( + isinstance(obj, UtcDateTime) + or isinstance(obj, sqla.types.TypeDecorator) + and isinstance(obj.impl, UtcDateTime) + ) return False filter_converter_class = UtcAwareFilterConverter @@ -444,6 +441,5 @@ def is_utcdatetime(self, col_name): # subclass) so we have no other option than to edit the conversion table in # place FieldConverter.conversion_table = ( - (('is_utcdatetime', DateTimeWithTimezoneField, AirflowDateTimePickerWidget),) + - FieldConverter.conversion_table -) + ('is_utcdatetime', DateTimeWithTimezoneField, AirflowDateTimePickerWidget), +) + FieldConverter.conversion_table diff --git a/airflow/www/validators.py b/airflow/www/validators.py index 282789c15dbb0..9699a47e46b3e 100644 --- a/airflow/www/validators.py +++ b/airflow/www/validators.py @@ -37,17 +37,14 @@ def __call__(self, form, field): try: other = form[self.fieldname] except KeyError: - raise ValidationError( - field.gettext("Invalid field name '%s'." % self.fieldname) - ) + raise ValidationError(field.gettext("Invalid field name '%s'." % self.fieldname)) if field.data is None or other.data is None: return if field.data < other.data: message_args = { - 'other_label': - hasattr(other, 'label') and other.label.text or self.fieldname, + 'other_label': hasattr(other, 'label') and other.label.text or self.fieldname, 'other_name': self.fieldname, } message = self.message @@ -77,6 +74,4 @@ def __call__(self, form, field): json.loads(field.data) except JSONDecodeError as ex: message = self.message or f'JSON Validation Error: {ex}' - raise ValidationError( - message=field.gettext(message.format(field.data)) - ) + raise ValidationError(message=field.gettext(message.format(field.data))) diff --git a/airflow/www/views.py b/airflow/www/views.py index 3d16e77924bd0..fe818f26fb8f3 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -34,8 +34,19 @@ import nvd3 import sqlalchemy as sqla from flask import ( - Markup, Response, current_app, escape, flash, g, jsonify, make_response, redirect, render_template, - request, session as flask_session, url_for, + Markup, + Response, + current_app, + escape, + flash, + g, + jsonify, + make_response, + redirect, + render_template, + request, + session as flask_session, + url_for, ) from flask_appbuilder import BaseView, ModelView, expose from flask_appbuilder.actions import action @@ -51,7 +62,8 @@ import airflow from airflow import models, plugins_manager, settings from airflow.api.common.experimental.mark_tasks import ( - set_dag_run_state_to_failed, set_dag_run_state_to_success, + set_dag_run_state_to_failed, + set_dag_run_state_to_success, ) from airflow.configuration import AIRFLOW_CONFIG, conf from airflow.exceptions import AirflowException @@ -76,7 +88,11 @@ from airflow.www import auth, utils as wwwutils from airflow.www.decorators import action_logging, gzipped from airflow.www.forms import ( - ConnectionForm, DagRunForm, DateTimeForm, DateTimeWithNumRunsForm, DateTimeWithNumRunsWithDagRunsForm, + ConnectionForm, + DagRunForm, + DateTimeForm, + DateTimeWithNumRunsForm, + DateTimeWithNumRunsWithDagRunsForm, ) from airflow.www.widgets import AirflowModelListWidget @@ -120,9 +136,7 @@ def get_date_time_num_runs_dag_runs_form_data(www_request, session, dag): drs = ( session.query(DagRun) - .filter( - DagRun.dag_id == dag.dag_id, - DagRun.execution_date <= base_date) + .filter(DagRun.dag_id == dag.dag_id, DagRun.execution_date <= base_date) .order_by(desc(DagRun.execution_date)) .limit(num_runs) .all() @@ -164,34 +178,39 @@ def task_group_to_dict(task_group): 'style': f"fill:{task_group.ui_color};", 'rx': 5, 'ry': 5, - } + }, } - children = [task_group_to_dict(child) for child in - sorted(task_group.children.values(), key=lambda t: t.label)] + children = [ + task_group_to_dict(child) for child in sorted(task_group.children.values(), key=lambda t: t.label) + ] if task_group.upstream_group_ids or task_group.upstream_task_ids: - children.append({ - 'id': task_group.upstream_join_id, - 'value': { - 'label': '', - 'labelStyle': f"fill:{task_group.ui_fgcolor};", - 'style': f"fill:{task_group.ui_color};", - 'shape': 'circle', + children.append( + { + 'id': task_group.upstream_join_id, + 'value': { + 'label': '', + 'labelStyle': f"fill:{task_group.ui_fgcolor};", + 'style': f"fill:{task_group.ui_color};", + 'shape': 'circle', + }, } - }) + ) if task_group.downstream_group_ids or task_group.downstream_task_ids: # This is the join node used to reduce the number of edges between two TaskGroup. - children.append({ - 'id': task_group.downstream_join_id, - 'value': { - 'label': '', - 'labelStyle': f"fill:{task_group.ui_fgcolor};", - 'style': f"fill:{task_group.ui_color};", - 'shape': 'circle', + children.append( + { + 'id': task_group.downstream_join_id, + 'value': { + 'label': '', + 'labelStyle': f"fill:{task_group.ui_fgcolor};", + 'style': f"fill:{task_group.ui_color};", + 'shape': 'circle', + }, } - }) + ) return { "id": task_group.group_id, @@ -204,7 +223,7 @@ def task_group_to_dict(task_group): 'clusterLabelPos': 'top', }, 'tooltip': task_group.tooltip, - 'children': children + 'children': children, } @@ -299,39 +318,47 @@ def get_downstream(task): for root in dag.roots: get_downstream(root) - return [{'source_id': source_id, 'target_id': target_id} - for source_id, target_id - in sorted(edges.union(edges_to_add) - edges_to_skip)] + return [ + {'source_id': source_id, 'target_id': target_id} + for source_id, target_id in sorted(edges.union(edges_to_add) - edges_to_skip) + ] ###################################################################################### # Error handlers ###################################################################################### + def circles(error): # pylint: disable=unused-argument """Show Circles on screen for any error in the Webserver""" - return render_template( - 'airflow/circles.html', hostname=socket.getfqdn() if conf.getboolean( # noqa - 'webserver', - 'EXPOSE_HOSTNAME', - fallback=True) else 'redact'), 404 + return ( + render_template( + 'airflow/circles.html', + hostname=socket.getfqdn() + if conf.getboolean('webserver', 'EXPOSE_HOSTNAME', fallback=True) # noqa + else 'redact', + ), + 404, + ) def show_traceback(error): # pylint: disable=unused-argument """Show Traceback for a given error""" - return render_template( - 'airflow/traceback.html', # noqa - python_version=sys.version.split(" ")[0], - airflow_version=version, - hostname=socket.getfqdn() if conf.getboolean( - 'webserver', - 'EXPOSE_HOSTNAME', - fallback=True) else 'redact', - info=traceback.format_exc() if conf.getboolean( - 'webserver', - 'EXPOSE_STACKTRACE', - fallback=True) else 'Error! Please contact server admin.' - ), 500 + return ( + render_template( + 'airflow/traceback.html', # noqa + python_version=sys.version.split(" ")[0], + airflow_version=version, + hostname=socket.getfqdn() + if conf.getboolean('webserver', 'EXPOSE_HOSTNAME', fallback=True) + else 'redact', + info=traceback.format_exc() + if conf.getboolean('webserver', 'EXPOSE_STACKTRACE', fallback=True) + else 'Error! Please contact server admin.', + ), + 500, + ) + ###################################################################################### # BaseViews @@ -342,6 +369,7 @@ class AirflowBaseView(BaseView): # noqa: D101 """Base View to set Airflow related properties""" from airflow import macros + route_base = '' # Make our macros available to our UI templates too. @@ -354,7 +382,7 @@ def render_template(self, *args, **kwargs): *args, # Cache this at most once per request, not for the lifetime of the view instance scheduler_job=lazy_object_proxy.Proxy(SchedulerJob.most_recent_job), - **kwargs + **kwargs, ) @@ -367,9 +395,7 @@ def health(self): An endpoint helping check the health status of the Airflow instance, including metadatabase and scheduler. """ - payload = { - 'metadatabase': {'status': 'unhealthy'} - } + payload = {'metadatabase': {'status': 'unhealthy'}} latest_scheduler_heartbeat = None scheduler_status = 'unhealthy' @@ -384,19 +410,22 @@ def health(self): except Exception: # noqa pylint: disable=broad-except payload['metadatabase']['status'] = 'unhealthy' - payload['scheduler'] = {'status': scheduler_status, - 'latest_scheduler_heartbeat': latest_scheduler_heartbeat} + payload['scheduler'] = { + 'status': scheduler_status, + 'latest_scheduler_heartbeat': latest_scheduler_heartbeat, + } return wwwutils.json_response(payload) @expose('/home') - @auth.has_access([ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), - ]) # pylint: disable=too-many-locals,too-many-statements + @auth.has_access( + [ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), + ] + ) # pylint: disable=too-many-locals,too-many-statements def index(self): """Home view.""" - hide_paused_dags_by_default = conf.getboolean('webserver', - 'hide_paused_dags_by_default') + hide_paused_dags_by_default = conf.getboolean('webserver', 'hide_paused_dags_by_default') default_dag_run = conf.getint('webserver', 'default_dag_run_display_number') num_runs = request.args.get('num_runs') @@ -448,15 +477,13 @@ def get_int_arg(value, default=0): with create_session() as session: # read orm_dags from the db - dags_query = session.query(DagModel).filter( - ~DagModel.is_subdag, DagModel.is_active - ) + dags_query = session.query(DagModel).filter(~DagModel.is_subdag, DagModel.is_active) # pylint: disable=no-member if arg_search_query: dags_query = dags_query.filter( - DagModel.dag_id.ilike('%' + arg_search_query + '%') | # noqa - DagModel.owners.ilike('%' + arg_search_query + '%') # noqa + DagModel.dag_id.ilike('%' + arg_search_query + '%') + | DagModel.owners.ilike('%' + arg_search_query + '%') # noqa # noqa ) if arg_tags_filter: @@ -472,7 +499,8 @@ def get_int_arg(value, default=0): is_paused_count = dict( all_dags.with_entities(DagModel.is_paused, func.count(DagModel.dag_id)) - .group_by(DagModel.is_paused).all() + .group_by(DagModel.is_paused) + .all() ) status_count_active = is_paused_count.get(False, 0) status_count_paused = is_paused_count.get(True, 0) @@ -487,8 +515,13 @@ def get_int_arg(value, default=0): current_dags = all_dags num_of_all_dags = all_dags_count - dags = current_dags.order_by(DagModel.dag_id).options( - joinedload(DagModel.tags)).offset(start).limit(dags_per_page).all() + dags = ( + current_dags.order_by(DagModel.dag_id) + .options(joinedload(DagModel.tags)) + .offset(start) + .limit(dags_per_page) + .all() + ) dagtags = session.query(DagTag.name).distinct(DagTag.name).all() tags = [ @@ -499,17 +532,15 @@ def get_int_arg(value, default=0): import_errors = session.query(errors.ImportError).all() for import_error in import_errors: - flash( - "Broken DAG: [{ie.filename}] {ie.stacktrace}".format(ie=import_error), - "dag_import_error") + flash("Broken DAG: [{ie.filename}] {ie.stacktrace}".format(ie=import_error), "dag_import_error") from airflow.plugins_manager import import_errors as plugin_import_errors + for filename, stacktrace in plugin_import_errors.items(): flash( - "Broken plugin: [{filename}] {stacktrace}".format( - stacktrace=stacktrace, - filename=filename), - "error") + f"Broken plugin: [{filename}] {stacktrace}", + "error", + ) num_of_pages = int(math.ceil(num_of_all_dags / float(dags_per_page))) @@ -526,10 +557,12 @@ def get_int_arg(value, default=0): num_dag_from=min(start + 1, num_of_all_dags), num_dag_to=min(end, num_of_all_dags), num_of_all_dags=num_of_all_dags, - paging=wwwutils.generate_pages(current_page, - num_of_pages, - search=escape(arg_search_query) if arg_search_query else None, - status=arg_status_filter if arg_status_filter else None), + paging=wwwutils.generate_pages( + current_page, + num_of_pages, + search=escape(arg_search_query) if arg_search_query else None, + status=arg_status_filter if arg_status_filter else None, + ), num_runs=num_runs, tags=tags, state_color=state_color_mapping, @@ -537,7 +570,8 @@ def get_int_arg(value, default=0): status_count_all=all_dags_count, status_count_active=status_count_active, status_count_paused=status_count_paused, - tags_filter=arg_tags_filter) + tags_filter=arg_tags_filter, + ) @expose('/dag_stats', methods=['POST']) @auth.has_access( @@ -555,13 +589,12 @@ def dag_stats(self, session=None): if permissions.RESOURCE_DAG in allowed_dag_ids: allowed_dag_ids = [dag_id for dag_id, in session.query(models.DagModel.dag_id)] - dag_state_stats = session.query(dr.dag_id, dr.state, sqla.func.count(dr.state))\ - .group_by(dr.dag_id, dr.state) + dag_state_stats = session.query(dr.dag_id, dr.state, sqla.func.count(dr.state)).group_by( + dr.dag_id, dr.state + ) # Filter by post parameters - selected_dag_ids = { - unquote(dag_id) for dag_id in request.form.getlist('dag_ids') if dag_id - } + selected_dag_ids = {unquote(dag_id) for dag_id in request.form.getlist('dag_ids') if dag_id} if selected_dag_ids: filter_dag_ids = selected_dag_ids.intersection(allowed_dag_ids) @@ -584,10 +617,7 @@ def dag_stats(self, session=None): payload[dag_id] = [] for state in State.dag_states: count = data.get(dag_id, {}).get(state, 0) - payload[dag_id].append({ - 'state': state, - 'count': count - }) + payload[dag_id].append({'state': state, 'count': count}) return wwwutils.json_response(payload) @@ -611,9 +641,7 @@ def task_stats(self, session=None): allowed_dag_ids = {dag_id for dag_id, in session.query(models.DagModel.dag_id)} # Filter by post parameters - selected_dag_ids = { - unquote(dag_id) for dag_id in request.form.getlist('dag_ids') if dag_id - } + selected_dag_ids = {unquote(dag_id) for dag_id in request.form.getlist('dag_ids') if dag_id} if selected_dag_ids: filter_dag_ids = selected_dag_ids.intersection(allowed_dag_ids) @@ -623,39 +651,41 @@ def task_stats(self, session=None): # pylint: disable=comparison-with-callable running_dag_run_query_result = ( session.query(DagRun.dag_id, DagRun.execution_date) - .join(DagModel, DagModel.dag_id == DagRun.dag_id) - .filter(DagRun.state == State.RUNNING, DagModel.is_active) + .join(DagModel, DagModel.dag_id == DagRun.dag_id) + .filter(DagRun.state == State.RUNNING, DagModel.is_active) ) # pylint: enable=comparison-with-callable # pylint: disable=no-member if selected_dag_ids: - running_dag_run_query_result = \ - running_dag_run_query_result.filter(DagRun.dag_id.in_(filter_dag_ids)) + running_dag_run_query_result = running_dag_run_query_result.filter( + DagRun.dag_id.in_(filter_dag_ids) + ) # pylint: enable=no-member running_dag_run_query_result = running_dag_run_query_result.subquery('running_dag_run') # pylint: disable=no-member # Select all task_instances from active dag_runs. - running_task_instance_query_result = ( - session.query(TaskInstance.dag_id.label('dag_id'), TaskInstance.state.label('state')) - .join(running_dag_run_query_result, - and_(running_dag_run_query_result.c.dag_id == TaskInstance.dag_id, - running_dag_run_query_result.c.execution_date == TaskInstance.execution_date)) + running_task_instance_query_result = session.query( + TaskInstance.dag_id.label('dag_id'), TaskInstance.state.label('state') + ).join( + running_dag_run_query_result, + and_( + running_dag_run_query_result.c.dag_id == TaskInstance.dag_id, + running_dag_run_query_result.c.execution_date == TaskInstance.execution_date, + ), ) if selected_dag_ids: - running_task_instance_query_result = \ - running_task_instance_query_result.filter(TaskInstance.dag_id.in_(filter_dag_ids)) + running_task_instance_query_result = running_task_instance_query_result.filter( + TaskInstance.dag_id.in_(filter_dag_ids) + ) # pylint: enable=no-member if conf.getboolean('webserver', 'SHOW_RECENT_STATS_FOR_COMPLETED_RUNS', fallback=True): # pylint: disable=comparison-with-callable last_dag_run = ( - session.query( - DagRun.dag_id, - sqla.func.max(DagRun.execution_date).label('execution_date') - ) + session.query(DagRun.dag_id, sqla.func.max(DagRun.execution_date).label('execution_date')) .join(DagModel, DagModel.dag_id == DagRun.dag_id) .filter(DagRun.state != State.RUNNING, DagModel.is_active) .group_by(DagRun.dag_id) @@ -669,30 +699,33 @@ def task_stats(self, session=None): # Select all task_instances from active dag_runs. # If no dag_run is active, return task instances from most recent dag_run. - last_task_instance_query_result = ( - session.query(TaskInstance.dag_id.label('dag_id'), TaskInstance.state.label('state')) - .join(last_dag_run, - and_(last_dag_run.c.dag_id == TaskInstance.dag_id, - last_dag_run.c.execution_date == TaskInstance.execution_date)) + last_task_instance_query_result = session.query( + TaskInstance.dag_id.label('dag_id'), TaskInstance.state.label('state') + ).join( + last_dag_run, + and_( + last_dag_run.c.dag_id == TaskInstance.dag_id, + last_dag_run.c.execution_date == TaskInstance.execution_date, + ), ) # pylint: disable=no-member if selected_dag_ids: - last_task_instance_query_result = \ - last_task_instance_query_result.filter(TaskInstance.dag_id.in_(filter_dag_ids)) + last_task_instance_query_result = last_task_instance_query_result.filter( + TaskInstance.dag_id.in_(filter_dag_ids) + ) # pylint: enable=no-member final_task_instance_query_result = union_all( - last_task_instance_query_result, - running_task_instance_query_result).alias('final_ti') + last_task_instance_query_result, running_task_instance_query_result + ).alias('final_ti') else: final_task_instance_query_result = running_task_instance_query_result.subquery('final_ti') - qry = ( - session.query(final_task_instance_query_result.c.dag_id, - final_task_instance_query_result.c.state, sqla.func.count()) - .group_by(final_task_instance_query_result.c.dag_id, - final_task_instance_query_result.c.state) - ) + qry = session.query( + final_task_instance_query_result.c.dag_id, + final_task_instance_query_result.c.state, + sqla.func.count(), + ).group_by(final_task_instance_query_result.c.dag_id, final_task_instance_query_result.c.state) data = {} for dag_id, state, count in qry: @@ -705,10 +738,7 @@ def task_stats(self, session=None): payload[dag_id] = [] for state in State.task_states: count = data.get(dag_id, {}).get(state, 0) - payload[dag_id].append({ - 'state': state, - 'count': count - }) + payload[dag_id].append({'state': state, 'count': count}) return wwwutils.json_response(payload) @expose('/last_dagruns', methods=['POST']) @@ -727,9 +757,7 @@ def last_dagruns(self, session=None): allowed_dag_ids = [dag_id for dag_id, in session.query(models.DagModel.dag_id)] # Filter by post parameters - selected_dag_ids = { - unquote(dag_id) for dag_id in request.form.getlist('dag_ids') if dag_id - } + selected_dag_ids = {unquote(dag_id) for dag_id in request.form.getlist('dag_ids') if dag_id} if selected_dag_ids: filter_dag_ids = selected_dag_ids.intersection(allowed_dag_ids) @@ -753,7 +781,8 @@ def last_dagruns(self, session=None): 'dag_id': r.dag_id, 'execution_date': r.execution_date.isoformat(), 'start_date': r.start_date.isoformat(), - } for r in query + } + for r in query } return wwwutils.json_response(resp) @@ -775,22 +804,30 @@ def code(self, session=None): dag_id = request.args.get('dag_id') dag_orm = DagModel.get_dagmodel(dag_id, session=session) code = DagCode.get_code_by_fileloc(dag_orm.fileloc) - html_code = Markup(highlight( - code, lexers.PythonLexer(), HtmlFormatter(linenos=True))) # pylint: disable=no-member + html_code = Markup( + highlight( + code, lexers.PythonLexer(), HtmlFormatter(linenos=True) # pylint: disable=no-member + ) + ) except Exception as e: # pylint: disable=broad-except all_errors += ( - "Exception encountered during " + - f"dag_id retrieval/dag retrieval fallback/code highlighting:\n\n{e}\n" + "Exception encountered during " + + f"dag_id retrieval/dag retrieval fallback/code highlighting:\n\n{e}\n" ) html_code = Markup('

    Failed to load file.

    Details: {}

    ').format( # noqa - escape(all_errors)) + escape(all_errors) + ) return self.render_template( - 'airflow/dag_code.html', html_code=html_code, dag=dag_orm, title=dag_id, + 'airflow/dag_code.html', + html_code=html_code, + dag=dag_orm, + title=dag_id, root=request.args.get('root'), demo_mode=conf.getboolean('webserver', 'demo_mode'), - wrapped=conf.getboolean('webserver', 'default_wrap')) + wrapped=conf.getboolean('webserver', 'default_wrap'), + ) @expose('/dag_details') @auth.has_access( @@ -809,19 +846,14 @@ def dag_details(self, session=None): states = ( session.query(TaskInstance.state, sqla.func.count(TaskInstance.dag_id)) - .filter(TaskInstance.dag_id == dag_id) - .group_by(TaskInstance.state) - .all() + .filter(TaskInstance.dag_id == dag_id) + .group_by(TaskInstance.state) + .all() ) - active_runs = models.DagRun.find( - dag_id=dag_id, - state=State.RUNNING, - external_trigger=False - ) + active_runs = models.DagRun.find(dag_id=dag_id, state=State.RUNNING, external_trigger=False) - tags = session.query(models.DagTag).filter( - models.DagTag.dag_id == dag_id).all() + tags = session.query(models.DagTag).filter(models.DagTag.dag_id == dag_id).all() return self.render_template( 'airflow/dag_details.html', @@ -831,7 +863,7 @@ def dag_details(self, session=None): states=states, State=State, active_runs=active_runs, - tags=tags + tags=tags, ) @expose('/rendered') @@ -877,8 +909,9 @@ def rendered(self): content = json.dumps(content, sort_keys=True, indent=4) html_dict[template_field] = renderers[renderer](content) else: - html_dict[template_field] = \ - Markup("
    {}
    ").format(pformat(content)) # noqa + html_dict[template_field] = Markup("
    {}
    ").format( + pformat(content) + ) # noqa return self.render_template( 'airflow/ti_code.html', @@ -888,14 +921,17 @@ def rendered(self): execution_date=execution_date, form=form, root=root, - title=title) + title=title, + ) @expose('/get_logs_with_metadata') - @auth.has_access([ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_LOG), - ]) + @auth.has_access( + [ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_LOG), + ] + ) @action_logging @provide_session def get_logs_with_metadata(self, session=None): @@ -921,8 +957,8 @@ def get_logs_with_metadata(self, session=None): except ValueError: error_message = ( 'Given execution date, {}, could not be identified ' - 'as a date. Example date format: 2015-11-16T14:34:15+00:00'.format( - execution_date)) + 'as a date. Example date format: 2015-11-16T14:34:15+00:00'.format(execution_date) + ) response = jsonify({'error': error_message}) response.status_code = 400 @@ -933,23 +969,24 @@ def get_logs_with_metadata(self, session=None): return jsonify( message="Task log handler does not support read logs.", error=True, - metadata={ - "end_of_log": True - } + metadata={"end_of_log": True}, ) - ti = session.query(models.TaskInstance).filter( - models.TaskInstance.dag_id == dag_id, - models.TaskInstance.task_id == task_id, - models.TaskInstance.execution_date == execution_date).first() + ti = ( + session.query(models.TaskInstance) + .filter( + models.TaskInstance.dag_id == dag_id, + models.TaskInstance.task_id == task_id, + models.TaskInstance.execution_date == execution_date, + ) + .first() + ) if ti is None: return jsonify( message="*** Task instance did not exist in the DB\n", error=True, - metadata={ - "end_of_log": True - } + metadata={"end_of_log": True}, ) try: @@ -968,13 +1005,10 @@ def get_logs_with_metadata(self, session=None): return Response( response=log_stream, mimetype="text/plain", - headers={ - "Content-Disposition": f"attachment; filename={attachment_filename}" - }) + headers={"Content-Disposition": f"attachment; filename={attachment_filename}"}, + ) except AttributeError as e: - error_message = [ - f"Task log handler does not support read logs.\n{str(e)}\n" - ] + error_message = [f"Task log handler does not support read logs.\n{str(e)}\n"] metadata['end_of_log'] = True return jsonify(message=error_message, error=True, metadata=metadata) @@ -997,10 +1031,15 @@ def log(self, session=None): form = DateTimeForm(data={'execution_date': dttm}) dag_model = DagModel.get_dagmodel(dag_id) - ti = session.query(models.TaskInstance).filter( - models.TaskInstance.dag_id == dag_id, - models.TaskInstance.task_id == task_id, - models.TaskInstance.execution_date == dttm).first() + ti = ( + session.query(models.TaskInstance) + .filter( + models.TaskInstance.dag_id == dag_id, + models.TaskInstance.task_id == task_id, + models.TaskInstance.execution_date == dttm, + ) + .first() + ) num_logs = 0 if ti is not None: @@ -1012,10 +1051,16 @@ def log(self, session=None): root = request.args.get('root', '') return self.render_template( 'airflow/ti_log.html', - logs=logs, dag=dag_model, title="Log by attempts", - dag_id=dag_id, task_id=task_id, - execution_date=execution_date, form=form, - root=root, wrapped=conf.getboolean('webserver', 'default_wrap')) + logs=logs, + dag=dag_model, + title="Log by attempts", + dag_id=dag_id, + task_id=task_id, + execution_date=execution_date, + form=form, + root=root, + wrapped=conf.getboolean('webserver', 'default_wrap'), + ) @expose('/redirect_to_external_log') @auth.has_access( @@ -1035,10 +1080,15 @@ def redirect_to_external_log(self, session=None): dttm = timezone.parse(execution_date) try_number = request.args.get('try_number', 1) - ti = session.query(models.TaskInstance).filter( - models.TaskInstance.dag_id == dag_id, - models.TaskInstance.task_id == task_id, - models.TaskInstance.execution_date == dttm).first() + ti = ( + session.query(models.TaskInstance) + .filter( + models.TaskInstance.dag_id == dag_id, + models.TaskInstance.task_id == task_id, + models.TaskInstance.execution_date == dttm, + ) + .first() + ) if not ti: flash(f"Task [{dag_id}.{task_id}] does not exist", "error") @@ -1074,10 +1124,7 @@ def task(self): dag = current_app.dag_bag.get_dag(dag_id) if not dag or task_id not in dag.task_ids: - flash( - "Task [{}.{}] doesn't seem to exist" - " at the moment".format(dag_id, task_id), - "error") + flash("Task [{}.{}] doesn't seem to exist" " at the moment".format(dag_id, task_id), "error") return redirect(url_for('Airflow.index')) task = copy.copy(dag.get_task(task_id)) task.resolve_template_files() @@ -1096,8 +1143,7 @@ def task(self): if not attr_name.startswith('_'): attr = getattr(task, attr_name) # pylint: disable=unidiomatic-typecheck - if type(attr) != type(self.task) and \ - attr_name not in wwwutils.get_attr_renderer(): # noqa + if type(attr) != type(self.task) and attr_name not in wwwutils.get_attr_renderer(): # noqa task_attrs.append((attr_name, str(attr))) # pylint: enable=unidiomatic-typecheck @@ -1106,24 +1152,29 @@ def task(self): for attr_name in wwwutils.get_attr_renderer(): if hasattr(task, attr_name): source = getattr(task, attr_name) - special_attrs_rendered[attr_name] = \ - wwwutils.get_attr_renderer()[attr_name](source) - - no_failed_deps_result = [( - "Unknown", - "All dependencies are met but the task instance is not running. In most " - "cases this just means that the task will probably be scheduled soon " - "unless:
    \n- The scheduler is down or under heavy load
    \n{}\n" - "
    \nIf this task instance does not start soon please contact your " - "Airflow administrator for assistance.".format( - "- This task instance already ran and had it's state changed manually " - "(e.g. cleared in the UI)
    " if ti.state == State.NONE else ""))] + special_attrs_rendered[attr_name] = wwwutils.get_attr_renderer()[attr_name](source) + + no_failed_deps_result = [ + ( + "Unknown", + "All dependencies are met but the task instance is not running. In most " + "cases this just means that the task will probably be scheduled soon " + "unless:
    \n- The scheduler is down or under heavy load
    \n{}\n" + "
    \nIf this task instance does not start soon please contact your " + "Airflow administrator for assistance.".format( + "- This task instance already ran and had it's state changed manually " + "(e.g. cleared in the UI)
    " + if ti.state == State.NONE + else "" + ), + ) + ] # Use the scheduler's context to figure out which dependencies are not met dep_context = DepContext(SCHEDULER_QUEUED_DEPS) - failed_dep_reasons = [(dep.dep_name, dep.reason) for dep in - ti.get_failed_dep_statuses( - dep_context=dep_context)] + failed_dep_reasons = [ + (dep.dep_name, dep.reason) for dep in ti.get_failed_dep_statuses(dep_context=dep_context) + ] title = "Task Instance Details" return self.render_template( @@ -1136,7 +1187,9 @@ def task(self): special_attrs_rendered=special_attrs_rendered, form=form, root=root, - dag=dag, title=title) + dag=dag, + title=title, + ) @expose('/xcom') @auth.has_access( @@ -1164,15 +1217,14 @@ def xcom(self, session=None): ti = session.query(ti_db).filter(ti_db.dag_id == dag_id and ti_db.task_id == task_id).first() if not ti: - flash( - "Task [{}.{}] doesn't seem to exist" - " at the moment".format(dag_id, task_id), - "error") + flash("Task [{}.{}] doesn't seem to exist" " at the moment".format(dag_id, task_id), "error") return redirect(url_for('Airflow.index')) - xcomlist = session.query(XCom).filter( - XCom.dag_id == dag_id, XCom.task_id == task_id, - XCom.execution_date == dttm).all() + xcomlist = ( + session.query(XCom) + .filter(XCom.dag_id == dag_id, XCom.task_id == task_id, XCom.execution_date == dttm) + .all() + ) attributes = [] for xcom in xcomlist: @@ -1187,7 +1239,9 @@ def xcom(self, session=None): execution_date=execution_date, form=form, root=root, - dag=dag, title=title) + dag=dag, + title=title, + ) @expose('/run', methods=['POST']) @auth.has_access( @@ -1217,12 +1271,14 @@ def run(self): try: from airflow.executors.celery_executor import CeleryExecutor # noqa + valid_celery_config = isinstance(executor, CeleryExecutor) except ImportError: pass try: from airflow.executors.kubernetes_executor import KubernetesExecutor # noqa + valid_kubernetes_config = isinstance(executor, KubernetesExecutor) except ImportError: pass @@ -1239,14 +1295,16 @@ def run(self): deps=RUNNING_DEPS, ignore_all_deps=ignore_all_deps, ignore_task_deps=ignore_task_deps, - ignore_ti_state=ignore_ti_state) + ignore_ti_state=ignore_ti_state, + ) failed_deps = list(ti.get_failed_dep_statuses(dep_context=dep_context)) if failed_deps: - failed_deps_str = ", ".join( - [f"{dep.dep_name}: {dep.reason}" for dep in failed_deps]) - flash("Could not queue task instance for execution, dependencies not met: " - "{}".format(failed_deps_str), - "error") + failed_deps_str = ", ".join([f"{dep.dep_name}: {dep.reason}" for dep in failed_deps]) + flash( + "Could not queue task instance for execution, dependencies not met: " + "{}".format(failed_deps_str), + "error", + ) return redirect(origin) executor.start() @@ -1254,11 +1312,10 @@ def run(self): ti, ignore_all_deps=ignore_all_deps, ignore_task_deps=ignore_task_deps, - ignore_ti_state=ignore_ti_state) + ignore_ti_state=ignore_ti_state, + ) executor.heartbeat() - flash( - "Sent {} to the message queue, " - "it should start any moment now.".format(ti)) + flash("Sent {} to the message queue, " "it should start any moment now.".format(ti)) return redirect(origin) @expose('/delete', methods=['POST']) @@ -1282,13 +1339,10 @@ def delete(self): flash(f"DAG with id {dag_id} not found. Cannot delete", 'error') return redirect(request.referrer) except DagFileExists: - flash("Dag id {} is still in DagBag. " - "Remove the DAG file first.".format(dag_id), - 'error') + flash("Dag id {} is still in DagBag. " "Remove the DAG file first.".format(dag_id), 'error') return redirect(request.referrer) - flash("Deleting DAG with id {}. May take a couple minutes to fully" - " disappear.".format(dag_id)) + flash("Deleting DAG with id {}. May take a couple minutes to fully" " disappear.".format(dag_id)) # Upon success return to origin. return redirect(origin) @@ -1320,10 +1374,7 @@ def trigger(self, session=None): except TypeError: flash("Could not pre-populate conf field due to non-JSON-serializable data-types") return self.render_template( - 'airflow/trigger.html', - dag_id=dag_id, - origin=origin, - conf=default_conf + 'airflow/trigger.html', dag_id=dag_id, origin=origin, conf=default_conf ) dag_orm = session.query(models.DagModel).filter(models.DagModel.dag_id == dag_id).first() @@ -1345,10 +1396,7 @@ def trigger(self, session=None): except json.decoder.JSONDecodeError: flash("Invalid JSON configuration", "error") return self.render_template( - 'airflow/trigger.html', - dag_id=dag_id, - origin=origin, - conf=request_conf + 'airflow/trigger.html', dag_id=dag_id, origin=origin, conf=request_conf ) dag = current_app.dag_bag.get_dag(dag_id) @@ -1361,13 +1409,12 @@ def trigger(self, session=None): dag_hash=current_app.dag_bag.dags_hash.get(dag_id), ) - flash( - "Triggered {}, " - "it should start any moment now.".format(dag_id)) + flash("Triggered {}, " "it should start any moment now.".format(dag_id)) return redirect(origin) - def _clear_dag_tis(self, dag, start_date, end_date, origin, - recursive=False, confirmed=False, only_failed=False): + def _clear_dag_tis( + self, dag, start_date, end_date, origin, recursive=False, confirmed=False, only_failed=False + ): if confirmed: count = dag.clear( start_date=start_date, @@ -1401,9 +1448,9 @@ def _clear_dag_tis(self, dag, start_date, end_date, origin, response = self.render_template( 'airflow/confirm.html', - message=("Here's the list of task instances you are about " - "to clear:"), - details=details) + message=("Here's the list of task instances you are about " "to clear:"), + details=details, + ) return response @@ -1435,19 +1482,29 @@ def clear(self): dag = dag.sub_dag( task_ids_or_regex=fr"^{task_id}$", include_downstream=downstream, - include_upstream=upstream) + include_upstream=upstream, + ) end_date = execution_date if not future else None start_date = execution_date if not past else None - return self._clear_dag_tis(dag, start_date, end_date, origin, - recursive=recursive, confirmed=confirmed, only_failed=only_failed) + return self._clear_dag_tis( + dag, + start_date, + end_date, + origin, + recursive=recursive, + confirmed=confirmed, + only_failed=only_failed, + ) @expose('/dagrun_clear', methods=['POST']) - @auth.has_access([ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_TASK_INSTANCE), - ]) + @auth.has_access( + [ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_TASK_INSTANCE), + ] + ) @action_logging def dagrun_clear(self): """Clears the DagRun""" @@ -1461,14 +1518,15 @@ def dagrun_clear(self): start_date = execution_date end_date = execution_date - return self._clear_dag_tis(dag, start_date, end_date, origin, - recursive=True, confirmed=confirmed) + return self._clear_dag_tis(dag, start_date, end_date, origin, recursive=True, confirmed=confirmed) @expose('/blocked', methods=['POST']) - @auth.has_access([ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - ]) + @auth.has_access( + [ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), + ] + ) @provide_session def blocked(self, session=None): """Mark Dag Blocked.""" @@ -1478,9 +1536,7 @@ def blocked(self, session=None): allowed_dag_ids = [dag_id for dag_id, in session.query(models.DagModel.dag_id)] # Filter by post parameters - selected_dag_ids = { - unquote(dag_id) for dag_id in request.form.getlist('dag_ids') if dag_id - } + selected_dag_ids = {unquote(dag_id) for dag_id in request.form.getlist('dag_ids') if dag_id} if selected_dag_ids: filter_dag_ids = selected_dag_ids.intersection(allowed_dag_ids) @@ -1493,9 +1549,9 @@ def blocked(self, session=None): # pylint: disable=comparison-with-callable dags = ( session.query(DagRun.dag_id, sqla.func.count(DagRun.id)) - .filter(DagRun.state == State.RUNNING) - .filter(DagRun.dag_id.in_(filter_dag_ids)) - .group_by(DagRun.dag_id) + .filter(DagRun.state == State.RUNNING) + .filter(DagRun.dag_id.in_(filter_dag_ids)) + .group_by(DagRun.dag_id) ) # pylint: enable=comparison-with-callable @@ -1506,11 +1562,13 @@ def blocked(self, session=None): if dag: # TODO: Make max_active_runs a column so we can query for it directly max_active_runs = dag.max_active_runs - payload.append({ - 'dag_id': dag_id, - 'active_dag_run': active_dag_runs, - 'max_active_runs': max_active_runs, - }) + payload.append( + { + 'dag_id': dag_id, + 'active_dag_run': active_dag_runs, + 'max_active_runs': max_active_runs, + } + ) return wwwutils.json_response(payload) def _mark_dagrun_state_as_failed(self, dag_id, execution_date, confirmed, origin): @@ -1537,7 +1595,8 @@ def _mark_dagrun_state_as_failed(self, dag_id, execution_date, confirmed, origin response = self.render_template( 'airflow/confirm.html', message="Here's the list of task instances you are about to mark as failed", - details=details) + details=details, + ) return response @@ -1553,8 +1612,7 @@ def _mark_dagrun_state_as_success(self, dag_id, execution_date, confirmed, origi flash(f'Cannot find DAG: {dag_id}', 'error') return redirect(origin) - new_dag_state = set_dag_run_state_to_success(dag, execution_date, - commit=confirmed) + new_dag_state = set_dag_run_state_to_success(dag, execution_date, commit=confirmed) if confirmed: flash('Marked success on {} task instances'.format(len(new_dag_state))) @@ -1566,15 +1624,18 @@ def _mark_dagrun_state_as_success(self, dag_id, execution_date, confirmed, origi response = self.render_template( 'airflow/confirm.html', message="Here's the list of task instances you are about to mark as success", - details=details) + details=details, + ) return response @expose('/dagrun_failed', methods=['POST']) - @auth.has_access([ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), - ]) + @auth.has_access( + [ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), + ] + ) @action_logging def dagrun_failed(self): """Mark DagRun failed.""" @@ -1582,14 +1643,15 @@ def dagrun_failed(self): execution_date = request.form.get('execution_date') confirmed = request.form.get('confirmed') == 'true' origin = get_safe_url(request.form.get('origin')) - return self._mark_dagrun_state_as_failed(dag_id, execution_date, - confirmed, origin) + return self._mark_dagrun_state_as_failed(dag_id, execution_date, confirmed, origin) @expose('/dagrun_success', methods=['POST']) - @auth.has_access([ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), - ]) + @auth.has_access( + [ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), + ] + ) @action_logging def dagrun_success(self): """Mark DagRun success""" @@ -1597,13 +1659,21 @@ def dagrun_success(self): execution_date = request.form.get('execution_date') confirmed = request.form.get('confirmed') == 'true' origin = get_safe_url(request.form.get('origin')) - return self._mark_dagrun_state_as_success(dag_id, execution_date, - confirmed, origin) - - def _mark_task_instance_state(self, # pylint: disable=too-many-arguments - dag_id, task_id, origin, execution_date, - confirmed, upstream, downstream, - future, past, state): + return self._mark_dagrun_state_as_success(dag_id, execution_date, confirmed, origin) + + def _mark_task_instance_state( # pylint: disable=too-many-arguments + self, + dag_id, + task_id, + origin, + execution_date, + confirmed, + upstream, + downstream, + future, + past, + state, + ): dag = current_app.dag_bag.get_dag(dag_id) task = dag.get_task(task_id) task.dag = dag @@ -1618,33 +1688,48 @@ def _mark_task_instance_state(self, # pylint: disable=too-many-arguments from airflow.api.common.experimental.mark_tasks import set_state if confirmed: - altered = set_state(tasks=[task], execution_date=execution_date, - upstream=upstream, downstream=downstream, - future=future, past=past, state=state, - commit=True) + altered = set_state( + tasks=[task], + execution_date=execution_date, + upstream=upstream, + downstream=downstream, + future=future, + past=past, + state=state, + commit=True, + ) flash("Marked {} on {} task instances".format(state, len(altered))) return redirect(origin) - to_be_altered = set_state(tasks=[task], execution_date=execution_date, - upstream=upstream, downstream=downstream, - future=future, past=past, state=state, - commit=False) + to_be_altered = set_state( + tasks=[task], + execution_date=execution_date, + upstream=upstream, + downstream=downstream, + future=future, + past=past, + state=state, + commit=False, + ) details = "\n".join([str(t) for t in to_be_altered]) response = self.render_template( "airflow/confirm.html", message=(f"Here's the list of task instances you are about to mark as {state}:"), - details=details) + details=details, + ) return response @expose('/failed', methods=['POST']) - @auth.has_access([ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), - ]) + @auth.has_access( + [ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), + ] + ) @action_logging def failed(self): """Mark task as failed.""" @@ -1659,15 +1744,26 @@ def failed(self): future = request.form.get('failed_future') == "true" past = request.form.get('failed_past') == "true" - return self._mark_task_instance_state(dag_id, task_id, origin, execution_date, - confirmed, upstream, downstream, - future, past, State.FAILED) + return self._mark_task_instance_state( + dag_id, + task_id, + origin, + execution_date, + confirmed, + upstream, + downstream, + future, + past, + State.FAILED, + ) @expose('/success', methods=['POST']) - @auth.has_access([ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), - ]) + @auth.has_access( + [ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), + ] + ) @action_logging def success(self): """Mark task as success.""" @@ -1682,16 +1778,27 @@ def success(self): future = request.form.get('success_future') == "true" past = request.form.get('success_past') == "true" - return self._mark_task_instance_state(dag_id, task_id, origin, execution_date, - confirmed, upstream, downstream, - future, past, State.SUCCESS) + return self._mark_task_instance_state( + dag_id, + task_id, + origin, + execution_date, + confirmed, + upstream, + downstream, + future, + past, + State.SUCCESS, + ) @expose('/tree') - @auth.has_access([ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_LOG), - ]) + @auth.has_access( + [ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_LOG), + ] + ) @gzipped # pylint: disable=too-many-locals @action_logging # pylint: disable=too-many-locals def tree(self): @@ -1705,10 +1812,7 @@ def tree(self): root = request.args.get('root') if root: - dag = dag.sub_dag( - task_ids_or_regex=root, - include_downstream=False, - include_upstream=True) + dag = dag.sub_dag(task_ids_or_regex=root, include_downstream=False, include_upstream=True) base_date = request.args.get('base_date') num_runs = request.args.get('num_runs') @@ -1725,16 +1829,12 @@ def tree(self): with create_session() as session: dag_runs = ( session.query(DagRun) - .filter( - DagRun.dag_id == dag.dag_id, - DagRun.execution_date <= base_date) + .filter(DagRun.dag_id == dag.dag_id, DagRun.execution_date <= base_date) .order_by(DagRun.execution_date.desc()) .limit(num_runs) .all() ) - dag_runs = { - dr.execution_date: alchemy_to_dict(dr) for dr in dag_runs - } + dag_runs = {dr.execution_date: alchemy_to_dict(dr) for dr in dag_runs} dates = sorted(list(dag_runs.keys())) max_date = max(dates) if dates else None @@ -1782,10 +1882,7 @@ def recurse_nodes(task, visited): node = { 'name': task.task_id, - 'instances': [ - encode_ti(task_instances.get((task_id, d))) - for d in dates - ], + 'instances': [encode_ti(task_instances.get((task_id, d))) for d in dates], 'num_dep': len(task.downstream_list), 'operator': task.task_type, 'retries': task.retries, @@ -1795,8 +1892,10 @@ def recurse_nodes(task, visited): if task.downstream_list: children = [ - recurse_nodes(t, visited) for t in task.downstream_list - if node_count < node_limit or t not in visited] + recurse_nodes(t, visited) + for t in task.downstream_list + if node_count < node_limit or t not in visited + ] # D3 tree uses children vs _children to define what is # expanded or not. The following block makes it such that @@ -1823,14 +1922,10 @@ def recurse_nodes(task, visited): data = { 'name': '[DAG]', 'children': [recurse_nodes(t, set()) for t in dag.roots], - 'instances': [ - dag_runs.get(d) or {'execution_date': d.isoformat()} - for d in dates - ], + 'instances': [dag_runs.get(d) or {'execution_date': d.isoformat()} for d in dates], } - form = DateTimeWithNumRunsForm(data={'base_date': max_date, - 'num_runs': num_runs}) + form = DateTimeWithNumRunsForm(data={'base_date': max_date, 'num_runs': num_runs}) doc_md = wwwutils.wrapped_markdown(getattr(dag, 'doc_md', None), css_class='dag-doc') @@ -1851,16 +1946,20 @@ def recurse_nodes(task, visited): dag=dag, doc_md=doc_md, data=data, - blur=blur, num_runs=num_runs, + blur=blur, + num_runs=num_runs, show_external_log_redirect=task_log_reader.supports_external_link, - external_log_name=external_log_name) + external_log_name=external_log_name, + ) @expose('/graph') - @auth.has_access([ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_LOG), - ]) + @auth.has_access( + [ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_LOG), + ] + ) @gzipped @action_logging @provide_session @@ -1875,10 +1974,7 @@ def graph(self, session=None): root = request.args.get('root') if root: - dag = dag.sub_dag( - task_ids_or_regex=root, - include_upstream=True, - include_downstream=False) + dag = dag.sub_dag(task_ids_or_regex=root, include_upstream=True, include_downstream=False) arrange = request.args.get('arrange', dag.orientation) @@ -1892,26 +1988,28 @@ def graph(self, session=None): class GraphForm(DateTimeWithNumRunsWithDagRunsForm): """Graph Form class.""" - arrange = SelectField("Layout", choices=( - ('LR', "Left > Right"), - ('RL', "Right > Left"), - ('TB', "Top > Bottom"), - ('BT', "Bottom > Top"), - )) + arrange = SelectField( + "Layout", + choices=( + ('LR', "Left > Right"), + ('RL', "Right > Left"), + ('TB', "Top > Bottom"), + ('BT', "Bottom > Top"), + ), + ) form = GraphForm(data=dt_nr_dr_data) form.execution_date.choices = dt_nr_dr_data['dr_choices'] - task_instances = { - ti.task_id: alchemy_to_dict(ti) - for ti in dag.get_task_instances(dttm, dttm)} + task_instances = {ti.task_id: alchemy_to_dict(ti) for ti in dag.get_task_instances(dttm, dttm)} tasks = { t.task_id: { 'dag_id': t.dag_id, 'task_type': t.task_type, 'extra_links': t.extra_links, } - for t in dag.tasks} + for t in dag.tasks + } if not tasks: flash("No tasks found", "error") session.commit() @@ -1942,13 +2040,16 @@ class GraphForm(DateTimeWithNumRunsWithDagRunsForm): edges=edges, show_external_log_redirect=task_log_reader.supports_external_link, external_log_name=external_log_name, - dag_run_state=dt_nr_dr_data['dr_state']) + dag_run_state=dt_nr_dr_data['dr_state'], + ) @expose('/duration') - @auth.has_access([ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ]) + @auth.has_access( + [ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), + ] + ) @action_logging # pylint: disable=too-many-locals @provide_session # pylint: disable=too-many-locals def duration(self, session=None): @@ -1979,16 +2080,11 @@ def duration(self, session=None): root = request.args.get('root') if root: - dag = dag.sub_dag( - task_ids_or_regex=root, - include_upstream=True, - include_downstream=False) + dag = dag.sub_dag(task_ids_or_regex=root, include_upstream=True, include_downstream=False) chart_height = wwwutils.get_chart_height(dag) - chart = nvd3.lineChart( - name="lineChart", x_is_date=True, height=chart_height, width="1200") - cum_chart = nvd3.lineChart( - name="cumLineChart", x_is_date=True, height=chart_height, width="1200") + chart = nvd3.lineChart(name="lineChart", x_is_date=True, height=chart_height, width="1200") + cum_chart = nvd3.lineChart(name="cumLineChart", x_is_date=True, height=chart_height, width="1200") y_points = defaultdict(list) x_points = defaultdict(list) @@ -1997,17 +2093,22 @@ def duration(self, session=None): task_instances = dag.get_task_instances(start_date=min_date, end_date=base_date) ti_fails = ( session.query(TaskFail) - .filter(TaskFail.dag_id == dag.dag_id, - TaskFail.execution_date >= min_date, - TaskFail.execution_date <= base_date, - TaskFail.task_id.in_([t.task_id for t in dag.tasks])) + .filter( + TaskFail.dag_id == dag.dag_id, + TaskFail.execution_date >= min_date, + TaskFail.execution_date <= base_date, + TaskFail.task_id.in_([t.task_id for t in dag.tasks]), + ) .all() ) fails_totals = defaultdict(int) for failed_task_instance in ti_fails: - dict_key = (failed_task_instance.dag_id, failed_task_instance.task_id, - failed_task_instance.execution_date) + dict_key = ( + failed_task_instance.dag_id, + failed_task_instance.task_id, + failed_task_instance.execution_date, + ) if failed_task_instance.duration: fails_totals[dict_key] += failed_task_instance.duration @@ -2025,34 +2126,38 @@ def duration(self, session=None): y_unit = infer_time_unit([d for t in y_points.values() for d in t]) cum_y_unit = infer_time_unit([d for t in cumulative_y.values() for d in t]) # update the y Axis on both charts to have the correct time units - chart.create_y_axis('yAxis', format='.02f', custom_format=False, - label=f'Duration ({y_unit})') + chart.create_y_axis('yAxis', format='.02f', custom_format=False, label=f'Duration ({y_unit})') chart.axislist['yAxis']['axisLabelDistance'] = '-15' - cum_chart.create_y_axis('yAxis', format='.02f', custom_format=False, - label=f'Duration ({cum_y_unit})') + cum_chart.create_y_axis('yAxis', format='.02f', custom_format=False, label=f'Duration ({cum_y_unit})') cum_chart.axislist['yAxis']['axisLabelDistance'] = '-15' for task in dag.tasks: if x_points[task.task_id]: - chart.add_serie(name=task.task_id, x=x_points[task.task_id], - y=scale_time_units(y_points[task.task_id], y_unit)) - cum_chart.add_serie(name=task.task_id, x=x_points[task.task_id], - y=scale_time_units(cumulative_y[task.task_id], - cum_y_unit)) + chart.add_serie( + name=task.task_id, + x=x_points[task.task_id], + y=scale_time_units(y_points[task.task_id], y_unit), + ) + cum_chart.add_serie( + name=task.task_id, + x=x_points[task.task_id], + y=scale_time_units(cumulative_y[task.task_id], cum_y_unit), + ) dates = sorted(list({ti.execution_date for ti in task_instances})) max_date = max([ti.execution_date for ti in task_instances]) if dates else None session.commit() - form = DateTimeWithNumRunsForm(data={'base_date': max_date, - 'num_runs': num_runs}) + form = DateTimeWithNumRunsForm(data={'base_date': max_date, 'num_runs': num_runs}) chart.buildcontent() cum_chart.buildcontent() s_index = cum_chart.htmlcontent.rfind('});') - cum_chart.htmlcontent = (cum_chart.htmlcontent[:s_index] + - "$( document ).trigger('chartload')" + - cum_chart.htmlcontent[s_index:]) + cum_chart.htmlcontent = ( + cum_chart.htmlcontent[:s_index] + + "$( document ).trigger('chartload')" + + cum_chart.htmlcontent[s_index:] + ) return self.render_template( 'airflow/duration_chart.html', @@ -2065,10 +2170,12 @@ def duration(self, session=None): ) @expose('/tries') - @auth.has_access([ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ]) + @auth.has_access( + [ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), + ] + ) @action_logging @provide_session def tries(self, session=None): @@ -2090,15 +2197,12 @@ def tries(self, session=None): root = request.args.get('root') if root: - dag = dag.sub_dag( - task_ids_or_regex=root, - include_upstream=True, - include_downstream=False) + dag = dag.sub_dag(task_ids_or_regex=root, include_upstream=True, include_downstream=False) chart_height = wwwutils.get_chart_height(dag) chart = nvd3.lineChart( - name="lineChart", x_is_date=True, y_axis_format='d', height=chart_height, - width="1200") + name="lineChart", x_is_date=True, y_axis_format='d', height=chart_height, width="1200" + ) for task in dag.tasks: y_points = [] @@ -2119,8 +2223,7 @@ def tries(self, session=None): session.commit() - form = DateTimeWithNumRunsForm(data={'base_date': max_date, - 'num_runs': num_runs}) + form = DateTimeWithNumRunsForm(data={'base_date': max_date, 'num_runs': num_runs}) chart.buildcontent() @@ -2135,10 +2238,12 @@ def tries(self, session=None): ) @expose('/landing_times') - @auth.has_access([ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ]) + @auth.has_access( + [ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), + ] + ) @action_logging @provide_session def landing_times(self, session=None): @@ -2160,14 +2265,10 @@ def landing_times(self, session=None): root = request.args.get('root') if root: - dag = dag.sub_dag( - task_ids_or_regex=root, - include_upstream=True, - include_downstream=False) + dag = dag.sub_dag(task_ids_or_regex=root, include_upstream=True, include_downstream=False) chart_height = wwwutils.get_chart_height(dag) - chart = nvd3.lineChart( - name="lineChart", x_is_date=True, height=chart_height, width="1200") + chart = nvd3.lineChart(name="lineChart", x_is_date=True, height=chart_height, width="1200") y_points = {} x_points = {} for task in dag.tasks: @@ -2188,13 +2289,15 @@ def landing_times(self, session=None): # for the DAG y_unit = infer_time_unit([d for t in y_points.values() for d in t]) # update the y Axis to have the correct time units - chart.create_y_axis('yAxis', format='.02f', custom_format=False, - label=f'Landing Time ({y_unit})') + chart.create_y_axis('yAxis', format='.02f', custom_format=False, label=f'Landing Time ({y_unit})') chart.axislist['yAxis']['axisLabelDistance'] = '-15' for task in dag.tasks: if x_points[task.task_id]: - chart.add_serie(name=task.task_id, x=x_points[task.task_id], - y=scale_time_units(y_points[task.task_id], y_unit)) + chart.add_serie( + name=task.task_id, + x=x_points[task.task_id], + y=scale_time_units(y_points[task.task_id], y_unit), + ) tis = dag.get_task_instances(start_date=min_date, end_date=base_date) dates = sorted(list({ti.execution_date for ti in tis})) @@ -2202,8 +2305,7 @@ def landing_times(self, session=None): session.commit() - form = DateTimeWithNumRunsForm(data={'base_date': max_date, - 'num_runs': num_runs}) + form = DateTimeWithNumRunsForm(data={'base_date': max_date, 'num_runs': num_runs}) chart.buildcontent() return self.render_template( 'airflow/chart.html', @@ -2217,29 +2319,31 @@ def landing_times(self, session=None): ) @expose('/paused', methods=['POST']) - @auth.has_access([ - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - ]) + @auth.has_access( + [ + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), + ] + ) @action_logging def paused(self): """Toggle paused.""" dag_id = request.args.get('dag_id') is_paused = request.args.get('is_paused') == 'false' - models.DagModel.get_dagmodel(dag_id).set_is_paused( - is_paused=is_paused) + models.DagModel.get_dagmodel(dag_id).set_is_paused(is_paused=is_paused) return "OK" @expose('/refresh', methods=['POST']) - @auth.has_access([ - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - ]) + @auth.has_access( + [ + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), + ] + ) @action_logging @provide_session def refresh(self, session=None): """Refresh DAG.""" dag_id = request.values.get('dag_id') - orm_dag = session.query( - DagModel).filter(DagModel.dag_id == dag_id).first() + orm_dag = session.query(DagModel).filter(DagModel.dag_id == dag_id).first() if orm_dag: orm_dag.last_expired = timezone.utcnow() @@ -2254,9 +2358,11 @@ def refresh(self, session=None): return redirect(request.referrer) @expose('/refresh_all', methods=['POST']) - @auth.has_access([ - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - ]) + @auth.has_access( + [ + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), + ] + ) @action_logging def refresh_all(self): """Refresh everything""" @@ -2269,10 +2375,12 @@ def refresh_all(self): return redirect(url_for('Airflow.index')) @expose('/gantt') - @auth.has_access([ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ]) + @auth.has_access( + [ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), + ] + ) @action_logging @provide_session def gantt(self, session=None): @@ -2283,10 +2391,7 @@ def gantt(self, session=None): root = request.args.get('root') if root: - dag = dag.sub_dag( - task_ids_or_regex=root, - include_upstream=True, - include_downstream=False) + dag = dag.sub_dag(task_ids_or_regex=root, include_upstream=True, include_downstream=False) dt_nr_dr_data = get_date_time_num_runs_dag_runs_form_data(request, session, dag) dttm = dt_nr_dr_data['dttm'] @@ -2294,18 +2399,24 @@ def gantt(self, session=None): form = DateTimeWithNumRunsWithDagRunsForm(data=dt_nr_dr_data) form.execution_date.choices = dt_nr_dr_data['dr_choices'] - tis = [ - ti for ti in dag.get_task_instances(dttm, dttm) - if ti.start_date and ti.state] + tis = [ti for ti in dag.get_task_instances(dttm, dttm) if ti.start_date and ti.state] tis = sorted(tis, key=lambda ti: ti.start_date) - ti_fails = list(itertools.chain(*[( - session - .query(TaskFail) - .filter(TaskFail.dag_id == ti.dag_id, - TaskFail.task_id == ti.task_id, - TaskFail.execution_date == ti.execution_date) - .all() - ) for ti in tis])) + ti_fails = list( + itertools.chain( + *[ + ( + session.query(TaskFail) + .filter( + TaskFail.dag_id == ti.dag_id, + TaskFail.task_id == ti.task_id, + TaskFail.execution_date == ti.execution_date, + ) + .all() + ) + for ti in tis + ] + ) + ) # determine bars to show in the gantt chart gantt_bar_items = [] @@ -2333,8 +2444,9 @@ def gantt(self, session=None): else: try_count = 1 prev_task_id = failed_task_instance.task_id - gantt_bar_items.append((failed_task_instance.task_id, start_date, end_date, State.FAILED, - try_count)) + gantt_bar_items.append( + (failed_task_instance.task_id, start_date, end_date, State.FAILED, try_count) + ) tf_count += 1 task = dag.get_task(failed_task_instance.task_id) task_dict = alchemy_to_dict(failed_task_instance) @@ -2364,10 +2476,12 @@ def gantt(self, session=None): ) @expose('/extra_links') - @auth.has_access([ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ]) + @auth.has_access( + [ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), + ] + ) @action_logging def extra_links(self): """ @@ -2397,11 +2511,10 @@ def extra_links(self): if not dag or task_id not in dag.task_ids: response = jsonify( - {'url': None, - 'error': "can't find dag {dag} or task_id {task_id}".format( - dag=dag, - task_id=task_id - )} + { + 'url': None, + 'error': f"can't find dag {dag} or task_id {task_id}", + } ) response.status_code = 404 return response @@ -2418,16 +2531,17 @@ def extra_links(self): response.status_code = 200 return response else: - response = jsonify( - {'url': None, 'error': f'No URL found for {link_name}'}) + response = jsonify({'url': None, 'error': f'No URL found for {link_name}'}) response.status_code = 404 return response @expose('/object/task_instances') - @auth.has_access([ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ]) + @auth.has_access( + [ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), + ] + ) @action_logging def task_instances(self): """Shows task instances.""" @@ -2440,9 +2554,7 @@ def task_instances(self): else: return "Error: Invalid execution_date" - task_instances = { - ti.task_id: alchemy_to_dict(ti) - for ti in dag.get_task_instances(dttm, dttm)} + task_instances = {ti.task_id: alchemy_to_dict(ti) for ti in dag.get_task_instances(dttm, dttm)} return json.dumps(task_instances, cls=utils_json.AirflowJsonEncoder) @@ -2453,12 +2565,17 @@ class ConfigurationView(AirflowBaseView): default_view = 'conf' class_permission_name = permissions.RESOURCE_CONFIG - base_permissions = [permissions.ACTION_CAN_READ, permissions.ACTION_CAN_ACCESS_MENU,] + base_permissions = [ + permissions.ACTION_CAN_READ, + permissions.ACTION_CAN_ACCESS_MENU, + ] @expose('/configuration') - @auth.has_access([ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_CONFIG), - ]) + @auth.has_access( + [ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_CONFIG), + ] + ) def conf(self): """Shows configuration.""" raw = request.args.get('raw') == "true" @@ -2468,31 +2585,36 @@ def conf(self): if conf.getboolean("webserver", "expose_config"): with open(AIRFLOW_CONFIG) as file: config = file.read() - table = [(section, key, value, source) - for section, parameters in conf.as_dict(True, True).items() - for key, (value, source) in parameters.items()] + table = [ + (section, key, value, source) + for section, parameters in conf.as_dict(True, True).items() + for key, (value, source) in parameters.items() + ] else: config = ( "# Your Airflow administrator chose not to expose the " - "configuration, most likely for security reasons.") + "configuration, most likely for security reasons." + ) table = None if raw: - return Response( - response=config, - status=200, - mimetype="application/text") + return Response(response=config, status=200, mimetype="application/text") else: - code_html = Markup(highlight( - config, - lexers.IniLexer(), # Lexer call pylint: disable=no-member - HtmlFormatter(noclasses=True)) + code_html = Markup( + highlight( + config, + lexers.IniLexer(), # Lexer call pylint: disable=no-member + HtmlFormatter(noclasses=True), + ) ) return self.render_template( 'airflow/config.html', pre_subtitle=settings.HEADER + " v" + airflow.__version__, - code_html=code_html, title=title, subtitle=subtitle, - table=table) + code_html=code_html, + title=title, + subtitle=subtitle, + table=table, + ) class RedocView(AirflowBaseView): @@ -2511,10 +2633,11 @@ def redoc(self): # ModelViews ###################################################################################### + class DagFilter(BaseFilter): """Filter using DagIDs""" - def apply(self, query, func): # noqa pylint: disable=redefined-outer-name,unused-argument + def apply(self, query, func): # noqa pylint: disable=redefined-outer-name,unused-argument if current_app.appbuilder.sm.has_all_dags_access(): return query filter_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids(g.user) @@ -2595,8 +2718,7 @@ class XComModelView(AirflowModelView): 'dag_id': wwwutils.dag_link, } - @action('muldelete', 'Delete', "Are you sure you want to delete selected records?", - single=False) + @action('muldelete', 'Delete', "Are you sure you want to delete selected records?", single=False) def action_muldelete(self, items): """Multiple delete action.""" self.datamodel.delete_all(items) @@ -2638,38 +2760,49 @@ class ConnectionModelView(AirflowModelView): permissions.ACTION_CAN_ACCESS_MENU, ] - extra_fields = ['extra__jdbc__drv_path', 'extra__jdbc__drv_clsname', - 'extra__google_cloud_platform__project', - 'extra__google_cloud_platform__key_path', - 'extra__google_cloud_platform__keyfile_dict', - 'extra__google_cloud_platform__scope', - 'extra__google_cloud_platform__num_retries', - 'extra__grpc__auth_type', - 'extra__grpc__credential_pem_file', - 'extra__grpc__scopes', - 'extra__yandexcloud__service_account_json', - 'extra__yandexcloud__service_account_json_path', - 'extra__yandexcloud__oauth', - 'extra__yandexcloud__public_ssh_key', - 'extra__yandexcloud__folder_id', - 'extra__kubernetes__in_cluster', - 'extra__kubernetes__kube_config', - 'extra__kubernetes__namespace'] - list_columns = ['conn_id', 'conn_type', 'host', 'port', 'is_encrypted', - 'is_extra_encrypted'] - add_columns = edit_columns = ['conn_id', 'conn_type', 'host', 'schema', - 'login', 'password', 'port', 'extra'] + extra_fields + extra_fields = [ + 'extra__jdbc__drv_path', + 'extra__jdbc__drv_clsname', + 'extra__google_cloud_platform__project', + 'extra__google_cloud_platform__key_path', + 'extra__google_cloud_platform__keyfile_dict', + 'extra__google_cloud_platform__scope', + 'extra__google_cloud_platform__num_retries', + 'extra__grpc__auth_type', + 'extra__grpc__credential_pem_file', + 'extra__grpc__scopes', + 'extra__yandexcloud__service_account_json', + 'extra__yandexcloud__service_account_json_path', + 'extra__yandexcloud__oauth', + 'extra__yandexcloud__public_ssh_key', + 'extra__yandexcloud__folder_id', + 'extra__kubernetes__in_cluster', + 'extra__kubernetes__kube_config', + 'extra__kubernetes__namespace', + ] + list_columns = ['conn_id', 'conn_type', 'host', 'port', 'is_encrypted', 'is_extra_encrypted'] + add_columns = edit_columns = [ + 'conn_id', + 'conn_type', + 'host', + 'schema', + 'login', + 'password', + 'port', + 'extra', + ] + extra_fields add_form = edit_form = ConnectionForm add_template = 'airflow/conn_create.html' edit_template = 'airflow/conn_edit.html' base_order = ('conn_id', 'asc') - @action('muldelete', 'Delete', 'Are you sure you want to delete selected records?', - single=False) - @auth.has_access([ - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - ]) + @action('muldelete', 'Delete', 'Are you sure you want to delete selected records?', single=False) + @auth.has_access( + [ + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), + ] + ) def action_muldelete(self, items): """Multiple delete.""" self.datamodel.delete_all(items) @@ -2680,9 +2813,7 @@ def process_form(self, form, is_created): """Process form data.""" formdata = form.data if formdata['conn_type'] in ['jdbc', 'google_cloud_platform', 'grpc', 'yandexcloud', 'kubernetes']: - extra = { - key: formdata[key] - for key in self.extra_fields if key in formdata} + extra = {key: formdata[key] for key in self.extra_fields if key in formdata} form.extra.data = json.dumps(extra) def prefill_form(self, form, pk): @@ -2795,8 +2926,7 @@ class PoolModelView(AirflowModelView): base_order = ('pool', 'asc') - @action('muldelete', 'Delete', 'Are you sure you want to delete selected records?', - single=False) + @action('muldelete', 'Delete', 'Are you sure you want to delete selected records?', single=False) def action_muldelete(self, items): """Multiple delete.""" if any(item.pool == models.Pool.DEFAULT_POOL_NAME for item in items): @@ -2823,8 +2953,8 @@ def frunning_slots(self): if pool_id is not None and running_slots is not None: url = url_for('TaskInstanceModelView.list', _flt_3_pool=pool_id, _flt_3_state='running') return Markup("{running_slots}").format( # noqa - url=url, - running_slots=running_slots) + url=url, running_slots=running_slots + ) else: return Markup('Invalid') @@ -2835,20 +2965,14 @@ def fqueued_slots(self): if pool_id is not None and queued_slots is not None: url = url_for('TaskInstanceModelView.list', _flt_3_pool=pool_id, _flt_3_state='queued') return Markup("{queued_slots}").format( # noqa - url=url, queued_slots=queued_slots) + url=url, queued_slots=queued_slots + ) else: return Markup('Invalid') - formatters_columns = { - 'pool': pool_link, - 'running_slots': frunning_slots, - 'queued_slots': fqueued_slots - } + formatters_columns = {'pool': pool_link, 'running_slots': frunning_slots, 'queued_slots': fqueued_slots} - validators_columns = { - 'pool': [validators.DataRequired()], - 'slots': [validators.NumberRange(min=-1)] - } + validators_columns = {'pool': [validators.DataRequired()], 'slots': [validators.NumberRange(min=-1)]} class VariableModelView(AirflowModelView): @@ -2901,16 +3025,13 @@ def hidden_field_formatter(self): 'val': hidden_field_formatter, } - validators_columns = { - 'key': [validators.DataRequired()] - } + validators_columns = {'key': [validators.DataRequired()]} def prefill_form(self, form, request_id): # pylint: disable=unused-argument if wwwutils.should_hide_value_for_key(form.key.data): form.val.data = '*' * 8 - @action('muldelete', 'Delete', 'Are you sure you want to delete selected records?', - single=False) + @action('muldelete', 'Delete', 'Are you sure you want to delete selected records?', single=False) def action_muldelete(self, items): """Multiple delete.""" self.datamodel.delete_all(items) @@ -2976,14 +3097,35 @@ class JobModelView(AirflowModelView): method_permission_name = { 'list': 'read', } - base_permissions = [permissions.ACTION_CAN_READ, permissions.ACTION_CAN_ACCESS_MENU,] + base_permissions = [ + permissions.ACTION_CAN_READ, + permissions.ACTION_CAN_ACCESS_MENU, + ] - list_columns = ['id', 'dag_id', 'state', 'job_type', 'start_date', - 'end_date', 'latest_heartbeat', - 'executor_class', 'hostname', 'unixname'] - search_columns = ['id', 'dag_id', 'state', 'job_type', 'start_date', - 'end_date', 'latest_heartbeat', 'executor_class', - 'hostname', 'unixname'] + list_columns = [ + 'id', + 'dag_id', + 'state', + 'job_type', + 'start_date', + 'end_date', + 'latest_heartbeat', + 'executor_class', + 'hostname', + 'unixname', + ] + search_columns = [ + 'id', + 'dag_id', + 'state', + 'job_type', + 'start_date', + 'end_date', + 'latest_heartbeat', + 'executor_class', + 'hostname', + 'unixname', + ] base_order = ('start_date', 'desc') @@ -3041,8 +3183,7 @@ class DagRunModelView(AirflowModelView): 'conf': wwwutils.json_f('conf'), } - @action('muldelete', "Delete", "Are you sure you want to delete selected records?", - single=False) + @action('muldelete', "Delete", "Are you sure you want to delete selected records?", single=False) @provide_session def action_muldelete(self, items, session=None): # noqa # pylint: disable=unused-argument """Multiple delete.""" @@ -3060,8 +3201,9 @@ def action_set_running(self, drs, session=None): try: count = 0 dirty_ids = [] - for dr in session.query(DagRun).filter( - DagRun.id.in_([dagrun.id for dagrun in drs])).all(): # noqa pylint: disable=no-member + for dr in ( + session.query(DagRun).filter(DagRun.id.in_([dagrun.id for dagrun in drs])).all() + ): # noqa pylint: disable=no-member dirty_ids.append(dr.dag_id) count += 1 dr.start_date = timezone.utcnow() @@ -3073,9 +3215,12 @@ def action_set_running(self, drs, session=None): flash('Failed to set state', 'error') return redirect(self.get_default_url()) - @action('set_failed', "Set state to 'failed'", - "All running task instances would also be marked as failed, are you sure?", - single=False) + @action( + 'set_failed', + "Set state to 'failed'", + "All running task instances would also be marked as failed, are you sure?", + single=False, + ) @provide_session def action_set_failed(self, drs, session=None): """Set state to failed.""" @@ -3083,26 +3228,29 @@ def action_set_failed(self, drs, session=None): count = 0 dirty_ids = [] altered_tis = [] - for dr in session.query(DagRun).filter( - DagRun.id.in_([dagrun.id for dagrun in drs])).all(): # noqa pylint: disable=no-member + for dr in ( + session.query(DagRun).filter(DagRun.id.in_([dagrun.id for dagrun in drs])).all() + ): # noqa pylint: disable=no-member dirty_ids.append(dr.dag_id) count += 1 - altered_tis += \ - set_dag_run_state_to_failed(current_app.dag_bag.get_dag(dr.dag_id), - dr.execution_date, - commit=True, - session=session) + altered_tis += set_dag_run_state_to_failed( + current_app.dag_bag.get_dag(dr.dag_id), dr.execution_date, commit=True, session=session + ) altered_ti_count = len(altered_tis) flash( "{count} dag runs and {altered_ti_count} task instances " - "were set to failed".format(count=count, altered_ti_count=altered_ti_count)) + "were set to failed".format(count=count, altered_ti_count=altered_ti_count) + ) except Exception: # noqa pylint: disable=broad-except flash('Failed to set state', 'error') return redirect(self.get_default_url()) - @action('set_success', "Set state to 'success'", - "All task instances would also be marked as success, are you sure?", - single=False) + @action( + 'set_success', + "Set state to 'success'", + "All task instances would also be marked as success, are you sure?", + single=False, + ) @provide_session def action_set_success(self, drs, session=None): """Set state to success.""" @@ -3110,26 +3258,24 @@ def action_set_success(self, drs, session=None): count = 0 dirty_ids = [] altered_tis = [] - for dr in session.query(DagRun).filter( - DagRun.id.in_([dagrun.id for dagrun in drs])).all(): # noqa pylint: disable=no-member + for dr in ( + session.query(DagRun).filter(DagRun.id.in_([dagrun.id for dagrun in drs])).all() + ): # noqa pylint: disable=no-member dirty_ids.append(dr.dag_id) count += 1 - altered_tis += \ - set_dag_run_state_to_success(current_app.dag_bag.get_dag(dr.dag_id), - dr.execution_date, - commit=True, - session=session) + altered_tis += set_dag_run_state_to_success( + current_app.dag_bag.get_dag(dr.dag_id), dr.execution_date, commit=True, session=session + ) altered_ti_count = len(altered_tis) flash( "{count} dag runs and {altered_ti_count} task instances " - "were set to success".format(count=count, altered_ti_count=altered_ti_count)) + "were set to success".format(count=count, altered_ti_count=altered_ti_count) + ) except Exception: # noqa pylint: disable=broad-except flash('Failed to set state', 'error') return redirect(self.get_default_url()) - @action('clear', "Clear the state", - "All task instances would be cleared, are you sure?", - single=False) + @action('clear', "Clear the state", "All task instances would be cleared, are you sure?", single=False) @provide_session def action_clear(self, drs, session=None): """Clears the state.""" @@ -3147,8 +3293,10 @@ def action_clear(self, drs, session=None): cleared_ti_count += len(tis) models.clear_task_instances(tis, session, dag=dag) - flash("{count} dag runs and {altered_ti_count} task instances " - "were cleared".format(count=count, altered_ti_count=cleared_ti_count)) + flash( + "{count} dag runs and {altered_ti_count} task instances " + "were cleared".format(count=count, altered_ti_count=cleared_ti_count) + ) except Exception: # noqa pylint: disable=broad-except flash('Failed to clear state', 'error') return redirect(self.get_default_url()) @@ -3165,10 +3313,12 @@ class LogModelView(AirflowModelView): method_permission_name = { 'list': 'read', } - base_permissions = [permissions.ACTION_CAN_READ, permissions.ACTION_CAN_ACCESS_MENU,] + base_permissions = [ + permissions.ACTION_CAN_READ, + permissions.ACTION_CAN_ACCESS_MENU, + ] - list_columns = ['id', 'dttm', 'dag_id', 'task_id', 'event', 'execution_date', - 'owner', 'extra'] + list_columns = ['id', 'dttm', 'dag_id', 'task_id', 'event', 'execution_date', 'owner', 'extra'] search_columns = ['dag_id', 'task_id', 'event', 'execution_date', 'owner', 'extra'] base_order = ('dttm', 'desc') @@ -3194,13 +3344,24 @@ class TaskRescheduleModelView(AirflowModelView): 'list': 'read', } - base_permissions = [permissions.ACTION_CAN_READ, permissions.ACTION_CAN_ACCESS_MENU,] + base_permissions = [ + permissions.ACTION_CAN_READ, + permissions.ACTION_CAN_ACCESS_MENU, + ] - list_columns = ['id', 'dag_id', 'task_id', 'execution_date', 'try_number', - 'start_date', 'end_date', 'duration', 'reschedule_date'] + list_columns = [ + 'id', + 'dag_id', + 'task_id', + 'execution_date', + 'try_number', + 'start_date', + 'end_date', + 'duration', + 'reschedule_date', + ] - search_columns = ['dag_id', 'task_id', 'execution_date', 'start_date', 'end_date', - 'reschedule_date'] + search_columns = ['dag_id', 'task_id', 'execution_date', 'start_date', 'end_date', 'reschedule_date'] base_order = ('id', 'desc') @@ -3251,15 +3412,40 @@ class TaskInstanceModelView(AirflowModelView): page_size = PAGE_SIZE - list_columns = ['state', 'dag_id', 'task_id', 'execution_date', 'operator', - 'start_date', 'end_date', 'duration', 'job_id', 'hostname', - 'unixname', 'priority_weight', 'queue', 'queued_dttm', 'try_number', - 'pool', 'log_url'] + list_columns = [ + 'state', + 'dag_id', + 'task_id', + 'execution_date', + 'operator', + 'start_date', + 'end_date', + 'duration', + 'job_id', + 'hostname', + 'unixname', + 'priority_weight', + 'queue', + 'queued_dttm', + 'try_number', + 'pool', + 'log_url', + ] order_columns = [item for item in list_columns if item not in ['try_number', 'log_url']] - search_columns = ['state', 'dag_id', 'task_id', 'execution_date', 'hostname', - 'queue', 'pool', 'operator', 'start_date', 'end_date'] + search_columns = [ + 'state', + 'dag_id', + 'task_id', + 'execution_date', + 'hostname', + 'queue', + 'pool', + 'operator', + 'start_date', + 'end_date', + ] base_order = ('job_id', 'asc') @@ -3294,10 +3480,15 @@ def duration_f(self): } @provide_session - @action('clear', lazy_gettext('Clear'), - lazy_gettext('Are you sure you want to clear the state of the selected task' - ' instance(s) and set their dagruns to the running state?'), - single=False) + @action( + 'clear', + lazy_gettext('Clear'), + lazy_gettext( + 'Are you sure you want to clear the state of the selected task' + ' instance(s) and set their dagruns to the running state?' + ), + single=False, + ) def action_clear(self, task_instances, session=None): """Clears the action.""" try: @@ -3326,15 +3517,20 @@ def set_task_instance_state(self, tis, target_state, session=None): for ti in tis: ti.set_state(target_state, session) session.commit() - flash("{count} task instances were set to '{target_state}'".format( - count=count, target_state=target_state)) + flash( + "{count} task instances were set to '{target_state}'".format( + count=count, target_state=target_state + ) + ) except Exception: # noqa pylint: disable=broad-except flash('Failed to set state', 'error') @action('set_running', "Set state to 'running'", '', single=False) - @auth.has_access([ - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - ]) + @auth.has_access( + [ + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), + ] + ) def action_set_running(self, tis): """Set state to 'running'""" self.set_task_instance_state(tis, State.RUNNING) @@ -3342,9 +3538,11 @@ def action_set_running(self, tis): return redirect(self.get_redirect()) @action('set_failed', "Set state to 'failed'", '', single=False) - @auth.has_access([ - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - ]) + @auth.has_access( + [ + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), + ] + ) def action_set_failed(self, tis): """Set state to 'failed'""" self.set_task_instance_state(tis, State.FAILED) @@ -3352,9 +3550,11 @@ def action_set_failed(self, tis): return redirect(self.get_redirect()) @action('set_success', "Set state to 'success'", '', single=False) - @auth.has_access([ - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - ]) + @auth.has_access( + [ + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), + ] + ) def action_set_success(self, tis): """Set state to 'success'""" self.set_task_instance_state(tis, State.SUCCESS) @@ -3362,9 +3562,11 @@ def action_set_success(self, tis): return redirect(self.get_redirect()) @action('set_retry', "Set state to 'up_for_retry'", '', single=False) - @auth.has_access([ - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - ]) + @auth.has_access( + [ + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), + ] + ) def action_set_retry(self, tis): """Set state to 'up_for_retry'""" self.set_task_instance_state(tis, State.UP_FOR_RETRY) @@ -3387,38 +3589,46 @@ class DagModelView(AirflowModelView): base_permissions = [ permissions.ACTION_CAN_READ, permissions.ACTION_CAN_EDIT, - permissions.ACTION_CAN_DELETE + permissions.ACTION_CAN_DELETE, ] - list_columns = ['dag_id', 'is_paused', 'last_scheduler_run', - 'last_expired', 'scheduler_lock', 'fileloc', 'owners'] + list_columns = [ + 'dag_id', + 'is_paused', + 'last_scheduler_run', + 'last_expired', + 'scheduler_lock', + 'fileloc', + 'owners', + ] - formatters_columns = { - 'dag_id': wwwutils.dag_link - } + formatters_columns = {'dag_id': wwwutils.dag_link} base_filters = [['dag_id', DagFilter, lambda: []]] def get_query(self): """Default filters for model""" return ( - super().get_query() # noqa pylint: disable=no-member - .filter(or_(models.DagModel.is_active, - models.DagModel.is_paused)) + super() # noqa pylint: disable=no-member + .get_query() + .filter(or_(models.DagModel.is_active, models.DagModel.is_paused)) .filter(~models.DagModel.is_subdag) ) def get_count_query(self): """Default filters for model""" return ( - super().get_count_query() # noqa pylint: disable=no-member + super() # noqa pylint: disable=no-member + .get_count_query() .filter(models.DagModel.is_active) .filter(~models.DagModel.is_subdag) ) - @auth.has_access([ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - ]) + @auth.has_access( + [ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + ] + ) @provide_session @expose('/autocomplete') def autocomplete(self, session=None): @@ -3430,12 +3640,12 @@ def autocomplete(self, session=None): # Provide suggestions of dag_ids and owners dag_ids_query = session.query(DagModel.dag_id.label('item')).filter( # pylint: disable=no-member - ~DagModel.is_subdag, DagModel.is_active, - DagModel.dag_id.ilike('%' + query + '%')) # noqa pylint: disable=no-member + ~DagModel.is_subdag, DagModel.is_active, DagModel.dag_id.ilike('%' + query + '%') + ) # noqa pylint: disable=no-member owners_query = session.query(func.distinct(DagModel.owners).label('item')).filter( - ~DagModel.is_subdag, DagModel.is_active, - DagModel.owners.ilike('%' + query + '%')) # noqa pylint: disable=no-member + ~DagModel.is_subdag, DagModel.is_active, DagModel.owners.ilike('%' + query + '%') + ) # noqa pylint: disable=no-member # Hide DAGs if not showing status: "all" status = flask_session.get(FILTER_STATUS_COOKIE) diff --git a/airflow/www/widgets.py b/airflow/www/widgets.py index 199c274ac0e7c..172380c2a733c 100644 --- a/airflow/www/widgets.py +++ b/airflow/www/widgets.py @@ -36,7 +36,6 @@ class AirflowDateTimePickerWidget: "" '' "" - ) def __call__(self, field, **kwargs): @@ -46,6 +45,4 @@ def __call__(self, field, **kwargs): field.data = "" template = self.data_template - return Markup( - template % {"text": html_params(type="text", value=field.data, **kwargs)} - ) + return Markup(template % {"text": html_params(type="text", value=field.data, **kwargs)}) diff --git a/chart/tests/test_celery_kubernetes_executor.py b/chart/tests/test_celery_kubernetes_executor.py index 57c39809b9351..fb219297553f3 100644 --- a/chart/tests/test_celery_kubernetes_executor.py +++ b/chart/tests/test_celery_kubernetes_executor.py @@ -18,6 +18,7 @@ import unittest import jmespath + from tests.helm_template_generator import render_chart diff --git a/chart/tests/test_celery_kubernetes_pod_launcher_role.py b/chart/tests/test_celery_kubernetes_pod_launcher_role.py index 535be11fafabc..952bc390b2786 100644 --- a/chart/tests/test_celery_kubernetes_pod_launcher_role.py +++ b/chart/tests/test_celery_kubernetes_pod_launcher_role.py @@ -18,6 +18,7 @@ import unittest import jmespath + from tests.helm_template_generator import render_chart diff --git a/chart/tests/test_chart_quality.py b/chart/tests/test_chart_quality.py index 32237cb7aa43d..5f17165ea3dea 100644 --- a/chart/tests/test_chart_quality.py +++ b/chart/tests/test_chart_quality.py @@ -18,11 +18,10 @@ import json import os import unittest -import yaml +import yaml from jsonschema import validate - CHART_FOLDER = os.path.dirname(os.path.dirname(__file__)) diff --git a/chart/tests/test_dags_persistent_volume_claim.py b/chart/tests/test_dags_persistent_volume_claim.py index 069a0cd5998d2..946c40fe7c319 100644 --- a/chart/tests/test_dags_persistent_volume_claim.py +++ b/chart/tests/test_dags_persistent_volume_claim.py @@ -18,6 +18,7 @@ import unittest import jmespath + from tests.helm_template_generator import render_chart diff --git a/chart/tests/test_flower_authorization.py b/chart/tests/test_flower_authorization.py index f0cc5b07b2af7..0520ddd28f273 100644 --- a/chart/tests/test_flower_authorization.py +++ b/chart/tests/test_flower_authorization.py @@ -18,6 +18,7 @@ import unittest import jmespath + from tests.helm_template_generator import render_chart diff --git a/chart/tests/test_git_sync_scheduler.py b/chart/tests/test_git_sync_scheduler.py index a01c0f216e8f4..068f36c83cf71 100644 --- a/chart/tests/test_git_sync_scheduler.py +++ b/chart/tests/test_git_sync_scheduler.py @@ -18,6 +18,7 @@ import unittest import jmespath + from tests.helm_template_generator import render_chart diff --git a/chart/tests/test_git_sync_webserver.py b/chart/tests/test_git_sync_webserver.py index 30c3f330a48a8..75ec51bf1fd83 100644 --- a/chart/tests/test_git_sync_webserver.py +++ b/chart/tests/test_git_sync_webserver.py @@ -18,6 +18,7 @@ import unittest import jmespath + from tests.helm_template_generator import render_chart diff --git a/chart/tests/test_git_sync_worker.py b/chart/tests/test_git_sync_worker.py index a70d311a121b1..e5036d76a4384 100644 --- a/chart/tests/test_git_sync_worker.py +++ b/chart/tests/test_git_sync_worker.py @@ -18,6 +18,7 @@ import unittest import jmespath + from tests.helm_template_generator import render_chart diff --git a/chart/tests/test_migrate_database_job.py b/chart/tests/test_migrate_database_job.py index 0524315d5e608..4b92acaf60100 100644 --- a/chart/tests/test_migrate_database_job.py +++ b/chart/tests/test_migrate_database_job.py @@ -18,6 +18,7 @@ import unittest import jmespath + from tests.helm_template_generator import render_chart diff --git a/chart/tests/test_pod_template_file.py b/chart/tests/test_pod_template_file.py index 0673e08ba7b18..d9334deb07d31 100644 --- a/chart/tests/test_pod_template_file.py +++ b/chart/tests/test_pod_template_file.py @@ -17,10 +17,11 @@ import unittest from os import remove -from os.path import realpath, dirname +from os.path import dirname, realpath from shutil import copyfile import jmespath + from tests.helm_template_generator import render_chart ROOT_FOLDER = realpath(dirname(realpath(__file__)) + "/..") diff --git a/chart/tests/test_scheduler.py b/chart/tests/test_scheduler.py index 976984823d55d..eb5225e35c389 100644 --- a/chart/tests/test_scheduler.py +++ b/chart/tests/test_scheduler.py @@ -18,6 +18,7 @@ import unittest import jmespath + from tests.helm_template_generator import render_chart diff --git a/chart/tests/test_worker.py b/chart/tests/test_worker.py index 2fc6d17c48162..9b3515ef81562 100644 --- a/chart/tests/test_worker.py +++ b/chart/tests/test_worker.py @@ -18,6 +18,7 @@ import unittest import jmespath + from tests.helm_template_generator import render_chart diff --git a/dags/test_dag.py b/dags/test_dag.py index 8a1695f310e3f..75f279c5a1dd0 100644 --- a/dags/test_dag.py +++ b/dags/test_dag.py @@ -23,17 +23,11 @@ from airflow.utils.dates import days_ago now = datetime.now() -now_to_the_hour = ( - now - timedelta(0, 0, 0, 0, 0, 3) -).replace(minute=0, second=0, microsecond=0) +now_to_the_hour = (now - timedelta(0, 0, 0, 0, 0, 3)).replace(minute=0, second=0, microsecond=0) START_DATE = now_to_the_hour DAG_NAME = 'test_dag_v1' -default_args = { - 'owner': 'airflow', - 'depends_on_past': True, - 'start_date': days_ago(2) -} +default_args = {'owner': 'airflow', 'depends_on_past': True, 'start_date': days_ago(2)} dag = DAG(DAG_NAME, schedule_interval='*/10 * * * *', default_args=default_args) run_this_1 = DummyOperator(task_id='run_this_1', dag=dag) diff --git a/dev/airflow-github b/dev/airflow-github index 17490f7c55ff8..10cfdfa862133 100755 --- a/dev/airflow-github +++ b/dev/airflow-github @@ -29,8 +29,7 @@ from github import Github from github.Issue import Issue from github.PullRequest import PullRequest -GIT_COMMIT_FIELDS = ['id', 'author_name', - 'author_email', 'date', 'subject', 'body'] +GIT_COMMIT_FIELDS = ['id', 'author_name', 'author_email', 'date', 'subject', 'body'] GIT_LOG_FORMAT = '%x1f'.join(['%h', '%an', '%ae', '%ad', '%s', '%b']) + '%x1e' pr_title_re = re.compile(r".*\((#[0-9]{1,6})\)$") @@ -131,17 +130,20 @@ def cli(): @cli.command(short_help='Compare a GitHub target version against git merges') @click.argument('target_version') @click.argument('github-token', envvar='GITHUB_TOKEN') -@click.option('--previous-version', - 'previous_version', - help="Specify the previous tag on the working branch to limit" - " searching for few commits to find the cherry-picked commits") +@click.option( + '--previous-version', + 'previous_version', + help="Specify the previous tag on the working branch to limit" + " searching for few commits to find the cherry-picked commits", +) @click.option('--unmerged', 'show_uncherrypicked_only', help="Show unmerged issues only", is_flag=True) def compare(target_version, github_token, previous_version=None, show_uncherrypicked_only=False): repo = git.Repo(".", search_parent_directories=True) github_handler = Github(github_token) - milestone_issues: List[Issue] = list(github_handler.search_issues( - f"repo:apache/airflow milestone:\"Airflow {target_version}\"")) + milestone_issues: List[Issue] = list( + github_handler.search_issues(f"repo:apache/airflow milestone:\"Airflow {target_version}\"") + ) num_cherrypicked = 0 num_uncherrypicked = Counter() @@ -151,13 +153,16 @@ def compare(target_version, github_token, previous_version=None, show_uncherrypi # !s forces as string formatstr = "{id:<8}|{typ!s:<15}|{status!s}|{description:<83.83}|{merged:<6}|{commit:>9.7}" - print(formatstr.format( - id="ISSUE", - typ="TYPE", - status="STATUS".ljust(10), - description="DESCRIPTION", - merged="MERGED", - commit="COMMIT")) + print( + formatstr.format( + id="ISSUE", + typ="TYPE", + status="STATUS".ljust(10), + description="DESCRIPTION", + merged="MERGED", + commit="COMMIT", + ) + ) for issue in milestone_issues: commit_in_master = get_commit_in_master_associated_with_pr(repo, issue) @@ -180,13 +185,17 @@ def compare(target_version, github_token, previous_version=None, show_uncherrypi description=issue.title, ) - print(formatstr.format( - **fields, - merged=cherrypicked, - commit=commit_in_master if commit_in_master else "")) + print( + formatstr.format( + **fields, merged=cherrypicked, commit=commit_in_master if commit_in_master else "" + ) + ) - print("Commits on branch: {:d}, {:d} ({}) yet to be cherry-picked".format( - num_cherrypicked, sum(num_uncherrypicked.values()), dict(num_uncherrypicked))) + print( + "Commits on branch: {:d}, {:d} ({}) yet to be cherry-picked".format( + num_cherrypicked, sum(num_uncherrypicked.values()), dict(num_uncherrypicked) + ) + ) @cli.command(short_help='Build a CHANGELOG grouped by GitHub Issue type') @@ -196,8 +205,7 @@ def compare(target_version, github_token, previous_version=None, show_uncherrypi def changelog(previous_version, target_version, github_token): repo = git.Repo(".", search_parent_directories=True) # Get a list of issues/PRs that have been committed on the current branch. - log_args = [ - f'--format={GIT_LOG_FORMAT}', previous_version + ".." + target_version] + log_args = [f'--format={GIT_LOG_FORMAT}', previous_version + ".." + target_version] log = repo.git.log(*log_args) log = log.strip('\n\x1e').split("\x1e") @@ -221,6 +229,7 @@ def changelog(previous_version, target_version, github_token): if __name__ == "__main__": import doctest + (failure_count, test_count) = doctest.testmod() if failure_count: sys.exit(-1) diff --git a/dev/airflow-license b/dev/airflow-license index 3487309ea38d3..2aac7e59fd5f0 100755 --- a/dev/airflow-license +++ b/dev/airflow-license @@ -24,10 +24,22 @@ import string import slugify # order is important -_licenses = {'MIT': ['Permission is hereby granted free of charge', 'The above copyright notice and this permission notice shall'], - 'BSD-3': ['Redistributions of source code must retain the above copyright', 'Redistributions in binary form must reproduce the above copyright', 'specific prior written permission'], - 'BSD-2': ['Redistributions of source code must retain the above copyright', 'Redistributions in binary form must reproduce the above copyright'], - 'AL': ['http://www.apache.org/licenses/LICENSE-2.0']} +_licenses = { + 'MIT': [ + 'Permission is hereby granted free of charge', + 'The above copyright notice and this permission notice shall', + ], + 'BSD-3': [ + 'Redistributions of source code must retain the above copyright', + 'Redistributions in binary form must reproduce the above copyright', + 'specific prior written permission', + ], + 'BSD-2': [ + 'Redistributions of source code must retain the above copyright', + 'Redistributions in binary form must reproduce the above copyright', + ], + 'AL': ['http://www.apache.org/licenses/LICENSE-2.0'], +} def get_notices(): @@ -56,16 +68,14 @@ def parse_license_file(project_name): if __name__ == "__main__": - print("{:<30}|{:<50}||{:<20}||{:<10}" - .format("PROJECT", "URL", "LICENSE TYPE DEFINED", "DETECTED")) + print("{:<30}|{:<50}||{:<20}||{:<10}".format("PROJECT", "URL", "LICENSE TYPE DEFINED", "DETECTED")) notices = get_notices() for notice in notices: notice = notice[0] license = parse_license_file(notice[1]) - print("{:<30}|{:<50}||{:<20}||{:<10}" - .format(notice[1], notice[2][:50], notice[0], license)) + print("{:<30}|{:<50}||{:<20}||{:<10}".format(notice[1], notice[2][:50], notice[0], license)) file_count = len([name for name in os.listdir("../licenses")]) print("Defined licenses: {} Files found: {}".format(len(notices), file_count)) diff --git a/dev/send_email.py b/dev/send_email.py index 6d698f7ede669..0e9c8b7a45f02 100755 --- a/dev/send_email.py +++ b/dev/send_email.py @@ -37,10 +37,7 @@ SMTP_PORT = 587 SMTP_SERVER = "mail-relay.apache.org" -MAILING_LIST = { - "dev": "dev@airflow.apache.org", - "users": "users@airflow.apache.org" -} +MAILING_LIST = {"dev": "dev@airflow.apache.org", "users": "users@airflow.apache.org"} def string_comma_to_list(message: str) -> List[str]: @@ -51,9 +48,13 @@ def string_comma_to_list(message: str) -> List[str]: def send_email( - smtp_server: str, smpt_port: int, - username: str, password: str, - sender_email: str, receiver_email: Union[str, List], message: str, + smtp_server: str, + smpt_port: int, + username: str, + password: str, + sender_email: str, + receiver_email: Union[str, List], + message: str, ): """ Send a simple text email (SMTP) @@ -92,8 +93,7 @@ def show_message(entity: str, message: str): def inter_send_email( - username: str, password: str, sender_email: str, receiver_email: Union[str, List], - message: str + username: str, password: str, sender_email: str, receiver_email: Union[str, List], message: str ): """ Send email using SMTP @@ -104,7 +104,13 @@ def inter_send_email( try: send_email( - SMTP_SERVER, SMTP_PORT, username, password, sender_email, receiver_email, message, + SMTP_SERVER, + SMTP_PORT, + username, + password, + sender_email, + receiver_email, + message, ) click.secho("✅ Email sent successfully", fg="green") except smtplib.SMTPAuthenticationError: @@ -117,10 +123,8 @@ class BaseParameters: """ Base Class to send emails using Apache Creds and for Jinja templating """ - def __init__( - self, name=None, email=None, username=None, password=None, - version=None, version_rc=None - ): + + def __init__(self, name=None, email=None, username=None, password=None, version=None, version_rc=None): self.name = name self.email = email self.username = username @@ -136,42 +140,53 @@ def __repr__(self): @click.group(context_settings=dict(help_option_names=["-h", "--help"])) @click.pass_context @click.option( - "-e", "--apache_email", + "-e", + "--apache_email", prompt="Apache Email", - envvar="APACHE_EMAIL", show_envvar=True, + envvar="APACHE_EMAIL", + show_envvar=True, help="Your Apache email will be used for SMTP From", - required=True + required=True, ) @click.option( - "-u", "--apache_username", + "-u", + "--apache_username", prompt="Apache Username", - envvar="APACHE_USERNAME", show_envvar=True, + envvar="APACHE_USERNAME", + show_envvar=True, help="Your LDAP Apache username", required=True, ) @click.password_option( # type: ignore - "-p", "--apache_password", + "-p", + "--apache_password", prompt="Apache Password", - envvar="APACHE_PASSWORD", show_envvar=True, + envvar="APACHE_PASSWORD", + show_envvar=True, help="Your LDAP Apache password", required=True, ) @click.option( - "-v", "--version", + "-v", + "--version", prompt="Version", - envvar="AIRFLOW_VERSION", show_envvar=True, + envvar="AIRFLOW_VERSION", + show_envvar=True, help="Release Version", required=True, ) @click.option( - "-rc", "--version_rc", + "-rc", + "--version_rc", prompt="Version (with RC)", - envvar="AIRFLOW_VERSION_RC", show_envvar=True, + envvar="AIRFLOW_VERSION_RC", + show_envvar=True, help="Release Candidate Version", required=True, ) @click.option( # type: ignore - "-n", "--name", + "-n", + "--name", prompt="Your Name", default=lambda: os.environ.get('USER', ''), show_default="Current User", @@ -180,9 +195,13 @@ def __repr__(self): required=True, ) def cli( - ctx, apache_email: str, - apache_username: str, apache_password: str, version: str, version_rc: str, - name: str + ctx, + apache_email: str, + apache_username: str, + apache_password: str, + version: str, + version_rc: str, + name: str, ): """ 🚀 CLI to send emails for the following: @@ -232,7 +251,8 @@ def vote(base_parameters, receiver_email: str): @cli.command("result") @click.option( - "-re", "--receiver_email", + "-re", + "--receiver_email", default=MAILING_LIST.get("dev"), type=click.STRING, prompt="The receiver email (To:)", @@ -258,25 +278,23 @@ def vote(base_parameters, receiver_email: str): @click.pass_obj def result( base_parameters, - receiver_email: str, vote_bindings: str, vote_nonbindings: str, vote_negatives: str, + receiver_email: str, + vote_bindings: str, + vote_nonbindings: str, + vote_negatives: str, ): """ Send email with results of voting on RC """ template_file = "templates/result_email.j2" base_parameters.template_arguments["receiver_email"] = receiver_email - base_parameters.template_arguments["vote_bindings"] = string_comma_to_list( - vote_bindings - ) - base_parameters.template_arguments["vote_nonbindings"] = string_comma_to_list( - vote_nonbindings - ) - base_parameters.template_arguments["vote_negatives"] = string_comma_to_list( - vote_negatives - ) + base_parameters.template_arguments["vote_bindings"] = string_comma_to_list(vote_bindings) + base_parameters.template_arguments["vote_nonbindings"] = string_comma_to_list(vote_nonbindings) + base_parameters.template_arguments["vote_negatives"] = string_comma_to_list(vote_negatives) message = render_template(template_file, **base_parameters.template_arguments) inter_send_email( - base_parameters.username, base_parameters.password, + base_parameters.username, + base_parameters.password, base_parameters.template_arguments["sender_email"], base_parameters.template_arguments["receiver_email"], message, @@ -288,12 +306,10 @@ def result( "--receiver_email", default=",".join(list(MAILING_LIST.values())), prompt="The receiver email (To:)", - help="Receiver's email address. If more than 1, separate them by comma" + help="Receiver's email address. If more than 1, separate them by comma", ) @click.pass_obj -def announce( - base_parameters, receiver_email: str -): +def announce(base_parameters, receiver_email: str): """ Send email to announce release of the new version """ @@ -304,7 +320,8 @@ def announce( message = render_template(template_file, **base_parameters.template_arguments) inter_send_email( - base_parameters.username, base_parameters.password, + base_parameters.username, + base_parameters.password, base_parameters.template_arguments["sender_email"], base_parameters.template_arguments["receiver_email"], message, @@ -312,14 +329,12 @@ def announce( if click.confirm("Show Slack message for announcement?", default=True): base_parameters.template_arguments["slack_rc"] = False - slack_msg = render_template( - "templates/slack.j2", **base_parameters.template_arguments) + slack_msg = render_template("templates/slack.j2", **base_parameters.template_arguments) show_message("Slack", slack_msg) if click.confirm("Show Twitter message for announcement?", default=True): - twitter_msg = render_template( - "templates/twitter.j2", **base_parameters.template_arguments) + twitter_msg = render_template("templates/twitter.j2", **base_parameters.template_arguments) show_message("Twitter", twitter_msg) if __name__ == '__main__': - cli() # pylint: disable=no-value-for-parameter + cli() # pylint: disable=no-value-for-parameter diff --git a/docs/build_docs.py b/docs/build_docs.py index 8f565d8e1993f..2799a46169d78 100755 --- a/docs/build_docs.py +++ b/docs/build_docs.py @@ -87,8 +87,13 @@ def __lt__(self, other): line_no_b = other.line_no or 0 context_line_a = self.context_line or '' context_line_b = other.context_line or '' - return (file_path_a, line_no_a, context_line_a, self.spelling, self.message) < \ - (file_path_b, line_no_b, context_line_b, other.spelling, other.message) + return (file_path_a, line_no_a, context_line_a, self.spelling, self.message) < ( + file_path_b, + line_no_b, + context_line_b, + other.spelling, + other.message, + ) build_errors: List[DocBuildError] = [] @@ -221,7 +226,7 @@ def generate_build_error(path, line_no, operator_name): f".. seealso::\n" f" For more information on how to use this operator, take a look at the guide:\n" f" :ref:`howto/operator:{operator_name}`\n" - ) + ), ) # Extract operators for which there are existing .rst guides @@ -261,9 +266,7 @@ def generate_build_error(path, line_no, operator_name): if f":ref:`howto/operator:{existing_operator}`" in ast.get_docstring(class_def): continue - build_errors.append( - generate_build_error(py_module_path, class_def.lineno, existing_operator) - ) + build_errors.append(generate_build_error(py_module_path, class_def.lineno, existing_operator)) def assert_file_not_contains(file_path: str, pattern: str, message: str) -> None: @@ -325,8 +328,9 @@ def check_class_links_in_operators_and_hooks_ref() -> None: airflow_modules = find_modules() - find_modules(deprecated_only=True) airflow_modules = { - o for o in airflow_modules if any(f".{d}." in o for d in - ["operators", "hooks", "sensors", "transfers"]) + o + for o in airflow_modules + if any(f".{d}." in o for d in ["operators", "hooks", "sensors", "transfers"]) } missing_modules = airflow_modules - current_modules_in_file @@ -359,20 +363,12 @@ def check_guide_links_in_operators_and_hooks_ref() -> None: if "_partials" not in guide ] # Remove partials and index - all_guides = [ - guide - for guide in all_guides - if "/_partials/" not in guide and not guide.endswith("index") - ] + all_guides = [guide for guide in all_guides if "/_partials/" not in guide and not guide.endswith("index")] with open(os.path.join(DOCS_DIR, "operators-and-hooks-ref.rst")) as ref_file: content = ref_file.read() - missing_guides = [ - guide - for guide in all_guides - if guide not in content - ] + missing_guides = [guide for guide in all_guides if guide not in content] if missing_guides: guide_text_list = "\n".join(f":doc:`How to use <{guide}>`" for guide in missing_guides) @@ -403,7 +399,7 @@ def check_exampleinclude_for_example_dags(): message=( "literalinclude directive is prohibited for example DAGs. \n" "You should use the exampleinclude directive to include example DAGs." - ) + ), ) @@ -418,7 +414,7 @@ def check_enforce_code_block(): message=( "We recommend using the code-block directive instead of the code directive. " "The code-block directive is more feature-full." - ) + ), ) @@ -444,10 +440,12 @@ def check_google_guides(): doc_files = glob(f"{DOCS_DIR}/howto/operator/google/**/*.rst", recursive=True) doc_names = {f.split("/")[-1].rsplit(".")[0] for f in doc_files} - operators_files = chain(*[ - glob(f"{ROOT_PACKAGE_DIR}/providers/google/*/{resource_type}/*.py") - for resource_type in ["operators", "sensors", "transfers"] - ]) + operators_files = chain( + *[ + glob(f"{ROOT_PACKAGE_DIR}/providers/google/*/{resource_type}/*.py") + for resource_type in ["operators", "sensors", "transfers"] + ] + ) operators_files = (f for f in operators_files if not f.endswith("__init__.py")) operator_names = {f.split("/")[-1].rsplit(".")[0] for f in operators_files} @@ -495,6 +493,7 @@ def guess_lexer_for_filename(filename): lexer = get_lexer_for_filename(filename) except ClassNotFound: from pygments.lexers.special import TextLexer + lexer = TextLexer() return lexer @@ -504,6 +503,7 @@ def guess_lexer_for_filename(filename): with suppress(ImportError): import pygments from pygments.formatters.terminal import TerminalFormatter + code = pygments.highlight( code=code, formatter=TerminalFormatter(), lexer=guess_lexer_for_filename(file_path) ) @@ -541,9 +541,7 @@ def parse_sphinx_warnings(warning_text: str) -> List[DocBuildError]: except Exception: # noqa pylint: disable=broad-except # If an exception occurred while parsing the warning message, display the raw warning message. sphinx_build_errors.append( - DocBuildError( - file_path=None, line_no=None, message=sphinx_warning - ) + DocBuildError(file_path=None, line_no=None, message=sphinx_warning) ) else: sphinx_build_errors.append(DocBuildError(file_path=None, line_no=None, message=sphinx_warning)) @@ -635,7 +633,7 @@ def check_spelling() -> None: "-D", # override the extensions because one of them throws an error on the spelling builder f"extensions={','.join(extensions_to_use)}", ".", # path to documentation source files - "_build/spelling" + "_build/spelling", ] print("Executing cmd: ", " ".join([shlex.quote(c) for c in build_cmd])) @@ -643,8 +641,12 @@ def check_spelling() -> None: if completed_proc.returncode != 0: spelling_errors.append( SpellingError( - file_path=None, line_no=None, spelling=None, suggestion=None, context_line=None, - message=f"Sphinx spellcheck returned non-zero exit status: {completed_proc.returncode}." + file_path=None, + line_no=None, + spelling=None, + suggestion=None, + context_line=None, + message=f"Sphinx spellcheck returned non-zero exit status: {completed_proc.returncode}.", ) ) @@ -710,10 +712,10 @@ def print_build_errors_and_exit(message) -> None: parser = argparse.ArgumentParser(description='Builds documentation and runs spell checking') -parser.add_argument('--docs-only', dest='docs_only', action='store_true', - help='Only build documentation') -parser.add_argument('--spellcheck-only', dest='spellcheck_only', action='store_true', - help='Only perform spellchecking') +parser.add_argument('--docs-only', dest='docs_only', action='store_true', help='Only build documentation') +parser.add_argument( + '--spellcheck-only', dest='spellcheck_only', action='store_true', help='Only perform spellchecking' +) args = parser.parse_args() diff --git a/docs/conf.py b/docs/conf.py index a035b596f878e..aab471a28f1a3 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -142,14 +142,9 @@ "sphinxcontrib.spelling", ] -autodoc_default_options = { - 'show-inheritance': True, - 'members': True -} +autodoc_default_options = {'show-inheritance': True, 'members': True} -jinja_contexts = { - 'config_ctx': {"configs": default_config_yaml()} -} +jinja_contexts = {'config_ctx': {"configs": default_config_yaml()}} viewcode_follow_imported_members = True @@ -213,7 +208,7 @@ # Templates or partials 'autoapi_templates', 'howto/operator/google/_partials', - 'howto/operator/microsoft/_partials' + 'howto/operator/microsoft/_partials', ] ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) @@ -225,7 +220,9 @@ def _get_rst_filepath_from_path(filepath: str): elif os.path.isfile(filepath) and filepath.endswith('/__init__.py'): result = filepath.rpartition("/")[0] else: - result = filepath.rpartition(".",)[0] + result = filepath.rpartition( + ".", + )[0] result += "/index.rst" result = f"_api/{os.path.relpath(result, ROOT_DIR)}" @@ -252,8 +249,7 @@ def _get_rst_filepath_from_path(filepath: str): } providers_package_indexes = { - f"_api/{os.path.relpath(name, ROOT_DIR)}/index.rst" - for name in providers_packages_roots + f"_api/{os.path.relpath(name, ROOT_DIR)}/index.rst" for name in providers_packages_roots } exclude_patterns.extend(providers_package_indexes) @@ -269,10 +265,7 @@ def _get_rst_filepath_from_path(filepath: str): for p in excluded_packages_in_providers for path in glob(f"{p}/**/*", recursive=True) } -excluded_files_in_providers |= { - _get_rst_filepath_from_path(name) - for name in excluded_packages_in_providers -} +excluded_files_in_providers |= {_get_rst_filepath_from_path(name) for name in excluded_packages_in_providers} exclude_patterns.extend(excluded_files_in_providers) @@ -437,10 +430,8 @@ def _get_rst_filepath_from_path(filepath: str): latex_elements = { # The paper size ('letterpaper' or 'a4paper'). # 'papersize': 'letterpaper', - # The font size ('10pt', '11pt' or '12pt'). # 'pointsize': '10pt', - # Additional stuff for the LaTeX preamble. # 'preamble': '', } # type: Dict[str,str] @@ -449,8 +440,7 @@ def _get_rst_filepath_from_path(filepath: str): # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - ('index', 'Airflow.tex', 'Airflow Documentation', - 'Apache Airflow', 'manual'), + ('index', 'Airflow.tex', 'Airflow Documentation', 'Apache Airflow', 'manual'), ] # The name of an image file (relative to this directory) to place at the top of @@ -478,10 +468,7 @@ def _get_rst_filepath_from_path(filepath: str): # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - ('index', 'airflow', 'Airflow Documentation', - ['Apache Airflow'], 1) -] +man_pages = [('index', 'airflow', 'Airflow Documentation', ['Apache Airflow'], 1)] # If true, show URL addresses after external links. # man_show_urls = False @@ -492,12 +479,17 @@ def _get_rst_filepath_from_path(filepath: str): # Grouping the document tree into Texinfo files. List of tuples # (source start file, target name, title, author, # dir menu entry, description, category) -texinfo_documents = [( - 'index', 'Airflow', 'Airflow Documentation', - 'Apache Airflow', 'Airflow', - 'Airflow is a system to programmatically author, schedule and monitor data pipelines.', - 'Miscellaneous' -), ] +texinfo_documents = [ + ( + 'index', + 'Airflow', + 'Airflow Documentation', + 'Apache Airflow', + 'Airflow', + 'Airflow is a system to programmatically author, schedule and monitor data pipelines.', + 'Miscellaneous', + ), +] # Documents to append as an appendix to all manuals. # texinfo_appendices = [] @@ -546,10 +538,7 @@ def _get_rst_filepath_from_path(filepath: str): redirects_file = 'redirects.txt' # -- Options for redoc docs ---------------------------------- -OPENAPI_FILE = os.path.join( - os.path.dirname(__file__), - "..", "airflow", "api_connexion", "openapi", "v1.yaml" -) +OPENAPI_FILE = os.path.join(os.path.dirname(__file__), "..", "airflow", "api_connexion", "openapi", "v1.yaml") redoc = [ { 'name': 'Airflow REST API', @@ -558,7 +547,7 @@ def _get_rst_filepath_from_path(filepath: str): 'opts': { 'hide-hostname': True, 'no-auto-auth': True, - } + }, }, ] diff --git a/docs/exts/docroles.py b/docs/exts/docroles.py index 5dd71cd1bfc6d..a1ac0cb2ecbee 100644 --- a/docs/exts/docroles.py +++ b/docs/exts/docroles.py @@ -52,21 +52,21 @@ def get_template_field(env, fullname): template_fields = getattr(clazz, "template_fields") if not template_fields: - raise RoleException( - f"Could not find the template fields for {classname} class in {modname} module." - ) + raise RoleException(f"Could not find the template fields for {classname} class in {modname} module.") return list(template_fields) -def template_field_role(app, - typ, # pylint: disable=unused-argument - rawtext, - text, - lineno, - inliner, - options=None, # pylint: disable=unused-argument - content=None): # pylint: disable=unused-argument +def template_field_role( + app, + typ, # pylint: disable=unused-argument + rawtext, + text, + lineno, + inliner, + options=None, # pylint: disable=unused-argument + content=None, +): # pylint: disable=unused-argument """ A role that allows you to include a list of template fields in the middle of the text. This is especially useful when writing guides describing how to use the operator. @@ -90,7 +90,10 @@ def template_field_role(app, try: template_fields = get_template_field(app.env, text) except RoleException as e: - msg = inliner.reporter.error(f"invalid class name {text} \n{e}", line=lineno) + msg = inliner.reporter.error( + f"invalid class name {text} \n{e}", + line=lineno, + ) prb = inliner.problematic(rawtext, rawtext, msg) return [prb], [msg] @@ -106,6 +109,7 @@ def template_field_role(app, def setup(app): """Sets the extension up""" from docutils.parsers.rst import roles # pylint: disable=wrong-import-order + roles.register_local_role("template-fields", partial(template_field_role, app)) return {"version": "builtin", "parallel_read_safe": True, "parallel_write_safe": True} diff --git a/docs/exts/exampleinclude.py b/docs/exts/exampleinclude.py index 774098b6cb32a..141ca8202bef6 100644 --- a/docs/exts/exampleinclude.py +++ b/docs/exts/exampleinclude.py @@ -34,6 +34,7 @@ try: import sphinx_airflow_theme # pylint: disable=unused-import + airflow_theme_is_available = True except ImportError: airflow_theme_is_available = False @@ -147,8 +148,9 @@ def register_source(app, env, modname): try: analyzer = ModuleAnalyzer.for_module(modname) except Exception as ex: # pylint: disable=broad-except - logger.info("Module \"%s\" could not be loaded. Full source will not be available. \"%s\"", - modname, ex) + logger.info( + "Module \"%s\" could not be loaded. Full source will not be available. \"%s\"", modname, ex + ) env._viewcode_modules[modname] = False return False @@ -168,6 +170,8 @@ def register_source(app, env, modname): env._viewcode_modules[modname] = entry return True + + # pylint: enable=protected-access @@ -232,6 +236,8 @@ def doctree_read(app, doctree): onlynode = create_node(env, relative_path, show_button) objnode.replace_self(onlynode) + + # pylint: enable=protected-access diff --git a/kubernetes_tests/test_kubernetes_executor.py b/kubernetes_tests/test_kubernetes_executor.py index de2726aca7f74..ba4900f01c180 100644 --- a/kubernetes_tests/test_kubernetes_executor.py +++ b/kubernetes_tests/test_kubernetes_executor.py @@ -27,7 +27,7 @@ from requests.adapters import HTTPAdapter from urllib3.util.retry import Retry -CLUSTER_FORWARDED_PORT = (os.environ.get('CLUSTER_FORWARDED_PORT') or "8080") +CLUSTER_FORWARDED_PORT = os.environ.get('CLUSTER_FORWARDED_PORT') or "8080" KUBERNETES_HOST_PORT = (os.environ.get('CLUSTER_HOST') or "localhost") + ":" + CLUSTER_FORWARDED_PORT print() @@ -36,7 +36,6 @@ class TestKubernetesExecutor(unittest.TestCase): - @staticmethod def _describe_resources(namespace: str): print("=" * 80) @@ -93,8 +92,7 @@ def setUp(self): def tearDown(self): self.session.close() - def monitor_task(self, host, execution_date, dag_id, task_id, expected_final_state, - timeout): + def monitor_task(self, host, execution_date, dag_id, task_id, expected_final_state, timeout): tries = 0 state = '' max_tries = max(int(timeout / 5), 1) @@ -103,9 +101,10 @@ def monitor_task(self, host, execution_date, dag_id, task_id, expected_final_sta time.sleep(5) # Trigger a new dagrun try: - get_string = \ - f'http://{host}/api/experimental/dags/{dag_id}/' \ + get_string = ( + f'http://{host}/api/experimental/dags/{dag_id}/' f'dag_runs/{execution_date}/tasks/{task_id}' + ) print(f"Calling [monitor_task]#1 {get_string}") result = self.session.get(get_string) if result.status_code == 404: @@ -129,18 +128,14 @@ def monitor_task(self, host, execution_date, dag_id, task_id, expected_final_sta print(f"The expected state is wrong {state} != {expected_final_state} (expected)!") self.assertEqual(state, expected_final_state) - def ensure_dag_expected_state(self, host, execution_date, dag_id, - expected_final_state, - timeout): + def ensure_dag_expected_state(self, host, execution_date, dag_id, expected_final_state, timeout): tries = 0 state = '' max_tries = max(int(timeout / 5), 1) # Wait some time for the operator to complete while tries < max_tries: time.sleep(5) - get_string = \ - f'http://{host}/api/experimental/dags/{dag_id}/' \ - f'dag_runs/{execution_date}' + get_string = f'http://{host}/api/experimental/dags/{dag_id}/' f'dag_runs/{execution_date}' print(f"Calling {get_string}") # Trigger a new dagrun result = self.session.get(get_string) @@ -148,8 +143,7 @@ def ensure_dag_expected_state(self, host, execution_date, dag_id, result_json = result.json() print(f"Received: {result}") state = result_json['state'] - check_call( - ["echo", f"Attempt {tries}: Current state of dag is {state}"]) + check_call(["echo", f"Attempt {tries}: Current state of dag is {state}"]) print(f"Attempt {tries}: Current state of dag is {state}") if state == expected_final_state: @@ -162,8 +156,7 @@ def ensure_dag_expected_state(self, host, execution_date, dag_id, # Maybe check if we can retrieve the logs, but then we need to extend the API def start_dag(self, dag_id, host): - get_string = f'http://{host}/api/experimental/' \ - f'dags/{dag_id}/paused/false' + get_string = f'http://{host}/api/experimental/' f'dags/{dag_id}/paused/false' print(f"Calling [start_dag]#1 {get_string}") result = self.session.get(get_string) try: @@ -171,10 +164,8 @@ def start_dag(self, dag_id, host): except ValueError: result_json = str(result) print(f"Received [start_dag]#1 {result_json}") - self.assertEqual(result.status_code, 200, "Could not enable DAG: {result}" - .format(result=result_json)) - post_string = f'http://{host}/api/experimental/' \ - f'dags/{dag_id}/dag_runs' + self.assertEqual(result.status_code, 200, f"Could not enable DAG: {result_json}") + post_string = f'http://{host}/api/experimental/' f'dags/{dag_id}/dag_runs' print(f"Calling [start_dag]#2 {post_string}") # Trigger a new dagrun result = self.session.post(post_string, json={}) @@ -183,17 +174,18 @@ def start_dag(self, dag_id, host): except ValueError: result_json = str(result) print(f"Received [start_dag]#2 {result_json}") - self.assertEqual(result.status_code, 200, "Could not trigger a DAG-run: {result}" - .format(result=result_json)) + self.assertEqual(result.status_code, 200, f"Could not trigger a DAG-run: {result_json}") time.sleep(1) get_string = f'http://{host}/api/experimental/latest_runs' print(f"Calling [start_dag]#3 {get_string}") result = self.session.get(get_string) - self.assertEqual(result.status_code, 200, "Could not get the latest DAG-run:" - " {result}" - .format(result=result.json())) + self.assertEqual( + result.status_code, + 200, + "Could not get the latest DAG-run:" " {result}".format(result=result.json()), + ) result_json = result.json() print(f"Received: [start_dag]#3 {result_json}") return result_json @@ -217,16 +209,22 @@ def test_integration_run_dag(self): print(f"Found the job with execution date {execution_date}") # Wait some time for the operator to complete - self.monitor_task(host=host, - execution_date=execution_date, - dag_id=dag_id, - task_id='start_task', - expected_final_state='success', timeout=300) + self.monitor_task( + host=host, + execution_date=execution_date, + dag_id=dag_id, + task_id='start_task', + expected_final_state='success', + timeout=300, + ) - self.ensure_dag_expected_state(host=host, - execution_date=execution_date, - dag_id=dag_id, - expected_final_state='success', timeout=300) + self.ensure_dag_expected_state( + host=host, + execution_date=execution_date, + dag_id=dag_id, + expected_final_state='success', + timeout=300, + ) def test_integration_run_dag_with_scheduler_failure(self): host = KUBERNETES_HOST_PORT @@ -239,23 +237,32 @@ def test_integration_run_dag_with_scheduler_failure(self): time.sleep(10) # give time for pod to restart # Wait some time for the operator to complete - self.monitor_task(host=host, - execution_date=execution_date, - dag_id=dag_id, - task_id='start_task', - expected_final_state='success', timeout=300) - - self.monitor_task(host=host, - execution_date=execution_date, - dag_id=dag_id, - task_id='other_namespace_task', - expected_final_state='success', timeout=300) - - self.ensure_dag_expected_state(host=host, - execution_date=execution_date, - dag_id=dag_id, - expected_final_state='success', timeout=300) - - self.assertEqual(self._num_pods_in_namespace('test-namespace'), - 0, - "failed to delete pods in other namespace") + self.monitor_task( + host=host, + execution_date=execution_date, + dag_id=dag_id, + task_id='start_task', + expected_final_state='success', + timeout=300, + ) + + self.monitor_task( + host=host, + execution_date=execution_date, + dag_id=dag_id, + task_id='other_namespace_task', + expected_final_state='success', + timeout=300, + ) + + self.ensure_dag_expected_state( + host=host, + execution_date=execution_date, + dag_id=dag_id, + expected_final_state='success', + timeout=300, + ) + + self.assertEqual( + self._num_pods_in_namespace('test-namespace'), 0, "failed to delete pods in other namespace" + ) diff --git a/kubernetes_tests/test_kubernetes_pod_operator.py b/kubernetes_tests/test_kubernetes_pod_operator.py index c78a821b3b30c..33eb553c06f26 100644 --- a/kubernetes_tests/test_kubernetes_pod_operator.py +++ b/kubernetes_tests/test_kubernetes_pod_operator.py @@ -46,8 +46,7 @@ def create_context(task): dag = DAG(dag_id="dag") tzinfo = pendulum.timezone("Europe/Amsterdam") execution_date = timezone.datetime(2016, 1, 1, 1, 0, 0, tzinfo=tzinfo) - task_instance = TaskInstance(task=task, - execution_date=execution_date) + task_instance = TaskInstance(task=task, execution_date=execution_date) return { "dag": dag, "ts": execution_date.isoformat(), @@ -57,7 +56,6 @@ def create_context(task): class TestKubernetesPodOperatorSystem(unittest.TestCase): - def get_current_task_name(self): # reverse test name to make pod name unique (it has limited length) return "_" + unittest.TestCase.id(self).replace(".", "_")[::-1] @@ -73,26 +71,30 @@ def setUp(self): 'name': ANY, 'annotations': {}, 'labels': { - 'foo': 'bar', 'kubernetes_pod_operator': 'True', + 'foo': 'bar', + 'kubernetes_pod_operator': 'True', 'airflow_version': airflow_version.replace('+', '-'), 'execution_date': '2016-01-01T0100000100-a2f50a31f', 'dag_id': 'dag', 'task_id': ANY, - 'try_number': '1'}, + 'try_number': '1', + }, }, 'spec': { 'affinity': {}, - 'containers': [{ - 'image': 'ubuntu:16.04', - 'args': ["echo 10"], - 'command': ["bash", "-cx"], - 'env': [], - 'envFrom': [], - 'resources': {}, - 'name': 'base', - 'ports': [], - 'volumeMounts': [], - }], + 'containers': [ + { + 'image': 'ubuntu:16.04', + 'args': ["echo 10"], + 'command': ["bash", "-cx"], + 'env': [], + 'envFrom': [], + 'resources': {}, + 'name': 'base', + 'ports': [], + 'volumeMounts': [], + } + ], 'hostNetwork': False, 'imagePullSecrets': [], 'initContainers': [], @@ -102,13 +104,14 @@ def setUp(self): 'serviceAccountName': 'default', 'tolerations': [], 'volumes': [], - } + }, } def tearDown(self) -> None: client = kube_client.get_kube_client(in_cluster=False) client.delete_collection_namespaced_pod(namespace="default") import time + time.sleep(1) def test_do_xcom_push_defaults_false(self): @@ -222,7 +225,7 @@ def test_pod_dnspolicy(self): in_cluster=False, do_xcom_push=False, hostnetwork=True, - dnspolicy=dns_policy + dnspolicy=dns_policy, ) context = create_context(k) k.execute(context) @@ -244,7 +247,7 @@ def test_pod_schedulername(self): task_id="task" + self.get_current_task_name(), in_cluster=False, do_xcom_push=False, - schedulername=scheduler_name + schedulername=scheduler_name, ) context = create_context(k) k.execute(context) @@ -253,9 +256,7 @@ def test_pod_schedulername(self): self.assertEqual(self.expected_pod, actual_pod) def test_pod_node_selectors(self): - node_selectors = { - 'beta.kubernetes.io/os': 'linux' - } + node_selectors = {'beta.kubernetes.io/os': 'linux'} k = KubernetesPodOperator( namespace='default', image="ubuntu:16.04", @@ -276,17 +277,8 @@ def test_pod_node_selectors(self): def test_pod_resources(self): resources = k8s.V1ResourceRequirements( - requests={ - 'memory': '64Mi', - 'cpu': '250m', - 'ephemeral-storage': '1Gi' - }, - limits={ - 'memory': '64Mi', - 'cpu': 0.25, - 'nvidia.com/gpu': None, - 'ephemeral-storage': '2Gi' - } + requests={'memory': '64Mi', 'cpu': '250m', 'ephemeral-storage': '1Gi'}, + limits={'memory': '64Mi', 'cpu': 0.25, 'nvidia.com/gpu': None, 'ephemeral-storage': '2Gi'}, ) k = KubernetesPodOperator( namespace='default', @@ -304,17 +296,8 @@ def test_pod_resources(self): k.execute(context) actual_pod = self.api_client.sanitize_for_serialization(k.pod) self.expected_pod['spec']['containers'][0]['resources'] = { - 'requests': { - 'memory': '64Mi', - 'cpu': '250m', - 'ephemeral-storage': '1Gi' - }, - 'limits': { - 'memory': '64Mi', - 'cpu': 0.25, - 'nvidia.com/gpu': None, - 'ephemeral-storage': '2Gi' - } + 'requests': {'memory': '64Mi', 'cpu': '250m', 'ephemeral-storage': '1Gi'}, + 'limits': {'memory': '64Mi', 'cpu': 0.25, 'nvidia.com/gpu': None, 'ephemeral-storage': '2Gi'}, } self.assertEqual(self.expected_pod, actual_pod) @@ -325,11 +308,7 @@ def test_pod_affinity(self): 'nodeSelectorTerms': [ { 'matchExpressions': [ - { - 'key': 'beta.kubernetes.io/os', - 'operator': 'In', - 'values': ['linux'] - } + {'key': 'beta.kubernetes.io/os', 'operator': 'In', 'values': ['linux']} ] } ] @@ -375,30 +354,24 @@ def test_port(self): context = create_context(k) k.execute(context=context) actual_pod = self.api_client.sanitize_for_serialization(k.pod) - self.expected_pod['spec']['containers'][0]['ports'] = [{ - 'name': 'http', - 'containerPort': 80 - }] + self.expected_pod['spec']['containers'][0]['ports'] = [{'name': 'http', 'containerPort': 80}] self.assertEqual(self.expected_pod, actual_pod) def test_volume_mount(self): with mock.patch.object(PodLauncher, 'log') as mock_logger: volume_mount = k8s.V1VolumeMount( - name='test-volume', - mount_path='/tmp/test_volume', - sub_path=None, - read_only=False + name='test-volume', mount_path='/tmp/test_volume', sub_path=None, read_only=False ) volume = k8s.V1Volume( name='test-volume', - persistent_volume_claim=k8s.V1PersistentVolumeClaimVolumeSource( - claim_name='test-volume' - ) + persistent_volume_claim=k8s.V1PersistentVolumeClaimVolumeSource(claim_name='test-volume'), ) - args = ["echo \"retrieved from mount\" > /tmp/test_volume/test.txt " - "&& cat /tmp/test_volume/test.txt"] + args = [ + "echo \"retrieved from mount\" > /tmp/test_volume/test.txt " + "&& cat /tmp/test_volume/test.txt" + ] k = KubernetesPodOperator( namespace='default', image="ubuntu:16.04", @@ -417,17 +390,12 @@ def test_volume_mount(self): mock_logger.info.assert_any_call('retrieved from mount') actual_pod = self.api_client.sanitize_for_serialization(k.pod) self.expected_pod['spec']['containers'][0]['args'] = args - self.expected_pod['spec']['containers'][0]['volumeMounts'] = [{ - 'name': 'test-volume', - 'mountPath': '/tmp/test_volume', - 'readOnly': False - }] - self.expected_pod['spec']['volumes'] = [{ - 'name': 'test-volume', - 'persistentVolumeClaim': { - 'claimName': 'test-volume' - } - }] + self.expected_pod['spec']['containers'][0]['volumeMounts'] = [ + {'name': 'test-volume', 'mountPath': '/tmp/test_volume', 'readOnly': False} + ] + self.expected_pod['spec']['volumes'] = [ + {'name': 'test-volume', 'persistentVolumeClaim': {'claimName': 'test-volume'}} + ] self.assertEqual(self.expected_pod, actual_pod) def test_run_as_user_root(self): @@ -604,9 +572,7 @@ def test_envs_from_configmaps(self, mock_client, mock_monitor, mock_start): from airflow.utils.state import State configmap_name = "test-config-map" - env_from = [k8s.V1EnvFromSource(config_map_ref=k8s.V1ConfigMapEnvSource( - name=configmap_name - ))] + env_from = [k8s.V1EnvFromSource(config_map_ref=k8s.V1ConfigMapEnvSource(name=configmap_name))] # WHEN k = KubernetesPodOperator( namespace='default', @@ -618,15 +584,13 @@ def test_envs_from_configmaps(self, mock_client, mock_monitor, mock_start): task_id="task" + self.get_current_task_name(), in_cluster=False, do_xcom_push=False, - env_from=env_from + env_from=env_from, ) # THEN mock_monitor.return_value = (State.SUCCESS, None) context = create_context(k) k.execute(context) - self.assertEqual( - mock_start.call_args[0][0].spec.containers[0].env_from, env_from - ) + self.assertEqual(mock_start.call_args[0][0].spec.containers[0].env_from, env_from) @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod") @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod") @@ -634,6 +598,7 @@ def test_envs_from_configmaps(self, mock_client, mock_monitor, mock_start): def test_envs_from_secrets(self, mock_client, monitor_mock, start_mock): # GIVEN from airflow.utils.state import State + secret_ref = 'secret_name' secrets = [Secret('env', None, secret_ref)] # WHEN @@ -655,29 +620,17 @@ def test_envs_from_secrets(self, mock_client, monitor_mock, start_mock): k.execute(context) self.assertEqual( start_mock.call_args[0][0].spec.containers[0].env_from, - [k8s.V1EnvFromSource(secret_ref=k8s.V1SecretEnvSource( - name=secret_ref - ))] + [k8s.V1EnvFromSource(secret_ref=k8s.V1SecretEnvSource(name=secret_ref))], ) def test_env_vars(self): # WHEN env_vars = [ - k8s.V1EnvVar( - name="ENV1", - value="val1" - ), - k8s.V1EnvVar( - name="ENV2", - value="val2" - ), + k8s.V1EnvVar(name="ENV1", value="val1"), + k8s.V1EnvVar(name="ENV2", value="val2"), k8s.V1EnvVar( name="ENV3", - value_from=k8s.V1EnvVarSource( - field_ref=k8s.V1ObjectFieldSelector( - field_path="status.podIP" - ) - ) + value_from=k8s.V1EnvVarSource(field_ref=k8s.V1ObjectFieldSelector(field_path="status.podIP")), ), ] @@ -702,14 +655,7 @@ def test_env_vars(self): self.expected_pod['spec']['containers'][0]['env'] = [ {'name': 'ENV1', 'value': 'val1'}, {'name': 'ENV2', 'value': 'val2'}, - { - 'name': 'ENV3', - 'valueFrom': { - 'fieldRef': { - 'fieldPath': 'status.podIP' - } - } - } + {'name': 'ENV3', 'valueFrom': {'fieldRef': {'fieldPath': 'status.podIP'}}}, ] self.assertEqual(self.expected_pod, actual_pod) @@ -719,7 +665,7 @@ def test_pod_template_file_system(self): task_id="task" + self.get_current_task_name(), in_cluster=False, pod_template_file=fixture, - do_xcom_push=True + do_xcom_push=True, ) context = create_context(k) @@ -735,7 +681,7 @@ def test_pod_template_file_with_overrides_system(self): env_vars=[k8s.V1EnvVar(name="env_name", value="value")], in_cluster=False, pod_template_file=fixture, - do_xcom_push=True + do_xcom_push=True, ) context = create_context(k) @@ -747,20 +693,14 @@ def test_pod_template_file_with_overrides_system(self): def test_init_container(self): # GIVEN - volume_mounts = [k8s.V1VolumeMount( - mount_path='/etc/foo', - name='test-volume', - sub_path=None, - read_only=True - )] - - init_environments = [k8s.V1EnvVar( - name='key1', - value='value1' - ), k8s.V1EnvVar( - name='key2', - value='value2' - )] + volume_mounts = [ + k8s.V1VolumeMount(mount_path='/etc/foo', name='test-volume', sub_path=None, read_only=True) + ] + + init_environments = [ + k8s.V1EnvVar(name='key1', value='value1'), + k8s.V1EnvVar(name='key2', value='value2'), + ] init_container = k8s.V1Container( name="init-container", @@ -768,32 +708,20 @@ def test_init_container(self): env=init_environments, volume_mounts=volume_mounts, command=["bash", "-cx"], - args=["echo 10"] + args=["echo 10"], ) volume = k8s.V1Volume( name='test-volume', - persistent_volume_claim=k8s.V1PersistentVolumeClaimVolumeSource( - claim_name='test-volume' - ) + persistent_volume_claim=k8s.V1PersistentVolumeClaimVolumeSource(claim_name='test-volume'), ) expected_init_container = { 'name': 'init-container', 'image': 'ubuntu:16.04', 'command': ['bash', '-cx'], 'args': ['echo 10'], - 'env': [{ - 'name': 'key1', - 'value': 'value1' - }, { - 'name': 'key2', - 'value': 'value2' - }], - 'volumeMounts': [{ - 'mountPath': '/etc/foo', - 'name': 'test-volume', - 'readOnly': True - }], + 'env': [{'name': 'key1', 'value': 'value1'}, {'name': 'key2', 'value': 'value2'}], + 'volumeMounts': [{'mountPath': '/etc/foo', 'name': 'test-volume', 'readOnly': True}], } k = KubernetesPodOperator( @@ -813,36 +741,30 @@ def test_init_container(self): k.execute(context) actual_pod = self.api_client.sanitize_for_serialization(k.pod) self.expected_pod['spec']['initContainers'] = [expected_init_container] - self.expected_pod['spec']['volumes'] = [{ - 'name': 'test-volume', - 'persistentVolumeClaim': { - 'claimName': 'test-volume' - } - }] + self.expected_pod['spec']['volumes'] = [ + {'name': 'test-volume', 'persistentVolumeClaim': {'claimName': 'test-volume'}} + ] self.assertEqual(self.expected_pod, actual_pod) @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod") @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod") @mock.patch("airflow.kubernetes.kube_client.get_kube_client") def test_pod_template_file( - self, - mock_client, - monitor_mock, - start_mock # pylint: disable=unused-argument + self, mock_client, monitor_mock, start_mock # pylint: disable=unused-argument ): from airflow.utils.state import State + path = sys.path[0] + '/tests/kubernetes/pod.yaml' k = KubernetesPodOperator( - task_id="task" + self.get_current_task_name(), - pod_template_file=path, - do_xcom_push=True + task_id="task" + self.get_current_task_name(), pod_template_file=path, do_xcom_push=True ) monitor_mock.return_value = (State.SUCCESS, None) context = create_context(k) with self.assertLogs(k.log, level=logging.DEBUG) as cm: k.execute(context) - expected_line = textwrap.dedent("""\ + expected_line = textwrap.dedent( + """\ DEBUG:airflow.task.operators:Starting pod: api_version: v1 kind: Pod @@ -851,65 +773,57 @@ def test_pod_template_file( cluster_name: null creation_timestamp: null deletion_grace_period_seconds: null\ - """).strip() + """ + ).strip() self.assertTrue(any(line.startswith(expected_line) for line in cm.output)) actual_pod = self.api_client.sanitize_for_serialization(k.pod) - expected_dict = {'apiVersion': 'v1', - 'kind': 'Pod', - 'metadata': {'annotations': {}, - 'labels': {}, - 'name': 'memory-demo', - 'namespace': 'mem-example'}, - 'spec': {'affinity': {}, - 'containers': [{'args': ['--vm', - '1', - '--vm-bytes', - '150M', - '--vm-hang', - '1'], - 'command': ['stress'], - 'env': [], - 'envFrom': [], - 'image': 'apache/airflow:stress-2020.07.10-1.0.4', - 'name': 'base', - 'ports': [], - 'resources': {'limits': {'memory': '200Mi'}, - 'requests': {'memory': '100Mi'}}, - 'volumeMounts': [{'mountPath': '/airflow/xcom', - 'name': 'xcom'}]}, - {'command': ['sh', - '-c', - 'trap "exit 0" INT; while true; do sleep ' - '30; done;'], - 'image': 'alpine', - 'name': 'airflow-xcom-sidecar', - 'resources': {'requests': {'cpu': '1m'}}, - 'volumeMounts': [{'mountPath': '/airflow/xcom', - 'name': 'xcom'}]}], - 'hostNetwork': False, - 'imagePullSecrets': [], - 'initContainers': [], - 'nodeSelector': {}, - 'restartPolicy': 'Never', - 'securityContext': {}, - 'serviceAccountName': 'default', - 'tolerations': [], - 'volumes': [{'emptyDir': {}, 'name': 'xcom'}]}} + expected_dict = { + 'apiVersion': 'v1', + 'kind': 'Pod', + 'metadata': {'annotations': {}, 'labels': {}, 'name': 'memory-demo', 'namespace': 'mem-example'}, + 'spec': { + 'affinity': {}, + 'containers': [ + { + 'args': ['--vm', '1', '--vm-bytes', '150M', '--vm-hang', '1'], + 'command': ['stress'], + 'env': [], + 'envFrom': [], + 'image': 'apache/airflow:stress-2020.07.10-1.0.4', + 'name': 'base', + 'ports': [], + 'resources': {'limits': {'memory': '200Mi'}, 'requests': {'memory': '100Mi'}}, + 'volumeMounts': [{'mountPath': '/airflow/xcom', 'name': 'xcom'}], + }, + { + 'command': ['sh', '-c', 'trap "exit 0" INT; while true; do sleep 30; done;'], + 'image': 'alpine', + 'name': 'airflow-xcom-sidecar', + 'resources': {'requests': {'cpu': '1m'}}, + 'volumeMounts': [{'mountPath': '/airflow/xcom', 'name': 'xcom'}], + }, + ], + 'hostNetwork': False, + 'imagePullSecrets': [], + 'initContainers': [], + 'nodeSelector': {}, + 'restartPolicy': 'Never', + 'securityContext': {}, + 'serviceAccountName': 'default', + 'tolerations': [], + 'volumes': [{'emptyDir': {}, 'name': 'xcom'}], + }, + } self.assertEqual(expected_dict, actual_pod) @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod") @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod") @mock.patch("airflow.kubernetes.kube_client.get_kube_client") def test_pod_priority_class_name( - self, - mock_client, - monitor_mock, - start_mock # pylint: disable=unused-argument + self, mock_client, monitor_mock, start_mock # pylint: disable=unused-argument ): - """Test ability to assign priorityClassName to pod - - """ + """Test ability to assign priorityClassName to pod""" from airflow.utils.state import State priority_class_name = "medium-test" @@ -949,9 +863,9 @@ def test_pod_name(self): ) @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod") - def test_on_kill(self, - monitor_mock): # pylint: disable=unused-argument + def test_on_kill(self, monitor_mock): # pylint: disable=unused-argument from airflow.utils.state import State + client = kube_client.get_kube_client(in_cluster=False) name = "test" namespace = "default" @@ -979,6 +893,7 @@ def test_on_kill(self, def test_reattach_failing_pod_once(self): from airflow.utils.state import State + client = kube_client.get_kube_client(in_cluster=False) name = "test" namespace = "default" @@ -1009,11 +924,14 @@ def test_reattach_failing_pod_once(self): k.execute(context) pod = client.read_namespaced_pod(name=name, namespace=namespace) self.assertEqual(pod.metadata.labels["already_checked"], "True") - with mock.patch("airflow.providers.cncf.kubernetes" - ".operators.kubernetes_pod.KubernetesPodOperator" - ".create_new_pod_for_operator") as create_mock: + with mock.patch( + "airflow.providers.cncf.kubernetes" + ".operators.kubernetes_pod.KubernetesPodOperator" + ".create_new_pod_for_operator" + ) as create_mock: create_mock.return_value = ("success", {}, {}) k.execute(context) create_mock.assert_called_once() + # pylint: enable=unused-argument diff --git a/metastore_browser/hive_metastore.py b/metastore_browser/hive_metastore.py index 6335891dba9d0..58d4da3ca62a3 100644 --- a/metastore_browser/hive_metastore.py +++ b/metastore_browser/hive_metastore.py @@ -63,16 +63,14 @@ def index(self): """ hook = MySqlHook(METASTORE_MYSQL_CONN_ID) df = hook.get_pandas_df(sql) - df.db = ( - '' + df.db + '') + df.db = '' + df.db + '' table = df.to_html( classes="table table-striped table-bordered table-hover", index=False, escape=False, - na_rep='',) - return self.render_template( - "metastore_browser/dbs.html", table=Markup(table)) + na_rep='', + ) + return self.render_template("metastore_browser/dbs.html", table=Markup(table)) @expose('/table/') def table(self): @@ -81,8 +79,8 @@ def table(self): metastore = HiveMetastoreHook(METASTORE_CONN_ID) table = metastore.get_table(table_name) return self.render_template( - "metastore_browser/table.html", - table=table, table_name=table_name, datetime=datetime, int=int) + "metastore_browser/table.html", table=table, table_name=table_name, datetime=datetime, int=int + ) @expose('/db/') def db(self): @@ -90,8 +88,7 @@ def db(self): db = request.args.get("db") metastore = HiveMetastoreHook(METASTORE_CONN_ID) tables = sorted(metastore.get_tables(db=db), key=lambda x: x.tableName) - return self.render_template( - "metastore_browser/db.html", tables=tables, db=db) + return self.render_template("metastore_browser/db.html", tables=tables, db=db) @gzipped @expose('/partitions/') @@ -114,13 +111,16 @@ def partitions(self): b.TBL_NAME like '{table}' AND d.NAME like '{schema}' ORDER BY PART_NAME DESC - """.format(table=table, schema=schema) + """.format( + table=table, schema=schema + ) hook = MySqlHook(METASTORE_MYSQL_CONN_ID) df = hook.get_pandas_df(sql) return df.to_html( classes="table table-striped table-bordered table-hover", index=False, - na_rep='',) + na_rep='', + ) @gzipped @expose('/objects/') @@ -144,11 +144,11 @@ def objects(self): b.NAME NOT LIKE '%temp%' {where_clause} LIMIT {LIMIT}; - """.format(where_clause=where_clause, LIMIT=TABLE_SELECTOR_LIMIT) + """.format( + where_clause=where_clause, LIMIT=TABLE_SELECTOR_LIMIT + ) hook = MySqlHook(METASTORE_MYSQL_CONN_ID) - data = [ - {'id': row[0], 'text': row[0]} - for row in hook.get_records(sql)] + data = [{'id': row[0], 'text': row[0]} for row in hook.get_records(sql)] return json.dumps(data) @gzipped @@ -162,7 +162,8 @@ def data(self): return df.to_html( classes="table table-striped table-bordered table-hover", index=False, - na_rep='',) + na_rep='', + ) @expose('/ddl/') def ddl(self): @@ -175,10 +176,12 @@ def ddl(self): # Creating a flask blueprint to integrate the templates and static folder bp = Blueprint( - "metastore_browser", __name__, + "metastore_browser", + __name__, template_folder='templates', static_folder='static', - static_url_path='/static/metastore_browser') + static_url_path='/static/metastore_browser', +) class MetastoreBrowserPlugin(AirflowPlugin): @@ -186,6 +189,6 @@ class MetastoreBrowserPlugin(AirflowPlugin): name = "metastore_browser" flask_blueprints = [bp] - appbuilder_views = [{"name": "Hive Metadata Browser", - "category": "Plugins", - "view": MetastoreBrowserView()}] + appbuilder_views = [ + {"name": "Hive Metadata Browser", "category": "Plugins", "view": MetastoreBrowserView()} + ] diff --git a/provider_packages/import_all_provider_classes.py b/provider_packages/import_all_provider_classes.py index 87e7f1ea41fb1..ff0ffbc1199eb 100755 --- a/provider_packages/import_all_provider_classes.py +++ b/provider_packages/import_all_provider_classes.py @@ -25,9 +25,9 @@ from typing import List -def import_all_provider_classes(source_paths: str, - provider_ids: List[str] = None, - print_imports: bool = False) -> List[str]: +def import_all_provider_classes( + source_paths: str, provider_ids: List[str] = None, print_imports: bool = False +) -> List[str]: """ Imports all classes in providers packages. This method loads and imports all the classes found in providers, so that we can find all the subclasses @@ -61,8 +61,9 @@ def mk_prefix(provider_id): provider_ids = [''] for provider_id in provider_ids: - for modinfo in pkgutil.walk_packages(mk_path(provider_id), prefix=mk_prefix(provider_id), - onerror=onerror): + for modinfo in pkgutil.walk_packages( + mk_path(provider_id), prefix=mk_prefix(provider_id), onerror=onerror + ): if print_imports: print(f"Importing module: {modinfo.name}") try: @@ -78,9 +79,12 @@ def mk_prefix(provider_id): exception_str = traceback.format_exc() tracebacks.append(exception_str) if tracebacks: - print(""" + print( + """ ERROR: There were some import errors -""", file=sys.stderr) +""", + file=sys.stderr, + ) for trace in tracebacks: print("----------------------------------------", file=sys.stderr) print(trace, file=sys.stderr) @@ -93,6 +97,7 @@ def mk_prefix(provider_id): if __name__ == '__main__': try: import airflow.providers + install_source_path = list(iter(airflow.providers.__path__)) except ImportError as e: print("----------------------------------------", file=sys.stderr) diff --git a/provider_packages/prepare_provider_packages.py b/provider_packages/prepare_provider_packages.py index ef12b2aedfa7e..5eed2407ea1ff 100644 --- a/provider_packages/prepare_provider_packages.py +++ b/provider_packages/prepare_provider_packages.py @@ -36,7 +36,6 @@ from packaging.version import Version - PROVIDER_TEMPLATE_PREFIX = "PROVIDER_" BACKPORT_PROVIDER_TEMPLATE_PREFIX = "BACKPORT_PROVIDER_" @@ -183,8 +182,9 @@ def get_pip_package_name(provider_package_id: str, backport_packages: bool) -> s :param backport_packages: whether to prepare regular (False) or backport (True) packages :return: the name of pip package """ - return ("apache-airflow-backport-providers-" if backport_packages else "apache-airflow-providers-") \ - + provider_package_id.replace(".", "-") + return ( + "apache-airflow-backport-providers-" if backport_packages else "apache-airflow-providers-" + ) + provider_package_id.replace(".", "-") def get_long_description(provider_package_id: str, backport_packages: bool) -> str: @@ -196,8 +196,9 @@ def get_long_description(provider_package_id: str, backport_packages: bool) -> s :return: content of the description (BACKPORT_PROVIDER_README/README file) """ package_folder = get_target_providers_package_folder(provider_package_id) - readme_file = os.path.join(package_folder, - "BACKPORT_PROVIDER_README.md" if backport_packages else "README.md") + readme_file = os.path.join( + package_folder, "BACKPORT_PROVIDER_README.md" if backport_packages else "README.md" + ) if not os.path.exists(readme_file): return "" with open(readme_file, encoding='utf-8', mode="r") as file: @@ -216,9 +217,9 @@ def get_long_description(provider_package_id: str, backport_packages: bool) -> s return long_description -def get_package_release_version(provider_package_id: str, - backport_packages: bool, - version_suffix: str = "") -> str: +def get_package_release_version( + provider_package_id: str, backport_packages: bool, version_suffix: str = "" +) -> str: """ Returns release version including optional suffix. @@ -227,14 +228,15 @@ def get_package_release_version(provider_package_id: str, :param version_suffix: optional suffix (rc1, rc2 etc). :return: """ - return get_latest_release( - get_package_path(provider_package_id=provider_package_id), - backport_packages=backport_packages).release_version + version_suffix + return ( + get_latest_release( + get_package_path(provider_package_id=provider_package_id), backport_packages=backport_packages + ).release_version + + version_suffix + ) -def get_install_requirements( - provider_package_id: str, - backport_packages: bool) -> List[str]: +def get_install_requirements(provider_package_id: str, backport_packages: bool) -> List[str]: """ Returns install requirements for the package. @@ -246,8 +248,11 @@ def get_install_requirements( dependencies = PROVIDERS_REQUIREMENTS[provider_package_id] if backport_packages: - airflow_dependency = 'apache-airflow~=1.10' if provider_package_id != 'cncf.kubernetes' \ + airflow_dependency = ( + 'apache-airflow~=1.10' + if provider_package_id != 'cncf.kubernetes' else 'apache-airflow>=1.10.12, <2.0.0' + ) else: airflow_dependency = 'apache-airflow>=2.0.0a0' install_requires = [airflow_dependency] @@ -260,10 +265,7 @@ def get_setup_requirements() -> List[str]: Returns setup requirements (common for all package for now). :return: setup requirements """ - return [ - 'setuptools', - 'wheel' - ] + return ['setuptools', 'wheel'] def get_package_extras(provider_package_id: str, backport_packages: bool) -> Dict[str, List[str]]: @@ -278,9 +280,14 @@ def get_package_extras(provider_package_id: str, backport_packages: bool) -> Dic return {} with open(DEPENDENCIES_JSON_FILE) as dependencies_file: cross_provider_dependencies: Dict[str, List[str]] = json.load(dependencies_file) - extras_dict = {module: [get_pip_package_name(module, backport_packages=backport_packages)] - for module in cross_provider_dependencies[provider_package_id]} \ - if cross_provider_dependencies.get(provider_package_id) else {} + extras_dict = ( + { + module: [get_pip_package_name(module, backport_packages=backport_packages)] + for module in cross_provider_dependencies[provider_package_id] + } + if cross_provider_dependencies.get(provider_package_id) + else {} + ) return extras_dict @@ -360,6 +367,7 @@ def inherits_from(the_class: Type, expected_ancestor: Type) -> bool: if expected_ancestor is None: return False import inspect + mro = inspect.getmro(the_class) return the_class is not expected_ancestor and expected_ancestor in mro @@ -371,6 +379,7 @@ def is_class(the_class: Type) -> bool: :return: true if it is a class """ import inspect + return inspect.isclass(the_class) @@ -386,14 +395,15 @@ def package_name_matches(the_class: Type, expected_pattern: Optional[str]) -> bo def find_all_entities( - imported_classes: List[str], - base_package: str, - ancestor_match: Type, - sub_package_pattern_match: str, - expected_class_name_pattern: str, - unexpected_class_name_patterns: Set[str], - exclude_class_type: Type = None, - false_positive_class_names: Optional[Set[str]] = None) -> VerifiedEntities: + imported_classes: List[str], + base_package: str, + ancestor_match: Type, + sub_package_pattern_match: str, + expected_class_name_pattern: str, + unexpected_class_name_patterns: Set[str], + exclude_class_type: Type = None, + false_positive_class_names: Optional[Set[str]] = None, +) -> VerifiedEntities: """ Returns set of entities containing all subclasses in package specified. @@ -412,35 +422,44 @@ def find_all_entities( for imported_name in imported_classes: module, class_name = imported_name.rsplit(".", maxsplit=1) the_class = getattr(importlib.import_module(module), class_name) - if is_class(the_class=the_class) \ - and not is_example_dag(imported_name=imported_name) \ - and is_from_the_expected_base_package(the_class=the_class, expected_package=base_package) \ - and is_imported_from_same_module(the_class=the_class, imported_name=imported_name) \ - and inherits_from(the_class=the_class, expected_ancestor=ancestor_match) \ - and not inherits_from(the_class=the_class, expected_ancestor=exclude_class_type) \ - and package_name_matches(the_class=the_class, expected_pattern=sub_package_pattern_match): + if ( + is_class(the_class=the_class) + and not is_example_dag(imported_name=imported_name) + and is_from_the_expected_base_package(the_class=the_class, expected_package=base_package) + and is_imported_from_same_module(the_class=the_class, imported_name=imported_name) + and inherits_from(the_class=the_class, expected_ancestor=ancestor_match) + and not inherits_from(the_class=the_class, expected_ancestor=exclude_class_type) + and package_name_matches(the_class=the_class, expected_pattern=sub_package_pattern_match) + ): if not false_positive_class_names or class_name not in false_positive_class_names: if not re.match(expected_class_name_pattern, class_name): wrong_entities.append( - (the_class, f"The class name {class_name} is wrong. " - f"It should match {expected_class_name_pattern}")) + ( + the_class, + f"The class name {class_name} is wrong. " + f"It should match {expected_class_name_pattern}", + ) + ) continue if unexpected_class_name_patterns: for unexpected_class_name_pattern in unexpected_class_name_patterns: if re.match(unexpected_class_name_pattern, class_name): wrong_entities.append( - (the_class, - f"The class name {class_name} is wrong. " - f"It should not match {unexpected_class_name_pattern}")) + ( + the_class, + f"The class name {class_name} is wrong. " + f"It should not match {unexpected_class_name_pattern}", + ) + ) continue found_entities.add(imported_name) return VerifiedEntities(all_entities=found_entities, wrong_entities=wrong_entities) -def convert_new_classes_to_table(entity_type: EntityType, - new_entities: List[str], - full_package_name: str) -> str: +def convert_new_classes_to_table( + entity_type: EntityType, new_entities: List[str], full_package_name: str +) -> str: """ Converts new entities tp a markdown table. @@ -450,15 +469,15 @@ def convert_new_classes_to_table(entity_type: EntityType, :return: table of new classes """ from tabulate import tabulate + headers = [f"New Airflow 2.0 {entity_type.value.lower()}: `{full_package_name}` package"] - table = [(get_class_code_link(full_package_name, class_name, "master"),) - for class_name in new_entities] + table = [(get_class_code_link(full_package_name, class_name, "master"),) for class_name in new_entities] return tabulate(table, headers=headers, tablefmt="pipe") -def convert_moved_classes_to_table(entity_type: EntityType, - moved_entities: Dict[str, str], - full_package_name: str) -> str: +def convert_moved_classes_to_table( + entity_type: EntityType, moved_entities: Dict[str, str], full_package_name: str +) -> str: """ Converts moved entities to a markdown table :param entity_type: type of entities -> operators, sensors etc. @@ -467,21 +486,27 @@ def convert_moved_classes_to_table(entity_type: EntityType, :return: table of moved classes """ from tabulate import tabulate - headers = [f"Airflow 2.0 {entity_type.value.lower()}: `{full_package_name}` package", - "Airflow 1.10.* previous location (usually `airflow.contrib`)"] + + headers = [ + f"Airflow 2.0 {entity_type.value.lower()}: `{full_package_name}` package", + "Airflow 1.10.* previous location (usually `airflow.contrib`)", + ] table = [ - (get_class_code_link(full_package_name, to_class, "master"), - get_class_code_link("airflow", moved_entities[to_class], "v1-10-stable")) + ( + get_class_code_link(full_package_name, to_class, "master"), + get_class_code_link("airflow", moved_entities[to_class], "v1-10-stable"), + ) for to_class in sorted(moved_entities.keys()) ] return tabulate(table, headers=headers, tablefmt="pipe") def get_details_about_classes( - entity_type: EntityType, - entities: Set[str], - wrong_entities: List[Tuple[type, str]], - full_package_name: str) -> EntityTypeSummary: + entity_type: EntityType, + entities: Set[str], + wrong_entities: List[Tuple[type, str]], + full_package_name: str, +) -> EntityTypeSummary: """ Splits the set of entities into new and moved, depending on their presence in the dict of objects retrieved from the test_contrib_to_core. Updates all_entities with the split class. @@ -518,7 +543,7 @@ def get_details_about_classes( moved_entities=moved_entities, full_package_name=full_package_name, ), - wrong_entities=wrong_entities + wrong_entities=wrong_entities, ) @@ -527,7 +552,7 @@ def strip_package_from_class(base_package: str, class_name: str) -> str: Strips base package name from the class (if it starts with the package name). """ if class_name.startswith(base_package): - return class_name[len(base_package) + 1:] + return class_name[len(base_package) + 1 :] else: return class_name @@ -553,8 +578,10 @@ def get_class_code_link(base_package: str, class_name: str, git_tag: str) -> str :return: URL to the class """ url_prefix = f'https://github.com/apache/airflow/blob/{git_tag}/' - return f'[{strip_package_from_class(base_package, class_name)}]' \ - f'({convert_class_name_to_url(url_prefix, class_name)})' + return ( + f'[{strip_package_from_class(base_package, class_name)}]' + f'({convert_class_name_to_url(url_prefix, class_name)})' + ) def print_wrong_naming(entity_type: EntityType, wrong_classes: List[Tuple[type, str]]): @@ -569,8 +596,9 @@ def print_wrong_naming(entity_type: EntityType, wrong_classes: List[Tuple[type, print(f"{entity_type}: {message}", file=sys.stderr) -def get_package_class_summary(full_package_name: str, imported_classes: List[str]) \ - -> Dict[EntityType, EntityTypeSummary]: +def get_package_class_summary( + full_package_name: str, imported_classes: List[str] +) -> Dict[EntityType, EntityTypeSummary]: """ Gets summary of the package in the form of dictionary containing all types of entities :param full_package_name: full package name @@ -582,72 +610,80 @@ def get_package_class_summary(full_package_name: str, imported_classes: List[str from airflow.secrets import BaseSecretsBackend from airflow.sensors.base_sensor_operator import BaseSensorOperator - all_verified_entities: Dict[EntityType, VerifiedEntities] = {EntityType.Operators: find_all_entities( - imported_classes=imported_classes, - base_package=full_package_name, - sub_package_pattern_match=r".*\.operators\..*", - ancestor_match=BaseOperator, - expected_class_name_pattern=OPERATORS_PATTERN, - unexpected_class_name_patterns=ALL_PATTERNS - {OPERATORS_PATTERN}, - exclude_class_type=BaseSensorOperator, - false_positive_class_names={ - 'CloudVisionAddProductToProductSetOperator', - 'CloudDataTransferServiceGCSToGCSOperator', - 'CloudDataTransferServiceS3ToGCSOperator', - 'BigQueryCreateDataTransferOperator', - 'CloudTextToSpeechSynthesizeOperator', - 'CloudSpeechToTextRecognizeSpeechOperator', - } - ), EntityType.Sensors: find_all_entities( - imported_classes=imported_classes, - base_package=full_package_name, - sub_package_pattern_match=r".*\.sensors\..*", - ancestor_match=BaseSensorOperator, - expected_class_name_pattern=SENSORS_PATTERN, - unexpected_class_name_patterns=ALL_PATTERNS - {OPERATORS_PATTERN, SENSORS_PATTERN} - ), EntityType.Hooks: find_all_entities( - imported_classes=imported_classes, - base_package=full_package_name, - sub_package_pattern_match=r".*\.hooks\..*", - ancestor_match=BaseHook, - expected_class_name_pattern=HOOKS_PATTERN, - unexpected_class_name_patterns=ALL_PATTERNS - {HOOKS_PATTERN} - ), EntityType.Secrets: find_all_entities( - imported_classes=imported_classes, - sub_package_pattern_match=r".*\.secrets\..*", - base_package=full_package_name, - ancestor_match=BaseSecretsBackend, - expected_class_name_pattern=SECRETS_PATTERN, - unexpected_class_name_patterns=ALL_PATTERNS - {SECRETS_PATTERN}, - ), EntityType.Transfers: find_all_entities( - imported_classes=imported_classes, - base_package=full_package_name, - sub_package_pattern_match=r".*\.transfers\..*", - ancestor_match=BaseOperator, - expected_class_name_pattern=TRANSFERS_PATTERN, - unexpected_class_name_patterns=ALL_PATTERNS - {OPERATORS_PATTERN, TRANSFERS_PATTERN}, - )} + all_verified_entities: Dict[EntityType, VerifiedEntities] = { + EntityType.Operators: find_all_entities( + imported_classes=imported_classes, + base_package=full_package_name, + sub_package_pattern_match=r".*\.operators\..*", + ancestor_match=BaseOperator, + expected_class_name_pattern=OPERATORS_PATTERN, + unexpected_class_name_patterns=ALL_PATTERNS - {OPERATORS_PATTERN}, + exclude_class_type=BaseSensorOperator, + false_positive_class_names={ + 'CloudVisionAddProductToProductSetOperator', + 'CloudDataTransferServiceGCSToGCSOperator', + 'CloudDataTransferServiceS3ToGCSOperator', + 'BigQueryCreateDataTransferOperator', + 'CloudTextToSpeechSynthesizeOperator', + 'CloudSpeechToTextRecognizeSpeechOperator', + }, + ), + EntityType.Sensors: find_all_entities( + imported_classes=imported_classes, + base_package=full_package_name, + sub_package_pattern_match=r".*\.sensors\..*", + ancestor_match=BaseSensorOperator, + expected_class_name_pattern=SENSORS_PATTERN, + unexpected_class_name_patterns=ALL_PATTERNS - {OPERATORS_PATTERN, SENSORS_PATTERN}, + ), + EntityType.Hooks: find_all_entities( + imported_classes=imported_classes, + base_package=full_package_name, + sub_package_pattern_match=r".*\.hooks\..*", + ancestor_match=BaseHook, + expected_class_name_pattern=HOOKS_PATTERN, + unexpected_class_name_patterns=ALL_PATTERNS - {HOOKS_PATTERN}, + ), + EntityType.Secrets: find_all_entities( + imported_classes=imported_classes, + sub_package_pattern_match=r".*\.secrets\..*", + base_package=full_package_name, + ancestor_match=BaseSecretsBackend, + expected_class_name_pattern=SECRETS_PATTERN, + unexpected_class_name_patterns=ALL_PATTERNS - {SECRETS_PATTERN}, + ), + EntityType.Transfers: find_all_entities( + imported_classes=imported_classes, + base_package=full_package_name, + sub_package_pattern_match=r".*\.transfers\..*", + ancestor_match=BaseOperator, + expected_class_name_pattern=TRANSFERS_PATTERN, + unexpected_class_name_patterns=ALL_PATTERNS - {OPERATORS_PATTERN, TRANSFERS_PATTERN}, + ), + } for entity in EntityType: print_wrong_naming(entity, all_verified_entities[entity].wrong_entities) - entities_summary: Dict[EntityType, EntityTypeSummary] = {} # noqa + entities_summary: Dict[EntityType, EntityTypeSummary] = {} # noqa for entity_type in EntityType: entities_summary[entity_type] = get_details_about_classes( entity_type, all_verified_entities[entity_type].all_entities, all_verified_entities[entity_type].wrong_entities, - full_package_name) + full_package_name, + ) return entities_summary def render_template( - template_name: str, - context: Dict[str, Any], - extension: str, - autoescape: bool = True, - keep_trailing_newline: bool = False) -> str: + template_name: str, + context: Dict[str, Any], + extension: str, + autoescape: bool = True, + keep_trailing_newline: bool = False, +) -> str: """ Renders template based on it's name. Reads the template from _TEMPLATE.md.jinja2 in current dir. :param template_name: name of the template to use @@ -658,12 +694,13 @@ def render_template( :return: rendered template """ import jinja2 + template_loader = jinja2.FileSystemLoader(searchpath=MY_DIR_PATH) template_env = jinja2.Environment( loader=template_loader, undefined=jinja2.StrictUndefined, autoescape=autoescape, - keep_trailing_newline=keep_trailing_newline + keep_trailing_newline=keep_trailing_newline, ) template = template_env.get_template(f"{template_name}_TEMPLATE{extension}.jinja2") content: str = template.render(context) @@ -684,6 +721,7 @@ def convert_git_changes_to_table(changes: str, base_url: str) -> str: :return: markdown-formatted table """ from tabulate import tabulate + lines = changes.split("\n") headers = ["Commit", "Committed", "Subject"] table_data = [] @@ -702,6 +740,7 @@ def convert_pip_requirements_to_table(requirements: Iterable[str]) -> str: :return: markdown-formatted table """ from tabulate import tabulate + headers = ["PIP package", "Version required"] table_data = [] for dependency in requirements: @@ -715,8 +754,7 @@ def convert_pip_requirements_to_table(requirements: Iterable[str]) -> str: return tabulate(table_data, headers=headers, tablefmt="pipe") -def convert_cross_package_dependencies_to_table( - cross_package_dependencies: List[str], base_url: str) -> str: +def convert_cross_package_dependencies_to_table(cross_package_dependencies: List[str], base_url: str) -> str: """ Converts cross-package dependencies to a markdown table :param cross_package_dependencies: list of cross-package dependencies @@ -724,6 +762,7 @@ def convert_cross_package_dependencies_to_table( :return: markdown-formatted table """ from tabulate import tabulate + headers = ["Dependent package", "Extra"] table_data = [] for dependency in cross_package_dependencies: @@ -757,8 +796,8 @@ def convert_cross_package_dependencies_to_table( Keeps information about historical releases. """ ReleaseInfo = collections.namedtuple( - "ReleaseInfo", - "release_version release_version_no_leading_zeros last_commit_hash content file_name") + "ReleaseInfo", "release_version release_version_no_leading_zeros last_commit_hash content file_name" +) def strip_leading_zeros_in_calver(calver_version: str) -> str: @@ -805,15 +844,19 @@ def get_all_releases(provider_package_path: str, backport_packages: bool) -> Lis print("No commit found. This seems to be first time you run it", file=sys.stderr) else: last_commit_hash = found.group(1) - release_version = file_name[len(changes_file_prefix):][:-3] - release_version_no_leading_zeros = strip_leading_zeros_in_calver(release_version) \ - if backport_packages else release_version + release_version = file_name[len(changes_file_prefix) :][:-3] + release_version_no_leading_zeros = ( + strip_leading_zeros_in_calver(release_version) if backport_packages else release_version + ) past_releases.append( - ReleaseInfo(release_version=release_version, - release_version_no_leading_zeros=release_version_no_leading_zeros, - last_commit_hash=last_commit_hash, - content=content, - file_name=file_name)) + ReleaseInfo( + release_version=release_version, + release_version_no_leading_zeros=release_version_no_leading_zeros, + last_commit_hash=last_commit_hash, + content=content, + file_name=file_name, + ) + ) return past_releases @@ -825,21 +868,24 @@ def get_latest_release(provider_package_path: str, backport_packages: bool) -> R :param backport_packages: whether to prepare regular (False) or backport (True) packages :return: latest release information """ - releases = get_all_releases(provider_package_path=provider_package_path, - backport_packages=backport_packages) + releases = get_all_releases( + provider_package_path=provider_package_path, backport_packages=backport_packages + ) if len(releases) == 0: - return ReleaseInfo(release_version="0.0.0", - release_version_no_leading_zeros="0.0.0", - last_commit_hash="no_hash", - content="empty", - file_name="no_file") + return ReleaseInfo( + release_version="0.0.0", + release_version_no_leading_zeros="0.0.0", + last_commit_hash="no_hash", + content="empty", + file_name="no_file", + ) else: return releases[0] -def get_previous_release_info(previous_release_version: str, - past_releases: List[ReleaseInfo], - current_release_version: str) -> Optional[str]: +def get_previous_release_info( + previous_release_version: str, past_releases: List[ReleaseInfo], current_release_version: str +) -> Optional[str]: """ Find previous release. In case we are re-running current release we assume that last release was the previous one. This is needed so that we can generate list of changes since the previous release. @@ -859,9 +905,8 @@ def get_previous_release_info(previous_release_version: str, def check_if_release_version_ok( - past_releases: List[ReleaseInfo], - current_release_version: str, - backport_packages: bool) -> Tuple[str, Optional[str]]: + past_releases: List[ReleaseInfo], current_release_version: str, backport_packages: bool +) -> Tuple[str, Optional[str]]: """ Check if the release version passed is not later than the last release version :param past_releases: all past releases (if there are any) @@ -880,8 +925,11 @@ def check_if_release_version_ok( current_release_version = "0.0.1" # TODO: replace with maintained version if previous_release_version: if Version(current_release_version) < Version(previous_release_version): - print(f"The release {current_release_version} must be not less than " - f"{previous_release_version} - last release for the package", file=sys.stderr) + print( + f"The release {current_release_version} must be not less than " + f"{previous_release_version} - last release for the package", + file=sys.stderr, + ) sys.exit(2) return current_release_version, previous_release_version @@ -907,17 +955,23 @@ def make_sure_remote_apache_exists_and_fetch(): :return: """ try: - subprocess.check_call(["git", "remote", "add", "apache-https-for-providers", - "https://github.com/apache/airflow.git"], - stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + subprocess.check_call( + ["git", "remote", "add", "apache-https-for-providers", "https://github.com/apache/airflow.git"], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) except subprocess.CalledProcessError as e: if e.returncode == 128: - print("The remote `apache-https-for-providers` already exists. If you have trouble running " - "git log delete the remote", file=sys.stderr) + print( + "The remote `apache-https-for-providers` already exists. If you have trouble running " + "git log delete the remote", + file=sys.stderr, + ) else: raise - subprocess.check_call(["git", "fetch", "apache-https-for-providers"], - stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + subprocess.check_call( + ["git", "fetch", "apache-https-for-providers"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL + ) def get_git_command(base_commit: Optional[str]) -> List[str]: @@ -926,18 +980,22 @@ def get_git_command(base_commit: Optional[str]) -> List[str]: :param base_commit: if present - base commit from which to start the log from :return: git command to run """ - git_cmd = ["git", "log", "apache-https-for-providers/master", - "--pretty=format:%H %h %cd %s", "--date=short"] + git_cmd = [ + "git", + "log", + "apache-https-for-providers/master", + "--pretty=format:%H %h %cd %s", + "--date=short", + ] if base_commit: git_cmd.append(f"{base_commit}...HEAD") git_cmd.extend(['--', '.']) return git_cmd -def store_current_changes(provider_package_path: str, - current_release_version: str, - current_changes: str, - backport_packages: bool) -> None: +def store_current_changes( + provider_package_path: str, current_release_version: str, current_changes: str, backport_packages: bool +) -> None: """ Stores current changes in the *_changes_YYYY.MM.DD.md file. @@ -948,7 +1006,8 @@ def store_current_changes(provider_package_path: str, """ current_changes_file_path = os.path.join( provider_package_path, - get_provider_changes_prefix(backport_packages=backport_packages) + current_release_version + ".md") + get_provider_changes_prefix(backport_packages=backport_packages) + current_release_version + ".md", + ) with open(current_changes_file_path, "wt") as current_changes_file: current_changes_file.write(current_changes) current_changes_file.write("\n") @@ -999,7 +1058,8 @@ def is_camel_case_with_acronyms(s: str): def check_if_classes_are_properly_named( - entity_summary: Dict[EntityType, EntityTypeSummary]) -> Tuple[int, int]: + entity_summary: Dict[EntityType, EntityTypeSummary] +) -> Tuple[int, int]: """ Check if all entities in the dictionary are named properly. It prints names at the output and returns the status of class names. @@ -1014,12 +1074,16 @@ def check_if_classes_are_properly_named( _, class_name = class_full_name.rsplit(".", maxsplit=1) error_encountered = False if not is_camel_case_with_acronyms(class_name): - print(f"The class {class_full_name} is wrongly named. The " - f"class name should be CamelCaseWithACRONYMS !") + print( + f"The class {class_full_name} is wrongly named. The " + f"class name should be CamelCaseWithACRONYMS !" + ) error_encountered = True if not class_name.endswith(class_suffix): - print(f"The class {class_full_name} is wrongly named. It is one of the {entity_type.value}" - f" so it should end with {class_suffix}") + print( + f"The class {class_full_name} is wrongly named. It is one of the {entity_type.value}" + f" so it should end with {class_suffix}" + ) error_encountered = True total_class_number += 1 if error_encountered: @@ -1034,11 +1098,13 @@ def get_package_pip_name(provider_package_id: str, backport_packages: bool): return f"apache-airflow-providers-{provider_package_id.replace('.', '-')}" -def update_release_notes_for_package(provider_package_id: str, - current_release_version: str, - version_suffix: str, - imported_classes: List[str], - backport_packages: bool) -> Tuple[int, int]: +def update_release_notes_for_package( + provider_package_id: str, + current_release_version: str, + version_suffix: str, + imported_classes: List[str], + backport_packages: bool, +) -> Tuple[int, int]: """ Updates release notes (BACKPORT_PROVIDER_README.md/README.md) for the package. Returns Tuple of total number of entities and badly named entities. @@ -1055,27 +1121,35 @@ def update_release_notes_for_package(provider_package_id: str, full_package_name = f"airflow.providers.{provider_package_id}" provider_package_path = get_package_path(provider_package_id) entity_summaries = get_package_class_summary(full_package_name, imported_classes) - past_releases = get_all_releases(provider_package_path=provider_package_path, - backport_packages=backport_packages) + past_releases = get_all_releases( + provider_package_path=provider_package_path, backport_packages=backport_packages + ) current_release_version, previous_release = check_if_release_version_ok( - past_releases, current_release_version, backport_packages) - cross_providers_dependencies = \ - get_cross_provider_dependent_packages(provider_package_id=provider_package_id) - previous_release = get_previous_release_info(previous_release_version=previous_release, - past_releases=past_releases, - current_release_version=current_release_version) + past_releases, current_release_version, backport_packages + ) + cross_providers_dependencies = get_cross_provider_dependent_packages( + provider_package_id=provider_package_id + ) + previous_release = get_previous_release_info( + previous_release_version=previous_release, + past_releases=past_releases, + current_release_version=current_release_version, + ) git_cmd = get_git_command(previous_release) changes = subprocess.check_output(git_cmd, cwd=provider_package_path, universal_newlines=True) changes_table = convert_git_changes_to_table( - changes, - base_url="https://github.com/apache/airflow/commit/") + changes, base_url="https://github.com/apache/airflow/commit/" + ) pip_requirements_table = convert_pip_requirements_to_table(PROVIDERS_REQUIREMENTS[provider_package_id]) - cross_providers_dependencies_table = \ - convert_cross_package_dependencies_to_table( - cross_providers_dependencies, - base_url="https://github.com/apache/airflow/tree/master/airflow/providers/") - release_version_no_leading_zeros = strip_leading_zeros_in_calver(current_release_version) \ - if backport_packages else current_release_version + cross_providers_dependencies_table = convert_cross_package_dependencies_to_table( + cross_providers_dependencies, + base_url="https://github.com/apache/airflow/tree/master/airflow/providers/", + ) + release_version_no_leading_zeros = ( + strip_leading_zeros_in_calver(current_release_version) + if backport_packages + else current_release_version + ) context: Dict[str, Any] = { "ENTITY_TYPES": list(EntityType), "README_FILE": "BACKPORT_PROVIDER_README.md" if backport_packages else "README.md", @@ -1095,17 +1169,21 @@ def update_release_notes_for_package(provider_package_id: str, "PROVIDER_TYPE": "Backport provider" if BACKPORT_PACKAGES else "Provider", "PROVIDERS_FOLDER": "backport-providers" if BACKPORT_PACKAGES else "providers", "INSTALL_REQUIREMENTS": get_install_requirements( - provider_package_id=provider_package_id, - backport_packages=backport_packages + provider_package_id=provider_package_id, backport_packages=backport_packages ), "SETUP_REQUIREMENTS": get_setup_requirements(), "EXTRAS_REQUIREMENTS": get_package_extras( - provider_package_id=provider_package_id, - backport_packages=backport_packages - ) + provider_package_id=provider_package_id, backport_packages=backport_packages + ), } - prepare_readme_and_changes_files(backport_packages, context, current_release_version, entity_summaries, - provider_package_id, provider_package_path) + prepare_readme_and_changes_files( + backport_packages, + context, + current_release_version, + entity_summaries, + provider_package_id, + provider_package_path, + ) prepare_setup_py_file(context, provider_package_path) prepare_setup_cfg_file(context, provider_package_path) total, bad = check_if_classes_are_properly_named(entity_summaries) @@ -1125,17 +1203,27 @@ def get_template_name(backport_packages: bool, template_suffix: str) -> str: :param template_suffix: suffix to add :return template name """ - return (BACKPORT_PROVIDER_TEMPLATE_PREFIX if backport_packages else PROVIDER_TEMPLATE_PREFIX) \ - + template_suffix + return ( + BACKPORT_PROVIDER_TEMPLATE_PREFIX if backport_packages else PROVIDER_TEMPLATE_PREFIX + ) + template_suffix -def prepare_readme_and_changes_files(backport_packages, context, current_release_version, entity_summaries, - provider_package_id, provider_package_path): +def prepare_readme_and_changes_files( + backport_packages, + context, + current_release_version, + entity_summaries, + provider_package_id, + provider_package_path, +): changes_template_name = get_template_name(backport_packages, "CHANGES") current_changes = render_template(template_name=changes_template_name, context=context, extension='.md') - store_current_changes(provider_package_path=provider_package_path, - current_release_version=current_release_version, - current_changes=current_changes, backport_packages=backport_packages) + store_current_changes( + provider_package_path=provider_package_path, + current_release_version=current_release_version, + current_changes=current_changes, + backport_packages=backport_packages, + ) context['ENTITIES'] = entity_summaries context['ENTITY_NAMES'] = ENTITY_NAMES all_releases = get_all_releases(provider_package_path, backport_packages=backport_packages) @@ -1147,8 +1235,9 @@ def prepare_readme_and_changes_files(backport_packages, context, current_release readme += render_template(template_name=classes_template_name, context=context, extension='.md') for a_release in all_releases: readme += a_release.content - readme_file_path = os.path.join(provider_package_path, - "BACKPORT_PROVIDER_README.md" if backport_packages else "README.md") + readme_file_path = os.path.join( + provider_package_path, "BACKPORT_PROVIDER_README.md" if backport_packages else "README.md" + ) old_text = "" if os.path.isfile(readme_file_path): with open(readme_file_path) as readme_file_read: @@ -1174,10 +1263,7 @@ def prepare_setup_py_file(context, provider_package_path): setup_target_prefix = BACKPORT_PROVIDER_PREFIX if BACKPORT_PACKAGES else "" setup_file_path = os.path.abspath(os.path.join(provider_package_path, setup_target_prefix + "setup.py")) setup_content = render_template( - template_name=setup_template_name, - context=context, - extension='.py', - autoescape=False + template_name=setup_template_name, context=context, extension='.py', autoescape=False ) with open(setup_file_path, "wt") as setup_file: setup_file.write(setup_content) @@ -1194,16 +1280,15 @@ def prepare_setup_cfg_file(context, provider_package_path): context=context, extension='.cfg', autoescape=False, - keep_trailing_newline=True + keep_trailing_newline=True, ) with open(setup_file_path, "wt") as setup_file: setup_file.write(setup_content) -def update_release_notes_for_packages(provider_ids: List[str], - release_version: str, - version_suffix: str, - backport_packages: bool): +def update_release_notes_for_packages( + provider_ids: List[str], release_version: str, version_suffix: str, backport_packages: bool +): """ Updates release notes for the list of packages specified. :param provider_ids: list of provider ids @@ -1213,7 +1298,8 @@ def update_release_notes_for_packages(provider_ids: List[str], :return: """ imported_classes = import_all_provider_classes( - source_paths=[PROVIDERS_PATH], provider_ids=provider_ids, print_imports=False) + source_paths=[PROVIDERS_PATH], provider_ids=provider_ids, print_imports=False + ) make_sure_remote_apache_exists_and_fetch() if len(provider_ids) == 0: if backport_packages: @@ -1231,11 +1317,8 @@ def update_release_notes_for_packages(provider_ids: List[str], print() for package in provider_ids: inc_total, inc_bad = update_release_notes_for_package( - package, - release_version, - version_suffix, - imported_classes, - backport_packages) + package, release_version, version_suffix, imported_classes, backport_packages + ) total += inc_total bad += inc_bad if bad == 0: @@ -1287,12 +1370,12 @@ def verify_provider_package(package: str) -> None: :return: None """ if package not in get_provider_packages(): - raise Exception(f"The package {package} is not a provider package. " - f"Use one of {get_provider_packages()}") + raise Exception( + f"The package {package} is not a provider package. " f"Use one of {get_provider_packages()}" + ) -def copy_setup_py(provider_package_id: str, - backport_packages: bool) -> None: +def copy_setup_py(provider_package_id: str, backport_packages: bool) -> None: """ Copies setup.py to provider_package directory. :param provider_package_id: package from which to copy the setup.py @@ -1301,12 +1384,13 @@ def copy_setup_py(provider_package_id: str, """ setup_source_prefix = BACKPORT_PROVIDER_PREFIX if backport_packages else "" provider_package_path = get_package_path(provider_package_id) - copyfile(os.path.join(provider_package_path, setup_source_prefix + "setup.py"), - os.path.join(MY_DIR_PATH, "setup.py")) + copyfile( + os.path.join(provider_package_path, setup_source_prefix + "setup.py"), + os.path.join(MY_DIR_PATH, "setup.py"), + ) -def copy_setup_cfg(provider_package_id: str, - backport_packages: bool) -> None: +def copy_setup_cfg(provider_package_id: str, backport_packages: bool) -> None: """ Copies setup.py to provider_package directory. :param provider_package_id: package from which to copy the setup.cfg @@ -1315,12 +1399,13 @@ def copy_setup_cfg(provider_package_id: str, """ setup_source_prefix = BACKPORT_PROVIDER_PREFIX if backport_packages else "" provider_package_path = get_package_path(provider_package_id) - copyfile(os.path.join(provider_package_path, setup_source_prefix + "setup.cfg"), - os.path.join(MY_DIR_PATH, "setup.cfg")) + copyfile( + os.path.join(provider_package_path, setup_source_prefix + "setup.cfg"), + os.path.join(MY_DIR_PATH, "setup.cfg"), + ) -def copy_readme_and_changelog(provider_package_id: str, - backport_packages: bool) -> None: +def copy_readme_and_changelog(provider_package_id: str, backport_packages: bool) -> None: """ Copies the right README.md/CHANGELOG.txt to provider_package directory. :param provider_package_id: package from which to copy the setup.py @@ -1347,7 +1432,7 @@ def copy_readme_and_changelog(provider_package_id: str, LIST_BACKPORTABLE_PACKAGES = "list-backportable-packages" UPDATE_PACKAGE_RELEASE_NOTES = "update-package-release-notes" - BACKPORT_PACKAGES = (os.getenv('BACKPORT_PACKAGES') == "true") + BACKPORT_PACKAGES = os.getenv('BACKPORT_PACKAGES') == "true" suffix = "" provider_names = get_provider_packages() @@ -1356,16 +1441,22 @@ def copy_readme_and_changelog(provider_package_id: str, possible_first_params.append(LIST_BACKPORTABLE_PACKAGES) possible_first_params.append(UPDATE_PACKAGE_RELEASE_NOTES) if len(sys.argv) == 1: - print(""" + print( + """ ERROR! Missing first param" -""", file=sys.stderr) +""", + file=sys.stderr, + ) usage() sys.exit(1) if sys.argv[1] == "--version-suffix": if len(sys.argv) < 3: - print(""" + print( + """ ERROR! --version-suffix needs parameter! -""", file=sys.stderr) +""", + file=sys.stderr, + ) usage() sys.exit(1) suffix = sys.argv[2] @@ -1375,9 +1466,12 @@ def copy_readme_and_changelog(provider_package_id: str, sys.exit(0) if sys.argv[1] not in possible_first_params: - print(f""" + print( + f""" ERROR! Wrong first param: {sys.argv[1]} -""", file=sys.stderr) +""", + file=sys.stderr, + ) usage() print() sys.exit(1) @@ -1410,7 +1504,8 @@ def copy_readme_and_changelog(provider_package_id: str, package_list, release_version=release_ver, version_suffix=suffix, - backport_packages=BACKPORT_PACKAGES) + backport_packages=BACKPORT_PACKAGES, + ) sys.exit(0) _provider_package = sys.argv[1] diff --git a/provider_packages/refactor_provider_packages.py b/provider_packages/refactor_provider_packages.py index 9ea80e9532967..ccacef21cfdc8 100755 --- a/provider_packages/refactor_provider_packages.py +++ b/provider_packages/refactor_provider_packages.py @@ -27,7 +27,9 @@ from fissix.pytree import Leaf from provider_packages.prepare_provider_packages import ( - get_source_airflow_folder, get_source_providers_folder, get_target_providers_folder, + get_source_airflow_folder, + get_source_providers_folder, + get_target_providers_folder, get_target_providers_package_folder, ) @@ -36,6 +38,7 @@ def copy_provider_sources() -> None: """ Copies provider sources to directory where they will be refactored. """ + def rm_build_dir() -> None: """ Removes build directory. @@ -111,6 +114,7 @@ def remove_class(self, class_name) -> None: :param class_name: name to remove """ + def _remover(node: LN, capture: Capture, filename: Filename) -> None: node.remove() @@ -181,27 +185,23 @@ def add_provide_context_to_python_operators(self) -> None: ) """ + def add_provide_context_to_python_operator(node: LN, capture: Capture, filename: Filename) -> None: fn_args = capture['function_arguments'][0] - if len(fn_args.children) > 0 and (not isinstance(fn_args.children[-1], Leaf) - or fn_args.children[-1].type != token.COMMA): + if len(fn_args.children) > 0 and ( + not isinstance(fn_args.children[-1], Leaf) or fn_args.children[-1].type != token.COMMA + ): fn_args.append_child(Comma()) provide_context_arg = KeywordArg(Name('provide_context'), Name('True')) provide_context_arg.prefix = fn_args.children[0].prefix fn_args.append_child(provide_context_arg) + (self.qry.select_function("PythonOperator").is_call().modify(add_provide_context_to_python_operator)) ( - self.qry. - select_function("PythonOperator"). - is_call(). - modify(add_provide_context_to_python_operator) - ) - ( - self.qry. - select_function("BranchPythonOperator"). - is_call(). - modify(add_provide_context_to_python_operator) + self.qry.select_function("BranchPythonOperator") + .is_call() + .modify(add_provide_context_to_python_operator) ) def remove_super_init_call(self): @@ -259,6 +259,7 @@ def __init__(self, context=None): self.max_ingestion_time = max_ingestion_time """ + def remove_super_init_call_modifier(node: LN, capture: Capture, filename: Filename) -> None: for ch in node.post_order(): if isinstance(ch, Leaf) and ch.value == "super": @@ -294,6 +295,7 @@ def remove_tags(self): ) as dag: """ + def remove_tags_modifier(_: LN, capture: Capture, filename: Filename) -> None: for node in capture['function_arguments'][0].post_order(): if isinstance(node, Leaf) and node.value == "tags" and node.type == TOKEN.NAME: @@ -324,6 +326,7 @@ class GCSUploadSessionCompleteSensor(BaseSensorOperator): Checks for changes in the number of objects at prefix in Google Cloud Storage """ + def find_and_remove_poke_mode_only_import(node: LN): for child in node.children: if isinstance(child, Leaf) and child.type == 1 and child.value == 'poke_mode_only': @@ -353,9 +356,14 @@ def find_root_remove_import(node: LN): find_and_remove_poke_mode_only_import(current_node) def is_poke_mode_only_decorator(node: LN) -> bool: - return node.children and len(node.children) >= 2 and \ - isinstance(node.children[0], Leaf) and node.children[0].value == '@' and \ - isinstance(node.children[1], Leaf) and node.children[1].value == 'poke_mode_only' + return ( + node.children + and len(node.children) >= 2 + and isinstance(node.children[0], Leaf) + and node.children[0].value == '@' + and isinstance(node.children[1], Leaf) + and node.children[1].value == 'poke_mode_only' + ) def remove_poke_mode_only_modifier(node: LN, capture: Capture, filename: Filename) -> None: for child in capture['node'].parent.children: @@ -391,36 +399,37 @@ def refactor_amazon_package(self): def amazon_package_filter(node: LN, capture: Capture, filename: Filename) -> bool: return filename.startswith("./airflow/providers/amazon/") - os.makedirs(os.path.join(get_target_providers_package_folder("amazon"), "common", "utils"), - exist_ok=True) + os.makedirs( + os.path.join(get_target_providers_package_folder("amazon"), "common", "utils"), exist_ok=True + ) copyfile( os.path.join(get_source_airflow_folder(), "airflow", "utils", "__init__.py"), - os.path.join(get_target_providers_package_folder("amazon"), "common", "__init__.py") + os.path.join(get_target_providers_package_folder("amazon"), "common", "__init__.py"), ) copyfile( os.path.join(get_source_airflow_folder(), "airflow", "utils", "__init__.py"), - os.path.join(get_target_providers_package_folder("amazon"), "common", "utils", "__init__.py") + os.path.join(get_target_providers_package_folder("amazon"), "common", "utils", "__init__.py"), ) copyfile( os.path.join(get_source_airflow_folder(), "airflow", "typing_compat.py"), - os.path.join(get_target_providers_package_folder("amazon"), "common", "utils", "typing_compat.py") + os.path.join( + get_target_providers_package_folder("amazon"), "common", "utils", "typing_compat.py" + ), ) ( - self.qry. - select_module("airflow.typing_compat"). - filter(callback=amazon_package_filter). - rename("airflow.providers.amazon.common.utils.typing_compat") + self.qry.select_module("airflow.typing_compat") + .filter(callback=amazon_package_filter) + .rename("airflow.providers.amazon.common.utils.typing_compat") ) copyfile( os.path.join(get_source_airflow_folder(), "airflow", "utils", "email.py"), - os.path.join(get_target_providers_package_folder("amazon"), "common", "utils", "email.py") + os.path.join(get_target_providers_package_folder("amazon"), "common", "utils", "email.py"), ) ( - self.qry. - select_module("airflow.utils.email"). - filter(callback=amazon_package_filter). - rename("airflow.providers.amazon.common.utils.email") + self.qry.select_module("airflow.utils.email") + .filter(callback=amazon_package_filter) + .rename("airflow.providers.amazon.common.utils.email") ) def refactor_google_package(self): @@ -517,6 +526,7 @@ def _generate_virtualenv_cmd(tmp_dir: str, python_bin: str, system_site_packages KEY_REGEX = re.compile(r'^[\\w.-]+$') """ + def google_package_filter(node: LN, capture: Capture, filename: Filename) -> bool: return filename.startswith("./airflow/providers/google/") @@ -524,64 +534,66 @@ def pure_airflow_models_filter(node: LN, capture: Capture, filename: Filename) - """Check if select is exactly [airflow, . , models]""" return len(list(node.children[1].leaves())) == 3 - os.makedirs(os.path.join(get_target_providers_package_folder("google"), "common", "utils"), - exist_ok=True) + os.makedirs( + os.path.join(get_target_providers_package_folder("google"), "common", "utils"), exist_ok=True + ) copyfile( os.path.join(get_source_airflow_folder(), "airflow", "utils", "__init__.py"), - os.path.join(get_target_providers_package_folder("google"), "common", "utils", "__init__.py") + os.path.join(get_target_providers_package_folder("google"), "common", "utils", "__init__.py"), ) copyfile( os.path.join(get_source_airflow_folder(), "airflow", "utils", "python_virtualenv.py"), - os.path.join(get_target_providers_package_folder("google"), "common", "utils", - "python_virtualenv.py") + os.path.join( + get_target_providers_package_folder("google"), "common", "utils", "python_virtualenv.py" + ), ) - copy_helper_py_file(os.path.join( - get_target_providers_package_folder("google"), "common", "utils", "helpers.py")) + copy_helper_py_file( + os.path.join(get_target_providers_package_folder("google"), "common", "utils", "helpers.py") + ) copyfile( os.path.join(get_source_airflow_folder(), "airflow", "utils", "module_loading.py"), - os.path.join(get_target_providers_package_folder("google"), "common", "utils", - "module_loading.py") + os.path.join( + get_target_providers_package_folder("google"), "common", "utils", "module_loading.py" + ), ) ( - self.qry. - select_module("airflow.utils.python_virtualenv"). - filter(callback=google_package_filter). - rename("airflow.providers.google.common.utils.python_virtualenv") + self.qry.select_module("airflow.utils.python_virtualenv") + .filter(callback=google_package_filter) + .rename("airflow.providers.google.common.utils.python_virtualenv") ) copyfile( os.path.join(get_source_airflow_folder(), "airflow", "utils", "process_utils.py"), - os.path.join(get_target_providers_package_folder("google"), "common", "utils", "process_utils.py") + os.path.join( + get_target_providers_package_folder("google"), "common", "utils", "process_utils.py" + ), ) ( - self.qry. - select_module("airflow.utils.process_utils"). - filter(callback=google_package_filter). - rename("airflow.providers.google.common.utils.process_utils") + self.qry.select_module("airflow.utils.process_utils") + .filter(callback=google_package_filter) + .rename("airflow.providers.google.common.utils.process_utils") ) ( - self.qry. - select_module("airflow.utils.helpers"). - filter(callback=google_package_filter). - rename("airflow.providers.google.common.utils.helpers") + self.qry.select_module("airflow.utils.helpers") + .filter(callback=google_package_filter) + .rename("airflow.providers.google.common.utils.helpers") ) ( - self.qry. - select_module("airflow.utils.module_loading"). - filter(callback=google_package_filter). - rename("airflow.providers.google.common.utils.module_loading") + self.qry.select_module("airflow.utils.module_loading") + .filter(callback=google_package_filter) + .rename("airflow.providers.google.common.utils.module_loading") ) ( # Fix BaseOperatorLinks imports - self.qry.select_module("airflow.models"). - is_filename(include=r"bigquery\.py|mlengine\.py"). - filter(callback=google_package_filter). - filter(pure_airflow_models_filter). - rename("airflow.models.baseoperator") + self.qry.select_module("airflow.models") + .is_filename(include=r"bigquery\.py|mlengine\.py") + .filter(callback=google_package_filter) + .filter(pure_airflow_models_filter) + .rename("airflow.models.baseoperator") ) def refactor_odbc_package(self): @@ -608,22 +620,21 @@ def refactor_odbc_package(self): """ + def odbc_package_filter(node: LN, capture: Capture, filename: Filename) -> bool: return filename.startswith("./airflow/providers/odbc/") os.makedirs(os.path.join(get_target_providers_folder(), "odbc", "utils"), exist_ok=True) copyfile( os.path.join(get_source_airflow_folder(), "airflow", "utils", "__init__.py"), - os.path.join(get_target_providers_package_folder("odbc"), "utils", "__init__.py") + os.path.join(get_target_providers_package_folder("odbc"), "utils", "__init__.py"), ) - copy_helper_py_file(os.path.join( - get_target_providers_package_folder("odbc"), "utils", "helpers.py")) + copy_helper_py_file(os.path.join(get_target_providers_package_folder("odbc"), "utils", "helpers.py")) ( - self.qry. - select_module("airflow.utils.helpers"). - filter(callback=odbc_package_filter). - rename("airflow.providers.odbc.utils.helpers") + self.qry.select_module("airflow.utils.helpers") + .filter(callback=odbc_package_filter) + .rename("airflow.providers.odbc.utils.helpers") ) def refactor_kubernetes_pod_operator(self): @@ -631,11 +642,10 @@ def kubernetes_package_filter(node: LN, capture: Capture, filename: Filename) -> return filename.startswith("./airflow/providers/cncf/kubernetes") ( - self.qry. - select_class("KubernetesPodOperator"). - select_method("add_xcom_sidecar"). - filter(callback=kubernetes_package_filter). - rename("add_sidecar") + self.qry.select_class("KubernetesPodOperator") + .select_method("add_xcom_sidecar") + .filter(callback=kubernetes_package_filter) + .rename("add_sidecar") ) def do_refactor(self, in_process: bool = False) -> None: # noqa @@ -653,7 +663,7 @@ def do_refactor(self, in_process: bool = False) -> None: # noqa if __name__ == '__main__': - BACKPORT_PACKAGES = (os.getenv('BACKPORT_PACKAGES') == "true") + BACKPORT_PACKAGES = os.getenv('BACKPORT_PACKAGES') == "true" in_process = False if len(sys.argv) > 1: if sys.argv[1] in ['--help', '-h']: diff --git a/provider_packages/remove_old_releases.py b/provider_packages/remove_old_releases.py index 7383c0d548cff..fb8643d74853d 100644 --- a/provider_packages/remove_old_releases.py +++ b/provider_packages/remove_old_releases.py @@ -40,13 +40,15 @@ class VersionedFile(NamedTuple): def split_version_and_suffix(file_name: str, suffix: str) -> VersionedFile: - no_suffix_file = file_name[:-len(suffix)] + no_suffix_file = file_name[: -len(suffix)] no_version_file, version = no_suffix_file.rsplit("-", 1) - return VersionedFile(base=no_version_file + "-", - version=version, - suffix=suffix, - type=no_version_file + "-" + suffix, - comparable_version=LooseVersion(version)) + return VersionedFile( + base=no_version_file + "-", + version=version, + suffix=suffix, + type=no_version_file + "-" + suffix, + comparable_version=LooseVersion(version), + ) def process_all_files(directory: str, suffix: str, execute: bool): @@ -63,8 +65,10 @@ def process_all_files(directory: str, suffix: str, execute: bool): for package_types in package_types_dicts.values(): if len(package_types) == 1: versioned_file = package_types[0] - print("Leaving the only version: " - f"${versioned_file.base + versioned_file.version + versioned_file.suffix}") + print( + "Leaving the only version: " + f"${versioned_file.base + versioned_file.version + versioned_file.suffix}" + ) # Leave only last version from each type for versioned_file in package_types[:-1]: command = ["svn", "rm", versioned_file.base + versioned_file.version + versioned_file.suffix] @@ -76,10 +80,16 @@ def process_all_files(directory: str, suffix: str, execute: bool): def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description='Removes old releases.') - parser.add_argument('--directory', dest='directory', action='store', required=True, - help='Directory to remove old releases in') - parser.add_argument('--execute', dest='execute', action='store_true', - help='Execute the removal rather than dry run') + parser.add_argument( + '--directory', + dest='directory', + action='store', + required=True, + help='Directory to remove old releases in', + ) + parser.add_argument( + '--execute', dest='execute', action='store_true', help='Execute the removal rather than dry run' + ) return parser.parse_args() diff --git a/scripts/ci/pre_commit/pre_commit_check_order_setup.py b/scripts/ci/pre_commit/pre_commit_check_order_setup.py index f255c2d0e34de..cfcb297e68e4b 100755 --- a/scripts/ci/pre_commit/pre_commit_check_order_setup.py +++ b/scripts/ci/pre_commit/pre_commit_check_order_setup.py @@ -38,8 +38,9 @@ def _check_list_sorted(the_list: List[str], message: str) -> None: while sorted_list[i] == the_list[i]: i += 1 print(f"{message} NOK") - errors.append(f"ERROR in {message}. First wrongly sorted element" - f" {the_list[i]}. Should be {sorted_list[i]}") + errors.append( + f"ERROR in {message}. First wrongly sorted element" f" {the_list[i]}. Should be {sorted_list[i]}" + ) def setup() -> str: @@ -55,7 +56,8 @@ def check_main_dependent_group(setup_context: str) -> None: '# Start dependencies group' and '# End dependencies group' in setup.py """ pattern_main_dependent_group = re.compile( - '# Start dependencies group\n(.*)# End dependencies group', re.DOTALL) + '# Start dependencies group\n(.*)# End dependencies group', re.DOTALL + ) main_dependent_group = pattern_main_dependent_group.findall(setup_context)[0] pattern_sub_dependent = re.compile(' = \\[.*?\\]\n', re.DOTALL) @@ -76,8 +78,7 @@ def check_sub_dependent_group(setup_context: str) -> None: pattern_dependent_version = re.compile('[~|><=;].*') for group_name in dependent_group_names: - pattern_sub_dependent = re.compile( - f'{group_name} = \\[(.*?)\\]', re.DOTALL) + pattern_sub_dependent = re.compile(f'{group_name} = \\[(.*?)\\]', re.DOTALL) sub_dependent = pattern_sub_dependent.findall(setup_context)[0] pattern_dependent = re.compile('\'(.*?)\'') dependent = pattern_dependent.findall(sub_dependent) @@ -104,8 +105,7 @@ def check_install_and_setup_requires(setup_context: str) -> None: Test for an order of dependencies in function do_setup section install_requires and setup_requires in setup.py """ - pattern_install_and_setup_requires = re.compile( - '(setup_requires) ?= ?\\[(.*?)\\]', re.DOTALL) + pattern_install_and_setup_requires = re.compile('(setup_requires) ?= ?\\[(.*?)\\]', re.DOTALL) install_and_setup_requires = pattern_install_and_setup_requires.findall(setup_context) for dependent_requires in install_and_setup_requires: @@ -123,7 +123,8 @@ def check_extras_require(setup_context: str) -> None: extras_require in setup.py """ pattern_extras_requires = re.compile( - r'EXTRAS_REQUIREMENTS: Dict\[str, Iterable\[str\]] = {(.*?)}', re.DOTALL) + r'EXTRAS_REQUIREMENTS: Dict\[str, Iterable\[str\]] = {(.*?)}', re.DOTALL + ) extras_requires = pattern_extras_requires.findall(setup_context)[0] pattern_dependent = re.compile('\'(.*?)\'') @@ -137,7 +138,8 @@ def check_provider_requirements(setup_context: str) -> None: providers_require in setup.py """ pattern_extras_requires = re.compile( - r'PROVIDERS_REQUIREMENTS: Dict\[str, Iterable\[str\]\] = {(.*?)}', re.DOTALL) + r'PROVIDERS_REQUIREMENTS: Dict\[str, Iterable\[str\]\] = {(.*?)}', re.DOTALL + ) extras_requires = pattern_extras_requires.findall(setup_context)[0] pattern_dependent = re.compile('"(.*?)"') diff --git a/scripts/ci/pre_commit/pre_commit_check_setup_installation.py b/scripts/ci/pre_commit/pre_commit_check_setup_installation.py index 7c30683c2c805..20bec5ed3959d 100755 --- a/scripts/ci/pre_commit/pre_commit_check_setup_installation.py +++ b/scripts/ci/pre_commit/pre_commit_check_setup_installation.py @@ -44,12 +44,12 @@ def get_extras_from_setup() -> Dict[str, List[str]]: """ setup_content = get_file_content(SETUP_PY_FILE) - extras_section_regex = re.compile( - r'^EXTRAS_REQUIREMENTS: Dict[^{]+{([^}]+)}', re.MULTILINE) + extras_section_regex = re.compile(r'^EXTRAS_REQUIREMENTS: Dict[^{]+{([^}]+)}', re.MULTILINE) extras_section = extras_section_regex.findall(setup_content)[0] extras_regex = re.compile( - rf'^\s+[\"\']({PY_IDENTIFIER})[\"\']:\s*({PY_IDENTIFIER})[^#\n]*(#\s*TODO.*)?$', re.MULTILINE) + rf'^\s+[\"\']({PY_IDENTIFIER})[\"\']:\s*({PY_IDENTIFIER})[^#\n]*(#\s*TODO.*)?$', re.MULTILINE + ) extras_dict: Dict[str, List[str]] = {} for extras in extras_regex.findall(extras_section): @@ -67,8 +67,9 @@ def get_extras_from_docs() -> List[str]: """ docs_content = get_file_content('docs', DOCS_FILE) - extras_section_regex = re.compile(rf'^\|[^|]+\|.*pip install .apache-airflow\[({PY_IDENTIFIER})\].', - re.MULTILINE) + extras_section_regex = re.compile( + rf'^\|[^|]+\|.*pip install .apache-airflow\[({PY_IDENTIFIER})\].', re.MULTILINE + ) extras = extras_section_regex.findall(docs_content) extras = list(filter(lambda entry: entry != 'all', extras)) @@ -90,10 +91,11 @@ def get_extras_from_docs() -> List[str]: if f"'{extras}'" not in setup_packages_str: output_table += "| {:20} | {:^10} | {:^10} |\n".format(extras, "", "V") - if(output_table == ""): + if output_table == "": exit(0) - print(f""" + print( + f""" ERROR "EXTRAS_REQUIREMENTS" section in {SETUP_PY_FILE} should be synchronized @@ -101,7 +103,8 @@ def get_extras_from_docs() -> List[str]: here is a list of packages that are used but are not documented, or documented although not used. - """) + """ + ) print(".{:_^22}.{:_^12}.{:_^12}.".format("NAME", "SETUP", "INSTALLATION")) print(output_table) diff --git a/scripts/ci/pre_commit/pre_commit_yaml_to_cfg.py b/scripts/ci/pre_commit/pre_commit_yaml_to_cfg.py index 9b31cef30f1c6..6558901890a14 100755 --- a/scripts/ci/pre_commit/pre_commit_yaml_to_cfg.py +++ b/scripts/ci/pre_commit/pre_commit_yaml_to_cfg.py @@ -91,7 +91,8 @@ def _write_section(configfile, section): section_description = None if section["description"] is not None: section_description = list( - filter(lambda x: (x is not None) or x != "", section["description"].splitlines())) + filter(lambda x: (x is not None) or x != "", section["description"].splitlines()) + ) if section_description: configfile.write("\n") for single_line_desc in section_description: @@ -106,8 +107,7 @@ def _write_section(configfile, section): def _write_option(configfile, idx, option): option_description = None if option["description"] is not None: - option_description = list( - filter(lambda x: x is not None, option["description"].splitlines())) + option_description = list(filter(lambda x: x is not None, option["description"].splitlines())) if option_description: if idx != 0: @@ -141,26 +141,26 @@ def _write_option(configfile, idx, option): if __name__ == '__main__': airflow_config_dir = os.path.join( - os.path.dirname(__file__), - os.pardir, os.pardir, os.pardir, - "airflow", "config_templates") + os.path.dirname(__file__), os.pardir, os.pardir, os.pardir, "airflow", "config_templates" + ) airflow_default_config_path = os.path.join(airflow_config_dir, "default_airflow.cfg") airflow_config_yaml_file_path = os.path.join(airflow_config_dir, "config.yml") write_config( - yaml_config_file_path=airflow_config_yaml_file_path, - default_cfg_file_path=airflow_default_config_path + yaml_config_file_path=airflow_config_yaml_file_path, default_cfg_file_path=airflow_default_config_path ) providers_dir = os.path.join( - os.path.dirname(__file__), - os.pardir, os.pardir, os.pardir, - "airflow", "providers") + os.path.dirname(__file__), os.pardir, os.pardir, os.pardir, "airflow", "providers" + ) for root, dir_names, file_names in os.walk(providers_dir): for file_name in file_names: - if root.endswith("config_templates") and file_name == 'config.yml' and \ - os.path.isfile(os.path.join(root, "default_config.cfg")): + if ( + root.endswith("config_templates") + and file_name == 'config.yml' + and os.path.isfile(os.path.join(root, "default_config.cfg")) + ): write_config( yaml_config_file_path=os.path.join(root, "config.yml"), - default_cfg_file_path=os.path.join(root, "default_config.cfg") + default_cfg_file_path=os.path.join(root, "default_config.cfg"), ) diff --git a/scripts/in_container/update_quarantined_test_status.py b/scripts/in_container/update_quarantined_test_status.py index ba5bf3f3209c4..73cf962904c58 100755 --- a/scripts/in_container/update_quarantined_test_status.py +++ b/scripts/in_container/update_quarantined_test_status.py @@ -66,8 +66,10 @@ class TestHistory(NamedTuple): def get_url(result: TestResult) -> str: - return f"[{result.name}](https://github.com/{user}/{repo}/blob/" \ - f"master/{result.file}?test_id={result.test_id}#L{result.line})" + return ( + f"[{result.name}](https://github.com/{user}/{repo}/blob/" + f"master/{result.file}?test_id={result.test_id}#L{result.line})" + ) def parse_state_history(history_string: str) -> List[bool]: @@ -136,11 +138,7 @@ def update_test_history(history: TestHistory, last_status: bool): def create_test_history(result: TestResult) -> TestHistory: print(f"Creating test history {result}") return TestHistory( - test_id=result.test_id, - name=result.name, - url=get_url(result), - states=[result.result], - comment="" + test_id=result.test_id, name=result.name, url=get_url(result), states=[result.result], comment="" ) @@ -151,9 +149,9 @@ def get_history_status(history: TestHistory): return "Flaky" if all(history.states): return "Stable" - if all(history.states[0:num_runs - 1]): + if all(history.states[0 : num_runs - 1]): return "Just one more" - if all(history.states[0:int(num_runs / 2)]): + if all(history.states[0 : int(num_runs / 2)]): return "Almost there" return "Flaky" @@ -163,13 +161,15 @@ def get_table(history_map: Dict[str, TestHistory]) -> str: the_table: List[List[str]] = [] for ordered_key in sorted(history_map.keys()): history = history_map[ordered_key] - the_table.append([ - history.url, - "Succeeded" if history.states[0] else "Failed", - " ".join([reverse_status_map[state] for state in history.states]), - get_history_status(history), - history.comment - ]) + the_table.append( + [ + history.url, + "Succeeded" if history.states[0] else "Failed", + " ".join([reverse_status_map[state] for state in history.states]), + get_history_status(history), + history.comment, + ] + ) return tabulate(the_table, headers, tablefmt="github") @@ -187,14 +187,16 @@ def get_table(history_map: Dict[str, TestHistory]) -> str: if len(test.contents) > 0 and test.contents[0].name == 'skipped': print(f"skipping {test['name']}") continue - test_results.append(TestResult( - test_id=test['classname'] + "::" + test['name'], - file=test['file'], - line=test['line'], - name=test['name'], - classname=test['classname'], - result=len(test.contents) == 0 - )) + test_results.append( + TestResult( + test_id=test['classname'] + "::" + test['name'], + file=test['file'], + line=test['line'], + name=test['name'], + classname=test['classname'], + result=len(test.contents) == 0, + ) + ) token = os.environ.get("GITHUB_TOKEN") print(f"Token: {token}") @@ -221,8 +223,7 @@ def get_table(history_map: Dict[str, TestHistory]) -> str: for test_result in test_results: previous_results = parsed_test_map.get(test_result.test_id) if previous_results: - updated_results = update_test_history( - previous_results, test_result.result) + updated_results = update_test_history(previous_results, test_result.result) new_test_map[previous_results.test_id] = updated_results else: new_history = create_test_history(test_result) @@ -234,8 +235,9 @@ def get_table(history_map: Dict[str, TestHistory]) -> str: print(table) print() with open(join(dirname(realpath(__file__)), "quarantine_issue_header.md")) as f: - header = jinja2.Template(f.read(), autoescape=True, undefined=StrictUndefined).\ - render(DATE_UTC_NOW=datetime.utcnow()) - quarantined_issue.edit(title=None, - body=header + "\n\n" + str(table), - state='open' if len(test_results) > 0 else 'closed') + header = jinja2.Template(f.read(), autoescape=True, undefined=StrictUndefined).render( + DATE_UTC_NOW=datetime.utcnow() + ) + quarantined_issue.edit( + title=None, body=header + "\n\n" + str(table), state='open' if len(test_results) > 0 else 'closed' + ) diff --git a/scripts/tools/list-integrations.py b/scripts/tools/list-integrations.py index 73c514c7c1ee3..ebfbc67be58fb 100755 --- a/scripts/tools/list-integrations.py +++ b/scripts/tools/list-integrations.py @@ -96,9 +96,7 @@ def _find_clazzes(directory, base_class): """ parser = argparse.ArgumentParser( # noqa - description=HELP, - formatter_class=argparse.RawTextHelpFormatter, - epilog=EPILOG + description=HELP, formatter_class=argparse.RawTextHelpFormatter, epilog=EPILOG ) # argparse handle `-h/--help/` internally parser.parse_args() @@ -111,8 +109,9 @@ def _find_clazzes(directory, base_class): } for integration_base_directory, integration_class in RESOURCE_TYPES.items(): - for integration_directory in glob(f"{AIRFLOW_ROOT}/airflow/**/{integration_base_directory}", - recursive=True): + for integration_directory in glob( + f"{AIRFLOW_ROOT}/airflow/**/{integration_base_directory}", recursive=True + ): if "contrib" in integration_directory: continue diff --git a/setup.py b/setup.py index f0872b0c8d8bc..992ee1f438e78 100644 --- a/setup.py +++ b/setup.py @@ -131,6 +131,7 @@ def git_version(version_: str) -> str: """ try: import git + try: repo = git.Repo(os.path.join(*[my_dir, '.git'])) except git.NoSuchPathError: @@ -208,10 +209,7 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version cloudant = [ 'cloudant>=2.0', ] -dask = [ - 'cloudpickle>=1.4.1, <1.5.0', - 'distributed>=2.11.1, <2.20' -] +dask = ['cloudpickle>=1.4.1, <1.5.0', 'distributed>=2.11.1, <2.20'] databricks = [ 'requests>=2.20.0, <3', ] @@ -227,7 +225,7 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version 'sphinx-rtd-theme>=0.1.6', 'sphinxcontrib-httpdomain>=1.7.0', "sphinxcontrib-redoc>=1.6.0", - "sphinxcontrib-spelling==5.2.1" + "sphinxcontrib-spelling==5.2.1", ] docker = [ 'docker~=3.0', @@ -316,9 +314,7 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version 'cryptography>=2.0.0', 'kubernetes>=3.0.0, <12.0.0', ] -kylin = [ - 'kylinpy>=2.6' -] +kylin = ['kylinpy>=2.6'] ldap = [ 'ldap3>=2.5.1', ] @@ -359,9 +355,7 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version postgres = [ 'psycopg2-binary>=2.7.4', ] -presto = [ - 'presto-python-client>=0.7.0,<0.8' -] +presto = ['presto-python-client>=0.7.0,<0.8'] qds = [ 'qds-sdk>=1.10.4', ] @@ -429,8 +423,21 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version ] # End dependencies group -all_dbs = (cassandra + cloudant + druid + exasol + hdfs + hive + mongo + mssql + mysql + - pinot + postgres + presto + vertica) +all_dbs = ( + cassandra + + cloudant + + druid + + exasol + + hdfs + + hive + + mongo + + mssql + + mysql + + pinot + + postgres + + presto + + vertica +) ############################################################################################################ # IMPORTANT NOTE!!!!!!!!!!!!!!! @@ -456,7 +463,7 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version 'jira', 'mongomock', 'moto==1.3.14', # TODO - fix Datasync issues to get higher version of moto: - # See: https://github.com/apache/airflow/issues/10985 + # See: https://github.com/apache/airflow/issues/10985 'parameterized', 'paramiko', 'pipdeptree', @@ -599,7 +606,7 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version 'jdbc': jdbc, 'jira': jira, 'kerberos': kerberos, - 'kubernetes': kubernetes, # TODO: remove this in Airflow 2.1 + 'kubernetes': kubernetes, # TODO: remove this in Airflow 2.1 'ldap': ldap, "microsoft.azure": azure, "microsoft.mssql": mssql, @@ -639,16 +646,22 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version } # Make devel_all contain all providers + extras + unique -devel_all = list(set(devel + - [req for req_list in EXTRAS_REQUIREMENTS.values() for req in req_list] + - [req for req_list in PROVIDERS_REQUIREMENTS.values() for req in req_list])) +devel_all = list( + set( + devel + + [req for req_list in EXTRAS_REQUIREMENTS.values() for req in req_list] + + [req for req_list in PROVIDERS_REQUIREMENTS.values() for req in req_list] + ) +) PACKAGES_EXCLUDED_FOR_ALL = [] if PY3: - PACKAGES_EXCLUDED_FOR_ALL.extend([ - 'snakebite', - ]) + PACKAGES_EXCLUDED_FOR_ALL.extend( + [ + 'snakebite', + ] + ) # Those packages are excluded because they break tests (downgrading mock) and they are # not needed to run our test suite. @@ -668,13 +681,17 @@ def is_package_excluded(package: str, exclusion_list: List[str]): return any(package.startswith(excluded_package) for excluded_package in exclusion_list) -devel_all = [package for package in devel_all if not is_package_excluded( - package=package, - exclusion_list=PACKAGES_EXCLUDED_FOR_ALL) +devel_all = [ + package + for package in devel_all + if not is_package_excluded(package=package, exclusion_list=PACKAGES_EXCLUDED_FOR_ALL) ] -devel_ci = [package for package in devel_all if not is_package_excluded( - package=package, - exclusion_list=PACKAGES_EXCLUDED_FOR_CI + PACKAGES_EXCLUDED_FOR_ALL) +devel_ci = [ + package + for package in devel_all + if not is_package_excluded( + package=package, exclusion_list=PACKAGES_EXCLUDED_FOR_CI + PACKAGES_EXCLUDED_FOR_ALL + ) ] EXTRAS_REQUIREMENTS.update( @@ -748,9 +765,11 @@ def is_package_excluded(package: str, exclusion_list: List[str]): def do_setup(): """Perform the Airflow package setup.""" install_providers_from_sources = os.getenv('INSTALL_PROVIDERS_FROM_SOURCES') - exclude_patterns = \ - [] if install_providers_from_sources and install_providers_from_sources == 'true' \ + exclude_patterns = ( + [] + if install_providers_from_sources and install_providers_from_sources == 'true' else ['airflow.providers', 'airflow.providers.*'] + ) write_version() setup( name='apache-airflow', @@ -759,13 +778,15 @@ def do_setup(): long_description_content_type='text/markdown', license='Apache License 2.0', version=version, - packages=find_packages( - include=['airflow*'], - exclude=exclude_patterns), + packages=find_packages(include=['airflow*'], exclude=exclude_patterns), package_data={ 'airflow': ['py.typed'], - '': ['airflow/alembic.ini', "airflow/git_version", "*.ipynb", - "airflow/providers/cncf/kubernetes/example_dags/*.yaml"], + '': [ + 'airflow/alembic.ini', + "airflow/git_version", + "*.ipynb", + "airflow/providers/cncf/kubernetes/example_dags/*.yaml", + ], 'airflow.api_connexion.openapi': ['*.yaml'], 'airflow.serialization': ["*.json"], }, @@ -800,8 +821,7 @@ def do_setup(): author='Apache Software Foundation', author_email='dev@airflow.apache.org', url='http://airflow.apache.org/', - download_url=( - 'https://archive.apache.org/dist/airflow/' + version), + download_url=('https://archive.apache.org/dist/airflow/' + version), cmdclass={ 'extra_clean': CleanCommand, 'compile_assets': CompileAssets, diff --git a/tests/airflow_pylint/disable_checks_for_tests.py b/tests/airflow_pylint/disable_checks_for_tests.py index e7b23b2f9abe3..cbdf9b2ce936f 100644 --- a/tests/airflow_pylint/disable_checks_for_tests.py +++ b/tests/airflow_pylint/disable_checks_for_tests.py @@ -19,8 +19,9 @@ from astroid import MANAGER, scoped_nodes from pylint.lint import PyLinter -DISABLED_CHECKS_FOR_TESTS = \ +DISABLED_CHECKS_FOR_TESTS = ( "missing-docstring, no-self-use, too-many-public-methods, protected-access, do-not-use-asserts" +) def register(_: PyLinter): @@ -42,18 +43,21 @@ def transform(mod): :param mod: astroid module :return: None """ - if mod.name.startswith("test_") or \ - mod.name.startswith("tests.") or \ - mod.name.startswith("kubernetes_tests.") or \ - mod.name.startswith("chart."): + if ( + mod.name.startswith("test_") + or mod.name.startswith("tests.") + or mod.name.startswith("kubernetes_tests.") + or mod.name.startswith("chart.") + ): decoded_lines = mod.stream().read().decode("utf-8").split("\n") if decoded_lines[0].startswith("# pylint: disable="): decoded_lines[0] = decoded_lines[0] + " " + DISABLED_CHECKS_FOR_TESTS elif decoded_lines[0].startswith("#") or decoded_lines[0].strip() == "": decoded_lines[0] = "# pylint: disable=" + DISABLED_CHECKS_FOR_TESTS else: - raise Exception(f"The first line of module {mod.name} is not a comment or empty. " - f"Please make sure it is!") + raise Exception( + f"The first line of module {mod.name} is not a comment or empty. " f"Please make sure it is!" + ) # pylint will read from `.file_bytes` attribute later when tokenization mod.file_bytes = "\n".join(decoded_lines).encode("utf-8") diff --git a/tests/airflow_pylint/do_not_use_asserts.py b/tests/airflow_pylint/do_not_use_asserts.py index e042c694bad38..47a0e208b3862 100644 --- a/tests/airflow_pylint/do_not_use_asserts.py +++ b/tests/airflow_pylint/do_not_use_asserts.py @@ -29,13 +29,14 @@ class DoNotUseAssertsChecker(BaseChecker): 'E7401': ( 'Do not use asserts.', 'do-not-use-asserts', - 'Asserts should not be used in the main Airflow code.' + 'Asserts should not be used in the main Airflow code.', ), } def visit_assert(self, node): self.add_message( - self.name, node=node, + self.name, + node=node, ) diff --git a/tests/always/test_example_dags.py b/tests/always/test_example_dags.py index 0261a4370c389..339a42a4c2f47 100644 --- a/tests/always/test_example_dags.py +++ b/tests/always/test_example_dags.py @@ -26,9 +26,7 @@ os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir, os.pardir) ) -NO_DB_QUERY_EXCEPTION = [ - "/airflow/example_dags/example_subdag_operator.py" -] +NO_DB_QUERY_EXCEPTION = ["/airflow/example_dags/example_subdag_operator.py"] class TestExampleDags(unittest.TestCase): diff --git a/tests/always/test_project_structure.py b/tests/always/test_project_structure.py index 6da4a53649420..bcde9bb9afe1e 100644 --- a/tests/always/test_project_structure.py +++ b/tests/always/test_project_structure.py @@ -57,8 +57,12 @@ def test_providers_modules_should_have_tests(self): Assert every module in /airflow/providers has a corresponding test_ file in tests/airflow/providers. """ # Deprecated modules that don't have corresponded test - expected_missing_providers_modules = {('airflow/providers/amazon/aws/hooks/aws_dynamodb.py', - 'tests/providers/amazon/aws/hooks/test_aws_dynamodb.py')} + expected_missing_providers_modules = { + ( + 'airflow/providers/amazon/aws/hooks/aws_dynamodb.py', + 'tests/providers/amazon/aws/hooks/test_aws_dynamodb.py', + ) + } # TODO: Should we extend this test to cover other directories? modules_files = glob.glob(f"{ROOT_FOLDER}/airflow/providers/**/*.py", recursive=True) @@ -71,13 +75,13 @@ def test_providers_modules_should_have_tests(self): modules_files = (f for f in modules_files if not f.endswith("__init__.py")) # Change airflow/ to tests/ expected_test_files = ( - f'tests/{f.partition("/")[2]}' - for f in modules_files if not f.endswith("__init__.py") + f'tests/{f.partition("/")[2]}' for f in modules_files if not f.endswith("__init__.py") ) # Add test_ prefix to filename expected_test_files = ( f'{f.rpartition("/")[0]}/test_{f.rpartition("/")[2]}' - for f in expected_test_files if not f.endswith("__init__.py") + for f in expected_test_files + if not f.endswith("__init__.py") ) current_test_files = glob.glob(f"{ROOT_FOLDER}/tests/providers/**/*.py", recursive=True) @@ -126,9 +130,7 @@ def test_example_dags(self): ) example_dags_files = self.find_resource_files(resource_type="example_dags") # Generate tuple of department and service e.g. ('marketing_platform', 'display_video') - operator_sets = [ - (f.split("/")[-3], f.split("/")[-1].rsplit(".")[0]) for f in operators_modules - ] + operator_sets = [(f.split("/")[-3], f.split("/")[-1].rsplit(".")[0]) for f in operators_modules] example_sets = [ (f.split("/")[-3], f.split("/")[-1].rsplit(".")[0].replace("example_", "", 1)) for f in example_dags_files @@ -174,7 +176,10 @@ def has_example_dag(operator_set): @parameterized.expand( [ - (resource_type, suffix,) + ( + resource_type, + suffix, + ) for suffix in ["_system.py", "_system_helper.py"] for resource_type in ["operators", "sensors", "tranfers"] ] @@ -186,12 +191,8 @@ def test_detect_invalid_system_tests(self, resource_type, filename_suffix): files = {f for f in operators_tests if f.endswith(filename_suffix)} expected_files = (f"tests/{f[8:]}" for f in operators_files) - expected_files = ( - f.replace(".py", filename_suffix).replace("/test_", "/") for f in expected_files - ) - expected_files = { - f'{f.rpartition("/")[0]}/test_{f.rpartition("/")[2]}' for f in expected_files - } + expected_files = (f.replace(".py", filename_suffix).replace("/test_", "/") for f in expected_files) + expected_files = {f'{f.rpartition("/")[0]}/test_{f.rpartition("/")[2]}' for f in expected_files} self.assertEqual(set(), files - expected_files) @@ -200,7 +201,7 @@ def find_resource_files( top_level_directory: str = "airflow", department: str = "*", resource_type: str = "*", - service: str = "*" + service: str = "*", ): python_files = glob.glob( f"{ROOT_FOLDER}/{top_level_directory}/providers/google/{department}/{resource_type}/{service}.py" @@ -215,16 +216,14 @@ def find_resource_files( class TestOperatorsHooks(unittest.TestCase): def test_no_illegal_suffixes(self): illegal_suffixes = ["_operator.py", "_hook.py", "_sensor.py"] - files = itertools.chain(*[ - glob.glob(f"{ROOT_FOLDER}/{part}/providers/**/{resource_type}/*.py", recursive=True) - for resource_type in ["operators", "hooks", "sensors", "example_dags"] - for part in ["airflow", "tests"] - ]) - - invalid_files = [ - f - for f in files - if any(f.endswith(suffix) for suffix in illegal_suffixes) - ] + files = itertools.chain( + *[ + glob.glob(f"{ROOT_FOLDER}/{part}/providers/**/{resource_type}/*.py", recursive=True) + for resource_type in ["operators", "hooks", "sensors", "example_dags"] + for part in ["airflow", "tests"] + ] + ) + + invalid_files = [f for f in files if any(f.endswith(suffix) for suffix in illegal_suffixes)] self.assertEqual([], invalid_files) diff --git a/tests/api/auth/backend/test_basic_auth.py b/tests/api/auth/backend/test_basic_auth.py index 06461347384a3..4cf427780e201 100644 --- a/tests/api/auth/backend/test_basic_auth.py +++ b/tests/api/auth/backend/test_basic_auth.py @@ -29,9 +29,7 @@ class TestBasicAuth(unittest.TestCase): def setUp(self) -> None: - with conf_vars( - {("api", "auth_backend"): "airflow.api.auth.backend.basic_auth"} - ): + with conf_vars({("api", "auth_backend"): "airflow.api.auth.backend.basic_auth"}): self.app = create_app(testing=True) self.appbuilder = self.app.appbuilder # pylint: disable=no-member @@ -52,9 +50,7 @@ def test_success(self): clear_db_pools() with self.app.test_client() as test_client: - response = test_client.get( - "/api/v1/pools", headers={"Authorization": token} - ) + response = test_client.get("/api/v1/pools", headers={"Authorization": token}) assert current_user.email == "test@fab.org" assert response.status_code == 200 @@ -72,37 +68,37 @@ def test_success(self): "total_entries": 1, } - @parameterized.expand([ - ("basic",), - ("basic ",), - ("bearer",), - ("test:test",), - (b64encode(b"test:test").decode(),), - ("bearer ",), - ("basic: ",), - ("basic 123",), - ]) + @parameterized.expand( + [ + ("basic",), + ("basic ",), + ("bearer",), + ("test:test",), + (b64encode(b"test:test").decode(),), + ("bearer ",), + ("basic: ",), + ("basic 123",), + ] + ) def test_malformed_headers(self, token): with self.app.test_client() as test_client: - response = test_client.get( - "/api/v1/pools", headers={"Authorization": token} - ) + response = test_client.get("/api/v1/pools", headers={"Authorization": token}) assert response.status_code == 401 assert response.headers["Content-Type"] == "application/problem+json" assert response.headers["WWW-Authenticate"] == "Basic" assert_401(response) - @parameterized.expand([ - ("basic " + b64encode(b"test").decode(),), - ("basic " + b64encode(b"test:").decode(),), - ("basic " + b64encode(b"test:123").decode(),), - ("basic " + b64encode(b"test test").decode(),), - ]) + @parameterized.expand( + [ + ("basic " + b64encode(b"test").decode(),), + ("basic " + b64encode(b"test:").decode(),), + ("basic " + b64encode(b"test:123").decode(),), + ("basic " + b64encode(b"test test").decode(),), + ] + ) def test_invalid_auth_header(self, token): with self.app.test_client() as test_client: - response = test_client.get( - "/api/v1/pools", headers={"Authorization": token} - ) + response = test_client.get("/api/v1/pools", headers={"Authorization": token}) assert response.status_code == 401 assert response.headers["Content-Type"] == "application/problem+json" assert response.headers["WWW-Authenticate"] == "Basic" @@ -110,9 +106,7 @@ def test_invalid_auth_header(self, token): def test_experimental_api(self): with self.app.test_client() as test_client: - response = test_client.get( - "/api/experimental/pools", headers={"Authorization": "Basic"} - ) + response = test_client.get("/api/experimental/pools", headers={"Authorization": "Basic"}) assert response.status_code == 401 assert response.headers["WWW-Authenticate"] == "Basic" assert response.data == b'Unauthorized' @@ -120,7 +114,7 @@ def test_experimental_api(self): clear_db_pools() response = test_client.get( "/api/experimental/pools", - headers={"Authorization": "Basic " + b64encode(b"test:test").decode()} + headers={"Authorization": "Basic " + b64encode(b"test:test").decode()}, ) assert response.status_code == 200 assert response.json[0]["pool"] == 'default_pool' diff --git a/tests/api/auth/test_client.py b/tests/api/auth/test_client.py index 3275f6e58c03f..8652b12772b5f 100644 --- a/tests/api/auth/test_client.py +++ b/tests/api/auth/test_client.py @@ -23,14 +23,15 @@ class TestGetCurrentApiClient(unittest.TestCase): - @mock.patch("airflow.api.client.json_client.Client") @mock.patch("airflow.api.auth.backend.default.CLIENT_AUTH", "CLIENT_AUTH") - @conf_vars({ - ("api", 'auth_backend'): 'airflow.api.auth.backend.default', - ("cli", 'api_client'): 'airflow.api.client.json_client', - ("cli", 'endpoint_url'): 'http://localhost:1234', - }) + @conf_vars( + { + ("api", 'auth_backend'): 'airflow.api.auth.backend.default', + ("cli", 'api_client'): 'airflow.api.client.json_client', + ("cli", 'endpoint_url'): 'http://localhost:1234', + } + ) def test_should_create_client(self, mock_client): result = get_current_api_client() @@ -41,17 +42,17 @@ def test_should_create_client(self, mock_client): @mock.patch("airflow.api.client.json_client.Client") @mock.patch("airflow.providers.google.common.auth_backend.google_openid.create_client_session") - @conf_vars({ - ("api", 'auth_backend'): 'airflow.providers.google.common.auth_backend.google_openid', - ("cli", 'api_client'): 'airflow.api.client.json_client', - ("cli", 'endpoint_url'): 'http://localhost:1234', - }) + @conf_vars( + { + ("api", 'auth_backend'): 'airflow.providers.google.common.auth_backend.google_openid', + ("cli", 'api_client'): 'airflow.api.client.json_client', + ("cli", 'endpoint_url'): 'http://localhost:1234', + } + ) def test_should_create_google_open_id_client(self, mock_create_client_session, mock_client): result = get_current_api_client() mock_client.assert_called_once_with( - api_base_url='http://localhost:1234', - auth=None, - session=mock_create_client_session.return_value + api_base_url='http://localhost:1234', auth=None, session=mock_create_client_session.return_value ) self.assertEqual(mock_client.return_value, result) diff --git a/tests/api/client/test_local_client.py b/tests/api/client/test_local_client.py index d3260a2cba016..d574615a243cf 100644 --- a/tests/api/client/test_local_client.py +++ b/tests/api/client/test_local_client.py @@ -38,7 +38,6 @@ class TestLocalClient(unittest.TestCase): - @classmethod def setUpClass(cls): super().setUpClass() @@ -67,44 +66,52 @@ def test_trigger_dag(self, mock): with freeze_time(EXECDATE): # no execution date, execution date should be set automatically self.client.trigger_dag(dag_id=test_dag_id) - mock.assert_called_once_with(run_id=run_id, - execution_date=EXECDATE_NOFRACTIONS, - state=State.RUNNING, - conf=None, - external_trigger=True, - dag_hash=ANY) + mock.assert_called_once_with( + run_id=run_id, + execution_date=EXECDATE_NOFRACTIONS, + state=State.RUNNING, + conf=None, + external_trigger=True, + dag_hash=ANY, + ) mock.reset_mock() # execution date with microseconds cutoff self.client.trigger_dag(dag_id=test_dag_id, execution_date=EXECDATE) - mock.assert_called_once_with(run_id=run_id, - execution_date=EXECDATE_NOFRACTIONS, - state=State.RUNNING, - conf=None, - external_trigger=True, - dag_hash=ANY) + mock.assert_called_once_with( + run_id=run_id, + execution_date=EXECDATE_NOFRACTIONS, + state=State.RUNNING, + conf=None, + external_trigger=True, + dag_hash=ANY, + ) mock.reset_mock() # run id custom_run_id = "my_run_id" self.client.trigger_dag(dag_id=test_dag_id, run_id=custom_run_id) - mock.assert_called_once_with(run_id=custom_run_id, - execution_date=EXECDATE_NOFRACTIONS, - state=State.RUNNING, - conf=None, - external_trigger=True, - dag_hash=ANY) + mock.assert_called_once_with( + run_id=custom_run_id, + execution_date=EXECDATE_NOFRACTIONS, + state=State.RUNNING, + conf=None, + external_trigger=True, + dag_hash=ANY, + ) mock.reset_mock() # test conf conf = '{"name": "John"}' self.client.trigger_dag(dag_id=test_dag_id, conf=conf) - mock.assert_called_once_with(run_id=run_id, - execution_date=EXECDATE_NOFRACTIONS, - state=State.RUNNING, - conf=json.loads(conf), - external_trigger=True, - dag_hash=ANY) + mock.assert_called_once_with( + run_id=run_id, + execution_date=EXECDATE_NOFRACTIONS, + state=State.RUNNING, + conf=json.loads(conf), + external_trigger=True, + dag_hash=ANY, + ) mock.reset_mock() def test_delete_dag(self): @@ -129,8 +136,7 @@ def test_get_pools(self): self.client.create_pool(name='foo1', slots=1, description='') self.client.create_pool(name='foo2', slots=2, description='') pools = sorted(self.client.get_pools(), key=lambda p: p[0]) - self.assertEqual(pools, [('default_pool', 128, 'Default pool'), - ('foo1', 1, ''), ('foo2', 2, '')]) + self.assertEqual(pools, [('default_pool', 128, 'Default pool'), ('foo1', 1, ''), ('foo2', 2, '')]) def test_create_pool(self): pool = self.client.create_pool(name='foo', slots=1, description='') diff --git a/tests/api/common/experimental/test_delete_dag.py b/tests/api/common/experimental/test_delete_dag.py index 00058713c63a0..471138f9f5cd3 100644 --- a/tests/api/common/experimental/test_delete_dag.py +++ b/tests/api/common/experimental/test_delete_dag.py @@ -37,7 +37,6 @@ class TestDeleteDAGCatchError(unittest.TestCase): - def setUp(self): self.dagbag = models.DagBag(include_examples=True) self.dag_id = 'example_bash_operator' @@ -59,30 +58,47 @@ def setup_dag_models(self, for_sub_dag=False): if for_sub_dag: self.key = "test_dag_id.test_subdag" - task = DummyOperator(task_id='dummy', - dag=models.DAG(dag_id=self.key, - default_args={'start_date': days_ago(2)}), - owner='airflow') + task = DummyOperator( + task_id='dummy', + dag=models.DAG(dag_id=self.key, default_args={'start_date': days_ago(2)}), + owner='airflow', + ) test_date = days_ago(1) with create_session() as session: session.add(DM(dag_id=self.key, fileloc=self.dag_file_path, is_subdag=for_sub_dag)) session.add(DR(dag_id=self.key, run_type=DagRunType.MANUAL)) - session.add(TI(task=task, - execution_date=test_date, - state=State.SUCCESS)) + session.add(TI(task=task, execution_date=test_date, state=State.SUCCESS)) # flush to ensure task instance if written before # task reschedule because of FK constraint session.flush() - session.add(LOG(dag_id=self.key, task_id=None, task_instance=None, - execution_date=test_date, event="varimport")) - session.add(TF(task=task, execution_date=test_date, - start_date=test_date, end_date=test_date)) - session.add(TR(task=task, execution_date=test_date, - start_date=test_date, end_date=test_date, - try_number=1, reschedule_date=test_date)) - session.add(IE(timestamp=test_date, filename=self.dag_file_path, - stacktrace="NameError: name 'airflow' is not defined")) + session.add( + LOG( + dag_id=self.key, + task_id=None, + task_instance=None, + execution_date=test_date, + event="varimport", + ) + ) + session.add(TF(task=task, execution_date=test_date, start_date=test_date, end_date=test_date)) + session.add( + TR( + task=task, + execution_date=test_date, + start_date=test_date, + end_date=test_date, + try_number=1, + reschedule_date=test_date, + ) + ) + session.add( + IE( + timestamp=test_date, + filename=self.dag_file_path, + stacktrace="NameError: name 'airflow' is not defined", + ) + ) def tearDown(self): with create_session() as session: @@ -102,8 +118,7 @@ def check_dag_models_exists(self): self.assertEqual(session.query(TF).filter(TF.dag_id == self.key).count(), 1) self.assertEqual(session.query(TR).filter(TR.dag_id == self.key).count(), 1) self.assertEqual(session.query(LOG).filter(LOG.dag_id == self.key).count(), 1) - self.assertEqual( - session.query(IE).filter(IE.filename == self.dag_file_path).count(), 1) + self.assertEqual(session.query(IE).filter(IE.filename == self.dag_file_path).count(), 1) def check_dag_models_removed(self, expect_logs=1): with create_session() as session: @@ -113,8 +128,7 @@ def check_dag_models_removed(self, expect_logs=1): self.assertEqual(session.query(TF).filter(TF.dag_id == self.key).count(), 0) self.assertEqual(session.query(TR).filter(TR.dag_id == self.key).count(), 0) self.assertEqual(session.query(LOG).filter(LOG.dag_id == self.key).count(), expect_logs) - self.assertEqual( - session.query(IE).filter(IE.filename == self.dag_file_path).count(), 0) + self.assertEqual(session.query(IE).filter(IE.filename == self.dag_file_path).count(), 0) def test_delete_dag_successful_delete(self): self.setup_dag_models() diff --git a/tests/api/common/experimental/test_mark_tasks.py b/tests/api/common/experimental/test_mark_tasks.py index dde25f5b857b2..4ff093100c93f 100644 --- a/tests/api/common/experimental/test_mark_tasks.py +++ b/tests/api/common/experimental/test_mark_tasks.py @@ -24,7 +24,10 @@ from airflow import models from airflow.api.common.experimental.mark_tasks import ( - _create_dagruns, set_dag_run_state_to_failed, set_dag_run_state_to_running, set_dag_run_state_to_success, + _create_dagruns, + set_dag_run_state_to_failed, + set_dag_run_state_to_running, + set_dag_run_state_to_success, set_state, ) from airflow.models import DagRun @@ -39,7 +42,6 @@ class TestMarkTasks(unittest.TestCase): - @classmethod def setUpClass(cls): models.DagBag(include_examples=True, read_dags_from_db=False).sync_to_db() @@ -53,31 +55,32 @@ def setUpClass(cls): cls.dag3.sync_to_db() cls.execution_dates = [days_ago(2), days_ago(1)] start_date3 = cls.dag3.start_date - cls.dag3_execution_dates = [start_date3, start_date3 + timedelta(days=1), - start_date3 + timedelta(days=2)] + cls.dag3_execution_dates = [ + start_date3, + start_date3 + timedelta(days=1), + start_date3 + timedelta(days=2), + ] def setUp(self): clear_db_runs() - drs = _create_dagruns(self.dag1, self.execution_dates, - state=State.RUNNING, - run_type=DagRunType.SCHEDULED) + drs = _create_dagruns( + self.dag1, self.execution_dates, state=State.RUNNING, run_type=DagRunType.SCHEDULED + ) for dr in drs: dr.dag = self.dag1 dr.verify_integrity() - drs = _create_dagruns(self.dag2, - [self.dag2.start_date], - state=State.RUNNING, - run_type=DagRunType.SCHEDULED) + drs = _create_dagruns( + self.dag2, [self.dag2.start_date], state=State.RUNNING, run_type=DagRunType.SCHEDULED + ) for dr in drs: dr.dag = self.dag2 dr.verify_integrity() - drs = _create_dagruns(self.dag3, - self.dag3_execution_dates, - state=State.SUCCESS, - run_type=DagRunType.MANUAL) + drs = _create_dagruns( + self.dag3, self.dag3_execution_dates, state=State.SUCCESS, run_type=DagRunType.MANUAL + ) for dr in drs: dr.dag = self.dag3 dr.verify_integrity() @@ -89,19 +92,17 @@ def tearDown(self): def snapshot_state(dag, execution_dates): TI = models.TaskInstance with create_session() as session: - return session.query(TI).filter( - TI.dag_id == dag.dag_id, - TI.execution_date.in_(execution_dates) - ).all() + return ( + session.query(TI) + .filter(TI.dag_id == dag.dag_id, TI.execution_date.in_(execution_dates)) + .all() + ) @provide_session def verify_state(self, dag, task_ids, execution_dates, state, old_tis, session=None): TI = models.TaskInstance - tis = session.query(TI).filter( - TI.dag_id == dag.dag_id, - TI.execution_date.in_(execution_dates) - ).all() + tis = session.query(TI).filter(TI.dag_id == dag.dag_id, TI.execution_date.in_(execution_dates)).all() self.assertTrue(len(tis) > 0) @@ -120,63 +121,97 @@ def test_mark_tasks_now(self): # set one task to success but do not commit snapshot = TestMarkTasks.snapshot_state(self.dag1, self.execution_dates) task = self.dag1.get_task("runme_1") - altered = set_state(tasks=[task], execution_date=self.execution_dates[0], - upstream=False, downstream=False, future=False, - past=False, state=State.SUCCESS, commit=False) + altered = set_state( + tasks=[task], + execution_date=self.execution_dates[0], + upstream=False, + downstream=False, + future=False, + past=False, + state=State.SUCCESS, + commit=False, + ) self.assertEqual(len(altered), 1) - self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], - None, snapshot) + self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], None, snapshot) # set one and only one task to success - altered = set_state(tasks=[task], execution_date=self.execution_dates[0], - upstream=False, downstream=False, future=False, - past=False, state=State.SUCCESS, commit=True) + altered = set_state( + tasks=[task], + execution_date=self.execution_dates[0], + upstream=False, + downstream=False, + future=False, + past=False, + state=State.SUCCESS, + commit=True, + ) self.assertEqual(len(altered), 1) - self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], - State.SUCCESS, snapshot) + self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], State.SUCCESS, snapshot) # set no tasks - altered = set_state(tasks=[task], execution_date=self.execution_dates[0], - upstream=False, downstream=False, future=False, - past=False, state=State.SUCCESS, commit=True) + altered = set_state( + tasks=[task], + execution_date=self.execution_dates[0], + upstream=False, + downstream=False, + future=False, + past=False, + state=State.SUCCESS, + commit=True, + ) self.assertEqual(len(altered), 0) - self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], - State.SUCCESS, snapshot) + self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], State.SUCCESS, snapshot) # set task to other than success - altered = set_state(tasks=[task], execution_date=self.execution_dates[0], - upstream=False, downstream=False, future=False, - past=False, state=State.FAILED, commit=True) + altered = set_state( + tasks=[task], + execution_date=self.execution_dates[0], + upstream=False, + downstream=False, + future=False, + past=False, + state=State.FAILED, + commit=True, + ) self.assertEqual(len(altered), 1) - self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], - State.FAILED, snapshot) + self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], State.FAILED, snapshot) # don't alter other tasks snapshot = TestMarkTasks.snapshot_state(self.dag1, self.execution_dates) task = self.dag1.get_task("runme_0") - altered = set_state(tasks=[task], execution_date=self.execution_dates[0], - upstream=False, downstream=False, future=False, - past=False, state=State.SUCCESS, commit=True) + altered = set_state( + tasks=[task], + execution_date=self.execution_dates[0], + upstream=False, + downstream=False, + future=False, + past=False, + state=State.SUCCESS, + commit=True, + ) self.assertEqual(len(altered), 1) - self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], - State.SUCCESS, snapshot) + self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], State.SUCCESS, snapshot) # set one task as FAILED. dag3 has schedule_interval None snapshot = TestMarkTasks.snapshot_state(self.dag3, self.dag3_execution_dates) task = self.dag3.get_task("run_this") - altered = set_state(tasks=[task], execution_date=self.dag3_execution_dates[1], - upstream=False, downstream=False, future=False, - past=False, state=State.FAILED, commit=True) + altered = set_state( + tasks=[task], + execution_date=self.dag3_execution_dates[1], + upstream=False, + downstream=False, + future=False, + past=False, + state=State.FAILED, + commit=True, + ) # exactly one TaskInstance should have been altered self.assertEqual(len(altered), 1) # task should have been marked as failed - self.verify_state(self.dag3, [task.task_id], [self.dag3_execution_dates[1]], - State.FAILED, snapshot) + self.verify_state(self.dag3, [task.task_id], [self.dag3_execution_dates[1]], State.FAILED, snapshot) # tasks on other days should be unchanged - self.verify_state(self.dag3, [task.task_id], [self.dag3_execution_dates[0]], - None, snapshot) - self.verify_state(self.dag3, [task.task_id], [self.dag3_execution_dates[2]], - None, snapshot) + self.verify_state(self.dag3, [task.task_id], [self.dag3_execution_dates[0]], None, snapshot) + self.verify_state(self.dag3, [task.task_id], [self.dag3_execution_dates[2]], None, snapshot) def test_mark_downstream(self): # test downstream @@ -186,9 +221,16 @@ def test_mark_downstream(self): task_ids = [t.task_id for t in relatives] task_ids.append(task.task_id) - altered = set_state(tasks=[task], execution_date=self.execution_dates[0], - upstream=False, downstream=True, future=False, - past=False, state=State.SUCCESS, commit=True) + altered = set_state( + tasks=[task], + execution_date=self.execution_dates[0], + upstream=False, + downstream=True, + future=False, + past=False, + state=State.SUCCESS, + commit=True, + ) self.assertEqual(len(altered), 3) self.verify_state(self.dag1, task_ids, [self.execution_dates[0]], State.SUCCESS, snapshot) @@ -200,28 +242,48 @@ def test_mark_upstream(self): task_ids = [t.task_id for t in relatives] task_ids.append(task.task_id) - altered = set_state(tasks=[task], execution_date=self.execution_dates[0], - upstream=True, downstream=False, future=False, - past=False, state=State.SUCCESS, commit=True) + altered = set_state( + tasks=[task], + execution_date=self.execution_dates[0], + upstream=True, + downstream=False, + future=False, + past=False, + state=State.SUCCESS, + commit=True, + ) self.assertEqual(len(altered), 4) - self.verify_state(self.dag1, task_ids, [self.execution_dates[0]], - State.SUCCESS, snapshot) + self.verify_state(self.dag1, task_ids, [self.execution_dates[0]], State.SUCCESS, snapshot) def test_mark_tasks_future(self): # set one task to success towards end of scheduled dag runs snapshot = TestMarkTasks.snapshot_state(self.dag1, self.execution_dates) task = self.dag1.get_task("runme_1") - altered = set_state(tasks=[task], execution_date=self.execution_dates[0], - upstream=False, downstream=False, future=True, - past=False, state=State.SUCCESS, commit=True) + altered = set_state( + tasks=[task], + execution_date=self.execution_dates[0], + upstream=False, + downstream=False, + future=True, + past=False, + state=State.SUCCESS, + commit=True, + ) self.assertEqual(len(altered), 2) self.verify_state(self.dag1, [task.task_id], self.execution_dates, State.SUCCESS, snapshot) snapshot = TestMarkTasks.snapshot_state(self.dag3, self.dag3_execution_dates) task = self.dag3.get_task("run_this") - altered = set_state(tasks=[task], execution_date=self.dag3_execution_dates[1], - upstream=False, downstream=False, future=True, - past=False, state=State.FAILED, commit=True) + altered = set_state( + tasks=[task], + execution_date=self.dag3_execution_dates[1], + upstream=False, + downstream=False, + future=True, + past=False, + state=State.FAILED, + commit=True, + ) self.assertEqual(len(altered), 2) self.verify_state(self.dag3, [task.task_id], [self.dag3_execution_dates[0]], None, snapshot) self.verify_state(self.dag3, [task.task_id], self.dag3_execution_dates[1:], State.FAILED, snapshot) @@ -230,17 +292,31 @@ def test_mark_tasks_past(self): # set one task to success towards end of scheduled dag runs snapshot = TestMarkTasks.snapshot_state(self.dag1, self.execution_dates) task = self.dag1.get_task("runme_1") - altered = set_state(tasks=[task], execution_date=self.execution_dates[1], - upstream=False, downstream=False, future=False, - past=True, state=State.SUCCESS, commit=True) + altered = set_state( + tasks=[task], + execution_date=self.execution_dates[1], + upstream=False, + downstream=False, + future=False, + past=True, + state=State.SUCCESS, + commit=True, + ) self.assertEqual(len(altered), 2) self.verify_state(self.dag1, [task.task_id], self.execution_dates, State.SUCCESS, snapshot) snapshot = TestMarkTasks.snapshot_state(self.dag3, self.dag3_execution_dates) task = self.dag3.get_task("run_this") - altered = set_state(tasks=[task], execution_date=self.dag3_execution_dates[1], - upstream=False, downstream=False, future=False, - past=True, state=State.FAILED, commit=True) + altered = set_state( + tasks=[task], + execution_date=self.dag3_execution_dates[1], + upstream=False, + downstream=False, + future=False, + past=True, + state=State.FAILED, + commit=True, + ) self.assertEqual(len(altered), 2) self.verify_state(self.dag3, [task.task_id], self.dag3_execution_dates[:2], State.FAILED, snapshot) self.verify_state(self.dag3, [task.task_id], [self.dag3_execution_dates[2]], None, snapshot) @@ -249,12 +325,20 @@ def test_mark_tasks_multiple(self): # set multiple tasks to success snapshot = TestMarkTasks.snapshot_state(self.dag1, self.execution_dates) tasks = [self.dag1.get_task("runme_1"), self.dag1.get_task("runme_2")] - altered = set_state(tasks=tasks, execution_date=self.execution_dates[0], - upstream=False, downstream=False, future=False, - past=False, state=State.SUCCESS, commit=True) + altered = set_state( + tasks=tasks, + execution_date=self.execution_dates[0], + upstream=False, + downstream=False, + future=False, + past=False, + state=State.SUCCESS, + commit=True, + ) self.assertEqual(len(altered), 2) - self.verify_state(self.dag1, [task.task_id for task in tasks], [self.execution_dates[0]], - State.SUCCESS, snapshot) + self.verify_state( + self.dag1, [task.task_id for task in tasks], [self.execution_dates[0]], State.SUCCESS, snapshot + ) # TODO: this backend should be removed once a fixing solution is found later # We skip it here because this test case is working with Postgres & SQLite @@ -267,20 +351,25 @@ def test_mark_tasks_subdag(self): task_ids = [t.task_id for t in relatives] task_ids.append(task.task_id) - altered = set_state(tasks=[task], execution_date=self.execution_dates[0], - upstream=False, downstream=True, future=False, - past=False, state=State.SUCCESS, commit=True) + altered = set_state( + tasks=[task], + execution_date=self.execution_dates[0], + upstream=False, + downstream=True, + future=False, + past=False, + state=State.SUCCESS, + commit=True, + ) self.assertEqual(len(altered), 14) # cannot use snapshot here as that will require drilling down the # sub dag tree essentially recreating the same code as in the # tested logic. - self.verify_state(self.dag2, task_ids, [self.execution_dates[0]], - State.SUCCESS, []) + self.verify_state(self.dag2, task_ids, [self.execution_dates[0]], State.SUCCESS, []) class TestMarkDAGRun(unittest.TestCase): - @classmethod def setUpClass(cls): dagbag = models.DagBag(include_examples=True, read_dags_from_db=False) @@ -318,17 +407,12 @@ def _verify_task_instance_states_remain_default(self, dr): @provide_session def _verify_task_instance_states(self, dag, date, state, session=None): TI = models.TaskInstance - tis = session.query(TI)\ - .filter(TI.dag_id == dag.dag_id, TI.execution_date == date) + tis = session.query(TI).filter(TI.dag_id == dag.dag_id, TI.execution_date == date) for ti in tis: self.assertEqual(ti.state, state) def _create_test_dag_run(self, state, date): - return self.dag1.create_dagrun( - run_type=DagRunType.MANUAL, - state=state, - execution_date=date - ) + return self.dag1.create_dagrun(run_type=DagRunType.MANUAL, state=state, execution_date=date) def _verify_dag_run_state(self, dag, date, state): drs = models.DagRun.find(dag_id=dag.dag_id, execution_date=date) @@ -341,10 +425,7 @@ def _verify_dag_run_dates(self, dag, date, state, middle_time, session=None): # When target state is RUNNING, we should set start_date, # otherwise we should set end_date. DR = DagRun - dr = session.query(DR).filter( - DR.dag_id == dag.dag_id, - DR.execution_date == date - ).one() + dr = session.query(DR).filter(DR.dag_id == dag.dag_id, DR.execution_date == date).one() if state == State.RUNNING: # Since the DAG is running, the start_date must be updated after creation self.assertGreater(dr.start_date, middle_time) @@ -515,19 +596,19 @@ def test_set_state_with_multiple_dagruns(self, session=None): run_type=DagRunType.MANUAL, state=State.FAILED, execution_date=self.execution_dates[0], - session=session + session=session, ) self.dag2.create_dagrun( run_type=DagRunType.MANUAL, state=State.FAILED, execution_date=self.execution_dates[1], - session=session + session=session, ) self.dag2.create_dagrun( run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=self.execution_dates[2], - session=session + session=session, ) altered = set_dag_run_state_to_success(self.dag2, self.execution_dates[1], commit=True) @@ -543,11 +624,9 @@ def count_dag_tasks(dag): self._verify_dag_run_state(self.dag2, self.execution_dates[1], State.SUCCESS) # Make sure other dag status are not changed - models.DagRun.find(dag_id=self.dag2.dag_id, - execution_date=self.execution_dates[0]) + models.DagRun.find(dag_id=self.dag2.dag_id, execution_date=self.execution_dates[0]) self._verify_dag_run_state(self.dag2, self.execution_dates[0], State.FAILED) - models.DagRun.find(dag_id=self.dag2.dag_id, - execution_date=self.execution_dates[2]) + models.DagRun.find(dag_id=self.dag2.dag_id, execution_date=self.execution_dates[2]) self._verify_dag_run_state(self.dag2, self.execution_dates[2], State.RUNNING) def test_set_dag_run_state_edge_cases(self): @@ -569,13 +648,13 @@ def test_set_dag_run_state_edge_cases(self): # This will throw ValueError since dag.latest_execution_date # need to be 0 does not exist. - self.assertRaises(ValueError, set_dag_run_state_to_success, self.dag2, - timezone.make_naive(self.execution_dates[0])) + self.assertRaises( + ValueError, set_dag_run_state_to_success, self.dag2, timezone.make_naive(self.execution_dates[0]) + ) # altered = set_dag_run_state_to_success(self.dag1, self.execution_dates[0]) # DagRun does not exist # This will throw ValueError since dag.latest_execution_date does not exist - self.assertRaises(ValueError, set_dag_run_state_to_success, - self.dag2, self.execution_dates[0]) + self.assertRaises(ValueError, set_dag_run_state_to_success, self.dag2, self.execution_dates[0]) def test_set_dag_run_state_to_failed_no_running_tasks(self): """ diff --git a/tests/api/common/experimental/test_pool.py b/tests/api/common/experimental/test_pool.py index d564ad2da9317..433823978a39b 100644 --- a/tests/api/common/experimental/test_pool.py +++ b/tests/api/common/experimental/test_pool.py @@ -52,28 +52,21 @@ def test_get_pool(self): self.assertEqual(pool.pool, self.pools[0].pool) def test_get_pool_non_existing(self): - self.assertRaisesRegex(PoolNotFound, - "^Pool 'test' doesn't exist$", - pool_api.get_pool, - name='test') + self.assertRaisesRegex(PoolNotFound, "^Pool 'test' doesn't exist$", pool_api.get_pool, name='test') def test_get_pool_bad_name(self): for name in ('', ' '): - self.assertRaisesRegex(AirflowBadRequest, - "^Pool name shouldn't be empty$", - pool_api.get_pool, - name=name) + self.assertRaisesRegex( + AirflowBadRequest, "^Pool name shouldn't be empty$", pool_api.get_pool, name=name + ) def test_get_pools(self): - pools = sorted(pool_api.get_pools(), - key=lambda p: p.pool) + pools = sorted(pool_api.get_pools(), key=lambda p: p.pool) self.assertEqual(pools[0].pool, self.pools[0].pool) self.assertEqual(pools[1].pool, self.pools[1].pool) def test_create_pool(self): - pool = pool_api.create_pool(name='foo', - slots=5, - description='') + pool = pool_api.create_pool(name='foo', slots=5, description='') self.assertEqual(pool.pool, 'foo') self.assertEqual(pool.slots, 5) self.assertEqual(pool.description, '') @@ -81,9 +74,7 @@ def test_create_pool(self): self.assertEqual(session.query(models.Pool).count(), self.TOTAL_POOL_COUNT + 1) def test_create_pool_existing(self): - pool = pool_api.create_pool(name=self.pools[0].pool, - slots=5, - description='') + pool = pool_api.create_pool(name=self.pools[0].pool, slots=5, description='') self.assertEqual(pool.pool, self.pools[0].pool) self.assertEqual(pool.slots, 5) self.assertEqual(pool.description, '') @@ -92,30 +83,36 @@ def test_create_pool_existing(self): def test_create_pool_bad_name(self): for name in ('', ' '): - self.assertRaisesRegex(AirflowBadRequest, - "^Pool name shouldn't be empty$", - pool_api.create_pool, - name=name, - slots=5, - description='') + self.assertRaisesRegex( + AirflowBadRequest, + "^Pool name shouldn't be empty$", + pool_api.create_pool, + name=name, + slots=5, + description='', + ) def test_create_pool_name_too_long(self): long_name = ''.join(random.choices(string.ascii_lowercase, k=300)) column_length = models.Pool.pool.property.columns[0].type.length - self.assertRaisesRegex(AirflowBadRequest, - "^Pool name can't be more than %d characters$" % column_length, - pool_api.create_pool, - name=long_name, - slots=5, - description='') + self.assertRaisesRegex( + AirflowBadRequest, + "^Pool name can't be more than %d characters$" % column_length, + pool_api.create_pool, + name=long_name, + slots=5, + description='', + ) def test_create_pool_bad_slots(self): - self.assertRaisesRegex(AirflowBadRequest, - "^Bad value for `slots`: foo$", - pool_api.create_pool, - name='foo', - slots='foo', - description='') + self.assertRaisesRegex( + AirflowBadRequest, + "^Bad value for `slots`: foo$", + pool_api.create_pool, + name='foo', + slots='foo', + description='', + ) def test_delete_pool(self): pool = pool_api.delete_pool(name=self.pools[-1].pool) @@ -124,19 +121,16 @@ def test_delete_pool(self): self.assertEqual(session.query(models.Pool).count(), self.TOTAL_POOL_COUNT - 1) def test_delete_pool_non_existing(self): - self.assertRaisesRegex(pool_api.PoolNotFound, - "^Pool 'test' doesn't exist$", - pool_api.delete_pool, - name='test') + self.assertRaisesRegex( + pool_api.PoolNotFound, "^Pool 'test' doesn't exist$", pool_api.delete_pool, name='test' + ) def test_delete_pool_bad_name(self): for name in ('', ' '): - self.assertRaisesRegex(AirflowBadRequest, - "^Pool name shouldn't be empty$", - pool_api.delete_pool, - name=name) + self.assertRaisesRegex( + AirflowBadRequest, "^Pool name shouldn't be empty$", pool_api.delete_pool, name=name + ) def test_delete_default_pool_not_allowed(self): - with self.assertRaisesRegex(AirflowBadRequest, - "^default_pool cannot be deleted$"): + with self.assertRaisesRegex(AirflowBadRequest, "^default_pool cannot be deleted$"): pool_api.delete_pool(Pool.DEFAULT_POOL_NAME) diff --git a/tests/api/common/experimental/test_trigger_dag.py b/tests/api/common/experimental/test_trigger_dag.py index 13ea2a32d2246..9fb772d3576f9 100644 --- a/tests/api/common/experimental/test_trigger_dag.py +++ b/tests/api/common/experimental/test_trigger_dag.py @@ -29,7 +29,6 @@ class TestTriggerDag(unittest.TestCase): - def setUp(self) -> None: db.clear_db_runs() @@ -107,11 +106,13 @@ def test_trigger_dag_with_valid_start_date(self, dag_bag_mock): assert len(triggers) == 1 - @parameterized.expand([ - (None, {}), - ({"foo": "bar"}, {"foo": "bar"}), - ('{"foo": "bar"}', {"foo": "bar"}), - ]) + @parameterized.expand( + [ + (None, {}), + ({"foo": "bar"}, {"foo": "bar"}), + ('{"foo": "bar"}', {"foo": "bar"}), + ] + ) @mock.patch('airflow.models.DagBag') def test_trigger_dag_with_conf(self, conf, expected_conf, dag_bag_mock): dag_id = "trigger_dag_with_conf" diff --git a/tests/api_connexion/endpoints/test_config_endpoint.py b/tests/api_connexion/endpoints/test_config_endpoint.py index 9acfaea2cf94f..5a6b14c9bec99 100644 --- a/tests/api_connexion/endpoints/test_config_endpoint.py +++ b/tests/api_connexion/endpoints/test_config_endpoint.py @@ -16,7 +16,6 @@ # under the License. import textwrap - from unittest.mock import patch from airflow.security import permissions @@ -24,7 +23,6 @@ from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.config import conf_vars - MOCK_CONF = { 'core': { 'parallelism': '1024', diff --git a/tests/api_connexion/endpoints/test_task_instance_endpoint.py b/tests/api_connexion/endpoints/test_task_instance_endpoint.py index 4bc78683b177b..34c5a388f264c 100644 --- a/tests/api_connexion/endpoints/test_task_instance_endpoint.py +++ b/tests/api_connexion/endpoints/test_task_instance_endpoint.py @@ -14,22 +14,22 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import unittest import datetime as dt import getpass +import unittest from unittest import mock from parameterized import parameterized -from airflow.models import DagBag, DagRun, TaskInstance, SlaMiss +from airflow.models import DagBag, DagRun, SlaMiss, TaskInstance from airflow.security import permissions -from airflow.utils.types import DagRunType from airflow.utils.session import provide_session from airflow.utils.state import State from airflow.utils.timezone import datetime +from airflow.utils.types import DagRunType from airflow.www import app +from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.config import conf_vars -from tests.test_utils.api_connexion_utils import create_user, delete_user, assert_401 from tests.test_utils.db import clear_db_runs, clear_db_sla_miss DEFAULT_DATETIME_1 = datetime(2020, 1, 1) diff --git a/tests/api_connexion/endpoints/test_variable_endpoint.py b/tests/api_connexion/endpoints/test_variable_endpoint.py index 9f768709236de..1c6cb657746de 100644 --- a/tests/api_connexion/endpoints/test_variable_endpoint.py +++ b/tests/api_connexion/endpoints/test_variable_endpoint.py @@ -18,9 +18,9 @@ from parameterized import parameterized -from airflow.security import permissions from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP from airflow.models import Variable +from airflow.security import permissions from airflow.www import app from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.config import conf_vars diff --git a/tests/api_connexion/endpoints/test_xcom_endpoint.py b/tests/api_connexion/endpoints/test_xcom_endpoint.py index be1e0d2bc1ba5..92ef069719f59 100644 --- a/tests/api_connexion/endpoints/test_xcom_endpoint.py +++ b/tests/api_connexion/endpoints/test_xcom_endpoint.py @@ -14,8 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from datetime import timedelta import unittest +from datetime import timedelta from parameterized import parameterized diff --git a/tests/api_connexion/schemas/test_task_instance_schema.py b/tests/api_connexion/schemas/test_task_instance_schema.py index c164aec8cdace..a2d1342e7228a 100644 --- a/tests/api_connexion/schemas/test_task_instance_schema.py +++ b/tests/api_connexion/schemas/test_task_instance_schema.py @@ -24,13 +24,13 @@ from airflow.api_connexion.schemas.task_instance_schema import ( clear_task_instance_form, - task_instance_schema, set_task_instance_state_form, + task_instance_schema, ) from airflow.models import DAG, SlaMiss, TaskInstance as TI from airflow.operators.dummy_operator import DummyOperator -from airflow.utils.state import State from airflow.utils.session import create_session, provide_session +from airflow.utils.state import State from airflow.utils.timezone import datetime diff --git a/tests/build_provider_packages_dependencies.py b/tests/build_provider_packages_dependencies.py index 9542759a73c92..3832bbc0e3000 100644 --- a/tests/build_provider_packages_dependencies.py +++ b/tests/build_provider_packages_dependencies.py @@ -67,8 +67,10 @@ def get_provider_from_file_name(file_name: str) -> Optional[str]: :param file_name: name of the file :return: provider name or None if no provider could be found """ - if AIRFLOW_PROVIDERS_FILE_PREFIX not in file_name and \ - AIRFLOW_TESTS_PROVIDERS_FILE_PREFIX not in file_name: + if ( + AIRFLOW_PROVIDERS_FILE_PREFIX not in file_name + and AIRFLOW_TESTS_PROVIDERS_FILE_PREFIX not in file_name + ): # We should only check file that are provider errors.append(f"Wrong file not in the providers package = {file_name}") return None @@ -84,9 +86,9 @@ def get_provider_from_file_name(file_name: str) -> Optional[str]: def get_file_suffix(file_name): if AIRFLOW_PROVIDERS_FILE_PREFIX in file_name: - return file_name[file_name.find(AIRFLOW_PROVIDERS_FILE_PREFIX):] + return file_name[file_name.find(AIRFLOW_PROVIDERS_FILE_PREFIX) :] if AIRFLOW_TESTS_PROVIDERS_FILE_PREFIX in file_name: - return file_name[file_name.find(AIRFLOW_TESTS_PROVIDERS_FILE_PREFIX):] + return file_name[file_name.find(AIRFLOW_TESTS_PROVIDERS_FILE_PREFIX) :] return None @@ -99,7 +101,7 @@ def get_provider_from_import(import_name: str) -> Optional[str]: if AIRFLOW_PROVIDERS_IMPORT_PREFIX not in import_name: # skip silently - we expect non-providers imports return None - suffix = import_name[import_name.find(AIRFLOW_PROVIDERS_IMPORT_PREFIX):] + suffix = import_name[import_name.find(AIRFLOW_PROVIDERS_IMPORT_PREFIX) :] split_import = suffix.split(".")[2:] provider = find_provider(split_import) if not provider: @@ -111,6 +113,7 @@ class ImportFinder(NodeVisitor): """ AST visitor that collects all imported names in its imports """ + def __init__(self, filename): self.imports: List[str] = [] self.filename = filename @@ -174,12 +177,16 @@ def check_if_different_provider_used(file_name: str): def parse_arguments(): import argparse + parser = argparse.ArgumentParser( - description='Checks if dependencies between packages are handled correctly.') - parser.add_argument("-f", "--provider-dependencies-file", - help="Stores dependencies between providers in the file") - parser.add_argument("-d", "--documentation-file", - help="Updates package documentation in the file specified (.rst)") + description='Checks if dependencies between packages are handled correctly.' + ) + parser.add_argument( + "-f", "--provider-dependencies-file", help="Stores dependencies between providers in the file" + ) + parser.add_argument( + "-d", "--documentation-file", help="Updates package documentation in the file specified (.rst)" + ) parser.add_argument('files', nargs='*') args = parser.parse_args() diff --git a/tests/cli/commands/test_celery_command.py b/tests/cli/commands/test_celery_command.py index f7187cef3a6d6..5bd2855b8b53a 100644 --- a/tests/cli/commands/test_celery_command.py +++ b/tests/cli/commands/test_celery_command.py @@ -64,7 +64,6 @@ def test_validate_session_dbapi_exception(self, mock_session): @pytest.mark.integration("rabbitmq") @pytest.mark.backend("mysql", "postgres") class TestWorkerServeLogs(unittest.TestCase): - @classmethod def setUpClass(cls): cls.parser = cli_parser.get_parser() @@ -123,10 +122,7 @@ def test_if_right_pid_is_read(self, mock_process, mock_setup_locations): @mock.patch("airflow.cli.commands.celery_command.setup_locations") @conf_vars({("core", "executor"): "CeleryExecutor"}) def test_same_pid_file_is_used_in_start_and_stop( - self, - mock_setup_locations, - mock_celery_worker, - mock_read_pid_from_pidfile + self, mock_setup_locations, mock_celery_worker, mock_read_pid_from_pidfile ): pid_file = "test_pid_file" mock_setup_locations.return_value = (pid_file, None, None, None) @@ -164,18 +160,20 @@ def test_worker_started_with_required_arguments(self, mock_worker, mock_popen, m celery_hostname = "celery_hostname" queues = "queue" autoscale = "2,5" - args = self.parser.parse_args([ - 'celery', - 'worker', - '--autoscale', - autoscale, - '--concurrency', - concurrency, - '--celery-hostname', - celery_hostname, - '--queues', - queues - ]) + args = self.parser.parse_args( + [ + 'celery', + 'worker', + '--autoscale', + autoscale, + '--concurrency', + concurrency, + '--celery-hostname', + celery_hostname, + '--queues', + queues, + ] + ) with mock.patch('celery.platforms.check_privileges') as mock_privil: mock_privil.return_value = 0 diff --git a/tests/cli/commands/test_cheat_sheet_command.py b/tests/cli/commands/test_cheat_sheet_command.py index c07259b4f9406..50239d6f95fc9 100644 --- a/tests/cli/commands/test_cheat_sheet_command.py +++ b/tests/cli/commands/test_cheat_sheet_command.py @@ -71,7 +71,7 @@ def noop(): help='Help text D', func=noop, args=(), - ) + ), ] EXPECTED_OUTPUT = """\ @@ -92,7 +92,6 @@ def noop(): class TestCheatSheetCommand(unittest.TestCase): - @classmethod def setUpClass(cls): cls.parser = cli_parser.get_parser() diff --git a/tests/cli/commands/test_config_command.py b/tests/cli/commands/test_config_command.py index c7c8925b9e1fe..56686545fe7da 100644 --- a/tests/cli/commands/test_config_command.py +++ b/tests/cli/commands/test_config_command.py @@ -35,9 +35,7 @@ def test_cli_show_config_should_write_data(self, mock_conf, mock_stringio): config_command.show_config(self.parser.parse_args(['config', 'list', '--color', 'off'])) mock_conf.write.assert_called_once_with(mock_stringio.return_value.__enter__.return_value) - @conf_vars({ - ('core', 'testkey'): 'test_value' - }) + @conf_vars({('core', 'testkey'): 'test_value'}) def test_cli_show_config_should_display_key(self): with contextlib.redirect_stdout(io.StringIO()) as temp_stdout: config_command.show_config(self.parser.parse_args(['config', 'list', '--color', 'off'])) @@ -50,9 +48,7 @@ class TestCliConfigGetValue(unittest.TestCase): def setUpClass(cls): cls.parser = cli_parser.get_parser() - @conf_vars({ - ('core', 'test_key'): 'test_value' - }) + @conf_vars({('core', 'test_key'): 'test_value'}) def test_should_display_value(self): with contextlib.redirect_stdout(io.StringIO()) as temp_stdout: config_command.get_value(self.parser.parse_args(['config', 'get-value', 'core', 'test_key'])) @@ -65,9 +61,9 @@ def test_should_raise_exception_when_section_is_missing(self, mock_conf): mock_conf.has_option.return_value = True with contextlib.redirect_stderr(io.StringIO()) as temp_stderr, self.assertRaises(SystemExit) as cm: - config_command.get_value(self.parser.parse_args( - ['config', 'get-value', 'missing-section', 'dags_folder'] - )) + config_command.get_value( + self.parser.parse_args(['config', 'get-value', 'missing-section', 'dags_folder']) + ) self.assertEqual(1, cm.exception.code) self.assertEqual( "The section [missing-section] is not found in config.", temp_stderr.getvalue().strip() @@ -79,9 +75,9 @@ def test_should_raise_exception_when_option_is_missing(self, mock_conf): mock_conf.has_option.return_value = False with contextlib.redirect_stderr(io.StringIO()) as temp_stderr, self.assertRaises(SystemExit) as cm: - config_command.get_value(self.parser.parse_args( - ['config', 'get-value', 'missing-section', 'dags_folder'] - )) + config_command.get_value( + self.parser.parse_args(['config', 'get-value', 'missing-section', 'dags_folder']) + ) self.assertEqual(1, cm.exception.code) self.assertEqual( "The option [missing-section/dags_folder] is not found in config.", temp_stderr.getvalue().strip() diff --git a/tests/cli/commands/test_connection_command.py b/tests/cli/commands/test_connection_command.py index 9065f510302ce..d7112268ab623 100644 --- a/tests/cli/commands/test_connection_command.py +++ b/tests/cli/commands/test_connection_command.py @@ -42,36 +42,67 @@ def tearDown(self): def test_cli_connection_get(self): with redirect_stdout(io.StringIO()) as stdout: - connection_command.connections_get(self.parser.parse_args( - ["connections", "get", "google_cloud_default"] - )) + connection_command.connections_get( + self.parser.parse_args(["connections", "get", "google_cloud_default"]) + ) stdout = stdout.getvalue() - self.assertIn( - "URI: google-cloud-platform:///default", - stdout - ) + self.assertIn("URI: google-cloud-platform:///default", stdout) def test_cli_connection_get_invalid(self): with self.assertRaisesRegex(SystemExit, re.escape("Connection not found.")): - connection_command.connections_get(self.parser.parse_args( - ["connections", "get", "INVALID"] - )) + connection_command.connections_get(self.parser.parse_args(["connections", "get", "INVALID"])) class TestCliListConnections(unittest.TestCase): EXPECTED_CONS = [ - ('airflow_db', 'mysql', ), - ('google_cloud_default', 'google_cloud_platform', ), - ('http_default', 'http', ), - ('local_mysql', 'mysql', ), - ('mongo_default', 'mongo', ), - ('mssql_default', 'mssql', ), - ('mysql_default', 'mysql', ), - ('pinot_broker_default', 'pinot', ), - ('postgres_default', 'postgres', ), - ('presto_default', 'presto', ), - ('sqlite_default', 'sqlite', ), - ('vertica_default', 'vertica', ), + ( + 'airflow_db', + 'mysql', + ), + ( + 'google_cloud_default', + 'google_cloud_platform', + ), + ( + 'http_default', + 'http', + ), + ( + 'local_mysql', + 'mysql', + ), + ( + 'mongo_default', + 'mongo', + ), + ( + 'mssql_default', + 'mssql', + ), + ( + 'mysql_default', + 'mysql', + ), + ( + 'pinot_broker_default', + 'pinot', + ), + ( + 'postgres_default', + 'postgres', + ), + ( + 'presto_default', + 'presto', + ), + ( + 'sqlite_default', + 'sqlite', + ), + ( + 'vertica_default', + 'vertica', + ), ] def setUp(self): @@ -126,7 +157,7 @@ def setUp(self, session=None): password="plainpassword", schema="airflow", ), - session + session, ) merge_conn( Connection( @@ -136,7 +167,7 @@ def setUp(self, session=None): port=8082, extra='{"endpoint": "druid/v2/sql"}', ), - session + session, ) self.parser = cli_parser.get_parser() @@ -146,34 +177,32 @@ def tearDown(self): def test_cli_connections_export_should_return_error_for_invalid_command(self): with self.assertRaises(SystemExit): - self.parser.parse_args([ - "connections", - "export", - ]) + self.parser.parse_args( + [ + "connections", + "export", + ] + ) def test_cli_connections_export_should_return_error_for_invalid_format(self): with self.assertRaises(SystemExit): - self.parser.parse_args([ - "connections", - "export", - "--format", - "invalid", - "/path/to/file" - ]) + self.parser.parse_args(["connections", "export", "--format", "invalid", "/path/to/file"]) @mock.patch('os.path.splitext') @mock.patch('builtins.open', new_callable=mock.mock_open()) - def test_cli_connections_export_should_return_error_for_invalid_export_format(self, - mock_file_open, - mock_splittext): + def test_cli_connections_export_should_return_error_for_invalid_export_format( + self, mock_file_open, mock_splittext + ): output_filepath = '/tmp/connections.invalid' mock_splittext.return_value = (None, '.invalid') - args = self.parser.parse_args([ - "connections", - "export", - output_filepath, - ]) + args = self.parser.parse_args( + [ + "connections", + "export", + output_filepath, + ] + ) with self.assertRaisesRegex( SystemExit, r"Unsupported file format. The file must have the extension .yaml, .json, .env" ): @@ -186,21 +215,24 @@ def test_cli_connections_export_should_return_error_for_invalid_export_format(se @mock.patch('os.path.splitext') @mock.patch('builtins.open', new_callable=mock.mock_open()) @mock.patch.object(connection_command, 'create_session') - def test_cli_connections_export_should_return_error_if_create_session_fails(self, mock_session, - mock_file_open, - mock_splittext): + def test_cli_connections_export_should_return_error_if_create_session_fails( + self, mock_session, mock_file_open, mock_splittext + ): output_filepath = '/tmp/connections.json' def my_side_effect(): raise Exception("dummy exception") + mock_session.side_effect = my_side_effect mock_splittext.return_value = (None, '.json') - args = self.parser.parse_args([ - "connections", - "export", - output_filepath, - ]) + args = self.parser.parse_args( + [ + "connections", + "export", + output_filepath, + ] + ) with self.assertRaisesRegex(Exception, r"dummy exception"): connection_command.connections_export(args) @@ -211,22 +243,26 @@ def my_side_effect(): @mock.patch('os.path.splitext') @mock.patch('builtins.open', new_callable=mock.mock_open()) @mock.patch.object(connection_command, 'create_session') - def test_cli_connections_export_should_return_error_if_fetching_connections_fails(self, mock_session, - mock_file_open, - mock_splittext): + def test_cli_connections_export_should_return_error_if_fetching_connections_fails( + self, mock_session, mock_file_open, mock_splittext + ): output_filepath = '/tmp/connections.json' def my_side_effect(_): raise Exception("dummy exception") - mock_session.return_value.__enter__.return_value.query.return_value.order_by.side_effect = \ + + mock_session.return_value.__enter__.return_value.query.return_value.order_by.side_effect = ( my_side_effect + ) mock_splittext.return_value = (None, '.json') - args = self.parser.parse_args([ - "connections", - "export", - output_filepath, - ]) + args = self.parser.parse_args( + [ + "connections", + "export", + output_filepath, + ] + ) with self.assertRaisesRegex(Exception, r"dummy exception"): connection_command.connections_export(args) @@ -237,19 +273,21 @@ def my_side_effect(_): @mock.patch('os.path.splitext') @mock.patch('builtins.open', new_callable=mock.mock_open()) @mock.patch.object(connection_command, 'create_session') - def test_cli_connections_export_should_not_return_error_if_connections_is_empty(self, mock_session, - mock_file_open, - mock_splittext): + def test_cli_connections_export_should_not_return_error_if_connections_is_empty( + self, mock_session, mock_file_open, mock_splittext + ): output_filepath = '/tmp/connections.json' mock_session.return_value.__enter__.return_value.query.return_value.all.return_value = [] mock_splittext.return_value = (None, '.json') - args = self.parser.parse_args([ - "connections", - "export", - output_filepath, - ]) + args = self.parser.parse_args( + [ + "connections", + "export", + output_filepath, + ] + ) connection_command.connections_export(args) mock_splittext.assert_called_once() @@ -262,33 +300,38 @@ def test_cli_connections_export_should_export_as_json(self, mock_file_open, mock output_filepath = '/tmp/connections.json' mock_splittext.return_value = (None, '.json') - args = self.parser.parse_args([ - "connections", - "export", - output_filepath, - ]) + args = self.parser.parse_args( + [ + "connections", + "export", + output_filepath, + ] + ) connection_command.connections_export(args) - expected_connections = json.dumps({ - "airflow_db": { - "conn_type": "mysql", - "host": "mysql", - "login": "root", - "password": "plainpassword", - "schema": "airflow", - "port": None, - "extra": None, + expected_connections = json.dumps( + { + "airflow_db": { + "conn_type": "mysql", + "host": "mysql", + "login": "root", + "password": "plainpassword", + "schema": "airflow", + "port": None, + "extra": None, + }, + "druid_broker_default": { + "conn_type": "druid", + "host": "druid-broker", + "login": None, + "password": None, + "schema": None, + "port": 8082, + "extra": "{\"endpoint\": \"druid/v2/sql\"}", + }, }, - "druid_broker_default": { - "conn_type": "druid", - "host": "druid-broker", - "login": None, - "password": None, - "schema": None, - "port": 8082, - "extra": "{\"endpoint\": \"druid/v2/sql\"}", - } - }, indent=2) + indent=2, + ) mock_splittext.assert_called_once() mock_file_open.assert_called_once_with(output_filepath, 'w', -1, 'UTF-8', None) @@ -300,29 +343,33 @@ def test_cli_connections_export_should_export_as_yaml(self, mock_file_open, mock output_filepath = '/tmp/connections.yaml' mock_splittext.return_value = (None, '.yaml') - args = self.parser.parse_args([ - "connections", - "export", - output_filepath, - ]) + args = self.parser.parse_args( + [ + "connections", + "export", + output_filepath, + ] + ) connection_command.connections_export(args) - expected_connections = ("airflow_db:\n" - " conn_type: mysql\n" - " extra: null\n" - " host: mysql\n" - " login: root\n" - " password: plainpassword\n" - " port: null\n" - " schema: airflow\n" - "druid_broker_default:\n" - " conn_type: druid\n" - " extra: \'{\"endpoint\": \"druid/v2/sql\"}\'\n" - " host: druid-broker\n" - " login: null\n" - " password: null\n" - " port: 8082\n" - " schema: null\n") + expected_connections = ( + "airflow_db:\n" + " conn_type: mysql\n" + " extra: null\n" + " host: mysql\n" + " login: root\n" + " password: plainpassword\n" + " port: null\n" + " schema: airflow\n" + "druid_broker_default:\n" + " conn_type: druid\n" + " extra: \'{\"endpoint\": \"druid/v2/sql\"}\'\n" + " host: druid-broker\n" + " login: null\n" + " password: null\n" + " port: 8082\n" + " schema: null\n" + ) mock_splittext.assert_called_once() mock_file_open.assert_called_once_with(output_filepath, 'w', -1, 'UTF-8', None) mock_file_open.return_value.write.assert_called_once_with(expected_connections) @@ -333,19 +380,20 @@ def test_cli_connections_export_should_export_as_env(self, mock_file_open, mock_ output_filepath = '/tmp/connections.env' mock_splittext.return_value = (None, '.env') - args = self.parser.parse_args([ - "connections", - "export", - output_filepath, - ]) + args = self.parser.parse_args( + [ + "connections", + "export", + output_filepath, + ] + ) connection_command.connections_export(args) expected_connections = [ "airflow_db=mysql://root:plainpassword@mysql/airflow\n" "druid_broker_default=druid://druid-broker:8082?endpoint=druid%2Fv2%2Fsql\n", - "druid_broker_default=druid://druid-broker:8082?endpoint=druid%2Fv2%2Fsql\n" - "airflow_db=mysql://root:plainpassword@mysql/airflow\n" + "airflow_db=mysql://root:plainpassword@mysql/airflow\n", ] mock_splittext.assert_called_once() @@ -355,24 +403,26 @@ def test_cli_connections_export_should_export_as_env(self, mock_file_open, mock_ @mock.patch('os.path.splitext') @mock.patch('builtins.open', new_callable=mock.mock_open()) - def test_cli_connections_export_should_export_as_env_for_uppercase_file_extension(self, mock_file_open, - mock_splittext): + def test_cli_connections_export_should_export_as_env_for_uppercase_file_extension( + self, mock_file_open, mock_splittext + ): output_filepath = '/tmp/connections.ENV' mock_splittext.return_value = (None, '.ENV') - args = self.parser.parse_args([ - "connections", - "export", - output_filepath, - ]) + args = self.parser.parse_args( + [ + "connections", + "export", + output_filepath, + ] + ) connection_command.connections_export(args) expected_connections = [ "airflow_db=mysql://root:plainpassword@mysql/airflow\n" "druid_broker_default=druid://druid-broker:8082?endpoint=druid%2Fv2%2Fsql\n", - "druid_broker_default=druid://druid-broker:8082?endpoint=druid%2Fv2%2Fsql\n" - "airflow_db=mysql://root:plainpassword@mysql/airflow\n" + "airflow_db=mysql://root:plainpassword@mysql/airflow\n", ] mock_splittext.assert_called_once() @@ -382,39 +432,45 @@ def test_cli_connections_export_should_export_as_env_for_uppercase_file_extensio @mock.patch('os.path.splitext') @mock.patch('builtins.open', new_callable=mock.mock_open()) - def test_cli_connections_export_should_force_export_as_specified_format(self, mock_file_open, - mock_splittext): + def test_cli_connections_export_should_force_export_as_specified_format( + self, mock_file_open, mock_splittext + ): output_filepath = '/tmp/connections.yaml' - args = self.parser.parse_args([ - "connections", - "export", - output_filepath, - "--format", - "json", - ]) + args = self.parser.parse_args( + [ + "connections", + "export", + output_filepath, + "--format", + "json", + ] + ) connection_command.connections_export(args) - expected_connections = json.dumps({ - "airflow_db": { - "conn_type": "mysql", - "host": "mysql", - "login": "root", - "password": "plainpassword", - "schema": "airflow", - "port": None, - "extra": None, + expected_connections = json.dumps( + { + "airflow_db": { + "conn_type": "mysql", + "host": "mysql", + "login": "root", + "password": "plainpassword", + "schema": "airflow", + "port": None, + "extra": None, + }, + "druid_broker_default": { + "conn_type": "druid", + "host": "druid-broker", + "login": None, + "password": None, + "schema": None, + "port": 8082, + "extra": "{\"endpoint\": \"druid/v2/sql\"}", + }, }, - "druid_broker_default": { - "conn_type": "druid", - "host": "druid-broker", - "login": None, - "password": None, - "schema": None, - "port": 8082, - "extra": "{\"endpoint\": \"druid/v2/sql\"}", - } - }, indent=2) + indent=2, + ) mock_splittext.assert_not_called() mock_file_open.assert_called_once_with(output_filepath, 'w', -1, 'UTF-8', None) mock_file_open.return_value.write.assert_called_once_with(expected_connections) @@ -607,7 +663,7 @@ def test_cli_delete_connections(self, session=None): Connection( conn_id="new1", conn_type="mysql", host="mysql", login="root", password="", schema="airflow" ), - session=session + session=session, ) # Delete connections with redirect_stdout(io.StringIO()) as stdout: diff --git a/tests/cli/commands/test_dag_command.py b/tests/cli/commands/test_dag_command.py index 87b58320389a7..331bbf811292a 100644 --- a/tests/cli/commands/test_dag_command.py +++ b/tests/cli/commands/test_dag_command.py @@ -38,23 +38,16 @@ dag_folder_path = '/'.join(os.path.realpath(__file__).split('/')[:-1]) DEFAULT_DATE = timezone.make_aware(datetime(2015, 1, 1)) -TEST_DAG_FOLDER = os.path.join( - os.path.dirname(dag_folder_path), 'dags') +TEST_DAG_FOLDER = os.path.join(os.path.dirname(dag_folder_path), 'dags') TEST_DAG_ID = 'unit_tests' EXAMPLE_DAGS_FOLDER = os.path.join( - os.path.dirname( - os.path.dirname( - os.path.dirname(os.path.realpath(__file__)) - ) - ), - "airflow/example_dags" + os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))), "airflow/example_dags" ) class TestCliDags(unittest.TestCase): - @classmethod def setUpClass(cls): cls.dagbag = DagBag(include_examples=True) @@ -68,9 +61,11 @@ def tearDownClass(cls) -> None: @mock.patch("airflow.cli.commands.dag_command.DAG.run") def test_backfill(self, mock_run): - dag_command.dag_backfill(self.parser.parse_args([ - 'dags', 'backfill', 'example_bash_operator', - '--start-date', DEFAULT_DATE.isoformat()])) + dag_command.dag_backfill( + self.parser.parse_args( + ['dags', 'backfill', 'example_bash_operator', '--start-date', DEFAULT_DATE.isoformat()] + ) + ) mock_run.assert_called_once_with( start_date=DEFAULT_DATE, @@ -91,9 +86,21 @@ def test_backfill(self, mock_run): dag = self.dagbag.get_dag('example_bash_operator') with contextlib.redirect_stdout(io.StringIO()) as stdout: - dag_command.dag_backfill(self.parser.parse_args([ - 'dags', 'backfill', 'example_bash_operator', '--task-regex', 'runme_0', '--dry-run', - '--start-date', DEFAULT_DATE.isoformat()]), dag=dag) + dag_command.dag_backfill( + self.parser.parse_args( + [ + 'dags', + 'backfill', + 'example_bash_operator', + '--task-regex', + 'runme_0', + '--dry-run', + '--start-date', + DEFAULT_DATE.isoformat(), + ] + ), + dag=dag, + ) output = stdout.getvalue() self.assertIn(f"Dry run of DAG example_bash_operator on {DEFAULT_DATE.isoformat()}\n", output) @@ -101,15 +108,35 @@ def test_backfill(self, mock_run): mock_run.assert_not_called() # Dry run shouldn't run the backfill - dag_command.dag_backfill(self.parser.parse_args([ - 'dags', 'backfill', 'example_bash_operator', '--dry-run', - '--start-date', DEFAULT_DATE.isoformat()]), dag=dag) + dag_command.dag_backfill( + self.parser.parse_args( + [ + 'dags', + 'backfill', + 'example_bash_operator', + '--dry-run', + '--start-date', + DEFAULT_DATE.isoformat(), + ] + ), + dag=dag, + ) mock_run.assert_not_called() # Dry run shouldn't run the backfill - dag_command.dag_backfill(self.parser.parse_args([ - 'dags', 'backfill', 'example_bash_operator', '--local', - '--start-date', DEFAULT_DATE.isoformat()]), dag=dag) + dag_command.dag_backfill( + self.parser.parse_args( + [ + 'dags', + 'backfill', + 'example_bash_operator', + '--local', + '--start-date', + DEFAULT_DATE.isoformat(), + ] + ), + dag=dag, + ) mock_run.assert_called_once_with( start_date=DEFAULT_DATE, @@ -130,8 +157,7 @@ def test_backfill(self, mock_run): def test_show_dag_print(self): with contextlib.redirect_stdout(io.StringIO()) as temp_stdout: - dag_command.dag_show(self.parser.parse_args([ - 'dags', 'show', 'example_bash_operator'])) + dag_command.dag_show(self.parser.parse_args(['dags', 'show', 'example_bash_operator'])) out = temp_stdout.getvalue() self.assertIn("label=example_bash_operator", out) self.assertIn("graph [label=example_bash_operator labelloc=t rankdir=LR]", out) @@ -140,9 +166,9 @@ def test_show_dag_print(self): @mock.patch("airflow.cli.commands.dag_command.render_dag") def test_show_dag_dave(self, mock_render_dag): with contextlib.redirect_stdout(io.StringIO()) as temp_stdout: - dag_command.dag_show(self.parser.parse_args([ - 'dags', 'show', 'example_bash_operator', '--save', 'awesome.png'] - )) + dag_command.dag_show( + self.parser.parse_args(['dags', 'show', 'example_bash_operator', '--save', 'awesome.png']) + ) out = temp_stdout.getvalue() mock_render_dag.return_value.render.assert_called_once_with( cleanup=True, filename='awesome', format='png' @@ -155,9 +181,9 @@ def test_show_dag_imgcat(self, mock_render_dag, mock_popen): mock_render_dag.return_value.pipe.return_value = b"DOT_DATA" mock_popen.return_value.communicate.return_value = (b"OUT", b"ERR") with contextlib.redirect_stdout(io.StringIO()) as temp_stdout: - dag_command.dag_show(self.parser.parse_args([ - 'dags', 'show', 'example_bash_operator', '--imgcat'] - )) + dag_command.dag_show( + self.parser.parse_args(['dags', 'show', 'example_bash_operator', '--imgcat']) + ) out = temp_stdout.getvalue() mock_render_dag.return_value.pipe.assert_called_once_with(format='png') mock_popen.return_value.communicate.assert_called_once_with(b'DOT_DATA') @@ -243,10 +269,12 @@ def test_cli_backfill_depends_on_past_backwards(self, mock_run): ) def test_next_execution(self): - dag_ids = ['example_bash_operator', # schedule_interval is '0 0 * * *' - 'latest_only', # schedule_interval is timedelta(hours=4) - 'example_python_operator', # schedule_interval=None - 'example_xcom'] # schedule_interval="@once" + dag_ids = [ + 'example_bash_operator', # schedule_interval is '0 0 * * *' + 'latest_only', # schedule_interval is timedelta(hours=4) + 'example_python_operator', # schedule_interval=None + 'example_xcom', + ] # schedule_interval="@once" # Delete DagRuns with create_session() as session: @@ -254,9 +282,7 @@ def test_next_execution(self): dr.delete(synchronize_session=False) # Test None output - args = self.parser.parse_args(['dags', - 'next-execution', - dag_ids[0]]) + args = self.parser.parse_args(['dags', 'next-execution', dag_ids[0]]) with contextlib.redirect_stdout(io.StringIO()) as temp_stdout: dag_command.dag_next_execution(args) out = temp_stdout.getvalue() @@ -266,40 +292,30 @@ def test_next_execution(self): # The details below is determined by the schedule_interval of example DAGs now = DEFAULT_DATE - expected_output = [str(now + timedelta(days=1)), - str(now + timedelta(hours=4)), - "None", - "None"] - expected_output_2 = [str(now + timedelta(days=1)) + os.linesep + str(now + timedelta(days=2)), - str(now + timedelta(hours=4)) + os.linesep + str(now + timedelta(hours=8)), - "None", - "None"] + expected_output = [str(now + timedelta(days=1)), str(now + timedelta(hours=4)), "None", "None"] + expected_output_2 = [ + str(now + timedelta(days=1)) + os.linesep + str(now + timedelta(days=2)), + str(now + timedelta(hours=4)) + os.linesep + str(now + timedelta(hours=8)), + "None", + "None", + ] for i, dag_id in enumerate(dag_ids): dag = self.dagbag.dags[dag_id] # Create a DagRun for each DAG, to prepare for next step dag.create_dagrun( - run_type=DagRunType.MANUAL, - execution_date=now, - start_date=now, - state=State.FAILED + run_type=DagRunType.MANUAL, execution_date=now, start_date=now, state=State.FAILED ) # Test num-executions = 1 (default) - args = self.parser.parse_args(['dags', - 'next-execution', - dag_id]) + args = self.parser.parse_args(['dags', 'next-execution', dag_id]) with contextlib.redirect_stdout(io.StringIO()) as temp_stdout: dag_command.dag_next_execution(args) out = temp_stdout.getvalue() self.assertIn(expected_output[i], out) # Test num-executions = 2 - args = self.parser.parse_args(['dags', - 'next-execution', - dag_id, - '--num-executions', - '2']) + args = self.parser.parse_args(['dags', 'next-execution', dag_id, '--num-executions', '2']) with contextlib.redirect_stdout(io.StringIO()) as temp_stdout: dag_command.dag_next_execution(args) out = temp_stdout.getvalue() @@ -310,9 +326,7 @@ def test_next_execution(self): dr = session.query(DagRun).filter(DagRun.dag_id.in_(dag_ids)) dr.delete(synchronize_session=False) - @conf_vars({ - ('core', 'load_examples'): 'true' - }) + @conf_vars({('core', 'load_examples'): 'true'}) def test_cli_report(self): args = self.parser.parse_args(['dags', 'report']) with contextlib.redirect_stdout(io.StringIO()) as temp_stdout: @@ -322,9 +336,7 @@ def test_cli_report(self): self.assertIn("airflow/example_dags/example_complex.py ", out) self.assertIn("['example_complex']", out) - @conf_vars({ - ('core', 'load_examples'): 'true' - }) + @conf_vars({('core', 'load_examples'): 'true'}) def test_cli_list_dags(self): args = self.parser.parse_args(['dags', 'list', '--output=fancy_grid']) with contextlib.redirect_stdout(io.StringIO()) as temp_stdout: @@ -335,49 +347,74 @@ def test_cli_list_dags(self): self.assertIn("airflow/example_dags/example_complex.py", out) def test_cli_list_dag_runs(self): - dag_command.dag_trigger(self.parser.parse_args([ - 'dags', 'trigger', 'example_bash_operator', ])) - args = self.parser.parse_args(['dags', - 'list-runs', - '--dag-id', - 'example_bash_operator', - '--no-backfill', - '--start-date', - DEFAULT_DATE.isoformat(), - '--end-date', - timezone.make_aware(datetime.max).isoformat()]) + dag_command.dag_trigger( + self.parser.parse_args( + [ + 'dags', + 'trigger', + 'example_bash_operator', + ] + ) + ) + args = self.parser.parse_args( + [ + 'dags', + 'list-runs', + '--dag-id', + 'example_bash_operator', + '--no-backfill', + '--start-date', + DEFAULT_DATE.isoformat(), + '--end-date', + timezone.make_aware(datetime.max).isoformat(), + ] + ) dag_command.dag_list_dag_runs(args) def test_cli_list_jobs_with_args(self): - args = self.parser.parse_args(['dags', 'list-jobs', '--dag-id', - 'example_bash_operator', - '--state', 'success', - '--limit', '100', - '--output', 'tsv']) + args = self.parser.parse_args( + [ + 'dags', + 'list-jobs', + '--dag-id', + 'example_bash_operator', + '--state', + 'success', + '--limit', + '100', + '--output', + 'tsv', + ] + ) dag_command.dag_list_jobs(args) def test_pause(self): - args = self.parser.parse_args([ - 'dags', 'pause', 'example_bash_operator']) + args = self.parser.parse_args(['dags', 'pause', 'example_bash_operator']) dag_command.dag_pause(args) self.assertIn(self.dagbag.dags['example_bash_operator'].get_is_paused(), [True, 1]) - args = self.parser.parse_args([ - 'dags', 'unpause', 'example_bash_operator']) + args = self.parser.parse_args(['dags', 'unpause', 'example_bash_operator']) dag_command.dag_unpause(args) self.assertIn(self.dagbag.dags['example_bash_operator'].get_is_paused(), [False, 0]) def test_trigger_dag(self): - dag_command.dag_trigger(self.parser.parse_args([ - 'dags', 'trigger', 'example_bash_operator', - '--conf', '{"foo": "bar"}'])) + dag_command.dag_trigger( + self.parser.parse_args(['dags', 'trigger', 'example_bash_operator', '--conf', '{"foo": "bar"}']) + ) self.assertRaises( ValueError, dag_command.dag_trigger, - self.parser.parse_args([ - 'dags', 'trigger', 'example_bash_operator', - '--run-id', 'trigger_dag_xxx', - '--conf', 'NOT JSON']) + self.parser.parse_args( + [ + 'dags', + 'trigger', + 'example_bash_operator', + '--run-id', + 'trigger_dag_xxx', + '--conf', + 'NOT JSON', + ] + ), ) def test_delete_dag(self): @@ -386,16 +423,12 @@ def test_delete_dag(self): session = settings.Session() session.add(DM(dag_id=key)) session.commit() - dag_command.dag_delete(self.parser.parse_args([ - 'dags', 'delete', key, '--yes'])) + dag_command.dag_delete(self.parser.parse_args(['dags', 'delete', key, '--yes'])) self.assertEqual(session.query(DM).filter_by(dag_id=key).count(), 0) self.assertRaises( AirflowException, dag_command.dag_delete, - self.parser.parse_args([ - 'dags', 'delete', - 'does_not_exist_dag', - '--yes']) + self.parser.parse_args(['dags', 'delete', 'does_not_exist_dag', '--yes']), ) def test_delete_dag_existing_file(self): @@ -407,8 +440,7 @@ def test_delete_dag_existing_file(self): with tempfile.NamedTemporaryFile() as f: session.add(DM(dag_id=key, fileloc=f.name)) session.commit() - dag_command.dag_delete(self.parser.parse_args([ - 'dags', 'delete', key, '--yes'])) + dag_command.dag_delete(self.parser.parse_args(['dags', 'delete', key, '--yes'])) self.assertEqual(session.query(DM).filter_by(dag_id=key).count(), 0) def test_cli_list_jobs(self): @@ -416,8 +448,12 @@ def test_cli_list_jobs(self): dag_command.dag_list_jobs(args) def test_dag_state(self): - self.assertEqual(None, dag_command.dag_state(self.parser.parse_args([ - 'dags', 'state', 'example_bash_operator', DEFAULT_DATE.isoformat()]))) + self.assertEqual( + None, + dag_command.dag_state( + self.parser.parse_args(['dags', 'state', 'example_bash_operator', DEFAULT_DATE.isoformat()]) + ), + ) @mock.patch("airflow.cli.commands.dag_command.DebugExecutor") @mock.patch("airflow.cli.commands.dag_command.get_dag") @@ -425,20 +461,21 @@ def test_dag_test(self, mock_get_dag, mock_executor): cli_args = self.parser.parse_args(['dags', 'test', 'example_bash_operator', DEFAULT_DATE.isoformat()]) dag_command.dag_test(cli_args) - mock_get_dag.assert_has_calls([ - mock.call( - subdir=cli_args.subdir, dag_id='example_bash_operator' - ), - mock.call().clear( - start_date=cli_args.execution_date, end_date=cli_args.execution_date, - dag_run_state=State.NONE, - ), - mock.call().run( - executor=mock_executor.return_value, - start_date=cli_args.execution_date, - end_date=cli_args.execution_date - ) - ]) + mock_get_dag.assert_has_calls( + [ + mock.call(subdir=cli_args.subdir, dag_id='example_bash_operator'), + mock.call().clear( + start_date=cli_args.execution_date, + end_date=cli_args.execution_date, + dag_run_state=State.NONE, + ), + mock.call().run( + executor=mock_executor.return_value, + start_date=cli_args.execution_date, + end_date=cli_args.execution_date, + ), + ] + ) @mock.patch( "airflow.cli.commands.dag_command.render_dag", **{'return_value.source': "SOURCE"} # type: ignore @@ -446,30 +483,28 @@ def test_dag_test(self, mock_get_dag, mock_executor): @mock.patch("airflow.cli.commands.dag_command.DebugExecutor") @mock.patch("airflow.cli.commands.dag_command.get_dag") def test_dag_test_show_dag(self, mock_get_dag, mock_executor, mock_render_dag): - cli_args = self.parser.parse_args([ - 'dags', 'test', 'example_bash_operator', DEFAULT_DATE.isoformat(), '--show-dagrun' - ]) + cli_args = self.parser.parse_args( + ['dags', 'test', 'example_bash_operator', DEFAULT_DATE.isoformat(), '--show-dagrun'] + ) with contextlib.redirect_stdout(io.StringIO()) as stdout: dag_command.dag_test(cli_args) output = stdout.getvalue() - mock_get_dag.assert_has_calls([ - mock.call( - subdir=cli_args.subdir, dag_id='example_bash_operator' - ), - mock.call().clear( - start_date=cli_args.execution_date, - end_date=cli_args.execution_date, - dag_run_state=State.NONE, - ), - mock.call().run( - executor=mock_executor.return_value, - start_date=cli_args.execution_date, - end_date=cli_args.execution_date - ) - ]) - mock_render_dag.assert_has_calls([ - mock.call(mock_get_dag.return_value, tis=[]) - ]) + mock_get_dag.assert_has_calls( + [ + mock.call(subdir=cli_args.subdir, dag_id='example_bash_operator'), + mock.call().clear( + start_date=cli_args.execution_date, + end_date=cli_args.execution_date, + dag_run_state=State.NONE, + ), + mock.call().run( + executor=mock_executor.return_value, + start_date=cli_args.execution_date, + end_date=cli_args.execution_date, + ), + ] + ) + mock_render_dag.assert_has_calls([mock.call(mock_get_dag.return_value, tis=[])]) self.assertIn("SOURCE", output) diff --git a/tests/cli/commands/test_db_command.py b/tests/cli/commands/test_db_command.py index e8ed61048761f..8543db9f6cc86 100644 --- a/tests/cli/commands/test_db_command.py +++ b/tests/cli/commands/test_db_command.py @@ -56,56 +56,48 @@ def test_cli_upgradedb(self, mock_upgradedb): @mock.patch("airflow.cli.commands.db_command.execute_interactive") @mock.patch("airflow.cli.commands.db_command.NamedTemporaryFile") - @mock.patch( - "airflow.cli.commands.db_command.settings.engine.url", - make_url("mysql://root@mysql/airflow") - ) + @mock.patch("airflow.cli.commands.db_command.settings.engine.url", make_url("mysql://root@mysql/airflow")) def test_cli_shell_mysql(self, mock_tmp_file, mock_execute_interactive): mock_tmp_file.return_value.__enter__.return_value.name = "/tmp/name" db_command.shell(self.parser.parse_args(['db', 'shell'])) - mock_execute_interactive.assert_called_once_with( - ['mysql', '--defaults-extra-file=/tmp/name'] - ) + mock_execute_interactive.assert_called_once_with(['mysql', '--defaults-extra-file=/tmp/name']) mock_tmp_file.return_value.__enter__.return_value.write.assert_called_once_with( - b'[client]\nhost = mysql\nuser = root\npassword = \nport = ' - b'\ndatabase = airflow' + b'[client]\nhost = mysql\nuser = root\npassword = \nport = ' b'\ndatabase = airflow' ) @mock.patch("airflow.cli.commands.db_command.execute_interactive") @mock.patch( - "airflow.cli.commands.db_command.settings.engine.url", - make_url("sqlite:////root/airflow/airflow.db") + "airflow.cli.commands.db_command.settings.engine.url", make_url("sqlite:////root/airflow/airflow.db") ) def test_cli_shell_sqlite(self, mock_execute_interactive): db_command.shell(self.parser.parse_args(['db', 'shell'])) - mock_execute_interactive.assert_called_once_with( - ['sqlite3', '/root/airflow/airflow.db'] - ) + mock_execute_interactive.assert_called_once_with(['sqlite3', '/root/airflow/airflow.db']) @mock.patch("airflow.cli.commands.db_command.execute_interactive") @mock.patch( "airflow.cli.commands.db_command.settings.engine.url", - make_url("postgresql+psycopg2://postgres:airflow@postgres/airflow") + make_url("postgresql+psycopg2://postgres:airflow@postgres/airflow"), ) def test_cli_shell_postgres(self, mock_execute_interactive): db_command.shell(self.parser.parse_args(['db', 'shell'])) - mock_execute_interactive.assert_called_once_with( - ['psql'], env=mock.ANY - ) + mock_execute_interactive.assert_called_once_with(['psql'], env=mock.ANY) _, kwargs = mock_execute_interactive.call_args env = kwargs['env'] postgres_env = {k: v for k, v in env.items() if k.startswith('PG')} - self.assertEqual({ - 'PGDATABASE': 'airflow', - 'PGHOST': 'postgres', - 'PGPASSWORD': 'airflow', - 'PGPORT': '', - 'PGUSER': 'postgres' - }, postgres_env) + self.assertEqual( + { + 'PGDATABASE': 'airflow', + 'PGHOST': 'postgres', + 'PGPASSWORD': 'airflow', + 'PGPORT': '', + 'PGUSER': 'postgres', + }, + postgres_env, + ) @mock.patch( "airflow.cli.commands.db_command.settings.engine.url", - make_url("invalid+psycopg2://postgres:airflow@postgres/airflow") + make_url("invalid+psycopg2://postgres:airflow@postgres/airflow"), ) def test_cli_shell_invalid(self): with self.assertRaisesRegex(AirflowException, r"Unknown driver: invalid\+psycopg2"): diff --git a/tests/cli/commands/test_info_command.py b/tests/cli/commands/test_info_command.py index 56eb584f9241a..bd69704f68f92 100644 --- a/tests/cli/commands/test_info_command.py +++ b/tests/cli/commands/test_info_command.py @@ -54,7 +54,10 @@ def test_should_remove_pii_from_path(self): "postgresql+psycopg2://:airflow@postgres/airflow", "postgresql+psycopg2://:PASSWORD@postgres/airflow", ), - ("postgresql+psycopg2://postgres/airflow", "postgresql+psycopg2://postgres/airflow",), + ( + "postgresql+psycopg2://postgres/airflow", + "postgresql+psycopg2://postgres/airflow", + ), ] ) def test_should_remove_pii_from_url(self, before, after): @@ -100,10 +103,12 @@ def test_should_read_config(self): class TestConfigInfoLogging(unittest.TestCase): def test_should_read_logging_configuration(self): - with conf_vars({ - ('logging', 'remote_logging'): 'True', - ('logging', 'remote_base_log_folder'): 'stackdriver://logs-name', - }): + with conf_vars( + { + ('logging', 'remote_logging'): 'True', + ('logging', 'remote_base_log_folder'): 'stackdriver://logs-name', + } + ): importlib.reload(airflow_local_settings) configure_logging() instance = info_command.ConfigInfo(info_command.NullAnonymizer()) @@ -166,7 +171,7 @@ def test_show_info_anonymize(self): "link": "https://file.io/TEST", "expiry": "14 days", }, - } + }, ) def test_show_info_anonymize_fileio(self, mock_requests): with contextlib.redirect_stdout(io.StringIO()) as stdout: diff --git a/tests/cli/commands/test_kubernetes_command.py b/tests/cli/commands/test_kubernetes_command.py index 9e21aaa82a11d..3c4dda8e1bc83 100644 --- a/tests/cli/commands/test_kubernetes_command.py +++ b/tests/cli/commands/test_kubernetes_command.py @@ -35,9 +35,18 @@ def setUpClass(cls): def test_generate_dag_yaml(self): with tempfile.TemporaryDirectory("airflow_dry_run_test/") as directory: file_name = "example_bash_operator_run_after_loop_2020-11-03T00_00_00_plus_00_00.yml" - kubernetes_command.generate_pod_yaml(self.parser.parse_args([ - 'kubernetes', 'generate-dag-yaml', - 'example_bash_operator', "2020-11-03", "--output-path", directory])) + kubernetes_command.generate_pod_yaml( + self.parser.parse_args( + [ + 'kubernetes', + 'generate-dag-yaml', + 'example_bash_operator', + "2020-11-03", + "--output-path", + directory, + ] + ) + ) self.assertEqual(len(os.listdir(directory)), 1) out_dir = directory + "/airflow_yaml_output/" self.assertEqual(len(os.listdir(out_dir)), 6) diff --git a/tests/cli/commands/test_legacy_commands.py b/tests/cli/commands/test_legacy_commands.py index 42a04ff5bb008..444cda07c1f08 100644 --- a/tests/cli/commands/test_legacy_commands.py +++ b/tests/cli/commands/test_legacy_commands.py @@ -24,11 +24,35 @@ from airflow.cli.commands import config_command from airflow.cli.commands.legacy_commands import COMMAND_MAP, check_legacy_command -LEGACY_COMMANDS = ["worker", "flower", "trigger_dag", "delete_dag", "show_dag", "list_dag", - "dag_status", "backfill", "list_dag_runs", "pause", "unpause", "test", - "clear", "list_tasks", "task_failed_deps", "task_state", "run", - "render", "initdb", "resetdb", "upgradedb", "checkdb", "shell", "pool", - "list_users", "create_user", "delete_user"] +LEGACY_COMMANDS = [ + "worker", + "flower", + "trigger_dag", + "delete_dag", + "show_dag", + "list_dag", + "dag_status", + "backfill", + "list_dag_runs", + "pause", + "unpause", + "test", + "clear", + "list_tasks", + "task_failed_deps", + "task_state", + "run", + "render", + "initdb", + "resetdb", + "upgradedb", + "checkdb", + "shell", + "pool", + "list_users", + "create_user", + "delete_user", +] class TestCliDeprecatedCommandsValue(unittest.TestCase): @@ -37,15 +61,16 @@ def setUpClass(cls): cls.parser = cli_parser.get_parser() def test_should_display_value(self): - with self.assertRaises(SystemExit) as cm_exception, \ - contextlib.redirect_stderr(io.StringIO()) as temp_stderr: + with self.assertRaises(SystemExit) as cm_exception, contextlib.redirect_stderr( + io.StringIO() + ) as temp_stderr: config_command.get_value(self.parser.parse_args(['worker'])) self.assertEqual(2, cm_exception.exception.code) self.assertIn( "`airflow worker` command, has been removed, " "please use `airflow celery worker`, see help above.", - temp_stderr.getvalue().strip() + temp_stderr.getvalue().strip(), ) def test_command_map(self): @@ -58,4 +83,5 @@ def test_check_legacy_command(self): check_legacy_command(action, 'list_users') self.assertEqual( str(e.exception), - "argument : `airflow list_users` command, has been removed, please use `airflow users list`") + "argument : `airflow list_users` command, has been removed, please use `airflow users list`", + ) diff --git a/tests/cli/commands/test_pool_command.py b/tests/cli/commands/test_pool_command.py index 135166d033608..be95d09ede633 100644 --- a/tests/cli/commands/test_pool_command.py +++ b/tests/cli/commands/test_pool_command.py @@ -81,18 +81,9 @@ def test_pool_delete(self): def test_pool_import_export(self): # Create two pools first pool_config_input = { - "foo": { - "description": "foo_test", - "slots": 1 - }, - 'default_pool': { - 'description': 'Default pool', - 'slots': 128 - }, - "baz": { - "description": "baz_test", - "slots": 2 - } + "foo": {"description": "foo_test", "slots": 1}, + 'default_pool': {'description': 'Default pool', 'slots': 128}, + "baz": {"description": "baz_test", "slots": 2}, } with open('pools_import.json', mode='w') as file: json.dump(pool_config_input, file) @@ -106,8 +97,7 @@ def test_pool_import_export(self): with open('pools_export.json', mode='r') as file: pool_config_output = json.load(file) self.assertEqual( - pool_config_input, - pool_config_output, - "Input and output pool files are not same") + pool_config_input, pool_config_output, "Input and output pool files are not same" + ) os.remove('pools_import.json') os.remove('pools_export.json') diff --git a/tests/cli/commands/test_role_command.py b/tests/cli/commands/test_role_command.py index 20af87986d6d0..e5a93fcebf8ee 100644 --- a/tests/cli/commands/test_role_command.py +++ b/tests/cli/commands/test_role_command.py @@ -36,6 +36,7 @@ def setUpClass(cls): def setUp(self): from airflow.www import app as application + self.app = application.create_app(testing=True) self.appbuilder = self.app.appbuilder # pylint: disable=no-member self.clear_roles_and_roles() @@ -56,9 +57,7 @@ def test_cli_create_roles(self): self.assertIsNone(self.appbuilder.sm.find_role('FakeTeamA')) self.assertIsNone(self.appbuilder.sm.find_role('FakeTeamB')) - args = self.parser.parse_args([ - 'roles', 'create', 'FakeTeamA', 'FakeTeamB' - ]) + args = self.parser.parse_args(['roles', 'create', 'FakeTeamA', 'FakeTeamB']) role_command.roles_create(args) self.assertIsNotNone(self.appbuilder.sm.find_role('FakeTeamA')) @@ -68,9 +67,7 @@ def test_cli_create_roles_is_reentrant(self): self.assertIsNone(self.appbuilder.sm.find_role('FakeTeamA')) self.assertIsNone(self.appbuilder.sm.find_role('FakeTeamB')) - args = self.parser.parse_args([ - 'roles', 'create', 'FakeTeamA', 'FakeTeamB' - ]) + args = self.parser.parse_args(['roles', 'create', 'FakeTeamA', 'FakeTeamB']) role_command.roles_create(args) diff --git a/tests/cli/commands/test_sync_perm_command.py b/tests/cli/commands/test_sync_perm_command.py index 5019d48d00d83..e1944472e1cf0 100644 --- a/tests/cli/commands/test_sync_perm_command.py +++ b/tests/cli/commands/test_sync_perm_command.py @@ -35,19 +35,17 @@ def setUpClass(cls): @mock.patch("airflow.cli.commands.sync_perm_command.cached_app") @mock.patch("airflow.cli.commands.sync_perm_command.DagBag") def test_cli_sync_perm(self, dagbag_mock, mock_cached_app): - self.expect_dagbag_contains([ - DAG('has_access_control', - access_control={ - 'Public': {permissions.ACTION_CAN_READ} - }), - DAG('no_access_control') - ], dagbag_mock) + self.expect_dagbag_contains( + [ + DAG('has_access_control', access_control={'Public': {permissions.ACTION_CAN_READ}}), + DAG('no_access_control'), + ], + dagbag_mock, + ) appbuilder = mock_cached_app.return_value.appbuilder appbuilder.sm = mock.Mock() - args = self.parser.parse_args([ - 'sync-perm' - ]) + args = self.parser.parse_args(['sync-perm']) sync_perm_command.sync_perm(args) assert appbuilder.sm.sync_roles.call_count == 1 @@ -55,8 +53,7 @@ def test_cli_sync_perm(self, dagbag_mock, mock_cached_app): dagbag_mock.assert_called_once_with(read_dags_from_db=True) self.assertEqual(2, len(appbuilder.sm.sync_perm_for_dag.mock_calls)) appbuilder.sm.sync_perm_for_dag.assert_any_call( - 'has_access_control', - {'Public': {permissions.ACTION_CAN_READ}} + 'has_access_control', {'Public': {permissions.ACTION_CAN_READ}} ) appbuilder.sm.sync_perm_for_dag.assert_any_call( 'no_access_control', diff --git a/tests/cli/commands/test_task_command.py b/tests/cli/commands/test_task_command.py index 94b8ac4e96807..113c961c21e91 100644 --- a/tests/cli/commands/test_task_command.py +++ b/tests/cli/commands/test_task_command.py @@ -65,15 +65,14 @@ def test_cli_list_tasks(self): args = self.parser.parse_args(['tasks', 'list', dag_id]) task_command.task_list(args) - args = self.parser.parse_args([ - 'tasks', 'list', 'example_bash_operator', '--tree']) + args = self.parser.parse_args(['tasks', 'list', 'example_bash_operator', '--tree']) task_command.task_list(args) def test_test(self): """Test the `airflow test` command""" - args = self.parser.parse_args([ - "tasks", "test", "example_python_operator", 'print_the_context', '2018-01-01' - ]) + args = self.parser.parse_args( + ["tasks", "test", "example_python_operator", 'print_the_context', '2018-01-01'] + ) with redirect_stdout(io.StringIO()) as stdout: task_command.task_test(args) @@ -91,13 +90,15 @@ def test_run_naive_taskinstance(self, mock_local_job): dag = self.dagbag.get_dag('test_run_ignores_all_dependencies') task0_id = 'test_run_dependent_task' - args0 = ['tasks', - 'run', - '--ignore-all-dependencies', - '--local', - dag_id, - task0_id, - naive_date.isoformat()] + args0 = [ + 'tasks', + 'run', + '--ignore-all-dependencies', + '--local', + dag_id, + task0_id, + naive_date.isoformat(), + ] task_command.task_run(self.parser.parse_args(args0), dag=dag) mock_local_job.assert_called_once_with( @@ -112,67 +113,119 @@ def test_run_naive_taskinstance(self, mock_local_job): ) def test_cli_test(self): - task_command.task_test(self.parser.parse_args([ - 'tasks', 'test', 'example_bash_operator', 'runme_0', - DEFAULT_DATE.isoformat()])) - task_command.task_test(self.parser.parse_args([ - 'tasks', 'test', 'example_bash_operator', 'runme_0', '--dry-run', - DEFAULT_DATE.isoformat()])) + task_command.task_test( + self.parser.parse_args( + ['tasks', 'test', 'example_bash_operator', 'runme_0', DEFAULT_DATE.isoformat()] + ) + ) + task_command.task_test( + self.parser.parse_args( + ['tasks', 'test', 'example_bash_operator', 'runme_0', '--dry-run', DEFAULT_DATE.isoformat()] + ) + ) def test_cli_test_with_params(self): - task_command.task_test(self.parser.parse_args([ - 'tasks', 'test', 'example_passing_params_via_test_command', 'run_this', - '--task-params', '{"foo":"bar"}', DEFAULT_DATE.isoformat()])) - task_command.task_test(self.parser.parse_args([ - 'tasks', 'test', 'example_passing_params_via_test_command', 'also_run_this', - '--task-params', '{"foo":"bar"}', DEFAULT_DATE.isoformat()])) + task_command.task_test( + self.parser.parse_args( + [ + 'tasks', + 'test', + 'example_passing_params_via_test_command', + 'run_this', + '--task-params', + '{"foo":"bar"}', + DEFAULT_DATE.isoformat(), + ] + ) + ) + task_command.task_test( + self.parser.parse_args( + [ + 'tasks', + 'test', + 'example_passing_params_via_test_command', + 'also_run_this', + '--task-params', + '{"foo":"bar"}', + DEFAULT_DATE.isoformat(), + ] + ) + ) def test_cli_test_with_env_vars(self): with redirect_stdout(io.StringIO()) as stdout: - task_command.task_test(self.parser.parse_args([ - 'tasks', 'test', 'example_passing_params_via_test_command', 'env_var_test_task', - '--env-vars', '{"foo":"bar"}', DEFAULT_DATE.isoformat()])) + task_command.task_test( + self.parser.parse_args( + [ + 'tasks', + 'test', + 'example_passing_params_via_test_command', + 'env_var_test_task', + '--env-vars', + '{"foo":"bar"}', + DEFAULT_DATE.isoformat(), + ] + ) + ) output = stdout.getvalue() self.assertIn('foo=bar', output) self.assertIn('AIRFLOW_TEST_MODE=True', output) def test_cli_run(self): - task_command.task_run(self.parser.parse_args([ - 'tasks', 'run', 'example_bash_operator', 'runme_0', '--local', - DEFAULT_DATE.isoformat()])) + task_command.task_run( + self.parser.parse_args( + ['tasks', 'run', 'example_bash_operator', 'runme_0', '--local', DEFAULT_DATE.isoformat()] + ) + ) @parameterized.expand( [ - ("--ignore-all-dependencies", ), - ("--ignore-depends-on-past", ), + ("--ignore-all-dependencies",), + ("--ignore-depends-on-past",), ("--ignore-dependencies",), ("--force",), ], - ) def test_cli_run_invalid_raw_option(self, option: str): with self.assertRaisesRegex( - AirflowException, - "Option --raw does not work with some of the other options on this command." + AirflowException, "Option --raw does not work with some of the other options on this command." ): - task_command.task_run(self.parser.parse_args([ # type: ignore - 'tasks', 'run', 'example_bash_operator', 'runme_0', DEFAULT_DATE.isoformat(), '--raw', option - ])) + task_command.task_run( + self.parser.parse_args( + [ # type: ignore + 'tasks', + 'run', + 'example_bash_operator', + 'runme_0', + DEFAULT_DATE.isoformat(), + '--raw', + option, + ] + ) + ) def test_cli_run_mutually_exclusive(self): - with self.assertRaisesRegex( - AirflowException, - "Option --raw and --local are mutually exclusive." - ): - task_command.task_run(self.parser.parse_args([ - 'tasks', 'run', 'example_bash_operator', 'runme_0', DEFAULT_DATE.isoformat(), '--raw', - '--local' - ])) + with self.assertRaisesRegex(AirflowException, "Option --raw and --local are mutually exclusive."): + task_command.task_run( + self.parser.parse_args( + [ + 'tasks', + 'run', + 'example_bash_operator', + 'runme_0', + DEFAULT_DATE.isoformat(), + '--raw', + '--local', + ] + ) + ) def test_task_state(self): - task_command.task_state(self.parser.parse_args([ - 'tasks', 'state', 'example_bash_operator', 'runme_0', - DEFAULT_DATE.isoformat()])) + task_command.task_state( + self.parser.parse_args( + ['tasks', 'state', 'example_bash_operator', 'runme_0', DEFAULT_DATE.isoformat()] + ) + ) def test_task_states_for_dag_run(self): @@ -188,58 +241,61 @@ def test_task_states_for_dag_run(self): ti_end = ti2.end_date with redirect_stdout(io.StringIO()) as stdout: - task_command.task_states_for_dag_run(self.parser.parse_args([ - 'tasks', 'states-for-dag-run', 'example_python_operator', defaut_date2.isoformat()])) + task_command.task_states_for_dag_run( + self.parser.parse_args( + ['tasks', 'states-for-dag-run', 'example_python_operator', defaut_date2.isoformat()] + ) + ) actual_out = stdout.getvalue() - formatted_rows = [('example_python_operator', - '2016-01-09 00:00:00+00:00', - 'print_the_context', - 'success', - ti_start, - ti_end)] - - expected = tabulate(formatted_rows, - ['dag', - 'exec_date', - 'task', - 'state', - 'start_date', - 'end_date'], - tablefmt="plain") + formatted_rows = [ + ( + 'example_python_operator', + '2016-01-09 00:00:00+00:00', + 'print_the_context', + 'success', + ti_start, + ti_end, + ) + ] + + expected = tabulate( + formatted_rows, ['dag', 'exec_date', 'task', 'state', 'start_date', 'end_date'], tablefmt="plain" + ) # Check that prints, and log messages, are shown self.assertIn(expected.replace("\n", ""), actual_out.replace("\n", "")) def test_subdag_clear(self): - args = self.parser.parse_args([ - 'tasks', 'clear', 'example_subdag_operator', '--yes']) + args = self.parser.parse_args(['tasks', 'clear', 'example_subdag_operator', '--yes']) task_command.task_clear(args) - args = self.parser.parse_args([ - 'tasks', 'clear', 'example_subdag_operator', '--yes', '--exclude-subdags']) + args = self.parser.parse_args( + ['tasks', 'clear', 'example_subdag_operator', '--yes', '--exclude-subdags'] + ) task_command.task_clear(args) def test_parentdag_downstream_clear(self): - args = self.parser.parse_args([ - 'tasks', 'clear', 'example_subdag_operator.section-1', '--yes']) + args = self.parser.parse_args(['tasks', 'clear', 'example_subdag_operator.section-1', '--yes']) task_command.task_clear(args) - args = self.parser.parse_args([ - 'tasks', 'clear', 'example_subdag_operator.section-1', '--yes', - '--exclude-parentdag']) + args = self.parser.parse_args( + ['tasks', 'clear', 'example_subdag_operator.section-1', '--yes', '--exclude-parentdag'] + ) task_command.task_clear(args) @pytest.mark.quarantined def test_local_run(self): - args = self.parser.parse_args([ - 'tasks', - 'run', - 'example_python_operator', - 'print_the_context', - '2018-04-27T08:39:51.298439+00:00', - '--interactive', - '--subdir', - '/root/dags/example_python_operator.py' - ]) + args = self.parser.parse_args( + [ + 'tasks', + 'run', + 'example_python_operator', + 'print_the_context', + '2018-04-27T08:39:51.298439+00:00', + '--interactive', + '--subdir', + '/root/dags/example_python_operator.py', + ] + ) dag = get_dag(args.subdir, args.dag_id) reset(dag.dag_id) @@ -253,7 +309,6 @@ def test_local_run(self): class TestLogsfromTaskRunCommand(unittest.TestCase): - def setUp(self) -> None: self.dag_id = "test_logging_dag" self.task_id = "test_task" @@ -299,40 +354,52 @@ def test_logging_with_run_task(self): # as that is what gets displayed with conf_vars({('core', 'dags_folder'): self.dag_path}): - task_command.task_run(self.parser.parse_args([ - 'tasks', 'run', self.dag_id, self.task_id, '--local', self.execution_date_str])) + task_command.task_run( + self.parser.parse_args( + ['tasks', 'run', self.dag_id, self.task_id, '--local', self.execution_date_str] + ) + ) with open(self.ti_log_file_path) as l_file: logs = l_file.read() - print(logs) # In case of a test failures this line would show detailed log + print(logs) # In case of a test failures this line would show detailed log logs_list = logs.splitlines() self.assertIn("INFO - Started process", logs) self.assertIn(f"Subtask {self.task_id}", logs) self.assertIn("standard_task_runner.py", logs) - self.assertIn(f"INFO - Running: ['airflow', 'tasks', 'run', '{self.dag_id}', " - f"'{self.task_id}', '{self.execution_date_str}',", logs) + self.assertIn( + f"INFO - Running: ['airflow', 'tasks', 'run', '{self.dag_id}', " + f"'{self.task_id}', '{self.execution_date_str}',", + logs, + ) self.assert_log_line("Log from DAG Logger", logs_list) self.assert_log_line("Log from TI Logger", logs_list) self.assert_log_line("Log from Print statement", logs_list, expect_from_logging_mixin=True) - self.assertIn(f"INFO - Marking task as SUCCESS. dag_id={self.dag_id}, " - f"task_id={self.task_id}, execution_date=20170101T000000", logs) + self.assertIn( + f"INFO - Marking task as SUCCESS. dag_id={self.dag_id}, " + f"task_id={self.task_id}, execution_date=20170101T000000", + logs, + ) @mock.patch("airflow.task.task_runner.standard_task_runner.CAN_FORK", False) def test_logging_with_run_task_subprocess(self): # We are not using self.assertLogs as we want to verify what actually is stored in the Log file # as that is what gets displayed with conf_vars({('core', 'dags_folder'): self.dag_path}): - task_command.task_run(self.parser.parse_args([ - 'tasks', 'run', self.dag_id, self.task_id, '--local', self.execution_date_str])) + task_command.task_run( + self.parser.parse_args( + ['tasks', 'run', self.dag_id, self.task_id, '--local', self.execution_date_str] + ) + ) with open(self.ti_log_file_path) as l_file: logs = l_file.read() - print(logs) # In case of a test failures this line would show detailed log + print(logs) # In case of a test failures this line would show detailed log logs_list = logs.splitlines() self.assertIn(f"Subtask {self.task_id}", logs) @@ -341,10 +408,16 @@ def test_logging_with_run_task_subprocess(self): self.assert_log_line("Log from TI Logger", logs_list) self.assert_log_line("Log from Print statement", logs_list, expect_from_logging_mixin=True) - self.assertIn(f"INFO - Running: ['airflow', 'tasks', 'run', '{self.dag_id}', " - f"'{self.task_id}', '{self.execution_date_str}',", logs) - self.assertIn(f"INFO - Marking task as SUCCESS. dag_id={self.dag_id}, " - f"task_id={self.task_id}, execution_date=20170101T000000", logs) + self.assertIn( + f"INFO - Running: ['airflow', 'tasks', 'run', '{self.dag_id}', " + f"'{self.task_id}', '{self.execution_date_str}',", + logs, + ) + self.assertIn( + f"INFO - Marking task as SUCCESS. dag_id={self.dag_id}, " + f"task_id={self.task_id}, execution_date=20170101T000000", + logs, + ) def test_log_file_template_with_run_task(self): """Verify that the taskinstance has the right context for log_filename_template""" @@ -405,46 +478,43 @@ def test_run_ignores_all_dependencies(self): dag.clear() task0_id = 'test_run_dependent_task' - args0 = ['tasks', - 'run', - '--ignore-all-dependencies', - dag_id, - task0_id, - DEFAULT_DATE.isoformat()] + args0 = ['tasks', 'run', '--ignore-all-dependencies', dag_id, task0_id, DEFAULT_DATE.isoformat()] task_command.task_run(self.parser.parse_args(args0)) - ti_dependent0 = TaskInstance( - task=dag.get_task(task0_id), - execution_date=DEFAULT_DATE) + ti_dependent0 = TaskInstance(task=dag.get_task(task0_id), execution_date=DEFAULT_DATE) ti_dependent0.refresh_from_db() self.assertEqual(ti_dependent0.state, State.FAILED) task1_id = 'test_run_dependency_task' - args1 = ['tasks', - 'run', - '--ignore-all-dependencies', - dag_id, - task1_id, - (DEFAULT_DATE + timedelta(days=1)).isoformat()] + args1 = [ + 'tasks', + 'run', + '--ignore-all-dependencies', + dag_id, + task1_id, + (DEFAULT_DATE + timedelta(days=1)).isoformat(), + ] task_command.task_run(self.parser.parse_args(args1)) ti_dependency = TaskInstance( - task=dag.get_task(task1_id), - execution_date=DEFAULT_DATE + timedelta(days=1)) + task=dag.get_task(task1_id), execution_date=DEFAULT_DATE + timedelta(days=1) + ) ti_dependency.refresh_from_db() self.assertEqual(ti_dependency.state, State.FAILED) task2_id = 'test_run_dependent_task' - args2 = ['tasks', - 'run', - '--ignore-all-dependencies', - dag_id, - task2_id, - (DEFAULT_DATE + timedelta(days=1)).isoformat()] + args2 = [ + 'tasks', + 'run', + '--ignore-all-dependencies', + dag_id, + task2_id, + (DEFAULT_DATE + timedelta(days=1)).isoformat(), + ] task_command.task_run(self.parser.parse_args(args2)) ti_dependent = TaskInstance( - task=dag.get_task(task2_id), - execution_date=DEFAULT_DATE + timedelta(days=1)) + task=dag.get_task(task2_id), execution_date=DEFAULT_DATE + timedelta(days=1) + ) ti_dependent.refresh_from_db() self.assertEqual(ti_dependent.state, State.SUCCESS) diff --git a/tests/cli/commands/test_user_command.py b/tests/cli/commands/test_user_command.py index 70595430ab5c4..188a977676b3c 100644 --- a/tests/cli/commands/test_user_command.py +++ b/tests/cli/commands/test_user_command.py @@ -47,6 +47,7 @@ def setUpClass(cls): def setUp(self): from airflow.www import app as application + self.app = application.create_app(testing=True) self.appbuilder = self.app.appbuilder # pylint: disable=no-member self.clear_roles_and_roles() @@ -64,41 +65,94 @@ def clear_roles_and_roles(self): self.appbuilder.sm.delete_role(role_name) def test_cli_create_user_random_password(self): - args = self.parser.parse_args([ - 'users', 'create', '--username', 'test1', '--lastname', 'doe', - '--firstname', 'jon', - '--email', 'jdoe@foo.com', '--role', 'Viewer', '--use-random-password' - ]) + args = self.parser.parse_args( + [ + 'users', + 'create', + '--username', + 'test1', + '--lastname', + 'doe', + '--firstname', + 'jon', + '--email', + 'jdoe@foo.com', + '--role', + 'Viewer', + '--use-random-password', + ] + ) user_command.users_create(args) def test_cli_create_user_supplied_password(self): - args = self.parser.parse_args([ - 'users', 'create', '--username', 'test2', '--lastname', 'doe', - '--firstname', 'jon', - '--email', 'jdoe@apache.org', '--role', 'Viewer', '--password', 'test' - ]) + args = self.parser.parse_args( + [ + 'users', + 'create', + '--username', + 'test2', + '--lastname', + 'doe', + '--firstname', + 'jon', + '--email', + 'jdoe@apache.org', + '--role', + 'Viewer', + '--password', + 'test', + ] + ) user_command.users_create(args) def test_cli_delete_user(self): - args = self.parser.parse_args([ - 'users', 'create', '--username', 'test3', '--lastname', 'doe', - '--firstname', 'jon', - '--email', 'jdoe@example.com', '--role', 'Viewer', '--use-random-password' - ]) + args = self.parser.parse_args( + [ + 'users', + 'create', + '--username', + 'test3', + '--lastname', + 'doe', + '--firstname', + 'jon', + '--email', + 'jdoe@example.com', + '--role', + 'Viewer', + '--use-random-password', + ] + ) user_command.users_create(args) - args = self.parser.parse_args([ - 'users', 'delete', '--username', 'test3', - ]) + args = self.parser.parse_args( + [ + 'users', + 'delete', + '--username', + 'test3', + ] + ) user_command.users_delete(args) def test_cli_list_users(self): for i in range(0, 3): - args = self.parser.parse_args([ - 'users', 'create', '--username', f'user{i}', '--lastname', - 'doe', '--firstname', 'jon', - '--email', f'jdoe+{i}@gmail.com', '--role', 'Viewer', - '--use-random-password' - ]) + args = self.parser.parse_args( + [ + 'users', + 'create', + '--username', + f'user{i}', + '--lastname', + 'doe', + '--firstname', + 'jon', + '--email', + f'jdoe+{i}@gmail.com', + '--role', + 'Viewer', + '--use-random-password', + ] + ) user_command.users_create(args) with redirect_stdout(io.StringIO()) as stdout: user_command.users_list(self.parser.parse_args(['users', 'list'])) @@ -122,15 +176,19 @@ def assert_user_not_in_roles(email, roles): assert_user_not_in_roles(TEST_USER2_EMAIL, ['Public']) users = [ { - "username": "imported_user1", "lastname": "doe1", - "firstname": "jon", "email": TEST_USER1_EMAIL, - "roles": ["Admin", "Op"] + "username": "imported_user1", + "lastname": "doe1", + "firstname": "jon", + "email": TEST_USER1_EMAIL, + "roles": ["Admin", "Op"], }, { - "username": "imported_user2", "lastname": "doe2", - "firstname": "jon", "email": TEST_USER2_EMAIL, - "roles": ["Public"] - } + "username": "imported_user2", + "lastname": "doe2", + "firstname": "jon", + "email": TEST_USER2_EMAIL, + "roles": ["Public"], + }, ] self._import_users_from_file(users) @@ -139,15 +197,19 @@ def assert_user_not_in_roles(email, roles): users = [ { - "username": "imported_user1", "lastname": "doe1", - "firstname": "jon", "email": TEST_USER1_EMAIL, - "roles": ["Public"] + "username": "imported_user1", + "lastname": "doe1", + "firstname": "jon", + "email": TEST_USER1_EMAIL, + "roles": ["Public"], }, { - "username": "imported_user2", "lastname": "doe2", - "firstname": "jon", "email": TEST_USER2_EMAIL, - "roles": ["Admin"] - } + "username": "imported_user2", + "lastname": "doe2", + "firstname": "jon", + "email": TEST_USER2_EMAIL, + "roles": ["Admin"], + }, ] self._import_users_from_file(users) @@ -157,12 +219,20 @@ def assert_user_not_in_roles(email, roles): assert_user_in_roles(TEST_USER2_EMAIL, ['Admin']) def test_cli_export_users(self): - user1 = {"username": "imported_user1", "lastname": "doe1", - "firstname": "jon", "email": TEST_USER1_EMAIL, - "roles": ["Public"]} - user2 = {"username": "imported_user2", "lastname": "doe2", - "firstname": "jon", "email": TEST_USER2_EMAIL, - "roles": ["Admin"]} + user1 = { + "username": "imported_user1", + "lastname": "doe1", + "firstname": "jon", + "email": TEST_USER1_EMAIL, + "roles": ["Public"], + } + user2 = { + "username": "imported_user2", + "lastname": "doe2", + "firstname": "jon", + "email": TEST_USER2_EMAIL, + "roles": ["Admin"], + } self._import_users_from_file([user1, user2]) users_filename = self._export_users_to_file() @@ -190,63 +260,79 @@ def _import_users_from_file(self, user_list): f.write(json_file_content.encode()) f.flush() - args = self.parser.parse_args([ - 'users', 'import', f.name - ]) + args = self.parser.parse_args(['users', 'import', f.name]) user_command.users_import(args) finally: os.remove(f.name) def _export_users_to_file(self): f = tempfile.NamedTemporaryFile(delete=False) - args = self.parser.parse_args([ - 'users', 'export', f.name - ]) + args = self.parser.parse_args(['users', 'export', f.name]) user_command.users_export(args) return f.name def test_cli_add_user_role(self): - args = self.parser.parse_args([ - 'users', 'create', '--username', 'test4', '--lastname', 'doe', - '--firstname', 'jon', - '--email', TEST_USER1_EMAIL, '--role', 'Viewer', '--use-random-password' - ]) + args = self.parser.parse_args( + [ + 'users', + 'create', + '--username', + 'test4', + '--lastname', + 'doe', + '--firstname', + 'jon', + '--email', + TEST_USER1_EMAIL, + '--role', + 'Viewer', + '--use-random-password', + ] + ) user_command.users_create(args) self.assertFalse( _does_user_belong_to_role(appbuilder=self.appbuilder, email=TEST_USER1_EMAIL, rolename='Op'), - "User should not yet be a member of role 'Op'" + "User should not yet be a member of role 'Op'", ) - args = self.parser.parse_args([ - 'users', 'add-role', '--username', 'test4', '--role', 'Op' - ]) + args = self.parser.parse_args(['users', 'add-role', '--username', 'test4', '--role', 'Op']) user_command.users_manage_role(args, remove=False) self.assertTrue( _does_user_belong_to_role(appbuilder=self.appbuilder, email=TEST_USER1_EMAIL, rolename='Op'), - "User should have been added to role 'Op'" + "User should have been added to role 'Op'", ) def test_cli_remove_user_role(self): - args = self.parser.parse_args([ - 'users', 'create', '--username', 'test4', '--lastname', 'doe', - '--firstname', 'jon', - '--email', TEST_USER1_EMAIL, '--role', 'Viewer', '--use-random-password' - ]) + args = self.parser.parse_args( + [ + 'users', + 'create', + '--username', + 'test4', + '--lastname', + 'doe', + '--firstname', + 'jon', + '--email', + TEST_USER1_EMAIL, + '--role', + 'Viewer', + '--use-random-password', + ] + ) user_command.users_create(args) self.assertTrue( _does_user_belong_to_role(appbuilder=self.appbuilder, email=TEST_USER1_EMAIL, rolename='Viewer'), - "User should have been created with role 'Viewer'" + "User should have been created with role 'Viewer'", ) - args = self.parser.parse_args([ - 'users', 'remove-role', '--username', 'test4', '--role', 'Viewer' - ]) + args = self.parser.parse_args(['users', 'remove-role', '--username', 'test4', '--role', 'Viewer']) user_command.users_manage_role(args, remove=True) self.assertFalse( _does_user_belong_to_role(appbuilder=self.appbuilder, email=TEST_USER1_EMAIL, rolename='Viewer'), - "User should have been removed from role 'Viewer'" + "User should have been removed from role 'Viewer'", ) diff --git a/tests/cli/commands/test_variable_command.py b/tests/cli/commands/test_variable_command.py index 71a4d9842caf1..04c253921eba9 100644 --- a/tests/cli/commands/test_variable_command.py +++ b/tests/cli/commands/test_variable_command.py @@ -43,8 +43,7 @@ def tearDown(self): def test_variables_set(self): """Test variable_set command""" - variable_command.variables_set(self.parser.parse_args([ - 'variables', 'set', 'foo', 'bar'])) + variable_command.variables_set(self.parser.parse_args(['variables', 'set', 'foo', 'bar'])) self.assertIsNotNone(Variable.get("foo")) self.assertRaises(KeyError, Variable.get, "foo1") @@ -52,53 +51,48 @@ def test_variables_get(self): Variable.set('foo', {'foo': 'bar'}, serialize_json=True) with redirect_stdout(io.StringIO()) as stdout: - variable_command.variables_get(self.parser.parse_args([ - 'variables', 'get', 'foo'])) + variable_command.variables_get(self.parser.parse_args(['variables', 'get', 'foo'])) self.assertEqual('{\n "foo": "bar"\n}\n', stdout.getvalue()) def test_get_variable_default_value(self): with redirect_stdout(io.StringIO()) as stdout: - variable_command.variables_get(self.parser.parse_args([ - 'variables', 'get', 'baz', '--default', 'bar'])) + variable_command.variables_get( + self.parser.parse_args(['variables', 'get', 'baz', '--default', 'bar']) + ) self.assertEqual("bar\n", stdout.getvalue()) def test_get_variable_missing_variable(self): with self.assertRaises(SystemExit): - variable_command.variables_get(self.parser.parse_args([ - 'variables', 'get', 'no-existing-VAR'])) + variable_command.variables_get(self.parser.parse_args(['variables', 'get', 'no-existing-VAR'])) def test_variables_set_different_types(self): """Test storage of various data types""" # Set a dict - variable_command.variables_set(self.parser.parse_args([ - 'variables', 'set', 'dict', '{"foo": "oops"}'])) + variable_command.variables_set( + self.parser.parse_args(['variables', 'set', 'dict', '{"foo": "oops"}']) + ) # Set a list - variable_command.variables_set(self.parser.parse_args([ - 'variables', 'set', 'list', '["oops"]'])) + variable_command.variables_set(self.parser.parse_args(['variables', 'set', 'list', '["oops"]'])) # Set str - variable_command.variables_set(self.parser.parse_args([ - 'variables', 'set', 'str', 'hello string'])) + variable_command.variables_set(self.parser.parse_args(['variables', 'set', 'str', 'hello string'])) # Set int - variable_command.variables_set(self.parser.parse_args([ - 'variables', 'set', 'int', '42'])) + variable_command.variables_set(self.parser.parse_args(['variables', 'set', 'int', '42'])) # Set float - variable_command.variables_set(self.parser.parse_args([ - 'variables', 'set', 'float', '42.0'])) + variable_command.variables_set(self.parser.parse_args(['variables', 'set', 'float', '42.0'])) # Set true - variable_command.variables_set(self.parser.parse_args([ - 'variables', 'set', 'true', 'true'])) + variable_command.variables_set(self.parser.parse_args(['variables', 'set', 'true', 'true'])) # Set false - variable_command.variables_set(self.parser.parse_args([ - 'variables', 'set', 'false', 'false'])) + variable_command.variables_set(self.parser.parse_args(['variables', 'set', 'false', 'false'])) # Set none - variable_command.variables_set(self.parser.parse_args([ - 'variables', 'set', 'null', 'null'])) + variable_command.variables_set(self.parser.parse_args(['variables', 'set', 'null', 'null'])) # Export and then import - variable_command.variables_export(self.parser.parse_args([ - 'variables', 'export', 'variables_types.json'])) - variable_command.variables_import(self.parser.parse_args([ - 'variables', 'import', 'variables_types.json'])) + variable_command.variables_export( + self.parser.parse_args(['variables', 'export', 'variables_types.json']) + ) + variable_command.variables_import( + self.parser.parse_args(['variables', 'import', 'variables_types.json']) + ) # Assert value self.assertEqual({'foo': 'oops'}, Variable.get('dict', deserialize_json=True)) @@ -115,26 +109,21 @@ def test_variables_set_different_types(self): def test_variables_list(self): """Test variable_list command""" # Test command is received - variable_command.variables_list(self.parser.parse_args([ - 'variables', 'list'])) + variable_command.variables_list(self.parser.parse_args(['variables', 'list'])) def test_variables_delete(self): """Test variable_delete command""" - variable_command.variables_set(self.parser.parse_args([ - 'variables', 'set', 'foo', 'bar'])) - variable_command.variables_delete(self.parser.parse_args([ - 'variables', 'delete', 'foo'])) + variable_command.variables_set(self.parser.parse_args(['variables', 'set', 'foo', 'bar'])) + variable_command.variables_delete(self.parser.parse_args(['variables', 'delete', 'foo'])) self.assertRaises(KeyError, Variable.get, "foo") def test_variables_import(self): """Test variables_import command""" - variable_command.variables_import(self.parser.parse_args([ - 'variables', 'import', os.devnull])) + variable_command.variables_import(self.parser.parse_args(['variables', 'import', os.devnull])) def test_variables_export(self): """Test variables_export command""" - variable_command.variables_export(self.parser.parse_args([ - 'variables', 'export', os.devnull])) + variable_command.variables_export(self.parser.parse_args(['variables', 'export', os.devnull])) def test_variables_isolation(self): """Test isolation of variables""" @@ -142,30 +131,22 @@ def test_variables_isolation(self): tmp2 = tempfile.NamedTemporaryFile(delete=True) # First export - variable_command.variables_set(self.parser.parse_args([ - 'variables', 'set', 'foo', '{"foo":"bar"}'])) - variable_command.variables_set(self.parser.parse_args([ - 'variables', 'set', 'bar', 'original'])) - variable_command.variables_export(self.parser.parse_args([ - 'variables', 'export', tmp1.name])) + variable_command.variables_set(self.parser.parse_args(['variables', 'set', 'foo', '{"foo":"bar"}'])) + variable_command.variables_set(self.parser.parse_args(['variables', 'set', 'bar', 'original'])) + variable_command.variables_export(self.parser.parse_args(['variables', 'export', tmp1.name])) first_exp = open(tmp1.name) - variable_command.variables_set(self.parser.parse_args([ - 'variables', 'set', 'bar', 'updated'])) - variable_command.variables_set(self.parser.parse_args([ - 'variables', 'set', 'foo', '{"foo":"oops"}'])) - variable_command.variables_delete(self.parser.parse_args([ - 'variables', 'delete', 'foo'])) - variable_command.variables_import(self.parser.parse_args([ - 'variables', 'import', tmp1.name])) + variable_command.variables_set(self.parser.parse_args(['variables', 'set', 'bar', 'updated'])) + variable_command.variables_set(self.parser.parse_args(['variables', 'set', 'foo', '{"foo":"oops"}'])) + variable_command.variables_delete(self.parser.parse_args(['variables', 'delete', 'foo'])) + variable_command.variables_import(self.parser.parse_args(['variables', 'import', tmp1.name])) self.assertEqual('original', Variable.get('bar')) self.assertEqual('{\n "foo": "bar"\n}', Variable.get('foo')) # Second export - variable_command.variables_export(self.parser.parse_args([ - 'variables', 'export', tmp2.name])) + variable_command.variables_export(self.parser.parse_args(['variables', 'export', tmp2.name])) second_exp = open(tmp2.name) self.assertEqual(first_exp.read(), second_exp.read()) diff --git a/tests/cli/commands/test_webserver_command.py b/tests/cli/commands/test_webserver_command.py index 2468466fb6513..0829d8d611fa0 100644 --- a/tests/cli/commands/test_webserver_command.py +++ b/tests/cli/commands/test_webserver_command.py @@ -34,7 +34,6 @@ class TestGunicornMonitor(unittest.TestCase): - def setUp(self) -> None: self.monitor = GunicornMonitor( gunicorn_master_pid=1, @@ -124,8 +123,9 @@ def _prepare_test_file(filepath: str, size: int): file.flush() def test_should_detect_changes_in_directory(self): - with tempfile.TemporaryDirectory() as tempdir,\ - mock.patch("airflow.cli.commands.webserver_command.settings.PLUGINS_FOLDER", tempdir): + with tempfile.TemporaryDirectory() as tempdir, mock.patch( + "airflow.cli.commands.webserver_command.settings.PLUGINS_FOLDER", tempdir + ): self._prepare_test_file(f"{tempdir}/file1.txt", 100) self._prepare_test_file(f"{tempdir}/nested/nested/nested/nested/file2.txt", 200) self._prepare_test_file(f"{tempdir}/file3.txt", 300) @@ -172,7 +172,6 @@ def test_should_detect_changes_in_directory(self): class TestCLIGetNumReadyWorkersRunning(unittest.TestCase): - @classmethod def setUpClass(cls): cls.parser = cli_parser.get_parser() @@ -271,7 +270,7 @@ def test_cli_webserver_foreground(self): "os.environ", AIRFLOW__CORE__DAGS_FOLDER="/dev/null", AIRFLOW__CORE__LOAD_EXAMPLES="False", - AIRFLOW__WEBSERVER__WORKERS="1" + AIRFLOW__WEBSERVER__WORKERS="1", ): # Run webserver in foreground and terminate it. proc = subprocess.Popen(["airflow", "webserver"]) @@ -293,7 +292,7 @@ def test_cli_webserver_foreground_with_pid(self): "os.environ", AIRFLOW__CORE__DAGS_FOLDER="/dev/null", AIRFLOW__CORE__LOAD_EXAMPLES="False", - AIRFLOW__WEBSERVER__WORKERS="1" + AIRFLOW__WEBSERVER__WORKERS="1", ): proc = subprocess.Popen(["airflow", "webserver", "--pid", pidfile]) self.assertEqual(None, proc.poll()) @@ -307,12 +306,12 @@ def test_cli_webserver_foreground_with_pid(self): @pytest.mark.quarantined def test_cli_webserver_background(self): - with tempfile.TemporaryDirectory(prefix="gunicorn") as tmpdir, \ - mock.patch.dict( - "os.environ", - AIRFLOW__CORE__DAGS_FOLDER="/dev/null", - AIRFLOW__CORE__LOAD_EXAMPLES="False", - AIRFLOW__WEBSERVER__WORKERS="1"): + with tempfile.TemporaryDirectory(prefix="gunicorn") as tmpdir, mock.patch.dict( + "os.environ", + AIRFLOW__CORE__DAGS_FOLDER="/dev/null", + AIRFLOW__CORE__LOAD_EXAMPLES="False", + AIRFLOW__WEBSERVER__WORKERS="1", + ): pidfile_webserver = f"{tmpdir}/pidflow-webserver.pid" pidfile_monitor = f"{tmpdir}/pidflow-webserver-monitor.pid" stdout = f"{tmpdir}/airflow-webserver.out" @@ -320,15 +319,21 @@ def test_cli_webserver_background(self): logfile = f"{tmpdir}/airflow-webserver.log" try: # Run webserver as daemon in background. Note that the wait method is not called. - proc = subprocess.Popen([ - "airflow", - "webserver", - "--daemon", - "--pid", pidfile_webserver, - "--stdout", stdout, - "--stderr", stderr, - "--log-file", logfile, - ]) + proc = subprocess.Popen( + [ + "airflow", + "webserver", + "--daemon", + "--pid", + pidfile_webserver, + "--stdout", + stdout, + "--stderr", + stderr, + "--log-file", + logfile, + ] + ) self.assertEqual(None, proc.poll()) pid_monitor = self._wait_pidfile(pidfile_monitor) @@ -354,8 +359,9 @@ def test_cli_webserver_background(self): raise # Patch for causing webserver timeout - @mock.patch("airflow.cli.commands.webserver_command.GunicornMonitor._get_num_workers_running", - return_value=0) + @mock.patch( + "airflow.cli.commands.webserver_command.GunicornMonitor._get_num_workers_running", return_value=0 + ) def test_cli_webserver_shutdown_when_gunicorn_master_is_killed(self, _): # Shorten timeout so that this test doesn't take too long time args = self.parser.parse_args(['webserver']) @@ -370,8 +376,7 @@ def test_cli_webserver_debug(self): sleep(3) # wait for webserver to start return_code = proc.poll() self.assertEqual( - None, - return_code, - f"webserver terminated with return code {return_code} in debug mode") + None, return_code, f"webserver terminated with return code {return_code} in debug mode" + ) proc.terminate() self.assertEqual(-15, proc.wait(60)) diff --git a/tests/cli/test_cli_parser.py b/tests/cli/test_cli_parser.py index 9593d411149f5..4a2dd8524420f 100644 --- a/tests/cli/test_cli_parser.py +++ b/tests/cli/test_cli_parser.py @@ -35,34 +35,28 @@ class TestCli(TestCase): - def test_arg_option_long_only(self): """ Test if the name of cli.args long option valid """ optional_long = [ - arg - for arg in cli_args.values() - if len(arg.flags) == 1 and arg.flags[0].startswith("-") + arg for arg in cli_args.values() if len(arg.flags) == 1 and arg.flags[0].startswith("-") ] for arg in optional_long: - self.assertIsNone(ILLEGAL_LONG_OPTION_PATTERN.match(arg.flags[0]), - f"{arg.flags[0]} is not match") + self.assertIsNone(ILLEGAL_LONG_OPTION_PATTERN.match(arg.flags[0]), f"{arg.flags[0]} is not match") def test_arg_option_mix_short_long(self): """ Test if the name of cli.args mix option (-s, --long) valid """ optional_mix = [ - arg - for arg in cli_args.values() - if len(arg.flags) == 2 and arg.flags[0].startswith("-") + arg for arg in cli_args.values() if len(arg.flags) == 2 and arg.flags[0].startswith("-") ] for arg in optional_mix: - self.assertIsNotNone(LEGAL_SHORT_OPTION_PATTERN.match(arg.flags[0]), - f"{arg.flags[0]} is not match") - self.assertIsNone(ILLEGAL_LONG_OPTION_PATTERN.match(arg.flags[1]), - f"{arg.flags[1]} is not match") + self.assertIsNotNone( + LEGAL_SHORT_OPTION_PATTERN.match(arg.flags[0]), f"{arg.flags[0]} is not match" + ) + self.assertIsNone(ILLEGAL_LONG_OPTION_PATTERN.match(arg.flags[1]), f"{arg.flags[1]} is not match") def test_subcommand_conflict(self): """ @@ -75,8 +69,9 @@ def test_subcommand_conflict(self): } for group_name, sub in subcommand.items(): name = [command.name.lower() for command in sub] - self.assertEqual(len(name), len(set(name)), - f"Command group {group_name} have conflict subcommand") + self.assertEqual( + len(name), len(set(name)), f"Command group {group_name} have conflict subcommand" + ) def test_subcommand_arg_name_conflict(self): """ @@ -90,9 +85,11 @@ def test_subcommand_arg_name_conflict(self): for group, command in subcommand.items(): for com in command: conflict_arg = [arg for arg, count in Counter(com.args).items() if count > 1] - self.assertListEqual([], conflict_arg, - f"Command group {group} function {com.name} have " - f"conflict args name {conflict_arg}") + self.assertListEqual( + [], + conflict_arg, + f"Command group {group} function {com.name} have " f"conflict args name {conflict_arg}", + ) def test_subcommand_arg_flag_conflict(self): """ @@ -106,35 +103,35 @@ def test_subcommand_arg_flag_conflict(self): for group, command in subcommand.items(): for com in command: position = [ - a.flags[0] - for a in com.args - if (len(a.flags) == 1 - and not a.flags[0].startswith("-")) + a.flags[0] for a in com.args if (len(a.flags) == 1 and not a.flags[0].startswith("-")) ] conflict_position = [arg for arg, count in Counter(position).items() if count > 1] - self.assertListEqual([], conflict_position, - f"Command group {group} function {com.name} have conflict " - f"position flags {conflict_position}") - - long_option = [a.flags[0] - for a in com.args - if (len(a.flags) == 1 - and a.flags[0].startswith("-"))] + \ - [a.flags[1] - for a in com.args if len(a.flags) == 2] + self.assertListEqual( + [], + conflict_position, + f"Command group {group} function {com.name} have conflict " + f"position flags {conflict_position}", + ) + + long_option = [ + a.flags[0] for a in com.args if (len(a.flags) == 1 and a.flags[0].startswith("-")) + ] + [a.flags[1] for a in com.args if len(a.flags) == 2] conflict_long_option = [arg for arg, count in Counter(long_option).items() if count > 1] - self.assertListEqual([], conflict_long_option, - f"Command group {group} function {com.name} have conflict " - f"long option flags {conflict_long_option}") - - short_option = [ - a.flags[0] - for a in com.args if len(a.flags) == 2 - ] + self.assertListEqual( + [], + conflict_long_option, + f"Command group {group} function {com.name} have conflict " + f"long option flags {conflict_long_option}", + ) + + short_option = [a.flags[0] for a in com.args if len(a.flags) == 2] conflict_short_option = [arg for arg, count in Counter(short_option).items() if count > 1] - self.assertEqual([], conflict_short_option, - f"Command group {group} function {com.name} have conflict " - f"short option flags {conflict_short_option}") + self.assertEqual( + [], + conflict_short_option, + f"Command group {group} function {com.name} have conflict " + f"short option flags {conflict_short_option}", + ) def test_falsy_default_value(self): arg = cli_parser.Arg(("--test",), default=0, type=int) @@ -166,9 +163,7 @@ def test_should_display_help(self): for command_as_args in ( [[top_command.name]] if isinstance(top_command, cli_parser.ActionCommand) - else [ - [top_command.name, nested_command.name] for nested_command in top_command.subcommands - ] + else [[top_command.name, nested_command.name] for nested_command in top_command.subcommands] ) ] for cmd_args in all_command_as_args: diff --git a/tests/cluster_policies/__init__.py b/tests/cluster_policies/__init__.py index 4d9c4b356b020..77e5253f2254b 100644 --- a/tests/cluster_policies/__init__.py +++ b/tests/cluster_policies/__init__.py @@ -24,10 +24,12 @@ # [START example_cluster_policy_rule] def task_must_have_owners(task: BaseOperator): - if not task.owner or task.owner.lower() == conf.get('operators', - 'default_owner'): + if not task.owner or task.owner.lower() == conf.get('operators', 'default_owner'): raise AirflowClusterPolicyViolation( - f'''Task must have non-None non-default owner. Current value: {task.owner}''') + f'''Task must have non-None non-default owner. Current value: {task.owner}''' + ) + + # [END example_cluster_policy_rule] @@ -50,10 +52,13 @@ def _check_task_rules(current_task: BaseOperator): raise AirflowClusterPolicyViolation( f"DAG policy violation (DAG ID: {current_task.dag_id}, Path: {current_task.dag.filepath}):\n" f"Notices:\n" - f"{notices_list}") + f"{notices_list}" + ) def cluster_policy(task: BaseOperator): """Ensure Tasks have non-default owners.""" _check_task_rules(task) + + # [END example_list_of_cluster_policy_rules] diff --git a/tests/conftest.py b/tests/conftest.py index dee74c1e327cb..98a7d1f5b0478 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -30,11 +30,12 @@ os.environ["AIRFLOW__CORE__DAGS_FOLDER"] = os.path.join(tests_directory, "dags") os.environ["AIRFLOW__CORE__UNIT_TEST_MODE"] = "True" -os.environ["AWS_DEFAULT_REGION"] = (os.environ.get("AWS_DEFAULT_REGION") or "us-east-1") -os.environ["CREDENTIALS_DIR"] = (os.environ.get('CREDENTIALS_DIR') or "/files/airflow-breeze-config/keys") +os.environ["AWS_DEFAULT_REGION"] = os.environ.get("AWS_DEFAULT_REGION") or "us-east-1" +os.environ["CREDENTIALS_DIR"] = os.environ.get('CREDENTIALS_DIR') or "/files/airflow-breeze-config/keys" from tests.test_utils.perf.perf_kit.sqlalchemy import ( # noqa isort:skip # pylint: disable=wrong-import-position - count_queries, trace_queries + count_queries, + trace_queries, ) @@ -60,6 +61,7 @@ def reset_db(): """ from airflow.utils import db + db.resetdb() yield @@ -94,11 +96,7 @@ def pytest_print(text): if columns == ['num']: # It is very unlikely that the user wants to display only numbers, but probably # the user just wants to count the queries. - exit_stack.enter_context( # pylint: disable=no-member - count_queries( - print_fn=pytest_print - ) - ) + exit_stack.enter_context(count_queries(print_fn=pytest_print)) # pylint: disable=no-member elif any(c for c in ['time', 'trace', 'sql', 'parameters']): exit_stack.enter_context( # pylint: disable=no-member trace_queries( @@ -107,7 +105,7 @@ def pytest_print(text): display_trace='trace' in columns, display_sql='sql' in columns, display_parameters='parameters' in columns, - print_fn=pytest_print + print_fn=pytest_print, ) ) @@ -130,7 +128,7 @@ def pytest_addoption(parser): action="append", metavar="INTEGRATIONS", help="only run tests matching integration specified: " - "[cassandra,kerberos,mongo,openldap,presto,rabbitmq,redis]. ", + "[cassandra,kerberos,mongo,openldap,presto,rabbitmq,redis]. ", ) group.addoption( "--backend", @@ -177,6 +175,7 @@ def initial_db_init(): os.system("airflow resetdb -y") else: from airflow.utils import db + db.resetdb() @@ -193,6 +192,7 @@ def breeze_test_helper(request): return from airflow import __version__ + if __version__.startswith("1.10"): os.environ['RUN_AIRFLOW_1_10'] = "true" @@ -235,18 +235,10 @@ def breeze_test_helper(request): def pytest_configure(config): - config.addinivalue_line( - "markers", "integration(name): mark test to run with named integration" - ) - config.addinivalue_line( - "markers", "backend(name): mark test to run with named backend" - ) - config.addinivalue_line( - "markers", "system(name): mark test to run with named system" - ) - config.addinivalue_line( - "markers", "long_running: mark test that run for a long time (many minutes)" - ) + config.addinivalue_line("markers", "integration(name): mark test to run with named integration") + config.addinivalue_line("markers", "backend(name): mark test to run with named backend") + config.addinivalue_line("markers", "system(name): mark test to run with named system") + config.addinivalue_line("markers", "long_running: mark test that run for a long time (many minutes)") config.addinivalue_line( "markers", "quarantined: mark test that are in quarantine (i.e. flaky, need to be isolated and fixed)" ) @@ -256,9 +248,7 @@ def pytest_configure(config): config.addinivalue_line( "markers", "credential_file(name): mark tests that require credential file in CREDENTIALS_DIR" ) - config.addinivalue_line( - "markers", "airflow_2: mark tests that works only on Airflow 2.0 / master" - ) + config.addinivalue_line("markers", "airflow_2: mark tests that works only on Airflow 2.0 / master") def skip_if_not_marked_with_integration(selected_integrations, item): @@ -266,10 +256,11 @@ def skip_if_not_marked_with_integration(selected_integrations, item): integration_name = marker.args[0] if integration_name in selected_integrations or "all" in selected_integrations: return - pytest.skip("The test is skipped because it does not have the right integration marker. " - "Only tests marked with pytest.mark.integration(INTEGRATION) are run with INTEGRATION" - " being one of {integration}. {item}". - format(integration=selected_integrations, item=item)) + pytest.skip( + "The test is skipped because it does not have the right integration marker. " + "Only tests marked with pytest.mark.integration(INTEGRATION) are run with INTEGRATION" + " being one of {integration}. {item}".format(integration=selected_integrations, item=item) + ) def skip_if_not_marked_with_backend(selected_backend, item): @@ -277,10 +268,11 @@ def skip_if_not_marked_with_backend(selected_backend, item): backend_names = marker.args if selected_backend in backend_names: return - pytest.skip("The test is skipped because it does not have the right backend marker " - "Only tests marked with pytest.mark.backend('{backend}') are run" - ": {item}". - format(backend=selected_backend, item=item)) + pytest.skip( + "The test is skipped because it does not have the right backend marker " + "Only tests marked with pytest.mark.backend('{backend}') are run" + ": {item}".format(backend=selected_backend, item=item) + ) def skip_if_not_marked_with_system(selected_systems, item): @@ -288,39 +280,46 @@ def skip_if_not_marked_with_system(selected_systems, item): systems_name = marker.args[0] if systems_name in selected_systems or "all" in selected_systems: return - pytest.skip("The test is skipped because it does not have the right system marker. " - "Only tests marked with pytest.mark.system(SYSTEM) are run with SYSTEM" - " being one of {systems}. {item}". - format(systems=selected_systems, item=item)) + pytest.skip( + "The test is skipped because it does not have the right system marker. " + "Only tests marked with pytest.mark.system(SYSTEM) are run with SYSTEM" + " being one of {systems}. {item}".format(systems=selected_systems, item=item) + ) def skip_system_test(item): for marker in item.iter_markers(name="system"): - pytest.skip("The test is skipped because it has system marker. " - "System tests are only run when --system flag " - "with the right system ({system}) is passed to pytest. {item}". - format(system=marker.args[0], item=item)) + pytest.skip( + "The test is skipped because it has system marker. " + "System tests are only run when --system flag " + "with the right system ({system}) is passed to pytest. {item}".format( + system=marker.args[0], item=item + ) + ) def skip_long_running_test(item): for _ in item.iter_markers(name="long_running"): - pytest.skip("The test is skipped because it has long_running marker. " - "And --include-long-running flag is not passed to pytest. {item}". - format(item=item)) + pytest.skip( + "The test is skipped because it has long_running marker. " + "And --include-long-running flag is not passed to pytest. {item}".format(item=item) + ) def skip_quarantined_test(item): for _ in item.iter_markers(name="quarantined"): - pytest.skip("The test is skipped because it has quarantined marker. " - "And --include-quarantined flag is passed to pytest. {item}". - format(item=item)) + pytest.skip( + "The test is skipped because it has quarantined marker. " + "And --include-quarantined flag is passed to pytest. {item}".format(item=item) + ) def skip_heisen_test(item): for _ in item.iter_markers(name="heisentests"): - pytest.skip("The test is skipped because it has heisentests marker. " - "And --include-heisentests flag is passed to pytest. {item}". - format(item=item)) + pytest.skip( + "The test is skipped because it has heisentests marker. " + "And --include-heisentests flag is passed to pytest. {item}".format(item=item) + ) def skip_if_integration_disabled(marker, item): @@ -328,12 +327,17 @@ def skip_if_integration_disabled(marker, item): environment_variable_name = "INTEGRATION_" + integration_name.upper() environment_variable_value = os.environ.get(environment_variable_name) if not environment_variable_value or environment_variable_value != "true": - pytest.skip("The test requires {integration_name} integration started and " - "{name} environment variable to be set to true (it is '{value}')." - " It can be set by specifying '--integration {integration_name}' at breeze startup" - ": {item}". - format(name=environment_variable_name, value=environment_variable_value, - integration_name=integration_name, item=item)) + pytest.skip( + "The test requires {integration_name} integration started and " + "{name} environment variable to be set to true (it is '{value}')." + " It can be set by specifying '--integration {integration_name}' at breeze startup" + ": {item}".format( + name=environment_variable_name, + value=environment_variable_value, + integration_name=integration_name, + item=item, + ) + ) def skip_if_wrong_backend(marker, item): @@ -341,12 +345,17 @@ def skip_if_wrong_backend(marker, item): environment_variable_name = "BACKEND" environment_variable_value = os.environ.get(environment_variable_name) if not environment_variable_value or environment_variable_value not in valid_backend_names: - pytest.skip("The test requires one of {valid_backend_names} backend started and " - "{name} environment variable to be set to 'true' (it is '{value}')." - " It can be set by specifying backend at breeze startup" - ": {item}". - format(name=environment_variable_name, value=environment_variable_value, - valid_backend_names=valid_backend_names, item=item)) + pytest.skip( + "The test requires one of {valid_backend_names} backend started and " + "{name} environment variable to be set to 'true' (it is '{value}')." + " It can be set by specifying backend at breeze startup" + ": {item}".format( + name=environment_variable_name, + value=environment_variable_value, + valid_backend_names=valid_backend_names, + item=item, + ) + ) def skip_if_credential_file_missing(item): @@ -354,8 +363,7 @@ def skip_if_credential_file_missing(item): credential_file = marker.args[0] credential_path = os.path.join(os.environ.get('CREDENTIALS_DIR'), credential_file) if not os.path.exists(credential_path): - pytest.skip("The test requires credential file {path}: {item}". - format(path=credential_path, item=item)) + pytest.skip(f"The test requires credential file {credential_path}: {item}") def skip_if_airflow_2_test(item): diff --git a/tests/core/test_config_templates.py b/tests/core/test_config_templates.py index af6cc4cc91e12..f603739074bb5 100644 --- a/tests/core/test_config_templates.py +++ b/tests/core/test_config_templates.py @@ -69,24 +69,34 @@ 'admin', 'elasticsearch', 'elasticsearch_configs', - 'kubernetes' + 'kubernetes', ] class TestAirflowCfg(unittest.TestCase): - @parameterized.expand([ - ("default_airflow.cfg",), - ("default_test.cfg",), - ]) + @parameterized.expand( + [ + ("default_airflow.cfg",), + ("default_test.cfg",), + ] + ) def test_should_be_ascii_file(self, filename: str): with open(os.path.join(CONFIG_TEMPLATES_FOLDER, filename), "rb") as f: content = f.read().decode("ascii") self.assertTrue(content) - @parameterized.expand([ - ("default_airflow.cfg", DEFAULT_AIRFLOW_SECTIONS,), - ("default_test.cfg", DEFAULT_TEST_SECTIONS,), - ]) + @parameterized.expand( + [ + ( + "default_airflow.cfg", + DEFAULT_AIRFLOW_SECTIONS, + ), + ( + "default_test.cfg", + DEFAULT_TEST_SECTIONS, + ), + ] + ) def test_should_be_ini_file(self, filename: str, expected_sections): filepath = os.path.join(CONFIG_TEMPLATES_FOLDER, filename) config = configparser.ConfigParser() diff --git a/tests/core/test_configuration.py b/tests/core/test_configuration.py index 7a07c3be02105..e202ce5b51a77 100644 --- a/tests/core/test_configuration.py +++ b/tests/core/test_configuration.py @@ -26,48 +26,49 @@ from airflow import configuration from airflow.configuration import ( - DEFAULT_CONFIG, AirflowConfigException, AirflowConfigParser, conf, expand_env_var, get_airflow_config, - get_airflow_home, parameterized_config, run_command, + DEFAULT_CONFIG, + AirflowConfigException, + AirflowConfigParser, + conf, + expand_env_var, + get_airflow_config, + get_airflow_home, + parameterized_config, + run_command, ) from tests.test_utils.config import conf_vars from tests.test_utils.reset_warning_registry import reset_warning_registry -@unittest.mock.patch.dict('os.environ', { - 'AIRFLOW__TESTSECTION__TESTKEY': 'testvalue', - 'AIRFLOW__TESTSECTION__TESTPERCENT': 'with%percent', - 'AIRFLOW__TESTCMDENV__ITSACOMMAND_CMD': 'echo -n "OK"', - 'AIRFLOW__TESTCMDENV__NOTACOMMAND_CMD': 'echo -n "NOT OK"' -}) +@unittest.mock.patch.dict( + 'os.environ', + { + 'AIRFLOW__TESTSECTION__TESTKEY': 'testvalue', + 'AIRFLOW__TESTSECTION__TESTPERCENT': 'with%percent', + 'AIRFLOW__TESTCMDENV__ITSACOMMAND_CMD': 'echo -n "OK"', + 'AIRFLOW__TESTCMDENV__NOTACOMMAND_CMD': 'echo -n "NOT OK"', + }, +) class TestConf(unittest.TestCase): - def test_airflow_home_default(self): with unittest.mock.patch.dict('os.environ'): if 'AIRFLOW_HOME' in os.environ: del os.environ['AIRFLOW_HOME'] - self.assertEqual( - get_airflow_home(), - expand_env_var('~/airflow')) + self.assertEqual(get_airflow_home(), expand_env_var('~/airflow')) def test_airflow_home_override(self): with unittest.mock.patch.dict('os.environ', AIRFLOW_HOME='/path/to/airflow'): - self.assertEqual( - get_airflow_home(), - '/path/to/airflow') + self.assertEqual(get_airflow_home(), '/path/to/airflow') def test_airflow_config_default(self): with unittest.mock.patch.dict('os.environ'): if 'AIRFLOW_CONFIG' in os.environ: del os.environ['AIRFLOW_CONFIG'] - self.assertEqual( - get_airflow_config('/home/airflow'), - expand_env_var('/home/airflow/airflow.cfg')) + self.assertEqual(get_airflow_config('/home/airflow'), expand_env_var('/home/airflow/airflow.cfg')) def test_airflow_config_override(self): with unittest.mock.patch.dict('os.environ', AIRFLOW_CONFIG='/path/to/airflow/airflow.cfg'): - self.assertEqual( - get_airflow_config('/home//airflow'), - '/path/to/airflow/airflow.cfg') + self.assertEqual(get_airflow_config('/home//airflow'), '/path/to/airflow/airflow.cfg') @conf_vars({("core", "percent"): "with%%inside"}) def test_case_sensitivity(self): @@ -87,15 +88,13 @@ def test_env_var_config(self): self.assertTrue(conf.has_option('testsection', 'testkey')) with unittest.mock.patch.dict( - 'os.environ', - AIRFLOW__KUBERNETES_ENVIRONMENT_VARIABLES__AIRFLOW__TESTSECTION__TESTKEY='nested' + 'os.environ', AIRFLOW__KUBERNETES_ENVIRONMENT_VARIABLES__AIRFLOW__TESTSECTION__TESTKEY='nested' ): opt = conf.get('kubernetes_environment_variables', 'AIRFLOW__TESTSECTION__TESTKEY') self.assertEqual(opt, 'nested') @mock.patch.dict( - 'os.environ', - AIRFLOW__KUBERNETES_ENVIRONMENT_VARIABLES__AIRFLOW__TESTSECTION__TESTKEY='nested' + 'os.environ', AIRFLOW__KUBERNETES_ENVIRONMENT_VARIABLES__AIRFLOW__TESTSECTION__TESTKEY='nested' ) @conf_vars({("core", "percent"): "with%%inside"}) def test_conf_as_dict(self): @@ -109,18 +108,15 @@ def test_conf_as_dict(self): # test env vars self.assertEqual(cfg_dict['testsection']['testkey'], '< hidden >') self.assertEqual( - cfg_dict['kubernetes_environment_variables']['AIRFLOW__TESTSECTION__TESTKEY'], - '< hidden >') + cfg_dict['kubernetes_environment_variables']['AIRFLOW__TESTSECTION__TESTKEY'], '< hidden >' + ) def test_conf_as_dict_source(self): # test display_source cfg_dict = conf.as_dict(display_source=True) - self.assertEqual( - cfg_dict['core']['load_examples'][1], 'airflow.cfg') - self.assertEqual( - cfg_dict['core']['load_default_connections'][1], 'airflow.cfg') - self.assertEqual( - cfg_dict['testsection']['testkey'], ('< hidden >', 'env var')) + self.assertEqual(cfg_dict['core']['load_examples'][1], 'airflow.cfg') + self.assertEqual(cfg_dict['core']['load_default_connections'][1], 'airflow.cfg') + self.assertEqual(cfg_dict['testsection']['testkey'], ('< hidden >', 'env var')) def test_conf_as_dict_sensitive(self): # test display_sensitive @@ -130,8 +126,7 @@ def test_conf_as_dict_sensitive(self): # test display_source and display_sensitive cfg_dict = conf.as_dict(display_sensitive=True, display_source=True) - self.assertEqual( - cfg_dict['testsection']['testkey'], ('testvalue', 'env var')) + self.assertEqual(cfg_dict['testsection']['testkey'], ('testvalue', 'env var')) @conf_vars({("core", "percent"): "with%%inside"}) def test_conf_as_dict_raw(self): @@ -166,8 +161,7 @@ def test_command_precedence(self): key6 = value6 ''' - test_conf = AirflowConfigParser( - default_config=parameterized_config(test_config_default)) + test_conf = AirflowConfigParser(default_config=parameterized_config(test_config_default)) test_conf.read_string(test_config) test_conf.sensitive_config_values = test_conf.sensitive_config_values | { ('test', 'key2'), @@ -204,10 +198,12 @@ def test_command_precedence(self): self.assertEqual('printf key4_result', cfg_dict['test']['key4_cmd']) @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") - @conf_vars({ - ("secrets", "backend"): "airflow.providers.hashicorp.secrets.vault.VaultBackend", - ("secrets", "backend_kwargs"): '{"url": "http://127.0.0.1:8200", "token": "token"}', - }) + @conf_vars( + { + ("secrets", "backend"): "airflow.providers.hashicorp.secrets.vault.VaultBackend", + ("secrets", "backend_kwargs"): '{"url": "http://127.0.0.1:8200", "token": "token"}', + } + ) def test_config_from_secret_backend(self, mock_hvac): """Get Config Value from a Secret Backend""" mock_client = mock.MagicMock() @@ -217,14 +213,18 @@ def test_config_from_secret_backend(self, mock_hvac): 'lease_id': '', 'renewable': False, 'lease_duration': 0, - 'data': {'data': {'value': 'sqlite:////Users/airflow/airflow/airflow.db'}, - 'metadata': {'created_time': '2020-03-28T02:10:54.301784Z', - 'deletion_time': '', - 'destroyed': False, - 'version': 1}}, + 'data': { + 'data': {'value': 'sqlite:////Users/airflow/airflow/airflow.db'}, + 'metadata': { + 'created_time': '2020-03-28T02:10:54.301784Z', + 'deletion_time': '', + 'destroyed': False, + 'version': 1, + }, + }, 'wrap_info': None, 'warnings': None, - 'auth': None + 'auth': None, } test_config = '''[test] @@ -241,7 +241,8 @@ def test_config_from_secret_backend(self, mock_hvac): } self.assertEqual( - 'sqlite:////Users/airflow/airflow/airflow.db', test_conf.get('test', 'sql_alchemy_conn')) + 'sqlite:////Users/airflow/airflow/airflow.db', test_conf.get('test', 'sql_alchemy_conn') + ) def test_getboolean(self): """Test AirflowConfigParser.getboolean""" @@ -268,7 +269,7 @@ def test_getboolean(self): re.escape( 'Failed to convert value to bool. Please check "key1" key in "type_validation" section. ' 'Current value: "non_bool_value".' - ) + ), ): test_conf.getboolean('type_validation', 'key1') self.assertTrue(isinstance(test_conf.getboolean('true', 'key3'), bool)) @@ -295,7 +296,7 @@ def test_getint(self): re.escape( 'Failed to convert value to int. Please check "key1" key in "invalid" section. ' 'Current value: "str".' - ) + ), ): test_conf.getint('invalid', 'key1') self.assertTrue(isinstance(test_conf.getint('valid', 'key2'), int)) @@ -316,7 +317,7 @@ def test_getfloat(self): re.escape( 'Failed to convert value to float. Please check "key1" key in "invalid" section. ' 'Current value: "str".' - ) + ), ): test_conf.getfloat('invalid', 'key1') self.assertTrue(isinstance(test_conf.getfloat('valid', 'key2'), float)) @@ -342,8 +343,7 @@ def test_remove_option(self): key2 = airflow ''' - test_conf = AirflowConfigParser( - default_config=parameterized_config(test_config_default)) + test_conf = AirflowConfigParser(default_config=parameterized_config(test_config_default)) test_conf.read_string(test_config) self.assertEqual('hello', test_conf.get('test', 'key1')) @@ -366,20 +366,13 @@ def test_getsection(self): [testsection] key3 = value3 ''' - test_conf = AirflowConfigParser( - default_config=parameterized_config(test_config_default)) + test_conf = AirflowConfigParser(default_config=parameterized_config(test_config_default)) test_conf.read_string(test_config) + self.assertEqual(OrderedDict([('key1', 'hello'), ('key2', 'airflow')]), test_conf.getsection('test')) self.assertEqual( - OrderedDict([('key1', 'hello'), ('key2', 'airflow')]), - test_conf.getsection('test') - ) - self.assertEqual( - OrderedDict([ - ('key3', 'value3'), - ('testkey', 'testvalue'), - ('testpercent', 'with%percent')]), - test_conf.getsection('testsection') + OrderedDict([('key3', 'value3'), ('testkey', 'testvalue'), ('testpercent', 'with%percent')]), + test_conf.getsection('testsection'), ) def test_get_section_should_respect_cmd_env_variable(self): @@ -390,9 +383,7 @@ def test_get_section_should_respect_cmd_env_variable(self): os.chmod(cmd_file.name, 0o0555) cmd_file.close() - with mock.patch.dict( - "os.environ", {"AIRFLOW__KUBERNETES__GIT_PASSWORD_CMD": cmd_file.name} - ): + with mock.patch.dict("os.environ", {"AIRFLOW__KUBERNETES__GIT_PASSWORD_CMD": cmd_file.name}): content = conf.getsection("kubernetes") os.unlink(cmd_file.name) self.assertEqual(content["git_password"], "difficult_unpredictable_cat_password") @@ -406,13 +397,12 @@ def test_kubernetes_environment_variables_section(self): test_config_default = ''' [kubernetes_environment_variables] ''' - test_conf = AirflowConfigParser( - default_config=parameterized_config(test_config_default)) + test_conf = AirflowConfigParser(default_config=parameterized_config(test_config_default)) test_conf.read_string(test_config) self.assertEqual( OrderedDict([('key1', 'hello'), ('AIRFLOW_HOME', '/root/airflow')]), - test_conf.getsection('kubernetes_environment_variables') + test_conf.getsection('kubernetes_environment_variables'), ) def test_broker_transport_options(self): @@ -422,10 +412,12 @@ def test_broker_transport_options(self): self.assertTrue(isinstance(section_dict['_test_only_float'], float)) self.assertTrue(isinstance(section_dict['_test_only_string'], str)) - @conf_vars({ - ("celery", "worker_concurrency"): None, - ("celery", "celeryd_concurrency"): None, - }) + @conf_vars( + { + ("celery", "worker_concurrency"): None, + ("celery", "celeryd_concurrency"): None, + } + ) def test_deprecated_options(self): # Guarantee we have a deprecated setting, so we test the deprecation # lookup even if we remove this explicit fallback @@ -443,10 +435,12 @@ def test_deprecated_options(self): with self.assertWarns(DeprecationWarning), conf_vars({('celery', 'celeryd_concurrency'): '99'}): self.assertEqual(conf.getint('celery', 'worker_concurrency'), 99) - @conf_vars({ - ('logging', 'logging_level'): None, - ('core', 'logging_level'): None, - }) + @conf_vars( + { + ('logging', 'logging_level'): None, + ('core', 'logging_level'): None, + } + ) def test_deprecated_options_with_new_section(self): # Guarantee we have a deprecated setting, so we test the deprecation # lookup even if we remove this explicit fallback @@ -465,11 +459,13 @@ def test_deprecated_options_with_new_section(self): with self.assertWarns(DeprecationWarning), conf_vars({('core', 'logging_level'): 'VALUE'}): self.assertEqual(conf.get('logging', 'logging_level'), "VALUE") - @conf_vars({ - ("celery", "result_backend"): None, - ("celery", "celery_result_backend"): None, - ("celery", "celery_result_backend_cmd"): None, - }) + @conf_vars( + { + ("celery", "result_backend"): None, + ("celery", "celery_result_backend"): None, + ("celery", "celery_result_backend_cmd"): None, + } + ) def test_deprecated_options_cmd(self): # Guarantee we have a deprecated setting, so we test the deprecation # lookup even if we remove this explicit fallback @@ -497,14 +493,16 @@ def make_config(): 'hostname_callable': (re.compile(r':'), r'.', '2.0'), }, } - test_conf.read_dict({ - 'core': { - 'executor': 'SequentialExecutor', - 'task_runner': 'BashTaskRunner', - 'sql_alchemy_conn': 'sqlite://', - 'hostname_callable': 'socket:getfqdn', - }, - }) + test_conf.read_dict( + { + 'core': { + 'executor': 'SequentialExecutor', + 'task_runner': 'BashTaskRunner', + 'sql_alchemy_conn': 'sqlite://', + 'hostname_callable': 'socket:getfqdn', + }, + } + ) return test_conf with self.assertWarns(FutureWarning): @@ -526,9 +524,11 @@ def make_config(): with reset_warning_registry(): with warnings.catch_warnings(record=True) as warning: - with unittest.mock.patch.dict('os.environ', - AIRFLOW__CORE__TASK_RUNNER='NotBashTaskRunner', - AIRFLOW__CORE__HOSTNAME_CALLABLE='CarrierPigeon'): + with unittest.mock.patch.dict( + 'os.environ', + AIRFLOW__CORE__TASK_RUNNER='NotBashTaskRunner', + AIRFLOW__CORE__HOSTNAME_CALLABLE='CarrierPigeon', + ): test_conf = make_config() self.assertEqual(test_conf.get('core', 'task_runner'), 'NotBashTaskRunner') @@ -537,8 +537,17 @@ def make_config(): self.assertListEqual([], warning) def test_deprecated_funcs(self): - for func in ['load_test_config', 'get', 'getboolean', 'getfloat', 'getint', 'has_option', - 'remove_option', 'as_dict', 'set']: + for func in [ + 'load_test_config', + 'get', + 'getboolean', + 'getfloat', + 'getint', + 'has_option', + 'remove_option', + 'as_dict', + 'set', + ]: with mock.patch(f'airflow.configuration.conf.{func}') as mock_method: with self.assertWarns(DeprecationWarning): getattr(configuration, func)() @@ -583,10 +592,7 @@ def test_config_use_original_when_original_and_fallback_are_present(self): fernet_key = conf.get('core', 'FERNET_KEY') with conf_vars({('core', 'FERNET_KEY_CMD'): 'printf HELLO'}): - fallback_fernet_key = conf.get( - "core", - "FERNET_KEY" - ) + fallback_fernet_key = conf.get("core", "FERNET_KEY") self.assertEqual(fernet_key, fallback_fernet_key) @@ -648,10 +654,7 @@ def test_store_dag_code_default_config(self): self.assertTrue(store_serialized_dags) self.assertTrue(store_dag_code) - @conf_vars({ - ("core", "store_serialized_dags"): "True", - ("core", "store_dag_code"): "False" - }) + @conf_vars({("core", "store_serialized_dags"): "True", ("core", "store_dag_code"): "False"}) def test_store_dag_code_config_when_set(self): store_serialized_dags = conf.getboolean('core', 'store_serialized_dags', fallback=False) store_dag_code = conf.getboolean("core", "store_dag_code", fallback=store_serialized_dags) diff --git a/tests/core/test_core.py b/tests/core/test_core.py index 71a22368d7461..44f4aa2f9633a 100644 --- a/tests/core/test_core.py +++ b/tests/core/test_core.py @@ -54,6 +54,7 @@ class OperatorSubclass(BaseOperator): """ An operator to test template substitution """ + template_fields = ['some_templated_field'] def __init__(self, some_templated_field, *args, **kwargs): @@ -78,15 +79,11 @@ def setUp(self): def tearDown(self): session = Session() - session.query(DagRun).filter( - DagRun.dag_id == TEST_DAG_ID).delete( - synchronize_session=False) - session.query(TaskInstance).filter( - TaskInstance.dag_id == TEST_DAG_ID).delete( - synchronize_session=False) - session.query(TaskFail).filter( - TaskFail.dag_id == TEST_DAG_ID).delete( - synchronize_session=False) + session.query(DagRun).filter(DagRun.dag_id == TEST_DAG_ID).delete(synchronize_session=False) + session.query(TaskInstance).filter(TaskInstance.dag_id == TEST_DAG_ID).delete( + synchronize_session=False + ) + session.query(TaskFail).filter(TaskFail.dag_id == TEST_DAG_ID).delete(synchronize_session=False) session.commit() session.close() clear_db_dags() @@ -102,10 +99,8 @@ def test_check_operators(self): self.dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=DEFAULT_DATE) op = CheckOperator( - task_id='check', - sql="select count(*) from operator_test_table", - conn_id=conn_id, - dag=self.dag) + task_id='check', sql="select count(*) from operator_test_table", conn_id=conn_id, dag=self.dag + ) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @@ -115,16 +110,15 @@ def test_check_operators(self): tolerance=0.1, conn_id=conn_id, sql="SELECT 100", - dag=self.dag) + dag=self.dag, + ) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) captain_hook.run("drop table operator_test_table") def test_clear_api(self): task = self.dag_bash.tasks[0] - task.clear( - start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, - upstream=True, downstream=True) + task.clear(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, upstream=True, downstream=True) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.are_dependents_done() @@ -139,7 +133,8 @@ def test_illegal_args(self): task_id='test_illegal_args', bash_command='echo success', dag=self.dag, - illegal_argument_1234='hello?') + illegal_argument_1234='hello?', + ) assert any(msg in str(w) for w in warning.warnings) def test_illegal_args_forbidden(self): @@ -152,17 +147,15 @@ def test_illegal_args_forbidden(self): task_id='test_illegal_args', bash_command='echo success', dag=self.dag, - illegal_argument_1234='hello?') + illegal_argument_1234='hello?', + ) self.assertIn( - ('Invalid arguments were passed to BashOperator ' - '(task_id: test_illegal_args).'), - str(ctx.exception)) + ('Invalid arguments were passed to BashOperator ' '(task_id: test_illegal_args).'), + str(ctx.exception), + ) def test_bash_operator(self): - op = BashOperator( - task_id='test_bash_operator', - bash_command="echo success", - dag=self.dag) + op = BashOperator(task_id='test_bash_operator', bash_command="echo success", dag=self.dag) self.dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=DEFAULT_DATE) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @@ -172,22 +165,22 @@ def test_bash_operator_multi_byte_output(self): task_id='test_multi_byte_bash_operator', bash_command="echo \u2600", dag=self.dag, - output_encoding='utf-8') + output_encoding='utf-8', + ) self.dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=DEFAULT_DATE) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_bash_operator_kill(self): import psutil + sleep_time = "100%d" % os.getpid() op = BashOperator( task_id='test_bash_operator_kill', execution_timeout=timedelta(seconds=1), bash_command="/bin/bash -c 'sleep %s'" % sleep_time, - dag=self.dag) - self.assertRaises( - AirflowTaskTimeout, - op.run, - start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dag=self.dag, + ) + self.assertRaises(AirflowTaskTimeout, op.run, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) sleep(2) pid = -1 for proc in psutil.process_iter(): @@ -210,26 +203,23 @@ def check_failure(context, test_case=self): task_id='check_on_failure_callback', bash_command="exit 1", dag=self.dag, - on_failure_callback=check_failure) + on_failure_callback=check_failure, + ) self.assertRaises( - AirflowException, - op.run, - start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + AirflowException, op.run, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True + ) self.assertTrue(data['called']) def test_dryrun(self): - op = BashOperator( - task_id='test_dryrun', - bash_command="echo success", - dag=self.dag) + op = BashOperator(task_id='test_dryrun', bash_command="echo success", dag=self.dag) op.dry_run() def test_sqlite(self): import airflow.providers.sqlite.operators.sqlite + op = airflow.providers.sqlite.operators.sqlite.SqliteOperator( - task_id='time_sqlite', - sql="CREATE TABLE IF NOT EXISTS unitest (dummy VARCHAR(20))", - dag=self.dag) + task_id='time_sqlite', sql="CREATE TABLE IF NOT EXISTS unitest (dummy VARCHAR(20))", dag=self.dag + ) self.dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=DEFAULT_DATE) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @@ -238,11 +228,11 @@ def test_timeout(self): task_id='test_timeout', execution_timeout=timedelta(seconds=1), python_callable=lambda: sleep(5), - dag=self.dag) + dag=self.dag, + ) self.assertRaises( - AirflowTaskTimeout, - op.run, - start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + AirflowTaskTimeout, op.run, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True + ) def test_python_op(self): def test_py_op(templates_dict, ds, **kwargs): @@ -250,25 +240,20 @@ def test_py_op(templates_dict, ds, **kwargs): raise Exception("failure") op = PythonOperator( - task_id='test_py_op', - python_callable=test_py_op, - templates_dict={'ds': "{{ ds }}"}, - dag=self.dag) + task_id='test_py_op', python_callable=test_py_op, templates_dict={'ds': "{{ ds }}"}, dag=self.dag + ) self.dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=DEFAULT_DATE) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_complex_template(self): def verify_templated_field(context): - self.assertEqual(context['ti'].task.some_templated_field['bar'][1], - context['ds']) + self.assertEqual(context['ti'].task.some_templated_field['bar'][1], context['ds']) op = OperatorSubclass( task_id='test_complex_template', - some_templated_field={ - 'foo': '123', - 'bar': ['baz', '{{ ds }}'] - }, - dag=self.dag) + some_templated_field={'foo': '123', 'bar': ['baz', '{{ ds }}']}, + dag=self.dag, + ) op.execute = verify_templated_field self.dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=DEFAULT_DATE) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @@ -286,20 +271,16 @@ def __bool__(self): # pylint: disable=invalid-bool-returned, bad-option-value return NotImplemented op = OperatorSubclass( - task_id='test_bad_template_obj', - some_templated_field=NonBoolObject(), - dag=self.dag) + task_id='test_bad_template_obj', some_templated_field=NonBoolObject(), dag=self.dag + ) op.resolve_template_files() def test_task_get_template(self): TI = TaskInstance - ti = TI( - task=self.runme_0, execution_date=DEFAULT_DATE) + ti = TI(task=self.runme_0, execution_date=DEFAULT_DATE) ti.dag = self.dag_bash self.dag_bash.create_dagrun( - run_type=DagRunType.MANUAL, - state=State.RUNNING, - execution_date=DEFAULT_DATE + run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=DEFAULT_DATE ) ti.run(ignore_ti_state=True) context = ti.get_template_context() @@ -328,20 +309,16 @@ def test_task_get_template(self): def test_local_task_job(self): TI = TaskInstance - ti = TI( - task=self.runme_0, execution_date=DEFAULT_DATE) + ti = TI(task=self.runme_0, execution_date=DEFAULT_DATE) job = LocalTaskJob(task_instance=ti, ignore_ti_state=True) job.run() def test_raw_job(self): TI = TaskInstance - ti = TI( - task=self.runme_0, execution_date=DEFAULT_DATE) + ti = TI(task=self.runme_0, execution_date=DEFAULT_DATE) ti.dag = self.dag_bash self.dag_bash.create_dagrun( - run_type=DagRunType.MANUAL, - state=State.RUNNING, - execution_date=DEFAULT_DATE + run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=DEFAULT_DATE ) ti.run(ignore_ti_state=True) @@ -353,20 +330,16 @@ def test_round_time(self): rt2 = round_time(datetime(2015, 1, 2), relativedelta(months=1)) self.assertEqual(datetime(2015, 1, 1, 0, 0), rt2) - rt3 = round_time(datetime(2015, 9, 16, 0, 0), timedelta(1), datetime( - 2015, 9, 14, 0, 0)) + rt3 = round_time(datetime(2015, 9, 16, 0, 0), timedelta(1), datetime(2015, 9, 14, 0, 0)) self.assertEqual(datetime(2015, 9, 16, 0, 0), rt3) - rt4 = round_time(datetime(2015, 9, 15, 0, 0), timedelta(1), datetime( - 2015, 9, 14, 0, 0)) + rt4 = round_time(datetime(2015, 9, 15, 0, 0), timedelta(1), datetime(2015, 9, 14, 0, 0)) self.assertEqual(datetime(2015, 9, 15, 0, 0), rt4) - rt5 = round_time(datetime(2015, 9, 14, 0, 0), timedelta(1), datetime( - 2015, 9, 14, 0, 0)) + rt5 = round_time(datetime(2015, 9, 14, 0, 0), timedelta(1), datetime(2015, 9, 14, 0, 0)) self.assertEqual(datetime(2015, 9, 14, 0, 0), rt5) - rt6 = round_time(datetime(2015, 9, 13, 0, 0), timedelta(1), datetime( - 2015, 9, 14, 0, 0)) + rt6 = round_time(datetime(2015, 9, 13, 0, 0), timedelta(1), datetime(2015, 9, 14, 0, 0)) self.assertEqual(datetime(2015, 9, 14, 0, 0), rt6) def test_infer_time_unit(self): @@ -390,29 +363,25 @@ def test_scale_time_units(self): assert_array_almost_equal(arr2, [110.0, 50.0, 10.0, 100.0], decimal=3) arr3 = scale_time_units([100000, 50000, 10000, 20000], 'hours') - assert_array_almost_equal(arr3, [27.778, 13.889, 2.778, 5.556], - decimal=3) + assert_array_almost_equal(arr3, [27.778, 13.889, 2.778, 5.556], decimal=3) arr4 = scale_time_units([200000, 100000], 'days') assert_array_almost_equal(arr4, [2.315, 1.157], decimal=3) def test_bad_trigger_rule(self): with self.assertRaises(AirflowException): - DummyOperator( - task_id='test_bad_trigger', - trigger_rule="non_existent", - dag=self.dag) + DummyOperator(task_id='test_bad_trigger', trigger_rule="non_existent", dag=self.dag) def test_terminate_task(self): """If a task instance's db state get deleted, it should fail""" from airflow.executors.sequential_executor import SequentialExecutor + TI = TaskInstance dag = self.dagbag.dags.get('test_utils') task = dag.task_dict.get('sleeps_forever') ti = TI(task=task, execution_date=DEFAULT_DATE) - job = LocalTaskJob( - task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) + job = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) # Running task instance asynchronously proc = multiprocessing.Process(target=job.run) @@ -423,11 +392,11 @@ def test_terminate_task(self): ti.refresh_from_db(session=session) # making sure it's actually running self.assertEqual(State.RUNNING, ti.state) - ti = session.query(TI).filter_by( - dag_id=task.dag_id, - task_id=task.task_id, - execution_date=DEFAULT_DATE - ).one() + ti = ( + session.query(TI) + .filter_by(dag_id=task.dag_id, task_id=task.task_id, execution_date=DEFAULT_DATE) + .one() + ) # deleting the instance should result in a failure session.delete(ti) @@ -443,16 +412,14 @@ def test_terminate_task(self): def test_task_fail_duration(self): """If a task fails, the duration should be recorded in TaskFail""" - op1 = BashOperator( - task_id='pass_sleepy', - bash_command='sleep 3', - dag=self.dag) + op1 = BashOperator(task_id='pass_sleepy', bash_command='sleep 3', dag=self.dag) op2 = BashOperator( task_id='fail_sleepy', bash_command='sleep 5', execution_timeout=timedelta(seconds=3), retry_delay=timedelta(seconds=0), - dag=self.dag) + dag=self.dag, + ) session = settings.Session() try: op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @@ -462,14 +429,16 @@ def test_task_fail_duration(self): op2.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) except Exception: # pylint: disable=broad-except pass - op1_fails = session.query(TaskFail).filter_by( - task_id='pass_sleepy', - dag_id=self.dag.dag_id, - execution_date=DEFAULT_DATE).all() - op2_fails = session.query(TaskFail).filter_by( - task_id='fail_sleepy', - dag_id=self.dag.dag_id, - execution_date=DEFAULT_DATE).all() + op1_fails = ( + session.query(TaskFail) + .filter_by(task_id='pass_sleepy', dag_id=self.dag.dag_id, execution_date=DEFAULT_DATE) + .all() + ) + op2_fails = ( + session.query(TaskFail) + .filter_by(task_id='fail_sleepy', dag_id=self.dag.dag_id, execution_date=DEFAULT_DATE) + .all() + ) self.assertEqual(0, len(op1_fails)) self.assertEqual(1, len(op2_fails)) @@ -484,18 +453,16 @@ def test_externally_triggered_dagrun(self): execution_ds_nodash = execution_ds.replace('-', '') dag = DAG( - TEST_DAG_ID, - default_args=self.args, - schedule_interval=timedelta(weeks=1), - start_date=DEFAULT_DATE) - task = DummyOperator(task_id='test_externally_triggered_dag_context', - dag=dag) - dag.create_dagrun(run_type=DagRunType.SCHEDULED, - execution_date=execution_date, - state=State.RUNNING, - external_trigger=True) - task.run( - start_date=execution_date, end_date=execution_date) + TEST_DAG_ID, default_args=self.args, schedule_interval=timedelta(weeks=1), start_date=DEFAULT_DATE + ) + task = DummyOperator(task_id='test_externally_triggered_dag_context', dag=dag) + dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=execution_date, + state=State.RUNNING, + external_trigger=True, + ) + task.run(start_date=execution_date, end_date=execution_date) ti = TI(task=task, execution_date=execution_date) context = ti.get_template_context() diff --git a/tests/core/test_core_to_contrib.py b/tests/core/test_core_to_contrib.py index e44287ccc7b26..e91e69d987525 100644 --- a/tests/core/test_core_to_contrib.py +++ b/tests/core/test_core_to_contrib.py @@ -35,9 +35,7 @@ def assert_warning(msg: str, warning: Any): assert any(msg in str(w) for w in warning.warnings), error def assert_is_subclass(self, clazz, other): - self.assertTrue( - issubclass(clazz, other), f"{clazz} is not subclass of {other}" - ) + self.assertTrue(issubclass(clazz, other), f"{clazz} is not subclass of {other}") def assert_proper_import(self, old_resource, new_resource): new_path, _, _ = new_resource.rpartition(".") @@ -66,9 +64,7 @@ def get_class_from_path(path_to_class, parent=False): if isabstract(class_) and not parent: class_name = f"Mock({class_.__name__})" - attributes = { - a: mock.MagicMock() for a in class_.__abstractmethods__ - } + attributes = {a: mock.MagicMock() for a in class_.__abstractmethods__} new_class = type(class_name, (class_,), attributes) return new_class @@ -111,9 +107,7 @@ def test_no_redirect_to_deprecated_classes(self): This will tell us to use new_A instead of old_B. """ - all_classes_by_old = { - old: new for new, old in ALL - } + all_classes_by_old = {old: new for new, old in ALL} for new, old in ALL: # Using if statement allows us to create a developer-friendly message only when we need it. diff --git a/tests/core/test_example_dags_system.py b/tests/core/test_example_dags_system.py index 0ce411d5e5222..4d16f0c24e6d8 100644 --- a/tests/core/test_example_dags_system.py +++ b/tests/core/test_example_dags_system.py @@ -23,12 +23,14 @@ @pytest.mark.system("core") class TestExampleDagsSystem(SystemTest): - @parameterized.expand([ - "example_bash_operator", - "example_branch_operator", - "tutorial_etl_dag", - "tutorial_functional_etl_dag", - "example_dag_decorator", - ]) + @parameterized.expand( + [ + "example_bash_operator", + "example_branch_operator", + "tutorial_etl_dag", + "tutorial_functional_etl_dag", + "example_dag_decorator", + ] + ) def test_dag_example(self, dag_id): self.run_dag(dag_id=dag_id) diff --git a/tests/core/test_impersonation_tests.py b/tests/core/test_impersonation_tests.py index 56c5cdf43b0c3..5fe10e03c5afd 100644 --- a/tests/core/test_impersonation_tests.py +++ b/tests/core/test_impersonation_tests.py @@ -35,12 +35,9 @@ from airflow.utils.timezone import datetime DEV_NULL = '/dev/null' -TEST_DAG_FOLDER = os.path.join( - os.path.dirname(os.path.realpath(__file__)), 'dags') -TEST_DAG_CORRUPTED_FOLDER = os.path.join( - os.path.dirname(os.path.realpath(__file__)), 'dags_corrupted') -TEST_UTILS_FOLDER = os.path.join( - os.path.dirname(os.path.realpath(__file__)), 'test_utils') +TEST_DAG_FOLDER = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'dags') +TEST_DAG_CORRUPTED_FOLDER = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'dags_corrupted') +TEST_UTILS_FOLDER = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_utils') DEFAULT_DATE = datetime(2015, 1, 1) TEST_USER = 'airflow_test_user' @@ -54,6 +51,7 @@ def mock_custom_module_path(path: str): the :envvar:`PYTHONPATH` environment variable set and sets the environment variable :envvar:`PYTHONPATH` to change the module load directory for child scripts. """ + def wrapper(func): @functools.wraps(func) def decorator(*args, **kwargs): @@ -64,6 +62,7 @@ def decorator(*args, **kwargs): return func(*args, **kwargs) finally: sys.path = copy_sys_path + return decorator return wrapper @@ -72,28 +71,31 @@ def decorator(*args, **kwargs): def grant_permissions(): airflow_home = os.environ['AIRFLOW_HOME'] subprocess.check_call( - 'find "%s" -exec sudo chmod og+w {} +; sudo chmod og+rx /root' % airflow_home, shell=True) + 'find "%s" -exec sudo chmod og+w {} +; sudo chmod og+rx /root' % airflow_home, shell=True + ) def revoke_permissions(): airflow_home = os.environ['AIRFLOW_HOME'] subprocess.check_call( - 'find "%s" -exec sudo chmod og-w {} +; sudo chmod og-rx /root' % airflow_home, shell=True) + 'find "%s" -exec sudo chmod og-w {} +; sudo chmod og-rx /root' % airflow_home, shell=True + ) def check_original_docker_image(): if not os.path.isfile('/.dockerenv') or os.environ.get('PYTHON_BASE_IMAGE') is None: - raise unittest.SkipTest("""Adding/removing a user as part of a test is very bad for host os + raise unittest.SkipTest( + """Adding/removing a user as part of a test is very bad for host os (especially if the user already existed to begin with on the OS), therefore we check if we run inside a the official docker container and only allow to run the test there. This is done by checking /.dockerenv file (always present inside container) and checking for PYTHON_BASE_IMAGE variable. -""") +""" + ) def create_user(): try: - subprocess.check_output(['sudo', 'useradd', '-m', TEST_USER, '-g', - str(os.getegid())]) + subprocess.check_output(['sudo', 'useradd', '-m', TEST_USER, '-g', str(os.getegid())]) except OSError as e: if e.errno == errno.ENOENT: raise unittest.SkipTest( @@ -111,7 +113,6 @@ def create_user(): @pytest.mark.quarantined class TestImpersonation(unittest.TestCase): - def setUp(self): check_original_docker_image() grant_permissions() @@ -133,14 +134,9 @@ def run_backfill(self, dag_id, task_id): dag = self.dagbag.get_dag(dag_id) dag.clear() - BackfillJob( - dag=dag, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE).run() + BackfillJob(dag=dag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE).run() - ti = models.TaskInstance( - task=dag.get_task(task_id), - execution_date=DEFAULT_DATE) + ti = models.TaskInstance(task=dag.get_task(task_id), execution_date=DEFAULT_DATE) ti.refresh_from_db() self.assertEqual(ti.state, State.SUCCESS) @@ -149,10 +145,7 @@ def test_impersonation(self): """ Tests that impersonating a unix user works """ - self.run_backfill( - 'test_impersonation', - 'test_impersonated_user' - ) + self.run_backfill('test_impersonation', 'test_impersonated_user') def test_no_impersonation(self): """ @@ -170,25 +163,18 @@ def test_default_impersonation(self): If default_impersonation=TEST_USER, tests that the job defaults to running as TEST_USER for a test without run_as_user set """ - self.run_backfill( - 'test_default_impersonation', - 'test_deelevated_user' - ) + self.run_backfill('test_default_impersonation', 'test_deelevated_user') def test_impersonation_subdag(self): """ Tests that impersonation using a subdag correctly passes the right configuration :return: """ - self.run_backfill( - 'impersonation_subdag', - 'test_subdag_operation' - ) + self.run_backfill('impersonation_subdag', 'test_subdag_operation') @pytest.mark.quarantined class TestImpersonationWithCustomPythonPath(unittest.TestCase): - @mock_custom_module_path(TEST_UTILS_FOLDER) def setUp(self): check_original_docker_image() @@ -211,14 +197,9 @@ def run_backfill(self, dag_id, task_id): dag = self.dagbag.get_dag(dag_id) dag.clear() - BackfillJob( - dag=dag, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE).run() + BackfillJob(dag=dag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE).run() - ti = models.TaskInstance( - task=dag.get_task(task_id), - execution_date=DEFAULT_DATE) + ti = models.TaskInstance(task=dag.get_task(task_id), execution_date=DEFAULT_DATE) ti.refresh_from_db() self.assertEqual(ti.state, State.SUCCESS) @@ -232,7 +213,4 @@ def test_impersonation_custom(self): # PYTHONPATH is already set in script triggering tests assert 'PYTHONPATH' in os.environ - self.run_backfill( - 'impersonation_with_custom_pkg', - 'exec_python_fn' - ) + self.run_backfill('impersonation_with_custom_pkg', 'exec_python_fn') diff --git a/tests/core/test_local_settings.py b/tests/core/test_local_settings.py index 9f32d3583a09f..8eacf0e542f5c 100644 --- a/tests/core/test_local_settings.py +++ b/tests/core/test_local_settings.py @@ -96,6 +96,7 @@ def test_initialize_order(self, prepare_syspath, import_local_settings): mock.attach_mock(import_local_settings, "import_local_settings") import airflow.settings + airflow.settings.initialize() mock.assert_has_calls([call.prepare_syspath(), call.import_local_settings()]) @@ -107,6 +108,7 @@ def test_import_with_dunder_all_not_specified(self): """ with SettingsContext(SETTINGS_FILE_POLICY_WITH_DUNDER_ALL, "airflow_local_settings"): from airflow import settings + settings.import_local_settings() with self.assertRaises(AttributeError): @@ -119,6 +121,7 @@ def test_import_with_dunder_all(self): """ with SettingsContext(SETTINGS_FILE_POLICY_WITH_DUNDER_ALL, "airflow_local_settings"): from airflow import settings + settings.import_local_settings() task_instance = MagicMock() @@ -133,6 +136,7 @@ def test_import_local_settings_without_syspath(self, log_mock): if there is no airflow_local_settings module on the syspath. """ from airflow import settings + settings.import_local_settings() log_mock.assert_called_once_with("Failed to import airflow_local_settings.", exc_info=True) @@ -143,6 +147,7 @@ def test_policy_function(self): """ with SettingsContext(SETTINGS_FILE_POLICY, "airflow_local_settings"): from airflow import settings + settings.import_local_settings() task_instance = MagicMock() @@ -157,6 +162,7 @@ def test_pod_mutation_hook(self): """ with SettingsContext(SETTINGS_FILE_POD_MUTATION_HOOK, "airflow_local_settings"): from airflow import settings + settings.import_local_settings() pod = MagicMock() @@ -167,6 +173,7 @@ def test_pod_mutation_hook(self): def test_custom_policy(self): with SettingsContext(SETTINGS_FILE_CUSTOM_POLICY, "airflow_local_settings"): from airflow import settings + settings.import_local_settings() task_instance = MagicMock() diff --git a/tests/core/test_logging_config.py b/tests/core/test_logging_config.py index 99b05fd927ab9..9c6983319c1a0 100644 --- a/tests/core/test_logging_config.py +++ b/tests/core/test_logging_config.py @@ -188,6 +188,7 @@ def tearDown(self): # When we try to load an invalid config file, we expect an error def test_loading_invalid_local_settings(self): from airflow.logging_config import configure_logging, log + with settings_context(SETTINGS_FILE_INVALID): with patch.object(log, 'error') as mock_info: # Load config @@ -204,56 +205,54 @@ def test_loading_valid_complex_local_settings(self): dir_structure = module_structure.replace('.', '/') with settings_context(SETTINGS_FILE_VALID, dir_structure): from airflow.logging_config import configure_logging, log + with patch.object(log, 'info') as mock_info: configure_logging() mock_info.assert_called_once_with( 'Successfully imported user-defined logging config from %s', - 'etc.airflow.config.{}.LOGGING_CONFIG'.format( - SETTINGS_DEFAULT_NAME - ) + f'etc.airflow.config.{SETTINGS_DEFAULT_NAME}.LOGGING_CONFIG', ) # When we try to load a valid config def test_loading_valid_local_settings(self): with settings_context(SETTINGS_FILE_VALID): from airflow.logging_config import configure_logging, log + with patch.object(log, 'info') as mock_info: configure_logging() mock_info.assert_called_once_with( 'Successfully imported user-defined logging config from %s', - '{}.LOGGING_CONFIG'.format( - SETTINGS_DEFAULT_NAME - ) + f'{SETTINGS_DEFAULT_NAME}.LOGGING_CONFIG', ) # When we load an empty file, it should go to default def test_loading_no_local_settings(self): with settings_context(SETTINGS_FILE_EMPTY): from airflow.logging_config import configure_logging + with self.assertRaises(ImportError): configure_logging() # When the key is not available in the configuration def test_when_the_config_key_does_not_exists(self): from airflow import logging_config + with conf_vars({('logging', 'logging_config_class'): None}): with patch.object(logging_config.log, 'debug') as mock_debug: logging_config.configure_logging() - mock_debug.assert_any_call( - 'Could not find key logging_config_class in config' - ) + mock_debug.assert_any_call('Could not find key logging_config_class in config') # Just default def test_loading_local_settings_without_logging_config(self): from airflow.logging_config import configure_logging, log + with patch.object(log, 'debug') as mock_info: configure_logging() - mock_info.assert_called_once_with( - 'Unable to load custom logging, using default config instead' - ) + mock_info.assert_called_once_with('Unable to load custom logging, using default config instead') def test_1_9_config(self): from airflow.logging_config import configure_logging + with conf_vars({('logging', 'task_log_reader'): 'file.task'}): with self.assertWarnsRegex(DeprecationWarning, r'file.task'): configure_logging() @@ -265,11 +264,13 @@ def test_loading_remote_logging_with_wasb_handler(self): from airflow.logging_config import configure_logging from airflow.utils.log.wasb_task_handler import WasbTaskHandler - with conf_vars({ - ('logging', 'remote_logging'): 'True', - ('logging', 'remote_log_conn_id'): 'some_wasb', - ('logging', 'remote_base_log_folder'): 'wasb://some-folder', - }): + with conf_vars( + { + ('logging', 'remote_logging'): 'True', + ('logging', 'remote_log_conn_id'): 'some_wasb', + ('logging', 'remote_base_log_folder'): 'wasb://some-folder', + } + ): importlib.reload(airflow_local_settings) configure_logging() diff --git a/tests/core/test_sentry.py b/tests/core/test_sentry.py index 3d5979a25ddfd..bfff5843fdfc4 100644 --- a/tests/core/test_sentry.py +++ b/tests/core/test_sentry.py @@ -64,6 +64,7 @@ class TestSentryHook(unittest.TestCase): @conf_vars({('sentry', 'sentry_on'): 'True'}) def setUp(self): from airflow import sentry + importlib.reload(sentry) self.sentry = sentry.ConfiguredSentry() diff --git a/tests/core/test_sqlalchemy_config.py b/tests/core/test_sqlalchemy_config.py index 6f0ee94b9c141..a7cd08b727bd7 100644 --- a/tests/core/test_sqlalchemy_config.py +++ b/tests/core/test_sqlalchemy_config.py @@ -25,13 +25,7 @@ from airflow.exceptions import AirflowConfigException from tests.test_utils.config import conf_vars -SQL_ALCHEMY_CONNECT_ARGS = { - 'test': 43503, - 'dict': { - 'is': 1, - 'supported': 'too' - } -} +SQL_ALCHEMY_CONNECT_ARGS = {'test': 43503, 'dict': {'is': 1, 'supported': 'too'}} class TestSqlAlchemySettings(unittest.TestCase): @@ -50,11 +44,9 @@ def tearDown(self): @patch('airflow.settings.scoped_session') @patch('airflow.settings.sessionmaker') @patch('airflow.settings.create_engine') - def test_configure_orm_with_default_values(self, - mock_create_engine, - mock_sessionmaker, - mock_scoped_session, - mock_setup_event_handlers): + def test_configure_orm_with_default_values( + self, mock_create_engine, mock_sessionmaker, mock_scoped_session, mock_setup_event_handlers + ): settings.configure_orm() mock_create_engine.assert_called_once_with( settings.SQL_ALCHEMY_CONN, @@ -63,22 +55,22 @@ def test_configure_orm_with_default_values(self, max_overflow=10, pool_pre_ping=True, pool_recycle=1800, - pool_size=5 + pool_size=5, ) @patch('airflow.settings.setup_event_handlers') @patch('airflow.settings.scoped_session') @patch('airflow.settings.sessionmaker') @patch('airflow.settings.create_engine') - def test_sql_alchemy_connect_args(self, - mock_create_engine, - mock_sessionmaker, - mock_scoped_session, - mock_setup_event_handlers): + def test_sql_alchemy_connect_args( + self, mock_create_engine, mock_sessionmaker, mock_scoped_session, mock_setup_event_handlers + ): config = { - ('core', 'sql_alchemy_connect_args'): - 'tests.core.test_sqlalchemy_config.SQL_ALCHEMY_CONNECT_ARGS', - ('core', 'sql_alchemy_pool_enabled'): 'False' + ( + 'core', + 'sql_alchemy_connect_args', + ): 'tests.core.test_sqlalchemy_config.SQL_ALCHEMY_CONNECT_ARGS', + ('core', 'sql_alchemy_pool_enabled'): 'False', } with conf_vars(config): settings.configure_orm() @@ -86,21 +78,19 @@ def test_sql_alchemy_connect_args(self, settings.SQL_ALCHEMY_CONN, connect_args=SQL_ALCHEMY_CONNECT_ARGS, poolclass=NullPool, - encoding='utf-8' + encoding='utf-8', ) @patch('airflow.settings.setup_event_handlers') @patch('airflow.settings.scoped_session') @patch('airflow.settings.sessionmaker') @patch('airflow.settings.create_engine') - def test_sql_alchemy_invalid_connect_args(self, - mock_create_engine, - mock_sessionmaker, - mock_scoped_session, - mock_setup_event_handlers): + def test_sql_alchemy_invalid_connect_args( + self, mock_create_engine, mock_sessionmaker, mock_scoped_session, mock_setup_event_handlers + ): config = { ('core', 'sql_alchemy_connect_args'): 'does.not.exist', - ('core', 'sql_alchemy_pool_enabled'): 'False' + ('core', 'sql_alchemy_pool_enabled'): 'False', } with self.assertRaises(AirflowConfigException): with conf_vars(config): diff --git a/tests/core/test_stats.py b/tests/core/test_stats.py index fd2acd7135ad7..9f9adeaf14a06 100644 --- a/tests/core/test_stats.py +++ b/tests/core/test_stats.py @@ -48,6 +48,7 @@ class InvalidCustomStatsd: This custom Statsd class is invalid because it does not subclass statsd.StatsClient. """ + incr_calls = 0 def __init__(self, host=None, port=None, prefix=None): @@ -62,7 +63,6 @@ def _reset(cls): class TestStats(unittest.TestCase): - def setUp(self): self.statsd_client = Mock() self.stats = SafeStatsdLogger(self.statsd_client) @@ -85,52 +85,54 @@ def test_stat_name_must_only_include_allowed_characters(self): self.stats.incr('test/$tats') self.statsd_client.assert_not_called() - @conf_vars({ - ('scheduler', 'statsd_on'): 'True' - }) + @conf_vars({('scheduler', 'statsd_on'): 'True'}) @mock.patch("statsd.StatsClient") def test_does_send_stats_using_statsd(self, mock_statsd): importlib.reload(airflow.stats) airflow.stats.Stats.incr("dummy_key") mock_statsd.return_value.incr.assert_called_once_with('dummy_key', 1, 1) - @conf_vars({ - ('scheduler', 'statsd_on'): 'True' - }) + @conf_vars({('scheduler', 'statsd_on'): 'True'}) @mock.patch("datadog.DogStatsd") def test_does_not_send_stats_using_dogstatsd(self, mock_dogstatsd): importlib.reload(airflow.stats) airflow.stats.Stats.incr("dummy_key") mock_dogstatsd.return_value.assert_not_called() - @conf_vars({ - ("scheduler", "statsd_on"): "True", - ("scheduler", "statsd_custom_client_path"): "tests.core.test_stats.CustomStatsd", - }) + @conf_vars( + { + ("scheduler", "statsd_on"): "True", + ("scheduler", "statsd_custom_client_path"): "tests.core.test_stats.CustomStatsd", + } + ) def test_load_custom_statsd_client(self): importlib.reload(airflow.stats) self.assertEqual('CustomStatsd', type(airflow.stats.Stats.statsd).__name__) - @conf_vars({ - ("scheduler", "statsd_on"): "True", - ("scheduler", "statsd_custom_client_path"): "tests.core.test_stats.CustomStatsd", - }) + @conf_vars( + { + ("scheduler", "statsd_on"): "True", + ("scheduler", "statsd_custom_client_path"): "tests.core.test_stats.CustomStatsd", + } + ) def test_does_use_custom_statsd_client(self): importlib.reload(airflow.stats) airflow.stats.Stats.incr("dummy_key") assert airflow.stats.Stats.statsd.incr_calls == 1 - @conf_vars({ - ("scheduler", "statsd_on"): "True", - ("scheduler", "statsd_custom_client_path"): "tests.core.test_stats.InvalidCustomStatsd", - }) + @conf_vars( + { + ("scheduler", "statsd_on"): "True", + ("scheduler", "statsd_custom_client_path"): "tests.core.test_stats.InvalidCustomStatsd", + } + ) def test_load_invalid_custom_stats_client(self): with self.assertRaisesRegex( AirflowConfigException, re.escape( 'Your custom Statsd client must extend the statsd.' 'StatsClient in order to ensure backwards compatibility.' - ) + ), ): importlib.reload(airflow.stats) @@ -140,7 +142,6 @@ def tearDown(self) -> None: class TestDogStats(unittest.TestCase): - def setUp(self): self.dogstatsd_client = Mock() self.dogstatsd = SafeDogStatsdLogger(self.dogstatsd_client) @@ -163,9 +164,7 @@ def test_stat_name_must_only_include_allowed_characters_with_dogstatsd(self): self.dogstatsd.incr('test/$tats') self.dogstatsd_client.assert_not_called() - @conf_vars({ - ('scheduler', 'statsd_datadog_enabled'): 'True' - }) + @conf_vars({('scheduler', 'statsd_datadog_enabled'): 'True'}) @mock.patch("datadog.DogStatsd") def test_does_send_stats_using_dogstatsd_when_dogstatsd_on(self, mock_dogstatsd): importlib.reload(airflow.stats) @@ -174,9 +173,7 @@ def test_does_send_stats_using_dogstatsd_when_dogstatsd_on(self, mock_dogstatsd) metric='dummy_key', sample_rate=1, tags=[], value=1 ) - @conf_vars({ - ('scheduler', 'statsd_datadog_enabled'): 'True' - }) + @conf_vars({('scheduler', 'statsd_datadog_enabled'): 'True'}) @mock.patch("datadog.DogStatsd") def test_does_send_stats_using_dogstatsd_with_tags(self, mock_dogstatsd): importlib.reload(airflow.stats) @@ -185,10 +182,7 @@ def test_does_send_stats_using_dogstatsd_with_tags(self, mock_dogstatsd): metric='dummy_key', sample_rate=1, tags=['key1:value1', 'key2:value2'], value=1 ) - @conf_vars({ - ('scheduler', 'statsd_on'): 'True', - ('scheduler', 'statsd_datadog_enabled'): 'True' - }) + @conf_vars({('scheduler', 'statsd_on'): 'True', ('scheduler', 'statsd_datadog_enabled'): 'True'}) @mock.patch("datadog.DogStatsd") def test_does_send_stats_using_dogstatsd_when_statsd_and_dogstatsd_both_on(self, mock_dogstatsd): importlib.reload(airflow.stats) @@ -197,10 +191,7 @@ def test_does_send_stats_using_dogstatsd_when_statsd_and_dogstatsd_both_on(self, metric='dummy_key', sample_rate=1, tags=[], value=1 ) - @conf_vars({ - ('scheduler', 'statsd_on'): 'True', - ('scheduler', 'statsd_datadog_enabled'): 'True' - }) + @conf_vars({('scheduler', 'statsd_on'): 'True', ('scheduler', 'statsd_datadog_enabled'): 'True'}) @mock.patch("statsd.StatsClient") def test_does_not_send_stats_using_statsd_when_statsd_and_dogstatsd_both_on(self, mock_statsd): importlib.reload(airflow.stats) @@ -213,7 +204,6 @@ def tearDown(self) -> None: class TestStatsWithAllowList(unittest.TestCase): - def setUp(self): self.statsd_client = Mock() self.stats = SafeStatsdLogger(self.statsd_client, AllowListValidator("stats_one, stats_two")) @@ -232,7 +222,6 @@ def test_not_increment_counter_if_not_allowed(self): class TestDogStatsWithAllowList(unittest.TestCase): - def setUp(self): self.dogstatsd_client = Mock() self.dogstats = SafeDogStatsdLogger(self.dogstatsd_client, AllowListValidator("stats_one, stats_two")) @@ -263,40 +252,48 @@ def always_valid(stat_name): class TestCustomStatsName(unittest.TestCase): - @conf_vars({ - ('scheduler', 'statsd_on'): 'True', - ('scheduler', 'stat_name_handler'): 'tests.core.test_stats.always_invalid' - }) + @conf_vars( + { + ('scheduler', 'statsd_on'): 'True', + ('scheduler', 'stat_name_handler'): 'tests.core.test_stats.always_invalid', + } + ) @mock.patch("statsd.StatsClient") def test_does_not_send_stats_using_statsd_when_the_name_is_not_valid(self, mock_statsd): importlib.reload(airflow.stats) airflow.stats.Stats.incr("dummy_key") mock_statsd.return_value.assert_not_called() - @conf_vars({ - ('scheduler', 'statsd_datadog_enabled'): 'True', - ('scheduler', 'stat_name_handler'): 'tests.core.test_stats.always_invalid' - }) + @conf_vars( + { + ('scheduler', 'statsd_datadog_enabled'): 'True', + ('scheduler', 'stat_name_handler'): 'tests.core.test_stats.always_invalid', + } + ) @mock.patch("datadog.DogStatsd") def test_does_not_send_stats_using_dogstatsd_when_the_name_is_not_valid(self, mock_dogstatsd): importlib.reload(airflow.stats) airflow.stats.Stats.incr("dummy_key") mock_dogstatsd.return_value.assert_not_called() - @conf_vars({ - ('scheduler', 'statsd_on'): 'True', - ('scheduler', 'stat_name_handler'): 'tests.core.test_stats.always_valid' - }) + @conf_vars( + { + ('scheduler', 'statsd_on'): 'True', + ('scheduler', 'stat_name_handler'): 'tests.core.test_stats.always_valid', + } + ) @mock.patch("statsd.StatsClient") def test_does_send_stats_using_statsd_when_the_name_is_valid(self, mock_statsd): importlib.reload(airflow.stats) airflow.stats.Stats.incr("dummy_key") mock_statsd.return_value.incr.assert_called_once_with('dummy_key', 1, 1) - @conf_vars({ - ('scheduler', 'statsd_datadog_enabled'): 'True', - ('scheduler', 'stat_name_handler'): 'tests.core.test_stats.always_valid' - }) + @conf_vars( + { + ('scheduler', 'statsd_datadog_enabled'): 'True', + ('scheduler', 'stat_name_handler'): 'tests.core.test_stats.always_valid', + } + ) @mock.patch("datadog.DogStatsd") def test_does_send_stats_using_dogstatsd_when_the_name_is_valid(self, mock_dogstatsd): importlib.reload(airflow.stats) diff --git a/tests/dags/subdir2/test_dont_ignore_this.py b/tests/dags/subdir2/test_dont_ignore_this.py index d1894d620649e..6af1f19a239ec 100644 --- a/tests/dags/subdir2/test_dont_ignore_this.py +++ b/tests/dags/subdir2/test_dont_ignore_this.py @@ -23,7 +23,4 @@ DEFAULT_DATE = datetime(2019, 12, 1) dag = DAG(dag_id='test_dag_under_subdir2', start_date=DEFAULT_DATE, schedule_interval=None) -task = BashOperator( - task_id='task1', - bash_command='echo "test dag under sub directory subdir2"', - dag=dag) +task = BashOperator(task_id='task1', bash_command='echo "test dag under sub directory subdir2"', dag=dag) diff --git a/tests/dags/test_backfill_pooled_tasks.py b/tests/dags/test_backfill_pooled_tasks.py index ced09c67ecff6..02a68b31b5952 100644 --- a/tests/dags/test_backfill_pooled_tasks.py +++ b/tests/dags/test_backfill_pooled_tasks.py @@ -33,4 +33,5 @@ dag=dag, pool='test_backfill_pooled_task_pool', owner='airflow', - start_date=datetime(2016, 2, 1)) + start_date=datetime(2016, 2, 1), +) diff --git a/tests/dags/test_clear_subdag.py b/tests/dags/test_clear_subdag.py index 1f11cdd7ed92a..4a59126f91a46 100644 --- a/tests/dags/test_clear_subdag.py +++ b/tests/dags/test_clear_subdag.py @@ -32,11 +32,7 @@ def create_subdag_opt(main_dag): schedule_interval=None, concurrency=2, ) - BashOperator( - bash_command="echo 1", - task_id="daily_job_subdag_task", - dag=subdag - ) + BashOperator(bash_command="echo 1", task_id="daily_job_subdag_task", dag=subdag) return SubDagOperator( task_id=subdag_name, subdag=subdag, @@ -48,12 +44,7 @@ def create_subdag_opt(main_dag): start_date = datetime.datetime(2016, 1, 1) -dag = DAG( - dag_id=dag_name, - concurrency=3, - start_date=start_date, - schedule_interval="0 0 * * *" -) +dag = DAG(dag_id=dag_name, concurrency=3, start_date=start_date, schedule_interval="0 0 * * *") daily_job_irrelevant = BashOperator( bash_command="echo 1", diff --git a/tests/dags/test_cli_triggered_dags.py b/tests/dags/test_cli_triggered_dags.py index 4facacb48a711..631a7317ace73 100644 --- a/tests/dags/test_cli_triggered_dags.py +++ b/tests/dags/test_cli_triggered_dags.py @@ -24,9 +24,7 @@ from airflow.utils.timezone import datetime DEFAULT_DATE = datetime(2016, 1, 1) -default_args = dict( - start_date=DEFAULT_DATE, - owner='airflow') +default_args = dict(start_date=DEFAULT_DATE, owner='airflow') def fail(): @@ -40,14 +38,9 @@ def success(ti=None, *args, **kwargs): # DAG tests that tasks ignore all dependencies -dag1 = DAG(dag_id='test_run_ignores_all_dependencies', - default_args=dict(depends_on_past=True, **default_args)) -dag1_task1 = PythonOperator( - task_id='test_run_dependency_task', - python_callable=fail, - dag=dag1) -dag1_task2 = PythonOperator( - task_id='test_run_dependent_task', - python_callable=success, - dag=dag1) +dag1 = DAG( + dag_id='test_run_ignores_all_dependencies', default_args=dict(depends_on_past=True, **default_args) +) +dag1_task1 = PythonOperator(task_id='test_run_dependency_task', python_callable=fail, dag=dag1) +dag1_task2 = PythonOperator(task_id='test_run_dependent_task', python_callable=success, dag=dag1) dag1_task1.set_downstream(dag1_task2) diff --git a/tests/dags/test_default_impersonation.py b/tests/dags/test_default_impersonation.py index 5268287bcd762..4e379d3edea90 100644 --- a/tests/dags/test_default_impersonation.py +++ b/tests/dags/test_default_impersonation.py @@ -39,7 +39,10 @@ echo current user $(whoami) is not {user}! exit 1 fi - """.format(user=deelevated_user)) + """.format( + user=deelevated_user + ) +) task = BashOperator( task_id='test_deelevated_user', diff --git a/tests/dags/test_default_views.py b/tests/dags/test_default_views.py index ab1b19227649a..f30177f356f5c 100644 --- a/tests/dags/test_default_views.py +++ b/tests/dags/test_default_views.py @@ -18,20 +18,18 @@ from airflow.models import DAG from airflow.utils.dates import days_ago -args = { - 'owner': 'airflow', - 'retries': 3, - 'start_date': days_ago(2) -} +args = {'owner': 'airflow', 'retries': 3, 'start_date': days_ago(2)} tree_dag = DAG( - dag_id='test_tree_view', default_args=args, + dag_id='test_tree_view', + default_args=args, schedule_interval='0 0 * * *', default_view='tree', ) graph_dag = DAG( - dag_id='test_graph_view', default_args=args, + dag_id='test_graph_view', + default_args=args, schedule_interval='0 0 * * *', default_view='graph', ) diff --git a/tests/dags/test_double_trigger.py b/tests/dags/test_double_trigger.py index e1c03441afef5..9c8aae11c5150 100644 --- a/tests/dags/test_double_trigger.py +++ b/tests/dags/test_double_trigger.py @@ -28,6 +28,4 @@ } dag = DAG(dag_id='test_localtaskjob_double_trigger', default_args=args) -task = DummyOperator( - task_id='test_localtaskjob_double_trigger_task', - dag=dag) +task = DummyOperator(task_id='test_localtaskjob_double_trigger_task', dag=dag) diff --git a/tests/dags/test_example_bash_operator.py b/tests/dags/test_example_bash_operator.py index 4d2dfdb24fb6f..1e28f664dcc4f 100644 --- a/tests/dags/test_example_bash_operator.py +++ b/tests/dags/test_example_bash_operator.py @@ -22,35 +22,30 @@ from airflow.operators.dummy_operator import DummyOperator from airflow.utils.dates import days_ago -args = { - 'owner': 'airflow', - 'retries': 3, - 'start_date': days_ago(2) -} +args = {'owner': 'airflow', 'retries': 3, 'start_date': days_ago(2)} dag = DAG( - dag_id='test_example_bash_operator', default_args=args, + dag_id='test_example_bash_operator', + default_args=args, schedule_interval='0 0 * * *', - dagrun_timeout=timedelta(minutes=60)) + dagrun_timeout=timedelta(minutes=60), +) cmd = 'ls -l' run_this_last = DummyOperator(task_id='run_this_last', dag=dag) -run_this = BashOperator( - task_id='run_after_loop', bash_command='echo 1', dag=dag) +run_this = BashOperator(task_id='run_after_loop', bash_command='echo 1', dag=dag) run_this.set_downstream(run_this_last) for i in range(3): task = BashOperator( - task_id='runme_' + str(i), - bash_command='echo "{{ task_instance_key_str }}" && sleep 1', - dag=dag) + task_id='runme_' + str(i), bash_command='echo "{{ task_instance_key_str }}" && sleep 1', dag=dag + ) task.set_downstream(run_this) task = BashOperator( - task_id='also_run_this', - bash_command='echo "run_id={{ run_id }} | dag_run={{ dag_run }}"', - dag=dag) + task_id='also_run_this', bash_command='echo "run_id={{ run_id }} | dag_run={{ dag_run }}"', dag=dag +) task.set_downstream(run_this_last) if __name__ == "__main__": diff --git a/tests/dags/test_heartbeat_failed_fast.py b/tests/dags/test_heartbeat_failed_fast.py index 5c74d49a578d6..01d5d6be88fd2 100644 --- a/tests/dags/test_heartbeat_failed_fast.py +++ b/tests/dags/test_heartbeat_failed_fast.py @@ -28,7 +28,4 @@ } dag = DAG(dag_id='test_heartbeat_failed_fast', default_args=args) -task = BashOperator( - task_id='test_heartbeat_failed_fast_op', - bash_command='sleep 7', - dag=dag) +task = BashOperator(task_id='test_heartbeat_failed_fast_op', bash_command='sleep 7', dag=dag) diff --git a/tests/dags/test_impersonation.py b/tests/dags/test_impersonation.py index d58f9cae5ee97..e72a9e11eb05a 100644 --- a/tests/dags/test_impersonation.py +++ b/tests/dags/test_impersonation.py @@ -39,7 +39,10 @@ echo current user is not {user}! exit 1 fi - """.format(user=run_as_user)) + """.format( + user=run_as_user + ) +) task = BashOperator( task_id='test_impersonated_user', diff --git a/tests/dags/test_impersonation_subdag.py b/tests/dags/test_impersonation_subdag.py index a55d5227e749a..3ea52002b2ba8 100644 --- a/tests/dags/test_impersonation_subdag.py +++ b/tests/dags/test_impersonation_subdag.py @@ -25,11 +25,7 @@ DEFAULT_DATE = datetime(2016, 1, 1) -default_args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE, - 'run_as_user': 'airflow_test_user' -} +default_args = {'owner': 'airflow', 'start_date': DEFAULT_DATE, 'run_as_user': 'airflow_test_user'} dag = DAG(dag_id='impersonation_subdag', default_args=default_args) @@ -38,25 +34,15 @@ def print_today(): print(f'Today is {datetime.utcnow()}') -subdag = DAG('impersonation_subdag.test_subdag_operation', - default_args=default_args) +subdag = DAG('impersonation_subdag.test_subdag_operation', default_args=default_args) -PythonOperator( - python_callable=print_today, - task_id='exec_python_fn', - dag=subdag) +PythonOperator(python_callable=print_today, task_id='exec_python_fn', dag=subdag) -BashOperator( - task_id='exec_bash_operator', - bash_command='echo "Running within SubDag"', - dag=subdag -) +BashOperator(task_id='exec_bash_operator', bash_command='echo "Running within SubDag"', dag=subdag) -subdag_operator = SubDagOperator(task_id='test_subdag_operation', - subdag=subdag, - mode='reschedule', - poke_interval=1, - dag=dag) +subdag_operator = SubDagOperator( + task_id='test_subdag_operation', subdag=subdag, mode='reschedule', poke_interval=1, dag=dag +) diff --git a/tests/dags/test_invalid_cron.py b/tests/dags/test_invalid_cron.py index 60e9f822c18a3..270abd377a452 100644 --- a/tests/dags/test_invalid_cron.py +++ b/tests/dags/test_invalid_cron.py @@ -24,11 +24,5 @@ # Cron expression. This invalid DAG will be used to # test whether dagbag.process_file() can identify # invalid Cron expression. -dag1 = DAG( - dag_id='test_invalid_cron', - start_date=datetime(2015, 1, 1), - schedule_interval="0 100 * * *") -dag1_task1 = DummyOperator( - task_id='task1', - dag=dag1, - owner='airflow') +dag1 = DAG(dag_id='test_invalid_cron', start_date=datetime(2015, 1, 1), schedule_interval="0 100 * * *") +dag1_task1 = DummyOperator(task_id='task1', dag=dag1, owner='airflow') diff --git a/tests/dags/test_issue_1225.py b/tests/dags/test_issue_1225.py index b962f23c112ba..d0daf9dc4ee89 100644 --- a/tests/dags/test_issue_1225.py +++ b/tests/dags/test_issue_1225.py @@ -30,9 +30,7 @@ from airflow.utils.trigger_rule import TriggerRule DEFAULT_DATE = datetime(2016, 1, 1) -default_args = dict( - start_date=DEFAULT_DATE, - owner='airflow') +default_args = dict(start_date=DEFAULT_DATE, owner='airflow') def fail(): @@ -45,19 +43,18 @@ def fail(): dag1_task1 = DummyOperator( task_id='test_backfill_pooled_task', dag=dag1, - pool='test_backfill_pooled_task_pool',) + pool='test_backfill_pooled_task_pool', +) # dag2 has been moved to test_prev_dagrun_dep.py # DAG tests that a Dag run that doesn't complete is marked failed dag3 = DAG(dag_id='test_dagrun_states_fail', default_args=default_args) -dag3_task1 = PythonOperator( - task_id='test_dagrun_fail', - dag=dag3, - python_callable=fail) +dag3_task1 = PythonOperator(task_id='test_dagrun_fail', dag=dag3, python_callable=fail) dag3_task2 = DummyOperator( task_id='test_dagrun_succeed', - dag=dag3,) + dag=dag3, +) dag3_task2.set_upstream(dag3_task1) # DAG tests that a Dag run that completes but has a failure is marked success @@ -67,11 +64,7 @@ def fail(): dag=dag4, python_callable=fail, ) -dag4_task2 = DummyOperator( - task_id='test_dagrun_succeed', - dag=dag4, - trigger_rule=TriggerRule.ALL_FAILED -) +dag4_task2 = DummyOperator(task_id='test_dagrun_succeed', dag=dag4, trigger_rule=TriggerRule.ALL_FAILED) dag4_task2.set_upstream(dag4_task1) # DAG tests that a Dag run that completes but has a root failure is marked fail @@ -91,11 +84,13 @@ def fail(): dag6_task1 = DummyOperator( task_id='test_depends_on_past', depends_on_past=True, - dag=dag6,) + dag=dag6, +) dag6_task2 = DummyOperator( task_id='test_depends_on_past_2', depends_on_past=True, - dag=dag6,) + dag=dag6, +) dag6_task2.set_upstream(dag6_task1) @@ -103,7 +98,7 @@ def fail(): dag8 = DAG(dag_id='test_dagrun_states_root_fail_unfinished', default_args=default_args) dag8_task1 = DummyOperator( task_id='test_dagrun_unfinished', # The test will unset the task instance state after - # running this test + # running this test dag=dag8, ) dag8_task2 = PythonOperator( diff --git a/tests/dags/test_latest_runs.py b/tests/dags/test_latest_runs.py index a4ea8eb8afaea..9bd15688a49e1 100644 --- a/tests/dags/test_latest_runs.py +++ b/tests/dags/test_latest_runs.py @@ -24,8 +24,4 @@ for i in range(1, 2): dag = DAG(dag_id=f'test_latest_runs_{i}') - task = DummyOperator( - task_id='dummy_task', - dag=dag, - owner='airflow', - start_date=datetime(2016, 2, 1)) + task = DummyOperator(task_id='dummy_task', dag=dag, owner='airflow', start_date=datetime(2016, 2, 1)) diff --git a/tests/dags/test_logging_in_dag.py b/tests/dags/test_logging_in_dag.py index 6d828fa631d4a..8b0c629de750a 100644 --- a/tests/dags/test_logging_in_dag.py +++ b/tests/dags/test_logging_in_dag.py @@ -35,11 +35,7 @@ def test_logging_fn(**kwargs): print("Log from Print statement") -dag = DAG( - dag_id='test_logging_dag', - schedule_interval=None, - start_date=datetime(2016, 1, 1) -) +dag = DAG(dag_id='test_logging_dag', schedule_interval=None, start_date=datetime(2016, 1, 1)) PythonOperator( task_id='test_task', diff --git a/tests/dags/test_mark_success.py b/tests/dags/test_mark_success.py index d731026b26f54..759b39f33457b 100644 --- a/tests/dags/test_mark_success.py +++ b/tests/dags/test_mark_success.py @@ -31,7 +31,5 @@ dag = DAG(dag_id='test_mark_success', default_args=args) task = PythonOperator( - task_id='task1', - python_callable=lambda x: sleep(x), # pylint: disable=W0108 - op_args=[600], - dag=dag) + task_id='task1', python_callable=lambda x: sleep(x), op_args=[600], dag=dag # pylint: disable=W0108 +) diff --git a/tests/dags/test_missing_owner.py b/tests/dags/test_missing_owner.py index 16f715c061b96..ef81f98859f36 100644 --- a/tests/dags/test_missing_owner.py +++ b/tests/dags/test_missing_owner.py @@ -29,4 +29,6 @@ dagrun_timeout=timedelta(minutes=60), tags=["example"], ) as dag: - run_this_last = DummyOperator(task_id="test_task",) + run_this_last = DummyOperator( + task_id="test_task", + ) diff --git a/tests/dags/test_multiple_dags.py b/tests/dags/test_multiple_dags.py index 0f7132615883b..44aa5cdfc2c5b 100644 --- a/tests/dags/test_multiple_dags.py +++ b/tests/dags/test_multiple_dags.py @@ -21,11 +21,7 @@ from airflow.operators.bash import BashOperator from airflow.utils.dates import days_ago -args = { - 'owner': 'airflow', - 'retries': 3, - 'start_date': days_ago(2) -} +args = {'owner': 'airflow', 'retries': 3, 'start_date': days_ago(2)} def create_dag(suffix): @@ -33,7 +29,7 @@ def create_dag(suffix): dag_id=f'test_multiple_dags__{suffix}', default_args=args, schedule_interval='0 0 * * *', - dagrun_timeout=timedelta(minutes=60) + dagrun_timeout=timedelta(minutes=60), ) with dag: diff --git a/tests/dags/test_no_impersonation.py b/tests/dags/test_no_impersonation.py index e068690ae8ac2..c5c9cea0cdafa 100644 --- a/tests/dags/test_no_impersonation.py +++ b/tests/dags/test_no_impersonation.py @@ -38,7 +38,8 @@ echo 'current uid does not have root privileges!' exit 1 fi - """) + """ +) task = BashOperator( task_id='test_superuser', diff --git a/tests/dags/test_on_failure_callback.py b/tests/dags/test_on_failure_callback.py index 28ea084fd1872..1daddba6da2a4 100644 --- a/tests/dags/test_on_failure_callback.py +++ b/tests/dags/test_on_failure_callback.py @@ -37,7 +37,5 @@ def write_data_to_callback(*arg, **kwargs): # pylint: disable=unused-argument task = DummyOperator( - task_id='test_om_failure_callback_task', - dag=dag, - on_failure_callback=write_data_to_callback + task_id='test_om_failure_callback_task', dag=dag, on_failure_callback=write_data_to_callback ) diff --git a/tests/dags/test_on_kill.py b/tests/dags/test_on_kill.py index ec172689dea6b..cb75edf349124 100644 --- a/tests/dags/test_on_kill.py +++ b/tests/dags/test_on_kill.py @@ -40,10 +40,5 @@ def on_kill(self): # DAG tests backfill with pooled tasks # Previously backfill would queue the task but never run it -dag1 = DAG( - dag_id='test_on_kill', - start_date=datetime(2015, 1, 1)) -dag1_task1 = DummyWithOnKill( - task_id='task1', - dag=dag1, - owner='airflow') +dag1 = DAG(dag_id='test_on_kill', start_date=datetime(2015, 1, 1)) +dag1_task1 = DummyWithOnKill(task_id='task1', dag=dag1, owner='airflow') diff --git a/tests/dags/test_prev_dagrun_dep.py b/tests/dags/test_prev_dagrun_dep.py index b89e99626763a..6103c0e5c22b5 100644 --- a/tests/dags/test_prev_dagrun_dep.py +++ b/tests/dags/test_prev_dagrun_dep.py @@ -27,11 +27,19 @@ # DAG tests depends_on_past dependencies dag_dop = DAG(dag_id="test_depends_on_past", default_args=default_args) with dag_dop: - dag_dop_task = DummyOperator(task_id="test_dop_task", depends_on_past=True,) + dag_dop_task = DummyOperator( + task_id="test_dop_task", + depends_on_past=True, + ) # DAG tests wait_for_downstream dependencies dag_wfd = DAG(dag_id="test_wait_for_downstream", default_args=default_args) with dag_wfd: - dag_wfd_upstream = DummyOperator(task_id="upstream_task", wait_for_downstream=True,) - dag_wfd_downstream = DummyOperator(task_id="downstream_task",) + dag_wfd_upstream = DummyOperator( + task_id="upstream_task", + wait_for_downstream=True, + ) + dag_wfd_downstream = DummyOperator( + task_id="downstream_task", + ) dag_wfd_upstream >> dag_wfd_downstream diff --git a/tests/dags/test_retry_handling_job.py b/tests/dags/test_retry_handling_job.py index 894cfc9430711..42bf9a47331e1 100644 --- a/tests/dags/test_retry_handling_job.py +++ b/tests/dags/test_retry_handling_job.py @@ -34,7 +34,4 @@ dag = DAG('test_retry_handling_job', default_args=default_args, schedule_interval='@once') -task1 = BashOperator( - task_id='test_retry_handling_op', - bash_command='exit 1', - dag=dag) +task1 = BashOperator(task_id='test_retry_handling_op', bash_command='exit 1', dag=dag) diff --git a/tests/dags/test_scheduler_dags.py b/tests/dags/test_scheduler_dags.py index bc47fbcf1abd1..35a9bd06a7c7b 100644 --- a/tests/dags/test_scheduler_dags.py +++ b/tests/dags/test_scheduler_dags.py @@ -25,26 +25,11 @@ # DAG tests backfill with pooled tasks # Previously backfill would queue the task but never run it -dag1 = DAG( - dag_id='test_start_date_scheduling', - start_date=datetime.utcnow() + timedelta(days=1)) -dag1_task1 = DummyOperator( - task_id='dummy', - dag=dag1, - owner='airflow') +dag1 = DAG(dag_id='test_start_date_scheduling', start_date=datetime.utcnow() + timedelta(days=1)) +dag1_task1 = DummyOperator(task_id='dummy', dag=dag1, owner='airflow') -dag2 = DAG( - dag_id='test_task_start_date_scheduling', - start_date=DEFAULT_DATE -) +dag2 = DAG(dag_id='test_task_start_date_scheduling', start_date=DEFAULT_DATE) dag2_task1 = DummyOperator( - task_id='dummy1', - dag=dag2, - owner='airflow', - start_date=DEFAULT_DATE + timedelta(days=3) -) -dag2_task2 = DummyOperator( - task_id='dummy2', - dag=dag2, - owner='airflow' + task_id='dummy1', dag=dag2, owner='airflow', start_date=DEFAULT_DATE + timedelta(days=3) ) +dag2_task2 = DummyOperator(task_id='dummy2', dag=dag2, owner='airflow') diff --git a/tests/dags/test_task_view_type_check.py b/tests/dags/test_task_view_type_check.py index 6f4585ba0258a..5003cadefe137 100644 --- a/tests/dags/test_task_view_type_check.py +++ b/tests/dags/test_task_view_type_check.py @@ -28,15 +28,14 @@ from airflow.operators.python import PythonOperator DEFAULT_DATE = datetime(2016, 1, 1) -default_args = dict( - start_date=DEFAULT_DATE, - owner='airflow') +default_args = dict(start_date=DEFAULT_DATE, owner='airflow') class CallableClass: """ Class that is callable. """ + def __call__(self): """A __call__ method """ diff --git a/tests/dags/test_with_non_default_owner.py b/tests/dags/test_with_non_default_owner.py index eebbb64c621bb..3e004d38fe5d2 100644 --- a/tests/dags/test_with_non_default_owner.py +++ b/tests/dags/test_with_non_default_owner.py @@ -29,4 +29,7 @@ dagrun_timeout=timedelta(minutes=60), tags=["example"], ) as dag: - run_this_last = DummyOperator(task_id="test_task", owner="John",) + run_this_last = DummyOperator( + task_id="test_task", + owner="John", + ) diff --git a/tests/dags_corrupted/test_impersonation_custom.py b/tests/dags_corrupted/test_impersonation_custom.py index f9deb45ae0f97..77ea1ed59951c 100644 --- a/tests/dags_corrupted/test_impersonation_custom.py +++ b/tests/dags_corrupted/test_impersonation_custom.py @@ -32,11 +32,7 @@ DEFAULT_DATE = datetime(2016, 1, 1) -args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE, - 'run_as_user': 'airflow_test_user' -} +args = {'owner': 'airflow', 'start_date': DEFAULT_DATE, 'run_as_user': 'airflow_test_user'} dag = DAG(dag_id='impersonation_with_custom_pkg', default_args=args) @@ -48,15 +44,10 @@ def print_today(): def check_hive_conf(): from airflow.configuration import conf + assert conf.get('hive', 'default_hive_mapred_queue') == 'airflow' -PythonOperator( - python_callable=print_today, - task_id='exec_python_fn', - dag=dag) +PythonOperator(python_callable=print_today, task_id='exec_python_fn', dag=dag) -PythonOperator( - python_callable=check_hive_conf, - task_id='exec_check_hive_conf_fn', - dag=dag) +PythonOperator(python_callable=check_hive_conf, task_id='exec_check_hive_conf_fn', dag=dag) diff --git a/tests/dags_with_system_exit/a_system_exit.py b/tests/dags_with_system_exit/a_system_exit.py index 530fd8fc500e0..fd0a86f2dc530 100644 --- a/tests/dags_with_system_exit/a_system_exit.py +++ b/tests/dags_with_system_exit/a_system_exit.py @@ -26,8 +26,6 @@ DEFAULT_DATE = datetime(2100, 1, 1) -dag1 = DAG( - dag_id='test_system_exit', - start_date=DEFAULT_DATE) +dag1 = DAG(dag_id='test_system_exit', start_date=DEFAULT_DATE) sys.exit(-1) diff --git a/tests/dags_with_system_exit/b_test_scheduler_dags.py b/tests/dags_with_system_exit/b_test_scheduler_dags.py index ff70b374c40fc..9fabdb526f0ca 100644 --- a/tests/dags_with_system_exit/b_test_scheduler_dags.py +++ b/tests/dags_with_system_exit/b_test_scheduler_dags.py @@ -23,11 +23,6 @@ DEFAULT_DATE = datetime(2000, 1, 1) -dag1 = DAG( - dag_id='exit_test_dag', - start_date=DEFAULT_DATE) +dag1 = DAG(dag_id='exit_test_dag', start_date=DEFAULT_DATE) -dag1_task1 = DummyOperator( - task_id='dummy', - dag=dag1, - owner='airflow') +dag1_task1 = DummyOperator(task_id='dummy', dag=dag1, owner='airflow') diff --git a/tests/dags_with_system_exit/c_system_exit.py b/tests/dags_with_system_exit/c_system_exit.py index 1f919dd34c836..9222bc699ad5c 100644 --- a/tests/dags_with_system_exit/c_system_exit.py +++ b/tests/dags_with_system_exit/c_system_exit.py @@ -26,8 +26,6 @@ DEFAULT_DATE = datetime(2100, 1, 1) -dag1 = DAG( - dag_id='test_system_exit', - start_date=DEFAULT_DATE) +dag1 = DAG(dag_id='test_system_exit', start_date=DEFAULT_DATE) sys.exit(-1) diff --git a/tests/deprecated_classes.py b/tests/deprecated_classes.py index 170b0c75503b0..d94774b2715ca 100644 --- a/tests/deprecated_classes.py +++ b/tests/deprecated_classes.py @@ -250,7 +250,6 @@ ( 'airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook', 'airflow.contrib.hooks.sagemaker_hook.SageMakerHook', - ), ( 'airflow.providers.mongo.hooks.mongo.MongoHook', @@ -466,8 +465,7 @@ ( "airflow.providers.google.cloud.operators.compute" ".ComputeEngineInstanceGroupUpdateManagerTemplateOperator", - "airflow.contrib.operators.gcp_compute_operator." - "GceInstanceGroupManagerUpdateTemplateOperator", + "airflow.contrib.operators.gcp_compute_operator.GceInstanceGroupManagerUpdateTemplateOperator", ), ( "airflow.providers.google.cloud.operators.compute.ComputeEngineStartInstanceOperator", @@ -628,8 +626,7 @@ ( "airflow.providers.google.cloud.operators.natural_language." "CloudNaturalLanguageAnalyzeEntitiesOperator", - "airflow.contrib.operators.gcp_natural_language_operator." - "CloudLanguageAnalyzeEntitiesOperator", + "airflow.contrib.operators.gcp_natural_language_operator.CloudLanguageAnalyzeEntitiesOperator", ), ( "airflow.providers.google.cloud.operators.natural_language." @@ -640,8 +637,7 @@ ( "airflow.providers.google.cloud.operators.natural_language." "CloudNaturalLanguageAnalyzeSentimentOperator", - "airflow.contrib.operators.gcp_natural_language_operator." - "CloudLanguageAnalyzeSentimentOperator", + "airflow.contrib.operators.gcp_natural_language_operator.CloudLanguageAnalyzeSentimentOperator", ), ( "airflow.providers.google.cloud.operators.natural_language." @@ -690,32 +686,27 @@ ( "airflow.providers.google.cloud.operators.cloud_storage_transfer_service." "CloudDataTransferServiceCancelOperationOperator", - "airflow.contrib.operators.gcp_transfer_operator." - "GcpTransferServiceOperationCancelOperator", + "airflow.contrib.operators.gcp_transfer_operator.GcpTransferServiceOperationCancelOperator", ), ( "airflow.providers.google.cloud.operators.cloud_storage_transfer_service." "CloudDataTransferServiceGetOperationOperator", - "airflow.contrib.operators.gcp_transfer_operator." - "GcpTransferServiceOperationGetOperator", + "airflow.contrib.operators.gcp_transfer_operator.GcpTransferServiceOperationGetOperator", ), ( "airflow.providers.google.cloud.operators.cloud_storage_transfer_service." "CloudDataTransferServicePauseOperationOperator", - "airflow.contrib.operators.gcp_transfer_operator." - "GcpTransferServiceOperationPauseOperator", + "airflow.contrib.operators.gcp_transfer_operator.GcpTransferServiceOperationPauseOperator", ), ( "airflow.providers.google.cloud.operators.cloud_storage_transfer_service." "CloudDataTransferServiceResumeOperationOperator", - "airflow.contrib.operators.gcp_transfer_operator." - "GcpTransferServiceOperationResumeOperator", + "airflow.contrib.operators.gcp_transfer_operator.GcpTransferServiceOperationResumeOperator", ), ( "airflow.providers.google.cloud.operators.cloud_storage_transfer_service." "CloudDataTransferServiceListOperationsOperator", - "airflow.contrib.operators.gcp_transfer_operator." - "GcpTransferServiceOperationsListOperator", + "airflow.contrib.operators.gcp_transfer_operator.GcpTransferServiceOperationsListOperator", ), ( "airflow.providers.google.cloud.operators.translate.CloudTranslateTextOperator", @@ -801,8 +792,7 @@ ), ( "airflow.providers.google.cloud.operators.vision.CloudVisionRemoveProductFromProductSetOperator", - "airflow.contrib.operators.gcp_vision_operator." - "CloudVisionRemoveProductFromProductSetOperator", + "airflow.contrib.operators.gcp_vision_operator.CloudVisionRemoveProductFromProductSetOperator", ), ( "airflow.providers.google.cloud.operators.mlengine.MLEngineStartBatchPredictionJobOperator", @@ -841,66 +831,53 @@ "airflow.contrib.operators.pubsub_operator.PubSubTopicDeleteOperator", ), ( - "airflow.providers.google.cloud." - "operators.dataproc.DataprocCreateClusterOperator", + "airflow.providers.google.cloud.operators.dataproc.DataprocCreateClusterOperator", "airflow.contrib.operators.dataproc_operator.DataprocClusterCreateOperator", ), ( - "airflow.providers.google.cloud." - "operators.dataproc.DataprocDeleteClusterOperator", + "airflow.providers.google.cloud.operators.dataproc.DataprocDeleteClusterOperator", "airflow.contrib.operators.dataproc_operator.DataprocClusterDeleteOperator", ), ( - "airflow.providers.google.cloud." - "operators.dataproc.DataprocScaleClusterOperator", + "airflow.providers.google.cloud.operators.dataproc.DataprocScaleClusterOperator", "airflow.contrib.operators.dataproc_operator.DataprocClusterScaleOperator", ), ( - "airflow.providers.google.cloud." - "operators.dataproc.DataprocSubmitHadoopJobOperator", + "airflow.providers.google.cloud.operators.dataproc.DataprocSubmitHadoopJobOperator", "airflow.contrib.operators.dataproc_operator.DataProcHadoopOperator", ), ( - "airflow.providers.google.cloud." - "operators.dataproc.DataprocSubmitHiveJobOperator", + "airflow.providers.google.cloud.operators.dataproc.DataprocSubmitHiveJobOperator", "airflow.contrib.operators.dataproc_operator.DataProcHiveOperator", ), ( - "airflow.providers.google.cloud." - "operators.dataproc.DataprocJobBaseOperator", + "airflow.providers.google.cloud.operators.dataproc.DataprocJobBaseOperator", "airflow.contrib.operators.dataproc_operator.DataProcJobBaseOperator", ), ( - "airflow.providers.google.cloud." - "operators.dataproc.DataprocSubmitPigJobOperator", + "airflow.providers.google.cloud.operators.dataproc.DataprocSubmitPigJobOperator", "airflow.contrib.operators.dataproc_operator.DataProcPigOperator", ), ( - "airflow.providers.google.cloud." - "operators.dataproc.DataprocSubmitPySparkJobOperator", + "airflow.providers.google.cloud.operators.dataproc.DataprocSubmitPySparkJobOperator", "airflow.contrib.operators.dataproc_operator.DataProcPySparkOperator", ), ( - "airflow.providers.google.cloud." - "operators.dataproc.DataprocSubmitSparkJobOperator", + "airflow.providers.google.cloud.operators.dataproc.DataprocSubmitSparkJobOperator", "airflow.contrib.operators.dataproc_operator.DataProcSparkOperator", ), ( - "airflow.providers.google.cloud." - "operators.dataproc.DataprocSubmitSparkSqlJobOperator", + "airflow.providers.google.cloud.operators.dataproc.DataprocSubmitSparkSqlJobOperator", "airflow.contrib.operators.dataproc_operator.DataProcSparkSqlOperator", ), ( "airflow.providers.google.cloud." "operators.dataproc.DataprocInstantiateInlineWorkflowTemplateOperator", - "airflow.contrib.operators.dataproc_operator." - "DataprocWorkflowTemplateInstantiateInlineOperator", + "airflow.contrib.operators.dataproc_operator.DataprocWorkflowTemplateInstantiateInlineOperator", ), ( - "airflow.providers.google.cloud." - "operators.dataproc.DataprocInstantiateWorkflowTemplateOperator", - "airflow.contrib.operators.dataproc_operator." - "DataprocWorkflowTemplateInstantiateOperator", + "airflow.providers.google.cloud.operators.dataproc.DataprocInstantiateWorkflowTemplateOperator", + "airflow.contrib.operators.dataproc_operator.DataprocWorkflowTemplateInstantiateOperator", ), ( "airflow.providers.google.cloud.operators.bigquery.BigQueryCreateEmptyDatasetOperator", @@ -1285,43 +1262,43 @@ ), ( 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLBaseOperator', - 'airflow.contrib.operators.gcp_sql_operator.CloudSqlBaseOperator' + 'airflow.contrib.operators.gcp_sql_operator.CloudSqlBaseOperator', ), ( 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLCreateInstanceDatabaseOperator', - 'airflow.contrib.operators.gcp_sql_operator.CloudSqlInstanceDatabaseCreateOperator' + 'airflow.contrib.operators.gcp_sql_operator.CloudSqlInstanceDatabaseCreateOperator', ), ( 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLCreateInstanceOperator', - 'airflow.contrib.operators.gcp_sql_operator.CloudSqlInstanceCreateOperator' + 'airflow.contrib.operators.gcp_sql_operator.CloudSqlInstanceCreateOperator', ), ( 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLDeleteInstanceDatabaseOperator', - 'airflow.contrib.operators.gcp_sql_operator.CloudSqlInstanceDatabaseDeleteOperator' + 'airflow.contrib.operators.gcp_sql_operator.CloudSqlInstanceDatabaseDeleteOperator', ), ( 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLDeleteInstanceOperator', - 'airflow.contrib.operators.gcp_sql_operator.CloudSqlInstanceDeleteOperator' + 'airflow.contrib.operators.gcp_sql_operator.CloudSqlInstanceDeleteOperator', ), ( 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLExecuteQueryOperator', - 'airflow.contrib.operators.gcp_sql_operator.CloudSqlQueryOperator' + 'airflow.contrib.operators.gcp_sql_operator.CloudSqlQueryOperator', ), ( 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLExportInstanceOperator', - 'airflow.contrib.operators.gcp_sql_operator.CloudSqlInstanceExportOperator' + 'airflow.contrib.operators.gcp_sql_operator.CloudSqlInstanceExportOperator', ), ( 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLImportInstanceOperator', - 'airflow.contrib.operators.gcp_sql_operator.CloudSqlInstanceImportOperator' + 'airflow.contrib.operators.gcp_sql_operator.CloudSqlInstanceImportOperator', ), ( 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLInstancePatchOperator', - 'airflow.contrib.operators.gcp_sql_operator.CloudSqlInstancePatchOperator' + 'airflow.contrib.operators.gcp_sql_operator.CloudSqlInstancePatchOperator', ), ( 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLPatchInstanceDatabaseOperator', - 'airflow.contrib.operators.gcp_sql_operator.CloudSqlInstanceDatabasePatchOperator' + 'airflow.contrib.operators.gcp_sql_operator.CloudSqlInstanceDatabasePatchOperator', ), ( 'airflow.providers.jira.operators.jira.JiraOperator', @@ -1379,14 +1356,12 @@ ), ( "airflow.providers.google.cloud.sensors.bigtable.BigtableTableReplicationCompletedSensor", - "airflow.contrib.operators.gcp_bigtable_operator." - "BigtableTableWaitForReplicationSensor", + "airflow.contrib.operators.gcp_bigtable_operator.BigtableTableWaitForReplicationSensor", ), ( "airflow.providers.google.cloud.sensors.cloud_storage_transfer_service." "CloudDataTransferServiceJobStatusSensor", - "airflow.contrib.sensors.gcp_transfer_sensor." - "GCPTransferServiceWaitForJobStatusSensor", + "airflow.contrib.sensors.gcp_transfer_sensor.GCPTransferServiceWaitForJobStatusSensor", ), ( "airflow.providers.google.cloud.sensors.pubsub.PubSubPullSensor", @@ -1575,7 +1550,7 @@ ( 'airflow.providers.sftp.sensors.sftp.SFTPSensor', 'airflow.contrib.sensors.sftp_sensor.SFTPSensor', - ) + ), ] TRANSFERS = [ @@ -1619,8 +1594,7 @@ ), ( "airflow.providers.google.cloud.transfers.postgres_to_gcs.PostgresToGCSOperator", - "airflow.contrib.operators.postgres_to_gcs_operator." - "PostgresToGoogleCloudStorageOperator", + "airflow.contrib.operators.postgres_to_gcs_operator.PostgresToGoogleCloudStorageOperator", ), ( "airflow.providers.google.cloud.transfers.bigquery_to_bigquery.BigQueryToBigQueryOperator", @@ -1738,7 +1712,7 @@ ( 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service' '.CloudDataTransferServiceS3ToGCSOperator', - 'airflow.contrib.operators.s3_to_gcs_transfer_operator.CloudDataTransferServiceS3ToGCSOperator' + 'airflow.contrib.operators.s3_to_gcs_transfer_operator.CloudDataTransferServiceS3ToGCSOperator', ), ( 'airflow.providers.google.cloud.transfers.cassandra_to_gcs.CassandraToGCSOperator', @@ -1754,34 +1728,34 @@ ( 'airflow.utils.weekday.WeekDay', 'airflow.contrib.utils.weekday.WeekDay', - ) + ), ] LOGS = [ ( "airflow.providers.amazon.aws.log.s3_task_handler.S3TaskHandler", - "airflow.utils.log.s3_task_handler.S3TaskHandler" + "airflow.utils.log.s3_task_handler.S3TaskHandler", ), ( 'airflow.providers.amazon.aws.log.cloudwatch_task_handler.CloudwatchTaskHandler', - 'airflow.utils.log.cloudwatch_task_handler.CloudwatchTaskHandler' + 'airflow.utils.log.cloudwatch_task_handler.CloudwatchTaskHandler', ), ( 'airflow.providers.elasticsearch.log.es_task_handler.ElasticsearchTaskHandler', - 'airflow.utils.log.es_task_handler.ElasticsearchTaskHandler' + 'airflow.utils.log.es_task_handler.ElasticsearchTaskHandler', ), ( "airflow.providers.google.cloud.log.stackdriver_task_handler.StackdriverTaskHandler", - "airflow.utils.log.stackdriver_task_handler.StackdriverTaskHandler" + "airflow.utils.log.stackdriver_task_handler.StackdriverTaskHandler", ), ( "airflow.providers.google.cloud.log.gcs_task_handler.GCSTaskHandler", - "airflow.utils.log.gcs_task_handler.GCSTaskHandler" + "airflow.utils.log.gcs_task_handler.GCSTaskHandler", ), ( "airflow.providers.microsoft.azure.log.wasb_task_handler.WasbTaskHandler", - "airflow.utils.log.wasb_task_handler.WasbTaskHandler" - ) + "airflow.utils.log.wasb_task_handler.WasbTaskHandler", + ), ] ALL = HOOKS + OPERATORS + SECRETS + SENSORS + TRANSFERS + UTILS + LOGS diff --git a/tests/executors/test_base_executor.py b/tests/executors/test_base_executor.py index f324fa753b21f..06fe03067c84f 100644 --- a/tests/executors/test_base_executor.py +++ b/tests/executors/test_base_executor.py @@ -51,9 +51,11 @@ def test_get_event_buffer(self): def test_gauge_executor_metrics(self, mock_stats_gauge, mock_trigger_tasks, mock_sync): executor = BaseExecutor() executor.heartbeat() - calls = [mock.call('executor.open_slots', mock.ANY), - mock.call('executor.queued_tasks', mock.ANY), - mock.call('executor.running_tasks', mock.ANY)] + calls = [ + mock.call('executor.open_slots', mock.ANY), + mock.call('executor.queued_tasks', mock.ANY), + mock.call('executor.running_tasks', mock.ANY), + ] mock_stats_gauge.assert_has_calls(calls) def test_try_adopt_task_instances(self): diff --git a/tests/executors/test_celery_executor.py b/tests/executors/test_celery_executor.py index feb245ef13a85..1701fed4ce736 100644 --- a/tests/executors/test_celery_executor.py +++ b/tests/executors/test_celery_executor.py @@ -49,10 +49,7 @@ def _prepare_test_bodies(): if 'CELERY_BROKER_URLS' in os.environ: - return [ - (url, ) - for url in os.environ['CELERY_BROKER_URLS'].split(',') - ] + return [(url,) for url in os.environ['CELERY_BROKER_URLS'].split(',')] return [(conf.get('celery', 'BROKER_URL'))] @@ -97,7 +94,6 @@ def _prepare_app(broker_url=None, execute=None): class TestCeleryExecutor(unittest.TestCase): - def setUp(self) -> None: db.clear_db_runs() db.clear_db_jobs() @@ -127,12 +123,20 @@ def fake_execute_command(command): execute_date = datetime.now() task_tuples_to_send = [ - (('success', 'fake_simple_ti', execute_date, 0), - None, success_command, celery_executor.celery_configuration['task_default_queue'], - celery_executor.execute_command), - (('fail', 'fake_simple_ti', execute_date, 0), - None, fail_command, celery_executor.celery_configuration['task_default_queue'], - celery_executor.execute_command) + ( + ('success', 'fake_simple_ti', execute_date, 0), + None, + success_command, + celery_executor.celery_configuration['task_default_queue'], + celery_executor.execute_command, + ), + ( + ('fail', 'fake_simple_ti', execute_date, 0), + None, + fail_command, + celery_executor.celery_configuration['task_default_queue'], + celery_executor.execute_command, + ), ] # "Enqueue" them. We don't have a real SimpleTaskInstance, so directly edit the dict @@ -145,24 +149,22 @@ def fake_execute_command(command): list(executor.tasks.keys()), [ ('success', 'fake_simple_ti', execute_date, 0), - ('fail', 'fake_simple_ti', execute_date, 0) - ] + ('fail', 'fake_simple_ti', execute_date, 0), + ], ) self.assertEqual( - executor.event_buffer[('success', 'fake_simple_ti', execute_date, 0)][0], - State.QUEUED + executor.event_buffer[('success', 'fake_simple_ti', execute_date, 0)][0], State.QUEUED ) self.assertEqual( - executor.event_buffer[('fail', 'fake_simple_ti', execute_date, 0)][0], - State.QUEUED + executor.event_buffer[('fail', 'fake_simple_ti', execute_date, 0)][0], State.QUEUED ) executor.end(synchronous=True) - self.assertEqual(executor.event_buffer[('success', 'fake_simple_ti', execute_date, 0)][0], - State.SUCCESS) - self.assertEqual(executor.event_buffer[('fail', 'fake_simple_ti', execute_date, 0)][0], - State.FAILED) + self.assertEqual( + executor.event_buffer[('success', 'fake_simple_ti', execute_date, 0)][0], State.SUCCESS + ) + self.assertEqual(executor.event_buffer[('fail', 'fake_simple_ti', execute_date, 0)][0], State.FAILED) self.assertNotIn('success', executor.tasks) self.assertNotIn('fail', executor.tasks) @@ -182,14 +184,15 @@ def fake_execute_command(): # which will cause TypeError when calling task.apply_async() executor = celery_executor.CeleryExecutor() task = BashOperator( - task_id="test", - bash_command="true", - dag=DAG(dag_id='id'), - start_date=datetime.now() + task_id="test", bash_command="true", dag=DAG(dag_id='id'), start_date=datetime.now() ) when = datetime.now() - value_tuple = 'command', 1, None, \ - SimpleTaskInstance(ti=TaskInstance(task=task, execution_date=datetime.now())) + value_tuple = ( + 'command', + 1, + None, + SimpleTaskInstance(ti=TaskInstance(task=task, execution_date=datetime.now())), + ) key = ('fail', 'fake_simple_ti', when, 0) executor.queued_tasks[key] = value_tuple executor.heartbeat() @@ -202,9 +205,7 @@ def test_exception_propagation(self): with _prepare_app(), self.assertLogs(celery_executor.log) as cm: executor = celery_executor.CeleryExecutor() - executor.tasks = { - 'key': FakeCeleryResult() - } + executor.tasks = {'key': FakeCeleryResult()} executor.bulk_state_fetcher._get_many_using_multiprocessing(executor.tasks.values()) self.assertTrue(any(celery_executor.CELERY_FETCH_ERR_MSG_HEADER in line for line in cm.output)) @@ -216,20 +217,21 @@ def test_exception_propagation(self): def test_gauge_executor_metrics(self, mock_stats_gauge, mock_trigger_tasks, mock_sync): executor = celery_executor.CeleryExecutor() executor.heartbeat() - calls = [mock.call('executor.open_slots', mock.ANY), - mock.call('executor.queued_tasks', mock.ANY), - mock.call('executor.running_tasks', mock.ANY)] + calls = [ + mock.call('executor.open_slots', mock.ANY), + mock.call('executor.queued_tasks', mock.ANY), + mock.call('executor.running_tasks', mock.ANY), + ] mock_stats_gauge.assert_has_calls(calls) - @parameterized.expand(( - [['true'], ValueError], - [['airflow', 'version'], ValueError], - [['airflow', 'tasks', 'run'], None] - )) + @parameterized.expand( + ([['true'], ValueError], [['airflow', 'version'], ValueError], [['airflow', 'tasks', 'run'], None]) + ) def test_command_validation(self, command, expected_exception): # Check that we validate _on the receiving_ side, not just sending side - with mock.patch('airflow.executors.celery_executor._execute_in_subprocess') as mock_subproc, \ - mock.patch('airflow.executors.celery_executor._execute_in_fork') as mock_fork: + with mock.patch( + 'airflow.executors.celery_executor._execute_in_subprocess' + ) as mock_subproc, mock.patch('airflow.executors.celery_executor._execute_in_fork') as mock_fork: if expected_exception: with pytest.raises(expected_exception): celery_executor.execute_command(command) @@ -238,8 +240,7 @@ def test_command_validation(self, command, expected_exception): else: celery_executor.execute_command(command) # One of these should be called. - assert mock_subproc.call_args == ((command,),) or \ - mock_fork.call_args == ((command,),) + assert mock_subproc.call_args == ((command,),) or mock_fork.call_args == ((command,),) @pytest.mark.backend("mysql", "postgres") def test_try_adopt_task_instances_none(self): @@ -289,8 +290,8 @@ def test_try_adopt_task_instances(self): dict(executor.adopted_task_timeouts), { key_1: queued_dttm + executor.task_adoption_timeout, - key_2: queued_dttm + executor.task_adoption_timeout - } + key_2: queued_dttm + executor.task_adoption_timeout, + }, ) self.assertEqual(executor.tasks, {key_1: AsyncResult("231"), key_2: AsyncResult("232")}) self.assertEqual(not_adopted_tis, []) @@ -313,14 +314,11 @@ def test_check_for_stalled_adopted_tasks(self): executor = celery_executor.CeleryExecutor() executor.adopted_task_timeouts = { key_1: queued_dttm + executor.task_adoption_timeout, - key_2: queued_dttm + executor.task_adoption_timeout + key_2: queued_dttm + executor.task_adoption_timeout, } executor.tasks = {key_1: AsyncResult("231"), key_2: AsyncResult("232")} executor.sync() - self.assertEqual( - executor.event_buffer, - {key_1: (State.FAILED, None), key_2: (State.FAILED, None)} - ) + self.assertEqual(executor.event_buffer, {key_1: (State.FAILED, None), key_2: (State.FAILED, None)}) self.assertEqual(executor.tasks, {}) self.assertEqual(executor.adopted_task_timeouts, {}) @@ -350,10 +348,10 @@ def __ne__(self, other): class TestBulkStateFetcher(unittest.TestCase): - - @mock.patch("celery.backends.base.BaseKeyValueStoreBackend.mget", return_value=[ - json.dumps({"status": "SUCCESS", "task_id": "123"}) - ]) + @mock.patch( + "celery.backends.base.BaseKeyValueStoreBackend.mget", + return_value=[json.dumps({"status": "SUCCESS", "task_id": "123"})], + ) @pytest.mark.integration("redis") @pytest.mark.integration("rabbitmq") @pytest.mark.backend("mysql", "postgres") @@ -362,10 +360,12 @@ def test_should_support_kv_backend(self, mock_mget): mock_backend = BaseKeyValueStoreBackend(app=celery_executor.app) with mock.patch.object(celery_executor.app, 'backend', mock_backend): fetcher = BulkStateFetcher() - result = fetcher.get_many([ - mock.MagicMock(task_id="123"), - mock.MagicMock(task_id="456"), - ]) + result = fetcher.get_many( + [ + mock.MagicMock(task_id="123"), + mock.MagicMock(task_id="456"), + ] + ) # Assert called - ignore order mget_args, _ = mock_mget.call_args @@ -389,10 +389,12 @@ def test_should_support_db_backend(self, mock_session): ] fetcher = BulkStateFetcher() - result = fetcher.get_many([ - mock.MagicMock(task_id="123"), - mock.MagicMock(task_id="456"), - ]) + result = fetcher.get_many( + [ + mock.MagicMock(task_id="123"), + mock.MagicMock(task_id="456"), + ] + ) self.assertEqual(result, {'123': ('SUCCESS', None), '456': ("PENDING", None)}) @@ -405,9 +407,11 @@ def test_should_support_base_backend(self): with mock.patch.object(celery_executor.app, 'backend', mock_backend): fetcher = BulkStateFetcher(1) - result = fetcher.get_many([ - ClassWithCustomAttributes(task_id="123", state='SUCCESS'), - ClassWithCustomAttributes(task_id="456", state="PENDING"), - ]) + result = fetcher.get_many( + [ + ClassWithCustomAttributes(task_id="123", state='SUCCESS'), + ClassWithCustomAttributes(task_id="456", state="PENDING"), + ] + ) self.assertEqual(result, {'123': ('SUCCESS', None), '456': ("PENDING", None)}) diff --git a/tests/executors/test_celery_kubernetes_executor.py b/tests/executors/test_celery_kubernetes_executor.py index a4f4852da11b0..cc8a958b2ed5a 100644 --- a/tests/executors/test_celery_kubernetes_executor.py +++ b/tests/executors/test_celery_kubernetes_executor.py @@ -74,7 +74,8 @@ def when_using_k8s_executor(): cke.queue_command(simple_task_instance, command, priority, queue) k8s_executor_mock.queue_command.assert_called_once_with( - simple_task_instance, command, priority, queue) + simple_task_instance, command, priority, queue + ) celery_executor_mock.queue_command.assert_not_called() def when_using_celery_executor(): @@ -88,7 +89,8 @@ def when_using_celery_executor(): cke.queue_command(simple_task_instance, command, priority, queue) celery_executor_mock.queue_command.assert_called_once_with( - simple_task_instance, command, priority, queue) + simple_task_instance, command, priority, queue + ) k8s_executor_mock.queue_command.assert_not_called() when_using_k8s_executor() @@ -121,7 +123,7 @@ def when_using_k8s_executor(): ignore_task_deps, ignore_ti_state, pool, - cfg_path + cfg_path, ) k8s_executor_mock.queue_task_instance.assert_called_once_with( @@ -133,7 +135,7 @@ def when_using_k8s_executor(): ignore_task_deps, ignore_ti_state, pool, - cfg_path + cfg_path, ) celery_executor_mock.queue_task_instance.assert_not_called() @@ -154,7 +156,7 @@ def when_using_celery_executor(): ignore_task_deps, ignore_ti_state, pool, - cfg_path + cfg_path, ) k8s_executor_mock.queue_task_instance.assert_not_called() @@ -167,7 +169,7 @@ def when_using_celery_executor(): ignore_task_deps, ignore_ti_state, pool, - cfg_path + cfg_path, ) when_using_k8s_executor() diff --git a/tests/executors/test_dask_executor.py b/tests/executors/test_dask_executor.py index 2f43271acb4a8..09a22f7bedbc3 100644 --- a/tests/executors/test_dask_executor.py +++ b/tests/executors/test_dask_executor.py @@ -26,6 +26,7 @@ try: from distributed import LocalCluster + # utility functions imported from the dask testing suite to instantiate a test # cluster for tls tests from distributed.utils_test import cluster as dask_testing_cluster, get_cert, tls_security @@ -38,7 +39,6 @@ class TestBaseDask(unittest.TestCase): - def assert_tasks_on_executor(self, executor): success_command = ['airflow', 'tasks', 'run', '--help'] @@ -49,10 +49,8 @@ def assert_tasks_on_executor(self, executor): executor.execute_async(key='success', command=success_command) executor.execute_async(key='fail', command=fail_command) - success_future = next( - k for k, v in executor.futures.items() if v == 'success') - fail_future = next( - k for k, v in executor.futures.items() if v == 'fail') + success_future = next(k for k, v in executor.futures.items() if v == 'success') + fail_future = next(k for k, v in executor.futures.items() if v == 'fail') # wait for the futures to execute, with a timeout timeout = timezone.utcnow() + timedelta(seconds=30) @@ -60,7 +58,8 @@ def assert_tasks_on_executor(self, executor): if timezone.utcnow() > timeout: raise ValueError( 'The futures should have finished; there is probably ' - 'an error communicating with the Dask cluster.') + 'an error communicating with the Dask cluster.' + ) # both tasks should have finished self.assertTrue(success_future.done()) @@ -72,7 +71,6 @@ def assert_tasks_on_executor(self, executor): class TestDaskExecutor(TestBaseDask): - def setUp(self): self.dagbag = DagBag(include_examples=True) self.cluster = LocalCluster() @@ -92,8 +90,8 @@ def test_backfill_integration(self): start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_first_depends_on_past=True, - executor=DaskExecutor( - cluster_address=self.cluster.scheduler_address)) + executor=DaskExecutor(cluster_address=self.cluster.scheduler_address), + ) job.run() def tearDown(self): @@ -101,21 +99,22 @@ def tearDown(self): class TestDaskExecutorTLS(TestBaseDask): - def setUp(self): self.dagbag = DagBag(include_examples=True) - @conf_vars({ - ('dask', 'tls_ca'): get_cert('tls-ca-cert.pem'), - ('dask', 'tls_cert'): get_cert('tls-key-cert.pem'), - ('dask', 'tls_key'): get_cert('tls-key.pem'), - }) + @conf_vars( + { + ('dask', 'tls_ca'): get_cert('tls-ca-cert.pem'), + ('dask', 'tls_cert'): get_cert('tls-key-cert.pem'), + ('dask', 'tls_key'): get_cert('tls-key.pem'), + } + ) def test_tls(self): # These use test certs that ship with dask/distributed and should not be # used in production with dask_testing_cluster( worker_kwargs={'security': tls_security(), "protocol": "tls"}, - scheduler_kwargs={'security': tls_security(), "protocol": "tls"} + scheduler_kwargs={'security': tls_security(), "protocol": "tls"}, ) as (cluster, _): executor = DaskExecutor(cluster_address=cluster['address']) @@ -133,7 +132,9 @@ def test_tls(self): def test_gauge_executor_metrics(self, mock_stats_gauge, mock_trigger_tasks, mock_sync): executor = DaskExecutor() executor.heartbeat() - calls = [mock.call('executor.open_slots', mock.ANY), - mock.call('executor.queued_tasks', mock.ANY), - mock.call('executor.running_tasks', mock.ANY)] + calls = [ + mock.call('executor.open_slots', mock.ANY), + mock.call('executor.queued_tasks', mock.ANY), + mock.call('executor.running_tasks', mock.ANY), + ] mock_stats_gauge.assert_has_calls(calls) diff --git a/tests/executors/test_executor_loader.py b/tests/executors/test_executor_loader.py index 60f4bcd7e0d5d..63ef8dd71961f 100644 --- a/tests/executors/test_executor_loader.py +++ b/tests/executors/test_executor_loader.py @@ -38,44 +38,37 @@ class FakePlugin(plugins_manager.AirflowPlugin): class TestExecutorLoader(unittest.TestCase): - def setUp(self) -> None: ExecutorLoader._default_executor = None def tearDown(self) -> None: ExecutorLoader._default_executor = None - @parameterized.expand([ - ("CeleryExecutor", ), - ("CeleryKubernetesExecutor", ), - ("DebugExecutor", ), - ("KubernetesExecutor", ), - ("LocalExecutor", ), - ]) + @parameterized.expand( + [ + ("CeleryExecutor",), + ("CeleryKubernetesExecutor",), + ("DebugExecutor",), + ("KubernetesExecutor",), + ("LocalExecutor",), + ] + ) def test_should_support_executor_from_core(self, executor_name): - with conf_vars({ - ("core", "executor"): executor_name - }): + with conf_vars({("core", "executor"): executor_name}): executor = ExecutorLoader.get_default_executor() self.assertIsNotNone(executor) self.assertEqual(executor_name, executor.__class__.__name__) - @mock.patch("airflow.plugins_manager.plugins", [ - FakePlugin() - ]) + @mock.patch("airflow.plugins_manager.plugins", [FakePlugin()]) @mock.patch("airflow.plugins_manager.executors_modules", None) def test_should_support_plugins(self): - with conf_vars({ - ("core", "executor"): f"{TEST_PLUGIN_NAME}.FakeExecutor" - }): + with conf_vars({("core", "executor"): f"{TEST_PLUGIN_NAME}.FakeExecutor"}): executor = ExecutorLoader.get_default_executor() self.assertIsNotNone(executor) self.assertEqual("FakeExecutor", executor.__class__.__name__) def test_should_support_custom_path(self): - with conf_vars({ - ("core", "executor"): "tests.executors.test_executor_loader.FakeExecutor" - }): + with conf_vars({("core", "executor"): "tests.executors.test_executor_loader.FakeExecutor"}): executor = ExecutorLoader.get_default_executor() self.assertIsNotNone(executor) self.assertEqual("FakeExecutor", executor.__class__.__name__) diff --git a/tests/executors/test_kubernetes_executor.py b/tests/executors/test_kubernetes_executor.py index 4c38352ceef25..9765669d32aa9 100644 --- a/tests/executors/test_kubernetes_executor.py +++ b/tests/executors/test_kubernetes_executor.py @@ -32,7 +32,9 @@ from kubernetes.client.rest import ApiException from airflow.executors.kubernetes_executor import ( - AirflowKubernetesScheduler, KubernetesExecutor, create_pod_id, + AirflowKubernetesScheduler, + KubernetesExecutor, + create_pod_id, ) from airflow.kubernetes import pod_generator from airflow.kubernetes.pod_generator import PodGenerator @@ -57,38 +59,29 @@ def _cases(self): ("my.dag.id", "my.task.id"), ("MYDAGID", "MYTASKID"), ("my_dag_id", "my_task_id"), - ("mydagid" * 200, "my_task_id" * 200) + ("mydagid" * 200, "my_task_id" * 200), ] - cases.extend([ - (self._gen_random_string(seed, 200), self._gen_random_string(seed, 200)) - for seed in range(100) - ]) + cases.extend( + [(self._gen_random_string(seed, 200), self._gen_random_string(seed, 200)) for seed in range(100)] + ) return cases @staticmethod def _is_valid_pod_id(name): regex = r"^[a-z0-9]([-a-z0-9]*[a-z0-9])?(\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*$" - return ( - len(name) <= 253 and - all(ch.lower() == ch for ch in name) and - re.match(regex, name)) + return len(name) <= 253 and all(ch.lower() == ch for ch in name) and re.match(regex, name) @staticmethod def _is_safe_label_value(value): regex = r'^[^a-z0-9A-Z]*|[^a-zA-Z0-9_\-\.]|[^a-z0-9A-Z]*$' - return ( - len(value) <= 63 and - re.match(regex, value)) + return len(value) <= 63 and re.match(regex, value) - @unittest.skipIf(AirflowKubernetesScheduler is None, - 'kubernetes python package is not installed') + @unittest.skipIf(AirflowKubernetesScheduler is None, 'kubernetes python package is not installed') def test_create_pod_id(self): for dag_id, task_id in self._cases(): - pod_name = PodGenerator.make_unique_pod_id( - create_pod_id(dag_id, task_id) - ) + pod_name = PodGenerator.make_unique_pod_id(create_pod_id(dag_id, task_id)) self.assertTrue(self._is_valid_pod_id(pod_name)) def test_make_safe_label_value(self): @@ -98,23 +91,16 @@ def test_make_safe_label_value(self): safe_task_id = pod_generator.make_safe_label_value(task_id) self.assertTrue(self._is_safe_label_value(safe_task_id)) dag_id = "my_dag_id" - self.assertEqual( - dag_id, - pod_generator.make_safe_label_value(dag_id) - ) + self.assertEqual(dag_id, pod_generator.make_safe_label_value(dag_id)) dag_id = "my_dag_id_" + "a" * 64 self.assertEqual( - "my_dag_id_" + "a" * 43 + "-0ce114c45", - pod_generator.make_safe_label_value(dag_id) + "my_dag_id_" + "a" * 43 + "-0ce114c45", pod_generator.make_safe_label_value(dag_id) ) def test_execution_date_serialize_deserialize(self): datetime_obj = datetime.now() - serialized_datetime = \ - pod_generator.datetime_to_label_safe_datestring( - datetime_obj) - new_datetime_obj = pod_generator.label_safe_datestring_to_datetime( - serialized_datetime) + serialized_datetime = pod_generator.datetime_to_label_safe_datestring(datetime_obj) + new_datetime_obj = pod_generator.label_safe_datestring_to_datetime(serialized_datetime) self.assertEqual(datetime_obj, new_datetime_obj) @@ -129,27 +115,27 @@ def setUp(self) -> None: self.kubernetes_executor = KubernetesExecutor() self.kubernetes_executor.job_id = "5" - @unittest.skipIf(AirflowKubernetesScheduler is None, - 'kubernetes python package is not installed') + @unittest.skipIf(AirflowKubernetesScheduler is None, 'kubernetes python package is not installed') @mock.patch('airflow.executors.kubernetes_executor.KubernetesJobWatcher') @mock.patch('airflow.executors.kubernetes_executor.get_kube_client') def test_run_next_exception(self, mock_get_kube_client, mock_kubernetes_job_watcher): import sys + path = sys.path[0] + '/tests/kubernetes/pod_generator_base_with_secrets.yaml' # When a quota is exceeded this is the ApiException we get response = HTTPResponse( body='{"kind": "Status", "apiVersion": "v1", "metadata": {}, "status": "Failure", ' - '"message": "pods \\"podname\\" is forbidden: exceeded quota: compute-resources, ' - 'requested: limits.memory=4Gi, used: limits.memory=6508Mi, limited: limits.memory=10Gi", ' - '"reason": "Forbidden", "details": {"name": "podname", "kind": "pods"}, "code": 403}') + '"message": "pods \\"podname\\" is forbidden: exceeded quota: compute-resources, ' + 'requested: limits.memory=4Gi, used: limits.memory=6508Mi, limited: limits.memory=10Gi", ' + '"reason": "Forbidden", "details": {"name": "podname", "kind": "pods"}, "code": 403}' + ) response.status = 403 response.reason = "Forbidden" # A mock kube_client that throws errors when making a pod mock_kube_client = mock.patch('kubernetes.client.CoreV1Api', autospec=True) - mock_kube_client.create_namespaced_pod = mock.MagicMock( - side_effect=ApiException(http_resp=response)) + mock_kube_client.create_namespaced_pod = mock.MagicMock(side_effect=ApiException(http_resp=response)) mock_get_kube_client.return_value = mock_kube_client mock_api_client = mock.MagicMock() mock_api_client.sanitize_for_serialization.return_value = {} @@ -163,10 +149,11 @@ def test_run_next_exception(self, mock_get_kube_client, mock_kubernetes_job_watc kubernetes_executor.start() # Execute a task while the Api Throws errors try_number = 1 - kubernetes_executor.execute_async(key=('dag', 'task', datetime.utcnow(), try_number), - queue=None, - command=['airflow', 'tasks', 'run', 'true', 'some_parameter'], - ) + kubernetes_executor.execute_async( + key=('dag', 'task', datetime.utcnow(), try_number), + queue=None, + command=['airflow', 'tasks', 'run', 'true', 'some_parameter'], + ) kubernetes_executor.sync() kubernetes_executor.sync() @@ -188,9 +175,11 @@ def test_run_next_exception(self, mock_get_kube_client, mock_kubernetes_job_watc def test_gauge_executor_metrics(self, mock_stats_gauge, mock_trigger_tasks, mock_sync, mock_kube_config): executor = self.kubernetes_executor executor.heartbeat() - calls = [mock.call('executor.open_slots', mock.ANY), - mock.call('executor.queued_tasks', mock.ANY), - mock.call('executor.running_tasks', mock.ANY)] + calls = [ + mock.call('executor.open_slots', mock.ANY), + mock.call('executor.queued_tasks', mock.ANY), + mock.call('executor.running_tasks', mock.ANY), + ] mock_stats_gauge.assert_has_calls(calls) @mock.patch('airflow.executors.kubernetes_executor.KubernetesJobWatcher') @@ -218,10 +207,7 @@ def test_change_state_success(self, mock_delete_pod, mock_get_kube_client, mock_ @mock.patch('airflow.executors.kubernetes_executor.get_kube_client') @mock.patch('airflow.executors.kubernetes_executor.AirflowKubernetesScheduler.delete_pod') def test_change_state_failed_no_deletion( - self, - mock_delete_pod, - mock_get_kube_client, - mock_kubernetes_job_watcher + self, mock_delete_pod, mock_get_kube_client, mock_kubernetes_job_watcher ): executor = self.kubernetes_executor executor.kube_config.delete_worker_pods = False @@ -232,13 +218,15 @@ def test_change_state_failed_no_deletion( executor._change_state(key, State.FAILED, 'pod_id', 'default') self.assertTrue(executor.event_buffer[key][0] == State.FAILED) mock_delete_pod.assert_not_called() -# pylint: enable=unused-argument + + # pylint: enable=unused-argument @mock.patch('airflow.executors.kubernetes_executor.KubernetesJobWatcher') @mock.patch('airflow.executors.kubernetes_executor.get_kube_client') @mock.patch('airflow.executors.kubernetes_executor.AirflowKubernetesScheduler.delete_pod') - def test_change_state_skip_pod_deletion(self, mock_delete_pod, mock_get_kube_client, - mock_kubernetes_job_watcher): + def test_change_state_skip_pod_deletion( + self, mock_delete_pod, mock_get_kube_client, mock_kubernetes_job_watcher + ): test_time = timezone.utcnow() executor = self.kubernetes_executor executor.kube_config.delete_worker_pods = False @@ -253,8 +241,9 @@ def test_change_state_skip_pod_deletion(self, mock_delete_pod, mock_get_kube_cli @mock.patch('airflow.executors.kubernetes_executor.KubernetesJobWatcher') @mock.patch('airflow.executors.kubernetes_executor.get_kube_client') @mock.patch('airflow.executors.kubernetes_executor.AirflowKubernetesScheduler.delete_pod') - def test_change_state_failed_pod_deletion(self, mock_delete_pod, mock_get_kube_client, - mock_kubernetes_job_watcher): + def test_change_state_failed_pod_deletion( + self, mock_delete_pod, mock_get_kube_client, mock_kubernetes_job_watcher + ): executor = self.kubernetes_executor executor.kube_config.delete_worker_pods_on_failure = True @@ -271,24 +260,22 @@ def test_adopt_launched_task(self, mock_kube_client): pod_ids = {"dagtask": {}} pod = k8s.V1Pod( metadata=k8s.V1ObjectMeta( - name="foo", - labels={ - "airflow-worker": "bar", - "dag_id": "dag", - "task_id": "task" - } + name="foo", labels={"airflow-worker": "bar", "dag_id": "dag", "task_id": "task"} ) ) executor.adopt_launched_task(mock_kube_client, pod=pod, pod_ids=pod_ids) self.assertEqual( mock_kube_client.patch_namespaced_pod.call_args[1], - {'body': {'metadata': {'labels': {'airflow-worker': 'modified', - 'dag_id': 'dag', - 'task_id': 'task'}, - 'name': 'foo'} - }, - 'name': 'foo', - 'namespace': None} + { + 'body': { + 'metadata': { + 'labels': {'airflow-worker': 'modified', 'dag_id': 'dag', 'task_id': 'task'}, + 'name': 'foo', + } + }, + 'name': 'foo', + 'namespace': None, + }, ) self.assertDictEqual(pod_ids, {}) @@ -306,12 +293,7 @@ def test_not_adopt_unassigned_task(self, mock_kube_client): pod_ids = {"foobar": {}} pod = k8s.V1Pod( metadata=k8s.V1ObjectMeta( - name="foo", - labels={ - "airflow-worker": "bar", - "dag_id": "dag", - "task_id": "task" - } + name="foo", labels={"airflow-worker": "bar", "dag_id": "dag", "task_id": "task"} ) ) executor.adopt_launched_task(mock_kube_client, pod=pod, pod_ids=pod_ids) diff --git a/tests/executors/test_local_executor.py b/tests/executors/test_local_executor.py index 20722bf478fd4..b12f783c1b5f7 100644 --- a/tests/executors/test_local_executor.py +++ b/tests/executors/test_local_executor.py @@ -91,14 +91,16 @@ def _test_execute(self, parallelism, success_command, fail_command): self.assertEqual(executor.workers_used, expected) def test_execution_subprocess_unlimited_parallelism(self): - with mock.patch.object(settings, 'EXECUTE_TASKS_NEW_PYTHON_INTERPRETER', - new_callable=mock.PropertyMock) as option: + with mock.patch.object( + settings, 'EXECUTE_TASKS_NEW_PYTHON_INTERPRETER', new_callable=mock.PropertyMock + ) as option: option.return_value = True self.execution_parallelism_subprocess(parallelism=0) # pylint: disable=no-value-for-parameter def test_execution_subprocess_limited_parallelism(self): - with mock.patch.object(settings, 'EXECUTE_TASKS_NEW_PYTHON_INTERPRETER', - new_callable=mock.PropertyMock) as option: + with mock.patch.object( + settings, 'EXECUTE_TASKS_NEW_PYTHON_INTERPRETER', new_callable=mock.PropertyMock + ) as option: option.return_value = True self.execution_parallelism_subprocess(parallelism=2) # pylint: disable=no-value-for-parameter @@ -116,7 +118,9 @@ def test_execution_limited_parallelism_fork(self): def test_gauge_executor_metrics(self, mock_stats_gauge, mock_trigger_tasks, mock_sync): executor = LocalExecutor() executor.heartbeat() - calls = [mock.call('executor.open_slots', mock.ANY), - mock.call('executor.queued_tasks', mock.ANY), - mock.call('executor.running_tasks', mock.ANY)] + calls = [ + mock.call('executor.open_slots', mock.ANY), + mock.call('executor.queued_tasks', mock.ANY), + mock.call('executor.running_tasks', mock.ANY), + ] mock_stats_gauge.assert_has_calls(calls) diff --git a/tests/executors/test_sequential_executor.py b/tests/executors/test_sequential_executor.py index 0189a6ebe401f..bb97e98b2a018 100644 --- a/tests/executors/test_sequential_executor.py +++ b/tests/executors/test_sequential_executor.py @@ -23,14 +23,15 @@ class TestSequentialExecutor(unittest.TestCase): - @mock.patch('airflow.executors.sequential_executor.SequentialExecutor.sync') @mock.patch('airflow.executors.base_executor.BaseExecutor.trigger_tasks') @mock.patch('airflow.executors.base_executor.Stats.gauge') def test_gauge_executor_metrics(self, mock_stats_gauge, mock_trigger_tasks, mock_sync): executor = SequentialExecutor() executor.heartbeat() - calls = [mock.call('executor.open_slots', mock.ANY), - mock.call('executor.queued_tasks', mock.ANY), - mock.call('executor.running_tasks', mock.ANY)] + calls = [ + mock.call('executor.open_slots', mock.ANY), + mock.call('executor.queued_tasks', mock.ANY), + mock.call('executor.running_tasks', mock.ANY), + ] mock_stats_gauge.assert_has_calls(calls) diff --git a/tests/hooks/test_dbapi_hook.py b/tests/hooks/test_dbapi_hook.py index 5e16aaf4cde80..a122e01bdfc02 100644 --- a/tests/hooks/test_dbapi_hook.py +++ b/tests/hooks/test_dbapi_hook.py @@ -25,7 +25,6 @@ class TestDbApiHook(unittest.TestCase): - def setUp(self): super().setUp() @@ -45,8 +44,7 @@ def get_conn(self): def test_get_records(self): statement = "SQL" - rows = [("hello",), - ("world",)] + rows = [("hello",), ("world",)] self.cur.fetchall.return_value = rows @@ -59,8 +57,7 @@ def test_get_records(self): def test_get_records_parameters(self): statement = "SQL" parameters = ["X", "Y", "Z"] - rows = [("hello",), - ("world",)] + rows = [("hello",), ("world",)] self.cur.fetchall.return_value = rows @@ -83,8 +80,7 @@ def test_get_records_exception(self): def test_insert_rows(self): table = "table" - rows = [("hello",), - ("world",)] + rows = [("hello",), ("world",)] self.db_hook.insert_rows(table, rows) @@ -100,8 +96,7 @@ def test_insert_rows(self): def test_insert_rows_replace(self): table = "table" - rows = [("hello",), - ("world",)] + rows = [("hello",), ("world",)] self.db_hook.insert_rows(table, rows, replace=True) @@ -117,8 +112,7 @@ def test_insert_rows_replace(self): def test_insert_rows_target_fields(self): table = "table" - rows = [("hello",), - ("world",)] + rows = [("hello",), ("world",)] target_fields = ["field"] self.db_hook.insert_rows(table, rows, target_fields) @@ -135,8 +129,7 @@ def test_insert_rows_target_fields(self): def test_insert_rows_commit_every(self): table = "table" - rows = [("hello",), - ("world",)] + rows = [("hello",), ("world",)] commit_every = 1 self.db_hook.insert_rows(table, rows, commit_every=commit_every) @@ -152,25 +145,24 @@ def test_insert_rows_commit_every(self): self.cur.execute.assert_any_call(sql, row) def test_get_uri_schema_not_none(self): - self.db_hook.get_connection = mock.MagicMock(return_value=Connection( - conn_type="conn_type", - host="host", - login="login", - password="password", - schema="schema", - port=1 - )) + self.db_hook.get_connection = mock.MagicMock( + return_value=Connection( + conn_type="conn_type", + host="host", + login="login", + password="password", + schema="schema", + port=1, + ) + ) self.assertEqual("conn_type://login:password@host:1/schema", self.db_hook.get_uri()) def test_get_uri_schema_none(self): - self.db_hook.get_connection = mock.MagicMock(return_value=Connection( - conn_type="conn_type", - host="host", - login="login", - password="password", - schema=None, - port=1 - )) + self.db_hook.get_connection = mock.MagicMock( + return_value=Connection( + conn_type="conn_type", host="host", login="login", password="password", schema=None, port=1 + ) + ) self.assertEqual("conn_type://login:password@host:1/", self.db_hook.get_uri()) def test_run_log(self): diff --git a/tests/jobs/test_backfill_job.py b/tests/jobs/test_backfill_job.py index b9a3b6a754a96..9898859236416 100644 --- a/tests/jobs/test_backfill_job.py +++ b/tests/jobs/test_backfill_job.py @@ -31,7 +31,10 @@ from airflow import settings from airflow.cli import cli_parser from airflow.exceptions import ( - AirflowException, AirflowTaskTimeout, DagConcurrencyLimitReached, NoAvailablePoolSlot, + AirflowException, + AirflowTaskTimeout, + DagConcurrencyLimitReached, + NoAvailablePoolSlot, TaskConcurrencyLimitReached, ) from airflow.jobs.backfill_job import BackfillJob @@ -53,19 +56,11 @@ @pytest.mark.heisentests class TestBackfillJob(unittest.TestCase): - def _get_dummy_dag(self, dag_id, pool=Pool.DEFAULT_POOL_NAME, task_concurrency=None): - dag = DAG( - dag_id=dag_id, - start_date=DEFAULT_DATE, - schedule_interval='@daily') + dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, schedule_interval='@daily') with dag: - DummyOperator( - task_id='op', - pool=pool, - task_concurrency=task_concurrency, - dag=dag) + DummyOperator(task_id='op', pool=pool, task_concurrency=task_concurrency, dag=dag) dag.clear() return dag @@ -105,7 +100,7 @@ def test_unfinished_dag_runs_set_to_failed(self): dag=dag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=8), - ignore_first_depends_on_past=True + ignore_first_depends_on_past=True, ) job._set_unfinished_dag_runs_to_failed([dag_run]) @@ -129,7 +124,7 @@ def test_dag_run_with_finished_tasks_set_to_success(self): dag=dag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=8), - ignore_first_depends_on_past=True + ignore_first_depends_on_past=True, ) job._set_unfinished_dag_runs_to_failed([dag_run]) @@ -154,10 +149,7 @@ def test_trigger_controller_dag(self): self.assertFalse(task_instances_list) job = BackfillJob( - dag=dag, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE, - ignore_first_depends_on_past=True + dag=dag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_first_depends_on_past=True ) job.run() @@ -181,7 +173,7 @@ def test_backfill_multi_dates(self): start_date=DEFAULT_DATE, end_date=end_date, executor=executor, - ignore_first_depends_on_past=True + ignore_first_depends_on_past=True, ) job.run() @@ -201,20 +193,19 @@ def test_backfill_multi_dates(self): ("run_this_last", end_date), ] self.assertListEqual( - [((dag.dag_id, task_id, when, 1), (State.SUCCESS, None)) - for (task_id, when) in expected_execution_order], - executor.sorted_tasks + [ + ((dag.dag_id, task_id, when, 1), (State.SUCCESS, None)) + for (task_id, when) in expected_execution_order + ], + executor.sorted_tasks, ) session = settings.Session() - drs = session.query(DagRun).filter( - DagRun.dag_id == dag.dag_id - ).order_by(DagRun.execution_date).all() + drs = session.query(DagRun).filter(DagRun.dag_id == dag.dag_id).order_by(DagRun.execution_date).all() self.assertTrue(drs[0].execution_date == DEFAULT_DATE) self.assertTrue(drs[0].state == State.SUCCESS) - self.assertTrue(drs[1].execution_date == - DEFAULT_DATE + datetime.timedelta(days=1)) + self.assertTrue(drs[1].execution_date == DEFAULT_DATE + datetime.timedelta(days=1)) self.assertTrue(drs[1].state == State.SUCCESS) dag.clear() @@ -241,8 +232,7 @@ def test_backfill_multi_dates(self): ], [ "example_bash_operator", - ("runme_0", "runme_1", "runme_2", "also_run_this", "run_after_loop", - "run_this_last"), + ("runme_0", "runme_1", "runme_2", "also_run_this", "run_after_loop", "run_this_last"), ], [ "example_skip_dag", @@ -277,13 +267,16 @@ def test_backfill_examples(self, dag_id, expected_execution_order): start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, executor=executor, - ignore_first_depends_on_past=True) + ignore_first_depends_on_past=True, + ) job.run() self.assertListEqual( - [((dag_id, task_id, DEFAULT_DATE, 1), (State.SUCCESS, None)) - for task_id in expected_execution_order], - executor.sorted_tasks + [ + ((dag_id, task_id, DEFAULT_DATE, 1), (State.SUCCESS, None)) + for task_id in expected_execution_order + ], + executor.sorted_tasks, ) def test_backfill_conf(self): @@ -292,11 +285,13 @@ def test_backfill_conf(self): executor = MockExecutor() conf_ = json.loads("""{"key": "value"}""") - job = BackfillJob(dag=dag, - executor=executor, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=2), - conf=conf_) + job = BackfillJob( + dag=dag, + executor=executor, + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=2), + conf=conf_, + ) job.run() dr = DagRun.find(dag_id='test_backfill_conf') @@ -548,10 +543,7 @@ def test_backfill_respect_pool_limit(self, mock_log): self.assertGreater(times_pool_limit_reached_in_debug, 0) def test_backfill_run_rescheduled(self): - dag = DAG( - dag_id='test_backfill_run_rescheduled', - start_date=DEFAULT_DATE, - schedule_interval='@daily') + dag = DAG(dag_id='test_backfill_run_rescheduled', start_date=DEFAULT_DATE, schedule_interval='@daily') with dag: DummyOperator( @@ -563,142 +555,130 @@ def test_backfill_run_rescheduled(self): executor = MockExecutor() - job = BackfillJob(dag=dag, - executor=executor, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=2), - ) + job = BackfillJob( + dag=dag, + executor=executor, + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=2), + ) job.run() - ti = TI(task=dag.get_task('test_backfill_run_rescheduled_task-1'), - execution_date=DEFAULT_DATE) + ti = TI(task=dag.get_task('test_backfill_run_rescheduled_task-1'), execution_date=DEFAULT_DATE) ti.refresh_from_db() ti.set_state(State.UP_FOR_RESCHEDULE) - job = BackfillJob(dag=dag, - executor=executor, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=2), - rerun_failed_tasks=True - ) + job = BackfillJob( + dag=dag, + executor=executor, + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=2), + rerun_failed_tasks=True, + ) job.run() - ti = TI(task=dag.get_task('test_backfill_run_rescheduled_task-1'), - execution_date=DEFAULT_DATE) + ti = TI(task=dag.get_task('test_backfill_run_rescheduled_task-1'), execution_date=DEFAULT_DATE) ti.refresh_from_db() self.assertEqual(ti.state, State.SUCCESS) def test_backfill_rerun_failed_tasks(self): - dag = DAG( - dag_id='test_backfill_rerun_failed', - start_date=DEFAULT_DATE, - schedule_interval='@daily') + dag = DAG(dag_id='test_backfill_rerun_failed', start_date=DEFAULT_DATE, schedule_interval='@daily') with dag: - DummyOperator( - task_id='test_backfill_rerun_failed_task-1', - dag=dag) + DummyOperator(task_id='test_backfill_rerun_failed_task-1', dag=dag) dag.clear() executor = MockExecutor() - job = BackfillJob(dag=dag, - executor=executor, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=2), - ) + job = BackfillJob( + dag=dag, + executor=executor, + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=2), + ) job.run() - ti = TI(task=dag.get_task('test_backfill_rerun_failed_task-1'), - execution_date=DEFAULT_DATE) + ti = TI(task=dag.get_task('test_backfill_rerun_failed_task-1'), execution_date=DEFAULT_DATE) ti.refresh_from_db() ti.set_state(State.FAILED) - job = BackfillJob(dag=dag, - executor=executor, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=2), - rerun_failed_tasks=True - ) + job = BackfillJob( + dag=dag, + executor=executor, + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=2), + rerun_failed_tasks=True, + ) job.run() - ti = TI(task=dag.get_task('test_backfill_rerun_failed_task-1'), - execution_date=DEFAULT_DATE) + ti = TI(task=dag.get_task('test_backfill_rerun_failed_task-1'), execution_date=DEFAULT_DATE) ti.refresh_from_db() self.assertEqual(ti.state, State.SUCCESS) def test_backfill_rerun_upstream_failed_tasks(self): dag = DAG( - dag_id='test_backfill_rerun_upstream_failed', - start_date=DEFAULT_DATE, - schedule_interval='@daily') + dag_id='test_backfill_rerun_upstream_failed', start_date=DEFAULT_DATE, schedule_interval='@daily' + ) with dag: - op1 = DummyOperator(task_id='test_backfill_rerun_upstream_failed_task-1', - dag=dag) - op2 = DummyOperator(task_id='test_backfill_rerun_upstream_failed_task-2', - dag=dag) + op1 = DummyOperator(task_id='test_backfill_rerun_upstream_failed_task-1', dag=dag) + op2 = DummyOperator(task_id='test_backfill_rerun_upstream_failed_task-2', dag=dag) op1.set_upstream(op2) dag.clear() executor = MockExecutor() - job = BackfillJob(dag=dag, - executor=executor, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=2), - ) + job = BackfillJob( + dag=dag, + executor=executor, + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=2), + ) job.run() - ti = TI(task=dag.get_task('test_backfill_rerun_upstream_failed_task-1'), - execution_date=DEFAULT_DATE) + ti = TI(task=dag.get_task('test_backfill_rerun_upstream_failed_task-1'), execution_date=DEFAULT_DATE) ti.refresh_from_db() ti.set_state(State.UPSTREAM_FAILED) - job = BackfillJob(dag=dag, - executor=executor, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=2), - rerun_failed_tasks=True - ) + job = BackfillJob( + dag=dag, + executor=executor, + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=2), + rerun_failed_tasks=True, + ) job.run() - ti = TI(task=dag.get_task('test_backfill_rerun_upstream_failed_task-1'), - execution_date=DEFAULT_DATE) + ti = TI(task=dag.get_task('test_backfill_rerun_upstream_failed_task-1'), execution_date=DEFAULT_DATE) ti.refresh_from_db() self.assertEqual(ti.state, State.SUCCESS) def test_backfill_rerun_failed_tasks_without_flag(self): - dag = DAG( - dag_id='test_backfill_rerun_failed', - start_date=DEFAULT_DATE, - schedule_interval='@daily') + dag = DAG(dag_id='test_backfill_rerun_failed', start_date=DEFAULT_DATE, schedule_interval='@daily') with dag: - DummyOperator( - task_id='test_backfill_rerun_failed_task-1', - dag=dag) + DummyOperator(task_id='test_backfill_rerun_failed_task-1', dag=dag) dag.clear() executor = MockExecutor() - job = BackfillJob(dag=dag, - executor=executor, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=2), - ) + job = BackfillJob( + dag=dag, + executor=executor, + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=2), + ) job.run() - ti = TI(task=dag.get_task('test_backfill_rerun_failed_task-1'), - execution_date=DEFAULT_DATE) + ti = TI(task=dag.get_task('test_backfill_rerun_failed_task-1'), execution_date=DEFAULT_DATE) ti.refresh_from_db() ti.set_state(State.FAILED) - job = BackfillJob(dag=dag, - executor=executor, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=2), - rerun_failed_tasks=False - ) + job = BackfillJob( + dag=dag, + executor=executor, + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=2), + rerun_failed_tasks=False, + ) with self.assertRaises(AirflowException): job.run() @@ -707,7 +687,8 @@ def test_backfill_ordered_concurrent_execute(self): dag = DAG( dag_id='test_backfill_ordered_concurrent_execute', start_date=DEFAULT_DATE, - schedule_interval="@daily") + schedule_interval="@daily", + ) with dag: op1 = DummyOperator(task_id='leave1') @@ -724,11 +705,12 @@ def test_backfill_ordered_concurrent_execute(self): dag.clear() executor = MockExecutor(parallelism=16) - job = BackfillJob(dag=dag, - executor=executor, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=2), - ) + job = BackfillJob( + dag=dag, + executor=executor, + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=2), + ) job.run() date0 = DEFAULT_DATE @@ -748,15 +730,12 @@ def test_backfill_ordered_concurrent_execute(self): ('leave1', date2), ('leave2', date0), ('leave2', date1), - ('leave2', date2) + ('leave2', date2), ], - [('upstream_level_1', date0), ('upstream_level_1', date1), - ('upstream_level_1', date2)], - [('upstream_level_2', date0), ('upstream_level_2', date1), - ('upstream_level_2', date2)], - [('upstream_level_3', date0), ('upstream_level_3', date1), - ('upstream_level_3', date2)], - ] + [('upstream_level_1', date0), ('upstream_level_1', date1), ('upstream_level_1', date2)], + [('upstream_level_2', date0), ('upstream_level_2', date1), ('upstream_level_2', date2)], + [('upstream_level_3', date0), ('upstream_level_3', date1), ('upstream_level_3', date2)], + ], ) def test_backfill_pooled_tasks(self): @@ -773,11 +752,7 @@ def test_backfill_pooled_tasks(self): dag.clear() executor = MockExecutor(do_update=True) - job = BackfillJob( - dag=dag, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE, - executor=executor) + job = BackfillJob(dag=dag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, executor=executor) # run with timeout because this creates an infinite loop if not # caught @@ -786,9 +761,7 @@ def test_backfill_pooled_tasks(self): job.run() except AirflowTaskTimeout: pass - ti = TI( - task=dag.get_task('test_backfill_pooled_task'), - execution_date=DEFAULT_DATE) + ti = TI(task=dag.get_task('test_backfill_pooled_task'), execution_date=DEFAULT_DATE) ti.refresh_from_db() self.assertEqual(ti.state, State.SUCCESS) @@ -804,14 +777,16 @@ def test_backfill_depends_on_past(self): self.assertRaisesRegex( AirflowException, 'BackfillJob is deadlocked', - BackfillJob(dag=dag, start_date=run_date, end_date=run_date).run) + BackfillJob(dag=dag, start_date=run_date, end_date=run_date).run, + ) BackfillJob( dag=dag, start_date=run_date, end_date=run_date, executor=MockExecutor(), - ignore_first_depends_on_past=True).run() + ignore_first_depends_on_past=True, + ).run() # ti should have succeeded ti = TI(dag.tasks[0], run_date) @@ -833,10 +808,7 @@ def test_backfill_depends_on_past_backwards(self): dag.clear() executor = MockExecutor() - job = BackfillJob(dag=dag, - executor=executor, - ignore_first_depends_on_past=True, - **kwargs) + job = BackfillJob(dag=dag, executor=executor, ignore_first_depends_on_past=True, **kwargs) job.run() ti = TI(dag.get_task('test_dop_task'), end_date) @@ -846,13 +818,11 @@ def test_backfill_depends_on_past_backwards(self): # raises backwards expected_msg = 'You cannot backfill backwards because one or more tasks depend_on_past: {}'.format( - 'test_dop_task') + 'test_dop_task' + ) with self.assertRaisesRegex(AirflowException, expected_msg): executor = MockExecutor() - job = BackfillJob(dag=dag, - executor=executor, - run_backwards=True, - **kwargs) + job = BackfillJob(dag=dag, executor=executor, run_backwards=True, **kwargs) job.run() def test_cli_receives_delay_arg(self): @@ -878,7 +848,7 @@ def _get_dag_test_max_active_limits(self, dag_id, max_active_runs=1): dag_id=dag_id, start_date=DEFAULT_DATE, schedule_interval="@hourly", - max_active_runs=max_active_runs + max_active_runs=max_active_runs, ) with dag: @@ -895,18 +865,16 @@ def _get_dag_test_max_active_limits(self, dag_id, max_active_runs=1): def test_backfill_max_limit_check_within_limit(self): dag = self._get_dag_test_max_active_limits( - 'test_backfill_max_limit_check_within_limit', - max_active_runs=16) + 'test_backfill_max_limit_check_within_limit', max_active_runs=16 + ) start_date = DEFAULT_DATE - datetime.timedelta(hours=1) end_date = DEFAULT_DATE executor = MockExecutor() - job = BackfillJob(dag=dag, - start_date=start_date, - end_date=end_date, - executor=executor, - donot_pickle=True) + job = BackfillJob( + dag=dag, start_date=start_date, end_date=end_date, executor=executor, donot_pickle=True + ) job.run() dagruns = DagRun.find(dag_id=dag.dag_id) @@ -943,16 +911,14 @@ def run_backfill(cond): thread_session.close() executor = MockExecutor() - job = BackfillJob(dag=dag, - start_date=start_date, - end_date=end_date, - executor=executor, - donot_pickle=True) + job = BackfillJob( + dag=dag, start_date=start_date, end_date=end_date, executor=executor, donot_pickle=True + ) job.run() - backfill_job_thread = threading.Thread(target=run_backfill, - name="run_backfill", - args=(dag_run_created_cond,)) + backfill_job_thread = threading.Thread( + target=run_backfill, name="run_backfill", args=(dag_run_created_cond,) + ) dag_run_created_cond.acquire() with create_session() as session: @@ -982,23 +948,22 @@ def run_backfill(cond): dag_run_created_cond.release() def test_backfill_max_limit_check_no_count_existing(self): - dag = self._get_dag_test_max_active_limits( - 'test_backfill_max_limit_check_no_count_existing') + dag = self._get_dag_test_max_active_limits('test_backfill_max_limit_check_no_count_existing') start_date = DEFAULT_DATE end_date = DEFAULT_DATE # Existing dagrun that is within the backfill range - dag.create_dagrun(run_id="test_existing_backfill", - state=State.RUNNING, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE) + dag.create_dagrun( + run_id="test_existing_backfill", + state=State.RUNNING, + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + ) executor = MockExecutor() - job = BackfillJob(dag=dag, - start_date=start_date, - end_date=end_date, - executor=executor, - donot_pickle=True) + job = BackfillJob( + dag=dag, start_date=start_date, end_date=end_date, executor=executor, donot_pickle=True + ) job.run() # BackfillJob will run since the existing DagRun does not count for the max @@ -1010,8 +975,7 @@ def test_backfill_max_limit_check_no_count_existing(self): self.assertEqual(State.SUCCESS, dagruns[0].state) def test_backfill_max_limit_check_complete_loop(self): - dag = self._get_dag_test_max_active_limits( - 'test_backfill_max_limit_check_complete_loop') + dag = self._get_dag_test_max_active_limits('test_backfill_max_limit_check_complete_loop') start_date = DEFAULT_DATE - datetime.timedelta(hours=1) end_date = DEFAULT_DATE @@ -1019,11 +983,9 @@ def test_backfill_max_limit_check_complete_loop(self): # backfill job 3 times success_expected = 2 executor = MockExecutor() - job = BackfillJob(dag=dag, - start_date=start_date, - end_date=end_date, - executor=executor, - donot_pickle=True) + job = BackfillJob( + dag=dag, start_date=start_date, end_date=end_date, executor=executor, donot_pickle=True + ) job.run() success_dagruns = len(DagRun.find(dag_id=dag.dag_id, state=State.SUCCESS)) @@ -1032,10 +994,7 @@ def test_backfill_max_limit_check_complete_loop(self): self.assertEqual(0, running_dagruns) # no dag_runs in running state are left def test_sub_set_subdag(self): - dag = DAG( - 'test_sub_set_subdag', - start_date=DEFAULT_DATE, - default_args={'owner': 'owner1'}) + dag = DAG('test_sub_set_subdag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) with dag: op1 = DummyOperator(task_id='leave1') @@ -1050,19 +1009,13 @@ def test_sub_set_subdag(self): op3.set_downstream(op4) dag.clear() - dr = dag.create_dagrun(run_id="test", - state=State.RUNNING, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE) + dr = dag.create_dagrun( + run_id="test", state=State.RUNNING, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE + ) executor = MockExecutor() - sub_dag = dag.sub_dag(task_ids_or_regex="leave*", - include_downstream=False, - include_upstream=False) - job = BackfillJob(dag=sub_dag, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE, - executor=executor) + sub_dag = dag.sub_dag(task_ids_or_regex="leave*", include_downstream=False, include_upstream=False) + job = BackfillJob(dag=sub_dag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, executor=executor) job.run() self.assertRaises(sqlalchemy.orm.exc.NoResultFound, dr.refresh_from_db) @@ -1093,10 +1046,9 @@ def test_backfill_fill_blanks(self): op6 = DummyOperator(task_id='op6') dag.clear() - dr = dag.create_dagrun(run_id='test', - state=State.RUNNING, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE) + dr = dag.create_dagrun( + run_id='test', state=State.RUNNING, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE + ) executor = MockExecutor() session = settings.Session() @@ -1119,14 +1071,8 @@ def test_backfill_fill_blanks(self): session.commit() session.close() - job = BackfillJob(dag=dag, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE, - executor=executor) - self.assertRaisesRegex( - AirflowException, - 'Some task instances failed', - job.run) + job = BackfillJob(dag=dag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, executor=executor) + self.assertRaisesRegex(AirflowException, 'Some task instances failed', job.run) self.assertRaises(sqlalchemy.orm.exc.NoResultFound, dr.refresh_from_db) # the run_id should have changed, so a refresh won't work @@ -1155,11 +1101,9 @@ def test_backfill_execute_subdag(self): start_date = timezone.utcnow() executor = MockExecutor() - job = BackfillJob(dag=subdag, - start_date=start_date, - end_date=start_date, - executor=executor, - donot_pickle=True) + job = BackfillJob( + dag=subdag, start_date=start_date, end_date=start_date, executor=executor, donot_pickle=True + ) job.run() subdag_op_task.pre_execute(context={'execution_date': start_date}) @@ -1177,8 +1121,7 @@ def test_backfill_execute_subdag(self): with create_session() as session: successful_subdag_runs = ( - session - .query(DagRun) + session.query(DagRun) .filter(DagRun.dag_id == subdag.dag_id) .filter(DagRun.execution_date == start_date) # pylint: disable=comparison-with-callable @@ -1198,42 +1141,30 @@ def test_subdag_clear_parentdag_downstream_clear(self): subdag = subdag_op_task.subdag executor = MockExecutor() - job = BackfillJob(dag=dag, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE, - executor=executor, - donot_pickle=True) + job = BackfillJob( + dag=dag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, executor=executor, donot_pickle=True + ) with timeout(seconds=30): job.run() - ti_subdag = TI( - task=dag.get_task('daily_job'), - execution_date=DEFAULT_DATE) + ti_subdag = TI(task=dag.get_task('daily_job'), execution_date=DEFAULT_DATE) ti_subdag.refresh_from_db() self.assertEqual(ti_subdag.state, State.SUCCESS) - ti_irrelevant = TI( - task=dag.get_task('daily_job_irrelevant'), - execution_date=DEFAULT_DATE) + ti_irrelevant = TI(task=dag.get_task('daily_job_irrelevant'), execution_date=DEFAULT_DATE) ti_irrelevant.refresh_from_db() self.assertEqual(ti_irrelevant.state, State.SUCCESS) - ti_downstream = TI( - task=dag.get_task('daily_job_downstream'), - execution_date=DEFAULT_DATE) + ti_downstream = TI(task=dag.get_task('daily_job_downstream'), execution_date=DEFAULT_DATE) ti_downstream.refresh_from_db() self.assertEqual(ti_downstream.state, State.SUCCESS) sdag = subdag.sub_dag( - task_ids_or_regex='daily_job_subdag_task', - include_downstream=True, - include_upstream=False) + task_ids_or_regex='daily_job_subdag_task', include_downstream=True, include_upstream=False + ) - sdag.clear( - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE, - include_parentdag=True) + sdag.clear(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, include_parentdag=True) ti_subdag.refresh_from_db() self.assertEqual(State.NONE, ti_subdag.state) @@ -1257,16 +1188,13 @@ def test_backfill_execute_subdag_with_removed_task(self): subdag = dag.get_task('section-1').subdag executor = MockExecutor() - job = BackfillJob(dag=subdag, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE, - executor=executor, - donot_pickle=True) + job = BackfillJob( + dag=subdag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, executor=executor, donot_pickle=True + ) removed_task_ti = TI( - task=DummyOperator(task_id='removed_task'), - execution_date=DEFAULT_DATE, - state=State.REMOVED) + task=DummyOperator(task_id='removed_task'), execution_date=DEFAULT_DATE, state=State.REMOVED + ) removed_task_ti.dag_id = subdag.dag_id session = settings.Session() @@ -1277,10 +1205,13 @@ def test_backfill_execute_subdag_with_removed_task(self): job.run() for task in subdag.tasks: - instance = session.query(TI).filter( - TI.dag_id == subdag.dag_id, - TI.task_id == task.task_id, - TI.execution_date == DEFAULT_DATE).first() + instance = ( + session.query(TI) + .filter( + TI.dag_id == subdag.dag_id, TI.task_id == task.task_id, TI.execution_date == DEFAULT_DATE + ) + .first() + ) self.assertIsNotNone(instance) self.assertEqual(instance.state, State.SUCCESS) @@ -1292,23 +1223,20 @@ def test_backfill_execute_subdag_with_removed_task(self): dag.clear() def test_update_counters(self): - dag = DAG( - dag_id='test_manage_executor_state', - start_date=DEFAULT_DATE) + dag = DAG(dag_id='test_manage_executor_state', start_date=DEFAULT_DATE) - task1 = DummyOperator( - task_id='dummy', - dag=dag, - owner='airflow') + task1 = DummyOperator(task_id='dummy', dag=dag, owner='airflow') job = BackfillJob(dag=dag) session = settings.Session() - dr = dag.create_dagrun(run_type=DagRunType.SCHEDULED, - state=State.RUNNING, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE, - session=session) + dr = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + state=State.RUNNING, + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + session=session, + ) ti = TI(task1, dr.execution_date) ti.refresh_from_db() @@ -1389,12 +1317,8 @@ def test_update_counters(self): session.close() def test_dag_get_run_dates(self): - def get_test_dag_for_backfill(schedule_interval=None): - dag = DAG( - dag_id='test_get_dates', - start_date=DEFAULT_DATE, - schedule_interval=schedule_interval) + dag = DAG(dag_id='test_get_dates', start_date=DEFAULT_DATE, schedule_interval=schedule_interval) DummyOperator( task_id='dummy', dag=dag, @@ -1403,18 +1327,23 @@ def get_test_dag_for_backfill(schedule_interval=None): return dag test_dag = get_test_dag_for_backfill() - self.assertEqual([DEFAULT_DATE], test_dag.get_run_dates( - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE)) + self.assertEqual( + [DEFAULT_DATE], test_dag.get_run_dates(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + ) test_dag = get_test_dag_for_backfill(schedule_interval="@hourly") - self.assertEqual([DEFAULT_DATE - datetime.timedelta(hours=3), - DEFAULT_DATE - datetime.timedelta(hours=2), - DEFAULT_DATE - datetime.timedelta(hours=1), - DEFAULT_DATE], - test_dag.get_run_dates( - start_date=DEFAULT_DATE - datetime.timedelta(hours=3), - end_date=DEFAULT_DATE, )) + self.assertEqual( + [ + DEFAULT_DATE - datetime.timedelta(hours=3), + DEFAULT_DATE - datetime.timedelta(hours=2), + DEFAULT_DATE - datetime.timedelta(hours=1), + DEFAULT_DATE, + ], + test_dag.get_run_dates( + start_date=DEFAULT_DATE - datetime.timedelta(hours=3), + end_date=DEFAULT_DATE, + ), + ) def test_backfill_run_backwards(self): dag = self.dagbag.get_dag("test_start_date_scheduling") @@ -1427,14 +1356,17 @@ def test_backfill_run_backwards(self): dag=dag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=1), - run_backwards=True + run_backwards=True, ) job.run() session = settings.Session() - tis = session.query(TI).filter( - TI.dag_id == 'test_start_date_scheduling' and TI.task_id == 'dummy' - ).order_by(TI.execution_date).all() + tis = ( + session.query(TI) + .filter(TI.dag_id == 'test_start_date_scheduling' and TI.task_id == 'dummy') + .order_by(TI.execution_date) + .all() + ) queued_times = [ti.queued_dttm for ti in tis] self.assertTrue(queued_times == sorted(queued_times, reverse=True)) @@ -1449,9 +1381,7 @@ def test_reset_orphaned_tasks_with_orphans(self): states = [State.QUEUED, State.SCHEDULED, State.NONE, State.RUNNING, State.SUCCESS] states_to_reset = [State.QUEUED, State.SCHEDULED, State.NONE] - dag = DAG(dag_id=prefix, - start_date=DEFAULT_DATE, - schedule_interval="@daily") + dag = DAG(dag_id=prefix, start_date=DEFAULT_DATE, schedule_interval="@daily") tasks = [] for i in range(len(states)): task_id = f"{prefix}_task_{i}" @@ -1543,9 +1473,7 @@ def test_job_id_is_assigned_to_dag_run(self): DummyOperator(task_id="dummy_task", dag=dag) job = BackfillJob( - dag=dag, - executor=MockExecutor(), - start_date=datetime.datetime.now() - datetime.timedelta(days=1) + dag=dag, executor=MockExecutor(), start_date=datetime.datetime.now() - datetime.timedelta(days=1) ) job.run() dr: DagRun = dag.get_last_dagrun() diff --git a/tests/jobs/test_base_job.py b/tests/jobs/test_base_job.py index 2d1496cb6e7a6..3ffe4adbee7b0 100644 --- a/tests/jobs/test_base_job.py +++ b/tests/jobs/test_base_job.py @@ -32,9 +32,7 @@ class MockJob(BaseJob): - __mapper_args__ = { - 'polymorphic_identity': 'MockJob' - } + __mapper_args__ = {'polymorphic_identity': 'MockJob'} def __init__(self, func, **kwargs): self.func = func @@ -54,6 +52,7 @@ def test_state_success(self): def test_state_sysexit(self): import sys + job = MockJob(lambda: sys.exit(0)) job.run() diff --git a/tests/jobs/test_local_task_job.py b/tests/jobs/test_local_task_job.py index f81027f724ba9..bcdf1fb5d3ce1 100644 --- a/tests/jobs/test_local_task_job.py +++ b/tests/jobs/test_local_task_job.py @@ -67,23 +67,19 @@ def test_localtaskjob_essential_attr(self): proper values without intervention """ dag = DAG( - 'test_localtaskjob_essential_attr', - start_date=DEFAULT_DATE, - default_args={'owner': 'owner1'}) + 'test_localtaskjob_essential_attr', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'} + ) with dag: op1 = DummyOperator(task_id='op1') dag.clear() - dr = dag.create_dagrun(run_id="test", - state=State.SUCCESS, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE) + dr = dag.create_dagrun( + run_id="test", state=State.SUCCESS, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE + ) ti = dr.get_task_instance(task_id=op1.task_id) - job1 = LocalTaskJob(task_instance=ti, - ignore_ti_state=True, - executor=SequentialExecutor()) + job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) essential_attr = ["dag_id", "job_type", "start_date", "hostname"] @@ -96,28 +92,25 @@ def test_localtaskjob_essential_attr(self): @patch('os.getpid') def test_localtaskjob_heartbeat(self, mock_pid): session = settings.Session() - dag = DAG( - 'test_localtaskjob_heartbeat', - start_date=DEFAULT_DATE, - default_args={'owner': 'owner1'}) + dag = DAG('test_localtaskjob_heartbeat', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) with dag: op1 = DummyOperator(task_id='op1') dag.clear() - dr = dag.create_dagrun(run_id="test", - state=State.SUCCESS, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE, - session=session) + dr = dag.create_dagrun( + run_id="test", + state=State.SUCCESS, + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + session=session, + ) ti = dr.get_task_instance(task_id=op1.task_id, session=session) ti.state = State.RUNNING ti.hostname = "blablabla" session.commit() - job1 = LocalTaskJob(task_instance=ti, - ignore_ti_state=True, - executor=SequentialExecutor()) + job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) self.assertRaises(AirflowException, job1.heartbeat_callback) mock_pid.return_value = 1 @@ -148,11 +141,13 @@ def test_heartbeat_failed_fast(self): dag = dagbag.get_dag(dag_id) task = dag.get_task(task_id) - dag.create_dagrun(run_id="test_heartbeat_failed_fast_run", - state=State.RUNNING, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE, - session=session) + dag.create_dagrun( + run_id="test_heartbeat_failed_fast_run", + state=State.RUNNING, + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + session=session, + ) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.refresh_from_db() ti.state = State.RUNNING @@ -189,11 +184,13 @@ def test_mark_success_no_kill(self): session = settings.Session() dag.clear() - dag.create_dagrun(run_id="test", - state=State.RUNNING, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE, - session=session) + dag.create_dagrun( + run_id="test", + state=State.RUNNING, + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + session=session, + ) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.refresh_from_db() job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True) @@ -226,11 +223,13 @@ def test_localtaskjob_double_trigger(self): session = settings.Session() dag.clear() - dr = dag.create_dagrun(run_id="test", - state=State.SUCCESS, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE, - session=session) + dr = dag.create_dagrun( + run_id="test", + state=State.SUCCESS, + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + session=session, + ) ti = dr.get_task_instance(task_id=task.task_id, session=session) ti.state = State.RUNNING ti.hostname = get_hostname() @@ -240,9 +239,9 @@ def test_localtaskjob_double_trigger(self): ti_run = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti_run.refresh_from_db() - job1 = LocalTaskJob(task_instance=ti_run, - executor=SequentialExecutor()) + job1 = LocalTaskJob(task_instance=ti_run, executor=SequentialExecutor()) from airflow.task.task_runner.standard_task_runner import StandardTaskRunner + with patch.object(StandardTaskRunner, 'start', return_value=None) as mock_method: job1.run() mock_method.assert_not_called() @@ -265,16 +264,17 @@ def test_localtaskjob_maintain_heart_rate(self): session = settings.Session() dag.clear() - dag.create_dagrun(run_id="test", - state=State.SUCCESS, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE, - session=session) + dag.create_dagrun( + run_id="test", + state=State.SUCCESS, + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + session=session, + ) ti_run = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti_run.refresh_from_db() - job1 = LocalTaskJob(task_instance=ti_run, - executor=SequentialExecutor()) + job1 = LocalTaskJob(task_instance=ti_run, executor=SequentialExecutor()) # this should make sure we only heartbeat once and exit at the second # loop in _execute() @@ -285,6 +285,7 @@ def multi_return_code(): time_start = time.time() from airflow.task.task_runner.standard_task_runner import StandardTaskRunner + with patch.object(StandardTaskRunner, 'start', return_value=None) as mock_start: with patch.object(StandardTaskRunner, 'return_code') as mock_ret_code: mock_ret_code.side_effect = multi_return_code @@ -331,22 +332,23 @@ def task_function(ti): task = PythonOperator( task_id='test_state_succeeded1', python_callable=task_function, - on_failure_callback=check_failure) + on_failure_callback=check_failure, + ) session = settings.Session() dag.clear() - dag.create_dagrun(run_id="test", - state=State.RUNNING, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE, - session=session) + dag.create_dagrun( + run_id="test", + state=State.RUNNING, + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + session=session, + ) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.refresh_from_db() - job1 = LocalTaskJob(task_instance=ti, - ignore_ti_state=True, - executor=SequentialExecutor()) + job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) with timeout(30): # This should be _much_ shorter to run. # If you change this limit, make the timeout in the callbable above bigger @@ -355,8 +357,9 @@ def task_function(ti): ti.refresh_from_db() self.assertEqual(ti.state, State.FAILED) self.assertTrue(data['called']) - self.assertNotIn('reached_end_of_sleep', data, - 'Task should not have been allowed to run to completion') + self.assertNotIn( + 'reached_end_of_sleep', data, 'Task should not have been allowed to run to completion' + ) @pytest.mark.quarantined def test_mark_success_on_success_callback(self): @@ -367,33 +370,28 @@ def test_mark_success_on_success_callback(self): data = {'called': False} def success_callback(context): - self.assertEqual(context['dag_run'].dag_id, - 'test_mark_success') + self.assertEqual(context['dag_run'].dag_id, 'test_mark_success') data['called'] = True - dag = DAG(dag_id='test_mark_success', - start_date=DEFAULT_DATE, - default_args={'owner': 'owner1'}) + dag = DAG(dag_id='test_mark_success', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) - task = DummyOperator( - task_id='test_state_succeeded1', - dag=dag, - on_success_callback=success_callback) + task = DummyOperator(task_id='test_state_succeeded1', dag=dag, on_success_callback=success_callback) session = settings.Session() dag.clear() - dag.create_dagrun(run_id="test", - state=State.RUNNING, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE, - session=session) + dag.create_dagrun( + run_id="test", + state=State.RUNNING, + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + session=session, + ) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.refresh_from_db() - job1 = LocalTaskJob(task_instance=ti, - ignore_ti_state=True, - executor=SequentialExecutor()) + job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) from airflow.task.task_runner.standard_task_runner import StandardTaskRunner + job1.task_runner = StandardTaskRunner(job1) process = multiprocessing.Process(target=job1.run) process.start() diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index a7589ef57dca6..5dd1bec6785b7 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -58,8 +58,14 @@ from tests.test_utils.asserts import assert_queries_count from tests.test_utils.config import conf_vars, env_vars from tests.test_utils.db import ( - clear_db_dags, clear_db_errors, clear_db_jobs, clear_db_pools, clear_db_runs, clear_db_serialized_dags, - clear_db_sla_miss, set_default_pool_slots, + clear_db_dags, + clear_db_errors, + clear_db_jobs, + clear_db_pools, + clear_db_runs, + clear_db_serialized_dags, + clear_db_sla_miss, + set_default_pool_slots, ) from tests.test_utils.mock_executor import MockExecutor @@ -77,11 +83,7 @@ # files contain a DAG (otherwise Airflow will skip them) PARSEABLE_DAG_FILE_CONTENTS = '"airflow DAG"' UNPARSEABLE_DAG_FILE_CONTENTS = 'airflow DAG' -INVALID_DAG_WITH_DEPTH_FILE_CONTENTS = ( - "def something():\n" - " return airflow_DAG\n" - "something()" -) +INVALID_DAG_WITH_DEPTH_FILE_CONTENTS = "def something():\n return airflow_DAG\nsomething()" # Filename to be used for dags that are created in an ad-hoc manner and can be removed/ # created at runtime @@ -97,7 +99,6 @@ def disable_load_example(): @pytest.mark.usefixtures("disable_load_example") class TestDagFileProcessor(unittest.TestCase): - @staticmethod def clean_db(): clear_db_runs() @@ -123,7 +124,8 @@ def create_test_dag(self, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + timed dag_id='test_scheduler_reschedule', start_date=start_date, # Make sure it only creates a single DAG Run - end_date=end_date) + end_date=end_date, + ) dag.clear() dag.is_subdag = False with create_session() as session: @@ -150,14 +152,13 @@ def test_dag_file_processor_sla_miss_callback(self): # Create dag with a start of 1 day ago, but an sla of 0 # so we'll already have an sla_miss on the books. test_start_date = days_ago(1) - dag = DAG(dag_id='test_sla_miss', - sla_miss_callback=sla_callback, - default_args={'start_date': test_start_date, - 'sla': datetime.timedelta()}) + dag = DAG( + dag_id='test_sla_miss', + sla_miss_callback=sla_callback, + default_args={'start_date': test_start_date, 'sla': datetime.timedelta()}, + ) - task = DummyOperator(task_id='dummy', - dag=dag, - owner='airflow') + task = DummyOperator(task_id='dummy', dag=dag, owner='airflow') session.merge(TaskInstance(task=task, execution_date=test_start_date, state='success')) @@ -181,10 +182,11 @@ def test_dag_file_processor_sla_miss_callback_invalid_sla(self): # so we'll already have an sla_miss on the books. # Pass anything besides a timedelta object to the sla argument. test_start_date = days_ago(1) - dag = DAG(dag_id='test_sla_miss', - sla_miss_callback=sla_callback, - default_args={'start_date': test_start_date, - 'sla': None}) + dag = DAG( + dag_id='test_sla_miss', + sla_miss_callback=sla_callback, + default_args={'start_date': test_start_date, 'sla': None}, + ) task = DummyOperator(task_id='dummy', dag=dag, owner='airflow') @@ -209,10 +211,11 @@ def test_dag_file_processor_sla_miss_callback_sent_notification(self): # Create dag with a start of 2 days ago, but an sla of 1 day # ago so we'll already have an sla_miss on the books test_start_date = days_ago(2) - dag = DAG(dag_id='test_sla_miss', - sla_miss_callback=sla_callback, - default_args={'start_date': test_start_date, - 'sla': datetime.timedelta(days=1)}) + dag = DAG( + dag_id='test_sla_miss', + sla_miss_callback=sla_callback, + default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)}, + ) task = DummyOperator(task_id='dummy', dag=dag, owner='airflow') @@ -220,11 +223,15 @@ def test_dag_file_processor_sla_miss_callback_sent_notification(self): session.merge(TaskInstance(task=task, execution_date=test_start_date, state='success')) # Create an SlaMiss where notification was sent, but email was not - session.merge(SlaMiss(task_id='dummy', - dag_id='test_sla_miss', - execution_date=test_start_date, - email_sent=False, - notification_sent=True)) + session.merge( + SlaMiss( + task_id='dummy', + dag_id='test_sla_miss', + execution_date=test_start_date, + email_sent=False, + notification_sent=True, + ) + ) # Now call manage_slas and see if the sla_miss callback gets called dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) @@ -242,21 +249,18 @@ def test_dag_file_processor_sla_miss_callback_exception(self): sla_callback = MagicMock(side_effect=RuntimeError('Could not call function')) test_start_date = days_ago(2) - dag = DAG(dag_id='test_sla_miss', - sla_miss_callback=sla_callback, - default_args={'start_date': test_start_date}) + dag = DAG( + dag_id='test_sla_miss', + sla_miss_callback=sla_callback, + default_args={'start_date': test_start_date}, + ) - task = DummyOperator(task_id='dummy', - dag=dag, - owner='airflow', - sla=datetime.timedelta(hours=1)) + task = DummyOperator(task_id='dummy', dag=dag, owner='airflow', sla=datetime.timedelta(hours=1)) session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success')) # Create an SlaMiss where notification was sent, but email was not - session.merge(SlaMiss(task_id='dummy', - dag_id='test_sla_miss', - execution_date=test_start_date)) + session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', execution_date=test_start_date)) # Now call manage_slas and see if the sla_miss callback gets called mock_log = mock.MagicMock() @@ -264,32 +268,28 @@ def test_dag_file_processor_sla_miss_callback_exception(self): dag_file_processor.manage_slas(dag=dag, session=session) assert sla_callback.called mock_log.exception.assert_called_once_with( - 'Could not call sla_miss_callback for DAG %s', - 'test_sla_miss') + 'Could not call sla_miss_callback for DAG %s', 'test_sla_miss' + ) @mock.patch('airflow.jobs.scheduler_job.send_email') def test_dag_file_processor_only_collect_emails_from_sla_missed_tasks(self, mock_send_email): session = settings.Session() test_start_date = days_ago(2) - dag = DAG(dag_id='test_sla_miss', - default_args={'start_date': test_start_date, - 'sla': datetime.timedelta(days=1)}) + dag = DAG( + dag_id='test_sla_miss', + default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)}, + ) email1 = 'test1@test.com' - task = DummyOperator(task_id='sla_missed', - dag=dag, - owner='airflow', - email=email1, - sla=datetime.timedelta(hours=1)) + task = DummyOperator( + task_id='sla_missed', dag=dag, owner='airflow', email=email1, sla=datetime.timedelta(hours=1) + ) session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success')) email2 = 'test2@test.com' - DummyOperator(task_id='sla_not_missed', - dag=dag, - owner='airflow', - email=email2) + DummyOperator(task_id='sla_not_missed', dag=dag, owner='airflow', email=email2) session.merge(SlaMiss(task_id='sla_missed', dag_id='test_sla_miss', execution_date=test_start_date)) @@ -316,15 +316,14 @@ def test_dag_file_processor_sla_miss_email_exception(self, mock_send_email, mock mock_send_email.side_effect = RuntimeError('Could not send an email') test_start_date = days_ago(2) - dag = DAG(dag_id='test_sla_miss', - default_args={'start_date': test_start_date, - 'sla': datetime.timedelta(days=1)}) + dag = DAG( + dag_id='test_sla_miss', + default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)}, + ) - task = DummyOperator(task_id='dummy', - dag=dag, - owner='airflow', - email='test@test.com', - sla=datetime.timedelta(hours=1)) + task = DummyOperator( + task_id='dummy', dag=dag, owner='airflow', email='test@test.com', sla=datetime.timedelta(hours=1) + ) session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success')) @@ -336,8 +335,8 @@ def test_dag_file_processor_sla_miss_email_exception(self, mock_send_email, mock dag_file_processor.manage_slas(dag=dag, session=session) mock_log.exception.assert_called_once_with( - 'Could not send SLA Miss email notification for DAG %s', - 'test_sla_miss') + 'Could not send SLA Miss email notification for DAG %s', 'test_sla_miss' + ) mock_stats_incr.assert_called_once_with('sla_email_notification_failure') def test_dag_file_processor_sla_miss_deleted_task(self): @@ -348,47 +347,48 @@ def test_dag_file_processor_sla_miss_deleted_task(self): session = settings.Session() test_start_date = days_ago(2) - dag = DAG(dag_id='test_sla_miss', - default_args={'start_date': test_start_date, - 'sla': datetime.timedelta(days=1)}) + dag = DAG( + dag_id='test_sla_miss', + default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)}, + ) - task = DummyOperator(task_id='dummy', - dag=dag, - owner='airflow', - email='test@test.com', - sla=datetime.timedelta(hours=1)) + task = DummyOperator( + task_id='dummy', dag=dag, owner='airflow', email='test@test.com', sla=datetime.timedelta(hours=1) + ) session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success')) # Create an SlaMiss where notification was sent, but email was not - session.merge(SlaMiss(task_id='dummy_deleted', dag_id='test_sla_miss', - execution_date=test_start_date)) + session.merge( + SlaMiss(task_id='dummy_deleted', dag_id='test_sla_miss', execution_date=test_start_date) + ) mock_log = mock.MagicMock() dag_file_processor = DagFileProcessor(dag_ids=[], log=mock_log) dag_file_processor.manage_slas(dag=dag, session=session) - @parameterized.expand([ - [State.NONE, None, None], - [State.UP_FOR_RETRY, timezone.utcnow() - datetime.timedelta(minutes=30), - timezone.utcnow() - datetime.timedelta(minutes=15)], - [State.UP_FOR_RESCHEDULE, timezone.utcnow() - datetime.timedelta(minutes=30), - timezone.utcnow() - datetime.timedelta(minutes=15)], - ]) + @parameterized.expand( + [ + [State.NONE, None, None], + [ + State.UP_FOR_RETRY, + timezone.utcnow() - datetime.timedelta(minutes=30), + timezone.utcnow() - datetime.timedelta(minutes=15), + ], + [ + State.UP_FOR_RESCHEDULE, + timezone.utcnow() - datetime.timedelta(minutes=30), + timezone.utcnow() - datetime.timedelta(minutes=15), + ], + ] + ) def test_dag_file_processor_process_task_instances(self, state, start_date, end_date): """ Test if _process_task_instances puts the right task instances into the mock_list. """ - dag = DAG( - dag_id='test_scheduler_process_execute_task', - start_date=DEFAULT_DATE) - BashOperator( - task_id='dummy', - dag=dag, - owner='airflow', - bash_command='echo hi' - ) + dag = DAG(dag_id='test_scheduler_process_execute_task', start_date=DEFAULT_DATE) + BashOperator(task_id='dummy', dag=dag, owner='airflow', bash_command='echo hi') with create_session() as session: orm_dag = DagModel(dag_id=dag.dag_id) @@ -419,30 +419,33 @@ def test_dag_file_processor_process_task_instances(self, state, start_date, end_ session.refresh(ti) assert ti.state == State.SCHEDULED - @parameterized.expand([ - [State.NONE, None, None], - [State.UP_FOR_RETRY, timezone.utcnow() - datetime.timedelta(minutes=30), - timezone.utcnow() - datetime.timedelta(minutes=15)], - [State.UP_FOR_RESCHEDULE, timezone.utcnow() - datetime.timedelta(minutes=30), - timezone.utcnow() - datetime.timedelta(minutes=15)], - ]) + @parameterized.expand( + [ + [State.NONE, None, None], + [ + State.UP_FOR_RETRY, + timezone.utcnow() - datetime.timedelta(minutes=30), + timezone.utcnow() - datetime.timedelta(minutes=15), + ], + [ + State.UP_FOR_RESCHEDULE, + timezone.utcnow() - datetime.timedelta(minutes=30), + timezone.utcnow() - datetime.timedelta(minutes=15), + ], + ] + ) def test_dag_file_processor_process_task_instances_with_task_concurrency( - self, state, start_date, end_date, + self, + state, + start_date, + end_date, ): """ Test if _process_task_instances puts the right task instances into the mock_list. """ - dag = DAG( - dag_id='test_scheduler_process_execute_task_with_task_concurrency', - start_date=DEFAULT_DATE) - BashOperator( - task_id='dummy', - task_concurrency=2, - dag=dag, - owner='airflow', - bash_command='echo Hi' - ) + dag = DAG(dag_id='test_scheduler_process_execute_task_with_task_concurrency', start_date=DEFAULT_DATE) + BashOperator(task_id='dummy', task_concurrency=2, dag=dag, owner='airflow', bash_command='echo Hi') with create_session() as session: orm_dag = DagModel(dag_id=dag.dag_id) @@ -473,13 +476,21 @@ def test_dag_file_processor_process_task_instances_with_task_concurrency( session.refresh(ti) assert ti.state == State.SCHEDULED - @parameterized.expand([ - [State.NONE, None, None], - [State.UP_FOR_RETRY, timezone.utcnow() - datetime.timedelta(minutes=30), - timezone.utcnow() - datetime.timedelta(minutes=15)], - [State.UP_FOR_RESCHEDULE, timezone.utcnow() - datetime.timedelta(minutes=30), - timezone.utcnow() - datetime.timedelta(minutes=15)], - ]) + @parameterized.expand( + [ + [State.NONE, None, None], + [ + State.UP_FOR_RETRY, + timezone.utcnow() - datetime.timedelta(minutes=30), + timezone.utcnow() - datetime.timedelta(minutes=15), + ], + [ + State.UP_FOR_RESCHEDULE, + timezone.utcnow() - datetime.timedelta(minutes=30), + timezone.utcnow() - datetime.timedelta(minutes=15), + ], + ] + ) def test_dag_file_processor_process_task_instances_depends_on_past(self, state, start_date, end_date): """ Test if _process_task_instances puts the right task instances into the @@ -492,18 +503,8 @@ def test_dag_file_processor_process_task_instances_depends_on_past(self, state, 'depends_on_past': True, }, ) - BashOperator( - task_id='dummy1', - dag=dag, - owner='airflow', - bash_command='echo hi' - ) - BashOperator( - task_id='dummy2', - dag=dag, - owner='airflow', - bash_command='echo hi' - ) + BashOperator(task_id='dummy1', dag=dag, owner='airflow', bash_command='echo hi') + BashOperator(task_id='dummy2', dag=dag, owner='airflow', bash_command='echo hi') with create_session() as session: orm_dag = DagModel(dag_id=dag.dag_id) @@ -586,17 +587,10 @@ def test_runs_respected_after_clear(self): Test if _process_task_instances only schedules ti's up to max_active_runs (related to issue AIRFLOW-137) """ - dag = DAG( - dag_id='test_scheduler_max_active_runs_respected_after_clear', - start_date=DEFAULT_DATE) + dag = DAG(dag_id='test_scheduler_max_active_runs_respected_after_clear', start_date=DEFAULT_DATE) dag.max_active_runs = 3 - BashOperator( - task_id='dummy', - dag=dag, - owner='airflow', - bash_command='echo Hi' - ) + BashOperator(task_id='dummy', dag=dag, owner='airflow', bash_command='echo Hi') session = settings.Session() orm_dag = DagModel(dag_id=dag.dag_id) @@ -664,16 +658,12 @@ def test_execute_on_failure_callbacks(self, mock_ti_handle_failure): requests = [ TaskCallbackRequest( - full_filepath="A", - simple_task_instance=SimpleTaskInstance(ti), - msg="Message" + full_filepath="A", simple_task_instance=SimpleTaskInstance(ti), msg="Message" ) ] dag_file_processor.execute_callbacks(dagbag, requests) mock_ti_handle_failure.assert_called_once_with( - "Message", - conf.getboolean('core', 'unit_test_mode'), - mock.ANY + "Message", conf.getboolean('core', 'unit_test_mode'), mock.ANY ) def test_process_file_should_failure_callback(self): @@ -696,7 +686,7 @@ def test_process_file_should_failure_callback(self): TaskCallbackRequest( full_filepath=dag.full_filepath, simple_task_instance=SimpleTaskInstance(ti), - msg="Message" + msg="Message", ) ] callback_file.close() @@ -738,15 +728,19 @@ def test_should_mark_dummy_task_as_success(self): dags = scheduler_job.dagbag.dags.values() self.assertEqual(['test_only_dummy_tasks'], [dag.dag_id for dag in dags]) self.assertEqual(5, len(tis)) - self.assertEqual({ - ('test_task_a', 'success'), - ('test_task_b', None), - ('test_task_c', 'success'), - ('test_task_on_execute', 'scheduled'), - ('test_task_on_success', 'scheduled'), - }, {(ti.task_id, ti.state) for ti in tis}) - for state, start_date, end_date, duration in [(ti.state, ti.start_date, ti.end_date, ti.duration) for - ti in tis]: + self.assertEqual( + { + ('test_task_a', 'success'), + ('test_task_b', None), + ('test_task_c', 'success'), + ('test_task_on_execute', 'scheduled'), + ('test_task_on_success', 'scheduled'), + }, + {(ti.task_id, ti.state) for ti in tis}, + ) + for state, start_date, end_date, duration in [ + (ti.state, ti.start_date, ti.end_date, ti.duration) for ti in tis + ]: if state == 'success': self.assertIsNotNone(start_date) self.assertIsNotNone(end_date) @@ -761,15 +755,19 @@ def test_should_mark_dummy_task_as_success(self): tis = session.query(TaskInstance).all() self.assertEqual(5, len(tis)) - self.assertEqual({ - ('test_task_a', 'success'), - ('test_task_b', 'success'), - ('test_task_c', 'success'), - ('test_task_on_execute', 'scheduled'), - ('test_task_on_success', 'scheduled'), - }, {(ti.task_id, ti.state) for ti in tis}) - for state, start_date, end_date, duration in [(ti.state, ti.start_date, ti.end_date, ti.duration) for - ti in tis]: + self.assertEqual( + { + ('test_task_a', 'success'), + ('test_task_b', 'success'), + ('test_task_c', 'success'), + ('test_task_on_execute', 'scheduled'), + ('test_task_on_success', 'scheduled'), + }, + {(ti.task_id, ti.state) for ti in tis}, + ) + for state, start_date, end_date, duration in [ + (ti.state, ti.start_date, ti.end_date, ti.duration) for ti in tis + ]: if state == 'success': self.assertIsNotNone(start_date) self.assertIsNotNone(end_date) @@ -782,7 +780,6 @@ def test_should_mark_dummy_task_as_success(self): @pytest.mark.usefixtures("disable_load_example") class TestSchedulerJob(unittest.TestCase): - def setUp(self): clear_db_runs() clear_db_pools() @@ -841,9 +838,8 @@ def run_single_scheduler_loop_with_no_dags(self, dags_folder): :type dags_folder: str """ scheduler = SchedulerJob( - executor=self.null_exec, - num_times_parse_dags=1, - subdir=os.path.join(dags_folder)) + executor=self.null_exec, num_times_parse_dags=1, subdir=os.path.join(dags_folder) + ) scheduler.heartrate = 0 scheduler.run() @@ -851,15 +847,12 @@ def test_no_orphan_process_will_be_left(self): empty_dir = mkdtemp() current_process = psutil.Process() old_children = current_process.children(recursive=True) - scheduler = SchedulerJob(subdir=empty_dir, - num_runs=1, - executor=MockExecutor(do_update=False)) + scheduler = SchedulerJob(subdir=empty_dir, num_runs=1, executor=MockExecutor(do_update=False)) scheduler.run() shutil.rmtree(empty_dir) # Remove potential noise created by previous tests. - current_children = set(current_process.children(recursive=True)) - set( - old_children) + current_children = set(current_process.children(recursive=True)) - set(old_children) self.assertFalse(current_children) @mock.patch('airflow.jobs.scheduler_job.TaskCallbackRequest') @@ -900,9 +893,9 @@ def test_process_executor_events(self, mock_stats_incr, mock_task_callback): full_filepath='/test_path1/', simple_task_instance=mock.ANY, msg='Executor reports task instance ' - ' ' - 'finished (failed) although the task says its queued. (Info: None) ' - 'Was the task killed externally?' + ' ' + 'finished (failed) although the task says its queued. (Info: None) ' + 'Was the task killed externally?', ) scheduler.processor_agent.send_callback_to_execute.assert_called_once_with(task_callback) scheduler.processor_agent.reset_mock() @@ -929,9 +922,7 @@ def test_process_executor_events_uses_inmemory_try_number(self): executor = MagicMock() scheduler = SchedulerJob(executor=executor) scheduler.processor_agent = MagicMock() - event_buffer = { - TaskInstanceKey(dag_id, task_id, execution_date, try_number): (State.SUCCESS, None) - } + event_buffer = {TaskInstanceKey(dag_id, task_id, execution_date, try_number): (State.SUCCESS, None)} executor.get_event_buffer.return_value = event_buffer dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE) @@ -1141,12 +1132,12 @@ def test_find_executable_task_instances_pool(self): state=State.RUNNING, ) - tis = ([ + tis = [ TaskInstance(task1, dr1.execution_date), TaskInstance(task2, dr1.execution_date), TaskInstance(task1, dr2.execution_date), - TaskInstance(task2, dr2.execution_date) - ]) + TaskInstance(task2, dr2.execution_date), + ] for ti in tis: ti.state = State.SCHEDULED session.merge(ti) @@ -1278,9 +1269,7 @@ def test_find_executable_task_instances_none(self): ) session.flush() - self.assertEqual(0, len(scheduler._executable_task_instances_to_queued( - max_tis=32, - session=session))) + self.assertEqual(0, len(scheduler._executable_task_instances_to_queued(max_tis=32, session=session))) session.rollback() def test_find_executable_task_instances_concurrency(self): @@ -1600,10 +1589,7 @@ def test_critical_section_execute_task_instances(self): self.assertEqual(State.RUNNING, dr1.state) self.assertEqual( - 2, - DAG.get_num_task_instances( - dag_id, dag.task_ids, states=[State.RUNNING], session=session - ) + 2, DAG.get_num_task_instances(dag_id, dag.task_ids, states=[State.RUNNING], session=session) ) # create second dag run @@ -1636,7 +1622,7 @@ def test_critical_section_execute_task_instances(self): 3, DAG.get_num_task_instances( dag_id, dag.task_ids, states=[State.RUNNING, State.QUEUED], session=session - ) + ), ) self.assertEqual(State.RUNNING, ti1.state) self.assertEqual(State.RUNNING, ti2.state) @@ -1691,9 +1677,9 @@ def test_execute_task_instances_limit(self): self.assertEqual(2, res) scheduler.max_tis_per_query = 8 - with mock.patch.object(type(scheduler.executor), - 'slots_available', - new_callable=mock.PropertyMock) as mock_slots: + with mock.patch.object( + type(scheduler.executor), 'slots_available', new_callable=mock.PropertyMock + ) as mock_slots: mock_slots.return_value = 2 # Check that we don't "overfill" the executor self.assertEqual(2, res) @@ -1722,17 +1708,21 @@ def test_change_state_for_tis_without_dagrun(self): DummyOperator(task_id='dummy', dag=dag3, owner='airflow') session = settings.Session() - dr1 = dag1.create_dagrun(run_type=DagRunType.SCHEDULED, - state=State.RUNNING, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE, - session=session) - - dr2 = dag2.create_dagrun(run_type=DagRunType.SCHEDULED, - state=State.RUNNING, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE, - session=session) + dr1 = dag1.create_dagrun( + run_type=DagRunType.SCHEDULED, + state=State.RUNNING, + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + session=session, + ) + + dr2 = dag2.create_dagrun( + run_type=DagRunType.SCHEDULED, + state=State.RUNNING, + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + session=session, + ) ti1a = dr1.get_task_instance(task_id='dummy', session=session) ti1a.state = State.SCHEDULED @@ -1760,9 +1750,8 @@ def test_change_state_for_tis_without_dagrun(self): scheduler.dagbag.collect_dags_from_db() scheduler._change_state_for_tis_without_dagrun( - old_states=[State.SCHEDULED, State.QUEUED], - new_state=State.NONE, - session=session) + old_states=[State.SCHEDULED, State.QUEUED], new_state=State.NONE, session=session + ) ti1a = dr1.get_task_instance(task_id='dummy', session=session) ti1a.refresh_from_db(session=session) @@ -1790,9 +1779,8 @@ def test_change_state_for_tis_without_dagrun(self): session.commit() scheduler._change_state_for_tis_without_dagrun( - old_states=[State.SCHEDULED, State.QUEUED], - new_state=State.NONE, - session=session) + old_states=[State.SCHEDULED, State.QUEUED], new_state=State.NONE, session=session + ) ti1a.refresh_from_db(session=session) self.assertEqual(ti1a.state, State.SCHEDULED) @@ -1805,14 +1793,9 @@ def test_change_state_for_tis_without_dagrun(self): self.assertEqual(ti2.state, State.SCHEDULED) def test_change_state_for_tasks_failed_to_execute(self): - dag = DAG( - dag_id='dag_id', - start_date=DEFAULT_DATE) + dag = DAG(dag_id='dag_id', start_date=DEFAULT_DATE) - task = DummyOperator( - task_id='task_id', - dag=dag, - owner='airflow') + task = DummyOperator(task_id='task_id', dag=dag, owner='airflow') dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) # If there's no left over task in executor.queued_tasks, nothing happens @@ -1858,23 +1841,28 @@ def test_adopt_or_reset_orphaned_tasks(self): dag = DAG( 'test_execute_helper_reset_orphaned_tasks', start_date=DEFAULT_DATE, - default_args={'owner': 'owner1'}) + default_args={'owner': 'owner1'}, + ) with dag: op1 = DummyOperator(task_id='op1') dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) dag.clear() - dr = dag.create_dagrun(run_type=DagRunType.SCHEDULED, - state=State.RUNNING, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE, - session=session) - dr2 = dag.create_dagrun(run_type=DagRunType.BACKFILL_JOB, - state=State.RUNNING, - execution_date=DEFAULT_DATE + datetime.timedelta(1), - start_date=DEFAULT_DATE, - session=session) + dr = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + state=State.RUNNING, + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + session=session, + ) + dr2 = dag.create_dagrun( + run_type=DagRunType.BACKFILL_JOB, + state=State.RUNNING, + execution_date=DEFAULT_DATE + datetime.timedelta(1), + start_date=DEFAULT_DATE, + session=session, + ) ti = dr.get_task_instance(task_id=op1.task_id, session=session) ti.state = State.SCHEDULED ti2 = dr2.get_task_instance(task_id=op1.task_id, session=session) @@ -1894,15 +1882,17 @@ def test_adopt_or_reset_orphaned_tasks(self): ti2 = dr2.get_task_instance(task_id=op1.task_id, session=session) self.assertEqual(ti2.state, State.SCHEDULED, "Tasks run by Backfill Jobs should not be reset") - @parameterized.expand([ - [State.UP_FOR_RETRY, State.FAILED], - [State.QUEUED, State.NONE], - [State.SCHEDULED, State.NONE], - [State.UP_FOR_RESCHEDULE, State.NONE], - ]) - def test_scheduler_loop_should_change_state_for_tis_without_dagrun(self, - initial_task_state, - expected_task_state): + @parameterized.expand( + [ + [State.UP_FOR_RETRY, State.FAILED], + [State.QUEUED, State.NONE], + [State.SCHEDULED, State.NONE], + [State.UP_FOR_RESCHEDULE, State.NONE], + ] + ) + def test_scheduler_loop_should_change_state_for_tis_without_dagrun( + self, initial_task_state, expected_task_state + ): session = settings.Session() dag_id = 'test_execute_helper_should_change_state_for_tis_without_dagrun' dag = DAG(dag_id, start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) @@ -1918,11 +1908,13 @@ def test_scheduler_loop_should_change_state_for_tis_without_dagrun(self, dag = DagBag(read_dags_from_db=True, include_examples=False).get_dag(dag_id) # Create DAG run with FAILED state dag.clear() - dr = dag.create_dagrun(run_type=DagRunType.SCHEDULED, - state=State.FAILED, - execution_date=DEFAULT_DATE + timedelta(days=1), - start_date=DEFAULT_DATE + timedelta(days=1), - session=session) + dr = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + state=State.FAILED, + execution_date=DEFAULT_DATE + timedelta(days=1), + start_date=DEFAULT_DATE + timedelta(days=1), + session=session, + ) ti = dr.get_task_instance(task_id=op1.task_id, session=session) ti.state = initial_task_state session.commit() @@ -1956,16 +1948,11 @@ def test_dagrun_timeout_verify_max_active_runs(self): Test if a a dagrun would be scheduled if max_dag_runs has been reached but dagrun_timeout is also reached """ - dag = DAG( - dag_id='test_scheduler_verify_max_active_runs_and_dagrun_timeout', - start_date=DEFAULT_DATE) + dag = DAG(dag_id='test_scheduler_verify_max_active_runs_and_dagrun_timeout', start_date=DEFAULT_DATE) dag.max_active_runs = 1 dag.dagrun_timeout = datetime.timedelta(seconds=60) - DummyOperator( - task_id='dummy', - dag=dag, - owner='airflow') + DummyOperator(task_id='dummy', dag=dag, owner='airflow') scheduler = SchedulerJob() scheduler.dagbag.bag_dag(dag, root_dag=dag) @@ -2012,7 +1999,7 @@ def test_dagrun_timeout_verify_max_active_runs(self): dag_id=dr.dag_id, is_failure_callback=True, execution_date=dr.execution_date, - msg="timed_out" + msg="timed_out", ) # Verify dag failure callback request is sent to file processor @@ -2025,15 +2012,10 @@ def test_dagrun_timeout_fails_run(self): """ Test if a a dagrun will be set failed if timeout, even without max_active_runs """ - dag = DAG( - dag_id='test_scheduler_fail_dagrun_timeout', - start_date=DEFAULT_DATE) + dag = DAG(dag_id='test_scheduler_fail_dagrun_timeout', start_date=DEFAULT_DATE) dag.dagrun_timeout = datetime.timedelta(seconds=60) - DummyOperator( - task_id='dummy', - dag=dag, - owner='airflow') + DummyOperator(task_id='dummy', dag=dag, owner='airflow') scheduler = SchedulerJob() scheduler.dagbag.bag_dag(dag, root_dag=dag) @@ -2071,7 +2053,7 @@ def test_dagrun_timeout_fails_run(self): dag_id=dr.dag_id, is_failure_callback=True, execution_date=dr.execution_date, - msg="timed_out" + msg="timed_out", ) # Verify dag failure callback request is sent to file processor @@ -2080,10 +2062,7 @@ def test_dagrun_timeout_fails_run(self): session.rollback() session.close() - @parameterized.expand([ - (State.SUCCESS, "success"), - (State.FAILED, "task_failure") - ]) + @parameterized.expand([(State.SUCCESS, "success"), (State.FAILED, "task_failure")]) def test_dagrun_callbacks_are_called(self, state, expected_callback_msg): """ Test if DagRun is successful, and if Success callbacks is defined, it is sent to DagFileProcessor. @@ -2093,7 +2072,7 @@ def test_dagrun_callbacks_are_called(self, state, expected_callback_msg): dag_id='test_dagrun_callbacks_are_called', start_date=DEFAULT_DATE, on_success_callback=lambda x: print("success"), - on_failure_callback=lambda x: print("failed") + on_failure_callback=lambda x: print("failed"), ) DummyOperator(task_id='dummy', dag=dag, owner='airflow') @@ -2129,7 +2108,7 @@ def test_dagrun_callbacks_are_called(self, state, expected_callback_msg): dag_id=dr.dag_id, is_failure_callback=bool(state == State.FAILED), execution_date=dr.execution_date, - msg=expected_callback_msg + msg=expected_callback_msg, ) # Verify dag failure callback request is sent to file processor @@ -2142,13 +2121,8 @@ def test_dagrun_callbacks_are_called(self, state, expected_callback_msg): session.close() def test_do_not_schedule_removed_task(self): - dag = DAG( - dag_id='test_scheduler_do_not_schedule_removed_task', - start_date=DEFAULT_DATE) - DummyOperator( - task_id='dummy', - dag=dag, - owner='airflow') + dag = DAG(dag_id='test_scheduler_do_not_schedule_removed_task', start_date=DEFAULT_DATE) + DummyOperator(task_id='dummy', dag=dag, owner='airflow') session = settings.Session() dag.sync_to_db(session=session) @@ -2165,9 +2139,7 @@ def test_do_not_schedule_removed_task(self): self.assertIsNotNone(dr) # Re-create the DAG, but remove the task - dag = DAG( - dag_id='test_scheduler_do_not_schedule_removed_task', - start_date=DEFAULT_DATE) + dag = DAG(dag_id='test_scheduler_do_not_schedule_removed_task', start_date=DEFAULT_DATE) dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) scheduler = SchedulerJob() @@ -2179,13 +2151,14 @@ def test_do_not_schedule_removed_task(self): @provide_session def evaluate_dagrun( - self, - dag_id, - expected_task_states, # dict of task_id: state - dagrun_state, - run_kwargs=None, - advance_execution_date=False, - session=None): # pylint: disable=unused-argument + self, + dag_id, + expected_task_states, # dict of task_id: state + dagrun_state, + run_kwargs=None, + advance_execution_date=False, + session=None, + ): # pylint: disable=unused-argument """ Helper for testing DagRun states with simple two-task DAGS. @@ -2248,7 +2221,8 @@ def test_dagrun_fail(self): 'test_dagrun_fail': State.FAILED, 'test_dagrun_succeed': State.UPSTREAM_FAILED, }, - dagrun_state=State.FAILED) + dagrun_state=State.FAILED, + ) def test_dagrun_success(self): """ @@ -2260,7 +2234,8 @@ def test_dagrun_success(self): 'test_dagrun_fail': State.FAILED, 'test_dagrun_succeed': State.SUCCESS, }, - dagrun_state=State.SUCCESS) + dagrun_state=State.SUCCESS, + ) def test_dagrun_root_fail(self): """ @@ -2272,7 +2247,8 @@ def test_dagrun_root_fail(self): 'test_dagrun_succeed': State.SUCCESS, 'test_dagrun_fail': State.FAILED, }, - dagrun_state=State.FAILED) + dagrun_state=State.FAILED, + ) def test_dagrun_root_fail_unfinished(self): """ @@ -2311,10 +2287,7 @@ def test_dagrun_root_after_dagrun_unfinished(self): dag_id = 'test_dagrun_states_root_future' dag = self.dagbag.get_dag(dag_id) dag.sync_to_db() - scheduler = SchedulerJob( - num_runs=1, - executor=self.null_exec, - subdir=dag.fileloc) + scheduler = SchedulerJob(num_runs=1, executor=self.null_exec, subdir=dag.fileloc) scheduler.run() first_run = DagRun.find(dag_id=dag_id, execution_date=DEFAULT_DATE)[0] @@ -2339,7 +2312,8 @@ def test_dagrun_deadlock_ignore_depends_on_past_advance_ex_date(self): }, dagrun_state=State.SUCCESS, advance_execution_date=True, - run_kwargs=dict(ignore_first_depends_on_past=True)) + run_kwargs=dict(ignore_first_depends_on_past=True), + ) def test_dagrun_deadlock_ignore_depends_on_past(self): """ @@ -2355,7 +2329,8 @@ def test_dagrun_deadlock_ignore_depends_on_past(self): 'test_depends_on_past_2': State.SUCCESS, }, dagrun_state=State.SUCCESS, - run_kwargs=dict(ignore_first_depends_on_past=True)) + run_kwargs=dict(ignore_first_depends_on_past=True), + ) def test_scheduler_start_date(self): """ @@ -2372,14 +2347,11 @@ def test_scheduler_start_date(self): other_dag.is_paused_upon_creation = True other_dag.sync_to_db() - scheduler = SchedulerJob(executor=self.null_exec, - subdir=dag.fileloc, - num_runs=1) + scheduler = SchedulerJob(executor=self.null_exec, subdir=dag.fileloc, num_runs=1) scheduler.run() # zero tasks ran - self.assertEqual( - len(session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).all()), 0) + self.assertEqual(len(session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).all()), 0) session.commit() self.assertListEqual([], self.null_exec.sorted_tasks) @@ -2388,32 +2360,24 @@ def test_scheduler_start_date(self): # That behavior still exists, but now it will only do so if after the # start date bf_exec = MockExecutor() - backfill = BackfillJob( - executor=bf_exec, - dag=dag, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE) + backfill = BackfillJob(executor=bf_exec, dag=dag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) backfill.run() # one task ran - self.assertEqual( - len(session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).all()), 1) + self.assertEqual(len(session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).all()), 1) self.assertListEqual( [ (TaskInstanceKey(dag.dag_id, 'dummy', DEFAULT_DATE, 1), (State.SUCCESS, None)), ], - bf_exec.sorted_tasks + bf_exec.sorted_tasks, ) session.commit() - scheduler = SchedulerJob(dag.fileloc, - executor=self.null_exec, - num_runs=1) + scheduler = SchedulerJob(dag.fileloc, executor=self.null_exec, num_runs=1) scheduler.run() # still one task - self.assertEqual( - len(session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).all()), 1) + self.assertEqual(len(session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).all()), 1) session.commit() self.assertListEqual([], self.null_exec.sorted_tasks) @@ -2435,9 +2399,7 @@ def test_scheduler_task_start_date(self): dagbag.sync_to_db() - scheduler = SchedulerJob(executor=self.null_exec, - subdir=dag.fileloc, - num_runs=2) + scheduler = SchedulerJob(executor=self.null_exec, subdir=dag.fileloc, num_runs=2) scheduler.run() session = settings.Session() @@ -2458,16 +2420,17 @@ def test_scheduler_multiprocessing(self): dag = self.dagbag.get_dag(dag_id) dag.clear() - scheduler = SchedulerJob(executor=self.null_exec, - subdir=os.path.join(TEST_DAG_FOLDER, 'test_scheduler_dags.py'), - num_runs=1) + scheduler = SchedulerJob( + executor=self.null_exec, + subdir=os.path.join(TEST_DAG_FOLDER, 'test_scheduler_dags.py'), + num_runs=1, + ) scheduler.run() # zero tasks ran dag_id = 'test_start_date_scheduling' session = settings.Session() - self.assertEqual( - len(session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).all()), 0) + self.assertEqual(len(session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).all()), 0) @conf_vars({("core", "mp_start_method"): "spawn"}) def test_scheduler_multiprocessing_with_spawn_method(self): @@ -2480,26 +2443,24 @@ def test_scheduler_multiprocessing_with_spawn_method(self): dag = self.dagbag.get_dag(dag_id) dag.clear() - scheduler = SchedulerJob(executor=self.null_exec, - subdir=os.path.join( - TEST_DAG_FOLDER, 'test_scheduler_dags.py'), - num_runs=1) + scheduler = SchedulerJob( + executor=self.null_exec, + subdir=os.path.join(TEST_DAG_FOLDER, 'test_scheduler_dags.py'), + num_runs=1, + ) scheduler.run() # zero tasks ran dag_id = 'test_start_date_scheduling' with create_session() as session: - self.assertEqual( - session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).count(), 0) + self.assertEqual(session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).count(), 0) def test_scheduler_verify_pool_full(self): """ Test task instances not queued when pool is full """ - dag = DAG( - dag_id='test_scheduler_verify_pool_full', - start_date=DEFAULT_DATE) + dag = DAG(dag_id='test_scheduler_verify_pool_full', start_date=DEFAULT_DATE) BashOperator( task_id='dummy', @@ -2509,9 +2470,11 @@ def test_scheduler_verify_pool_full(self): bash_command='echo hi', ) - dagbag = DagBag(dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"), - include_examples=False, - read_dags_from_db=True) + dagbag = DagBag( + dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"), + include_examples=False, + read_dags_from_db=True, + ) dagbag.bag_dag(dag=dag, root_dag=dag) dagbag.sync_to_db() @@ -2549,9 +2512,7 @@ def test_scheduler_verify_pool_full_2_slots_per_task(self): Variation with non-default pool_slots """ - dag = DAG( - dag_id='test_scheduler_verify_pool_full_2_slots_per_task', - start_date=DEFAULT_DATE) + dag = DAG(dag_id='test_scheduler_verify_pool_full_2_slots_per_task', start_date=DEFAULT_DATE) BashOperator( task_id='dummy', @@ -2562,9 +2523,11 @@ def test_scheduler_verify_pool_full_2_slots_per_task(self): bash_command='echo hi', ) - dagbag = DagBag(dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"), - include_examples=False, - read_dags_from_db=True) + dagbag = DagBag( + dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"), + include_examples=False, + read_dags_from_db=True, + ) dagbag.bag_dag(dag=dag, root_dag=dag) dagbag.sync_to_db() @@ -2601,9 +2564,7 @@ def test_scheduler_verify_priority_and_slots(self): Though tasks with lower priority might be executed. """ - dag = DAG( - dag_id='test_scheduler_verify_priority_and_slots', - start_date=DEFAULT_DATE) + dag = DAG(dag_id='test_scheduler_verify_priority_and_slots', start_date=DEFAULT_DATE) # Medium priority, not enough slots BashOperator( @@ -2636,9 +2597,11 @@ def test_scheduler_verify_priority_and_slots(self): bash_command='echo hi', ) - dagbag = DagBag(dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"), - include_examples=False, - read_dags_from_db=True) + dagbag = DagBag( + dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"), + include_examples=False, + read_dags_from_db=True, + ) dagbag.bag_dag(dag=dag, root_dag=dag) dagbag.sync_to_db() @@ -2664,16 +2627,25 @@ def test_scheduler_verify_priority_and_slots(self): # Only second and third self.assertEqual(len(task_instances_list), 2) - ti0 = session.query(TaskInstance)\ - .filter(TaskInstance.task_id == 'test_scheduler_verify_priority_and_slots_t0').first() + ti0 = ( + session.query(TaskInstance) + .filter(TaskInstance.task_id == 'test_scheduler_verify_priority_and_slots_t0') + .first() + ) self.assertEqual(ti0.state, State.SCHEDULED) - ti1 = session.query(TaskInstance)\ - .filter(TaskInstance.task_id == 'test_scheduler_verify_priority_and_slots_t1').first() + ti1 = ( + session.query(TaskInstance) + .filter(TaskInstance.task_id == 'test_scheduler_verify_priority_and_slots_t1') + .first() + ) self.assertEqual(ti1.state, State.QUEUED) - ti2 = session.query(TaskInstance)\ - .filter(TaskInstance.task_id == 'test_scheduler_verify_priority_and_slots_t2').first() + ti2 = ( + session.query(TaskInstance) + .filter(TaskInstance.task_id == 'test_scheduler_verify_priority_and_slots_t2') + .first() + ) self.assertEqual(ti2.state, State.QUEUED) def test_verify_integrity_if_dag_not_changed(self): @@ -2711,12 +2683,16 @@ def test_verify_integrity_if_dag_not_changed(self): assert scheduled_tis == 1 - tis_count = session.query(func.count(TaskInstance.task_id)).filter( - TaskInstance.dag_id == dr.dag_id, - TaskInstance.execution_date == dr.execution_date, - TaskInstance.task_id == dr.dag.tasks[0].task_id, - TaskInstance.state == State.SCHEDULED - ).scalar() + tis_count = ( + session.query(func.count(TaskInstance.task_id)) + .filter( + TaskInstance.dag_id == dr.dag_id, + TaskInstance.execution_date == dr.execution_date, + TaskInstance.task_id == dr.dag.tasks[0].task_id, + TaskInstance.state == State.SCHEDULED, + ) + .scalar() + ) assert tis_count == 1 latest_dag_version = SerializedDagModel.get_latest_version_hash(dr.dag_id, session=session) @@ -2776,11 +2752,15 @@ def test_verify_integrity_if_dag_changed(self): assert scheduler.dagbag.dags == {'test_verify_integrity_if_dag_changed': dag} assert len(scheduler.dagbag.dags.get("test_verify_integrity_if_dag_changed").tasks) == 2 - tis_count = session.query(func.count(TaskInstance.task_id)).filter( - TaskInstance.dag_id == dr.dag_id, - TaskInstance.execution_date == dr.execution_date, - TaskInstance.state == State.SCHEDULED - ).scalar() + tis_count = ( + session.query(func.count(TaskInstance.task_id)) + .filter( + TaskInstance.dag_id == dr.dag_id, + TaskInstance.execution_date == dr.execution_date, + TaskInstance.state == State.SCHEDULED, + ) + .scalar() + ) assert tis_count == 2 latest_dag_version = SerializedDagModel.get_latest_version_hash(dr.dag_id, session=session) @@ -2798,16 +2778,10 @@ def test_retry_still_in_executor(self): dagbag = DagBag(dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"), include_examples=False) dagbag.dags.clear() - dag = DAG( - dag_id='test_retry_still_in_executor', - start_date=DEFAULT_DATE, - schedule_interval="@once") + dag = DAG(dag_id='test_retry_still_in_executor', start_date=DEFAULT_DATE, schedule_interval="@once") dag_task1 = BashOperator( - task_id='test_retry_handling_op', - bash_command='exit 1', - retries=1, - dag=dag, - owner='airflow') + task_id='test_retry_handling_op', bash_command='exit 1', retries=1, dag=dag, owner='airflow' + ) dag.clear() dag.is_subdag = False @@ -2825,17 +2799,22 @@ def do_schedule(mock_dagbag): # Use a empty file since the above mock will return the # expected DAGs. Also specify only a single file so that it doesn't # try to schedule the above DAG repeatedly. - scheduler = SchedulerJob(num_runs=1, - executor=executor, - subdir=os.path.join(settings.DAGS_FOLDER, - "no_dags.py")) + scheduler = SchedulerJob( + num_runs=1, executor=executor, subdir=os.path.join(settings.DAGS_FOLDER, "no_dags.py") + ) scheduler.heartrate = 0 scheduler.run() do_schedule() # pylint: disable=no-value-for-parameter with create_session() as session: - ti = session.query(TaskInstance).filter(TaskInstance.dag_id == 'test_retry_still_in_executor', - TaskInstance.task_id == 'test_retry_handling_op').first() + ti = ( + session.query(TaskInstance) + .filter( + TaskInstance.dag_id == 'test_retry_still_in_executor', + TaskInstance.task_id == 'test_retry_handling_op', + ) + .first() + ) ti.task = dag_task1 def run_with_error(ti, ignore_ti_state=False): @@ -2875,14 +2854,16 @@ def test_retry_handling_job(self): dag_task1 = dag.get_task("test_retry_handling_op") dag.clear() - scheduler = SchedulerJob(dag_id=dag.dag_id, - num_runs=1) + scheduler = SchedulerJob(dag_id=dag.dag_id, num_runs=1) scheduler.heartrate = 0 scheduler.run() session = settings.Session() - ti = session.query(TaskInstance).filter(TaskInstance.dag_id == dag.dag_id, - TaskInstance.task_id == dag_task1.task_id).first() + ti = ( + session.query(TaskInstance) + .filter(TaskInstance.dag_id == dag.dag_id, TaskInstance.task_id == dag_task1.task_id) + .first() + ) # make sure the counter has increased self.assertEqual(ti.try_number, 2) @@ -2894,23 +2875,15 @@ def test_dag_get_active_runs(self): """ now = timezone.utcnow() - six_hours_ago_to_the_hour = \ - (now - datetime.timedelta(hours=6)).replace(minute=0, second=0, microsecond=0) + six_hours_ago_to_the_hour = (now - datetime.timedelta(hours=6)).replace( + minute=0, second=0, microsecond=0 + ) start_date = six_hours_ago_to_the_hour dag_name1 = 'get_active_runs_test' - default_args = { - 'owner': 'airflow', - 'depends_on_past': False, - 'start_date': start_date - - } - dag1 = DAG(dag_name1, - schedule_interval='* * * * *', - max_active_runs=1, - default_args=default_args - ) + default_args = {'owner': 'airflow', 'depends_on_past': False, 'start_date': start_date} + dag1 = DAG(dag_name1, schedule_interval='* * * * *', max_active_runs=1, default_args=default_args) run_this_1 = DummyOperator(task_id='run_this_1', dag=dag1) run_this_2 = DummyOperator(task_id='run_this_2', dag=dag1) @@ -2963,10 +2936,8 @@ def test_add_unparseable_file_before_sched_start_creates_import_error(self): self.assertEqual(len(import_errors), 1) import_error = import_errors[0] - self.assertEqual(import_error.filename, - unparseable_filename) - self.assertEqual(import_error.stacktrace, - f"invalid syntax ({TEMP_DAG_FILENAME}, line 1)") + self.assertEqual(import_error.filename, unparseable_filename) + self.assertEqual(import_error.stacktrace, f"invalid syntax ({TEMP_DAG_FILENAME}, line 1)") @conf_vars({("core", "dagbag_import_error_tracebacks"): "False"}) def test_add_unparseable_file_after_sched_start_creates_import_error(self): @@ -2993,10 +2964,8 @@ def test_add_unparseable_file_after_sched_start_creates_import_error(self): self.assertEqual(len(import_errors), 1) import_error = import_errors[0] - self.assertEqual(import_error.filename, - unparseable_filename) - self.assertEqual(import_error.stacktrace, - f"invalid syntax ({TEMP_DAG_FILENAME}, line 1)") + self.assertEqual(import_error.filename, unparseable_filename) + self.assertEqual(import_error.stacktrace, f"invalid syntax ({TEMP_DAG_FILENAME}, line 1)") def test_no_import_errors_with_parseable_dag(self): try: @@ -3028,9 +2997,8 @@ def test_new_import_error_replaces_old(self): # Generate replacement import error (the error will be on the second line now) with open(unparseable_filename, 'w') as unparseable_file: unparseable_file.writelines( - PARSEABLE_DAG_FILE_CONTENTS + - os.linesep + - UNPARSEABLE_DAG_FILE_CONTENTS) + PARSEABLE_DAG_FILE_CONTENTS + os.linesep + UNPARSEABLE_DAG_FILE_CONTENTS + ) self.run_single_scheduler_loop_with_no_dags(dags_folder) finally: shutil.rmtree(dags_folder) @@ -3040,10 +3008,8 @@ def test_new_import_error_replaces_old(self): self.assertEqual(len(import_errors), 1) import_error = import_errors[0] - self.assertEqual(import_error.filename, - unparseable_filename) - self.assertEqual(import_error.stacktrace, - f"invalid syntax ({TEMP_DAG_FILENAME}, line 2)") + self.assertEqual(import_error.filename, unparseable_filename) + self.assertEqual(import_error.stacktrace, f"invalid syntax ({TEMP_DAG_FILENAME}, line 2)") def test_remove_error_clears_import_error(self): try: @@ -3057,8 +3023,7 @@ def test_remove_error_clears_import_error(self): # Remove the import error from the file with open(filename_to_parse, 'w') as file_to_parse: - file_to_parse.writelines( - PARSEABLE_DAG_FILE_CONTENTS) + file_to_parse.writelines(PARSEABLE_DAG_FILE_CONTENTS) self.run_single_scheduler_loop_with_no_dags(dags_folder) finally: shutil.rmtree(dags_folder) @@ -3113,8 +3078,7 @@ def test_import_error_tracebacks(self): "NameError: name 'airflow_DAG' is not defined\n" ) self.assertEqual( - import_error.stacktrace, - expected_stacktrace.format(unparseable_filename, unparseable_filename) + import_error.stacktrace, expected_stacktrace.format(unparseable_filename, unparseable_filename) ) @conf_vars({("core", "dagbag_import_error_traceback_depth"): "1"}) @@ -3140,9 +3104,7 @@ def test_import_error_traceback_depth(self): " return airflow_DAG\n" "NameError: name 'airflow_DAG' is not defined\n" ) - self.assertEqual( - import_error.stacktrace, expected_stacktrace.format(unparseable_filename) - ) + self.assertEqual(import_error.stacktrace, expected_stacktrace.format(unparseable_filename)) def test_import_error_tracebacks_zip(self): dags_folder = mkdtemp() @@ -3170,8 +3132,7 @@ def test_import_error_tracebacks_zip(self): "NameError: name 'airflow_DAG' is not defined\n" ) self.assertEqual( - import_error.stacktrace, - expected_stacktrace.format(invalid_dag_filename, invalid_dag_filename) + import_error.stacktrace, expected_stacktrace.format(invalid_dag_filename, invalid_dag_filename) ) @conf_vars({("core", "dagbag_import_error_traceback_depth"): "1"}) @@ -3198,9 +3159,7 @@ def test_import_error_tracebacks_zip_depth(self): " return airflow_DAG\n" "NameError: name 'airflow_DAG' is not defined\n" ) - self.assertEqual( - import_error.stacktrace, expected_stacktrace.format(invalid_dag_filename) - ) + self.assertEqual(import_error.stacktrace, expected_stacktrace.format(invalid_dag_filename)) def test_list_py_file_paths(self): """ @@ -3220,8 +3179,7 @@ def test_list_py_file_paths(self): for file_name in files: if file_name.endswith('.py') or file_name.endswith('.zip'): if file_name not in ignored_files: - expected_files.add( - f'{root}/{file_name}') + expected_files.add(f'{root}/{file_name}') for file_path in list_py_file_paths(TEST_DAG_FOLDER, include_examples=False): detected_files.add(file_path) self.assertEqual(detected_files, expected_files) @@ -3243,13 +3201,14 @@ def test_list_py_file_paths(self): smart_sensor_dag_folder = airflow.smart_sensor_dags.__path__[0] for root, _, files in os.walk(smart_sensor_dag_folder): for file_name in files: - if (file_name.endswith('.py') or file_name.endswith('.zip')) and \ - file_name not in ['__init__.py']: + if (file_name.endswith('.py') or file_name.endswith('.zip')) and file_name not in [ + '__init__.py' + ]: expected_files.add(os.path.join(root, file_name)) detected_files.clear() - for file_path in list_py_file_paths(TEST_DAG_FOLDER, - include_examples=True, - include_smart_sensor=True): + for file_path in list_py_file_paths( + TEST_DAG_FOLDER, include_examples=True, include_smart_sensor=True + ): detected_files.add(file_path) self.assertEqual(detected_files, expected_files) @@ -3268,12 +3227,14 @@ def test_adopt_or_reset_orphaned_tasks_external_triggered_dag(self): scheduler = SchedulerJob() session = settings.Session() - dr1 = dag.create_dagrun(run_type=DagRunType.SCHEDULED, - state=State.RUNNING, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE, - external_trigger=True, - session=session) + dr1 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + state=State.RUNNING, + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + external_trigger=True, + session=session, + ) ti = dr1.get_task_instances(session=session)[0] ti.state = State.SCHEDULED session.merge(ti) @@ -3294,11 +3255,13 @@ def test_adopt_or_reset_orphaned_tasks_backfill_dag(self): session.add(scheduler) session.flush() - dr1 = dag.create_dagrun(run_type=DagRunType.BACKFILL_JOB, - state=State.RUNNING, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE, - session=session) + dr1 = dag.create_dagrun( + run_type=DagRunType.BACKFILL_JOB, + state=State.RUNNING, + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + session=session, + ) ti = dr1.get_task_instances(session=session)[0] ti.state = State.SCHEDULED session.merge(ti) @@ -3342,11 +3305,13 @@ def test_reset_orphaned_tasks_no_orphans(self): session.add(scheduler) session.flush() - dr1 = dag.create_dagrun(run_type=DagRunType.SCHEDULED, - state=State.RUNNING, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE, - session=session) + dr1 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + state=State.RUNNING, + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + session=session, + ) tis = dr1.get_task_instances(session=session) tis[0].state = State.RUNNING tis[0].queued_by_job_id = scheduler.id @@ -3370,11 +3335,13 @@ def test_reset_orphaned_tasks_non_running_dagruns(self): session.add(scheduler) session.flush() - dr1 = dag.create_dagrun(run_type=DagRunType.SCHEDULED, - state=State.SUCCESS, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE, - session=session) + dr1 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + state=State.SUCCESS, + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + session=session, + ) tis = dr1.get_task_instances(session=session) self.assertEqual(1, len(tis)) tis[0].state = State.SCHEDULED @@ -3409,7 +3376,7 @@ def test_adopt_or_reset_orphaned_tasks_stale_scheduler_jobs(self): execution_date=DEFAULT_DATE, start_date=timezone.utcnow(), state=State.RUNNING, - session=session + session=session, ) ti1, ti2 = dr1.get_task_instances(session=session) @@ -3482,9 +3449,7 @@ def test_send_sla_callbacks_to_processor_sla_with_task_slas(self): ) def test_scheduler_sets_job_id_on_dag_run(self): - dag = DAG( - dag_id='test_scheduler_sets_job_id_on_dag_run', - start_date=DEFAULT_DATE) + dag = DAG(dag_id='test_scheduler_sets_job_id_on_dag_run', start_date=DEFAULT_DATE) DummyOperator( task_id='dummy', @@ -3494,7 +3459,7 @@ def test_scheduler_sets_job_id_on_dag_run(self): dagbag = DagBag( dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"), include_examples=False, - read_dags_from_db=True + read_dags_from_db=True, ) dagbag.bag_dag(dag=dag, root_dag=dag) dagbag.sync_to_db() @@ -3587,7 +3552,7 @@ def test_do_schedule_max_active_runs_and_manual_trigger(self): dagbag = DagBag( dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"), include_examples=False, - read_dags_from_db=True + read_dags_from_db=True, ) dagbag.bag_dag(dag=dag, root_dag=dag) @@ -3665,9 +3630,7 @@ def test_task_with_upstream_skip_process_task_instances(): """ clear_db_runs() with DAG( - dag_id='test_task_with_upstream_skip_dag', - start_date=DEFAULT_DATE, - schedule_interval=None + dag_id='test_task_with_upstream_skip_dag', start_date=DEFAULT_DATE, schedule_interval=None ) as dag: dummy1 = DummyOperator(task_id='dummy1') dummy2 = DummyOperator(task_id="dummy2") @@ -3676,9 +3639,7 @@ def test_task_with_upstream_skip_process_task_instances(): # dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) dag.clear() - dr = dag.create_dagrun(run_type=DagRunType.MANUAL, - state=State.RUNNING, - execution_date=DEFAULT_DATE) + dr = dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=DEFAULT_DATE) assert dr is not None with create_session() as session: @@ -3729,17 +3690,24 @@ def setUp(self) -> None: ) @pytest.mark.quarantined def test_execute_queries_count_with_harvested_dags(self, expected_query_count, dag_count, task_count): - with mock.patch.dict("os.environ", { - "PERF_DAGS_COUNT": str(dag_count), - "PERF_TASKS_COUNT": str(task_count), - "PERF_START_AGO": "1d", - "PERF_SCHEDULE_INTERVAL": "30m", - "PERF_SHAPE": "no_structure", - }), conf_vars({ - ('scheduler', 'use_job_schedule'): 'True', - ('core', 'load_examples'): 'False', - ('core', 'store_serialized_dags'): 'True', - }), mock.patch.object(settings, 'STORE_SERIALIZED_DAGS', True): + with mock.patch.dict( + "os.environ", + { + "PERF_DAGS_COUNT": str(dag_count), + "PERF_TASKS_COUNT": str(task_count), + "PERF_START_AGO": "1d", + "PERF_SCHEDULE_INTERVAL": "30m", + "PERF_SHAPE": "no_structure", + }, + ), conf_vars( + { + ('scheduler', 'use_job_schedule'): 'True', + ('core', 'load_examples'): 'False', + ('core', 'store_serialized_dags'): 'True', + } + ), mock.patch.object( + settings, 'STORE_SERIALIZED_DAGS', True + ): dagruns = [] dagbag = DagBag(dag_folder=ELASTIC_DAG_FILE, include_examples=False, read_dags_from_db=False) dagbag.sync_to_db() @@ -3775,35 +3743,35 @@ def test_execute_queries_count_with_harvested_dags(self, expected_query_count, d # pylint: disable=bad-whitespace # expected, dag_count, task_count, start_ago, schedule_interval, shape # One DAG with one task per DAG file - ([10, 10, 10, 10], 1, 1, "1d", "None", "no_structure"), # noqa - ([10, 10, 10, 10], 1, 1, "1d", "None", "linear"), # noqa - ([22, 14, 14, 14], 1, 1, "1d", "@once", "no_structure"), # noqa - ([22, 14, 14, 14], 1, 1, "1d", "@once", "linear"), # noqa - ([22, 24, 27, 30], 1, 1, "1d", "30m", "no_structure"), # noqa - ([22, 24, 27, 30], 1, 1, "1d", "30m", "linear"), # noqa - ([22, 24, 27, 30], 1, 1, "1d", "30m", "binary_tree"), # noqa - ([22, 24, 27, 30], 1, 1, "1d", "30m", "star"), # noqa - ([22, 24, 27, 30], 1, 1, "1d", "30m", "grid"), # noqa + ([10, 10, 10, 10], 1, 1, "1d", "None", "no_structure"), # noqa + ([10, 10, 10, 10], 1, 1, "1d", "None", "linear"), # noqa + ([22, 14, 14, 14], 1, 1, "1d", "@once", "no_structure"), # noqa + ([22, 14, 14, 14], 1, 1, "1d", "@once", "linear"), # noqa + ([22, 24, 27, 30], 1, 1, "1d", "30m", "no_structure"), # noqa + ([22, 24, 27, 30], 1, 1, "1d", "30m", "linear"), # noqa + ([22, 24, 27, 30], 1, 1, "1d", "30m", "binary_tree"), # noqa + ([22, 24, 27, 30], 1, 1, "1d", "30m", "star"), # noqa + ([22, 24, 27, 30], 1, 1, "1d", "30m", "grid"), # noqa # One DAG with five tasks per DAG file - ([10, 10, 10, 10], 1, 5, "1d", "None", "no_structure"), # noqa - ([10, 10, 10, 10], 1, 5, "1d", "None", "linear"), # noqa - ([22, 14, 14, 14], 1, 5, "1d", "@once", "no_structure"), # noqa - ([23, 15, 15, 15], 1, 5, "1d", "@once", "linear"), # noqa - ([22, 24, 27, 30], 1, 5, "1d", "30m", "no_structure"), # noqa - ([23, 26, 30, 34], 1, 5, "1d", "30m", "linear"), # noqa - ([23, 26, 30, 34], 1, 5, "1d", "30m", "binary_tree"), # noqa - ([23, 26, 30, 34], 1, 5, "1d", "30m", "star"), # noqa - ([23, 26, 30, 34], 1, 5, "1d", "30m", "grid"), # noqa + ([10, 10, 10, 10], 1, 5, "1d", "None", "no_structure"), # noqa + ([10, 10, 10, 10], 1, 5, "1d", "None", "linear"), # noqa + ([22, 14, 14, 14], 1, 5, "1d", "@once", "no_structure"), # noqa + ([23, 15, 15, 15], 1, 5, "1d", "@once", "linear"), # noqa + ([22, 24, 27, 30], 1, 5, "1d", "30m", "no_structure"), # noqa + ([23, 26, 30, 34], 1, 5, "1d", "30m", "linear"), # noqa + ([23, 26, 30, 34], 1, 5, "1d", "30m", "binary_tree"), # noqa + ([23, 26, 30, 34], 1, 5, "1d", "30m", "star"), # noqa + ([23, 26, 30, 34], 1, 5, "1d", "30m", "grid"), # noqa # 10 DAGs with 10 tasks per DAG file - ([10, 10, 10, 10], 10, 10, "1d", "None", "no_structure"), # noqa - ([10, 10, 10, 10], 10, 10, "1d", "None", "linear"), # noqa - ([85, 38, 38, 38], 10, 10, "1d", "@once", "no_structure"), # noqa - ([95, 51, 51, 51], 10, 10, "1d", "@once", "linear"), # noqa - ([85, 99, 99, 99], 10, 10, "1d", "30m", "no_structure"), # noqa - ([95, 125, 125, 125], 10, 10, "1d", "30m", "linear"), # noqa - ([95, 119, 119, 119], 10, 10, "1d", "30m", "binary_tree"), # noqa - ([95, 119, 119, 119], 10, 10, "1d", "30m", "star"), # noqa - ([95, 119, 119, 119], 10, 10, "1d", "30m", "grid"), # noqa + ([10, 10, 10, 10], 10, 10, "1d", "None", "no_structure"), # noqa + ([10, 10, 10, 10], 10, 10, "1d", "None", "linear"), # noqa + ([85, 38, 38, 38], 10, 10, "1d", "@once", "no_structure"), # noqa + ([95, 51, 51, 51], 10, 10, "1d", "@once", "linear"), # noqa + ([85, 99, 99, 99], 10, 10, "1d", "30m", "no_structure"), # noqa + ([95, 125, 125, 125], 10, 10, "1d", "30m", "linear"), # noqa + ([95, 119, 119, 119], 10, 10, "1d", "30m", "binary_tree"), # noqa + ([95, 119, 119, 119], 10, 10, "1d", "30m", "star"), # noqa + ([95, 119, 119, 119], 10, 10, "1d", "30m", "grid"), # noqa # pylint: enable=bad-whitespace ] ) @@ -3811,16 +3779,23 @@ def test_execute_queries_count_with_harvested_dags(self, expected_query_count, d def test_process_dags_queries_count( self, expected_query_counts, dag_count, task_count, start_ago, schedule_interval, shape ): - with mock.patch.dict("os.environ", { - "PERF_DAGS_COUNT": str(dag_count), - "PERF_TASKS_COUNT": str(task_count), - "PERF_START_AGO": start_ago, - "PERF_SCHEDULE_INTERVAL": schedule_interval, - "PERF_SHAPE": shape, - }), conf_vars({ - ('scheduler', 'use_job_schedule'): 'True', - ('core', 'store_serialized_dags'): 'True', - }), mock.patch.object(settings, 'STORE_SERIALIZED_DAGS', True): + with mock.patch.dict( + "os.environ", + { + "PERF_DAGS_COUNT": str(dag_count), + "PERF_TASKS_COUNT": str(task_count), + "PERF_START_AGO": start_ago, + "PERF_SCHEDULE_INTERVAL": schedule_interval, + "PERF_SHAPE": shape, + }, + ), conf_vars( + { + ('scheduler', 'use_job_schedule'): 'True', + ('core', 'store_serialized_dags'): 'True', + } + ), mock.patch.object( + settings, 'STORE_SERIALIZED_DAGS', True + ): dagbag = DagBag(dag_folder=ELASTIC_DAG_FILE, include_examples=False) dagbag.sync_to_db() diff --git a/tests/kubernetes/models/test_secret.py b/tests/kubernetes/models/test_secret.py index 7036cf7f8d640..183a9ef9bd381 100644 --- a/tests/kubernetes/models/test_secret.py +++ b/tests/kubernetes/models/test_secret.py @@ -27,65 +27,50 @@ class TestSecret(unittest.TestCase): - def test_to_env_secret(self): secret = Secret('env', 'name', 'secret', 'key') - self.assertEqual(secret.to_env_secret(), k8s.V1EnvVar( - name='NAME', - value_from=k8s.V1EnvVarSource( - secret_key_ref=k8s.V1SecretKeySelector( - name='secret', - key='key' - ) - ) - )) + self.assertEqual( + secret.to_env_secret(), + k8s.V1EnvVar( + name='NAME', + value_from=k8s.V1EnvVarSource( + secret_key_ref=k8s.V1SecretKeySelector(name='secret', key='key') + ), + ), + ) def test_to_env_from_secret(self): secret = Secret('env', None, 'secret') - self.assertEqual(secret.to_env_from_secret(), k8s.V1EnvFromSource( - secret_ref=k8s.V1SecretEnvSource(name='secret') - )) + self.assertEqual( + secret.to_env_from_secret(), k8s.V1EnvFromSource(secret_ref=k8s.V1SecretEnvSource(name='secret')) + ) @mock.patch('uuid.uuid4') def test_to_volume_secret(self, mock_uuid): mock_uuid.return_value = '0' secret = Secret('volume', '/etc/foo', 'secret_b') - self.assertEqual(secret.to_volume_secret(), ( - k8s.V1Volume( - name='secretvol0', - secret=k8s.V1SecretVolumeSource( - secret_name='secret_b' - ) + self.assertEqual( + secret.to_volume_secret(), + ( + k8s.V1Volume(name='secretvol0', secret=k8s.V1SecretVolumeSource(secret_name='secret_b')), + k8s.V1VolumeMount(mount_path='/etc/foo', name='secretvol0', read_only=True), ), - k8s.V1VolumeMount( - mount_path='/etc/foo', - name='secretvol0', - read_only=True - ) - )) + ) @mock.patch('uuid.uuid4') def test_only_mount_sub_secret(self, mock_uuid): mock_uuid.return_value = '0' - items = [k8s.V1KeyToPath( - key="my-username", - path="/extra/path" - ) - ] + items = [k8s.V1KeyToPath(key="my-username", path="/extra/path")] secret = Secret('volume', '/etc/foo', 'secret_b', items=items) - self.assertEqual(secret.to_volume_secret(), ( - k8s.V1Volume( - name='secretvol0', - secret=k8s.V1SecretVolumeSource( - secret_name='secret_b', - items=items) + self.assertEqual( + secret.to_volume_secret(), + ( + k8s.V1Volume( + name='secretvol0', secret=k8s.V1SecretVolumeSource(secret_name='secret_b', items=items) + ), + k8s.V1VolumeMount(mount_path='/etc/foo', name='secretvol0', read_only=True), ), - k8s.V1VolumeMount( - mount_path='/etc/foo', - name='secretvol0', - read_only=True - ) - )) + ) @mock.patch('uuid.uuid4') def test_attach_to_pod(self, mock_uuid): @@ -104,44 +89,60 @@ def test_attach_to_pod(self, mock_uuid): k8s_client = ApiClient() pod = append_to_pod(pod, secrets) result = k8s_client.sanitize_for_serialization(pod) - self.assertEqual(result, - {'apiVersion': 'v1', - 'kind': 'Pod', - 'metadata': {'labels': {'app': 'myapp'}, - 'name': 'myapp-pod-cf4a56d281014217b0272af6216feb48', - 'namespace': 'default'}, - 'spec': {'containers': [{'command': ['sh', '-c', 'echo Hello Kubernetes!'], - 'env': [{'name': 'ENVIRONMENT', 'value': 'prod'}, - {'name': 'LOG_LEVEL', 'value': 'warning'}, - {'name': 'TARGET', - 'valueFrom': - {'secretKeyRef': {'key': 'source_b', - 'name': 'secret_b'}}}], - 'envFrom': [{'configMapRef': {'name': 'configmap_a'}}, - {'secretRef': {'name': 'secret_a'}}], - 'image': 'busybox', - 'name': 'base', - 'ports': [{'containerPort': 1234, 'name': 'foo'}], - 'resources': {'limits': {'memory': '200Mi'}, - 'requests': {'memory': '100Mi'}}, - 'volumeMounts': [{'mountPath': '/airflow/xcom', - 'name': 'xcom'}, - {'mountPath': '/etc/foo', - 'name': 'secretvol' + str(static_uuid), - 'readOnly': True}]}, - {'command': ['sh', - '-c', - 'trap "exit 0" INT; while true; do sleep ' - '30; done;'], - 'image': 'alpine', - 'name': 'airflow-xcom-sidecar', - 'resources': {'requests': {'cpu': '1m'}}, - 'volumeMounts': [{'mountPath': '/airflow/xcom', - 'name': 'xcom'}]}], - 'hostNetwork': True, - 'imagePullSecrets': [{'name': 'pull_secret_a'}, - {'name': 'pull_secret_b'}], - 'securityContext': {'fsGroup': 2000, 'runAsUser': 1000}, - 'volumes': [{'emptyDir': {}, 'name': 'xcom'}, - {'name': 'secretvol' + str(static_uuid), - 'secret': {'secretName': 'secret_b'}}]}}) + self.assertEqual( + result, + { + 'apiVersion': 'v1', + 'kind': 'Pod', + 'metadata': { + 'labels': {'app': 'myapp'}, + 'name': 'myapp-pod-cf4a56d281014217b0272af6216feb48', + 'namespace': 'default', + }, + 'spec': { + 'containers': [ + { + 'command': ['sh', '-c', 'echo Hello Kubernetes!'], + 'env': [ + {'name': 'ENVIRONMENT', 'value': 'prod'}, + {'name': 'LOG_LEVEL', 'value': 'warning'}, + { + 'name': 'TARGET', + 'valueFrom': {'secretKeyRef': {'key': 'source_b', 'name': 'secret_b'}}, + }, + ], + 'envFrom': [ + {'configMapRef': {'name': 'configmap_a'}}, + {'secretRef': {'name': 'secret_a'}}, + ], + 'image': 'busybox', + 'name': 'base', + 'ports': [{'containerPort': 1234, 'name': 'foo'}], + 'resources': {'limits': {'memory': '200Mi'}, 'requests': {'memory': '100Mi'}}, + 'volumeMounts': [ + {'mountPath': '/airflow/xcom', 'name': 'xcom'}, + { + 'mountPath': '/etc/foo', + 'name': 'secretvol' + str(static_uuid), + 'readOnly': True, + }, + ], + }, + { + 'command': ['sh', '-c', 'trap "exit 0" INT; while true; do sleep 30; done;'], + 'image': 'alpine', + 'name': 'airflow-xcom-sidecar', + 'resources': {'requests': {'cpu': '1m'}}, + 'volumeMounts': [{'mountPath': '/airflow/xcom', 'name': 'xcom'}], + }, + ], + 'hostNetwork': True, + 'imagePullSecrets': [{'name': 'pull_secret_a'}, {'name': 'pull_secret_b'}], + 'securityContext': {'fsGroup': 2000, 'runAsUser': 1000}, + 'volumes': [ + {'emptyDir': {}, 'name': 'xcom'}, + {'name': 'secretvol' + str(static_uuid), 'secret': {'secretName': 'secret_b'}}, + ], + }, + }, + ) diff --git a/tests/kubernetes/test_client.py b/tests/kubernetes/test_client.py index fae79a5470bae..73faedd370a9e 100644 --- a/tests/kubernetes/test_client.py +++ b/tests/kubernetes/test_client.py @@ -25,7 +25,6 @@ class TestClient(unittest.TestCase): - @mock.patch('airflow.kubernetes.kube_client.config') def test_load_cluster_config(self, _): client = get_kube_client(in_cluster=True) diff --git a/tests/kubernetes/test_pod_generator.py b/tests/kubernetes/test_pod_generator.py index 865c156b45678..8af12f9a03fd6 100644 --- a/tests/kubernetes/test_pod_generator.py +++ b/tests/kubernetes/test_pod_generator.py @@ -25,13 +25,16 @@ from airflow import __version__ from airflow.exceptions import AirflowConfigException from airflow.kubernetes.pod_generator import ( - PodDefaults, PodGenerator, datetime_to_label_safe_datestring, extend_object_field, merge_objects, + PodDefaults, + PodGenerator, + datetime_to_label_safe_datestring, + extend_object_field, + merge_objects, ) from airflow.kubernetes.secret import Secret class TestPodGenerator(unittest.TestCase): - def setUp(self): self.static_uuid = uuid.UUID('cf4a56d2-8101-4217-b027-2af6216feb48') self.deserialize_result = { @@ -39,23 +42,19 @@ def setUp(self): 'kind': 'Pod', 'metadata': {'name': 'memory-demo', 'namespace': 'mem-example'}, 'spec': { - 'containers': [{ - 'args': ['--vm', '1', '--vm-bytes', '150M', '--vm-hang', '1'], - 'command': ['stress'], - 'image': 'apache/airflow:stress-2020.07.10-1.0.4', - 'name': 'memory-demo-ctr', - 'resources': { - 'limits': {'memory': '200Mi'}, - 'requests': {'memory': '100Mi'} + 'containers': [ + { + 'args': ['--vm', '1', '--vm-bytes', '150M', '--vm-hang', '1'], + 'command': ['stress'], + 'image': 'apache/airflow:stress-2020.07.10-1.0.4', + 'name': 'memory-demo-ctr', + 'resources': {'limits': {'memory': '200Mi'}, 'requests': {'memory': '100Mi'}}, } - }] - } + ] + }, } - self.envs = { - 'ENVIRONMENT': 'prod', - 'LOG_LEVEL': 'warning' - } + self.envs = {'ENVIRONMENT': 'prod', 'LOG_LEVEL': 'warning'} self.secrets = [ # This should be a secretRef Secret('env', None, 'secret_a'), @@ -77,7 +76,7 @@ def setUp(self): 'task_id': self.task_id, 'try_number': str(self.try_number), 'airflow_version': __version__.replace('+', '-'), - 'kubernetes_executor': 'True' + 'kubernetes_executor': 'True', } self.annotations = { 'dag_id': self.dag_id, @@ -98,12 +97,7 @@ def setUp(self): "memory": "1Gi", "ephemeral-storage": "2Gi", }, - limits={ - "cpu": 2, - "memory": "2Gi", - "ephemeral-storage": "4Gi", - 'nvidia.com/gpu': 1 - } + limits={"cpu": 2, "memory": "2Gi", "ephemeral-storage": "4Gi", 'nvidia.com/gpu': 1}, ) self.k8s_client = ApiClient() @@ -120,14 +114,9 @@ def setUp(self): k8s.V1Container( name='base', image='busybox', - command=[ - 'sh', '-c', 'echo Hello Kubernetes!' - ], + command=['sh', '-c', 'echo Hello Kubernetes!'], env=[ - k8s.V1EnvVar( - name='ENVIRONMENT', - value='prod' - ), + k8s.V1EnvVar(name='ENVIRONMENT', value='prod'), k8s.V1EnvVar( name="LOG_LEVEL", value='warning', @@ -135,44 +124,22 @@ def setUp(self): k8s.V1EnvVar( name='TARGET', value_from=k8s.V1EnvVarSource( - secret_key_ref=k8s.V1SecretKeySelector( - name='secret_b', - key='source_b' - ) + secret_key_ref=k8s.V1SecretKeySelector(name='secret_b', key='source_b') ), - ) - ], - env_from=[ - k8s.V1EnvFromSource( - config_map_ref=k8s.V1ConfigMapEnvSource( - name='configmap_a' - ) - ), - k8s.V1EnvFromSource( - config_map_ref=k8s.V1ConfigMapEnvSource( - name='configmap_b' - ) - ), - k8s.V1EnvFromSource( - secret_ref=k8s.V1SecretEnvSource( - name='secret_a' - ) ), ], - ports=[ - k8s.V1ContainerPort( - name="foo", - container_port=1234 - ) + env_from=[ + k8s.V1EnvFromSource(config_map_ref=k8s.V1ConfigMapEnvSource(name='configmap_a')), + k8s.V1EnvFromSource(config_map_ref=k8s.V1ConfigMapEnvSource(name='configmap_b')), + k8s.V1EnvFromSource(secret_ref=k8s.V1SecretEnvSource(name='secret_a')), ], + ports=[k8s.V1ContainerPort(name="foo", container_port=1234)], resources=k8s.V1ResourceRequirements( - requests={ - 'memory': '100Mi' - }, + requests={'memory': '100Mi'}, limits={ 'memory': '200Mi', - } - ) + }, + ), ) ], security_context=k8s.V1PodSecurityContext( @@ -181,13 +148,9 @@ def setUp(self): ), host_network=True, image_pull_secrets=[ - k8s.V1LocalObjectReference( - name="pull_secret_a" - ), - k8s.V1LocalObjectReference( - name="pull_secret_b" - ) - ] + k8s.V1LocalObjectReference(name="pull_secret_a"), + k8s.V1LocalObjectReference(name="pull_secret_b"), + ], ), ) @@ -196,31 +159,20 @@ def test_gen_pod_extract_xcom(self, mock_uuid): mock_uuid.return_value = self.static_uuid path = sys.path[0] + '/tests/kubernetes/pod_generator_base_with_secrets.yaml' - pod_generator = PodGenerator( - pod_template_file=path, - extract_xcom=True - ) + pod_generator = PodGenerator(pod_template_file=path, extract_xcom=True) result = pod_generator.gen_pod() result_dict = self.k8s_client.sanitize_for_serialization(result) container_two = { 'name': 'airflow-xcom-sidecar', 'image': "alpine", 'command': ['sh', '-c', PodDefaults.XCOM_CMD], - 'volumeMounts': [ - { - 'name': 'xcom', - 'mountPath': '/airflow/xcom' - } - ], + 'volumeMounts': [{'name': 'xcom', 'mountPath': '/airflow/xcom'}], 'resources': {'requests': {'cpu': '1m'}}, } self.expected.spec.containers.append(container_two) base_container: k8s.V1Container = self.expected.spec.containers[0] base_container.volume_mounts = base_container.volume_mounts or [] - base_container.volume_mounts.append(k8s.V1VolumeMount( - name="xcom", - mount_path="/airflow/xcom" - )) + base_container.volume_mounts.append(k8s.V1VolumeMount(name="xcom", mount_path="/airflow/xcom")) self.expected.spec.containers[0] = base_container self.expected.spec.volumes = self.expected.spec.volumes or [] self.expected.spec.volumes.append( @@ -240,153 +192,148 @@ def test_from_obj(self): "pod_override": k8s.V1Pod( api_version="v1", kind="Pod", - metadata=k8s.V1ObjectMeta( - name="foo", - annotations={"test": "annotation"} - ), + metadata=k8s.V1ObjectMeta(name="foo", annotations={"test": "annotation"}), spec=k8s.V1PodSpec( containers=[ k8s.V1Container( name="base", volume_mounts=[ k8s.V1VolumeMount( - mount_path="/foo/", - name="example-kubernetes-test-volume" + mount_path="/foo/", name="example-kubernetes-test-volume" ) - ] + ], ) ], volumes=[ k8s.V1Volume( name="example-kubernetes-test-volume", - host_path=k8s.V1HostPathVolumeSource( - path="/tmp/" - ) + host_path=k8s.V1HostPathVolumeSource(path="/tmp/"), ) - ] - ) + ], + ), ) } ) result = self.k8s_client.sanitize_for_serialization(result) - self.assertEqual({ - 'apiVersion': 'v1', - 'kind': 'Pod', - 'metadata': { - 'name': 'foo', - 'annotations': {'test': 'annotation'}, + self.assertEqual( + { + 'apiVersion': 'v1', + 'kind': 'Pod', + 'metadata': { + 'name': 'foo', + 'annotations': {'test': 'annotation'}, + }, + 'spec': { + 'containers': [ + { + 'name': 'base', + 'volumeMounts': [ + {'mountPath': '/foo/', 'name': 'example-kubernetes-test-volume'} + ], + } + ], + 'volumes': [{'hostPath': {'path': '/tmp/'}, 'name': 'example-kubernetes-test-volume'}], + }, }, - 'spec': { - 'containers': [{ - 'name': 'base', - 'volumeMounts': [{ - 'mountPath': '/foo/', - 'name': 'example-kubernetes-test-volume' - }], - }], - 'volumes': [{ - 'hostPath': {'path': '/tmp/'}, - 'name': 'example-kubernetes-test-volume' - }], - } - }, result) - result = PodGenerator.from_obj({ - "KubernetesExecutor": { - "annotations": {"test": "annotation"}, - "volumes": [ - { - "name": "example-kubernetes-test-volume", - "hostPath": {"path": "/tmp/"}, - }, - ], - "volume_mounts": [ - { - "mountPath": "/foo/", - "name": "example-kubernetes-test-volume", - }, - ], + result, + ) + result = PodGenerator.from_obj( + { + "KubernetesExecutor": { + "annotations": {"test": "annotation"}, + "volumes": [ + { + "name": "example-kubernetes-test-volume", + "hostPath": {"path": "/tmp/"}, + }, + ], + "volume_mounts": [ + { + "mountPath": "/foo/", + "name": "example-kubernetes-test-volume", + }, + ], + } } - }) + ) result_from_pod = PodGenerator.from_obj( - {"pod_override": - k8s.V1Pod( - metadata=k8s.V1ObjectMeta( - annotations={"test": "annotation"} - ), + { + "pod_override": k8s.V1Pod( + metadata=k8s.V1ObjectMeta(annotations={"test": "annotation"}), spec=k8s.V1PodSpec( containers=[ k8s.V1Container( name="base", volume_mounts=[ k8s.V1VolumeMount( - name="example-kubernetes-test-volume", - mount_path="/foo/" + name="example-kubernetes-test-volume", mount_path="/foo/" ) - ] + ], ) ], - volumes=[ - k8s.V1Volume( - name="example-kubernetes-test-volume", - host_path="/tmp/" - ) - ] - ) + volumes=[k8s.V1Volume(name="example-kubernetes-test-volume", host_path="/tmp/")], + ), ) - } + } ) result = self.k8s_client.sanitize_for_serialization(result) result_from_pod = self.k8s_client.sanitize_for_serialization(result_from_pod) - expected_from_pod = {'metadata': {'annotations': {'test': 'annotation'}}, - 'spec': {'containers': [ - {'name': 'base', - 'volumeMounts': [{'mountPath': '/foo/', - 'name': 'example-kubernetes-test-volume'}]}], - 'volumes': [{'hostPath': '/tmp/', - 'name': 'example-kubernetes-test-volume'}]}} - self.assertEqual(result_from_pod, expected_from_pod, "There was a discrepency" - " between KubernetesExecutor and pod_override") - - self.assertEqual({ - 'apiVersion': 'v1', - 'kind': 'Pod', - 'metadata': { - 'annotations': {'test': 'annotation'}, - }, + expected_from_pod = { + 'metadata': {'annotations': {'test': 'annotation'}}, 'spec': { - 'containers': [{ - 'args': [], - 'command': [], - 'env': [], - 'envFrom': [], - 'name': 'base', - 'ports': [], - 'volumeMounts': [{ - 'mountPath': '/foo/', - 'name': 'example-kubernetes-test-volume' - }], - }], - 'hostNetwork': False, - 'imagePullSecrets': [], - 'volumes': [{ - 'hostPath': {'path': '/tmp/'}, - 'name': 'example-kubernetes-test-volume' - }], - } - }, result) + 'containers': [ + { + 'name': 'base', + 'volumeMounts': [{'mountPath': '/foo/', 'name': 'example-kubernetes-test-volume'}], + } + ], + 'volumes': [{'hostPath': '/tmp/', 'name': 'example-kubernetes-test-volume'}], + }, + } + self.assertEqual( + result_from_pod, + expected_from_pod, + "There was a discrepency" " between KubernetesExecutor and pod_override", + ) + + self.assertEqual( + { + 'apiVersion': 'v1', + 'kind': 'Pod', + 'metadata': { + 'annotations': {'test': 'annotation'}, + }, + 'spec': { + 'containers': [ + { + 'args': [], + 'command': [], + 'env': [], + 'envFrom': [], + 'name': 'base', + 'ports': [], + 'volumeMounts': [ + {'mountPath': '/foo/', 'name': 'example-kubernetes-test-volume'} + ], + } + ], + 'hostNetwork': False, + 'imagePullSecrets': [], + 'volumes': [{'hostPath': {'path': '/tmp/'}, 'name': 'example-kubernetes-test-volume'}], + }, + }, + result, + ) @mock.patch('uuid.uuid4') def test_reconcile_pods_empty_mutator_pod(self, mock_uuid): mock_uuid.return_value = self.static_uuid path = sys.path[0] + '/tests/kubernetes/pod_generator_base_with_secrets.yaml' - pod_generator = PodGenerator( - pod_template_file=path, - extract_xcom=True - ) + pod_generator = PodGenerator(pod_template_file=path, extract_xcom=True) base_pod = pod_generator.gen_pod() mutator_pod = None name = 'name1-' + self.static_uuid.hex @@ -405,10 +352,7 @@ def test_reconcile_pods(self, mock_uuid): mock_uuid.return_value = self.static_uuid path = sys.path[0] + '/tests/kubernetes/pod_generator_base_with_secrets.yaml' - base_pod = PodGenerator( - pod_template_file=path, - extract_xcom=False - ).gen_pod() + base_pod = PodGenerator(pod_template_file=path, extract_xcom=False).gen_pod() mutator_pod = k8s.V1Pod( metadata=k8s.V1ObjectMeta( @@ -416,19 +360,23 @@ def test_reconcile_pods(self, mock_uuid): labels={"bar": "baz"}, ), spec=k8s.V1PodSpec( - containers=[k8s.V1Container( - image='', - name='name', - command=['/bin/command2.sh', 'arg2'], - volume_mounts=[k8s.V1VolumeMount(mount_path="/foo/", - name="example-kubernetes-test-volume2")] - )], + containers=[ + k8s.V1Container( + image='', + name='name', + command=['/bin/command2.sh', 'arg2'], + volume_mounts=[ + k8s.V1VolumeMount(mount_path="/foo/", name="example-kubernetes-test-volume2") + ], + ) + ], volumes=[ - k8s.V1Volume(host_path=k8s.V1HostPathVolumeSource(path="/tmp/"), - name="example-kubernetes-test-volume2") - ] - ) - + k8s.V1Volume( + host_path=k8s.V1HostPathVolumeSource(path="/tmp/"), + name="example-kubernetes-test-volume2", + ) + ], + ), ) result = PodGenerator.reconcile_pods(base_pod, mutator_pod) @@ -437,15 +385,15 @@ def test_reconcile_pods(self, mock_uuid): expected.metadata.labels['bar'] = 'baz' expected.spec.volumes = expected.spec.volumes or [] expected.spec.volumes.append( - k8s.V1Volume(host_path=k8s.V1HostPathVolumeSource(path="/tmp/"), - name="example-kubernetes-test-volume2") + k8s.V1Volume( + host_path=k8s.V1HostPathVolumeSource(path="/tmp/"), name="example-kubernetes-test-volume2" + ) ) base_container: k8s.V1Container = expected.spec.containers[0] base_container.command = ['/bin/command2.sh', 'arg2'] base_container.volume_mounts = [ - k8s.V1VolumeMount(mount_path="/foo/", - name="example-kubernetes-test-volume2") + k8s.V1VolumeMount(mount_path="/foo/", name="example-kubernetes-test-volume2") ] base_container.name = "name" expected.spec.containers[0] = base_container @@ -464,13 +412,7 @@ def test_construct_pod(self, mock_uuid): spec=k8s.V1PodSpec( containers=[ k8s.V1Container( - name='', - resources=k8s.V1ResourceRequirements( - limits={ - 'cpu': '1m', - 'memory': '1G' - } - ) + name='', resources=k8s.V1ResourceRequirements(limits={'cpu': '1m', 'memory': '1G'}) ) ] ) @@ -497,9 +439,7 @@ def test_construct_pod(self, mock_uuid): expected.metadata.namespace = 'test_namespace' expected.spec.containers[0].command = ['command'] expected.spec.containers[0].image = 'airflow_image' - expected.spec.containers[0].resources = {'limits': {'cpu': '1m', - 'memory': '1G'} - } + expected.spec.containers[0].resources = {'limits': {'cpu': '1m', 'memory': '1G'}} result_dict = self.k8s_client.sanitize_for_serialization(result) expected_dict = self.k8s_client.sanitize_for_serialization(self.expected) @@ -560,10 +500,7 @@ def test_merge_objects(self): base_annotations = {'foo1': 'bar1'} base_labels = {'foo1': 'bar1'} client_annotations = {'foo2': 'bar2'} - base_obj = k8s.V1ObjectMeta( - annotations=base_annotations, - labels=base_labels - ) + base_obj = k8s.V1ObjectMeta(annotations=base_annotations, labels=base_labels) client_obj = k8s.V1ObjectMeta(annotations=client_annotations) res = merge_objects(base_obj, client_obj) client_obj.labels = base_labels diff --git a/tests/kubernetes/test_pod_launcher.py b/tests/kubernetes/test_pod_launcher.py index b85ac2bc5bb59..2b55a7770bdcf 100644 --- a/tests/kubernetes/test_pod_launcher.py +++ b/tests/kubernetes/test_pod_launcher.py @@ -24,7 +24,6 @@ class TestPodLauncher(unittest.TestCase): - def setUp(self): self.mock_kube_client = mock.Mock() self.pod_launcher = PodLauncher(kube_client=self.mock_kube_client) @@ -39,79 +38,77 @@ def test_read_pod_logs_retries_successfully(self): mock.sentinel.metadata = mock.MagicMock() self.mock_kube_client.read_namespaced_pod_log.side_effect = [ BaseHTTPError('Boom'), - mock.sentinel.logs + mock.sentinel.logs, ] logs = self.pod_launcher.read_pod_logs(mock.sentinel) self.assertEqual(mock.sentinel.logs, logs) - self.mock_kube_client.read_namespaced_pod_log.assert_has_calls([ - mock.call( - _preload_content=False, - container='base', - follow=True, - timestamps=False, - name=mock.sentinel.metadata.name, - namespace=mock.sentinel.metadata.namespace - ), - mock.call( - _preload_content=False, - container='base', - follow=True, - timestamps=False, - name=mock.sentinel.metadata.name, - namespace=mock.sentinel.metadata.namespace - ) - ]) + self.mock_kube_client.read_namespaced_pod_log.assert_has_calls( + [ + mock.call( + _preload_content=False, + container='base', + follow=True, + timestamps=False, + name=mock.sentinel.metadata.name, + namespace=mock.sentinel.metadata.namespace, + ), + mock.call( + _preload_content=False, + container='base', + follow=True, + timestamps=False, + name=mock.sentinel.metadata.name, + namespace=mock.sentinel.metadata.namespace, + ), + ] + ) def test_read_pod_logs_retries_fails(self): mock.sentinel.metadata = mock.MagicMock() self.mock_kube_client.read_namespaced_pod_log.side_effect = [ BaseHTTPError('Boom'), BaseHTTPError('Boom'), - BaseHTTPError('Boom') + BaseHTTPError('Boom'), ] - self.assertRaises( - AirflowException, - self.pod_launcher.read_pod_logs, - mock.sentinel - ) + self.assertRaises(AirflowException, self.pod_launcher.read_pod_logs, mock.sentinel) def test_read_pod_logs_successfully_with_tail_lines(self): mock.sentinel.metadata = mock.MagicMock() - self.mock_kube_client.read_namespaced_pod_log.side_effect = [ - mock.sentinel.logs - ] + self.mock_kube_client.read_namespaced_pod_log.side_effect = [mock.sentinel.logs] logs = self.pod_launcher.read_pod_logs(mock.sentinel, tail_lines=100) self.assertEqual(mock.sentinel.logs, logs) - self.mock_kube_client.read_namespaced_pod_log.assert_has_calls([ - mock.call( - _preload_content=False, - container='base', - follow=True, - timestamps=False, - name=mock.sentinel.metadata.name, - namespace=mock.sentinel.metadata.namespace, - tail_lines=100 - ), - ]) + self.mock_kube_client.read_namespaced_pod_log.assert_has_calls( + [ + mock.call( + _preload_content=False, + container='base', + follow=True, + timestamps=False, + name=mock.sentinel.metadata.name, + namespace=mock.sentinel.metadata.namespace, + tail_lines=100, + ), + ] + ) def test_read_pod_logs_successfully_with_since_seconds(self): mock.sentinel.metadata = mock.MagicMock() - self.mock_kube_client.read_namespaced_pod_log.side_effect = [ - mock.sentinel.logs - ] + self.mock_kube_client.read_namespaced_pod_log.side_effect = [mock.sentinel.logs] logs = self.pod_launcher.read_pod_logs(mock.sentinel, since_seconds=2) self.assertEqual(mock.sentinel.logs, logs) - self.mock_kube_client.read_namespaced_pod_log.assert_has_calls([ - mock.call( - _preload_content=False, - container='base', - follow=True, - timestamps=False, - name=mock.sentinel.metadata.name, - namespace=mock.sentinel.metadata.namespace, - since_seconds=2 - ), - ]) + self.mock_kube_client.read_namespaced_pod_log.assert_has_calls( + [ + mock.call( + _preload_content=False, + container='base', + follow=True, + timestamps=False, + name=mock.sentinel.metadata.name, + namespace=mock.sentinel.metadata.namespace, + since_seconds=2, + ), + ] + ) def test_read_pod_events_successfully_returns_events(self): mock.sentinel.metadata = mock.MagicMock() @@ -123,33 +120,31 @@ def test_read_pod_events_retries_successfully(self): mock.sentinel.metadata = mock.MagicMock() self.mock_kube_client.list_namespaced_event.side_effect = [ BaseHTTPError('Boom'), - mock.sentinel.events + mock.sentinel.events, ] events = self.pod_launcher.read_pod_events(mock.sentinel) self.assertEqual(mock.sentinel.events, events) - self.mock_kube_client.list_namespaced_event.assert_has_calls([ - mock.call( - namespace=mock.sentinel.metadata.namespace, - field_selector=f"involvedObject.name={mock.sentinel.metadata.name}" - ), - mock.call( - namespace=mock.sentinel.metadata.namespace, - field_selector=f"involvedObject.name={mock.sentinel.metadata.name}" - ) - ]) + self.mock_kube_client.list_namespaced_event.assert_has_calls( + [ + mock.call( + namespace=mock.sentinel.metadata.namespace, + field_selector=f"involvedObject.name={mock.sentinel.metadata.name}", + ), + mock.call( + namespace=mock.sentinel.metadata.namespace, + field_selector=f"involvedObject.name={mock.sentinel.metadata.name}", + ), + ] + ) def test_read_pod_events_retries_fails(self): mock.sentinel.metadata = mock.MagicMock() self.mock_kube_client.list_namespaced_event.side_effect = [ BaseHTTPError('Boom'), BaseHTTPError('Boom'), - BaseHTTPError('Boom') + BaseHTTPError('Boom'), ] - self.assertRaises( - AirflowException, - self.pod_launcher.read_pod_events, - mock.sentinel - ) + self.assertRaises(AirflowException, self.pod_launcher.read_pod_events, mock.sentinel) def test_read_pod_returns_logs(self): mock.sentinel.metadata = mock.MagicMock() @@ -161,31 +156,30 @@ def test_read_pod_retries_successfully(self): mock.sentinel.metadata = mock.MagicMock() self.mock_kube_client.read_namespaced_pod.side_effect = [ BaseHTTPError('Boom'), - mock.sentinel.pod_info + mock.sentinel.pod_info, ] pod_info = self.pod_launcher.read_pod(mock.sentinel) self.assertEqual(mock.sentinel.pod_info, pod_info) - self.mock_kube_client.read_namespaced_pod.assert_has_calls([ - mock.call(mock.sentinel.metadata.name, mock.sentinel.metadata.namespace), - mock.call(mock.sentinel.metadata.name, mock.sentinel.metadata.namespace) - ]) + self.mock_kube_client.read_namespaced_pod.assert_has_calls( + [ + mock.call(mock.sentinel.metadata.name, mock.sentinel.metadata.namespace), + mock.call(mock.sentinel.metadata.name, mock.sentinel.metadata.namespace), + ] + ) def test_read_pod_retries_fails(self): mock.sentinel.metadata = mock.MagicMock() self.mock_kube_client.read_namespaced_pod.side_effect = [ BaseHTTPError('Boom'), BaseHTTPError('Boom'), - BaseHTTPError('Boom') + BaseHTTPError('Boom'), ] - self.assertRaises( - AirflowException, - self.pod_launcher.read_pod, - mock.sentinel - ) + self.assertRaises(AirflowException, self.pod_launcher.read_pod, mock.sentinel) def test_parse_log_line(self): - timestamp, message = \ - self.pod_launcher.parse_log_line('2020-10-08T14:16:17.793417674Z Valid message\n') + timestamp, message = self.pod_launcher.parse_log_line( + '2020-10-08T14:16:17.793417674Z Valid message\n' + ) self.assertEqual(timestamp, '2020-10-08T14:16:17.793417674Z') self.assertEqual(message, 'Valid message') diff --git a/tests/kubernetes/test_refresh_config.py b/tests/kubernetes/test_refresh_config.py index 98a671de08bee..ca3740efd2257 100644 --- a/tests/kubernetes/test_refresh_config.py +++ b/tests/kubernetes/test_refresh_config.py @@ -23,7 +23,6 @@ class TestRefreshKubeConfigLoader(TestCase): - def test_parse_timestamp_should_convert_z_timezone_to_unix_timestamp(self): ts = _parse_timestamp("2020-01-13T13:42:20Z") self.assertEqual(1578922940, ts) diff --git a/tests/lineage/test_lineage.py b/tests/lineage/test_lineage.py index dddaaaa70722d..fddf723576cf9 100644 --- a/tests/lineage/test_lineage.py +++ b/tests/lineage/test_lineage.py @@ -27,12 +27,8 @@ class TestLineage(unittest.TestCase): - def test_lineage(self): - dag = DAG( - dag_id='test_prepare_lineage', - start_date=DEFAULT_DATE - ) + dag = DAG(dag_id='test_prepare_lineage', start_date=DEFAULT_DATE) f1s = "/tmp/does_not_exist_1-{}" f2s = "/tmp/does_not_exist_2-{}" @@ -42,16 +38,17 @@ def test_lineage(self): file3 = File(f3s) with dag: - op1 = DummyOperator(task_id='leave1', - inlets=file1, - outlets=[file2, ]) + op1 = DummyOperator( + task_id='leave1', + inlets=file1, + outlets=[ + file2, + ], + ) op2 = DummyOperator(task_id='leave2') - op3 = DummyOperator(task_id='upstream_level_1', - inlets=AUTO, - outlets=file3) + op3 = DummyOperator(task_id='upstream_level_1', inlets=AUTO, outlets=file3) op4 = DummyOperator(task_id='upstream_level_2') - op5 = DummyOperator(task_id='upstream_level_3', - inlets=["leave1", "upstream_level_1"]) + op5 = DummyOperator(task_id='upstream_level_3', inlets=["leave1", "upstream_level_1"]) op1.set_downstream(op3) op2.set_downstream(op3) @@ -61,14 +58,10 @@ def test_lineage(self): dag.clear() # execution_date is set in the context in order to avoid creating task instances - ctx1 = {"ti": TI(task=op1, execution_date=DEFAULT_DATE), - "execution_date": DEFAULT_DATE} - ctx2 = {"ti": TI(task=op2, execution_date=DEFAULT_DATE), - "execution_date": DEFAULT_DATE} - ctx3 = {"ti": TI(task=op3, execution_date=DEFAULT_DATE), - "execution_date": DEFAULT_DATE} - ctx5 = {"ti": TI(task=op5, execution_date=DEFAULT_DATE), - "execution_date": DEFAULT_DATE} + ctx1 = {"ti": TI(task=op1, execution_date=DEFAULT_DATE), "execution_date": DEFAULT_DATE} + ctx2 = {"ti": TI(task=op2, execution_date=DEFAULT_DATE), "execution_date": DEFAULT_DATE} + ctx3 = {"ti": TI(task=op3, execution_date=DEFAULT_DATE), "execution_date": DEFAULT_DATE} + ctx5 = {"ti": TI(task=op5, execution_date=DEFAULT_DATE), "execution_date": DEFAULT_DATE} # prepare with manual inlets and outlets op1.pre_execute(ctx1) @@ -101,10 +94,7 @@ def test_lineage(self): def test_lineage_render(self): # tests inlets / outlets are rendered if they are added # after initalization - dag = DAG( - dag_id='test_lineage_render', - start_date=DEFAULT_DATE - ) + dag = DAG(dag_id='test_lineage_render', start_date=DEFAULT_DATE) with dag: op1 = DummyOperator(task_id='task1') @@ -116,8 +106,7 @@ def test_lineage_render(self): op1.outlets.append(file1) # execution_date is set in the context in order to avoid creating task instances - ctx1 = {"ti": TI(task=op1, execution_date=DEFAULT_DATE), - "execution_date": DEFAULT_DATE} + ctx1 = {"ti": TI(task=op1, execution_date=DEFAULT_DATE), "execution_date": DEFAULT_DATE} op1.pre_execute(ctx1) self.assertEqual(op1.inlets[0].url, f1s.format(DEFAULT_DATE)) diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py index 8f835afd39e0a..9232d48d4c7a0 100644 --- a/tests/models/test_baseoperator.py +++ b/tests/models/test_baseoperator.py @@ -55,15 +55,8 @@ def __ne__(self, other): # Objects with circular references (for testing purpose) -object1 = ClassWithCustomAttributes( - attr="{{ foo }}_1", - template_fields=["ref"] -) -object2 = ClassWithCustomAttributes( - attr="{{ foo }}_2", - ref=object1, - template_fields=["ref"] -) +object1 = ClassWithCustomAttributes(attr="{{ foo }}_1", template_fields=["ref"]) +object2 = ClassWithCustomAttributes(attr="{{ foo }}_2", ref=object1, template_fields=["ref"]) setattr(object1, 'ref', object2) @@ -98,29 +91,31 @@ class TestBaseOperator(unittest.TestCase): ), ( # check deep nested fields can be templated - ClassWithCustomAttributes(nested1=ClassWithCustomAttributes(att1="{{ foo }}_1", - att2="{{ foo }}_2", - template_fields=["att1"]), - nested2=ClassWithCustomAttributes(att3="{{ foo }}_3", - att4="{{ foo }}_4", - template_fields=["att3"]), - template_fields=["nested1"]), + ClassWithCustomAttributes( + nested1=ClassWithCustomAttributes( + att1="{{ foo }}_1", att2="{{ foo }}_2", template_fields=["att1"] + ), + nested2=ClassWithCustomAttributes( + att3="{{ foo }}_3", att4="{{ foo }}_4", template_fields=["att3"] + ), + template_fields=["nested1"], + ), {"foo": "bar"}, - ClassWithCustomAttributes(nested1=ClassWithCustomAttributes(att1="bar_1", - att2="{{ foo }}_2", - template_fields=["att1"]), - nested2=ClassWithCustomAttributes(att3="{{ foo }}_3", - att4="{{ foo }}_4", - template_fields=["att3"]), - template_fields=["nested1"]), + ClassWithCustomAttributes( + nested1=ClassWithCustomAttributes( + att1="bar_1", att2="{{ foo }}_2", template_fields=["att1"] + ), + nested2=ClassWithCustomAttributes( + att3="{{ foo }}_3", att4="{{ foo }}_4", template_fields=["att3"] + ), + template_fields=["nested1"], + ), ), ( # check null value on nested template field - ClassWithCustomAttributes(att1=None, - template_fields=["att1"]), + ClassWithCustomAttributes(att1=None, template_fields=["att1"]), {}, - ClassWithCustomAttributes(att1=None, - template_fields=["att1"]), + ClassWithCustomAttributes(att1=None, template_fields=["att1"]), ), ( # check there is no RecursionError on circular references @@ -214,8 +209,9 @@ def test_nested_template_fields_declared_must_exist(self): with self.assertRaises(AttributeError) as e: task.render_template(ClassWithCustomAttributes(template_fields=["missing_field"]), {}) - self.assertEqual("'ClassWithCustomAttributes' object has no attribute 'missing_field'", - str(e.exception)) + self.assertEqual( + "'ClassWithCustomAttributes' object has no attribute 'missing_field'", str(e.exception) + ) def test_jinja_invalid_expression_is_just_propagated(self): """Test render_template propagates Jinja invalid expression errors.""" @@ -236,9 +232,9 @@ def test_jinja_env_creation(self, mock_jinja_env): def test_set_jinja_env_additional_option(self): """Test render_template given various input types.""" - with DAG("test-dag", - start_date=DEFAULT_DATE, - jinja_environment_kwargs={'keep_trailing_newline': True}): + with DAG( + "test-dag", start_date=DEFAULT_DATE, jinja_environment_kwargs={'keep_trailing_newline': True} + ): task = DummyOperator(task_id="op1") result = task.render_template("{{ foo }}\n\n", {"foo": "bar"}) @@ -246,9 +242,7 @@ def test_set_jinja_env_additional_option(self): def test_override_jinja_env_option(self): """Test render_template given various input types.""" - with DAG("test-dag", - start_date=DEFAULT_DATE, - jinja_environment_kwargs={'cache_size': 50}): + with DAG("test-dag", start_date=DEFAULT_DATE, jinja_environment_kwargs={'cache_size': 50}): task = DummyOperator(task_id="op1") result = task.render_template("{{ foo }}", {"foo": "bar"}) @@ -270,9 +264,7 @@ def test_default_email_on_actions(self): def test_email_on_actions(self): test_task = DummyOperator( - task_id='test_default_email_on_actions', - email_on_retry=False, - email_on_failure=True + task_id='test_default_email_on_actions', email_on_retry=False, email_on_failure=True ) assert test_task.email_on_retry is False assert test_task.email_on_failure is True @@ -291,10 +283,7 @@ def test_cross_downstream(self): def test_chain(self): dag = DAG(dag_id='test_chain', start_date=datetime.now()) - [op1, op2, op3, op4, op5, op6] = [ - DummyOperator(task_id=f't{i}', dag=dag) - for i in range(1, 7) - ] + [op1, op2, op3, op4, op5, op6] = [DummyOperator(task_id=f't{i}', dag=dag) for i in range(1, 7)] chain(op1, [op2, op3], [op4, op5], op6) self.assertCountEqual([op2, op3], op1.get_direct_relatives(upstream=False)) @@ -390,9 +379,7 @@ def test_setattr_performs_no_custom_action_at_execute_time(self): op = CustomOp(task_id="test_task") op_copy = op.prepare_for_execution() - with mock.patch( - "airflow.models.baseoperator.BaseOperator.set_xcomargs_dependencies" - ) as method_mock: + with mock.patch("airflow.models.baseoperator.BaseOperator.set_xcomargs_dependencies") as method_mock: op_copy.execute({}) assert method_mock.call_count == 0 diff --git a/tests/models/test_cleartasks.py b/tests/models/test_cleartasks.py index 9efe51ae150db..1909e6c9fa08b 100644 --- a/tests/models/test_cleartasks.py +++ b/tests/models/test_cleartasks.py @@ -30,7 +30,6 @@ class TestClearTasks(unittest.TestCase): - def setUp(self) -> None: db.clear_db_runs() @@ -38,8 +37,11 @@ def tearDown(self): db.clear_db_runs() def test_clear_task_instances(self): - dag = DAG('test_clear_task_instances', start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=10)) + dag = DAG( + 'test_clear_task_instances', + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=10), + ) task0 = DummyOperator(task_id='0', owner='test', dag=dag) task1 = DummyOperator(task_id='1', owner='test', dag=dag, retries=2) ti0 = TI(task=task0, execution_date=DEFAULT_DATE) @@ -54,8 +56,7 @@ def test_clear_task_instances(self): ti0.run() ti1.run() with create_session() as session: - qry = session.query(TI).filter( - TI.dag_id == dag.dag_id).all() + qry = session.query(TI).filter(TI.dag_id == dag.dag_id).all() clear_task_instances(qry, session, dag=dag) ti0.refresh_from_db() @@ -67,8 +68,11 @@ def test_clear_task_instances(self): self.assertEqual(ti1.max_tries, 3) def test_clear_task_instances_without_task(self): - dag = DAG('test_clear_task_instances_without_task', start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=10)) + dag = DAG( + 'test_clear_task_instances_without_task', + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=10), + ) task0 = DummyOperator(task_id='task0', owner='test', dag=dag) task1 = DummyOperator(task_id='task1', owner='test', dag=dag, retries=2) ti0 = TI(task=task0, execution_date=DEFAULT_DATE) @@ -89,8 +93,7 @@ def test_clear_task_instances_without_task(self): self.assertFalse(dag.has_task(task1.task_id)) with create_session() as session: - qry = session.query(TI).filter( - TI.dag_id == dag.dag_id).all() + qry = session.query(TI).filter(TI.dag_id == dag.dag_id).all() clear_task_instances(qry, session) # When dag is None, max_tries will be maximum of original max_tries or try_number. @@ -103,8 +106,11 @@ def test_clear_task_instances_without_task(self): self.assertEqual(ti1.max_tries, 2) def test_clear_task_instances_without_dag(self): - dag = DAG('test_clear_task_instances_without_dag', start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=10)) + dag = DAG( + 'test_clear_task_instances_without_dag', + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=10), + ) task0 = DummyOperator(task_id='task_0', owner='test', dag=dag) task1 = DummyOperator(task_id='task_1', owner='test', dag=dag, retries=2) ti0 = TI(task=task0, execution_date=DEFAULT_DATE) @@ -120,8 +126,7 @@ def test_clear_task_instances_without_dag(self): ti1.run() with create_session() as session: - qry = session.query(TI).filter( - TI.dag_id == dag.dag_id).all() + qry = session.query(TI).filter(TI.dag_id == dag.dag_id).all() clear_task_instances(qry, session) # When dag is None, max_tries will be maximum of original max_tries or try_number. @@ -134,8 +139,9 @@ def test_clear_task_instances_without_dag(self): self.assertEqual(ti1.max_tries, 2) def test_dag_clear(self): - dag = DAG('test_dag_clear', start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=10)) + dag = DAG( + 'test_dag_clear', start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=10) + ) task0 = DummyOperator(task_id='test_dag_clear_task_0', owner='test', dag=dag) ti0 = TI(task=task0, execution_date=DEFAULT_DATE) @@ -155,8 +161,7 @@ def test_dag_clear(self): self.assertEqual(ti0.state, State.NONE) self.assertEqual(ti0.max_tries, 1) - task1 = DummyOperator(task_id='test_dag_clear_task_1', owner='test', - dag=dag, retries=2) + task1 = DummyOperator(task_id='test_dag_clear_task_1', owner='test', dag=dag, retries=2) ti1 = TI(task=task1, execution_date=DEFAULT_DATE) self.assertEqual(ti1.max_tries, 2) ti1.try_number = 1 @@ -181,11 +186,15 @@ def test_dags_clear(self): dags, tis = [], [] num_of_dags = 5 for i in range(num_of_dags): - dag = DAG('test_dag_clear_' + str(i), start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=10)) - ti = TI(task=DummyOperator(task_id='test_task_clear_' + str(i), owner='test', - dag=dag), - execution_date=DEFAULT_DATE) + dag = DAG( + 'test_dag_clear_' + str(i), + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=10), + ) + ti = TI( + task=DummyOperator(task_id='test_task_clear_' + str(i), owner='test', dag=dag), + execution_date=DEFAULT_DATE, + ) dag.create_dagrun( execution_date=ti.execution_date, @@ -227,6 +236,7 @@ def test_dags_clear(self): # test only_failed from random import randint + failed_dag_idx = randint(0, len(tis) - 1) tis[failed_dag_idx].state = State.FAILED session.merge(tis[failed_dag_idx]) @@ -246,8 +256,11 @@ def test_dags_clear(self): self.assertEqual(tis[i].max_tries, 2) def test_operator_clear(self): - dag = DAG('test_operator_clear', start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=10)) + dag = DAG( + 'test_operator_clear', + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=10), + ) op1 = DummyOperator(task_id='bash_op', owner='test', dag=dag) op2 = DummyOperator(task_id='dummy_op', owner='test', dag=dag, retries=1) diff --git a/tests/models/test_connection.py b/tests/models/test_connection.py index df3ace8fdb80b..fbd275d01164c 100644 --- a/tests/models/test_connection.py +++ b/tests/models/test_connection.py @@ -126,7 +126,7 @@ def test_connection_extra_with_encryption_rotate_fernet_key(self): ), UriTestCaseConfig( test_conn_uri='scheme://user:password@host%2Flocation:1234/schema?' - 'extra1=a%20value&extra2=%2Fpath%2F', + 'extra1=a%20value&extra2=%2Fpath%2F', test_conn_attributes=dict( conn_type='scheme', host='host/location', @@ -134,9 +134,9 @@ def test_connection_extra_with_encryption_rotate_fernet_key(self): login='user', password='password', port=1234, - extra_dejson={'extra1': 'a value', 'extra2': '/path/'} + extra_dejson={'extra1': 'a value', 'extra2': '/path/'}, ), - description='with extras' + description='with extras', ), UriTestCaseConfig( test_conn_uri='scheme://user:password@host%2Flocation:1234/schema?extra1=a%20value&extra2=', @@ -147,13 +147,13 @@ def test_connection_extra_with_encryption_rotate_fernet_key(self): login='user', password='password', port=1234, - extra_dejson={'extra1': 'a value', 'extra2': ''} + extra_dejson={'extra1': 'a value', 'extra2': ''}, ), - description='with empty extras' + description='with empty extras', ), UriTestCaseConfig( test_conn_uri='scheme://user:password@host%2Flocation%3Ax%3Ay:1234/schema?' - 'extra1=a%20value&extra2=%2Fpath%2F', + 'extra1=a%20value&extra2=%2Fpath%2F', test_conn_attributes=dict( conn_type='scheme', host='host/location:x:y', @@ -163,7 +163,7 @@ def test_connection_extra_with_encryption_rotate_fernet_key(self): port=1234, extra_dejson={'extra1': 'a value', 'extra2': '/path/'}, ), - description='with colon in hostname' + description='with colon in hostname', ), UriTestCaseConfig( test_conn_uri='scheme://user:password%20with%20space@host%2Flocation%3Ax%3Ay:1234/schema', @@ -175,7 +175,7 @@ def test_connection_extra_with_encryption_rotate_fernet_key(self): password='password with space', port=1234, ), - description='with encoded password' + description='with encoded password', ), UriTestCaseConfig( test_conn_uri='scheme://domain%2Fuser:password@host%2Flocation%3Ax%3Ay:1234/schema', @@ -199,7 +199,7 @@ def test_connection_extra_with_encryption_rotate_fernet_key(self): password='password with space', port=1234, ), - description='with encoded schema' + description='with encoded schema', ), UriTestCaseConfig( test_conn_uri='scheme://user:password%20with%20space@host:1234', @@ -211,7 +211,7 @@ def test_connection_extra_with_encryption_rotate_fernet_key(self): password='password with space', port=1234, ), - description='no schema' + description='no schema', ), UriTestCaseConfig( test_conn_uri='google-cloud-platform://?extra__google_cloud_platform__key_' @@ -229,7 +229,7 @@ def test_connection_extra_with_encryption_rotate_fernet_key(self): extra__google_cloud_platform__key_path='/keys/key.json', extra__google_cloud_platform__scope='https://www.googleapis.com/auth/cloud-platform', extra__google_cloud_platform__project='airflow', - ) + ), ), description='with underscore', ), @@ -243,7 +243,7 @@ def test_connection_extra_with_encryption_rotate_fernet_key(self): password=None, port=1234, ), - description='without auth info' + description='without auth info', ), UriTestCaseConfig( test_conn_uri='scheme://%2FTmP%2F:1234', @@ -435,9 +435,12 @@ def test_connection_from_with_auth_info(self, uri, uri_parts): self.assertEqual(connection.port, uri_parts.port) self.assertEqual(connection.schema, uri_parts.schema) - @mock.patch.dict('os.environ', { - 'AIRFLOW_CONN_TEST_URI': 'postgres://username:password@ec2.compute.com:5432/the_database', - }) + @mock.patch.dict( + 'os.environ', + { + 'AIRFLOW_CONN_TEST_URI': 'postgres://username:password@ec2.compute.com:5432/the_database', + }, + ) def test_using_env_var(self): conn = SqliteHook.get_connection(conn_id='test_uri') self.assertEqual('ec2.compute.com', conn.host) @@ -446,9 +449,12 @@ def test_using_env_var(self): self.assertEqual('password', conn.password) self.assertEqual(5432, conn.port) - @mock.patch.dict('os.environ', { - 'AIRFLOW_CONN_TEST_URI_NO_CREDS': 'postgres://ec2.compute.com/the_database', - }) + @mock.patch.dict( + 'os.environ', + { + 'AIRFLOW_CONN_TEST_URI_NO_CREDS': 'postgres://ec2.compute.com/the_database', + }, + ) def test_using_unix_socket_env_var(self): conn = SqliteHook.get_connection(conn_id='test_uri_no_creds') self.assertEqual('ec2.compute.com', conn.host) @@ -458,9 +464,14 @@ def test_using_unix_socket_env_var(self): self.assertIsNone(conn.port) def test_param_setup(self): - conn = Connection(conn_id='local_mysql', conn_type='mysql', - host='localhost', login='airflow', - password='airflow', schema='airflow') + conn = Connection( + conn_id='local_mysql', + conn_type='mysql', + host='localhost', + login='airflow', + password='airflow', + schema='airflow', + ) self.assertEqual('localhost', conn.host) self.assertEqual('airflow', conn.schema) self.assertEqual('airflow', conn.login) @@ -471,9 +482,12 @@ def test_env_var_priority(self): conn = SqliteHook.get_connection(conn_id='airflow_db') self.assertNotEqual('ec2.compute.com', conn.host) - with mock.patch.dict('os.environ', { - 'AIRFLOW_CONN_AIRFLOW_DB': 'postgres://username:password@ec2.compute.com:5432/the_database', - }): + with mock.patch.dict( + 'os.environ', + { + 'AIRFLOW_CONN_AIRFLOW_DB': 'postgres://username:password@ec2.compute.com:5432/the_database', + }, + ): conn = SqliteHook.get_connection(conn_id='airflow_db') self.assertEqual('ec2.compute.com', conn.host) self.assertEqual('the_database', conn.schema) @@ -481,10 +495,13 @@ def test_env_var_priority(self): self.assertEqual('password', conn.password) self.assertEqual(5432, conn.port) - @mock.patch.dict('os.environ', { - 'AIRFLOW_CONN_TEST_URI': 'postgres://username:password@ec2.compute.com:5432/the_database', - 'AIRFLOW_CONN_TEST_URI_NO_CREDS': 'postgres://ec2.compute.com/the_database', - }) + @mock.patch.dict( + 'os.environ', + { + 'AIRFLOW_CONN_TEST_URI': 'postgres://username:password@ec2.compute.com:5432/the_database', + 'AIRFLOW_CONN_TEST_URI_NO_CREDS': 'postgres://ec2.compute.com/the_database', + }, + ) def test_dbapi_get_uri(self): conn = BaseHook.get_connection(conn_id='test_uri') hook = conn.get_hook() @@ -493,10 +510,13 @@ def test_dbapi_get_uri(self): hook2 = conn2.get_hook() self.assertEqual('postgres://ec2.compute.com/the_database', hook2.get_uri()) - @mock.patch.dict('os.environ', { - 'AIRFLOW_CONN_TEST_URI': 'postgres://username:password@ec2.compute.com:5432/the_database', - 'AIRFLOW_CONN_TEST_URI_NO_CREDS': 'postgres://ec2.compute.com/the_database', - }) + @mock.patch.dict( + 'os.environ', + { + 'AIRFLOW_CONN_TEST_URI': 'postgres://username:password@ec2.compute.com:5432/the_database', + 'AIRFLOW_CONN_TEST_URI_NO_CREDS': 'postgres://ec2.compute.com/the_database', + }, + ) def test_dbapi_get_sqlalchemy_engine(self): conn = BaseHook.get_connection(conn_id='test_uri') hook = conn.get_hook() @@ -504,10 +524,13 @@ def test_dbapi_get_sqlalchemy_engine(self): self.assertIsInstance(engine, sqlalchemy.engine.Engine) self.assertEqual('postgres://username:password@ec2.compute.com:5432/the_database', str(engine.url)) - @mock.patch.dict('os.environ', { - 'AIRFLOW_CONN_TEST_URI': 'postgres://username:password@ec2.compute.com:5432/the_database', - 'AIRFLOW_CONN_TEST_URI_NO_CREDS': 'postgres://ec2.compute.com/the_database', - }) + @mock.patch.dict( + 'os.environ', + { + 'AIRFLOW_CONN_TEST_URI': 'postgres://username:password@ec2.compute.com:5432/the_database', + 'AIRFLOW_CONN_TEST_URI_NO_CREDS': 'postgres://ec2.compute.com/the_database', + }, + ) def test_get_connections_env_var(self): conns = SqliteHook.get_connections(conn_id='test_uri') assert len(conns) == 1 @@ -523,7 +546,7 @@ def test_connection_mixed(self): re.escape( "You must create an object using the URI or individual values (conn_type, host, login, " "password, schema, port or extra).You can't mix these two ways to create this object." - ) + ), ): Connection(conn_id="TEST_ID", uri="mysql://", schema="AAA") diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 7fd0f2dbab695..7b09ba89ddba4 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -63,7 +63,6 @@ class TestDag(unittest.TestCase): - def setUp(self) -> None: clear_db_runs() clear_db_dags() @@ -78,15 +77,9 @@ def tearDown(self) -> None: @staticmethod def _clean_up(dag_id: str): with create_session() as session: - session.query(DagRun).filter( - DagRun.dag_id == dag_id).delete( - synchronize_session=False) - session.query(TI).filter( - TI.dag_id == dag_id).delete( - synchronize_session=False) - session.query(TaskFail).filter( - TaskFail.dag_id == dag_id).delete( - synchronize_session=False) + session.query(DagRun).filter(DagRun.dag_id == dag_id).delete(synchronize_session=False) + session.query(TI).filter(TI.dag_id == dag_id).delete(synchronize_session=False) + session.query(TaskFail).filter(TaskFail.dag_id == dag_id).delete(synchronize_session=False) @staticmethod def _occur_before(a, b, list_): @@ -122,9 +115,7 @@ def test_params_passed_and_params_in_default_args_no_override(self): params1 = {'parameter1': 1} params2 = {'parameter2': 2} - dag = models.DAG('test-dag', - default_args={'params': params1}, - params=params2) + dag = models.DAG('test-dag', default_args={'params': params1}, params=params2) params_combined = params1.copy() params_combined.update(params2) @@ -134,43 +125,29 @@ def test_dag_invalid_default_view(self): """ Test invalid `default_view` of DAG initialization """ - with self.assertRaisesRegex(AirflowException, - 'Invalid values of dag.default_view: only support'): - models.DAG( - dag_id='test-invalid-default_view', - default_view='airflow' - ) + with self.assertRaisesRegex(AirflowException, 'Invalid values of dag.default_view: only support'): + models.DAG(dag_id='test-invalid-default_view', default_view='airflow') def test_dag_default_view_default_value(self): """ Test `default_view` default value of DAG initialization """ - dag = models.DAG( - dag_id='test-default_default_view' - ) - self.assertEqual(conf.get('webserver', 'dag_default_view').lower(), - dag.default_view) + dag = models.DAG(dag_id='test-default_default_view') + self.assertEqual(conf.get('webserver', 'dag_default_view').lower(), dag.default_view) def test_dag_invalid_orientation(self): """ Test invalid `orientation` of DAG initialization """ - with self.assertRaisesRegex(AirflowException, - 'Invalid values of dag.orientation: only support'): - models.DAG( - dag_id='test-invalid-orientation', - orientation='airflow' - ) + with self.assertRaisesRegex(AirflowException, 'Invalid values of dag.orientation: only support'): + models.DAG(dag_id='test-invalid-orientation', orientation='airflow') def test_dag_orientation_default_value(self): """ Test `orientation` default value of DAG initialization """ - dag = models.DAG( - dag_id='test-default_orientation' - ) - self.assertEqual(conf.get('webserver', 'dag_orientation'), - dag.orientation) + dag = models.DAG(dag_id='test-default_orientation') + self.assertEqual(conf.get('webserver', 'dag_orientation'), dag.orientation) def test_dag_as_context_manager(self): """ @@ -178,14 +155,8 @@ def test_dag_as_context_manager(self): When used as a context manager, Operators are automatically added to the DAG (unless they specify a different DAG) """ - dag = DAG( - 'dag', - start_date=DEFAULT_DATE, - default_args={'owner': 'owner1'}) - dag2 = DAG( - 'dag2', - start_date=DEFAULT_DATE, - default_args={'owner': 'owner2'}) + dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) + dag2 = DAG('dag2', start_date=DEFAULT_DATE, default_args={'owner': 'owner2'}) with dag: op1 = DummyOperator(task_id='op1') @@ -263,10 +234,7 @@ def test_dag_topological_sort_include_subdag_tasks(self): self.assertTrue(self._occur_before('b_child', 'b_parent', topological_list)) def test_dag_topological_sort1(self): - dag = DAG( - 'dag', - start_date=DEFAULT_DATE, - default_args={'owner': 'owner1'}) + dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) # A -> B # A -> C -> D @@ -292,10 +260,7 @@ def test_dag_topological_sort1(self): self.assertTrue(topological_list[3] == op1) def test_dag_topological_sort2(self): - dag = DAG( - 'dag', - start_date=DEFAULT_DATE, - default_args={'owner': 'owner1'}) + dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) # C -> (A u B) -> D # C -> E @@ -332,10 +297,7 @@ def test_dag_topological_sort2(self): self.assertTrue(topological_list[4] == op3) def test_dag_topological_sort_dag_without_tasks(self): - dag = DAG( - 'dag', - start_date=DEFAULT_DATE, - default_args={'owner': 'owner1'}) + dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) self.assertEqual((), dag.topological_sort()) @@ -356,8 +318,9 @@ def test_dag_start_date_propagates_to_end_date(self): An explicit check the `tzinfo` attributes for both are the same is an extra check. """ - dag = DAG('DAG', default_args={'start_date': '2019-06-05T00:00:00+05:00', - 'end_date': '2019-06-05T00:00:00'}) + dag = DAG( + 'DAG', default_args={'start_date': '2019-06-05T00:00:00+05:00', 'end_date': '2019-06-05T00:00:00'} + ) self.assertEqual(dag.default_args['start_date'], dag.default_args['end_date']) self.assertEqual(dag.default_args['start_date'].tzinfo, dag.default_args['end_date'].tzinfo) @@ -383,12 +346,10 @@ def test_dag_task_priority_weight_total(self): # Fully connected parallel tasks. i.e. every task at each parallel # stage is dependent on every task in the previous stage. # Default weight should be calculated using downstream descendants - with DAG('dag', start_date=DEFAULT_DATE, - default_args={'owner': 'owner1'}) as dag: + with DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) as dag: pipeline = [ - [DummyOperator( - task_id=f'stage{i}.{j}', priority_weight=weight) - for j in range(0, width)] for i in range(0, depth) + [DummyOperator(task_id=f'stage{i}.{j}', priority_weight=weight) for j in range(0, width)] + for i in range(0, depth) ] for i, stage in enumerate(pipeline): if i == 0: @@ -412,13 +373,17 @@ def test_dag_task_priority_weight_total_using_upstream(self): width = 5 depth = 5 pattern = re.compile('stage(\\d*).(\\d*)') - with DAG('dag', start_date=DEFAULT_DATE, - default_args={'owner': 'owner1'}) as dag: + with DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) as dag: pipeline = [ - [DummyOperator( - task_id=f'stage{i}.{j}', priority_weight=weight, - weight_rule=WeightRule.UPSTREAM) - for j in range(0, width)] for i in range(0, depth) + [ + DummyOperator( + task_id=f'stage{i}.{j}', + priority_weight=weight, + weight_rule=WeightRule.UPSTREAM, + ) + for j in range(0, width) + ] + for i in range(0, depth) ] for i, stage in enumerate(pipeline): if i == 0: @@ -441,13 +406,17 @@ def test_dag_task_priority_weight_total_using_absolute(self): weight = 10 width = 5 depth = 5 - with DAG('dag', start_date=DEFAULT_DATE, - default_args={'owner': 'owner1'}) as dag: + with DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) as dag: pipeline = [ - [DummyOperator( - task_id=f'stage{i}.{j}', priority_weight=weight, - weight_rule=WeightRule.ABSOLUTE) - for j in range(0, width)] for i in range(0, depth) + [ + DummyOperator( + task_id=f'stage{i}.{j}', + priority_weight=weight, + weight_rule=WeightRule.ABSOLUTE, + ) + for j in range(0, width) + ] + for i in range(0, depth) ] for i, stage in enumerate(pipeline): if i == 0: @@ -490,40 +459,29 @@ def test_get_num_task_instances(self): session.merge(ti4) session.commit() + self.assertEqual(0, DAG.get_num_task_instances(test_dag_id, ['fakename'], session=session)) + self.assertEqual(4, DAG.get_num_task_instances(test_dag_id, [test_task_id], session=session)) self.assertEqual( - 0, - DAG.get_num_task_instances(test_dag_id, ['fakename'], session=session) - ) - self.assertEqual( - 4, - DAG.get_num_task_instances(test_dag_id, [test_task_id], session=session) - ) - self.assertEqual( - 4, - DAG.get_num_task_instances( - test_dag_id, ['fakename', test_task_id], session=session) + 4, DAG.get_num_task_instances(test_dag_id, ['fakename', test_task_id], session=session) ) self.assertEqual( - 1, - DAG.get_num_task_instances( - test_dag_id, [test_task_id], states=[None], session=session) + 1, DAG.get_num_task_instances(test_dag_id, [test_task_id], states=[None], session=session) ) self.assertEqual( 2, - DAG.get_num_task_instances( - test_dag_id, [test_task_id], states=[State.RUNNING], session=session) + DAG.get_num_task_instances(test_dag_id, [test_task_id], states=[State.RUNNING], session=session), ) self.assertEqual( 3, DAG.get_num_task_instances( - test_dag_id, [test_task_id], - states=[None, State.RUNNING], session=session) + test_dag_id, [test_task_id], states=[None, State.RUNNING], session=session + ), ) self.assertEqual( 4, DAG.get_num_task_instances( - test_dag_id, [test_task_id], - states=[None, State.QUEUED, State.RUNNING], session=session) + test_dag_id, [test_task_id], states=[None, State.QUEUED, State.RUNNING], session=session + ), ) session.close() @@ -578,10 +536,10 @@ def test_following_previous_schedule(self): Make sure DST transitions are properly observed """ local_tz = pendulum.timezone('Europe/Zurich') - start = local_tz.convert(datetime.datetime(2018, 10, 28, 2, 55), - dst_rule=pendulum.PRE_TRANSITION) - self.assertEqual(start.isoformat(), "2018-10-28T02:55:00+02:00", - "Pre-condition: start date is in DST") + start = local_tz.convert(datetime.datetime(2018, 10, 28, 2, 55), dst_rule=pendulum.PRE_TRANSITION) + self.assertEqual( + start.isoformat(), "2018-10-28T02:55:00+02:00", "Pre-condition: start date is in DST" + ) utc = timezone.convert_to_utc(start) @@ -608,8 +566,7 @@ def test_following_previous_schedule_daily_dag_cest_to_cet(self): Make sure DST transitions are properly observed """ local_tz = pendulum.timezone('Europe/Zurich') - start = local_tz.convert(datetime.datetime(2018, 10, 27, 3), - dst_rule=pendulum.PRE_TRANSITION) + start = local_tz.convert(datetime.datetime(2018, 10, 27, 3), dst_rule=pendulum.PRE_TRANSITION) utc = timezone.convert_to_utc(start) @@ -638,8 +595,7 @@ def test_following_previous_schedule_daily_dag_cet_to_cest(self): Make sure DST transitions are properly observed """ local_tz = pendulum.timezone('Europe/Zurich') - start = local_tz.convert(datetime.datetime(2018, 3, 25, 2), - dst_rule=pendulum.PRE_TRANSITION) + start = local_tz.convert(datetime.datetime(2018, 3, 25, 2), dst_rule=pendulum.PRE_TRANSITION) utc = timezone.convert_to_utc(start) @@ -669,12 +625,8 @@ def test_following_schedule_relativedelta(self): """ dag_id = "test_schedule_dag_relativedelta" delta = relativedelta(hours=+1) - dag = DAG(dag_id=dag_id, - schedule_interval=delta) - dag.add_task(BaseOperator( - task_id="faketastic", - owner='Also fake', - start_date=TEST_DATE)) + dag = DAG(dag_id=dag_id, schedule_interval=delta) + dag.add_task(BaseOperator(task_id="faketastic", owner='Also fake', start_date=TEST_DATE)) _next = dag.following_schedule(TEST_DATE) self.assertEqual(_next.isoformat(), "2015-01-02T01:00:00+00:00") @@ -687,22 +639,21 @@ def test_dagtag_repr(self): dag = DAG('dag-test-dagtag', start_date=DEFAULT_DATE, tags=['tag-1', 'tag-2']) dag.sync_to_db() with create_session() as session: - self.assertEqual({'tag-1', 'tag-2'}, - {repr(t) for t in session.query(DagTag).filter( - DagTag.dag_id == 'dag-test-dagtag').all()}) + self.assertEqual( + {'tag-1', 'tag-2'}, + {repr(t) for t in session.query(DagTag).filter(DagTag.dag_id == 'dag-test-dagtag').all()}, + ) def test_bulk_write_to_db(self): clear_db_dags() - dags = [ - DAG(f'dag-bulk-sync-{i}', start_date=DEFAULT_DATE, tags=["test-dag"]) for i in range(0, 4) - ] + dags = [DAG(f'dag-bulk-sync-{i}', start_date=DEFAULT_DATE, tags=["test-dag"]) for i in range(0, 4)] with assert_queries_count(5): DAG.bulk_write_to_db(dags) with create_session() as session: self.assertEqual( {'dag-bulk-sync-0', 'dag-bulk-sync-1', 'dag-bulk-sync-2', 'dag-bulk-sync-3'}, - {row[0] for row in session.query(DagModel.dag_id).all()} + {row[0] for row in session.query(DagModel.dag_id).all()}, ) self.assertEqual( { @@ -711,7 +662,7 @@ def test_bulk_write_to_db(self): ('dag-bulk-sync-2', 'test-dag'), ('dag-bulk-sync-3', 'test-dag'), }, - set(session.query(DagTag.dag_id, DagTag.name).all()) + set(session.query(DagTag.dag_id, DagTag.name).all()), ) # Re-sync should do fewer queries with assert_queries_count(3): @@ -726,7 +677,7 @@ def test_bulk_write_to_db(self): with create_session() as session: self.assertEqual( {'dag-bulk-sync-0', 'dag-bulk-sync-1', 'dag-bulk-sync-2', 'dag-bulk-sync-3'}, - {row[0] for row in session.query(DagModel.dag_id).all()} + {row[0] for row in session.query(DagModel.dag_id).all()}, ) self.assertEqual( { @@ -739,7 +690,7 @@ def test_bulk_write_to_db(self): ('dag-bulk-sync-3', 'test-dag'), ('dag-bulk-sync-3', 'test-dag2'), }, - set(session.query(DagTag.dag_id, DagTag.name).all()) + set(session.query(DagTag.dag_id, DagTag.name).all()), ) # Removing tags for dag in dags: @@ -749,7 +700,7 @@ def test_bulk_write_to_db(self): with create_session() as session: self.assertEqual( {'dag-bulk-sync-0', 'dag-bulk-sync-1', 'dag-bulk-sync-2', 'dag-bulk-sync-3'}, - {row[0] for row in session.query(DagModel.dag_id).all()} + {row[0] for row in session.query(DagModel.dag_id).all()}, ) self.assertEqual( { @@ -758,7 +709,7 @@ def test_bulk_write_to_db(self): ('dag-bulk-sync-2', 'test-dag2'), ('dag-bulk-sync-3', 'test-dag2'), }, - set(session.query(DagTag.dag_id, DagTag.name).all()) + set(session.query(DagTag.dag_id, DagTag.name).all()), ) def test_bulk_write_to_db_max_active_runs(self): @@ -766,15 +717,10 @@ def test_bulk_write_to_db_max_active_runs(self): Test that DagModel.next_dagrun_create_after is set to NULL when the dag cannot be created due to max active runs being hit. """ - dag = DAG( - dag_id='test_scheduler_verify_max_active_runs', - start_date=DEFAULT_DATE) + dag = DAG(dag_id='test_scheduler_verify_max_active_runs', start_date=DEFAULT_DATE) dag.max_active_runs = 1 - DummyOperator( - task_id='dummy', - dag=dag, - owner='airflow') + DummyOperator(task_id='dummy', dag=dag, owner='airflow') session = settings.Session() dag.clear() @@ -807,15 +753,14 @@ def test_sync_to_db(self): ) with dag: DummyOperator(task_id='task', owner='owner1') - subdag = DAG('dag.subtask', start_date=DEFAULT_DATE, ) + subdag = DAG( + 'dag.subtask', + start_date=DEFAULT_DATE, + ) # parent_dag and is_subdag was set by DagBag. We don't use DagBag, so this value is not set. subdag.parent_dag = dag subdag.is_subdag = True - SubDagOperator( - task_id='subtask', - owner='owner2', - subdag=subdag - ) + SubDagOperator(task_id='subtask', owner='owner2', subdag=subdag) session = settings.Session() dag.sync_to_db(session=session) @@ -823,8 +768,7 @@ def test_sync_to_db(self): self.assertEqual(set(orm_dag.owners.split(', ')), {'owner1', 'owner2'}) self.assertTrue(orm_dag.is_active) self.assertIsNotNone(orm_dag.default_view) - self.assertEqual(orm_dag.default_view, - conf.get('webserver', 'dag_default_view').lower()) + self.assertEqual(orm_dag.default_view, conf.get('webserver', 'dag_default_view').lower()) self.assertEqual(orm_dag.safe_dag_id, 'dag') orm_subdag = session.query(DagModel).filter(DagModel.dag_id == 'dag.subtask').one() @@ -848,7 +792,7 @@ def test_sync_to_db_default_view(self): subdag=DAG( 'dag.subtask', start_date=DEFAULT_DATE, - ) + ), ) session = settings.Session() dag.sync_to_db(session=session) @@ -877,10 +821,7 @@ def test_is_paused_subdag(self, session): ) with dag: - SubDagOperator( - task_id='subdag', - subdag=subdag - ) + SubDagOperator(task_id='subdag', subdag=subdag) # parent_dag and is_subdag was set by DagBag. We don't use DagBag, so this value is not set. subdag.parent_dag = dag @@ -892,63 +833,70 @@ def test_is_paused_subdag(self, session): dag.sync_to_db(session=session) - unpaused_dags = session.query( - DagModel.dag_id, DagModel.is_paused - ).filter( - DagModel.dag_id.in_([subdag_id, dag_id]), - ).all() + unpaused_dags = ( + session.query(DagModel.dag_id, DagModel.is_paused) + .filter( + DagModel.dag_id.in_([subdag_id, dag_id]), + ) + .all() + ) - self.assertEqual({ - (dag_id, False), - (subdag_id, False), - }, set(unpaused_dags)) + self.assertEqual( + { + (dag_id, False), + (subdag_id, False), + }, + set(unpaused_dags), + ) DagModel.get_dagmodel(dag.dag_id).set_is_paused(is_paused=True, including_subdags=False) - paused_dags = session.query( - DagModel.dag_id, DagModel.is_paused - ).filter( - DagModel.dag_id.in_([subdag_id, dag_id]), - ).all() + paused_dags = ( + session.query(DagModel.dag_id, DagModel.is_paused) + .filter( + DagModel.dag_id.in_([subdag_id, dag_id]), + ) + .all() + ) - self.assertEqual({ - (dag_id, True), - (subdag_id, False), - }, set(paused_dags)) + self.assertEqual( + { + (dag_id, True), + (subdag_id, False), + }, + set(paused_dags), + ) DagModel.get_dagmodel(dag.dag_id).set_is_paused(is_paused=True) - paused_dags = session.query( - DagModel.dag_id, DagModel.is_paused - ).filter( - DagModel.dag_id.in_([subdag_id, dag_id]), - ).all() + paused_dags = ( + session.query(DagModel.dag_id, DagModel.is_paused) + .filter( + DagModel.dag_id.in_([subdag_id, dag_id]), + ) + .all() + ) - self.assertEqual({ - (dag_id, True), - (subdag_id, True), - }, set(paused_dags)) + self.assertEqual( + { + (dag_id, True), + (subdag_id, True), + }, + set(paused_dags), + ) def test_existing_dag_is_paused_upon_creation(self): - dag = DAG( - 'dag_paused' - ) + dag = DAG('dag_paused') dag.sync_to_db() self.assertFalse(dag.get_is_paused()) - dag = DAG( - 'dag_paused', - is_paused_upon_creation=True - ) + dag = DAG('dag_paused', is_paused_upon_creation=True) dag.sync_to_db() # Since the dag existed before, it should not follow the pause flag upon creation self.assertFalse(dag.get_is_paused()) def test_new_dag_is_paused_upon_creation(self): - dag = DAG( - 'new_nonexisting_dag', - is_paused_upon_creation=True - ) + dag = DAG('new_nonexisting_dag', is_paused_upon_creation=True) session = settings.Session() dag.sync_to_db(session=session) @@ -1045,9 +993,7 @@ def test_tree_view(self): def test_duplicate_task_ids_not_allowed_with_dag_context_manager(self): """Verify tasks with Duplicate task_id raises error""" - with self.assertRaisesRegex( - DuplicateTaskIdFound, "Task id 't1' has already been added to the DAG" - ): + with self.assertRaisesRegex(DuplicateTaskIdFound, "Task id 't1' has already been added to the DAG"): with DAG("test_dag", start_date=DEFAULT_DATE) as dag: op1 = DummyOperator(task_id="t1") op2 = BashOperator(task_id="t1", bash_command="sleep 1") @@ -1057,9 +1003,7 @@ def test_duplicate_task_ids_not_allowed_with_dag_context_manager(self): def test_duplicate_task_ids_not_allowed_without_dag_context_manager(self): """Verify tasks with Duplicate task_id raises error""" - with self.assertRaisesRegex( - DuplicateTaskIdFound, "Task id 't1' has already been added to the DAG" - ): + with self.assertRaisesRegex(DuplicateTaskIdFound, "Task id 't1' has already been added to the DAG"): dag = DAG("test_dag", start_date=DEFAULT_DATE) op1 = DummyOperator(task_id="t1", dag=dag) op2 = DummyOperator(task_id="t1", dag=dag) @@ -1096,10 +1040,7 @@ def test_schedule_dag_no_previous_runs(self): """ dag_id = "test_schedule_dag_no_previous_runs" dag = DAG(dag_id=dag_id) - dag.add_task(BaseOperator( - task_id="faketastic", - owner='Also fake', - start_date=TEST_DATE)) + dag.add_task(BaseOperator(task_id="faketastic", owner='Also fake', start_date=TEST_DATE)) dag_run = dag.create_dagrun( run_type=DagRunType.SCHEDULED, @@ -1113,8 +1054,7 @@ def test_schedule_dag_no_previous_runs(self): self.assertEqual( TEST_DATE, dag_run.execution_date, - msg='dag_run.execution_date did not match expectation: {}' - .format(dag_run.execution_date) + msg=f'dag_run.execution_date did not match expectation: {dag_run.execution_date}', ) self.assertEqual(State.RUNNING, dag_run.state) self.assertFalse(dag_run.external_trigger) @@ -1133,12 +1073,10 @@ def test_dag_handle_callback_crash(self, mock_stats): dag_id=dag_id, # callback with invalid signature should not cause crashes on_success_callback=lambda: 1, - on_failure_callback=mock_callback_with_exception) + on_failure_callback=mock_callback_with_exception, + ) when = TEST_DATE - dag.add_task(BaseOperator( - task_id="faketastic", - owner='Also fake', - start_date=when)) + dag.add_task(BaseOperator(task_id="faketastic", owner='Also fake', start_date=when)) dag_run = dag.create_dagrun(State.RUNNING, when, run_type=DagRunType.MANUAL) # should not rause any exception @@ -1157,18 +1095,15 @@ def test_next_dagrun_after_fake_scheduled_previous(self): """ delta = datetime.timedelta(hours=1) dag_id = "test_schedule_dag_fake_scheduled_previous" - dag = DAG(dag_id=dag_id, - schedule_interval=delta, - start_date=DEFAULT_DATE) - dag.add_task(BaseOperator( - task_id="faketastic", - owner='Also fake', - start_date=DEFAULT_DATE)) - - dag.create_dagrun(run_type=DagRunType.SCHEDULED, - execution_date=DEFAULT_DATE, - state=State.SUCCESS, - external_trigger=True) + dag = DAG(dag_id=dag_id, schedule_interval=delta, start_date=DEFAULT_DATE) + dag.add_task(BaseOperator(task_id="faketastic", owner='Also fake', start_date=DEFAULT_DATE)) + + dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE, + state=State.SUCCESS, + external_trigger=True, + ) dag.sync_to_db() with create_session() as session: model = session.query(DagModel).get((dag.dag_id,)) @@ -1189,17 +1124,12 @@ def test_schedule_dag_once(self): dag = DAG(dag_id=dag_id) dag.schedule_interval = '@once' self.assertEqual(dag.normalized_schedule_interval, None) - dag.add_task(BaseOperator( - task_id="faketastic", - owner='Also fake', - start_date=TEST_DATE)) + dag.add_task(BaseOperator(task_id="faketastic", owner='Also fake', start_date=TEST_DATE)) # Sync once to create the DagModel dag.sync_to_db() - dag.create_dagrun(run_type=DagRunType.SCHEDULED, - execution_date=TEST_DATE, - state=State.SUCCESS) + dag.create_dagrun(run_type=DagRunType.SCHEDULED, execution_date=TEST_DATE, state=State.SUCCESS) # Then sync again after creating the dag run -- this should update next_dagrun dag.sync_to_db() @@ -1217,10 +1147,7 @@ def test_fractional_seconds(self): dag_id = "test_fractional_seconds" dag = DAG(dag_id=dag_id) dag.schedule_interval = '@once' - dag.add_task(BaseOperator( - task_id="faketastic", - owner='Also fake', - start_date=TEST_DATE)) + dag.add_task(BaseOperator(task_id="faketastic", owner='Also fake', start_date=TEST_DATE)) start_date = timezone.utcnow() @@ -1229,15 +1156,13 @@ def test_fractional_seconds(self): execution_date=start_date, start_date=start_date, state=State.RUNNING, - external_trigger=False + external_trigger=False, ) run.refresh_from_db() - self.assertEqual(start_date, run.execution_date, - "dag run execution_date loses precision") - self.assertEqual(start_date, run.start_date, - "dag run start_date loses precision ") + self.assertEqual(start_date, run.execution_date, "dag run execution_date loses precision") + self.assertEqual(start_date, run.start_date, "dag run start_date loses precision ") self._clean_up(dag_id) def test_pickling(self): @@ -1262,8 +1187,7 @@ class DAGsubclass(DAG): dag_diff_name = DAG(test_dag_id + '_neq', default_args=args) dag_subclass = DAGsubclass(test_dag_id, default_args=args) - dag_subclass_diff_name = DAGsubclass( - test_dag_id + '2', default_args=args) + dag_subclass_diff_name = DAGsubclass(test_dag_id + '2', default_args=args) for dag_ in [dag_eq, dag_diff_name, dag_subclass, dag_subclass_diff_name]: dag_.last_loaded = dag.last_loaded @@ -1307,23 +1231,21 @@ def test_get_paused_dag_ids(self): self.assertEqual(paused_dag_ids, {dag_id}) with create_session() as session: - session.query(DagModel).filter( - DagModel.dag_id == dag_id).delete( - synchronize_session=False) - - @parameterized.expand([ - (None, None), - ("@daily", "0 0 * * *"), - ("@weekly", "0 0 * * 0"), - ("@monthly", "0 0 1 * *"), - ("@quarterly", "0 0 1 */3 *"), - ("@yearly", "0 0 1 1 *"), - ("@once", None), - (datetime.timedelta(days=1), datetime.timedelta(days=1)), - ]) - def test_normalized_schedule_interval( - self, schedule_interval, expected_n_schedule_interval - ): + session.query(DagModel).filter(DagModel.dag_id == dag_id).delete(synchronize_session=False) + + @parameterized.expand( + [ + (None, None), + ("@daily", "0 0 * * *"), + ("@weekly", "0 0 * * 0"), + ("@monthly", "0 0 1 * *"), + ("@quarterly", "0 0 1 */3 *"), + ("@yearly", "0 0 1 1 *"), + ("@once", None), + (datetime.timedelta(days=1), datetime.timedelta(days=1)), + ] + ) + def test_normalized_schedule_interval(self, schedule_interval, expected_n_schedule_interval): dag = DAG("test_schedule_interval", schedule_interval=schedule_interval) self.assertEqual(dag.normalized_schedule_interval, expected_n_schedule_interval) @@ -1399,20 +1321,24 @@ def test_clear_set_dagrun_state(self, dag_run_state): session=session, ) - dagruns = session.query( - DagRun, - ).filter( - DagRun.dag_id == dag_id, - ).all() + dagruns = ( + session.query( + DagRun, + ) + .filter( + DagRun.dag_id == dag_id, + ) + .all() + ) self.assertEqual(len(dagruns), 1) dagrun = dagruns[0] # type: DagRun self.assertEqual(dagrun.state, dag_run_state) - @parameterized.expand([ - (state, State.NONE) - for state in State.task_states if state != State.RUNNING - ] + [(State.RUNNING, State.SHUTDOWN)]) # type: ignore + @parameterized.expand( + [(state, State.NONE) for state in State.task_states if state != State.RUNNING] + + [(State.RUNNING, State.SHUTDOWN)] + ) # type: ignore def test_clear_dag(self, ti_state_begin, ti_state_end: Optional[str]): dag_id = 'test_clear_dag' self._clean_up(dag_id) @@ -1440,11 +1366,15 @@ def test_clear_dag(self, ti_state_begin, ti_state_end: Optional[str]): session=session, ) - task_instances = session.query( - TI, - ).filter( - TI.dag_id == dag_id, - ).all() + task_instances = ( + session.query( + TI, + ) + .filter( + TI.dag_id == dag_id, + ) + .all() + ) self.assertEqual(len(task_instances), 1) task_instance = task_instances[0] # type: TI @@ -1453,9 +1383,8 @@ def test_clear_dag(self, ti_state_begin, ti_state_end: Optional[str]): def test_next_dagrun_after_date_once(self): dag = DAG( - 'test_scheduler_dagrun_once', - start_date=timezone.datetime(2015, 1, 1), - schedule_interval="@once") + 'test_scheduler_dagrun_once', start_date=timezone.datetime(2015, 1, 1), schedule_interval="@once" + ) next_date = dag.next_dagrun_after_date(None) @@ -1474,10 +1403,7 @@ def test_next_dagrun_after_date_start_end_dates(self): start_date = DEFAULT_DATE end_date = start_date + (runs - 1) * delta dag_id = "test_schedule_dag_start_end_dates" - dag = DAG(dag_id=dag_id, - start_date=start_date, - end_date=end_date, - schedule_interval=delta) + dag = DAG(dag_id=dag_id, start_date=start_date, end_date=end_date, schedule_interval=delta) dag.add_task(BaseOperator(task_id='faketastic', owner='Also fake')) # Create and schedule the dag runs @@ -1504,11 +1430,13 @@ def make_dag(dag_id, schedule_interval, start_date, catchup): 'owner': 'airflow', 'depends_on_past': False, } - dag = DAG(dag_id, - schedule_interval=schedule_interval, - start_date=start_date, - catchup=catchup, - default_args=default_args) + dag = DAG( + dag_id, + schedule_interval=schedule_interval, + start_date=start_date, + catchup=catchup, + default_args=default_args, + ) op1 = DummyOperator(task_id='t1', dag=dag) op2 = DummyOperator(task_id='t2', dag=dag) @@ -1519,23 +1447,28 @@ def make_dag(dag_id, schedule_interval, start_date, catchup): now = timezone.utcnow() six_hours_ago_to_the_hour = (now - datetime.timedelta(hours=6)).replace( - minute=0, second=0, microsecond=0) + minute=0, second=0, microsecond=0 + ) half_an_hour_ago = now - datetime.timedelta(minutes=30) two_hours_ago = now - datetime.timedelta(hours=2) - dag1 = make_dag(dag_id='dag_without_catchup_ten_minute', - schedule_interval='*/10 * * * *', - start_date=six_hours_ago_to_the_hour, - catchup=False) + dag1 = make_dag( + dag_id='dag_without_catchup_ten_minute', + schedule_interval='*/10 * * * *', + start_date=six_hours_ago_to_the_hour, + catchup=False, + ) next_date = dag1.next_dagrun_after_date(None) # The DR should be scheduled in the last half an hour, not 6 hours ago assert next_date > half_an_hour_ago assert next_date < timezone.utcnow() - dag2 = make_dag(dag_id='dag_without_catchup_hourly', - schedule_interval='@hourly', - start_date=six_hours_ago_to_the_hour, - catchup=False) + dag2 = make_dag( + dag_id='dag_without_catchup_hourly', + schedule_interval='@hourly', + start_date=six_hours_ago_to_the_hour, + catchup=False, + ) next_date = dag2.next_dagrun_after_date(None) # The DR should be scheduled in the last 2 hours, not 6 hours ago @@ -1543,10 +1476,12 @@ def make_dag(dag_id, schedule_interval, start_date, catchup): # The DR should be scheduled BEFORE now assert next_date < timezone.utcnow() - dag3 = make_dag(dag_id='dag_without_catchup_once', - schedule_interval='@once', - start_date=six_hours_ago_to_the_hour, - catchup=False) + dag3 = make_dag( + dag_id='dag_without_catchup_once', + schedule_interval='@once', + start_date=six_hours_ago_to_the_hour, + catchup=False, + ) next_date = dag3.next_dagrun_after_date(None) # The DR should be scheduled in the last 2 hours, not 6 hours ago @@ -1562,7 +1497,8 @@ def test_next_dagrun_after_date_timedelta_schedule_and_catchup_false(self): 'test_scheduler_dagrun_once_with_timedelta_and_catchup_false', start_date=timezone.datetime(2015, 1, 1), schedule_interval=timedelta(days=1), - catchup=False) + catchup=False, + ) next_date = dag.next_dagrun_after_date(None) assert next_date == timezone.datetime(2020, 1, 4) @@ -1581,7 +1517,8 @@ def test_next_dagrun_after_date_timedelta_schedule_and_catchup_true(self): 'test_scheduler_dagrun_once_with_timedelta_and_catchup_true', start_date=timezone.datetime(2020, 5, 1), schedule_interval=timedelta(days=1), - catchup=True) + catchup=True, + ) next_date = dag.next_dagrun_after_date(None) assert next_date == timezone.datetime(2020, 5, 1) @@ -1606,12 +1543,9 @@ def test_next_dagrun_after_auto_align(self): dag = DAG( dag_id='test_scheduler_auto_align_1', start_date=timezone.datetime(2016, 1, 1, 10, 10, 0), - schedule_interval="4 5 * * *" + schedule_interval="4 5 * * *", ) - DummyOperator( - task_id='dummy', - dag=dag, - owner='airflow') + DummyOperator(task_id='dummy', dag=dag, owner='airflow') next_date = dag.next_dagrun_after_date(None) assert next_date == timezone.datetime(2016, 1, 2, 5, 4) @@ -1619,12 +1553,9 @@ def test_next_dagrun_after_auto_align(self): dag = DAG( dag_id='test_scheduler_auto_align_2', start_date=timezone.datetime(2016, 1, 1, 10, 10, 0), - schedule_interval="10 10 * * *" + schedule_interval="10 10 * * *", ) - DummyOperator( - task_id='dummy', - dag=dag, - owner='airflow') + DummyOperator(task_id='dummy', dag=dag, owner='airflow') next_date = dag.next_dagrun_after_date(None) assert next_date == timezone.datetime(2016, 1, 1, 10, 10) @@ -1639,9 +1570,11 @@ def subdag(parent_dag_name, child_dag_name, args): """ Create a subdag. """ - dag_subdag = DAG(dag_id=f'{parent_dag_name}.{child_dag_name}', - schedule_interval="@daily", - default_args=args) + dag_subdag = DAG( + dag_id=f'{parent_dag_name}.{child_dag_name}', + schedule_interval="@daily", + default_args=args, + ) for i in range(2): DummyOperator(task_id='{}-task-{}'.format(child_dag_name, i + 1), dag=dag_subdag) @@ -1673,11 +1606,11 @@ def subdag(parent_dag_name, child_dag_name, args): def test_replace_outdated_access_control_actions(self): outdated_permissions = { 'role1': {permissions.ACTION_CAN_READ, permissions.ACTION_CAN_EDIT}, - 'role2': {permissions.DEPRECATED_ACTION_CAN_DAG_READ, permissions.DEPRECATED_ACTION_CAN_DAG_EDIT} + 'role2': {permissions.DEPRECATED_ACTION_CAN_DAG_READ, permissions.DEPRECATED_ACTION_CAN_DAG_EDIT}, } updated_permissions = { 'role1': {permissions.ACTION_CAN_READ, permissions.ACTION_CAN_EDIT}, - 'role2': {permissions.ACTION_CAN_READ, permissions.ACTION_CAN_EDIT} + 'role2': {permissions.ACTION_CAN_READ, permissions.ACTION_CAN_EDIT}, } with pytest.warns(DeprecationWarning): @@ -1690,15 +1623,9 @@ def test_replace_outdated_access_control_actions(self): class TestDagModel: - def test_dags_needing_dagruns_not_too_early(self): - dag = DAG( - dag_id='far_future_dag', - start_date=timezone.datetime(2038, 1, 1)) - DummyOperator( - task_id='dummy', - dag=dag, - owner='airflow') + dag = DAG(dag_id='far_future_dag', start_date=timezone.datetime(2038, 1, 1)) + DummyOperator(task_id='dummy', dag=dag, owner='airflow') session = settings.Session() orm_dag = DagModel( @@ -1722,13 +1649,8 @@ def test_dags_needing_dagruns_only_unpaused(self): """ We should never create dagruns for unpaused DAGs """ - dag = DAG( - dag_id='test_dags', - start_date=DEFAULT_DATE) - DummyOperator( - task_id='dummy', - dag=dag, - owner='airflow') + dag = DAG(dag_id='test_dags', start_date=DEFAULT_DATE) + DummyOperator(task_id='dummy', dag=dag, owner='airflow') session = settings.Session() orm_dag = DagModel( @@ -1755,17 +1677,18 @@ def test_dags_needing_dagruns_only_unpaused(self): class TestQueries(unittest.TestCase): - def setUp(self) -> None: clear_db_runs() def tearDown(self) -> None: clear_db_runs() - @parameterized.expand([ - (3, ), - (12, ), - ]) + @parameterized.expand( + [ + (3,), + (12,), + ] + ) def test_count_number_queries(self, tasks_count): dag = DAG('test_dagrun_query_count', start_date=DEFAULT_DATE) for i in range(tasks_count): @@ -1799,6 +1722,7 @@ def tearDown(self): def test_set_dag_id(self): """Test that checks you can set dag_id from decorator.""" + @dag_decorator('test', default_args=self.DEFAULT_ARGS) def noop_pipeline(): @task_decorator @@ -1806,12 +1730,14 @@ def return_num(num): return num return_num(4) + dag = noop_pipeline() assert isinstance(dag, DAG) assert dag.dag_id, 'test' def test_default_dag_id(self): """Test that @dag uses function name as default dag id.""" + @dag_decorator(default_args=self.DEFAULT_ARGS) def noop_pipeline(): @task_decorator @@ -1819,22 +1745,26 @@ def return_num(num): return num return_num(4) + dag = noop_pipeline() assert isinstance(dag, DAG) assert dag.dag_id, 'noop_pipeline' def test_documentation_added(self): """Test that @dag uses function docs as doc_md for DAG object""" + @dag_decorator(default_args=self.DEFAULT_ARGS) def noop_pipeline(): """ Regular DAG documentation """ + @task_decorator def return_num(num): return num return_num(4) + dag = noop_pipeline() assert isinstance(dag, DAG) assert dag.dag_id, 'test' @@ -1842,6 +1772,7 @@ def return_num(num): def test_fails_if_arg_not_set(self): """Test that @dag decorated function fails if positional argument is not set""" + @dag_decorator(default_args=self.DEFAULT_ARGS) def noop_pipeline(value): @task_decorator @@ -1856,6 +1787,7 @@ def return_num(num): def test_dag_param_resolves(self): """Test that dag param is correctly resolved by operator""" + @dag_decorator(default_args=self.DEFAULT_ARGS) def xcom_pass_to_op(value=self.VALUE): @task_decorator @@ -1871,7 +1803,7 @@ def return_num(num): run_id=DagRunType.MANUAL.value, start_date=timezone.utcnow(), execution_date=self.DEFAULT_DATE, - state=State.RUNNING + state=State.RUNNING, ) self.operator.run(start_date=self.DEFAULT_DATE, end_date=self.DEFAULT_DATE) @@ -1880,11 +1812,13 @@ def return_num(num): def test_dag_param_dagrun_parameterized(self): """Test that dag param is correctly overwritten when set in dag run""" + @dag_decorator(default_args=self.DEFAULT_ARGS) def xcom_pass_to_op(value=self.VALUE): @task_decorator def return_num(num): return num + assert isinstance(value, DagParam) xcom_arg = return_num(value) @@ -1897,7 +1831,7 @@ def return_num(num): start_date=timezone.utcnow(), execution_date=self.DEFAULT_DATE, state=State.RUNNING, - conf={'value': new_value} + conf={'value': new_value}, ) self.operator.run(start_date=self.DEFAULT_DATE, end_date=self.DEFAULT_DATE) @@ -1906,6 +1840,7 @@ def return_num(num): def test_set_params_for_dag(self): """Test that dag param is correctly set when using dag decorator""" + @dag_decorator(default_args=self.DEFAULT_ARGS) def xcom_pass_to_op(value=self.VALUE): @task_decorator diff --git a/tests/models/test_dagbag.py b/tests/models/test_dagbag.py index d43950996bc8b..a0d7fbed5f062 100644 --- a/tests/models/test_dagbag.py +++ b/tests/models/test_dagbag.py @@ -64,8 +64,7 @@ def test_get_existing_dag(self): """ dagbag = models.DagBag(dag_folder=self.empty_dir, include_examples=True) - some_expected_dag_ids = ["example_bash_operator", - "example_branch_operator"] + some_expected_dag_ids = ["example_bash_operator", "example_branch_operator"] for dag_id in some_expected_dag_ids: dag = dagbag.get_dag(dag_id) @@ -105,9 +104,7 @@ def test_safe_mode_heuristic_match(self): dagbag = models.DagBag(include_examples=False, safe_mode=True) self.assertEqual(len(dagbag.dagbag_stats), 1) - self.assertEqual( - dagbag.dagbag_stats[0].file, - "/{}".format(os.path.basename(f.name))) + self.assertEqual(dagbag.dagbag_stats[0].file, "/{}".format(os.path.basename(f.name))) def test_safe_mode_heuristic_mismatch(self): """With safe mode enabled, a file not matching the discovery heuristics @@ -119,15 +116,12 @@ def test_safe_mode_heuristic_mismatch(self): self.assertEqual(len(dagbag.dagbag_stats), 0) def test_safe_mode_disabled(self): - """With safe mode disabled, an empty python file should be discovered. - """ + """With safe mode disabled, an empty python file should be discovered.""" with NamedTemporaryFile(dir=self.empty_dir, suffix=".py") as f: with conf_vars({('core', 'dags_folder'): self.empty_dir}): dagbag = models.DagBag(include_examples=False, safe_mode=False) self.assertEqual(len(dagbag.dagbag_stats), 1) - self.assertEqual( - dagbag.dagbag_stats[0].file, - "/{}".format(os.path.basename(f.name))) + self.assertEqual(dagbag.dagbag_stats[0].file, "/{}".format(os.path.basename(f.name))) def test_process_file_that_contains_multi_bytes_char(self): """ @@ -153,7 +147,7 @@ def test_zip_skip_log(self): self.assertIn( f'INFO:airflow.models.dagbag.DagBag:File {test_zip_path}:file_no_airflow_dag.py ' 'assumed to contain no DAGs. Skipping.', - cm.output + cm.output, ) def test_zip(self): @@ -218,7 +212,7 @@ def test_get_dag_fileloc(self): 'example_bash_operator': 'airflow/example_dags/example_bash_operator.py', 'example_subdag_operator': 'airflow/example_dags/example_subdag_operator.py', 'example_subdag_operator.section-1': 'airflow/example_dags/subdags/subdag.py', - 'test_zip_dag': 'dags/test_zip.zip/test_zip.py' + 'test_zip_dag': 'dags/test_zip.zip/test_zip.py', } for dag_id, path in expected.items(): @@ -233,14 +227,10 @@ def test_refresh_py_dag(self, mock_dagmodel): example_dags_folder = airflow.example_dags.__path__[0] dag_id = "example_bash_operator" - fileloc = os.path.realpath( - os.path.join(example_dags_folder, "example_bash_operator.py") - ) + fileloc = os.path.realpath(os.path.join(example_dags_folder, "example_bash_operator.py")) mock_dagmodel.return_value = DagModel() - mock_dagmodel.return_value.last_expired = datetime.max.replace( - tzinfo=timezone.utc - ) + mock_dagmodel.return_value.last_expired = datetime.max.replace(tzinfo=timezone.utc) mock_dagmodel.return_value.fileloc = fileloc class _TestDagBag(DagBag): @@ -265,14 +255,10 @@ def test_refresh_packaged_dag(self, mock_dagmodel): Test that we can refresh a packaged DAG """ dag_id = "test_zip_dag" - fileloc = os.path.realpath( - os.path.join(TEST_DAGS_FOLDER, "test_zip.zip/test_zip.py") - ) + fileloc = os.path.realpath(os.path.join(TEST_DAGS_FOLDER, "test_zip.zip/test_zip.py")) mock_dagmodel.return_value = DagModel() - mock_dagmodel.return_value.last_expired = datetime.max.replace( - tzinfo=timezone.utc - ) + mock_dagmodel.return_value.last_expired = datetime.max.replace(tzinfo=timezone.utc) mock_dagmodel.return_value.fileloc = fileloc class _TestDagBag(DagBag): @@ -296,8 +282,7 @@ def process_dag(self, create_dag): Helper method to process a file generated from the input create_dag function. """ # write source to file - source = textwrap.dedent(''.join( - inspect.getsource(create_dag).splitlines(True)[1:-1])) + source = textwrap.dedent(''.join(inspect.getsource(create_dag).splitlines(True)[1:-1])) f = NamedTemporaryFile() f.write(source.encode('utf8')) f.flush() @@ -306,8 +291,7 @@ def process_dag(self, create_dag): found_dags = dagbag.process_file(f.name) return dagbag, found_dags, f.name - def validate_dags(self, expected_parent_dag, actual_found_dags, actual_dagbag, - should_be_found=True): + def validate_dags(self, expected_parent_dag, actual_found_dags, actual_dagbag, should_be_found=True): expected_dag_ids = list(map(lambda dag: dag.dag_id, expected_parent_dag.subdags)) expected_dag_ids.append(expected_parent_dag.dag_id) @@ -316,14 +300,16 @@ def validate_dags(self, expected_parent_dag, actual_found_dags, actual_dagbag, for dag_id in expected_dag_ids: actual_dagbag.log.info('validating %s' % dag_id) self.assertEqual( - dag_id in actual_found_dag_ids, should_be_found, - 'dag "%s" should %shave been found after processing dag "%s"' % - (dag_id, '' if should_be_found else 'not ', expected_parent_dag.dag_id) + dag_id in actual_found_dag_ids, + should_be_found, + 'dag "%s" should %shave been found after processing dag "%s"' + % (dag_id, '' if should_be_found else 'not ', expected_parent_dag.dag_id), ) self.assertEqual( - dag_id in actual_dagbag.dags, should_be_found, - 'dag "%s" should %sbe in dagbag.dags after processing dag "%s"' % - (dag_id, '' if should_be_found else 'not ', expected_parent_dag.dag_id) + dag_id in actual_dagbag.dags, + should_be_found, + 'dag "%s" should %sbe in dagbag.dags after processing dag "%s"' + % (dag_id, '' if should_be_found else 'not ', expected_parent_dag.dag_id), ) def test_load_subdags(self): @@ -334,14 +320,10 @@ def standard_subdag(): from airflow.models import DAG from airflow.operators.dummy_operator import DummyOperator from airflow.operators.subdag_operator import SubDagOperator + dag_name = 'master' - default_args = { - 'owner': 'owner1', - 'start_date': datetime.datetime(2016, 1, 1) - } - dag = DAG( - dag_name, - default_args=default_args) + default_args = {'owner': 'owner1', 'start_date': datetime.datetime(2016, 1, 1)} + dag = DAG(dag_name, default_args=default_args) # master: # A -> opSubDag_0 @@ -352,6 +334,7 @@ def standard_subdag(): # -> subdag_1.task with dag: + def subdag_0(): subdag_0 = DAG('master.op_subdag_0', default_args=default_args) DummyOperator(task_id='subdag_0.task', dag=subdag_0) @@ -362,10 +345,8 @@ def subdag_1(): DummyOperator(task_id='subdag_1.task', dag=subdag_1) return subdag_1 - op_subdag_0 = SubDagOperator( - task_id='op_subdag_0', dag=dag, subdag=subdag_0()) - op_subdag_1 = SubDagOperator( - task_id='op_subdag_1', dag=dag, subdag=subdag_1()) + op_subdag_0 = SubDagOperator(task_id='op_subdag_0', dag=dag, subdag=subdag_0()) + op_subdag_1 = SubDagOperator(task_id='op_subdag_1', dag=dag, subdag=subdag_1()) op_a = DummyOperator(task_id='A') op_a.set_downstream(op_subdag_0) @@ -390,14 +371,10 @@ def nested_subdags(): from airflow.models import DAG from airflow.operators.dummy_operator import DummyOperator from airflow.operators.subdag_operator import SubDagOperator + dag_name = 'master' - default_args = { - 'owner': 'owner1', - 'start_date': datetime.datetime(2016, 1, 1) - } - dag = DAG( - dag_name, - default_args=default_args) + default_args = {'owner': 'owner1', 'start_date': datetime.datetime(2016, 1, 1)} + dag = DAG(dag_name, default_args=default_args) # master: # A -> op_subdag_0 @@ -418,27 +395,24 @@ def nested_subdags(): # -> subdag_d.task with dag: + def subdag_a(): - subdag_a = DAG( - 'master.op_subdag_0.opSubdag_A', default_args=default_args) + subdag_a = DAG('master.op_subdag_0.opSubdag_A', default_args=default_args) DummyOperator(task_id='subdag_a.task', dag=subdag_a) return subdag_a def subdag_b(): - subdag_b = DAG( - 'master.op_subdag_0.opSubdag_B', default_args=default_args) + subdag_b = DAG('master.op_subdag_0.opSubdag_B', default_args=default_args) DummyOperator(task_id='subdag_b.task', dag=subdag_b) return subdag_b def subdag_c(): - subdag_c = DAG( - 'master.op_subdag_1.opSubdag_C', default_args=default_args) + subdag_c = DAG('master.op_subdag_1.opSubdag_C', default_args=default_args) DummyOperator(task_id='subdag_c.task', dag=subdag_c) return subdag_c def subdag_d(): - subdag_d = DAG( - 'master.op_subdag_1.opSubdag_D', default_args=default_args) + subdag_d = DAG('master.op_subdag_1.opSubdag_D', default_args=default_args) DummyOperator(task_id='subdag_d.task', dag=subdag_d) return subdag_d @@ -454,10 +428,8 @@ def subdag_1(): SubDagOperator(task_id='opSubdag_D', dag=subdag_1, subdag=subdag_d()) return subdag_1 - op_subdag_0 = SubDagOperator( - task_id='op_subdag_0', dag=dag, subdag=subdag_0()) - op_subdag_1 = SubDagOperator( - task_id='op_subdag_1', dag=dag, subdag=subdag_1()) + op_subdag_0 = SubDagOperator(task_id='op_subdag_0', dag=dag, subdag=subdag_0()) + op_subdag_1 = SubDagOperator(task_id='op_subdag_1', dag=dag, subdag=subdag_1()) op_a = DummyOperator(task_id='A') op_a.set_downstream(op_subdag_0) @@ -488,14 +460,10 @@ def basic_cycle(): from airflow.models import DAG from airflow.operators.dummy_operator import DummyOperator + dag_name = 'cycle_dag' - default_args = { - 'owner': 'owner1', - 'start_date': datetime.datetime(2016, 1, 1) - } - dag = DAG( - dag_name, - default_args=default_args) + default_args = {'owner': 'owner1', 'start_date': datetime.datetime(2016, 1, 1)} + dag = DAG(dag_name, default_args=default_args) # A -> A with dag: @@ -523,14 +491,10 @@ def nested_subdag_cycle(): from airflow.models import DAG from airflow.operators.dummy_operator import DummyOperator from airflow.operators.subdag_operator import SubDagOperator + dag_name = 'nested_cycle' - default_args = { - 'owner': 'owner1', - 'start_date': datetime.datetime(2016, 1, 1) - } - dag = DAG( - dag_name, - default_args=default_args) + default_args = {'owner': 'owner1', 'start_date': datetime.datetime(2016, 1, 1)} + dag = DAG(dag_name, default_args=default_args) # cycle: # A -> op_subdag_0 @@ -551,30 +515,26 @@ def nested_subdag_cycle(): # -> subdag_d.task with dag: + def subdag_a(): - subdag_a = DAG( - 'nested_cycle.op_subdag_0.opSubdag_A', default_args=default_args) + subdag_a = DAG('nested_cycle.op_subdag_0.opSubdag_A', default_args=default_args) DummyOperator(task_id='subdag_a.task', dag=subdag_a) return subdag_a def subdag_b(): - subdag_b = DAG( - 'nested_cycle.op_subdag_0.opSubdag_B', default_args=default_args) + subdag_b = DAG('nested_cycle.op_subdag_0.opSubdag_B', default_args=default_args) DummyOperator(task_id='subdag_b.task', dag=subdag_b) return subdag_b def subdag_c(): - subdag_c = DAG( - 'nested_cycle.op_subdag_1.opSubdag_C', default_args=default_args) - op_subdag_c_task = DummyOperator( - task_id='subdag_c.task', dag=subdag_c) + subdag_c = DAG('nested_cycle.op_subdag_1.opSubdag_C', default_args=default_args) + op_subdag_c_task = DummyOperator(task_id='subdag_c.task', dag=subdag_c) # introduce a loop in opSubdag_C op_subdag_c_task.set_downstream(op_subdag_c_task) return subdag_c def subdag_d(): - subdag_d = DAG( - 'nested_cycle.op_subdag_1.opSubdag_D', default_args=default_args) + subdag_d = DAG('nested_cycle.op_subdag_1.opSubdag_D', default_args=default_args) DummyOperator(task_id='subdag_d.task', dag=subdag_d) return subdag_d @@ -590,10 +550,8 @@ def subdag_1(): SubDagOperator(task_id='opSubdag_D', dag=subdag_1, subdag=subdag_d()) return subdag_1 - op_subdag_0 = SubDagOperator( - task_id='op_subdag_0', dag=dag, subdag=subdag_0()) - op_subdag_1 = SubDagOperator( - task_id='op_subdag_1', dag=dag, subdag=subdag_1()) + op_subdag_0 = SubDagOperator(task_id='op_subdag_0', dag=dag, subdag=subdag_0()) + op_subdag_1 = SubDagOperator(task_id='op_subdag_1', dag=dag, subdag=subdag_1()) op_a = DummyOperator(task_id='A') op_a.set_downstream(op_subdag_0) @@ -655,7 +613,8 @@ def test_serialized_dags_are_written_to_db_on_sync(self): dagbag = DagBag( dag_folder=os.path.join(TEST_DAGS_FOLDER, "test_example_bash_operator.py"), - include_examples=False) + include_examples=False, + ) dagbag.sync_to_db() self.assertFalse(dagbag.read_dags_from_db) @@ -682,11 +641,13 @@ def test_sync_to_db_is_retried(self, mock_bulk_write_to_db, mock_sdag_sync_to_db dagbag.sync_to_db(session=mock_session) # Test that 3 attempts were made to run 'DAG.bulk_write_to_db' successfully - mock_bulk_write_to_db.assert_has_calls([ - mock.call(mock.ANY, session=mock.ANY), - mock.call(mock.ANY, session=mock.ANY), - mock.call(mock.ANY, session=mock.ANY), - ]) + mock_bulk_write_to_db.assert_has_calls( + [ + mock.call(mock.ANY, session=mock.ANY), + mock.call(mock.ANY, session=mock.ANY), + mock.call(mock.ANY, session=mock.ANY), + ] + ) # Assert that rollback is called twice (i.e. whenever OperationalError occurs) mock_session.rollback.assert_has_calls([mock.call(), mock.call()]) # Check that 'SerializedDagModel.bulk_sync_to_db' is also called @@ -760,9 +721,7 @@ def test_cluster_policy_violation(self): """ dag_file = os.path.join(TEST_DAGS_FOLDER, "test_missing_owner.py") - dagbag = DagBag(dag_folder=dag_file, - include_smart_sensor=False, - include_examples=False) + dagbag = DagBag(dag_folder=dag_file, include_smart_sensor=False, include_examples=False) self.assertEqual(set(), set(dagbag.dag_ids)) expected_import_errors = { dag_file: ( @@ -778,12 +737,9 @@ def test_cluster_policy_obeyed(self): """test that dag successfully imported without import errors when tasks obey cluster policy. """ - dag_file = os.path.join(TEST_DAGS_FOLDER, - "test_with_non_default_owner.py") + dag_file = os.path.join(TEST_DAGS_FOLDER, "test_with_non_default_owner.py") - dagbag = DagBag(dag_folder=dag_file, - include_examples=False, - include_smart_sensor=False) + dagbag = DagBag(dag_folder=dag_file, include_examples=False, include_smart_sensor=False) self.assertEqual({"test_with_non_default_owner"}, set(dagbag.dag_ids)) self.assertEqual({}, dagbag.import_errors) diff --git a/tests/models/test_dagcode.py b/tests/models/test_dagcode.py index c16ebf1d2f7f5..ee2eb2aaa61d4 100644 --- a/tests/models/test_dagcode.py +++ b/tests/models/test_dagcode.py @@ -22,6 +22,7 @@ from airflow import AirflowException, example_dags as example_dags_module from airflow.models import DagBag from airflow.models.dagcode import DagCode + # To move it to a shared module. from airflow.utils.file import open_maybe_zipped from airflow.utils.session import create_session @@ -82,7 +83,7 @@ def test_bulk_sync_to_db_half_files(self): """Dg code can be bulk written into database.""" example_dags = make_example_dags(example_dags_module) files = [dag.fileloc for dag in example_dags.values()] - half_files = files[:int(len(files) / 2)] + half_files = files[: int(len(files) / 2)] with create_session() as session: DagCode.bulk_sync_to_db(half_files, session=session) session.commit() @@ -107,11 +108,12 @@ def _compare_example_dags(self, example_dags): dag.fileloc = dag.parent_dag.fileloc self.assertTrue(DagCode.has_dag(dag.fileloc)) dag_fileloc_hash = DagCode.dag_fileloc_hash(dag.fileloc) - result = session.query( - DagCode.fileloc, DagCode.fileloc_hash, DagCode.source_code) \ - .filter(DagCode.fileloc == dag.fileloc) \ - .filter(DagCode.fileloc_hash == dag_fileloc_hash) \ + result = ( + session.query(DagCode.fileloc, DagCode.fileloc_hash, DagCode.source_code) + .filter(DagCode.fileloc == dag.fileloc) + .filter(DagCode.fileloc_hash == dag_fileloc_hash) .one() + ) self.assertEqual(result.fileloc, dag.fileloc) with open_maybe_zipped(dag.fileloc, 'r') as source: @@ -145,9 +147,7 @@ def test_db_code_updated_on_dag_file_change(self): example_dag.sync_to_db() with create_session() as session: - result = session.query(DagCode) \ - .filter(DagCode.fileloc == example_dag.fileloc) \ - .one() + result = session.query(DagCode).filter(DagCode.fileloc == example_dag.fileloc).one() self.assertEqual(result.fileloc, example_dag.fileloc) self.assertIsNotNone(result.source_code) @@ -160,9 +160,7 @@ def test_db_code_updated_on_dag_file_change(self): example_dag.sync_to_db() with create_session() as session: - new_result = session.query(DagCode) \ - .filter(DagCode.fileloc == example_dag.fileloc) \ - .one() + new_result = session.query(DagCode).filter(DagCode.fileloc == example_dag.fileloc).one() self.assertEqual(new_result.fileloc, example_dag.fileloc) self.assertEqual(new_result.source_code, "# dummy code") diff --git a/tests/models/test_dagparam.py b/tests/models/test_dagparam.py index 2f723baea3fa0..1eb28cbc54c1c 100644 --- a/tests/models/test_dagparam.py +++ b/tests/models/test_dagparam.py @@ -57,7 +57,7 @@ def return_num(num): run_id=DagRunType.MANUAL.value, start_date=timezone.utcnow(), execution_date=self.DEFAULT_DATE, - state=State.RUNNING + state=State.RUNNING, ) # pylint: disable=maybe-no-member @@ -84,7 +84,7 @@ def return_num(num): start_date=timezone.utcnow(), execution_date=self.DEFAULT_DATE, state=State.RUNNING, - conf={'value': new_value} + conf={'value': new_value}, ) # pylint: disable=maybe-no-member diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py index ec7f6e69ed573..33bf32d6373e6 100644 --- a/tests/models/test_dagrun.py +++ b/tests/models/test_dagrun.py @@ -37,7 +37,6 @@ class TestDagRun(unittest.TestCase): - @classmethod def setUpClass(cls): cls.dagbag = DagBag(include_examples=True) @@ -46,12 +45,14 @@ def setUp(self): clear_db_runs() clear_db_pools() - def create_dag_run(self, dag, - state=State.RUNNING, - task_states=None, - execution_date=None, - is_backfill=False, - ): + def create_dag_run( + self, + dag, + state=State.RUNNING, + task_states=None, + execution_date=None, + is_backfill=False, + ): now = timezone.utcnow() if execution_date is None: execution_date = now @@ -88,15 +89,11 @@ def test_clear_task_instances_for_backfill_dagrun(self): ti0 = TI(task=task0, execution_date=now) ti0.run() - qry = session.query(TI).filter( - TI.dag_id == dag.dag_id).all() + qry = session.query(TI).filter(TI.dag_id == dag.dag_id).all() clear_task_instances(qry, session) session.commit() ti0.refresh_from_db() - dr0 = session.query(DagRun).filter( - DagRun.dag_id == dag_id, - DagRun.execution_date == now - ).first() + dr0 = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.execution_date == now).first() self.assertEqual(dr0.state, State.RUNNING) def test_dagrun_find(self): @@ -127,33 +124,21 @@ def test_dagrun_find(self): session.commit() - self.assertEqual(1, - len(models.DagRun.find(dag_id=dag_id1, external_trigger=True))) - self.assertEqual(0, - len(models.DagRun.find(dag_id=dag_id1, external_trigger=False))) - self.assertEqual(0, - len(models.DagRun.find(dag_id=dag_id2, external_trigger=True))) - self.assertEqual(1, - len(models.DagRun.find(dag_id=dag_id2, external_trigger=False))) + self.assertEqual(1, len(models.DagRun.find(dag_id=dag_id1, external_trigger=True))) + self.assertEqual(0, len(models.DagRun.find(dag_id=dag_id1, external_trigger=False))) + self.assertEqual(0, len(models.DagRun.find(dag_id=dag_id2, external_trigger=True))) + self.assertEqual(1, len(models.DagRun.find(dag_id=dag_id2, external_trigger=False))) def test_dagrun_success_when_all_skipped(self): """ Tests that a DAG run succeeds when all tasks are skipped """ - dag = DAG( - dag_id='test_dagrun_success_when_all_skipped', - start_date=timezone.datetime(2017, 1, 1) - ) + dag = DAG(dag_id='test_dagrun_success_when_all_skipped', start_date=timezone.datetime(2017, 1, 1)) dag_task1 = ShortCircuitOperator( - task_id='test_short_circuit_false', - dag=dag, - python_callable=lambda: False) - dag_task2 = DummyOperator( - task_id='test_state_skipped1', - dag=dag) - dag_task3 = DummyOperator( - task_id='test_state_skipped2', - dag=dag) + task_id='test_short_circuit_false', dag=dag, python_callable=lambda: False + ) + dag_task2 = DummyOperator(task_id='test_state_skipped1', dag=dag) + dag_task3 = DummyOperator(task_id='test_state_skipped2', dag=dag) dag_task1.set_downstream(dag_task2) dag_task2.set_downstream(dag_task3) @@ -163,19 +148,14 @@ def test_dagrun_success_when_all_skipped(self): 'test_state_skipped2': State.SKIPPED, } - dag_run = self.create_dag_run(dag=dag, - state=State.RUNNING, - task_states=initial_task_states) + dag_run = self.create_dag_run(dag=dag, state=State.RUNNING, task_states=initial_task_states) dag_run.update_state() self.assertEqual(State.SUCCESS, dag_run.state) def test_dagrun_success_conditions(self): session = settings.Session() - dag = DAG( - 'test_dagrun_success_conditions', - start_date=DEFAULT_DATE, - default_args={'owner': 'owner1'}) + dag = DAG('test_dagrun_success_conditions', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) # A -> B # A -> C -> D @@ -191,10 +171,9 @@ def test_dagrun_success_conditions(self): dag.clear() now = timezone.utcnow() - dr = dag.create_dagrun(run_id='test_dagrun_success_conditions', - state=State.RUNNING, - execution_date=now, - start_date=now) + dr = dag.create_dagrun( + run_id='test_dagrun_success_conditions', state=State.RUNNING, execution_date=now, start_date=now + ) # op1 = root ti_op1 = dr.get_task_instance(task_id=op1.task_id) @@ -217,10 +196,7 @@ def test_dagrun_success_conditions(self): def test_dagrun_deadlock(self): session = settings.Session() - dag = DAG( - 'text_dagrun_deadlock', - start_date=DEFAULT_DATE, - default_args={'owner': 'owner1'}) + dag = DAG('text_dagrun_deadlock', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) with dag: op1 = DummyOperator(task_id='A') @@ -230,10 +206,9 @@ def test_dagrun_deadlock(self): dag.clear() now = timezone.utcnow() - dr = dag.create_dagrun(run_id='test_dagrun_deadlock', - state=State.RUNNING, - execution_date=now, - start_date=now) + dr = dag.create_dagrun( + run_id='test_dagrun_deadlock', state=State.RUNNING, execution_date=now, start_date=now + ) ti_op1 = dr.get_task_instance(task_id=op1.task_id) ti_op1.set_state(state=State.SUCCESS, session=session) @@ -250,17 +225,18 @@ def test_dagrun_deadlock(self): def test_dagrun_no_deadlock_with_shutdown(self): session = settings.Session() - dag = DAG('test_dagrun_no_deadlock_with_shutdown', - start_date=DEFAULT_DATE) + dag = DAG('test_dagrun_no_deadlock_with_shutdown', start_date=DEFAULT_DATE) with dag: op1 = DummyOperator(task_id='upstream_task') op2 = DummyOperator(task_id='downstream_task') op2.set_upstream(op1) - dr = dag.create_dagrun(run_id='test_dagrun_no_deadlock_with_shutdown', - state=State.RUNNING, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE) + dr = dag.create_dagrun( + run_id='test_dagrun_no_deadlock_with_shutdown', + state=State.RUNNING, + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + ) upstream_ti = dr.get_task_instance(task_id='upstream_task') upstream_ti.set_state(State.SHUTDOWN, session=session) @@ -269,21 +245,24 @@ def test_dagrun_no_deadlock_with_shutdown(self): def test_dagrun_no_deadlock_with_depends_on_past(self): session = settings.Session() - dag = DAG('test_dagrun_no_deadlock', - start_date=DEFAULT_DATE) + dag = DAG('test_dagrun_no_deadlock', start_date=DEFAULT_DATE) with dag: DummyOperator(task_id='dop', depends_on_past=True) DummyOperator(task_id='tc', task_concurrency=1) dag.clear() - dr = dag.create_dagrun(run_id='test_dagrun_no_deadlock_1', - state=State.RUNNING, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE) - dr2 = dag.create_dagrun(run_id='test_dagrun_no_deadlock_2', - state=State.RUNNING, - execution_date=DEFAULT_DATE + datetime.timedelta(days=1), - start_date=DEFAULT_DATE + datetime.timedelta(days=1)) + dr = dag.create_dagrun( + run_id='test_dagrun_no_deadlock_1', + state=State.RUNNING, + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + ) + dr2 = dag.create_dagrun( + run_id='test_dagrun_no_deadlock_2', + state=State.RUNNING, + execution_date=DEFAULT_DATE + datetime.timedelta(days=1), + start_date=DEFAULT_DATE + datetime.timedelta(days=1), + ) ti1_op1 = dr.get_task_instance(task_id='dop') dr2.get_task_instance(task_id='dop') ti2_op1 = dr.get_task_instance(task_id='tc') @@ -302,22 +281,15 @@ def test_dagrun_no_deadlock_with_depends_on_past(self): def test_dagrun_success_callback(self): def on_success_callable(context): - self.assertEqual( - context['dag_run'].dag_id, - 'test_dagrun_success_callback' - ) + self.assertEqual(context['dag_run'].dag_id, 'test_dagrun_success_callback') dag = DAG( dag_id='test_dagrun_success_callback', start_date=datetime.datetime(2017, 1, 1), on_success_callback=on_success_callable, ) - dag_task1 = DummyOperator( - task_id='test_state_succeeded1', - dag=dag) - dag_task2 = DummyOperator( - task_id='test_state_succeeded2', - dag=dag) + dag_task1 = DummyOperator(task_id='test_state_succeeded1', dag=dag) + dag_task2 = DummyOperator(task_id='test_state_succeeded2', dag=dag) dag_task1.set_downstream(dag_task2) initial_task_states = { @@ -325,9 +297,7 @@ def on_success_callable(context): 'test_state_succeeded2': State.SUCCESS, } - dag_run = self.create_dag_run(dag=dag, - state=State.RUNNING, - task_states=initial_task_states) + dag_run = self.create_dag_run(dag=dag, state=State.RUNNING, task_states=initial_task_states) _, callback = dag_run.update_state() self.assertEqual(State.SUCCESS, dag_run.state) # Callbacks are not added until handle_callback = False is passed to dag_run.update_state() @@ -335,22 +305,15 @@ def on_success_callable(context): def test_dagrun_failure_callback(self): def on_failure_callable(context): - self.assertEqual( - context['dag_run'].dag_id, - 'test_dagrun_failure_callback' - ) + self.assertEqual(context['dag_run'].dag_id, 'test_dagrun_failure_callback') dag = DAG( dag_id='test_dagrun_failure_callback', start_date=datetime.datetime(2017, 1, 1), on_failure_callback=on_failure_callable, ) - dag_task1 = DummyOperator( - task_id='test_state_succeeded1', - dag=dag) - dag_task2 = DummyOperator( - task_id='test_state_failed2', - dag=dag) + dag_task1 = DummyOperator(task_id='test_state_succeeded1', dag=dag) + dag_task2 = DummyOperator(task_id='test_state_failed2', dag=dag) initial_task_states = { 'test_state_succeeded1': State.SUCCESS, @@ -358,9 +321,7 @@ def on_failure_callable(context): } dag_task1.set_downstream(dag_task2) - dag_run = self.create_dag_run(dag=dag, - state=State.RUNNING, - task_states=initial_task_states) + dag_run = self.create_dag_run(dag=dag, state=State.RUNNING, task_states=initial_task_states) _, callback = dag_run.update_state() self.assertEqual(State.FAILED, dag_run.state) # Callbacks are not added until handle_callback = False is passed to dag_run.update_state() @@ -369,8 +330,7 @@ def on_failure_callable(context): def test_dagrun_update_state_with_handle_callback_success(self): def on_success_callable(context): self.assertEqual( - context['dag_run'].dag_id, - 'test_dagrun_update_state_with_handle_callback_success' + context['dag_run'].dag_id, 'test_dagrun_update_state_with_handle_callback_success' ) dag = DAG( @@ -378,12 +338,8 @@ def on_success_callable(context): start_date=datetime.datetime(2017, 1, 1), on_success_callback=on_success_callable, ) - dag_task1 = DummyOperator( - task_id='test_state_succeeded1', - dag=dag) - dag_task2 = DummyOperator( - task_id='test_state_succeeded2', - dag=dag) + dag_task1 = DummyOperator(task_id='test_state_succeeded1', dag=dag) + dag_task2 = DummyOperator(task_id='test_state_succeeded2', dag=dag) dag_task1.set_downstream(dag_task2) initial_task_states = { @@ -402,14 +358,13 @@ def on_success_callable(context): dag_id="test_dagrun_update_state_with_handle_callback_success", execution_date=dag_run.execution_date, is_failure_callback=False, - msg="success" + msg="success", ) def test_dagrun_update_state_with_handle_callback_failure(self): def on_failure_callable(context): self.assertEqual( - context['dag_run'].dag_id, - 'test_dagrun_update_state_with_handle_callback_failure' + context['dag_run'].dag_id, 'test_dagrun_update_state_with_handle_callback_failure' ) dag = DAG( @@ -417,12 +372,8 @@ def on_failure_callable(context): start_date=datetime.datetime(2017, 1, 1), on_failure_callback=on_failure_callable, ) - dag_task1 = DummyOperator( - task_id='test_state_succeeded1', - dag=dag) - dag_task2 = DummyOperator( - task_id='test_state_failed2', - dag=dag) + dag_task1 = DummyOperator(task_id='test_state_succeeded1', dag=dag) + dag_task2 = DummyOperator(task_id='test_state_failed2', dag=dag) dag_task1.set_downstream(dag_task2) initial_task_states = { @@ -441,24 +392,20 @@ def on_failure_callable(context): dag_id="test_dagrun_update_state_with_handle_callback_failure", execution_date=dag_run.execution_date, is_failure_callback=True, - msg="task_failure" + msg="task_failure", ) def test_dagrun_set_state_end_date(self): session = settings.Session() - dag = DAG( - 'test_dagrun_set_state_end_date', - start_date=DEFAULT_DATE, - default_args={'owner': 'owner1'}) + dag = DAG('test_dagrun_set_state_end_date', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) dag.clear() now = timezone.utcnow() - dr = dag.create_dagrun(run_id='test_dagrun_set_state_end_date', - state=State.RUNNING, - execution_date=now, - start_date=now) + dr = dag.create_dagrun( + run_id='test_dagrun_set_state_end_date', state=State.RUNNING, execution_date=now, start_date=now + ) # Initial end_date should be NULL # State.SUCCESS and State.FAILED are all ending state and should set end_date @@ -471,9 +418,7 @@ def test_dagrun_set_state_end_date(self): session.merge(dr) session.commit() - dr_database = session.query(DagRun).filter( - DagRun.run_id == 'test_dagrun_set_state_end_date' - ).one() + dr_database = session.query(DagRun).filter(DagRun.run_id == 'test_dagrun_set_state_end_date').one() self.assertIsNotNone(dr_database.end_date) self.assertEqual(dr.end_date, dr_database.end_date) @@ -481,18 +426,14 @@ def test_dagrun_set_state_end_date(self): session.merge(dr) session.commit() - dr_database = session.query(DagRun).filter( - DagRun.run_id == 'test_dagrun_set_state_end_date' - ).one() + dr_database = session.query(DagRun).filter(DagRun.run_id == 'test_dagrun_set_state_end_date').one() self.assertIsNone(dr_database.end_date) dr.set_state(State.FAILED) session.merge(dr) session.commit() - dr_database = session.query(DagRun).filter( - DagRun.run_id == 'test_dagrun_set_state_end_date' - ).one() + dr_database = session.query(DagRun).filter(DagRun.run_id == 'test_dagrun_set_state_end_date').one() self.assertIsNotNone(dr_database.end_date) self.assertEqual(dr.end_date, dr_database.end_date) @@ -501,9 +442,8 @@ def test_dagrun_update_state_end_date(self): session = settings.Session() dag = DAG( - 'test_dagrun_update_state_end_date', - start_date=DEFAULT_DATE, - default_args={'owner': 'owner1'}) + 'test_dagrun_update_state_end_date', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'} + ) # A -> B with dag: @@ -514,10 +454,12 @@ def test_dagrun_update_state_end_date(self): dag.clear() now = timezone.utcnow() - dr = dag.create_dagrun(run_id='test_dagrun_update_state_end_date', - state=State.RUNNING, - execution_date=now, - start_date=now) + dr = dag.create_dagrun( + run_id='test_dagrun_update_state_end_date', + state=State.RUNNING, + execution_date=now, + start_date=now, + ) # Initial end_date should be NULL # State.SUCCESS and State.FAILED are all ending state and should set end_date @@ -533,9 +475,7 @@ def test_dagrun_update_state_end_date(self): dr.update_state() - dr_database = session.query(DagRun).filter( - DagRun.run_id == 'test_dagrun_update_state_end_date' - ).one() + dr_database = session.query(DagRun).filter(DagRun.run_id == 'test_dagrun_update_state_end_date').one() self.assertIsNotNone(dr_database.end_date) self.assertEqual(dr.end_date, dr_database.end_date) @@ -543,9 +483,7 @@ def test_dagrun_update_state_end_date(self): ti_op2.set_state(state=State.RUNNING, session=session) dr.update_state() - dr_database = session.query(DagRun).filter( - DagRun.run_id == 'test_dagrun_update_state_end_date' - ).one() + dr_database = session.query(DagRun).filter(DagRun.run_id == 'test_dagrun_update_state_end_date').one() self.assertEqual(dr._state, State.RUNNING) self.assertIsNone(dr.end_date) @@ -555,9 +493,7 @@ def test_dagrun_update_state_end_date(self): ti_op2.set_state(state=State.FAILED, session=session) dr.update_state() - dr_database = session.query(DagRun).filter( - DagRun.run_id == 'test_dagrun_update_state_end_date' - ).one() + dr_database = session.query(DagRun).filter(DagRun.run_id == 'test_dagrun_update_state_end_date').one() self.assertIsNotNone(dr_database.end_date) self.assertEqual(dr.end_date, dr_database.end_date) @@ -566,14 +502,8 @@ def test_get_task_instance_on_empty_dagrun(self): """ Make sure that a proper value is returned when a dagrun has no task instances """ - dag = DAG( - dag_id='test_get_task_instance_on_empty_dagrun', - start_date=timezone.datetime(2017, 1, 1) - ) - ShortCircuitOperator( - task_id='test_short_circuit_false', - dag=dag, - python_callable=lambda: False) + dag = DAG(dag_id='test_get_task_instance_on_empty_dagrun', start_date=timezone.datetime(2017, 1, 1)) + ShortCircuitOperator(task_id='test_short_circuit_false', dag=dag, python_callable=lambda: False) session = settings.Session() @@ -597,9 +527,7 @@ def test_get_task_instance_on_empty_dagrun(self): def test_get_latest_runs(self): session = settings.Session() - dag = DAG( - dag_id='test_latest_runs_1', - start_date=DEFAULT_DATE) + dag = DAG(dag_id='test_latest_runs_1', start_date=DEFAULT_DATE) self.create_dag_run(dag, execution_date=timezone.datetime(2015, 1, 1)) self.create_dag_run(dag, execution_date=timezone.datetime(2015, 1, 2)) dagruns = models.DagRun.get_latest_runs(session) @@ -614,11 +542,9 @@ def test_is_backfill(self): dagrun = self.create_dag_run(dag, execution_date=DEFAULT_DATE) dagrun.run_type = DagRunType.BACKFILL_JOB - dagrun2 = self.create_dag_run( - dag, execution_date=DEFAULT_DATE + datetime.timedelta(days=1)) + dagrun2 = self.create_dag_run(dag, execution_date=DEFAULT_DATE + datetime.timedelta(days=1)) - dagrun3 = self.create_dag_run( - dag, execution_date=DEFAULT_DATE + datetime.timedelta(days=2)) + dagrun3 = self.create_dag_run(dag, execution_date=DEFAULT_DATE + datetime.timedelta(days=2)) dagrun3.run_id = None self.assertTrue(dagrun.is_backfill) @@ -694,13 +620,15 @@ def mutate_task_instance(task_instance): task = dagrun.get_task_instances()[0] assert task.queue == 'queue1' - @parameterized.expand([ - (State.SUCCESS, True), - (State.SKIPPED, True), - (State.RUNNING, False), - (State.FAILED, False), - (State.NONE, False), - ]) + @parameterized.expand( + [ + (State.SUCCESS, True), + (State.SKIPPED, True), + (State.RUNNING, False), + (State.FAILED, False), + (State.NONE, False), + ] + ) def test_depends_on_past(self, prev_ti_state, is_ti_success): dag_id = 'test_depends_on_past' @@ -718,13 +646,15 @@ def test_depends_on_past(self, prev_ti_state, is_ti_success): ti.run() self.assertEqual(ti.state == State.SUCCESS, is_ti_success) - @parameterized.expand([ - (State.SUCCESS, True), - (State.SKIPPED, True), - (State.RUNNING, False), - (State.FAILED, False), - (State.NONE, False), - ]) + @parameterized.expand( + [ + (State.SUCCESS, True), + (State.SKIPPED, True), + (State.RUNNING, False), + (State.FAILED, False), + (State.NONE, False), + ] + ) def test_wait_for_downstream(self, prev_ti_state, is_ti_success): dag_id = 'test_wait_for_downstream' dag = self.dagbag.get_dag(dag_id) @@ -751,13 +681,8 @@ def test_next_dagruns_to_examine_only_unpaused(self): Check that "next_dagruns_to_examine" ignores runs from paused/inactive DAGs """ - dag = DAG( - dag_id='test_dags', - start_date=DEFAULT_DATE) - DummyOperator( - task_id='dummy', - dag=dag, - owner='airflow') + dag = DAG(dag_id='test_dags', start_date=DEFAULT_DATE) + DummyOperator(task_id='dummy', dag=dag, owner='airflow') session = settings.Session() orm_dag = DagModel( @@ -769,11 +694,13 @@ def test_next_dagruns_to_examine_only_unpaused(self): ) session.add(orm_dag) session.flush() - dr = dag.create_dagrun(run_type=DagRunType.SCHEDULED, - state=State.RUNNING, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE, - session=session) + dr = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + state=State.RUNNING, + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + session=session, + ) runs = DagRun.next_dagruns_to_examine(session).all() diff --git a/tests/models/test_pool.py b/tests/models/test_pool.py index 1860ed446935a..bea4b2ebcfa94 100644 --- a/tests/models/test_pool.py +++ b/tests/models/test_pool.py @@ -31,7 +31,6 @@ class TestPool(unittest.TestCase): - def setUp(self): clear_db_runs() clear_db_pools() @@ -44,7 +43,8 @@ def test_open_slots(self): pool = Pool(pool='test_pool', slots=5) dag = DAG( dag_id='test_open_slots', - start_date=DEFAULT_DATE, ) + start_date=DEFAULT_DATE, + ) op1 = DummyOperator(task_id='dummy1', dag=dag, pool='test_pool') op2 = DummyOperator(task_id='dummy2', dag=dag, pool='test_pool') ti1 = TI(task=op1, execution_date=DEFAULT_DATE) @@ -63,26 +63,30 @@ def test_open_slots(self): self.assertEqual(1, pool.running_slots()) # pylint: disable=no-value-for-parameter self.assertEqual(1, pool.queued_slots()) # pylint: disable=no-value-for-parameter self.assertEqual(2, pool.occupied_slots()) # pylint: disable=no-value-for-parameter - self.assertEqual({ - "default_pool": { - "open": 128, - "queued": 0, - "total": 128, - "running": 0, - }, - "test_pool": { - "open": 3, - "queued": 1, - "running": 1, - "total": 5, + self.assertEqual( + { + "default_pool": { + "open": 128, + "queued": 0, + "total": 128, + "running": 0, + }, + "test_pool": { + "open": 3, + "queued": 1, + "running": 1, + "total": 5, + }, }, - }, pool.slots_stats()) + pool.slots_stats(), + ) def test_infinite_slots(self): pool = Pool(pool='test_pool', slots=-1) dag = DAG( dag_id='test_infinite_slots', - start_date=DEFAULT_DATE, ) + start_date=DEFAULT_DATE, + ) op1 = DummyOperator(task_id='dummy1', dag=dag, pool='test_pool') op2 = DummyOperator(task_id='dummy2', dag=dag, pool='test_pool') ti1 = TI(task=op1, execution_date=DEFAULT_DATE) @@ -108,7 +112,8 @@ def test_default_pool_open_slots(self): dag = DAG( dag_id='test_default_pool_open_slots', - start_date=DEFAULT_DATE, ) + start_date=DEFAULT_DATE, + ) op1 = DummyOperator(task_id='dummy1', dag=dag) op2 = DummyOperator(task_id='dummy2', dag=dag, pool_slots=2) ti1 = TI(task=op1, execution_date=DEFAULT_DATE) diff --git a/tests/models/test_renderedtifields.py b/tests/models/test_renderedtifields.py index ec988a7e7227e..d2eab8d70e434 100644 --- a/tests/models/test_renderedtifields.py +++ b/tests/models/test_renderedtifields.py @@ -68,36 +68,41 @@ def setUp(self): def tearDown(self): clear_rendered_ti_fields() - @parameterized.expand([ - (None, None), - ([], []), - ({}, {}), - ("test-string", "test-string"), - ({"foo": "bar"}, {"foo": "bar"}), - ("{{ task.task_id }}", "test"), - (date(2018, 12, 6), "2018-12-06"), - (datetime(2018, 12, 6, 10, 55), "2018-12-06 10:55:00+00:00"), - ( - ClassWithCustomAttributes( - att1="{{ task.task_id }}", att2="{{ task.task_id }}", template_fields=["att1"]), - "ClassWithCustomAttributes({'att1': 'test', 'att2': '{{ task.task_id }}', " - "'template_fields': ['att1']})", - ), - ( - ClassWithCustomAttributes(nested1=ClassWithCustomAttributes(att1="{{ task.task_id }}", - att2="{{ task.task_id }}", - template_fields=["att1"]), - nested2=ClassWithCustomAttributes(att3="{{ task.task_id }}", - att4="{{ task.task_id }}", - template_fields=["att3"]), - template_fields=["nested1"]), - "ClassWithCustomAttributes({'nested1': ClassWithCustomAttributes(" - "{'att1': 'test', 'att2': '{{ task.task_id }}', 'template_fields': ['att1']}), " - "'nested2': ClassWithCustomAttributes(" - "{'att3': '{{ task.task_id }}', 'att4': '{{ task.task_id }}', 'template_fields': ['att3']}), " - "'template_fields': ['nested1']})", - ), - ]) + @parameterized.expand( + [ + (None, None), + ([], []), + ({}, {}), + ("test-string", "test-string"), + ({"foo": "bar"}, {"foo": "bar"}), + ("{{ task.task_id }}", "test"), + (date(2018, 12, 6), "2018-12-06"), + (datetime(2018, 12, 6, 10, 55), "2018-12-06 10:55:00+00:00"), + ( + ClassWithCustomAttributes( + att1="{{ task.task_id }}", att2="{{ task.task_id }}", template_fields=["att1"] + ), + "ClassWithCustomAttributes({'att1': 'test', 'att2': '{{ task.task_id }}', " + "'template_fields': ['att1']})", + ), + ( + ClassWithCustomAttributes( + nested1=ClassWithCustomAttributes( + att1="{{ task.task_id }}", att2="{{ task.task_id }}", template_fields=["att1"] + ), + nested2=ClassWithCustomAttributes( + att3="{{ task.task_id }}", att4="{{ task.task_id }}", template_fields=["att3"] + ), + template_fields=["nested1"], + ), + "ClassWithCustomAttributes({'nested1': ClassWithCustomAttributes(" + "{'att1': 'test', 'att2': '{{ task.task_id }}', 'template_fields': ['att1']}), " + "'nested2': ClassWithCustomAttributes(" + "{'att3': '{{ task.task_id }}', 'att4': '{{ task.task_id }}', 'template_fields': ['att3']}), " + "'template_fields': ['nested1']})", + ), + ] + ) def test_get_templated_fields(self, templated_field, expected_rendered_field): """ Test that template_fields are rendered correctly, stored in the Database, @@ -118,8 +123,8 @@ def test_get_templated_fields(self, templated_field, expected_rendered_field): session.add(rtif) self.assertEqual( - {"bash_command": expected_rendered_field, "env": None}, - RTIF.get_templated_fields(ti=ti)) + {"bash_command": expected_rendered_field, "env": None}, RTIF.get_templated_fields(ti=ti) + ) # Test the else part of get_templated_fields # i.e. for the TIs that are not stored in RTIF table @@ -130,14 +135,16 @@ def test_get_templated_fields(self, templated_field, expected_rendered_field): ti2 = TI(task_2, EXECUTION_DATE) self.assertIsNone(RTIF.get_templated_fields(ti=ti2)) - @parameterized.expand([ - (0, 1, 0, 1), - (1, 1, 1, 1), - (1, 0, 1, 0), - (3, 1, 1, 1), - (4, 2, 2, 1), - (5, 2, 2, 1), - ]) + @parameterized.expand( + [ + (0, 1, 0, 1), + (1, 1, 1, 1), + (1, 0, 1, 0), + (3, 1, 1, 1), + (4, 2, 2, 1), + (5, 2, 2, 1), + ] + ) def test_delete_old_records(self, rtif_num, num_to_keep, remaining_rtifs, expected_query_count): """ Test that old records are deleted from rendered_task_instance_fields table @@ -156,8 +163,7 @@ def test_delete_old_records(self, rtif_num, num_to_keep, remaining_rtifs, expect session.add_all(rtif_list) session.commit() - result = session.query(RTIF)\ - .filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all() + result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all() for rtif in rtif_list: self.assertIn(rtif, result) @@ -167,8 +173,7 @@ def test_delete_old_records(self, rtif_num, num_to_keep, remaining_rtifs, expect # Verify old records are deleted and only 'num_to_keep' records are kept with assert_queries_count(expected_query_count): RTIF.delete_old_records(task_id=task.task_id, dag_id=task.dag_id, num_to_keep=num_to_keep) - result = session.query(RTIF) \ - .filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all() + result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all() self.assertEqual(remaining_rtifs, len(result)) def test_write(self): @@ -186,17 +191,16 @@ def test_write(self): rtif = RTIF(TI(task=task, execution_date=EXECUTION_DATE)) rtif.write() - result = session.query(RTIF.dag_id, RTIF.task_id, RTIF.rendered_fields).filter( - RTIF.dag_id == rtif.dag_id, - RTIF.task_id == rtif.task_id, - RTIF.execution_date == rtif.execution_date - ).first() - self.assertEqual( - ( - 'test_write', 'test', { - 'bash_command': 'echo test_val', 'env': None - } - ), result) + result = ( + session.query(RTIF.dag_id, RTIF.task_id, RTIF.rendered_fields) + .filter( + RTIF.dag_id == rtif.dag_id, + RTIF.task_id == rtif.task_id, + RTIF.execution_date == rtif.execution_date, + ) + .first() + ) + self.assertEqual(('test_write', 'test', {'bash_command': 'echo test_val', 'env': None}), result) # Test that overwrite saves new values to the DB Variable.delete("test_key") @@ -208,14 +212,15 @@ def test_write(self): rtif_updated = RTIF(TI(task=updated_task, execution_date=EXECUTION_DATE)) rtif_updated.write() - result_updated = session.query(RTIF.dag_id, RTIF.task_id, RTIF.rendered_fields).filter( - RTIF.dag_id == rtif_updated.dag_id, - RTIF.task_id == rtif_updated.task_id, - RTIF.execution_date == rtif_updated.execution_date - ).first() + result_updated = ( + session.query(RTIF.dag_id, RTIF.task_id, RTIF.rendered_fields) + .filter( + RTIF.dag_id == rtif_updated.dag_id, + RTIF.task_id == rtif_updated.task_id, + RTIF.execution_date == rtif_updated.execution_date, + ) + .first() + ) self.assertEqual( - ( - 'test_write', 'test', { - 'bash_command': 'echo test_val_updated', 'env': None - } - ), result_updated) + ('test_write', 'test', {'bash_command': 'echo test_val_updated', 'env': None}), result_updated + ) diff --git a/tests/models/test_sensorinstance.py b/tests/models/test_sensorinstance.py index 168d97f82d916..6246df8c14d98 100644 --- a/tests/models/test_sensorinstance.py +++ b/tests/models/test_sensorinstance.py @@ -24,22 +24,20 @@ class SensorInstanceTest(unittest.TestCase): - def test_get_classpath(self): # Test the classpath in/out airflow - obj1 = NamedHivePartitionSensor( - partition_names=['test_partition'], - task_id='meta_partition_test_1') + obj1 = NamedHivePartitionSensor(partition_names=['test_partition'], task_id='meta_partition_test_1') obj1_classpath = SensorInstance.get_classpath(obj1) - obj1_importpath = "airflow.providers.apache.hive." \ - "sensors.named_hive_partition.NamedHivePartitionSensor" + obj1_importpath = ( + "airflow.providers.apache.hive.sensors.named_hive_partition.NamedHivePartitionSensor" + ) self.assertEqual(obj1_classpath, obj1_importpath) def test_callable(): return - obj3 = PythonSensor(python_callable=test_callable, - task_id='python_sensor_test') + + obj3 = PythonSensor(python_callable=test_callable, task_id='python_sensor_test') obj3_classpath = SensorInstance.get_classpath(obj3) obj3_importpath = "airflow.sensors.python.PythonSensor" diff --git a/tests/models/test_serialized_dag.py b/tests/models/test_serialized_dag.py index a6309327f2af4..590ceed56220d 100644 --- a/tests/models/test_serialized_dag.py +++ b/tests/models/test_serialized_dag.py @@ -52,8 +52,7 @@ def tearDown(self): def test_dag_fileloc_hash(self): """Verifies the correctness of hashing file path.""" - self.assertEqual(DagCode.dag_fileloc_hash('/airflow/dags/test_dag.py'), - 33826252060516589) + self.assertEqual(DagCode.dag_fileloc_hash('/airflow/dags/test_dag.py'), 33826252060516589) def _write_example_dags(self): example_dags = make_example_dags(example_dags_module) @@ -68,8 +67,7 @@ def test_write_dag(self): with create_session() as session: for dag in example_dags.values(): self.assertTrue(SDM.has_dag(dag.dag_id)) - result = session.query( - SDM.fileloc, SDM.data).filter(SDM.dag_id == dag.dag_id).one() + result = session.query(SDM.fileloc, SDM.data).filter(SDM.dag_id == dag.dag_id).one() self.assertTrue(result.fileloc == dag.full_filepath) # Verifies JSON schema. @@ -142,7 +140,9 @@ def test_remove_dags_by_filepath(self): def test_bulk_sync_to_db(self): dags = [ - DAG("dag_1"), DAG("dag_2"), DAG("dag_3"), + DAG("dag_1"), + DAG("dag_2"), + DAG("dag_3"), ] with assert_queries_count(10): SDM.bulk_sync_to_db(dags) diff --git a/tests/models/test_skipmixin.py b/tests/models/test_skipmixin.py index 0a96ee23b6195..a9145701c6458 100644 --- a/tests/models/test_skipmixin.py +++ b/tests/models/test_skipmixin.py @@ -35,7 +35,6 @@ class TestSkipMixin(unittest.TestCase): - @patch('airflow.utils.timezone.utcnow') def test_skip(self, mock_now): session = settings.Session() @@ -52,11 +51,7 @@ def test_skip(self, mock_now): execution_date=now, state=State.FAILED, ) - SkipMixin().skip( - dag_run=dag_run, - execution_date=now, - tasks=tasks, - session=session) + SkipMixin().skip(dag_run=dag_run, execution_date=now, tasks=tasks, session=session) session.query(TI).filter( TI.dag_id == 'dag', @@ -77,11 +72,7 @@ def test_skip_none_dagrun(self, mock_now): ) with dag: tasks = [DummyOperator(task_id='task')] - SkipMixin().skip( - dag_run=None, - execution_date=now, - tasks=tasks, - session=session) + SkipMixin().skip(dag_run=None, execution_date=now, tasks=tasks, session=session) session.query(TI).filter( TI.dag_id == 'dag', @@ -113,10 +104,7 @@ def test_skip_all_except(self): ti2 = TI(task2, execution_date=DEFAULT_DATE) ti3 = TI(task3, execution_date=DEFAULT_DATE) - SkipMixin().skip_all_except( - ti=ti1, - branch_task_ids=['task2'] - ) + SkipMixin().skip_all_except(ti=ti1, branch_task_ids=['task2']) def get_state(ti): ti.refresh_from_db() diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index bcf4ca0346530..f91d9070d9189 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -34,7 +34,14 @@ from airflow.exceptions import AirflowException, AirflowFailException, AirflowSkipException from airflow.jobs.scheduler_job import SchedulerJob from airflow.models import ( - DAG, DagModel, DagRun, Pool, RenderedTaskInstanceFields, TaskInstance as TI, TaskReschedule, Variable, + DAG, + DagModel, + DagRun, + Pool, + RenderedTaskInstanceFields, + TaskInstance as TI, + TaskReschedule, + Variable, ) from airflow.operators.bash import BashOperator from airflow.operators.dummy_operator import DummyOperator @@ -74,15 +81,17 @@ def wrap_task_instance(self, ti): def success_handler(self, context): # pylint: disable=unused-argument self.callback_ran = True session = settings.Session() - temp_instance = session.query(TI).filter( - TI.task_id == self.task_id).filter( - TI.dag_id == self.dag_id).filter( - TI.execution_date == self.execution_date).one() + temp_instance = ( + session.query(TI) + .filter(TI.task_id == self.task_id) + .filter(TI.dag_id == self.dag_id) + .filter(TI.execution_date == self.execution_date) + .one() + ) self.task_state_in_callback = temp_instance.state class TestTaskInstance(unittest.TestCase): - @staticmethod def clean_db(): db.clear_db_dags() @@ -106,8 +115,7 @@ def test_set_task_dates(self): """ Test that tasks properly take start/end dates from DAGs """ - dag = DAG('dag', start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=10)) + dag = DAG('dag', start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=10)) op1 = DummyOperator(task_id='op_1', owner='test') @@ -115,31 +123,29 @@ def test_set_task_dates(self): # dag should assign its dates to op1 because op1 has no dates dag.add_task(op1) - self.assertTrue( - op1.start_date == dag.start_date and op1.end_date == dag.end_date) + self.assertTrue(op1.start_date == dag.start_date and op1.end_date == dag.end_date) op2 = DummyOperator( task_id='op_2', owner='test', start_date=DEFAULT_DATE - datetime.timedelta(days=1), - end_date=DEFAULT_DATE + datetime.timedelta(days=11)) + end_date=DEFAULT_DATE + datetime.timedelta(days=11), + ) # dag should assign its dates to op2 because they are more restrictive dag.add_task(op2) - self.assertTrue( - op2.start_date == dag.start_date and op2.end_date == dag.end_date) + self.assertTrue(op2.start_date == dag.start_date and op2.end_date == dag.end_date) op3 = DummyOperator( task_id='op_3', owner='test', start_date=DEFAULT_DATE + datetime.timedelta(days=1), - end_date=DEFAULT_DATE + datetime.timedelta(days=9)) + end_date=DEFAULT_DATE + datetime.timedelta(days=9), + ) # op3 should keep its dates because they are more restrictive dag.add_task(op3) - self.assertTrue( - op3.start_date == DEFAULT_DATE + datetime.timedelta(days=1)) - self.assertTrue( - op3.end_date == DEFAULT_DATE + datetime.timedelta(days=9)) + self.assertTrue(op3.start_date == DEFAULT_DATE + datetime.timedelta(days=1)) + self.assertTrue(op3.end_date == DEFAULT_DATE + datetime.timedelta(days=9)) def test_timezone_awareness(self): naive_datetime = DEFAULT_DATE.replace(tzinfo=None) @@ -168,9 +174,9 @@ def test_timezone_awareness(self): def test_task_naive_datetime(self): naive_datetime = DEFAULT_DATE.replace(tzinfo=None) - op_no_dag = DummyOperator(task_id='test_task_naive_datetime', - start_date=naive_datetime, - end_date=naive_datetime) + op_no_dag = DummyOperator( + task_id='test_task_naive_datetime', start_date=naive_datetime, end_date=naive_datetime + ) self.assertTrue(op_no_dag.start_date.tzinfo) self.assertTrue(op_no_dag.end_date.tzinfo) @@ -213,9 +219,7 @@ def test_infer_dag(self): op4 = DummyOperator(task_id='test_op_4', owner='test', dag=dag2) # double check dags - self.assertEqual( - [i.has_dag() for i in [op1, op2, op3, op4]], - [False, False, True, True]) + self.assertEqual([i.has_dag() for i in [op1, op2, op3, op4]], [False, False, True, True]) # can't combine operators with no dags self.assertRaises(AirflowException, op1.set_downstream, op2) @@ -246,8 +250,12 @@ def test_bitshift_compose_operators(self): def test_requeue_over_dag_concurrency(self, mock_concurrency_reached): mock_concurrency_reached.return_value = True - dag = DAG(dag_id='test_requeue_over_dag_concurrency', start_date=DEFAULT_DATE, - max_active_runs=1, concurrency=2) + dag = DAG( + dag_id='test_requeue_over_dag_concurrency', + start_date=DEFAULT_DATE, + max_active_runs=1, + concurrency=2, + ) task = DummyOperator(task_id='test_requeue_over_dag_concurrency_op', dag=dag) ti = TI(task=task, execution_date=timezone.utcnow(), state=State.QUEUED) @@ -259,10 +267,13 @@ def test_requeue_over_dag_concurrency(self, mock_concurrency_reached): self.assertEqual(ti.state, State.NONE) def test_requeue_over_task_concurrency(self): - dag = DAG(dag_id='test_requeue_over_task_concurrency', start_date=DEFAULT_DATE, - max_active_runs=1, concurrency=2) - task = DummyOperator(task_id='test_requeue_over_task_concurrency_op', dag=dag, - task_concurrency=0) + dag = DAG( + dag_id='test_requeue_over_task_concurrency', + start_date=DEFAULT_DATE, + max_active_runs=1, + concurrency=2, + ) + task = DummyOperator(task_id='test_requeue_over_task_concurrency_op', dag=dag, task_concurrency=0) ti = TI(task=task, execution_date=timezone.utcnow(), state=State.QUEUED) # TI.run() will sync from DB before validating deps. @@ -273,10 +284,13 @@ def test_requeue_over_task_concurrency(self): self.assertEqual(ti.state, State.NONE) def test_requeue_over_pool_concurrency(self): - dag = DAG(dag_id='test_requeue_over_pool_concurrency', start_date=DEFAULT_DATE, - max_active_runs=1, concurrency=2) - task = DummyOperator(task_id='test_requeue_over_pool_concurrency_op', dag=dag, - task_concurrency=0) + dag = DAG( + dag_id='test_requeue_over_pool_concurrency', + start_date=DEFAULT_DATE, + max_active_runs=1, + concurrency=2, + ) + task = DummyOperator(task_id='test_requeue_over_pool_concurrency_op', dag=dag, task_concurrency=0) ti = TI(task=task, execution_date=timezone.utcnow(), state=State.QUEUED) # TI.run() will sync from DB before validating deps. @@ -297,9 +311,9 @@ def test_not_requeue_non_requeueable_task_instance(self): dag=dag, pool='test_pool', owner='airflow', - start_date=timezone.datetime(2016, 2, 1, 0, 0, 0)) - ti = TI( - task=task, execution_date=timezone.utcnow(), state=State.QUEUED) + start_date=timezone.datetime(2016, 2, 1, 0, 0, 0), + ) + ti = TI(task=task, execution_date=timezone.utcnow(), state=State.QUEUED) with create_session() as session: session.add(ti) session.commit() @@ -309,16 +323,13 @@ def test_not_requeue_non_requeueable_task_instance(self): patch_dict = {} for dep in all_non_requeueable_deps: class_name = dep.__class__.__name__ - dep_patch = patch('{}.{}.{}'.format(dep.__module__, class_name, - dep._get_dep_statuses.__name__)) + dep_patch = patch(f'{dep.__module__}.{class_name}.{dep._get_dep_statuses.__name__}') method_patch = dep_patch.start() - method_patch.return_value = iter([TIDepStatus('mock_' + class_name, True, - 'mock')]) + method_patch.return_value = iter([TIDepStatus('mock_' + class_name, True, 'mock')]) patch_dict[class_name] = (dep_patch, method_patch) for class_name, (dep_patch, method_patch) in patch_dict.items(): - method_patch.return_value = iter( - [TIDepStatus('mock_' + class_name, False, 'mock')]) + method_patch.return_value = iter([TIDepStatus('mock_' + class_name, False, 'mock')]) ti.run() self.assertEqual(ti.state, State.QUEUED) dep_patch.return_value = TIDepStatus('mock_' + class_name, True, 'mock') @@ -331,17 +342,16 @@ def test_mark_non_runnable_task_as_success(self): test that running task with mark_success param update task state as SUCCESS without running task despite it fails dependency checks. """ - non_runnable_state = ( - set(State.task_states) - RUNNABLE_STATES - set(State.SUCCESS)).pop() + non_runnable_state = (set(State.task_states) - RUNNABLE_STATES - set(State.SUCCESS)).pop() dag = models.DAG(dag_id='test_mark_non_runnable_task_as_success') task = DummyOperator( task_id='test_mark_non_runnable_task_as_success_op', dag=dag, pool='test_pool', owner='airflow', - start_date=timezone.datetime(2016, 2, 1, 0, 0, 0)) - ti = TI( - task=task, execution_date=timezone.utcnow(), state=non_runnable_state) + start_date=timezone.datetime(2016, 2, 1, 0, 0, 0), + ) + ti = TI(task=task, execution_date=timezone.utcnow(), state=non_runnable_state) # TI.run() will sync from DB before validating deps. with create_session() as session: session.add(ti) @@ -361,9 +371,13 @@ def test_run_pooling_task(self): test that running a task in an existing pool update task state as SUCCESS. """ dag = models.DAG(dag_id='test_run_pooling_task') - task = DummyOperator(task_id='test_run_pooling_task_op', dag=dag, - pool='test_pool', owner='airflow', - start_date=timezone.datetime(2016, 2, 1, 0, 0, 0)) + task = DummyOperator( + task_id='test_run_pooling_task_op', + dag=dag, + pool='test_pool', + owner='airflow', + start_date=timezone.datetime(2016, 2, 1, 0, 0, 0), + ) ti = TI(task=task, execution_date=timezone.utcnow()) dag.create_dagrun( @@ -380,11 +394,17 @@ def test_pool_slots_property(self): """ test that try to create a task with pool_slots less than 1 """ + def create_task_instance(): dag = models.DAG(dag_id='test_run_pooling_task') - task = DummyOperator(task_id='test_run_pooling_task_op', dag=dag, - pool='test_pool', pool_slots=0, owner='airflow', - start_date=timezone.datetime(2016, 2, 1, 0, 0, 0)) + task = DummyOperator( + task_id='test_run_pooling_task_op', + dag=dag, + pool='test_pool', + pool_slots=0, + owner='airflow', + start_date=timezone.datetime(2016, 2, 1, 0, 0, 0), + ) return TI(task=task, execution_date=timezone.utcnow()) self.assertRaises(AirflowException, create_task_instance) @@ -395,9 +415,12 @@ def test_ti_updates_with_task(self, session=None): test that updating the executor_config propogates to the TaskInstance DB """ with models.DAG(dag_id='test_run_pooling_task') as dag: - task = DummyOperator(task_id='test_run_pooling_task_op', owner='airflow', - executor_config={'foo': 'bar'}, - start_date=timezone.datetime(2016, 2, 1, 0, 0, 0)) + task = DummyOperator( + task_id='test_run_pooling_task_op', + owner='airflow', + executor_config={'foo': 'bar'}, + start_date=timezone.datetime(2016, 2, 1, 0, 0, 0), + ) ti = TI(task=task, execution_date=timezone.utcnow()) dag.create_dagrun( @@ -411,9 +434,12 @@ def test_ti_updates_with_task(self, session=None): tis = dag.get_task_instances() self.assertEqual({'foo': 'bar'}, tis[0].executor_config) with models.DAG(dag_id='test_run_pooling_task') as dag: - task2 = DummyOperator(task_id='test_run_pooling_task_op', owner='airflow', - executor_config={'bar': 'baz'}, - start_date=timezone.datetime(2016, 2, 1, 0, 0, 0)) + task2 = DummyOperator( + task_id='test_run_pooling_task_op', + owner='airflow', + executor_config={'bar': 'baz'}, + start_date=timezone.datetime(2016, 2, 1, 0, 0, 0), + ) ti = TI(task=task2, execution_date=timezone.utcnow()) @@ -440,7 +466,8 @@ def test_run_pooling_task_with_mark_success(self): dag=dag, pool='test_pool', owner='airflow', - start_date=timezone.datetime(2016, 2, 1, 0, 0, 0)) + start_date=timezone.datetime(2016, 2, 1, 0, 0, 0), + ) ti = TI(task=task, execution_date=timezone.utcnow()) dag.create_dagrun( @@ -466,7 +493,8 @@ def raise_skip_exception(): dag=dag, python_callable=raise_skip_exception, owner='airflow', - start_date=timezone.datetime(2016, 2, 1, 0, 0, 0)) + start_date=timezone.datetime(2016, 2, 1, 0, 0, 0), + ) ti = TI(task=task, execution_date=timezone.utcnow()) dag.create_dagrun( execution_date=ti.execution_date, @@ -488,7 +516,8 @@ def test_retry_delay(self): retry_delay=datetime.timedelta(seconds=3), dag=dag, owner='airflow', - start_date=timezone.datetime(2016, 2, 1, 0, 0, 0)) + start_date=timezone.datetime(2016, 2, 1, 0, 0, 0), + ) def run_with_error(ti): try: @@ -532,7 +561,8 @@ def test_retry_handling(self): retry_delay=datetime.timedelta(seconds=0), dag=dag, owner='test_pool', - start_date=timezone.datetime(2016, 2, 1, 0, 0, 0)) + start_date=timezone.datetime(2016, 2, 1, 0, 0, 0), + ) def run_with_error(ti): try: @@ -540,8 +570,7 @@ def run_with_error(ti): except AirflowException: pass - ti = TI( - task=task, execution_date=timezone.utcnow()) + ti = TI(task=task, execution_date=timezone.utcnow()) self.assertEqual(ti.try_number, 1) # first run -- up for retry @@ -588,9 +617,9 @@ def test_next_retry_datetime(self): max_retry_delay=max_delay, dag=dag, owner='airflow', - start_date=timezone.datetime(2016, 2, 1, 0, 0, 0)) - ti = TI( - task=task, execution_date=DEFAULT_DATE) + start_date=timezone.datetime(2016, 2, 1, 0, 0, 0), + ) + ti = TI(task=task, execution_date=DEFAULT_DATE) ti.end_date = pendulum.instance(timezone.utcnow()) date = ti.next_retry_datetime() @@ -632,9 +661,9 @@ def test_next_retry_datetime_short_intervals(self): max_retry_delay=max_delay, dag=dag, owner='airflow', - start_date=timezone.datetime(2016, 2, 1, 0, 0, 0)) - ti = TI( - task=task, execution_date=DEFAULT_DATE) + start_date=timezone.datetime(2016, 2, 1, 0, 0, 0), + ) + ti = TI(task=task, execution_date=DEFAULT_DATE) ti.end_date = pendulum.instance(timezone.utcnow()) date = ti.next_retry_datetime() @@ -666,7 +695,8 @@ def func(): dag=dag, owner='airflow', pool='test_pool', - start_date=timezone.datetime(2016, 2, 1, 0, 0, 0)) + start_date=timezone.datetime(2016, 2, 1, 0, 0, 0), + ) ti = TI(task=task, execution_date=timezone.utcnow()) self.assertEqual(ti._try_number, 0) @@ -678,10 +708,15 @@ def func(): run_type=DagRunType.SCHEDULED, ) - def run_ti_and_assert(run_date, expected_start_date, expected_end_date, - expected_duration, - expected_state, expected_try_number, - expected_task_reschedule_count): + def run_ti_and_assert( + run_date, + expected_start_date, + expected_end_date, + expected_duration, + expected_state, + expected_try_number, + expected_task_reschedule_count, + ): with freeze_time(run_date): try: ti.run() @@ -767,16 +802,22 @@ def func(): dag=dag, owner='airflow', pool='test_pool', - start_date=timezone.datetime(2016, 2, 1, 0, 0, 0)) + start_date=timezone.datetime(2016, 2, 1, 0, 0, 0), + ) ti = TI(task=task, execution_date=timezone.utcnow()) self.assertEqual(ti._try_number, 0) self.assertEqual(ti.try_number, 1) - def run_ti_and_assert(run_date, expected_start_date, expected_end_date, - expected_duration, - expected_state, expected_try_number, - expected_task_reschedule_count): + def run_ti_and_assert( + run_date, + expected_start_date, + expected_end_date, + expected_duration, + expected_state, + expected_try_number, + expected_task_reschedule_count, + ): with freeze_time(run_date): try: ti.run() @@ -808,10 +849,7 @@ def run_ti_and_assert(run_date, expected_start_date, expected_end_date, self.assertFalse(trs) def test_depends_on_past(self): - dag = DAG( - dag_id='test_depends_on_past', - start_date=DEFAULT_DATE - ) + dag = DAG(dag_id='test_depends_on_past', start_date=DEFAULT_DATE) task = DummyOperator( task_id='test_dop_task', @@ -836,10 +874,7 @@ def test_depends_on_past(self): self.assertIs(ti.state, None) # ignore first depends_on_past to allow the run - task.run( - start_date=run_date, - end_date=run_date, - ignore_first_depends_on_past=True) + task.run(start_date=run_date, end_date=run_date, ignore_first_depends_on_past=True) ti.refresh_from_db() self.assertEqual(ti.state, State.SUCCESS) @@ -847,58 +882,64 @@ def test_depends_on_past(self): # of the trigger_rule under various circumstances # Numeric fields are in order: # successes, skipped, failed, upstream_failed, done - @parameterized.expand([ - - # - # Tests for all_success - # - ['all_success', 5, 0, 0, 0, 0, True, None, True], - ['all_success', 2, 0, 0, 0, 0, True, None, False], - ['all_success', 2, 0, 1, 0, 0, True, State.UPSTREAM_FAILED, False], - ['all_success', 2, 1, 0, 0, 0, True, State.SKIPPED, False], - # - # Tests for one_success - # - ['one_success', 5, 0, 0, 0, 5, True, None, True], - ['one_success', 2, 0, 0, 0, 2, True, None, True], - ['one_success', 2, 0, 1, 0, 3, True, None, True], - ['one_success', 2, 1, 0, 0, 3, True, None, True], - # - # Tests for all_failed - # - ['all_failed', 5, 0, 0, 0, 5, True, State.SKIPPED, False], - ['all_failed', 0, 0, 5, 0, 5, True, None, True], - ['all_failed', 2, 0, 0, 0, 2, True, State.SKIPPED, False], - ['all_failed', 2, 0, 1, 0, 3, True, State.SKIPPED, False], - ['all_failed', 2, 1, 0, 0, 3, True, State.SKIPPED, False], - # - # Tests for one_failed - # - ['one_failed', 5, 0, 0, 0, 0, True, None, False], - ['one_failed', 2, 0, 0, 0, 0, True, None, False], - ['one_failed', 2, 0, 1, 0, 0, True, None, True], - ['one_failed', 2, 1, 0, 0, 3, True, None, False], - ['one_failed', 2, 3, 0, 0, 5, True, State.SKIPPED, False], - # - # Tests for done - # - ['all_done', 5, 0, 0, 0, 5, True, None, True], - ['all_done', 2, 0, 0, 0, 2, True, None, False], - ['all_done', 2, 0, 1, 0, 3, True, None, False], - ['all_done', 2, 1, 0, 0, 3, True, None, False] - ]) - def test_check_task_dependencies(self, trigger_rule, successes, skipped, - failed, upstream_failed, done, - flag_upstream_failed, - expect_state, expect_completed): + @parameterized.expand( + [ + # + # Tests for all_success + # + ['all_success', 5, 0, 0, 0, 0, True, None, True], + ['all_success', 2, 0, 0, 0, 0, True, None, False], + ['all_success', 2, 0, 1, 0, 0, True, State.UPSTREAM_FAILED, False], + ['all_success', 2, 1, 0, 0, 0, True, State.SKIPPED, False], + # + # Tests for one_success + # + ['one_success', 5, 0, 0, 0, 5, True, None, True], + ['one_success', 2, 0, 0, 0, 2, True, None, True], + ['one_success', 2, 0, 1, 0, 3, True, None, True], + ['one_success', 2, 1, 0, 0, 3, True, None, True], + # + # Tests for all_failed + # + ['all_failed', 5, 0, 0, 0, 5, True, State.SKIPPED, False], + ['all_failed', 0, 0, 5, 0, 5, True, None, True], + ['all_failed', 2, 0, 0, 0, 2, True, State.SKIPPED, False], + ['all_failed', 2, 0, 1, 0, 3, True, State.SKIPPED, False], + ['all_failed', 2, 1, 0, 0, 3, True, State.SKIPPED, False], + # + # Tests for one_failed + # + ['one_failed', 5, 0, 0, 0, 0, True, None, False], + ['one_failed', 2, 0, 0, 0, 0, True, None, False], + ['one_failed', 2, 0, 1, 0, 0, True, None, True], + ['one_failed', 2, 1, 0, 0, 3, True, None, False], + ['one_failed', 2, 3, 0, 0, 5, True, State.SKIPPED, False], + # + # Tests for done + # + ['all_done', 5, 0, 0, 0, 5, True, None, True], + ['all_done', 2, 0, 0, 0, 2, True, None, False], + ['all_done', 2, 0, 1, 0, 3, True, None, False], + ['all_done', 2, 1, 0, 0, 3, True, None, False], + ] + ) + def test_check_task_dependencies( + self, + trigger_rule, + successes, + skipped, + failed, + upstream_failed, + done, + flag_upstream_failed, + expect_state, + expect_completed, + ): start_date = timezone.datetime(2016, 2, 1, 0, 0, 0) dag = models.DAG('test-dag', start_date=start_date) - downstream = DummyOperator(task_id='downstream', - dag=dag, owner='airflow', - trigger_rule=trigger_rule) + downstream = DummyOperator(task_id='downstream', dag=dag, owner='airflow', trigger_rule=trigger_rule) for i in range(5): - task = DummyOperator(task_id=f'runme_{i}', - dag=dag, owner='airflow') + task = DummyOperator(task_id=f'runme_{i}', dag=dag, owner='airflow') task.set_downstream(downstream) run_date = task.start_date + datetime.timedelta(days=5) @@ -923,20 +964,24 @@ def test_respects_prev_dagrun_dep(self): ti = TI(task, DEFAULT_DATE) failing_status = [TIDepStatus('test fail status name', False, 'test fail reason')] passing_status = [TIDepStatus('test pass status name', True, 'test passing reason')] - with patch('airflow.ti_deps.deps.prev_dagrun_dep.PrevDagrunDep.get_dep_statuses', - return_value=failing_status): + with patch( + 'airflow.ti_deps.deps.prev_dagrun_dep.PrevDagrunDep.get_dep_statuses', return_value=failing_status + ): self.assertFalse(ti.are_dependencies_met()) - with patch('airflow.ti_deps.deps.prev_dagrun_dep.PrevDagrunDep.get_dep_statuses', - return_value=passing_status): + with patch( + 'airflow.ti_deps.deps.prev_dagrun_dep.PrevDagrunDep.get_dep_statuses', return_value=passing_status + ): self.assertTrue(ti.are_dependencies_met()) - @parameterized.expand([ - (State.SUCCESS, True), - (State.SKIPPED, True), - (State.RUNNING, False), - (State.FAILED, False), - (State.NONE, False), - ]) + @parameterized.expand( + [ + (State.SUCCESS, True), + (State.SKIPPED, True), + (State.RUNNING, False), + (State.FAILED, False), + (State.NONE, False), + ] + ) def test_are_dependents_done(self, downstream_ti_state, expected_are_dependents_done): with DAG(dag_id='test_dag'): task = DummyOperator(task_id='task', start_date=DEFAULT_DATE) @@ -954,8 +999,10 @@ def test_xcom_pull(self): Test xcom_pull, using different filtering methods. """ dag = models.DAG( - dag_id='test_xcom', schedule_interval='@monthly', - start_date=timezone.datetime(2016, 6, 1, 0, 0, 0)) + dag_id='test_xcom', + schedule_interval='@monthly', + start_date=timezone.datetime(2016, 6, 1, 0, 0, 0), + ) exec_date = timezone.utcnow() @@ -982,8 +1029,7 @@ def test_xcom_pull(self): result = ti1.xcom_pull(task_ids='test_xcom_2', key='foo') self.assertEqual(result, 'baz') # Pull the values pushed by both tasks - result = ti1.xcom_pull( - task_ids=['test_xcom_1', 'test_xcom_2'], key='foo') + result = ti1.xcom_pull(task_ids=['test_xcom_1', 'test_xcom_2'], key='foo') self.assertEqual(result, ['baz', 'bar']) def test_xcom_pull_after_success(self): @@ -999,7 +1045,8 @@ def test_xcom_pull_after_success(self): dag=dag, pool='test_xcom', owner='airflow', - start_date=timezone.datetime(2016, 6, 2, 0, 0, 0)) + start_date=timezone.datetime(2016, 6, 2, 0, 0, 0), + ) exec_date = timezone.utcnow() ti = TI(task=task, execution_date=exec_date) @@ -1039,7 +1086,8 @@ def test_xcom_pull_different_execution_date(self): dag=dag, pool='test_xcom', owner='airflow', - start_date=timezone.datetime(2016, 6, 2, 0, 0, 0)) + start_date=timezone.datetime(2016, 6, 2, 0, 0, 0), + ) exec_date = timezone.utcnow() ti = TI(task=task, execution_date=exec_date) @@ -1054,18 +1102,14 @@ def test_xcom_pull_different_execution_date(self): self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), value) ti.run() exec_date += datetime.timedelta(days=1) - ti = TI( - task=task, execution_date=exec_date) + ti = TI(task=task, execution_date=exec_date) ti.run() # We have set a new execution date (and did not pass in # 'include_prior_dates'which means this task should now have a cleared # xcom value self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), None) # We *should* get a value using 'include_prior_dates' - self.assertEqual(ti.xcom_pull(task_ids='test_xcom', - key=key, - include_prior_dates=True), - value) + self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key, include_prior_dates=True), value) def test_xcom_push_flag(self): """ @@ -1082,7 +1126,7 @@ def test_xcom_push_flag(self): python_callable=lambda: value, do_xcom_push=False, owner='airflow', - start_date=datetime.datetime(2017, 1, 1) + start_date=datetime.datetime(2017, 1, 1), ) ti = TI(task=task, execution_date=datetime.datetime(2017, 1, 1)) dag.create_dagrun( @@ -1091,12 +1135,7 @@ def test_xcom_push_flag(self): run_type=DagRunType.SCHEDULED, ) ti.run() - self.assertEqual( - ti.xcom_pull( - task_ids=task_id, key=models.XCOM_RETURN_KEY - ), - None - ) + self.assertEqual(ti.xcom_pull(task_ids=task_id, key=models.XCOM_RETURN_KEY), None) def test_post_execute_hook(self): """ @@ -1118,7 +1157,8 @@ def post_execute(self, context, result=None): dag=dag, python_callable=lambda: 'error', owner='airflow', - start_date=timezone.datetime(2017, 2, 1)) + start_date=timezone.datetime(2017, 2, 1), + ) ti = TI(task=task, execution_date=timezone.utcnow()) with self.assertRaises(TestError): @@ -1127,8 +1167,7 @@ def post_execute(self, context, result=None): def test_check_and_change_state_before_execution(self): dag = models.DAG(dag_id='test_check_and_change_state_before_execution') task = DummyOperator(task_id='task', dag=dag, start_date=DEFAULT_DATE) - ti = TI( - task=task, execution_date=timezone.utcnow()) + ti = TI(task=task, execution_date=timezone.utcnow()) self.assertEqual(ti._try_number, 0) self.assertTrue(ti.check_and_change_state_before_execution()) # State should be running, and try_number column should be incremented @@ -1140,8 +1179,7 @@ def test_check_and_change_state_before_execution_dep_not_met(self): task = DummyOperator(task_id='task', dag=dag, start_date=DEFAULT_DATE) task2 = DummyOperator(task_id='task2', dag=dag, start_date=DEFAULT_DATE) task >> task2 - ti = TI( - task=task2, execution_date=timezone.utcnow()) + ti = TI(task=task2, execution_date=timezone.utcnow()) self.assertFalse(ti.check_and_change_state_before_execution()) def test_try_number(self): @@ -1212,8 +1250,8 @@ def test_mark_success_url(self): task = DummyOperator(task_id='op', dag=dag) ti = TI(task=task, execution_date=now) query = urllib.parse.parse_qs( - urllib.parse.urlparse(ti.mark_success_url).query, - keep_blank_values=True, strict_parsing=True) + urllib.parse.urlparse(ti.mark_success_url).query, keep_blank_values=True, strict_parsing=True + ) self.assertEqual(query['dag_id'][0], 'dag') self.assertEqual(query['task_id'][0], 'op') self.assertEqual(pendulum.parse(query['execution_date'][0]), now) @@ -1252,11 +1290,8 @@ def test_overwrite_params_with_dag_run_conf_none(self): def test_email_alert(self, mock_send_email): dag = models.DAG(dag_id='test_failure_email') task = BashOperator( - task_id='test_email_alert', - dag=dag, - bash_command='exit 1', - start_date=DEFAULT_DATE, - email='to') + task_id='test_email_alert', dag=dag, bash_command='exit 1', start_date=DEFAULT_DATE, email='to' + ) ti = TI(task=task, execution_date=timezone.utcnow()) @@ -1271,10 +1306,12 @@ def test_email_alert(self, mock_send_email): self.assertIn('test_email_alert', body) self.assertIn('Try 1', body) - @conf_vars({ - ('email', 'subject_template'): '/subject/path', - ('email', 'html_content_template'): '/html_content/path', - }) + @conf_vars( + { + ('email', 'subject_template'): '/subject/path', + ('email', 'html_content_template'): '/html_content/path', + } + ) @patch('airflow.models.taskinstance.send_email') def test_email_alert_with_config(self, mock_send_email): dag = models.DAG(dag_id='test_failure_email') @@ -1283,7 +1320,8 @@ def test_email_alert_with_config(self, mock_send_email): dag=dag, bash_command='exit 1', start_date=DEFAULT_DATE, - email='to') + email='to', + ) ti = TI(task=task, execution_date=timezone.utcnow()) @@ -1318,10 +1356,17 @@ def test_set_duration_empty_dates(self): def test_success_callback_no_race_condition(self): callback_wrapper = CallbackWrapper() - dag = DAG('test_success_callback_no_race_condition', start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=10)) - task = DummyOperator(task_id='op', email='test@test.test', - on_success_callback=callback_wrapper.success_handler, dag=dag) + dag = DAG( + 'test_success_callback_no_race_condition', + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=10), + ) + task = DummyOperator( + task_id='op', + email='test@test.test', + on_success_callback=callback_wrapper.success_handler, + dag=dag, + ) ti = TI(task=task, execution_date=datetime.datetime.now()) ti.state = State.RUNNING session = settings.Session() @@ -1343,8 +1388,9 @@ def test_success_callback_no_race_condition(self): self.assertEqual(ti.state, State.SUCCESS) @staticmethod - def _test_previous_dates_setup(schedule_interval: Union[str, datetime.timedelta, None], - catchup: bool, scenario: List[str]) -> list: + def _test_previous_dates_setup( + schedule_interval: Union[str, datetime.timedelta, None], catchup: bool, scenario: List[str] + ) -> list: dag_id = 'test_previous_dates' dag = models.DAG(dag_id=dag_id, schedule_interval=schedule_interval, catchup=catchup) task = DummyOperator(task_id='task', dag=dag, start_date=DEFAULT_DATE) @@ -1355,7 +1401,7 @@ def get_test_ti(session, execution_date: pendulum.DateTime, state: str) -> TI: state=state, execution_date=execution_date, start_date=pendulum.now('UTC'), - session=session + session=session, ) ti = TI(task=task, execution_date=execution_date) ti.set_state(state=State.SUCCESS, session=session) @@ -1392,15 +1438,9 @@ def test_previous_ti(self, _, schedule_interval, catchup) -> None: self.assertIsNone(ti_list[0].get_previous_ti()) - self.assertEqual( - ti_list[2].get_previous_ti().execution_date, - ti_list[1].execution_date - ) + self.assertEqual(ti_list[2].get_previous_ti().execution_date, ti_list[1].execution_date) - self.assertNotEqual( - ti_list[2].get_previous_ti().execution_date, - ti_list[0].execution_date - ) + self.assertNotEqual(ti_list[2].get_previous_ti().execution_date, ti_list[0].execution_date) @parameterized.expand(_prev_dates_param_list) def test_previous_ti_success(self, _, schedule_interval, catchup) -> None: @@ -1413,13 +1453,11 @@ def test_previous_ti_success(self, _, schedule_interval, catchup) -> None: self.assertIsNone(ti_list[1].get_previous_ti(state=State.SUCCESS)) self.assertEqual( - ti_list[3].get_previous_ti(state=State.SUCCESS).execution_date, - ti_list[1].execution_date + ti_list[3].get_previous_ti(state=State.SUCCESS).execution_date, ti_list[1].execution_date ) self.assertNotEqual( - ti_list[3].get_previous_ti(state=State.SUCCESS).execution_date, - ti_list[2].execution_date + ti_list[3].get_previous_ti(state=State.SUCCESS).execution_date, ti_list[2].execution_date ) @parameterized.expand(_prev_dates_param_list) @@ -1432,12 +1470,10 @@ def test_previous_execution_date_success(self, _, schedule_interval, catchup) -> self.assertIsNone(ti_list[0].get_previous_execution_date(state=State.SUCCESS)) self.assertIsNone(ti_list[1].get_previous_execution_date(state=State.SUCCESS)) self.assertEqual( - ti_list[3].get_previous_execution_date(state=State.SUCCESS), - ti_list[1].execution_date + ti_list[3].get_previous_execution_date(state=State.SUCCESS), ti_list[1].execution_date ) self.assertNotEqual( - ti_list[3].get_previous_execution_date(state=State.SUCCESS), - ti_list[2].execution_date + ti_list[3].get_previous_execution_date(state=State.SUCCESS), ti_list[2].execution_date ) @parameterized.expand(_prev_dates_param_list) @@ -1460,8 +1496,10 @@ def test_previous_start_date_success(self, _, schedule_interval, catchup) -> Non def test_pendulum_template_dates(self): dag = models.DAG( - dag_id='test_pendulum_template_dates', schedule_interval='0 12 * * *', - start_date=timezone.datetime(2016, 6, 1, 0, 0, 0)) + dag_id='test_pendulum_template_dates', + schedule_interval='0 12 * * *', + start_date=timezone.datetime(2016, 6, 1, 0, 0, 0), + ) task = DummyOperator(task_id='test_pendulum_template_dates_task', dag=dag) ti = TI(task=task, execution_date=timezone.utcnow()) @@ -1544,16 +1582,16 @@ def test_execute_callback(self): def on_execute_callable(context): nonlocal called called = True - self.assertEqual( - context['dag_run'].dag_id, - 'test_dagrun_execute_callback' - ) + self.assertEqual(context['dag_run'].dag_id, 'test_dagrun_execute_callback') - dag = DAG('test_execute_callback', start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=10)) - task = DummyOperator(task_id='op', email='test@test.test', - on_execute_callback=on_execute_callable, - dag=dag) + dag = DAG( + 'test_execute_callback', + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=10), + ) + task = DummyOperator( + task_id='op', email='test@test.test', on_execute_callback=on_execute_callable, dag=dag + ) ti = TI(task=task, execution_date=datetime.datetime.now()) ti.state = State.RUNNING session = settings.Session() @@ -1580,10 +1618,12 @@ def test_handle_failure(self): mock_on_failure_1 = mock.MagicMock() mock_on_retry_1 = mock.MagicMock() - task1 = DummyOperator(task_id="test_handle_failure_on_failure", - on_failure_callback=mock_on_failure_1, - on_retry_callback=mock_on_retry_1, - dag=dag) + task1 = DummyOperator( + task_id="test_handle_failure_on_failure", + on_failure_callback=mock_on_failure_1, + on_retry_callback=mock_on_retry_1, + dag=dag, + ) ti1 = TI(task=task1, execution_date=start_date) ti1.state = State.FAILED ti1.handle_failure("test failure handling") @@ -1594,11 +1634,13 @@ def test_handle_failure(self): mock_on_failure_2 = mock.MagicMock() mock_on_retry_2 = mock.MagicMock() - task2 = DummyOperator(task_id="test_handle_failure_on_retry", - on_failure_callback=mock_on_failure_2, - on_retry_callback=mock_on_retry_2, - retries=1, - dag=dag) + task2 = DummyOperator( + task_id="test_handle_failure_on_retry", + on_failure_callback=mock_on_failure_2, + on_retry_callback=mock_on_retry_2, + retries=1, + dag=dag, + ) ti2 = TI(task=task2, execution_date=start_date) ti2.state = State.FAILED ti2.handle_failure("test retry handling") @@ -1611,11 +1653,13 @@ def test_handle_failure(self): # test the scenario where normally we would retry but have been asked to fail mock_on_failure_3 = mock.MagicMock() mock_on_retry_3 = mock.MagicMock() - task3 = DummyOperator(task_id="test_handle_failure_on_force_fail", - on_failure_callback=mock_on_failure_3, - on_retry_callback=mock_on_retry_3, - retries=1, - dag=dag) + task3 = DummyOperator( + task_id="test_handle_failure_on_force_fail", + on_failure_callback=mock_on_failure_3, + on_retry_callback=mock_on_retry_3, + retries=1, + dag=dag, + ) ti3 = TI(task=task3, execution_date=start_date) ti3.state = State.FAILED ti3.handle_failure("test force_fail handling", force_fail=True) @@ -1635,7 +1679,7 @@ def fail(): python_callable=fail, owner='airflow', start_date=timezone.datetime(2016, 2, 1, 0, 0, 0), - retries=1 + retries=1, ) ti = TI(task=task, execution_date=timezone.utcnow()) try: @@ -1655,7 +1699,7 @@ def fail(): python_callable=fail, owner='airflow', start_date=timezone.datetime(2016, 2, 1, 0, 0, 0), - retries=1 + retries=1, ) ti = TI(task=task, execution_date=timezone.utcnow()) try: @@ -1667,23 +1711,27 @@ def fail(): def _env_var_check_callback(self): self.assertEqual('test_echo_env_variables', os.environ['AIRFLOW_CTX_DAG_ID']) self.assertEqual('hive_in_python_op', os.environ['AIRFLOW_CTX_TASK_ID']) - self.assertEqual(DEFAULT_DATE.isoformat(), - os.environ['AIRFLOW_CTX_EXECUTION_DATE']) - self.assertEqual(DagRun.generate_run_id(DagRunType.MANUAL, DEFAULT_DATE), - os.environ['AIRFLOW_CTX_DAG_RUN_ID']) + self.assertEqual(DEFAULT_DATE.isoformat(), os.environ['AIRFLOW_CTX_EXECUTION_DATE']) + self.assertEqual( + DagRun.generate_run_id(DagRunType.MANUAL, DEFAULT_DATE), os.environ['AIRFLOW_CTX_DAG_RUN_ID'] + ) def test_echo_env_variables(self): - dag = DAG('test_echo_env_variables', start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=10)) - op = PythonOperator(task_id='hive_in_python_op', - dag=dag, - python_callable=self._env_var_check_callback) + dag = DAG( + 'test_echo_env_variables', + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=10), + ) + op = PythonOperator( + task_id='hive_in_python_op', dag=dag, python_callable=self._env_var_check_callback + ) dag.create_dagrun( run_type=DagRunType.MANUAL, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, state=State.RUNNING, - external_trigger=False) + external_trigger=False, + ) ti = TI(task=op, execution_date=DEFAULT_DATE) ti.state = State.RUNNING session = settings.Session() @@ -1695,15 +1743,19 @@ def test_echo_env_variables(self): @patch.object(Stats, 'incr') def test_task_stats(self, stats_mock): - dag = DAG('test_task_start_end_stats', start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=10)) + dag = DAG( + 'test_task_start_end_stats', + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=10), + ) op = DummyOperator(task_id='dummy_op', dag=dag) dag.create_dagrun( run_id='manual__' + DEFAULT_DATE.isoformat(), execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, state=State.RUNNING, - external_trigger=False) + external_trigger=False, + ) ti = TI(task=op, execution_date=DEFAULT_DATE) ti.state = State.RUNNING session = settings.Session() @@ -1725,10 +1777,18 @@ def test_generate_command_default_param(self): def test_generate_command_specific_param(self): dag_id = 'test_generate_command_specific_param' task_id = 'task' - assert_command = ['airflow', 'tasks', 'run', dag_id, - task_id, DEFAULT_DATE.isoformat(), '--mark-success'] - generate_command = TI.generate_command(dag_id=dag_id, task_id=task_id, - execution_date=DEFAULT_DATE, mark_success=True) + assert_command = [ + 'airflow', + 'tasks', + 'run', + dag_id, + task_id, + DEFAULT_DATE.isoformat(), + '--mark-success', + ] + generate_command = TI.generate_command( + dag_id=dag_id, task_id=task_id, execution_date=DEFAULT_DATE, mark_success=True + ) assert assert_command == generate_command def test_get_rendered_template_fields(self): @@ -1759,50 +1819,49 @@ def validate_ti_states(self, dag_run, ti_state_mapping, error_message): task_instance = dag_run.get_task_instance(task_id=task_id) self.assertEqual(task_instance.state, expected_state, error_message) - @parameterized.expand([ - ( - {('scheduler', 'schedule_after_task_execution'): 'True'}, - {'A': 'B', 'B': 'C'}, - {'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE}, - {'A': State.SUCCESS, 'B': State.SCHEDULED, 'C': State.NONE}, - {'A': State.SUCCESS, 'B': State.SUCCESS, 'C': State.SCHEDULED}, - "A -> B -> C, with fast-follow ON when A runs, B should be QUEUED. Same for B and C." - ), - ( - {('scheduler', 'schedule_after_task_execution'): 'False'}, - {'A': 'B', 'B': 'C'}, - {'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE}, - {'A': State.SUCCESS, 'B': State.NONE, 'C': State.NONE}, - None, - "A -> B -> C, with fast-follow OFF, when A runs, B shouldn't be QUEUED." - ), - ( - {('scheduler', 'schedule_after_task_execution'): 'True'}, - {'A': 'B', 'C': 'B', 'D': 'C'}, - {'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE, 'D': State.NONE}, - {'A': State.SUCCESS, 'B': State.NONE, 'C': State.NONE, 'D': State.NONE}, - None, - "D -> C -> B & A -> B, when A runs but C isn't QUEUED yet, B shouldn't be QUEUED." - ), - ( - {('scheduler', 'schedule_after_task_execution'): 'True'}, - {'A': 'C', 'B': 'C'}, - {'A': State.QUEUED, 'B': State.FAILED, 'C': State.NONE}, - {'A': State.SUCCESS, 'B': State.FAILED, 'C': State.UPSTREAM_FAILED}, - None, - "A -> C & B -> C, when A is QUEUED but B has FAILED, C is marked UPSTREAM_FAILED." - ), - ]) + @parameterized.expand( + [ + ( + {('scheduler', 'schedule_after_task_execution'): 'True'}, + {'A': 'B', 'B': 'C'}, + {'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE}, + {'A': State.SUCCESS, 'B': State.SCHEDULED, 'C': State.NONE}, + {'A': State.SUCCESS, 'B': State.SUCCESS, 'C': State.SCHEDULED}, + "A -> B -> C, with fast-follow ON when A runs, B should be QUEUED. Same for B and C.", + ), + ( + {('scheduler', 'schedule_after_task_execution'): 'False'}, + {'A': 'B', 'B': 'C'}, + {'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE}, + {'A': State.SUCCESS, 'B': State.NONE, 'C': State.NONE}, + None, + "A -> B -> C, with fast-follow OFF, when A runs, B shouldn't be QUEUED.", + ), + ( + {('scheduler', 'schedule_after_task_execution'): 'True'}, + {'A': 'B', 'C': 'B', 'D': 'C'}, + {'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE, 'D': State.NONE}, + {'A': State.SUCCESS, 'B': State.NONE, 'C': State.NONE, 'D': State.NONE}, + None, + "D -> C -> B & A -> B, when A runs but C isn't QUEUED yet, B shouldn't be QUEUED.", + ), + ( + {('scheduler', 'schedule_after_task_execution'): 'True'}, + {'A': 'C', 'B': 'C'}, + {'A': State.QUEUED, 'B': State.FAILED, 'C': State.NONE}, + {'A': State.SUCCESS, 'B': State.FAILED, 'C': State.UPSTREAM_FAILED}, + None, + "A -> C & B -> C, when A is QUEUED but B has FAILED, C is marked UPSTREAM_FAILED.", + ), + ] + ) def test_fast_follow( self, conf, dependencies, init_state, first_run_state, second_run_state, error_message ): with conf_vars(conf): session = settings.Session() - dag = DAG( - 'test_dagrun_fast_follow', - start_date=DEFAULT_DATE - ) + dag = DAG('test_dagrun_fast_follow', start_date=DEFAULT_DATE) dag_model = DagModel( dag_id=dag.dag_id, @@ -1857,9 +1916,16 @@ def test_fast_follow( @pytest.mark.parametrize("pool_override", [None, "test_pool2"]) def test_refresh_from_task(pool_override): - task = DummyOperator(task_id="dummy", queue="test_queue", pool="test_pool1", pool_slots=3, - priority_weight=10, run_as_user="test", retries=30, - executor_config={"KubernetesExecutor": {"image": "myCustomDockerImage"}}) + task = DummyOperator( + task_id="dummy", + queue="test_queue", + pool="test_pool1", + pool_slots=3, + priority_weight=10, + run_as_user="test", + retries=30, + executor_config={"KubernetesExecutor": {"image": "myCustomDockerImage"}}, + ) ti = TI(task, execution_date=pendulum.datetime(2020, 1, 1)) ti.refresh_from_task(task, pool_override=pool_override) @@ -1898,11 +1964,13 @@ def setUp(self) -> None: def tearDown(self) -> None: self._clean() - @parameterized.expand([ - # Expected queries, mark_success - (10, False), - (5, True), - ]) + @parameterized.expand( + [ + # Expected queries, mark_success + (10, False), + (5, True), + ] + ) def test_execute_queries_count(self, expected_query_count, mark_success): with create_session() as session: dag = DAG('test_queries', start_date=DEFAULT_DATE) diff --git a/tests/models/test_timestamp.py b/tests/models/test_timestamp.py index 979fb1d5b6415..8424a687ec94e 100644 --- a/tests/models/test_timestamp.py +++ b/tests/models/test_timestamp.py @@ -38,8 +38,7 @@ def clear_db(session=None): def add_log(execdate, session, timezone_override=None): - dag = DAG(dag_id='logging', - default_args={'start_date': execdate}) + dag = DAG(dag_id='logging', default_args={'start_date': execdate}) task = DummyOperator(task_id='dummy', dag=dag, owner='airflow') task_instance = TaskInstance(task=task, execution_date=execdate, state='success') session.merge(task_instance) diff --git a/tests/models/test_variable.py b/tests/models/test_variable.py index 4b507e616cd1b..54074dd3a6ea3 100644 --- a/tests/models/test_variable.py +++ b/tests/models/test_variable.py @@ -58,10 +58,7 @@ def test_variable_with_encryption(self): self.assertTrue(test_var.is_encrypted) self.assertEqual(test_var.val, 'value') - @parameterized.expand([ - 'value', - '' - ]) + @parameterized.expand(['value', '']) def test_var_with_encryption_rotate_fernet_key(self, test_value): """ Tests rotating encrypted variables. @@ -106,8 +103,7 @@ def test_variable_set_existing_value_to_blank(self): def test_get_non_existing_var_should_return_default(self): default_value = "some default val" - self.assertEqual(default_value, Variable.get("thisIdDoesNotExist", - default_var=default_value)) + self.assertEqual(default_value, Variable.get("thisIdDoesNotExist", default_var=default_value)) def test_get_non_existing_var_should_raise_key_error(self): with self.assertRaises(KeyError): @@ -118,9 +114,10 @@ def test_get_non_existing_var_with_none_default_should_return_none(self): def test_get_non_existing_var_should_not_deserialize_json_default(self): default_value = "}{ this is a non JSON default }{" - self.assertEqual(default_value, Variable.get("thisIdDoesNotExist", - default_var=default_value, - deserialize_json=True)) + self.assertEqual( + default_value, + Variable.get("thisIdDoesNotExist", default_var=default_value, deserialize_json=True), + ) def test_variable_setdefault_round_trip(self): key = "tested_var_setdefault_1_id" diff --git a/tests/models/test_xcom.py b/tests/models/test_xcom.py index 39586d1638468..9793a26faf6f8 100644 --- a/tests/models/test_xcom.py +++ b/tests/models/test_xcom.py @@ -32,7 +32,6 @@ def serialize_value(_): class TestXCom(unittest.TestCase): - def setUp(self) -> None: db.clear_db_xcom() @@ -45,9 +44,7 @@ def test_resolve_xcom_class(self): assert issubclass(cls, CustomXCom) assert cls().serialize_value(None) == "custom_value" - @conf_vars( - {("core", "xcom_backend"): "", ("core", "enable_xcom_pickling"): "False"} - ) + @conf_vars({("core", "xcom_backend"): "", ("core", "enable_xcom_pickling"): "False"}) def test_resolve_xcom_class_fallback_to_basexcom(self): cls = resolve_xcom_backend() assert issubclass(cls, BaseXCom) @@ -68,24 +65,28 @@ def test_xcom_disable_pickle_type(self): key = "xcom_test1" dag_id = "test_dag1" task_id = "test_task1" - XCom.set(key=key, - value=json_obj, - dag_id=dag_id, - task_id=task_id, - execution_date=execution_date) + XCom.set(key=key, value=json_obj, dag_id=dag_id, task_id=task_id, execution_date=execution_date) - ret_value = XCom.get_many(key=key, - dag_ids=dag_id, - task_ids=task_id, - execution_date=execution_date).first().value + ret_value = ( + XCom.get_many(key=key, dag_ids=dag_id, task_ids=task_id, execution_date=execution_date) + .first() + .value + ) self.assertEqual(ret_value, json_obj) session = settings.Session() - ret_value = session.query(XCom).filter(XCom.key == key, XCom.dag_id == dag_id, - XCom.task_id == task_id, - XCom.execution_date == execution_date - ).first().value + ret_value = ( + session.query(XCom) + .filter( + XCom.key == key, + XCom.dag_id == dag_id, + XCom.task_id == task_id, + XCom.execution_date == execution_date, + ) + .first() + .value + ) self.assertEqual(ret_value, json_obj) @@ -96,24 +97,24 @@ def test_xcom_get_one_disable_pickle_type(self): key = "xcom_test1" dag_id = "test_dag1" task_id = "test_task1" - XCom.set(key=key, - value=json_obj, - dag_id=dag_id, - task_id=task_id, - execution_date=execution_date) + XCom.set(key=key, value=json_obj, dag_id=dag_id, task_id=task_id, execution_date=execution_date) - ret_value = XCom.get_one(key=key, - dag_id=dag_id, - task_id=task_id, - execution_date=execution_date) + ret_value = XCom.get_one(key=key, dag_id=dag_id, task_id=task_id, execution_date=execution_date) self.assertEqual(ret_value, json_obj) session = settings.Session() - ret_value = session.query(XCom).filter(XCom.key == key, XCom.dag_id == dag_id, - XCom.task_id == task_id, - XCom.execution_date == execution_date - ).first().value + ret_value = ( + session.query(XCom) + .filter( + XCom.key == key, + XCom.dag_id == dag_id, + XCom.task_id == task_id, + XCom.execution_date == execution_date, + ) + .first() + .value + ) self.assertEqual(ret_value, json_obj) @@ -124,24 +125,28 @@ def test_xcom_enable_pickle_type(self): key = "xcom_test2" dag_id = "test_dag2" task_id = "test_task2" - XCom.set(key=key, - value=json_obj, - dag_id=dag_id, - task_id=task_id, - execution_date=execution_date) + XCom.set(key=key, value=json_obj, dag_id=dag_id, task_id=task_id, execution_date=execution_date) - ret_value = XCom.get_many(key=key, - dag_ids=dag_id, - task_ids=task_id, - execution_date=execution_date).first().value + ret_value = ( + XCom.get_many(key=key, dag_ids=dag_id, task_ids=task_id, execution_date=execution_date) + .first() + .value + ) self.assertEqual(ret_value, json_obj) session = settings.Session() - ret_value = session.query(XCom).filter(XCom.key == key, XCom.dag_id == dag_id, - XCom.task_id == task_id, - XCom.execution_date == execution_date - ).first().value + ret_value = ( + session.query(XCom) + .filter( + XCom.key == key, + XCom.dag_id == dag_id, + XCom.task_id == task_id, + XCom.execution_date == execution_date, + ) + .first() + .value + ) self.assertEqual(ret_value, json_obj) @@ -152,24 +157,24 @@ def test_xcom_get_one_enable_pickle_type(self): key = "xcom_test3" dag_id = "test_dag" task_id = "test_task3" - XCom.set(key=key, - value=json_obj, - dag_id=dag_id, - task_id=task_id, - execution_date=execution_date) + XCom.set(key=key, value=json_obj, dag_id=dag_id, task_id=task_id, execution_date=execution_date) - ret_value = XCom.get_one(key=key, - dag_id=dag_id, - task_id=task_id, - execution_date=execution_date) + ret_value = XCom.get_one(key=key, dag_id=dag_id, task_id=task_id, execution_date=execution_date) self.assertEqual(ret_value, json_obj) session = settings.Session() - ret_value = session.query(XCom).filter(XCom.key == key, XCom.dag_id == dag_id, - XCom.task_id == task_id, - XCom.execution_date == execution_date - ).first().value + ret_value = ( + session.query(XCom) + .filter( + XCom.key == key, + XCom.dag_id == dag_id, + XCom.task_id == task_id, + XCom.execution_date == execution_date, + ) + .first() + .value + ) self.assertEqual(ret_value, json_obj) @@ -179,12 +184,15 @@ class PickleRce: def __reduce__(self): return os.system, ("ls -alt",) - self.assertRaises(TypeError, XCom.set, - key="xcom_test3", - value=PickleRce(), - dag_id="test_dag3", - task_id="test_task3", - execution_date=timezone.utcnow()) + self.assertRaises( + TypeError, + XCom.set, + key="xcom_test3", + value=PickleRce(), + dag_id="test_dag3", + task_id="test_task3", + execution_date=timezone.utcnow(), + ) @conf_vars({("core", "xcom_enable_pickling"): "True"}) def test_xcom_get_many(self): @@ -196,17 +204,9 @@ def test_xcom_get_many(self): dag_id2 = "test_dag5" task_id2 = "test_task5" - XCom.set(key=key, - value=json_obj, - dag_id=dag_id1, - task_id=task_id1, - execution_date=execution_date) - - XCom.set(key=key, - value=json_obj, - dag_id=dag_id2, - task_id=task_id2, - execution_date=execution_date) + XCom.set(key=key, value=json_obj, dag_id=dag_id1, task_id=task_id1, execution_date=execution_date) + + XCom.set(key=key, value=json_obj, dag_id=dag_id2, task_id=task_id2, execution_date=execution_date) results = XCom.get_many(key=key, execution_date=execution_date) diff --git a/tests/models/test_xcom_arg.py b/tests/models/test_xcom_arg.py index 2f38ed72bdafc..6c08de069e0da 100644 --- a/tests/models/test_xcom_arg.py +++ b/tests/models/test_xcom_arg.py @@ -62,17 +62,21 @@ def test_xcom_ctor(self): assert actual.key == "test_key" # Asserting the overridden __eq__ method assert actual == XComArg(python_op, "test_key") - assert str(actual) == "task_instance.xcom_pull(" \ - "task_ids=\'test_xcom_op\', " \ - "dag_id=\'test_xcom_dag\', " \ - "key=\'test_key\')" + assert ( + str(actual) == "task_instance.xcom_pull(" + "task_ids=\'test_xcom_op\', " + "dag_id=\'test_xcom_dag\', " + "key=\'test_key\')" + ) def test_xcom_key_is_empty_str(self): python_op = build_python_op() actual = XComArg(python_op, key="") assert actual.key == "" - assert str(actual) == "task_instance.xcom_pull(task_ids='test_xcom_op', " \ - "dag_id='test_xcom_dag', key='')" + assert ( + str(actual) == "task_instance.xcom_pull(task_ids='test_xcom_op', " + "dag_id='test_xcom_dag', key='')" + ) def test_set_downstream(self): with DAG("test_set_downstream", default_args=DEFAULT_ARGS): diff --git a/tests/operators/test_bash.py b/tests/operators/test_bash.py index 22dd60051affa..9c0cca18d817e 100644 --- a/tests/operators/test_bash.py +++ b/tests/operators/test_bash.py @@ -36,7 +36,6 @@ class TestBashOperator(unittest.TestCase): - def test_echo_env_variables(self): """ Test that env variables are exported correctly to the @@ -46,13 +45,11 @@ def test_echo_env_variables(self): now = now.replace(tzinfo=timezone.utc) dag = DAG( - dag_id='bash_op_test', default_args={ - 'owner': 'airflow', - 'retries': 100, - 'start_date': DEFAULT_DATE - }, + dag_id='bash_op_test', + default_args={'owner': 'airflow', 'retries': 100, 'start_date': DEFAULT_DATE}, schedule_interval='@daily', - dagrun_timeout=timedelta(minutes=60)) + dagrun_timeout=timedelta(minutes=60), + ) dag.create_dagrun( run_type=DagRunType.MANUAL, @@ -67,19 +64,17 @@ def test_echo_env_variables(self): task_id='echo_env_vars', dag=dag, bash_command='echo $AIRFLOW_HOME>> {0};' - 'echo $PYTHONPATH>> {0};' - 'echo $AIRFLOW_CTX_DAG_ID >> {0};' - 'echo $AIRFLOW_CTX_TASK_ID>> {0};' - 'echo $AIRFLOW_CTX_EXECUTION_DATE>> {0};' - 'echo $AIRFLOW_CTX_DAG_RUN_ID>> {0};'.format(tmp_file.name) + 'echo $PYTHONPATH>> {0};' + 'echo $AIRFLOW_CTX_DAG_ID >> {0};' + 'echo $AIRFLOW_CTX_TASK_ID>> {0};' + 'echo $AIRFLOW_CTX_EXECUTION_DATE>> {0};' + 'echo $AIRFLOW_CTX_DAG_RUN_ID>> {0};'.format(tmp_file.name), ) - with mock.patch.dict('os.environ', { - 'AIRFLOW_HOME': 'MY_PATH_TO_AIRFLOW_HOME', - 'PYTHONPATH': 'AWESOME_PYTHONPATH' - }): - task.run(DEFAULT_DATE, DEFAULT_DATE, - ignore_first_depends_on_past=True, ignore_ti_state=True) + with mock.patch.dict( + 'os.environ', {'AIRFLOW_HOME': 'MY_PATH_TO_AIRFLOW_HOME', 'PYTHONPATH': 'AWESOME_PYTHONPATH'} + ): + task.run(DEFAULT_DATE, DEFAULT_DATE, ignore_first_depends_on_past=True, ignore_ti_state=True) with open(tmp_file.name) as file: output = ''.join(file.readlines()) @@ -92,60 +87,44 @@ def test_echo_env_variables(self): self.assertIn(DagRun.generate_run_id(DagRunType.MANUAL, DEFAULT_DATE), output) def test_return_value(self): - bash_operator = BashOperator( - bash_command='echo "stdout"', - task_id='test_return_value', - dag=None - ) + bash_operator = BashOperator(bash_command='echo "stdout"', task_id='test_return_value', dag=None) return_value = bash_operator.execute(context={}) self.assertEqual(return_value, 'stdout') def test_raise_exception_on_non_zero_exit_code(self): - bash_operator = BashOperator( - bash_command='exit 42', - task_id='test_return_value', - dag=None - ) + bash_operator = BashOperator(bash_command='exit 42', task_id='test_return_value', dag=None) with self.assertRaisesRegex( - AirflowException, - "Bash command failed\\. The command returned a non-zero exit code\\." + AirflowException, "Bash command failed\\. The command returned a non-zero exit code\\." ): bash_operator.execute(context={}) def test_task_retries(self): bash_operator = BashOperator( - bash_command='echo "stdout"', - task_id='test_task_retries', - retries=2, - dag=None + bash_command='echo "stdout"', task_id='test_task_retries', retries=2, dag=None ) self.assertEqual(bash_operator.retries, 2) def test_default_retries(self): - bash_operator = BashOperator( - bash_command='echo "stdout"', - task_id='test_default_retries', - dag=None - ) + bash_operator = BashOperator(bash_command='echo "stdout"', task_id='test_default_retries', dag=None) self.assertEqual(bash_operator.retries, 0) @mock.patch.dict('os.environ', clear=True) - @mock.patch("airflow.operators.bash.TemporaryDirectory", **{ # type: ignore - 'return_value.__enter__.return_value': '/tmp/airflowtmpcatcat' - }) - @mock.patch("airflow.operators.bash.Popen", **{ # type: ignore - 'return_value.stdout.readline.side_effect': [b'BAR', b'BAZ'], - 'return_value.returncode': 0 - }) + @mock.patch( + "airflow.operators.bash.TemporaryDirectory", + **{'return_value.__enter__.return_value': '/tmp/airflowtmpcatcat'}, # type: ignore + ) + @mock.patch( + "airflow.operators.bash.Popen", + **{ # type: ignore + 'return_value.stdout.readline.side_effect': [b'BAR', b'BAZ'], + 'return_value.returncode': 0, + }, + ) def test_should_exec_subprocess(self, mock_popen, mock_temporary_directory): - bash_operator = BashOperator( - bash_command='echo "stdout"', - task_id='test_return_value', - dag=None - ) + bash_operator = BashOperator(bash_command='echo "stdout"', task_id='test_return_value', dag=None) bash_operator.execute({}) mock_popen.assert_called_once_with( @@ -154,5 +133,5 @@ def test_should_exec_subprocess(self, mock_popen, mock_temporary_directory): env={}, preexec_fn=mock.ANY, stderr=STDOUT, - stdout=PIPE + stdout=PIPE, ) diff --git a/tests/operators/test_branch_operator.py b/tests/operators/test_branch_operator.py index 9553a725581dc..f97f0dab8b0cf 100644 --- a/tests/operators/test_branch_operator.py +++ b/tests/operators/test_branch_operator.py @@ -51,11 +51,11 @@ def setUpClass(cls): session.query(TI).delete() def setUp(self): - self.dag = DAG('branch_operator_test', - default_args={ - 'owner': 'airflow', - 'start_date': DEFAULT_DATE}, - schedule_interval=INTERVAL) + self.dag = DAG( + 'branch_operator_test', + default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE}, + schedule_interval=INTERVAL, + ) self.branch_1 = DummyOperator(task_id='branch_1', dag=self.dag) self.branch_2 = DummyOperator(task_id='branch_2', dag=self.dag) @@ -79,10 +79,7 @@ def test_without_dag_run(self): self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) with create_session() as session: - tis = session.query(TI).filter( - TI.dag_id == self.dag.dag_id, - TI.execution_date == DEFAULT_DATE - ) + tis = session.query(TI).filter(TI.dag_id == self.dag.dag_id, TI.execution_date == DEFAULT_DATE) for ti in tis: if ti.task_id == 'make_choice': @@ -107,10 +104,7 @@ def test_branch_list_without_dag_run(self): self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) with create_session() as session: - tis = session.query(TI).filter( - TI.dag_id == self.dag.dag_id, - TI.execution_date == DEFAULT_DATE - ) + tis = session.query(TI).filter(TI.dag_id == self.dag.dag_id, TI.execution_date == DEFAULT_DATE) expected = { "make_choice": State.SUCCESS, @@ -135,7 +129,7 @@ def test_with_dag_run(self): run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, - state=State.RUNNING + state=State.RUNNING, ) self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) @@ -161,7 +155,7 @@ def test_with_skip_in_branch_downstream_dependencies(self): run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, - state=State.RUNNING + state=State.RUNNING, ) self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) diff --git a/tests/operators/test_dagrun_operator.py b/tests/operators/test_dagrun_operator.py index 804305b6204d0..0ef41587932df 100644 --- a/tests/operators/test_dagrun_operator.py +++ b/tests/operators/test_dagrun_operator.py @@ -149,11 +149,13 @@ def test_trigger_dagrun_operator_templated_conf(self): def test_trigger_dagrun_with_reset_dag_run_false(self): """Test TriggerDagRunOperator with reset_dag_run.""" execution_date = DEFAULT_DATE - task = TriggerDagRunOperator(task_id="test_task", - trigger_dag_id=TRIGGERED_DAG_ID, - execution_date=execution_date, - reset_dag_run=False, - dag=self.dag) + task = TriggerDagRunOperator( + task_id="test_task", + trigger_dag_id=TRIGGERED_DAG_ID, + execution_date=execution_date, + reset_dag_run=False, + dag=self.dag, + ) task.run(start_date=execution_date, end_date=execution_date, ignore_ti_state=True) with self.assertRaises(DagRunAlreadyExists): @@ -162,11 +164,13 @@ def test_trigger_dagrun_with_reset_dag_run_false(self): def test_trigger_dagrun_with_reset_dag_run_true(self): """Test TriggerDagRunOperator with reset_dag_run.""" execution_date = DEFAULT_DATE - task = TriggerDagRunOperator(task_id="test_task", - trigger_dag_id=TRIGGERED_DAG_ID, - execution_date=execution_date, - reset_dag_run=True, - dag=self.dag) + task = TriggerDagRunOperator( + task_id="test_task", + trigger_dag_id=TRIGGERED_DAG_ID, + execution_date=execution_date, + reset_dag_run=True, + dag=self.dag, + ) task.run(start_date=execution_date, end_date=execution_date, ignore_ti_state=True) task.run(start_date=execution_date, end_date=execution_date, ignore_ti_state=True) diff --git a/tests/operators/test_email.py b/tests/operators/test_email.py index 867c1116a9370..c02b04ff494ea 100644 --- a/tests/operators/test_email.py +++ b/tests/operators/test_email.py @@ -34,15 +34,13 @@ class TestEmailOperator(unittest.TestCase): - def setUp(self): super().setUp() self.dag = DAG( 'test_dag', - default_args={ - 'owner': 'airflow', - 'start_date': DEFAULT_DATE}, - schedule_interval=INTERVAL) + default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE}, + schedule_interval=INTERVAL, + ) self.addCleanup(self.dag.clear) def _run_as_operator(self, **kwargs): @@ -52,12 +50,11 @@ def _run_as_operator(self, **kwargs): html_content='The quick brown fox jumps over the lazy dog', task_id='task', dag=self.dag, - **kwargs) + **kwargs, + ) task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) def test_execute(self): - with conf_vars( - {('email', 'email_backend'): 'tests.operators.test_email.send_email_test'} - ): + with conf_vars({('email', 'email_backend'): 'tests.operators.test_email.send_email_test'}): self._run_as_operator() assert send_email_test.call_count == 1 diff --git a/tests/operators/test_generic_transfer.py b/tests/operators/test_generic_transfer.py index 6ad670615b55b..68b5f8f251a2a 100644 --- a/tests/operators/test_generic_transfer.py +++ b/tests/operators/test_generic_transfer.py @@ -37,10 +37,7 @@ @pytest.mark.backend("mysql") class TestMySql(unittest.TestCase): def setUp(self): - args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - } + args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} dag = DAG(TEST_DAG_ID, default_args=args) self.dag = dag @@ -50,7 +47,12 @@ def tearDown(self): for table in drop_tables: conn.execute(f"DROP TABLE IF EXISTS {table}") - @parameterized.expand([("mysqlclient",), ("mysql-connector-python",), ]) + @parameterized.expand( + [ + ("mysqlclient",), + ("mysql-connector-python",), + ] + ) def test_mysql_to_mysql(self, client): with MySqlContext(client): sql = "SELECT * FROM INFORMATION_SCHEMA.TABLES LIMIT 100;" @@ -58,14 +60,14 @@ def test_mysql_to_mysql(self, client): task_id='test_m2m', preoperator=[ "DROP TABLE IF EXISTS test_mysql_to_mysql", - "CREATE TABLE IF NOT EXISTS " - "test_mysql_to_mysql LIKE INFORMATION_SCHEMA.TABLES" + "CREATE TABLE IF NOT EXISTS test_mysql_to_mysql LIKE INFORMATION_SCHEMA.TABLES", ], source_conn_id='airflow_db', destination_conn_id='airflow_db', destination_table="test_mysql_to_mysql", sql=sql, - dag=self.dag) + dag=self.dag, + ) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @@ -89,12 +91,12 @@ def test_postgres_to_postgres(self): task_id='test_p2p', preoperator=[ "DROP TABLE IF EXISTS test_postgres_to_postgres", - "CREATE TABLE IF NOT EXISTS " - "test_postgres_to_postgres (LIKE INFORMATION_SCHEMA.TABLES)" + "CREATE TABLE IF NOT EXISTS test_postgres_to_postgres (LIKE INFORMATION_SCHEMA.TABLES)", ], source_conn_id='postgres_default', destination_conn_id='postgres_default', destination_table="test_postgres_to_postgres", sql=sql, - dag=self.dag) + dag=self.dag, + ) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) diff --git a/tests/operators/test_latest_only_operator.py b/tests/operators/test_latest_only_operator.py index b3987564b3d8d..a514d642c6808 100644 --- a/tests/operators/test_latest_only_operator.py +++ b/tests/operators/test_latest_only_operator.py @@ -40,23 +40,22 @@ def get_task_instances(task_id): session = settings.Session() - return session \ - .query(TaskInstance) \ - .filter(TaskInstance.task_id == task_id) \ - .order_by(TaskInstance.execution_date) \ + return ( + session.query(TaskInstance) + .filter(TaskInstance.task_id == task_id) + .order_by(TaskInstance.execution_date) .all() + ) class TestLatestOnlyOperator(unittest.TestCase): - def setUp(self): super().setUp() self.dag = DAG( 'test_dag', - default_args={ - 'owner': 'airflow', - 'start_date': DEFAULT_DATE}, - schedule_interval=INTERVAL) + default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE}, + schedule_interval=INTERVAL, + ) with create_session() as session: session.query(DagRun).delete() session.query(TaskInstance).delete() @@ -65,25 +64,16 @@ def setUp(self): self.addCleanup(freezer.stop) def test_run(self): - task = LatestOnlyOperator( - task_id='latest', - dag=self.dag) + task = LatestOnlyOperator(task_id='latest', dag=self.dag) task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) def test_skipping_non_latest(self): - latest_task = LatestOnlyOperator( - task_id='latest', - dag=self.dag) - downstream_task = DummyOperator( - task_id='downstream', - dag=self.dag) - downstream_task2 = DummyOperator( - task_id='downstream_2', - dag=self.dag) + latest_task = LatestOnlyOperator(task_id='latest', dag=self.dag) + downstream_task = DummyOperator(task_id='downstream', dag=self.dag) + downstream_task2 = DummyOperator(task_id='downstream_2', dag=self.dag) downstream_task3 = DummyOperator( - task_id='downstream_3', - trigger_rule=TriggerRule.NONE_FAILED, - dag=self.dag) + task_id='downstream_3', trigger_rule=TriggerRule.NONE_FAILED, dag=self.dag + ) downstream_task.set_upstream(latest_task) downstream_task2.set_upstream(downstream_task) @@ -116,51 +106,53 @@ def test_skipping_non_latest(self): downstream_task3.run(start_date=DEFAULT_DATE, end_date=END_DATE) latest_instances = get_task_instances('latest') - exec_date_to_latest_state = { - ti.execution_date: ti.state for ti in latest_instances} - self.assertEqual({ - timezone.datetime(2016, 1, 1): 'success', - timezone.datetime(2016, 1, 1, 12): 'success', - timezone.datetime(2016, 1, 2): 'success'}, - exec_date_to_latest_state) + exec_date_to_latest_state = {ti.execution_date: ti.state for ti in latest_instances} + self.assertEqual( + { + timezone.datetime(2016, 1, 1): 'success', + timezone.datetime(2016, 1, 1, 12): 'success', + timezone.datetime(2016, 1, 2): 'success', + }, + exec_date_to_latest_state, + ) downstream_instances = get_task_instances('downstream') - exec_date_to_downstream_state = { - ti.execution_date: ti.state for ti in downstream_instances} - self.assertEqual({ - timezone.datetime(2016, 1, 1): 'skipped', - timezone.datetime(2016, 1, 1, 12): 'skipped', - timezone.datetime(2016, 1, 2): 'success'}, - exec_date_to_downstream_state) + exec_date_to_downstream_state = {ti.execution_date: ti.state for ti in downstream_instances} + self.assertEqual( + { + timezone.datetime(2016, 1, 1): 'skipped', + timezone.datetime(2016, 1, 1, 12): 'skipped', + timezone.datetime(2016, 1, 2): 'success', + }, + exec_date_to_downstream_state, + ) downstream_instances = get_task_instances('downstream_2') - exec_date_to_downstream_state = { - ti.execution_date: ti.state for ti in downstream_instances} - self.assertEqual({ - timezone.datetime(2016, 1, 1): None, - timezone.datetime(2016, 1, 1, 12): None, - timezone.datetime(2016, 1, 2): 'success'}, - exec_date_to_downstream_state) + exec_date_to_downstream_state = {ti.execution_date: ti.state for ti in downstream_instances} + self.assertEqual( + { + timezone.datetime(2016, 1, 1): None, + timezone.datetime(2016, 1, 1, 12): None, + timezone.datetime(2016, 1, 2): 'success', + }, + exec_date_to_downstream_state, + ) downstream_instances = get_task_instances('downstream_3') - exec_date_to_downstream_state = { - ti.execution_date: ti.state for ti in downstream_instances} - self.assertEqual({ - timezone.datetime(2016, 1, 1): 'success', - timezone.datetime(2016, 1, 1, 12): 'success', - timezone.datetime(2016, 1, 2): 'success'}, - exec_date_to_downstream_state) + exec_date_to_downstream_state = {ti.execution_date: ti.state for ti in downstream_instances} + self.assertEqual( + { + timezone.datetime(2016, 1, 1): 'success', + timezone.datetime(2016, 1, 1, 12): 'success', + timezone.datetime(2016, 1, 2): 'success', + }, + exec_date_to_downstream_state, + ) def test_not_skipping_external(self): - latest_task = LatestOnlyOperator( - task_id='latest', - dag=self.dag) - downstream_task = DummyOperator( - task_id='downstream', - dag=self.dag) - downstream_task2 = DummyOperator( - task_id='downstream_2', - dag=self.dag) + latest_task = LatestOnlyOperator(task_id='latest', dag=self.dag) + downstream_task = DummyOperator(task_id='downstream', dag=self.dag) + downstream_task2 = DummyOperator(task_id='downstream_2', dag=self.dag) downstream_task.set_upstream(latest_task) downstream_task2.set_upstream(downstream_task) @@ -194,28 +186,34 @@ def test_not_skipping_external(self): downstream_task2.run(start_date=DEFAULT_DATE, end_date=END_DATE) latest_instances = get_task_instances('latest') - exec_date_to_latest_state = { - ti.execution_date: ti.state for ti in latest_instances} - self.assertEqual({ - timezone.datetime(2016, 1, 1): 'success', - timezone.datetime(2016, 1, 1, 12): 'success', - timezone.datetime(2016, 1, 2): 'success'}, - exec_date_to_latest_state) + exec_date_to_latest_state = {ti.execution_date: ti.state for ti in latest_instances} + self.assertEqual( + { + timezone.datetime(2016, 1, 1): 'success', + timezone.datetime(2016, 1, 1, 12): 'success', + timezone.datetime(2016, 1, 2): 'success', + }, + exec_date_to_latest_state, + ) downstream_instances = get_task_instances('downstream') - exec_date_to_downstream_state = { - ti.execution_date: ti.state for ti in downstream_instances} - self.assertEqual({ - timezone.datetime(2016, 1, 1): 'success', - timezone.datetime(2016, 1, 1, 12): 'success', - timezone.datetime(2016, 1, 2): 'success'}, - exec_date_to_downstream_state) + exec_date_to_downstream_state = {ti.execution_date: ti.state for ti in downstream_instances} + self.assertEqual( + { + timezone.datetime(2016, 1, 1): 'success', + timezone.datetime(2016, 1, 1, 12): 'success', + timezone.datetime(2016, 1, 2): 'success', + }, + exec_date_to_downstream_state, + ) downstream_instances = get_task_instances('downstream_2') - exec_date_to_downstream_state = { - ti.execution_date: ti.state for ti in downstream_instances} - self.assertEqual({ - timezone.datetime(2016, 1, 1): 'success', - timezone.datetime(2016, 1, 1, 12): 'success', - timezone.datetime(2016, 1, 2): 'success'}, - exec_date_to_downstream_state) + exec_date_to_downstream_state = {ti.execution_date: ti.state for ti in downstream_instances} + self.assertEqual( + { + timezone.datetime(2016, 1, 1): 'success', + timezone.datetime(2016, 1, 1, 12): 'success', + timezone.datetime(2016, 1, 2): 'success', + }, + exec_date_to_downstream_state, + ) diff --git a/tests/operators/test_python.py b/tests/operators/test_python.py index f804f06e855ad..f1bb085d80334 100644 --- a/tests/operators/test_python.py +++ b/tests/operators/test_python.py @@ -34,7 +34,11 @@ from airflow.models.xcom_arg import XComArg from airflow.operators.dummy_operator import DummyOperator from airflow.operators.python import ( - BranchPythonOperator, PythonOperator, PythonVirtualenvOperator, ShortCircuitOperator, get_current_context, + BranchPythonOperator, + PythonOperator, + PythonVirtualenvOperator, + ShortCircuitOperator, + get_current_context, task as task_decorator, ) from airflow.utils import timezone @@ -49,10 +53,12 @@ INTERVAL = timedelta(hours=12) FROZEN_NOW = timezone.datetime(2016, 1, 2, 12, 1, 1) -TI_CONTEXT_ENV_VARS = ['AIRFLOW_CTX_DAG_ID', - 'AIRFLOW_CTX_TASK_ID', - 'AIRFLOW_CTX_EXECUTION_DATE', - 'AIRFLOW_CTX_DAG_RUN_ID'] +TI_CONTEXT_ENV_VARS = [ + 'AIRFLOW_CTX_DAG_ID', + 'AIRFLOW_CTX_TASK_ID', + 'AIRFLOW_CTX_EXECUTION_DATE', + 'AIRFLOW_CTX_DAG_RUN_ID', +] class Call: @@ -88,11 +94,7 @@ def setUpClass(cls): def setUp(self): super().setUp() - self.dag = DAG( - 'test_dag', - default_args={ - 'owner': 'airflow', - 'start_date': DEFAULT_DATE}) + self.dag = DAG('test_dag', default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE}) self.addCleanup(self.dag.clear) self.clear_run() self.addCleanup(self.clear_run) @@ -113,21 +115,12 @@ def _assert_calls_equal(self, first, second): self.assertTupleEqual(first.args, second.args) # eliminate context (conf, dag_run, task_instance, etc.) test_args = ["an_int", "a_date", "a_templated_string"] - first.kwargs = { - key: value - for (key, value) in first.kwargs.items() - if key in test_args - } - second.kwargs = { - key: value - for (key, value) in second.kwargs.items() - if key in test_args - } + first.kwargs = {key: value for (key, value) in first.kwargs.items() if key in test_args} + second.kwargs = {key: value for (key, value) in second.kwargs.items() if key in test_args} self.assertDictEqual(first.kwargs, second.kwargs) class TestPythonOperator(TestPythonBase): - def do_run(self): self.run = True @@ -136,10 +129,7 @@ def is_run(self): def test_python_operator_run(self): """Tests that the python callable is invoked on task run.""" - task = PythonOperator( - python_callable=self.do_run, - task_id='python_operator', - dag=self.dag) + task = PythonOperator(python_callable=self.do_run, task_id='python_operator', dag=self.dag) self.assertFalse(self.is_run()) task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) self.assertTrue(self.is_run()) @@ -149,16 +139,10 @@ def test_python_operator_python_callable_is_callable(self): the python_callable argument is callable.""" not_callable = {} with self.assertRaises(AirflowException): - PythonOperator( - python_callable=not_callable, - task_id='python_operator', - dag=self.dag) + PythonOperator(python_callable=not_callable, task_id='python_operator', dag=self.dag) not_callable = None with self.assertRaises(AirflowException): - PythonOperator( - python_callable=not_callable, - task_id='python_operator', - dag=self.dag) + PythonOperator(python_callable=not_callable, task_id='python_operator', dag=self.dag) def test_python_callable_arguments_are_templatized(self): """Test PythonOperator op_args are templatized""" @@ -174,19 +158,15 @@ def test_python_callable_arguments_are_templatized(self): # a Mock instance cannot be used as a callable function or test fails with a # TypeError: Object of type Mock is not JSON serializable python_callable=build_recording_function(recorded_calls), - op_args=[ - 4, - date(2019, 1, 1), - "dag {{dag.dag_id}} ran on {{ds}}.", - named_tuple - ], - dag=self.dag) + op_args=[4, date(2019, 1, 1), "dag {{dag.dag_id}} ran on {{ds}}.", named_tuple], + dag=self.dag, + ) self.dag.create_dagrun( run_type=DagRunType.MANUAL, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, - state=State.RUNNING + state=State.RUNNING, ) task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) @@ -194,10 +174,12 @@ def test_python_callable_arguments_are_templatized(self): self.assertEqual(1, len(recorded_calls)) self._assert_calls_equal( recorded_calls[0], - Call(4, - date(2019, 1, 1), - f"dag {self.dag.dag_id} ran on {ds_templated}.", - Named(ds_templated, 'unchanged')) + Call( + 4, + date(2019, 1, 1), + f"dag {self.dag.dag_id} ran on {ds_templated}.", + Named(ds_templated, 'unchanged'), + ), ) def test_python_callable_keyword_arguments_are_templatized(self): @@ -212,25 +194,29 @@ def test_python_callable_keyword_arguments_are_templatized(self): op_kwargs={ 'an_int': 4, 'a_date': date(2019, 1, 1), - 'a_templated_string': "dag {{dag.dag_id}} ran on {{ds}}." + 'a_templated_string': "dag {{dag.dag_id}} ran on {{ds}}.", }, - dag=self.dag) + dag=self.dag, + ) self.dag.create_dagrun( run_type=DagRunType.MANUAL, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, - state=State.RUNNING + state=State.RUNNING, ) task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) self.assertEqual(1, len(recorded_calls)) self._assert_calls_equal( recorded_calls[0], - Call(an_int=4, - a_date=date(2019, 1, 1), - a_templated_string="dag {} ran on {}.".format( - self.dag.dag_id, DEFAULT_DATE.date().isoformat())) + Call( + an_int=4, + a_date=date(2019, 1, 1), + a_templated_string="dag {} ran on {}.".format( + self.dag.dag_id, DEFAULT_DATE.date().isoformat() + ), + ), ) def test_python_operator_shallow_copy_attr(self): @@ -239,15 +225,15 @@ def test_python_operator_shallow_copy_attr(self): python_callable=not_callable, task_id='python_operator', op_kwargs={'certain_attrs': ''}, - dag=self.dag + dag=self.dag, ) new_task = copy.deepcopy(original_task) # shallow copy op_kwargs - self.assertEqual(id(original_task.op_kwargs['certain_attrs']), - id(new_task.op_kwargs['certain_attrs'])) + self.assertEqual( + id(original_task.op_kwargs['certain_attrs']), id(new_task.op_kwargs['certain_attrs']) + ) # shallow copy python_callable - self.assertEqual(id(original_task.python_callable), - id(new_task.python_callable)) + self.assertEqual(id(original_task.python_callable), id(new_task.python_callable)) def test_conflicting_kwargs(self): self.dag.create_dagrun( @@ -265,10 +251,7 @@ def func(dag): raise RuntimeError(f"Should not be triggered, dag: {dag}") python_operator = PythonOperator( - task_id='python_operator', - op_args=[1], - python_callable=func, - dag=self.dag + task_id='python_operator', op_args=[1], python_callable=func, dag=self.dag ) with self.assertRaises(ValueError) as context: @@ -296,7 +279,7 @@ def func(custom, dag): op_kwargs={'custom': 1}, python_callable=func, provide_context=True, - dag=self.dag + dag=self.dag, ) python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) @@ -314,10 +297,7 @@ def func(custom, dag): self.assertIsNotNone(dag, "dag should be set") python_operator = PythonOperator( - task_id='python_operator', - op_kwargs={'custom': 1}, - python_callable=func, - dag=self.dag + task_id='python_operator', op_kwargs={'custom': 1}, python_callable=func, dag=self.dag ) python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) @@ -335,16 +315,12 @@ def func(**context): self.assertGreater(len(context), 0, "Context has not been injected") python_operator = PythonOperator( - task_id='python_operator', - op_kwargs={'custom': 1}, - python_callable=func, - dag=self.dag + task_id='python_operator', op_kwargs={'custom': 1}, python_callable=func, dag=self.dag ) python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) class TestAirflowTaskDecorator(TestPythonBase): - def test_python_operator_python_callable_is_callable(self): """Tests that @task will only instantiate if the python_callable argument is callable.""" @@ -369,6 +345,7 @@ def test_fail_method(self): """Tests that @task will fail if signature is not binding.""" with pytest.raises(AirflowException): + class Test: num = 2 @@ -389,7 +366,7 @@ def add_number(num: int): run_id=DagRunType.MANUAL, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, - state=State.RUNNING + state=State.RUNNING, ) with pytest.raises(AirflowException): @@ -407,7 +384,7 @@ def add_number(num: int): run_id=DagRunType.MANUAL, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, - state=State.RUNNING + state=State.RUNNING, ) with pytest.raises(AirflowException): @@ -427,14 +404,15 @@ def test_python_callable_arguments_are_templatized(self): # a Mock instance cannot be used as a callable function or test fails with a # TypeError: Object of type Mock is not JSON serializable build_recording_function(recorded_calls), - dag=self.dag) + dag=self.dag, + ) ret = task(4, date(2019, 1, 1), "dag {{dag.dag_id}} ran on {{ds}}.", named_tuple) self.dag.create_dagrun( run_id=DagRunType.MANUAL, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, - state=State.RUNNING + state=State.RUNNING, ) ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) # pylint: disable=maybe-no-member @@ -442,10 +420,12 @@ def test_python_callable_arguments_are_templatized(self): assert len(recorded_calls) == 1 self._assert_calls_equal( recorded_calls[0], - Call(4, - date(2019, 1, 1), - f"dag {self.dag.dag_id} ran on {ds_templated}.", - Named(ds_templated, 'unchanged')) + Call( + 4, + date(2019, 1, 1), + f"dag {self.dag.dag_id} ran on {ds_templated}.", + Named(ds_templated, 'unchanged'), + ), ) def test_python_callable_keyword_arguments_are_templatized(self): @@ -456,24 +436,27 @@ def test_python_callable_keyword_arguments_are_templatized(self): # a Mock instance cannot be used as a callable function or test fails with a # TypeError: Object of type Mock is not JSON serializable build_recording_function(recorded_calls), - dag=self.dag + dag=self.dag, ) ret = task(an_int=4, a_date=date(2019, 1, 1), a_templated_string="dag {{dag.dag_id}} ran on {{ds}}.") self.dag.create_dagrun( run_id=DagRunType.MANUAL, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, - state=State.RUNNING + state=State.RUNNING, ) ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) # pylint: disable=maybe-no-member assert len(recorded_calls) == 1 self._assert_calls_equal( recorded_calls[0], - Call(an_int=4, - a_date=date(2019, 1, 1), - a_templated_string="dag {} ran on {}.".format( - self.dag.dag_id, DEFAULT_DATE.date().isoformat())) + Call( + an_int=4, + a_date=date(2019, 1, 1), + a_templated_string="dag {} ran on {}.".format( + self.dag.dag_id, DEFAULT_DATE.date().isoformat() + ), + ), ) def test_manual_task_id(self): @@ -523,10 +506,7 @@ def test_multiple_outputs(self): @task_decorator(multiple_outputs=True) def return_dict(number: int): - return { - 'number': number + 1, - '43': 43 - } + return {'number': number + 1, '43': 43} test_number = 10 with self.dag: @@ -536,7 +516,7 @@ def return_dict(number: int): run_id=DagRunType.MANUAL, start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, - state=State.RUNNING + state=State.RUNNING, ) ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) # pylint: disable=maybe-no-member @@ -578,7 +558,7 @@ def add_num(number: int, num2: int = 2): run_id=DagRunType.MANUAL, start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, - state=State.RUNNING + state=State.RUNNING, ) bigger_number.operator.run( # pylint: disable=maybe-no-member @@ -655,11 +635,11 @@ def setUpClass(cls): session.query(TI).delete() def setUp(self): - self.dag = DAG('branch_operator_test', - default_args={ - 'owner': 'airflow', - 'start_date': DEFAULT_DATE}, - schedule_interval=INTERVAL) + self.dag = DAG( + 'branch_operator_test', + default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE}, + schedule_interval=INTERVAL, + ) self.branch_1 = DummyOperator(task_id='branch_1', dag=self.dag) self.branch_2 = DummyOperator(task_id='branch_2', dag=self.dag) @@ -674,9 +654,9 @@ def tearDown(self): def test_without_dag_run(self): """This checks the defensive against non existent tasks in a dag run""" - branch_op = BranchPythonOperator(task_id='make_choice', - dag=self.dag, - python_callable=lambda: 'branch_1') + branch_op = BranchPythonOperator( + task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_1' + ) self.branch_1.set_upstream(branch_op) self.branch_2.set_upstream(branch_op) self.dag.clear() @@ -684,10 +664,7 @@ def test_without_dag_run(self): branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) with create_session() as session: - tis = session.query(TI).filter( - TI.dag_id == self.dag.dag_id, - TI.execution_date == DEFAULT_DATE - ) + tis = session.query(TI).filter(TI.dag_id == self.dag.dag_id, TI.execution_date == DEFAULT_DATE) for ti in tis: if ti.task_id == 'make_choice': @@ -702,9 +679,9 @@ def test_without_dag_run(self): def test_branch_list_without_dag_run(self): """This checks if the BranchPythonOperator supports branching off to a list of tasks.""" - branch_op = BranchPythonOperator(task_id='make_choice', - dag=self.dag, - python_callable=lambda: ['branch_1', 'branch_2']) + branch_op = BranchPythonOperator( + task_id='make_choice', dag=self.dag, python_callable=lambda: ['branch_1', 'branch_2'] + ) self.branch_1.set_upstream(branch_op) self.branch_2.set_upstream(branch_op) self.branch_3 = DummyOperator(task_id='branch_3', dag=self.dag) @@ -714,10 +691,7 @@ def test_branch_list_without_dag_run(self): branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) with create_session() as session: - tis = session.query(TI).filter( - TI.dag_id == self.dag.dag_id, - TI.execution_date == DEFAULT_DATE - ) + tis = session.query(TI).filter(TI.dag_id == self.dag.dag_id, TI.execution_date == DEFAULT_DATE) expected = { "make_choice": State.SUCCESS, @@ -733,9 +707,9 @@ def test_branch_list_without_dag_run(self): raise ValueError(f'Invalid task id {ti.task_id} found!') def test_with_dag_run(self): - branch_op = BranchPythonOperator(task_id='make_choice', - dag=self.dag, - python_callable=lambda: 'branch_1') + branch_op = BranchPythonOperator( + task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_1' + ) self.branch_1.set_upstream(branch_op) self.branch_2.set_upstream(branch_op) @@ -745,7 +719,7 @@ def test_with_dag_run(self): run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, - state=State.RUNNING + state=State.RUNNING, ) branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) @@ -762,9 +736,9 @@ def test_with_dag_run(self): raise ValueError(f'Invalid task id {ti.task_id} found!') def test_with_skip_in_branch_downstream_dependencies(self): - branch_op = BranchPythonOperator(task_id='make_choice', - dag=self.dag, - python_callable=lambda: 'branch_1') + branch_op = BranchPythonOperator( + task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_1' + ) branch_op >> self.branch_1 >> self.branch_2 branch_op >> self.branch_2 @@ -774,7 +748,7 @@ def test_with_skip_in_branch_downstream_dependencies(self): run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, - state=State.RUNNING + state=State.RUNNING, ) branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) @@ -791,9 +765,9 @@ def test_with_skip_in_branch_downstream_dependencies(self): raise ValueError(f'Invalid task id {ti.task_id} found!') def test_with_skip_in_branch_downstream_dependencies2(self): - branch_op = BranchPythonOperator(task_id='make_choice', - dag=self.dag, - python_callable=lambda: 'branch_2') + branch_op = BranchPythonOperator( + task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_2' + ) branch_op >> self.branch_1 >> self.branch_2 branch_op >> self.branch_2 @@ -803,7 +777,7 @@ def test_with_skip_in_branch_downstream_dependencies2(self): run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, - state=State.RUNNING + state=State.RUNNING, ) branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) @@ -820,9 +794,9 @@ def test_with_skip_in_branch_downstream_dependencies2(self): raise ValueError(f'Invalid task id {ti.task_id} found!') def test_xcom_push(self): - branch_op = BranchPythonOperator(task_id='make_choice', - dag=self.dag, - python_callable=lambda: 'branch_1') + branch_op = BranchPythonOperator( + task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_1' + ) self.branch_1.set_upstream(branch_op) self.branch_2.set_upstream(branch_op) @@ -832,7 +806,7 @@ def test_xcom_push(self): run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, - state=State.RUNNING + state=State.RUNNING, ) branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) @@ -840,17 +814,16 @@ def test_xcom_push(self): tis = dr.get_task_instances() for ti in tis: if ti.task_id == 'make_choice': - self.assertEqual( - ti.xcom_pull(task_ids='make_choice'), 'branch_1') + self.assertEqual(ti.xcom_pull(task_ids='make_choice'), 'branch_1') def test_clear_skipped_downstream_task(self): """ After a downstream task is skipped by BranchPythonOperator, clearing the skipped task should not cause it to be executed. """ - branch_op = BranchPythonOperator(task_id='make_choice', - dag=self.dag, - python_callable=lambda: 'branch_1') + branch_op = BranchPythonOperator( + task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_1' + ) branches = [self.branch_1, self.branch_2] branch_op >> branches self.dag.clear() @@ -859,7 +832,7 @@ def test_clear_skipped_downstream_task(self): run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, - state=State.RUNNING + state=State.RUNNING, ) branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) @@ -919,15 +892,12 @@ def tearDown(self): def test_without_dag_run(self): """This checks the defensive against non existent tasks in a dag run""" value = False - dag = DAG('shortcircuit_operator_test_without_dag_run', - default_args={ - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - }, - schedule_interval=INTERVAL) - short_op = ShortCircuitOperator(task_id='make_choice', - dag=dag, - python_callable=lambda: value) + dag = DAG( + 'shortcircuit_operator_test_without_dag_run', + default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE}, + schedule_interval=INTERVAL, + ) + short_op = ShortCircuitOperator(task_id='make_choice', dag=dag, python_callable=lambda: value) branch_1 = DummyOperator(task_id='branch_1', dag=dag) branch_1.set_upstream(short_op) branch_2 = DummyOperator(task_id='branch_2', dag=dag) @@ -939,10 +909,7 @@ def test_without_dag_run(self): short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) with create_session() as session: - tis = session.query(TI).filter( - TI.dag_id == dag.dag_id, - TI.execution_date == DEFAULT_DATE - ) + tis = session.query(TI).filter(TI.dag_id == dag.dag_id, TI.execution_date == DEFAULT_DATE) for ti in tis: if ti.task_id == 'make_choice': @@ -972,15 +939,12 @@ def test_without_dag_run(self): def test_with_dag_run(self): value = False - dag = DAG('shortcircuit_operator_test_with_dag_run', - default_args={ - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - }, - schedule_interval=INTERVAL) - short_op = ShortCircuitOperator(task_id='make_choice', - dag=dag, - python_callable=lambda: value) + dag = DAG( + 'shortcircuit_operator_test_with_dag_run', + default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE}, + schedule_interval=INTERVAL, + ) + short_op = ShortCircuitOperator(task_id='make_choice', dag=dag, python_callable=lambda: value) branch_1 = DummyOperator(task_id='branch_1', dag=dag) branch_1.set_upstream(short_op) branch_2 = DummyOperator(task_id='branch_2', dag=dag) @@ -994,7 +958,7 @@ def test_with_dag_run(self): run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, - state=State.RUNNING + state=State.RUNNING, ) upstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) @@ -1035,15 +999,12 @@ def test_clear_skipped_downstream_task(self): After a downstream task is skipped by ShortCircuitOperator, clearing the skipped task should not cause it to be executed. """ - dag = DAG('shortcircuit_clear_skipped_downstream_task', - default_args={ - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - }, - schedule_interval=INTERVAL) - short_op = ShortCircuitOperator(task_id='make_choice', - dag=dag, - python_callable=lambda: False) + dag = DAG( + 'shortcircuit_clear_skipped_downstream_task', + default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE}, + schedule_interval=INTERVAL, + ) + short_op = ShortCircuitOperator(task_id='make_choice', dag=dag, python_callable=lambda: False) downstream = DummyOperator(task_id='downstream', dag=dag) short_op >> downstream @@ -1054,7 +1015,7 @@ def test_clear_skipped_downstream_task(self): run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, - state=State.RUNNING + state=State.RUNNING, ) short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) @@ -1072,9 +1033,7 @@ def test_clear_skipped_downstream_task(self): # Clear downstream with create_session() as session: - clear_task_instances([t for t in tis if t.task_id == "downstream"], - session=session, - dag=dag) + clear_task_instances([t for t in tis if t.task_id == "downstream"], session=session, dag=dag) # Run downstream again downstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) @@ -1093,24 +1052,19 @@ def test_clear_skipped_downstream_task(self): class TestPythonVirtualenvOperator(unittest.TestCase): - def setUp(self): super().setUp() self.dag = DAG( 'test_dag', - default_args={ - 'owner': 'airflow', - 'start_date': DEFAULT_DATE}, - schedule_interval=INTERVAL) + default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE}, + schedule_interval=INTERVAL, + ) self.addCleanup(self.dag.clear) def _run_as_operator(self, fn, python_version=sys.version_info[0], **kwargs): task = PythonVirtualenvOperator( - python_callable=fn, - python_version=python_version, - task_id='task', - dag=self.dag, - **kwargs) + python_callable=fn, python_version=python_version, task_id='task', dag=self.dag, **kwargs + ) task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) return task @@ -1146,11 +1100,11 @@ def f(): self._run_as_operator(f, requirements=['funcsigs'], system_site_packages=True) def test_with_requirements_pinned(self): - self.assertNotEqual( - '0.4', funcsigs.__version__, 'Please update this string if this fails') + self.assertNotEqual('0.4', funcsigs.__version__, 'Please update this string if this fails') def f(): import funcsigs # noqa: F401 # pylint: disable=redefined-outer-name,reimported + if funcsigs.__version__ != '0.4': raise Exception @@ -1160,15 +1114,13 @@ def test_unpinned_requirements(self): def f(): import funcsigs # noqa: F401 # pylint: disable=redefined-outer-name,reimported,unused-import - self._run_as_operator( - f, requirements=['funcsigs', 'dill'], system_site_packages=False) + self._run_as_operator(f, requirements=['funcsigs', 'dill'], system_site_packages=False) def test_range_requirements(self): def f(): import funcsigs # noqa: F401 # pylint: disable=redefined-outer-name,reimported,unused-import - self._run_as_operator( - f, requirements=['funcsigs>1.0', 'dill'], system_site_packages=False) + self._run_as_operator(f, requirements=['funcsigs>1.0', 'dill'], system_site_packages=False) def test_fail(self): def f(): @@ -1193,6 +1145,7 @@ def f(): def test_python_3(self): def f(): import sys # pylint: disable=reimported,unused-import,redefined-outer-name + print(sys.version) try: {}.iteritems() # pylint: disable=no-member @@ -1234,8 +1187,7 @@ def f(): if virtualenv_string_args[0] != virtualenv_string_args[2]: raise Exception - self._run_as_operator( - f, python_version=self._invert_python_major_version(), string_args=[1, 2, 1]) + self._run_as_operator(f, python_version=self._invert_python_major_version(), string_args=[1, 2, 1]) def test_with_args(self): def f(a, b, c=False, d=False): @@ -1254,10 +1206,7 @@ def f(): def test_lambda(self): with self.assertRaises(AirflowException): - PythonVirtualenvOperator( - python_callable=lambda x: 4, - task_id='task', - dag=self.dag) + PythonVirtualenvOperator(python_callable=lambda x: 4, task_id='task', dag=self.dag) def test_nonimported_as_arg(self): def f(_): @@ -1305,16 +1254,11 @@ def f( dag_run, task, # other - **context + **context, ): # pylint: disable=unused-argument,too-many-arguments,too-many-locals pass - self._run_as_operator( - f, - use_dill=True, - system_site_packages=True, - requirements=None - ) + self._run_as_operator(f, use_dill=True, system_site_packages=True, requirements=None) def test_pendulum_context(self): def f( @@ -1344,15 +1288,12 @@ def f( prev_execution_date_success, prev_start_date_success, # other - **context + **context, ): # pylint: disable=unused-argument,too-many-arguments,too-many-locals pass self._run_as_operator( - f, - use_dill=True, - system_site_packages=False, - requirements=['pendulum', 'lazy_object_proxy'] + f, use_dill=True, system_site_packages=False, requirements=['pendulum', 'lazy_object_proxy'] ) def test_base_context(self): @@ -1377,16 +1318,11 @@ def f( yesterday_ds, yesterday_ds_nodash, # other - **context + **context, ): # pylint: disable=unused-argument,too-many-arguments,too-many-locals pass - self._run_as_operator( - f, - use_dill=True, - system_site_packages=False, - requirements=None - ) + self._run_as_operator(f, use_dill=True, system_site_packages=False, requirements=None) DEFAULT_ARGS = { @@ -1416,7 +1352,9 @@ def test_context_removed_after_exit(self): with set_current_context(example_context): pass - with pytest.raises(AirflowException, ): + with pytest.raises( + AirflowException, + ): get_current_context() def test_nested_context(self): @@ -1477,7 +1415,7 @@ def test_get_context_in_old_style_context_task(self): [ ("task1", [State.SUCCESS, State.SUCCESS, State.SUCCESS]), ("join", [State.SUCCESS, State.SKIPPED, State.SUCCESS]), - ] + ], ) def test_empty_branch(choice, expected_states): """ diff --git a/tests/operators/test_sql.py b/tests/operators/test_sql.py index b6c9ec12595fe..a2bd128576907 100644 --- a/tests/operators/test_sql.py +++ b/tests/operators/test_sql.py @@ -25,7 +25,10 @@ from airflow.exceptions import AirflowException from airflow.models import DAG, DagRun, TaskInstance as TI from airflow.operators.check_operator import ( - CheckOperator, IntervalCheckOperator, ThresholdCheckOperator, ValueCheckOperator, + CheckOperator, + IntervalCheckOperator, + ThresholdCheckOperator, + ValueCheckOperator, ) from airflow.operators.dummy_operator import DummyOperator from airflow.operators.sql import BranchSQLOperator @@ -98,8 +101,7 @@ def _construct_operator(self, sql, pass_value, tolerance=None): def test_pass_value_template_string(self): pass_value_str = "2018-03-22" - operator = self._construct_operator( - "select date from tab1;", "{{ ds }}") + operator = self._construct_operator("select date from tab1;", "{{ ds }}") operator.render_template_fields({"ds": pass_value_str}) @@ -108,8 +110,7 @@ def test_pass_value_template_string(self): def test_pass_value_template_string_float(self): pass_value_float = 4.0 - operator = self._construct_operator( - "select date from tab1;", pass_value_float) + operator = self._construct_operator("select date from tab1;", pass_value_float) operator.render_template_fields({}) @@ -134,8 +135,7 @@ def test_execute_fail(self, mock_get_db_hook): mock_hook.get_first.return_value = [11] mock_get_db_hook.return_value = mock_hook - operator = self._construct_operator( - "select value from tab1 limit 1;", 5, 1) + operator = self._construct_operator("select value from tab1 limit 1;", 5, 1) with self.assertRaisesRegex(AirflowException, "Tolerance:100.0%"): operator.execute() @@ -155,7 +155,9 @@ def test_invalid_ratio_formula(self): with self.assertRaisesRegex(AirflowException, "Invalid diff_method"): self._construct_operator( table="test_table", - metric_thresholds={"f1": 1, }, + metric_thresholds={ + "f1": 1, + }, ratio_formula="abs", ignore_zero=False, ) @@ -168,7 +170,9 @@ def test_execute_not_ignore_zero(self, mock_get_db_hook): operator = self._construct_operator( table="test_table", - metric_thresholds={"f1": 1, }, + metric_thresholds={ + "f1": 1, + }, ratio_formula="max_over_min", ignore_zero=False, ) @@ -184,7 +188,9 @@ def test_execute_ignore_zero(self, mock_get_db_hook): operator = self._construct_operator( table="test_table", - metric_thresholds={"f1": 1, }, + metric_thresholds={ + "f1": 1, + }, ratio_formula="max_over_min", ignore_zero=True, ) @@ -208,7 +214,12 @@ def returned_row(): operator = self._construct_operator( table="test_table", - metric_thresholds={"f0": 1.0, "f1": 1.5, "f2": 2.0, "f3": 2.5, }, + metric_thresholds={ + "f0": 1.0, + "f1": 1.5, + "f2": 2.0, + "f3": 2.5, + }, ratio_formula="max_over_min", ignore_zero=True, ) @@ -233,7 +244,12 @@ def returned_row(): operator = self._construct_operator( table="test_table", - metric_thresholds={"f0": 0.5, "f1": 0.6, "f2": 0.7, "f3": 0.8, }, + metric_thresholds={ + "f0": 0.5, + "f1": 0.6, + "f2": 0.7, + "f3": 0.8, + }, ratio_formula="relative_diff", ignore_zero=True, ) @@ -260,9 +276,7 @@ def test_pass_min_value_max_value(self, mock_get_db_hook): mock_hook.get_first.return_value = (10,) mock_get_db_hook.return_value = mock_hook - operator = self._construct_operator( - "Select avg(val) from table1 limit 1", 1, 100 - ) + operator = self._construct_operator("Select avg(val) from table1 limit 1", 1, 100) operator.execute() @@ -272,9 +286,7 @@ def test_fail_min_value_max_value(self, mock_get_db_hook): mock_hook.get_first.return_value = (10,) mock_get_db_hook.return_value = mock_hook - operator = self._construct_operator( - "Select avg(val) from table1 limit 1", 20, 100 - ) + operator = self._construct_operator("Select avg(val) from table1 limit 1", 20, 100) with self.assertRaisesRegex(AirflowException, "10.*20.0.*100.0"): operator.execute() @@ -285,8 +297,7 @@ def test_pass_min_sql_max_sql(self, mock_get_db_hook): mock_hook.get_first.side_effect = lambda x: (int(x.split()[1]),) mock_get_db_hook.return_value = mock_hook - operator = self._construct_operator( - "Select 10", "Select 1", "Select 100") + operator = self._construct_operator("Select 10", "Select 1", "Select 100") operator.execute() @@ -296,8 +307,7 @@ def test_fail_min_sql_max_sql(self, mock_get_db_hook): mock_hook.get_first.side_effect = lambda x: (int(x.split()[1]),) mock_get_db_hook.return_value = mock_hook - operator = self._construct_operator( - "Select 10", "Select 20", "Select 100") + operator = self._construct_operator("Select 10", "Select 20", "Select 100") with self.assertRaisesRegex(AirflowException, "10.*20.*100"): operator.execute() @@ -367,8 +377,7 @@ def test_unsupported_conn_type(self): ) with self.assertRaises(AirflowException): - op.run(start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE, ignore_ti_state=True) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_invalid_conn(self): """Check if BranchSQLOperator throws an exception for invalid connection """ @@ -382,8 +391,7 @@ def test_invalid_conn(self): ) with self.assertRaises(AirflowException): - op.run(start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE, ignore_ti_state=True) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_invalid_follow_task_true(self): """Check if BranchSQLOperator throws an exception for invalid connection """ @@ -397,8 +405,7 @@ def test_invalid_follow_task_true(self): ) with self.assertRaises(AirflowException): - op.run(start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE, ignore_ti_state=True) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_invalid_follow_task_false(self): """Check if BranchSQLOperator throws an exception for invalid connection """ @@ -412,8 +419,7 @@ def test_invalid_follow_task_false(self): ) with self.assertRaises(AirflowException): - op.run(start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE, ignore_ti_state=True) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @pytest.mark.backend("mysql") def test_sql_branch_operator_mysql(self): @@ -426,9 +432,7 @@ def test_sql_branch_operator_mysql(self): follow_task_ids_if_false="branch_2", dag=self.dag, ) - branch_op.run( - start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True - ) + branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @pytest.mark.backend("postgres") def test_sql_branch_operator_postgres(self): @@ -441,9 +445,7 @@ def test_sql_branch_operator_postgres(self): follow_task_ids_if_false="branch_2", dag=self.dag, ) - branch_op.run( - start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True - ) + branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @mock.patch("airflow.operators.sql.BaseHook") def test_branch_single_value_with_dag_run(self, mock_hook): @@ -469,9 +471,7 @@ def test_branch_single_value_with_dag_run(self, mock_hook): ) mock_hook.get_connection("mysql_default").conn_type = "mysql" - mock_get_records = ( - mock_hook.get_connection.return_value.get_hook.return_value.get_first - ) + mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_first mock_get_records.return_value = 1 @@ -512,9 +512,7 @@ def test_branch_true_with_dag_run(self, mock_hook): ) mock_hook.get_connection("mysql_default").conn_type = "mysql" - mock_get_records = ( - mock_hook.get_connection.return_value.get_hook.return_value.get_first - ) + mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_first for true_value in SUPPORTED_TRUE_VALUES: mock_get_records.return_value = true_value @@ -556,9 +554,7 @@ def test_branch_false_with_dag_run(self, mock_hook): ) mock_hook.get_connection("mysql_default").conn_type = "mysql" - mock_get_records = ( - mock_hook.get_connection.return_value.get_hook.return_value.get_first - ) + mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_first for false_value in SUPPORTED_FALSE_VALUES: mock_get_records.return_value = false_value @@ -602,9 +598,7 @@ def test_branch_list_with_dag_run(self, mock_hook): ) mock_hook.get_connection("mysql_default").conn_type = "mysql" - mock_get_records = ( - mock_hook.get_connection.return_value.get_hook.return_value.get_first - ) + mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_first mock_get_records.return_value = [["1"]] branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) @@ -646,9 +640,7 @@ def test_invalid_query_result_with_dag_run(self, mock_hook): ) mock_hook.get_connection("mysql_default").conn_type = "mysql" - mock_get_records = ( - mock_hook.get_connection.return_value.get_hook.return_value.get_first - ) + mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_first mock_get_records.return_value = ["Invalid Value"] @@ -679,9 +671,7 @@ def test_with_skip_in_branch_downstream_dependencies(self, mock_hook): ) mock_hook.get_connection("mysql_default").conn_type = "mysql" - mock_get_records = ( - mock_hook.get_connection.return_value.get_hook.return_value.get_first - ) + mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_first for true_value in SUPPORTED_TRUE_VALUES: mock_get_records.return_value = [true_value] @@ -723,9 +713,7 @@ def test_with_skip_in_branch_downstream_dependencies2(self, mock_hook): ) mock_hook.get_connection("mysql_default").conn_type = "mysql" - mock_get_records = ( - mock_hook.get_connection.return_value.get_hook.return_value.get_first - ) + mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_first for false_value in SUPPORTED_FALSE_VALUES: mock_get_records.return_value = [false_value] diff --git a/tests/operators/test_subdag_operator.py b/tests/operators/test_subdag_operator.py index f1946d36af96e..bd2b578d631fd 100644 --- a/tests/operators/test_subdag_operator.py +++ b/tests/operators/test_subdag_operator.py @@ -42,7 +42,6 @@ class TestSubDagOperator(unittest.TestCase): - def setUp(self): clear_db_runs() self.dag_run_running = DagRun() @@ -63,15 +62,9 @@ def test_subdag_name(self): subdag_bad3 = DAG('bad.bad', default_args=default_args) SubDagOperator(task_id='test', dag=dag, subdag=subdag_good) - self.assertRaises( - AirflowException, - SubDagOperator, task_id='test', dag=dag, subdag=subdag_bad1) - self.assertRaises( - AirflowException, - SubDagOperator, task_id='test', dag=dag, subdag=subdag_bad2) - self.assertRaises( - AirflowException, - SubDagOperator, task_id='test', dag=dag, subdag=subdag_bad3) + self.assertRaises(AirflowException, SubDagOperator, task_id='test', dag=dag, subdag=subdag_bad1) + self.assertRaises(AirflowException, SubDagOperator, task_id='test', dag=dag, subdag=subdag_bad2) + self.assertRaises(AirflowException, SubDagOperator, task_id='test', dag=dag, subdag=subdag_bad3) def test_subdag_in_context_manager(self): """ @@ -101,14 +94,12 @@ def test_subdag_pools(self): DummyOperator(task_id='dummy', dag=subdag, pool='test_pool_1') self.assertRaises( - AirflowException, - SubDagOperator, - task_id='child', dag=dag, subdag=subdag, pool='test_pool_1') + AirflowException, SubDagOperator, task_id='child', dag=dag, subdag=subdag, pool='test_pool_1' + ) # recreate dag because failed subdagoperator was already added dag = DAG('parent', default_args=default_args) - SubDagOperator( - task_id='child', dag=dag, subdag=subdag, pool='test_pool_10') + SubDagOperator(task_id='child', dag=dag, subdag=subdag, pool='test_pool_10') session.delete(pool_1) session.delete(pool_10) @@ -132,9 +123,7 @@ def test_subdag_pools_no_possible_conflict(self): DummyOperator(task_id='dummy', dag=subdag, pool='test_pool_10') mock_session = Mock() - SubDagOperator( - task_id='child', dag=dag, subdag=subdag, pool='test_pool_1', - session=mock_session) + SubDagOperator(task_id='child', dag=dag, subdag=subdag, pool='test_pool_1', session=mock_session) self.assertFalse(mock_session.query.called) session.delete(pool_1) @@ -272,22 +261,19 @@ def test_rerun_failed_subdag(self): sub_dagrun.refresh_from_db() self.assertEqual(sub_dagrun.state, State.RUNNING) - @parameterized.expand([ - (SkippedStatePropagationOptions.ALL_LEAVES, [State.SKIPPED, State.SKIPPED], True), - (SkippedStatePropagationOptions.ALL_LEAVES, [State.SKIPPED, State.SUCCESS], False), - (SkippedStatePropagationOptions.ANY_LEAF, [State.SKIPPED, State.SUCCESS], True), - (SkippedStatePropagationOptions.ANY_LEAF, [State.FAILED, State.SKIPPED], True), - (None, [State.SKIPPED, State.SKIPPED], False), - ]) + @parameterized.expand( + [ + (SkippedStatePropagationOptions.ALL_LEAVES, [State.SKIPPED, State.SKIPPED], True), + (SkippedStatePropagationOptions.ALL_LEAVES, [State.SKIPPED, State.SUCCESS], False), + (SkippedStatePropagationOptions.ANY_LEAF, [State.SKIPPED, State.SUCCESS], True), + (SkippedStatePropagationOptions.ANY_LEAF, [State.FAILED, State.SKIPPED], True), + (None, [State.SKIPPED, State.SKIPPED], False), + ] + ) @mock.patch('airflow.operators.subdag_operator.SubDagOperator.skip') @mock.patch('airflow.operators.subdag_operator.get_task_instance') def test_subdag_with_propagate_skipped_state( - self, - propagate_option, - states, - skip_parent, - mock_get_task_instance, - mock_skip + self, propagate_option, states, skip_parent, mock_get_task_instance, mock_skip ): """ Tests that skipped state of leaf tasks propagates to the parent dag. @@ -296,15 +282,10 @@ def test_subdag_with_propagate_skipped_state( dag = DAG('parent', default_args=default_args) subdag = DAG('parent.test', default_args=default_args) subdag_task = SubDagOperator( - task_id='test', - subdag=subdag, - dag=dag, - poke_interval=1, - propagate_skipped_state=propagate_option + task_id='test', subdag=subdag, dag=dag, poke_interval=1, propagate_skipped_state=propagate_option ) dummy_subdag_tasks = [ - DummyOperator(task_id=f'dummy_subdag_{i}', dag=subdag) - for i in range(len(states)) + DummyOperator(task_id=f'dummy_subdag_{i}', dag=subdag) for i in range(len(states)) ] dummy_dag_task = DummyOperator(task_id='dummy_dag', dag=dag) subdag_task >> dummy_dag_task @@ -312,25 +293,14 @@ def test_subdag_with_propagate_skipped_state( subdag_task._get_dagrun = Mock() subdag_task._get_dagrun.return_value = self.dag_run_success mock_get_task_instance.side_effect = [ - TaskInstance( - task=task, - execution_date=DEFAULT_DATE, - state=state - ) for task, state in zip(dummy_subdag_tasks, states) + TaskInstance(task=task, execution_date=DEFAULT_DATE, state=state) + for task, state in zip(dummy_subdag_tasks, states) ] - context = { - 'execution_date': DEFAULT_DATE, - 'dag_run': DagRun(), - 'task': subdag_task - } + context = {'execution_date': DEFAULT_DATE, 'dag_run': DagRun(), 'task': subdag_task} subdag_task.post_execute(context) if skip_parent: - mock_skip.assert_called_once_with( - context['dag_run'], - context['execution_date'], - [dummy_dag_task] - ) + mock_skip.assert_called_once_with(context['dag_run'], context['execution_date'], [dummy_dag_task]) else: mock_skip.assert_not_called() diff --git a/tests/plugins/test_plugin.py b/tests/plugins/test_plugin.py index 1f719f16d5fb7..d7a4df1f40fbb 100644 --- a/tests/plugins/test_plugin.py +++ b/tests/plugins/test_plugin.py @@ -20,14 +20,21 @@ from flask_appbuilder import BaseView as AppBuilderBaseView, expose from airflow.executors.base_executor import BaseExecutor + # Importing base classes that we need to derive from airflow.hooks.base_hook import BaseHook from airflow.models.baseoperator import BaseOperator + # This is the class you derive to create a plugin from airflow.plugins_manager import AirflowPlugin from airflow.sensors.base_sensor_operator import BaseSensorOperator from tests.test_utils.mock_operators import ( - AirflowLink, AirflowLink2, CustomBaseIndexOpLink, CustomOpLink, GithubLink, GoogleLink, + AirflowLink, + AirflowLink2, + CustomBaseIndexOpLink, + CustomOpLink, + GithubLink, + GoogleLink, ) @@ -66,22 +73,24 @@ def test(self): v_appbuilder_view = PluginTestAppBuilderBaseView() -v_appbuilder_package = {"name": "Test View", - "category": "Test Plugin", - "view": v_appbuilder_view} +v_appbuilder_package = {"name": "Test View", "category": "Test Plugin", "view": v_appbuilder_view} # Creating a flask appbuilder Menu Item -appbuilder_mitem = {"name": "Google", - "category": "Search", - "category_icon": "fa-th", - "href": "https://www.google.com"} +appbuilder_mitem = { + "name": "Google", + "category": "Search", + "category_icon": "fa-th", + "href": "https://www.google.com", +} # Creating a flask blueprint to intergrate the templates and static folder bp = Blueprint( - "test_plugin", __name__, + "test_plugin", + __name__, template_folder='templates', # registers airflow/plugins/templates as a Jinja template folder static_folder='static', - static_url_path='/static/test_plugin') + static_url_path='/static/test_plugin', +) # Defining the plugin class @@ -99,9 +108,7 @@ class AirflowTestPlugin(AirflowPlugin): AirflowLink(), GithubLink(), ] - operator_extra_links = [ - GoogleLink(), AirflowLink2(), CustomOpLink(), CustomBaseIndexOpLink(1) - ] + operator_extra_links = [GoogleLink(), AirflowLink2(), CustomOpLink(), CustomBaseIndexOpLink(1)] class MockPluginA(AirflowPlugin): diff --git a/tests/plugins/test_plugins_manager.py b/tests/plugins/test_plugins_manager.py index 7b5357ba58bd3..07e500dc01b96 100644 --- a/tests/plugins/test_plugins_manager.py +++ b/tests/plugins/test_plugins_manager.py @@ -31,15 +31,20 @@ def setUp(self): def test_flaskappbuilder_views(self): from tests.plugins.test_plugin import v_appbuilder_package + appbuilder_class_name = str(v_appbuilder_package['view'].__class__.__name__) - plugin_views = [view for view in self.appbuilder.baseviews - if view.blueprint.name == appbuilder_class_name] + plugin_views = [ + view for view in self.appbuilder.baseviews if view.blueprint.name == appbuilder_class_name + ] self.assertTrue(len(plugin_views) == 1) # view should have a menu item matching category of v_appbuilder_package - links = [menu_item for menu_item in self.appbuilder.menu.menu - if menu_item.name == v_appbuilder_package['category']] + links = [ + menu_item + for menu_item in self.appbuilder.menu.menu + if menu_item.name == v_appbuilder_package['category'] + ] self.assertTrue(len(links) == 1) @@ -52,8 +57,11 @@ def test_flaskappbuilder_menu_links(self): from tests.plugins.test_plugin import appbuilder_mitem # menu item should exist matching appbuilder_mitem - links = [menu_item for menu_item in self.appbuilder.menu.menu - if menu_item.name == appbuilder_mitem['category']] + links = [ + menu_item + for menu_item in self.appbuilder.menu.menu + if menu_item.name == appbuilder_mitem['category'] + ] self.assertTrue(len(links) == 1) @@ -115,6 +123,7 @@ class TestNonPropertyHook(BaseHook): with mock_plugin_manager(plugins=[AirflowTestPropertyPlugin()]): from airflow import plugins_manager + plugins_manager.integrate_dag_plugins() self.assertIn('AirflowTestPropertyPlugin', str(plugins_manager.plugins)) @@ -132,24 +141,24 @@ class AirflowAdminMenuLinksPlugin(AirflowPlugin): menu_links = [mock.MagicMock()] - with mock_plugin_manager(plugins=[ - AirflowAdminViewsPlugin(), - AirflowAdminMenuLinksPlugin() - ]): + with mock_plugin_manager(plugins=[AirflowAdminViewsPlugin(), AirflowAdminMenuLinksPlugin()]): from airflow import plugins_manager # assert not logs with self.assertLogs(plugins_manager.log) as cm: plugins_manager.initialize_web_ui_plugins() - self.assertEqual(cm.output, [ - 'WARNING:airflow.plugins_manager:Plugin \'test_admin_views_plugin\' may not be ' - 'compatible with the current Airflow version. Please contact the author of ' - 'the plugin.', - 'WARNING:airflow.plugins_manager:Plugin \'test_menu_links_plugin\' may not be ' - 'compatible with the current Airflow version. Please contact the author of ' - 'the plugin.' - ]) + self.assertEqual( + cm.output, + [ + 'WARNING:airflow.plugins_manager:Plugin \'test_admin_views_plugin\' may not be ' + 'compatible with the current Airflow version. Please contact the author of ' + 'the plugin.', + 'WARNING:airflow.plugins_manager:Plugin \'test_menu_links_plugin\' may not be ' + 'compatible with the current Airflow version. Please contact the author of ' + 'the plugin.', + ], + ) def test_should_not_warning_about_fab_plugins(self): class AirflowAdminViewsPlugin(AirflowPlugin): @@ -162,10 +171,7 @@ class AirflowAdminMenuLinksPlugin(AirflowPlugin): appbuilder_menu_items = [mock.MagicMock()] - with mock_plugin_manager(plugins=[ - AirflowAdminViewsPlugin(), - AirflowAdminMenuLinksPlugin() - ]): + with mock_plugin_manager(plugins=[AirflowAdminViewsPlugin(), AirflowAdminMenuLinksPlugin()]): from airflow import plugins_manager # assert not logs @@ -185,10 +191,7 @@ class AirflowAdminMenuLinksPlugin(AirflowPlugin): menu_links = [mock.MagicMock()] appbuilder_menu_items = [mock.MagicMock()] - with mock_plugin_manager(plugins=[ - AirflowAdminViewsPlugin(), - AirflowAdminMenuLinksPlugin() - ]): + with mock_plugin_manager(plugins=[AirflowAdminViewsPlugin(), AirflowAdminMenuLinksPlugin()]): from airflow import plugins_manager # assert not logs diff --git a/tests/providers/amazon/aws/hooks/test_glacier.py b/tests/providers/amazon/aws/hooks/test_glacier.py index 63b93d5004fd1..0f60f371884f1 100644 --- a/tests/providers/amazon/aws/hooks/test_glacier.py +++ b/tests/providers/amazon/aws/hooks/test_glacier.py @@ -17,8 +17,8 @@ # under the License. import unittest - from unittest import mock + from testfixtures import LogCapture from airflow.providers.amazon.aws.hooks.glacier import GlacierHook diff --git a/tests/providers/amazon/aws/hooks/test_glue.py b/tests/providers/amazon/aws/hooks/test_glue.py index d450bc780dde3..512e7c84691cd 100644 --- a/tests/providers/amazon/aws/hooks/test_glue.py +++ b/tests/providers/amazon/aws/hooks/test_glue.py @@ -17,7 +17,6 @@ # under the License. import json import unittest - from unittest import mock from airflow.providers.amazon.aws.hooks.glue import AwsGlueJobHook diff --git a/tests/providers/amazon/aws/hooks/test_sagemaker.py b/tests/providers/amazon/aws/hooks/test_sagemaker.py index a94281a6452fa..b6ecfce873f19 100644 --- a/tests/providers/amazon/aws/hooks/test_sagemaker.py +++ b/tests/providers/amazon/aws/hooks/test_sagemaker.py @@ -20,8 +20,8 @@ import time import unittest from datetime import datetime - from unittest import mock + from tzlocal import get_localzone from airflow.exceptions import AirflowException diff --git a/tests/providers/amazon/aws/hooks/test_secrets_manager.py b/tests/providers/amazon/aws/hooks/test_secrets_manager.py index c83260c3ba397..916a5ea04def0 100644 --- a/tests/providers/amazon/aws/hooks/test_secrets_manager.py +++ b/tests/providers/amazon/aws/hooks/test_secrets_manager.py @@ -17,9 +17,9 @@ # under the License. # -import unittest import base64 import json +import unittest from airflow.providers.amazon.aws.hooks.secrets_manager import SecretsManagerHook diff --git a/tests/providers/amazon/aws/operators/test_athena.py b/tests/providers/amazon/aws/operators/test_athena.py index 1eef2c4b5f484..61e43d26aec1c 100644 --- a/tests/providers/amazon/aws/operators/test_athena.py +++ b/tests/providers/amazon/aws/operators/test_athena.py @@ -16,7 +16,6 @@ # under the License. import unittest - from unittest import mock from airflow.models import DAG, TaskInstance diff --git a/tests/providers/amazon/aws/operators/test_batch.py b/tests/providers/amazon/aws/operators/test_batch.py index 4bf36c455132c..f72e12a2ece43 100644 --- a/tests/providers/amazon/aws/operators/test_batch.py +++ b/tests/providers/amazon/aws/operators/test_batch.py @@ -20,7 +20,6 @@ # pylint: disable=missing-docstring import unittest - from unittest import mock from airflow.exceptions import AirflowException diff --git a/tests/providers/amazon/aws/operators/test_ecs.py b/tests/providers/amazon/aws/operators/test_ecs.py index 16f39bc116342..a8bb79216164f 100644 --- a/tests/providers/amazon/aws/operators/test_ecs.py +++ b/tests/providers/amazon/aws/operators/test_ecs.py @@ -20,8 +20,8 @@ import sys import unittest from copy import deepcopy - from unittest import mock + from parameterized import parameterized from airflow.exceptions import AirflowException diff --git a/tests/providers/amazon/aws/operators/test_glacier.py b/tests/providers/amazon/aws/operators/test_glacier.py index 561ad20683b7b..e86e82cf929b8 100644 --- a/tests/providers/amazon/aws/operators/test_glacier.py +++ b/tests/providers/amazon/aws/operators/test_glacier.py @@ -16,13 +16,9 @@ # specific language governing permissions and limitations # under the License. -from unittest import TestCase +from unittest import TestCase, mock -from unittest import mock - -from airflow.providers.amazon.aws.operators.glacier import ( - GlacierCreateJobOperator, -) +from airflow.providers.amazon.aws.operators.glacier import GlacierCreateJobOperator AWS_CONN_ID = "aws_default" BUCKET_NAME = "airflow_bucket" diff --git a/tests/providers/amazon/aws/operators/test_glacier_system.py b/tests/providers/amazon/aws/operators/test_glacier_system.py index b31c69958e76b..99a5315504b25 100644 --- a/tests/providers/amazon/aws/operators/test_glacier_system.py +++ b/tests/providers/amazon/aws/operators/test_glacier_system.py @@ -18,7 +18,6 @@ from tests.test_utils.amazon_system_helpers import AWS_DAG_FOLDER, AmazonSystemTest from tests.test_utils.gcp_system_helpers import GoogleSystemTest - BUCKET = "data_from_glacier" diff --git a/tests/providers/amazon/aws/operators/test_glue.py b/tests/providers/amazon/aws/operators/test_glue.py index f0a08f5551f22..144bebed29f83 100644 --- a/tests/providers/amazon/aws/operators/test_glue.py +++ b/tests/providers/amazon/aws/operators/test_glue.py @@ -16,7 +16,6 @@ # under the License. import unittest - from unittest import mock from airflow import configuration diff --git a/tests/providers/amazon/aws/operators/test_s3_bucket.py b/tests/providers/amazon/aws/operators/test_s3_bucket.py index 5d090b81a0b40..e501125ae90bf 100644 --- a/tests/providers/amazon/aws/operators/test_s3_bucket.py +++ b/tests/providers/amazon/aws/operators/test_s3_bucket.py @@ -17,8 +17,8 @@ # under the License. import os import unittest - from unittest import mock + from moto import mock_s3 from airflow.providers.amazon.aws.hooks.s3 import S3Hook diff --git a/tests/providers/amazon/aws/operators/test_s3_list.py b/tests/providers/amazon/aws/operators/test_s3_list.py index 5c271a38d23b4..b51c8fac89b65 100644 --- a/tests/providers/amazon/aws/operators/test_s3_list.py +++ b/tests/providers/amazon/aws/operators/test_s3_list.py @@ -17,7 +17,6 @@ # under the License. import unittest - from unittest import mock from airflow.providers.amazon.aws.operators.s3_list import S3ListOperator diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py index 44563204571f1..d08c92453842c 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py @@ -17,8 +17,8 @@ # under the License. import unittest - from unittest import mock + from botocore.exceptions import ClientError from airflow.exceptions import AirflowException diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint_config.py b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint_config.py index 6d36284f01f30..e725c0228a706 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint_config.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint_config.py @@ -17,7 +17,6 @@ # under the License. import unittest - from unittest import mock from airflow.exceptions import AirflowException diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_model.py b/tests/providers/amazon/aws/operators/test_sagemaker_model.py index cf8aaa925a879..6676f1009c395 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_model.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_model.py @@ -17,7 +17,6 @@ # under the License. import unittest - from unittest import mock from airflow.exceptions import AirflowException diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py index 1e51a9f523211..5e55b39006167 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py @@ -16,8 +16,8 @@ # under the License. import unittest - from unittest import mock + from parameterized import parameterized from airflow.exceptions import AirflowException diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_training.py b/tests/providers/amazon/aws/operators/test_sagemaker_training.py index e7134860a2207..7424d6f7f8b25 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_training.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_training.py @@ -16,7 +16,6 @@ # under the License. import unittest - from unittest import mock from airflow.exceptions import AirflowException diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py index ccbaa09b51918..739baed701dbe 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py @@ -17,7 +17,6 @@ # under the License. import unittest - from unittest import mock from airflow.exceptions import AirflowException diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py b/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py index a197ddf3d925b..9978944a18eec 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py @@ -17,7 +17,6 @@ # under the License. import unittest - from unittest import mock from airflow.exceptions import AirflowException diff --git a/tests/providers/amazon/aws/sensors/test_athena.py b/tests/providers/amazon/aws/sensors/test_athena.py index ad83cfe29544d..8f5f2c7173cf7 100644 --- a/tests/providers/amazon/aws/sensors/test_athena.py +++ b/tests/providers/amazon/aws/sensors/test_athena.py @@ -17,7 +17,6 @@ # under the License. import unittest - from unittest import mock from airflow.exceptions import AirflowException diff --git a/tests/providers/amazon/aws/sensors/test_glacier.py b/tests/providers/amazon/aws/sensors/test_glacier.py index 410f984cdacc1..d954c6e7ed599 100644 --- a/tests/providers/amazon/aws/sensors/test_glacier.py +++ b/tests/providers/amazon/aws/sensors/test_glacier.py @@ -17,7 +17,6 @@ # under the License. import unittest - from unittest import mock from airflow import AirflowException diff --git a/tests/providers/amazon/aws/sensors/test_glue.py b/tests/providers/amazon/aws/sensors/test_glue.py index 21f0083869a54..ff503b1467970 100644 --- a/tests/providers/amazon/aws/sensors/test_glue.py +++ b/tests/providers/amazon/aws/sensors/test_glue.py @@ -16,7 +16,6 @@ # under the License. import unittest - from unittest import mock from airflow import configuration diff --git a/tests/providers/amazon/aws/sensors/test_glue_catalog_partition.py b/tests/providers/amazon/aws/sensors/test_glue_catalog_partition.py index 26f683f6cb0c9..a4c264bbe6359 100644 --- a/tests/providers/amazon/aws/sensors/test_glue_catalog_partition.py +++ b/tests/providers/amazon/aws/sensors/test_glue_catalog_partition.py @@ -17,7 +17,6 @@ # under the License. import unittest - from unittest import mock from airflow.providers.amazon.aws.hooks.glue_catalog import AwsGlueCatalogHook diff --git a/tests/providers/amazon/aws/sensors/test_sagemaker_endpoint.py b/tests/providers/amazon/aws/sensors/test_sagemaker_endpoint.py index 042859c53b00a..ca4090b12f2dd 100644 --- a/tests/providers/amazon/aws/sensors/test_sagemaker_endpoint.py +++ b/tests/providers/amazon/aws/sensors/test_sagemaker_endpoint.py @@ -17,7 +17,6 @@ # under the License. import unittest - from unittest import mock from airflow.exceptions import AirflowException diff --git a/tests/providers/amazon/aws/sensors/test_sagemaker_training.py b/tests/providers/amazon/aws/sensors/test_sagemaker_training.py index 011d23f4696dc..c6d5b782a68ed 100644 --- a/tests/providers/amazon/aws/sensors/test_sagemaker_training.py +++ b/tests/providers/amazon/aws/sensors/test_sagemaker_training.py @@ -18,7 +18,6 @@ import unittest from datetime import datetime - from unittest import mock from airflow.exceptions import AirflowException diff --git a/tests/providers/amazon/aws/sensors/test_sagemaker_transform.py b/tests/providers/amazon/aws/sensors/test_sagemaker_transform.py index 3f576f11211f8..b0463373c10ac 100644 --- a/tests/providers/amazon/aws/sensors/test_sagemaker_transform.py +++ b/tests/providers/amazon/aws/sensors/test_sagemaker_transform.py @@ -17,7 +17,6 @@ # under the License. import unittest - from unittest import mock from airflow.exceptions import AirflowException diff --git a/tests/providers/amazon/aws/sensors/test_sagemaker_tuning.py b/tests/providers/amazon/aws/sensors/test_sagemaker_tuning.py index 90e8a5e7e1845..32b2553f1f5c6 100644 --- a/tests/providers/amazon/aws/sensors/test_sagemaker_tuning.py +++ b/tests/providers/amazon/aws/sensors/test_sagemaker_tuning.py @@ -17,7 +17,6 @@ # under the License. import unittest - from unittest import mock from airflow.exceptions import AirflowException diff --git a/tests/providers/amazon/aws/transfers/test_gcs_to_s3.py b/tests/providers/amazon/aws/transfers/test_gcs_to_s3.py index 404431e81de6b..fc62cc4635c19 100644 --- a/tests/providers/amazon/aws/transfers/test_gcs_to_s3.py +++ b/tests/providers/amazon/aws/transfers/test_gcs_to_s3.py @@ -17,7 +17,6 @@ # under the License. import unittest - from unittest import mock from airflow.providers.amazon.aws.hooks.s3 import S3Hook diff --git a/tests/providers/amazon/aws/transfers/test_glacier_to_gcs.py b/tests/providers/amazon/aws/transfers/test_glacier_to_gcs.py index a2a49eb057947..f871f5acfa72c 100644 --- a/tests/providers/amazon/aws/transfers/test_glacier_to_gcs.py +++ b/tests/providers/amazon/aws/transfers/test_glacier_to_gcs.py @@ -15,9 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from unittest import TestCase - -from unittest import mock +from unittest import TestCase, mock from airflow.providers.amazon.aws.transfers.glacier_to_gcs import GlacierToGCSOperator diff --git a/tests/providers/amazon/aws/transfers/test_mongo_to_s3.py b/tests/providers/amazon/aws/transfers/test_mongo_to_s3.py index 44deffaa489a6..8ac0735b6b98f 100644 --- a/tests/providers/amazon/aws/transfers/test_mongo_to_s3.py +++ b/tests/providers/amazon/aws/transfers/test_mongo_to_s3.py @@ -16,7 +16,6 @@ # specific language governing permissions and limitations # under the License. import unittest - from unittest import mock from airflow.models import TaskInstance diff --git a/tests/providers/apache/cassandra/hooks/test_cassandra.py b/tests/providers/apache/cassandra/hooks/test_cassandra.py index 391cfd0705e5b..5e37acef674ec 100644 --- a/tests/providers/apache/cassandra/hooks/test_cassandra.py +++ b/tests/providers/apache/cassandra/hooks/test_cassandra.py @@ -17,8 +17,8 @@ # under the License. import unittest - from unittest import mock + import pytest from cassandra.cluster import Cluster from cassandra.policies import ( diff --git a/tests/providers/apache/druid/operators/test_druid_check.py b/tests/providers/apache/druid/operators/test_druid_check.py index b1bc844190db3..bb84e9d47daec 100644 --- a/tests/providers/apache/druid/operators/test_druid_check.py +++ b/tests/providers/apache/druid/operators/test_druid_check.py @@ -19,7 +19,6 @@ import unittest from datetime import datetime - from unittest import mock from airflow.exceptions import AirflowException diff --git a/tests/providers/apache/hive/transfers/test_s3_to_hive.py b/tests/providers/apache/hive/transfers/test_s3_to_hive.py index ad3fe94f24245..473b889335b65 100644 --- a/tests/providers/apache/hive/transfers/test_s3_to_hive.py +++ b/tests/providers/apache/hive/transfers/test_s3_to_hive.py @@ -26,7 +26,6 @@ from gzip import GzipFile from itertools import product from tempfile import NamedTemporaryFile, mkdtemp - from unittest import mock from airflow.exceptions import AirflowException diff --git a/tests/providers/apache/pig/hooks/test_pig.py b/tests/providers/apache/pig/hooks/test_pig.py index da8de176a8fa1..a714368b88354 100644 --- a/tests/providers/apache/pig/hooks/test_pig.py +++ b/tests/providers/apache/pig/hooks/test_pig.py @@ -17,7 +17,6 @@ # under the License. import unittest - from unittest import mock from airflow.providers.apache.pig.hooks.pig import PigCliHook diff --git a/tests/providers/apache/pig/operators/test_pig.py b/tests/providers/apache/pig/operators/test_pig.py index 174f24c0607ea..92546d122e175 100644 --- a/tests/providers/apache/pig/operators/test_pig.py +++ b/tests/providers/apache/pig/operators/test_pig.py @@ -16,7 +16,6 @@ # under the License. import unittest - from unittest import mock from airflow.providers.apache.pig.hooks.pig import PigCliHook diff --git a/tests/providers/apache/spark/hooks/test_spark_jdbc_script.py b/tests/providers/apache/spark/hooks/test_spark_jdbc_script.py index 1431ec1a86c86..18bb75f192a3a 100644 --- a/tests/providers/apache/spark/hooks/test_spark_jdbc_script.py +++ b/tests/providers/apache/spark/hooks/test_spark_jdbc_script.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. from unittest import mock + import pytest from pyspark.sql.readwriter import DataFrameReader, DataFrameWriter diff --git a/tests/providers/cloudant/hooks/test_cloudant.py b/tests/providers/cloudant/hooks/test_cloudant.py index cf291c027cf82..d0c78dbebb97c 100644 --- a/tests/providers/cloudant/hooks/test_cloudant.py +++ b/tests/providers/cloudant/hooks/test_cloudant.py @@ -16,7 +16,6 @@ # specific language governing permissions and limitations # under the License. import unittest - from unittest.mock import patch from airflow.exceptions import AirflowException diff --git a/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py b/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py index e03540abecb61..9a01d8fa5007b 100644 --- a/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py +++ b/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py @@ -23,11 +23,10 @@ from unittest import mock from unittest.mock import patch +import kubernetes from parameterized import parameterized -import kubernetes from airflow import AirflowException - from airflow.models import Connection from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook from airflow.utils import db diff --git a/tests/providers/databricks/hooks/test_databricks.py b/tests/providers/databricks/hooks/test_databricks.py index ac70edcf91e59..94a094be7ba83 100644 --- a/tests/providers/databricks/hooks/test_databricks.py +++ b/tests/providers/databricks/hooks/test_databricks.py @@ -20,8 +20,8 @@ import itertools import json import unittest - from unittest import mock + from requests import exceptions as requests_exceptions from airflow import __version__ diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index c26e464f96795..8f43ad5391b65 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -18,7 +18,6 @@ # import unittest from datetime import datetime - from unittest import mock from airflow.exceptions import AirflowException diff --git a/tests/providers/datadog/sensors/test_datadog.py b/tests/providers/datadog/sensors/test_datadog.py index cf9633b752b61..10dae5f31a508 100644 --- a/tests/providers/datadog/sensors/test_datadog.py +++ b/tests/providers/datadog/sensors/test_datadog.py @@ -19,7 +19,6 @@ import json import unittest from typing import List - from unittest.mock import patch from airflow.models import Connection diff --git a/tests/providers/docker/hooks/test_docker.py b/tests/providers/docker/hooks/test_docker.py index d66823a96030b..49e810b3d2859 100644 --- a/tests/providers/docker/hooks/test_docker.py +++ b/tests/providers/docker/hooks/test_docker.py @@ -17,7 +17,6 @@ # under the License. import unittest - from unittest import mock from airflow.exceptions import AirflowException diff --git a/tests/providers/docker/operators/test_docker.py b/tests/providers/docker/operators/test_docker.py index 49cd6a69bf3d7..631d42be2df54 100644 --- a/tests/providers/docker/operators/test_docker.py +++ b/tests/providers/docker/operators/test_docker.py @@ -17,7 +17,6 @@ # under the License. import logging import unittest - from unittest import mock from airflow.exceptions import AirflowException diff --git a/tests/providers/docker/operators/test_docker_swarm.py b/tests/providers/docker/operators/test_docker_swarm.py index 784e8e7fc3d03..a3208ef00138b 100644 --- a/tests/providers/docker/operators/test_docker_swarm.py +++ b/tests/providers/docker/operators/test_docker_swarm.py @@ -17,8 +17,8 @@ # under the License. import unittest - from unittest import mock + import requests from docker import APIClient diff --git a/tests/providers/exasol/hooks/test_exasol.py b/tests/providers/exasol/hooks/test_exasol.py index 110259e5c6ad2..4b964e102ad33 100644 --- a/tests/providers/exasol/hooks/test_exasol.py +++ b/tests/providers/exasol/hooks/test_exasol.py @@ -19,7 +19,6 @@ import json import unittest - from unittest import mock from airflow import models diff --git a/tests/providers/exasol/operators/test_exasol.py b/tests/providers/exasol/operators/test_exasol.py index a882075a72ac4..2b486a14c8057 100644 --- a/tests/providers/exasol/operators/test_exasol.py +++ b/tests/providers/exasol/operators/test_exasol.py @@ -17,7 +17,6 @@ # under the License. import unittest - from unittest import mock from airflow.providers.exasol.operators.exasol import ExasolOperator diff --git a/tests/providers/facebook/ads/hooks/test_ads.py b/tests/providers/facebook/ads/hooks/test_ads.py index 366277a7a03b2..8d4d88b021c98 100644 --- a/tests/providers/facebook/ads/hooks/test_ads.py +++ b/tests/providers/facebook/ads/hooks/test_ads.py @@ -16,6 +16,7 @@ # under the License. from unittest import mock + import pytest from airflow.providers.facebook.ads.hooks.ads import FacebookAdsReportingHook diff --git a/tests/providers/google/ads/hooks/test_ads.py b/tests/providers/google/ads/hooks/test_ads.py index 1928f79760ba9..83d17b7d716c4 100644 --- a/tests/providers/google/ads/hooks/test_ads.py +++ b/tests/providers/google/ads/hooks/test_ads.py @@ -16,6 +16,7 @@ # under the License. from unittest import mock + import pytest from airflow.providers.google.ads.hooks.ads import GoogleAdsHook diff --git a/tests/providers/google/cloud/hooks/test_automl.py b/tests/providers/google/cloud/hooks/test_automl.py index 353921b644a74..898001c3b17a2 100644 --- a/tests/providers/google/cloud/hooks/test_automl.py +++ b/tests/providers/google/cloud/hooks/test_automl.py @@ -17,8 +17,8 @@ # under the License. # import unittest - from unittest import mock + from google.cloud.automl_v1beta1 import AutoMlClient, PredictionServiceClient from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook diff --git a/tests/providers/google/cloud/hooks/test_bigquery_dts.py b/tests/providers/google/cloud/hooks/test_bigquery_dts.py index 82944f8afb42e..cf8b1cd899fae 100644 --- a/tests/providers/google/cloud/hooks/test_bigquery_dts.py +++ b/tests/providers/google/cloud/hooks/test_bigquery_dts.py @@ -18,8 +18,8 @@ import unittest from copy import deepcopy - from unittest import mock + from google.cloud.bigquery_datatransfer_v1 import DataTransferServiceClient from google.cloud.bigquery_datatransfer_v1.types import TransferConfig from google.protobuf.json_format import ParseDict diff --git a/tests/providers/google/cloud/hooks/test_cloud_build.py b/tests/providers/google/cloud/hooks/test_cloud_build.py index ecd7a16377495..9a3b67ec258e6 100644 --- a/tests/providers/google/cloud/hooks/test_cloud_build.py +++ b/tests/providers/google/cloud/hooks/test_cloud_build.py @@ -21,7 +21,6 @@ import unittest from typing import Optional from unittest import mock - from unittest.mock import PropertyMock from airflow.exceptions import AirflowException diff --git a/tests/providers/google/cloud/hooks/test_cloud_storage_transfer_service.py b/tests/providers/google/cloud/hooks/test_cloud_storage_transfer_service.py index 4fda2ad776907..b7d869b535cb4 100644 --- a/tests/providers/google/cloud/hooks/test_cloud_storage_transfer_service.py +++ b/tests/providers/google/cloud/hooks/test_cloud_storage_transfer_service.py @@ -20,9 +20,9 @@ import re import unittest from copy import deepcopy - from unittest import mock from unittest.mock import MagicMock, PropertyMock + from googleapiclient.errors import HttpError from parameterized import parameterized diff --git a/tests/providers/google/cloud/hooks/test_compute.py b/tests/providers/google/cloud/hooks/test_compute.py index 01ee15113f827..64f6d928fcf60 100644 --- a/tests/providers/google/cloud/hooks/test_compute.py +++ b/tests/providers/google/cloud/hooks/test_compute.py @@ -19,7 +19,6 @@ # pylint: disable=too-many-lines import unittest - from unittest import mock from unittest.mock import PropertyMock diff --git a/tests/providers/google/cloud/hooks/test_dataflow.py b/tests/providers/google/cloud/hooks/test_dataflow.py index 9de46a5b81e09..be486eed511df 100644 --- a/tests/providers/google/cloud/hooks/test_dataflow.py +++ b/tests/providers/google/cloud/hooks/test_dataflow.py @@ -21,9 +21,9 @@ import shlex import unittest from typing import Any, Dict - from unittest import mock from unittest.mock import MagicMock + from parameterized import parameterized from airflow.exceptions import AirflowException diff --git a/tests/providers/google/cloud/hooks/test_dataproc.py b/tests/providers/google/cloud/hooks/test_dataproc.py index ebca9f5f340d7..43a28ee7c0594 100644 --- a/tests/providers/google/cloud/hooks/test_dataproc.py +++ b/tests/providers/google/cloud/hooks/test_dataproc.py @@ -17,8 +17,8 @@ # under the License. # import unittest - from unittest import mock + from google.cloud.dataproc_v1beta2.types import JobStatus # pylint: disable=no-name-in-module from airflow.exceptions import AirflowException diff --git a/tests/providers/google/cloud/hooks/test_datastore.py b/tests/providers/google/cloud/hooks/test_datastore.py index dfb15d173808d..3d9216ab9daee 100644 --- a/tests/providers/google/cloud/hooks/test_datastore.py +++ b/tests/providers/google/cloud/hooks/test_datastore.py @@ -18,9 +18,8 @@ # import unittest -from unittest.mock import call, patch - from unittest import mock +from unittest.mock import call, patch from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.datastore import DatastoreHook diff --git a/tests/providers/google/cloud/hooks/test_dlp.py b/tests/providers/google/cloud/hooks/test_dlp.py index 7e5c5c8dc601f..c9f14d99e81cf 100644 --- a/tests/providers/google/cloud/hooks/test_dlp.py +++ b/tests/providers/google/cloud/hooks/test_dlp.py @@ -24,9 +24,9 @@ import unittest from typing import Any, Dict - from unittest import mock from unittest.mock import PropertyMock + from google.cloud.dlp_v2.types import DlpJob from airflow.exceptions import AirflowException diff --git a/tests/providers/google/cloud/hooks/test_functions.py b/tests/providers/google/cloud/hooks/test_functions.py index 02fd13847492e..76304b00b3a07 100644 --- a/tests/providers/google/cloud/hooks/test_functions.py +++ b/tests/providers/google/cloud/hooks/test_functions.py @@ -17,7 +17,6 @@ # under the License. import unittest - from unittest import mock from unittest.mock import PropertyMock diff --git a/tests/providers/google/cloud/hooks/test_kms.py b/tests/providers/google/cloud/hooks/test_kms.py index 94383040cb9f1..894532aaed41b 100644 --- a/tests/providers/google/cloud/hooks/test_kms.py +++ b/tests/providers/google/cloud/hooks/test_kms.py @@ -19,7 +19,6 @@ import unittest from base64 import b64decode, b64encode from collections import namedtuple - from unittest import mock from airflow.providers.google.cloud.hooks.kms import CloudKMSHook diff --git a/tests/providers/google/cloud/hooks/test_kubernetes_engine.py b/tests/providers/google/cloud/hooks/test_kubernetes_engine.py index d93abe52b4009..892c9f7a6fb53 100644 --- a/tests/providers/google/cloud/hooks/test_kubernetes_engine.py +++ b/tests/providers/google/cloud/hooks/test_kubernetes_engine.py @@ -17,9 +17,9 @@ # under the License. # import unittest - from unittest import mock from unittest.mock import PropertyMock + from google.cloud.container_v1.types import Cluster from airflow.exceptions import AirflowException diff --git a/tests/providers/google/cloud/hooks/test_life_sciences.py b/tests/providers/google/cloud/hooks/test_life_sciences.py index 9f6c553d997a1..e203e87f3eb79 100644 --- a/tests/providers/google/cloud/hooks/test_life_sciences.py +++ b/tests/providers/google/cloud/hooks/test_life_sciences.py @@ -20,7 +20,6 @@ """ import unittest from unittest import mock - from unittest.mock import PropertyMock from airflow.exceptions import AirflowException diff --git a/tests/providers/google/cloud/hooks/test_natural_language.py b/tests/providers/google/cloud/hooks/test_natural_language.py index 1a4b588017490..745b304506364 100644 --- a/tests/providers/google/cloud/hooks/test_natural_language.py +++ b/tests/providers/google/cloud/hooks/test_natural_language.py @@ -18,8 +18,8 @@ # import unittest from typing import Any, Dict - from unittest import mock + from google.cloud.language_v1.proto.language_service_pb2 import Document from airflow.providers.google.cloud.hooks.natural_language import CloudNaturalLanguageHook diff --git a/tests/providers/google/cloud/hooks/test_pubsub.py b/tests/providers/google/cloud/hooks/test_pubsub.py index 7c1fcad6953eb..f519c70504a4b 100644 --- a/tests/providers/google/cloud/hooks/test_pubsub.py +++ b/tests/providers/google/cloud/hooks/test_pubsub.py @@ -18,8 +18,8 @@ import unittest from typing import List - from unittest import mock + from google.api_core.exceptions import AlreadyExists, GoogleAPICallError from google.cloud.exceptions import NotFound from google.cloud.pubsub_v1.types import ReceivedMessage diff --git a/tests/providers/google/cloud/hooks/test_spanner.py b/tests/providers/google/cloud/hooks/test_spanner.py index 2fb4527a5c5de..27b7a06e67456 100644 --- a/tests/providers/google/cloud/hooks/test_spanner.py +++ b/tests/providers/google/cloud/hooks/test_spanner.py @@ -17,7 +17,6 @@ # under the License. import unittest - from unittest import mock from unittest.mock import PropertyMock diff --git a/tests/providers/google/cloud/hooks/test_speech_to_text.py b/tests/providers/google/cloud/hooks/test_speech_to_text.py index cd2c8fe5debfa..be5d933137823 100644 --- a/tests/providers/google/cloud/hooks/test_speech_to_text.py +++ b/tests/providers/google/cloud/hooks/test_speech_to_text.py @@ -18,7 +18,6 @@ # import unittest - from unittest.mock import PropertyMock, patch from airflow.providers.google.cloud.hooks.speech_to_text import CloudSpeechToTextHook diff --git a/tests/providers/google/cloud/hooks/test_stackdriver.py b/tests/providers/google/cloud/hooks/test_stackdriver.py index ae16064829364..6892d0552f559 100644 --- a/tests/providers/google/cloud/hooks/test_stackdriver.py +++ b/tests/providers/google/cloud/hooks/test_stackdriver.py @@ -18,8 +18,8 @@ import json import unittest - from unittest import mock + from google.api_core.gapic_v1.method import DEFAULT from google.cloud import monitoring_v3 from google.protobuf.json_format import ParseDict diff --git a/tests/providers/google/cloud/hooks/test_tasks.py b/tests/providers/google/cloud/hooks/test_tasks.py index 8c7791843c546..8be6686cf6c1d 100644 --- a/tests/providers/google/cloud/hooks/test_tasks.py +++ b/tests/providers/google/cloud/hooks/test_tasks.py @@ -18,8 +18,8 @@ # import unittest from typing import Any, Dict - from unittest import mock + from google.cloud.tasks_v2.types import Queue, Task from airflow.providers.google.cloud.hooks.tasks import CloudTasksHook diff --git a/tests/providers/google/cloud/hooks/test_text_to_speech.py b/tests/providers/google/cloud/hooks/test_text_to_speech.py index ed6338c19a538..cc627a380b0dd 100644 --- a/tests/providers/google/cloud/hooks/test_text_to_speech.py +++ b/tests/providers/google/cloud/hooks/test_text_to_speech.py @@ -18,7 +18,6 @@ # import unittest - from unittest.mock import PropertyMock, patch from airflow.providers.google.cloud.hooks.text_to_speech import CloudTextToSpeechHook diff --git a/tests/providers/google/cloud/hooks/test_translate.py b/tests/providers/google/cloud/hooks/test_translate.py index c00eb026742c7..43c559c3b1a35 100644 --- a/tests/providers/google/cloud/hooks/test_translate.py +++ b/tests/providers/google/cloud/hooks/test_translate.py @@ -17,7 +17,6 @@ # under the License. import unittest - from unittest import mock from airflow.providers.google.cloud.hooks.translate import CloudTranslateHook diff --git a/tests/providers/google/cloud/hooks/test_video_intelligence.py b/tests/providers/google/cloud/hooks/test_video_intelligence.py index f6973f4c1e840..624d0c597b95c 100644 --- a/tests/providers/google/cloud/hooks/test_video_intelligence.py +++ b/tests/providers/google/cloud/hooks/test_video_intelligence.py @@ -17,8 +17,8 @@ # under the License. # import unittest - from unittest import mock + from google.cloud.videointelligence_v1 import enums from airflow.providers.google.cloud.hooks.video_intelligence import CloudVideoIntelligenceHook diff --git a/tests/providers/google/cloud/hooks/test_vision.py b/tests/providers/google/cloud/hooks/test_vision.py index f6cfbd882118f..31f004c98419f 100644 --- a/tests/providers/google/cloud/hooks/test_vision.py +++ b/tests/providers/google/cloud/hooks/test_vision.py @@ -16,8 +16,8 @@ # specific language governing permissions and limitations # under the License. import unittest - from unittest import mock + from google.cloud.vision import enums from google.cloud.vision_v1 import ProductSearchClient from google.cloud.vision_v1.proto.image_annotator_pb2 import ( diff --git a/tests/providers/google/cloud/operators/test_automl.py b/tests/providers/google/cloud/operators/test_automl.py index 22650a79b3638..903600b0fa815 100644 --- a/tests/providers/google/cloud/operators/test_automl.py +++ b/tests/providers/google/cloud/operators/test_automl.py @@ -18,8 +18,8 @@ # import copy import unittest - from unittest import mock + from google.cloud.automl_v1beta1 import AutoMlClient, PredictionServiceClient from airflow.providers.google.cloud.operators.automl import ( diff --git a/tests/providers/google/cloud/operators/test_bigquery.py b/tests/providers/google/cloud/operators/test_bigquery.py index b3558bcaddd4f..3fa01f88273e1 100644 --- a/tests/providers/google/cloud/operators/test_bigquery.py +++ b/tests/providers/google/cloud/operators/test_bigquery.py @@ -18,9 +18,9 @@ import unittest from datetime import datetime +from unittest import mock from unittest.mock import MagicMock -from unittest import mock import pytest from google.cloud.exceptions import Conflict from parameterized import parameterized diff --git a/tests/providers/google/cloud/operators/test_bigquery_dts.py b/tests/providers/google/cloud/operators/test_bigquery_dts.py index d25e1abd18140..4d423527acc8e 100644 --- a/tests/providers/google/cloud/operators/test_bigquery_dts.py +++ b/tests/providers/google/cloud/operators/test_bigquery_dts.py @@ -16,7 +16,6 @@ # specific language governing permissions and limitations # under the License. import unittest - from unittest import mock from airflow.providers.google.cloud.operators.bigquery_dts import ( diff --git a/tests/providers/google/cloud/operators/test_cloud_build.py b/tests/providers/google/cloud/operators/test_cloud_build.py index 924bb45df95bd..970ee0606ef0f 100644 --- a/tests/providers/google/cloud/operators/test_cloud_build.py +++ b/tests/providers/google/cloud/operators/test_cloud_build.py @@ -20,9 +20,8 @@ import tempfile from copy import deepcopy from datetime import datetime -from unittest import TestCase +from unittest import TestCase, mock -from unittest import mock from parameterized import parameterized from airflow.exceptions import AirflowException diff --git a/tests/providers/google/cloud/operators/test_cloud_memorystore.py b/tests/providers/google/cloud/operators/test_cloud_memorystore.py index d64acf8bcd388..8ef60bd9b62b1 100644 --- a/tests/providers/google/cloud/operators/test_cloud_memorystore.py +++ b/tests/providers/google/cloud/operators/test_cloud_memorystore.py @@ -19,9 +19,9 @@ from unittest import TestCase, mock from google.api_core.retry import Retry +from google.cloud.memcache_v1beta2.types import cloud_memcache from google.cloud.redis_v1.gapic.enums import FailoverInstanceRequest from google.cloud.redis_v1.types import Instance -from google.cloud.memcache_v1beta2.types import cloud_memcache from airflow.providers.google.cloud.operators.cloud_memorystore import ( CloudMemorystoreCreateInstanceAndImportOperator, @@ -32,13 +32,13 @@ CloudMemorystoreGetInstanceOperator, CloudMemorystoreImportOperator, CloudMemorystoreListInstancesOperator, - CloudMemorystoreScaleInstanceOperator, - CloudMemorystoreUpdateInstanceOperator, CloudMemorystoreMemcachedCreateInstanceOperator, CloudMemorystoreMemcachedDeleteInstanceOperator, CloudMemorystoreMemcachedGetInstanceOperator, CloudMemorystoreMemcachedListInstancesOperator, CloudMemorystoreMemcachedUpdateInstanceOperator, + CloudMemorystoreScaleInstanceOperator, + CloudMemorystoreUpdateInstanceOperator, ) TEST_GCP_CONN_ID = "test-gcp-conn-id" diff --git a/tests/providers/google/cloud/operators/test_cloud_sql.py b/tests/providers/google/cloud/operators/test_cloud_sql.py index 13b36d01a8eb2..2e14687ae53ac 100644 --- a/tests/providers/google/cloud/operators/test_cloud_sql.py +++ b/tests/providers/google/cloud/operators/test_cloud_sql.py @@ -20,8 +20,8 @@ import os import unittest - from unittest import mock + from parameterized import parameterized from airflow.exceptions import AirflowException diff --git a/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py b/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py index 31b8c53b14ce4..9dc3060b4c432 100644 --- a/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py +++ b/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py @@ -20,8 +20,8 @@ from copy import deepcopy from datetime import date, time from typing import Dict - from unittest import mock + from botocore.credentials import Credentials from freezegun import freeze_time from parameterized import parameterized diff --git a/tests/providers/google/cloud/operators/test_dataflow.py b/tests/providers/google/cloud/operators/test_dataflow.py index 48ccdfbeda6c5..02d95e52179ae 100644 --- a/tests/providers/google/cloud/operators/test_dataflow.py +++ b/tests/providers/google/cloud/operators/test_dataflow.py @@ -18,15 +18,14 @@ # import unittest - from unittest import mock from airflow.providers.google.cloud.operators.dataflow import ( CheckJobRunning, DataflowCreateJavaJobOperator, DataflowCreatePythonJobOperator, - DataflowTemplatedJobStartOperator, DataflowStartFlexTemplateOperator, + DataflowTemplatedJobStartOperator, ) from airflow.version import version diff --git a/tests/providers/google/cloud/operators/test_dataproc_system.py b/tests/providers/google/cloud/operators/test_dataproc_system.py index 48d4fdf8e6bbb..568af28f53fa0 100644 --- a/tests/providers/google/cloud/operators/test_dataproc_system.py +++ b/tests/providers/google/cloud/operators/test_dataproc_system.py @@ -17,7 +17,7 @@ # under the License. import pytest -from airflow.providers.google.cloud.example_dags.example_dataproc import PYSPARK_MAIN, BUCKET, SPARKR_MAIN +from airflow.providers.google.cloud.example_dags.example_dataproc import BUCKET, PYSPARK_MAIN, SPARKR_MAIN from tests.providers.google.cloud.utils.gcp_authenticator import GCP_DATAPROC_KEY from tests.test_utils.gcp_system_helpers import CLOUD_DAG_FOLDER, GoogleSystemTest, provide_gcp_context diff --git a/tests/providers/google/cloud/operators/test_dlp.py b/tests/providers/google/cloud/operators/test_dlp.py index 6b6ad09489208..7c68102784f7e 100644 --- a/tests/providers/google/cloud/operators/test_dlp.py +++ b/tests/providers/google/cloud/operators/test_dlp.py @@ -22,7 +22,6 @@ """ import unittest - from unittest import mock from airflow.providers.google.cloud.operators.dlp import ( diff --git a/tests/providers/google/cloud/operators/test_dlp_system.py b/tests/providers/google/cloud/operators/test_dlp_system.py index d27951fd7c2c5..12296ae248e78 100644 --- a/tests/providers/google/cloud/operators/test_dlp_system.py +++ b/tests/providers/google/cloud/operators/test_dlp_system.py @@ -23,9 +23,9 @@ """ import pytest +from airflow.providers.google.cloud.example_dags.example_dlp import OUTPUT_BUCKET, OUTPUT_FILENAME from tests.providers.google.cloud.utils.gcp_authenticator import GCP_DLP_KEY from tests.test_utils.gcp_system_helpers import CLOUD_DAG_FOLDER, GoogleSystemTest, provide_gcp_context -from airflow.providers.google.cloud.example_dags.example_dlp import OUTPUT_BUCKET, OUTPUT_FILENAME @pytest.fixture(scope="class") diff --git a/tests/providers/google/cloud/operators/test_functions.py b/tests/providers/google/cloud/operators/test_functions.py index a96d45ae1ee2c..4a8546054eb16 100644 --- a/tests/providers/google/cloud/operators/test_functions.py +++ b/tests/providers/google/cloud/operators/test_functions.py @@ -18,8 +18,8 @@ import unittest from copy import deepcopy - from unittest import mock + from googleapiclient.errors import HttpError from parameterized import parameterized diff --git a/tests/providers/google/cloud/operators/test_gcs.py b/tests/providers/google/cloud/operators/test_gcs.py index 08d649dd9ca5e..b41cf97710a6f 100644 --- a/tests/providers/google/cloud/operators/test_gcs.py +++ b/tests/providers/google/cloud/operators/test_gcs.py @@ -17,7 +17,6 @@ # under the License. import unittest - from unittest import mock from airflow.providers.google.cloud.operators.gcs import ( diff --git a/tests/providers/google/cloud/operators/test_kubernetes_engine.py b/tests/providers/google/cloud/operators/test_kubernetes_engine.py index 2bf5933ba0da2..44b9b68daea14 100644 --- a/tests/providers/google/cloud/operators/test_kubernetes_engine.py +++ b/tests/providers/google/cloud/operators/test_kubernetes_engine.py @@ -18,9 +18,9 @@ import json import os import unittest - from unittest import mock from unittest.mock import PropertyMock + from parameterized import parameterized from airflow.exceptions import AirflowException diff --git a/tests/providers/google/cloud/operators/test_life_sciences.py b/tests/providers/google/cloud/operators/test_life_sciences.py index d7432825657b4..cc08e13ff8ff8 100644 --- a/tests/providers/google/cloud/operators/test_life_sciences.py +++ b/tests/providers/google/cloud/operators/test_life_sciences.py @@ -18,7 +18,6 @@ """Tests for Google Life Sciences Run Pipeline operator """ import unittest - from unittest import mock from airflow.providers.google.cloud.operators.life_sciences import LifeSciencesRunPipelineOperator diff --git a/tests/providers/google/cloud/operators/test_mlengine_utils.py b/tests/providers/google/cloud/operators/test_mlengine_utils.py index 8cf5009a99754..6c133154b61c0 100644 --- a/tests/providers/google/cloud/operators/test_mlengine_utils.py +++ b/tests/providers/google/cloud/operators/test_mlengine_utils.py @@ -17,9 +17,8 @@ import datetime import unittest -from unittest.mock import ANY, patch - from unittest import mock +from unittest.mock import ANY, patch from airflow.exceptions import AirflowException from airflow.models.dag import DAG diff --git a/tests/providers/google/cloud/operators/test_pubsub.py b/tests/providers/google/cloud/operators/test_pubsub.py index 662baa8e7db4d..0bb1dfd8d2a86 100644 --- a/tests/providers/google/cloud/operators/test_pubsub.py +++ b/tests/providers/google/cloud/operators/test_pubsub.py @@ -18,8 +18,8 @@ import unittest from typing import Any, Dict, List - from unittest import mock + from google.cloud.pubsub_v1.types import ReceivedMessage from google.protobuf.json_format import MessageToDict, ParseDict diff --git a/tests/providers/google/cloud/operators/test_spanner.py b/tests/providers/google/cloud/operators/test_spanner.py index 1ce86f2044d0e..4daccdb0e3936 100644 --- a/tests/providers/google/cloud/operators/test_spanner.py +++ b/tests/providers/google/cloud/operators/test_spanner.py @@ -16,8 +16,8 @@ # specific language governing permissions and limitations # under the License. import unittest - from unittest import mock + from parameterized import parameterized from airflow.exceptions import AirflowException diff --git a/tests/providers/google/cloud/operators/test_speech_to_text.py b/tests/providers/google/cloud/operators/test_speech_to_text.py index bd527bdc4cd4f..c9325ebe5924e 100644 --- a/tests/providers/google/cloud/operators/test_speech_to_text.py +++ b/tests/providers/google/cloud/operators/test_speech_to_text.py @@ -17,7 +17,6 @@ # under the License. import unittest - from unittest.mock import MagicMock, Mock, patch from airflow.exceptions import AirflowException diff --git a/tests/providers/google/cloud/operators/test_stackdriver.py b/tests/providers/google/cloud/operators/test_stackdriver.py index 3cd126329c5c7..28901b4b32416 100644 --- a/tests/providers/google/cloud/operators/test_stackdriver.py +++ b/tests/providers/google/cloud/operators/test_stackdriver.py @@ -18,8 +18,8 @@ import json import unittest - from unittest import mock + from google.api_core.gapic_v1.method import DEFAULT from airflow.providers.google.cloud.operators.stackdriver import ( diff --git a/tests/providers/google/cloud/operators/test_tasks.py b/tests/providers/google/cloud/operators/test_tasks.py index 73eae98a8949e..cac1441f67698 100644 --- a/tests/providers/google/cloud/operators/test_tasks.py +++ b/tests/providers/google/cloud/operators/test_tasks.py @@ -17,8 +17,8 @@ # under the License. import unittest - from unittest import mock + from google.cloud.tasks_v2.types import Queue, Task from airflow.providers.google.cloud.operators.tasks import ( diff --git a/tests/providers/google/cloud/operators/test_text_to_speech.py b/tests/providers/google/cloud/operators/test_text_to_speech.py index dd6a628cfcc23..006c6b5df77a6 100644 --- a/tests/providers/google/cloud/operators/test_text_to_speech.py +++ b/tests/providers/google/cloud/operators/test_text_to_speech.py @@ -17,8 +17,8 @@ # under the License. import unittest - from unittest.mock import ANY, Mock, PropertyMock, patch + from parameterized import parameterized from airflow.exceptions import AirflowException diff --git a/tests/providers/google/cloud/operators/test_translate.py b/tests/providers/google/cloud/operators/test_translate.py index 5a882e5275053..c32bd2f709701 100644 --- a/tests/providers/google/cloud/operators/test_translate.py +++ b/tests/providers/google/cloud/operators/test_translate.py @@ -17,7 +17,6 @@ # under the License. import unittest - from unittest import mock from airflow.providers.google.cloud.operators.translate import CloudTranslateTextOperator diff --git a/tests/providers/google/cloud/operators/test_translate_speech.py b/tests/providers/google/cloud/operators/test_translate_speech.py index 015d4ee00edd6..fc1c6376cadcc 100644 --- a/tests/providers/google/cloud/operators/test_translate_speech.py +++ b/tests/providers/google/cloud/operators/test_translate_speech.py @@ -17,8 +17,8 @@ # under the License. import unittest - from unittest import mock + from google.cloud.speech_v1.proto.cloud_speech_pb2 import ( RecognizeResponse, SpeechRecognitionAlternative, diff --git a/tests/providers/google/cloud/operators/test_video_intelligence.py b/tests/providers/google/cloud/operators/test_video_intelligence.py index f3c8360516588..c6b5f322ae59f 100644 --- a/tests/providers/google/cloud/operators/test_video_intelligence.py +++ b/tests/providers/google/cloud/operators/test_video_intelligence.py @@ -17,8 +17,8 @@ # under the License. import unittest - from unittest import mock + from google.cloud.videointelligence_v1 import enums from google.cloud.videointelligence_v1.proto.video_intelligence_pb2 import AnnotateVideoResponse diff --git a/tests/providers/google/cloud/operators/test_vision.py b/tests/providers/google/cloud/operators/test_vision.py index 5dcaf0ef007d0..2ca8d9ae70b89 100644 --- a/tests/providers/google/cloud/operators/test_vision.py +++ b/tests/providers/google/cloud/operators/test_vision.py @@ -17,8 +17,8 @@ # under the License. import unittest - from unittest import mock + from google.api_core.exceptions import AlreadyExists from google.cloud.vision_v1.types import Product, ProductSet, ReferenceImage diff --git a/tests/providers/google/cloud/sensors/test_bigquery_dts.py b/tests/providers/google/cloud/sensors/test_bigquery_dts.py index 4bb466dbb3d1a..92a116ef8e3e4 100644 --- a/tests/providers/google/cloud/sensors/test_bigquery_dts.py +++ b/tests/providers/google/cloud/sensors/test_bigquery_dts.py @@ -17,7 +17,6 @@ # under the License. import unittest - from unittest import mock from airflow.providers.google.cloud.sensors.bigquery_dts import BigQueryDataTransferServiceTransferRunSensor diff --git a/tests/providers/google/cloud/sensors/test_cloud_storage_transfer_service.py b/tests/providers/google/cloud/sensors/test_cloud_storage_transfer_service.py index e86746986b10d..aa169fdf1dd23 100644 --- a/tests/providers/google/cloud/sensors/test_cloud_storage_transfer_service.py +++ b/tests/providers/google/cloud/sensors/test_cloud_storage_transfer_service.py @@ -16,8 +16,8 @@ # specific language governing permissions and limitations # under the License. import unittest - from unittest import mock + from parameterized import parameterized from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import GcpTransferOperationStatus diff --git a/tests/providers/google/cloud/sensors/test_dataproc.py b/tests/providers/google/cloud/sensors/test_dataproc.py index f9b40035e8cb5..f2a1d45704cc8 100644 --- a/tests/providers/google/cloud/sensors/test_dataproc.py +++ b/tests/providers/google/cloud/sensors/test_dataproc.py @@ -19,9 +19,9 @@ from unittest import mock from google.cloud.dataproc_v1beta2.types import JobStatus + from airflow import AirflowException from airflow.providers.google.cloud.sensors.dataproc import DataprocJobSensor - from airflow.version import version as airflow_version AIRFLOW_VERSION = "v" + airflow_version.replace(".", "-").replace("+", "-") diff --git a/tests/providers/google/cloud/sensors/test_pubsub.py b/tests/providers/google/cloud/sensors/test_pubsub.py index 46c85e3158e50..5c1ead938fd1a 100644 --- a/tests/providers/google/cloud/sensors/test_pubsub.py +++ b/tests/providers/google/cloud/sensors/test_pubsub.py @@ -18,8 +18,8 @@ import unittest from typing import Any, Dict, List - from unittest import mock + from google.cloud.pubsub_v1.types import ReceivedMessage from google.protobuf.json_format import MessageToDict, ParseDict diff --git a/tests/providers/google/cloud/transfers/test_adls_to_gcs.py b/tests/providers/google/cloud/transfers/test_adls_to_gcs.py index 94003ffbd1a3d..68649c242c6d1 100644 --- a/tests/providers/google/cloud/transfers/test_adls_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_adls_to_gcs.py @@ -17,7 +17,6 @@ # under the License. import unittest - from unittest import mock from airflow.providers.google.cloud.transfers.adls_to_gcs import ADLSToGCSOperator diff --git a/tests/providers/google/cloud/transfers/test_azure_fileshare_to_gcs.py b/tests/providers/google/cloud/transfers/test_azure_fileshare_to_gcs.py index 11da033e5d944..1a46622869b27 100644 --- a/tests/providers/google/cloud/transfers/test_azure_fileshare_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_azure_fileshare_to_gcs.py @@ -16,7 +16,6 @@ # under the License. import unittest - from unittest import mock from airflow.providers.google.cloud.transfers.azure_fileshare_to_gcs import AzureFileShareToGCSOperator diff --git a/tests/providers/google/cloud/transfers/test_azure_fileshare_to_gcs_system.py b/tests/providers/google/cloud/transfers/test_azure_fileshare_to_gcs_system.py index ebd36652ea55b..76e69cdc67500 100644 --- a/tests/providers/google/cloud/transfers/test_azure_fileshare_to_gcs_system.py +++ b/tests/providers/google/cloud/transfers/test_azure_fileshare_to_gcs_system.py @@ -19,22 +19,21 @@ import pytest +from airflow.models import Connection from airflow.providers.google.cloud.example_dags.example_azure_fileshare_to_gcs import ( + AZURE_DIRECTORY_NAME, AZURE_SHARE_NAME, DEST_GCS_BUCKET, - AZURE_DIRECTORY_NAME, ) from airflow.utils.session import create_session - -from airflow.models import Connection from tests.providers.google.cloud.utils.gcp_authenticator import GCP_GCS_KEY from tests.test_utils.azure_system_helpers import AzureSystemTest, provide_azure_fileshare from tests.test_utils.db import clear_db_connections from tests.test_utils.gcp_system_helpers import ( - GoogleSystemTest, - provide_gcs_bucket, CLOUD_DAG_FOLDER, + GoogleSystemTest, provide_gcp_context, + provide_gcs_bucket, ) AZURE_LOGIN = os.environ.get('AZURE_LOGIN', 'default_login') diff --git a/tests/providers/google/cloud/transfers/test_bigquery_to_bigquery.py b/tests/providers/google/cloud/transfers/test_bigquery_to_bigquery.py index 11f71fbb4d451..162182958655d 100644 --- a/tests/providers/google/cloud/transfers/test_bigquery_to_bigquery.py +++ b/tests/providers/google/cloud/transfers/test_bigquery_to_bigquery.py @@ -17,7 +17,6 @@ # under the License. import unittest - from unittest import mock from airflow.providers.google.cloud.transfers.bigquery_to_bigquery import BigQueryToBigQueryOperator diff --git a/tests/providers/google/cloud/transfers/test_bigquery_to_gcs.py b/tests/providers/google/cloud/transfers/test_bigquery_to_gcs.py index d25dbf1955bd1..2ddac81e76929 100644 --- a/tests/providers/google/cloud/transfers/test_bigquery_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_bigquery_to_gcs.py @@ -17,7 +17,6 @@ # under the License. import unittest - from unittest import mock from airflow.providers.google.cloud.transfers.bigquery_to_gcs import BigQueryToGCSOperator diff --git a/tests/providers/google/cloud/transfers/test_bigquery_to_mysql.py b/tests/providers/google/cloud/transfers/test_bigquery_to_mysql.py index 29811ab33ac77..acdf63dd0b103 100644 --- a/tests/providers/google/cloud/transfers/test_bigquery_to_mysql.py +++ b/tests/providers/google/cloud/transfers/test_bigquery_to_mysql.py @@ -16,7 +16,6 @@ # specific language governing permissions and limitations # under the License. import unittest - from unittest import mock from airflow.providers.google.cloud.transfers.bigquery_to_mysql import BigQueryToMySqlOperator diff --git a/tests/providers/google/cloud/transfers/test_cassandra_to_gcs.py b/tests/providers/google/cloud/transfers/test_cassandra_to_gcs.py index 7016d8709c8d2..4d6481d1b1e72 100644 --- a/tests/providers/google/cloud/transfers/test_cassandra_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_cassandra_to_gcs.py @@ -18,7 +18,6 @@ import unittest from unittest import mock - from unittest.mock import call from airflow.providers.google.cloud.transfers.cassandra_to_gcs import CassandraToGCSOperator diff --git a/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py b/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py index a602bee4257bc..9b9aabfe4ebcb 100644 --- a/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py +++ b/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py @@ -17,7 +17,6 @@ # under the License. import unittest - from unittest import mock from airflow.providers.google.cloud.transfers.gcs_to_bigquery import GCSToBigQueryOperator diff --git a/tests/providers/google/cloud/transfers/test_gcs_to_gcs.py b/tests/providers/google/cloud/transfers/test_gcs_to_gcs.py index e280a6d9d0188..0b2e1eb0ca2fa 100644 --- a/tests/providers/google/cloud/transfers/test_gcs_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_gcs_to_gcs.py @@ -18,7 +18,6 @@ import unittest from datetime import datetime - from unittest import mock from airflow.exceptions import AirflowException diff --git a/tests/providers/google/cloud/transfers/test_gcs_to_local.py b/tests/providers/google/cloud/transfers/test_gcs_to_local.py index b703d63741114..fcd03d6c3dca1 100644 --- a/tests/providers/google/cloud/transfers/test_gcs_to_local.py +++ b/tests/providers/google/cloud/transfers/test_gcs_to_local.py @@ -16,7 +16,6 @@ # under the License. import unittest - from unittest import mock from airflow.providers.google.cloud.transfers.gcs_to_local import GCSToLocalFilesystemOperator diff --git a/tests/providers/google/cloud/transfers/test_gcs_to_local_system.py b/tests/providers/google/cloud/transfers/test_gcs_to_local_system.py index 1b5447b2ec642..55762bab3836f 100644 --- a/tests/providers/google/cloud/transfers/test_gcs_to_local_system.py +++ b/tests/providers/google/cloud/transfers/test_gcs_to_local_system.py @@ -21,8 +21,8 @@ from airflow.providers.google.cloud.example_dags.example_gcs_to_local import ( BUCKET, - PATH_TO_REMOTE_FILE, PATH_TO_LOCAL_FILE, + PATH_TO_REMOTE_FILE, ) from tests.providers.google.cloud.utils.gcp_authenticator import GCP_GCS_KEY from tests.test_utils.gcp_system_helpers import CLOUD_DAG_FOLDER, GoogleSystemTest, provide_gcp_context diff --git a/tests/providers/google/cloud/transfers/test_gcs_to_sftp.py b/tests/providers/google/cloud/transfers/test_gcs_to_sftp.py index 81653447cd93a..3cf3d537392f9 100644 --- a/tests/providers/google/cloud/transfers/test_gcs_to_sftp.py +++ b/tests/providers/google/cloud/transfers/test_gcs_to_sftp.py @@ -19,7 +19,6 @@ import os import unittest - from unittest import mock from airflow.exceptions import AirflowException diff --git a/tests/providers/google/cloud/transfers/test_local_to_gcs.py b/tests/providers/google/cloud/transfers/test_local_to_gcs.py index 5cf2509f06500..800d9d8a23d35 100644 --- a/tests/providers/google/cloud/transfers/test_local_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_local_to_gcs.py @@ -21,8 +21,8 @@ import os import unittest from glob import glob - from unittest import mock + import pytest from airflow.models.dag import DAG diff --git a/tests/providers/google/cloud/transfers/test_mssql_to_gcs.py b/tests/providers/google/cloud/transfers/test_mssql_to_gcs.py index f537aef205a6b..8f22ef4e26d88 100644 --- a/tests/providers/google/cloud/transfers/test_mssql_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_mssql_to_gcs.py @@ -17,7 +17,6 @@ # under the License. import unittest - from unittest import mock from airflow import PY38 diff --git a/tests/providers/google/cloud/transfers/test_mysql_to_gcs.py b/tests/providers/google/cloud/transfers/test_mysql_to_gcs.py index f96f830622d0c..d5690c63afa7d 100644 --- a/tests/providers/google/cloud/transfers/test_mysql_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_mysql_to_gcs.py @@ -19,8 +19,8 @@ import datetime import decimal import unittest - from unittest import mock + from _mysql_exceptions import ProgrammingError from parameterized import parameterized diff --git a/tests/providers/google/cloud/transfers/test_mysql_to_gcs_system.py b/tests/providers/google/cloud/transfers/test_mysql_to_gcs_system.py index 65e2bc7af5330..08aeac929ca8c 100644 --- a/tests/providers/google/cloud/transfers/test_mysql_to_gcs_system.py +++ b/tests/providers/google/cloud/transfers/test_mysql_to_gcs_system.py @@ -16,10 +16,10 @@ # specific language governing permissions and limitations # under the License. import pytest -from psycopg2 import ProgrammingError, OperationalError +from psycopg2 import OperationalError, ProgrammingError -from airflow.providers.mysql.hooks.mysql import MySqlHook from airflow.providers.google.cloud.example_dags.example_mysql_to_gcs import GCS_BUCKET +from airflow.providers.mysql.hooks.mysql import MySqlHook from tests.providers.google.cloud.utils.gcp_authenticator import GCP_GCS_KEY from tests.test_utils.gcp_system_helpers import CLOUD_DAG_FOLDER, GoogleSystemTest, provide_gcp_context diff --git a/tests/providers/google/cloud/transfers/test_s3_to_gcs.py b/tests/providers/google/cloud/transfers/test_s3_to_gcs.py index e59fa34a6cd71..990702fdbdb9a 100644 --- a/tests/providers/google/cloud/transfers/test_s3_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_s3_to_gcs.py @@ -17,7 +17,6 @@ # under the License. import unittest - from unittest import mock from airflow.providers.google.cloud.transfers.s3_to_gcs import S3ToGCSOperator diff --git a/tests/providers/google/cloud/transfers/test_s3_to_gcs_system.py b/tests/providers/google/cloud/transfers/test_s3_to_gcs_system.py index 66b4217452f69..6f25f641443fa 100644 --- a/tests/providers/google/cloud/transfers/test_s3_to_gcs_system.py +++ b/tests/providers/google/cloud/transfers/test_s3_to_gcs_system.py @@ -16,9 +16,10 @@ # under the License. import pytest + from airflow.providers.google.cloud.example_dags.example_s3_to_gcs import UPLOAD_FILE from tests.providers.google.cloud.utils.gcp_authenticator import GCP_GCS_KEY -from tests.test_utils.gcp_system_helpers import GoogleSystemTest, provide_gcp_context, CLOUD_DAG_FOLDER +from tests.test_utils.gcp_system_helpers import CLOUD_DAG_FOLDER, GoogleSystemTest, provide_gcp_context FILENAME = UPLOAD_FILE.split('/')[-1] diff --git a/tests/providers/google/cloud/transfers/test_salesforce_to_gcs.py b/tests/providers/google/cloud/transfers/test_salesforce_to_gcs.py index e52e8ddf39ded..5aed4b7535662 100644 --- a/tests/providers/google/cloud/transfers/test_salesforce_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_salesforce_to_gcs.py @@ -17,7 +17,6 @@ import unittest from collections import OrderedDict - from unittest import mock from airflow.providers.google.cloud.hooks.gcs import GCSHook diff --git a/tests/providers/google/cloud/transfers/test_salesforce_to_gcs_system.py b/tests/providers/google/cloud/transfers/test_salesforce_to_gcs_system.py index ffdf49584956f..d5e45b6682bab 100644 --- a/tests/providers/google/cloud/transfers/test_salesforce_to_gcs_system.py +++ b/tests/providers/google/cloud/transfers/test_salesforce_to_gcs_system.py @@ -16,6 +16,7 @@ # under the License. import os + import pytest from tests.providers.google.cloud.utils.gcp_authenticator import GCP_BIGQUERY_KEY diff --git a/tests/providers/google/cloud/transfers/test_sftp_to_gcs.py b/tests/providers/google/cloud/transfers/test_sftp_to_gcs.py index 681b6fe657ec7..03b21d48c085e 100644 --- a/tests/providers/google/cloud/transfers/test_sftp_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_sftp_to_gcs.py @@ -19,7 +19,6 @@ import os import unittest - from unittest import mock from airflow.exceptions import AirflowException diff --git a/tests/providers/google/cloud/transfers/test_sql_to_gcs.py b/tests/providers/google/cloud/transfers/test_sql_to_gcs.py index 889352d6357d4..b4bc5a3263fb9 100644 --- a/tests/providers/google/cloud/transfers/test_sql_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_sql_to_gcs.py @@ -17,9 +17,9 @@ import json import unittest +from unittest import mock from unittest.mock import Mock -from unittest import mock import unicodecsv as csv from airflow.providers.google.cloud.hooks.gcs import GCSHook diff --git a/tests/providers/google/cloud/utils/test_credentials_provider.py b/tests/providers/google/cloud/utils/test_credentials_provider.py index 0aa9770dd2a74..c49872c0c29a9 100644 --- a/tests/providers/google/cloud/utils/test_credentials_provider.py +++ b/tests/providers/google/cloud/utils/test_credentials_provider.py @@ -20,9 +20,9 @@ import re import unittest from io import StringIO +from unittest import mock from uuid import uuid4 -from unittest import mock from google.auth.environment_vars import CREDENTIALS from parameterized import parameterized diff --git a/tests/providers/google/firebase/hooks/test_firestore.py b/tests/providers/google/firebase/hooks/test_firestore.py index dea7c78691565..ee1117a5875bc 100644 --- a/tests/providers/google/firebase/hooks/test_firestore.py +++ b/tests/providers/google/firebase/hooks/test_firestore.py @@ -21,7 +21,6 @@ import unittest from typing import Optional from unittest import mock - from unittest.mock import PropertyMock from airflow.exceptions import AirflowException diff --git a/tests/providers/google/suite/hooks/test_drive.py b/tests/providers/google/suite/hooks/test_drive.py index ba0e31bdb9cd0..37f224734a8ba 100644 --- a/tests/providers/google/suite/hooks/test_drive.py +++ b/tests/providers/google/suite/hooks/test_drive.py @@ -17,7 +17,6 @@ # under the License. # import unittest - from unittest import mock from airflow.providers.google.suite.hooks.drive import GoogleDriveHook diff --git a/tests/providers/google/suite/hooks/test_sheets.py b/tests/providers/google/suite/hooks/test_sheets.py index adcbb3de28762..56f5f5e93a495 100644 --- a/tests/providers/google/suite/hooks/test_sheets.py +++ b/tests/providers/google/suite/hooks/test_sheets.py @@ -21,7 +21,6 @@ """ import unittest - from unittest import mock from airflow.exceptions import AirflowException diff --git a/tests/providers/grpc/hooks/test_grpc.py b/tests/providers/grpc/hooks/test_grpc.py index 07d432980821b..4ec9267c4670a 100644 --- a/tests/providers/grpc/hooks/test_grpc.py +++ b/tests/providers/grpc/hooks/test_grpc.py @@ -17,7 +17,6 @@ import unittest from io import StringIO - from unittest import mock from airflow.exceptions import AirflowConfigException diff --git a/tests/providers/grpc/operators/test_grpc.py b/tests/providers/grpc/operators/test_grpc.py index 5e96249b5bc24..1bbed8b789c0c 100644 --- a/tests/providers/grpc/operators/test_grpc.py +++ b/tests/providers/grpc/operators/test_grpc.py @@ -16,7 +16,6 @@ # under the License. import unittest - from unittest import mock from airflow.providers.grpc.operators.grpc import GrpcOperator diff --git a/tests/providers/http/hooks/test_http.py b/tests/providers/http/hooks/test_http.py index ae666ef4ccc6f..6824d64ae6284 100644 --- a/tests/providers/http/hooks/test_http.py +++ b/tests/providers/http/hooks/test_http.py @@ -17,8 +17,8 @@ # under the License. import json import unittest - from unittest import mock + import requests import requests_mock import tenacity diff --git a/tests/providers/http/operators/test_http.py b/tests/providers/http/operators/test_http.py index d9c7c2ba11f96..75ca9dc91b9bc 100644 --- a/tests/providers/http/operators/test_http.py +++ b/tests/providers/http/operators/test_http.py @@ -17,8 +17,8 @@ # under the License. import unittest - from unittest import mock + import requests_mock from airflow.exceptions import AirflowException diff --git a/tests/providers/http/sensors/test_http.py b/tests/providers/http/sensors/test_http.py index 046390fac846b..04b6c4946c758 100644 --- a/tests/providers/http/sensors/test_http.py +++ b/tests/providers/http/sensors/test_http.py @@ -16,9 +16,9 @@ # specific language governing permissions and limitations # under the License. import unittest +from unittest import mock from unittest.mock import patch -from unittest import mock import requests from airflow.exceptions import AirflowException, AirflowSensorTimeout diff --git a/tests/providers/microsoft/azure/hooks/test_azure_batch.py b/tests/providers/microsoft/azure/hooks/test_azure_batch.py index 6d9188fae4869..9549e5795a2e5 100644 --- a/tests/providers/microsoft/azure/hooks/test_azure_batch.py +++ b/tests/providers/microsoft/azure/hooks/test_azure_batch.py @@ -18,8 +18,8 @@ # import json import unittest - from unittest import mock + from azure.batch import BatchServiceClient, models as batch_models from airflow.models import Connection diff --git a/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py b/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py index 1703f4092bfc5..701abc5a34143 100644 --- a/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py +++ b/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py @@ -22,8 +22,8 @@ import logging import unittest import uuid - from unittest import mock + from azure.cosmos.cosmos_client import CosmosClient from airflow.exceptions import AirflowException diff --git a/tests/providers/microsoft/azure/hooks/test_azure_data_lake.py b/tests/providers/microsoft/azure/hooks/test_azure_data_lake.py index 90a608e1b7464..8fa0feea54146 100644 --- a/tests/providers/microsoft/azure/hooks/test_azure_data_lake.py +++ b/tests/providers/microsoft/azure/hooks/test_azure_data_lake.py @@ -18,7 +18,6 @@ # import json import unittest - from unittest import mock from airflow.models import Connection diff --git a/tests/providers/microsoft/azure/hooks/test_azure_fileshare.py b/tests/providers/microsoft/azure/hooks/test_azure_fileshare.py index d298ccd8d19e9..a9a6298ddf641 100644 --- a/tests/providers/microsoft/azure/hooks/test_azure_fileshare.py +++ b/tests/providers/microsoft/azure/hooks/test_azure_fileshare.py @@ -27,9 +27,9 @@ import json import unittest - from unittest import mock -from azure.storage.file import File, Directory + +from azure.storage.file import Directory, File from airflow.models import Connection from airflow.providers.microsoft.azure.hooks.azure_fileshare import AzureFileShareHook diff --git a/tests/providers/microsoft/azure/hooks/test_wasb.py b/tests/providers/microsoft/azure/hooks/test_wasb.py index d7cd0dc4bdaca..12a784790c73e 100644 --- a/tests/providers/microsoft/azure/hooks/test_wasb.py +++ b/tests/providers/microsoft/azure/hooks/test_wasb.py @@ -21,7 +21,6 @@ import json import unittest from collections import namedtuple - from unittest import mock from airflow.exceptions import AirflowException diff --git a/tests/providers/microsoft/azure/operators/test_adls_list.py b/tests/providers/microsoft/azure/operators/test_adls_list.py index c5be2b5ded627..9b0a5c2f7ed09 100644 --- a/tests/providers/microsoft/azure/operators/test_adls_list.py +++ b/tests/providers/microsoft/azure/operators/test_adls_list.py @@ -17,7 +17,6 @@ # under the License. import unittest - from unittest import mock from airflow.providers.microsoft.azure.operators.adls_list import AzureDataLakeStorageListOperator diff --git a/tests/providers/microsoft/azure/operators/test_azure_batch.py b/tests/providers/microsoft/azure/operators/test_azure_batch.py index 18702f568db0f..48e5a7a812383 100644 --- a/tests/providers/microsoft/azure/operators/test_azure_batch.py +++ b/tests/providers/microsoft/azure/operators/test_azure_batch.py @@ -18,7 +18,6 @@ # import json import unittest - from unittest import mock from airflow.exceptions import AirflowException diff --git a/tests/providers/microsoft/azure/operators/test_azure_container_instances.py b/tests/providers/microsoft/azure/operators/test_azure_container_instances.py index 22c495249d86c..90f9ece7c3435 100644 --- a/tests/providers/microsoft/azure/operators/test_azure_container_instances.py +++ b/tests/providers/microsoft/azure/operators/test_azure_container_instances.py @@ -19,9 +19,9 @@ import unittest from collections import namedtuple +from unittest import mock from unittest.mock import MagicMock -from unittest import mock from azure.mgmt.containerinstance.models import ContainerState, Event from airflow.exceptions import AirflowException diff --git a/tests/providers/microsoft/azure/operators/test_azure_cosmos.py b/tests/providers/microsoft/azure/operators/test_azure_cosmos.py index ac1299012c1c9..26144406cb892 100644 --- a/tests/providers/microsoft/azure/operators/test_azure_cosmos.py +++ b/tests/providers/microsoft/azure/operators/test_azure_cosmos.py @@ -21,7 +21,6 @@ import json import unittest import uuid - from unittest import mock from airflow.models import Connection diff --git a/tests/providers/microsoft/azure/operators/test_wasb_delete_blob.py b/tests/providers/microsoft/azure/operators/test_wasb_delete_blob.py index 8310f72bedf67..04df34cd89ca6 100644 --- a/tests/providers/microsoft/azure/operators/test_wasb_delete_blob.py +++ b/tests/providers/microsoft/azure/operators/test_wasb_delete_blob.py @@ -19,7 +19,6 @@ import datetime import unittest - from unittest import mock from airflow.models.dag import DAG diff --git a/tests/providers/microsoft/azure/sensors/test_wasb.py b/tests/providers/microsoft/azure/sensors/test_wasb.py index 8b1c6a1d44509..5aaec19e5e203 100644 --- a/tests/providers/microsoft/azure/sensors/test_wasb.py +++ b/tests/providers/microsoft/azure/sensors/test_wasb.py @@ -19,7 +19,6 @@ import datetime import unittest - from unittest import mock from airflow.models.dag import DAG diff --git a/tests/providers/microsoft/azure/transfers/test_azure_blob_to_gcs.py b/tests/providers/microsoft/azure/transfers/test_azure_blob_to_gcs.py index cc1d7b1d51606..b9517a894c727 100644 --- a/tests/providers/microsoft/azure/transfers/test_azure_blob_to_gcs.py +++ b/tests/providers/microsoft/azure/transfers/test_azure_blob_to_gcs.py @@ -15,12 +15,9 @@ # specific language governing permissions and limitations # under the License. import unittest - from unittest import mock -from airflow.providers.microsoft.azure.transfers.azure_blob_to_gcs import ( - AzureBlobStorageToGCSOperator, -) +from airflow.providers.microsoft.azure.transfers.azure_blob_to_gcs import AzureBlobStorageToGCSOperator WASB_CONN_ID = "wasb_default" GCP_CONN_ID = "google_cloud_default" diff --git a/tests/providers/microsoft/azure/transfers/test_file_to_wasb.py b/tests/providers/microsoft/azure/transfers/test_file_to_wasb.py index d7b708ba8b6a7..5a4f14c729259 100644 --- a/tests/providers/microsoft/azure/transfers/test_file_to_wasb.py +++ b/tests/providers/microsoft/azure/transfers/test_file_to_wasb.py @@ -19,7 +19,6 @@ import datetime import unittest - from unittest import mock from airflow.models.dag import DAG diff --git a/tests/providers/microsoft/azure/transfers/test_local_to_adls.py b/tests/providers/microsoft/azure/transfers/test_local_to_adls.py index d79188dd6962f..0bcc371393a35 100644 --- a/tests/providers/microsoft/azure/transfers/test_local_to_adls.py +++ b/tests/providers/microsoft/azure/transfers/test_local_to_adls.py @@ -17,7 +17,6 @@ # under the License. import unittest - from unittest import mock from airflow.exceptions import AirflowException diff --git a/tests/providers/microsoft/azure/transfers/test_oracle_to_azure_data_lake.py b/tests/providers/microsoft/azure/transfers/test_oracle_to_azure_data_lake.py index 49f591f4278e7..8d16878b2d45d 100644 --- a/tests/providers/microsoft/azure/transfers/test_oracle_to_azure_data_lake.py +++ b/tests/providers/microsoft/azure/transfers/test_oracle_to_azure_data_lake.py @@ -19,9 +19,9 @@ import os import unittest from tempfile import TemporaryDirectory - from unittest import mock from unittest.mock import MagicMock + import unicodecsv as csv from airflow.providers.microsoft.azure.transfers.oracle_to_azure_data_lake import ( diff --git a/tests/providers/microsoft/mssql/hooks/test_mssql.py b/tests/providers/microsoft/mssql/hooks/test_mssql.py index 0065f9438ef41..91aed5d76fd44 100644 --- a/tests/providers/microsoft/mssql/hooks/test_mssql.py +++ b/tests/providers/microsoft/mssql/hooks/test_mssql.py @@ -17,7 +17,6 @@ # under the License. import unittest - from unittest import mock from airflow import PY38 diff --git a/tests/providers/microsoft/mssql/operators/test_mssql.py b/tests/providers/microsoft/mssql/operators/test_mssql.py index 4e4c6c272d1f8..1304b123985db 100644 --- a/tests/providers/microsoft/mssql/operators/test_mssql.py +++ b/tests/providers/microsoft/mssql/operators/test_mssql.py @@ -17,7 +17,6 @@ # under the License. import unittest - from unittest import mock from airflow import PY38 diff --git a/tests/providers/openfaas/hooks/test_openfaas.py b/tests/providers/openfaas/hooks/test_openfaas.py index 10332e36f9051..5c68def0c5c1c 100644 --- a/tests/providers/openfaas/hooks/test_openfaas.py +++ b/tests/providers/openfaas/hooks/test_openfaas.py @@ -18,8 +18,8 @@ # import unittest - from unittest import mock + import requests_mock from airflow.exceptions import AirflowException diff --git a/tests/providers/oracle/hooks/test_oracle.py b/tests/providers/oracle/hooks/test_oracle.py index 2176306895d8d..d27ca44d68704 100644 --- a/tests/providers/oracle/hooks/test_oracle.py +++ b/tests/providers/oracle/hooks/test_oracle.py @@ -19,8 +19,8 @@ import json import unittest from datetime import datetime - from unittest import mock + import numpy from airflow.models import Connection diff --git a/tests/providers/oracle/operators/test_oracle.py b/tests/providers/oracle/operators/test_oracle.py index bbc3fc9221c98..2c4d984647e8c 100644 --- a/tests/providers/oracle/operators/test_oracle.py +++ b/tests/providers/oracle/operators/test_oracle.py @@ -16,7 +16,6 @@ # under the License. import unittest - from unittest import mock from airflow.providers.oracle.hooks.oracle import OracleHook diff --git a/tests/providers/oracle/transfers/test_oracle_to_oracle.py b/tests/providers/oracle/transfers/test_oracle_to_oracle.py index f7536afe6d1a5..a036a3e990df0 100644 --- a/tests/providers/oracle/transfers/test_oracle_to_oracle.py +++ b/tests/providers/oracle/transfers/test_oracle_to_oracle.py @@ -18,7 +18,6 @@ import unittest from unittest import mock - from unittest.mock import MagicMock from airflow.providers.oracle.transfers.oracle_to_oracle import OracleToOracleOperator diff --git a/tests/providers/pagerduty/hooks/test_pagerduty.py b/tests/providers/pagerduty/hooks/test_pagerduty.py index 50a115437ff8c..95e9719cd10bb 100644 --- a/tests/providers/pagerduty/hooks/test_pagerduty.py +++ b/tests/providers/pagerduty/hooks/test_pagerduty.py @@ -16,6 +16,7 @@ # specific language governing permissions and limitations # under the License. import unittest + import requests_mock from airflow.models import Connection diff --git a/tests/providers/plexus/hooks/test_plexus.py b/tests/providers/plexus/hooks/test_plexus.py index f90d35b90da27..3419777af4119 100644 --- a/tests/providers/plexus/hooks/test_plexus.py +++ b/tests/providers/plexus/hooks/test_plexus.py @@ -21,6 +21,7 @@ import arrow import pytest from requests.exceptions import Timeout + from airflow.exceptions import AirflowException from airflow.providers.plexus.hooks.plexus import PlexusHook diff --git a/tests/providers/plexus/operators/test_job.py b/tests/providers/plexus/operators/test_job.py index 44873ff4f2521..7c902332fc5b9 100644 --- a/tests/providers/plexus/operators/test_job.py +++ b/tests/providers/plexus/operators/test_job.py @@ -17,8 +17,10 @@ from unittest import mock from unittest.mock import Mock + import pytest from requests.exceptions import Timeout + from airflow.exceptions import AirflowException from airflow.providers.plexus.operators.job import PlexusJobOperator diff --git a/tests/providers/qubole/operators/test_qubole_check.py b/tests/providers/qubole/operators/test_qubole_check.py index e3f09a56cc89f..227f0490a5d35 100644 --- a/tests/providers/qubole/operators/test_qubole_check.py +++ b/tests/providers/qubole/operators/test_qubole_check.py @@ -18,8 +18,8 @@ # import unittest from datetime import datetime - from unittest import mock + from qds_sdk.commands import HiveCommand from airflow.exceptions import AirflowException diff --git a/tests/providers/samba/hooks/test_samba.py b/tests/providers/samba/hooks/test_samba.py index e5f294d93a0c7..86d9c764207d8 100644 --- a/tests/providers/samba/hooks/test_samba.py +++ b/tests/providers/samba/hooks/test_samba.py @@ -17,9 +17,9 @@ # under the License. import unittest +from unittest import mock from unittest.mock import call -from unittest import mock import smbclient from airflow.exceptions import AirflowException diff --git a/tests/providers/sendgrid/utils/test_emailer.py b/tests/providers/sendgrid/utils/test_emailer.py index df872013447ba..bb1a5f2cf5f20 100644 --- a/tests/providers/sendgrid/utils/test_emailer.py +++ b/tests/providers/sendgrid/utils/test_emailer.py @@ -21,7 +21,6 @@ import os import tempfile import unittest - from unittest import mock from airflow.providers.sendgrid.utils.emailer import send_email diff --git a/tests/providers/singularity/operators/test_singularity.py b/tests/providers/singularity/operators/test_singularity.py index 324ca35fd9df1..71ab5ace64db9 100644 --- a/tests/providers/singularity/operators/test_singularity.py +++ b/tests/providers/singularity/operators/test_singularity.py @@ -17,8 +17,8 @@ # under the License. import unittest - from unittest import mock + from parameterized import parameterized from spython.instance import Instance diff --git a/tests/providers/slack/hooks/test_slack.py b/tests/providers/slack/hooks/test_slack.py index 4ca0060f8fbaf..6ebec3a38386d 100644 --- a/tests/providers/slack/hooks/test_slack.py +++ b/tests/providers/slack/hooks/test_slack.py @@ -17,8 +17,8 @@ # under the License. import unittest - from unittest import mock + from slack.errors import SlackApiError from airflow.exceptions import AirflowException diff --git a/tests/providers/slack/operators/test_slack.py b/tests/providers/slack/operators/test_slack.py index 39e31a3c9bc99..e505282c033b3 100644 --- a/tests/providers/slack/operators/test_slack.py +++ b/tests/providers/slack/operators/test_slack.py @@ -18,7 +18,6 @@ import json import unittest - from unittest import mock from airflow.providers.slack.operators.slack import SlackAPIFileOperator, SlackAPIPostOperator diff --git a/tests/providers/snowflake/operators/test_snowflake.py b/tests/providers/snowflake/operators/test_snowflake.py index 9d0d9ff6fa9c4..3bec67174ce71 100644 --- a/tests/providers/snowflake/operators/test_snowflake.py +++ b/tests/providers/snowflake/operators/test_snowflake.py @@ -17,7 +17,6 @@ # under the License. import unittest - from unittest import mock from airflow.models.dag import DAG diff --git a/tests/providers/ssh/hooks/test_ssh.py b/tests/providers/ssh/hooks/test_ssh.py index be52138daedab..d09a6678edc70 100644 --- a/tests/providers/ssh/hooks/test_ssh.py +++ b/tests/providers/ssh/hooks/test_ssh.py @@ -16,13 +16,13 @@ # specific language governing permissions and limitations # under the License. import json +import random +import string import unittest from io import StringIO - from typing import Optional -import random -import string from unittest import mock + import paramiko from airflow.models import Connection diff --git a/tests/secrets/test_local_filesystem.py b/tests/secrets/test_local_filesystem.py index bc6d0dea0a079..5ee4813f8285b 100644 --- a/tests/secrets/test_local_filesystem.py +++ b/tests/secrets/test_local_filesystem.py @@ -107,10 +107,13 @@ def test_missing_file(self, mock_exists): @parameterized.expand( ( ("KEY: AAA", {"KEY": "AAA"}), - (""" + ( + """ KEY_A: AAA KEY_B: BBB - """, {"KEY_A": "AAA", "KEY_B": "BBB"}), + """, + {"KEY_A": "AAA", "KEY_B": "BBB"}, + ), ) ) def test_yaml_file_should_load_variables(self, file_content, expected_variables): @@ -141,8 +144,7 @@ def test_env_file_should_load_connection(self, file_content, expected_connection with mock_local_file(file_content): connection_by_conn_id = local_filesystem.load_connections_dict("a.env") connection_uris_by_conn_id = { - conn_id: connection.get_uri() - for conn_id, connection in connection_by_conn_id.items() + conn_id: connection.get_uri() for conn_id, connection in connection_by_conn_id.items() } self.assertEqual(expected_connection_uris, connection_uris_by_conn_id) @@ -170,8 +172,7 @@ def test_json_file_should_load_connection(self, file_content, expected_connectio with mock_local_file(json.dumps(file_content)): connections_by_conn_id = local_filesystem.load_connections_dict("a.json") connection_uris_by_conn_id = { - conn_id: connection.get_uri() - for conn_id, connection in connections_by_conn_id.items() + conn_id: connection.get_uri() for conn_id, connection in connections_by_conn_id.items() } self.assertEqual(expected_connection_uris, connection_uris_by_conn_id) @@ -203,7 +204,8 @@ def test_missing_file(self, mock_exists): @parameterized.expand( ( ("""CONN_A: 'mysql://host_a'""", {"CONN_A": "mysql://host_a"}), - (""" + ( + """ conn_a: mysql://hosta conn_b: conn_type: scheme @@ -216,18 +218,22 @@ def test_missing_file(self, mock_exists): extra__google_cloud_platform__keyfile_dict: a: b extra__google_cloud_platform__keyfile_path: asaa""", - {"conn_a": "mysql://hosta", - "conn_b": ''.join("""scheme://Login:None@host:1234/lschema? + { + "conn_a": "mysql://hosta", + "conn_b": ''.join( + """scheme://Login:None@host:1234/lschema? extra__google_cloud_platform__keyfile_dict=%7B%27a%27%3A+%27b%27%7D - &extra__google_cloud_platform__keyfile_path=asaa""".split())}), + &extra__google_cloud_platform__keyfile_path=asaa""".split() + ), + }, + ), ) ) def test_yaml_file_should_load_connection(self, file_content, expected_connection_uris): with mock_local_file(file_content): connections_by_conn_id = local_filesystem.load_connections_dict("a.yaml") connection_uris_by_conn_id = { - conn_id: connection.get_uri() - for conn_id, connection in connections_by_conn_id.items() + conn_id: connection.get_uri() for conn_id, connection in connections_by_conn_id.items() } self.assertEqual(expected_connection_uris, connection_uris_by_conn_id) @@ -295,7 +301,8 @@ def test_yaml_file_should_load_connection_extras(self, file_content, expected_ex @parameterized.expand( ( - ("""conn_c: + ( + """conn_c: conn_type: scheme host: host schema: lschema @@ -307,7 +314,9 @@ def test_yaml_file_should_load_connection_extras(self, file_content, expected_ex extra_dejson: aws_conn_id: bbb region_name: ccc - """, "The extra and extra_dejson parameters are mutually exclusive."), + """, + "The extra and extra_dejson parameters are mutually exclusive.", + ), ) ) def test_yaml_invalid_extra(self, file_content, expected_message): @@ -316,9 +325,7 @@ def test_yaml_invalid_extra(self, file_content, expected_message): local_filesystem.load_connections_dict("a.yaml") @parameterized.expand( - ( - "CONN_ID=mysql://host_1/\nCONN_ID=mysql://host_2/", - ), + ("CONN_ID=mysql://host_1/\nCONN_ID=mysql://host_2/",), ) def test_ensure_unique_connection_env(self, file_content): with mock_local_file(file_content): @@ -327,12 +334,8 @@ def test_ensure_unique_connection_env(self, file_content): @parameterized.expand( ( - ( - {"CONN_ID": ["mysql://host_1", "mysql://host_2"]}, - ), - ( - {"CONN_ID": [{"uri": "mysql://host_1"}, {"uri": "mysql://host_2"}]}, - ), + ({"CONN_ID": ["mysql://host_1", "mysql://host_2"]},), + ({"CONN_ID": [{"uri": "mysql://host_1"}, {"uri": "mysql://host_2"}]},), ) ) def test_ensure_unique_connection_json(self, file_content): @@ -342,10 +345,12 @@ def test_ensure_unique_connection_json(self, file_content): @parameterized.expand( ( - (""" + ( + """ conn_a: - mysql://hosta - - mysql://hostb"""), + - mysql://hostb""" + ), ), ) def test_ensure_unique_connection_yaml(self, file_content): diff --git a/tests/secrets/test_secrets.py b/tests/secrets/test_secrets.py index 53a11a14b6c90..ed13adec61dac 100644 --- a/tests/secrets/test_secrets.py +++ b/tests/secrets/test_secrets.py @@ -42,11 +42,15 @@ def test_get_connections_first_try(self, mock_env_get, mock_meta_get): mock_env_get.assert_called_once_with(conn_id="fake_conn_id") mock_meta_get.not_called() - @conf_vars({ - ("secrets", "backend"): - "airflow.providers.amazon.aws.secrets.systems_manager.SystemsManagerParameterStoreBackend", - ("secrets", "backend_kwargs"): '{"connections_prefix": "/airflow", "profile_name": null}', - }) + @conf_vars( + { + ( + "secrets", + "backend", + ): "airflow.providers.amazon.aws.secrets.systems_manager.SystemsManagerParameterStoreBackend", + ("secrets", "backend_kwargs"): '{"connections_prefix": "/airflow", "profile_name": null}', + } + ) def test_initialize_secrets_backends(self): backends = initialize_secrets_backends() backend_classes = [backend.__class__.__name__ for backend in backends] @@ -54,30 +58,44 @@ def test_initialize_secrets_backends(self): self.assertEqual(3, len(backends)) self.assertIn('SystemsManagerParameterStoreBackend', backend_classes) - @conf_vars({ - ("secrets", "backend"): - "airflow.providers.amazon.aws.secrets.systems_manager.SystemsManagerParameterStoreBackend", - ("secrets", "backend_kwargs"): '{"use_ssl": false}', - }) + @conf_vars( + { + ( + "secrets", + "backend", + ): "airflow.providers.amazon.aws.secrets.systems_manager.SystemsManagerParameterStoreBackend", + ("secrets", "backend_kwargs"): '{"use_ssl": false}', + } + ) def test_backends_kwargs(self): backends = initialize_secrets_backends() systems_manager = [ - backend for backend in backends + backend + for backend in backends if backend.__class__.__name__ == 'SystemsManagerParameterStoreBackend' ][0] self.assertEqual(systems_manager.kwargs, {'use_ssl': False}) - @conf_vars({ - ("secrets", "backend"): - "airflow.providers.amazon.aws.secrets.systems_manager.SystemsManagerParameterStoreBackend", - ("secrets", "backend_kwargs"): '{"connections_prefix": "/airflow", "profile_name": null}', - }) - @mock.patch.dict('os.environ', { - 'AIRFLOW_CONN_TEST_MYSQL': 'mysql://airflow:airflow@host:5432/airflow', - }) - @mock.patch("airflow.providers.amazon.aws.secrets.systems_manager." - "SystemsManagerParameterStoreBackend.get_conn_uri") + @conf_vars( + { + ( + "secrets", + "backend", + ): "airflow.providers.amazon.aws.secrets.systems_manager.SystemsManagerParameterStoreBackend", + ("secrets", "backend_kwargs"): '{"connections_prefix": "/airflow", "profile_name": null}', + } + ) + @mock.patch.dict( + 'os.environ', + { + 'AIRFLOW_CONN_TEST_MYSQL': 'mysql://airflow:airflow@host:5432/airflow', + }, + ) + @mock.patch( + "airflow.providers.amazon.aws.secrets.systems_manager." + "SystemsManagerParameterStoreBackend.get_conn_uri" + ) def test_backend_fallback_to_env_var(self, mock_get_uri): mock_get_uri.return_value = None @@ -94,7 +112,6 @@ def test_backend_fallback_to_env_var(self, mock_get_uri): class TestVariableFromSecrets(unittest.TestCase): - def setUp(self) -> None: clear_db_variables() diff --git a/tests/secrets/test_secrets_backends.py b/tests/secrets/test_secrets_backends.py index ced829a96ce23..5318b8b1212d9 100644 --- a/tests/secrets/test_secrets_backends.py +++ b/tests/secrets/test_secrets_backends.py @@ -36,14 +36,11 @@ def __init__(self, conn_id, variation: str): self.conn_id = conn_id self.var_name = "AIRFLOW_CONN_" + self.conn_id.upper() self.host = f"host_{variation}.com" - self.conn_uri = ( - "mysql://user:pw@" + self.host + "/schema?extra1=val%2B1&extra2=val%2B2" - ) + self.conn_uri = "mysql://user:pw@" + self.host + "/schema?extra1=val%2B1&extra2=val%2B2" self.conn = Connection(conn_id=self.conn_id, uri=self.conn_uri) class TestBaseSecretsBackend(unittest.TestCase): - def setUp(self) -> None: clear_db_variables() @@ -51,10 +48,12 @@ def tearDown(self) -> None: clear_db_connections() clear_db_variables() - @parameterized.expand([ - ('default', {"path_prefix": "PREFIX", "secret_id": "ID"}, "PREFIX/ID"), - ('with_sep', {"path_prefix": "PREFIX", "secret_id": "ID", "sep": "-"}, "PREFIX-ID") - ]) + @parameterized.expand( + [ + ('default', {"path_prefix": "PREFIX", "secret_id": "ID"}, "PREFIX/ID"), + ('with_sep', {"path_prefix": "PREFIX", "secret_id": "ID", "sep": "-"}, "PREFIX-ID"), + ] + ) def test_build_path(self, _, kwargs, output): build_path = BaseSecretsBackend.build_path self.assertEqual(build_path(**kwargs), output) @@ -78,14 +77,15 @@ def test_connection_metastore_secrets_backend(self): metastore_backend = MetastoreBackend() conn_list = metastore_backend.get_connections("sample_2") host_list = {x.host for x in conn_list} - self.assertEqual( - {sample_conn_2.host.lower()}, set(host_list) - ) + self.assertEqual({sample_conn_2.host.lower()}, set(host_list)) - @mock.patch.dict('os.environ', { - 'AIRFLOW_VAR_HELLO': 'World', - 'AIRFLOW_VAR_EMPTY_STR': '', - }) + @mock.patch.dict( + 'os.environ', + { + 'AIRFLOW_VAR_HELLO': 'World', + 'AIRFLOW_VAR_EMPTY_STR': '', + }, + ) def test_variable_env_secrets_backend(self): env_secrets_backend = EnvironmentVariablesBackend() variable_value = env_secrets_backend.get_variable(key="hello") diff --git a/tests/security/test_kerberos.py b/tests/security/test_kerberos.py index 93e880990d90a..3f08540f38491 100644 --- a/tests/security/test_kerberos.py +++ b/tests/security/test_kerberos.py @@ -30,16 +30,18 @@ @unittest.skipIf(KRB5_KTNAME is None, 'Skipping Kerberos API tests due to missing KRB5_KTNAME') class TestKerberos(unittest.TestCase): def setUp(self): - self.args = Namespace(keytab=KRB5_KTNAME, principal=None, pid=None, - daemon=None, stdout=None, stderr=None, log_file=None) + self.args = Namespace( + keytab=KRB5_KTNAME, principal=None, pid=None, daemon=None, stdout=None, stderr=None, log_file=None + ) @conf_vars({('kerberos', 'keytab'): KRB5_KTNAME}) def test_renew_from_kt(self): """ We expect no result, but a successful run. No more TypeError """ - self.assertIsNone(renew_from_kt(principal=self.args.principal, # pylint: disable=no-member - keytab=self.args.keytab)) + self.assertIsNone( + renew_from_kt(principal=self.args.principal, keytab=self.args.keytab) # pylint: disable=no-member + ) @conf_vars({('kerberos', 'keytab'): ''}) def test_args_from_cli(self): @@ -49,13 +51,14 @@ def test_args_from_cli(self): self.args.keytab = "test_keytab" with self.assertRaises(SystemExit) as err: - renew_from_kt(principal=self.args.principal, # pylint: disable=no-member - keytab=self.args.keytab) + renew_from_kt(principal=self.args.principal, keytab=self.args.keytab) # pylint: disable=no-member with self.assertLogs(kerberos.log) as log: self.assertIn( 'kinit: krb5_init_creds_set_keytab: Failed to find ' 'airflow@LUPUS.GRIDDYNAMICS.NET in keytab FILE:{} ' - '(unknown enctype)'.format(self.args.keytab), log.output) + '(unknown enctype)'.format(self.args.keytab), + log.output, + ) self.assertEqual(err.exception.code, 1) diff --git a/tests/sensors/test_base_sensor.py b/tests/sensors/test_base_sensor.py index cc0bbef359524..be3ee694f9876 100644 --- a/tests/sensors/test_base_sensor.py +++ b/tests/sensors/test_base_sensor.py @@ -59,10 +59,7 @@ def clean_db(): db.clear_db_xcom() def setUp(self): - args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - } + args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} self.dag = DAG(TEST_DAG_ID, default_args=args) self.clean_db() @@ -74,7 +71,7 @@ def _make_dag_run(self): run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, - state=State.RUNNING + state=State.RUNNING, ) def _make_sensor(self, return_value, task_id=SENSOR_OP, **kwargs): @@ -86,17 +83,9 @@ def _make_sensor(self, return_value, task_id=SENSOR_OP, **kwargs): if timeout not in kwargs: kwargs[timeout] = 0 - sensor = DummySensor( - task_id=task_id, - return_value=return_value, - dag=self.dag, - **kwargs - ) + sensor = DummySensor(task_id=task_id, return_value=return_value, dag=self.dag, **kwargs) - dummy_op = DummyOperator( - task_id=DUMMY_OP, - dag=self.dag - ) + dummy_op = DummyOperator(task_id=DUMMY_OP, dag=self.dag) dummy_op.set_upstream(sensor) return sensor @@ -146,10 +135,8 @@ def test_soft_fail(self): def test_soft_fail_with_retries(self): sensor = self._make_sensor( - return_value=False, - soft_fail=True, - retries=1, - retry_delay=timedelta(milliseconds=1)) + return_value=False, soft_fail=True, retries=1, retry_delay=timedelta(milliseconds=1) + ) dr = self._make_dag_run() # first run fails and task instance is marked up to retry @@ -175,11 +162,7 @@ def test_soft_fail_with_retries(self): self.assertEqual(ti.state, State.NONE) def test_ok_with_reschedule(self): - sensor = self._make_sensor( - return_value=None, - poke_interval=10, - timeout=25, - mode='reschedule') + sensor = self._make_sensor(return_value=None, poke_interval=10, timeout=25, mode='reschedule') sensor.poke = Mock(side_effect=[False, False, True]) dr = self._make_dag_run() @@ -199,8 +182,9 @@ def test_ok_with_reschedule(self): task_reschedules = TaskReschedule.find_for_task_instance(ti) self.assertEqual(len(task_reschedules), 1) self.assertEqual(task_reschedules[0].start_date, date1) - self.assertEqual(task_reschedules[0].reschedule_date, - date1 + timedelta(seconds=sensor.poke_interval)) + self.assertEqual( + task_reschedules[0].reschedule_date, date1 + timedelta(seconds=sensor.poke_interval) + ) if ti.task_id == DUMMY_OP: self.assertEqual(ti.state, State.NONE) @@ -220,8 +204,9 @@ def test_ok_with_reschedule(self): task_reschedules = TaskReschedule.find_for_task_instance(ti) self.assertEqual(len(task_reschedules), 2) self.assertEqual(task_reschedules[1].start_date, date2) - self.assertEqual(task_reschedules[1].reschedule_date, - date2 + timedelta(seconds=sensor.poke_interval)) + self.assertEqual( + task_reschedules[1].reschedule_date, date2 + timedelta(seconds=sensor.poke_interval) + ) if ti.task_id == DUMMY_OP: self.assertEqual(ti.state, State.NONE) @@ -240,11 +225,7 @@ def test_ok_with_reschedule(self): self.assertEqual(ti.state, State.NONE) def test_fail_with_reschedule(self): - sensor = self._make_sensor( - return_value=False, - poke_interval=10, - timeout=5, - mode='reschedule') + sensor = self._make_sensor(return_value=False, poke_interval=10, timeout=5, mode='reschedule') dr = self._make_dag_run() # first poke returns False and task is re-scheduled @@ -274,11 +255,8 @@ def test_fail_with_reschedule(self): def test_soft_fail_with_reschedule(self): sensor = self._make_sensor( - return_value=False, - poke_interval=10, - timeout=5, - soft_fail=True, - mode='reschedule') + return_value=False, poke_interval=10, timeout=5, soft_fail=True, mode='reschedule' + ) dr = self._make_dag_run() # first poke returns False and task is re-scheduled @@ -312,7 +290,8 @@ def test_ok_with_reschedule_and_retry(self): timeout=5, retries=1, retry_delay=timedelta(seconds=10), - mode='reschedule') + mode='reschedule', + ) sensor.poke = Mock(side_effect=[False, False, False, True]) dr = self._make_dag_run() @@ -329,8 +308,9 @@ def test_ok_with_reschedule_and_retry(self): task_reschedules = TaskReschedule.find_for_task_instance(ti) self.assertEqual(len(task_reschedules), 1) self.assertEqual(task_reschedules[0].start_date, date1) - self.assertEqual(task_reschedules[0].reschedule_date, - date1 + timedelta(seconds=sensor.poke_interval)) + self.assertEqual( + task_reschedules[0].reschedule_date, date1 + timedelta(seconds=sensor.poke_interval) + ) self.assertEqual(task_reschedules[0].try_number, 1) if ti.task_id == DUMMY_OP: self.assertEqual(ti.state, State.NONE) @@ -361,8 +341,9 @@ def test_ok_with_reschedule_and_retry(self): task_reschedules = TaskReschedule.find_for_task_instance(ti) self.assertEqual(len(task_reschedules), 1) self.assertEqual(task_reschedules[0].start_date, date3) - self.assertEqual(task_reschedules[0].reschedule_date, - date3 + timedelta(seconds=sensor.poke_interval)) + self.assertEqual( + task_reschedules[0].reschedule_date, date3 + timedelta(seconds=sensor.poke_interval) + ) self.assertEqual(task_reschedules[0].try_number, 2) if ti.task_id == DUMMY_OP: self.assertEqual(ti.state, State.NONE) @@ -391,22 +372,20 @@ def test_should_not_include_ready_to_reschedule_dep_in_poke_mode(self): def test_invalid_mode(self): with self.assertRaises(AirflowException): - self._make_sensor( - return_value=True, - mode='foo') + self._make_sensor(return_value=True, mode='foo') def test_ok_with_custom_reschedule_exception(self): - sensor = self._make_sensor( - return_value=None, - mode='reschedule') + sensor = self._make_sensor(return_value=None, mode='reschedule') date1 = timezone.utcnow() date2 = date1 + timedelta(seconds=60) date3 = date1 + timedelta(seconds=120) - sensor.poke = Mock(side_effect=[ - AirflowRescheduleException(date2), - AirflowRescheduleException(date3), - True, - ]) + sensor.poke = Mock( + side_effect=[ + AirflowRescheduleException(date2), + AirflowRescheduleException(date3), + True, + ] + ) dr = self._make_dag_run() # first poke returns False and task is re-scheduled @@ -455,11 +434,7 @@ def test_ok_with_custom_reschedule_exception(self): self.assertEqual(ti.state, State.NONE) def test_reschedule_with_test_mode(self): - sensor = self._make_sensor( - return_value=None, - poke_interval=10, - timeout=25, - mode='reschedule') + sensor = self._make_sensor(return_value=None, poke_interval=10, timeout=25, mode='reschedule') sensor.poke = Mock(side_effect=[False]) dr = self._make_dag_run() @@ -467,9 +442,7 @@ def test_reschedule_with_test_mode(self): date1 = timezone.utcnow() with freeze_time(date1): for date in self.dag.date_range(DEFAULT_DATE, end_date=DEFAULT_DATE): - TaskInstance(sensor, date).run( - ignore_ti_state=True, - test_mode=True) + TaskInstance(sensor, date).run(ignore_ti_state=True, test_mode=True) tis = dr.get_task_instances() self.assertEqual(len(tis), 2) for ti in tis: @@ -491,20 +464,20 @@ def test_sensor_with_invalid_poke_interval(self): task_id='test_sensor_task_1', return_value=None, poke_interval=negative_poke_interval, - timeout=25) + timeout=25, + ) with self.assertRaises(AirflowException): self._make_sensor( task_id='test_sensor_task_2', return_value=None, poke_interval=non_number_poke_interval, - timeout=25) + timeout=25, + ) self._make_sensor( - task_id='test_sensor_task_3', - return_value=None, - poke_interval=positive_poke_interval, - timeout=25) + task_id='test_sensor_task_3', return_value=None, poke_interval=positive_poke_interval, timeout=25 + ) def test_sensor_with_invalid_timeout(self): negative_timeout = -25 @@ -512,30 +485,20 @@ def test_sensor_with_invalid_timeout(self): positive_timeout = 25 with self.assertRaises(AirflowException): self._make_sensor( - task_id='test_sensor_task_1', - return_value=None, - poke_interval=10, - timeout=negative_timeout) + task_id='test_sensor_task_1', return_value=None, poke_interval=10, timeout=negative_timeout + ) with self.assertRaises(AirflowException): self._make_sensor( - task_id='test_sensor_task_2', - return_value=None, - poke_interval=10, - timeout=non_number_timeout) + task_id='test_sensor_task_2', return_value=None, poke_interval=10, timeout=non_number_timeout + ) self._make_sensor( - task_id='test_sensor_task_3', - return_value=None, - poke_interval=10, - timeout=positive_timeout) + task_id='test_sensor_task_3', return_value=None, poke_interval=10, timeout=positive_timeout + ) def test_sensor_with_exponential_backoff_off(self): - sensor = self._make_sensor( - return_value=None, - poke_interval=5, - timeout=60, - exponential_backoff=False) + sensor = self._make_sensor(return_value=None, poke_interval=5, timeout=60, exponential_backoff=False) started_at = timezone.utcnow() - timedelta(seconds=10) self.assertEqual(sensor._get_next_poke_interval(started_at, 1), sensor.poke_interval) @@ -543,11 +506,7 @@ def test_sensor_with_exponential_backoff_off(self): def test_sensor_with_exponential_backoff_on(self): - sensor = self._make_sensor( - return_value=None, - poke_interval=5, - timeout=60, - exponential_backoff=True) + sensor = self._make_sensor(return_value=None, poke_interval=5, timeout=60, exponential_backoff=True) with patch('airflow.utils.timezone.utcnow') as mock_utctime: mock_utctime.return_value = DEFAULT_DATE @@ -582,22 +541,14 @@ def change_mode(self, mode): class TestPokeModeOnly(unittest.TestCase): - def setUp(self): - self.dagbag = DagBag( - dag_folder=DEV_NULL, - include_examples=True - ) - self.args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - } + self.dagbag = DagBag(dag_folder=DEV_NULL, include_examples=True) + self.args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} self.dag = DAG(TEST_DAG_ID, default_args=self.args) def test_poke_mode_only_allows_poke_mode(self): try: - sensor = DummyPokeOnlySensor(task_id='foo', mode='poke', poke_changes_mode=False, - dag=self.dag) + sensor = DummyPokeOnlySensor(task_id='foo', mode='poke', poke_changes_mode=False, dag=self.dag) except ValueError: self.fail("__init__ failed with mode='poke'.") try: @@ -610,18 +561,15 @@ def test_poke_mode_only_allows_poke_mode(self): self.fail("class method failed without changing mode from 'poke'.") def test_poke_mode_only_bad_class_method(self): - sensor = DummyPokeOnlySensor(task_id='foo', mode='poke', poke_changes_mode=False, - dag=self.dag) + sensor = DummyPokeOnlySensor(task_id='foo', mode='poke', poke_changes_mode=False, dag=self.dag) with self.assertRaises(ValueError): sensor.change_mode('reschedule') def test_poke_mode_only_bad_init(self): with self.assertRaises(ValueError): - DummyPokeOnlySensor(task_id='foo', mode='reschedule', - poke_changes_mode=False, dag=self.dag) + DummyPokeOnlySensor(task_id='foo', mode='reschedule', poke_changes_mode=False, dag=self.dag) def test_poke_mode_only_bad_poke(self): - sensor = DummyPokeOnlySensor(task_id='foo', mode='poke', poke_changes_mode=True, - dag=self.dag) + sensor = DummyPokeOnlySensor(task_id='foo', mode='poke', poke_changes_mode=True, dag=self.dag) with self.assertRaises(ValueError): sensor.poke({}) diff --git a/tests/sensors/test_bash.py b/tests/sensors/test_bash.py index 3322ba2b74b1b..934d67cd74ec4 100644 --- a/tests/sensors/test_bash.py +++ b/tests/sensors/test_bash.py @@ -27,10 +27,7 @@ class TestBashSensor(unittest.TestCase): def setUp(self): - args = { - 'owner': 'airflow', - 'start_date': datetime.datetime(2017, 1, 1) - } + args = {'owner': 'airflow', 'start_date': datetime.datetime(2017, 1, 1)} dag = DAG('test_dag_id', default_args=args) self.dag = dag @@ -41,7 +38,7 @@ def test_true_condition(self): output_encoding='utf-8', poke_interval=1, timeout=2, - dag=self.dag + dag=self.dag, ) op.execute(None) @@ -52,7 +49,7 @@ def test_false_condition(self): output_encoding='utf-8', poke_interval=1, timeout=2, - dag=self.dag + dag=self.dag, ) with self.assertRaises(AirflowSensorTimeout): op.execute(None) diff --git a/tests/sensors/test_date_time_sensor.py b/tests/sensors/test_date_time_sensor.py index 88370153c7397..87b7dbdbdd221 100644 --- a/tests/sensors/test_date_time_sensor.py +++ b/tests/sensors/test_date_time_sensor.py @@ -44,13 +44,19 @@ def setup_class(cls): ] ) def test_valid_input(self, task_id, target_time, expected): - op = DateTimeSensor(task_id=task_id, target_time=target_time, dag=self.dag,) + op = DateTimeSensor( + task_id=task_id, + target_time=target_time, + dag=self.dag, + ) assert op.target_time == expected def test_invalid_input(self): with pytest.raises(TypeError): DateTimeSensor( - task_id="test", target_time=timezone.utcnow().time(), dag=self.dag, + task_id="test", + target_time=timezone.utcnow().time(), + dag=self.dag, ) @parameterized.expand( diff --git a/tests/sensors/test_external_task_sensor.py b/tests/sensors/test_external_task_sensor.py index 7a1e917e8b31b..a99edf9260e0e 100644 --- a/tests/sensors/test_external_task_sensor.py +++ b/tests/sensors/test_external_task_sensor.py @@ -39,24 +39,13 @@ class TestExternalTaskSensor(unittest.TestCase): - def setUp(self): - self.dagbag = DagBag( - dag_folder=DEV_NULL, - include_examples=True - ) - self.args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - } + self.dagbag = DagBag(dag_folder=DEV_NULL, include_examples=True) + self.args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} self.dag = DAG(TEST_DAG_ID, default_args=self.args) def test_time_sensor(self): - op = TimeSensor( - task_id=TEST_TASK_ID, - target_time=time(0), - dag=self.dag - ) + op = TimeSensor(task_id=TEST_TASK_ID, target_time=time(0), dag=self.dag) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_external_task_sensor(self): @@ -65,13 +54,9 @@ def test_external_task_sensor(self): task_id='test_external_task_sensor_check', external_dag_id=TEST_DAG_ID, external_task_id=TEST_TASK_ID, - dag=self.dag - ) - op.run( - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE, - ignore_ti_state=True + dag=self.dag, ) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_catch_overlap_allowed_failed_state(self): with self.assertRaises(AirflowException): @@ -81,7 +66,7 @@ def test_catch_overlap_allowed_failed_state(self): external_task_id=TEST_TASK_ID, allowed_states=[State.SUCCESS], failed_states=[State.SUCCESS], - dag=self.dag + dag=self.dag, ) def test_external_task_sensor_wrong_failed_states(self): @@ -91,7 +76,7 @@ def test_external_task_sensor_wrong_failed_states(self): external_dag_id=TEST_DAG_ID, external_task_id=TEST_TASK_ID, failed_states=["invalid_state"], - dag=self.dag + dag=self.dag, ) def test_external_task_sensor_failed_states(self): @@ -101,13 +86,9 @@ def test_external_task_sensor_failed_states(self): external_dag_id=TEST_DAG_ID, external_task_id=TEST_TASK_ID, failed_states=["failed"], - dag=self.dag - ) - op.run( - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE, - ignore_ti_state=True + dag=self.dag, ) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_external_task_sensor_failed_states_as_success(self): self.test_time_sensor() @@ -117,57 +98,38 @@ def test_external_task_sensor_failed_states_as_success(self): external_task_id=TEST_TASK_ID, allowed_states=["failed"], failed_states=["success"], - dag=self.dag + dag=self.dag, ) with self.assertRaises(AirflowException) as cm: - op.run( - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE, - ignore_ti_state=True - ) - self.assertEqual(str(cm.exception), - "The external task " - "time_sensor_check in DAG " - "unit_test_dag failed.") + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + self.assertEqual( + str(cm.exception), "The external task " "time_sensor_check in DAG " "unit_test_dag failed." + ) def test_external_dag_sensor(self): - other_dag = DAG( - 'other_dag', - default_args=self.args, - end_date=DEFAULT_DATE, - schedule_interval='@once') + other_dag = DAG('other_dag', default_args=self.args, end_date=DEFAULT_DATE, schedule_interval='@once') other_dag.create_dagrun( - run_id='test', - start_date=DEFAULT_DATE, - execution_date=DEFAULT_DATE, - state=State.SUCCESS) + run_id='test', start_date=DEFAULT_DATE, execution_date=DEFAULT_DATE, state=State.SUCCESS + ) op = ExternalTaskSensor( task_id='test_external_dag_sensor_check', external_dag_id='other_dag', external_task_id=None, - dag=self.dag - ) - op.run( - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE, - ignore_ti_state=True + dag=self.dag, ) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_templated_sensor(self): with self.dag: sensor = ExternalTaskSensor( - task_id='templated_task', - external_dag_id='dag_{{ ds }}', - external_task_id='task_{{ ds }}' + task_id='templated_task', external_dag_id='dag_{{ ds }}', external_task_id='task_{{ ds }}' ) instance = TaskInstance(sensor, DEFAULT_DATE) instance.render_templates() - self.assertEqual(sensor.external_dag_id, - f"dag_{DEFAULT_DATE.date()}") - self.assertEqual(sensor.external_task_id, - f"task_{DEFAULT_DATE.date()}") + self.assertEqual(sensor.external_dag_id, f"dag_{DEFAULT_DATE.date()}") + self.assertEqual(sensor.external_task_id, f"task_{DEFAULT_DATE.date()}") def test_external_task_sensor_fn_multiple_execution_dates(self): bash_command_code = """ @@ -180,84 +142,71 @@ def test_external_task_sensor_fn_multiple_execution_dates(self): exit 0 """ dag_external_id = TEST_DAG_ID + '_external' - dag_external = DAG( - dag_external_id, - default_args=self.args, - schedule_interval=timedelta(seconds=1)) + dag_external = DAG(dag_external_id, default_args=self.args, schedule_interval=timedelta(seconds=1)) task_external_with_failure = BashOperator( - task_id="task_external_with_failure", - bash_command=bash_command_code, - retries=0, - dag=dag_external) + task_id="task_external_with_failure", bash_command=bash_command_code, retries=0, dag=dag_external + ) task_external_without_failure = DummyOperator( - task_id="task_external_without_failure", - retries=0, - dag=dag_external) + task_id="task_external_without_failure", retries=0, dag=dag_external + ) task_external_without_failure.run( - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + timedelta(seconds=1), - ignore_ti_state=True) + start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + timedelta(seconds=1), ignore_ti_state=True + ) session = settings.Session() TI = TaskInstance try: task_external_with_failure.run( - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + timedelta(seconds=1), - ignore_ti_state=True) + start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + timedelta(seconds=1), ignore_ti_state=True + ) # The test_with_failure task is excepted to fail # once per minute (the run on the first second of # each minute). except Exception as e: # pylint: disable=broad-except - failed_tis = session.query(TI).filter( - TI.dag_id == dag_external_id, - TI.state == State.FAILED, - TI.execution_date == DEFAULT_DATE + timedelta(seconds=1)).all() - if len(failed_tis) == 1 and \ - failed_tis[0].task_id == 'task_external_with_failure': + failed_tis = ( + session.query(TI) + .filter( + TI.dag_id == dag_external_id, + TI.state == State.FAILED, + TI.execution_date == DEFAULT_DATE + timedelta(seconds=1), + ) + .all() + ) + if len(failed_tis) == 1 and failed_tis[0].task_id == 'task_external_with_failure': pass else: raise e dag_id = TEST_DAG_ID - dag = DAG( - dag_id, - default_args=self.args, - schedule_interval=timedelta(minutes=1)) + dag = DAG(dag_id, default_args=self.args, schedule_interval=timedelta(minutes=1)) task_without_failure = ExternalTaskSensor( task_id='task_without_failure', external_dag_id=dag_external_id, external_task_id='task_external_without_failure', - execution_date_fn=lambda dt: [dt + timedelta(seconds=i) - for i in range(2)], + execution_date_fn=lambda dt: [dt + timedelta(seconds=i) for i in range(2)], allowed_states=['success'], retries=0, timeout=1, poke_interval=1, - dag=dag) + dag=dag, + ) task_with_failure = ExternalTaskSensor( task_id='task_with_failure', external_dag_id=dag_external_id, external_task_id='task_external_with_failure', - execution_date_fn=lambda dt: [dt + timedelta(seconds=i) - for i in range(2)], + execution_date_fn=lambda dt: [dt + timedelta(seconds=i) for i in range(2)], allowed_states=['success'], retries=0, timeout=1, poke_interval=1, - dag=dag) + dag=dag, + ) - task_without_failure.run( - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE, - ignore_ti_state=True) + task_without_failure.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) with self.assertRaises(AirflowSensorTimeout): - task_with_failure.run( - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE, - ignore_ti_state=True) + task_with_failure.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_external_task_sensor_delta(self): self.test_time_sensor() @@ -267,13 +216,9 @@ def test_external_task_sensor_delta(self): external_task_id=TEST_TASK_ID, execution_delta=timedelta(0), allowed_states=['success'], - dag=self.dag - ) - op.run( - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE, - ignore_ti_state=True + dag=self.dag, ) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_external_task_sensor_fn(self): self.test_time_sensor() @@ -284,13 +229,9 @@ def test_external_task_sensor_fn(self): external_task_id=TEST_TASK_ID, execution_date_fn=lambda dt: dt + timedelta(0), allowed_states=['success'], - dag=self.dag - ) - op1.run( - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE, - ignore_ti_state=True + dag=self.dag, ) + op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) # double check that the execution is being called by failing the test op2 = ExternalTaskSensor( task_id='test_external_task_sensor_check_delta_2', @@ -300,14 +241,10 @@ def test_external_task_sensor_fn(self): allowed_states=['success'], timeout=1, poke_interval=1, - dag=self.dag + dag=self.dag, ) with self.assertRaises(exceptions.AirflowSensorTimeout): - op2.run( - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE, - ignore_ti_state=True - ) + op2.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_external_task_sensor_fn_multiple_args(self): """Check this task sensor passes multiple args with full context. If no failure, means clean run.""" @@ -323,13 +260,9 @@ def my_func(dt, context): external_task_id=TEST_TASK_ID, execution_date_fn=my_func, allowed_states=['success'], - dag=self.dag - ) - op1.run( - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE, - ignore_ti_state=True + dag=self.dag, ) + op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_external_task_sensor_error_delta_and_fn(self): self.test_time_sensor() @@ -342,7 +275,7 @@ def test_external_task_sensor_error_delta_and_fn(self): execution_delta=timedelta(0), execution_date_fn=lambda dt: dt, allowed_states=['success'], - dag=self.dag + dag=self.dag, ) def test_catch_invalid_allowed_states(self): @@ -352,7 +285,7 @@ def test_catch_invalid_allowed_states(self): external_dag_id=TEST_DAG_ID, external_task_id=TEST_TASK_ID, allowed_states=['invalid_state'], - dag=self.dag + dag=self.dag, ) with self.assertRaises(ValueError): @@ -361,7 +294,7 @@ def test_catch_invalid_allowed_states(self): external_dag_id=TEST_DAG_ID, external_task_id=None, allowed_states=['invalid_state'], - dag=self.dag + dag=self.dag, ) def test_external_task_sensor_waits_for_task_check_existence(self): @@ -370,15 +303,11 @@ def test_external_task_sensor_waits_for_task_check_existence(self): external_dag_id="example_bash_operator", external_task_id="non-existing-task", check_existence=True, - dag=self.dag + dag=self.dag, ) with self.assertRaises(AirflowException): - op.run( - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE, - ignore_ti_state=True - ) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_external_task_sensor_waits_for_dag_check_existence(self): op = ExternalTaskSensor( @@ -386,15 +315,11 @@ def test_external_task_sensor_waits_for_dag_check_existence(self): external_dag_id="non-existing-dag", external_task_id=None, check_existence=True, - dag=self.dag + dag=self.dag, ) with self.assertRaises(AirflowException): - op.run( - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE, - ignore_ti_state=True - ) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) class TestExternalTaskMarker(unittest.TestCase): @@ -407,7 +332,7 @@ def test_serialized_external_task_marker(self): task_id="parent_task", external_dag_id="external_task_marker_child", external_task_id="child_task1", - dag=dag + dag=dag, ) serialized_op = SerializedBaseOperator.serialize_operator(task) @@ -438,42 +363,33 @@ def dag_bag_ext(): dag_0 = DAG("dag_0", start_date=DEFAULT_DATE, schedule_interval=None) task_a_0 = DummyOperator(task_id="task_a_0", dag=dag_0) - task_b_0 = ExternalTaskMarker(task_id="task_b_0", - external_dag_id="dag_1", - external_task_id="task_a_1", - recursion_depth=3, - dag=dag_0) + task_b_0 = ExternalTaskMarker( + task_id="task_b_0", external_dag_id="dag_1", external_task_id="task_a_1", recursion_depth=3, dag=dag_0 + ) task_a_0 >> task_b_0 dag_1 = DAG("dag_1", start_date=DEFAULT_DATE, schedule_interval=None) - task_a_1 = ExternalTaskSensor(task_id="task_a_1", - external_dag_id=dag_0.dag_id, - external_task_id=task_b_0.task_id, - dag=dag_1) - task_b_1 = ExternalTaskMarker(task_id="task_b_1", - external_dag_id="dag_2", - external_task_id="task_a_2", - recursion_depth=2, - dag=dag_1) + task_a_1 = ExternalTaskSensor( + task_id="task_a_1", external_dag_id=dag_0.dag_id, external_task_id=task_b_0.task_id, dag=dag_1 + ) + task_b_1 = ExternalTaskMarker( + task_id="task_b_1", external_dag_id="dag_2", external_task_id="task_a_2", recursion_depth=2, dag=dag_1 + ) task_a_1 >> task_b_1 dag_2 = DAG("dag_2", start_date=DEFAULT_DATE, schedule_interval=None) - task_a_2 = ExternalTaskSensor(task_id="task_a_2", - external_dag_id=dag_1.dag_id, - external_task_id=task_b_1.task_id, - dag=dag_2) - task_b_2 = ExternalTaskMarker(task_id="task_b_2", - external_dag_id="dag_3", - external_task_id="task_a_3", - recursion_depth=1, - dag=dag_2) + task_a_2 = ExternalTaskSensor( + task_id="task_a_2", external_dag_id=dag_1.dag_id, external_task_id=task_b_1.task_id, dag=dag_2 + ) + task_b_2 = ExternalTaskMarker( + task_id="task_b_2", external_dag_id="dag_3", external_task_id="task_a_3", recursion_depth=1, dag=dag_2 + ) task_a_2 >> task_b_2 dag_3 = DAG("dag_3", start_date=DEFAULT_DATE, schedule_interval=None) - task_a_3 = ExternalTaskSensor(task_id="task_a_3", - external_dag_id=dag_2.dag_id, - external_task_id=task_b_2.task_id, - dag=dag_3) + task_a_3 = ExternalTaskSensor( + task_id="task_a_3", external_dag_id=dag_2.dag_id, external_task_id=task_b_2.task_id, dag=dag_3 + ) task_b_3 = DummyOperator(task_id="task_b_3", dag=dag_3) task_a_3 >> task_b_3 @@ -587,23 +503,18 @@ def dag_bag_cyclic(): dag_0 = DAG("dag_0", start_date=DEFAULT_DATE, schedule_interval=None) task_a_0 = DummyOperator(task_id="task_a_0", dag=dag_0) - task_b_0 = ExternalTaskMarker(task_id="task_b_0", - external_dag_id="dag_1", - external_task_id="task_a_1", - recursion_depth=3, - dag=dag_0) + task_b_0 = ExternalTaskMarker( + task_id="task_b_0", external_dag_id="dag_1", external_task_id="task_a_1", recursion_depth=3, dag=dag_0 + ) task_a_0 >> task_b_0 dag_1 = DAG("dag_1", start_date=DEFAULT_DATE, schedule_interval=None) - task_a_1 = ExternalTaskSensor(task_id="task_a_1", - external_dag_id=dag_0.dag_id, - external_task_id=task_b_0.task_id, - dag=dag_1) - task_b_1 = ExternalTaskMarker(task_id="task_b_1", - external_dag_id="dag_0", - external_task_id="task_a_0", - recursion_depth=2, - dag=dag_1) + task_a_1 = ExternalTaskSensor( + task_id="task_a_1", external_dag_id=dag_0.dag_id, external_task_id=task_b_0.task_id, dag=dag_1 + ) + task_b_1 = ExternalTaskMarker( + task_id="task_b_1", external_dag_id="dag_0", external_task_id="task_a_0", recursion_depth=2, dag=dag_1 + ) task_a_1 >> task_b_1 for dag in [dag_0, dag_1]: @@ -639,11 +550,13 @@ def dag_bag_multiple(): start = DummyOperator(task_id="start", dag=agg_dag) for i in range(25): - task = ExternalTaskMarker(task_id=f"{daily_task.task_id}_{i}", - external_dag_id=daily_dag.dag_id, - external_task_id=daily_task.task_id, - execution_date="{{ macros.ds_add(ds, -1 * %s) }}" % i, - dag=agg_dag) + task = ExternalTaskMarker( + task_id=f"{daily_task.task_id}_{i}", + external_dag_id=daily_dag.dag_id, + external_task_id=daily_task.task_id, + execution_date="{{ macros.ds_add(ds, -1 * %s) }}" % i, + dag=agg_dag, + ) start >> task yield dag_bag @@ -689,16 +602,20 @@ def dag_bag_head_tail(): """ dag_bag = DagBag(dag_folder=DEV_NULL, include_examples=False) with DAG("head_tail", start_date=DEFAULT_DATE, schedule_interval="@daily") as dag: - head = ExternalTaskSensor(task_id='head', - external_dag_id=dag.dag_id, - external_task_id="tail", - execution_delta=timedelta(days=1), - mode="reschedule") + head = ExternalTaskSensor( + task_id='head', + external_dag_id=dag.dag_id, + external_task_id="tail", + execution_delta=timedelta(days=1), + mode="reschedule", + ) body = DummyOperator(task_id="body") - tail = ExternalTaskMarker(task_id="tail", - external_dag_id=dag.dag_id, - external_task_id=head.task_id, - execution_date="{{ tomorrow_ds_nodash }}") + tail = ExternalTaskMarker( + task_id="tail", + external_dag_id=dag.dag_id, + external_task_id=head.task_id, + execution_date="{{ tomorrow_ds_nodash }}", + ) head >> body >> tail dag_bag.bag_dag(dag=dag, root_dag=dag) diff --git a/tests/sensors/test_filesystem.py b/tests/sensors/test_filesystem.py index 03ba2758b9e3e..8332f9c553cf7 100644 --- a/tests/sensors/test_filesystem.py +++ b/tests/sensors/test_filesystem.py @@ -34,11 +34,9 @@ class TestFileSensor(unittest.TestCase): def setUp(self): from airflow.hooks.filesystem import FSHook + hook = FSHook() - args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - } + args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args) self.hook = hook self.dag = dag @@ -53,8 +51,7 @@ def test_simple(self): timeout=0, ) task._hook = self.hook - task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, - ignore_ti_state=True) + task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_file_in_nonexistent_dir(self): temp_dir = tempfile.mkdtemp() @@ -64,13 +61,12 @@ def test_file_in_nonexistent_dir(self): fs_conn_id='fs_default', dag=self.dag, timeout=0, - poke_interval=1 + poke_interval=1, ) task._hook = self.hook try: with self.assertRaises(AirflowSensorTimeout): - task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, - ignore_ti_state=True) + task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) finally: shutil.rmtree(temp_dir) @@ -82,13 +78,12 @@ def test_empty_dir(self): fs_conn_id='fs_default', dag=self.dag, timeout=0, - poke_interval=1 + poke_interval=1, ) task._hook = self.hook try: with self.assertRaises(AirflowSensorTimeout): - task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, - ignore_ti_state=True) + task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) finally: shutil.rmtree(temp_dir) @@ -105,8 +100,7 @@ def test_file_in_dir(self): try: # `touch` the dir open(temp_dir + "/file", "a").close() - task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, - ignore_ti_state=True) + task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) finally: shutil.rmtree(temp_dir) @@ -119,8 +113,7 @@ def test_default_fs_conn_id(self): timeout=0, ) task._hook = self.hook - task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, - ignore_ti_state=True) + task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_wildcard_file(self): suffix = '.txt' @@ -134,8 +127,7 @@ def test_wildcard_file(self): timeout=0, ) task._hook = self.hook - task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, - ignore_ti_state=True) + task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_subdirectory_not_empty(self): suffix = '.txt' @@ -151,8 +143,7 @@ def test_subdirectory_not_empty(self): timeout=0, ) task._hook = self.hook - task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, - ignore_ti_state=True) + task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) shutil.rmtree(temp_dir) def test_subdirectory_empty(self): @@ -164,11 +155,10 @@ def test_subdirectory_empty(self): fs_conn_id='fs_default', dag=self.dag, timeout=0, - poke_interval=1 + poke_interval=1, ) task._hook = self.hook with self.assertRaises(AirflowSensorTimeout): - task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, - ignore_ti_state=True) + task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) shutil.rmtree(temp_dir) diff --git a/tests/sensors/test_python.py b/tests/sensors/test_python.py index 445496c8b6714..c8eee8428420e 100644 --- a/tests/sensors/test_python.py +++ b/tests/sensors/test_python.py @@ -31,12 +31,8 @@ class TestPythonSensor(TestPythonBase): - def test_python_sensor_true(self): - op = PythonSensor( - task_id='python_sensor_check_true', - python_callable=lambda: True, - dag=self.dag) + op = PythonSensor(task_id='python_sensor_check_true', python_callable=lambda: True, dag=self.dag) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_python_sensor_false(self): @@ -45,15 +41,13 @@ def test_python_sensor_false(self): timeout=0.01, poke_interval=0.01, python_callable=lambda: False, - dag=self.dag) + dag=self.dag, + ) with self.assertRaises(AirflowSensorTimeout): op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_python_sensor_raise(self): - op = PythonSensor( - task_id='python_sensor_check_raise', - python_callable=lambda: 1 / 0, - dag=self.dag) + op = PythonSensor(task_id='python_sensor_check_raise', python_callable=lambda: 1 / 0, dag=self.dag) with self.assertRaises(ZeroDivisionError): op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @@ -73,19 +67,15 @@ def test_python_callable_arguments_are_templatized(self): # a Mock instance cannot be used as a callable function or test fails with a # TypeError: Object of type Mock is not JSON serializable python_callable=build_recording_function(recorded_calls), - op_args=[ - 4, - date(2019, 1, 1), - "dag {{dag.dag_id}} ran on {{ds}}.", - named_tuple - ], - dag=self.dag) + op_args=[4, date(2019, 1, 1), "dag {{dag.dag_id}} ran on {{ds}}.", named_tuple], + dag=self.dag, + ) self.dag.create_dagrun( run_type=DagRunType.MANUAL, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, - state=State.RUNNING + state=State.RUNNING, ) with self.assertRaises(AirflowSensorTimeout): task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) @@ -95,10 +85,12 @@ def test_python_callable_arguments_are_templatized(self): self.assertEqual(2, len(recorded_calls)) self._assert_calls_equal( recorded_calls[0], - Call(4, - date(2019, 1, 1), - f"dag {self.dag.dag_id} ran on {ds_templated}.", - Named(ds_templated, 'unchanged')) + Call( + 4, + date(2019, 1, 1), + f"dag {self.dag.dag_id} ran on {ds_templated}.", + Named(ds_templated, 'unchanged'), + ), ) def test_python_callable_keyword_arguments_are_templatized(self): @@ -115,15 +107,16 @@ def test_python_callable_keyword_arguments_are_templatized(self): op_kwargs={ 'an_int': 4, 'a_date': date(2019, 1, 1), - 'a_templated_string': "dag {{dag.dag_id}} ran on {{ds}}." + 'a_templated_string': "dag {{dag.dag_id}} ran on {{ds}}.", }, - dag=self.dag) + dag=self.dag, + ) self.dag.create_dagrun( run_type=DagRunType.MANUAL, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, - state=State.RUNNING + state=State.RUNNING, ) with self.assertRaises(AirflowSensorTimeout): task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) @@ -132,8 +125,11 @@ def test_python_callable_keyword_arguments_are_templatized(self): self.assertEqual(2, len(recorded_calls)) self._assert_calls_equal( recorded_calls[0], - Call(an_int=4, - a_date=date(2019, 1, 1), - a_templated_string="dag {} ran on {}.".format( - self.dag.dag_id, DEFAULT_DATE.date().isoformat())) + Call( + an_int=4, + a_date=date(2019, 1, 1), + a_templated_string="dag {} ran on {}.".format( + self.dag.dag_id, DEFAULT_DATE.date().isoformat() + ), + ), ) diff --git a/tests/sensors/test_smart_sensor_operator.py b/tests/sensors/test_smart_sensor_operator.py index 90f5c5c67dd7b..fa7fc079bb208 100644 --- a/tests/sensors/test_smart_sensor_operator.py +++ b/tests/sensors/test_smart_sensor_operator.py @@ -43,13 +43,10 @@ class DummySmartSensor(SmartSensorOperator): - def __init__(self, - shard_max=conf.getint('smart_sensor', 'shard_code_upper_limit'), - shard_min=0, - **kwargs): - super().__init__(shard_min=shard_min, - shard_max=shard_max, - **kwargs) + def __init__( + self, shard_max=conf.getint('smart_sensor', 'shard_code_upper_limit'), shard_min=0, **kwargs + ): + super().__init__(shard_min=shard_min, shard_max=shard_max, **kwargs) class DummySensor(BaseSensorOperator): @@ -73,10 +70,7 @@ def setUp(self): os.environ['AIRFLOW__SMART_SENSER__USE_SMART_SENSOR'] = 'true' os.environ['AIRFLOW__SMART_SENSER__SENSORS_ENABLED'] = 'DummySmartSensor' - args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - } + args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} self.dag = DAG(TEST_DAG_ID, default_args=args) self.sensor_dag = DAG(TEST_SENSOR_DAG_ID, default_args=args) self.log = logging.getLogger('BaseSmartTest') @@ -102,7 +96,7 @@ def _make_dag_run(self): run_id='manual__' + TEST_DAG_ID, start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, - state=State.RUNNING + state=State.RUNNING, ) def _make_sensor_dag_run(self): @@ -110,7 +104,7 @@ def _make_sensor_dag_run(self): run_id='manual__' + TEST_SENSOR_DAG_ID, start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, - state=State.RUNNING + state=State.RUNNING, ) def _make_sensor(self, return_value, **kwargs): @@ -121,12 +115,7 @@ def _make_sensor(self, return_value, **kwargs): if timeout not in kwargs: kwargs[timeout] = 0 - sensor = DummySensor( - task_id=SENSOR_OP, - return_value=return_value, - dag=self.sensor_dag, - **kwargs - ) + sensor = DummySensor(task_id=SENSOR_OP, return_value=return_value, dag=self.sensor_dag, **kwargs) return sensor @@ -139,12 +128,7 @@ def _make_sensor_instance(self, index, return_value, **kwargs): kwargs[timeout] = 0 task_id = SENSOR_OP + str(index) - sensor = DummySensor( - task_id=task_id, - return_value=return_value, - dag=self.sensor_dag, - **kwargs - ) + sensor = DummySensor(task_id=task_id, return_value=return_value, dag=self.sensor_dag, **kwargs) ti = TaskInstance(task=sensor, execution_date=DEFAULT_DATE) @@ -158,16 +142,9 @@ def _make_smart_operator(self, index, **kwargs): if smart_sensor_timeout not in kwargs: kwargs[smart_sensor_timeout] = 0 - smart_task = DummySmartSensor( - task_id=SMART_OP + "_" + str(index), - dag=self.dag, - **kwargs - ) + smart_task = DummySmartSensor(task_id=SMART_OP + "_" + str(index), dag=self.dag, **kwargs) - dummy_op = DummyOperator( - task_id=DUMMY_OP, - dag=self.dag - ) + dummy_op = DummyOperator(task_id=DUMMY_OP, dag=self.dag) dummy_op.set_upstream(smart_task) return smart_task @@ -315,11 +292,13 @@ def test_register_in_sensor_service(self): session = settings.Session() SI = SensorInstance - sensor_instance = session.query(SI).filter( - SI.dag_id == si1.dag_id, - SI.task_id == si1.task_id, - SI.execution_date == si1.execution_date) \ + sensor_instance = ( + session.query(SI) + .filter( + SI.dag_id == si1.dag_id, SI.task_id == si1.task_id, SI.execution_date == si1.execution_date + ) .first() + ) self.assertIsNotNone(sensor_instance) self.assertEqual(sensor_instance.state, State.SENSING) diff --git a/tests/sensors/test_sql_sensor.py b/tests/sensors/test_sql_sensor.py index ac58f26926007..18948856abffa 100644 --- a/tests/sensors/test_sql_sensor.py +++ b/tests/sensors/test_sql_sensor.py @@ -32,13 +32,9 @@ class TestSqlSensor(TestHiveEnvironment): - def setUp(self): super().setUp() - args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - } + args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} self.dag = DAG(TEST_DAG_ID, default_args=args) def test_unsupported_conn_type(self): @@ -46,7 +42,7 @@ def test_unsupported_conn_type(self): task_id='sql_sensor_check', conn_id='redis_default', sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES", - dag=self.dag + dag=self.dag, ) with self.assertRaises(AirflowException): @@ -58,7 +54,7 @@ def test_sql_sensor_mysql(self): task_id='sql_sensor_check_1', conn_id='mysql_default', sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES", - dag=self.dag + dag=self.dag, ) op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @@ -67,7 +63,7 @@ def test_sql_sensor_mysql(self): conn_id='mysql_default', sql="SELECT count(%s) FROM INFORMATION_SCHEMA.TABLES", parameters=["table_name"], - dag=self.dag + dag=self.dag, ) op2.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @@ -77,7 +73,7 @@ def test_sql_sensor_postgres(self): task_id='sql_sensor_check_1', conn_id='postgres_default', sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES", - dag=self.dag + dag=self.dag, ) op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @@ -86,7 +82,7 @@ def test_sql_sensor_postgres(self): conn_id='postgres_default', sql="SELECT count(%s) FROM INFORMATION_SCHEMA.TABLES", parameters=["table_name"], - dag=self.dag + dag=self.dag, ) op2.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @@ -125,10 +121,7 @@ def test_sql_sensor_postgres_poke(self, mock_hook): @mock.patch('airflow.sensors.sql_sensor.BaseHook') def test_sql_sensor_postgres_poke_fail_on_empty(self, mock_hook): op = SqlSensor( - task_id='sql_sensor_check', - conn_id='postgres_default', - sql="SELECT 1", - fail_on_empty=True + task_id='sql_sensor_check', conn_id='postgres_default', sql="SELECT 1", fail_on_empty=True ) mock_hook.get_connection('postgres_default').conn_type = "postgres" @@ -140,10 +133,7 @@ def test_sql_sensor_postgres_poke_fail_on_empty(self, mock_hook): @mock.patch('airflow.sensors.sql_sensor.BaseHook') def test_sql_sensor_postgres_poke_success(self, mock_hook): op = SqlSensor( - task_id='sql_sensor_check', - conn_id='postgres_default', - sql="SELECT 1", - success=lambda x: x in [1] + task_id='sql_sensor_check', conn_id='postgres_default', sql="SELECT 1", success=lambda x: x in [1] ) mock_hook.get_connection('postgres_default').conn_type = "postgres" @@ -161,10 +151,7 @@ def test_sql_sensor_postgres_poke_success(self, mock_hook): @mock.patch('airflow.sensors.sql_sensor.BaseHook') def test_sql_sensor_postgres_poke_failure(self, mock_hook): op = SqlSensor( - task_id='sql_sensor_check', - conn_id='postgres_default', - sql="SELECT 1", - failure=lambda x: x in [1] + task_id='sql_sensor_check', conn_id='postgres_default', sql="SELECT 1", failure=lambda x: x in [1] ) mock_hook.get_connection('postgres_default').conn_type = "postgres" @@ -183,7 +170,7 @@ def test_sql_sensor_postgres_poke_failure_success(self, mock_hook): conn_id='postgres_default', sql="SELECT 1", failure=lambda x: x in [1], - success=lambda x: x in [2] + success=lambda x: x in [2], ) mock_hook.get_connection('postgres_default').conn_type = "postgres" @@ -205,7 +192,7 @@ def test_sql_sensor_postgres_poke_failure_success_same(self, mock_hook): conn_id='postgres_default', sql="SELECT 1", failure=lambda x: x in [1], - success=lambda x: x in [1] + success=lambda x: x in [1], ) mock_hook.get_connection('postgres_default').conn_type = "postgres" @@ -248,13 +235,13 @@ def test_sql_sensor_postgres_poke_invalid_success(self, mock_hook): self.assertRaises(AirflowException, op.poke, None) @unittest.skipIf( - 'AIRFLOW_RUNALL_TESTS' not in os.environ, - "Skipped because AIRFLOW_RUNALL_TESTS is not set") + 'AIRFLOW_RUNALL_TESTS' not in os.environ, "Skipped because AIRFLOW_RUNALL_TESTS is not set" + ) def test_sql_sensor_presto(self): op = SqlSensor( task_id='hdfs_sensor_check', conn_id='presto_default', sql="SELECT 'x' FROM airflow.static_babynames LIMIT 1;", - dag=self.dag) - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, - ignore_ti_state=True) + dag=self.dag, + ) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) diff --git a/tests/sensors/test_time_sensor.py b/tests/sensors/test_time_sensor.py index 28d8936e338dd..567297d0917dc 100644 --- a/tests/sensors/test_time_sensor.py +++ b/tests/sensors/test_time_sensor.py @@ -28,9 +28,7 @@ DEFAULT_TIMEZONE = "Asia/Singapore" # UTC+08:00 DEFAULT_DATE_WO_TZ = datetime(2015, 1, 1) -DEFAULT_DATE_WITH_TZ = datetime( - 2015, 1, 1, tzinfo=pendulum.tz.timezone(DEFAULT_TIMEZONE) -) +DEFAULT_DATE_WITH_TZ = datetime(2015, 1, 1, tzinfo=pendulum.tz.timezone(DEFAULT_TIMEZONE)) @patch( diff --git a/tests/sensors/test_timedelta_sensor.py b/tests/sensors/test_timedelta_sensor.py index 615c922e68560..d613337161b91 100644 --- a/tests/sensors/test_timedelta_sensor.py +++ b/tests/sensors/test_timedelta_sensor.py @@ -30,14 +30,10 @@ class TestTimedeltaSensor(unittest.TestCase): def setUp(self): - self.dagbag = DagBag( - dag_folder=DEV_NULL, include_examples=True) + self.dagbag = DagBag(dag_folder=DEV_NULL, include_examples=True) self.args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} self.dag = DAG(TEST_DAG_ID, default_args=self.args) def test_timedelta_sensor(self): - op = TimeDeltaSensor( - task_id='timedelta_sensor_check', - delta=timedelta(seconds=2), - dag=self.dag) + op = TimeDeltaSensor(task_id='timedelta_sensor_check', delta=timedelta(seconds=2), dag=self.dag) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) diff --git a/tests/sensors/test_timeout_sensor.py b/tests/sensors/test_timeout_sensor.py index 09b35b3ad2247..3df6e101ece5b 100644 --- a/tests/sensors/test_timeout_sensor.py +++ b/tests/sensors/test_timeout_sensor.py @@ -39,9 +39,7 @@ class TimeoutTestSensor(BaseSensorOperator): """ @apply_defaults - def __init__(self, - return_value=False, - **kwargs): + def __init__(self, return_value=False, **kwargs): self.return_value = return_value super().__init__(**kwargs) @@ -65,10 +63,7 @@ def execute(self, context): class TestSensorTimeout(unittest.TestCase): def setUp(self): - args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - } + args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} self.dag = DAG(TEST_DAG_ID, default_args=args) def test_timeout(self): @@ -78,10 +73,8 @@ def test_timeout(self): return_value=False, poke_interval=5, params={'time_jump': timedelta(days=2, seconds=1)}, - dag=self.dag + dag=self.dag, ) self.assertRaises( - AirflowSensorTimeout, - op.run, - start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True + AirflowSensorTimeout, op.run, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True ) diff --git a/tests/sensors/test_weekday_sensor.py b/tests/sensors/test_weekday_sensor.py index 75b5cb9961b17..26b65a06e9147 100644 --- a/tests/sensors/test_weekday_sensor.py +++ b/tests/sensors/test_weekday_sensor.py @@ -37,7 +37,6 @@ class TestDayOfWeekSensor(unittest.TestCase): - @staticmethod def clean_db(): db.clear_db_runs() @@ -45,34 +44,28 @@ def clean_db(): def setUp(self): self.clean_db() - self.dagbag = DagBag( - dag_folder=DEV_NULL, - include_examples=True - ) - self.args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - } + self.dagbag = DagBag(dag_folder=DEV_NULL, include_examples=True) + self.args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} dag = DAG(TEST_DAG_ID, default_args=self.args) self.dag = dag def tearDown(self): self.clean_db() - @parameterized.expand([ - ("with-string", 'Thursday'), - ("with-enum", WeekDay.THURSDAY), - ("with-enum-set", {WeekDay.THURSDAY}), - ("with-enum-set-2-items", {WeekDay.THURSDAY, WeekDay.FRIDAY}), - ("with-string-set", {'Thursday'}), - ("with-string-set-2-items", {'Thursday', 'Friday'}), - ]) + @parameterized.expand( + [ + ("with-string", 'Thursday'), + ("with-enum", WeekDay.THURSDAY), + ("with-enum-set", {WeekDay.THURSDAY}), + ("with-enum-set-2-items", {WeekDay.THURSDAY, WeekDay.FRIDAY}), + ("with-string-set", {'Thursday'}), + ("with-string-set-2-items", {'Thursday', 'Friday'}), + ] + ) def test_weekday_sensor_true(self, _, week_day): op = DayOfWeekSensor( - task_id='weekday_sensor_check_true', - week_day=week_day, - use_task_execution_day=True, - dag=self.dag) + task_id='weekday_sensor_check_true', week_day=week_day, use_task_execution_day=True, dag=self.dag + ) op.run(start_date=WEEKDAY_DATE, end_date=WEEKDAY_DATE, ignore_ti_state=True) self.assertEqual(op.week_day, week_day) @@ -83,32 +76,35 @@ def test_weekday_sensor_false(self): timeout=2, week_day='Tuesday', use_task_execution_day=True, - dag=self.dag) + dag=self.dag, + ) with self.assertRaises(AirflowSensorTimeout): op.run(start_date=WEEKDAY_DATE, end_date=WEEKDAY_DATE, ignore_ti_state=True) def test_invalid_weekday_number(self): invalid_week_day = 'Thsday' - with self.assertRaisesRegex(AttributeError, - f'Invalid Week Day passed: "{invalid_week_day}"'): + with self.assertRaisesRegex(AttributeError, f'Invalid Week Day passed: "{invalid_week_day}"'): DayOfWeekSensor( task_id='weekday_sensor_invalid_weekday_num', week_day=invalid_week_day, use_task_execution_day=True, - dag=self.dag) + dag=self.dag, + ) def test_weekday_sensor_with_invalid_type(self): invalid_week_day = ['Thsday'] - with self.assertRaisesRegex(TypeError, - 'Unsupported Type for week_day parameter:' - ' {}. It should be one of str, set or ' - 'Weekday enum type'.format(type(invalid_week_day)) - ): + with self.assertRaisesRegex( + TypeError, + 'Unsupported Type for week_day parameter:' + ' {}. It should be one of str, set or ' + 'Weekday enum type'.format(type(invalid_week_day)), + ): DayOfWeekSensor( task_id='weekday_sensor_check_true', week_day=invalid_week_day, use_task_execution_day=True, - dag=self.dag) + dag=self.dag, + ) def test_weekday_sensor_timeout_with_set(self): op = DayOfWeekSensor( @@ -117,6 +113,7 @@ def test_weekday_sensor_timeout_with_set(self): timeout=2, week_day={WeekDay.MONDAY, WeekDay.TUESDAY}, use_task_execution_day=True, - dag=self.dag) + dag=self.dag, + ) with self.assertRaises(AirflowSensorTimeout): op.run(start_date=WEEKDAY_DATE, end_date=WEEKDAY_DATE, ignore_ti_state=True) diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 08a8aaad938a9..51be0b0abf903 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -44,17 +44,12 @@ executor_config_pod = k8s.V1Pod( metadata=k8s.V1ObjectMeta(name="my-name"), - spec=k8s.V1PodSpec(containers=[ - k8s.V1Container( - name="base", - volume_mounts=[ - k8s.V1VolumeMount( - name="my-vol", - mount_path="/vol/" - ) - ] - ) - ])) + spec=k8s.V1PodSpec( + containers=[ + k8s.V1Container(name="base", volume_mounts=[k8s.V1VolumeMount(name="my-vol", mount_path="/vol/")]) + ] + ), +) serialized_simple_dag_ground_truth = { "__version": 1, @@ -64,24 +59,22 @@ "__var": { "depends_on_past": False, "retries": 1, - "retry_delay": { - "__type": "timedelta", - "__var": 300.0 - } - } + "retry_delay": {"__type": "timedelta", "__var": 300.0}, + }, }, "start_date": 1564617600.0, - '_task_group': {'_group_id': None, - 'prefix_group_id': True, - 'children': {'bash_task': ('operator', 'bash_task'), - 'custom_task': ('operator', 'custom_task')}, - 'tooltip': '', - 'ui_color': 'CornflowerBlue', - 'ui_fgcolor': '#000', - 'upstream_group_ids': [], - 'downstream_group_ids': [], - 'upstream_task_ids': [], - 'downstream_task_ids': []}, + '_task_group': { + '_group_id': None, + 'prefix_group_id': True, + 'children': {'bash_task': ('operator', 'bash_task'), 'custom_task': ('operator', 'custom_task')}, + 'tooltip': '', + 'ui_color': 'CornflowerBlue', + 'ui_fgcolor': '#000', + 'upstream_group_ids': [], + 'downstream_group_ids': [], + 'upstream_task_ids': [], + 'downstream_task_ids': [], + }, "is_paused_upon_creation": False, "_dag_id": "simple_dag", "fileloc": None, @@ -103,12 +96,15 @@ "_task_type": "BashOperator", "_task_module": "airflow.operators.bash", "pool": "default_pool", - "executor_config": {'__type': 'dict', - '__var': {"pod_override": { - '__type': 'k8s.V1Pod', - '__var': PodGenerator.serialize_pod(executor_config_pod)} - } - } + "executor_config": { + '__type': 'dict', + '__var': { + "pod_override": { + '__type': 'k8s.V1Pod', + '__var': PodGenerator.serialize_pod(executor_config_pod), + } + }, + }, }, { "task_id": "custom_task", @@ -134,13 +130,10 @@ "__var": { "test_role": { "__type": "set", - "__var": [ - permissions.ACTION_CAN_READ, - permissions.ACTION_CAN_EDIT - ] + "__var": [permissions.ACTION_CAN_READ, permissions.ACTION_CAN_EDIT], } - } - } + }, + }, }, } @@ -166,13 +159,15 @@ def make_simple_dag(): }, start_date=datetime(2019, 8, 1), is_paused_upon_creation=False, - access_control={ - "test_role": {permissions.ACTION_CAN_READ, permissions.ACTION_CAN_EDIT} - } + access_control={"test_role": {permissions.ACTION_CAN_READ, permissions.ACTION_CAN_EDIT}}, ) as dag: CustomOperator(task_id='custom_task') - BashOperator(task_id='bash_task', bash_command='echo {{ task.task_id }}', owner='airflow', - executor_config={"pod_override": executor_config_pod}) + BashOperator( + task_id='bash_task', + bash_command='echo {{ task.task_id }}', + owner='airflow', + executor_config={"pod_override": executor_config_pod}, + ) return {'simple_dag': dag} @@ -189,19 +184,15 @@ def make_user_defined_macro_filter_dag(): def compute_next_execution_date(dag, execution_date): return dag.following_schedule(execution_date) - default_args = { - 'start_date': datetime(2019, 7, 10) - } + default_args = {'start_date': datetime(2019, 7, 10)} dag = DAG( 'user_defined_macro_filter_dag', default_args=default_args, user_defined_macros={ 'next_execution_date': compute_next_execution_date, }, - user_defined_filters={ - 'hello': lambda name: 'Hello %s' % name - }, - catchup=False + user_defined_filters={'hello': lambda name: 'Hello %s' % name}, + catchup=False, ) BashOperator( task_id='echo', @@ -252,14 +243,18 @@ def setUp(self): super().setUp() BaseHook.get_connection = mock.Mock( return_value=Connection( - extra=('{' - '"project_id": "mock", ' - '"location": "mock", ' - '"instance": "mock", ' - '"database_type": "postgres", ' - '"use_proxy": "False", ' - '"use_ssl": "False"' - '}'))) + extra=( + '{' + '"project_id": "mock", ' + '"location": "mock", ' + '"instance": "mock", ' + '"database_type": "postgres", ' + '"use_proxy": "False", ' + '"use_ssl": "False"' + '}' + ) + ) + ) self.maxDiff = None # pylint: disable=invalid-name def test_serialization(self): @@ -272,14 +267,11 @@ def test_serialization(self): serialized_dags[v.dag_id] = dag # Compares with the ground truth of JSON string. - self.validate_serialized_dag( - serialized_dags['simple_dag'], - serialized_simple_dag_ground_truth) + self.validate_serialized_dag(serialized_dags['simple_dag'], serialized_simple_dag_ground_truth) def validate_serialized_dag(self, json_dag, ground_truth_dag): """Verify serialized DAGs match the ground truth.""" - self.assertTrue( - json_dag['dag']['fileloc'].split('/')[-1] == 'test_dag_serialization.py') + self.assertTrue(json_dag['dag']['fileloc'].split('/')[-1] == 'test_dag_serialization.py') json_dag['dag']['fileloc'] = None def sorted_serialized_dag(dag_dict: dict): @@ -289,8 +281,7 @@ def sorted_serialized_dag(dag_dict: dict): items should not matter but assertEqual would fail if the order of items changes in the dag dictionary """ - dag_dict["dag"]["tasks"] = sorted(dag_dict["dag"]["tasks"], - key=lambda x: sorted(x.keys())) + dag_dict["dag"]["tasks"] = sorted(dag_dict["dag"]["tasks"], key=lambda x: sorted(x.keys())) dag_dict["dag"]["_access_control"]["__var"]["test_role"]["__var"] = sorted( dag_dict["dag"]["_access_control"]["__var"]["test_role"]["__var"] ) @@ -306,8 +297,7 @@ def test_deserialization_across_process(self): # and once here to get a DAG to compare to) we don't want to load all # dags. queue = multiprocessing.Queue() - proc = multiprocessing.Process( - target=serialize_subprocess, args=(queue, "airflow/example_dags")) + proc = multiprocessing.Process(target=serialize_subprocess, args=(queue, "airflow/example_dags")) proc.daemon = True proc.start() @@ -328,10 +318,12 @@ def test_deserialization_across_process(self): self.validate_deserialized_dag(stringified_dags[dag_id], dags[dag_id]) def test_roundtrip_provider_example_dags(self): - dags = collect_dags([ - "airflow/providers/*/example_dags", - "airflow/providers/*/*/example_dags", - ]) + dags = collect_dags( + [ + "airflow/providers/*/example_dags", + "airflow/providers/*/*/example_dags", + ] + ) # Verify deserialized DAGs. for dag in dags.values(): @@ -346,14 +338,14 @@ def validate_deserialized_dag(self, serialized_dag, dag): fields_to_check = dag.get_serialized_fields() - { # Doesn't implement __eq__ properly. Check manually 'timezone', - # Need to check fields in it, to exclude functions 'default_args', - "_task_group" + "_task_group", } for field in fields_to_check: - assert getattr(serialized_dag, field) == getattr(dag, field), \ - f'{dag.dag_id}.{field} does not match' + assert getattr(serialized_dag, field) == getattr( + dag, field + ), f'{dag.dag_id}.{field} does not match' if dag.default_args: for k, v in dag.default_args.items(): @@ -361,8 +353,9 @@ def validate_deserialized_dag(self, serialized_dag, dag): # Check we stored _something_. assert k in serialized_dag.default_args else: - assert v == serialized_dag.default_args[k], \ - f'{dag.dag_id}.default_args[{k}] does not match' + assert ( + v == serialized_dag.default_args[k] + ), f'{dag.dag_id}.default_args[{k}] does not match' assert serialized_dag.timezone.name == dag.timezone.name @@ -373,7 +366,11 @@ def validate_deserialized_dag(self, serialized_dag, dag): # and is equal to fileloc assert serialized_dag.full_filepath == dag.fileloc - def validate_deserialized_task(self, serialized_task, task,): + def validate_deserialized_task( + self, + serialized_task, + task, + ): """Verify non-airflow operators are casted to BaseOperator.""" assert isinstance(serialized_task, SerializedBaseOperator) assert not isinstance(task, SerializedBaseOperator) @@ -381,17 +378,16 @@ def validate_deserialized_task(self, serialized_task, task,): fields_to_check = task.get_serialized_fields() - { # Checked separately - '_task_type', 'subdag', - + '_task_type', + 'subdag', # Type is excluded, so don't check it '_log', - # List vs tuple. Check separately 'template_fields', - # We store the string, real dag has the actual code - 'on_failure_callback', 'on_success_callback', 'on_retry_callback', - + 'on_failure_callback', + 'on_success_callback', + 'on_retry_callback', # Checked separately 'resources', } @@ -403,8 +399,9 @@ def validate_deserialized_task(self, serialized_task, task,): assert serialized_task.downstream_task_ids == task.downstream_task_ids for field in fields_to_check: - assert getattr(serialized_task, field) == getattr(task, field), \ - f'{task.dag.dag_id}.{task.task_id}.{field} does not match' + assert getattr(serialized_task, field) == getattr( + task, field + ), f'{task.dag.dag_id}.{task.task_id}.{field} does not match' if serialized_task.resources is None: assert task.resources is None or task.resources == [] @@ -419,17 +416,22 @@ def validate_deserialized_task(self, serialized_task, task,): else: assert serialized_task.subdag is None - @parameterized.expand([ - (datetime(2019, 8, 1, tzinfo=timezone.utc), None, datetime(2019, 8, 1, tzinfo=timezone.utc)), - (datetime(2019, 8, 1, tzinfo=timezone.utc), datetime(2019, 8, 2, tzinfo=timezone.utc), - datetime(2019, 8, 2, tzinfo=timezone.utc)), - (datetime(2019, 8, 1, tzinfo=timezone.utc), datetime(2019, 7, 30, tzinfo=timezone.utc), - datetime(2019, 8, 1, tzinfo=timezone.utc)), - ]) - def test_deserialization_start_date(self, - dag_start_date, - task_start_date, - expected_task_start_date): + @parameterized.expand( + [ + (datetime(2019, 8, 1, tzinfo=timezone.utc), None, datetime(2019, 8, 1, tzinfo=timezone.utc)), + ( + datetime(2019, 8, 1, tzinfo=timezone.utc), + datetime(2019, 8, 2, tzinfo=timezone.utc), + datetime(2019, 8, 2, tzinfo=timezone.utc), + ), + ( + datetime(2019, 8, 1, tzinfo=timezone.utc), + datetime(2019, 7, 30, tzinfo=timezone.utc), + datetime(2019, 8, 1, tzinfo=timezone.utc), + ), + ] + ) + def test_deserialization_start_date(self, dag_start_date, task_start_date, expected_task_start_date): dag = DAG(dag_id='simple_dag', start_date=dag_start_date) BaseOperator(task_id='simple_task', dag=dag, start_date=task_start_date) @@ -446,19 +448,23 @@ def test_deserialization_start_date(self, simple_task = dag.task_dict["simple_task"] self.assertEqual(simple_task.start_date, expected_task_start_date) - @parameterized.expand([ - (datetime(2019, 8, 1, tzinfo=timezone.utc), None, datetime(2019, 8, 1, tzinfo=timezone.utc)), - (datetime(2019, 8, 1, tzinfo=timezone.utc), datetime(2019, 8, 2, tzinfo=timezone.utc), - datetime(2019, 8, 1, tzinfo=timezone.utc)), - (datetime(2019, 8, 1, tzinfo=timezone.utc), datetime(2019, 7, 30, tzinfo=timezone.utc), - datetime(2019, 7, 30, tzinfo=timezone.utc)), - ]) - def test_deserialization_end_date(self, - dag_end_date, - task_end_date, - expected_task_end_date): - dag = DAG(dag_id='simple_dag', start_date=datetime(2019, 8, 1), - end_date=dag_end_date) + @parameterized.expand( + [ + (datetime(2019, 8, 1, tzinfo=timezone.utc), None, datetime(2019, 8, 1, tzinfo=timezone.utc)), + ( + datetime(2019, 8, 1, tzinfo=timezone.utc), + datetime(2019, 8, 2, tzinfo=timezone.utc), + datetime(2019, 8, 1, tzinfo=timezone.utc), + ), + ( + datetime(2019, 8, 1, tzinfo=timezone.utc), + datetime(2019, 7, 30, tzinfo=timezone.utc), + datetime(2019, 7, 30, tzinfo=timezone.utc), + ), + ] + ) + def test_deserialization_end_date(self, dag_end_date, task_end_date, expected_task_end_date): + dag = DAG(dag_id='simple_dag', start_date=datetime(2019, 8, 1), end_date=dag_end_date) BaseOperator(task_id='simple_task', dag=dag, end_date=task_end_date) serialized_dag = SerializedDAG.to_dict(dag) @@ -473,12 +479,14 @@ def test_deserialization_end_date(self, simple_task = dag.task_dict["simple_task"] self.assertEqual(simple_task.end_date, expected_task_end_date) - @parameterized.expand([ - (None, None, None), - ("@weekly", "@weekly", "0 0 * * 0"), - ("@once", "@once", None), - ({"__type": "timedelta", "__var": 86400.0}, timedelta(days=1), timedelta(days=1)), - ]) + @parameterized.expand( + [ + (None, None, None), + ("@weekly", "@weekly", "0 0 * * 0"), + ("@once", "@once", None), + ({"__type": "timedelta", "__var": 86400.0}, timedelta(days=1), timedelta(days=1)), + ] + ) def test_deserialization_schedule_interval( self, serialized_schedule_interval, expected_schedule_interval, expected_n_schedule_interval ): @@ -501,14 +509,16 @@ def test_deserialization_schedule_interval( self.assertEqual(dag.schedule_interval, expected_schedule_interval) self.assertEqual(dag.normalized_schedule_interval, expected_n_schedule_interval) - @parameterized.expand([ - (relativedelta(days=-1), {"__type": "relativedelta", "__var": {"days": -1}}), - (relativedelta(month=1, days=-1), {"__type": "relativedelta", "__var": {"month": 1, "days": -1}}), - # Every friday - (relativedelta(weekday=FR), {"__type": "relativedelta", "__var": {"weekday": [4]}}), - # Every second friday - (relativedelta(weekday=FR(2)), {"__type": "relativedelta", "__var": {"weekday": [4, 2]}}) - ]) + @parameterized.expand( + [ + (relativedelta(days=-1), {"__type": "relativedelta", "__var": {"days": -1}}), + (relativedelta(month=1, days=-1), {"__type": "relativedelta", "__var": {"month": 1, "days": -1}}), + # Every friday + (relativedelta(weekday=FR), {"__type": "relativedelta", "__var": {"weekday": [4]}}), + # Every second friday + (relativedelta(weekday=FR(2)), {"__type": "relativedelta", "__var": {"weekday": [4, 2]}}), + ] + ) def test_roundtrip_relativedelta(self, val, expected): serialized = SerializedDAG._serialize(val) self.assertDictEqual(serialized, expected) @@ -516,10 +526,12 @@ def test_roundtrip_relativedelta(self, val, expected): round_tripped = SerializedDAG._deserialize(serialized) self.assertEqual(val, round_tripped) - @parameterized.expand([ - (None, {}), - ({"param_1": "value_1"}, {"param_1": "value_1"}), - ]) + @parameterized.expand( + [ + (None, {}), + ({"param_1": "value_1"}, {"param_1": "value_1"}), + ] + ) def test_dag_params_roundtrip(self, val, expected_val): """ Test that params work both on Serialized DAGs & Tasks @@ -538,17 +550,18 @@ def test_dag_params_roundtrip(self, val, expected_val): self.assertEqual(expected_val, deserialized_dag.params) self.assertEqual(expected_val, deserialized_simple_task.params) - @parameterized.expand([ - (None, {}), - ({"param_1": "value_1"}, {"param_1": "value_1"}), - ]) + @parameterized.expand( + [ + (None, {}), + ({"param_1": "value_1"}, {"param_1": "value_1"}), + ] + ) def test_task_params_roundtrip(self, val, expected_val): """ Test that params work both on Serialized DAGs & Tasks """ dag = DAG(dag_id='simple_dag') - BaseOperator(task_id='simple_task', dag=dag, params=val, - start_date=datetime(2019, 8, 1)) + BaseOperator(task_id='simple_task', dag=dag, params=val, start_date=datetime(2019, 8, 1)) serialized_dag = SerializedDAG.to_dict(dag) if val: @@ -589,7 +602,7 @@ def test_extra_serialized_field_and_operator_links(self): # Check Serialized version of operator link only contains the inbuilt Op Link self.assertEqual( serialized_dag["dag"]["tasks"][0]["_operator_extra_links"], - [{'tests.test_utils.mock_operators.CustomOpLink': {}}] + [{'tests.test_utils.mock_operators.CustomOpLink': {}}], ) # Test all the extra_links are set @@ -614,6 +627,7 @@ def test_extra_operator_links_logs_error_for_non_registered_extra_links(self): class TaskStateLink(BaseOperatorLink): """OperatorLink not registered via Plugins nor a built-in OperatorLink""" + name = 'My Link' def get_link(self, operator, dttm): @@ -621,6 +635,7 @@ def get_link(self, operator, dttm): class MyOperator(BaseOperator): """Just a DummyOperator using above defined Extra Operator Link""" + operator_extra_links = [TaskStateLink()] def execute(self, context): @@ -672,12 +687,14 @@ def test_extra_serialized_field_and_multiple_operator_links(self): [ {'tests.test_utils.mock_operators.CustomBaseIndexOpLink': {'index': 0}}, {'tests.test_utils.mock_operators.CustomBaseIndexOpLink': {'index': 1}}, - ] + ], ) # Test all the extra_links are set - self.assertCountEqual(simple_task.extra_links, [ - 'BigQuery Console #1', 'BigQuery Console #2', 'airflow', 'github', 'google']) + self.assertCountEqual( + simple_task.extra_links, + ['BigQuery Console #1', 'BigQuery Console #2', 'airflow', 'github', 'google'], + ) ti = TaskInstance(task=simple_task, execution_date=test_date) ti.xcom_push('search_query', ["dummy_value_1", "dummy_value_2"]) @@ -715,42 +732,48 @@ def __eq__(self, other): def __ne__(self, other): return not self.__eq__(other) - @parameterized.expand([ - (None, None), - ([], []), - ({}, {}), - ("{{ task.task_id }}", "{{ task.task_id }}"), - (["{{ task.task_id }}", "{{ task.task_id }}"]), - ({"foo": "{{ task.task_id }}"}, {"foo": "{{ task.task_id }}"}), - ({"foo": {"bar": "{{ task.task_id }}"}}, {"foo": {"bar": "{{ task.task_id }}"}}), - ( - [{"foo1": {"bar": "{{ task.task_id }}"}}, {"foo2": {"bar": "{{ task.task_id }}"}}], - [{"foo1": {"bar": "{{ task.task_id }}"}}, {"foo2": {"bar": "{{ task.task_id }}"}}], - ), - ( - {"foo": {"bar": {"{{ task.task_id }}": ["sar"]}}}, - {"foo": {"bar": {"{{ task.task_id }}": ["sar"]}}}), - ( - ClassWithCustomAttributes( - att1="{{ task.task_id }}", att2="{{ task.task_id }}", template_fields=["att1"]), - "ClassWithCustomAttributes(" - "{'att1': '{{ task.task_id }}', 'att2': '{{ task.task_id }}', 'template_fields': ['att1']})", - ), - ( - ClassWithCustomAttributes(nested1=ClassWithCustomAttributes(att1="{{ task.task_id }}", - att2="{{ task.task_id }}", - template_fields=["att1"]), - nested2=ClassWithCustomAttributes(att3="{{ task.task_id }}", - att4="{{ task.task_id }}", - template_fields=["att3"]), - template_fields=["nested1"]), - "ClassWithCustomAttributes(" - "{'nested1': ClassWithCustomAttributes({'att1': '{{ task.task_id }}', " - "'att2': '{{ task.task_id }}', 'template_fields': ['att1']}), " - "'nested2': ClassWithCustomAttributes({'att3': '{{ task.task_id }}', " - "'att4': '{{ task.task_id }}', 'template_fields': ['att3']}), 'template_fields': ['nested1']})", - ), - ]) + @parameterized.expand( + [ + (None, None), + ([], []), + ({}, {}), + ("{{ task.task_id }}", "{{ task.task_id }}"), + (["{{ task.task_id }}", "{{ task.task_id }}"]), + ({"foo": "{{ task.task_id }}"}, {"foo": "{{ task.task_id }}"}), + ({"foo": {"bar": "{{ task.task_id }}"}}, {"foo": {"bar": "{{ task.task_id }}"}}), + ( + [{"foo1": {"bar": "{{ task.task_id }}"}}, {"foo2": {"bar": "{{ task.task_id }}"}}], + [{"foo1": {"bar": "{{ task.task_id }}"}}, {"foo2": {"bar": "{{ task.task_id }}"}}], + ), + ( + {"foo": {"bar": {"{{ task.task_id }}": ["sar"]}}}, + {"foo": {"bar": {"{{ task.task_id }}": ["sar"]}}}, + ), + ( + ClassWithCustomAttributes( + att1="{{ task.task_id }}", att2="{{ task.task_id }}", template_fields=["att1"] + ), + "ClassWithCustomAttributes(" + "{'att1': '{{ task.task_id }}', 'att2': '{{ task.task_id }}', 'template_fields': ['att1']})", + ), + ( + ClassWithCustomAttributes( + nested1=ClassWithCustomAttributes( + att1="{{ task.task_id }}", att2="{{ task.task_id }}", template_fields=["att1"] + ), + nested2=ClassWithCustomAttributes( + att3="{{ task.task_id }}", att4="{{ task.task_id }}", template_fields=["att3"] + ), + template_fields=["nested1"], + ), + "ClassWithCustomAttributes(" + "{'nested1': ClassWithCustomAttributes({'att1': '{{ task.task_id }}', " + "'att2': '{{ task.task_id }}', 'template_fields': ['att1']}), " + "'nested2': ClassWithCustomAttributes({'att3': '{{ task.task_id }}', 'att4': " + "'{{ task.task_id }}', 'template_fields': ['att3']}), 'template_fields': ['nested1']})", + ), + ] + ) def test_templated_fields_exist_in_serialized_dag(self, templated_field, expected_field): """ Test that templated_fields exists for all Operators in Serialized DAG @@ -781,8 +804,9 @@ def test_dag_serialized_fields_with_schema(self): self.assertEqual(set(DAG.get_serialized_fields()), dag_params) def test_operator_subclass_changing_base_defaults(self): - assert BaseOperator(task_id='dummy').do_xcom_push is True, \ - "Precondition check! If this fails the test won't make sense" + assert ( + BaseOperator(task_id='dummy').do_xcom_push is True + ), "Precondition check! If this fails the test won't make sense" class MyOperator(BaseOperator): def __init__(self, do_xcom_push=False, **kwargs): @@ -804,49 +828,53 @@ def test_no_new_fields_added_to_base_operator(self): """ base_operator = BaseOperator(task_id="10") fields = base_operator.__dict__ - self.assertEqual({'_BaseOperator__instantiated': True, - '_dag': None, - '_downstream_task_ids': set(), - '_inlets': [], - '_log': base_operator.log, - '_outlets': [], - '_upstream_task_ids': set(), - 'depends_on_past': False, - 'do_xcom_push': True, - 'email': None, - 'email_on_failure': True, - 'email_on_retry': True, - 'end_date': None, - 'execution_timeout': None, - 'executor_config': {}, - 'inlets': [], - 'label': '10', - 'max_retry_delay': None, - 'on_execute_callback': None, - 'on_failure_callback': None, - 'on_retry_callback': None, - 'on_success_callback': None, - 'outlets': [], - 'owner': 'airflow', - 'params': {}, - 'pool': 'default_pool', - 'pool_slots': 1, - 'priority_weight': 1, - 'queue': 'default', - 'resources': None, - 'retries': 0, - 'retry_delay': timedelta(0, 300), - 'retry_exponential_backoff': False, - 'run_as_user': None, - 'sla': None, - 'start_date': None, - 'subdag': None, - 'task_concurrency': None, - 'task_id': '10', - 'trigger_rule': 'all_success', - 'wait_for_downstream': False, - 'weight_rule': 'downstream'}, fields, - """ + self.assertEqual( + { + '_BaseOperator__instantiated': True, + '_dag': None, + '_downstream_task_ids': set(), + '_inlets': [], + '_log': base_operator.log, + '_outlets': [], + '_upstream_task_ids': set(), + 'depends_on_past': False, + 'do_xcom_push': True, + 'email': None, + 'email_on_failure': True, + 'email_on_retry': True, + 'end_date': None, + 'execution_timeout': None, + 'executor_config': {}, + 'inlets': [], + 'label': '10', + 'max_retry_delay': None, + 'on_execute_callback': None, + 'on_failure_callback': None, + 'on_retry_callback': None, + 'on_success_callback': None, + 'outlets': [], + 'owner': 'airflow', + 'params': {}, + 'pool': 'default_pool', + 'pool_slots': 1, + 'priority_weight': 1, + 'queue': 'default', + 'resources': None, + 'retries': 0, + 'retry_delay': timedelta(0, 300), + 'retry_exponential_backoff': False, + 'run_as_user': None, + 'sla': None, + 'start_date': None, + 'subdag': None, + 'task_concurrency': None, + 'task_id': '10', + 'trigger_rule': 'all_success', + 'wait_for_downstream': False, + 'weight_rule': 'downstream', + }, + fields, + """ !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ACTION NEEDED! PLEASE READ THIS CAREFULLY AND CORRECT TESTS CAREFULLY @@ -858,8 +886,8 @@ def test_no_new_fields_added_to_base_operator(self): Note that we do not support versioning yet so you should only add optional fields to BaseOperator. !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - """ - ) + """, + ) def test_task_group_serialization(self): """ diff --git a/tests/task/task_runner/test_cgroup_task_runner.py b/tests/task/task_runner/test_cgroup_task_runner.py index c83cf6242fc2a..c5fb97098abe0 100644 --- a/tests/task/task_runner/test_cgroup_task_runner.py +++ b/tests/task/task_runner/test_cgroup_task_runner.py @@ -22,7 +22,6 @@ class TestCgroupTaskRunner(unittest.TestCase): - @mock.patch("airflow.task.task_runner.base_task_runner.BaseTaskRunner.__init__") @mock.patch("airflow.task.task_runner.base_task_runner.BaseTaskRunner.on_finish") def test_cgroup_task_runner_super_calls(self, mock_super_on_finish, mock_super_init): diff --git a/tests/task/task_runner/test_standard_task_runner.py b/tests/task/task_runner/test_standard_task_runner.py index b37040c2adfe8..b9e1db7a4cfcf 100644 --- a/tests/task/task_runner/test_standard_task_runner.py +++ b/tests/task/task_runner/test_standard_task_runner.py @@ -40,24 +40,16 @@ 'version': 1, 'disable_existing_loggers': False, 'formatters': { - 'airflow.task': { - 'format': '[%(asctime)s] {{%(filename)s:%(lineno)d}} %(levelname)s - %(message)s' - }, + 'airflow.task': {'format': '[%(asctime)s] {{%(filename)s:%(lineno)d}} %(levelname)s - %(message)s'}, }, 'handlers': { 'console': { 'class': 'logging.StreamHandler', 'formatter': 'airflow.task', - 'stream': 'ext://sys.stdout' + 'stream': 'ext://sys.stdout', } }, - 'loggers': { - 'airflow': { - 'handlers': ['console'], - 'level': 'INFO', - 'propagate': False - } - } + 'loggers': {'airflow': {'handlers': ['console'], 'level': 'INFO', 'propagate': False}}, } @@ -79,7 +71,12 @@ def test_start_and_terminate(self): local_task_job.task_instance = mock.MagicMock() local_task_job.task_instance.run_as_user = None local_task_job.task_instance.command_as_list.return_value = [ - 'airflow', 'tasks', 'test', 'test_on_kill', 'task1', '2016-01-01' + 'airflow', + 'tasks', + 'test', + 'test_on_kill', + 'task1', + '2016-01-01', ] runner = StandardTaskRunner(local_task_job) @@ -104,7 +101,12 @@ def test_start_and_terminate_run_as_user(self): local_task_job.task_instance = mock.MagicMock() local_task_job.task_instance.run_as_user = getpass.getuser() local_task_job.task_instance.command_as_list.return_value = [ - 'airflow', 'tasks', 'test', 'test_on_kill', 'task1', '2016-01-01' + 'airflow', + 'tasks', + 'test', + 'test_on_kill', + 'task1', + '2016-01-01', ] runner = StandardTaskRunner(local_task_job) @@ -146,11 +148,13 @@ def test_on_kill(self): session = settings.Session() dag.clear() - dag.create_dagrun(run_id="test", - state=State.RUNNING, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE, - session=session) + dag.create_dagrun( + run_id="test", + state=State.RUNNING, + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + session=session, + ) ti = TI(task=task, execution_date=DEFAULT_DATE) job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True) session.commit() diff --git a/tests/task/task_runner/test_task_runner.py b/tests/task/task_runner/test_task_runner.py index be4187d36166a..4601333ee18ea 100644 --- a/tests/task/task_runner/test_task_runner.py +++ b/tests/task/task_runner/test_task_runner.py @@ -27,16 +27,11 @@ class GetTaskRunner(unittest.TestCase): - - @parameterized.expand([ - (import_path, ) for import_path in CORE_TASK_RUNNERS.values() - ]) + @parameterized.expand([(import_path,) for import_path in CORE_TASK_RUNNERS.values()]) def test_should_have_valid_imports(self, import_path): self.assertIsNotNone(import_string(import_path)) - @mock.patch( - 'airflow.task.task_runner.base_task_runner.subprocess' - ) + @mock.patch('airflow.task.task_runner.base_task_runner.subprocess') @mock.patch('airflow.task.task_runner._TASK_RUNNER_NAME', "StandardTaskRunner") def test_should_support_core_task_runner(self, mock_subprocess): local_task_job = mock.MagicMock( @@ -48,7 +43,7 @@ def test_should_support_core_task_runner(self, mock_subprocess): @mock.patch( 'airflow.task.task_runner._TASK_RUNNER_NAME', - "tests.task.task_runner.test_task_runner.custom_task_runner" + "tests.task.task_runner.test_task_runner.custom_task_runner", ) def test_should_support_custom_task_runner(self): local_task_job = mock.MagicMock( diff --git a/tests/test_utils/amazon_system_helpers.py b/tests/test_utils/amazon_system_helpers.py index f252a66121711..ac0d153a65e99 100644 --- a/tests/test_utils/amazon_system_helpers.py +++ b/tests/test_utils/amazon_system_helpers.py @@ -28,9 +28,7 @@ from tests.test_utils.logging_command_executor import get_executor from tests.test_utils.system_tests_class import SystemTest -AWS_DAG_FOLDER = os.path.join( - AIRFLOW_MAIN_FOLDER, "airflow", "providers", "amazon", "aws", "example_dags" -) +AWS_DAG_FOLDER = os.path.join(AIRFLOW_MAIN_FOLDER, "airflow", "providers", "amazon", "aws", "example_dags") @contextmanager @@ -53,7 +51,6 @@ def provide_aws_s3_bucket(name): @pytest.mark.system("amazon") class AmazonSystemTest(SystemTest): - @staticmethod def _region_name(): return os.environ.get("REGION_NAME") @@ -85,8 +82,7 @@ def execute_with_ctx(cls, cmd: List[str]): executor.execute_cmd(cmd=cmd) @staticmethod - def create_connection(aws_conn_id: str, - region: str) -> None: + def create_connection(aws_conn_id: str, region: str) -> None: """ Create aws connection with region @@ -137,8 +133,7 @@ def create_emr_default_roles(cls) -> None: cls.execute_with_ctx(cmd) @staticmethod - def create_ecs_cluster(aws_conn_id: str, - cluster_name: str) -> None: + def create_ecs_cluster(aws_conn_id: str, cluster_name: str) -> None: """ Create ecs cluster with given name @@ -174,8 +169,7 @@ def create_ecs_cluster(aws_conn_id: str, ) @staticmethod - def delete_ecs_cluster(aws_conn_id: str, - cluster_name: str) -> None: + def delete_ecs_cluster(aws_conn_id: str, cluster_name: str) -> None: """ Delete ecs cluster with given short name or full Amazon Resource Name (ARN) @@ -193,14 +187,16 @@ def delete_ecs_cluster(aws_conn_id: str, ) @staticmethod - def create_ecs_task_definition(aws_conn_id: str, - task_definition: str, - container: str, - image: str, - execution_role_arn: str, - awslogs_group: str, - awslogs_region: str, - awslogs_stream_prefix: str) -> None: + def create_ecs_task_definition( + aws_conn_id: str, + task_definition: str, + container: str, + image: str, + execution_role_arn: str, + awslogs_group: str, + awslogs_region: str, + awslogs_stream_prefix: str, + ) -> None: """ Create ecs task definition with given name @@ -256,8 +252,7 @@ def create_ecs_task_definition(aws_conn_id: str, ) @staticmethod - def delete_ecs_task_definition(aws_conn_id: str, - task_definition: str) -> None: + def delete_ecs_task_definition(aws_conn_id: str, task_definition: str) -> None: """ Delete all revisions of given ecs task definition @@ -283,8 +278,7 @@ def delete_ecs_task_definition(aws_conn_id: str, ) @staticmethod - def is_ecs_task_definition_exists(aws_conn_id: str, - task_definition: str) -> bool: + def is_ecs_task_definition_exists(aws_conn_id: str, task_definition: str) -> bool: """ Check whether given task definition exits in ecs diff --git a/tests/test_utils/api_connexion_utils.py b/tests/test_utils/api_connexion_utils.py index 919c99f57560a..522c70c1b5ae3 100644 --- a/tests/test_utils/api_connexion_utils.py +++ b/tests/test_utils/api_connexion_utils.py @@ -69,5 +69,5 @@ def assert_401(response): 'detail': None, 'status': 401, 'title': 'Unauthorized', - 'type': EXCEPTIONS_LINK_MAP[401] + 'type': EXCEPTIONS_LINK_MAP[401], } diff --git a/tests/test_utils/asserts.py b/tests/test_utils/asserts.py index 220331de1540f..dccaad10e2894 100644 --- a/tests/test_utils/asserts.py +++ b/tests/test_utils/asserts.py @@ -32,6 +32,7 @@ def assert_equal_ignore_multiple_spaces(case, first, second, msg=None): def _trim(s): return re.sub(r"\s+", " ", s.strip()) + return case.assertEqual(_trim(first), _trim(second), msg) @@ -42,6 +43,7 @@ class CountQueries: Does not support multiple processes. When a new process is started in context, its queries will not be included. """ + def __init__(self): self.result = Counter() @@ -55,10 +57,11 @@ def __exit__(self, type_, value, tb): def after_cursor_execute(self, *args, **kwargs): stack = [ - f for f in traceback.extract_stack() - if 'sqlalchemy' not in f.filename and - __file__ != f.filename and - ('session.py' not in f.filename and f.name != 'wrapper') + f + for f in traceback.extract_stack() + if 'sqlalchemy' not in f.filename + and __file__ != f.filename + and ('session.py' not in f.filename and f.name != 'wrapper') ] stack_info = ">".join([f"{f.filename.rpartition('/')[-1]}:{f.name}" for f in stack][-3:]) lineno = stack[-1].lineno @@ -75,9 +78,12 @@ def assert_queries_count(expected_count, message_fmt=None): count = sum(result.values()) if expected_count != count: - message_fmt = message_fmt or "The expected number of db queries is {expected_count}. " \ - "The current number is {current_count}.\n\n" \ - "Recorded query locations:" + message_fmt = ( + message_fmt + or "The expected number of db queries is {expected_count}. " + "The current number is {current_count}.\n\n" + "Recorded query locations:" + ) message = message_fmt.format(current_count=count, expected_count=expected_count) for location, count in result.items(): diff --git a/tests/test_utils/azure_system_helpers.py b/tests/test_utils/azure_system_helpers.py index 6f6f1e8b3722c..add526d752109 100644 --- a/tests/test_utils/azure_system_helpers.py +++ b/tests/test_utils/azure_system_helpers.py @@ -39,7 +39,6 @@ def provide_azure_fileshare(share_name: str, wasb_conn_id: str, file_name: str, @pytest.mark.system("azure") class AzureSystemTest(SystemTest): - @classmethod def create_share(cls, share_name: str, wasb_conn_id: str): hook = AzureFileShareHook(wasb_conn_id=wasb_conn_id) diff --git a/tests/test_utils/db.py b/tests/test_utils/db.py index cd2ef469faa35..cfc43682f1ea0 100644 --- a/tests/test_utils/db.py +++ b/tests/test_utils/db.py @@ -17,8 +17,20 @@ # under the License. from airflow.jobs.base_job import BaseJob from airflow.models import ( - Connection, DagModel, DagRun, DagTag, Log, Pool, RenderedTaskInstanceFields, SlaMiss, TaskFail, - TaskInstance, TaskReschedule, Variable, XCom, errors, + Connection, + DagModel, + DagRun, + DagTag, + Log, + Pool, + RenderedTaskInstanceFields, + SlaMiss, + TaskFail, + TaskInstance, + TaskReschedule, + Variable, + XCom, + errors, ) from airflow.models.dagcode import DagCode from airflow.models.serialized_dag import SerializedDagModel diff --git a/tests/test_utils/gcp_system_helpers.py b/tests/test_utils/gcp_system_helpers.py index b80b76be29825..250593af240d0 100644 --- a/tests/test_utils/gcp_system_helpers.py +++ b/tests/test_utils/gcp_system_helpers.py @@ -86,19 +86,24 @@ def provide_gcp_context( :type project_id: str """ key_file_path = resolve_full_gcp_key_path(key_file_path) # type: ignore - with provide_gcp_conn_and_credentials(key_file_path, scopes, project_id), \ - tempfile.TemporaryDirectory() as gcloud_config_tmp, \ - mock.patch.dict('os.environ', {CLOUD_SDK_CONFIG_DIR: gcloud_config_tmp}): + with provide_gcp_conn_and_credentials( + key_file_path, scopes, project_id + ), tempfile.TemporaryDirectory() as gcloud_config_tmp, mock.patch.dict( + 'os.environ', {CLOUD_SDK_CONFIG_DIR: gcloud_config_tmp} + ): executor = get_executor() if project_id: - executor.execute_cmd([ - "gcloud", "config", "set", "core/project", project_id - ]) + executor.execute_cmd(["gcloud", "config", "set", "core/project", project_id]) if key_file_path: - executor.execute_cmd([ - "gcloud", "auth", "activate-service-account", f"--key-file={key_file_path}", - ]) + executor.execute_cmd( + [ + "gcloud", + "auth", + "activate-service-account", + f"--key-file={key_file_path}", + ] + ) yield @@ -121,8 +126,9 @@ def _service_key(): return os.environ.get(CREDENTIALS) @classmethod - def execute_with_ctx(cls, cmd: List[str], key: str = GCP_GCS_KEY, project_id=None, scopes=None, - silent: bool = False): + def execute_with_ctx( + cls, cmd: List[str], key: str = GCP_GCS_KEY, project_id=None, scopes=None, silent: bool = False + ): """ Executes command with context created by provide_gcp_context and activated service key. @@ -149,9 +155,7 @@ def delete_gcs_bucket(cls, name: str): @classmethod def upload_to_gcs(cls, source_uri: str, target_uri: str): - cls.execute_with_ctx( - ["gsutil", "cp", source_uri, target_uri], key=GCP_GCS_KEY - ) + cls.execute_with_ctx(["gsutil", "cp", source_uri, target_uri], key=GCP_GCS_KEY) @classmethod def upload_content_to_gcs(cls, lines: str, bucket: str, filename: str): @@ -192,10 +196,18 @@ def create_secret(cls, name: str, value: str): with tempfile.NamedTemporaryFile() as tmp: tmp.write(value.encode("UTF-8")) tmp.flush() - cmd = ["gcloud", "secrets", "create", name, - "--replication-policy", "automatic", - "--project", GoogleSystemTest._project_id(), - "--data-file", tmp.name] + cmd = [ + "gcloud", + "secrets", + "create", + name, + "--replication-policy", + "automatic", + "--project", + GoogleSystemTest._project_id(), + "--data-file", + tmp.name, + ] cls.execute_with_ctx(cmd, key=GCP_SECRET_MANAGER_KEY) @classmethod @@ -203,7 +215,15 @@ def update_secret(cls, name: str, value: str): with tempfile.NamedTemporaryFile() as tmp: tmp.write(value.encode("UTF-8")) tmp.flush() - cmd = ["gcloud", "secrets", "versions", "add", name, - "--project", GoogleSystemTest._project_id(), - "--data-file", tmp.name] + cmd = [ + "gcloud", + "secrets", + "versions", + "add", + name, + "--project", + GoogleSystemTest._project_id(), + "--data-file", + tmp.name, + ] cls.execute_with_ctx(cmd, key=GCP_SECRET_MANAGER_KEY) diff --git a/tests/test_utils/hdfs_utils.py b/tests/test_utils/hdfs_utils.py index 91a4a464a870f..348396f84e2e4 100644 --- a/tests/test_utils/hdfs_utils.py +++ b/tests/test_utils/hdfs_utils.py @@ -33,7 +33,6 @@ class FakeSnakeBiteClientException(Exception): class FakeSnakeBiteClient: - def __init__(self): self.started = True @@ -48,126 +47,142 @@ def ls(self, path, include_toplevel=False): # pylint: disable=invalid-name if path[0] == '/datadirectory/empty_directory' and not include_toplevel: return [] elif path[0] == '/datadirectory/datafile': - return [{ - 'group': 'supergroup', - 'permission': 420, - 'file_type': 'f', - 'access_time': 1481122343796, - 'block_replication': 3, - 'modification_time': 1481122343862, - 'length': 0, - 'blocksize': 134217728, - 'owner': 'hdfs', - 'path': '/datadirectory/datafile' - }] + return [ + { + 'group': 'supergroup', + 'permission': 420, + 'file_type': 'f', + 'access_time': 1481122343796, + 'block_replication': 3, + 'modification_time': 1481122343862, + 'length': 0, + 'blocksize': 134217728, + 'owner': 'hdfs', + 'path': '/datadirectory/datafile', + } + ] elif path[0] == '/datadirectory/empty_directory' and include_toplevel: - return [{ - 'group': 'supergroup', - 'permission': 493, - 'file_type': 'd', - 'access_time': 0, - 'block_replication': 0, - 'modification_time': 1481132141540, - 'length': 0, - 'blocksize': 0, - 'owner': 'hdfs', - 'path': '/datadirectory/empty_directory' - }] + return [ + { + 'group': 'supergroup', + 'permission': 493, + 'file_type': 'd', + 'access_time': 0, + 'block_replication': 0, + 'modification_time': 1481132141540, + 'length': 0, + 'blocksize': 0, + 'owner': 'hdfs', + 'path': '/datadirectory/empty_directory', + } + ] elif path[0] == '/datadirectory/not_empty_directory' and include_toplevel: - return [{ - 'group': 'supergroup', - 'permission': 493, - 'file_type': 'd', - 'access_time': 0, - 'block_replication': 0, - 'modification_time': 1481132141540, - 'length': 0, - 'blocksize': 0, - 'owner': 'hdfs', - 'path': '/datadirectory/empty_directory' - }, { - 'group': 'supergroup', - 'permission': 420, - 'file_type': 'f', - 'access_time': 1481122343796, - 'block_replication': 3, - 'modification_time': 1481122343862, - 'length': 0, - 'blocksize': 134217728, - 'owner': 'hdfs', - 'path': '/datadirectory/not_empty_directory/test_file' - }] + return [ + { + 'group': 'supergroup', + 'permission': 493, + 'file_type': 'd', + 'access_time': 0, + 'block_replication': 0, + 'modification_time': 1481132141540, + 'length': 0, + 'blocksize': 0, + 'owner': 'hdfs', + 'path': '/datadirectory/empty_directory', + }, + { + 'group': 'supergroup', + 'permission': 420, + 'file_type': 'f', + 'access_time': 1481122343796, + 'block_replication': 3, + 'modification_time': 1481122343862, + 'length': 0, + 'blocksize': 134217728, + 'owner': 'hdfs', + 'path': '/datadirectory/not_empty_directory/test_file', + }, + ] elif path[0] == '/datadirectory/not_empty_directory': - return [{ - 'group': 'supergroup', - 'permission': 420, - 'file_type': 'f', - 'access_time': 1481122343796, - 'block_replication': 3, - 'modification_time': 1481122343862, - 'length': 0, - 'blocksize': 134217728, - 'owner': 'hdfs', - 'path': '/datadirectory/not_empty_directory/test_file' - }] + return [ + { + 'group': 'supergroup', + 'permission': 420, + 'file_type': 'f', + 'access_time': 1481122343796, + 'block_replication': 3, + 'modification_time': 1481122343862, + 'length': 0, + 'blocksize': 134217728, + 'owner': 'hdfs', + 'path': '/datadirectory/not_empty_directory/test_file', + } + ] elif path[0] == '/datadirectory/not_existing_file_or_directory': raise FakeSnakeBiteClientException elif path[0] == '/datadirectory/regex_dir': - return [{ - 'group': 'supergroup', - 'permission': 420, - 'file_type': 'f', - 'access_time': 1481122343796, - 'block_replication': 3, - 'modification_time': 1481122343862, 'length': 12582912, - 'blocksize': 134217728, - 'owner': 'hdfs', - 'path': '/datadirectory/regex_dir/test1file' - }, { - 'group': 'supergroup', - 'permission': 420, - 'file_type': 'f', - 'access_time': 1481122343796, - 'block_replication': 3, - 'modification_time': 1481122343862, - 'length': 12582912, - 'blocksize': 134217728, - 'owner': 'hdfs', - 'path': '/datadirectory/regex_dir/test2file' - }, { - 'group': 'supergroup', - 'permission': 420, - 'file_type': 'f', - 'access_time': 1481122343796, - 'block_replication': 3, - 'modification_time': 1481122343862, - 'length': 12582912, - 'blocksize': 134217728, - 'owner': 'hdfs', - 'path': '/datadirectory/regex_dir/test3file' - }, { - 'group': 'supergroup', - 'permission': 420, - 'file_type': 'f', - 'access_time': 1481122343796, - 'block_replication': 3, - 'modification_time': 1481122343862, - 'length': 12582912, - 'blocksize': 134217728, - 'owner': 'hdfs', - 'path': '/datadirectory/regex_dir/copying_file_1.txt._COPYING_' - }, { - 'group': 'supergroup', - 'permission': 420, - 'file_type': 'f', - 'access_time': 1481122343796, - 'block_replication': 3, - 'modification_time': 1481122343862, - 'length': 12582912, - 'blocksize': 134217728, - 'owner': 'hdfs', - 'path': '/datadirectory/regex_dir/copying_file_3.txt.sftp' - }] + return [ + { + 'group': 'supergroup', + 'permission': 420, + 'file_type': 'f', + 'access_time': 1481122343796, + 'block_replication': 3, + 'modification_time': 1481122343862, + 'length': 12582912, + 'blocksize': 134217728, + 'owner': 'hdfs', + 'path': '/datadirectory/regex_dir/test1file', + }, + { + 'group': 'supergroup', + 'permission': 420, + 'file_type': 'f', + 'access_time': 1481122343796, + 'block_replication': 3, + 'modification_time': 1481122343862, + 'length': 12582912, + 'blocksize': 134217728, + 'owner': 'hdfs', + 'path': '/datadirectory/regex_dir/test2file', + }, + { + 'group': 'supergroup', + 'permission': 420, + 'file_type': 'f', + 'access_time': 1481122343796, + 'block_replication': 3, + 'modification_time': 1481122343862, + 'length': 12582912, + 'blocksize': 134217728, + 'owner': 'hdfs', + 'path': '/datadirectory/regex_dir/test3file', + }, + { + 'group': 'supergroup', + 'permission': 420, + 'file_type': 'f', + 'access_time': 1481122343796, + 'block_replication': 3, + 'modification_time': 1481122343862, + 'length': 12582912, + 'blocksize': 134217728, + 'owner': 'hdfs', + 'path': '/datadirectory/regex_dir/copying_file_1.txt._COPYING_', + }, + { + 'group': 'supergroup', + 'permission': 420, + 'file_type': 'f', + 'access_time': 1481122343796, + 'block_replication': 3, + 'modification_time': 1481122343862, + 'length': 12582912, + 'blocksize': 134217728, + 'owner': 'hdfs', + 'path': '/datadirectory/regex_dir/copying_file_3.txt.sftp', + }, + ] else: raise FakeSnakeBiteClientException diff --git a/tests/test_utils/logging_command_executor.py b/tests/test_utils/logging_command_executor.py index 081b21247bcf6..1ebf729acd860 100644 --- a/tests/test_utils/logging_command_executor.py +++ b/tests/test_utils/logging_command_executor.py @@ -24,7 +24,6 @@ class LoggingCommandExecutor(LoggingMixin): - def execute_cmd(self, cmd, silent=False, cwd=None, env=None): if silent: self.log.info("Executing in silent mode: '%s'", " ".join([shlex.quote(c) for c in cmd])) @@ -33,8 +32,12 @@ def execute_cmd(self, cmd, silent=False, cwd=None, env=None): else: self.log.info("Executing: '%s'", " ".join([shlex.quote(c) for c in cmd])) process = subprocess.Popen( - args=cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, - universal_newlines=True, cwd=cwd, env=env + args=cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True, + cwd=cwd, + env=env, ) output, err = process.communicate() retcode = process.poll() @@ -46,16 +49,16 @@ def execute_cmd(self, cmd, silent=False, cwd=None, env=None): def check_output(self, cmd): self.log.info("Executing for output: '%s'", " ".join([shlex.quote(c) for c in cmd])) - process = subprocess.Popen(args=cmd, stdout=subprocess.PIPE, - stderr=subprocess.PIPE) + process = subprocess.Popen(args=cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) output, err = process.communicate() retcode = process.poll() if retcode: self.log.error("Error when executing '%s'", " ".join([shlex.quote(c) for c in cmd])) self.log.info("Stdout: %s", output) self.log.info("Stderr: %s", err) - raise AirflowException("Retcode {} on {} with stdout: {}, stderr: {}". - format(retcode, " ".join(cmd), output, err)) + raise AirflowException( + "Retcode {} on {} with stdout: {}, stderr: {}".format(retcode, " ".join(cmd), output, err) + ) return output diff --git a/tests/test_utils/mock_hooks.py b/tests/test_utils/mock_hooks.py index 900cf206720d7..2c6e2893f4822 100644 --- a/tests/test_utils/mock_hooks.py +++ b/tests/test_utils/mock_hooks.py @@ -27,8 +27,7 @@ class MockHiveMetastoreHook(HiveMetastoreHook): def __init__(self, *args, **kwargs): self._find_valid_server = mock.MagicMock(return_value={}) - self.get_metastore_client = mock.MagicMock( - return_value=mock.MagicMock()) + self.get_metastore_client = mock.MagicMock(return_value=mock.MagicMock()) super().__init__() @@ -64,8 +63,7 @@ def __init__(self, *args, **kwargs): self.conn.execute = mock.MagicMock() self.get_conn = mock.MagicMock(return_value=self.conn) - self.get_first = mock.MagicMock( - return_value=[['val_0', 'val_1'], 'val_2']) + self.get_first = mock.MagicMock(return_value=[['val_0', 'val_1'], 'val_2']) super().__init__(*args, **kwargs) diff --git a/tests/test_utils/mock_operators.py b/tests/test_utils/mock_operators.py index ce410876413bf..534770ea4100b 100644 --- a/tests/test_utils/mock_operators.py +++ b/tests/test_utils/mock_operators.py @@ -51,6 +51,7 @@ class AirflowLink(BaseOperatorLink): """ Operator Link for Apache Airflow Website """ + name = 'airflow' def get_link(self, operator, dttm): @@ -62,9 +63,8 @@ class Dummy2TestOperator(BaseOperator): Example of an Operator that has an extra operator link and will be overridden by the one defined in tests/plugins/test_plugin.py """ - operator_extra_links = ( - AirflowLink(), - ) + + operator_extra_links = (AirflowLink(),) class Dummy3TestOperator(BaseOperator): @@ -72,6 +72,7 @@ class Dummy3TestOperator(BaseOperator): Example of an operator that has no extra Operator link. An operator link would be added to this operator via Airflow plugin """ + operator_extra_links = () @@ -113,12 +114,8 @@ def operator_extra_links(self): Return operator extra links """ if isinstance(self.bash_command, str) or self.bash_command is None: - return ( - CustomOpLink(), - ) - return ( - CustomBaseIndexOpLink(i) for i, _ in enumerate(self.bash_command) - ) + return (CustomOpLink(),) + return (CustomBaseIndexOpLink(i) for i, _ in enumerate(self.bash_command)) @apply_defaults def __init__(self, bash_command=None, **kwargs): @@ -134,6 +131,7 @@ class GoogleLink(BaseOperatorLink): """ Operator Link for Apache Airflow Website for Google """ + name = 'google' operators = [Dummy3TestOperator, CustomOperator] @@ -145,6 +143,7 @@ class AirflowLink2(BaseOperatorLink): """ Operator Link for Apache Airflow Website for 1.10.5 """ + name = 'airflow' operators = [Dummy2TestOperator, Dummy3TestOperator] @@ -156,6 +155,7 @@ class GithubLink(BaseOperatorLink): """ Operator Link for Apache Airflow GitHub """ + name = 'github' def get_link(self, operator, dttm): diff --git a/tests/test_utils/mock_process.py b/tests/test_utils/mock_process.py index abd1e42b969da..4031d55d8d67e 100644 --- a/tests/test_utils/mock_process.py +++ b/tests/test_utils/mock_process.py @@ -25,8 +25,7 @@ def __init__(self, extra_dejson=None, *args, **kwargs): self.get_records = mock.MagicMock(return_value=[['test_record']]) output = kwargs.get('output', ['' for _ in range(10)]) - self.readline = mock.MagicMock( - side_effect=[line.encode() for line in output]) + self.readline = mock.MagicMock(side_effect=[line.encode() for line in output]) def status(self, *args, **kwargs): return True @@ -35,8 +34,7 @@ def status(self, *args, **kwargs): class MockStdOut: def __init__(self, *args, **kwargs): output = kwargs.get('output', ['' for _ in range(10)]) - self.readline = mock.MagicMock( - side_effect=[line.encode() for line in output]) + self.readline = mock.MagicMock(side_effect=[line.encode() for line in output]) class MockSubProcess: @@ -54,8 +52,10 @@ def wait(self): class MockConnectionCursor: def __init__(self, *args, **kwargs): self.arraysize = None - self.description = [('hive_server_hook.a', 'INT_TYPE', None, None, None, None, True), - ('hive_server_hook.b', 'INT_TYPE', None, None, None, None, True)] + self.description = [ + ('hive_server_hook.a', 'INT_TYPE', None, None, None, None, True), + ('hive_server_hook.b', 'INT_TYPE', None, None, None, None, True), + ] self.iterable = [(1, 1), (2, 2)] self.conn_exists = kwargs.get('exists', True) diff --git a/tests/test_utils/perf/dags/elastic_dag.py b/tests/test_utils/perf/dags/elastic_dag.py index 4de9cd30f5c8b..9aa0a4d621c8c 100644 --- a/tests/test_utils/perf/dags/elastic_dag.py +++ b/tests/test_utils/perf/dags/elastic_dag.py @@ -141,6 +141,7 @@ class DagShape(Enum): """ Define shape of the Dag that will be used for testing. """ + NO_STRUCTURE = "no_structure" LINEAR = "linear" BINARY_TREE = "binary_tree" @@ -168,25 +169,25 @@ class DagShape(Enum): for dag_no in range(1, DAG_COUNT + 1): dag = DAG( - dag_id=safe_dag_id("__".join( - [ - DAG_PREFIX, - f"SHAPE={SHAPE.name.lower()}", - f"DAGS_COUNT={dag_no}_of_{DAG_COUNT}", - f"TASKS_COUNT=${TASKS_COUNT}", - f"START_DATE=${START_DATE_ENV}", - f"SCHEDULE_INTERVAL=${SCHEDULE_INTERVAL_ENV}", - ] - )), + dag_id=safe_dag_id( + "__".join( + [ + DAG_PREFIX, + f"SHAPE={SHAPE.name.lower()}", + f"DAGS_COUNT={dag_no}_of_{DAG_COUNT}", + f"TASKS_COUNT=${TASKS_COUNT}", + f"START_DATE=${START_DATE_ENV}", + f"SCHEDULE_INTERVAL=${SCHEDULE_INTERVAL_ENV}", + ] + ) + ), is_paused_upon_creation=False, default_args=args, schedule_interval=SCHEDULE_INTERVAL, ) elastic_dag_tasks = [ - BashOperator( - task_id="__".join(["tasks", f"{i}_of_{TASKS_COUNT}"]), bash_command='echo test', dag=dag - ) + BashOperator(task_id="__".join(["tasks", f"{i}_of_{TASKS_COUNT}"]), bash_command='echo test', dag=dag) for i in range(1, TASKS_COUNT + 1) ] diff --git a/tests/test_utils/perf/dags/perf_dag_1.py b/tests/test_utils/perf/dags/perf_dag_1.py index 021a910b54433..3757c7d40e092 100644 --- a/tests/test_utils/perf/dags/perf_dag_1.py +++ b/tests/test_utils/perf/dags/perf_dag_1.py @@ -30,14 +30,12 @@ } dag = DAG( - dag_id='perf_dag_1', default_args=args, - schedule_interval='@daily', - dagrun_timeout=timedelta(minutes=60)) + dag_id='perf_dag_1', default_args=args, schedule_interval='@daily', dagrun_timeout=timedelta(minutes=60) +) task_1 = BashOperator( - task_id='perf_task_1', - bash_command='sleep 5; echo "run_id={{ run_id }} | dag_run={{ dag_run }}"', - dag=dag) + task_id='perf_task_1', bash_command='sleep 5; echo "run_id={{ run_id }} | dag_run={{ dag_run }}"', dag=dag +) for i in range(2, 5): task = BashOperator( @@ -45,7 +43,8 @@ bash_command=''' sleep 5; echo "run_id={{ run_id }} | dag_run={{ dag_run }}" ''', - dag=dag) + dag=dag, + ) task.set_upstream(task_1) if __name__ == "__main__": diff --git a/tests/test_utils/perf/dags/perf_dag_2.py b/tests/test_utils/perf/dags/perf_dag_2.py index d9ef47efa1195..a0e8bba9429fe 100644 --- a/tests/test_utils/perf/dags/perf_dag_2.py +++ b/tests/test_utils/perf/dags/perf_dag_2.py @@ -30,14 +30,12 @@ } dag = DAG( - dag_id='perf_dag_2', default_args=args, - schedule_interval='@daily', - dagrun_timeout=timedelta(minutes=60)) + dag_id='perf_dag_2', default_args=args, schedule_interval='@daily', dagrun_timeout=timedelta(minutes=60) +) task_1 = BashOperator( - task_id='perf_task_1', - bash_command='sleep 5; echo "run_id={{ run_id }} | dag_run={{ dag_run }}"', - dag=dag) + task_id='perf_task_1', bash_command='sleep 5; echo "run_id={{ run_id }} | dag_run={{ dag_run }}"', dag=dag +) for i in range(2, 5): task = BashOperator( @@ -45,7 +43,8 @@ bash_command=''' sleep 5; echo "run_id={{ run_id }} | dag_run={{ dag_run }}" ''', - dag=dag) + dag=dag, + ) task.set_upstream(task_1) if __name__ == "__main__": diff --git a/tests/test_utils/perf/perf_kit/memory.py b/tests/test_utils/perf/perf_kit/memory.py index f84c505efafd3..5236e236d05db 100644 --- a/tests/test_utils/perf/perf_kit/memory.py +++ b/tests/test_utils/perf/perf_kit/memory.py @@ -36,6 +36,7 @@ def _human_readable_size(size, decimal_places=3): class TraceMemoryResult: """Trace results of memory,""" + def __init__(self): self.before = 0 self.after = 0 diff --git a/tests/test_utils/perf/perf_kit/repeat_and_time.py b/tests/test_utils/perf/perf_kit/repeat_and_time.py index 8efd7f10b6bda..b0434b1c579d0 100644 --- a/tests/test_utils/perf/perf_kit/repeat_and_time.py +++ b/tests/test_utils/perf/perf_kit/repeat_and_time.py @@ -24,6 +24,7 @@ class TimingResult: """Timing result.""" + def __init__(self): self.start_time = 0 self.end_time = 0 @@ -86,6 +87,7 @@ def timeout(seconds=1): :param seconds: Number of seconds """ + def handle_timeout(signum, frame): raise TimeoutException("Process timed out.") diff --git a/tests/test_utils/perf/perf_kit/sqlalchemy.py b/tests/test_utils/perf/perf_kit/sqlalchemy.py index e5c7c359cec90..b0c5a3ed70f5e 100644 --- a/tests/test_utils/perf/perf_kit/sqlalchemy.py +++ b/tests/test_utils/perf/perf_kit/sqlalchemy.py @@ -26,9 +26,8 @@ def _pretty_format_sql(text: str): import pygments from pygments.formatters.terminal import TerminalFormatter from pygments.lexers.sql import SqlLexer - text = pygments.highlight( - code=text, formatter=TerminalFormatter(), lexer=SqlLexer() - ).rstrip() + + text = pygments.highlight(code=text, formatter=TerminalFormatter(), lexer=SqlLexer()).rstrip() return text @@ -43,6 +42,7 @@ class TraceQueries: :param display_parameters: If True, display SQL statement parameters :param print_fn: The function used to display the text. By default,``builtins.print`` """ + def __init__( self, *, @@ -51,7 +51,7 @@ def __init__( display_trace: bool = True, display_sql: bool = False, display_parameters: bool = True, - print_fn: Callable[[str], None] = print + print_fn: Callable[[str], None] = print, ): self.display_num = display_num self.display_time = display_time @@ -61,13 +61,15 @@ def __init__( self.print_fn = print_fn self.query_count = 0 - def before_cursor_execute(self, - conn, - cursor, # pylint: disable=unused-argument - statement, # pylint: disable=unused-argument - parameters, # pylint: disable=unused-argument - context, # pylint: disable=unused-argument - executemany): # pylint: disable=unused-argument + def before_cursor_execute( + self, + conn, + cursor, # pylint: disable=unused-argument + statement, # pylint: disable=unused-argument + parameters, # pylint: disable=unused-argument + context, # pylint: disable=unused-argument + executemany, + ): # pylint: disable=unused-argument """ Executed before cursor. @@ -83,13 +85,15 @@ def before_cursor_execute(self, conn.info.setdefault("query_start_time", []).append(time.monotonic()) self.query_count += 1 - def after_cursor_execute(self, - conn, - cursor, # pylint: disable=unused-argument - statement, - parameters, - context, # pylint: disable=unused-argument - executemany): # pylint: disable=unused-argument + def after_cursor_execute( + self, + conn, + cursor, # pylint: disable=unused-argument + statement, + parameters, + context, # pylint: disable=unused-argument + executemany, + ): # pylint: disable=unused-argument """ Executed after cursor. @@ -139,6 +143,7 @@ def __enter__(self): def __exit__(self, type_, value, traceback): # noqa pylint: disable=redefined-outer-name import airflow.settings + event.remove(airflow.settings.engine, "before_cursor_execute", self.before_cursor_execute) event.remove(airflow.settings.engine, "after_cursor_execute", self.after_cursor_execute) @@ -150,6 +155,7 @@ class CountQueriesResult: """ Counter for number of queries. """ + def __init__(self): self.count = 0 @@ -180,13 +186,15 @@ def __exit__(self, type_, value, traceback): # noqa pylint: disable=redefined-o event.remove(airflow.settings.engine, "after_cursor_execute", self.after_cursor_execute) self.print_fn(f"Count SQL queries: {self.result.count}") - def after_cursor_execute(self, - conn, # pylint: disable=unused-argument - cursor, # pylint: disable=unused-argument - statement, # pylint: disable=unused-argument - parameters, # pylint: disable=unused-argument - context, # pylint: disable=unused-argument - executemany): # pylint: disable=unused-argument + def after_cursor_execute( + self, + conn, # pylint: disable=unused-argument + cursor, # pylint: disable=unused-argument + statement, # pylint: disable=unused-argument + parameters, # pylint: disable=unused-argument + context, # pylint: disable=unused-argument + executemany, + ): # pylint: disable=unused-argument """ Executed after cursor. @@ -212,13 +220,16 @@ def case(): from airflow.jobs.scheduler_job import DagFileProcessor - with mock.patch.dict("os.environ", { - "PERF_DAGS_COUNT": "200", - "PERF_TASKS_COUNT": "10", - "PERF_START_AGO": "2d", - "PERF_SCHEDULE_INTERVAL": "None", - "PERF_SHAPE": "no_structure", - }): + with mock.patch.dict( + "os.environ", + { + "PERF_DAGS_COUNT": "200", + "PERF_TASKS_COUNT": "10", + "PERF_START_AGO": "2d", + "PERF_SCHEDULE_INTERVAL": "None", + "PERF_SHAPE": "no_structure", + }, + ): log = logging.getLogger(__name__) processor = DagFileProcessor(dag_ids=[], log=log) dag_file = os.path.join(os.path.dirname(__file__), os.path.pardir, "dags", "elastic_dag.py") diff --git a/tests/test_utils/perf/scheduler_dag_execution_timing.py b/tests/test_utils/perf/scheduler_dag_execution_timing.py index 8b9b979357e44..87ea15f5a05c7 100755 --- a/tests/test_utils/perf/scheduler_dag_execution_timing.py +++ b/tests/test_utils/perf/scheduler_dag_execution_timing.py @@ -33,6 +33,7 @@ class ShortCircuitExecutorMixin: """ Mixin class to manage the scheduler state during the performance test run. """ + def __init__(self, dag_ids_to_watch, num_runs): super().__init__() self.num_runs_per_dag = num_runs @@ -48,8 +49,9 @@ def reset(self, dag_ids_to_watch): # A "cache" of DagRun row, so we don't have to look it up each # time. This is to try and reduce the impact of our # benchmarking code on runtime, - runs={} - ) for dag_id in dag_ids_to_watch + runs={}, + ) + for dag_id in dag_ids_to_watch } def change_state(self, key, state, info=None): @@ -58,6 +60,7 @@ def change_state(self, key, state, info=None): and then shut down the scheduler after the task is complete """ from airflow.utils.state import State + super().change_state(key, state, info=info) dag_id, _, execution_date, __ = key @@ -87,8 +90,9 @@ def change_state(self, key, state, info=None): self.log.warning("STOPPING SCHEDULER -- all runs complete") self.scheduler_job.processor_agent._done = True # pylint: disable=protected-access return - self.log.warning("WAITING ON %d RUNS", - sum(map(attrgetter('waiting_for'), self.dags_to_watch.values()))) + self.log.warning( + "WAITING ON %d RUNS", sum(map(attrgetter('waiting_for'), self.dags_to_watch.values())) + ) def get_executor_under_test(dotted_path): @@ -114,6 +118,7 @@ class ShortCircuitExecutor(ShortCircuitExecutorMixin, executor_cls): """ Placeholder class that implements the inheritance hierarchy """ + scheduler_job = None return ShortCircuitExecutor @@ -142,6 +147,7 @@ def pause_all_dags(session): Pause all Dags """ from airflow.models.dag import DagModel + session.query(DagModel).update({'is_paused': True}) @@ -154,9 +160,11 @@ def create_dag_runs(dag, num_runs, session): try: from airflow.utils.types import DagRunType + id_prefix = f'{DagRunType.SCHEDULED.value}__' except ImportError: from airflow.models.dagrun import DagRun + id_prefix = DagRun.ID_PREFIX # pylint: disable=no-member next_run_date = dag.normalize_schedule(dag.start_date or min(t.start_date for t in dag.tasks)) @@ -176,18 +184,27 @@ def create_dag_runs(dag, num_runs, session): @click.command() @click.option('--num-runs', default=1, help='number of DagRun, to run for each DAG') @click.option('--repeat', default=3, help='number of times to run test, to reduce variance') -@click.option('--pre-create-dag-runs', is_flag=True, default=False, - help='''Pre-create the dag runs and stop the scheduler creating more. +@click.option( + '--pre-create-dag-runs', + is_flag=True, + default=False, + help='''Pre-create the dag runs and stop the scheduler creating more. Warning: this makes the scheduler do (slightly) less work so may skew your numbers. Use sparingly! - ''') -@click.option('--executor-class', default='MockExecutor', - help=textwrap.dedent(''' + ''', +) +@click.option( + '--executor-class', + default='MockExecutor', + help=textwrap.dedent( + ''' Dotted path Executor class to test, for example 'airflow.executors.local_executor.LocalExecutor'. Defaults to MockExecutor which doesn't run tasks. - ''')) + ''' + ), # pylint: disable=too-many-locals +) @click.argument('dag_ids', required=True, nargs=-1) -def main(num_runs, repeat, pre_create_dag_runs, executor_class, dag_ids): # pylint: disable=too-many-locals +def main(num_runs, repeat, pre_create_dag_runs, executor_class, dag_ids): """ This script can be used to measure the total "scheduler overhead" of Airflow. @@ -250,7 +267,8 @@ def main(num_runs, repeat, pre_create_dag_runs, executor_class, dag_ids): # pyl message = ( f"DAG {dag_id} has incorrect end_date ({end_date}) for number of runs! " f"It should be " - f" {next_run_date}") + f" {next_run_date}" + ) sys.exit(message) if pre_create_dag_runs: @@ -298,13 +316,10 @@ def main(num_runs, repeat, pre_create_dag_runs, executor_class, dag_ids): # pyl msg = "Time for %d dag runs of %d dags with %d total tasks: %.4fs" if len(times) > 1: - print((msg + " (±%.3fs)") % ( - num_runs, - len(dags), - total_tasks, - statistics.mean(times), - statistics.stdev(times) - )) + print( + (msg + " (±%.3fs)") + % (num_runs, len(dags), total_tasks, statistics.mean(times), statistics.stdev(times)) + ) else: print(msg % (num_runs, len(dags), total_tasks, times[0])) diff --git a/tests/test_utils/perf/scheduler_ops_metrics.py b/tests/test_utils/perf/scheduler_ops_metrics.py index 0d3b4a474c640..a0ffd420c1ae9 100644 --- a/tests/test_utils/perf/scheduler_ops_metrics.py +++ b/tests/test_utils/perf/scheduler_ops_metrics.py @@ -59,9 +59,8 @@ class SchedulerMetricsJob(SchedulerJob): You can specify timeout in seconds as an optional parameter. Its default value is 6 seconds. """ - __mapper_args__ = { - 'polymorphic_identity': 'SchedulerMetricsJob' - } + + __mapper_args__ = {'polymorphic_identity': 'SchedulerMetricsJob'} def __init__(self, dag_ids, subdir, max_runtime_secs): self.max_runtime_secs = max_runtime_secs @@ -73,23 +72,32 @@ def print_stats(self): """ session = settings.Session() TI = TaskInstance - tis = ( - session - .query(TI) - .filter(TI.dag_id.in_(DAG_IDS)) - .all() - ) + tis = session.query(TI).filter(TI.dag_id.in_(DAG_IDS)).all() successful_tis = [x for x in tis if x.state == State.SUCCESS] - ti_perf = [(ti.dag_id, ti.task_id, ti.execution_date, - (ti.queued_dttm - self.start_date).total_seconds(), - (ti.start_date - self.start_date).total_seconds(), - (ti.end_date - self.start_date).total_seconds(), - ti.duration) for ti in successful_tis] - ti_perf_df = pd.DataFrame(ti_perf, columns=['dag_id', 'task_id', - 'execution_date', - 'queue_delay', - 'start_delay', 'land_time', - 'duration']) + ti_perf = [ + ( + ti.dag_id, + ti.task_id, + ti.execution_date, + (ti.queued_dttm - self.start_date).total_seconds(), + (ti.start_date - self.start_date).total_seconds(), + (ti.end_date - self.start_date).total_seconds(), + ti.duration, + ) + for ti in successful_tis + ] + ti_perf_df = pd.DataFrame( + ti_perf, + columns=[ + 'dag_id', + 'task_id', + 'execution_date', + 'queue_delay', + 'start_delay', + 'land_time', + 'duration', + ], + ) print('Performance Results') print('###################') @@ -99,9 +107,15 @@ def print_stats(self): print('###################') if len(tis) > len(successful_tis): print("WARNING!! The following task instances haven't completed") - print(pd.DataFrame([(ti.dag_id, ti.task_id, ti.execution_date, ti.state) - for ti in filter(lambda x: x.state != State.SUCCESS, tis)], - columns=['dag_id', 'task_id', 'execution_date', 'state'])) + print( + pd.DataFrame( + [ + (ti.dag_id, ti.task_id, ti.execution_date, ti.state) + for ti in filter(lambda x: x.state != State.SUCCESS, tis) + ], + columns=['dag_id', 'task_id', 'execution_date', 'state'], + ) + ) session.commit() @@ -114,23 +128,21 @@ def heartbeat(self): # Get all the relevant task instances TI = TaskInstance successful_tis = ( - session - .query(TI) - .filter(TI.dag_id.in_(DAG_IDS)) - .filter(TI.state.in_([State.SUCCESS])) - .all() + session.query(TI).filter(TI.dag_id.in_(DAG_IDS)).filter(TI.state.in_([State.SUCCESS])).all() ) session.commit() dagbag = DagBag(SUBDIR) dags = [dagbag.dags[dag_id] for dag_id in DAG_IDS] # the tasks in perf_dag_1 and per_dag_2 have a daily schedule interval. - num_task_instances = sum([(timezone.utcnow() - task.start_date).days - for dag in dags for task in dag.tasks]) + num_task_instances = sum( + [(timezone.utcnow() - task.start_date).days for dag in dags for task in dag.tasks] + ) - if (len(successful_tis) == num_task_instances or - (timezone.utcnow() - self.start_date).total_seconds() > - self.max_runtime_secs): + if ( + len(successful_tis) == num_task_instances + or (timezone.utcnow() - self.start_date).total_seconds() > self.max_runtime_secs + ): if len(successful_tis) == num_task_instances: self.log.info("All tasks processed! Printing stats.") else: @@ -145,9 +157,13 @@ def clear_dag_runs(): Remove any existing DAG runs for the perf test DAGs. """ session = settings.Session() - drs = session.query(DagRun).filter( - DagRun.dag_id.in_(DAG_IDS), - ).all() + drs = ( + session.query(DagRun) + .filter( + DagRun.dag_id.in_(DAG_IDS), + ) + .all() + ) for dr in drs: logging.info('Deleting DagRun :: %s', dr) session.delete(dr) @@ -159,12 +175,7 @@ def clear_dag_task_instances(): """ session = settings.Session() TI = TaskInstance - tis = ( - session - .query(TI) - .filter(TI.dag_id.in_(DAG_IDS)) - .all() - ) + tis = session.query(TI).filter(TI.dag_id.in_(DAG_IDS)).all() for ti in tis: logging.info('Deleting TaskInstance :: %s', ti) session.delete(ti) @@ -176,8 +187,7 @@ def set_dags_paused_state(is_paused): Toggle the pause state of the DAGs in the test. """ session = settings.Session() - dag_models = session.query(DagModel).filter( - DagModel.dag_id.in_(DAG_IDS)) + dag_models = session.query(DagModel).filter(DagModel.dag_id.in_(DAG_IDS)) for dag_model in dag_models: logging.info('Setting DAG :: %s is_paused=%s', dag_model, is_paused) dag_model.is_paused = is_paused @@ -205,10 +215,7 @@ def main(): clear_dag_runs() clear_dag_task_instances() - job = SchedulerMetricsJob( - dag_ids=DAG_IDS, - subdir=SUBDIR, - max_runtime_secs=max_runtime_secs) + job = SchedulerMetricsJob(dag_ids=DAG_IDS, subdir=SUBDIR, max_runtime_secs=max_runtime_secs) job.run() diff --git a/tests/test_utils/perf/sql_queries.py b/tests/test_utils/perf/sql_queries.py index fc3320cb81823..a8e57aa7fe184 100644 --- a/tests/test_utils/perf/sql_queries.py +++ b/tests/test_utils/perf/sql_queries.py @@ -34,9 +34,7 @@ LOG_LEVEL = "INFO" LOG_FILE = "/files/sql_stats.log" # Default to run in Breeze -os.environ[ - "AIRFLOW__LOGGING__LOGGING_CONFIG_CLASS" -] = "scripts.perf.sql_queries.DEBUG_LOGGING_CONFIG" +os.environ["AIRFLOW__LOGGING__LOGGING_CONFIG_CLASS"] = "scripts.perf.sql_queries.DEBUG_LOGGING_CONFIG" DEBUG_LOGGING_CONFIG = { "version": 1, @@ -76,6 +74,7 @@ class Query(NamedTuple): """ Define attributes of the queries that will be picked up by the performance tests. """ + function: str file: str location: int @@ -110,6 +109,7 @@ def reset_db(): Wrapper function that calls the airflows resetdb function. """ from airflow.utils.db import resetdb + resetdb() diff --git a/tests/test_utils/remote_user_api_auth_backend.py b/tests/test_utils/remote_user_api_auth_backend.py index 3793831d1c953..e5f189a0af94b 100644 --- a/tests/test_utils/remote_user_api_auth_backend.py +++ b/tests/test_utils/remote_user_api_auth_backend.py @@ -38,9 +38,8 @@ def init_app(_): def _lookup_user(user_email_or_username: str): security_manager = current_app.appbuilder.sm - user = ( - security_manager.find_user(email=user_email_or_username) - or security_manager.find_user(username=user_email_or_username) + user = security_manager.find_user(email=user_email_or_username) or security_manager.find_user( + username=user_email_or_username ) if not user: return None @@ -53,6 +52,7 @@ def _lookup_user(user_email_or_username: str): def requires_authentication(function: T): """Decorator for functions that require authentication""" + @wraps(function) def decorated(*args, **kwargs): user_id = request.remote_user diff --git a/tests/test_utils/test_remote_user_api_auth_backend.py b/tests/test_utils/test_remote_user_api_auth_backend.py index c63d0f46fdf5a..966412fbf4e97 100644 --- a/tests/test_utils/test_remote_user_api_auth_backend.py +++ b/tests/test_utils/test_remote_user_api_auth_backend.py @@ -49,9 +49,7 @@ def test_success_using_username(self): clear_db_pools() with self.app.test_client() as test_client: - response = test_client.get( - "/api/experimental/pools", environ_overrides={'REMOTE_USER': "test"} - ) + response = test_client.get("/api/experimental/pools", environ_overrides={'REMOTE_USER': "test"}) self.assertEqual("test@fab.org", current_user.email) self.assertEqual(200, response.status_code) diff --git a/tests/ti_deps/deps/fake_models.py b/tests/ti_deps/deps/fake_models.py index a7f16b2fa44b7..63346101c8aeb 100644 --- a/tests/ti_deps/deps/fake_models.py +++ b/tests/ti_deps/deps/fake_models.py @@ -20,7 +20,6 @@ class FakeTI: - def __init__(self, **kwds): self.__dict__.update(kwds) @@ -32,13 +31,11 @@ def are_dependents_done(self, session): # pylint: disable=unused-argument class FakeTask: - def __init__(self, **kwds): self.__dict__.update(kwds) class FakeDag: - def __init__(self, **kwds): self.__dict__.update(kwds) @@ -47,6 +44,5 @@ def get_running_dagruns(self, _): class FakeContext: - def __init__(self, **kwds): self.__dict__.update(kwds) diff --git a/tests/ti_deps/deps/test_dag_ti_slots_available_dep.py b/tests/ti_deps/deps/test_dag_ti_slots_available_dep.py index 3b27fe63be376..f6fba610f6710 100644 --- a/tests/ti_deps/deps/test_dag_ti_slots_available_dep.py +++ b/tests/ti_deps/deps/test_dag_ti_slots_available_dep.py @@ -24,7 +24,6 @@ class TestDagTISlotsAvailableDep(unittest.TestCase): - def test_concurrency_reached(self): """ Test concurrency reached should fail dep diff --git a/tests/ti_deps/deps/test_dag_unpaused_dep.py b/tests/ti_deps/deps/test_dag_unpaused_dep.py index a397af0498074..1303f616c9370 100644 --- a/tests/ti_deps/deps/test_dag_unpaused_dep.py +++ b/tests/ti_deps/deps/test_dag_unpaused_dep.py @@ -24,7 +24,6 @@ class TestDagUnpausedDep(unittest.TestCase): - def test_concurrency_reached(self): """ Test paused DAG should fail dependency diff --git a/tests/ti_deps/deps/test_dagrun_exists_dep.py b/tests/ti_deps/deps/test_dagrun_exists_dep.py index 567a3aa6508e6..a1696c0794810 100644 --- a/tests/ti_deps/deps/test_dagrun_exists_dep.py +++ b/tests/ti_deps/deps/test_dagrun_exists_dep.py @@ -25,7 +25,6 @@ class TestDagrunRunningDep(unittest.TestCase): - @patch('airflow.models.DagRun.find', return_value=()) def test_dagrun_doesnt_exist(self, mock_dagrun_find): """ diff --git a/tests/ti_deps/deps/test_not_in_retry_period_dep.py b/tests/ti_deps/deps/test_not_in_retry_period_dep.py index bab5372ec1d86..2fd28d32e9675 100644 --- a/tests/ti_deps/deps/test_not_in_retry_period_dep.py +++ b/tests/ti_deps/deps/test_not_in_retry_period_dep.py @@ -29,9 +29,7 @@ class TestNotInRetryPeriodDep(unittest.TestCase): - - def _get_task_instance(self, state, end_date=None, - retry_delay=timedelta(minutes=15)): + def _get_task_instance(self, state, end_date=None, retry_delay=timedelta(minutes=15)): task = Mock(retry_delay=retry_delay, retry_exponential_backoff=False) ti = TaskInstance(task=task, state=state, execution_date=None) ti.end_date = end_date @@ -42,8 +40,7 @@ def test_still_in_retry_period(self): """ Task instances that are in their retry period should fail this dep """ - ti = self._get_task_instance(State.UP_FOR_RETRY, - end_date=datetime(2016, 1, 1, 15, 30)) + ti = self._get_task_instance(State.UP_FOR_RETRY, end_date=datetime(2016, 1, 1, 15, 30)) self.assertTrue(ti.is_premature) self.assertFalse(NotInRetryPeriodDep().is_met(ti=ti)) @@ -52,8 +49,7 @@ def test_retry_period_finished(self): """ Task instance's that have had their retry period elapse should pass this dep """ - ti = self._get_task_instance(State.UP_FOR_RETRY, - end_date=datetime(2016, 1, 1)) + ti = self._get_task_instance(State.UP_FOR_RETRY, end_date=datetime(2016, 1, 1)) self.assertFalse(ti.is_premature) self.assertTrue(NotInRetryPeriodDep().is_met(ti=ti)) diff --git a/tests/ti_deps/deps/test_not_previously_skipped_dep.py b/tests/ti_deps/deps/test_not_previously_skipped_dep.py index 658fee3644076..9ddd0287c78d1 100644 --- a/tests/ti_deps/deps/test_not_previously_skipped_dep.py +++ b/tests/ti_deps/deps/test_not_previously_skipped_dep.py @@ -50,9 +50,7 @@ def test_no_skipmixin_parent(): A simple DAG with no branching. Both op1 and op2 are DummyOperator. NotPreviouslySkippedDep is met. """ start_date = pendulum.datetime(2020, 1, 1) - dag = DAG( - "test_no_skipmixin_parent_dag", schedule_interval=None, start_date=start_date - ) + dag = DAG("test_no_skipmixin_parent_dag", schedule_interval=None, start_date=start_date) op1 = DummyOperator(task_id="op1", dag=dag) op2 = DummyOperator(task_id="op2", dag=dag) op1 >> op2 @@ -71,9 +69,7 @@ def test_parent_follow_branch(): A simple DAG with a BranchPythonOperator that follows op2. NotPreviouslySkippedDep is met. """ start_date = pendulum.datetime(2020, 1, 1) - dag = DAG( - "test_parent_follow_branch_dag", schedule_interval=None, start_date=start_date - ) + dag = DAG("test_parent_follow_branch_dag", schedule_interval=None, start_date=start_date) dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=start_date) op1 = BranchPythonOperator(task_id="op1", python_callable=lambda: "op2", dag=dag) op2 = DummyOperator(task_id="op2", dag=dag) @@ -96,9 +92,7 @@ def test_parent_skip_branch(): session.query(DagRun).delete() session.query(TaskInstance).delete() start_date = pendulum.datetime(2020, 1, 1) - dag = DAG( - "test_parent_skip_branch_dag", schedule_interval=None, start_date=start_date - ) + dag = DAG("test_parent_skip_branch_dag", schedule_interval=None, start_date=start_date) dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=start_date) op1 = BranchPythonOperator(task_id="op1", python_callable=lambda: "op3", dag=dag) op2 = DummyOperator(task_id="op2", dag=dag) @@ -120,9 +114,7 @@ def test_parent_not_executed(): executed (no xcom data). NotPreviouslySkippedDep is met (no decision). """ start_date = pendulum.datetime(2020, 1, 1) - dag = DAG( - "test_parent_not_executed_dag", schedule_interval=None, start_date=start_date - ) + dag = DAG("test_parent_not_executed_dag", schedule_interval=None, start_date=start_date) op1 = BranchPythonOperator(task_id="op1", python_callable=lambda: "op3", dag=dag) op2 = DummyOperator(task_id="op2", dag=dag) op3 = DummyOperator(task_id="op3", dag=dag) diff --git a/tests/ti_deps/deps/test_prev_dagrun_dep.py b/tests/ti_deps/deps/test_prev_dagrun_dep.py index 948ebd5d39e71..4a05d77ea72c1 100644 --- a/tests/ti_deps/deps/test_prev_dagrun_dep.py +++ b/tests/ti_deps/deps/test_prev_dagrun_dep.py @@ -28,7 +28,6 @@ class TestPrevDagrunDep(unittest.TestCase): - def _get_task(self, **kwargs): return BaseOperator(task_id='test_task', dag=DAG('test_dag'), **kwargs) @@ -37,14 +36,16 @@ def test_not_depends_on_past(self): If depends on past isn't set in the task then the previous dagrun should be ignored, even though there is no previous_ti which would normally fail the dep """ - task = self._get_task(depends_on_past=False, - start_date=datetime(2016, 1, 1), - wait_for_downstream=False) - prev_ti = Mock(task=task, state=State.SUCCESS, - are_dependents_done=Mock(return_value=True), - execution_date=datetime(2016, 1, 2)) - ti = Mock(task=task, previous_ti=prev_ti, - execution_date=datetime(2016, 1, 3)) + task = self._get_task( + depends_on_past=False, start_date=datetime(2016, 1, 1), wait_for_downstream=False + ) + prev_ti = Mock( + task=task, + state=State.SUCCESS, + are_dependents_done=Mock(return_value=True), + execution_date=datetime(2016, 1, 2), + ) + ti = Mock(task=task, previous_ti=prev_ti, execution_date=datetime(2016, 1, 3)) dep_context = DepContext(ignore_depends_on_past=False) self.assertTrue(PrevDagrunDep().is_met(ti=ti, dep_context=dep_context)) @@ -54,14 +55,16 @@ def test_context_ignore_depends_on_past(self): If the context overrides depends_on_past then the dep should be met, even though there is no previous_ti which would normally fail the dep """ - task = self._get_task(depends_on_past=True, - start_date=datetime(2016, 1, 1), - wait_for_downstream=False) - prev_ti = Mock(task=task, state=State.SUCCESS, - are_dependents_done=Mock(return_value=True), - execution_date=datetime(2016, 1, 2)) - ti = Mock(task=task, previous_ti=prev_ti, - execution_date=datetime(2016, 1, 3)) + task = self._get_task( + depends_on_past=True, start_date=datetime(2016, 1, 1), wait_for_downstream=False + ) + prev_ti = Mock( + task=task, + state=State.SUCCESS, + are_dependents_done=Mock(return_value=True), + execution_date=datetime(2016, 1, 2), + ) + ti = Mock(task=task, previous_ti=prev_ti, execution_date=datetime(2016, 1, 3)) dep_context = DepContext(ignore_depends_on_past=True) self.assertTrue(PrevDagrunDep().is_met(ti=ti, dep_context=dep_context)) @@ -70,12 +73,11 @@ def test_first_task_run(self): """ The first task run for a TI should pass since it has no previous dagrun. """ - task = self._get_task(depends_on_past=True, - start_date=datetime(2016, 1, 1), - wait_for_downstream=False) + task = self._get_task( + depends_on_past=True, start_date=datetime(2016, 1, 1), wait_for_downstream=False + ) prev_ti = None - ti = Mock(task=task, previous_ti=prev_ti, - execution_date=datetime(2016, 1, 1)) + ti = Mock(task=task, previous_ti=prev_ti, execution_date=datetime(2016, 1, 1)) dep_context = DepContext(ignore_depends_on_past=False) self.assertTrue(PrevDagrunDep().is_met(ti=ti, dep_context=dep_context)) @@ -84,13 +86,11 @@ def test_prev_ti_bad_state(self): """ If the previous TI did not complete execution this dep should fail. """ - task = self._get_task(depends_on_past=True, - start_date=datetime(2016, 1, 1), - wait_for_downstream=False) - prev_ti = Mock(state=State.NONE, - are_dependents_done=Mock(return_value=True)) - ti = Mock(task=task, previous_ti=prev_ti, - execution_date=datetime(2016, 1, 2)) + task = self._get_task( + depends_on_past=True, start_date=datetime(2016, 1, 1), wait_for_downstream=False + ) + prev_ti = Mock(state=State.NONE, are_dependents_done=Mock(return_value=True)) + ti = Mock(task=task, previous_ti=prev_ti, execution_date=datetime(2016, 1, 2)) dep_context = DepContext(ignore_depends_on_past=False) self.assertFalse(PrevDagrunDep().is_met(ti=ti, dep_context=dep_context)) @@ -101,13 +101,9 @@ def test_failed_wait_for_downstream(self): previous dagrun then it should fail this dep if the downstream TIs of the previous TI are not done. """ - task = self._get_task(depends_on_past=True, - start_date=datetime(2016, 1, 1), - wait_for_downstream=True) - prev_ti = Mock(state=State.SUCCESS, - are_dependents_done=Mock(return_value=False)) - ti = Mock(task=task, previous_ti=prev_ti, - execution_date=datetime(2016, 1, 2)) + task = self._get_task(depends_on_past=True, start_date=datetime(2016, 1, 1), wait_for_downstream=True) + prev_ti = Mock(state=State.SUCCESS, are_dependents_done=Mock(return_value=False)) + ti = Mock(task=task, previous_ti=prev_ti, execution_date=datetime(2016, 1, 2)) dep_context = DepContext(ignore_depends_on_past=False) self.assertFalse(PrevDagrunDep().is_met(ti=ti, dep_context=dep_context)) @@ -116,16 +112,9 @@ def test_all_met(self): """ Test to make sure all of the conditions for the dep are met """ - task = self._get_task(depends_on_past=True, - start_date=datetime(2016, 1, 1), - wait_for_downstream=True) - prev_ti = Mock(state=State.SUCCESS, - are_dependents_done=Mock(return_value=True)) - ti = Mock( - task=task, - execution_date=datetime(2016, 1, 2), - **{'get_previous_ti.return_value': prev_ti} - ) + task = self._get_task(depends_on_past=True, start_date=datetime(2016, 1, 1), wait_for_downstream=True) + prev_ti = Mock(state=State.SUCCESS, are_dependents_done=Mock(return_value=True)) + ti = Mock(task=task, execution_date=datetime(2016, 1, 2), **{'get_previous_ti.return_value': prev_ti}) dep_context = DepContext(ignore_depends_on_past=False) self.assertTrue(PrevDagrunDep().is_met(ti=ti, dep_context=dep_context)) diff --git a/tests/ti_deps/deps/test_ready_to_reschedule_dep.py b/tests/ti_deps/deps/test_ready_to_reschedule_dep.py index 66562b30a8462..f9bbb9f282b3f 100644 --- a/tests/ti_deps/deps/test_ready_to_reschedule_dep.py +++ b/tests/ti_deps/deps/test_ready_to_reschedule_dep.py @@ -28,7 +28,6 @@ class TestNotInReschedulePeriodDep(unittest.TestCase): - def _get_task_instance(self, state): dag = DAG('test_dag') task = Mock(dag=dag) @@ -37,9 +36,14 @@ def _get_task_instance(self, state): def _get_task_reschedule(self, reschedule_date): task = Mock(dag_id='test_dag', task_id='test_task') - reschedule = TaskReschedule(task=task, execution_date=None, try_number=None, - start_date=reschedule_date, end_date=reschedule_date, - reschedule_date=reschedule_date) + reschedule = TaskReschedule( + task=task, + execution_date=None, + try_number=None, + start_date=reschedule_date, + end_date=reschedule_date, + reschedule_date=reschedule_date, + ) return reschedule def test_should_pass_if_ignore_in_reschedule_period_is_set(self): @@ -59,8 +63,9 @@ def test_should_pass_if_no_reschedule_record_exists(self, mock_query_for_task_in @patch('airflow.models.taskreschedule.TaskReschedule.query_for_task_instance') def test_should_pass_after_reschedule_date_one(self, mock_query_for_task_instance): - mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value = \ + mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value = ( self._get_task_reschedule(utcnow() - timedelta(minutes=1)) + ) ti = self._get_task_instance(State.UP_FOR_RESCHEDULE) self.assertTrue(ReadyToRescheduleDep().is_met(ti=ti)) @@ -76,8 +81,9 @@ def test_should_pass_after_reschedule_date_multiple(self, mock_query_for_task_in @patch('airflow.models.taskreschedule.TaskReschedule.query_for_task_instance') def test_should_fail_before_reschedule_date_one(self, mock_query_for_task_instance): - mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value = \ + mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value = ( self._get_task_reschedule(utcnow() + timedelta(minutes=1)) + ) ti = self._get_task_instance(State.UP_FOR_RESCHEDULE) self.assertFalse(ReadyToRescheduleDep().is_met(ti=ti)) diff --git a/tests/ti_deps/deps/test_runnable_exec_date_dep.py b/tests/ti_deps/deps/test_runnable_exec_date_dep.py index 4c08f577d163e..f12f3a284a890 100644 --- a/tests/ti_deps/deps/test_runnable_exec_date_dep.py +++ b/tests/ti_deps/deps/test_runnable_exec_date_dep.py @@ -30,13 +30,16 @@ @freeze_time('2016-11-01') -@pytest.mark.parametrize("allow_trigger_in_future,schedule_interval,execution_date,is_met", [ - (True, None, datetime(2016, 11, 3), True), - (True, "@daily", datetime(2016, 11, 3), False), - (False, None, datetime(2016, 11, 3), False), - (False, "@daily", datetime(2016, 11, 3), False), - (False, "@daily", datetime(2016, 11, 1), True), - (False, None, datetime(2016, 11, 1), True)] +@pytest.mark.parametrize( + "allow_trigger_in_future,schedule_interval,execution_date,is_met", + [ + (True, None, datetime(2016, 11, 3), True), + (True, "@daily", datetime(2016, 11, 3), False), + (False, None, datetime(2016, 11, 3), False), + (False, "@daily", datetime(2016, 11, 3), False), + (False, "@daily", datetime(2016, 11, 1), True), + (False, None, datetime(2016, 11, 1), True), + ], ) def test_exec_date_dep(allow_trigger_in_future, schedule_interval, execution_date, is_met): """ @@ -49,7 +52,8 @@ def test_exec_date_dep(allow_trigger_in_future, schedule_interval, execution_dat 'test_localtaskjob_heartbeat', start_date=datetime(2015, 1, 1), end_date=datetime(2016, 11, 5), - schedule_interval=schedule_interval) + schedule_interval=schedule_interval, + ) with dag: op1 = DummyOperator(task_id='op1') @@ -59,7 +63,6 @@ def test_exec_date_dep(allow_trigger_in_future, schedule_interval, execution_dat class TestRunnableExecDateDep(unittest.TestCase): - def _get_task_instance(self, execution_date, dag_end_date=None, task_end_date=None): dag = Mock(end_date=dag_end_date) task = Mock(dag=dag, end_date=task_end_date) @@ -74,7 +77,8 @@ def test_exec_date_after_end_date(self): 'test_localtaskjob_heartbeat', start_date=datetime(2015, 1, 1), end_date=datetime(2016, 11, 5), - schedule_interval=None) + schedule_interval=None, + ) with dag: op1 = DummyOperator(task_id='op1') diff --git a/tests/ti_deps/deps/test_task_concurrency.py b/tests/ti_deps/deps/test_task_concurrency.py index 4d8f5fba5a8f7..9bdd1fab1c971 100644 --- a/tests/ti_deps/deps/test_task_concurrency.py +++ b/tests/ti_deps/deps/test_task_concurrency.py @@ -27,7 +27,6 @@ class TestTaskConcurrencyDep(unittest.TestCase): - def _get_task(self, **kwargs): return BaseOperator(task_id='test_task', dag=DAG('test_dag'), **kwargs) diff --git a/tests/ti_deps/deps/test_task_not_running_dep.py b/tests/ti_deps/deps/test_task_not_running_dep.py index 2a4ab563f556d..353db3fafb8db 100644 --- a/tests/ti_deps/deps/test_task_not_running_dep.py +++ b/tests/ti_deps/deps/test_task_not_running_dep.py @@ -25,7 +25,6 @@ class TestTaskNotRunningDep(unittest.TestCase): - def test_not_running_state(self): ti = Mock(state=State.QUEUED, end_date=datetime(2016, 1, 1)) self.assertTrue(TaskNotRunningDep().is_met(ti=ti)) diff --git a/tests/ti_deps/deps/test_trigger_rule_dep.py b/tests/ti_deps/deps/test_trigger_rule_dep.py index 76b61b8af4f1f..c993e6c05fb66 100644 --- a/tests/ti_deps/deps/test_trigger_rule_dep.py +++ b/tests/ti_deps/deps/test_trigger_rule_dep.py @@ -34,11 +34,8 @@ class TestTriggerRuleDep(unittest.TestCase): - - def _get_task_instance(self, trigger_rule=TriggerRule.ALL_SUCCESS, - state=None, upstream_task_ids=None): - task = BaseOperator(task_id='test_task', trigger_rule=trigger_rule, - start_date=datetime(2015, 1, 1)) + def _get_task_instance(self, trigger_rule=TriggerRule.ALL_SUCCESS, state=None, upstream_task_ids=None): + task = BaseOperator(task_id='test_task', trigger_rule=trigger_rule, start_date=datetime(2015, 1, 1)) if upstream_task_ids: task._upstream_task_ids.update(upstream_task_ids) return TaskInstance(task=task, state=state, execution_date=task.start_date) @@ -62,15 +59,18 @@ def test_one_success_tr_success(self): One-success trigger rule success """ ti = self._get_task_instance(TriggerRule.ONE_SUCCESS, State.UP_FOR_RETRY) - dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - successes=1, - skipped=2, - failed=2, - upstream_failed=2, - done=2, - flag_upstream_failed=False, - session="Fake Session")) + dep_statuses = tuple( + TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + successes=1, + skipped=2, + failed=2, + upstream_failed=2, + done=2, + flag_upstream_failed=False, + session="Fake Session", + ) + ) self.assertEqual(len(dep_statuses), 0) def test_one_success_tr_failure(self): @@ -78,15 +78,18 @@ def test_one_success_tr_failure(self): One-success trigger rule failure """ ti = self._get_task_instance(TriggerRule.ONE_SUCCESS) - dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - successes=0, - skipped=2, - failed=2, - upstream_failed=2, - done=2, - flag_upstream_failed=False, - session="Fake Session")) + dep_statuses = tuple( + TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + successes=0, + skipped=2, + failed=2, + upstream_failed=2, + done=2, + flag_upstream_failed=False, + session="Fake Session", + ) + ) self.assertEqual(len(dep_statuses), 1) self.assertFalse(dep_statuses[0].passed) @@ -95,15 +98,18 @@ def test_one_failure_tr_failure(self): One-failure trigger rule failure """ ti = self._get_task_instance(TriggerRule.ONE_FAILED) - dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - successes=2, - skipped=0, - failed=0, - upstream_failed=0, - done=2, - flag_upstream_failed=False, - session="Fake Session")) + dep_statuses = tuple( + TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + successes=2, + skipped=0, + failed=0, + upstream_failed=0, + done=2, + flag_upstream_failed=False, + session="Fake Session", + ) + ) self.assertEqual(len(dep_statuses), 1) self.assertFalse(dep_statuses[0].passed) @@ -112,61 +118,72 @@ def test_one_failure_tr_success(self): One-failure trigger rule success """ ti = self._get_task_instance(TriggerRule.ONE_FAILED) - dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - successes=0, - skipped=2, - failed=2, - upstream_failed=0, - done=2, - flag_upstream_failed=False, - session="Fake Session")) + dep_statuses = tuple( + TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + successes=0, + skipped=2, + failed=2, + upstream_failed=0, + done=2, + flag_upstream_failed=False, + session="Fake Session", + ) + ) self.assertEqual(len(dep_statuses), 0) - dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - successes=0, - skipped=2, - failed=0, - upstream_failed=2, - done=2, - flag_upstream_failed=False, - session="Fake Session")) + dep_statuses = tuple( + TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + successes=0, + skipped=2, + failed=0, + upstream_failed=2, + done=2, + flag_upstream_failed=False, + session="Fake Session", + ) + ) self.assertEqual(len(dep_statuses), 0) def test_all_success_tr_success(self): """ All-success trigger rule success """ - ti = self._get_task_instance(TriggerRule.ALL_SUCCESS, - upstream_task_ids=["FakeTaskID"]) - dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - successes=1, - skipped=0, - failed=0, - upstream_failed=0, - done=1, - flag_upstream_failed=False, - session="Fake Session")) + ti = self._get_task_instance(TriggerRule.ALL_SUCCESS, upstream_task_ids=["FakeTaskID"]) + dep_statuses = tuple( + TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + successes=1, + skipped=0, + failed=0, + upstream_failed=0, + done=1, + flag_upstream_failed=False, + session="Fake Session", + ) + ) self.assertEqual(len(dep_statuses), 0) def test_all_success_tr_failure(self): """ All-success trigger rule failure """ - ti = self._get_task_instance(TriggerRule.ALL_SUCCESS, - upstream_task_ids=["FakeTaskID", - "OtherFakeTaskID"]) - dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - successes=1, - skipped=0, - failed=1, - upstream_failed=0, - done=2, - flag_upstream_failed=False, - session="Fake Session")) + ti = self._get_task_instance( + TriggerRule.ALL_SUCCESS, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"] + ) + dep_statuses = tuple( + TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + successes=1, + skipped=0, + failed=1, + upstream_failed=0, + done=2, + flag_upstream_failed=False, + session="Fake Session", + ) + ) self.assertEqual(len(dep_statuses), 1) self.assertFalse(dep_statuses[0].passed) @@ -174,18 +191,21 @@ def test_all_success_tr_skip(self): """ All-success trigger rule fails when some upstream tasks are skipped. """ - ti = self._get_task_instance(TriggerRule.ALL_SUCCESS, - upstream_task_ids=["FakeTaskID", - "OtherFakeTaskID"]) - dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - successes=1, - skipped=1, - failed=0, - upstream_failed=0, - done=2, - flag_upstream_failed=False, - session="Fake Session")) + ti = self._get_task_instance( + TriggerRule.ALL_SUCCESS, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"] + ) + dep_statuses = tuple( + TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + successes=1, + skipped=1, + failed=0, + upstream_failed=0, + done=2, + flag_upstream_failed=False, + session="Fake Session", + ) + ) self.assertEqual(len(dep_statuses), 1) self.assertFalse(dep_statuses[0].passed) @@ -194,18 +214,21 @@ def test_all_success_tr_skip_flag_upstream(self): All-success trigger rule fails when some upstream tasks are skipped. The state of the ti should be set to SKIPPED when flag_upstream_failed is True. """ - ti = self._get_task_instance(TriggerRule.ALL_SUCCESS, - upstream_task_ids=["FakeTaskID", - "OtherFakeTaskID"]) - dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - successes=1, - skipped=1, - failed=0, - upstream_failed=0, - done=2, - flag_upstream_failed=True, - session=Mock())) + ti = self._get_task_instance( + TriggerRule.ALL_SUCCESS, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"] + ) + dep_statuses = tuple( + TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + successes=1, + skipped=1, + failed=0, + upstream_failed=0, + done=2, + flag_upstream_failed=True, + session=Mock(), + ) + ) self.assertEqual(len(dep_statuses), 1) self.assertFalse(dep_statuses[0].passed) self.assertEqual(ti.state, State.SKIPPED) @@ -214,36 +237,42 @@ def test_none_failed_tr_success(self): """ All success including skip trigger rule success """ - ti = self._get_task_instance(TriggerRule.NONE_FAILED, - upstream_task_ids=["FakeTaskID", - "OtherFakeTaskID"]) - dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - successes=1, - skipped=1, - failed=0, - upstream_failed=0, - done=2, - flag_upstream_failed=False, - session="Fake Session")) + ti = self._get_task_instance( + TriggerRule.NONE_FAILED, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"] + ) + dep_statuses = tuple( + TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + successes=1, + skipped=1, + failed=0, + upstream_failed=0, + done=2, + flag_upstream_failed=False, + session="Fake Session", + ) + ) self.assertEqual(len(dep_statuses), 0) def test_none_failed_tr_skipped(self): """ All success including all upstream skips trigger rule success """ - ti = self._get_task_instance(TriggerRule.NONE_FAILED, - upstream_task_ids=["FakeTaskID", - "OtherFakeTaskID"]) - dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - successes=0, - skipped=2, - failed=0, - upstream_failed=0, - done=2, - flag_upstream_failed=True, - session=Mock())) + ti = self._get_task_instance( + TriggerRule.NONE_FAILED, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"] + ) + dep_statuses = tuple( + TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + successes=0, + skipped=2, + failed=0, + upstream_failed=0, + done=2, + flag_upstream_failed=True, + session=Mock(), + ) + ) self.assertEqual(len(dep_statuses), 0) self.assertEqual(ti.state, State.NONE) @@ -251,19 +280,21 @@ def test_none_failed_tr_failure(self): """ All success including skip trigger rule failure """ - ti = self._get_task_instance(TriggerRule.NONE_FAILED, - upstream_task_ids=["FakeTaskID", - "OtherFakeTaskID", - "FailedFakeTaskID"]) - dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - successes=1, - skipped=1, - failed=1, - upstream_failed=0, - done=3, - flag_upstream_failed=False, - session="Fake Session")) + ti = self._get_task_instance( + TriggerRule.NONE_FAILED, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID", "FailedFakeTaskID"] + ) + dep_statuses = tuple( + TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + successes=1, + skipped=1, + failed=1, + upstream_failed=0, + done=3, + flag_upstream_failed=False, + session="Fake Session", + ) + ) self.assertEqual(len(dep_statuses), 1) self.assertFalse(dep_statuses[0].passed) @@ -271,36 +302,42 @@ def test_none_failed_or_skipped_tr_success(self): """ All success including skip trigger rule success """ - ti = self._get_task_instance(TriggerRule.NONE_FAILED_OR_SKIPPED, - upstream_task_ids=["FakeTaskID", - "OtherFakeTaskID"]) - dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - successes=1, - skipped=1, - failed=0, - upstream_failed=0, - done=2, - flag_upstream_failed=False, - session="Fake Session")) + ti = self._get_task_instance( + TriggerRule.NONE_FAILED_OR_SKIPPED, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"] + ) + dep_statuses = tuple( + TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + successes=1, + skipped=1, + failed=0, + upstream_failed=0, + done=2, + flag_upstream_failed=False, + session="Fake Session", + ) + ) self.assertEqual(len(dep_statuses), 0) def test_none_failed_or_skipped_tr_skipped(self): """ All success including all upstream skips trigger rule success """ - ti = self._get_task_instance(TriggerRule.NONE_FAILED_OR_SKIPPED, - upstream_task_ids=["FakeTaskID", - "OtherFakeTaskID"]) - dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - successes=0, - skipped=2, - failed=0, - upstream_failed=0, - done=2, - flag_upstream_failed=True, - session=Mock())) + ti = self._get_task_instance( + TriggerRule.NONE_FAILED_OR_SKIPPED, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"] + ) + dep_statuses = tuple( + TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + successes=0, + skipped=2, + failed=0, + upstream_failed=0, + done=2, + flag_upstream_failed=True, + session=Mock(), + ) + ) self.assertEqual(len(dep_statuses), 0) self.assertEqual(ti.state, State.SKIPPED) @@ -308,19 +345,22 @@ def test_none_failed_or_skipped_tr_failure(self): """ All success including skip trigger rule failure """ - ti = self._get_task_instance(TriggerRule.NONE_FAILED_OR_SKIPPED, - upstream_task_ids=["FakeTaskID", - "OtherFakeTaskID", - "FailedFakeTaskID"]) - dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - successes=1, - skipped=1, - failed=1, - upstream_failed=0, - done=3, - flag_upstream_failed=False, - session="Fake Session")) + ti = self._get_task_instance( + TriggerRule.NONE_FAILED_OR_SKIPPED, + upstream_task_ids=["FakeTaskID", "OtherFakeTaskID", "FailedFakeTaskID"], + ) + dep_statuses = tuple( + TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + successes=1, + skipped=1, + failed=1, + upstream_failed=0, + done=3, + flag_upstream_failed=False, + session="Fake Session", + ) + ) self.assertEqual(len(dep_statuses), 1) self.assertFalse(dep_statuses[0].passed) @@ -328,36 +368,42 @@ def test_all_failed_tr_success(self): """ All-failed trigger rule success """ - ti = self._get_task_instance(TriggerRule.ALL_FAILED, - upstream_task_ids=["FakeTaskID", - "OtherFakeTaskID"]) - dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - successes=0, - skipped=0, - failed=2, - upstream_failed=0, - done=2, - flag_upstream_failed=False, - session="Fake Session")) + ti = self._get_task_instance( + TriggerRule.ALL_FAILED, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"] + ) + dep_statuses = tuple( + TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + successes=0, + skipped=0, + failed=2, + upstream_failed=0, + done=2, + flag_upstream_failed=False, + session="Fake Session", + ) + ) self.assertEqual(len(dep_statuses), 0) def test_all_failed_tr_failure(self): """ All-failed trigger rule failure """ - ti = self._get_task_instance(TriggerRule.ALL_FAILED, - upstream_task_ids=["FakeTaskID", - "OtherFakeTaskID"]) - dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - successes=2, - skipped=0, - failed=0, - upstream_failed=0, - done=2, - flag_upstream_failed=False, - session="Fake Session")) + ti = self._get_task_instance( + TriggerRule.ALL_FAILED, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"] + ) + dep_statuses = tuple( + TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + successes=2, + skipped=0, + failed=0, + upstream_failed=0, + done=2, + flag_upstream_failed=False, + session="Fake Session", + ) + ) self.assertEqual(len(dep_statuses), 1) self.assertFalse(dep_statuses[0].passed) @@ -365,36 +411,42 @@ def test_all_done_tr_success(self): """ All-done trigger rule success """ - ti = self._get_task_instance(TriggerRule.ALL_DONE, - upstream_task_ids=["FakeTaskID", - "OtherFakeTaskID"]) - dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - successes=2, - skipped=0, - failed=0, - upstream_failed=0, - done=2, - flag_upstream_failed=False, - session="Fake Session")) + ti = self._get_task_instance( + TriggerRule.ALL_DONE, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"] + ) + dep_statuses = tuple( + TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + successes=2, + skipped=0, + failed=0, + upstream_failed=0, + done=2, + flag_upstream_failed=False, + session="Fake Session", + ) + ) self.assertEqual(len(dep_statuses), 0) def test_all_done_tr_failure(self): """ All-done trigger rule failure """ - ti = self._get_task_instance(TriggerRule.ALL_DONE, - upstream_task_ids=["FakeTaskID", - "OtherFakeTaskID"]) - dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - successes=1, - skipped=0, - failed=0, - upstream_failed=0, - done=1, - flag_upstream_failed=False, - session="Fake Session")) + ti = self._get_task_instance( + TriggerRule.ALL_DONE, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"] + ) + dep_statuses = tuple( + TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + successes=1, + skipped=0, + failed=0, + upstream_failed=0, + done=1, + flag_upstream_failed=False, + session="Fake Session", + ) + ) self.assertEqual(len(dep_statuses), 1) self.assertFalse(dep_statuses[0].passed) @@ -403,78 +455,92 @@ def test_none_skipped_tr_success(self): None-skipped trigger rule success """ - ti = self._get_task_instance(TriggerRule.NONE_SKIPPED, - upstream_task_ids=["FakeTaskID", - "OtherFakeTaskID", - "FailedFakeTaskID"]) + ti = self._get_task_instance( + TriggerRule.NONE_SKIPPED, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID", "FailedFakeTaskID"] + ) with create_session() as session: - dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - successes=2, - skipped=0, - failed=1, - upstream_failed=0, - done=3, - flag_upstream_failed=False, - session=session)) + dep_statuses = tuple( + TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + successes=2, + skipped=0, + failed=1, + upstream_failed=0, + done=3, + flag_upstream_failed=False, + session=session, + ) + ) self.assertEqual(len(dep_statuses), 0) # with `flag_upstream_failed` set to True - dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - successes=0, - skipped=0, - failed=3, - upstream_failed=0, - done=3, - flag_upstream_failed=True, - session=session)) + dep_statuses = tuple( + TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + successes=0, + skipped=0, + failed=3, + upstream_failed=0, + done=3, + flag_upstream_failed=True, + session=session, + ) + ) self.assertEqual(len(dep_statuses), 0) def test_none_skipped_tr_failure(self): """ None-skipped trigger rule failure """ - ti = self._get_task_instance(TriggerRule.NONE_SKIPPED, - upstream_task_ids=["FakeTaskID", - "SkippedTaskID"]) + ti = self._get_task_instance( + TriggerRule.NONE_SKIPPED, upstream_task_ids=["FakeTaskID", "SkippedTaskID"] + ) with create_session() as session: - dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - successes=1, - skipped=1, - failed=0, - upstream_failed=0, - done=2, - flag_upstream_failed=False, - session=session)) + dep_statuses = tuple( + TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + successes=1, + skipped=1, + failed=0, + upstream_failed=0, + done=2, + flag_upstream_failed=False, + session=session, + ) + ) self.assertEqual(len(dep_statuses), 1) self.assertFalse(dep_statuses[0].passed) # with `flag_upstream_failed` set to True - dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - successes=1, - skipped=1, - failed=0, - upstream_failed=0, - done=2, - flag_upstream_failed=True, - session=session)) + dep_statuses = tuple( + TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + successes=1, + skipped=1, + failed=0, + upstream_failed=0, + done=2, + flag_upstream_failed=True, + session=session, + ) + ) self.assertEqual(len(dep_statuses), 1) self.assertFalse(dep_statuses[0].passed) # Fail until all upstream tasks have completed execution - dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - successes=0, - skipped=0, - failed=0, - upstream_failed=0, - done=0, - flag_upstream_failed=False, - session=session)) + dep_statuses = tuple( + TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + successes=0, + skipped=0, + failed=0, + upstream_failed=0, + done=0, + flag_upstream_failed=False, + session=session, + ) + ) self.assertEqual(len(dep_statuses), 1) self.assertFalse(dep_statuses[0].passed) @@ -484,15 +550,18 @@ def test_unknown_tr(self): """ ti = self._get_task_instance() ti.task.trigger_rule = "Unknown Trigger Rule" - dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - successes=1, - skipped=0, - failed=0, - upstream_failed=0, - done=1, - flag_upstream_failed=False, - session="Fake Session")) + dep_statuses = tuple( + TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + successes=1, + skipped=0, + failed=0, + upstream_failed=0, + done=1, + flag_upstream_failed=False, + session="Fake Session", + ) + ) self.assertEqual(len(dep_statuses), 1) self.assertFalse(dep_statuses[0].passed) @@ -506,10 +575,7 @@ def test_get_states_count_upstream_ti(self): get_states_count_upstream_ti = TriggerRuleDep._get_states_count_upstream_ti session = settings.Session() now = timezone.utcnow() - dag = DAG( - 'test_dagrun_with_pre_tis', - start_date=DEFAULT_DATE, - default_args={'owner': 'owner1'}) + dag = DAG('test_dagrun_with_pre_tis', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) with dag: op1 = DummyOperator(task_id='A') @@ -524,10 +590,9 @@ def test_get_states_count_upstream_ti(self): clear_db_runs() dag.clear() - dr = dag.create_dagrun(run_id='test_dagrun_with_pre_tis', - state=State.RUNNING, - execution_date=now, - start_date=now) + dr = dag.create_dagrun( + run_id='test_dagrun_with_pre_tis', state=State.RUNNING, execution_date=now, start_date=now + ) ti_op1 = TaskInstance(task=dag.get_task(op1.task_id), execution_date=dr.execution_date) ti_op2 = TaskInstance(task=dag.get_task(op2.task_id), execution_date=dr.execution_date) @@ -545,13 +610,16 @@ def test_get_states_count_upstream_ti(self): # check handling with cases that tasks are triggered from backfill with no finished tasks finished_tasks = DepContext().ensure_finished_tasks(ti_op2.task.dag, ti_op2.execution_date, session) - self.assertEqual(get_states_count_upstream_ti(finished_tasks=finished_tasks, ti=ti_op2), - (1, 0, 0, 0, 1)) + self.assertEqual( + get_states_count_upstream_ti(finished_tasks=finished_tasks, ti=ti_op2), (1, 0, 0, 0, 1) + ) finished_tasks = dr.get_task_instances(state=State.finished, session=session) - self.assertEqual(get_states_count_upstream_ti(finished_tasks=finished_tasks, ti=ti_op4), - (1, 0, 1, 0, 2)) - self.assertEqual(get_states_count_upstream_ti(finished_tasks=finished_tasks, ti=ti_op5), - (2, 0, 1, 0, 3)) + self.assertEqual( + get_states_count_upstream_ti(finished_tasks=finished_tasks, ti=ti_op4), (1, 0, 1, 0, 2) + ) + self.assertEqual( + get_states_count_upstream_ti(finished_tasks=finished_tasks, ti=ti_op5), (2, 0, 1, 0, 3) + ) dr.update_state() self.assertEqual(State.SUCCESS, dr.state) diff --git a/tests/ti_deps/deps/test_valid_state_dep.py b/tests/ti_deps/deps/test_valid_state_dep.py index 74c7e2aaf9d48..7e6ee7fcf2fac 100644 --- a/tests/ti_deps/deps/test_valid_state_dep.py +++ b/tests/ti_deps/deps/test_valid_state_dep.py @@ -26,7 +26,6 @@ class TestValidStateDep(unittest.TestCase): - def test_valid_state(self): """ Valid state should pass this dep diff --git a/tests/utils/log/test_file_processor_handler.py b/tests/utils/log/test_file_processor_handler.py index a3556e5cc7a83..20115059bc7a8 100644 --- a/tests/utils/log/test_file_processor_handler.py +++ b/tests/utils/log/test_file_processor_handler.py @@ -37,8 +37,7 @@ def setUp(self): def test_non_template(self): date = timezone.utcnow().strftime("%Y-%m-%d") - handler = FileProcessorHandler(base_log_folder=self.base_log_folder, - filename_template=self.filename) + handler = FileProcessorHandler(base_log_folder=self.base_log_folder, filename_template=self.filename) handler.dag_dir = self.dag_dir path = os.path.join(self.base_log_folder, "latest") @@ -50,8 +49,9 @@ def test_non_template(self): def test_template(self): date = timezone.utcnow().strftime("%Y-%m-%d") - handler = FileProcessorHandler(base_log_folder=self.base_log_folder, - filename_template=self.filename_template) + handler = FileProcessorHandler( + base_log_folder=self.base_log_folder, filename_template=self.filename_template + ) handler.dag_dir = self.dag_dir path = os.path.join(self.base_log_folder, "latest") @@ -62,8 +62,7 @@ def test_template(self): self.assertTrue(os.path.exists(os.path.join(path, "logfile.log"))) def test_symlink_latest_log_directory(self): - handler = FileProcessorHandler(base_log_folder=self.base_log_folder, - filename_template=self.filename) + handler = FileProcessorHandler(base_log_folder=self.base_log_folder, filename_template=self.filename) handler.dag_dir = self.dag_dir date1 = (timezone.utcnow() + timedelta(days=1)).strftime("%Y-%m-%d") @@ -92,8 +91,7 @@ def test_symlink_latest_log_directory(self): self.assertTrue(os.path.exists(os.path.join(link, "log2"))) def test_symlink_latest_log_directory_exists(self): - handler = FileProcessorHandler(base_log_folder=self.base_log_folder, - filename_template=self.filename) + handler = FileProcessorHandler(base_log_folder=self.base_log_folder, filename_template=self.filename) handler.dag_dir = self.dag_dir date1 = (timezone.utcnow() + timedelta(days=1)).strftime("%Y-%m-%d") diff --git a/tests/utils/log/test_json_formatter.py b/tests/utils/log/test_json_formatter.py index a14bcc3c80540..ba827fb25e87e 100644 --- a/tests/utils/log/test_json_formatter.py +++ b/tests/utils/log/test_json_formatter.py @@ -30,6 +30,7 @@ class TestJSONFormatter(unittest.TestCase): """ TestJSONFormatter class combine all tests for JSONFormatter """ + def test_json_formatter_is_not_none(self): """ JSONFormatter instance should return not none @@ -61,5 +62,6 @@ def test_format_with_extras(self): log_record = makeLogRecord({"label": "value"}) json_fmt = JSONFormatter(json_fields=["label"], extras={'pod_extra': 'useful_message'}) # compare as a dicts to not fail on sorting errors - self.assertDictEqual(json.loads(json_fmt.format(log_record)), - {"label": "value", "pod_extra": "useful_message"}) + self.assertDictEqual( + json.loads(json_fmt.format(log_record)), {"label": "value", "pod_extra": "useful_message"} + ) diff --git a/tests/utils/log/test_log_reader.py b/tests/utils/log/test_log_reader.py index 2ae98c4534cff..b4d1331e4effa 100644 --- a/tests/utils/log/test_log_reader.py +++ b/tests/utils/log/test_log_reader.py @@ -105,10 +105,12 @@ def test_test_read_log_chunks_should_read_one_try(self): self.assertEqual( [ - ('', - f"*** Reading local file: " - f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/1.log\n" - f"try_number=1.\n") + ( + '', + f"*** Reading local file: " + f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/1.log\n" + f"try_number=1.\n", + ) ], logs[0], ) @@ -120,18 +122,30 @@ def test_test_read_log_chunks_should_read_all_files(self): self.assertEqual( [ - [('', - "*** Reading local file: " - f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/1.log\n" - "try_number=1.\n")], - [('', - f"*** Reading local file: " - f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/2.log\n" - f"try_number=2.\n")], - [('', - f"*** Reading local file: " - f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/3.log\n" - f"try_number=3.\n")], + [ + ( + '', + "*** Reading local file: " + f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/1.log\n" + "try_number=1.\n", + ) + ], + [ + ( + '', + f"*** Reading local file: " + f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/2.log\n" + f"try_number=2.\n", + ) + ], + [ + ( + '', + f"*** Reading local file: " + f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/3.log\n" + f"try_number=3.\n", + ) + ], ], logs, ) diff --git a/tests/utils/test_cli_util.py b/tests/utils/test_cli_util.py index dfa71e10c7413..64cd57fd16775 100644 --- a/tests/utils/test_cli_util.py +++ b/tests/utils/test_cli_util.py @@ -33,19 +33,19 @@ class TestCliUtil(unittest.TestCase): - def test_metrics_build(self): func_name = 'test' exec_date = datetime.utcnow() - namespace = Namespace(dag_id='foo', task_id='bar', - subcommand='test', execution_date=exec_date) + namespace = Namespace(dag_id='foo', task_id='bar', subcommand='test', execution_date=exec_date) metrics = cli._build_metrics(func_name, namespace) - expected = {'user': os.environ.get('USER'), - 'sub_command': 'test', - 'dag_id': 'foo', - 'task_id': 'bar', - 'execution_date': exec_date} + expected = { + 'user': os.environ.get('USER'), + 'sub_command': 'test', + 'dag_id': 'foo', + 'task_id': 'bar', + 'execution_date': exec_date, + } for k, v in expected.items(): self.assertEqual(v, metrics.get(k)) @@ -93,24 +93,24 @@ def test_get_dags(self): [ ( "airflow users create -u test2 -l doe -f jon -e jdoe@apache.org -r admin --password test", - "airflow users create -u test2 -l doe -f jon -e jdoe@apache.org -r admin --password ********" + "airflow users create -u test2 -l doe -f jon -e jdoe@apache.org -r admin --password ********", ), ( "airflow users create -u test2 -l doe -f jon -e jdoe@apache.org -r admin -p test", - "airflow users create -u test2 -l doe -f jon -e jdoe@apache.org -r admin -p ********" + "airflow users create -u test2 -l doe -f jon -e jdoe@apache.org -r admin -p ********", ), ( "airflow users create -u test2 -l doe -f jon -e jdoe@apache.org -r admin --password=test", - "airflow users create -u test2 -l doe -f jon -e jdoe@apache.org -r admin --password=********" + "airflow users create -u test2 -l doe -f jon -e jdoe@apache.org -r admin --password=********", ), ( "airflow users create -u test2 -l doe -f jon -e jdoe@apache.org -r admin -p=test", - "airflow users create -u test2 -l doe -f jon -e jdoe@apache.org -r admin -p=********" + "airflow users create -u test2 -l doe -f jon -e jdoe@apache.org -r admin -p=********", ), ( "airflow connections add dsfs --conn-login asd --conn-password test --conn-type google", "airflow connections add dsfs --conn-login asd --conn-password ******** --conn-type google", - ) + ), ] ) def test_cli_create_user_supplied_password_is_masked(self, given_command, expected_masked_command): diff --git a/tests/utils/test_compression.py b/tests/utils/test_compression.py index 1e62e8cf6ee18..b288ca60fde64 100644 --- a/tests/utils/test_compression.py +++ b/tests/utils/test_compression.py @@ -29,7 +29,6 @@ class TestCompression(unittest.TestCase): - def setUp(self): self.file_names = {} try: @@ -38,15 +37,12 @@ def setUp(self): line2 = b"2\tCompressionUtil\n" self.tmp_dir = tempfile.mkdtemp(prefix='test_utils_compression_') # create sample txt, gz and bz2 files - with tempfile.NamedTemporaryFile(mode='wb+', - dir=self.tmp_dir, - delete=False) as f_txt: + with tempfile.NamedTemporaryFile(mode='wb+', dir=self.tmp_dir, delete=False) as f_txt: self._set_fn(f_txt.name, '.txt') f_txt.writelines([header, line1, line2]) fn_gz = self._get_fn('.txt') + ".gz" - with gzip.GzipFile(filename=fn_gz, - mode="wb") as f_gz: + with gzip.GzipFile(filename=fn_gz, mode="wb") as f_gz: self._set_fn(fn_gz, '.gz') f_gz.writelines([header, line1, line2]) @@ -80,21 +76,22 @@ def _get_fn(self, ext): def test_uncompress_file(self): # Testing txt file type - self.assertRaisesRegex(NotImplementedError, - "^Received .txt format. Only gz and bz2.*", - compression.uncompress_file, - **{'input_file_name': None, - 'file_extension': '.txt', - 'dest_dir': None - }) + self.assertRaisesRegex( + NotImplementedError, + "^Received .txt format. Only gz and bz2.*", + compression.uncompress_file, + **{'input_file_name': None, 'file_extension': '.txt', 'dest_dir': None}, + ) # Testing gz file type fn_txt = self._get_fn('.txt') fn_gz = self._get_fn('.gz') txt_gz = compression.uncompress_file(fn_gz, '.gz', self.tmp_dir) - self.assertTrue(filecmp.cmp(txt_gz, fn_txt, shallow=False), - msg="Uncompressed file doest match original") + self.assertTrue( + filecmp.cmp(txt_gz, fn_txt, shallow=False), msg="Uncompressed file doest match original" + ) # Testing bz2 file type fn_bz2 = self._get_fn('.bz2') txt_bz2 = compression.uncompress_file(fn_bz2, '.bz2', self.tmp_dir) - self.assertTrue(filecmp.cmp(txt_bz2, fn_txt, shallow=False), - msg="Uncompressed file doest match original") + self.assertTrue( + filecmp.cmp(txt_bz2, fn_txt, shallow=False), msg="Uncompressed file doest match original" + ) diff --git a/tests/utils/test_dag_cycle.py b/tests/utils/test_dag_cycle.py index 320faafb45d4d..fec7440e82e12 100644 --- a/tests/utils/test_dag_cycle.py +++ b/tests/utils/test_dag_cycle.py @@ -27,19 +27,13 @@ class TestCycleTester(unittest.TestCase): def test_cycle_empty(self): # test empty - dag = DAG( - 'dag', - start_date=DEFAULT_DATE, - default_args={'owner': 'owner1'}) + dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) self.assertFalse(_test_cycle(dag)) def test_cycle_single_task(self): # test single task - dag = DAG( - 'dag', - start_date=DEFAULT_DATE, - default_args={'owner': 'owner1'}) + dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) with dag: DummyOperator(task_id='A') @@ -47,10 +41,7 @@ def test_cycle_single_task(self): self.assertFalse(_test_cycle(dag)) def test_semi_complex(self): - dag = DAG( - 'dag', - start_date=DEFAULT_DATE, - default_args={'owner': 'owner1'}) + dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) # A -> B -> C # B -> D @@ -67,10 +58,7 @@ def test_semi_complex(self): def test_cycle_no_cycle(self): # test no cycle - dag = DAG( - 'dag', - start_date=DEFAULT_DATE, - default_args={'owner': 'owner1'}) + dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) # A -> B -> C # B -> D @@ -91,10 +79,7 @@ def test_cycle_no_cycle(self): def test_cycle_loop(self): # test self loop - dag = DAG( - 'dag', - start_date=DEFAULT_DATE, - default_args={'owner': 'owner1'}) + dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) # A -> A with dag: @@ -106,10 +91,7 @@ def test_cycle_loop(self): def test_cycle_downstream_loop(self): # test downstream self loop - dag = DAG( - 'dag', - start_date=DEFAULT_DATE, - default_args={'owner': 'owner1'}) + dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) # A -> B -> C -> D -> E -> E with dag: @@ -129,10 +111,7 @@ def test_cycle_downstream_loop(self): def test_cycle_large_loop(self): # large loop - dag = DAG( - 'dag', - start_date=DEFAULT_DATE, - default_args={'owner': 'owner1'}) + dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) # A -> B -> C -> D -> E -> A with dag: @@ -150,10 +129,7 @@ def test_cycle_large_loop(self): def test_cycle_arbitrary_loop(self): # test arbitrary loop - dag = DAG( - 'dag', - start_date=DEFAULT_DATE, - default_args={'owner': 'owner1'}) + dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) # E-> A -> B -> F -> A # -> C -> F diff --git a/tests/utils/test_dag_processing.py b/tests/utils/test_dag_processing.py index 84feff7f654f9..8a9dc342555a9 100644 --- a/tests/utils/test_dag_processing.py +++ b/tests/utils/test_dag_processing.py @@ -36,7 +36,11 @@ from airflow.utils import timezone from airflow.utils.callback_requests import TaskCallbackRequest from airflow.utils.dag_processing import ( - DagFileProcessorAgent, DagFileProcessorManager, DagFileStat, DagParsingSignal, DagParsingStat, + DagFileProcessorAgent, + DagFileProcessorManager, + DagFileStat, + DagParsingSignal, + DagParsingStat, ) from airflow.utils.file import correct_maybe_zipped, open_maybe_zipped from airflow.utils.session import create_session @@ -45,8 +49,7 @@ from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags -TEST_DAG_FOLDER = os.path.join( - os.path.dirname(os.path.realpath(__file__)), os.pardir, 'dags') +TEST_DAG_FOLDER = os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir, 'dags') DEFAULT_DATE = timezone.datetime(2016, 1, 1) @@ -132,7 +135,8 @@ def test_max_runs_when_no_files(self): signal_conn=child_pipe, dag_ids=[], pickle_dags=False, - async_mode=async_mode) + async_mode=async_mode, + ) self.run_processor_manager_one_loop(manager, parent_pipe) child_pipe.close() @@ -147,11 +151,11 @@ def test_set_file_paths_when_processor_file_path_not_in_new_file_paths(self): signal_conn=MagicMock(), dag_ids=[], pickle_dags=False, - async_mode=True) + async_mode=True, + ) mock_processor = MagicMock() - mock_processor.stop.side_effect = AttributeError( - 'DagFileProcessor object has no attribute stop') + mock_processor.stop.side_effect = AttributeError('DagFileProcessor object has no attribute stop') mock_processor.terminate.side_effect = None manager._processors['missing_file.txt'] = mock_processor @@ -169,11 +173,11 @@ def test_set_file_paths_when_processor_file_path_is_in_new_file_paths(self): signal_conn=MagicMock(), dag_ids=[], pickle_dags=False, - async_mode=True) + async_mode=True, + ) mock_processor = MagicMock() - mock_processor.stop.side_effect = AttributeError( - 'DagFileProcessor object has no attribute stop') + mock_processor.stop.side_effect = AttributeError('DagFileProcessor object has no attribute stop') mock_processor.terminate.side_effect = None manager._processors['abc.txt'] = mock_processor @@ -190,7 +194,8 @@ def test_find_zombies(self): signal_conn=MagicMock(), dag_ids=[], pickle_dags=False, - async_mode=True) + async_mode=True, + ) dagbag = DagBag(TEST_DAG_FOLDER, read_dags_from_db=False) with create_session() as session: @@ -211,7 +216,8 @@ def test_find_zombies(self): session.commit() manager._last_zombie_query_time = timezone.utcnow() - timedelta( - seconds=manager._zombie_threshold_secs + 1) + seconds=manager._zombie_threshold_secs + 1 + ) manager._find_zombies() # pylint: disable=no-value-for-parameter requests = manager._callback_to_execute[dag.full_filepath] self.assertEqual(1, len(requests)) @@ -232,8 +238,7 @@ def test_handle_failure_callback_with_zombies_are_correctly_passed_to_dag_file_p file processors until the next zombie detection logic is invoked. """ test_dag_path = os.path.join(TEST_DAG_FOLDER, 'test_example_bash_operator.py') - with conf_vars({('scheduler', 'max_threads'): '1', - ('core', 'load_examples'): 'False'}): + with conf_vars({('scheduler', 'max_threads'): '1', ('core', 'load_examples'): 'False'}): dagbag = DagBag(test_dag_path, read_dags_from_db=False) with create_session() as session: session.query(LJ).delete() @@ -257,7 +262,7 @@ def test_handle_failure_callback_with_zombies_are_correctly_passed_to_dag_file_p TaskCallbackRequest( full_filepath=dag.full_filepath, simple_task_instance=SimpleTaskInstance(ti), - msg="Message" + msg="Message", ) ] @@ -282,7 +287,8 @@ def fake_processor_factory(*args, **kwargs): signal_conn=child_pipe, dag_ids=[], pickle_dags=False, - async_mode=async_mode) + async_mode=async_mode, + ) self.run_processor_manager_one_loop(manager, parent_pipe) @@ -296,10 +302,9 @@ def fake_processor_factory(*args, **kwargs): assert fake_processors[-1]._file_path == test_dag_path callback_requests = fake_processors[-1]._callback_requests - assert ( - {zombie.simple_task_instance.key for zombie in expected_failure_callback_requests} == - {result.simple_task_instance.key for result in callback_requests} - ) + assert {zombie.simple_task_instance.key for zombie in expected_failure_callback_requests} == { + result.simple_task_instance.key for result in callback_requests + } child_pipe.close() parent_pipe.close() @@ -316,7 +321,8 @@ def test_kill_timed_out_processors_kill(self, mock_kill, mock_pid): signal_conn=MagicMock(), dag_ids=[], pickle_dags=False, - async_mode=True) + async_mode=True, + ) processor = DagFileProcessorProcess('abc.txt', False, [], []) processor._start_time = timezone.make_aware(datetime.min) @@ -336,7 +342,8 @@ def test_kill_timed_out_processors_no_kill(self, mock_dag_file_processor, mock_p signal_conn=MagicMock(), dag_ids=[], pickle_dags=False, - async_mode=True) + async_mode=True, + ) processor = DagFileProcessorProcess('abc.txt', False, [], []) processor._start_time = timezone.make_aware(datetime.max) @@ -373,7 +380,8 @@ def test_dag_with_system_exit(self): processor_timeout=timedelta(seconds=5), signal_conn=child_pipe, pickle_dags=False, - async_mode=True) + async_mode=True, + ) manager._run_parsing_loop() @@ -407,10 +415,7 @@ def tearDown(self): @staticmethod def _processor_factory(file_path, zombies, dag_ids, pickle_dags): - return DagFileProcessorProcess(file_path, - pickle_dags, - dag_ids, - zombies) + return DagFileProcessorProcess(file_path, pickle_dags, dag_ids, zombies) def test_reload_module(self): """ @@ -431,13 +436,9 @@ class path, thus when reloading logging module the airflow.processor_manager pass # Starting dag processing with 0 max_runs to avoid redundant operations. - processor_agent = DagFileProcessorAgent(test_dag_path, - 0, - type(self)._processor_factory, - timedelta.max, - [], - False, - async_mode) + processor_agent = DagFileProcessorAgent( + test_dag_path, 0, type(self)._processor_factory, timedelta.max, [], False, async_mode + ) processor_agent.start() if not async_mode: processor_agent.run_single_parsing_loop() @@ -455,13 +456,9 @@ def test_parse_once(self): test_dag_path = os.path.join(TEST_DAG_FOLDER, 'test_scheduler_dags.py') async_mode = 'sqlite' not in conf.get('core', 'sql_alchemy_conn') - processor_agent = DagFileProcessorAgent(test_dag_path, - 1, - type(self)._processor_factory, - timedelta.max, - [], - False, - async_mode) + processor_agent = DagFileProcessorAgent( + test_dag_path, 1, type(self)._processor_factory, timedelta.max, [], False, async_mode + ) processor_agent.start() if not async_mode: processor_agent.run_single_parsing_loop() @@ -491,13 +488,9 @@ def test_launch_process(self): pass # Starting dag processing with 0 max_runs to avoid redundant operations. - processor_agent = DagFileProcessorAgent(test_dag_path, - 0, - type(self)._processor_factory, - timedelta.max, - [], - False, - async_mode) + processor_agent = DagFileProcessorAgent( + test_dag_path, 0, type(self)._processor_factory, timedelta.max, [], False, async_mode + ) processor_agent.start() if not async_mode: processor_agent.run_single_parsing_loop() diff --git a/tests/utils/test_dates.py b/tests/utils/test_dates.py index c2f0f7b6861f4..0cab6ae58d574 100644 --- a/tests/utils/test_dates.py +++ b/tests/utils/test_dates.py @@ -25,7 +25,6 @@ class TestDates(unittest.TestCase): - def test_days_ago(self): today = pendulum.today() today_midnight = pendulum.instance(datetime.fromordinal(today.date().toordinal())) @@ -44,39 +43,31 @@ def test_parse_execution_date(self): bad_execution_date_str = '2017-11-06TXX:00:00Z' self.assertEqual( - timezone.datetime(2017, 11, 2, 0, 0, 0), - dates.parse_execution_date(execution_date_str_wo_ms)) + timezone.datetime(2017, 11, 2, 0, 0, 0), dates.parse_execution_date(execution_date_str_wo_ms) + ) self.assertEqual( timezone.datetime(2017, 11, 5, 16, 18, 30, 989729), - dates.parse_execution_date(execution_date_str_w_ms)) + dates.parse_execution_date(execution_date_str_w_ms), + ) self.assertRaises(ValueError, dates.parse_execution_date, bad_execution_date_str) class TestUtilsDatesDateRange(unittest.TestCase): - def test_no_delta(self): - self.assertEqual(dates.date_range(datetime(2016, 1, 1), datetime(2016, 1, 3)), - []) + self.assertEqual(dates.date_range(datetime(2016, 1, 1), datetime(2016, 1, 3)), []) def test_end_date_before_start_date(self): with self.assertRaisesRegex(Exception, "Wait. start_date needs to be before end_date"): - dates.date_range(datetime(2016, 2, 1), - datetime(2016, 1, 1), - delta=timedelta(seconds=1)) + dates.date_range(datetime(2016, 2, 1), datetime(2016, 1, 1), delta=timedelta(seconds=1)) def test_both_end_date_and_num_given(self): with self.assertRaisesRegex(Exception, "Wait. Either specify end_date OR num"): - dates.date_range(datetime(2016, 1, 1), - datetime(2016, 1, 3), - num=2, - delta=timedelta(seconds=1)) + dates.date_range(datetime(2016, 1, 1), datetime(2016, 1, 3), num=2, delta=timedelta(seconds=1)) def test_invalid_delta(self): exception_msg = "Wait. delta must be either datetime.timedelta or cron expression as str" with self.assertRaisesRegex(Exception, exception_msg): - dates.date_range(datetime(2016, 1, 1), - datetime(2016, 1, 3), - delta=1) + dates.date_range(datetime(2016, 1, 1), datetime(2016, 1, 3), delta=1) def test_positive_num_given(self): for num in range(1, 10): diff --git a/tests/utils/test_db.py b/tests/utils/test_db.py index e724543f6b951..70a1a9c12794e 100644 --- a/tests/utils/test_db.py +++ b/tests/utils/test_db.py @@ -32,7 +32,6 @@ class TestDb(unittest.TestCase): - def test_database_schema_and_sqlalchemy_model_are_in_sync(self): all_meta_data = MetaData() for (table_name, table) in airflow_base.metadata.tables.items(): @@ -45,62 +44,35 @@ def test_database_schema_and_sqlalchemy_model_are_in_sync(self): # known diffs to ignore ignores = [ # ignore tables created by celery - lambda t: (t[0] == 'remove_table' and - t[1].name == 'celery_taskmeta'), - lambda t: (t[0] == 'remove_table' and - t[1].name == 'celery_tasksetmeta'), - + lambda t: (t[0] == 'remove_table' and t[1].name == 'celery_taskmeta'), + lambda t: (t[0] == 'remove_table' and t[1].name == 'celery_tasksetmeta'), # ignore indices created by celery - lambda t: (t[0] == 'remove_index' and - t[1].name == 'task_id'), - lambda t: (t[0] == 'remove_index' and - t[1].name == 'taskset_id'), - + lambda t: (t[0] == 'remove_index' and t[1].name == 'task_id'), + lambda t: (t[0] == 'remove_index' and t[1].name == 'taskset_id'), # Ignore all the fab tables - lambda t: (t[0] == 'remove_table' and - t[1].name == 'ab_permission'), - lambda t: (t[0] == 'remove_table' and - t[1].name == 'ab_register_user'), - lambda t: (t[0] == 'remove_table' and - t[1].name == 'ab_role'), - lambda t: (t[0] == 'remove_table' and - t[1].name == 'ab_permission_view'), - lambda t: (t[0] == 'remove_table' and - t[1].name == 'ab_permission_view_role'), - lambda t: (t[0] == 'remove_table' and - t[1].name == 'ab_user_role'), - lambda t: (t[0] == 'remove_table' and - t[1].name == 'ab_user'), - lambda t: (t[0] == 'remove_table' and - t[1].name == 'ab_view_menu'), - + lambda t: (t[0] == 'remove_table' and t[1].name == 'ab_permission'), + lambda t: (t[0] == 'remove_table' and t[1].name == 'ab_register_user'), + lambda t: (t[0] == 'remove_table' and t[1].name == 'ab_role'), + lambda t: (t[0] == 'remove_table' and t[1].name == 'ab_permission_view'), + lambda t: (t[0] == 'remove_table' and t[1].name == 'ab_permission_view_role'), + lambda t: (t[0] == 'remove_table' and t[1].name == 'ab_user_role'), + lambda t: (t[0] == 'remove_table' and t[1].name == 'ab_user'), + lambda t: (t[0] == 'remove_table' and t[1].name == 'ab_view_menu'), # Ignore all the fab indices - lambda t: (t[0] == 'remove_index' and - t[1].name == 'permission_id'), - lambda t: (t[0] == 'remove_index' and - t[1].name == 'name'), - lambda t: (t[0] == 'remove_index' and - t[1].name == 'user_id'), - lambda t: (t[0] == 'remove_index' and - t[1].name == 'username'), - lambda t: (t[0] == 'remove_index' and - t[1].name == 'field_string'), - lambda t: (t[0] == 'remove_index' and - t[1].name == 'email'), - lambda t: (t[0] == 'remove_index' and - t[1].name == 'permission_view_id'), - + lambda t: (t[0] == 'remove_index' and t[1].name == 'permission_id'), + lambda t: (t[0] == 'remove_index' and t[1].name == 'name'), + lambda t: (t[0] == 'remove_index' and t[1].name == 'user_id'), + lambda t: (t[0] == 'remove_index' and t[1].name == 'username'), + lambda t: (t[0] == 'remove_index' and t[1].name == 'field_string'), + lambda t: (t[0] == 'remove_index' and t[1].name == 'email'), + lambda t: (t[0] == 'remove_index' and t[1].name == 'permission_view_id'), # from test_security unit test - lambda t: (t[0] == 'remove_table' and - t[1].name == 'some_model'), + lambda t: (t[0] == 'remove_table' and t[1].name == 'some_model'), ] for ignore in ignores: diff = [d for d in diff if not ignore(d)] - self.assertFalse( - diff, - 'Database schema and SQLAlchemy model are not in sync: ' + str(diff) - ) + self.assertFalse(diff, 'Database schema and SQLAlchemy model are not in sync: ' + str(diff)) def test_only_single_head_revision_in_migrations(self): config = Config() diff --git a/tests/utils/test_decorators.py b/tests/utils/test_decorators.py index 94920b9dee4cc..274eb77da0711 100644 --- a/tests/utils/test_decorators.py +++ b/tests/utils/test_decorators.py @@ -37,7 +37,6 @@ def __init__(self, test_sub_param, **kwargs): class TestApplyDefault(unittest.TestCase): - def test_apply(self): dummy = DummyClass(test_param=True) self.assertTrue(dummy.test_param) @@ -60,8 +59,7 @@ def test_default_args(self): self.assertTrue(dummy_class.test_param) self.assertTrue(dummy_subclass.test_sub_param) - with self.assertRaisesRegex(AirflowException, - 'Argument.*test_sub_param.*required'): + with self.assertRaisesRegex(AirflowException, 'Argument.*test_sub_param.*required'): DummySubClass(default_args=default_args) # pylint: disable=no-value-for-parameter def test_incorrect_default_args(self): diff --git a/tests/utils/test_docs.py b/tests/utils/test_docs.py index 8a86bcada557c..81d6a1d6e8e95 100644 --- a/tests/utils/test_docs.py +++ b/tests/utils/test_docs.py @@ -24,12 +24,14 @@ class TestGetDocsUrl(unittest.TestCase): - @parameterized.expand([ - ('2.0.0.dev0', None, 'https://airflow.readthedocs.io/en/latest/'), - ('2.0.0.dev0', 'migration.html', 'https://airflow.readthedocs.io/en/latest/migration.html'), - ('1.10.0', None, 'https://airflow.apache.org/docs/1.10.0/'), - ('1.10.0', 'migration.html', 'https://airflow.apache.org/docs/1.10.0/migration.html'), - ]) + @parameterized.expand( + [ + ('2.0.0.dev0', None, 'https://airflow.readthedocs.io/en/latest/'), + ('2.0.0.dev0', 'migration.html', 'https://airflow.readthedocs.io/en/latest/migration.html'), + ('1.10.0', None, 'https://airflow.apache.org/docs/1.10.0/'), + ('1.10.0', 'migration.html', 'https://airflow.apache.org/docs/1.10.0/migration.html'), + ] + ) def test_should_return_link(self, version, page, expected_urk): with mock.patch('airflow.version.version', version): self.assertEqual(expected_urk, get_docs_url(page)) diff --git a/tests/utils/test_email.py b/tests/utils/test_email.py index 2a0810f21d599..8081fc85a74cf 100644 --- a/tests/utils/test_email.py +++ b/tests/utils/test_email.py @@ -34,48 +34,40 @@ class TestEmail(unittest.TestCase): - def test_get_email_address_single_email(self): emails_string = 'test1@example.com' - self.assertEqual( - get_email_address_list(emails_string), [emails_string]) + self.assertEqual(get_email_address_list(emails_string), [emails_string]) def test_get_email_address_comma_sep_string(self): emails_string = 'test1@example.com, test2@example.com' - self.assertEqual( - get_email_address_list(emails_string), EMAILS) + self.assertEqual(get_email_address_list(emails_string), EMAILS) def test_get_email_address_colon_sep_string(self): emails_string = 'test1@example.com; test2@example.com' - self.assertEqual( - get_email_address_list(emails_string), EMAILS) + self.assertEqual(get_email_address_list(emails_string), EMAILS) def test_get_email_address_list(self): emails_list = ['test1@example.com', 'test2@example.com'] - self.assertEqual( - get_email_address_list(emails_list), EMAILS) + self.assertEqual(get_email_address_list(emails_list), EMAILS) def test_get_email_address_tuple(self): emails_tuple = ('test1@example.com', 'test2@example.com') - self.assertEqual( - get_email_address_list(emails_tuple), EMAILS) + self.assertEqual(get_email_address_list(emails_tuple), EMAILS) def test_get_email_address_invalid_type(self): emails_string = 1 - self.assertRaises( - TypeError, get_email_address_list, emails_string) + self.assertRaises(TypeError, get_email_address_list, emails_string) def test_get_email_address_invalid_type_in_iterable(self): emails_list = ['test1@example.com', 2] - self.assertRaises( - TypeError, get_email_address_list, emails_list) + self.assertRaises(TypeError, get_email_address_list, emails_list) def setUp(self): conf.remove_option('email', 'EMAIL_BACKEND') @@ -91,8 +83,16 @@ def test_custom_backend(self, mock_send_email): with conf_vars({('email', 'email_backend'): 'tests.utils.test_email.send_email_test'}): utils.email.send_email('to', 'subject', 'content') send_email_test.assert_called_once_with( - 'to', 'subject', 'content', files=None, dryrun=False, - cc=None, bcc=None, mime_charset='utf-8', mime_subtype='mixed') + 'to', + 'subject', + 'content', + files=None, + dryrun=False, + cc=None, + bcc=None, + mime_charset='utf-8', + mime_subtype='mixed', + ) self.assertFalse(mock_send_email.called) def test_build_mime_message(self): @@ -162,8 +162,10 @@ def test_send_bcc_smtp(self, mock_send_mime): self.assertEqual('subject', msg['Subject']) self.assertEqual(conf.get('smtp', 'SMTP_MAIL_FROM'), msg['From']) self.assertEqual(2, len(msg.get_payload())) - self.assertEqual('attachment; filename="' + os.path.basename(attachment.name) + '"', - msg.get_payload()[-1].get('Content-Disposition')) + self.assertEqual( + 'attachment; filename="' + os.path.basename(attachment.name) + '"', + msg.get_payload()[-1].get('Content-Disposition'), + ) mimeapp = MIMEApplication('attachment') self.assertEqual(mimeapp.get_payload(), msg.get_payload()[-1].get_payload()) @@ -204,10 +206,12 @@ def test_send_mime_ssl(self, mock_smtp, mock_smtp_ssl): def test_send_mime_noauth(self, mock_smtp, mock_smtp_ssl): mock_smtp.return_value = mock.Mock() mock_smtp_ssl.return_value = mock.Mock() - with conf_vars({ - ('smtp', 'smtp_user'): None, - ('smtp', 'smtp_password'): None, - }): + with conf_vars( + { + ('smtp', 'smtp_user'): None, + ('smtp', 'smtp_password'): None, + } + ): utils.email.send_mime_email('from', 'to', MIMEMultipart(), dryrun=False) self.assertFalse(mock_smtp_ssl.called) mock_smtp.assert_called_once_with( diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py index b513e3e401a0e..85c53c5a0cc48 100644 --- a/tests/utils/test_helpers.py +++ b/tests/utils/test_helpers.py @@ -27,7 +27,6 @@ class TestHelpers(unittest.TestCase): - def test_render_log_filename(self): try_number = 1 dag_id = 'test_render_log_filename_dag' @@ -41,10 +40,9 @@ def test_render_log_filename(self): filename_template = "{{ ti.dag_id }}/{{ ti.task_id }}/{{ ts }}/{{ try_number }}.log" ts = ti.get_template_context()['ts'] - expected_filename = "{dag_id}/{task_id}/{ts}/{try_number}.log".format(dag_id=dag_id, - task_id=task_id, - ts=ts, - try_number=try_number) + expected_filename = "{dag_id}/{task_id}/{ts}/{try_number}.log".format( + dag_id=dag_id, task_id=task_id, ts=ts, try_number=try_number + ) rendered_filename = helpers.render_log_filename(ti, try_number, filename_template) @@ -62,22 +60,15 @@ def test_chunks(self): self.assertEqual(list(helpers.chunks([1, 2, 3], 2)), [[1, 2], [3]]) def test_reduce_in_chunks(self): - self.assertEqual(helpers.reduce_in_chunks(lambda x, y: x + [y], - [1, 2, 3, 4, 5], - []), - [[1, 2, 3, 4, 5]]) - - self.assertEqual(helpers.reduce_in_chunks(lambda x, y: x + [y], - [1, 2, 3, 4, 5], - [], - 2), - [[1, 2], [3, 4], [5]]) - - self.assertEqual(helpers.reduce_in_chunks(lambda x, y: x + y[0] * y[1], - [1, 2, 3, 4], - 0, - 2), - 14) + self.assertEqual( + helpers.reduce_in_chunks(lambda x, y: x + [y], [1, 2, 3, 4, 5], []), [[1, 2, 3, 4, 5]] + ) + + self.assertEqual( + helpers.reduce_in_chunks(lambda x, y: x + [y], [1, 2, 3, 4, 5], [], 2), [[1, 2], [3, 4], [5]] + ) + + self.assertEqual(helpers.reduce_in_chunks(lambda x, y: x + y[0] * y[1], [1, 2, 3, 4], 0, 2), 14) def test_is_container(self): self.assertFalse(helpers.is_container("a string is not a container")) @@ -89,14 +80,10 @@ def test_is_container(self): self.assertFalse(helpers.is_container(10)) def test_as_tuple(self): - self.assertEqual( - helpers.as_tuple("a string is not a container"), - ("a string is not a container",) - ) + self.assertEqual(helpers.as_tuple("a string is not a container"), ("a string is not a container",)) self.assertEqual( - helpers.as_tuple(["a", "list", "is", "a", "container"]), - ("a", "list", "is", "a", "container") + helpers.as_tuple(["a", "list", "is", "a", "container"]), ("a", "list", "is", "a", "container") ) def test_as_tuple_iter(self): @@ -111,8 +98,7 @@ def test_as_tuple_no_iter(self): def test_convert_camel_to_snake(self): self.assertEqual(helpers.convert_camel_to_snake('LocalTaskJob'), 'local_task_job') - self.assertEqual(helpers.convert_camel_to_snake('somethingVeryRandom'), - 'something_very_random') + self.assertEqual(helpers.convert_camel_to_snake('somethingVeryRandom'), 'something_very_random') def test_merge_dicts(self): """ diff --git a/tests/utils/test_json.py b/tests/utils/test_json.py index 69363b84feb8d..a127ff98e6cab 100644 --- a/tests/utils/test_json.py +++ b/tests/utils/test_json.py @@ -26,37 +26,21 @@ class TestAirflowJsonEncoder(unittest.TestCase): - def test_encode_datetime(self): obj = datetime.strptime('2017-05-21 00:00:00', '%Y-%m-%d %H:%M:%S') - self.assertEqual( - json.dumps(obj, cls=utils_json.AirflowJsonEncoder), - '"2017-05-21T00:00:00Z"' - ) + self.assertEqual(json.dumps(obj, cls=utils_json.AirflowJsonEncoder), '"2017-05-21T00:00:00Z"') def test_encode_date(self): - self.assertEqual( - json.dumps(date(2017, 5, 21), cls=utils_json.AirflowJsonEncoder), - '"2017-05-21"' - ) + self.assertEqual(json.dumps(date(2017, 5, 21), cls=utils_json.AirflowJsonEncoder), '"2017-05-21"') def test_encode_numpy_int(self): - self.assertEqual( - json.dumps(np.int32(5), cls=utils_json.AirflowJsonEncoder), - '5' - ) + self.assertEqual(json.dumps(np.int32(5), cls=utils_json.AirflowJsonEncoder), '5') def test_encode_numpy_bool(self): - self.assertEqual( - json.dumps(np.bool_(True), cls=utils_json.AirflowJsonEncoder), - 'true' - ) + self.assertEqual(json.dumps(np.bool_(True), cls=utils_json.AirflowJsonEncoder), 'true') def test_encode_numpy_float(self): - self.assertEqual( - json.dumps(np.float16(3.76953125), cls=utils_json.AirflowJsonEncoder), - '3.76953125' - ) + self.assertEqual(json.dumps(np.float16(3.76953125), cls=utils_json.AirflowJsonEncoder), '3.76953125') def test_encode_k8s_v1pod(self): from kubernetes.client import models as k8s @@ -73,17 +57,21 @@ def test_encode_k8s_v1pod(self): image="bar", ) ] - ) + ), ) self.assertEqual( json.loads(json.dumps(pod, cls=utils_json.AirflowJsonEncoder)), - {"metadata": {"name": "foo", "namespace": "bar"}, - "spec": {"containers": [{"image": "bar", "name": "foo"}]}} + { + "metadata": {"name": "foo", "namespace": "bar"}, + "spec": {"containers": [{"image": "bar", "name": "foo"}]}, + }, ) def test_encode_raises(self): - self.assertRaisesRegex(TypeError, - "^.*is not JSON serializable$", - json.dumps, - Exception, - cls=utils_json.AirflowJsonEncoder) + self.assertRaisesRegex( + TypeError, + "^.*is not JSON serializable$", + json.dumps, + Exception, + cls=utils_json.AirflowJsonEncoder, + ) diff --git a/tests/utils/test_log_handlers.py b/tests/utils/test_log_handlers.py index df23ed1af4b36..32a8d251226cb 100644 --- a/tests/utils/test_log_handlers.py +++ b/tests/utils/test_log_handlers.py @@ -38,7 +38,6 @@ class TestFileTaskLogHandler(unittest.TestCase): - def clean_up(self): with create_session() as session: session.query(DagRun).delete() @@ -66,6 +65,7 @@ def test_default_task_logging_setup(self): def test_file_task_handler(self): def task_callable(ti, **kwargs): ti.log.info("test") + dag = DAG('dag_for_testing_file_task_handler', start_date=DEFAULT_DATE) dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=DEFAULT_DATE) task = PythonOperator( @@ -78,8 +78,9 @@ def task_callable(ti, **kwargs): logger = ti.log ti.log.disabled = False - file_handler = next((handler for handler in logger.handlers - if handler.name == FILE_TASK_HANDLER), None) + file_handler = next( + (handler for handler in logger.handlers if handler.name == FILE_TASK_HANDLER), None + ) self.assertIsNotNone(file_handler) set_context(logger, ti) @@ -106,11 +107,7 @@ def task_callable(ti, **kwargs): # We should expect our log line from the callable above to appear in # the logs we read back - self.assertRegex( - logs[0][0][-1], - target_re, - "Logs were " + str(logs) - ) + self.assertRegex(logs[0][0][-1], target_re, "Logs were " + str(logs)) # Remove the generated tmp log file. os.remove(log_filename) @@ -118,6 +115,7 @@ def task_callable(ti, **kwargs): def test_file_task_handler_running(self): def task_callable(ti, **kwargs): ti.log.info("test") + dag = DAG('dag_for_testing_file_task_handler', start_date=DEFAULT_DATE) task = PythonOperator( task_id='task_for_testing_file_log_handler', @@ -131,8 +129,9 @@ def task_callable(ti, **kwargs): logger = ti.log ti.log.disabled = False - file_handler = next((handler for handler in logger.handlers - if handler.name == FILE_TASK_HANDLER), None) + file_handler = next( + (handler for handler in logger.handlers if handler.name == FILE_TASK_HANDLER), None + ) self.assertIsNotNone(file_handler) set_context(logger, ti) @@ -159,25 +158,26 @@ def task_callable(ti, **kwargs): class TestFilenameRendering(unittest.TestCase): - def setUp(self): dag = DAG('dag_for_testing_filename_rendering', start_date=DEFAULT_DATE) task = DummyOperator(task_id='task_for_testing_filename_rendering', dag=dag) self.ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) def test_python_formatting(self): - expected_filename = \ - 'dag_for_testing_filename_rendering/task_for_testing_filename_rendering/%s/42.log' \ + expected_filename = ( + 'dag_for_testing_filename_rendering/task_for_testing_filename_rendering/%s/42.log' % DEFAULT_DATE.isoformat() + ) fth = FileTaskHandler('', '{dag_id}/{task_id}/{execution_date}/{try_number}.log') rendered_filename = fth._render_filename(self.ti, 42) self.assertEqual(expected_filename, rendered_filename) def test_jinja_rendering(self): - expected_filename = \ - 'dag_for_testing_filename_rendering/task_for_testing_filename_rendering/%s/42.log' \ + expected_filename = ( + 'dag_for_testing_filename_rendering/task_for_testing_filename_rendering/%s/42.log' % DEFAULT_DATE.isoformat() + ) fth = FileTaskHandler('', '{{ ti.dag_id }}/{{ ti.task_id }}/{{ ts }}/{{ try_number }}.log') rendered_filename = fth._render_filename(self.ti, 42) diff --git a/tests/utils/test_logging_mixin.py b/tests/utils/test_logging_mixin.py index 7a6bda50f83d8..583eefd4bccbf 100644 --- a/tests/utils/test_logging_mixin.py +++ b/tests/utils/test_logging_mixin.py @@ -25,18 +25,20 @@ class TestLoggingMixin(unittest.TestCase): def setUp(self): - warnings.filterwarnings( - action='always' - ) + warnings.filterwarnings(action='always') def test_set_context(self): handler1 = mock.MagicMock() handler2 = mock.MagicMock() parent = mock.MagicMock() parent.propagate = False - parent.handlers = [handler1, ] + parent.handlers = [ + handler1, + ] log = mock.MagicMock() - log.handlers = [handler2, ] + log.handlers = [ + handler2, + ] log.parent = parent log.propagate = True diff --git a/tests/utils/test_net.py b/tests/utils/test_net.py index 2217785b4f4d4..2f06371018a7f 100644 --- a/tests/utils/test_net.py +++ b/tests/utils/test_net.py @@ -29,7 +29,6 @@ def get_hostname(): class TestGetHostname(unittest.TestCase): - @mock.patch('socket.getfqdn', return_value='first') @conf_vars({('core', 'hostname_callable'): None}) def test_get_hostname_unset(self, mock_getfqdn): @@ -51,6 +50,6 @@ def test_get_hostname_set_missing(self): re.escape( 'The object could not be loaded. Please check "hostname_callable" key in "core" section. ' 'Current value: "tests.utils.test_net.missing_func"' - ) + ), ): net.get_hostname() diff --git a/tests/utils/test_operator_helpers.py b/tests/utils/test_operator_helpers.py index d6a861287567d..c448901f3c8b8 100644 --- a/tests/utils/test_operator_helpers.py +++ b/tests/utils/test_operator_helpers.py @@ -24,7 +24,6 @@ class TestOperatorHelpers(unittest.TestCase): - def setUp(self): super().setUp() self.dag_id = 'dag_id' @@ -37,21 +36,15 @@ def setUp(self): 'dag_run': mock.MagicMock( name='dag_run', run_id=self.dag_run_id, - execution_date=datetime.strptime(self.execution_date, - '%Y-%m-%dT%H:%M:%S'), + execution_date=datetime.strptime(self.execution_date, '%Y-%m-%dT%H:%M:%S'), ), 'task_instance': mock.MagicMock( name='task_instance', task_id=self.task_id, dag_id=self.dag_id, - execution_date=datetime.strptime(self.execution_date, - '%Y-%m-%dT%H:%M:%S'), + execution_date=datetime.strptime(self.execution_date, '%Y-%m-%dT%H:%M:%S'), ), - 'task': mock.MagicMock( - name='task', - owner=self.owner, - email=self.email - ) + 'task': mock.MagicMock(name='task', owner=self.owner, email=self.email), } def test_context_to_airflow_vars_empty_context(self): @@ -66,19 +59,18 @@ def test_context_to_airflow_vars_all_context(self): 'airflow.ctx.task_id': self.task_id, 'airflow.ctx.dag_run_id': self.dag_run_id, 'airflow.ctx.dag_owner': 'owner1,owner2', - 'airflow.ctx.dag_email': 'email1@test.com' - } + 'airflow.ctx.dag_email': 'email1@test.com', + }, ) self.assertDictEqual( - operator_helpers.context_to_airflow_vars(self.context, - in_env_var_format=True), + operator_helpers.context_to_airflow_vars(self.context, in_env_var_format=True), { 'AIRFLOW_CTX_DAG_ID': self.dag_id, 'AIRFLOW_CTX_EXECUTION_DATE': self.execution_date, 'AIRFLOW_CTX_TASK_ID': self.task_id, 'AIRFLOW_CTX_DAG_RUN_ID': self.dag_run_id, 'AIRFLOW_CTX_DAG_OWNER': 'owner1,owner2', - 'AIRFLOW_CTX_DAG_EMAIL': 'email1@test.com' - } + 'AIRFLOW_CTX_DAG_EMAIL': 'email1@test.com', + }, ) diff --git a/tests/utils/test_process_utils.py b/tests/utils/test_process_utils.py index 1620bfe747bf9..ee1097e93f22c 100644 --- a/tests/utils/test_process_utils.py +++ b/tests/utils/test_process_utils.py @@ -39,7 +39,6 @@ class TestReapProcessGroup(unittest.TestCase): - @staticmethod def _ignores_sigterm(child_pid, child_setup_done): def signal_handler(unused_signum, unused_frame): @@ -55,11 +54,13 @@ def signal_handler(unused_signum, unused_frame): def _parent_of_ignores_sigterm(parent_pid, child_pid, setup_done): def signal_handler(unused_signum, unused_frame): pass + os.setsid() signal.signal(signal.SIGTERM, signal_handler) child_setup_done = multiprocessing.Semaphore(0) - child = multiprocessing.Process(target=TestReapProcessGroup._ignores_sigterm, - args=[child_pid, child_setup_done]) + child = multiprocessing.Process( + target=TestReapProcessGroup._ignores_sigterm, args=[child_pid, child_setup_done] + ) child.start() child_setup_done.acquire(timeout=5.0) parent_pid.value = os.getpid() @@ -96,19 +97,13 @@ def test_reap_process_group(self): class TestExecuteInSubProcess(unittest.TestCase): - def test_should_print_all_messages1(self): with self.assertLogs(log) as logs: execute_in_subprocess(["bash", "-c", "echo CAT; echo KITTY;"]) msgs = [record.getMessage() for record in logs.records] - self.assertEqual([ - "Executing cmd: bash -c 'echo CAT; echo KITTY;'", - 'Output:', - 'CAT', - 'KITTY' - ], msgs) + self.assertEqual(["Executing cmd: bash -c 'echo CAT; echo KITTY;'", 'Output:', 'CAT', 'KITTY'], msgs) def test_should_raise_exception(self): with self.assertRaises(CalledProcessError): diff --git a/tests/utils/test_python_virtualenv.py b/tests/utils/test_python_virtualenv.py index 218bce343ddaf..59ad504b84e1f 100644 --- a/tests/utils/test_python_virtualenv.py +++ b/tests/utils/test_python_virtualenv.py @@ -23,14 +23,10 @@ class TestPrepareVirtualenv(unittest.TestCase): - @mock.patch('airflow.utils.python_virtualenv.execute_in_subprocess') def test_should_create_virtualenv(self, mock_execute_in_subprocess): python_bin = prepare_virtualenv( - venv_directory="/VENV", - python_bin="pythonVER", - system_site_packages=False, - requirements=[] + venv_directory="/VENV", python_bin="pythonVER", system_site_packages=False, requirements=[] ) self.assertEqual("/VENV/bin/python", python_bin) mock_execute_in_subprocess.assert_called_once_with(['virtualenv', '/VENV', '--python=pythonVER']) @@ -38,10 +34,7 @@ def test_should_create_virtualenv(self, mock_execute_in_subprocess): @mock.patch('airflow.utils.python_virtualenv.execute_in_subprocess') def test_should_create_virtualenv_with_system_packages(self, mock_execute_in_subprocess): python_bin = prepare_virtualenv( - venv_directory="/VENV", - python_bin="pythonVER", - system_site_packages=True, - requirements=[] + venv_directory="/VENV", python_bin="pythonVER", system_site_packages=True, requirements=[] ) self.assertEqual("/VENV/bin/python", python_bin) mock_execute_in_subprocess.assert_called_once_with( @@ -54,14 +47,10 @@ def test_should_create_virtualenv_with_extra_packages(self, mock_execute_in_subp venv_directory="/VENV", python_bin="pythonVER", system_site_packages=False, - requirements=['apache-beam[gcp]'] + requirements=['apache-beam[gcp]'], ) self.assertEqual("/VENV/bin/python", python_bin) - mock_execute_in_subprocess.assert_any_call( - ['virtualenv', '/VENV', '--python=pythonVER'] - ) + mock_execute_in_subprocess.assert_any_call(['virtualenv', '/VENV', '--python=pythonVER']) - mock_execute_in_subprocess.assert_called_with( - ['/VENV/bin/pip', 'install', 'apache-beam[gcp]'] - ) + mock_execute_in_subprocess.assert_called_with(['/VENV/bin/pip', 'install', 'apache-beam[gcp]']) diff --git a/tests/utils/test_sqlalchemy.py b/tests/utils/test_sqlalchemy.py index ac8a52a776ef9..dd6729bb6882d 100644 --- a/tests/utils/test_sqlalchemy.py +++ b/tests/utils/test_sqlalchemy.py @@ -94,28 +94,66 @@ def test_process_bind_param_naive(self): state=State.NONE, execution_date=start_date, start_date=start_date, - session=self.session + session=self.session, ) dag.clear() - @parameterized.expand([ - ("postgresql", True, {'skip_locked': True}, ), - ("mysql", False, {}, ), - ("mysql", True, {'skip_locked': True}, ), - ("sqlite", False, {'skip_locked': True}, ), - ]) + @parameterized.expand( + [ + ( + "postgresql", + True, + {'skip_locked': True}, + ), + ( + "mysql", + False, + {}, + ), + ( + "mysql", + True, + {'skip_locked': True}, + ), + ( + "sqlite", + False, + {'skip_locked': True}, + ), + ] + ) def test_skip_locked(self, dialect, supports_for_update_of, expected_return_value): session = mock.Mock() session.bind.dialect.name = dialect session.bind.dialect.supports_for_update_of = supports_for_update_of self.assertEqual(skip_locked(session=session), expected_return_value) - @parameterized.expand([ - ("postgresql", True, {'nowait': True}, ), - ("mysql", False, {}, ), - ("mysql", True, {'nowait': True}, ), - ("sqlite", False, {'nowait': True, }, ), - ]) + @parameterized.expand( + [ + ( + "postgresql", + True, + {'nowait': True}, + ), + ( + "mysql", + False, + {}, + ), + ( + "mysql", + True, + {'nowait': True}, + ), + ( + "sqlite", + False, + { + 'nowait': True, + }, + ), + ] + ) def test_nowait(self, dialect, supports_for_update_of, expected_return_value): session = mock.Mock() session.bind.dialect.name = dialect diff --git a/tests/utils/test_task_handler_with_custom_formatter.py b/tests/utils/test_task_handler_with_custom_formatter.py index 37f874442252c..3c9ea13b092c3 100644 --- a/tests/utils/test_task_handler_with_custom_formatter.py +++ b/tests/utils/test_task_handler_with_custom_formatter.py @@ -29,8 +29,7 @@ DEFAULT_DATE = datetime(2019, 1, 1) TASK_LOGGER = 'airflow.task' TASK_HANDLER = 'task' -TASK_HANDLER_CLASS = 'airflow.utils.log.task_handler_with_custom_formatter.' \ - 'TaskHandlerWithCustomFormatter' +TASK_HANDLER_CLASS = 'airflow.utils.log.task_handler_with_custom_formatter.TaskHandlerWithCustomFormatter' PREV_TASK_HANDLER = DEFAULT_LOGGING_CONFIG['handlers']['task'] @@ -40,7 +39,7 @@ def setUp(self): DEFAULT_LOGGING_CONFIG['handlers']['task'] = { 'class': TASK_HANDLER_CLASS, 'formatter': 'airflow', - 'stream': 'sys.stdout' + 'stream': 'sys.stdout', } logging.config.dictConfig(DEFAULT_LOGGING_CONFIG) diff --git a/tests/utils/test_timezone.py b/tests/utils/test_timezone.py index 400813e88a4cc..d7249bde1cdd4 100644 --- a/tests/utils/test_timezone.py +++ b/tests/utils/test_timezone.py @@ -55,10 +55,12 @@ def test_convert_to_utc(self): def test_make_naive(self): self.assertEqual( timezone.make_naive(datetime.datetime(2011, 9, 1, 13, 20, 30, tzinfo=EAT), EAT), - datetime.datetime(2011, 9, 1, 13, 20, 30)) + datetime.datetime(2011, 9, 1, 13, 20, 30), + ) self.assertEqual( timezone.make_naive(datetime.datetime(2011, 9, 1, 17, 20, 30, tzinfo=ICT), EAT), - datetime.datetime(2011, 9, 1, 13, 20, 30)) + datetime.datetime(2011, 9, 1, 13, 20, 30), + ) with self.assertRaises(ValueError): timezone.make_naive(datetime.datetime(2011, 9, 1, 13, 20, 30), EAT) @@ -66,6 +68,7 @@ def test_make_naive(self): def test_make_aware(self): self.assertEqual( timezone.make_aware(datetime.datetime(2011, 9, 1, 13, 20, 30), EAT), - datetime.datetime(2011, 9, 1, 13, 20, 30, tzinfo=EAT)) + datetime.datetime(2011, 9, 1, 13, 20, 30, tzinfo=EAT), + ) with self.assertRaises(ValueError): timezone.make_aware(datetime.datetime(2011, 9, 1, 13, 20, 30, tzinfo=EAT), EAT) diff --git a/tests/utils/test_trigger_rule.py b/tests/utils/test_trigger_rule.py index 1ea399ebde3b7..afbe0fcf48a37 100644 --- a/tests/utils/test_trigger_rule.py +++ b/tests/utils/test_trigger_rule.py @@ -22,7 +22,6 @@ class TestTriggerRule(unittest.TestCase): - def test_valid_trigger_rules(self): self.assertTrue(TriggerRule.is_valid(TriggerRule.ALL_SUCCESS)) self.assertTrue(TriggerRule.is_valid(TriggerRule.ALL_FAILED)) diff --git a/tests/utils/test_weight_rule.py b/tests/utils/test_weight_rule.py index de63d5f9c276f..862e1b426f7fa 100644 --- a/tests/utils/test_weight_rule.py +++ b/tests/utils/test_weight_rule.py @@ -22,7 +22,6 @@ class TestWeightRule(unittest.TestCase): - def test_valid_weight_rules(self): self.assertTrue(WeightRule.is_valid(WeightRule.DOWNSTREAM)) self.assertTrue(WeightRule.is_valid(WeightRule.UPSTREAM)) diff --git a/tests/www/api/experimental/test_dag_runs_endpoint.py b/tests/www/api/experimental/test_dag_runs_endpoint.py index 8e7502f321f06..4e020c76e6188 100644 --- a/tests/www/api/experimental/test_dag_runs_endpoint.py +++ b/tests/www/api/experimental/test_dag_runs_endpoint.py @@ -26,7 +26,6 @@ class TestDagRunsEndpoint(unittest.TestCase): - @classmethod def setUpClass(cls): super().setUpClass() @@ -55,8 +54,7 @@ def test_get_dag_runs_success(self): url_template = '/api/experimental/dags/{}/dag_runs' dag_id = 'example_bash_operator' # Create DagRun - dag_run = trigger_dag( - dag_id=dag_id, run_id='test_get_dag_runs_success') + dag_run = trigger_dag(dag_id=dag_id, run_id='test_get_dag_runs_success') response = self.app.get(url_template.format(dag_id)) self.assertEqual(200, response.status_code) @@ -71,8 +69,7 @@ def test_get_dag_runs_success_with_state_parameter(self): url_template = '/api/experimental/dags/{}/dag_runs?state=running' dag_id = 'example_bash_operator' # Create DagRun - dag_run = trigger_dag( - dag_id=dag_id, run_id='test_get_dag_runs_success') + dag_run = trigger_dag(dag_id=dag_id, run_id='test_get_dag_runs_success') response = self.app.get(url_template.format(dag_id)) self.assertEqual(200, response.status_code) @@ -87,8 +84,7 @@ def test_get_dag_runs_success_with_capital_state_parameter(self): url_template = '/api/experimental/dags/{}/dag_runs?state=RUNNING' dag_id = 'example_bash_operator' # Create DagRun - dag_run = trigger_dag( - dag_id=dag_id, run_id='test_get_dag_runs_success') + dag_run = trigger_dag(dag_id=dag_id, run_id='test_get_dag_runs_success') response = self.app.get(url_template.format(dag_id)) self.assertEqual(200, response.status_code) diff --git a/tests/www/api/experimental/test_endpoints.py b/tests/www/api/experimental/test_endpoints.py index 760e0ac639e44..cbdf1c44c6b0c 100644 --- a/tests/www/api/experimental/test_endpoints.py +++ b/tests/www/api/experimental/test_endpoints.py @@ -53,13 +53,11 @@ def assert_deprecated(self, resp): self.assertEqual('true', resp.headers['Deprecation']) self.assertRegex( resp.headers['Link'], - r'\<.+/stable-rest-api/migration.html\>; ' - 'rel="deprecation"; type="text/html"', + r'\<.+/stable-rest-api/migration.html\>; ' 'rel="deprecation"; type="text/html"', ) class TestApiExperimental(TestBase): - @classmethod def setUpClass(cls): super().setUpClass() @@ -94,40 +92,30 @@ def test_info(self): def test_task_info(self): url_template = '/api/experimental/dags/{}/tasks/{}' - response = self.client.get( - url_template.format('example_bash_operator', 'runme_0') - ) + response = self.client.get(url_template.format('example_bash_operator', 'runme_0')) self.assert_deprecated(response) self.assertIn('"email"', response.data.decode('utf-8')) self.assertNotIn('error', response.data.decode('utf-8')) self.assertEqual(200, response.status_code) - response = self.client.get( - url_template.format('example_bash_operator', 'DNE') - ) + response = self.client.get(url_template.format('example_bash_operator', 'DNE')) self.assertIn('error', response.data.decode('utf-8')) self.assertEqual(404, response.status_code) - response = self.client.get( - url_template.format('DNE', 'DNE') - ) + response = self.client.get(url_template.format('DNE', 'DNE')) self.assertIn('error', response.data.decode('utf-8')) self.assertEqual(404, response.status_code) def test_get_dag_code(self): url_template = '/api/experimental/dags/{}/code' - response = self.client.get( - url_template.format('example_bash_operator') - ) + response = self.client.get(url_template.format('example_bash_operator')) self.assert_deprecated(response) self.assertIn('BashOperator(', response.data.decode('utf-8')) self.assertEqual(200, response.status_code) - response = self.client.get( - url_template.format('xyz') - ) + response = self.client.get(url_template.format('xyz')) self.assertEqual(404, response.status_code) def test_dag_paused(self): @@ -135,9 +123,7 @@ def test_dag_paused(self): paused_url_template = '/api/experimental/dags/{}/paused' paused_url = paused_url_template.format('example_bash_operator') - response = self.client.get( - pause_url_template.format('example_bash_operator', 'true') - ) + response = self.client.get(pause_url_template.format('example_bash_operator', 'true')) self.assert_deprecated(response) self.assertIn('ok', response.data.decode('utf-8')) self.assertEqual(200, response.status_code) @@ -147,9 +133,7 @@ def test_dag_paused(self): self.assertEqual(200, paused_response.status_code) self.assertEqual({"is_paused": True}, paused_response.json) - response = self.client.get( - pause_url_template.format('example_bash_operator', 'false') - ) + response = self.client.get(pause_url_template.format('example_bash_operator', 'false')) self.assertIn('ok', response.data.decode('utf-8')) self.assertEqual(200, response.status_code) @@ -164,13 +148,12 @@ def test_trigger_dag(self): response = self.client.post( url_template.format('example_bash_operator'), data=json.dumps({'run_id': run_id}), - content_type="application/json" + content_type="application/json", ) self.assert_deprecated(response) self.assertEqual(200, response.status_code) - response_execution_date = parse_datetime( - json.loads(response.data.decode('utf-8'))['execution_date']) + response_execution_date = parse_datetime(json.loads(response.data.decode('utf-8'))['execution_date']) self.assertEqual(0, response_execution_date.microsecond) # Check execution_date is correct @@ -184,9 +167,7 @@ def test_trigger_dag(self): # Test error for nonexistent dag response = self.client.post( - url_template.format('does_not_exist_dag'), - data=json.dumps({}), - content_type="application/json" + url_template.format('does_not_exist_dag'), data=json.dumps({}), content_type="application/json" ) self.assertEqual(404, response.status_code) @@ -200,7 +181,7 @@ def test_trigger_dag_for_date(self): response = self.client.post( url_template.format(dag_id), data=json.dumps({'execution_date': datetime_string}), - content_type="application/json" + content_type="application/json", ) self.assert_deprecated(response) self.assertEqual(200, response.status_code) @@ -209,33 +190,28 @@ def test_trigger_dag_for_date(self): dagbag = DagBag() dag = dagbag.get_dag(dag_id) dag_run = dag.get_dagrun(execution_date) - self.assertTrue(dag_run, - 'Dag Run not found for execution date {}' - .format(execution_date)) + self.assertTrue(dag_run, f'Dag Run not found for execution date {execution_date}') # Test correct execution with execution date and microseconds replaced response = self.client.post( url_template.format(dag_id), data=json.dumps({'execution_date': datetime_string, 'replace_microseconds': 'true'}), - content_type="application/json" + content_type="application/json", ) self.assertEqual(200, response.status_code) - response_execution_date = parse_datetime( - json.loads(response.data.decode('utf-8'))['execution_date']) + response_execution_date = parse_datetime(json.loads(response.data.decode('utf-8'))['execution_date']) self.assertEqual(0, response_execution_date.microsecond) dagbag = DagBag() dag = dagbag.get_dag(dag_id) dag_run = dag.get_dagrun(response_execution_date) - self.assertTrue(dag_run, - 'Dag Run not found for execution date {}' - .format(execution_date)) + self.assertTrue(dag_run, f'Dag Run not found for execution date {execution_date}') # Test error for nonexistent dag response = self.client.post( url_template.format('does_not_exist_dag'), data=json.dumps({'execution_date': datetime_string}), - content_type="application/json" + content_type="application/json", ) self.assertEqual(404, response.status_code) @@ -243,7 +219,7 @@ def test_trigger_dag_for_date(self): response = self.client.post( url_template.format(dag_id), data=json.dumps({'execution_date': 'not_a_datetime'}), - content_type="application/json" + content_type="application/json", ) self.assertEqual(400, response.status_code) @@ -253,19 +229,13 @@ def test_task_instance_info(self): task_id = 'also_run_this' execution_date = utcnow().replace(microsecond=0) datetime_string = quote_plus(execution_date.isoformat()) - wrong_datetime_string = quote_plus( - datetime(1990, 1, 1, 1, 1, 1).isoformat() - ) + wrong_datetime_string = quote_plus(datetime(1990, 1, 1, 1, 1, 1).isoformat()) # Create DagRun - trigger_dag(dag_id=dag_id, - run_id='test_task_instance_info_run', - execution_date=execution_date) + trigger_dag(dag_id=dag_id, run_id='test_task_instance_info_run', execution_date=execution_date) # Test Correct execution - response = self.client.get( - url_template.format(dag_id, datetime_string, task_id) - ) + response = self.client.get(url_template.format(dag_id, datetime_string, task_id)) self.assert_deprecated(response) self.assertEqual(200, response.status_code) self.assertIn('state', response.data.decode('utf-8')) @@ -273,30 +243,23 @@ def test_task_instance_info(self): # Test error for nonexistent dag response = self.client.get( - url_template.format('does_not_exist_dag', datetime_string, - task_id), + url_template.format('does_not_exist_dag', datetime_string, task_id), ) self.assertEqual(404, response.status_code) self.assertIn('error', response.data.decode('utf-8')) # Test error for nonexistent task - response = self.client.get( - url_template.format(dag_id, datetime_string, 'does_not_exist_task') - ) + response = self.client.get(url_template.format(dag_id, datetime_string, 'does_not_exist_task')) self.assertEqual(404, response.status_code) self.assertIn('error', response.data.decode('utf-8')) # Test error for nonexistent dag run (wrong execution_date) - response = self.client.get( - url_template.format(dag_id, wrong_datetime_string, task_id) - ) + response = self.client.get(url_template.format(dag_id, wrong_datetime_string, task_id)) self.assertEqual(404, response.status_code) self.assertIn('error', response.data.decode('utf-8')) # Test error for bad datetime format - response = self.client.get( - url_template.format(dag_id, 'not_a_datetime', task_id) - ) + response = self.client.get(url_template.format(dag_id, 'not_a_datetime', task_id)) self.assertEqual(400, response.status_code) self.assertIn('error', response.data.decode('utf-8')) @@ -305,19 +268,13 @@ def test_dagrun_status(self): dag_id = 'example_bash_operator' execution_date = utcnow().replace(microsecond=0) datetime_string = quote_plus(execution_date.isoformat()) - wrong_datetime_string = quote_plus( - datetime(1990, 1, 1, 1, 1, 1).isoformat() - ) + wrong_datetime_string = quote_plus(datetime(1990, 1, 1, 1, 1, 1).isoformat()) # Create DagRun - trigger_dag(dag_id=dag_id, - run_id='test_task_instance_info_run', - execution_date=execution_date) + trigger_dag(dag_id=dag_id, run_id='test_task_instance_info_run', execution_date=execution_date) # Test Correct execution - response = self.client.get( - url_template.format(dag_id, datetime_string) - ) + response = self.client.get(url_template.format(dag_id, datetime_string)) self.assert_deprecated(response) self.assertEqual(200, response.status_code) self.assertIn('state', response.data.decode('utf-8')) @@ -331,16 +288,12 @@ def test_dagrun_status(self): self.assertIn('error', response.data.decode('utf-8')) # Test error for nonexistent dag run (wrong execution_date) - response = self.client.get( - url_template.format(dag_id, wrong_datetime_string) - ) + response = self.client.get(url_template.format(dag_id, wrong_datetime_string)) self.assertEqual(404, response.status_code) self.assertIn('error', response.data.decode('utf-8')) # Test error for bad datetime format - response = self.client.get( - url_template.format(dag_id, 'not_a_datetime') - ) + response = self.client.get(url_template.format(dag_id, 'not_a_datetime')) self.assertEqual(400, response.status_code) self.assertIn('error', response.data.decode('utf-8')) @@ -368,19 +321,13 @@ def test_lineage_info(self): dag_id = 'example_papermill_operator' execution_date = utcnow().replace(microsecond=0) datetime_string = quote_plus(execution_date.isoformat()) - wrong_datetime_string = quote_plus( - datetime(1990, 1, 1, 1, 1, 1).isoformat() - ) + wrong_datetime_string = quote_plus(datetime(1990, 1, 1, 1, 1, 1).isoformat()) # create DagRun - trigger_dag(dag_id=dag_id, - run_id='test_lineage_info_run', - execution_date=execution_date) + trigger_dag(dag_id=dag_id, run_id='test_lineage_info_run', execution_date=execution_date) # test correct execution - response = self.client.get( - url_template.format(dag_id, datetime_string) - ) + response = self.client.get(url_template.format(dag_id, datetime_string)) self.assert_deprecated(response) self.assertEqual(200, response.status_code) self.assertIn('task_ids', response.data.decode('utf-8')) @@ -394,16 +341,12 @@ def test_lineage_info(self): self.assertIn('error', response.data.decode('utf-8')) # Test error for nonexistent dag run (wrong execution_date) - response = self.client.get( - url_template.format(dag_id, wrong_datetime_string) - ) + response = self.client.get(url_template.format(dag_id, wrong_datetime_string)) self.assertEqual(404, response.status_code) self.assertIn('error', response.data.decode('utf-8')) # Test error for bad datetime format - response = self.client.get( - url_template.format(dag_id, 'not_a_datetime') - ) + response = self.client.get(url_template.format(dag_id, 'not_a_datetime')) self.assertEqual(400, response.status_code) self.assertIn('error', response.data.decode('utf-8')) @@ -444,14 +387,12 @@ def test_get_pool(self): ) self.assert_deprecated(response) self.assertEqual(response.status_code, 200) - self.assertEqual(json.loads(response.data.decode('utf-8')), - self.pool.to_json()) + self.assertEqual(json.loads(response.data.decode('utf-8')), self.pool.to_json()) def test_get_pool_non_existing(self): response = self.client.get('/api/experimental/pools/foo') self.assertEqual(response.status_code, 404) - self.assertEqual(json.loads(response.data.decode('utf-8'))['error'], - "Pool 'foo' doesn't exist") + self.assertEqual(json.loads(response.data.decode('utf-8'))['error'], "Pool 'foo' doesn't exist") def test_get_pools(self): response = self.client.get('/api/experimental/pools') @@ -465,11 +406,13 @@ def test_get_pools(self): def test_create_pool(self): response = self.client.post( '/api/experimental/pools', - data=json.dumps({ - 'name': 'foo', - 'slots': 1, - 'description': '', - }), + data=json.dumps( + { + 'name': 'foo', + 'slots': 1, + 'description': '', + } + ), content_type='application/json', ) self.assert_deprecated(response) @@ -484,11 +427,13 @@ def test_create_pool_with_bad_name(self): for name in ('', ' '): response = self.client.post( '/api/experimental/pools', - data=json.dumps({ - 'name': name, - 'slots': 1, - 'description': '', - }), + data=json.dumps( + { + 'name': name, + 'slots': 1, + 'description': '', + } + ), content_type='application/json', ) self.assertEqual(response.status_code, 400) @@ -504,8 +449,7 @@ def test_delete_pool(self): ) self.assert_deprecated(response) self.assertEqual(response.status_code, 200) - self.assertEqual(json.loads(response.data.decode('utf-8')), - self.pool.to_json()) + self.assertEqual(json.loads(response.data.decode('utf-8')), self.pool.to_json()) self.assertEqual(self._get_pool_count(), self.TOTAL_POOL_COUNT - 1) def test_delete_pool_non_existing(self): @@ -513,8 +457,7 @@ def test_delete_pool_non_existing(self): '/api/experimental/pools/foo', ) self.assertEqual(response.status_code, 404) - self.assertEqual(json.loads(response.data.decode('utf-8'))['error'], - "Pool 'foo' doesn't exist") + self.assertEqual(json.loads(response.data.decode('utf-8'))['error'], "Pool 'foo' doesn't exist") def test_delete_default_pool(self): clear_db_pools() @@ -522,5 +465,4 @@ def test_delete_default_pool(self): '/api/experimental/pools/default_pool', ) self.assertEqual(response.status_code, 400) - self.assertEqual(json.loads(response.data.decode('utf-8'))['error'], - "default_pool cannot be deleted") + self.assertEqual(json.loads(response.data.decode('utf-8'))['error'], "default_pool cannot be deleted") diff --git a/tests/www/api/experimental/test_kerberos_endpoints.py b/tests/www/api/experimental/test_kerberos_endpoints.py index 402e5cf059b73..8425d39eb393f 100644 --- a/tests/www/api/experimental/test_kerberos_endpoints.py +++ b/tests/www/api/experimental/test_kerberos_endpoints.py @@ -42,10 +42,12 @@ def setUpClass(cls): for dag in dagbag.dags.values(): dag.sync_to_db() - @conf_vars({ - ("api", "auth_backend"): "airflow.api.auth.backend.kerberos_auth", - ("kerberos", "keytab"): KRB5_KTNAME, - }) + @conf_vars( + { + ("api", "auth_backend"): "airflow.api.auth.backend.kerberos_auth", + ("kerberos", "keytab"): KRB5_KTNAME, + } + ) def setUp(self): self.app = application.create_app(testing=True) @@ -59,7 +61,7 @@ def test_trigger_dag(self): response = client.post( url_template.format('example_bash_operator'), data=json.dumps(dict(run_id='my_run' + datetime.now().isoformat())), - content_type="application/json" + content_type="application/json", ) self.assertEqual(401, response.status_code) @@ -84,7 +86,7 @@ class Request: url_template.format('example_bash_operator'), data=json.dumps(dict(run_id='my_run' + datetime.now().isoformat())), content_type="application/json", - headers=response.request.headers + headers=response.request.headers, ) self.assertEqual(200, response2.status_code) @@ -94,7 +96,7 @@ def test_unauthorized(self): response = client.post( url_template.format('example_bash_operator'), data=json.dumps(dict(run_id='my_run' + datetime.now().isoformat())), - content_type="application/json" + content_type="application/json", ) self.assertEqual(401, response.status_code) diff --git a/tests/www/test_app.py b/tests/www/test_app.py index 198a73ed1f825..5740e9a55ace9 100644 --- a/tests/www/test_app.py +++ b/tests/www/test_app.py @@ -32,16 +32,19 @@ class TestApp(unittest.TestCase): @classmethod def setUpClass(cls) -> None: from airflow import settings + settings.configure_orm() - @conf_vars({ - ('webserver', 'enable_proxy_fix'): 'True', - ('webserver', 'proxy_fix_x_for'): '1', - ('webserver', 'proxy_fix_x_proto'): '1', - ('webserver', 'proxy_fix_x_host'): '1', - ('webserver', 'proxy_fix_x_port'): '1', - ('webserver', 'proxy_fix_x_prefix'): '1' - }) + @conf_vars( + { + ('webserver', 'enable_proxy_fix'): 'True', + ('webserver', 'proxy_fix_x_for'): '1', + ('webserver', 'proxy_fix_x_proto'): '1', + ('webserver', 'proxy_fix_x_host'): '1', + ('webserver', 'proxy_fix_x_port'): '1', + ('webserver', 'proxy_fix_x_prefix'): '1', + } + ) @mock.patch("airflow.www.app.app", None) def test_should_respect_proxy_fix(self): app = application.cached_app(testing=True) @@ -77,9 +80,11 @@ def debug_view(): self.assertEqual(b"success", response.get_data()) self.assertEqual(response.status_code, 200) - @conf_vars({ - ('webserver', 'base_url'): 'http://localhost:8080/internal-client', - }) + @conf_vars( + { + ('webserver', 'base_url'): 'http://localhost:8080/internal-client', + } + ) @mock.patch("airflow.www.app.app", None) def test_should_respect_base_url_ignore_proxy_headers(self): app = application.cached_app(testing=True) @@ -115,15 +120,17 @@ def debug_view(): self.assertEqual(b"success", response.get_data()) self.assertEqual(response.status_code, 200) - @conf_vars({ - ('webserver', 'base_url'): 'http://localhost:8080/internal-client', - ('webserver', 'enable_proxy_fix'): 'True', - ('webserver', 'proxy_fix_x_for'): '1', - ('webserver', 'proxy_fix_x_proto'): '1', - ('webserver', 'proxy_fix_x_host'): '1', - ('webserver', 'proxy_fix_x_port'): '1', - ('webserver', 'proxy_fix_x_prefix'): '1' - }) + @conf_vars( + { + ('webserver', 'base_url'): 'http://localhost:8080/internal-client', + ('webserver', 'enable_proxy_fix'): 'True', + ('webserver', 'proxy_fix_x_for'): '1', + ('webserver', 'proxy_fix_x_proto'): '1', + ('webserver', 'proxy_fix_x_host'): '1', + ('webserver', 'proxy_fix_x_port'): '1', + ('webserver', 'proxy_fix_x_prefix'): '1', + } + ) @mock.patch("airflow.www.app.app", None) def test_should_respect_base_url_when_proxy_fix_and_base_url_is_set_up_but_headers_missing(self): app = application.cached_app(testing=True) @@ -153,15 +160,17 @@ def debug_view(): self.assertEqual(b"success", response.get_data()) self.assertEqual(response.status_code, 200) - @conf_vars({ - ('webserver', 'base_url'): 'http://localhost:8080/internal-client', - ('webserver', 'enable_proxy_fix'): 'True', - ('webserver', 'proxy_fix_x_for'): '1', - ('webserver', 'proxy_fix_x_proto'): '1', - ('webserver', 'proxy_fix_x_host'): '1', - ('webserver', 'proxy_fix_x_port'): '1', - ('webserver', 'proxy_fix_x_prefix'): '1' - }) + @conf_vars( + { + ('webserver', 'base_url'): 'http://localhost:8080/internal-client', + ('webserver', 'enable_proxy_fix'): 'True', + ('webserver', 'proxy_fix_x_for'): '1', + ('webserver', 'proxy_fix_x_proto'): '1', + ('webserver', 'proxy_fix_x_host'): '1', + ('webserver', 'proxy_fix_x_port'): '1', + ('webserver', 'proxy_fix_x_prefix'): '1', + } + ) @mock.patch("airflow.www.app.app", None) def test_should_respect_base_url_and_proxy_when_proxy_fix_and_base_url_is_set_up(self): app = application.cached_app(testing=True) @@ -197,21 +206,18 @@ def debug_view(): self.assertEqual(b"success", response.get_data()) self.assertEqual(response.status_code, 200) - @conf_vars({ - ('core', 'sql_alchemy_pool_enabled'): 'True', - ('core', 'sql_alchemy_pool_size'): '3', - ('core', 'sql_alchemy_max_overflow'): '5', - ('core', 'sql_alchemy_pool_recycle'): '120', - ('core', 'sql_alchemy_pool_pre_ping'): 'True', - }) + @conf_vars( + { + ('core', 'sql_alchemy_pool_enabled'): 'True', + ('core', 'sql_alchemy_pool_size'): '3', + ('core', 'sql_alchemy_max_overflow'): '5', + ('core', 'sql_alchemy_pool_recycle'): '120', + ('core', 'sql_alchemy_pool_pre_ping'): 'True', + } + ) @mock.patch("airflow.www.app.app", None) @pytest.mark.backend("mysql", "postgres") def test_should_set_sqlalchemy_engine_options(self): app = application.cached_app(testing=True) - engine_params = { - 'pool_size': 3, - 'pool_recycle': 120, - 'pool_pre_ping': True, - 'max_overflow': 5 - } + engine_params = {'pool_size': 3, 'pool_recycle': 120, 'pool_pre_ping': True, 'max_overflow': 5} self.assertEqual(app.config['SQLALCHEMY_ENGINE_OPTIONS'], engine_params) diff --git a/tests/www/test_security.py b/tests/www/test_security.py index b44f67482da86..1bf460b8713c1 100644 --- a/tests/www/test_security.py +++ b/tests/www/test_security.py @@ -187,7 +187,14 @@ def test_get_all_permissions_views(self, mock_get_user_roles): username = 'get_all_permissions_views' with self.app.app_context(): - user = fab_utils.create_user(self.app, username, role_name, permissions=[(role_perm, role_vm),],) + user = fab_utils.create_user( + self.app, + username, + role_name, + permissions=[ + (role_perm, role_vm), + ], + ) role = user.roles[0] mock_get_user_roles.return_value = [role] @@ -281,9 +288,7 @@ def test_all_dag_access_doesnt_give_non_dag_access(self): ) self.assertFalse( self.security_manager.has_access( - permissions.ACTION_CAN_READ, - permissions.RESOURCE_TASK_INSTANCE, - user + permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE, user ) ) @@ -293,7 +298,11 @@ def test_access_control_with_invalid_permission(self): 'can_eat_pudding', # clearly not a real permission ] username = "LaUser" - user = fab_utils.create_user(self.app, username=username, role_name='team-a',) + user = fab_utils.create_user( + self.app, + username=username, + role_name='team-a', + ) for permission in invalid_permissions: self.expect_user_is_in_role(user, rolename='team-a') with self.assertRaises(AirflowException) as context: @@ -306,7 +315,12 @@ def test_access_control_is_set_on_init(self): username = 'access_control_is_set_on_init' role_name = 'team-a' with self.app.app_context(): - user = fab_utils.create_user(self.app, username, role_name, permissions=[],) + user = fab_utils.create_user( + self.app, + username, + role_name, + permissions=[], + ) self.expect_user_is_in_role(user, rolename='team-a') self.security_manager.sync_perm_for_dag( 'access_control_test', @@ -329,7 +343,12 @@ def test_access_control_stale_perms_are_revoked(self): username = 'access_control_stale_perms_are_revoked' role_name = 'team-a' with self.app.app_context(): - user = fab_utils.create_user(self.app, username, role_name, permissions=[],) + user = fab_utils.create_user( + self.app, + username, + role_name, + permissions=[], + ) self.expect_user_is_in_role(user, rolename='team-a') self.security_manager.sync_perm_for_dag( 'access_control_test', access_control={'team-a': READ_WRITE} diff --git a/tests/www/test_utils.py b/tests/www/test_utils.py index 51d73fba286b0..efdb249e8cbd2 100644 --- a/tests/www/test_utils.py +++ b/tests/www/test_utils.py @@ -61,7 +61,7 @@ def test_sensitive_variable_fields_should_be_hidden( [ (None, 'TRELLO_API', False), ('token', 'TRELLO_KEY', False), - ('token, mysecretword', 'TRELLO_KEY', False) + ('token, mysecretword', 'TRELLO_KEY', False), ], ) def test_normal_variable_fields_should_not_be_hidden( @@ -70,22 +70,15 @@ def test_normal_variable_fields_should_not_be_hidden( with conf_vars({('admin', 'sensitive_variable_fields'): str(sensitive_variable_fields)}): self.assertEqual(expected_result, utils.should_hide_value_for_key(key)) - def check_generate_pages_html(self, current_page, total_pages, - window=7, check_middle=False): + def check_generate_pages_html(self, current_page, total_pages, window=7, check_middle=False): extra_links = 4 # first, prev, next, last search = "'>\"/>" - html_str = utils.generate_pages(current_page, total_pages, - search=search) + html_str = utils.generate_pages(current_page, total_pages, search=search) - self.assertNotIn(search, html_str, - "The raw search string shouldn't appear in the output") - self.assertIn('search=%27%3E%22%2F%3E%3Cimg+src%3Dx+onerror%3Dalert%281%29%3E', - html_str) + self.assertNotIn(search, html_str, "The raw search string shouldn't appear in the output") + self.assertIn('search=%27%3E%22%2F%3E%3Cimg+src%3Dx+onerror%3Dalert%281%29%3E', html_str) - self.assertTrue( - callable(html_str.__html__), - "Should return something that is HTML-escaping aware" - ) + self.assertTrue(callable(html_str.__html__), "Should return something that is HTML-escaping aware") dom = BeautifulSoup(html_str, 'html.parser') self.assertIsNotNone(dom) @@ -112,25 +105,20 @@ def check_generate_pages_html(self, current_page, total_pages, self.assertListEqual(query['search'], [search]) def test_generate_pager_current_start(self): - self.check_generate_pages_html(current_page=0, - total_pages=6) + self.check_generate_pages_html(current_page=0, total_pages=6) def test_generate_pager_current_middle(self): - self.check_generate_pages_html(current_page=10, - total_pages=20, - check_middle=True) + self.check_generate_pages_html(current_page=10, total_pages=20, check_middle=True) def test_generate_pager_current_end(self): - self.check_generate_pages_html(current_page=38, - total_pages=39) + self.check_generate_pages_html(current_page=38, total_pages=39) def test_params_no_values(self): """Should return an empty string if no params are passed""" self.assertEqual('', utils.get_params()) def test_params_search(self): - self.assertEqual('search=bash_', - utils.get_params(search='bash_')) + self.assertEqual('search=bash_', utils.get_params(search='bash_')) def test_params_none_and_zero(self): query_str = utils.get_params(a=0, b=None, c='true') @@ -140,16 +128,13 @@ def test_params_none_and_zero(self): def test_params_all(self): query = utils.get_params(status='active', page=3, search='bash_') - self.assertEqual( - {'page': ['3'], - 'search': ['bash_'], - 'status': ['active']}, - parse_qs(query) - ) + self.assertEqual({'page': ['3'], 'search': ['bash_'], 'status': ['active']}, parse_qs(query)) def test_params_escape(self): - self.assertEqual('search=%27%3E%22%2F%3E%3Cimg+src%3Dx+onerror%3Dalert%281%29%3E', - utils.get_params(search="'>\"/>")) + self.assertEqual( + 'search=%27%3E%22%2F%3E%3Cimg+src%3Dx+onerror%3Dalert%281%29%3E', + utils.get_params(search="'>\"/>"), + ) def test_state_token(self): # It's shouldn't possible to set these odd values anymore, but lets @@ -168,12 +153,13 @@ def test_state_token(self): def test_task_instance_link(self): from airflow.www.app import cached_app + with cached_app(testing=True).test_request_context(): - html = str(utils.task_instance_link({ - 'dag_id': '', - 'task_id': '', - 'execution_date': datetime.now() - })) + html = str( + utils.task_instance_link( + {'dag_id': '', 'task_id': '', 'execution_date': datetime.now()} + ) + ) self.assertIn('%3Ca%261%3E', html) self.assertIn('%3Cb2%3E', html) @@ -182,23 +168,20 @@ def test_task_instance_link(self): def test_dag_link(self): from airflow.www.app import cached_app + with cached_app(testing=True).test_request_context(): - html = str(utils.dag_link({ - 'dag_id': '', - 'execution_date': datetime.now() - })) + html = str(utils.dag_link({'dag_id': '', 'execution_date': datetime.now()})) self.assertIn('%3Ca%261%3E', html) self.assertNotIn('', html) def test_dag_run_link(self): from airflow.www.app import cached_app + with cached_app(testing=True).test_request_context(): - html = str(utils.dag_run_link({ - 'dag_id': '', - 'run_id': '', - 'execution_date': datetime.now() - })) + html = str( + utils.dag_run_link({'dag_id': '', 'run_id': '', 'execution_date': datetime.now()}) + ) self.assertIn('%3Ca%261%3E', html) self.assertIn('%3Cb2%3E', html) @@ -207,13 +190,13 @@ def test_dag_run_link(self): class TestAttrRenderer(unittest.TestCase): - def setUp(self): self.attr_renderer = utils.get_attr_renderer() def test_python_callable(self): def example_callable(unused_self): print("example") + rendered = self.attr_renderer["python_callable"](example_callable) self.assertIn('"example"', rendered) @@ -233,7 +216,6 @@ def test_markdown_none(self): class TestWrappedMarkdown(unittest.TestCase): - def test_wrapped_markdown_with_docstring_curly_braces(self): rendered = wrapped_markdown("{braces}", css_class="a_class") self.assertEqual('

    {braces}

    ', rendered) @@ -242,4 +224,6 @@ def test_wrapped_markdown_with_some_markdown(self): rendered = wrapped_markdown("*italic*\n**bold**\n", css_class="a_class") self.assertEqual( '''

    italic -bold

    ''', rendered) +bold

    ''', + rendered, + ) diff --git a/tests/www/test_validators.py b/tests/www/test_validators.py index e63f6b042e529..df36b61f1d294 100644 --- a/tests/www/test_validators.py +++ b/tests/www/test_validators.py @@ -23,7 +23,6 @@ class TestGreaterEqualThan(unittest.TestCase): - def setUp(self): super().setUp() self.form_field_mock = mock.MagicMock(data='2017-05-06') @@ -39,8 +38,7 @@ def _validate(self, fieldname=None, message=None): if fieldname is None: fieldname = 'other_field' - validator = validators.GreaterEqualThan(fieldname=fieldname, - message=message) + validator = validators.GreaterEqualThan(fieldname=fieldname, message=message) return validator(self.form_mock, self.form_field_mock) @@ -92,7 +90,6 @@ def test_validation_raises_custom_message(self): class TestValidJson(unittest.TestCase): - def setUp(self): super().setUp() self.form_field_mock = mock.MagicMock(data='{"valid":"True"}') diff --git a/tests/www/test_views.py b/tests/www/test_views.py index f56d79b7bd454..e17180579447c 100644 --- a/tests/www/test_views.py +++ b/tests/www/test_views.py @@ -175,6 +175,7 @@ def capture_templates(self) -> Generator[List[TemplateWithContext], None, None]: def record(sender, template, context, **extra): # pylint: disable=unused-argument recorded.append(TemplateWithContext(template, context)) + template_rendered.connect(record, self.app) # type: ignore try: yield recorded @@ -218,8 +219,7 @@ def create_user_and_login(self, username, role_name, perms): role_name=role_name, permissions=perms, ) - self.login(username=username, - password=username) + self.login(username=username, password=username) class TestConnectionModelView(TestBase): @@ -231,7 +231,7 @@ def setUp(self): 'host': 'localhost', 'port': 8080, 'username': 'root', - 'password': 'admin' + 'password': 'admin', } def tearDown(self): @@ -239,20 +239,14 @@ def tearDown(self): super().tearDown() def test_create_connection(self): - resp = self.client.post('/connection/add', - data=self.connection, - follow_redirects=True) + resp = self.client.post('/connection/add', data=self.connection, follow_redirects=True) self.check_content_in_response('Added Row', resp) class TestVariableModelView(TestBase): def setUp(self): super().setUp() - self.variable = { - 'key': 'test_key', - 'val': 'text_val', - 'is_encrypted': True - } + self.variable = {'key': 'test_key', 'val': 'text_val', 'is_encrypted': True} def tearDown(self): self.clear_table(models.Variable) @@ -265,18 +259,17 @@ def test_can_handle_error_on_decrypt(self): # update the variable with a wrong value, given that is encrypted Var = models.Variable # pylint: disable=invalid-name - (self.session.query(Var) + ( + self.session.query(Var) .filter(Var.key == self.variable['key']) - .update({ - 'val': 'failed_value_not_encrypted' - }, synchronize_session=False)) + .update({'val': 'failed_value_not_encrypted'}, synchronize_session=False) + ) self.session.commit() # retrieve Variables page, should not fail and contain the Invalid # label for the variable resp = self.client.get('/variable/list', follow_redirects=True) - self.check_content_in_response( - 'Invalid', resp) + self.check_content_in_response('Invalid', resp) def test_xss_prevention(self): xss = "/variable/list/" @@ -286,12 +279,10 @@ def test_xss_prevention(self): follow_redirects=True, ) self.assertEqual(resp.status_code, 404) - self.assertNotIn("", - resp.data.decode("utf-8")) + self.assertNotIn("", resp.data.decode("utf-8")) def test_import_variables_no_file(self): - resp = self.client.post('/variable/varimport', - follow_redirects=True) + resp = self.client.post('/variable/varimport', follow_redirects=True) self.check_content_in_response('Missing file or syntax error.', resp) def test_import_variables_failed(self): @@ -308,16 +299,17 @@ def test_import_variables_failed(self): # python 2.7 bytes_content = io.BytesIO(bytes(content)) - resp = self.client.post('/variable/varimport', - data={'file': (bytes_content, 'test.json')}, - follow_redirects=True) + resp = self.client.post( + '/variable/varimport', data={'file': (bytes_content, 'test.json')}, follow_redirects=True + ) self.check_content_in_response('1 variable(s) failed to be updated.', resp) def test_import_variables_success(self): self.assertEqual(self.session.query(models.Variable).count(), 0) - content = ('{"str_key": "str_value", "int_key": 60,' - '"list_key": [1, 2], "dict_key": {"k_a": 2, "k_b": 3}}') + content = ( + '{"str_key": "str_value", "int_key": 60, "list_key": [1, 2], "dict_key": {"k_a": 2, "k_b": 3}}' + ) try: # python 3+ bytes_content = io.BytesIO(bytes(content, encoding='utf-8')) @@ -325,9 +317,9 @@ def test_import_variables_success(self): # python 2.7 bytes_content = io.BytesIO(bytes(content)) - resp = self.client.post('/variable/varimport', - data={'file': (bytes_content, 'test.json')}, - follow_redirects=True) + resp = self.client.post( + '/variable/varimport', data={'file': (bytes_content, 'test.json')}, follow_redirects=True + ) self.check_content_in_response('4 variable(s) successfully updated.', resp) @@ -494,19 +486,22 @@ def prepare_dagruns(self): run_type=DagRunType.SCHEDULED, execution_date=self.EXAMPLE_DAG_DEFAULT_DATE, start_date=timezone.utcnow(), - state=State.RUNNING) + state=State.RUNNING, + ) self.sub_dagrun = self.sub_dag.create_dagrun( run_type=DagRunType.SCHEDULED, execution_date=self.EXAMPLE_DAG_DEFAULT_DATE, start_date=timezone.utcnow(), - state=State.RUNNING) + state=State.RUNNING, + ) self.xcom_dagrun = self.xcom_dag.create_dagrun( run_type=DagRunType.SCHEDULED, execution_date=self.EXAMPLE_DAG_DEFAULT_DATE, start_date=timezone.utcnow(), - state=State.RUNNING) + state=State.RUNNING, + ) def test_index(self): with assert_queries_count(42): @@ -527,55 +522,67 @@ def test_health(self): # case-1: healthy scheduler status last_scheduler_heartbeat_for_testing_1 = timezone.utcnow() - self.session.add(BaseJob(job_type='SchedulerJob', - state='running', - latest_heartbeat=last_scheduler_heartbeat_for_testing_1)) + self.session.add( + BaseJob( + job_type='SchedulerJob', + state='running', + latest_heartbeat=last_scheduler_heartbeat_for_testing_1, + ) + ) self.session.commit() resp_json = json.loads(self.client.get('health', follow_redirects=True).data.decode('utf-8')) self.assertEqual('healthy', resp_json['metadatabase']['status']) self.assertEqual('healthy', resp_json['scheduler']['status']) - self.assertEqual(last_scheduler_heartbeat_for_testing_1.isoformat(), - resp_json['scheduler']['latest_scheduler_heartbeat']) - - self.session.query(BaseJob).\ - filter(BaseJob.job_type == 'SchedulerJob', - BaseJob.state == 'running', - BaseJob.latest_heartbeat == last_scheduler_heartbeat_for_testing_1).\ - delete() + self.assertEqual( + last_scheduler_heartbeat_for_testing_1.isoformat(), + resp_json['scheduler']['latest_scheduler_heartbeat'], + ) + + self.session.query(BaseJob).filter( + BaseJob.job_type == 'SchedulerJob', + BaseJob.state == 'running', + BaseJob.latest_heartbeat == last_scheduler_heartbeat_for_testing_1, + ).delete() self.session.commit() # case-2: unhealthy scheduler status - scenario 1 (SchedulerJob is running too slowly) last_scheduler_heartbeat_for_testing_2 = timezone.utcnow() - timedelta(minutes=1) - (self.session - .query(BaseJob) - .filter(BaseJob.job_type == 'SchedulerJob') - .update({'latest_heartbeat': last_scheduler_heartbeat_for_testing_2 - timedelta(seconds=1)})) - self.session.add(BaseJob(job_type='SchedulerJob', - state='running', - latest_heartbeat=last_scheduler_heartbeat_for_testing_2)) + ( + self.session.query(BaseJob) + .filter(BaseJob.job_type == 'SchedulerJob') + .update({'latest_heartbeat': last_scheduler_heartbeat_for_testing_2 - timedelta(seconds=1)}) + ) + self.session.add( + BaseJob( + job_type='SchedulerJob', + state='running', + latest_heartbeat=last_scheduler_heartbeat_for_testing_2, + ) + ) self.session.commit() resp_json = json.loads(self.client.get('health', follow_redirects=True).data.decode('utf-8')) self.assertEqual('healthy', resp_json['metadatabase']['status']) self.assertEqual('unhealthy', resp_json['scheduler']['status']) - self.assertEqual(last_scheduler_heartbeat_for_testing_2.isoformat(), - resp_json['scheduler']['latest_scheduler_heartbeat']) - - self.session.query(BaseJob).\ - filter(BaseJob.job_type == 'SchedulerJob', - BaseJob.state == 'running', - BaseJob.latest_heartbeat == last_scheduler_heartbeat_for_testing_2).\ - delete() + self.assertEqual( + last_scheduler_heartbeat_for_testing_2.isoformat(), + resp_json['scheduler']['latest_scheduler_heartbeat'], + ) + + self.session.query(BaseJob).filter( + BaseJob.job_type == 'SchedulerJob', + BaseJob.state == 'running', + BaseJob.latest_heartbeat == last_scheduler_heartbeat_for_testing_2, + ).delete() self.session.commit() # case-3: unhealthy scheduler status - scenario 2 (no running SchedulerJob) - self.session.query(BaseJob).\ - filter(BaseJob.job_type == 'SchedulerJob', - BaseJob.state == 'running').\ - delete() + self.session.query(BaseJob).filter( + BaseJob.job_type == 'SchedulerJob', BaseJob.state == 'running' + ).delete() self.session.commit() resp_json = json.loads(self.client.get('health', follow_redirects=True).data.decode('utf-8')) @@ -588,13 +595,15 @@ def test_home(self): with self.capture_templates() as templates: resp = self.client.get('home', follow_redirects=True) self.check_content_in_response('DAGs', resp) - val_state_color_mapping = 'const STATE_COLOR = {"failed": "red", ' \ - '"null": "lightblue", "queued": "gray", ' \ - '"removed": "lightgrey", "running": "lime", ' \ - '"scheduled": "tan", "sensing": "lightseagreen", ' \ - '"shutdown": "blue", "skipped": "pink", ' \ - '"success": "green", "up_for_reschedule": "turquoise", ' \ - '"up_for_retry": "gold", "upstream_failed": "orange"};' + val_state_color_mapping = ( + 'const STATE_COLOR = {"failed": "red", ' + '"null": "lightblue", "queued": "gray", ' + '"removed": "lightgrey", "running": "lime", ' + '"scheduled": "tan", "sensing": "lightseagreen", ' + '"shutdown": "blue", "skipped": "pink", ' + '"success": "green", "up_for_reschedule": "turquoise", ' + '"up_for_retry": "gold", "upstream_failed": "orange"};' + ) self.check_content_in_response(val_state_color_mapping, resp) self.assertEqual(len(templates), 1) @@ -629,6 +638,7 @@ def test_permissionsviews_list(self): def test_home_filter_tags(self): from airflow.www.views import FILTER_TAGS_COOKIE + with self.client: self.client.get('home?tags=example&tags=data', follow_redirects=True) self.assertEqual('example,data', flask_session[FILTER_TAGS_COOKIE]) @@ -638,6 +648,7 @@ def test_home_filter_tags(self): def test_home_status_filter_cookie(self): from airflow.www.views import FILTER_STATUS_COOKIE + with self.client: self.client.get('home', follow_redirects=True) self.assertEqual('all', flask_session[FILTER_STATUS_COOKIE]) @@ -652,20 +663,23 @@ def test_home_status_filter_cookie(self): self.assertEqual('all', flask_session[FILTER_STATUS_COOKIE]) def test_task(self): - url = ('task?task_id=runme_0&dag_id=example_bash_operator&execution_date={}' - .format(self.percent_encode(self.EXAMPLE_DAG_DEFAULT_DATE))) + url = 'task?task_id=runme_0&dag_id=example_bash_operator&execution_date={}'.format( + self.percent_encode(self.EXAMPLE_DAG_DEFAULT_DATE) + ) resp = self.client.get(url, follow_redirects=True) self.check_content_in_response('Task Instance Details', resp) def test_xcom(self): - url = ('xcom?task_id=runme_0&dag_id=example_bash_operator&execution_date={}' - .format(self.percent_encode(self.EXAMPLE_DAG_DEFAULT_DATE))) + url = 'xcom?task_id=runme_0&dag_id=example_bash_operator&execution_date={}'.format( + self.percent_encode(self.EXAMPLE_DAG_DEFAULT_DATE) + ) resp = self.client.get(url, follow_redirects=True) self.check_content_in_response('XCom', resp) def test_rendered(self): - url = ('rendered?task_id=runme_0&dag_id=example_bash_operator&execution_date={}' - .format(self.percent_encode(self.EXAMPLE_DAG_DEFAULT_DATE))) + url = 'rendered?task_id=runme_0&dag_id=example_bash_operator&execution_date={}'.format( + self.percent_encode(self.EXAMPLE_DAG_DEFAULT_DATE) + ) resp = self.client.get(url, follow_redirects=True) self.check_content_in_response('Rendered Template', resp) @@ -681,8 +695,7 @@ def test_dag_stats(self): def test_task_stats(self): resp = self.client.post('task_stats', follow_redirects=True) self.assertEqual(resp.status_code, 200) - self.assertEqual(set(list(resp.json.items())[0][1][0].keys()), - {'state', 'count'}) + self.assertEqual(set(list(resp.json.items())[0][1][0].keys()), {'state', 'count'}) @conf_vars({("webserver", "show_recent_stats_for_completed_runs"): "False"}) def test_task_stats_only_noncompleted(self): @@ -704,12 +717,14 @@ def test_view_uses_existing_dagbag(self, endpoint): resp = self.client.get(url, follow_redirects=True) self.check_content_in_response('example_bash_operator', resp) - @parameterized.expand([ - ("hello\nworld", r'\"conf\":{\"abc\":\"hello\\nworld\"}'), - ("hello'world", r'\"conf\":{\"abc\":\"hello\\u0027world\"}'), - ("