diff --git a/api/src/main/java/ai/djl/modality/Classifications.java b/api/src/main/java/ai/djl/modality/Classifications.java index 237fb49d83c..cb75c423976 100644 --- a/api/src/main/java/ai/djl/modality/Classifications.java +++ b/api/src/main/java/ai/djl/modality/Classifications.java @@ -45,6 +45,7 @@ public class Classifications implements JsonSerializable { protected List classNames; protected List probabilities; + private int topK; /** * Constructs a {@code Classifications} using a parallel list of classNames and probabilities. @@ -55,6 +56,7 @@ public class Classifications implements JsonSerializable { public Classifications(List classNames, List probabilities) { this.classNames = classNames; this.probabilities = probabilities; + this.topK = 5; } /** @@ -65,11 +67,33 @@ public Classifications(List classNames, List probabilities) { * @param probabilities the probabilities for each class for the input */ public Classifications(List classNames, NDArray probabilities) { + this(classNames, probabilities, 5); + } + + /** + * Constructs a {@code Classifications} using list of classNames parallel to an NDArray of + * probabilities. + * + * @param classNames the names of the classes + * @param probabilities the probabilities for each class for the input + * @param topK the number of top classes to return + */ + public Classifications(List classNames, NDArray probabilities, int topK) { this.classNames = classNames; NDArray array = probabilities.toType(DataType.FLOAT64, false); this.probabilities = Arrays.stream(array.toDoubleArray()).boxed().collect(Collectors.toList()); array.close(); + this.topK = topK; + } + + /** + * Set the topK number of classes to be displayed. + * + * @param topK the number of top classes to return + */ + public final void setTopK(int topK) { + this.topK = topK; } /** @@ -99,6 +123,16 @@ public T item(int index) { return (T) new Classification(classNames.get(index), probabilities.get(index)); } + /** + * Returns a list of the top classes. + * + * @param the type of the classification item for the task + * @return the list of classification items for the best classes in order of best to worst + */ + public List topK() { + return topK(topK); + } + /** * Returns a list of the top {@code k} best classes. * @@ -163,7 +197,7 @@ public ByteBuffer toByteBuffer() { public String toString() { StringBuilder sb = new StringBuilder(); sb.append('[').append(System.lineSeparator()); - for (Classification item : topK(5)) { + for (Classification item : topK(topK)) { sb.append('\t').append(item).append(System.lineSeparator()); } sb.append(']'); @@ -227,7 +261,7 @@ public static final class ClassificationsSerializer implements JsonSerializer list = src.topK(5); + List list = src.topK(); return ctx.serialize(list); } } diff --git a/api/src/main/java/ai/djl/modality/cv/output/DetectedObjects.java b/api/src/main/java/ai/djl/modality/cv/output/DetectedObjects.java index ffb5b3c410a..8373308b3ae 100644 --- a/api/src/main/java/ai/djl/modality/cv/output/DetectedObjects.java +++ b/api/src/main/java/ai/djl/modality/cv/output/DetectedObjects.java @@ -45,6 +45,7 @@ public DetectedObjects( List classNames, List probabilities, List boundingBoxes) { super(classNames, probabilities); this.boundingBoxes = boundingBoxes; + setTopK(Integer.MAX_VALUE); } /** {@inheritDoc} */ diff --git a/api/src/main/java/ai/djl/modality/cv/translator/ImageClassificationTranslator.java b/api/src/main/java/ai/djl/modality/cv/translator/ImageClassificationTranslator.java index 713c809b817..feb98f5a3a5 100644 --- a/api/src/main/java/ai/djl/modality/cv/translator/ImageClassificationTranslator.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/ImageClassificationTranslator.java @@ -27,6 +27,7 @@ public class ImageClassificationTranslator extends BaseImageTranslator classes; @@ -39,6 +40,7 @@ public ImageClassificationTranslator(Builder builder) { super(builder); this.synsetLoader = builder.synsetLoader; this.applySoftmax = builder.applySoftmax; + this.topK = builder.topK; } /** {@inheritDoc} */ @@ -56,7 +58,7 @@ public Classifications processOutput(TranslatorContext ctx, NDList list) { if (applySoftmax) { probabilitiesNd = probabilitiesNd.softmax(0); } - return new Classifications(classes, probabilitiesNd); + return new Classifications(classes, probabilitiesNd, topK); } /** @@ -85,9 +87,21 @@ public static Builder builder(Map arguments) { public static class Builder extends ClassificationBuilder { private boolean applySoftmax; + private int topK = 5; Builder() {} + /** + * Set the topK number of classes to be displayed. + * + * @param topK the number of top classes to return + * @return the builder + */ + public Builder optTopK(int topK) { + this.topK = topK; + return this; + } + /** * Sets whether to apply softmax when processing output. Some models already include softmax * in the last layer, so don't apply softmax when processing model output. @@ -111,6 +125,7 @@ protected Builder self() { protected void configPostProcess(Map arguments) { super.configPostProcess(arguments); applySoftmax = ArgumentsUtil.booleanValue(arguments, "applySoftmax"); + topK = ArgumentsUtil.intValue(arguments, "topK", 5); } /**