From 5b3cef30f963a236205088848d7dc660a1f6c7fc Mon Sep 17 00:00:00 2001 From: Alan MacDonald Date: Tue, 14 Jun 2022 15:28:25 -0700 Subject: [PATCH] [microTVM][zephyr] Add support for host-driven AoT execution on zephyr (#11650) * - add support for host-driven AoT execution on zephyr; - add initial version of reference counting to prevent python code from inadvertently freeing tensors during garbage collection; - add support for numerical indices to host-drive AoT get_input(); - add two initial tests for host-driven AoT execution on zephyr; - rename existing zephyr AoT exec. test; * address PR feedback * increase stack size to accommodate qemu_riscv64 stack usage --- .../template_project/crt_config/crt_config.h | 2 +- .../template_project/microtvm_api_server.py | 2 +- python/tvm/micro/session.py | 10 +- python/tvm/runtime/ndarray.py | 2 +- src/runtime/crt/aot_executor/aot_executor.c | 12 +- .../aot_executor_module/aot_executor_module.c | 30 +++- src/runtime/crt/common/crt_runtime_api.c | 49 +++--- src/runtime/crt/common/ndarray.c | 26 ++- .../crt/graph_executor/graph_executor.c | 4 +- .../graph_executor_module.c | 13 +- src/runtime/crt/host/main.cc | 3 - .../tvm/runtime/crt/internal/common/ndarray.h | 8 + .../crt/microtvm_rpc_server/rpc_server.cc | 6 + src/runtime/graph_executor/graph_executor.h | 2 +- tests/micro/zephyr/conftest.py | 4 +- tests/micro/zephyr/test_zephyr_aot_exec.py | 157 ++++++++++++++++++ ....py => test_zephyr_aot_exec_standalone.py} | 0 17 files changed, 276 insertions(+), 54 deletions(-) create mode 100644 tests/micro/zephyr/test_zephyr_aot_exec.py rename tests/micro/zephyr/{test_zephyr_aot.py => test_zephyr_aot_exec_standalone.py} (100%) diff --git a/apps/microtvm/zephyr/template_project/crt_config/crt_config.h b/apps/microtvm/zephyr/template_project/crt_config/crt_config.h index c3beaed522f2..3481d342a1ce 100644 --- a/apps/microtvm/zephyr/template_project/crt_config/crt_config.h +++ b/apps/microtvm/zephyr/template_project/crt_config/crt_config.h @@ -36,7 +36,7 @@ #define TVM_CRT_MAX_ARGS 10 /*! Size of the global function registry, in bytes. */ -#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 256 +#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 512 /*! Maximum number of registered modules. */ #define TVM_CRT_MAX_REGISTERED_MODULES 2 diff --git a/apps/microtvm/zephyr/template_project/microtvm_api_server.py b/apps/microtvm/zephyr/template_project/microtvm_api_server.py index bcf9f78f4b11..dad4cdf9d64c 100644 --- a/apps/microtvm/zephyr/template_project/microtvm_api_server.py +++ b/apps/microtvm/zephyr/template_project/microtvm_api_server.py @@ -420,7 +420,7 @@ def _create_prj_conf(self, project_dir, options): API_SERVER_CRT_LIBS_TOKEN = "" CRT_LIBS_BY_PROJECT_TYPE = { - "host_driven": "microtvm_rpc_server microtvm_rpc_common common", + "host_driven": "microtvm_rpc_server microtvm_rpc_common aot_executor_module aot_executor common", "aot_demo": "memory microtvm_rpc_common common", } diff --git a/python/tvm/micro/session.py b/python/tvm/micro/session.py index 4c38476207ba..967eaee62958 100644 --- a/python/tvm/micro/session.py +++ b/python/tvm/micro/session.py @@ -39,7 +39,7 @@ @register_error class SessionTerminatedError(Exception): - """Raised when a transport read operationd discovers that the remote session is terminated.""" + """Raised when a transport read operation discovers that the remote session is terminated.""" class Session: @@ -86,12 +86,18 @@ def __init__( self._rpc = None self._graph_executor = None + self._enable_rpc_logger = False self._exit_called = False def get_system_lib(self): return self._rpc.get_function("runtime.SystemLib")() + def create_aot_executor(self): + return self._rpc.get_function("tvm.aot_executor.create")( + self.get_system_lib(), self.device, "default" + ) + def _wrap_transport_read(self, n, timeout_microsec): try: return self.transport.read( @@ -133,7 +139,7 @@ def __enter__(self): int(timeouts.session_start_timeout_sec * 1e6), int(timeouts.session_established_timeout_sec * 1e6), self._cleanup, - False, + self._enable_rpc_logger, ) ) self.device = self._rpc.cpu(0) diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py index 3d4764d6164a..9d3a3aff2165 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/ndarray.py @@ -127,7 +127,7 @@ def __setitem__(self, in_slice, value): raise TypeError("type %s not supported" % str(type(value))) def copyfrom(self, source_array): - """Perform an synchronize copy from the array. + """Perform a synchronous copy from the array. Parameters ---------- diff --git a/src/runtime/crt/aot_executor/aot_executor.c b/src/runtime/crt/aot_executor/aot_executor.c index 1360c40b0fa4..1724fabec4a0 100644 --- a/src/runtime/crt/aot_executor/aot_executor.c +++ b/src/runtime/crt/aot_executor/aot_executor.c @@ -173,21 +173,29 @@ int TVMAotExecutor_Init(TVMAotExecutor* executor, TVMModuleHandle module_handle, for (i = 0; i < md->num_inputs; ++i) { LOG_DEBUG("input allocate[%d]: %s\n", i, md->inputs[i].name); + TVMNDArray* array = &executor->args[arg_idx++]; + status = TVMNDArray_Empty(md->inputs[i].num_shape, md->inputs[i].shape, md->inputs[i].dtype, - executor->device, &executor->args[arg_idx++]); + executor->device, array); if (status != 0) { return status; } + + TVMNDArray_IncrementReference(array); } for (i = 0; i < md->num_outputs; ++i) { LOG_DEBUG("output allocate[%d]: %s\n", i, md->outputs[i].name); + TVMNDArray* array = &executor->args[arg_idx++]; + status = TVMNDArray_Empty(md->outputs[i].num_shape, md->outputs[i].shape, md->outputs[i].dtype, - executor->device, &executor->args[arg_idx++]); + executor->device, array); if (status != 0) { return status; } + + TVMNDArray_IncrementReference(array); } for (i = 0; i < md->num_pools; ++i) { diff --git a/src/runtime/crt/aot_executor_module/aot_executor_module.c b/src/runtime/crt/aot_executor_module/aot_executor_module.c index e1dbd533a3ec..5dd11c3dbc7e 100644 --- a/src/runtime/crt/aot_executor_module/aot_executor_module.c +++ b/src/runtime/crt/aot_executor_module/aot_executor_module.c @@ -80,13 +80,27 @@ int32_t TVMAotExecutorModule_NotImplemented(TVMValue* args, int* tcodes, int nar int32_t TVMAotExecutorModule_GetInput(TVMValue* args, int* tcodes, int nargs, TVMValue* ret_values, int* ret_tcodes, void* resource_handle) { - int index = TVMAotExecutor_GetInputIndex(aot_executor.executor, args[0].v_str); + int64_t index; - if (index < 0) { - return kTvmErrorExecutorModuleNoSuchInput; + if (tcodes[0] == kTVMArgInt) { + if (args[0].v_int64 > TVMAotExecutor_GetNumInputs(aot_executor.executor)) { + return kTvmErrorFunctionCallInvalidArg; + } + + index = args[0].v_int64; + } else { + index = TVMAotExecutor_GetInputIndex(aot_executor.executor, args[0].v_str); + + if (index < 0) { + return kTvmErrorExecutorModuleNoSuchInput; + } } - ret_values[0].v_handle = (void*)&aot_executor.executor->args[index].dl_tensor; + TVMNDArray* array = &aot_executor.executor->args[index]; + + TVMNDArray_IncrementReference(array); + + ret_values[0].v_handle = (void*)(&array->dl_tensor); ret_tcodes[0] = kTVMNDArrayHandle; return 0; @@ -103,9 +117,13 @@ int32_t TVMAotExecutorModule_GetOutput(TVMValue* args, int* tcodes, int nargs, T } // index past the input entries - int64_t idx = args[0].v_int64 + TVMAotExecutor_GetNumInputs(aot_executor.executor); + int64_t index = args[0].v_int64 + TVMAotExecutor_GetNumInputs(aot_executor.executor); + + TVMNDArray* array = &aot_executor.executor->args[index]; + + TVMNDArray_IncrementReference(array); - ret_values[0].v_handle = (void*)&aot_executor.executor->args[idx].dl_tensor; + ret_values[0].v_handle = (void*)(&array->dl_tensor); ret_tcodes[0] = kTVMNDArrayHandle; return 0; diff --git a/src/runtime/crt/common/crt_runtime_api.c b/src/runtime/crt/common/crt_runtime_api.c index 31ab3e9a6973..a8a17041f5ea 100644 --- a/src/runtime/crt/common/crt_runtime_api.c +++ b/src/runtime/crt/common/crt_runtime_api.c @@ -76,9 +76,9 @@ int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, int dtype_ } int TVMArrayFree(TVMArrayHandle handle) { - TVMNDArray arr; - arr.dl_tensor = *handle; - return TVMNDArray_Release(&arr); + TVMNDArray* arr = (TVMNDArray*)handle; + + return TVMNDArray_Release(arr); } int TVMDeviceAllocDataSpace(DLDevice dev, size_t nbytes, size_t alignment, DLDataType type_hint, @@ -149,7 +149,7 @@ static const TVMModule* registered_modules[TVM_CRT_MAX_REGISTERED_MODULES]; /*! \brief Passed as `module_index` to EncodeFunctionHandle. */ static const tvm_module_index_t kGlobalFuncModuleIndex = TVM_CRT_MAX_REGISTERED_MODULES; -/*! \brief Special module handle for retur values from RPCTimeEvaluator. */ +/*! \brief Special module handle for return values from RPCTimeEvaluator. */ static const tvm_module_index_t kTimeEvaluatorModuleIndex = 0x7fff; static int DecodeModuleHandle(TVMModuleHandle handle, tvm_module_index_t* out_module_index) { @@ -202,8 +202,8 @@ int TVMModFree(TVMModuleHandle mod) { return 0; } -int SystemLibraryCreate(TVMValue* args, int* type_codes, int num_args, TVMValue* ret_val, - int* ret_type_codes) { +static int SystemLibraryCreate(TVMValue* args, int* type_codes, int num_args, TVMValue* ret_val, + int* ret_type_codes) { const TVMModule* system_lib; if (system_lib_handle == kTVMModuleHandleUninitialized) { @@ -400,8 +400,22 @@ int RPCGetCRTMaxPacketSize(TVMValue* args, int* type_codes, int num_args, TVMVal return 0; } -int TVMContribRandomFill(TVMValue* args, int* type_codes, int num_args, TVMValue* ret_val, - int* ret_type_code); +// Fill the tensor in args[0] with random data using TVMPlatformGenerateRandom. +static int RandomFill(TVMValue* args, int* type_codes, int num_args, TVMValue* ret_val, + int* ret_type_code) { + if (num_args != 1) { + return kTvmErrorFunctionCallNumArguments; + } + + if (type_codes[0] != kTVMDLTensorHandle) { + return kTvmErrorFunctionCallWrongArgType; + } + + DLTensor* tensor = (DLTensor*)args[0].v_handle; + TVMNDArray arr = {*tensor, 0}; + return TVMNDArray_RandomFill(&arr); +} + tvm_crt_error_t TVMInitializeRuntime() { int idx = 0; tvm_crt_error_t error = kTvmErrorNoError; @@ -440,7 +454,7 @@ tvm_crt_error_t TVMInitializeRuntime() { } if (error == kTvmErrorNoError) { - error = TVMFuncRegisterGlobal("tvm.contrib.random.random_fill", &TVMContribRandomFill, 0); + error = TVMFuncRegisterGlobal("tvm.contrib.random.random_fill", &RandomFill, 0); } if (error != kTvmErrorNoError) { @@ -590,20 +604,3 @@ __attribute__((weak)) tvm_crt_error_t TVMPlatformBeforeMeasurement() { return kT // Default implementation, overridden by the platform runtime. __attribute__((weak)) tvm_crt_error_t TVMPlatformAfterMeasurement() { return kTvmErrorNoError; } - -// Fill the tensor in args[0] with random data using TVMPlatformGenerateRandom. -// Named to correspond with the analogous function in the C++ runtime. -int TVMContribRandomFill(TVMValue* args, int* type_codes, int num_args, TVMValue* ret_val, - int* ret_type_code) { - if (num_args != 1) { - return kTvmErrorFunctionCallNumArguments; - } - - if (type_codes[0] != kTVMDLTensorHandle) { - return kTvmErrorFunctionCallWrongArgType; - } - - DLTensor* tensor = (DLTensor*)args[0].v_handle; - TVMNDArray arr = {*tensor}; - return TVMNDArray_RandomFill(&arr); -} diff --git a/src/runtime/crt/common/ndarray.c b/src/runtime/crt/common/ndarray.c index 16bde3227f7c..b0e869766bde 100644 --- a/src/runtime/crt/common/ndarray.c +++ b/src/runtime/crt/common/ndarray.c @@ -30,8 +30,8 @@ #include "crt_config.h" -int TVMNDArray_Create(int32_t ndim, const tvm_index_t* shape, DLDataType dtype, DLDevice dev, - TVMNDArray* array) { +static int Create(int32_t ndim, const tvm_index_t* shape, DLDataType dtype, DLDevice dev, + TVMNDArray* array) { memset(array, 0, sizeof(TVMNDArray)); array->dl_tensor.ndim = ndim; tvm_crt_error_t err; @@ -58,7 +58,7 @@ int64_t TVMNDArray_DataSizeBytes(TVMNDArray* array) { int TVMNDArray_Empty(int32_t ndim, const tvm_index_t* shape, DLDataType dtype, DLDevice dev, TVMNDArray* array) { - int status = TVMNDArray_Create(ndim, shape, dtype, dev, array); + int status = Create(ndim, shape, dtype, dev, array); if (status != 0) { return status; } @@ -132,7 +132,7 @@ int TVMNDArray_Load(TVMNDArray* ret, const char** strm) { int TVMNDArray_CreateView(TVMNDArray* arr, const tvm_index_t* shape, int32_t ndim, DLDataType dtype, TVMNDArray* array_view) { - int status = TVMNDArray_Create(ndim, shape, dtype, arr->dl_tensor.device, array_view); + int status = Create(ndim, shape, dtype, arr->dl_tensor.device, array_view); if (status != 0) { return status; } @@ -149,21 +149,35 @@ int TVMNDArray_RandomFill(TVMNDArray* arr) { return TVMPlatformGenerateRandom(arr->dl_tensor.data, (size_t)num_bytes); } +void TVMNDArray_IncrementReference(TVMNDArray* arr) { arr->reference_count++; } + +uint32_t TVMNDArray_DecrementReference(TVMNDArray* arr) { + if (arr->reference_count > 0) { + arr->reference_count--; + } + + return arr->reference_count; +} + int TVMNDArray_Release(TVMNDArray* arr) { tvm_crt_error_t err; DLDevice dev = {kDLCPU, 0}; + if (TVMNDArray_DecrementReference(arr) > 0) { + return 0; + } + err = TVMPlatformMemoryFree(arr->dl_tensor.data, dev); if (err != kTvmErrorNoError) { return err; } + arr->dl_tensor.data = NULL; - arr->dl_tensor.data = 0; err = TVMPlatformMemoryFree(arr->dl_tensor.shape, dev); if (err != kTvmErrorNoError) { return err; } + arr->dl_tensor.shape = NULL; - arr->dl_tensor.shape = 0; return 0; } diff --git a/src/runtime/crt/graph_executor/graph_executor.c b/src/runtime/crt/graph_executor/graph_executor.c index 3fea408d9760..395a343ccb41 100644 --- a/src/runtime/crt/graph_executor/graph_executor.c +++ b/src/runtime/crt/graph_executor/graph_executor.c @@ -1014,7 +1014,7 @@ int TVMGraphExecutor_SetupStorage(TVMGraphExecutor* executor) { executor->storage_pool_count++; } - // Assign the pooled entries. A unified memory pool is used to simplifiy + // Assign the pooled entries. A unified memory pool is used to simplify // memory assignment for each node entry. The allocated memory on each device // is mapped to this pool. executor->data_entry_count = executor->node_row_ptr[executor->node_row_ptr_count - 1]; @@ -1031,6 +1031,8 @@ int TVMGraphExecutor_SetupStorage(TVMGraphExecutor* executor) { attrs->shape + idx * TVM_CRT_MAX_NDIM, attrs->ndim[idx], vtype[idx], &executor->data_entry[idx]); CHECK_EQ(status, 0, "fail to create for node with idx=%d, storage_id=%u\n", idx, storage_id); + + TVMNDArray_IncrementReference(&executor->data_entry[idx]); } // Release memory diff --git a/src/runtime/crt/graph_executor_module/graph_executor_module.c b/src/runtime/crt/graph_executor_module/graph_executor_module.c index 0ae12f5a9e0a..559b6896a55e 100644 --- a/src/runtime/crt/graph_executor_module/graph_executor_module.c +++ b/src/runtime/crt/graph_executor_module/graph_executor_module.c @@ -95,7 +95,12 @@ int32_t TVMGraphExecutorModule_GetInput(TVMValue* args, int* tcodes, int nargs, uint32_t eid = TVMGraphExecutor_GetEntryId(graph_executor.executor, graph_executor.executor->input_nodes[index], 0); - ret_values[0].v_handle = (void*)&graph_executor.executor->data_entry[eid].dl_tensor; + + TVMNDArray* array = &graph_executor.executor->data_entry[eid]; + + TVMNDArray_IncrementReference(array); + + ret_values[0].v_handle = (void*)(&array->dl_tensor); ret_tcodes[0] = kTVMNDArrayHandle; return 0; } @@ -158,7 +163,11 @@ int32_t TVMGraphExecutorModule_GetOutput(TVMValue* args, int* tcodes, int nargs, uint32_t index = graph_executor.executor->outputs[output_index].index; uint32_t eid = TVMGraphExecutor_GetEntryId(graph_executor.executor, nid, index); - ret_values[0].v_handle = (void*)&(graph_executor.executor->data_entry[eid].dl_tensor); + TVMNDArray* array = &graph_executor.executor->data_entry[eid]; + + TVMNDArray_IncrementReference(array); + + ret_values[0].v_handle = (void*)(&array->dl_tensor); ret_tcodes[0] = kTVMNDArrayHandle; return 0; } diff --git a/src/runtime/crt/host/main.cc b/src/runtime/crt/host/main.cc index bf4a98569e33..d8fa95fe236b 100644 --- a/src/runtime/crt/host/main.cc +++ b/src/runtime/crt/host/main.cc @@ -139,9 +139,6 @@ int main(int argc, char** argv) { "failed to register GraphExecutor TVMModule"); #endif - CHECK_EQ(TVMAotExecutorModule_Register(), kTvmErrorNoError, - "failed to register AoT Executor TVMModule"); - int error = TVMFuncRegisterGlobal("tvm.testing.reset_server", (TVMFunctionHandle)&testonly_reset_server, 0); if (error) { diff --git a/src/runtime/crt/include/tvm/runtime/crt/internal/common/ndarray.h b/src/runtime/crt/include/tvm/runtime/crt/internal/common/ndarray.h index e5869ed2a303..0162c6eb4de6 100644 --- a/src/runtime/crt/include/tvm/runtime/crt/internal/common/ndarray.h +++ b/src/runtime/crt/include/tvm/runtime/crt/internal/common/ndarray.h @@ -38,7 +38,11 @@ static const uint64_t kTVMNDArrayMagic = 0xDD5E40F096B4A13F; static const uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7; typedef struct TVMNDArray { + /*! \brief the actual tensor in DLPack format. NOTE: this must be first element in struct */ DLTensor dl_tensor; + + /*! \brief count of references to TVMNDArray to avoid early freeing by host */ + uint32_t reference_count; } TVMNDArray; int TVMNDArray_Create(int32_t ndim, const tvm_index_t* shape, DLDataType dtype, DLDevice dev, @@ -56,6 +60,10 @@ int TVMNDArray_Load(TVMNDArray* ret, const char** strm); int TVMNDArray_CreateView(TVMNDArray* arr, const tvm_index_t* shape, int32_t ndim, DLDataType dtype, TVMNDArray* array_view); +void TVMNDArray_IncrementReference(TVMNDArray* arr); + +uint32_t TVMNDArray_DecrementReference(TVMNDArray* arr); + int TVMNDArray_Release(TVMNDArray* arr); #endif // TVM_RUNTIME_CRT_INCLUDE_TVM_RUNTIME_CRT_INTERNAL_COMMON_NDARRAY_H_ diff --git a/src/runtime/crt/microtvm_rpc_server/rpc_server.cc b/src/runtime/crt/microtvm_rpc_server/rpc_server.cc index b7bae243ecf0..1e5f625998ab 100644 --- a/src/runtime/crt/microtvm_rpc_server/rpc_server.cc +++ b/src/runtime/crt/microtvm_rpc_server/rpc_server.cc @@ -33,6 +33,7 @@ #define DMLC_CMAKE_LITTLE_ENDIAN DMLC_IO_USE_LITTLE_ENDIAN #define DMLC_LITTLE_ENDIAN 1 #include +#include #include #include #include @@ -207,6 +208,11 @@ microtvm_rpc_server_t MicroTVMRpcServerInit(microtvm_rpc_channel_write_t write_f TVMPlatformAbort(err); } + err = TVMAotExecutorModule_Register(); + if (err != kTvmErrorNoError) { + TVMPlatformAbort(err); + } + DLDevice dev = {kDLCPU, 0}; void* receive_buffer_memory; err = TVMPlatformMemoryAllocate(TVM_CRT_MAX_PACKET_SIZE_BYTES, dev, &receive_buffer_memory); diff --git a/src/runtime/graph_executor/graph_executor.h b/src/runtime/graph_executor/graph_executor.h index 25b01a253c7d..2564f5b0d924 100644 --- a/src/runtime/graph_executor/graph_executor.h +++ b/src/runtime/graph_executor/graph_executor.h @@ -61,7 +61,7 @@ struct TVMOpParam { /*! * \brief Tiny graph executor. * - * This runtime can be acccesibly in various language via + * This runtime can be accessible in various languages via * TVM runtime PackedFunc API. */ class TVM_DLL GraphExecutor : public ModuleNode { diff --git a/tests/micro/zephyr/conftest.py b/tests/micro/zephyr/conftest.py index 997237d370a5..c4de48a0a47a 100644 --- a/tests/micro/zephyr/conftest.py +++ b/tests/micro/zephyr/conftest.py @@ -30,7 +30,7 @@ def pytest_addoption(parser): "--zephyr-board", required=True, choices=test_utils.ZEPHYR_BOARDS.keys(), - help=("Zephyr board for test."), + help="Zephyr board for test.", ) parser.addoption( "--west-cmd", default="west", help="Path to `west` command for flashing device." @@ -92,5 +92,5 @@ def skip_by_board(request, board): def pytest_configure(config): config.addinivalue_line( "markers", - "skip_by_board(board): skip test for the given board", + "skip_boards(board): skip test for the given board", ) diff --git a/tests/micro/zephyr/test_zephyr_aot_exec.py b/tests/micro/zephyr/test_zephyr_aot_exec.py new file mode 100644 index 000000000000..1add0063bc9c --- /dev/null +++ b/tests/micro/zephyr/test_zephyr_aot_exec.py @@ -0,0 +1,157 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +import os +import pathlib +import sys +import logging + +import pytest +import numpy as np + +import onnx +from PIL import Image + +import tvm +import tvm.testing +import tvm.relay as relay +from tvm.relay.backend import Executor, Runtime +from tvm.relay.testing import byoc +from tvm.contrib import utils +from tvm.micro.testing.utils import check_tune_log +from tvm._ffi import get_global_func, register_func + +import test_utils + + +def _make_session(temp_dir, zephyr_board, west_cmd, mod, build_config): + config_main_stack_size = None + if test_utils.qemu_boards(zephyr_board): + # fyi: qemu_riscv64 seems to be the greediest stack user + config_main_stack_size = 4096 + + project_options = { + "project_type": "host_driven", + "west_cmd": west_cmd, + "verbose": bool(build_config.get("debug")), + "zephyr_board": zephyr_board, + } + if config_main_stack_size is not None: + project_options["config_main_stack_size"] = config_main_stack_size + + project = tvm.micro.generate_project( + str(test_utils.TEMPLATE_PROJECT_DIR), + mod, + temp_dir / "project", + project_options, + ) + project.build() + project.flash() + return tvm.micro.Session(project.transport()) + + +@tvm.testing.requires_micro +def test_relay(temp_dir, board, west_cmd, tvm_debug): + """Testing a simple relay graph""" + + model = test_utils.ZEPHYR_BOARDS[board] + build_config = {"debug": tvm_debug} + shape = (10,) + dtype = "int8" + + # Construct Relay program. + x = relay.var("x", relay.TensorType(shape=shape, dtype=dtype)) + xx = relay.multiply(x, x) + z = relay.add(xx, relay.const(np.ones(shape=shape, dtype=dtype))) + func = relay.Function([x], z) + ir_mod = tvm.IRModule.from_expr(func) + + runtime = Runtime("crt", {"system-lib": True}) + executor = Executor("aot") + target = tvm.target.target.micro(model) + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + mod = tvm.relay.build(ir_mod, target=target, runtime=runtime, executor=executor) + + with _make_session(temp_dir, board, west_cmd, mod, build_config) as session: + + aot_executor = tvm.runtime.executor.aot_executor.AotModule(session.create_aot_executor()) + + x_in = np.random.randint(10, size=shape[0], dtype=dtype) + aot_executor.run(x=x_in) + result = aot_executor.get_output(0).numpy() + tvm.testing.assert_allclose(aot_executor.get_input(0).numpy(), x_in) + tvm.testing.assert_allclose(result, x_in * x_in + 1) + + +@tvm.testing.requires_micro +def test_aot_executor(temp_dir, board, west_cmd, tvm_debug): + """Test use of the AOT executor with microTVM.""" + + model = test_utils.ZEPHYR_BOARDS[board] + build_config = {"debug": tvm_debug} + shape = (10,) + dtype = "int8" + + print("test_relay: construct relay program\n") + + # Construct Relay program. + relay_mod = tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), uint8]) { + %0 = %a + %b; + %0 + }""" + ) + + runtime = Runtime("crt", {"system-lib": True}) + executor = Executor("aot") + target = tvm.target.target.micro(model) + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + mod = tvm.relay.build(relay_mod, target=target, runtime=runtime, executor=executor) + + def do_test(): + + aot_executor = tvm.runtime.executor.aot_executor.AotModule(session.create_aot_executor()) + + assert aot_executor.get_input_index("a") == 0 + assert aot_executor.get_input_index("b") == 1 + + assert aot_executor.get_num_inputs() == 2 + assert aot_executor.get_num_outputs() == 1 + + A_np = np.array([[2, 3]], dtype="uint8") + B_np = np.array([[4, 7]], dtype="uint8") + + A_data = aot_executor.get_input("a").copyfrom(A_np) + B_data = aot_executor.get_input("b").copyfrom(B_np) + + aot_executor.run() + + out = aot_executor.get_output(0) + assert (out.numpy() == np.array([6, 10])).all() + + B_np_new = np.array([[5, 8]]) + aot_executor.set_input("b", B_np_new) + assert (B_data.numpy() == B_np_new).all() + + with _make_session(temp_dir, board, west_cmd, mod, build_config) as session: + do_test() + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/micro/zephyr/test_zephyr_aot.py b/tests/micro/zephyr/test_zephyr_aot_exec_standalone.py similarity index 100% rename from tests/micro/zephyr/test_zephyr_aot.py rename to tests/micro/zephyr/test_zephyr_aot_exec_standalone.py