From b959c5a938e2e226d199676012a979d73f899920 Mon Sep 17 00:00:00 2001 From: Margaret Qian Date: Thu, 17 Mar 2022 20:36:33 -0700 Subject: [PATCH 1/5] introduce vm compile path --- python/tvm/driver/tvmc/compiler.py | 47 +++++++++++--- python/tvm/driver/tvmc/model.py | 6 +- python/tvm/driver/tvmc/runner.py | 99 ++++++++++++++++-------------- 3 files changed, 98 insertions(+), 54 deletions(-) diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index df56e3b9825d..a6d856bc97c9 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -196,6 +196,7 @@ def compile_model( disabled_pass: Optional[str] = None, pass_context_configs: Optional[List[str]] = None, additional_target_options: Optional[Dict[str, Dict[str, Any]]] = None, + use_vm: bool = False, ): """Compile a model from a supported framework into a TVM module. @@ -286,8 +287,13 @@ def compile_model( opt_level=opt_level, config=config, disabled_pass=disabled_pass ): logger.debug("building relay graph with autoscheduler") - graph_module = relay.build( - mod, target=tvm_target, executor=executor, runtime=runtime, params=params + graph_module = build( + mod, + tvm_target=tvm_target, + executor=executor, + runtime=runtime, + params=params, + use_vm=use_vm ) else: with autotvm.apply_history_best(tuning_records): @@ -295,16 +301,27 @@ def compile_model( opt_level=opt_level, config=config, disabled_pass=disabled_pass ): logger.debug("building relay graph with tuning records") - graph_module = relay.build( - mod, target=tvm_target, executor=executor, runtime=runtime, params=params + # TODO: replace with vm.compile + graph_module = build( + mod, + tvm_target=tvm_target, + executor=executor, + runtime=runtime, + params=params, + use_vm=use_vm ) else: with tvm.transform.PassContext( opt_level=opt_level, config=config, disabled_pass=disabled_pass ): logger.debug("building relay graph (no tuning records provided)") - graph_module = relay.build( - mod, target=tvm_target, executor=executor, runtime=runtime, params=params + graph_module = build( + mod, + tvm_target=tvm_target, + executor=executor, + runtime=runtime, + params=params, + use_vm=use_vm ) # Generate output dump files with sources @@ -333,7 +350,23 @@ def compile_model( if dumps: save_dumps(package_path, dumps) - return TVMCPackage(package_path) + return TVMCPackage(package_path, use_vm=use_vm) + + +def build( + mod: tvm.IRModule, + tvm_target: str, + executor: Executor, + runtime: Runtime, + params: Dict[str, tvm.nd.NDArray], + use_vm: bool +): + if use_vm: + return relay.vm.compile(mod, target=tvm_target, params=params) + else: + return relay.build( + mod, target=tvm_target, executor=executor, runtime=runtime, params=params + ) def save_dumps(module_name: str, dumps: Dict[str, str], dump_root: str = "."): diff --git a/python/tvm/driver/tvmc/model.py b/python/tvm/driver/tvmc/model.py index 9a2617f3ed53..779e5b4502dd 100644 --- a/python/tvm/driver/tvmc/model.py +++ b/python/tvm/driver/tvmc/model.py @@ -314,12 +314,16 @@ class TVMCPackage(object): project_dir : Path, str If given and loading a MLF file, the path to the project directory that contains the file. + + use_vm : bool + Whether the graph module was compiled with vm or not. """ - def __init__(self, package_path: str, project_dir: Optional[Union[Path, str]] = None): + def __init__(self, package_path: str, project_dir: Optional[Union[Path, str]] = None, use_vm: bool = False): self._tmp_dir = utils.tempdir() self.package_path = package_path self.import_package(self.package_path) + self.use_vm = use_vm if project_dir and self.type != "mlf": raise TVMCException("Setting 'project_dir' is only allowed when importing a MLF.!") diff --git a/python/tvm/driver/tvmc/runner.py b/python/tvm/driver/tvmc/runner.py index 14227b4d8bda..5e960ac81e98 100644 --- a/python/tvm/driver/tvmc/runner.py +++ b/python/tvm/driver/tvmc/runner.py @@ -28,8 +28,9 @@ import tvm from tvm import rpc +from tvm.runtime import vm from tvm.autotvm.measure import request_remote -from tvm.contrib import graph_executor as runtime +from tvm.contrib import graph_executor from tvm.contrib.debugger import debug_executor from . import TVMCException from .arguments import TVMCSuppressedArgumentParser @@ -530,58 +531,64 @@ def run_module( assert device == "cpu" dev = session.cpu() - # TODO(gromero): Adjust for micro targets. - if profile: - logger.debug("Creating runtime with profiling enabled.") - module = debug_executor.create(tvmc_package.graph, lib, dev, dump_root="./prof") + if tvmc_package.use_vm: + exe = vm.VirtualMachine(tvmc_package.graph, device) + inputs_dict = make_inputs_dict(shape_dict, dtype_dict, inputs, fill_mode) + outputs = exe.invoke("main", inputs_dict.values()) + times = exe.benchmark(dev, inputs_dict.values(), func_name="main", repeat=repeat, number=number, end_to_end=end_to_end) else: - if device == "micro": - logger.debug("Creating runtime (micro) with profiling disabled.") - module = tvm.micro.create_local_graph_executor(tvmc_package.graph, lib, dev) + # TODO(gromero): Adjust for micro targets. + if profile: + logger.debug("Creating runtime with profiling enabled.") + module = debug_executor.create(tvmc_package.graph, lib, dev, dump_root="./prof") else: - logger.debug("Creating runtime with profiling disabled.") - module = runtime.create(tvmc_package.graph, lib, dev) + if device == "micro": + logger.debug("Creating runtime (micro) with profiling disabled.") + module = tvm.micro.create_local_graph_executor(tvmc_package.graph, lib, dev) + else: + logger.debug("Creating runtime with profiling disabled.") + module = graph_executor.create(tvmc_package.graph, lib, dev) - logger.debug("Loading params into the runtime module.") - module.load_params(tvmc_package.params) + logger.debug("Loading params into the runtime module.") + module.load_params(tvmc_package.params) - logger.debug("Collecting graph input shape and type:") - shape_dict, dtype_dict = module.get_input_info() - logger.debug("Graph input shape: %s", shape_dict) - logger.debug("Graph input type: %s", dtype_dict) + logger.debug("Collecting graph input shape and type:") + shape_dict, dtype_dict = module.get_input_info() + logger.debug("Graph input shape: %s", shape_dict) + logger.debug("Graph input type: %s", dtype_dict) - inputs_dict = make_inputs_dict(shape_dict, dtype_dict, inputs, fill_mode) + inputs_dict = make_inputs_dict(shape_dict, dtype_dict, inputs, fill_mode) - logger.debug("Setting inputs to the module.") - module.set_input(**inputs_dict) + logger.debug("Setting inputs to the module.") + module.set_input(**inputs_dict) - # Run must be called explicitly if profiling - if profile: - logger.info("Running the module with profiling enabled.") - report = module.profile() - # This print is intentional - print(report) + # Run must be called explicitly if profiling + if profile: + logger.info("Running the module with profiling enabled.") + report = module.profile() + # This print is intentional + print(report) - if device == "micro": - # TODO(gromero): Fix time_evaluator() for micro targets. Once it's - # fixed module.benchmark() can be used instead and this if/else can - # be removed. - module.run() - times = [] - else: - # Call the benchmarking function of the executor. - # Optionally measure e2e data transfers from the - # CPU to device memory overheads (e.g. PCIE - # overheads if the device is a discrete GPU). - if end_to_end: - dev = session.cpu() - times = module.benchmark(dev, number=number, repeat=repeat, end_to_end=end_to_end) - - logger.debug("Collecting the output tensors.") - num_outputs = module.get_num_outputs() - outputs = {} - for i in range(num_outputs): - output_name = "output_{}".format(i) - outputs[output_name] = module.get_output(i).numpy() + if device == "micro": + # TODO(gromero): Fix time_evaluator() for micro targets. Once it's + # fixed module.benchmark() can be used instead and this if/else can + # be removed. + module.run() + times = [] + else: + # Call the benchmarking function of the executor. + # Optionally measure e2e data transfers from the + # CPU to device memory overheads (e.g. PCIE + # overheads if the device is a discrete GPU). + if end_to_end: + dev = session.cpu() + times = module.benchmark(dev, number=number, repeat=repeat, end_to_end=end_to_end) + + logger.debug("Collecting the output tensors.") + num_outputs = module.get_num_outputs() + outputs = {} + for i in range(num_outputs): + output_name = "output_{}".format(i) + outputs[output_name] = module.get_output(i).numpy() return TVMCResult(outputs, times) From 2f6c8b51bb4f4a2d4443931c71c9fd6be43807d2 Mon Sep 17 00:00:00 2001 From: Margaret Qian Date: Tue, 22 Mar 2022 19:47:02 -0700 Subject: [PATCH 2/5] support vm in tvmc --- python/tvm/driver/tvmc/compiler.py | 12 ++++- python/tvm/driver/tvmc/model.py | 53 ++++++++++++++++++--- python/tvm/driver/tvmc/runner.py | 23 ++++++--- python/tvm/runtime/vm.py | 1 + tests/python/driver/tvmc/test_compiler.py | 57 ++++++++++++++++------- tests/python/driver/tvmc/test_model.py | 23 ++++++++- 6 files changed, 136 insertions(+), 33 deletions(-) diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index a6d856bc97c9..da0177102493 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -28,6 +28,7 @@ from tvm.driver.tvmc.registry import generate_registry_args, reconstruct_registry_entity from tvm.target import Target from tvm.relay.backend import Executor, Runtime +from tvm.contrib import utils from . import composite_target, frontends from .model import TVMCModel, TVMCPackage @@ -244,7 +245,8 @@ def compile_model( PassContext. additional_target_options: Optional[Dict[str, Dict[str, Any]]] Additional target options in a dictionary to combine with initial Target arguments - + use_vm: bool + Whether to use the VM to compile the model as opposed to the graph executor Returns ------- @@ -331,7 +333,10 @@ def compile_model( dump_code = [dump_code] dumps = {} for source_type in dump_code: - lib = graph_module.get_lib() + if use_vm: + _, lib = graph_module.save() + else: + lib = graph_module.get_lib() # TODO lib.get_source call have inconsistent behavior for unsupported # formats (@leandron). source = str(mod) if source_type == "relay" else lib.get_source(source_type) @@ -344,6 +349,7 @@ def compile_model( cross, cross_options, output_format, + use_vm=use_vm ) # Write dumps to file. @@ -362,8 +368,10 @@ def build( use_vm: bool ): if use_vm: + logger.debug("building with vm compile") return relay.vm.compile(mod, target=tvm_target, params=params) else: + logger.debug("building with relay build") return relay.build( mod, target=tvm_target, executor=executor, runtime=runtime, params=params ) diff --git a/python/tvm/driver/tvmc/model.py b/python/tvm/driver/tvmc/model.py index 779e5b4502dd..05aec64699e9 100644 --- a/python/tvm/driver/tvmc/model.py +++ b/python/tvm/driver/tvmc/model.py @@ -182,6 +182,27 @@ def default_package_path(self): """ return self._tmp_dir.relpath("model_package.tar") + def export_vm_format( + self, + vm_exec: tvm.runtime.vm.Executable, + package_path: Optional[str] = None, + lib_format: str = "so", + ): + # TODO: write some docs + lib_name = "lib." + lib_format + temp = self._tmp_dir + if package_path is None: + package_path = self.default_package_path() + + path_lib = temp.relpath(lib_name) + vm_exec.mod.export_library(path_lib) + self.lib_path = path_lib + # Package up all the temp files into a tar file. + with tarfile.open(package_path, "w") as tar: + tar.add(path_lib, lib_name) + + return package_path + def export_classic_format( self, executor_factory: GraphExecutorFactoryModule, @@ -248,11 +269,12 @@ def export_classic_format( def export_package( self, - executor_factory: GraphExecutorFactoryModule, + executor_factory: Union[GraphExecutorFactoryModule, tvm.runtime.vm.Executable], package_path: Optional[str] = None, cross: Optional[Union[str, Callable]] = None, cross_options: Optional[str] = None, output_format: str = "so", + use_vm: bool = False, ): """Save this TVMCModel to file. Parameters @@ -281,7 +303,9 @@ def export_package( if output_format == "mlf" and cross: raise TVMCException("Specifying the MLF output and a cross compiler is not supported.") - if output_format in ["so", "tar"]: + if use_vm: + package_path = self.export_vm_format(executor_factory, package_path, output_format) + elif output_format in ["so", "tar"]: package_path = self.export_classic_format( executor_factory, package_path, cross, cross_options, output_format ) @@ -322,8 +346,8 @@ class TVMCPackage(object): def __init__(self, package_path: str, project_dir: Optional[Union[Path, str]] = None, use_vm: bool = False): self._tmp_dir = utils.tempdir() self.package_path = package_path - self.import_package(self.package_path) self.use_vm = use_vm + self.import_package(self.package_path) if project_dir and self.type != "mlf": raise TVMCException("Setting 'project_dir' is only allowed when importing a MLF.!") @@ -341,7 +365,21 @@ def import_package(self, package_path: str): t = tarfile.open(package_path) t.extractall(temp.relpath(".")) - if os.path.exists(temp.relpath("metadata.json")): + if self.use_vm: + self.type = "vm" + graph = None + params = None + lib_name_so = "lib.so" + lib_name_tar = "lib.tar" + if os.path.exists(temp.relpath(lib_name_so)): + self.lib_name = lib_name_so + elif os.path.exists(temp.relpath(lib_name_tar)): + self.lib_name = lib_name_tar + else: + raise TVMCException("Couldn't find exported library in the package.") + + self.lib_path = temp.relpath(self.lib_name) + elif os.path.exists(temp.relpath("metadata.json")): # Model Library Format (MLF) self.lib_name = None self.lib_path = None @@ -370,8 +408,11 @@ def import_package(self, package_path: str): self.type = "classic" - with open(params, "rb") as param_file: - self.params = bytearray(param_file.read()) + if params is not None: + with open(params, "rb") as param_file: + self.params = bytearray(param_file.read()) + else: + self.params = None if graph is not None: with open(graph) as graph_file: diff --git a/python/tvm/driver/tvmc/runner.py b/python/tvm/driver/tvmc/runner.py index 5e960ac81e98..5ea3467c77f8 100644 --- a/python/tvm/driver/tvmc/runner.py +++ b/python/tvm/driver/tvmc/runner.py @@ -509,8 +509,12 @@ def run_module( # must be already flashed into the micro target before one tries # to run it. Hence skip model upload for micro targets. if device != "micro": - session.upload(tvmc_package.lib_path) - lib = session.load_module(tvmc_package.lib_name) + if tvmc_package.use_vm: + session.upload(tvmc_package.lib_path) + rexec = session.load_module(tvmc_package.lib_name) + else: + session.upload(tvmc_package.lib_path) + lib = session.load_module(tvmc_package.lib_name) # TODO expand to other supported devices, as listed in tvm.rpc.client (@leandron) logger.debug("Device is %s.", device) @@ -532,10 +536,17 @@ def run_module( dev = session.cpu() if tvmc_package.use_vm: - exe = vm.VirtualMachine(tvmc_package.graph, device) - inputs_dict = make_inputs_dict(shape_dict, dtype_dict, inputs, fill_mode) - outputs = exe.invoke("main", inputs_dict.values()) - times = exe.benchmark(dev, inputs_dict.values(), func_name="main", repeat=repeat, number=number, end_to_end=end_to_end) + exe = vm.VirtualMachine(rexec, dev) + input_tensor = tvm.nd.array(inputs, dev) + exe.set_input("main", input_tensor) + exe.invoke_stateful("main") + times = exe.benchmark(dev, input_tensor, func_name="main", repeat=repeat, number=number, end_to_end=end_to_end) + outputs = exe.get_outputs() + outputs_dict = {} + for i in range(len(outputs)): + output_name = "output_{}".format(i) + outputs_dict[output_name] = outputs[i] + outputs = outputs_dict else: # TODO(gromero): Adjust for micro targets. if profile: diff --git a/python/tvm/runtime/vm.py b/python/tvm/runtime/vm.py index 27fd5af51a27..a25e1bfc2b82 100644 --- a/python/tvm/runtime/vm.py +++ b/python/tvm/runtime/vm.py @@ -441,6 +441,7 @@ def set_input(self, func_name, *args, **kwargs): idx = func_params.index(k) new_args[idx] = kwargs[k] cnt += 1 + breakpoint() assert len(args) + cnt == len(func_params) idx = 0 for i, arg in enumerate(new_args): diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index 5ebcb8eea27d..0182e4789149 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -48,19 +48,26 @@ def test_save_dumps(tmpdir_factory): # End to end tests for compilation +def verify_tvmc_package(tvmc_package, dumps_path, use_vm=False): + # check for output types + assert type(tvmc_package) is TVMCPackage + assert os.path.exists(dumps_path) + assert type(tvmc_package.lib_path) is str + + if use_vm: + assert tvmc_package.graph is None + assert tvmc_package.params is None + else: + assert type(tvmc_package.graph) is str + assert type(tvmc_package.params) is bytearray -def verify_compile_tflite_module(model, shape_dict=None): + +def verify_compile_tflite_module(model, shape_dict=None, use_vm=False): pytest.importorskip("tflite") tvmc_model = tvmc.load(model, shape_dict=shape_dict) - tvmc_package = tvmc.compile(tvmc_model, target="llvm", dump_code="ll", desired_layout="NCHW") + tvmc_package = tvmc.compile(tvmc_model, target="llvm", dump_code="ll", desired_layout="NCHW", use_vm=use_vm) dumps_path = tvmc_package.package_path + ".ll" - - # check for output types - assert type(tvmc_package) is TVMCPackage - assert type(tvmc_package.graph) is str - assert type(tvmc_package.lib_path) is str - assert type(tvmc_package.params) is bytearray - assert os.path.exists(dumps_path) + verify_tvmc_package(tvmc_package, dumps_path, use_vm=use_vm) def test_compile_tflite_module(tflite_mobilenet_v1_1_quant): @@ -74,6 +81,17 @@ def test_compile_tflite_module(tflite_mobilenet_v1_1_quant): verify_compile_tflite_module(tflite_mobilenet_v1_1_quant, shape_dict) +def test_compile_tflite_module_use_vm(tflite_mobilenet_v1_1_quant): + # some CI environments wont offer tflite, so skip in case it is not present + pytest.importorskip("tflite") + # Check default compilation. + verify_compile_tflite_module(tflite_mobilenet_v1_1_quant) + # Check with manual shape override + shape_string = "input:[1,224,224,3]" + shape_dict = tvmc.shape_parser.parse_shape_string(shape_string) + verify_compile_tflite_module(tflite_mobilenet_v1_1_quant, shape_dict, use_vm=True) + + # This test will be skipped if the AArch64 cross-compilation toolchain is not installed. @pytest.mark.skipif( not shutil.which("aarch64-linux-gnu-gcc"), reason="cross-compilation toolchain not installed" @@ -198,19 +216,13 @@ def test_cross_compile_options_aarch64_keras_module(keras_resnet50): assert os.path.exists(dumps_path) -def verify_compile_onnx_module(model, shape_dict=None): +def verify_compile_onnx_module(model, shape_dict=None, use_vm=False): # some CI environments wont offer onnx, so skip in case it is not present pytest.importorskip("onnx") tvmc_model = tvmc.load(model, shape_dict=shape_dict) - tvmc_package = tvmc.compile(tvmc_model, target="llvm", dump_code="ll") + tvmc_package = tvmc.compile(tvmc_model, target="llvm", dump_code="ll", use_vm=use_vm) dumps_path = tvmc_package.package_path + ".ll" - - # check for output types - assert type(tvmc_package) is TVMCPackage - assert type(tvmc_package.graph) is str - assert type(tvmc_package.lib_path) is str - assert type(tvmc_package.params) is bytearray - assert os.path.exists(dumps_path) + verify_tvmc_package(tvmc_package, dumps_path, use_vm=use_vm) def test_compile_onnx_module(onnx_resnet50): @@ -222,6 +234,15 @@ def test_compile_onnx_module(onnx_resnet50): verify_compile_onnx_module(onnx_resnet50, shape_dict) +def test_compile_onnx_module_use_vm(onnx_resnet50): + # Test default compilation + verify_compile_onnx_module(onnx_resnet50, use_vm=True) + # Test with manual shape dict + shape_string = "data:[1,3,200,200]" + shape_dict = tvmc.shape_parser.parse_shape_string(shape_string) + verify_compile_onnx_module(onnx_resnet50, shape_dict, use_vm=True) + + # This test will be skipped if the AArch64 cross-compilation toolchain is not installed. @pytest.mark.skipif( not shutil.which("aarch64-linux-gnu-gcc"), reason="cross-compilation toolchain not installed" diff --git a/tests/python/driver/tvmc/test_model.py b/tests/python/driver/tvmc/test_model.py index d0d398b75521..be112d7d8ba4 100644 --- a/tests/python/driver/tvmc/test_model.py +++ b/tests/python/driver/tvmc/test_model.py @@ -16,13 +16,14 @@ # under the License. import pytest import os +import numpy as np from os import path from tvm.driver import tvmc from tvm.driver.tvmc.model import TVMCModel, TVMCPackage, TVMCResult from tvm.runtime.module import BenchmarkResult - +from tvm import nd def test_tvmc_workflow(keras_simple): pytest.importorskip("tensorflow") @@ -40,6 +41,26 @@ def test_tvmc_workflow(keras_simple): assert "output_0" in result.outputs.keys() +def test_tvmc_workflow_use_vm(keras_simple): + pytest.importorskip("tensorflow") + + tvmc_model = tvmc.load(keras_simple) + tuning_records = tvmc.tune(tvmc_model, target="llvm", enable_autoscheduler=True, trials=2) + tvmc_package = tvmc.compile(tvmc_model, tuning_records=tuning_records, target="llvm", use_vm=True) + + np_input = np.random.uniform(size=(1, 32, 32, 3)).astype("float32") + # input_tensor = nd.array(np_input) + result = tvmc.run(tvmc_package, device="cpu", end_to_end=True, inputs=np_input) + + assert type(tvmc_model) is TVMCModel + assert type(tvmc_package) is TVMCPackage + assert type(result) is TVMCResult + assert path.exists(tuning_records) + assert type(result.outputs) is dict + assert type(result.times) is BenchmarkResult + assert "output_0" in result.outputs.keys() + + def test_save_load_model(keras_simple, tmpdir_factory): pytest.importorskip("onnx") From af68d7dacc41ddd09735ea2e5916555b1b51bd6a Mon Sep 17 00:00:00 2001 From: Margaret Qian Date: Tue, 22 Mar 2022 23:03:29 -0700 Subject: [PATCH 3/5] cleanup + lint --- python/tvm/driver/tvmc/compiler.py | 46 ++++++++++++++--------- python/tvm/driver/tvmc/model.py | 30 ++++++++++++--- python/tvm/driver/tvmc/runner.py | 37 ++++++++++-------- python/tvm/runtime/vm.py | 1 - tests/python/driver/tvmc/test_compiler.py | 5 ++- tests/python/driver/tvmc/test_model.py | 11 +++--- 6 files changed, 85 insertions(+), 45 deletions(-) diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index da0177102493..e0ef54cc0dbd 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -28,7 +28,6 @@ from tvm.driver.tvmc.registry import generate_registry_args, reconstruct_registry_entity from tvm.target import Target from tvm.relay.backend import Executor, Runtime -from tvm.contrib import utils from . import composite_target, frontends from .model import TVMCModel, TVMCPackage @@ -295,7 +294,7 @@ def compile_model( executor=executor, runtime=runtime, params=params, - use_vm=use_vm + use_vm=use_vm, ) else: with autotvm.apply_history_best(tuning_records): @@ -303,14 +302,13 @@ def compile_model( opt_level=opt_level, config=config, disabled_pass=disabled_pass ): logger.debug("building relay graph with tuning records") - # TODO: replace with vm.compile graph_module = build( mod, tvm_target=tvm_target, executor=executor, runtime=runtime, params=params, - use_vm=use_vm + use_vm=use_vm, ) else: with tvm.transform.PassContext( @@ -323,7 +321,7 @@ def compile_model( executor=executor, runtime=runtime, params=params, - use_vm=use_vm + use_vm=use_vm, ) # Generate output dump files with sources @@ -344,12 +342,7 @@ def compile_model( # Create a new tvmc model package object from the graph definition. package_path = tvmc_model.export_package( - graph_module, - package_path, - cross, - cross_options, - output_format, - use_vm=use_vm + graph_module, package_path, cross, cross_options, output_format, use_vm=use_vm ) # Write dumps to file. @@ -361,20 +354,37 @@ def compile_model( def build( mod: tvm.IRModule, - tvm_target: str, + tvm_target: str, executor: Executor, runtime: Runtime, params: Dict[str, tvm.nd.NDArray], - use_vm: bool + use_vm: bool, ): + """ + Builds the model with the provided executor. + + Parameters + ---------- + mod : tvm.IRModule + The relay module corresponding to this model. + tvm_target : str + The target for which to compile. Can be a plain string or + a path. + executor : Executor + The graph executor to build the model if use_vm is not True + runtime : Runtime + The runtime configuration. + params : dict + A parameter dictionary for the model. + use_vm: bool + Whether to use the VM to compile the model as opposed to the graph executor + + """ if use_vm: logger.debug("building with vm compile") return relay.vm.compile(mod, target=tvm_target, params=params) - else: - logger.debug("building with relay build") - return relay.build( - mod, target=tvm_target, executor=executor, runtime=runtime, params=params - ) + logger.debug("building with relay build") + return relay.build(mod, target=tvm_target, executor=executor, runtime=runtime, params=params) def save_dumps(module_name: str, dumps: Dict[str, str], dump_root: str = "."): diff --git a/python/tvm/driver/tvmc/model.py b/python/tvm/driver/tvmc/model.py index 05aec64699e9..1d6557826d79 100644 --- a/python/tvm/driver/tvmc/model.py +++ b/python/tvm/driver/tvmc/model.py @@ -188,7 +188,22 @@ def export_vm_format( package_path: Optional[str] = None, lib_format: str = "so", ): - # TODO: write some docs + """Save this TVMCModel compiled via vm to file. + Parameters + ---------- + vm_exec : vm.Executable + The VM Executable containing compiled the compiled artifacts needed to run this model. + package_path : str, None + Where the model should be saved. Note that it will be packaged as a .tar file. + If not provided, the package will be saved to a generically named file in tmp. + lib_format : str + How to export the modules function library. Must be one of "so" or "tar". + + Returns + ------- + package_path : str + The path that the package was saved to. + """ lib_name = "lib." + lib_format temp = self._tmp_dir if package_path is None: @@ -200,7 +215,7 @@ def export_vm_format( # Package up all the temp files into a tar file. with tarfile.open(package_path, "w") as tar: tar.add(path_lib, lib_name) - + return package_path def export_classic_format( @@ -343,7 +358,12 @@ class TVMCPackage(object): Whether the graph module was compiled with vm or not. """ - def __init__(self, package_path: str, project_dir: Optional[Union[Path, str]] = None, use_vm: bool = False): + def __init__( + self, + package_path: str, + project_dir: Optional[Union[Path, str]] = None, + use_vm: bool = False, + ): self._tmp_dir = utils.tempdir() self.package_path = package_path self.use_vm = use_vm @@ -376,9 +396,9 @@ def import_package(self, package_path: str): elif os.path.exists(temp.relpath(lib_name_tar)): self.lib_name = lib_name_tar else: - raise TVMCException("Couldn't find exported library in the package.") + raise TVMCException("Couldn't find exported library in the package.") - self.lib_path = temp.relpath(self.lib_name) + self.lib_path = temp.relpath(self.lib_name) elif os.path.exists(temp.relpath("metadata.json")): # Model Library Format (MLF) self.lib_name = None diff --git a/python/tvm/driver/tvmc/runner.py b/python/tvm/driver/tvmc/runner.py index 5ea3467c77f8..30edf9dc6a8b 100644 --- a/python/tvm/driver/tvmc/runner.py +++ b/python/tvm/driver/tvmc/runner.py @@ -509,12 +509,8 @@ def run_module( # must be already flashed into the micro target before one tries # to run it. Hence skip model upload for micro targets. if device != "micro": - if tvmc_package.use_vm: - session.upload(tvmc_package.lib_path) - rexec = session.load_module(tvmc_package.lib_name) - else: - session.upload(tvmc_package.lib_path) - lib = session.load_module(tvmc_package.lib_name) + session.upload(tvmc_package.lib_path) + lib = session.load_module(tvmc_package.lib_name) # TODO expand to other supported devices, as listed in tvm.rpc.client (@leandron) logger.debug("Device is %s.", device) @@ -536,17 +532,28 @@ def run_module( dev = session.cpu() if tvmc_package.use_vm: - exe = vm.VirtualMachine(rexec, dev) - input_tensor = tvm.nd.array(inputs, dev) - exe.set_input("main", input_tensor) + assert inputs is not None and isinstance( + inputs, dict + ), "vm runner requires inputs to be provided as a dict" + exe = vm.VirtualMachine(lib, dev) + input_tensor = {} + for e, i in inputs.items(): + input_tensor[e] = tvm.nd.array(i, dev) + exe.set_input("main", **input_tensor) exe.invoke_stateful("main") - times = exe.benchmark(dev, input_tensor, func_name="main", repeat=repeat, number=number, end_to_end=end_to_end) - outputs = exe.get_outputs() - outputs_dict = {} - for i in range(len(outputs)): + times = exe.benchmark( + dev, + **input_tensor, + func_name="main", + repeat=repeat, + number=number, + end_to_end=end_to_end, + ) + exe_outputs = exe.get_outputs() + outputs = {} + for i, val in enumerate(exe_outputs): output_name = "output_{}".format(i) - outputs_dict[output_name] = outputs[i] - outputs = outputs_dict + outputs[output_name] = val else: # TODO(gromero): Adjust for micro targets. if profile: diff --git a/python/tvm/runtime/vm.py b/python/tvm/runtime/vm.py index a25e1bfc2b82..27fd5af51a27 100644 --- a/python/tvm/runtime/vm.py +++ b/python/tvm/runtime/vm.py @@ -441,7 +441,6 @@ def set_input(self, func_name, *args, **kwargs): idx = func_params.index(k) new_args[idx] = kwargs[k] cnt += 1 - breakpoint() assert len(args) + cnt == len(func_params) idx = 0 for i, arg in enumerate(new_args): diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index 0182e4789149..7b31ebf670f3 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -48,6 +48,7 @@ def test_save_dumps(tmpdir_factory): # End to end tests for compilation + def verify_tvmc_package(tvmc_package, dumps_path, use_vm=False): # check for output types assert type(tvmc_package) is TVMCPackage @@ -65,7 +66,9 @@ def verify_tvmc_package(tvmc_package, dumps_path, use_vm=False): def verify_compile_tflite_module(model, shape_dict=None, use_vm=False): pytest.importorskip("tflite") tvmc_model = tvmc.load(model, shape_dict=shape_dict) - tvmc_package = tvmc.compile(tvmc_model, target="llvm", dump_code="ll", desired_layout="NCHW", use_vm=use_vm) + tvmc_package = tvmc.compile( + tvmc_model, target="llvm", dump_code="ll", desired_layout="NCHW", use_vm=use_vm + ) dumps_path = tvmc_package.package_path + ".ll" verify_tvmc_package(tvmc_package, dumps_path, use_vm=use_vm) diff --git a/tests/python/driver/tvmc/test_model.py b/tests/python/driver/tvmc/test_model.py index be112d7d8ba4..23a39473c537 100644 --- a/tests/python/driver/tvmc/test_model.py +++ b/tests/python/driver/tvmc/test_model.py @@ -23,7 +23,7 @@ from tvm.driver import tvmc from tvm.driver.tvmc.model import TVMCModel, TVMCPackage, TVMCResult from tvm.runtime.module import BenchmarkResult -from tvm import nd + def test_tvmc_workflow(keras_simple): pytest.importorskip("tensorflow") @@ -46,11 +46,12 @@ def test_tvmc_workflow_use_vm(keras_simple): tvmc_model = tvmc.load(keras_simple) tuning_records = tvmc.tune(tvmc_model, target="llvm", enable_autoscheduler=True, trials=2) - tvmc_package = tvmc.compile(tvmc_model, tuning_records=tuning_records, target="llvm", use_vm=True) + tvmc_package = tvmc.compile( + tvmc_model, tuning_records=tuning_records, target="llvm", use_vm=True + ) - np_input = np.random.uniform(size=(1, 32, 32, 3)).astype("float32") - # input_tensor = nd.array(np_input) - result = tvmc.run(tvmc_package, device="cpu", end_to_end=True, inputs=np_input) + input_dict = {"input_2": np.random.uniform(size=(1, 32, 32, 3)).astype("float32")} + result = tvmc.run(tvmc_package, device="cpu", end_to_end=True, inputs=input_dict) assert type(tvmc_model) is TVMCModel assert type(tvmc_package) is TVMCPackage From 554af6411b3f5425f631373b15b07cff31749235 Mon Sep 17 00:00:00 2001 From: Margaret Qian Date: Tue, 29 Mar 2022 18:55:04 -0700 Subject: [PATCH 4/5] add profiler + simplify vm case in tvmcpackage --- python/tvm/driver/tvmc/compiler.py | 2 +- python/tvm/driver/tvmc/model.py | 58 ++++++++++++------------- python/tvm/driver/tvmc/runner.py | 30 +++++++++---- tests/python/driver/tvmc/test_model.py | 6 ++- tests/python/driver/tvmc/test_runner.py | 38 +++++++++++++++- 5 files changed, 93 insertions(+), 41 deletions(-) diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index e0ef54cc0dbd..ed2e9f0f590a 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -349,7 +349,7 @@ def compile_model( if dumps: save_dumps(package_path, dumps) - return TVMCPackage(package_path, use_vm=use_vm) + return TVMCPackage(package_path) def build( diff --git a/python/tvm/driver/tvmc/model.py b/python/tvm/driver/tvmc/model.py index 1d6557826d79..bf3cea80ca7e 100644 --- a/python/tvm/driver/tvmc/model.py +++ b/python/tvm/driver/tvmc/model.py @@ -57,6 +57,8 @@ from tvm.driver.tvmc import TVMCException from tvm.relay.backend.executor_factory import GraphExecutorFactoryModule from tvm.runtime.module import BenchmarkResult +from tvm.runtime.vm import Executable + try: from tvm.micro import export_model_library_format @@ -184,7 +186,7 @@ def default_package_path(self): def export_vm_format( self, - vm_exec: tvm.runtime.vm.Executable, + vm_exec: Executable, package_path: Optional[str] = None, lib_format: str = "so", ): @@ -284,7 +286,7 @@ def export_classic_format( def export_package( self, - executor_factory: Union[GraphExecutorFactoryModule, tvm.runtime.vm.Executable], + executor_factory: Union[GraphExecutorFactoryModule, Executable], package_path: Optional[str] = None, cross: Optional[Union[str, Callable]] = None, cross_options: Optional[str] = None, @@ -362,11 +364,9 @@ def __init__( self, package_path: str, project_dir: Optional[Union[Path, str]] = None, - use_vm: bool = False, ): self._tmp_dir = utils.tempdir() self.package_path = package_path - self.use_vm = use_vm self.import_package(self.package_path) if project_dir and self.type != "mlf": @@ -385,21 +385,7 @@ def import_package(self, package_path: str): t = tarfile.open(package_path) t.extractall(temp.relpath(".")) - if self.use_vm: - self.type = "vm" - graph = None - params = None - lib_name_so = "lib.so" - lib_name_tar = "lib.tar" - if os.path.exists(temp.relpath(lib_name_so)): - self.lib_name = lib_name_so - elif os.path.exists(temp.relpath(lib_name_tar)): - self.lib_name = lib_name_tar - else: - raise TVMCException("Couldn't find exported library in the package.") - - self.lib_path = temp.relpath(self.lib_name) - elif os.path.exists(temp.relpath("metadata.json")): + if os.path.exists(temp.relpath("metadata.json")): # Model Library Format (MLF) self.lib_name = None self.lib_path = None @@ -413,20 +399,34 @@ def import_package(self, package_path: str): self.type = "mlf" else: # Classic format - lib_name_so = "mod.so" - lib_name_tar = "mod.tar" - if os.path.exists(temp.relpath(lib_name_so)): - self.lib_name = lib_name_so - elif os.path.exists(temp.relpath(lib_name_tar)): - self.lib_name = lib_name_tar + classic_lib_name_so = "mod.so" + classic_lib_name_tar = "mod.tar" + + # VM format + vm_lib_name_so = "lib.so" + vm_lib_name_tar = "lib.tar" + + if os.path.exists(temp.relpath(classic_lib_name_so)): + self.lib_name = classic_lib_name_so + self.type = "classic" + elif os.path.exists(temp.relpath(classic_lib_name_tar)): + self.lib_name = classic_lib_name_tar + self.type = "classic" + elif os.path.exists(temp.relpath(vm_lib_name_so)): + self.lib_name = vm_lib_name_so + self.type = "vm" + elif os.path.exists(temp.relpath(vm_lib_name_tar)): + self.lib_name = vm_lib_name_tar + self.type = "vm" else: raise TVMCException("Couldn't find exported library in the package.") - self.lib_path = temp.relpath(self.lib_name) - graph = temp.relpath("mod.json") - params = temp.relpath("mod.params") + self.lib_path = temp.relpath(self.lib_name) - self.type = "classic" + graph, params = None, None + if self.type == "classic": + graph = temp.relpath("mod.json") + params = temp.relpath("mod.params") if params is not None: with open(params, "rb") as param_file: diff --git a/python/tvm/driver/tvmc/runner.py b/python/tvm/driver/tvmc/runner.py index 30edf9dc6a8b..c66f4d6c63e7 100644 --- a/python/tvm/driver/tvmc/runner.py +++ b/python/tvm/driver/tvmc/runner.py @@ -32,6 +32,7 @@ from tvm.autotvm.measure import request_remote from tvm.contrib import graph_executor from tvm.contrib.debugger import debug_executor +from tvm.runtime import profiler_vm from . import TVMCException from .arguments import TVMCSuppressedArgumentParser from .project import ( @@ -531,16 +532,23 @@ def run_module( assert device == "cpu" dev = session.cpu() - if tvmc_package.use_vm: - assert inputs is not None and isinstance( - inputs, dict - ), "vm runner requires inputs to be provided as a dict" - exe = vm.VirtualMachine(lib, dev) + if tvmc_package.type == "vm": + assert inputs is not None, "vm runner requires inputs to be provided as a dict" + input_tensor = {} for e, i in inputs.items(): input_tensor[e] = tvm.nd.array(i, dev) - exe.set_input("main", **input_tensor) - exe.invoke_stateful("main") + + if profile: + logger.debug("Creating vm with profile enabled.") + exe = profiler_vm.VirtualMachineProfiler(lib, dev) + res = exe.profile(**input_tensor, func_name="main") + # This print is intentional + print(res) + else: + exe = vm.VirtualMachine(lib, dev) + + exe_outputs = exe.invoke("main", **input_tensor) times = exe.benchmark( dev, **input_tensor, @@ -549,11 +557,15 @@ def run_module( number=number, end_to_end=end_to_end, ) - exe_outputs = exe.get_outputs() + + # Special handling if the output only has a single value + if not isinstance(exe_outputs, list): + exe_outputs = [exe_outputs] + outputs = {} for i, val in enumerate(exe_outputs): output_name = "output_{}".format(i) - outputs[output_name] = val + outputs[output_name] = val.numpy() else: # TODO(gromero): Adjust for micro targets. if profile: diff --git a/tests/python/driver/tvmc/test_model.py b/tests/python/driver/tvmc/test_model.py index 23a39473c537..e116d5801da6 100644 --- a/tests/python/driver/tvmc/test_model.py +++ b/tests/python/driver/tvmc/test_model.py @@ -43,6 +43,10 @@ def test_tvmc_workflow(keras_simple): def test_tvmc_workflow_use_vm(keras_simple): pytest.importorskip("tensorflow") + import tensorflow as tf + + # Reset so the input name remains consistent across unit test runs + tf.keras.backend.clear_session() tvmc_model = tvmc.load(keras_simple) tuning_records = tvmc.tune(tvmc_model, target="llvm", enable_autoscheduler=True, trials=2) @@ -50,7 +54,7 @@ def test_tvmc_workflow_use_vm(keras_simple): tvmc_model, tuning_records=tuning_records, target="llvm", use_vm=True ) - input_dict = {"input_2": np.random.uniform(size=(1, 32, 32, 3)).astype("float32")} + input_dict = {"input_1": np.random.uniform(size=(1, 32, 32, 3)).astype("float32")} result = tvmc.run(tvmc_package, device="cpu", end_to_end=True, inputs=input_dict) assert type(tvmc_model) is TVMCModel diff --git a/tests/python/driver/tvmc/test_runner.py b/tests/python/driver/tvmc/test_runner.py index 30ce2c6f2191..b7f655ac7d84 100644 --- a/tests/python/driver/tvmc/test_runner.py +++ b/tests/python/driver/tvmc/test_runner.py @@ -79,11 +79,47 @@ def test_run_tflite_module__with_profile__valid_input( pytest.importorskip("tflite") inputs = np.load(imagenet_cat) + input_dict = {"input": inputs["input"].astype("uint8")} tflite_compiled_model = tflite_compile_model(tflite_mobilenet_v1_1_quant) result = tvmc.run( tflite_compiled_model, - inputs=inputs, + inputs=input_dict, + hostname=None, + device="cpu", + profile=True, + ) + + # collect the top 5 results + top_5_results = get_top_results(result, 5) + top_5_ids = top_5_results[0] + + # IDs were collected from this reference: + # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/ + # java/demo/app/src/main/assets/labels_mobilenet_quant_v1_224.txt + tiger_cat_mobilenet_id = 283 + + assert ( + tiger_cat_mobilenet_id in top_5_ids + ), "tiger cat is expected in the top-5 for mobilenet v1" + assert type(result.outputs) is dict + assert type(result.times) is BenchmarkResult + assert "output_0" in result.outputs.keys() + + +def test_run_tflite_module__with_profile_vm__valid_input( + tflite_mobilenet_v1_1_quant, tflite_compile_model, imagenet_cat +): + # some CI environments wont offer TFLite, so skip in case it is not present + pytest.importorskip("tflite") + + inputs = np.load(imagenet_cat) + input_dict = {"input": inputs["input"].astype("uint8")} + + tflite_compiled_model = tflite_compile_model(tflite_mobilenet_v1_1_quant, use_vm=True) + result = tvmc.run( + tflite_compiled_model, + inputs=input_dict, hostname=None, device="cpu", profile=True, From ebfc60839e152087e02de9e6de7d14be3165e5ec Mon Sep 17 00:00:00 2001 From: Margaret Qian Date: Wed, 30 Mar 2022 21:25:59 -0700 Subject: [PATCH 5/5] address comments + parametrize tests --- python/tvm/driver/tvmc/compiler.py | 4 +-- python/tvm/driver/tvmc/model.py | 3 +- tests/python/driver/tvmc/test_compiler.py | 30 ++++------------- tests/python/driver/tvmc/test_model.py | 29 ++++------------ tests/python/driver/tvmc/test_runner.py | 40 ++--------------------- 5 files changed, 19 insertions(+), 87 deletions(-) diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index 7a18b213c01c..8f24dd4d7536 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -337,7 +337,7 @@ def compile_model( dumps = {} for source_type in dump_code: if use_vm: - _, lib = graph_module.save() + lib = graph_module.lib else: lib = graph_module.get_lib() # TODO lib.get_source call have inconsistent behavior for unsupported @@ -347,7 +347,7 @@ def compile_model( # Create a new tvmc model package object from the graph definition. package_path = tvmc_model.export_package( - graph_module, package_path, cross, cross_options, output_format, use_vm=use_vm + graph_module, package_path, cross, cross_options, output_format ) # Write dumps to file. diff --git a/python/tvm/driver/tvmc/model.py b/python/tvm/driver/tvmc/model.py index bf3cea80ca7e..93ca27c60947 100644 --- a/python/tvm/driver/tvmc/model.py +++ b/python/tvm/driver/tvmc/model.py @@ -291,7 +291,6 @@ def export_package( cross: Optional[Union[str, Callable]] = None, cross_options: Optional[str] = None, output_format: str = "so", - use_vm: bool = False, ): """Save this TVMCModel to file. Parameters @@ -320,7 +319,7 @@ def export_package( if output_format == "mlf" and cross: raise TVMCException("Specifying the MLF output and a cross compiler is not supported.") - if use_vm: + if isinstance(executor_factory, Executable): package_path = self.export_vm_format(executor_factory, package_path, output_format) elif output_format in ["so", "tar"]: package_path = self.export_classic_format( diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index e75ab9c2e136..bc836de7d554 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -73,7 +73,8 @@ def verify_compile_tflite_module(model, shape_dict=None, use_vm=False): verify_tvmc_package(tvmc_package, dumps_path, use_vm=use_vm) -def test_compile_tflite_module(tflite_mobilenet_v1_1_quant): +@pytest.mark.parametrize("use_vm", [True, False]) +def test_compile_tflite_module(use_vm, tflite_mobilenet_v1_1_quant): # some CI environments wont offer tflite, so skip in case it is not present pytest.importorskip("tflite") # Check default compilation. @@ -81,18 +82,7 @@ def test_compile_tflite_module(tflite_mobilenet_v1_1_quant): # Check with manual shape override shape_string = "input:[1,224,224,3]" shape_dict = tvmc.shape_parser.parse_shape_string(shape_string) - verify_compile_tflite_module(tflite_mobilenet_v1_1_quant, shape_dict) - - -def test_compile_tflite_module_use_vm(tflite_mobilenet_v1_1_quant): - # some CI environments wont offer tflite, so skip in case it is not present - pytest.importorskip("tflite") - # Check default compilation. - verify_compile_tflite_module(tflite_mobilenet_v1_1_quant) - # Check with manual shape override - shape_string = "input:[1,224,224,3]" - shape_dict = tvmc.shape_parser.parse_shape_string(shape_string) - verify_compile_tflite_module(tflite_mobilenet_v1_1_quant, shape_dict, use_vm=True) + verify_compile_tflite_module(tflite_mobilenet_v1_1_quant, shape_dict, use_vm=use_vm) # This test will be skipped if the AArch64 cross-compilation toolchain is not installed. @@ -228,22 +218,14 @@ def verify_compile_onnx_module(model, shape_dict=None, use_vm=False): verify_tvmc_package(tvmc_package, dumps_path, use_vm=use_vm) -def test_compile_onnx_module(onnx_resnet50): +@pytest.mark.parametrize("use_vm", [True, False]) +def test_compile_onnx_module(use_vm, onnx_resnet50): # Test default compilation verify_compile_onnx_module(onnx_resnet50) # Test with manual shape dict shape_string = "data:[1,3,200,200]" shape_dict = tvmc.shape_parser.parse_shape_string(shape_string) - verify_compile_onnx_module(onnx_resnet50, shape_dict) - - -def test_compile_onnx_module_use_vm(onnx_resnet50): - # Test default compilation - verify_compile_onnx_module(onnx_resnet50, use_vm=True) - # Test with manual shape dict - shape_string = "data:[1,3,200,200]" - shape_dict = tvmc.shape_parser.parse_shape_string(shape_string) - verify_compile_onnx_module(onnx_resnet50, shape_dict, use_vm=True) + verify_compile_onnx_module(onnx_resnet50, shape_dict, use_vm=use_vm) # This test will be skipped if the AArch64 cross-compilation toolchain is not installed. diff --git a/tests/python/driver/tvmc/test_model.py b/tests/python/driver/tvmc/test_model.py index a0b5b8dd6c96..74c1c4ded8a4 100644 --- a/tests/python/driver/tvmc/test_model.py +++ b/tests/python/driver/tvmc/test_model.py @@ -30,23 +30,8 @@ platform.machine() == "aarch64", reason="Currently failing on AArch64 - see https://github.com/apache/tvm/issues/10673", ) -def test_tvmc_workflow(keras_simple): - pytest.importorskip("tensorflow") - - tvmc_model = tvmc.load(keras_simple) - tuning_records = tvmc.tune(tvmc_model, target="llvm", enable_autoscheduler=True, trials=2) - tvmc_package = tvmc.compile(tvmc_model, tuning_records=tuning_records, target="llvm") - result = tvmc.run(tvmc_package, device="cpu", end_to_end=True) - assert type(tvmc_model) is TVMCModel - assert type(tvmc_package) is TVMCPackage - assert type(result) is TVMCResult - assert path.exists(tuning_records) - assert type(result.outputs) is dict - assert type(result.times) is BenchmarkResult - assert "output_0" in result.outputs.keys() - - -def test_tvmc_workflow_use_vm(keras_simple): +@pytest.mark.parametrize("use_vm", [True, False]) +def test_tvmc_workflow(use_vm, keras_simple): pytest.importorskip("tensorflow") import tensorflow as tf @@ -56,12 +41,11 @@ def test_tvmc_workflow_use_vm(keras_simple): tvmc_model = tvmc.load(keras_simple) tuning_records = tvmc.tune(tvmc_model, target="llvm", enable_autoscheduler=True, trials=2) tvmc_package = tvmc.compile( - tvmc_model, tuning_records=tuning_records, target="llvm", use_vm=True + tvmc_model, tuning_records=tuning_records, target="llvm", use_vm=use_vm ) - input_dict = {"input_1": np.random.uniform(size=(1, 32, 32, 3)).astype("float32")} - result = tvmc.run(tvmc_package, device="cpu", end_to_end=True, inputs=input_dict) + result = tvmc.run(tvmc_package, device="cpu", end_to_end=True, inputs=input_dict) assert type(tvmc_model) is TVMCModel assert type(tvmc_package) is TVMCPackage assert type(result) is TVMCResult @@ -71,7 +55,8 @@ def test_tvmc_workflow_use_vm(keras_simple): assert "output_0" in result.outputs.keys() -def test_save_load_model(keras_simple, tmpdir_factory): +@pytest.mark.parametrize("use_vm", [True, False]) +def test_save_load_model(use_vm, keras_simple, tmpdir_factory): pytest.importorskip("onnx") tmpdir = tmpdir_factory.mktemp("data") @@ -81,7 +66,7 @@ def test_save_load_model(keras_simple, tmpdir_factory): tvmc.tune(tvmc_model, target="llvm", trials=2) # Create package artifacts - tvmc.compile(tvmc_model, target="llvm") + tvmc.compile(tvmc_model, target="llvm", use_vm=use_vm) # Save the model to disk model_path = os.path.join(tmpdir, "saved_model.tar") diff --git a/tests/python/driver/tvmc/test_runner.py b/tests/python/driver/tvmc/test_runner.py index b7f655ac7d84..3f4ab11f6ba2 100644 --- a/tests/python/driver/tvmc/test_runner.py +++ b/tests/python/driver/tvmc/test_runner.py @@ -72,8 +72,9 @@ def test_get_top_results_keep_results(): assert len(sut[1]) == expected_number_of_results_per_line +@pytest.mark.parametrize("use_vm", [True, False]) def test_run_tflite_module__with_profile__valid_input( - tflite_mobilenet_v1_1_quant, tflite_compile_model, imagenet_cat + use_vm, tflite_mobilenet_v1_1_quant, tflite_compile_model, imagenet_cat ): # some CI environments wont offer TFLite, so skip in case it is not present pytest.importorskip("tflite") @@ -81,42 +82,7 @@ def test_run_tflite_module__with_profile__valid_input( inputs = np.load(imagenet_cat) input_dict = {"input": inputs["input"].astype("uint8")} - tflite_compiled_model = tflite_compile_model(tflite_mobilenet_v1_1_quant) - result = tvmc.run( - tflite_compiled_model, - inputs=input_dict, - hostname=None, - device="cpu", - profile=True, - ) - - # collect the top 5 results - top_5_results = get_top_results(result, 5) - top_5_ids = top_5_results[0] - - # IDs were collected from this reference: - # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/ - # java/demo/app/src/main/assets/labels_mobilenet_quant_v1_224.txt - tiger_cat_mobilenet_id = 283 - - assert ( - tiger_cat_mobilenet_id in top_5_ids - ), "tiger cat is expected in the top-5 for mobilenet v1" - assert type(result.outputs) is dict - assert type(result.times) is BenchmarkResult - assert "output_0" in result.outputs.keys() - - -def test_run_tflite_module__with_profile_vm__valid_input( - tflite_mobilenet_v1_1_quant, tflite_compile_model, imagenet_cat -): - # some CI environments wont offer TFLite, so skip in case it is not present - pytest.importorskip("tflite") - - inputs = np.load(imagenet_cat) - input_dict = {"input": inputs["input"].astype("uint8")} - - tflite_compiled_model = tflite_compile_model(tflite_mobilenet_v1_1_quant, use_vm=True) + tflite_compiled_model = tflite_compile_model(tflite_mobilenet_v1_1_quant, use_vm=use_vm) result = tvmc.run( tflite_compiled_model, inputs=input_dict,