Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SkipFailedPipelineNodes API to orchestrator. #6888

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 62 additions & 1 deletion tfx/orchestration/experimental/core/pipeline_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def _check_nodes_exist(
raise status_lib.StatusNotOkError(
code=status_lib.Code.INVALID_ARGUMENT,
message=(
f'`f{op_name}` operation failed, cannot find node(s) '
f'`{op_name}` operation failed, cannot find node(s) '
f'{", ".join(node_id_set)} in the pipeline IR.'
),
)
Expand Down Expand Up @@ -554,6 +554,67 @@ def skip_nodes(
)


@_pipeline_op()
def skip_failed_nodes(
mlmd_handle: metadata.Metadata, node_uids: Sequence[task_lib.NodeUid]
) -> None:
"""Marks the given failed nodes as skipped instead."""
# All node_uids must have the same pipeline_uid.
pipeline_uids_set = set(n.pipeline_uid for n in node_uids)
if len(pipeline_uids_set) != 1:
raise status_lib.StatusNotOkError(
code=status_lib.Code.INVALID_ARGUMENT,
message=(
'All nodes must belong to the same pipeline, but the given '
f'nodes do not. Node UIDs were: {node_uids}'
),
)
pipeline_uid = pipeline_uids_set.pop()
with pstate.PipelineState.load_run(
mlmd_handle,
pipeline_id=pipeline_uid.pipeline_id,
run_id=pipeline_uid.pipeline_run_id,
) as pipeline_state:
pipeline = pipeline_state.pipeline
if pipeline.execution_mode != pipeline_pb2.Pipeline.SYNC:
raise status_lib.StatusNotOkError(
code=status_lib.Code.FAILED_PRECONDITION,
message=(
'Can only skip failed nodes for SYNC pipelines, but pipeline had'
f'execution mode: {pipeline.execution_mode}'
),
)
if not execution_lib.is_execution_failed(pipeline_state.execution):
state_str = metadata_store_pb2.Execution.State.Name(
pipeline_state.execution.last_known_state
)
raise status_lib.StatusNotOkError(
code=status_lib.Code.FAILED_PRECONDITION,
message=(
'Can only skip failed nodes for a pipeline in FAILED state, but '
f'pipeline in state: {state_str}'
),
)
_check_nodes_exist(node_uids, pipeline_state.pipeline, 'skip_nodes')
for node_uid in node_uids:
with pipeline_state.node_state_update_context(node_uid) as node_state:
if node_state.state != pstate.NodeState.FAILED:
raise status_lib.StatusNotOkError(
code=status_lib.Code.FAILED_PRECONDITION,
message=(
'Can only skip nodes that are in a FAILED state, but node '
f'{node_uid} was in state {node_state.state}'
),
)
node_state.update(
pstate.NodeState.SKIPPED,
status_lib.Status(
code=status_lib.Code.OK,
message='Failed node marked as skipped using SkipFailedNodes',
),
)


@_pipeline_op()
def resume_manual_node(
mlmd_handle: metadata.Metadata, node_uid: task_lib.NodeUid
Expand Down
74 changes: 74 additions & 0 deletions tfx/orchestration/experimental/core/pipeline_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3206,6 +3206,80 @@ def test_skip_nodes(self):
states_dict[task_lib.NodeUid(pipeline_uid, 'Pusher')].state,
)

def test_skip_failed_nodes(self):
with self._mlmd_cm as mlmd_connection_manager:
m = mlmd_connection_manager.primary_mlmd_handle
pipeline = _test_pipeline(
'pipeline1', execution_mode=pipeline_pb2.Pipeline.SYNC
)
pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
pipeline.nodes.add().pipeline_node.node_info.id = 'ExampleGen'
pipeline.nodes.add().pipeline_node.node_info.id = 'Transform'
pipeline_ops.initiate_pipeline_start(m, pipeline)

# Can't skip failed nodes if the pipeline isn't in a FAILED state
with self.assertRaises(status_lib.StatusNotOkError) as exception_context:
pipeline_ops.skip_failed_nodes(
m,
[task_lib.NodeUid(pipeline_uid, 'ExampleGen')],
)
self.assertEqual(
status_lib.Code.FAILED_PRECONDITION, exception_context.exception.code
)

with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state:
# Change state of ExampleGen node to COMPLETE.
with pipeline_state.node_state_update_context(
task_lib.NodeUid(pipeline_uid, 'ExampleGen')
) as node_state:
node_state.state = pstate.NodeState.COMPLETE
# Change state of Transform node to FAILED.
with pipeline_state.node_state_update_context(
task_lib.NodeUid(pipeline_uid, 'Transform')
) as node_state:
node_state.state = pstate.NodeState.FAILED

# Can't skip failed nodes if the pipeline isn't in a FAILED state,
# even if the node is in a FAILED state.
with self.assertRaises(status_lib.StatusNotOkError) as exception_context:
pipeline_ops.skip_failed_nodes(
m,
[task_lib.NodeUid(pipeline_uid, 'Transform')],
)
self.assertEqual(
status_lib.Code.FAILED_PRECONDITION, exception_context.exception.code
)

with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state:
# Now mark the pipeline as FAILED.
pipeline_state.set_pipeline_execution_state(
metadata_store_pb2.Execution.FAILED
)

# Can't skip non-failed nodes.
with self.assertRaises(status_lib.StatusNotOkError) as exception_context:
pipeline_ops.skip_failed_nodes(
m,
[task_lib.NodeUid(pipeline_uid, 'ExampleGen')],
)
self.assertEqual(
status_lib.Code.FAILED_PRECONDITION, exception_context.exception.code
)

# Skip Transform
pipeline_ops.skip_failed_nodes(
m,
[task_lib.NodeUid(pipeline_uid, 'Transform')],
)

pipeline_view = pstate.PipelineView.load(
m,
pipeline_id=pipeline_uid.pipeline_id,
pipeline_run_id=pipeline_uid.pipeline_run_id,
)
states_dict = pipeline_view.get_node_states_dict()
self.assertEqual(pstate.NodeState.SKIPPED, states_dict['Transform'].state)

def test_exception_while_orchestrating_active_pipeline(self):
with self._mlmd_cm as mlmd_connection_manager:
m = mlmd_connection_manager.primary_mlmd_handle
Expand Down