From c9d5412f780bdb504cc6095ff27897b84d4e4c4e Mon Sep 17 00:00:00 2001 From: KexinFeng Date: Thu, 12 May 2022 11:10:13 -0700 Subject: [PATCH] Support of take from pytorch (#1627) * take_dev * assertion_err * Keep test EngineAgnostic * testTake skipping windows is removed * combine get(NDArray(BOOLEAN) and NDArray(INT)) * add test * add_test * comb_get * bug * DataType check --- api/src/main/java/ai/djl/ndarray/NDArray.java | 32 +++++++++++++------ .../java/ai/djl/ndarray/NDArrayAdapter.java | 6 ++++ .../java/ai/djl/mxnet/engine/MxNDArray.java | 6 ++++ .../java/ai/djl/pytorch/engine/PtNDArray.java | 9 ++++++ .../java/ai/djl/pytorch/jni/JniUtils.java | 12 +++++++ .../ai/djl/pytorch/jni/PyTorchLibrary.java | 2 ++ ...i_djl_pytorch_jni_PyTorchLibrary_tensor.cc | 10 ++++++ .../ai/djl/tensorflow/engine/TfNDArray.java | 6 ++++ .../tests/ndarray/NDIndexTest.java | 21 ++++++++++-- 9 files changed, 92 insertions(+), 12 deletions(-) diff --git a/api/src/main/java/ai/djl/ndarray/NDArray.java b/api/src/main/java/ai/djl/ndarray/NDArray.java index c35c7f09545..c7c6f96aa4e 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArray.java +++ b/api/src/main/java/ai/djl/ndarray/NDArray.java @@ -511,6 +511,20 @@ default NDArray get(NDIndex index) { return getNDArrayInternal().getIndexer().get(this, index); } + /** + * Returns a partial {@code NDArray}. + * + * @param index the boolean or int {@code NDArray} that indicates what to get + * @return the partial {@code NDArray} + */ + default NDArray get(NDArray index) { + if (index.getDataType() == DataType.BOOLEAN) { + return get(new NDIndex().addBooleanIndex(index)); + } else { + return take(index); + } + } + /** * Returns a partial {@code NDArray}. * @@ -535,16 +549,6 @@ default NDArray get(long... indices) { return get(new NDIndex(indices)); } - /** - * Returns a partial {@code NDArray}. - * - * @param index the boolean {@code NDArray} that indicates what to get - * @return the partial {@code NDArray} - */ - default NDArray get(NDArray index) { - return get(new NDIndex().addBooleanIndex(index)); - } - /** * Returns a partial {@code NDArray} pointed by the indexed array. Given NDArray arr, NDArray * idx, and long axis, the output is out_{ijk} = arr_{idx_{ijk}, j, k} if axis=0 or arr_{i, @@ -556,6 +560,14 @@ default NDArray get(NDArray index) { */ NDArray gather(NDArray index, int axis); + /** + * Returns a partial {@code NDArray} pointed by the indexed array, according to linear indexing. + * + * @param index picks the elements of an NDArray to the same position as index + * @return the partial {@code NDArray} of the same shape as index + */ + NDArray take(NDArray index); + /** * Returns a scalar {@code NDArray} corresponding to a single element. * diff --git a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java index 46f2634c351..c36762010a9 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java +++ b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java @@ -182,6 +182,12 @@ public NDArray gather(NDArray index, int axis) { throw new UnsupportedOperationException(UNSUPPORTED_MSG); } + /** {@inheritDoc} */ + @Override + public NDArray take(NDArray index) { + throw new UnsupportedOperationException(UNSUPPORTED_MSG); + } + /** {@inheritDoc} */ @Override public void set(Buffer data) { diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java index f786b022930..66fe02d5d03 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java @@ -312,6 +312,12 @@ public NDArray gather(NDArray index, int axis) { throw new UnsupportedOperationException("Not implemented yet."); } + /** {@inheritDoc} */ + @Override + public NDArray take(NDArray index) { + throw new UnsupportedOperationException("Not implemented yet."); + } + /** {@inheritDoc} */ @Override public void copyTo(NDArray ndArray) { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java index a1cacb83e93..a8944cd92c2 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java @@ -248,6 +248,15 @@ public NDArray gather(NDArray index, int axis) { return JniUtils.gather(this, (PtNDArray) index, axis); } + /** {@inheritDoc} */ + @Override + public NDArray take(NDArray index) { + if (!(index instanceof PtNDArray)) { + throw new IllegalArgumentException("Only PtNDArray is supported."); + } + return JniUtils.take(this, (PtNDArray) index); + } + /** {@inheritDoc} */ @Override public void copyTo(NDArray array) { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java index a11719612ae..5d0f21b619a 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java @@ -347,11 +347,23 @@ public static void set(PtNDArray self, ByteBuffer data) { } public static PtNDArray gather(PtNDArray ndArray, PtNDArray index, long dim) { + if (index.getDataType() != DataType.INT64) { + index = index.toType(DataType.INT64, true); + } return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchGather(ndArray.getHandle(), index.getHandle(), dim, false)); } + public static PtNDArray take(PtNDArray ndArray, PtNDArray index) { + if (index.getDataType() != DataType.INT64) { + index = index.toType(DataType.INT64, true); + } + return new PtNDArray( + ndArray.getManager(), + PyTorchLibrary.LIB.torchTake(ndArray.getHandle(), index.getHandle())); + } + public static PtNDArray pick(PtNDArray ndArray, PtNDArray index, long dim) { Shape indexShape = index.getShape(); Shape ndShape = ndArray.getShape(); diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java index 63317cb812a..7bee9f7abd1 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java @@ -198,6 +198,8 @@ native void torchIndexPut( native long torchGather(long handle, long index, long dim, boolean sparseGrad); + native long torchTake(long handle, long index); + native long torchMaskedSelect(long handle, long maskHandle); native void torchMaskedPut(long handle, long valueHandle, long maskHandle); diff --git a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc index 6ec0bf3c943..0e876784be3 100644 --- a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc +++ b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc @@ -176,6 +176,16 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchGather( API_END_RETURN() } +JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchTake( + JNIEnv* env, jobject jthis, jlong jhandle, jlong jindex_handle) { + API_BEGIN() + const auto* tensor_ptr = reinterpret_cast(jhandle); + const auto* index_ptr = reinterpret_cast(jindex_handle); + const auto* result_ptr = new torch::Tensor(tensor_ptr->take(*index_ptr)); + return reinterpret_cast(result_ptr); + API_END_RETURN() +} + JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchMaskedSelect( JNIEnv* env, jobject jthis, jlong jhandle, jlong jmasked_handle) { API_BEGIN() diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java index 026680b97dd..421cd1fe706 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java @@ -212,6 +212,12 @@ public NDArray gather(NDArray index, int axis) { throw new UnsupportedOperationException("Not implemented yet."); } + /** {@inheritDoc} */ + @Override + public NDArray take(NDArray index) { + throw new UnsupportedOperationException("Not implemented yet."); + } + /** {@inheritDoc} */ @Override public void attach(NDManager manager) { diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java index 59e1b14ac1b..49e9f652c4a 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java @@ -56,14 +56,24 @@ public void testPick() { public void testGather() { try (NDManager manager = NDManager.newBaseManager()) { NDArray arr = manager.arange(20f).reshape(-1, 4); - long[] idx = {0, 0, 2, 1, 1, 2}; - NDArray index = manager.create(idx, new Shape(3, 2)); + NDArray index = manager.create(new long[] {0, 0, 2, 1, 1, 2}, new Shape(3, 2)); NDArray actual = arr.gather(index, 1); NDArray expected = manager.create(new float[] {0, 0, 6, 5, 9, 10}, new Shape(3, 2)); Assert.assertEquals(actual, expected); } } + @Test + public void testTake() { + try (NDManager manager = NDManager.newBaseManager()) { + NDArray arr = manager.arange(1, 7f).reshape(-1, 3); + NDArray index = manager.create(new long[] {0, 4, 1, 2}, new Shape(2, 2)); + NDArray actual = arr.take(index); + NDArray expected = manager.create(new float[] {1, 5, 2, 3}, new Shape(2, 2)); + Assert.assertEquals(actual, expected); + } + } + @Test public void testGet() { try (NDManager manager = NDManager.newBaseManager()) { @@ -107,6 +117,13 @@ public void testGet() { NDArray bool = manager.create(new boolean[] {true, false}); expected = manager.arange(5).reshape(1, 5); Assert.assertEquals(original.get(bool), expected); + + // get from int array + original = manager.arange(1, 7f).reshape(-1, 3); + NDArray index = manager.create(new long[] {0, 4, 1, 2}, new Shape(2, 2)); + NDArray actual = original.get(index); + expected = manager.create(new float[] {1, 5, 2, 3}, new Shape(2, 2)); + Assert.assertEquals(actual, expected); } }