Skip to content

Commit

Permalink
refactor: Make sure xcoms work correctly in multi-threaded environmen…
Browse files Browse the repository at this point in the history
…t by taking the map_index into account (apache#40297)

Co-authored-by: David Blain <david.blain@infrabel.be>
  • Loading branch information
2 people authored and romsharon98 committed Jul 26, 2024
1 parent d61c207 commit 4572a67
Showing 1 changed file with 22 additions and 14 deletions.
36 changes: 22 additions & 14 deletions airflow/providers/microsoft/azure/operators/msgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,10 @@ def execute_complete(
event["response"] = result

try:
self.trigger_next_link(response, method_name=self.pull_execute_complete.__name__)
self.trigger_next_link(response=response, method_name=self.execute_complete.__name__)
except TaskDeferred as exception:
self.results = self.pull_xcom(context=context)
self.log.debug("value: %s", result)
self.append_result(
result=result,
append_result_as_list_if_absent=True,
Expand All @@ -198,8 +200,6 @@ def append_result(
result: Any,
append_result_as_list_if_absent: bool = False,
):
self.log.debug("value: %s", result)

if isinstance(self.results, list):
if isinstance(result, list):
self.results.extend(result)
Expand All @@ -214,30 +214,38 @@ def append_result(
else:
self.results = result

def push_xcom(self, context: Context, value) -> None:
self.log.debug("do_xcom_push: %s", self.do_xcom_push)
if self.do_xcom_push:
self.log.info("Pushing XCom with key '%s': %s", self.key, value)
self.xcom_push(context=context, key=self.key, value=value)
def xcom_key(self, context: Context) -> str:
map_index = context["ti"].map_index
return f"{self.key}_{map_index}" if map_index else self.key

def pull_execute_complete(self, context: Context, event: dict[Any, Any] | None = None) -> Any:
self.results = list(
def pull_xcom(self, context: Context) -> list:
key = self.xcom_key(context=context)
value = list(
self.xcom_pull(
context=context,
task_ids=self.task_id,
dag_id=self.dag_id,
key=self.key,
key=key,
)
or []
)

self.log.info(
"Pulled XCom with task_id '%s' and dag_id '%s' and key '%s': %s",
self.task_id,
self.dag_id,
self.key,
self.results,
key,
value,
)
return self.execute_complete(context, event)

return value

def push_xcom(self, context: Context, value) -> None:
self.log.debug("do_xcom_push: %s", self.do_xcom_push)
if self.do_xcom_push:
key = self.xcom_key(context=context)
self.log.info("Pushing XCom with key '%s': %s", key, value)
self.xcom_push(context=context, key=key, value=value)

@staticmethod
def paginate(operator: MSGraphAsyncOperator, response: dict) -> tuple[Any, dict[str, Any] | None]:
Expand Down

0 comments on commit 4572a67

Please sign in to comment.