Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pytorch] Add oneHot operator #1014

Merged
merged 1 commit into from
Jun 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -4592,6 +4592,57 @@ default NDArray oneHot(int depth) {
return oneHot(depth, 1f, 0f, DataType.FLOAT32);
}

/**
* Returns a one-hot {@code NDArray}.
*
* <ul>
* <li>The locations represented by indices take value 1, while all other locations take value
* 0.
* <li>If the input {@code NDArray} is rank N, the output will have rank N+1. The new axis is
* appended at the end.
* <li>If {@code NDArray} is a scalar the output shape will be a vector of length depth.
* <li>If {@code NDArray} is a vector of length features, the output shape will be features x
* depth.
* <li>If {@code NDArray} is a matrix with shape [batch, features], the output shape will be
* batch x features x depth.
* </ul>
*
* <p>Examples
*
* <pre>
* jshell&gt; NDArray array = manager.create(new int[] {1, 0, 2, 0});
* jshell&gt; array.oneHot(3);
* ND: (4, 3) cpu() float32
* [[0., 1., 0.],
* [1., 0., 0.],
* [0., 0., 1.],
* [1., 0., 0.],
* ]
* jshell&gt; NDArray array = manager.create(new int[][] {{1, 0}, {1, 0}, {2, 0}});
* jshell&gt; array.oneHot(3);
* ND: (3, 2, 3) cpu() float32
* [[[0., 1., 0.],
* [1., 0., 0.],
* ],
* [[0., 1., 0.],
* [1., 0., 0.],
* ],
* [[0., 0., 1.],
* [1., 0., 0.],
* ],
* ]
* </pre>
*
* @param depth Depth of the one hot dimension.
* @param dataType dataType of the output.
* @return one-hot encoding of this {@code NDArray}
* @see <a
* href=https://d2l.djl.ai/chapter_linear-networks/softmax-regression.html#classification-problems>Classification-problems</a>
*/
default NDArray oneHot(int depth, DataType dataType) {
return oneHot(depth, 0f, 1f, dataType);
}

/**
* Returns a one-hot {@code NDArray}.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1408,6 +1408,18 @@ public NDArray norm(int order, int[] axes, boolean keepDims) {
return JniUtils.norm(this, order, axes, keepDims);
}

/** {@inheritDoc} */
@Override
public NDArray oneHot(int depth) {
return JniUtils.oneHot(this, depth, DataType.FLOAT32);
}

/** {@inheritDoc} */
@Override
public NDArray oneHot(int depth, DataType dataType) {
return JniUtils.oneHot(this, depth, dataType);
}

/** {@inheritDoc} */
@Override
public NDArray oneHot(int depth, float onValue, float offValue, DataType dataType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,14 @@ public static PtNDArray cumSum(PtNDArray ndArray, long dim) {
ndArray.getManager(), PyTorchLibrary.LIB.torchCumSum(ndArray.getHandle(), dim));
}

public static PtNDArray oneHot(PtNDArray ndArray, int depth, DataType dataType) {
return new PtNDArray(
ndArray.getManager(),
PyTorchLibrary.LIB.torchNNOneHot(
ndArray.toType(DataType.INT64, false).getHandle(), depth))
.toType(dataType, false);
}

public static NDList split(PtNDArray ndArray, long size, long axis) {
long[] ndPtrs = PyTorchLibrary.LIB.torchSplit(ndArray.getHandle(), size, axis);
NDList list = new NDList();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,8 @@ native long torchNNMaxPool(
native long torchNNLpPool(
long inputHandle, double normType, long[] kernelSize, long[] stride, boolean ceilMode);

native long torchNNOneHot(long inputHandle, int depth);

native boolean torchRequiresGrad(long inputHandle);

native String torchGradFnName(long inputHandle);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_moduleWrite(
API_BEGIN()
auto* module_ptr = reinterpret_cast<torch::jit::script::Module*>(module_handle);
#if defined(__ANDROID__)
env->ThrowNew(ENGINE_EXCEPTION_CLASS, "This kind of mode is not supported on Android");
return;
env->ThrowNew(ENGINE_EXCEPTION_CLASS, "This kind of mode is not supported on Android");
return;
#endif
std::ostringstream stream;
module_ptr->save(stream);
Expand Down Expand Up @@ -207,8 +207,8 @@ JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_moduleSave(
API_BEGIN()
auto* module_ptr = reinterpret_cast<torch::jit::script::Module*>(jhandle);
#if defined(__ANDROID__)
env->ThrowNew(ENGINE_EXCEPTION_CLASS, "This kind of mode is not supported on Android");
return;
env->ThrowNew(ENGINE_EXCEPTION_CLASS, "This kind of mode is not supported on Android");
return;
#endif
module_ptr->save(djl::utils::jni::GetStringFromJString(env, jpath));
API_END()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,15 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchLogSoftmax(
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchNNOneHot(
JNIEnv* env, jobject jthis, jlong jhandle, jint jdepth) {
API_BEGIN()
const auto* tensor_ptr = reinterpret_cast<torch::Tensor*>(jhandle);
const auto* result_ptr = new torch::Tensor(torch::nn::functional::one_hot(*tensor_ptr, jdepth));
return reinterpret_cast<uintptr_t>(result_ptr);
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchNNInterpolate(
JNIEnv* env, jobject jthis, jlong jhandle, jlongArray jsize, jint jmode, jboolean jalign_corners) {
API_BEGIN()
Expand Down