Skip to content

Commit

Permalink
Clean up and add docs for BigGAN demo (#1038)
Browse files Browse the repository at this point in the history
* [pytorch] Add BigGAN demo in examples

* [api] Start support for GAN

* Fix biggan model issues

Change-Id: If6bfbb4989a6054934bcf6dff8ace09ec729d08a

* Fix minor issues and add comments

* Fix build issue

* Simplify the BigGAN demo

* Add test for BigGAN demo

* Add docs for BigGAN demo

Co-authored-by: Frank Liu <frankfliu2000@gmail.com>
  • Loading branch information
AzizZayed and frankfliu authored Jun 23, 2021
1 parent 79a6720 commit 6a81d9d
Show file tree
Hide file tree
Showing 9 changed files with 455 additions and 0 deletions.
3 changes: 3 additions & 0 deletions api/src/main/java/ai/djl/Application.java
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@ public interface CV {
* String}&gt;.
*/
Application WORD_RECOGNITION = new Application("cv/word_recognition");

/** An application that accepts a seed and returns generated images. */
Application GAN = new Application("cv/gan");
}

/** The common set of applications for natural language processing (text data). */
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.modality.cv.translator;

import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;

/** Built-in {@code Translator} that provides preprocessing and postprocessing for BigGAN. */
public final class BigGANTranslator implements Translator<int[], Image[]> {

private static final int NUMBER_OF_CATEGORIES = 1000;
private static final int SEED_COLUMN_SIZE = 128;
private float truncation;

/**
* Construct a translator for BigGAN.
*
* @param truncation value used to scale the normal seed for BigGAN
*/
public BigGANTranslator(float truncation) {
this.truncation = truncation;
}

@Override
public Image[] processOutput(TranslatorContext ctx, NDList list) {
NDArray output = list.get(0).addi(1).muli(128).clip(0, 255).toType(DataType.UINT8, false);

int sampleSize = (int) output.getShape().get(0);
Image[] images = new Image[sampleSize];

for (int i = 0; i < sampleSize; ++i) {
images[i] = ImageFactory.getInstance().fromNDArray(output.get(i));
}

return images;
}

@Override
public NDList processInput(TranslatorContext ctx, int[] input) throws Exception {
NDManager manager = ctx.getNDManager();

NDArray classes = manager.create(input).oneHot(NUMBER_OF_CATEGORIES);
NDArray seed =
manager.truncatedNormal(new Shape(input.length, SEED_COLUMN_SIZE)).muli(truncation);

return new NDList(seed, classes, manager.create(truncation));
}

@Override
public Batchifier getBatchifier() {
return null;
}
}
42 changes: 42 additions & 0 deletions examples/docs/biggan.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Image Generation with BigGAN from the Model Zoo

[Generative Adversarial Networks](https://en.wikipedia.org/wiki/Generative_adversarial_network) (GANs) are a branch of deep learning used for generative modeling.
They consist of 2 neural networks that act as adversaries, the Generator and the Discriminator. The Generator is assigned to generated fake images that look real, and the Discriminator needs to correctly identify the fake ones.

In this example, you will learn how to use a [BigGAN](https://deepmind.com/research/open-source/biggan) generator to create images, using the generator directly from the [ModelZoo](../../docs/model-zoo.md).

The source code for this example can be found at [BigGAN.java](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/BigGAN.java).

## Setup guide

To configure your development environment, follow [setup](../../docs/development/setup.md).

## Run Generation

### Introduction

BigGAN is trained on a subset of the [ImageNet dataset](https://en.wikipedia.org/wiki/ImageNet) with 1000 categories.
You can see the labels in [this file](https://github.com/deepjavalibrary/djl/blob/master/model-zoo/src/test/resources/mlrepo/model/cv/image_classification/ai/djl/zoo/synset_imagenet.txt).
The training was done such that the input to the model uses the ID of the category, between 0 and 999. For us, the ID is the line number in the file, starting at 0.

Thus, the input to the translator will be an array of category IDs:

```java
int[] input = {100, 207, 971, 970, 933};
```

### Build the project and run
Use the following commands to run the project:

```
cd examples
./gradlew run -Dmain=ai.djl.examples.inference.BigGAN -Dai.djl.default_engine=PyTorch
```

### Output

Your output will vary since the generation depends on a random seed. Here are a few examples:

Black Swan | Golden Retriever | Bubble | Alp | Cheeseburger
:-------------------------:|:-------------------------: |:-------------------------: | :----------------------: | :----------------------:
![]( https://resources.djl.ai/images/biggan/black-swan.png) | ![]( https://resources.djl.ai/images/biggan/golden-retriever.png)| ![]( https://resources.djl.ai/images/biggan/bubble.png) | ![]( https://resources.djl.ai/images/biggan/hills.png) | ![]( https://resources.djl.ai/images/biggan/cheeseburger.png)
81 changes: 81 additions & 0 deletions examples/src/main/java/ai/djl/examples/inference/BigGAN.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.examples.inference;

import ai.djl.Application;
import ai.djl.ModelException;
import ai.djl.engine.Engine;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** An example of generation using BigGAN. */
public final class BigGAN {

private static final Logger logger = LoggerFactory.getLogger(BigGAN.class);

private BigGAN() {}

public static void main(String[] args) throws ModelException, TranslateException, IOException {
Image[] generatedImages = BigGAN.generate();

if (generatedImages == null) {
logger.info("This example only works for PyTorch Engine");
} else {
logger.info("Using PyTorch Engine. {} images generated.", generatedImages.length);
saveImages(generatedImages);
}
}

private static void saveImages(Image[] generatedImages) throws IOException {
Path outputPath = Paths.get("build/output/gan/");
Files.createDirectories(outputPath);

for (int i = 0; i < generatedImages.length; ++i) {
Path imagePath = outputPath.resolve("image" + i + ".png");
generatedImages[i].save(Files.newOutputStream(imagePath), "png");
}
logger.info("Generated images have been saved in: {}", outputPath);
}

public static Image[] generate() throws IOException, ModelException, TranslateException {
if (!"PyTorch".equals(Engine.getInstance().getEngineName())) {
return null;
}

Criteria<int[], Image[]> criteria =
Criteria.builder()
.optApplication(Application.CV.GAN)
.setTypes(int[].class, Image[].class)
.optFilter("size", "256")
.optArgument("truncation", 0.4f)
.optProgress(new ProgressBar())
.build();

int[] input = {100, 207, 971, 970, 933};

try (ZooModel<int[], Image[]> model = criteria.loadModel();
Predictor<int[], Image[]> generator = model.newPredictor()) {
return generator.predict(input);
}
}
}
39 changes: 39 additions & 0 deletions examples/src/test/java/ai/djl/examples/inference/BigGANTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.examples.inference;

import ai.djl.ModelException;
import ai.djl.modality.cv.Image;
import ai.djl.translate.TranslateException;
import java.io.IOException;
import org.testng.Assert;
import org.testng.SkipException;
import org.testng.annotations.Test;

public class BigGANTest {

@Test
public void testBigGAN() throws ModelException, TranslateException, IOException {
Image[] generatedImages = BigGAN.generate();

if (generatedImages == null) {
throw new SkipException("Only works for PyTorch engine.");
}

Assert.assertEquals(generatedImages.length, 5);
for (Image img : generatedImages) {
Assert.assertEquals(img.getWidth(), 256);
Assert.assertEquals(img.getHeight(), 256);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import ai.djl.modality.cv.zoo.ImageClassificationModelLoader;
import ai.djl.pytorch.engine.PtEngine;
import ai.djl.pytorch.zoo.cv.gan.BigGANModelLoader;
import ai.djl.pytorch.zoo.cv.objectdetection.PtSsdModelLoader;
import ai.djl.pytorch.zoo.nlp.qa.BertQAModelLoader;
import ai.djl.pytorch.zoo.nlp.sentimentanalysis.DistilBertSentimentAnalysisModelLoader;
Expand Down Expand Up @@ -43,6 +44,8 @@ public class PtModelZoo implements ModelZoo {
public static final DistilBertSentimentAnalysisModelLoader DB_SENTIMENT_ANALYSIS =
new DistilBertSentimentAnalysisModelLoader(REPOSITORY);

public static final BigGANModelLoader BIG_GAN = new BigGANModelLoader(REPOSITORY);

/** {@inheritDoc} */
@Override
public String getGroupId() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
* Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.pytorch.zoo.cv.gan;

import ai.djl.Application;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.translator.BigGANTranslator;
import ai.djl.pytorch.zoo.PtModelZoo;
import ai.djl.repository.MRL;
import ai.djl.repository.Repository;
import ai.djl.repository.zoo.BaseModelLoader;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import ai.djl.util.Pair;
import ai.djl.util.Progress;
import java.io.IOException;
import java.util.Map;

/** Model loader for BigGAN. */
public class BigGANModelLoader extends BaseModelLoader {

private static final Application APPLICATION = Application.CV.GAN;
private static final String GROUP_ID = PtModelZoo.GROUP_ID;
private static final String ARTIFACT_ID = "biggan-deep";
private static final String VERSION = "0.0.1";

/**
* Creates the Model loader from the given repository.
*
* @param repository the repository to load the model from
*/
public BigGANModelLoader(Repository repository) {
super(repository, MRL.model(APPLICATION, GROUP_ID, ARTIFACT_ID), VERSION, new PtModelZoo());
FactoryImpl factory = new FactoryImpl();
factories.put(new Pair<>(int[].class, Image[].class), factory);
}

/**
* Loads the model.
*
* @return the loaded model
* @throws IOException for various exceptions loading data from the repository
* @throws ModelNotFoundException if no model with the specified criteria is found
* @throws MalformedModelException if the model data is malformed
*/
public ZooModel<int[], Image[]> loadModel()
throws MalformedModelException, ModelNotFoundException, IOException {
return loadModel(null, null, null);
}

/**
* Loads the model.
*
* @param progress the progress tracker to update while loading the model
* @return the loaded model
* @throws IOException for various exceptions loading data from the repository
* @throws ModelNotFoundException if no model with the specified criteria is found
* @throws MalformedModelException if the model data is malformed
*/
public ZooModel<int[], Image[]> loadModel(Progress progress)
throws MalformedModelException, ModelNotFoundException, IOException {
return loadModel(null, null, progress);
}

/**
* Loads the model with the given search filters.
*
* @param filters the search filters to match against the loaded model
* @param device the device the loaded model should use
* @param progress the progress tracker to update while loading the model
* @return the loaded model
* @throws IOException for various exceptions loading data from the repository
* @throws ModelNotFoundException if no model with the specified criteria is found
* @throws MalformedModelException if the model data is malformed
*/
public ZooModel<int[], Image[]> loadModel(
Map<String, String> filters, Device device, Progress progress)
throws IOException, ModelNotFoundException, MalformedModelException {
Criteria<int[], Image[]> criteria =
Criteria.builder()
.setTypes(int[].class, Image[].class)
.optModelZoo(modelZoo)
.optGroupId(resource.getMrl().getGroupId())
.optArtifactId(resource.getMrl().getArtifactId())
.optFilters(filters)
.optDevice(device)
.optProgress(progress)
.build();
return loadModel(criteria);
}

private static final class FactoryImpl implements TranslatorFactory<int[], Image[]> {

/** {@inheritDoc} */
@Override
public Translator<int[], Image[]> newInstance(Model model, Map<String, ?> arguments) {
Float truncation = (Float) arguments.get("truncation");
return new BigGANTranslator(truncation == null ? 0.5f : truncation);
}
}
}
Loading

0 comments on commit 6a81d9d

Please sign in to comment.