Skip to content

Commit

Permalink
Rebase to pass CI and reflect suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
sunggg authored and junrushao committed Jan 29, 2022
1 parent 419d756 commit c45d16a
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 25 deletions.
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@
"""Testing utilities in meta schedule"""
from .local_rpc import LocalRPC
from .relay_workload import get_network
from .byoc_trt import relay_build_with_tensorrt
17 changes: 17 additions & 0 deletions python/tvm/meta_schedule/testing/byoc_trt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import tvm
from tvm.runtime import Module
from tvm.target import Target
from tvm.meta_schedule.builder import BuilderInput, LocalBuilder, BuilderResult
from typing import List


def relay_build_with_tensorrt(
mod: Module,
target: Target,
params: dict,
) -> List[BuilderResult]:
from tvm.relay.op.contrib.tensorrt import partition_for_tensorrt

mod, config = partition_for_tensorrt(mod, params)
with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}):
return tvm.relay.build_module._build_module_no_factory(mod, "cuda", "llvm", params)
29 changes: 4 additions & 25 deletions tests/python/unittest/test_meta_schedule_byoc_tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,8 @@
from tvm.relay import testing
from tvm.relay.op.contrib import tensorrt
import numpy as np
from typing import List, Tuple

# from tvm import script
# from tvm._ffi import register_func
# from tvm.runtime import Module
from typing import List
from tvm._ffi import register_func
from tvm.relay.testing.init import Initializer
from tvm.target import Target
from tvm.runtime import Module
from tvm.meta_schedule.arg_info import TensorInfo
Expand Down Expand Up @@ -94,25 +89,11 @@ def verify_meta_schedule_with_tensorrt(
):
if use_meta_sched:
# With meta_schedule
dev = "nvidia/geforce-rtx-2080"
dev = "cuda"

# Build
if use_trt:

def relay_build_with_tensorrt(
mod: Module,
target: Target,
params: dict,
) -> List[BuilderResult]:
from tvm.relay.op.contrib.tensorrt import partition_for_tensorrt

mod, config = partition_for_tensorrt(mod, params)
with tvm.transform.PassContext(
opt_level=3, config={"relay.ext.tensorrt.options": config}
):
return tvm.relay.build_module._build_module_no_factory(
mod, "cuda", "llvm", params
)
from tvm.meta_schedule.testing import relay_build_with_tensorrt

builder = LocalBuilder(f_build=relay_build_with_tensorrt)
else:
Expand All @@ -122,7 +103,6 @@ def relay_build_without_tensorrt(
target: Target,
params: dict,
) -> List[BuilderResult]:
# @Sung: Weird. Cannot pass keyword arg
return tvm.relay.build_module._build_module_no_factory(mod, "cuda", "llvm", params)

builder = LocalBuilder(f_build=relay_build_without_tensorrt)
Expand Down Expand Up @@ -235,7 +215,7 @@ def test_conv2d_relu():
"model_name",
["resnet-50", "mobilenet"],
)
@pytest.mark.parametrize("batch_size", [1, 8, 16])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize("use_meta_sched", [True])
@pytest.mark.parametrize("use_trt", [True, False])
def test_relay_model(model_name: str, batch_size: int, use_meta_sched: bool, use_trt: bool):
Expand All @@ -246,6 +226,5 @@ def test_relay_model(model_name: str, batch_size: int, use_meta_sched: bool, use
)


# @sunggg: memory verification error at test_relay_model("resnet-50", 1, use_meta_sched=False, use_trt=True)
if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 comments on commit c45d16a

Please sign in to comment.