From 72d702aaa77609503d1fc998778df1ad7c59028a Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Tue, 25 Aug 2020 10:11:12 +0100 Subject: [PATCH 1/8] Black --- .pre-commit-config.yaml | 5 +++++ pyproject.toml | 26 ++++++++++++++++++++++++++ 2 files changed, 31 insertions(+) create mode 100644 pyproject.toml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f15fa94686370..0d6c91334ff98 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -149,6 +149,11 @@ repos: - repo: meta hooks: - id: check-hooks-apply + - repo: https://github.com/psf/black + rev: stable + hooks: + - id: black + files: .*providers.*\.py - repo: https://github.com/pre-commit/pre-commit-hooks rev: v3.2.0 hooks: diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000..920b195075a48 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,26 @@ +[tool.black] +line-length = 110 +target-version = ['py36', 'py37', 'py38'] +skip-string-normalization = true +include = '\.pyi?$' +exclude = ''' +/( + \.eggs + | \.build + | \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | _build + | buck-out + | build + | dist + | logs +)/ +''' + +[tool.isort] +line_length = 110 +combine_as_imports = true +profile = "black" From 0be5bfcae54dd5393b2a6a09616581b2e10ff1b0 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Tue, 25 Aug 2020 10:18:17 +0100 Subject: [PATCH 2/8] Enable Black on api_connextion --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0d6c91334ff98..a484f86a5c0dc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -153,7 +153,7 @@ repos: rev: stable hooks: - id: black - files: .*providers.*\.py + files: ^api_connexion/.*\.py - repo: https://github.com/pre-commit/pre-commit-hooks rev: v3.2.0 hooks: From 1a29fb31736a35aa850e36323568739758393081 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Tue, 25 Aug 2020 10:18:44 +0100 Subject: [PATCH 3/8] fixup! Enable Black on api_connextion --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a484f86a5c0dc..178e772b1274d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -153,7 +153,7 @@ repos: rev: stable hooks: - id: black - files: ^api_connexion/.*\.py + files: api_connexion/.*\.py - repo: https://github.com/pre-commit/pre-commit-hooks rev: v3.2.0 hooks: From 4d991cd0bd64455b87c0304070be508348207623 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Tue, 25 Aug 2020 10:33:33 +0100 Subject: [PATCH 4/8] Enable black on api_connextion folder and its tests --- .flake8 | 2 +- .pre-commit-config.yaml | 1 + .../endpoints/config_endpoint.py | 12 +- .../endpoints/connection_endpoint.py | 14 +- .../api_connexion/endpoints/dag_endpoint.py | 9 +- .../endpoints/dag_run_endpoint.py | 105 ++++++--- .../endpoints/event_log_endpoint.py | 13 +- .../endpoints/health_endpoint.py | 5 +- .../endpoints/import_error_endpoint.py | 8 +- .../api_connexion/endpoints/log_endpoint.py | 12 +- .../api_connexion/endpoints/pool_endpoint.py | 8 +- .../endpoints/variable_endpoint.py | 9 +- .../endpoints/version_endpoint.py | 1 + .../api_connexion/endpoints/xcom_endpoint.py | 25 +- airflow/api_connexion/exceptions.py | 11 +- airflow/api_connexion/parameters.py | 4 +- .../api_connexion/schemas/common_schema.py | 13 +- .../api_connexion/schemas/config_schema.py | 6 + .../schemas/connection_schema.py | 3 + .../api_connexion/schemas/dag_run_schema.py | 2 + airflow/api_connexion/schemas/dag_schema.py | 2 + .../schemas/dag_source_schema.py | 1 + airflow/api_connexion/schemas/enum_schemas.py | 5 +- .../api_connexion/schemas/event_log_schema.py | 2 + .../api_connexion/schemas/health_schema.py | 2 + airflow/api_connexion/schemas/log_schema.py | 1 + airflow/api_connexion/schemas/pool_schema.py | 1 + airflow/api_connexion/schemas/task_schema.py | 9 +- .../api_connexion/schemas/variable_schema.py | 2 + .../api_connexion/schemas/version_schema.py | 1 + airflow/api_connexion/schemas/xcom_schema.py | 3 + airflow/api_connexion/security.py | 1 + .../endpoints/test_config_endpoint.py | 42 ++-- .../endpoints/test_connection_endpoint.py | 219 ++++++------------ .../endpoints/test_dag_endpoint.py | 89 ++----- .../endpoints/test_dag_run_endpoint.py | 157 +++++-------- .../endpoints/test_dag_source_endpoint.py | 52 ++--- .../endpoints/test_event_log_endpoint.py | 81 ++----- .../endpoints/test_extra_link_endpoint.py | 10 +- .../endpoints/test_health_endpoint.py | 4 +- .../endpoints/test_import_error_endpoint.py | 19 +- .../endpoints/test_log_endpoint.py | 75 +++--- .../endpoints/test_pool_endpoint.py | 83 ++----- .../endpoints/test_task_endpoint.py | 18 +- .../endpoints/test_task_instance_endpoint.py | 6 +- .../endpoints/test_variable_endpoint.py | 127 ++++------ .../endpoints/test_version_endpoint.py | 8 +- .../endpoints/test_xcom_endpoint.py | 127 +++++----- .../schemas/test_common_schema.py | 16 +- .../schemas/test_config_schema.py | 21 +- .../schemas/test_connection_schema.py | 67 +++--- .../schemas/test_dag_run_schema.py | 19 +- .../api_connexion/schemas/test_dag_schema.py | 7 +- .../schemas/test_error_schema.py | 4 +- .../schemas/test_event_log_schema.py | 48 ++-- .../schemas/test_health_schema.py | 5 +- .../api_connexion/schemas/test_task_schema.py | 9 +- .../schemas/test_version_schema.py | 8 +- .../api_connexion/schemas/test_xcom_schema.py | 28 ++- tests/api_connexion/test_parameters.py | 3 - 60 files changed, 629 insertions(+), 1016 deletions(-) diff --git a/.flake8 b/.flake8 index f9bf91ec16ebc..cffaf32f138d9 100644 --- a/.flake8 +++ b/.flake8 @@ -1,6 +1,6 @@ [flake8] max-line-length = 110 -ignore = E731,W504,I001,W503 +ignore = E231,E731,W504,I001,W503 exclude = .svn,CVS,.bzr,.hg,.git,__pycache__,.eggs,*.egg,node_modules format = ${cyan}%(path)s${reset}:${yellow_bold}%(row)d${reset}:${green_bold}%(col)d${reset}: ${red_bold}%(code)s${reset} %(text)s per-file-ignores = diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 178e772b1274d..88da0d928315c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -154,6 +154,7 @@ repos: hooks: - id: black files: api_connexion/.*\.py + args: [--config=./pyproject.toml] - repo: https://github.com/pre-commit/pre-commit-hooks rev: v3.2.0 hooks: diff --git a/airflow/api_connexion/endpoints/config_endpoint.py b/airflow/api_connexion/endpoints/config_endpoint.py index e813af11afaf3..675fcd2df3597 100644 --- a/airflow/api_connexion/endpoints/config_endpoint.py +++ b/airflow/api_connexion/endpoints/config_endpoint.py @@ -30,11 +30,7 @@ def _conf_dict_to_config(conf_dict: dict) -> Config: config = Config( sections=[ ConfigSection( - name=section, - options=[ - ConfigOption(key=key, value=value) - for key, value in options.items() - ] + name=section, options=[ConfigOption(key=key, value=value) for key, value in options.items()] ) for section, options in conf_dict.items() ] @@ -49,8 +45,10 @@ def _option_to_text(config_option: ConfigOption) -> str: def _section_to_text(config_section: ConfigSection) -> str: """Convert a single config section to text""" - return (f'[{config_section.name}]{LINE_SEP}' - f'{LINE_SEP.join(_option_to_text(option) for option in config_section.options)}{LINE_SEP}') + return ( + f'[{config_section.name}]{LINE_SEP}' + f'{LINE_SEP.join(_option_to_text(option) for option in config_section.options)}{LINE_SEP}' + ) def _config_to_text(config: Config) -> str: diff --git a/airflow/api_connexion/endpoints/connection_endpoint.py b/airflow/api_connexion/endpoints/connection_endpoint.py index d5ed5fbd043eb..daab4f50a8b39 100644 --- a/airflow/api_connexion/endpoints/connection_endpoint.py +++ b/airflow/api_connexion/endpoints/connection_endpoint.py @@ -24,7 +24,10 @@ from airflow.api_connexion.exceptions import AlreadyExists, BadRequest, NotFound from airflow.api_connexion.parameters import check_limit, format_parameters from airflow.api_connexion.schemas.connection_schema import ( - ConnectionCollection, connection_collection_item_schema, connection_collection_schema, connection_schema, + ConnectionCollection, + connection_collection_item_schema, + connection_collection_schema, + connection_schema, ) from airflow.models import Connection from airflow.utils.session import provide_session @@ -56,9 +59,7 @@ def get_connection(connection_id, session): @security.requires_authentication -@format_parameters({ - 'limit': check_limit -}) +@format_parameters({'limit': check_limit}) @provide_session def get_connections(session, limit, offset=0): """ @@ -67,8 +68,9 @@ def get_connections(session, limit, offset=0): total_entries = session.query(func.count(Connection.id)).scalar() query = session.query(Connection) connections = query.order_by(Connection.id).offset(offset).limit(limit).all() - return connection_collection_schema.dump(ConnectionCollection(connections=connections, - total_entries=total_entries)) + return connection_collection_schema.dump( + ConnectionCollection(connections=connections, total_entries=total_entries) + ) @security.requires_authentication diff --git a/airflow/api_connexion/endpoints/dag_endpoint.py b/airflow/api_connexion/endpoints/dag_endpoint.py index 27634885f8d11..6f6c0bf9d9938 100644 --- a/airflow/api_connexion/endpoints/dag_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_endpoint.py @@ -23,7 +23,10 @@ from airflow.api_connexion.exceptions import BadRequest, NotFound from airflow.api_connexion.parameters import check_limit, format_parameters from airflow.api_connexion.schemas.dag_schema import ( - DAGCollection, dag_detail_schema, dag_schema, dags_collection_schema, + DAGCollection, + dag_detail_schema, + dag_schema, + dags_collection_schema, ) from airflow.models.dag import DagModel from airflow.utils.session import provide_session @@ -55,9 +58,7 @@ def get_dag_details(dag_id): @security.requires_authentication -@format_parameters({ - 'limit': check_limit -}) +@format_parameters({'limit': check_limit}) @provide_session def get_dags(session, limit, offset=0): """ diff --git a/airflow/api_connexion/endpoints/dag_run_endpoint.py b/airflow/api_connexion/endpoints/dag_run_endpoint.py index fb1fd1c4855e1..0a66b06ba1c6b 100644 --- a/airflow/api_connexion/endpoints/dag_run_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_run_endpoint.py @@ -23,7 +23,10 @@ from airflow.api_connexion.exceptions import AlreadyExists, BadRequest, NotFound from airflow.api_connexion.parameters import check_limit, format_datetime, format_parameters from airflow.api_connexion.schemas.dag_run_schema import ( - DAGRunCollection, dagrun_collection_schema, dagrun_schema, dagruns_batch_form_schema, + DAGRunCollection, + dagrun_collection_schema, + dagrun_schema, + dagruns_batch_form_schema, ) from airflow.models import DagModel, DagRun from airflow.utils.session import provide_session @@ -36,11 +39,7 @@ def delete_dag_run(dag_id, dag_run_id, session): """ Delete a DAG Run """ - if ( - session.query(DagRun) - .filter(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id) - .delete() == 0 - ): + if session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).delete() == 0: raise NotFound(detail=f"DAGRun with DAG ID: '{dag_id}' and DagRun ID: '{dag_run_id}' not found") return NoContent, 204 @@ -51,23 +50,24 @@ def get_dag_run(dag_id, dag_run_id, session): """ Get a DAG Run. """ - dag_run = session.query(DagRun).filter( - DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).one_or_none() + dag_run = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).one_or_none() if dag_run is None: raise NotFound("DAGRun not found") return dagrun_schema.dump(dag_run) @security.requires_authentication -@format_parameters({ - 'start_date_gte': format_datetime, - 'start_date_lte': format_datetime, - 'execution_date_gte': format_datetime, - 'execution_date_lte': format_datetime, - 'end_date_gte': format_datetime, - 'end_date_lte': format_datetime, - 'limit': check_limit -}) +@format_parameters( + { + 'start_date_gte': format_datetime, + 'start_date_lte': format_datetime, + 'execution_date_gte': format_datetime, + 'execution_date_lte': format_datetime, + 'end_date_gte': format_datetime, + 'end_date_lte': format_datetime, + 'limit': check_limit, + } +) @provide_session def get_dag_runs( session, @@ -91,27 +91,52 @@ def get_dag_runs( if dag_id != "~": query = query.filter(DagRun.dag_id == dag_id) - dag_run, total_entries = _fetch_dag_runs(query, session, end_date_gte, end_date_lte, execution_date_gte, - execution_date_lte, start_date_gte, start_date_lte, - limit, offset) + dag_run, total_entries = _fetch_dag_runs( + query, + session, + end_date_gte, + end_date_lte, + execution_date_gte, + execution_date_lte, + start_date_gte, + start_date_lte, + limit, + offset, + ) - return dagrun_collection_schema.dump(DAGRunCollection(dag_runs=dag_run, - total_entries=total_entries)) + return dagrun_collection_schema.dump(DAGRunCollection(dag_runs=dag_run, total_entries=total_entries)) -def _fetch_dag_runs(query, session, end_date_gte, end_date_lte, - execution_date_gte, execution_date_lte, - start_date_gte, start_date_lte, limit, offset): - query = _apply_date_filters_to_query(query, end_date_gte, end_date_lte, execution_date_gte, - execution_date_lte, start_date_gte, start_date_lte) +def _fetch_dag_runs( + query, + session, + end_date_gte, + end_date_lte, + execution_date_gte, + execution_date_lte, + start_date_gte, + start_date_lte, + limit, + offset, +): + query = _apply_date_filters_to_query( + query, + end_date_gte, + end_date_lte, + execution_date_gte, + execution_date_lte, + start_date_gte, + start_date_lte, + ) # apply offset and limit dag_run = query.order_by(DagRun.id).offset(offset).limit(limit).all() total_entries = session.query(func.count(DagRun.id)).scalar() return dag_run, total_entries -def _apply_date_filters_to_query(query, end_date_gte, end_date_lte, execution_date_gte, - execution_date_lte, start_date_gte, start_date_lte): +def _apply_date_filters_to_query( + query, end_date_gte, end_date_lte, execution_date_gte, execution_date_lte, start_date_gte, start_date_lte +): # filter start date if start_date_gte: query = query.filter(DagRun.start_date >= start_date_gte) @@ -147,13 +172,20 @@ def get_dag_runs_batch(session): if data["dag_ids"]: query = query.filter(DagRun.dag_id.in_(data["dag_ids"])) - dag_runs, total_entries = _fetch_dag_runs(query, session, data["end_date_gte"], data["end_date_lte"], - data["execution_date_gte"], data["execution_date_lte"], - data["start_date_gte"], data["start_date_lte"], - data["page_limit"], data["page_offset"]) + dag_runs, total_entries = _fetch_dag_runs( + query, + session, + data["end_date_gte"], + data["end_date_lte"], + data["execution_date_gte"], + data["execution_date_lte"], + data["start_date_gte"], + data["start_date_lte"], + data["page_limit"], + data["page_offset"], + ) - return dagrun_collection_schema.dump(DAGRunCollection(dag_runs=dag_runs, - total_entries=total_entries)) + return dagrun_collection_schema.dump(DAGRunCollection(dag_runs=dag_runs, total_entries=total_entries)) @security.requires_authentication @@ -167,8 +199,7 @@ def post_dag_run(dag_id, session): post_body = dagrun_schema.load(request.json, session=session) dagrun_instance = ( - session.query(DagRun).filter( - DagRun.dag_id == dag_id, DagRun.run_id == post_body["run_id"]).first() + session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == post_body["run_id"]).first() ) if not dagrun_instance: dag_run = DagRun(dag_id=dag_id, run_type=DagRunType.MANUAL.value, **post_body) diff --git a/airflow/api_connexion/endpoints/event_log_endpoint.py b/airflow/api_connexion/endpoints/event_log_endpoint.py index edbdc23d07911..5ce4400821dbc 100644 --- a/airflow/api_connexion/endpoints/event_log_endpoint.py +++ b/airflow/api_connexion/endpoints/event_log_endpoint.py @@ -22,7 +22,9 @@ from airflow.api_connexion.exceptions import NotFound from airflow.api_connexion.parameters import check_limit, format_parameters from airflow.api_connexion.schemas.event_log_schema import ( - EventLogCollection, event_log_collection_schema, event_log_schema, + EventLogCollection, + event_log_collection_schema, + event_log_schema, ) from airflow.models import Log from airflow.utils.session import provide_session @@ -41,9 +43,7 @@ def get_event_log(event_log_id, session): @security.requires_authentication -@format_parameters({ - 'limit': check_limit -}) +@format_parameters({'limit': check_limit}) @provide_session def get_event_logs(session, limit, offset=None): """ @@ -52,5 +52,6 @@ def get_event_logs(session, limit, offset=None): total_entries = session.query(func.count(Log.id)).scalar() event_logs = session.query(Log).order_by(Log.id).offset(offset).limit(limit).all() - return event_log_collection_schema.dump(EventLogCollection(event_logs=event_logs, - total_entries=total_entries)) + return event_log_collection_schema.dump( + EventLogCollection(event_logs=event_logs, total_entries=total_entries) + ) diff --git a/airflow/api_connexion/endpoints/health_endpoint.py b/airflow/api_connexion/endpoints/health_endpoint.py index f3f18aebaf1ee..6927f960de4fb 100644 --- a/airflow/api_connexion/endpoints/health_endpoint.py +++ b/airflow/api_connexion/endpoints/health_endpoint.py @@ -40,10 +40,7 @@ def get_health(): payload = { "metadatabase": {"status": metadatabase_status}, - "scheduler": { - "status": scheduler_status, - "latest_scheduler_heartbeat": latest_scheduler_heartbeat, - }, + "scheduler": {"status": scheduler_status, "latest_scheduler_heartbeat": latest_scheduler_heartbeat,}, } return health_schema.dump(payload) diff --git a/airflow/api_connexion/endpoints/import_error_endpoint.py b/airflow/api_connexion/endpoints/import_error_endpoint.py index ca2b8cbfef672..12baeda9fca31 100644 --- a/airflow/api_connexion/endpoints/import_error_endpoint.py +++ b/airflow/api_connexion/endpoints/import_error_endpoint.py @@ -21,7 +21,9 @@ from airflow.api_connexion.exceptions import NotFound from airflow.api_connexion.parameters import check_limit, format_parameters from airflow.api_connexion.schemas.error_schema import ( - ImportErrorCollection, import_error_collection_schema, import_error_schema, + ImportErrorCollection, + import_error_collection_schema, + import_error_schema, ) from airflow.models.errors import ImportError # pylint: disable=redefined-builtin from airflow.utils.session import provide_session @@ -41,9 +43,7 @@ def get_import_error(import_error_id, session): @security.requires_authentication -@format_parameters({ - 'limit': check_limit -}) +@format_parameters({'limit': check_limit}) @provide_session def get_import_errors(session, limit, offset=None): """ diff --git a/airflow/api_connexion/endpoints/log_endpoint.py b/airflow/api_connexion/endpoints/log_endpoint.py index a0330e388aaba..3bb6d78fb6cc3 100644 --- a/airflow/api_connexion/endpoints/log_endpoint.py +++ b/airflow/api_connexion/endpoints/log_endpoint.py @@ -29,8 +29,7 @@ @security.requires_authentication @provide_session -def get_log(session, dag_id, dag_run_id, task_id, task_try_number, - full_content=False, token=None): +def get_log(session, dag_id, dag_run_id, task_id, task_try_number, full_content=False, token=None): """ Get logs for specific task instance """ @@ -77,13 +76,8 @@ def get_log(session, dag_id, dag_run_id, task_id, task_try_number, logs, metadata = task_log_reader.read_log_chunks(ti, task_try_number, metadata) logs = logs[0] if task_try_number is not None else logs token = URLSafeSerializer(key).dumps(metadata) - return logs_schema.dump(LogResponseObject(continuation_token=token, - content=logs) - ) + return logs_schema.dump(LogResponseObject(continuation_token=token, content=logs)) # text/plain. Stream logs = task_log_reader.read_log_stream(ti, task_try_number, metadata) - return Response( - logs, - headers={"Content-Type": return_type} - ) + return Response(logs, headers={"Content-Type": return_type}) diff --git a/airflow/api_connexion/endpoints/pool_endpoint.py b/airflow/api_connexion/endpoints/pool_endpoint.py index 2ff599ab036d1..b9c51b04f6191 100644 --- a/airflow/api_connexion/endpoints/pool_endpoint.py +++ b/airflow/api_connexion/endpoints/pool_endpoint.py @@ -54,9 +54,7 @@ def get_pool(pool_name, session): @security.requires_authentication -@format_parameters({ - 'limit': check_limit -}) +@format_parameters({'limit': check_limit}) @provide_session def get_pools(session, limit, offset=None): """ @@ -65,9 +63,7 @@ def get_pools(session, limit, offset=None): total_entries = session.query(func.count(Pool.id)).scalar() pools = session.query(Pool).order_by(Pool.id).offset(offset).limit(limit).all() - return pool_collection_schema.dump( - PoolCollection(pools=pools, total_entries=total_entries) - ) + return pool_collection_schema.dump(PoolCollection(pools=pools, total_entries=total_entries)) @security.requires_authentication diff --git a/airflow/api_connexion/endpoints/variable_endpoint.py b/airflow/api_connexion/endpoints/variable_endpoint.py index 6c36599aa68d5..e76e5f2152c1f 100644 --- a/airflow/api_connexion/endpoints/variable_endpoint.py +++ b/airflow/api_connexion/endpoints/variable_endpoint.py @@ -51,9 +51,7 @@ def get_variable(variable_key: str) -> Response: @security.requires_authentication -@format_parameters({ - 'limit': check_limit -}) +@format_parameters({'limit': check_limit}) @provide_session def get_variables(session, limit: Optional[int], offset: Optional[int] = None) -> Response: """ @@ -66,10 +64,7 @@ def get_variables(session, limit: Optional[int], offset: Optional[int] = None) - if limit: query = query.limit(limit) variables = query.all() - return variable_collection_schema.dump({ - "variables": variables, - "total_entries": total_entries, - }) + return variable_collection_schema.dump({"variables": variables, "total_entries": total_entries,}) @security.requires_authentication diff --git a/airflow/api_connexion/endpoints/version_endpoint.py b/airflow/api_connexion/endpoints/version_endpoint.py index 8c7cf8329f882..2175f0314e2b9 100644 --- a/airflow/api_connexion/endpoints/version_endpoint.py +++ b/airflow/api_connexion/endpoints/version_endpoint.py @@ -27,6 +27,7 @@ class VersionInfo(NamedTuple): """Version information""" + version: str git_version: Optional[str] diff --git a/airflow/api_connexion/endpoints/xcom_endpoint.py b/airflow/api_connexion/endpoints/xcom_endpoint.py index 48d553d7b4bc5..af0cf1c5fde43 100644 --- a/airflow/api_connexion/endpoints/xcom_endpoint.py +++ b/airflow/api_connexion/endpoints/xcom_endpoint.py @@ -23,7 +23,10 @@ from airflow.api_connexion.exceptions import NotFound from airflow.api_connexion.parameters import check_limit, format_parameters from airflow.api_connexion.schemas.xcom_schema import ( - XComCollection, XComCollectionItemSchema, XComCollectionSchema, xcom_collection_item_schema, + XComCollection, + XComCollectionItemSchema, + XComCollectionSchema, + xcom_collection_item_schema, xcom_collection_schema, ) from airflow.models import DagRun as DR, XCom @@ -31,9 +34,7 @@ @security.requires_authentication -@format_parameters({ - 'limit': check_limit -}) +@format_parameters({'limit': check_limit}) @provide_session def get_xcom_entries( dag_id: str, @@ -41,7 +42,7 @@ def get_xcom_entries( task_id: str, session: Session, limit: Optional[int], - offset: Optional[int] = None + offset: Optional[int] = None, ) -> XComCollectionSchema: """ Get all XCom values @@ -57,9 +58,7 @@ def get_xcom_entries( query = query.filter(XCom.task_id == task_id) if dag_run_id != '~': query = query.filter(DR.run_id == dag_run_id) - query = query.order_by( - XCom.execution_date, XCom.task_id, XCom.dag_id, XCom.key - ) + query = query.order_by(XCom.execution_date, XCom.task_id, XCom.dag_id, XCom.key) total_entries = session.query(func.count(XCom.key)).scalar() query = query.offset(offset).limit(limit) return xcom_collection_schema.dump(XComCollection(xcom_entries=query.all(), total_entries=total_entries)) @@ -68,18 +67,12 @@ def get_xcom_entries( @security.requires_authentication @provide_session def get_xcom_entry( - dag_id: str, - task_id: str, - dag_run_id: str, - xcom_key: str, - session: Session + dag_id: str, task_id: str, dag_run_id: str, xcom_key: str, session: Session ) -> XComCollectionItemSchema: """ Get an XCom entry """ - query = session.query(XCom).filter(XCom.dag_id == dag_id, - XCom.task_id == task_id, - XCom.key == xcom_key) + query = session.query(XCom).filter(XCom.dag_id == dag_id, XCom.task_id == task_id, XCom.key == xcom_key) query = query.join(DR, and_(XCom.dag_id == DR.dag_id, XCom.execution_date == DR.execution_date)) query = query.filter(DR.run_id == dag_run_id) diff --git a/airflow/api_connexion/exceptions.py b/airflow/api_connexion/exceptions.py index e88372722bed6..7e34f31a14666 100644 --- a/airflow/api_connexion/exceptions.py +++ b/airflow/api_connexion/exceptions.py @@ -21,40 +21,43 @@ class NotFound(ProblemException): """Raise when the object cannot be found""" + def __init__(self, title='Object not found', detail=None): super().__init__(status=404, title=title, detail=detail) class BadRequest(ProblemException): """Raise when the server processes a bad request""" + def __init__(self, title='Bad request', detail=None): super().__init__(status=400, title=title, detail=detail) class Unauthenticated(ProblemException): """Raise when the user is not authenticated""" + def __init__( - self, - title: str = 'Unauthorized', - detail: Optional[str] = None, - headers: Optional[Dict] = None, + self, title: str = 'Unauthorized', detail: Optional[str] = None, headers: Optional[Dict] = None, ): super().__init__(status=401, title=title, detail=detail, headers=headers) class PermissionDenied(ProblemException): """Raise when the user does not have the required permissions""" + def __init__(self, title='Forbidden', detail=None): super().__init__(status=403, title=title, detail=detail) class AlreadyExists(ProblemException): """Raise when the object already exists""" + def __init__(self, title='Object already exists', detail=None): super().__init__(status=409, title=title, detail=detail) class Unknown(ProblemException): """Returns a response body and status code for HTTP 500 exception""" + def __init__(self, title='Unknown server error', detail=None): super().__init__(status=500, title=title, detail=detail) diff --git a/airflow/api_connexion/parameters.py b/airflow/api_connexion/parameters.py index 6c99f74387028..f3c471a54f894 100644 --- a/airflow/api_connexion/parameters.py +++ b/airflow/api_connexion/parameters.py @@ -37,9 +37,7 @@ def format_datetime(value: str): try: return timezone.parse(value) except (ParserError, TypeError) as err: - raise BadRequest( - "Incorrect datetime argument", detail=str(err) - ) + raise BadRequest("Incorrect datetime argument", detail=str(err)) def check_limit(value: int): diff --git a/airflow/api_connexion/schemas/common_schema.py b/airflow/api_connexion/schemas/common_schema.py index 27fd413206710..7f1020cb01f14 100644 --- a/airflow/api_connexion/schemas/common_schema.py +++ b/airflow/api_connexion/schemas/common_schema.py @@ -30,6 +30,7 @@ class CronExpression(typing.NamedTuple): """Cron expression schema""" + value: str @@ -102,6 +103,7 @@ class ScheduleIntervalSchema(OneOfSchema): * RelativeDelta * CronExpression """ + type_field = "__type" type_schemas = { "TimeDelta": TimeDeltaSchema, @@ -129,20 +131,18 @@ def get_obj_type(self, obj): class ColorField(fields.String): """Schema for color property""" + def __init__(self, **metadata): super().__init__(**metadata) - self.validators = ( - [validate.Regexp("^#[a-fA-F0-9]{3,6}$")] + list(self.validators) - ) + self.validators = [validate.Regexp("^#[a-fA-F0-9]{3,6}$")] + list(self.validators) class WeightRuleField(fields.String): """Schema for WeightRule""" + def __init__(self, **metadata): super().__init__(**metadata) - self.validators = ( - [validate.OneOf(WeightRule.all_weight_rules())] + list(self.validators) - ) + self.validators = [validate.OneOf(WeightRule.all_weight_rules())] + list(self.validators) class TimezoneField(fields.String): @@ -153,6 +153,7 @@ class ClassReferenceSchema(Schema): """ Class reference schema. """ + module_path = fields.Method("_get_module", required=True) class_name = fields.Method("_get_class_name", required=True) diff --git a/airflow/api_connexion/schemas/config_schema.py b/airflow/api_connexion/schemas/config_schema.py index bb7be9cd33dfc..2eb459ce14263 100644 --- a/airflow/api_connexion/schemas/config_schema.py +++ b/airflow/api_connexion/schemas/config_schema.py @@ -22,35 +22,41 @@ class ConfigOptionSchema(Schema): """Config Option Schema""" + key = fields.String(required=True) value = fields.String(required=True) class ConfigOption(NamedTuple): """Config option""" + key: str value: str class ConfigSectionSchema(Schema): """Config Section Schema""" + name = fields.String(required=True) options = fields.List(fields.Nested(ConfigOptionSchema)) class ConfigSection(NamedTuple): """List of config options within a section""" + name: str options: List[ConfigOption] class ConfigSchema(Schema): """Config Schema""" + sections = fields.List(fields.Nested(ConfigSectionSchema)) class Config(NamedTuple): """List of config sections with their options""" + sections: List[ConfigSection] diff --git a/airflow/api_connexion/schemas/connection_schema.py b/airflow/api_connexion/schemas/connection_schema.py index e1d0a78cadc52..7c4e49f8bd100 100644 --- a/airflow/api_connexion/schemas/connection_schema.py +++ b/airflow/api_connexion/schemas/connection_schema.py @@ -30,6 +30,7 @@ class ConnectionCollectionItemSchema(SQLAlchemySchema): class Meta: """Meta""" + model = Connection connection_id = auto_field('conn_id', required=True) @@ -51,12 +52,14 @@ class ConnectionSchema(ConnectionCollectionItemSchema): # pylint: disable=too-m class ConnectionCollection(NamedTuple): """List of Connections with meta""" + connections: List[Connection] total_entries: int class ConnectionCollectionSchema(Schema): """Connection Collection Schema""" + connections = fields.List(fields.Nested(ConnectionCollectionItemSchema)) total_entries = fields.Int() diff --git a/airflow/api_connexion/schemas/dag_run_schema.py b/airflow/api_connexion/schemas/dag_run_schema.py index f13d3d787f015..fcab35771af56 100644 --- a/airflow/api_connexion/schemas/dag_run_schema.py +++ b/airflow/api_connexion/schemas/dag_run_schema.py @@ -30,6 +30,7 @@ class ConfObject(fields.Field): """The conf field""" + def _serialize(self, value, attr, obj, **kwargs): if not value: return {} @@ -94,6 +95,7 @@ class DagRunsBatchFormSchema(Schema): class Meta: """Meta""" + datetimeformat = 'iso' strict = True diff --git a/airflow/api_connexion/schemas/dag_schema.py b/airflow/api_connexion/schemas/dag_schema.py index ec07695eb403b..7a50799f1ada3 100644 --- a/airflow/api_connexion/schemas/dag_schema.py +++ b/airflow/api_connexion/schemas/dag_schema.py @@ -26,6 +26,7 @@ class DagTagSchema(SQLAlchemySchema): """Dag Tag schema""" + class Meta: """Meta""" @@ -39,6 +40,7 @@ class DAGSchema(SQLAlchemySchema): class Meta: """Meta""" + model = DagModel dag_id = auto_field(dump_only=True) diff --git a/airflow/api_connexion/schemas/dag_source_schema.py b/airflow/api_connexion/schemas/dag_source_schema.py index 6ce65c80ac642..d142454bc1f6d 100644 --- a/airflow/api_connexion/schemas/dag_source_schema.py +++ b/airflow/api_connexion/schemas/dag_source_schema.py @@ -20,6 +20,7 @@ class DagSourceSchema(Schema): """Dag Source schema""" + content = fields.String(dump_only=True) diff --git a/airflow/api_connexion/schemas/enum_schemas.py b/airflow/api_connexion/schemas/enum_schemas.py index 8e7280a877a00..352540a10f15c 100644 --- a/airflow/api_connexion/schemas/enum_schemas.py +++ b/airflow/api_connexion/schemas/enum_schemas.py @@ -22,8 +22,7 @@ class DagStateField(fields.String): """Schema for DagState Enum""" + def __init__(self, **metadata): super().__init__(**metadata) - self.validators = ( - [validate.OneOf(State.dag_states)] + list(self.validators) - ) + self.validators = [validate.OneOf(State.dag_states)] + list(self.validators) diff --git a/airflow/api_connexion/schemas/event_log_schema.py b/airflow/api_connexion/schemas/event_log_schema.py index c0b1bb280be9a..d97c223bffa23 100644 --- a/airflow/api_connexion/schemas/event_log_schema.py +++ b/airflow/api_connexion/schemas/event_log_schema.py @@ -28,6 +28,7 @@ class EventLogSchema(SQLAlchemySchema): class Meta: """Meta""" + model = Log id = auto_field(data_key='event_log_id', dump_only=True) @@ -42,6 +43,7 @@ class Meta: class EventLogCollection(NamedTuple): """List of import errors with metadata""" + event_logs: List[Log] total_entries: int diff --git a/airflow/api_connexion/schemas/health_schema.py b/airflow/api_connexion/schemas/health_schema.py index bccfc0ae0fb03..b8be4d88048b4 100644 --- a/airflow/api_connexion/schemas/health_schema.py +++ b/airflow/api_connexion/schemas/health_schema.py @@ -20,6 +20,7 @@ class BaseInfoSchema(Schema): """Base status field for metadatabase and scheduler""" + status = fields.String(dump_only=True) @@ -29,6 +30,7 @@ class MetaDatabaseInfoSchema(BaseInfoSchema): class SchedulerInfoSchema(BaseInfoSchema): """Schema for Metadatabase info""" + latest_scheduler_heartbeat = fields.String(dump_only=True) diff --git a/airflow/api_connexion/schemas/log_schema.py b/airflow/api_connexion/schemas/log_schema.py index b9b7817dc3b30..66f12e6dd202a 100644 --- a/airflow/api_connexion/schemas/log_schema.py +++ b/airflow/api_connexion/schemas/log_schema.py @@ -28,6 +28,7 @@ class LogsSchema(Schema): class LogResponseObject(NamedTuple): """Log Response Object""" + content: str continuation_token: str diff --git a/airflow/api_connexion/schemas/pool_schema.py b/airflow/api_connexion/schemas/pool_schema.py index 4b7f62eb13021..c1b92f3ba6b91 100644 --- a/airflow/api_connexion/schemas/pool_schema.py +++ b/airflow/api_connexion/schemas/pool_schema.py @@ -28,6 +28,7 @@ class PoolSchema(SQLAlchemySchema): class Meta: """Meta""" + model = Pool name = auto_field("pool") diff --git a/airflow/api_connexion/schemas/task_schema.py b/airflow/api_connexion/schemas/task_schema.py index 52a6a3034b7d6..d87123fb5d8d4 100644 --- a/airflow/api_connexion/schemas/task_schema.py +++ b/airflow/api_connexion/schemas/task_schema.py @@ -20,7 +20,10 @@ from marshmallow import Schema, fields from airflow.api_connexion.schemas.common_schema import ( - ClassReferenceSchema, ColorField, TimeDeltaSchema, WeightRuleField, + ClassReferenceSchema, + ColorField, + TimeDeltaSchema, + WeightRuleField, ) from airflow.api_connexion.schemas.dag_schema import DAGSchema from airflow.models.baseoperator import BaseOperator @@ -36,9 +39,7 @@ class TaskSchema(Schema): end_date = fields.DateTime(dump_only=True) trigger_rule = fields.String(dump_only=True) extra_links = fields.List( - fields.Nested(ClassReferenceSchema), - dump_only=True, - attribute="operator_extra_links" + fields.Nested(ClassReferenceSchema), dump_only=True, attribute="operator_extra_links" ) depends_on_past = fields.Boolean(dump_only=True) wait_for_downstream = fields.Boolean(dump_only=True) diff --git a/airflow/api_connexion/schemas/variable_schema.py b/airflow/api_connexion/schemas/variable_schema.py index 4c73f8c36e3d3..6b5d16e4227d6 100644 --- a/airflow/api_connexion/schemas/variable_schema.py +++ b/airflow/api_connexion/schemas/variable_schema.py @@ -20,12 +20,14 @@ class VariableSchema(Schema): """Variable Schema""" + key = fields.String(required=True) value = fields.String(attribute="val", required=True) class VariableCollectionSchema(Schema): """Variable Collection Schema""" + variables = fields.List(fields.Nested(VariableSchema)) total_entries = fields.Int() diff --git a/airflow/api_connexion/schemas/version_schema.py b/airflow/api_connexion/schemas/version_schema.py index e2ca25528b86c..24bd9337c1c36 100644 --- a/airflow/api_connexion/schemas/version_schema.py +++ b/airflow/api_connexion/schemas/version_schema.py @@ -20,6 +20,7 @@ class VersionInfoSchema(Schema): """Version information schema""" + version = fields.String(dump_only=True) git_version = fields.String(dump_only=True) diff --git a/airflow/api_connexion/schemas/xcom_schema.py b/airflow/api_connexion/schemas/xcom_schema.py index 268732534851f..01b93b5caf13b 100644 --- a/airflow/api_connexion/schemas/xcom_schema.py +++ b/airflow/api_connexion/schemas/xcom_schema.py @@ -29,6 +29,7 @@ class XComCollectionItemSchema(SQLAlchemySchema): class Meta: """Meta""" + model = XCom key = auto_field() @@ -48,12 +49,14 @@ class XComSchema(XComCollectionItemSchema): class XComCollection(NamedTuple): """List of XComs with meta""" + xcom_entries: List[XCom] total_entries: int class XComCollectionSchema(Schema): """XCom Collection Schema""" + xcom_entries = fields.List(fields.Nested(XComCollectionItemSchema)) total_entries = fields.Int() diff --git a/airflow/api_connexion/security.py b/airflow/api_connexion/security.py index 5ededfcad79da..ca01d3cd63b2d 100644 --- a/airflow/api_connexion/security.py +++ b/airflow/api_connexion/security.py @@ -27,6 +27,7 @@ def requires_authentication(function: T): """Decorator for functions that require authentication""" + @wraps(function) def decorated(*args, **kwargs): response = current_app.api_auth.requires_authentication(Response)() diff --git a/tests/api_connexion/endpoints/test_config_endpoint.py b/tests/api_connexion/endpoints/test_config_endpoint.py index 3d0c3b9046489..62a0c08c5975f 100644 --- a/tests/api_connexion/endpoints/test_config_endpoint.py +++ b/tests/api_connexion/endpoints/test_config_endpoint.py @@ -24,26 +24,16 @@ from tests.test_utils.config import conf_vars MOCK_CONF = { - 'core': { - 'parallelism': '1024', - }, - 'smtp': { - 'smtp_host': 'localhost', - 'smtp_mail_from': 'airflow@example.com', - }, + 'core': {'parallelism': '1024',}, + 'smtp': {'smtp_host': 'localhost', 'smtp_mail_from': 'airflow@example.com',}, } -@patch( - "airflow.api_connexion.endpoints.config_endpoint.conf.as_dict", - return_value=MOCK_CONF -) +@patch("airflow.api_connexion.endpoints.config_endpoint.conf.as_dict", return_value=MOCK_CONF) class TestGetConfig: @classmethod def setup_class(cls) -> None: - with conf_vars( - {("api", "auth_backend"): "tests.test_utils.remote_user_api_auth_backend"} - ): + with conf_vars({("api", "auth_backend"): "tests.test_utils.remote_user_api_auth_backend"}): cls.app = app.create_app(testing=True) # type:ignore # TODO: Add new role for each view to test permission create_user(cls.app, username="test", role="Admin") # type: ignore @@ -62,37 +52,34 @@ def test_should_response_200_text_plain(self, mock_as_dict): "/api/v1/config", headers={'Accept': 'text/plain'}, environ_overrides={'REMOTE_USER': "test"} ) assert response.status_code == 200 - expected = textwrap.dedent("""\ + expected = textwrap.dedent( + """\ [core] parallelism = 1024 [smtp] smtp_host = localhost smtp_mail_from = airflow@example.com - """) + """ + ) assert expected == response.data.decode() def test_should_response_200_application_json(self, mock_as_dict): response = self.client.get( "/api/v1/config", headers={'Accept': 'application/json'}, - environ_overrides={'REMOTE_USER': "test"} + environ_overrides={'REMOTE_USER': "test"}, ) assert response.status_code == 200 expected = { 'sections': [ - { - 'name': 'core', - 'options': [ - {'key': 'parallelism', 'value': '1024'}, - ] - }, + {'name': 'core', 'options': [{'key': 'parallelism', 'value': '1024'},]}, { 'name': 'smtp', 'options': [ {'key': 'smtp_host', 'value': 'localhost'}, {'key': 'smtp_mail_from', 'value': 'airflow@example.com'}, - ] + ], }, ] } @@ -102,14 +89,11 @@ def test_should_response_406(self, mock_as_dict): response = self.client.get( "/api/v1/config", headers={'Accept': 'application/octet-stream'}, - environ_overrides={'REMOTE_USER': "test"} + environ_overrides={'REMOTE_USER': "test"}, ) assert response.status_code == 406 def test_should_raises_401_unauthenticated(self, mock_as_dict): - response = self.client.get( - "/api/v1/config", - headers={'Accept': 'application/json'} - ) + response = self.client.get("/api/v1/config", headers={'Accept': 'application/json'}) assert_401(response) diff --git a/tests/api_connexion/endpoints/test_connection_endpoint.py b/tests/api_connexion/endpoints/test_connection_endpoint.py index e45e013164c21..7bde11f90f191 100644 --- a/tests/api_connexion/endpoints/test_connection_endpoint.py +++ b/tests/api_connexion/endpoints/test_connection_endpoint.py @@ -30,9 +30,7 @@ class TestConnectionEndpoint(unittest.TestCase): @classmethod def setUpClass(cls) -> None: super().setUpClass() - with conf_vars( - {("api", "auth_backend"): "tests.test_utils.remote_user_api_auth_backend"} - ): + with conf_vars({("api", "auth_backend"): "tests.test_utils.remote_user_api_auth_backend"}): cls.app = app.create_app(testing=True) # type:ignore # TODO: Add new role for each view to test permission. create_user(cls.app, username="test", role="Admin") # type: ignore @@ -50,18 +48,15 @@ def tearDown(self) -> None: clear_db_connections() def _create_connection(self, session): - connection_model = Connection(conn_id='test-connection-id', - conn_type='test_type') + connection_model = Connection(conn_id='test-connection-id', conn_type='test_type') session.add(connection_model) session.commit() class TestDeleteConnection(TestConnectionEndpoint): - @provide_session def test_delete_should_response_204(self, session): - connection_model = Connection(conn_id='test-connection', - conn_type='test_type') + connection_model = Connection(conn_id='test-connection', conn_type='test_type') session.add(connection_model) session.commit() @@ -81,12 +76,7 @@ def test_delete_should_response_404(self): assert response.status_code == 404 self.assertEqual( response.json, - { - 'detail': None, - 'status': 404, - 'title': 'Connection not found', - 'type': 'about:blank' - } + {'detail': None, 'status': 404, 'title': 'Connection not found', 'type': 'about:blank'}, ) def test_should_raises_401_unauthenticated(self): @@ -96,16 +86,16 @@ def test_should_raises_401_unauthenticated(self): class TestGetConnection(TestConnectionEndpoint): - @provide_session def test_should_response_200(self, session): - connection_model = Connection(conn_id='test-connection-id', - conn_type='mysql', - host='mysql', - login='login', - schema='testschema', - port=80 - ) + connection_model = Connection( + conn_id='test-connection-id', + conn_type='mysql', + host='mysql', + login='login', + schema='testschema', + port=80, + ) session.add(connection_model) session.commit() result = session.query(Connection).all() @@ -122,7 +112,7 @@ def test_should_response_200(self, session): "host": 'mysql', "login": 'login', 'schema': 'testschema', - 'port': 80 + 'port': 80, }, ) @@ -132,13 +122,8 @@ def test_should_response_404(self): ) assert response.status_code == 404 self.assertEqual( - { - 'detail': None, - 'status': 404, - 'title': 'Connection not found', - 'type': 'about:blank' - }, - response.json + {'detail': None, 'status': 404, 'title': 'Connection not found', 'type': 'about:blank'}, + response.json, ) def test_should_raises_401_unauthenticated(self): @@ -148,21 +133,16 @@ def test_should_raises_401_unauthenticated(self): class TestGetConnections(TestConnectionEndpoint): - @provide_session def test_should_response_200(self, session): - connection_model_1 = Connection(conn_id='test-connection-id-1', - conn_type='test_type') - connection_model_2 = Connection(conn_id='test-connection-id-2', - conn_type='test_type') + connection_model_1 = Connection(conn_id='test-connection-id-1', conn_type='test_type') + connection_model_2 = Connection(conn_id='test-connection-id-2', conn_type='test_type') connections = [connection_model_1, connection_model_2] session.add_all(connections) session.commit() result = session.query(Connection).all() assert len(result) == 2 - response = self.client.get( - "/api/v1/connections", environ_overrides={'REMOTE_USER': "test"} - ) + response = self.client.get("/api/v1/connections", environ_overrides={'REMOTE_USER': "test"}) assert response.status_code == 200 self.assertEqual( response.json, @@ -174,7 +154,7 @@ def test_should_response_200(self, session): "host": None, "login": None, 'schema': None, - 'port': None + 'port': None, }, { "connection_id": "test-connection-id-2", @@ -182,11 +162,11 @@ def test_should_response_200(self, session): "host": None, "login": None, 'schema': None, - 'port': None - } + 'port': None, + }, ], - 'total_entries': 2 - } + 'total_entries': 2, + }, ) def test_should_raises_401_unauthenticated(self): @@ -196,20 +176,13 @@ def test_should_raises_401_unauthenticated(self): class TestGetConnectionsPagination(TestConnectionEndpoint): - @parameterized.expand( [ ("/api/v1/connections?limit=1", ['TEST_CONN_ID1']), ("/api/v1/connections?limit=2", ['TEST_CONN_ID1', "TEST_CONN_ID2"]), ( "/api/v1/connections?offset=5", - [ - "TEST_CONN_ID6", - "TEST_CONN_ID7", - "TEST_CONN_ID8", - "TEST_CONN_ID9", - "TEST_CONN_ID10", - ], + ["TEST_CONN_ID6", "TEST_CONN_ID7", "TEST_CONN_ID8", "TEST_CONN_ID9", "TEST_CONN_ID10",], ), ( "/api/v1/connections?offset=0", @@ -228,10 +201,7 @@ class TestGetConnectionsPagination(TestConnectionEndpoint): ), ("/api/v1/connections?limit=1&offset=5", ["TEST_CONN_ID6"]), ("/api/v1/connections?limit=1&offset=1", ["TEST_CONN_ID2"]), - ( - "/api/v1/connections?limit=2&offset=2", - ["TEST_CONN_ID3", "TEST_CONN_ID4"], - ), + ("/api/v1/connections?limit=2&offset=2", ["TEST_CONN_ID3", "TEST_CONN_ID4"],), ] ) @provide_session @@ -281,28 +251,17 @@ def test_should_return_conf_max_if_req_max_above_conf(self, session): self.assertEqual(len(response.json['connections']), 150) def _create_connections(self, count): - return [Connection( - conn_id='TEST_CONN_ID' + str(i), - conn_type='TEST_CONN_TYPE' + str(i) - ) for i in range(1, count + 1)] + return [ + Connection(conn_id='TEST_CONN_ID' + str(i), conn_type='TEST_CONN_TYPE' + str(i)) + for i in range(1, count + 1) + ] class TestPatchConnection(TestConnectionEndpoint): - @parameterized.expand( [ - ( - { - "connection_id": "test-connection-id", - "conn_type": 'test_type', - "extra": "{'key': 'var'}" - }, - ), - ( - { - "extra": "{'key': 'var'}" - }, - ) + ({"connection_id": "test-connection-id", "conn_type": 'test_type', "extra": "{'key': 'var'}"},), + ({"extra": "{'key': 'var'}"},), ] ) @provide_session @@ -310,9 +269,7 @@ def test_patch_should_response_200(self, payload, session): self._create_connection(session) response = self.client.patch( - "/api/v1/connections/test-connection-id", - json=payload, - environ_overrides={'REMOTE_USER': "test"} + "/api/v1/connections/test-connection-id", json=payload, environ_overrides={'REMOTE_USER': "test"} ) assert response.status_code == 200 @@ -330,7 +287,7 @@ def test_patch_should_response_200_with_update_mask(self, session): response = self.client.patch( "/api/v1/connections/test-connection-id?update_mask=port,login", json=payload, - environ_overrides={'REMOTE_USER': "test"} + environ_overrides={'REMOTE_USER': "test"}, ) assert response.status_code == 200 connection = session.query(Connection).filter_by(conn_id=test_connection).first() @@ -344,9 +301,8 @@ def test_patch_should_response_200_with_update_mask(self, session): 'login': "login", # updated "port": 80, # updated "schema": None, - "host": None - - } + "host": None, + }, ) @parameterized.expand( @@ -360,8 +316,7 @@ def test_patch_should_response_200_with_update_mask(self, session): "port": 80, }, 'update_mask=ports, login', # posts is unknown - "'ports' is unknown or cannot be updated." - + "'ports' is unknown or cannot be updated.", ), ( { @@ -372,8 +327,7 @@ def test_patch_should_response_200_with_update_mask(self, session): "port": 80, }, 'update_mask=port, login, conn_id', # conn_id is unknown - "'conn_id' is unknown or cannot be updated." - + "'conn_id' is unknown or cannot be updated.", ), ( { @@ -384,8 +338,7 @@ def test_patch_should_response_200_with_update_mask(self, session): "port": 80, }, 'update_mask=port, login, connection_id', # connection_id cannot be updated - "'connection_id' is unknown or cannot be updated." - + "'connection_id' is unknown or cannot be updated.", ), ( { @@ -394,7 +347,7 @@ def test_patch_should_response_200_with_update_mask(self, session): "login": "login", }, '', # not necessary - "The connection_id cannot be updated." + "The connection_id cannot be updated.", ), ] ) @@ -406,7 +359,7 @@ def test_patch_should_response_400_for_invalid_fields_in_update_mask( response = self.client.patch( f"/api/v1/connections/test-connection-id?{update_mask}", json=payload, - environ_overrides={'REMOTE_USER': "test"} + environ_overrides={'REMOTE_USER': "test"}, ) assert response.status_code == 400 self.assertEqual(response.json['title'], error_message) @@ -418,14 +371,16 @@ def test_patch_should_response_400_for_invalid_fields_in_update_mask( "connection_id": "test-connection-id", "conn_type": "test-type", "extra": 0, # expected string - }, "0 is not of type 'string' - 'extra'" + }, + "0 is not of type 'string' - 'extra'", ), ( { "connection_id": "test-connection-id", "conn_type": "test-type", "extras": "{}", # extras not a known field e.g typo - }, "extras" + }, + "extras", ), ( { @@ -433,43 +388,29 @@ def test_patch_should_response_400_for_invalid_fields_in_update_mask( "conn_type": "test-type", "invalid_field": "invalid field", # unknown field "_password": "{}", # _password not a known field - }, "_password" + }, + "_password", ), ] ) @provide_session - def test_patch_should_response_400_for_invalid_update( - self, payload, error_message, session - ): + def test_patch_should_response_400_for_invalid_update(self, payload, error_message, session): self._create_connection(session) response = self.client.patch( - "/api/v1/connections/test-connection-id", - json=payload, - environ_overrides={'REMOTE_USER': "test"} + "/api/v1/connections/test-connection-id", json=payload, environ_overrides={'REMOTE_USER': "test"} ) assert response.status_code == 400 self.assertIn(error_message, response.json['detail']) def test_patch_should_response_404_not_found(self): - payload = { - "connection_id": "test-connection-id", - "conn_type": "test-type", - "port": 90 - } + payload = {"connection_id": "test-connection-id", "conn_type": "test-type", "port": 90} response = self.client.patch( - "/api/v1/connections/test-connection-id", - json=payload, - environ_overrides={'REMOTE_USER': "test"} + "/api/v1/connections/test-connection-id", json=payload, environ_overrides={'REMOTE_USER': "test"} ) assert response.status_code == 404 self.assertEqual( - { - 'detail': None, - 'status': 404, - 'title': 'Connection not found', - 'type': 'about:blank' - }, - response.json + {'detail': None, 'status': 404, 'title': 'Connection not found', 'type': 'about:blank'}, + response.json, ) @provide_session @@ -478,28 +419,18 @@ def test_should_raises_401_unauthenticated(self, session): response = self.client.patch( "/api/v1/connections/test-connection-id", - json={ - "connection_id": "test-connection-id", - "conn_type": 'test_type', - "extra": "{'key': 'var'}" - } + json={"connection_id": "test-connection-id", "conn_type": 'test_type', "extra": "{'key': 'var'}"}, ) assert_401(response) class TestPostConnection(TestConnectionEndpoint): - @provide_session def test_post_should_response_200(self, session): - payload = { - "connection_id": "test-connection-id", - "conn_type": 'test_type' - } + payload = {"connection_id": "test-connection-id", "conn_type": 'test_type'} response = self.client.post( - "/api/v1/connections", - json=payload, - environ_overrides={'REMOTE_USER': "test"} + "/api/v1/connections", json=payload, environ_overrides={'REMOTE_USER': "test"} ) assert response.status_code == 200 connection = session.query(Connection).all() @@ -511,29 +442,29 @@ def test_post_should_response_400_for_invalid_payload(self): "connection_id": "test-connection-id", } # conn_type missing response = self.client.post( - "/api/v1/connections", - json=payload, - environ_overrides={'REMOTE_USER': "test"} + "/api/v1/connections", json=payload, environ_overrides={'REMOTE_USER': "test"} ) assert response.status_code == 400 - self.assertEqual(response.json, - {'detail': "{'conn_type': ['Missing data for required field.']}", - 'status': 400, - 'title': 'Bad request', - 'type': 'about:blank'} - ) + self.assertEqual( + response.json, + { + 'detail': "{'conn_type': ['Missing data for required field.']}", + 'status': 400, + 'title': 'Bad request', + 'type': 'about:blank', + }, + ) def test_post_should_response_409_already_exist(self): - payload = { - "connection_id": "test-connection-id", - "conn_type": 'test_type' - } + payload = {"connection_id": "test-connection-id", "conn_type": 'test_type'} response = self.client.post( - "/api/v1/connections", json=payload, environ_overrides={'REMOTE_USER': "test"}) + "/api/v1/connections", json=payload, environ_overrides={'REMOTE_USER': "test"} + ) assert response.status_code == 200 # Another request response = self.client.post( - "/api/v1/connections", json=payload, environ_overrides={'REMOTE_USER': "test"}) + "/api/v1/connections", json=payload, environ_overrides={'REMOTE_USER': "test"} + ) assert response.status_code == 409 self.assertEqual( response.json, @@ -541,17 +472,13 @@ def test_post_should_response_409_already_exist(self): 'detail': None, 'status': 409, 'title': 'Connection already exist. ID: test-connection-id', - 'type': 'about:blank' - } + 'type': 'about:blank', + }, ) def test_should_raises_401_unauthenticated(self): response = self.client.post( - "/api/v1/connections", - json={ - "connection_id": "test-connection-id", - "conn_type": 'test_type' - } + "/api/v1/connections", json={"connection_id": "test-connection-id", "conn_type": 'test_type'} ) assert_401(response) diff --git a/tests/api_connexion/endpoints/test_dag_endpoint.py b/tests/api_connexion/endpoints/test_dag_endpoint.py index 66c7cf7ec8ea8..e219a285908ae 100644 --- a/tests/api_connexion/endpoints/test_dag_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_endpoint.py @@ -44,9 +44,7 @@ def clean_db(): @classmethod def setUpClass(cls) -> None: super().setUpClass() - with conf_vars( - {("api", "auth_backend"): "tests.test_utils.remote_user_api_auth_backend"} - ): + with conf_vars({("api", "auth_backend"): "tests.test_utils.remote_user_api_auth_backend"}): cls.app = app.create_app(testing=True) # type:ignore # TODO: Add new role for each view to test permission. create_user(cls.app, username="test", role="Admin") # type: ignore @@ -74,9 +72,7 @@ def tearDown(self) -> None: def _create_dag_models(self, count, session=None): for num in range(1, count + 1): dag_model = DagModel( - dag_id=f"TEST_DAG_{num}", - fileloc=f"/tmp/dag_{num}.py", - schedule_interval="2 2 * * *", + dag_id=f"TEST_DAG_{num}", fileloc=f"/tmp/dag_{num}.py", schedule_interval="2 2 * * *", ) session.add(dag_model) @@ -135,12 +131,7 @@ def test_should_response_200(self): "is_subdag": False, "orientation": "LR", "owners": [], - "schedule_interval": { - "__type": "TimeDelta", - "days": 1, - "microseconds": 0, - "seconds": 0, - }, + "schedule_interval": {"__type": "TimeDelta", "days": 1, "microseconds": 0, "seconds": 0,}, "start_date": "2020-06-15T00:00:00+00:00", "tags": None, "timezone": "Timezone('UTC')", @@ -169,12 +160,7 @@ def test_should_response_200_serialized(self): "is_subdag": False, "orientation": "LR", "owners": [], - "schedule_interval": { - "__type": "TimeDelta", - "days": 1, - "microseconds": 0, - "seconds": 0, - }, + "schedule_interval": {"__type": "TimeDelta", "days": 1, "microseconds": 0, "seconds": 0,}, "start_date": "2020-06-15T00:00:00+00:00", "tags": None, "timezone": "Timezone('UTC')", @@ -202,15 +188,10 @@ def test_should_response_200_serialized(self): 'is_subdag': False, 'orientation': 'LR', 'owners': [], - 'schedule_interval': { - '__type': 'TimeDelta', - 'days': 1, - 'microseconds': 0, - 'seconds': 0 - }, + 'schedule_interval': {'__type': 'TimeDelta', 'days': 1, 'microseconds': 0, 'seconds': 0}, 'start_date': '2020-06-15T00:00:00+00:00', 'tags': None, - 'timezone': "Timezone('UTC')" + 'timezone': "Timezone('UTC')", } assert response.json == expected @@ -239,10 +220,7 @@ def test_should_response_200(self): "is_subdag": False, "owners": [], "root_dag_id": None, - "schedule_interval": { - "__type": "CronExpression", - "value": "2 2 * * *", - }, + "schedule_interval": {"__type": "CronExpression", "value": "2 2 * * *",}, "tags": [], }, { @@ -253,10 +231,7 @@ def test_should_response_200(self): "is_subdag": False, "owners": [], "root_dag_id": None, - "schedule_interval": { - "__type": "CronExpression", - "value": "2 2 * * *", - }, + "schedule_interval": {"__type": "CronExpression", "value": "2 2 * * *",}, "tags": [], }, ], @@ -269,10 +244,7 @@ def test_should_response_200(self): [ ("api/v1/dags?limit=1", ["TEST_DAG_1"]), ("api/v1/dags?limit=2", ["TEST_DAG_1", "TEST_DAG_10"]), - ( - "api/v1/dags?offset=5", - ["TEST_DAG_5", "TEST_DAG_6", "TEST_DAG_7", "TEST_DAG_8", "TEST_DAG_9"], - ), + ("api/v1/dags?offset=5", ["TEST_DAG_5", "TEST_DAG_6", "TEST_DAG_7", "TEST_DAG_8", "TEST_DAG_9"],), ( "api/v1/dags?offset=0", [ @@ -326,10 +298,8 @@ def test_should_response_200_on_patch_is_paused(self): dag_model = self._create_dag_model() response = self.client.patch( f"/api/v1/dags/{dag_model.dag_id}", - json={ - "is_paused": False, - }, - environ_overrides={'REMOTE_USER': "test"} + json={"is_paused": False,}, + environ_overrides={'REMOTE_USER': "test"}, ) self.assertEqual(response.status_code, 200) expected_response = { @@ -340,10 +310,7 @@ def test_should_response_200_on_patch_is_paused(self): "is_subdag": False, "owners": [], "root_dag_id": None, - "schedule_interval": { - "__type": "CronExpression", - "value": "2 2 * * *", - }, + "schedule_interval": {"__type": "CronExpression", "value": "2 2 * * *",}, "tags": [], } self.assertEqual(response.json, expected_response) @@ -351,20 +318,20 @@ def test_should_response_200_on_patch_is_paused(self): def test_should_response_400_on_invalid_request(self): patch_body = { "is_paused": True, - "schedule_interval": { - "__type": "CronExpression", - "value": "1 1 * * *", - }, + "schedule_interval": {"__type": "CronExpression", "value": "1 1 * * *",}, } dag_model = self._create_dag_model() response = self.client.patch(f"/api/v1/dags/{dag_model.dag_id}", json=patch_body) self.assertEqual(response.status_code, 400) - self.assertEqual(response.json, { - 'detail': "Property is read-only - 'schedule_interval'", - 'status': 400, - 'title': 'Bad Request', - 'type': 'about:blank' - }) + self.assertEqual( + response.json, + { + 'detail': "Property is read-only - 'schedule_interval'", + 'status': 400, + 'title': 'Bad Request', + 'type': 'about:blank', + }, + ) def test_should_response_404(self): response = self.client.get("/api/v1/dags/INVALID_DAG", environ_overrides={'REMOTE_USER': "test"}) @@ -373,21 +340,13 @@ def test_should_response_404(self): @provide_session def _create_dag_model(self, session=None): dag_model = DagModel( - dag_id="TEST_DAG_1", - fileloc="/tmp/dag_1.py", - schedule_interval="2 2 * * *", - is_paused=True + dag_id="TEST_DAG_1", fileloc="/tmp/dag_1.py", schedule_interval="2 2 * * *", is_paused=True ) session.add(dag_model) return dag_model def test_should_raises_401_unauthenticated(self): dag_model = self._create_dag_model() - response = self.client.patch( - f"/api/v1/dags/{dag_model.dag_id}", - json={ - "is_paused": False, - }, - ) + response = self.client.patch(f"/api/v1/dags/{dag_model.dag_id}", json={"is_paused": False,},) assert_401(response) diff --git a/tests/api_connexion/endpoints/test_dag_run_endpoint.py b/tests/api_connexion/endpoints/test_dag_run_endpoint.py index 41b98e19415d4..816beb5dc4218 100644 --- a/tests/api_connexion/endpoints/test_dag_run_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_run_endpoint.py @@ -34,9 +34,7 @@ class TestDagRunEndpoint(unittest.TestCase): def setUpClass(cls) -> None: super().setUpClass() - with conf_vars( - {("api", "auth_backend"): "tests.test_utils.remote_user_api_auth_backend"} - ): + with conf_vars({("api", "auth_backend"): "tests.test_utils.remote_user_api_auth_backend"}): cls.app = app.create_app(testing=True) # type:ignore # TODO: Add new role for each view to test permission. create_user(cls.app, username="test", role="Admin") # type: ignore @@ -77,14 +75,17 @@ def _create_test_dag_run(self, state='running', extra_dag=False, commit=True): ) dag_runs.append(dagrun_model_2) if extra_dag: - dagrun_extra = [DagRun( - dag_id='TEST_DAG_ID_' + str(i), - run_id='TEST_DAG_RUN_ID_' + str(i), - run_type=DagRunType.MANUAL.value, - execution_date=timezone.parse(self.default_time_2), - start_date=timezone.parse(self.default_time), - external_trigger=True, - ) for i in range(3, 5)] + dagrun_extra = [ + DagRun( + dag_id='TEST_DAG_ID_' + str(i), + run_id='TEST_DAG_RUN_ID_' + str(i), + run_type=DagRunType.MANUAL.value, + execution_date=timezone.parse(self.default_time_2), + start_date=timezone.parse(self.default_time), + external_trigger=True, + ) + for i in range(3, 5) + ] dag_runs.extend(dagrun_extra) if commit: with create_session() as session: @@ -98,8 +99,7 @@ def test_should_response_204(self, session): session.add_all(self._create_test_dag_run()) session.commit() response = self.client.delete( - "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1", - environ_overrides={'REMOTE_USER': "test"} + "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1", environ_overrides={'REMOTE_USER': "test"} ) self.assertEqual(response.status_code, 204) # Check if the Dag Run is deleted from the database @@ -128,9 +128,7 @@ def test_should_raises_401_unauthenticated(self, session): session.add_all(self._create_test_dag_run()) session.commit() - response = self.client.delete( - "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1", - ) + response = self.client.delete("api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1",) assert_401(response) @@ -151,8 +149,7 @@ def test_should_response_200(self, session): result = session.query(DagRun).all() assert len(result) == 1 response = self.client.get( - "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID", - environ_overrides={'REMOTE_USER': "test"} + "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID", environ_overrides={'REMOTE_USER': "test"} ) assert response.status_code == 200 expected_response = { @@ -163,7 +160,7 @@ def test_should_response_200(self, session): 'execution_date': self.default_time, 'external_trigger': True, 'start_date': self.default_time, - 'conf': {} + 'conf': {}, } assert response.json == expected_response @@ -172,12 +169,7 @@ def test_should_response_404(self): "api/v1/dags/invalid-id/dagRuns/invalid-id", environ_overrides={'REMOTE_USER': "test"} ) assert response.status_code == 404 - expected_resp = { - 'detail': None, - 'status': 404, - 'title': 'DAGRun not found', - 'type': 'about:blank' - } + expected_resp = {'detail': None, 'status': 404, 'title': 'DAGRun not found', 'type': 'about:blank'} assert expected_resp == response.json @provide_session @@ -193,9 +185,7 @@ def test_should_raises_401_unauthenticated(self, session): session.add(dagrun_model) session.commit() - response = self.client.get( - "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID" - ) + response = self.client.get("api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID") assert_401(response) @@ -239,8 +229,7 @@ def test_should_response_200(self, session): @provide_session def test_should_return_all_with_tilde_as_dag_id(self, session): self._create_test_dag_run(extra_dag=True) - expected_dag_run_ids = ['TEST_DAG_ID', 'TEST_DAG_ID', - "TEST_DAG_ID_3", "TEST_DAG_ID_4"] + expected_dag_run_ids = ['TEST_DAG_ID', 'TEST_DAG_ID', "TEST_DAG_ID_3", "TEST_DAG_ID_4"] result = session.query(DagRun).all() assert len(result) == 4 response = self.client.get("api/v1/dags/~/dagRuns", environ_overrides={'REMOTE_USER': "test"}) @@ -251,9 +240,7 @@ def test_should_return_all_with_tilde_as_dag_id(self, session): def test_should_raises_401_unauthenticated(self): self._create_test_dag_run() - response = self.client.get( - "api/v1/dags/TEST_DAG_ID/dagRuns" - ) + response = self.client.get("api/v1/dags/TEST_DAG_ID/dagRuns") assert_401(response) @@ -262,10 +249,7 @@ class TestGetDagRunsPagination(TestDagRunEndpoint): @parameterized.expand( [ ("api/v1/dags/TEST_DAG_ID/dagRuns?limit=1", ["TEST_DAG_RUN_ID1"]), - ( - "api/v1/dags/TEST_DAG_ID/dagRuns?limit=2", - ["TEST_DAG_RUN_ID1", "TEST_DAG_RUN_ID2"], - ), + ("api/v1/dags/TEST_DAG_ID/dagRuns?limit=2", ["TEST_DAG_RUN_ID1", "TEST_DAG_RUN_ID2"],), ( "api/v1/dags/TEST_DAG_ID/dagRuns?offset=5", [ @@ -293,10 +277,7 @@ class TestGetDagRunsPagination(TestDagRunEndpoint): ), ("api/v1/dags/TEST_DAG_ID/dagRuns?limit=1&offset=5", ["TEST_DAG_RUN_ID6"]), ("api/v1/dags/TEST_DAG_ID/dagRuns?limit=1&offset=1", ["TEST_DAG_RUN_ID2"]), - ( - "api/v1/dags/TEST_DAG_ID/dagRuns?limit=2&offset=2", - ["TEST_DAG_RUN_ID3", "TEST_DAG_RUN_ID4"], - ), + ("api/v1/dags/TEST_DAG_ID/dagRuns?limit=2&offset=2", ["TEST_DAG_RUN_ID3", "TEST_DAG_RUN_ID4"],), ] ) def test_handle_limit_and_offset(self, url, expected_dag_run_ids): @@ -444,8 +425,7 @@ def test_end_date_gte_lte(self, url, expected_dag_run_ids): response = self.client.get(url, environ_overrides={'REMOTE_USER': "test"}) assert response.status_code == 200 assert response.json["total_entries"] == 2 - dag_run_ids = [ - dag_run["dag_run_id"] for dag_run in response.json["dag_runs"] if dag_run] + dag_run_ids = [dag_run["dag_run_id"] for dag_run in response.json["dag_runs"] if dag_run] assert dag_run_ids == expected_dag_run_ids @@ -454,10 +434,8 @@ def test_should_respond_200(self): self._create_test_dag_run() response = self.client.post( "api/v1/dags/~/dagRuns/list", - json={ - "dag_ids": ["TEST_DAG_ID"] - }, - environ_overrides={'REMOTE_USER': "test"} + json={"dag_ids": ["TEST_DAG_ID"]}, + environ_overrides={'REMOTE_USER': "test"}, ) assert response.status_code == 200 assert response.json == { @@ -488,32 +466,27 @@ def test_should_respond_200(self): @parameterized.expand( [ - ({"dag_ids": ["TEST_DAG_ID"], "page_offset": -1}, - "-1 is less than the minimum of 0 - 'page_offset'"), - ({"dag_ids": ["TEST_DAG_ID"], "page_limit": 0}, - "0 is less than the minimum of 1 - 'page_limit'"), - ({"dag_ids": "TEST_DAG_ID"}, - "'TEST_DAG_ID' is not of type 'array' - 'dag_ids'"), - ({"start_date_gte": "2020-06-12T18"}, - "{'start_date_gte': ['Not a valid datetime.']}"), + ( + {"dag_ids": ["TEST_DAG_ID"], "page_offset": -1}, + "-1 is less than the minimum of 0 - 'page_offset'", + ), + ({"dag_ids": ["TEST_DAG_ID"], "page_limit": 0}, "0 is less than the minimum of 1 - 'page_limit'"), + ({"dag_ids": "TEST_DAG_ID"}, "'TEST_DAG_ID' is not of type 'array' - 'dag_ids'"), + ({"start_date_gte": "2020-06-12T18"}, "{'start_date_gte': ['Not a valid datetime.']}"), ] ) def test_payload_validation(self, payload, error): self._create_test_dag_run() response = self.client.post( - "api/v1/dags/~/dagRuns/list", json=payload, environ_overrides={'REMOTE_USER': "test"}) + "api/v1/dags/~/dagRuns/list", json=payload, environ_overrides={'REMOTE_USER': "test"} + ) assert response.status_code == 400 assert error == response.json.get("detail") def test_should_raises_401_unauthenticated(self): self._create_test_dag_run() - response = self.client.post( - "api/v1/dags/~/dagRuns/list", - json={ - "dag_ids": ["TEST_DAG_ID"] - } - ) + response = self.client.post("api/v1/dags/~/dagRuns/list", json={"dag_ids": ["TEST_DAG_ID"]}) assert_401(response) @@ -602,20 +575,31 @@ class TestGetDagRunBatchDateFilters(TestDagRunEndpoint): ["TEST_START_EXEC_DAY_10", "TEST_START_EXEC_DAY_11"], ), ( - {"start_date_lte": "2020-06-15T18:00:00+00:00", - "start_date_gte": "2020-06-12T18:00:00Z"}, - ["TEST_START_EXEC_DAY_12", "TEST_START_EXEC_DAY_13", - "TEST_START_EXEC_DAY_14", "TEST_START_EXEC_DAY_15"], + {"start_date_lte": "2020-06-15T18:00:00+00:00", "start_date_gte": "2020-06-12T18:00:00Z"}, + [ + "TEST_START_EXEC_DAY_12", + "TEST_START_EXEC_DAY_13", + "TEST_START_EXEC_DAY_14", + "TEST_START_EXEC_DAY_15", + ], ), ( {"execution_date_lte": "2020-06-13T18:00:00+00:00"}, - ["TEST_START_EXEC_DAY_10", "TEST_START_EXEC_DAY_11", - "TEST_START_EXEC_DAY_12", "TEST_START_EXEC_DAY_13"], + [ + "TEST_START_EXEC_DAY_10", + "TEST_START_EXEC_DAY_11", + "TEST_START_EXEC_DAY_12", + "TEST_START_EXEC_DAY_13", + ], ), ( {"execution_date_gte": "2020-06-16T18:00:00+00:00"}, - ["TEST_START_EXEC_DAY_16", "TEST_START_EXEC_DAY_17", - "TEST_START_EXEC_DAY_18", "TEST_START_EXEC_DAY_19"], + [ + "TEST_START_EXEC_DAY_16", + "TEST_START_EXEC_DAY_17", + "TEST_START_EXEC_DAY_18", + "TEST_START_EXEC_DAY_19", + ], ), ] ) @@ -662,10 +646,7 @@ def _create_dag_runs(self): @parameterized.expand( [ - ( - {"end_date_gte": f"{(timezone.utcnow() + timedelta(days=1)).isoformat()}"}, - [], - ), + ({"end_date_gte": f"{(timezone.utcnow() + timedelta(days=1)).isoformat()}"}, [],), ( {"end_date_lte": f"{(timezone.utcnow() + timedelta(days=1)).isoformat()}"}, ["TEST_DAG_RUN_ID_1"], @@ -675,9 +656,7 @@ def _create_dag_runs(self): def test_end_date_gte_lte(self, payload, expected_dag_run_ids): self._create_test_dag_run('success') # state==success, then end date is today response = self.client.post( - "api/v1/dags/~/dagRuns/list", - json=payload, - environ_overrides={'REMOTE_USER': "test"} + "api/v1/dags/~/dagRuns/list", json=payload, environ_overrides={'REMOTE_USER': "test"} ) assert response.status_code == 200 assert response.json["total_entries"] == 2 @@ -690,10 +669,7 @@ class TestPostDagRun(TestDagRunEndpoint): [ ( "All fields present", - { - "dag_run_id": "TEST_DAG_RUN", - "execution_date": "2020-06-11T18:00:00+00:00", - }, + {"dag_run_id": "TEST_DAG_RUN", "execution_date": "2020-06-11T18:00:00+00:00",}, ), ("dag_run_id missing", {"execution_date": "2020-06-11T18:00:00+00:00"}), ("dag_run_id and execution_date missing", {}), @@ -727,7 +703,7 @@ def test_response_404(self): response = self.client.post( "api/v1/dags/TEST_DAG_ID/dagRuns", json={"dag_run_id": "TEST_DAG_RUN", "execution_date": self.default_time}, - environ_overrides={'REMOTE_USER': "test"} + environ_overrides={'REMOTE_USER': "test"}, ) self.assertEqual(response.status_code, 404) self.assertEqual( @@ -745,10 +721,7 @@ def test_response_404(self): ( "start_date in request json", "api/v1/dags/TEST_DAG_ID/dagRuns", - { - "start_date": "2020-06-11T18:00:00+00:00", - "execution_date": "2020-06-12T18:00:00+00:00", - }, + {"start_date": "2020-06-11T18:00:00+00:00", "execution_date": "2020-06-12T18:00:00+00:00",}, { "detail": "Property is read-only - 'start_date'", "status": 400, @@ -787,11 +760,8 @@ def test_response_409(self, session): session.commit() response = self.client.post( "api/v1/dags/TEST_DAG_ID/dagRuns", - json={ - "dag_run_id": "TEST_DAG_RUN_ID_1", - "execution_date": self.default_time, - }, - environ_overrides={'REMOTE_USER': "test"} + json={"dag_run_id": "TEST_DAG_RUN_ID_1", "execution_date": self.default_time,}, + environ_overrides={'REMOTE_USER': "test"}, ) self.assertEqual(response.status_code, 409, response.data) self.assertEqual( @@ -808,10 +778,7 @@ def test_response_409(self, session): def test_should_raises_401_unauthenticated(self): response = self.client.post( "api/v1/dags/TEST_DAG_ID/dagRuns", - json={ - "dag_run_id": "TEST_DAG_RUN_ID_1", - "execution_date": self.default_time, - }, + json={"dag_run_id": "TEST_DAG_RUN_ID_1", "execution_date": self.default_time,}, ) assert_401(response) diff --git a/tests/api_connexion/endpoints/test_dag_source_endpoint.py b/tests/api_connexion/endpoints/test_dag_source_endpoint.py index 1373243a6f09f..1b76c28742774 100644 --- a/tests/api_connexion/endpoints/test_dag_source_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_source_endpoint.py @@ -38,9 +38,7 @@ class TestGetSource(unittest.TestCase): @classmethod def setUpClass(cls) -> None: super().setUpClass() - with conf_vars( - {("api", "auth_backend"): "tests.test_utils.remote_user_api_auth_backend"} - ): + with conf_vars({("api", "auth_backend"): "tests.test_utils.remote_user_api_auth_backend"}): cls.app = app.create_app(testing=True) # type:ignore # TODO: Add new role for each view to test permission. create_user(cls.app, username="test", role="Admin") # type: ignore @@ -73,9 +71,9 @@ def _get_dag_file_docstring(fileloc: str) -> str: @parameterized.expand([(True,), (False,)]) def test_should_response_200_text(self, store_dag_code): serializer = URLSafeSerializer(conf.get('webserver', 'SECRET_KEY')) - with mock.patch( - "airflow.models.dag.settings.STORE_DAG_CODE", store_dag_code - ), mock.patch("airflow.models.dagcode.STORE_DAG_CODE", store_dag_code): + with mock.patch("airflow.models.dag.settings.STORE_DAG_CODE", store_dag_code), mock.patch( + "airflow.models.dagcode.STORE_DAG_CODE", store_dag_code + ): dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE) dagbag.sync_to_db() first_dag: DAG = next(iter(dagbag.dags.values())) @@ -83,9 +81,7 @@ def test_should_response_200_text(self, store_dag_code): url = f"/api/v1/dagSources/{serializer.dumps(first_dag.fileloc)}" response = self.client.get( - url, - headers={"Accept": "text/plain"}, - environ_overrides={'REMOTE_USER': "test"} + url, headers={"Accept": "text/plain"}, environ_overrides={'REMOTE_USER': "test"} ) self.assertEqual(200, response.status_code) @@ -95,9 +91,9 @@ def test_should_response_200_text(self, store_dag_code): @parameterized.expand([(True,), (False,)]) def test_should_response_200_json(self, store_dag_code): serializer = URLSafeSerializer(conf.get('webserver', 'SECRET_KEY')) - with mock.patch( - "airflow.models.dag.settings.STORE_DAG_CODE", store_dag_code - ), mock.patch("airflow.models.dagcode.STORE_DAG_CODE", store_dag_code): + with mock.patch("airflow.models.dag.settings.STORE_DAG_CODE", store_dag_code), mock.patch( + "airflow.models.dagcode.STORE_DAG_CODE", store_dag_code + ): dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE) dagbag.sync_to_db() first_dag: DAG = next(iter(dagbag.dags.values())) @@ -105,48 +101,39 @@ def test_should_response_200_json(self, store_dag_code): url = f"/api/v1/dagSources/{serializer.dumps(first_dag.fileloc)}" response = self.client.get( - url, - headers={"Accept": 'application/json'}, - environ_overrides={'REMOTE_USER': "test"} + url, headers={"Accept": 'application/json'}, environ_overrides={'REMOTE_USER': "test"} ) self.assertEqual(200, response.status_code) - self.assertIn( - dag_docstring, - response.json['content'] - ) + self.assertIn(dag_docstring, response.json['content']) self.assertEqual('application/json', response.headers['Content-Type']) @parameterized.expand([(True,), (False,)]) def test_should_response_406(self, store_dag_code): serializer = URLSafeSerializer(conf.get('webserver', 'SECRET_KEY')) - with mock.patch( - "airflow.models.dag.settings.STORE_DAG_CODE", store_dag_code - ), mock.patch("airflow.models.dagcode.STORE_DAG_CODE", store_dag_code): + with mock.patch("airflow.models.dag.settings.STORE_DAG_CODE", store_dag_code), mock.patch( + "airflow.models.dagcode.STORE_DAG_CODE", store_dag_code + ): dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE) dagbag.sync_to_db() first_dag: DAG = next(iter(dagbag.dags.values())) url = f"/api/v1/dagSources/{serializer.dumps(first_dag.fileloc)}" response = self.client.get( - url, - headers={"Accept": 'image/webp'}, - environ_overrides={'REMOTE_USER': "test"} + url, headers={"Accept": 'image/webp'}, environ_overrides={'REMOTE_USER': "test"} ) self.assertEqual(406, response.status_code) @parameterized.expand([(True,), (False,)]) def test_should_response_404(self, store_dag_code): - with mock.patch( - "airflow.models.dag.settings.STORE_DAG_CODE", store_dag_code - ), mock.patch("airflow.models.dagcode.STORE_DAG_CODE", store_dag_code): + with mock.patch("airflow.models.dag.settings.STORE_DAG_CODE", store_dag_code), mock.patch( + "airflow.models.dagcode.STORE_DAG_CODE", store_dag_code + ): wrong_fileloc = "abcd1234" url = f"/api/v1/dagSources/{wrong_fileloc}" response = self.client.get( - url, - headers={"Accept": 'application/json'}, - environ_overrides={'REMOTE_USER': "test"} + url, headers={"Accept": 'application/json'}, environ_overrides={'REMOTE_USER': "test"} ) self.assertEqual(404, response.status_code) @@ -158,8 +145,7 @@ def test_should_raises_401_unauthenticated(self): first_dag: DAG = next(iter(dagbag.dags.values())) response = self.client.get( - f"/api/v1/dagSources/{serializer.dumps(first_dag.fileloc)}", - headers={"Accept": "text/plain"}, + f"/api/v1/dagSources/{serializer.dumps(first_dag.fileloc)}", headers={"Accept": "text/plain"}, ) assert_401(response) diff --git a/tests/api_connexion/endpoints/test_event_log_endpoint.py b/tests/api_connexion/endpoints/test_event_log_endpoint.py index 84fcad7f68d36..780f447250f65 100644 --- a/tests/api_connexion/endpoints/test_event_log_endpoint.py +++ b/tests/api_connexion/endpoints/test_event_log_endpoint.py @@ -33,9 +33,7 @@ class TestEventLogEndpoint(unittest.TestCase): @classmethod def setUpClass(cls) -> None: super().setUpClass() - with conf_vars( - {("api", "auth_backend"): "tests.test_utils.remote_user_api_auth_backend"} - ): + with conf_vars({("api", "auth_backend"): "tests.test_utils.remote_user_api_auth_backend"}): cls.app = app.create_app(testing=True) # type:ignore # TODO: Add new role for each view to test permission. create_user(cls.app, username="test", role="Admin") # type: ignore @@ -54,23 +52,21 @@ def tearDown(self) -> None: clear_db_logs() def _create_task_instance(self): - dag = DAG('TEST_DAG_ID', start_date=timezone.parse(self.default_time), - end_date=timezone.parse(self.default_time)) - op1 = DummyOperator(task_id="TEST_TASK_ID", owner="airflow", - ) + dag = DAG( + 'TEST_DAG_ID', + start_date=timezone.parse(self.default_time), + end_date=timezone.parse(self.default_time), + ) + op1 = DummyOperator(task_id="TEST_TASK_ID", owner="airflow",) dag.add_task(op1) ti = TaskInstance(task=op1, execution_date=timezone.parse(self.default_time)) return ti class TestGetEventLog(TestEventLogEndpoint): - @provide_session def test_should_response_200(self, session): - log_model = Log( - event='TEST_EVENT', - task_instance=self._create_task_instance(), - ) + log_model = Log(event='TEST_EVENT', task_instance=self._create_task_instance(),) log_model.dttm = timezone.parse(self.default_time) session.add(log_model) session.commit() @@ -89,26 +85,21 @@ def test_should_response_200(self, session): "execution_date": self.default_time, "owner": 'airflow', "when": self.default_time, - "extra": None - } + "extra": None, + }, ) def test_should_response_404(self): - response = self.client.get( - "/api/v1/eventLogs/1", environ_overrides={'REMOTE_USER': "test"} - ) + response = self.client.get("/api/v1/eventLogs/1", environ_overrides={'REMOTE_USER': "test"}) assert response.status_code == 404 self.assertEqual( {'detail': None, 'status': 404, 'title': 'Event Log not found', 'type': 'about:blank'}, - response.json + response.json, ) @provide_session def test_should_raises_401_unauthenticated(self, session): - log_model = Log( - event='TEST_EVENT', - task_instance=self._create_task_instance(), - ) + log_model = Log(event='TEST_EVENT', task_instance=self._create_task_instance(),) log_model.dttm = timezone.parse(self.default_time) session.add(log_model) session.commit() @@ -120,17 +111,10 @@ def test_should_raises_401_unauthenticated(self, session): class TestGetEventLogs(TestEventLogEndpoint): - @provide_session def test_should_response_200(self, session): - log_model_1 = Log( - event='TEST_EVENT_1', - task_instance=self._create_task_instance(), - ) - log_model_2 = Log( - event='TEST_EVENT_2', - task_instance=self._create_task_instance(), - ) + log_model_1 = Log(event='TEST_EVENT_1', task_instance=self._create_task_instance(),) + log_model_2 = Log(event='TEST_EVENT_2', task_instance=self._create_task_instance(),) log_model_1.dttm = timezone.parse(self.default_time) log_model_2.dttm = timezone.parse(self.default_time_2) session.add_all([log_model_1, log_model_2]) @@ -142,7 +126,6 @@ def test_should_response_200(self, session): { "event_logs": [ { - "event_log_id": log_model_1.id, "event": "TEST_EVENT_1", "dag_id": "TEST_DAG_ID", @@ -150,8 +133,7 @@ def test_should_response_200(self, session): "execution_date": self.default_time, "owner": 'airflow', "when": self.default_time, - "extra": None - + "extra": None, }, { "event_log_id": log_model_2.id, @@ -161,23 +143,17 @@ def test_should_response_200(self, session): "execution_date": self.default_time, "owner": 'airflow', "when": self.default_time_2, - "extra": None - } + "extra": None, + }, ], - "total_entries": 2 - } + "total_entries": 2, + }, ) @provide_session def test_should_raises_401_unauthenticated(self, session): - log_model_1 = Log( - event='TEST_EVENT_1', - task_instance=self._create_task_instance(), - ) - log_model_2 = Log( - event='TEST_EVENT_2', - task_instance=self._create_task_instance(), - ) + log_model_1 = Log(event='TEST_EVENT_1', task_instance=self._create_task_instance(),) + log_model_2 = Log(event='TEST_EVENT_2', task_instance=self._create_task_instance(),) log_model_1.dttm = timezone.parse(self.default_time) log_model_2.dttm = timezone.parse(self.default_time_2) session.add_all([log_model_1, log_model_2]) @@ -195,13 +171,7 @@ class TestGetEventLogPagination(TestEventLogEndpoint): ("api/v1/eventLogs?limit=2", ["TEST_EVENT_1", "TEST_EVENT_2"]), ( "api/v1/eventLogs?offset=5", - [ - "TEST_EVENT_6", - "TEST_EVENT_7", - "TEST_EVENT_8", - "TEST_EVENT_9", - "TEST_EVENT_10", - ], + ["TEST_EVENT_6", "TEST_EVENT_7", "TEST_EVENT_8", "TEST_EVENT_9", "TEST_EVENT_10",], ), ( "api/v1/eventLogs?offset=0", @@ -261,9 +231,6 @@ def test_should_return_conf_max_if_req_max_above_conf(self, session): def _create_event_logs(self, count): return [ - Log( - event="TEST_EVENT_" + str(i), - task_instance=self._create_task_instance() - ) + Log(event="TEST_EVENT_" + str(i), task_instance=self._create_task_instance()) for i in range(1, count + 1) ] diff --git a/tests/api_connexion/endpoints/test_extra_link_endpoint.py b/tests/api_connexion/endpoints/test_extra_link_endpoint.py index 87e7946aea3d3..f0bdd7ca38e22 100644 --- a/tests/api_connexion/endpoints/test_extra_link_endpoint.py +++ b/tests/api_connexion/endpoints/test_extra_link_endpoint.py @@ -129,7 +129,7 @@ def test_should_response_200(self): ) response = self.client.get( "/api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID/taskInstances/TEST_SINGLE_QUERY/links", - environ_overrides={'REMOTE_USER': "test"} + environ_overrides={'REMOTE_USER': "test"}, ) self.assertEqual(200, response.status_code, response.data) @@ -141,7 +141,7 @@ def test_should_response_200(self): def test_should_response_200_missing_xcom(self): response = self.client.get( "/api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID/taskInstances/TEST_SINGLE_QUERY/links", - environ_overrides={'REMOTE_USER': "test"} + environ_overrides={'REMOTE_USER': "test"}, ) self.assertEqual(200, response.status_code, response.data) @@ -160,7 +160,7 @@ def test_should_response_200_multiple_links(self): ) response = self.client.get( "/api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID/taskInstances/TEST_MULTIPLE_QUERY/links", - environ_overrides={'REMOTE_USER': "test"} + environ_overrides={'REMOTE_USER': "test"}, ) self.assertEqual(200, response.status_code, response.data) @@ -176,7 +176,7 @@ def test_should_response_200_multiple_links(self): def test_should_response_200_multiple_links_missing_xcom(self): response = self.client.get( "/api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID/taskInstances/TEST_MULTIPLE_QUERY/links", - environ_overrides={'REMOTE_USER': "test"} + environ_overrides={'REMOTE_USER': "test"}, ) self.assertEqual(200, response.status_code, response.data) @@ -214,7 +214,7 @@ class AirflowTestPlugin(AirflowPlugin): with mock_plugin_manager(plugins=[AirflowTestPlugin]): response = self.client.get( "/api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID/taskInstances/TEST_SINGLE_QUERY/links", - environ_overrides={'REMOTE_USER': "test"} + environ_overrides={'REMOTE_USER': "test"}, ) self.assertEqual(200, response.status_code, response.data) diff --git a/tests/api_connexion/endpoints/test_health_endpoint.py b/tests/api_connexion/endpoints/test_health_endpoint.py index 9765c46283ce3..d85dea4857657 100644 --- a/tests/api_connexion/endpoints/test_health_endpoint.py +++ b/tests/api_connexion/endpoints/test_health_endpoint.py @@ -66,9 +66,7 @@ def test_healthy_scheduler_status(self, session): @provide_session def test_unhealthy_scheduler_is_slow(self, session): - last_scheduler_heartbeat_for_testing_2 = timezone.utcnow() - timedelta( - minutes=1 - ) + last_scheduler_heartbeat_for_testing_2 = timezone.utcnow() - timedelta(minutes=1) session.add( BaseJob( job_type="SchedulerJob", diff --git a/tests/api_connexion/endpoints/test_import_error_endpoint.py b/tests/api_connexion/endpoints/test_import_error_endpoint.py index 5565969cb83c9..78439faa3d1ed 100644 --- a/tests/api_connexion/endpoints/test_import_error_endpoint.py +++ b/tests/api_connexion/endpoints/test_import_error_endpoint.py @@ -31,9 +31,7 @@ class TestBaseImportError(unittest.TestCase): @classmethod def setUpClass(cls) -> None: super().setUpClass() - with conf_vars( - {("api", "auth_backend"): "tests.test_utils.remote_user_api_auth_backend"} - ): + with conf_vars({("api", "auth_backend"): "tests.test_utils.remote_user_api_auth_backend"}): cls.app = app.create_app(testing=True) # type:ignore # TODO: Add new role for each view to test permission. create_user(cls.app, username="test", role="Admin") # type: ignore @@ -89,12 +87,7 @@ def test_response_404(self): response = self.client.get("/api/v1/importErrors/2", environ_overrides={'REMOTE_USER': "test"}) assert response.status_code == 404 self.assertEqual( - { - "detail": None, - "status": 404, - "title": "Import error not found", - "type": "about:blank", - }, + {"detail": None, "status": 404, "title": "Import error not found", "type": "about:blank",}, response.json, ) @@ -108,9 +101,7 @@ def test_should_raises_401_unauthenticated(self, session): session.add(import_error) session.commit() - response = self.client.get( - f"/api/v1/importErrors/{import_error.id}" - ) + response = self.client.get(f"/api/v1/importErrors/{import_error.id}") assert_401(response) @@ -202,9 +193,7 @@ def test_limit_and_offset(self, url, expected_import_error_ids, session): response = self.client.get(url, environ_overrides={'REMOTE_USER': "test"}) assert response.status_code == 200 - import_ids = [ - pool["filename"] for pool in response.json["import_errors"] - ] + import_ids = [pool["filename"] for pool in response.json["import_errors"]] self.assertEqual(import_ids, expected_import_error_ids) @provide_session diff --git a/tests/api_connexion/endpoints/test_log_endpoint.py b/tests/api_connexion/endpoints/test_log_endpoint.py index ba5b9b37b2507..fb8df5a774ae9 100644 --- a/tests/api_connexion/endpoints/test_log_endpoint.py +++ b/tests/api_connexion/endpoints/test_log_endpoint.py @@ -48,9 +48,7 @@ class TestGetLog(unittest.TestCase): def setUpClass(cls): settings.configure_orm() cls.session = settings.Session - with conf_vars( - {("api", "auth_backend"): "tests.test_utils.remote_user_api_auth_backend"} - ): + with conf_vars({("api", "auth_backend"): "tests.test_utils.remote_user_api_auth_backend"}): cls.app = app.create_app(testing=True) # TODO: Add new role for each view to test permission. create_user(cls.app, username="test", role="Admin") @@ -86,9 +84,9 @@ def _configure_loggers(self): logging_config = copy.deepcopy(DEFAULT_LOGGING_CONFIG) logging_config['handlers']['task']['base_log_folder'] = self.log_dir - logging_config['handlers']['task']['filename_template'] = \ - '{{ ti.dag_id }}/{{ ti.task_id }}/' \ - '{{ ts | replace(":", ".") }}/{{ try_number }}.log' + logging_config['handlers']['task']['filename_template'] = ( + '{{ ti.dag_id }}/{{ ti.task_id }}/{{ ts | replace(":", ".") }}/{{ try_number }}.log' + ) # Write the custom logging configuration to a file self.settings_folder = tempfile.mkdtemp() @@ -98,10 +96,12 @@ def _configure_loggers(self): handle.writelines(new_logging_file) sys.path.append(self.settings_folder) - with conf_vars({ - ('logging', 'logging_config_class'): 'airflow_local_settings.LOGGING_CONFIG', - ("api", "auth_backend"): "tests.test_utils.remote_user_api_auth_backend" - }): + with conf_vars( + { + ('logging', 'logging_config_class'): 'airflow_local_settings.LOGGING_CONFIG', + ("api", "auth_backend"): "tests.test_utils.remote_user_api_auth_backend", + } + ): self.app = app.create_app(testing=True) self.client = self.app.test_client() settings.configure_logging() @@ -114,14 +114,13 @@ def _prepare_db(self): with create_session() as session: self.ti = TaskInstance( task=DummyOperator(task_id=self.TASK_ID, dag=dag), - execution_date=timezone.parse(self.default_time) + execution_date=timezone.parse(self.default_time), ) self.ti.try_number = 1 session.merge(self.ti) def _prepare_log_files(self): - dir_path = f"{self.log_dir}/{self.DAG_ID}/{self.TASK_ID}/" \ - f"{self.default_time.replace(':', '.')}/" + dir_path = f"{self.log_dir}/{self.DAG_ID}/{self.TASK_ID}/" f"{self.default_time.replace(':', '.')}/" os.makedirs(dir_path) with open(f"{dir_path}/1.log", "w+") as file: file.write("Log for testing.") @@ -152,23 +151,16 @@ def test_should_response_200_json(self, session): f"api/v1/dags/{self.DAG_ID}/dagRuns/TEST_DAG_RUN_ID/" f"taskInstances/{self.TASK_ID}/logs/1?token={token}", headers={'Accept': 'application/json'}, - environ_overrides={'REMOTE_USER': "test"} + environ_overrides={'REMOTE_USER': "test"}, ) expected_filename = "{}/{}/{}/{}/1.log".format( - self.log_dir, - self.DAG_ID, - self.TASK_ID, - self.default_time.replace(":", ".") + self.log_dir, self.DAG_ID, self.TASK_ID, self.default_time.replace(":", ".") ) self.assertEqual( - response.json['content'], - f"*** Reading local file: {expected_filename}\nLog for testing." + response.json['content'], f"*** Reading local file: {expected_filename}\nLog for testing." ) info = serializer.loads(response.json['continuation_token']) - self.assertEqual( - info, - {'end_of_log': True} - ) + self.assertEqual(info, {'end_of_log': True}) self.assertEqual(200, response.status_code) @provide_session @@ -182,18 +174,14 @@ def test_should_response_200_text_plain(self, session): f"api/v1/dags/{self.DAG_ID}/dagRuns/TEST_DAG_RUN_ID/" f"taskInstances/{self.TASK_ID}/logs/1?token={token}", headers={'Accept': 'text/plain'}, - environ_overrides={'REMOTE_USER': "test"} + environ_overrides={'REMOTE_USER': "test"}, ) expected_filename = "{}/{}/{}/{}/1.log".format( - self.log_dir, - self.DAG_ID, - self.TASK_ID, - self.default_time.replace(':', '.') + self.log_dir, self.DAG_ID, self.TASK_ID, self.default_time.replace(':', '.') ) self.assertEqual(200, response.status_code) self.assertEqual( - response.data.decode('utf-8'), - f"*** Reading local file: {expected_filename}\nLog for testing.\n" + response.data.decode('utf-8'), f"*** Reading local file: {expected_filename}\nLog for testing.\n" ) @provide_session @@ -206,7 +194,7 @@ def test_get_logs_response_with_ti_equal_to_none(self, session): response = self.client.get( f"api/v1/dags/{self.DAG_ID}/dagRuns/TEST_DAG_RUN_ID/" f"taskInstances/Invalid-Task-ID/logs/1?token={token}", - environ_overrides={'REMOTE_USER': "test"} + environ_overrides={'REMOTE_USER': "test"}, ) self.assertEqual(response.status_code, 400) self.assertEqual(response.json['detail'], "Task instance did not exist in the DB") @@ -225,7 +213,7 @@ def test_get_logs_with_metadata_as_download_large_file(self, session): f"api/v1/dags/{self.DAG_ID}/dagRuns/TEST_DAG_RUN_ID/" f"taskInstances/{self.TASK_ID}/logs/1?full_content=True", headers={"Accept": 'text/plain'}, - environ_overrides={'REMOTE_USER': "test"} + environ_overrides={'REMOTE_USER': "test"}, ) self.assertIn('1st line', response.data.decode('utf-8')) @@ -246,12 +234,10 @@ def test_get_logs_for_handler_without_read_method(self, mock_log_reader): f"api/v1/dags/{self.DAG_ID}/dagRuns/TEST_DAG_RUN_ID/" f"taskInstances/{self.TASK_ID}/logs/1?token={token}", headers={'Content-Type': 'application/jso'}, - environ_overrides={'REMOTE_USER': "test"} + environ_overrides={'REMOTE_USER': "test"}, ) self.assertEqual(400, response.status_code) - self.assertIn( - 'Task log handler does not support read logs.', - response.data.decode('utf-8')) + self.assertIn('Task log handler does not support read logs.', response.data.decode('utf-8')) @provide_session def test_bad_signature_raises(self, session): @@ -262,7 +248,7 @@ def test_bad_signature_raises(self, session): f"api/v1/dags/{self.DAG_ID}/dagRuns/TEST_DAG_RUN_ID/" f"taskInstances/{self.TASK_ID}/logs/1?token={token}", headers={'Accept': 'application/json'}, - environ_overrides={'REMOTE_USER': "test"} + environ_overrides={'REMOTE_USER': "test"}, ) self.assertEqual( response.json, @@ -270,8 +256,8 @@ def test_bad_signature_raises(self, session): 'detail': None, 'status': 400, 'title': "Bad Signature. Please use only the tokens provided by the API.", - 'type': 'about:blank' - } + 'type': 'about:blank', + }, ) def test_raises_404_for_invalid_dag_run_id(self): @@ -279,16 +265,11 @@ def test_raises_404_for_invalid_dag_run_id(self): f"api/v1/dags/{self.DAG_ID}/dagRuns/TEST_DAG_RUN/" # invalid dagrun_id f"taskInstances/{self.TASK_ID}/logs/1?", headers={'Accept': 'application/json'}, - environ_overrides={'REMOTE_USER': "test"} + environ_overrides={'REMOTE_USER': "test"}, ) self.assertEqual( response.json, - { - 'detail': None, - 'status': 404, - 'title': "DAG Run not found", - 'type': 'about:blank' - } + {'detail': None, 'status': 404, 'title': "DAG Run not found", 'type': 'about:blank'}, ) def test_should_raises_401_unauthenticated(self): diff --git a/tests/api_connexion/endpoints/test_pool_endpoint.py b/tests/api_connexion/endpoints/test_pool_endpoint.py index a6df5b9b55722..baba612f87ba7 100644 --- a/tests/api_connexion/endpoints/test_pool_endpoint.py +++ b/tests/api_connexion/endpoints/test_pool_endpoint.py @@ -30,9 +30,7 @@ class TestBasePoolEndpoints(unittest.TestCase): @classmethod def setUpClass(cls) -> None: super().setUpClass() - with conf_vars( - {("api", "auth_backend"): "tests.test_utils.remote_user_api_auth_backend"} - ): + with conf_vars({("api", "auth_backend"): "tests.test_utils.remote_user_api_auth_backend"}): cls.app = app.create_app(testing=True) # type:ignore # TODO: Add new role for each view to test permission. create_user(cls.app, username="test", role="Admin") # type: ignore @@ -101,15 +99,9 @@ class TestGetPoolsPagination(TestBasePoolEndpoints): ("/api/v1/pools?limit=2", ["default_pool", "test_pool1"]), ("/api/v1/pools?limit=1", ["default_pool"]), # Limit and offset test data - ( - "/api/v1/pools?limit=100&offset=1", - [f"test_pool{i}" for i in range(1, 101)], - ), + ("/api/v1/pools?limit=100&offset=1", [f"test_pool{i}" for i in range(1, 101)],), ("/api/v1/pools?limit=2&offset=1", ["test_pool1", "test_pool2"]), - ( - "/api/v1/pools?limit=3&offset=2", - ["test_pool2", "test_pool3", "test_pool4"], - ), + ("/api/v1/pools?limit=3&offset=2", ["test_pool2", "test_pool3", "test_pool4"],), ] ) @provide_session @@ -154,9 +146,7 @@ def test_response_200(self, session): pool_model = Pool(pool="test_pool_a", slots=3) session.add(pool_model) session.commit() - response = self.client.get( - "/api/v1/pools/test_pool_a", environ_overrides={'REMOTE_USER': "test"} - ) + response = self.client.get("/api/v1/pools/test_pool_a", environ_overrides={'REMOTE_USER': "test"}) assert response.status_code == 200 self.assertEqual( { @@ -171,9 +161,7 @@ def test_response_200(self, session): ) def test_response_404(self): - response = self.client.get( - "/api/v1/pools/invalid_pool", environ_overrides={'REMOTE_USER': "test"} - ) + response = self.client.get("/api/v1/pools/invalid_pool", environ_overrides={'REMOTE_USER': "test"}) assert response.status_code == 404 self.assertEqual( { @@ -230,9 +218,7 @@ def test_should_raises_401_unauthenticated(self, session): assert_401(response) # Should still exists - response = self.client.get( - f"/api/v1/pools/{pool_name}", environ_overrides={'REMOTE_USER': "test"} - ) + response = self.client.get(f"/api/v1/pools/{pool_name}", environ_overrides={'REMOTE_USER': "test"}) assert response.status_code == 200 @@ -241,7 +227,7 @@ def test_response_200(self): response = self.client.post( "api/v1/pools", json={"name": "test_pool_a", "slots": 3}, - environ_overrides={'REMOTE_USER': "test"} + environ_overrides={'REMOTE_USER': "test"}, ) assert response.status_code == 200 self.assertEqual( @@ -265,7 +251,7 @@ def test_response_409(self, session): response = self.client.post( "api/v1/pools", json={"name": "test_pool_a", "slots": 3}, - environ_overrides={'REMOTE_USER': "test"} + environ_overrides={'REMOTE_USER': "test"}, ) assert response.status_code == 409 self.assertEqual( @@ -281,11 +267,7 @@ def test_response_409(self, session): @parameterized.expand( [ ("for missing pool name", {"slots": 3}, "'name' is a required property",), - ( - "for missing slots", - {"name": "invalid_pool"}, - "'slots' is a required property", - ), + ("for missing slots", {"name": "invalid_pool"}, "'slots' is a required property",), ( "for extra fields", {"name": "invalid_pool", "slots": 3, "extra_field_1": "extra"}, @@ -296,26 +278,16 @@ def test_response_409(self, session): def test_response_400(self, name, request_json, error_detail): del name response = self.client.post( - "api/v1/pools", - json=request_json, - environ_overrides={'REMOTE_USER': "test"} + "api/v1/pools", json=request_json, environ_overrides={'REMOTE_USER': "test"} ) assert response.status_code == 400 self.assertDictEqual( - { - "detail": error_detail, - "status": 400, - "title": "Bad request", - "type": "about:blank", - }, + {"detail": error_detail, "status": 400, "title": "Bad request", "type": "about:blank",}, response.json, ) def test_should_raises_401_unauthenticated(self): - response = self.client.post( - "api/v1/pools", - json={"name": "test_pool_a", "slots": 3} - ) + response = self.client.post("api/v1/pools", json={"name": "test_pool_a", "slots": 3}) assert_401(response) @@ -329,7 +301,7 @@ def test_response_200(self, session): response = self.client.patch( "api/v1/pools/test_pool", json={"name": "test_pool_a", "slots": 3}, - environ_overrides={'REMOTE_USER': "test"} + environ_overrides={'REMOTE_USER': "test"}, ) self.assertEqual(response.status_code, 200) self.assertEqual( @@ -366,12 +338,7 @@ def test_response_400(self, error_detail, request_json, session): ) assert response.status_code == 400 self.assertEqual( - { - "detail": error_detail, - "status": 400, - "title": "Bad request", - "type": "about:blank", - }, + {"detail": error_detail, "status": 400, "title": "Bad request", "type": "about:blank",}, response.json, ) @@ -381,9 +348,7 @@ def test_should_raises_401_unauthenticated(self, session): session.add(pool) session.commit() - response = self.client.patch( - "api/v1/pools/test_pool", json={"name": "test_pool_a", "slots": 3}, - ) + response = self.client.patch("api/v1/pools/test_pool", json={"name": "test_pool_a", "slots": 3},) assert_401(response) @@ -500,18 +465,11 @@ class TestPatchPoolWithUpdateMask(TestBasePoolEndpoints): "test_pool", 2, ), - ( - "api/v1/pools/test_pool?update_mask=slots", - {"slots": 2}, - "test_pool", - 2, - ), + ("api/v1/pools/test_pool?update_mask=slots", {"slots": 2}, "test_pool", 2,), ] ) @provide_session - def test_response_200( - self, url, patch_json, expected_name, expected_slots, session - ): + def test_response_200(self, url, patch_json, expected_name, expected_slots, session): pool = Pool(pool="test_pool", slots=3) session.add(pool) session.commit() @@ -567,11 +525,6 @@ def test_response_400(self, name, error_detail, url, patch_json, session): assert response.status_code == 400 self.assertEqual ( - { - "detail": error_detail, - "status": 400, - "title": "Bad Request", - "type": "about:blank", - }, + {"detail": error_detail, "status": 400, "title": "Bad Request", "type": "about:blank",}, response.json, ) diff --git a/tests/api_connexion/endpoints/test_task_endpoint.py b/tests/api_connexion/endpoints/test_task_endpoint.py index 192d368622ced..6230099ecc60d 100644 --- a/tests/api_connexion/endpoints/test_task_endpoint.py +++ b/tests/api_connexion/endpoints/test_task_endpoint.py @@ -41,9 +41,7 @@ def clean_db(): @classmethod def setUpClass(cls) -> None: super().setUpClass() - with conf_vars( - {("api", "auth_backend"): "tests.test_utils.remote_user_api_auth_backend"} - ): + with conf_vars({("api", "auth_backend"): "tests.test_utils.remote_user_api_auth_backend"}): cls.app = app.create_app(testing=True) # type:ignore # TODO: Add new role for each view to test permission. create_user(cls.app, username="test", role="Admin") # type: ignore @@ -71,10 +69,7 @@ def tearDown(self) -> None: class TestGetTask(TestTaskEndpoint): def test_should_response_200(self): expected = { - "class_ref": { - "class_name": "DummyOperator", - "module_path": "airflow.operators.dummy_operator", - }, + "class_ref": {"class_name": "DummyOperator", "module_path": "airflow.operators.dummy_operator",}, "depends_on_past": False, "downstream_task_ids": [], "end_date": None, @@ -114,10 +109,7 @@ def test_should_response_200_serialized(self): SerializedDagModel.write_dag(self.dag) expected = { - "class_ref": { - "class_name": "DummyOperator", - "module_path": "airflow.operators.dummy_operator", - }, + "class_ref": {"class_name": "DummyOperator", "module_path": "airflow.operators.dummy_operator",}, "depends_on_past": False, "downstream_task_ids": [], "end_date": None, @@ -201,9 +193,7 @@ def test_should_response_200(self): def test_should_response_404(self): dag_id = "xxxx_not_existing" - response = self.client.get( - f"/api/v1/dags/{dag_id}/tasks", environ_overrides={'REMOTE_USER': "test"} - ) + response = self.client.get(f"/api/v1/dags/{dag_id}/tasks", environ_overrides={'REMOTE_USER': "test"}) assert response.status_code == 404 def test_should_raises_401_unauthenticated(self): diff --git a/tests/api_connexion/endpoints/test_task_instance_endpoint.py b/tests/api_connexion/endpoints/test_task_instance_endpoint.py index b66ea9f4c966d..aaf6f2deecfe1 100644 --- a/tests/api_connexion/endpoints/test_task_instance_endpoint.py +++ b/tests/api_connexion/endpoints/test_task_instance_endpoint.py @@ -27,9 +27,7 @@ class TestTaskInstanceEndpoint(unittest.TestCase): @classmethod def setUpClass(cls) -> None: super().setUpClass() - with conf_vars( - {("api", "auth_backend"): "tests.test_utils.remote_user_api_auth_backend"} - ): + with conf_vars({("api", "auth_backend"): "tests.test_utils.remote_user_api_auth_backend"}): cls.app = app.create_app(testing=True) # type:ignore # TODO: Add new role for each view to test permission. create_user(cls.app, username="test", role="Admin") # type: ignore @@ -47,7 +45,7 @@ class TestGetTaskInstance(TestTaskInstanceEndpoint): def test_should_response_200(self): response = self.client.get( "/api/v1/dags/TEST_DG_ID/dagRuns/TEST_DAG_RUN_ID/taskInstances/TEST_TASK_ID", - environ_overrides={'REMOTE_USER': "test"} + environ_overrides={'REMOTE_USER': "test"}, ) assert response.status_code == 200 diff --git a/tests/api_connexion/endpoints/test_variable_endpoint.py b/tests/api_connexion/endpoints/test_variable_endpoint.py index bbdb8d60e3cb2..9413a3c34e324 100644 --- a/tests/api_connexion/endpoints/test_variable_endpoint.py +++ b/tests/api_connexion/endpoints/test_variable_endpoint.py @@ -29,9 +29,7 @@ class TestVariableEndpoint(unittest.TestCase): @classmethod def setUpClass(cls) -> None: super().setUpClass() - with conf_vars( - {("api", "auth_backend"): "tests.test_utils.remote_user_api_auth_backend"} - ): + with conf_vars({("api", "auth_backend"): "tests.test_utils.remote_user_api_auth_backend"}): cls.app = app.create_app(testing=True) # type:ignore # TODO: Add new role for each view to test permission. create_user(cls.app, username="test", role="Admin") # type: ignore @@ -52,9 +50,7 @@ class TestDeleteVariable(TestVariableEndpoint): def test_should_delete_variable(self): Variable.set("delete_var1", 1) # make sure variable is added - response = self.client.get( - "/api/v1/variables/delete_var1", environ_overrides={'REMOTE_USER': "test"} - ) + response = self.client.get("/api/v1/variables/delete_var1", environ_overrides={'REMOTE_USER': "test"}) assert response.status_code == 200 response = self.client.delete( @@ -63,9 +59,7 @@ def test_should_delete_variable(self): assert response.status_code == 204 # make sure variable is deleted - response = self.client.get( - "/api/v1/variables/delete_var1", environ_overrides={'REMOTE_USER': "test"} - ) + response = self.client.get("/api/v1/variables/delete_var1", environ_overrides={'REMOTE_USER': "test"}) assert response.status_code == 404 def test_should_response_404_if_key_does_not_exist(self): @@ -77,21 +71,16 @@ def test_should_response_404_if_key_does_not_exist(self): def test_should_raises_401_unauthenticated(self): Variable.set("delete_var1", 1) # make sure variable is added - response = self.client.delete( - "/api/v1/variables/delete_var1" - ) + response = self.client.delete("/api/v1/variables/delete_var1") assert_401(response) # make sure variable is not deleted - response = self.client.get( - "/api/v1/variables/delete_var1", environ_overrides={'REMOTE_USER': "test"} - ) + response = self.client.get("/api/v1/variables/delete_var1", environ_overrides={'REMOTE_USER': "test"}) assert response.status_code == 200 class TestGetVariable(TestVariableEndpoint): - def test_should_response_200(self): expected_value = '{"foo": 1}' Variable.set("TEST_VARIABLE_KEY", expected_value) @@ -110,36 +99,34 @@ def test_should_response_404_if_not_found(self): def test_should_raises_401_unauthenticated(self): Variable.set("TEST_VARIABLE_KEY", '{"foo": 1}') - response = self.client.get( - "/api/v1/variables/TEST_VARIABLE_KEY" - ) + response = self.client.get("/api/v1/variables/TEST_VARIABLE_KEY") assert_401(response) class TestGetVariables(TestVariableEndpoint): - @parameterized.expand([ - ("/api/v1/variables?limit=2&offset=0", { - "variables": [ - {"key": "var1", "value": "1"}, - {"key": "var2", "value": "foo"}, - ], - "total_entries": 3, - }), - ("/api/v1/variables?limit=2&offset=1", { - "variables": [ - {"key": "var2", "value": "foo"}, - {"key": "var3", "value": "[100, 101]"}, - ], - "total_entries": 3, - }), - ("/api/v1/variables?limit=1&offset=2", { - "variables": [ - {"key": "var3", "value": "[100, 101]"}, - ], - "total_entries": 3, - }), - ]) + @parameterized.expand( + [ + ( + "/api/v1/variables?limit=2&offset=0", + { + "variables": [{"key": "var1", "value": "1"}, {"key": "var2", "value": "foo"},], + "total_entries": 3, + }, + ), + ( + "/api/v1/variables?limit=2&offset=1", + { + "variables": [{"key": "var2", "value": "foo"}, {"key": "var3", "value": "[100, 101]"},], + "total_entries": 3, + }, + ), + ( + "/api/v1/variables?limit=1&offset=2", + {"variables": [{"key": "var3", "value": "[100, 101]"},], "total_entries": 3,}, + ), + ] + ) def test_should_get_list_variables(self, query, expected): Variable.set("var1", 1) Variable.set("var2", "foo") @@ -160,9 +147,7 @@ def test_should_respect_page_size_limit_default(self): def test_should_return_conf_max_if_req_max_above_conf(self): for i in range(200): Variable.set(f"var{i}", i) - response = self.client.get( - "/api/v1/variables?limit=180", environ_overrides={'REMOTE_USER': "test"} - ) + response = self.client.get("/api/v1/variables?limit=180", environ_overrides={'REMOTE_USER': "test"}) assert response.status_code == 200 self.assertEqual(len(response.json['variables']), 150) @@ -179,11 +164,8 @@ def test_should_update_variable(self): Variable.set("var1", "foo") response = self.client.patch( "/api/v1/variables/var1", - json={ - "key": "var1", - "value": "updated", - }, - environ_overrides={'REMOTE_USER': "test"} + json={"key": "var1", "value": "updated",}, + environ_overrides={'REMOTE_USER': "test"}, ) assert response.status_code == 204 response = self.client.get("/api/v1/variables/var1", environ_overrides={'REMOTE_USER': "test"}) @@ -196,11 +178,8 @@ def test_should_reject_invalid_update(self): Variable.set("var1", "foo") response = self.client.patch( "/api/v1/variables/var1", - json={ - "key": "var2", - "value": "updated", - }, - environ_overrides={'REMOTE_USER': "test"} + json={"key": "var2", "value": "updated",}, + environ_overrides={'REMOTE_USER': "test"}, ) assert response.status_code == 400 assert response.json == { @@ -211,11 +190,7 @@ def test_should_reject_invalid_update(self): } response = self.client.patch( - "/api/v1/variables/var1", - json={ - "key": "var2", - }, - environ_overrides={'REMOTE_USER': "test"} + "/api/v1/variables/var1", json={"key": "var2",}, environ_overrides={'REMOTE_USER': "test"} ) assert response.json == { "title": "Invalid Variable schema", @@ -227,13 +202,7 @@ def test_should_reject_invalid_update(self): def test_should_raises_401_unauthenticated(self): Variable.set("var1", "foo") - response = self.client.patch( - "/api/v1/variables/var1", - json={ - "key": "var1", - "value": "updated", - }, - ) + response = self.client.patch("/api/v1/variables/var1", json={"key": "var1", "value": "updated",},) assert_401(response) @@ -242,16 +211,11 @@ class TestPostVariables(TestVariableEndpoint): def test_should_create_variable(self): response = self.client.post( "/api/v1/variables", - json={ - "key": "var_create", - "value": "{}", - }, - environ_overrides={'REMOTE_USER': "test"} + json={"key": "var_create", "value": "{}",}, + environ_overrides={'REMOTE_USER': "test"}, ) assert response.status_code == 200 - response = self.client.get( - "/api/v1/variables/var_create", environ_overrides={'REMOTE_USER': "test"} - ) + response = self.client.get("/api/v1/variables/var_create", environ_overrides={'REMOTE_USER': "test"}) assert response.json == { "key": "var_create", "value": "{}", @@ -260,11 +224,8 @@ def test_should_create_variable(self): def test_should_reject_invalid_request(self): response = self.client.post( "/api/v1/variables", - json={ - "key": "var_create", - "v": "{}", - }, - environ_overrides={'REMOTE_USER': "test"} + json={"key": "var_create", "v": "{}",}, + environ_overrides={'REMOTE_USER': "test"}, ) assert response.status_code == 400 assert response.json == { @@ -275,12 +236,6 @@ def test_should_reject_invalid_request(self): } def test_should_raises_401_unauthenticated(self): - response = self.client.post( - "/api/v1/variables", - json={ - "key": "var_create", - "value": "{}", - }, - ) + response = self.client.post("/api/v1/variables", json={"key": "var_create", "value": "{}",},) assert_401(response) diff --git a/tests/api_connexion/endpoints/test_version_endpoint.py b/tests/api_connexion/endpoints/test_version_endpoint.py index f890e3f62c094..b898178ed22d9 100644 --- a/tests/api_connexion/endpoints/test_version_endpoint.py +++ b/tests/api_connexion/endpoints/test_version_endpoint.py @@ -25,17 +25,13 @@ class TestGetHealthTest(unittest.TestCase): @classmethod def setUpClass(cls) -> None: super().setUpClass() - with conf_vars( - {("api", "auth_backend"): "tests.test_utils.remote_user_api_auth_backend"} - ): + with conf_vars({("api", "auth_backend"): "tests.test_utils.remote_user_api_auth_backend"}): cls.app = app.create_app(testing=True) # type:ignore def setUp(self) -> None: self.client = self.app.test_client() # type:ignore - @mock.patch( - "airflow.api_connexion.endpoints.version_endpoint.airflow.__version__", "MOCK_VERSION" - ) + @mock.patch("airflow.api_connexion.endpoints.version_endpoint.airflow.__version__", "MOCK_VERSION") @mock.patch( "airflow.api_connexion.endpoints.version_endpoint.get_airflow_git_version", return_value="GIT_COMMIT" ) diff --git a/tests/api_connexion/endpoints/test_xcom_endpoint.py b/tests/api_connexion/endpoints/test_xcom_endpoint.py index f5496fd73fa75..506b9100749eb 100644 --- a/tests/api_connexion/endpoints/test_xcom_endpoint.py +++ b/tests/api_connexion/endpoints/test_xcom_endpoint.py @@ -32,9 +32,7 @@ class TestXComEndpoint(unittest.TestCase): @classmethod def setUpClass(cls) -> None: super().setUpClass() - with conf_vars( - {("api", "auth_backend"): "tests.test_utils.remote_user_api_auth_backend"} - ): + with conf_vars({("api", "auth_backend"): "tests.test_utils.remote_user_api_auth_backend"}): cls.app = app.create_app(testing=True) # type:ignore # TODO: Add new role for each view to test permission. create_user(cls.app, username="test", role="Admin") # type: ignore @@ -64,7 +62,6 @@ def tearDown(self) -> None: class TestGetXComEntry(TestXComEndpoint): - def test_should_response_200(self): dag_id = 'test-dag-id' task_id = 'test-task-id' @@ -75,7 +72,7 @@ def test_should_response_200(self): self._create_xcom_entry(dag_id, dag_run_id, execution_date_parsed, task_id, xcom_key) response = self.client.get( f"/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/xcomEntries/{xcom_key}", - environ_overrides={'REMOTE_USER': "test"} + environ_overrides={'REMOTE_USER': "test"}, ) self.assertEqual(200, response.status_code) @@ -88,8 +85,8 @@ def test_should_response_200(self): 'execution_date': execution_date, 'key': xcom_key, 'task_id': task_id, - 'timestamp': 'TIMESTAMP' - } + 'timestamp': 'TIMESTAMP', + }, ) def test_should_raises_401_unauthenticated(self): @@ -109,16 +106,16 @@ def test_should_raises_401_unauthenticated(self): @provide_session def _create_xcom_entry(self, dag_id, dag_run_id, execution_date, task_id, xcom_key, session=None): - XCom.set(key=xcom_key, - value="TEST_VALUE", - execution_date=execution_date, - task_id=task_id, - dag_id=dag_id,) - dagrun = DR(dag_id=dag_id, - run_id=dag_run_id, - execution_date=execution_date, - start_date=execution_date, - run_type=DagRunType.MANUAL.value) + XCom.set( + key=xcom_key, value="TEST_VALUE", execution_date=execution_date, task_id=task_id, dag_id=dag_id, + ) + dagrun = DR( + dag_id=dag_id, + run_id=dag_run_id, + execution_date=execution_date, + start_date=execution_date, + run_type=DagRunType.MANUAL.value, + ) session.add(dagrun) @@ -133,7 +130,7 @@ def test_should_response_200(self): self._create_xcom_entries(dag_id, dag_run_id, execution_date_parsed, task_id) response = self.client.get( f"/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/xcomEntries", - environ_overrides={'REMOTE_USER': "test"} + environ_overrides={'REMOTE_USER': "test"}, ) self.assertEqual(200, response.status_code) @@ -149,18 +146,18 @@ def test_should_response_200(self): 'execution_date': execution_date, 'key': 'test-xcom-key-1', 'task_id': task_id, - 'timestamp': "TIMESTAMP" + 'timestamp': "TIMESTAMP", }, { 'dag_id': dag_id, 'execution_date': execution_date, 'key': 'test-xcom-key-2', 'task_id': task_id, - 'timestamp': "TIMESTAMP" - } + 'timestamp': "TIMESTAMP", + }, ], 'total_entries': 2, - } + }, ) def test_should_raises_401_unauthenticated(self): @@ -187,16 +184,17 @@ def _create_xcom_entries(self, dag_id, dag_run_id, execution_date, task_id, sess task_id=task_id, dag_id=dag_id, ) - dagrun = DR(dag_id=dag_id, - run_id=dag_run_id, - execution_date=execution_date, - start_date=execution_date, - run_type=DagRunType.MANUAL.value) + dagrun = DR( + dag_id=dag_id, + run_id=dag_run_id, + execution_date=execution_date, + start_date=execution_date, + run_type=DagRunType.MANUAL.value, + ) session.add(dagrun) class TestPaginationGetXComEntries(TestXComEndpoint): - def setUp(self): super().setUp() self.dag_id = 'test-dag-id' @@ -207,23 +205,11 @@ def setUp(self): @parameterized.expand( [ - ( - "limit=1", - ["TEST_XCOM_KEY1"], - ), - ( - "limit=2", - ["TEST_XCOM_KEY1", "TEST_XCOM_KEY10"], - ), + ("limit=1", ["TEST_XCOM_KEY1"],), + ("limit=2", ["TEST_XCOM_KEY1", "TEST_XCOM_KEY10"],), ( "offset=5", - [ - "TEST_XCOM_KEY5", - "TEST_XCOM_KEY6", - "TEST_XCOM_KEY7", - "TEST_XCOM_KEY8", - "TEST_XCOM_KEY9", - ] + ["TEST_XCOM_KEY5", "TEST_XCOM_KEY6", "TEST_XCOM_KEY7", "TEST_XCOM_KEY8", "TEST_XCOM_KEY9",], ), ( "offset=0", @@ -237,35 +223,27 @@ def setUp(self): "TEST_XCOM_KEY6", "TEST_XCOM_KEY7", "TEST_XCOM_KEY8", - "TEST_XCOM_KEY9" - ] - ), - ( - "limit=1&offset=5", - ["TEST_XCOM_KEY5"], - ), - ( - "limit=1&offset=1", - ["TEST_XCOM_KEY10"], - ), - ( - "limit=2&offset=2", - ["TEST_XCOM_KEY2", "TEST_XCOM_KEY3"], + "TEST_XCOM_KEY9", + ], ), + ("limit=1&offset=5", ["TEST_XCOM_KEY5"],), + ("limit=1&offset=1", ["TEST_XCOM_KEY10"],), + ("limit=2&offset=2", ["TEST_XCOM_KEY2", "TEST_XCOM_KEY3"],), ] ) @provide_session def test_handle_limit_offset(self, query_params, expected_xcom_ids, session): url = "/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/xcomEntries?{query_params}" - url = url.format(dag_id=self.dag_id, - dag_run_id=self.dag_run_id, - task_id=self.task_id, - query_params=query_params) - dagrun = DR(dag_id=self.dag_id, - run_id=self.dag_run_id, - execution_date=self.execution_date_parsed, - start_date=self.execution_date_parsed, - run_type=DagRunType.MANUAL.value) + url = url.format( + dag_id=self.dag_id, dag_run_id=self.dag_run_id, task_id=self.task_id, query_params=query_params + ) + dagrun = DR( + dag_id=self.dag_id, + run_id=self.dag_run_id, + execution_date=self.execution_date_parsed, + start_date=self.execution_date_parsed, + run_type=DagRunType.MANUAL.value, + ) xcom_models = self._create_xcoms(10) session.add_all(xcom_models) session.add(dagrun) @@ -277,10 +255,13 @@ def test_handle_limit_offset(self, query_params, expected_xcom_ids, session): self.assertEqual(conn_ids, expected_xcom_ids) def _create_xcoms(self, count): - return [XCom( - key=f'TEST_XCOM_KEY{i}', - execution_date=self.execution_date_parsed, - task_id=self.task_id, - dag_id=self.dag_id, - timestamp=self.execution_date_parsed, - ) for i in range(1, count + 1)] + return [ + XCom( + key=f'TEST_XCOM_KEY{i}', + execution_date=self.execution_date_parsed, + task_id=self.task_id, + dag_id=self.dag_id, + timestamp=self.execution_date_parsed, + ) + for i in range(1, count + 1) + ] diff --git a/tests/api_connexion/schemas/test_common_schema.py b/tests/api_connexion/schemas/test_common_schema.py index 618989dfca340..8ea8f7e57ceac 100644 --- a/tests/api_connexion/schemas/test_common_schema.py +++ b/tests/api_connexion/schemas/test_common_schema.py @@ -21,7 +21,11 @@ from dateutil import relativedelta from airflow.api_connexion.schemas.common_schema import ( - CronExpression, CronExpressionSchema, RelativeDeltaSchema, ScheduleIntervalSchema, TimeDeltaSchema, + CronExpression, + CronExpressionSchema, + RelativeDeltaSchema, + ScheduleIntervalSchema, + TimeDeltaSchema, ) @@ -30,10 +34,7 @@ def test_should_serialize(self): instance = datetime.timedelta(days=12) schema_instance = TimeDeltaSchema() result = schema_instance.dump(instance) - self.assertEqual( - {"__type": "TimeDelta", "days": 12, "seconds": 0, "microseconds": 0}, - result - ) + self.assertEqual({"__type": "TimeDelta", "days": 12, "seconds": 0, "microseconds": 0}, result) def test_should_deserialize(self): instance = {"__type": "TimeDelta", "days": 12, "seconds": 0, "microseconds": 0} @@ -92,10 +93,7 @@ def test_should_serialize_timedelta(self): instance = datetime.timedelta(days=12) schema_instance = ScheduleIntervalSchema() result = schema_instance.dump(instance) - self.assertEqual( - {"__type": "TimeDelta", "days": 12, "seconds": 0, "microseconds": 0}, - result - ) + self.assertEqual({"__type": "TimeDelta", "days": 12, "seconds": 0, "microseconds": 0}, result) def test_should_deserialize_timedelta(self): instance = {"__type": "TimeDelta", "days": 12, "seconds": 0, "microseconds": 0} diff --git a/tests/api_connexion/schemas/test_config_schema.py b/tests/api_connexion/schemas/test_config_schema.py index 473dabc37a0a4..423ee1cb0e8ad 100644 --- a/tests/api_connexion/schemas/test_config_schema.py +++ b/tests/api_connexion/schemas/test_config_schema.py @@ -27,14 +27,9 @@ def test_serialize(self): options=[ ConfigOption(key='apache', value='airflow'), ConfigOption(key='hello', value='world'), - ] - ), - ConfigSection( - name='sec2', - options=[ - ConfigOption(key='foo', value='bar'), - ] + ], ), + ConfigSection(name='sec2', options=[ConfigOption(key='foo', value='bar'),]), ] ) result = config_schema.dump(config) @@ -42,17 +37,9 @@ def test_serialize(self): 'sections': [ { 'name': 'sec1', - 'options': [ - {'key': 'apache', 'value': 'airflow'}, - {'key': 'hello', 'value': 'world'}, - ] - }, - { - 'name': 'sec2', - 'options': [ - {'key': 'foo', 'value': 'bar'}, - ] + 'options': [{'key': 'apache', 'value': 'airflow'}, {'key': 'hello', 'value': 'world'},], }, + {'name': 'sec2', 'options': [{'key': 'foo', 'value': 'bar'},]}, ] } assert result == expected diff --git a/tests/api_connexion/schemas/test_connection_schema.py b/tests/api_connexion/schemas/test_connection_schema.py index dea887fb50c9a..c76abe3ad1004 100644 --- a/tests/api_connexion/schemas/test_connection_schema.py +++ b/tests/api_connexion/schemas/test_connection_schema.py @@ -20,7 +20,10 @@ import marshmallow from airflow.api_connexion.schemas.connection_schema import ( - ConnectionCollection, connection_collection_item_schema, connection_collection_schema, connection_schema, + ConnectionCollection, + connection_collection_item_schema, + connection_collection_schema, + connection_schema, ) from airflow.models import Connection from airflow.utils.session import create_session, provide_session @@ -28,7 +31,6 @@ class TestConnectionCollectionItemSchema(unittest.TestCase): - def setUp(self) -> None: with create_session() as session: session.query(Connection).delete() @@ -44,7 +46,7 @@ def test_serialize(self, session): host='mysql', login='login', schema='testschema', - port=80 + port=80, ) session.add(connection_model) session.commit() @@ -58,8 +60,8 @@ def test_serialize(self, session): 'host': 'mysql', 'login': 'login', 'schema': 'testschema', - 'port': 80 - } + 'port': 80, + }, ) def test_deserialize(self): @@ -69,7 +71,7 @@ def test_deserialize(self): 'host': 'mysql', 'login': 'login', 'schema': 'testschema', - 'port': 80 + 'port': 80, } connection_dump_2 = { 'connection_id': "mysql_default_2", @@ -86,16 +88,10 @@ def test_deserialize(self): 'host': 'mysql', 'login': 'login', 'schema': 'testschema', - 'port': 80 - } - ) - self.assertEqual( - result_2, - { - 'conn_id': "mysql_default_2", - 'conn_type': "postgres", - } + 'port': 80, + }, ) + self.assertEqual(result_2, {'conn_id': "mysql_default_2", 'conn_type': "postgres",}) def test_deserialize_required_fields(self): connection_dump_1 = { @@ -103,13 +99,12 @@ def test_deserialize_required_fields(self): } with self.assertRaisesRegex( marshmallow.exceptions.ValidationError, - re.escape("{'conn_type': ['Missing data for required field.']}") + re.escape("{'conn_type': ['Missing data for required field.']}"), ): connection_collection_item_schema.load(connection_dump_1) class TestConnectionCollectionSchema(unittest.TestCase): - def setUp(self) -> None: with create_session() as session: session.query(Connection).delete() @@ -119,21 +114,12 @@ def tearDown(self) -> None: @provide_session def test_serialize(self, session): - connection_model_1 = Connection( - conn_id='mysql_default_1', - conn_type='test-type' - ) - connection_model_2 = Connection( - conn_id='mysql_default_2', - conn_type='test-type2' - ) + connection_model_1 = Connection(conn_id='mysql_default_1', conn_type='test-type') + connection_model_2 = Connection(conn_id='mysql_default_2', conn_type='test-type2') connections = [connection_model_1, connection_model_2] session.add_all(connections) session.commit() - instance = ConnectionCollection( - connections=connections, - total_entries=2 - ) + instance = ConnectionCollection(connections=connections, total_entries=2) deserialized_connections = connection_collection_schema.dump(instance) self.assertEqual( deserialized_connections, @@ -145,7 +131,7 @@ def test_serialize(self, session): "host": None, "login": None, 'schema': None, - 'port': None + 'port': None, }, { "connection_id": "mysql_default_2", @@ -153,16 +139,15 @@ def test_serialize(self, session): "host": None, "login": None, 'schema': None, - 'port': None - } + 'port': None, + }, ], - 'total_entries': 2 - } + 'total_entries': 2, + }, ) class TestConnectionSchema(unittest.TestCase): - def setUp(self) -> None: with create_session() as session: session.query(Connection).delete() @@ -180,7 +165,7 @@ def test_serialize(self, session): schema='testschema', port=80, password='test-password', - extra="{'key':'string'}" + extra="{'key':'string'}", ) session.add(connection_model) session.commit() @@ -195,8 +180,8 @@ def test_serialize(self, session): 'login': 'login', 'schema': 'testschema', 'port': 80, - 'extra': "{'key':'string'}" - } + 'extra': "{'key':'string'}", + }, ) def test_deserialize(self): @@ -207,7 +192,7 @@ def test_deserialize(self): 'login': 'login', 'schema': 'testschema', 'port': 80, - 'extra': "{'key':'string'}" + 'extra': "{'key':'string'}", } result = connection_schema.load(den) self.assertEqual( @@ -219,6 +204,6 @@ def test_deserialize(self): 'login': 'login', 'schema': 'testschema', 'port': 80, - 'extra': "{'key':'string'}" - } + 'extra': "{'key':'string'}", + }, ) diff --git a/tests/api_connexion/schemas/test_dag_run_schema.py b/tests/api_connexion/schemas/test_dag_run_schema.py index de8019ad389e6..addc2cb86bb16 100644 --- a/tests/api_connexion/schemas/test_dag_run_schema.py +++ b/tests/api_connexion/schemas/test_dag_run_schema.py @@ -20,7 +20,9 @@ from parameterized import parameterized from airflow.api_connexion.schemas.dag_run_schema import ( - DAGRunCollection, dagrun_collection_schema, dagrun_schema, + DAGRunCollection, + dagrun_collection_schema, + dagrun_schema, ) from airflow.models import DagRun from airflow.utils import timezone @@ -76,16 +78,8 @@ def test_serialze(self, session): {"run_id": "my-dag-run", "execution_date": parse(DEFAULT_TIME)}, ), ( - { - "dag_run_id": "my-dag-run", - "execution_date": DEFAULT_TIME, - "conf": {"start": "stop"}, - }, - { - "run_id": "my-dag-run", - "execution_date": parse(DEFAULT_TIME), - "conf": {"start": "stop"}, - }, + {"dag_run_id": "my-dag-run", "execution_date": DEFAULT_TIME, "conf": {"start": "stop"},}, + {"run_id": "my-dag-run", "execution_date": parse(DEFAULT_TIME), "conf": {"start": "stop"},}, ), ] ) @@ -98,8 +92,7 @@ def test_autofill_fields(self): serialized_dagrun = {} result = dagrun_schema.load(serialized_dagrun) self.assertDictEqual( - result, - {"execution_date": result["execution_date"], "run_id": result["run_id"]}, + result, {"execution_date": result["execution_date"], "run_id": result["run_id"]}, ) diff --git a/tests/api_connexion/schemas/test_dag_schema.py b/tests/api_connexion/schemas/test_dag_schema.py index 1fb0f0b30f653..fdc9ad2b7236a 100644 --- a/tests/api_connexion/schemas/test_dag_schema.py +++ b/tests/api_connexion/schemas/test_dag_schema.py @@ -20,7 +20,10 @@ from airflow import DAG from airflow.api_connexion.schemas.dag_schema import ( - DAGCollection, DAGCollectionSchema, DAGDetailSchema, DAGSchema, + DAGCollection, + DAGCollectionSchema, + DAGDetailSchema, + DAGSchema, ) from airflow.models import DagModel, DagTag @@ -119,6 +122,6 @@ def test_serialize(self): 'schedule_interval': {'__type': 'TimeDelta', 'days': 1, 'seconds': 0, 'microseconds': 0}, 'start_date': '2020-06-19T00:00:00+00:00', 'tags': None, - 'timezone': "Timezone('UTC')" + 'timezone': "Timezone('UTC')", } assert schema.dump(dag) == expected diff --git a/tests/api_connexion/schemas/test_error_schema.py b/tests/api_connexion/schemas/test_error_schema.py index 521e3e6a76948..cf7d555940eb0 100644 --- a/tests/api_connexion/schemas/test_error_schema.py +++ b/tests/api_connexion/schemas/test_error_schema.py @@ -17,7 +17,9 @@ import unittest from airflow.api_connexion.schemas.error_schema import ( - ImportErrorCollection, import_error_collection_schema, import_error_schema, + ImportErrorCollection, + import_error_collection_schema, + import_error_schema, ) from airflow.models.errors import ImportError # pylint: disable=redefined-builtin from airflow.utils import timezone diff --git a/tests/api_connexion/schemas/test_event_log_schema.py b/tests/api_connexion/schemas/test_event_log_schema.py index 8e58c923722a1..c7da442449cc3 100644 --- a/tests/api_connexion/schemas/test_event_log_schema.py +++ b/tests/api_connexion/schemas/test_event_log_schema.py @@ -19,7 +19,9 @@ from airflow import DAG from airflow.api_connexion.schemas.event_log_schema import ( - EventLogCollection, event_log_collection_schema, event_log_schema, + EventLogCollection, + event_log_collection_schema, + event_log_schema, ) from airflow.models import Log, TaskInstance from airflow.operators.dummy_operator import DummyOperator @@ -28,7 +30,6 @@ class TestEventLogSchemaBase(unittest.TestCase): - def setUp(self) -> None: with create_session() as session: session.query(Log).delete() @@ -40,20 +41,19 @@ def tearDown(self) -> None: session.query(Log).delete() def _create_task_instance(self): - with DAG('TEST_DAG_ID', start_date=timezone.parse(self.default_time), - end_date=timezone.parse(self.default_time)): + with DAG( + 'TEST_DAG_ID', + start_date=timezone.parse(self.default_time), + end_date=timezone.parse(self.default_time), + ): op1 = DummyOperator(task_id="TEST_TASK_ID", owner="airflow") return TaskInstance(task=op1, execution_date=timezone.parse(self.default_time)) class TestEventLogSchema(TestEventLogSchemaBase): - @provide_session def test_serialize(self, session): - event_log_model = Log( - event="TEST_EVENT", - task_instance=self._create_task_instance() - ) + event_log_model = Log(event="TEST_EVENT", task_instance=self._create_task_instance()) session.add(event_log_model) session.commit() event_log_model.dttm = timezone.parse(self.default_time) @@ -69,37 +69,28 @@ def test_serialize(self, session): "execution_date": self.default_time, "owner": 'airflow', "when": self.default_time, - "extra": None - } + "extra": None, + }, ) class TestEventLogCollection(TestEventLogSchemaBase): - @provide_session def test_serialize(self, session): - event_log_model_1 = Log( - event="TEST_EVENT_1", - task_instance=self._create_task_instance() - ) - event_log_model_2 = Log( - event="TEST_EVENT_2", - task_instance=self._create_task_instance() - ) + event_log_model_1 = Log(event="TEST_EVENT_1", task_instance=self._create_task_instance()) + event_log_model_2 = Log(event="TEST_EVENT_2", task_instance=self._create_task_instance()) event_logs = [event_log_model_1, event_log_model_2] session.add_all(event_logs) session.commit() event_log_model_1.dttm = timezone.parse(self.default_time) event_log_model_2.dttm = timezone.parse(self.default_time2) - instance = EventLogCollection(event_logs=event_logs, - total_entries=2) + instance = EventLogCollection(event_logs=event_logs, total_entries=2) deserialized_event_logs = event_log_collection_schema.dump(instance) self.assertEqual( deserialized_event_logs, { "event_logs": [ { - "event_log_id": event_log_model_1.id, "event": "TEST_EVENT_1", "dag_id": "TEST_DAG_ID", @@ -107,8 +98,7 @@ def test_serialize(self, session): "execution_date": self.default_time, "owner": 'airflow', "when": self.default_time, - "extra": None - + "extra": None, }, { "event_log_id": event_log_model_2.id, @@ -118,9 +108,9 @@ def test_serialize(self, session): "execution_date": self.default_time, "owner": 'airflow', "when": self.default_time2, - "extra": None - } + "extra": None, + }, ], - "total_entries": 2 - } + "total_entries": 2, + }, ) diff --git a/tests/api_connexion/schemas/test_health_schema.py b/tests/api_connexion/schemas/test_health_schema.py index e7e1ff6336efc..0a5bda46146e9 100644 --- a/tests/api_connexion/schemas/test_health_schema.py +++ b/tests/api_connexion/schemas/test_health_schema.py @@ -26,10 +26,7 @@ def setUp(self): def test_serialize(self): payload = { "metadatabase": {"status": "healthy"}, - "scheduler": { - "status": "healthy", - "latest_scheduler_heartbeat": self.default_datetime, - }, + "scheduler": {"status": "healthy", "latest_scheduler_heartbeat": self.default_datetime,}, } serialized_data = health_schema.dump(payload) self.assertDictEqual(serialized_data, payload) diff --git a/tests/api_connexion/schemas/test_task_schema.py b/tests/api_connexion/schemas/test_task_schema.py index 96ad28a487ecd..534712ed065e2 100644 --- a/tests/api_connexion/schemas/test_task_schema.py +++ b/tests/api_connexion/schemas/test_task_schema.py @@ -24,16 +24,11 @@ class TestTaskSchema: def test_serialize(self): op = DummyOperator( - task_id="task_id", - start_date=datetime(2020, 6, 16), - end_date=datetime(2020, 6, 26), + task_id="task_id", start_date=datetime(2020, 6, 16), end_date=datetime(2020, 6, 26), ) result = task_schema.dump(op) expected = { - "class_ref": { - "module_path": "airflow.operators.dummy_operator", - "class_name": "DummyOperator", - }, + "class_ref": {"module_path": "airflow.operators.dummy_operator", "class_name": "DummyOperator",}, "depends_on_past": False, "downstream_task_ids": [], "end_date": "2020-06-26T00:00:00+00:00", diff --git a/tests/api_connexion/schemas/test_version_schema.py b/tests/api_connexion/schemas/test_version_schema.py index b5f7b2476e39d..d5b49fac7bf69 100644 --- a/tests/api_connexion/schemas/test_version_schema.py +++ b/tests/api_connexion/schemas/test_version_schema.py @@ -24,11 +24,9 @@ class TestVersionInfoSchema(unittest.TestCase): - - @parameterized.expand([ - ("GIT_COMMIT", ), - (None, ), - ]) + @parameterized.expand( + [("GIT_COMMIT",), (None,),] + ) def test_serialize(self, git_commit): version_info = VersionInfo("VERSION", git_commit) current_data = version_info_schema.dump(version_info) diff --git a/tests/api_connexion/schemas/test_xcom_schema.py b/tests/api_connexion/schemas/test_xcom_schema.py index ef079834b9ba0..8e2f514ab1f7d 100644 --- a/tests/api_connexion/schemas/test_xcom_schema.py +++ b/tests/api_connexion/schemas/test_xcom_schema.py @@ -19,7 +19,10 @@ from sqlalchemy import or_ from airflow.api_connexion.schemas.xcom_schema import ( - XComCollection, xcom_collection_item_schema, xcom_collection_schema, xcom_schema, + XComCollection, + xcom_collection_item_schema, + xcom_collection_schema, + xcom_schema, ) from airflow.models import XCom from airflow.utils.dates import parse_execution_date @@ -27,7 +30,6 @@ class TestXComSchemaBase(unittest.TestCase): - def setUp(self): """ Clear Hanging XComs pre test @@ -44,7 +46,6 @@ def tearDown(self) -> None: class TestXComCollectionItemSchema(TestXComSchemaBase): - def setUp(self) -> None: super().setUp() self.default_time = '2005-04-02T21:00:00+00:00' @@ -71,7 +72,7 @@ def test_serialize(self, session): 'execution_date': self.default_time, 'task_id': 'test_task_id', 'dag_id': 'test_dag', - } + }, ) def test_deserialize(self): @@ -91,12 +92,11 @@ def test_deserialize(self): 'execution_date': self.default_time_parsed, 'task_id': 'test_task_id', 'dag_id': 'test_dag', - } + }, ) class TestXComCollectionSchema(TestXComSchemaBase): - def setUp(self) -> None: super().setUp() self.default_time_1 = '2005-04-02T21:00:00+00:00' @@ -127,10 +127,9 @@ def test_serialize(self, session): or_(XCom.execution_date == self.time_1, XCom.execution_date == self.time_2) ) xcom_models_queried = xcom_models_query.all() - deserialized_xcoms = xcom_collection_schema.dump(XComCollection( - xcom_entries=xcom_models_queried, - total_entries=xcom_models_query.count(), - )) + deserialized_xcoms = xcom_collection_schema.dump( + XComCollection(xcom_entries=xcom_models_queried, total_entries=xcom_models_query.count(),) + ) self.assertEqual( deserialized_xcoms, { @@ -148,15 +147,14 @@ def test_serialize(self, session): 'execution_date': self.default_time_2, 'task_id': 'test_task_id_2', 'dag_id': 'test_dag_2', - } + }, ], 'total_entries': len(xcom_models), - } + }, ) class TestXComSchema(TestXComSchemaBase): - def setUp(self) -> None: super().setUp() self.default_time = '2005-04-02T21:00:00+00:00' @@ -185,7 +183,7 @@ def test_serialize(self, session): 'task_id': 'test_task_id', 'dag_id': 'test_dag', 'value': 'test_binary', - } + }, ) def test_deserialize(self): @@ -207,5 +205,5 @@ def test_deserialize(self): 'task_id': 'test_task_id', 'dag_id': 'test_dag', 'value': 'test_binary', - } + }, ) diff --git a/tests/api_connexion/test_parameters.py b/tests/api_connexion/test_parameters.py index 50f3f11167cc4..bc9a73f6145df 100644 --- a/tests/api_connexion/test_parameters.py +++ b/tests/api_connexion/test_parameters.py @@ -28,7 +28,6 @@ class TestDateTimeParser(unittest.TestCase): - def setUp(self) -> None: self.default_time = '2020-06-13T22:44:00+00:00' self.default_time_2 = '2020-06-13T22:44:00Z' @@ -52,7 +51,6 @@ def test_raises_400_for_invalid_arg(self): class TestMaximumPagelimit(unittest.TestCase): - @conf_vars({("api", "maximum_page_limit"): "320"}) def test_maximum_limit_return_val(self): limit = check_limit(300) @@ -80,7 +78,6 @@ def test_negative_limit_raises(self): class TestFormatParameters(unittest.TestCase): - def test_should_works_with_datetime_formatter(self): decorator = format_parameters({"param_a": format_datetime}) endpoint = mock.MagicMock() From 430396fc3cddc9c196a681ed44ecfa1b827efed2 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Tue, 25 Aug 2020 10:54:22 +0100 Subject: [PATCH 5/8] fixup! Enable black on api_connextion folder and its tests --- .pre-commit-config.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 88da0d928315c..178e772b1274d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -154,7 +154,6 @@ repos: hooks: - id: black files: api_connexion/.*\.py - args: [--config=./pyproject.toml] - repo: https://github.com/pre-commit/pre-commit-hooks rev: v3.2.0 hooks: From 34ee0c919307d7b7b32b7dbad0f47db82535eb36 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Tue, 25 Aug 2020 11:21:43 +0100 Subject: [PATCH 6/8] fixup! fixup! Enable black on api_connextion folder and its tests --- .pre-commit-config.yaml | 1 + pyproject.toml | 17 ----------------- setup.cfg | 1 + 3 files changed, 2 insertions(+), 17 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 178e772b1274d..88da0d928315c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -154,6 +154,7 @@ repos: hooks: - id: black files: api_connexion/.*\.py + args: [--config=./pyproject.toml] - repo: https://github.com/pre-commit/pre-commit-hooks rev: v3.2.0 hooks: diff --git a/pyproject.toml b/pyproject.toml index 920b195075a48..602897e9062ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,23 +2,6 @@ line-length = 110 target-version = ['py36', 'py37', 'py38'] skip-string-normalization = true -include = '\.pyi?$' -exclude = ''' -/( - \.eggs - | \.build - | \.git - | \.hg - | \.mypy_cache - | \.tox - | \.venv - | _build - | buck-out - | build - | dist - | logs -)/ -''' [tool.isort] line_length = 110 diff --git a/setup.cfg b/setup.cfg index 24c02b03b0c5d..d33e1096f32dd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -76,3 +76,4 @@ known_first_party=airflow,tests multi_line_output=5 # Need to be consistent with the exclude config defined in pre-commit-config.yaml skip=build,.tox,venv +profile = "black" From f01efbedca4dedc5a439a8c6a7816426264ed49b Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Tue, 25 Aug 2020 11:29:41 +0100 Subject: [PATCH 7/8] fixup! fixup! fixup! Enable black on api_connextion folder and its tests --- .pre-commit-config.yaml | 2 +- pyproject.toml | 5 ----- setup.cfg | 3 ++- 3 files changed, 3 insertions(+), 7 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 88da0d928315c..31972006dcbcc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -190,7 +190,7 @@ repos: name: Run isort to sort imports types: [python] # To keep consistent with the global isort skip config defined in setup.cfg - exclude: ^build/.*$|^.tox/.*$|^venv/.*$ + exclude: ^build/.*$|^.tox/.*$|^venv/.*$|.*api_connexion/.*\.py - repo: https://github.com/pycqa/pydocstyle rev: 5.0.2 hooks: diff --git a/pyproject.toml b/pyproject.toml index 602897e9062ee..8446e16145e52 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,8 +2,3 @@ line-length = 110 target-version = ['py36', 'py37', 'py38'] skip-string-normalization = true - -[tool.isort] -line_length = 110 -combine_as_imports = true -profile = "black" diff --git a/setup.cfg b/setup.cfg index d33e1096f32dd..aac6b42ba740b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -76,4 +76,5 @@ known_first_party=airflow,tests multi_line_output=5 # Need to be consistent with the exclude config defined in pre-commit-config.yaml skip=build,.tox,venv -profile = "black" +# ToDo: Enable the below before Airflow 2.0 +# profile = "black" From 72c84d0abcb5e38ac75230366ffa8737afe39f9a Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Tue, 25 Aug 2020 11:30:08 +0100 Subject: [PATCH 8/8] fixup! fixup! fixup! fixup! Enable black on api_connextion folder and its tests --- tests/api_connexion/endpoints/test_log_endpoint.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/api_connexion/endpoints/test_log_endpoint.py b/tests/api_connexion/endpoints/test_log_endpoint.py index fb8df5a774ae9..2acacf3eb30dd 100644 --- a/tests/api_connexion/endpoints/test_log_endpoint.py +++ b/tests/api_connexion/endpoints/test_log_endpoint.py @@ -84,9 +84,9 @@ def _configure_loggers(self): logging_config = copy.deepcopy(DEFAULT_LOGGING_CONFIG) logging_config['handlers']['task']['base_log_folder'] = self.log_dir - logging_config['handlers']['task']['filename_template'] = ( - '{{ ti.dag_id }}/{{ ti.task_id }}/{{ ts | replace(":", ".") }}/{{ try_number }}.log' - ) + logging_config['handlers']['task'][ + 'filename_template' + ] = '{{ ti.dag_id }}/{{ ti.task_id }}/{{ ts | replace(":", ".") }}/{{ try_number }}.log' # Write the custom logging configuration to a file self.settings_folder = tempfile.mkdtemp()