Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Black on Connexion API folders #10545

Merged
merged 8 commits into from
Aug 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[flake8]
max-line-length = 110
ignore = E731,W504,I001,W503
ignore = E231,E731,W504,I001,W503
exclude = .svn,CVS,.bzr,.hg,.git,__pycache__,.eggs,*.egg,node_modules
format = ${cyan}%(path)s${reset}:${yellow_bold}%(row)d${reset}:${green_bold}%(col)d${reset}: ${red_bold}%(code)s${reset} %(text)s
per-file-ignores =
Expand Down
8 changes: 7 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,12 @@ repos:
- repo: meta
hooks:
- id: check-hooks-apply
- repo: https://github.com/psf/black
rev: stable
hooks:
- id: black
files: api_connexion/.*\.py
args: [--config=./pyproject.toml]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.2.0
hooks:
Expand Down Expand Up @@ -184,7 +190,7 @@ repos:
name: Run isort to sort imports
types: [python]
# To keep consistent with the global isort skip config defined in setup.cfg
exclude: ^build/.*$|^.tox/.*$|^venv/.*$
exclude: ^build/.*$|^.tox/.*$|^venv/.*$|.*api_connexion/.*\.py
- repo: https://github.com/pycqa/pydocstyle
rev: 5.0.2
hooks:
Expand Down
12 changes: 5 additions & 7 deletions airflow/api_connexion/endpoints/config_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,7 @@ def _conf_dict_to_config(conf_dict: dict) -> Config:
config = Config(
sections=[
ConfigSection(
name=section,
options=[
ConfigOption(key=key, value=value)
for key, value in options.items()
]
name=section, options=[ConfigOption(key=key, value=value) for key, value in options.items()]
)
for section, options in conf_dict.items()
]
Expand All @@ -49,8 +45,10 @@ def _option_to_text(config_option: ConfigOption) -> str:

def _section_to_text(config_section: ConfigSection) -> str:
"""Convert a single config section to text"""
return (f'[{config_section.name}]{LINE_SEP}'
f'{LINE_SEP.join(_option_to_text(option) for option in config_section.options)}{LINE_SEP}')
return (
f'[{config_section.name}]{LINE_SEP}'
f'{LINE_SEP.join(_option_to_text(option) for option in config_section.options)}{LINE_SEP}'
)


def _config_to_text(config: Config) -> str:
Expand Down
14 changes: 8 additions & 6 deletions airflow/api_connexion/endpoints/connection_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@
from airflow.api_connexion.exceptions import AlreadyExists, BadRequest, NotFound
from airflow.api_connexion.parameters import check_limit, format_parameters
from airflow.api_connexion.schemas.connection_schema import (
ConnectionCollection, connection_collection_item_schema, connection_collection_schema, connection_schema,
ConnectionCollection,
connection_collection_item_schema,
connection_collection_schema,
connection_schema,
)
from airflow.models import Connection
from airflow.utils.session import provide_session
Expand Down Expand Up @@ -56,9 +59,7 @@ def get_connection(connection_id, session):


@security.requires_authentication
@format_parameters({
'limit': check_limit
})
@format_parameters({'limit': check_limit})
@provide_session
def get_connections(session, limit, offset=0):
"""
Expand All @@ -67,8 +68,9 @@ def get_connections(session, limit, offset=0):
total_entries = session.query(func.count(Connection.id)).scalar()
query = session.query(Connection)
connections = query.order_by(Connection.id).offset(offset).limit(limit).all()
return connection_collection_schema.dump(ConnectionCollection(connections=connections,
total_entries=total_entries))
return connection_collection_schema.dump(
ConnectionCollection(connections=connections, total_entries=total_entries)
)


@security.requires_authentication
Expand Down
9 changes: 5 additions & 4 deletions airflow/api_connexion/endpoints/dag_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
from airflow.api_connexion.exceptions import BadRequest, NotFound
from airflow.api_connexion.parameters import check_limit, format_parameters
from airflow.api_connexion.schemas.dag_schema import (
DAGCollection, dag_detail_schema, dag_schema, dags_collection_schema,
DAGCollection,
dag_detail_schema,
dag_schema,
dags_collection_schema,
)
from airflow.models.dag import DagModel
from airflow.utils.session import provide_session
Expand Down Expand Up @@ -55,9 +58,7 @@ def get_dag_details(dag_id):


@security.requires_authentication
@format_parameters({
'limit': check_limit
})
@format_parameters({'limit': check_limit})
@provide_session
def get_dags(session, limit, offset=0):
"""
Expand Down
105 changes: 68 additions & 37 deletions airflow/api_connexion/endpoints/dag_run_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
from airflow.api_connexion.exceptions import AlreadyExists, BadRequest, NotFound
from airflow.api_connexion.parameters import check_limit, format_datetime, format_parameters
from airflow.api_connexion.schemas.dag_run_schema import (
DAGRunCollection, dagrun_collection_schema, dagrun_schema, dagruns_batch_form_schema,
DAGRunCollection,
dagrun_collection_schema,
dagrun_schema,
dagruns_batch_form_schema,
)
from airflow.models import DagModel, DagRun
from airflow.utils.session import provide_session
Expand All @@ -36,11 +39,7 @@ def delete_dag_run(dag_id, dag_run_id, session):
"""
Delete a DAG Run
"""
if (
session.query(DagRun)
.filter(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id)
.delete() == 0
):
if session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).delete() == 0:
raise NotFound(detail=f"DAGRun with DAG ID: '{dag_id}' and DagRun ID: '{dag_run_id}' not found")
return NoContent, 204

Expand All @@ -51,23 +50,24 @@ def get_dag_run(dag_id, dag_run_id, session):
"""
Get a DAG Run.
"""
dag_run = session.query(DagRun).filter(
DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).one_or_none()
dag_run = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).one_or_none()
if dag_run is None:
raise NotFound("DAGRun not found")
return dagrun_schema.dump(dag_run)


@security.requires_authentication
@format_parameters({
'start_date_gte': format_datetime,
'start_date_lte': format_datetime,
'execution_date_gte': format_datetime,
'execution_date_lte': format_datetime,
'end_date_gte': format_datetime,
'end_date_lte': format_datetime,
'limit': check_limit
})
@format_parameters(
{
'start_date_gte': format_datetime,
'start_date_lte': format_datetime,
'execution_date_gte': format_datetime,
'execution_date_lte': format_datetime,
'end_date_gte': format_datetime,
'end_date_lte': format_datetime,
'limit': check_limit,
}
)
@provide_session
def get_dag_runs(
session,
Expand All @@ -91,27 +91,52 @@ def get_dag_runs(
if dag_id != "~":
query = query.filter(DagRun.dag_id == dag_id)

dag_run, total_entries = _fetch_dag_runs(query, session, end_date_gte, end_date_lte, execution_date_gte,
execution_date_lte, start_date_gte, start_date_lte,
limit, offset)
dag_run, total_entries = _fetch_dag_runs(
query,
session,
end_date_gte,
end_date_lte,
execution_date_gte,
execution_date_lte,
start_date_gte,
start_date_lte,
limit,
offset,
)

return dagrun_collection_schema.dump(DAGRunCollection(dag_runs=dag_run,
total_entries=total_entries))
return dagrun_collection_schema.dump(DAGRunCollection(dag_runs=dag_run, total_entries=total_entries))


def _fetch_dag_runs(query, session, end_date_gte, end_date_lte,
execution_date_gte, execution_date_lte,
start_date_gte, start_date_lte, limit, offset):
query = _apply_date_filters_to_query(query, end_date_gte, end_date_lte, execution_date_gte,
execution_date_lte, start_date_gte, start_date_lte)
def _fetch_dag_runs(
query,
session,
end_date_gte,
end_date_lte,
execution_date_gte,
execution_date_lte,
start_date_gte,
start_date_lte,
limit,
offset,
):
query = _apply_date_filters_to_query(
query,
end_date_gte,
end_date_lte,
execution_date_gte,
execution_date_lte,
start_date_gte,
start_date_lte,
)
# apply offset and limit
dag_run = query.order_by(DagRun.id).offset(offset).limit(limit).all()
total_entries = session.query(func.count(DagRun.id)).scalar()
return dag_run, total_entries


def _apply_date_filters_to_query(query, end_date_gte, end_date_lte, execution_date_gte,
execution_date_lte, start_date_gte, start_date_lte):
def _apply_date_filters_to_query(
query, end_date_gte, end_date_lte, execution_date_gte, execution_date_lte, start_date_gte, start_date_lte
):
# filter start date
if start_date_gte:
query = query.filter(DagRun.start_date >= start_date_gte)
Expand Down Expand Up @@ -147,13 +172,20 @@ def get_dag_runs_batch(session):
if data["dag_ids"]:
query = query.filter(DagRun.dag_id.in_(data["dag_ids"]))

dag_runs, total_entries = _fetch_dag_runs(query, session, data["end_date_gte"], data["end_date_lte"],
data["execution_date_gte"], data["execution_date_lte"],
data["start_date_gte"], data["start_date_lte"],
data["page_limit"], data["page_offset"])
dag_runs, total_entries = _fetch_dag_runs(
query,
session,
data["end_date_gte"],
data["end_date_lte"],
data["execution_date_gte"],
data["execution_date_lte"],
data["start_date_gte"],
data["start_date_lte"],
data["page_limit"],
data["page_offset"],
)

return dagrun_collection_schema.dump(DAGRunCollection(dag_runs=dag_runs,
total_entries=total_entries))
return dagrun_collection_schema.dump(DAGRunCollection(dag_runs=dag_runs, total_entries=total_entries))


@security.requires_authentication
Expand All @@ -167,8 +199,7 @@ def post_dag_run(dag_id, session):

post_body = dagrun_schema.load(request.json, session=session)
dagrun_instance = (
session.query(DagRun).filter(
DagRun.dag_id == dag_id, DagRun.run_id == post_body["run_id"]).first()
session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == post_body["run_id"]).first()
)
if not dagrun_instance:
dag_run = DagRun(dag_id=dag_id, run_type=DagRunType.MANUAL.value, **post_body)
Expand Down
13 changes: 7 additions & 6 deletions airflow/api_connexion/endpoints/event_log_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
from airflow.api_connexion.exceptions import NotFound
from airflow.api_connexion.parameters import check_limit, format_parameters
from airflow.api_connexion.schemas.event_log_schema import (
EventLogCollection, event_log_collection_schema, event_log_schema,
EventLogCollection,
event_log_collection_schema,
event_log_schema,
)
from airflow.models import Log
from airflow.utils.session import provide_session
Expand All @@ -41,9 +43,7 @@ def get_event_log(event_log_id, session):


@security.requires_authentication
@format_parameters({
'limit': check_limit
})
@format_parameters({'limit': check_limit})
@provide_session
def get_event_logs(session, limit, offset=None):
"""
Expand All @@ -52,5 +52,6 @@ def get_event_logs(session, limit, offset=None):

total_entries = session.query(func.count(Log.id)).scalar()
event_logs = session.query(Log).order_by(Log.id).offset(offset).limit(limit).all()
return event_log_collection_schema.dump(EventLogCollection(event_logs=event_logs,
total_entries=total_entries))
return event_log_collection_schema.dump(
EventLogCollection(event_logs=event_logs, total_entries=total_entries)
)
5 changes: 1 addition & 4 deletions airflow/api_connexion/endpoints/health_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,7 @@ def get_health():

payload = {
"metadatabase": {"status": metadatabase_status},
"scheduler": {
"status": scheduler_status,
"latest_scheduler_heartbeat": latest_scheduler_heartbeat,
},
"scheduler": {"status": scheduler_status, "latest_scheduler_heartbeat": latest_scheduler_heartbeat,},
}

return health_schema.dump(payload)
8 changes: 4 additions & 4 deletions airflow/api_connexion/endpoints/import_error_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
from airflow.api_connexion.exceptions import NotFound
from airflow.api_connexion.parameters import check_limit, format_parameters
from airflow.api_connexion.schemas.error_schema import (
ImportErrorCollection, import_error_collection_schema, import_error_schema,
ImportErrorCollection,
import_error_collection_schema,
import_error_schema,
)
from airflow.models.errors import ImportError # pylint: disable=redefined-builtin
from airflow.utils.session import provide_session
Expand All @@ -41,9 +43,7 @@ def get_import_error(import_error_id, session):


@security.requires_authentication
@format_parameters({
'limit': check_limit
})
@format_parameters({'limit': check_limit})
@provide_session
def get_import_errors(session, limit, offset=None):
"""
Expand Down
12 changes: 3 additions & 9 deletions airflow/api_connexion/endpoints/log_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@

@security.requires_authentication
@provide_session
def get_log(session, dag_id, dag_run_id, task_id, task_try_number,
full_content=False, token=None):
def get_log(session, dag_id, dag_run_id, task_id, task_try_number, full_content=False, token=None):
"""
Get logs for specific task instance
"""
Expand Down Expand Up @@ -77,13 +76,8 @@ def get_log(session, dag_id, dag_run_id, task_id, task_try_number,
logs, metadata = task_log_reader.read_log_chunks(ti, task_try_number, metadata)
logs = logs[0] if task_try_number is not None else logs
token = URLSafeSerializer(key).dumps(metadata)
return logs_schema.dump(LogResponseObject(continuation_token=token,
content=logs)
)
return logs_schema.dump(LogResponseObject(continuation_token=token, content=logs))
# text/plain. Stream
logs = task_log_reader.read_log_stream(ti, task_try_number, metadata)

return Response(
logs,
headers={"Content-Type": return_type}
)
return Response(logs, headers={"Content-Type": return_type})
8 changes: 2 additions & 6 deletions airflow/api_connexion/endpoints/pool_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ def get_pool(pool_name, session):


@security.requires_authentication
@format_parameters({
'limit': check_limit
})
@format_parameters({'limit': check_limit})
@provide_session
def get_pools(session, limit, offset=None):
"""
Expand All @@ -65,9 +63,7 @@ def get_pools(session, limit, offset=None):

total_entries = session.query(func.count(Pool.id)).scalar()
pools = session.query(Pool).order_by(Pool.id).offset(offset).limit(limit).all()
return pool_collection_schema.dump(
PoolCollection(pools=pools, total_entries=total_entries)
)
return pool_collection_schema.dump(PoolCollection(pools=pools, total_entries=total_entries))


@security.requires_authentication
Expand Down
Loading