Skip to content

Commit

Permalink
BUGFIX: Make sure XComs work correctly in MSGraphAsyncOperator with p…
Browse files Browse the repository at this point in the history
…aged results and dynamic task mapping (apache#40301)



---------

Co-authored-by: David Blain <david.blain@infrabel.be>
  • Loading branch information
2 people authored and romsharon98 committed Jul 26, 2024
1 parent 9424c91 commit a213bbc
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 24 deletions.
57 changes: 34 additions & 23 deletions airflow/providers/microsoft/azure/operators/msgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,9 @@ def execute_complete(
event["response"] = result

try:
self.trigger_next_link(response, method_name=self.pull_execute_complete.__name__)
self.trigger_next_link(response=response, method_name=self.execute_complete.__name__)
except TaskDeferred as exception:
self.results = self.pull_xcom(context=context)
self.append_result(
result=result,
append_result_as_list_if_absent=True,
Expand All @@ -188,7 +189,6 @@ def execute_complete(
raise exception

self.append_result(result=result)
self.log.debug("results: %s", self.results)

return self.results
return None
Expand All @@ -198,8 +198,6 @@ def append_result(
result: Any,
append_result_as_list_if_absent: bool = False,
):
self.log.debug("value: %s", result)

if isinstance(self.results, list):
if isinstance(result, list):
self.results.extend(result)
Expand All @@ -214,30 +212,43 @@ def append_result(
else:
self.results = result

def push_xcom(self, context: Context, value) -> None:
self.log.debug("do_xcom_push: %s", self.do_xcom_push)
if self.do_xcom_push:
self.log.info("Pushing XCom with key '%s': %s", self.key, value)
self.xcom_push(context=context, key=self.key, value=value)

def pull_execute_complete(self, context: Context, event: dict[Any, Any] | None = None) -> Any:
self.results = list(
self.xcom_pull(
context=context,
def pull_xcom(self, context: Context) -> list:
map_index = context["ti"].map_index
value = list(
context["ti"].xcom_pull(
key=self.key,
task_ids=self.task_id,
dag_id=self.dag_id,
key=self.key,
map_indexes=map_index,
)
or []
)
self.log.info(
"Pulled XCom with task_id '%s' and dag_id '%s' and key '%s': %s",
self.task_id,
self.dag_id,
self.key,
self.results,
)
return self.execute_complete(context, event)

if map_index:
self.log.info(
"Pulled XCom with task_id '%s' and dag_id '%s' and key '%s' and map_index %s: %s",
self.task_id,
self.dag_id,
self.key,
map_index,
value,
)
else:
self.log.info(
"Pulled XCom with task_id '%s' and dag_id '%s' and key '%s': %s",
self.task_id,
self.dag_id,
self.key,
value,
)

return value

def push_xcom(self, context: Context, value) -> None:
self.log.debug("do_xcom_push: %s", self.do_xcom_push)
if self.do_xcom_push:
self.log.info("Pushing XCom with key '%s': %s", self.key, value)
self.xcom_push(context=context, key=self.key, value=value)

@staticmethod
def paginate(operator: MSGraphAsyncOperator, response: dict) -> tuple[Any, dict[str, Any] | None]:
Expand Down
4 changes: 3 additions & 1 deletion tests/providers/microsoft/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ def xcom_pull(
map_indexes: Iterable[int] | int | None = None,
default: Any | None = None,
) -> Any:
if map_indexes:
return values.get(f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}_{map_indexes}")
return values.get(f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}")

def xcom_push(
Expand All @@ -152,7 +154,7 @@ def xcom_push(
execution_date: datetime | None = None,
session: Session = NEW_SESSION,
) -> None:
values[f"{self.task_id}_{self.dag_id}_{key}"] = value
values[f"{self.task_id}_{self.dag_id}_{key}_{self.map_index}"] = value

values["ti"] = MockedTaskInstance(task=task)

Expand Down

0 comments on commit a213bbc

Please sign in to comment.