diff --git a/tfx/orchestration/experimental/core/async_pipeline_task_gen.py b/tfx/orchestration/experimental/core/async_pipeline_task_gen.py index 60a36b773b..0b7204ba5f 100644 --- a/tfx/orchestration/experimental/core/async_pipeline_task_gen.py +++ b/tfx/orchestration/experimental/core/async_pipeline_task_gen.py @@ -490,6 +490,8 @@ def _generate_tasks_for_node( execution_type=node.node_info.type, contexts=resolved_info.contexts, input_and_params=unprocessed_inputs, + pipeline=self._pipeline, + node_id=node.node_info.id, ) for execution in executions: diff --git a/tfx/orchestration/experimental/core/sync_pipeline_task_gen.py b/tfx/orchestration/experimental/core/sync_pipeline_task_gen.py index 04f49cdeca..cf2a965a8a 100644 --- a/tfx/orchestration/experimental/core/sync_pipeline_task_gen.py +++ b/tfx/orchestration/experimental/core/sync_pipeline_task_gen.py @@ -564,6 +564,8 @@ def _generate_tasks_from_resolved_inputs( execution_type=node.node_info.type, contexts=resolved_info.contexts, input_and_params=resolved_info.input_and_params, + pipeline=self._pipeline, + node_id=node.node_info.id, ) result.extend( diff --git a/tfx/orchestration/experimental/core/task_gen_utils.py b/tfx/orchestration/experimental/core/task_gen_utils.py index 514c042fd2..9a943ba2a9 100644 --- a/tfx/orchestration/experimental/core/task_gen_utils.py +++ b/tfx/orchestration/experimental/core/task_gen_utils.py @@ -30,6 +30,7 @@ from tfx.orchestration import metadata from tfx.orchestration import node_proto_view from tfx.orchestration.experimental.core import constants +from tfx.orchestration.experimental.core import env from tfx.orchestration.experimental.core import mlmd_state from tfx.orchestration.experimental.core import task as task_lib from tfx.orchestration import mlmd_connection_manager as mlmd_cm @@ -548,21 +549,35 @@ def register_executions_from_existing_executions( contexts = metadata_handle.store.get_contexts_by_execution( existing_executions[0].id ) - return execution_lib.put_executions( + executions = execution_lib.put_executions( metadata_handle, new_executions, contexts, input_artifacts_maps=input_artifacts, ) + pipeline_asset = metadata_handle.store.pipeline_asset + if pipeline_asset: + env.get_env().create_pipeline_run_node_executions( + pipeline_asset.owner, + pipeline_asset.name, + pipeline, + node.node_info.id, + executions, + ) + + return executions + def register_executions( metadata_handle: metadata.Metadata, execution_type: metadata_store_pb2.ExecutionType, contexts: Sequence[metadata_store_pb2.Context], input_and_params: Sequence[InputAndParam], + pipeline: Optional[pipeline_pb2.Pipeline] = None, + node_id: Optional[str] = None, ) -> Sequence[metadata_store_pb2.Execution]: - """Registers multiple executions in MLMD. + """Registers multiple executions in MLMD and Tflex backends. Along with the execution: - the input artifacts will be linked to the executions. @@ -575,6 +590,8 @@ def register_executions( input_and_params: A list of InputAndParams, which includes input_dicts (dictionaries of artifacts. One execution will be registered for each of the input_dict) and corresponding exec_properties. + pipeline: Optional. The pipeline proto. + node_id: Optional. The node id of the executions to be registered. Returns: A list of MLMD executions that are registered in MLMD, with id populated. @@ -603,7 +620,7 @@ def register_executions( executions.append(execution) if len(executions) == 1: - return [ + new_executions = [ execution_lib.put_execution( metadata_handle, executions[0], @@ -611,13 +628,27 @@ def register_executions( input_artifacts=input_and_params[0].input_artifacts, ) ] + else: + new_executions = execution_lib.put_executions( + metadata_handle, + executions, + contexts, + [ + input_and_param.input_artifacts + for input_and_param in input_and_params + ], + ) - return execution_lib.put_executions( - metadata_handle, - executions, - contexts, - [input_and_param.input_artifacts for input_and_param in input_and_params], - ) + pipeline_asset = metadata_handle.store.pipeline_asset + if pipeline_asset and pipeline and node_id: + env.get_env().create_pipeline_run_node_executions( + pipeline_asset.owner, + pipeline_asset.name, + pipeline, + node_id, + new_executions, + ) + return new_executions def update_external_artifact_type(