Skip to content

Commit

Permalink
Starting to check in IR-based execution stack. When finished, it will…
Browse files Browse the repository at this point in the history
… realize the IR-based execution workflow introduced in the TFX IR RFC: tensorflow/community#271

PiperOrigin-RevId: 325953252
  • Loading branch information
tfx-copybara committed Aug 11, 2020
1 parent 55154bc commit e1f7326
Show file tree
Hide file tree
Showing 36 changed files with 4,679 additions and 0 deletions.
5 changes: 5 additions & 0 deletions tfx/orchestration/portable/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# IR-based portable execution stack

This module is WIP. When finished, it will realize the IR-based execution
workflow introduced in the TFX IR
[RFC](https://github.com/tensorflow/community/pull/271).
14 changes: 14 additions & 0 deletions tfx/orchestration/portable/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Lint as: python2, python3
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
99 changes: 99 additions & 0 deletions tfx/orchestration/portable/base_executor_operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base class to define how to operator an executor."""

import abc
from typing import Any, Dict, List, Optional, Text

import attr
import six
from tfx import types
from tfx.proto.orchestration import execution_result_pb2
from tfx.proto.orchestration import pipeline_pb2
from tfx.utils import abc_utils

from google.protobuf import message
from ml_metadata.proto import metadata_store_pb2


# TODO(b/150979622): We should introduce an id that is not changed across
# retires of the same component run and pass it to executor operators for
# human-readability purpose.
@attr.s(auto_attribs=True)
class ExecutionInfo:
"""A struct to store information for an execution."""
# The metadata of this execution that is registered in MLMD.
execution_metadata: metadata_store_pb2.Execution = None
# The input map to feed to executor
input_dict: Dict[Text, List[types.Artifact]] = None
# The output map to feed to executor
output_dict: Dict[Text, List[types.Artifact]] = None
# The exec_properties to feed to executor
exec_properties: Dict[Text, Any] = None
# The uri to executor result, note that Executors and Launchers may not run
# in the same process, so executors should use this uri to "return"
# ExecutorOutput to the launcher.
executor_output_uri: Text = None
# Stateful working dir will be deterministic given pipeline, node and run_id.
# The typical usecase is to restore long running executor's state after
# eviction. For examples, a Trainer can use this directory to store
# checkpoints.
stateful_working_dir: Text = None
# The config of this Node.
pipeline_node: pipeline_pb2.PipelineNode = None
# The config of the pipeline that this node is running in.
pipeline_info: pipeline_pb2.PipelineInfo = None


class BaseExecutorOperator(six.with_metaclass(abc.ABCMeta, object)):
"""The base class of all executor operators."""

SUPPORTED_EXECUTOR_SPEC_TYPE = abc_utils.abstract_property()
SUPPORTED_PLATFORM_SPEC_TYPE = abc_utils.abstract_property()

def __init__(self,
executor_spec: message.Message,
platform_spec: Optional[message.Message] = None):
"""Constructor.
Args:
executor_spec: The specification of how to initialize the executor.
platform_spec: The specification of how to allocate resource for the
executor.
Raises:
RuntimeError: if the executor_spec or platform_spec is not supported.
"""
if not isinstance(executor_spec,
tuple(t for t in self.SUPPORTED_EXECUTOR_SPEC_TYPE)):
raise RuntimeError('Executor spec not supported: %s' % executor_spec)
if platform_spec and not isinstance(
platform_spec, tuple(t for t in self.SUPPORTED_PLATFORM_SPEC_TYPE)):
raise RuntimeError('Platform spec not supported: %s' % platform_spec)
self._executor_spec = executor_spec
self._platform_spec = platform_spec

@abc.abstractmethod
def run_executor(
self,
execution_info: ExecutionInfo,
) -> execution_result_pb2.ExecutorOutput:
"""Invokers executors given input from the Launcher.
Args:
execution_info: A wrapper of the details of this execution.
Returns:
The output from executor.
"""
pass
134 changes: 134 additions & 0 deletions tfx/orchestration/portable/beam_dag_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Definition of Beam TFX runner."""

import os
from typing import Any, Iterable

from absl import logging
import apache_beam as beam
from tfx.orchestration import metadata
from tfx.orchestration.portable import launcher
from tfx.orchestration.portable import tfx_runner
from tfx.proto.orchestration import pipeline_pb2
from tfx.utils import telemetry_utils

from ml_metadata.proto import metadata_store_pb2


# TODO(jyzhao): confirm it's re-executable, add test case.
@beam.typehints.with_input_types(Any)
@beam.typehints.with_output_types(Any)
class _PipelineNodeAsDoFn(beam.DoFn):
"""Wrap component as beam DoFn."""

def __init__(self,
pipeline_node: pipeline_pb2.PipelineNode,
mlmd_connection: metadata.Metadata,
pipeline_info: pipeline_pb2.PipelineInfo,
pipeline_runtime_spec: pipeline_pb2.PipelineRuntimeSpec):
"""Initializes the _PipelineNodeAsDoFn.
Args:
pipeline_node: The specification of the node that this launcher lauches.
mlmd_connection: ML metadata connection. The connection is expected to
not be opened before launcher is initiated.
pipeline_info: The information of the pipeline that this node runs in.
pipeline_runtime_spec: The runtime information of the pipeline that this
node runs in.
"""
self._launcher = launcher.Launcher(
pipeline_node=pipeline_node,
mlmd_connection=mlmd_connection,
pipeline_info=pipeline_info,
pipeline_runtime_spec=pipeline_runtime_spec)
self._component_id = pipeline_node.node_info.id

def process(self, element: Any, *signals: Iterable[Any]) -> None:
"""Executes component based on signals.
Args:
element: a signal element to trigger the component.
*signals: side input signals indicate completeness of upstream components.
"""
for signal in signals:
assert not list(signal), 'Signal PCollection should be empty.'
self._run_component()

def _run_component(self) -> None:
logging.info('Component %s is running.', self._component_id)
self._launcher.launch()
logging.info('Component %s is finished.', self._component_id)


class BeamDagRunner(tfx_runner.TfxRunner):
"""Tfx runner on Beam."""

def __init__(self):
"""Initializes BeamDagRunner as a TFX orchestrator.
"""

def run(self, pipeline: pipeline_pb2.Pipeline) -> None:
"""Deploys given logical pipeline on Beam.
Args:
pipeline: Logical pipeline in IR format.
"""
# For CLI, while creating or updating pipeline, pipeline_args are extracted
# and hence we avoid deploying the pipeline.
if 'TFX_JSON_EXPORT_PIPELINE_ARGS_PATH' in os.environ:
return

# TODO(b/163003901): Support beam DAG runner args through IR.
# TODO(b/163003901): MLMD connection config should be passed in via IR.
connection_config = metadata_store_pb2.ConnectionConfig()
connection_config.sqlite.SetInParent()
mlmd_connection = metadata.Metadata(
connection_config=connection_config)

with telemetry_utils.scoped_labels(
{telemetry_utils.LABEL_TFX_RUNNER: 'beam'}):
with beam.Pipeline() as p:
# Uses for triggering the component DoFns.
root = p | 'CreateRoot' >> beam.Create([None])

# Stores mapping of component to its signal.
signal_map = {}
# pipeline.components are in topological order.
for node in pipeline.nodes:
# TODO(b/160882349): Support subpipeline
pipeline_node = node.pipeline_node
component_id = pipeline_node.node_info.id

# Signals from upstream components.
signals_to_wait = []
for upstream_node in pipeline_node.upstream_nodes:
assert upstream_node in signal_map, ('Components is not in '
'topological order')
signals_to_wait.append(signal_map[upstream_node])
logging.info('Component %s depends on %s.', component_id,
[s.producer.full_label for s in signals_to_wait])

# Each signal is an empty PCollection. AsIter ensures component will
# be triggered after upstream components are finished.
# LINT.IfChange
signal_map[component_id] = (
root
| 'Run[%s]' % component_id >> beam.ParDo(
_PipelineNodeAsDoFn(pipeline_node, mlmd_connection,
pipeline.pipeline_info,
pipeline.runtime_spec), *
[beam.pvalue.AsIter(s) for s in signals_to_wait]))
# LINT.ThenChange(//tfx/orchestration/beam/beam_dag_runner.py)
logging.info('Component %s is scheduled.', component_id)
64 changes: 64 additions & 0 deletions tfx/orchestration/portable/beam_dag_runner_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tfx.orchestration.portable.beam_dag_runner."""

import mock
import tensorflow as tf
from tfx.orchestration import metadata
from tfx.orchestration.portable import beam_dag_runner
from tfx.orchestration.portable import test_utils
from tfx.proto.orchestration import pipeline_pb2


_executed_components = []


# TODO(b/162980675): When PythonExecutorOperator is implemented. We don't
# Need to Fake the whole FakeComponentAsDoFn. Instead, just fake or mock
# executors.
class _FakeComponentAsDoFn(beam_dag_runner._PipelineNodeAsDoFn):

def __init__(self,
pipeline_node: pipeline_pb2.PipelineNode,
mlmd_connection: metadata.Metadata,
pipeline_info: pipeline_pb2.PipelineInfo,
pipeline_runtime_spec: pipeline_pb2.PipelineRuntimeSpec):
self._component_id = pipeline_node.node_info.id

def _run_component(self):
_executed_components.append(self._component_id)


class BeamDagRunnerTest(test_utils.TfxTest):

def setUp(self):
super(BeamDagRunnerTest, self).setUp()
# Setup pipelines
self._pipeline = pipeline_pb2.Pipeline()
self.load_proto_from_text('pipeline_for_launcher_test.pbtxt',
self._pipeline)

@mock.patch.multiple(
beam_dag_runner,
_PipelineNodeAsDoFn=_FakeComponentAsDoFn,
)
def testRun(self):
beam_dag_runner.BeamDagRunner().run(self._pipeline)
self.assertEqual(_executed_components, [
'my_example_gen', 'my_transform', 'my_trainer'
])


if __name__ == '__main__':
tf.test.main()
Loading

0 comments on commit e1f7326

Please sign in to comment.