Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backward compatible with MXNet indexing. #1802

Merged
merged 10 commits into from
Jul 18, 2022
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import ai.djl.ndarray.index.dim.NDIndexAll;
import ai.djl.ndarray.index.dim.NDIndexElement;
import ai.djl.ndarray.index.dim.NDIndexPick;
import ai.djl.ndarray.index.dim.NDIndexTake;
import ai.djl.ndarray.types.Shape;

import java.util.Optional;
Expand Down Expand Up @@ -52,13 +53,25 @@ public static Optional<NDIndexFullPick> fromIndex(NDIndex index, Shape target) {
if (el instanceof NDIndexAll) {
axis++;
} else if (el instanceof NDIndexPick) {
if (fullPick == null) {
fullPick = new NDIndexFullPick(((NDIndexPick) el).getIndex(), axis);
} else {
if (fullPick != null) {
// Don't support multiple picks
throw new UnsupportedOperationException(
"Only one pick per get is currently supported");
}
NDArray indexElem = ((NDIndexPick) el).getIndex();
fullPick = new NDIndexFullPick(indexElem, axis);
} else if (el instanceof NDIndexTake) {
if (fullPick != null) {
// Don't support multiple picks
throw new UnsupportedOperationException(
"Only one pick per get is currently supported");
}
NDArray indexElem = ((NDIndexTake) el).getIndex();
if (!indexElem.getShape().isRankOne()) {
throw new UnsupportedOperationException(
"Only rank-1 indexing array is supported for pick");
}
fullPick = new NDIndexFullPick(indexElem, axis);
} else {
// Invalid dim for fullPick
return Optional.empty();
Expand Down
22 changes: 22 additions & 0 deletions api/src/main/java/ai/djl/ndarray/types/Shape.java
Original file line number Diff line number Diff line change
Expand Up @@ -477,4 +477,26 @@ public static Shape decode(DataInputStream dis) throws IOException {
}
return new Shape(shapeValue, new String(layout));
}

/**
* Returns if the array is rank-1 which is inferred from the shape.
*
* <p>For example, an array with shape [1, 10, 1] returns true. Array with indeterminate size -1
* returns false.
*
* @return if the array is rank-1
*/
public boolean isRankOne() {
int max = 1;
int ans = 1;
for (long s : shape) {
int size = Math.toIntExact(s);
max = Math.max(max, size);
ans *= size;
if (ans < 0) {
return false;
}
}
return max == ans;
}
}