Skip to content

Commit

Permalink
[Target][BugFix] Convert dict and str to TVM object (#9807)
Browse files Browse the repository at this point in the history
* [Target][BugFix] Convert dict and str to TVM object

* Add tests
  • Loading branch information
leeexyz authored Dec 31, 2021
1 parent 0d7e2ec commit a5ac362
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 8 deletions.
22 changes: 15 additions & 7 deletions python/tvm/target/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@

import tvm._ffi
from tvm._ffi import register_func as _register_func
from tvm.runtime import Object
from tvm.runtime import Object, convert
from tvm.runtime.container import String
from tvm.ir.container import Map

from . import _ffi_api

Expand Down Expand Up @@ -107,10 +109,14 @@ def __init__(self, target, host=None):
When using a dictionary or json string to configure target, the possible values are
same as target.
"""
if target is None or not isinstance(target, (dict, str, Target)):
if isinstance(target, (dict, str)):
target = convert(target)
if isinstance(host, (dict, str)):
host = convert(host)
if target is None or not isinstance(target, (Map, String, Target)):
raise ValueError("target has to be a string or dictionary.")
if host is not None:
if not isinstance(host, (dict, str, Target)):
if not isinstance(host, (Map, String, Target)):
raise ValueError("target host has to be a string or dictionary.")
self.__init_handle_by_constructor__(_ffi_api.Target, Target(target), Target(host))
else:
Expand Down Expand Up @@ -221,15 +227,19 @@ def check_and_update_host_consist(target, host=None, target_is_dict_key=True):
target_is_dict_key : Bool
When the type of target is dict, whether Target is the key (Otherwise the value)
"""
if isinstance(target, (dict, str)):
target = convert(target)
if isinstance(host, (dict, str)):
host = convert(host)
if target is None:
assert host is None, "Target host is not empty when target is empty."
return target, host
if isinstance(target, dict) and "kind" not in target:
if isinstance(target, Map) and "kind" not in target:
new_target = {}
for tgt, mod in target.items():
if not target_is_dict_key:
tgt, mod = mod, tgt
if isinstance(tgt, (dict, str, Target)):
if isinstance(tgt, (Map, String, Target)):
tgt, host = Target.check_and_update_host_consist(tgt, host)
if not target_is_dict_key:
tgt, mod = mod, tgt
Expand All @@ -242,8 +252,6 @@ def check_and_update_host_consist(target, host=None, target_is_dict_key=True):


# TODO(@tvm-team): Deprecate the helper functions below. Encourage the usage of config dict instead.


def _merge_opts(opts, new_opts):
"""Helper function to merge options"""
if isinstance(new_opts, str):
Expand Down
26 changes: 25 additions & 1 deletion tests/python/relay/test_build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

import pytest

import tvm
from tvm import relay
from tvm.target.target import Target
from tvm.relay.backend import Runtime, Executor
from tvm.relay.backend import Runtime, Executor, graph_executor_codegen
from tvm.relay.build_module import _reconstruct_from_deprecated_options


Expand Down Expand Up @@ -58,5 +60,27 @@ def test_deprecated_target_parameters(target, executor, runtime):
assert runtime == actual_runtime


def test_build_relay_graph_():
"""Test to build a simple relay graph by using APIs directly"""

def build_graph(mod, target):
target = relay.build_module.build_target_by_device_type_map(target)
target, target_host = tvm.target.Target.check_and_update_host_consist(target)
mod, _ = relay.optimize(mod, target, None)
grc = graph_executor_codegen.GraphExecutorCodegen(None, target)
_, lowered_funcs, _ = grc.codegen(mod, mod["main"])
_ = relay.backend._backend.build(lowered_funcs, target, target_host)

def add(shape, dtype):
lhs = relay.var("A", shape=shape, dtype=dtype)
rhs = relay.var("B", shape=shape, dtype=dtype)
out = relay.add(lhs, rhs)
expr = relay.Function((lhs, rhs), out)
mod = tvm.IRModule.from_expr(expr)
return mod

build_graph(add((1, 8), "float32"), tvm.target.Target("llvm"))


if __name__ == "__main__":
pytest.main()
32 changes: 32 additions & 0 deletions tests/python/unittest/test_target_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,17 @@ def test_target_host_merge_2():
assert tgt.host.kind.name == "llvm"


def test_target_tvm_object():
"""Test creating Target by using TVM Objects"""
String = tvm.runtime.container.String
tgt = tvm.target.Target(target=String("cuda --host llvm"))
assert tgt.kind.name == "cuda"
assert tgt.host.kind.name == "llvm"
tgt = tvm.target.Target(target=String("cuda"), host=String("llvm"))
assert tgt.kind.name == "cuda"
assert tgt.host.kind.name == "llvm"


@pytest.mark.skip(reason="Causing infinite loop because of pytest and handle issue")
def test_target_host_merge_3():
with pytest.raises(ValueError, match=r"target host has to be a string or dictionary."):
Expand Down Expand Up @@ -372,6 +383,27 @@ def test_check_and_update_host_consist_3():
assert target.host == host


def test_check_and_update_host_consist_4():
"""Test `check_and_update_host_consist` by using TVM Objects"""
cuda_device_type = tvm.device("cuda").device_type
target = {cuda_device_type: Target(target="cuda", host="llvm")}
host = None
target_1, host_1 = Target.check_and_update_host_consist(target, host)
assert isinstance(target_1, dict)
assert target_1[cuda_device_type].kind.name == "cuda"
assert target_1[cuda_device_type].host.kind.name == "llvm"
assert host_1 is None

target = {cuda_device_type: Target(tvm.runtime.container.String("cuda"))}
host = Target(tvm.runtime.container.String("llvm"))
target = tvm.runtime.convert(target)
assert isinstance(target, tvm.ir.container.Map)
target_2, host_2 = Target.check_and_update_host_consist(target, host)
assert isinstance(target_2, dict)
assert target_2[cuda_device_type].kind.name == "cuda"
assert host_2.kind.name == "llvm"


def test_target_attr_bool_value():
target0 = Target("vulkan --supports_float16=True")
assert target0.attrs["supports_float16"] == 1
Expand Down

0 comments on commit a5ac362

Please sign in to comment.