diff --git a/.github/workflows/compiler-build.yml b/.github/workflows/compiler-build.yml index c5c082e802..3a3e600c21 100644 --- a/.github/workflows/compiler-build.yml +++ b/.github/workflows/compiler-build.yml @@ -17,7 +17,7 @@ jobs: strategy: matrix: config: - - {name: x86_64-macos, os: macos-latest, cmakeArgs: -DENABLE_X86SIMD=OFF, buildType: Release} + - {name: aarch64-macos, os: macos-14, cmakeArgs: '', buildType: Release} - {name: x86_64-linux, os: ubuntu-latest, cmakeArgs: '', buildType: Release} - {name: x86_64-windows, os: windows-latest, arch: x64, cmakeArgs: -DCMAKE_C_COMPILER=clang-cl -DCMAKE_CXX_COMPILER=clang-cl, buildType: Release} @@ -25,22 +25,17 @@ jobs: - uses: actions/checkout@v3 - uses: seanmiddleditch/gha-setup-ninja@master + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + - name: Set up build environment (Windows, Visual Studio) uses: ilammy/msvc-dev-cmd@v1 with: arch: ${{matrix.config.arch}} if: runner.os == 'Windows' - - name: Set up build environment (Macos) - run: | - brew install sunnycase/core/libomp@11.1.0 - if: runner.os == 'Macos' - - - name: Setup Python - uses: actions/setup-python@v4 - with: - python-version: 3.7 - - name: Install Conan shell: bash run: | @@ -54,6 +49,13 @@ jobs: echo "CXX=g++-10" >> $GITHUB_ENV if: runner.os == 'Linux' + - name: Configure Conan (Macos) + run: | + conan config init + sed -i '' 's/xtensalx7]/xtensalx7, arm64]/g' ~/.conan/settings.yml + sed -i '' 's/"14.0"]/"14.0", "15"]/g' ~/.conan/settings.yml + if: runner.os == 'Macos' + - name: Configure CMake shell: bash run: | @@ -79,12 +81,14 @@ jobs: matrix: dotnet-version: ['7.0'] config: - - {name: x86_64-macos, os: macos-latest, shell: bash, rid: osx-x64, buildType: Release} + - {name: aarch64-macos, os: macos-14, shell: bash, rid: osx-arm64, buildType: Release} - {name: x86_64-linux, os: ubuntu-latest, shell: bash, rid: linux-x64, buildType: Release} - - {name: x86_64-windows, os: windows-latest, shell: bash, rid: win-x64, buildType: Release} + - {name: x86_64-windows, os: windows-latest, arch: x64, shell: bash, rid: win-x64, buildType: Release} steps: - uses: actions/checkout@v2 + - uses: seanmiddleditch/gha-setup-ninja@master + - name: Setup .NET uses: actions/setup-dotnet@v1 with: @@ -104,11 +108,6 @@ jobs: name: nncase-native-${{matrix.config.name}} path: ${{github.workspace}}/install - - name: Set up build environment (Macos) - run: | - brew install sunnycase/core/libomp@11.1.0 - if: runner.os == 'Macos' - - name: Build run: | dotnet restore -r ${{matrix.config.rid}} @@ -142,7 +141,7 @@ jobs: working-directory: ${{github.workspace}} run: | dotnet tool install --global dotnet-coverage - dotnet-coverage collect -s tools/dotnet_coverage.settings.xml -f cobertura -o coverage/unit.xml "dotnet test -c ${{matrix.config.buildType}} -s test.runsettings --no-build --verbosity normal" + dotnet-coverage collect -s tools/dotnet_coverage.settings.xml -f cobertura -o coverage/unit.xml "dotnet test -c ${{matrix.config.buildType}} -s test.runsettings --no-build --verbosity normal --blame" dotnet-coverage merge -o coverage.unit.xml -f cobertura -r coverage/*.xml - name: Upload Coverage @@ -168,20 +167,29 @@ jobs: matrix: dotnet-version: ['7.0'] config: - - {name: x86_64-macos, os: macos-latest, shell: bash} + - {name: aarch64-macos, os: macos-14, shell: bash} - {name: x86_64-linux, os: ubuntu-latest, shell: bash} - - {name: x86_64-windows, os: windows-latest, shell: bash} + - {name: x86_64-windows, os: windows-latest, arch: x64, shell: bash} env: - VULKANSDK_VER: 1.3.268.0 + VULKANSDK_VER: 1.3.280.0 steps: - uses: actions/checkout@v3 + - uses: seanmiddleditch/gha-setup-ninja@master + - name: Setup .NET uses: actions/setup-dotnet@v1 with: dotnet-version: ${{matrix.dotnet-version}} + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + cache: 'pip' + cache-dependency-path: '**/requirements.test.txt' + - name: Install nncase native Artifact uses: actions/download-artifact@v3 with: @@ -196,16 +204,11 @@ jobs: - name: Set up test environment (macOS) run: | - brew install sunnycase/core/libomp@11.1.0 - aria2c --parameterized-uri=true https://{sdk.lunarg.com/sdk/download/${VULKANSDK_VER}/mac,distfiles.macports.org/MoltenVK}/vulkansdk-macos-${VULKANSDK_VER}.dmg + aria2c --parameterized-uri=true https://sdk.lunarg.com/sdk/download/${VULKANSDK_VER}/mac/vulkansdk-macos-${VULKANSDK_VER}.dmg hdiutil attach ./vulkansdk-macos-*.dmg sudo /Volumes/vulkansdk-macos-*/InstallVulkan.app/Contents/MacOS/InstallVulkan --root $HOME/VulkanSDK --accept-licenses --default-answer --confirm-command install hdiutil detach /Volumes/vulkansdk-macos-* echo "VULKAN_SDK=$HOME/VulkanSDK/macOS" >> $GITHUB_ENV - wget https://github.com/sunnycase/swiftshader/releases/download/v1.0/swiftshader-macos-10.15-x86_64.zip -O swiftshader.zip - unzip swiftshader.zip - sudo cmake -E make_directory /usr/local/share/vulkan/icd.d - sudo cp lib/* /usr/local/share/vulkan/icd.d cp install/lib/*.dylib install/ echo "PYTHONPATH=$GITHUB_WORKSPACE/install/lib:$GITHUB_WORKSPACE/install/python:$GITHUB_WORKSPACE/tests" >> $GITHUB_ENV if: runner.os == 'macOS' @@ -232,18 +235,12 @@ jobs: Expand-Archive swiftshader.zip Copy-Item swiftshader\lib\vk_swiftshader_icd.json swiftshader\bin\ Copy-Item install/bin/*.dll install/ + Copy-Item install/bin/*.dll install/lib/ echo "VK_ICD_FILENAMES=${env:GITHUB_WORKSPACE}/swiftshader/bin/vk_swiftshader_icd.json" >> $env:GITHUB_ENV echo "PYTHONPATH=${env:GITHUB_WORKSPACE}/install/lib;${env:GITHUB_WORKSPACE}/install/python;${env:GITHUB_WORKSPACE}/tests" >> $env:GITHUB_ENV echo "PATH=${env:PATH};${env:GITHUB_WORKSPACE}/install/bin" >> $env:GITHUB_ENV if: runner.os == 'Windows' - - name: Setup Python - uses: actions/setup-python@v4 - with: - python-version: 3.7 - cache: 'pip' - cache-dependency-path: '**/requirements.test.txt' - - name: Install Python Packages run: python -m pip install --upgrade pip @@ -263,7 +260,7 @@ jobs: dotnet-coverage collect -s tools/dotnet_coverage.settings.xml -f cobertura -o coverage/onnx_combine.xml pytest tests/importer/onnx_/combine/ --doctest-modules --junitxml=test_results/onnx_combine.xml dotnet-coverage collect -s tools/dotnet_coverage.settings.xml -f cobertura -o coverage/tflite_basic.xml pytest tests/importer/tflite_/basic/ --doctest-modules --junitxml=test_results/tflite_basic.xml dotnet-coverage collect -s tools/dotnet_coverage.settings.xml -f cobertura -o coverage/tflite_combine.xml pytest tests/importer/tflite_/combine/ --doctest-modules --junitxml=test_results/tflite_combine.xml - dotnet-coverage collect -s tools/dotnet_coverage.settings.xml -f cobertura -o coverage/tflite_model.xml pytest tests/importer/tflite_/model/ --doctest-modules --junitxml=test_results/tflite_model.xml + #dotnet-coverage collect -s tools/dotnet_coverage.settings.xml -f cobertura -o coverage/tflite_model.xml pytest tests/importer/tflite_/model/ --doctest-modules --junitxml=test_results/tflite_model.xml dotnet-coverage collect -s tools/dotnet_coverage.settings.xml -f cobertura -o coverage/ncnn_basic.xml pytest tests/importer/ncnn_/basic/ --doctest-modules --junitxml=test_results/ncnn_basic.xml dotnet-coverage merge -o coverage.integration.xml -f cobertura -r coverage/*.xml @@ -327,4 +324,4 @@ jobs: with: name: nncase-coverage-report path: coveragereport - if-no-files-found: error \ No newline at end of file + if-no-files-found: error diff --git a/.github/workflows/compiler-python-release.yml b/.github/workflows/compiler-python-release.yml index 5e0db927a0..1bf17c17e0 100644 --- a/.github/workflows/compiler-python-release.yml +++ b/.github/workflows/compiler-python-release.yml @@ -14,7 +14,7 @@ jobs: matrix: dotnet-version: ['7.0'] config: - - {name: x86_64-macos, os: macos-latest, shell: bash, rid: osx-x64, buildType: Release} + # - {name: aarch64-macos, os: macos-14, shell: bash, rid: osx-arm64, buildType: Release} - {name: x86_64-linux, os: ubuntu-latest, shell: bash, rid: linux-x64, buildType: Release} - {name: x86_64-windows, os: windows-latest, shell: bash, rid: win-x64, buildType: Release} @@ -53,7 +53,7 @@ jobs: matrix: dotnet-version: ['7.0'] config: - - {name: x86_64-macos, os: macos-latest} + # - {name: aarch64-macos, os: macos-14} - {name: x86_64-linux, os: ubuntu-latest} - {name: x86_64-windows, os: windows-latest, arch: x64} @@ -88,7 +88,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v4 with: - python-version: 3.7 + python-version: '3.10' - name: Install cibuildwheel run: pip install cibuildwheel diff --git a/.github/workflows/jupyter-test.yml b/.github/workflows/jupyter-test.yml index 1d2ee23550..19a74c8086 100755 --- a/.github/workflows/jupyter-test.yml +++ b/.github/workflows/jupyter-test.yml @@ -10,7 +10,7 @@ jobs: strategy: matrix: config: - - {name: x86_64-macos, os: macos-latest} + - {name: aarch64-macos, os: macos-14} - {name: x86_64-linux, os: ubuntu-latest} - {name: x86_64-windows, os: windows-latest} @@ -20,7 +20,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v2 with: - python-version: 3.7 + python-version: '3.10' - name: Install dependencies run: pip install --upgrade pip && pip install jupyterlab pytest nbmake diff --git a/.github/workflows/runtime-build.yml b/.github/workflows/runtime-build.yml index c11d287f2d..228e74224c 100644 --- a/.github/workflows/runtime-build.yml +++ b/.github/workflows/runtime-build.yml @@ -13,7 +13,7 @@ jobs: strategy: matrix: config: - - { name: x86_64-macos, os: macos-latest, cmakeArgs: '', buildType: Release } + #- { name: aarch64-macos, os: macos-14, cmakeArgs: '', buildType: Release } - { name: x86_64-linux, os: ubuntu-latest, cmakeArgs: '', buildType: Release } - { name: x86_64-windows, os: windows-latest, arch: x64, cmakeArgs: -DCMAKE_C_COMPILER=clang-cl -DCMAKE_CXX_COMPILER=clang-cl, buildType: Release } @@ -27,15 +27,10 @@ jobs: arch: ${{matrix.config.arch}} if: runner.os == 'Windows' - - name: Set up build environment (Macos) - run: | - brew install sunnycase/core/libomp@11.1.0 - if: runner.os == 'Macos' - - name: Setup Python uses: actions/setup-python@v4 with: - python-version: 3.7 + python-version: '3.10' - name: Install Conan run: | @@ -51,10 +46,17 @@ jobs: echo "CXX=g++-10" >> $GITHUB_ENV if: runner.os == 'Linux' + - name: Configure Conan (Macos) + run: | + conan config init + sed -i '' 's/xtensalx7]/xtensalx7, arm64]/g' ~/.conan/settings.yml + sed -i '' 's/"14.0"]/"14.0", "15"]/g' ~/.conan/settings.yml + if: runner.os == 'Macos' + - name: Configure CMake shell: bash run: | - conan install . -if build --build=missing -s build_type=${{matrix.config.buildType}} --profile=default -o runtime=True -o python=False -o tests=True -s compiler.cppstd=17 + conan install . -if build --build=missing -s build_type=${{matrix.config.buildType}} --profile=default -o runtime=True -o python=False -o tests=True -s compiler.cppstd=20 - name: Build & Install run: | @@ -101,7 +103,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v4 with: - python-version: 3.7 + python-version: '3.10' - name: Install toolchain and QEMU shell: bash @@ -129,7 +131,7 @@ jobs: - name: Configure CMake run: | - conan install . -if build --build=missing -s build_type=${{matrix.config.buildType}} --profile:host=toolchains/riscv64-unknown-linux.profile.jinja --profile:build=default -o runtime=True -o python=False -o tests=True -s compiler.cppstd=17 + conan install . -if build --build=missing -s build_type=${{matrix.config.buildType}} --profile:host=toolchains/riscv64-unknown-linux.profile.jinja --profile:build=default -o runtime=True -o python=False -o tests=True -s compiler.cppstd=20 - name: Build & Install run: | diff --git a/.gitignore b/.gitignore index 5b1e72c18f..eaffc2eb90 100644 --- a/.gitignore +++ b/.gitignore @@ -261,6 +261,7 @@ __pycache__/ # vscode .vscode/ +.mono/ # clangd .cache/ diff --git a/CMakeLists.txt b/CMakeLists.txt index 7ac7539a47..c5c4cd42a2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -39,8 +39,6 @@ project(nncase VERSION ${NNCASE_VERSION} LANGUAGES C CXX ASM) -option(ENABLE_OPENMP "OpenMP support" ON) -option(ENABLE_HALIDE "halide kernels support" ON) option(DOTNET_INIT_FOR_CONFIG "Initialize dotnet from runtimeconfig" OFF) option(BUILD_PYTHON_BINDING "Build python binding" ON) option(BUILD_CSHARP_BINDING "Build csharp binding" ON) @@ -106,7 +104,7 @@ if (BUILDING_RUNTIME) else() add_compile_options(-Wall -Wextra -pedantic -Werror -Wno-multichar -Wno-missing-field-initializers -Wno-unused-function -Wno-type-limits) if (APPLE) - add_compile_options(-Wno-four-char-constants -Wno-sometimes-uninitialized) + add_compile_options(-Wno-four-char-constants -Wno-sometimes-uninitialized -Wno-deprecated-declarations) elseif (CMAKE_CXX_COMPILER_ID STREQUAL "Clang") add_compile_options(-Wno-uninitialized -Wno-unused-private-field) else() @@ -124,6 +122,9 @@ if (BUILDING_RUNTIME) # add_subdirectory(src/Native/src/kernels) # add_subdirectory(src/Native/src/runtime) add_subdirectory(src/Native/src) + if(BUILD_TESTING) + add_subdirectory(src/Native/test) + endif() # add_subdirectory(src/Native/src/functional) if(BUILD_BENCHMARK) # add_subdirectory(benchmark) @@ -214,7 +215,9 @@ else() add_subdirectory(src/Native/include/nncase) add_subdirectory(src/Native/src) - +if(BUILD_TESTING) + add_subdirectory(src/Native/test) +endif() # Python binding if(BUILD_PYTHON_BINDING) add_subdirectory(python/nncase/native) diff --git a/Directory.Packages.props b/Directory.Packages.props index 3bb76c949a..2aa955d89e 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -12,13 +12,12 @@ true - - - - - + + + + + - @@ -43,26 +42,29 @@ 1.1.1 - - - - - - - - + + + + + + + + + + + - + @@ -74,7 +76,6 @@ - diff --git a/NuGet.Config b/NuGet.Config index fd11e2a06a..5e7849eabb 100644 --- a/NuGet.Config +++ b/NuGet.Config @@ -2,11 +2,13 @@ + + diff --git a/benchmark/models/models.cpp b/benchmark/models/models.cpp index b7b7239cfb..0b59b54b37 100644 --- a/benchmark/models/models.cpp +++ b/benchmark/models/models.cpp @@ -23,7 +23,7 @@ using namespace nncase; namespace { -gsl::span get_model_impl(const std::string &name, size_t id) +std::span get_model_impl(const std::string &name, size_t id) { auto hres = FindResourceW(NULL, MAKEINTRESOURCEW(id), L"Binary"); if (!hres) @@ -33,7 +33,7 @@ gsl::span get_model_impl(const std::string &name, size_t id) if (!hmem) return {}; auto res_data = LockResource(hmem); - return { reinterpret_cast(res_data), (size_t)size }; + return { reinterpret_cast(res_data), (size_t)size }; } } @@ -41,7 +41,7 @@ gsl::span get_model_impl(const std::string &name, size_t id) if (name == #model) \ return get_model_impl(name, IDR_cpu_##model) -gsl::span nncase::get_model(const std::string &name) +std::span nncase::get_model(const std::string &name) { GET_MODEL_IMPL(mnist); GET_MODEL_IMPL(mobilenet_v2); @@ -55,9 +55,9 @@ INCBIN(mobilenet_v2, "cpu/mobilenet_v2.kmodel"); #define GET_MODEL_IMPL(model) \ if (name == #model) \ - return { reinterpret_cast(g##model##_data), g##model##_size } + return { reinterpret_cast(g##model##_data), g##model##_size } -gsl::span nncase::get_model(const std::string &name) +std::span nncase::get_model(const std::string &name) { GET_MODEL_IMPL(mnist); GET_MODEL_IMPL(mobilenet_v2); diff --git a/benchmark/models/models.h b/benchmark/models/models.h index 7ee9ce92c2..56096be505 100644 --- a/benchmark/models/models.h +++ b/benchmark/models/models.h @@ -17,5 +17,5 @@ namespace nncase { -gsl::span get_model(const std::string &name); +std::span get_model(const std::string &name); } diff --git a/cmake/configure-conan.cmake b/cmake/configure-conan.cmake index e5b75ca340..63662a9fd0 100644 --- a/cmake/configure-conan.cmake +++ b/cmake/configure-conan.cmake @@ -14,16 +14,10 @@ endfunction() _SET_CONANOPT(CONAN_OPTS "runtime" BUILDING_RUNTIME) _SET_CONANOPT(CONAN_OPTS "tests" BUILD_TESTING) _SET_CONANOPT(CONAN_OPTS "python" BUILD_PYTHON_BINDING) -_SET_CONANOPT(CONAN_OPTS "openmp" ENABLE_OPENMP) _SET_CONANOPT(CONAN_OPTS "vulkan_runtime" ENABLE_VULKAN_RUNTIME) -_SET_CONANOPT(CONAN_OPTS "halide" ENABLE_HALIDE) if (NOT DEFINED CMAKE_CXX_STANDARD) - if (BUILDING_RUNTIME) - set (CMAKE_CXX_STANDARD 17) - else () - set (CMAKE_CXX_STANDARD 20) - endif () + set (CMAKE_CXX_STANDARD 20) endif () _SET_CONANSETTING(CONAN_SETTINGS "compiler.cppstd" ${CMAKE_CXX_STANDARD}) diff --git a/cmake/dependencies.cmake b/cmake/dependencies.cmake index 85e8e1213e..2827003ab2 100644 --- a/cmake/dependencies.cmake +++ b/cmake/dependencies.cmake @@ -1,25 +1,13 @@ -find_package(gsl-lite REQUIRED) if (ENABLE_OPENMP) find_package(OpenMP COMPONENTS CXX REQUIRED) endif () -if ((NOT BUILDING_RUNTIME) OR ENABLE_VULKAN_RUNTIME) - find_package(Vulkan REQUIRED) -endif () - if (NOT BUILDING_RUNTIME) - find_package(absl REQUIRED) find_package(nethost REQUIRED) find_package(fmt REQUIRED) - find_package(magic_enum REQUIRED) - find_package(spdlog REQUIRED) - find_package(inja REQUIRED) + find_package(nlohmann_json REQUIRED) endif () if (BUILD_TESTING) find_package(GTest REQUIRED) endif () - -if (ENABLE_HALIDE) - find_package(hkg REQUIRED) -endif () \ No newline at end of file diff --git a/cmake/nncaseConfig.cmake.in b/cmake/nncaseConfig.cmake.in index 7d1a54245e..bf853ae583 100644 --- a/cmake/nncaseConfig.cmake.in +++ b/cmake/nncaseConfig.cmake.in @@ -1,3 +1,2 @@ include(${CMAKE_CURRENT_LIST_DIR}/nncaseTargets.cmake) -find_package(gsl-lite REQUIRED) find_package(fmt REQUIRED) diff --git a/cmake/nncaseruntimeConfig.cmake.in b/cmake/nncaseruntimeConfig.cmake.in index cce5810298..b4500a2ae9 100644 --- a/cmake/nncaseruntimeConfig.cmake.in +++ b/cmake/nncaseruntimeConfig.cmake.in @@ -1,5 +1 @@ include(${CMAKE_CURRENT_LIST_DIR}/nncaseruntimeTargets.cmake) - -if(NOT TARGET gsl-lite) - find_package(gsl-lite REQUIRED) -endif() \ No newline at end of file diff --git a/conanfile.py b/conanfile.py index 8a3a0c72b7..9a4200dadd 100644 --- a/conanfile.py +++ b/conanfile.py @@ -24,20 +24,16 @@ class nncaseConan(ConanFile): "fPIC": [True, False], "runtime": [True, False], "tests": [True, False], - "halide": [True, False], "python": [True, False], - "vulkan_runtime": [True, False], - "openmp": [True, False] + "vulkan_runtime": [True, False] } default_options = { "shared": False, "fPIC": True, "runtime": False, "tests": False, - "halide": True, "python": True, - "vulkan_runtime": False, - "openmp": True + "vulkan_runtime": False } def imports(self): @@ -46,67 +42,42 @@ def imports(self): self.copy("ortki.dll", "bin", "bin") def requirements(self): - self.requires('gsl-lite/0.37.0') - self.requires('hkg/0.0.1') if self.options.tests: self.requires('gtest/1.10.0') self.requires('ortki/0.0.2') self.requires('rapidjson/1.1.x') if self.options.python: - self.requires('pybind11/2.6.1') + self.requires('pybind11/2.11.1') if not self.options.runtime: - self.requires('abseil/20220623.1') - self.requires('nethost/6.0.11') + self.requires('nethost/7.0.5') self.requires('fmt/7.1.3') - self.requires('magic_enum/0.7.0') - self.requires('spdlog/1.8.2') - self.requires('inja/3.2.0') - if self.options.tests: - self.requires('gtest/1.10.0') - - if (not self.options.runtime) or self.options.vulkan_runtime: - self.requires('vulkan-headers/1.2.182') - self.requires('vulkan-loader/1.2.182') + self.requires('nlohmann_json/3.9.1') def build_requirements(self): pass def configure(self): - min_cppstd = "17" if self.options.runtime else "20" + min_cppstd = "20" tools.check_min_cppstd(self, min_cppstd) if self.settings.os == 'Windows': self.settings.compiler.toolset = 'ClangCL' - - if self.settings.arch not in ("x86_64",): - self.options.halide = False if not self.options.runtime: if self.settings.os == 'Windows': self.options["nethost"].shared = True - if (not self.options.runtime) or self.options.vulkan_runtime: - if self.settings.os == 'Linux': - self.options["vulkan-loader"].with_wsi_xcb = False - self.options["vulkan-loader"].with_wsi_xlib = False - self.options["vulkan-loader"].with_wsi_wayland = False - self.options["vulkan-loader"].with_wsi_directfb = False - if self.options.tests: self.options["ortki"].shared = True def cmake_configure(self): cmake = CMake(self) cmake.definitions['BUILDING_RUNTIME'] = self.options.runtime - cmake.definitions['ENABLE_OPENMP'] = self.options.openmp cmake.definitions['ENABLE_VULKAN_RUNTIME'] = self.options.vulkan_runtime - cmake.definitions['ENABLE_HALIDE'] = self.options.halide cmake.definitions['BUILD_PYTHON_BINDING'] = self.options.python cmake.definitions['BUILD_TESTING'] = self.options.tests - if self.options.runtime: - cmake.definitions["CMAKE_CXX_STANDARD"] = 17 cmake.configure() return cmake diff --git a/csharp/RuntimeTensor.h b/csharp/RuntimeTensor.h index d25c52a565..0b6b9fdf66 100644 --- a/csharp/RuntimeTensor.h +++ b/csharp/RuntimeTensor.h @@ -94,8 +94,8 @@ RuntimeTensor_from_buffer(const uint8_t *buffer_ptr, datatype_t datatype, host_runtime_tensor::create( (datatype_t)datatype, to_shape(shape_ptr, shape_size), to_strides(stride_ptr, shape_size), - gsl::make_span((gsl::byte *)(buffer_ptr), total_items * item_size), - [=](gsl::byte *) {}) + gsl::make_span((std::byte *)(buffer_ptr), total_items * item_size), + [=](std::byte *) {}) .unwrap_or_throw(); auto rt = new runtime_tensor(std::move(hostrt)); return rt; diff --git a/csharp/interpreter.cpp b/csharp/interpreter.cpp index ebda591c3f..29cd9ba136 100644 --- a/csharp/interpreter.cpp +++ b/csharp/interpreter.cpp @@ -37,7 +37,7 @@ interpreter_init() { EXPORT_API(void) interpreter_load_model(uint8_t *buffer_ptr, int size) { auto buffer = - gsl::span((const gsl::byte *)(buffer_ptr), size); + std::span((const std::byte *)(buffer_ptr), size); _interp->load_model(buffer).unwrap_or_throw(); } diff --git a/modules/Nncase.Modules.CPU/CPUApplicationPart.cs b/modules/Nncase.Modules.CPU/CPUApplicationPart.cs new file mode 100644 index 0000000000..ecaeb388ad --- /dev/null +++ b/modules/Nncase.Modules.CPU/CPUApplicationPart.cs @@ -0,0 +1,31 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Text; +using System.Threading.Tasks; +using DryIoc; +using Nncase.Hosting; + +namespace Nncase; + +/// +/// CPU application part extensions. +/// +public static class CPUApplicationPart +{ + /// + /// Add CPU assembly. + /// + /// Service registrator. + /// Configured service registrator. + public static IRegistrator AddCPU(this IRegistrator registrator) + { + return registrator.RegisterModule() + .RegisterModule() + .RegisterModule(); + } +} diff --git a/modules/Nncase.Modules.CPU/CPUModule.cs b/modules/Nncase.Modules.CPU/CPUModule.cs new file mode 100644 index 0000000000..5e91015cef --- /dev/null +++ b/modules/Nncase.Modules.CPU/CPUModule.cs @@ -0,0 +1,19 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using DryIoc; +using Nncase.Hosting; +using Nncase.Targets; + +namespace Nncase; + +/// +/// CPU module. +/// +internal class CPUModule : IApplicationPart +{ + public void ConfigureServices(IRegistrator registrator) + { + registrator.Register(reuse: Reuse.Singleton); + } +} diff --git a/modules/Nncase.Modules.CPU/CodeGen/CPU/CSourceBuiltn.cs b/modules/Nncase.Modules.CPU/CodeGen/CPU/CSourceBuiltn.cs new file mode 100644 index 0000000000..a9e75554b5 --- /dev/null +++ b/modules/Nncase.Modules.CPU/CodeGen/CPU/CSourceBuiltn.cs @@ -0,0 +1,59 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. +using System.Runtime.CompilerServices; +using DryIoc.ImTools; +using NetFabric.Hyperlinq; +using Razor.Templating.Core; + +namespace Nncase.CodeGen.CPU; + +public static class CSourceBuiltn +{ + public const string KernelHeader = @"#pragma once +#include +using namespace nncase::ntt; + +"; + + public static string CMakeDef(string name) + { + var cmakePath = CMakePath(Path.Combine(Path.GetDirectoryName(typeof(CSourceBuiltn).Assembly.Location)!, "Runtime", "src", "cpu_runtime.cmake")); + var content = RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/CMakeLists.txt.cshtml", new { CMakePath = cmakePath }).Result; + return content; + } + + public static string MakeKernel(string ctype, string kernelImpl) + { + return KernelHeader + ctype + kernelImpl; + } + + public static string MakeMain(TIR.PrimFunction primFunction, IEnumerable rdataBuffers) + { + string init_tensors = string.Join("\n", primFunction.Parameters.ToArray().Select((b, i) => + { + var buffer = (TIR.Buffer)b; + var size = TensorUtilities.GetSize(b.CheckedShape.ToValueArray(), TensorUtilities.GetStrides(b.CheckedShape.ToValueArray()), 1); + return $@" std::span<{buffer.ElemType.ToC()}, {size}> p{buffer.Name}(({buffer.ElemType.ToC()} *)inputs[{i}], {size}); + tensor_view<{buffer.ElemType.ToC()}, {KernelUtility.DimensionsToC(buffer.Dimensions)}, {KernelUtility.StridesToC(buffer.Strides)}> {buffer.Name}(p{buffer.Name}); +"; + }).Concat(rdataBuffers.Select(b => + { + var size = TensorUtilities.GetSize(b.CheckedShape.ToValueArray(), TensorUtilities.GetStrides(b.CheckedShape.ToValueArray()), 1); + return $@" std::span<{b.ElemType.ToC()}, {size}> p{b.Name}(({b.ElemType.ToC()}*)(rdata + {((IR.TensorConst)b.MemSpan.Start).Value.ToScalar()}), {size}); + tensor_view<{b.ElemType.ToC()}, {KernelUtility.DimensionsToC(b.Dimensions)}, {KernelUtility.StridesToC(b.Strides)}> {b.Name}(p{b.Name});"; + }))); + return @$"#include +#include ""../device.h"" +#include ""kernel.h"" + +extern ""C"" void kernel_entry(nncase_runtime_cpu_mt_t *cpu_mt, uint8_t **inputs, uint8_t *rdata, uint8_t *l1_data) {{ +g_cpu_mt = cpu_mt; +{init_tensors} + + {primFunction.Name}({string.Join(", ", primFunction.Parameters.AsValueEnumerable().Select(b => ((TIR.Buffer)b).Name).ToArray().Concat(rdataBuffers.Select(b => b.Name)).ToArray())}, l1_data); +}}"; + } + + private static string CMakePath(string path) => + path.Replace("\\", "/", StringComparison.Ordinal); +} diff --git a/modules/Nncase.Modules.CPU/CodeGen/CPU/CSourceCompiler.cs b/modules/Nncase.Modules.CPU/CodeGen/CPU/CSourceCompiler.cs new file mode 100644 index 0000000000..929219da36 --- /dev/null +++ b/modules/Nncase.Modules.CPU/CodeGen/CPU/CSourceCompiler.cs @@ -0,0 +1,202 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Runtime.InteropServices; +using System.Text; +using Nncase.IR; +using Nncase.Schedule; +using Nncase.TIR; + +namespace Nncase.CodeGen.CPU; + +/// +/// the csource code compiler. +/// +public class CSourceCompiler +{ + /// + /// compiler exe name. + /// + private string _exe = string.Empty; + + /// + /// compiler exe name. + /// + private string _arch = string.Empty; + + /// + /// compiler exe name. + /// + private string _ext = string.Empty; + + public CSourceCompiler() + { + PlatformSpecific(); + ArchSpecific(); + } + + protected string Exe + { + get => _exe; + } + + protected string Arch + { + get => _arch; + } + + protected string Ext + { + get => _ext; + } + + /// + /// compile the source txt, write to the out_path. + /// + /// c source code. + /// out .so path. + /// outPath. + public string Compile(string sourcePath, string outPath) + { + var errMsg = new StringBuilder(); + using (var errWriter = new StringWriter(errMsg)) + { + using (var proc = new Process()) + { + proc.StartInfo.FileName = Exe; + proc.StartInfo.Arguments = ArgumentsSpecific(sourcePath, outPath); + proc.StartInfo.WorkingDirectory = Directory.GetCurrentDirectory(); + proc.StartInfo.RedirectStandardError = true; + proc.StartInfo.RedirectStandardOutput = true; + proc.OutputDataReceived += (sender, e) => errWriter.WriteLine(e.Data); + proc.ErrorDataReceived += (sender, e) => errWriter.WriteLine(e.Data); + proc.Start(); + proc.BeginErrorReadLine(); + proc.BeginOutputReadLine(); + proc.WaitForExit(); + if (proc.ExitCode != 0) + { + throw new InvalidOperationException(errMsg.ToString()); + } + } + } + + return outPath; + } + + /// + /// create the temp dll file and compile source + /// . + /// + public string Compile(string sourcePath) => Compile(sourcePath, Path.Join(sourcePath, "build", Path.GetFileName(sourcePath))); + + private static string? FindVCVarPath() + { + var vsDir = Environment.GetEnvironmentVariable("VSAPPIDDIR"); + if (!string.IsNullOrEmpty(vsDir)) + { + return Path.Combine(vsDir, "..\\..\\VC\\Auxiliary\\Build\\vcvarsall.bat"); + } + else + { + var vsWhereDir = Path.Combine(Environment.GetEnvironmentVariable("ProgramFiles(x86)")!, "Microsoft Visual Studio\\Installer\\vswhere"); + if (string.IsNullOrEmpty(vsWhereDir)) + { + return null; + } + + using (var proc = new Process()) + { + proc.StartInfo.FileName = vsWhereDir; + proc.StartInfo.Arguments = "-prerelease -latest -property installationPath"; + proc.StartInfo.RedirectStandardOutput = true; + proc.Start(); + proc.WaitForExit(); + vsDir = proc.StandardOutput.ReadLine()!; + return Path.Combine(vsDir, "VC\\Auxiliary\\Build\\vcvarsall.bat"); + } + } + } + + /// + /// select current pattern's exe. + /// + /// NotSupportedException. + private void PlatformSpecific() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + { + _exe = "/bin/bash"; + _ext = "so"; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + { + _exe = "/bin/bash"; + _ext = "dylib"; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + _exe = "cmd"; + _ext = "dll"; + } + + if (System.Environment.GetEnvironmentVariable("NNCASE_CPU_COMPILER") is string exe) + { + _exe = exe; + } + } + + private void ArchSpecific() + { + _arch = RuntimeInformation.OSArchitecture switch + { + Architecture.X64 => RuntimeInformation.IsOSPlatform(OSPlatform.Linux) ? "x86-64" : "x86_64", + Architecture.Arm64 => "arm64", + _ => throw new NotSupportedException(RuntimeInformation.OSArchitecture.ToString()), + }; + } + + private string ArgumentsSpecific(string sourcePath, string outPath) + { + var archConfig = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? "-DCMAKE_C_COMPILER=clang-cl -DCMAKE_CXX_COMPILER=clang-cl" : string.Empty; + +#if DEBUG + var config = "Debug"; +#else + var config = "Release"; +#endif + var script = $""" + cd {sourcePath} && + cmake -E remove_directory build && + cmake -G Ninja -S . -B build -DCMAKE_BUILD_TYPE={config} {archConfig} && + cmake --build build --config {config} + """.Replace("\r\n", " ", StringComparison.Ordinal); + + if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + { + return $"-c \"{script}\""; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + { + return $"-c \"{script}\""; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + var vcVarPath = FindVCVarPath(); + if (!string.IsNullOrEmpty(vcVarPath)) + { + return $"/C \"(\"{vcVarPath}\" x64) && {script}\""; + } + + return $"/C {script}"; + } + + throw new NotSupportedException("Only Support Linux/Osx/Windows"); + } +} diff --git a/modules/Nncase.Modules.CPU/CodeGen/CPU/CSourceExtensions.cs b/modules/Nncase.Modules.CPU/CodeGen/CPU/CSourceExtensions.cs new file mode 100644 index 0000000000..f1f918613f --- /dev/null +++ b/modules/Nncase.Modules.CPU/CodeGen/CPU/CSourceExtensions.cs @@ -0,0 +1,137 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.IR; +using Nncase.TIR; + +namespace Nncase.CodeGen.CPU; + +/// +/// convert the type/op to c name. +/// +internal static class CSourceExtensions +{ + private static readonly Dictionary _primTypeToC = new() + { + { DataTypes.Boolean, "uint8_t" }, + { DataTypes.Int8, "int8_t" }, + { DataTypes.Int16, "int16_t" }, + { DataTypes.Int32, "int32_t" }, + { DataTypes.Int64, "int64_t" }, + { DataTypes.UInt8, "uint8_t" }, + { DataTypes.UInt16, "uint16_t" }, + { DataTypes.UInt32, "uint32_t" }, + { DataTypes.UInt64, "uint64_t" }, + { DataTypes.Float32, "float" }, + { DataTypes.Float64, "double" }, + }; + + public static string ToC(this PrimType primType) => + _primTypeToC[primType]; + + public static string ToC(this ReduceArgOp op) => op switch + { + ReduceArgOp.ArgMin => "arg_min", + ReduceArgOp.ArgMax => "arg_max", + _ => throw new NotImplementedException(), + }; + + public static string ToC(this DataType dataType) => dataType switch + { + PrimType ptype => ptype.ToC(), + PointerType => "uint8_t *", + VectorType vtype => $"vector<{vtype.ElemType.ToC()},{string.Join(",", vtype.Lanes)}>", + _ => throw new NotSupportedException(dataType.ToString()), + }; + + public static string ToC(this MemoryLocation location) => location switch + { + MemoryLocation.Output or MemoryLocation.Input or MemoryLocation.Rdata => "loc_t::device", + MemoryLocation.L2Data => "loc_t::shared", + MemoryLocation.L1Data => "loc_t::local", + _ => throw new NotSupportedException(location.ToString()), + }; + + public static string ToC(this ImageResizeMode mode) => mode switch + { + ImageResizeMode.Bilinear => "bilinear", + ImageResizeMode.NearestNeighbor => "nearest_neighbor", + _ => throw new NotImplementedException(), + }; + + public static string ToC(this ImageResizeTransformationMode mode) => mode switch + { + ImageResizeTransformationMode.HalfPixel => "half_pixel", + ImageResizeTransformationMode.PytorchHalfPixel => "pytorch_half_pixel", + ImageResizeTransformationMode.AlignCorners => "align_corners", + ImageResizeTransformationMode.Asymmetric => "asymmetric", + ImageResizeTransformationMode.TFCropAndResize => "tfcrop_and_resize", + _ => throw new NotImplementedException(), + }; + + public static string ToC(this ImageResizeNearestMode mode) => mode switch + { + ImageResizeNearestMode.RoundPreferFloor => "round_prefer_floor", + ImageResizeNearestMode.RoundPreferCeil => "round_prefer_ceil", + ImageResizeNearestMode.Floor => "floor", + ImageResizeNearestMode.Ceil => "ceil", + _ => throw new NotImplementedException(), + }; + + public static string ToSlicing(this IEnumerable dims, string[] begins, IRArray ndsbp, Placement placement) + { + var hstrides = TensorUtilities.GetStrides(placement.Hierarchy.ToArray()); + var splits = Enumerable.Range(0, begins.Length).Select(_ => new List<(int H, SBPSplit S)>()).ToArray(); + foreach (var (sbp, i) in ndsbp.Select((s, i) => (s, i))) + { + if (sbp is SBPSplit { Axis: int axis } split) + { + splits[axis].Add((i, split)); + } + } + + foreach (var splist in splits) + { + splist.Sort((a, b) => -a.H.CompareTo(b.H)); + } + + for (int i = 0; i < begins.Length; i++) + { + var sp = splits[i]; + if (sp.Count > 0) + { + var dimi = dims.ElementAt(i); + if (dimi.IndexOf('?', System.StringComparison.CurrentCulture) is int s && dimi.IndexOf(':', System.StringComparison.CurrentCulture) is int e && s != -1 && e != -1) + { + dimi = dimi[(s + 1)..e].Trim(); + } + + begins[i] += " + " + sp.Skip(1).Aggregate($"{placement.Name[sp[0].H]}id", (acc, p) => $"({acc} + {TensorUtilities.GetProduct(placement.Hierarchy[(p.H + 1)..])} * {placement.Name[p.H]}id)") + $" * {dimi}"; + } + } + + return $".view(make_ranked_shape({string.Join(',', begins)}), fixed_shape<{string.Join(",", dims.Select(d => d.ToString()))}>{{}})"; + } + + public static string ToSlicing(this IEnumerable dims, IRArray ndsbp, Placement placement) => ToSlicing(dims, Enumerable.Repeat("0", dims.Count()).ToArray(), ndsbp, placement); + + public static string ToC(this BinaryOp binaryOp) => binaryOp switch + { + BinaryOp.Add => "+", + BinaryOp.Sub => "-", + BinaryOp.Mul => "*", + BinaryOp.Div => "/", + _ => throw new NotSupportedException(binaryOp.ToString()), + }; + + public static string ToC(this CompareOp op) => op switch + { + CompareOp.Equal => "==", + CompareOp.NotEqual => "!=", + CompareOp.LowerThan => "<", + CompareOp.LowerOrEqual => "<=", + CompareOp.GreaterThan => ">=", + CompareOp.GreaterOrEqual => ">", + _ => throw new NotSupportedException(op.ToString()), + }; +} diff --git a/modules/Nncase.Modules.CPU/CodeGen/CPU/CSourceUtilities.cs b/modules/Nncase.Modules.CPU/CodeGen/CPU/CSourceUtilities.cs new file mode 100644 index 0000000000..a1359ebdaa --- /dev/null +++ b/modules/Nncase.Modules.CPU/CodeGen/CPU/CSourceUtilities.cs @@ -0,0 +1,78 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System.CommandLine; +using System.Globalization; +using DryIoc.ImTools; +using Nncase.Diagnostics; +using Nncase.IR.Math; + +namespace Nncase.CodeGen.CPU; + +internal static class CSourceUtilities +{ + public static string ContertBinary(Binary binary, CSymbol[] arguments) + { + var lhs = arguments[Binary.Lhs.Index].Name; + var rhs = arguments[Binary.Rhs.Index].Name; + string str; + switch (binary.BinaryOp) + { + case BinaryOp.Add or BinaryOp.Sub or BinaryOp.Mul or BinaryOp.Div: + str = $"({lhs} {binary.BinaryOp.ToC()} {rhs})"; + break; + case BinaryOp.Min: + str = $"std::min({lhs}, {rhs})"; + break; + default: + throw new NotSupportedException(); + } + + return str; + } + + public static bool TryGetDivRem(string dim, out int div, out int rem) + { + div = 0; + rem = 0; + if (dim.IndexOf('?', System.StringComparison.CurrentCulture) is int s && dim.IndexOf(':', System.StringComparison.CurrentCulture) is int e && s != -1 && e != -1) + { + div = int.Parse(dim[(s + 1)..e].Trim()); + rem = int.Parse(dim[(e + 1)..^1].Trim()); + return true; + } + + return false; + } + + internal static string ContertUnary(Unary op, CSymbol[] arguments) + { + var input = arguments[Unary.Input.Index].Name; + string str; + switch (op.UnaryOp) + { + default: + str = $"nncase_mt->{arguments[0].Type}_{nameof(Unary).ToLower(CultureInfo.CurrentCulture)}_{op.UnaryOp.ToString().ToLower(CultureInfo.CurrentCulture)}{input}"; + break; + } + + return str; + } + + internal static string ContertCompare(Compare op, CSymbol[] arguments) + { + var lhs = arguments[Compare.Lhs.Index].Name; + var rhs = arguments[Compare.Rhs.Index].Name; + string str = $"({lhs} {op.CompareOp.ToC()} {rhs})"; + return str; + } + + internal static string ContertSelect(Select s, CSymbol[] arguments) + { + var p = arguments[Select.Predicate.Index].Name; + var lhs = arguments[Select.TrueValue.Index].Name; + var rhs = arguments[Select.FalseValue.Index].Name; + string str = $"({p} ? {lhs} : {rhs})"; + return str; + } +} diff --git a/modules/Nncase.Modules.CPU/CodeGen/CPU/DeviceCSourceConvertVisitor.cs b/modules/Nncase.Modules.CPU/CodeGen/CPU/DeviceCSourceConvertVisitor.cs new file mode 100644 index 0000000000..bbf91a3810 --- /dev/null +++ b/modules/Nncase.Modules.CPU/CodeGen/CPU/DeviceCSourceConvertVisitor.cs @@ -0,0 +1,390 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +#define MULTI_CORE_XPU + +// #define DEBUG_PRINT +using System; +using System.Collections; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Reactive; +using System.Runtime.InteropServices; +using System.Text; +using DryIoc; +using Google.OrTools.Sat; +using NetFabric.Hyperlinq; +using Nncase.IR; +using Nncase.Runtime; +using Nncase.TIR; +using Nncase.Utilities; +using Razor.Templating.Core; + +namespace Nncase.CodeGen.CPU; + +internal sealed class DeviceCSourceConvertVisitor : ExprFunctor +{ + private readonly Dictionary _exprMemo; + private readonly StringBuilder _deviceBuilder; + + public DeviceCSourceConvertVisitor() + { + _exprMemo = new(ReferenceEqualityComparer.Instance); + _deviceBuilder = new(); + } + + public PrimFunction VisitEntry => (TIR.PrimFunction)VisitRoot!; + + public string GetHeader() + { + return _deviceBuilder.ToString(); + } + + /// + protected override CSymbol VisitPrimFunction(PrimFunction expr) + { + if (_exprMemo.TryGetValue(expr, out var symbol)) + { + return symbol; + } + + if (expr.CheckedType is not CallableType { ReturnType: TupleType r } || r != TupleType.Void) + { + throw new NotSupportedException("The PrimFunction must return void!"); + } + + var ctype = $"template<{string.Join(", ", Enumerable.Range(0, expr.Parameters.Length).Select(x => $"class T{x}"))}>" + + $"void {expr.Name}({string.Join(", ", expr.Parameters.AsValueEnumerable().Select(Visit).Select((s, i) => $"T{i} &&{s.Name}").ToArray())})"; + + using (var scope = new IndentScope(_deviceBuilder)) + { + // 1. Function signature + IndentScope.Writer.IndWrite($"{ctype} {{\n"); + + // 2. Function body + using (_ = new IndentScope()) + { + Visit(expr.Body); + } + + // 3. Function closing + IndentScope.Writer.IndWrite("}\n"); + } + + symbol = new(ctype, expr.Name); + _exprMemo.Add(expr, symbol); + return symbol; + } + + protected override CSymbol VisitIfThenElse(IfThenElse expr) + { + if (_exprMemo.TryGetValue(expr, out var symbol)) + { + return symbol; + } + + var cond = Visit(expr.Condition); + IndentScope.Writer.IndWrite($"if ({cond.Name}) {{\n"); + using (_ = new IndentScope()) + { + Visit(expr.Then); + } + + IndentScope.Writer.IndWrite("}\n"); + IndentScope.Writer.IndWrite("else {\n"); + using (_ = new IndentScope()) + { + Visit(expr.Else); + } + + IndentScope.Writer.IndWrite("}\n"); + + symbol = new(string.Empty, string.Empty); + _exprMemo.Add(expr, symbol); + return symbol; + } + + protected override CSymbol VisitLet(Let expr) + { + if (_exprMemo.TryGetValue(expr, out var symbol)) + { + return symbol; + } + + var @var = Visit(expr.Var); + var value = Visit(expr.Expression); + +#if DEBUG_PRINT + IndentScope.Writer.IndWrite($"runtime_util->printf(\"let {@var.Name}\\n\");\n"); +#endif + IndentScope.Writer.IndWrite($"{value.Type} {@var.Name} = {value.Name};\n"); + Visit(expr.Body); + + symbol = new(string.Empty, string.Empty); + _exprMemo.Add(expr, symbol); + return symbol; + } + + /// + protected override CSymbol VisitMemSpan(MemSpan expr) + { + if (_exprMemo.TryGetValue(expr, out var symbol)) + { + return symbol; + } + + var start = Visit(expr.Start); + var size = Visit(expr.Size); + string name = expr.Location switch + { + MemoryLocation.L2Data => start.Name, + MemoryLocation.Input or MemoryLocation.Output => start.Name, + _ => throw new NotSupportedException(expr.Location.ToString()), + }; + + symbol = new(start.Type, $"std::span({name}, {size.Name})"); + _exprMemo.Add(expr, symbol); + return symbol; + } + + protected override CSymbol VisitBuffer(TIR.Buffer expr) + { + if (_exprMemo.TryGetValue(expr, out var symbol)) + { + return symbol; + } + + var type = $"tensor_view<{expr.ElemType.ToC()}, {KernelUtility.DimensionsToC(expr.Dimensions)}, {KernelUtility.StridesToC(expr.Strides)}> "; + + symbol = new(type, expr.Name); + _exprMemo.Add(expr, symbol); + return symbol; + } + + protected override CSymbol VisitCall(Call expr) + { + if (_exprMemo.TryGetValue(expr, out var symbol)) + { + return symbol; + } + + string type = expr.CheckedType switch + { + TupleType x when x == TupleType.Void => string.Empty, + TensorType { IsScalar: true } x => x.DType.ToC(), + TensorType { Shape: { IsRanked: true } } x => x.Shape.IsFixed switch + { + true => $"tensor_view<{x.DType.ToC()}, fixed_shape<{x.Shape.ToString()[1..^1]}>>", + false => $"tensor_view<{x.DType.ToC()}, ranked_shape<{x.Shape.Rank}>>", + }, + _ => throw new NotSupportedException(), + }; + + string str = string.Empty; + var arguments = expr.Arguments.AsValueEnumerable().Select(Visit).ToArray(); + switch (expr.Target) + { + case PrimFunction deviceFunc: + IndentScope.Writer.IndWrite($"{deviceFunc.Name}({string.Join(",", arguments.Select(arg => arg.Name))});\n"); + break; + case IR.Math.Binary op: + str = CSourceUtilities.ContertBinary(op, arguments); + break; + case IR.Math.Unary op: + str = CSourceUtilities.ContertUnary(op, arguments); + break; + case IR.Math.Compare op: + str = CSourceUtilities.ContertCompare(op, arguments); + break; + case IR.Math.Select op: + str = CSourceUtilities.ContertSelect(op, arguments); + break; + case TIR.CPU.SramPtr op: + str = $"g_cpu_mt->sram_address(bid, tid) + {arguments[0].Name}"; + break; + case TIR.Load op: + str = $"{arguments[0].Name}[{arguments[1].Name}]"; + break; + case TIR.Store op: +#if DEBUG_PRINT + IndentScope.Writer.IndWrite($"runtime_util->printf(\"{arguments[0].Name}[%d]\\n\", {arguments[1].Name});\n"); +#endif + IndentScope.Writer.IndWrite($"{arguments[0].Name}[{arguments[1].Name}] = {arguments[2].Name};\n"); + break; + case TIR.CPU.PtrOf op: + str = op.PtrName + ".data()"; + break; + case IR.Buffers.Allocate op: + str = $"({type})runtime_util->malloc({arguments[0].Name})"; + break; + case IR.Buffers.AllocateBufferView op: + { + var buffer = (TIR.Buffer)expr.Arguments[0]; + if (buffer.CheckedShape.IsFixed) + { + str = $"{{span_cast<{buffer.ElemType.ToC()}>({Visit(buffer.MemSpan).Name}), {KernelUtility.DimensionsToC(buffer.Dimensions)}{{}}, {KernelUtility.StridesToC(buffer.Strides)}{{}}}}"; + } + else + { + str = $"{{span_cast<{buffer.ElemType.ToC()}>({Visit(buffer.MemSpan).Name}), make_ranked_shape({StringUtility.Join(", ", buffer.Dimensions.AsValueEnumerable().Select(x => Visit(x).Name))})}}"; + } + } + + break; + case IR.Tensors.Cast op: + str = $"(({op.NewType.ToC()}){arguments[0].Name})"; + break; + case TIR.CPU.Memcopy op: + IndentScope.Writer.IndWrite($"tensor_copy({arguments[1].Name}, {arguments[0].Name});\n"); + break; + case TIR.CPU.Unary op: + IndentScope.Writer.IndWrite(RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/Kernels/Unary.cshtml", new UnaryKernelTemplateModel + { + Arguments = arguments.Select(x => new KernelArgument { Symbol = x }).ToArray(), + UnaryOp = op.UnaryOp, + }).Result); + break; + case TIR.CPU.Binary op: + IndentScope.Writer.IndWrite(RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/Kernels/Binary.cshtml", new BinaryKernelTemplateModel + { + Arguments = arguments.Select(x => new KernelArgument { Symbol = x }).ToArray(), + BinaryOp = op.BinaryOp, + }).Result); + break; + default: + throw new NotSupportedException(); + } + + symbol = new(type, str); + _exprMemo.Add(expr, symbol); + return symbol; + } + + /// + protected override CSymbol VisitConst(Const expr) + { + if (_exprMemo.TryGetValue(expr, out var symbol)) + { + return symbol; + } + + string type; + string str; + if (expr is TensorConst { Value: Tensor { ElementType: PrimType ptype, Shape: { IsScalar: true } } scalar }) + { + str = scalar[0].ToString() switch + { + "True" => "1", + "False" => "0", + null => string.Empty, + var x => x, + }; + + type = ptype.ToC(); + } + else if (expr is TensorConst { Value: Tensor { ElementType: PointerType { ElemType: PrimType }, Shape: { IsScalar: true } } pointer }) + { + str = pointer.ToScalar().ToString(); + type = pointer.ElementType.ToC(); + } + else + { + throw new NotSupportedException(); + } + + symbol = new(type, str); + _exprMemo.Add(expr, symbol); + return symbol; + } + + /// + protected override CSymbol VisitSequential(Sequential expr) + { + if (_exprMemo.TryGetValue(expr, out var symbol)) + { + return symbol; + } + + foreach (var field in expr.Fields) + { + Visit(field); + } + + symbol = new(string.Empty, string.Empty); + _exprMemo.Add(expr, symbol); + return symbol; + } + + protected override CSymbol VisitFor(For expr) + { + if (_exprMemo.TryGetValue(expr, out var symbol)) + { + return symbol; + } + + // 1. For Loop signature + var loopVar = Visit(expr.LoopVar); + IndentScope.Writer.IndWrite($"for ({loopVar.Type} {loopVar.Name} = {Visit(expr.Domain.Start).Name}; {loopVar.Name} < {Visit(expr.Domain.Stop).Name}; {loopVar.Name} += {Visit(expr.Domain.Step).Name}) {{\n"); +#if DEBUG_PRINT + IndentScope.Writer.IndWrite($"runtime_util->printf(\"{loopVar.Name} = %d\\n\", {loopVar.Name});\n"); +#endif + + using (_ = new IndentScope()) + { + // 2. For Body + Visit(expr.Body); + } + + // 3. For closing + IndentScope.Writer.IndWrite("}\n"); + + symbol = new(string.Empty, string.Empty); + _exprMemo.Add(expr, symbol); + return symbol; + } + + protected override CSymbol VisitVar(Var expr) + { + if (_exprMemo.TryGetValue(expr, out var symbol)) + { + return symbol; + } + + symbol = new( + expr.CheckedType switch + { + TensorType t => t.DType.ToC(), + _ => throw new ArgumentOutOfRangeException(nameof(expr)), + }, + expr.Name + expr.GlobalVarIndex.ToString()); + _exprMemo.Add(expr, symbol); + return symbol; + } + + protected override CSymbol VisitBufferRegion(BufferRegion expr) + { + if (_exprMemo.TryGetValue(expr, out var symbol)) + { + return symbol; + } + + var buffer = Visit(expr.Buffer); + if (expr.Region.AsValueEnumerable().All(r => r is { Start: TensorConst, Stop: TensorConst, Step: TensorConst step } && step.Value.ToScalar() == 1)) + { + var begins = $"{StringUtility.Join(", ", expr.Region.AsValueEnumerable().Select(x => Visit(x.Start).Name))}"; + var extents = $"{StringUtility.Join(", ", expr.Region.AsValueEnumerable().Select(x => Visit(x.Stop).Name))}"; + symbol = new(string.Empty, $"{buffer.Name}.view(fixed_shape<{begins}>{{}}, fixed_shape<{extents}>{{}})"); + _exprMemo.Add(expr, symbol); + } + else + { + var begins = $"{StringUtility.Join(", ", expr.Region.AsValueEnumerable().Select(x => Visit(x.Start).Name))}"; + var extents = $"{StringUtility.Join(", ", expr.Region.AsValueEnumerable().Select(x => Visit(x.Stop - x.Start).Name))}"; + symbol = new(string.Empty, $"{buffer.Name}.view(make_ranked_shape({begins}), make_ranked_shape({extents}))"); + _exprMemo.Add(expr, symbol); + } + + return symbol; + } +} diff --git a/modules/Nncase.Modules.CPU/CodeGen/CPU/FunctionBuilder.cs b/modules/Nncase.Modules.CPU/CodeGen/CPU/FunctionBuilder.cs new file mode 100644 index 0000000000..f1625b40b4 --- /dev/null +++ b/modules/Nncase.Modules.CPU/CodeGen/CPU/FunctionBuilder.cs @@ -0,0 +1,86 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. +using System.Runtime.InteropServices; +using System.Text; +using System.Threading.Tasks; +using NetFabric.Hyperlinq; +using Nncase.CodeGen.CPU; +using Nncase.IR; + +namespace Nncase.CodeGen.CPU; + +/// +/// StackVM function builder. +/// +internal class FunctionBuilder +{ + public const string KernelHeaderSectionName = ".desc"; + private readonly uint _id; + private readonly SectionManager _sectionManager; + private readonly BinaryWriter _textWriter; + private readonly BinaryWriter _rdataWriter; + + public FunctionBuilder(uint id, BinaryWriter rdataWriter) + { + _id = id; + _sectionManager = new(); + _textWriter = _sectionManager.GetWriter(WellknownSectionNames.Text); + _rdataWriter = rdataWriter; + } + + public unsafe ILinkableFunction Build(TIR.PrimFunction function) + { + if (function.Name.EndsWith("kernel")) + { + // 1. convert func to csource + var visitor = new KernelCSourceConvertVisitor(); + visitor.Visit(function); + var functionCSource = visitor.GetCSource(); + + // 2. write the kernel header + using (var writer = _sectionManager.GetWriter(KernelHeaderSectionName)) + { + var header = default(DescHeader); + header.DataPoolSize = function.SchedResult.DataUsage; + header.DataAlign = function.SchedResult.DataAlign; + writer.Write(ref header); + } + + // 3. write the rdata + foreach (var (@const, range) in function.SchedResult.Rdatas) + { + var bytes = ((TensorConst)@const).Value.BytesBuffer; + var size = range.Max - range.Min; + if ((uint)bytes.Length != size) + { + throw new InvalidDataException("The Buffer Size Not Equal!"); + } + + _rdataWriter.Position(range.Min); + _rdataWriter.Write(bytes); + } + + return new LinkableKernelFunction(_id, function, functionCSource, _sectionManager.GetContent(WellknownSectionNames.Text)!, new LinkedSection(_sectionManager.GetContent(KernelHeaderSectionName), KernelHeaderSectionName, 0, 8, (uint)sizeof(DescHeader))); + } + else if (function.Name.EndsWith("device")) + { + var visitor = new DeviceCSourceConvertVisitor(); + visitor.Visit(function); + var header = visitor.GetHeader(); + + return new LinkableDeviceFunction(_id, function, header, _sectionManager.GetContent(WellknownSectionNames.Text)!); + } + + throw new NotSupportedException("the function name is invalid"); + } + + [StructLayout(LayoutKind.Sequential)] + private unsafe struct DescHeader + { + [MarshalAs(UnmanagedType.U8)] + public ulong DataPoolSize; + + [MarshalAs(UnmanagedType.U8)] + public ulong DataAlign; + } +} diff --git a/modules/Nncase.Modules.CPU/CodeGen/CPU/FunctionCSource.cs b/modules/Nncase.Modules.CPU/CodeGen/CPU/FunctionCSource.cs new file mode 100644 index 0000000000..396d1fb986 --- /dev/null +++ b/modules/Nncase.Modules.CPU/CodeGen/CPU/FunctionCSource.cs @@ -0,0 +1,20 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. +#define MULTI_CORE_XPU +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Runtime.InteropServices; +using System.Text; +using Nncase.IR; +using Nncase.Schedule; +using Nncase.TIR; + +namespace Nncase.CodeGen; + +internal sealed record KernelCSource(string Main, string Kernel) +{ +} diff --git a/modules/Nncase.Modules.CPU/CodeGen/CPU/KernelCSourceConvertVisitor.cs b/modules/Nncase.Modules.CPU/CodeGen/CPU/KernelCSourceConvertVisitor.cs new file mode 100644 index 0000000000..ce8d1ad47d --- /dev/null +++ b/modules/Nncase.Modules.CPU/CodeGen/CPU/KernelCSourceConvertVisitor.cs @@ -0,0 +1,584 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +#define MULTI_CORE_CPU + +using System; +using System.Collections; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Reactive; +using System.Runtime.InteropServices; +using System.Text; +using DryIoc.ImTools; +using NetFabric.Hyperlinq; +using Nncase.CodeGen.CPU; +using Nncase.IR; +using Nncase.Runtime; +using Nncase.TIR; +using Razor.Templating.Core; + +namespace Nncase.CodeGen.CPU; + +internal struct IndentScope : IDisposable +{ + private static readonly AsyncLocal _writer = new AsyncLocal(); + + private readonly bool _initialized; + + private readonly IndentWriter? _originalWriter; + + public IndentScope(StringBuilder sb) + { + _initialized = true; + _originalWriter = _writer.Value; + _writer.Value = new IndentWriter(sb); + } + + public IndentScope() + { + _initialized = true; + if (_writer.Value is null) + { + return; + } + + _originalWriter = _writer.Value; + _writer.Value = new(_originalWriter.GetStringBuilder(), _originalWriter.Indent + 2); + } + + public static IndentWriter Writer => _writer.Value!; + + public void Dispose() + { + if (_initialized) + { + _writer.Value = _originalWriter; + } + } +} + +/// +/// the c symbol define. +/// +public sealed class CSymbol +{ + public CSymbol(string type, string name) + { + Type = type; + Name = name; + } + + public static IReadOnlyList Builtns => new CSymbol[] { + new CSymbol("nncase_mt_t*", "nncase_mt"), + new CSymbol("uint8_t*", "data"), + new CSymbol("const uint8_t*", "rdata"), + }; + + public string Type { get; } + + public string Name { get; } + + public override string ToString() => $"{Type} {Name}"; +} + +internal sealed class IndentWriter : StringWriter +{ + public IndentWriter(StringBuilder sb, int indent = 0) + : base(sb) + { + Indent = indent; + } + + public int Indent { get; set; } + + public void IndWrite(string? value) + { + for (int i = 0; i < Indent; i++) + { + Write(' '); + } + + Write(value); + } +} + +/// +/// convert single prim function to c source. +/// +internal sealed class KernelCSourceConvertVisitor : ExprFunctor, IDisposable +{ + private readonly Dictionary _exprMemo; + private readonly StringBuilder _kernelBuilder; + + private readonly StringBuilder _sharedBuilder; + private readonly HashSet _refFuncs; + private readonly StringWriter _sharedWriter; + + public KernelCSourceConvertVisitor() + { + _kernelBuilder = new StringBuilder(); + _sharedBuilder = new StringBuilder(); + _sharedWriter = new StringWriter(_sharedBuilder); + _exprMemo = new(ReferenceEqualityComparer.Instance); + _refFuncs = new(ReferenceEqualityComparer.Instance); + } + + public PrimFunction VisitEntry => (TIR.PrimFunction)VisitRoot!; + + public KernelCSource GetCSource() + { + var ctype = $"void {VisitEntry.Name}({string.Join(", ", VisitEntry.Parameters.AsValueEnumerable().Select(Visit).Select(s => $"{s.Type} {s.Name}").ToArray().Concat(_exprMemo.Keys.OfType().Where(b => b.MemSpan.Location == MemoryLocation.Rdata).Select(Visit).Select(s => $" {s.Type} {s.Name}").ToArray()))}, uint8_t* l1_data)"; + return new( + CSourceBuiltn.MakeMain(VisitEntry, _exprMemo.Keys.OfType().Where(b => b.MemSpan.Location == MemoryLocation.Rdata)), + CSourceBuiltn.MakeKernel(ctype, _kernelBuilder.ToString())); + } + + /// + public void Dispose() + { + _sharedWriter.Dispose(); + } + + protected override CSymbol VisitVar(Var expr) + { + if (_exprMemo.TryGetValue(expr, out var symbol)) + { + return symbol; + } + + symbol = new(string.Empty, expr.Name); + _exprMemo.Add(expr, symbol); + return symbol; + } + + /// + protected override CSymbol VisitPrimFunction(PrimFunction expr) + { + if (_exprMemo.TryGetValue(expr, out var symbol)) + { + return symbol; + } + + if (expr.CheckedType is not CallableType { ReturnType: TupleType r } || r != TupleType.Void) + { + throw new NotSupportedException("The PrimFunction must return void!"); + } + + var ctype = $"void {expr.Name}({string.Join(", ", expr.Parameters.AsValueEnumerable().Select(Visit).Select(s => $"{s.Type} {s.Name}").ToArray())})"; + + using (var scope = new IndentScope(_kernelBuilder)) + { + // 1. Function signature + IndentScope.Writer.IndWrite($"{{\n"); + + // 2. Function body + using (_ = new IndentScope()) + { + Visit(expr.Body); + } + + // 3. Function closing + IndentScope.Writer.IndWrite("}\n"); + } + + symbol = new(ctype, expr.Name); + _exprMemo.Add(expr, symbol); + return symbol; + } + + /// + protected override CSymbol VisitMemSpan(MemSpan expr) + { + if (_exprMemo.TryGetValue(expr, out var symbol)) + { + return symbol; + } + + var start = Visit(expr.Start); + _ = Visit(expr.Size); + string loc = (expr.Location, expr.Hierarchy) switch + { + (MemoryLocation.Rdata, 0) => "rdata", + (MemoryLocation.Data, 0) => "data", + (MemoryLocation.Data, 1) => "l1_data", + _ => throw new NotSupportedException(), + }; + var ptype = (PointerType)expr.CheckedDataType; + var ptypeName = ptype.ElemType.ToC(); + var spanSize = ((TensorConst)expr.Size).Value.ToScalar() / ptype.ElemType.SizeInBytes; + var name = $"std::span<{ptypeName}, {spanSize}> (reinterpret_cast<{ptypeName}*>({loc} + {start.Name}), {spanSize})"; + + symbol = new(start.Type, name); + _exprMemo.Add(expr, symbol); + return symbol; + } + + protected override CSymbol VisitBuffer(TIR.Buffer expr) + { + if (_exprMemo.TryGetValue(expr, out var symbol)) + { + return symbol; + } + + var type = VisitEntry.Parameters.AsValueEnumerable().Contains(expr) || expr.MemSpan.Location == MemoryLocation.Rdata || expr.MemSpan.Start is TensorConst + ? $"tensor_view<{expr.ElemType.ToC()}, {KernelUtility.DimensionsToC(expr.Dimensions)}, {KernelUtility.StridesToC(expr.Strides)}> " + : $"tensor<{expr.ElemType.ToC()}, {KernelUtility.DimensionsToC(expr.Dimensions)}> "; + + symbol = new(type, expr.Name); + _exprMemo.Add(expr, symbol); + return symbol; + } + + /// + protected override CSymbol VisitCall(Call expr) + { + if (_exprMemo.TryGetValue(expr, out var symbol)) + { + return symbol; + } + + string type = expr.CheckedType switch + { + TupleType x when x == TupleType.Void => string.Empty, + TensorType { IsScalar: true } x => x.DType.ToC(), + _ => throw new NotSupportedException(), + }; + + string str = string.Empty; + if (expr.Target is TIR.CPU.CPUKernelOp xpuOp) + { + foreach (var item in expr.Arguments.ToArray().OfType()) + { + DeclBuffer(item); + } + + var args = expr.Arguments.ToArray().OfType().ToArray(); + switch (xpuOp) + { + case TIR.CPU.Unary unary: + IndentScope.Writer.Write(RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/Kernels/Unary.cshtml", new UnaryKernelTemplateModel + { + Arguments = args.Select(x => new KernelArgument { Symbol = Visit(x) }).ToArray(), + UnaryOp = unary.UnaryOp, + }).Result); + break; + case TIR.CPU.TensorLoad load: + if (args.Length == 1) + { + var fullShape = Enumerable.Repeat(1, args[0].Dimensions.Length).ToArray(); + var splitAxisAndScale = load.NdSbp.Select((sbp, i) => sbp is SBPSplit s ? (s.Axis, load.Placement.Hierarchy[i]) : (0, 1)).ToArray(); + foreach (var s in splitAxisAndScale) + { + fullShape[s.Item1] *= s.Item2; + } + + foreach (var (dimS, axis) in args[0].Dimensions.ToArray().Select((e, axis) => (Visit(e).Name, axis))) + { + if (int.TryParse(dimS, out var div)) + { + fullShape[axis] *= div; + } + else if (CSourceUtilities.TryGetDivRem(dimS, out div, out var rem)) + { + fullShape[axis] = (fullShape[axis] - 1) * div; + fullShape[axis] += rem; + } + } + + IndentScope.Writer.Write($"tensor_boxing_load({Visit(args[0]).Name}, {{{string.Join(',', fullShape)}}}, {args[0].Dimensions.ToArray().Select(e => Visit(e).Name).ToSlicing(load.NdSbp, load.Placement)[1..^1]}, ctx);\n"); + } + else + { + IndentScope.Writer.Write($"tensor_copy({Visit(args[1]).Name}{args[0].Dimensions.ToArray().Select(e => Visit(e).Name).ToSlicing(load.NdSbp, load.Placement)}, {Visit(args[0]).Name});\n"); + } + + break; + case TIR.CPU.TensorStore store: + if (args.Length == 1) + { + var fullShape = Enumerable.Repeat(1, args[0].Dimensions.Length).ToArray(); + var splitAxisAndScale = store.NdSbp.Select((sbp, i) => sbp is SBPSplit s ? (s.Axis, store.Placement.Hierarchy[i]) : (0, 1)).ToArray(); + foreach (var s in splitAxisAndScale) + { + fullShape[s.Item1] *= s.Item2; + } + + foreach (var (dimS, axis) in args[0].Dimensions.ToArray().Select((e, axis) => (Visit(e).Name, axis))) + { + if (int.TryParse(dimS, out var div)) + { + fullShape[axis] *= div; + } + else if (CSourceUtilities.TryGetDivRem(dimS, out div, out var rem)) + { + fullShape[axis] = (fullShape[axis] - 1) * div; + fullShape[axis] += rem; + } + } + + IndentScope.Writer.Write($"tensor_boxing_store({Visit(args[0]).Name}, {{{string.Join(',', fullShape)}}}, {args[0].Dimensions.ToArray().Select(e => Visit(e).Name).ToSlicing(store.NdSbp, store.Placement)[1..^1]}, ctx);\n"); + } + else + { + IndentScope.Writer.Write($"tensor_copy({Visit(args[0]).Name}, {Visit(args[1]).Name}{args[0].Dimensions.ToArray().Select(e => Visit(e).Name).ToSlicing(store.NdSbp, store.Placement)});\n"); + } + + break; + case TIR.CPU.Binary binary: + { + IndentScope.Writer.Write(RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/Kernels/Binary.cshtml", new BinaryKernelTemplateModel + { + Arguments = args.Select(x => new KernelArgument { Symbol = Visit(x) }).ToArray(), + BinaryOp = binary.BinaryOp, + }).Result); + } + + break; + case TIR.CPU.Pack pack: + { + IndentScope.Writer.Write(RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/Kernels/Pack.cshtml", new TypedKernelTemplateModel(pack) + { + Arguments = args.Select(x => new KernelArgument { Symbol = Visit(x) }).ToArray(), + }).Result); + } + + break; + + case TIR.CPU.Unpack unpack: + { + IndentScope.Writer.Write(RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/Kernels/Unpack.cshtml", new TypedKernelTemplateModel(unpack) + { + Arguments = args.Select(x => new KernelArgument { Symbol = Visit(x) }).ToArray(), + }).Result); + } + + break; + case TIR.CPU.PackedLayerNorm packedLayerNorm: + { + IndentScope.Writer.Write(RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/Kernels/PackedLayerNorm.cshtml", new TypedKernelTemplateModel(packedLayerNorm) + { + Arguments = args.Select(x => new KernelArgument { Symbol = Visit(x) }).ToArray(), + Args = args.ToArray(), + }).Result); + } + + break; + case TIR.CPU.PackedSoftmax packedsoftmax: + { + IndentScope.Writer.Write(RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/Kernels/PackedSoftMax.cshtml", new TypedKernelTemplateModel(packedsoftmax) + { + Arguments = args.Select(x => new KernelArgument { Symbol = Visit(x) }).ToArray(), + Args = args.ToArray(), + }).Result); + } + + break; + case TIR.CPU.PackedBinary packedBinary: + { + IndentScope.Writer.Write(RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/Kernels/Binary.cshtml", new BinaryKernelTemplateModel + { + BinaryOp = packedBinary.BinaryOp, + Arguments = args.Select(x => new KernelArgument { Symbol = Visit(x) }).ToArray(), + }).Result); + } + + break; + case TIR.CPU.PackedMatMul packedMatmul: + { + IndentScope.Writer.Write(RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/Kernels/PackedMatmul.cshtml", new TypedKernelTemplateModel(packedMatmul) + { + Arguments = args.Select(x => new KernelArgument { Symbol = Visit(x) }).ToArray(), + }).Result); + } + + break; + case TIR.CPU.PackedTranspose transpose: + { + IndentScope.Writer.Write(RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/Kernels/PackedTranspose.cshtml", new TypedKernelTemplateModel(transpose) + { + Arguments = args.Select(x => new KernelArgument { Symbol = Visit(x) }).ToArray(), + Args = args.ToArray(), + }).Result); + } + + break; + + case TIR.CPU.Memcopy copy: + IndentScope.Writer.Write($"tensor_copy({Visit(args[0]).Name}, {Visit(args[1]).Name});\n"); + break; + case TIR.CPU.Gather gather: + IndentScope.Writer.Write($"gather<{gather.Axis}>({Visit(args[0]).Name}, {Visit(args[1]).Name}, {Visit(args[2]).Name});\n"); + break; + case TIR.CPU.Reshape reshape: + { + IndentScope.Writer.Write(RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/Kernels/Reshape.cshtml", new TypedKernelTemplateModel(reshape) + { + Arguments = args.Select(x => new KernelArgument { Symbol = Visit(x) }).ToArray(), + Args = args.ToArray(), + }).Result); + } + + break; + case TIR.CPU.Matmul matmul: + IndentScope.Writer.Write($"matmul({Visit(args[0]).Name}, {Visit(args[1]).Name}, {Visit(args[2]).Name});\n"); + break; + case TIR.CPU.Swish swish: + if (swish.Beta != 1.0f) + { + throw new NotSupportedException(); + } + + IndentScope.Writer.Write($"unary({Visit(args[0]).Name}, {Visit(args[1]).Name});\n"); + break; + case TIR.CPU.Slice slice: + IndentScope.Writer.Write($"slice, fixed_shape<{string.Join(",", slice.Ends)}>, fixed_shape<{string.Join(",", slice.Axes)}>, fixed_shape<{string.Join(",", slice.Strides)}>>({Visit(args[0]).Name}, {Visit(args[1]).Name});\n"); + break; + case TIR.CPU.Concat concat: + IndentScope.Writer.Write($"concat<{concat.Axis}>(std::make_tuple({string.Join(",", args.SkipLast(1).Select(Visit).Select(s => s.Name))}), {Visit(args[^1]).Name});\n"); + break; + case TIR.CPU.Transpose transpose: + IndentScope.Writer.Write($"transpose>({Visit(args[0]).Name}, {Visit(args[1]).Name});\n"); + break; + case TIR.CPU.Pad pad: + IndentScope.Writer.Write($"pad<{string.Join(",", pad.Paddings)}>({Visit(args[0]).Name}, {Visit(args[1]).Name}, {args[0].CheckedDataType.ToC()} {{ {pad.PadValue} }} );\n"); + break; + default: + throw new NotSupportedException(xpuOp.ToString()); + } + } + else if (expr.Target is PrimFunction deviceFunc) + { + foreach (var item in expr.Arguments.ToArray().OfType()) + { + DeclBuffer(item); + } +#if DEBUG_PRINT + IndentScope.Writer.IndWrite($"runtime_util->printf(\"call {deviceFunc.Name} bid %d tid %d\\n\", bid, tid);\n"); +#endif + var arguments = expr.Arguments.AsValueEnumerable().Select(Visit).ToArray(); + _refFuncs.Add(deviceFunc); + IndentScope.Writer.IndWrite($"{deviceFunc.Name}({string.Join(",", arguments.Select(arg => arg.Name))});\n"); + } + else + { + var arguments = expr.Arguments.AsValueEnumerable().Select(Visit).ToArray(); + switch (expr.Target) + { + case IR.Math.Binary op: + str = CSourceUtilities.ContertBinary(op, arguments); + break; + case IR.Math.Unary op: + str = CSourceUtilities.ContertUnary(op, arguments); + break; + case IR.Math.Compare op: + str = CSourceUtilities.ContertCompare(op, arguments); + break; + case IR.Math.Select op: + str = CSourceUtilities.ContertSelect(op, arguments); + break; + case TIR.Load op: + str = $"{arguments[0].Name}[{arguments[1].Name}]"; + break; + case TIR.Store op: + IndentScope.Writer.IndWrite($"{arguments[0].Name}[{arguments[1].Name}] = {arguments[1].Name};\n"); + break; + case TIR.CPU.PtrOf op: + str = op.PtrName; + break; + default: + throw new NotSupportedException(); + } + } + + symbol = new(type, str); + _exprMemo.Add(expr, symbol); + return symbol; + } + + /// + protected override CSymbol VisitConst(Const expr) + { + if (_exprMemo.TryGetValue(expr, out var symbol)) + { + return symbol; + } + + string type; + string str; + if (expr is TensorConst { Value: Tensor { ElementType: PrimType ptype, Shape: { IsScalar: true } } scalar }) + { + str = scalar[0].ToString() switch + { + "True" => "1", + "False" => "0", + null => string.Empty, + var x => x, + }; + + type = ptype.ToC(); + } + else if (expr is TensorConst { Value: Tensor { ElementType: PointerType { ElemType: DataType }, Shape: { IsScalar: true } } pointer }) + { + str = pointer.ToScalar().ToString(); + type = "uint8_t *"; + } + else + { + throw new NotSupportedException(); + } + + symbol = new(type, str); + _exprMemo.Add(expr, symbol); + return symbol; + } + + /// + protected override CSymbol VisitSequential(Sequential expr) + { + if (_exprMemo.TryGetValue(expr, out var symbol)) + { + return symbol; + } + + foreach (var field in expr.Fields) + { + if (field is Call call) + { + IndentScope.Writer.IndWrite(Visit(call).Name); + } + else + { + Visit(field); + } + } + + symbol = new(string.Empty, string.Empty); + _exprMemo.Add(expr, symbol); + return symbol; + } + + private void DeclBuffer(TIR.Buffer buffer) + { + if (_exprMemo.ContainsKey(buffer)) + { + return; + } + + var symbol = Visit(buffer); + + if (buffer.MemSpan.Location == MemoryLocation.Rdata) + { + return; + } + + IndentScope.Writer.IndWrite($"{symbol.Type} {symbol.Name}"); + if (buffer.MemSpan.Start is not None) + { + IndentScope.Writer.IndWrite($"({Visit(buffer.MemSpan).Name})"); + } + + IndentScope.Writer.Write($";\n"); + } +} diff --git a/modules/Nncase.Modules.CPU/CodeGen/CPU/KernelTemplateModel.cs b/modules/Nncase.Modules.CPU/CodeGen/CPU/KernelTemplateModel.cs new file mode 100644 index 0000000000..cb84404374 --- /dev/null +++ b/modules/Nncase.Modules.CPU/CodeGen/CPU/KernelTemplateModel.cs @@ -0,0 +1,43 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Nncase.CodeGen.CPU; + +public class KernelArgument +{ + public CSymbol Symbol { get; set; } = null!; +} + +public class KernelTemplateModel +{ + public KernelArgument[] Arguments { get; set; } = null!; +} + +public class UnaryKernelTemplateModel : KernelTemplateModel +{ + public UnaryOp UnaryOp { get; set; } +} + +public class BinaryKernelTemplateModel : KernelTemplateModel +{ + public BinaryOp BinaryOp { get; set; } +} + +public class TypedKernelTemplateModel : KernelTemplateModel + where T : IR.Op +{ + public TypedKernelTemplateModel(T target) + { + Target = target; + } + + public T Target { get; } + + public IR.Expr[] Args { get; set; } = Array.Empty(); +} diff --git a/modules/Nncase.Modules.CPU/CodeGen/CPU/KernelUtility.cs b/modules/Nncase.Modules.CPU/CodeGen/CPU/KernelUtility.cs new file mode 100644 index 0000000000..9289e43f1c --- /dev/null +++ b/modules/Nncase.Modules.CPU/CodeGen/CPU/KernelUtility.cs @@ -0,0 +1,66 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.CommandLine; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Nncase.IR; + +namespace Nncase.CodeGen.CPU; + +public static class KernelUtility +{ + public static ulong GetLength(TIR.Buffer buffer) + { + // Scalar + if (buffer.Dimensions.Length == 0) + { + return 1; + } + + ulong length = 1; + foreach (var dim in buffer.Dimensions) + { + length *= ((TensorConst)dim).Value.Cast()[0]; + } + + return length; + } + + public static string DimensionsToC(ReadOnlySpan dimensions) + { + var sb = new StringBuilder("fixed_shape<"); + for (int i = 0; i < dimensions.Length; i++) + { + var value = ((TensorConst)dimensions[i]).Value.Cast()[0]; + sb.Append(value); + if (i != dimensions.Length - 1) + { + sb.Append(", "); + } + } + + sb.Append('>'); + return sb.ToString(); + } + + public static string StridesToC(ReadOnlySpan dimensions) + { + var sb = new StringBuilder("fixed_strides<"); + for (int i = 0; i < dimensions.Length; i++) + { + var value = ((TensorConst)dimensions[i]).Value.Cast()[0]; + sb.Append(value); + if (i != dimensions.Length - 1) + { + sb.Append(", "); + } + } + + sb.Append('>'); + return sb.ToString(); + } +} diff --git a/modules/Nncase.Modules.CPU/CodeGen/CPU/LinkableFunction.cs b/modules/Nncase.Modules.CPU/CodeGen/CPU/LinkableFunction.cs new file mode 100644 index 0000000000..c0102dbff9 --- /dev/null +++ b/modules/Nncase.Modules.CPU/CodeGen/CPU/LinkableFunction.cs @@ -0,0 +1,60 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System.Runtime.InteropServices; +using Nncase.IR; + +namespace Nncase.CodeGen.CPU; +internal sealed class LinkableKernelFunction : ILinkableFunction +{ + public LinkableKernelFunction(uint id, TIR.PrimFunction sourceFunction, KernelCSource funcCSource, Stream text, params ILinkedSection[] sections) + { + Id = id; + SourceFunction = sourceFunction; + PrimFunction = sourceFunction; + FunctionCSource = funcCSource; + Text = text; + Sections = sections; + } + + public uint Id { get; } + + public BaseFunction SourceFunction { get; } + + public TIR.PrimFunction PrimFunction { get; } + + public KernelCSource FunctionCSource { get; } + + public Stream Text { get; } + + public IEnumerable FunctionRefs => Enumerable.Empty(); + + public IReadOnlyList Sections { get; } +} + +internal sealed class LinkableDeviceFunction : ILinkableFunction +{ + public LinkableDeviceFunction(uint id, TIR.PrimFunction sourceFunction, string header, Stream text) + { + Id = id; + SourceFunction = sourceFunction; + Header = header; + PrimFunction = sourceFunction; + Text = text; + Sections = Array.Empty(); + } + + public uint Id { get; } + + public BaseFunction SourceFunction { get; } + + public string Header { get; } + + public TIR.PrimFunction PrimFunction { get; } + + public Stream Text { get; } + + public IEnumerable FunctionRefs => Enumerable.Empty(); + + public IReadOnlyList Sections { get; } +} diff --git a/modules/Nncase.Modules.CPU/CodeGen/CPU/LinkableModule.cs b/modules/Nncase.Modules.CPU/CodeGen/CPU/LinkableModule.cs new file mode 100644 index 0000000000..5d0567970b --- /dev/null +++ b/modules/Nncase.Modules.CPU/CodeGen/CPU/LinkableModule.cs @@ -0,0 +1,110 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Runtime.InteropServices; +using System.Text; +using System.Threading.Tasks; +using DryIoc.ImTools; +using Nncase.CodeGen.CPU; +using Nncase.Diagnostics; +using Nncase.Runtime.StackVM; + +namespace Nncase.CodeGen.CPU; + +internal sealed class LinkableModule : ILinkableModule +{ + private readonly Stream _rdata; + + private readonly IReadOnlyList _functions; + private readonly CompileOptions _options; + + public LinkableModule(Stream rdata, IReadOnlyList functions, CompileOptions options) + { + _rdata = rdata; + _functions = functions; + _options = options; + } + + public ILinkedModule Link(ILinkContext linkContext) + { + { + if (!Directory.Exists(_options.DumpDir)) + { + Directory.CreateDirectory(_options.DumpDir); + } + + using (var writer = new StreamWriter(File.Open(Path.Join(_options.DumpDir, "device.h"), FileMode.Create))) + { + writer.Write(CSourceBuiltn.KernelHeader); + + foreach (var func in _functions.OfType()) + { + writer.Write(func.Header); + } + } + } + + foreach (var func in _functions.OfType()) + { + var dumpPath = Path.Join(_options.DumpDir, func.PrimFunction.Name); + if (!Directory.Exists(dumpPath)) + { + Directory.CreateDirectory(dumpPath); + } + + using (var fs = File.Open(Path.Join(dumpPath, "main.cpp"), FileMode.Create)) + { + using (var writer = new StreamWriter(fs)) + { + writer.Write(func.FunctionCSource.Main); + } + } + + using (var fs = File.Open(Path.Join(dumpPath, "kernel.h"), FileMode.Create)) + { + using (var writer = new StreamWriter(fs)) + { + writer.Write(func.FunctionCSource.Kernel); + } + } + + using (var fs = File.Open(Path.Join(dumpPath, "CMakeLists.txt"), FileMode.Create)) + { + using (var writer = new StreamWriter(fs)) + { + writer.Write(CSourceBuiltn.CMakeDef(func.PrimFunction.Name)); + } + } + } + + var manager = new SectionManager(); + var textWriter = manager.GetWriter(WellknownSectionNames.Text); + var linkedFunctions = new List(); + int offset = 0; + foreach (var func in _functions.OfType()) + { + var dumpPath = Path.Join(_options.DumpDir, func.PrimFunction.Name); + var elfPath = CompileCSource(dumpPath); + + var func_text = File.ReadAllBytes(elfPath); + textWriter.Write(func_text); + linkedFunctions.Add(new LinkedFunction(func.Id, func.SourceFunction, (uint)offset, (uint)func_text.Length, func.Sections)); + offset += func_text.Length; + } + + return new LinkedModule(linkedFunctions, manager.GetContent(WellknownSectionNames.Text)!, _rdata); + } + + private string CompileCSource(string sourcePath) + { + var compiler = new CSourceCompiler(); + var binDir = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) + ? Path.Join(sourcePath, "build", "nncase_cpu_module.exe") + : Path.Join(sourcePath, "build", "nncase_cpu_module"); + return compiler.Compile(sourcePath, binDir); + } +} diff --git a/modules/Nncase.Modules.CPU/CodeGen/CPU/LinkedModule.cs b/modules/Nncase.Modules.CPU/CodeGen/CPU/LinkedModule.cs new file mode 100644 index 0000000000..a94e9a76f3 --- /dev/null +++ b/modules/Nncase.Modules.CPU/CodeGen/CPU/LinkedModule.cs @@ -0,0 +1,32 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Nncase.Runtime.StackVM; + +namespace Nncase.CodeGen.CPU; + +internal sealed class LinkedModule : ILinkedModule +{ + public LinkedModule(IReadOnlyList functions, Stream text, Stream rdata) + { + Functions = functions; + Sections = new[] + { + new LinkedSection(text, WellknownSectionNames.Text, 0, 8, (ulong)text.Length), + new LinkedSection(rdata, WellknownSectionNames.Rdata, 0, 8, (ulong)rdata.Length), + }; + } + + public string ModuleKind => "cpu"; + + public uint Version => 0; + + public IReadOnlyList Functions { get; } + + public IReadOnlyList Sections { get; } +} diff --git a/modules/Nncase.Modules.CPU/CodeGen/CPU/ModuleBuilder.cs b/modules/Nncase.Modules.CPU/CodeGen/CPU/ModuleBuilder.cs new file mode 100644 index 0000000000..ccbdb0d572 --- /dev/null +++ b/modules/Nncase.Modules.CPU/CodeGen/CPU/ModuleBuilder.cs @@ -0,0 +1,38 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System.Text; +using Nncase.Diagnostics; +using Nncase.IR; + +namespace Nncase.CodeGen.CPU; + +/// +/// K230CoreModule builder. +/// +public sealed class CPUModuleBuilder : IModuleBuilder +{ + private readonly SectionManager _sectionManager; + private readonly BinaryWriter _rdataWriter; + + public CPUModuleBuilder(CompileOptions options) + { + _sectionManager = new(); + _rdataWriter = _sectionManager.GetWriter(WellknownSectionNames.Rdata); + CompileOptions = options; + } + + public CompileOptions CompileOptions { get; } + + /// + public string ModuleKind => "cpu"; + + /// + public ILinkableModule Build(IReadOnlyList functions) + { + var linkableFunctions = functions.OfType().Select((f, i) => new FunctionBuilder((uint)i, _rdataWriter).Build(f)).ToArray(); + _rdataWriter.Flush(); + + return new LinkableModule(_sectionManager.GetContent(WellknownSectionNames.Rdata)!, linkableFunctions, CompileOptions); + } +} diff --git a/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/CMakeLists.txt.cshtml b/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/CMakeLists.txt.cshtml new file mode 100644 index 0000000000..7b48b304d5 --- /dev/null +++ b/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/CMakeLists.txt.cshtml @@ -0,0 +1,28 @@ +# This file is generated by Nncase CPU module builder. + +cmake_minimum_required(VERSION 3.15) + +project(nncase_cpu_module) + +include(@Html.Raw(Model.CMakePath)) + +add_executable(nncase_cpu_module main.cpp) +target_compile_features(nncase_cpu_module PUBLIC cxx_std_20) +target_link_libraries(nncase_cpu_module PRIVATE nncase_cpu_runtime) +target_compile_definitions(nncase_cpu_module PUBLIC -DNNCASE_CPU_MODULE=1) + +if (MSVC) + set_target_properties(nncase_cpu_module PROPERTIES LINK_FLAGS /SUBSYSTEM:CONSOLE) + target_link_options(nncase_cpu_module PRIVATE /ENTRY:kernel_entry /NODEFAULTLIB) + target_link_libraries(nncase_cpu_module PRIVATE libvcruntime msvcrt) + set_property(TARGET nncase_cpu_module PROPERTY + MSVC_RUNTIME_LIBRARY "MultiThreaded$<$:Debug>") +else() + target_link_options(nncase_cpu_module PRIVATE -static) + if (APPLE) + target_link_options(nncase_cpu_module PRIVATE -e _kernel_entry -bundle -ld_classic -lc) + else() + target_link_options(nncase_cpu_module PRIVATE -e kernel_entry -nostdlib) + target_link_libraries(nncase_cpu_module PRIVATE gcc) + endif() +endif() diff --git a/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/Kernels/Binary.cshtml b/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/Kernels/Binary.cshtml new file mode 100644 index 0000000000..bbd4779985 --- /dev/null +++ b/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/Kernels/Binary.cshtml @@ -0,0 +1,17 @@ +@model Nncase.CodeGen.CPU.BinaryKernelTemplateModel +@{ + string BinaryToCFunction(BinaryOp op) => + op switch + { + BinaryOp.Add => "ops::add", + BinaryOp.Sub => "ops::sub", + BinaryOp.Mul => "ops::mul", + BinaryOp.Div => "ops::div", + BinaryOp.Mod => "ops::mod", + BinaryOp.Min => "ops::min", + BinaryOp.Max => "ops::max", + BinaryOp.Pow => "ops::pow", + _ => throw new NotSupportedException($"Unsupported binary: {op}."), + }; +} +binary<@BinaryToCFunction(Model.BinaryOp)>(@Html.Raw(Model.Arguments[0].Symbol.Name), @Html.Raw(Model.Arguments[1].Symbol.Name), @Html.Raw(Model.Arguments[2].Symbol.Name)); diff --git a/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/Kernels/Pack.cshtml b/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/Kernels/Pack.cshtml new file mode 100644 index 0000000000..952534aaf2 --- /dev/null +++ b/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/Kernels/Pack.cshtml @@ -0,0 +1,4 @@ +@model Nncase.CodeGen.CPU.TypedKernelTemplateModel +@{ +} +pack<@string.Join(",", Model.Target.Axes)>(@Html.Raw(Model.Arguments[0].Symbol.Name), @Html.Raw(Model.Arguments[1].Symbol.Name)); diff --git a/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/Kernels/PackedLayerNorm.cshtml b/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/Kernels/PackedLayerNorm.cshtml new file mode 100644 index 0000000000..0ca328fb0b --- /dev/null +++ b/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/Kernels/PackedLayerNorm.cshtml @@ -0,0 +1,4 @@ +@model Nncase.CodeGen.CPU.TypedKernelTemplateModel +@{ +} +packed_layer_norm<@Model.Target.Axis>(@Html.Raw(Model.Arguments[0].Symbol.Name), @Html.Raw(Model.Arguments[1].Symbol.Name), @Html.Raw(Model.Arguments[2].Symbol.Name), @Html.Raw(Model.Arguments[3].Symbol.Name), @Html.Raw(Model.Args[0].CheckedTensorType.DType.ToC()) { @Model.Target.Epsilon }, @Model.Target.UseMean.ToString().ToLower(), fixed_shape<@string.Join(",", Model.Target.PackedAxes)>{}, fixed_shape<@string.Join(",", Model.Target.PadedNums)>{}); diff --git a/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/Kernels/PackedMatmul.cshtml b/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/Kernels/PackedMatmul.cshtml new file mode 100644 index 0000000000..28f1af3bb9 --- /dev/null +++ b/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/Kernels/PackedMatmul.cshtml @@ -0,0 +1,5 @@ +@model Nncase.CodeGen.CPU.TypedKernelTemplateModel +@{ +} +packed_matmul(@Html.Raw(Model.Arguments[0].Symbol.Name), @Html.Raw(Model.Arguments[1].Symbol.Name), @Html.Raw(Model.Arguments[2].Symbol.Name), fixed_shape<@string.Join(",", Model.Target.LhsPackedAxes)>{}, fixed_shape<@string.Join(",", Model.Target.LhsPadedNums)>{}, fixed_shape<@string.Join(",", Model.Target.RhsPackedAxes)>{}, fixed_shape<@string.Join(",", Model.Target.RhsPadedNums)>{}); + diff --git a/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/Kernels/PackedSoftMax.cshtml b/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/Kernels/PackedSoftMax.cshtml new file mode 100644 index 0000000000..015c5a10c5 --- /dev/null +++ b/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/Kernels/PackedSoftMax.cshtml @@ -0,0 +1,5 @@ +@model Nncase.CodeGen.CPU.TypedKernelTemplateModel +@{ +} +packed_softmax<@Model.Target.Axis>(@Html.Raw(Model.Arguments[0].Symbol.Name), @Html.Raw(Model.Arguments[1].Symbol.Name), fixed_shape<@string.Join(",", Model.Target.PackedAxes)>{}); + diff --git a/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/Kernels/PackedTranspose.cshtml b/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/Kernels/PackedTranspose.cshtml new file mode 100644 index 0000000000..213d6c78b4 --- /dev/null +++ b/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/Kernels/PackedTranspose.cshtml @@ -0,0 +1,4 @@ +@model Nncase.CodeGen.CPU.TypedKernelTemplateModel +@{ +} +transpose>(@Html.Raw(Model.Arguments[0].Symbol.Name), @Html.Raw(Model.Arguments[1].Symbol.Name)); diff --git a/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/Kernels/Reshape.cshtml b/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/Kernels/Reshape.cshtml new file mode 100644 index 0000000000..79d8d8a6bc --- /dev/null +++ b/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/Kernels/Reshape.cshtml @@ -0,0 +1,4 @@ +@model Nncase.CodeGen.CPU.TypedKernelTemplateModel +@{ +} +tensor_copy(@(Html.Raw(Model.Arguments[0].Symbol.Name)).reshape(fixed_shape<@string.Join(",", Model.Target.NewShape)>{}), @Html.Raw(Model.Arguments[1].Symbol.Name)); diff --git a/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/Kernels/Unary.cshtml b/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/Kernels/Unary.cshtml new file mode 100644 index 0000000000..29b5f56f79 --- /dev/null +++ b/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/Kernels/Unary.cshtml @@ -0,0 +1,29 @@ +@model Nncase.CodeGen.CPU.UnaryKernelTemplateModel +@{ + string UnaryToCFunction(UnaryOp op) => + op switch + { + UnaryOp.Abs => "ops::abs", + UnaryOp.Acos => "ops::acos", + UnaryOp.Acosh => "ops::acosh", + UnaryOp.Asin => "ops::asin", + UnaryOp.Asinh => "ops::asinh", + UnaryOp.Ceil => "ops::ceil", + UnaryOp.Cos => "ops::cos", + UnaryOp.Cosh => "ops::cosh", + UnaryOp.Exp => "ops::exp", + UnaryOp.Floor => "ops::floor", + UnaryOp.Log => "ops::log", + UnaryOp.Neg => "ops::neg", + UnaryOp.Round => "ops::round", + UnaryOp.Rsqrt => "ops::rsqrt", + UnaryOp.Sign => "ops::sign", + UnaryOp.Sin => "ops::sin", + UnaryOp.Sinh => "ops::sinh", + UnaryOp.Sqrt => "ops::sqrt", + UnaryOp.Square => "ops::square", + UnaryOp.Tanh => "ops::tanh", + _ => throw new NotSupportedException($"Unsupported unary: {op}."), + }; +} +unary<@UnaryToCFunction(Model.UnaryOp)>(@Html.Raw(Model.Arguments[0].Symbol.Name), @Html.Raw(Model.Arguments[1].Symbol.Name)); diff --git a/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/Kernels/Unpack.cshtml b/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/Kernels/Unpack.cshtml new file mode 100644 index 0000000000..3154087509 --- /dev/null +++ b/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/Kernels/Unpack.cshtml @@ -0,0 +1,4 @@ +@model Nncase.CodeGen.CPU.TypedKernelTemplateModel +@{ +} +unpack<@string.Join(",", Model.Target.Axes)>(@Html.Raw(Model.Arguments[0].Symbol.Name), @Html.Raw(Model.Arguments[1].Symbol.Name)); diff --git a/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/_ViewImports.cshtml b/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/_ViewImports.cshtml new file mode 100644 index 0000000000..ad79fd8715 --- /dev/null +++ b/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/_ViewImports.cshtml @@ -0,0 +1,4 @@ +@using Nncase +@using Nncase.CodeGen.CPU +@using Nncase.TIR +@*@addTagHelper *, Microsoft.AspNetCore.Mvc.TagHelpers*@ diff --git a/modules/Nncase.Modules.CPU/Evaluator/CPU/Boxing.cs b/modules/Nncase.Modules.CPU/Evaluator/CPU/Boxing.cs new file mode 100644 index 0000000000..e88422dc16 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Evaluator/CPU/Boxing.cs @@ -0,0 +1,155 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. +#pragma warning disable SA1010, SA1008 +using System; +using System.Collections.Generic; +using System.Linq; +using Nncase.CostModel; +using Nncase.IR; +using Nncase.IR.CPU; +using Nncase.Utilities; + +namespace Nncase.Evaluator.IR.CPU; + +public sealed class BoxingEvaluator : ITypeInferencer, ICostEvaluator, IEvaluator +{ + private const int _burstLength = 256; + + public IRType Visit(ITypeInferenceContext context, Boxing target) + { + return context.GetArgumentType(target, Boxing.Input) switch + { + InvalidType inv => inv, + _ => target.NewType, + }; + } + + public Cost Visit(ICostEvaluateContext context, Boxing target) + { + var inType = context.GetArgumentType(target, Boxing.Input); + var returnType = context.GetReturnType(); + var cost = new Cost() { [CostFactorNames.MemoryLoad] = 0, [CostFactorNames.MemoryStore] = 0 }; + switch (inType, returnType) + { + case (TensorType tensorType, DistributedType distTensorType): + cost = new Cost() + { + [CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(tensorType), + [CostFactorNames.MemoryStore] = (UInt128)((float)CostUtility.GetMemoryAccess(distTensorType) / DistributedUtility.GetDividedTensorEfficiency(distTensorType, _burstLength)), + }; + break; + case (DistributedType distTensorType, TensorType tensorType): + cost = new Cost() + { + [CostFactorNames.MemoryLoad] = (UInt128)((float)CostUtility.GetMemoryAccess(distTensorType) / DistributedUtility.GetDividedTensorEfficiency(distTensorType, _burstLength)), + [CostFactorNames.MemoryStore] = CostUtility.GetMemoryAccess(tensorType), + }; + break; + + case (DistributedType a, DistributedType b) when a.Placement == b.Placement && a.NdSBP != b.NdSBP: + { + var fullLoadStore = new Cost() + { + [CostFactorNames.MemoryStore] = (UInt128)((float)CostUtility.GetMemoryAccess(a) / DistributedUtility.GetDividedTensorEfficiency(a, _burstLength)), + [CostFactorNames.MemoryLoad] = (UInt128)((float)CostUtility.GetMemoryAccess(b) / DistributedUtility.GetDividedTensorEfficiency(b, _burstLength)), + }; + + float scatterPart = 1; + float gatherPart = 1; + for (int i = 0; i < a.Placement.Rank; i++) + { + switch (a.NdSBP[i], b.NdSBP[i]) + { + case (SBPSplit { Axis: int ax }, SBP sbpout): + switch (sbpout) + { + case SBPSplit { Axis: int bx }: + if (ax != bx) + { + // when split different axis, need global load store. + return fullLoadStore; + } + + break; + case SBPBroadCast: + scatterPart *= a.Placement.Hierarchy[i]; + gatherPart *= a.Placement.Hierarchy[i]; + break; + default: + throw new NotSupportedException("split to partial"); + } + + break; + case (SBPBroadCast, SBPBroadCast or SBPSplit): + // no cost. + cost += new Cost() + { + [CostFactorNames.CPUCycles] = 1, + }; + break; + case (SBPPartialSum, SBP sbpout): + switch (sbpout) + { + case SBPPartialSum: + break; + case SBPBroadCast or SBPSplit: + gatherPart *= a.Placement.Hierarchy[i]; + if (i == 0) + { + scatterPart *= a.Placement.Hierarchy[i]; + } + + break; + } + + break; + default: + throw new NotSupportedException($"{a} to {b}"); + } + } + + if (gatherPart > 1f) + { + cost += new Cost() + { + [CostFactorNames.MemoryStore] = (UInt128)((gatherPart - 1) * (float)CostUtility.GetMemoryAccess(DistributedUtility.GetDividedTensorType(a)) / gatherPart), + }; + } + + if (scatterPart > 1f) + { + cost += new Cost() + { + [CostFactorNames.MemoryLoad] = (UInt128)((scatterPart - 1) * (float)CostUtility.GetMemoryAccess(DistributedUtility.GetDividedTensorType(b)) / scatterPart), + }; + } + } + + break; + case (DistributedType a, DistributedType b) when a.TensorType != b.TensorType && a.Placement == b.Placement: + cost = new Cost() + { + [CostFactorNames.MemoryStore] = (UInt128)((float)CostUtility.GetMemoryAccess(a) / DistributedUtility.GetDividedTensorEfficiency(a, _burstLength)), + [CostFactorNames.MemoryLoad] = (UInt128)((float)CostUtility.GetMemoryAccess(b) / DistributedUtility.GetDividedTensorEfficiency(b, _burstLength)), + }; + break; + case (DistributedType a, DistributedType b) when a == b: + throw new InvalidOperationException($"the boxing inType == outType"); + default: + throw new NotSupportedException($"{inType} {returnType}"); + } + + return cost; + } + + public IValue Visit(IEvaluateContext context, Boxing target) + { + var input = context.GetArgumentValueAsTensor(target, Boxing.Input); + return target.NewType switch + { + TensorType t => Value.FromTensor(Tensor.FromBytes(input.ElementType, input.BytesBuffer.ToArray(), t.Shape)), + DistributedType d => Value.FromTensor(Tensor.FromBytes(input.ElementType, input.BytesBuffer.ToArray(), d.TensorType.Shape)), + _ => Value.FromTensor(input), + }; + } +} diff --git a/modules/Nncase.Modules.CPU/Evaluator/CPU/CPUKernelOp.cs b/modules/Nncase.Modules.CPU/Evaluator/CPU/CPUKernelOp.cs new file mode 100644 index 0000000000..39e1f58601 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Evaluator/CPU/CPUKernelOp.cs @@ -0,0 +1,34 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Linq; +using Nncase.CostModel; +using Nncase.IR; +using Nncase.IR.CPU; + +namespace Nncase.Evaluator.IR.CPU; + +/// +/// Evaluator for . +/// +public class CPUKernelOpEvaluator : IEvaluator, ITypeInferencer, ICostEvaluator +{ + /// + public IValue Visit(IEvaluateContext context, CPUKernelOp target) + { + return CompilerServices.EvaluateOp(target.Target, context); + } + + /// + public IRType Visit(ITypeInferenceContext context, CPUKernelOp target) + { + return CompilerServices.InferenceOp(target.Target, context, new()); + } + + /// + public Cost Visit(ICostEvaluateContext context, CPUKernelOp target) + { + return CompilerServices.EvaluateOpCost(target.Target, context); + } +} diff --git a/modules/Nncase.Modules.CPU/Evaluator/CPU/CPUModule.cs b/modules/Nncase.Modules.CPU/Evaluator/CPU/CPUModule.cs new file mode 100644 index 0000000000..70c0fc141c --- /dev/null +++ b/modules/Nncase.Modules.CPU/Evaluator/CPU/CPUModule.cs @@ -0,0 +1,28 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using DryIoc; +using Nncase.Hosting; + +namespace Nncase.Evaluator.IR.CPU; + +/// +/// CPU module. +/// +internal class CPUModule : IApplicationPart +{ + public void ConfigureServices(IRegistrator registrator) + { + registrator.RegisterManyInterface(reuse: Reuse.Singleton); + registrator.RegisterManyInterface(reuse: Reuse.Singleton); + registrator.RegisterManyInterface(reuse: Reuse.Singleton); + registrator.RegisterManyInterface(reuse: Reuse.Singleton); + registrator.RegisterManyInterface(reuse: Reuse.Singleton); + registrator.RegisterManyInterface(reuse: Reuse.Singleton); + registrator.RegisterManyInterface(reuse: Reuse.Singleton); + registrator.RegisterManyInterface(reuse: Reuse.Singleton); + registrator.RegisterManyInterface(reuse: Reuse.Singleton); + registrator.RegisterManyInterface(reuse: Reuse.Singleton); + registrator.RegisterManyInterface(reuse: Reuse.Singleton); + } +} diff --git a/modules/Nncase.Modules.CPU/Evaluator/CPU/Load.cs b/modules/Nncase.Modules.CPU/Evaluator/CPU/Load.cs new file mode 100644 index 0000000000..cf0902ce46 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Evaluator/CPU/Load.cs @@ -0,0 +1,27 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Nncase.CostModel; +using Nncase.IR; +using Nncase.IR.CPU; + +namespace Nncase.Evaluator.IR.CPU; + +public sealed class LoadEvaluator : ITypeInferencer, ICostEvaluator +{ + public IRType Visit(ITypeInferenceContext context, Load target) + { + return context.GetArgumentType(target, Load.Input); + } + + public Cost Visit(ICostEvaluateContext context, Load target) => new Cost() + { + [CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(context.GetArgumentType(target, Load.Input)), + [CostFactorNames.MemoryStore] = CostUtility.GetMemoryAccess(context.GetArgumentType(target, Load.Input)), + }; +} diff --git a/modules/Nncase.Modules.CPU/Evaluator/CPU/Pack.cs b/modules/Nncase.Modules.CPU/Evaluator/CPU/Pack.cs new file mode 100644 index 0000000000..710a29cbc5 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Evaluator/CPU/Pack.cs @@ -0,0 +1,81 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. +#pragma warning disable SA1010, SA1008 +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using Nncase.CostModel; +using Nncase.IR; +using Nncase.IR.CPU; +using Nncase.IR.Tensors; +using Nncase.Utilities; +using OrtKISharp; + +namespace Nncase.Evaluator.IR.CPU; + +public sealed class PackEvaluator : ITypeInferencer, ICostEvaluator, IEvaluator +{ + /// + public IValue Visit(IEvaluateContext context, Pack target) + { + var input = context.GetOrtArgumentValue(target, Pack.Input); + foreach (var (lanes, axis) in target.Lanes.Zip(target.Axes)) + { + input = input.Pack(lanes, axis); + } + + return Value.FromTensor(Tensor.FromBytes(new VectorType(input.DataType.ToDataType(), target.Lanes), input.BytesBuffer.ToArray(), input.Shape.ToArray().SkipLast(target.Lanes.Count).Select(i => (int)i).ToArray())); + } + + /// + public IRType Visit(ITypeInferenceContext context, Pack target) + { + var input = context.CheckArgumentType(target, Pack.Input); + + return input switch + { + DistributedType d => Visit(context, target, d), + TensorType t => Visit(context, target, t), + AnyType => AnyType.Default, + _ => new InvalidType(input.GetType().ToString()), + }; + } + + /// + public Cost Visit(ICostEvaluateContext context, Pack target) + { + var inputType = context.GetArgumentType(target, Pack.Input); + var outputType = context.GetReturnType(); + + return new() + { + [CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(inputType), + [CostFactorNames.MemoryStore] = CostUtility.GetMemoryAccess(outputType), + }; + } + + public Metric Visit(IMetricEvaluateContext context, Pack target) + { + var returnType = context.GetReturnType(); + return new() + { + [MetricFactorNames.OffChipMemoryTraffic] = CostUtility.GetMemoryAccess(returnType) * 2, + }; + } + + private IRType Visit(ITypeInferenceContext context, Pack target, TensorType input) + { + return TypeInference.PackType(input, target.Lanes, target.Axes); + } + + private IRType Visit(ITypeInferenceContext context, Pack target, DistributedType input) + { + if (Visit(context, target, input.TensorType) is not TensorType tensorType) + { + throw new InvalidOperationException(); + } + + return new DistributedType(tensorType, input.NdSBP, input.Placement); + } +} diff --git a/modules/Nncase.Modules.CPU/Evaluator/CPU/PackedBinary.cs b/modules/Nncase.Modules.CPU/Evaluator/CPU/PackedBinary.cs new file mode 100644 index 0000000000..8a1d5fa5aa --- /dev/null +++ b/modules/Nncase.Modules.CPU/Evaluator/CPU/PackedBinary.cs @@ -0,0 +1,230 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +#pragma warning disable SA1008 // Opening parenthesis should be spaced correctly + +using System; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Numerics; +using System.Runtime.InteropServices; +using Nncase.CostModel; +using Nncase.IR; +using Nncase.IR.CPU; +using Nncase.Utilities; +using OrtKISharp; + +namespace Nncase.Evaluator.IR.CPU; + +public sealed class PackedBinaryEvaluator : IEvaluator, ITypeInferencer, ICostEvaluator +{ + internal enum DimKind : int + { + E, // elemwise + B, // broadcast + } + + public IValue Visit(IEvaluateContext context, PackedBinary target) + { + var a = context.GetOrtArgumentValue(target, PackedBinary.Lhs); + var b = context.GetOrtArgumentValue(target, PackedBinary.Rhs); + _ = System.Math.Max(target.LhsPackedAxes.Count, target.RhsPackedAxes.Count); + + switch (target.LhsPackedAxes.Count, target.RhsPackedAxes.Count) + { + case (2, 1): + b = OrtKI.Unsqueeze(b, new long[] { -2 }); + break; + case (1, 2): + a = OrtKI.Unsqueeze(a, new long[] { -2 }); + break; + default: + break; + } + + var binary = target.BinaryOp switch + { + BinaryOp.Add => a + b, + BinaryOp.Sub => a - b, + BinaryOp.Mul => a * b, + BinaryOp.Div => a / b, + _ => throw new ArgumentOutOfRangeException(target.BinaryOp.ToString()), + }; + + return Value.FromTensor(Tensor.FromBytes(context.CurrentCall.CheckedDataType, binary.BytesBuffer.ToArray(), context.CurrentCall.CheckedShape)); + } + + public IRType Visit(ITypeInferenceContext context, PackedBinary target) + { + var lhs = context.CheckArgumentType(target, PackedBinary.Lhs); + var rhs = context.CheckArgumentType(target, PackedBinary.Rhs); + + return (lhs, rhs) switch + { + (DistributedType a, DistributedType b) => Visit(target, a, b), + (TensorType a, TensorType b) => Visit(target, a, b), + _ => new InvalidType("not support"), + }; + } + + public Cost Visit(ICostEvaluateContext context, PackedBinary target) + { + var lhs = context.GetArgumentType(target, PackedBinary.Lhs); + var rhs = context.GetArgumentType(target, PackedBinary.Rhs); + var outputType = context.GetReturnType(); + + uint macPerElement = 1; + if (lhs is TensorType { Shape: Shape lhsShape }) + { + macPerElement = lhsShape[^1].IsFixed ? (uint)lhsShape[^1].FixedValue : 1U; + } + else if (lhs is DistributedType distributedType) + { + var lhsType = DistributedUtility.GetDividedTensorType(distributedType); + macPerElement = lhsType.Shape[^1].IsFixed ? (uint)lhsType.Shape[^1].FixedValue : 1U; + } + + return new() + { + [CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(lhs) + CostUtility.GetMemoryAccess(rhs), + [CostFactorNames.MemoryStore] = CostUtility.GetMemoryAccess(outputType), + [CostFactorNames.CPUCycles] = CostUtility.GetCPUCycles(outputType, macPerElement), + }; + } + + private IRType Visit(PackedBinary target, TensorType a, TensorType b) + { + var rank = System.Math.Max(a.Shape.Rank, b.Shape.Rank); + var outShape = new int[rank]; + var lhsOrginShape = a.Shape.ToValueArray(); + var rhsOrginShape = b.Shape.ToValueArray(); + for (int i = 0; i < target.LhsPackedAxes.Count; i++) + { + lhsOrginShape[target.LhsPackedAxes[i]] = (lhsOrginShape[target.LhsPackedAxes[i]] * ((VectorType)a.DType).Lanes[i]) - target.LhsPadedNums[i]; + } + + for (int i = 0; i < target.RhsPackedAxes.Count; i++) + { + rhsOrginShape[target.RhsPackedAxes[i]] = (rhsOrginShape[target.RhsPackedAxes[i]] * ((VectorType)b.DType).Lanes[i]) - target.RhsPadedNums[i]; + } + + var orginKinds = new DimKind[rank]; + + for (int i = -1; i >= -rank; i--) + { + var aAxis = a.Shape.Rank + i; + var bAxis = b.Shape.Rank + i; + switch (aAxis, bAxis) + { + case ( < 0, _): + outShape[rank + i] = b.Shape[bAxis].FixedValue; + orginKinds[rank + i] = DimKind.B; + break; + case (_, < 0): + outShape[rank + i] = a.Shape[aAxis].FixedValue; + orginKinds[rank + i] = DimKind.B; + break; + case ( >= 0, >= 0): + switch (lhsOrginShape[aAxis], rhsOrginShape[bAxis]) + { + case (int l, int r) when l == r: + outShape[rank + i] = a.Shape[aAxis].FixedValue; + orginKinds[rank + i] = DimKind.E; + break; + case (1, _): + outShape[rank + i] = b.Shape[bAxis].FixedValue; + orginKinds[rank + i] = DimKind.B; + break; + case (_, 1): + outShape[rank + i] = a.Shape[aAxis].FixedValue; + orginKinds[rank + i] = DimKind.B; + break; + default: + return new InvalidType("packed binary not support dim"); + } + + break; + default: + throw new NotSupportedException(); + } + } + + // second check the dtype. + DataType dataType; + switch (a.DType, b.DType) + { + case (VectorType va, VectorType vb): + { + var lanes = System.Math.Max(va.Lanes.Count, vb.Lanes.Count); + var valid = true; + for (int i = -1; i >= -lanes; --i) + { + var ai = va.Lanes.Count + i; + var bi = vb.Lanes.Count + i; + switch (ai, bi) + { + case ( < 0, _): + valid &= orginKinds[target.RhsPackedAxes[bi] - b.Shape.Rank + rank] == DimKind.B && rhsOrginShape[target.RhsPackedAxes[bi]] != 1; + break; + case (_, < 0): + valid &= orginKinds[target.LhsPackedAxes[ai] - a.Shape.Rank + rank] == DimKind.B && lhsOrginShape[target.LhsPackedAxes[ai]] != 1; + break; + case ( >= 0, >= 0): + var laxis = target.LhsPackedAxes[ai] - a.Shape.Rank + rank; + var raxis = target.RhsPackedAxes[bi] - b.Shape.Rank + rank; + valid &= lhsOrginShape[target.LhsPackedAxes[ai]] == rhsOrginShape[target.RhsPackedAxes[bi]] && laxis == raxis && orginKinds[laxis] == orginKinds[raxis] && orginKinds[raxis] == DimKind.E; + break; + } + } + + if (valid) + { + dataType = va.Lanes.Count >= vb.Lanes.Count ? va : vb; + } + else + { + return new InvalidType("can't pack on the broadcast axis!"); + } + } + + break; + case (VectorType va, PrimType pb): + if (va.ElemType != pb) + { + return new InvalidType("Shape Can't Broadcast"); + } + + dataType = va; + break; + case (PrimType pa, VectorType vb): + if (vb.ElemType != pa) + { + return new InvalidType("Shape Can't Broadcast"); + } + + dataType = vb; + break; + default: + return new InvalidType("Shape Can't Broadcast"); + } + + return new TensorType(dataType, outShape); + } + + private IRType Visit(PackedBinary target, DistributedType a, DistributedType b) + { + if (a.Placement != b.Placement) + { + return new InvalidType("lhs rhs have different placement"); + } + + var rType = Visit(target, a.TensorType, b.TensorType); + if (rType is not TensorType tensorType) + { + return rType; + } + + return Math.BinaryEvaluator.CheckSBP(target.BinaryOp, tensorType, a, b); + } +} +#pragma warning restore SA1008 // Opening parenthesis should be spaced correctly diff --git a/modules/Nncase.Modules.CPU/Evaluator/CPU/PackedLayerNorm.cs b/modules/Nncase.Modules.CPU/Evaluator/CPU/PackedLayerNorm.cs new file mode 100644 index 0000000000..5d2397daee --- /dev/null +++ b/modules/Nncase.Modules.CPU/Evaluator/CPU/PackedLayerNorm.cs @@ -0,0 +1,206 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Linq; +using System.Numerics; +using System.Runtime.InteropServices; +using Nncase.CostModel; +using Nncase.IR; +using Nncase.IR.CPU; +using Nncase.Utilities; +using OrtKISharp; + +namespace Nncase.Evaluator.IR.CPU; + +public sealed class PackedLayerNormEvaluator : IEvaluator, ITypeInferencer, ICostEvaluator, + IShapeEvaluator, IMetricEvaluator +{ + /// + public IValue Visit(IEvaluateContext context, PackedLayerNorm target) + { + var input = context.GetOrtArgumentValue(target, PackedLayerNorm.Input); + var scale = context.GetOrtArgumentValue(target, PackedLayerNorm.Scale); + var bias = context.GetOrtArgumentValue(target, PackedLayerNorm.Bias); + var lanes = input.Shape.TakeLast(target.PackedAxes.Count).Select(i => (int)i).ToArray(); + var unpackedInput = UnpackTensor(input, target.PackedAxes, target.PadedNums); + var packAxes = target.PackedAxes.Where(axis => axis >= target.Axis).Select(axis => axis - target.Axis).ToArray(); + var padedNums = target.PadedNums.Skip(target.PackedAxes.Count - packAxes.Length).ToArray(); + var unpackedScale = UnpackTensor(scale, packAxes, padedNums); + var unpackedBias = UnpackTensor(bias, packAxes, padedNums); + + var shape = unpackedInput.Shape.Select(i => (int)i).ToArray(); + var inputBuffer = unpackedInput.BytesBuffer.ToArray(); + var inputSpan = MemoryMarshal.Cast(inputBuffer); + var scaleBuffer = unpackedScale.BytesBuffer.ToArray(); + var scaleSpan = MemoryMarshal.Cast(scaleBuffer); + var biasBuffer = unpackedBias.BytesBuffer.ToArray(); + var biasSpan = MemoryMarshal.Cast(biasBuffer); + + var output = NN.LayerNormEvaluator.LayerNormImpl(shape, inputSpan, scaleSpan, biasSpan, target.Axis, target.Epsilon, target.UseMean); + var outputTensor = OrtKISharp.Tensor.MakeTensor(new Memory(output), OrtDataType.Float, unpackedInput.Shape); + outputTensor = RepackTensor(outputTensor, lanes, target.PackedAxes, target.PadedNums); + + return Value.FromTensor(Tensor.FromBytes(new VectorType(DataTypes.Float32, lanes), outputTensor.BytesBuffer.ToArray(), outputTensor.Shape.SkipLast(target.PackedAxes.Count).Select(i => (int)i).ToArray())); + } + + /// + public IRType Visit(ITypeInferenceContext context, PackedLayerNorm target) + { + var input = context.CheckArgumentType(target, PackedLayerNorm.Input); + var scale = context.CheckArgumentType(target, PackedLayerNorm.Scale); + var bias = context.CheckArgumentType(target, PackedLayerNorm.Bias); + + return (input, scale, bias) switch + { + (DistributedType a, DistributedType b, DistributedType c) => Visit(a, b, c, target.Axis), + (TensorType a, TensorType, TensorType) => Visit(a), + _ => new InvalidType(input.GetType().ToString()), + }; + } + + /// + public Cost Visit(ICostEvaluateContext context, PackedLayerNorm target) + { + var inputType = context.GetArgumentType(target, PackedLayerNorm.Input); + var returnType = context.GetReturnType(); + switch (inputType, returnType) + { + case (TensorType, TensorType): + return new() + { + [CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(inputType), + [CostFactorNames.MemoryStore] = CostUtility.GetMemoryAccess(returnType), + }; + + case (DistributedType inputDistributedType, DistributedType): + var scaleType = context.GetArgumentType(target, PackedLayerNorm.Scale); + var biasType = context.GetArgumentType(target, PackedLayerNorm.Bias); + var ring = GetRingReduceCommunicate(scaleType, new[] { 0, 1 }) + GetRingReduceCommunicate(biasType, new[] { 0, 1 }); + var reCompute = inputDistributedType.NdSBP.Select((sbp, i) => sbp is SBPSplit ? 1 : inputDistributedType.Placement.Hierarchy[i]).ToArray().Aggregate(1, (acc, rep) => acc * rep); + return new() + { + [CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(inputType) + ring, + [CostFactorNames.CPUCycles] = CostUtility.GetCPUCycles(inputType, 1) * (UInt128)reCompute, + [CostFactorNames.MemoryStore] = CostUtility.GetMemoryAccess(returnType) + ring, + }; + default: + throw new NotSupportedException(); + } + } + + public Metric Visit(IMetricEvaluateContext context, PackedLayerNorm target) + { + var inputType = context.GetArgumentType(target, PackedLayerNorm.Input); + var returnType = context.GetReturnType(); + + var r = MetricUtility.GetFLOPs(returnType); + var i = MetricUtility.GetFLOPs(inputType); + var outter = i / r; + var inner = i / outter; + + return new() + { + [MetricFactorNames.OffChipMemoryTraffic] = CostUtility.GetMemoryAccess(inputType) + CostUtility.GetMemoryAccess(returnType), + [MetricFactorNames.FLOPs] = outter * ((inner * 7) + MetricUtility.SqrtFLOPs), + [MetricFactorNames.Parallel] = 4, + }; + } + + public Expr Visit(IShapeEvaluateContext context, PackedLayerNorm target) => context.GetArgumentShape(target, PackedLayerNorm.Input); + + private static OrtKISharp.Tensor UnpackTensor(OrtKISharp.Tensor input, IRArray packedAxes, IRArray padNums) + { + OrtKISharp.Tensor unpacked = input; + foreach (var axis in packedAxes.Reverse()) + { + unpacked = unpacked.Unpack(axis); + } + + var shape = unpacked.Shape.ToArray(); + + OrtKISharp.Tensor sliced = unpacked; + if (padNums.Any(i => i > 0)) + { + sliced = OrtKI.Slice(unpacked, Enumerable.Repeat(0L, padNums.Count).ToArray(), Enumerable.Range(0, padNums.Count).Select(i => shape[packedAxes[i]] - padNums[i]).ToArray(), packedAxes.Select(i => (long)i).ToArray(), Enumerable.Range(0, padNums.Count).Select(i => 1L).ToArray()); + } + + return sliced; + } + + private static OrtKISharp.Tensor RepackTensor(OrtKISharp.Tensor input, IRArray lanes, IRArray packedAxes, IRArray padNums) + { + OrtKISharp.Tensor paded = input; + var shape = input.Shape; + + if (padNums.Any(i => i > 0)) + { + var pads = Enumerable.Repeat(0L, shape.Length * 2).ToArray(); + for (int i = 0; i < packedAxes.Count; i++) + { + pads[shape.Length + packedAxes[i]] = padNums[i]; + } + + // bottom_0,bottom_1,..., top_0, top_1, ... + paded = OrtKI.Pad(paded, pads, 0f, "constant"); + } + + OrtKISharp.Tensor packed = paded; + foreach (var (lane, axis) in lanes.Zip(packedAxes)) + { + packed = packed.Pack(lane, axis); + } + + return packed; + } + + private IRType Visit(TensorType input) + { + return input; + } + + private IRType Visit(DistributedType input, DistributedType scale, DistributedType bias, int raxis) + { + var invalid = new InvalidType($"{input}, {scale}, {bias} not support"); + if (input.Placement != scale.Placement || scale.Placement != bias.Placement) + { + return invalid; + } + + var ndsbp = new SBP[input.Placement.Rank]; + + for (int i = 0; i < input.Placement.Rank; i++) + { + switch (input.NdSBP[i], scale.NdSBP[i], bias.NdSBP[i]) + { + case (SBPSplit { Axis: int ix }, SBPSplit { Axis: int sx }, SBPSplit { Axis: int bx }) when ix >= raxis && sx == (ix - raxis) && bx == sx: + ndsbp[i] = SBP.S(ix); + break; + case (SBPSplit { Axis: int ix }, SBPBroadCast, SBPBroadCast) when ix < raxis: + ndsbp[i] = SBP.S(ix); + break; + case (SBPBroadCast, SBPBroadCast, SBPBroadCast): + ndsbp[i] = SBP.B; + break; + default: + return invalid; + } + } + + return new DistributedType(input.TensorType, ndsbp, input.Placement); + } + + private UInt128 GetRingReduceCommunicate(DistributedType distributedType, int[] axes) + { + var ttype = Utilities.DistributedUtility.GetDividedTensorType(distributedType); + var splits = axes.Where(i => i < distributedType.Placement.Rank && distributedType.NdSBP[i] is SBPSplit); + if (!splits.Any()) + { + return 0; + } + + var p = (UInt128)splits.Select(i => distributedType.Placement.Hierarchy[i]).Aggregate(1, (acc, i) => acc * i); + var v = CostUtility.GetMemoryAccess(distributedType.TensorType); + return (p - 1) * (v / p); + } +} diff --git a/modules/Nncase.Modules.CPU/Evaluator/CPU/PackedMatMul.cs b/modules/Nncase.Modules.CPU/Evaluator/CPU/PackedMatMul.cs new file mode 100644 index 0000000000..e327b2b4fd --- /dev/null +++ b/modules/Nncase.Modules.CPU/Evaluator/CPU/PackedMatMul.cs @@ -0,0 +1,146 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Linq; +using System.Numerics; +using System.Runtime.InteropServices; +using Nncase.CostModel; +using Nncase.IR; +using Nncase.IR.CPU; +using Nncase.Utilities; +using OrtKISharp; + +namespace Nncase.Evaluator.IR.CPU; + +public sealed class PackedMatMulEvaluator : IEvaluator, ITypeInferencer, ICostEvaluator +{ + public IValue Visit(IEvaluateContext context, PackedMatMul target) + { + var lhs = context.GetOrtArgumentValue(target, PackedMatMul.Lhs); // [x,m/32,k/32,m',k'] + var rhs = context.GetOrtArgumentValue(target, PackedMatMul.Rhs); // [x,k/32,n/32,k',n'] + + var outLanes = target.LhsPackedAxes.Count == 1 ? Array.Empty() : new[] { (int)lhs.Shape[^2], (int)rhs.Shape[^1] }; + var outshape = target.LhsPackedAxes.Count == 1 ? new[] { (int)lhs.Shape[^3], (int)rhs.Shape[^2] } : new[] { (int)lhs.Shape[^4], (int)rhs.Shape[^3] }; + var maxRank = System.Math.Max(lhs.Shape.Length, rhs.Shape.Length); + outshape = Enumerable.Repeat(1L, maxRank - lhs.Shape.Length).Concat(lhs.Shape.SkipLast(2 + target.LhsPackedAxes.Count)). + Zip(Enumerable.Repeat(1L, maxRank - rhs.Shape.Length).Concat(rhs.Shape.SkipLast(2 + target.RhsPackedAxes.Count))). + Select(p => (int)System.Math.Max(p.First, p.Second)). + Concat(outshape).ToArray(); + + foreach (var axis in target.LhsPackedAxes.Reverse()) + { + lhs = lhs.Unpack(axis); + } + + foreach (var axis in target.RhsPackedAxes.Reverse()) + { + rhs = rhs.Unpack(axis); + } + + // lhs = OrtKI.Unsqueeze(lhs, new long[] { -4, -1 }); // [x,m/32,k/32, 1 , m' ,k', 1 ] + // rhs = OrtKI.Unsqueeze(rhs, new long[] { -6, -3 }); // [x, 1 ,k/32,n/32, 1 ,k', n'] + // var matmul = OrtKI.Mul(lhs, rhs); // [x, m/32,k/32,n/32,m',k',n'] + // matmul = OrtKI.ReduceSum(matmul, new long[] { -2, -5 }, 0, 1); + var matmul = OrtKI.MatMul(lhs, rhs); + if (target.LhsPackedAxes.Count == 2) + { + foreach (var (lane, axis) in outLanes.Zip(new[] { -2 + outshape.Length, -1 + outshape.Length })) + { + matmul = matmul.Pack(lane, axis); + } + } + + return Value.FromTensor(Tensor.FromBytes(outLanes.Length == 0 ? DataTypes.Float32 : new VectorType(DataTypes.Float32, outLanes), matmul.BytesBuffer.ToArray(), outshape)); + } + + public IRType Visit(ITypeInferenceContext context, PackedMatMul target) + { + var lhs = context.CheckArgumentType(target, PackedMatMul.Lhs); + var rhs = context.CheckArgumentType(target, PackedMatMul.Rhs); + + bool CheckPackAxes(Shape lhs, Shape rhs) + { + bool valid = true; + switch (target.LhsPackedAxes.Count, target.RhsPackedAxes.Count) + { + case (1, 1): + if (target.LhsPackedAxes[0] != lhs.Rank - 1 || target.RhsPackedAxes[0] != rhs.Rank - 2) + { + valid = false; + } + + break; + case (2, 2): + if (target.LhsPackedAxes[0] != lhs.Rank - 2 || target.LhsPackedAxes[1] != lhs.Rank - 1) + { + valid = false; + } + + if (target.RhsPackedAxes[0] != rhs.Rank - 2 || target.RhsPackedAxes[1] != rhs.Rank - 1) + { + valid = false; + } + + break; + default: + valid = false; + break; + } + + return valid; + } + + IRType rType; + switch (lhs, rhs) + { + case (DistributedType a, DistributedType b): + if (!CheckPackAxes(a.TensorType.Shape, b.TensorType.Shape)) + { + goto ERROR; + } + + rType = Math.MatMulEvaluator.VisitDistributedType(a, b); + + break; + case (TensorType a, TensorType b): + if (!CheckPackAxes(a.Shape, b.Shape)) + { + goto ERROR; + } + + rType = Math.MatMulEvaluator.VisitTensorType(a, b); + break; + default: + ERROR: rType = new InvalidType($"{lhs} {rhs} not support"); + break; + } + + return rType; + } + + public Cost Visit(ICostEvaluateContext context, PackedMatMul target) + { + var lhs = context.GetArgumentType(target, PackedMatMul.Lhs); + var rhs = context.GetArgumentType(target, PackedMatMul.Rhs); + var outputType = context.GetReturnType(); + + uint macPerElement = 1; + if (lhs is TensorType { Shape: Shape lhsShape }) + { + macPerElement = lhsShape[^1].IsFixed ? (uint)lhsShape[^1].FixedValue : 1U; + } + else if (lhs is DistributedType distributedType) + { + var lhsType = DistributedUtility.GetDividedTensorType(distributedType); + macPerElement = lhsType.Shape[^1].IsFixed ? (uint)lhsType.Shape[^1].FixedValue : 1U; + } + + return new() + { + [CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(lhs) + CostUtility.GetMemoryAccess(rhs), + [CostFactorNames.MemoryStore] = CostUtility.GetMemoryAccess(outputType), + [CostFactorNames.CPUCycles] = CostUtility.GetCPUCycles(outputType, macPerElement), + }; + } +} diff --git a/modules/Nncase.Modules.CPU/Evaluator/CPU/PackedSoftMax.cs b/modules/Nncase.Modules.CPU/Evaluator/CPU/PackedSoftMax.cs new file mode 100644 index 0000000000..0171708cf1 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Evaluator/CPU/PackedSoftMax.cs @@ -0,0 +1,85 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using Nncase.CostModel; +using Nncase.IR; +using Nncase.IR.CPU; +using Nncase.IR.Tensors; +using Nncase.Utilities; +using OrtKISharp; + +namespace Nncase.Evaluator.IR.CPU; + +public sealed class PackedSoftMaxEvaluator : ITypeInferencer, ICostEvaluator, IEvaluator +{ + public IRType Visit(ITypeInferenceContext context, PackedSoftmax target) + { + var input = context.CheckArgumentType(target, PackedSoftmax.Input); + + return input switch + { + DistributedType d => Visit(context, target, d), + TensorType t => Visit(context, target, t), + AnyType => AnyType.Default, + _ => new InvalidType(input.GetType().ToString()), + }; + } + + public Cost Visit(ICostEvaluateContext context, PackedSoftmax target) + { + var returnType = context.GetReturnType(); + return new() + { + [CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(returnType), + [CostFactorNames.MemoryStore] = CostUtility.GetMemoryAccess(returnType), + }; + } + + public IValue Visit(IEvaluateContext context, PackedSoftmax target) + { + var input = context.GetOrtArgumentValue(target, PackedSoftmax.Input); + var shape = input.Shape.Select(i => (int)i).ToArray(); + OrtKISharp.Tensor softmax; + if (!target.PackedAxes.Any(i => i == target.Axis)) + { + softmax = OrtKI.Softmax(input, target.Axis); + } + else + { + var packedAxis = shape.Length - target.PackedAxes.Count + target.PackedAxes.IndexOf(target.Axis); + var max = OrtKI.ReduceMax(input, new long[] { target.Axis, packedAxis }, 1); + var exp = OrtKI.Exp(input - max); + var reduceSum = OrtKI.ReduceSum(exp, new long[] { target.Axis, packedAxis }, 1, 0); + softmax = OrtKI.Div(exp, reduceSum); + } + + return Value.FromTensor(Tensor.FromBytes(new TensorType(new VectorType(input.DataType.ToDataType(), shape.TakeLast(target.PackedAxes.Count).ToArray()), shape.SkipLast(target.PackedAxes.Count).ToArray()), softmax.BytesBuffer.ToArray())); + } + + private IRType Visit(ITypeInferenceContext context, PackedSoftmax target, TensorType input) + { + foreach (var axis in target.PackedAxes) + { + if (axis >= input.Shape.Rank) + { + return new InvalidType("axis out of range"); + } + } + + return input; + } + + private IRType Visit(ITypeInferenceContext context, PackedSoftmax target, DistributedType input) + { + if (Visit(context, target, input.TensorType) is not TensorType tensorType) + { + throw new InvalidOperationException(); + } + + return new DistributedType(tensorType, input.NdSBP, input.Placement); + } +} diff --git a/modules/Nncase.Modules.CPU/Evaluator/CPU/PackedTranspose.cs b/modules/Nncase.Modules.CPU/Evaluator/CPU/PackedTranspose.cs new file mode 100644 index 0000000000..e2b8d1ab9f --- /dev/null +++ b/modules/Nncase.Modules.CPU/Evaluator/CPU/PackedTranspose.cs @@ -0,0 +1,60 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using Nncase.CostModel; +using Nncase.IR; +using Nncase.IR.CPU; +using Nncase.Utilities; +using OrtKISharp; + +namespace Nncase.Evaluator.IR.CPU; + +public sealed class PackedTransposeEvaluator : IEvaluator, ITypeInferencer, ICostEvaluator +{ + public IValue Visit(IEvaluateContext context, PackedTranspose target) + { + var input = context.GetOrtArgumentValue(target, PackedTranspose.Input); + var perm = context.GetArgumentValueAsArray(target, PackedTranspose.Perm); + + var packedAxes = target.PackedAxes.Select(axis => perm.IndexOf(axis)).ToArray(); + var restAxis = LinqUtility.Range(perm.Length, packedAxes.Length).ToArray(); + restAxis = packedAxes.Zip(restAxis).OrderBy(p => p.First).Select(p => p.Second).ToArray(); + + perm = perm.Concat(restAxis).ToArray(); + + var transposed = OrtKI.Transpose(input, perm); + + return Value.FromTensor(Tensor.FromBytes(context.CurrentCall.CheckedDataType, transposed.BytesBuffer.ToArray(), context.CurrentCall.CheckedShape.ToValueArray())); + } + + public IRType Visit(ITypeInferenceContext context, PackedTranspose target) + { + var input = context.CheckArgumentType(target, PackedTranspose.Input); + var permExpr = context.GetArgument(target, PackedTranspose.Perm); + + return input switch + { + DistributedType d => Tensors.TransposeEvaluator.Visit(d, permExpr), + TensorType t => Tensors.TransposeEvaluator.Visit(t, permExpr), + AnyType => AnyType.Default, + _ => new InvalidType(input.GetType().ToString()), + }; + } + + /// + public Cost Visit(ICostEvaluateContext context, PackedTranspose target) + { + var inputType = context.GetArgumentType(target, PackedTranspose.Input); + var outputType = context.GetReturnType(); + + return new() + { + [CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(inputType), + [CostFactorNames.MemoryStore] = CostUtility.GetMemoryAccess(outputType), + }; + } +} diff --git a/modules/Nncase.Modules.CPU/Evaluator/CPU/Store.cs b/modules/Nncase.Modules.CPU/Evaluator/CPU/Store.cs new file mode 100644 index 0000000000..a367696bba --- /dev/null +++ b/modules/Nncase.Modules.CPU/Evaluator/CPU/Store.cs @@ -0,0 +1,27 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Nncase.CostModel; +using Nncase.IR; +using Nncase.IR.CPU; + +namespace Nncase.Evaluator.IR.CPU; + +public sealed class StoreEvaluator : ITypeInferencer, ICostEvaluator +{ + public IRType Visit(ITypeInferenceContext context, Store target) + { + return context.GetArgumentType(target, Store.Input); + } + + public Cost Visit(ICostEvaluateContext context, Store target) => new Cost() + { + [CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(context.GetArgumentType(target, Store.Input)), + [CostFactorNames.MemoryStore] = CostUtility.GetMemoryAccess(context.GetArgumentType(target, Store.Input)), + }; +} diff --git a/modules/Nncase.Modules.CPU/Evaluator/CPU/Unpack.cs b/modules/Nncase.Modules.CPU/Evaluator/CPU/Unpack.cs new file mode 100644 index 0000000000..0f861e7160 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Evaluator/CPU/Unpack.cs @@ -0,0 +1,82 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. +#pragma warning disable SA1010, SA1008 +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using DryIoc.ImTools; +using Nncase.CostModel; +using Nncase.IR; +using Nncase.IR.CPU; +using Nncase.IR.Tensors; +using Nncase.Utilities; +using OrtKISharp; + +namespace Nncase.Evaluator.IR.CPU; + +public sealed class UnpackEvaluator : ITypeInferencer, ICostEvaluator, IEvaluator +{ + /// + public IValue Visit(IEvaluateContext context, Unpack target) + { + var input = context.GetOrtArgumentValue(target, Unpack.Input); + foreach (var axis in target.Axes.Reverse()) + { + input = input.Unpack(axis); + } + + return Value.FromTensor(input.ToTensor()); + } + + /// + public IRType Visit(ITypeInferenceContext context, Unpack target) + { + var input = context.CheckArgumentType(target, Unpack.Input); + + return input switch + { + DistributedType d => Visit(context, target, d), + TensorType t => Visit(context, target, t), + AnyType => AnyType.Default, + _ => new InvalidType(input.GetType().ToString()), + }; + } + + /// + public Cost Visit(ICostEvaluateContext context, Unpack target) + { + var inputType = context.GetArgumentType(target, Unpack.Input); + var outputType = context.GetReturnType(); + + return new() + { + [CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(inputType), + [CostFactorNames.MemoryStore] = CostUtility.GetMemoryAccess(outputType), + }; + } + + public Metric Visit(IMetricEvaluateContext context, Unpack target) + { + var returnType = context.GetReturnType(); + return new() + { + [MetricFactorNames.OffChipMemoryTraffic] = CostUtility.GetMemoryAccess(returnType) * 2, + }; + } + + private IRType Visit(ITypeInferenceContext context, Unpack target, TensorType input) + { + return TypeInference.UnpackType(input, target.Axes); + } + + private IRType Visit(ITypeInferenceContext context, Unpack target, DistributedType input) + { + if (Visit(context, target, input.TensorType) is not TensorType tensorType) + { + throw new InvalidOperationException(); + } + + return new DistributedType(tensorType, input.NdSBP, input.Placement); + } +} diff --git a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Binary.cs b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Binary.cs new file mode 100644 index 0000000000..71f8cb4e8b --- /dev/null +++ b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Binary.cs @@ -0,0 +1,15 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.IR; +using Nncase.TIR.CPU; + +namespace Nncase.Evaluator.TIR.CPU; + +public sealed class BinaryEvaluator : ITypeInferencer +{ + public IRType Visit(ITypeInferenceContext context, Binary target) + { + return TupleType.Void; + } +} diff --git a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/CPUModule.cs b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/CPUModule.cs new file mode 100644 index 0000000000..2f81cdf0ba --- /dev/null +++ b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/CPUModule.cs @@ -0,0 +1,42 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using DryIoc; +using Nncase.Evaluator.Imaging; +using Nncase.Evaluator.NN; +using Nncase.Evaluator.Tensors; +using Nncase.Hosting; + +namespace Nncase.Evaluator.TIR.CPU; + +/// +/// CPU module. +/// +internal class CPUModule : IApplicationPart +{ + public void ConfigureServices(IRegistrator registrator) + { + registrator.RegisterManyInterface(reuse: Reuse.Singleton); + registrator.RegisterManyInterface(reuse: Reuse.Singleton); + registrator.RegisterManyInterface(reuse: Reuse.Singleton); + registrator.RegisterManyInterface(reuse: Reuse.Singleton); + registrator.RegisterManyInterface(reuse: Reuse.Singleton); + registrator.RegisterManyInterface(reuse: Reuse.Singleton); + registrator.RegisterManyInterface(reuse: Reuse.Singleton); + registrator.RegisterManyInterface(reuse: Reuse.Singleton); + registrator.RegisterManyInterface(reuse: Reuse.Singleton); + registrator.RegisterManyInterface(reuse: Reuse.Singleton); + registrator.RegisterManyInterface(reuse: Reuse.Singleton); + registrator.RegisterManyInterface(reuse: Reuse.Singleton); + registrator.RegisterManyInterface(reuse: Reuse.Singleton); + registrator.RegisterManyInterface(reuse: Reuse.Singleton); + registrator.RegisterManyInterface(reuse: Reuse.Singleton); + registrator.RegisterManyInterface(reuse: Reuse.Singleton); + registrator.RegisterManyInterface(reuse: Reuse.Singleton); + registrator.RegisterManyInterface(reuse: Reuse.Singleton); + registrator.RegisterManyInterface(reuse: Reuse.Singleton); + registrator.RegisterManyInterface(reuse: Reuse.Singleton); + registrator.RegisterManyInterface(reuse: Reuse.Singleton); + registrator.RegisterManyInterface(reuse: Reuse.Singleton); + } +} diff --git a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Concat.cs b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Concat.cs new file mode 100644 index 0000000000..bb173f3c71 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Concat.cs @@ -0,0 +1,18 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.Evaluator; +using Nncase.IR; +using Nncase.TIR.CPU; + +namespace Nncase.Evaluator.TIR.CPU; + +public sealed class ConcatEvaluator : ITypeInferencer +{ + public IRType Visit(ITypeInferenceContext context, Concat target) + { + context.CheckArgumentType(target, Concat.Input); + context.CheckArgumentType(target, Concat.Output); + return TupleType.Void; + } +} diff --git a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Gather.cs b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Gather.cs new file mode 100644 index 0000000000..0c2fbc4b0e --- /dev/null +++ b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Gather.cs @@ -0,0 +1,15 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.IR; +using Nncase.TIR.CPU; + +namespace Nncase.Evaluator.TIR.CPU; + +public sealed class GatherEvaluator : ITypeInferencer +{ + public IRType Visit(ITypeInferenceContext context, Gather target) + { + return TupleType.Void; + } +} diff --git a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Matmul.cs b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Matmul.cs new file mode 100644 index 0000000000..6ad2912cfb --- /dev/null +++ b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Matmul.cs @@ -0,0 +1,12 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.IR; +using Nncase.TIR.CPU; + +namespace Nncase.Evaluator.TIR.CPU; + +public sealed class MatmulEvaluator : ITypeInferencer +{ + public IRType Visit(ITypeInferenceContext context, Matmul target) => TupleType.Void; +} diff --git a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Memcopy.cs b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Memcopy.cs new file mode 100644 index 0000000000..e88830a734 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Memcopy.cs @@ -0,0 +1,17 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.IR; +using Nncase.TIR.CPU; + +namespace Nncase.Evaluator.TIR.CPU; + +public class MemcopyEvaluator : ITypeInferencer +{ + public IRType Visit(ITypeInferenceContext context, Memcopy target) + { + _ = context.CheckArgumentType(target, Memcopy.Dest); + _ = context.CheckArgumentType(target, Memcopy.Src); + return TupleType.Void; + } +} diff --git a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Pack.cs b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Pack.cs new file mode 100644 index 0000000000..b85558fae5 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Pack.cs @@ -0,0 +1,19 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using Nncase.CostModel; +using Nncase.IR; +using Nncase.TIR.CPU; +using Nncase.Utilities; +using OrtKISharp; + +namespace Nncase.Evaluator.TIR.CPU; + +public sealed class PackEvaluator : ITypeInferencer +{ + public IRType Visit(ITypeInferenceContext context, Pack target) => TupleType.Void; +} diff --git a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/PackedBinary.cs b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/PackedBinary.cs new file mode 100644 index 0000000000..88e65c8e30 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/PackedBinary.cs @@ -0,0 +1,20 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Numerics; +using System.Runtime.InteropServices; +using Nncase.CostModel; +using Nncase.IR; +using Nncase.TIR.CPU; +using Nncase.Utilities; +using OrtKISharp; + +namespace Nncase.Evaluator.TIR.CPU; + +public sealed class PackedBinaryEvaluator : ITypeInferencer +{ + public IRType Visit(ITypeInferenceContext context, PackedBinary target) => TupleType.Void; +} diff --git a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/PackedLayerNorm.cs b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/PackedLayerNorm.cs new file mode 100644 index 0000000000..6d7bc11e13 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/PackedLayerNorm.cs @@ -0,0 +1,20 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Linq; +using System.Numerics; +using System.Runtime.InteropServices; +using Nncase.CostModel; +using Nncase.IR; +using Nncase.TIR.CPU; +using Nncase.Utilities; +using OrtKISharp; + +namespace Nncase.Evaluator.TIR.CPU; + +public sealed class PackedLayerNormEvaluator : ITypeInferencer +{ + /// + public IRType Visit(ITypeInferenceContext context, PackedLayerNorm target) => TupleType.Void; +} diff --git a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/PackedMatMul.cs b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/PackedMatMul.cs new file mode 100644 index 0000000000..7410f8f21f --- /dev/null +++ b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/PackedMatMul.cs @@ -0,0 +1,19 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Linq; +using System.Numerics; +using System.Runtime.InteropServices; +using Nncase.CostModel; +using Nncase.IR; +using Nncase.TIR.CPU; +using Nncase.Utilities; +using OrtKISharp; + +namespace Nncase.Evaluator.TIR.CPU; + +public sealed class PackedMatMulEvaluator : ITypeInferencer +{ + public IRType Visit(ITypeInferenceContext context, PackedMatMul target) => TupleType.Void; +} diff --git a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/PackedSoftMax.cs b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/PackedSoftMax.cs new file mode 100644 index 0000000000..0035dea489 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/PackedSoftMax.cs @@ -0,0 +1,19 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using Nncase.CostModel; +using Nncase.IR; +using Nncase.TIR.CPU; +using Nncase.Utilities; +using OrtKISharp; + +namespace Nncase.Evaluator.TIR.CPU; + +public sealed class PackedSoftMaxEvaluator : ITypeInferencer +{ + public IRType Visit(ITypeInferenceContext context, PackedSoftmax target) => TupleType.Void; +} diff --git a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/PackedTranspose.cs b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/PackedTranspose.cs new file mode 100644 index 0000000000..1ec6a81748 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/PackedTranspose.cs @@ -0,0 +1,19 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using Nncase.CostModel; +using Nncase.IR; +using Nncase.TIR.CPU; +using Nncase.Utilities; +using OrtKISharp; + +namespace Nncase.Evaluator.TIR.CPU; + +public sealed class PackedTransposeEvaluator : ITypeInferencer +{ + public IRType Visit(ITypeInferenceContext context, PackedTranspose target) => TupleType.Void; +} diff --git a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Pad.cs b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Pad.cs new file mode 100644 index 0000000000..9b811b7fa5 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Pad.cs @@ -0,0 +1,18 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.Evaluator; +using Nncase.IR; +using Nncase.TIR.CPU; + +namespace Nncase.Evaluator.TIR.CPU; + +public sealed class PadEvaluator : ITypeInferencer +{ + public IRType Visit(ITypeInferenceContext context, Pad target) + { + context.CheckArgumentType(target, Pad.Input); + context.CheckArgumentType(target, Pad.Output); + return TupleType.Void; + } +} diff --git a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/PtrOf.cs b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/PtrOf.cs new file mode 100644 index 0000000000..3508f6f931 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/PtrOf.cs @@ -0,0 +1,22 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.IR; +using Nncase.TIR.CPU; + +namespace Nncase.Evaluator.TIR.CPU; + +public sealed class PtrOfEvaluator : ITypeInferencer, IOpPrinter +{ + public IRType Visit(ITypeInferenceContext context, PtrOf target) => new PointerType(target.DataType); + + public string Visit(IIRPrinterContext context, PtrOf target, bool iLmode) + { + if (iLmode) + { + throw new NotSupportedException(); + } + + return $"PtrOf({target.PtrName})"; + } +} diff --git a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Reshape.cs b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Reshape.cs new file mode 100644 index 0000000000..b5e11095b9 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Reshape.cs @@ -0,0 +1,18 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.Evaluator; +using Nncase.IR; +using Nncase.TIR.CPU; + +namespace Nncase.Evaluator.TIR.CPU; + +public sealed class ReshapeEvaluator : ITypeInferencer +{ + public IRType Visit(ITypeInferenceContext context, Reshape target) + { + context.CheckArgumentType(target, Reshape.Input); + context.CheckArgumentType(target, Reshape.Output); + return TupleType.Void; + } +} diff --git a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Slice.cs b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Slice.cs new file mode 100644 index 0000000000..a26491b8eb --- /dev/null +++ b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Slice.cs @@ -0,0 +1,18 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.Evaluator; +using Nncase.IR; +using Nncase.TIR.CPU; + +namespace Nncase.Evaluator.TIR.CPU; + +public sealed class SliceEvaluator : ITypeInferencer +{ + public IRType Visit(ITypeInferenceContext context, Slice target) + { + context.CheckArgumentType(target, Slice.Input); + context.CheckArgumentType(target, Slice.Output); + return TupleType.Void; + } +} diff --git a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/SramPtr.cs b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/SramPtr.cs new file mode 100644 index 0000000000..c9d591d2ac --- /dev/null +++ b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/SramPtr.cs @@ -0,0 +1,12 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.IR; +using Nncase.TIR.CPU; + +namespace Nncase.Evaluator.TIR.CPU; + +public sealed class SramPtrEvaluator : ITypeInferencer +{ + public IRType Visit(ITypeInferenceContext context, SramPtr target) => new PointerType(target.DataType); +} diff --git a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Swish.cs b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Swish.cs new file mode 100644 index 0000000000..fb8209afc5 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Swish.cs @@ -0,0 +1,18 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.Evaluator; +using Nncase.IR; +using Nncase.TIR.CPU; + +namespace Nncase.Evaluator.TIR.CPU; + +public sealed class SwishEvaluator : ITypeInferencer +{ + public IRType Visit(ITypeInferenceContext context, Swish target) + { + context.CheckArgumentType(target, Swish.Input); + context.CheckArgumentType(target, Swish.Output); + return TupleType.Void; + } +} diff --git a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/TensorLoad.cs b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/TensorLoad.cs new file mode 100644 index 0000000000..c41eacf55f --- /dev/null +++ b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/TensorLoad.cs @@ -0,0 +1,17 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.IR; +using Nncase.TIR.CPU; + +namespace Nncase.Evaluator.TIR.CPU; + +public class TensorLoadEvaluator : ITypeInferencer +{ + public IRType Visit(ITypeInferenceContext context, TensorLoad target) + { + _ = context.CheckArgumentType(target, TensorLoad.Dest); + _ = context.CheckArgumentType(target, TensorLoad.Src); + return TupleType.Void; + } +} diff --git a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/TensorStore.cs b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/TensorStore.cs new file mode 100644 index 0000000000..742a8f1592 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/TensorStore.cs @@ -0,0 +1,17 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.IR; +using Nncase.TIR.CPU; + +namespace Nncase.Evaluator.TIR.CPU; + +public sealed class TensorStoreEvaluator : ITypeInferencer +{ + public IRType Visit(ITypeInferenceContext context, TensorStore target) + { + _ = context.CheckArgumentType(target, TensorStore.Src); + _ = context.CheckArgumentType(target, TensorStore.Dest); + return TupleType.Void; + } +} diff --git a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Transpose.cs b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Transpose.cs new file mode 100644 index 0000000000..c769ce19e6 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Transpose.cs @@ -0,0 +1,18 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.Evaluator; +using Nncase.IR; +using Nncase.TIR.CPU; + +namespace Nncase.Evaluator.TIR.CPU; + +public sealed class TransposeEvaluator : ITypeInferencer +{ + public IRType Visit(ITypeInferenceContext context, Transpose target) + { + context.CheckArgumentType(target, Transpose.Input); + context.CheckArgumentType(target, Transpose.Output); + return TupleType.Void; + } +} diff --git a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Unary.cs b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Unary.cs new file mode 100644 index 0000000000..5fd104b57f --- /dev/null +++ b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Unary.cs @@ -0,0 +1,18 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.Evaluator; +using Nncase.IR; +using Nncase.TIR.CPU; + +namespace Nncase.Evaluator.TIR.CPU; + +public sealed class UnaryEvaluator : ITypeInferencer +{ + public IRType Visit(ITypeInferenceContext context, Unary target) + { + context.CheckArgumentType(target, Unary.Input); + context.CheckArgumentType(target, Unary.Output); + return TupleType.Void; + } +} diff --git a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Unpack.cs b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Unpack.cs new file mode 100644 index 0000000000..7e4d468377 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Unpack.cs @@ -0,0 +1,21 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using DryIoc.ImTools; +using Nncase.CostModel; +using Nncase.IR; +using Nncase.TIR.CPU; +using Nncase.Utilities; +using OrtKISharp; + +namespace Nncase.Evaluator.TIR.CPU; + +public sealed class UnpackEvaluator : ITypeInferencer +{ + /// + public IRType Visit(ITypeInferenceContext context, Unpack target) => TupleType.Void; +} diff --git a/modules/Nncase.Modules.CPU/IR/CPU/Boxing.cs b/modules/Nncase.Modules.CPU/IR/CPU/Boxing.cs new file mode 100644 index 0000000000..d86c10bdaf --- /dev/null +++ b/modules/Nncase.Modules.CPU/IR/CPU/Boxing.cs @@ -0,0 +1,29 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Nncase.PatternMatch; + +namespace Nncase.IR.CPU; + +/// +/// Boxing expression. +/// +[PatternFunctionalGenerator] +public sealed partial class Boxing : Op +{ + /// + /// Gets input. + /// + public static readonly ParameterInfo Input = new(typeof(Boxing), 0, "input"); + + public IRType NewType { get; } + + /// + public override string DisplayProperty() => $"{NewType}"; +} diff --git a/modules/Nncase.Modules.CPU/IR/CPU/CPUKernelOp.cs b/modules/Nncase.Modules.CPU/IR/CPU/CPUKernelOp.cs new file mode 100644 index 0000000000..22a75beb56 --- /dev/null +++ b/modules/Nncase.Modules.CPU/IR/CPU/CPUKernelOp.cs @@ -0,0 +1,33 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Nncase.IR.Math; +using Nncase.PatternMatch; + +namespace Nncase.IR.CPU; + +public sealed class CPUKernelOp : Op +{ + private readonly ExprPinner _exprPinner; + + public CPUKernelOp(Op target) + { + _exprPinner = new(target); + Target = target; + } + + /// + /// Gets the target. + /// + public Op Target { get; } + + /// + public override IEnumerable Parameters => Target.Parameters; + + public override string DisplayProperty() => Target.GetType().Name; +} diff --git a/modules/Nncase.Modules.CPU/IR/CPU/Functional.cs b/modules/Nncase.Modules.CPU/IR/CPU/Functional.cs new file mode 100644 index 0000000000..ebf9e8d39f --- /dev/null +++ b/modules/Nncase.Modules.CPU/IR/CPU/Functional.cs @@ -0,0 +1,90 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Nncase.IR.CPU; + +namespace Nncase.IR.F; + +public partial class CPU +{ + /// + /// Call cpu kernel. + /// + /// Unary operator. + /// Source inputs. + /// Result expression. + public static Call CPUKernel(Op target, params Expr[] inputs) + { + return new Call(new CPUKernelOp(target), inputs); + } + + public static Call Boxing(Expr input, IRType type) + { + return new Call(new Boxing(type), input); + } + + public static Call Load(Expr input) + { + return new Call(new Load(), input); + } + + public static Call Store(Expr input) + { + return new Call(new Store(), input); + } + + public static Expr Pack(Expr input, int[] lanes, int[] axes) + { + if (lanes.Length != axes.Length) + { + throw new NotSupportedException(); + } + + if (axes.Length == 0) + { + return input; + } + + return new Call(new Pack(lanes, axes), input); + } + + public static Expr Unpack(Expr input, int[] axes) + { + if (axes.Length == 0) + { + return input; + } + + return new Call(new Unpack(axes), input); + } + + public static Expr PackedSoftmax(Expr input, int axis, IRArray packedAxes) + { + return new Call(new PackedSoftmax(axis, packedAxes), input); + } + + public static Expr PackedLayerNorm(Expr input, Expr scale, Expr bias, int axis, float epsilon, bool usemean, IRArray packedAxes, IRArray padedNums) + { + return new Call(new PackedLayerNorm(axis, epsilon, usemean, packedAxes, padedNums), input, scale, bias); + } + + public static Expr PackedMatMul(Expr lhs, Expr rhs, IRArray lhsPackedAxes, IRArray lhsPadedNums, IRArray rhsPackedAxes, IRArray rhsPadedNums) + { + return new Call(new PackedMatMul(lhsPackedAxes, lhsPadedNums, rhsPackedAxes, rhsPadedNums), lhs, rhs); + } + + public static Expr PackedBinary(Expr lhs, Expr rhs, BinaryOp binaryOp, IRArray lhsPackedAxes, IRArray lhsPadedNums, IRArray rhsPackedAxes, IRArray rhsPadedNums) + { + return new Call(new PackedBinary(binaryOp, lhsPackedAxes, lhsPadedNums, rhsPackedAxes, rhsPadedNums), lhs, rhs); + } + + public static Expr PackedTranspose(Expr input, Expr perm, IRArray packedAxes) + { + return new Call(new PackedTranspose(packedAxes), input, perm); + } +} diff --git a/modules/Nncase.Modules.CPU/IR/CPU/Load.cs b/modules/Nncase.Modules.CPU/IR/CPU/Load.cs new file mode 100644 index 0000000000..f92faea4fd --- /dev/null +++ b/modules/Nncase.Modules.CPU/IR/CPU/Load.cs @@ -0,0 +1,21 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Nncase.IR.Math; +using Nncase.PatternMatch; + +namespace Nncase.IR.CPU; + +[PatternFunctionalGenerator] +public sealed partial class Load : Op +{ + /// + /// Gets input. + /// + public static readonly ParameterInfo Input = new(typeof(Load), 0, "input", ParameterKind.Input); +} diff --git a/modules/Nncase.Modules.CPU/IR/CPU/Pack.cs b/modules/Nncase.Modules.CPU/IR/CPU/Pack.cs new file mode 100644 index 0000000000..a06a2e20ae --- /dev/null +++ b/modules/Nncase.Modules.CPU/IR/CPU/Pack.cs @@ -0,0 +1,31 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Nncase.PatternMatch; + +namespace Nncase.IR.CPU; + +/// +/// Pack expression. +/// +[PatternFunctionalGenerator] +public sealed partial class Pack : Op +{ + /// + /// Gets input. + /// + public static readonly ParameterInfo Input = new(typeof(Pack), 0, "input", ParameterKind.Input); + + public IRArray Lanes { get; } + + public IRArray Axes { get; } + + /// + public override string DisplayProperty() => $"Lanes: {Lanes}, Axes: {Axes}"; +} diff --git a/modules/Nncase.Modules.CPU/IR/CPU/PackedBinary.cs b/modules/Nncase.Modules.CPU/IR/CPU/PackedBinary.cs new file mode 100644 index 0000000000..2ff1c88654 --- /dev/null +++ b/modules/Nncase.Modules.CPU/IR/CPU/PackedBinary.cs @@ -0,0 +1,32 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.PatternMatch; + +namespace Nncase.IR.CPU; + +[PatternFunctionalGenerator] +public sealed partial class PackedBinary : PackedOp +{ + /// + /// Gets input. + /// + public static readonly ParameterInfo Lhs = new(typeof(PackedBinary), 0, "lhs", ParameterKind.Input); + + /// + /// Gets Other. + /// + public static readonly ParameterInfo Rhs = new(typeof(PackedBinary), 1, "rhs", ParameterKind.Input); + + public BinaryOp BinaryOp { get; } + + public IRArray LhsPackedAxes { get; } + + public IRArray LhsPadedNums { get; } + + public IRArray RhsPackedAxes { get; } + + public IRArray RhsPadedNums { get; } + + public override string DisplayProperty() => $"BinaryOp: {BinaryOp}, LhsPackedAxes: {LhsPackedAxes}, LhsPadedNums: {LhsPadedNums}, RhsPackedAxes: {RhsPackedAxes}, RhsPadedNums: {RhsPadedNums}"; +} diff --git a/modules/Nncase.Modules.CPU/IR/CPU/PackedLayerNorm.cs b/modules/Nncase.Modules.CPU/IR/CPU/PackedLayerNorm.cs new file mode 100644 index 0000000000..8b5e96e577 --- /dev/null +++ b/modules/Nncase.Modules.CPU/IR/CPU/PackedLayerNorm.cs @@ -0,0 +1,37 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.PatternMatch; + +namespace Nncase.IR.CPU; + +[PatternFunctionalGenerator] +public sealed partial class PackedLayerNorm : PackedOp +{ + /// + /// Gets input. + /// + public static readonly ParameterInfo Input = new(typeof(PackedLayerNorm), 0, "input", ParameterKind.Input); + + /// + /// Gets scale. + /// + public static readonly ParameterInfo Scale = new(typeof(PackedLayerNorm), 1, "scale", ParameterKind.Input); + + /// + /// Gets bias. + /// + public static readonly ParameterInfo Bias = new(typeof(PackedLayerNorm), 2, "bias", ParameterKind.Input); + + public int Axis { get; } + + public float Epsilon { get; } + + public bool UseMean { get; } + + public IRArray PackedAxes { get; } + + public IRArray PadedNums { get; } + + public override string DisplayProperty() => $"Axis: {Axis}, Epsilon: {Epsilon}, UseMean: {UseMean}, PackedAxes: {PackedAxes}, PadedNums: {PadedNums}"; +} diff --git a/modules/Nncase.Modules.CPU/IR/CPU/PackedMatMul.cs b/modules/Nncase.Modules.CPU/IR/CPU/PackedMatMul.cs new file mode 100644 index 0000000000..ce562d042c --- /dev/null +++ b/modules/Nncase.Modules.CPU/IR/CPU/PackedMatMul.cs @@ -0,0 +1,30 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.PatternMatch; + +namespace Nncase.IR.CPU; + +[PatternFunctionalGenerator] +public sealed partial class PackedMatMul : PackedOp +{ + /// + /// Gets input. + /// + public static readonly ParameterInfo Lhs = new(typeof(PackedMatMul), 0, "lhs", ParameterKind.Input); + + /// + /// Gets Other. + /// + public static readonly ParameterInfo Rhs = new(typeof(PackedMatMul), 1, "rhs", ParameterKind.Input); + + public IRArray LhsPackedAxes { get; } + + public IRArray LhsPadedNums { get; } + + public IRArray RhsPackedAxes { get; } + + public IRArray RhsPadedNums { get; } + + public override string DisplayProperty() => $"LhsPackedAxes: {LhsPackedAxes}, LhsPadedNums: {LhsPadedNums}, RhsPackedAxes: {RhsPackedAxes}, RhsPadedNums: {RhsPadedNums}"; +} diff --git a/modules/Nncase.Modules.CPU/IR/CPU/PackedOp.cs b/modules/Nncase.Modules.CPU/IR/CPU/PackedOp.cs new file mode 100644 index 0000000000..02f53d11ee --- /dev/null +++ b/modules/Nncase.Modules.CPU/IR/CPU/PackedOp.cs @@ -0,0 +1,16 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Nncase.IR; + +namespace Nncase.IR.CPU; + +public abstract class PackedOp : Op +{ +} diff --git a/modules/Nncase.Modules.CPU/IR/CPU/PackedSoftMax.cs b/modules/Nncase.Modules.CPU/IR/CPU/PackedSoftMax.cs new file mode 100644 index 0000000000..18994bd010 --- /dev/null +++ b/modules/Nncase.Modules.CPU/IR/CPU/PackedSoftMax.cs @@ -0,0 +1,18 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.PatternMatch; + +namespace Nncase.IR.CPU; + +[PatternFunctionalGenerator] +public sealed partial class PackedSoftmax : PackedOp +{ + public static readonly ParameterInfo Input = new(typeof(PackedSoftmax), 0, "input", ParameterKind.Input); + + public int Axis { get; } + + public IRArray PackedAxes { get; } + + public override string DisplayProperty() => $"{Axis}, {PackedAxes}"; +} diff --git a/modules/Nncase.Modules.CPU/IR/CPU/PackedTranspose.cs b/modules/Nncase.Modules.CPU/IR/CPU/PackedTranspose.cs new file mode 100644 index 0000000000..2acd936ddf --- /dev/null +++ b/modules/Nncase.Modules.CPU/IR/CPU/PackedTranspose.cs @@ -0,0 +1,23 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.PatternMatch; +using static Nncase.IR.TypePatternUtility; + +namespace Nncase.IR.CPU; + +[PatternFunctionalGenerator] +public sealed partial class PackedTranspose : PackedOp +{ + /// + /// Gets input. + /// + public static readonly ParameterInfo Input = new(typeof(PackedTranspose), 0, "input", ParameterKind.Input); + + /// + /// Gets perm. + /// + public static readonly ParameterInfo Perm = new(typeof(PackedTranspose), 1, "perm", HasRank(1) & IsIntegral()); + + public IRArray PackedAxes { get; } +} diff --git a/modules/Nncase.Modules.CPU/IR/CPU/Store.cs b/modules/Nncase.Modules.CPU/IR/CPU/Store.cs new file mode 100644 index 0000000000..aafc7a7773 --- /dev/null +++ b/modules/Nncase.Modules.CPU/IR/CPU/Store.cs @@ -0,0 +1,21 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Nncase.IR.Math; +using Nncase.PatternMatch; + +namespace Nncase.IR.CPU; + +[PatternFunctionalGenerator] +public sealed partial class Store : Op +{ + /// + /// Gets input. + /// + public static readonly ParameterInfo Input = new(typeof(Store), 0, "input", ParameterKind.Input); +} diff --git a/modules/Nncase.Modules.CPU/IR/CPU/Unpack.cs b/modules/Nncase.Modules.CPU/IR/CPU/Unpack.cs new file mode 100644 index 0000000000..f446923be1 --- /dev/null +++ b/modules/Nncase.Modules.CPU/IR/CPU/Unpack.cs @@ -0,0 +1,29 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Nncase.PatternMatch; + +namespace Nncase.IR.CPU; + +/// +/// Unpack expression. +/// +[PatternFunctionalGenerator] +public sealed partial class Unpack : Op +{ + /// + /// Gets input. + /// + public static readonly ParameterInfo Input = new(typeof(Unpack), 0, "input", ParameterKind.Input); + + public IRArray Axes { get; } + + /// + public override string DisplayProperty() => $"Axes: {Axes}"; +} diff --git a/modules/Nncase.Modules.CPU/Nncase.Modules.CPU.csproj b/modules/Nncase.Modules.CPU/Nncase.Modules.CPU.csproj new file mode 100644 index 0000000000..614d4d0318 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Nncase.Modules.CPU.csproj @@ -0,0 +1,46 @@ + + + + Nncase + enable + true + true + True + true + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Always + + + Always + + + diff --git a/modules/Nncase.Modules.CPU/Passes/BufferSchedule/BufferScheduler.cs b/modules/Nncase.Modules.CPU/Passes/BufferSchedule/BufferScheduler.cs new file mode 100644 index 0000000000..9dbbc2cb8c --- /dev/null +++ b/modules/Nncase.Modules.CPU/Passes/BufferSchedule/BufferScheduler.cs @@ -0,0 +1,48 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System.Text.RegularExpressions; +using NetFabric.Hyperlinq; +using Nncase.IR; +using Nncase.TIR; +using Buffer = Nncase.TIR.Buffer; + +namespace Nncase.Passes.BufferSchedule; + +internal class Lifeness +{ + public Lifeness(int start, int end) + { + Start = start; + End = end; + } + + public int Start { get; set; } + + public int End { get; set; } + + public override string ToString() + { + return $"Lifeness({Start}, {End})"; + } +} + +internal class ScheduledBuffer +{ + public ScheduledBuffer(Lifeness lifeness, Buffer buffer) + { + Lifeness = lifeness; + Buffer = buffer; + } + + public Lifeness Lifeness { get; } + + public Buffer Buffer { get; } + + public string Name => Buffer.Name; + + public override string ToString() + { + return $"ScheduledBuffer(\"{Name}\", {Lifeness}, Location({Buffer.MemSpan.Start}, {Buffer.MemSpan.Size}), [{string.Join(",", Buffer.Dimensions.ToArray().Select(s => ((TensorConst)s).Value[0]))}], [{string.Join(",", Buffer.Strides.ToArray().Select(s => ((TensorConst)s).Value[0]))}])"; + } +} diff --git a/modules/Nncase.Modules.CPU/Passes/BufferSchedule/SRAM.cs b/modules/Nncase.Modules.CPU/Passes/BufferSchedule/SRAM.cs new file mode 100644 index 0000000000..1e2e924898 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Passes/BufferSchedule/SRAM.cs @@ -0,0 +1,11 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +namespace Nncase.Passes.BufferSchedule; + +public class SRAM +{ + public static int SramSizePerBlock { get; } = 2 * 1024 * 1024; + + public static int SramSizePerThread { get; } = SramSizePerBlock / 4; +} diff --git a/modules/Nncase.Modules.CPU/Passes/BufferSchedule/ScheduleResponse.cs b/modules/Nncase.Modules.CPU/Passes/BufferSchedule/ScheduleResponse.cs new file mode 100644 index 0000000000..10acc1fe8a --- /dev/null +++ b/modules/Nncase.Modules.CPU/Passes/BufferSchedule/ScheduleResponse.cs @@ -0,0 +1,138 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System.Text; +using Nncase.Diagnostics; +using Nncase.IR; +using Nncase.TIR; + +namespace Nncase.Passes.BufferSchedule; + +internal sealed class ScheduledResponse +{ + private const string _bufferTypesContents = @"from dataclasses import dataclass +from enum import Enum +from typing import List +@dataclass +class Lifeness(): + start: int + end: int + +@dataclass +class Location(): + start: int + size: int + def __str__(self) -> str: + return f'(start: {self.start}, size {self.size})' + +@dataclass +class ScheduledBuffer(): + name: str + lifeness: Lifeness + location: Location + shape: List[int] + stride: List[int] +"; + + private const string _drawContents = @"from bokeh.models import ColumnDataSource, HoverTool, FuncTickFormatter, SingleIntervalTicker, SaveTool, WheelZoomTool, WheelPanTool, ResetTool +from bokeh.palettes import Category20_20 as palette +from bokeh.plotting import figure, show +from {0} import buffers +import itertools +colors = itertools.cycle(palette) + +source = {{ + ""name"": [], + ""x"": [], + ""y"": [], + ""width"": [], + ""height"": [], + ""color"": [], + ""location"": [], + ""shape"":[], + ""stride"":[], +}} + +y_range_max = 0 +for buffer in buffers: + source[""name""].append(buffer.name) + width = buffer.lifeness.end - buffer.lifeness.start + x = buffer.lifeness.start + (width / 2) + height = buffer.location.size + y = buffer.location.start + (height / 2) + y_range_max = max(y_range_max,y) + source[""x""].append(x) + source[""y""].append(y) + source[""width""].append(width) + source[""height""].append(height) + source[""color""].append(next(colors)) + source[""location""].append(str(buffer.location)) + source[""shape""].append(','.join([str(s) for s in buffer.shape])) + source[""stride""].append(','.join([str(s) for s in buffer.stride])) + +source = ColumnDataSource(source) +hover = HoverTool(tooltips = [('name','@name'),('location','@location'), + ('shape','@shape'),('stride','@stride')]) + +p = figure(tools=[hover, WheelPanTool(), SaveTool(), WheelZoomTool(), ResetTool()], width=1280, height=720, + y_range=(0, min(y_range_max * 2,{1})), + title=""Local Buffer LifeTime (by Steps)"") +p.rect(x=""x"", y=""y"", width=""width"", height=""height"", fill_color=""color"", source=source) + +p.yaxis.ticker = SingleIntervalTicker(interval=1024, num_minor_ticks=0) +p.yaxis.formatter = FuncTickFormatter(code="""""" + return Math.floor(tick / (1024)) +"""""") +p.ygrid.grid_line_color = 'navy' +p.ygrid.grid_line_dash = [6, 4] + +p.xaxis.axis_label = ""Time (steps)"" +p.outline_line_color = None + +show(p) +"; + + private const string _schedBufferContents = @"from buffer_types import Lifeness, Location, ScheduledBuffer +# Generator Information: {0} +buffers = [ +{1} +] +"; + + private readonly IReadOnlyDictionary _bufferLifenessMap; + + public ScheduledResponse( + IReadOnlyDictionary bufferLifenessMap, + bool success) + { + _bufferLifenessMap = bufferLifenessMap; + Success = success; + } + + public bool Success { get; } + + public void Dump(string file_name, string generatorInformation) + { + var path = Path.Combine(DumpScope.Current.Directory, "buffer_types.py"); + if (!File.Exists(path)) + { + File.WriteAllText(path, _bufferTypesContents); + } + + path = Path.Combine(DumpScope.Current.Directory, "draw.py"); + if (!File.Exists(path)) + { + File.WriteAllText(path, string.Format(_drawContents, file_name, SRAM.SramSizePerThread)); + } + + var code = string.Format( + _schedBufferContents, + generatorInformation, + string.Join( + ",\n", + _bufferLifenessMap.Select(kv => _bufferLifenessMap[kv.Key]))); + + path = Path.Combine(DumpScope.Current.Directory, $"{file_name}.py"); + File.WriteAllText(path, code, System.Text.Encoding.UTF8); + } +} diff --git a/modules/Nncase.Modules.CPU/Passes/BufferSchedule/SchedulerSolver.cs b/modules/Nncase.Modules.CPU/Passes/BufferSchedule/SchedulerSolver.cs new file mode 100644 index 0000000000..93af045f4f --- /dev/null +++ b/modules/Nncase.Modules.CPU/Passes/BufferSchedule/SchedulerSolver.cs @@ -0,0 +1,114 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Google.OrTools.Sat; +using NetFabric.Hyperlinq; +using Nncase.IR; +using Nncase.TIR; + +namespace Nncase.Passes.BufferSchedule; + +internal static class SchedulerSolver +{ + public static bool ScheduleByCpModel( + IReadOnlyDictionary lifenessMap, + bool multiWorkers, + float timeout, + out Dictionary scheduledBuffer) + { + scheduledBuffer = new(ReferenceEqualityComparer.Instance); + bool invalidDomain = false; + var model = new CpModel(); + + var yMap = new Dictionary(ReferenceEqualityComparer.Instance); + + // 1. add lifeness overlap constraint + var lifenessNoOverlap = model.AddNoOverlap2D(); + var interval_vars = lifenessMap.Where(sched => sched.Value.Buffer.MemSpan.Location == MemoryLocation.L2Data).Select(sched => + { + var lifeness = lifenessMap[sched.Key].Lifeness; + var buffer = sched.Value.Buffer.MemSpan; + var x = model.NewIntervalVar( + model.NewConstant(lifeness.Start), + model.NewConstant(lifeness.End - lifeness.Start), + model.NewConstant(lifeness.End), + "x"); + + var y_start_domain = SRAM.SramSizePerThread - ((TensorConst)buffer.Size).Value.ToScalar(); + if (y_start_domain <= 0) + { + invalidDomain = true; + } + + var y_start = model.NewIntVar(0, y_start_domain, $"{sched.Value.Buffer.Name}_y_start"); + + var y = model.NewFixedSizeIntervalVar( + y_start, + ((TensorConst)buffer.Size).Value.ToScalar(), + "y"); + + yMap.Add(sched.Value.Buffer, (y, y_start)); + + lifenessNoOverlap.AddRectangle(x, y); + return (x, y); + }).ToList(); + + if (invalidDomain) + { + return false; + } + + var solver = new CpSolver(); + var workers = multiWorkers ? '0' : '1'; + solver.StringParameters = $"max_time_in_seconds:{timeout},num_workers:{workers}"; + + var callback = new EarlyStopCallback(3); + CpSolverStatus solve_status = solver.Solve(model, callback); + + if (solve_status == CpSolverStatus.Unknown) + { + return false; + } + + if (solve_status == CpSolverStatus.ModelInvalid) + { + throw new InvalidDataException(model.Validate()); + } + + if (solve_status != CpSolverStatus.Optimal && solve_status != CpSolverStatus.Feasible) + { + return false; + } + + foreach (var (expr, vars) in lifenessMap.Where(sched => sched.Value.Buffer.MemSpan.Location == MemoryLocation.L2Data).Select(kv => kv.Key).Zip(interval_vars)) + { + var buffer = lifenessMap[expr].Buffer; + var start = TIR.F.CPU.SramPtr(solver.Value(vars.y.StartExpr()), buffer.ElemType); + var schedBuffer = buffer.With(memSpan: buffer.MemSpan.With(start: start)); + scheduledBuffer.Add(expr, new ScheduledBuffer(lifenessMap[expr].Lifeness, schedBuffer)); + } + + return true; + } +} + +internal sealed class EarlyStopCallback : CpSolverSolutionCallback +{ + private readonly int _solutionLimit; + + private int _solutionCount; + + public EarlyStopCallback(int limit) + { + _solutionLimit = limit; + } + + public override void OnSolutionCallback() + { + _solutionCount++; + if (_solutionCount > _solutionLimit) + { + StopSearch(); + } + } +} diff --git a/modules/Nncase.Modules.CPU/Passes/CPUFusionToModulePass.cs b/modules/Nncase.Modules.CPU/Passes/CPUFusionToModulePass.cs new file mode 100644 index 0000000000..ec5267cdd8 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Passes/CPUFusionToModulePass.cs @@ -0,0 +1,32 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using Nncase.Diagnostics; +using Nncase.IR; +using Nncase.Passes.Analysis; +using Nncase.Passes.Mutators; +using Nncase.Passes.Tile; +using Nncase.Targets; +using Nncase.TIR; + +namespace Nncase.Passes; + +internal sealed class CPUFusionToModulePass : ModulePass +{ + /// + protected override Task RunCoreAsync(IRModule module, RunPassContext options) + { + foreach (var item in ExprCollector.Collect(module.Entry!).OfType().Where(f => f.ModuleKind == CPUTarget.Kind)) + { + module.Add(item); + } + + return Task.FromResult(module); + } +} diff --git a/modules/Nncase.Modules.CPU/Passes/CPUFusionToTirPass.cs b/modules/Nncase.Modules.CPU/Passes/CPUFusionToTirPass.cs new file mode 100644 index 0000000000..3f279aeafa --- /dev/null +++ b/modules/Nncase.Modules.CPU/Passes/CPUFusionToTirPass.cs @@ -0,0 +1,78 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using Nncase.Diagnostics; +using Nncase.IR; +using Nncase.Passes.Analysis; +using Nncase.Passes.Mutators; +using Nncase.Passes.Tile; +using Nncase.Targets; +using Nncase.TIR; + +namespace Nncase.Passes; + +internal sealed class CPUFusionToTirPass : ModulePass +{ + private IAnalyzerManager AnalyzerManager => CompileSession.GetRequiredService(); + + /// + protected override Task RunCoreAsync(IRModule module, RunPassContext options) + { + HashSet kernelFuncs = new(ReferenceEqualityComparer.Instance); + HashSet deviceFuncs = new(ReferenceEqualityComparer.Instance); + + for (int i = 0; i < module.Functions.Count; i++) + { + if (module.Functions[i] is Fusion { ModuleKind: CPUTarget.Kind } fusion && fusion.Name.EndsWith("kernel")) + { + // var analysis = new Dictionary + // { + // [typeof(IExprUserAnalysisResult)] = AnalyzerManager.GetAnaylsis(module.Functions[i]), + // }; + // var rewriter = new DataFlowMergeRewriter(); + var fusionCheckCache = new Dictionary(ReferenceEqualityComparer.Instance); + + // var post = (Fusion)rewriter.Rewrite( + // fusion, + // new IMergeRewriteRule[] { + // new CPUSameInputFusionMergeRule(), + // new CPUMultiInputFusionMergeRule(), + // }, + // (rule, option) => new CPUFusionGroupMutator(fusionCheckCache, rule, option), + // new() { AnalysisResults = analysis, MatchOptions = new FusionGroupMutator.GroupedMatchOptions() }); + // if (DumpScope.Current.IsEnabled(DumpFlags.PassIR)) + // { + // DumpScope.Current.DumpIR(post, string.Empty, "L2Tiled"); + // } + var post = fusion; + var primBody = new List(); + var visitor = new KernelToTIRVisitor(primBody, deviceFuncs, fusionCheckCache); + visitor.Convert(post); + var primFunc = T.PrimFunc(post.Name, post.ModuleKind, visitor.InputBuffers.Concat(visitor.OutputBuffers).ToArray()).Body(primBody.ToArray()).Build(); + primFunc.SchedResult.DataUsage = visitor.DataUsage; + primFunc.SchedResult.DataAlign = visitor.MaxDTypeSize; + var primWrapper = new PrimFunctionWrapper(primFunc, visitor.InputBuffers.Count()); + module.Replace(i, primWrapper); + kernelFuncs.Add(primWrapper); + } + } + + foreach (var item in kernelFuncs) + { + module.Add(item.Target); + } + + foreach (var item in deviceFuncs) + { + module.Add(item); + } + + return Task.FromResult(module); + } +} diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/CPU/Affine/LowerUnary.cs b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/Affine/LowerUnary.cs new file mode 100644 index 0000000000..f07ce16f14 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/Affine/LowerUnary.cs @@ -0,0 +1,39 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Nncase.IR; +using Nncase.IR.Affine; +using Nncase.IR.Math; +using Nncase.PatternMatch; +using Nncase.Targets; +using static Nncase.IR.F.CPU; +using static Nncase.IR.TypePatternUtility; +using static Nncase.PatternMatch.F.Math; +using static Nncase.PatternMatch.Utility; + +namespace Nncase.Passes.Rules.CPU.Affine; + +[RuleGenerator] +public partial class LowerUnary : RewriteRule +{ + /// + public override Pattern Pattern { get; } = IsUnary( + target_name: "unary", + _ => true, + IsWildcard("input") with { TypePattern = HasFixedShape() }); + + private Expr GetReplace(Unary unary, Expr input) + { + var rank = input.CheckedShape.Rank; + return IR.F.Affine.Grid(CPUTarget.Kind) + .Read(input, AffineMap.Identity(rank), out var inTile) + .Write(TIR.T.CreateBuffer(input.CheckedTensorType, TIR.MemoryLocation.Data, out _), AffineMap.Identity(rank), out var outTile) + .Body(TIR.F.CPU.Unary(unary.UnaryOp, inTile, outTile)) + .Build(); + } +} diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/CPU/AutoDistributed.cs b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/AutoDistributed.cs new file mode 100644 index 0000000000..c6df852fa0 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/AutoDistributed.cs @@ -0,0 +1,348 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System.Reactive; +using System.Runtime.CompilerServices; +using NetFabric.Hyperlinq; +using Nncase.CodeGen; +using Nncase.IR; +using Nncase.IR.CPU; +using Nncase.IR.Tensors; +using Nncase.PatternMatch; +using Nncase.Targets; +using static Nncase.PatternMatch.Utility; + +[assembly: InternalsVisibleTo("Nncase.Tests")] + +namespace Nncase.Passes.Rules; + +/// +/// auto distributed the xpu fusion. +/// +[RuleGenerator] +public sealed partial class AutoDistributed : IRewriteRule +{ + private readonly CompileOptions _compileOptions; + + public AutoDistributed(CompileOptions compileOptions) + { + _compileOptions = compileOptions; + } + + public IPattern Pattern { get; } = IsCallWildcard("call", IsFusion("fusion", CPUTarget.Kind, IsWildcard("body"), IsVArgsRepeat("parameters", () => IsVar()))); + + private Expr? GetReplace(Call call, Fusion fusion, Expr body, IReadOnlyList parameters, IReadOnlyList callParams) + { + // 1. convert to distribute graph + if (body is Call { Target: Boxing } || (body is IR.Tuple tp && tp.Fields.AsValueEnumerable().Any(e => e is Call { Target: Boxing }))) + { + return null; + } + + var distConverter = new AutoDistributedConvertVisitor(_compileOptions.TargetCompileOptions is CPUCompileOptions options ? options : CPUCompileOptions.Default); + var newbody = distConverter.Convert(body); + var newFusion = fusion.With(moduleKind: CPUTarget.Kind, body: newbody, parameters: parameters.Cast().ToArray()); + return new Call(newFusion, callParams.ToArray()); + } +} + +internal sealed class AutoDistributedConvertVisitor : ExprVisitor>, Unit> +{ + public AutoDistributedConvertVisitor(CPUCompileOptions compileOptions) + { + Placement = new Placement(compileOptions.Hierarchy, compileOptions.HierarchyNames); + CompileOptions = compileOptions; + } + + public Placement Placement { get; } + + public CPUCompileOptions CompileOptions { get; } + + public static IReadOnlyList GetLeafCandidateBoxings(Expr expr, Placement placement) + { + return Utilities.DistributedUtility.GetLeafCandidateNDSBPs((TensorType)expr.CheckedType, placement). + Select(ndsbp => IR.F.CPU.Boxing(expr, new DistributedType((TensorType)expr.CheckedType, ndsbp, placement))). + ToArray(); + } + + public Expr Convert(Expr body) + { + var createFinalBoxing = (Expr e, TensorType type) => + { + var d = (DistributedType)e.CheckedType; + if (d.NdSBP.Any(s => s is SBPPartialSum)) + { + var boxingP2B = IR.F.CPU.Boxing(e, new DistributedType(type, d.NdSBP.Select(s => s is SBPPartialSum ? SBP.B : s).ToArray(), Placement)); + return IR.F.CPU.Boxing(boxingP2B, type); + } + + return IR.F.CPU.Boxing(e, type); + }; + + var equivalents = Visit(body).Select(g => g.Value[0] switch + { + IR.Tuple tp => new IR.Tuple(tp.Fields.ToArray().Select((f, i) => createFinalBoxing(f, (TensorType)((IR.Tuple)body).Fields[i].CheckedType)).ToArray()), + Expr e => (Expr)createFinalBoxing(e, (TensorType)body.CheckedType), + }).ToArray(); + using (new ExprPinner(equivalents)) + { + BranchCut(); + } + + var graph = new EGraph(); + foreach (var (exprKey, buckets) in ExprMemo.Where(kv => kv.Key is not Op)) + { + foreach (var (typeKey, bucket) in buckets.Where(kv => kv.Value.Any())) + { + Unions(graph, bucket); + } + } + + var root = Unions(graph, equivalents); + return graph.Extract(root, null); + } + + protected override Dictionary> DefaultVisitLeaf(Expr expr) + { + return new(); + } + + protected override Dictionary> VisitLeafTuple(IR.Tuple expr) + { + return expr.Fields.ToArray(). + Select(Visit). + CartesianProduct(). + Select(e => new IR.Tuple(e.Select(e => e.Value[0]).ToArray())). + GroupBy(tp => tp.CheckedType). + ToDictionary(g => g.Key, g => g.ToList()); + } + + protected override Dictionary> VisitLeafCall(Call expr) + { + if (expr.Target is not Op op) + { + throw new NotSupportedException("not support auto distributed call function"); + } + + foreach (var param in op.Parameters) + { + VisitLeafArgument(param.ParameterKind, expr.Arguments[param.Index]); + } + + var results = expr.Arguments.ToArray(). + Select(Visit). + CartesianProduct(). + Select(args => args.ToArray()). + Select(args => BuildEquivalCalls(op, args.Select(kv => kv.Value[0]).ToArray()).ToArray()). + SelectMany(i => i). + GroupBy(c => c.CheckedType). + ToDictionary(g => g.Key, g => new List(g.ToList())); + + if (results.Count == 0) + { + return expr.Arguments.ToArray(). + Select(Visit). + CartesianProduct(). + Select(args => args.ToArray()). + Select(args => new[] { new Call(op, args.Select(kv => kv.Value[0]).Select(arg => arg.CheckedType switch + { + DistributedType d => d.NdSBP.All(sbp => sbp is SBPBroadCast) ? arg : IR.F.CPU.Boxing(arg, d with { NdSBP = new(Enumerable.Repeat(SBP.B, d.NdSBP.Count)) }), + _ => arg, + }).ToArray()), }). + SelectMany(i => i). + GroupBy(c => c.CheckedType). + ToDictionary(g => g.Key, g => new List(g.ToList())); + } + + return results; + } + + private Dictionary> VisitLeafArgument(ParameterKind parameterKind, Expr expr) + { + var updateBuckets = (Dictionary> buckets, IEnumerable equivalents) => + { + foreach (var eq in equivalents) + { + if (!buckets.TryGetValue(eq.CheckedType, out var bucket)) + { + bucket = new(); + buckets.Add(eq.CheckedType, bucket); + } + + bucket.Add(eq); + } + }; + + var buckets = ExprMemo[expr]; + if (!buckets.Any()) + { + switch (parameterKind, expr) + { + case (ParameterKind.Input, Expr e) when e is Const or Var: + updateBuckets(buckets, GetLeafCandidateBoxings(e, Placement)); + break; + case (ParameterKind.Input, Expr e) when e is IR.Tuple tp: + foreach (var f in tp.Fields) + { + VisitLeafArgument(parameterKind, f); + } + + foreach (var (k, v) in VisitLeafTuple(tp)) + { + buckets.Add(k, v); + } + + break; + case (ParameterKind.Attribute, Var e): + updateBuckets(buckets, new[] { e }); + break; + case (ParameterKind.Attribute, TensorConst e): + updateBuckets(buckets, new[] { e.With() }); // remove all old users. + break; + case (ParameterKind.Attribute, None e): + updateBuckets(buckets, new[] { e.With() }); + break; + default: + throw new InvalidOperationException(); + } + } + + return buckets; + } + + private IEnumerable BuildEquivalCalls(Op target, Expr[] args) + { + if (!target.Parameters.Where(p => p.ParameterKind == ParameterKind.Input).All(p => IsDistributed(args[p.Index].CheckedType))) + { + throw new ArgumentException("the some arg have no distributed type.", nameof(args)); + } + + var calls = new List(); + var call = new Call(target, args); + var valid = call.InferenceType(); + if (!valid) + { + // 1. dispose current call + using var pinner = new ExprPinner(args); + call.Dispose(); + + if (target is CPUKernelOp { Target: Reshape } || target is Reshape) + { + // the reshape need force boxing. + var newShape = ((TensorConst)args[1]).Value.ToArray(); + var inType = (DistributedType)args[0].CheckedType; + var tensorType = inType.TensorType with { Shape = newShape }; + foreach (var boxing in Utilities.DistributedUtility.GetLeafCandidateNDSBPs(tensorType, inType.Placement). + Select(ndsbp => IR.F.CPU.Boxing(args[0], new DistributedType(tensorType, ndsbp, inType.Placement)))) + { + if (boxing.CheckedType is InvalidType) + { + boxing.Dispose(); + } + else + { + calls.Add(boxing); + } + } + } + else + { + // todo expand search space. + // calls.AddRange(Utilities.DistributedUtility.GetLeafCandidateNDSBPs(tensorType, inType.Placement). + // Select(ndsbp => IR.F.CPU.Boxing(args[0], new DistributedType(tensorType, ndsbp, inType.Placement)))); + } + } + else + { + calls.Add(call); + if (call.CheckedType is DistributedType distributedType) + { + calls.AddRange(Utilities.DistributedUtility.GetPartialCandidateNDSBPs(distributedType). + Select(ndsbp => IR.F.CPU.Boxing(call, distributedType with { NdSBP = ndsbp }))); + } + } + + return calls; + } + + private IReadOnlyList GetReBoxings(Expr expr) + { + if (expr is IR.Tuple tuple) + { + var candidates = tuple.Fields.ToArray(). + Select(GetReBoxings). + CartesianProduct(); + return candidates.Any() ? candidates. + Select(fs => new IR.Tuple(fs.ToArray())). + ToArray() : Array.Empty(); + } + + var type = (DistributedType)expr.CheckedType; + var tensorType = type.TensorType; + var candidateNdsbps = new List[type.Placement.Rank]; + for (int i = 0; i < type.Placement.Rank; i++) + { + candidateNdsbps[i] = new List { SBP.B }; + for (int axis = 0; axis < tensorType.Shape.Rank; axis++) + { + if (tensorType.Shape[axis] is { IsFixed: true, Value: int s } && Utilities.DistributedUtility.IsDivideExactly(s, type.Placement.Hierarchy[i])) + { + candidateNdsbps[i].Add(SBP.S(axis)); + } + } + } + + return candidateNdsbps.CartesianProduct(). + Select(ndsbp => new IRArray(ndsbp)). + Where(ndsbp => ndsbp != type.NdSBP). + Select(ndsbp => new DistributedType(tensorType, new IRArray(ndsbp), type.Placement)). + Select(disttype => IR.F.CPU.Boxing(expr, disttype)).ToArray(); + } + + private bool IsDistributed(IRType type) => type switch + { + DistributedType => true, + TupleType t => t.All(IsDistributed), + _ => false, + }; + + private EClass Unions(EGraph graph, IEnumerable equivalents) + { + var eids = equivalents.Select(graph.Add).ToArray(); + foreach (var cls in eids.Skip(1)) + { + graph.Union(eids[0], cls); + } + + graph.Rebuild(); + return eids[0]; + } + + private void BranchCut() + { + bool changed = true; + while (changed) + { + changed = false; + foreach (var (_, bukets) in ExprMemo) + { + foreach (var (_, buket) in bukets.Where(kv => kv.Value.Any())) + { + if (!buket[0].Users.Any()) + { + foreach (var item in buket) + { + using (new ExprPinner(item.Operands.ToArray())) + { + item.Dispose(); + } + } + + buket.Clear(); + changed = true; + } + } + } + } + } +} diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/CPU/AutoPacking.cs b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/AutoPacking.cs new file mode 100644 index 0000000000..cf22dd34fa --- /dev/null +++ b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/AutoPacking.cs @@ -0,0 +1,64 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System.Reactive; +using System.Runtime.CompilerServices; +using NetFabric.Hyperlinq; +using Nncase.CodeGen; +using Nncase.IR; +using Nncase.IR.CPU; +using Nncase.IR.Tensors; +using Nncase.PatternMatch; +using Nncase.Targets; +using static Nncase.PatternMatch.Utility; + +[assembly: InternalsVisibleTo("Nncase.Tests")] + +namespace Nncase.Passes.Rules; + +/// +/// auto distributed the xpu fusion. +/// +[RuleGenerator] +public sealed partial class AutoPacking : IRewriteRule +{ + public IPattern Pattern { get; } = IsCallWildcard("call", IsFusion("fusion", CPUTarget.Kind, IsWildcard("body"), IsVArgsRepeat("parameters", () => IsVar()))); + + private Expr? GetReplace(Call call, Fusion fusion, Expr body, IReadOnlyList parameters, IReadOnlyList callParams) + { + // 1. convert to distribute graph + if (fusion.Metadata is PackMetaData) + { + return null; + } + + var rank = 1; + var lane = System.Runtime.Intrinsics.Vector256.IsHardwareAccelerated ? 8 : 4; + var newbody = CompilerServices.ERewrite( + body, + new IRewriteRule[] { + new Passes.Rules.CPU.PackSoftmax() { Rank = rank, Lane = lane }, + new Passes.Rules.CPU.PackSwish() { Rank = rank, Lane = lane }, + new Passes.Rules.CPU.PackLayerNorm() { Rank = rank, Lane = lane }, + new Passes.Rules.CPU.PackMatMul() { Rank = rank, Lane = lane }, + new Passes.Rules.CPU.PackUnary() { Rank = rank, Lane = lane }, + new Passes.Rules.CPU.PackBinary() { Rank = rank, Lane = lane }, + new Passes.Rules.CPU.PackTranspose() { Rank = rank, Lane = lane }, + new Passes.Rules.CPU.PackUnsqueeze() { Rank = rank, Lane = lane }, + new Passes.Rules.CPU.PackReshape() { Rank = rank, Lane = lane }, + new Passes.Rules.CPU.PackSlice() { Rank = rank, Lane = lane }, + new Passes.Rules.Neutral.FoldConstCall(), + new Passes.Rules.CPU.FoldPackUnpack(), + new Passes.Rules.CPU.FoldPackConcatUnpack(), + }, + new()); + + var newFusion = fusion.With(moduleKind: CPUTarget.Kind, body: newbody, parameters: parameters.Cast().ToArray()); + newFusion.Metadata = new PackMetaData(); + return new Call(newFusion, callParams.ToArray()); + } + + private sealed class PackMetaData : IR.IRMetadata + { + } +} diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/CPU/FoldBoxingConst.cs b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/FoldBoxingConst.cs new file mode 100644 index 0000000000..f360987b06 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/FoldBoxingConst.cs @@ -0,0 +1,34 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Nncase.IR; +using Nncase.IR.CPU; +using Nncase.PatternMatch; +using static Nncase.IR.F.NN; + +using static Nncase.IR.TypePatternUtility; +using static Nncase.PatternMatch.F.CPU; +using static Nncase.PatternMatch.Utility; + +namespace Nncase.Passes.Rules; + +[RuleGenerator] +public partial class FoldBoxingConst : RewriteRule +{ + /// + public override Pattern Pattern { get; } = IsBoxing( + target_name: "boxing", + _ => true, + IsTensorConst("input")); + + private Expr? GetReplace(Boxing boxing, Tensor input) + { + var type = (DistributedType)boxing.NewType; + return new TensorConst(input, type.NdSBP, type.Placement); + } +} diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/CPU/FoldStoreLoad.cs b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/FoldStoreLoad.cs new file mode 100644 index 0000000000..ff3fb87174 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/FoldStoreLoad.cs @@ -0,0 +1,25 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.IR; +using Nncase.PatternMatch; +using static Nncase.PatternMatch.F.CPU; +using static Nncase.PatternMatch.Utility; + +namespace Nncase.Passes.Rules.CPU; + +[RuleGenerator] +public sealed partial class FoldStoreLoad : IRewriteRule +{ + public IPattern Pattern { get; } = + IsLoad( + _ => true, + IsStore( + _ => true, + input: IsWildcard("input"))); + + public Expr? GetReplace(Expr input) + { + return input; + } +} diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/CPU/FusionMerger.cs b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/FusionMerger.cs new file mode 100644 index 0000000000..58cb05506f --- /dev/null +++ b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/FusionMerger.cs @@ -0,0 +1,79 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net.Http.Headers; +using System.Reactive; +using System.Text; +using System.Threading.Tasks; +using DryIoc.ImTools; +using Google.OrTools.LinearSolver; +using Nncase.IR; +using Nncase.IR.Math; +using Nncase.IR.NN; +using Nncase.IR.Tensors; +using Nncase.Passes.Rules.Neutral; +using Nncase.PatternMatch; +using Nncase.Targets; +using static Nncase.PatternMatch.F.Math; +using static Nncase.PatternMatch.Utility; +using static Nncase.Utilities.ReplaceUtility; + +namespace Nncase.Passes.Rules; + +/// +/// Unet Merger for all. +/// +public sealed class FusionMerger : ExprCloner +{ + private readonly IReadOnlyDictionary _multiVarMap; + + public FusionMerger(IReadOnlyDictionary multiVarMap) + { + _multiVarMap = multiVarMap; + } + + protected override Expr VisitCall(Call expr, Unit context) + { + if (_multiVarMap.TryGetValue(expr, out var newVar)) + { + return newVar; + } + + return base.VisitCall(expr, context); + } + + protected override Expr VisitLeafCall(Call expr, Unit context) + { + var target = Clone(expr.Target, context); + var arguments = CloneArray(expr.Arguments, context); + if (target is Binary || target is Where) + { + arguments = arguments.Select(e => e switch { TensorConst { Value: Tensor { Shape.IsScalar: true } } tc => Const.FromTensor(Tensor.FromBytes(tc.CheckedDataType, tc.Value.BytesBuffer.ToArray(), new[] { 1 })), _ => e }).ToArray(); + } + + if (target is Conv2D conv) + { + var bias = (TensorConst)arguments[2]; + var fusedClamp = ((TensorConst)arguments[7]).Value.ToArray(); + var newConv = IR.F.NN.Conv2D(arguments[0], arguments[1], Tensor.Zeros(bias.CheckedShape), arguments[3], arguments[4], arguments[5], conv.PadMode, arguments[6], new[] { float.NegativeInfinity, float.PositiveInfinity }); + var newBias = IR.F.Math.Add(newConv, Tensor.FromBytes(bias.CheckedDataType, bias.Value.BytesBuffer.ToArray(), new[] { bias.CheckedShape[0].FixedValue, 1, 1 })); + var newClamp = IR.F.Math.Clamp(newBias, fusedClamp[0], fusedClamp[1]); + return newClamp; + } + + return expr.With(target: target, arguments: arguments); + } + + protected override Expr VisitLeafVar(Var expr, Unit context) + { + if (_multiVarMap.TryGetValue(expr, out var newVar)) + { + return newVar; + } + + throw new InvalidOperationException(); + } +} diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/CPU/LowerBinary.cs b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/LowerBinary.cs new file mode 100644 index 0000000000..3c56b82258 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/LowerBinary.cs @@ -0,0 +1,34 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Nncase.IR; +using Nncase.IR.Math; +using Nncase.PatternMatch; + +using static Nncase.IR.F.CPU; +using static Nncase.IR.TypePatternUtility; +using static Nncase.PatternMatch.F.Math; +using static Nncase.PatternMatch.Utility; + +namespace Nncase.Passes.Rules.CPU; + +[RuleGenerator] +public partial class LowerBinary : RewriteRule +{ + /// + public override Pattern Pattern { get; } = IsBinary( + target_name: "binary", + _ => true, + IsWildcard("lhs") with { TypePattern = IsFloat() & HasFixedShape() }, + IsWildcard("rhs") with { TypePattern = IsFloat() & HasFixedShape() }); + + private Expr? GetReplace(Binary binary, Expr lhs, Expr rhs) + { + return CPUKernel(binary, lhs, rhs); + } +} diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/CPU/MakeFusion.cs b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/MakeFusion.cs new file mode 100644 index 0000000000..83c187a199 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/MakeFusion.cs @@ -0,0 +1,339 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Nncase.IR; +using Nncase.IR.CPU; +using Nncase.IR.Math; +using Nncase.IR.NN; +using Nncase.IR.Tensors; +using Nncase.Passes.Rules.Neutral; +using Nncase.PatternMatch; +using Nncase.Targets; + +using static Nncase.IR.TypePatternUtility; +using static Nncase.PatternMatch.F.Math; +using static Nncase.PatternMatch.F.Tensors; +using static Nncase.PatternMatch.Utility; +using static Nncase.Utilities.ReplaceUtility; + +namespace Nncase.Passes.Rules; + +[RuleGenerator] +internal sealed partial class CPUDeviceFusion : FusionMaker +{ + public override string ModuleKind { get; } = CPUTarget.Kind; + + public override Pattern Pattern => IsCallWildcard( + "call", + IsOp( + "op", + op => op is IR.Math.Unary /*or IR.Math.MatMul*/ or IR.Math.Binary)); + + private Call? GetReplace(Call call, Op op, IReadOnlyList callParams) + { + if (call.CheckedType is not DistributedType distributedType) + { + return null; + } + + // note current not support. + if (!Utilities.DistributedUtility.TryGetDividedTensorType(distributedType, out _)) + { + return null; + } + + var newInputs = new List(); + for (int i = 0; i < callParams.Count; i++) + { + if (callParams[i] is Call or Var) + { + newInputs.Add(new Var(callParams[i].CheckedType!)); + } + else + { + newInputs.Add(callParams[i]); + } + } + + var newCall = IR.F.CPU.Store(new Call(op, newInputs.Select(IR.F.CPU.Load).ToArray())); + var callFusion = new Call(new Fusion($"{op.GetType().Name}_{Count++}_device", ModuleKind, newCall, newInputs.OfType().ToArray()), newInputs.Select((e, i) => (e, i)).Where(p => p.e is Var).Select(p => callParams[p.i]).ToArray()); + return callFusion; + } +} + +[RuleGenerator] +internal sealed partial class CPUSingleKernelFusion : FusionMaker +{ + public override string ModuleKind { get; } = CPUTarget.Kind; + + public override Pattern Pattern => IsCallWildcard( + "call", + IsOp( + "op", + op => op switch + { + IR.Math.Unary u => u.UnaryOp is UnaryOp.Abs or UnaryOp.Acos or UnaryOp.Acosh or UnaryOp.Asin or UnaryOp.Asinh or UnaryOp.Ceil or UnaryOp.Cos or UnaryOp.Cosh or UnaryOp.Exp or UnaryOp.Floor or UnaryOp.Log or UnaryOp.Neg or UnaryOp.Round or UnaryOp.Rsqrt or UnaryOp.Sign or UnaryOp.Sin or UnaryOp.Sinh or UnaryOp.Sqrt or UnaryOp.Square or UnaryOp.Tanh, + IR.Math.MatMul => true, + IR.Tensors.Gather => true, + IR.Math.Binary b => b.BinaryOp is BinaryOp.Add or BinaryOp.Sub or BinaryOp.Mul or BinaryOp.Div or BinaryOp.Mod or BinaryOp.Min or BinaryOp.Max or BinaryOp.Pow, + _ => false, + })) with + { TypePattern = TypePatternUtility.HasFixedShape() & TypePatternUtility.HasRank() }; + + private Call? GetReplace(Call call, Op op, IReadOnlyList callParams) + { + var newInputs = new List(); + for (int i = 0; i < callParams.Count; i++) + { + if (callParams[i] is Call or Var or If or Marker) + { + newInputs.Add(new Var(callParams[i].CheckedType switch + { + TensorType { IsScalar: true } t => t with { Shape = new Shape(1) }, + var x => x, + })); + } + else + { + if (callParams[i] is TensorConst { Value: Tensor { Shape.IsScalar: true } } tc) + { + newInputs.Add(Const.FromTensor(Tensor.FromBytes(tc.CheckedDataType, tc.Value.BytesBuffer.ToArray(), new[] { 1 }))); + } + else + { + newInputs.Add(callParams[i]); + } + } + } + + var newCall = new Call(op, newInputs.ToArray()); + var callFusion = new Call(new Fusion($"{op.GetType().Name}_{Count++}_kernel", ModuleKind, newCall, newInputs.OfType().ToArray()), newInputs.Select((e, i) => (e, i)).Where(p => p.e is Var).Select(p => callParams[p.i] switch + { + Expr { CheckedShape.IsScalar: true } e => IR.F.Tensors.Unsqueeze(e, new[] { 0 }), + var e => e, + }).ToArray()); + return callFusion; + } +} + +[RuleGenerator] +internal sealed partial class FuseMHA2 : FusionMaker +{ + public override string ModuleKind { get; } = CPUTarget.Kind; + + public override Pattern Pattern => CreatePattern(); + + private static Pattern CreatePattern() + { + var v1 = IsWildcard("hidden_in"); + + var v2 = IsTensorConst("v2"); + var v3 = IsTensorConst("v3"); + var v4 = IsCall("v4", IsOp(), IsVArgs(v1, v2, v3)); + var v5 = IsTensorConst("v5"); + var v6 = IsCall("v6", IsOp(), IsVArgs(v4, v5)); + var v7 = IsTensorConst("v7"); + var v8 = IsCall("v8", IsOp(), IsVArgs(v6, v7)); + + // var v9 = IsTensorConst("v9"); + var v10 = IsWildcard("left_gather"); + + // var v11 = IsCall("v11", IsOp(), IsVArgs(v9, v10)); + var v12 = IsTensorConst("v12"); + var v13 = IsCall("v13", IsOp(), IsVArgs(v10, v12)); + var v14 = IsCall("v14", IsOp(), IsVArgs(v8, v13)); + var v15 = IsTensorConst("v15"); + var v16 = IsTensorConst("v16"); + var v17 = IsTensorConst("v17"); + var v18 = IsTensorConst("v18"); + var v19 = IsCall("v19", IsOp(), IsVArgs(v8, v15, v16, v17, v18)); + var v20 = IsCall("v20", IsOp(), IsVArgs(v19)); + var v21 = IsTensorConst("v21"); + var v22 = IsCall("v22", IsOp(), IsVArgs(v8, v21, v15, v17, v18)); + var v23 = IsTuple("v23", IsVArgs(v20, v22)); + + var v24 = IsCall("v24", IsOp(), IsVArgs(v23)); + + // var v25 = IsTensorConst("v25"); + // var v26 = IsCall("v26", IsOp(), IsVArgs(v25, v10)); + var v26 = IsWildcard("right_gather"); + var v27 = IsCall("v27", IsOp(), IsVArgs(v26, v12)); + var v28 = IsCall("v28", IsOp(), IsVArgs(v24, v27)); + var v29 = IsCall("v29", IsOp(), IsVArgs(v14, v28)); + var v30 = IsTensorConst("v30"); + var v31 = IsCall("v31", IsOp(), IsVArgs(v4, v30)); + var v32 = IsTensorConst("v32"); + var v33 = IsCall("v33", IsOp(), IsVArgs(v31, v32)); + var v34 = IsCall("v34", IsOp(), IsVArgs(v33, v13)); + var v35 = IsCall("v35", IsOp(), IsVArgs(v33, v15, v16, v17, v18)); + var v36 = IsCall("v36", IsOp(), IsVArgs(v35)); + var v37 = IsCall("v37", IsOp(), IsVArgs(v33, v21, v15, v17, v18)); + var v38 = IsTuple("v38", IsVArgs(v36, v37)); + + var v39 = IsCall("v39", IsOp(), IsVArgs(v38)); + var v40 = IsCall("v40", IsOp(), IsVArgs(v39, v27)); + var v41 = IsCall("v41", IsOp(), IsVArgs(v34, v40)); + var v42 = IsTensorConst("v42"); + var v43 = IsCall("v43", IsOp(), IsVArgs(v41, v42)); + var v44 = IsCall("v44", IsOp(), IsVArgs(v29, v43)); + var v45 = IsTensorConst("v45"); + var v46 = IsCall("v46", IsOp(), IsVArgs(v44, v45)); + var v47 = IsWildcard("attn_mask"); + + var v48 = IsCall("v48", IsOp(), IsVArgs(v46, v47)); + var v49 = IsTensorConst("v49"); + var v50 = IsCall("v50", IsOp(), IsVArgs(v48, v49)); + var v51 = IsTensorConst("v51"); + var v52 = IsCall("v52", IsOp(), IsVArgs(v4, v51)); + var v53 = IsTensorConst("v53"); + var v54 = IsCall("v54", IsOp(), IsVArgs(v52, v53)); + var v55 = IsCall("v55", IsOp(), IsVArgs(v50, v54)); + var v56 = IsTensorConst("v56"); + var v57 = IsCall("v57", IsOp(), IsVArgs(v55, v56)); + var v58 = IsTensorConst("v58"); + var v59 = IsCall("v59", IsOp(), IsVArgs(v57, v58)); + var v60 = IsTensorConst("v60"); + var v61 = IsCall("v61", IsOp(), IsVArgs(v59, v60)); + var v62 = IsCall("v62", IsOp(), IsVArgs(v1, v61)); + var v2_ = IsTensorConst("v2_"); + var v3_ = IsTensorConst("v3_"); + var v63 = IsCall("v63", IsOp(), IsVArgs(v62, v2_, v3_)); + var v64 = IsTensorConst("v64"); + var v65 = IsCall("v65", IsOp(), IsVArgs(v63, v64)); + var v66 = IsTensorConst("v66"); + var v67 = IsCall("v67", IsOp(), IsVArgs(v65, v66)); + var v68 = IsTensorConst("v68"); + var v69 = IsCall("v69", IsOp(), IsVArgs(v63, v68)); + var v70 = IsCall("v70", IsOp(), IsVArgs(v67, v69)); + var v71 = IsTensorConst("v71"); + var v72 = IsCall("v72", IsOp(), IsVArgs(v70, v71)); + var v73 = IsCall("root", IsOp(), IsVArgs(v62, v72)); + + return v73; + } + + private Call? GetReplace(Call root, Expr hidden_in, Expr left_gather, Expr right_gather, Expr attn_mask) + { + var newInputs = new List + { + new Var(hidden_in.CheckedType!), + new Var(left_gather.CheckedType!), + new Var(right_gather.CheckedType!), + new Var(attn_mask.CheckedType!), + }; + + var multiVarMap = new Dictionary(ReferenceEqualityComparer.Instance) + { + { hidden_in, (Var)newInputs[0] }, + { left_gather, (Var)newInputs[1] }, + { right_gather, (Var)newInputs[2] }, + { attn_mask, (Var)newInputs[3] }, + }; + var merger = new FusionMerger(multiVarMap); + var clonedRoot = merger.Clone(root, default); + + var callFusion = new Call(new Fusion($"MHALLaMA65B_{nameof(FuseMHA2)}_{Count++}_kernel", ModuleKind, clonedRoot, newInputs.OfType().ToArray()), hidden_in, left_gather, right_gather, attn_mask); + return callFusion; + } +} + +/// +/// Convert QKV computation to MHA style. +/// %9 = MatMul(%2, const(f32[768,768])) +/// %10 = Add(BinaryOp.Add, const(f32[768]), %9) +/// %11 = Reshape(%10, const(i32[4] : {1,77,12,64})) +/// %12 = Transpose(%11, const(i64[4] : {0L,2L,1L,3L})) +/// %13 = Reshape(%12, const(i32[3] : {12,77,64})). +/// +[RuleGenerator] +internal sealed partial class CombineMHA : IRewriteRule +{ + public CombineMHA() + { + Pattern v0 = IsMatMul("mm", "mmCall", IsWildcard("x"), IsTensorConst("w")); + + var bias = IsAlt( + IsBinary("add", "addCall", op => op.BinaryOp == BinaryOp.Add, IsTensorConst("bias"), v0), + IsBinary("add", "addCall", op => op.BinaryOp == BinaryOp.Add, v0, IsTensorConst("bias")), + v0); + var scale = IsAlt( + IsBinary("mul", "mulCall", op => op.BinaryOp == BinaryOp.Mul, bias, IsTensorConst("scale")), + IsBinary("mul", "mulCall", op => op.BinaryOp == BinaryOp.Mul, IsTensorConst("scale"), bias), + bias); + + var v1 = IsReshape("rshape", "rshapeCall", scale, IsTensorConst("newShape")); + var v2 = IsTranspose("tp", "tpCall", v1, IsTensorConst("perm")) with { TypePattern = HasFixedShape() }; + Pattern = v2; + } + + public IPattern Pattern { get; } + + private Expr? GetReplace(Expr x, Call mmCall, TensorConst w, TensorConst newShape, int[] perm, IMatchResult matchResult) + { + var mmOutShape = mmCall.CheckedShape.ToValueArray(); + var wReshape = newShape.Value.ToArray().TakeLast(2).ToArray(); + + // TODO: add more patterns, only llama65b for now + if (perm.Length == 4 && perm.SequenceEqual(new[] { 0, 2, 1, 3 }) + && wReshape.Aggregate(1, (x, y) => x * y) == mmOutShape[^1] + && (mmOutShape.Length == 2 || (mmOutShape.Length == 3 && mmOutShape[0] == 1))) + { + var newW = IR.F.Tensors.Transpose(IR.F.Tensors.Reshape(w, new[] { -1, wReshape[0], wReshape[1] }), new[] { 1, 0, 2 }); + var newMm = IR.F.Tensors.MatMul(IR.F.Tensors.Unsqueeze(x, new[] { 1 }), newW); + if (matchResult.GetValueOrDefault("bias") is TensorConst bias) + { + return null; + } + + if (matchResult.GetValueOrDefault("scale") is TensorConst scale) + { + return null; + } + + return newMm; + } + else if (perm.Length == 3 && perm.SequenceEqual(new[] { 1, 0, 2 }) + && wReshape.Aggregate(1, (x, y) => x * y) == mmOutShape[^1] + && (mmOutShape.Length == 2 || (mmOutShape.Length == 3 && mmOutShape[0] == 1))) + { + var newW = IR.F.Tensors.Transpose(IR.F.Tensors.Reshape(w, new[] { -1, wReshape[0], wReshape[1] }), new[] { 1, 0, 2 }); + var newMm = IR.F.Tensors.MatMul(x, newW); + if (matchResult.GetValueOrDefault("bias") is TensorConst bias) + { + newMm = IR.F.Math.Add(newMm, bias.Value.Shape.IsScalar ? bias : IR.F.Tensors.Reshape(bias, new[] { -1, 1, wReshape[1] })); + } + + if (matchResult.GetValueOrDefault("scale") is TensorConst scale) + { + newMm = IR.F.Math.Mul(newMm, scale.Value.Shape.IsScalar ? scale : IR.F.Tensors.Reshape(scale, new[] { -1, 1, wReshape[1] })); + } + + return newMm; + } + else if (perm.Length == 3 && perm.SequenceEqual(new[] { 1, 2, 0 }) + && wReshape.Aggregate(1, (x, y) => x * y) == mmOutShape[^1] + && (mmOutShape.Length == 2 || (mmOutShape.Length == 3 && mmOutShape[0] == 1))) + { + var newW = IR.F.Tensors.Transpose(IR.F.Tensors.Reshape(w, new[] { -1, wReshape[0], wReshape[1] }), new[] { 1, 0, 2 }); + var newMm = IR.F.Tensors.MatMul(x, newW); + if (matchResult.GetValueOrDefault("bias") is TensorConst bias) + { + newMm = IR.F.Math.Add(newMm, bias.Value.Shape.IsScalar ? bias : IR.F.Tensors.Reshape(bias, new[] { -1, 1, wReshape[1] })); + } + + if (matchResult.GetValueOrDefault("scale") is TensorConst scale) + { + newMm = IR.F.Math.Mul(newMm, scale.Value.Shape.IsScalar ? scale : IR.F.Tensors.Reshape(scale, new[] { -1, 1, wReshape[1] })); + } + + return IR.F.Tensors.Transpose(newMm, new[] { 0, 2, 1 }); + } + + return null; + } +} diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/CPU/PackRule.cs b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/PackRule.cs new file mode 100644 index 0000000000..584eb165ed --- /dev/null +++ b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/PackRule.cs @@ -0,0 +1,655 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Nncase.IR; +using Nncase.PatternMatch; +using Nncase.Utilities; + +using static Nncase.IR.TypePatternUtility; +using static Nncase.PatternMatch.F.Math; +using static Nncase.PatternMatch.F.NN; +using static Nncase.PatternMatch.F.Tensors; +using static Nncase.PatternMatch.Utility; + +namespace Nncase.Passes.Rules.CPU; + +public abstract class PackRule : RewriteRule +{ + public int Lane { get; set; } = 32; + + public int Rank { get; set; } = 2; + + public override Expr? GetReplace(IMatchResult result, RunPassContext options) => throw new NotImplementedException(); +} + +public sealed class PackSoftmax : PackRule +{ + public override Pattern Pattern { get; } = IsSoftmax( + "target", + IsWildcard("input") with { TypePattern = IsFloat() }, + IsWildcard("axis") with { TypePattern = IsIntegralScalar() }); + + public override List GetReplaceCandidates(IMatchResult result, RunPassContext context) + { + var rets = new List(); + var input = (Expr)result["input"]; + var axis = ((TensorConst)result["axis"]).Value.ToScalar(); + var inShape = input.CheckedShape.ToValueArray(); + + void AddCandidate(int[] packedAxes, int[] lanes) + { + var packed = IR.F.CPU.Pack(PackUtility.PadForPack(input, inShape, packedAxes, lanes, float.NegativeInfinity, out var pads), lanes, packedAxes); + var softmax = IR.F.CPU.PackedSoftmax(packed, axis, packedAxes); + if (softmax.CheckedType is not InvalidType) + { + var post = PackUtility.SliceForPack(IR.F.CPU.Unpack(softmax, packedAxes), inShape, pads); + rets.Add(post); + } + } + + for (int i = 0; i < input.CheckedShape.Count; i++) + { + AddCandidate(new[] { i }, new[] { Lane }); + for (int j = i + 1; j < input.CheckedShape.Count; j++) + { + if (Rank > 1) + { + AddCandidate(new[] { i, j }, new[] { Lane, Lane }); + } + } + } + + return rets; + } +} + +public sealed class PackLayerNorm : PackRule +{ + public override Pattern Pattern { get; } = IsLayerNorm( + "target", + _ => true, + IsWildcard("input") with { TypePattern = IsFloat() }, + IsWildcard("scale") with { TypePattern = IsFloat() }, + IsWildcard("bias") with { TypePattern = IsFloat() }); + + public override List GetReplaceCandidates(IMatchResult result, RunPassContext context) + { + var rets = new List(); + var op = (IR.NN.LayerNorm)result["target"]; + var input = (Expr)result["input"]; + var scale = (Expr)result["scale"]; + var bias = (Expr)result["bias"]; + var inShape = input.CheckedShape.ToValueArray(); + var pshape = inShape.Skip(op.Axis).ToArray(); + + void AddCandidate(int[] packedAxes, int[] lanes) + { + var packedInput = IR.F.CPU.Pack(PackUtility.PadForPack(input, inShape, packedAxes, lanes, 0f, out var padsInput), lanes, packedAxes); + + var pAxes = packedAxes.Where(i => i >= op.Axis).Select(i => i - op.Axis).ToArray(); + var packedScale = PackUtility.PadForPack(scale, pshape, pAxes, lanes, 0f, out var padsScale); + if (pAxes.Length > 0) + { + packedScale = IR.F.CPU.Pack(packedScale, Enumerable.Repeat(Lane, pAxes.Length).ToArray(), pAxes); + } + + var packedBias = PackUtility.PadForPack(bias, pshape, pAxes, lanes, 0f, out var padsBias); + if (pAxes.Length > 0) + { + packedBias = IR.F.CPU.Pack(packedBias, Enumerable.Repeat(Lane, pAxes.Length).ToArray(), pAxes); + } + + var layernorm = IR.F.CPU.PackedLayerNorm(packedInput, packedScale, packedBias, op.Axis, op.Epsilon, op.UseMean, packedAxes, padsInput); + + if (layernorm.CheckedType is not InvalidType) + { + var post = PackUtility.SliceForPack(IR.F.CPU.Unpack(layernorm, packedAxes), inShape, padsInput); + rets.Add(post); + } + } + + for (int i = 0; i < input.CheckedShape.Count; i++) + { + AddCandidate(new[] { i }, new[] { Lane }); + for (int j = i + 1; j < input.CheckedShape.Count; j++) + { + if (Rank > 1) + { + AddCandidate(new[] { i, j }, new[] { Lane, Lane }); + } + } + } + + return rets; + } +} + +public sealed class PackMatMul : PackRule +{ + public override Pattern Pattern { get; } = IsMatMul( + "target", + IsWildcard("lhs") with { TypePattern = IsFloat() }, + IsWildcard("rhs") with { TypePattern = IsFloat() }); + + public override List GetReplaceCandidates(IMatchResult result, RunPassContext context) + { + var rets = new List(); + var lhs = (Expr)result["lhs"]; + var rhs = (Expr)result["rhs"]; + var candidate = (Expr)result[Pattern]; + var lhsShape = lhs.CheckedShape.ToValueArray(); + var rhsShape = rhs.CheckedShape.ToValueArray(); + + void AddCandidate(int[] lhsPackedAxes, int[] rhsPackedAxes, int[] lhsLanes, int[] rhsLanes) + { + var packedLhs = IR.F.CPU.Pack(PackUtility.PadForPack(lhs, lhsShape, lhsPackedAxes, lhsLanes, 0f, out var lhsPadNums), lhsLanes, lhsPackedAxes); + var packedRhs = IR.F.CPU.Pack(PackUtility.PadForPack(rhs, rhsShape, rhsPackedAxes, rhsLanes, 0f, out var rhsPadNums), rhsLanes, rhsPackedAxes); + + var matmul = IR.F.CPU.PackedMatMul(packedLhs, packedRhs, lhsPackedAxes, lhsPadNums, rhsPackedAxes, rhsPadNums); + var lhsAlign = System.Math.Max(lhsShape.Length, rhsShape.Length) - lhsShape.Length; + var rhsAlign = System.Math.Max(lhsShape.Length, rhsShape.Length) - rhsShape.Length; + var post = matmul; + if (lhsPackedAxes.Length == 2 && rhsPackedAxes.Length == 2) + { + post = PackUtility.SliceForPack(IR.F.CPU.Unpack(matmul, new[] { lhsAlign + lhsPackedAxes[0], rhsAlign + rhsPackedAxes[1] }), candidate.CheckedShape.ToValueArray(), new[] { lhsPadNums[0], rhsPadNums[1] }); + } + + rets.Add(post); + } + + AddCandidate(new[] { lhsShape.Length - 1 }, new[] { rhsShape.Length - 2 }, new[] { Lane }, new[] { Lane }); + if (Rank > 1) + { + AddCandidate(new[] { lhsShape.Length - 2, lhsShape.Length - 1 }, new[] { rhsShape.Length - 2, rhsShape.Length - 1 }, new[] { Lane, Lane }, new[] { Lane, Lane }); + } + + return rets; + } +} + +public sealed class PackUnary : PackRule +{ + public override Pattern Pattern { get; } = IsUnary( + "target", + _ => true, + IsWildcard("input") with { TypePattern = IsFloat() }); + + public override List GetReplaceCandidates(IMatchResult result, RunPassContext context) + { + var rets = new List(); + var op = (IR.Math.Unary)result["target"]; + var input = (Expr)result["input"]; + var inShape = input.CheckedShape.ToValueArray(); + + void AddCandidate(int[] packedAxes, int[] lanes) + { + var packedInput = IR.F.CPU.Pack(PackUtility.PadForPack(input, inShape, packedAxes, lanes, 0f, out var padsInput), lanes, packedAxes); + var unary = IR.F.Math.Unary(op.UnaryOp, packedInput); + if (unary.CheckedType is not InvalidType) + { + var post = PackUtility.SliceForPack(IR.F.CPU.Unpack(unary, packedAxes), inShape, padsInput); + rets.Add(post); + } + } + + for (int i = 0; i < input.CheckedShape.Count; i++) + { + AddCandidate(new[] { i }, new[] { Lane }); + for (int j = i + 1; j < input.CheckedShape.Count; j++) + { + if (Rank > 1) + { + AddCandidate(new[] { i, j }, new[] { Lane, Lane }); + } + } + } + + return rets; + } +} + +public sealed class PackBinary : PackRule +{ + public override Pattern Pattern { get; } = IsBinary( + "target", + _ => true, + IsWildcard("lhs") with { TypePattern = IsFloat() }, + IsWildcard("rhs") with { TypePattern = IsFloat() }); + + public override List GetReplaceCandidates(IMatchResult result, RunPassContext context) + { + var rets = new List(); + var op = (IR.Math.Binary)result["target"]; + var lhs = (Expr)result["lhs"]; + var rhs = (Expr)result["rhs"]; + var candidate = (Expr)result[Pattern]; + var lhsShape = lhs.CheckedShape.ToValueArray(); + var rhsShape = rhs.CheckedShape.ToValueArray(); + + void AddCandidate(int[] lhsPackedAxes, int[] rhsPackedAxes, int[] lhsLanes, int[] rhsLanes) + { + var packedLhs = IR.F.CPU.Pack(PackUtility.PadForPack(lhs, lhsShape, lhsPackedAxes, lhsLanes, 0f, out var lhsPadNums), lhsLanes, lhsPackedAxes); + var packedRhs = IR.F.CPU.Pack(PackUtility.PadForPack(rhs, rhsShape, rhsPackedAxes, rhsLanes, 0f, out var rhsPadNums), rhsLanes, rhsPackedAxes); + + var binary = IR.F.CPU.PackedBinary(packedLhs, packedRhs, op.BinaryOp, lhsPackedAxes, lhsPadNums, rhsPackedAxes, rhsPadNums); + if (binary.CheckedType is not InvalidType) + { + var post = PackUtility.SliceForPack(IR.F.CPU.Unpack(binary, lhsPackedAxes.Length >= rhsPackedAxes.Length ? lhsPackedAxes : rhsPackedAxes), candidate.CheckedShape.ToValueArray(), lhsPackedAxes.Length >= rhsPackedAxes.Length ? lhsPadNums : rhsPadNums); + rets.Add(post); + } + } + + foreach (var arr in new[] { GeneratePackAxes(lhsShape), GeneratePackAxes(rhsShape) }.CartesianProduct()) + { + var lhsPackedAxes = arr.First(); + var rhsPackedAxes = arr.Skip(1).First(); + if (lhsPackedAxes.Length <= Rank && rhsPackedAxes.Length <= Rank) + { + AddCandidate(lhsPackedAxes, rhsPackedAxes, Enumerable.Repeat(Lane, lhsPackedAxes.Length).ToArray(), Enumerable.Repeat(Lane, rhsPackedAxes.Length).ToArray()); + } + } + + return rets; + } + + public IEnumerable GeneratePackAxes(int[] shape) + { + if (shape.Length == 0 || (shape.Length == 1 && shape[0] == 1)) + { + yield return Array.Empty(); + } + else + { + for (int i = 0; i < shape.Length; i++) + { + yield return new[] { i }; + for (int j = i + 1; j < shape.Length; j++) + { + yield return new[] { i, j }; + } + } + } + } +} + +public sealed class PackSwish : PackRule +{ + public override Pattern Pattern { get; } = IsSwish( + "target", + IsWildcard("input") with { TypePattern = IsFloat() }, + IsTensorConst("beta") with { TypePattern = IsFloatScalar() }); + + public override List GetReplaceCandidates(IMatchResult result, RunPassContext context) + { + var rets = new List(); + var input = (Expr)result["input"]; + var beta = ((TensorConst)result["beta"]).Value.ToScalar(); + var inShape = input.CheckedShape.ToValueArray(); + + void AddCandidate(int[] packedAxes, int[] lanes) + { + var packed = IR.F.CPU.Pack(PackUtility.PadForPack(input, inShape, packedAxes, lanes, 0f, out var pads), lanes, packedAxes); + var swish = IR.F.NN.Swish(packed, beta); + var post = PackUtility.SliceForPack(IR.F.CPU.Unpack(swish, packedAxes), inShape, pads); + rets.Add(post); + } + + for (int i = 0; i < input.CheckedShape.Count; i++) + { + AddCandidate(new[] { i }, new[] { Lane }); + for (int j = i + 1; j < input.CheckedShape.Count; j++) + { + if (Rank > 1) + { + AddCandidate(new[] { i, j }, new[] { Lane, Lane }); + } + } + } + + return rets; + } +} + +public sealed class PackTranspose : PackRule +{ + public override Pattern Pattern { get; } = IsTranspose( + "trans", + IsWildcard("input") with { TypePattern = IsFloat() }, + IsTensorConst("perm") with { TypePattern = IsIntegral() }); + + public override List GetReplaceCandidates(IMatchResult result, RunPassContext context) + { + var rets = new List(); + + var input = (Expr)result["input"]; + var perm = ((TensorConst)result["perm"]).Value.ToArray(); + var inShape = input.CheckedShape.ToValueArray(); + + void AddCandidate(int[] packedAxes, int[] lanes) + { + var packed = IR.F.CPU.Pack(PackUtility.PadForPack(input, inShape, packedAxes, lanes, 0f, out var pads), lanes, packedAxes); + + var tarns = IR.F.CPU.PackedTranspose(packed, perm, packedAxes); + if (tarns.CheckedType is not InvalidType) + { + var unpackAxes = packedAxes.Select(axis => perm.IndexOf(axis)).ToArray(); + bool swap = unpackAxes.Length == 2 && unpackAxes[0] > unpackAxes[1]; + if (swap) + { + (unpackAxes[0], unpackAxes[1]) = (unpackAxes[1], unpackAxes[0]); + (pads[0], pads[1]) = (pads[1], pads[0]); + } + + var newShape = perm.Select(i => inShape[i]).ToArray(); + rets.Add(PackUtility.SliceForPack(IR.F.CPU.Unpack(tarns, unpackAxes), newShape, pads)); + } + } + + for (int i = 0; i < input.CheckedShape.Count; i++) + { + AddCandidate(new[] { i }, new[] { Lane }); + for (int j = i + 1; j < input.CheckedShape.Count; j++) + { + if (Rank > 1) + { + AddCandidate(new[] { i, j }, new[] { Lane, Lane }); + } + } + } + + return rets; + } +} + +public sealed class PackUnsqueeze : PackRule +{ + public override Pattern Pattern { get; } = IsUnsqueeze( + "unsq", + IsWildcard("input") with { TypePattern = IsFloat() }, + IsTensorConst("axes") with { TypePattern = IsIntegral() }); + + public override List GetReplaceCandidates(IMatchResult result, RunPassContext context) + { + var rets = new List(); + + var input = (Expr)result["input"]; + var axes = ((TensorConst)result["axes"]).Value.ToArray(); + var inShape = input.CheckedShape.ToValueArray(); + + void AddCandidate(int[] packedAxes, int[] lanes) + { + var packed = IR.F.CPU.Pack(PackUtility.PadForPack(input, inShape, packedAxes, lanes, 0f, out var pads), lanes, packedAxes); + + var post = IR.F.Tensors.Unsqueeze(packed, axes); + if (post.CheckedType is not InvalidType) + { + var unpackAxes = packedAxes.Select(axis => axis + axes.Count(i => i <= axis)).ToArray(); + var outShape = inShape.ToList(); + foreach (var axis in axes) + { + if (axis >= 0) + { + outShape.Insert(axis, 1); + } + else + { + var index = System.Math.Max(outShape.Count + axis + 1, 0); + outShape.Insert(index, 1); + } + } + + rets.Add(PackUtility.SliceForPack(IR.F.CPU.Unpack(post, unpackAxes), outShape.ToArray(), pads)); + } + } + + for (int i = 0; i < input.CheckedShape.Count; i++) + { + AddCandidate(new[] { i }, new[] { Lane }); + for (int j = i + 1; j < input.CheckedShape.Count; j++) + { + if (Rank > 1) + { + AddCandidate(new[] { i, j }, new[] { Lane, Lane }); + } + } + } + + return rets; + } +} + +public sealed class PackReshape : PackRule +{ + public override Pattern Pattern { get; } = IsReshape( + "target", + IsWildcard("input") with { TypePattern = IsFloat() }, + IsTensorConst("newShape") with { TypePattern = IsIntegral() }); + + public override List GetReplaceCandidates(IMatchResult result, RunPassContext context) + { + var rets = new List(); + + var input = (Expr)result["input"]; + var newShape = ((TensorConst)result["newShape"]).Value.ToArray(); + var inShape = input.CheckedShape.ToValueArray(); + + // 1. find the mapping transforms + if (!PackUtility.TryGetShapeMapMatrix(inShape, newShape, out var mat)) + { + return new List { }; + } + + var (forwardDict, backwardDict) = PackUtility.ShapeMapMatrixAsDict(mat); + + void AddCandidate(int[] packedAxes, int[] lanes) + { + // 1. skip when the packedAxes will be split or merge. + var unpackAxes = new List(); + foreach (var axis in packedAxes) + { + var mapedOutAxes = forwardDict[axis]; + if (mapedOutAxes.Count > 1) + { + // split to more dim. + if (mapedOutAxes.Count(i => newShape[i] != 1) > 1) + { + continue; + } + else + { + // unsqueeze. + var outAxis = mapedOutAxes.FirstOrDefault(i => newShape[i] != 1, mapedOutAxes.First()); + if (backwardDict[outAxis].Count != 1) + { + continue; + } + + unpackAxes.Add(outAxis); + } + } + else + { + var outAxis = mapedOutAxes.First(); + + // when the outAxis is merged dim, only support no transpose order and no pad. + var inAxes = backwardDict[outAxis]; + if (inAxes.Count == 1 || (inAxes[^1] == axis && inShape[axis] % Lane == 0)) + { + unpackAxes.Add(outAxis); + } + else + { + return; + } + } + } + + var packed = IR.F.CPU.Pack(PackUtility.PadForPack(input, inShape, packedAxes, lanes, 0f, out var pads), lanes, packedAxes); + var packedNewShape = newShape.ToArray(); + foreach (var (lane, axis) in lanes.Zip(unpackAxes)) + { + packedNewShape[axis] = MathUtility.CeilDiv(packedNewShape[axis], lane); + } + + var post = IR.F.Tensors.Reshape(packed, packedNewShape); + if (post.CheckedType is not InvalidType) + { + rets.Add(PackUtility.SliceForPack(IR.F.CPU.Unpack(post, unpackAxes.ToArray()), newShape, pads)); + } + } + + for (int i = 0; i < input.CheckedShape.Count; i++) + { + AddCandidate(new[] { i }, new[] { Lane }); + for (int j = i + 1; j < input.CheckedShape.Count; j++) + { + if (Rank > 1) + { + AddCandidate(new[] { i, j }, new[] { Lane, Lane }); + } + } + } + + return rets; + } +} + +public sealed class PackSlice : PackRule +{ + public override Pattern Pattern { get; } = IsSlice( + "target", + IsWildcard("input") with { TypePattern = IsFloat() }, + IsTensorConst("begins") with { TypePattern = IsIntegral() }, + IsTensorConst("ends") with { TypePattern = IsIntegral() }, + IsTensorConst("axes") with { TypePattern = IsIntegral() }, + IsTensorConst("strides") with { TypePattern = IsIntegral() }); + + public override List GetReplaceCandidates(IMatchResult result, RunPassContext context) + { + var rets = new List(); + + var input = (Expr)result["input"]; + var begins = ((TensorConst)result["begins"]).Value.ToArray(); + var ends = ((TensorConst)result["ends"]).Value.ToArray(); + var axes = ((TensorConst)result["axes"]).Value.ToArray(); + var strides = ((TensorConst)result["strides"]).Value.ToArray(); + var inShape = input.CheckedShape.ToValueArray(); + var candidate = (Expr)result[Pattern]; + for (int i = 0; i < axes.Length; i++) + { + ends[i] = ends[i] switch + { + < 0 => inShape[axes[i]] + ends[i], + int.MaxValue => inShape[axes[i]], + long.MaxValue => inShape[axes[i]], + _ => ends[i], + }; + } + + if (strides.Any(s => s != 1)) + { + return rets; + } + + void AddCandidate(int[] packAxes, int[] lanes) + { + var packedBegins = begins.ToArray(); + var packedEnds = ends.ToArray(); + for (int i = 0; i < packAxes.Length; i++) + { + var packAxis = packAxes[i]; + int j = axes.IndexOf(packAxis); + + // when the slice axis was packed, it must have no pad. + if (j != -1) + { + if (begins[j] % lanes[i] == 0 && ends[j] % lanes[i] == 0) + { + packedBegins[j] = begins[j] / lanes[i]; + packedEnds[j] = ends[j] / lanes[i]; + } + else + { + return; + } + } + } + + var packed = IR.F.CPU.Pack(PackUtility.PadForPack(input, inShape, packAxes, lanes, 0f, out var pads), lanes, packAxes); + var post = IR.F.Tensors.Slice(packed, packedBegins, packedEnds, axes, strides); + if (post.CheckedType is not InvalidType) + { + rets.Add(PackUtility.SliceForPack(IR.F.CPU.Unpack(post, packAxes), candidate.CheckedShape.ToValueArray(), pads)); + } + } + + for (int i = 0; i < input.CheckedShape.Count; i++) + { + AddCandidate(new[] { i }, new[] { Lane }); + for (int j = i + 1; j < input.CheckedShape.Count; j++) + { + if (Rank > 1) + { + AddCandidate(new[] { i, j }, new[] { Lane, Lane }); + } + } + } + + return rets; + } +} + +[RuleGenerator] +public sealed partial class FoldPackUnpack : RewriteRule +{ + public override Pattern Pattern { get; } = PatternMatch.F.CPU.IsPack("pack", "caller", _ => true, PatternMatch.F.CPU.IsUnpack("unpack", "callee", _ => true, IsWildcard("input"))); + + private Expr? GetReplace(IR.CPU.Pack pack, IR.CPU.Unpack unpack, Expr input) + { + if (pack.Axes.SequenceEqual(unpack.Axes)) + { + return input; + } + + return null; + } +} + +[RuleGenerator] +public sealed partial class FoldPackConcatUnpack : RewriteRule +{ + public override Pattern Pattern { get; } = PatternMatch.F.CPU.IsPack("pack", "caller", _ => true, PatternMatch.F.Tensors.IsConcat("concat", _ => true, IsTuple("tuple", IsVArgsRepeat("fileds", exprs => + { + var patterns = new Pattern[exprs.Length]; + for (int i = 0; i < exprs.Length; i++) + { + patterns[i] = PatternMatch.F.CPU.IsUnpack($"unpack_{i}", $"callee_{i}", _ => true, IsWildcard($"input_{i}")); + } + + return patterns; + })))); + + private Expr? GetReplace(IR.CPU.Pack pack, IR.Tensors.Concat concat, IReadOnlyList fileds, IMatchResult result) + { + var inputs = new Expr[fileds.Count]; + for (int i = 0; i < fileds.Count; i++) + { + var unpack = (IR.CPU.Unpack)result[$"unpack_{i}"]; + if (pack.Axes.SequenceEqual(unpack.Axes)) + { + inputs[i] = (Expr)result[$"input_{i}"]; + } + else + { + return null; + } + } + + return IR.F.Tensors.Concat(new IR.Tuple(inputs), concat.Axis); + } +} diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/AffineMap.cs b/modules/Nncase.Modules.CPU/Passes/Tile/AffineMap.cs new file mode 100644 index 0000000000..c45044df5d --- /dev/null +++ b/modules/Nncase.Modules.CPU/Passes/Tile/AffineMap.cs @@ -0,0 +1,286 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reactive; +using NetFabric.Hyperlinq; +using Nncase.IR; + +namespace Nncase.Passes.Tile; + +public static class ExprExtensions +{ + public static Expr Compose(this Expr expr, AffineMap map) + { + return expr.ReplaceDimsAndSymbols(map.Results, map.Symbols); + } + + public static Expr ReplaceDimsAndSymbols(this Expr expr, Expr[] newDims, Expr[] newSymbols) + { + int i; + switch (expr) + { + case TensorConst: + return expr; + case Var dimExpr when dimExpr.Name.StartsWith("d"): + i = int.Parse(dimExpr.Name.Substring(1)); + if (i >= newDims.Length) + { + return expr; + } + + return newDims[i]; + case Var symExpr when symExpr.Name.StartsWith("s"): + i = int.Parse(symExpr.Name.Substring(1)); + if (i >= newSymbols.Length) + { + return expr; + } + + return newSymbols[i]; + case Call { Target: IR.Math.Binary op } call: + var lhs = ReplaceDimsAndSymbols(call[op.Parameters.First()], newDims, newSymbols); + var rhs = ReplaceDimsAndSymbols(call[op.Parameters.Last()], newDims, newSymbols); + return IR.F.Math.Binary(op.BinaryOp, lhs, rhs); + case Call { Target: IR.Math.Unary { UnaryOp: UnaryOp.Neg } op } call: + return IR.F.Math.Unary(op.UnaryOp, ReplaceDimsAndSymbols(call[op.Parameters.First()], newDims, newSymbols)); + case TIR.Range range: + return new TIR.Range(ReplaceDimsAndSymbols(range.Start, newDims, newSymbols), ReplaceDimsAndSymbols(range.Stop, newDims, newSymbols), ReplaceDimsAndSymbols(range.Step, newDims, newSymbols)); + default: + throw new InvalidOperationException("Unreachable"); + } + } + + public static Expr[] Dims(int rank) + { + return Enumerable.Range(0, rank).Select(i => (Expr)new Var($"d{i}", DataTypes.Int32)).ToArray(); + } + + public static Expr[] Symbols(int rank) + { + return Enumerable.Range(0, rank).Select(i => (Expr)new Var($"s{i}", DataTypes.Int32)).ToArray(); + } + + public static string Display(this Expr expr) + { + switch (expr) + { + case Var var: + return var.Name; + case TensorConst @const: + return @const.Value.ToScalar().ToString(); + case Call { Target: IR.Math.Unary op } call: + return op.UnaryOp switch + { + UnaryOp.Neg => $"-{Display(call[op.Parameters.First()])}", + _ => throw new InvalidOperationException("Unreachable Unary Op"), + }; + case Call { Target: IR.Math.Binary op } call: + return op.BinaryOp switch + { + BinaryOp.Add => $"{Display(call[op.Parameters.First()])} + {Display(call[op.Parameters.Last()])}", + BinaryOp.Mul => $"{Display(call[op.Parameters.First()])} * {Display(call[op.Parameters.Last()])}", + BinaryOp.Sub => $"{Display(call[op.Parameters.First()])} - {Display(call[op.Parameters.Last()])}", + BinaryOp.Div => $"{Display(call[op.Parameters.First()])} / {Display(call[op.Parameters.Last()])}", + BinaryOp.Mod => $"{Display(call[op.Parameters.First()])} % {Display(call[op.Parameters.Last()])}", + BinaryOp.FloorDiv => $"{Display(call[op.Parameters.First()])} // {Display(call[op.Parameters.Last()])}", + BinaryOp.CeilDiv => $"{Display(call[op.Parameters.First()])} \\\\ {Display(call[op.Parameters.Last()])}", + _ => throw new InvalidOperationException("Unreachable Binary Op"), + }; + case TIR.Range rg: + return $"({rg.Start.Display()}, {rg.Stop.Display()}, {rg.Step.Display()})"; + default: + throw new InvalidOperationException("Unreachable Affine Expr"); + } + } +} + +public sealed class MapCloner : ExprCloner +{ + private readonly IReadOnlyDictionary _multiExprMap; + + public MapCloner(IReadOnlyDictionary multiExprMap) + { + _multiExprMap = multiExprMap; + } + + protected override Expr VisitLeafVar(Var expr, Unit context) + { + if (_multiExprMap.TryGetValue(expr, out var newVar)) + { + return newVar; + } + + throw new InvalidOperationException("Could not find var in map."); + } +} + +public class AffineMap +{ + public AffineMap(Expr[] dims, Expr[] symbols, Expr[] results) + { + Dims = dims; + Symbols = symbols; + Results = results; + } + + public Expr[] Dims { get; set; } + + public Expr[] Symbols { get; set; } + + public Expr[] Results { get; } + + public static AffineMap ConstantMap(int value) + { + return new AffineMap(Array.Empty(), Array.Empty(), new[] { (Expr)value }); + } + + public static AffineMap PointMap(params int[] values) + { + return new AffineMap(Array.Empty(), Array.Empty(), values.Select(v => (Expr)v).ToArray()); + } + + public static AffineMap Identity(int rank) + { + var dims = Enumerable.Range(0, rank).Select(i => (Expr)new Var($"d{i}", DataTypes.Int32)).ToArray(); + return new AffineMap(dims, Array.Empty(), dims); + } + + public static AffineMap TransposeMap() + { + var dims = new[] { (Expr)new Var("d0", DataTypes.Int32), (Expr)new Var("d1", DataTypes.Int32) }; + return new AffineMap(dims, Array.Empty(), new[] { dims[1], dims[0] }); + } + + public static AffineMap Empty() + { + return new AffineMap(Array.Empty(), Array.Empty(), Array.Empty()); + } + + public static AffineMap FromCallable(T func, int dimsNum, int symbsNum) + where T : Delegate + { + var dims = Enumerable.Range(0, dimsNum).Select(i => (Expr)new Var($"d{i}", DataTypes.Int32)).ToArray(); + var symbols = Enumerable.Range(0, symbsNum).Select(i => (Expr)new Var($"s{i}", DataTypes.Int32)).ToArray(); + var funcParams = func.Method.GetParameters(); + object? results = null; + if (funcParams.Length == 1 && funcParams[0].ParameterType.IsArray) + { + results = func.DynamicInvoke(new object[] { dims.Concat(symbols).ToArray() }); + } + else + { + results = func.DynamicInvoke(dims.Concat(symbols).ToArray()); + } + + if (results is Expr[] ret) + { + return new AffineMap(dims, symbols, ret); + } + + throw new NotSupportedException("Only Expr[] is supported."); + } + + public AffineMap ReplaceDimsAndSymbols(Expr[] newDims, Expr[] newSymbols, int skipSymbols = 0) + { + var newResults = Results.Select(expr => expr.ReplaceDimsAndSymbols(newDims, newSymbols.Skip(skipSymbols).ToArray())).ToArray(); + return new AffineMap(newDims, newSymbols, newResults); + } + + /// + /// Y->Z compose X->Y => X->Z. + /// + public AffineMap Compose(AffineMap other) + { + if (Dims.Length != other.Results.Length) + { + throw new InvalidOperationException("Cannot compose AffineMaps with mismatching dimensions and results."); + } + + var numDims = other.Dims.Length; + var numSymbols = Symbols.Length + other.Symbols.Length; + var newDims = ExprExtensions.Dims(numDims); + var newSymbols = ExprExtensions.Symbols(numSymbols); + + var newMap = other.ReplaceDimsAndSymbols(newDims, newSymbols, Symbols.Length); + var results = Results.Select(expr => expr.Compose(newMap)).ToArray(); + return new AffineMap(newMap.Dims, newMap.Symbols, results); + } + + public AffineMap InversePermutation() + { + if (Symbols.Length != 0) + { + throw new InvalidOperationException("Cannot invert AffineMap with symbols."); + } + + var foundDims = new int[Dims.Length]; + Array.Fill(foundDims, -1); + + for (int i = 0; i < Results.Length; i++) + { + if (Results[i] is { } dimExpr && foundDims[((TensorConst)dimExpr).Value.ToScalar()] == -1) + { + foundDims[((TensorConst)dimExpr).Value.ToScalar()] = i; + } + } + + if (foundDims.Any(d => d == -1)) + { + return null!; + } + + var results = foundDims.Select(i => Results[i]).ToArray(); + return new AffineMap(Results, Array.Empty(), results); + } + + public List Eval(int[] dims, int[] symbols) + { + if (dims.Length != Dims.Length || symbols.Length != Symbols.Length) + { + throw new ArgumentException("Dimension and symbol arrays must match the map's dimensions and symbols."); + } + + var feedDict = new Dictionary(); + foreach (var (first, second) in Dims.Zip(dims)) + { + feedDict.Add((Var)first, Value.FromTensor(Tensor.FromScalar(second))); + } + + foreach (var (first, second) in Symbols.Zip(symbols)) + { + feedDict.Add((Var)first, Value.FromTensor(Tensor.FromScalar(second))); + } + + return Results.Select(expr => expr.Evaluate(feedDict).AsTensor().ToScalar()).ToList(); + } + + public Expr[] Apply(Expr[] parameters) + { + if (parameters.Length != Dims.Length + Symbols.Length) + { + throw new ArgumentException("Parameters must match the map's dimensions and symbols."); + } + + Dictionary map = new(ReferenceEqualityComparer.Instance); + for (int i = 0; i < parameters.Length; i++) + { + map.Add(i < Dims.Length ? Dims[i] : Symbols[i - Dims.Length], parameters[i]); + } + + var cloner = new MapCloner(map); + + return Results.Select(r => cloner.Clone(r, default)).ToArray(); + } + + public override string ToString() + { + var dims = string.Join(", ", Enumerable.Range(0, Dims.Length).Select(i => $"d{i}")); + var syms = string.Join(", ", Enumerable.Range(0, Symbols.Length).Select(i => $"s{i}")); + var results = string.Join(", ", Results.Select(expr => expr.Display())); + + return Symbols.Length == 0 ? $"({dims}) -> ({results})" : $"({dims})[{syms}] -> ({results})"; + } +} diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/CPUFusionGroupMutator.cs b/modules/Nncase.Modules.CPU/Passes/Tile/CPUFusionGroupMutator.cs new file mode 100644 index 0000000000..5c490773df --- /dev/null +++ b/modules/Nncase.Modules.CPU/Passes/Tile/CPUFusionGroupMutator.cs @@ -0,0 +1,90 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System.Runtime.CompilerServices; +using Nncase.Diagnostics; +using Nncase.IR; +using Nncase.Passes.Mutators; +using Nncase.Targets; + +[assembly: InternalsVisibleTo("Nncase.Tests.CPU")] + +namespace Nncase.Passes.Tile; + +internal sealed class CPUSameInputFusionMergeRule : SameInputFusionMergeRule +{ + public override string ModuleKind => CPUTarget.Kind; +} + +internal sealed class CPUMultiInputFusionMergeRule : MultiInputFusionMergeRule +{ + public override string ModuleKind => CPUTarget.Kind; +} + +internal sealed class CPUShortCutFusionMergeRuleLeft : ShortCutFusionMergeRuleLeft +{ + public override string ModuleKind => CPUTarget.Kind; +} + +internal sealed class CPUShortCutFusionMergeRuleRight : ShortCutFusionMergeRuleRight +{ + public override string ModuleKind => CPUTarget.Kind; +} + +internal sealed class CPUFusionGroupMutator : FusionGroupMutator +{ + private readonly Dictionary _fusioncheckerCache; + private bool _checked; + + // private readonly TileOptions _tileOptions = null!; + public CPUFusionGroupMutator( + Dictionary fusioncheckerCache, + IMergeRewriteRule rule, + RunPassContext passOptions) + : base(rule, passOptions) + { + _fusioncheckerCache = fusioncheckerCache; + _checked = false; + } + + /// + public override bool MergedFusionCheckCallBack(Fusion mergedFusion, HashSet candidateFusions) + { + bool ok = false; + if (!_checked) + { + PrimTileVisitor primTileVisitor = new(); + primTileVisitor.Visit(mergedFusion.Body); + var checker = new FusionChecker(primTileVisitor.TileList); + + // CompilerServices.DumpDotIR(merged_fusion, "before_merge_check", PassOptions.DumpDir,true); // dump sub function. + var ret = checker.Check(mergedFusion.Body); + ok = ret.Count > 0; + + // CompilerServices.DumpDotIR(merged_fusion, "after_merge_check", PassOptions.DumpDir,true); // dump sub function. + if (ok) + { + _checked = true; + _fusioncheckerCache.Add(mergedFusion, checker); + foreach (var cand in candidateFusions) + { + // release the merged fusion. + _fusioncheckerCache.Remove(cand); + } + } + } + + return ok; + } + + public override Expr MergedFusionRewriteCallBack(Expr mergedFusionBody) + { + using var dumpScope = new DumpScope("MergedFusionClear"); + return CompilerServices.ERewrite(mergedFusionBody, new[] { new Passes.Rules.CPU.FoldStoreLoad() }, new()); + } + + protected override Expr RewriteLeafCall(Call expr) + { + return _checked ? expr : base.RewriteLeafCall(expr); + } +} diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/DeviceFusionPatterns.cs b/modules/Nncase.Modules.CPU/Passes/Tile/DeviceFusionPatterns.cs new file mode 100644 index 0000000000..b7faa1f663 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Passes/Tile/DeviceFusionPatterns.cs @@ -0,0 +1,27 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.PatternMatch; +using static Nncase.PatternMatch.Utility; + +namespace Nncase.Passes.Tile; + +internal static class DeviceFusionPatterns +{ + public static Pattern UnaryUnaryPattern() + { + var v0 = IsVar("input"); + var v1 = PatternMatch.F.Math.IsUnary(null, "callee", _ => true, v0); + var v2 = PatternMatch.F.Math.IsUnary(null, "caller", _ => true, v1); + return v2; + } + + public static Pattern MatmulUnaryPattern() + { + var v00 = IsVar("lhs"); + var v01 = IsVar("rhs"); + var v1 = PatternMatch.F.Math.IsMatMul(null, "callee", _ => true, v00, v01); + var v2 = PatternMatch.F.Math.IsUnary(null, "caller", _ => true, v1); + return v2; + } +} diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/DeviceToTIRVisitor.cs b/modules/Nncase.Modules.CPU/Passes/Tile/DeviceToTIRVisitor.cs new file mode 100644 index 0000000000..78fe77aa8e --- /dev/null +++ b/modules/Nncase.Modules.CPU/Passes/Tile/DeviceToTIRVisitor.cs @@ -0,0 +1,622 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +#define USE_KERNEL_LIB +using System.Linq; +using System.Reactive; +using NetFabric.Hyperlinq; +using Nncase.CostModel; +using Nncase.IR; +using Nncase.IR.Imaging; +using Nncase.IR.Math; +using Nncase.IR.NN; +using Nncase.IR.Tensors; +using Nncase.PatternMatch; +using Nncase.TIR; +using Nncase.TIR.Builders; +using Nncase.Utilities; +using Buffer = Nncase.TIR.Buffer; + +namespace Nncase.Passes.Tile; + +internal struct TileScope : IDisposable +{ + private static readonly List> _bufferMapStack = new(); + private static readonly List _blockBuilderStack = new(); + private static readonly Stack _frames = new(); + private static readonly List>> _loopBuildersStack = new(); + private static readonly List> _loopVarsStack = new(); + + public TileScope(TileFrame frame) + { + _frames.Push(frame); + frame.Enter(); + } + + public static IBlockBuilder CurrentBlock => _blockBuilderStack.Count == 0 ? null! : _blockBuilderStack[^1]; + + public static IReadOnlyDictionary CurrentMap => _bufferMapStack.Count == 0 ? null! : _bufferMapStack[^1]; + + public static IReadOnlyList CurrentLoopVars => _loopVarsStack.Count == 0 ? null! : _loopVarsStack[^1]; + + public static IReadOnlyList> LoopVarStack => _loopVarsStack; + + public static IReadOnlyList> CurrentLoops => _loopBuildersStack.Count == 0 ? null! : _loopBuildersStack[^1]; + + public void Dispose() + { + var frame = _frames.Pop(); + frame.Exit(); + } + + public abstract class TileFrame + { + public abstract void Enter(); + + public abstract void Exit(); + } + + public sealed class PushMemoryFrame : TileFrame + { + private readonly Dictionary _bufferMap; + private readonly IBlockBuilder _fusionBlock; + private readonly ISequentialBuilder[] _builders; + private readonly Var[] _vars; + + public PushMemoryFrame(Dictionary bufferMap, IBlockBuilder fusionBlock, ISequentialBuilder[] builders, Var[] vars) + { + _bufferMap = bufferMap; + _fusionBlock = fusionBlock; + _builders = builders; + _vars = vars; + } + + public override void Enter() + { + _bufferMapStack.Add(_bufferMap); + _blockBuilderStack.Add(_fusionBlock); + _loopBuildersStack.Add(new(_builders)); + _loopVarsStack.Add(new(_vars)); + } + + public override void Exit() + { + _bufferMapStack.RemoveAt(_bufferMapStack.Count - 1); + _blockBuilderStack.RemoveAt(_blockBuilderStack.Count - 1); + _loopBuildersStack.RemoveAt(_loopBuildersStack.Count - 1); + _loopVarsStack.RemoveAt(_loopVarsStack.Count - 1); + } + } + + public sealed class PushLoopFrame : TileFrame + { + private readonly ISequentialBuilder[] _builders; + private readonly Var[] _vars; + + public PushLoopFrame(ISequentialBuilder[] builders, Var[] vars) + { + _builders = builders; + _vars = vars; + } + + public override void Enter() + { + _loopBuildersStack[^1].AddRange(_builders); + _loopVarsStack[^1].AddRange(_vars); + } + + public override void Exit() + { + var total = _loopBuildersStack[^1].Count; + int length = _builders.Length; + _loopBuildersStack[^1].RemoveRange(total - length, length); + total = _loopVarsStack[^1].Count; + length = _vars.Length; + _loopVarsStack[^1].RemoveRange(total - length, length); + } + } +} + +internal sealed class DeviceFusionToPrimFuncRewriter : ExprRewriter +{ + private readonly HashSet _primFunctions = new(ReferenceEqualityComparer.Instance); + private readonly IReadOnlyDictionary _fusionCheckCache; + + public DeviceFusionToPrimFuncRewriter(Dictionary fusionCheckCache) + { + _fusionCheckCache = fusionCheckCache; + } + + public HashSet PrimFunctions => _primFunctions; + + protected override Expr DefaultRewriteLeaf(Expr expr) => base.DefaultRewriteLeaf(expr); + + protected override Expr RewriteLeafFusion(Fusion expr) + { + if (expr.ModuleKind == Targets.CPUTarget.Kind && expr.Name.EndsWith("device")) + { + // var oldBody = expr.Body; + // PrimTileVisitor primTileVisitor = new(); + // primTileVisitor.Visit(oldBody); + // FusionChecker fusionChecker = new(primTileVisitor.TileList, primTileVisitor.NameList); + // var tileMap = fusionChecker.Check(oldBody)[0]; + if (!_fusionCheckCache.TryGetValue(expr, out var cachedChecker)) + { + PrimTileVisitor primTileVisitor = new(); + primTileVisitor.Visit(expr.Body); + cachedChecker = new FusionChecker(primTileVisitor.TileList); + cachedChecker.Check(expr.Body); + } + + if (cachedChecker.CheckedResult.Count != 1) + { + throw new NotSupportedException("Not support no uniform shard!"); + } + + var (_, tileMap) = cachedChecker.CheckedResult[0]; + + // var tileShape = tileMap[oldBody].OutShape; + // var newBody = IR.F.CPU.Store( + // tileShape, + // new TileType(TIR.MemoryLocation.Output, DistributedUtility.GetDividedTensorType((DistributedType)oldBody.CheckedType)), + // new TileFusionLowerCloner(tileMap).Clone(oldBody, default)); + + // var egraph = new EGraph(newBody); + // CompilerServices.ERewrite(egraph, new IRewriteRule[] { new UnaryL1Fusion(), new MatmulL1Fusion() }, new()); + // var tiledBody = egraph.Extract(egraph.Root!, new TileFusionCostEvaluator(), out var _); + // var newfusion = new Fusion(expr.Name, Targets.CPUTarget.Kind, tiledBody, expr.Parameters); + + // if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.Tiling)) + // { + // Diagnostics.DumpScope.Current.DumpIR(newfusion, string.Empty, "L1Tiled"); + // } + + // var allocMap = fusionChecker.ReAllocate(newfusion.Body, true); + var converter = new DeviceToTIRConverter(expr, tileMap); + var primfunc = converter.Convert(); + _primFunctions.Add(primfunc); + return primfunc; + } + + return expr; + } +} + +internal sealed class TileFusionCostEvaluator : Evaluator.IBaseFuncCostEvaluator +{ + public Cost VisitLeaf(BaseFunction target) + { + return new Cost() + { + [CostFactorNames.CPUCycles] = 1000, + }; + } +} + +internal sealed class DeviceToTIRConverter +{ + private readonly Fusion _fusion; + private readonly IReadOnlyDictionary _tileMemo; + private readonly Dictionary _regionMemo; + + public DeviceToTIRConverter(Fusion expr, IReadOnlyDictionary tileMap) + { + _fusion = expr; + _tileMemo = tileMap; + _regionMemo = new(ReferenceEqualityComparer.Instance); + } + + public TIR.PrimFunction Convert() + { + var shape = _fusion.Body.CheckedShape; + var func = T.PrimFunc(_fusion.Name, Targets.CPUTarget.Kind, _fusion.Parameters.ToArray().Select(p => _tileMemo[p].Buffer).Concat(new[] { _tileMemo[_fusion.Body].Buffer }).ToArray()).Body( + Visit(_fusion, AffineMap.Identity(shape.Rank), null!, out _)); + return func.Build(); + } + + public Expr Visit(Expr expr, AffineMap rootMap, BufferRegion outRegion, out AffineMap[] inputMaps) + { + inputMaps = Array.Empty(); + return expr switch + { + Call call => (call.Target switch + { + IR.CPU.Load op => LowerLoad(call, op, rootMap, outRegion, out inputMaps), + IR.CPU.Store op => LowerStore(call, op, rootMap, outRegion, out inputMaps), + IR.Math.Unary op => LowerUnary(call, op, rootMap, outRegion, out inputMaps), + IR.Math.MatMul op => LowerMatmul(call, op, rootMap, outRegion, out inputMaps), + IR.Math.Binary op => LowerBinary(call, op, rootMap, outRegion, out inputMaps), + Fusion func => LowerFusion(call, func, rootMap, outRegion, out inputMaps), + _ => throw new NotSupportedException(), + }).Build(), + Fusion func => LowerFusion(null, func, rootMap, outRegion, out inputMaps).Build(), + _ => T.Nop(), + }; + } + + private ISequentialBuilder LowerMatmul(Call call, MatMul op, AffineMap rootMap, BufferRegion outRegion, out AffineMap[] inputMaps) + { + var lhsTile = GetTile(call.Arguments[0]); + var lhsShape = GetShape(call.Arguments[0]); + var rhsShape = GetShape(call.Arguments[1]); + var rhsTile = GetTile(call.Arguments[1]); + var tileShape = GetTile(call); + var fullShape = GetShape(call); + + Expr[] PostProcessAffineMap(List iters, IReadOnlyList inShape, IReadOnlyList outShape) + { + var ralign = outShape.Count - inShape.Count; + for (int i = outShape.Count - 1; i >= 0; i--) + { + if (i < ralign) + { + iters.RemoveAt(i); + } + else if (i < (outShape.Count - 2) && inShape[i] == 1 && outShape[i] != 1) + { + iters[i] = 0; + } + } + + return iters.ToArray(); + } + + var outKLoop = T.ForLoop(out var ok, new TIR.Range(0, lhsShape[^1], lhsTile[^1]), LoopMode.Serial); + using (new TileScope(new TileScope.PushLoopFrame(new[] { outKLoop }, new[] { ok }))) + { + Expr[] LhsFunc(params Expr[] exprs) + { + return PostProcessAffineMap(exprs[..^2].Concat(new[] { exprs[^1] }).ToList(), lhsShape, fullShape); + } + + Expr[] RhsFunc(params Expr[] exprs) + { + return PostProcessAffineMap(exprs[..^3].Concat(new[] { exprs[^1], exprs[^2] }).ToList(), rhsShape, fullShape); + } + + var lhsMap = AffineMap.FromCallable(LhsFunc, fullShape.Count, 1).Compose(rootMap); + var rhsMap = AffineMap.FromCallable(RhsFunc, fullShape.Count, 1).Compose(rootMap); + + var outStarts = outRegion.Region.ToArray().Select(r => r.Start).ToList(); + outStarts.Add(0); + var outStops = outRegion.Region.ToArray().Select(r => r.Stop).ToList(); + outStops.Add(IR.F.Math.Min(ok + lhsTile[^1], lhsShape[^1]) - ok); + + var lhsRegion = GetBufferRegion(call.Arguments[0], (TIR.Buffer lhsBuffer) => + { + var lhsStarts = lhsMap.Apply(outStarts.ToArray()); + var lhsStops = lhsMap.Apply(outStops.ToArray()); + return new BufferRegion(lhsBuffer, lhsStarts.Zip(lhsStops).Select(p => new TIR.Range(p.First, p.Second, 1)).ToArray()); + }); + + var rhsRegion = GetBufferRegion(call.Arguments[1], (TIR.Buffer rhsBuffer) => + { + var rhsStarts = rhsMap.Apply(outStarts.ToArray()); + var rhsStops = rhsMap.Apply(outStops.ToArray()); + return new BufferRegion(rhsBuffer, rhsStarts.Zip(rhsStops).Select(p => new TIR.Range(p.First, p.Second, 1)).ToArray()); + }); + TileScope.CurrentBlock.Alloc(outRegion.Buffer); + var block = T.Block(nameof(MatMul)). + Reads(lhsRegion, rhsRegion). + Writes(outRegion); + outKLoop.Body( + Visit(call.Arguments[0], lhsMap, lhsRegion, out var lhsInputMaps), + Visit(call.Arguments[1], rhsMap, rhsRegion, out var rhsInputMaps), + block); +#if USE_KERNEL_LIB + block.Body(TIR.F.CPU.Matmul(lhsRegion, rhsRegion, outRegion)); +#else + // var lhsStarts = lhsRegion.Region.ToArray().Select(r => (T.Let(out var start, r.Start), start)).ToArray(); + // var rhsStarts = rhsRegion.Region.ToArray().Select(r => (T.Let(out var start, r.Start), start)).ToArray(); + // var outLetStarts = outStarts.ToArray().Select(r => (T.Let(out var start, r), start)).ToArray(); + var stopLets = outStops.Select((s, i) => (T.Let(out var stop, s, $"stop{i}"), stop)).ToArray(); + var compute = T.Grid(out var vars, LoopMode.Serial, stopLets.Select((p, i) => new TIR.Range(0, p.stop, i < stopLets.Length - 3 ? 1 : 32)).ToArray()).Body( + T.Let(out var curM, IR.F.Math.Min(stopLets[^3].stop - vars[^3], 32)).Body( + T.Let(out var curN, IR.F.Math.Min(stopLets[^2].stop - vars[^2], 32)).Body( + T.Let(out var curK, IR.F.Math.Min(stopLets[^1].stop - vars[^1], 32)).Body( + TIR.F.CPU.TMMA( + GetBufferPtr(lhsRegion, lhsMap.Apply(vars).Select((v, i) => v + lhsRegion.Region[i].Start).ToArray()), + GetBufferPtr(rhsRegion, rhsMap.Apply(vars).Select((v, i) => v + rhsRegion.Region[i].Start).ToArray()), + GetBufferPtr(outRegion, vars.SkipLast(1).Select((v, i) => v + outRegion.Region[i].Start).ToArray()), + curM, + curK, + curN, + lhsRegion.Buffer.Strides[^2], + rhsRegion.Buffer.Strides[^2], + outRegion.Buffer.Strides[^2], + DataTypes.Float32, + lhsRegion.Buffer.ElemType, + outRegion.Buffer.ElemType, + IR.F.Math.NotEqual(vars[^1] + ok, 0)))))); + + var final = stopLets.Select(p => p.Item1).Aggregate((acc, cur) => + { + acc.Body(cur); + return cur; + }); + final.Body(compute); + block.Body(stopLets[0].Item1); +#endif + } + + // var fullK = ((TileType)call.Arguments[0].CheckedType).TensorType.Shape[^1].FixedValue; + Expr[] LhsInFunc(params Expr[] exprs) => PostProcessAffineMap(exprs[..^1].Concat(new Expr[] { 0 }).ToList(), lhsShape, fullShape); + Expr[] RhsInFunc(params Expr[] exprs) => PostProcessAffineMap(exprs[..^2].Concat(new Expr[] { 0, exprs[^1] }).ToList(), rhsShape, fullShape); + + // root = (b,c,m,n) -> (b,c,m,n) + // lhs loop vars = b,c,m,k + inputMaps = new[] { + AffineMap.FromCallable(LhsInFunc, fullShape.Count, 0).Compose(rootMap), + AffineMap.FromCallable(RhsInFunc, fullShape.Count, 0).Compose(rootMap), + }; + + return T.Sequential().Body(outKLoop); + } + + private ISequentialBuilder LowerLoad(Call call, IR.CPU.Load load, AffineMap rootMap, BufferRegion outRegion, out AffineMap[] inputMaps) + { + var tileShape = GetTile(call); + var inShape = GetShape(call.Arguments[0]); + var iterVars = rootMap.Apply(TileScope.CurrentLoopVars.ToArray()); + inputMaps = new[] { rootMap }; + + var inRegion = GetBufferRegion(call.Arguments[0], (TIR.Buffer inBuffer) => + new BufferRegion(inBuffer, Enumerable.Range(0, tileShape.Count).Select(i => + { + var iterV = iterVars[i]; + return new TIR.Range(iterV, IR.F.Math.Min(iterV + tileShape[i], inShape[i]), 1); + }).ToArray())); + TileScope.CurrentBlock.Alloc(outRegion.Buffer); + var block = T.Block("load"). + Reads(inRegion). + Writes(outRegion); + var seq = T.Sequential().Body( + Visit(call.Arguments[0], rootMap, inRegion, out var _), + block); +#if USE_KERNEL_LIB + block.Body(TIR.F.CPU.Memcopy(outRegion, inRegion)); +#else + // var inStarts = inRegion.Region.ToArray().Select(r => (T.Let(out var start, r.Start), start)).ToArray(); + // var outStarts = outRegion.Region.ToArray().Select(r => (T.Let(out var start, r.Start), start)).ToArray(); + var compute = T.Grid(out var vars, LoopMode.Serial, inRegion.Region.ToArray().Select(r => new TIR.Range(0, r.Stop - r.Start, 1)).ToArray()). + Body( + T.BufferStore(outRegion.Buffer, vars.Select((v, i) => v + outRegion.Region[i].Start).ToArray(), T.BufferLoad(inRegion.Buffer, vars.Select((v, i) => v + inRegion.Region[i].Start).ToArray()))); + + // var final = inStarts.Concat(outStarts).Select(p => p.Item1).Aggregate((acc, cur) => + // { + // acc.Body(cur); + // return cur; + // }); + // final.Body(compute); + // block.Body(inStarts[0].Item1); + block.Body(compute); +#endif + + return seq; + } + + private ISequentialBuilder LowerStore(Call call, IR.CPU.Store store, AffineMap rootMap, BufferRegion outRegion, out AffineMap[] inputMaps) + { + var iterVars = rootMap.Apply(TileScope.CurrentLoopVars.ToArray()); + var tileShape = GetTile(call); + var outShape = GetShape(call); + + outRegion = GetBufferRegion(call, (TIR.Buffer outBuffer) => + new BufferRegion(outBuffer, Enumerable.Range(0, tileShape.Count).Select(i => + { + var iterV = iterVars[i]; + return new TIR.Range(iterV, IR.F.Math.Min(iterV + tileShape[i], outShape[i]), 1); + }).ToArray())); + + var inRegion = GetBufferRegion(call.Arguments[0], (TIR.Buffer inBuffer) => + new BufferRegion(inBuffer, Enumerable.Range(0, tileShape.Count).Select(i => + { + // var iterV = iterVars[i]; + return new TIR.Range(0, outRegion.Region[i].Stop - outRegion.Region[i].Start, 1); + }).ToArray())); + + var block = T.Block(nameof(store)). + Reads(inRegion). + Writes(outRegion); + var seq = T.Sequential().Body( + Visit(call.Arguments[0], rootMap, inRegion, out inputMaps), + block); +#if USE_KERNEL_LIB + block.Body(TIR.F.CPU.Memcopy(outRegion, inRegion)); +#else + // var inStarts = inRegion.Region.ToArray().Select(r => (T.Let(out var start, r.Start), start)).ToArray(); + // var outStarts = outRegion.Region.ToArray().Select(r => (T.Let(out var start, r.Start), start)).ToArray(); + var compute = T.Grid(out var vars, LoopMode.Serial, inRegion.Region.ToArray().Select(r => new TIR.Range(0, r.Stop - r.Start, 1)).ToArray()). + Body( + T.BufferStore(outRegion.Buffer, vars.Select((v, i) => v + outRegion.Region[i].Start).ToArray(), T.BufferLoad(inRegion.Buffer, vars.Select((v, i) => v + inRegion.Region[i].Start).ToArray()))); + + // var final = inStarts.Concat(outStarts).Select(p => p.Item1).Aggregate((acc, cur) => + // { + // acc.Body(cur); + // return cur; + // }); + // final.Body(compute); + // block.Body(inStarts[0].Item1); + block.Body(compute); +#endif + return seq; + } + + private ISequentialBuilder LowerBinary(Call call, Binary op, AffineMap rootMap, BufferRegion outRegion, out AffineMap[] inputMaps) + { + var lhsShape = GetShape(call.Arguments[0]); + var rhsShape = GetShape(call.Arguments[1]); + var fullShape = GetShape(call); + var lhsRegion = GetBufferRegion(call.Arguments[0], (TIR.Buffer inBuffer) => new BufferRegion(inBuffer, outRegion.Region)); + var rhsRegion = GetBufferRegion(call.Arguments[1], (TIR.Buffer inBuffer) => new BufferRegion(inBuffer, outRegion.Region)); + TileScope.CurrentBlock.Alloc(outRegion.Buffer); + + Expr[] PostProcessAffineMap(List iters, IReadOnlyList inShape, IReadOnlyList outShape) + { + var ralign = outShape.Count - inShape.Count; + for (int i = outShape.Count - 1; i >= 0; i--) + { + if (i < ralign) + { + iters.RemoveAt(i); + } + else if (i < (outShape.Count - 2) && inShape[i] == 1 && outShape[i] != 1) + { + iters[i] = 0; + } + } + + return iters.ToArray(); + } + + Expr[] LhsInFunc(params Expr[] exprs) => PostProcessAffineMap(exprs.ToList(), lhsShape, fullShape); + Expr[] RhsInFunc(params Expr[] exprs) => PostProcessAffineMap(exprs.ToList(), rhsShape, fullShape); + + inputMaps = new[] { + AffineMap.FromCallable(LhsInFunc, fullShape.Count, 0).Compose(rootMap), + AffineMap.FromCallable(RhsInFunc, fullShape.Count, 0).Compose(rootMap), + }; + + var block = T.Block("binary"). + Reads(lhsRegion, rhsRegion). + Writes(outRegion); + var seq = T.Sequential().Body( + Visit(call.Arguments[0], rootMap, lhsRegion, out _), + Visit(call.Arguments[1], rootMap, rhsRegion, out _), + block); +#if USE_KERNEL_LIB + block.Body(TIR.F.CPU.Binary(op.BinaryOp, lhsRegion, rhsRegion, outRegion)); +#else + throw new NotSupportedException(); +#endif + return seq; + } + + private ISequentialBuilder LowerUnary(Call call, Unary op, AffineMap rootMap, BufferRegion outRegion, out AffineMap[] inputMaps) + { + // var iterVars = rootMap.Apply(TileScope.CurrentLoopVars.ToArray()); + var inRegion = GetBufferRegion(call.Arguments[0], (TIR.Buffer inBuffer) => new BufferRegion(inBuffer, outRegion.Region)); + TileScope.CurrentBlock.Alloc(outRegion.Buffer); + inputMaps = new[] { rootMap }; + var block = T.Block("unary"). + Reads(inRegion). + Writes(outRegion); + var seq = T.Sequential().Body( + Visit(call.Arguments[0], rootMap, inRegion, out _), + block); +#if USE_KERNEL_LIB + block.Body(TIR.F.CPU.Unary(op.UnaryOp, inRegion, outRegion)); +#else + // var inStarts = inRegion.Region.ToArray().Select(r => (T.Let(out var start, r.Start), start)).ToArray(); + // var outStarts = outRegion.Region.ToArray().Select(r => (T.Let(out var start, r.Start), start)).ToArray(); + var compute = T.Grid(out var vars, LoopMode.Serial, inRegion.Region.ToArray().Select(r => new TIR.Range(0, r.Stop - r.Start, 1)).ToArray()). + Body( + T.BufferStore(outRegion.Buffer, vars.Select((v, i) => v + outRegion.Region[i].Start).ToArray(), IR.F.Math.Unary(op.UnaryOp, T.BufferLoad(inRegion.Buffer, vars.Select((v, i) => v + inRegion.Region[i].Start).ToArray())))); + + // var final = inStarts.Concat(outStarts).Select(p => p.Item1).Aggregate((acc, cur) => + // { + // acc.Body(cur); + // return cur; + // }); + // final.Body(compute); + // block.Body(inStarts[0].Item1); + block.Body(compute); +#endif + return seq; + } + + private ISequentialBuilder LowerFusion(Call? call, Fusion func, AffineMap rootMap, BufferRegion outRegion, out AffineMap[] inputMaps) + { + if (func.Body is not Call { Target: IR.CPU.Store store }) + { + throw new NotSupportedException(); + } + + // var inBuffer = call is null ? GetBuffer(func.Parameters[0]) : GetBuffer(call.Arguments[0]); + // var outBuffer = call is null ? GetBuffer(func.Body) : GetBuffer(call); + + // 1. func body + var fusionBlock = T.Block("main"); + var outShape = GetShape(func.Body); + var outTile = GetTile(func.Body); + var nestBuilder = T.Grid(out var loopVars, out var loops, LoopMode.Serial, Enumerable.Range(0, outShape.Count).Select(i => new TIR.Range(0, outShape[i], outTile[i])).ToArray()); + + AffineMap[] bodyinputMaps; + using (new TileScope( + new TileScope.PushMemoryFrame( + new Dictionary(ReferenceEqualityComparer.Instance) + { + // { func.Parameters[0], inBuffer }, { func.Body, outBuffer }, + }, + fusionBlock, + loops, + loopVars))) + { + fusionBlock.Body( + nestBuilder.Body( + Visit(func.Body, rootMap, outRegion, out bodyinputMaps))); + } + + var seq = T.Sequential(); + + inputMaps = bodyinputMaps; + if (call is not null) + { + for (int i = 0; i < call.Arguments.Length; i++) + { + AffineMap[] inmaps = Array.Empty(); + seq.Body(Visit(call.Arguments[i], bodyinputMaps[i], outRegion, out _)); + } + } + + // 2. visit args. + return seq.Body(fusionBlock); + } + + private TIR.Range[] ComputeRanges(IReadOnlyList tiles, AffineMap rootMap) + { + var starts = rootMap.Apply(TileScope.CurrentLoopVars.ToArray()); + return starts.Zip(tiles).Select(p => new TIR.Range(p.First, p.First + p.Second, 1)).ToArray(); + } + + private Expr[] ComputeIndcies(TIR.Buffer top, Expr[] loopvars, AffineMap rootMap) + { + var topLevel = top.MemSpan.Location switch + { + MemoryLocation.Input or MemoryLocation.Output or MemoryLocation.Rdata => 0, + MemoryLocation.L2Data => 1, + _ => throw new InvalidDataException(), + }; + + var newLoopvars = loopvars.ToArray(); + + for (int level = TileScope.LoopVarStack.Count - 1; level >= topLevel; level--) + { + var mappedVars = rootMap.Apply(TileScope.LoopVarStack[level].ToArray()); + System.Diagnostics.Trace.Assert(mappedVars.Length == newLoopvars.Length); + + for (int i = 0; i < newLoopvars.Length; i++) + { + newLoopvars[i] += mappedVars[i]; + } + } + + return newLoopvars; + } + + private IReadOnlyList GetTile(Expr expr) => _tileMemo[expr].TileShape; + + private IReadOnlyList GetShape(Expr expr) => _tileMemo[expr].OutShape; + + private BufferRegion GetBufferRegion(Expr expr, Func createFunc) + { + var buf = _tileMemo[expr].Buffer; + if (!_regionMemo.TryGetValue(expr, out var region)) + { + region = createFunc(buf); + _regionMemo.Add(expr, region); + } + + return region; + } +} diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/FusionChecker.cs b/modules/Nncase.Modules.CPU/Passes/Tile/FusionChecker.cs new file mode 100644 index 0000000000..2482e0821b --- /dev/null +++ b/modules/Nncase.Modules.CPU/Passes/Tile/FusionChecker.cs @@ -0,0 +1,660 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System.Collections; +using System.Reactive; +using DryIoc; +using NetFabric.Hyperlinq; +using Nncase.Evaluator.Tensors; +using Nncase.IR; +using Nncase.IR.Math; +using Nncase.IR.Tensors; +using Nncase.Passes.BufferSchedule; +using Nncase.TIR; +using Nncase.Utilities; + +namespace Nncase.Passes.Tile; + +public enum ConditionKind +{ + Norm, + Tail, +} + +public record BucketCondition(ConditionKind Bid, ConditionKind Tid, ConditionKind BidTid) +{ +} + +public sealed class NodeInfo : IDisposable +{ + private readonly ExprPinner _pinner; + private readonly TIR.Buffer? _buffer; + + public NodeInfo(TIR.Buffer? buffer, int[] tileShape, int[] outShape) + { + _buffer = buffer; + TileShape = tileShape; + OutShape = outShape; + if (_buffer is not null) + { + _pinner = new ExprPinner(_buffer); + } + else + { + _pinner = new ExprPinner(); + } + } + + public TIR.Buffer Buffer => _buffer!; + + public IReadOnlyList OutShape { get; } + + public int[] TileShape { get; set; } + + public void Dispose() => _pinner.Dispose(); +} + +internal sealed record TileFragment(BucketCondition Condition, IReadOnlyDictionary TileMap) +{ +} + +internal sealed class FusionChecker +{ + private readonly List> _initTileList; + private IReadOnlyList? _checkedResult; + + public FusionChecker(List> initTileList) + { + _initTileList = initTileList; + } + + public IReadOnlyList CheckedResult => _checkedResult!; + + public IReadOnlyList Check(Expr root) + { + if (_checkedResult is not null) + { + return _checkedResult; + } + + var (buckets, conditions) = GetSplitBuckets(); + var tileMaps = new Dictionary[buckets.Count]; + + for (var b = 0; b < buckets.Count; b++) + { + var bucket = buckets[b]; + Dictionary tileMap = new(); + + var updatedTileShape = _initTileList.Last().Value.ToArray(); + if (_initTileList.Any(kv => kv.Key is Call { Target: MatMul })) + { + var candidateKs = GetCandidateKs(bucket); + + // search k first + int finalK = 0; + for (var k = 0; k < candidateKs.Count; k++) + { + tileMap.Clear(); + tileMap.Add(root, new(null!, updatedTileShape, bucket[root])); + Visit((Call)root, tileMap, bucket, candidateKs, k); + var ok = TryAllocate(tileMap, bucket); + if (ok) + { + tileMaps[b] = tileMap.ToDictionary(kv => kv.Key, kv => kv.Value); + finalK = k; + } + else + { + break; + } + } + + for (var r = root.CheckedShape.Rank - 1; r >= 0; r--) + { + if (_initTileList.Last().Value[r] == 32) + { + tileMap.Clear(); + while (true) + { + tileMap!.Add(root, new NodeInfo(null!, updatedTileShape, bucket[root])); + Visit((Call)root, tileMap, bucket, candidateKs, finalK); + var ok = TryAllocate(tileMap, bucket); + if (ok) + { + tileMaps[b] = tileMap.ToDictionary(kv => kv.Key, kv => kv.Value); + if (updatedTileShape[r] + 32 > bucket[root][r]) + { + break; + } + + updatedTileShape[r] += 32; + } + else + { + updatedTileShape[r] -= 32; + break; + } + + tileMap.Clear(); + } + } + } + } + else + { + for (var r = root.CheckedShape.Rank - 1; r >= 0; r--) + { + var incr = r == root.CheckedShape.Rank - 1 ? 32 : 1; + tileMap.Clear(); + while (true) + { + tileMap.Add(root, new(null!, updatedTileShape, bucket[root])); + Visit((Call)root, tileMap, bucket, new()); + var ok = TryAllocate(tileMap, bucket); + if (ok) + { + tileMaps[b] = tileMap.ToDictionary(kv => kv.Key, kv => kv.Value); + if (updatedTileShape[r] + incr > bucket[root][r]) + { + break; + } + + updatedTileShape[r] += incr; + } + else + { + updatedTileShape[r] -= incr; + break; + } + + tileMap.Clear(); + } + } + } + } + + for (int b = 0; b < buckets.Count; b++) + { + TryAllocate(tileMaps[b], buckets[b], true); + } + + return _checkedResult = conditions.Zip(tileMaps).Select(p => new TileFragment(p.First, p.Second)).ToList(); + } + + private static List> GetCandidateKs(Dictionary bucket) + { + var allKs = new Dictionary>(); + foreach (var kv in bucket) + { + if (kv.Key is Call { Target: MatMul op } call) + { + var k = bucket[call[op.Parameters.First()]].Last(); + var ks = new List(); + for (int i = 32; i < k; i += 32) + { + ks.Add(i); + } + + ks.Add(k); + allKs.Add(kv.Key, ks); + } + } + + IEnumerable>> ret = new[] { Enumerable.Empty>() }; + foreach (var kvp in allKs) + { + ret = from seq in ret + from item in kvp.Value + select seq.Concat(new[] { new KeyValuePair(kvp.Key, item) }); + } + + return ret.Select(seq => seq.ToDictionary(kv => kv.Key, kv => kv.Value)).ToList(); + } + + private (List> Buckets, List Conditions) GetSplitBuckets() + { + var buckets = new Dictionary>(); + foreach (var s in GetCandidateBuckets()) + { + buckets.Add(s, new()); + } + + foreach (var kv in _initTileList) + { + var ndSbp = ((DistributedType)kv.Key.CheckedType).NdSBP; + var hierarchy = ((DistributedType)kv.Key.CheckedType).Placement.Hierarchy; + var divided = Enumerable.Range(0, ndSbp.Count).Where(i => ndSbp[i] is SBPSplit).Select(i => (((SBPSplit)ndSbp[i]).Axis, hierarchy[i])).ToArray(); + var dividedSlice = DistributedUtility.TryGetNonUniformDividedSlice((DistributedType)kv.Key.CheckedType); + if (dividedSlice.Count == 1) + { + foreach (BucketCondition s in GetCandidateBuckets()) + { + buckets[s].Add(kv.Key, dividedSlice[0]); + } + } + else + { + switch (divided.Length) + { + case 1 when hierarchy[0] == divided[0].Item2: + foreach (BucketCondition s in Enum.GetValues(typeof(BucketCondition))) + { + if (s is BucketCondition { Bid: ConditionKind.Norm }) + { + buckets[s].Add(kv.Key, dividedSlice[0]); + } + else + { + buckets[s].Add(kv.Key, dividedSlice[1]); + } + } + + break; + case 1 when hierarchy[1] == divided[0].Item2: + foreach (BucketCondition s in GetCandidateBuckets()) + { + if (s is BucketCondition { Tid: ConditionKind.Norm }) + { + buckets[s].Add(kv.Key, dividedSlice[0]); + } + else + { + buckets[s].Add(kv.Key, dividedSlice[1]); + } + } + + break; + case 2 when divided[0].Axis == divided[1].Axis: + foreach (BucketCondition s in GetCandidateBuckets()) + { + if (s is BucketCondition { BidTid: ConditionKind.Norm }) + { + buckets[s].Add(kv.Key, dividedSlice[0]); + } + else + { + buckets[s].Add(kv.Key, dividedSlice[1]); + } + } + + break; + case 2 when divided[0].Axis != divided[1].Axis: + if (dividedSlice.Count == 2) + { + if (kv.Key.CheckedShape[divided[0].Axis].FixedValue % hierarchy[0] == 0) + { + foreach (BucketCondition s in GetCandidateBuckets()) + { + if (s is BucketCondition { Tid: ConditionKind.Norm }) + { + buckets[s].Add(kv.Key, dividedSlice[0]); + } + else + { + buckets[s].Add(kv.Key, dividedSlice[1]); + } + } + } + else + { + foreach (BucketCondition s in GetCandidateBuckets()) + { + if (s is BucketCondition { BidTid: ConditionKind.Norm }) + { + buckets[s].Add(kv.Key, dividedSlice[0]); + } + else + { + buckets[s].Add(kv.Key, dividedSlice[1]); + } + } + } + } + + if (dividedSlice.Count == 4) + { + foreach (BucketCondition s in GetCandidateBuckets()) + { + if (s is BucketCondition { Bid: ConditionKind.Norm, Tid: ConditionKind.Norm }) + { + buckets[s].Add(kv.Key, dividedSlice[0]); + } + else if (s is BucketCondition { Bid: ConditionKind.Norm, Tid: ConditionKind.Tail }) + { + buckets[s].Add(kv.Key, dividedSlice[1]); + } + else if (s is BucketCondition { Bid: ConditionKind.Tail, Tid: ConditionKind.Norm }) + { + buckets[s].Add(kv.Key, dividedSlice[2]); + } + else + { + buckets[s].Add(kv.Key, dividedSlice[3]); + } + } + } + + break; + default: + throw new NotImplementedException("Not support split"); + } + } + } + + List> ret = new(); + List conditions = new(); + foreach (BucketCondition s in GetCandidateBuckets()) + { + var bucket = buckets[s]; + bool redundant = false; + foreach (var b in ret) + { + if (bucket.All(kv => kv.Value.SequenceEqual(b[kv.Key]))) + { + redundant = true; + } + + if (redundant) + { + break; + } + } + + if (!redundant) + { + conditions.Add(s); + ret.Add(bucket); + } + } + + return (ret, conditions); + } + + private IEnumerable GetCandidateBuckets() => + new[] { + new[] { ConditionKind.Norm, ConditionKind.Tail }, + new[] { ConditionKind.Norm, ConditionKind.Tail }, + new[] { ConditionKind.Norm, ConditionKind.Tail }, + }.CartesianProduct(). + Select(p => p.ToArray()). + Select(a => new BucketCondition(a[0], a[1], a[2])); + + private bool TryAllocate(Dictionary tileMap, Dictionary bucket, bool finalAllocate = false) + { + var tileList = new List>(); + var exprs = ExprCollector.Collect(_initTileList.Last().Key).Where(e => e is not Op); + foreach (var expr in exprs) + { + tileList.Add(new(expr, tileMap[expr])); + } + + var tileBuffer = TryAllocate(tileList, bucket, finalAllocate); + if (tileBuffer.Count > 0) + { + foreach (var kv in tileBuffer) + { + tileMap[kv.Key] = new NodeInfo(kv.Value, tileMap[kv.Key].TileShape, tileMap[kv.Key].OutShape.ToArray()); + } + + return true; + } + + return false; + } + + private Dictionary TryAllocate(List> tileList, Dictionary bucket, bool finalAllocate = false) + { + // TODO: + // 1. 支持不同数据类型的检查 + // 2. 支持weights和数据采用不一样的buffer,可以考虑按pass load weights + // 3. 支持不同层的weights复用或者不复用等 + // 4. 支持线程数可配 + // 5. 如果切K,partial sum 要考虑扩大尺寸 + // 6. cache search的结果,返回时直接输出最终的buffer + Dictionary lifenessMap = new(); + + void UpdateLifeness(int start, Expr expr, TIR.Buffer buffer, bool updateEnd) + { + lifenessMap.Add(expr, new ScheduledBuffer(new Lifeness(start, int.MaxValue), buffer)); + if (updateEnd) + { + foreach (var operand in expr.Operands.ToArray().Where(e => e is not Op)) + { + var userList = operand.Users.Where(u => u is Call).ToList(); + if (userList.All(u => lifenessMap.ContainsKey(u))) + { + lifenessMap[operand].Lifeness.End = start + 1; + } + } + } + } + + foreach (var (kv, i) in tileList.Select((kv, i) => (kv, i))) + { + var shape = kv.Value.TileShape; + var strides = TensorUtilities.GetStrides(shape); + var dtype = kv.Key.CheckedType switch + { + DistributedType d => d.TensorType.DType, + TensorType te => te.DType, + _ => throw new NotSupportedException("Not support type"), + }; + + var location = kv.Key switch + { + TensorConst { ValueType: DistributedType } => MemoryLocation.Rdata, + Var => MemoryLocation.Input, + Call { Target: IR.CPU.Store } => MemoryLocation.Output, + _ => MemoryLocation.L2Data, + }; + + var bfname = kv.Key switch + { + Call c => c.Target.GetType().ToString().Split(".")[^1], + Var v => v.Name, + Const c => "cons", + _ => throw new NotSupportedException(), + } + + + i.ToString(); + Expr start = location switch + { + MemoryLocation.L2Data => IR.None.Default, + MemoryLocation.Rdata => IR.F.Buffer.DDrOf(kv.Key), + _ => TIR.F.CPU.PtrOf(bfname, kv.Key.CheckedDataType), + }; + + if (location is MemoryLocation.Input or MemoryLocation.Output) + { + shape = bucket[kv.Key]; + strides = TensorUtilities.GetStrides(shape); + } + + Expr size; + if (shape.Length == 0) + { + size = dtype.SizeInBytes; + } + else + { + size = shape[0] * strides[0] * dtype.SizeInBytes; + } + + var memSpan = new MemSpan(start, size, location); + var buffer = new TIR.Buffer(bfname, dtype, memSpan, shape.Select(s => (Expr)s).ToArray(), strides.Select(s => (Expr)s).ToArray()); + UpdateLifeness(i, kv.Key, buffer, location == MemoryLocation.L2Data); + } + + foreach (var kv in lifenessMap) + { + if (kv.Value.Lifeness.End == int.MaxValue) + { + kv.Value.Lifeness.End = kv.Value.Lifeness.Start + 2; + } + } + + bool ok = SchedulerSolver.ScheduleByCpModel(lifenessMap, true, 1f, out var scheduledBufferMap); + var ret = new Dictionary(); + if (ok) + { + foreach (var (key, candidateSched) in lifenessMap) + { + if (scheduledBufferMap.TryGetValue(key, out var schedBuffer)) + { + ret.Add(key, schedBuffer.Buffer); + } + else + { + ret.Add(key, candidateSched.Buffer); + } + } + + if (finalAllocate && Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.Rewrite)) + { + var scheduleResponse = new ScheduledResponse(scheduledBufferMap, ok); + scheduleResponse.Dump("buffers", "auto"); + } + } + + return ret; + } + + private void Visit(Call expr, Dictionary tileMap, Dictionary bucketMap, List> candidateKs, int k = -1) + { + switch (expr.Target) + { + case IR.Math.MatMul op: + VisitMatmul(op, expr, tileMap, bucketMap, candidateKs, k); + break; + case IR.Math.Unary or IR.CPU.Load or IR.CPU.Store: + VisitIdenity(expr, tileMap, bucketMap, candidateKs, k); + break; + case IR.Math.Binary op: + VisitBinary(op, expr, tileMap, bucketMap, candidateKs, k); + break; + default: + throw new NotImplementedException("Not Implemented Op: " + expr.Target); + } + } + + private void VisitIdenity(Call call, Dictionary tileMap, Dictionary bucketMap, List> candidateKs, int k = -1) + { + var inTileShape = tileMap[call].TileShape; + var input = call.Arguments[0]; + if (input is Var or TensorConst) + { + tileMap.Add(input, new(null!, inTileShape, bucketMap[input])); + } + else + { + if (tileMap.ContainsKey(input)) + { + tileMap[input].TileShape = inTileShape.Select((s, i) => Math.Max(s, tileMap[input].TileShape[i])).ToArray(); + } + else + { + tileMap.Add(input, new(null!, inTileShape, bucketMap[input])); + } + + Visit((Call)input, tileMap, bucketMap, candidateKs, k); + } + } + + private void VisitMatmul(IR.Math.MatMul op, Call call, Dictionary tileMap, Dictionary bucketMap, List> candidateKs, int k) + { + var lhs = call.Arguments[0]; + var rhs = call.Arguments[1]; + + var outTileShape = tileMap[call].TileShape; + var inTileShapeA = Enumerable.Repeat(1, lhs.CheckedShape.Rank).ToArray(); + inTileShapeA[^2] = outTileShape[^2]; + inTileShapeA[^1] = candidateKs[k][call]; + var inTileShapeB = Enumerable.Repeat(1, rhs.CheckedShape.Rank).ToArray(); + inTileShapeB[^2] = candidateKs[k][call]; + inTileShapeB[^1] = outTileShape[^1]; + + if (!(lhs is Var or TensorConst)) + { + if (tileMap.ContainsKey(lhs)) + { + tileMap[lhs].TileShape = inTileShapeA.Select((s, i) => Math.Max(s, tileMap[lhs].TileShape[i])).ToArray(); + } + else + { + tileMap.Add(lhs, new(null!, inTileShapeA, bucketMap[lhs])); + } + + Visit((Call)lhs, tileMap, bucketMap, candidateKs, k); + } + else + { + tileMap.Add(lhs, new(null!, inTileShapeA, bucketMap[lhs])); + } + + if (!(rhs is Var or TensorConst)) + { + if (tileMap.ContainsKey(rhs)) + { + tileMap[rhs].TileShape = inTileShapeB.Select((s, i) => Math.Max(s, tileMap[rhs].TileShape[i])).ToArray(); + } + else + { + tileMap.Add(rhs, new(null!, inTileShapeB, bucketMap[rhs])); + } + + Visit((Call)rhs, tileMap, bucketMap, candidateKs, k); + } + else + { + tileMap.Add(rhs, new(null!, inTileShapeB, bucketMap[rhs])); + } + } + + private void VisitBinary(IR.Math.Binary op, Call call, Dictionary tileMap, Dictionary bucketMap, List> candidateKs, int k) + { + var lhs = call.Arguments[0]; + var rhs = call.Arguments[1]; + + var outTileShape = tileMap[call].TileShape; + var padLhs = outTileShape.Length - lhs.CheckedShape.Rank; + var inTileShapeA = Enumerable.Range(0, lhs.CheckedShape.Rank).Select(i => lhs.CheckedShape[i].FixedValue == 1 ? 1 : outTileShape[i + padLhs]).ToArray(); + var padRhs = outTileShape.Length - rhs.CheckedShape.Rank; + var inTileShapeB = Enumerable.Range(0, rhs.CheckedShape.Rank).Select(i => rhs.CheckedShape[i].FixedValue == 1 ? 1 : outTileShape[i + padRhs]).ToArray(); + + if (!(lhs is Var or TensorConst)) + { + if (tileMap.ContainsKey(lhs)) + { + tileMap[lhs].TileShape = inTileShapeA.Select((s, i) => Math.Max(s, tileMap[lhs].TileShape[i])).ToArray(); + } + else + { + tileMap.Add(lhs, new(null!, inTileShapeA, bucketMap[lhs])); + } + + Visit((Call)lhs, tileMap, bucketMap, candidateKs, k); + } + else + { + tileMap.Add(lhs, new(null!, inTileShapeA, bucketMap[lhs])); + } + + if (!(rhs is Var or TensorConst)) + { + if (tileMap.ContainsKey(rhs)) + { + tileMap[rhs].TileShape = inTileShapeB.Select((s, i) => Math.Max(s, tileMap[rhs].TileShape[i])).ToArray(); + } + else + { + tileMap.Add(rhs, new(null!, inTileShapeB, bucketMap[rhs])); + } + + Visit((Call)rhs, tileMap, bucketMap, candidateKs, k); + } + else + { + tileMap.Add(rhs, new(null!, inTileShapeB, bucketMap[rhs])); + } + } +} diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/KernelToTIRVisitor.cs b/modules/Nncase.Modules.CPU/Passes/Tile/KernelToTIRVisitor.cs new file mode 100644 index 0000000000..e5ff6ca849 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Passes/Tile/KernelToTIRVisitor.cs @@ -0,0 +1,444 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System.Reactive; +using NetFabric.Hyperlinq; +using Nncase.IR; +using Nncase.IR.CPU; +using Nncase.IR.Imaging; +using Nncase.IR.Math; +using Nncase.IR.NN; +using Nncase.IR.Tensors; +using Nncase.TIR; +using Nncase.Utilities; +using Buffer = Nncase.TIR.Buffer; + +namespace Nncase.Passes.Tile; + +internal sealed class KernelToTIRVisitor : ExprVisitor +{ + private readonly Dictionary _buffersMap = new(ReferenceEqualityComparer.Instance); + private readonly List _mainBody; + private readonly HashSet _devices; + private readonly List<(int, TIR.Buffer)> _outputbuffers; + private readonly Dictionary _fusionCheckCache; + + public KernelToTIRVisitor(List mainBody, HashSet devices, Dictionary fusionCheckCache) + { + _mainBody = mainBody; + _devices = devices; + _outputbuffers = new(); + _fusionCheckCache = fusionCheckCache; + VisitRootFusion = null!; + DataUsage = 0; + MaxDTypeSize = 0; + } + + public ulong DataUsage { get; private set; } + + public ulong MaxDTypeSize { get; private set; } + + public Fusion VisitRootFusion { get; private set; } + + public IEnumerable OutputBuffers => _outputbuffers.OrderBy(p => p.Item1).Select(p => p.Item2); + + public IEnumerable InputBuffers => VisitRootFusion.Parameters.ToArray().Select(p => _buffersMap[p]).OfType().Where(b => b.MemSpan.Location.HasFlag(MemoryLocation.Input)); + + public void Convert(Fusion post) + { + VisitRootFusion = post; + AllocBuffers(post); + Visit(post); + } + + protected override Unit DefaultVisitLeaf(Expr expr) + { + return default; + } + + protected override Unit VisitLeafCall(Call expr) + { + var arguments = expr.Arguments.AsValueEnumerable().Select(GetBuffer).ToArray(); + var ret = GetBuffer(expr); + var op = expr.Target is IR.CPU.CPUKernelOp kop ? kop.Target : expr.Target; + switch (op) + { + case Fusion deviceFunc: + { + var r = new DeviceFusionToPrimFuncRewriter(_fusionCheckCache); + var post = (TIR.PrimFunction)r.Rewrite(deviceFunc); + _devices.Add(post); + _mainBody.Add(new Call(post, arguments.Concat(new[] { ret }).ToArray())); + } + + break; + case IR.Math.Unary unary: + GenerateUnary(unary.UnaryOp, arguments, ret); + break; + case IR.CPU.Boxing boxing: + GenerateBoxing(boxing, arguments, ret, expr); + break; + case Binary binary: + GenerateBinary(binary, arguments, ret, expr); + break; + case IR.CPU.Pack pack: + _mainBody.Add(TIR.F.CPU.Pack(arguments[0], ret, pack.Lanes, pack.Axes)); + break; + case IR.CPU.Unpack unpack: + _mainBody.Add(TIR.F.CPU.Unpack(arguments[0], ret, unpack.Axes)); + break; + case IR.CPU.PackedBinary packed_binary: + // _mainBody.Add(TIR.F.CPU.Binary(arguments[0], arguments[1], ret, packed_binary.BinaryOp, packed_binary.LhsPackedAxes, packed_binary.LhsPadedNums, packed_binary.RhsPackedAxes, packed_binary.RhsPadedNums)); + _mainBody.Add(TIR.F.CPU.Binary(packed_binary.BinaryOp, arguments[0], arguments[1], ret)); + break; + case IR.CPU.PackedMatMul packed_mat_mul: + _mainBody.Add(TIR.F.CPU.PackedMatMul(arguments[0], arguments[1], ret, packed_mat_mul.LhsPackedAxes, packed_mat_mul.LhsPadedNums, packed_mat_mul.RhsPackedAxes, packed_mat_mul.RhsPadedNums)); + break; + case IR.Math.MatMul matmul: + _mainBody.Add(TIR.F.CPU.Matmul(arguments[0], arguments[1], ret)); + break; + case IR.CPU.PackedSoftmax packed_softmax: + _mainBody.Add(TIR.F.CPU.PackedSoftmax(arguments[0], ret, packed_softmax.Axis, packed_softmax.PackedAxes)); + break; + case IR.NN.Softmax softmax: + _mainBody.Add(TIR.F.CPU.PackedSoftmax(arguments[0], ret, ((TensorConst)expr.Arguments[1]).Value.ToScalar(), Array.Empty())); + break; + case IR.CPU.PackedTranspose packed_transpose: + // _mainBody.Add(TIR.F.CPU.PackedTranspose(arguments[0], arguments[1], ret, packed_transpose.PackedAxes)); + _mainBody.Add(TIR.F.CPU.PackedTranspose(arguments[0], ret, ((TensorConst)expr.Arguments[1]).Value.ToArray(), packed_transpose.PackedAxes)); + break; + case IR.CPU.PackedLayerNorm packed_layer_norm: + _mainBody.Add(TIR.F.CPU.PackedLayerNorm(arguments[0], arguments[1], arguments[2], ret, packed_layer_norm.Axis, packed_layer_norm.Epsilon, packed_layer_norm.UseMean, packed_layer_norm.PackedAxes, packed_layer_norm.PadedNums)); + break; + case IR.NN.LayerNorm layernorm: + _mainBody.Add(TIR.F.CPU.PackedLayerNorm(arguments[0], arguments[1], arguments[2], ret, layernorm.Axis, layernorm.Epsilon, layernorm.UseMean, Array.Empty(), Array.Empty())); + break; + case IR.Tensors.Unsqueeze unsqueeze: + _mainBody.Add(TIR.F.CPU.Reshape(arguments[0], ret, expr.CheckedShape.ToValueArray())); + break; + case IR.Tensors.Reshape reshape: + _mainBody.Add(TIR.F.CPU.Reshape(arguments[0], ret, expr.CheckedShape.ToValueArray())); + break; + case IR.Tensors.Slice slice: + _mainBody.Add(TIR.F.CPU.Slice(arguments[0], ret, ((TensorConst)expr.Arguments[1]).Value.ToArray(), ((TensorConst)expr.Arguments[2]).Value.ToArray(), ((TensorConst)expr.Arguments[3]).Value.ToArray(), ((TensorConst)expr.Arguments[4]).Value.ToArray())); + break; + case IR.Tensors.Concat concat: + _mainBody.Add(TIR.F.CPU.Concat(((IR.Tuple)expr.Arguments[0]).Fields.AsValueEnumerable().Select(GetBuffer).ToArray(), ret, concat.Axis)); + break; + case IR.Tensors.Transpose trans: + _mainBody.Add(TIR.F.CPU.Transpose(arguments[0], ret, ((TensorConst)expr.Arguments[1]).Value.ToArray())); + break; + case IR.NN.Swish swish: + _mainBody.Add(TIR.F.CPU.Swish(arguments[0], ret, ((TensorConst)expr.Arguments[1]).Value.ToScalar())); + break; + case IR.Tensors.Gather gather: + _mainBody.Add(TIR.F.CPU.Gather(arguments[0], arguments[1], ret, gather.Axis)); + break; + case IR.NN.Pad pad: + _mainBody.Add(TIR.F.CPU.Pad(arguments[0], ret, ((TensorConst)expr.Arguments[1]).Value.ToArray(), ((TensorConst)expr.Arguments[2]).Value.ToScalar())); + break; +#if false + case MatMul matmul: + GenerateMatmul(matmul, arguments, ret); + break; + case LayerNorm layernorm: + GenerateLayerNorm(layernorm, arguments, ret, (DistributedType)expr.Arguments[0].CheckedType); + break; + case InstanceNormalization instnorm: + GenerateInstanceNorm(instnorm, ((TensorConst)expr.Arguments[3]).Value.ToScalar(), arguments, ret, (DistributedType)expr.Arguments[0].CheckedType); + break; + case Gather gather: + GenerateGather(gather, arguments, ret); + break; + case Concat concat: + GenerateConcat(concat, ((IR.Tuple)expr.Arguments[0]).Fields.AsValueEnumerable().Select(AllocOrGetBuffer).ToArray(), ret); + break; + case Slice slice: + GenerateSlice(slice, arguments[0], ret, expr.Arguments[1], expr.Arguments[2], expr.Arguments[3], (DistributedType)expr.CheckedType); + break; + case Softmax softmax: + GenerateSoftmax(softmax, ((TensorConst)expr.Arguments[1]).Value.ToScalar(), arguments, ret, (DistributedType)expr.CheckedType); + break; + case Transpose transpose: + GenerateTranspose(transpose, ((TensorConst)expr.Arguments[1]).Value.ToArray(), arguments, ret); + break; + case Reshape or Unsqueeze: + GenerateReshape(arguments[0], ret); + break; + case Swish: + GenerateSwishB(arguments[0], ret, ((TensorConst)expr.Arguments[1]).Value.ToScalar()); + break; + case Gelu: + GenerateUnary("gelu", arguments, ret); + break; + case Conv2D conv: + GenerateConv2D(conv, arguments, ret, ((TensorConst)expr.Arguments[3]).Value.ToArray(), ((TensorConst)expr.Arguments[4]).Value.ToArray(), ((TensorConst)expr.Arguments[5]).Value.ToArray(), ((TensorConst)expr.Arguments[6]).Value.ToScalar(), (TensorConst)expr.Arguments[7], (DistributedType)expr.CheckedType); + break; + case ReduceArg reduceArg: + GenerateReduceArg(reduceArg, arguments, ret, ((TensorConst)expr.Arguments[1]).Value.ToScalar(), ((TensorConst)expr.Arguments[2]).Value.ToScalar(), ((TensorConst)expr.Arguments[3]).Value.ToScalar(), reduceArg.ReduceArgOp, reduceArg.DestType); + break; + case ResizeImage resize: + float[] roi = expr.Arguments[1] is TensorConst tc ? tc.Value.ToArray() : new[] { 0f, 0f, 1f, 1f }; + int[] newSize = ((TensorConst)expr.Arguments[2]).Value.ToArray(); + float cubicCoeffA = expr.Arguments[3] is TensorConst tc1 ? tc1.Value.ToScalar() : -0.75f; + int excludeOutside = expr.Arguments[4] is TensorConst tc2 ? tc2.Value.ToScalar() : 0; + float extrapolationValue = expr.Arguments[5] is TensorConst tc3 ? tc3.Value.ToScalar() : 0f; + GenerateResize(resize, arguments, ret, roi, newSize, cubicCoeffA, excludeOutside, extrapolationValue, (DistributedType)expr.CheckedType); + break; + case Cast cast: + GenerateCast(cast.NewType, cast.CastMode, arguments, ret); + break; + case Expand expand: + GenerateExpand(((TensorConst)expr.Arguments[1]).Value.ToArray(), (DistributedType)expr.CheckedType, arguments, ret); + break; + case Clamp clamp: + GenerateClamp(arguments, ret, ((TensorConst)expr.Arguments[1]).Value.ToArray()[0], ((TensorConst)expr.Arguments[2]).Value.ToArray()[0]); + break; + case Where where: + GenerateWhere(arguments, ret, (DistributedType)expr.CheckedType); + break; +#endif + default: + throw new NotSupportedException(); + } + + return default; + } + + private TIR.Buffer GetBuffer(Expr expr) => _buffersMap.GetValueOrDefault(expr, null!); + + private void AllocBuffers(Fusion fusion) + { + var candidates = ExprCollector.Collect(fusion).Where(e => e is Call or Var or TensorConst); + MaxDTypeSize = (ulong)candidates.Select(e => e.CheckedDataType.SizeInBytes).Max(); + foreach (var expr in candidates) + { + var name = $"buffer_{_buffersMap.Keys.Count}"; + if (!_buffersMap.TryGetValue(expr, out var buffer)) + { + switch (expr) + { + case Call c: + var loc = MemoryLocation.Data; + var hierarchy = 0; + var index = CheckRootCall(c, ref loc); + if (c.Target is Boxing box && box.NewType is DistributedType d && !d.TensorType.Shape.Equals(c.Arguments[0].CheckedShape)) + { + name += "_reshape"; + } + + TensorType? dividedType = null; + if (c.CheckedType is TensorType tensorType) + { + dividedType = tensorType; + } + else if (c.CheckedType is DistributedType distributedType) + { + hierarchy = 1; + if (DistributedUtility.TryGetDividedTensorType(distributedType, out var type)) + { + dividedType = type; + } + } + + if (dividedType is TensorType) + { + T.AttachBuffer(Tensor.FromPointer(DataUsage, dividedType.DType), dividedType, loc, hierarchy, out buffer, name); + DataUsage += (ulong)(dividedType.Shape.Size * dividedType.DType.SizeInBytes); + DataUsage = MathUtility.AlignUp(DataUsage, MaxDTypeSize); + } + else if (c.CheckedType is DistributedType) + { + // deal the not uinform sbp. + // var shape = DistributedUtility.TryGetNonUniformDividedShape(distributedType); + // var @var = new Var(TensorType.Pointer(distributedType.TensorType.DType)); + // var strides = TensorUtilities.GetStrides(shape); + // var size = TensorUtilities.GetProduct(shape) * distributedType.TensorType.DType.SizeInBytes; + // buffer = new Buffer(name, distributedType.TensorType.DType, new MemSpan(@var, size, loc, hierarchy), shape, strides); + throw new NotSupportedException("not support non uniform sbp"); + } + else + { + throw new NotSupportedException(); + } + + if (index != -1) + { + _outputbuffers.Add((index, buffer)); + } + + break; + case Var v: + buffer = T.AttachBuffer((TensorType)v.CheckedType, MemoryLocation.Input, 0, out _, out _, name); + break; + case TensorConst c: + buffer = T.AttachBuffer(c, out _, name); + break; + default: + throw new NotSupportedException(); + } + + _buffersMap.Add(expr, buffer); + } + } + } + + private void GenerateUnary(UnaryOp unaryOp, ReadOnlySpan arguments, Buffer ret) + { + var input = arguments[IR.Math.Unary.Input.Index]; + _mainBody.Add(TIR.F.CPU.Unary(unaryOp, input, ret)); + } + + private void GenerateBinary(Binary binary, Buffer[] arguments, Buffer ret, Call expr) + { + _ = (DistributedType)expr.Arguments[0].CheckedType; + _ = (DistributedType)expr.Arguments[1].CheckedType; + _ = (DistributedType)expr.CheckedType; + _mainBody.Add(TIR.F.CPU.Binary(binary.BinaryOp, arguments[0], arguments[1], ret)); + } + + private void GenerateBoxing(IR.CPU.Boxing boxing, Buffer[] arguments, Buffer ret, Call expr) + { + switch (expr.Arguments[0].CheckedType, boxing.NewType) + { + case (TensorType, DistributedType distTensorType): + { + _mainBody.Add(TIR.F.CPU.TensorLoad(ret, arguments[0], distTensorType.NdSBP, distTensorType.Placement)); + } + + break; + case (DistributedType distTensorType, TensorType): + { + _mainBody.Add(TIR.F.CPU.TensorStore(arguments[0], ret, distTensorType.NdSBP, distTensorType.Placement)); + } + + break; + case (DistributedType inType, DistributedType outType): + { + if (inType.NdSBP.Any(sbp => sbp is SBPPartialSum)) + { + // _mainBody.Add(TIR.F.CPU.GatherReduceScatter(arguments[0], ret, inType, outType)); + } + else + { + _mainBody.Add(TIR.F.CPU.TensorStore(arguments[0], None.Default, inType.NdSBP, inType.Placement)); + _mainBody.Add(TIR.F.CPU.TensorLoad(ret, None.Default, outType.NdSBP, outType.Placement)); + } + } + + break; + default: + throw new NotSupportedException(); + } + } + +#if false + private void GenerateSwishB(Buffer input, Buffer ret, float beta) + { + _mainBody.Add(TIR.F.CPU.SwishB(input, ret, beta)); + } + + private void GenerateReshape(Buffer input, Buffer ret) + { + _mainBody.Add(TIR.F.CPU.ReShape(input, ret)); + } + + private void GenerateConcat(Concat concat, Buffer[] inputs, Buffer ret) + { + _mainBody.Add(TIR.F.CPU.Concat(concat.Axis, inputs, ret)); + } + + private void GenerateSlice(Slice slice, Buffer input, Buffer output, Expr begins, Expr ends, Expr axes, DistributedType distributedType) + { + _mainBody.Add(TIR.F.CPU.Slice(input, output, begins, ends, axes, distributedType)); + } + + private void GenerateMatmul(MatMul matmul, Buffer[] arguments, Buffer ret) + { + _mainBody.Add(TIR.F.CPU.Matmul(arguments[0], arguments[1], ret)); + } + + private void GenerateLayerNorm(LayerNorm layerNorm, Buffer[] arguments, Buffer ret, DistributedType distributedType) + { + _mainBody.Add(TIR.F.CPU.LayerNorm(layerNorm.Axis, layerNorm.Epsilon, layerNorm.UseMean, arguments[0], arguments[1], arguments[2], ret, distributedType)); + } + + private void GenerateInstanceNorm(InstanceNormalization instNorm, float eps, Buffer[] arguments, Buffer ret, DistributedType distributedType) + { + _mainBody.Add(TIR.F.CPU.InstanceNorm(eps, arguments[0], arguments[1], arguments[2], ret, distributedType)); + } + + private void GenerateGather(Gather gahter, Buffer[] arguments, Buffer ret) + { + _mainBody.Add(TIR.F.CPU.Gather(gahter.Axis, arguments[0], arguments[1], ret)); + } + + private void GenerateSoftmax(Softmax softmax, int axis, Buffer[] arguments, Buffer ret, DistributedType distributedType) + { + _mainBody.Add(TIR.F.CPU.Softmax(axis, arguments[0], ret, distributedType)); + } + + private void GenerateTranspose(Transpose transpose, int[] perm, Buffer[] arguments, Buffer ret) + { + _mainBody.Add(TIR.F.CPU.Transpose(perm, arguments[0], ret)); + } + + private void GenerateConv2D(Conv2D conv, Buffer[] arguments, Buffer ret, int[] stride, int[] padding, int[] dilation, int groups, TensorConst fusedClamp, DistributedType distributedType) + { + _mainBody.Add(TIR.F.CPU.Conv2D(arguments[0], arguments[1], arguments[2], ret, stride, padding, dilation, groups, fusedClamp, distributedType)); + } + + private void GenerateReduceArg(ReduceArg reduceArg, Buffer[] arguments, Buffer ret, int axis, bool keepdims, bool selectLastIndex, ReduceArgOp op, DataType dataType) + { + _mainBody.Add(TIR.F.CPU.ReduceArg(arguments[0], ret, axis, keepdims, selectLastIndex, op, dataType)); + } + + private void GenerateResize(ResizeImage resize, Buffer[] arguments, Buffer ret, float[] roi, int[] newSize, float cubicCoeffA, int excludeOutside, float extrapolationValue, DistributedType distributedType) + { + _mainBody.Add(TIR.F.CPU.Resize(arguments[0], ret, roi, newSize, cubicCoeffA, excludeOutside, extrapolationValue, resize.ResizeMode, resize.TransformationMode, resize.NearestMode, resize.IsTFResize)); + } + + private void GenerateCast(DataType dataType, CastMode castMode, ReadOnlySpan arguments, Buffer ret) + { + _mainBody.Add(TIR.F.CPU.Cast(arguments[0], ret, dataType, castMode)); + } + + private void GenerateExpand(int[] shape, DistributedType distributedType, ReadOnlySpan arguments, Buffer ret) + { + _mainBody.Add(TIR.F.CPU.Expand(shape, distributedType, arguments[0], ret)); + } + + private void GenerateClamp(ReadOnlySpan arguments, Buffer ret, float min, float max) + { + _mainBody.Add(TIR.F.CPU.Clamp(arguments[0], ret, min, max)); + } + + private void GenerateWhere(ReadOnlySpan arguments, Buffer ret, DistributedType distributedType) + { + _mainBody.Add(TIR.F.CPU.Where(arguments[0], arguments[1], arguments[2], ret, distributedType)); + } +#endif + + private int CheckRootCall(Call c, ref MemoryLocation loc) + { + var index = -1; + if (VisitRootFusion.Body is Call rootCall && ReferenceEquals(c, rootCall)) + { + index = 0; + loc = MemoryLocation.Output; + } + else if (VisitRootFusion.Body is IR.Tuple tp) + { + for (int i = 0; i < tp.Fields.Length; i++) + { + if (ReferenceEquals(tp.Fields[i], c)) + { + index = i; + loc = MemoryLocation.Output; + } + } + } + + return index; + } +} diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/PrimTileVisitor.cs b/modules/Nncase.Modules.CPU/Passes/Tile/PrimTileVisitor.cs new file mode 100644 index 0000000000..15da060551 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Passes/Tile/PrimTileVisitor.cs @@ -0,0 +1,142 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System.Reactive; +using Nncase.IR; + +namespace Nncase.Passes.Tile; + +internal sealed class PrimTileVisitor : ExprVisitor +{ + public PrimTileVisitor() + { + TileList = new(); + NameList = new(); + Count = 0; + } + + public List> TileList { get; } + + public List> NameList { get; } + + public int Count { get; private set; } + + protected override Unit DefaultVisitLeaf(Expr expr) + { + return Unit.Default; + } + + protected override Unit VisitLeafCall(Call expr) + { + switch (expr.Target) + { + case IR.Math.MatMul op: + { + var lhs = expr.Arguments[0]; + var rhs = expr.Arguments[1]; + var inTileShapeA = Enumerable.Repeat(1, lhs.CheckedShape.Rank).ToArray(); + Array.Fill(inTileShapeA, 32, inTileShapeA.Length - 2, 2); + var inTileShapeB = Enumerable.Repeat(1, rhs.CheckedShape.Rank).ToArray(); + Array.Fill(inTileShapeB, 32, inTileShapeB.Length - 2, 2); + + if (!(lhs is Var or TensorConst)) + { + var oldTileAShape = TileList.Find(k => k.Key == lhs).Value; + inTileShapeA = inTileShapeA.Select((s, i) => Math.Max(s, oldTileAShape[i])).ToArray(); + } + else + { + TileList.Add(new(lhs, inTileShapeA)); + NameList.Add(new(lhs, nameof(IR.Math.MatMul) + "_" + Count.ToString() + "_lhs")); + } + + if (!(rhs is Var or TensorConst)) + { + var oldTileBShape = TileList.Find(k => k.Key == rhs).Value; + inTileShapeB = inTileShapeB.Select((s, i) => Math.Max(s, oldTileBShape[i])).ToArray(); + } + else + { + TileList.Add(new(rhs, inTileShapeB)); + NameList.Add(new(rhs, nameof(IR.Math.MatMul) + "_" + Count.ToString() + "_rhs")); + } + + var outTileShape = Enumerable.Repeat(1, expr.CheckedShape.Rank).ToArray(); + outTileShape[^1] = inTileShapeB[^1]; + outTileShape[^2] = inTileShapeA[^2]; + TileList.Add(new(expr, outTileShape)); + NameList.Add(new(expr, nameof(IR.Math.MatMul) + "_" + Count.ToString())); + Count++; + break; + } + + case IR.Math.Unary or IR.CPU.Store or IR.CPU.Load: + { + var input = expr.Arguments[0]; + var inTileShape = Enumerable.Repeat(1, input.CheckedShape.Rank).ToArray(); + inTileShape[^1] = 32; + + if (!(input is Var or TensorConst)) + { + var oldTileShape = TileList.Find(k => k.Key == input).Value; + inTileShape = inTileShape.Select((s, i) => Math.Max(s, oldTileShape[i])).ToArray(); + } + else + { + TileList.Add(new(input, inTileShape)); + NameList.Add(new(expr, expr.Target.GetType().Name + "_" + Count.ToString() + "_input")); + } + + var outTileShape = inTileShape; + TileList.Add(new(expr, outTileShape)); + NameList.Add(new(expr, expr.Target.GetType().Name + "_" + Count.ToString())); + Count++; + break; + } + + case IR.Math.Binary op: + { + var lhs = expr.Arguments[0]; + var rhs = expr.Arguments[1]; + var inTileShapeA = Enumerable.Repeat(1, lhs.CheckedShape.Rank).ToArray(); + inTileShapeA[^1] = 32; + var inTileShapeB = Enumerable.Repeat(1, rhs.CheckedShape.Rank).ToArray(); + inTileShapeB[^1] = 32; + + if (!(lhs is Var or TensorConst)) + { + var oldTileAShape = TileList.Find(k => k.Key == lhs).Value; + inTileShapeA = inTileShapeA.Select((s, i) => Math.Max(s, oldTileAShape[i])).ToArray(); + } + else + { + TileList.Add(new(lhs, inTileShapeA)); + NameList.Add(new(lhs, nameof(IR.Math.Binary) + "_" + Count + "_lhs")); + } + + if (!(rhs is Var or TensorConst)) + { + var oldTileBShape = TileList.Find(k => k.Key == rhs).Value; + inTileShapeB = inTileShapeB.Select((s, i) => Math.Max(s, oldTileBShape[i])).ToArray(); + } + else + { + TileList.Add(new(rhs, inTileShapeB)); + NameList.Add(new(rhs, nameof(IR.Math.Binary) + "_" + Count + "_rhs")); + } + + var outTileShape = Enumerable.Repeat(1, expr.CheckedShape.Rank).ToArray(); + outTileShape[^1] = 32; + TileList.Add(new(expr, outTileShape)); + NameList.Add(new(expr, nameof(IR.Math.Binary) + "_" + Count)); + Count++; + break; + } + + default: + throw new NotImplementedException("Not Implemented Op: " + expr.Target); + } + + return Unit.Default; + } +} diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/TileOptions.cs b/modules/Nncase.Modules.CPU/Passes/Tile/TileOptions.cs new file mode 100644 index 0000000000..48a3bdf137 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Passes/Tile/TileOptions.cs @@ -0,0 +1,21 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Nncase.Passes.Tile; + +/// +/// TileOptions. +/// +/// TargetTileSize. +/// the hierarchy shapes. +/// each hierarchy ram size. +public sealed record TileOptions(int[] TargetTileSize, int[] Hierarchy, int[] HierarchySizes) +{ + public static TileOptions Default { get; } = new(Array.Empty(), new[] { 1 }, new[] { 64 * (int)MathF.Pow(2, 30) }); +} diff --git a/modules/Nncase.Modules.CPU/TIR/CPU/Binary.cs b/modules/Nncase.Modules.CPU/TIR/CPU/Binary.cs new file mode 100644 index 0000000000..6a8c47bd4e --- /dev/null +++ b/modules/Nncase.Modules.CPU/TIR/CPU/Binary.cs @@ -0,0 +1,28 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Nncase.IR; + +namespace Nncase.TIR.CPU; + +public sealed partial class Binary : CPUKernelOp +{ + public static readonly ParameterInfo Lhs = new(typeof(Binary), 0, "input"); + + public static readonly ParameterInfo Rhs = new(typeof(Binary), 1, "input"); + + public static readonly ParameterInfo Output = new(typeof(Binary), 2, "output"); + + public BinaryOp BinaryOp { get; } + + /// + public override string DisplayProperty() + { + return $"BinaryOp.{BinaryOp}"; + } +} diff --git a/modules/Nncase.Modules.CPU/TIR/CPU/CPUKernelOp.cs b/modules/Nncase.Modules.CPU/TIR/CPU/CPUKernelOp.cs new file mode 100644 index 0000000000..ecfc457503 --- /dev/null +++ b/modules/Nncase.Modules.CPU/TIR/CPU/CPUKernelOp.cs @@ -0,0 +1,9 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. +using Nncase.IR; + +namespace Nncase.TIR.CPU; + +public abstract class CPUKernelOp : Op +{ +} diff --git a/modules/Nncase.Modules.CPU/TIR/CPU/Concat.cs b/modules/Nncase.Modules.CPU/TIR/CPU/Concat.cs new file mode 100644 index 0000000000..c003525de4 --- /dev/null +++ b/modules/Nncase.Modules.CPU/TIR/CPU/Concat.cs @@ -0,0 +1,35 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Nncase.IR; +using Nncase.PatternMatch; +using static Nncase.IR.TypePatternUtility; + +namespace Nncase.TIR.CPU; + +/// +/// Concat expression. +/// +public sealed partial class Concat : CPUKernelOp +{ + /// + /// Gets input. + /// + public static readonly ParameterInfo Input = new(typeof(Concat), 0, "input"); + + /// + /// Gets input. + /// + public static readonly ParameterInfo Output = new(typeof(Concat), 1, "output"); + + /// + /// Gets begins. + /// + public int Axis { get; } +} diff --git a/modules/Nncase.Modules.CPU/TIR/CPU/Functional.cs b/modules/Nncase.Modules.CPU/TIR/CPU/Functional.cs new file mode 100644 index 0000000000..7d578fa117 --- /dev/null +++ b/modules/Nncase.Modules.CPU/TIR/CPU/Functional.cs @@ -0,0 +1,127 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Nncase.IR; +using Nncase.IR.Math; +using Nncase.TIR; +using Nncase.TIR.CPU; + +namespace Nncase.TIR.F; + +public partial class CPU +{ + /// + /// the ptr of can create the *PtrName in the c code. + /// + /// c pointer name. + /// type. + /// call. + public static Call PtrOf(string name, DataType primType) => new Call(new PtrOf(name, primType)); + + public static Call SramPtr(Expr input, DataType primType) => new Call(new SramPtr(primType), input); + + public static Call TensorLoad(Expr dest, Expr src, IRArray ndsbp, Placement placement) + { + return new Call(new TensorLoad(ndsbp, placement), dest, src); + } + + public static Call TensorStore(Expr src, Expr dest, IRArray ndsbp, Placement placement) + { + return new Call(new TensorStore(ndsbp, placement), src, dest); + } + + public static Call Memcopy(Expr dest, Expr src) + { + return new Call(new Memcopy(), dest, src); + } + + public static Call Unary(UnaryOp unaryOp, Expr input, Expr output) + { + return new Call(new TIR.CPU.Unary(unaryOp), input, output); + } + + public static Call Binary(BinaryOp binaryOp, Expr lhs, Expr rhs, Expr output) + { + return new Call(new TIR.CPU.Binary(binaryOp), lhs, rhs, output); + } + + public static Call Matmul(Expr lhs, Expr rhs, Expr output) + { + return new Call(new Matmul(), lhs, rhs, output); + } + + public static Expr Pack(Expr input, Expr output, IRArray lanes, IRArray axes) + { + return new Call(new Pack(lanes, axes), input, output); + } + + public static Expr Unpack(Expr input, Expr output, IRArray axes) + { + return new Call(new Unpack(axes), input, output); + } + + public static Expr PackedSoftmax(Expr input, Expr output, int axis, IRArray packedAxes) + { + return new Call(new PackedSoftmax(axis, packedAxes), input, output); + } + + public static Expr PackedLayerNorm(Expr input, Expr scale, Expr bias, Expr output, int axis, float epsilon, bool usemean, IRArray packedAxes, IRArray padedNums) + { + return new Call(new PackedLayerNorm(axis, epsilon, usemean, packedAxes, padedNums), input, scale, bias, output); + } + + public static Expr PackedMatMul(Expr lhs, Expr rhs, Expr output, IRArray lhsPackedAxes, IRArray lhsPadedNums, IRArray rhsPackedAxes, IRArray rhsPadedNums) + { + return new Call(new PackedMatMul(lhsPackedAxes, lhsPadedNums, rhsPackedAxes, rhsPadedNums), lhs, rhs, output); + } + + public static Expr PackedBinary(Expr lhs, Expr rhs, Expr output, BinaryOp binaryOp, IRArray lhsPackedAxes, IRArray lhsPadedNums, IRArray rhsPackedAxes, IRArray rhsPadedNums) + { + return new Call(new PackedBinary(binaryOp, lhsPackedAxes, lhsPadedNums, rhsPackedAxes, rhsPadedNums), lhs, rhs, output); + } + + public static Expr PackedTranspose(Expr input, Expr output, IRArray perm, IRArray packedAxes) + { + return new Call(new PackedTranspose(perm, packedAxes), input, output); + } + + public static Expr Slice(Buffer input, Buffer ret, int[] begin, int[] stop, int[] axes, int[] stride) + { + return new Call(new Slice(begin, stop, axes, stride), input, ret); + } + + public static Expr Concat(Buffer[] inputs, Buffer ret, int axis) + { + return new Call(new Concat(axis), inputs.Concat(new[] { ret }).ToArray()); + } + + public static Expr Reshape(Buffer input, Buffer ret, int[] newShape) + { + return new Call(new Reshape(newShape), input, ret); + } + + public static Expr Swish(Buffer buffer, Buffer ret, float v) + { + return new Call(new Swish(v), buffer, ret); + } + + public static Expr Gather(Buffer input, Buffer indcies, Buffer ret, int axis) + { + return new Call(new Gather(axis), input, indcies, ret); + } + + public static Expr Transpose(Buffer buffer, Buffer ret, int[] perm) + { + return new Call(new Transpose(perm), buffer, ret); + } + + internal static Expr Pad(Buffer input, Buffer ret, int[] pads, float padValue) + { + return new Call(new Pad(pads, padValue), input, ret); + } +} diff --git a/modules/Nncase.Modules.CPU/TIR/CPU/Gather.cs b/modules/Nncase.Modules.CPU/TIR/CPU/Gather.cs new file mode 100644 index 0000000000..ee6533c5b3 --- /dev/null +++ b/modules/Nncase.Modules.CPU/TIR/CPU/Gather.cs @@ -0,0 +1,41 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using NetFabric.Hyperlinq; +using Nncase.IR; +using Nncase.PatternMatch; +using static Nncase.IR.TypePatternUtility; + +namespace Nncase.TIR.CPU; + +/// +/// Gather expression. +/// +public sealed partial class Gather : CPUKernelOp +{ + /// + /// Gets input. + /// + public static readonly ParameterInfo Input = new(typeof(Gather), 0, "input", ParameterKind.Input); + + /// + /// Gets index. + /// + public static readonly ParameterInfo Index = new(typeof(Gather), 1, "index", IsIntegral(), ParameterKind.Input); + + /// + /// Gets index. + /// + public static readonly ParameterInfo Output = new(typeof(Gather), 2, "output"); + + /// + /// Gets axis. + /// + public int Axis { get; } +} diff --git a/modules/Nncase.Modules.CPU/TIR/CPU/Matmul.cs b/modules/Nncase.Modules.CPU/TIR/CPU/Matmul.cs new file mode 100644 index 0000000000..8454bfd19b --- /dev/null +++ b/modules/Nncase.Modules.CPU/TIR/CPU/Matmul.cs @@ -0,0 +1,14 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. +using Nncase.IR; + +namespace Nncase.TIR.CPU; + +public sealed partial class Matmul : CPUKernelOp +{ + public static readonly ParameterInfo Lhs = new(typeof(Matmul), 0, "lhs"); + + public static readonly ParameterInfo Rhs = new(typeof(Matmul), 1, "rhs"); + + public static readonly ParameterInfo Output = new(typeof(Matmul), 2, "output"); +} diff --git a/modules/Nncase.Modules.CPU/TIR/CPU/Memcopy.cs b/modules/Nncase.Modules.CPU/TIR/CPU/Memcopy.cs new file mode 100644 index 0000000000..e04b89717a --- /dev/null +++ b/modules/Nncase.Modules.CPU/TIR/CPU/Memcopy.cs @@ -0,0 +1,12 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. +using Nncase.IR; + +namespace Nncase.TIR.CPU; + +public sealed partial class Memcopy : CPUKernelOp +{ + public static readonly ParameterInfo Dest = new(typeof(Memcopy), 0, "dest"); + + public static readonly ParameterInfo Src = new(typeof(Memcopy), 1, "src"); +} diff --git a/modules/Nncase.Modules.CPU/TIR/CPU/Pack.cs b/modules/Nncase.Modules.CPU/TIR/CPU/Pack.cs new file mode 100644 index 0000000000..b5c212233f --- /dev/null +++ b/modules/Nncase.Modules.CPU/TIR/CPU/Pack.cs @@ -0,0 +1,33 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Nncase.IR; +using Nncase.PatternMatch; + +namespace Nncase.TIR.CPU; + +/// +/// Pack expression. +/// +public sealed partial class Pack : CPUKernelOp +{ + /// + /// Gets input. + /// + public static readonly ParameterInfo Input = new(typeof(Pack), 0, "input", ParameterKind.Input); + + public static readonly ParameterInfo Output = new(typeof(Pack), 1, "output", ParameterKind.Input); + + public IRArray Lanes { get; } + + public IRArray Axes { get; } + + /// + public override string DisplayProperty() => $"Lanes: {Lanes}, Axes: {Axes}"; +} diff --git a/modules/Nncase.Modules.CPU/TIR/CPU/PackedBinary.cs b/modules/Nncase.Modules.CPU/TIR/CPU/PackedBinary.cs new file mode 100644 index 0000000000..a310632f5d --- /dev/null +++ b/modules/Nncase.Modules.CPU/TIR/CPU/PackedBinary.cs @@ -0,0 +1,34 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.IR; +using Nncase.PatternMatch; + +namespace Nncase.TIR.CPU; + +public sealed partial class PackedBinary : CPUKernelOp +{ + /// + /// Gets input. + /// + public static readonly ParameterInfo Lhs = new(typeof(PackedBinary), 0, "lhs", ParameterKind.Input); + + /// + /// Gets Other. + /// + public static readonly ParameterInfo Rhs = new(typeof(PackedBinary), 1, "rhs", ParameterKind.Input); + + public static readonly ParameterInfo Output = new(typeof(PackedBinary), 2, "output", ParameterKind.Input); + + public BinaryOp BinaryOp { get; } + + public IRArray LhsPackedAxes { get; } + + public IRArray LhsPadedNums { get; } + + public IRArray RhsPackedAxes { get; } + + public IRArray RhsPadedNums { get; } + + public override string DisplayProperty() => $"BinaryOp: {BinaryOp}, LhsPackedAxes: {LhsPackedAxes}, LhsPadedNums: {LhsPadedNums}, RhsPackedAxes: {RhsPackedAxes}, RhsPadedNums: {RhsPadedNums}"; +} diff --git a/modules/Nncase.Modules.CPU/TIR/CPU/PackedLayerNorm.cs b/modules/Nncase.Modules.CPU/TIR/CPU/PackedLayerNorm.cs new file mode 100644 index 0000000000..87537a277d --- /dev/null +++ b/modules/Nncase.Modules.CPU/TIR/CPU/PackedLayerNorm.cs @@ -0,0 +1,39 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.IR; +using Nncase.PatternMatch; + +namespace Nncase.TIR.CPU; + +public sealed partial class PackedLayerNorm : CPUKernelOp +{ + /// + /// Gets input. + /// + public static readonly ParameterInfo Input = new(typeof(PackedLayerNorm), 0, "input", ParameterKind.Input); + + /// + /// Gets scale. + /// + public static readonly ParameterInfo Scale = new(typeof(PackedLayerNorm), 1, "scale", ParameterKind.Input); + + /// + /// Gets bias. + /// + public static readonly ParameterInfo Bias = new(typeof(PackedLayerNorm), 2, "bias", ParameterKind.Input); + + public static readonly ParameterInfo Output = new(typeof(PackedLayerNorm), 3, "output", ParameterKind.Input); + + public int Axis { get; } + + public float Epsilon { get; } + + public bool UseMean { get; } + + public IRArray PackedAxes { get; } + + public IRArray PadedNums { get; } + + public override string DisplayProperty() => $"Axis: {Axis}, Epsilon: {Epsilon}, UseMean: {UseMean}, PackedAxes: {PackedAxes}, PadedNums: {PadedNums}"; +} diff --git a/modules/Nncase.Modules.CPU/TIR/CPU/PackedMatMul.cs b/modules/Nncase.Modules.CPU/TIR/CPU/PackedMatMul.cs new file mode 100644 index 0000000000..08645c72ff --- /dev/null +++ b/modules/Nncase.Modules.CPU/TIR/CPU/PackedMatMul.cs @@ -0,0 +1,32 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.IR; +using Nncase.PatternMatch; + +namespace Nncase.TIR.CPU; + +public sealed partial class PackedMatMul : CPUKernelOp +{ + /// + /// Gets input. + /// + public static readonly ParameterInfo Lhs = new(typeof(PackedMatMul), 0, "lhs", ParameterKind.Input); + + /// + /// Gets Other. + /// + public static readonly ParameterInfo Rhs = new(typeof(PackedMatMul), 1, "rhs", ParameterKind.Input); + + public static readonly ParameterInfo Output = new(typeof(PackedMatMul), 2, "output", ParameterKind.Input); + + public IRArray LhsPackedAxes { get; } + + public IRArray LhsPadedNums { get; } + + public IRArray RhsPackedAxes { get; } + + public IRArray RhsPadedNums { get; } + + public override string DisplayProperty() => $"LhsPackedAxes: {LhsPackedAxes}, LhsPadedNums: {LhsPadedNums}, RhsPackedAxes: {RhsPackedAxes}, RhsPadedNums: {RhsPadedNums}"; +} diff --git a/modules/Nncase.Modules.CPU/TIR/CPU/PackedSoftMax.cs b/modules/Nncase.Modules.CPU/TIR/CPU/PackedSoftMax.cs new file mode 100644 index 0000000000..003bf4fbfc --- /dev/null +++ b/modules/Nncase.Modules.CPU/TIR/CPU/PackedSoftMax.cs @@ -0,0 +1,20 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.IR; +using Nncase.PatternMatch; + +namespace Nncase.TIR.CPU; + +public sealed partial class PackedSoftmax : CPUKernelOp +{ + public static readonly ParameterInfo Input = new(typeof(PackedSoftmax), 0, "input", ParameterKind.Input); + + public static readonly ParameterInfo Output = new(typeof(PackedSoftmax), 1, "output", ParameterKind.Input); + + public int Axis { get; } + + public IRArray PackedAxes { get; } + + public override string DisplayProperty() => $"{Axis}, {PackedAxes}"; +} diff --git a/modules/Nncase.Modules.CPU/TIR/CPU/PackedTranspose.cs b/modules/Nncase.Modules.CPU/TIR/CPU/PackedTranspose.cs new file mode 100644 index 0000000000..2356bfb18c --- /dev/null +++ b/modules/Nncase.Modules.CPU/TIR/CPU/PackedTranspose.cs @@ -0,0 +1,22 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.IR; +using Nncase.PatternMatch; +using static Nncase.IR.TypePatternUtility; + +namespace Nncase.TIR.CPU; + +public sealed partial class PackedTranspose : CPUKernelOp +{ + /// + /// Gets input. + /// + public static readonly ParameterInfo Input = new(typeof(PackedTranspose), 0, "input", ParameterKind.Input); + + public static readonly ParameterInfo Output = new(typeof(PackedTranspose), 1, "output", ParameterKind.Input); + + public IRArray Perm { get; } + + public IRArray PackedAxes { get; } +} diff --git a/modules/Nncase.Modules.CPU/TIR/CPU/Pad.cs b/modules/Nncase.Modules.CPU/TIR/CPU/Pad.cs new file mode 100644 index 0000000000..4f309fe996 --- /dev/null +++ b/modules/Nncase.Modules.CPU/TIR/CPU/Pad.cs @@ -0,0 +1,34 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Nncase.IR; +using Nncase.PatternMatch; +using static Nncase.IR.TypePatternUtility; + +namespace Nncase.TIR.CPU; + +/// +/// Concat expression. +/// +public sealed partial class Pad : CPUKernelOp +{ + /// + /// Gets input. + /// + public static readonly ParameterInfo Input = new(typeof(Pad), 0, "input"); + + /// + /// Gets input. + /// + public static readonly ParameterInfo Output = new(typeof(Pad), 1, "output"); + + public IRArray Paddings { get; } + + public float PadValue { get; } +} diff --git a/modules/Nncase.Modules.CPU/TIR/CPU/PtrOf.cs b/modules/Nncase.Modules.CPU/TIR/CPU/PtrOf.cs new file mode 100644 index 0000000000..f2fe691cf1 --- /dev/null +++ b/modules/Nncase.Modules.CPU/TIR/CPU/PtrOf.cs @@ -0,0 +1,16 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. +using Nncase.IR; + +namespace Nncase.TIR.CPU; + +public sealed partial class PtrOf : Op +{ + public string PtrName { get; } + + public DataType DataType { get; } + + public override bool CanFoldConstCall => false; + + public override string DisplayProperty() => $"{PtrName}"; +} diff --git a/modules/Nncase.Modules.CPU/TIR/CPU/Reshape.cs b/modules/Nncase.Modules.CPU/TIR/CPU/Reshape.cs new file mode 100644 index 0000000000..7e87edcb79 --- /dev/null +++ b/modules/Nncase.Modules.CPU/TIR/CPU/Reshape.cs @@ -0,0 +1,35 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Nncase.IR; +using Nncase.PatternMatch; +using static Nncase.IR.TypePatternUtility; + +namespace Nncase.TIR.CPU; + +/// +/// Reshape expression. +/// +public sealed partial class Reshape : CPUKernelOp +{ + /// + /// Gets input. + /// + public static readonly ParameterInfo Input = new(typeof(Reshape), 0, "input"); + + /// + /// Gets input. + /// + public static readonly ParameterInfo Output = new(typeof(Reshape), 1, "output"); + + /// + /// Gets begins. + /// + public IRArray NewShape { get; } +} diff --git a/modules/Nncase.Modules.CPU/TIR/CPU/Slice.cs b/modules/Nncase.Modules.CPU/TIR/CPU/Slice.cs new file mode 100644 index 0000000000..013b038584 --- /dev/null +++ b/modules/Nncase.Modules.CPU/TIR/CPU/Slice.cs @@ -0,0 +1,50 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Nncase.IR; +using Nncase.PatternMatch; +using static Nncase.IR.TypePatternUtility; + +namespace Nncase.TIR.CPU; + +/// +/// Slice expression. +/// +public sealed partial class Slice : CPUKernelOp +{ + /// + /// Gets input. + /// + public static readonly ParameterInfo Input = new(typeof(Slice), 0, "input"); + + /// + /// Gets input. + /// + public static readonly ParameterInfo Output = new(typeof(Slice), 1, "output"); + + /// + /// Gets begins. + /// + public IRArray Begins { get; } + + /// + /// Gets ends. + /// + public IRArray Ends { get; } + + /// + /// Gets axes. + /// + public IRArray Axes { get; } + + /// + /// Gets strides. + /// + public IRArray Strides { get; } +} diff --git a/modules/Nncase.Modules.CPU/TIR/CPU/SramPtr.cs b/modules/Nncase.Modules.CPU/TIR/CPU/SramPtr.cs new file mode 100644 index 0000000000..e436a61d02 --- /dev/null +++ b/modules/Nncase.Modules.CPU/TIR/CPU/SramPtr.cs @@ -0,0 +1,15 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. +using Nncase.IR; +using static Nncase.IR.TypePatternUtility; + +namespace Nncase.TIR.CPU; + +public sealed partial class SramPtr : Op +{ + public static readonly ParameterInfo OffSet = new(typeof(SramPtr), 0, "offset", IsIntegralScalar()); + + public DataType DataType { get; } + + public override bool CanFoldConstCall => false; +} diff --git a/modules/Nncase.Modules.CPU/TIR/CPU/Swish.cs b/modules/Nncase.Modules.CPU/TIR/CPU/Swish.cs new file mode 100644 index 0000000000..4d81e33c8b --- /dev/null +++ b/modules/Nncase.Modules.CPU/TIR/CPU/Swish.cs @@ -0,0 +1,35 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Nncase.IR; +using Nncase.PatternMatch; +using static Nncase.IR.TypePatternUtility; + +namespace Nncase.TIR.CPU; + +/// +/// Swish expression. +/// +public sealed partial class Swish : CPUKernelOp +{ + /// + /// Gets input. + /// + public static readonly ParameterInfo Input = new(typeof(Swish), 0, "input"); + + /// + /// Gets input. + /// + public static readonly ParameterInfo Output = new(typeof(Swish), 1, "output"); + + /// + /// Gets begins. + /// + public float Beta { get; } +} diff --git a/modules/Nncase.Modules.CPU/TIR/CPU/TensorLoad.cs b/modules/Nncase.Modules.CPU/TIR/CPU/TensorLoad.cs new file mode 100644 index 0000000000..473f1f42db --- /dev/null +++ b/modules/Nncase.Modules.CPU/TIR/CPU/TensorLoad.cs @@ -0,0 +1,16 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. +using Nncase.IR; + +namespace Nncase.TIR.CPU; + +public sealed partial class TensorLoad : CPUKernelOp +{ + public static readonly ParameterInfo Dest = new(typeof(TensorLoad), 0, "dest"); + + public static readonly ParameterInfo Src = new(typeof(TensorLoad), 1, "src"); + + public IRArray NdSbp { get; } + + public Placement Placement { get; } +} diff --git a/modules/Nncase.Modules.CPU/TIR/CPU/TensorStore.cs b/modules/Nncase.Modules.CPU/TIR/CPU/TensorStore.cs new file mode 100644 index 0000000000..1942eb8d19 --- /dev/null +++ b/modules/Nncase.Modules.CPU/TIR/CPU/TensorStore.cs @@ -0,0 +1,16 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. +using Nncase.IR; + +namespace Nncase.TIR.CPU; + +public sealed partial class TensorStore : CPUKernelOp +{ + public static readonly ParameterInfo Src = new(typeof(TensorStore), 0, "src"); + + public static readonly ParameterInfo Dest = new(typeof(TensorStore), 1, "dest"); + + public IRArray NdSbp { get; } + + public Placement Placement { get; } +} diff --git a/modules/Nncase.Modules.CPU/TIR/CPU/Transpose.cs b/modules/Nncase.Modules.CPU/TIR/CPU/Transpose.cs new file mode 100644 index 0000000000..568aa61492 --- /dev/null +++ b/modules/Nncase.Modules.CPU/TIR/CPU/Transpose.cs @@ -0,0 +1,35 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Nncase.IR; +using Nncase.PatternMatch; +using static Nncase.IR.TypePatternUtility; + +namespace Nncase.TIR.CPU; + +/// +/// Concat expression. +/// +public sealed partial class Transpose : CPUKernelOp +{ + /// + /// Gets input. + /// + public static readonly ParameterInfo Input = new(typeof(Transpose), 0, "input"); + + /// + /// Gets input. + /// + public static readonly ParameterInfo Output = new(typeof(Transpose), 1, "output"); + + /// + /// Gets begins. + /// + public IRArray Perm { get; } +} diff --git a/modules/Nncase.Modules.CPU/TIR/CPU/Unary.cs b/modules/Nncase.Modules.CPU/TIR/CPU/Unary.cs new file mode 100644 index 0000000000..cd7e9bd444 --- /dev/null +++ b/modules/Nncase.Modules.CPU/TIR/CPU/Unary.cs @@ -0,0 +1,20 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. +using Nncase.IR; + +namespace Nncase.TIR.CPU; + +public sealed partial class Unary : CPUKernelOp +{ + public static readonly ParameterInfo Input = new(typeof(Unary), 0, "input"); + + public static readonly ParameterInfo Output = new(typeof(Unary), 1, "output"); + + public UnaryOp UnaryOp { get; } + + /// + public override string DisplayProperty() + { + return $"UnaryOp.{UnaryOp}"; + } +} diff --git a/modules/Nncase.Modules.CPU/TIR/CPU/Unpack.cs b/modules/Nncase.Modules.CPU/TIR/CPU/Unpack.cs new file mode 100644 index 0000000000..00b0df769a --- /dev/null +++ b/modules/Nncase.Modules.CPU/TIR/CPU/Unpack.cs @@ -0,0 +1,31 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Nncase.IR; +using Nncase.PatternMatch; + +namespace Nncase.TIR.CPU; + +/// +/// Unpack expression. +/// +public sealed partial class Unpack : CPUKernelOp +{ + /// + /// Gets input. + /// + public static readonly ParameterInfo Input = new(typeof(Unpack), 0, "input", ParameterKind.Input); + + public static readonly ParameterInfo Output = new(typeof(Unpack), 1, "output", ParameterKind.Input); + + public IRArray Axes { get; } + + /// + public override string DisplayProperty() => $"Axes: {Axes}"; +} diff --git a/modules/Nncase.Modules.CPU/Targets/CPUCompileOptions.cs b/modules/Nncase.Modules.CPU/Targets/CPUCompileOptions.cs new file mode 100644 index 0000000000..1744bd786f --- /dev/null +++ b/modules/Nncase.Modules.CPU/Targets/CPUCompileOptions.cs @@ -0,0 +1,15 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Nncase.Targets; + +public sealed record CPUCompileOptions(string ModelName, bool Packing, int[] TargetTileSize, int[] Hierarchy, string HierarchyNames, int[] HierarchySizes) : ITargetCompileOptions +{ + public static CPUCompileOptions Default { get; } = new(string.Empty, false, Array.Empty(), new[] { 1 }, "b", new[] { 3 * (int)MathF.Pow(2, 20) }); +} diff --git a/modules/Nncase.Modules.StackVM/Targets/CPUTarget.cs b/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs similarity index 50% rename from modules/Nncase.Modules.StackVM/Targets/CPUTarget.cs rename to modules/Nncase.Modules.CPU/Targets/CPUTarget.cs index 2f63e02be9..941236878d 100644 --- a/modules/Nncase.Modules.StackVM/Targets/CPUTarget.cs +++ b/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs @@ -3,16 +3,20 @@ using System; using System.Collections.Generic; +using System.CommandLine; using System.CommandLine.Invocation; using System.Linq; +using System.Runtime.InteropServices; using System.Text; using System.Threading.Tasks; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.Options; using Nncase.CodeGen; +using Nncase.CodeGen.CPU; using Nncase.CodeGen.StackVM; using Nncase.IR; using Nncase.Passes; +using Nncase.Passes.Transforms; using Nncase.Quantization; namespace Nncase.Targets; @@ -28,7 +32,12 @@ public class CPUTarget : ITarget public (System.CommandLine.Command Command, Func Parser) RegisterCommandAndParser() { - return (new System.CommandLine.Command(Kind), (_, _) => DefaultTargetCompileOptions.Instance); + var cmd = new System.CommandLine.Command(Kind); + cmd.AddOption(new Option( + name: "--packing", + description: "enable layout optimization.", + getDefaultValue: () => false)); + return (cmd, ParseTargetCompileOptions); } /// @@ -44,6 +53,24 @@ public void RegisterTargetInDependentPass(IPassManager passManager, CompileOptio /// public void RegisterTargetDependentPass(IPassManager passManager, CompileOptions options) { + passManager.AddWithName("MakeFusion").Configure(p => + { + p.Add(); + p.Add(); + p.Add(); + }); + +#if false + passManager.AddWithName("CPUDeviceFusion").Configure(p => + { + p.Add(); + }); +#endif + + passManager.AddWithName("CPUKernelFusion").Configure(p => + { + p.Add(); + }); } /// @@ -74,6 +101,54 @@ public void RegisterTargetDependentAfterQuantPass(IPassManager passManager, Comp p.Add(); }); } + + if (options.TargetCompileOptions is CPUCompileOptions { Packing: true }) + { + passManager.AddWithName("AutoPacking").Configure(p => + { + p.Add(); + }); + } + + passManager.AddWithName("AutoDistributed").Configure(p => + { + p.Add(); + }); + + passManager.Add(); + +#if false + // FIX ME: Disable macos as macho loader is buggy. + if (!RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + { + passManager.AddWithName("CPUDeviceFusion").Configure(p => + { + p.AddAnalysis(); + p.Add(); + }); + } +#endif + + passManager.Add(); + + passManager.Add(); + + passManager.Add().Configure(p => + { + p.Add(); + p.Add(); + p.Add(); + p.Add(); + }); + + passManager.AddWithName("DDrBufferSchdeule"); + + passManager.AddWithName("InstStage").Configure(p => + { + p.Add(); + p.Add(); + p.Add(); + }); } public void RegisterTargetDependentBeforeCodeGen(IPassManager passManager, CompileOptions options) @@ -87,9 +162,18 @@ public IModuleBuilder CreateModuleBuilder(string moduleKind, CompileOptions opti { return new StackVMModuleBuilder(); } + else if (moduleKind == "cpu") + { + return new CPUModuleBuilder(options); + } else { throw new NotSupportedException($"{moduleKind} module is not supported."); } } + + private static ITargetCompileOptions ParseTargetCompileOptions(InvocationContext context, Command command) + { + return new CPUCompileOptions(string.Empty, false, Array.Empty(), new[] { 1 }, "b", new[] { 3 * (int)MathF.Pow(2, 20) }); + } } diff --git a/modules/Nncase.Modules.CPU/Utilities/PackUtility.cs b/modules/Nncase.Modules.CPU/Utilities/PackUtility.cs new file mode 100644 index 0000000000..91d01f9984 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Utilities/PackUtility.cs @@ -0,0 +1,149 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.IR; + +namespace Nncase.Utilities; + +public static class PackUtility +{ + public static Expr PadForPack(Expr input, int[] shape, int[] packedAxes, int[] lanes, Expr value, out int[] padNums) + { + var isPadded = false; + var pads = new int[shape.Length, 2]; + for (int i = 0; i < packedAxes.Length; i++) + { + var axis = packedAxes[i]; + if (shape[axis] % lanes[i] != 0) + { + pads[axis, 1] = MathUtility.AlignUp(shape[axis], lanes[i]) - shape[axis]; + isPadded = true; + } + } + + padNums = new int[packedAxes.Length]; + for (int i = 0; i < packedAxes.Length; i++) + { + padNums[i] = pads[packedAxes[i], 1]; + } + + if (isPadded) + { + return IR.F.NN.Pad(input, pads, PadMode.Constant, value); + } + + return input; + } + + public static Expr SliceForPack(Expr input, int[] shape, int[] padNums) + { + bool isPadded = false; + var ends = shape.ToArray(); + if (padNums.Any(i => i > 0)) + { + isPadded = true; + } + + return isPadded ? IR.F.Tensors.Slice(input, Enumerable.Repeat(0, shape.Length).ToArray(), ends, shape.Length) : input; + } + + /// + /// find the reshape's shape transform matrix. + /// + /// input shape. + /// new shape. + /// mat. + /// bool. + public static bool TryGetShapeMapMatrix(int[] inShape, int[] newShape, out int[,] mat) + { + int Dot(int[,] cmat, int i) + { + var prod = 1; + for (int j = 0; j < inShape.Length; j++) + { + var v = cmat[i, j] * inShape[j]; + if (v != 0) + { + prod *= v; + } + } + + return prod; + } + + mat = new int[newShape.Length, inShape.Length]; + int i = 0, j = 0; + var paths = new List<(int, int)>(); + while (i < newShape.Length) + { + if (paths.IndexOf((i, j)) != -1) + { + return false; + } + + mat[i, j] = 1; + paths.Add((i, j)); + var newDim = Dot(mat, i); + switch (newDim - newShape[i]) + { + case 0: + i++; j++; + break; + case < 0: + j++; + break; + case > 0: + mat[i, j] = 0; + j--; + paths.RemoveAt(paths.Count - 1); + break; + } + } + + return i == newShape.Length && j == inShape.Length; + } + + /// + /// convert the mapping matrix as a dictionary. + /// the key is in dim, value is not dim. + /// + /// mat. + /// dict. + public static (Dictionary> Forward, Dictionary> Backward) ShapeMapMatrixAsDict(int[,] mat) + { + var forward = new Dictionary>(); + var backward = new Dictionary>(); + for (int i = 0; i < mat.GetLength(0); i++) + { + for (int j = 0; j < mat.GetLength(1); j++) + { + if (mat[i, j] == 0) + { + continue; + } + + if (!forward.TryGetValue(j, out var l1)) + { + l1 = new() { i }; + forward.Add(j, l1); + } + else + { + l1.Add(i); + } + + if (!backward.TryGetValue(i, out var l2)) + { + l2 = new() { j }; + backward.Add(i, l2); + } + else + { + l2.Add(j); + } + } + } + + return (forward, backward); + } +} diff --git a/modules/Nncase.Modules.CPU/packages.lock.json b/modules/Nncase.Modules.CPU/packages.lock.json new file mode 100644 index 0000000000..2976ad6007 --- /dev/null +++ b/modules/Nncase.Modules.CPU/packages.lock.json @@ -0,0 +1,334 @@ +{ + "version": 2, + "dependencies": { + "net7.0": { + "Razor.Templating.Core": { + "type": "Direct", + "requested": "[1.9.0, )", + "resolved": "1.9.0", + "contentHash": "eHNqkpmNcPr5rvP/8/FFkddnvzVMH0BSyrq03H0VLZK2r1GUe3RgIgsoIXnImHMIrBzUS8gOwV65MfRPdYRi6g==" + }, + "StyleCop.Analyzers": { + "type": "Direct", + "requested": "[1.2.0-beta.435, )", + "resolved": "1.2.0-beta.435", + "contentHash": "TADk7vdGXtfTnYCV7GyleaaRTQjfoSfZXprQrVMm7cSJtJbFc1QIbWPyLvrgrfGdfHbGmUPvaN4ODKNxg2jgPQ==", + "dependencies": { + "StyleCop.Analyzers.Unstable": "1.2.0.435" + } + }, + "Google.OrTools.runtime.linux-arm64": { + "type": "Transitive", + "resolved": "9.4.1874", + "contentHash": "Z46ndZcZa2Lt5b76xU9kxVYbPLg/LfuMufhUVsu3Qo3L7Bibf7WXd9j7RRldjnuv8RIHWTqb0b+2FwwMxs0c5A==" + }, + "Google.OrTools.runtime.linux-x64": { + "type": "Transitive", + "resolved": "9.4.1874", + "contentHash": "zGeDb8FuvP9HXjrsU7krVXtSDFpR+DUGNEsH51k94jL9tzf2vWYI8+WUBRHZ/cGe50dpLr+vIjfcNo3gFyOpkQ==" + }, + "Google.OrTools.runtime.osx-arm64": { + "type": "Transitive", + "resolved": "9.4.1874", + "contentHash": "Wo0ZfDaH6DhiQw0jZm4HWJm/oPGPpWNwOLUz+EYaoH3MLtocSxItHGQj/Ta3HyhXnYNOv+TliAH8L+8RCXu/2w==" + }, + "Google.OrTools.runtime.osx-x64": { + "type": "Transitive", + "resolved": "9.4.1874", + "contentHash": "IAfGgKR1og6vU87axK1d37Ak/4jy8B4NMoElovG/KZc/2UY+cJEAQDA709UMegtI4lBhuxTWFNUiHQYmRIB9yQ==" + }, + "Google.OrTools.runtime.win-x64": { + "type": "Transitive", + "resolved": "9.4.1874", + "contentHash": "fUs5qDnZA6itygolcX6nPuachQkY9CVvQbakIzIiRAWKcaj8umQAbFdGwbkyzp3qp34BKW5mtPVsmMyfQBBjOQ==" + }, + "libortki": { + "type": "Transitive", + "resolved": "0.0.2", + "contentHash": "svfuG5mxGY/QC/5DVheHOCELmdSP90RtxQ73j23KarPXZ9ZXW+7v1l5J77hGDyQbEh1BGrnGgKBlyn76RauGHg==", + "dependencies": { + "libortki-linux": "0.0.2", + "libortki-osx": "0.0.2", + "libortki-osx-arm64": "0.0.2", + "libortki-win": "0.0.2" + } + }, + "libortki-linux": { + "type": "Transitive", + "resolved": "0.0.2", + "contentHash": "b04LWD4lgGy60tys3hPFhnUpgWDM6dN5r1PI7GOcPj8VupXCaI70LKNQ5/5twbDE6rkowOGanVTw0S2wBGBqBQ==" + }, + "libortki-osx": { + "type": "Transitive", + "resolved": "0.0.2", + "contentHash": "O6Q9GLULkDkZEPAZJVKLPH0ROXGVOE7BxuddgOcHNK2oiTEM7wIRnzp2OIlYgLpaOLyxJMisbGOhtWgdzt2Wng==" + }, + "libortki-osx-arm64": { + "type": "Transitive", + "resolved": "0.0.2", + "contentHash": "4Qn2dirJmRicnUG945oWpq7HVGwgqCKKxYPMISv/MRvmpZBbXrZ1cVvRaF8WwTu4XXgfKTa1sLv+i8zLifUMeQ==" + }, + "libortki-win": { + "type": "Transitive", + "resolved": "0.0.2", + "contentHash": "HAoROgAKn8XBun11X43HZuspKlo5JGy8/OYw5IUPo7FVh5TCaPrLjGmyGYYZ2dqLlv31yv/b6s254PIRGn95cA==" + }, + "Microsoft.Extensions.Configuration.Abstractions": { + "type": "Transitive", + "resolved": "8.0.0", + "contentHash": "3lE/iLSutpgX1CC0NOW70FJoGARRHbyKmG7dc0klnUZ9Dd9hS6N/POPWhKhMLCEuNN5nXEY5agmlFtH562vqhQ==", + "dependencies": { + "Microsoft.Extensions.Primitives": "8.0.0" + } + }, + "Microsoft.Extensions.DependencyInjection.Abstractions": { + "type": "Transitive", + "resolved": "8.0.1", + "contentHash": "fGLiCRLMYd00JYpClraLjJTNKLmMJPnqxMaiRzEBIIvevlzxz33mXy39Lkd48hu1G+N21S7QpaO5ZzKsI6FRuA==" + }, + "Microsoft.Extensions.Diagnostics.Abstractions": { + "type": "Transitive", + "resolved": "8.0.0", + "contentHash": "JHYCQG7HmugNYUhOl368g+NMxYE/N/AiclCYRNlgCY9eVyiBkOHMwK4x60RYMxv9EL3+rmj1mqHvdCiPpC+D4Q==", + "dependencies": { + "Microsoft.Extensions.DependencyInjection.Abstractions": "8.0.0", + "Microsoft.Extensions.Options": "8.0.0", + "System.Diagnostics.DiagnosticSource": "8.0.0" + } + }, + "Microsoft.Extensions.FileProviders.Abstractions": { + "type": "Transitive", + "resolved": "8.0.0", + "contentHash": "ZbaMlhJlpisjuWbvXr4LdAst/1XxH3vZ6A0BsgTphZ2L4PGuxRLz7Jr/S7mkAAnOn78Vu0fKhEgNF5JO3zfjqQ==", + "dependencies": { + "Microsoft.Extensions.Primitives": "8.0.0" + } + }, + "Microsoft.Extensions.Primitives": { + "type": "Transitive", + "resolved": "8.0.0", + "contentHash": "bXJEZrW9ny8vjMF1JV253WeLhpEVzFo1lyaZu1vQ4ZxWUlVvknZ/+ftFgVheLubb4eZPSwwxBeqS1JkCOjxd8g==" + }, + "NetFabric.Hyperlinq.Abstractions": { + "type": "Transitive", + "resolved": "1.3.0", + "contentHash": "WXnEcGwmXfa8gW9N2MlcaPNUzM3NLMwnAhacbtH554F8YcoXbIkTB+uGa1Aa+9gyb/9JZgYVHnmADgJUKP52nA==" + }, + "StyleCop.Analyzers.Unstable": { + "type": "Transitive", + "resolved": "1.2.0.435", + "contentHash": "ouwPWZxbOV3SmCZxIRqHvljkSzkCyi1tDoMzQtDb/bRP8ctASV/iRJr+A2Gdj0QLaLmWnqTWDrH82/iP+X80Lg==" + }, + "System.Buffers": { + "type": "Transitive", + "resolved": "4.5.1", + "contentHash": "Rw7ijyl1qqRS0YQD/WycNst8hUUMgrMH4FCn1nNm27M4VxchZ1js3fVjQaANHO5f3sN4isvP4a+Met9Y4YomAg==" + }, + "System.Diagnostics.DiagnosticSource": { + "type": "Transitive", + "resolved": "8.0.0", + "contentHash": "c9xLpVz6PL9lp/djOWtk5KPDZq3cSYpmXoJQY524EOtuFl5z9ZtsotpsyrDW40U1DRnQSYvcPKEUV0X//u6gkQ==" + }, + "System.Runtime.CompilerServices.Unsafe": { + "type": "Transitive", + "resolved": "5.0.0", + "contentHash": "ZD9TMpsmYJLrxbbmdvhwt9YEgG5WntEnZ/d1eH8JBX9LBp+Ju8BSBhUGbZMNVHHomWo2KVImJhTDl2hIgw/6MA==" + }, + "nncase.codegen": { + "type": "Project", + "dependencies": { + "Extension.Mathematics": "[1.2.12, )", + "Nncase.Core": "[1.0.0, )", + "Nncase.IO": "[1.0.0, )", + "Razor.Templating.Core": "[1.9.0, )" + } + }, + "nncase.core": { + "type": "Project", + "dependencies": { + "CommunityToolkit.HighPerformance": "[8.2.2, )", + "DryIoc.dll": "[5.3.1, )", + "GiGraph.Dot": "[2.0.0, )", + "Microsoft.Extensions.Hosting.Abstractions": "[8.0.0, )", + "Microsoft.Extensions.Logging.Abstractions": "[8.0.1, )", + "Microsoft.Extensions.Options": "[8.0.2, )", + "NetFabric.Hyperlinq": "[3.0.0-beta48, )", + "System.CommandLine": "[2.0.0-beta4.22272.1, )", + "System.Reactive": "[6.0.0, )" + } + }, + "nncase.diagnostics": { + "type": "Project", + "dependencies": { + "Nncase.Core": "[1.0.0, )" + } + }, + "nncase.egraph": { + "type": "Project", + "dependencies": { + "GiGraph.Dot": "[2.0.0, )", + "Google.OrTools": "[9.4.1874, )", + "NetFabric.Hyperlinq": "[3.0.0-beta48, )", + "Nncase.Core": "[1.0.0, )", + "Nncase.Evaluator": "[1.0.0, )", + "Singulink.Collections.Weak": "[1.0.2, )" + } + }, + "nncase.evaluator": { + "type": "Project", + "dependencies": { + "Nncase.Core": "[1.0.0, )", + "OrtKISharp": "[0.0.2, )" + } + }, + "nncase.graph": { + "type": "Project", + "dependencies": { + "Nncase.Core": "[1.0.0, )", + "Nncase.Evaluator": "[1.0.0, )" + } + }, + "nncase.io": { + "type": "Project" + }, + "nncase.modules.stackvm": { + "type": "Project", + "dependencies": { + "Nncase.CodeGen": "[1.0.0, )", + "Nncase.Passes": "[1.0.0, )" + } + }, + "nncase.passes": { + "type": "Project", + "dependencies": { + "Nncase.Core": "[1.0.0, )", + "Nncase.EGraph": "[1.0.0, )", + "Nncase.Evaluator": "[1.0.0, )", + "Nncase.Graph": "[1.0.0, )" + } + }, + "nncase.schedule": { + "type": "Project", + "dependencies": { + "Google.OrTools": "[9.4.1874, )", + "Nncase.Core": "[1.0.0, )", + "Nncase.Passes": "[1.0.0, )" + } + }, + "CommunityToolkit.HighPerformance": { + "type": "CentralTransitive", + "requested": "[8.2.2, )", + "resolved": "8.2.2", + "contentHash": "+zIp8d3sbtYaRbM6hqDs4Ui/z34j7DcUmleruZlYLE4CVxXq+MO8XJyIs42vzeTYFX+k0Iq1dEbBUnQ4z/Gnrw==" + }, + "DryIoc.dll": { + "type": "CentralTransitive", + "requested": "[5.3.1, )", + "resolved": "5.3.1", + "contentHash": "E3zclUh2CIBks1t2uBD1k18pyGFJ1YSKCrbCDbB7qCdl2RAB+k68AyDpjeplhF1ot2XPV82AgyCWBXMf0ggL1g==" + }, + "Extension.Mathematics": { + "type": "CentralTransitive", + "requested": "[1.2.12, )", + "resolved": "1.2.12", + "contentHash": "D4mn5Cab4ztPLJ0V8uMErDrO/Y61098nwrvyIOLZymVAYOQcwP1vomVWKbTagf1aPU3cX5Q7adZtQEQwOy6XEg==" + }, + "GiGraph.Dot": { + "type": "CentralTransitive", + "requested": "[2.0.0, )", + "resolved": "2.0.0", + "contentHash": "ThvS2mQVveSkTMUm04tMbRYzu1XFPV8xBHISrUMp02APjhv9IRbLu3v3upTPCywORx2Ds/c6AqEUL1WU6kPfuQ==" + }, + "Google.OrTools": { + "type": "CentralTransitive", + "requested": "[9.4.1874, )", + "resolved": "9.4.1874", + "contentHash": "jqRoI+pYlym+fhoU25u+13oti5h+772bllQ9zDitTVMclDXVTiG6pxzvmYO74wnADBMdpb2SQlgiNQxoNk5dlA==", + "dependencies": { + "Google.OrTools.runtime.linux-arm64": "9.4.1874", + "Google.OrTools.runtime.linux-x64": "9.4.1874", + "Google.OrTools.runtime.osx-arm64": "9.4.1874", + "Google.OrTools.runtime.osx-x64": "9.4.1874", + "Google.OrTools.runtime.win-x64": "9.4.1874", + "Google.Protobuf": "3.19.4" + } + }, + "Google.Protobuf": { + "type": "CentralTransitive", + "requested": "[3.19.4, )", + "resolved": "3.19.4", + "contentHash": "fd07/ykL4O4FhqrZIELm5lmiyOHfdPg9+o+hWr6tcfRdS7tHXnImg/2wtogLzlW2eEmr0J7j6ZrZvaWOLiJbxQ==" + }, + "Microsoft.Extensions.Hosting.Abstractions": { + "type": "CentralTransitive", + "requested": "[8.0.0, )", + "resolved": "8.0.0", + "contentHash": "AG7HWwVRdCHlaA++1oKDxLsXIBxmDpMPb3VoyOoAghEWnkUvEAdYQUwnV4jJbAaa/nMYNiEh5ByoLauZBEiovg==", + "dependencies": { + "Microsoft.Extensions.Configuration.Abstractions": "8.0.0", + "Microsoft.Extensions.DependencyInjection.Abstractions": "8.0.0", + "Microsoft.Extensions.Diagnostics.Abstractions": "8.0.0", + "Microsoft.Extensions.FileProviders.Abstractions": "8.0.0", + "Microsoft.Extensions.Logging.Abstractions": "8.0.0" + } + }, + "Microsoft.Extensions.Logging.Abstractions": { + "type": "CentralTransitive", + "requested": "[8.0.1, )", + "resolved": "8.0.1", + "contentHash": "RIFgaqoaINxkM2KTOw72dmilDmTrYA0ns2KW4lDz4gZ2+o6IQ894CzmdL3StM2oh7QQq44nCWiqKqc4qUI9Jmg==", + "dependencies": { + "Microsoft.Extensions.DependencyInjection.Abstractions": "8.0.1" + } + }, + "Microsoft.Extensions.Options": { + "type": "CentralTransitive", + "requested": "[8.0.2, )", + "resolved": "8.0.2", + "contentHash": "dWGKvhFybsaZpGmzkGCbNNwBD1rVlWzrZKANLW/CcbFJpCEceMCGzT7zZwHOGBCbwM0SzBuceMj5HN1LKV1QqA==", + "dependencies": { + "Microsoft.Extensions.DependencyInjection.Abstractions": "8.0.0", + "Microsoft.Extensions.Primitives": "8.0.0" + } + }, + "NetFabric.Hyperlinq": { + "type": "CentralTransitive", + "requested": "[3.0.0-beta48, )", + "resolved": "3.0.0-beta48", + "contentHash": "oYUhXvxNS8bBJWqNkvx5g8y0P/0LtyqS2pN0w4OWjVDNWEpLbdbvPy9w/9z1n2PrqIjX3jxUsEnoCmxxGnI3gw==", + "dependencies": { + "NetFabric.Hyperlinq.Abstractions": "1.3.0", + "System.Buffers": "4.5.1", + "System.Runtime.CompilerServices.Unsafe": "5.0.0" + } + }, + "OrtKISharp": { + "type": "CentralTransitive", + "requested": "[0.0.2, )", + "resolved": "0.0.2", + "contentHash": "q8j0yR5836Zhv9WB9BFkQt1UaEFyibq8bqJcTiULlILF6/sz8z7Wy2N8sgYdDKsdW25zncIz7j6IDbKM5ynePg==", + "dependencies": { + "libortki": "0.0.2" + } + }, + "Singulink.Collections.Weak": { + "type": "CentralTransitive", + "requested": "[1.0.2, )", + "resolved": "1.0.2", + "contentHash": "giLAHrjJe0Bh7yhNexR6pmcv02+Fi+lEPxQVdB8zvkuJCmy6rnqu8CZLIpxrUfLcWDuTCSiK0IfGmMhig3UDhA==" + }, + "System.CommandLine": { + "type": "CentralTransitive", + "requested": "[2.0.0-beta4.22272.1, )", + "resolved": "2.0.0-beta4.22272.1", + "contentHash": "1uqED/q2H0kKoLJ4+hI2iPSBSEdTuhfCYADeJrAqERmiGQ2NNacYKRNEQ+gFbU4glgVyK8rxI+ZOe1onEtr/Pg==" + }, + "System.Reactive": { + "type": "CentralTransitive", + "requested": "[6.0.0, )", + "resolved": "6.0.0", + "contentHash": "31kfaW4ZupZzPsI5PVe77VhnvFF55qgma7KZr/E0iFTs6fmdhhG8j0mgEx620iLTey1EynOkEfnyTjtNEpJzGw==" + } + } + } +} \ No newline at end of file diff --git a/modules/Nncase.Modules.StackVM/StackVMModule.cs b/modules/Nncase.Modules.StackVM/StackVMModule.cs index 44d5c616e8..fcbeb7a0f5 100644 --- a/modules/Nncase.Modules.StackVM/StackVMModule.cs +++ b/modules/Nncase.Modules.StackVM/StackVMModule.cs @@ -14,6 +14,5 @@ internal class StackVMModule : IApplicationPart { public void ConfigureServices(IRegistrator registrator) { - registrator.Register(reuse: Reuse.Singleton); } } diff --git a/modules/Nncase.Modules.StackVM/packages.lock.json b/modules/Nncase.Modules.StackVM/packages.lock.json index 8820c05237..24f44d41ca 100644 --- a/modules/Nncase.Modules.StackVM/packages.lock.json +++ b/modules/Nncase.Modules.StackVM/packages.lock.json @@ -69,33 +69,40 @@ }, "Microsoft.Extensions.Configuration.Abstractions": { "type": "Transitive", - "resolved": "6.0.0", - "contentHash": "qWzV9o+ZRWq+pGm+1dF+R7qTgTYoXvbyowRoBxQJGfqTpqDun2eteerjRQhq5PQ/14S+lqto3Ft4gYaRyl4rdQ==", + "resolved": "8.0.0", + "contentHash": "3lE/iLSutpgX1CC0NOW70FJoGARRHbyKmG7dc0klnUZ9Dd9hS6N/POPWhKhMLCEuNN5nXEY5agmlFtH562vqhQ==", "dependencies": { - "Microsoft.Extensions.Primitives": "6.0.0" + "Microsoft.Extensions.Primitives": "8.0.0" } }, "Microsoft.Extensions.DependencyInjection.Abstractions": { "type": "Transitive", - "resolved": "6.0.0", - "contentHash": "xlzi2IYREJH3/m6+lUrQlujzX8wDitm4QGnUu6kUXTQAWPuZY8i+ticFJbzfqaetLA6KR/rO6Ew/HuYD+bxifg==" + "resolved": "8.0.1", + "contentHash": "fGLiCRLMYd00JYpClraLjJTNKLmMJPnqxMaiRzEBIIvevlzxz33mXy39Lkd48hu1G+N21S7QpaO5ZzKsI6FRuA==" }, - "Microsoft.Extensions.FileProviders.Abstractions": { + "Microsoft.Extensions.Diagnostics.Abstractions": { "type": "Transitive", - "resolved": "6.0.0", - "contentHash": "0pd4/fho0gC12rQswaGQxbU34jOS1TPS8lZPpkFCH68ppQjHNHYle9iRuHeev1LhrJ94YPvzcRd8UmIuFk23Qw==", + "resolved": "8.0.0", + "contentHash": "JHYCQG7HmugNYUhOl368g+NMxYE/N/AiclCYRNlgCY9eVyiBkOHMwK4x60RYMxv9EL3+rmj1mqHvdCiPpC+D4Q==", "dependencies": { - "Microsoft.Extensions.Primitives": "6.0.0" + "Microsoft.Extensions.DependencyInjection.Abstractions": "8.0.0", + "Microsoft.Extensions.Options": "8.0.0", + "System.Diagnostics.DiagnosticSource": "8.0.0" } }, - "Microsoft.Extensions.Primitives": { + "Microsoft.Extensions.FileProviders.Abstractions": { "type": "Transitive", - "resolved": "6.0.0", - "contentHash": "9+PnzmQFfEFNR9J2aDTfJGGupShHjOuGw4VUv+JB044biSHrnmCIMD+mJHmb2H7YryrfBEXDurxQ47gJZdCKNQ==", + "resolved": "8.0.0", + "contentHash": "ZbaMlhJlpisjuWbvXr4LdAst/1XxH3vZ6A0BsgTphZ2L4PGuxRLz7Jr/S7mkAAnOn78Vu0fKhEgNF5JO3zfjqQ==", "dependencies": { - "System.Runtime.CompilerServices.Unsafe": "6.0.0" + "Microsoft.Extensions.Primitives": "8.0.0" } }, + "Microsoft.Extensions.Primitives": { + "type": "Transitive", + "resolved": "8.0.0", + "contentHash": "bXJEZrW9ny8vjMF1JV253WeLhpEVzFo1lyaZu1vQ4ZxWUlVvknZ/+ftFgVheLubb4eZPSwwxBeqS1JkCOjxd8g==" + }, "NetFabric.Hyperlinq.Abstractions": { "type": "Transitive", "resolved": "1.3.0", @@ -111,10 +118,15 @@ "resolved": "4.5.1", "contentHash": "Rw7ijyl1qqRS0YQD/WycNst8hUUMgrMH4FCn1nNm27M4VxchZ1js3fVjQaANHO5f3sN4isvP4a+Met9Y4YomAg==" }, + "System.Diagnostics.DiagnosticSource": { + "type": "Transitive", + "resolved": "8.0.0", + "contentHash": "c9xLpVz6PL9lp/djOWtk5KPDZq3cSYpmXoJQY524EOtuFl5z9ZtsotpsyrDW40U1DRnQSYvcPKEUV0X//u6gkQ==" + }, "System.Runtime.CompilerServices.Unsafe": { "type": "Transitive", - "resolved": "6.0.0", - "contentHash": "/iUeP3tq1S0XdNNoMz5C9twLSrM/TH+qElHkXWaPvuNOt+99G75NrV0OS2EqHx5wMN7popYjpc8oTjC1y16DLg==" + "resolved": "5.0.0", + "contentHash": "ZD9TMpsmYJLrxbbmdvhwt9YEgG5WntEnZ/d1eH8JBX9LBp+Ju8BSBhUGbZMNVHHomWo2KVImJhTDl2hIgw/6MA==" }, "nncase.codegen": { "type": "Project", @@ -128,15 +140,15 @@ "nncase.core": { "type": "Project", "dependencies": { + "CommunityToolkit.HighPerformance": "[8.2.2, )", "DryIoc.dll": "[5.3.1, )", "GiGraph.Dot": "[2.0.0, )", - "Microsoft.Extensions.Hosting.Abstractions": "[6.0.0, )", - "Microsoft.Extensions.Logging.Abstractions": "[6.0.0, )", - "Microsoft.Extensions.Options": "[6.0.0, )", - "Microsoft.Toolkit.HighPerformance": "[7.1.1, )", + "Microsoft.Extensions.Hosting.Abstractions": "[8.0.0, )", + "Microsoft.Extensions.Logging.Abstractions": "[8.0.1, )", + "Microsoft.Extensions.Options": "[8.0.2, )", "NetFabric.Hyperlinq": "[3.0.0-beta48, )", "System.CommandLine": "[2.0.0-beta4.22272.1, )", - "System.Reactive": "[5.0.0, )" + "System.Reactive": "[6.0.0, )" } }, "nncase.egraph": { @@ -176,6 +188,12 @@ "Nncase.Graph": "[1.0.0, )" } }, + "CommunityToolkit.HighPerformance": { + "type": "CentralTransitive", + "requested": "[8.2.2, )", + "resolved": "8.2.2", + "contentHash": "+zIp8d3sbtYaRbM6hqDs4Ui/z34j7DcUmleruZlYLE4CVxXq+MO8XJyIs42vzeTYFX+k0Iq1dEbBUnQ4z/Gnrw==" + }, "DryIoc.dll": { "type": "CentralTransitive", "requested": "[5.3.1, )", @@ -216,37 +234,36 @@ }, "Microsoft.Extensions.Hosting.Abstractions": { "type": "CentralTransitive", - "requested": "[6.0.0, )", - "resolved": "6.0.0", - "contentHash": "GcT5l2CYXL6Sa27KCSh0TixsRfADUgth+ojQSD5EkzisZxmGFh7CwzkcYuGwvmXLjr27uWRNrJ2vuuEjMhU05Q==", + "requested": "[8.0.0, )", + "resolved": "8.0.0", + "contentHash": "AG7HWwVRdCHlaA++1oKDxLsXIBxmDpMPb3VoyOoAghEWnkUvEAdYQUwnV4jJbAaa/nMYNiEh5ByoLauZBEiovg==", "dependencies": { - "Microsoft.Extensions.Configuration.Abstractions": "6.0.0", - "Microsoft.Extensions.DependencyInjection.Abstractions": "6.0.0", - "Microsoft.Extensions.FileProviders.Abstractions": "6.0.0" + "Microsoft.Extensions.Configuration.Abstractions": "8.0.0", + "Microsoft.Extensions.DependencyInjection.Abstractions": "8.0.0", + "Microsoft.Extensions.Diagnostics.Abstractions": "8.0.0", + "Microsoft.Extensions.FileProviders.Abstractions": "8.0.0", + "Microsoft.Extensions.Logging.Abstractions": "8.0.0" } }, "Microsoft.Extensions.Logging.Abstractions": { "type": "CentralTransitive", - "requested": "[6.0.0, )", - "resolved": "6.0.0", - "contentHash": "/HggWBbTwy8TgebGSX5DBZ24ndhzi93sHUBDvP1IxbZD7FDokYzdAr6+vbWGjw2XAfR2EJ1sfKUotpjHnFWPxA==" + "requested": "[8.0.1, )", + "resolved": "8.0.1", + "contentHash": "RIFgaqoaINxkM2KTOw72dmilDmTrYA0ns2KW4lDz4gZ2+o6IQ894CzmdL3StM2oh7QQq44nCWiqKqc4qUI9Jmg==", + "dependencies": { + "Microsoft.Extensions.DependencyInjection.Abstractions": "8.0.1" + } }, "Microsoft.Extensions.Options": { "type": "CentralTransitive", - "requested": "[6.0.0, )", - "resolved": "6.0.0", - "contentHash": "dzXN0+V1AyjOe2xcJ86Qbo233KHuLEY0njf/P2Kw8SfJU+d45HNS2ctJdnEnrWbM9Ye2eFgaC5Mj9otRMU6IsQ==", + "requested": "[8.0.2, )", + "resolved": "8.0.2", + "contentHash": "dWGKvhFybsaZpGmzkGCbNNwBD1rVlWzrZKANLW/CcbFJpCEceMCGzT7zZwHOGBCbwM0SzBuceMj5HN1LKV1QqA==", "dependencies": { - "Microsoft.Extensions.DependencyInjection.Abstractions": "6.0.0", - "Microsoft.Extensions.Primitives": "6.0.0" + "Microsoft.Extensions.DependencyInjection.Abstractions": "8.0.0", + "Microsoft.Extensions.Primitives": "8.0.0" } }, - "Microsoft.Toolkit.HighPerformance": { - "type": "CentralTransitive", - "requested": "[7.1.1, )", - "resolved": "7.1.1", - "contentHash": "TRnvDpZPXO30hTOtjfLw6Y9BtTKtTpzk9lefeh4RMCaUihWrVKQR454nYH4/mMJAh+LXqfAPyk0kfkJs0Amopw==" - }, "NetFabric.Hyperlinq": { "type": "CentralTransitive", "requested": "[3.0.0-beta48, )", @@ -287,9 +304,9 @@ }, "System.Reactive": { "type": "CentralTransitive", - "requested": "[5.0.0, )", - "resolved": "5.0.0", - "contentHash": "erBZjkQHWL9jpasCE/0qKAryzVBJFxGHVBAvgRN1bzM0q2s1S4oYREEEL0Vb+1kA/6BKb5FjUZMp5VXmy+gzkQ==" + "requested": "[6.0.0, )", + "resolved": "6.0.0", + "contentHash": "31kfaW4ZupZzPsI5PVe77VhnvFF55qgma7KZr/E0iFTs6fmdhhG8j0mgEx620iLTey1EynOkEfnyTjtNEpJzGw==" } } } diff --git a/modules/k210/include/nncase/runtime/k210/op_reader.h b/modules/k210/include/nncase/runtime/k210/op_reader.h index 99c84831ce..04df913581 100644 --- a/modules/k210/include/nncase/runtime/k210/op_reader.h +++ b/modules/k210/include/nncase/runtime/k210/op_reader.h @@ -25,7 +25,7 @@ class NNCASE_MODULES_K210_API op_visitor { ~op_visitor() = default; - result visit(gsl::span text) noexcept; + result visit(std::span text) noexcept; virtual result visit(NNCASE_UNUSED const kpu_download_options &op) noexcept { diff --git a/modules/k210/src/runtime/op_reader.cpp b/modules/k210/src/runtime/op_reader.cpp index 119f7f75fb..638ec314a9 100644 --- a/modules/k210/src/runtime/op_reader.cpp +++ b/modules/k210/src/runtime/op_reader.cpp @@ -37,7 +37,7 @@ result op_visitor::next() noexcept { return err(nncase_k210_errc::k210_illegal_instruction); } -result op_visitor::visit(gsl::span text) noexcept { +result op_visitor::visit(std::span text) noexcept { reader_ = span_reader(text); interrupted_ = false; diff --git a/modules/k210/src/runtime/ops/copy.cpp b/modules/k210/src/runtime/ops/copy.cpp index 0a0ac60c54..df0a2fda78 100644 --- a/modules/k210/src/runtime/ops/copy.cpp +++ b/modules/k210/src/runtime/ops/copy.cpp @@ -28,7 +28,7 @@ result k210_runtime_function::visit(const copy_options &op) noexcept { runtime_shape_t in_strides{op.in_strides.begin(), op.in_strides.end()}; runtime_shape_t out_strides{op.out_strides.begin(), op.out_strides.end()}; return kernels::copy(op.input.datatype, - reinterpret_cast(input.data()), - reinterpret_cast(output.data()), in_shape, + reinterpret_cast(input.data()), + reinterpret_cast(output.data()), in_shape, in_strides, out_strides); } diff --git a/modules/k210/src/runtime/runtime_function.cpp b/modules/k210/src/runtime/runtime_function.cpp index 0746dd6181..89dc980d32 100644 --- a/modules/k210/src/runtime/runtime_function.cpp +++ b/modules/k210/src/runtime/runtime_function.cpp @@ -77,10 +77,10 @@ result k210_runtime_function::invoke_core() noexcept { return ok(); } -result> +result> k210_runtime_function::memory_at(const memory_range &mrange) noexcept { #define ID_NOT_FOUND ((size_t)-1) - gsl::byte *base; + std::byte *base; switch (mrange.memory_location) { case mem_input: { size_t id = ID_NOT_FOUND; @@ -93,7 +93,7 @@ k210_runtime_function::memory_at(const memory_range &mrange) noexcept { if (id != ID_NOT_FOUND) { try_var(tensor, device_input_tensor(id)); - base = reinterpret_cast( + base = reinterpret_cast( static_cast(tensor.impl()) ->memory_block() .virtual_address - @@ -122,7 +122,7 @@ k210_runtime_function::memory_at(const memory_range &mrange) noexcept { break; } case mem_rdata: - base = const_cast(module().rdata().data()); + base = const_cast(module().rdata().data()); break; case mem_data: base = module().data().data(); diff --git a/modules/k210/src/runtime/runtime_function.h b/modules/k210/src/runtime/runtime_function.h index b4abdc165c..bd2af63582 100644 --- a/modules/k210/src/runtime/runtime_function.h +++ b/modules/k210/src/runtime/runtime_function.h @@ -47,10 +47,10 @@ class k210_runtime_function : public runtime_function, private op_visitor { result visit(const copy_options &op) noexcept override; private: - result> memory_at(const memory_range &mrange) noexcept; + result> memory_at(const memory_range &mrange) noexcept; private: - gsl::span text_; + std::span text_; }; END_NS_NNCASE_RT_MODULE diff --git a/modules/k210/src/runtime/runtime_module.cpp b/modules/k210/src/runtime/runtime_module.cpp index ec0c8cdbbd..8d1af08b33 100644 --- a/modules/k210/src/runtime/runtime_module.cpp +++ b/modules/k210/src/runtime/runtime_module.cpp @@ -42,7 +42,7 @@ result k210_runtime_module::initialize_before_functions( assert(context.is_section_pinned()); auto data_pool = mempool(mem_data); if (data_pool.size) { - data_.reset(new (std::nothrow) gsl::byte[data_pool.size]); + data_.reset(new (std::nothrow) std::byte[data_pool.size]); if (!data_) return err(std::errc::not_enough_memory); } @@ -57,21 +57,21 @@ result k210_runtime_module::initialize_before_functions( return ok(); } -gsl::span k210_runtime_module::data() const noexcept { +std::span k210_runtime_module::data() const noexcept { return {data_.get(), mempool(mem_data).size}; } -gsl::span k210_runtime_module::kpu_ram() noexcept { - gsl::byte *base; +std::span k210_runtime_module::kpu_ram() noexcept { + std::byte *base; #ifdef NNCASE_SIMULATOR base = kpu_ram_.data(); #else - base = reinterpret_cast(AI_IO_BASE_ADDR); + base = reinterpret_cast(AI_IO_BASE_ADDR); #endif return {base, KPU_RAM_SIZE}; } -gsl::span k210_runtime_module::rdata() const noexcept { +std::span k210_runtime_module::rdata() const noexcept { return rdata_; } diff --git a/modules/k210/src/runtime/runtime_module.h b/modules/k210/src/runtime/runtime_module.h index 7a62836c27..61e5965827 100644 --- a/modules/k210/src/runtime/runtime_module.h +++ b/modules/k210/src/runtime/runtime_module.h @@ -20,9 +20,9 @@ BEGIN_NS_NNCASE_RT_MODULE(k210) class k210_runtime_module : public runtime_module { public: - gsl::span data() const noexcept; - gsl::span rdata() const noexcept; - gsl::span kpu_ram() noexcept; + std::span data() const noexcept; + std::span rdata() const noexcept; + std::span kpu_ram() noexcept; #if !NNCASE_SIMULATOR uint32_t dma_ch() const noexcept { return dma_ch_; } @@ -35,11 +35,11 @@ class k210_runtime_module : public runtime_module { create_function() noexcept override; private: - std::unique_ptr data_; - gsl::span rdata_; - gsl::span text_; + std::unique_ptr data_; + std::span rdata_; + std::span text_; #ifdef NNCASE_SIMULATOR - std::array kpu_ram_; + std::array kpu_ram_; #else uint32_t dma_ch_; #endif diff --git a/modules/k210/src/runtime/shared_runtime_tensor.platform.cpp b/modules/k210/src/runtime/shared_runtime_tensor.platform.cpp index e2be1ae32f..6fffd2b96f 100644 --- a/modules/k210/src/runtime/shared_runtime_tensor.platform.cpp +++ b/modules/k210/src/runtime/shared_runtime_tensor.platform.cpp @@ -49,7 +49,7 @@ physical_memory_block::operator=(physical_memory_block &&other) noexcept { void physical_memory_block::free( NNCASE_UNUSED host_memory_block &block) noexcept { if (owned) - delete[] reinterpret_cast(physical_address + IOMEM); + delete[] reinterpret_cast(physical_address + IOMEM); physical_address = 0; owned = false; } @@ -70,7 +70,7 @@ physical_memory_block::acknowledge(host_memory_block &block) noexcept { result physical_memory_block::allocate(host_memory_block &block) noexcept { - auto buffer = new (std::nothrow) gsl::byte[block.size_bytes]; + auto buffer = new (std::nothrow) std::byte[block.size_bytes]; CHECK_WITH_ERR(buffer, std::errc::not_enough_memory); block.virtual_address = reinterpret_cast(buffer); block.physical_block.physical_address = block.virtual_address - IOMEM; diff --git a/modules/k210/src/transforms/k210/kpu_conv2d.cpp b/modules/k210/src/transforms/k210/kpu_conv2d.cpp index 0643695396..45189fa20f 100644 --- a/modules/k210/src/transforms/k210/kpu_conv2d.cpp +++ b/modules/k210/src/transforms/k210/kpu_conv2d.cpp @@ -143,9 +143,9 @@ auto quantize_act(quantizer &quantizer, float act_in_scale, fused_unary::compile_graph(fu->subgraph(), builder); auto buf = ss.str(); - std::vector body( - reinterpret_cast(buf.data()), - reinterpret_cast(buf.data() + buf.size())); + std::vector body( + reinterpret_cast(buf.data()), + reinterpret_cast(buf.data() + buf.size())); kernels::nnil_unary_method(samples_x.data(), samples_y.data(), samples_count, body) .unwrap_or_throw(); diff --git a/modules/vulkan/include/nncase/runtime/vulkan/op_reader.h b/modules/vulkan/include/nncase/runtime/vulkan/op_reader.h index 6326936789..1a409847d5 100644 --- a/modules/vulkan/include/nncase/runtime/vulkan/op_reader.h +++ b/modules/vulkan/include/nncase/runtime/vulkan/op_reader.h @@ -25,7 +25,7 @@ class NNCASE_MODULES_VULKAN_API op_visitor { ~op_visitor() = default; - result visit(gsl::span text) noexcept; + result visit(std::span text) noexcept; virtual result visit(NNCASE_UNUSED const ldbuf_op_t &op) noexcept { return ok(); diff --git a/modules/vulkan/src/runtime/op_reader.cpp b/modules/vulkan/src/runtime/op_reader.cpp index 94e7e32d98..42f91d564a 100644 --- a/modules/vulkan/src/runtime/op_reader.cpp +++ b/modules/vulkan/src/runtime/op_reader.cpp @@ -42,7 +42,7 @@ result op_visitor::next() noexcept { return err(std::errc::operation_not_supported); } -result op_visitor::visit(gsl::span text) noexcept { +result op_visitor::visit(std::span text) noexcept { reader_ = span_reader(text); interrupted_ = false; diff --git a/modules/vulkan/src/runtime/runtime_function.h b/modules/vulkan/src/runtime/runtime_function.h index 63da9411af..a68fd0a0ce 100644 --- a/modules/vulkan/src/runtime/runtime_function.h +++ b/modules/vulkan/src/runtime/runtime_function.h @@ -71,7 +71,7 @@ class vulkan_runtime_function : public runtime_function, private op_visitor { private: uint32_t input_pool_size_; uint32_t output_pool_size_; - gsl::span text_; + std::span text_; vk::Buffer input_buffer_; vk::Buffer output_buffer_; vk::DeviceMemory input_mem_; diff --git a/modules/vulkan/src/runtime/runtime_module.h b/modules/vulkan/src/runtime/runtime_module.h index 59e8c98dec..dbd19f00c1 100644 --- a/modules/vulkan/src/runtime/runtime_module.h +++ b/modules/vulkan/src/runtime/runtime_module.h @@ -28,7 +28,7 @@ class vulkan_runtime_module : public runtime_module { vk::Buffer data() const noexcept { return data_buffer_; } vk::Buffer rdata() const noexcept { return {}; } - gsl::span shader() const noexcept { return shader_; } + std::span shader() const noexcept { return shader_; } vk::Device device() const noexcept { return ctx_->device(); } vk::CommandPool command_pool() const noexcept { return cmd_pool_; } @@ -71,8 +71,8 @@ class vulkan_runtime_module : public runtime_module { private: uint32_t descriptors_; uint32_t descriptor_sets_; - gsl::span text_; - gsl::span shader_; + std::span text_; + std::span shader_; vulkan_context *ctx_; vk::Buffer data_buffer_; vk::DeviceMemory data_mem_; diff --git a/nncase.sln b/nncase.sln index 77baf826c2..e1d5a9b5f1 100644 --- a/nncase.sln +++ b/nncase.sln @@ -77,7 +77,9 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Tests.TestFixture", EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Passes", "src\Nncase.Passes\Nncase.Passes.csproj", "{E6462E82-B48F-4AFA-AE34-725EF0A9CB42}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Nncase.Studio", "src\Nncase.Studio\Nncase.Studio.csproj", "{B9A09DA2-EF1A-4C0E-A0F5-427AFBB5C769}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Studio", "src\Nncase.Studio\Nncase.Studio.csproj", "{0E5BF964-B878-4BD6-8C84-FFE85E23994B}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Modules.CPU", "modules\Nncase.Modules.CPU\Nncase.Modules.CPU.csproj", "{6AEE2334-CCF4-464E-8C90-C6BC0D930327}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution @@ -173,10 +175,14 @@ Global {E6462E82-B48F-4AFA-AE34-725EF0A9CB42}.Debug|Any CPU.Build.0 = Debug|Any CPU {E6462E82-B48F-4AFA-AE34-725EF0A9CB42}.Release|Any CPU.ActiveCfg = Release|Any CPU {E6462E82-B48F-4AFA-AE34-725EF0A9CB42}.Release|Any CPU.Build.0 = Release|Any CPU - {B9A09DA2-EF1A-4C0E-A0F5-427AFBB5C769}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {B9A09DA2-EF1A-4C0E-A0F5-427AFBB5C769}.Debug|Any CPU.Build.0 = Debug|Any CPU - {B9A09DA2-EF1A-4C0E-A0F5-427AFBB5C769}.Release|Any CPU.ActiveCfg = Release|Any CPU - {B9A09DA2-EF1A-4C0E-A0F5-427AFBB5C769}.Release|Any CPU.Build.0 = Release|Any CPU + {0E5BF964-B878-4BD6-8C84-FFE85E23994B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {0E5BF964-B878-4BD6-8C84-FFE85E23994B}.Debug|Any CPU.Build.0 = Debug|Any CPU + {0E5BF964-B878-4BD6-8C84-FFE85E23994B}.Release|Any CPU.ActiveCfg = Release|Any CPU + {0E5BF964-B878-4BD6-8C84-FFE85E23994B}.Release|Any CPU.Build.0 = Release|Any CPU + {6AEE2334-CCF4-464E-8C90-C6BC0D930327}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {6AEE2334-CCF4-464E-8C90-C6BC0D930327}.Debug|Any CPU.Build.0 = Debug|Any CPU + {6AEE2334-CCF4-464E-8C90-C6BC0D930327}.Release|Any CPU.ActiveCfg = Release|Any CPU + {6AEE2334-CCF4-464E-8C90-C6BC0D930327}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -207,7 +213,8 @@ Global {E365B1B1-4D13-4839-9763-A7A7C5F32FD4} = {BCA74168-F015-4B5B-B4CD-C83AE06B9822} {98A03405-CA53-4EC4-9B18-94D1C8DF9453} = {E5A4516C-4080-4346-991D-57A7AA76ADA6} {E6462E82-B48F-4AFA-AE34-725EF0A9CB42} = {BCA74168-F015-4B5B-B4CD-C83AE06B9822} - {B9A09DA2-EF1A-4C0E-A0F5-427AFBB5C769} = {BCA74168-F015-4B5B-B4CD-C83AE06B9822} + {0E5BF964-B878-4BD6-8C84-FFE85E23994B} = {BCA74168-F015-4B5B-B4CD-C83AE06B9822} + {6AEE2334-CCF4-464E-8C90-C6BC0D930327} = {9859F5E8-5504-4AFE-B955-9497A0A0CD66} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {9492E141-292E-4D60-9C6E-3738AB234DB2} diff --git a/python/common/pytype_utils.h b/python/common/pytype_utils.h index 0007f4f878..658dfd2860 100644 --- a/python/common/pytype_utils.h +++ b/python/common/pytype_utils.h @@ -127,7 +127,7 @@ strides_t to_rt_strides(size_t elemsize, return strides; } -std::vector to_py_shape(gsl::span value) { +std::vector to_py_shape(std::span value) { namespace py = pybind11; std::vector shape(value.size()); @@ -137,7 +137,7 @@ std::vector to_py_shape(gsl::span value) { } std::vector to_py_strides(size_t elemsize, - gsl::span value) { + std::span value) { namespace py = pybind11; std::vector strides(value.size()); diff --git a/python/common/runtime_tensor.inl b/python/common/runtime_tensor.inl index 11ec6280e6..2f61ed8453 100644 --- a/python/common/runtime_tensor.inl +++ b/python/common/runtime_tensor.inl @@ -23,28 +23,27 @@ py::class_(m, "TensorDesc") .def_readwrite("size", &tensor_desc::size); py::class_(m, "RuntimeTensor") - .def_static("from_numpy", - [](py::array arr) { - auto src_buffer = arr.request(); - auto datatype = from_dtype(arr); - auto tensor = - host_runtime_tensor::create( - datatype, to_rt_shape(src_buffer.shape), - to_rt_strides(src_buffer.itemsize, - src_buffer.strides), - gsl::make_span( - reinterpret_cast(src_buffer.ptr), - src_buffer.size * src_buffer.itemsize), - [=](gsl::byte *) { - if (!py::detail::is_py_shutdown()) { - py::gil_scoped_acquire gil; - arr.dec_ref(); - } - }) - .unwrap_or_throw(); - arr.inc_ref(); - return tensor; - }) + .def_static( + "from_numpy", + [](py::array arr) { + auto src_buffer = arr.request(); + auto datatype = from_dtype(arr); + auto tensor = + host_runtime_tensor::create( + datatype, to_rt_shape(src_buffer.shape), + to_rt_strides(src_buffer.itemsize, src_buffer.strides), + std::span(reinterpret_cast(src_buffer.ptr), + src_buffer.size * src_buffer.itemsize), + [=](std::byte *) { + if (!py::detail::is_py_shutdown()) { + py::gil_scoped_acquire gil; + arr.dec_ref(); + } + }) + .unwrap_or_throw(); + arr.inc_ref(); + return tensor; + }) .def("copy_to", [](runtime_tensor &from, runtime_tensor &to) { from.copy_to(to).unwrap_or_throw(); diff --git a/python/common/type_casters.h b/python/common/type_casters.h index 1c9a8f49c8..fb546eed1f 100644 --- a/python/common/type_casters.h +++ b/python/common/type_casters.h @@ -25,9 +25,9 @@ inline bool is_py_shutdown() { g_python_shutdown.load(std::memory_order_acquire); } -template <> struct type_caster> { +template <> struct type_caster> { public: - PYBIND11_TYPE_CASTER(gsl::span, _("bytes")); + PYBIND11_TYPE_CASTER(std::span, _("bytes")); bool load(handle src, bool) { if (!py::isinstance(src)) @@ -38,7 +38,7 @@ template <> struct type_caster> { if (PyBytes_AsStringAndSize( src.ptr(), reinterpret_cast(&buffer), &length)) return false; - value = {(const gsl::byte *)buffer, (size_t)length}; + value = {(const std::byte *)buffer, (size_t)length}; loader_life_support::add_patient(src); return true; } diff --git a/python/nncase/native/ffi.cpp b/python/nncase/native/ffi.cpp index b8c99ed966..13fd2f4ac6 100644 --- a/python/nncase/native/ffi.cpp +++ b/python/nncase/native/ffi.cpp @@ -317,7 +317,7 @@ PYBIND11_MODULE(_nncase, m) { py::class_(m, "Simulator") .def(py::init()) .def("load_model", - [](interpreter &interp, gsl::span buffer) { + [](interpreter &interp, std::span buffer) { interp.load_model(buffer, true).unwrap_or_throw(); }) .def_property_readonly("inputs_size", &interpreter::inputs_size) diff --git a/python/nncaseruntime/native/ffi.cpp b/python/nncaseruntime/native/ffi.cpp index 63393cc32f..374bb72c8d 100644 --- a/python/nncaseruntime/native/ffi.cpp +++ b/python/nncaseruntime/native/ffi.cpp @@ -81,7 +81,7 @@ PYBIND11_MODULE(_nncaseruntime, m) { py::class_(m, "Interpreter") .def(py::init()) .def("load_model", - [](interpreter &interp, gsl::span buffer) { + [](interpreter &interp, std::span buffer) { interp.load_model(buffer, true).unwrap_or_throw(); }) .def_property_readonly("inputs_size", &interpreter::inputs_size) diff --git a/requirements.test.txt b/requirements.test.txt index 2e5eabae04..00d2282ef7 100644 --- a/requirements.test.txt +++ b/requirements.test.txt @@ -1,21 +1,19 @@ -tensorflow==2.10.0 +tensorflow==2.16.1 +torch==2.2.1 +torchvision==0.17.1 +onnx==1.15.0 +onnx-simplifier==0.4.36 +onnxruntime==1.17.1 +ncnn==1.0.20240102 +toml==0.10.2 +numpy +imageio +protobuf matplotlib pillow opencv-python -onnx==1.12.0 -onnx-simplifier==0.3.6 -onnxoptimizer==0.2.6 -onnxruntime==1.12.0 -ncnn==1.0.20230816 -numpy==1.21.0 -torch==1.9.0 -torchvision==0.10.0 -imageio==2.15.0 -protobuf==3.12.2 -kendryte-caffe pytest pytest-xdist pyyaml -toml==0.10.2 pandas tabulate diff --git a/setup.py b/setup.py index 6724cbac65..0c0af6240b 100644 --- a/setup.py +++ b/setup.py @@ -83,8 +83,7 @@ def run(self): os.walk(os.path.join(bin_dir, 'sharplibs')) for _lib in files if os.path.isfile(os.path.join(root, _lib)) and (os.path.splitext(_lib)[-1] in [".dll", ".so", ".dylib", ".json"] or - _lib.startswith("lib")) - and not _lib.endswith(".deps.json")] + _lib.startswith("lib"))] for lib in sharp_libs: shutil.move(lib, os.path.join(self.build_dir, @@ -204,7 +203,7 @@ def build_cmake(self, ext: Extension): extdir += os.path.sep bin_dir = os.path.abspath(os.path.join(self.build_temp, 'install')) - cmake_args = ['-G', 'Ninja', '-DDOTNET_INIT_FOR_CONFIG=ON'] + cmake_args = ['-G', 'Ninja', '-DDOTNET_INIT_FOR_CONFIG=OFF'] if platform.system() == 'Windows': cmake_args += ['-DCMAKE_C_COMPILER=clang-cl'] cmake_args += ['-DCMAKE_CXX_COMPILER=clang-cl'] diff --git a/src/Native/include/nncase/api.h b/src/Native/include/nncase/api.h index 1f457d1f8e..4896d0c4f4 100644 --- a/src/Native/include/nncase/api.h +++ b/src/Native/include/nncase/api.h @@ -48,6 +48,9 @@ NNCASE_API int nncase_interp_free(nncase::runtime::interpreter *interp); NNCASE_API int nncase_interp_load_model(nncase::runtime::interpreter *interp, void *model_buffer, uint32_t model_size, bool copy_buffer); +NNCASE_API int +nncase_interp_load_model_from_path(nncase::runtime::interpreter *interp, + const char *model_path); NNCASE_API int nncase_interp_set_dump_root(nncase::runtime::interpreter *interp, const char *path); NNCASE_API int diff --git a/src/Native/include/nncase/compiler.h b/src/Native/include/nncase/compiler.h index 1ef12f990d..9d8c876535 100644 --- a/src/Native/include/nncase/compiler.h +++ b/src/Native/include/nncase/compiler.h @@ -22,7 +22,6 @@ #include #include #include -using nlohmann::json; extern "C" { typedef void *clr_object_handle_t; @@ -451,7 +450,7 @@ class shape_bucket_options : public clr_object_base { std::map> range_info() { return {}; } void range_info(std::map> value) { - json j = value; + nlohmann::json j = value; std::string s = j.dump(); nncase_clr_api()->shape_bucket_options_set_range_info( obj_.get(), s.c_str(), s.length()); @@ -465,7 +464,7 @@ class shape_bucket_options : public clr_object_base { std::map fix_var_map() { return {}; } void fix_var_map(std::map value) { - json j = value; + nlohmann::json j = value; std::string s = j.dump(); nncase_clr_api()->shape_bucket_options_set_fix_var_map( obj_.get(), s.c_str(), s.length()); diff --git a/src/Native/include/nncase/compiler_defs.h b/src/Native/include/nncase/compiler_defs.h index ea44203b9e..3ab64188e0 100644 --- a/src/Native/include/nncase/compiler_defs.h +++ b/src/Native/include/nncase/compiler_defs.h @@ -13,7 +13,6 @@ * limitations under the License. */ #pragma once -#include #include #if defined(_MSC_VER) @@ -34,31 +33,17 @@ #define NNCASE_UNREACHABLE() __builtin_unreachable() #endif -#if gsl_CPP17_OR_GREATER #define NNCASE_INLINE_VAR inline #define NNCASE_UNUSED [[maybe_unused]] namespace nncase { template using invoke_result_t = std::invoke_result_t; } -#else -#define NNCASE_INLINE_VAR -#if defined(_MSC_VER) -#define NNCASE_UNUSED -#else -#define NNCASE_UNUSED __attribute__((unused)) -#endif -namespace nncase { -template -using invoke_result_t = std::result_of_t; -} -#endif #define NNCASE_LITTLE_ENDIAN 1 -#define NNCASE_HAVE_STD_BYTE gsl_CPP17_OR_GREATER -#define NNCASE_NODISCARD gsl_NODISCARD -#define NNCASE_NORETURN gsl_NORETURN +#define NNCASE_NODISCARD [[nodiscard]] +#define NNCASE_NORETURN [[noreturn]] #define BEGIN_NS_NNCASE_RUNTIME \ namespace nncase { \ @@ -96,8 +81,35 @@ using invoke_result_t = std::result_of_t; } #ifndef DEFINE_ENUM_BITMASK_OPERATORS -#define DEFINE_ENUM_BITMASK_OPERATORS(ENUMTYPE) \ - gsl_DEFINE_ENUM_BITMASK_OPERATORS(ENUMTYPE) +#define DEFINE_ENUM_BITMASK_OPERATORS(ENUM) \ + [[nodiscard]] inline constexpr ENUM operator~(ENUM val) noexcept { \ + typedef typename std::underlying_type::type U; \ + return ENUM(~U(val)); \ + } \ + [[nodiscard]] inline constexpr ENUM operator|(ENUM lhs, \ + ENUM rhs) noexcept { \ + typedef typename std::underlying_type::type U; \ + return ENUM(U(lhs) | U(rhs)); \ + } \ + [[nodiscard]] inline constexpr ENUM operator&(ENUM lhs, \ + ENUM rhs) noexcept { \ + typedef typename std::underlying_type::type U; \ + return ENUM(U(lhs) & U(rhs)); \ + } \ + [[nodiscard]] inline constexpr ENUM operator^(ENUM lhs, \ + ENUM rhs) noexcept { \ + typedef typename std::underlying_type::type U; \ + return ENUM(U(lhs) ^ U(rhs)); \ + } \ + inline constexpr ENUM &operator|=(ENUM &lhs, ENUM rhs) noexcept { \ + return lhs = lhs | rhs; \ + } \ + inline constexpr ENUM &operator&=(ENUM &lhs, ENUM rhs) noexcept { \ + return lhs = lhs & rhs; \ + } \ + inline constexpr ENUM &operator^=(ENUM &lhs, ENUM rhs) noexcept { \ + return lhs = lhs ^ rhs; \ + } #endif namespace nncase { diff --git a/src/Native/include/nncase/kernels/apply.h b/src/Native/include/nncase/kernels/apply.h index 1df71583fe..17a85a4449 100644 --- a/src/Native/include/nncase/kernels/apply.h +++ b/src/Native/include/nncase/kernels/apply.h @@ -41,49 +41,49 @@ namespace detail { #define APPLY_IMPL_FOR(i) for (index[i] = 0; index[i] < shape[i]; index[i]++) template -result apply_1(gsl::span shape, +result apply_1(std::span shape, Callable &&callable) noexcept { size_t index[1]; APPLY_IMPL_FOR(0) - try_(callable(gsl::span(index))); + try_(callable(std::span(index))); return ok(); } template -result apply_2(gsl::span shape, +result apply_2(std::span shape, Callable &&callable) noexcept { size_t index[2]; APPLY_IMPL_FOR(0) APPLY_IMPL_FOR(1) - try_(callable(gsl::span(index))); + try_(callable(std::span(index))); return ok(); } template -result apply_3(gsl::span shape, +result apply_3(std::span shape, Callable &&callable) noexcept { size_t index[3]; APPLY_IMPL_FOR(0) APPLY_IMPL_FOR(1) APPLY_IMPL_FOR(2) - try_(callable(gsl::span(index))); + try_(callable(std::span(index))); return ok(); } template -result apply_4(gsl::span shape, +result apply_4(std::span shape, Callable &&callable) noexcept { size_t index[4]; APPLY_IMPL_FOR(0) APPLY_IMPL_FOR(1) APPLY_IMPL_FOR(2) APPLY_IMPL_FOR(3) - try_(callable(gsl::span(index))); + try_(callable(std::span(index))); return ok(); } template -result apply_5(gsl::span shape, +result apply_5(std::span shape, Callable &&callable) noexcept { size_t index[5]; APPLY_IMPL_FOR(0) @@ -91,12 +91,12 @@ result apply_5(gsl::span shape, APPLY_IMPL_FOR(2) APPLY_IMPL_FOR(3) APPLY_IMPL_FOR(4) - try_(callable(gsl::span(index))); + try_(callable(std::span(index))); return ok(); } template -result apply_generic(gsl::span shape, +result apply_generic(std::span shape, Callable &&callable) noexcept { auto index_buffer = (size_t *) #ifdef _WIN32 @@ -106,7 +106,7 @@ result apply_generic(gsl::span shape, #endif (sizeof(size_t) * shape.size()); - gsl::span index(index_buffer, shape.size()); + std::span index(index_buffer, shape.size()); std::fill(index.begin(), index.end(), 0); auto last_dim_idx = (int32_t)shape.size() - 1; while (true) { @@ -128,7 +128,7 @@ result apply_generic(gsl::span shape, } // namespace detail template -result apply(gsl::span shape, +result apply(std::span shape, Callable &&callable) noexcept { switch (shape.size()) { case 0: diff --git a/src/Native/include/nncase/kernels/kernel_utils.h b/src/Native/include/nncase/kernels/kernel_utils.h index f787f5976f..37aae842ee 100644 --- a/src/Native/include/nncase/kernels/kernel_utils.h +++ b/src/Native/include/nncase/kernels/kernel_utils.h @@ -47,12 +47,12 @@ inline offset_type element_offset(const S &strides, It first, using difference_type = typename std::iterator_traits::difference_type; auto size = static_cast((std::min)( static_cast(std::distance(first, last)), strides.size())); - return std::inner_product(last - size, last, strides.cend() - size, + return std::inner_product(last - size, last, strides.end() - size, offset_type(0)); } -inline size_t offset(gsl::span strides, - gsl::span index) { +inline size_t offset(std::span strides, + std::span index) { // scalar if (strides.size() == 0 || index.size() == 0) { return 0; @@ -92,8 +92,8 @@ inline size_t get_windowed_output_size(size_t size, int32_t filter, stride; } -inline dims_t get_binary_output_shape(gsl::span input_a_shape, - gsl::span input_b_shape) { +inline dims_t get_binary_output_shape(std::span input_a_shape, + std::span input_b_shape) { dims_t out_shape; const auto dest_dims = @@ -129,8 +129,8 @@ inline T apply_activation(T value, value_range activation) { return clamp(value, activation.min, activation.max); } -inline dims_t get_reduced_offset(gsl::span in_offset, - gsl::span reduced_shape) { +inline dims_t get_reduced_offset(std::span in_offset, + std::span reduced_shape) { dims_t off(reduced_shape.size()); const auto dims_ext = in_offset.size() - reduced_shape.size(); for (size_t i = 0; i < reduced_shape.size(); i++) { @@ -143,8 +143,8 @@ inline dims_t get_reduced_offset(gsl::span in_offset, return off; } -inline dims_t get_reduced_shape(gsl::span in_shape, - gsl::span axis, bool keep_dims) { +inline dims_t get_reduced_shape(std::span in_shape, + std::span axis, bool keep_dims) { dims_t shape; shape.reserve(in_shape.size() - (keep_dims ? 0 : axis.size())); for (size_t i = 0; i < in_shape.size(); i++) { @@ -170,8 +170,8 @@ size_t get_reduce_block_size(const TShape &in_shape, const TShape &axis) { return size; } -inline dims_t get_reduced_offset(gsl::span in_offset, - gsl::span axis, bool keep_dims) { +inline dims_t get_reduced_offset(std::span in_offset, + std::span axis, bool keep_dims) { if (in_offset.size() == 0) { return in_offset; } @@ -221,7 +221,7 @@ constexpr T quantize(float value, const quant_param_t ¶m) noexcept { } inline std::pair -get_resize_scales(gsl::span in_shape, int32_t out_h, +get_resize_scales(std::span in_shape, int32_t out_h, int32_t out_w, bool align_corners) { auto height_scale = (float)in_shape[2] / out_h; auto width_scale = (float)in_shape[3] / out_w; diff --git a/src/Native/include/nncase/ntt/apply.h b/src/Native/include/nncase/ntt/apply.h new file mode 100644 index 0000000000..94cf7a762f --- /dev/null +++ b/src/Native/include/nncase/ntt/apply.h @@ -0,0 +1,45 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "tensor.h" + +namespace nncase::ntt { +namespace detail { +template struct apply_impl { + void operator()(ranked_shape &index, const Shape &shape, + Callable &&callable) { + for (index[Axis] = 0; index[Axis] < shape[Axis]; index[Axis]++) { + if constexpr (Axis == Shape::rank() - 1) { + callable(index); + } else { + apply_impl()( + index, shape, std::forward(callable)); + } + } + } +}; +} // namespace detail + +template +void apply(const Shape &shape, Callable &&callable) { + ranked_shape index; + if constexpr (Shape::rank()) { + detail::apply_impl<0, Shape, Callable>()( + index, shape, std::forward(callable)); + } else { + callable(index); + } +} +} // namespace nncase::ntt diff --git a/src/Native/include/nncase/ntt/arch/aarch64/arch_types.h b/src/Native/include/nncase/ntt/arch/aarch64/arch_types.h new file mode 100644 index 0000000000..8c009c291b --- /dev/null +++ b/src/Native/include/nncase/ntt/arch/aarch64/arch_types.h @@ -0,0 +1,29 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "../../native_tensor.h" +#include + +NTT_DEFINE_NATIVE_TENSOR(int8_t, int8x16_t, 16) +NTT_DEFINE_NATIVE_TENSOR(uint8_t, uint8x16_t, 16) +NTT_DEFINE_NATIVE_TENSOR(int16_t, int16x8_t, 8) +NTT_DEFINE_NATIVE_TENSOR(uint16_t, uint16x8_t, 8) +NTT_DEFINE_NATIVE_TENSOR(int32_t, int32x4_t, 4) +NTT_DEFINE_NATIVE_TENSOR(uint32_t, uint32x4_t, 4) +NTT_DEFINE_NATIVE_TENSOR(int64_t, int64x2_t, 2) +NTT_DEFINE_NATIVE_TENSOR(uint64_t, uint64x2_t, 2) +NTT_DEFINE_NATIVE_TENSOR(float, float32x4_t, 4) +NTT_DEFINE_NATIVE_TENSOR(float, float32x4x2_t, 8) +NTT_DEFINE_NATIVE_TENSOR(double, float64x2_t, 2) diff --git a/src/Native/include/nncase/ntt/arch/aarch64/arm_math.h b/src/Native/include/nncase/ntt/arch/aarch64/arm_math.h new file mode 100644 index 0000000000..32262a3f4f --- /dev/null +++ b/src/Native/include/nncase/ntt/arch/aarch64/arm_math.h @@ -0,0 +1,301 @@ +/* NEON implementation of sin, cos, exp and log + + Inspired by Intel Approximate Math library, and based on the + corresponding algorithms of the cephes math library +*/ + +/* Copyright (C) 2011 Julien Pommier + + This software is provided 'as-is', without any express or implied + warranty. In no event will the authors be held liable for any damages + arising from the use of this software. + + Permission is granted to anyone to use this software for any purpose, + including commercial applications, and to alter it and redistribute it + freely, subject to the following restrictions: + + 1. The origin of this software must not be misrepresented; you must not + claim that you wrote the original software. If you use this software + in a product, an acknowledgment in the product documentation would be + appreciated but is not required. + 2. Altered source versions must be plainly marked as such, and must not be + misrepresented as being the original software. + 3. This notice may not be removed or altered from any source distribution. + + (this is the zlib license) +*/ + +#include + +typedef float32x4_t v4sf; // vector of 4 float +typedef uint32x4_t v4su; // vector of 4 uint32 +typedef int32x4_t v4si; // vector of 4 uint32 + +#define c_inv_mant_mask ~0x7f800000u +#define c_cephes_SQRTHF 0.707106781186547524 +#define c_cephes_log_p0 7.0376836292E-2 +#define c_cephes_log_p1 -1.1514610310E-1 +#define c_cephes_log_p2 1.1676998740E-1 +#define c_cephes_log_p3 -1.2420140846E-1 +#define c_cephes_log_p4 +1.4249322787E-1 +#define c_cephes_log_p5 -1.6668057665E-1 +#define c_cephes_log_p6 +2.0000714765E-1 +#define c_cephes_log_p7 -2.4999993993E-1 +#define c_cephes_log_p8 +3.3333331174E-1 +#define c_cephes_log_q1 -2.12194440e-4 +#define c_cephes_log_q2 0.693359375 + +/* natural logarithm computed for 4 simultaneous float + return NaN for x <= 0 +*/ +v4sf log_ps(v4sf x) { + v4sf one = vdupq_n_f32(1); + + x = vmaxq_f32(x, + vdupq_n_f32(0)); /* force flush to zero on denormal values */ + v4su invalid_mask = vcleq_f32(x, vdupq_n_f32(0)); + + v4si ux = vreinterpretq_s32_f32(x); + + v4si emm0 = vshrq_n_s32(ux, 23); + + /* keep only the fractional part */ + ux = vandq_s32(ux, vdupq_n_s32(c_inv_mant_mask)); + ux = vorrq_s32(ux, vreinterpretq_s32_f32(vdupq_n_f32(0.5f))); + x = vreinterpretq_f32_s32(ux); + + emm0 = vsubq_s32(emm0, vdupq_n_s32(0x7f)); + v4sf e = vcvtq_f32_s32(emm0); + + e = vaddq_f32(e, one); + + /* part2: + if( x < SQRTHF ) { + e -= 1; + x = x + x - 1.0; + } else { x = x - 1.0; } + */ + v4su mask = vcltq_f32(x, vdupq_n_f32(c_cephes_SQRTHF)); + v4sf tmp = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(x), mask)); + x = vsubq_f32(x, one); + e = vsubq_f32( + e, vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(one), mask))); + x = vaddq_f32(x, tmp); + + v4sf z = vmulq_f32(x, x); + + v4sf y = vdupq_n_f32(c_cephes_log_p0); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p1)); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p2)); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p3)); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p4)); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p5)); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p6)); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p7)); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p8)); + y = vmulq_f32(y, x); + + y = vmulq_f32(y, z); + + tmp = vmulq_f32(e, vdupq_n_f32(c_cephes_log_q1)); + y = vaddq_f32(y, tmp); + + tmp = vmulq_f32(z, vdupq_n_f32(0.5f)); + y = vsubq_f32(y, tmp); + + tmp = vmulq_f32(e, vdupq_n_f32(c_cephes_log_q2)); + x = vaddq_f32(x, y); + x = vaddq_f32(x, tmp); + x = vreinterpretq_f32_u32(vorrq_u32( + vreinterpretq_u32_f32(x), invalid_mask)); // negative arg will be NAN + return x; +} + +#define c_exp_hi 88.3762626647949f +#define c_exp_lo -88.3762626647949f + +#define c_cephes_LOG2EF 1.44269504088896341 +#define c_cephes_exp_C1 0.693359375 +#define c_cephes_exp_C2 -2.12194440e-4 + +#define c_cephes_exp_p0 1.9875691500E-4 +#define c_cephes_exp_p1 1.3981999507E-3 +#define c_cephes_exp_p2 8.3334519073E-3 +#define c_cephes_exp_p3 4.1665795894E-2 +#define c_cephes_exp_p4 1.6666665459E-1 +#define c_cephes_exp_p5 5.0000001201E-1 + +/* exp() computed for 4 float at once */ +v4sf exp_ps(v4sf x) { + v4sf tmp, fx; + + v4sf one = vdupq_n_f32(1); + x = vminq_f32(x, vdupq_n_f32(c_exp_hi)); + x = vmaxq_f32(x, vdupq_n_f32(c_exp_lo)); + + /* express exp(x) as exp(g + n*log(2)) */ + fx = vmlaq_f32(vdupq_n_f32(0.5f), x, vdupq_n_f32(c_cephes_LOG2EF)); + + /* perform a floorf */ + tmp = vcvtq_f32_s32(vcvtq_s32_f32(fx)); + + /* if greater, substract 1 */ + v4su mask = vcgtq_f32(tmp, fx); + mask = vandq_u32(mask, vreinterpretq_u32_f32(one)); + + fx = vsubq_f32(tmp, vreinterpretq_f32_u32(mask)); + + tmp = vmulq_f32(fx, vdupq_n_f32(c_cephes_exp_C1)); + v4sf z = vmulq_f32(fx, vdupq_n_f32(c_cephes_exp_C2)); + x = vsubq_f32(x, tmp); + x = vsubq_f32(x, z); + + static const float cephes_exp_p[6] = {c_cephes_exp_p0, c_cephes_exp_p1, + c_cephes_exp_p2, c_cephes_exp_p3, + c_cephes_exp_p4, c_cephes_exp_p5}; + v4sf y = vld1q_dup_f32(cephes_exp_p + 0); + v4sf c1 = vld1q_dup_f32(cephes_exp_p + 1); + v4sf c2 = vld1q_dup_f32(cephes_exp_p + 2); + v4sf c3 = vld1q_dup_f32(cephes_exp_p + 3); + v4sf c4 = vld1q_dup_f32(cephes_exp_p + 4); + v4sf c5 = vld1q_dup_f32(cephes_exp_p + 5); + + y = vmulq_f32(y, x); + z = vmulq_f32(x, x); + y = vaddq_f32(y, c1); + y = vmulq_f32(y, x); + y = vaddq_f32(y, c2); + y = vmulq_f32(y, x); + y = vaddq_f32(y, c3); + y = vmulq_f32(y, x); + y = vaddq_f32(y, c4); + y = vmulq_f32(y, x); + y = vaddq_f32(y, c5); + + y = vmulq_f32(y, z); + y = vaddq_f32(y, x); + y = vaddq_f32(y, one); + + /* build 2^n */ + int32x4_t mm; + mm = vcvtq_s32_f32(fx); + mm = vaddq_s32(mm, vdupq_n_s32(0x7f)); + mm = vshlq_n_s32(mm, 23); + v4sf pow2n = vreinterpretq_f32_s32(mm); + + y = vmulq_f32(y, pow2n); + return y; +} + +#define c_minus_cephes_DP1 -0.78515625 +#define c_minus_cephes_DP2 -2.4187564849853515625e-4 +#define c_minus_cephes_DP3 -3.77489497744594108e-8 +#define c_sincof_p0 -1.9515295891E-4 +#define c_sincof_p1 8.3321608736E-3 +#define c_sincof_p2 -1.6666654611E-1 +#define c_coscof_p0 2.443315711809948E-005 +#define c_coscof_p1 -1.388731625493765E-003 +#define c_coscof_p2 4.166664568298827E-002 +#define c_cephes_FOPI 1.27323954473516 // 4 / M_PI + +/* evaluation of 4 sines & cosines at once. + + The code is the exact rewriting of the cephes sinf function. + Precision is excellent as long as x < 8192 (I did not bother to + take into account the special handling they have for greater values + -- it does not return garbage for arguments over 8192, though, but + the extra precision is missing). + + Note that it is such that sinf((float)M_PI) = 8.74e-8, which is the + surprising but correct result. + + Note also that when you compute sin(x), cos(x) is available at + almost no extra price so both sin_ps and cos_ps make use of + sincos_ps.. + */ +void sincos_ps(v4sf x, v4sf *ysin, v4sf *ycos) { // any x + v4sf xmm1, xmm2, xmm3, y; + + v4su emm2; + + v4su sign_mask_sin, sign_mask_cos; + sign_mask_sin = vcltq_f32(x, vdupq_n_f32(0)); + x = vabsq_f32(x); + + /* scale by 4/Pi */ + y = vmulq_f32(x, vdupq_n_f32(c_cephes_FOPI)); + + /* store the integer part of y in mm0 */ + emm2 = vcvtq_u32_f32(y); + /* j=(j+1) & (~1) (see the cephes sources) */ + emm2 = vaddq_u32(emm2, vdupq_n_u32(1)); + emm2 = vandq_u32(emm2, vdupq_n_u32(~1)); + y = vcvtq_f32_u32(emm2); + + /* get the polynom selection mask + there is one polynom for 0 <= x <= Pi/4 + and another one for Pi/4 struct load_scalar> { + ntt::vector operator()(float v) const noexcept { + return vdupq_n_f32(v); + } +}; +} // namespace nncase::ntt::tensor_ops diff --git a/src/Native/include/nncase/ntt/arch/x86_64/arch_types.h b/src/Native/include/nncase/ntt/arch/x86_64/arch_types.h new file mode 100644 index 0000000000..dce73fd9d0 --- /dev/null +++ b/src/Native/include/nncase/ntt/arch/x86_64/arch_types.h @@ -0,0 +1,28 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "../../native_tensor.h" +#include + +NTT_DEFINE_NATIVE_TENSOR(int8_t, __m256i, 32) +NTT_DEFINE_NATIVE_TENSOR(uint8_t, __m256i, 32) +NTT_DEFINE_NATIVE_TENSOR(int16_t, __m256i, 16) +NTT_DEFINE_NATIVE_TENSOR(uint16_t, __m256i, 16) +NTT_DEFINE_NATIVE_TENSOR(int32_t, __m256i, 8) +NTT_DEFINE_NATIVE_TENSOR(uint32_t, __m256i, 8) +NTT_DEFINE_NATIVE_TENSOR(int64_t, __m256i, 4) +NTT_DEFINE_NATIVE_TENSOR(uint64_t, __m256i, 4) +NTT_DEFINE_NATIVE_TENSOR(float, __m256, 8) +NTT_DEFINE_NATIVE_TENSOR(double, __m256d, 4) diff --git a/src/Native/include/nncase/ntt/arch/x86_64/avx_mathfun.h b/src/Native/include/nncase/ntt/arch/x86_64/avx_mathfun.h new file mode 100644 index 0000000000..c76ebd571d --- /dev/null +++ b/src/Native/include/nncase/ntt/arch/x86_64/avx_mathfun.h @@ -0,0 +1,1057 @@ +/* + AVX implementation of sin, cos, sincos, exp and log + + Based on "sse_mathfun.h", by Julien Pommier + http://gruntthepeon.free.fr/ssemath/ + + Copyright (C) 2012 Giovanni Garberoglio + Interdisciplinary Laboratory for Computational Science (LISC) + Fondazione Bruno Kessler and University of Trento + via Sommarive, 18 + I-38123 Trento (Italy) + + This software is provided 'as-is', without any express or implied + warranty. In no event will the authors be held liable for any damages + arising from the use of this software. + + Permission is granted to anyone to use this software for any purpose, + including commercial applications, and to alter it and redistribute it + freely, subject to the following restrictions: + + 1. The origin of this software must not be misrepresented; you must not + claim that you wrote the original software. If you use this software + in a product, an acknowledgment in the product documentation would be + appreciated but is not required. + 2. Altered source versions must be plainly marked as such, and must not be + misrepresented as being the original software. + 3. This notice may not be removed or altered from any source distribution. + + (this is the zlib license) +*/ + +#ifndef AVX_MATHFUN_H +#define AVX_MATHFUN_H + +#include "x86_usability.h" +#include +#include + +/* yes I know, the top of this file is quite ugly */ + +#ifdef _MSC_VER /* visual c++ */ +#define ALIGN32_BEG __declspec(align(32)) +#define ALIGN32_END +#else /* gcc or icc */ +#define ALIGN32_BEG +#define ALIGN32_END __attribute__((aligned(32))) +#endif + +#define _PI32AVX_CONST(Name, Val) \ + static const ALIGN32_BEG int _pi32avx_##Name[4] ALIGN32_END = {Val, Val, \ + Val, Val} + +_PI32AVX_CONST(1, 1); +_PI32AVX_CONST(inv1, ~1); +_PI32AVX_CONST(2, 2); +_PI32AVX_CONST(4, 4); + +/* declare some AVX constants -- why can't I figure a better way to do that? */ +#define _PS256_CONST(Name, Val) \ + static const ALIGN32_BEG float _ps256_##Name[8] ALIGN32_END = { \ + Val, Val, Val, Val, Val, Val, Val, Val} +#define _PI32_CONST256(Name, Val) \ + static const ALIGN32_BEG int _pi32_256_##Name[8] ALIGN32_END = { \ + Val, Val, Val, Val, Val, Val, Val, Val} +#define _PS256_CONST_TYPE(Name, Type, Val) \ + static const ALIGN32_BEG Type _ps256_##Name[8] ALIGN32_END = { \ + Val, Val, Val, Val, Val, Val, Val, Val} + +_PS256_CONST(1, 1.0f); +_PS256_CONST(0p5, 0.5f); +/* the smallest non denormalized float number */ +_PS256_CONST_TYPE(min_norm_pos, int, 0x00800000); +_PS256_CONST_TYPE(mant_mask, int, 0x7f800000); +_PS256_CONST_TYPE(inv_mant_mask, int, ~0x7f800000); + +_PS256_CONST_TYPE(sign_mask, int, (int)0x80000000); +_PS256_CONST_TYPE(inv_sign_mask, int, ~0x80000000); + +_PI32_CONST256(0, 0); +_PI32_CONST256(1, 1); +_PI32_CONST256(inv1, ~1); +_PI32_CONST256(2, 2); +_PI32_CONST256(4, 4); +_PI32_CONST256(0x7f, 0x7f); + +_PS256_CONST(cephes_SQRTHF, 0.707106781186547524f); +_PS256_CONST(cephes_log_p0, 7.0376836292E-2f); +_PS256_CONST(cephes_log_p1, -1.1514610310E-1f); +_PS256_CONST(cephes_log_p2, 1.1676998740E-1f); +_PS256_CONST(cephes_log_p3, -1.2420140846E-1f); +_PS256_CONST(cephes_log_p4, +1.4249322787E-1f); +_PS256_CONST(cephes_log_p5, -1.6668057665E-1f); +_PS256_CONST(cephes_log_p6, +2.0000714765E-1f); +_PS256_CONST(cephes_log_p7, -2.4999993993E-1f); +_PS256_CONST(cephes_log_p8, +3.3333331174E-1f); +_PS256_CONST(cephes_log_q1, -2.12194440e-4f); +_PS256_CONST(cephes_log_q2, 0.693359375f); + +#ifndef __AVX2__ +typedef union imm_xmm_union { + __m256i imm; + __m128i xmm[2]; +} imm_xmm_union; + +#define COPY_IMM_TO_XMM(imm_, xmm0_, xmm1_) \ + { \ + ALIGN32_BEG imm_xmm_union u ALIGN32_END; \ + u.imm = imm_; \ + xmm0_ = u.xmm[0]; \ + xmm1_ = u.xmm[1]; \ + } + +#define COPY_XMM_TO_IMM(xmm0_, xmm1_, imm_) \ + { \ + ALIGN32_BEG imm_xmm_union u ALIGN32_END; \ + u.xmm[0] = xmm0_; \ + u.xmm[1] = xmm1_; \ + imm_ = u.imm; \ + } + +#define AVX2_BITOP_USING_SSE2(fn) \ + static inline __m256i _mm256_comp_##fn(__m256i x, int a) { \ + /* use SSE2 instruction to perform the bitop AVX2 */ \ + __m128i x1, x2; \ + __m256i ret; \ + COPY_IMM_TO_XMM(x, x1, x2); \ + x1 = _mm_##fn(x1, a); \ + x2 = _mm_##fn(x2, a); \ + COPY_XMM_TO_IMM(x1, x2, ret); \ + return (ret); \ + } +#define AVX2_INTOP_USING_SSE2(fn) \ + static inline __m256i _mm256_comp_##fn(__m256i x, __m256i y) { \ + /* use SSE2 instructions to perform the AVX2 integer operation */ \ + __m128i x1, x2; \ + __m128i y1, y2; \ + __m256i ret; \ + COPY_IMM_TO_XMM(x, x1, x2); \ + COPY_IMM_TO_XMM(y, y1, y2); \ + x1 = _mm_##fn(x1, y1); \ + x2 = _mm_##fn(x2, y2); \ + COPY_XMM_TO_IMM(x1, x2, ret); \ + return (ret); \ + } +#else +#define AVX2_BITOP_USING_SSE2(fn) \ + static inline __m256i _mm256_comp_##fn(__m256i x, int a) { \ + return _mm256_##fn(x, a); \ + } +#define AVX2_INTOP_USING_SSE2(fn) \ + static inline __m256i _mm256_comp_##fn(__m256i x, __m256i y) { \ + return _mm256_##fn(x, y); \ + } +#endif + +AVX2_BITOP_USING_SSE2(slli_epi32) +AVX2_BITOP_USING_SSE2(srli_epi32) +AVX2_INTOP_USING_SSE2(cmpeq_epi32) +AVX2_INTOP_USING_SSE2(sub_epi32) +AVX2_INTOP_USING_SSE2(add_epi32) + +// Replace 256 bit operations with 128 bit ones when AVX2 is disabled +#ifndef __AVX2__ +AVX2_INTOP_USING_SSE2(and_si128) +AVX2_INTOP_USING_SSE2(andnot_si128) +#endif + +/* natural logarithm computed for 8 simultaneous float + return NaN for x <= 0 +*/ +static inline __m256 log256_ps(__m256 x) { + __m256i imm0; + __m256 one = *(__m256 *)_ps256_1; + + //__m256 invalid_mask = _mm256_cmple_ps(x, _mm256_setzero_ps()); + __m256 invalid_mask = _mm256_cmp_ps(x, _mm256_setzero_ps(), _CMP_LE_OS); + + x = _mm256_max_ps( + x, *(__m256 *)_ps256_min_norm_pos); /* cut off denormalized stuff */ + + // can be done with AVX2 + imm0 = _mm256_comp_srli_epi32(_mm256_castps_si256(x), 23); + + /* keep only the fractional part */ + x = _mm256_and_ps(x, *(__m256 *)_ps256_inv_mant_mask); + x = _mm256_or_ps(x, *(__m256 *)_ps256_0p5); + + // this is again another AVX2 instruction + imm0 = _mm256_comp_sub_epi32(imm0, *(__m256i *)_pi32_256_0x7f); + __m256 e = _mm256_cvtepi32_ps(imm0); + + e = _mm256_add_ps(e, one); + + /* part2: + if( x < SQRTHF ) { + e -= 1; + x = x + x - 1.0; + } else { x = x - 1.0; } + */ + //__m256 mask = _mm256_cmplt_ps(x, *(__m256*)_ps256_cephes_SQRTHF); + __m256 mask = _mm256_cmp_ps(x, *(__m256 *)_ps256_cephes_SQRTHF, _CMP_LT_OS); + __m256 tmp = _mm256_and_ps(x, mask); + x = _mm256_sub_ps(x, one); + e = _mm256_sub_ps(e, _mm256_and_ps(one, mask)); + x = _mm256_add_ps(x, tmp); + + __m256 z = _mm256_mul_ps(x, x); + + __m256 y = *(__m256 *)_ps256_cephes_log_p0; + y = _mm256_comp_fmadd_ps(y, x, *(__m256 *)_ps256_cephes_log_p1); + y = _mm256_comp_fmadd_ps(y, x, *(__m256 *)_ps256_cephes_log_p2); + y = _mm256_comp_fmadd_ps(y, x, *(__m256 *)_ps256_cephes_log_p3); + y = _mm256_comp_fmadd_ps(y, x, *(__m256 *)_ps256_cephes_log_p4); + y = _mm256_comp_fmadd_ps(y, x, *(__m256 *)_ps256_cephes_log_p5); + y = _mm256_comp_fmadd_ps(y, x, *(__m256 *)_ps256_cephes_log_p6); + y = _mm256_comp_fmadd_ps(y, x, *(__m256 *)_ps256_cephes_log_p7); + y = _mm256_comp_fmadd_ps(y, x, *(__m256 *)_ps256_cephes_log_p8); + y = _mm256_mul_ps(y, x); + + y = _mm256_mul_ps(y, z); + + y = _mm256_comp_fmadd_ps(e, *(__m256 *)_ps256_cephes_log_q1, y); + + // y = -z * 0.5 + y + y = _mm256_comp_fnmadd_ps(z, *(__m256 *)_ps256_0p5, y); + + x = _mm256_add_ps(x, y); + x = _mm256_comp_fmadd_ps(e, *(__m256 *)_ps256_cephes_log_q2, x); + y = _mm256_or_ps(x, invalid_mask); // negative arg will be NAN + return y; +} + +_PS256_CONST(exp_hi, 88.3762626647949f); +_PS256_CONST(exp_lo, -88.3762626647949f); + +_PS256_CONST(cephes_LOG2EF, 1.44269504088896341f); +_PS256_CONST(cephes_exp_C1, 0.693359375f); +_PS256_CONST(cephes_exp_C2, -2.12194440e-4f); + +_PS256_CONST(cephes_exp_p0, 1.9875691500E-4f); +_PS256_CONST(cephes_exp_p1, 1.3981999507E-3f); +_PS256_CONST(cephes_exp_p2, 8.3334519073E-3f); +_PS256_CONST(cephes_exp_p3, 4.1665795894E-2f); +_PS256_CONST(cephes_exp_p4, 1.6666665459E-1f); +_PS256_CONST(cephes_exp_p5, 5.0000001201E-1f); + +static inline __m256 exp256_ps(__m256 x) { + __m256 tmp = _mm256_setzero_ps(), fx; + __m256i imm0; + __m256 one = *(__m256 *)_ps256_1; + + x = _mm256_min_ps(x, *(__m256 *)_ps256_exp_hi); + x = _mm256_max_ps(x, *(__m256 *)_ps256_exp_lo); + + /* express exp(x) as exp(g + n*log(2)) */ + fx = _mm256_comp_fmadd_ps(x, *(__m256 *)_ps256_cephes_LOG2EF, + *(__m256 *)_ps256_0p5); + + /* how to perform a floorf with SSE: just below */ + // imm0 = _mm256_cvttps_epi32(fx); + // tmp = _mm256_cvtepi32_ps(imm0); + + tmp = _mm256_floor_ps(fx); + + /* if greater, subtract 1 */ + //__m256 mask = _mm256_cmpgt_ps(tmp, fx); + __m256 mask = _mm256_cmp_ps(tmp, fx, _CMP_GT_OS); + mask = _mm256_and_ps(mask, one); + fx = _mm256_sub_ps(tmp, mask); + + // x = x - fx * exp_C1 + x = _mm256_comp_fnmadd_ps(fx, *(__m256 *)_ps256_cephes_exp_C1, x); + // x = x - fx * exp_C2 + x = _mm256_comp_fnmadd_ps(fx, *(__m256 *)_ps256_cephes_exp_C2, x); + + tmp = _mm256_mul_ps(x, x); + + __m256 y = *(__m256 *)_ps256_cephes_exp_p0; + y = _mm256_comp_fmadd_ps(y, x, *(__m256 *)_ps256_cephes_exp_p1); + y = _mm256_comp_fmadd_ps(y, x, *(__m256 *)_ps256_cephes_exp_p2); + y = _mm256_comp_fmadd_ps(y, x, *(__m256 *)_ps256_cephes_exp_p3); + y = _mm256_comp_fmadd_ps(y, x, *(__m256 *)_ps256_cephes_exp_p4); + y = _mm256_comp_fmadd_ps(y, x, *(__m256 *)_ps256_cephes_exp_p5); + y = _mm256_comp_fmadd_ps(y, tmp, x); + y = _mm256_add_ps(y, one); + + /* build 2^n */ + imm0 = _mm256_cvttps_epi32(fx); + // another two AVX2 instructions + imm0 = _mm256_comp_add_epi32(imm0, *(__m256i *)_pi32_256_0x7f); + imm0 = _mm256_comp_slli_epi32(imm0, 23); + __m256 pow2n = _mm256_castsi256_ps(imm0); + y = _mm256_mul_ps(y, pow2n); + return y; +} + +_PS256_CONST(tanh_hi, 9.0f); +_PS256_CONST(tanh_lo, -9.0f); + +_PS256_CONST(cephes_tanh_p0, -2.76076847742355E-16f); +_PS256_CONST(cephes_tanh_p1, 2.00018790482477E-13f); +_PS256_CONST(cephes_tanh_p2, -8.60467152213735E-11f); +_PS256_CONST(cephes_tanh_p3, 5.12229709037114E-08f); +_PS256_CONST(cephes_tanh_p4, 1.48572235717979E-05f); +_PS256_CONST(cephes_tanh_p5, 6.37261928875436E-04f); +_PS256_CONST(cephes_tanh_p6, 4.89352455891786E-03f); + +_PS256_CONST(cephes_tanh_p7, 1.19825839466702e-06f); +_PS256_CONST(cephes_tanh_p8, 1.18534705686654e-04f); +_PS256_CONST(cephes_tanh_p9, 2.26843463243900e-03f); + +// an approximation of tanh +static inline __m256 tanh256_ps(const __m256 x) { + __m256 value = x; + value = _mm256_max_ps(*(__m256 *)_ps256_tanh_lo, value); + value = _mm256_min_ps(*(__m256 *)_ps256_tanh_hi, value); + + __m256 value_squared = _mm256_mul_ps(value, value); + + __m256 p; + p = _mm256_comp_fmadd_ps(value_squared, *(__m256 *)_ps256_cephes_tanh_p0, + *(__m256 *)_ps256_cephes_tanh_p1); + p = _mm256_comp_fmadd_ps(p, value_squared, + *(__m256 *)_ps256_cephes_tanh_p2); + p = _mm256_comp_fmadd_ps(p, value_squared, + *(__m256 *)_ps256_cephes_tanh_p3); + p = _mm256_comp_fmadd_ps(p, value_squared, + *(__m256 *)_ps256_cephes_tanh_p4); + p = _mm256_comp_fmadd_ps(p, value_squared, + *(__m256 *)_ps256_cephes_tanh_p5); + p = _mm256_comp_fmadd_ps(p, value_squared, + *(__m256 *)_ps256_cephes_tanh_p6); + p = _mm256_mul_ps(p, value); + + __m256 q; + q = _mm256_comp_fmadd_ps(value_squared, *(__m256 *)_ps256_cephes_tanh_p7, + *(__m256 *)_ps256_cephes_tanh_p8); + q = _mm256_comp_fmadd_ps(q, value_squared, + *(__m256 *)_ps256_cephes_tanh_p9); + q = _mm256_comp_fmadd_ps(q, value_squared, + *(__m256 *)_ps256_cephes_tanh_p6); + + __m256 dst = _mm256_div_ps(p, q); + return dst; +} + +_PS256_CONST(minus_cephes_DP1, -0.78515625f); +_PS256_CONST(minus_cephes_DP2, -2.4187564849853515625e-4f); +_PS256_CONST(minus_cephes_DP3, -3.77489497744594108e-8f); +_PS256_CONST(sincof_p0, -1.9515295891E-4f); +_PS256_CONST(sincof_p1, 8.3321608736E-3f); +_PS256_CONST(sincof_p2, -1.6666654611E-1f); +_PS256_CONST(coscof_p0, 2.443315711809948E-005f); +_PS256_CONST(coscof_p1, -1.388731625493765E-003f); +_PS256_CONST(coscof_p2, 4.166664568298827E-002f); +_PS256_CONST(cephes_FOPI, 1.27323954473516f); // 4 / M_PI + +/* evaluation of 8 sines at onces using AVX intrisics + + The code is the exact rewriting of the cephes sinf function. + Precision is excellent as long as x < 8192 (I did not bother to + take into account the special handling they have for greater values + -- it does not return garbage for arguments over 8192, though, but + the extra precision is missing). + + Note that it is such that sinf((float)M_PI) = 8.74e-8, which is the + surprising but correct result. + +*/ +static inline __m256 sin256_ps(__m256 x) { // any x + __m256 xmm1, xmm2 = _mm256_setzero_ps(), xmm3, sign_bit, y; + __m256i imm0, imm2; + +#ifndef __AVX2__ + __m128i imm0_1, imm0_2; + __m128i imm2_1, imm2_2; +#endif + + sign_bit = x; + /* take the absolute value */ + x = _mm256_and_ps(x, *(__m256 *)_ps256_inv_sign_mask); + /* extract the sign bit (upper one) */ + sign_bit = _mm256_and_ps(sign_bit, *(__m256 *)_ps256_sign_mask); + + /* scale by 4/Pi */ + y = _mm256_mul_ps(x, *(__m256 *)_ps256_cephes_FOPI); + + /* + Here we start a series of integer operations, which are in the + realm of AVX2. + If we don't have AVX, let's perform them using SSE2 directives + */ + +#ifdef __AVX2__ + /* store the integer part of y in mm0 */ + imm2 = _mm256_cvttps_epi32(y); + /* j=(j+1) & (~1) (see the cephes sources) */ + // another two AVX2 instruction + imm2 = _mm256_comp_add_epi32(imm2, *(__m256i *)_pi32_256_1); + imm2 = _mm256_and_si256(imm2, *(__m256i *)_pi32_256_inv1); + y = _mm256_cvtepi32_ps(imm2); + + /* get the swap sign flag */ + imm0 = _mm256_and_si256(imm2, *(__m256i *)_pi32_256_4); + imm0 = _mm256_comp_slli_epi32(imm0, 29); + /* get the polynom selection mask + there is one polynom for 0 <= x <= Pi/4 + and another one for Pi/4 struct abs> { + ntt::vector operator()(ntt::vector v) const noexcept { + return abs256_ps(v); + } +}; + +// acos +template <> struct acos> { + ntt::vector operator()(ntt::vector v) const noexcept { + return acos256_ps(v); + } +}; + +// acosh(v) = ln(v + sqrt(v^2 - 1)), v >= 1 +template <> struct acosh> { + ntt::vector operator()(ntt::vector v) const noexcept { + auto ones = _mm256_set1_ps(1.0f); + return log256_ps(_mm256_add_ps( + v, _mm256_sqrt_ps(_mm256_sub_ps(_mm256_mul_ps(v, v), ones)))); + } +}; + +// asin +template <> struct asin> { + ntt::vector operator()(ntt::vector v) const noexcept { + return asin256_ps(v); + } +}; + +// asinh(v) = ln(v + sqrt(v^2 + 1)) +template <> struct asinh> { + ntt::vector operator()(ntt::vector v) const noexcept { + auto ones = _mm256_set1_ps(1.0f); + return log256_ps(_mm256_add_ps( + v, _mm256_sqrt_ps(_mm256_add_ps(_mm256_mul_ps(v, v), ones)))); + } +}; + +// ceil +template <> struct ceil> { + ntt::vector operator()(ntt::vector v) const noexcept { + return _mm256_ceil_ps(v); + } +}; + +// cos +template <> struct cos> { + ntt::vector operator()(ntt::vector v) const noexcept { + return cos256_ps(v); + } +}; + +// cosh(v) = (exp(v) + exp(-v)) / 2 +template <> struct cosh> { + ntt::vector operator()(ntt::vector v) const noexcept { + auto zeros = _mm256_setzero_ps(); + auto twos = _mm256_set1_ps(2.0f); + return _mm256_div_ps( + _mm256_add_ps(exp256_ps(v), exp256_ps(_mm256_sub_ps(zeros, v))), + twos); + } +}; + +// exp +template <> struct exp> { + ntt::vector operator()(ntt::vector v) const noexcept { + return exp256_ps(v); + } +}; + +// floor +template <> struct floor> { + ntt::vector operator()(ntt::vector v) const noexcept { + return _mm256_floor_ps(v); + } +}; + +// log +template <> struct log> { + ntt::vector operator()(ntt::vector v) const noexcept { + return log256_ps(v); + } +}; + +// neg +template <> struct neg> { + ntt::vector operator()(ntt::vector v) const noexcept { + return _mm256_sub_ps(_mm256_setzero_ps(), v); + } +}; + +// round +template <> struct round> { + ntt::vector operator()(ntt::vector v) const noexcept { + return _mm256_round_ps(v, + _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + } +}; + +// rsqrt +template <> struct rsqrt> { + ntt::vector operator()(ntt::vector v) const noexcept { + return _mm256_rsqrt_ps(v); + } +}; + +// sign +template <> struct sign> { + ntt::vector operator()(ntt::vector v) const noexcept { +#if 0 + auto sign_mask = _mm256_set1_ps(-0.0f); + auto sign_bits = _mm256_and_ps(v, sign_mask); + auto minus_ones = _mm256_set1_ps(-1.0f); + auto zeros = _mm256_setzero_ps(); + auto ret = _mm256_blendv_ps(zeros, minus_ones, sign_bits); + auto gt_zero_mask = _mm256_cmp_ps(v, zeros, _CMP_GT_OQ); + auto ones = _mm256_set1_ps(1.0f); + ret = _mm256_blendv_ps(ret, ones, gt_zero_mask); +#else + auto minus_ones = _mm256_set1_ps(-1.0f); + auto ones = _mm256_set1_ps(1.0f); + auto zeros = _mm256_setzero_ps(); + auto ret = _mm256_setzero_ps(); + auto mask = _mm256_cmp_ps(v, zeros, _CMP_GT_OQ); + ret = _mm256_blendv_ps(ret, ones, mask); + mask = _mm256_cmp_ps(v, zeros, _CMP_LT_OQ); + ret = _mm256_blendv_ps(ret, minus_ones, mask); +#endif + return ret; + } +}; + +// sin +template <> struct sin> { + ntt::vector operator()(ntt::vector v) const noexcept { + return sin256_ps(v); + } +}; + +// sinh(v) = (exp(v) - exp(-v)) / 2 +template <> struct sinh> { + ntt::vector operator()(ntt::vector v) const noexcept { + auto zeros = _mm256_setzero_ps(); + auto twos = _mm256_set1_ps(2.0f); + return _mm256_div_ps( + _mm256_sub_ps(exp256_ps(v), exp256_ps(_mm256_sub_ps(zeros, v))), + twos); + } +}; + +// sqrt +template <> struct sqrt> { + ntt::vector operator()(ntt::vector v) const noexcept { + return _mm256_sqrt_ps(v); + } +}; + +// square +template <> struct square> { + ntt::vector operator()(ntt::vector v) const noexcept { + return _mm256_mul_ps(v, v); + } +}; + +// tanh +template <> struct tanh> { + ntt::vector operator()(ntt::vector v) const noexcept { + return tanh256_ps(v); + } +}; + +// swish(v) = v / (1 + std::exp(-v)) +template <> struct swish> { + ntt::vector operator()(ntt::vector v) const noexcept { + auto ones = _mm256_set1_ps(1.0f); + auto zeros = _mm256_setzero_ps(); + return _mm256_div_ps( + v, _mm256_add_ps(ones, exp256_ps(_mm256_sub_ps(zeros, v)))); + } +}; + +#endif +} // namespace nncase::ntt::ops diff --git a/src/Native/include/nncase/ntt/arch/x86_64/tensor_ops.h b/src/Native/include/nncase/ntt/arch/x86_64/tensor_ops.h new file mode 100644 index 0000000000..16d3e3e10f --- /dev/null +++ b/src/Native/include/nncase/ntt/arch/x86_64/tensor_ops.h @@ -0,0 +1,26 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "../../tensor_ops.h" +#include "arch_types.h" +#include "avx_mathfun.h" + +namespace nncase::ntt::tensor_ops { +template <> struct load_scalar> { + ntt::vector operator()(float v) const noexcept { + return _mm256_set1_ps(v); + } +}; +} // namespace nncase::ntt::tensor_ops diff --git a/src/Native/include/nncase/ntt/arch/x86_64/x86_usability.h b/src/Native/include/nncase/ntt/arch/x86_64/x86_usability.h new file mode 100644 index 0000000000..3bc92e3450 --- /dev/null +++ b/src/Native/include/nncase/ntt/arch/x86_64/x86_usability.h @@ -0,0 +1,1364 @@ +// Tencent is pleased to support the open source community by making ncnn +// available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this +// file except in compliance with the License. You may obtain a copy of the +// License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +#ifndef X86_USABILITY_H +#define X86_USABILITY_H + +#include +#if __SSE2__ +#include +#if __SSE4_1__ +#include +#if __AVX__ +#include +#if __XOP__ +#ifdef _MSC_VER +#include +#else +#include +#endif +#endif +#endif +#endif +#endif // __SSE2__ + +static inline signed char float2int8(float v) { + int int32 = (int)round(v); + if (int32 > 127) + return 127; + if (int32 < -127) + return -127; + return (signed char)int32; +} + +#if __SSE2__ +static inline void transpose4x8_epi32(__m128i &_r0, __m128i &_r1, __m128i &_r2, + __m128i &_r3, __m128i &_r4, __m128i &_r5, + __m128i &_r6, __m128i &_r7) { + __m128i _tmp0 = _mm_unpacklo_epi32(_r0, _r1); + __m128i _tmp1 = _mm_unpackhi_epi32(_r0, _r1); + __m128i _tmp2 = _mm_unpacklo_epi32(_r2, _r3); + __m128i _tmp3 = _mm_unpackhi_epi32(_r2, _r3); + __m128i _tmp4 = _mm_unpacklo_epi32(_r4, _r5); + __m128i _tmp5 = _mm_unpackhi_epi32(_r4, _r5); + __m128i _tmp6 = _mm_unpacklo_epi32(_r6, _r7); + __m128i _tmp7 = _mm_unpackhi_epi32(_r6, _r7); + + _r0 = _mm_unpacklo_epi64(_tmp0, _tmp2); + _r1 = _mm_unpacklo_epi64(_tmp4, _tmp6); + _r2 = _mm_unpackhi_epi64(_tmp0, _tmp2); + _r3 = _mm_unpackhi_epi64(_tmp4, _tmp6); + _r4 = _mm_unpacklo_epi64(_tmp1, _tmp3); + _r5 = _mm_unpacklo_epi64(_tmp5, _tmp7); + _r6 = _mm_unpackhi_epi64(_tmp1, _tmp3); + _r7 = _mm_unpackhi_epi64(_tmp5, _tmp7); +} + +static inline void transpose4x4_epi32(__m128i &_r0, __m128i &_r1, __m128i &_r2, + __m128i &_r3) { + __m128i _tmp0 = _mm_unpacklo_epi32(_r0, _r1); + __m128i _tmp1 = _mm_unpackhi_epi32(_r0, _r1); + __m128i _tmp2 = _mm_unpacklo_epi32(_r2, _r3); + __m128i _tmp3 = _mm_unpackhi_epi32(_r2, _r3); + + _r0 = _mm_unpacklo_epi64(_tmp0, _tmp2); + _r1 = _mm_unpackhi_epi64(_tmp0, _tmp2); + _r2 = _mm_unpacklo_epi64(_tmp1, _tmp3); + _r3 = _mm_unpackhi_epi64(_tmp1, _tmp3); +} + +static inline void transpose8x8_epi16(__m128i &_r0, __m128i &_r1, __m128i &_r2, + __m128i &_r3, __m128i &_r4, __m128i &_r5, + __m128i &_r6, __m128i &_r7) { + __m128i _tmp0 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _tmp1 = _mm_unpackhi_epi16(_r0, _r1); + __m128i _tmp2 = _mm_unpacklo_epi16(_r2, _r3); + __m128i _tmp3 = _mm_unpackhi_epi16(_r2, _r3); + __m128i _tmp4 = _mm_unpacklo_epi16(_r4, _r5); + __m128i _tmp5 = _mm_unpackhi_epi16(_r4, _r5); + __m128i _tmp6 = _mm_unpacklo_epi16(_r6, _r7); + __m128i _tmp7 = _mm_unpackhi_epi16(_r6, _r7); + + __m128i _tmp8 = _mm_unpacklo_epi32(_tmp0, _tmp2); + __m128i _tmp9 = _mm_unpackhi_epi32(_tmp0, _tmp2); + __m128i _tmpa = _mm_unpacklo_epi32(_tmp1, _tmp3); + __m128i _tmpb = _mm_unpackhi_epi32(_tmp1, _tmp3); + __m128i _tmpc = _mm_unpacklo_epi32(_tmp4, _tmp6); + __m128i _tmpd = _mm_unpackhi_epi32(_tmp4, _tmp6); + __m128i _tmpe = _mm_unpacklo_epi32(_tmp5, _tmp7); + __m128i _tmpf = _mm_unpackhi_epi32(_tmp5, _tmp7); + + _r0 = _mm_unpacklo_epi64(_tmp8, _tmpc); + _r1 = _mm_unpackhi_epi64(_tmp8, _tmpc); + _r2 = _mm_unpacklo_epi64(_tmp9, _tmpd); + _r3 = _mm_unpackhi_epi64(_tmp9, _tmpd); + _r4 = _mm_unpacklo_epi64(_tmpa, _tmpe); + _r5 = _mm_unpackhi_epi64(_tmpa, _tmpe); + _r6 = _mm_unpacklo_epi64(_tmpb, _tmpf); + _r7 = _mm_unpackhi_epi64(_tmpb, _tmpf); +} + +static inline void transpose8x4_epi16(__m128i &_r0, __m128i &_r1, __m128i &_r2, + __m128i &_r3) { + __m128i _tmp0 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _tmp1 = _mm_unpackhi_epi16(_r0, _r1); + __m128i _tmp2 = _mm_unpacklo_epi16(_r2, _r3); + __m128i _tmp3 = _mm_unpackhi_epi16(_r2, _r3); + + _r0 = _mm_unpacklo_epi32(_tmp0, _tmp2); + _r1 = _mm_unpackhi_epi32(_tmp0, _tmp2); + _r2 = _mm_unpacklo_epi32(_tmp1, _tmp3); + _r3 = _mm_unpackhi_epi32(_tmp1, _tmp3); +} + +static inline float _mm_reduce_add_ps(__m128 x128) { + const __m128 x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128)); + const __m128 x32 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55)); + return _mm_cvtss_f32(x32); +} + +static inline float _mm_reduce_max_ps(__m128 x128) { + const __m128 x64 = _mm_max_ps(x128, _mm_movehl_ps(x128, x128)); + const __m128 x32 = _mm_max_ss(x64, _mm_shuffle_ps(x64, x64, 0x55)); + return _mm_cvtss_f32(x32); +} + +static inline int _mm_reduce_add_epi32(__m128i x) { + __m128i hi64 = _mm_unpackhi_epi64(x, x); + __m128i sum64 = _mm_add_epi32(hi64, x); + __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + __m128i sum32 = _mm_add_epi32(sum64, hi32); + return _mm_cvtsi128_si32(sum32); +} + +static inline int32_t float2int8_sse(const __m128 &_v0) { + // _MM_ROUND_NEAREST round to even + // simulate round to nearest via +/-0.5 with round to zero + __m128 _p5 = _mm_set1_ps(0.5f); + __m128 _signmask = _mm_castsi128_ps(_mm_set1_epi32(1 << 31)); + __m128 _sign0 = _mm_and_ps(_v0, _signmask); + __m128 _v0_p5 = _mm_or_ps(_p5, _sign0); + __m128 _v0_adj = _mm_add_ps(_v0, _v0_p5); + __m128i _v0_i = _mm_cvttps_epi32(_v0_adj); + + __m128i _v0_s16 = _mm_packs_epi32(_v0_i, _v0_i); + + _v0_s16 = _mm_min_epi16(_v0_s16, _mm_set1_epi16(127)); + _v0_s16 = _mm_max_epi16(_v0_s16, _mm_set1_epi16(-127)); + + __m128i _v8 = _mm_packs_epi16(_v0_s16, _v0_s16); + +#if defined(__x86_64__) || defined(_M_X64) + return (int32_t)_mm_cvtsi128_si64(_v8); +#else + return _mm_cvtsi128_si32(_v8); +#endif +} + +static inline int64_t float2int8_sse(const __m128 &_v0, const __m128 &_v1) { + // _MM_ROUND_NEAREST round to even + // simulate round to nearest via +/-0.5 with round to zero + __m128 _p5 = _mm_set1_ps(0.5f); + __m128 _signmask = _mm_castsi128_ps(_mm_set1_epi32(1 << 31)); + __m128 _sign0 = _mm_and_ps(_v0, _signmask); + __m128 _sign1 = _mm_and_ps(_v1, _signmask); + __m128 _v0_p5 = _mm_or_ps(_p5, _sign0); + __m128 _v1_p5 = _mm_or_ps(_p5, _sign1); + __m128 _v0_adj = _mm_add_ps(_v0, _v0_p5); + __m128 _v1_adj = _mm_add_ps(_v1, _v1_p5); + __m128i _v0_i = _mm_cvttps_epi32(_v0_adj); + __m128i _v1_i = _mm_cvttps_epi32(_v1_adj); + + __m128i _v01_s16 = _mm_packs_epi32(_v0_i, _v1_i); + + _v01_s16 = _mm_min_epi16(_v01_s16, _mm_set1_epi16(127)); + _v01_s16 = _mm_max_epi16(_v01_s16, _mm_set1_epi16(-127)); + + __m128i _v8 = _mm_packs_epi16(_v01_s16, _v01_s16); + +#if defined(__x86_64__) || defined(_M_X64) + return _mm_cvtsi128_si64(_v8); +#else + int64_t v8[2]; + _mm_storeu_si128((__m128i *)v8, _v8); + return v8[0]; +#endif +} + +static inline __m128i float2int8_sse(const __m128 &_v0, const __m128 &_v1, + const __m128 &_v2, const __m128 &_v3) { + // _MM_ROUND_NEAREST round to even + // simulate round to nearest via +/-0.5 with round to zero + __m128 _p5 = _mm_set1_ps(0.5f); + __m128 _signmask = _mm_castsi128_ps(_mm_set1_epi32(1 << 31)); + __m128 _sign0 = _mm_and_ps(_v0, _signmask); + __m128 _sign1 = _mm_and_ps(_v1, _signmask); + __m128 _sign2 = _mm_and_ps(_v2, _signmask); + __m128 _sign3 = _mm_and_ps(_v3, _signmask); + __m128 _v0_p5 = _mm_or_ps(_p5, _sign0); + __m128 _v1_p5 = _mm_or_ps(_p5, _sign1); + __m128 _v2_p5 = _mm_or_ps(_p5, _sign2); + __m128 _v3_p5 = _mm_or_ps(_p5, _sign3); + __m128 _v0_adj = _mm_add_ps(_v0, _v0_p5); + __m128 _v1_adj = _mm_add_ps(_v1, _v1_p5); + __m128 _v2_adj = _mm_add_ps(_v2, _v2_p5); + __m128 _v3_adj = _mm_add_ps(_v3, _v3_p5); + __m128i _v0_i = _mm_cvttps_epi32(_v0_adj); + __m128i _v1_i = _mm_cvttps_epi32(_v1_adj); + __m128i _v2_i = _mm_cvttps_epi32(_v2_adj); + __m128i _v3_i = _mm_cvttps_epi32(_v3_adj); + + __m128i _v01_s16 = _mm_packs_epi32(_v0_i, _v1_i); + __m128i _v23_s16 = _mm_packs_epi32(_v2_i, _v3_i); + + _v01_s16 = _mm_min_epi16(_v01_s16, _mm_set1_epi16(127)); + _v23_s16 = _mm_min_epi16(_v23_s16, _mm_set1_epi16(127)); + _v01_s16 = _mm_max_epi16(_v01_s16, _mm_set1_epi16(-127)); + _v23_s16 = _mm_max_epi16(_v23_s16, _mm_set1_epi16(-127)); + + __m128i _v8 = _mm_packs_epi16(_v01_s16, _v23_s16); + + return _v8; +} + +static inline __m128 bfloat2float_sse(const __m128i &v0) { + __m128i _zero = _mm_setzero_si128(); + __m128i _a = _mm_unpacklo_epi16(_zero, v0); + __m128 _v = _mm_castsi128_ps(_a); + return _v; +} + +static inline __m128i float2bfloat_sse(const __m128 &v0, const __m128 &v1) { +#if __AVX512BF16__ + __m128i _v = (__m128i)_mm256_cvtneps_pbh( + _mm256_insertf128_ps(_mm256_castps128_ps256(v0), v1, 1)); +#else + __m128i _a = _mm_castps_si128(v0); + __m128i _b = _mm_castps_si128(v1); +#if __SSE4_1__ + _a = _mm_srli_epi32(_a, 16); + _b = _mm_srli_epi32(_b, 16); + __m128i _v = _mm_packus_epi32(_a, _b); +#else + _a = _mm_shufflelo_epi16(_a, _MM_SHUFFLE(2, 0, 3, 1)); + _b = _mm_shufflelo_epi16(_b, _MM_SHUFFLE(2, 0, 3, 1)); + _a = _mm_shufflehi_epi16(_a, _MM_SHUFFLE(2, 0, 3, 1)); + _b = _mm_shufflehi_epi16(_b, _MM_SHUFFLE(2, 0, 3, 1)); + __m128i _v = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(_a), _mm_castsi128_ps(_b), _MM_SHUFFLE(2, 0, 2, 0))); +#endif +#endif + return _v; +} + +#ifndef __FMA__ +static inline __m128 _mm_comp_fmadd_ps(const __m128 &_a, const __m128 &_b, + const __m128 &_c) { + return _mm_add_ps(_mm_mul_ps(_a, _b), _c); +} +static inline __m128 _mm_comp_fnmadd_ps(const __m128 &_a, const __m128 &_b, + const __m128 &_c) { + return _mm_sub_ps(_c, _mm_mul_ps(_a, _b)); +} +static inline __m128 _mm_comp_fmsub_ps(const __m128 &_a, const __m128 &_b, + const __m128 &_c) { + return _mm_sub_ps(_mm_mul_ps(_a, _b), _c); +} +static inline __m128 _mm_comp_fnmsub_ps(const __m128 &_a, const __m128 &_b, + const __m128 &_c) { + return _mm_sub_ps(_c, _mm_mul_ps(_mm_mul_ps(_a, _b), _mm_set1_ps(-1))); +} +#else +static inline __m128 _mm_comp_fmadd_ps(const __m128 &_a, const __m128 &_b, + const __m128 &_c) { + return _mm_fmadd_ps(_a, _b, _c); +} +static inline __m128 _mm_comp_fnmadd_ps(const __m128 &_a, const __m128 &_b, + const __m128 &_c) { + // return -a * b + c + return _mm_fnmadd_ps(_a, _b, _c); +} +static inline __m128 _mm_comp_fmsub_ps(const __m128 &_a, const __m128 &_b, + const __m128 &_c) { + return _mm_fmsub_ps(_a, _b, _c); +} +static inline __m128 _mm_comp_fnmsub_ps(const __m128 &_a, const __m128 &_b, + const __m128 &_c) { + return _mm_fnmsub_ps(_a, _b, _c); +} +#endif // !__FMA__ + +#if __AVX__ +#ifndef __FMA__ +static inline __m256 _mm256_comp_fmadd_ps(const __m256 &_a, const __m256 &_b, + const __m256 &_c) { + return _mm256_add_ps(_mm256_mul_ps(_a, _b), _c); +} +static inline __m256 _mm256_comp_fnmadd_ps(const __m256 &_a, const __m256 &_b, + const __m256 &_c) { + return _mm256_sub_ps(_c, _mm256_mul_ps(_a, _b)); +} +static inline __m256 _mm256_comp_fmsub_ps(const __m256 &_a, const __m256 &_b, + const __m256 &_c) { + return _mm256_sub_ps(_mm256_mul_ps(_a, _b), _c); +} +static inline __m256 _mm256_comp_fnmsub_ps(const __m256 &_a, const __m256 &_b, + const __m256 &_c) { + return _mm256_sub_ps( + _c, _mm256_mul_ps(_mm256_mul_ps(_a, _b), _mm256_set1_ps(-1))); +} +#else +static inline __m256 _mm256_comp_fmadd_ps(const __m256 &_a, const __m256 &_b, + const __m256 &_c) { + // return a * b + c + return _mm256_fmadd_ps(_a, _b, _c); +} +static inline __m256 _mm256_comp_fnmadd_ps(const __m256 &_a, const __m256 &_b, + const __m256 &_c) { + // return -a * b + c + return _mm256_fnmadd_ps(_a, _b, _c); +} +static inline __m256 _mm256_comp_fmsub_ps(const __m256 &_a, const __m256 &_b, + const __m256 &_c) { + // return a * b - c + return _mm256_fmsub_ps(_a, _b, _c); +} +static inline __m256 _mm256_comp_fnmsub_ps(const __m256 &_a, const __m256 &_b, + const __m256 &_c) { + // return -(a * b) - c + return _mm256_fnmsub_ps(_a, _b, _c); +} +#endif + +static inline __m256 _mm256_fmadd_1_ps(const __m256 &a, const __m256 &b, + float c) { + return _mm256_comp_fmadd_ps(b, _mm256_set1_ps(c), a); +} + +static inline __m256 _mm256_fmrsub_1_ps(const __m256 &a, const __m256 &b, + float c) { + // return a - b * c + return _mm256_comp_fnmadd_ps(b, _mm256_set1_ps(c), a); +} + +static inline void transpose8x12_ps(__m256 &_r0, __m256 &_r1, __m256 &_r2, + __m256 &_r3, __m256 &_r4, __m256 &_r5, + __m256 &_r6, __m256 &_r7, __m256 &_r8, + __m256 &_r9, __m256 &_ra, __m256 &_rb) { + __m256 _tmp0 = _mm256_unpacklo_ps(_r0, _r1); + __m256 _tmp1 = _mm256_unpackhi_ps(_r0, _r1); + __m256 _tmp2 = _mm256_unpacklo_ps(_r2, _r3); + __m256 _tmp3 = _mm256_unpackhi_ps(_r2, _r3); + __m256 _tmp4 = _mm256_unpacklo_ps(_r4, _r5); + __m256 _tmp5 = _mm256_unpackhi_ps(_r4, _r5); + __m256 _tmp6 = _mm256_unpacklo_ps(_r6, _r7); + __m256 _tmp7 = _mm256_unpackhi_ps(_r6, _r7); + __m256 _tmp8 = _mm256_unpacklo_ps(_r8, _r9); + __m256 _tmp9 = _mm256_unpackhi_ps(_r8, _r9); + __m256 _tmpa = _mm256_unpacklo_ps(_ra, _rb); + __m256 _tmpb = _mm256_unpackhi_ps(_ra, _rb); + + __m256 _tmpc = _mm256_shuffle_ps(_tmp0, _tmp2, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpd = _mm256_shuffle_ps(_tmp0, _tmp2, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmpe = _mm256_shuffle_ps(_tmp1, _tmp3, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpf = _mm256_shuffle_ps(_tmp1, _tmp3, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmpg = _mm256_shuffle_ps(_tmp4, _tmp6, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmph = _mm256_shuffle_ps(_tmp4, _tmp6, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmpi = _mm256_shuffle_ps(_tmp5, _tmp7, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpj = _mm256_shuffle_ps(_tmp5, _tmp7, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmpk = _mm256_shuffle_ps(_tmp8, _tmpa, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpl = _mm256_shuffle_ps(_tmp8, _tmpa, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmpm = _mm256_shuffle_ps(_tmp9, _tmpb, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpn = _mm256_shuffle_ps(_tmp9, _tmpb, _MM_SHUFFLE(3, 2, 3, 2)); + + _r0 = _mm256_permute2f128_ps(_tmpc, _tmpg, _MM_SHUFFLE(0, 2, 0, 0)); + _r1 = _mm256_permute2f128_ps(_tmpk, _tmpd, _MM_SHUFFLE(0, 2, 0, 0)); + _r2 = _mm256_permute2f128_ps(_tmph, _tmpl, _MM_SHUFFLE(0, 2, 0, 0)); + _r3 = _mm256_permute2f128_ps(_tmpe, _tmpi, _MM_SHUFFLE(0, 2, 0, 0)); + _r4 = _mm256_permute2f128_ps(_tmpm, _tmpf, _MM_SHUFFLE(0, 2, 0, 0)); + _r5 = _mm256_permute2f128_ps(_tmpj, _tmpn, _MM_SHUFFLE(0, 2, 0, 0)); + _r6 = _mm256_permute2f128_ps(_tmpc, _tmpg, _MM_SHUFFLE(0, 3, 0, 1)); + _r7 = _mm256_permute2f128_ps(_tmpk, _tmpd, _MM_SHUFFLE(0, 3, 0, 1)); + _r8 = _mm256_permute2f128_ps(_tmph, _tmpl, _MM_SHUFFLE(0, 3, 0, 1)); + _r9 = _mm256_permute2f128_ps(_tmpe, _tmpi, _MM_SHUFFLE(0, 3, 0, 1)); + _ra = _mm256_permute2f128_ps(_tmpm, _tmpf, _MM_SHUFFLE(0, 3, 0, 1)); + _rb = _mm256_permute2f128_ps(_tmpj, _tmpn, _MM_SHUFFLE(0, 3, 0, 1)); +} + +static inline void transpose8x8_ps(__m256 &_r0, __m256 &_r1, __m256 &_r2, + __m256 &_r3, __m256 &_r4, __m256 &_r5, + __m256 &_r6, __m256 &_r7) { + __m256 _tmp0 = _mm256_unpacklo_ps(_r0, _r1); + __m256 _tmp1 = _mm256_unpackhi_ps(_r0, _r1); + __m256 _tmp2 = _mm256_unpacklo_ps(_r2, _r3); + __m256 _tmp3 = _mm256_unpackhi_ps(_r2, _r3); + __m256 _tmp4 = _mm256_unpacklo_ps(_r4, _r5); + __m256 _tmp5 = _mm256_unpackhi_ps(_r4, _r5); + __m256 _tmp6 = _mm256_unpacklo_ps(_r6, _r7); + __m256 _tmp7 = _mm256_unpackhi_ps(_r6, _r7); + + __m256 _tmp8 = _mm256_shuffle_ps(_tmp0, _tmp2, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmp9 = _mm256_shuffle_ps(_tmp0, _tmp2, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmpa = _mm256_shuffle_ps(_tmp1, _tmp3, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpb = _mm256_shuffle_ps(_tmp1, _tmp3, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmpc = _mm256_shuffle_ps(_tmp4, _tmp6, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpd = _mm256_shuffle_ps(_tmp4, _tmp6, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmpe = _mm256_shuffle_ps(_tmp5, _tmp7, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpf = _mm256_shuffle_ps(_tmp5, _tmp7, _MM_SHUFFLE(3, 2, 3, 2)); + + _r0 = _mm256_permute2f128_ps(_tmp8, _tmpc, _MM_SHUFFLE(0, 2, 0, 0)); + _r1 = _mm256_permute2f128_ps(_tmp9, _tmpd, _MM_SHUFFLE(0, 2, 0, 0)); + _r2 = _mm256_permute2f128_ps(_tmpa, _tmpe, _MM_SHUFFLE(0, 2, 0, 0)); + _r3 = _mm256_permute2f128_ps(_tmpb, _tmpf, _MM_SHUFFLE(0, 2, 0, 0)); + _r4 = _mm256_permute2f128_ps(_tmp8, _tmpc, _MM_SHUFFLE(0, 3, 0, 1)); + _r5 = _mm256_permute2f128_ps(_tmp9, _tmpd, _MM_SHUFFLE(0, 3, 0, 1)); + _r6 = _mm256_permute2f128_ps(_tmpa, _tmpe, _MM_SHUFFLE(0, 3, 0, 1)); + _r7 = _mm256_permute2f128_ps(_tmpb, _tmpf, _MM_SHUFFLE(0, 3, 0, 1)); +} + +static inline void transpose8x4_ps(__m256 &_r0, __m256 &_r1, __m256 &_r2, + __m256 &_r3) { + __m256 _tmp0 = _mm256_unpacklo_ps(_r0, _r1); + __m256 _tmp1 = _mm256_unpackhi_ps(_r0, _r1); + __m256 _tmp2 = _mm256_unpacklo_ps(_r2, _r3); + __m256 _tmp3 = _mm256_unpackhi_ps(_r2, _r3); + + __m256 _tmp4 = _mm256_shuffle_ps(_tmp0, _tmp2, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmp5 = _mm256_shuffle_ps(_tmp0, _tmp2, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmp6 = _mm256_shuffle_ps(_tmp1, _tmp3, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmp7 = _mm256_shuffle_ps(_tmp1, _tmp3, _MM_SHUFFLE(3, 2, 3, 2)); + + _r0 = _mm256_permute2f128_ps(_tmp4, _tmp5, _MM_SHUFFLE(0, 2, 0, 0)); + _r1 = _mm256_permute2f128_ps(_tmp6, _tmp7, _MM_SHUFFLE(0, 2, 0, 0)); + _r2 = _mm256_permute2f128_ps(_tmp4, _tmp5, _MM_SHUFFLE(0, 3, 0, 1)); + _r3 = _mm256_permute2f128_ps(_tmp6, _tmp7, _MM_SHUFFLE(0, 3, 0, 1)); +} + +static inline void transpose8x2_ps(__m256 &_r0, __m256 &_r1) { + __m256 _tmp0 = _mm256_unpacklo_ps(_r0, _r1); + __m256 _tmp1 = _mm256_unpackhi_ps(_r0, _r1); + + _r0 = _mm256_permute2f128_ps(_tmp0, _tmp1, _MM_SHUFFLE(0, 2, 0, 0)); + _r1 = _mm256_permute2f128_ps(_tmp0, _tmp1, _MM_SHUFFLE(0, 3, 0, 1)); +} + +static inline void transpose2x8_ps(__m256 &_r0, __m256 &_r1) { + __m256 _tmp0 = _mm256_permute2f128_ps(_r0, _r1, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp1 = _mm256_permute2f128_ps(_r0, _r1, _MM_SHUFFLE(0, 3, 0, 1)); + + _r0 = _mm256_shuffle_ps(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _r1 = _mm256_shuffle_ps(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); +} + +static inline void transpose3x8_ps(__m256 &_r0, __m256 &_r1, __m256 &_r2) { + __m256 _tmp0 = _mm256_permute2f128_ps(_r0, _r1, _MM_SHUFFLE(0, 3, 0, 0)); + __m256 _tmp1 = _mm256_permute2f128_ps(_r0, _r2, _MM_SHUFFLE(0, 2, 0, 1)); + __m256 _tmp2 = _mm256_permute2f128_ps(_r1, _r2, _MM_SHUFFLE(0, 3, 0, 0)); + + __m256 _tmp4 = _mm256_shuffle_ps(_tmp0, _tmp1, _MM_SHUFFLE(1, 0, 2, 1)); + __m256 _tmp5 = _mm256_shuffle_ps(_tmp1, _tmp2, _MM_SHUFFLE(2, 1, 3, 2)); + + _r0 = _mm256_shuffle_ps(_tmp0, _tmp5, _MM_SHUFFLE(2, 0, 3, 0)); + _r1 = _mm256_shuffle_ps(_tmp4, _tmp5, _MM_SHUFFLE(3, 1, 2, 0)); + _r2 = _mm256_shuffle_ps(_tmp4, _tmp2, _MM_SHUFFLE(3, 0, 3, 1)); +} + +static inline void transpose8x6_ps(__m256 &_r0, __m256 &_r1, __m256 &_r2, + __m256 &_r3, __m256 &_r4, __m256 &_r5) { + __m256 _tmp0 = _mm256_unpacklo_ps(_r0, _r1); + __m256 _tmp1 = _mm256_unpackhi_ps(_r0, _r1); + __m256 _tmp2 = _mm256_unpacklo_ps(_r2, _r3); + __m256 _tmp3 = _mm256_unpackhi_ps(_r2, _r3); + __m256 _tmp4 = _mm256_unpacklo_ps(_r4, _r5); + __m256 _tmp5 = _mm256_unpackhi_ps(_r4, _r5); + + __m256 _tmp6 = _mm256_shuffle_ps(_tmp0, _tmp2, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmp7 = _mm256_shuffle_ps(_tmp4, _tmp0, _MM_SHUFFLE(3, 2, 1, 0)); + __m256 _tmp8 = _mm256_shuffle_ps(_tmp2, _tmp4, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmp9 = _mm256_shuffle_ps(_tmp1, _tmp3, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpa = _mm256_shuffle_ps(_tmp5, _tmp1, _MM_SHUFFLE(3, 2, 1, 0)); + __m256 _tmpb = _mm256_shuffle_ps(_tmp3, _tmp5, _MM_SHUFFLE(3, 2, 3, 2)); + + _r0 = _mm256_permute2f128_ps(_tmp6, _tmp7, _MM_SHUFFLE(0, 2, 0, 0)); + _r1 = _mm256_permute2f128_ps(_tmp8, _tmp9, _MM_SHUFFLE(0, 2, 0, 0)); + _r2 = _mm256_permute2f128_ps(_tmpa, _tmpb, _MM_SHUFFLE(0, 2, 0, 0)); + _r3 = _mm256_permute2f128_ps(_tmp6, _tmp7, _MM_SHUFFLE(0, 3, 0, 1)); + _r4 = _mm256_permute2f128_ps(_tmp8, _tmp9, _MM_SHUFFLE(0, 3, 0, 1)); + _r5 = _mm256_permute2f128_ps(_tmpa, _tmpb, _MM_SHUFFLE(0, 3, 0, 1)); +} + +static inline void transpose8x11_ps(__m256 &_r0, __m256 &_r1, __m256 &_r2, + __m256 &_r3, __m256 &_r4, __m256 &_r5, + __m256 &_r6, __m256 &_r7, __m256 &_r8, + __m256 &_r9, __m256 &_ra) { + __m256 _tmp0 = _mm256_unpacklo_ps(_r0, _r1); + __m256 _tmp1 = _mm256_unpackhi_ps(_r0, _r1); + __m256 _tmp2 = _mm256_unpacklo_ps(_r2, _r3); + __m256 _tmp3 = _mm256_unpackhi_ps(_r2, _r3); + __m256 _tmp4 = _mm256_unpacklo_ps(_r4, _r5); + __m256 _tmp5 = _mm256_unpackhi_ps(_r4, _r5); + __m256 _tmp6 = _mm256_unpacklo_ps(_r6, _r7); + __m256 _tmp7 = _mm256_unpackhi_ps(_r6, _r7); + __m256 _tmp8 = _mm256_unpacklo_ps(_r8, _r9); + __m256 _tmp9 = _mm256_unpackhi_ps(_r8, _r9); + __m256 _tmpa = _mm256_unpacklo_ps(_ra, _r0); + __m256 _tmpb = _mm256_shuffle_ps(_ra, _tmp1, _MM_SHUFFLE(3, 2, 1, 2)); + __m256 _tmpc = _mm256_unpacklo_ps(_r1, _r2); + __m256 _tmpd = _mm256_unpackhi_ps(_r1, _r2); + __m256 _tmpe = _mm256_unpacklo_ps(_r3, _r4); + __m256 _tmpf = _mm256_unpackhi_ps(_r3, _r4); + __m256 _tmpg = _mm256_unpacklo_ps(_r5, _r6); + __m256 _tmph = _mm256_unpackhi_ps(_r5, _r6); + __m256 _tmpi = _mm256_unpacklo_ps(_r7, _r8); + __m256 _tmpj = _mm256_unpackhi_ps(_r7, _r8); + __m256 _tmpk = _mm256_unpacklo_ps(_r9, _ra); + __m256 _tmpl = _mm256_unpackhi_ps(_r9, _ra); + + __m256 _tmpm = _mm256_shuffle_ps(_tmp0, _tmp2, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpn = _mm256_shuffle_ps(_tmp4, _tmp6, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpo = _mm256_shuffle_ps(_tmp8, _tmpa, _MM_SHUFFLE(3, 0, 1, 0)); + __m256 _tmpp = _mm256_shuffle_ps(_tmpc, _tmpe, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmpq = _mm256_shuffle_ps(_tmpg, _tmpi, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmpr = _mm256_shuffle_ps(_tmpk, _tmp1, _MM_SHUFFLE(1, 0, 3, 2)); + __m256 _tmps = _mm256_shuffle_ps(_tmp3, _tmp5, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpt = _mm256_shuffle_ps(_tmp7, _tmp9, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpu = _mm256_shuffle_ps(_tmpb, _tmpd, _MM_SHUFFLE(3, 2, 2, 0)); + __m256 _tmpv = _mm256_shuffle_ps(_tmpf, _tmph, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmpw = _mm256_shuffle_ps(_tmpj, _tmpl, _MM_SHUFFLE(3, 2, 3, 2)); + + _r0 = _mm256_permute2f128_ps(_tmpm, _tmpn, _MM_SHUFFLE(0, 2, 0, 0)); + _r1 = _mm256_permute2f128_ps(_tmpo, _tmpp, _MM_SHUFFLE(0, 2, 0, 0)); + _r2 = _mm256_permute2f128_ps(_tmpq, _tmpr, _MM_SHUFFLE(0, 2, 0, 0)); + _r3 = _mm256_permute2f128_ps(_tmps, _tmpt, _MM_SHUFFLE(0, 2, 0, 0)); + _r4 = _mm256_permute2f128_ps(_tmpu, _tmpv, _MM_SHUFFLE(0, 2, 0, 0)); + _r5 = _mm256_permute2f128_ps(_tmpw, _tmpm, _MM_SHUFFLE(0, 3, 0, 0)); + _r6 = _mm256_permute2f128_ps(_tmpn, _tmpo, _MM_SHUFFLE(0, 3, 0, 1)); + _r7 = _mm256_permute2f128_ps(_tmpp, _tmpq, _MM_SHUFFLE(0, 3, 0, 1)); + _r8 = _mm256_permute2f128_ps(_tmpr, _tmps, _MM_SHUFFLE(0, 3, 0, 1)); + _r9 = _mm256_permute2f128_ps(_tmpt, _tmpu, _MM_SHUFFLE(0, 3, 0, 1)); + _ra = _mm256_permute2f128_ps(_tmpv, _tmpw, _MM_SHUFFLE(0, 3, 0, 1)); +} + +static void transpose8x18_ps(__m256 &_r0, __m256 &_r1, __m256 &_r2, __m256 &_r3, + __m256 &_r4, __m256 &_r5, __m256 &_r6, __m256 &_r7, + __m256 &_r8, __m256 &_r9, __m256 &_ra, __m256 &_rb, + __m256 &_rc, __m256 &_rd, __m256 &_re, __m256 &_rf, + __m256 &_rg, __m256 &_rh) { + __m256 _tmp0 = _mm256_unpacklo_ps(_r0, _r1); + __m256 _tmp1 = _mm256_unpackhi_ps(_r0, _r1); + __m256 _tmp2 = _mm256_unpacklo_ps(_r2, _r3); + __m256 _tmp3 = _mm256_unpackhi_ps(_r2, _r3); + __m256 _tmp4 = _mm256_unpacklo_ps(_r4, _r5); + __m256 _tmp5 = _mm256_unpackhi_ps(_r4, _r5); + __m256 _tmp6 = _mm256_unpacklo_ps(_r6, _r7); + __m256 _tmp7 = _mm256_unpackhi_ps(_r6, _r7); + __m256 _tmp8 = _mm256_unpacklo_ps(_r8, _r9); + __m256 _tmp9 = _mm256_unpackhi_ps(_r8, _r9); + __m256 _tmpa = _mm256_unpacklo_ps(_ra, _rb); + __m256 _tmpb = _mm256_unpackhi_ps(_ra, _rb); + __m256 _tmpc = _mm256_unpacklo_ps(_rc, _rd); + __m256 _tmpd = _mm256_unpackhi_ps(_rc, _rd); + __m256 _tmpe = _mm256_unpacklo_ps(_re, _rf); + __m256 _tmpf = _mm256_unpackhi_ps(_re, _rf); + __m256 _tmpg = _mm256_unpacklo_ps(_rg, _rh); + __m256 _tmph = _mm256_unpackhi_ps(_rg, _rh); + + __m256 _tmpi = _mm256_shuffle_ps(_tmp0, _tmp2, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpj = _mm256_shuffle_ps(_tmp4, _tmp6, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpk = _mm256_shuffle_ps(_tmp8, _tmpa, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpl = _mm256_shuffle_ps(_tmpc, _tmpe, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpm = _mm256_shuffle_ps(_tmpg, _tmp0, _MM_SHUFFLE(3, 2, 1, 0)); + __m256 _tmpn = _mm256_shuffle_ps(_tmp2, _tmp4, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmpo = _mm256_shuffle_ps(_tmp6, _tmp8, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmpp = _mm256_shuffle_ps(_tmpa, _tmpc, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmpq = _mm256_shuffle_ps(_tmpe, _tmpg, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmpr = _mm256_shuffle_ps(_tmp1, _tmp3, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmps = _mm256_shuffle_ps(_tmp5, _tmp7, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpt = _mm256_shuffle_ps(_tmp9, _tmpb, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpu = _mm256_shuffle_ps(_tmpd, _tmpf, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpv = _mm256_shuffle_ps(_tmph, _tmp1, _MM_SHUFFLE(3, 2, 1, 0)); + __m256 _tmpw = _mm256_shuffle_ps(_tmp3, _tmp5, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmpx = _mm256_shuffle_ps(_tmp7, _tmp9, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmpy = _mm256_shuffle_ps(_tmpb, _tmpd, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmpz = _mm256_shuffle_ps(_tmpf, _tmph, _MM_SHUFFLE(3, 2, 3, 2)); + + _r0 = _mm256_permute2f128_ps(_tmpi, _tmpj, _MM_SHUFFLE(0, 2, 0, 0)); + _r1 = _mm256_permute2f128_ps(_tmpk, _tmpl, _MM_SHUFFLE(0, 2, 0, 0)); + _r2 = _mm256_permute2f128_ps(_tmpm, _tmpn, _MM_SHUFFLE(0, 2, 0, 0)); + _r3 = _mm256_permute2f128_ps(_tmpo, _tmpp, _MM_SHUFFLE(0, 2, 0, 0)); + _r4 = _mm256_permute2f128_ps(_tmpq, _tmpr, _MM_SHUFFLE(0, 2, 0, 0)); + _r5 = _mm256_permute2f128_ps(_tmps, _tmpt, _MM_SHUFFLE(0, 2, 0, 0)); + _r6 = _mm256_permute2f128_ps(_tmpu, _tmpv, _MM_SHUFFLE(0, 2, 0, 0)); + _r7 = _mm256_permute2f128_ps(_tmpw, _tmpx, _MM_SHUFFLE(0, 2, 0, 0)); + _r8 = _mm256_permute2f128_ps(_tmpy, _tmpz, _MM_SHUFFLE(0, 2, 0, 0)); + _r9 = _mm256_permute2f128_ps(_tmpi, _tmpj, _MM_SHUFFLE(0, 3, 0, 1)); + _ra = _mm256_permute2f128_ps(_tmpk, _tmpl, _MM_SHUFFLE(0, 3, 0, 1)); + _rb = _mm256_permute2f128_ps(_tmpm, _tmpn, _MM_SHUFFLE(0, 3, 0, 1)); + _rc = _mm256_permute2f128_ps(_tmpo, _tmpp, _MM_SHUFFLE(0, 3, 0, 1)); + _rd = _mm256_permute2f128_ps(_tmpq, _tmpr, _MM_SHUFFLE(0, 3, 0, 1)); + _re = _mm256_permute2f128_ps(_tmps, _tmpt, _MM_SHUFFLE(0, 3, 0, 1)); + _rf = _mm256_permute2f128_ps(_tmpu, _tmpv, _MM_SHUFFLE(0, 3, 0, 1)); + _rg = _mm256_permute2f128_ps(_tmpw, _tmpx, _MM_SHUFFLE(0, 3, 0, 1)); + _rh = _mm256_permute2f128_ps(_tmpy, _tmpz, _MM_SHUFFLE(0, 3, 0, 1)); +} + +static inline __m256 HorizontalSums(__m256 &v0, __m256 &v1, __m256 &v2, + __m256 &v3, __m256 &v4, __m256 &v5, + __m256 &v6, __m256 &v7) { + const __m256 s01 = _mm256_hadd_ps(v0, v1); + const __m256 s23 = _mm256_hadd_ps(v2, v3); + const __m256 s45 = _mm256_hadd_ps(v4, v5); + const __m256 s67 = _mm256_hadd_ps(v6, v7); + const __m256 s0123 = _mm256_hadd_ps(s01, s23); + const __m256 s4556 = _mm256_hadd_ps(s45, s67); + + // inter-lane shuffle + const __m256 vb0 = _mm256_blend_ps(s0123, s4556, 0xF0); + const __m256 vb1 = _mm256_permute2f128_ps(s0123, s4556, 0x21); + + return _mm256_add_ps(vb0, vb1); +} + +static inline __m128 HorizontalSums(__m256 &v0, __m256 &v1, __m256 &v2, + __m256 &v3) { + const __m256 s01 = _mm256_hadd_ps(v0, v1); + const __m256 s23 = _mm256_hadd_ps(v2, v3); + const __m256 s0123 = _mm256_hadd_ps(s01, s23); + + return _mm_add_ps(_mm256_extractf128_ps(s0123, 1), + _mm256_castps256_ps128(s0123)); +} + +static inline __m128 HorizontalSums(__m256 &v0, __m256 &v1, __m256 &v2) { + const __m256 v3 = _mm256_set1_ps(0.0f); + const __m256 s01 = _mm256_hadd_ps(v0, v1); + const __m256 s23 = _mm256_hadd_ps(v2, v3); + const __m256 s0123 = _mm256_hadd_ps(s01, s23); + + return _mm_add_ps(_mm256_extractf128_ps(s0123, 1), + _mm256_castps256_ps128(s0123)); +} + +static inline float _mm256_reduce_add_ps(__m256 x) { + /* ( x3+x7, x2+x6, x1+x5, x0+x4 ) */ + const __m128 x128 = + _mm_add_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x)); + /* ( -, -, x1+x3+x5+x7, x0+x2+x4+x6 ) */ + const __m128 x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128)); + /* ( -, -, -, x0+x1+x2+x3+x4+x5+x6+x7 ) */ + const __m128 x32 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55)); + /* Conversion to float is a no-op on x86-64 */ + return _mm_cvtss_f32(x32); +} + +static inline float _mm256_reduce_max_ps(__m256 x) { + const __m128 x128 = + _mm_max_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x)); + const __m128 x64 = _mm_max_ps(x128, _mm_movehl_ps(x128, x128)); + const __m128 x32 = _mm_max_ss(x64, _mm_shuffle_ps(x64, x64, 0x55)); + return _mm_cvtss_f32(x32); +} + +static inline int64_t float2int8_avx(const __m256 &_v0) { + // _MM_FROUND_TO_NEAREST_INT round to even + // simulate round to nearest via +/-0.5 with round to zero + __m256 _p5 = _mm256_set1_ps(0.5f); + __m256 _signmask = _mm256_castsi256_ps(_mm256_set1_epi32(1 << 31)); + __m256 _sign = _mm256_and_ps(_v0, _signmask); + __m256 _v0_p5 = _mm256_or_ps(_p5, _sign); + __m256 _v0_adj = _mm256_add_ps(_v0, _v0_p5); + __m256i _v0_i = _mm256_cvttps_epi32(_v0_adj); + +#if __AVX2__ + __m256i _v01_s16 = _mm256_packs_epi32(_v0_i, _v0_i); + _v01_s16 = _mm256_permute4x64_epi64(_v01_s16, 0xd8); + + __m128i _v01_s16low = _mm256_extractf128_si256(_v01_s16, 0); +#else // __AVX2__ + __m128i _v0_i_low = _mm256_extractf128_si256(_v0_i, 0); + __m128i _v0_i_high = _mm256_extractf128_si256(_v0_i, 1); + + __m128i _v01_s16low = _mm_packs_epi32(_v0_i_low, _v0_i_high); +#endif // __AVX2__ + + _v01_s16low = _mm_min_epi16(_v01_s16low, _mm_set1_epi16(127)); + _v01_s16low = _mm_max_epi16(_v01_s16low, _mm_set1_epi16(-127)); + + __m128i _v8 = _mm_packs_epi16(_v01_s16low, _v01_s16low); + +#if defined(__x86_64__) || defined(_M_X64) + return _mm_cvtsi128_si64(_v8); +#else + int64_t v8[2]; + _mm_storeu_si128((__m128i *)v8, _v8); + return v8[0]; +#endif +} + +static inline __m128i float2int8_avx(const __m256 &_v0, const __m256 &_v1) { + // _MM_FROUND_TO_NEAREST_INT round to even + // simulate round to nearest via +/-0.5 with round to zero + __m256 _p5 = _mm256_set1_ps(0.5f); + __m256 _signmask = _mm256_castsi256_ps(_mm256_set1_epi32(1 << 31)); + __m256 _sign0 = _mm256_and_ps(_v0, _signmask); + __m256 _sign1 = _mm256_and_ps(_v1, _signmask); + __m256 _v0_p5 = _mm256_or_ps(_p5, _sign0); + __m256 _v1_p5 = _mm256_or_ps(_p5, _sign1); + __m256 _v0_adj = _mm256_add_ps(_v0, _v0_p5); + __m256 _v1_adj = _mm256_add_ps(_v1, _v1_p5); + __m256i _v0_i = _mm256_cvttps_epi32(_v0_adj); + __m256i _v1_i = _mm256_cvttps_epi32(_v1_adj); + +#if __AVX2__ + __m256i _v01_s16 = _mm256_packs_epi32(_v0_i, _v1_i); + _v01_s16 = _mm256_permute4x64_epi64(_v01_s16, 0xd8); + + _v01_s16 = _mm256_min_epi16(_v01_s16, _mm256_set1_epi16(127)); + _v01_s16 = _mm256_max_epi16(_v01_s16, _mm256_set1_epi16(-127)); + + __m256i _v8 = _mm256_packs_epi16(_v01_s16, _v01_s16); + _v8 = _mm256_permute4x64_epi64(_v8, 0xd8); + + return _mm256_extractf128_si256(_v8, 0); +#else // __AVX2__ + __m128i _v0_i_low = _mm256_extractf128_si256(_v0_i, 0); + __m128i _v0_i_high = _mm256_extractf128_si256(_v0_i, 1); + __m128i _v1_i_low = _mm256_extractf128_si256(_v1_i, 0); + __m128i _v1_i_high = _mm256_extractf128_si256(_v1_i, 1); + + __m128i _v01_s16low = _mm_packs_epi32(_v0_i_low, _v0_i_high); + __m128i _v01_s16high = _mm_packs_epi32(_v1_i_low, _v1_i_high); + + _v01_s16low = _mm_min_epi16(_v01_s16low, _mm_set1_epi16(127)); + _v01_s16high = _mm_min_epi16(_v01_s16high, _mm_set1_epi16(127)); + _v01_s16low = _mm_max_epi16(_v01_s16low, _mm_set1_epi16(-127)); + _v01_s16high = _mm_max_epi16(_v01_s16high, _mm_set1_epi16(-127)); + + __m128i _v8 = _mm_packs_epi16(_v01_s16low, _v01_s16high); + return _v8; +#endif // __AVX2__ +} + +static inline void _mm256_comp_fmadd_ps4(__m256 &_sum, const __m256 &_w0, + const __m256 &_w1, const __m256 &_w2, + const __m256 &_w3, const __m256 &_v0, + const __m256 &_v1, const __m256 &_v2, + const __m256 &_v3) { + __m256 _mul0 = _mm256_mul_ps(_w0, _v0); + __m256 _mul1 = _mm256_mul_ps(_w1, _v1); + __m256 _sum01 = _mm256_add_ps(_mul0, _mul1); + __m256 _mul2 = _mm256_mul_ps(_w2, _v2); + __m256 _mul3 = _mm256_mul_ps(_w3, _v3); + __m256 _sum23 = _mm256_add_ps(_mul2, _mul3); + __m256 _sum0123 = _mm256_add_ps(_sum01, _sum23); + _sum = _mm256_add_ps(_sum, _sum0123); +} + +static inline void +_mm256_comp_fmadd_ps8(__m256 &_sum, const __m256 &_w0, const __m256 &_w1, + const __m256 &_w2, const __m256 &_w3, const __m256 &_w4, + const __m256 &_w5, const __m256 &_w6, const __m256 &_w7, + const __m256 &_v0, const __m256 &_v1, const __m256 &_v2, + const __m256 &_v3, const __m256 &_v4, const __m256 &_v5, + const __m256 &_v6, const __m256 &_v7) { + _mm256_comp_fmadd_ps4(_sum, _w0, _w1, _w2, _w3, _v0, _v1, _v2, _v3); + + _mm256_comp_fmadd_ps4(_sum, _w4, _w5, _w6, _w7, _v4, _v5, _v6, _v7); +} + +static inline __m256 bfloat2float_avx(const __m128i &v0) { +#if __AVX512BF16__ + __m256 _v = _mm256_cvtpbh_ps((__m128bh)v0); +#else + __m128i _zero = _mm_setzero_si128(); + __m128i _a = _mm_unpacklo_epi16(_zero, v0); + __m128i _b = _mm_unpackhi_epi16(_zero, v0); + __m256 _v = _mm256_castsi256_ps( + _mm256_insertf128_si256(_mm256_castsi128_si256(_a), _b, 1)); +#endif + return _v; +} + +static inline __m128i float2bfloat_avx(const __m256 &v0) { +#if __AVX512BF16__ + __m128i _v = (__m128i)_mm256_cvtneps_pbh(v0); +#else + __m256i _ab = _mm256_castps_si256(v0); +#if __AVX2__ + _ab = _mm256_srli_epi32(_ab, 16); + __m128i _a = _mm256_extractf128_si256(_ab, 0); + __m128i _b = _mm256_extractf128_si256(_ab, 1); +#else + __m128i _a = _mm256_extractf128_si256(_ab, 0); + __m128i _b = _mm256_extractf128_si256(_ab, 1); + _a = _mm_srli_epi32(_a, 16); + _b = _mm_srli_epi32(_b, 16); +#endif + __m128i _v = _mm_packus_epi32(_a, _b); +#endif + return _v; +} + +static inline __m256i float2bfloat_avx(const __m256 &v0, const __m256 &v1) { +#if __AVX512BF16__ + __m128i _v0 = (__m128i)_mm256_cvtneps_pbh(v0); + __m128i _v1 = (__m128i)_mm256_cvtneps_pbh(v1); + __m256i _v = _mm256_insertf128_si256(_mm256_castsi128_si256(_v0), _v1, 1); +#else + __m256i _a = _mm256_castps_si256(v0); + __m256i _b = _mm256_castps_si256(v1); +#if __AVX2__ + _a = _mm256_srli_epi32(_a, 16); + _b = _mm256_srli_epi32(_b, 16); + __m256i _v = _mm256_packus_epi32(_a, _b); + _v = _mm256_permute4x64_epi64(_v, _MM_SHUFFLE(3, 1, 2, 0)); +#else + __m128i _a0 = _mm256_extractf128_si256(_a, 0); + __m128i _a1 = _mm256_extractf128_si256(_a, 1); + __m128i _b0 = _mm256_extractf128_si256(_b, 0); + __m128i _b1 = _mm256_extractf128_si256(_b, 1); + _a0 = _mm_srli_epi32(_a0, 16); + _a1 = _mm_srli_epi32(_a1, 16); + _b0 = _mm_srli_epi32(_b0, 16); + _b1 = _mm_srli_epi32(_b1, 16); + __m128i _v0 = _mm_packus_epi32(_a0, _a1); + __m128i _v1 = _mm_packus_epi32(_b0, _b1); + __m256i _v = _mm256_insertf128_si256(_mm256_castsi128_si256(_v0), _v1, 1); +#endif +#endif + return _v; +} + +#if __AVX2__ +static inline void transpose8x2_epi32(__m256i &_r0, __m256i &_r1) { + __m256i _tmp0 = _mm256_unpacklo_epi32(_r0, _r1); + __m256i _tmp1 = _mm256_unpackhi_epi32(_r0, _r1); + + _r0 = _mm256_permute2x128_si256(_tmp0, _tmp1, _MM_SHUFFLE(0, 2, 0, 0)); + _r1 = _mm256_permute2x128_si256(_tmp0, _tmp1, _MM_SHUFFLE(0, 3, 0, 1)); +} + +static inline void transpose16x8_epi16(__m256i &_r0, __m256i &_r1, __m256i &_r2, + __m256i &_r3, __m256i &_r4, __m256i &_r5, + __m256i &_r6, __m256i &_r7) { + __m256i _tmp0 = _mm256_unpacklo_epi16(_r0, _r1); + __m256i _tmp1 = _mm256_unpackhi_epi16(_r0, _r1); + __m256i _tmp2 = _mm256_unpacklo_epi16(_r2, _r3); + __m256i _tmp3 = _mm256_unpackhi_epi16(_r2, _r3); + __m256i _tmp4 = _mm256_unpacklo_epi16(_r4, _r5); + __m256i _tmp5 = _mm256_unpackhi_epi16(_r4, _r5); + __m256i _tmp6 = _mm256_unpacklo_epi16(_r6, _r7); + __m256i _tmp7 = _mm256_unpackhi_epi16(_r6, _r7); + + __m256i _tmpg = _mm256_unpacklo_epi32(_tmp0, _tmp2); + __m256i _tmph = _mm256_unpackhi_epi32(_tmp0, _tmp2); + __m256i _tmpi = _mm256_unpacklo_epi32(_tmp1, _tmp3); + __m256i _tmpj = _mm256_unpackhi_epi32(_tmp1, _tmp3); + __m256i _tmpk = _mm256_unpacklo_epi32(_tmp4, _tmp6); + __m256i _tmpl = _mm256_unpackhi_epi32(_tmp4, _tmp6); + __m256i _tmpm = _mm256_unpacklo_epi32(_tmp5, _tmp7); + __m256i _tmpn = _mm256_unpackhi_epi32(_tmp5, _tmp7); + + _tmp0 = _mm256_unpacklo_epi64(_tmpg, _tmpk); + _tmp1 = _mm256_unpackhi_epi64(_tmpg, _tmpk); + _tmp2 = _mm256_unpacklo_epi64(_tmph, _tmpl); + _tmp3 = _mm256_unpackhi_epi64(_tmph, _tmpl); + _tmp4 = _mm256_unpacklo_epi64(_tmpi, _tmpm); + _tmp5 = _mm256_unpackhi_epi64(_tmpi, _tmpm); + _tmp6 = _mm256_unpacklo_epi64(_tmpj, _tmpn); + _tmp7 = _mm256_unpackhi_epi64(_tmpj, _tmpn); + + _r0 = _mm256_permute2x128_si256(_tmp0, _tmp1, _MM_SHUFFLE(0, 2, 0, 0)); + _r1 = _mm256_permute2x128_si256(_tmp2, _tmp3, _MM_SHUFFLE(0, 2, 0, 0)); + _r2 = _mm256_permute2x128_si256(_tmp4, _tmp5, _MM_SHUFFLE(0, 2, 0, 0)); + _r3 = _mm256_permute2x128_si256(_tmp6, _tmp7, _MM_SHUFFLE(0, 2, 0, 0)); + _r4 = _mm256_permute2x128_si256(_tmp0, _tmp1, _MM_SHUFFLE(0, 3, 0, 1)); + _r5 = _mm256_permute2x128_si256(_tmp2, _tmp3, _MM_SHUFFLE(0, 3, 0, 1)); + _r6 = _mm256_permute2x128_si256(_tmp4, _tmp5, _MM_SHUFFLE(0, 3, 0, 1)); + _r7 = _mm256_permute2x128_si256(_tmp6, _tmp7, _MM_SHUFFLE(0, 3, 0, 1)); +} + +#if __AVX512F__ +static inline void transpose16x16_ps(__m512 &_r0, __m512 &_r1, __m512 &_r2, + __m512 &_r3, __m512 &_r4, __m512 &_r5, + __m512 &_r6, __m512 &_r7, __m512 &_r8, + __m512 &_r9, __m512 &_ra, __m512 &_rb, + __m512 &_rc, __m512 &_rd, __m512 &_re, + __m512 &_rf) { + __m512 _tmp0 = _mm512_unpacklo_ps(_r0, _r1); + __m512 _tmp1 = _mm512_unpackhi_ps(_r0, _r1); + __m512 _tmp2 = _mm512_unpacklo_ps(_r2, _r3); + __m512 _tmp3 = _mm512_unpackhi_ps(_r2, _r3); + __m512 _tmp4 = _mm512_unpacklo_ps(_r4, _r5); + __m512 _tmp5 = _mm512_unpackhi_ps(_r4, _r5); + __m512 _tmp6 = _mm512_unpacklo_ps(_r6, _r7); + __m512 _tmp7 = _mm512_unpackhi_ps(_r6, _r7); + __m512 _tmp8 = _mm512_unpacklo_ps(_r8, _r9); + __m512 _tmp9 = _mm512_unpackhi_ps(_r8, _r9); + __m512 _tmpa = _mm512_unpacklo_ps(_ra, _rb); + __m512 _tmpb = _mm512_unpackhi_ps(_ra, _rb); + __m512 _tmpc = _mm512_unpacklo_ps(_rc, _rd); + __m512 _tmpd = _mm512_unpackhi_ps(_rc, _rd); + __m512 _tmpe = _mm512_unpacklo_ps(_re, _rf); + __m512 _tmpf = _mm512_unpackhi_ps(_re, _rf); + + __m512 _tmpg = _mm512_shuffle_ps(_tmp0, _tmp2, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmph = _mm512_shuffle_ps(_tmp0, _tmp2, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmpi = _mm512_shuffle_ps(_tmp1, _tmp3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmpj = _mm512_shuffle_ps(_tmp1, _tmp3, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmpk = _mm512_shuffle_ps(_tmp4, _tmp6, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmpl = _mm512_shuffle_ps(_tmp4, _tmp6, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmpm = _mm512_shuffle_ps(_tmp5, _tmp7, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmpn = _mm512_shuffle_ps(_tmp5, _tmp7, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmpo = _mm512_shuffle_ps(_tmp8, _tmpa, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmpp = _mm512_shuffle_ps(_tmp8, _tmpa, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmpq = _mm512_shuffle_ps(_tmp9, _tmpb, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmpr = _mm512_shuffle_ps(_tmp9, _tmpb, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmps = _mm512_shuffle_ps(_tmpc, _tmpe, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmpt = _mm512_shuffle_ps(_tmpc, _tmpe, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmpu = _mm512_shuffle_ps(_tmpd, _tmpf, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmpv = _mm512_shuffle_ps(_tmpd, _tmpf, _MM_SHUFFLE(3, 2, 3, 2)); + + _tmp0 = _mm512_shuffle_f32x4(_tmpg, _tmpk, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp1 = _mm512_shuffle_f32x4(_tmpo, _tmps, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp2 = _mm512_shuffle_f32x4(_tmph, _tmpl, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp3 = _mm512_shuffle_f32x4(_tmpp, _tmpt, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp4 = _mm512_shuffle_f32x4(_tmpi, _tmpm, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp5 = _mm512_shuffle_f32x4(_tmpq, _tmpu, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp6 = _mm512_shuffle_f32x4(_tmpj, _tmpn, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp7 = _mm512_shuffle_f32x4(_tmpr, _tmpv, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp8 = _mm512_shuffle_f32x4(_tmpg, _tmpk, _MM_SHUFFLE(3, 1, 3, 1)); + _tmp9 = _mm512_shuffle_f32x4(_tmpo, _tmps, _MM_SHUFFLE(3, 1, 3, 1)); + _tmpa = _mm512_shuffle_f32x4(_tmph, _tmpl, _MM_SHUFFLE(3, 1, 3, 1)); + _tmpb = _mm512_shuffle_f32x4(_tmpp, _tmpt, _MM_SHUFFLE(3, 1, 3, 1)); + _tmpc = _mm512_shuffle_f32x4(_tmpi, _tmpm, _MM_SHUFFLE(3, 1, 3, 1)); + _tmpd = _mm512_shuffle_f32x4(_tmpq, _tmpu, _MM_SHUFFLE(3, 1, 3, 1)); + _tmpe = _mm512_shuffle_f32x4(_tmpj, _tmpn, _MM_SHUFFLE(3, 1, 3, 1)); + _tmpf = _mm512_shuffle_f32x4(_tmpr, _tmpv, _MM_SHUFFLE(3, 1, 3, 1)); + + _r0 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _r1 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _r2 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(2, 0, 2, 0)); + _r3 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(2, 0, 2, 0)); + _r4 = _mm512_shuffle_f32x4(_tmp8, _tmp9, _MM_SHUFFLE(2, 0, 2, 0)); + _r5 = _mm512_shuffle_f32x4(_tmpa, _tmpb, _MM_SHUFFLE(2, 0, 2, 0)); + _r6 = _mm512_shuffle_f32x4(_tmpc, _tmpd, _MM_SHUFFLE(2, 0, 2, 0)); + _r7 = _mm512_shuffle_f32x4(_tmpe, _tmpf, _MM_SHUFFLE(2, 0, 2, 0)); + _r8 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + _r9 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); + _ra = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(3, 1, 3, 1)); + _rb = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(3, 1, 3, 1)); + _rc = _mm512_shuffle_f32x4(_tmp8, _tmp9, _MM_SHUFFLE(3, 1, 3, 1)); + _rd = _mm512_shuffle_f32x4(_tmpa, _tmpb, _MM_SHUFFLE(3, 1, 3, 1)); + _re = _mm512_shuffle_f32x4(_tmpc, _tmpd, _MM_SHUFFLE(3, 1, 3, 1)); + _rf = _mm512_shuffle_f32x4(_tmpe, _tmpf, _MM_SHUFFLE(3, 1, 3, 1)); +} + +static inline void transpose16x12_ps(__m512 &_r0, __m512 &_r1, __m512 &_r2, + __m512 &_r3, __m512 &_r4, __m512 &_r5, + __m512 &_r6, __m512 &_r7, __m512 &_r8, + __m512 &_r9, __m512 &_ra, __m512 &_rb) { + __m512 _tmp0 = _mm512_unpacklo_ps(_r0, _r1); + __m512 _tmp1 = _mm512_unpackhi_ps(_r0, _r1); + __m512 _tmp2 = _mm512_unpacklo_ps(_r2, _r3); + __m512 _tmp3 = _mm512_unpackhi_ps(_r2, _r3); + __m512 _tmp4 = _mm512_unpacklo_ps(_r4, _r5); + __m512 _tmp5 = _mm512_unpackhi_ps(_r4, _r5); + __m512 _tmp6 = _mm512_unpacklo_ps(_r6, _r7); + __m512 _tmp7 = _mm512_unpackhi_ps(_r6, _r7); + __m512 _tmp8 = _mm512_unpacklo_ps(_r8, _r9); + __m512 _tmp9 = _mm512_unpackhi_ps(_r8, _r9); + __m512 _tmpa = _mm512_unpacklo_ps(_ra, _rb); + __m512 _tmpb = _mm512_unpackhi_ps(_ra, _rb); + + __m512 _tmpc = _mm512_shuffle_ps(_tmp0, _tmp2, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmpd = _mm512_shuffle_ps(_tmp0, _tmp2, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmpe = _mm512_shuffle_ps(_tmp1, _tmp3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmpf = _mm512_shuffle_ps(_tmp1, _tmp3, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmpg = _mm512_shuffle_ps(_tmp4, _tmp6, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmph = _mm512_shuffle_ps(_tmp4, _tmp6, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmpi = _mm512_shuffle_ps(_tmp5, _tmp7, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmpj = _mm512_shuffle_ps(_tmp5, _tmp7, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmpk = _mm512_shuffle_ps(_tmp8, _tmpa, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmpl = _mm512_shuffle_ps(_tmp8, _tmpa, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmpm = _mm512_shuffle_ps(_tmp9, _tmpb, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmpn = _mm512_shuffle_ps(_tmp9, _tmpb, _MM_SHUFFLE(3, 2, 3, 2)); + + _tmp0 = _mm512_shuffle_f32x4(_tmpc, _tmpg, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp1 = _mm512_shuffle_f32x4(_tmpk, _tmpd, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp2 = _mm512_shuffle_f32x4(_tmph, _tmpl, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp3 = _mm512_shuffle_f32x4(_tmpe, _tmpi, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp4 = _mm512_shuffle_f32x4(_tmpm, _tmpf, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp5 = _mm512_shuffle_f32x4(_tmpj, _tmpn, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp6 = _mm512_shuffle_f32x4(_tmpc, _tmpg, _MM_SHUFFLE(3, 1, 3, 1)); + _tmp7 = _mm512_shuffle_f32x4(_tmpk, _tmpd, _MM_SHUFFLE(3, 1, 3, 1)); + _tmp8 = _mm512_shuffle_f32x4(_tmph, _tmpl, _MM_SHUFFLE(3, 1, 3, 1)); + _tmp9 = _mm512_shuffle_f32x4(_tmpe, _tmpi, _MM_SHUFFLE(3, 1, 3, 1)); + _tmpa = _mm512_shuffle_f32x4(_tmpm, _tmpf, _MM_SHUFFLE(3, 1, 3, 1)); + _tmpb = _mm512_shuffle_f32x4(_tmpj, _tmpn, _MM_SHUFFLE(3, 1, 3, 1)); + + _r0 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _r1 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _r2 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(2, 0, 2, 0)); + _r3 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(2, 0, 2, 0)); + _r4 = _mm512_shuffle_f32x4(_tmp8, _tmp9, _MM_SHUFFLE(2, 0, 2, 0)); + _r5 = _mm512_shuffle_f32x4(_tmpa, _tmpb, _MM_SHUFFLE(2, 0, 2, 0)); + _r6 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + _r7 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); + _r8 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(3, 1, 3, 1)); + _r9 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(3, 1, 3, 1)); + _ra = _mm512_shuffle_f32x4(_tmp8, _tmp9, _MM_SHUFFLE(3, 1, 3, 1)); + _rb = _mm512_shuffle_f32x4(_tmpa, _tmpb, _MM_SHUFFLE(3, 1, 3, 1)); +} + +static inline void transpose16x8_ps(__m512 &_r0, __m512 &_r1, __m512 &_r2, + __m512 &_r3, __m512 &_r4, __m512 &_r5, + __m512 &_r6, __m512 &_r7) { + __m512 _tmp0 = _mm512_unpacklo_ps(_r0, _r1); + __m512 _tmp1 = _mm512_unpackhi_ps(_r0, _r1); + __m512 _tmp2 = _mm512_unpacklo_ps(_r2, _r3); + __m512 _tmp3 = _mm512_unpackhi_ps(_r2, _r3); + __m512 _tmp4 = _mm512_unpacklo_ps(_r4, _r5); + __m512 _tmp5 = _mm512_unpackhi_ps(_r4, _r5); + __m512 _tmp6 = _mm512_unpacklo_ps(_r6, _r7); + __m512 _tmp7 = _mm512_unpackhi_ps(_r6, _r7); + + __m512 _tmp8 = _mm512_shuffle_ps(_tmp0, _tmp2, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp9 = _mm512_shuffle_ps(_tmp0, _tmp2, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmpa = _mm512_shuffle_ps(_tmp1, _tmp3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmpb = _mm512_shuffle_ps(_tmp1, _tmp3, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmpc = _mm512_shuffle_ps(_tmp4, _tmp6, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmpd = _mm512_shuffle_ps(_tmp4, _tmp6, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmpe = _mm512_shuffle_ps(_tmp5, _tmp7, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmpf = _mm512_shuffle_ps(_tmp5, _tmp7, _MM_SHUFFLE(3, 2, 3, 2)); + + _tmp0 = _mm512_shuffle_f32x4(_tmp8, _tmpc, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp1 = _mm512_shuffle_f32x4(_tmp9, _tmpd, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp2 = _mm512_shuffle_f32x4(_tmpa, _tmpe, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp3 = _mm512_shuffle_f32x4(_tmpb, _tmpf, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp4 = _mm512_shuffle_f32x4(_tmp8, _tmpc, _MM_SHUFFLE(3, 1, 3, 1)); + _tmp5 = _mm512_shuffle_f32x4(_tmp9, _tmpd, _MM_SHUFFLE(3, 1, 3, 1)); + _tmp6 = _mm512_shuffle_f32x4(_tmpa, _tmpe, _MM_SHUFFLE(3, 1, 3, 1)); + _tmp7 = _mm512_shuffle_f32x4(_tmpb, _tmpf, _MM_SHUFFLE(3, 1, 3, 1)); + + _r0 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _r1 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _r2 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(2, 0, 2, 0)); + _r3 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(2, 0, 2, 0)); + _r4 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + _r5 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); + _r6 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(3, 1, 3, 1)); + _r7 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(3, 1, 3, 1)); +} + +static inline void transpose16x4_ps(__m512 &_r0, __m512 &_r1, __m512 &_r2, + __m512 &_r3) { + __m512 _tmp0 = _mm512_unpacklo_ps(_r0, _r1); + __m512 _tmp1 = _mm512_unpackhi_ps(_r0, _r1); + __m512 _tmp2 = _mm512_unpacklo_ps(_r2, _r3); + __m512 _tmp3 = _mm512_unpackhi_ps(_r2, _r3); + + __m512 _tmp4 = _mm512_shuffle_ps(_tmp0, _tmp2, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp5 = _mm512_shuffle_ps(_tmp0, _tmp2, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp6 = _mm512_shuffle_ps(_tmp1, _tmp3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp7 = _mm512_shuffle_ps(_tmp1, _tmp3, _MM_SHUFFLE(3, 2, 3, 2)); + + _tmp0 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp1 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp2 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(3, 1, 3, 1)); + _tmp3 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(3, 1, 3, 1)); + + _r0 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _r1 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _r2 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + _r3 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); +} + +static inline void transpose16x2_ps(__m512 &_r0, __m512 &_r1) { + __m512 _tmp0 = _mm512_unpacklo_ps(_r0, _r1); + __m512 _tmp1 = _mm512_unpackhi_ps(_r0, _r1); + + __m512 _tmp2 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + + _r0 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _r1 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); +} + +static inline void transpose8x16_ps(__m256 &_r0, __m256 &_r1, __m256 &_r2, + __m256 &_r3, __m256 &_r4, __m256 &_r5, + __m256 &_r6, __m256 &_r7, __m256 &_r8, + __m256 &_r9, __m256 &_ra, __m256 &_rb, + __m256 &_rc, __m256 &_rd, __m256 &_re, + __m256 &_rf) { + __m256 _tmp0 = _mm256_unpacklo_ps(_r0, _r1); + __m256 _tmp1 = _mm256_unpackhi_ps(_r0, _r1); + __m256 _tmp2 = _mm256_unpacklo_ps(_r2, _r3); + __m256 _tmp3 = _mm256_unpackhi_ps(_r2, _r3); + __m256 _tmp4 = _mm256_unpacklo_ps(_r4, _r5); + __m256 _tmp5 = _mm256_unpackhi_ps(_r4, _r5); + __m256 _tmp6 = _mm256_unpacklo_ps(_r6, _r7); + __m256 _tmp7 = _mm256_unpackhi_ps(_r6, _r7); + __m256 _tmp8 = _mm256_unpacklo_ps(_r8, _r9); + __m256 _tmp9 = _mm256_unpackhi_ps(_r8, _r9); + __m256 _tmpa = _mm256_unpacklo_ps(_ra, _rb); + __m256 _tmpb = _mm256_unpackhi_ps(_ra, _rb); + __m256 _tmpc = _mm256_unpacklo_ps(_rc, _rd); + __m256 _tmpd = _mm256_unpackhi_ps(_rc, _rd); + __m256 _tmpe = _mm256_unpacklo_ps(_re, _rf); + __m256 _tmpf = _mm256_unpackhi_ps(_re, _rf); + + __m256 _tmpg = _mm256_shuffle_ps(_tmp0, _tmp2, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmph = _mm256_shuffle_ps(_tmp0, _tmp2, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmpi = _mm256_shuffle_ps(_tmp1, _tmp3, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpj = _mm256_shuffle_ps(_tmp1, _tmp3, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmpk = _mm256_shuffle_ps(_tmp4, _tmp6, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpl = _mm256_shuffle_ps(_tmp4, _tmp6, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmpm = _mm256_shuffle_ps(_tmp5, _tmp7, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpn = _mm256_shuffle_ps(_tmp5, _tmp7, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmpo = _mm256_shuffle_ps(_tmp8, _tmpa, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpp = _mm256_shuffle_ps(_tmp8, _tmpa, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmpq = _mm256_shuffle_ps(_tmp9, _tmpb, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpr = _mm256_shuffle_ps(_tmp9, _tmpb, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmps = _mm256_shuffle_ps(_tmpc, _tmpe, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpt = _mm256_shuffle_ps(_tmpc, _tmpe, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmpu = _mm256_shuffle_ps(_tmpd, _tmpf, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpv = _mm256_shuffle_ps(_tmpd, _tmpf, _MM_SHUFFLE(3, 2, 3, 2)); + + _r0 = _mm256_permute2f128_ps(_tmpg, _tmpk, _MM_SHUFFLE(0, 2, 0, 0)); + _r1 = _mm256_permute2f128_ps(_tmpo, _tmps, _MM_SHUFFLE(0, 2, 0, 0)); + _r2 = _mm256_permute2f128_ps(_tmph, _tmpl, _MM_SHUFFLE(0, 2, 0, 0)); + _r3 = _mm256_permute2f128_ps(_tmpp, _tmpt, _MM_SHUFFLE(0, 2, 0, 0)); + _r4 = _mm256_permute2f128_ps(_tmpi, _tmpm, _MM_SHUFFLE(0, 2, 0, 0)); + _r5 = _mm256_permute2f128_ps(_tmpq, _tmpu, _MM_SHUFFLE(0, 2, 0, 0)); + _r6 = _mm256_permute2f128_ps(_tmpj, _tmpn, _MM_SHUFFLE(0, 2, 0, 0)); + _r7 = _mm256_permute2f128_ps(_tmpr, _tmpv, _MM_SHUFFLE(0, 2, 0, 0)); + _r8 = _mm256_permute2f128_ps(_tmpg, _tmpk, _MM_SHUFFLE(0, 3, 0, 1)); + _r9 = _mm256_permute2f128_ps(_tmpo, _tmps, _MM_SHUFFLE(0, 3, 0, 1)); + _ra = _mm256_permute2f128_ps(_tmph, _tmpl, _MM_SHUFFLE(0, 3, 0, 1)); + _rb = _mm256_permute2f128_ps(_tmpp, _tmpt, _MM_SHUFFLE(0, 3, 0, 1)); + _rc = _mm256_permute2f128_ps(_tmpi, _tmpm, _MM_SHUFFLE(0, 3, 0, 1)); + _rd = _mm256_permute2f128_ps(_tmpq, _tmpu, _MM_SHUFFLE(0, 3, 0, 1)); + _re = _mm256_permute2f128_ps(_tmpj, _tmpn, _MM_SHUFFLE(0, 3, 0, 1)); + _rf = _mm256_permute2f128_ps(_tmpr, _tmpv, _MM_SHUFFLE(0, 3, 0, 1)); +} + +static inline void +transpose16x16_epi16(__m256i &_r0, __m256i &_r1, __m256i &_r2, __m256i &_r3, + __m256i &_r4, __m256i &_r5, __m256i &_r6, __m256i &_r7, + __m256i &_r8, __m256i &_r9, __m256i &_ra, __m256i &_rb, + __m256i &_rc, __m256i &_rd, __m256i &_re, __m256i &_rf) { + __m256i _tmp0 = _mm256_unpacklo_epi16(_r0, _r1); + __m256i _tmp1 = _mm256_unpackhi_epi16(_r0, _r1); + __m256i _tmp2 = _mm256_unpacklo_epi16(_r2, _r3); + __m256i _tmp3 = _mm256_unpackhi_epi16(_r2, _r3); + __m256i _tmp4 = _mm256_unpacklo_epi16(_r4, _r5); + __m256i _tmp5 = _mm256_unpackhi_epi16(_r4, _r5); + __m256i _tmp6 = _mm256_unpacklo_epi16(_r6, _r7); + __m256i _tmp7 = _mm256_unpackhi_epi16(_r6, _r7); + __m256i _tmp8 = _mm256_unpacklo_epi16(_r8, _r9); + __m256i _tmp9 = _mm256_unpackhi_epi16(_r8, _r9); + __m256i _tmpa = _mm256_unpacklo_epi16(_ra, _rb); + __m256i _tmpb = _mm256_unpackhi_epi16(_ra, _rb); + __m256i _tmpc = _mm256_unpacklo_epi16(_rc, _rd); + __m256i _tmpd = _mm256_unpackhi_epi16(_rc, _rd); + __m256i _tmpe = _mm256_unpacklo_epi16(_re, _rf); + __m256i _tmpf = _mm256_unpackhi_epi16(_re, _rf); + + __m256i _tmpg = _mm256_unpacklo_epi32(_tmp0, _tmp2); + __m256i _tmph = _mm256_unpackhi_epi32(_tmp0, _tmp2); + __m256i _tmpi = _mm256_unpacklo_epi32(_tmp1, _tmp3); + __m256i _tmpj = _mm256_unpackhi_epi32(_tmp1, _tmp3); + __m256i _tmpk = _mm256_unpacklo_epi32(_tmp4, _tmp6); + __m256i _tmpl = _mm256_unpackhi_epi32(_tmp4, _tmp6); + __m256i _tmpm = _mm256_unpacklo_epi32(_tmp5, _tmp7); + __m256i _tmpn = _mm256_unpackhi_epi32(_tmp5, _tmp7); + __m256i _tmpo = _mm256_unpacklo_epi32(_tmp8, _tmpa); + __m256i _tmpp = _mm256_unpackhi_epi32(_tmp8, _tmpa); + __m256i _tmpq = _mm256_unpacklo_epi32(_tmp9, _tmpb); + __m256i _tmpr = _mm256_unpackhi_epi32(_tmp9, _tmpb); + __m256i _tmps = _mm256_unpacklo_epi32(_tmpc, _tmpe); + __m256i _tmpt = _mm256_unpackhi_epi32(_tmpc, _tmpe); + __m256i _tmpu = _mm256_unpacklo_epi32(_tmpd, _tmpf); + __m256i _tmpv = _mm256_unpackhi_epi32(_tmpd, _tmpf); + + _tmp0 = _mm256_unpacklo_epi64(_tmpg, _tmpk); + _tmp1 = _mm256_unpackhi_epi64(_tmpg, _tmpk); + _tmp2 = _mm256_unpacklo_epi64(_tmph, _tmpl); + _tmp3 = _mm256_unpackhi_epi64(_tmph, _tmpl); + _tmp4 = _mm256_unpacklo_epi64(_tmpi, _tmpm); + _tmp5 = _mm256_unpackhi_epi64(_tmpi, _tmpm); + _tmp6 = _mm256_unpacklo_epi64(_tmpj, _tmpn); + _tmp7 = _mm256_unpackhi_epi64(_tmpj, _tmpn); + _tmp8 = _mm256_unpacklo_epi64(_tmpo, _tmps); + _tmp9 = _mm256_unpackhi_epi64(_tmpo, _tmps); + _tmpa = _mm256_unpacklo_epi64(_tmpp, _tmpt); + _tmpb = _mm256_unpackhi_epi64(_tmpp, _tmpt); + _tmpc = _mm256_unpacklo_epi64(_tmpq, _tmpu); + _tmpd = _mm256_unpackhi_epi64(_tmpq, _tmpu); + _tmpe = _mm256_unpacklo_epi64(_tmpr, _tmpv); + _tmpf = _mm256_unpackhi_epi64(_tmpr, _tmpv); + + _r0 = _mm256_permute2x128_si256(_tmp0, _tmp8, _MM_SHUFFLE(0, 2, 0, 0)); + _r1 = _mm256_permute2x128_si256(_tmp1, _tmp9, _MM_SHUFFLE(0, 2, 0, 0)); + _r2 = _mm256_permute2x128_si256(_tmp2, _tmpa, _MM_SHUFFLE(0, 2, 0, 0)); + _r3 = _mm256_permute2x128_si256(_tmp3, _tmpb, _MM_SHUFFLE(0, 2, 0, 0)); + _r4 = _mm256_permute2x128_si256(_tmp4, _tmpc, _MM_SHUFFLE(0, 2, 0, 0)); + _r5 = _mm256_permute2x128_si256(_tmp5, _tmpd, _MM_SHUFFLE(0, 2, 0, 0)); + _r6 = _mm256_permute2x128_si256(_tmp6, _tmpe, _MM_SHUFFLE(0, 2, 0, 0)); + _r7 = _mm256_permute2x128_si256(_tmp7, _tmpf, _MM_SHUFFLE(0, 2, 0, 0)); + _r8 = _mm256_permute2x128_si256(_tmp0, _tmp8, _MM_SHUFFLE(0, 3, 0, 1)); + _r9 = _mm256_permute2x128_si256(_tmp1, _tmp9, _MM_SHUFFLE(0, 3, 0, 1)); + _ra = _mm256_permute2x128_si256(_tmp2, _tmpa, _MM_SHUFFLE(0, 3, 0, 1)); + _rb = _mm256_permute2x128_si256(_tmp3, _tmpb, _MM_SHUFFLE(0, 3, 0, 1)); + _rc = _mm256_permute2x128_si256(_tmp4, _tmpc, _MM_SHUFFLE(0, 3, 0, 1)); + _rd = _mm256_permute2x128_si256(_tmp5, _tmpd, _MM_SHUFFLE(0, 3, 0, 1)); + _re = _mm256_permute2x128_si256(_tmp6, _tmpe, _MM_SHUFFLE(0, 3, 0, 1)); + _rf = _mm256_permute2x128_si256(_tmp7, _tmpf, _MM_SHUFFLE(0, 3, 0, 1)); +} + +static inline void transpose8x16_epi16(__m128i &_r0, __m128i &_r1, __m128i &_r2, + __m128i &_r3, __m128i &_r4, __m128i &_r5, + __m128i &_r6, __m128i &_r7, __m128i &_r8, + __m128i &_r9, __m128i &_ra, __m128i &_rb, + __m128i &_rc, __m128i &_rd, __m128i &_re, + __m128i &_rf) { + __m128i _tmp0 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _tmp1 = _mm_unpackhi_epi16(_r0, _r1); + __m128i _tmp2 = _mm_unpacklo_epi16(_r2, _r3); + __m128i _tmp3 = _mm_unpackhi_epi16(_r2, _r3); + __m128i _tmp4 = _mm_unpacklo_epi16(_r4, _r5); + __m128i _tmp5 = _mm_unpackhi_epi16(_r4, _r5); + __m128i _tmp6 = _mm_unpacklo_epi16(_r6, _r7); + __m128i _tmp7 = _mm_unpackhi_epi16(_r6, _r7); + __m128i _tmp8 = _mm_unpacklo_epi16(_r8, _r9); + __m128i _tmp9 = _mm_unpackhi_epi16(_r8, _r9); + __m128i _tmpa = _mm_unpacklo_epi16(_ra, _rb); + __m128i _tmpb = _mm_unpackhi_epi16(_ra, _rb); + __m128i _tmpc = _mm_unpacklo_epi16(_rc, _rd); + __m128i _tmpd = _mm_unpackhi_epi16(_rc, _rd); + __m128i _tmpe = _mm_unpacklo_epi16(_re, _rf); + __m128i _tmpf = _mm_unpackhi_epi16(_re, _rf); + + __m128i _tmpg = _mm_unpacklo_epi32(_tmp0, _tmp2); + __m128i _tmph = _mm_unpackhi_epi32(_tmp0, _tmp2); + __m128i _tmpi = _mm_unpacklo_epi32(_tmp1, _tmp3); + __m128i _tmpj = _mm_unpackhi_epi32(_tmp1, _tmp3); + __m128i _tmpk = _mm_unpacklo_epi32(_tmp4, _tmp6); + __m128i _tmpl = _mm_unpackhi_epi32(_tmp4, _tmp6); + __m128i _tmpm = _mm_unpacklo_epi32(_tmp5, _tmp7); + __m128i _tmpn = _mm_unpackhi_epi32(_tmp5, _tmp7); + __m128i _tmpo = _mm_unpacklo_epi32(_tmp8, _tmpa); + __m128i _tmpp = _mm_unpackhi_epi32(_tmp8, _tmpa); + __m128i _tmpq = _mm_unpacklo_epi32(_tmp9, _tmpb); + __m128i _tmpr = _mm_unpackhi_epi32(_tmp9, _tmpb); + __m128i _tmps = _mm_unpacklo_epi32(_tmpc, _tmpe); + __m128i _tmpt = _mm_unpackhi_epi32(_tmpc, _tmpe); + __m128i _tmpu = _mm_unpacklo_epi32(_tmpd, _tmpf); + __m128i _tmpv = _mm_unpackhi_epi32(_tmpd, _tmpf); + + _r0 = _mm_unpacklo_epi64(_tmpg, _tmpk); + _r1 = _mm_unpacklo_epi64(_tmpo, _tmps); + _r2 = _mm_unpackhi_epi64(_tmpg, _tmpk); + _r3 = _mm_unpackhi_epi64(_tmpo, _tmps); + _r4 = _mm_unpacklo_epi64(_tmph, _tmpl); + _r5 = _mm_unpacklo_epi64(_tmpp, _tmpt); + _r6 = _mm_unpackhi_epi64(_tmph, _tmpl); + _r7 = _mm_unpackhi_epi64(_tmpp, _tmpt); + _r8 = _mm_unpacklo_epi64(_tmpi, _tmpm); + _r9 = _mm_unpacklo_epi64(_tmpq, _tmpu); + _ra = _mm_unpackhi_epi64(_tmpi, _tmpm); + _rb = _mm_unpackhi_epi64(_tmpq, _tmpu); + _rc = _mm_unpacklo_epi64(_tmpj, _tmpn); + _rd = _mm_unpacklo_epi64(_tmpr, _tmpv); + _re = _mm_unpackhi_epi64(_tmpj, _tmpn); + _rf = _mm_unpackhi_epi64(_tmpr, _tmpv); +} + +static inline float _mm512_comp_reduce_add_ps(__m512 x) { + const __m256 x256 = + _mm256_add_ps(_mm512_castps512_ps256(x), _mm512_extractf32x8_ps(x, 1)); + const __m128 x128 = _mm_add_ps(_mm256_castps256_ps128(x256), + _mm256_extractf128_ps(x256, 1)); + const __m128 x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128)); + const __m128 x32 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55)); + return _mm_cvtss_f32(x32); +} + +static inline float _mm512_comp_reduce_max_ps(__m512 x) { + const __m256 x256 = + _mm256_max_ps(_mm512_castps512_ps256(x), _mm512_extractf32x8_ps(x, 1)); + const __m128 x128 = _mm_max_ps(_mm256_castps256_ps128(x256), + _mm256_extractf128_ps(x256, 1)); + const __m128 x64 = _mm_max_ps(x128, _mm_movehl_ps(x128, x128)); + const __m128 x32 = _mm_max_ss(x64, _mm_shuffle_ps(x64, x64, 0x55)); + return _mm_cvtss_f32(x32); +} + +static inline __m512 bfloat2float_avx512(const __m256i &v0) { +#if __AVX512BF16__ + __m512 _v = _mm512_cvtpbh_ps((__m256bh)v0); +#else + __m256i _zero = _mm256_setzero_si256(); + __m256i _a = _mm256_unpacklo_epi16(_zero, v0); + __m256i _b = _mm256_unpackhi_epi16(_zero, v0); + __m256i _c = _mm256_permute2x128_si256(_a, _b, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i _d = _mm256_permute2x128_si256(_a, _b, _MM_SHUFFLE(0, 3, 0, 1)); + __m512 _v = _mm512_castsi512_ps( + _mm512_inserti32x8(_mm512_castsi256_si512(_c), _d, 1)); +#endif + return _v; +} + +static inline __m256i float2bfloat_avx512(const __m512 &v0) { +#if __AVX512BF16__ + __m256i _v = (__m256i)_mm512_cvtneps_pbh(v0); +#else + __m512i _ab = _mm512_castps_si512(v0); + _ab = _mm512_srli_epi32(_ab, 16); + __m256i _a = _mm512_extracti32x8_epi32(_ab, 0); + __m256i _b = _mm512_extracti32x8_epi32(_ab, 1); + __m256i _v = _mm256_packus_epi32(_a, _b); + _v = _mm256_permute4x64_epi64(_v, _MM_SHUFFLE(3, 1, 2, 0)); +#endif + return _v; +} + +static inline __m512i float2bfloat_avx512(const __m512 &v0, const __m512 &v1) { +#if __AVX512BF16__ + __m256bh _v0 = _mm512_cvtneps_pbh(v0); + __m256bh _v1 = _mm512_cvtneps_pbh(v1); + __m512i _v = _mm512_inserti32x8(_mm512_castsi256_si512((__m256i)_v0), + (__m256i)_v1, 1); +#else + __m512i _a = _mm512_castps_si512(v0); + __m512i _b = _mm512_castps_si512(v1); + _a = _mm512_srli_epi32(_a, 16); + _b = _mm512_srli_epi32(_b, 16); + __m512i _v = _mm512_packus_epi32(_a, _b); + _v = _mm512_permutex_epi64(_v, _MM_SHUFFLE(3, 1, 2, 0)); + _v = _mm512_shuffle_i32x4(_v, _v, _MM_SHUFFLE(3, 1, 2, 0)); +#endif + return _v; +} + +#endif // __AVX512F__ +#endif // __AVX2__ +#endif // __AVX__ +#endif // __SSE2__ + +#endif // X86_USABILITY_H \ No newline at end of file diff --git a/src/Native/include/nncase/ntt/compiler_defs.h b/src/Native/include/nncase/ntt/compiler_defs.h new file mode 100644 index 0000000000..37be148b52 --- /dev/null +++ b/src/Native/include/nncase/ntt/compiler_defs.h @@ -0,0 +1,22 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#if defined(_MSC_VER) +// Fix: https://learn.microsoft.com/en-us/cpp/cpp/empty-bases +#define NTT_EMPTY_BASES __declspec(empty_bases) +#else +#define NTT_EMPTY_BASES +#endif diff --git a/src/Native/include/nncase/ntt/cpu_runtime.h b/src/Native/include/nncase/ntt/cpu_runtime.h new file mode 100644 index 0000000000..140faaf28d --- /dev/null +++ b/src/Native/include/nncase/ntt/cpu_runtime.h @@ -0,0 +1,49 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include + +extern "C" { +struct nncase_runtime_cpu_mt_t { + float (*acosf)(float v); + float (*acoshf)(float v); + float (*asinf)(float v); + float (*asinhf)(float v); + float (*copysignf)(float mag, float sgn); + float (*cosf)(float v); + float (*coshf)(float v); + float (*expf)(float v); + float (*fmodf)(float x, float y); + float (*logf)(float v); + float (*nearbyintf)(float v); + float (*powf)(float x, float y); + float (*sinf)(float v); + float (*sinhf)(float v); + float (*tanhf)(float v); + uint8_t *(*sram_address)(int bid, int tid); + void (*failfast)(const char *format, va_list args); + +#ifndef WIN32 + void *(*memcpy)(void *dst, const void *src, size_t len); +#endif +}; + +#ifdef NNCASE_CPU_MODULE +extern nncase_runtime_cpu_mt_t *g_cpu_mt; +extern size_t bid; +extern size_t tid; +#endif +} diff --git a/src/Native/include/nncase/ntt/detail/shape_storage.h b/src/Native/include/nncase/ntt/detail/shape_storage.h new file mode 100644 index 0000000000..fd42ca4d09 --- /dev/null +++ b/src/Native/include/nncase/ntt/detail/shape_storage.h @@ -0,0 +1,76 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "../compiler_defs.h" +#include "../shape.h" + +namespace nncase::ntt::detail { +template class shape_storage { + public: + shape_storage(Shape shape) : shape_(shape) {} + + constexpr Shape &shape() noexcept { return shape_; } + constexpr const Shape &shape() const noexcept { return shape_; } + + private: + Shape shape_; +}; + +template class shape_storage> { + public: + static constexpr auto shape() noexcept { return fixed_shape{}; } +}; + +template class strides_storage { + public: + strides_storage(Strides strides) : strides_(strides) {} + + constexpr Strides &strides() noexcept { return strides_; } + constexpr const Strides &strides() const noexcept { return strides_; } + + private: + Strides strides_; +}; + +template class strides_storage> { + public: + static constexpr auto strides() noexcept { + return fixed_strides{}; + } +}; + +template +struct NTT_EMPTY_BASES tensor_size_impl : public shape_storage, + public strides_storage { + tensor_size_impl(Shape shape, Strides strides) + : shape_storage(shape), strides_storage(strides) {} + + constexpr size_t size() noexcept { + return linear_size(this->shape(), this->strides()); + } +}; + +template +class NTT_EMPTY_BASES + tensor_size_impl, fixed_strides> + : public shape_storage>, + public strides_storage> { + public: + static constexpr size_t size() noexcept { + return linear_size(fixed_shape{}, + fixed_strides{}); + } +}; +} // namespace nncase::ntt::detail diff --git a/src/Native/include/nncase/ntt/detail/tensor_storage.h b/src/Native/include/nncase/ntt/detail/tensor_storage.h new file mode 100644 index 0000000000..04b887cc33 --- /dev/null +++ b/src/Native/include/nncase/ntt/detail/tensor_storage.h @@ -0,0 +1,104 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "../shape.h" +#include + +namespace nncase::ntt::detail { +template class tensor_storage; + +// fixed tensor +template class tensor_storage { + public: + using buffer_type = std::array; + + tensor_storage() = default; + + // ignore size + explicit tensor_storage(size_t) noexcept {} + tensor_storage(std::in_place_t, buffer_type value) noexcept + : buffer_(value) {} + + constexpr const buffer_type &buffer() const noexcept { return buffer_; } + constexpr buffer_type &buffer() noexcept { return buffer_; } + + constexpr std::span elements() const noexcept { + return buffer_; + } + constexpr std::span elements() noexcept { return buffer_; } + + private: + buffer_type buffer_; +}; + +// fixed view +template class tensor_storage { + public: + using buffer_type = std::span; + + tensor_storage(std::in_place_t, buffer_type value) : buffer_(value) {} + + constexpr const buffer_type &buffer() const noexcept { return buffer_; } + constexpr buffer_type &buffer() noexcept { return buffer_; } + + constexpr std::span elements() const noexcept { + return buffer_; + } + constexpr std::span elements() noexcept { return buffer_; } + + private: + buffer_type buffer_; +}; + +// dynamic tensor +template class tensor_storage { + public: + using buffer_type = std::vector; + + explicit tensor_storage(size_t size) : buffer_(size) {} + tensor_storage(std::in_place_t, buffer_type value) : buffer_(value) {} + + constexpr const buffer_type &buffer() const noexcept { return buffer_; } + constexpr buffer_type &buffer() noexcept { return buffer_; } + + constexpr std::span elements() const noexcept { + return {buffer_.data(), buffer_.size()}; + } + constexpr std::span elements() noexcept { + return {buffer_.data(), buffer_.size()}; + } + + private: + buffer_type buffer_; +}; + +// dynamic view +template class tensor_storage { + public: + using const_buffer_type = std::span; + using buffer_type = std::span; + + tensor_storage(std::in_place_t, buffer_type value) : buffer_(value) {} + + constexpr const_buffer_type buffer() const noexcept { return buffer_; } + constexpr buffer_type buffer() noexcept { return buffer_; } + + constexpr const_buffer_type elements() const noexcept { return buffer_; } + constexpr buffer_type elements() noexcept { return buffer_; } + + private: + buffer_type buffer_; +}; +} // namespace nncase::ntt::detail diff --git a/src/Native/include/nncase/ntt/kernels/arch/aarch64/binary.h b/src/Native/include/nncase/ntt/kernels/arch/aarch64/binary.h new file mode 100644 index 0000000000..6113623669 --- /dev/null +++ b/src/Native/include/nncase/ntt/kernels/arch/aarch64/binary.h @@ -0,0 +1,136 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include + +namespace nncase::ntt::mathops { + +template <> struct add> { + inline ntt::vector + operator()(ntt::vector v1, + ntt::vector v2) const noexcept { + float32x4x2_t r; + r.val[0] = ((float32x4x2_t)v1).val[0] + ((float32x4x2_t)v2).val[0]; + r.val[1] = ((float32x4x2_t)v1).val[1] + ((float32x4x2_t)v2).val[1]; + return r; + } +}; + +template <> struct sub> { + inline ntt::vector + operator()(ntt::vector v1, + ntt::vector v2) const noexcept { + float32x4x2_t r; + r.val[0] = ((float32x4x2_t)v1).val[0] - ((float32x4x2_t)v2).val[0]; + r.val[1] = ((float32x4x2_t)v1).val[1] - ((float32x4x2_t)v2).val[1]; + return r; + } +}; + +template <> struct mul> { + inline ntt::vector + operator()(ntt::vector v1, + ntt::vector v2) const noexcept { + float32x4x2_t r; + r.val[0] = ((float32x4x2_t)v1).val[0] * ((float32x4x2_t)v2).val[0]; + r.val[1] = ((float32x4x2_t)v1).val[1] * ((float32x4x2_t)v2).val[1]; + return r; + } +}; + +template <> struct div> { + inline ntt::vector + operator()(ntt::vector v1, + ntt::vector v2) const noexcept { + float32x4x2_t r; + r.val[0] = ((float32x4x2_t)v1).val[0] / ((float32x4x2_t)v2).val[0]; + r.val[1] = ((float32x4x2_t)v1).val[1] / ((float32x4x2_t)v2).val[1]; + return r; + } +}; +template <> struct max> { + inline ntt::vector + operator()(ntt::vector v1, + ntt::vector v2) const noexcept { + float32x4x2_t r; + r.val[0] = + vmaxq_f32(((float32x4x2_t)v1).val[0], ((float32x4x2_t)v2).val[0]); + r.val[1] = + vmaxq_f32(((float32x4x2_t)v1).val[1], ((float32x4x2_t)v2).val[1]); + return r; + } +}; + +template <> struct add> { + inline ntt::vector + operator()(ntt::vector v1, + ntt::vector v2) const noexcept { + return impl(v1, v2); + } + + inline float32x4_t impl(float32x4_t v1, float32x4_t v2) const noexcept { + return v1 + v2; + } +}; + +template <> struct sub> { + inline ntt::vector + operator()(ntt::vector v1, + ntt::vector v2) const noexcept { + return impl(v1, v2); + } + + inline float32x4_t impl(float32x4_t v1, float32x4_t v2) const noexcept { + return v1 - v2; + } +}; + +template <> struct mul> { + inline ntt::vector + operator()(ntt::vector v1, + ntt::vector v2) const noexcept { + return impl(v1, v2); + } + + inline float32x4_t impl(float32x4_t v1, float32x4_t v2) const noexcept { + return v1 * v2; + } +}; + +template <> struct div> { + inline ntt::vector + operator()(ntt::vector v1, + ntt::vector v2) const noexcept { + return impl(v1, v2); + } + + inline float32x4_t impl(float32x4_t v1, float32x4_t v2) const noexcept { + return v1 / v2; + } +}; +template <> struct max> { + inline ntt::vector + operator()(ntt::vector v1, + ntt::vector v2) const noexcept { + return impl(v1, v2); + } + + inline float32x4_t impl(float32x4_t v1, float32x4_t v2) const noexcept { + return vmaxq_f32(v1, v2); + } +}; + +} // namespace nncase::ntt::mathops \ No newline at end of file diff --git a/src/Native/include/nncase/ntt/kernels/arch/aarch64/pack_element.h b/src/Native/include/nncase/ntt/kernels/arch/aarch64/pack_element.h new file mode 100644 index 0000000000..7cda41c746 --- /dev/null +++ b/src/Native/include/nncase/ntt/kernels/arch/aarch64/pack_element.h @@ -0,0 +1,26 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include + +inline float32x4_t pack_elemt(const std::array &vec) { + return vld1q_f32(&vec[0]); +} + +inline float32x2_t pack_elemt(const std::array &vec) { + return vld1_f32(&vec[0]); +} diff --git a/src/Native/include/nncase/ntt/kernels/arch/aarch64/unary.h b/src/Native/include/nncase/ntt/kernels/arch/aarch64/unary.h new file mode 100644 index 0000000000..8e6c539b33 --- /dev/null +++ b/src/Native/include/nncase/ntt/kernels/arch/aarch64/unary.h @@ -0,0 +1,47 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "arm_math.h" +#include + +namespace std { +inline float32x4_t cos(float32x4_t v) { return cos_ps(v); } + +inline float32x4_t exp(float32x4_t v) { return exp_ps(v); } + +inline float32x4_t sqrt(float32x4_t v) { return vsqrtq_f32(v); } + +inline float32x4x2_t exp(float32x4x2_t v) { + return float32x4x2_t{exp_ps(v.val[0]), exp_ps(v.val[1])}; +} +} // namespace std + +namespace nncase::ntt { +namespace arch { +template +constexpr void unary(Op &&op, const T *input_p, T *output_p) { + for (size_t i = 0; i < Extent; i++) { + output_p[i] = op(input_p[i]); + } +} + +template +constexpr void unary(Op &&op, const T *input_p, T *output_p, size_t extent) { + for (size_t i = 0; i < extent; i++) { + output_p[i] = op(input_p[i]); + } +} +} // namespace arch +} // namespace nncase::ntt diff --git a/src/Native/include/nncase/ntt/kernels/arch/aarch64/unary_mathops.h b/src/Native/include/nncase/ntt/kernels/arch/aarch64/unary_mathops.h new file mode 100644 index 0000000000..551c23f22b --- /dev/null +++ b/src/Native/include/nncase/ntt/kernels/arch/aarch64/unary_mathops.h @@ -0,0 +1,23 @@ + + +namespace nncase::ntt::mathops { +template <> struct sqrt> { + ntt::vector operator()(ntt::vector v) const noexcept { + float32x4x2_t vv = v; + return float32x4x2_t{vsqrtq_f32(vv.val[0]), vsqrtq_f32(vv.val[1])}; + } +}; + +template <> struct swish> { + ntt::vector operator()(ntt::vector v) const noexcept { + float32x4x2_t vv = v; + return float32x4x2_t{impl(vv.val[0]), impl(vv.val[1])}; + } + + float32x4_t impl(float32x4_t v) const noexcept { + auto zero = vdupq_n_f32(0); + auto one = vdupq_n_f32(1); + return v / exp_ps(zero - v) + one; + } +}; +} // namespace nncase::ntt::mathops \ No newline at end of file diff --git a/src/Native/include/nncase/ntt/kernels/arch/aarch64/unpack_element.h b/src/Native/include/nncase/ntt/kernels/arch/aarch64/unpack_element.h new file mode 100644 index 0000000000..3997143bf3 --- /dev/null +++ b/src/Native/include/nncase/ntt/kernels/arch/aarch64/unpack_element.h @@ -0,0 +1,22 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include + +inline void unpack_elemt(std::array &arr, const float32x4_t &vec) { + vst1q_f32(&arr[0], vec); +} diff --git a/src/Native/include/nncase/ntt/kernels/arch/aarch64/vector_ops.h b/src/Native/include/nncase/ntt/kernels/arch/aarch64/vector_ops.h new file mode 100644 index 0000000000..52e29fac78 --- /dev/null +++ b/src/Native/include/nncase/ntt/kernels/arch/aarch64/vector_ops.h @@ -0,0 +1,57 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include + +namespace nncase::ntt::vector_ops { +template <> struct reduce_sum> { + float operator()(ntt::vector v) const noexcept { + float32x2_t vec1 = vadd_f32(vget_low_f32(v), vget_high_f32(v)); + return vaddv_f32(vec1); + } +}; + +template <> struct reduce_sum> { + float operator()(ntt::vector v) const noexcept { + float32x4x2_t val = v; + float result = 0; + auto vec = val.val[0]; + float32x2_t vec1 = vadd_f32(vget_low_f32(vec), vget_high_f32(vec)); + float32x2_t vec2 = vadd_f32(vec1, vrev64_f32(vec1)); + result += vget_lane_f32(vec2, 0); + + vec = val.val[1]; + vec1 = vadd_f32(vget_low_f32(vec), vget_high_f32(vec)); + vec2 = vadd_f32(vec1, vrev64_f32(vec1)); + result += vget_lane_f32(vec2, 0); + + return result; + } +}; + +template <> struct reduce_max> { + float operator()(ntt::vector v) const noexcept { + return vmaxvq_f32(v); + } +}; + +template <> struct reduce_max> { + float operator()(ntt::vector v) const noexcept { + float32x4x2_t val = v; + return std::max(vmaxvq_f32(val.val[0]), vmaxvq_f32(val.val[1])); + } +}; + +} // namespace nncase::ntt::vector_ops diff --git a/src/Native/include/nncase/ntt/kernels/arch/aarch64/vector_types.h b/src/Native/include/nncase/ntt/kernels/arch/aarch64/vector_types.h new file mode 100644 index 0000000000..e27d102b24 --- /dev/null +++ b/src/Native/include/nncase/ntt/kernels/arch/aarch64/vector_types.h @@ -0,0 +1,20 @@ +#pragma once +#include + +namespace nncase::ntt { +template <> struct native_vector_type { + using type = float32x4_t[8]; +}; + +template <> struct native_vector_type { + using type = float32x4x2_t; + static type from_element(const float &f) { + return type{vdupq_n_f32(f), vdupq_n_f32(f)}; + } +}; + +template <> struct native_vector_type { + using type = float32x4_t; + static type from_element(const float &f) { return vdupq_n_f32(f); } +}; +} // namespace nncase::ntt \ No newline at end of file diff --git a/src/Native/include/nncase/ntt/kernels/arch/x86_64/binary.h b/src/Native/include/nncase/ntt/kernels/arch/x86_64/binary.h new file mode 100644 index 0000000000..bc985dc735 --- /dev/null +++ b/src/Native/include/nncase/ntt/kernels/arch/x86_64/binary.h @@ -0,0 +1,51 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include + +namespace nncase::ntt::mathops { + +template <> struct add> { + ntt::vector operator()(ntt::vector v1, + ntt::vector v2) const noexcept { + return _mm256_add_ps(v1, v2); + } +}; +template <> struct sub> { + ntt::vector operator()(ntt::vector v1, + ntt::vector v2) const noexcept { + return _mm256_sub_ps(v1, v2); + } +}; +template <> struct mul> { + ntt::vector operator()(ntt::vector v1, + ntt::vector v2) const noexcept { + return _mm256_mul_ps(v1, v2); + } +}; +template <> struct div> { + ntt::vector operator()(ntt::vector v1, + ntt::vector v2) const noexcept { + return _mm256_div_ps(v1, v2); + } +}; +template <> struct max> { + ntt::vector operator()(ntt::vector v1, + ntt::vector v2) const noexcept { + return _mm256_max_ps(v1, v2); + } +}; +} // namespace nncase::ntt::mathops \ No newline at end of file diff --git a/src/Native/include/nncase/ntt/kernels/arch/x86_64/pack_element.h b/src/Native/include/nncase/ntt/kernels/arch/x86_64/pack_element.h new file mode 100644 index 0000000000..30fb55ebbf --- /dev/null +++ b/src/Native/include/nncase/ntt/kernels/arch/x86_64/pack_element.h @@ -0,0 +1,26 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include + +inline __m128 pack_elemt(const std::array &vec) { + return _mm_load_ps(&vec[0]); +} + +inline __m256 pack_elemt(const std::array &vec) { + return _mm256_load_ps(&vec[0]); +} diff --git a/src/Native/include/nncase/ntt/kernels/arch/x86_64/unary.h b/src/Native/include/nncase/ntt/kernels/arch/x86_64/unary.h new file mode 100644 index 0000000000..0fcbc41f32 --- /dev/null +++ b/src/Native/include/nncase/ntt/kernels/arch/x86_64/unary.h @@ -0,0 +1,65 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "../../../vector_type.h" +#include "avx_mathfun.h" +#include + +namespace std { +inline __m256 cos(__m256 v) { + __m256 s, c; + sincos256_ps(v, &s, &c); + return s; +} + +inline __m128 cos(__m128 v) { + float arr[4]; + _mm_store_ps(arr, v); + for (size_t i = 0; i < 4; i++) { + arr[i] = cosf(arr[i]); + } + return _mm_load_ps(arr); +} + +inline __m128 sqrt(__m128 v) { return _mm_sqrt_ps(v); } +inline __m256 sqrt(__m256 v) { return _mm256_sqrt_ps(v); } +inline __m256 exp(__m256 v) { return exp256_ps(v); } +} // namespace std + +namespace nncase::ntt::arch { +template +constexpr void unary(Op &&op, const T *input_p, T *output_p) { + for (size_t i = 0; i < Extent; i++) { + output_p[i] = op(input_p[i]); + } +} + +template +constexpr void unary(Op &&op, const T *input_p, T *output_p, size_t extent) { + for (size_t i = 0; i < extent; i++) { + output_p[i] = op(input_p[i]); + } +} +} // namespace nncase::ntt::arch + +// namespace nncase::ntt::mathops { +// template <> struct sqrt> { +// ntt::vector operator()(ntt::vector v) const noexcept +// { +// return std::sqrt(v); +// } +// }; + +// } // namespace nncase::ntt::mathops \ No newline at end of file diff --git a/src/Native/include/nncase/ntt/kernels/arch/x86_64/unary_mathops.h b/src/Native/include/nncase/ntt/kernels/arch/x86_64/unary_mathops.h new file mode 100644 index 0000000000..158657624c --- /dev/null +++ b/src/Native/include/nncase/ntt/kernels/arch/x86_64/unary_mathops.h @@ -0,0 +1,21 @@ + + +namespace nncase::ntt::mathops { +template <> struct swish> { + ntt::vector operator()(ntt::vector v) const noexcept { + return impl(v); + } + + __m256 impl(__m256 v) const noexcept { + auto zero = _mm256_set1_ps(0); + auto one = _mm256_set1_ps(1); + return v / exp256_ps(zero - v) + one; + } +}; + +template <> struct neg> { + ntt::vector operator()(ntt::vector v) const noexcept { + return _mm256_set1_ps(0) - (__m256)v; + } +}; +} // namespace nncase::ntt::mathops \ No newline at end of file diff --git a/src/Native/include/nncase/ntt/kernels/arch/x86_64/unpack_element.h b/src/Native/include/nncase/ntt/kernels/arch/x86_64/unpack_element.h new file mode 100644 index 0000000000..35a59e0adb --- /dev/null +++ b/src/Native/include/nncase/ntt/kernels/arch/x86_64/unpack_element.h @@ -0,0 +1,26 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include + +inline void unpack_elemt(std::array &arr, const __m128 &vec) { + _mm_store_ps(&arr[0], vec); +} + +inline void unpack_elemt(std::array &arr, const __m256 &vec) { + _mm256_store_ps(&arr[0], vec); +} diff --git a/src/Native/include/nncase/ntt/kernels/arch/x86_64/vector_ops.h b/src/Native/include/nncase/ntt/kernels/arch/x86_64/vector_ops.h new file mode 100644 index 0000000000..5fb89e1cf4 --- /dev/null +++ b/src/Native/include/nncase/ntt/kernels/arch/x86_64/vector_ops.h @@ -0,0 +1,61 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include + +namespace nncase::ntt::vector_ops { +template <> struct reduce_sum> { + float operator()(ntt::vector v) const noexcept { + auto res0 = _mm_hadd_ps(v, v); // a,b,c,d -> (a+b, c+d, a+b, c+d) + res0 = _mm_hadd_ps(res0, res0); // (a+b, c+d, a+b, c+d) + return _mm_cvtss_f32(res0); + } +}; + +template <> struct reduce_max> { + float operator()(ntt::vector v) const noexcept { + __m128 h = _mm_unpackhi_ps(v, v); // c,d,c,d + __m128 l = _mm_unpacklo_ps(v, v); // a,b,a,b + auto r = _mm_max_ps(l, h); // max(a,c),max(b,d), ... + return std::max(r[0], r[1]); + } +}; + +template <> struct reduce_sum> { + float operator()(ntt::vector v) const noexcept { + // horizontal add top lane and bottom lane + auto res0 = _mm256_hadd_ps(v, v); + res0 = _mm256_hadd_ps(res0, res0); + __m128 acc1 = _mm256_extractf128_ps(res0, 0); + __m128 acc2 = _mm256_extractf128_ps(res0, 1); + acc1 = _mm_add_ss(acc1, acc2); + return _mm_cvtss_f32(acc1); + } +}; + +template <> struct reduce_max> { + float operator()(ntt::vector v) const noexcept { + __m128 lhs = _mm256_extractf128_ps(v, 0); + __m128 rhs = _mm256_extractf128_ps(v, 1); + __m128 r = _mm_max_ps(lhs, rhs); // a,b,c,d + + __m128 h = _mm_unpackhi_ps(r, r); // c,d,c,d + __m128 l = _mm_unpacklo_ps(r, r); // a,b,a,b + r = _mm_max_ps(l, h); // max(a,c),max(b,d), ... + return std::max(r[0], r[1]); + } +}; + +} // namespace nncase::ntt::vector_ops diff --git a/src/Native/include/nncase/ntt/kernels/arch/x86_64/vector_types.h b/src/Native/include/nncase/ntt/kernels/arch/x86_64/vector_types.h new file mode 100644 index 0000000000..872187bfd5 --- /dev/null +++ b/src/Native/include/nncase/ntt/kernels/arch/x86_64/vector_types.h @@ -0,0 +1,15 @@ +#pragma once +#include +namespace nncase::ntt { +template <> struct native_vector_type { + using type = __m128; + static type from_element(const float &f) { return _mm_setr_ps(f, f, f, f); } +}; + +template <> struct native_vector_type { + using type = __m256; + static type from_element(const float &f) { + return _mm256_setr_ps(f, f, f, f, f, f, f, f); + } +}; +} // namespace nncase::ntt \ No newline at end of file diff --git a/src/Native/include/nncase/ntt/kernels/binary.h b/src/Native/include/nncase/ntt/kernels/binary.h new file mode 100644 index 0000000000..0fd5d86850 --- /dev/null +++ b/src/Native/include/nncase/ntt/kernels/binary.h @@ -0,0 +1,35 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "../apply.h" +#include "../shape_infer/binary.h" +#include "../shape_infer/reduce.h" +#include + +namespace nncase::ntt { +template