Skip to content

Commit

Permalink
[Frontend][Relay][Parser] fix unparsable yolo formals (apache#6963)
Browse files Browse the repository at this point in the history
* fix yolo formals

* fix lint

* move test to test_forward
  • Loading branch information
hypercubestart authored and Trevor Morris committed Dec 2, 2020
1 parent 7751a66 commit f737788
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relay/frontend/darknet.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _darknet_not_support(attr, op="relay"):

def _get_params_prefix(opname, layer_num):
"""Makes the params prefix name from opname and layer number."""
return str(opname) + str(layer_num)
return str(opname).replace(".", "_") + str(layer_num)


def _get_params_name(prefix, item):
Expand Down
15 changes: 15 additions & 0 deletions tests/python/frontend/darknet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@
)


def astext(program, unify_free_vars=False):
"""check that program is parsable in text format"""
text = program.astext()
if isinstance(program, relay.Expr):
roundtrip_program = tvm.parser.parse_expr(text)
else:
roundtrip_program = tvm.parser.fromtext(text)

tvm.ir.assert_structural_equal(roundtrip_program, program, map_free_vars=True)


def _read_memory_buffer(shape, data, dtype="float32"):
length = 1
for x in shape:
Expand All @@ -60,6 +71,10 @@ def _get_tvm_output(net, data, build_dtype="float32", states=None):
"""Compute TVM output"""
dtype = "float32"
mod, params = relay.frontend.from_darknet(net, data.shape, dtype)
# verify that from_darknet creates a valid, parsable relay program
mod = relay.transform.InferType()(mod)
astext(mod)

target = "llvm"
shape_dict = {"data": data.shape}
lib = relay.build(mod, target, params=params)
Expand Down
5 changes: 2 additions & 3 deletions tests/python/relay/test_ir_text_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import numpy as np
from tvm.relay import Expr
from tvm.relay.analysis import free_vars
import pytest

DEBUG_PRINT = False

Expand Down Expand Up @@ -269,6 +270,4 @@ def test_span():


if __name__ == "__main__":
import sys

pytext.argv(sys.argv)
pytest.main([__file__])

0 comments on commit f737788

Please sign in to comment.