-
Notifications
You must be signed in to change notification settings - Fork 685
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Clean up and add docs for BigGAN demo (#1038)
* [pytorch] Add BigGAN demo in examples * [api] Start support for GAN * Fix biggan model issues Change-Id: If6bfbb4989a6054934bcf6dff8ace09ec729d08a * Fix minor issues and add comments * Fix build issue * Simplify the BigGAN demo * Add test for BigGAN demo * Add docs for BigGAN demo Co-authored-by: Frank Liu <frankfliu2000@gmail.com>
- Loading branch information
Showing
9 changed files
with
455 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
71 changes: 71 additions & 0 deletions
71
api/src/main/java/ai/djl/modality/cv/translator/BigGANTranslator.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
/* | ||
* Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
* with the License. A copy of the License is located at | ||
* | ||
* http://aws.amazon.com/apache2.0/ | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
* and limitations under the License. | ||
*/ | ||
package ai.djl.modality.cv.translator; | ||
|
||
import ai.djl.modality.cv.Image; | ||
import ai.djl.modality.cv.ImageFactory; | ||
import ai.djl.ndarray.NDArray; | ||
import ai.djl.ndarray.NDList; | ||
import ai.djl.ndarray.NDManager; | ||
import ai.djl.ndarray.types.DataType; | ||
import ai.djl.ndarray.types.Shape; | ||
import ai.djl.translate.Batchifier; | ||
import ai.djl.translate.Translator; | ||
import ai.djl.translate.TranslatorContext; | ||
|
||
/** Built-in {@code Translator} that provides preprocessing and postprocessing for BigGAN. */ | ||
public final class BigGANTranslator implements Translator<int[], Image[]> { | ||
|
||
private static final int NUMBER_OF_CATEGORIES = 1000; | ||
private static final int SEED_COLUMN_SIZE = 128; | ||
private float truncation; | ||
|
||
/** | ||
* Construct a translator for BigGAN. | ||
* | ||
* @param truncation value used to scale the normal seed for BigGAN | ||
*/ | ||
public BigGANTranslator(float truncation) { | ||
this.truncation = truncation; | ||
} | ||
|
||
@Override | ||
public Image[] processOutput(TranslatorContext ctx, NDList list) { | ||
NDArray output = list.get(0).addi(1).muli(128).clip(0, 255).toType(DataType.UINT8, false); | ||
|
||
int sampleSize = (int) output.getShape().get(0); | ||
Image[] images = new Image[sampleSize]; | ||
|
||
for (int i = 0; i < sampleSize; ++i) { | ||
images[i] = ImageFactory.getInstance().fromNDArray(output.get(i)); | ||
} | ||
|
||
return images; | ||
} | ||
|
||
@Override | ||
public NDList processInput(TranslatorContext ctx, int[] input) throws Exception { | ||
NDManager manager = ctx.getNDManager(); | ||
|
||
NDArray classes = manager.create(input).oneHot(NUMBER_OF_CATEGORIES); | ||
NDArray seed = | ||
manager.truncatedNormal(new Shape(input.length, SEED_COLUMN_SIZE)).muli(truncation); | ||
|
||
return new NDList(seed, classes, manager.create(truncation)); | ||
} | ||
|
||
@Override | ||
public Batchifier getBatchifier() { | ||
return null; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Image Generation with BigGAN from the Model Zoo | ||
|
||
[Generative Adversarial Networks](https://en.wikipedia.org/wiki/Generative_adversarial_network) (GANs) are a branch of deep learning used for generative modeling. | ||
They consist of 2 neural networks that act as adversaries, the Generator and the Discriminator. The Generator is assigned to generated fake images that look real, and the Discriminator needs to correctly identify the fake ones. | ||
|
||
In this example, you will learn how to use a [BigGAN](https://deepmind.com/research/open-source/biggan) generator to create images, using the generator directly from the [ModelZoo](../../docs/model-zoo.md). | ||
|
||
The source code for this example can be found at [BigGAN.java](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/BigGAN.java). | ||
|
||
## Setup guide | ||
|
||
To configure your development environment, follow [setup](../../docs/development/setup.md). | ||
|
||
## Run Generation | ||
|
||
### Introduction | ||
|
||
BigGAN is trained on a subset of the [ImageNet dataset](https://en.wikipedia.org/wiki/ImageNet) with 1000 categories. | ||
You can see the labels in [this file](https://github.com/deepjavalibrary/djl/blob/master/model-zoo/src/test/resources/mlrepo/model/cv/image_classification/ai/djl/zoo/synset_imagenet.txt). | ||
The training was done such that the input to the model uses the ID of the category, between 0 and 999. For us, the ID is the line number in the file, starting at 0. | ||
|
||
Thus, the input to the translator will be an array of category IDs: | ||
|
||
```java | ||
int[] input = {100, 207, 971, 970, 933}; | ||
``` | ||
|
||
### Build the project and run | ||
Use the following commands to run the project: | ||
|
||
``` | ||
cd examples | ||
./gradlew run -Dmain=ai.djl.examples.inference.BigGAN -Dai.djl.default_engine=PyTorch | ||
``` | ||
|
||
### Output | ||
|
||
Your output will vary since the generation depends on a random seed. Here are a few examples: | ||
|
||
Black Swan | Golden Retriever | Bubble | Alp | Cheeseburger | ||
:-------------------------:|:-------------------------: |:-------------------------: | :----------------------: | :----------------------: | ||
 | |  |  |  |
81 changes: 81 additions & 0 deletions
81
examples/src/main/java/ai/djl/examples/inference/BigGAN.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
/* | ||
* Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
* with the License. A copy of the License is located at | ||
* | ||
* http://aws.amazon.com/apache2.0/ | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
* and limitations under the License. | ||
*/ | ||
package ai.djl.examples.inference; | ||
|
||
import ai.djl.Application; | ||
import ai.djl.ModelException; | ||
import ai.djl.engine.Engine; | ||
import ai.djl.inference.Predictor; | ||
import ai.djl.modality.cv.Image; | ||
import ai.djl.repository.zoo.Criteria; | ||
import ai.djl.repository.zoo.ZooModel; | ||
import ai.djl.training.util.ProgressBar; | ||
import ai.djl.translate.TranslateException; | ||
import java.io.IOException; | ||
import java.nio.file.Files; | ||
import java.nio.file.Path; | ||
import java.nio.file.Paths; | ||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
|
||
/** An example of generation using BigGAN. */ | ||
public final class BigGAN { | ||
|
||
private static final Logger logger = LoggerFactory.getLogger(BigGAN.class); | ||
|
||
private BigGAN() {} | ||
|
||
public static void main(String[] args) throws ModelException, TranslateException, IOException { | ||
Image[] generatedImages = BigGAN.generate(); | ||
|
||
if (generatedImages == null) { | ||
logger.info("This example only works for PyTorch Engine"); | ||
} else { | ||
logger.info("Using PyTorch Engine. {} images generated.", generatedImages.length); | ||
saveImages(generatedImages); | ||
} | ||
} | ||
|
||
private static void saveImages(Image[] generatedImages) throws IOException { | ||
Path outputPath = Paths.get("build/output/gan/"); | ||
Files.createDirectories(outputPath); | ||
|
||
for (int i = 0; i < generatedImages.length; ++i) { | ||
Path imagePath = outputPath.resolve("image" + i + ".png"); | ||
generatedImages[i].save(Files.newOutputStream(imagePath), "png"); | ||
} | ||
logger.info("Generated images have been saved in: {}", outputPath); | ||
} | ||
|
||
public static Image[] generate() throws IOException, ModelException, TranslateException { | ||
if (!"PyTorch".equals(Engine.getInstance().getEngineName())) { | ||
return null; | ||
} | ||
|
||
Criteria<int[], Image[]> criteria = | ||
Criteria.builder() | ||
.optApplication(Application.CV.GAN) | ||
.setTypes(int[].class, Image[].class) | ||
.optFilter("size", "256") | ||
.optArgument("truncation", 0.4f) | ||
.optProgress(new ProgressBar()) | ||
.build(); | ||
|
||
int[] input = {100, 207, 971, 970, 933}; | ||
|
||
try (ZooModel<int[], Image[]> model = criteria.loadModel(); | ||
Predictor<int[], Image[]> generator = model.newPredictor()) { | ||
return generator.predict(input); | ||
} | ||
} | ||
} |
39 changes: 39 additions & 0 deletions
39
examples/src/test/java/ai/djl/examples/inference/BigGANTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
/* | ||
* Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
* with the License. A copy of the License is located at | ||
* | ||
* http://aws.amazon.com/apache2.0/ | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
* and limitations under the License. | ||
*/ | ||
package ai.djl.examples.inference; | ||
|
||
import ai.djl.ModelException; | ||
import ai.djl.modality.cv.Image; | ||
import ai.djl.translate.TranslateException; | ||
import java.io.IOException; | ||
import org.testng.Assert; | ||
import org.testng.SkipException; | ||
import org.testng.annotations.Test; | ||
|
||
public class BigGANTest { | ||
|
||
@Test | ||
public void testBigGAN() throws ModelException, TranslateException, IOException { | ||
Image[] generatedImages = BigGAN.generate(); | ||
|
||
if (generatedImages == null) { | ||
throw new SkipException("Only works for PyTorch engine."); | ||
} | ||
|
||
Assert.assertEquals(generatedImages.length, 5); | ||
for (Image img : generatedImages) { | ||
Assert.assertEquals(img.getWidth(), 256); | ||
Assert.assertEquals(img.getHeight(), 256); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
117 changes: 117 additions & 0 deletions
117
pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/cv/gan/BigGANModelLoader.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
/* | ||
* Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
* with the License. A copy of the License is located at | ||
* | ||
* http://aws.amazon.com/apache2.0/ | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
* and limitations under the License. | ||
*/ | ||
package ai.djl.pytorch.zoo.cv.gan; | ||
|
||
import ai.djl.Application; | ||
import ai.djl.Device; | ||
import ai.djl.MalformedModelException; | ||
import ai.djl.Model; | ||
import ai.djl.modality.cv.Image; | ||
import ai.djl.modality.cv.translator.BigGANTranslator; | ||
import ai.djl.pytorch.zoo.PtModelZoo; | ||
import ai.djl.repository.MRL; | ||
import ai.djl.repository.Repository; | ||
import ai.djl.repository.zoo.BaseModelLoader; | ||
import ai.djl.repository.zoo.Criteria; | ||
import ai.djl.repository.zoo.ModelNotFoundException; | ||
import ai.djl.repository.zoo.ZooModel; | ||
import ai.djl.translate.Translator; | ||
import ai.djl.translate.TranslatorFactory; | ||
import ai.djl.util.Pair; | ||
import ai.djl.util.Progress; | ||
import java.io.IOException; | ||
import java.util.Map; | ||
|
||
/** Model loader for BigGAN. */ | ||
public class BigGANModelLoader extends BaseModelLoader { | ||
|
||
private static final Application APPLICATION = Application.CV.GAN; | ||
private static final String GROUP_ID = PtModelZoo.GROUP_ID; | ||
private static final String ARTIFACT_ID = "biggan-deep"; | ||
private static final String VERSION = "0.0.1"; | ||
|
||
/** | ||
* Creates the Model loader from the given repository. | ||
* | ||
* @param repository the repository to load the model from | ||
*/ | ||
public BigGANModelLoader(Repository repository) { | ||
super(repository, MRL.model(APPLICATION, GROUP_ID, ARTIFACT_ID), VERSION, new PtModelZoo()); | ||
FactoryImpl factory = new FactoryImpl(); | ||
factories.put(new Pair<>(int[].class, Image[].class), factory); | ||
} | ||
|
||
/** | ||
* Loads the model. | ||
* | ||
* @return the loaded model | ||
* @throws IOException for various exceptions loading data from the repository | ||
* @throws ModelNotFoundException if no model with the specified criteria is found | ||
* @throws MalformedModelException if the model data is malformed | ||
*/ | ||
public ZooModel<int[], Image[]> loadModel() | ||
throws MalformedModelException, ModelNotFoundException, IOException { | ||
return loadModel(null, null, null); | ||
} | ||
|
||
/** | ||
* Loads the model. | ||
* | ||
* @param progress the progress tracker to update while loading the model | ||
* @return the loaded model | ||
* @throws IOException for various exceptions loading data from the repository | ||
* @throws ModelNotFoundException if no model with the specified criteria is found | ||
* @throws MalformedModelException if the model data is malformed | ||
*/ | ||
public ZooModel<int[], Image[]> loadModel(Progress progress) | ||
throws MalformedModelException, ModelNotFoundException, IOException { | ||
return loadModel(null, null, progress); | ||
} | ||
|
||
/** | ||
* Loads the model with the given search filters. | ||
* | ||
* @param filters the search filters to match against the loaded model | ||
* @param device the device the loaded model should use | ||
* @param progress the progress tracker to update while loading the model | ||
* @return the loaded model | ||
* @throws IOException for various exceptions loading data from the repository | ||
* @throws ModelNotFoundException if no model with the specified criteria is found | ||
* @throws MalformedModelException if the model data is malformed | ||
*/ | ||
public ZooModel<int[], Image[]> loadModel( | ||
Map<String, String> filters, Device device, Progress progress) | ||
throws IOException, ModelNotFoundException, MalformedModelException { | ||
Criteria<int[], Image[]> criteria = | ||
Criteria.builder() | ||
.setTypes(int[].class, Image[].class) | ||
.optModelZoo(modelZoo) | ||
.optGroupId(resource.getMrl().getGroupId()) | ||
.optArtifactId(resource.getMrl().getArtifactId()) | ||
.optFilters(filters) | ||
.optDevice(device) | ||
.optProgress(progress) | ||
.build(); | ||
return loadModel(criteria); | ||
} | ||
|
||
private static final class FactoryImpl implements TranslatorFactory<int[], Image[]> { | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public Translator<int[], Image[]> newInstance(Model model, Map<String, ?> arguments) { | ||
Float truncation = (Float) arguments.get("truncation"); | ||
return new BigGANTranslator(truncation == null ? 0.5f : truncation); | ||
} | ||
} | ||
} |
Oops, something went wrong.