Skip to content

Commit

Permalink
[BugFix][Runtime] Fix Incorrect node information (#13693)
Browse files Browse the repository at this point in the history
* [BugFix][Runtime] Fix Incorrect node information

* 1

* 1
  • Loading branch information
zhaojinxi authored Jan 5, 2023
1 parent 048028b commit 721f115
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 11 deletions.
24 changes: 14 additions & 10 deletions python/tvm/contrib/debugger/debug_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,21 +73,25 @@ def _update_graph_json(self):
"""update the nodes_list with name, shape and data type,
for temporarily storing the output.
"""
nodes_len = len(self._nodes_list)
for i in range(nodes_len):
node = self._nodes_list[i]
eid = 0
for node in self._nodes_list:
input_list = []
for input_node in node["inputs"]:
input_list.append(self._nodes_list[input_node[0]]["name"])
node["inputs"] = input_list
dtype = str("type: " + self._dtype_list[1][i])
if "attrs" not in node:
if node["op"] == "null":
node["attrs"] = {}
node["op"] = "param"
else:
num_outputs = 1
elif node["op"] == "tvm_op":
for input_node in node["inputs"]:
input_list.append(self._nodes_list[input_node[0]]["name"])
node["op"] = node["attrs"]["func_name"]
num_outputs = int(node["attrs"]["num_outputs"])
else:
raise ValueError("")
node["inputs"] = input_list
dtype = str("type: " + self._dtype_list[1][eid])
node["attrs"].update({"T": dtype})
node["shape"] = self._shapes_list[1][i]
node["shape"] = self._shapes_list[1][eid]
eid += num_outputs

def _cleanup_tensors(self):
"""Remove the tensor dump file (graph wont be removed)"""
Expand Down
26 changes: 25 additions & 1 deletion tests/python/unittest/test_runtime_graph_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from tvm._ffi.base import TVMError
from tvm.contrib import utils
from tvm.contrib.debugger import debug_executor

from tvm import relay

# Constants for creating simple graphs, fixtures to avoid free globals
@pytest.fixture
Expand Down Expand Up @@ -275,5 +275,29 @@ def test_run_single_node(graph, n, A, myadd):
mod.run_individual_node(2)


@tvm.testing.requires_llvm
def test_multiple_output():
x = relay.var("x", shape=(1, 3, 48, 16), dtype="float32")
t = relay.split(x, [12, 16, 32], 2).astuple()
x0 = relay.TupleGetItem(t, 0)
x1 = relay.TupleGetItem(t, 1)
x2 = relay.TupleGetItem(t, 2)
x3 = relay.TupleGetItem(t, 3)
p0 = relay.const(np.random.uniform(-1, 1, (3, 3, 1, 1)).astype("float32"))
y = relay.nn.conv2d(x2, p0, kernel_size=(1, 1), kernel_layout="OIHW", out_dtype="float32") + x3

func = relay.Function([x], relay.Tuple([x0, x1, y]))
mod = tvm.IRModule.from_expr(func)
mod = relay.transform.InferType()(mod)
target = tvm.target.Target("llvm")
device = tvm.cpu()
lib = relay.build(mod, target=target)
m = debug_executor.GraphModuleDebug(
lib["debug_create"]("default", device), [device], lib.get_graph_json(), None
)
nodes = m.debug_datum.get_graph_nodes()
assert nodes[2]["shape"] == [3, 3, 1, 1]


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 721f115

Please sign in to comment.