diff --git a/api/src/main/java/ai/djl/ndarray/NDArray.java b/api/src/main/java/ai/djl/ndarray/NDArray.java index edcc507969c8..314413cea6b0 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArray.java +++ b/api/src/main/java/ai/djl/ndarray/NDArray.java @@ -581,8 +581,8 @@ default NDArray get(NDManager manager, long... indices) { NDArray gather(NDArray index, int axis); /** - * Returns a partial {@code NDArray} pointed by index according to linear indexing, and - * the of output is of the same shape as index. + * Returns a partial {@code NDArray} pointed by index according to linear indexing, and the of + * output is of the same shape as index. * * @param index picks the elements of an NDArray and output to the same entry as in index * @return the partial {@code NDArray} of the same shape as index @@ -590,11 +590,11 @@ default NDArray get(NDManager manager, long... indices) { NDArray take(NDArray index); /** - * Set the entries of {@code NDArray} pointed by index according to linear indexing, to - * be the numbers in data, which is of the same shape as index. + * Set the entries of {@code NDArray} pointed by index according to linear indexing, to be the + * numbers in data, which is of the same shape as index. * * @param index select the entries of an {@code NDArray} - * @param data numbers to assign to the indexed entries + * @param data numbers to assign to the indexed entries * @return The NDArray with updated values */ default NDArray put(NDArray index, NDArray data) { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java index 140395b4af97..a4f2a610021d 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java @@ -451,7 +451,8 @@ public static PtNDArray put(PtNDArray ndArray, PtNDArray index, PtNDArray data) } return new PtNDArray( ndArray.getManager(), - PyTorchLibrary.LIB.torchPut(ndArray.getHandle(), index.getHandle(), data.getHandle())); + PyTorchLibrary.LIB.torchPut( + ndArray.getHandle(), index.getHandle(), data.getHandle())); } public static PtNDArray pick(PtNDArray ndArray, PtNDArray index, long dim) { diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java index 8767b1f097f8..49c165fd1115 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java @@ -172,13 +172,22 @@ public void testSetArrayBroadcast() { } @Test - public void testPut() { + public void testSetNumber() { try (NDManager manager = NDManager.newBaseManager()) { NDArray original = manager.create(new float[] {1, 2, 3, 4}, new Shape(2, 2)); - NDArray expected = manager.create(new float[] {1, 8, 666, 77}, new Shape(2, 2)); - NDArray idx = manager.create(new long[] {2, 3, 1}, new Shape(3)); - NDArray data = manager.create(new float[] {666, 77, 8}, new Shape(3)); - Assert.assertEquals(original.put(idx, data), expected); + NDArray expected = manager.create(new float[] {9, 9, 3, 4}, new Shape(2, 2)); + original.set(new NDIndex(0), 9); + Assert.assertEquals(original, expected); + + original = manager.arange(4f).reshape(2, 2); + expected = manager.ones(new Shape(2, 2)); + original.set(new NDIndex("..."), 1); + Assert.assertEquals(original, expected); + + original = manager.arange(4f).reshape(2, 2); + expected = manager.create(new float[] {1, 1, 1, 3}).reshape(2, 2); + original.set(new NDIndex("..., 0"), 1); + Assert.assertEquals(original, expected); } } @@ -220,22 +229,14 @@ public void testSetByFunctionIncrements() { } } - public void testSetNumber() { + @Test + public void testPut() { try (NDManager manager = NDManager.newBaseManager()) { NDArray original = manager.create(new float[] {1, 2, 3, 4}, new Shape(2, 2)); - NDArray expected = manager.create(new float[] {9, 9, 3, 4}, new Shape(2, 2)); - original.set(new NDIndex(0), 9); - Assert.assertEquals(original, expected); - - original = manager.arange(4f).reshape(2, 2); - expected = manager.ones(new Shape(2, 2)); - original.set(new NDIndex("..."), 1); - Assert.assertEquals(original, expected); - - original = manager.arange(4f).reshape(2, 2); - expected = manager.create(new float[] {1, 1, 1, 3}).reshape(2, 2); - original.set(new NDIndex("..., 0"), 1); - Assert.assertEquals(original, expected); + NDArray expected = manager.create(new float[] {1, 8, 666, 77}, new Shape(2, 2)); + NDArray idx = manager.create(new long[] {2, 3, 1}, new Shape(3)); + NDArray data = manager.create(new float[] {666, 77, 8}, new Shape(3)); + Assert.assertEquals(original.put(idx, data), expected); } } }