Skip to content

Commit

Permalink
Expose end node API through python NodeExecutionOptions.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 574505882
  • Loading branch information
kmonte authored and tfx-copybara committed Oct 18, 2023
1 parent 1269182 commit ceea6bd
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 12 deletions.
4 changes: 4 additions & 0 deletions tfx/dsl/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,10 @@ def _compile_node(
node.execution_options.node_success_optional = node_execution_options.success_optional
node.execution_options.max_execution_retries = node_execution_options.max_execution_retries
node.execution_options.execution_timeout_sec = node_execution_options.execution_timeout_sec
if node_execution_options.lifetime_start:
node.execution_options.resource_lifetime.lifetime_start = (
node_execution_options.lifetime_start
)

if pipeline_ctx.is_async_mode:
input_triggers = node.execution_options.async_trigger.input_triggers
Expand Down
13 changes: 13 additions & 0 deletions tfx/dsl/experimental/node_execution_options/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
This is only used for the experimental orchestrator.
"""
import dataclasses
from typing import Optional

from tfx.proto.orchestration import pipeline_pb2

Expand All @@ -32,5 +33,17 @@ class NodeExecutionOptions:
max_execution_retries: int = 0
execution_timeout_sec: int = 0

# This is an experimental feature to enable "end nodes" in a pipeline to
# support resource lifetimes. If this field is set then the node which this
# NodeExecutionOptions belongs to will run during pipeline finalization if the
# "lifetime_start" has run succesfully.
# Pipeline finalization happens when:
# 1. All nodes in the pipeline completed, this is the "happy path".
# 2. A user requests for the pipeline to stop
# 3. A node fails in the pipeline and it cannot continue executing.
# This should be the id of the node "starting" a lifetime.
# If you want to use this feature please contact kmonte@ first.
lifetime_start: Optional[str] = None

def __post_init__(self):
self.max_execution_retries = max(self.max_execution_retries, 0)
31 changes: 19 additions & 12 deletions tfx/dsl/experimental/node_execution_options/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for utils."""

import tensorflow as tf
from tfx import types
from tfx.dsl.components.base import base_component
Expand All @@ -28,8 +29,9 @@ class _BasicComponentSpec(types.ComponentSpec):
PARAMETERS = {}
INPUTS = {}
OUTPUTS = {
'examples':
component_spec.ChannelParameter(type=standard_artifacts.Examples)
"examples": component_spec.ChannelParameter(
type=standard_artifacts.Examples
)
}


Expand All @@ -41,30 +43,35 @@ class _BasicComponent(base_component.BaseComponent):
def __init__(self, component_spec_args):
super().__init__(_BasicComponentSpec(**component_spec_args))

_COMPONENT_SPEC_ARGS = {
"examples": channel.Channel(standard_artifacts.Examples)
}


class UtilsTest(tf.test.TestCase):

def test_execution_options(self):
component = _BasicComponent(component_spec_args={
'examples': channel.Channel(standard_artifacts.Examples)
})
component = _BasicComponent(component_spec_args=_COMPONENT_SPEC_ARGS)
component.node_execution_options = utils.NodeExecutionOptions(
trigger_strategy=pipeline_pb2.NodeExecutionOptions
.ALL_UPSTREAM_NODES_COMPLETED,
trigger_strategy=pipeline_pb2.NodeExecutionOptions.ALL_UPSTREAM_NODES_COMPLETED,
success_optional=True,
max_execution_retries=-1,
execution_timeout_sec=100)
execution_timeout_sec=100,
lifetime_start="foo",
)
self.assertEqual(
component.node_execution_options,
utils.NodeExecutionOptions(
trigger_strategy=pipeline_pb2.NodeExecutionOptions
.ALL_UPSTREAM_NODES_COMPLETED,
trigger_strategy=pipeline_pb2.NodeExecutionOptions.ALL_UPSTREAM_NODES_COMPLETED,
success_optional=True,
max_execution_retries=0,
execution_timeout_sec=100))
execution_timeout_sec=100,
lifetime_start="foo",
),
)
component.node_execution_options = None
self.assertIsNone(component.node_execution_options)


if __name__ == '__main__':
if __name__ == "__main__":
tf.test.main()

0 comments on commit ceea6bd

Please sign in to comment.