From 30eb5c7d79787473c57f986b45798e4c02c405cd Mon Sep 17 00:00:00 2001 From: Wang Date: Mon, 4 Jun 2018 12:06:38 -0700 Subject: [PATCH] Address comments --- nnvm/python/nnvm/compiler/build_module.py | 6 +++--- nnvm/python/nnvm/compiler/graph_util.py | 3 +++ nnvm/src/top/vision/nms.cc | 10 +++------- nnvm/src/top/vision/ssd/mutibox_op.cc | 12 ++++++------ 4 files changed, 15 insertions(+), 16 deletions(-) diff --git a/nnvm/python/nnvm/compiler/build_module.py b/nnvm/python/nnvm/compiler/build_module.py index dcaa258a0f728..bac92028f24f8 100644 --- a/nnvm/python/nnvm/compiler/build_module.py +++ b/nnvm/python/nnvm/compiler/build_module.py @@ -331,9 +331,9 @@ def build(graph, target=None, shape=None, dtype="float32", if params is None: params = {} params.update(init_var) - if not build_extra: - return graph, libmod, params - return graph, libmod, params, extra_lib + if build_extra: + return graph, libmod, params, extra_lib + return graph, libmod, params def _run_graph(graph, params): diff --git a/nnvm/python/nnvm/compiler/graph_util.py b/nnvm/python/nnvm/compiler/graph_util.py index 621872ead98b5..3b2915b38b6dc 100644 --- a/nnvm/python/nnvm/compiler/graph_util.py +++ b/nnvm/python/nnvm/compiler/graph_util.py @@ -167,6 +167,9 @@ def split_last_op(graph): """ graph_idx = graph.index last_op_node = graph_idx.nodes[-1] + if last_op_node["op"] == "null": + raise RuntimeError("split_last_op doesn't support sast operator " + "to be null.") last_op_func = getattr(sym, last_op_node["op"]) if "attrs" in last_op_node: last_op_attr = last_op_node["attrs"] diff --git a/nnvm/src/top/vision/nms.cc b/nnvm/src/top/vision/nms.cc index 4a6723222b21b..2680b894255b4 100644 --- a/nnvm/src/top/vision/nms.cc +++ b/nnvm/src/top/vision/nms.cc @@ -27,17 +27,13 @@ bool NMSShape(const NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 2U) << "Inputs: [data, valid_count]"; TShape dshape = in_attrs->at(0); TShape vshape = in_attrs->at(1); - CHECK_EQ(dshape.ndim(), 3U) << "Provided: " << dshape; - CHECK_EQ(vshape.ndim(), 1U) << "Provided: " << vshape; + CHECK_EQ(dshape.ndim(), 3U) << "Input data should be 3-D."; + CHECK_EQ(vshape.ndim(), 1U) << "Input valid count should be 1-D."; CHECK_EQ(dshape[2], 6U) << "Data input should have shape " "(batch_size, num_anchors, 6)."; CHECK_EQ(dshape[0], vshape[0]) << "batch_size mismatch."; - TShape oshape = TShape(3); - oshape[0] = dshape[0]; - oshape[1] = dshape[1]; - oshape[2] = 6; // [id, prob, xmin, ymin, xmax, ymax] out_attrs->clear(); - NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, oshape); + NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, dshape); return true; } diff --git a/nnvm/src/top/vision/ssd/mutibox_op.cc b/nnvm/src/top/vision/ssd/mutibox_op.cc index 577657ecb304c..d02ae802c636f 100644 --- a/nnvm/src/top/vision/ssd/mutibox_op.cc +++ b/nnvm/src/top/vision/ssd/mutibox_op.cc @@ -91,12 +91,12 @@ bool MultiBoxDetectionShape(const NodeAttrs& attrs, TShape cshape = in_attrs->at(0); TShape lshape = in_attrs->at(1); TShape ashape = in_attrs->at(2); - CHECK_EQ(cshape.ndim(), 3U) << "Provided: " << cshape; - CHECK_EQ(lshape.ndim(), 2U) << "Provided: " << lshape; - CHECK_EQ(ashape.ndim(), 3U) << "Provided: " << ashape; - CHECK_EQ(cshape[2], ashape[1]) << "Number of anchors mismatch"; - CHECK_EQ(cshape[2] * 4, lshape[1]) << "# anchors mismatch with # loc"; - CHECK_GT(ashape[1], 0U) << "Number of anchors must > 0"; + CHECK_EQ(cshape.ndim(), 3U) << "Class probability should be 3-D."; + CHECK_EQ(lshape.ndim(), 2U) << "Location prediction should be 2-D."; + CHECK_EQ(ashape.ndim(), 3U) << "Anchor should be 3-D."; + CHECK_EQ(cshape[2], ashape[1]) << "Number of anchors mismatch."; + CHECK_EQ(cshape[2] * 4, lshape[1]) << "# anchors mismatch with # loc."; + CHECK_GT(ashape[1], 0U) << "Number of anchors must > 0."; CHECK_EQ(ashape[2], 4U); TShape oshape = TShape(3); oshape[0] = cshape[0];