Skip to content

Commit

Permalink
Merge pull request apache#6 from prashantsail/ps_spop
Browse files Browse the repository at this point in the history
merging changes required to get parent subgraph for processing.
  • Loading branch information
deepakbabel23 authored May 15, 2020
2 parents b77cece + 14982af commit 4b396c0
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 47 deletions.
25 changes: 15 additions & 10 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1931,30 +1931,37 @@ def _impl(inputs, attr, params, mod):
return _impl

def _partitioned_call():
def _impl(inputs, attr, params, mod, graph):
from tensorflow.python.framework import function_def_to_graph
from tensorflow.python.framework import ops

def _impl(inputs, attr, params, mod):
node_func_name = attr.get('f').name
func = next((f for f in graph.library.function if f.signature.name == node_func_name), None)
if func:
from tensorflow.python.framework import function_def_to_graph

outer_graph = ops.get_default_graph()
outer_graph_def = outer_graph.as_graph_def(add_shapes=True)

func = next((f for f in outer_graph_def.library.function if f.signature.name == node_func_name), None)
if func:
# Convert function definition to graph
func_input_shapes = func.attr["_input_shapes"].list.shape
subgraph, flat_tensor_name = function_def_to_graph.function_def_to_graph_def(func, func_input_shapes)
subgraph = function_def_to_graph.function_def_to_graph(func, func_input_shapes)
subgraph_def = subgraph.as_graph_def(add_shapes=True)

# Computing subgraph's input shape dictionary
subgraph_shape_dict = {}
for f_arg, input in zip(func.signature.input_arg, inputs):
subgraph_shape_dict[f_arg.name] = _infer_shape(input)

# Construct relay nodes from the subgraph
# Construct relay nodes from the subgraph_def
g = GraphProto()
mod, params = g.from_tensorflow(subgraph, shape=subgraph_shape_dict)
mod, params = g.from_tensorflow(subgraph_def, shape=subgraph_shape_dict)
wl = tvm.relay.var('partitioned_call')
sb = tvm.relay.scope_builder.ScopeBuilder()
sb.let(wl, mod["main"])
sb.ret(wl(*inputs))
op = sb.get()
return op

return _impl

def _stateful_partitioned_call():
Expand Down Expand Up @@ -3222,9 +3229,7 @@ def _convert_operator(self, op_name, inputs, attrs,
if op_name in identity_list:
sym = get_relay_op(op_name)(*inputs, **attrs)
elif op_name in convert_map:
if op_name in ["PartitionedCall", "StatefulPartitionedCall"]:
sym = convert_map[op_name](inputs, attrs, self._params, self._mod, self._graph)
elif _need_prelude_for_shape_inference(op_name):
if _need_prelude_for_shape_inference(op_name):
sym = convert_map[op_name](inputs, attrs, self._params, self._prelude)
else:
sym = convert_map[op_name](inputs, attrs, self._params, self._mod)
Expand Down
83 changes: 46 additions & 37 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3168,9 +3168,7 @@ def test_forward_isfinite():
_verify_infiniteness_ops(tf.is_finite, "isfinite")

def _test_spop_placeholder_one():
tf.reset_default_graph()
g = tf.Graph()
with g.as_default():
with tf.Graph().as_default():

@function.Defun(*[tf.int32]*2)
def Forward(x,y):
Expand All @@ -3190,7 +3188,6 @@ def Forward(x,y):
['StatefulPartitionedCall:0',z2.name], mode='vm', init_global_variables=True)

def _test_spop_placeholder_two():

with tf.Graph().as_default():
data = np.ones([1], dtype=int).astype(np.int32)
dataVar = tf.Variable(data, shape=data.shape)
Expand All @@ -3205,37 +3202,34 @@ def pl_with_default(pl):
compare_tf_with_tvm(data, ['pl1:0'], 'StatefulPartitionedCall:0', mode='vm', init_global_variables=True)

def _test_spop_placeholder_three():
tf.disable_eager_execution()
t1 = tf.placeholder(tf.int32, (3, 3, 3), "t1")
t1_data = np.arange(27, dtype=np.int32).reshape((3, 3, 3))
t2 = tf.placeholder(tf.int32, (3, 3, 3), "t2")
t2_data = np.arange(27, dtype=np.int32).reshape((3, 3, 3))
with tf.Graph().as_default():
t1 = tf.placeholder(tf.int32, (3, 3, 3), "t1")
t1_data = np.arange(27, dtype=np.int32).reshape((3, 3, 3))
t2 = tf.placeholder(tf.int32, (3, 3, 3), "t2")
t2_data = np.arange(27, dtype=np.int32).reshape((3, 3, 3))

@tf.function
def add(x, y):
return tf.add(x, y, "add_t1_t2")
@tf.function
def add(x, y):
return tf.add(x, y, "add_t1_t2")

t3 = add(t1, t2)
compare_tf_with_tvm([t1_data, t2_data], ['t1:0', 't2:0'], [t3.name], mode='vm', init_global_variables=True)
t3 = add(t1, t2)
compare_tf_with_tvm([t1_data, t2_data], ['t1:0', 't2:0'], [t3.name], mode='vm', init_global_variables=True)

def _test_spop_placeholder_four():
tf.disable_eager_execution()
t1_data = np.array([[-1, 1, 3], [2, -2, 4], [2, -3, 14]], dtype=np.int32)
t2_data = np.array([[-2, 1, 2], [12, -2, 14], [12, -3, 4]], dtype=np.int32)
tf.reset_default_graph()
t1 = tf.placeholder(tf.int32, name="t1")
t2 = tf.placeholder(tf.int32, name="t2")
with tf.Graph().as_default():
t1_data = np.array([[-1, 1, 3], [2, -2, 4], [2, -3, 14]], dtype=np.int32)
t2_data = np.array([[-2, 1, 2], [12, -2, 14], [12, -3, 4]], dtype=np.int32)
t1 = tf.placeholder(tf.int32, name="t1")
t2 = tf.placeholder(tf.int32, name="t2")

@tf.function
def add(x, y):
return tf.add(x, y, "add_t1_t2")
@tf.function
def add(x, y):
return tf.add(x, y, "add_t1_t2")

t3 = add(t1, t2)
compare_tf_with_tvm([t1_data, t2_data], ['t1:0', 't2:0'], [t3.name], mode='vm', init_global_variables=True)
t3 = add(t1, t2)
compare_tf_with_tvm([t1_data, t2_data], ['t1:0', 't2:0'], [t3.name], mode='vm', init_global_variables=True)

def _test_spop_function_invocation_basic():
tf.disable_eager_execution()
tf.reset_default_graph()
with tf.Graph().as_default():

def fun1(a):
Expand All @@ -3255,9 +3249,30 @@ def fun3(x,y):

compare_tf_with_tvm([], [], [t3.name], mode='vm', init_global_variables=True)

def _test_spop_function_invocation_nested():
with tf.Graph().as_default():
t1 = tf.compat.v1.placeholder(tf.int32, (3, 3, 3), name="t1")
t1_data = np.arange(27, dtype=np.int32).reshape((3, 3, 3))
t2 = tf.compat.v1.placeholder(tf.int32, name="t2")
t2_data = np.arange(27, dtype=np.int32).reshape((3, 3, 3))

@tf.function
def myfunc(x, y):
return tf.add(x, y, "myfunc")

@tf.function
def myfunc2(x, y):
z = myfunc(x, y)
l = myfunc(z, y)
m = myfunc(l,z)
return tf.add(l, m, "myfunc2")

res1 = myfunc(t1, t2)
res2 = myfunc2(res1, t1)

compare_tf_with_tvm([t1_data, t2_data], ['t1:0', 't2:0'], [res2.name], mode='vm', init_global_variables=True)

def _test_spop_function_invocation_autograph():
tf.disable_eager_execution()
tf.reset_default_graph()
with tf.Graph().as_default():

@tf.function
Expand All @@ -3280,7 +3295,6 @@ def fun3(x,y):
compare_tf_with_tvm([], [], [t3.name], mode='vm', init_global_variables=True)

def _test_spop_function_invocation_defun():
tf.reset_default_graph()
with tf.Graph().as_default():

def fun1(a):
Expand All @@ -3301,7 +3315,6 @@ def fun3(x,y):
compare_tf_with_tvm([],[], 'SpopFnInvocation:0', mode='vm', init_global_variables=True)

def _test_spop_arithmetic():
tf.reset_default_graph()
with tf.Graph().as_default():
@function.Defun(*[dtypes.int32]*3)
def arithmetic(m,x,c):
Expand All @@ -3316,7 +3329,6 @@ def arithmetic(m,x,c):
compare_tf_with_tvm([],[],'StatefulPartitionedCall:0', mode='vm', init_global_variables=True)

def _test_spop_control_flow():
tf.reset_default_graph()
with tf.Graph().as_default():

@function.Defun(*[dtypes.float32] * 2)
Expand All @@ -3335,10 +3347,7 @@ def Body1(x, y):
compare_tf_with_tvm([], [], 'StatefulPartitionedCall:0', mode='vm', init_global_variables=True)

def _test_spop_variables():
tf.reset_default_graph()
g = tf.Graph()
with g.as_default():

with tf.Graph().as_default():
const1 = tf.constant(10)
const2 = tf.constant(20)
var1 = tf.Variable(const1, dtype=tf.int32)
Expand All @@ -3352,7 +3361,6 @@ def Forward(x,y):
compare_tf_with_tvm([], [], 'StatefulPartitionedCall:0', init_global_variables=True, mode="vm")

def _test_spop_constants():
tf.reset_default_graph()
with tf.Graph().as_default():
@function.Defun(*[dtypes.int32] * 2)
def constantsFn(x, y):
Expand All @@ -3373,6 +3381,7 @@ def _test_spop_placeholder():

def _test_spop_function_invocation():
_test_spop_function_invocation_basic()
_test_spop_function_invocation_nested()
_test_spop_function_invocation_autograph()
_test_spop_function_invocation_defun()

Expand Down

0 comments on commit 4b396c0

Please sign in to comment.