diff --git a/api/src/main/native/djl/utils.h b/api/src/main/native/djl/utils.h index edce9e54a9c7..8aa9b857c192 100644 --- a/api/src/main/native/djl/utils.h +++ b/api/src/main/native/djl/utils.h @@ -29,9 +29,20 @@ inline std::string GetStringFromJString(JNIEnv* env, jstring jstr) { if (jstr == nullptr) { return std::string(); } - const char* c_str = env->GetStringUTFChars(jstr, JNI_FALSE); - std::string str = std::string(c_str); - env->ReleaseStringUTFChars(jstr, c_str); + + const jclass string_class = env->GetObjectClass(jstr); + const jmethodID getbytes_method = env->GetMethodID(string_class, "getBytes", "(Ljava/lang/String;)[B"); + + const jstring charset = env->NewStringUTF("UTF-8"); + const jbyteArray jbytes = (jbyteArray) env->CallObjectMethod(jstr, getbytes_method, charset); + env->DeleteLocalRef(charset); + + const jsize length = env->GetArrayLength(jbytes); + jbyte* c_str = env->GetByteArrayElements(jbytes, NULL); + std::string str = std::string(reinterpret_cast(c_str), length); + + env->ReleaseByteArrayElements(jbytes, c_str, RELEASE_MODE); + env->DeleteLocalRef(jbytes); return str; } @@ -100,9 +111,22 @@ inline std::vector GetVecFromJStringArray(JNIEnv* env, jobjectArray // String[] inline jobjectArray GetStringArrayFromVec(JNIEnv* env, const std::vector &vec) { jobjectArray array = env->NewObjectArray(vec.size(), env->FindClass("Ljava/lang/String;"), nullptr); + + const jclass string_class = env->FindClass("java/lang/String"); + const jmethodID ctor = env->GetMethodID(string_class, "", "([BLjava/lang/String;)V"); + const jstring charset = env->NewStringUTF("UTF-8"); + for (int i = 0; i < vec.size(); ++i) { - env->SetObjectArrayElement(array, i, env->NewStringUTF(vec[i].c_str())); + const char* c_str = vec[i].c_str(); + int len = vec[i].length(); + auto jbytes = env->NewByteArray(len); + env->SetByteArrayRegion(jbytes, 0, len, reinterpret_cast(c_str)); + jobject jstr = env->NewObject(string_class, ctor, jbytes, charset); + env->DeleteLocalRef(jbytes); + env->SetObjectArrayElement(array, i, jstr); } + + env->DeleteLocalRef(charset); return array; } diff --git a/extensions/sentencepiece/src/test/java/ai/djl/sentencepiece/SpTokenizerTest.java b/extensions/sentencepiece/src/test/java/ai/djl/sentencepiece/SpTokenizerTest.java index 803a22b0146f..12dcdc13dc92 100644 --- a/extensions/sentencepiece/src/test/java/ai/djl/sentencepiece/SpTokenizerTest.java +++ b/extensions/sentencepiece/src/test/java/ai/djl/sentencepiece/SpTokenizerTest.java @@ -53,6 +53,20 @@ public void testTokenize() throws IOException { } } + @Test + public void testUtf16Tokenize() throws IOException { + if (System.getProperty("os.name").startsWith("Win")) { + throw new SkipException("Skip windows test."); + } + Path modelPath = Paths.get("build/test/models/sententpiece_test_model.model"); + try (SpTokenizer tokenizer = new SpTokenizer(modelPath)) { + String original = "\uD83D\uDC4B\uD83D\uDC4B"; + List tokens = tokenizer.tokenize(original); + List expected = Arrays.asList("▁", "\uD83D\uDC4B\uD83D\uDC4B"); + Assert.assertEquals(tokens, expected); + } + } + @Test public void testEncodeDecode() throws IOException { if (System.getProperty("os.name").startsWith("Win")) {