Skip to content

Commit

Permalink
Add unitest to check probability less than 1
Browse files Browse the repository at this point in the history
Change-Id: Ib6a2113f4ed669be6b5ffc4b0fdb5463b2289734
  • Loading branch information
frankfliu committed Apr 16, 2022
1 parent fdb26ba commit 2d82a8d
Showing 1 changed file with 9 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ public class TrainMnistTest {
public void testTrainMnist() throws ModelException, TranslateException, IOException {
TestRequirements.engine("MXNet", "PyTorch");

double expectedProb;
if (Boolean.getBoolean("nightly")) {
String[] args = new String[] {"-g", "1"};

Expand All @@ -39,14 +40,18 @@ public void testTrainMnist() throws ModelException, TranslateException, IOExcept
Assert.assertTrue(accuracy > 0.9f, "Accuracy: " + accuracy);
Assert.assertTrue(loss < 0.35f, "Loss: " + loss);

Classifications classifications = ImageClassification.predict();
Classifications.Classification best = classifications.best();
Assert.assertEquals(best.getClassName(), "0");
Assert.assertTrue(Double.compare(best.getProbability(), 0.9) > 0);
expectedProb = 0.9;
} else {
String[] args = new String[] {"-g", "1", "-m", "2"};

TrainMnist.runExample(args);
expectedProb = 0;
}

Classifications classifications = ImageClassification.predict();
Classifications.Classification best = classifications.best();
Assert.assertEquals(best.getClassName(), "0");
double probability = best.getProbability();
Assert.assertTrue(probability > expectedProb && probability <= 1);
}
}

0 comments on commit 2d82a8d

Please sign in to comment.