diff --git a/.gitignore b/.gitignore index 37cf70c73..5d27f62b6 100644 --- a/.gitignore +++ b/.gitignore @@ -26,4 +26,7 @@ examples/csharp/HelloPhi/models !test/test_models/hf-internal-testing/ !test/test_models/hf-internal-testing/tiny-random-gpt2*/*.onnx -.ipynb_checkpoints/ \ No newline at end of file +.ipynb_checkpoints/ +/src/java/.gradle +/src/java/local.properties +/src/java/build diff --git a/CMakeLists.txt b/CMakeLists.txt index 01c3d45f4..65e9a5605 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -8,7 +8,6 @@ project(Generators LANGUAGES C CXX) # All Options should be defined in cmake/options.cmake This must be included before any other cmake file is included include(cmake/options.cmake) - include(cmake/external/onnxruntime_external_deps.cmake) # All Global variables, including GLOB, for the top level CMakeLists.txt should be defined here include(cmake/global_variables.cmake) @@ -19,24 +18,21 @@ include(cmake/check_dml.cmake) include(cmake/cxx_standard.cmake) - - -if (ANDROID) - # Paths are based on the directory structure of the ORT Android AAR. - set(ORT_HEADER_DIR ${ORT_HOME}/headers) - set(ORT_LIB_DIR ${ORT_HOME}/jni/${ANDROID_ABI}) -else() -endif() - add_compile_definitions(BUILDING_ORT_GENAI_C) if(MSVC) # set updated value for __cplusplus macro instead of 199711L add_compile_options($<$:/Zc:__cplusplus>) endif() -if(ENABLE_TESTS AND TEST_PHI2) - add_compile_definitions(TEST_PHI2=1) -else() - add_compile_definitions(TEST_PHI2=0) + +if(ENABLE_TESTS) + # call enable_testing so we can add tests from subdirectories (e.g. test and src/java) + # it applies recursively to all subdirectories + enable_testing() + if (TEST_PHI2) + add_compile_definitions(TEST_PHI2=1) + else() + add_compile_definitions(TEST_PHI2=0) + endif() endif() @@ -117,18 +113,24 @@ if(USE_DML) endif() if(ENABLE_TESTS) - add_subdirectory("${CMAKE_SOURCE_DIR}/test") message("------------------Enabling tests------------------") + add_subdirectory("${REPO_ROOT}/test") endif() if(ENABLE_PYTHON) - add_subdirectory("${CMAKE_SOURCE_DIR}/src/python") message("------------------Enabling Python Wheel------------------") + add_subdirectory("${SRC_ROOT}/python") endif() +if (ENABLE_JAVA) + message("------------------Enabling Java Jar------------------") + add_subdirectory("${SRC_ROOT}/java") +endif() + + if(ENABLE_MODEL_BENCHMARK) - add_subdirectory("${CMAKE_SOURCE_DIR}/benchmark/c") message("------------------Enabling model benchmark------------------") + add_subdirectory("${REPO_ROOT}/benchmark/c") endif() # Copy the onnxruntime binaries into the build folder so it's found on launch diff --git a/build.py b/build.py index 84df6cdbe..5e5b1c063 100644 --- a/build.py +++ b/build.py @@ -61,6 +61,9 @@ def _parse_args(): parser.add_argument("--skip_wheel", action="store_true", help="Skip building the Python wheel.") parser.add_argument("--skip_csharp", action="store_true", help="Skip building the C# API.") + # Default to not building the Java bindings + parser.add_argument("--build_java", action="store_true", help="Build Java bindings.") + parser.add_argument("--parallel", action="store_true", help="Enable parallel build.") # CI's sometimes explicitly set the path to the CMake and CTest executables. @@ -363,8 +366,9 @@ def update(args: argparse.Namespace, env: dict[str, str]): "-B", str(args.build_dir), "-DCMAKE_POSITION_INDEPENDENT_CODE=ON", - "-DUSE_CUDA=ON" if args.use_cuda else "-DUSE_CUDA=OFF", - "-DUSE_DML=ON" if args.use_dml else "-DUSE_DML=OFF", + f"-DUSE_CUDA={'ON' if args.use_cuda else 'OFF'}", + f"-DUSE_DML={'ON' if args.use_dml else 'OFF'}", + f"-DENABLE_JAVA={'ON' if args.build_java else 'OFF'}", f"-DBUILD_WHEEL={build_wheel}", ] diff --git a/build.sh b/build.sh old mode 100644 new mode 100755 diff --git a/cmake/global_variables.cmake b/cmake/global_variables.cmake index d5b9fbb11..bb7dcab6b 100644 --- a/cmake/global_variables.cmake +++ b/cmake/global_variables.cmake @@ -7,9 +7,11 @@ message("Building onnxruntime-genai for version ${VERSION_INFO}") # Define the project directories -set(GENERATORS_ROOT ${PROJECT_SOURCE_DIR}/src) -set(MODELS_ROOT ${PROJECT_SOURCE_DIR}/src/models) -set(ORT_HOME ${CMAKE_SOURCE_DIR}/ort CACHE PATH "Path to the onnxruntime root directory.") +set(REPO_ROOT ${PROJECT_SOURCE_DIR}) +set(SRC_ROOT ${REPO_ROOT}/src) +set(GENERATORS_ROOT ${SRC_ROOT}) +set(MODELS_ROOT ${SRC_ROOT}/models) +set(ORT_HOME ${REPO_ROOT}/ort CACHE PATH "Path to the onnxruntime root directory.") if (ANDROID) # Paths are based on the directory structure of the ORT Android AAR. @@ -53,4 +55,59 @@ if(NOT EXISTS "${ORT_LIB_DIR}/${ONNXRUNTIME_LIB}") endif() if(NOT EXISTS "${ORT_HEADER_DIR}/onnxruntime_c_api.h") message(FATAL_ERROR "Expected the ONNX Runtime C API header to be found at \"${ORT_HEADER_DIR}/onnxruntime_c_api.h\". Actual: Not found.") -endif() \ No newline at end of file +endif() + + +# normalize the target platform to x64 or arm64. additional architectures can be added as needed. +if (MSVC) + if (CMAKE_VS_PLATFORM_NAME) + # cross-platform generator + set(genai_target_platform ${CMAKE_VS_PLATFORM_NAME}) + else() + set(genai_target_platform ${CMAKE_SYSTEM_PROCESSOR}) + endif() + + if (genai_target_platform STREQUAL "arm64") + # pass + elseif (genai_target_platform STREQUAL "x64" OR + genai_target_platform STREQUAL "x86_64" OR + genai_target_platform STREQUAL "AMD64" OR + CMAKE_GENERATOR MATCHES "Win64") + set(genai_target_platform "x64") + else() + message(FATAL_ERROR "Unsupported architecture. CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}") + endif() +elseif(APPLE) + # TODO: do we need to support CMAKE_OSX_ARCHITECTURES having multiple values? + set(_apple_target_arch ${CMAKE_OSX_ARCHITECTURES}) + if (NOT _apple_target_arch) + set(_apple_target_arch ${CMAKE_HOST_SYSTEM_PROCESSOR}) + endif() + + if (_apple_target_arch STREQUAL "arm64") + set(genai_target_platform "arm64") + elseif (_apple_target_arch STREQUAL "x86_64") + set(genai_target_platform "x64") + else() + message(FATAL_ERROR "Unsupported architecture. ${_apple_target_arch}") + endif() +elseif(ANDROID) + if (CMAKE_ANDROID_ARCH_ABI STREQUAL "arm64-v8a") + set(genai_target_platform "arm64") + elseif (CMAKE_ANDROID_ARCH_ABI STREQUAL "x86_64") + set(genai_target_platform "x64") + else() + message(FATAL_ERROR "Unsupported architecture. CMAKE_ANDROID_ARCH_ABI: ${CMAKE_ANDROID_ARCH_ABI}") + endif() +else() + if(CMAKE_SYSTEM_PROCESSOR MATCHES "^arm64.*") + set(genai_target_platform "arm64") + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64.*") + set(genai_target_platform "arm64") + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$") + set(genai_target_platform "x64") + else() + message(FATAL_ERROR "Unsupported architecture. CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}") + endif() +endif() + diff --git a/cmake/options.cmake b/cmake/options.cmake index 688633fda..c1aeaa1a5 100644 --- a/cmake/options.cmake +++ b/cmake/options.cmake @@ -1,10 +1,18 @@ include(CMakeDependentOption) +# features option(USE_CUDA "Build with CUDA support" ON) option(USE_DML "Build with DML support" OFF) + +# bindings +option(ENABLE_JAVA "Build the Java API." OFF) option(ENABLE_PYTHON "Build the Python API." ON) +cmake_dependent_option(BUILD_WHEEL "Build the python wheel" ON "ENABLE_PYTHON" OFF) + +# testing option(ENABLE_TESTS "Enable tests" ON) option(TEST_PHI2 "Enable tests for Phi2" OFF) + +# performance option(ENABLE_MODEL_BENCHMARK "Build model benchmark program" ON) -cmake_dependent_option(BUILD_WHEEL "Build the python wheel" ON "ENABLE_PYTHON" OFF) \ No newline at end of file diff --git a/src/java/CMakeLists.txt b/src/java/CMakeLists.txt new file mode 100644 index 000000000..4f723cd20 --- /dev/null +++ b/src/java/CMakeLists.txt @@ -0,0 +1,207 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +set(JAVA_AWT_LIBRARY NotNeeded) +set(JAVA_AWT_INCLUDE_PATH NotNeeded) +include(FindJava) +find_package(Java REQUIRED) +include(UseJava) + +if (NOT ANDROID) + find_package(JNI REQUIRED) +endif() + +set(JAVA_SRC_ROOT ${CMAKE_CURRENT_SOURCE_DIR}) +# /src/java (path used with add_subdirectory in root CMakeLists.txt) +set(JAVA_OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}) + +# Should we use onnxruntime-genai or onnxruntime-genai-static? Using onnxruntime-genai for now. +# Add dependency on native target +set(JAVA_DEPENDS onnxruntime-genai) + +set(GRADLE_EXECUTABLE "${JAVA_SRC_ROOT}/gradlew") + +file(GLOB_RECURSE genai4j_gradle_files "${JAVA_SRC_ROOT}/*.gradle") +file(GLOB_RECURSE genai4j_srcs "${JAVA_SRC_ROOT}/src/main/java/ai/onnxruntime-genai/*.java") + +# set gradle options that are used with multiple gradle commands +if(WIN32) + set(GRADLE_OPTIONS --console=plain -Dorg.gradle.daemon=false) +elseif (ANDROID) + # For Android build, we may run gradle multiple times in same build. Sometimes gradle JVM will run out of memory + # if we keep the daemon running, so we use no-daemon to avoid that + set(GRADLE_OPTIONS --console=plain --no-daemon) +endif() + +# this jar is solely used to signaling mechanism for dependency management in CMake +# if any of the Java sources change, the jar (and generated headers) will be regenerated +# and the onnxruntime-genai-jni target will be rebuilt +set(JAVA_OUTPUT_JAR ${JAVA_OUTPUT_DIR}/build/libs/onnxruntime-genai.jar) +set(GRADLE_ARGS clean jar -x test) + +# this jar is solely used to signaling mechanism for dependency management in CMake +# if any of the Java sources change, the jar (and generated headers) will be regenerated +# and the onnxruntime-genai-jni target will be rebuilt +set(JAVA_OUTPUT_JAR ${JAVA_SRC_ROOT}/build/libs/onnxruntime-genai.jar) +set(GRADLE_ARGS clean jar -x test) + +add_custom_command(OUTPUT ${JAVA_OUTPUT_JAR} + COMMAND ${GRADLE_EXECUTABLE} ${GRADLE_OPTIONS} ${GRADLE_ARGS} + WORKING_DIRECTORY ${JAVA_SRC_ROOT} + DEPENDS ${genai4j_gradle_files} ${genai4j_srcs}) +add_custom_target(onnxruntime-genai4j DEPENDS ${JAVA_OUTPUT_JAR}) + +set_source_files_properties(${JAVA_OUTPUT_JAR} PROPERTIES GENERATED TRUE) +set_property(TARGET onnxruntime-genai4j APPEND PROPERTY ADDITIONAL_CLEAN_FILES "${JAVA_OUTPUT_DIR}") + +# Specify the JNI native sources +file(GLOB genai4j_native_src + "${JAVA_SRC_ROOT}/src/main/native/*.cpp" + "${JAVA_SRC_ROOT}/src/main/native/*.h" + "${SRC_ROOT}/ort_genai_c.h" + ) + +add_library(onnxruntime-genai-jni SHARED ${genai4j_native_src}) +set_property(TARGET onnxruntime-genai-jni PROPERTY CXX_STANDARD 17) +add_dependencies(onnxruntime-genai-jni onnxruntime-genai4j) +# the JNI headers are generated in the genai4j target +target_include_directories(onnxruntime-genai-jni PRIVATE ${SRC_ROOT} + ${JAVA_SRC_ROOT}/build/headers + ${JNI_INCLUDE_DIRS}) +target_link_libraries(onnxruntime-genai-jni PUBLIC onnxruntime-genai) + +set(JAVA_PACKAGE_OUTPUT_DIR ${JAVA_OUTPUT_DIR}/build) +file(MAKE_DIRECTORY ${JAVA_PACKAGE_OUTPUT_DIR}) +if (ANDROID) + set(ANDROID_PACKAGE_OUTPUT_DIR ${JAVA_PACKAGE_OUTPUT_DIR}/android) + file(MAKE_DIRECTORY ${ANDROID_PACKAGE_OUTPUT_DIR}) +endif() + +if (WIN32) + set(JAVA_PLAT "win") +elseif (APPLE) + set(JAVA_PLAT "osx") +elseif (LINUX) + set(JAVA_PLAT "linux") +elseif (ANDROID) + set(JAVA_PLAT "android") +else() + message(FATAL_ERROR "GenAI with Java is not currently supported on this platform") +endif() + +# Set platform and arch for packaging +if (genai_target_platform STREQUAL "x64") + set(JNI_ARCH x64) +elseif (genai_target_platform STREQUAL "arm64") + set(JNI_ARCH aarch64) +else() + message(FATAL_ERROR "GenAI with Java is not currently supported on this platform") +endif() + +# Similar to Nuget schema +set(JAVA_OS_ARCH ${JAVA_PLAT}-${JNI_ARCH}) + +# expose native libraries to the gradle build process +set(JAVA_PACKAGE_DIR ai/onnxruntime-genai/native/${JAVA_OS_ARCH}) +set(JAVA_NATIVE_LIB_DIR ${JAVA_OUTPUT_DIR}/native-lib) +set(JAVA_PACKAGE_LIB_DIR ${JAVA_NATIVE_LIB_DIR}/${JAVA_PACKAGE_DIR}) +file(MAKE_DIRECTORY ${JAVA_PACKAGE_LIB_DIR}) + +# Add the native genai library to the native-lib dir +add_custom_command(TARGET onnxruntime-genai-jni POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different + $ + ${JAVA_PACKAGE_LIB_DIR}/$) + +# also need the onnxruntime libraries in the same directory as onnxruntime-genai.dll has a dependency on it +foreach (ort_dll ${onnxruntime_libs}) + add_custom_command(TARGET onnxruntime-genai-jni POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${ort_dll} ${JAVA_PACKAGE_LIB_DIR}) +endforeach() + +# Add the JNI bindings to the native-jni dir +add_custom_command(TARGET onnxruntime-genai-jni POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different + $ + ${JAVA_PACKAGE_LIB_DIR}/$) + +# run the build process +set(GRADLE_ARGS cmakeBuild -DcmakeBuildDir=${JAVA_OUTPUT_DIR} -DnativeLibDir=${JAVA_PACKAGE_LIB_DIR}) +add_custom_command(TARGET onnxruntime-genai-jni POST_BUILD + COMMAND ${GRADLE_EXECUTABLE} ${GRADLE_OPTIONS} ${GRADLE_ARGS} + WORKING_DIRECTORY ${JAVA_SRC_ROOT}) + +if (ANDROID) + set(ANDROID_PACKAGE_JNILIBS_DIR ${JAVA_OUTPUT_DIR}/android) + set(ANDROID_PACKAGE_ABI_DIR ${ANDROID_PACKAGE_JNILIBS_DIR}/${ANDROID_ABI}) + file(MAKE_DIRECTORY ${ANDROID_PACKAGE_JNILIBS_DIR}) + file(MAKE_DIRECTORY ${ANDROID_PACKAGE_ABI_DIR}) + + # Copy onnxruntime-genai.so and genai-jni.so for building Android AAR package + add_custom_command(TARGET onnxruntime-genai-jni POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different + $ + ${ANDROID_PACKAGE_ABI_DIR}/$) + + add_custom_command(TARGET onnxruntime-genai-jni POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different + $ + ${ANDROID_PACKAGE_ABI_DIR}/$) + + # Generate the Android AAR package + add_custom_command(TARGET onnxruntime-genai-jni + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E echo "Generating Android AAR package..." + COMMAND ${GRADLE_EXECUTABLE} + build + -b build-android.gradle -c settings-android.gradle + -DjniLibsDir=${ANDROID_PACKAGE_JNILIBS_DIR} -DbuildDir=${ANDROID_PACKAGE_OUTPUT_DIR} + -DminSdkVer=${ANDROID_MIN_SDK} -DheadersDir=${ANDROID_HEADERS_DIR} + --stacktrace + WORKING_DIRECTORY ${JAVA_SRC_ROOT}) + + # unit tests + set(ANDROID_TEST_PACKAGE_ROOT ${JAVA_SRC_ROOT}/src/test/android) + set(ANDROID_TEST_PACKAGE_DIR ${JAVA_OUTPUT_DIR}/androidtest/android) + # copy the androidtest project into cmake binary directory + file(MAKE_DIRECTORY ${JAVA_OUTPUT_DIR}/androidtest) + file(COPY ${ANDROID_TEST_PACKAGE_ROOT} DESTINATION ${JAVA_OUTPUT_DIR}/androidtest) + set(ANDROID_TEST_PACKAGE_LIB_DIR ${ANDROID_TEST_PACKAGE_DIR}/app/libs) + file(MAKE_DIRECTORY ${ANDROID_TEST_PACKAGE_LIB_DIR}) + + # Copy the built Android AAR package to libs folder of our test app + add_custom_command(TARGET onnxruntime-genai-jni POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different + ${ANDROID_PACKAGE_OUTPUT_DIR}/outputs/aar/onnxruntime-genai-debug.aar + ${ANDROID_TEST_PACKAGE_LIB_DIR}/onnxruntime-genai.aar) + + # Build Android test apk for java package + add_custom_command(TARGET onnxruntime-genai-jni POST_BUILD + COMMAND ${CMAKE_COMMAND} -E echo "Building and running Android test for Android AAR package..." + COMMAND ${GRADLE_EXECUTABLE} clean assembleDebug assembleDebugAndroidTest + -DminSdkVer=${ANDROID_MIN_SDK} --stacktrace + WORKING_DIRECTORY ${ANDROID_TEST_PACKAGE_DIR}) +endif() + +if (ENABLE_TESTS) + message(STATUS "Adding Java tests") + if (WIN32) + # On windows ctest requires a test to be an .exe(.com) file + # With gradle wrapper we get gradlew.bat. We delegate execution to a separate .cmake file + # That can handle both .exe and .bat + add_test(NAME onnxruntime-genai4j_test + COMMAND ${CMAKE_COMMAND} + -DGRADLE_EXECUTABLE=${GRADLE_EXECUTABLE} + -DBIN_DIR=${JAVA_OUTPUT_DIR} + -DJAVA_SRC_ROOT=${JAVA_SRC_ROOT} + -DJAVA_PACKAGE_LIB_DIR=${JAVA_PACKAGE_LIB_DIR} + -P ${JAVA_SRC_ROOT}/windows-unittests.cmake) + else() + add_test(NAME onnxruntime-genai4j_test + COMMAND ${GRADLE_EXECUTABLE} cmakeCheck + -DcmakeBuildDir=${JAVA_OUTPUT_DIR} -DnativeLibDir=${JAVA_PACKAGE_LIB_DIR} + WORKING_DIRECTORY ${JAVA_SRC_ROOT}) + endif() + + set_property(TEST onnxruntime-genai4j_test APPEND PROPERTY DEPENDS onnxruntime-genai-jni) +endif() diff --git a/src/java/Debugging.md b/src/java/Debugging.md new file mode 100644 index 000000000..b56189f84 --- /dev/null +++ b/src/java/Debugging.md @@ -0,0 +1,81 @@ +# Debugging Notes for Windows + +## To debug using VS Code. + +Create a config for java tests in settings.json. Adjust the paths based on your setup. +My repo root is D:\src\github\ort.genai, and I was testing a Debug build on Windows. + +```yaml + "java.test.config": { + "testKind": "junit", + "workingDirectory": "D:\\src\\github\\ort.genai\\build\\Windows\\Debug\\src\\java", + "classPaths": [ + "D:\\src\\github\\ort.genai\\src\\java\\build\\classes\\java\\main", + "D:\\src\\github\\ort.genai\\src\\java\\build\\classes\\java\\test", + "D:\\src\\github\\ort.genai\\src\\java\\build\\resources\\test" + ], + "sourcePaths": [ + "D:\\src\\github\\ort.genai\\src\\java\\src\\main\\java", + "D:\\src\\github\\ort.genai\\src\\java\\src\\test\\java" + ], + "vmArgs": [ "-Djava.library.path=D:\\src\\github\\ort.genai\\build\\Windows\\Debug\\src\\java\\native-lib\\ai\\onnxruntime-genai\\native\\win-x64" ], + }, +``` + +You may also want to set this in the VS Code settings: +```yaml + "java.debug.settings.onBuildFailureProceed": true, +``` + +I didn't try to setup VS Code to be able to build the tests using cmake or gradle, so the build VS Code attempts +before run/debug of a test always fails. Instead I built the binding/test code from the command line, and then debugged +the tests from VS Code. + +You can do a top level build (`./build --build_java --config Debug --build --test ...` from the repo root), +or manually run gradlew from the src/java directory. + +e.g. the gradlew command line looks something like this to update the tests. Adjust build output path as needed. +> D:\src\github\ort.genai\src\java>D:/src/github/ort.genai/src/java/gradlew --info test -DcmakeBuildDir="D:\src\github\ort.genai\build\Windows\Debug\src\java" -DnativeLibDir="D:\src\github\ort.genai\build\Windows\Debug\src\java\native-lib\ai\onnxruntime-genai\native\win-x64" -Dorg.gradle.daemon=false + +NOTE: If using the top-level build, the unit test code gets built in the 'test' phase - that's just how the gradle build is setup. + +## To debug using IntelliJ + +The test Debug/Run config for the test needs the cmakeBuildDir and nativeLibDir values to be added. + +Easiest way to create a test config is to right-click on the test or test directory in the Project window and run it. +The run will fail, but the bulk of the configuration will be created for you. + +Now open the test configuration, and in the 'Run' command which will start with something like + `:test --tests "ai.onnxruntime.genai..."` +add the values for cmakeBuildDir and nativeLibDir. + +e.g. +`"-DcmakeBuildDir=D:\src\github\ort.genai\build\Windows\Debug\src\java -DnativeLibDir=D:\src\github\ort.genai\build\Windows\Debug\src\java\native-lib\ai\onnxruntime-genai\native\win-x64` + +----- + +# Debugging native code + +Download a junit-platform-console-standalone jar file from https://central.sonatype.com/artifact/org.junit.platform/junit-platform-console-standalone/versions + +With that the magic incantation to run the tests from the command line (on Windows at least) is... + +All tests: + +> D:\Java\jdk-11.0.17\bin\java.exe "-Djava.library.path=D:\src\github\ort.genai\build\Windows\Debug\src\java\native-lib\ai\onnxruntime-genai\native\win-x64" -jar junit-platform-console-standalone-1.10.2.jar -cp D:\src\github\ort.genai\src\java\build\classes\java\test -cp D:\src\github\ort.genai\src\java\build\resources\test -cp D:\src\github\ort.genai\src\java\build\classes\java\main --scan-classpath + +Specific test class uses `-c` and the full class name. e.g. + +> D:\Java\jdk-11.0.17\bin\java.exe "-Djava.library.path=D:\src\github\ort.genai\build\Windows\Debug\src\java\native-lib\ai\onnxruntime-genai\native\win-x64" -jar junit-platform-console-standalone-1.10.2.jar -cp D:\src\github\ort.genai\src\java\build\classes\java\test -cp D:\src\github\ort.genai\src\java\build\resources\test -cp D:\src\github\ort.genai\src\java\build\classes\java\main -c ai.onnxruntime.genai.GenerationTest + +Adjust the paths for your setup. Run from the java build output directory: e.g. D:\src\github\ort.genai\build\Windows\Debug\src\java + +That command can also be run from Visual Studio using the solution file for the native library you need to debug (onnxruntime-genai or onnxruntime) by setting the debug command, arguments and working directory in the project properties. + +e.g. to debug the `onnxruntime` project (which builds the onnxruntime shared library) in onnxruntime.sln, in the project properties for the `onnxruntime` project, under Debugging, set Command/Command Arguments/Working Directory to the above values. +You can then right-click on the `onnxruntime` project -> Debug -> Start new instance. That should run java.exe and let you break on any exceptions with full symbols for the native code. + +To also be able to set breakpoints, make sure a local debug build of the library is in the nativeLibDir so that java.exe is loading that. + + diff --git a/src/java/README.md b/src/java/README.md new file mode 100644 index 000000000..f02a92414 --- /dev/null +++ b/src/java/README.md @@ -0,0 +1,71 @@ +# ONNX Runtime GenAI Java API + +This directory contains the Java language binding for the ONNX Runtime GenAI. +Java Native Interface (JNI) is used to allow for seamless calls to ONNX Runtime GenAI from Java. + +## Usage + +This document pertains to developing, building, running, and testing the API itself in your local environment. +For general purpose usage of the publicly distributed API, please see the [general Java API documentation](https://www.onnxruntime.ai/docs/reference/api/java-api.html). + +### Building + +Build with the `--build_java` option. + +Windows: `REPO_ROOT/build --build_java` +*nix: `REPO_ROOT/build.sh --build_java` + +#### Requirements + +Java 11 or later is required to build the library. The compiled jar file will run on Java 8 or later. + +The [Gradle](https://gradle.org/) build system is used here to manage the Java project's dependency management, compilation, testing, and assembly. +In particular, the Gradle [wrapper](https://docs.gradle.org/current/userguide/gradle_wrapper.html) at `java/gradlew[.bat]` is used, locking the Gradle version to the one specified in the `java/gradle/wrapper/gradle-wrapper.properties` configuration. +Using the Gradle wrapper removes the need to have the right version of Gradle installed on the system. + +#### Build Output + +The build will generate output in `$REPO_ROOT/build/$OS/$CONFIGURATION/src/java`: + +* `build/docs/javadoc/` - HTML javadoc +* `build/reports/` - detailed test results and other reports +* `build/libs/onnxruntime-genai-VERSION.jar` - JAR with compiled classes +* `native-jni` - platform-specific JNI shared library +* `native-lib` - platform-specific onnxruntime-genai and onnxruntime shared libraries. + +#### Build System Overview + +The main CMake build system delegates building and testing to Gradle. +This allows the CMake system to ensure all of the C/C++ compilation is achieved prior to the Java build. +The Java build depends on C/C++ onnxruntime-genai shared library and a C JNI shared library (source located in the `src/main/native` directory). +The JNI shared library is the glue that allows for Java to call functions in onnxruntime-genai shared library. +Given the fact that CMake injects native dependencies during CMake builds, some gradle tasks (primarily, `build`, `test`, and `check`) may fail. + +When running the build script, CMake will compile the `onnxruntime-genai` target and the JNI glue `onnxruntime-genai-jni` target and expose the resulting libraries in a place where Gradle can ingest them. +Upon successful compilation of those targets, a special Gradle task to build will be executed. The results will be placed in the output directory stated above. + +### Advanced Loading + +The default behavior is to load the shared libraries using classpath resources. +If your use case requires custom loading of the shared libraries, please consult the javadoc in the [package-info.java](src/main/java/ai/onnxruntime-genai/package-info.java) or [OnnxRuntimeGenAI.java](src/main/java/ai/onnxruntime-genai/GenAI.java) files. + +## Development + +### Code Formatting + +[Spotless](https://github.com/diffplug/spotless/tree/master/plugin-gradle) is used to keep the code properly formatted. +Gradle's `spotlessCheck` task will show any misformatted code. +Gradle's `spotlessApply` task will try to fix the formatting. +Misformatted code will raise failures when checks are ran during test run. + +### JNI Headers + +When adding or updating native methods in the Java files, the auto-generated JNI headers in `build/headers/ai_onnxruntime-genai*.h` can be used to determine the JNI function signature. + +These header files can be manually generated using Gradle's `compileJava` task which will compile the Java and update the header files accordingly. + +Cut-and-paste the function declaration from the auto-generated .h file to add the implementation in the `./src/main/native/ai_onnxruntime-genai_*.cpp` file. + +### Dependencies + +The Java API does not have any runtime or compile dependencies. diff --git a/src/java/build-android.gradle b/src/java/build-android.gradle new file mode 100644 index 000000000..5b61fd03b --- /dev/null +++ b/src/java/build-android.gradle @@ -0,0 +1,198 @@ +apply plugin: 'com.android.library' +apply plugin: 'maven-publish' + +def jniLibsDir = System.properties['jniLibsDir'] +def buildDir = System.properties['buildDir'] +def headersDir = System.properties['headersDir'] +def publishDir = System.properties['publishDir'] +def minSdkVer = System.properties['minSdkVer'] +def targetSdkVer = System.properties['targetSdkVer'] + +// Since Android requires a higher numbers indicating more recent versions +// This function assume ORT version number will be in formart of A.B.C such as 1.7.0 +// We generate version code A[0{0,1}]B[0{0,1}]C, +// for example '1.7.0' -> 10700, '1.6.15' -> 10615 +def static getVersionCode(String version){ + String[] codes = version.split('\\.'); + // This will have problem if we have 3 digit [sub]version number, such as 1.7.199 + // but it is highly unlikely to happen + String versionCodeStr = String.format("%d%02d%02d", codes[0] as int, codes[1] as int, codes[2] as int); + return versionCodeStr as int; +} + +project.buildDir = buildDir +project.version = rootProject.file('../../VERSION_INFO').text.trim() +project.group = "com.microsoft.onnxruntime" + +def mavenArtifactId = project.name + '-android' +def defaultDescription = 'ONNX Runtime GenAI is ... ' + + 'This package contains the Android (AAR) build of ONNX Runtime GenAI, including Java bindings.' + +buildscript { + repositories { + google() + mavenCentral() + } + dependencies { + classpath 'com.android.tools.build:gradle:7.4.2' + + // NOTE: Do not place your application dependencies here; they belong + // in the individual module build.gradle files + } +} + +allprojects { + repositories { + google() + mavenCentral() + } +} + +android { + compileSdkVersion 32 + + defaultConfig { + minSdkVersion minSdkVer + targetSdkVersion targetSdkVer + versionCode = getVersionCode(project.version) + versionName = project.version + + testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" + } + + android { + lintOptions { + abortOnError false + } + } + + buildTypes { + release { + minifyEnabled false + debuggable false + proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' + } + } + + compileOptions { + sourceCompatibility = JavaVersion.VERSION_1_8 + targetCompatibility = JavaVersion.VERSION_1_8 + } + + sourceSets { + main { + jniLibs.srcDirs = [jniLibsDir] + java { + srcDirs = ['src/main/java', 'src/main/android'] + } + } + } + + namespace 'ai.onnxruntime.genai' +} + +task sourcesJar(type: Jar) { + archiveClassifier = "sources" + from android.sourceSets.main.java.srcDirs +} + +task javadoc(type: Javadoc) { + source = android.sourceSets.main.java.srcDirs + classpath += project.files(android.getBootClasspath().join(File.pathSeparator)) +} + +task javadocJar(type: Jar, dependsOn: javadoc) { + archiveClassifier = 'javadoc' + from javadoc.destinationDir +} + +artifacts { + archives javadocJar + archives sourcesJar +} + +dependencies { + testImplementation 'org.junit.jupiter:junit-jupiter-api:5.7.0' + testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.7.0' +} + +publishing { + publications { + maven(MavenPublication) { + groupId = project.group + artifactId = mavenArtifactId + version = project.version + + // Three artifacts, the `aar`, the sources and the javadoc + artifact("$buildDir/outputs/aar/${project.name}-release.aar") + artifact javadocJar + artifact sourcesJar + + pom { + name = 'onnxruntime-genai' + description = defaultDescription + // TODO: Setup https://microsoft.github.io/onnxruntime-genai/ for equivalence with ORT? + url = 'https://github.com/microsoft/onnxruntime-genai/' + licenses { + license { + name = 'MIT License' + url = 'https://opensource.org/licenses/MIT' + } + } + organization { + name = 'Microsoft' + url = 'http://www.microsoft.com' + } + scm { + connection = 'scm:git:git://github.com:microsoft/onnxruntime-genai.git' + developerConnection = 'scm:git:ssh://github.com/microsoft/onnxruntime-genai.git' + url = 'https://github.com/microsoft/onnxruntime-genai' + } + developers { + // TODO: Does this need updating? + developer { + id = 'onnxruntime' + name = 'ONNX Runtime' + email = 'onnxruntime@microsoft.com' + } + } + } + } + } + + //publish to filesystem repo + repositories{ + maven { + url "$publishDir" + } + } +} + +// Add ORT C and C++ API headers to the AAR package, after task bundleDebugAar or bundleReleaseAar +// Such that developers using ORT native API can extract libraries and headers from AAR package without building ORT +tasks.whenTaskAdded { task -> + if (task.name.startsWith("bundle") && task.name.endsWith("Aar")) { + doLast { + addFolderToAar("addHeadersTo" + task.name, task.archivePath, headersDir, 'headers') + } + } +} + +def addFolderToAar(taskName, aarPath, folderPath, folderPathInAar) { + def tmpDir = file("${buildDir}/${taskName}") + tmpDir.mkdir() + def tmpDirFolder = file("${tmpDir.path}/${folderPathInAar}") + tmpDirFolder.mkdir() + copy { + from zipTree(aarPath) + into tmpDir + } + copy { + from fileTree(folderPath) + into tmpDirFolder + } + ant.zip(destfile: aarPath) { + fileset(dir: tmpDir.path) + } + delete tmpDir +} diff --git a/src/java/build.gradle b/src/java/build.gradle new file mode 100644 index 000000000..a601885e6 --- /dev/null +++ b/src/java/build.gradle @@ -0,0 +1,249 @@ +plugins { + id 'java-library' + id 'maven-publish' + id 'signing' + id 'jacoco' + id "com.diffplug.spotless" version "6.25.0" +} + +allprojects { + repositories { + mavenCentral() + } +} + +project.group = "com.microsoft.onnxruntime" +version = rootProject.file('../../VERSION_INFO').text.trim() + +// cmake runs will inform us of the build directory of the current run +def cmakeBuildDir = System.properties['cmakeBuildDir'] +def useCUDA = System.properties['USE_CUDA'] +def cmakeJavaDir = "${cmakeBuildDir}" +def cmakeNativeLibDir = System.properties['nativeLibDir'] +def cmakeBuildOutputDir = "${cmakeJavaDir}/build" + +def mavenUser = System.properties['mavenUser'] +def mavenPwd = System.properties['mavenPwd'] +def mavenArtifactId = useCUDA == null ? project.name : project.name + "_gpu" + +def defaultDescription = 'ONNX Runtime GenAI is ' + +logger.lifecycle("cmakeBuildDir:${cmakeBuildDir}") +logger.lifecycle("cmakeNativeLibDir:${cmakeNativeLibDir}") + +java { + sourceCompatibility = JavaVersion.VERSION_1_8 + targetCompatibility = JavaVersion.VERSION_1_8 +} + +// This jar tasks serves as a CMAKE signalling +// mechanism. The jar will be overwritten by allJar task +jar { +} + +// Add explicit sources jar with pom file. +task sourcesJar(type: Jar, dependsOn: classes) { + archiveClassifier = "sources" + from sourceSets.main.allSource + into("META-INF/maven/$project.group/$mavenArtifactId") { + from { generatePomFileForMavenPublication } + rename ".*", "pom.xml" + } +} + +// Add explicit javadoc jar with pom file +task javadocJar(type: Jar, dependsOn: javadoc) { + archiveClassifier = "javadoc" + from javadoc.destinationDir + into("META-INF/maven/$project.group/$mavenArtifactId") { + from { generatePomFileForMavenPublication } + rename ".*", "pom.xml" + } +} + +spotless { + java { + removeUnusedImports() + googleJavaFormat() + } + format 'gradle', { + target '**/*.gradle' + trimTrailingWhitespace() + indentWithTabs() + + } +} + +compileJava { + dependsOn spotlessJava + options.compilerArgs += ["-h", "${project.buildDir}/headers/"] + if (!JavaVersion.current().isJava8()) { + // Ensures only methods present in Java 8 are used + options.compilerArgs.addAll(['--release', '8']) + // Gradle versions before 6.6 require that these flags are unset when using "-release" + java.sourceCompatibility = null + java.targetCompatibility = null + } +} + +compileTestJava { + if (!JavaVersion.current().isJava8()) { + // Ensures only methods present in Java 8 are used + options.compilerArgs.addAll(['--release', '8']) + // Gradle versions before 6.6 require that these flags are unset when using "-release" + java.sourceCompatibility = null + java.targetCompatibility = null + } +} + +sourceSets.main.java { + srcDirs = ['src/main/java', 'src/main/jvm'] +} + +// expects +sourceSets.test { + // add test resource files + resources.srcDirs += [ + "../../test/test_models" + ] + + if (cmakeNativeLibDir != null) { + // add compiled native libs + resources.srcDirs += [cmakeNativeLibDir] + } +} + +if (cmakeNativeLibDir != null) { + // generate tasks to be called from cmake + + // Overwrite jar location + task allJar(type: Jar) { + manifest { + attributes('Automatic-Module-Name': "com.microsoft.onnxruntime.genai", + 'Implementation-Title': 'onnxruntime-genai', + 'Implementation-Version': project.version) + } + into("META-INF/maven/$project.group/$mavenArtifactId") { + from { generatePomFileForMavenPublication } + rename ".*", "pom.xml" + } + from sourceSets.main.output + from cmakeNativeLibDir + } + + task cmakeBuild(type: Copy) { + from project.buildDir + include 'libs/**' + include 'docs/**' + into cmakeBuildOutputDir + } + + cmakeBuild.dependsOn allJar + cmakeBuild.dependsOn sourcesJar + cmakeBuild.dependsOn javadocJar + cmakeBuild.dependsOn javadoc + + task cmakeCheck(type: Copy) { + from project.buildDir + include 'reports/**' + into cmakeBuildOutputDir + } + cmakeCheck.dependsOn check +} + +dependencies { + testImplementation 'org.junit.jupiter:junit-jupiter-api:5.9.2' + testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.9.2' +} + +processTestResources { + duplicatesStrategy(DuplicatesStrategy.INCLUDE) // allows duplicates in the test resources +} + +test { + java { + dependsOn spotlessJava + } + useJUnitPlatform() + if (cmakeBuildDir != null) { + workingDir cmakeBuildDir + } + + systemProperty "java.library.path", cmakeNativeLibDir + systemProperties System.getProperties().subMap(['USE_CUDA']) + testLogging { + events "passed", "skipped", "failed" + showStandardStreams = true + showStackTraces = true + exceptionFormat = "full" + } +} + +jacocoTestReport { + reports { + xml.required = true + csv.required = true + html.destination file("${buildDir}/jacocoHtml") + } +} + +publishing { + publications { + maven(MavenPublication) { + groupId = project.group + artifactId = mavenArtifactId + + from components.java + pom { + name = 'onnxruntime-genai' + description = defaultDescription + // TODO: Setup https://microsoft.github.io/onnxruntime-genai/ for equivalence with ORT? + url = 'https://github.com/microsoft/onnxruntime-genai/' + licenses { + license { + name = 'MIT License' + url = 'https://opensource.org/licenses/MIT' + } + } + organization { + name = 'Microsoft' + url = 'https://www.microsoft.com' + } + scm { + connection = 'scm:git:git://github.com:microsoft/onnxruntime-genai.git' + developerConnection = 'scm:git:ssh://github.com/microsoft/onnxruntime-genai.git' + url = 'https://github.com/microsoft/onnxruntime-genai' + } + developers { + // TODO: Does this need updating? + developer { + id = 'onnxruntime' + name = 'ONNX Runtime' + email = 'onnxruntime@microsoft.com' + } + } + } + } + } + repositories { + maven { + url 'https://oss.sonatype.org/service/local/staging/deploy/maven2/' + credentials { + username mavenUser + password mavenPwd + } + } + } +} + +// Generates a task signMavenPublication that will +// build all artifacts. +signing { + // Queries env vars: + // ORG_GRADLE_PROJECT_signingKey + // ORG_GRADLE_PROJECT_signingPassword but can be changed to properties + def signingKey = findProperty("signingKey") + def signingPassword = findProperty("signingPassword") + useInMemoryPgpKeys(signingKey, signingPassword) + sign publishing.publications.maven +} diff --git a/src/java/gradle/wrapper/gradle-wrapper.jar b/src/java/gradle/wrapper/gradle-wrapper.jar new file mode 100644 index 000000000..d64cd4917 Binary files /dev/null and b/src/java/gradle/wrapper/gradle-wrapper.jar differ diff --git a/src/java/gradle/wrapper/gradle-wrapper.properties b/src/java/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 000000000..4baf5a11d --- /dev/null +++ b/src/java/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,8 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionSha256Sum=9631d53cf3e74bfa726893aee1f8994fee4e060c401335946dba2156f440f24c +distributionUrl=https\://services.gradle.org/distributions/gradle-8.6-bin.zip +networkTimeout=10000 +validateDistributionUrl=true +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/src/java/gradlew b/src/java/gradlew new file mode 100755 index 000000000..1aa94a426 --- /dev/null +++ b/src/java/gradlew @@ -0,0 +1,249 @@ +#!/bin/sh + +# +# Copyright © 2015-2021 the original authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +############################################################################## +# +# Gradle start up script for POSIX generated by Gradle. +# +# Important for running: +# +# (1) You need a POSIX-compliant shell to run this script. If your /bin/sh is +# noncompliant, but you have some other compliant shell such as ksh or +# bash, then to run this script, type that shell name before the whole +# command line, like: +# +# ksh Gradle +# +# Busybox and similar reduced shells will NOT work, because this script +# requires all of these POSIX shell features: +# * functions; +# * expansions «$var», «${var}», «${var:-default}», «${var+SET}», +# «${var#prefix}», «${var%suffix}», and «$( cmd )»; +# * compound commands having a testable exit status, especially «case»; +# * various built-in commands including «command», «set», and «ulimit». +# +# Important for patching: +# +# (2) This script targets any POSIX shell, so it avoids extensions provided +# by Bash, Ksh, etc; in particular arrays are avoided. +# +# The "traditional" practice of packing multiple parameters into a +# space-separated string is a well documented source of bugs and security +# problems, so this is (mostly) avoided, by progressively accumulating +# options in "$@", and eventually passing that to Java. +# +# Where the inherited environment variables (DEFAULT_JVM_OPTS, JAVA_OPTS, +# and GRADLE_OPTS) rely on word-splitting, this is performed explicitly; +# see the in-line comments for details. +# +# There are tweaks for specific operating systems such as AIX, CygWin, +# Darwin, MinGW, and NonStop. +# +# (3) This script is generated from the Groovy template +# https://github.com/gradle/gradle/blob/HEAD/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt +# within the Gradle project. +# +# You can find Gradle at https://github.com/gradle/gradle/. +# +############################################################################## + +# Attempt to set APP_HOME + +# Resolve links: $0 may be a link +app_path=$0 + +# Need this for daisy-chained symlinks. +while + APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path + [ -h "$app_path" ] +do + ls=$( ls -ld "$app_path" ) + link=${ls#*' -> '} + case $link in #( + /*) app_path=$link ;; #( + *) app_path=$APP_HOME$link ;; + esac +done + +# This is normally unused +# shellcheck disable=SC2034 +APP_BASE_NAME=${0##*/} +# Discard cd standard output in case $CDPATH is set (https://github.com/gradle/gradle/issues/25036) +APP_HOME=$( cd "${APP_HOME:-./}" > /dev/null && pwd -P ) || exit + +# Use the maximum available, or set MAX_FD != -1 to use that value. +MAX_FD=maximum + +warn () { + echo "$*" +} >&2 + +die () { + echo + echo "$*" + echo + exit 1 +} >&2 + +# OS specific support (must be 'true' or 'false'). +cygwin=false +msys=false +darwin=false +nonstop=false +case "$( uname )" in #( + CYGWIN* ) cygwin=true ;; #( + Darwin* ) darwin=true ;; #( + MSYS* | MINGW* ) msys=true ;; #( + NONSTOP* ) nonstop=true ;; +esac + +CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + + +# Determine the Java command to use to start the JVM. +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD=$JAVA_HOME/jre/sh/java + else + JAVACMD=$JAVA_HOME/bin/java + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD=java + if ! command -v java >/dev/null 2>&1 + then + die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +fi + +# Increase the maximum file descriptors if we can. +if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then + case $MAX_FD in #( + max*) + # In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC2039,SC3045 + MAX_FD=$( ulimit -H -n ) || + warn "Could not query maximum file descriptor limit" + esac + case $MAX_FD in #( + '' | soft) :;; #( + *) + # In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC2039,SC3045 + ulimit -n "$MAX_FD" || + warn "Could not set maximum file descriptor limit to $MAX_FD" + esac +fi + +# Collect all arguments for the java command, stacking in reverse order: +# * args from the command line +# * the main class name +# * -classpath +# * -D...appname settings +# * --module-path (only if needed) +# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables. + +# For Cygwin or MSYS, switch paths to Windows format before running java +if "$cygwin" || "$msys" ; then + APP_HOME=$( cygpath --path --mixed "$APP_HOME" ) + CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" ) + + JAVACMD=$( cygpath --unix "$JAVACMD" ) + + # Now convert the arguments - kludge to limit ourselves to /bin/sh + for arg do + if + case $arg in #( + -*) false ;; # don't mess with options #( + /?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath + [ -e "$t" ] ;; #( + *) false ;; + esac + then + arg=$( cygpath --path --ignore --mixed "$arg" ) + fi + # Roll the args list around exactly as many times as the number of + # args, so each arg winds up back in the position where it started, but + # possibly modified. + # + # NB: a `for` loop captures its iteration list before it begins, so + # changing the positional parameters here affects neither the number of + # iterations, nor the values presented in `arg`. + shift # remove old arg + set -- "$@" "$arg" # push replacement arg + done +fi + + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' + +# Collect all arguments for the java command: +# * DEFAULT_JVM_OPTS, JAVA_OPTS, JAVA_OPTS, and optsEnvironmentVar are not allowed to contain shell fragments, +# and any embedded shellness will be escaped. +# * For example: A user cannot expect ${Hostname} to be expanded, as it is an environment variable and will be +# treated as '${Hostname}' itself on the command line. + +set -- \ + "-Dorg.gradle.appname=$APP_BASE_NAME" \ + -classpath "$CLASSPATH" \ + org.gradle.wrapper.GradleWrapperMain \ + "$@" + +# Stop when "xargs" is not available. +if ! command -v xargs >/dev/null 2>&1 +then + die "xargs is not available" +fi + +# Use "xargs" to parse quoted args. +# +# With -n1 it outputs one arg per line, with the quotes and backslashes removed. +# +# In Bash we could simply go: +# +# readarray ARGS < <( xargs -n1 <<<"$var" ) && +# set -- "${ARGS[@]}" "$@" +# +# but POSIX shell has neither arrays nor command substitution, so instead we +# post-process each arg (as a line of input to sed) to backslash-escape any +# character that might be a shell metacharacter, then use eval to reverse +# that process (while maintaining the separation between arguments), and wrap +# the whole thing up as a single "set" statement. +# +# This will of course break if any of these variables contains a newline or +# an unmatched quote. +# + +eval "set -- $( + printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" | + xargs -n1 | + sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' | + tr '\n' ' ' + )" '"$@"' + +exec "$JAVACMD" "$@" diff --git a/src/java/gradlew.bat b/src/java/gradlew.bat new file mode 100644 index 000000000..93e3f59f1 --- /dev/null +++ b/src/java/gradlew.bat @@ -0,0 +1,92 @@ +@rem +@rem Copyright 2015 the original author or authors. +@rem +@rem Licensed under the Apache License, Version 2.0 (the "License"); +@rem you may not use this file except in compliance with the License. +@rem You may obtain a copy of the License at +@rem +@rem https://www.apache.org/licenses/LICENSE-2.0 +@rem +@rem Unless required by applicable law or agreed to in writing, software +@rem distributed under the License is distributed on an "AS IS" BASIS, +@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +@rem See the License for the specific language governing permissions and +@rem limitations under the License. +@rem + +@if "%DEBUG%"=="" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +set DIRNAME=%~dp0 +if "%DIRNAME%"=="" set DIRNAME=. +@rem This is normally unused +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Resolve any "." and ".." in APP_HOME to make it shorter. +for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if %ERRORLEVEL% equ 0 goto execute + +echo. +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto execute + +echo. +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:execute +@rem Setup the command line + +set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar + + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %* + +:end +@rem End local scope for the variables with windows NT shell +if %ERRORLEVEL% equ 0 goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +set EXIT_CODE=%ERRORLEVEL% +if %EXIT_CODE% equ 0 set EXIT_CODE=1 +if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE% +exit /b %EXIT_CODE% + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/src/java/settings-android.gradle b/src/java/settings-android.gradle new file mode 100644 index 000000000..55667fd8b --- /dev/null +++ b/src/java/settings-android.gradle @@ -0,0 +1,2 @@ +rootProject.name = 'onnxruntime-genai' +rootProject.buildFileName = 'build-android.gradle' diff --git a/src/java/settings.gradle b/src/java/settings.gradle new file mode 100644 index 000000000..9d4459bb9 --- /dev/null +++ b/src/java/settings.gradle @@ -0,0 +1 @@ +rootProject.name = 'onnxruntime-genai' \ No newline at end of file diff --git a/src/java/src/main/AndroidManifest.xml b/src/java/src/main/AndroidManifest.xml new file mode 100644 index 000000000..94cbbcfc3 --- /dev/null +++ b/src/java/src/main/AndroidManifest.xml @@ -0,0 +1 @@ + diff --git a/src/java/src/main/java/ai/onnxruntime/genai/GenAI.java b/src/java/src/main/java/ai/onnxruntime/genai/GenAI.java new file mode 100644 index 000000000..027eece93 --- /dev/null +++ b/src/java/src/main/java/ai/onnxruntime/genai/GenAI.java @@ -0,0 +1,270 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ +package ai.onnxruntime.genai; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Locale; +import java.util.Optional; +import java.util.logging.Level; +import java.util.logging.Logger; + +final class GenAI { + private static final Logger logger = Logger.getLogger(GenAI.class.getName()); + + /** + * The name of the system property which when set gives the path on disk where the ONNX Runtime + * native libraries are stored. + */ + static final String GENAI_NATIVE_PATH = "onnxruntime-genai.native.path"; + + /** The short name of the ONNX Runtime GenAI shared library */ + static final String GENAI_LIBRARY_NAME = "onnxruntime-genai"; + + /** The short name of the ONNX Runtime GenAI JNI shared library */ + static final String GENAI_JNI_LIBRARY_NAME = "onnxruntime-genai-jni"; + + /** The short name of the ONNX runtime shared library */ + static final String ONNXRUNTIME_LIBRARY_NAME = "onnxruntime"; + + /** The value of the {@link #GENAI_NATIVE_PATH} system property */ + private static String libraryDirPathProperty; + + /** The OS & CPU architecture string */ + private static final String OS_ARCH_STR = initOsArch(); + + /** Have the native libraries been loaded */ + private static boolean loaded = false; + + /** The temp directory where native libraries are extracted */ + private static Path tempDirectory; + + static synchronized void init() throws IOException { + if (loaded) { + return; + } + + tempDirectory = isAndroid() ? null : Files.createTempDirectory("onnxruntime-genai-java"); + + try { + libraryDirPathProperty = System.getProperty(GENAI_NATIVE_PATH); + + load(ONNXRUNTIME_LIBRARY_NAME); // ORT native + load(GENAI_LIBRARY_NAME); // ORT GenAI native + load(GENAI_JNI_LIBRARY_NAME); // GenAI JNI layer + loaded = true; + } finally { + if (tempDirectory != null) { + cleanUp(tempDirectory.toFile()); + } + } + } + + /* Computes and initializes OS_ARCH_STR (such as linux-x64) */ + private static String initOsArch() { + String detectedOS = null; + String os = System.getProperty("os.name", "generic").toLowerCase(Locale.ENGLISH); + if (os.contains("mac") || os.contains("darwin")) { + detectedOS = "osx"; + } else if (os.contains("win")) { + detectedOS = "win"; + } else if (os.contains("nux")) { + detectedOS = "linux"; + } else if (isAndroid()) { + detectedOS = "android"; + } else { + throw new IllegalStateException("Unsupported os:" + os); + } + + String detectedArch = null; + String arch = System.getProperty("os.arch", "generic").toLowerCase(Locale.ENGLISH); + if (arch.startsWith("amd64") || arch.startsWith("x86_64")) { + detectedArch = "x64"; + } else if (arch.startsWith("x86")) { + // 32-bit x86 is not supported by the Java API + detectedArch = "x86"; + } else if (arch.startsWith("aarch64")) { + detectedArch = "aarch64"; + } else if (arch.startsWith("ppc64")) { + detectedArch = "ppc64"; + } else if (isAndroid()) { + detectedArch = arch; + } else { + throw new IllegalStateException("Unsupported arch:" + arch); + } + + return detectedOS + '-' + detectedArch; + } + + /** + * Check if we're running on Android. + * + * @return True if the property java.vendor equals The Android Project, false otherwise. + */ + static boolean isAndroid() { + return System.getProperty("java.vendor", "generic").equals("The Android Project"); + } + + /** + * Marks the file for delete on exit. + * + * @param file The file to remove. + */ + private static void cleanUp(File file) { + if (!file.exists()) { + return; + } + + logger.log(Level.FINE, "Deleting " + file + " on exit"); + file.deleteOnExit(); + } + + /** + * Load a shared library by name. + * + *

If the library path is not specified via a system property then it attempts to extract the + * library from the classpath before loading it. + * + * @param library The bare name of the library. + * @throws IOException If the file failed to read or write. + */ + private static void load(String library) throws IOException { + if (isAndroid()) { + // On Android, we simply use System.loadLibrary. TODO: is this sufficient? + System.loadLibrary(GENAI_JNI_LIBRARY_NAME); + return; + } + + // 1) The user may skip loading of this library: + String skip = System.getProperty("onnxruntime-genai.native." + library + ".skip"); + if (Boolean.TRUE.toString().equalsIgnoreCase(skip)) { + logger.log(Level.FINE, "Skipping load of native library '" + library + "'"); + return; + } + + // Resolve the platform dependent library name. + String libraryFileName = mapLibraryName(library); + + // 2) The user may explicitly specify the path to a directory containing all shared libraries: + if (libraryDirPathProperty != null) { + logger.log( + Level.FINE, + "Attempting to load native library '" + + library + + "' from specified path: " + + libraryDirPathProperty); + + // TODO: Switch this to Path.of when the minimum Java version is 11. + File libraryFile = Paths.get(libraryDirPathProperty, libraryFileName).toFile(); + String libraryFilePath = libraryFile.getAbsolutePath(); + if (!libraryFile.exists()) { + throw new IOException("Native library '" + library + "' not found at " + libraryFilePath); + } + + System.load(libraryFilePath); + logger.log(Level.FINE, "Loaded native library '" + library + "' from specified path"); + return; + } + + // 3) The user may explicitly specify the path to their shared library: + String libraryPathProperty = + System.getProperty("onnxruntime-genai.native." + library + ".path"); + if (libraryPathProperty != null) { + logger.log( + Level.FINE, + "Attempting to load native library '" + + library + + "' from specified path: " + + libraryPathProperty); + File libraryFile = new File(libraryPathProperty); + String libraryFilePath = libraryFile.getAbsolutePath(); + if (!libraryFile.exists()) { + throw new IOException("Native library '" + library + "' not found at " + libraryFilePath); + } + + System.load(libraryFilePath); + logger.log(Level.FINE, "Loaded native library '" + library + "' from specified path"); + return; + } + + // 4) try loading from resources or library path: + Optional extractedPath = extractFromResources(library); + if (extractedPath.isPresent()) { + // extracted library from resources + System.load(extractedPath.get().getAbsolutePath()); + logger.log(Level.FINE, "Loaded native library '" + library + "' from resource path"); + } else { + // failed to load library from resources, try to load it from the library path + logger.log( + Level.FINE, "Attempting to load native library '" + library + "' from library path"); + System.loadLibrary(library); + logger.log(Level.FINE, "Loaded native library '" + library + "' from library path"); + } + } + + /** + * Extracts the library from the classpath resources. returns optional.empty if it failed to + * extract or couldn't be found. + * + * @param library The library name + * @return An optional containing the file if it is successfully extracted, or an empty optional + * if it failed to extract or couldn't be found. + */ + private static Optional extractFromResources(String library) { + String libraryFileName = mapLibraryName(library); + String resourcePath = "/ai/onnxruntime/genai/native/" + OS_ARCH_STR + '/' + libraryFileName; + File tempFile = tempDirectory.resolve(libraryFileName).toFile(); + + try (InputStream is = GenAI.class.getResourceAsStream(resourcePath)) { + if (is == null) { + // Not found in classpath resources + return Optional.empty(); + } else { + // Found in classpath resources, load via temporary file + logger.log( + Level.FINE, + "Attempting to load native library '" + + library + + "' from resource path " + + resourcePath + + " copying to " + + tempFile); + + byte[] buffer = new byte[4096]; + int readBytes; + try (FileOutputStream os = new FileOutputStream(tempFile)) { + while ((readBytes = is.read(buffer)) != -1) { + os.write(buffer, 0, readBytes); + } + } + + logger.log(Level.FINE, "Extracted native library '" + library + "' from resource path"); + return Optional.of(tempFile); + } + } catch (IOException e) { + logger.log( + Level.WARNING, "Failed to extract library '" + library + "' from the resources", e); + return Optional.empty(); + } finally { + cleanUp(tempFile); + } + } + + /** + * Maps the library name into a platform dependent library filename. Converts macOS's "jnilib" to + * "dylib" but otherwise is the same as {@link System#mapLibraryName(String)}. + * + * @param library The library name + * @return The library filename. + */ + private static String mapLibraryName(String library) { + return System.mapLibraryName(library).replace("jnilib", "dylib"); + } +} diff --git a/src/java/src/main/java/ai/onnxruntime/genai/GenAIException.java b/src/java/src/main/java/ai/onnxruntime/genai/GenAIException.java new file mode 100644 index 000000000..30c1d58c1 --- /dev/null +++ b/src/java/src/main/java/ai/onnxruntime/genai/GenAIException.java @@ -0,0 +1,15 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. Licensed under the MIT License. + */ +package ai.onnxruntime.genai; + +/** An exception which contains the error message and code produced by the native layer. */ +public final class GenAIException extends Exception { + public GenAIException(String message) { + super(message); + } + + public GenAIException(String message, Exception innerException) { + super(message, innerException); + } +} diff --git a/src/java/src/main/java/ai/onnxruntime/genai/Generator.java b/src/java/src/main/java/ai/onnxruntime/genai/Generator.java new file mode 100644 index 000000000..e90cc4f0d --- /dev/null +++ b/src/java/src/main/java/ai/onnxruntime/genai/Generator.java @@ -0,0 +1,142 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. Licensed under the MIT License. + */ +package ai.onnxruntime.genai; + +/** + * The Generator class generates output using a model and generator parameters. + * + *

The expected usage is to loop until isDone returns false. Within the loop, call computeLogits + * followed by generateNextToken. + * + *

The newly generated token can be retrieved with getLastTokenInSequence and decoded with + * TokenizerStream.Decode. + * + *

After the generation process is done, GetSequence can be used to retrieve the complete + * generated sequence if needed. + */ +public final class Generator implements AutoCloseable { + private long nativeHandle = 0; + + /** + * Constructs a Generator object with the given model and generator parameters. + * + * @param model The model. + * @param generatorParams The generator parameters. + * @throws GenAIException If the call to the GenAI native API fails. + */ + public Generator(Model model, GeneratorParams generatorParams) throws GenAIException { + if (model.nativeHandle() == 0) { + throw new IllegalArgumentException("model has been freed and is invalid"); + } + + if (generatorParams.nativeHandle() == 0) { + throw new IllegalArgumentException("generatorParams has been freed and is invalid"); + } + + nativeHandle = createGenerator(model.nativeHandle(), generatorParams.nativeHandle()); + } + + /** + * Checks if the generation process is done. + * + * @return true if the generation process is done, false otherwise. + */ + public boolean isDone() { + if (nativeHandle == 0) { + throw new IllegalStateException("Instance has been freed and is invalid"); + } + + return isDone(nativeHandle); + } + + /** + * Computes the logits for the next token in the sequence. + * + * @throws GenAIException If the call to the GenAI native API fails. + */ + public void computeLogits() throws GenAIException { + if (nativeHandle == 0) { + throw new IllegalStateException("Instance has been freed and is invalid"); + } + + computeLogits(nativeHandle); + } + + /** + * Generates the next token in the sequence. + * + * @throws GenAIException If the call to the GenAI native API fails. + */ + public void generateNextToken() throws GenAIException { + if (nativeHandle == 0) { + throw new IllegalStateException("Instance has been freed and is invalid"); + } + + generateNextTokenNative(nativeHandle); + } + + /** + * Retrieves a sequence of token ids for the specified sequence index. + * + * @param sequenceIndex The index of the sequence. + * @return An array of integers with the sequence token ids. + * @throws GenAIException If the call to the GenAI native API fails. + */ + public int[] getSequence(long sequenceIndex) throws GenAIException { + if (nativeHandle == 0) { + throw new IllegalStateException("Instance has been freed and is invalid"); + } + + return getSequenceNative(nativeHandle, sequenceIndex); + } + + /** + * Retrieves the last token in the sequence for the specified sequence index. + * + * @param sequenceIndex The index of the sequence. + * @return The last token in the sequence. + * @throws GenAIException If the call to the GenAI native API fails. + */ + public int getLastTokenInSequence(long sequenceIndex) throws GenAIException { + if (nativeHandle == 0) { + throw new IllegalStateException("Instance has been freed and is invalid"); + } + + return getSequenceLastToken(nativeHandle, sequenceIndex); + } + + /** Closes the Generator and releases any associated resources. */ + @Override + public void close() { + if (nativeHandle != 0) { + destroyGenerator(nativeHandle); + nativeHandle = 0; + } + } + + static { + try { + GenAI.init(); + } catch (Exception e) { + throw new RuntimeException("Failed to load onnxruntime-genai native libraries", e); + } + } + + private native long createGenerator(long modelHandle, long generatorParamsHandle) + throws GenAIException; + + private native void destroyGenerator(long nativeHandle); + + private native boolean isDone(long nativeHandle); + + private native void computeLogits(long nativeHandle) throws GenAIException; + + private native void generateNextTokenNative(long nativeHandle) throws GenAIException; + + private native int[] getSequenceNative(long nativeHandle, long sequenceIndex) + throws GenAIException; + + private native int getSequenceLastToken(long nativeHandle, long sequenceIndex) + throws GenAIException; +} diff --git a/src/java/src/main/java/ai/onnxruntime/genai/GeneratorParams.java b/src/java/src/main/java/ai/onnxruntime/genai/GeneratorParams.java new file mode 100644 index 000000000..d0bcc3983 --- /dev/null +++ b/src/java/src/main/java/ai/onnxruntime/genai/GeneratorParams.java @@ -0,0 +1,122 @@ +package ai.onnxruntime.genai; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +/** + * The `GeneratorParams` class represents the parameters used for generating sequences with a model. + * Set the prompt using setInput, and any other search options using setSearchOption. + */ +public final class GeneratorParams implements AutoCloseable { + private long nativeHandle = 0; + private ByteBuffer tokenIdsBuffer; + + GeneratorParams(Model model) throws GenAIException { + if (model.nativeHandle() == 0) { + throw new IllegalStateException("model has been freed and is invalid"); + } + + nativeHandle = createGeneratorParams(model.nativeHandle()); + } + + public void setSearchOption(String optionName, double value) throws GenAIException { + if (nativeHandle == 0) { + throw new IllegalStateException("Instance has been freed and is invalid"); + } + + setSearchOptionNumber(nativeHandle, optionName, value); + } + + public void setSearchOption(String optionName, boolean value) throws GenAIException { + if (nativeHandle == 0) { + throw new IllegalStateException("Instance has been freed and is invalid"); + } + + setSearchOptionBool(nativeHandle, optionName, value); + } + + /** + * Sets the prompt/s for model execution. The `sequences` are created by using Tokenizer.Encode or + * EncodeBatch. + * + * @param sequences Sequences containing the encoded prompt. + * @throws GenAIException If the call to the GenAI native API fails. + */ + public void setInput(Sequences sequences) throws GenAIException { + if (sequences.nativeHandle() == 0) { + throw new IllegalArgumentException("sequences has been freed and is invalid"); + } + + if (nativeHandle == 0) { + throw new IllegalStateException("Instance has been freed and is invalid"); + } + + tokenIdsBuffer = null; // free the token ids buffer if previously used. + setInputSequences(nativeHandle, sequences.nativeHandle()); + } + + /** + * Sets the prompt/s token ids for model execution. The `tokenIds` are the encoded + * + * @param tokenIds The token ids of the encoded prompt/s. + * @param sequenceLength The length of each sequence. + * @param batchSize The batch size + * @throws GenAIException If the call to the GenAI native API fails. + *

NOTE: All sequences in the batch must be the same length. + */ + public void setInput(int[] tokenIds, int sequenceLength, int batchSize) throws GenAIException { + if (nativeHandle == 0) { + throw new IllegalStateException("Instance has been freed and is invalid"); + } + + if (sequenceLength * batchSize != tokenIds.length) { + throw new IllegalArgumentException( + "tokenIds length must be equal to sequenceLength * batchSize"); + } + + // allocate a direct buffer to store the token ids so that they remain valid throughout the + // generation process as the GenAI layer does not copy the token ids. + tokenIdsBuffer = ByteBuffer.allocateDirect(tokenIds.length * Integer.BYTES); + tokenIdsBuffer.order(ByteOrder.nativeOrder()); + tokenIdsBuffer.asIntBuffer().put(tokenIds); + + setInputIDs(nativeHandle, tokenIdsBuffer, sequenceLength, batchSize); + } + + @Override + public void close() { + if (nativeHandle != 0) { + destroyGeneratorParams(nativeHandle); + nativeHandle = 0; + } + } + + long nativeHandle() { + return nativeHandle; + } + + static { + try { + GenAI.init(); + } catch (Exception e) { + throw new RuntimeException("Failed to load onnxruntime-genai native libraries", e); + } + } + + private native long createGeneratorParams(long modelHandle) throws GenAIException; + + private native void destroyGeneratorParams(long nativeHandle); + + private native void setSearchOptionNumber(long nativeHandle, String optionName, double value) + throws GenAIException; + + private native void setSearchOptionBool(long nativeHandle, String optionName, boolean value) + throws GenAIException; + + private native void setInputSequences(long nativeHandle, long sequencesHandle) + throws GenAIException; + + private native void setInputIDs( + long nativeHandle, ByteBuffer tokenIds, int sequenceLength, int batchSize) + throws GenAIException; +} diff --git a/src/java/src/main/java/ai/onnxruntime/genai/Model.java b/src/java/src/main/java/ai/onnxruntime/genai/Model.java new file mode 100644 index 000000000..cac90136e --- /dev/null +++ b/src/java/src/main/java/ai/onnxruntime/genai/Model.java @@ -0,0 +1,92 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. Licensed under the MIT License. + */ +package ai.onnxruntime.genai; + +public final class Model implements AutoCloseable { + private long nativeHandle; + + public Model(String modelPath) throws GenAIException { + nativeHandle = createModel(modelPath); + } + + /** + * Creates a Tokenizer instance for this model. The model contains the configuration information + * that determines the tokenizer to use. + * + * @return The Tokenizer instance. + * @throws GenAIException If the call to the GenAI native API fails. + */ + public Tokenizer createTokenizer() throws GenAIException { + if (nativeHandle == 0) { + throw new IllegalStateException("Instance has been freed and is invalid"); + } + + return new Tokenizer(this); + } + + // NOTE: Having model.createGeneratorParams is still under discussion. + // model.createTokenizer is consistent with the python setup at least and agreed upon. + + /** + * Creates a GeneratorParams instance for executing the model. NOTE: GeneratorParams internally + * uses the Model, so the Model instance must remain valid + * + * @return The GeneratorParams instance. + * @throws GenAIException If the call to the GenAI native API fails. + */ + public GeneratorParams createGeneratorParams() throws GenAIException { + if (nativeHandle == 0) { + throw new IllegalStateException("Instance has been freed and is invalid"); + } + + return new GeneratorParams(this); + } + + /** + * Run the model to generate output sequences. Generation is limited to the "max_length" value + * (default:300) in the generator parameters. Use a Tokenizer to decode the generated sequences. + * + * @param generatorParams The generator parameters. + * @return The generated sequences. + * @throws GenAIException If the call to the GenAI native API fails. + */ + public Sequences generate(GeneratorParams generatorParams) throws GenAIException { + if (generatorParams.nativeHandle() == 0) { + throw new IllegalArgumentException("generatorParams has been freed and is invalid"); + } + + if (nativeHandle == 0) { + throw new IllegalStateException("Instance has been freed and is invalid"); + } + + long sequencesHandle = generate(nativeHandle, generatorParams.nativeHandle()); + return new Sequences(sequencesHandle); + } + + @Override + public void close() { + if (nativeHandle != 0) { + destroyModel(nativeHandle); + nativeHandle = 0; + } + } + + long nativeHandle() { + return nativeHandle; + } + + static { + try { + GenAI.init(); + } catch (Exception e) { + throw new RuntimeException("Failed to load onnxruntime-genai native libraries", e); + } + } + + private native long createModel(String modelPath) throws GenAIException; + + private native void destroyModel(long modelHandle); + + private native long generate(long modelHandle, long generatorParamsHandle) throws GenAIException; +} diff --git a/src/java/src/main/java/ai/onnxruntime/genai/Sequences.java b/src/java/src/main/java/ai/onnxruntime/genai/Sequences.java new file mode 100644 index 000000000..2c100622c --- /dev/null +++ b/src/java/src/main/java/ai/onnxruntime/genai/Sequences.java @@ -0,0 +1,66 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. Licensed under the MIT License. + */ +package ai.onnxruntime.genai; + +/** Represents a collection of encoded prompts/responses. */ +public final class Sequences implements AutoCloseable { + private long nativeHandle; + private long numSequences; + + Sequences(long sequencesHandle) { + assert (sequencesHandle != 0); // internal usage should never pass an invalid handle + + nativeHandle = sequencesHandle; + numSequences = getSequencesCount(sequencesHandle); + } + + /** + * Gets the number of sequences in the collection. This is equivalent to the batch size. + * + * @return The number of sequences. + */ + public long numSequences() { + return numSequences; + } + + /** + * Gets the sequence at the specified index. + * + * @param sequenceIndex The index of the sequence. + * @return The sequence as an array of integers. + */ + int[] getSequence(long sequenceIndex) { + if (nativeHandle == 0) { + throw new IllegalStateException("Instance has been freed and is invalid"); + } + + return getSequenceNative(nativeHandle, sequenceIndex); + } + + @Override + public void close() { + if (nativeHandle != 0) { + destroySequences(nativeHandle); + nativeHandle = 0; + } + } + + long nativeHandle() { + return nativeHandle; + } + + static { + try { + GenAI.init(); + } catch (Exception e) { + throw new RuntimeException("Failed to load onnxruntime-genai native libraries", e); + } + } + + private native long getSequencesCount(long sequencesHandle); + + private native int[] getSequenceNative(long sequencesHandle, long sequenceIndex); + + private native void destroySequences(long sequencesHandle); +} diff --git a/src/java/src/main/java/ai/onnxruntime/genai/SimpleGenAI.java b/src/java/src/main/java/ai/onnxruntime/genai/SimpleGenAI.java new file mode 100644 index 000000000..3fe095581 --- /dev/null +++ b/src/java/src/main/java/ai/onnxruntime/genai/SimpleGenAI.java @@ -0,0 +1,123 @@ +package ai.onnxruntime.genai; + +import java.util.function.Consumer; + +/** + * The `SimpleGenAI` class provides a simple usage example of the GenAI API. It works with a model + * that generates text based on a prompt, processing a single prompt at a time. + * + *

Usage: + * + *

    + *
  • Create an instance of the class with the path to the model. The path should also contain + * the GenAI configuration files. + *
  • Call createGeneratorParams with the prompt text. + *
  • Set any other search options via the GeneratorParams object as needed using + * `setSearchOption`. + *
  • Call generate with the GeneratorParams object and an optional listener. + *
+ * + *

The listener is used as a callback mechanism so that tokens can be used as they are generated. + * Create a class that implements the TokenUpdateListener interface and provide an instance of that + * class as the `listener` argument. + */ +public class SimpleGenAI { + private Model model; + private Tokenizer tokenizer; + + public SimpleGenAI(String modelPath) throws GenAIException { + model = new Model(modelPath); + tokenizer = new Tokenizer(model); + } + + /** + * Create the generator parameters and add the prompt text. The user can set other search options + * via the GeneratorParams object prior to running `generate`. + * + * @param prompt The prompt text to encode. + * @return The generator parameters. + * @throws GenAIException on failure + */ + GeneratorParams createGeneratorParams(String prompt) throws GenAIException { + GeneratorParams generatorParams = model.createGeneratorParams(); + + try (Sequences encodedPrompt = tokenizer.encode(prompt)) { + generatorParams.setInput(encodedPrompt); + } catch (GenAIException e) { + generatorParams.close(); + throw e; + } + + return generatorParams; + } + + /** + * Create the generator parameters and add the prompt text. The user can set other search options + * via the GeneratorParams object prior to running `generate`. + * + * @return The generator parameters. + * @throws GenAIException on failure + */ + GeneratorParams createGeneratorParams(int[] tokenIds, int sequenceLength, int batchSize) + throws GenAIException { + GeneratorParams generatorParams = model.createGeneratorParams(); + try { + generatorParams.setInput(tokenIds, sequenceLength, batchSize); + } catch (GenAIException e) { + generatorParams.close(); + throw e; + } + + return generatorParams; + } + + /** + * Generate text based on the prompt and settings in GeneratorParams. + * + *

NOTE: This only handles a single sequence of input (i.e. a single prompt which equates to + * batch size of 1) + * + * @param generatorParams The prompt and settings to run the model with. + * @param listener Optional callback for tokens to be provided as they are generated. NOTE: Token + * generation will be blocked until the listener's `accept` method returns. + * @return The generated text. + * @throws GenAIException on failure + */ + public String generate(GeneratorParams generatorParams, Consumer listener) + throws GenAIException { + String result; + try (Tokenizer tokenizer = new Tokenizer(model)) { + int[] output_ids; + + if (listener != null) { + try (TokenizerStream stream = tokenizer.createStream(); + Generator generator = new Generator(model, generatorParams)) { + while (!generator.isDone()) { + // generate next token + generator.computeLogits(); + generator.generateNextToken(); + + // decode and call listener + int token_id = generator.getLastTokenInSequence(0); + String token = stream.decode(token_id); + listener.accept(token); + // listener.onTokenGenerate(token); + } + + output_ids = generator.getSequence(0); + } catch (GenAIException e) { + throw new GenAIException("Token generation loop failed.", e); + } + } else { + Sequences output_sequences = model.generate(generatorParams); + output_ids = output_sequences.getSequence(0); + } + + result = tokenizer.decode(output_ids); + } catch (GenAIException e) { + throw new GenAIException("Failed to create Tokenizer.", e); + } + + return result; + } +} diff --git a/src/java/src/main/java/ai/onnxruntime/genai/Tokenizer.java b/src/java/src/main/java/ai/onnxruntime/genai/Tokenizer.java new file mode 100644 index 000000000..6a0bf703f --- /dev/null +++ b/src/java/src/main/java/ai/onnxruntime/genai/Tokenizer.java @@ -0,0 +1,116 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. Licensed under the MIT License. + */ +package ai.onnxruntime.genai; + +/** The Tokenizer class is responsible for converting between text and token ids. */ +public class Tokenizer implements AutoCloseable { + private long nativeHandle; + + Tokenizer(Model model) throws GenAIException { + assert (model.nativeHandle() != 0); // internal code should never pass an invalid model + + nativeHandle = createTokenizer(model.nativeHandle()); + } + + /** + * Encodes a string into a sequence of token ids. + * + * @param string Text to encode as token ids. + * @return a Sequences object with a single sequence in it. + * @throws GenAIException If the call to the GenAI native API fails. + */ + public Sequences encode(String string) throws GenAIException { + return encodeBatch(new String[] {string}); + } + + /** + * Encodes an array of strings into a sequence of token ids for each input. + * + * @param strings Collection of strings to encode as token ids. + * @return a Sequences object with one sequence per input string. + * @throws GenAIException If the call to the GenAI native API fails. + */ + public Sequences encodeBatch(String[] strings) throws GenAIException { + if (nativeHandle == 0) { + throw new IllegalStateException("Instance has been freed and is invalid"); + } + + long sequencesHandle = tokenizerEncode(nativeHandle, strings); + return new Sequences(sequencesHandle); + } + + /** + * Decodes a sequence of token ids into text. + * + * @param sequence Collection of token ids to decode to text. + * @return The text representation of the sequence. + * @throws GenAIException If the call to the GenAI native API fails. + */ + public String decode(int[] sequence) throws GenAIException { + if (nativeHandle == 0) { + throw new IllegalStateException("Instance has been freed and is invalid"); + } + + return tokenizerDecode(nativeHandle, sequence); + } + + /** + * Decodes a batch of sequences of token ids into text. + * + * @param sequences A Sequences object with one or more sequences of token ids. + * @return An array of strings with the text representation of each sequence. + * @throws GenAIException If the call to the GenAI native API fails. + */ + public String[] decodeBatch(Sequences sequences) throws GenAIException { + int numSequences = (int) sequences.numSequences(); + + String[] result = new String[numSequences]; + for (int i = 0; i < numSequences; i++) { + result[i] = decode(sequences.getSequence(i)); + } + + return result; + } + + /** + * Creates a TokenizerStream object for streaming tokenization. This is used with Generator class + * to provide each token as it is generated. + * + * @return The new TokenizerStream instance. + * @throws GenAIException If the call to the GenAI native API fails. + */ + public TokenizerStream createStream() throws GenAIException { + if (nativeHandle == 0) { + throw new IllegalStateException("Instance has been freed and is invalid"); + } + + return new TokenizerStream(createTokenizerStream(nativeHandle)); + } + + @Override + public void close() { + if (nativeHandle != 0) { + destroyTokenizer(nativeHandle); + nativeHandle = 0; + } + } + + static { + try { + GenAI.init(); + } catch (Exception e) { + throw new RuntimeException("Failed to load onnxruntime-genai native libraries", e); + } + } + + private native long createTokenizer(long modelHandle) throws GenAIException; + + private native void destroyTokenizer(long tokenizerHandle); + + private native long tokenizerEncode(long tokenizerHandle, String[] strings) throws GenAIException; + + private native String tokenizerDecode(long tokenizerHandle, int[] sequence) throws GenAIException; + + private native long createTokenizerStream(long tokenizerHandle) throws GenAIException; +} diff --git a/src/java/src/main/java/ai/onnxruntime/genai/TokenizerStream.java b/src/java/src/main/java/ai/onnxruntime/genai/TokenizerStream.java new file mode 100644 index 000000000..a30c1724c --- /dev/null +++ b/src/java/src/main/java/ai/onnxruntime/genai/TokenizerStream.java @@ -0,0 +1,46 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. Licensed under the MIT License. + */ +package ai.onnxruntime.genai; + +/** + * A TokenizerStream is used to convert individual tokens when using Generator.generateNextToken. + */ +public class TokenizerStream implements AutoCloseable { + + private long nativeHandle = 0; + + TokenizerStream(long tokenizerStreamHandle) { + assert (tokenizerStreamHandle != 0); // internal usage should never pass an invalid handle + nativeHandle = tokenizerStreamHandle; + } + + public String decode(int token) throws GenAIException { + if (nativeHandle == 0) { + throw new IllegalStateException("Instance has been freed and is invalid"); + } + + return tokenizerStreamDecode(nativeHandle, token); + } + + @Override + public void close() { + if (nativeHandle != 0) { + destroyTokenizerStream(nativeHandle); + nativeHandle = 0; + } + } + + static { + try { + GenAI.init(); + } catch (Exception e) { + throw new RuntimeException("Failed to load onnxruntime-genai native libraries", e); + } + } + + private native String tokenizerStreamDecode(long tokenizerStreamHandle, int token) + throws GenAIException; + + private native void destroyTokenizerStream(long tokenizerStreamHandle); +} diff --git a/src/java/src/main/java/ai/onnxruntime/genai/package-info.java b/src/java/src/main/java/ai/onnxruntime/genai/package-info.java new file mode 100644 index 000000000..9a43f976a --- /dev/null +++ b/src/java/src/main/java/ai/onnxruntime/genai/package-info.java @@ -0,0 +1,40 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ + +/** + * A Java interface to the ONNX Runtime GenAI library. + * + *

There are two shared libraries required: onnxruntime-genai and + * onnxruntime-genai-jni + * . The loader is in {@link ai.onnxruntime.genai.GenAI} and the logic is in this order: + * + *

    + *
  1. The user may signal to skip loading of a shared library using a property in the form + * onnxruntime-genai.native.LIB_NAME.skip with a value of true. This means + * the user has decided to load the library by some other means. + *
  2. The user may specify an explicit location of all native library files using a property in + * the form onnxruntime-genai.native.path. This uses {@link + * java.lang.System#load}. + *
  3. The user may specify an explicit location of the shared library file using a property in + * the form onnxruntime-genai.native.LIB_NAME.path. This uses {@link + * java.lang.System#load}. + *
  4. The shared library is autodiscovered: + *
      + *
    1. If the shared library is present in the classpath resources, load using {@link + * java.lang.System#load} via a temporary file. Ideally, this should be the default use + * case when adding JAR's/dependencies containing the shared libraries to your + * classpath. + *
    2. If the shared library is not present in the classpath resources, then load using + * {@link java.lang.System#loadLibrary}, which usually looks elsewhere on the filesystem + * for the library. The semantics and behavior of that method are system/JVM dependent. + * Typically, the java.library.path property is used to specify the + * location of native libraries. + *
    + *
+ * + * For troubleshooting, all shared library loading events are reported to Java logging at the level + * FINE. + */ +package ai.onnxruntime.genai; diff --git a/src/java/src/main/native/ai_onnxruntime_genai_Generator.cpp b/src/java/src/main/native/ai_onnxruntime_genai_Generator.cpp new file mode 100644 index 000000000..6a3e953bd --- /dev/null +++ b/src/java/src/main/native/ai_onnxruntime_genai_Generator.cpp @@ -0,0 +1,81 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ +#include +#include "ort_genai_c.h" +#include "utils.h" + +#include + +using namespace Helpers; + +extern "C" JNIEXPORT jlong JNICALL +Java_ai_onnxruntime_genai_Generator_createGenerator(JNIEnv* env, jobject thiz, jlong model_handle, + jlong generator_params_handle) { + const OgaModel* model = reinterpret_cast(model_handle); + const OgaGeneratorParams* params = reinterpret_cast(generator_params_handle); + OgaGenerator* generator = nullptr; + if (ThrowIfError(env, OgaCreateGenerator(model, params, &generator))) { + return 0; + } + + return reinterpret_cast(generator); +} + +extern "C" JNIEXPORT void JNICALL +Java_ai_onnxruntime_genai_Generator_destroyGenerator(JNIEnv* env, jobject thiz, jlong native_handle) { + OgaDestroyGenerator(reinterpret_cast(native_handle)); +} + +extern "C" JNIEXPORT jboolean JNICALL +Java_ai_onnxruntime_genai_Generator_isDone(JNIEnv* env, jobject thiz, jlong native_handle) { + return OgaGenerator_IsDone(reinterpret_cast(native_handle)); +} + +extern "C" JNIEXPORT void JNICALL +Java_ai_onnxruntime_genai_Generator_computeLogits(JNIEnv* env, jobject thiz, jlong native_handle) { + ThrowIfError(env, OgaGenerator_ComputeLogits(reinterpret_cast(native_handle))); +} + +extern "C" JNIEXPORT void JNICALL +Java_ai_onnxruntime_genai_Generator_generateNextTokenNative(JNIEnv* env, jobject thiz, jlong native_handle) { + ThrowIfError(env, OgaGenerator_GenerateNextToken(reinterpret_cast(native_handle))); +} + +extern "C" JNIEXPORT jintArray JNICALL +Java_ai_onnxruntime_genai_Generator_getSequenceNative(JNIEnv* env, jobject thiz, jlong generator, jlong index) { + const OgaGenerator* oga_generator = reinterpret_cast(generator); + + size_t num_tokens = OgaGenerator_GetSequenceCount(oga_generator, index); + const int32_t* tokens = OgaGenerator_GetSequenceData(oga_generator, index); + + if (num_tokens == 0) { + ThrowException(env, "OgaGenerator_GetSequenceCount returned 0 tokens."); + return nullptr; + } + + // as there's no 'destroy' function in GenAI C API for the tokens we assume the OgaGenerator owns the memory. + // copy the tokens so there's no potential for Java code to write to it (values should be treated as const) + // or attempt to access the memory after the OgaGenerator is destroyed. + jintArray java_int_array = env->NewIntArray(num_tokens); + // jint is `long` on Windows and `int` on linux. 32-bit but requires reinterpret_cast. + env->SetIntArrayRegion(java_int_array, 0, num_tokens, reinterpret_cast(tokens)); + + return java_int_array; +} + +extern "C" JNIEXPORT jint JNICALL +Java_ai_onnxruntime_genai_Generator_getSequenceLastToken(JNIEnv* env, jobject thiz, jlong generator, jlong index) { + const OgaGenerator* oga_generator = reinterpret_cast(generator); + + size_t num_tokens = OgaGenerator_GetSequenceCount(oga_generator, index); + const int32_t* tokens = OgaGenerator_GetSequenceData(oga_generator, index); + + if (num_tokens == 0) { + ThrowException(env, "OgaGenerator_GetSequenceCount returned 0 tokens."); + return -1; + } + + return jint(tokens[num_tokens - 1]); +} \ No newline at end of file diff --git a/src/java/src/main/native/ai_onnxruntime_genai_GeneratorParams.cpp b/src/java/src/main/native/ai_onnxruntime_genai_GeneratorParams.cpp new file mode 100644 index 000000000..78a3adda0 --- /dev/null +++ b/src/java/src/main/native/ai_onnxruntime_genai_GeneratorParams.cpp @@ -0,0 +1,64 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ +#include +#include "ort_genai_c.h" +#include "utils.h" + +using namespace Helpers; + +extern "C" JNIEXPORT jlong JNICALL +Java_ai_onnxruntime_genai_GeneratorParams_createGeneratorParams(JNIEnv* env, jobject thiz, jlong model_handle) { + const OgaModel* model = reinterpret_cast(model_handle); + OgaGeneratorParams* generator_params = nullptr; + if (ThrowIfError(env, OgaCreateGeneratorParams(model, &generator_params))) { + return 0; + } + + return reinterpret_cast(generator_params); +} + +extern "C" JNIEXPORT void JNICALL +Java_ai_onnxruntime_genai_GeneratorParams_destroyGeneratorParams(JNIEnv* env, jobject thiz, jlong native_handle) { + OgaGeneratorParams* generator_params = reinterpret_cast(native_handle); + OgaDestroyGeneratorParams(generator_params); +} + +extern "C" JNIEXPORT void JNICALL +Java_ai_onnxruntime_genai_GeneratorParams_setSearchOptionNumber(JNIEnv* env, jobject thiz, jlong native_handle, + jstring option_name, jdouble value) { + OgaGeneratorParams* generator_params = reinterpret_cast(native_handle); + CString name{env, option_name}; + + ThrowIfError(env, OgaGeneratorParamsSetSearchNumber(generator_params, name, value)); +} + +extern "C" JNIEXPORT void JNICALL +Java_ai_onnxruntime_genai_GeneratorParams_setSearchOptionBool(JNIEnv* env, jobject thiz, jlong native_handle, + jstring option_name, jboolean value) { + OgaGeneratorParams* generator_params = reinterpret_cast(native_handle); + CString name{env, option_name}; + + ThrowIfError(env, OgaGeneratorParamsSetSearchBool(generator_params, name, value)); +} + +extern "C" JNIEXPORT void JNICALL +Java_ai_onnxruntime_genai_GeneratorParams_setInputSequences(JNIEnv* env, jobject thiz, jlong native_handle, + jlong sequences_handle) { + OgaGeneratorParams* generator_params = reinterpret_cast(native_handle); + const OgaSequences* sequences = reinterpret_cast(sequences_handle); + + ThrowIfError(env, OgaGeneratorParamsSetInputSequences(generator_params, sequences)); +} + +extern "C" JNIEXPORT void JNICALL +Java_ai_onnxruntime_genai_GeneratorParams_setInputIDs(JNIEnv* env, jobject thiz, jlong native_handle, + jobject token_ids, jint sequence_length, jint batch_size) { + OgaGeneratorParams* generator_params = reinterpret_cast(native_handle); + + auto num_tokens = sequence_length * batch_size; + const int32_t* tokens = reinterpret_cast(env->GetDirectBufferAddress(token_ids)); + + ThrowIfError(env, OgaGeneratorParamsSetInputIDs(generator_params, tokens, num_tokens, sequence_length, batch_size)); +} \ No newline at end of file diff --git a/src/java/src/main/native/ai_onnxruntime_genai_Model.cpp b/src/java/src/main/native/ai_onnxruntime_genai_Model.cpp new file mode 100644 index 000000000..61e4a0828 --- /dev/null +++ b/src/java/src/main/native/ai_onnxruntime_genai_Model.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ +#include +#include "ort_genai_c.h" +#include "utils.h" + +using namespace Helpers; + +extern "C" JNIEXPORT jlong JNICALL +Java_ai_onnxruntime_genai_Model_createModel(JNIEnv* env, jobject thiz, jstring model_path) { + CString path{env, model_path}; + + OgaModel* model = nullptr; + if (ThrowIfError(env, OgaCreateModel(path, &model))) { + return 0; + } + + return reinterpret_cast(model); +} + +extern "C" JNIEXPORT void JNICALL +Java_ai_onnxruntime_genai_Model_destroyModel(JNIEnv* env, jobject thiz, jlong model_handle) { + OgaModel* model = reinterpret_cast(model_handle); + OgaDestroyModel(model); +} + +extern "C" JNIEXPORT jlong JNICALL +Java_ai_onnxruntime_genai_Model_generate(JNIEnv* env, jobject thiz, jlong model_handle, + jlong generator_params_handle) { + const OgaModel* model = reinterpret_cast(model_handle); + const OgaGeneratorParams* params = reinterpret_cast(generator_params_handle); + OgaSequences* sequences = nullptr; + if (ThrowIfError(env, OgaGenerate(model, params, &sequences))) { + return 0; + } + + return reinterpret_cast(sequences); +} \ No newline at end of file diff --git a/src/java/src/main/native/ai_onnxruntime_genai_Sequences.cpp b/src/java/src/main/native/ai_onnxruntime_genai_Sequences.cpp new file mode 100644 index 000000000..ac57f2337 --- /dev/null +++ b/src/java/src/main/native/ai_onnxruntime_genai_Sequences.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ +#include +#include "ort_genai_c.h" +#include "utils.h" + +using namespace Helpers; + +extern "C" JNIEXPORT void JNICALL +Java_ai_onnxruntime_genai_Sequences_destroySequences(JNIEnv* env, jobject thiz, jlong sequences_handle) { + OgaSequences* sequences = reinterpret_cast(sequences_handle); + OgaDestroySequences(sequences); +} + +extern "C" JNIEXPORT jlong JNICALL +Java_ai_onnxruntime_genai_Sequences_getSequencesCount(JNIEnv* env, jobject thiz, jlong sequences_handle) { + const OgaSequences* sequences = reinterpret_cast(sequences_handle); + size_t num_sequences = OgaSequencesCount(sequences); + return static_cast(num_sequences); +} + +extern "C" JNIEXPORT jintArray JNICALL +Java_ai_onnxruntime_genai_Sequences_getSequenceNative(JNIEnv* env, jobject thiz, jlong sequences_handle, + jlong sequence_index) { + const OgaSequences* sequences = reinterpret_cast(sequences_handle); + + size_t num_tokens = OgaSequencesGetSequenceCount(sequences, (size_t)sequence_index); + const int32_t* tokens = OgaSequencesGetSequenceData(sequences, (size_t)sequence_index); + + // as there's no 'destroy' function in GenAI C API for the tokens we assume OgaSequences owns the memory. + // copy the tokens so there's no potential for Java code to write to it (values should be treated as const), + // or attempt to access the memory after the OgaSequences is destroyed. + // note: jint is `long` on Windows and `int` on linux. both are 32-bit but require reinterpret_cast. + jintArray java_int_array = env->NewIntArray(num_tokens); + env->SetIntArrayRegion(java_int_array, 0, num_tokens, reinterpret_cast(tokens)); + + return java_int_array; +} diff --git a/src/java/src/main/native/ai_onnxruntime_genai_Tokenizer.cpp b/src/java/src/main/native/ai_onnxruntime_genai_Tokenizer.cpp new file mode 100644 index 000000000..92b56e7f5 --- /dev/null +++ b/src/java/src/main/native/ai_onnxruntime_genai_Tokenizer.cpp @@ -0,0 +1,84 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ +#include +#include "ort_genai_c.h" +#include "utils.h" + +using namespace Helpers; + +extern "C" JNIEXPORT jlong JNICALL +Java_ai_onnxruntime_genai_Tokenizer_createTokenizer(JNIEnv* env, jobject thiz, jlong model_handle) { + const OgaModel* model = reinterpret_cast(model_handle); + OgaTokenizer* tokenizer = nullptr; + + if (ThrowIfError(env, OgaCreateTokenizer(model, &tokenizer))) { + return 0; + } + + return reinterpret_cast(tokenizer); +} + +extern "C" JNIEXPORT void JNICALL +Java_ai_onnxruntime_genai_Tokenizer_destroyTokenizer(JNIEnv* env, jobject thiz, jlong tokenizer_handle) { + OgaTokenizer* tokenizer = reinterpret_cast(tokenizer_handle); + OgaDestroyTokenizer(tokenizer); +} + +extern "C" JNIEXPORT jlong JNICALL +Java_ai_onnxruntime_genai_Tokenizer_tokenizerEncode(JNIEnv* env, jobject thiz, jlong tokenizer_handle, + jobjectArray strings) { + const OgaTokenizer* tokenizer = reinterpret_cast(tokenizer_handle); + auto num_strings = env->GetArrayLength(strings); + + OgaSequences* sequences = nullptr; + if (ThrowIfError(env, OgaCreateSequences(&sequences))) { + return 0; + } + + for (int i = 0; i < num_strings; i++) { + jstring string = static_cast(env->GetObjectArrayElement(strings, i)); + CString c_string{env, string}; + if (ThrowIfError(env, OgaTokenizerEncode(tokenizer, c_string, sequences))) { + OgaDestroySequences(sequences); + return 0; + } + } + + return reinterpret_cast(sequences); +} + +extern "C" JNIEXPORT jstring JNICALL +Java_ai_onnxruntime_genai_Tokenizer_tokenizerDecode(JNIEnv* env, jobject thiz, jlong tokenizer_handle, + jintArray sequence) { + const OgaTokenizer* tokenizer = reinterpret_cast(tokenizer_handle); + auto num_tokens = env->GetArrayLength(sequence); + jint* jtokens = env->GetIntArrayElements(sequence, nullptr); + const int32_t* tokens = reinterpret_cast(jtokens); // convert between 32-bit types + const char* decoded_text = nullptr; + + bool error = ThrowIfError(env, OgaTokenizerDecode(tokenizer, tokens, num_tokens, &decoded_text)); + env->ReleaseIntArrayElements(sequence, jtokens, JNI_ABORT); + + if (error) { + return nullptr; + } + + jstring result = env->NewStringUTF(decoded_text); + OgaDestroyString(decoded_text); + + return result; +} + +extern "C" JNIEXPORT jlong JNICALL +Java_ai_onnxruntime_genai_Tokenizer_createTokenizerStream(JNIEnv* env, jobject thiz, jlong tokenizer_handle) { + const OgaTokenizer* tokenizer = reinterpret_cast(tokenizer_handle); + OgaTokenizerStream* tokenizer_stream = nullptr; + + if (ThrowIfError(env, OgaCreateTokenizerStream(tokenizer, &tokenizer_stream))) { + return 0; + } + + return reinterpret_cast(tokenizer_stream); +} \ No newline at end of file diff --git a/src/java/src/main/native/ai_onnxruntime_genai_TokenizerStream.cpp b/src/java/src/main/native/ai_onnxruntime_genai_TokenizerStream.cpp new file mode 100644 index 000000000..8e725c807 --- /dev/null +++ b/src/java/src/main/native/ai_onnxruntime_genai_TokenizerStream.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ +#include +#include "ort_genai_c.h" +#include "utils.h" + +#include + +using namespace Helpers; + +extern "C" JNIEXPORT jstring JNICALL +Java_ai_onnxruntime_genai_TokenizerStream_tokenizerStreamDecode(JNIEnv* env, jobject thiz, + jlong tokenizer_stream_handle, jint token) { + OgaTokenizerStream* tokenizer_stream = reinterpret_cast(tokenizer_stream_handle); + const char* decoded_text = nullptr; + + // The const char* returned in decoded_text is the result of calling c_str on a std::string in the tokenizer cache. + // The std::string is owned by the tokenizer cache. + // Due to that, it is invalid to call `OgaDestroyString(decoded_text)`, and doing so will result in a crash. + if (ThrowIfError(env, OgaTokenizerStreamDecode(tokenizer_stream, token, &decoded_text))) { + return nullptr; + } + + jstring result = env->NewStringUTF(decoded_text); + return result; +} + +extern "C" JNIEXPORT void JNICALL +Java_ai_onnxruntime_genai_TokenizerStream_destroyTokenizerStream(JNIEnv* env, jobject thiz, + jlong tokenizer_stream_handle) { + OgaTokenizerStream* tokenizer_stream = reinterpret_cast(tokenizer_stream_handle); + OgaDestroyTokenizerStream(tokenizer_stream); +} \ No newline at end of file diff --git a/src/java/src/main/native/utils.cpp b/src/java/src/main/native/utils.cpp new file mode 100644 index 000000000..d272b449d --- /dev/null +++ b/src/java/src/main/native/utils.cpp @@ -0,0 +1,41 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ +#include +#include "utils.h" + +jint JNI_OnLoad(JavaVM* vm, void* reserved) { + // To silence unused-parameter error. + // This function must exist according to the JNI spec, but the arguments aren't necessary for the library + // to request a specific version. + (void)vm; + (void)reserved; + // Requesting 1.6 to support Android. Will need to be bumped to a later version to call interface default methods + // from native code, or to access other new Java features. + return JNI_VERSION_1_6; +} + +namespace { +void ThrowExceptionImpl(JNIEnv* env, const char* error_message) { + static const char* className = "ai/onnxruntime/genai/GenAIException"; + env->ThrowNew(env->FindClass(className), error_message); +} +} // namespace + +namespace Helpers { +void ThrowException(JNIEnv* env, const char* message) { + ThrowExceptionImpl(env, message); +} + +bool ThrowIfError(JNIEnv* env, OgaResult* result) { + bool error = result != nullptr; + + if (error) { + ThrowExceptionImpl(env, OgaResultGetError(result)); + OgaDestroyResult(result); + } + + return error; +} +} // namespace Helpers \ No newline at end of file diff --git a/src/java/src/main/native/utils.h b/src/java/src/main/native/utils.h new file mode 100644 index 000000000..b56fdbafc --- /dev/null +++ b/src/java/src/main/native/utils.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ +#pragma once + +#include +#include +#include "ort_genai_c.h" + +#ifdef __cplusplus +extern "C" { +#endif + +jint JNI_OnLoad(JavaVM* vm, void* reserved); + +#ifdef __cplusplus +} +#endif + +namespace Helpers { +void ThrowException(JNIEnv* env, const char* message); + +/// @brief Throw a GenAIException if the result is an error. +/// @param env JNI environment +/// @param result Result from GenAI C API call +/// @return True if there was an error. JNI code should generally return immediately if this is true. +bool ThrowIfError(JNIEnv* env, OgaResult* result); + +// handle conversion/release of jstring to const char* +struct CString { + CString(JNIEnv* env, jstring str) + : env_{env}, str_{str}, cstr{env->GetStringUTFChars(str, /* isCopy */ nullptr)} { + } + + const char* cstr; + + operator const char*() const { return cstr; } + + ~CString() { + env_->ReleaseStringUTFChars(str_, cstr); + } + + private: + JNIEnv* env_; + jstring str_; +}; +} // namespace Helpers diff --git a/src/java/src/test/android/.gitignore b/src/java/src/test/android/.gitignore new file mode 100644 index 000000000..e099cb50e --- /dev/null +++ b/src/java/src/test/android/.gitignore @@ -0,0 +1,90 @@ +# Built application files +*.apk +*.aar +*.ap_ +*.aab + +# Files for the ART/Dalvik VM +*.dex + +# Java class files +*.class + +# Generated files +bin/ +gen/ +out/ +# Uncomment the following line in case you need and you don't have the release build type files in your app +# release/ + +# Gradle files +.gradle/ +build/ + +# Local configuration file (sdk path, etc) +local.properties + +# Proguard folder generated by Eclipse +proguard/ + +# Log Files +*.log + +# Android Studio Navigation editor temp files +.navigation/ + +# Android Studio captures folder +captures/ + +# IntelliJ +*.iml +.idea/workspace.xml +.idea/tasks.xml +.idea/gradle.xml +.idea/assetWizardSettings.xml +.idea/dictionaries +.idea/libraries +# Android Studio 3 in .gitignore file. +.idea/caches +.idea/modules.xml +# Comment next line if keeping position of elements in Navigation Editor is relevant for you +.idea/navEditor.xml + +# Keystore files +# Uncomment the following lines if you do not want to check your keystore files in. +#*.jks +#*.keystore + +# External native build folder generated in Android Studio 2.2 and later +.externalNativeBuild +.cxx/ + +# Google Services (e.g. APIs or Firebase) +# google-services.json + +# Freeline +freeline.py +freeline/ +freeline_project_description.json + +# fastlane +fastlane/report.xml +fastlane/Preview.html +fastlane/screenshots +fastlane/test_output +fastlane/readme.md + +# Version control +vcs.xml + +# lint +lint/intermediates/ +lint/generated/ +lint/outputs/ +lint/tmp/ +# lint/reports/ +.DS_Store +.idea/* +gradlew +gradlew.bat +gradle \ No newline at end of file diff --git a/src/java/src/test/android/README.md b/src/java/src/test/android/README.md new file mode 100644 index 000000000..b03da44f4 --- /dev/null +++ b/src/java/src/test/android/README.md @@ -0,0 +1,37 @@ +# Android Test Application for ORT-Mobile + +This directory contains a simple android application for testing the ONNX Runtime GenaI AAR package. + +### Test Android Application Overview + +This android application is mainly aimed for testing: + +- Model used: TBD - need smallest model GenAI produces some output with +- Main test file: An android instrumentation test under `app\src\androidtest\java\ai.onnxruntime.genai.example.javavalidator\SimpleTest.kt` +- The main dependency of this application is `onnxruntime-genai` aar package under `app\libs`. +- The MainActivity of this application is set to be empty. + +### Requirements + +- JDK version 11 or later is required. +- The [Gradle](https://gradle.org/) build system is required for building the APKs used to run [android instrumentation tests](https://source.android.com/compatibility/tests/development/instrumentation). Version 7.5 or newer is required. + The Gradle wrapper at `java/gradlew[.bat]` may be used. + +### Building + +Build for Android with the additional `--build_java` and `--android_run_emulator` options. + +e.g. +`./build --android --android_ndk D:\Android\ndk\26.2.11394342\ --android_abi x86_64 --ort_home 'path to unzipped onnxruntime-android.aar from https://mvnrepository.com/artifact/com.microsoft.onnxruntime/onnxruntime-android/' --build_java --android_run_emulator` + +Please note that you must set the `--android_abi` value to match the local system architecture, as the Android instrumentation test is run on an Android emulator on the local system. + +#### Build Output + +The build will generate two apks which is required to run the test application in `$YOUR_BUILD_DIR/java/androidtest/android/app/build/outputs/apk`: + +* `androidtest/debug/app-debug-androidtest.apk` +* `debug/app-debug.apk` + +**TODO**: Update emulator name if it is not `ort_android` once we finishing adding the `android_run_emulator` logic to build.py +After running the build script, the two apks will be installed on `ort_android` emulator and it will automatically run the test application in an adb shell. diff --git a/src/java/src/test/android/app/build.gradle b/src/java/src/test/android/app/build.gradle new file mode 100644 index 000000000..ed2379da0 --- /dev/null +++ b/src/java/src/test/android/app/build.gradle @@ -0,0 +1,51 @@ +plugins { + id 'com.android.application' + id 'kotlin-android' +} + +def minSdkVer = System.properties.get("minSdkVer")?:24 + +android { + compileSdkVersion 32 + + defaultConfig { + applicationId "ai.onnxruntime.genai.example.javavalidator" + minSdkVersion minSdkVer + targetSdkVersion 32 + versionCode 1 + versionName "1.0" + + testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" + } + + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' + } + } + compileOptions { + sourceCompatibility JavaVersion.VERSION_1_8 + targetCompatibility JavaVersion.VERSION_1_8 + } + kotlinOptions { + jvmTarget = '1.8' + } + namespace 'ai.onnxruntime.genai.example.javavalidator' +} + +dependencies { + implementation "org.jetbrains.kotlin:kotlin-stdlib-jdk8:$kotlin_version" + implementation "org.jetbrains.kotlin:kotlin-stdlib:$kotlin_version" + implementation 'androidx.core:core-ktx:1.3.2' + implementation 'androidx.appcompat:appcompat:1.2.0' + implementation 'com.google.android.material:material:1.3.0' + implementation 'androidx.constraintlayout:constraintlayout:2.0.4' + implementation(name: "onnxruntime-genai", ext: "aar") + + testImplementation 'junit:junit:4.+' + androidTestImplementation 'androidx.test.ext:junit:1.1.3' + androidTestImplementation 'androidx.test.espresso:espresso-core:3.4.0' + androidTestImplementation 'androidx.test:runner:1.4.0' + androidTestImplementation 'androidx.test:rules:1.4.0' +} diff --git a/src/java/src/test/android/app/proguard-rules.pro b/src/java/src/test/android/app/proguard-rules.pro new file mode 100644 index 000000000..481bb4348 --- /dev/null +++ b/src/java/src/test/android/app/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile \ No newline at end of file diff --git a/src/java/src/test/android/app/src/androidTest/java/ai/onnxruntime_genai/example/javavalidator/SimpleTest.kt b/src/java/src/test/android/app/src/androidTest/java/ai/onnxruntime_genai/example/javavalidator/SimpleTest.kt new file mode 100644 index 000000000..a6effb453 --- /dev/null +++ b/src/java/src/test/android/app/src/androidTest/java/ai/onnxruntime_genai/example/javavalidator/SimpleTest.kt @@ -0,0 +1,42 @@ +package ai.onnxruntime.genai.example.javavalidator + +import ai.onnxruntime.genai.* + +import android.os.Build; +import android.util.Log +import androidx.test.ext.junit.rules.ActivityScenarioRule +import androidx.test.ext.junit.runners.AndroidJUnit4 +import androidx.test.platform.app.InstrumentationRegistry +import com.microsoft.appcenter.espresso.Factory +import com.microsoft.appcenter.espresso.ReportHelper +import org.junit.* +import org.junit.runner.RunWith +import java.io.IOException +import java.util.* + +private const val TAG = "ORTGenAIAndroidTest" + +@RunWith(AndroidJUnit4::class) +class SimpleTest { + @get:Rule + val activityTestRule = ActivityScenarioRule(MainActivity::class.java) + + @get:Rule + var reportHelper: ReportHelper = Factory.getReportHelper() + + @Before + fun Start() { + reportHelper.label("Starting App") + Log.println(Log.INFO, TAG, "SystemABI=" + Build.SUPPORTED_ABIS[0]) + } + + @After + fun TearDown() { + reportHelper.label("Stopping App") + } + + @Test + fun runBasicTest() { + // TODO: Add tests + } +} diff --git a/src/java/src/test/android/app/src/main/AndroidManifest.xml b/src/java/src/test/android/app/src/main/AndroidManifest.xml new file mode 100644 index 000000000..2938b7e8b --- /dev/null +++ b/src/java/src/test/android/app/src/main/AndroidManifest.xml @@ -0,0 +1,20 @@ + + + + + + + + + + + + + + \ No newline at end of file diff --git a/src/java/src/test/android/app/src/main/java/ai/onnxruntime_genai/example/javavalidator/MainActivity.kt b/src/java/src/test/android/app/src/main/java/ai/onnxruntime_genai/example/javavalidator/MainActivity.kt new file mode 100644 index 000000000..1943ad9ea --- /dev/null +++ b/src/java/src/test/android/app/src/main/java/ai/onnxruntime_genai/example/javavalidator/MainActivity.kt @@ -0,0 +1,11 @@ +package ai.onnxruntime.genai.example.javavalidator + +import android.os.Bundle +import androidx.appcompat.app.AppCompatActivity + +/*Empty activity app mainly used for testing*/ +class MainActivity : AppCompatActivity() { + override fun onCreate(savedInstanceState: Bundle?) { + super.onCreate(savedInstanceState) + } +} \ No newline at end of file diff --git a/src/java/src/test/android/app/src/main/res/drawable-v24/ic_launcher_foreground.xml b/src/java/src/test/android/app/src/main/res/drawable-v24/ic_launcher_foreground.xml new file mode 100644 index 000000000..2b068d114 --- /dev/null +++ b/src/java/src/test/android/app/src/main/res/drawable-v24/ic_launcher_foreground.xml @@ -0,0 +1,30 @@ + + + + + + + + + + + \ No newline at end of file diff --git a/src/java/src/test/android/app/src/main/res/drawable/ic_launcher_background.xml b/src/java/src/test/android/app/src/main/res/drawable/ic_launcher_background.xml new file mode 100644 index 000000000..07d5da9cb --- /dev/null +++ b/src/java/src/test/android/app/src/main/res/drawable/ic_launcher_background.xml @@ -0,0 +1,170 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/java/src/test/android/app/src/main/res/layout/activity_main.xml b/src/java/src/test/android/app/src/main/res/layout/activity_main.xml new file mode 100644 index 000000000..4fc244418 --- /dev/null +++ b/src/java/src/test/android/app/src/main/res/layout/activity_main.xml @@ -0,0 +1,18 @@ + + + + + + \ No newline at end of file diff --git a/src/java/src/test/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher.xml b/src/java/src/test/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher.xml new file mode 100644 index 000000000..eca70cfe5 --- /dev/null +++ b/src/java/src/test/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher.xml @@ -0,0 +1,5 @@ + + + + + \ No newline at end of file diff --git a/src/java/src/test/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml b/src/java/src/test/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml new file mode 100644 index 000000000..eca70cfe5 --- /dev/null +++ b/src/java/src/test/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml @@ -0,0 +1,5 @@ + + + + + \ No newline at end of file diff --git a/src/java/src/test/android/app/src/main/res/mipmap-hdpi/ic_launcher.png b/src/java/src/test/android/app/src/main/res/mipmap-hdpi/ic_launcher.png new file mode 100644 index 000000000..a571e6009 Binary files /dev/null and b/src/java/src/test/android/app/src/main/res/mipmap-hdpi/ic_launcher.png differ diff --git a/src/java/src/test/android/app/src/main/res/mipmap-hdpi/ic_launcher_round.png b/src/java/src/test/android/app/src/main/res/mipmap-hdpi/ic_launcher_round.png new file mode 100644 index 000000000..61da551c5 Binary files /dev/null and b/src/java/src/test/android/app/src/main/res/mipmap-hdpi/ic_launcher_round.png differ diff --git a/src/java/src/test/android/app/src/main/res/mipmap-mdpi/ic_launcher.png b/src/java/src/test/android/app/src/main/res/mipmap-mdpi/ic_launcher.png new file mode 100644 index 000000000..c41dd2853 Binary files /dev/null and b/src/java/src/test/android/app/src/main/res/mipmap-mdpi/ic_launcher.png differ diff --git a/src/java/src/test/android/app/src/main/res/mipmap-mdpi/ic_launcher_round.png b/src/java/src/test/android/app/src/main/res/mipmap-mdpi/ic_launcher_round.png new file mode 100644 index 000000000..db5080a75 Binary files /dev/null and b/src/java/src/test/android/app/src/main/res/mipmap-mdpi/ic_launcher_round.png differ diff --git a/src/java/src/test/android/app/src/main/res/mipmap-xhdpi/ic_launcher.png b/src/java/src/test/android/app/src/main/res/mipmap-xhdpi/ic_launcher.png new file mode 100644 index 000000000..6dba46dab Binary files /dev/null and b/src/java/src/test/android/app/src/main/res/mipmap-xhdpi/ic_launcher.png differ diff --git a/src/java/src/test/android/app/src/main/res/mipmap-xhdpi/ic_launcher_round.png b/src/java/src/test/android/app/src/main/res/mipmap-xhdpi/ic_launcher_round.png new file mode 100644 index 000000000..da31a871c Binary files /dev/null and b/src/java/src/test/android/app/src/main/res/mipmap-xhdpi/ic_launcher_round.png differ diff --git a/src/java/src/test/android/app/src/main/res/mipmap-xxhdpi/ic_launcher.png b/src/java/src/test/android/app/src/main/res/mipmap-xxhdpi/ic_launcher.png new file mode 100644 index 000000000..15ac68172 Binary files /dev/null and b/src/java/src/test/android/app/src/main/res/mipmap-xxhdpi/ic_launcher.png differ diff --git a/src/java/src/test/android/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.png b/src/java/src/test/android/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.png new file mode 100644 index 000000000..b216f2d31 Binary files /dev/null and b/src/java/src/test/android/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.png differ diff --git a/src/java/src/test/android/app/src/main/res/mipmap-xxxhdpi/ic_launcher.png b/src/java/src/test/android/app/src/main/res/mipmap-xxxhdpi/ic_launcher.png new file mode 100644 index 000000000..f25a41974 Binary files /dev/null and b/src/java/src/test/android/app/src/main/res/mipmap-xxxhdpi/ic_launcher.png differ diff --git a/src/java/src/test/android/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.png b/src/java/src/test/android/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.png new file mode 100644 index 000000000..e96783ccc Binary files /dev/null and b/src/java/src/test/android/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.png differ diff --git a/src/java/src/test/android/app/src/main/res/values-night/themes.xml b/src/java/src/test/android/app/src/main/res/values-night/themes.xml new file mode 100644 index 000000000..02020fdea --- /dev/null +++ b/src/java/src/test/android/app/src/main/res/values-night/themes.xml @@ -0,0 +1,16 @@ + + + + \ No newline at end of file diff --git a/src/java/src/test/android/app/src/main/res/values/colors.xml b/src/java/src/test/android/app/src/main/res/values/colors.xml new file mode 100644 index 000000000..f8c6127d3 --- /dev/null +++ b/src/java/src/test/android/app/src/main/res/values/colors.xml @@ -0,0 +1,10 @@ + + + #FFBB86FC + #FF6200EE + #FF3700B3 + #FF03DAC5 + #FF018786 + #FF000000 + #FFFFFFFF + \ No newline at end of file diff --git a/src/java/src/test/android/app/src/main/res/values/strings.xml b/src/java/src/test/android/app/src/main/res/values/strings.xml new file mode 100644 index 000000000..c7668bbe1 --- /dev/null +++ b/src/java/src/test/android/app/src/main/res/values/strings.xml @@ -0,0 +1,3 @@ + + JavaValidator + \ No newline at end of file diff --git a/src/java/src/test/android/app/src/main/res/values/themes.xml b/src/java/src/test/android/app/src/main/res/values/themes.xml new file mode 100644 index 000000000..39a613d4d --- /dev/null +++ b/src/java/src/test/android/app/src/main/res/values/themes.xml @@ -0,0 +1,16 @@ + + + + \ No newline at end of file diff --git a/src/java/src/test/android/build.gradle b/src/java/src/test/android/build.gradle new file mode 100644 index 000000000..d7672e1e1 --- /dev/null +++ b/src/java/src/test/android/build.gradle @@ -0,0 +1,27 @@ +// Top-level build file where you can add configuration options common to all sub-projects/modules. +buildscript { + ext.kotlin_version = '1.6.21' + + repositories { + google() + mavenCentral() + } + dependencies { + classpath 'com.android.tools.build:gradle:7.4.2' + classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlin_version" + // NOTE: Do not place your application dependencies here; they belong + // in the individual module build.gradle files + } +} + +allprojects { + repositories { + google() + mavenCentral() + flatDir{dirs 'libs'} + } +} + +task clean(type: Delete) { + delete rootProject.buildDir +} diff --git a/src/java/src/test/android/gradle.properties b/src/java/src/test/android/gradle.properties new file mode 100644 index 000000000..aa69f30e6 --- /dev/null +++ b/src/java/src/test/android/gradle.properties @@ -0,0 +1,21 @@ +# Project-wide Gradle settings. +# IDE (e.g. Android Studio) users: +# Gradle settings configured through the IDE *will override* +# any settings specified in this file. +# For more details on how to configure your build environment visit +# http://www.gradle.org/docs/current/userguide/build_environment.html +# Specifies the JVM arguments used for the daemon process. +# The setting is particularly useful for tweaking memory settings. +org.gradle.jvmargs=-Xmx4096m -Dfile.encoding=UTF-8 +# When configured, Gradle will run in incubating parallel mode. +# This option should only be used with decoupled projects. More details, visit +# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects +# org.gradle.parallel=true +# AndroidX package structure to make it clearer which packages are bundled with the +# Android operating system, and which are packaged with your app"s APK +# https://developer.android.com/topic/libraries/support-library/androidx-rn +android.useAndroidX=true +# Automatically convert third-party libraries to use AndroidX +android.enableJetifier=true +# Kotlin code style for this project: "official" or "obsolete": +kotlin.code.style=official diff --git a/src/java/src/test/android/settings.gradle b/src/java/src/test/android/settings.gradle new file mode 100644 index 000000000..982409cdb --- /dev/null +++ b/src/java/src/test/android/settings.gradle @@ -0,0 +1,2 @@ +include ':app' +rootProject.name = "JavaValidator" \ No newline at end of file diff --git a/src/java/src/test/java/ai/onnxruntime/genai/GenerationTest.java b/src/java/src/test/java/ai/onnxruntime/genai/GenerationTest.java new file mode 100644 index 000000000..770ed037f --- /dev/null +++ b/src/java/src/test/java/ai/onnxruntime/genai/GenerationTest.java @@ -0,0 +1,103 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ +package ai.onnxruntime.genai; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.io.File; +import java.util.function.Consumer; +import java.util.logging.Logger; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIf; + +// Test the overall generation. +// Uses SimpleGenAI with phi-2 (if available) for text -> text generation. +// Uses the HF test model with pre-defined input tokens for token -> token generation +// +// This indirectly tests the majority of the bindings. Any gaps are covered in the class specific +// tests. +public class GenerationTest { + private static final Logger logger = Logger.getLogger(GenerationTest.class.getName()); + + // phi-2 can be used in full end-to-end testing but needs to be manually downloaded. + // it's also used this way in the C# unit tests. + private static final String phi2ModelPath() { + String repoRoot = TestUtils.getRepoRoot(); + File f = new File(repoRoot + "examples/python/example-models/phi2-int4-cpu"); + + if (!f.exists()) { + logger.warning("phi2 model not found at: " + f.getPath()); + logger.warning( + "Please install as per https://github.com/microsoft/onnxruntime-genai/tree/rel-0.2.0/examples/csharp/HelloPhi2"); + return null; + } + + return f.getPath(); + } + + @SuppressWarnings("unused") // Used in EnabledIf + private static boolean havePhi2() { + return phi2ModelPath() != null; + } + + @Test + @EnabledIf("havePhi2") + public void testUsageNoListener() throws GenAIException { + SimpleGenAI generator = new SimpleGenAI(phi2ModelPath()); + GeneratorParams params = generator.createGeneratorParams("What's 6 times 7?"); + + String result = generator.generate(params, null); + logger.info("Result: " + result); + assertTrue(result.indexOf("Answer: 42") != -1); + } + + @Test + @EnabledIf("havePhi2") + public void testUsageWithListener() throws GenAIException { + SimpleGenAI generator = new SimpleGenAI(phi2ModelPath()); + GeneratorParams params = generator.createGeneratorParams("What's 6 times 7?"); + Consumer listener = token -> logger.info("onTokenGenerate: " + token); + String result = generator.generate(params, listener); + + logger.info("Result: " + result); + assertTrue(result.indexOf("Answer: 42") != -1); + } + + @Test + public void testWithInputIds() throws GenAIException { + // test using the HF model. input id values must be < 1000 so we use manually created input. + // Input/expected output copied from the C# unit tests + Model model = new Model(TestUtils.testModelPath()); + GeneratorParams params = new GeneratorParams(model); + int batchSize = 2; + int sequenceLength = 4; + int maxLength = 10; + int[] inputIDs = + new int[] { + 0, 0, 0, 52, + 0, 0, 195, 731 + }; + + params.setInput(inputIDs, sequenceLength, batchSize); + params.setSearchOption("max_length", maxLength); + + int[] expectedOutput = + new int[] { + 0, 0, 0, 52, 204, 204, 204, 204, 204, 204, + 0, 0, 195, 731, 731, 114, 114, 114, 114, 114 + }; + + Sequences output = model.generate(params); + assertEquals(output.numSequences(), batchSize); + + for (int i = 0; i < batchSize; i++) { + int[] outputIds = output.getSequence(i); + for (int j = 0; j < maxLength; j++) { + assertEquals(outputIds[j], expectedOutput[i * maxLength + j]); + } + } + } +} diff --git a/src/java/src/test/java/ai/onnxruntime/genai/GeneratorParamsTest.java b/src/java/src/test/java/ai/onnxruntime/genai/GeneratorParamsTest.java new file mode 100644 index 000000000..5fc2413c3 --- /dev/null +++ b/src/java/src/test/java/ai/onnxruntime/genai/GeneratorParamsTest.java @@ -0,0 +1,25 @@ +package ai.onnxruntime.genai; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +import org.junit.jupiter.api.Test; + +// NOTE: Typical usage is covered in GenerationTest.java so we are just filling test gaps here. +public class GeneratorParamsTest { + @Test + public void testValidSearchOption() throws GenAIException { + // test setting an invalid search option throws a GenAIException + SimpleGenAI generator = new SimpleGenAI(TestUtils.testModelPath()); + GeneratorParams params = generator.createGeneratorParams("Ignoed"); + params.setSearchOption("early_stopping", true); // boolean + params.setSearchOption("max_length", 20); // number + } + + @Test + public void testInvalidSearchOption() throws GenAIException { + // test setting an invalid search option throws a GenAIException + SimpleGenAI generator = new SimpleGenAI(TestUtils.testModelPath()); + GeneratorParams params = generator.createGeneratorParams("This is a testing prompt"); + assertThrows(GenAIException.class, () -> params.setSearchOption("invalid", true)); + } +} diff --git a/src/java/src/test/java/ai/onnxruntime/genai/TestUtils.java b/src/java/src/test/java/ai/onnxruntime/genai/TestUtils.java new file mode 100644 index 000000000..4d2520bef --- /dev/null +++ b/src/java/src/test/java/ai/onnxruntime/genai/TestUtils.java @@ -0,0 +1,42 @@ +package ai.onnxruntime.genai; + +import java.io.File; +import java.net.URL; +import java.util.logging.Logger; + +public class TestUtils { + private static final Logger logger = Logger.getLogger(TestUtils.class.getName()); + + public static final String testModelPath() { + // get the resources directory from one of the classes + URL url = TestUtils.class.getResource("/hf-internal-testing/tiny-random-gpt2-fp32"); + if (url == null) { + logger.warning("Model not found at /hf-internal-testing/tiny-random-gpt2-fp32"); + return null; + } + + File f = new File(url.getFile()); + return f.getPath(); + } + + public static final String getRepoRoot() { + String classDirFileUrl = SimpleGenAI.class.getResource("").getFile(); + String repoRoot = classDirFileUrl.substring(0, classDirFileUrl.lastIndexOf("src/java/build")); + return repoRoot; + } + + public static final boolean setLocalNativeLibraryPath() { + // set to /src/java/native-jni/ai/onnxruntime/genai/native/win-x64, + // adjusting for your build output location and platform as needed + String nativeJniBuildOutput = + "build/Windows/Debug/src/java/native-jni/ai/onnxruntime/genai/native/win-x64"; + File fullPath = new File(getRepoRoot() + nativeJniBuildOutput); + if (!fullPath.exists()) { + logger.warning("Local native-jni build output not found at: " + fullPath.getPath()); + return false; + } + + System.setProperty("onnxruntime-genai.native.path", fullPath.getPath()); + return true; + } +} diff --git a/src/java/src/test/java/ai/onnxruntime/genai/TokenizerTest.java b/src/java/src/test/java/ai/onnxruntime/genai/TokenizerTest.java new file mode 100644 index 000000000..7a099f323 --- /dev/null +++ b/src/java/src/test/java/ai/onnxruntime/genai/TokenizerTest.java @@ -0,0 +1,23 @@ +package ai.onnxruntime.genai; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; + +// NOTE: Typical usage is covered in GenerationTest.java so we are just filling test gaps here. +public class TokenizerTest { + @Test + public void testBatchEncodeDecode() throws GenAIException { + try (Model model = new Model(TestUtils.testModelPath()); + Tokenizer tokenizer = new Tokenizer(model)) { + String[] inputs = new String[] {"This is a test", "This is another test"}; + Sequences encoded = tokenizer.encodeBatch(inputs); + String[] decoded = tokenizer.decodeBatch(encoded); + + assertEquals(inputs.length, decoded.length); + for (int i = 0; i < inputs.length; i++) { + assert inputs[i].equals(decoded[i]); + } + } + } +} diff --git a/src/java/windows-unittests.cmake b/src/java/windows-unittests.cmake new file mode 100644 index 000000000..a9aa84c78 --- /dev/null +++ b/src/java/windows-unittests.cmake @@ -0,0 +1,19 @@ +# This is a windows only file so we can run gradle tests via ctest + +# Are these needed? +FILE(TO_NATIVE_PATH ${GRADLE_EXECUTABLE} GRADLE_NATIVE_PATH) +FILE(TO_NATIVE_PATH ${BIN_DIR} BINDIR_NATIVE_PATH) +FILE(TO_NATIVE_PATH ${JAVA_PACKAGE_LIB_DIR} PACKAGE_LIB_DIR_NATIVE_PATH) + +execute_process(COMMAND cmd /C ${GRADLE_NATIVE_PATH} + --console=plain + cmakeCheck + -DcmakeBuildDir=${BINDIR_NATIVE_PATH} + -DnativeLibDir=${PACKAGE_LIB_DIR_NATIVE_PATH} + -Dorg.gradle.daemon=false + WORKING_DIRECTORY ${JAVA_SRC_ROOT} + RESULT_VARIABLE HAD_ERROR) + +if(HAD_ERROR) + message(FATAL_ERROR "Java Unitests failed") +endif() diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt index 01d1887e0..c9054026f 100644 --- a/src/python/CMakeLists.txt +++ b/src/python/CMakeLists.txt @@ -1,4 +1,4 @@ -include(${CMAKE_SOURCE_DIR}/cmake/cxx_standard.cmake) +include(${REPO_ROOT}/cmake/cxx_standard.cmake) file(GLOB python_srcs CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/*.h" diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 5d446cc40..1d31e0f61 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1,5 +1,3 @@ -enable_testing() - include(${CMAKE_SOURCE_DIR}/cmake/cxx_standard.cmake) set(TESTS_ROOT ${CMAKE_CURRENT_SOURCE_DIR} PARENT_SCOPE)