Skip to content

Commit

Permalink
Bulk Batch creation (#1869)
Browse files Browse the repository at this point in the history
* Bulk Batch creation

* execute "take" instead of "pick" for mxnet if requested

* more convenience methods to retrieve a sub dataset

* Update api/src/main/java/ai/djl/training/dataset/BulkDataIterable.java

* revert initial take implementation

* use new take implementation

* use NDIndexFullTake approach

* added test for BulkDataIterable

* tidy up test for LinearCollection

* test isRange

* add index test and clean code

* add warning

Co-authored-by: KexinFeng <fengx463@umn.edu>
  • Loading branch information
patins1 and KexinFeng authored Sep 1, 2022
1 parent eb0d1a9 commit e9ae8d3
Show file tree
Hide file tree
Showing 20 changed files with 797 additions and 92 deletions.
14 changes: 13 additions & 1 deletion api/src/main/java/ai/djl/ndarray/NDArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,19 @@ default NDArray get(NDManager manager, long... indices) {
* @param index picks the elements of an NDArray and output to the same entry as in index
* @return the partial {@code NDArray} of the same shape as index
*/
NDArray take(NDArray index);
default NDArray take(NDArray index) {
return take(this.getManager(), index);
}

/**
* Returns a partial {@code NDArray} pointed by index according to linear indexing, and the of
* output is of the same shape as index.
*
* @param manager the manager used to create the arrays
* @param index picks the elements of an NDArray and output to the same entry as in index
* @return the partial {@code NDArray} of the same shape as index
*/
NDArray take(NDManager manager, NDArray index);

/**
* Set the entries of {@code NDArray} pointed by index according to linear indexing, to be the
Expand Down
2 changes: 1 addition & 1 deletion api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ public NDArray gather(NDArray index, int axis) {

/** {@inheritDoc} */
@Override
public NDArray take(NDArray index) {
public NDArray take(NDManager manager, NDArray index) {
throw new UnsupportedOperationException(UNSUPPORTED_MSG);
}

Expand Down
23 changes: 23 additions & 0 deletions api/src/main/java/ai/djl/ndarray/index/NDArrayIndexer.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
import ai.djl.ndarray.index.dim.NDIndexElement;
import ai.djl.ndarray.index.full.NDIndexFullPick;
import ai.djl.ndarray.index.full.NDIndexFullSlice;
import ai.djl.ndarray.index.full.NDIndexFullTake;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.List;
import java.util.Optional;
Expand All @@ -34,6 +38,15 @@ public abstract class NDArrayIndexer {
*/
public abstract NDArray get(NDArray array, NDIndexFullPick fullPick);

/**
* Returns a subarray by taken the elements from one axis.
*
* @param array the array to get from
* @param fullTake the elements to pick
* @return the subArray
*/
public abstract NDArray get(NDArray array, NDIndexFullTake fullTake);

/**
* Returns a subarray at the slice.
*
Expand Down Expand Up @@ -65,6 +78,16 @@ public NDArray get(NDArray array, NDIndex index) {
return array.booleanMask(((NDIndexBooleans) indices.get(0)).getIndex());
}

Optional<NDIndexFullTake> fullTake = NDIndexFullTake.fromIndex(index, array.getShape());
if (fullTake.isPresent()) {
Logger logger = LoggerFactory.getLogger(NDArrayIndexer.class);
logger.warn(
"The definition of the getter by array NDIndex: get(NDIndex array) has changed"
+ " from pick to take.If you still want to use array index as pick, then do"
+ " it explicitly by get(new NDIndex().addPickDim(array));");
return get(array, fullTake.get());
}

Optional<NDIndexFullPick> fullPick = NDIndexFullPick.fromIndex(index, array.getShape());
if (fullPick.isPresent()) {
return get(array, fullPick.get());
Expand Down
13 changes: 0 additions & 13 deletions api/src/main/java/ai/djl/ndarray/index/full/NDIndexFullPick.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import ai.djl.ndarray.index.dim.NDIndexAll;
import ai.djl.ndarray.index.dim.NDIndexElement;
import ai.djl.ndarray.index.dim.NDIndexPick;
import ai.djl.ndarray.index.dim.NDIndexTake;
import ai.djl.ndarray.types.Shape;

import java.util.Optional;
Expand Down Expand Up @@ -60,18 +59,6 @@ public static Optional<NDIndexFullPick> fromIndex(NDIndex index, Shape target) {
}
NDArray indexElem = ((NDIndexPick) el).getIndex();
fullPick = new NDIndexFullPick(indexElem, axis);
} else if (el instanceof NDIndexTake) {
if (fullPick != null) {
// Don't support multiple picks
throw new UnsupportedOperationException(
"Only one pick per get is currently supported");
}
NDArray indexElem = ((NDIndexTake) el).getIndex();
if (!indexElem.getShape().isRankOne()) {
throw new UnsupportedOperationException(
"Only rank-1 indexing array is supported for pick");
}
fullPick = new NDIndexFullPick(indexElem, axis);
} else {
// Invalid dim for fullPick
return Optional.empty();
Expand Down
91 changes: 91 additions & 0 deletions api/src/main/java/ai/djl/ndarray/index/full/NDIndexFullTake.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.ndarray.index.full;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.index.dim.NDIndexAll;
import ai.djl.ndarray.index.dim.NDIndexElement;
import ai.djl.ndarray.index.dim.NDIndexTake;
import ai.djl.ndarray.types.Shape;

import java.util.Optional;

/** A simplified representation of a take-based {@link NDIndex}. */
public final class NDIndexFullTake {

private NDArray indices;
private int axis;

/**
* Constructs a new {@link NDIndexFullTake}.
*
* @param indices the indices to take
* @param axis the axis to take at
*/
private NDIndexFullTake(NDArray indices, int axis) {
this.indices = indices;
this.axis = axis;
}

/**
* Returns (if possible) the {@link NDIndexFullTake} representation of an {@link NDIndex}.
*
* @param index the index to represent
* @param target the shape of the array to index
* @return the full take representation or nothing if it can't represent the index
*/
public static Optional<NDIndexFullTake> fromIndex(NDIndex index, Shape target) {
int axis = 0;
NDIndexFullTake fullTake = null;
for (NDIndexElement el : index.getIndices()) {
if (el instanceof NDIndexAll) {
axis++;
} else if (el instanceof NDIndexTake) {
if (fullTake != null) {
// Don't support multiple takes
throw new UnsupportedOperationException(
"Only one take per get is currently supported");
}
NDArray indexElem = ((NDIndexTake) el).getIndex();
if (!indexElem.getShape().isRankOne()) {
throw new UnsupportedOperationException(
"Only rank-1 indexing array is supported for take");
}
fullTake = new NDIndexFullTake(indexElem, axis);
} else {
// Invalid dim for fullTake
return Optional.empty();
}
}
return Optional.ofNullable(fullTake);
}

/**
* Returns the indices to take.
*
* @return the indices to take
*/
public NDArray getIndices() {
return indices;
}

/**
* Returns the axis to take.
*
* @return the axis to take
*/
public int getAxis() {
return axis;
}
}
Loading

0 comments on commit e9ae8d3

Please sign in to comment.