diff --git a/api/src/main/java/ai/djl/ndarray/NDArray.java b/api/src/main/java/ai/djl/ndarray/NDArray.java index ebb3148b285..57cabbd2da5 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArray.java +++ b/api/src/main/java/ai/djl/ndarray/NDArray.java @@ -4592,6 +4592,57 @@ default NDArray oneHot(int depth) { return oneHot(depth, 1f, 0f, DataType.FLOAT32); } + /** + * Returns a one-hot {@code NDArray}. + * + * + * + *

Examples + * + *

+     * jshell> NDArray array = manager.create(new int[] {1, 0, 2, 0});
+     * jshell> array.oneHot(3);
+     * ND: (4, 3) cpu() float32
+     * [[0., 1., 0.],
+     *  [1., 0., 0.],
+     *  [0., 0., 1.],
+     *  [1., 0., 0.],
+     * ]
+     * jshell> NDArray array = manager.create(new int[][] {{1, 0}, {1, 0}, {2, 0}});
+     * jshell> array.oneHot(3);
+     * ND: (3, 2, 3) cpu() float32
+     * [[[0., 1., 0.],
+     *   [1., 0., 0.],
+     *  ],
+     *  [[0., 1., 0.],
+     *   [1., 0., 0.],
+     *  ],
+     *  [[0., 0., 1.],
+     *   [1., 0., 0.],
+     *  ],
+     * ]
+     * 
+ * + * @param depth Depth of the one hot dimension. + * @param dataType dataType of the output. + * @return one-hot encoding of this {@code NDArray} + * @see Classification-problems + */ + default NDArray oneHot(int depth, DataType dataType) { + return oneHot(depth, 0f, 1f, dataType); + } + /** * Returns a one-hot {@code NDArray}. * diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java index 3e0ede826a5..246b9e57fee 100644 --- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java +++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java @@ -1408,6 +1408,18 @@ public NDArray norm(int order, int[] axes, boolean keepDims) { return JniUtils.norm(this, order, axes, keepDims); } + /** {@inheritDoc} */ + @Override + public NDArray oneHot(int depth) { + return JniUtils.oneHot(this, depth, DataType.FLOAT32); + } + + /** {@inheritDoc} */ + @Override + public NDArray oneHot(int depth, DataType dataType) { + return JniUtils.oneHot(this, depth, dataType); + } + /** {@inheritDoc} */ @Override public NDArray oneHot(int depth, float onValue, float offValue, DataType dataType) { diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java index c98f4e60d75..64a8965f736 100644 --- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java +++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java @@ -720,6 +720,14 @@ public static PtNDArray cumSum(PtNDArray ndArray, long dim) { ndArray.getManager(), PyTorchLibrary.LIB.torchCumSum(ndArray.getHandle(), dim)); } + public static PtNDArray oneHot(PtNDArray ndArray, int depth, DataType dataType) { + return new PtNDArray( + ndArray.getManager(), + PyTorchLibrary.LIB.torchNNOneHot( + ndArray.toType(DataType.INT64, false).getHandle(), depth)) + .toType(dataType, false); + } + public static NDList split(PtNDArray ndArray, long size, long axis) { long[] ndPtrs = PyTorchLibrary.LIB.torchSplit(ndArray.getHandle(), size, axis); NDList list = new NDList(); diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java index 2145ada3df9..0bb7dbafccf 100644 --- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java +++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java @@ -448,6 +448,8 @@ native long torchNNMaxPool( native long torchNNLpPool( long inputHandle, double normType, long[] kernelSize, long[] stride, boolean ceilMode); + native long torchNNOneHot(long inputHandle, int depth); + native boolean torchRequiresGrad(long inputHandle); native String torchGradFnName(long inputHandle); diff --git a/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_inference.cc b/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_inference.cc index 643bd20e5f8..66f2e171ba4 100644 --- a/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_inference.cc +++ b/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_inference.cc @@ -99,8 +99,8 @@ JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_moduleWrite( API_BEGIN() auto* module_ptr = reinterpret_cast(module_handle); #if defined(__ANDROID__) - env->ThrowNew(ENGINE_EXCEPTION_CLASS, "This kind of mode is not supported on Android"); - return; + env->ThrowNew(ENGINE_EXCEPTION_CLASS, "This kind of mode is not supported on Android"); + return; #endif std::ostringstream stream; module_ptr->save(stream); @@ -207,8 +207,8 @@ JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_moduleSave( API_BEGIN() auto* module_ptr = reinterpret_cast(jhandle); #if defined(__ANDROID__) - env->ThrowNew(ENGINE_EXCEPTION_CLASS, "This kind of mode is not supported on Android"); - return; + env->ThrowNew(ENGINE_EXCEPTION_CLASS, "This kind of mode is not supported on Android"); + return; #endif module_ptr->save(djl::utils::jni::GetStringFromJString(env, jpath)); API_END() diff --git a/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_nn_functional.cc b/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_nn_functional.cc index ecf7bb2b98c..59afabcecf5 100644 --- a/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_nn_functional.cc +++ b/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_nn_functional.cc @@ -37,6 +37,15 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchLogSoftmax( API_END_RETURN() } +JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchNNOneHot( + JNIEnv* env, jobject jthis, jlong jhandle, jint jdepth) { + API_BEGIN() + const auto* tensor_ptr = reinterpret_cast(jhandle); + const auto* result_ptr = new torch::Tensor(torch::nn::functional::one_hot(*tensor_ptr, jdepth)); + return reinterpret_cast(result_ptr); + API_END_RETURN() +} + JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchNNInterpolate( JNIEnv* env, jobject jthis, jlong jhandle, jlongArray jsize, jint jmode, jboolean jalign_corners) { API_BEGIN()