diff --git a/api/src/main/java/ai/djl/ndarray/NDArray.java b/api/src/main/java/ai/djl/ndarray/NDArray.java index c35c7f09545..f83daf3e445 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArray.java +++ b/api/src/main/java/ai/djl/ndarray/NDArray.java @@ -556,6 +556,14 @@ default NDArray get(NDArray index) { */ NDArray gather(NDArray index, int axis); + /** + * Returns a partial {@code NDArray} pointed by the indexed array, according to linear indexing. + * + * @param index picks the elements of an NDArray to the same position as index + * @return the partial {@code NDArray} of the same shape as index + */ + NDArray take(NDArray index); + /** * Returns a scalar {@code NDArray} corresponding to a single element. * diff --git a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java index 46f2634c351..c36762010a9 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java +++ b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java @@ -182,6 +182,12 @@ public NDArray gather(NDArray index, int axis) { throw new UnsupportedOperationException(UNSUPPORTED_MSG); } + /** {@inheritDoc} */ + @Override + public NDArray take(NDArray index) { + throw new UnsupportedOperationException(UNSUPPORTED_MSG); + } + /** {@inheritDoc} */ @Override public void set(Buffer data) { diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java index f786b022930..66fe02d5d03 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java @@ -312,6 +312,12 @@ public NDArray gather(NDArray index, int axis) { throw new UnsupportedOperationException("Not implemented yet."); } + /** {@inheritDoc} */ + @Override + public NDArray take(NDArray index) { + throw new UnsupportedOperationException("Not implemented yet."); + } + /** {@inheritDoc} */ @Override public void copyTo(NDArray ndArray) { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java index a1cacb83e93..a8944cd92c2 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java @@ -248,6 +248,15 @@ public NDArray gather(NDArray index, int axis) { return JniUtils.gather(this, (PtNDArray) index, axis); } + /** {@inheritDoc} */ + @Override + public NDArray take(NDArray index) { + if (!(index instanceof PtNDArray)) { + throw new IllegalArgumentException("Only PtNDArray is supported."); + } + return JniUtils.take(this, (PtNDArray) index); + } + /** {@inheritDoc} */ @Override public void copyTo(NDArray array) { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java index a11719612ae..b3ae21101b8 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java @@ -352,6 +352,12 @@ public static PtNDArray gather(PtNDArray ndArray, PtNDArray index, long dim) { PyTorchLibrary.LIB.torchGather(ndArray.getHandle(), index.getHandle(), dim, false)); } + public static PtNDArray take(PtNDArray ndArray, PtNDArray index) { + return new PtNDArray( + ndArray.getManager(), + PyTorchLibrary.LIB.torchTake(ndArray.getHandle(), index.getHandle())); + } + public static PtNDArray pick(PtNDArray ndArray, PtNDArray index, long dim) { Shape indexShape = index.getShape(); Shape ndShape = ndArray.getShape(); diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java index 63317cb812a..7bee9f7abd1 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java @@ -198,6 +198,8 @@ native void torchIndexPut( native long torchGather(long handle, long index, long dim, boolean sparseGrad); + native long torchTake(long handle, long index); + native long torchMaskedSelect(long handle, long maskHandle); native void torchMaskedPut(long handle, long valueHandle, long maskHandle); diff --git a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc index 6ec0bf3c943..0e876784be3 100644 --- a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc +++ b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc @@ -176,6 +176,16 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchGather( API_END_RETURN() } +JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchTake( + JNIEnv* env, jobject jthis, jlong jhandle, jlong jindex_handle) { + API_BEGIN() + const auto* tensor_ptr = reinterpret_cast(jhandle); + const auto* index_ptr = reinterpret_cast(jindex_handle); + const auto* result_ptr = new torch::Tensor(tensor_ptr->take(*index_ptr)); + return reinterpret_cast(result_ptr); + API_END_RETURN() +} + JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchMaskedSelect( JNIEnv* env, jobject jthis, jlong jhandle, jlong jmasked_handle) { API_BEGIN() diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java index 026680b97dd..421cd1fe706 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java @@ -212,6 +212,12 @@ public NDArray gather(NDArray index, int axis) { throw new UnsupportedOperationException("Not implemented yet."); } + /** {@inheritDoc} */ + @Override + public NDArray take(NDArray index) { + throw new UnsupportedOperationException("Not implemented yet."); + } + /** {@inheritDoc} */ @Override public void attach(NDManager manager) { diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java index a954f3e0fd0..41db19eda6b 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java @@ -12,6 +12,7 @@ */ package ai.djl.integration.tests.ndarray; +import ai.djl.engine.Engine; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDManager; import ai.djl.ndarray.index.NDIndex; @@ -61,14 +62,25 @@ public void testGather() { TestRequirements.notWindows(); try (NDManager manager = NDManager.newBaseManager()) { NDArray arr = manager.arange(20f).reshape(-1, 4); - long[] idx = {0, 0, 2, 1, 1, 2}; - NDArray index = manager.create(idx, new Shape(3, 2)); + NDArray index = manager.create(new long[] {0, 0, 2, 1, 1, 2}, new Shape(3, 2)); NDArray actual = arr.gather(index, 1); NDArray expected = manager.create(new float[] {0, 0, 6, 5, 9, 10}, new Shape(3, 2)); Assert.assertEquals(actual, expected); } } + @Test + public void testTake() { + Engine engine = Engine.getEngine("PyTorch"); + try (NDManager manager = engine.newBaseManager()) { + NDArray arr = manager.arange(6f).reshape(-1, 3); + NDArray index = manager.create(new long[] {0, 4, 1, 2}, new Shape(2, 2)); + NDArray actual = arr.take(index); + NDArray expected = manager.create(new float[] {1, 5, 2, 3}, new Shape(2, 2)); + Assert.assertEquals(actual, expected); + } + } + @Test public void testGet() { try (NDManager manager = NDManager.newBaseManager()) {