Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pytorch] Advanced indexing that supports all indexing features on PyTorch #1719

Merged
merged 15 commits into from
Jun 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions api/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ dependencies {
}
testImplementation "org.slf4j:slf4j-simple:${slf4j_version}"
testRuntimeOnly project(":engines:pytorch:pytorch-model-zoo")
testRuntimeOnly project(":engines:pytorch:pytorch-jni")
}

javadoc {
Expand Down
6 changes: 1 addition & 5 deletions api/src/main/java/ai/djl/ndarray/NDArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -530,11 +530,7 @@ default NDArray get(NDManager manager, NDIndex index) {
* @return the partial {@code NDArray}
*/
default NDArray get(NDArray index) {
if (index.getDataType() == DataType.BOOLEAN) {
return get(new NDIndex().addBooleanIndex(index));
} else {
return take(index);
}
return get(new NDIndex("{}", index));
}

/**
Expand Down
2 changes: 1 addition & 1 deletion api/src/main/java/ai/djl/ndarray/NDManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ static NDManager subManagerOf(NDResource resource) {
ByteBuffer allocateDirect(int capacity);

/**
* Creates a new {@code NDArray} if the input {@link NDArray} is from external engine.
* Creates a new {@code NDArray} if the input {@link NDArray} is from an external engine.
*
* @param array the input {@code NDArray}
* @return a new {@code NDArray} if the input {@code NDArray} is from external engine
Expand Down
18 changes: 16 additions & 2 deletions api/src/main/java/ai/djl/ndarray/index/NDIndex.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
import ai.djl.ndarray.index.dim.NDIndexBooleans;
import ai.djl.ndarray.index.dim.NDIndexElement;
import ai.djl.ndarray.index.dim.NDIndexFixed;
import ai.djl.ndarray.index.dim.NDIndexNull;
import ai.djl.ndarray.index.dim.NDIndexPick;
import ai.djl.ndarray.index.dim.NDIndexSlice;
import ai.djl.ndarray.index.dim.NDIndexTake;
import ai.djl.ndarray.types.DataType;

import java.util.ArrayList;
Expand Down Expand Up @@ -50,7 +52,7 @@ public class NDIndex {
/* Android regex requires escape } char as well */
private static final Pattern ITEM_PATTERN =
Pattern.compile(
"(\\*)|((-?\\d+|\\{\\})?:(-?\\d+|\\{\\})?(:(-?\\d+|\\{\\}))?)|(-?\\d+|\\{\\})");
"(\\*)|((-?\\d+|\\{\\})?:(-?\\d+|\\{\\})?(:(-?\\d+|\\{\\}))?)|(-?\\d+|\\{\\})|null");

private int rank;
private List<NDIndexElement> indices;
Expand Down Expand Up @@ -105,6 +107,10 @@ public NDIndex() {
*
* // Uses ellipsis to select all the dimensions except for last axis where we only get a subsection.
* assertEquals(a.get(new NDIndex("..., 2")).getShape(), new Shape(5, 4));
*
* // Uses null to add an extra axis to the output array
* assertEquals(a.get(new NDIndex(":2, null, 0, :2")).getShape(), new Shape(2, 1, 2));
*
* </pre>
*
* @param indices a comma separated list of indices corresponding to either subsections,
Expand Down Expand Up @@ -335,6 +341,11 @@ private int addIndexItem(String indexItem, int argIndex, Object[] args) {
if (!m.matches()) {
throw new IllegalArgumentException("Invalid argument index: " + indexItem);
}
// "null" case
if ("null".equals(indexItem)) {
indices.add(new NDIndexNull());
return argIndex;
}
// "*" case
String star = m.group(1);
if (star != null) {
Expand All @@ -358,9 +369,12 @@ private int addIndexItem(String indexItem, int argIndex, Object[] args) {
indices.add(new NDIndexBooleans(array));
return argIndex + 1;
} else if (array.getDataType().isInteger()) {
indices.add(new NDIndexPick(array));
indices.add(new NDIndexTake(array));
return argIndex + 1;
}
} else if (arg == null) {
indices.add(new NDIndexNull());
return argIndex + 1;
}
throw new IllegalArgumentException("Unknown argument: " + arg);
} else {
Expand Down
23 changes: 23 additions & 0 deletions api/src/main/java/ai/djl/ndarray/index/dim/NDIndexNull.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.ndarray.index.dim;

/** An {@code NDIndexElement} to return all values in a particular dimension. */
public class NDIndexNull implements NDIndexElement {

/** {@inheritDoc} */
@Override
public int getRank() {
return 1;
}
}
16 changes: 8 additions & 8 deletions api/src/main/java/ai/djl/ndarray/index/dim/NDIndexPick.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@
/** An {@link NDIndexElement} that gets elements by index in the specified axis. */
public class NDIndexPick implements NDIndexElement {

private NDArray indices;
private NDArray index;

/**
* Constructs a pick.
*
* @param indices the indices to pick
* @param index the index to pick
*/
public NDIndexPick(NDArray indices) {
this.indices = indices;
public NDIndexPick(NDArray index) {
this.index = index;
}

/** {@inheritDoc} */
Expand All @@ -35,11 +35,11 @@ public int getRank() {
}

/**
* Returns the indices to pick.
* Returns the index to pick.
*
* @return the indices to pick
* @return the index to pick
*/
public NDArray getIndices() {
return indices;
public NDArray getIndex() {
return index;
}
}
45 changes: 45 additions & 0 deletions api/src/main/java/ai/djl/ndarray/index/dim/NDIndexTake.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.ndarray.index.dim;

import ai.djl.ndarray.NDArray;

/** An {@link NDIndexElement} that gets elements by index in the specified axis. */
public class NDIndexTake implements NDIndexElement {

private NDArray index;

/**
* Constructs a pick.
*
* @param index the index to pick
*/
public NDIndexTake(NDArray index) {
this.index = index;
}

/** {@inheritDoc} */
@Override
public int getRank() {
return 1;
}

/**
* Returns the index to pick.
*
* @return the index to pick
*/
public NDArray getIndex() {
return index;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public static Optional<NDIndexFullPick> fromIndex(NDIndex index, Shape target) {
axis++;
} else if (el instanceof NDIndexPick) {
if (fullPick == null) {
fullPick = new NDIndexFullPick(((NDIndexPick) el).getIndices(), axis);
fullPick = new NDIndexFullPick(((NDIndexPick) el).getIndex(), axis);
} else {
// Don't support multiple picks
throw new UnsupportedOperationException(
Expand Down
1 change: 1 addition & 0 deletions engines/dlr/dlr-engine/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ dependencies {
}
testImplementation "org.slf4j:slf4j-simple:${slf4j_version}"
testRuntimeOnly project(":engines:pytorch:pytorch-engine")
testRuntimeOnly project(":engines:pytorch:pytorch-jni")
}

compileJava.dependsOn(processResources)
Expand Down
1 change: 1 addition & 0 deletions engines/paddlepaddle/paddlepaddle-model-zoo/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ dependencies {
testImplementation "org.slf4j:slf4j-simple:${slf4j_version}"
testImplementation(project(":testing"))
testRuntimeOnly project(":engines:pytorch:pytorch-engine")
testRuntimeOnly project(":engines:pytorch:pytorch-jni")
}

task syncS3(type: Exec) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.index.NDArrayIndexer;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.index.dim.NDIndexBooleans;
import ai.djl.ndarray.index.full.NDIndexFullPick;
import ai.djl.ndarray.index.full.NDIndexFullSlice;
Expand Down Expand Up @@ -49,6 +50,28 @@ public NDArray get(NDArray array, NDIndexFullSlice fullSlice) {
}
}

/** {@inheritDoc} */
@Override
public NDArray get(NDArray array, NDIndex index) {
if (index.getRank() == 0) {
if (array.getShape().isScalar()) {
return array.getManager() == manager
? array.duplicate()
: manager.create(
array.toByteBuffer(), array.getShape(), array.getDataType());
}
index.addAllDim();
}

if (array == null || array instanceof PtNDArray && array.getManager() == manager) {
return JniUtils.indexAdv((PtNDArray) array, index);
} else {
PtNDArray arrayNew =
manager.create(array.toByteBuffer(), array.getShape(), array.getDataType());
return JniUtils.indexAdv(arrayNew, index);
}
}

/** {@inheritDoc} */
@Override
public void set(NDArray array, NDIndexFullSlice fullSlice, NDArray value) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@

import ai.djl.Device;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.index.dim.NDIndexAll;
import ai.djl.ndarray.index.dim.NDIndexBooleans;
import ai.djl.ndarray.index.dim.NDIndexElement;
import ai.djl.ndarray.index.dim.NDIndexFixed;
import ai.djl.ndarray.index.dim.NDIndexNull;
import ai.djl.ndarray.index.dim.NDIndexPick;
import ai.djl.ndarray.index.dim.NDIndexSlice;
import ai.djl.ndarray.index.dim.NDIndexTake;
import ai.djl.ndarray.index.full.NDIndexFullPick;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.ndarray.types.SparseFormat;
Expand All @@ -35,6 +45,8 @@
import java.nio.file.Path;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.ListIterator;
import java.util.Set;

/**
Expand Down Expand Up @@ -337,6 +349,69 @@ public static PtNDArray index(
ndArray.getHandle(), minIndices, maxIndices, stepIndices));
}

@SuppressWarnings("OptionalGetWithoutIsPresent")
public static PtNDArray indexAdv(PtNDArray ndArray, NDIndex index) {
if (ndArray == null) {
return ndArray;
}
List<NDIndexElement> indices = index.getIndices();
long torchIndexHandle = PyTorchLibrary.LIB.torchIndexInit(indices.size());
ListIterator<NDIndexElement> it = indices.listIterator();
while (it.hasNext()) {
if (it.nextIndex() == index.getEllipsisIndex()) {
PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, true);
}

NDIndexElement elem = it.next();
if (elem instanceof NDIndexNull) {
PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, false);
} else if (elem instanceof NDIndexSlice) {
Long min = ((NDIndexSlice) elem).getMin();
Long max = ((NDIndexSlice) elem).getMax();
Long step = ((NDIndexSlice) elem).getStep();
int nullSliceBin = (min == null ? 1 : 0) * 2 + (max == null ? 1 : 0);
// nullSliceBin encodes whether the slice (min, max) is null:
// is_null == 1, ! is_null == 0;
// 0b11 == 3, 0b10 = 2, ...
PyTorchLibrary.LIB.torchIndexAppendSlice(
torchIndexHandle,
min == null ? 0 : min,
max == null ? 0 : max,
step == null ? 1 : step,
nullSliceBin);
} else if (elem instanceof NDIndexAll) {
PyTorchLibrary.LIB.torchIndexAppendSlice(torchIndexHandle, 0, 0, 1, 3);
} else if (elem instanceof NDIndexFixed) {
PyTorchLibrary.LIB.torchIndexAppendFixed(
torchIndexHandle, ((NDIndexFixed) elem).getIndex());
} else if (elem instanceof NDIndexBooleans) {
PtNDArray indexArr = (PtNDArray) ((NDIndexBooleans) elem).getIndex();
PyTorchLibrary.LIB.torchIndexAppendArray(torchIndexHandle, indexArr.getHandle());
} else if (elem instanceof NDIndexTake) {
PtNDArray indexArr = (PtNDArray) ((NDIndexTake) elem).getIndex();
if (indexArr.getDataType() != DataType.INT64) {
indexArr = indexArr.toType(DataType.INT64, true);
}
PyTorchLibrary.LIB.torchIndexAppendArray(torchIndexHandle, indexArr.getHandle());
} else if (elem instanceof NDIndexPick) {
// Backward compatible
NDIndexFullPick fullPick =
NDIndexFullPick.fromIndex(index, ndArray.getShape()).get();
return pick(
ndArray,
ndArray.getManager().from(fullPick.getIndices()),
fullPick.getAxis());
}
}
if (indices.size() == index.getEllipsisIndex()) {
PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, true);
}

return new PtNDArray(
ndArray.getManager(),
PyTorchLibrary.LIB.torchIndexReturn(ndArray.getHandle(), torchIndexHandle));
}

public static void indexSet(
PtNDArray ndArray,
PtNDArray value,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -600,4 +600,17 @@ native void sgdUpdate(
native long torchNorm(long handle, int ord, long[] axis, boolean keepDims);

native long torchNonZeros(long handle);

native long torchIndexInit(int size);

native long torchIndexReturn(long handle, long torchIndexHandle);

native void torchIndexAppendNoneEllipsis(long torchIndexHandle, boolean isEllipsis);

native void torchIndexAppendSlice(
long torchIndexHandle, long min, long max, long step, int nullSliceBinary);

native void torchIndexAppendFixed(long torchIndexHandle, long idx);

native void torchIndexAppendArray(long torchIndexHandle, long arrayHandle);
}
Loading