-
Notifications
You must be signed in to change notification settings - Fork 685
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Bulk Batch creation * execute "take" instead of "pick" for mxnet if requested * more convenience methods to retrieve a sub dataset * Update api/src/main/java/ai/djl/training/dataset/BulkDataIterable.java * revert initial take implementation * use new take implementation * use NDIndexFullTake approach * added test for BulkDataIterable * tidy up test for LinearCollection * test isRange * add index test and clean code * add warning Co-authored-by: KexinFeng <fengx463@umn.edu>
- Loading branch information
Showing
20 changed files
with
797 additions
and
92 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
91 changes: 91 additions & 0 deletions
91
api/src/main/java/ai/djl/ndarray/index/full/NDIndexFullTake.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
/* | ||
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
* with the License. A copy of the License is located at | ||
* | ||
* http://aws.amazon.com/apache2.0/ | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
* and limitations under the License. | ||
*/ | ||
package ai.djl.ndarray.index.full; | ||
|
||
import ai.djl.ndarray.NDArray; | ||
import ai.djl.ndarray.index.NDIndex; | ||
import ai.djl.ndarray.index.dim.NDIndexAll; | ||
import ai.djl.ndarray.index.dim.NDIndexElement; | ||
import ai.djl.ndarray.index.dim.NDIndexTake; | ||
import ai.djl.ndarray.types.Shape; | ||
|
||
import java.util.Optional; | ||
|
||
/** A simplified representation of a take-based {@link NDIndex}. */ | ||
public final class NDIndexFullTake { | ||
|
||
private NDArray indices; | ||
private int axis; | ||
|
||
/** | ||
* Constructs a new {@link NDIndexFullTake}. | ||
* | ||
* @param indices the indices to take | ||
* @param axis the axis to take at | ||
*/ | ||
private NDIndexFullTake(NDArray indices, int axis) { | ||
this.indices = indices; | ||
this.axis = axis; | ||
} | ||
|
||
/** | ||
* Returns (if possible) the {@link NDIndexFullTake} representation of an {@link NDIndex}. | ||
* | ||
* @param index the index to represent | ||
* @param target the shape of the array to index | ||
* @return the full take representation or nothing if it can't represent the index | ||
*/ | ||
public static Optional<NDIndexFullTake> fromIndex(NDIndex index, Shape target) { | ||
int axis = 0; | ||
NDIndexFullTake fullTake = null; | ||
for (NDIndexElement el : index.getIndices()) { | ||
if (el instanceof NDIndexAll) { | ||
axis++; | ||
} else if (el instanceof NDIndexTake) { | ||
if (fullTake != null) { | ||
// Don't support multiple takes | ||
throw new UnsupportedOperationException( | ||
"Only one take per get is currently supported"); | ||
} | ||
NDArray indexElem = ((NDIndexTake) el).getIndex(); | ||
if (!indexElem.getShape().isRankOne()) { | ||
throw new UnsupportedOperationException( | ||
"Only rank-1 indexing array is supported for take"); | ||
} | ||
fullTake = new NDIndexFullTake(indexElem, axis); | ||
} else { | ||
// Invalid dim for fullTake | ||
return Optional.empty(); | ||
} | ||
} | ||
return Optional.ofNullable(fullTake); | ||
} | ||
|
||
/** | ||
* Returns the indices to take. | ||
* | ||
* @return the indices to take | ||
*/ | ||
public NDArray getIndices() { | ||
return indices; | ||
} | ||
|
||
/** | ||
* Returns the axis to take. | ||
* | ||
* @return the axis to take | ||
*/ | ||
public int getAxis() { | ||
return axis; | ||
} | ||
} |
Oops, something went wrong.