From 0472c2dfe3513151776a720c67d9db913643c8d7 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Mon, 5 Aug 2019 16:37:19 +0200 Subject: [PATCH 01/15] Remove provide context --- .../example_dags/example_qubole_operator.py | 1 - .../contrib/utils/mlengine_operator_utils.py | 1 - airflow/example_dags/docker_copy_data.py | 1 - .../example_branch_python_dop_operator_3.py | 1 - ...example_passing_params_via_test_command.py | 9 ++--- .../example_dags/example_python_operator.py | 1 - .../example_trigger_target_dag.py | 1 - airflow/example_dags/example_xcom.py | 1 - airflow/operators/python_operator.py | 39 +++++++++---------- docs/howto/operator/python.rst | 3 +- .../operators/test_aws_athena_operator.py | 1 - .../operators/test_s3_to_sftp_operator.py | 1 - tests/contrib/operators/test_sftp_operator.py | 1 - .../operators/test_sftp_to_s3_operator.py | 1 - tests/contrib/operators/test_ssh_operator.py | 1 - tests/core.py | 1 - tests/dags/test_cli_triggered_dags.py | 1 - tests/operators/test_python_operator.py | 15 ++++--- tests/operators/test_virtualenv_operator.py | 14 ------- tests/utils/test_log_handlers.py | 2 - 20 files changed, 31 insertions(+), 65 deletions(-) diff --git a/airflow/contrib/example_dags/example_qubole_operator.py b/airflow/contrib/example_dags/example_qubole_operator.py index b07f2734e8ff0..45f0f30d6a939 100644 --- a/airflow/contrib/example_dags/example_qubole_operator.py +++ b/airflow/contrib/example_dags/example_qubole_operator.py @@ -88,7 +88,6 @@ def compare_result(ds, **kwargs): t3 = PythonOperator( task_id='compare_result', - provide_context=True, python_callable=compare_result, trigger_rule="all_done", dag=dag) diff --git a/airflow/contrib/utils/mlengine_operator_utils.py b/airflow/contrib/utils/mlengine_operator_utils.py index e1682ef45ade0..ed545fdb46638 100644 --- a/airflow/contrib/utils/mlengine_operator_utils.py +++ b/airflow/contrib/utils/mlengine_operator_utils.py @@ -238,7 +238,6 @@ def apply_validate_fn(*args, **kwargs): evaluate_validation = PythonOperator( task_id=(task_prefix + "-validation"), python_callable=apply_validate_fn, - provide_context=True, templates_dict={"prediction_path": prediction_path}, dag=dag) evaluate_validation.set_upstream(evaluate_summary) diff --git a/airflow/example_dags/docker_copy_data.py b/airflow/example_dags/docker_copy_data.py index 6aba3f2cb3557..aa7600f3d8f2f 100644 --- a/airflow/example_dags/docker_copy_data.py +++ b/airflow/example_dags/docker_copy_data.py @@ -70,7 +70,6 @@ # # t_is_data_available = ShortCircuitOperator( # task_id='check_if_data_available', -# provide_context=True, # python_callable=is_data_available, # dag=dag) # diff --git a/airflow/example_dags/example_branch_python_dop_operator_3.py b/airflow/example_dags/example_branch_python_dop_operator_3.py index ec60cfc01b903..7455ef7ebbd23 100644 --- a/airflow/example_dags/example_branch_python_dop_operator_3.py +++ b/airflow/example_dags/example_branch_python_dop_operator_3.py @@ -58,7 +58,6 @@ def should_run(**kwargs): cond = BranchPythonOperator( task_id='condition', - provide_context=True, python_callable=should_run, dag=dag, ) diff --git a/airflow/example_dags/example_passing_params_via_test_command.py b/airflow/example_dags/example_passing_params_via_test_command.py index 152b8cde9e63b..e8fc9c963a916 100644 --- a/airflow/example_dags/example_passing_params_via_test_command.py +++ b/airflow/example_dags/example_passing_params_via_test_command.py @@ -37,17 +37,17 @@ ) -def my_py_command(**kwargs): +def my_py_command(test_mode, params): """ Print out the "foo" param passed in via `airflow tasks test example_passing_params_via_test_command run_this -tp '{"foo":"bar"}'` """ - if kwargs["test_mode"]: + if test_mode: print(" 'foo' was passed in via test={} command : kwargs[params][foo] \ - = {}".format(kwargs["test_mode"], kwargs["params"]["foo"])) + = {}".format(test_mode, params["foo"])) # Print out the value of "miff", passed in below via the Python Operator - print(" 'miff' was passed in via task params = {}".format(kwargs["params"]["miff"])) + print(" 'miff' was passed in via task params = {}".format(params["miff"])) return 1 @@ -58,7 +58,6 @@ def my_py_command(**kwargs): run_this = PythonOperator( task_id='run_this', - provide_context=True, python_callable=my_py_command, params={"miff": "agg"}, dag=dag, diff --git a/airflow/example_dags/example_python_operator.py b/airflow/example_dags/example_python_operator.py index 29c664f0a65ec..86403ceb25b63 100644 --- a/airflow/example_dags/example_python_operator.py +++ b/airflow/example_dags/example_python_operator.py @@ -48,7 +48,6 @@ def print_context(ds, **kwargs): run_this = PythonOperator( task_id='print_the_context', - provide_context=True, python_callable=print_context, dag=dag, ) diff --git a/airflow/example_dags/example_trigger_target_dag.py b/airflow/example_dags/example_trigger_target_dag.py index 4758176981152..32255103d804e 100644 --- a/airflow/example_dags/example_trigger_target_dag.py +++ b/airflow/example_dags/example_trigger_target_dag.py @@ -69,7 +69,6 @@ def run_this_func(**kwargs): run_this = PythonOperator( task_id='run_this', - provide_context=True, python_callable=run_this_func, dag=dag, ) diff --git a/airflow/example_dags/example_xcom.py b/airflow/example_dags/example_xcom.py index 5b7b79aca80f9..1c2bb2e532966 100644 --- a/airflow/example_dags/example_xcom.py +++ b/airflow/example_dags/example_xcom.py @@ -26,7 +26,6 @@ args = { 'owner': 'Airflow', 'start_date': airflow.utils.dates.days_ago(2), - 'provide_context': True, } dag = DAG('example_xcom', schedule_interval="@once", default_args=args) diff --git a/airflow/operators/python_operator.py b/airflow/operators/python_operator.py index 46430b215e93c..ad674d12f5a14 100644 --- a/airflow/operators/python_operator.py +++ b/airflow/operators/python_operator.py @@ -25,6 +25,7 @@ import types from textwrap import dedent from typing import Optional, Iterable, Dict, Callable +from inspect import signature import dill @@ -51,12 +52,6 @@ class PythonOperator(BaseOperator): :param op_args: a list of positional arguments that will get unpacked when calling your callable :type op_args: list (templated) - :param provide_context: if set to true, Airflow will pass a set of - keyword arguments that can be used in your function. This set of - kwargs correspond exactly to what you can use in your jinja - templates. For this to work, you need to define `**kwargs` in your - function header. - :type provide_context: bool :param templates_dict: a dictionary where the values are templates that will get templated by the Airflow engine sometime between ``__init__`` and ``execute`` takes place and are made available @@ -79,7 +74,6 @@ def __init__( python_callable: Callable, op_args: Optional[Iterable] = None, op_kwargs: Optional[Dict] = None, - provide_context: bool = False, templates_dict: Optional[Dict] = None, templates_exts: Optional[Iterable[str]] = None, *args, @@ -91,7 +85,6 @@ def __init__( self.python_callable = python_callable self.op_args = op_args or [] self.op_kwargs = op_kwargs or {} - self.provide_context = provide_context self.templates_dict = templates_dict if templates_exts: self.template_ext = templates_exts @@ -104,10 +97,21 @@ def execute(self, context): for k, v in airflow_context_vars.items()])) os.environ.update(airflow_context_vars) - if self.provide_context: - context.update(self.op_kwargs) - context['templates_dict'] = self.templates_dict + context.update(self.op_kwargs) + context['templates_dict'] = self.templates_dict + + if {parameter for name, parameter + in signature(self.python_callable).parameters.items() + if str(parameter).startswith("**")}: + # If there is a **kwargs, **context or **_ then just pass everything. self.op_kwargs = context + else: + # If there is only for example, an execution_date, then pass only these in :-) + self.op_kwargs = { + name: context[name] for name, parameter + in signature(self.python_callable).parameters.items() + if name in context # If it isn't available on the context, then ignore + } return_value = self.execute_callable() self.log.info("Done. Returned value was: %s", return_value) @@ -130,6 +134,7 @@ class BranchPythonOperator(PythonOperator, SkipMixin): downstream to allow for the DAG state to fill up and the DAG run's state to be inferred. """ + def execute(self, context): branch = super().execute(context) self.skip_all_except(context['ti'], branch) @@ -147,6 +152,7 @@ class ShortCircuitOperator(PythonOperator, SkipMixin): The condition is determined by the result of `python_callable`. """ + def execute(self, context): condition = super().execute(context) self.log.info("Condition result is %s", condition) @@ -200,12 +206,6 @@ class PythonVirtualenvOperator(PythonOperator): :type op_kwargs: list :param op_kwargs: A dict of keyword arguments to pass to python_callable. :type op_kwargs: dict - :param provide_context: if set to true, Airflow will pass a set of - keyword arguments that can be used in your function. This set of - kwargs correspond exactly to what you can use in your jinja - templates. For this to work, you need to define `**kwargs` in your - function header. - :type provide_context: bool :param string_args: Strings that are present in the global var virtualenv_string_args, available to python_callable at runtime as a list[str]. Note that args are split by newline. @@ -219,6 +219,7 @@ class PythonVirtualenvOperator(PythonOperator): processing templated fields, for examples ``['.sql', '.hql']`` :type templates_exts: list[str] """ + @apply_defaults def __init__( self, @@ -229,7 +230,6 @@ def __init__( system_site_packages: bool = True, op_args: Iterable = None, op_kwargs: Dict = None, - provide_context: bool = False, string_args: Optional[Iterable[str]] = None, templates_dict: Optional[Dict] = None, templates_exts: Optional[Iterable[str]] = None, @@ -242,7 +242,6 @@ def __init__( op_kwargs=op_kwargs, templates_dict=templates_dict, templates_exts=templates_exts, - provide_context=provide_context, *args, **kwargs) self.requirements = requirements or [] @@ -383,7 +382,7 @@ def _generate_python_code(self): fn = self.python_callable # dont try to read pickle if we didnt pass anything if self._pass_op_args(): - load_args_line = 'with open(sys.argv[1], "rb") as file: arg_dict = {}.load(file)'\ + load_args_line = 'with open(sys.argv[1], "rb") as file: arg_dict = {}.load(file)' \ .format(pickling_library) else: load_args_line = 'arg_dict = {"args": [], "kwargs": {}}' diff --git a/docs/howto/operator/python.rst b/docs/howto/operator/python.rst index da2180138e069..1f361735fbfec 100644 --- a/docs/howto/operator/python.rst +++ b/docs/howto/operator/python.rst @@ -42,8 +42,7 @@ to the Python callable. Templating ^^^^^^^^^^ -When you set the ``provide_context`` argument to ``True``, Airflow passes in -an additional set of keyword arguments: one for each of the :doc:`Jinja +Airflow passes in a set of keyword arguments: one for each of the :doc:`Jinja template variables <../../macros>` and a ``templates_dict`` argument. The ``templates_dict`` argument is templated, so each value in the dictionary diff --git a/tests/contrib/operators/test_aws_athena_operator.py b/tests/contrib/operators/test_aws_athena_operator.py index b86bbab0f1369..c3d98e25db97d 100644 --- a/tests/contrib/operators/test_aws_athena_operator.py +++ b/tests/contrib/operators/test_aws_athena_operator.py @@ -53,7 +53,6 @@ def setUp(self): args = { 'owner': 'airflow', 'start_date': DEFAULT_DATE, - 'provide_context': True } self.dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', diff --git a/tests/contrib/operators/test_s3_to_sftp_operator.py b/tests/contrib/operators/test_s3_to_sftp_operator.py index fc78f0c8698d3..a40fc4557247e 100644 --- a/tests/contrib/operators/test_s3_to_sftp_operator.py +++ b/tests/contrib/operators/test_s3_to_sftp_operator.py @@ -69,7 +69,6 @@ def setUp(self): args = { 'owner': 'airflow', 'start_date': DEFAULT_DATE, - 'provide_context': True } dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args) dag.schedule_interval = '@once' diff --git a/tests/contrib/operators/test_sftp_operator.py b/tests/contrib/operators/test_sftp_operator.py index b54c328ba5efa..30fa74101d7b7 100644 --- a/tests/contrib/operators/test_sftp_operator.py +++ b/tests/contrib/operators/test_sftp_operator.py @@ -53,7 +53,6 @@ def setUp(self): args = { 'owner': 'airflow', 'start_date': DEFAULT_DATE, - 'provide_context': True } dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args) dag.schedule_interval = '@once' diff --git a/tests/contrib/operators/test_sftp_to_s3_operator.py b/tests/contrib/operators/test_sftp_to_s3_operator.py index 02f4e84c010b2..9b45e1da17232 100644 --- a/tests/contrib/operators/test_sftp_to_s3_operator.py +++ b/tests/contrib/operators/test_sftp_to_s3_operator.py @@ -68,7 +68,6 @@ def setUp(self): args = { 'owner': 'airflow', 'start_date': DEFAULT_DATE, - 'provide_context': True } dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args) dag.schedule_interval = '@once' diff --git a/tests/contrib/operators/test_ssh_operator.py b/tests/contrib/operators/test_ssh_operator.py index a27dc27bc7ca1..a7fe90fa95eeb 100644 --- a/tests/contrib/operators/test_ssh_operator.py +++ b/tests/contrib/operators/test_ssh_operator.py @@ -51,7 +51,6 @@ def setUp(self): args = { 'owner': 'airflow', 'start_date': DEFAULT_DATE, - 'provide_context': True } dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args) dag.schedule_interval = '@once' diff --git a/tests/core.py b/tests/core.py index c6de2e0f9679f..698c996e6cbfd 100644 --- a/tests/core.py +++ b/tests/core.py @@ -567,7 +567,6 @@ def test_py_op(templates_dict, ds, **kwargs): t = PythonOperator( task_id='test_py_op', - provide_context=True, python_callable=test_py_op, templates_dict={'ds': "{{ ds }}"}, dag=self.dag) diff --git a/tests/dags/test_cli_triggered_dags.py b/tests/dags/test_cli_triggered_dags.py index 9f53ca4c3ab0b..f2dc7b63895d6 100644 --- a/tests/dags/test_cli_triggered_dags.py +++ b/tests/dags/test_cli_triggered_dags.py @@ -51,6 +51,5 @@ def success(ti=None, *args, **kwargs): dag1_task2 = PythonOperator( task_id='test_run_dependent_task', python_callable=success, - provide_context=True, dag=dag1) dag1_task1.set_downstream(dag1_task2) diff --git a/tests/operators/test_python_operator.py b/tests/operators/test_python_operator.py index e5e8049aa1340..de931fd863e2b 100644 --- a/tests/operators/test_python_operator.py +++ b/tests/operators/test_python_operator.py @@ -57,8 +57,8 @@ def build_recording_function(calls_collection): Then using this custom function recording custom Call objects for further testing (replacing Mock.assert_called_with assertion method) """ - def recording_function(*args, **kwargs): - calls_collection.append(Call(*args, **kwargs)) + def recording_function(*args): + calls_collection.append(Call(*args)) return recording_function @@ -129,11 +129,10 @@ def test_python_operator_python_callable_is_callable(self): task_id='python_operator', dag=self.dag) - def _assertCallsEqual(self, first, second): + def _assert_calls_equal(self, first, second): self.assertIsInstance(first, Call) self.assertIsInstance(second, Call) self.assertTupleEqual(first.args, second.args) - self.assertDictEqual(first.kwargs, second.kwargs) def test_python_callable_arguments_are_templatized(self): """Test PythonOperator op_args are templatized""" @@ -148,7 +147,7 @@ def test_python_callable_arguments_are_templatized(self): task_id='python_operator', # a Mock instance cannot be used as a callable function or test fails with a # TypeError: Object of type Mock is not JSON serializable - python_callable=(build_recording_function(recorded_calls)), + python_callable=build_recording_function(recorded_calls), op_args=[ 4, date(2019, 1, 1), @@ -167,7 +166,7 @@ def test_python_callable_arguments_are_templatized(self): ds_templated = DEFAULT_DATE.date().isoformat() self.assertEqual(1, len(recorded_calls)) - self._assertCallsEqual( + self._assert_calls_equal( recorded_calls[0], Call(4, date(2019, 1, 1), @@ -183,7 +182,7 @@ def test_python_callable_keyword_arguments_are_templatized(self): task_id='python_operator', # a Mock instance cannot be used as a callable function or test fails with a # TypeError: Object of type Mock is not JSON serializable - python_callable=(build_recording_function(recorded_calls)), + python_callable=build_recording_function(recorded_calls), op_kwargs={ 'an_int': 4, 'a_date': date(2019, 1, 1), @@ -200,7 +199,7 @@ def test_python_callable_keyword_arguments_are_templatized(self): task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) self.assertEqual(1, len(recorded_calls)) - self._assertCallsEqual( + self._assert_calls_equal( recorded_calls[0], Call(an_int=4, a_date=date(2019, 1, 1), diff --git a/tests/operators/test_virtualenv_operator.py b/tests/operators/test_virtualenv_operator.py index 52c14f88806d4..95ff2142510e4 100644 --- a/tests/operators/test_virtualenv_operator.py +++ b/tests/operators/test_virtualenv_operator.py @@ -202,17 +202,3 @@ def test_context(self): def f(**kwargs): return kwargs['templates_dict']['ds'] self._run_as_operator(f, templates_dict={'ds': '{{ ds }}'}) - - def test_provide_context(self): - def fn(): - pass - task = PythonVirtualenvOperator( - python_callable=fn, - python_version=sys.version_info[0], - task_id='task', - dag=self.dag, - provide_context=True, - ) - self.assertTrue( - task.provide_context - ) diff --git a/tests/utils/test_log_handlers.py b/tests/utils/test_log_handlers.py index e06bdae90891e..7126871cb5807 100644 --- a/tests/utils/test_log_handlers.py +++ b/tests/utils/test_log_handlers.py @@ -71,7 +71,6 @@ def task_callable(ti, **kwargs): task_id='task_for_testing_file_log_handler', dag=dag, python_callable=task_callable, - provide_context=True ) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) @@ -123,7 +122,6 @@ def task_callable(ti, **kwargs): task_id='task_for_testing_file_log_handler', dag=dag, python_callable=task_callable, - provide_context=True ) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.try_number = 2 From 66ab4f71092cfc8e02afa26d10d5b89d5730c156 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Tue, 3 Sep 2019 11:32:02 +0200 Subject: [PATCH 02/15] Simplify the arguments --- tests/operators/test_virtualenv_operator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/operators/test_virtualenv_operator.py b/tests/operators/test_virtualenv_operator.py index 95ff2142510e4..97c3dcc4eb473 100644 --- a/tests/operators/test_virtualenv_operator.py +++ b/tests/operators/test_virtualenv_operator.py @@ -199,6 +199,6 @@ def f(_): self._run_as_operator(f, op_args=[datetime.datetime.utcnow()]) def test_context(self): - def f(**kwargs): - return kwargs['templates_dict']['ds'] + def f(templates_dict): + return templates_dict['ds'] self._run_as_operator(f, templates_dict={'ds': '{{ ds }}'}) From 15b8eefd8972a67e5ee9a12c7ae40ab11eb08eea Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Tue, 3 Sep 2019 11:47:12 +0200 Subject: [PATCH 03/15] Feedback is een cadeautje --- airflow/operators/python_operator.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/airflow/operators/python_operator.py b/airflow/operators/python_operator.py index ad674d12f5a14..5ec578f6aa9d6 100644 --- a/airflow/operators/python_operator.py +++ b/airflow/operators/python_operator.py @@ -100,16 +100,16 @@ def execute(self, context): context.update(self.op_kwargs) context['templates_dict'] = self.templates_dict - if {parameter for name, parameter + if {param for param in signature(self.python_callable).parameters.items() - if str(parameter).startswith("**")}: + if str(param).startswith("**")}: # If there is a **kwargs, **context or **_ then just pass everything. self.op_kwargs = context else: # If there is only for example, an execution_date, then pass only these in :-) self.op_kwargs = { - name: context[name] for name, parameter - in signature(self.python_callable).parameters.items() + name: context[name] + for name in signature(self.python_callable).parameters.keys() if name in context # If it isn't available on the context, then ignore } From 6c0c8bb61f4c23c5b12a8757e7fdf76eff33016c Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Tue, 3 Sep 2019 12:22:03 +0200 Subject: [PATCH 04/15] Add additional tests --- tests/operators/test_python_operator.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/operators/test_python_operator.py b/tests/operators/test_python_operator.py index 31f18edf1dde9..76867f1b4ab73 100644 --- a/tests/operators/test_python_operator.py +++ b/tests/operators/test_python_operator.py @@ -250,6 +250,28 @@ def test_echo_env_variables(self): ) t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + def test_dynamic_provide_context(self): + def fn(dag): + if dag != 1: + raise ValueError("Should be 1") + + python_operator = PythonOperator( + op_kwargs={'dag': 1}, + python_callable=fn + ) + python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + + def test_dynamic_provide_context(self): + def fn(dag): + if dag != 1: + raise ValueError("Should be 1") + + python_operator = PythonOperator( + op_args=[1], + python_callable=fn + ) + python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + class TestBranchOperator(unittest.TestCase): @classmethod From c6ba062e90ab400771e136b12915b2d8e785facb Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Tue, 3 Sep 2019 12:44:46 +0200 Subject: [PATCH 05/15] Cover some edge cases --- airflow/operators/python_operator.py | 17 +++++++++++------ tests/operators/test_python_operator.py | 4 ++-- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/airflow/operators/python_operator.py b/airflow/operators/python_operator.py index 5ec578f6aa9d6..3cc8f23b71a80 100644 --- a/airflow/operators/python_operator.py +++ b/airflow/operators/python_operator.py @@ -23,9 +23,10 @@ import subprocess import sys import types +from inspect import signature +from itertools import islice from textwrap import dedent from typing import Optional, Iterable, Dict, Callable -from inspect import signature import dill @@ -100,16 +101,20 @@ def execute(self, context): context.update(self.op_kwargs) context['templates_dict'] = self.templates_dict - if {param for param - in signature(self.python_callable).parameters.items() - if str(param).startswith("**")}: - # If there is a **kwargs, **context or **_ then just pass everything. + sig_full = signature(self.python_callable).parameters.items() + # Remove the first n arguments equal to len(op_args). + # The notation is a bit akward since the OrderedDict is not slice-able + # https://stackoverflow.com/questions/30975339/slicing-a-python-ordereddict + sig_without_op_args = islice(sig_full, len(self.op_args), sys.maxsize) + + if any(str(param).startswith("**") for param in sig_without_op_args): + # If there is a **kwargs, **context or **_ then just dump everything. self.op_kwargs = context else: # If there is only for example, an execution_date, then pass only these in :-) self.op_kwargs = { name: context[name] - for name in signature(self.python_callable).parameters.keys() + for name in sig_without_op_args if name in context # If it isn't available on the context, then ignore } diff --git a/tests/operators/test_python_operator.py b/tests/operators/test_python_operator.py index 76867f1b4ab73..852a6bdf8970e 100644 --- a/tests/operators/test_python_operator.py +++ b/tests/operators/test_python_operator.py @@ -250,7 +250,7 @@ def test_echo_env_variables(self): ) t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - def test_dynamic_provide_context(self): + def test_conflicting_kwargs(self): def fn(dag): if dag != 1: raise ValueError("Should be 1") @@ -261,7 +261,7 @@ def fn(dag): ) python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - def test_dynamic_provide_context(self): + def test_context_with_conflicting_op_args(self): def fn(dag): if dag != 1: raise ValueError("Should be 1") From d99bb9c96b56aa16e0cd26877f743ee7e952d8df Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Tue, 3 Sep 2019 22:30:21 +0200 Subject: [PATCH 06/15] Extend the tests --- tests/operators/test_python_operator.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/tests/operators/test_python_operator.py b/tests/operators/test_python_operator.py index 852a6bdf8970e..27abed22a8b38 100644 --- a/tests/operators/test_python_operator.py +++ b/tests/operators/test_python_operator.py @@ -251,24 +251,44 @@ def test_echo_env_variables(self): t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) def test_conflicting_kwargs(self): + self.dag.create_dagrun( + run_id='manual__' + DEFAULT_DATE.isoformat(), + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + state=State.RUNNING, + external_trigger=False, + ) + def fn(dag): if dag != 1: raise ValueError("Should be 1") python_operator = PythonOperator( + task_id='python_operator', op_kwargs={'dag': 1}, - python_callable=fn + python_callable=fn, + dag=self.dag ) python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) def test_context_with_conflicting_op_args(self): + self.dag.create_dagrun( + run_id='manual__' + DEFAULT_DATE.isoformat(), + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + state=State.RUNNING, + external_trigger=False, + ) + def fn(dag): if dag != 1: raise ValueError("Should be 1") python_operator = PythonOperator( + task_id='python_operator', op_args=[1], - python_callable=fn + python_callable=fn, + dag=self.dag ) python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) From 145f83594c1cebf7456c2082b3e57e3444de65fc Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Wed, 4 Sep 2019 10:41:56 +0200 Subject: [PATCH 07/15] Update the tests --- airflow/operators/python_operator.py | 22 +++++++++++++--------- tests/operators/test_python_operator.py | 21 +++++++++++++-------- 2 files changed, 26 insertions(+), 17 deletions(-) diff --git a/airflow/operators/python_operator.py b/airflow/operators/python_operator.py index 3cc8f23b71a80..6d5a92b1b339f 100644 --- a/airflow/operators/python_operator.py +++ b/airflow/operators/python_operator.py @@ -100,21 +100,25 @@ def execute(self, context): context.update(self.op_kwargs) context['templates_dict'] = self.templates_dict + context_keys = context.keys() - sig_full = signature(self.python_callable).parameters.items() - # Remove the first n arguments equal to len(op_args). - # The notation is a bit akward since the OrderedDict is not slice-able - # https://stackoverflow.com/questions/30975339/slicing-a-python-ordereddict - sig_without_op_args = islice(sig_full, len(self.op_args), sys.maxsize) + sig = signature(self.python_callable).parameters.items() + op_args_names = islice(sig, len(self.op_args)) - if any(str(param).startswith("**") for param in sig_without_op_args): + for name in op_args_names: + # Check if it part of the context + if name in context_keys: + # Raise an exception + raise ValueError("The key {} in the op_args is part of the context, and therefore reserved".format(name)) + + if any(str(param).startswith("**") for param in sig): # If there is a **kwargs, **context or **_ then just dump everything. self.op_kwargs = context else: # If there is only for example, an execution_date, then pass only these in :-) self.op_kwargs = { name: context[name] - for name in sig_without_op_args + for name in sig if name in context # If it isn't available on the context, then ignore } @@ -268,8 +272,8 @@ def __init__( self.__class__.__name__) # check that args are passed iff python major version matches if (python_version is not None and - str(python_version)[0] != str(sys.version_info[0]) and - self._pass_op_args()): + str(python_version)[0] != str(sys.version_info[0]) and + self._pass_op_args()): raise AirflowException("Passing op_args or op_kwargs is not supported across " "different Python major versions " "for PythonVirtualenvOperator. " diff --git a/tests/operators/test_python_operator.py b/tests/operators/test_python_operator.py index 27abed22a8b38..8a34887098a08 100644 --- a/tests/operators/test_python_operator.py +++ b/tests/operators/test_python_operator.py @@ -24,6 +24,8 @@ from collections import namedtuple from datetime import timedelta, date +import pytest + from airflow.exceptions import AirflowException from airflow.models import TaskInstance as TI, DAG, DagRun from airflow.operators.dummy_operator import DummyOperator @@ -259,17 +261,20 @@ def test_conflicting_kwargs(self): external_trigger=False, ) + # dag is not allowed since it is a reserved keyword def fn(dag): - if dag != 1: - raise ValueError("Should be 1") + # An ValueError should be triggered since we're using dag as a + # reserved keyword + raise RuntimeError("Should not be triggered, dag: {}".format(dag)) python_operator = PythonOperator( task_id='python_operator', - op_kwargs={'dag': 1}, + op_args=[1], python_callable=fn, dag=self.dag ) - python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + with pytest.raises(ValueError, match=r".* dag .*"): + python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) def test_context_with_conflicting_op_args(self): self.dag.create_dagrun( @@ -280,13 +285,13 @@ def test_context_with_conflicting_op_args(self): external_trigger=False, ) - def fn(dag): - if dag != 1: - raise ValueError("Should be 1") + def fn(custom, dag): + if custom != 1: + raise ValueError("Should be 1, but was {}, dag: {}".format(custom, dag)) python_operator = PythonOperator( task_id='python_operator', - op_args=[1], + op_kwargs={'custom': 1}, python_callable=fn, dag=self.dag ) From d0be33328b16b0794ef483da5938b6404c5f9242 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Wed, 4 Sep 2019 11:08:30 +0200 Subject: [PATCH 08/15] Remove flake8 violations --- airflow/operators/python_operator.py | 8 +++++--- tests/contrib/hooks/test_aws_glue_catalog_hook.py | 1 + tests/core.py | 1 + tests/dags/test_dag_serialization.py | 4 ++-- 4 files changed, 9 insertions(+), 5 deletions(-) diff --git a/airflow/operators/python_operator.py b/airflow/operators/python_operator.py index 6d5a92b1b339f..bc2b6d850ed44 100644 --- a/airflow/operators/python_operator.py +++ b/airflow/operators/python_operator.py @@ -109,7 +109,9 @@ def execute(self, context): # Check if it part of the context if name in context_keys: # Raise an exception - raise ValueError("The key {} in the op_args is part of the context, and therefore reserved".format(name)) + raise ValueError( + "The key {} in the op_args is part of the context, and therefore reserved".format(name) + ) if any(str(param).startswith("**") for param in sig): # If there is a **kwargs, **context or **_ then just dump everything. @@ -272,8 +274,8 @@ def __init__( self.__class__.__name__) # check that args are passed iff python major version matches if (python_version is not None and - str(python_version)[0] != str(sys.version_info[0]) and - self._pass_op_args()): + str(python_version)[0] != str(sys.version_info[0]) and + self._pass_op_args()): raise AirflowException("Passing op_args or op_kwargs is not supported across " "different Python major versions " "for PythonVirtualenvOperator. " diff --git a/tests/contrib/hooks/test_aws_glue_catalog_hook.py b/tests/contrib/hooks/test_aws_glue_catalog_hook.py index 311dbcaf3390f..85b2777e88c51 100644 --- a/tests/contrib/hooks/test_aws_glue_catalog_hook.py +++ b/tests/contrib/hooks/test_aws_glue_catalog_hook.py @@ -43,6 +43,7 @@ } } + @unittest.skipIf(mock_glue is None, "Skipping test because moto.mock_glue is not available") class TestAwsGlueCatalogHook(unittest.TestCase): diff --git a/tests/core.py b/tests/core.py index 62a1271139f05..f7e6f4c5f5ef6 100644 --- a/tests/core.py +++ b/tests/core.py @@ -2178,6 +2178,7 @@ def test_init_proxy_user(self): HDFSHook = None # type: Optional[hdfs_hook.HDFSHook] snakebite = None # type: None + @unittest.skipIf(HDFSHook is None, "Skipping test because HDFSHook is not installed") class TestHDFSHook(unittest.TestCase): diff --git a/tests/dags/test_dag_serialization.py b/tests/dags/test_dag_serialization.py index 9618925efa312..1934c642a4232 100644 --- a/tests/dags/test_dag_serialization.py +++ b/tests/dags/test_dag_serialization.py @@ -158,7 +158,7 @@ def make_example_dags(module, dag_ids): def make_simple_dag(): """Make very simple DAG to verify serialization result.""" dag = DAG(dag_id='simple_dag') - _ = BaseOperator(task_id='simple_task', dag=dag, start_date=datetime(2019, 8, 1)) + BaseOperator(task_id='simple_task', dag=dag, start_date=datetime(2019, 8, 1)) return {'simple_dag': dag} @@ -186,7 +186,7 @@ def compute_next_execution_date(dag, execution_date): }, catchup=False ) - _ = BashOperator( + BashOperator( task_id='echo', bash_command='echo "{{ next_execution_date(dag, execution_date) }}"', dag=dag, From 3fc2e98f8f29a5056373573847148be012363a88 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Wed, 4 Sep 2019 11:40:24 +0200 Subject: [PATCH 09/15] Works on my machine (using Breeze :-) --- airflow/operators/python_operator.py | 14 +++++++---- tests/operators/test_python_operator.py | 31 +++++++++++++++++++++---- 2 files changed, 35 insertions(+), 10 deletions(-) diff --git a/airflow/operators/python_operator.py b/airflow/operators/python_operator.py index bc2b6d850ed44..ec846556543be 100644 --- a/airflow/operators/python_operator.py +++ b/airflow/operators/python_operator.py @@ -104,26 +104,30 @@ def execute(self, context): sig = signature(self.python_callable).parameters.items() op_args_names = islice(sig, len(self.op_args)) - - for name in op_args_names: + for name, _ in op_args_names: # Check if it part of the context if name in context_keys: - # Raise an exception + # Raise an exception to let the user know that the keyword is reserved raise ValueError( "The key {} in the op_args is part of the context, and therefore reserved".format(name) ) - if any(str(param).startswith("**") for param in sig): + print(sig) + + if any(str(param).startswith("**") for _, param in sig): # If there is a **kwargs, **context or **_ then just dump everything. self.op_kwargs = context else: # If there is only for example, an execution_date, then pass only these in :-) self.op_kwargs = { name: context[name] - for name in sig + for name, _ in sig if name in context # If it isn't available on the context, then ignore } + print(self.op_kwargs) + print(sig) + return_value = self.execute_callable() self.log.info("Done. Returned value was: %s", return_value) return return_value diff --git a/tests/operators/test_python_operator.py b/tests/operators/test_python_operator.py index 8a34887098a08..497a001939262 100644 --- a/tests/operators/test_python_operator.py +++ b/tests/operators/test_python_operator.py @@ -24,8 +24,6 @@ from collections import namedtuple from datetime import timedelta, date -import pytest - from airflow.exceptions import AirflowException from airflow.models import TaskInstance as TI, DAG, DagRun from airflow.operators.dummy_operator import DummyOperator @@ -273,8 +271,10 @@ def fn(dag): python_callable=fn, dag=self.dag ) - with pytest.raises(ValueError, match=r".* dag .*"): + + with self.assertRaises(ValueError) as context: python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + self.assertTrue('dag' in context.exception, "'dag' not found in the exception") def test_context_with_conflicting_op_args(self): self.dag.create_dagrun( @@ -286,8 +286,29 @@ def test_context_with_conflicting_op_args(self): ) def fn(custom, dag): - if custom != 1: - raise ValueError("Should be 1, but was {}, dag: {}".format(custom, dag)) + self.assertEqual(1, custom, "custom should be 1") + self.assertIsNotNone(dag, "dag should be set") + + python_operator = PythonOperator( + task_id='python_operator', + op_kwargs={'custom': 1}, + python_callable=fn, + dag=self.dag + ) + python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + + def test_context_with_kwargs(self): + self.dag.create_dagrun( + run_id='manual__' + DEFAULT_DATE.isoformat(), + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + state=State.RUNNING, + external_trigger=False, + ) + + def fn(**context): + # check if context is being set + self.assertGreater(len(context), 0, "Context has not been injected") python_operator = PythonOperator( task_id='python_operator', From c9995eef84afbbb476476d4411f57b62e39cfddf Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Wed, 4 Sep 2019 12:51:26 +0200 Subject: [PATCH 10/15] Clean up occurrences with provide_context --- UPDATING.md | 5 ++ ...kins_job_trigger_operator.py.notexecutable | 5 +- airflow/contrib/sensors/python_sensor.py | 27 ++++------- airflow/gcp/utils/mlengine_operator_utils.py | 1 - airflow/operators/python_operator.py | 42 +++++++++-------- airflow/sensors/http_sensor.py | 46 +++++++++---------- docs/concepts.rst | 13 +++--- tests/contrib/sensors/test_file_sensor.py | 3 +- .../sensors/test_gcs_upload_session_sensor.py | 1 - tests/sensors/test_http_sensor.py | 13 ++---- 10 files changed, 72 insertions(+), 84 deletions(-) diff --git a/UPDATING.md b/UPDATING.md index daac8adf23b59..dd80e9e252661 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -37,8 +37,13 @@ assists users migrating to a new version. - [Airflow 1.7.1.2](#airflow-1712) + ## Airflow Master +### Remove provide_context + +Instead of settings `provide_context` we're automagically inferring the signature of the callable that is being passed to the PythonOperator. The only behavioural change in is that using a key that is already in the context in the function, such as `dag` or `ds` is not allowed anymore and will thrown an exception. If the `provide_context` is still explicitly passed to the function, it will just end up in the `kwargs`, which can cause no harm. + ### Change dag loading duration metric name Change DAG file loading duration metric from `dag.loading-duration.` to `dag.loading-duration.`. This is to diff --git a/airflow/contrib/example_dags/example_jenkins_job_trigger_operator.py.notexecutable b/airflow/contrib/example_dags/example_jenkins_job_trigger_operator.py.notexecutable index 2d8906b0b3b9d..c0c3df61f5c13 100644 --- a/airflow/contrib/example_dags/example_jenkins_job_trigger_operator.py.notexecutable +++ b/airflow/contrib/example_dags/example_jenkins_job_trigger_operator.py.notexecutable @@ -7,9 +7,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -75,7 +75,6 @@ def grabArtifactFromJenkins(**context): artifact_grabber = PythonOperator( task_id='artifact_grabber', - provide_context=True, python_callable=grabArtifactFromJenkins, dag=dag) diff --git a/airflow/contrib/sensors/python_sensor.py b/airflow/contrib/sensors/python_sensor.py index a2dc2031a86bb..ec611be3a236e 100644 --- a/airflow/contrib/sensors/python_sensor.py +++ b/airflow/contrib/sensors/python_sensor.py @@ -16,9 +16,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from airflow.operators.python_operator import PythonOperator from airflow.sensors.base_sensor_operator import BaseSensorOperator from airflow.utils.decorators import apply_defaults +from typing import Optional, Iterable, Dict, Callable class PythonSensor(BaseSensorOperator): @@ -38,12 +40,6 @@ class PythonSensor(BaseSensorOperator): :param op_args: a list of positional arguments that will get unpacked when calling your callable :type op_args: list - :param provide_context: if set to true, Airflow will pass a set of - keyword arguments that can be used in your function. This set of - kwargs correspond exactly to what you can use in your jinja - templates. For this to work, you need to define `**kwargs` in your - function header. - :type provide_context: bool :param templates_dict: a dictionary where the values are templates that will get templated by the Airflow engine sometime between ``__init__`` and ``execute`` takes place and are made available @@ -56,24 +52,21 @@ class PythonSensor(BaseSensorOperator): @apply_defaults def __init__( self, - python_callable, - op_args=None, - op_kwargs=None, - provide_context=False, - templates_dict=None, + python_callable: Callable, + op_args: Optional[Iterable] = None, + op_kwargs: Optional[Dict] = None, + templates_dict: Optional[Dict] = None, *args, **kwargs): super().__init__(*args, **kwargs) self.python_callable = python_callable self.op_args = op_args or [] self.op_kwargs = op_kwargs or {} - self.provide_context = provide_context self.templates_dict = templates_dict - def poke(self, context): - if self.provide_context: - context.update(self.op_kwargs) - context['templates_dict'] = self.templates_dict - self.op_kwargs = context + def poke(self, context: Dict): + context.update(self.op_kwargs) + context['templates_dict'] = self.templates_dict + self.op_kwargs = PythonOperator.determine_op_kwargs(self.python_callable, context, len(self.op_args)) self.log.info("Poking callable: %s", str(self.python_callable)) return_value = self.python_callable(*self.op_args, **self.op_kwargs) diff --git a/airflow/gcp/utils/mlengine_operator_utils.py b/airflow/gcp/utils/mlengine_operator_utils.py index 66cdad8a171d4..658a2088b5425 100644 --- a/airflow/gcp/utils/mlengine_operator_utils.py +++ b/airflow/gcp/utils/mlengine_operator_utils.py @@ -240,7 +240,6 @@ def apply_validate_fn(*args, **kwargs): evaluate_validation = PythonOperator( task_id=(task_prefix + "-validation"), python_callable=apply_validate_fn, - provide_context=True, templates_dict={"prediction_path": prediction_path}, dag=dag) evaluate_validation.set_upstream(evaluate_summary) diff --git a/airflow/operators/python_operator.py b/airflow/operators/python_operator.py index ec846556543be..c760dbec20d34 100644 --- a/airflow/operators/python_operator.py +++ b/airflow/operators/python_operator.py @@ -26,7 +26,7 @@ from inspect import signature from itertools import islice from textwrap import dedent -from typing import Optional, Iterable, Dict, Callable +from typing import Optional, Iterable, Dict, Callable, Tuple import dill @@ -90,20 +90,13 @@ def __init__( if templates_exts: self.template_ext = templates_exts - def execute(self, context): - # Export context to make it available for callables to use. - airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True) - self.log.info("Exporting the following env vars:\n%s", - '\n'.join(["{}={}".format(k, v) - for k, v in airflow_context_vars.items()])) - os.environ.update(airflow_context_vars) - - context.update(self.op_kwargs) - context['templates_dict'] = self.templates_dict + @staticmethod + def determine_op_kwargs(python_callable: Callable, + context: Dict, + num_op_args: int = 0) -> Dict: context_keys = context.keys() - - sig = signature(self.python_callable).parameters.items() - op_args_names = islice(sig, len(self.op_args)) + sig = signature(python_callable).parameters.items() + op_args_names = islice(sig, num_op_args) for name, _ in op_args_names: # Check if it part of the context if name in context_keys: @@ -112,21 +105,30 @@ def execute(self, context): "The key {} in the op_args is part of the context, and therefore reserved".format(name) ) - print(sig) - if any(str(param).startswith("**") for _, param in sig): # If there is a **kwargs, **context or **_ then just dump everything. - self.op_kwargs = context + op_kwargs = context else: # If there is only for example, an execution_date, then pass only these in :-) - self.op_kwargs = { + op_kwargs = { name: context[name] for name, _ in sig if name in context # If it isn't available on the context, then ignore } + return op_kwargs + + def execute(self, context: Dict): + # Export context to make it available for callables to use. + airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True) + self.log.info("Exporting the following env vars:\n%s", + '\n'.join(["{}={}".format(k, v) + for k, v in airflow_context_vars.items()])) + os.environ.update(airflow_context_vars) + + context.update(self.op_kwargs) + context['templates_dict'] = self.templates_dict - print(self.op_kwargs) - print(sig) + self.op_kwargs = PythonOperator.determine_op_kwargs(self.python_callable, context, len(self.op_args)) return_value = self.execute_callable() self.log.info("Done. Returned value was: %s", return_value) diff --git a/airflow/sensors/http_sensor.py b/airflow/sensors/http_sensor.py index 6e5d69946232e..3102d33f4c049 100644 --- a/airflow/sensors/http_sensor.py +++ b/airflow/sensors/http_sensor.py @@ -16,6 +16,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Dict, Callable + +from airflow.operators.python_operator import PythonOperator from airflow.exceptions import AirflowException from airflow.hooks.http_hook import HttpHook @@ -31,13 +34,17 @@ class HttpSensor(BaseSensorOperator): HTTP Error codes other than 404 (like 403) or Connection Refused Error would fail the sensor itself directly (no more poking). - The response check can access the template context by passing ``provide_context=True`` to the operator:: + The response check can access the template context to the operator: - def response_check(response, **context): - # Can look at context['ti'] etc. + def response_check(response, task_instance): + # The task_instance is injected, so you can pull data form xcom + # Other context variables such as dag, ds, execution_date are also available. + xcom_data = task_instance.xcom_pull(task_ids='pushing_task') + # In practice you would do something more sensible with this data.. + print(xcom_data) return True - HttpSensor(task_id='my_http_sensor', ..., provide_context=True, response_check=response_check) + HttpSensor(task_id='my_http_sensor', ..., response_check=response_check) :param http_conn_id: The connection to run the sensor against @@ -50,12 +57,6 @@ def response_check(response, **context): :type request_params: a dictionary of string key/value pairs :param headers: The HTTP headers to be added to the GET request :type headers: a dictionary of string key/value pairs - :param provide_context: if set to true, Airflow will pass a set of - keyword arguments that can be used in your function. This set of - kwargs correspond exactly to what you can use in your jinja - templates. For this to work, you need to define context in your - function header. - :type provide_context: bool :param response_check: A check against the 'requests' response object. Returns True for 'pass' and False otherwise. :type response_check: A lambda or defined function. @@ -69,14 +70,13 @@ def response_check(response, **context): @apply_defaults def __init__(self, - endpoint, - http_conn_id='http_default', - method='GET', - request_params=None, - headers=None, - response_check=None, - provide_context=False, - extra_options=None, *args, **kwargs): + endpoint: str, + http_conn_id: str = 'http_default', + method: str = 'GET', + request_params: Dict = None, + headers: Dict = None, + response_check: Callable = None, + extra_options: Dict = None, *args, **kwargs): super().__init__(*args, **kwargs) self.endpoint = endpoint self.http_conn_id = http_conn_id @@ -84,13 +84,12 @@ def __init__(self, self.headers = headers or {} self.extra_options = extra_options or {} self.response_check = response_check - self.provide_context = provide_context self.hook = HttpHook( method=method, http_conn_id=http_conn_id) - def poke(self, context): + def poke(self, context: Dict): self.log.info('Poking: %s', self.endpoint) try: response = self.hook.run(self.endpoint, @@ -98,10 +97,9 @@ def poke(self, context): headers=self.headers, extra_options=self.extra_options) if self.response_check: - if self.provide_context: - return self.response_check(response, **context) - else: - return self.response_check(response) + op_kwargs = PythonOperator.determine_op_kwargs(self.response_check, context) + return self.response_check(response, **op_kwargs) + except AirflowException as ae: if str(ae).startswith("404"): return False diff --git a/docs/concepts.rst b/docs/concepts.rst index d258b4da38155..96930d185a8dd 100644 --- a/docs/concepts.rst +++ b/docs/concepts.rst @@ -548,9 +548,12 @@ passed, then a corresponding list of XCom values is returned. def push_function(): return value - # inside another PythonOperator where provide_context=True - def pull_function(**context): - value = context['task_instance'].xcom_pull(task_ids='pushing_task') + # inside another PythonOperator + def pull_function(task_instance): + value = task_instance.xcom_pull(task_ids='pushing_task') + +When specifying arguments that are part of the context, they will be +automatically passed to the function. It is also possible to pull XCom directly in a template, here's an example of what this may look like: @@ -632,8 +635,7 @@ For example: .. code:: python - def branch_func(**kwargs): - ti = kwargs['ti'] + def branch_func(ti): xcom_value = int(ti.xcom_pull(task_ids='start_task')) if xcom_value >= 5: return 'continue_task' @@ -648,7 +650,6 @@ For example: branch_op = BranchPythonOperator( task_id='branch_task', - provide_context=True, python_callable=branch_func, dag=dag) diff --git a/tests/contrib/sensors/test_file_sensor.py b/tests/contrib/sensors/test_file_sensor.py index 34720f5cbccca..8d520ce2572c3 100644 --- a/tests/contrib/sensors/test_file_sensor.py +++ b/tests/contrib/sensors/test_file_sensor.py @@ -49,8 +49,7 @@ def setUp(self): hook = FSHook() args = { 'owner': 'airflow', - 'start_date': DEFAULT_DATE, - 'provide_context': True + 'start_date': DEFAULT_DATE } dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args) dag.schedule_interval = '@once' diff --git a/tests/contrib/sensors/test_gcs_upload_session_sensor.py b/tests/contrib/sensors/test_gcs_upload_session_sensor.py index c230835923120..ea4515b6ecb57 100644 --- a/tests/contrib/sensors/test_gcs_upload_session_sensor.py +++ b/tests/contrib/sensors/test_gcs_upload_session_sensor.py @@ -62,7 +62,6 @@ def setUp(self): args = { 'owner': 'airflow', 'start_date': DEFAULT_DATE, - 'provide_context': True } dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args) dag.schedule_interval = '@once' diff --git a/tests/sensors/test_http_sensor.py b/tests/sensors/test_http_sensor.py index 37a0a48e50906..40c374eb21853 100644 --- a/tests/sensors/test_http_sensor.py +++ b/tests/sensors/test_http_sensor.py @@ -96,19 +96,13 @@ def resp_check(_): @patch("airflow.hooks.http_hook.requests.Session.send") def test_poke_context(self, mock_session_send): - """ - test provide_context - """ response = requests.Response() response.status_code = 200 mock_session_send.return_value = response - def resp_check(resp, **context): - if context: - if "execution_date" in context: - if context["execution_date"] == DEFAULT_DATE: - return True - + def resp_check(resp, execution_date): + if execution_date == DEFAULT_DATE: + return True raise AirflowException('AirflowException raised here!') task = HttpSensor( @@ -117,7 +111,6 @@ def resp_check(resp, **context): endpoint='', request_params={}, response_check=resp_check, - provide_context=True, timeout=5, poke_interval=1, dag=self.dag) From d7f35d6b8e0061ef9d8b3fcec09d40c0e6297c16 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Wed, 4 Sep 2019 13:02:00 +0200 Subject: [PATCH 11/15] Some cleanup --- airflow/operators/python_operator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/operators/python_operator.py b/airflow/operators/python_operator.py index c760dbec20d34..f1e6d1dbd15fe 100644 --- a/airflow/operators/python_operator.py +++ b/airflow/operators/python_operator.py @@ -26,7 +26,7 @@ from inspect import signature from itertools import islice from textwrap import dedent -from typing import Optional, Iterable, Dict, Callable, Tuple +from typing import Optional, Iterable, Dict, Callable import dill From 7ab68bc90d4cf34d93ef01f491df1e425e2c89d7 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Wed, 4 Sep 2019 13:47:47 +0200 Subject: [PATCH 12/15] Fix the types --- airflow/contrib/sensors/python_sensor.py | 4 ++-- airflow/operators/python_operator.py | 19 ++++++++++++++----- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/airflow/contrib/sensors/python_sensor.py b/airflow/contrib/sensors/python_sensor.py index ec611be3a236e..146ab7ec39ba1 100644 --- a/airflow/contrib/sensors/python_sensor.py +++ b/airflow/contrib/sensors/python_sensor.py @@ -20,7 +20,7 @@ from airflow.sensors.base_sensor_operator import BaseSensorOperator from airflow.utils.decorators import apply_defaults -from typing import Optional, Iterable, Dict, Callable +from typing import Optional, Iterable, Dict, Callable, List class PythonSensor(BaseSensorOperator): @@ -53,7 +53,7 @@ class PythonSensor(BaseSensorOperator): def __init__( self, python_callable: Callable, - op_args: Optional[Iterable] = None, + op_args: Optional[List] = None, op_kwargs: Optional[Dict] = None, templates_dict: Optional[Dict] = None, *args, **kwargs): diff --git a/airflow/operators/python_operator.py b/airflow/operators/python_operator.py index f1e6d1dbd15fe..49b89658aa9c2 100644 --- a/airflow/operators/python_operator.py +++ b/airflow/operators/python_operator.py @@ -26,7 +26,7 @@ from inspect import signature from itertools import islice from textwrap import dedent -from typing import Optional, Iterable, Dict, Callable +from typing import Optional, Iterable, Dict, Callable, List import dill @@ -73,10 +73,10 @@ class PythonOperator(BaseOperator): def __init__( self, python_callable: Callable, - op_args: Optional[Iterable] = None, + op_args: Optional[List] = None, op_kwargs: Optional[Dict] = None, templates_dict: Optional[Dict] = None, - templates_exts: Optional[Iterable[str]] = None, + templates_exts: Optional[List[str]] = None, *args, **kwargs ) -> None: @@ -94,6 +94,15 @@ def __init__( def determine_op_kwargs(python_callable: Callable, context: Dict, num_op_args: int = 0) -> Dict: + """ + Function that will inspect the signature of a python_callable to determine which + values need to be passed to the function. + + :param python_callable: The function that you want to invoke + :param context: The context provided by the execute method of the Operator/Sensor + :param num_op_args: The number of op_args provided, so we know how many to skip + :return: The op_args dictionary which contains the values that are compatible with the Callable + """ context_keys = context.keys() sig = signature(python_callable).parameters.items() op_args_names = islice(sig, num_op_args) @@ -152,7 +161,7 @@ class BranchPythonOperator(PythonOperator, SkipMixin): to be inferred. """ - def execute(self, context): + def execute(self, context: Dict): branch = super().execute(context) self.skip_all_except(context['ti'], branch) @@ -170,7 +179,7 @@ class ShortCircuitOperator(PythonOperator, SkipMixin): The condition is determined by the result of `python_callable`. """ - def execute(self, context): + def execute(self, context: Dict): condition = super().execute(context) self.log.info("Condition result is %s", condition) From 2cde97b63a1670185de714f0bfd61ed1604c3db6 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Wed, 4 Sep 2019 13:56:27 +0200 Subject: [PATCH 13/15] Less is more --- airflow/contrib/sensors/python_sensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/contrib/sensors/python_sensor.py b/airflow/contrib/sensors/python_sensor.py index 146ab7ec39ba1..a4e5ec77aa520 100644 --- a/airflow/contrib/sensors/python_sensor.py +++ b/airflow/contrib/sensors/python_sensor.py @@ -20,7 +20,7 @@ from airflow.sensors.base_sensor_operator import BaseSensorOperator from airflow.utils.decorators import apply_defaults -from typing import Optional, Iterable, Dict, Callable, List +from typing import Optional, Dict, Callable, List class PythonSensor(BaseSensorOperator): From 95764ee35ca8295099c1dd95f3458b93f26c7795 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Wed, 4 Sep 2019 14:35:58 +0200 Subject: [PATCH 14/15] Patch tests --- tests/sensors/test_http_sensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/sensors/test_http_sensor.py b/tests/sensors/test_http_sensor.py index 40c374eb21853..1b6c3667ba653 100644 --- a/tests/sensors/test_http_sensor.py +++ b/tests/sensors/test_http_sensor.py @@ -63,7 +63,7 @@ def resp_check(resp): timeout=5, poke_interval=1) with self.assertRaisesRegex(AirflowException, 'AirflowException raised here!'): - task.execute(None) + task.execute(context={}) @patch("airflow.hooks.http_hook.requests.Session.send") def test_head_method(self, mock_session_send): @@ -81,7 +81,7 @@ def resp_check(_): timeout=5, poke_interval=1) - task.execute(None) + task.execute(context={}) args, kwargs = mock_session_send.call_args received_request = args[0] From b592e68f7ed0bff18dbbe4b326d18e4db97afe3f Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Tue, 10 Sep 2019 12:59:58 +0200 Subject: [PATCH 15/15] Feedback from Bas --- UPDATING.md | 38 +++++++++++++++++++++++++++- airflow/operators/python_operator.py | 4 +-- 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/UPDATING.md b/UPDATING.md index 6792ef6948076..306a2da454784 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -42,7 +42,43 @@ assists users migrating to a new version. ### Remove provide_context -Instead of settings `provide_context` we're automagically inferring the signature of the callable that is being passed to the PythonOperator. The only behavioural change in is that using a key that is already in the context in the function, such as `dag` or `ds` is not allowed anymore and will thrown an exception. If the `provide_context` is still explicitly passed to the function, it will just end up in the `kwargs`, which can cause no harm. +`provide_context` argument on the PythonOperator was removed. The signature of the callable passed to the PythonOperator is now inferred and argument values are always automatically provided. There is no need to explicitly provide or not provide the context anymore. For example: + +```python +def myfunc(execution_date): + print(execution_date) + +python_operator = PythonOperator(task_id='mytask', python_callable=myfunc, dag=dag) +``` + +Notice you don't have to set provide_context=True, variables from the task context are now automatically detected and provided. + +All context variables can still be provided with a double-asterisk argument: + +```python +def myfunc(**context): + print(context) # all variables will be provided to context + +python_operator = PythonOperator(task_id='mytask', python_callable=myfunc) +``` + +The task context variable names are reserved names in the callable function, hence a clash with `op_args` and `op_kwargs` results in an exception: + +```python +def myfunc(dag): + # raises a ValueError because "dag" is a reserved name + # valid signature example: myfunc(mydag) + +python_operator = PythonOperator( + task_id='mytask', + op_args=[1], + python_callable=myfunc, +) +``` + +The change is backwards compatible, setting `provide_context` will add the `provide_context` variable to the `kwargs` (but won't do anything). + +PR: [#5990](https://github.com/apache/airflow/pull/5990) ### Changes to FileSensor diff --git a/airflow/operators/python_operator.py b/airflow/operators/python_operator.py index 49b89658aa9c2..4d3c8da19f60c 100644 --- a/airflow/operators/python_operator.py +++ b/airflow/operators/python_operator.py @@ -107,7 +107,7 @@ def determine_op_kwargs(python_callable: Callable, sig = signature(python_callable).parameters.items() op_args_names = islice(sig, num_op_args) for name, _ in op_args_names: - # Check if it part of the context + # Check if it is part of the context if name in context_keys: # Raise an exception to let the user know that the keyword is reserved raise ValueError( @@ -115,7 +115,7 @@ def determine_op_kwargs(python_callable: Callable, ) if any(str(param).startswith("**") for _, param in sig): - # If there is a **kwargs, **context or **_ then just dump everything. + # If there is a ** argument then just dump everything. op_kwargs = context else: # If there is only for example, an execution_date, then pass only these in :-)