Skip to content

Commit

Permalink
Set array with advanced index on PyTorch
Browse files Browse the repository at this point in the history
  • Loading branch information
KexinFeng committed Jul 1, 2022
1 parent 33ce569 commit e68031b
Show file tree
Hide file tree
Showing 11 changed files with 209 additions and 106 deletions.
2 changes: 1 addition & 1 deletion api/src/main/java/ai/djl/ndarray/NDArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ default NDArray get(NDManager manager, long... indices) {
*
* @param index select the entries of an {@code NDArray}
* @param data numbers to assign to the indexed entries
* @return The NDArray with updated values
* @return the NDArray with updated values
*/
default NDArray put(NDArray index, NDArray data) {
throw new UnsupportedOperationException("Not implemented yet.");
Expand Down
99 changes: 45 additions & 54 deletions api/src/main/java/ai/djl/ndarray/index/NDArrayIndexer.java
Original file line number Diff line number Diff line change
Expand Up @@ -79,89 +79,80 @@ public NDArray get(NDArray array, NDIndex index) {
}

/**
* Sets the values of the array at the fullSlice with an array.
*
* @param array the array to set
* @param fullSlice the fullSlice of the index to set in the array
* @param value the value to set with
*/
public abstract void set(NDArray array, NDIndexFullSlice fullSlice, NDArray value);

/**
* Sets the values of the array at the boolean locations with an array.
*
* @param array the array to set
* @param indices a boolean array where true indicates values to update
* @param value the value to set with when condition is true
*/
public void set(NDArray array, NDIndexBooleans indices, NDArray value) {
array.intern(NDArrays.where(indices.getIndex(), value, array));
}

/**
* Sets the values of the array at the index locations with an array.
* Sets the entries of array at the indexed locations with the parameter value. The value can be
* only Number or NDArray.
*
* @param array the array to set
* @param index the index to set at in the array
* @param value the value to set with
*/
public void set(NDArray array, NDIndex index, NDArray value) {
public void set(NDArray array, NDIndex index, Object value) {
NDIndexFullSlice fullSlice =
NDIndexFullSlice.fromIndex(index, array.getShape()).orElse(null);
if (fullSlice != null) {
if (value instanceof Number) {
set(array, fullSlice, (Number) value);
} else if (value instanceof NDArray) {
set(array, fullSlice, (NDArray) value);
} else {
throw new IllegalArgumentException(
"The type of value to assign cannot be other than NDArray and Number.");
}
return;
}

List<NDIndexElement> indices = index.getIndices();
if (!indices.isEmpty() && indices.get(0) instanceof NDIndexBooleans) {
if (indices.size() != 1) {
throw new IllegalArgumentException(
"get() currently didn't support more that one boolean NDArray");
"set() currently doesn't support more than one boolean NDArray");
}
if (value instanceof Number) {
set(
array,
(NDIndexBooleans) indices.get(0),
array.getManager().create((Number) value));
} else if (value instanceof NDArray) {
set(array, (NDIndexBooleans) indices.get(0), (NDArray) value);
} else {
throw new IllegalArgumentException(
"The type of value to assign cannot be other than NDArray and Number.");
}
set(array, (NDIndexBooleans) indices.get(0), value);
}

NDIndexFullSlice fullSlice =
NDIndexFullSlice.fromIndex(index, array.getShape()).orElse(null);
if (fullSlice != null) {
set(array, fullSlice, value);
return;
}
throw new UnsupportedOperationException(
"set() currently supports all, fixed, and slices indices");
}

/**
* Sets the values of the array at the fullSlice with a number.
* Sets the values of the array at the fullSlice with an array.
*
* @param array the array to set
* @param fullSlice the fullSlice of the index to set in the array
* @param value the value to set with
*/
public abstract void set(NDArray array, NDIndexFullSlice fullSlice, Number value);
public abstract void set(NDArray array, NDIndexFullSlice fullSlice, NDArray value);

/**
* Sets the values of the array at the index locations with a number.
* Sets the values of the array at the boolean locations with an array.
*
* @param array the array to set
* @param index the index to set at in the array
* @param value the value to set with
* @param indices a boolean array where true indicates values to update
* @param value the value to set with when condition is true
*/
public void set(NDArray array, NDIndex index, Number value) {
NDIndexFullSlice fullSlice =
NDIndexFullSlice.fromIndex(index, array.getShape()).orElse(null);
if (fullSlice != null) {
set(array, fullSlice, value);
return;
}
// use booleanMask for NDIndexBooleans case
List<NDIndexElement> indices = index.getIndices();
if (!indices.isEmpty() && indices.get(0) instanceof NDIndexBooleans) {
if (indices.size() != 1) {
throw new IllegalArgumentException(
"set() currently didn't support more that one boolean NDArray");
}
set(array, (NDIndexBooleans) indices.get(0), array.getManager().create(value));
return;
}
throw new UnsupportedOperationException(
"set() currently supports all, fixed, and slices indices");
public void set(NDArray array, NDIndexBooleans indices, NDArray value) {
array.intern(NDArrays.where(indices.getIndex(), value, array));
}

/**
* Sets the values of the array at the fullSlice with a number.
*
* @param array the array to set
* @param fullSlice the fullSlice of the index to set in the array
* @param value the value to set with
*/
public abstract void set(NDArray array, NDIndexFullSlice fullSlice, Number value);

/**
* Sets a scalar value in the array at the indexed location.
*
Expand Down
27 changes: 16 additions & 11 deletions api/src/main/java/ai/djl/ndarray/index/NDIndex.java
Original file line number Diff line number Diff line change
Expand Up @@ -73,44 +73,49 @@ public NDIndex() {
* <pre>
* NDArray a = manager.ones(new Shape(5, 4, 3));
*
* // Gets a subsection of the NDArray in the first axis.
* // Get a subsection of the NDArray in the first axis.
* assertEquals(a.get(new NDIndex("2")).getShape(), new Shape(4, 3));
*
* // Gets a subsection of the NDArray indexing from the end (-i == length - i).
* // Get a subsection of the NDArray indexing from the end (-i == length - i).
* assertEquals(a.get(new NDIndex("-1")).getShape(), new Shape(4, 3));
*
* // Gets everything in the first axis and a subsection in the second axis.
* // Get everything in the first axis and a subsection in the second axis.
* // You can use either : or * to represent everything
* assertEquals(a.get(new NDIndex(":, 2")).getShape(), new Shape(5, 3));
* assertEquals(a.get(new NDIndex("*, 2")).getShape(), new Shape(5, 3));
*
* // Gets a range of values along the second axis that is inclusive on the bottom and exclusive on the top.
* // Get a range of values along the second axis that is inclusive on the bottom and exclusive on the top.
* assertEquals(a.get(new NDIndex(":, 1:3")).getShape(), new Shape(5, 2, 3));
*
* // Excludes either the min or the max of the range to go all the way to the beginning or end.
* // Exclude either the min or the max of the range to go all the way to the beginning or end.
* assertEquals(a.get(new NDIndex(":, :3")).getShape(), new Shape(5, 3, 3));
* assertEquals(a.get(new NDIndex(":, 1:")).getShape(), new Shape(5, 4, 3));
*
* // Uses the value after the second colon in a slicing range, the step, to get every other result.
* // Use the value after the second colon in a slicing range, the step, to get every other result.
* assertEquals(a.get(new NDIndex(":, 1::2")).getShape(), new Shape(5, 2, 3));
*
* // Uses a negative step to reverse along the dimension.
* // Use a negative step to reverse along the dimension.
* assertEquals(a.get(new NDIndex("-1")).getShape(), new Shape(5, 4, 3));
*
* // Uses a variable argument to the index
* // Use a variable argument to the index
* // It can replace any number in any of these formats with {} and then the value of {}
* // is specified in an argument following the indices string.
* assertEquals(a.get(new NDIndex("{}, {}:{}", 0, 1, 3)).getShape(), new Shape(2, 3));
*
* // Uses ellipsis to insert many full slices
* // Use ellipsis to insert many full slices
* assertEquals(a.get(new NDIndex("...")).getShape(), new Shape(5, 4, 3));
*
* // Uses ellipsis to select all the dimensions except for last axis where we only get a subsection.
* // Use ellipsis to select all the dimensions except for last axis where we only get a subsection.
* assertEquals(a.get(new NDIndex("..., 2")).getShape(), new Shape(5, 4));
*
* // Uses null to add an extra axis to the output array
* // Use null to add an extra axis to the output array
* assertEquals(a.get(new NDIndex(":2, null, 0, :2")).getShape(), new Shape(2, 1, 2));
*
* // Get entries of an NDArray with mixed index
* index1 = manager.create(new long[] {0, 1, 1}, new Shape(2));
* bool1 = manager.create(new boolean[] {true, false, true});
* assertEquals(a.get(new NDIndex(":{}, {}, {}, {}" 2, index1, bool1, null).getShape(), new Shape(2, 2, 1));
*
* </pre>
*
* @param indices a comma separated list of indices corresponding to either subsections,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,11 +258,7 @@ public NDArray take(NDArray index) {
return JniUtils.take(this, (PtNDArray) index);
}

/**
* {@inheritDoc}
*
* @return
*/
/** {@inheritDoc} */
@Override
public NDArray put(NDArray index, NDArray data) {
if (!(index instanceof PtNDArray) || !(data instanceof PtNDArray)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,25 @@ public NDArray get(NDArray array, NDIndex index) {
}
}

/** {@inheritDoc} */
@Override
public void set(NDArray array, NDIndex index, Object data) {
PtNDArray ptArray =
array instanceof PtNDArray
? (PtNDArray) array
: manager.create(
array.toByteBuffer(), array.getShape(), array.getDataType());

if (data instanceof Number) {
JniUtils.indexAdvPut(ptArray, index, (PtNDArray) manager.create((Number) data));
} else if (data instanceof NDArray) {
JniUtils.indexAdvPut(ptArray, index, (PtNDArray) data);
} else {
throw new IllegalArgumentException(
"The type of value to assign cannot be other than NDArray and Number.");
}
}

/** {@inheritDoc} */
@Override
public void set(NDArray array, NDIndexFullSlice fullSlice, NDArray value) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,69 @@ public static PtNDArray indexAdv(PtNDArray ndArray, NDIndex index) {

return new PtNDArray(
ndArray.getManager(),
PyTorchLibrary.LIB.torchIndexReturn(ndArray.getHandle(), torchIndexHandle));
PyTorchLibrary.LIB.torchIndexAdvGet(ndArray.getHandle(), torchIndexHandle));
}

@SuppressWarnings("OptionalGetWithoutIsPresent")
public static void indexAdvPut(PtNDArray ndArray, NDIndex index, PtNDArray data) {
if (ndArray == null) {
return;
}

// Index aggregation
List<NDIndexElement> indices = index.getIndices();
long torchIndexHandle = PyTorchLibrary.LIB.torchIndexInit(indices.size());
ListIterator<NDIndexElement> it = indices.listIterator();
while (it.hasNext()) {
if (it.nextIndex() == index.getEllipsisIndex()) {
PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, true);
}

NDIndexElement elem = it.next();
if (elem instanceof NDIndexNull) {
PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, false);
} else if (elem instanceof NDIndexSlice) {
Long min = ((NDIndexSlice) elem).getMin();
Long max = ((NDIndexSlice) elem).getMax();
Long step = ((NDIndexSlice) elem).getStep();
int nullSliceBin = (min == null ? 1 : 0) * 2 + (max == null ? 1 : 0);
// nullSliceBin encodes whether the slice (min, max) is null:
// is_null == 1, ! is_null == 0;
// 0b11 == 3, 0b10 = 2, ...
PyTorchLibrary.LIB.torchIndexAppendSlice(
torchIndexHandle,
min == null ? 0 : min,
max == null ? 0 : max,
step == null ? 1 : step,
nullSliceBin);
} else if (elem instanceof NDIndexAll) {
PyTorchLibrary.LIB.torchIndexAppendSlice(torchIndexHandle, 0, 0, 1, 3);
} else if (elem instanceof NDIndexFixed) {
PyTorchLibrary.LIB.torchIndexAppendFixed(
torchIndexHandle, ((NDIndexFixed) elem).getIndex());
} else if (elem instanceof NDIndexBooleans) {
PtNDArray indexArr = (PtNDArray) ((NDIndexBooleans) elem).getIndex();
PyTorchLibrary.LIB.torchIndexAppendArray(torchIndexHandle, indexArr.getHandle());
} else if (elem instanceof NDIndexTake) {
PtNDArray indexArr = (PtNDArray) ((NDIndexTake) elem).getIndex();
if (indexArr.getDataType() != DataType.INT64) {
indexArr = indexArr.toType(DataType.INT64, true);
}
PyTorchLibrary.LIB.torchIndexAppendArray(torchIndexHandle, indexArr.getHandle());
} else if (elem instanceof NDIndexPick) {
// Backward compatible
NDIndexFullPick fullPick =
NDIndexFullPick.fromIndex(index, ndArray.getShape()).get();
pick(ndArray, ndArray.getManager().from(fullPick.getIndices()), fullPick.getAxis());
return;
}
}
if (indices.size() == index.getEllipsisIndex()) {
PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, true);
}

PyTorchLibrary.LIB.torchIndexAdvPut(
ndArray.getHandle(), torchIndexHandle, data.getHandle());
}

public static void indexSet(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ native void torchIndexPut(
long[] maxIndices,
long[] stepIndices);

native void torchIndexAdvPut(long handle, long torchIndexHandle, long data);

native void torchSet(long handle, ByteBuffer data);

native long torchSlice(long handle, long dim, long start, long end, long step);
Expand Down Expand Up @@ -605,7 +607,7 @@ native void sgdUpdate(

native long torchIndexInit(int size);

native long torchIndexReturn(long handle, long torchIndexHandle);
native long torchIndexAdvGet(long handle, long torchIndexHandle);

native void torchIndexAppendNoneEllipsis(long torchIndexHandle, boolean isEllipsis);

Expand Down
Loading

0 comments on commit e68031b

Please sign in to comment.