-
Notifications
You must be signed in to change notification settings - Fork 720
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Starting to check in IR-based execution stack. When finished, it will…
… realize the IR-based execution workflow introduced in the TFX IR RFC: tensorflow/community#271 PiperOrigin-RevId: 325953252
- Loading branch information
1 parent
55154bc
commit e1f7326
Showing
36 changed files
with
4,679 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# IR-based portable execution stack | ||
|
||
This module is WIP. When finished, it will realize the IR-based execution | ||
workflow introduced in the TFX IR | ||
[RFC](https://github.com/tensorflow/community/pull/271). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# Lint as: python2, python3 | ||
# Copyright 2020 Google LLC. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# Copyright 2020 Google LLC. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Base class to define how to operator an executor.""" | ||
|
||
import abc | ||
from typing import Any, Dict, List, Optional, Text | ||
|
||
import attr | ||
import six | ||
from tfx import types | ||
from tfx.proto.orchestration import execution_result_pb2 | ||
from tfx.proto.orchestration import pipeline_pb2 | ||
from tfx.utils import abc_utils | ||
|
||
from google.protobuf import message | ||
from ml_metadata.proto import metadata_store_pb2 | ||
|
||
|
||
# TODO(b/150979622): We should introduce an id that is not changed across | ||
# retires of the same component run and pass it to executor operators for | ||
# human-readability purpose. | ||
@attr.s(auto_attribs=True) | ||
class ExecutionInfo: | ||
"""A struct to store information for an execution.""" | ||
# The metadata of this execution that is registered in MLMD. | ||
execution_metadata: metadata_store_pb2.Execution = None | ||
# The input map to feed to executor | ||
input_dict: Dict[Text, List[types.Artifact]] = None | ||
# The output map to feed to executor | ||
output_dict: Dict[Text, List[types.Artifact]] = None | ||
# The exec_properties to feed to executor | ||
exec_properties: Dict[Text, Any] = None | ||
# The uri to executor result, note that Executors and Launchers may not run | ||
# in the same process, so executors should use this uri to "return" | ||
# ExecutorOutput to the launcher. | ||
executor_output_uri: Text = None | ||
# Stateful working dir will be deterministic given pipeline, node and run_id. | ||
# The typical usecase is to restore long running executor's state after | ||
# eviction. For examples, a Trainer can use this directory to store | ||
# checkpoints. | ||
stateful_working_dir: Text = None | ||
# The config of this Node. | ||
pipeline_node: pipeline_pb2.PipelineNode = None | ||
# The config of the pipeline that this node is running in. | ||
pipeline_info: pipeline_pb2.PipelineInfo = None | ||
|
||
|
||
class BaseExecutorOperator(six.with_metaclass(abc.ABCMeta, object)): | ||
"""The base class of all executor operators.""" | ||
|
||
SUPPORTED_EXECUTOR_SPEC_TYPE = abc_utils.abstract_property() | ||
SUPPORTED_PLATFORM_SPEC_TYPE = abc_utils.abstract_property() | ||
|
||
def __init__(self, | ||
executor_spec: message.Message, | ||
platform_spec: Optional[message.Message] = None): | ||
"""Constructor. | ||
Args: | ||
executor_spec: The specification of how to initialize the executor. | ||
platform_spec: The specification of how to allocate resource for the | ||
executor. | ||
Raises: | ||
RuntimeError: if the executor_spec or platform_spec is not supported. | ||
""" | ||
if not isinstance(executor_spec, | ||
tuple(t for t in self.SUPPORTED_EXECUTOR_SPEC_TYPE)): | ||
raise RuntimeError('Executor spec not supported: %s' % executor_spec) | ||
if platform_spec and not isinstance( | ||
platform_spec, tuple(t for t in self.SUPPORTED_PLATFORM_SPEC_TYPE)): | ||
raise RuntimeError('Platform spec not supported: %s' % platform_spec) | ||
self._executor_spec = executor_spec | ||
self._platform_spec = platform_spec | ||
|
||
@abc.abstractmethod | ||
def run_executor( | ||
self, | ||
execution_info: ExecutionInfo, | ||
) -> execution_result_pb2.ExecutorOutput: | ||
"""Invokers executors given input from the Launcher. | ||
Args: | ||
execution_info: A wrapper of the details of this execution. | ||
Returns: | ||
The output from executor. | ||
""" | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
# Copyright 2020 Google LLC. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Definition of Beam TFX runner.""" | ||
|
||
import os | ||
from typing import Any, Iterable | ||
|
||
from absl import logging | ||
import apache_beam as beam | ||
from tfx.orchestration import metadata | ||
from tfx.orchestration.portable import launcher | ||
from tfx.orchestration.portable import tfx_runner | ||
from tfx.proto.orchestration import pipeline_pb2 | ||
from tfx.utils import telemetry_utils | ||
|
||
from ml_metadata.proto import metadata_store_pb2 | ||
|
||
|
||
# TODO(jyzhao): confirm it's re-executable, add test case. | ||
@beam.typehints.with_input_types(Any) | ||
@beam.typehints.with_output_types(Any) | ||
class _PipelineNodeAsDoFn(beam.DoFn): | ||
"""Wrap component as beam DoFn.""" | ||
|
||
def __init__(self, | ||
pipeline_node: pipeline_pb2.PipelineNode, | ||
mlmd_connection: metadata.Metadata, | ||
pipeline_info: pipeline_pb2.PipelineInfo, | ||
pipeline_runtime_spec: pipeline_pb2.PipelineRuntimeSpec): | ||
"""Initializes the _PipelineNodeAsDoFn. | ||
Args: | ||
pipeline_node: The specification of the node that this launcher lauches. | ||
mlmd_connection: ML metadata connection. The connection is expected to | ||
not be opened before launcher is initiated. | ||
pipeline_info: The information of the pipeline that this node runs in. | ||
pipeline_runtime_spec: The runtime information of the pipeline that this | ||
node runs in. | ||
""" | ||
self._launcher = launcher.Launcher( | ||
pipeline_node=pipeline_node, | ||
mlmd_connection=mlmd_connection, | ||
pipeline_info=pipeline_info, | ||
pipeline_runtime_spec=pipeline_runtime_spec) | ||
self._component_id = pipeline_node.node_info.id | ||
|
||
def process(self, element: Any, *signals: Iterable[Any]) -> None: | ||
"""Executes component based on signals. | ||
Args: | ||
element: a signal element to trigger the component. | ||
*signals: side input signals indicate completeness of upstream components. | ||
""" | ||
for signal in signals: | ||
assert not list(signal), 'Signal PCollection should be empty.' | ||
self._run_component() | ||
|
||
def _run_component(self) -> None: | ||
logging.info('Component %s is running.', self._component_id) | ||
self._launcher.launch() | ||
logging.info('Component %s is finished.', self._component_id) | ||
|
||
|
||
class BeamDagRunner(tfx_runner.TfxRunner): | ||
"""Tfx runner on Beam.""" | ||
|
||
def __init__(self): | ||
"""Initializes BeamDagRunner as a TFX orchestrator. | ||
""" | ||
|
||
def run(self, pipeline: pipeline_pb2.Pipeline) -> None: | ||
"""Deploys given logical pipeline on Beam. | ||
Args: | ||
pipeline: Logical pipeline in IR format. | ||
""" | ||
# For CLI, while creating or updating pipeline, pipeline_args are extracted | ||
# and hence we avoid deploying the pipeline. | ||
if 'TFX_JSON_EXPORT_PIPELINE_ARGS_PATH' in os.environ: | ||
return | ||
|
||
# TODO(b/163003901): Support beam DAG runner args through IR. | ||
# TODO(b/163003901): MLMD connection config should be passed in via IR. | ||
connection_config = metadata_store_pb2.ConnectionConfig() | ||
connection_config.sqlite.SetInParent() | ||
mlmd_connection = metadata.Metadata( | ||
connection_config=connection_config) | ||
|
||
with telemetry_utils.scoped_labels( | ||
{telemetry_utils.LABEL_TFX_RUNNER: 'beam'}): | ||
with beam.Pipeline() as p: | ||
# Uses for triggering the component DoFns. | ||
root = p | 'CreateRoot' >> beam.Create([None]) | ||
|
||
# Stores mapping of component to its signal. | ||
signal_map = {} | ||
# pipeline.components are in topological order. | ||
for node in pipeline.nodes: | ||
# TODO(b/160882349): Support subpipeline | ||
pipeline_node = node.pipeline_node | ||
component_id = pipeline_node.node_info.id | ||
|
||
# Signals from upstream components. | ||
signals_to_wait = [] | ||
for upstream_node in pipeline_node.upstream_nodes: | ||
assert upstream_node in signal_map, ('Components is not in ' | ||
'topological order') | ||
signals_to_wait.append(signal_map[upstream_node]) | ||
logging.info('Component %s depends on %s.', component_id, | ||
[s.producer.full_label for s in signals_to_wait]) | ||
|
||
# Each signal is an empty PCollection. AsIter ensures component will | ||
# be triggered after upstream components are finished. | ||
# LINT.IfChange | ||
signal_map[component_id] = ( | ||
root | ||
| 'Run[%s]' % component_id >> beam.ParDo( | ||
_PipelineNodeAsDoFn(pipeline_node, mlmd_connection, | ||
pipeline.pipeline_info, | ||
pipeline.runtime_spec), * | ||
[beam.pvalue.AsIter(s) for s in signals_to_wait])) | ||
# LINT.ThenChange(//tfx/orchestration/beam/beam_dag_runner.py) | ||
logging.info('Component %s is scheduled.', component_id) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# Copyright 2020 Google LLC. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Tests for tfx.orchestration.portable.beam_dag_runner.""" | ||
|
||
import mock | ||
import tensorflow as tf | ||
from tfx.orchestration import metadata | ||
from tfx.orchestration.portable import beam_dag_runner | ||
from tfx.orchestration.portable import test_utils | ||
from tfx.proto.orchestration import pipeline_pb2 | ||
|
||
|
||
_executed_components = [] | ||
|
||
|
||
# TODO(b/162980675): When PythonExecutorOperator is implemented. We don't | ||
# Need to Fake the whole FakeComponentAsDoFn. Instead, just fake or mock | ||
# executors. | ||
class _FakeComponentAsDoFn(beam_dag_runner._PipelineNodeAsDoFn): | ||
|
||
def __init__(self, | ||
pipeline_node: pipeline_pb2.PipelineNode, | ||
mlmd_connection: metadata.Metadata, | ||
pipeline_info: pipeline_pb2.PipelineInfo, | ||
pipeline_runtime_spec: pipeline_pb2.PipelineRuntimeSpec): | ||
self._component_id = pipeline_node.node_info.id | ||
|
||
def _run_component(self): | ||
_executed_components.append(self._component_id) | ||
|
||
|
||
class BeamDagRunnerTest(test_utils.TfxTest): | ||
|
||
def setUp(self): | ||
super(BeamDagRunnerTest, self).setUp() | ||
# Setup pipelines | ||
self._pipeline = pipeline_pb2.Pipeline() | ||
self.load_proto_from_text('pipeline_for_launcher_test.pbtxt', | ||
self._pipeline) | ||
|
||
@mock.patch.multiple( | ||
beam_dag_runner, | ||
_PipelineNodeAsDoFn=_FakeComponentAsDoFn, | ||
) | ||
def testRun(self): | ||
beam_dag_runner.BeamDagRunner().run(self._pipeline) | ||
self.assertEqual(_executed_components, [ | ||
'my_example_gen', 'my_transform', 'my_trainer' | ||
]) | ||
|
||
|
||
if __name__ == '__main__': | ||
tf.test.main() |
Oops, something went wrong.