Skip to content

Commit

Permalink
[api] Fixes topK items for DetectedObjects and make it configurable t…
Browse files Browse the repository at this point in the history
…o Classifications

Change-Id: I4ffd29cff5205f32ee18afa5269d70491634936e
  • Loading branch information
frankfliu committed Jan 20, 2022
1 parent 3a67b8e commit 4250ad5
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 3 deletions.
38 changes: 36 additions & 2 deletions api/src/main/java/ai/djl/modality/Classifications.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ public class Classifications implements JsonSerializable {

protected List<String> classNames;
protected List<Double> probabilities;
private int topK;

/**
* Constructs a {@code Classifications} using a parallel list of classNames and probabilities.
Expand All @@ -55,6 +56,7 @@ public class Classifications implements JsonSerializable {
public Classifications(List<String> classNames, List<Double> probabilities) {
this.classNames = classNames;
this.probabilities = probabilities;
this.topK = 5;
}

/**
Expand All @@ -65,11 +67,33 @@ public Classifications(List<String> classNames, List<Double> probabilities) {
* @param probabilities the probabilities for each class for the input
*/
public Classifications(List<String> classNames, NDArray probabilities) {
this(classNames, probabilities, 5);
}

/**
* Constructs a {@code Classifications} using list of classNames parallel to an NDArray of
* probabilities.
*
* @param classNames the names of the classes
* @param probabilities the probabilities for each class for the input
* @param topK the number of top classes to return
*/
public Classifications(List<String> classNames, NDArray probabilities, int topK) {
this.classNames = classNames;
NDArray array = probabilities.toType(DataType.FLOAT64, false);
this.probabilities =
Arrays.stream(array.toDoubleArray()).boxed().collect(Collectors.toList());
array.close();
this.topK = topK;
}

/**
* Set the topK number of classes to be displayed.
*
* @param topK the number of top classes to return
*/
public final void setTopK(int topK) {
this.topK = topK;
}

/**
Expand Down Expand Up @@ -99,6 +123,16 @@ public <T extends Classification> T item(int index) {
return (T) new Classification(classNames.get(index), probabilities.get(index));
}

/**
* Returns a list of the top classes.
*
* @param <T> the type of the classification item for the task
* @return the list of classification items for the best classes in order of best to worst
*/
public <T extends Classification> List<T> topK() {
return topK(topK);
}

/**
* Returns a list of the top {@code k} best classes.
*
Expand Down Expand Up @@ -163,7 +197,7 @@ public ByteBuffer toByteBuffer() {
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append('[').append(System.lineSeparator());
for (Classification item : topK(5)) {
for (Classification item : topK(topK)) {
sb.append('\t').append(item).append(System.lineSeparator());
}
sb.append(']');
Expand Down Expand Up @@ -227,7 +261,7 @@ public static final class ClassificationsSerializer implements JsonSerializer<Cl
/** {@inheritDoc} */
@Override
public JsonElement serialize(Classifications src, Type type, JsonSerializationContext ctx) {
List<?> list = src.topK(5);
List<?> list = src.topK();
return ctx.serialize(list);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ public DetectedObjects(
List<String> classNames, List<Double> probabilities, List<BoundingBox> boundingBoxes) {
super(classNames, probabilities);
this.boundingBoxes = boundingBoxes;
setTopK(Integer.MAX_VALUE);
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ public class ImageClassificationTranslator extends BaseImageTranslator<Classific

private SynsetLoader synsetLoader;
private boolean applySoftmax;
private int topK;

private List<String> classes;

Expand All @@ -39,6 +40,7 @@ public ImageClassificationTranslator(Builder builder) {
super(builder);
this.synsetLoader = builder.synsetLoader;
this.applySoftmax = builder.applySoftmax;
this.topK = builder.topK;
}

/** {@inheritDoc} */
Expand All @@ -56,7 +58,7 @@ public Classifications processOutput(TranslatorContext ctx, NDList list) {
if (applySoftmax) {
probabilitiesNd = probabilitiesNd.softmax(0);
}
return new Classifications(classes, probabilitiesNd);
return new Classifications(classes, probabilitiesNd, topK);
}

/**
Expand Down Expand Up @@ -85,9 +87,21 @@ public static Builder builder(Map<String, ?> arguments) {
public static class Builder extends ClassificationBuilder<Builder> {

private boolean applySoftmax;
private int topK = 5;

Builder() {}

/**
* Set the topK number of classes to be displayed.
*
* @param topK the number of top classes to return
* @return the builder
*/
public Builder optTopK(int topK) {
this.topK = topK;
return this;
}

/**
* Sets whether to apply softmax when processing output. Some models already include softmax
* in the last layer, so don't apply softmax when processing model output.
Expand All @@ -111,6 +125,7 @@ protected Builder self() {
protected void configPostProcess(Map<String, ?> arguments) {
super.configPostProcess(arguments);
applySoftmax = ArgumentsUtil.booleanValue(arguments, "applySoftmax");
topK = ArgumentsUtil.intValue(arguments, "topK", 5);
}

/**
Expand Down

0 comments on commit 4250ad5

Please sign in to comment.