From e6f30dbd9515f005c503c4a4bbbad12a06ca4902 Mon Sep 17 00:00:00 2001 From: SRK Reddy Date: Sat, 8 Sep 2018 22:42:42 +0530 Subject: [PATCH] [FRONTEND][TENSORFLOW] Helper function to add shapes into the graph. Use tmp folder for model files and clean it. --- nnvm/python/nnvm/testing/tf.py | 35 +++++++++++++++++++++++++++++-- tutorials/nnvm/from_tensorflow.py | 4 ++-- 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/nnvm/python/nnvm/testing/tf.py b/nnvm/python/nnvm/testing/tf.py index 0372d7450586..f5b49b2280b4 100644 --- a/nnvm/python/nnvm/testing/tf.py +++ b/nnvm/python/nnvm/testing/tf.py @@ -8,6 +8,7 @@ import os.path import collections import numpy as np +from tvm.contrib import util # Tensorflow imports import tensorflow as tf @@ -43,6 +44,31 @@ def ProcessGraphDefParam(graph_def): raise TypeError('graph_def must be a GraphDef proto.') return graph_def + +def AddShapesToGraphDef(out_node): + """ Add shapes attribute to nodes of the graph. + Input graph here is the default graph in context. + + Parameters + ---------- + out_node: String + Final output node of the graph. + + Returns + ------- + graph_def : Obj + tensorflow graph definition with shapes attribute added to nodes. + + """ + + with tf.Session() as sess: + graph_def = tf.graph_util.convert_variables_to_constants( + sess, + sess.graph.as_graph_def(add_shapes=True), + [out_node], + ) + return graph_def + class NodeLookup(object): """Converts integer node ID's to human readable labels.""" @@ -128,13 +154,18 @@ def get_workload(model_path): model_url = os.path.join(repo_base, model_path) from mxnet.gluon.utils import download - download(model_url, model_name) + + temp = util.tempdir() + path_model = temp.relpath(model_name) + + download(model_url, path_model) # Creates graph from saved graph_def.pb. - with tf.gfile.FastGFile(os.path.join("./", model_name), 'rb') as f: + with tf.gfile.FastGFile(path_model, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) graph = tf.import_graph_def(graph_def, name='') + temp.remove() return graph_def ####################################################################### diff --git a/tutorials/nnvm/from_tensorflow.py b/tutorials/nnvm/from_tensorflow.py index ee025c5b09ff..db6cd0a43654 100644 --- a/tutorials/nnvm/from_tensorflow.py +++ b/tutorials/nnvm/from_tensorflow.py @@ -64,7 +64,6 @@ download(map_proto_url, map_proto) download(lable_map_url, lable_map) - ###################################################################### # Import model # ------------ @@ -76,7 +75,8 @@ graph = tf.import_graph_def(graph_def, name='') # Call the utility to import the graph definition into default graph. graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def) - + # Add shapes to the graph. + graph_def = nnvm.testing.tf.AddShapesToGraphDef('softmax') ###################################################################### # Decode image