Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cherry-pick pr from apache/tvm #8464 and #8454 #221

Merged
merged 2 commits into from
Jul 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 201 additions & 5 deletions python/tvm/relay/frontend/tensorflow2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except
# pylint: disable=invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except, too-many-nested-blocks
"""Tensorflow2.x graph to relay converter.

If model is constructed using tf2.x API, then use this converter:
Expand All @@ -38,12 +38,20 @@
from .common import infer_type as _infer_type

from .tensorflow_ops import _convert_map as _convert_map_common
from .tensorflow_ops import _need_prelude_for_shape_inference
from .tensorflow_ops import _get_more_static_shape_rank
from .tensorflow2_ops import _convert_map as _convert_map_tf2
from .tensorflow2_ops import _need_prelude_for_shape_inference

from ..ty import Any

__all__ = ["from_tensorflow"]

# A map to record tensor list write ops and input tl/tensor indices
# Value is (index of tensor list, index of written node)
_tensor_list_write_ops = {
"TensorListSetItem": (0, 2),
}


def _infer_type_with_prelude(val, prelude):
body = _infer_type(val, prelude.mod)
Expand All @@ -66,6 +74,11 @@ def set_span(sym, node_name):
return sym


def is_tensor_list_constuctor(tf_node):
"""Check whether is tensor list constructor node."""
return tf_node.op == "TensorListReserve"


def convert_const_node(node, shape):
"""convert tf const node into relay const or var"""

Expand Down Expand Up @@ -196,6 +209,10 @@ def __init__(self, module):
self._output_shapes = {}
self._tf_node_map = {}
self._gdef_lib = {}
self._tensor_list_shapes = {}
self._tensor_list_shape_nodes = {}
self._sub_map = {}
self._sub_input_idx_map = {}

def from_tensorflow(
self, graph, layout="NHWC", shape=None, outputs=None, input_types=None, gdef_lib=None
Expand All @@ -215,10 +232,134 @@ def from_tensorflow(
)
return func, self._params

def _analysis_tensor_list_op(
self,
graph,
node,
tl_write_nodes,
tl_stack_nodes,
tl_construct_nodes,
sub_func_name="",
root_node="",
):
if sub_func_name and sub_func_name not in self._sub_input_idx_map:
self._sub_input_idx_map[sub_func_name] = {}

if node.op == "Placeholder":
# record placeholder node in sub functions
self._sub_map[sub_func_name] = node
self._sub_input_idx_map[sub_func_name][node.name] = len(
self._sub_input_idx_map[sub_func_name]
)

if node.op.startswith("TensorList"):
if is_tensor_list_constuctor(node):
tl_construct_nodes.append(node)
else:
for tl_write_name, idx in _tensor_list_write_ops.items():
if node.op.startswith(tl_write_name):
tl_write_nodes.append((node, idx, sub_func_name, root_node))
if node.op.startswith("TensorListStack"):
tl_stack_nodes.append(node)
elif node.op.startswith("StatelessWhile"):
root_node = node.name
cond_fn_name, body_fn_name = [
parse_attr(node.attr).get(x).name for x in ["cond", "body"]
]
for fn_name in [cond_fn_name, body_fn_name]:
subfunction = self._gdef_lib[fn_name]
sub_func_name = fn_name
for sub_node in subfunction.node:
# bypass const node
if sub_node.op == "Const":
continue
self._tf_node_map[sub_node.name] = sub_node
self._analysis_tensor_list_op(
subfunction,
sub_node,
tl_write_nodes,
tl_stack_nodes,
tl_construct_nodes,
sub_func_name=sub_func_name,
root_node=root_node,
)

def _infer_static_shape_stack_node(self, tl_stack_nodes):
for stack_node in tl_stack_nodes:
if len(stack_node.input) < 2:
# Stack node does not have shape
continue
input_shape_name = stack_node.input[1].split(":")[0]
input_shape_node = self._tf_node_map[input_shape_name]
stack = [self._tf_node_map[stack_node.input[0].split(":")[0]]]
in_idx = -1
while stack:
cnode = stack.pop(0)
if not cnode.op.startswith("TensorList"):
if in_idx and cnode.op.startswith("StatelessWhile"):
stack.append(self._tf_node_map[cnode.input[in_idx].split(":")[0]])
else:
for iname in cnode.input:
if self._tf_node_map[iname.split(":")[0]].op.startswith(
"StatelessWhile"
):
# identify input index based on output index
if iname.split(":")[1]:
in_idx = int(iname.split(":")[1])
stack.append(self._tf_node_map[iname.split(":")[0]])
# identify the corresponding constructor node and add shape to _tensor_list_shapes
elif cnode.name != stack_node.name:
if is_tensor_list_constuctor(cnode):
shape_attr = parse_attr(input_shape_node.attr)
if "value" not in shape_attr:
continue
raw_elem_shape = tensor_util.MakeNdarray(shape_attr["value"])
elem_shape = []
for dim in raw_elem_shape:
if dim < 0:
elem_shape.append(Any())
else:
elem_shape.append(int(dim))
self._tensor_list_shapes[cnode.name] = elem_shape
break

def _infer_static_shape_write_node(self, tl_write_nodes):
for item in tl_write_nodes:
wnode = item[0]
ta_idx, inode_idx = item[1]
sub_func_name = item[2]
root_name = item[3]
stack = [self._tf_node_map[wnode.input[ta_idx].split(":")[0]]]
while stack:
cnode = stack.pop(0)

if not cnode.op.startswith("TensorList"):
if cnode.op == "Placeholder" and sub_func_name:
# need to map subfunction
input_idx = self._sub_input_idx_map[sub_func_name][cnode.name]
stack.append(
self._tf_node_map[
self._tf_node_map[root_name].input[input_idx].split(":")[0]
]
)
else:
for iname in cnode.input:
stack.append(self._tf_node_map[iname.split(":")[0]])
# identify the corresponding constructor node and add it to _tensor_list_shape_nodes
elif cnode.name != wnode.name:
if is_tensor_list_constuctor(cnode):
inode = self._tf_node_map[wnode.input[inode_idx].split(":")[0]]
tn = wnode.input[inode_idx].split(":")
output_index = int(tn[1]) if len(tn) > 1 else 0
self._tensor_list_shape_nodes[cnode.name] = (inode, wnode.op, output_index)
break

def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None, input_types=None):
if input_types is None:
input_types = {}

tl_write_nodes = []
tl_stack_nodes = []
tl_construct_nodes = []
self._layout = layout
for node in graph.node:
name = node.name
Expand All @@ -235,6 +376,18 @@ def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None, input_
self._nodes[node.name] = sym
if param:
self._params[node.name] = param
# recursivly iterate tensorlist op if seen while loop
else:
self._analysis_tensor_list_op(
graph, node, tl_write_nodes, tl_stack_nodes, tl_construct_nodes
)

# Use tensor list stack to infer static tensor list shape
self._infer_static_shape_stack_node(tl_stack_nodes)

# Fetch node contains static tensor list shape
self._infer_static_shape_write_node(tl_write_nodes)

for node in graph.node:
self._backtrack_construct(graph, node.name)

Expand Down Expand Up @@ -321,16 +474,36 @@ def _convert_operator(self, graph, op_name, node_name, inputs, attrs):
gdef_lib=self._gdef_lib,
)
elif op_name in _convert_map_common:
# assert op are exclusive
assert not set(_convert_map_common.keys()) & set(_convert_map_tf2.keys())
if _need_prelude_for_shape_inference(op_name):
sym = _convert_map_common[op_name](inputs, attrs, self._params, self._prelude)
else:
sym = _convert_map_common[op_name](inputs, attrs, self._params, self._module.mod)
elif op_name in _convert_map_tf2:
if _need_prelude_for_shape_inference(op_name):
sym = _convert_map_tf2[op_name](inputs, attrs, self._params, self._prelude)
else:
sym = _convert_map_tf2[op_name](inputs, attrs, self._params, self._module.mod)
else:
raise NotImplementedError("Operator {} not implemented.".format(op_name))

sym = set_span(sym, node_name)
return sym

def _parse_element_shape(self, elem_shape, shape_attr):
if "value" in shape_attr:
raw_elem_shape = tensor_util.MakeNdarray(shape_attr["value"])

if raw_elem_shape.size == 1 and raw_elem_shape == -1:
elem_shape.append(Any())
else:
for dim in raw_elem_shape:
if dim < 0:
elem_shape.append(Any())
else:
elem_shape.append(dim)

def _backtrack_construct(self, graph, node_name):
"""Convert a specific tensorflow node to relay expression.

Expand Down Expand Up @@ -370,8 +543,8 @@ def _backtrack_construct(self, graph, node_name):
CallNode(Op(add), [Var(x, ty=TensorType([], float32)), Constant(1.0)], (nullptr), [])

"""

input_op_name = node_name.split(":")[0].split("^")[-1]

if input_op_name not in self._nodes:
node = self._tf_node_map[input_op_name]
attr = parse_attr(node.attr)
Expand All @@ -386,8 +559,31 @@ def _backtrack_construct(self, graph, node_name):
attr["_node_name"] = node.name
attr["_target_layout"] = self._layout
inputs = [self._backtrack_construct(graph, iname) for iname in node.input]
op = self._convert_operator(graph, node.op, node.name, inputs, attr)

# infer shape for TensorList op
if is_tensor_list_constuctor(node):
input_shape_name = (
node.input[1] if "TensorListFromTensor" in node.op else node.input[0]
)
input_shape_name = input_shape_name.split(":")[0]
input_shape_node = self._tf_node_map[input_shape_name]
shape_attr = parse_attr(input_shape_node.attr)
elem_shape = []

self._parse_element_shape(elem_shape, shape_attr)

if elem_shape:
attr["shape"] = elem_shape
if (
"identical_element_shapes" in attr and attr["identical_element_shapes"]
) or elem_shape:
shape = elem_shape
if node.name in self._tensor_list_shapes:
preset_shape = self._tensor_list_shapes[node.name]
shape = _get_more_static_shape_rank(shape, preset_shape)
attr["shape"] = shape

op = self._convert_operator(graph, node.op, node.name, inputs, attr)
if isinstance(op, np.ndarray):
self._params[node.name] = tvm.nd.array(op)
op = [
Expand Down
Loading