diff --git a/UPDATING.md b/UPDATING.md index 98dc7faf07ee0..306a2da454784 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -37,9 +37,51 @@ assists users migrating to a new version. - [Airflow 1.7.1.2](#airflow-1712) + ## Airflow Master +### Remove provide_context + +`provide_context` argument on the PythonOperator was removed. The signature of the callable passed to the PythonOperator is now inferred and argument values are always automatically provided. There is no need to explicitly provide or not provide the context anymore. For example: + +```python +def myfunc(execution_date): + print(execution_date) + +python_operator = PythonOperator(task_id='mytask', python_callable=myfunc, dag=dag) +``` + +Notice you don't have to set provide_context=True, variables from the task context are now automatically detected and provided. + +All context variables can still be provided with a double-asterisk argument: + +```python +def myfunc(**context): + print(context) # all variables will be provided to context + +python_operator = PythonOperator(task_id='mytask', python_callable=myfunc) +``` + +The task context variable names are reserved names in the callable function, hence a clash with `op_args` and `op_kwargs` results in an exception: + +```python +def myfunc(dag): + # raises a ValueError because "dag" is a reserved name + # valid signature example: myfunc(mydag) + +python_operator = PythonOperator( + task_id='mytask', + op_args=[1], + python_callable=myfunc, +) +``` + +The change is backwards compatible, setting `provide_context` will add the `provide_context` variable to the `kwargs` (but won't do anything). + +PR: [#5990](https://github.com/apache/airflow/pull/5990) + ### Changes to FileSensor + FileSensor is now takes a glob pattern, not just a filename. If the filename you are looking for has `*`, `?`, or `[` in it then you should replace these with `[*]`, `[?]`, and `[[]`. ### Change dag loading duration metric name diff --git a/airflow/contrib/example_dags/example_jenkins_job_trigger_operator.py.notexecutable b/airflow/contrib/example_dags/example_jenkins_job_trigger_operator.py.notexecutable index 2d8906b0b3b9d..c0c3df61f5c13 100644 --- a/airflow/contrib/example_dags/example_jenkins_job_trigger_operator.py.notexecutable +++ b/airflow/contrib/example_dags/example_jenkins_job_trigger_operator.py.notexecutable @@ -7,9 +7,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -75,7 +75,6 @@ def grabArtifactFromJenkins(**context): artifact_grabber = PythonOperator( task_id='artifact_grabber', - provide_context=True, python_callable=grabArtifactFromJenkins, dag=dag) diff --git a/airflow/contrib/example_dags/example_qubole_operator.py b/airflow/contrib/example_dags/example_qubole_operator.py index 1f7e2a8ce9d8f..ef4681a85b798 100644 --- a/airflow/contrib/example_dags/example_qubole_operator.py +++ b/airflow/contrib/example_dags/example_qubole_operator.py @@ -97,7 +97,6 @@ def compare_result(**kwargs): t3 = PythonOperator( task_id='compare_result', - provide_context=True, python_callable=compare_result, trigger_rule="all_done", dag=dag) diff --git a/airflow/contrib/sensors/python_sensor.py b/airflow/contrib/sensors/python_sensor.py index a2dc2031a86bb..a4e5ec77aa520 100644 --- a/airflow/contrib/sensors/python_sensor.py +++ b/airflow/contrib/sensors/python_sensor.py @@ -16,9 +16,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from airflow.operators.python_operator import PythonOperator from airflow.sensors.base_sensor_operator import BaseSensorOperator from airflow.utils.decorators import apply_defaults +from typing import Optional, Dict, Callable, List class PythonSensor(BaseSensorOperator): @@ -38,12 +40,6 @@ class PythonSensor(BaseSensorOperator): :param op_args: a list of positional arguments that will get unpacked when calling your callable :type op_args: list - :param provide_context: if set to true, Airflow will pass a set of - keyword arguments that can be used in your function. This set of - kwargs correspond exactly to what you can use in your jinja - templates. For this to work, you need to define `**kwargs` in your - function header. - :type provide_context: bool :param templates_dict: a dictionary where the values are templates that will get templated by the Airflow engine sometime between ``__init__`` and ``execute`` takes place and are made available @@ -56,24 +52,21 @@ class PythonSensor(BaseSensorOperator): @apply_defaults def __init__( self, - python_callable, - op_args=None, - op_kwargs=None, - provide_context=False, - templates_dict=None, + python_callable: Callable, + op_args: Optional[List] = None, + op_kwargs: Optional[Dict] = None, + templates_dict: Optional[Dict] = None, *args, **kwargs): super().__init__(*args, **kwargs) self.python_callable = python_callable self.op_args = op_args or [] self.op_kwargs = op_kwargs or {} - self.provide_context = provide_context self.templates_dict = templates_dict - def poke(self, context): - if self.provide_context: - context.update(self.op_kwargs) - context['templates_dict'] = self.templates_dict - self.op_kwargs = context + def poke(self, context: Dict): + context.update(self.op_kwargs) + context['templates_dict'] = self.templates_dict + self.op_kwargs = PythonOperator.determine_op_kwargs(self.python_callable, context, len(self.op_args)) self.log.info("Poking callable: %s", str(self.python_callable)) return_value = self.python_callable(*self.op_args, **self.op_kwargs) diff --git a/airflow/example_dags/docker_copy_data.py b/airflow/example_dags/docker_copy_data.py index f091969777eeb..484f82f683df3 100644 --- a/airflow/example_dags/docker_copy_data.py +++ b/airflow/example_dags/docker_copy_data.py @@ -69,7 +69,6 @@ # # t_is_data_available = ShortCircuitOperator( # task_id='check_if_data_available', -# provide_context=True, # python_callable=is_data_available, # dag=dag) # diff --git a/airflow/example_dags/example_branch_python_dop_operator_3.py b/airflow/example_dags/example_branch_python_dop_operator_3.py index ec60cfc01b903..7455ef7ebbd23 100644 --- a/airflow/example_dags/example_branch_python_dop_operator_3.py +++ b/airflow/example_dags/example_branch_python_dop_operator_3.py @@ -58,7 +58,6 @@ def should_run(**kwargs): cond = BranchPythonOperator( task_id='condition', - provide_context=True, python_callable=should_run, dag=dag, ) diff --git a/airflow/example_dags/example_passing_params_via_test_command.py b/airflow/example_dags/example_passing_params_via_test_command.py index 152b8cde9e63b..e8fc9c963a916 100644 --- a/airflow/example_dags/example_passing_params_via_test_command.py +++ b/airflow/example_dags/example_passing_params_via_test_command.py @@ -37,17 +37,17 @@ ) -def my_py_command(**kwargs): +def my_py_command(test_mode, params): """ Print out the "foo" param passed in via `airflow tasks test example_passing_params_via_test_command run_this -tp '{"foo":"bar"}'` """ - if kwargs["test_mode"]: + if test_mode: print(" 'foo' was passed in via test={} command : kwargs[params][foo] \ - = {}".format(kwargs["test_mode"], kwargs["params"]["foo"])) + = {}".format(test_mode, params["foo"])) # Print out the value of "miff", passed in below via the Python Operator - print(" 'miff' was passed in via task params = {}".format(kwargs["params"]["miff"])) + print(" 'miff' was passed in via task params = {}".format(params["miff"])) return 1 @@ -58,7 +58,6 @@ def my_py_command(**kwargs): run_this = PythonOperator( task_id='run_this', - provide_context=True, python_callable=my_py_command, params={"miff": "agg"}, dag=dag, diff --git a/airflow/example_dags/example_python_operator.py b/airflow/example_dags/example_python_operator.py index 29c664f0a65ec..86403ceb25b63 100644 --- a/airflow/example_dags/example_python_operator.py +++ b/airflow/example_dags/example_python_operator.py @@ -48,7 +48,6 @@ def print_context(ds, **kwargs): run_this = PythonOperator( task_id='print_the_context', - provide_context=True, python_callable=print_context, dag=dag, ) diff --git a/airflow/example_dags/example_trigger_target_dag.py b/airflow/example_dags/example_trigger_target_dag.py index 4758176981152..32255103d804e 100644 --- a/airflow/example_dags/example_trigger_target_dag.py +++ b/airflow/example_dags/example_trigger_target_dag.py @@ -69,7 +69,6 @@ def run_this_func(**kwargs): run_this = PythonOperator( task_id='run_this', - provide_context=True, python_callable=run_this_func, dag=dag, ) diff --git a/airflow/example_dags/example_xcom.py b/airflow/example_dags/example_xcom.py index 8bd8e93b38cf0..fb043b021ebc3 100644 --- a/airflow/example_dags/example_xcom.py +++ b/airflow/example_dags/example_xcom.py @@ -26,7 +26,6 @@ args = { 'owner': 'Airflow', 'start_date': airflow.utils.dates.days_ago(2), - 'provide_context': True, } dag = DAG('example_xcom', schedule_interval="@once", default_args=args) diff --git a/airflow/gcp/utils/mlengine_operator_utils.py b/airflow/gcp/utils/mlengine_operator_utils.py index d09c2318061f3..dd44f7c543d25 100644 --- a/airflow/gcp/utils/mlengine_operator_utils.py +++ b/airflow/gcp/utils/mlengine_operator_utils.py @@ -241,7 +241,6 @@ def apply_validate_fn(*args, **kwargs): evaluate_validation = PythonOperator( task_id=(task_prefix + "-validation"), python_callable=apply_validate_fn, - provide_context=True, templates_dict={"prediction_path": prediction_path}, dag=dag) evaluate_validation.set_upstream(evaluate_summary) diff --git a/airflow/operators/python_operator.py b/airflow/operators/python_operator.py index 46430b215e93c..4d3c8da19f60c 100644 --- a/airflow/operators/python_operator.py +++ b/airflow/operators/python_operator.py @@ -23,8 +23,10 @@ import subprocess import sys import types +from inspect import signature +from itertools import islice from textwrap import dedent -from typing import Optional, Iterable, Dict, Callable +from typing import Optional, Iterable, Dict, Callable, List import dill @@ -51,12 +53,6 @@ class PythonOperator(BaseOperator): :param op_args: a list of positional arguments that will get unpacked when calling your callable :type op_args: list (templated) - :param provide_context: if set to true, Airflow will pass a set of - keyword arguments that can be used in your function. This set of - kwargs correspond exactly to what you can use in your jinja - templates. For this to work, you need to define `**kwargs` in your - function header. - :type provide_context: bool :param templates_dict: a dictionary where the values are templates that will get templated by the Airflow engine sometime between ``__init__`` and ``execute`` takes place and are made available @@ -77,11 +73,10 @@ class PythonOperator(BaseOperator): def __init__( self, python_callable: Callable, - op_args: Optional[Iterable] = None, + op_args: Optional[List] = None, op_kwargs: Optional[Dict] = None, - provide_context: bool = False, templates_dict: Optional[Dict] = None, - templates_exts: Optional[Iterable[str]] = None, + templates_exts: Optional[List[str]] = None, *args, **kwargs ) -> None: @@ -91,12 +86,47 @@ def __init__( self.python_callable = python_callable self.op_args = op_args or [] self.op_kwargs = op_kwargs or {} - self.provide_context = provide_context self.templates_dict = templates_dict if templates_exts: self.template_ext = templates_exts - def execute(self, context): + @staticmethod + def determine_op_kwargs(python_callable: Callable, + context: Dict, + num_op_args: int = 0) -> Dict: + """ + Function that will inspect the signature of a python_callable to determine which + values need to be passed to the function. + + :param python_callable: The function that you want to invoke + :param context: The context provided by the execute method of the Operator/Sensor + :param num_op_args: The number of op_args provided, so we know how many to skip + :return: The op_args dictionary which contains the values that are compatible with the Callable + """ + context_keys = context.keys() + sig = signature(python_callable).parameters.items() + op_args_names = islice(sig, num_op_args) + for name, _ in op_args_names: + # Check if it is part of the context + if name in context_keys: + # Raise an exception to let the user know that the keyword is reserved + raise ValueError( + "The key {} in the op_args is part of the context, and therefore reserved".format(name) + ) + + if any(str(param).startswith("**") for _, param in sig): + # If there is a ** argument then just dump everything. + op_kwargs = context + else: + # If there is only for example, an execution_date, then pass only these in :-) + op_kwargs = { + name: context[name] + for name, _ in sig + if name in context # If it isn't available on the context, then ignore + } + return op_kwargs + + def execute(self, context: Dict): # Export context to make it available for callables to use. airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True) self.log.info("Exporting the following env vars:\n%s", @@ -104,10 +134,10 @@ def execute(self, context): for k, v in airflow_context_vars.items()])) os.environ.update(airflow_context_vars) - if self.provide_context: - context.update(self.op_kwargs) - context['templates_dict'] = self.templates_dict - self.op_kwargs = context + context.update(self.op_kwargs) + context['templates_dict'] = self.templates_dict + + self.op_kwargs = PythonOperator.determine_op_kwargs(self.python_callable, context, len(self.op_args)) return_value = self.execute_callable() self.log.info("Done. Returned value was: %s", return_value) @@ -130,7 +160,8 @@ class BranchPythonOperator(PythonOperator, SkipMixin): downstream to allow for the DAG state to fill up and the DAG run's state to be inferred. """ - def execute(self, context): + + def execute(self, context: Dict): branch = super().execute(context) self.skip_all_except(context['ti'], branch) @@ -147,7 +178,8 @@ class ShortCircuitOperator(PythonOperator, SkipMixin): The condition is determined by the result of `python_callable`. """ - def execute(self, context): + + def execute(self, context: Dict): condition = super().execute(context) self.log.info("Condition result is %s", condition) @@ -200,12 +232,6 @@ class PythonVirtualenvOperator(PythonOperator): :type op_kwargs: list :param op_kwargs: A dict of keyword arguments to pass to python_callable. :type op_kwargs: dict - :param provide_context: if set to true, Airflow will pass a set of - keyword arguments that can be used in your function. This set of - kwargs correspond exactly to what you can use in your jinja - templates. For this to work, you need to define `**kwargs` in your - function header. - :type provide_context: bool :param string_args: Strings that are present in the global var virtualenv_string_args, available to python_callable at runtime as a list[str]. Note that args are split by newline. @@ -219,6 +245,7 @@ class PythonVirtualenvOperator(PythonOperator): processing templated fields, for examples ``['.sql', '.hql']`` :type templates_exts: list[str] """ + @apply_defaults def __init__( self, @@ -229,7 +256,6 @@ def __init__( system_site_packages: bool = True, op_args: Iterable = None, op_kwargs: Dict = None, - provide_context: bool = False, string_args: Optional[Iterable[str]] = None, templates_dict: Optional[Dict] = None, templates_exts: Optional[Iterable[str]] = None, @@ -242,7 +268,6 @@ def __init__( op_kwargs=op_kwargs, templates_dict=templates_dict, templates_exts=templates_exts, - provide_context=provide_context, *args, **kwargs) self.requirements = requirements or [] @@ -264,8 +289,8 @@ def __init__( self.__class__.__name__) # check that args are passed iff python major version matches if (python_version is not None and - str(python_version)[0] != str(sys.version_info[0]) and - self._pass_op_args()): + str(python_version)[0] != str(sys.version_info[0]) and + self._pass_op_args()): raise AirflowException("Passing op_args or op_kwargs is not supported across " "different Python major versions " "for PythonVirtualenvOperator. " @@ -383,7 +408,7 @@ def _generate_python_code(self): fn = self.python_callable # dont try to read pickle if we didnt pass anything if self._pass_op_args(): - load_args_line = 'with open(sys.argv[1], "rb") as file: arg_dict = {}.load(file)'\ + load_args_line = 'with open(sys.argv[1], "rb") as file: arg_dict = {}.load(file)' \ .format(pickling_library) else: load_args_line = 'arg_dict = {"args": [], "kwargs": {}}' diff --git a/airflow/sensors/http_sensor.py b/airflow/sensors/http_sensor.py index 6e5d69946232e..3102d33f4c049 100644 --- a/airflow/sensors/http_sensor.py +++ b/airflow/sensors/http_sensor.py @@ -16,6 +16,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Dict, Callable + +from airflow.operators.python_operator import PythonOperator from airflow.exceptions import AirflowException from airflow.hooks.http_hook import HttpHook @@ -31,13 +34,17 @@ class HttpSensor(BaseSensorOperator): HTTP Error codes other than 404 (like 403) or Connection Refused Error would fail the sensor itself directly (no more poking). - The response check can access the template context by passing ``provide_context=True`` to the operator:: + The response check can access the template context to the operator: - def response_check(response, **context): - # Can look at context['ti'] etc. + def response_check(response, task_instance): + # The task_instance is injected, so you can pull data form xcom + # Other context variables such as dag, ds, execution_date are also available. + xcom_data = task_instance.xcom_pull(task_ids='pushing_task') + # In practice you would do something more sensible with this data.. + print(xcom_data) return True - HttpSensor(task_id='my_http_sensor', ..., provide_context=True, response_check=response_check) + HttpSensor(task_id='my_http_sensor', ..., response_check=response_check) :param http_conn_id: The connection to run the sensor against @@ -50,12 +57,6 @@ def response_check(response, **context): :type request_params: a dictionary of string key/value pairs :param headers: The HTTP headers to be added to the GET request :type headers: a dictionary of string key/value pairs - :param provide_context: if set to true, Airflow will pass a set of - keyword arguments that can be used in your function. This set of - kwargs correspond exactly to what you can use in your jinja - templates. For this to work, you need to define context in your - function header. - :type provide_context: bool :param response_check: A check against the 'requests' response object. Returns True for 'pass' and False otherwise. :type response_check: A lambda or defined function. @@ -69,14 +70,13 @@ def response_check(response, **context): @apply_defaults def __init__(self, - endpoint, - http_conn_id='http_default', - method='GET', - request_params=None, - headers=None, - response_check=None, - provide_context=False, - extra_options=None, *args, **kwargs): + endpoint: str, + http_conn_id: str = 'http_default', + method: str = 'GET', + request_params: Dict = None, + headers: Dict = None, + response_check: Callable = None, + extra_options: Dict = None, *args, **kwargs): super().__init__(*args, **kwargs) self.endpoint = endpoint self.http_conn_id = http_conn_id @@ -84,13 +84,12 @@ def __init__(self, self.headers = headers or {} self.extra_options = extra_options or {} self.response_check = response_check - self.provide_context = provide_context self.hook = HttpHook( method=method, http_conn_id=http_conn_id) - def poke(self, context): + def poke(self, context: Dict): self.log.info('Poking: %s', self.endpoint) try: response = self.hook.run(self.endpoint, @@ -98,10 +97,9 @@ def poke(self, context): headers=self.headers, extra_options=self.extra_options) if self.response_check: - if self.provide_context: - return self.response_check(response, **context) - else: - return self.response_check(response) + op_kwargs = PythonOperator.determine_op_kwargs(self.response_check, context) + return self.response_check(response, **op_kwargs) + except AirflowException as ae: if str(ae).startswith("404"): return False diff --git a/docs/concepts.rst b/docs/concepts.rst index db4a0f34d23b5..83c9f0e1c79fd 100644 --- a/docs/concepts.rst +++ b/docs/concepts.rst @@ -549,9 +549,12 @@ passed, then a corresponding list of XCom values is returned. def push_function(): return value - # inside another PythonOperator where provide_context=True - def pull_function(**context): - value = context['task_instance'].xcom_pull(task_ids='pushing_task') + # inside another PythonOperator + def pull_function(task_instance): + value = task_instance.xcom_pull(task_ids='pushing_task') + +When specifying arguments that are part of the context, they will be +automatically passed to the function. It is also possible to pull XCom directly in a template, here's an example of what this may look like: @@ -633,8 +636,7 @@ For example: .. code:: python - def branch_func(**kwargs): - ti = kwargs['ti'] + def branch_func(ti): xcom_value = int(ti.xcom_pull(task_ids='start_task')) if xcom_value >= 5: return 'continue_task' @@ -649,7 +651,6 @@ For example: branch_op = BranchPythonOperator( task_id='branch_task', - provide_context=True, python_callable=branch_func, dag=dag) diff --git a/docs/howto/operator/python.rst b/docs/howto/operator/python.rst index d0a0da4fb7e45..e3ff9d61ac5e7 100644 --- a/docs/howto/operator/python.rst +++ b/docs/howto/operator/python.rst @@ -44,9 +44,9 @@ to the Python callable. Templating ^^^^^^^^^^ -When you set the ``provide_context`` argument to ``True``, Airflow passes in -an additional set of keyword arguments: one for each of the :doc:`Jinja -template variables <../../macros-ref>` and a ``templates_dict`` argument. +Airflow passes in an additional set of keyword arguments: one for each of the +:doc:`Jinja template variables <../../macros-ref>` and a ``templates_dict`` +argument. The ``templates_dict`` argument is templated, so each value in the dictionary is evaluated as a :ref:`Jinja template `. diff --git a/tests/contrib/hooks/test_aws_glue_catalog_hook.py b/tests/contrib/hooks/test_aws_glue_catalog_hook.py index 311dbcaf3390f..85b2777e88c51 100644 --- a/tests/contrib/hooks/test_aws_glue_catalog_hook.py +++ b/tests/contrib/hooks/test_aws_glue_catalog_hook.py @@ -43,6 +43,7 @@ } } + @unittest.skipIf(mock_glue is None, "Skipping test because moto.mock_glue is not available") class TestAwsGlueCatalogHook(unittest.TestCase): diff --git a/tests/contrib/operators/test_aws_athena_operator.py b/tests/contrib/operators/test_aws_athena_operator.py index b86bbab0f1369..c3d98e25db97d 100644 --- a/tests/contrib/operators/test_aws_athena_operator.py +++ b/tests/contrib/operators/test_aws_athena_operator.py @@ -53,7 +53,6 @@ def setUp(self): args = { 'owner': 'airflow', 'start_date': DEFAULT_DATE, - 'provide_context': True } self.dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', diff --git a/tests/contrib/operators/test_s3_to_sftp_operator.py b/tests/contrib/operators/test_s3_to_sftp_operator.py index c6cd0369557d4..4d1e2fd5134b9 100644 --- a/tests/contrib/operators/test_s3_to_sftp_operator.py +++ b/tests/contrib/operators/test_s3_to_sftp_operator.py @@ -69,7 +69,6 @@ def setUp(self): args = { 'owner': 'airflow', 'start_date': DEFAULT_DATE, - 'provide_context': True } dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args) dag.schedule_interval = '@once' diff --git a/tests/contrib/operators/test_sftp_operator.py b/tests/contrib/operators/test_sftp_operator.py index 298cdb244bd4c..b5aa652691158 100644 --- a/tests/contrib/operators/test_sftp_operator.py +++ b/tests/contrib/operators/test_sftp_operator.py @@ -53,7 +53,6 @@ def setUp(self): args = { 'owner': 'airflow', 'start_date': DEFAULT_DATE, - 'provide_context': True } dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args) dag.schedule_interval = '@once' diff --git a/tests/contrib/operators/test_sftp_to_s3_operator.py b/tests/contrib/operators/test_sftp_to_s3_operator.py index c86c51b8e8a1c..00214107f6f6d 100644 --- a/tests/contrib/operators/test_sftp_to_s3_operator.py +++ b/tests/contrib/operators/test_sftp_to_s3_operator.py @@ -68,7 +68,6 @@ def setUp(self): args = { 'owner': 'airflow', 'start_date': DEFAULT_DATE, - 'provide_context': True } dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args) dag.schedule_interval = '@once' diff --git a/tests/contrib/operators/test_ssh_operator.py b/tests/contrib/operators/test_ssh_operator.py index f66541d010a13..e8cb8d00d72a2 100644 --- a/tests/contrib/operators/test_ssh_operator.py +++ b/tests/contrib/operators/test_ssh_operator.py @@ -51,7 +51,6 @@ def setUp(self): args = { 'owner': 'airflow', 'start_date': DEFAULT_DATE, - 'provide_context': True } dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args) dag.schedule_interval = '@once' diff --git a/tests/contrib/sensors/test_file_sensor.py b/tests/contrib/sensors/test_file_sensor.py index f7704ddd76d34..6d7b3f3d5e0e8 100644 --- a/tests/contrib/sensors/test_file_sensor.py +++ b/tests/contrib/sensors/test_file_sensor.py @@ -50,8 +50,7 @@ def setUp(self): hook = FSHook() args = { 'owner': 'airflow', - 'start_date': DEFAULT_DATE, - 'provide_context': True + 'start_date': DEFAULT_DATE } dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args) dag.schedule_interval = '@once' diff --git a/tests/contrib/sensors/test_gcs_upload_session_sensor.py b/tests/contrib/sensors/test_gcs_upload_session_sensor.py index c230835923120..ea4515b6ecb57 100644 --- a/tests/contrib/sensors/test_gcs_upload_session_sensor.py +++ b/tests/contrib/sensors/test_gcs_upload_session_sensor.py @@ -62,7 +62,6 @@ def setUp(self): args = { 'owner': 'airflow', 'start_date': DEFAULT_DATE, - 'provide_context': True } dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args) dag.schedule_interval = '@once' diff --git a/tests/core.py b/tests/core.py index 0dfb9ba3d0ba1..f7e6f4c5f5ef6 100644 --- a/tests/core.py +++ b/tests/core.py @@ -568,7 +568,6 @@ def test_py_op(templates_dict, ds, **kwargs): t = PythonOperator( task_id='test_py_op', - provide_context=True, python_callable=test_py_op, templates_dict={'ds': "{{ ds }}"}, dag=self.dag) @@ -2179,6 +2178,7 @@ def test_init_proxy_user(self): HDFSHook = None # type: Optional[hdfs_hook.HDFSHook] snakebite = None # type: None + @unittest.skipIf(HDFSHook is None, "Skipping test because HDFSHook is not installed") class TestHDFSHook(unittest.TestCase): diff --git a/tests/dags/test_cli_triggered_dags.py b/tests/dags/test_cli_triggered_dags.py index 7747d20710b1d..64d827dc9cf83 100644 --- a/tests/dags/test_cli_triggered_dags.py +++ b/tests/dags/test_cli_triggered_dags.py @@ -50,6 +50,5 @@ def success(ti=None, *args, **kwargs): dag1_task2 = PythonOperator( task_id='test_run_dependent_task', python_callable=success, - provide_context=True, dag=dag1) dag1_task1.set_downstream(dag1_task2) diff --git a/tests/dags/test_dag_serialization.py b/tests/dags/test_dag_serialization.py index 9618925efa312..1934c642a4232 100644 --- a/tests/dags/test_dag_serialization.py +++ b/tests/dags/test_dag_serialization.py @@ -158,7 +158,7 @@ def make_example_dags(module, dag_ids): def make_simple_dag(): """Make very simple DAG to verify serialization result.""" dag = DAG(dag_id='simple_dag') - _ = BaseOperator(task_id='simple_task', dag=dag, start_date=datetime(2019, 8, 1)) + BaseOperator(task_id='simple_task', dag=dag, start_date=datetime(2019, 8, 1)) return {'simple_dag': dag} @@ -186,7 +186,7 @@ def compute_next_execution_date(dag, execution_date): }, catchup=False ) - _ = BashOperator( + BashOperator( task_id='echo', bash_command='echo "{{ next_execution_date(dag, execution_date) }}"', dag=dag, diff --git a/tests/operators/test_python_operator.py b/tests/operators/test_python_operator.py index 3dd8b323fc40d..497a001939262 100644 --- a/tests/operators/test_python_operator.py +++ b/tests/operators/test_python_operator.py @@ -57,8 +57,8 @@ def build_recording_function(calls_collection): Then using this custom function recording custom Call objects for further testing (replacing Mock.assert_called_with assertion method) """ - def recording_function(*args, **kwargs): - calls_collection.append(Call(*args, **kwargs)) + def recording_function(*args): + calls_collection.append(Call(*args)) return recording_function @@ -129,11 +129,10 @@ def test_python_operator_python_callable_is_callable(self): task_id='python_operator', dag=self.dag) - def _assertCallsEqual(self, first, second): + def _assert_calls_equal(self, first, second): self.assertIsInstance(first, Call) self.assertIsInstance(second, Call) self.assertTupleEqual(first.args, second.args) - self.assertDictEqual(first.kwargs, second.kwargs) def test_python_callable_arguments_are_templatized(self): """Test PythonOperator op_args are templatized""" @@ -148,7 +147,7 @@ def test_python_callable_arguments_are_templatized(self): task_id='python_operator', # a Mock instance cannot be used as a callable function or test fails with a # TypeError: Object of type Mock is not JSON serializable - python_callable=(build_recording_function(recorded_calls)), + python_callable=build_recording_function(recorded_calls), op_args=[ 4, date(2019, 1, 1), @@ -167,7 +166,7 @@ def test_python_callable_arguments_are_templatized(self): ds_templated = DEFAULT_DATE.date().isoformat() self.assertEqual(1, len(recorded_calls)) - self._assertCallsEqual( + self._assert_calls_equal( recorded_calls[0], Call(4, date(2019, 1, 1), @@ -183,7 +182,7 @@ def test_python_callable_keyword_arguments_are_templatized(self): task_id='python_operator', # a Mock instance cannot be used as a callable function or test fails with a # TypeError: Object of type Mock is not JSON serializable - python_callable=(build_recording_function(recorded_calls)), + python_callable=build_recording_function(recorded_calls), op_kwargs={ 'an_int': 4, 'a_date': date(2019, 1, 1), @@ -200,7 +199,7 @@ def test_python_callable_keyword_arguments_are_templatized(self): task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) self.assertEqual(1, len(recorded_calls)) - self._assertCallsEqual( + self._assert_calls_equal( recorded_calls[0], Call(an_int=4, a_date=date(2019, 1, 1), @@ -251,6 +250,74 @@ def test_echo_env_variables(self): ) t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + def test_conflicting_kwargs(self): + self.dag.create_dagrun( + run_id='manual__' + DEFAULT_DATE.isoformat(), + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + state=State.RUNNING, + external_trigger=False, + ) + + # dag is not allowed since it is a reserved keyword + def fn(dag): + # An ValueError should be triggered since we're using dag as a + # reserved keyword + raise RuntimeError("Should not be triggered, dag: {}".format(dag)) + + python_operator = PythonOperator( + task_id='python_operator', + op_args=[1], + python_callable=fn, + dag=self.dag + ) + + with self.assertRaises(ValueError) as context: + python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + self.assertTrue('dag' in context.exception, "'dag' not found in the exception") + + def test_context_with_conflicting_op_args(self): + self.dag.create_dagrun( + run_id='manual__' + DEFAULT_DATE.isoformat(), + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + state=State.RUNNING, + external_trigger=False, + ) + + def fn(custom, dag): + self.assertEqual(1, custom, "custom should be 1") + self.assertIsNotNone(dag, "dag should be set") + + python_operator = PythonOperator( + task_id='python_operator', + op_kwargs={'custom': 1}, + python_callable=fn, + dag=self.dag + ) + python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + + def test_context_with_kwargs(self): + self.dag.create_dagrun( + run_id='manual__' + DEFAULT_DATE.isoformat(), + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + state=State.RUNNING, + external_trigger=False, + ) + + def fn(**context): + # check if context is being set + self.assertGreater(len(context), 0, "Context has not been injected") + + python_operator = PythonOperator( + task_id='python_operator', + op_kwargs={'custom': 1}, + python_callable=fn, + dag=self.dag + ) + python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + class TestBranchOperator(unittest.TestCase): @classmethod diff --git a/tests/operators/test_virtualenv_operator.py b/tests/operators/test_virtualenv_operator.py index 52c14f88806d4..97c3dcc4eb473 100644 --- a/tests/operators/test_virtualenv_operator.py +++ b/tests/operators/test_virtualenv_operator.py @@ -199,20 +199,6 @@ def f(_): self._run_as_operator(f, op_args=[datetime.datetime.utcnow()]) def test_context(self): - def f(**kwargs): - return kwargs['templates_dict']['ds'] + def f(templates_dict): + return templates_dict['ds'] self._run_as_operator(f, templates_dict={'ds': '{{ ds }}'}) - - def test_provide_context(self): - def fn(): - pass - task = PythonVirtualenvOperator( - python_callable=fn, - python_version=sys.version_info[0], - task_id='task', - dag=self.dag, - provide_context=True, - ) - self.assertTrue( - task.provide_context - ) diff --git a/tests/sensors/test_http_sensor.py b/tests/sensors/test_http_sensor.py index 37a0a48e50906..1b6c3667ba653 100644 --- a/tests/sensors/test_http_sensor.py +++ b/tests/sensors/test_http_sensor.py @@ -63,7 +63,7 @@ def resp_check(resp): timeout=5, poke_interval=1) with self.assertRaisesRegex(AirflowException, 'AirflowException raised here!'): - task.execute(None) + task.execute(context={}) @patch("airflow.hooks.http_hook.requests.Session.send") def test_head_method(self, mock_session_send): @@ -81,7 +81,7 @@ def resp_check(_): timeout=5, poke_interval=1) - task.execute(None) + task.execute(context={}) args, kwargs = mock_session_send.call_args received_request = args[0] @@ -96,19 +96,13 @@ def resp_check(_): @patch("airflow.hooks.http_hook.requests.Session.send") def test_poke_context(self, mock_session_send): - """ - test provide_context - """ response = requests.Response() response.status_code = 200 mock_session_send.return_value = response - def resp_check(resp, **context): - if context: - if "execution_date" in context: - if context["execution_date"] == DEFAULT_DATE: - return True - + def resp_check(resp, execution_date): + if execution_date == DEFAULT_DATE: + return True raise AirflowException('AirflowException raised here!') task = HttpSensor( @@ -117,7 +111,6 @@ def resp_check(resp, **context): endpoint='', request_params={}, response_check=resp_check, - provide_context=True, timeout=5, poke_interval=1, dag=self.dag) diff --git a/tests/utils/test_log_handlers.py b/tests/utils/test_log_handlers.py index e06bdae90891e..7126871cb5807 100644 --- a/tests/utils/test_log_handlers.py +++ b/tests/utils/test_log_handlers.py @@ -71,7 +71,6 @@ def task_callable(ti, **kwargs): task_id='task_for_testing_file_log_handler', dag=dag, python_callable=task_callable, - provide_context=True ) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) @@ -123,7 +122,6 @@ def task_callable(ti, **kwargs): task_id='task_for_testing_file_log_handler', dag=dag, python_callable=task_callable, - provide_context=True ) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.try_number = 2