From dc1b7a8ea552c0e5966309b8d60466e2b4ef5851 Mon Sep 17 00:00:00 2001 From: Kexin Date: Fri, 6 May 2022 14:00:55 -0700 Subject: [PATCH 01/12] take_dev --- api/src/main/java/ai/djl/ndarray/NDArray.java | 8 ++++++++ .../main/java/ai/djl/ndarray/NDArrayAdapter.java | 6 ++++++ .../main/java/ai/djl/mxnet/engine/MxNDArray.java | 6 ++++++ .../java/ai/djl/pytorch/engine/PtNDArray.java | 9 +++++++++ .../main/java/ai/djl/pytorch/jni/JniUtils.java | 6 ++++++ .../java/ai/djl/pytorch/jni/PyTorchLibrary.java | 2 ++ .../ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc | 10 ++++++++++ .../java/ai/djl/tensorflow/engine/TfNDArray.java | 6 ++++++ .../integration/tests/ndarray/NDIndexTest.java | 16 ++++++++++++++-- 9 files changed, 67 insertions(+), 2 deletions(-) diff --git a/api/src/main/java/ai/djl/ndarray/NDArray.java b/api/src/main/java/ai/djl/ndarray/NDArray.java index c35c7f09545..f83daf3e445 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArray.java +++ b/api/src/main/java/ai/djl/ndarray/NDArray.java @@ -556,6 +556,14 @@ default NDArray get(NDArray index) { */ NDArray gather(NDArray index, int axis); + /** + * Returns a partial {@code NDArray} pointed by the indexed array, according to linear indexing. + * + * @param index picks the elements of an NDArray to the same position as index + * @return the partial {@code NDArray} of the same shape as index + */ + NDArray take(NDArray index); + /** * Returns a scalar {@code NDArray} corresponding to a single element. * diff --git a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java index 46f2634c351..c36762010a9 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java +++ b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java @@ -182,6 +182,12 @@ public NDArray gather(NDArray index, int axis) { throw new UnsupportedOperationException(UNSUPPORTED_MSG); } + /** {@inheritDoc} */ + @Override + public NDArray take(NDArray index) { + throw new UnsupportedOperationException(UNSUPPORTED_MSG); + } + /** {@inheritDoc} */ @Override public void set(Buffer data) { diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java index f786b022930..66fe02d5d03 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java @@ -312,6 +312,12 @@ public NDArray gather(NDArray index, int axis) { throw new UnsupportedOperationException("Not implemented yet."); } + /** {@inheritDoc} */ + @Override + public NDArray take(NDArray index) { + throw new UnsupportedOperationException("Not implemented yet."); + } + /** {@inheritDoc} */ @Override public void copyTo(NDArray ndArray) { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java index a1cacb83e93..a8944cd92c2 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java @@ -248,6 +248,15 @@ public NDArray gather(NDArray index, int axis) { return JniUtils.gather(this, (PtNDArray) index, axis); } + /** {@inheritDoc} */ + @Override + public NDArray take(NDArray index) { + if (!(index instanceof PtNDArray)) { + throw new IllegalArgumentException("Only PtNDArray is supported."); + } + return JniUtils.take(this, (PtNDArray) index); + } + /** {@inheritDoc} */ @Override public void copyTo(NDArray array) { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java index a11719612ae..b3ae21101b8 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java @@ -352,6 +352,12 @@ public static PtNDArray gather(PtNDArray ndArray, PtNDArray index, long dim) { PyTorchLibrary.LIB.torchGather(ndArray.getHandle(), index.getHandle(), dim, false)); } + public static PtNDArray take(PtNDArray ndArray, PtNDArray index) { + return new PtNDArray( + ndArray.getManager(), + PyTorchLibrary.LIB.torchTake(ndArray.getHandle(), index.getHandle())); + } + public static PtNDArray pick(PtNDArray ndArray, PtNDArray index, long dim) { Shape indexShape = index.getShape(); Shape ndShape = ndArray.getShape(); diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java index 63317cb812a..7bee9f7abd1 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java @@ -198,6 +198,8 @@ native void torchIndexPut( native long torchGather(long handle, long index, long dim, boolean sparseGrad); + native long torchTake(long handle, long index); + native long torchMaskedSelect(long handle, long maskHandle); native void torchMaskedPut(long handle, long valueHandle, long maskHandle); diff --git a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc index 6ec0bf3c943..0e876784be3 100644 --- a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc +++ b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc @@ -176,6 +176,16 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchGather( API_END_RETURN() } +JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchTake( + JNIEnv* env, jobject jthis, jlong jhandle, jlong jindex_handle) { + API_BEGIN() + const auto* tensor_ptr = reinterpret_cast(jhandle); + const auto* index_ptr = reinterpret_cast(jindex_handle); + const auto* result_ptr = new torch::Tensor(tensor_ptr->take(*index_ptr)); + return reinterpret_cast(result_ptr); + API_END_RETURN() +} + JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchMaskedSelect( JNIEnv* env, jobject jthis, jlong jhandle, jlong jmasked_handle) { API_BEGIN() diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java index 026680b97dd..421cd1fe706 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java @@ -212,6 +212,12 @@ public NDArray gather(NDArray index, int axis) { throw new UnsupportedOperationException("Not implemented yet."); } + /** {@inheritDoc} */ + @Override + public NDArray take(NDArray index) { + throw new UnsupportedOperationException("Not implemented yet."); + } + /** {@inheritDoc} */ @Override public void attach(NDManager manager) { diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java index a954f3e0fd0..41db19eda6b 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java @@ -12,6 +12,7 @@ */ package ai.djl.integration.tests.ndarray; +import ai.djl.engine.Engine; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDManager; import ai.djl.ndarray.index.NDIndex; @@ -61,14 +62,25 @@ public void testGather() { TestRequirements.notWindows(); try (NDManager manager = NDManager.newBaseManager()) { NDArray arr = manager.arange(20f).reshape(-1, 4); - long[] idx = {0, 0, 2, 1, 1, 2}; - NDArray index = manager.create(idx, new Shape(3, 2)); + NDArray index = manager.create(new long[] {0, 0, 2, 1, 1, 2}, new Shape(3, 2)); NDArray actual = arr.gather(index, 1); NDArray expected = manager.create(new float[] {0, 0, 6, 5, 9, 10}, new Shape(3, 2)); Assert.assertEquals(actual, expected); } } + @Test + public void testTake() { + Engine engine = Engine.getEngine("PyTorch"); + try (NDManager manager = engine.newBaseManager()) { + NDArray arr = manager.arange(6f).reshape(-1, 3); + NDArray index = manager.create(new long[] {0, 4, 1, 2}, new Shape(2, 2)); + NDArray actual = arr.take(index); + NDArray expected = manager.create(new float[] {1, 5, 2, 3}, new Shape(2, 2)); + Assert.assertEquals(actual, expected); + } + } + @Test public void testGet() { try (NDManager manager = NDManager.newBaseManager()) { From 239c908a494dbc7b75688348436de21dda258cdc Mon Sep 17 00:00:00 2001 From: Kexin Date: Fri, 6 May 2022 19:39:28 -0700 Subject: [PATCH 02/12] assertion_err --- .../java/ai/djl/integration/tests/ndarray/NDIndexTest.java | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java index 41db19eda6b..1d677b9c28b 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java @@ -60,7 +60,9 @@ public void testGather() { // In the dependencies, changing runtimeOnly to api however will remedy the problem. // TODO: remove this when gradle problem is fixed. TestRequirements.notWindows(); - try (NDManager manager = NDManager.newBaseManager()) { + Engine engine = Engine.getEngine("PyTorch"); + try (NDManager manager = engine.newBaseManager()) { +// try (NDManager manager = NDManager.newBaseManager()) { NDArray arr = manager.arange(20f).reshape(-1, 4); NDArray index = manager.create(new long[] {0, 0, 2, 1, 1, 2}, new Shape(3, 2)); NDArray actual = arr.gather(index, 1); @@ -73,7 +75,7 @@ public void testGather() { public void testTake() { Engine engine = Engine.getEngine("PyTorch"); try (NDManager manager = engine.newBaseManager()) { - NDArray arr = manager.arange(6f).reshape(-1, 3); + NDArray arr = manager.arange(1,7f).reshape(-1, 3); NDArray index = manager.create(new long[] {0, 4, 1, 2}, new Shape(2, 2)); NDArray actual = arr.take(index); NDArray expected = manager.create(new float[] {1, 5, 2, 3}, new Shape(2, 2)); From a7208229b32874b94dd6cf8e835af0f7057f8074 Mon Sep 17 00:00:00 2001 From: Kexin Date: Fri, 6 May 2022 19:45:20 -0700 Subject: [PATCH 03/12] Keep test EngineAgnostic --- .../integration/tests/ndarray/NDIndexTest.java | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java index 1d677b9c28b..4cf0acb7622 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java @@ -12,7 +12,6 @@ */ package ai.djl.integration.tests.ndarray; -import ai.djl.engine.Engine; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDManager; import ai.djl.ndarray.index.NDIndex; @@ -56,13 +55,12 @@ public void testPick() { @Test public void testGather() { - // Currently in windows gradle cannot find all the engines. + // Currently in windows gradle cannot find all the engines to fill in -classpath, except for + // MXNet. // In the dependencies, changing runtimeOnly to api however will remedy the problem. // TODO: remove this when gradle problem is fixed. TestRequirements.notWindows(); - Engine engine = Engine.getEngine("PyTorch"); - try (NDManager manager = engine.newBaseManager()) { -// try (NDManager manager = NDManager.newBaseManager()) { + try (NDManager manager = NDManager.newBaseManager()) { NDArray arr = manager.arange(20f).reshape(-1, 4); NDArray index = manager.create(new long[] {0, 0, 2, 1, 1, 2}, new Shape(3, 2)); NDArray actual = arr.gather(index, 1); @@ -73,9 +71,10 @@ public void testGather() { @Test public void testTake() { - Engine engine = Engine.getEngine("PyTorch"); - try (NDManager manager = engine.newBaseManager()) { - NDArray arr = manager.arange(1,7f).reshape(-1, 3); + // TODO: remove this when gradle problem in windows shown above is fixed. + TestRequirements.notWindows(); + try (NDManager manager = NDManager.newBaseManager()) { + NDArray arr = manager.arange(1, 7f).reshape(-1, 3); NDArray index = manager.create(new long[] {0, 4, 1, 2}, new Shape(2, 2)); NDArray actual = arr.take(index); NDArray expected = manager.create(new float[] {1, 5, 2, 3}, new Shape(2, 2)); From 6af0c984f5d07262ffefa1f3f1f952e75fbac71e Mon Sep 17 00:00:00 2001 From: Kexin Date: Mon, 9 May 2022 15:21:44 -0700 Subject: [PATCH 04/12] testTake skipping windows is removed --- .../ai/djl/integration/tests/ndarray/NDIndexTest.java | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java index 4cf0acb7622..820a2be709c 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java @@ -12,11 +12,11 @@ */ package ai.djl.integration.tests.ndarray; +import ai.djl.engine.Engine; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDManager; import ai.djl.ndarray.index.NDIndex; import ai.djl.ndarray.types.Shape; -import ai.djl.testing.TestRequirements; import org.testng.Assert; import org.testng.annotations.Test; @@ -55,11 +55,6 @@ public void testPick() { @Test public void testGather() { - // Currently in windows gradle cannot find all the engines to fill in -classpath, except for - // MXNet. - // In the dependencies, changing runtimeOnly to api however will remedy the problem. - // TODO: remove this when gradle problem is fixed. - TestRequirements.notWindows(); try (NDManager manager = NDManager.newBaseManager()) { NDArray arr = manager.arange(20f).reshape(-1, 4); NDArray index = manager.create(new long[] {0, 0, 2, 1, 1, 2}, new Shape(3, 2)); @@ -71,8 +66,6 @@ public void testGather() { @Test public void testTake() { - // TODO: remove this when gradle problem in windows shown above is fixed. - TestRequirements.notWindows(); try (NDManager manager = NDManager.newBaseManager()) { NDArray arr = manager.arange(1, 7f).reshape(-1, 3); NDArray index = manager.create(new long[] {0, 4, 1, 2}, new Shape(2, 2)); From 9d5ebeb0624894a4f1209829712e0788da6a2d86 Mon Sep 17 00:00:00 2001 From: Kexin Date: Mon, 9 May 2022 16:35:21 -0700 Subject: [PATCH 05/12] combine get(NDArray(BOOLEAN) and NDArray(INT)) --- api/src/main/java/ai/djl/ndarray/NDArray.java | 25 +++++++++++-------- .../tests/ndarray/NDIndexTest.java | 1 - 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/api/src/main/java/ai/djl/ndarray/NDArray.java b/api/src/main/java/ai/djl/ndarray/NDArray.java index f83daf3e445..63751ba22f3 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArray.java +++ b/api/src/main/java/ai/djl/ndarray/NDArray.java @@ -511,6 +511,21 @@ default NDArray get(NDIndex index) { return getNDArrayInternal().getIndexer().get(this, index); } + /** + * Returns a partial {@code NDArray}. + * + * @param index the boolean or int {@code NDArray} that indicates what to get + * @return the partial {@code NDArray} + */ + default NDArray get(NDArray index) { + DataType indexType = index.getDataType(); + if (indexType == DataType.BOOLEAN) { + return get(new NDIndex().addBooleanIndex(index)); + } else { + return take(index); + } + } + /** * Returns a partial {@code NDArray}. * @@ -535,16 +550,6 @@ default NDArray get(long... indices) { return get(new NDIndex(indices)); } - /** - * Returns a partial {@code NDArray}. - * - * @param index the boolean {@code NDArray} that indicates what to get - * @return the partial {@code NDArray} - */ - default NDArray get(NDArray index) { - return get(new NDIndex().addBooleanIndex(index)); - } - /** * Returns a partial {@code NDArray} pointed by the indexed array. Given NDArray arr, NDArray * idx, and long axis, the output is out_{ijk} = arr_{idx_{ijk}, j, k} if axis=0 or arr_{i, diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java index 820a2be709c..aa4b8af5a59 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java @@ -12,7 +12,6 @@ */ package ai.djl.integration.tests.ndarray; -import ai.djl.engine.Engine; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDManager; import ai.djl.ndarray.index.NDIndex; From 84938f11bf833ed439cdb3773df90641dd5d2718 Mon Sep 17 00:00:00 2001 From: Kexin Date: Mon, 9 May 2022 16:54:37 -0700 Subject: [PATCH 06/12] add test --- .../main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java index aa4b8af5a59..0419513ec0c 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java @@ -69,8 +69,10 @@ public void testTake() { NDArray arr = manager.arange(1, 7f).reshape(-1, 3); NDArray index = manager.create(new long[] {0, 4, 1, 2}, new Shape(2, 2)); NDArray actual = arr.take(index); + NDArray actual2 = arr.get(index); NDArray expected = manager.create(new float[] {1, 5, 2, 3}, new Shape(2, 2)); Assert.assertEquals(actual, expected); + Assert.assertEquals(actual2, expected); } } From 9ac41a68bda9591b593a2a887405d002a51e0db5 Mon Sep 17 00:00:00 2001 From: Kexin Date: Mon, 9 May 2022 17:03:54 -0700 Subject: [PATCH 07/12] add_test --- .../ai/djl/integration/tests/ndarray/NDIndexTest.java | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java index 0419513ec0c..3c95f501008 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java @@ -69,10 +69,8 @@ public void testTake() { NDArray arr = manager.arange(1, 7f).reshape(-1, 3); NDArray index = manager.create(new long[] {0, 4, 1, 2}, new Shape(2, 2)); NDArray actual = arr.take(index); - NDArray actual2 = arr.get(index); NDArray expected = manager.create(new float[] {1, 5, 2, 3}, new Shape(2, 2)); Assert.assertEquals(actual, expected); - Assert.assertEquals(actual2, expected); } } @@ -119,6 +117,13 @@ public void testGet() { NDArray bool = manager.create(new boolean[] {true, false}); expected = manager.arange(5).reshape(1, 5); Assert.assertEquals(original.get(bool), expected); + + // get from int array + original = manager.arange(1, 7f).reshape(-1, 3); + NDArray index = manager.create(new long[] {0, 4, 1, 2}, new Shape(2, 2)); + NDArray actual = original.take(index); + expected = manager.create(new float[] {1, 5, 2, 3}, new Shape(2, 2)); + Assert.assertEquals(actual, expected); } } From aef770ed7fbe2aa135ec3e3bba3c5a2135e06f79 Mon Sep 17 00:00:00 2001 From: Kexin Date: Mon, 9 May 2022 18:06:35 -0700 Subject: [PATCH 08/12] comb_get --- api/src/main/java/ai/djl/ndarray/NDArray.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/api/src/main/java/ai/djl/ndarray/NDArray.java b/api/src/main/java/ai/djl/ndarray/NDArray.java index 63751ba22f3..c7c6f96aa4e 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArray.java +++ b/api/src/main/java/ai/djl/ndarray/NDArray.java @@ -518,8 +518,7 @@ default NDArray get(NDIndex index) { * @return the partial {@code NDArray} */ default NDArray get(NDArray index) { - DataType indexType = index.getDataType(); - if (indexType == DataType.BOOLEAN) { + if (index.getDataType() == DataType.BOOLEAN) { return get(new NDIndex().addBooleanIndex(index)); } else { return take(index); From ea84044bb3a75854cb942c841abc2dc2f2bcd893 Mon Sep 17 00:00:00 2001 From: Kexin Date: Tue, 10 May 2022 10:37:36 -0700 Subject: [PATCH 09/12] bug --- .../main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java index 3c95f501008..49e9f652c4a 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java @@ -121,7 +121,7 @@ public void testGet() { // get from int array original = manager.arange(1, 7f).reshape(-1, 3); NDArray index = manager.create(new long[] {0, 4, 1, 2}, new Shape(2, 2)); - NDArray actual = original.take(index); + NDArray actual = original.get(index); expected = manager.create(new float[] {1, 5, 2, 3}, new Shape(2, 2)); Assert.assertEquals(actual, expected); } From 7725af96f051c92281793a2e65df0ec65830768e Mon Sep 17 00:00:00 2001 From: Kexin Date: Wed, 11 May 2022 18:40:07 -0700 Subject: [PATCH 10/12] DataType check --- api/src/main/java/ai/djl/ndarray/index/NDArrayIndexer.java | 2 +- .../src/main/java/ai/djl/pytorch/engine/PtNDArray.java | 2 +- .../src/main/java/ai/djl/pytorch/jni/JniUtils.java | 6 ++++++ 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/api/src/main/java/ai/djl/ndarray/index/NDArrayIndexer.java b/api/src/main/java/ai/djl/ndarray/index/NDArrayIndexer.java index 31f4008120f..fd29e8f8bf9 100644 --- a/api/src/main/java/ai/djl/ndarray/index/NDArrayIndexer.java +++ b/api/src/main/java/ai/djl/ndarray/index/NDArrayIndexer.java @@ -59,7 +59,7 @@ public NDArray get(NDArray array, NDIndex index) { if (!indices.isEmpty() && indices.get(0) instanceof NDIndexBooleans) { if (indices.size() != 1) { throw new IllegalArgumentException( - "get() currently didn't support more that one boolean NDArray"); + "get() currently doesn't support more that one boolean NDArray"); } return array.booleanMask(((NDIndexBooleans) indices.get(0)).getIndex()); } diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java index a8944cd92c2..43501db4e28 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java @@ -301,7 +301,7 @@ public PtNDArray booleanMask(NDArray index, int axis) { // Result is flattened since shape is undetermined return JniUtils.booleanMask(this, manager.from(index)); } else if (indexShape.equals(getShape().slice(axis))) { - // index will be broadcasted by default + // index will be broadcast by default try (PtNDArray flattedResult = JniUtils.booleanMask(this, manager.from(index))) { // Shape recovery Shape remainder = getShape().slice(0, axis); diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java index b3ae21101b8..5d0f21b619a 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java @@ -347,12 +347,18 @@ public static void set(PtNDArray self, ByteBuffer data) { } public static PtNDArray gather(PtNDArray ndArray, PtNDArray index, long dim) { + if (index.getDataType() != DataType.INT64) { + index = index.toType(DataType.INT64, true); + } return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchGather(ndArray.getHandle(), index.getHandle(), dim, false)); } public static PtNDArray take(PtNDArray ndArray, PtNDArray index) { + if (index.getDataType() != DataType.INT64) { + index = index.toType(DataType.INT64, true); + } return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchTake(ndArray.getHandle(), index.getHandle())); From e94bd17fefeb2735f593f842e333d698657f8703 Mon Sep 17 00:00:00 2001 From: Kexin Date: Thu, 12 May 2022 19:47:40 -0700 Subject: [PATCH 11/12] Add mxnet support of indexing `take` --- api/src/main/java/ai/djl/ndarray/NDArray.java | 3 ++- .../src/main/java/ai/djl/mxnet/engine/MxNDArray.java | 8 +++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/api/src/main/java/ai/djl/ndarray/NDArray.java b/api/src/main/java/ai/djl/ndarray/NDArray.java index c7c6f96aa4e..0a1130e2297 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArray.java +++ b/api/src/main/java/ai/djl/ndarray/NDArray.java @@ -561,7 +561,8 @@ default NDArray get(long... indices) { NDArray gather(NDArray index, int axis); /** - * Returns a partial {@code NDArray} pointed by the indexed array, according to linear indexing. + * Returns a partial {@code NDArray} pointed by the indexed array according to linear indexing, + * and the of output is of the same shape as index. * * @param index picks the elements of an NDArray to the same position as index * @return the partial {@code NDArray} of the same shape as index diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java index 66fe02d5d03..d7f74a0b16b 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java @@ -315,7 +315,13 @@ public NDArray gather(NDArray index, int axis) { /** {@inheritDoc} */ @Override public NDArray take(NDArray index) { - throw new UnsupportedOperationException("Not implemented yet."); + MxOpParams params = new MxOpParams(); + params.addParam("shape", "(-1,)"); + NDList flattened = manager.invoke("reshape", new NDList(this), params); + params.clear(); + params.add("mode", "wrap"); + flattened.add(index); + return manager.invoke("take", flattened, params).singletonOrThrow(); } /** {@inheritDoc} */ From d6aab6b5f23817d63128c2313bac6c0e78d95589 Mon Sep 17 00:00:00 2001 From: Kexin Date: Fri, 13 May 2022 18:19:21 -0700 Subject: [PATCH 12/12] Use existing flatten --- .../src/main/java/ai/djl/mxnet/engine/MxNDArray.java | 6 +----- .../ai/djl/integration/tests/ndarray/NDIndexTest.java | 8 ++++---- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java index d7f74a0b16b..fc68b4b739c 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java @@ -316,12 +316,8 @@ public NDArray gather(NDArray index, int axis) { @Override public NDArray take(NDArray index) { MxOpParams params = new MxOpParams(); - params.addParam("shape", "(-1,)"); - NDList flattened = manager.invoke("reshape", new NDList(this), params); - params.clear(); params.add("mode", "wrap"); - flattened.add(index); - return manager.invoke("take", flattened, params).singletonOrThrow(); + return manager.invoke("take", new NDList(this.flatten(), index), params).singletonOrThrow(); } /** {@inheritDoc} */ diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java index 49e9f652c4a..88abee3ad28 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java @@ -55,9 +55,9 @@ public void testPick() { @Test public void testGather() { try (NDManager manager = NDManager.newBaseManager()) { - NDArray arr = manager.arange(20f).reshape(-1, 4); + NDArray original = manager.arange(20f).reshape(-1, 4); NDArray index = manager.create(new long[] {0, 0, 2, 1, 1, 2}, new Shape(3, 2)); - NDArray actual = arr.gather(index, 1); + NDArray actual = original.gather(index, 1); NDArray expected = manager.create(new float[] {0, 0, 6, 5, 9, 10}, new Shape(3, 2)); Assert.assertEquals(actual, expected); } @@ -66,9 +66,9 @@ public void testGather() { @Test public void testTake() { try (NDManager manager = NDManager.newBaseManager()) { - NDArray arr = manager.arange(1, 7f).reshape(-1, 3); + NDArray original = manager.arange(1, 7f).reshape(-1, 3); NDArray index = manager.create(new long[] {0, 4, 1, 2}, new Shape(2, 2)); - NDArray actual = arr.take(index); + NDArray actual = original.take(index); NDArray expected = manager.create(new float[] {1, 5, 2, 3}, new Shape(2, 2)); Assert.assertEquals(actual, expected); }