From 6e052e9a818109d0b9db23412bdcf4a8c98d782e Mon Sep 17 00:00:00 2001 From: Aziz Zayed Date: Sun, 20 Jun 2021 16:03:15 -0700 Subject: [PATCH] Add docs for BigGAN demo --- examples/docs/biggan.md | 42 +++++++++++++++++++ .../ai/djl/examples/inference/BigGAN.java | 4 +- .../ai/djl/examples/inference/BigGANTest.java | 13 +++--- 3 files changed, 49 insertions(+), 10 deletions(-) create mode 100644 examples/docs/biggan.md diff --git a/examples/docs/biggan.md b/examples/docs/biggan.md new file mode 100644 index 000000000000..604434587b7e --- /dev/null +++ b/examples/docs/biggan.md @@ -0,0 +1,42 @@ +# Image Generation with BigGAN from the Model Zoo + +[Generative Adversarial Networks](https://en.wikipedia.org/wiki/Generative_adversarial_network) (GANs) are a branch of deep learning used for generative modeling. +They consist of 2 neural networks that act as adversaries, the Generator and the Discriminator. The Generator is assigned to generated fake images that look real, and the Discriminator needs to correctly identify the fake ones. + +In this example, you will learn how to use a [BigGAN](https://deepmind.com/research/open-source/biggan) generator to create images, using the generator directly from the [ModelZoo](../../docs/model-zoo.md). + +The source code for this example can be found at [BigGAN.java](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/BigGAN.java). + +## Setup guide + +To configure your development environment, follow [setup](../../docs/development/setup.md). + +## Run Generation + +### Introduction + +BigGAN is trained on a subset of the [ImageNet dataset](https://en.wikipedia.org/wiki/ImageNet) with 1000 categories. +You can see the labels in [this file](https://github.com/deepjavalibrary/djl/blob/master/model-zoo/src/test/resources/mlrepo/model/cv/image_classification/ai/djl/zoo/synset_imagenet.txt). +The training was done such that the input to the model uses the ID of the category, between 0 and 999. For us, the ID is the line number in the file, starting at 0. + +Thus, the input to the translator will be an array of category IDs: + +```java +int[] input = {100, 207, 971, 970, 933}; +``` + +### Build the project and run +Use the following commands to run the project: + +``` +cd examples +./gradlew run -Dmain=ai.djl.examples.inference.BigGAN -Dai.djl.default_engine=PyTorch +``` + +### Output + +Your output will vary since the generation depends on a random seed. Here are a few examples: + +Black Swan | Golden Retriever | Bubble | Alp | Cheeseburger +:-------------------------:|:-------------------------: |:-------------------------: | :----------------------: | :----------------------: +![](https://djl-ai.s3.amazonaws.com/resources/images/biggan/black-swan.png) | ![](https://djl-ai.s3.amazonaws.com/resources/images/biggan/golden-retriever.png)| ![](https://djl-ai.s3.amazonaws.com/resources/images/biggan/bubble.png) | ![](https://djl-ai.s3.amazonaws.com/resources/images/biggan/hills.png) | ![](https://djl-ai.s3.amazonaws.com/resources/images/biggan/cheeseburger.png) diff --git a/examples/src/main/java/ai/djl/examples/inference/BigGAN.java b/examples/src/main/java/ai/djl/examples/inference/BigGAN.java index 36e025b6246f..eeca209bc417 100644 --- a/examples/src/main/java/ai/djl/examples/inference/BigGAN.java +++ b/examples/src/main/java/ai/djl/examples/inference/BigGAN.java @@ -66,12 +66,12 @@ public static Image[] generate() throws IOException, ModelException, TranslateEx Criteria.builder() .optApplication(Application.CV.GAN) .setTypes(int[].class, Image[].class) - .optEngine("PyTorch") + .optFilter("size", "256") .optArgument("truncation", 0.4f) .optProgress(new ProgressBar()) .build(); - int[] input = {0, 100, 200, 300, 400}; + int[] input = {100, 207, 971, 970, 933}; try (ZooModel model = criteria.loadModel(); Predictor generator = model.newPredictor()) { diff --git a/examples/src/test/java/ai/djl/examples/inference/BigGANTest.java b/examples/src/test/java/ai/djl/examples/inference/BigGANTest.java index 19af0bd8e02a..47b83938c92c 100644 --- a/examples/src/test/java/ai/djl/examples/inference/BigGANTest.java +++ b/examples/src/test/java/ai/djl/examples/inference/BigGANTest.java @@ -13,10 +13,8 @@ package ai.djl.examples.inference; import ai.djl.ModelException; -import ai.djl.engine.Engine; import ai.djl.modality.cv.Image; import ai.djl.translate.TranslateException; -import java.io.File; import java.io.IOException; import org.testng.Assert; import org.testng.SkipException; @@ -26,17 +24,16 @@ public class BigGANTest { @Test public void testBigGAN() throws ModelException, TranslateException, IOException { - if (!"PyTorch".equals(Engine.getInstance().getEngineName())) { + Image[] generatedImages = BigGAN.generate(); + + if (generatedImages == null) { throw new SkipException("Only works for PyTorch engine."); } - Image[] generatedImages = BigGAN.generate(); Assert.assertEquals(generatedImages.length, 5); - Assert.assertEquals(new File("build/output/gan/").list().length, 5); - for (Image img : generatedImages) { - Assert.assertEquals(img.getWidth(), 128); - Assert.assertEquals(img.getHeight(), 128); + Assert.assertEquals(img.getWidth(), 256); + Assert.assertEquals(img.getHeight(), 256); } } }