From a002d0794af33849878d9df77e0abf727ae4cd1d Mon Sep 17 00:00:00 2001 From: Daniel Xu <11912921@mail.sustech.edu.cn> Date: Sat, 16 Apr 2022 15:05:16 +0800 Subject: [PATCH 1/3] [examples] Fix ImageClassification invalid probability --- .../inference/ImageClassification.java | 1 + .../inference/ImageClassificationTest.java | 37 +++++++++++++++++++ 2 files changed, 38 insertions(+) create mode 100644 examples/src/test/java/ai/djl/examples/inference/ImageClassificationTest.java diff --git a/examples/src/main/java/ai/djl/examples/inference/ImageClassification.java b/examples/src/main/java/ai/djl/examples/inference/ImageClassification.java index 9712ecb8260..69ae0f4422c 100644 --- a/examples/src/main/java/ai/djl/examples/inference/ImageClassification.java +++ b/examples/src/main/java/ai/djl/examples/inference/ImageClassification.java @@ -68,6 +68,7 @@ public static Classifications predict() throws IOException, ModelException, Tran ImageClassificationTranslator.builder() .addTransform(new ToTensor()) .optSynset(classes) + .optApplySoftmax(true) .build(); try (Predictor predictor = model.newPredictor(translator)) { diff --git a/examples/src/test/java/ai/djl/examples/inference/ImageClassificationTest.java b/examples/src/test/java/ai/djl/examples/inference/ImageClassificationTest.java new file mode 100644 index 00000000000..b95b6b784d0 --- /dev/null +++ b/examples/src/test/java/ai/djl/examples/inference/ImageClassificationTest.java @@ -0,0 +1,37 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.examples.inference; + +import ai.djl.ModelException; +import ai.djl.examples.training.TrainMnist; +import ai.djl.modality.Classifications; +import ai.djl.testing.TestRequirements; +import ai.djl.translate.TranslateException; +import java.io.IOException; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class ImageClassificationTest { + @Test + public void testImageClassification() throws IOException, TranslateException, ModelException { + TestRequirements.engine("MXNet", "PyTorch"); + String[] args = new String[] {"-g", "1", "-m", "2"}; + + TrainMnist.runExample(args); + Classifications classifications = ImageClassification.predict(); + for (Classifications.Classification classification : classifications.items()) { + Assert.assertTrue( + classification.getProbability() >= 0f && classification.getProbability() <= 1f); + } + } +} From fdb26ba73684d56858117f8d0a40a8c008f94434 Mon Sep 17 00:00:00 2001 From: Daniel Xu <11912921@mail.sustech.edu.cn> Date: Sun, 17 Apr 2022 00:33:55 +0800 Subject: [PATCH 2/3] Delete unnecessary Test --- .../inference/ImageClassificationTest.java | 37 ------------------- 1 file changed, 37 deletions(-) delete mode 100644 examples/src/test/java/ai/djl/examples/inference/ImageClassificationTest.java diff --git a/examples/src/test/java/ai/djl/examples/inference/ImageClassificationTest.java b/examples/src/test/java/ai/djl/examples/inference/ImageClassificationTest.java deleted file mode 100644 index b95b6b784d0..00000000000 --- a/examples/src/test/java/ai/djl/examples/inference/ImageClassificationTest.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.examples.inference; - -import ai.djl.ModelException; -import ai.djl.examples.training.TrainMnist; -import ai.djl.modality.Classifications; -import ai.djl.testing.TestRequirements; -import ai.djl.translate.TranslateException; -import java.io.IOException; -import org.testng.Assert; -import org.testng.annotations.Test; - -public class ImageClassificationTest { - @Test - public void testImageClassification() throws IOException, TranslateException, ModelException { - TestRequirements.engine("MXNet", "PyTorch"); - String[] args = new String[] {"-g", "1", "-m", "2"}; - - TrainMnist.runExample(args); - Classifications classifications = ImageClassification.predict(); - for (Classifications.Classification classification : classifications.items()) { - Assert.assertTrue( - classification.getProbability() >= 0f && classification.getProbability() <= 1f); - } - } -} From 2d82a8da2d7b11c39af5cca4ea7df3a347a63379 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Sat, 16 Apr 2022 09:55:04 -0700 Subject: [PATCH 3/3] Add unitest to check probability less than 1 Change-Id: Ib6a2113f4ed669be6b5ffc4b0fdb5463b2289734 --- .../ai/djl/examples/training/TrainMnistTest.java | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/examples/src/test/java/ai/djl/examples/training/TrainMnistTest.java b/examples/src/test/java/ai/djl/examples/training/TrainMnistTest.java index 511e780eca5..5c29059be71 100644 --- a/examples/src/test/java/ai/djl/examples/training/TrainMnistTest.java +++ b/examples/src/test/java/ai/djl/examples/training/TrainMnistTest.java @@ -28,6 +28,7 @@ public class TrainMnistTest { public void testTrainMnist() throws ModelException, TranslateException, IOException { TestRequirements.engine("MXNet", "PyTorch"); + double expectedProb; if (Boolean.getBoolean("nightly")) { String[] args = new String[] {"-g", "1"}; @@ -39,14 +40,18 @@ public void testTrainMnist() throws ModelException, TranslateException, IOExcept Assert.assertTrue(accuracy > 0.9f, "Accuracy: " + accuracy); Assert.assertTrue(loss < 0.35f, "Loss: " + loss); - Classifications classifications = ImageClassification.predict(); - Classifications.Classification best = classifications.best(); - Assert.assertEquals(best.getClassName(), "0"); - Assert.assertTrue(Double.compare(best.getProbability(), 0.9) > 0); + expectedProb = 0.9; } else { String[] args = new String[] {"-g", "1", "-m", "2"}; TrainMnist.runExample(args); + expectedProb = 0; } + + Classifications classifications = ImageClassification.predict(); + Classifications.Classification best = classifications.best(); + Assert.assertEquals(best.getClassName(), "0"); + double probability = best.getProbability(); + Assert.assertTrue(probability > expectedProb && probability <= 1); } }