Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Black - Python Auto Formmatter #9550

Merged
merged 2 commits into from
Nov 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
5 changes: 2 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,7 @@ repos:
rev: 20.8b1
hooks:
- id: black
files: api_connexion/.*\.py|.*providers.*\.py|^chart/tests/.*\.py
exclude: .*kubernetes_pod\.py|.*google/common/hooks/base_google\.py$
exclude: .*kubernetes_pod\.py|.*google/common/hooks/base_google\.py$|^airflow/configuration.py$
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason for this exclude?

Copy link
Member Author

@kaxil kaxil Nov 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, black had issues with it i.e. it couldn't reproduce the same file pre and post black so it would be good to ignore them for now

❯ pre-commit run black -a
black....................................................................Failed
- hook id: black
- exit code: 123

error: cannot format /Users/kaxilnaik/Documents/Github/astronomer/airflow/airflow/configuration.py: INTERNAL ERROR: Black produced different code on the second pass of the formatter.  Please report a bug on https://github.com/psf/black/issues.  This diff might be helpful: /var/folders/jk/z68c_8nd1w5ggdc51r5zhnxm0000gn/T/blk_madaodvm.log
Oh no! 💥 💔 💥
2548 files left unchanged, 1 file failed to reformat.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting 👀

args: [--config=./pyproject.toml]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.3.0
Expand Down Expand Up @@ -203,7 +202,7 @@ repos:
name: Run isort to sort imports
types: [python]
# To keep consistent with the global isort skip config defined in setup.cfg
exclude: ^build/.*$|^.tox/.*$|^venv/.*$|.*api_connexion/.*\.py|.*providers.*\.py
exclude: ^build/.*$|^.tox/.*$|^venv/.*$
- repo: https://github.com/pycqa/pydocstyle
rev: 5.1.1
hooks:
Expand Down
3 changes: 3 additions & 0 deletions airflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,18 @@ def __getattr__(name):
# PEP-562: Lazy loaded attributes on python modules
if name == "DAG":
from airflow.models.dag import DAG # pylint: disable=redefined-outer-name

return DAG
if name == "AirflowException":
from airflow.exceptions import AirflowException # pylint: disable=redefined-outer-name

return AirflowException
raise AttributeError(f"module {__name__} has no attribute {name}")


if not settings.LAZY_LOAD_PLUGINS:
from airflow import plugins_manager

plugins_manager.ensure_plugins_loaded()


Expand Down
5 changes: 1 addition & 4 deletions airflow/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,5 @@ def load_auth():
log.info("Loaded API auth backend: %s", auth_backend)
return auth_backend
except ImportError as err:
log.critical(
"Cannot import %s for API authentication due to: %s",
auth_backend, err
)
log.critical("Cannot import %s for API authentication due to: %s", auth_backend, err)
raise AirflowException(err)
5 changes: 2 additions & 3 deletions airflow/api/auth/backend/basic_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,12 @@ def auth_current_user() -> Optional[User]:

def requires_authentication(function: T):
"""Decorator for functions that require authentication"""

@wraps(function)
def decorated(*args, **kwargs):
if auth_current_user() is not None:
return function(*args, **kwargs)
else:
return Response(
"Unauthorized", 401, {"WWW-Authenticate": "Basic"}
)
return Response("Unauthorized", 401, {"WWW-Authenticate": "Basic"})

return cast(T, decorated)
1 change: 1 addition & 0 deletions airflow/api/auth/backend/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def init_app(_):

def requires_authentication(function: T):
"""Decorator for functions that require authentication"""

@wraps(function)
def decorated(*args, **kwargs):
return function(*args, **kwargs)
Expand Down
5 changes: 3 additions & 2 deletions airflow/api/auth/backend/kerberos_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def _gssapi_authenticate(token):

def requires_authentication(function: T):
"""Decorator for functions that require authentication with Kerberos"""

@wraps(function)
def decorated(*args, **kwargs):
header = request.headers.get("Authorization")
Expand All @@ -144,11 +145,11 @@ def decorated(*args, **kwargs):
response = function(*args, **kwargs)
response = make_response(response)
if ctx.kerberos_token is not None:
response.headers['WWW-Authenticate'] = ' '.join(['negotiate',
ctx.kerberos_token])
response.headers['WWW-Authenticate'] = ' '.join(['negotiate', ctx.kerberos_token])

return response
if return_code != kerberos.AUTH_GSS_CONTINUE:
return _forbidden()
return _unauthorized()

return cast(T, decorated)
2 changes: 1 addition & 1 deletion airflow/api/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,6 @@ def get_current_api_client() -> Client:
api_client = api_module.Client(
api_base_url=conf.get('cli', 'endpoint_url'),
auth=getattr(auth_backend, 'CLIENT_AUTH', None),
session=session
session=session,
)
return api_client
30 changes: 18 additions & 12 deletions airflow/api/client/json_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,15 @@ def _request(self, url, method='GET', json=None):
def trigger_dag(self, dag_id, run_id=None, conf=None, execution_date=None):
endpoint = f'/api/experimental/dags/{dag_id}/dag_runs'
url = urljoin(self._api_base_url, endpoint)
data = self._request(url, method='POST',
json={
"run_id": run_id,
"conf": conf,
"execution_date": execution_date,
})
data = self._request(
url,
method='POST',
json={
"run_id": run_id,
"conf": conf,
"execution_date": execution_date,
},
)
return data['message']

def delete_dag(self, dag_id):
Expand All @@ -74,12 +77,15 @@ def get_pools(self):
def create_pool(self, name, slots, description):
endpoint = '/api/experimental/pools'
url = urljoin(self._api_base_url, endpoint)
pool = self._request(url, method='POST',
json={
'name': name,
'slots': slots,
'description': description,
})
pool = self._request(
url,
method='POST',
json={
'name': name,
'slots': slots,
'description': description,
},
)
return pool['pool'], pool['slots'], pool['description']

def delete_pool(self, name):
Expand Down
7 changes: 3 additions & 4 deletions airflow/api/client/local_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@ class Client(api_client.Client):
"""Local API client implementation."""

def trigger_dag(self, dag_id, run_id=None, conf=None, execution_date=None):
dag_run = trigger_dag.trigger_dag(dag_id=dag_id,
run_id=run_id,
conf=conf,
execution_date=execution_date)
dag_run = trigger_dag.trigger_dag(
dag_id=dag_id, run_id=run_id, conf=conf, execution_date=execution_date
)
return f"Created {dag_run}"

def delete_dag(self, dag_id):
Expand Down
8 changes: 2 additions & 6 deletions airflow/api/common/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,7 @@ def check_and_get_dag(dag_id: str, task_id: Optional[str] = None) -> DagModel:
if dag_model is None:
raise DagNotFound(f"Dag id {dag_id} not found in DagModel")

dagbag = DagBag(
dag_folder=dag_model.fileloc,
read_dags_from_db=True
)
dagbag = DagBag(dag_folder=dag_model.fileloc, read_dags_from_db=True)
dag = dagbag.get_dag(dag_id)
if not dag:
error_message = f"Dag id {dag_id} not found"
Expand All @@ -47,7 +44,6 @@ def check_and_get_dagrun(dag: DagModel, execution_date: datetime) -> DagRun:
"""Get DagRun object and check that it exists"""
dagrun = dag.get_dagrun(execution_date=execution_date)
if not dagrun:
error_message = ('Dag Run for date {} not found in dag {}'
.format(execution_date, dag.dag_id))
error_message = f'Dag Run for date {execution_date} not found in dag {dag.dag_id}'
raise DagRunNotFound(error_message)
return dagrun
11 changes: 6 additions & 5 deletions airflow/api/common/experimental/delete_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,14 @@ def delete_dag(dag_id: str, keep_records_in_log: bool = True, session=None) -> i
if dag.is_subdag:
parent_dag_id, task_id = dag_id.rsplit(".", 1)
for model in TaskFail, models.TaskInstance:
count += session.query(model).filter(model.dag_id == parent_dag_id,
model.task_id == task_id).delete()
count += (
session.query(model).filter(model.dag_id == parent_dag_id, model.task_id == task_id).delete()
)

# Delete entries in Import Errors table for a deleted DAG
# This handles the case when the dag_id is changed in the file
session.query(models.ImportError).filter(
models.ImportError.filename == dag.fileloc
).delete(synchronize_session='fetch')
session.query(models.ImportError).filter(models.ImportError.filename == dag.fileloc).delete(
synchronize_session='fetch'
)

return count
22 changes: 11 additions & 11 deletions airflow/api/common/experimental/get_dag_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,16 @@ def get_dag_runs(dag_id: str, state: Optional[str] = None) -> List[Dict[str, Any
dag_runs = []
state = state.lower() if state else None
for run in DagRun.find(dag_id=dag_id, state=state):
dag_runs.append({
'id': run.id,
'run_id': run.run_id,
'state': run.state,
'dag_id': run.dag_id,
'execution_date': run.execution_date.isoformat(),
'start_date': ((run.start_date or '') and
run.start_date.isoformat()),
'dag_run_url': url_for('Airflow.graph', dag_id=run.dag_id,
execution_date=run.execution_date)
})
dag_runs.append(
{
'id': run.id,
'run_id': run.run_id,
'state': run.state,
'dag_id': run.dag_id,
'execution_date': run.execution_date.isoformat(),
'start_date': ((run.start_date or '') and run.start_date.isoformat()),
'dag_run_url': url_for('Airflow.graph', dag_id=run.dag_id, execution_date=run.execution_date),
}
)

return dag_runs
10 changes: 6 additions & 4 deletions airflow/api/common/experimental/get_lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@ def get_lineage(dag_id: str, execution_date: datetime.datetime, session=None) ->
dag = check_and_get_dag(dag_id)
check_and_get_dagrun(dag, execution_date)

inlets: List[XCom] = XCom.get_many(dag_ids=dag_id, execution_date=execution_date,
key=PIPELINE_INLETS, session=session).all()
outlets: List[XCom] = XCom.get_many(dag_ids=dag_id, execution_date=execution_date,
key=PIPELINE_OUTLETS, session=session).all()
inlets: List[XCom] = XCom.get_many(
dag_ids=dag_id, execution_date=execution_date, key=PIPELINE_INLETS, session=session
).all()
outlets: List[XCom] = XCom.get_many(
dag_ids=dag_id, execution_date=execution_date, key=PIPELINE_OUTLETS, session=session
).all()

lineage: Dict[str, Dict[str, Any]] = {}
for meta in inlets:
Expand Down
3 changes: 1 addition & 2 deletions airflow/api/common/experimental/get_task_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ def get_task_instance(dag_id: str, task_id: str, execution_date: datetime) -> Ta
# Get task instance object and check that it exists
task_instance = dagrun.get_task_instance(task_id)
if not task_instance:
error_message = ('Task {} instance for date {} not found'
.format(task_id, execution_date))
error_message = f'Task {task_id} instance for date {execution_date} not found'
raise TaskInstanceNotFound(error_message)

return task_instance
73 changes: 36 additions & 37 deletions airflow/api/common/experimental/mark_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def set_state(
past: bool = False,
state: str = State.SUCCESS,
commit: bool = False,
session=None
session=None,
): # pylint: disable=too-many-arguments,too-many-locals
"""
Set the state of a task instance and if needed its relatives. Can set state
Expand Down Expand Up @@ -137,33 +137,24 @@ def set_state(
# Flake and pylint disagree about correct indents here
def all_subdag_tasks_query(sub_dag_run_ids, session, state, confirmed_dates): # noqa: E123
"""Get *all* tasks of the sub dags"""
qry_sub_dag = session.query(TaskInstance). \
filter(
TaskInstance.dag_id.in_(sub_dag_run_ids),
TaskInstance.execution_date.in_(confirmed_dates)
). \
filter(
or_(
TaskInstance.state.is_(None),
TaskInstance.state != state
)
qry_sub_dag = (
session.query(TaskInstance)
.filter(TaskInstance.dag_id.in_(sub_dag_run_ids), TaskInstance.execution_date.in_(confirmed_dates))
.filter(or_(TaskInstance.state.is_(None), TaskInstance.state != state))
) # noqa: E123
return qry_sub_dag


def get_all_dag_task_query(dag, session, state, task_ids, confirmed_dates):
"""Get all tasks of the main dag that will be affected by a state change"""
qry_dag = session.query(TaskInstance). \
filter(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.execution_date.in_(confirmed_dates),
TaskInstance.task_id.in_(task_ids) # noqa: E123
). \
filter(
or_(
TaskInstance.state.is_(None),
TaskInstance.state != state
qry_dag = (
session.query(TaskInstance)
.filter(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.execution_date.in_(confirmed_dates),
TaskInstance.task_id.in_(task_ids), # noqa: E123
)
.filter(or_(TaskInstance.state.is_(None), TaskInstance.state != state))
)
return qry_dag

Expand All @@ -186,10 +177,12 @@ def get_subdag_runs(dag, session, state, task_ids, commit, confirmed_dates):
# this works as a kind of integrity check
# it creates missing dag runs for subdag operators,
# maybe this should be moved to dagrun.verify_integrity
dag_runs = _create_dagruns(current_task.subdag,
execution_dates=confirmed_dates,
state=State.RUNNING,
run_type=DagRunType.BACKFILL_JOB)
dag_runs = _create_dagruns(
current_task.subdag,
execution_dates=confirmed_dates,
state=State.RUNNING,
run_type=DagRunType.BACKFILL_JOB,
)

verify_dagruns(dag_runs, commit, state, session, current_task)

Expand Down Expand Up @@ -279,10 +272,9 @@ def _set_dag_run_state(dag_id, execution_date, state, session=None):
:param state: target state
:param session: database session
"""
dag_run = session.query(DagRun).filter(
DagRun.dag_id == dag_id,
DagRun.execution_date == execution_date
).one()
dag_run = (
session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.execution_date == execution_date).one()
)
dag_run.state = state
if state == State.RUNNING:
dag_run.start_date = timezone.utcnow()
Expand Down Expand Up @@ -316,8 +308,9 @@ def set_dag_run_state_to_success(dag, execution_date, commit=False, session=None
# Mark all task instances of the dag run to success.
for task in dag.tasks:
task.dag = dag
return set_state(tasks=dag.tasks, execution_date=execution_date,
state=State.SUCCESS, commit=commit, session=session)
return set_state(
tasks=dag.tasks, execution_date=execution_date, state=State.SUCCESS, commit=commit, session=session
)


@provide_session
Expand All @@ -343,10 +336,15 @@ def set_dag_run_state_to_failed(dag, execution_date, commit=False, session=None)

# Mark only RUNNING task instances.
task_ids = [task.task_id for task in dag.tasks]
tis = session.query(TaskInstance).filter(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.execution_date == execution_date,
TaskInstance.task_id.in_(task_ids)).filter(TaskInstance.state == State.RUNNING)
tis = (
session.query(TaskInstance)
.filter(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.execution_date == execution_date,
TaskInstance.task_id.in_(task_ids),
)
.filter(TaskInstance.state == State.RUNNING)
)
task_ids_of_running_tis = [task_instance.task_id for task_instance in tis]

tasks = []
Expand All @@ -356,8 +354,9 @@ def set_dag_run_state_to_failed(dag, execution_date, commit=False, session=None)
task.dag = dag
tasks.append(task)

return set_state(tasks=tasks, execution_date=execution_date,
state=State.FAILED, commit=commit, session=session)
return set_state(
tasks=tasks, execution_date=execution_date, state=State.FAILED, commit=commit, session=session
)


@provide_session
Expand Down
Loading