Skip to content

Commit

Permalink
Add docs for BigGAN demo
Browse files Browse the repository at this point in the history
  • Loading branch information
AzizZayed committed Jun 21, 2021
1 parent 0cec711 commit 6e052e9
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 10 deletions.
42 changes: 42 additions & 0 deletions examples/docs/biggan.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Image Generation with BigGAN from the Model Zoo

[Generative Adversarial Networks](https://en.wikipedia.org/wiki/Generative_adversarial_network) (GANs) are a branch of deep learning used for generative modeling.
They consist of 2 neural networks that act as adversaries, the Generator and the Discriminator. The Generator is assigned to generated fake images that look real, and the Discriminator needs to correctly identify the fake ones.

In this example, you will learn how to use a [BigGAN](https://deepmind.com/research/open-source/biggan) generator to create images, using the generator directly from the [ModelZoo](../../docs/model-zoo.md).

The source code for this example can be found at [BigGAN.java](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/BigGAN.java).

## Setup guide

To configure your development environment, follow [setup](../../docs/development/setup.md).

## Run Generation

### Introduction

BigGAN is trained on a subset of the [ImageNet dataset](https://en.wikipedia.org/wiki/ImageNet) with 1000 categories.
You can see the labels in [this file](https://github.com/deepjavalibrary/djl/blob/master/model-zoo/src/test/resources/mlrepo/model/cv/image_classification/ai/djl/zoo/synset_imagenet.txt).
The training was done such that the input to the model uses the ID of the category, between 0 and 999. For us, the ID is the line number in the file, starting at 0.

Thus, the input to the translator will be an array of category IDs:

```java
int[] input = {100, 207, 971, 970, 933};
```

### Build the project and run
Use the following commands to run the project:

```
cd examples
./gradlew run -Dmain=ai.djl.examples.inference.BigGAN -Dai.djl.default_engine=PyTorch
```

### Output

Your output will vary since the generation depends on a random seed. Here are a few examples:

Black Swan | Golden Retriever | Bubble | Alp | Cheeseburger
:-------------------------:|:-------------------------: |:-------------------------: | :----------------------: | :----------------------:
![](https://djl-ai.s3.amazonaws.com/resources/images/biggan/black-swan.png) | ![](https://djl-ai.s3.amazonaws.com/resources/images/biggan/golden-retriever.png)| ![](https://djl-ai.s3.amazonaws.com/resources/images/biggan/bubble.png) | ![](https://djl-ai.s3.amazonaws.com/resources/images/biggan/hills.png) | ![](https://djl-ai.s3.amazonaws.com/resources/images/biggan/cheeseburger.png)
4 changes: 2 additions & 2 deletions examples/src/main/java/ai/djl/examples/inference/BigGAN.java
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,12 @@ public static Image[] generate() throws IOException, ModelException, TranslateEx
Criteria.builder()
.optApplication(Application.CV.GAN)
.setTypes(int[].class, Image[].class)
.optEngine("PyTorch")
.optFilter("size", "256")
.optArgument("truncation", 0.4f)
.optProgress(new ProgressBar())
.build();

int[] input = {0, 100, 200, 300, 400};
int[] input = {100, 207, 971, 970, 933};

try (ZooModel<int[], Image[]> model = criteria.loadModel();
Predictor<int[], Image[]> generator = model.newPredictor()) {
Expand Down
13 changes: 5 additions & 8 deletions examples/src/test/java/ai/djl/examples/inference/BigGANTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,8 @@
package ai.djl.examples.inference;

import ai.djl.ModelException;
import ai.djl.engine.Engine;
import ai.djl.modality.cv.Image;
import ai.djl.translate.TranslateException;
import java.io.File;
import java.io.IOException;
import org.testng.Assert;
import org.testng.SkipException;
Expand All @@ -26,17 +24,16 @@ public class BigGANTest {

@Test
public void testBigGAN() throws ModelException, TranslateException, IOException {
if (!"PyTorch".equals(Engine.getInstance().getEngineName())) {
Image[] generatedImages = BigGAN.generate();

if (generatedImages == null) {
throw new SkipException("Only works for PyTorch engine.");
}

Image[] generatedImages = BigGAN.generate();
Assert.assertEquals(generatedImages.length, 5);
Assert.assertEquals(new File("build/output/gan/").list().length, 5);

for (Image img : generatedImages) {
Assert.assertEquals(img.getWidth(), 128);
Assert.assertEquals(img.getHeight(), 128);
Assert.assertEquals(img.getWidth(), 256);
Assert.assertEquals(img.getHeight(), 256);
}
}
}

0 comments on commit 6e052e9

Please sign in to comment.