From 47008b66e4f4fb77bcf22fbb9028296c98de432e Mon Sep 17 00:00:00 2001 From: tfx-team <tensorflow-extended-nonhuman@googlegroups.com> Date: Wed, 7 Aug 2024 15:02:59 -0700 Subject: [PATCH] no-op PiperOrigin-RevId: 660551866 --- .../core/async_pipeline_task_gen.py | 2 + .../experimental/core/pipeline_state.py | 8 +++ .../core/sync_pipeline_task_gen.py | 2 + .../experimental/core/task_gen_utils.py | 61 ++++++++++++++++--- .../portable/execution_publish_utils.py | 7 ++- .../portable/importer_node_handler.py | 22 ++++++- .../portable/partial_run_utils.py | 46 +++++++++++--- .../portable/resolver_node_handler.py | 37 ++++++++++- 8 files changed, 163 insertions(+), 22 deletions(-) diff --git a/tfx/orchestration/experimental/core/async_pipeline_task_gen.py b/tfx/orchestration/experimental/core/async_pipeline_task_gen.py index 60a36b773b..0b7204ba5f 100644 --- a/tfx/orchestration/experimental/core/async_pipeline_task_gen.py +++ b/tfx/orchestration/experimental/core/async_pipeline_task_gen.py @@ -490,6 +490,8 @@ def _generate_tasks_for_node( execution_type=node.node_info.type, contexts=resolved_info.contexts, input_and_params=unprocessed_inputs, + pipeline=self._pipeline, + node_id=node.node_info.id, ) for execution in executions: diff --git a/tfx/orchestration/experimental/core/pipeline_state.py b/tfx/orchestration/experimental/core/pipeline_state.py index 9be76a4792..516087456a 100644 --- a/tfx/orchestration/experimental/core/pipeline_state.py +++ b/tfx/orchestration/experimental/core/pipeline_state.py @@ -1673,3 +1673,11 @@ def get_pipeline_and_node( 'pipeline nodes are supported for external executions.' ) return (pipeline_state.pipeline, node) + + +def get_pipeline( + mlmd_handle: metadata.Metadata, pipeline_id: str +) -> pipeline_pb2.Pipeline: + """Loads the pipeline proto for a pipeline from latest execution.""" + pipeline_view = PipelineView.load(mlmd_handle, pipeline_id) + return pipeline_view.pipeline diff --git a/tfx/orchestration/experimental/core/sync_pipeline_task_gen.py b/tfx/orchestration/experimental/core/sync_pipeline_task_gen.py index 04f49cdeca..cf2a965a8a 100644 --- a/tfx/orchestration/experimental/core/sync_pipeline_task_gen.py +++ b/tfx/orchestration/experimental/core/sync_pipeline_task_gen.py @@ -564,6 +564,8 @@ def _generate_tasks_from_resolved_inputs( execution_type=node.node_info.type, contexts=resolved_info.contexts, input_and_params=resolved_info.input_and_params, + pipeline=self._pipeline, + node_id=node.node_info.id, ) result.extend( diff --git a/tfx/orchestration/experimental/core/task_gen_utils.py b/tfx/orchestration/experimental/core/task_gen_utils.py index 514c042fd2..4db7b00e7d 100644 --- a/tfx/orchestration/experimental/core/task_gen_utils.py +++ b/tfx/orchestration/experimental/core/task_gen_utils.py @@ -30,6 +30,7 @@ from tfx.orchestration import metadata from tfx.orchestration import node_proto_view from tfx.orchestration.experimental.core import constants +from tfx.orchestration.experimental.core import env from tfx.orchestration.experimental.core import mlmd_state from tfx.orchestration.experimental.core import task as task_lib from tfx.orchestration import mlmd_connection_manager as mlmd_cm @@ -548,21 +549,41 @@ def register_executions_from_existing_executions( contexts = metadata_handle.store.get_contexts_by_execution( existing_executions[0].id ) - return execution_lib.put_executions( + executions = execution_lib.put_executions( metadata_handle, new_executions, contexts, input_artifacts_maps=input_artifacts, ) + pipeline_asset = metadata_handle.store.pipeline_asset + if pipeline_asset: + env.get_env().create_pipeline_run_node_executions( + pipeline_asset.owner, + pipeline_asset.name, + pipeline, + node.node_info.id, + executions, + ) + else: + logging.warning( + 'Pipeline asset %s not found in MLMD. Unable to create pipeline run' + ' node executions.', + pipeline_asset, + ) + return executions + +# TODO(b/349654866): make pipeline and node_id non-optional. def register_executions( metadata_handle: metadata.Metadata, execution_type: metadata_store_pb2.ExecutionType, contexts: Sequence[metadata_store_pb2.Context], input_and_params: Sequence[InputAndParam], + pipeline: Optional[pipeline_pb2.Pipeline] = None, + node_id: Optional[str] = None, ) -> Sequence[metadata_store_pb2.Execution]: - """Registers multiple executions in MLMD. + """Registers multiple executions in storage backends. Along with the execution: - the input artifacts will be linked to the executions. @@ -575,6 +596,8 @@ def register_executions( input_and_params: A list of InputAndParams, which includes input_dicts (dictionaries of artifacts. One execution will be registered for each of the input_dict) and corresponding exec_properties. + pipeline: Optional. The pipeline proto. + node_id: Optional. The node id of the executions to be registered. Returns: A list of MLMD executions that are registered in MLMD, with id populated. @@ -603,7 +626,7 @@ def register_executions( executions.append(execution) if len(executions) == 1: - return [ + new_executions = [ execution_lib.put_execution( metadata_handle, executions[0], @@ -611,13 +634,33 @@ def register_executions( input_artifacts=input_and_params[0].input_artifacts, ) ] + else: + new_executions = execution_lib.put_executions( + metadata_handle, + executions, + contexts, + [ + input_and_param.input_artifacts + for input_and_param in input_and_params + ], + ) - return execution_lib.put_executions( - metadata_handle, - executions, - contexts, - [input_and_param.input_artifacts for input_and_param in input_and_params], - ) + pipeline_asset = metadata_handle.store.pipeline_asset + if pipeline_asset and pipeline and node_id: + env.get_env().create_pipeline_run_node_executions( + pipeline_asset.owner, + pipeline_asset.name, + pipeline, + node_id, + new_executions, + ) + else: + logging.warning( + 'Skipping creating pipeline run node executions for pipeline asset %s.', + pipeline_asset, + ) + + return new_executions def update_external_artifact_type( diff --git a/tfx/orchestration/portable/execution_publish_utils.py b/tfx/orchestration/portable/execution_publish_utils.py index aa16aa26c7..928e8a5187 100644 --- a/tfx/orchestration/portable/execution_publish_utils.py +++ b/tfx/orchestration/portable/execution_publish_utils.py @@ -37,7 +37,7 @@ def publish_cached_executions( output_artifacts_maps: Optional[ Sequence[typing_utils.ArtifactMultiMap] ] = None, -) -> None: +) -> Sequence[metadata_store_pb2.Execution]: """Marks an existing execution as using cached outputs from a previous execution. Args: @@ -46,11 +46,14 @@ def publish_cached_executions( executions: Executions that will be published as CACHED executions. output_artifacts_maps: A list of output artifacts of the executions. Each artifact will be linked with the execution through an event of type OUTPUT + + Returns: + A list of MLMD executions that are published to MLMD, with id pupulated. """ for execution in executions: execution.last_known_state = metadata_store_pb2.Execution.CACHED - execution_lib.put_executions( + return execution_lib.put_executions( metadata_handle, executions, contexts, diff --git a/tfx/orchestration/portable/importer_node_handler.py b/tfx/orchestration/portable/importer_node_handler.py index d3997e8a86..f7b78d75db 100644 --- a/tfx/orchestration/portable/importer_node_handler.py +++ b/tfx/orchestration/portable/importer_node_handler.py @@ -20,6 +20,8 @@ from tfx.dsl.components.common import importer from tfx.orchestration import data_types_utils from tfx.orchestration import metadata +from tfx.orchestration.experimental.core import env +from tfx.orchestration.experimental.core import pipeline_state as pstate from tfx.orchestration.portable import data_types from tfx.orchestration.portable import execution_publish_utils from tfx.orchestration.portable import inputs_utils @@ -57,7 +59,7 @@ def run( Args: mlmd_connection: ML metadata connection. - pipeline_node: The specification of the node that this launcher lauches. + pipeline_node: The specification of the node that this launcher launches. pipeline_info: The information of the pipeline that this node runs in. pipeline_runtime_spec: The runtime information of the pipeline that this node runs in. @@ -78,13 +80,29 @@ def run( inputs_utils.resolve_parameters_with_schema( node_parameters=pipeline_node.parameters)) - # 3. Registers execution in metadata. + # 3. Registers execution in storage backend. execution = execution_publish_utils.register_execution( metadata_handle=m, execution_type=pipeline_node.node_info.type, contexts=contexts, exec_properties=exec_properties, ) + pipeline_asset = m.store.pipeline_asset + if pipeline_asset: + env.get_env().create_pipeline_run_node_executions( + pipeline_asset.owner, + pipeline_asset.name, + pstate.get_pipeline(m, pipeline_info.id), + pipeline_node.node_info.id, + [execution], + ) + else: + logging.warning( + 'Pipeline asset %s not found in MLMD. Unable to create pipeline run' + ' node execution %s.', + pipeline_asset, + execution, + ) # 4. Generate output artifacts to represent the imported artifacts. output_key = cast(str, exec_properties[importer.OUTPUT_KEY_KEY]) diff --git a/tfx/orchestration/portable/partial_run_utils.py b/tfx/orchestration/portable/partial_run_utils.py index fe701e9a2c..e8155fc55d 100644 --- a/tfx/orchestration/portable/partial_run_utils.py +++ b/tfx/orchestration/portable/partial_run_utils.py @@ -24,6 +24,7 @@ from tfx.dsl.compiler import constants from tfx.orchestration import metadata from tfx.orchestration import node_proto_view +from tfx.orchestration.experimental.core import env from tfx.orchestration.portable import execution_publish_utils from tfx.orchestration.portable.mlmd import context_lib from tfx.orchestration.portable.mlmd import execution_lib @@ -599,6 +600,8 @@ def __init__( for node in node_proto_view.get_view_for_all_in(new_pipeline_run_ir) } + self._pipeline = new_pipeline_run_ir + def _get_base_pipeline_run_context( self, base_run_id: Optional[str] = None ) -> metadata_store_pb2.Context: @@ -788,7 +791,12 @@ def _cache_and_publish( contexts=[self._new_pipeline_run_context] + node_contexts, ) ) - if not prev_cache_executions: + + # If there are no previous attempts to cache and publish, we will create new + # cache executions. + create_new_cache_executions: bool = not prev_cache_executions + + if create_new_cache_executions: new_cached_executions = [] for e in existing_executions: new_cached_executions.append( @@ -820,12 +828,36 @@ def _cache_and_publish( execution_lib.get_output_artifacts(self._mlmd, e.id) for e in existing_executions ] - execution_publish_utils.publish_cached_executions( - self._mlmd, - contexts=cached_execution_contexts, - executions=new_cached_executions, - output_artifacts_maps=output_artifacts_maps, - ) + + if create_new_cache_executions: + new_executions = execution_publish_utils.publish_cached_executions( + self._mlmd, + contexts=cached_execution_contexts, + executions=new_cached_executions, + output_artifacts_maps=output_artifacts_maps, + ) + pipeline_asset = self._mlmd.store.pipeline_asset + if pipeline_asset: + env.get_env().create_pipeline_run_node_executions( + pipeline_asset.owner, + pipeline_asset.name, + self._pipeline, + node.node_info.id, + new_executions, + ) + else: + logging.warning( + 'Pipeline asset %s not found in MLMD. Unable to create pipeline run' + ' node executions.', + pipeline_asset, + ) + else: + execution_publish_utils.publish_cached_executions( + self._mlmd, + contexts=cached_execution_contexts, + executions=new_cached_executions, + output_artifacts_maps=output_artifacts_maps, + ) def put_parent_context(self): """Puts a ParentContext edge in MLMD.""" diff --git a/tfx/orchestration/portable/resolver_node_handler.py b/tfx/orchestration/portable/resolver_node_handler.py index 221d6bb278..151615bffe 100644 --- a/tfx/orchestration/portable/resolver_node_handler.py +++ b/tfx/orchestration/portable/resolver_node_handler.py @@ -20,6 +20,8 @@ import grpc from tfx.orchestration import data_types_utils from tfx.orchestration import metadata +from tfx.orchestration.experimental.core import env +from tfx.orchestration.experimental.core import pipeline_state as pstate from tfx.orchestration.portable import data_types from tfx.orchestration.portable import execution_publish_utils from tfx.orchestration.portable import inputs_utils @@ -86,6 +88,22 @@ def run( contexts=contexts, exec_properties=exec_properties, ) + pipeline_asset = m.store.pipeline_asset + if pipeline_asset: + env.get_env().create_pipeline_run_node_executions( + pipeline_asset.owner, + pipeline_asset.name, + pstate.get_pipeline(m, pipeline_info.id), + pipeline_node.node_info.id, + [execution], + ) + else: + logging.warning( + 'Pipeline asset %s not found in MLMD. Unable to create pipeline' + ' run node execution %s.', + pipeline_asset, + execution, + ) execution_publish_utils.publish_failed_execution( metadata_handle=m, contexts=contexts, @@ -103,14 +121,29 @@ def run( if isinstance(resolved_inputs, inputs_utils.Skip): return data_types.ExecutionInfo() - # 3. Registers execution in metadata. + # 3. Registers execution in storage backends. execution = execution_publish_utils.register_execution( metadata_handle=m, execution_type=pipeline_node.node_info.type, contexts=contexts, exec_properties=exec_properties, ) - + pipeline_asset = m.store.pipeline_asset + if pipeline_asset: + env.get_env().create_pipeline_run_node_executions( + pipeline_asset.owner, + pipeline_asset.name, + pstate.get_pipeline(m, pipeline_info.id), + pipeline_node.node_info.id, + [execution], + ) + else: + logging.warning( + 'Pipeline asset %s not found in MLMD. Unable to create pipeline' + ' run node execution %s.', + pipeline_asset, + execution, + ) # TODO(b/197741942): Support len > 1. if len(resolved_inputs) > 1: execution_publish_utils.publish_failed_execution(