Skip to content

Commit

Permalink
Fix rendering the mapped parameters when using expand_kwargs method (
Browse files Browse the repository at this point in the history
…#32272)

* Fix rendering the mapped parameters in the mapped operator

Signed-off-by: Hussein Awala <hussein@awala.fr>

* add template_in_template arg to expand method to tell Airflow whether to resolve the xcom data or not

* fix dag serialization tests

* Revert "fix dag serialization tests"

This reverts commit 191351c.

* Revert "add template_in_template arg to expand method to tell Airflow whether to resolve the xcom data or not"

This reverts commit 14bd392.

* Fix ListOfDictsExpandInput resolve method

* remove _iter_parse_time_resolved_kwargs method

* remove unnecessary step

---------

Signed-off-by: Hussein Awala <hussein@awala.fr>
  • Loading branch information
hussein-awala authored Aug 18, 2023
1 parent 58eb19f commit d1e6a5c
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 13 deletions.
5 changes: 4 additions & 1 deletion airflow/models/expandinput.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,10 @@ def resolve(self, context: Context, session: Session) -> tuple[Mapping[str, Any]
f"expand_kwargs() input dict keys must all be str, "
f"but {key!r} is of type {_describe_type(key)}"
)
return mapping, {id(v) for v in mapping.values()}
# filter out parse time resolved values from the resolved_oids
resolved_oids = {id(v) for k, v in mapping.items() if not _is_parse_time_mappable(v)}

return mapping, resolved_oids


EXPAND_INPUT_EMPTY = DictOfListsExpandInput({}) # Sentinel value.
Expand Down
89 changes: 77 additions & 12 deletions tests/models/test_mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import logging
from collections import defaultdict
from datetime import timedelta
from unittest import mock
from unittest.mock import patch

import pendulum
Expand Down Expand Up @@ -399,17 +400,31 @@ def test_mapped_expand_against_params(dag_maker, dag_params, task_params, expect


def test_mapped_render_template_fields_validating_operator(dag_maker, session):
class MyOperator(MockOperator):
def __init__(self, value, arg1, **kwargs):
assert isinstance(value, str), "value should have been resolved before unmapping"
assert isinstance(arg1, str), "value should have been resolved before unmapping"
super().__init__(arg1=arg1, **kwargs)
self.value = value
class MyOperator(BaseOperator):
template_fields = ("partial_template", "map_template", "file_template")
template_ext = (".ext",)

def __init__(
self, partial_template, partial_static, map_template, map_static, file_template, **kwargs
):
for value in [partial_template, partial_static, map_template, map_static, file_template]:
assert isinstance(value, str), "value should have been resolved before unmapping"
super().__init__(**kwargs)
self.partial_template = partial_template
self.partial_static = partial_static
self.map_template = map_template
self.map_static = map_static
self.file_template = file_template

def execute(self, context):
pass

with dag_maker(session=session):
task1 = BaseOperator(task_id="op1")
output1 = task1.output
mapped = MyOperator.partial(task_id="a", arg2="{{ ti.task_id }}").expand(value=output1, arg1=output1)
mapped = MyOperator.partial(
task_id="a", partial_template="{{ ti.task_id }}", partial_static="{{ ti.task_id }}"
).expand(map_template=output1, map_static=output1, file_template=["/path/to/file.ext"])

dr = dag_maker.create_dagrun()
ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session)
Expand All @@ -432,12 +447,62 @@ def __init__(self, value, arg1, **kwargs):
mapped_ti.map_index = 0

assert isinstance(mapped_ti.task, MappedOperator)
mapped.render_template_fields(context=mapped_ti.get_template_context(session=session))
with patch("builtins.open", mock.mock_open(read_data=b"loaded data")), patch(
"os.path.isfile", return_value=True
), patch("os.path.getmtime", return_value=0):
mapped.render_template_fields(context=mapped_ti.get_template_context(session=session))
assert isinstance(mapped_ti.task, MyOperator)

assert mapped_ti.task.partial_template == "a", "Should be templated!"
assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be templated!"
assert mapped_ti.task.map_template == "{{ ds }}", "Should not be templated!"
assert mapped_ti.task.map_static == "{{ ds }}", "Should not be templated!"
assert mapped_ti.task.file_template == "loaded data", "Should be templated!"


def test_mapped_expand_kwargs_render_template_fields_validating_operator(dag_maker, session):
class MyOperator(BaseOperator):
template_fields = ("partial_template", "map_template", "file_template")
template_ext = (".ext",)

def __init__(
self, partial_template, partial_static, map_template, map_static, file_template, **kwargs
):
for value in [partial_template, partial_static, map_template, map_static, file_template]:
assert isinstance(value, str), "value should have been resolved before unmapping"
super().__init__(**kwargs)
self.partial_template = partial_template
self.partial_static = partial_static
self.map_template = map_template
self.map_static = map_static
self.file_template = file_template

def execute(self, context):
pass

with dag_maker(session=session):
mapped = MyOperator.partial(
task_id="a", partial_template="{{ ti.task_id }}", partial_static="{{ ti.task_id }}"
).expand_kwargs(
[{"map_template": "{{ ds }}", "map_static": "{{ ds }}", "file_template": "/path/to/file.ext"}]
)

dr = dag_maker.create_dagrun()

mapped_ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session, map_index=0)

assert isinstance(mapped_ti.task, MappedOperator)
with patch("builtins.open", mock.mock_open(read_data=b"loaded data")), patch(
"os.path.isfile", return_value=True
), patch("os.path.getmtime", return_value=0):
mapped.render_template_fields(context=mapped_ti.get_template_context(session=session))
assert isinstance(mapped_ti.task, MyOperator)

assert mapped_ti.task.value == "{{ ds }}", "Should not be templated!"
assert mapped_ti.task.arg1 == "{{ ds }}", "Should not be templated!"
assert mapped_ti.task.arg2 == "a"
assert mapped_ti.task.partial_template == "a", "Should be templated!"
assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be templated!"
assert mapped_ti.task.map_template == "2016-01-01", "Should be templated!"
assert mapped_ti.task.map_static == "{{ ds }}", "Should not be templated!"
assert mapped_ti.task.file_template == "loaded data", "Should be templated!"


def test_mapped_render_nested_template_fields(dag_maker, session):
Expand Down Expand Up @@ -534,7 +599,7 @@ def test_expand_kwargs_mapped_task_instance(dag_maker, session, num_existing_tis
@pytest.mark.parametrize(
"map_index, expected",
[
pytest.param(0, "{{ ds }}", id="0"),
pytest.param(0, "2016-01-01", id="0"),
pytest.param(1, 2, id="1"),
],
)
Expand Down

0 comments on commit d1e6a5c

Please sign in to comment.