Skip to content

Commit

Permalink
add index test and clean code
Browse files Browse the repository at this point in the history
  • Loading branch information
KexinFeng committed Aug 28, 2022
1 parent 5f3ebfb commit 8ffac00
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 12 deletions.
12 changes: 0 additions & 12 deletions api/src/main/java/ai/djl/ndarray/index/full/NDIndexFullPick.java
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,6 @@ public static Optional<NDIndexFullPick> fromIndex(NDIndex index, Shape target) {
}
NDArray indexElem = ((NDIndexPick) el).getIndex();
fullPick = new NDIndexFullPick(indexElem, axis);
} else if (el instanceof NDIndexTake) {
if (fullPick != null) {
// Don't support multiple picks
throw new UnsupportedOperationException(
"Only one pick per get is currently supported");
}
NDArray indexElem = ((NDIndexTake) el).getIndex();
if (!indexElem.getShape().isRankOne()) {
throw new UnsupportedOperationException(
"Only rank-1 indexing array is supported for pick");
}
fullPick = new NDIndexFullPick(indexElem, axis);
} else {
// Invalid dim for fullPick
return Optional.empty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ public void testPick() {
original.get(
new NDIndex().addAllDim().addPickDim(manager.create(new int[] {0, 1})));
Assert.assertEquals(actual, expected);

// The difference between take and pick used combined with addAllDim()
NDArray yHat = manager.create(new float[][]{{0.1f, 0.3f, 0.6f}, {0.3f, 0.2f, 0.5f}});
NDArray yGet = yHat.get(new NDIndex(":, {}", manager.create(new int[]{0, 2})));
NDArray yPick = yHat.get(new NDIndex().addAllDim().addPickDim(manager.create(new int[]{0, 2})));
Assert.assertEquals(yGet, manager.create(new float[][]{{0.1f, 0.6f}, {0.3f, 0.5f}}));
Assert.assertEquals(yPick, manager.create(new float[][]{{0.1f}, {0.5f}}));
}
}

Expand Down

0 comments on commit 8ffac00

Please sign in to comment.