Skip to content

Commit

Permalink
[examples] Fix ImageClassification invalid probability (#1575)
Browse files Browse the repository at this point in the history
* [examples] Fix ImageClassification invalid probability

* Delete unnecessary Test

* Add unitest to check probability less than 1

Change-Id: Ib6a2113f4ed669be6b5ffc4b0fdb5463b2289734

Co-authored-by: Frank Liu <frankfliu2000@gmail.com>
  • Loading branch information
dandansamax and frankfliu authored Apr 16, 2022
1 parent 0d06682 commit 64b29e9
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ public static Classifications predict() throws IOException, ModelException, Tran
ImageClassificationTranslator.builder()
.addTransform(new ToTensor())
.optSynset(classes)
.optApplySoftmax(true)
.build();

try (Predictor<Image, Classifications> predictor = model.newPredictor(translator)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ public class TrainMnistTest {
public void testTrainMnist() throws ModelException, TranslateException, IOException {
TestRequirements.engine("MXNet", "PyTorch");

double expectedProb;
if (Boolean.getBoolean("nightly")) {
String[] args = new String[] {"-g", "1"};

Expand All @@ -39,14 +40,18 @@ public void testTrainMnist() throws ModelException, TranslateException, IOExcept
Assert.assertTrue(accuracy > 0.9f, "Accuracy: " + accuracy);
Assert.assertTrue(loss < 0.35f, "Loss: " + loss);

Classifications classifications = ImageClassification.predict();
Classifications.Classification best = classifications.best();
Assert.assertEquals(best.getClassName(), "0");
Assert.assertTrue(Double.compare(best.getProbability(), 0.9) > 0);
expectedProb = 0.9;
} else {
String[] args = new String[] {"-g", "1", "-m", "2"};

TrainMnist.runExample(args);
expectedProb = 0;
}

Classifications classifications = ImageClassification.predict();
Classifications.Classification best = classifications.best();
Assert.assertEquals(best.getClassName(), "0");
double probability = best.getProbability();
Assert.assertTrue(probability > expectedProb && probability <= 1);
}
}

0 comments on commit 64b29e9

Please sign in to comment.