Skip to content

Commit

Permalink
doc change
Browse files Browse the repository at this point in the history
  • Loading branch information
KexinFeng committed Jun 29, 2022
1 parent 973e74b commit 952eef2
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 18 deletions.
27 changes: 16 additions & 11 deletions api/src/main/java/ai/djl/ndarray/index/NDIndex.java
Original file line number Diff line number Diff line change
Expand Up @@ -73,44 +73,49 @@ public NDIndex() {
* <pre>
* NDArray a = manager.ones(new Shape(5, 4, 3));
*
* // Gets a subsection of the NDArray in the first axis.
* // Get a subsection of the NDArray in the first axis.
* assertEquals(a.get(new NDIndex("2")).getShape(), new Shape(4, 3));
*
* // Gets a subsection of the NDArray indexing from the end (-i == length - i).
* // Get a subsection of the NDArray indexing from the end (-i == length - i).
* assertEquals(a.get(new NDIndex("-1")).getShape(), new Shape(4, 3));
*
* // Gets everything in the first axis and a subsection in the second axis.
* // Get everything in the first axis and a subsection in the second axis.
* // You can use either : or * to represent everything
* assertEquals(a.get(new NDIndex(":, 2")).getShape(), new Shape(5, 3));
* assertEquals(a.get(new NDIndex("*, 2")).getShape(), new Shape(5, 3));
*
* // Gets a range of values along the second axis that is inclusive on the bottom and exclusive on the top.
* // Get a range of values along the second axis that is inclusive on the bottom and exclusive on the top.
* assertEquals(a.get(new NDIndex(":, 1:3")).getShape(), new Shape(5, 2, 3));
*
* // Excludes either the min or the max of the range to go all the way to the beginning or end.
* // Exclude either the min or the max of the range to go all the way to the beginning or end.
* assertEquals(a.get(new NDIndex(":, :3")).getShape(), new Shape(5, 3, 3));
* assertEquals(a.get(new NDIndex(":, 1:")).getShape(), new Shape(5, 4, 3));
*
* // Uses the value after the second colon in a slicing range, the step, to get every other result.
* // Use the value after the second colon in a slicing range, the step, to get every other result.
* assertEquals(a.get(new NDIndex(":, 1::2")).getShape(), new Shape(5, 2, 3));
*
* // Uses a negative step to reverse along the dimension.
* // Use a negative step to reverse along the dimension.
* assertEquals(a.get(new NDIndex("-1")).getShape(), new Shape(5, 4, 3));
*
* // Uses a variable argument to the index
* // Use a variable argument to the index
* // It can replace any number in any of these formats with {} and then the value of {}
* // is specified in an argument following the indices string.
* assertEquals(a.get(new NDIndex("{}, {}:{}", 0, 1, 3)).getShape(), new Shape(2, 3));
*
* // Uses ellipsis to insert many full slices
* // Use ellipsis to insert many full slices
* assertEquals(a.get(new NDIndex("...")).getShape(), new Shape(5, 4, 3));
*
* // Uses ellipsis to select all the dimensions except for last axis where we only get a subsection.
* // Use ellipsis to select all the dimensions except for last axis where we only get a subsection.
* assertEquals(a.get(new NDIndex("..., 2")).getShape(), new Shape(5, 4));
*
* // Uses null to add an extra axis to the output array
* // Use null to add an extra axis to the output array
* assertEquals(a.get(new NDIndex(":2, null, 0, :2")).getShape(), new Shape(2, 1, 2));
*
* // Get entries of an NDArray with mixed index
* index1 = manager.create(new long[] {0, 1, 1}, new Shape(2));
* bool1 = manager.create(new boolean[] {true, false, true});
* assertEquals(a.get(new NDIndex(":{}, {}, {}, {}" 2, index1, bool1, null).getShape(), new Shape(2, 2, 1));
*
* </pre>
*
* @param indices a comma separated list of indices corresponding to either subsections,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,14 @@ public void testSetArray() {
expected = manager.create(new int[] {0, 1, 9, 10, 4, 5, 11, 12}, new Shape(2, 4));
original.set(new NDIndex(":, 2:"), manager.arange(9, 13).reshape(2, 2));
Assert.assertEquals(original, expected);

// set by index array
NDArray index = manager.create(new long[] {0, 1}, new Shape(2));
value = manager.create(new int[] {666, 777, 888, 999}, new Shape(2, 2));
original.set(new NDIndex("{}, :{}", index, 2), value);
expected =
manager.create(new int[] {666, 777, 3, 888, 999, 6, 7, 8, 9}, new Shape(3, 3));
Assert.assertEquals(original, expected);
}
}

Expand Down Expand Up @@ -189,19 +197,13 @@ public void testSetNumber() {
original.set(new NDIndex("..., 0"), 1);
Assert.assertEquals(original, expected);

// set from int array
// set by index array
original = manager.arange(1, 10).reshape(3, 3);
NDArray index = manager.create(new long[] {0, 1}, new Shape(2));
original.set(new NDIndex("{}, :{}", index, 2), 666);
expected =
manager.create(new int[] {666, 666, 3, 666, 666, 6, 7, 8, 9}, new Shape(3, 3));
Assert.assertEquals(original, expected);

NDArray value = manager.create(new int[] {666, 777, 888, 999}, new Shape(2, 2));
original.set(new NDIndex("{}, :{}", index, 2), value);
expected =
manager.create(new int[] {666, 777, 3, 888, 999, 6, 7, 8, 9}, new Shape(3, 3));
Assert.assertEquals(original, expected);
}
}

Expand Down

0 comments on commit 952eef2

Please sign in to comment.