diff --git a/api/src/main/java/ai/djl/ndarray/index/full/NDIndexFullPick.java b/api/src/main/java/ai/djl/ndarray/index/full/NDIndexFullPick.java index 7171ecec709..f373bb7197c 100644 --- a/api/src/main/java/ai/djl/ndarray/index/full/NDIndexFullPick.java +++ b/api/src/main/java/ai/djl/ndarray/index/full/NDIndexFullPick.java @@ -60,18 +60,6 @@ public static Optional fromIndex(NDIndex index, Shape target) { } NDArray indexElem = ((NDIndexPick) el).getIndex(); fullPick = new NDIndexFullPick(indexElem, axis); - } else if (el instanceof NDIndexTake) { - if (fullPick != null) { - // Don't support multiple picks - throw new UnsupportedOperationException( - "Only one pick per get is currently supported"); - } - NDArray indexElem = ((NDIndexTake) el).getIndex(); - if (!indexElem.getShape().isRankOne()) { - throw new UnsupportedOperationException( - "Only rank-1 indexing array is supported for pick"); - } - fullPick = new NDIndexFullPick(indexElem, axis); } else { // Invalid dim for fullPick return Optional.empty(); diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java index 5c59092f8f6..01182303b72 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java @@ -50,6 +50,13 @@ public void testPick() { original.get( new NDIndex().addAllDim().addPickDim(manager.create(new int[] {0, 1}))); Assert.assertEquals(actual, expected); + + // The difference between take and pick used combined with addAllDim() + NDArray yHat = manager.create(new float[][]{{0.1f, 0.3f, 0.6f}, {0.3f, 0.2f, 0.5f}}); + NDArray yGet = yHat.get(new NDIndex(":, {}", manager.create(new int[]{0, 2}))); + NDArray yPick = yHat.get(new NDIndex().addAllDim().addPickDim(manager.create(new int[]{0, 2}))); + Assert.assertEquals(yGet, manager.create(new float[][]{{0.1f, 0.6f}, {0.3f, 0.5f}})); + Assert.assertEquals(yPick, manager.create(new float[][]{{0.1f}, {0.5f}})); } }