From 47008b66e4f4fb77bcf22fbb9028296c98de432e Mon Sep 17 00:00:00 2001
From: tfx-team <tensorflow-extended-nonhuman@googlegroups.com>
Date: Wed, 7 Aug 2024 15:02:59 -0700
Subject: [PATCH] no-op

PiperOrigin-RevId: 660551866
---
 .../core/async_pipeline_task_gen.py           |  2 +
 .../experimental/core/pipeline_state.py       |  8 +++
 .../core/sync_pipeline_task_gen.py            |  2 +
 .../experimental/core/task_gen_utils.py       | 61 ++++++++++++++++---
 .../portable/execution_publish_utils.py       |  7 ++-
 .../portable/importer_node_handler.py         | 22 ++++++-
 .../portable/partial_run_utils.py             | 46 +++++++++++---
 .../portable/resolver_node_handler.py         | 37 ++++++++++-
 8 files changed, 163 insertions(+), 22 deletions(-)

diff --git a/tfx/orchestration/experimental/core/async_pipeline_task_gen.py b/tfx/orchestration/experimental/core/async_pipeline_task_gen.py
index 60a36b773b..0b7204ba5f 100644
--- a/tfx/orchestration/experimental/core/async_pipeline_task_gen.py
+++ b/tfx/orchestration/experimental/core/async_pipeline_task_gen.py
@@ -490,6 +490,8 @@ def _generate_tasks_for_node(
         execution_type=node.node_info.type,
         contexts=resolved_info.contexts,
         input_and_params=unprocessed_inputs,
+        pipeline=self._pipeline,
+        node_id=node.node_info.id,
     )
 
     for execution in executions:
diff --git a/tfx/orchestration/experimental/core/pipeline_state.py b/tfx/orchestration/experimental/core/pipeline_state.py
index 9be76a4792..516087456a 100644
--- a/tfx/orchestration/experimental/core/pipeline_state.py
+++ b/tfx/orchestration/experimental/core/pipeline_state.py
@@ -1673,3 +1673,11 @@ def get_pipeline_and_node(
           'pipeline nodes are supported for external executions.'
       )
     return (pipeline_state.pipeline, node)
+
+
+def get_pipeline(
+    mlmd_handle: metadata.Metadata, pipeline_id: str
+) -> pipeline_pb2.Pipeline:
+  """Loads the pipeline proto for a pipeline from latest execution."""
+  pipeline_view = PipelineView.load(mlmd_handle, pipeline_id)
+  return pipeline_view.pipeline
diff --git a/tfx/orchestration/experimental/core/sync_pipeline_task_gen.py b/tfx/orchestration/experimental/core/sync_pipeline_task_gen.py
index 04f49cdeca..cf2a965a8a 100644
--- a/tfx/orchestration/experimental/core/sync_pipeline_task_gen.py
+++ b/tfx/orchestration/experimental/core/sync_pipeline_task_gen.py
@@ -564,6 +564,8 @@ def _generate_tasks_from_resolved_inputs(
         execution_type=node.node_info.type,
         contexts=resolved_info.contexts,
         input_and_params=resolved_info.input_and_params,
+        pipeline=self._pipeline,
+        node_id=node.node_info.id,
     )
 
     result.extend(
diff --git a/tfx/orchestration/experimental/core/task_gen_utils.py b/tfx/orchestration/experimental/core/task_gen_utils.py
index 514c042fd2..4db7b00e7d 100644
--- a/tfx/orchestration/experimental/core/task_gen_utils.py
+++ b/tfx/orchestration/experimental/core/task_gen_utils.py
@@ -30,6 +30,7 @@
 from tfx.orchestration import metadata
 from tfx.orchestration import node_proto_view
 from tfx.orchestration.experimental.core import constants
+from tfx.orchestration.experimental.core import env
 from tfx.orchestration.experimental.core import mlmd_state
 from tfx.orchestration.experimental.core import task as task_lib
 from tfx.orchestration import mlmd_connection_manager as mlmd_cm
@@ -548,21 +549,41 @@ def register_executions_from_existing_executions(
   contexts = metadata_handle.store.get_contexts_by_execution(
       existing_executions[0].id
   )
-  return execution_lib.put_executions(
+  executions = execution_lib.put_executions(
       metadata_handle,
       new_executions,
       contexts,
       input_artifacts_maps=input_artifacts,
   )
 
+  pipeline_asset = metadata_handle.store.pipeline_asset
+  if pipeline_asset:
+    env.get_env().create_pipeline_run_node_executions(
+        pipeline_asset.owner,
+        pipeline_asset.name,
+        pipeline,
+        node.node_info.id,
+        executions,
+    )
+  else:
+    logging.warning(
+        'Pipeline asset %s not found in MLMD. Unable to create pipeline run'
+        ' node executions.',
+        pipeline_asset,
+    )
+  return executions
+
 
+# TODO(b/349654866): make pipeline and node_id non-optional.
 def register_executions(
     metadata_handle: metadata.Metadata,
     execution_type: metadata_store_pb2.ExecutionType,
     contexts: Sequence[metadata_store_pb2.Context],
     input_and_params: Sequence[InputAndParam],
+    pipeline: Optional[pipeline_pb2.Pipeline] = None,
+    node_id: Optional[str] = None,
 ) -> Sequence[metadata_store_pb2.Execution]:
-  """Registers multiple executions in MLMD.
+  """Registers multiple executions in storage backends.
 
   Along with the execution:
   -  the input artifacts will be linked to the executions.
@@ -575,6 +596,8 @@ def register_executions(
     input_and_params: A list of InputAndParams, which includes input_dicts
       (dictionaries of artifacts. One execution will be registered for each of
       the input_dict) and corresponding exec_properties.
+    pipeline: Optional. The pipeline proto.
+    node_id: Optional. The node id of the executions to be registered.
 
   Returns:
     A list of MLMD executions that are registered in MLMD, with id populated.
@@ -603,7 +626,7 @@ def register_executions(
     executions.append(execution)
 
   if len(executions) == 1:
-    return [
+    new_executions = [
         execution_lib.put_execution(
             metadata_handle,
             executions[0],
@@ -611,13 +634,33 @@ def register_executions(
             input_artifacts=input_and_params[0].input_artifacts,
         )
     ]
+  else:
+    new_executions = execution_lib.put_executions(
+        metadata_handle,
+        executions,
+        contexts,
+        [
+            input_and_param.input_artifacts
+            for input_and_param in input_and_params
+        ],
+    )
 
-  return execution_lib.put_executions(
-      metadata_handle,
-      executions,
-      contexts,
-      [input_and_param.input_artifacts for input_and_param in input_and_params],
-  )
+  pipeline_asset = metadata_handle.store.pipeline_asset
+  if pipeline_asset and pipeline and node_id:
+    env.get_env().create_pipeline_run_node_executions(
+        pipeline_asset.owner,
+        pipeline_asset.name,
+        pipeline,
+        node_id,
+        new_executions,
+    )
+  else:
+    logging.warning(
+        'Skipping creating pipeline run node executions for pipeline asset %s.',
+        pipeline_asset,
+    )
+
+  return new_executions
 
 
 def update_external_artifact_type(
diff --git a/tfx/orchestration/portable/execution_publish_utils.py b/tfx/orchestration/portable/execution_publish_utils.py
index aa16aa26c7..928e8a5187 100644
--- a/tfx/orchestration/portable/execution_publish_utils.py
+++ b/tfx/orchestration/portable/execution_publish_utils.py
@@ -37,7 +37,7 @@ def publish_cached_executions(
     output_artifacts_maps: Optional[
         Sequence[typing_utils.ArtifactMultiMap]
     ] = None,
-) -> None:
+) -> Sequence[metadata_store_pb2.Execution]:
   """Marks an existing execution as using cached outputs from a previous execution.
 
   Args:
@@ -46,11 +46,14 @@ def publish_cached_executions(
     executions: Executions that will be published as CACHED executions.
     output_artifacts_maps: A list of output artifacts of the executions. Each
       artifact will be linked with the execution through an event of type OUTPUT
+
+  Returns:
+    A list of MLMD executions that are published to MLMD, with id pupulated.
   """
   for execution in executions:
     execution.last_known_state = metadata_store_pb2.Execution.CACHED
 
-  execution_lib.put_executions(
+  return execution_lib.put_executions(
       metadata_handle,
       executions,
       contexts,
diff --git a/tfx/orchestration/portable/importer_node_handler.py b/tfx/orchestration/portable/importer_node_handler.py
index d3997e8a86..f7b78d75db 100644
--- a/tfx/orchestration/portable/importer_node_handler.py
+++ b/tfx/orchestration/portable/importer_node_handler.py
@@ -20,6 +20,8 @@
 from tfx.dsl.components.common import importer
 from tfx.orchestration import data_types_utils
 from tfx.orchestration import metadata
+from tfx.orchestration.experimental.core import env
+from tfx.orchestration.experimental.core import pipeline_state as pstate
 from tfx.orchestration.portable import data_types
 from tfx.orchestration.portable import execution_publish_utils
 from tfx.orchestration.portable import inputs_utils
@@ -57,7 +59,7 @@ def run(
 
     Args:
       mlmd_connection: ML metadata connection.
-      pipeline_node: The specification of the node that this launcher lauches.
+      pipeline_node: The specification of the node that this launcher launches.
       pipeline_info: The information of the pipeline that this node runs in.
       pipeline_runtime_spec: The runtime information of the pipeline that this
         node runs in.
@@ -78,13 +80,29 @@ def run(
           inputs_utils.resolve_parameters_with_schema(
               node_parameters=pipeline_node.parameters))
 
-      # 3. Registers execution in metadata.
+      # 3. Registers execution in storage backend.
       execution = execution_publish_utils.register_execution(
           metadata_handle=m,
           execution_type=pipeline_node.node_info.type,
           contexts=contexts,
           exec_properties=exec_properties,
       )
+      pipeline_asset = m.store.pipeline_asset
+      if pipeline_asset:
+        env.get_env().create_pipeline_run_node_executions(
+            pipeline_asset.owner,
+            pipeline_asset.name,
+            pstate.get_pipeline(m, pipeline_info.id),
+            pipeline_node.node_info.id,
+            [execution],
+        )
+      else:
+        logging.warning(
+            'Pipeline asset %s not found in MLMD. Unable to create pipeline run'
+            ' node execution %s.',
+            pipeline_asset,
+            execution,
+        )
 
       # 4. Generate output artifacts to represent the imported artifacts.
       output_key = cast(str, exec_properties[importer.OUTPUT_KEY_KEY])
diff --git a/tfx/orchestration/portable/partial_run_utils.py b/tfx/orchestration/portable/partial_run_utils.py
index fe701e9a2c..e8155fc55d 100644
--- a/tfx/orchestration/portable/partial_run_utils.py
+++ b/tfx/orchestration/portable/partial_run_utils.py
@@ -24,6 +24,7 @@
 from tfx.dsl.compiler import constants
 from tfx.orchestration import metadata
 from tfx.orchestration import node_proto_view
+from tfx.orchestration.experimental.core import env
 from tfx.orchestration.portable import execution_publish_utils
 from tfx.orchestration.portable.mlmd import context_lib
 from tfx.orchestration.portable.mlmd import execution_lib
@@ -599,6 +600,8 @@ def __init__(
         for node in node_proto_view.get_view_for_all_in(new_pipeline_run_ir)
     }
 
+    self._pipeline = new_pipeline_run_ir
+
   def _get_base_pipeline_run_context(
       self, base_run_id: Optional[str] = None
   ) -> metadata_store_pb2.Context:
@@ -788,7 +791,12 @@ def _cache_and_publish(
             contexts=[self._new_pipeline_run_context] + node_contexts,
         )
     )
-    if not prev_cache_executions:
+
+    # If there are no previous attempts to cache and publish, we will create new
+    # cache executions.
+    create_new_cache_executions: bool = not prev_cache_executions
+
+    if create_new_cache_executions:
       new_cached_executions = []
       for e in existing_executions:
         new_cached_executions.append(
@@ -820,12 +828,36 @@ def _cache_and_publish(
         execution_lib.get_output_artifacts(self._mlmd, e.id)
         for e in existing_executions
     ]
-    execution_publish_utils.publish_cached_executions(
-        self._mlmd,
-        contexts=cached_execution_contexts,
-        executions=new_cached_executions,
-        output_artifacts_maps=output_artifacts_maps,
-    )
+
+    if create_new_cache_executions:
+      new_executions = execution_publish_utils.publish_cached_executions(
+          self._mlmd,
+          contexts=cached_execution_contexts,
+          executions=new_cached_executions,
+          output_artifacts_maps=output_artifacts_maps,
+      )
+      pipeline_asset = self._mlmd.store.pipeline_asset
+      if pipeline_asset:
+        env.get_env().create_pipeline_run_node_executions(
+            pipeline_asset.owner,
+            pipeline_asset.name,
+            self._pipeline,
+            node.node_info.id,
+            new_executions,
+        )
+      else:
+        logging.warning(
+            'Pipeline asset %s not found in MLMD. Unable to create pipeline run'
+            ' node executions.',
+            pipeline_asset,
+        )
+    else:
+      execution_publish_utils.publish_cached_executions(
+          self._mlmd,
+          contexts=cached_execution_contexts,
+          executions=new_cached_executions,
+          output_artifacts_maps=output_artifacts_maps,
+      )
 
   def put_parent_context(self):
     """Puts a ParentContext edge in MLMD."""
diff --git a/tfx/orchestration/portable/resolver_node_handler.py b/tfx/orchestration/portable/resolver_node_handler.py
index 221d6bb278..151615bffe 100644
--- a/tfx/orchestration/portable/resolver_node_handler.py
+++ b/tfx/orchestration/portable/resolver_node_handler.py
@@ -20,6 +20,8 @@
 import grpc
 from tfx.orchestration import data_types_utils
 from tfx.orchestration import metadata
+from tfx.orchestration.experimental.core import env
+from tfx.orchestration.experimental.core import pipeline_state as pstate
 from tfx.orchestration.portable import data_types
 from tfx.orchestration.portable import execution_publish_utils
 from tfx.orchestration.portable import inputs_utils
@@ -86,6 +88,22 @@ def run(
             contexts=contexts,
             exec_properties=exec_properties,
         )
+        pipeline_asset = m.store.pipeline_asset
+        if pipeline_asset:
+          env.get_env().create_pipeline_run_node_executions(
+              pipeline_asset.owner,
+              pipeline_asset.name,
+              pstate.get_pipeline(m, pipeline_info.id),
+              pipeline_node.node_info.id,
+              [execution],
+          )
+        else:
+          logging.warning(
+              'Pipeline asset %s not found in MLMD. Unable to create pipeline'
+              ' run node execution %s.',
+              pipeline_asset,
+              execution,
+          )
         execution_publish_utils.publish_failed_execution(
             metadata_handle=m,
             contexts=contexts,
@@ -103,14 +121,29 @@ def run(
       if isinstance(resolved_inputs, inputs_utils.Skip):
         return data_types.ExecutionInfo()
 
-      # 3. Registers execution in metadata.
+      # 3. Registers execution in storage backends.
       execution = execution_publish_utils.register_execution(
           metadata_handle=m,
           execution_type=pipeline_node.node_info.type,
           contexts=contexts,
           exec_properties=exec_properties,
       )
-
+      pipeline_asset = m.store.pipeline_asset
+      if pipeline_asset:
+        env.get_env().create_pipeline_run_node_executions(
+            pipeline_asset.owner,
+            pipeline_asset.name,
+            pstate.get_pipeline(m, pipeline_info.id),
+            pipeline_node.node_info.id,
+            [execution],
+        )
+      else:
+        logging.warning(
+            'Pipeline asset %s not found in MLMD. Unable to create pipeline'
+            ' run node execution %s.',
+            pipeline_asset,
+            execution,
+        )
       # TODO(b/197741942): Support len > 1.
       if len(resolved_inputs) > 1:
         execution_publish_utils.publish_failed_execution(