Skip to content

Commit

Permalink
Return Batches with ArrayDataset helpers (#1995)
Browse files Browse the repository at this point in the history
Fixes minor issue with #1869
  • Loading branch information
zachgk authored Sep 6, 2022
1 parent 5def6e3 commit bce07da
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 17 deletions.
33 changes: 21 additions & 12 deletions api/src/main/java/ai/djl/training/dataset/ArrayDataset.java
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,13 @@ public Record get(NDManager manager, long index) {
}

/**
* Gets the {@link Record} for the given indices from the dataset.
* Gets the {@link Batch} for the given indices from the dataset.
*
* @param manager the manager used to create the arrays
* @param indices indices of the requested data items
* @return a {@link Record} that contains the data and label of the requested data items
* @return a {@link Batch} that contains the data and label of the requested data items
*/
public Record getByIndices(NDManager manager, long... indices) {
public Batch getByIndices(NDManager manager, long... indices) {
try (NDArray ndIndices = manager.create(indices)) {
NDIndex index = new NDIndex("{}", ndIndices);
NDList datum = new NDList();
Expand All @@ -122,19 +122,27 @@ public Record getByIndices(NDManager manager, long... indices) {
label.add(array.get(manager, index));
}
}
return new Record(datum, label);
return new Batch(
manager,
datum,
label,
indices.length,
Batchifier.STACK,
Batchifier.STACK,
-1,
-1);
}
}

/**
* Gets the {@link Record} for the given range from the dataset.
* Gets the {@link Batch} for the given range from the dataset.
*
* @param manager the manager used to create the arrays
* @param fromIndex low endpoint (inclusive) of the dataset
* @param toIndex high endpoint (exclusive) of the dataset
* @return a {@link Record} that contains the data and label of the requested data items
* @return a {@link Batch} that contains the data and label of the requested data items
*/
public Record getByRange(NDManager manager, long fromIndex, long toIndex) {
public Batch getByRange(NDManager manager, long fromIndex, long toIndex) {
NDIndex index = new NDIndex().addSliceDim(fromIndex, toIndex);
NDList datum = new NDList();
NDList label = new NDList();
Expand All @@ -146,7 +154,8 @@ public Record getByRange(NDManager manager, long fromIndex, long toIndex) {
label.add(array.get(manager, index));
}
}
return new Record(datum, label);
int size = Math.toIntExact(toIndex - fromIndex);
return new Batch(manager, datum, label, size, Batchifier.STACK, Batchifier.STACK, -1, -1);
}

/** {@inheritDoc} */
Expand Down Expand Up @@ -277,7 +286,7 @@ public Record get(NDManager manager, long index) {

/** {@inheritDoc} */
@Override
public Record getByIndices(NDManager manager, long... indices) {
public Batch getByIndices(NDManager manager, long... indices) {
long[] resolvedIndices = new long[indices.length];
int i = 0;
for (long index : indices) {
Expand All @@ -288,7 +297,7 @@ public Record getByIndices(NDManager manager, long... indices) {

/** {@inheritDoc} */
@Override
public Record getByRange(NDManager manager, long fromIndex, long toIndex) {
public Batch getByRange(NDManager manager, long fromIndex, long toIndex) {
return dataset.getByRange(manager, fromIndex + from, toIndex + from);
}

Expand Down Expand Up @@ -330,7 +339,7 @@ public Record get(NDManager manager, long index) {

/** {@inheritDoc} */
@Override
public Record getByIndices(NDManager manager, long... indices) {
public Batch getByIndices(NDManager manager, long... indices) {
long[] resolvedIndices = new long[indices.length];
int i = 0;
for (long index : indices) {
Expand All @@ -341,7 +350,7 @@ public Record getByIndices(NDManager manager, long... indices) {

/** {@inheritDoc} */
@Override
public Record getByRange(NDManager manager, long fromIndex, long toIndex) {
public Batch getByRange(NDManager manager, long fromIndex, long toIndex) {
long[] resolvedIndices = new long[(int) (toIndex - fromIndex)];
int i = 0;
for (long index = fromIndex; index < toIndex; index++) {
Expand Down
10 changes: 5 additions & 5 deletions api/src/main/java/ai/djl/training/dataset/BulkDataIterable.java
Original file line number Diff line number Diff line change
Expand Up @@ -73,23 +73,23 @@ protected Batch fetch(List<Long> indices, int progress) throws IOException {
subManager.setName("dataIter fetch");
int batchSize = indices.size();

Record record;
Batch raw;
if (isRange(indices)) {
long fromIndex = indices.get(0);
long toIndex = fromIndex + indices.size();
record = ((ArrayDataset) dataset).getByRange(subManager, fromIndex, toIndex);
raw = ((ArrayDataset) dataset).getByRange(subManager, fromIndex, toIndex);
} else {
long[] indicesArr = indices.stream().mapToLong(Long::longValue).toArray();
record = ((ArrayDataset) dataset).getByIndices(subManager, indicesArr);
raw = ((ArrayDataset) dataset).getByIndices(subManager, indicesArr);
}

NDList batchData = record.getData();
NDList batchData = raw.getData();
// apply transform
if (pipeline != null) {
batchData = pipeline.transform(batchData);
}

NDList batchLabels = record.getLabels();
NDList batchLabels = raw.getLabels();

// apply label transform
if (targetPipeline != null) {
Expand Down

0 comments on commit bce07da

Please sign in to comment.