Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backward compatible with MXNet indexing. #1802

Merged
merged 10 commits into from
Jul 18, 2022
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import ai.djl.ndarray.index.dim.NDIndexAll;
import ai.djl.ndarray.index.dim.NDIndexElement;
import ai.djl.ndarray.index.dim.NDIndexPick;
import ai.djl.ndarray.index.dim.NDIndexTake;
import ai.djl.ndarray.types.Shape;

import java.util.Optional;
Expand Down Expand Up @@ -51,9 +52,17 @@ public static Optional<NDIndexFullPick> fromIndex(NDIndex index, Shape target) {
for (NDIndexElement el : index.getIndices()) {
if (el instanceof NDIndexAll) {
axis++;
} else if (el instanceof NDIndexPick) {
} else if (el instanceof NDIndexPick || el instanceof NDIndexTake) {
if (fullPick == null) {
fullPick = new NDIndexFullPick(((NDIndexPick) el).getIndex(), axis);
NDArray indexElem =
el instanceof NDIndexPick
? ((NDIndexPick) el).getIndex()
: ((NDIndexTake) el).getIndex();
if (!indexElem.getShape().isRankOne()) {
throw new UnsupportedOperationException(
"Only rank-1 indexing array is supported for pick");
}
fullPick = new NDIndexFullPick(indexElem, axis);
} else {
// Don't support multiple picks
throw new UnsupportedOperationException(
Expand Down
21 changes: 21 additions & 0 deletions api/src/main/java/ai/djl/ndarray/types/Shape.java
Original file line number Diff line number Diff line change
Expand Up @@ -477,4 +477,25 @@ public static Shape decode(DataInputStream dis) throws IOException {
}
return new Shape(shapeValue, new String(layout));
}

/**
* Returns if the array is rank-1 which is inferred from the shape.
*
* <p>For example, an array with shape [1, 10, 1] returns true. Array with indeterminate size -1
* returns false.
*
* @return if the array is rank-1
*/
public boolean isRankOne() {
int max = 1;
int ans = 1;
for (long size : shape) {
max = (int) Math.max(max, size);
ans *= size;
if (ans < 0) {
return false;
}
}
return max == ans;
}
}