Skip to content

Commit

Permalink
Fix NDIndexTest
Browse files Browse the repository at this point in the history
  • Loading branch information
KexinFeng committed Jun 26, 2022
1 parent 61986bc commit 59bd30c
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 25 deletions.
10 changes: 5 additions & 5 deletions api/src/main/java/ai/djl/ndarray/NDArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -581,20 +581,20 @@ default NDArray get(NDManager manager, long... indices) {
NDArray gather(NDArray index, int axis);

/**
* Returns a partial {@code NDArray} pointed by index according to linear indexing, and
* the of output is of the same shape as index.
* Returns a partial {@code NDArray} pointed by index according to linear indexing, and the of
* output is of the same shape as index.
*
* @param index picks the elements of an NDArray and output to the same entry as in index
* @return the partial {@code NDArray} of the same shape as index
*/
NDArray take(NDArray index);

/**
* Set the entries of {@code NDArray} pointed by index according to linear indexing, to
* be the numbers in data, which is of the same shape as index.
* Set the entries of {@code NDArray} pointed by index according to linear indexing, to be the
* numbers in data, which is of the same shape as index.
*
* @param index select the entries of an {@code NDArray}
* @param data numbers to assign to the indexed entries
* @param data numbers to assign to the indexed entries
* @return The NDArray with updated values
*/
default NDArray put(NDArray index, NDArray data) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,8 @@ public static PtNDArray put(PtNDArray ndArray, PtNDArray index, PtNDArray data)
}
return new PtNDArray(
ndArray.getManager(),
PyTorchLibrary.LIB.torchPut(ndArray.getHandle(), index.getHandle(), data.getHandle()));
PyTorchLibrary.LIB.torchPut(
ndArray.getHandle(), index.getHandle(), data.getHandle()));
}

public static PtNDArray pick(PtNDArray ndArray, PtNDArray index, long dim) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,22 @@ public void testSetArrayBroadcast() {
}

@Test
public void testPut() {
public void testSetNumber() {
try (NDManager manager = NDManager.newBaseManager()) {
NDArray original = manager.create(new float[] {1, 2, 3, 4}, new Shape(2, 2));
NDArray expected = manager.create(new float[] {1, 8, 666, 77}, new Shape(2, 2));
NDArray idx = manager.create(new long[] {2, 3, 1}, new Shape(3));
NDArray data = manager.create(new float[] {666, 77, 8}, new Shape(3));
Assert.assertEquals(original.put(idx, data), expected);
NDArray expected = manager.create(new float[] {9, 9, 3, 4}, new Shape(2, 2));
original.set(new NDIndex(0), 9);
Assert.assertEquals(original, expected);

original = manager.arange(4f).reshape(2, 2);
expected = manager.ones(new Shape(2, 2));
original.set(new NDIndex("..."), 1);
Assert.assertEquals(original, expected);

original = manager.arange(4f).reshape(2, 2);
expected = manager.create(new float[] {1, 1, 1, 3}).reshape(2, 2);
original.set(new NDIndex("..., 0"), 1);
Assert.assertEquals(original, expected);
}
}

Expand Down Expand Up @@ -220,22 +229,14 @@ public void testSetByFunctionIncrements() {
}
}

public void testSetNumber() {
@Test
public void testPut() {
try (NDManager manager = NDManager.newBaseManager()) {
NDArray original = manager.create(new float[] {1, 2, 3, 4}, new Shape(2, 2));
NDArray expected = manager.create(new float[] {9, 9, 3, 4}, new Shape(2, 2));
original.set(new NDIndex(0), 9);
Assert.assertEquals(original, expected);

original = manager.arange(4f).reshape(2, 2);
expected = manager.ones(new Shape(2, 2));
original.set(new NDIndex("..."), 1);
Assert.assertEquals(original, expected);

original = manager.arange(4f).reshape(2, 2);
expected = manager.create(new float[] {1, 1, 1, 3}).reshape(2, 2);
original.set(new NDIndex("..., 0"), 1);
Assert.assertEquals(original, expected);
NDArray expected = manager.create(new float[] {1, 8, 666, 77}, new Shape(2, 2));
NDArray idx = manager.create(new long[] {2, 3, 1}, new Shape(3));
NDArray data = manager.create(new float[] {666, 77, 8}, new Shape(3));
Assert.assertEquals(original.put(idx, data), expected);
}
}
}

0 comments on commit 59bd30c

Please sign in to comment.